@jax-js/jax 0.0.5 → 0.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -1,5 +1,5 @@
1
1
  import { __export } from "./chunk-Cl8Af3a2.js";
2
- import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-CdcTZEOF.js";
2
+ import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-CoVtc9dx.js";
3
3
 
4
4
  //#region src/tree.ts
5
5
  var tree_exports = {};
@@ -29,6 +29,10 @@ var JsTreeDef = class JsTreeDef {
29
29
  this.nodeMetadata = nodeMetadata;
30
30
  this.childTreedefs = childTreedefs;
31
31
  }
32
+ /** Get the total number of leaves in the tree. */
33
+ get size() {
34
+ return this.nodeType === NodeType.Leaf ? 1 : this.childTreedefs.reduce((a, b) => a + b.size, 0);
35
+ }
32
36
  /** Returns a string representation of this tree definition. */
33
37
  toString(root = true) {
34
38
  if (root) return "JsTreeDef(" + this.toString(false) + ")";
@@ -184,6 +188,16 @@ function pool(st, ks, strides = 1, dilation = 1) {
184
188
  const s_ = strides;
185
189
  const d_ = dilation;
186
190
  const o_ = zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
191
+ if (d_.every((d) => d === 1) && ks.every((k, j) => k <= s_[j])) {
192
+ st = st.padOrShrink([...noop.map(() => [0, 0]), ...zipn(i_, o_, s_).map(([i, o, s]) => [0, o * s - i])]);
193
+ st = st.reshape([...noop, ...zip(o_, s_).flatMap(([o, s]) => [o, s])]).shrink([...noop.map((x) => [0, x]), ...zip(o_, ks).flatMap(([o, k]) => [[0, o], [0, k]])]);
194
+ st = st.permute([
195
+ ...range(noop.length),
196
+ ...ks.map((_, j) => noop.length + 2 * j),
197
+ ...ks.map((_, j) => noop.length + 2 * j + 1)
198
+ ]);
199
+ return st;
200
+ }
187
201
  const f_ = zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
188
202
  const kidf = zipn(ks, i_, d_, f_);
189
203
  st = st.repeat([...rep(noop.length, 1), ...kidf.map(([k, i, d, f]) => Math.ceil(k * (i * f + d) / i))]);
@@ -218,6 +232,12 @@ function poolTranspose(st, inShape, ks, strides = 1, dilation = 1) {
218
232
  const s_ = strides;
219
233
  const d_ = dilation;
220
234
  const o_ = zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
235
+ if (d_.every((d) => d === 1) && ks.every((k, j) => k <= s_[j])) {
236
+ st = st.permute([...range(noop.length), ...ks.flatMap((_, j) => [noop.length + j, noop.length + o_.length + j])]);
237
+ st = st.pad([...noop.map(() => [0, 0]), ...zip(s_, ks).flatMap(([s, k]) => [[0, 0], [0, s - k]])]).reshape([...noop, ...zip(o_, s_).map(([o, s]) => o * s)]);
238
+ st = st.padOrShrink([...noop.map(() => [0, 0]), ...zipn(i_, o_, s_).map(([i, o, s]) => [0, i - o * s])]);
239
+ return st.reshape(st.shape.concat(rep(ks.length, 1)));
240
+ }
221
241
  if (!deepEqual(o_, st.shape.slice(noop.length, noop.length + ks.length))) throw new Error("poolTranspose() called with mismatched output shape");
222
242
  const f_ = zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
223
243
  const kidf = zipn(ks, i_, d_, f_);
@@ -327,6 +347,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
327
347
  Primitive$1["Atan"] = "atan";
328
348
  Primitive$1["Exp"] = "exp";
329
349
  Primitive$1["Log"] = "log";
350
+ Primitive$1["Erf"] = "erf";
351
+ Primitive$1["Erfc"] = "erfc";
330
352
  Primitive$1["Sqrt"] = "sqrt";
331
353
  Primitive$1["Min"] = "min";
332
354
  Primitive$1["Max"] = "max";
@@ -348,11 +370,9 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
348
370
  return Primitive$1;
349
371
  }({});
350
372
  let CompareOp = /* @__PURE__ */ function(CompareOp$1) {
351
- CompareOp$1["Greater"] = "greater";
352
373
  CompareOp$1["Less"] = "less";
353
374
  CompareOp$1["Equal"] = "equal";
354
375
  CompareOp$1["NotEqual"] = "not_equal";
355
- CompareOp$1["GreaterEqual"] = "greater_equal";
356
376
  CompareOp$1["LessEqual"] = "less_equal";
357
377
  return CompareOp$1;
358
378
  }({});
@@ -404,6 +424,12 @@ function exp$1(x) {
404
424
  function log$1(x) {
405
425
  return bind1(Primitive.Log, [x]);
406
426
  }
427
+ function erf$1(x) {
428
+ return bind1(Primitive.Erf, [x]);
429
+ }
430
+ function erfc$1(x) {
431
+ return bind1(Primitive.Erfc, [x]);
432
+ }
407
433
  function sqrt$1(x) {
408
434
  return bind1(Primitive.Sqrt, [x]);
409
435
  }
@@ -442,7 +468,7 @@ function compare(x, y, op) {
442
468
  return bind1(Primitive.Compare, [x, y], { op });
443
469
  }
444
470
  function greater$1(x, y) {
445
- return compare(x, y, CompareOp.Greater);
471
+ return compare(y, x, CompareOp.Less);
446
472
  }
447
473
  function less$1(x, y) {
448
474
  return compare(x, y, CompareOp.Less);
@@ -454,7 +480,7 @@ function notEqual$1(x, y) {
454
480
  return compare(x, y, CompareOp.NotEqual);
455
481
  }
456
482
  function greaterEqual$1(x, y) {
457
- return compare(x, y, CompareOp.GreaterEqual);
483
+ return compare(y, x, CompareOp.LessEqual);
458
484
  }
459
485
  function lessEqual$1(x, y) {
460
486
  return compare(x, y, CompareOp.LessEqual);
@@ -1146,12 +1172,18 @@ function reshapeViews(exp$2, mapping, reduceAxis = false) {
1146
1172
  } else if (exp$3.op === AluOp.GlobalIndex) throw new Error("internal: reshapeViews() called with GlobalIndex op");
1147
1173
  });
1148
1174
  }
1149
- function broadcastedJit(fn) {
1175
+ function broadcastedJit(fn, opts) {
1150
1176
  return (nargs, exps, avals, params) => {
1151
- const newShape = avals.map((aval) => aval.shape).reduce(generalBroadcast);
1152
- exps = exps.map((exp$3) => reshapeViews(exp$3, (st) => {
1153
- if (!deepEqual(st.shape, newShape)) return st.broadcast(newShape, range(newShape.length - st.shape.length));
1154
- }));
1177
+ let { shape: newShape, dtype: newDtype } = avals.reduce(promoteAvals);
1178
+ const skipCastIdx = opts?.skipCastIdx ?? [];
1179
+ if (skipCastIdx.length) newDtype = avals.filter((_, i) => !skipCastIdx.includes(i)).reduce(promoteAvals).dtype;
1180
+ exps = exps.map((exp$3, i) => {
1181
+ exp$3 = reshapeViews(exp$3, (st) => {
1182
+ if (!deepEqual(st.shape, newShape)) return st.broadcast(newShape, range(newShape.length - st.shape.length));
1183
+ });
1184
+ if (exp$3.dtype !== newDtype && !skipCastIdx.includes(i)) exp$3 = AluExp.cast(newDtype, exp$3);
1185
+ return exp$3;
1186
+ });
1155
1187
  const exp$2 = fn(exps, params);
1156
1188
  return new Kernel(nargs, prod(newShape), exp$2);
1157
1189
  };
@@ -1194,6 +1226,8 @@ const jitRules = {
1194
1226
  [Primitive.Atan]: unopJit(AluExp.atan),
1195
1227
  [Primitive.Exp]: unopJit(AluExp.exp),
1196
1228
  [Primitive.Log]: unopJit(AluExp.log),
1229
+ [Primitive.Erf]: unopJit(AluExp.erf),
1230
+ [Primitive.Erfc]: unopJit(AluExp.erfc),
1197
1231
  [Primitive.Sqrt]: unopJit(AluExp.sqrt),
1198
1232
  [Primitive.Min]: broadcastedJit(([a, b]) => AluExp.min(a, b)),
1199
1233
  [Primitive.Max]: broadcastedJit(([a, b]) => AluExp.max(a, b)),
@@ -1241,7 +1275,7 @@ const jitRules = {
1241
1275
  return jitRules[Primitive.Dot](nargs, [a, b], [as, bs], {});
1242
1276
  },
1243
1277
  [Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
1244
- [Primitive.Where]: broadcastedJit(([cond, a, b]) => AluExp.where(cond, a, b)),
1278
+ [Primitive.Where]: broadcastedJit(([cond, a, b]) => AluExp.where(cond, a, b), { skipCastIdx: [0] }),
1245
1279
  [Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
1246
1280
  [Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
1247
1281
  [Primitive.Reshape]: reshapeJit((st, { shape: shape$1 }) => st.reshape(shape$1)),
@@ -1412,7 +1446,7 @@ var PendingExecute = class {
1412
1446
  /**
1413
1447
  * A multidimensional numeric array with data stored on CPU or GPU.
1414
1448
  *
1415
- * This is the library's core data type. Equivalent to `jnp.Array` from JAX, or
1449
+ * This is the library's core data type. Equivalent to `jax.Array` from JAX, or
1416
1450
  * `torch.Tensor`.
1417
1451
  *
1418
1452
  * Not to be confused with the JavaScript "Array" constructor. Avoid importing
@@ -1427,6 +1461,7 @@ var Array$1 = class Array$1 extends Tracer {
1427
1461
  #source;
1428
1462
  #st;
1429
1463
  #backend;
1464
+ #committed;
1430
1465
  #rc;
1431
1466
  #pendingSet;
1432
1467
  /**
@@ -1443,6 +1478,7 @@ var Array$1 = class Array$1 extends Tracer {
1443
1478
  this.#source = args.source;
1444
1479
  this.#st = args.st;
1445
1480
  this.#backend = args.backend;
1481
+ this.#committed = args.committed;
1446
1482
  this.#rc = 1;
1447
1483
  this.#pendingSet = new Set(args.pending);
1448
1484
  if (this.#pendingSet.size === 0) this.#pendingSet = null;
@@ -1470,6 +1506,7 @@ var Array$1 = class Array$1 extends Tracer {
1470
1506
  dtype: args.dtype ?? this.#dtype,
1471
1507
  weakType: this.#weakType,
1472
1508
  backend: args.backend ?? this.#backend,
1509
+ committed: args.committed ?? this.#committed,
1473
1510
  pending: args.pending ?? this.#pending ?? void 0
1474
1511
  });
1475
1512
  }
@@ -1525,9 +1562,10 @@ var Array$1 = class Array$1 extends Tracer {
1525
1562
  */
1526
1563
  #gather(indices, axis, outDim) {
1527
1564
  this.#check();
1528
- if (indices.some((a) => a.#backend !== this.#backend)) throw new TypeError(`Gather indices must have the same backend: ${this.#backend.type}`);
1529
1565
  const axisSet = new Set(axis);
1530
1566
  if (axisSet.size !== axis.length) throw new TypeError("Gather axis must not have duplicates");
1567
+ if (indices.some((a) => a.#committed && a.#backend !== this.#backend)) throw new TypeError(`Gather indices must have the same backend: ${this.#backend.type}`);
1568
+ indices = indices.map((ar) => ar._putSync(this.#backend));
1531
1569
  indices = Array$1.#broadcastArrays(indices);
1532
1570
  const indexShape = indices[0].shape;
1533
1571
  const finalShape = this.shape.filter((_, i) => !axisSet.has(i));
@@ -1596,6 +1634,7 @@ var Array$1 = class Array$1 extends Tracer {
1596
1634
  this.#check();
1597
1635
  if (this.#source instanceof AluExp) {
1598
1636
  const exp$3 = new AluExp(op, dtypeOutput, [this.#source]);
1637
+ this.dispose();
1599
1638
  return this.#newArrayFrom({
1600
1639
  source: exp$3.simplify(),
1601
1640
  dtype: dtypeOutput,
@@ -1624,21 +1663,19 @@ var Array$1 = class Array$1 extends Tracer {
1624
1663
  }
1625
1664
  static #naryCustom(name, custom, arrays, { dtypeOverride, strongTypeOutput, reduceAxis } = {}) {
1626
1665
  const n = arrays.length;
1627
- const backend = arrays[0].#backend;
1628
1666
  if (n === 0) throw new TypeError(`No inputs for ${name}`);
1629
1667
  for (const ar of arrays) ar.#check();
1630
1668
  let castDtype;
1631
1669
  let castWeakType = true;
1632
- for (let i = 0; i < n; i++) {
1633
- if (dtypeOverride?.[i]) {
1634
- if (arrays[i].#dtype !== dtypeOverride[i]) throw new TypeError(`Wrong dtype in ${name}: expected ${dtypeOverride[i]}, got ${arrays[i].#dtype}`);
1635
- } else if (castDtype === void 0) {
1636
- castDtype = arrays[i].#dtype;
1637
- castWeakType = arrays[i].#weakType;
1638
- } else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), new ShapedArray([], arrays[i].#dtype, arrays[i].#weakType)));
1639
- if (arrays[i].#backend !== backend) throw new TypeError(`Backend mismatch in ${name}: ${backend.type} vs ${arrays[i].#backend.type}`);
1640
- }
1670
+ for (let i = 0; i < n; i++) if (dtypeOverride?.[i]) {
1671
+ if (arrays[i].#dtype !== dtypeOverride[i]) throw new TypeError(`Wrong dtype in ${name}: expected ${dtypeOverride[i]}, got ${arrays[i].#dtype}`);
1672
+ } else if (castDtype === void 0) {
1673
+ castDtype = arrays[i].#dtype;
1674
+ castWeakType = arrays[i].#weakType;
1675
+ } else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), new ShapedArray([], arrays[i].#dtype, arrays[i].#weakType)));
1641
1676
  const weakType = castWeakType && !strongTypeOutput;
1677
+ const { backend, committed } = Array$1.#computeBackend(name, arrays);
1678
+ arrays = arrays.map((ar) => ar._putSync(backend));
1642
1679
  arrays = Array$1.#broadcastArrays(arrays);
1643
1680
  const newShape = [...arrays[0].shape];
1644
1681
  if (arrays.every((ar) => ar.#source instanceof AluExp) && !reduceAxis) {
@@ -1648,12 +1685,14 @@ var Array$1 = class Array$1 extends Tracer {
1648
1685
  });
1649
1686
  if (arrays.every((ar) => deepEqual(ar.#st, arrays[0].#st))) {
1650
1687
  const exp$4 = custom(sources);
1688
+ arrays.forEach((ar) => ar.dispose());
1651
1689
  return new Array$1({
1652
1690
  source: exp$4.simplify(),
1653
1691
  st: arrays[0].#st,
1654
1692
  dtype: exp$4.dtype,
1655
1693
  weakType,
1656
- backend
1694
+ backend,
1695
+ committed
1657
1696
  });
1658
1697
  }
1659
1698
  const exp$3 = custom(arrays.map((ar, i) => {
@@ -1662,12 +1701,14 @@ var Array$1 = class Array$1 extends Tracer {
1662
1701
  return accessorAluExp(src$1, ar.#st, unravelAlu(newShape, AluVar.idx));
1663
1702
  }));
1664
1703
  const st = ShapeTracker.fromShape(newShape);
1704
+ arrays.forEach((ar) => ar.dispose());
1665
1705
  return new Array$1({
1666
1706
  source: exp$3.simplify(),
1667
1707
  st,
1668
1708
  dtype: exp$3.dtype,
1669
1709
  weakType,
1670
- backend
1710
+ backend,
1711
+ committed
1671
1712
  });
1672
1713
  }
1673
1714
  let indices;
@@ -1703,13 +1744,14 @@ var Array$1 = class Array$1 extends Tracer {
1703
1744
  const pending = new Set([...arrays.flatMap((ar) => ar.#pending)]);
1704
1745
  for (const exe of pending) exe.updateRc(1);
1705
1746
  pending.add(new PendingExecute(backend, kernel, inputs, [output]));
1706
- for (const ar of arrays) ar.dispose();
1747
+ arrays.forEach((ar) => ar.dispose());
1707
1748
  return new Array$1({
1708
1749
  source: output,
1709
1750
  st: ShapeTracker.fromShape(newShape),
1710
1751
  dtype: kernel.dtype,
1711
1752
  weakType,
1712
1753
  backend,
1754
+ committed,
1713
1755
  pending
1714
1756
  });
1715
1757
  }
@@ -1787,6 +1829,23 @@ var Array$1 = class Array$1 extends Tracer {
1787
1829
  return ar.#reshape(ar.#st.broadcast(newShape, range(newShape.length - ar.ndim)));
1788
1830
  });
1789
1831
  }
1832
+ static #computeBackend(name, arrays) {
1833
+ const committed = arrays.filter((ar) => ar.#committed);
1834
+ if (committed.length > 0) {
1835
+ const backend = committed[0].#backend;
1836
+ for (const ar of committed) if (ar.#backend !== backend) throw new Error(`Device mismatch in ${name} between committed arrays on (${backend.type}, ${ar.#backend.type}), please move to the same device with devicePut()`);
1837
+ return {
1838
+ backend,
1839
+ committed: true
1840
+ };
1841
+ } else {
1842
+ const backend = arrays.length > 0 ? arrays[0].#backend : getBackend();
1843
+ return {
1844
+ backend,
1845
+ committed: false
1846
+ };
1847
+ }
1848
+ }
1790
1849
  /** Realize the array and return it as data. */
1791
1850
  async data() {
1792
1851
  if (this.#source instanceof AluExp && this.size < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
@@ -1946,6 +2005,12 @@ var Array$1 = class Array$1 extends Tracer {
1946
2005
  [Primitive.Log]([x]) {
1947
2006
  return [x.#unary(AluOp.Log)];
1948
2007
  },
2008
+ [Primitive.Erf]([x]) {
2009
+ return [x.#unary(AluOp.Erf)];
2010
+ },
2011
+ [Primitive.Erfc]([x]) {
2012
+ return [x.#unary(AluOp.Erfc)];
2013
+ },
1949
2014
  [Primitive.Sqrt]([x]) {
1950
2015
  return [x.#unary(AluOp.Sqrt)];
1951
2016
  },
@@ -2014,7 +2079,8 @@ var Array$1 = class Array$1 extends Tracer {
2014
2079
  },
2015
2080
  [Primitive.JitCall](args, { jaxpr, numConsts }) {
2016
2081
  if (jaxpr.inBinders.length !== args.length) throw new Error(`jit_call expects ${jaxpr.inBinders.length} args, got ${args.length}`);
2017
- const backend = getBackend();
2082
+ const { backend, committed } = Array$1.#computeBackend("jit_call", args);
2083
+ args = args.map((ar) => ar._putSync(backend));
2018
2084
  const consts = args.slice(0, numConsts);
2019
2085
  const tracers = args.slice(numConsts);
2020
2086
  const jp = jitCompile(backend, jaxpr, consts);
@@ -2031,16 +2097,54 @@ var Array$1 = class Array$1 extends Tracer {
2031
2097
  dtype: jaxpr.outs[i].aval.dtype,
2032
2098
  weakType: jaxpr.outs[i].aval.weakType,
2033
2099
  backend,
2100
+ committed,
2034
2101
  pending
2035
2102
  });
2036
2103
  });
2037
2104
  }
2038
2105
  };
2039
2106
  }
2107
+ /** @private */
2040
2108
  _realizeSource() {
2041
2109
  this.#realize();
2042
2110
  return this.#source;
2043
2111
  }
2112
+ /** @private Put this array on a new backend, asynchronously. */
2113
+ async _put(backend) {
2114
+ if (this.#backend === backend) return this;
2115
+ if (this.#source instanceof AluExp) {
2116
+ const ar = this.#newArrayFrom({
2117
+ backend,
2118
+ committed: true
2119
+ });
2120
+ this.dispose();
2121
+ return ar;
2122
+ } else {
2123
+ const data = await this.data();
2124
+ return arrayFromData(data, this.shape, {
2125
+ dtype: this.#dtype,
2126
+ device: backend.type
2127
+ }, this.#weakType);
2128
+ }
2129
+ }
2130
+ /** @private Put this array on a new backend, synchronously. */
2131
+ _putSync(backend) {
2132
+ if (this.#backend === backend) return this;
2133
+ if (this.#source instanceof AluExp) {
2134
+ const ar = this.#newArrayFrom({
2135
+ backend,
2136
+ committed: true
2137
+ });
2138
+ this.dispose();
2139
+ return ar;
2140
+ } else {
2141
+ const data = this.dataSync();
2142
+ return arrayFromData(data, this.shape, {
2143
+ dtype: this.#dtype,
2144
+ device: backend.type
2145
+ }, this.#weakType);
2146
+ }
2147
+ }
2044
2148
  };
2045
2149
  /** Constructor for creating a new array from data. */
2046
2150
  function array(values, { shape: shape$1, dtype, device } = {}) {
@@ -2103,6 +2207,9 @@ function arrayFromData(data, shape$1, { dtype, device }, weakType = false) {
2103
2207
  } else if (data instanceof Float16Array) {
2104
2208
  if (dtype && dtype !== DType.Float16) throw new Error("Float16Array must have float16 type");
2105
2209
  dtype ??= DType.Float16;
2210
+ } else if (data instanceof Float64Array) {
2211
+ if (dtype && dtype !== DType.Float64) throw new Error("Float64Array must have float64 type");
2212
+ dtype ??= DType.Float64;
2106
2213
  } else throw new Error("Unsupported data array type: " + data.constructor.name);
2107
2214
  if (data.length < inlineArrayLimit) {
2108
2215
  let allEqual = true;
@@ -2123,7 +2230,8 @@ function arrayFromData(data, shape$1, { dtype, device }, weakType = false) {
2123
2230
  st: ShapeTracker.fromShape(shape$1),
2124
2231
  dtype,
2125
2232
  weakType,
2126
- backend
2233
+ backend,
2234
+ committed: device != void 0
2127
2235
  });
2128
2236
  }
2129
2237
  function dataToJs(dtype, data, shape$1) {
@@ -2157,7 +2265,8 @@ function fullInternal(aval, fillValue, device) {
2157
2265
  st: ShapeTracker.fromShape(aval.shape),
2158
2266
  dtype: aval.dtype,
2159
2267
  weakType: aval.weakType,
2160
- backend: getBackend(device)
2268
+ backend: getBackend(device),
2269
+ committed: device != void 0
2161
2270
  });
2162
2271
  }
2163
2272
  function zerosLike$1(val, dtype) {
@@ -2225,7 +2334,8 @@ function eye(numRows, numCols, { dtype, device } = {}) {
2225
2334
  st: ShapeTracker.fromShape([numRows, numCols]),
2226
2335
  dtype,
2227
2336
  weakType,
2228
- backend: getBackend(device)
2337
+ backend: getBackend(device),
2338
+ committed: device != void 0
2229
2339
  });
2230
2340
  }
2231
2341
  /** Return the identity matrix, with ones on the main diagonal. */
@@ -2268,7 +2378,8 @@ function arange(start, stop, step = 1, { dtype, device } = {}) {
2268
2378
  st,
2269
2379
  dtype,
2270
2380
  weakType: false,
2271
- backend: getBackend(device)
2381
+ backend: getBackend(device),
2382
+ committed: device != void 0
2272
2383
  });
2273
2384
  }
2274
2385
  /**
@@ -2304,16 +2415,15 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
2304
2415
  st,
2305
2416
  dtype,
2306
2417
  weakType: false,
2307
- backend: getBackend(device)
2418
+ backend: getBackend(device),
2419
+ committed: device != void 0
2308
2420
  });
2309
2421
  }
2310
2422
  function aluCompare(a, b, op) {
2311
2423
  switch (op) {
2312
- case CompareOp.Greater: return AluExp.mul(AluExp.cmpne(a, b), AluExp.cmplt(a, b).not());
2313
2424
  case CompareOp.Less: return AluExp.cmplt(a, b);
2314
2425
  case CompareOp.Equal: return AluExp.cmpne(a, b).not();
2315
2426
  case CompareOp.NotEqual: return AluExp.cmpne(a, b);
2316
- case CompareOp.GreaterEqual: return AluExp.cmplt(a, b).not();
2317
2427
  case CompareOp.LessEqual: return AluExp.add(AluExp.cmplt(a, b), AluExp.cmpne(a, b).not());
2318
2428
  }
2319
2429
  }
@@ -2446,7 +2556,7 @@ var JaxprEqn = class {
2446
2556
  const paramsList = Object.entries(this.params).map(([k, v]) => PPrint.pp(`${k}=${v}`));
2447
2557
  if (paramsList.length > 0) rhs = rhs.stack(PPrint.pp(" [ ")).stack(PPrint.prototype.concat(...paramsList)).stack(PPrint.pp(" ] "));
2448
2558
  else rhs = rhs.stack(PPrint.pp(" "));
2449
- rhs = rhs.stack(PPrint.pp(this.inputs.map((x) => x instanceof Var ? vp.name(x) : JSON.stringify(x.value)).join(" ")));
2559
+ rhs = rhs.stack(PPrint.pp(this.inputs.map((x) => x instanceof Var ? vp.name(x) : String(x.value)).join(" ")));
2450
2560
  return lhs.stack(PPrint.pp(" = ")).stack(rhs);
2451
2561
  }
2452
2562
  toString() {
@@ -2812,6 +2922,8 @@ const abstractEvalRules = {
2812
2922
  [Primitive.Atan]: vectorizedUnopAbstractEval,
2813
2923
  [Primitive.Exp]: vectorizedUnopAbstractEval,
2814
2924
  [Primitive.Log]: vectorizedUnopAbstractEval,
2925
+ [Primitive.Erf]: vectorizedUnopAbstractEval,
2926
+ [Primitive.Erfc]: vectorizedUnopAbstractEval,
2815
2927
  [Primitive.Sqrt]: vectorizedUnopAbstractEval,
2816
2928
  [Primitive.Min]: binopAbstractEval,
2817
2929
  [Primitive.Max]: binopAbstractEval,
@@ -3064,6 +3176,16 @@ const jvpRules = {
3064
3176
  [Primitive.Log]([x], [dx]) {
3065
3177
  return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
3066
3178
  },
3179
+ [Primitive.Erf]([x], [dx]) {
3180
+ const coeff = 2 / Math.sqrt(Math.PI);
3181
+ const expTerm = exp$1(neg(x.ref.mul(x.ref)));
3182
+ return [[erf$1(x)], [expTerm.mul(coeff).mul(dx)]];
3183
+ },
3184
+ [Primitive.Erfc]([x], [dx]) {
3185
+ const coeff = -2 / Math.sqrt(Math.PI);
3186
+ const expTerm = exp$1(neg(x.ref.mul(x.ref)));
3187
+ return [[erfc$1(x)], [expTerm.mul(coeff).mul(dx)]];
3188
+ },
3067
3189
  [Primitive.Sqrt]([x], [dx]) {
3068
3190
  const z = sqrt$1(x);
3069
3191
  return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
@@ -3225,6 +3347,10 @@ var BatchTrace = class extends Trace {
3225
3347
  const [valsIn, bdimsIn] = unzip2(tracers.map((t) => [t.val, t.batchDim]));
3226
3348
  const vmapRule = vmapRules[primitive];
3227
3349
  if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
3350
+ if (bdimsIn.every((d) => d === null)) {
3351
+ const valOuts$1 = bind(primitive, valsIn, params);
3352
+ return valOuts$1.map((x) => new BatchTracer(this, x, null));
3353
+ }
3228
3354
  const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
3229
3355
  return zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
3230
3356
  }
@@ -3232,24 +3358,28 @@ var BatchTrace = class extends Trace {
3232
3358
  return this.main.globalData;
3233
3359
  }
3234
3360
  };
3235
- function handleScalarBroadcasting(nd, x, d) {
3236
- if (d === null || nd === ndim$1(x)) return x;
3237
- else {
3238
- const axis = range(ndim$1(x), nd);
3239
- const shape$1 = [...x.shape, ...axis.map(() => 1)];
3240
- return broadcast(x, shape$1, axis);
3241
- }
3242
- }
3243
- /** Process a primitive with built-in broadcasting. */
3361
+ /**
3362
+ * Process a primitive with built-in broadcasting.
3363
+ *
3364
+ * Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
3365
+ */
3244
3366
  function broadcastBatcher(op) {
3245
3367
  return (axisSize, args, dims) => {
3246
3368
  if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
3247
- const idx = dims.findIndex((d) => d !== null);
3248
- if (idx === -1) return [[op(...args)], [null]];
3249
- if (zip(args, dims).every(([x, d]) => ndim$1(x) === 0 || deepEqual(x.shape, args[idx].shape) && d === dims[idx])) return [[op(...args)], [dims[idx]]];
3250
- args = args.map((x, i) => ndim$1(x) > 0 ? moveBatchAxis(axisSize, dims[i], 0, x) : x);
3251
- const nd = Math.max(...args.map(ndim$1));
3252
- args = args.map((x, i) => handleScalarBroadcasting(nd, x, dims[i]));
3369
+ const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
3370
+ const firstIdx = dims.findIndex((d) => d !== null);
3371
+ const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
3372
+ if (zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
3373
+ args = args.map((x, i) => {
3374
+ if (dims[i] === null) return x;
3375
+ x = moveBatchAxis(axisSize, dims[i], 0, x);
3376
+ if (x.ndim < nd) x = x.reshape([
3377
+ x.shape[0],
3378
+ ...rep(nd - x.ndim, 1),
3379
+ ...x.shape.slice(1)
3380
+ ]);
3381
+ return x;
3382
+ });
3253
3383
  return [[op(...args)], [0]];
3254
3384
  };
3255
3385
  }
@@ -3273,17 +3403,18 @@ const vmapRules = {
3273
3403
  [Primitive.Atan]: unopBatcher(atan$1),
3274
3404
  [Primitive.Exp]: unopBatcher(exp$1),
3275
3405
  [Primitive.Log]: unopBatcher(log$1),
3406
+ [Primitive.Erf]: unopBatcher(erf$1),
3407
+ [Primitive.Erfc]: unopBatcher(erfc$1),
3276
3408
  [Primitive.Sqrt]: unopBatcher(sqrt$1),
3277
3409
  [Primitive.Min]: broadcastBatcher(min$1),
3278
3410
  [Primitive.Max]: broadcastBatcher(max$1),
3279
3411
  [Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
3280
- if (xBdim === null) return [[reduce(x, op, axis)], [null]];
3412
+ assertNonNull(xBdim);
3281
3413
  const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3282
3414
  const outBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
3283
3415
  return [[reduce(x, op, newAxis)], [outBdim]];
3284
3416
  },
3285
3417
  [Primitive.Dot](axisSize, [x, y], [xBdim, yBdim]) {
3286
- if (xBdim === null && yBdim === null) return [[dot$1(x, y)], [null]];
3287
3418
  x = moveBatchAxis(axisSize, xBdim, x.ndim - (xBdim === null ? 1 : 2), x);
3288
3419
  y = moveBatchAxis(axisSize, yBdim, y.ndim - (yBdim === null ? 1 : 2), y);
3289
3420
  const z = dot$1(x, y);
@@ -3292,26 +3423,68 @@ const vmapRules = {
3292
3423
  [Primitive.Compare](axisSize, args, dims, { op }) {
3293
3424
  return broadcastBatcher((x, y) => compare(x, y, op))(axisSize, args, dims, {});
3294
3425
  },
3426
+ [Primitive.Where]: broadcastBatcher(where$1),
3427
+ [Primitive.Transpose](axisSize, [x], [xBdim], { perm }) {
3428
+ assertNonNull(xBdim);
3429
+ const newPerm = perm.map((p) => p + (xBdim <= p ? 1 : 0));
3430
+ newPerm.splice(xBdim, 0, xBdim);
3431
+ return [[transpose$1(x, newPerm)], [xBdim]];
3432
+ },
3433
+ [Primitive.Broadcast](axisSize, [x], [xBdim], { shape: shape$1, axis }) {
3434
+ assertNonNull(xBdim);
3435
+ const newShape = shape$1.toSpliced(xBdim, 0, axisSize);
3436
+ const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3437
+ return [[broadcast(x, newShape, newAxis)], [xBdim]];
3438
+ },
3295
3439
  [Primitive.Reshape](axisSize, [x], [xBdim], { shape: shape$1 }) {
3296
- if (xBdim === null) return [[reshape$1(x, shape$1)], [null]];
3297
3440
  x = moveBatchAxis(axisSize, xBdim, 0, x);
3298
3441
  return [[reshape$1(x, [axisSize, ...shape$1])], [0]];
3299
3442
  },
3300
3443
  [Primitive.Flip](axisSize, [x], [xBdim], { axis }) {
3301
- if (xBdim === null) return [[flip$1(x, axis)], [null]];
3444
+ assertNonNull(xBdim);
3302
3445
  const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3303
3446
  return [[flip$1(x, newAxis)], [xBdim]];
3304
3447
  },
3305
3448
  [Primitive.Shrink](axisSize, [x], [xBdim], { slice }) {
3306
- if (xBdim === null) return [[shrink(x, slice)], [null]];
3449
+ assertNonNull(xBdim);
3307
3450
  const newSlice = slice.toSpliced(xBdim, 0, [0, axisSize]);
3308
3451
  return [[shrink(x, newSlice)], [xBdim]];
3309
3452
  },
3310
3453
  [Primitive.Pad](axisSize, [x], [xBdim], { width }) {
3311
- if (xBdim === null) return [[pad$1(x, width)], [null]];
3454
+ assertNonNull(xBdim);
3312
3455
  const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
3313
3456
  return [[pad$1(x, newWidth)], [xBdim]];
3314
3457
  },
3458
+ [Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
3459
+ if (indicesBdim.every((d) => d === null)) {
3460
+ assertNonNull(xBdim);
3461
+ const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3462
+ let newBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
3463
+ let newOutDim = outDim;
3464
+ if (newOutDim < newBdim) newBdim += axis.length;
3465
+ else newOutDim += 1;
3466
+ return [[gather(x, indices, newAxis, newOutDim)], [newBdim]];
3467
+ }
3468
+ const nd = Math.max(...indices.map((m, i) => ndim$1(m) + (indicesBdim[i] === null ? 1 : 0)));
3469
+ indices = indices.map((m, i) => {
3470
+ if (indicesBdim[i] === null) return m;
3471
+ m = moveBatchAxis(axisSize, indicesBdim[i], 0, m);
3472
+ if (m.ndim < nd) m = m.reshape([
3473
+ m.shape[0],
3474
+ ...rep(nd - m.ndim, 1),
3475
+ ...m.shape.slice(1)
3476
+ ]);
3477
+ return m;
3478
+ });
3479
+ if (xBdim === null) return [[gather(x, indices, axis, outDim)], [outDim]];
3480
+ else {
3481
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3482
+ const newAxis = [0, ...axis.map((ax) => ax + 1)];
3483
+ const extraBatchIndex = arange(axisSize).reshape([-1, ...rep(nd - 1, 1)]);
3484
+ indices.splice(0, 0, extraBatchIndex);
3485
+ return [[gather(x, indices, newAxis, outDim)], [outDim]];
3486
+ }
3487
+ },
3315
3488
  [Primitive.JitCall](axisSize, args, dims, { name, jaxpr }) {
3316
3489
  const { newJaxpr, newConsts } = vmapJaxpr(jaxpr, axisSize, dims);
3317
3490
  const outs = bind(Primitive.JitCall, [...newConsts.map((c) => c.ref), ...args], {
@@ -3371,12 +3544,14 @@ function vmapFlat(f, inAxes, args) {
3371
3544
  function vmap$1(f, inAxes = 0) {
3372
3545
  return (...args) => {
3373
3546
  const [argsFlat, inTree] = flatten(args);
3374
- let inAxesFlat;
3547
+ let inAxesFlat = [];
3375
3548
  if (typeof inAxes === "number") inAxesFlat = rep(argsFlat.length, inAxes);
3549
+ else for (let i = 0; i < args.length; i++) if (inAxes[i] == null) inAxesFlat.push(...rep(inTree.childTreedefs[i].size, null));
3550
+ else if (typeof inAxes[i] === "number") inAxesFlat.push(...rep(inTree.childTreedefs[i].size, inAxes[i]));
3376
3551
  else {
3377
- let inTree2;
3378
- [inAxesFlat, inTree2] = flatten(inAxes);
3379
- if (!inTree.equals(inTree2)) throw new TreeMismatchError("vmap", inTree, inTree2);
3552
+ const [axesFlat, axesTreeDef] = flatten(inAxes[i]);
3553
+ if (!inTree.childTreedefs[i].equals(axesTreeDef)) throw new TreeMismatchError("vmap", inTree.childTreedefs[i], axesTreeDef);
3554
+ inAxesFlat.push(...axesFlat);
3380
3555
  }
3381
3556
  const [fFlat, outTree] = flattenFun(f, inTree);
3382
3557
  const outsFlat = vmapFlat(fFlat, inAxesFlat, argsFlat);
@@ -3996,7 +4171,7 @@ function valueAndGrad$1(f) {
3996
4171
  const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
3997
4172
  if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
3998
4173
  if (!isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
3999
- const [ct, ...rest] = fVjp(array(1, { dtype: y.dtype }));
4174
+ const [ct, ...rest] = fVjp(onesLike$1(y.ref));
4000
4175
  for (const r of rest) dispose(r);
4001
4176
  fVjp.dispose();
4002
4177
  return [y, ct];
@@ -4024,7 +4199,10 @@ __export(lax_exports, {
4024
4199
  conv: () => conv$1,
4025
4200
  convGeneralDilated: () => convGeneralDilated,
4026
4201
  convWithGeneralPadding: () => convWithGeneralPadding,
4027
- reduceWindow: () => reduceWindow
4202
+ erf: () => erf,
4203
+ erfc: () => erfc,
4204
+ reduceWindow: () => reduceWindow,
4205
+ stopGradient: () => stopGradient$1
4028
4206
  });
4029
4207
  function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
4030
4208
  const padType = padding.toUpperCase();
@@ -4083,6 +4261,28 @@ function reduceWindow(operand, computation, windowDimensions, windowStrides) {
4083
4261
  strides: windowStrides
4084
4262
  }));
4085
4263
  }
4264
+ /** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
4265
+ function erf(x) {
4266
+ return erf$1(x);
4267
+ }
4268
+ /**
4269
+ * The complementary error function: `erfc(x) = 1 - erf(x)`.
4270
+ *
4271
+ * This function is more accurate than `1 - erf(x)` for large values of `x`,
4272
+ * where `erf(x)` is very close to 1.
4273
+ */
4274
+ function erfc(x) {
4275
+ return erfc$1(x);
4276
+ }
4277
+ /**
4278
+ * Stops gradient computation.
4279
+ *
4280
+ * Behaves as the identity function but prevents the flow of gradients during
4281
+ * forward or reverse-mode automatic differentiation.
4282
+ */
4283
+ function stopGradient$1(x) {
4284
+ return stopGradient(x);
4285
+ }
4086
4286
 
4087
4287
  //#endregion
4088
4288
  //#region src/numpy.ts
@@ -4141,16 +4341,25 @@ __export(numpy_exports, {
4141
4341
  flipud: () => flipud,
4142
4342
  float16: () => float16,
4143
4343
  float32: () => float32,
4344
+ float64: () => float64,
4144
4345
  full: () => full,
4145
4346
  fullLike: () => fullLike$1,
4146
4347
  greater: () => greater,
4147
4348
  greaterEqual: () => greaterEqual,
4349
+ hamming: () => hamming,
4350
+ hann: () => hann,
4351
+ heaviside: () => heaviside,
4148
4352
  hstack: () => hstack,
4149
4353
  hypot: () => hypot,
4150
4354
  identity: () => identity$1,
4151
4355
  inf: () => inf,
4152
4356
  inner: () => inner,
4153
4357
  int32: () => int32,
4358
+ isfinite: () => isfinite,
4359
+ isinf: () => isinf,
4360
+ isnan: () => isnan,
4361
+ isneginf: () => isneginf,
4362
+ isposinf: () => isposinf,
4154
4363
  less: () => less,
4155
4364
  lessEqual: () => lessEqual,
4156
4365
  linspace: () => linspace,
@@ -4221,6 +4430,7 @@ const int32 = DType.Int32;
4221
4430
  const uint32 = DType.Uint32;
4222
4431
  const bool = DType.Bool;
4223
4432
  const float16 = DType.Float16;
4433
+ const float64 = DType.Float64;
4224
4434
  /** Euler's constant, `e = 2.7182818284590...` */
4225
4435
  const e = Math.E;
4226
4436
  /** Euler-Mascheroni constant, `γ = 0.5772156649...` */
@@ -4784,6 +4994,32 @@ function sign(x) {
4784
4994
  x = fudgeArray(x);
4785
4995
  return where(notEqual(x.ref, 0), where(less(x.ref, 0), -1, 1), 0);
4786
4996
  }
4997
+ /**
4998
+ * Return the Hamming window of size M, a taper with a weighted cosine bell.
4999
+ *
5000
+ * `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
5001
+ */
5002
+ function hamming(M) {
5003
+ return cos(linspace(0, 2 * Math.PI, M)).mul(-.46).add(.54);
5004
+ }
5005
+ /**
5006
+ * Return the Hann window of size M, a taper with a weighted cosine bell.
5007
+ *
5008
+ * `w(n) = 0.5 - 0.5 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
5009
+ */
5010
+ function hann(M) {
5011
+ return cos(linspace(0, 2 * Math.PI, M)).mul(-.5).add(.5);
5012
+ }
5013
+ /**
5014
+ * @function
5015
+ * Compute the Heaviside step function. It is defined piecewise:
5016
+ * - `heaviside(x1, x2) = 0` for `x1 < 0`,
5017
+ * - `heaviside(x1, x2) = x2` for `x1 == 0`,
5018
+ * - `heaviside(x1, x2) = 1` for `x1 > 0`.
5019
+ */
5020
+ const heaviside = jit$1(function heaviside$1(x1, x2) {
5021
+ return where(less(x1.ref, 0), 0, where(equal(x1, 0), x2, 1));
5022
+ });
4787
5023
  /** Calculate element-wise square of the input array. */
4788
5024
  function square(x) {
4789
5025
  x = fudgeArray(x);
@@ -4803,8 +5039,8 @@ function acos(x) {
4803
5039
  * Return element-wise hypotenuse for the given legs of a right triangle.
4804
5040
  *
4805
5041
  * In the original NumPy/JAX implementation, this function is more numerically
4806
- * stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
4807
- * improvements.
5042
+ * stable than `sqrt(x1**2 + x2**2)`. We don't currently implement those
5043
+ * stability improvements.
4808
5044
  */
4809
5045
  const hypot = jit$1(function hypot$1(x1, x2) {
4810
5046
  return sqrt(square(x1).add(square(x2)));
@@ -4995,6 +5231,34 @@ function var_(x, axis = null, opts) {
4995
5231
  function std(x, axis = null, opts) {
4996
5232
  return sqrt(var_(x, axis, opts));
4997
5233
  }
5234
+ /** Test element-wise for positive or negative infinity, return bool array. */
5235
+ function isinf(x) {
5236
+ x = fudgeArray(x);
5237
+ return isFloatDtype(x.dtype) ? x.ref.equal(Infinity).add(x.equal(-Infinity)) : fullLike$1(x, false);
5238
+ }
5239
+ /** Test element-wise for NaN (Not a Number). */
5240
+ function isnan(x) {
5241
+ x = fudgeArray(x);
5242
+ return isFloatDtype(x.dtype) ? x.ref.notEqual(x) : fullLike$1(x, false);
5243
+ }
5244
+ /** Test element-wise for negative infinity, return bool array. */
5245
+ function isneginf(x) {
5246
+ x = fudgeArray(x);
5247
+ return isFloatDtype(x.dtype) ? x.equal(-Infinity) : fullLike$1(x, false);
5248
+ }
5249
+ /** Test element-wise for positive infinity, return bool array. */
5250
+ function isposinf(x) {
5251
+ x = fudgeArray(x);
5252
+ return isFloatDtype(x.dtype) ? x.equal(Infinity) : fullLike$1(x, false);
5253
+ }
5254
+ /**
5255
+ * @function
5256
+ * Test element-wise for finite values (not infinity or NaN).
5257
+ */
5258
+ const isfinite = jit$1(function isfinite$1(x) {
5259
+ if (!isFloatDtype(x.dtype)) return fullLike$1(x, true);
5260
+ return isnan(x.ref).add(isinf(x)).notEqual(true);
5261
+ });
4998
5262
 
4999
5263
  //#endregion
5000
5264
  //#region src/nn.ts
@@ -5128,18 +5392,20 @@ function celu(x, alpha = 1) {
5128
5392
  * @function
5129
5393
  * Gaussion error linear unit (GELU) activation function.
5130
5394
  *
5131
- * This is computed element-wise. Currently jax-js does not support the erf() or
5132
- * gelu() functions exactly as primitives, so an approximation is used:
5133
- * `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
5395
+ * This is computed element-wise. There are two variants depending on whether
5396
+ * `approximate` is set (default true):
5134
5397
  *
5135
- * Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
5398
+ * - Approximate: `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`
5399
+ * - Exact: `gelu(x) = x * 0.5 * erfc(-x / sqrt(2))`
5136
5400
  *
5137
- * This will be improved in the future.
5401
+ * Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
5138
5402
  */
5139
- const gelu = jit$1(function gelu$1(x) {
5140
- const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
5141
- return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
5142
- });
5403
+ const gelu = jit$1(function gelu$1(x, opts) {
5404
+ if (opts?.approximate ?? true) {
5405
+ const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
5406
+ return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
5407
+ } else return x.ref.mul(.5).mul(erfc$1(negative(x.ref.mul(Math.SQRT1_2))));
5408
+ }, { staticArgnums: [1] });
5143
5409
  /**
5144
5410
  * Gated linear unit (GLU) activation function.
5145
5411
  *
@@ -5360,6 +5626,25 @@ const normal = jit$1(function normal$1(key$1, shape$1 = []) {
5360
5626
  return radius.mul(cos(theta));
5361
5627
  }, { staticArgnums: [1] });
5362
5628
 
5629
+ //#endregion
5630
+ //#region src/scipy-special.ts
5631
+ var scipy_special_exports = {};
5632
+ __export(scipy_special_exports, {
5633
+ erf: () => erf,
5634
+ erfc: () => erfc,
5635
+ logSoftmax: () => logSoftmax,
5636
+ logit: () => logit,
5637
+ logsumexp: () => logsumexp,
5638
+ softmax: () => softmax
5639
+ });
5640
+ /**
5641
+ * @function
5642
+ * The logit function, `logit(p) = log(p / (1-p))`.
5643
+ */
5644
+ const logit = jit$1(function logit$1(x) {
5645
+ return log(x.ref.div(subtract(1, x)));
5646
+ });
5647
+
5363
5648
  //#endregion
5364
5649
  //#region src/polyfills.ts
5365
5650
  /** @file Polyfills for using this library. */
@@ -5453,6 +5738,25 @@ async function blockUntilReady(x) {
5453
5738
  await Promise.all(promises);
5454
5739
  return x;
5455
5740
  }
5741
+ /**
5742
+ * Transfer `x` to `device`.
5743
+ *
5744
+ * `x` may be a nested container of arrays or scalars. The resulting structure
5745
+ * is committed to the device.
5746
+ *
5747
+ * If `device` is not specified, this function behaves as identity if the input
5748
+ * is already an `Array`, otherwise it places the scalar uncommitted on the
5749
+ * default device.
5750
+ */
5751
+ async function devicePut(x, device) {
5752
+ const [xflat, structure$1] = flatten(x);
5753
+ const yflat = await Promise.all(xflat.map((leaf) => {
5754
+ if (leaf instanceof Array$1) return device ? leaf._put(getBackend(device)) : Promise.resolve(leaf);
5755
+ else return Promise.resolve(array(leaf, { device }));
5756
+ }));
5757
+ return unflatten(structure$1, yflat);
5758
+ }
5456
5759
 
5457
5760
  //#endregion
5458
- export { Array$1 as Array, DType, Jaxpr, blockUntilReady, defaultDevice, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
5761
+ export { Array$1 as Array, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
5762
+ //# sourceMappingURL=index.js.map