@jax-js/jax 0.0.5 → 0.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.cjs CHANGED
@@ -30,7 +30,7 @@ var __toESM = (mod, isNodeMode, target) => (target = mod != null ? __create(__ge
30
30
  }) : target, mod));
31
31
 
32
32
  //#endregion
33
- const require_backend = require('./backend-yEU0L_ig.cjs');
33
+ const require_backend = require('./backend-BbrKEB18.cjs');
34
34
 
35
35
  //#region src/tree.ts
36
36
  var tree_exports = {};
@@ -60,6 +60,10 @@ var JsTreeDef = class JsTreeDef {
60
60
  this.nodeMetadata = nodeMetadata;
61
61
  this.childTreedefs = childTreedefs;
62
62
  }
63
+ /** Get the total number of leaves in the tree. */
64
+ get size() {
65
+ return this.nodeType === NodeType.Leaf ? 1 : this.childTreedefs.reduce((a, b) => a + b.size, 0);
66
+ }
63
67
  /** Returns a string representation of this tree definition. */
64
68
  toString(root = true) {
65
69
  if (root) return "JsTreeDef(" + this.toString(false) + ")";
@@ -215,6 +219,16 @@ function pool(st, ks, strides = 1, dilation = 1) {
215
219
  const s_ = strides;
216
220
  const d_ = dilation;
217
221
  const o_ = require_backend.zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
222
+ if (d_.every((d) => d === 1) && ks.every((k, j) => k <= s_[j])) {
223
+ st = st.padOrShrink([...noop.map(() => [0, 0]), ...require_backend.zipn(i_, o_, s_).map(([i, o, s]) => [0, o * s - i])]);
224
+ st = st.reshape([...noop, ...require_backend.zip(o_, s_).flatMap(([o, s]) => [o, s])]).shrink([...noop.map((x) => [0, x]), ...require_backend.zip(o_, ks).flatMap(([o, k]) => [[0, o], [0, k]])]);
225
+ st = st.permute([
226
+ ...require_backend.range(noop.length),
227
+ ...ks.map((_, j) => noop.length + 2 * j),
228
+ ...ks.map((_, j) => noop.length + 2 * j + 1)
229
+ ]);
230
+ return st;
231
+ }
218
232
  const f_ = require_backend.zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
219
233
  const kidf = require_backend.zipn(ks, i_, d_, f_);
220
234
  st = st.repeat([...require_backend.rep(noop.length, 1), ...kidf.map(([k, i, d, f]) => Math.ceil(k * (i * f + d) / i))]);
@@ -249,6 +263,12 @@ function poolTranspose(st, inShape, ks, strides = 1, dilation = 1) {
249
263
  const s_ = strides;
250
264
  const d_ = dilation;
251
265
  const o_ = require_backend.zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
266
+ if (d_.every((d) => d === 1) && ks.every((k, j) => k <= s_[j])) {
267
+ st = st.permute([...require_backend.range(noop.length), ...ks.flatMap((_, j) => [noop.length + j, noop.length + o_.length + j])]);
268
+ st = st.pad([...noop.map(() => [0, 0]), ...require_backend.zip(s_, ks).flatMap(([s, k]) => [[0, 0], [0, s - k]])]).reshape([...noop, ...require_backend.zip(o_, s_).map(([o, s]) => o * s)]);
269
+ st = st.padOrShrink([...noop.map(() => [0, 0]), ...require_backend.zipn(i_, o_, s_).map(([i, o, s]) => [0, i - o * s])]);
270
+ return st.reshape(st.shape.concat(require_backend.rep(ks.length, 1)));
271
+ }
252
272
  if (!require_backend.deepEqual(o_, st.shape.slice(noop.length, noop.length + ks.length))) throw new Error("poolTranspose() called with mismatched output shape");
253
273
  const f_ = require_backend.zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
254
274
  const kidf = require_backend.zipn(ks, i_, d_, f_);
@@ -358,6 +378,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
358
378
  Primitive$1["Atan"] = "atan";
359
379
  Primitive$1["Exp"] = "exp";
360
380
  Primitive$1["Log"] = "log";
381
+ Primitive$1["Erf"] = "erf";
382
+ Primitive$1["Erfc"] = "erfc";
361
383
  Primitive$1["Sqrt"] = "sqrt";
362
384
  Primitive$1["Min"] = "min";
363
385
  Primitive$1["Max"] = "max";
@@ -379,11 +401,9 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
379
401
  return Primitive$1;
380
402
  }({});
381
403
  let CompareOp = /* @__PURE__ */ function(CompareOp$1) {
382
- CompareOp$1["Greater"] = "greater";
383
404
  CompareOp$1["Less"] = "less";
384
405
  CompareOp$1["Equal"] = "equal";
385
406
  CompareOp$1["NotEqual"] = "not_equal";
386
- CompareOp$1["GreaterEqual"] = "greater_equal";
387
407
  CompareOp$1["LessEqual"] = "less_equal";
388
408
  return CompareOp$1;
389
409
  }({});
@@ -435,6 +455,12 @@ function exp$1(x) {
435
455
  function log$1(x) {
436
456
  return bind1(Primitive.Log, [x]);
437
457
  }
458
+ function erf$1(x) {
459
+ return bind1(Primitive.Erf, [x]);
460
+ }
461
+ function erfc$1(x) {
462
+ return bind1(Primitive.Erfc, [x]);
463
+ }
438
464
  function sqrt$1(x) {
439
465
  return bind1(Primitive.Sqrt, [x]);
440
466
  }
@@ -473,7 +499,7 @@ function compare(x, y, op) {
473
499
  return bind1(Primitive.Compare, [x, y], { op });
474
500
  }
475
501
  function greater$1(x, y) {
476
- return compare(x, y, CompareOp.Greater);
502
+ return compare(y, x, CompareOp.Less);
477
503
  }
478
504
  function less$1(x, y) {
479
505
  return compare(x, y, CompareOp.Less);
@@ -485,7 +511,7 @@ function notEqual$1(x, y) {
485
511
  return compare(x, y, CompareOp.NotEqual);
486
512
  }
487
513
  function greaterEqual$1(x, y) {
488
- return compare(x, y, CompareOp.GreaterEqual);
514
+ return compare(y, x, CompareOp.LessEqual);
489
515
  }
490
516
  function lessEqual$1(x, y) {
491
517
  return compare(x, y, CompareOp.LessEqual);
@@ -1177,12 +1203,18 @@ function reshapeViews(exp$2, mapping, reduceAxis = false) {
1177
1203
  } else if (exp$3.op === require_backend.AluOp.GlobalIndex) throw new Error("internal: reshapeViews() called with GlobalIndex op");
1178
1204
  });
1179
1205
  }
1180
- function broadcastedJit(fn) {
1206
+ function broadcastedJit(fn, opts) {
1181
1207
  return (nargs, exps, avals, params) => {
1182
- const newShape = avals.map((aval) => aval.shape).reduce(require_backend.generalBroadcast);
1183
- exps = exps.map((exp$3) => reshapeViews(exp$3, (st) => {
1184
- if (!require_backend.deepEqual(st.shape, newShape)) return st.broadcast(newShape, require_backend.range(newShape.length - st.shape.length));
1185
- }));
1208
+ let { shape: newShape, dtype: newDtype } = avals.reduce(promoteAvals);
1209
+ const skipCastIdx = opts?.skipCastIdx ?? [];
1210
+ if (skipCastIdx.length) newDtype = avals.filter((_, i) => !skipCastIdx.includes(i)).reduce(promoteAvals).dtype;
1211
+ exps = exps.map((exp$3, i) => {
1212
+ exp$3 = reshapeViews(exp$3, (st) => {
1213
+ if (!require_backend.deepEqual(st.shape, newShape)) return st.broadcast(newShape, require_backend.range(newShape.length - st.shape.length));
1214
+ });
1215
+ if (exp$3.dtype !== newDtype && !skipCastIdx.includes(i)) exp$3 = require_backend.AluExp.cast(newDtype, exp$3);
1216
+ return exp$3;
1217
+ });
1186
1218
  const exp$2 = fn(exps, params);
1187
1219
  return new require_backend.Kernel(nargs, require_backend.prod(newShape), exp$2);
1188
1220
  };
@@ -1225,6 +1257,8 @@ const jitRules = {
1225
1257
  [Primitive.Atan]: unopJit(require_backend.AluExp.atan),
1226
1258
  [Primitive.Exp]: unopJit(require_backend.AluExp.exp),
1227
1259
  [Primitive.Log]: unopJit(require_backend.AluExp.log),
1260
+ [Primitive.Erf]: unopJit(require_backend.AluExp.erf),
1261
+ [Primitive.Erfc]: unopJit(require_backend.AluExp.erfc),
1228
1262
  [Primitive.Sqrt]: unopJit(require_backend.AluExp.sqrt),
1229
1263
  [Primitive.Min]: broadcastedJit(([a, b]) => require_backend.AluExp.min(a, b)),
1230
1264
  [Primitive.Max]: broadcastedJit(([a, b]) => require_backend.AluExp.max(a, b)),
@@ -1272,7 +1306,7 @@ const jitRules = {
1272
1306
  return jitRules[Primitive.Dot](nargs, [a, b], [as, bs], {});
1273
1307
  },
1274
1308
  [Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
1275
- [Primitive.Where]: broadcastedJit(([cond, a, b]) => require_backend.AluExp.where(cond, a, b)),
1309
+ [Primitive.Where]: broadcastedJit(([cond, a, b]) => require_backend.AluExp.where(cond, a, b), { skipCastIdx: [0] }),
1276
1310
  [Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
1277
1311
  [Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
1278
1312
  [Primitive.Reshape]: reshapeJit((st, { shape: shape$1 }) => st.reshape(shape$1)),
@@ -1443,7 +1477,7 @@ var PendingExecute = class {
1443
1477
  /**
1444
1478
  * A multidimensional numeric array with data stored on CPU or GPU.
1445
1479
  *
1446
- * This is the library's core data type. Equivalent to `jnp.Array` from JAX, or
1480
+ * This is the library's core data type. Equivalent to `jax.Array` from JAX, or
1447
1481
  * `torch.Tensor`.
1448
1482
  *
1449
1483
  * Not to be confused with the JavaScript "Array" constructor. Avoid importing
@@ -1458,6 +1492,7 @@ var Array$1 = class Array$1 extends Tracer {
1458
1492
  #source;
1459
1493
  #st;
1460
1494
  #backend;
1495
+ #committed;
1461
1496
  #rc;
1462
1497
  #pendingSet;
1463
1498
  /**
@@ -1474,6 +1509,7 @@ var Array$1 = class Array$1 extends Tracer {
1474
1509
  this.#source = args.source;
1475
1510
  this.#st = args.st;
1476
1511
  this.#backend = args.backend;
1512
+ this.#committed = args.committed;
1477
1513
  this.#rc = 1;
1478
1514
  this.#pendingSet = new Set(args.pending);
1479
1515
  if (this.#pendingSet.size === 0) this.#pendingSet = null;
@@ -1501,6 +1537,7 @@ var Array$1 = class Array$1 extends Tracer {
1501
1537
  dtype: args.dtype ?? this.#dtype,
1502
1538
  weakType: this.#weakType,
1503
1539
  backend: args.backend ?? this.#backend,
1540
+ committed: args.committed ?? this.#committed,
1504
1541
  pending: args.pending ?? this.#pending ?? void 0
1505
1542
  });
1506
1543
  }
@@ -1556,9 +1593,10 @@ var Array$1 = class Array$1 extends Tracer {
1556
1593
  */
1557
1594
  #gather(indices, axis, outDim) {
1558
1595
  this.#check();
1559
- if (indices.some((a) => a.#backend !== this.#backend)) throw new TypeError(`Gather indices must have the same backend: ${this.#backend.type}`);
1560
1596
  const axisSet = new Set(axis);
1561
1597
  if (axisSet.size !== axis.length) throw new TypeError("Gather axis must not have duplicates");
1598
+ if (indices.some((a) => a.#committed && a.#backend !== this.#backend)) throw new TypeError(`Gather indices must have the same backend: ${this.#backend.type}`);
1599
+ indices = indices.map((ar) => ar._putSync(this.#backend));
1562
1600
  indices = Array$1.#broadcastArrays(indices);
1563
1601
  const indexShape = indices[0].shape;
1564
1602
  const finalShape = this.shape.filter((_, i) => !axisSet.has(i));
@@ -1627,6 +1665,7 @@ var Array$1 = class Array$1 extends Tracer {
1627
1665
  this.#check();
1628
1666
  if (this.#source instanceof require_backend.AluExp) {
1629
1667
  const exp$3 = new require_backend.AluExp(op, dtypeOutput, [this.#source]);
1668
+ this.dispose();
1630
1669
  return this.#newArrayFrom({
1631
1670
  source: exp$3.simplify(),
1632
1671
  dtype: dtypeOutput,
@@ -1655,21 +1694,19 @@ var Array$1 = class Array$1 extends Tracer {
1655
1694
  }
1656
1695
  static #naryCustom(name, custom, arrays, { dtypeOverride, strongTypeOutput, reduceAxis } = {}) {
1657
1696
  const n = arrays.length;
1658
- const backend = arrays[0].#backend;
1659
1697
  if (n === 0) throw new TypeError(`No inputs for ${name}`);
1660
1698
  for (const ar of arrays) ar.#check();
1661
1699
  let castDtype;
1662
1700
  let castWeakType = true;
1663
- for (let i = 0; i < n; i++) {
1664
- if (dtypeOverride?.[i]) {
1665
- if (arrays[i].#dtype !== dtypeOverride[i]) throw new TypeError(`Wrong dtype in ${name}: expected ${dtypeOverride[i]}, got ${arrays[i].#dtype}`);
1666
- } else if (castDtype === void 0) {
1667
- castDtype = arrays[i].#dtype;
1668
- castWeakType = arrays[i].#weakType;
1669
- } else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), new ShapedArray([], arrays[i].#dtype, arrays[i].#weakType)));
1670
- if (arrays[i].#backend !== backend) throw new TypeError(`Backend mismatch in ${name}: ${backend.type} vs ${arrays[i].#backend.type}`);
1671
- }
1701
+ for (let i = 0; i < n; i++) if (dtypeOverride?.[i]) {
1702
+ if (arrays[i].#dtype !== dtypeOverride[i]) throw new TypeError(`Wrong dtype in ${name}: expected ${dtypeOverride[i]}, got ${arrays[i].#dtype}`);
1703
+ } else if (castDtype === void 0) {
1704
+ castDtype = arrays[i].#dtype;
1705
+ castWeakType = arrays[i].#weakType;
1706
+ } else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), new ShapedArray([], arrays[i].#dtype, arrays[i].#weakType)));
1672
1707
  const weakType = castWeakType && !strongTypeOutput;
1708
+ const { backend, committed } = Array$1.#computeBackend(name, arrays);
1709
+ arrays = arrays.map((ar) => ar._putSync(backend));
1673
1710
  arrays = Array$1.#broadcastArrays(arrays);
1674
1711
  const newShape = [...arrays[0].shape];
1675
1712
  if (arrays.every((ar) => ar.#source instanceof require_backend.AluExp) && !reduceAxis) {
@@ -1679,12 +1716,14 @@ var Array$1 = class Array$1 extends Tracer {
1679
1716
  });
1680
1717
  if (arrays.every((ar) => require_backend.deepEqual(ar.#st, arrays[0].#st))) {
1681
1718
  const exp$4 = custom(sources);
1719
+ arrays.forEach((ar) => ar.dispose());
1682
1720
  return new Array$1({
1683
1721
  source: exp$4.simplify(),
1684
1722
  st: arrays[0].#st,
1685
1723
  dtype: exp$4.dtype,
1686
1724
  weakType,
1687
- backend
1725
+ backend,
1726
+ committed
1688
1727
  });
1689
1728
  }
1690
1729
  const exp$3 = custom(arrays.map((ar, i) => {
@@ -1693,12 +1732,14 @@ var Array$1 = class Array$1 extends Tracer {
1693
1732
  return require_backend.accessorAluExp(src$1, ar.#st, require_backend.unravelAlu(newShape, require_backend.AluVar.idx));
1694
1733
  }));
1695
1734
  const st = require_backend.ShapeTracker.fromShape(newShape);
1735
+ arrays.forEach((ar) => ar.dispose());
1696
1736
  return new Array$1({
1697
1737
  source: exp$3.simplify(),
1698
1738
  st,
1699
1739
  dtype: exp$3.dtype,
1700
1740
  weakType,
1701
- backend
1741
+ backend,
1742
+ committed
1702
1743
  });
1703
1744
  }
1704
1745
  let indices;
@@ -1734,13 +1775,14 @@ var Array$1 = class Array$1 extends Tracer {
1734
1775
  const pending = new Set([...arrays.flatMap((ar) => ar.#pending)]);
1735
1776
  for (const exe of pending) exe.updateRc(1);
1736
1777
  pending.add(new PendingExecute(backend, kernel, inputs, [output]));
1737
- for (const ar of arrays) ar.dispose();
1778
+ arrays.forEach((ar) => ar.dispose());
1738
1779
  return new Array$1({
1739
1780
  source: output,
1740
1781
  st: require_backend.ShapeTracker.fromShape(newShape),
1741
1782
  dtype: kernel.dtype,
1742
1783
  weakType,
1743
1784
  backend,
1785
+ committed,
1744
1786
  pending
1745
1787
  });
1746
1788
  }
@@ -1818,6 +1860,23 @@ var Array$1 = class Array$1 extends Tracer {
1818
1860
  return ar.#reshape(ar.#st.broadcast(newShape, require_backend.range(newShape.length - ar.ndim)));
1819
1861
  });
1820
1862
  }
1863
+ static #computeBackend(name, arrays) {
1864
+ const committed = arrays.filter((ar) => ar.#committed);
1865
+ if (committed.length > 0) {
1866
+ const backend = committed[0].#backend;
1867
+ for (const ar of committed) if (ar.#backend !== backend) throw new Error(`Device mismatch in ${name} between committed arrays on (${backend.type}, ${ar.#backend.type}), please move to the same device with devicePut()`);
1868
+ return {
1869
+ backend,
1870
+ committed: true
1871
+ };
1872
+ } else {
1873
+ const backend = arrays.length > 0 ? arrays[0].#backend : require_backend.getBackend();
1874
+ return {
1875
+ backend,
1876
+ committed: false
1877
+ };
1878
+ }
1879
+ }
1821
1880
  /** Realize the array and return it as data. */
1822
1881
  async data() {
1823
1882
  if (this.#source instanceof require_backend.AluExp && this.size < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
@@ -1977,6 +2036,12 @@ var Array$1 = class Array$1 extends Tracer {
1977
2036
  [Primitive.Log]([x]) {
1978
2037
  return [x.#unary(require_backend.AluOp.Log)];
1979
2038
  },
2039
+ [Primitive.Erf]([x]) {
2040
+ return [x.#unary(require_backend.AluOp.Erf)];
2041
+ },
2042
+ [Primitive.Erfc]([x]) {
2043
+ return [x.#unary(require_backend.AluOp.Erfc)];
2044
+ },
1980
2045
  [Primitive.Sqrt]([x]) {
1981
2046
  return [x.#unary(require_backend.AluOp.Sqrt)];
1982
2047
  },
@@ -2045,7 +2110,8 @@ var Array$1 = class Array$1 extends Tracer {
2045
2110
  },
2046
2111
  [Primitive.JitCall](args, { jaxpr, numConsts }) {
2047
2112
  if (jaxpr.inBinders.length !== args.length) throw new Error(`jit_call expects ${jaxpr.inBinders.length} args, got ${args.length}`);
2048
- const backend = require_backend.getBackend();
2113
+ const { backend, committed } = Array$1.#computeBackend("jit_call", args);
2114
+ args = args.map((ar) => ar._putSync(backend));
2049
2115
  const consts = args.slice(0, numConsts);
2050
2116
  const tracers = args.slice(numConsts);
2051
2117
  const jp = jitCompile(backend, jaxpr, consts);
@@ -2062,16 +2128,54 @@ var Array$1 = class Array$1 extends Tracer {
2062
2128
  dtype: jaxpr.outs[i].aval.dtype,
2063
2129
  weakType: jaxpr.outs[i].aval.weakType,
2064
2130
  backend,
2131
+ committed,
2065
2132
  pending
2066
2133
  });
2067
2134
  });
2068
2135
  }
2069
2136
  };
2070
2137
  }
2138
+ /** @private */
2071
2139
  _realizeSource() {
2072
2140
  this.#realize();
2073
2141
  return this.#source;
2074
2142
  }
2143
+ /** @private Put this array on a new backend, asynchronously. */
2144
+ async _put(backend) {
2145
+ if (this.#backend === backend) return this;
2146
+ if (this.#source instanceof require_backend.AluExp) {
2147
+ const ar = this.#newArrayFrom({
2148
+ backend,
2149
+ committed: true
2150
+ });
2151
+ this.dispose();
2152
+ return ar;
2153
+ } else {
2154
+ const data = await this.data();
2155
+ return arrayFromData(data, this.shape, {
2156
+ dtype: this.#dtype,
2157
+ device: backend.type
2158
+ }, this.#weakType);
2159
+ }
2160
+ }
2161
+ /** @private Put this array on a new backend, synchronously. */
2162
+ _putSync(backend) {
2163
+ if (this.#backend === backend) return this;
2164
+ if (this.#source instanceof require_backend.AluExp) {
2165
+ const ar = this.#newArrayFrom({
2166
+ backend,
2167
+ committed: true
2168
+ });
2169
+ this.dispose();
2170
+ return ar;
2171
+ } else {
2172
+ const data = this.dataSync();
2173
+ return arrayFromData(data, this.shape, {
2174
+ dtype: this.#dtype,
2175
+ device: backend.type
2176
+ }, this.#weakType);
2177
+ }
2178
+ }
2075
2179
  };
2076
2180
  /** Constructor for creating a new array from data. */
2077
2181
  function array(values, { shape: shape$1, dtype, device } = {}) {
@@ -2134,6 +2238,9 @@ function arrayFromData(data, shape$1, { dtype, device }, weakType = false) {
2134
2238
  } else if (data instanceof Float16Array) {
2135
2239
  if (dtype && dtype !== require_backend.DType.Float16) throw new Error("Float16Array must have float16 type");
2136
2240
  dtype ??= require_backend.DType.Float16;
2241
+ } else if (data instanceof Float64Array) {
2242
+ if (dtype && dtype !== require_backend.DType.Float64) throw new Error("Float64Array must have float64 type");
2243
+ dtype ??= require_backend.DType.Float64;
2137
2244
  } else throw new Error("Unsupported data array type: " + data.constructor.name);
2138
2245
  if (data.length < inlineArrayLimit) {
2139
2246
  let allEqual = true;
@@ -2154,7 +2261,8 @@ function arrayFromData(data, shape$1, { dtype, device }, weakType = false) {
2154
2261
  st: require_backend.ShapeTracker.fromShape(shape$1),
2155
2262
  dtype,
2156
2263
  weakType,
2157
- backend
2264
+ backend,
2265
+ committed: device != void 0
2158
2266
  });
2159
2267
  }
2160
2268
  function dataToJs(dtype, data, shape$1) {
@@ -2188,7 +2296,8 @@ function fullInternal(aval, fillValue, device) {
2188
2296
  st: require_backend.ShapeTracker.fromShape(aval.shape),
2189
2297
  dtype: aval.dtype,
2190
2298
  weakType: aval.weakType,
2191
- backend: require_backend.getBackend(device)
2299
+ backend: require_backend.getBackend(device),
2300
+ committed: device != void 0
2192
2301
  });
2193
2302
  }
2194
2303
  function zerosLike$1(val, dtype) {
@@ -2256,7 +2365,8 @@ function eye(numRows, numCols, { dtype, device } = {}) {
2256
2365
  st: require_backend.ShapeTracker.fromShape([numRows, numCols]),
2257
2366
  dtype,
2258
2367
  weakType,
2259
- backend: require_backend.getBackend(device)
2368
+ backend: require_backend.getBackend(device),
2369
+ committed: device != void 0
2260
2370
  });
2261
2371
  }
2262
2372
  /** Return the identity matrix, with ones on the main diagonal. */
@@ -2299,7 +2409,8 @@ function arange(start, stop, step = 1, { dtype, device } = {}) {
2299
2409
  st,
2300
2410
  dtype,
2301
2411
  weakType: false,
2302
- backend: require_backend.getBackend(device)
2412
+ backend: require_backend.getBackend(device),
2413
+ committed: device != void 0
2303
2414
  });
2304
2415
  }
2305
2416
  /**
@@ -2335,16 +2446,15 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
2335
2446
  st,
2336
2447
  dtype,
2337
2448
  weakType: false,
2338
- backend: require_backend.getBackend(device)
2449
+ backend: require_backend.getBackend(device),
2450
+ committed: device != void 0
2339
2451
  });
2340
2452
  }
2341
2453
  function aluCompare(a, b, op) {
2342
2454
  switch (op) {
2343
- case CompareOp.Greater: return require_backend.AluExp.mul(require_backend.AluExp.cmpne(a, b), require_backend.AluExp.cmplt(a, b).not());
2344
2455
  case CompareOp.Less: return require_backend.AluExp.cmplt(a, b);
2345
2456
  case CompareOp.Equal: return require_backend.AluExp.cmpne(a, b).not();
2346
2457
  case CompareOp.NotEqual: return require_backend.AluExp.cmpne(a, b);
2347
- case CompareOp.GreaterEqual: return require_backend.AluExp.cmplt(a, b).not();
2348
2458
  case CompareOp.LessEqual: return require_backend.AluExp.add(require_backend.AluExp.cmplt(a, b), require_backend.AluExp.cmpne(a, b).not());
2349
2459
  }
2350
2460
  }
@@ -2481,7 +2591,7 @@ var JaxprEqn = class {
2481
2591
  const paramsList = Object.entries(this.params).map(([k, v]) => require_backend.PPrint.pp(`${k}=${v}`));
2482
2592
  if (paramsList.length > 0) rhs = rhs.stack(require_backend.PPrint.pp(" [ ")).stack(require_backend.PPrint.prototype.concat(...paramsList)).stack(require_backend.PPrint.pp(" ] "));
2483
2593
  else rhs = rhs.stack(require_backend.PPrint.pp(" "));
2484
- rhs = rhs.stack(require_backend.PPrint.pp(this.inputs.map((x) => x instanceof Var ? vp.name(x) : JSON.stringify(x.value)).join(" ")));
2594
+ rhs = rhs.stack(require_backend.PPrint.pp(this.inputs.map((x) => x instanceof Var ? vp.name(x) : String(x.value)).join(" ")));
2485
2595
  return lhs.stack(require_backend.PPrint.pp(" = ")).stack(rhs);
2486
2596
  }
2487
2597
  toString() {
@@ -2847,6 +2957,8 @@ const abstractEvalRules = {
2847
2957
  [Primitive.Atan]: vectorizedUnopAbstractEval,
2848
2958
  [Primitive.Exp]: vectorizedUnopAbstractEval,
2849
2959
  [Primitive.Log]: vectorizedUnopAbstractEval,
2960
+ [Primitive.Erf]: vectorizedUnopAbstractEval,
2961
+ [Primitive.Erfc]: vectorizedUnopAbstractEval,
2850
2962
  [Primitive.Sqrt]: vectorizedUnopAbstractEval,
2851
2963
  [Primitive.Min]: binopAbstractEval,
2852
2964
  [Primitive.Max]: binopAbstractEval,
@@ -3100,6 +3212,16 @@ const jvpRules = {
3100
3212
  [Primitive.Log]([x], [dx]) {
3101
3213
  return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
3102
3214
  },
3215
+ [Primitive.Erf]([x], [dx]) {
3216
+ const coeff = 2 / Math.sqrt(Math.PI);
3217
+ const expTerm = exp$1(neg(x.ref.mul(x.ref)));
3218
+ return [[erf$1(x)], [expTerm.mul(coeff).mul(dx)]];
3219
+ },
3220
+ [Primitive.Erfc]([x], [dx]) {
3221
+ const coeff = -2 / Math.sqrt(Math.PI);
3222
+ const expTerm = exp$1(neg(x.ref.mul(x.ref)));
3223
+ return [[erfc$1(x)], [expTerm.mul(coeff).mul(dx)]];
3224
+ },
3103
3225
  [Primitive.Sqrt]([x], [dx]) {
3104
3226
  const z = sqrt$1(x);
3105
3227
  return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
@@ -3262,6 +3384,10 @@ var BatchTrace = class extends Trace {
3262
3384
  const [valsIn, bdimsIn] = require_backend.unzip2(tracers.map((t) => [t.val, t.batchDim]));
3263
3385
  const vmapRule = vmapRules[primitive];
3264
3386
  if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
3387
+ if (bdimsIn.every((d) => d === null)) {
3388
+ const valOuts$1 = bind(primitive, valsIn, params);
3389
+ return valOuts$1.map((x) => new BatchTracer(this, x, null));
3390
+ }
3265
3391
  const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
3266
3392
  return require_backend.zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
3267
3393
  }
@@ -3269,24 +3395,28 @@ var BatchTrace = class extends Trace {
3269
3395
  return this.main.globalData;
3270
3396
  }
3271
3397
  };
3272
- function handleScalarBroadcasting(nd, x, d) {
3273
- if (d === null || nd === ndim$1(x)) return x;
3274
- else {
3275
- const axis = require_backend.range(ndim$1(x), nd);
3276
- const shape$1 = [...x.shape, ...axis.map(() => 1)];
3277
- return broadcast(x, shape$1, axis);
3278
- }
3279
- }
3280
- /** Process a primitive with built-in broadcasting. */
3398
+ /**
3399
+ * Process a primitive with built-in broadcasting.
3400
+ *
3401
+ * Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
3402
+ */
3281
3403
  function broadcastBatcher(op) {
3282
3404
  return (axisSize, args, dims) => {
3283
3405
  if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
3284
- const idx = dims.findIndex((d) => d !== null);
3285
- if (idx === -1) return [[op(...args)], [null]];
3286
- if (require_backend.zip(args, dims).every(([x, d]) => ndim$1(x) === 0 || require_backend.deepEqual(x.shape, args[idx].shape) && d === dims[idx])) return [[op(...args)], [dims[idx]]];
3287
- args = args.map((x, i) => ndim$1(x) > 0 ? moveBatchAxis(axisSize, dims[i], 0, x) : x);
3288
- const nd = Math.max(...args.map(ndim$1));
3289
- args = args.map((x, i) => handleScalarBroadcasting(nd, x, dims[i]));
3406
+ const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
3407
+ const firstIdx = dims.findIndex((d) => d !== null);
3408
+ const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
3409
+ if (require_backend.zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
3410
+ args = args.map((x, i) => {
3411
+ if (dims[i] === null) return x;
3412
+ x = moveBatchAxis(axisSize, dims[i], 0, x);
3413
+ if (x.ndim < nd) x = x.reshape([
3414
+ x.shape[0],
3415
+ ...require_backend.rep(nd - x.ndim, 1),
3416
+ ...x.shape.slice(1)
3417
+ ]);
3418
+ return x;
3419
+ });
3290
3420
  return [[op(...args)], [0]];
3291
3421
  };
3292
3422
  }
@@ -3310,17 +3440,18 @@ const vmapRules = {
3310
3440
  [Primitive.Atan]: unopBatcher(atan$1),
3311
3441
  [Primitive.Exp]: unopBatcher(exp$1),
3312
3442
  [Primitive.Log]: unopBatcher(log$1),
3443
+ [Primitive.Erf]: unopBatcher(erf$1),
3444
+ [Primitive.Erfc]: unopBatcher(erfc$1),
3313
3445
  [Primitive.Sqrt]: unopBatcher(sqrt$1),
3314
3446
  [Primitive.Min]: broadcastBatcher(min$1),
3315
3447
  [Primitive.Max]: broadcastBatcher(max$1),
3316
3448
  [Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
3317
- if (xBdim === null) return [[reduce(x, op, axis)], [null]];
3449
+ require_backend.assertNonNull(xBdim);
3318
3450
  const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3319
3451
  const outBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
3320
3452
  return [[reduce(x, op, newAxis)], [outBdim]];
3321
3453
  },
3322
3454
  [Primitive.Dot](axisSize, [x, y], [xBdim, yBdim]) {
3323
- if (xBdim === null && yBdim === null) return [[dot$1(x, y)], [null]];
3324
3455
  x = moveBatchAxis(axisSize, xBdim, x.ndim - (xBdim === null ? 1 : 2), x);
3325
3456
  y = moveBatchAxis(axisSize, yBdim, y.ndim - (yBdim === null ? 1 : 2), y);
3326
3457
  const z = dot$1(x, y);
@@ -3329,26 +3460,68 @@ const vmapRules = {
3329
3460
  [Primitive.Compare](axisSize, args, dims, { op }) {
3330
3461
  return broadcastBatcher((x, y) => compare(x, y, op))(axisSize, args, dims, {});
3331
3462
  },
3463
+ [Primitive.Where]: broadcastBatcher(where$1),
3464
+ [Primitive.Transpose](axisSize, [x], [xBdim], { perm }) {
3465
+ require_backend.assertNonNull(xBdim);
3466
+ const newPerm = perm.map((p) => p + (xBdim <= p ? 1 : 0));
3467
+ newPerm.splice(xBdim, 0, xBdim);
3468
+ return [[transpose$1(x, newPerm)], [xBdim]];
3469
+ },
3470
+ [Primitive.Broadcast](axisSize, [x], [xBdim], { shape: shape$1, axis }) {
3471
+ require_backend.assertNonNull(xBdim);
3472
+ const newShape = shape$1.toSpliced(xBdim, 0, axisSize);
3473
+ const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3474
+ return [[broadcast(x, newShape, newAxis)], [xBdim]];
3475
+ },
3332
3476
  [Primitive.Reshape](axisSize, [x], [xBdim], { shape: shape$1 }) {
3333
- if (xBdim === null) return [[reshape$1(x, shape$1)], [null]];
3334
3477
  x = moveBatchAxis(axisSize, xBdim, 0, x);
3335
3478
  return [[reshape$1(x, [axisSize, ...shape$1])], [0]];
3336
3479
  },
3337
3480
  [Primitive.Flip](axisSize, [x], [xBdim], { axis }) {
3338
- if (xBdim === null) return [[flip$1(x, axis)], [null]];
3481
+ require_backend.assertNonNull(xBdim);
3339
3482
  const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3340
3483
  return [[flip$1(x, newAxis)], [xBdim]];
3341
3484
  },
3342
3485
  [Primitive.Shrink](axisSize, [x], [xBdim], { slice }) {
3343
- if (xBdim === null) return [[shrink(x, slice)], [null]];
3486
+ require_backend.assertNonNull(xBdim);
3344
3487
  const newSlice = slice.toSpliced(xBdim, 0, [0, axisSize]);
3345
3488
  return [[shrink(x, newSlice)], [xBdim]];
3346
3489
  },
3347
3490
  [Primitive.Pad](axisSize, [x], [xBdim], { width }) {
3348
- if (xBdim === null) return [[pad$1(x, width)], [null]];
3491
+ require_backend.assertNonNull(xBdim);
3349
3492
  const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
3350
3493
  return [[pad$1(x, newWidth)], [xBdim]];
3351
3494
  },
3495
+ [Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
3496
+ if (indicesBdim.every((d) => d === null)) {
3497
+ require_backend.assertNonNull(xBdim);
3498
+ const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3499
+ let newBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
3500
+ let newOutDim = outDim;
3501
+ if (newOutDim < newBdim) newBdim += axis.length;
3502
+ else newOutDim += 1;
3503
+ return [[gather(x, indices, newAxis, newOutDim)], [newBdim]];
3504
+ }
3505
+ const nd = Math.max(...indices.map((m, i) => ndim$1(m) + (indicesBdim[i] === null ? 1 : 0)));
3506
+ indices = indices.map((m, i) => {
3507
+ if (indicesBdim[i] === null) return m;
3508
+ m = moveBatchAxis(axisSize, indicesBdim[i], 0, m);
3509
+ if (m.ndim < nd) m = m.reshape([
3510
+ m.shape[0],
3511
+ ...require_backend.rep(nd - m.ndim, 1),
3512
+ ...m.shape.slice(1)
3513
+ ]);
3514
+ return m;
3515
+ });
3516
+ if (xBdim === null) return [[gather(x, indices, axis, outDim)], [outDim]];
3517
+ else {
3518
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3519
+ const newAxis = [0, ...axis.map((ax) => ax + 1)];
3520
+ const extraBatchIndex = arange(axisSize).reshape([-1, ...require_backend.rep(nd - 1, 1)]);
3521
+ indices.splice(0, 0, extraBatchIndex);
3522
+ return [[gather(x, indices, newAxis, outDim)], [outDim]];
3523
+ }
3524
+ },
3352
3525
  [Primitive.JitCall](axisSize, args, dims, { name, jaxpr }) {
3353
3526
  const { newJaxpr, newConsts } = vmapJaxpr(jaxpr, axisSize, dims);
3354
3527
  const outs = bind(Primitive.JitCall, [...newConsts.map((c) => c.ref), ...args], {
@@ -3408,12 +3581,14 @@ function vmapFlat(f, inAxes, args) {
3408
3581
  function vmap$1(f, inAxes = 0) {
3409
3582
  return (...args) => {
3410
3583
  const [argsFlat, inTree] = flatten(args);
3411
- let inAxesFlat;
3584
+ let inAxesFlat = [];
3412
3585
  if (typeof inAxes === "number") inAxesFlat = require_backend.rep(argsFlat.length, inAxes);
3586
+ else for (let i = 0; i < args.length; i++) if (inAxes[i] == null) inAxesFlat.push(...require_backend.rep(inTree.childTreedefs[i].size, null));
3587
+ else if (typeof inAxes[i] === "number") inAxesFlat.push(...require_backend.rep(inTree.childTreedefs[i].size, inAxes[i]));
3413
3588
  else {
3414
- let inTree2;
3415
- [inAxesFlat, inTree2] = flatten(inAxes);
3416
- if (!inTree.equals(inTree2)) throw new TreeMismatchError("vmap", inTree, inTree2);
3589
+ const [axesFlat, axesTreeDef] = flatten(inAxes[i]);
3590
+ if (!inTree.childTreedefs[i].equals(axesTreeDef)) throw new TreeMismatchError("vmap", inTree.childTreedefs[i], axesTreeDef);
3591
+ inAxesFlat.push(...axesFlat);
3417
3592
  }
3418
3593
  const [fFlat, outTree] = flattenFun(f, inTree);
3419
3594
  const outsFlat = vmapFlat(fFlat, inAxesFlat, argsFlat);
@@ -4033,7 +4208,7 @@ function valueAndGrad$1(f) {
4033
4208
  const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
4034
4209
  if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
4035
4210
  if (!require_backend.isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
4036
- const [ct, ...rest] = fVjp(array(1, { dtype: y.dtype }));
4211
+ const [ct, ...rest] = fVjp(onesLike$1(y.ref));
4037
4212
  for (const r of rest) dispose(r);
4038
4213
  fVjp.dispose();
4039
4214
  return [y, ct];
@@ -4061,7 +4236,10 @@ __export(lax_exports, {
4061
4236
  conv: () => conv$1,
4062
4237
  convGeneralDilated: () => convGeneralDilated,
4063
4238
  convWithGeneralPadding: () => convWithGeneralPadding,
4064
- reduceWindow: () => reduceWindow
4239
+ erf: () => erf,
4240
+ erfc: () => erfc,
4241
+ reduceWindow: () => reduceWindow,
4242
+ stopGradient: () => stopGradient$1
4065
4243
  });
4066
4244
  function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
4067
4245
  const padType = padding.toUpperCase();
@@ -4120,6 +4298,28 @@ function reduceWindow(operand, computation, windowDimensions, windowStrides) {
4120
4298
  strides: windowStrides
4121
4299
  }));
4122
4300
  }
4301
+ /** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
4302
+ function erf(x) {
4303
+ return erf$1(x);
4304
+ }
4305
+ /**
4306
+ * The complementary error function: `erfc(x) = 1 - erf(x)`.
4307
+ *
4308
+ * This function is more accurate than `1 - erf(x)` for large values of `x`,
4309
+ * where `erf(x)` is very close to 1.
4310
+ */
4311
+ function erfc(x) {
4312
+ return erfc$1(x);
4313
+ }
4314
+ /**
4315
+ * Stops gradient computation.
4316
+ *
4317
+ * Behaves as the identity function but prevents the flow of gradients during
4318
+ * forward or reverse-mode automatic differentiation.
4319
+ */
4320
+ function stopGradient$1(x) {
4321
+ return stopGradient(x);
4322
+ }
4123
4323
 
4124
4324
  //#endregion
4125
4325
  //#region src/numpy.ts
@@ -4178,16 +4378,25 @@ __export(numpy_exports, {
4178
4378
  flipud: () => flipud,
4179
4379
  float16: () => float16,
4180
4380
  float32: () => float32,
4381
+ float64: () => float64,
4181
4382
  full: () => full,
4182
4383
  fullLike: () => fullLike$1,
4183
4384
  greater: () => greater,
4184
4385
  greaterEqual: () => greaterEqual,
4386
+ hamming: () => hamming,
4387
+ hann: () => hann,
4388
+ heaviside: () => heaviside,
4185
4389
  hstack: () => hstack,
4186
4390
  hypot: () => hypot,
4187
4391
  identity: () => identity$1,
4188
4392
  inf: () => inf,
4189
4393
  inner: () => inner,
4190
4394
  int32: () => int32,
4395
+ isfinite: () => isfinite,
4396
+ isinf: () => isinf,
4397
+ isnan: () => isnan,
4398
+ isneginf: () => isneginf,
4399
+ isposinf: () => isposinf,
4191
4400
  less: () => less,
4192
4401
  lessEqual: () => lessEqual,
4193
4402
  linspace: () => linspace,
@@ -4258,6 +4467,7 @@ const int32 = require_backend.DType.Int32;
4258
4467
  const uint32 = require_backend.DType.Uint32;
4259
4468
  const bool = require_backend.DType.Bool;
4260
4469
  const float16 = require_backend.DType.Float16;
4470
+ const float64 = require_backend.DType.Float64;
4261
4471
  /** Euler's constant, `e = 2.7182818284590...` */
4262
4472
  const e = Math.E;
4263
4473
  /** Euler-Mascheroni constant, `γ = 0.5772156649...` */
@@ -4821,6 +5031,32 @@ function sign(x) {
4821
5031
  x = fudgeArray(x);
4822
5032
  return where(notEqual(x.ref, 0), where(less(x.ref, 0), -1, 1), 0);
4823
5033
  }
5034
+ /**
5035
+ * Return the Hamming window of size M, a taper with a weighted cosine bell.
5036
+ *
5037
+ * `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
5038
+ */
5039
+ function hamming(M) {
5040
+ return cos(linspace(0, 2 * Math.PI, M)).mul(-.46).add(.54);
5041
+ }
5042
+ /**
5043
+ * Return the Hann window of size M, a taper with a weighted cosine bell.
5044
+ *
5045
+ * `w(n) = 0.5 - 0.5 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
5046
+ */
5047
+ function hann(M) {
5048
+ return cos(linspace(0, 2 * Math.PI, M)).mul(-.5).add(.5);
5049
+ }
5050
+ /**
5051
+ * @function
5052
+ * Compute the Heaviside step function. It is defined piecewise:
5053
+ * - `heaviside(x1, x2) = 0` for `x1 < 0`,
5054
+ * - `heaviside(x1, x2) = x2` for `x1 == 0`,
5055
+ * - `heaviside(x1, x2) = 1` for `x1 > 0`.
5056
+ */
5057
+ const heaviside = jit$1(function heaviside$1(x1, x2) {
5058
+ return where(less(x1.ref, 0), 0, where(equal(x1, 0), x2, 1));
5059
+ });
4824
5060
  /** Calculate element-wise square of the input array. */
4825
5061
  function square(x) {
4826
5062
  x = fudgeArray(x);
@@ -4840,8 +5076,8 @@ function acos(x) {
4840
5076
  * Return element-wise hypotenuse for the given legs of a right triangle.
4841
5077
  *
4842
5078
  * In the original NumPy/JAX implementation, this function is more numerically
4843
- * stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
4844
- * improvements.
5079
+ * stable than `sqrt(x1**2 + x2**2)`. We don't currently implement those
5080
+ * stability improvements.
4845
5081
  */
4846
5082
  const hypot = jit$1(function hypot$1(x1, x2) {
4847
5083
  return sqrt(square(x1).add(square(x2)));
@@ -5032,6 +5268,34 @@ function var_(x, axis = null, opts) {
5032
5268
  function std(x, axis = null, opts) {
5033
5269
  return sqrt(var_(x, axis, opts));
5034
5270
  }
5271
+ /** Test element-wise for positive or negative infinity, return bool array. */
5272
+ function isinf(x) {
5273
+ x = fudgeArray(x);
5274
+ return require_backend.isFloatDtype(x.dtype) ? x.ref.equal(Infinity).add(x.equal(-Infinity)) : fullLike$1(x, false);
5275
+ }
5276
+ /** Test element-wise for NaN (Not a Number). */
5277
+ function isnan(x) {
5278
+ x = fudgeArray(x);
5279
+ return require_backend.isFloatDtype(x.dtype) ? x.ref.notEqual(x) : fullLike$1(x, false);
5280
+ }
5281
+ /** Test element-wise for negative infinity, return bool array. */
5282
+ function isneginf(x) {
5283
+ x = fudgeArray(x);
5284
+ return require_backend.isFloatDtype(x.dtype) ? x.equal(-Infinity) : fullLike$1(x, false);
5285
+ }
5286
+ /** Test element-wise for positive infinity, return bool array. */
5287
+ function isposinf(x) {
5288
+ x = fudgeArray(x);
5289
+ return require_backend.isFloatDtype(x.dtype) ? x.equal(Infinity) : fullLike$1(x, false);
5290
+ }
5291
+ /**
5292
+ * @function
5293
+ * Test element-wise for finite values (not infinity or NaN).
5294
+ */
5295
+ const isfinite = jit$1(function isfinite$1(x) {
5296
+ if (!require_backend.isFloatDtype(x.dtype)) return fullLike$1(x, true);
5297
+ return isnan(x.ref).add(isinf(x)).notEqual(true);
5298
+ });
5035
5299
 
5036
5300
  //#endregion
5037
5301
  //#region src/nn.ts
@@ -5165,18 +5429,20 @@ function celu(x, alpha = 1) {
5165
5429
  * @function
5166
5430
  * Gaussion error linear unit (GELU) activation function.
5167
5431
  *
5168
- * This is computed element-wise. Currently jax-js does not support the erf() or
5169
- * gelu() functions exactly as primitives, so an approximation is used:
5170
- * `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
5432
+ * This is computed element-wise. There are two variants depending on whether
5433
+ * `approximate` is set (default true):
5171
5434
  *
5172
- * Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
5435
+ * - Approximate: `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`
5436
+ * - Exact: `gelu(x) = x * 0.5 * erfc(-x / sqrt(2))`
5173
5437
  *
5174
- * This will be improved in the future.
5438
+ * Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
5175
5439
  */
5176
- const gelu = jit$1(function gelu$1(x) {
5177
- const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
5178
- return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
5179
- });
5440
+ const gelu = jit$1(function gelu$1(x, opts) {
5441
+ if (opts?.approximate ?? true) {
5442
+ const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
5443
+ return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
5444
+ } else return x.ref.mul(.5).mul(erfc$1(negative(x.ref.mul(Math.SQRT1_2))));
5445
+ }, { staticArgnums: [1] });
5180
5446
  /**
5181
5447
  * Gated linear unit (GLU) activation function.
5182
5448
  *
@@ -5397,6 +5663,25 @@ const normal = jit$1(function normal$1(key$1, shape$1 = []) {
5397
5663
  return radius.mul(cos(theta));
5398
5664
  }, { staticArgnums: [1] });
5399
5665
 
5666
+ //#endregion
5667
+ //#region src/scipy-special.ts
5668
+ var scipy_special_exports = {};
5669
+ __export(scipy_special_exports, {
5670
+ erf: () => erf,
5671
+ erfc: () => erfc,
5672
+ logSoftmax: () => logSoftmax,
5673
+ logit: () => logit,
5674
+ logsumexp: () => logsumexp,
5675
+ softmax: () => softmax
5676
+ });
5677
+ /**
5678
+ * @function
5679
+ * The logit function, `logit(p) = log(p / (1-p))`.
5680
+ */
5681
+ const logit = jit$1(function logit$1(x) {
5682
+ return log(x.ref.div(subtract(1, x)));
5683
+ });
5684
+
5400
5685
  //#endregion
5401
5686
  //#region src/polyfills.ts
5402
5687
  /** @file Polyfills for using this library. */
@@ -5490,6 +5775,24 @@ async function blockUntilReady(x) {
5490
5775
  await Promise.all(promises);
5491
5776
  return x;
5492
5777
  }
5778
+ /**
5779
+ * Transfer `x` to `device`.
5780
+ *
5781
+ * `x` may be a nested container of arrays or scalars. The resulting structure
5782
+ * is committed to the device.
5783
+ *
5784
+ * If `device` is not specified, this function behaves as identity if the input
5785
+ * is already an `Array`, otherwise it places the scalar uncommitted on the
5786
+ * default device.
5787
+ */
5788
+ async function devicePut(x, device) {
5789
+ const [xflat, structure$1] = flatten(x);
5790
+ const yflat = await Promise.all(xflat.map((leaf) => {
5791
+ if (leaf instanceof Array$1) return device ? leaf._put(require_backend.getBackend(device)) : Promise.resolve(leaf);
5792
+ else return Promise.resolve(array(leaf, { device }));
5793
+ }));
5794
+ return unflatten(structure$1, yflat);
5795
+ }
5493
5796
 
5494
5797
  //#endregion
5495
5798
  exports.Array = Array$1;
@@ -5497,6 +5800,7 @@ exports.DType = require_backend.DType;
5497
5800
  exports.Jaxpr = Jaxpr;
5498
5801
  exports.blockUntilReady = blockUntilReady;
5499
5802
  exports.defaultDevice = require_backend.defaultDevice;
5803
+ exports.devicePut = devicePut;
5500
5804
  exports.devices = require_backend.devices;
5501
5805
  exports.grad = grad;
5502
5806
  exports.init = require_backend.init;
@@ -5531,6 +5835,12 @@ Object.defineProperty(exports, 'random', {
5531
5835
  return random_exports;
5532
5836
  }
5533
5837
  });
5838
+ Object.defineProperty(exports, 'scipySpecial', {
5839
+ enumerable: true,
5840
+ get: function () {
5841
+ return scipy_special_exports;
5842
+ }
5843
+ });
5534
5844
  exports.setDebug = require_backend.setDebug;
5535
5845
  Object.defineProperty(exports, 'tree', {
5536
5846
  enumerable: true,
@@ -5540,4 +5850,5 @@ Object.defineProperty(exports, 'tree', {
5540
5850
  });
5541
5851
  exports.valueAndGrad = valueAndGrad;
5542
5852
  exports.vjp = vjp;
5543
- exports.vmap = vmap;
5853
+ exports.vmap = vmap;
5854
+ //# sourceMappingURL=index.cjs.map