@jax-js/jax 0.0.4 → 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -1,5 +1,5 @@
1
1
  import { __export } from "./chunk-Cl8Af3a2.js";
2
- import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-EBRGmEYw.js";
2
+ import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-DwIAd0AG.js";
3
3
 
4
4
  //#region src/tree.ts
5
5
  var tree_exports = {};
@@ -29,6 +29,10 @@ var JsTreeDef = class JsTreeDef {
29
29
  this.nodeMetadata = nodeMetadata;
30
30
  this.childTreedefs = childTreedefs;
31
31
  }
32
+ /** Get the total number of leaves in the tree. */
33
+ get size() {
34
+ return this.nodeType === NodeType.Leaf ? 1 : this.childTreedefs.reduce((a, b) => a + b.size, 0);
35
+ }
32
36
  /** Returns a string representation of this tree definition. */
33
37
  toString(root = true) {
34
38
  if (root) return "JsTreeDef(" + this.toString(false) + ")";
@@ -184,6 +188,16 @@ function pool(st, ks, strides = 1, dilation = 1) {
184
188
  const s_ = strides;
185
189
  const d_ = dilation;
186
190
  const o_ = zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
191
+ if (d_.every((d) => d === 1) && ks.every((k, j) => k <= s_[j])) {
192
+ st = st.padOrShrink([...noop.map(() => [0, 0]), ...zipn(i_, o_, s_).map(([i, o, s]) => [0, o * s - i])]);
193
+ st = st.reshape([...noop, ...zip(o_, s_).flatMap(([o, s]) => [o, s])]).shrink([...noop.map((x) => [0, x]), ...zip(o_, ks).flatMap(([o, k]) => [[0, o], [0, k]])]);
194
+ st = st.permute([
195
+ ...range(noop.length),
196
+ ...ks.map((_, j) => noop.length + 2 * j),
197
+ ...ks.map((_, j) => noop.length + 2 * j + 1)
198
+ ]);
199
+ return st;
200
+ }
187
201
  const f_ = zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
188
202
  const kidf = zipn(ks, i_, d_, f_);
189
203
  st = st.repeat([...rep(noop.length, 1), ...kidf.map(([k, i, d, f]) => Math.ceil(k * (i * f + d) / i))]);
@@ -218,6 +232,12 @@ function poolTranspose(st, inShape, ks, strides = 1, dilation = 1) {
218
232
  const s_ = strides;
219
233
  const d_ = dilation;
220
234
  const o_ = zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
235
+ if (d_.every((d) => d === 1) && ks.every((k, j) => k <= s_[j])) {
236
+ st = st.permute([...range(noop.length), ...ks.flatMap((_, j) => [noop.length + j, noop.length + o_.length + j])]);
237
+ st = st.pad([...noop.map(() => [0, 0]), ...zip(s_, ks).flatMap(([s, k]) => [[0, 0], [0, s - k]])]).reshape([...noop, ...zip(o_, s_).map(([o, s]) => o * s)]);
238
+ st = st.padOrShrink([...noop.map(() => [0, 0]), ...zipn(i_, o_, s_).map(([i, o, s]) => [0, i - o * s])]);
239
+ return st.reshape(st.shape.concat(rep(ks.length, 1)));
240
+ }
221
241
  if (!deepEqual(o_, st.shape.slice(noop.length, noop.length + ks.length))) throw new Error("poolTranspose() called with mismatched output shape");
222
242
  const f_ = zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
223
243
  const kidf = zipn(ks, i_, d_, f_);
@@ -327,6 +347,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
327
347
  Primitive$1["Atan"] = "atan";
328
348
  Primitive$1["Exp"] = "exp";
329
349
  Primitive$1["Log"] = "log";
350
+ Primitive$1["Erf"] = "erf";
351
+ Primitive$1["Erfc"] = "erfc";
330
352
  Primitive$1["Sqrt"] = "sqrt";
331
353
  Primitive$1["Min"] = "min";
332
354
  Primitive$1["Max"] = "max";
@@ -404,6 +426,12 @@ function exp$1(x) {
404
426
  function log$1(x) {
405
427
  return bind1(Primitive.Log, [x]);
406
428
  }
429
+ function erf$1(x) {
430
+ return bind1(Primitive.Erf, [x]);
431
+ }
432
+ function erfc$1(x) {
433
+ return bind1(Primitive.Erfc, [x]);
434
+ }
407
435
  function sqrt$1(x) {
408
436
  return bind1(Primitive.Sqrt, [x]);
409
437
  }
@@ -565,6 +593,21 @@ var Trace = class {
565
593
  this.main = main;
566
594
  }
567
595
  };
596
+ /**
597
+ * Broadcast shapes and promote types with casting for two avals.
598
+ *
599
+ * This implements the weak type behavior described in `promoteTypes()`, but not
600
+ * implemented in that function as `weakType` is not passed.
601
+ */
602
+ function promoteAvals(a, b) {
603
+ const shape$1 = generalBroadcast(a.shape, b.shape);
604
+ const weakType = a.weakType && b.weakType;
605
+ let dtype;
606
+ if (a.weakType === b.weakType) dtype = promoteTypes(a.dtype, b.dtype);
607
+ else if (a.weakType) dtype = promoteTypes(b.dtype, DType.Uint32);
608
+ else dtype = promoteTypes(a.dtype, DType.Uint32);
609
+ return new ShapedArray(shape$1, dtype, weakType);
610
+ }
568
611
  var Tracer = class Tracer {
569
612
  /** @ignore */
570
613
  _trace;
@@ -579,10 +622,19 @@ var Tracer = class Tracer {
579
622
  get size() {
580
623
  return prod(this.shape);
581
624
  }
582
- /** The dtype of the array. */
625
+ /** The dtype of elements stored in the array. */
583
626
  get dtype() {
584
627
  return this.aval.dtype;
585
628
  }
629
+ /**
630
+ * Whether the array is weakly typed.
631
+ *
632
+ * Weakly typed arrays will cast to the dtype of the other operand. See
633
+ * `promoteTypes()` for details.
634
+ */
635
+ get weakType() {
636
+ return this.aval.weakType;
637
+ }
586
638
  /** The number of dimensions of the array. */
587
639
  get ndim() {
588
640
  return this.shape.length;
@@ -819,12 +871,13 @@ function getShape(x) {
819
871
  return x instanceof Tracer ? x.shape : [];
820
872
  }
821
873
  var ShapedArray = class ShapedArray {
822
- constructor(shape$1, dtype) {
874
+ constructor(shape$1, dtype, weakType) {
823
875
  this.shape = shape$1;
824
876
  this.dtype = dtype;
877
+ this.weakType = weakType;
825
878
  }
826
879
  static fromAval(aval) {
827
- return new ShapedArray(aval.shape, aval.dtype);
880
+ return new ShapedArray(aval.shape, aval.dtype, aval.weakType);
828
881
  }
829
882
  get ndim() {
830
883
  return this.shape.length;
@@ -838,7 +891,7 @@ var ShapedArray = class ShapedArray {
838
891
  };
839
892
  function getAval(x) {
840
893
  if (x instanceof Tracer) return x.aval;
841
- else if (typeof x === "boolean" || typeof x === "number") return new ShapedArray([], typeof x === "boolean" ? DType.Bool : DType.Float32);
894
+ else if (typeof x === "boolean" || typeof x === "number") return new ShapedArray([], typeof x === "boolean" ? DType.Bool : DType.Float32, typeof x === "boolean" ? false : true);
842
895
  else throw new TypeError(`Unknown value: ${x}`);
843
896
  }
844
897
  function bind(prim, args, params = {}) {
@@ -1121,12 +1174,18 @@ function reshapeViews(exp$2, mapping, reduceAxis = false) {
1121
1174
  } else if (exp$3.op === AluOp.GlobalIndex) throw new Error("internal: reshapeViews() called with GlobalIndex op");
1122
1175
  });
1123
1176
  }
1124
- function broadcastedJit(fn) {
1177
+ function broadcastedJit(fn, opts) {
1125
1178
  return (nargs, exps, avals, params) => {
1126
- const newShape = avals.map((aval) => aval.shape).reduce(generalBroadcast);
1127
- exps = exps.map((exp$3) => reshapeViews(exp$3, (st) => {
1128
- if (!deepEqual(st.shape, newShape)) return st.broadcast(newShape, range(newShape.length - st.shape.length));
1129
- }));
1179
+ let { shape: newShape, dtype: newDtype } = avals.reduce(promoteAvals);
1180
+ const skipCastIdx = opts?.skipCastIdx ?? [];
1181
+ if (skipCastIdx.length) newDtype = avals.filter((_, i) => !skipCastIdx.includes(i)).reduce(promoteAvals).dtype;
1182
+ exps = exps.map((exp$3, i) => {
1183
+ exp$3 = reshapeViews(exp$3, (st) => {
1184
+ if (!deepEqual(st.shape, newShape)) return st.broadcast(newShape, range(newShape.length - st.shape.length));
1185
+ });
1186
+ if (exp$3.dtype !== newDtype && !skipCastIdx.includes(i)) exp$3 = AluExp.cast(newDtype, exp$3);
1187
+ return exp$3;
1188
+ });
1130
1189
  const exp$2 = fn(exps, params);
1131
1190
  return new Kernel(nargs, prod(newShape), exp$2);
1132
1191
  };
@@ -1160,7 +1219,7 @@ const jitRules = {
1160
1219
  const k1 = reshapeViews(keys[1], mapping);
1161
1220
  const c0 = AluExp.u32(0);
1162
1221
  const c1 = AluExp.cast(DType.Uint32, AluVar.gidx);
1163
- const exp$2 = AluExp.threefry2x32(c0, c1, k0, k1, mode);
1222
+ const exp$2 = AluExp.threefry2x32(k0, k1, c0, c1, mode);
1164
1223
  return new Kernel(nargs, prod(shape$1), exp$2);
1165
1224
  },
1166
1225
  [Primitive.Sin]: unopJit(AluExp.sin),
@@ -1169,6 +1228,8 @@ const jitRules = {
1169
1228
  [Primitive.Atan]: unopJit(AluExp.atan),
1170
1229
  [Primitive.Exp]: unopJit(AluExp.exp),
1171
1230
  [Primitive.Log]: unopJit(AluExp.log),
1231
+ [Primitive.Erf]: unopJit(AluExp.erf),
1232
+ [Primitive.Erfc]: unopJit(AluExp.erfc),
1172
1233
  [Primitive.Sqrt]: unopJit(AluExp.sqrt),
1173
1234
  [Primitive.Min]: broadcastedJit(([a, b]) => AluExp.min(a, b)),
1174
1235
  [Primitive.Max]: broadcastedJit(([a, b]) => AluExp.max(a, b)),
@@ -1201,7 +1262,7 @@ const jitRules = {
1201
1262
  [Primitive.Dot](nargs, [a, b], [as, bs]) {
1202
1263
  const k1 = jitRules[Primitive.Mul](nargs, [a, b], [as, bs], {});
1203
1264
  const c = k1.exp;
1204
- const cs = new ShapedArray(generalBroadcast(as.shape, bs.shape), c.dtype);
1265
+ const cs = promoteAvals(as, bs);
1205
1266
  return jitRules[Primitive.Reduce](nargs, [c], [cs], {
1206
1267
  op: AluOp.Add,
1207
1268
  axis: [cs.ndim - 1]
@@ -1211,12 +1272,12 @@ const jitRules = {
1211
1272
  const [stX, stY] = prepareConv(ShapeTracker.fromShape(as.shape), ShapeTracker.fromShape(bs.shape), params);
1212
1273
  a = reshapeViews(a, (st) => st.compose(stX));
1213
1274
  b = reshapeViews(b, (st) => st.compose(stY));
1214
- as = new ShapedArray(stX.shape, as.dtype);
1215
- bs = new ShapedArray(stY.shape, bs.dtype);
1275
+ as = new ShapedArray(stX.shape, as.dtype, as.weakType);
1276
+ bs = new ShapedArray(stY.shape, bs.dtype, bs.weakType);
1216
1277
  return jitRules[Primitive.Dot](nargs, [a, b], [as, bs], {});
1217
1278
  },
1218
1279
  [Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
1219
- [Primitive.Where]: broadcastedJit(([cond, a, b]) => AluExp.where(cond, a, b)),
1280
+ [Primitive.Where]: broadcastedJit(([cond, a, b]) => AluExp.where(cond, a, b), { skipCastIdx: [0] }),
1220
1281
  [Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
1221
1282
  [Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
1222
1283
  [Primitive.Reshape]: reshapeJit((st, { shape: shape$1 }) => st.reshape(shape$1)),
@@ -1265,9 +1326,10 @@ function splitGraphDataflow(backend, jaxpr) {
1265
1326
  Primitive.Conv,
1266
1327
  Primitive.PoolTranspose
1267
1328
  ];
1329
+ const heterogeneousViewPrimitives = [Primitive.Gather, Primitive.RandomBits];
1268
1330
  for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
1269
1331
  const eqn = jaxpr.eqns[i];
1270
- if (reducePrimitives.includes(eqn.primitive) || eqn.primitive === Primitive.Gather || eqn.outBinders.some((v) => blackNodes.has(v))) {
1332
+ if (reducePrimitives.includes(eqn.primitive) || heterogeneousViewPrimitives.includes(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
1271
1333
  for (const v of eqn.outBinders) {
1272
1334
  blackNodes.add(v);
1273
1335
  p1NextBlack.set(v, v);
@@ -1386,7 +1448,7 @@ var PendingExecute = class {
1386
1448
  /**
1387
1449
  * A multidimensional numeric array with data stored on CPU or GPU.
1388
1450
  *
1389
- * This is the library's core data type. Equivalent to `jnp.Array` from JAX, or
1451
+ * This is the library's core data type. Equivalent to `jax.Array` from JAX, or
1390
1452
  * `torch.Tensor`.
1391
1453
  *
1392
1454
  * Not to be confused with the JavaScript "Array" constructor. Avoid importing
@@ -1397,9 +1459,11 @@ var Array$1 = class Array$1 extends Tracer {
1397
1459
  static #nextId = 1001;
1398
1460
  id;
1399
1461
  #dtype;
1462
+ #weakType;
1400
1463
  #source;
1401
1464
  #st;
1402
1465
  #backend;
1466
+ #committed;
1403
1467
  #rc;
1404
1468
  #pendingSet;
1405
1469
  /**
@@ -1408,21 +1472,23 @@ var Array$1 = class Array$1 extends Tracer {
1408
1472
  * is a backend `Slot`, this constructor _takes ownership_ of the slot. It
1409
1473
  * will be freed when the array is disposed.
1410
1474
  */
1411
- constructor(source, st, dtype, backend, { pending = null } = {}) {
1475
+ constructor(args) {
1412
1476
  super(baseArrayTrace);
1413
1477
  this.id = Array$1.#nextId++;
1414
- this.#dtype = dtype;
1415
- this.#source = source;
1416
- this.#st = st;
1417
- this.#backend = backend;
1478
+ this.#dtype = args.dtype;
1479
+ this.#weakType = args.weakType;
1480
+ this.#source = args.source;
1481
+ this.#st = args.st;
1482
+ this.#backend = args.backend;
1483
+ this.#committed = args.committed;
1418
1484
  this.#rc = 1;
1419
- this.#pendingSet = new Set(pending);
1485
+ this.#pendingSet = new Set(args.pending);
1420
1486
  if (this.#pendingSet.size === 0) this.#pendingSet = null;
1421
- else if (source instanceof AluExp) throw new Error("internal: AluExp source cannot have pending executes");
1487
+ else if (this.#source instanceof AluExp) throw new Error("internal: AluExp source cannot have pending executes");
1422
1488
  }
1423
1489
  /** @ignore */
1424
1490
  get aval() {
1425
- return new ShapedArray(this.#st.shape, this.#dtype);
1491
+ return new ShapedArray(this.#st.shape, this.#dtype, this.#weakType);
1426
1492
  }
1427
1493
  /** Return a simple string representation of the array's dimensions. */
1428
1494
  toString() {
@@ -1434,6 +1500,18 @@ var Array$1 = class Array$1 extends Tracer {
1434
1500
  #check() {
1435
1501
  if (this.#rc <= 0) throw new UseAfterFreeError(this);
1436
1502
  }
1503
+ /** Construct an array, copying fields from `this`. */
1504
+ #newArrayFrom(args) {
1505
+ return new Array$1({
1506
+ source: args.source ?? this.#source,
1507
+ st: args.st ?? this.#st,
1508
+ dtype: args.dtype ?? this.#dtype,
1509
+ weakType: this.#weakType,
1510
+ backend: args.backend ?? this.#backend,
1511
+ committed: args.committed ?? this.#committed,
1512
+ pending: args.pending ?? this.#pending ?? void 0
1513
+ });
1514
+ }
1437
1515
  get ref() {
1438
1516
  this.#check();
1439
1517
  this.#rc++;
@@ -1473,7 +1551,10 @@ var Array$1 = class Array$1 extends Tracer {
1473
1551
  const pending = this.#pending;
1474
1552
  for (const exe of pending) exe.updateRc(1);
1475
1553
  if (typeof this.#source === "number") this.#backend.incRef(this.#source);
1476
- const ar = new Array$1(this.#source, st, this.#dtype, this.#backend, { pending });
1554
+ const ar = this.#newArrayFrom({
1555
+ st,
1556
+ pending
1557
+ });
1477
1558
  this.dispose();
1478
1559
  return ar;
1479
1560
  }
@@ -1483,9 +1564,10 @@ var Array$1 = class Array$1 extends Tracer {
1483
1564
  */
1484
1565
  #gather(indices, axis, outDim) {
1485
1566
  this.#check();
1486
- if (indices.some((a) => a.#backend !== this.#backend)) throw new TypeError(`Gather indices must have the same backend: ${this.#backend.type}`);
1487
1567
  const axisSet = new Set(axis);
1488
1568
  if (axisSet.size !== axis.length) throw new TypeError("Gather axis must not have duplicates");
1569
+ if (indices.some((a) => a.#committed && a.#backend !== this.#backend)) throw new TypeError(`Gather indices must have the same backend: ${this.#backend.type}`);
1570
+ indices = indices.map((ar) => ar._putSync(this.#backend));
1489
1571
  indices = Array$1.#broadcastArrays(indices);
1490
1572
  const indexShape = indices[0].shape;
1491
1573
  const finalShape = this.shape.filter((_, i) => !axisSet.has(i));
@@ -1522,7 +1604,11 @@ var Array$1 = class Array$1 extends Tracer {
1522
1604
  pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
1523
1605
  this.dispose();
1524
1606
  for (const ar of indices) ar.dispose();
1525
- return new Array$1(output, ShapeTracker.fromShape(finalShape), this.#dtype, this.#backend, { pending });
1607
+ return this.#newArrayFrom({
1608
+ source: output,
1609
+ st: ShapeTracker.fromShape(finalShape),
1610
+ pending
1611
+ });
1526
1612
  }
1527
1613
  /** Move axes to the rightmost dimension of the shape. */
1528
1614
  #moveAxesDown(axis) {
@@ -1545,11 +1631,17 @@ var Array$1 = class Array$1 extends Tracer {
1545
1631
  return this.#reshape(this.#st.permute(perm));
1546
1632
  }
1547
1633
  #unary(op, dtypeOutput) {
1634
+ const weakType = !dtypeOutput && this.#weakType;
1548
1635
  dtypeOutput ??= this.#dtype;
1549
1636
  this.#check();
1550
1637
  if (this.#source instanceof AluExp) {
1551
1638
  const exp$3 = new AluExp(op, dtypeOutput, [this.#source]);
1552
- return new Array$1(exp$3.simplify(), this.#st, dtypeOutput, this.#backend);
1639
+ this.dispose();
1640
+ return this.#newArrayFrom({
1641
+ source: exp$3.simplify(),
1642
+ dtype: dtypeOutput,
1643
+ weakType
1644
+ });
1553
1645
  }
1554
1646
  const indices = unravelAlu(this.#st.shape, AluVar.gidx);
1555
1647
  const exp$2 = new AluExp(op, dtypeOutput, [AluExp.globalView(this.#dtype, 0, this.#st, indices)]);
@@ -1559,41 +1651,67 @@ var Array$1 = class Array$1 extends Tracer {
1559
1651
  for (const exe of pending) exe.updateRc(1);
1560
1652
  pending.push(new PendingExecute(this.#backend, kernel, [this.#source], [output]));
1561
1653
  this.dispose();
1562
- return new Array$1(output, ShapeTracker.fromShape(this.shape), dtypeOutput, this.#backend, { pending });
1654
+ return this.#newArrayFrom({
1655
+ source: output,
1656
+ st: ShapeTracker.fromShape(this.shape),
1657
+ dtype: dtypeOutput,
1658
+ weakType,
1659
+ pending
1660
+ });
1563
1661
  }
1564
1662
  #binary(op, other) {
1565
- const custom = (src) => new AluExp(op, this.#dtype, src);
1663
+ const custom = (src) => new AluExp(op, src[0].dtype, src);
1566
1664
  return Array$1.#naryCustom(op, custom, [this, other]);
1567
1665
  }
1568
- static #naryCustom(name, custom, arrays, { dtypeOverride, dtypeOutput, reduceAxis } = {}) {
1666
+ static #naryCustom(name, custom, arrays, { dtypeOverride, strongTypeOutput, reduceAxis } = {}) {
1569
1667
  const n = arrays.length;
1570
- const backend = arrays[0].#backend;
1571
1668
  if (n === 0) throw new TypeError(`No inputs for ${name}`);
1572
1669
  for (const ar of arrays) ar.#check();
1573
- let dtype;
1574
- for (let i = 0; i < n; i++) {
1575
- if (dtypeOverride?.[i]) {
1576
- if (arrays[i].#dtype !== dtypeOverride[i]) throw new TypeError(`Wrong dtype in ${name}: expected ${dtypeOverride[i]}, got ${arrays[i].#dtype}`);
1577
- } else if (!dtype) dtype = arrays[i].#dtype;
1578
- else if (arrays[i].#dtype !== dtype) throw new TypeError(`Dtype mismatch in ${name}: ${dtype} vs ${arrays[i].#dtype}`);
1579
- if (arrays[i].#backend !== backend) throw new TypeError(`Backend mismatch in ${name}: ${backend.type} vs ${arrays[i].#backend.type}`);
1580
- }
1581
- dtypeOutput ??= dtype;
1582
- if (!dtypeOutput) throw new TypeError("nary operation with no dtype");
1670
+ let castDtype;
1671
+ let castWeakType = true;
1672
+ for (let i = 0; i < n; i++) if (dtypeOverride?.[i]) {
1673
+ if (arrays[i].#dtype !== dtypeOverride[i]) throw new TypeError(`Wrong dtype in ${name}: expected ${dtypeOverride[i]}, got ${arrays[i].#dtype}`);
1674
+ } else if (castDtype === void 0) {
1675
+ castDtype = arrays[i].#dtype;
1676
+ castWeakType = arrays[i].#weakType;
1677
+ } else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), new ShapedArray([], arrays[i].#dtype, arrays[i].#weakType)));
1678
+ const weakType = castWeakType && !strongTypeOutput;
1679
+ const { backend, committed } = Array$1.#computeBackend(name, arrays);
1680
+ arrays = arrays.map((ar) => ar._putSync(backend));
1583
1681
  arrays = Array$1.#broadcastArrays(arrays);
1584
1682
  const newShape = [...arrays[0].shape];
1585
1683
  if (arrays.every((ar) => ar.#source instanceof AluExp) && !reduceAxis) {
1684
+ const sources = arrays.map((ar, i) => {
1685
+ if (!dtypeOverride?.[i]) return AluExp.cast(castDtype, ar.#source);
1686
+ else return ar.#source;
1687
+ });
1586
1688
  if (arrays.every((ar) => deepEqual(ar.#st, arrays[0].#st))) {
1587
- const exp$4 = custom(arrays.map((ar) => ar.#source));
1588
- return new Array$1(exp$4.simplify(), arrays[0].#st, exp$4.dtype, backend);
1689
+ const exp$4 = custom(sources);
1690
+ arrays.forEach((ar) => ar.dispose());
1691
+ return new Array$1({
1692
+ source: exp$4.simplify(),
1693
+ st: arrays[0].#st,
1694
+ dtype: exp$4.dtype,
1695
+ weakType,
1696
+ backend,
1697
+ committed
1698
+ });
1589
1699
  }
1590
- const exp$3 = custom(arrays.map((ar) => {
1591
- const src$1 = ar.#source;
1700
+ const exp$3 = custom(arrays.map((ar, i) => {
1701
+ const src$1 = sources[i];
1592
1702
  if (ar.#st.contiguous) return src$1;
1593
1703
  return accessorAluExp(src$1, ar.#st, unravelAlu(newShape, AluVar.idx));
1594
1704
  }));
1595
1705
  const st = ShapeTracker.fromShape(newShape);
1596
- return new Array$1(exp$3.simplify(), st, exp$3.dtype, backend);
1706
+ arrays.forEach((ar) => ar.dispose());
1707
+ return new Array$1({
1708
+ source: exp$3.simplify(),
1709
+ st,
1710
+ dtype: exp$3.dtype,
1711
+ weakType,
1712
+ backend,
1713
+ committed
1714
+ });
1597
1715
  }
1598
1716
  let indices;
1599
1717
  if (!reduceAxis) indices = unravelAlu(newShape, AluVar.gidx);
@@ -1603,14 +1721,19 @@ var Array$1 = class Array$1 extends Tracer {
1603
1721
  }
1604
1722
  const inputs = [];
1605
1723
  const src = [];
1606
- for (const ar of arrays) if (ar.#source instanceof AluExp) src.push(accessorAluExp(ar.#source, ar.#st, indices));
1607
- else {
1608
- let gid = inputs.indexOf(ar.#source);
1609
- if (gid === -1) {
1610
- gid = inputs.length;
1611
- inputs.push(ar.#source);
1724
+ for (const [i, ar] of arrays.entries()) {
1725
+ let nextSrc;
1726
+ if (ar.#source instanceof AluExp) nextSrc = accessorAluExp(ar.#source, ar.#st, indices);
1727
+ else {
1728
+ let gid = inputs.indexOf(ar.#source);
1729
+ if (gid === -1) {
1730
+ gid = inputs.length;
1731
+ inputs.push(ar.#source);
1732
+ }
1733
+ nextSrc = AluExp.globalView(ar.#dtype, gid, ar.#st, indices);
1612
1734
  }
1613
- src.push(AluExp.globalView(ar.#dtype, gid, ar.#st, indices));
1735
+ if (!dtypeOverride?.[i]) nextSrc = AluExp.cast(castDtype, nextSrc);
1736
+ src.push(nextSrc);
1614
1737
  }
1615
1738
  const exp$2 = custom(src);
1616
1739
  let re = void 0;
@@ -1623,13 +1746,19 @@ var Array$1 = class Array$1 extends Tracer {
1623
1746
  const pending = new Set([...arrays.flatMap((ar) => ar.#pending)]);
1624
1747
  for (const exe of pending) exe.updateRc(1);
1625
1748
  pending.add(new PendingExecute(backend, kernel, inputs, [output]));
1626
- for (const ar of arrays) ar.dispose();
1627
- return new Array$1(output, ShapeTracker.fromShape(newShape), dtypeOutput, backend, { pending });
1749
+ arrays.forEach((ar) => ar.dispose());
1750
+ return new Array$1({
1751
+ source: output,
1752
+ st: ShapeTracker.fromShape(newShape),
1753
+ dtype: kernel.dtype,
1754
+ weakType,
1755
+ backend,
1756
+ committed,
1757
+ pending
1758
+ });
1628
1759
  }
1629
1760
  /** Reduce the last dimension of the array by an operation. */
1630
1761
  #reduce(op) {
1631
- this.#check();
1632
- if (this.ndim === 0) throw new Error("Cannot reduce a scalar");
1633
1762
  const shape$1 = this.shape;
1634
1763
  const reduction = new Reduction(this.#dtype, op, shape$1[shape$1.length - 1]);
1635
1764
  const newShape = shape$1.slice(0, -1);
@@ -1648,7 +1777,11 @@ var Array$1 = class Array$1 extends Tracer {
1648
1777
  for (const exe of pending) exe.updateRc(1);
1649
1778
  pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
1650
1779
  this.dispose();
1651
- return new Array$1(output, ShapeTracker.fromShape(newShape), this.#dtype, this.#backend, { pending });
1780
+ return this.#newArrayFrom({
1781
+ source: output,
1782
+ st: ShapeTracker.fromShape(newShape),
1783
+ pending
1784
+ });
1652
1785
  }
1653
1786
  /**
1654
1787
  * Normalizes this array into one backed by a `Slot`.
@@ -1684,8 +1817,8 @@ var Array$1 = class Array$1 extends Tracer {
1684
1817
  }
1685
1818
  #dataInline() {
1686
1819
  this.#check();
1687
- const exp$2 = this.#source;
1688
- const ar = new Array$1(exp$2, this.#st, this.dtype, getBackend("cpu"));
1820
+ if (!(this.#source instanceof AluExp)) throw new Error("internal: #dataInline called on non-AluExp source");
1821
+ const ar = this.#newArrayFrom({ backend: getBackend("cpu") });
1689
1822
  this.dispose();
1690
1823
  return ar.dataSync();
1691
1824
  }
@@ -1698,6 +1831,23 @@ var Array$1 = class Array$1 extends Tracer {
1698
1831
  return ar.#reshape(ar.#st.broadcast(newShape, range(newShape.length - ar.ndim)));
1699
1832
  });
1700
1833
  }
1834
+ static #computeBackend(name, arrays) {
1835
+ const committed = arrays.filter((ar) => ar.#committed);
1836
+ if (committed.length > 0) {
1837
+ const backend = committed[0].#backend;
1838
+ for (const ar of committed) if (ar.#backend !== backend) throw new Error(`Device mismatch in ${name} between committed arrays on (${backend.type}, ${ar.#backend.type}), please move to the same device with devicePut()`);
1839
+ return {
1840
+ backend,
1841
+ committed: true
1842
+ };
1843
+ } else {
1844
+ const backend = arrays.length > 0 ? arrays[0].#backend : getBackend();
1845
+ return {
1846
+ backend,
1847
+ committed: false
1848
+ };
1849
+ }
1850
+ }
1701
1851
  /** Realize the array and return it as data. */
1702
1852
  async data() {
1703
1853
  if (this.#source instanceof AluExp && this.size < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
@@ -1811,7 +1961,11 @@ var Array$1 = class Array$1 extends Tracer {
1811
1961
  x.#backend.incRef(x.#source);
1812
1962
  const pending = x.#pending;
1813
1963
  for (const exe of pending) exe.updateRc(1);
1814
- const y = new Array$1(x.#source, x.#st, dtype, x.#backend, { pending });
1964
+ const y = x.#newArrayFrom({
1965
+ dtype,
1966
+ weakType: false,
1967
+ pending
1968
+ });
1815
1969
  x.dispose();
1816
1970
  return [y];
1817
1971
  }
@@ -1853,6 +2007,12 @@ var Array$1 = class Array$1 extends Tracer {
1853
2007
  [Primitive.Log]([x]) {
1854
2008
  return [x.#unary(AluOp.Log)];
1855
2009
  },
2010
+ [Primitive.Erf]([x]) {
2011
+ return [x.#unary(AluOp.Erf)];
2012
+ },
2013
+ [Primitive.Erfc]([x]) {
2014
+ return [x.#unary(AluOp.Erfc)];
2015
+ },
1856
2016
  [Primitive.Sqrt]([x]) {
1857
2017
  return [x.#unary(AluOp.Sqrt)];
1858
2018
  },
@@ -1886,7 +2046,7 @@ var Array$1 = class Array$1 extends Tracer {
1886
2046
  },
1887
2047
  [Primitive.Compare]([x, y], { op }) {
1888
2048
  const custom = ([x$1, y$1]) => aluCompare(x$1, y$1, op);
1889
- return [Array$1.#naryCustom("compare", custom, [x, y], { dtypeOutput: DType.Bool })];
2049
+ return [Array$1.#naryCustom("compare", custom, [x, y], { strongTypeOutput: true })];
1890
2050
  },
1891
2051
  [Primitive.Where]([cond, x, y]) {
1892
2052
  const custom = ([cond$1, x$1, y$1]) => AluExp.where(cond$1, x$1, y$1);
@@ -1921,7 +2081,8 @@ var Array$1 = class Array$1 extends Tracer {
1921
2081
  },
1922
2082
  [Primitive.JitCall](args, { jaxpr, numConsts }) {
1923
2083
  if (jaxpr.inBinders.length !== args.length) throw new Error(`jit_call expects ${jaxpr.inBinders.length} args, got ${args.length}`);
1924
- const backend = getBackend();
2084
+ const { backend, committed } = Array$1.#computeBackend("jit_call", args);
2085
+ args = args.map((ar) => ar._putSync(backend));
1925
2086
  const consts = args.slice(0, numConsts);
1926
2087
  const tracers = args.slice(numConsts);
1927
2088
  const jp = jitCompile(backend, jaxpr, consts);
@@ -1932,43 +2093,66 @@ var Array$1 = class Array$1 extends Tracer {
1932
2093
  pending.splice(0, 0, ...prevPending);
1933
2094
  args.forEach((x) => x.dispose());
1934
2095
  return outputs.map((source, i) => {
1935
- return new Array$1(source, ShapeTracker.fromShape(jaxpr.outs[i].aval.shape), jaxpr.outs[i].aval.dtype, backend, { pending });
2096
+ return new Array$1({
2097
+ source,
2098
+ st: ShapeTracker.fromShape(jaxpr.outs[i].aval.shape),
2099
+ dtype: jaxpr.outs[i].aval.dtype,
2100
+ weakType: jaxpr.outs[i].aval.weakType,
2101
+ backend,
2102
+ committed,
2103
+ pending
2104
+ });
1936
2105
  });
1937
2106
  }
1938
2107
  };
1939
2108
  }
2109
+ /** @private */
1940
2110
  _realizeSource() {
1941
2111
  this.#realize();
1942
2112
  return this.#source;
1943
2113
  }
2114
+ /** @private Put this array on a new backend, asynchronously. */
2115
+ async _put(backend) {
2116
+ if (this.#backend === backend) return this;
2117
+ if (this.#source instanceof AluExp) {
2118
+ const ar = this.#newArrayFrom({
2119
+ backend,
2120
+ committed: true
2121
+ });
2122
+ this.dispose();
2123
+ return ar;
2124
+ } else {
2125
+ const data = await this.data();
2126
+ return arrayFromData(data, this.shape, {
2127
+ dtype: this.#dtype,
2128
+ device: backend.type
2129
+ }, this.#weakType);
2130
+ }
2131
+ }
2132
+ /** @private Put this array on a new backend, synchronously. */
2133
+ _putSync(backend) {
2134
+ if (this.#backend === backend) return this;
2135
+ if (this.#source instanceof AluExp) {
2136
+ const ar = this.#newArrayFrom({
2137
+ backend,
2138
+ committed: true
2139
+ });
2140
+ this.dispose();
2141
+ return ar;
2142
+ } else {
2143
+ const data = this.dataSync();
2144
+ return arrayFromData(data, this.shape, {
2145
+ dtype: this.#dtype,
2146
+ device: backend.type
2147
+ }, this.#weakType);
2148
+ }
2149
+ }
1944
2150
  };
1945
- /** Construct an array from a single scalar constant. */
1946
- function scalar(value, { dtype, device } = {}) {
1947
- if (typeof value === "number") {
1948
- dtype ??= DType.Float32;
1949
- if (![
1950
- DType.Float32,
1951
- DType.Float16,
1952
- DType.Int32,
1953
- DType.Uint32
1954
- ].includes(dtype)) throw new TypeError(`Mismatched dtype for scalar ${value}`);
1955
- } else if (typeof value === "boolean") {
1956
- dtype ??= DType.Bool;
1957
- if (![
1958
- DType.Float32,
1959
- DType.Float16,
1960
- DType.Int32,
1961
- DType.Uint32,
1962
- DType.Bool
1963
- ].includes(dtype)) throw new TypeError(`Mismatched dtype for scalar ${value}`);
1964
- } else throw new TypeError(`Invalid type for scalar ${value}`);
1965
- return new Array$1(AluExp.const(dtype, value), ShapeTracker.fromShape([]), dtype, getBackend(device));
1966
- }
1967
2151
  /** Constructor for creating a new array from data. */
1968
2152
  function array(values, { shape: shape$1, dtype, device } = {}) {
1969
2153
  if (values instanceof Tracer) {
1970
2154
  if (shape$1 && !deepEqual(values.shape, shape$1)) values = values.reshape(shape$1);
1971
- if (dtype && values.dtype !== dtype) throw new Error("array astype not implemented yet");
2155
+ if (dtype && values.dtype !== dtype) values = values.astype(dtype);
1972
2156
  return values;
1973
2157
  } else if (ArrayBuffer.isView(values)) return arrayFromData(values, shape$1 ?? [values.length], {
1974
2158
  dtype,
@@ -1990,6 +2174,10 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
1990
2174
  dtype,
1991
2175
  device
1992
2176
  });
2177
+ if (size$1 === 1) return full(shape$1, flat[0], {
2178
+ dtype,
2179
+ device
2180
+ });
1993
2181
  if (typeof flat[0] === "boolean") {
1994
2182
  dtype = dtype ?? DType.Bool;
1995
2183
  const data = new Int32Array(flat.map((x) => x ? 1 : 0));
@@ -1998,46 +2186,52 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
1998
2186
  device
1999
2187
  });
2000
2188
  } else {
2189
+ const weakType = dtype == void 0;
2001
2190
  dtype = dtype ?? DType.Float32;
2002
2191
  const data = dtypedJsArray(dtype, flat);
2003
2192
  return arrayFromData(data, shape$1, {
2004
2193
  dtype,
2005
2194
  device
2006
- });
2195
+ }, weakType);
2007
2196
  }
2008
2197
  }
2009
2198
  }
2010
- function arrayFromData(data, shape$1, { dtype, device } = {}) {
2199
+ function arrayFromData(data, shape$1, { dtype, device }, weakType = false) {
2200
+ if (data instanceof Float32Array) {
2201
+ if (dtype && dtype !== DType.Float32) throw new Error("Float32Array must have float32 type");
2202
+ dtype ??= DType.Float32;
2203
+ } else if (data instanceof Int32Array) {
2204
+ if (dtype && dtype !== DType.Int32 && dtype !== DType.Bool) throw new Error("Int32Array must have int32 or bool type");
2205
+ dtype ??= DType.Int32;
2206
+ } else if (data instanceof Uint32Array) {
2207
+ if (dtype && dtype !== DType.Uint32) throw new Error("Uint32Array must have uint32 type");
2208
+ dtype ??= DType.Uint32;
2209
+ } else if (data instanceof Float16Array) {
2210
+ if (dtype && dtype !== DType.Float16) throw new Error("Float16Array must have float16 type");
2211
+ dtype ??= DType.Float16;
2212
+ } else throw new Error("Unsupported data array type: " + data.constructor.name);
2011
2213
  if (data.length < inlineArrayLimit) {
2012
2214
  let allEqual = true;
2013
2215
  for (let i = 1; i < data.length; i++) if (data[i] !== data[0]) {
2014
2216
  allEqual = false;
2015
2217
  break;
2016
2218
  }
2017
- if (allEqual) return full(shape$1, data[0], {
2018
- dtype,
2019
- device
2020
- });
2219
+ if (allEqual) {
2220
+ const sa = new ShapedArray(shape$1, dtype, weakType);
2221
+ return fullInternal(sa, data[0], device);
2222
+ }
2021
2223
  }
2022
2224
  const backend = getBackend(device);
2023
- if (ArrayBuffer.isView(data)) {
2024
- const buf = new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
2025
- if (data instanceof Float32Array) {
2026
- if (dtype && dtype !== DType.Float32) throw new Error("Float32Array must have float32 type");
2027
- dtype ??= DType.Float32;
2028
- } else if (data instanceof Int32Array) {
2029
- if (dtype && dtype !== DType.Int32 && dtype !== DType.Bool) throw new Error("Int32Array must have int32 or bool type");
2030
- dtype ??= DType.Int32;
2031
- } else if (data instanceof Uint32Array) {
2032
- if (dtype && dtype !== DType.Uint32) throw new Error("Uint32Array must have uint32 type");
2033
- dtype ??= DType.Uint32;
2034
- } else if (data instanceof Float16Array) {
2035
- if (dtype && dtype !== DType.Float16) throw new Error("Float16Array must have float16 type");
2036
- dtype ??= DType.Float16;
2037
- } else throw new Error("Unsupported data array type: " + data.constructor.name);
2038
- const slot = backend.malloc(data.byteLength, buf);
2039
- return new Array$1(slot, ShapeTracker.fromShape(shape$1), dtype, backend);
2040
- } else throw new Error("Unsupported data type: " + data.constructor.name);
2225
+ const buf = new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
2226
+ const slot = backend.malloc(data.byteLength, buf);
2227
+ return new Array$1({
2228
+ source: slot,
2229
+ st: ShapeTracker.fromShape(shape$1),
2230
+ dtype,
2231
+ weakType,
2232
+ backend,
2233
+ committed: device != void 0
2234
+ });
2041
2235
  }
2042
2236
  function dataToJs(dtype, data, shape$1) {
2043
2237
  if (shape$1.length === 0) return dtype === DType.Bool ? Boolean(data[0]) : data[0];
@@ -2053,7 +2247,7 @@ function dataToJs(dtype, data, shape$1) {
2053
2247
  /** If x is a value, lift it into an array, otherwise leave it be. */
2054
2248
  function pureArray(x) {
2055
2249
  if (x instanceof Tracer) return x;
2056
- else return scalar(x);
2250
+ else return array(x);
2057
2251
  }
2058
2252
  var EvalTrace = class extends Trace {
2059
2253
  pure = (x) => pureArray(x);
@@ -2064,20 +2258,28 @@ var EvalTrace = class extends Trace {
2064
2258
  };
2065
2259
  const baseArrayTrace = new EvalTrace(newMain(EvalTrace, null));
2066
2260
  const implRules = Array$1._implRules();
2261
+ function fullInternal(aval, fillValue, device) {
2262
+ return new Array$1({
2263
+ source: AluExp.const(aval.dtype, fillValue),
2264
+ st: ShapeTracker.fromShape(aval.shape),
2265
+ dtype: aval.dtype,
2266
+ weakType: aval.weakType,
2267
+ backend: getBackend(device),
2268
+ committed: device != void 0
2269
+ });
2270
+ }
2067
2271
  function zerosLike$1(val, dtype) {
2068
- const aval = getAval(val);
2069
- if (val instanceof Tracer) val.dispose();
2070
- return zeros(aval.shape, { dtype: dtype ?? aval.dtype });
2272
+ return fullLike(val, 0, dtype);
2071
2273
  }
2072
2274
  function onesLike$1(val, dtype) {
2073
- const aval = getAval(val);
2074
- if (val instanceof Tracer) val.dispose();
2075
- return ones(aval.shape, { dtype: dtype ?? aval.dtype });
2275
+ return fullLike(val, 1, dtype);
2076
2276
  }
2077
2277
  function fullLike(val, fillValue, dtype) {
2078
2278
  const aval = getAval(val);
2079
2279
  if (val instanceof Tracer) val.dispose();
2080
- return full(aval.shape, fillValue, { dtype: dtype ?? aval.dtype });
2280
+ if (fillValue instanceof Tracer) throw new Error("numpy.fullLike() with array argument not implemented yet");
2281
+ const sa = new ShapedArray(aval.shape, dtype ?? aval.dtype, aval.weakType);
2282
+ return fullInternal(sa, fillValue);
2081
2283
  }
2082
2284
  /** Return a new array of given shape and type, filled with zeros. */
2083
2285
  function zeros(shape$1, { dtype, device } = {}) {
@@ -2095,19 +2297,14 @@ function ones(shape$1, { dtype, device } = {}) {
2095
2297
  }
2096
2298
  /** Return a new array of given shape and type, filled with `fill_value`. */
2097
2299
  function full(shape$1, fillValue, { dtype, device } = {}) {
2098
- let source;
2099
- if (typeof fillValue === "number") {
2100
- dtype = dtype ?? DType.Float32;
2101
- source = AluExp.const(dtype, fillValue);
2102
- } else if (typeof fillValue === "bigint") {
2103
- dtype = dtype ?? DType.Int32;
2104
- source = AluExp.const(dtype, Number(fillValue));
2105
- } else if (typeof fillValue === "boolean") {
2300
+ let weakType = dtype == void 0;
2301
+ if (typeof fillValue === "number") dtype = dtype ?? DType.Float32;
2302
+ else if (typeof fillValue === "boolean") {
2106
2303
  dtype = dtype ?? DType.Bool;
2107
- source = AluExp.const(dtype, fillValue ? 1 : 0);
2304
+ weakType = false;
2108
2305
  } else if (fillValue instanceof Tracer) throw new Error("numpy.full() with array argument not implemented yet");
2109
2306
  else throw new TypeError(`Invalid type for full: ${fillValue}`);
2110
- return new Array$1(source, ShapeTracker.fromShape(shape$1), dtype ?? DType.Float32, getBackend(device));
2307
+ return fullInternal(new ShapedArray(shape$1, dtype, weakType), fillValue, device);
2111
2308
  }
2112
2309
  /**
2113
2310
  * Create an identity matrix.
@@ -2117,6 +2314,7 @@ function full(shape$1, fillValue, { dtype, device } = {}) {
2117
2314
  */
2118
2315
  function eye(numRows, numCols, { dtype, device } = {}) {
2119
2316
  numCols = numCols ?? numRows;
2317
+ const weakType = dtype == void 0;
2120
2318
  dtype = dtype ?? DType.Float32;
2121
2319
  if (numCols < numRows) {
2122
2320
  const arr = eye(numCols, numRows, {
@@ -2130,7 +2328,14 @@ function eye(numRows, numCols, { dtype, device } = {}) {
2130
2328
  device
2131
2329
  });
2132
2330
  const exp$2 = AluExp.cmplt(AluExp.mod(AluVar.idx, AluExp.i32(numCols + 1)), AluExp.i32(1));
2133
- return new Array$1(AluExp.cast(dtype, exp$2), ShapeTracker.fromShape([numRows, numCols]), dtype, getBackend(device));
2331
+ return new Array$1({
2332
+ source: AluExp.cast(dtype, exp$2),
2333
+ st: ShapeTracker.fromShape([numRows, numCols]),
2334
+ dtype,
2335
+ weakType,
2336
+ backend: getBackend(device),
2337
+ committed: device != void 0
2338
+ });
2134
2339
  }
2135
2340
  /** Return the identity matrix, with ones on the main diagonal. */
2136
2341
  function identity$1(n, { dtype, device } = {}) {
@@ -2167,7 +2372,14 @@ function arange(start, stop, step = 1, { dtype, device } = {}) {
2167
2372
  });
2168
2373
  const exp$2 = AluExp.add(AluExp.const(dtype, start), AluExp.mul(AluExp.cast(dtype, AluVar.idx), AluExp.const(dtype, step)));
2169
2374
  const st = ShapeTracker.fromShape([size$1]);
2170
- return new Array$1(exp$2, st, dtype, getBackend(device));
2375
+ return new Array$1({
2376
+ source: exp$2,
2377
+ st,
2378
+ dtype,
2379
+ weakType: false,
2380
+ backend: getBackend(device),
2381
+ committed: device != void 0
2382
+ });
2171
2383
  }
2172
2384
  /**
2173
2385
  * Return evenly spaced numbers over a specified interval.
@@ -2185,10 +2397,10 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
2185
2397
  dtype,
2186
2398
  device
2187
2399
  });
2188
- else if (num === 1) return scalar(start, {
2400
+ else if (num === 1) return full([1], start, {
2189
2401
  dtype,
2190
2402
  device
2191
- }).reshape([1]);
2403
+ });
2192
2404
  else if (start === stop) return full([num], start, {
2193
2405
  dtype,
2194
2406
  device
@@ -2197,7 +2409,14 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
2197
2409
  const denom = endpoint ? num - 1 : num;
2198
2410
  const exp$2 = AluExp.cast(dtype, AluExp.add(AluExp.f32(start), AluExp.mul(AluExp.f32(delta / denom), AluExp.cast(DType.Float32, AluVar.idx))));
2199
2411
  const st = ShapeTracker.fromShape([num]);
2200
- return new Array$1(exp$2, st, dtype, getBackend(device));
2412
+ return new Array$1({
2413
+ source: exp$2,
2414
+ st,
2415
+ dtype,
2416
+ weakType: false,
2417
+ backend: getBackend(device),
2418
+ committed: device != void 0
2419
+ });
2201
2420
  }
2202
2421
  function aluCompare(a, b, op) {
2203
2422
  switch (op) {
@@ -2209,35 +2428,6 @@ function aluCompare(a, b, op) {
2209
2428
  case CompareOp.LessEqual: return AluExp.add(AluExp.cmplt(a, b), AluExp.cmpne(a, b).not());
2210
2429
  }
2211
2430
  }
2212
- /**
2213
- * Implements a NumPy-style generalized broadcast rule on two array shapes.
2214
- *
2215
- * "When operating on two arrays, NumPy compares their shapes element-wise. It
2216
- * starts with the trailing (i.e. rightmost) dimension and works its way left.
2217
- * Two dimensions are compatible when:
2218
- * 1. they are equal, or
2219
- * 2. one of them is 1."
2220
- *
2221
- * Throws a TypeError if the broadcast is not possible.
2222
- *
2223
- * <https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules>
2224
- */
2225
- function generalBroadcast(a, b) {
2226
- const out = [];
2227
- let i = a.length - 1;
2228
- let j = b.length - 1;
2229
- for (; i >= 0 && j >= 0; i--, j--) {
2230
- const x = a[i];
2231
- const y = b[j];
2232
- if (x === y) out.push(x);
2233
- else if (x === 1) out.push(y);
2234
- else if (y === 1) out.push(x);
2235
- else throw new TypeError(`Incompatible array broadcast shapes: ${a} vs ${b}`);
2236
- }
2237
- for (; i >= 0; i--) out.push(a[i]);
2238
- for (; j >= 0; j--) out.push(b[j]);
2239
- return out.reverse();
2240
- }
2241
2431
 
2242
2432
  //#endregion
2243
2433
  //#region node_modules/.pnpm/@oxc-project+runtime@0.78.0/node_modules/@oxc-project/runtime/src/helpers/esm/usingCtx.js
@@ -2313,13 +2503,15 @@ var Var = class Var {
2313
2503
  };
2314
2504
  /** Literal in a Jaxpr expression. Currently, only scalars are supported. */
2315
2505
  var Lit = class {
2316
- dtype;
2317
2506
  value;
2318
2507
  aval;
2319
- constructor(dtype, value) {
2320
- this.dtype = dtype;
2508
+ get dtype() {
2509
+ return this.aval.dtype;
2510
+ }
2511
+ constructor(aval, value) {
2512
+ if (aval.shape.length !== 0) throw new Error(`internal: Lit must be a scalar`);
2321
2513
  this.value = value;
2322
- this.aval = new ShapedArray([], dtype);
2514
+ this.aval = ShapedArray.fromAval(aval);
2323
2515
  }
2324
2516
  };
2325
2517
  function atomIsLit(atom, literal) {
@@ -2443,14 +2635,19 @@ var Jaxpr = class Jaxpr {
2443
2635
  const c = eqn.outBinders[0];
2444
2636
  if (atomIsLit(a, 0)) context.set(c, b);
2445
2637
  else if (atomIsLit(b, 0)) context.set(c, a);
2446
- else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(a.dtype, a.dtype === DType.Bool ? Math.min(a.value + b.value, 1) : a.value + b.value));
2638
+ else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(promoteAvals(a.aval, b.aval), a.dtype === DType.Bool ? Math.min(a.value + b.value, 1) : a.value + b.value));
2639
+ else newEqns.push(eqn);
2640
+ } else if (eqn.primitive === Primitive.Neg) {
2641
+ const [a] = inputs;
2642
+ const c = eqn.outBinders[0];
2643
+ if (atomIsLit(a)) context.set(c, new Lit(a.aval, -a.value));
2447
2644
  else newEqns.push(eqn);
2448
2645
  } else if (eqn.primitive === Primitive.Mul) {
2449
2646
  const [a, b] = inputs;
2450
2647
  const c = eqn.outBinders[0];
2451
2648
  if (atomIsLit(a, 1)) context.set(c, b);
2452
2649
  else if (atomIsLit(b, 1)) context.set(c, a);
2453
- else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(a.dtype, a.value * b.value));
2650
+ else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(promoteAvals(a.aval, b.aval), a.value * b.value));
2454
2651
  else newEqns.push(eqn);
2455
2652
  } else if (eqn.primitive === Primitive.Idiv) {
2456
2653
  const [a, b] = inputs;
@@ -2548,7 +2745,7 @@ function evalJaxpr(jaxpr, args) {
2548
2745
  if (x instanceof Var) {
2549
2746
  remainingRefs.set(x, (remainingRefs.get(x) ?? 0) - 1);
2550
2747
  return env.get(x);
2551
- } else return scalar(x.value, { dtype: x.dtype });
2748
+ } else return array(x.value, { dtype: x.dtype });
2552
2749
  };
2553
2750
  const write = (v, val) => {
2554
2751
  if (env.has(v)) throw new Error(`Variable already bound: ${v}`);
@@ -2607,7 +2804,7 @@ var JaxprTrace = class extends Trace {
2607
2804
  let tracer = this.builder.constTracers.get(val);
2608
2805
  if (tracer === void 0) {
2609
2806
  tracer = this.builder.newTracer(this, ShapedArray.fromAval(getAval(val)));
2610
- this.builder.addConst(tracer, val instanceof Tracer ? val.ref : scalar(val));
2807
+ this.builder.addConst(tracer, val instanceof Tracer ? val.ref : array(val));
2611
2808
  }
2612
2809
  return tracer;
2613
2810
  }
@@ -2676,7 +2873,7 @@ function _inlineLiterals(jaxpr, consts) {
2676
2873
  const newConsts = [];
2677
2874
  for (let i = 0; i < consts.length; i++) if (ndim$1(consts[i]) === 0 && consts[i] instanceof Array$1) {
2678
2875
  const ar = consts[i];
2679
- literals.set(jaxpr.inBinders[i], new Lit(ar.dtype, ar.dataSync()[0]));
2876
+ literals.set(jaxpr.inBinders[i], new Lit(ar.aval, ar.dataSync()[0]));
2680
2877
  } else {
2681
2878
  constBinders.push(jaxpr.inBinders[i]);
2682
2879
  newConsts.push(consts[i]);
@@ -2689,13 +2886,12 @@ function _inlineLiterals(jaxpr, consts) {
2689
2886
  }
2690
2887
  function binopAbstractEval([x, y]) {
2691
2888
  if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("binopAbstractEval expects ShapedArray inputs");
2692
- if (x.dtype !== y.dtype) throw new TypeError(`Mismatched dtypes: ${x.dtype} vs ${y.dtype}`);
2693
- return [new ShapedArray(generalBroadcast(x.shape, y.shape), x.dtype)];
2889
+ return [promoteAvals(x, y)];
2694
2890
  }
2695
2891
  function compareAbstractEval([x, y]) {
2696
2892
  if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("compareAbstractEval expects ShapedArray inputs");
2697
- if (x.dtype !== y.dtype) throw new TypeError(`Mismatched dtypes: ${x.dtype} vs ${y.dtype}`);
2698
- return [new ShapedArray(generalBroadcast(x.shape, y.shape), DType.Bool)];
2893
+ const aval = promoteAvals(x, y);
2894
+ return [new ShapedArray(aval.shape, DType.Bool, false)];
2699
2895
  }
2700
2896
  function vectorizedUnopAbstractEval([x]) {
2701
2897
  return [ShapedArray.fromAval(x)];
@@ -2708,18 +2904,18 @@ const abstractEvalRules = {
2708
2904
  [Primitive.Reciprocal]: vectorizedUnopAbstractEval,
2709
2905
  [Primitive.StopGradient]: vectorizedUnopAbstractEval,
2710
2906
  [Primitive.Cast]([x], { dtype }) {
2711
- return [new ShapedArray(x.shape, dtype)];
2907
+ return [new ShapedArray(x.shape, dtype, false)];
2712
2908
  },
2713
2909
  [Primitive.Bitcast]([x], { dtype }) {
2714
2910
  if (x.dtype === DType.Bool || dtype === DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
2715
2911
  if (byteWidth(x.dtype) !== byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
2716
- return [new ShapedArray(x.shape, dtype)];
2912
+ return [new ShapedArray(x.shape, dtype, false)];
2717
2913
  },
2718
2914
  [Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
2719
2915
  if (k0.dtype !== DType.Uint32 || k1.dtype !== DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
2720
2916
  const keyShape = generalBroadcast(k0.shape, k1.shape);
2721
2917
  if (!deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
2722
- return [new ShapedArray(shape$1, DType.Uint32)];
2918
+ return [new ShapedArray(shape$1, DType.Uint32, false)];
2723
2919
  },
2724
2920
  [Primitive.Sin]: vectorizedUnopAbstractEval,
2725
2921
  [Primitive.Cos]: vectorizedUnopAbstractEval,
@@ -2727,61 +2923,62 @@ const abstractEvalRules = {
2727
2923
  [Primitive.Atan]: vectorizedUnopAbstractEval,
2728
2924
  [Primitive.Exp]: vectorizedUnopAbstractEval,
2729
2925
  [Primitive.Log]: vectorizedUnopAbstractEval,
2926
+ [Primitive.Erf]: vectorizedUnopAbstractEval,
2927
+ [Primitive.Erfc]: vectorizedUnopAbstractEval,
2730
2928
  [Primitive.Sqrt]: vectorizedUnopAbstractEval,
2731
2929
  [Primitive.Min]: binopAbstractEval,
2732
2930
  [Primitive.Max]: binopAbstractEval,
2733
2931
  [Primitive.Reduce]([x], { axis }) {
2734
2932
  const axisSet = new Set(axis);
2735
2933
  const newShape = x.shape.filter((_, i) => !axisSet.has(i));
2736
- return [new ShapedArray(newShape, x.dtype)];
2934
+ return [new ShapedArray(newShape, x.dtype, x.weakType)];
2737
2935
  },
2738
2936
  [Primitive.Pool]([x], { window, strides }) {
2739
2937
  const shape$1 = checkPoolShape(x.shape, window, strides);
2740
- return [new ShapedArray(shape$1, x.dtype)];
2938
+ return [new ShapedArray(shape$1, x.dtype, x.weakType)];
2741
2939
  },
2742
2940
  [Primitive.PoolTranspose]([x], { inShape, window, strides }) {
2743
2941
  const shape$1 = checkPoolShape(inShape, window, strides);
2744
2942
  if (!deepEqual(shape$1, x.shape)) throw new TypeError(`PoolTranspose shape mismatch: expected ${JSON.stringify(shape$1)}, got ${JSON.stringify(x.shape)}`);
2745
- return [new ShapedArray(inShape, x.dtype)];
2943
+ return [new ShapedArray(inShape, x.dtype, x.weakType)];
2746
2944
  },
2747
2945
  [Primitive.Dot]([x, y]) {
2748
- if (x.dtype !== y.dtype) throw new TypeError(`Dot dtype mismatch, got ${x.dtype} vs ${y.dtype}`);
2749
2946
  if (x.ndim === 0 && y.ndim === 0) throw new TypeError("Dot requires at least 1D inputs");
2750
- const shape$1 = generalBroadcast(x.shape, y.shape);
2947
+ const { shape: shape$1, dtype, weakType } = promoteAvals(x, y);
2751
2948
  shape$1.splice(-1, 1);
2752
- return [new ShapedArray(shape$1, x.dtype)];
2949
+ return [new ShapedArray(shape$1, dtype, weakType)];
2753
2950
  },
2754
2951
  [Primitive.Conv]([lhs, rhs], params) {
2755
- if (lhs.dtype !== rhs.dtype) throw new TypeError(`Conv dtype mismatch, got ${lhs.dtype} vs ${rhs.dtype}`);
2952
+ const { dtype, weakType } = promoteAvals(new ShapedArray([], lhs.dtype, lhs.weakType), new ShapedArray([], rhs.dtype, rhs.weakType));
2756
2953
  const shape$1 = checkConvShape(lhs.shape, rhs.shape, params);
2757
- return [new ShapedArray(shape$1, lhs.dtype)];
2954
+ return [new ShapedArray(shape$1, dtype, weakType)];
2758
2955
  },
2759
2956
  [Primitive.Compare]: compareAbstractEval,
2760
2957
  [Primitive.Where]([cond, x, y]) {
2761
2958
  if (cond.dtype !== DType.Bool) throw new TypeError(`Condition must be boolean, got ${cond.dtype}`);
2762
- if (x.dtype !== y.dtype) throw new TypeError(`Mismatched dtypes: ${x.dtype} vs ${y.dtype}`);
2763
- const shape$1 = generalBroadcast(cond.shape, generalBroadcast(x.shape, y.shape));
2764
- return [new ShapedArray(shape$1, x.dtype)];
2959
+ const xy = promoteAvals(x, y);
2960
+ const shape$1 = generalBroadcast(cond.shape, xy.shape);
2961
+ return [new ShapedArray(shape$1, xy.dtype, xy.weakType)];
2765
2962
  },
2766
2963
  [Primitive.Transpose]([x], { perm }) {
2767
- return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype)];
2964
+ return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype, x.weakType)];
2768
2965
  },
2769
2966
  [Primitive.Broadcast]([x], { shape: shape$1 }) {
2770
- return [new ShapedArray(shape$1, x.dtype)];
2967
+ return [new ShapedArray(shape$1, x.dtype, x.weakType)];
2771
2968
  },
2772
2969
  [Primitive.Reshape]([x], { shape: shape$1 }) {
2773
- return [new ShapedArray(shape$1, x.dtype)];
2970
+ return [new ShapedArray(shape$1, x.dtype, x.weakType)];
2774
2971
  },
2775
2972
  [Primitive.Flip]([x], _) {
2776
- return [new ShapedArray(x.shape, x.dtype)];
2973
+ return [ShapedArray.fromAval(x)];
2777
2974
  },
2778
2975
  [Primitive.Shrink]([x], { slice }) {
2779
2976
  const newShape = slice.map((s) => s[1] - s[0]);
2780
- return [new ShapedArray(newShape, x.dtype)];
2977
+ return [new ShapedArray(newShape, x.dtype, x.weakType)];
2781
2978
  },
2782
2979
  [Primitive.Pad]([x], { width }) {
2783
2980
  const newShape = x.shape.map((dim, i) => dim + width[i][0] + width[i][1]);
2784
- return [new ShapedArray(newShape, x.dtype)];
2981
+ return [new ShapedArray(newShape, x.dtype, x.weakType)];
2785
2982
  },
2786
2983
  [Primitive.Gather]([x, ...indices], { axis, outDim }) {
2787
2984
  for (const a of indices) if (a.dtype !== DType.Int32 && a.dtype !== DType.Uint32) throw new TypeError(`Gather indices must be Int32 or Uint32, got ${a.dtype}`);
@@ -2794,7 +2991,7 @@ const abstractEvalRules = {
2794
2991
  const gatherShape = indices.reduce((shape$1, a) => generalBroadcast(shape$1, a.shape), []);
2795
2992
  const newShape = x.shape.filter((_, i) => !axisSet.has(i));
2796
2993
  newShape.splice(outDim, 0, ...gatherShape);
2797
- return [new ShapedArray(newShape, x.dtype)];
2994
+ return [new ShapedArray(newShape, x.dtype, x.weakType)];
2798
2995
  },
2799
2996
  [Primitive.JitCall](args, { jaxpr }) {
2800
2997
  const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
@@ -2861,6 +3058,7 @@ function jit$1(f, opts) {
2861
3058
  const cacheKey = JSON.stringify(jaxprArgs);
2862
3059
  const { jaxpr, consts, treedef: outTree } = runWithCache(cache, cacheKey, () => makeJaxpr$1(f, opts)(...jaxprArgs));
2863
3060
  const outs = bind(Primitive.JitCall, [...consts.map((c) => c.ref), ...argsFlat], {
3061
+ name: f.name || "closure",
2864
3062
  jaxpr,
2865
3063
  numConsts: consts.length
2866
3064
  });
@@ -2979,6 +3177,16 @@ const jvpRules = {
2979
3177
  [Primitive.Log]([x], [dx]) {
2980
3178
  return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
2981
3179
  },
3180
+ [Primitive.Erf]([x], [dx]) {
3181
+ const coeff = 2 / Math.sqrt(Math.PI);
3182
+ const expTerm = exp$1(neg(x.ref.mul(x.ref)));
3183
+ return [[erf$1(x)], [expTerm.mul(coeff).mul(dx)]];
3184
+ },
3185
+ [Primitive.Erfc]([x], [dx]) {
3186
+ const coeff = -2 / Math.sqrt(Math.PI);
3187
+ const expTerm = exp$1(neg(x.ref.mul(x.ref)));
3188
+ return [[erfc$1(x)], [expTerm.mul(coeff).mul(dx)]];
3189
+ },
2982
3190
  [Primitive.Sqrt]([x], [dx]) {
2983
3191
  const z = sqrt$1(x);
2984
3192
  return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
@@ -3022,13 +3230,14 @@ const jvpRules = {
3022
3230
  const indicesRef = indices.map((t) => t.ref);
3023
3231
  return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
3024
3232
  },
3025
- [Primitive.JitCall](primals, tangents, { jaxpr }) {
3233
+ [Primitive.JitCall](primals, tangents, { name, jaxpr }) {
3026
3234
  const { newJaxpr, newConsts } = jvpJaxpr(jaxpr);
3027
3235
  const outs = bind(Primitive.JitCall, [
3028
3236
  ...newConsts.map((c) => c.ref),
3029
3237
  ...primals,
3030
3238
  ...tangents
3031
3239
  ], {
3240
+ name: `${name}_jvp`,
3032
3241
  jaxpr: newJaxpr,
3033
3242
  numConsts: newConsts.length
3034
3243
  });
@@ -3082,7 +3291,7 @@ function jvp$1(f, primals, tangents) {
3082
3291
  function mappedAval(batchDim, aval) {
3083
3292
  const shape$1 = [...aval.shape];
3084
3293
  shape$1.splice(batchDim, 1);
3085
- return new ShapedArray(shape$1, aval.dtype);
3294
+ return new ShapedArray(shape$1, aval.dtype, aval.weakType);
3086
3295
  }
3087
3296
  /** Move one axis to a different index. */
3088
3297
  function moveaxis$1(x, src, dst) {
@@ -3139,6 +3348,10 @@ var BatchTrace = class extends Trace {
3139
3348
  const [valsIn, bdimsIn] = unzip2(tracers.map((t) => [t.val, t.batchDim]));
3140
3349
  const vmapRule = vmapRules[primitive];
3141
3350
  if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
3351
+ if (bdimsIn.every((d) => d === null)) {
3352
+ const valOuts$1 = bind(primitive, valsIn, params);
3353
+ return valOuts$1.map((x) => new BatchTracer(this, x, null));
3354
+ }
3142
3355
  const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
3143
3356
  return zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
3144
3357
  }
@@ -3146,24 +3359,28 @@ var BatchTrace = class extends Trace {
3146
3359
  return this.main.globalData;
3147
3360
  }
3148
3361
  };
3149
- function handleScalarBroadcasting(nd, x, d) {
3150
- if (d === null || nd === ndim$1(x)) return x;
3151
- else {
3152
- const axis = range(ndim$1(x), nd);
3153
- const shape$1 = [...x.shape, ...axis.map(() => 1)];
3154
- return broadcast(x, shape$1, axis);
3155
- }
3156
- }
3157
- /** Process a primitive with built-in broadcasting. */
3362
+ /**
3363
+ * Process a primitive with built-in broadcasting.
3364
+ *
3365
+ * Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
3366
+ */
3158
3367
  function broadcastBatcher(op) {
3159
3368
  return (axisSize, args, dims) => {
3160
3369
  if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
3161
- const idx = dims.findIndex((d) => d !== null);
3162
- if (idx === -1) return [[op(...args)], [null]];
3163
- if (zip(args, dims).every(([x, d]) => ndim$1(x) === 0 || deepEqual(x.shape, args[idx].shape) && d === dims[idx])) return [[op(...args)], [dims[idx]]];
3164
- args = args.map((x, i) => ndim$1(x) > 0 ? moveBatchAxis(axisSize, dims[i], 0, x) : x);
3165
- const nd = Math.max(...args.map(ndim$1));
3166
- args = args.map((x, i) => handleScalarBroadcasting(nd, x, dims[i]));
3370
+ const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
3371
+ const firstIdx = dims.findIndex((d) => d !== null);
3372
+ const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
3373
+ if (zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
3374
+ args = args.map((x, i) => {
3375
+ if (dims[i] === null) return x;
3376
+ x = moveBatchAxis(axisSize, dims[i], 0, x);
3377
+ if (x.ndim < nd) x = x.reshape([
3378
+ x.shape[0],
3379
+ ...rep(nd - x.ndim, 1),
3380
+ ...x.shape.slice(1)
3381
+ ]);
3382
+ return x;
3383
+ });
3167
3384
  return [[op(...args)], [0]];
3168
3385
  };
3169
3386
  }
@@ -3187,17 +3404,18 @@ const vmapRules = {
3187
3404
  [Primitive.Atan]: unopBatcher(atan$1),
3188
3405
  [Primitive.Exp]: unopBatcher(exp$1),
3189
3406
  [Primitive.Log]: unopBatcher(log$1),
3407
+ [Primitive.Erf]: unopBatcher(erf$1),
3408
+ [Primitive.Erfc]: unopBatcher(erfc$1),
3190
3409
  [Primitive.Sqrt]: unopBatcher(sqrt$1),
3191
3410
  [Primitive.Min]: broadcastBatcher(min$1),
3192
3411
  [Primitive.Max]: broadcastBatcher(max$1),
3193
3412
  [Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
3194
- if (xBdim === null) return [[reduce(x, op, axis)], [null]];
3413
+ assertNonNull(xBdim);
3195
3414
  const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3196
3415
  const outBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
3197
3416
  return [[reduce(x, op, newAxis)], [outBdim]];
3198
3417
  },
3199
3418
  [Primitive.Dot](axisSize, [x, y], [xBdim, yBdim]) {
3200
- if (xBdim === null && yBdim === null) return [[dot$1(x, y)], [null]];
3201
3419
  x = moveBatchAxis(axisSize, xBdim, x.ndim - (xBdim === null ? 1 : 2), x);
3202
3420
  y = moveBatchAxis(axisSize, yBdim, y.ndim - (yBdim === null ? 1 : 2), y);
3203
3421
  const z = dot$1(x, y);
@@ -3206,29 +3424,72 @@ const vmapRules = {
3206
3424
  [Primitive.Compare](axisSize, args, dims, { op }) {
3207
3425
  return broadcastBatcher((x, y) => compare(x, y, op))(axisSize, args, dims, {});
3208
3426
  },
3427
+ [Primitive.Where]: broadcastBatcher(where$1),
3428
+ [Primitive.Transpose](axisSize, [x], [xBdim], { perm }) {
3429
+ assertNonNull(xBdim);
3430
+ const newPerm = perm.map((p) => p + (xBdim <= p ? 1 : 0));
3431
+ newPerm.splice(xBdim, 0, xBdim);
3432
+ return [[transpose$1(x, newPerm)], [xBdim]];
3433
+ },
3434
+ [Primitive.Broadcast](axisSize, [x], [xBdim], { shape: shape$1, axis }) {
3435
+ assertNonNull(xBdim);
3436
+ const newShape = shape$1.toSpliced(xBdim, 0, axisSize);
3437
+ const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3438
+ return [[broadcast(x, newShape, newAxis)], [xBdim]];
3439
+ },
3209
3440
  [Primitive.Reshape](axisSize, [x], [xBdim], { shape: shape$1 }) {
3210
- if (xBdim === null) return [[reshape$1(x, shape$1)], [null]];
3211
3441
  x = moveBatchAxis(axisSize, xBdim, 0, x);
3212
3442
  return [[reshape$1(x, [axisSize, ...shape$1])], [0]];
3213
3443
  },
3214
3444
  [Primitive.Flip](axisSize, [x], [xBdim], { axis }) {
3215
- if (xBdim === null) return [[flip$1(x, axis)], [null]];
3445
+ assertNonNull(xBdim);
3216
3446
  const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3217
3447
  return [[flip$1(x, newAxis)], [xBdim]];
3218
3448
  },
3219
3449
  [Primitive.Shrink](axisSize, [x], [xBdim], { slice }) {
3220
- if (xBdim === null) return [[shrink(x, slice)], [null]];
3450
+ assertNonNull(xBdim);
3221
3451
  const newSlice = slice.toSpliced(xBdim, 0, [0, axisSize]);
3222
3452
  return [[shrink(x, newSlice)], [xBdim]];
3223
3453
  },
3224
3454
  [Primitive.Pad](axisSize, [x], [xBdim], { width }) {
3225
- if (xBdim === null) return [[pad$1(x, width)], [null]];
3455
+ assertNonNull(xBdim);
3226
3456
  const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
3227
3457
  return [[pad$1(x, newWidth)], [xBdim]];
3228
3458
  },
3229
- [Primitive.JitCall](axisSize, args, dims, { jaxpr }) {
3459
+ [Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
3460
+ if (indicesBdim.every((d) => d === null)) {
3461
+ assertNonNull(xBdim);
3462
+ const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3463
+ let newBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
3464
+ let newOutDim = outDim;
3465
+ if (newOutDim < newBdim) newBdim += axis.length;
3466
+ else newOutDim += 1;
3467
+ return [[gather(x, indices, newAxis, newOutDim)], [newBdim]];
3468
+ }
3469
+ const nd = Math.max(...indices.map((m, i) => ndim$1(m) + (indicesBdim[i] === null ? 1 : 0)));
3470
+ indices = indices.map((m, i) => {
3471
+ if (indicesBdim[i] === null) return m;
3472
+ m = moveBatchAxis(axisSize, indicesBdim[i], 0, m);
3473
+ if (m.ndim < nd) m = m.reshape([
3474
+ m.shape[0],
3475
+ ...rep(nd - m.ndim, 1),
3476
+ ...m.shape.slice(1)
3477
+ ]);
3478
+ return m;
3479
+ });
3480
+ if (xBdim === null) return [[gather(x, indices, axis, outDim)], [outDim]];
3481
+ else {
3482
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3483
+ const newAxis = [0, ...axis.map((ax) => ax + 1)];
3484
+ const extraBatchIndex = arange(axisSize).reshape([-1, ...rep(nd - 1, 1)]);
3485
+ indices.splice(0, 0, extraBatchIndex);
3486
+ return [[gather(x, indices, newAxis, outDim)], [outDim]];
3487
+ }
3488
+ },
3489
+ [Primitive.JitCall](axisSize, args, dims, { name, jaxpr }) {
3230
3490
  const { newJaxpr, newConsts } = vmapJaxpr(jaxpr, axisSize, dims);
3231
3491
  const outs = bind(Primitive.JitCall, [...newConsts.map((c) => c.ref), ...args], {
3492
+ name: `${name}_vmap`,
3232
3493
  jaxpr: newJaxpr,
3233
3494
  numConsts: newConsts.length
3234
3495
  });
@@ -3244,7 +3505,7 @@ function vmapJaxpr(jaxpr, axisSize, dims) {
3244
3505
  if (dims[i] === null) return v.aval;
3245
3506
  const shape$1 = [...v.aval.shape];
3246
3507
  shape$1.splice(dims[i], 0, axisSize);
3247
- return new ShapedArray(shape$1, v.aval.dtype);
3508
+ return new ShapedArray(shape$1, v.aval.dtype, v.aval.weakType);
3248
3509
  });
3249
3510
  const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((args) => vmapFlat(jaxprAsFun(jaxpr), dims, args))(inAvals);
3250
3511
  const result = {
@@ -3284,12 +3545,14 @@ function vmapFlat(f, inAxes, args) {
3284
3545
  function vmap$1(f, inAxes = 0) {
3285
3546
  return (...args) => {
3286
3547
  const [argsFlat, inTree] = flatten(args);
3287
- let inAxesFlat;
3548
+ let inAxesFlat = [];
3288
3549
  if (typeof inAxes === "number") inAxesFlat = rep(argsFlat.length, inAxes);
3550
+ else for (let i = 0; i < args.length; i++) if (inAxes[i] == null) inAxesFlat.push(...rep(inTree.childTreedefs[i].size, null));
3551
+ else if (typeof inAxes[i] === "number") inAxesFlat.push(...rep(inTree.childTreedefs[i].size, inAxes[i]));
3289
3552
  else {
3290
- let inTree2;
3291
- [inAxesFlat, inTree2] = flatten(inAxes);
3292
- if (!inTree.equals(inTree2)) throw new TreeMismatchError("vmap", inTree, inTree2);
3553
+ const [axesFlat, axesTreeDef] = flatten(inAxes[i]);
3554
+ if (!inTree.childTreedefs[i].equals(axesTreeDef)) throw new TreeMismatchError("vmap", inTree.childTreedefs[i], axesTreeDef);
3555
+ inAxesFlat.push(...axesFlat);
3293
3556
  }
3294
3557
  const [fFlat, outTree] = flattenFun(f, inTree);
3295
3558
  const outsFlat = vmapFlat(fFlat, inAxesFlat, argsFlat);
@@ -3457,8 +3720,8 @@ var PartialEvalTrace = class extends Trace {
3457
3720
  processPrimitive(primitive, tracers, params) {
3458
3721
  if (tracers.every((t) => t.pval.isKnown)) return bind(primitive, tracers.map((t) => t.fullLower()), params);
3459
3722
  if (primitive === Primitive.JitCall) {
3460
- const { jaxpr, numConsts } = params;
3461
- return this.#partialEvalJaxpr(jaxpr, numConsts, tracers);
3723
+ const { name, jaxpr, numConsts } = params;
3724
+ return this.#partialEvalJaxpr(name, jaxpr, numConsts, tracers);
3462
3725
  }
3463
3726
  const tracersIn = tracers.map((t) => this.instantiateConst(t));
3464
3727
  const avalsIn = tracersIn.map((t) => t.pval.aval);
@@ -3484,12 +3747,13 @@ var PartialEvalTrace = class extends Trace {
3484
3747
  *
3485
3748
  * Used when encountering a JitCall rule during the trace.
3486
3749
  */
3487
- #partialEvalJaxpr(jaxpr, numConsts, tracers) {
3750
+ #partialEvalJaxpr(name, jaxpr, numConsts, tracers) {
3488
3751
  jaxpr = jaxpr.flatten();
3489
3752
  const inUnknowns = tracers.map((t) => !t.pval.isKnown);
3490
3753
  const { jaxpr1, jaxpr2, outUnknowns, numRes } = partialEvalJaxpr(jaxpr, inUnknowns);
3491
3754
  const [knownTracers, unknownTracers] = partitionList(inUnknowns, tracers);
3492
3755
  const outs1Res = bind(Primitive.JitCall, knownTracers.map((t) => t.ref.fullLower()), {
3756
+ name: `${name}_peval`,
3493
3757
  jaxpr: jaxpr1,
3494
3758
  numConsts: 0
3495
3759
  });
@@ -3501,6 +3765,7 @@ var PartialEvalTrace = class extends Trace {
3501
3765
  prim: Primitive.JitCall,
3502
3766
  tracersIn: resTracers.concat(unknownTracers),
3503
3767
  params: {
3768
+ name: `${name}_resid`,
3504
3769
  jaxpr: jaxpr2,
3505
3770
  numConsts: 0
3506
3771
  },
@@ -3643,7 +3908,7 @@ function evalJaxprTransposed(jaxpr, args, cotangents) {
3643
3908
  }
3644
3909
  for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
3645
3910
  const eqn = jaxpr.eqns[i];
3646
- const primalsIn = eqn.inputs.map((v) => v instanceof Lit ? scalar(v.value, { dtype: v.dtype }) : knownPrimals.has(v) ? knownPrimals.get(v).ref : new UndefPrimal(v.aval));
3911
+ const primalsIn = eqn.inputs.map((v) => v instanceof Lit ? array(v.value, { dtype: v.dtype }) : knownPrimals.has(v) ? knownPrimals.get(v).ref : new UndefPrimal(v.aval));
3647
3912
  const cotangentsOut = eqn.outBinders.map(readCotangent);
3648
3913
  const rule = transposeRules[eqn.primitive];
3649
3914
  if (!rule) throw new TypeError(`Backward pass not implemented for ${eqn.primitive}`);
@@ -3823,7 +4088,7 @@ const transposeRules = {
3823
4088
  if (indices.some((i) => i instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
3824
4089
  throw new Error("Gather transpose rule is not yet implemented, requires complex Scatter sum operation");
3825
4090
  },
3826
- [Primitive.JitCall](cts, args, { jaxpr }) {
4091
+ [Primitive.JitCall](cts, args, { name, jaxpr }) {
3827
4092
  const undefPrimals = args.map((x) => x instanceof UndefPrimal);
3828
4093
  const { newJaxpr, newConsts } = transposeJaxpr(jaxpr, undefPrimals);
3829
4094
  const residuals = args.filter((x, i$1) => !undefPrimals[i$1]);
@@ -3832,6 +4097,7 @@ const transposeRules = {
3832
4097
  ...residuals,
3833
4098
  ...cts
3834
4099
  ], {
4100
+ name: `${name}_t`,
3835
4101
  jaxpr: newJaxpr,
3836
4102
  numConsts: newConsts.length
3837
4103
  });
@@ -3906,7 +4172,7 @@ function valueAndGrad$1(f) {
3906
4172
  const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
3907
4173
  if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
3908
4174
  if (!isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
3909
- const [ct, ...rest] = fVjp(scalar(1, { dtype: y.dtype }));
4175
+ const [ct, ...rest] = fVjp(onesLike$1(y.ref));
3910
4176
  for (const r of rest) dispose(r);
3911
4177
  fVjp.dispose();
3912
4178
  return [y, ct];
@@ -3934,7 +4200,10 @@ __export(lax_exports, {
3934
4200
  conv: () => conv$1,
3935
4201
  convGeneralDilated: () => convGeneralDilated,
3936
4202
  convWithGeneralPadding: () => convWithGeneralPadding,
3937
- reduceWindow: () => reduceWindow
4203
+ erf: () => erf,
4204
+ erfc: () => erfc,
4205
+ reduceWindow: () => reduceWindow,
4206
+ stopGradient: () => stopGradient$1
3938
4207
  });
3939
4208
  function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
3940
4209
  const padType = padding.toUpperCase();
@@ -3993,6 +4262,28 @@ function reduceWindow(operand, computation, windowDimensions, windowStrides) {
3993
4262
  strides: windowStrides
3994
4263
  }));
3995
4264
  }
4265
+ /** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
4266
+ function erf(x) {
4267
+ return erf$1(x);
4268
+ }
4269
+ /**
4270
+ * The complementary error function: `erfc(x) = 1 - erf(x)`.
4271
+ *
4272
+ * This function is more accurate than `1 - erf(x)` for large values of `x`,
4273
+ * where `erf(x)` is very close to 1.
4274
+ */
4275
+ function erfc(x) {
4276
+ return erfc$1(x);
4277
+ }
4278
+ /**
4279
+ * Stops gradient computation.
4280
+ *
4281
+ * Behaves as the identity function but prevents the flow of gradients during
4282
+ * forward or reverse-mode automatic differentiation.
4283
+ */
4284
+ function stopGradient$1(x) {
4285
+ return stopGradient(x);
4286
+ }
3996
4287
 
3997
4288
  //#endregion
3998
4289
  //#region src/numpy.ts
@@ -4055,6 +4346,9 @@ __export(numpy_exports, {
4055
4346
  fullLike: () => fullLike$1,
4056
4347
  greater: () => greater,
4057
4348
  greaterEqual: () => greaterEqual,
4349
+ hamming: () => hamming,
4350
+ hann: () => hann,
4351
+ heaviside: () => heaviside,
4058
4352
  hstack: () => hstack,
4059
4353
  hypot: () => hypot,
4060
4354
  identity: () => identity$1,
@@ -4276,7 +4570,7 @@ function argmin(a, axis, opts) {
4276
4570
  } else axis = checkAxis(axis, a.ndim);
4277
4571
  const shape$1 = a.shape;
4278
4572
  const isMax = equal(a, min(a.ref, axis, { keepdims: true }));
4279
- const length = scalar(shape$1[axis], {
4573
+ const length = array(shape$1[axis], {
4280
4574
  dtype: int32,
4281
4575
  device: a.device
4282
4576
  });
@@ -4300,7 +4594,7 @@ function argmax(a, axis, opts) {
4300
4594
  } else axis = checkAxis(axis, a.ndim);
4301
4595
  const shape$1 = a.shape;
4302
4596
  const isMax = equal(a, max(a.ref, axis, { keepdims: true }));
4303
- const length = scalar(shape$1[axis], {
4597
+ const length = array(shape$1[axis], {
4304
4598
  dtype: int32,
4305
4599
  device: a.device
4306
4600
  });
@@ -4694,6 +4988,32 @@ function sign(x) {
4694
4988
  x = fudgeArray(x);
4695
4989
  return where(notEqual(x.ref, 0), where(less(x.ref, 0), -1, 1), 0);
4696
4990
  }
4991
+ /**
4992
+ * Return the Hamming window of size M, a taper with a weighted cosine bell.
4993
+ *
4994
+ * `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
4995
+ */
4996
+ function hamming(M) {
4997
+ return cos(linspace(0, 2 * Math.PI, M)).mul(-.46).add(.54);
4998
+ }
4999
+ /**
5000
+ * Return the Hann window of size M, a taper with a weighted cosine bell.
5001
+ *
5002
+ * `w(n) = 0.5 - 0.5 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
5003
+ */
5004
+ function hann(M) {
5005
+ return cos(linspace(0, 2 * Math.PI, M)).mul(-.5).add(.5);
5006
+ }
5007
+ /**
5008
+ * @function
5009
+ * Compute the Heaviside step function. It is defined piecewise:
5010
+ * - `heaviside(x1, x2) = 0` for `x1 < 0`,
5011
+ * - `heaviside(x1, x2) = x2` for `x1 == 0`,
5012
+ * - `heaviside(x1, x2) = 1` for `x1 > 0`.
5013
+ */
5014
+ const heaviside = jit$1(function heaviside$1(x1, x2) {
5015
+ return where(less(x1.ref, 0), 0, where(equal(x1, 0), x2, 1));
5016
+ });
4697
5017
  /** Calculate element-wise square of the input array. */
4698
5018
  function square(x) {
4699
5019
  x = fudgeArray(x);
@@ -4713,10 +5033,10 @@ function acos(x) {
4713
5033
  * Return element-wise hypotenuse for the given legs of a right triangle.
4714
5034
  *
4715
5035
  * In the original NumPy/JAX implementation, this function is more numerically
4716
- * stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
4717
- * improvements.
5036
+ * stable than `sqrt(x1**2 + x2**2)`. We don't currently implement those
5037
+ * stability improvements.
4718
5038
  */
4719
- const hypot = jit$1((x1, x2) => {
5039
+ const hypot = jit$1(function hypot$1(x1, x2) {
4720
5040
  return sqrt(square(x1).add(square(x2)));
4721
5041
  });
4722
5042
  /**
@@ -4732,7 +5052,7 @@ const hypot = jit$1((x1, x2) => {
4732
5052
  *
4733
5053
  * The output is ill-defined when both x and y are zero.
4734
5054
  */
4735
- const atan2 = jit$1((y, x) => {
5055
+ const atan2 = jit$1(function atan2$1(y, x) {
4736
5056
  const r = sqrt(square(x.ref).add(square(y.ref)));
4737
5057
  const xNeg = less(x.ref, 0);
4738
5058
  const numer = where(xNeg.ref, r.ref.sub(x.ref), y.ref);
@@ -4800,13 +5120,13 @@ const degrees = rad2deg;
4800
5120
  * @function
4801
5121
  * Computes first array raised to power of second array, element-wise.
4802
5122
  */
4803
- const power = jit$1((x1, x2) => {
5123
+ const power = jit$1(function power$1(x1, x2) {
4804
5124
  return exp(log(x1).mul(x2));
4805
5125
  });
4806
5126
  /** @function Alias of `jax.numpy.power()`. */
4807
5127
  const pow = power;
4808
5128
  /** @function Calculate the element-wise cube root of the input array. */
4809
- const cbrt = jit$1((x) => {
5129
+ const cbrt = jit$1(function cbrt$1(x) {
4810
5130
  const sgn = where(less(x.ref, 0), -1, 1);
4811
5131
  return sgn.ref.mul(exp(log(x.mul(sgn)).mul(1 / 3)));
4812
5132
  });
@@ -4816,7 +5136,7 @@ const cbrt = jit$1((x) => {
4816
5136
  *
4817
5137
  * `sinh(x) = (exp(x) - exp(-x)) / 2`
4818
5138
  */
4819
- const sinh = jit$1((x) => {
5139
+ const sinh = jit$1(function sinh$1(x) {
4820
5140
  const ex = exp(x);
4821
5141
  const emx = reciprocal(ex.ref);
4822
5142
  return ex.sub(emx).mul(.5);
@@ -4827,7 +5147,7 @@ const sinh = jit$1((x) => {
4827
5147
  *
4828
5148
  * `cosh(x) = (exp(x) + exp(-x)) / 2`
4829
5149
  */
4830
- const cosh = jit$1((x) => {
5150
+ const cosh = jit$1(function cosh$1(x) {
4831
5151
  const ex = exp(x);
4832
5152
  const emx = reciprocal(ex.ref);
4833
5153
  return ex.add(emx).mul(.5);
@@ -4838,7 +5158,7 @@ const cosh = jit$1((x) => {
4838
5158
  *
4839
5159
  * `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
4840
5160
  */
4841
- const tanh = jit$1((x) => {
5161
+ const tanh = jit$1(function tanh$1(x) {
4842
5162
  const negsgn = where(less(x.ref, 0), 1, -1);
4843
5163
  const en2x = exp(x.mul(negsgn.ref).mul(2));
4844
5164
  return en2x.ref.sub(1).div(en2x.add(1)).mul(negsgn);
@@ -4849,7 +5169,7 @@ const tanh = jit$1((x) => {
4849
5169
  *
4850
5170
  * `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
4851
5171
  */
4852
- const arcsinh = jit$1((x) => {
5172
+ const arcsinh = jit$1(function arcsinh$1(x) {
4853
5173
  return log(x.ref.add(sqrt(square(x).add(1))));
4854
5174
  });
4855
5175
  /**
@@ -4858,7 +5178,7 @@ const arcsinh = jit$1((x) => {
4858
5178
  *
4859
5179
  * `arccosh(x) = ln(x + sqrt(x^2 - 1))`
4860
5180
  */
4861
- const arccosh = jit$1((x) => {
5181
+ const arccosh = jit$1(function arccosh$1(x) {
4862
5182
  return log(x.ref.add(sqrt(square(x).sub(1))));
4863
5183
  });
4864
5184
  /**
@@ -4867,7 +5187,7 @@ const arccosh = jit$1((x) => {
4867
5187
  *
4868
5188
  * `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
4869
5189
  */
4870
- const arctanh = jit$1((x) => {
5190
+ const arctanh = jit$1(function arctanh$1(x) {
4871
5191
  return log(add(1, x.ref).div(subtract(1, x))).mul(.5);
4872
5192
  });
4873
5193
  /** @function Alias of `jax.numpy.arcsinh()`. */
@@ -4983,7 +5303,9 @@ function softSign(x) {
4983
5303
  *
4984
5304
  * Reference: https://en.wikipedia.org/wiki/Swish_function
4985
5305
  */
4986
- const silu = jit$1((x) => x.ref.mul(sigmoid(x)));
5306
+ const silu = jit$1(function silu$1(x) {
5307
+ return x.ref.mul(sigmoid(x));
5308
+ });
4987
5309
  /**
4988
5310
  * @function
4989
5311
  * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
@@ -5036,18 +5358,20 @@ function celu(x, alpha = 1) {
5036
5358
  * @function
5037
5359
  * Gaussion error linear unit (GELU) activation function.
5038
5360
  *
5039
- * This is computed element-wise. Currently jax-js does not support the erf() or
5040
- * gelu() functions exactly as primitives, so an approximation is used:
5041
- * `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
5361
+ * This is computed element-wise. There are two variants depending on whether
5362
+ * `approximate` is set (default true):
5042
5363
  *
5043
- * Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
5364
+ * - Approximate: `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`
5365
+ * - Exact: `gelu(x) = x * 0.5 * erfc(-x / sqrt(2))`
5044
5366
  *
5045
- * This will be improved in the future.
5367
+ * Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
5046
5368
  */
5047
- const gelu = jit$1((x) => {
5048
- const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
5049
- return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
5050
- });
5369
+ const gelu = jit$1(function gelu$1(x, opts) {
5370
+ if (opts?.approximate ?? true) {
5371
+ const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
5372
+ return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
5373
+ } else return x.ref.mul(.5).mul(erfc$1(negative(x.ref.mul(Math.SQRT1_2))));
5374
+ }, { staticArgnums: [1] });
5051
5375
  /**
5052
5376
  * Gated linear unit (GLU) activation function.
5053
5377
  *
@@ -5215,8 +5539,11 @@ function bits(key$1, shape$1 = []) {
5215
5539
  const keyShape = validateKeyShape(key$1);
5216
5540
  return randomBits(key$1.ref.slice(...keyShape.map(() => null), 0), key$1.slice(...keyShape.map(() => null), 1), shape$1);
5217
5541
  }
5218
- /** Sample uniform random values in [minval, maxval) with given shape. */
5219
- function uniform(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
5542
+ /**
5543
+ * @function
5544
+ * Sample uniform random values in [minval, maxval) with given shape.
5545
+ */
5546
+ const uniform = jit$1(function uniform$1(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
5220
5547
  if (minval >= maxval) throw new Error(`Invalid range: [${minval}, ${maxval}).`);
5221
5548
  const mantissa = bits(key$1, shape$1).div(array(512, {
5222
5549
  dtype: DType.Uint32,
@@ -5229,7 +5556,7 @@ function uniform(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
5229
5556
  const rand = bitcast(float12, DType.Float32).sub(1);
5230
5557
  if (minval === 0 && maxval === 1) return rand;
5231
5558
  else return rand.mul(maxval - minval).add(minval);
5232
- }
5559
+ }, { staticArgnums: [1, 2] });
5233
5560
  /**
5234
5561
  * Sample Bernoulli random variables with given mean (0,1 categorical).
5235
5562
  *
@@ -5240,26 +5567,49 @@ function bernoulli(key$1, p = .5, shape$1 = []) {
5240
5567
  p = fudgeArray(p);
5241
5568
  return uniform(key$1, shape$1).less(p);
5242
5569
  }
5243
- /** Sample exponential random values according to `p(x) = exp(-x)`. */
5244
- function exponential(key$1, shape$1 = []) {
5570
+ /**
5571
+ * @function
5572
+ * Sample exponential random values according to `p(x) = exp(-x)`.
5573
+ */
5574
+ const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
5245
5575
  const u = uniform(key$1, shape$1);
5246
5576
  return negative(log1p(negative(u)));
5247
- }
5577
+ }, { staticArgnums: [1] });
5248
5578
  /**
5579
+ * @function
5249
5580
  * Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
5250
5581
  *
5251
5582
  * Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
5252
5583
  * directly inverts the CDF, but we don't have support for that yet. Outputs will not be
5253
5584
  * bitwise identical to JAX.
5254
5585
  */
5255
- function normal(key$1, shape$1 = []) {
5586
+ const normal = jit$1(function normal$1(key$1, shape$1 = []) {
5256
5587
  const [k1, k2] = split(key$1, 2);
5257
5588
  const u1 = uniform(k1, shape$1);
5258
5589
  const u2 = uniform(k2, shape$1);
5259
5590
  const radius = sqrt(log1p(negative(u1)).mul(-2));
5260
5591
  const theta = u2.mul(2 * Math.PI);
5261
5592
  return radius.mul(cos(theta));
5262
- }
5593
+ }, { staticArgnums: [1] });
5594
+
5595
+ //#endregion
5596
+ //#region src/scipy-special.ts
5597
+ var scipy_special_exports = {};
5598
+ __export(scipy_special_exports, {
5599
+ erf: () => erf,
5600
+ erfc: () => erfc,
5601
+ logSoftmax: () => logSoftmax,
5602
+ logit: () => logit,
5603
+ logsumexp: () => logsumexp,
5604
+ softmax: () => softmax
5605
+ });
5606
+ /**
5607
+ * @function
5608
+ * The logit function, `logit(p) = log(p / (1-p))`.
5609
+ */
5610
+ const logit = jit$1(function logit$1(x) {
5611
+ return log(x.ref.div(subtract(1, x)));
5612
+ });
5263
5613
 
5264
5614
  //#endregion
5265
5615
  //#region src/polyfills.ts
@@ -5354,6 +5704,25 @@ async function blockUntilReady(x) {
5354
5704
  await Promise.all(promises);
5355
5705
  return x;
5356
5706
  }
5707
+ /**
5708
+ * Transfer `x` to `device`.
5709
+ *
5710
+ * `x` may be a nested container of arrays or scalars. The resulting structure
5711
+ * is committed to the device.
5712
+ *
5713
+ * If `device` is not specified, this function behaves as identity if the input
5714
+ * is already an `Array`, otherwise it places the scalar uncommitted on the
5715
+ * default device.
5716
+ */
5717
+ async function devicePut(x, device) {
5718
+ const [xflat, structure$1] = flatten(x);
5719
+ const yflat = await Promise.all(xflat.map((leaf) => {
5720
+ if (leaf instanceof Array$1) return device ? leaf._put(getBackend(device)) : Promise.resolve(leaf);
5721
+ else return Promise.resolve(array(leaf, { device }));
5722
+ }));
5723
+ return unflatten(structure$1, yflat);
5724
+ }
5357
5725
 
5358
5726
  //#endregion
5359
- export { Array$1 as Array, DType, Jaxpr, blockUntilReady, defaultDevice, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
5727
+ export { Array$1 as Array, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
5728
+ //# sourceMappingURL=index.js.map