@jax-js/jax 0.0.5 → 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.cjs CHANGED
@@ -30,7 +30,7 @@ var __toESM = (mod, isNodeMode, target) => (target = mod != null ? __create(__ge
30
30
  }) : target, mod));
31
31
 
32
32
  //#endregion
33
- const require_backend = require('./backend-yEU0L_ig.cjs');
33
+ const require_backend = require('./backend-FtkbO6pI.cjs');
34
34
 
35
35
  //#region src/tree.ts
36
36
  var tree_exports = {};
@@ -60,6 +60,10 @@ var JsTreeDef = class JsTreeDef {
60
60
  this.nodeMetadata = nodeMetadata;
61
61
  this.childTreedefs = childTreedefs;
62
62
  }
63
+ /** Get the total number of leaves in the tree. */
64
+ get size() {
65
+ return this.nodeType === NodeType.Leaf ? 1 : this.childTreedefs.reduce((a, b) => a + b.size, 0);
66
+ }
63
67
  /** Returns a string representation of this tree definition. */
64
68
  toString(root = true) {
65
69
  if (root) return "JsTreeDef(" + this.toString(false) + ")";
@@ -215,6 +219,16 @@ function pool(st, ks, strides = 1, dilation = 1) {
215
219
  const s_ = strides;
216
220
  const d_ = dilation;
217
221
  const o_ = require_backend.zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
222
+ if (d_.every((d) => d === 1) && ks.every((k, j) => k <= s_[j])) {
223
+ st = st.padOrShrink([...noop.map(() => [0, 0]), ...require_backend.zipn(i_, o_, s_).map(([i, o, s]) => [0, o * s - i])]);
224
+ st = st.reshape([...noop, ...require_backend.zip(o_, s_).flatMap(([o, s]) => [o, s])]).shrink([...noop.map((x) => [0, x]), ...require_backend.zip(o_, ks).flatMap(([o, k]) => [[0, o], [0, k]])]);
225
+ st = st.permute([
226
+ ...require_backend.range(noop.length),
227
+ ...ks.map((_, j) => noop.length + 2 * j),
228
+ ...ks.map((_, j) => noop.length + 2 * j + 1)
229
+ ]);
230
+ return st;
231
+ }
218
232
  const f_ = require_backend.zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
219
233
  const kidf = require_backend.zipn(ks, i_, d_, f_);
220
234
  st = st.repeat([...require_backend.rep(noop.length, 1), ...kidf.map(([k, i, d, f]) => Math.ceil(k * (i * f + d) / i))]);
@@ -249,6 +263,12 @@ function poolTranspose(st, inShape, ks, strides = 1, dilation = 1) {
249
263
  const s_ = strides;
250
264
  const d_ = dilation;
251
265
  const o_ = require_backend.zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
266
+ if (d_.every((d) => d === 1) && ks.every((k, j) => k <= s_[j])) {
267
+ st = st.permute([...require_backend.range(noop.length), ...ks.flatMap((_, j) => [noop.length + j, noop.length + o_.length + j])]);
268
+ st = st.pad([...noop.map(() => [0, 0]), ...require_backend.zip(s_, ks).flatMap(([s, k]) => [[0, 0], [0, s - k]])]).reshape([...noop, ...require_backend.zip(o_, s_).map(([o, s]) => o * s)]);
269
+ st = st.padOrShrink([...noop.map(() => [0, 0]), ...require_backend.zipn(i_, o_, s_).map(([i, o, s]) => [0, i - o * s])]);
270
+ return st.reshape(st.shape.concat(require_backend.rep(ks.length, 1)));
271
+ }
252
272
  if (!require_backend.deepEqual(o_, st.shape.slice(noop.length, noop.length + ks.length))) throw new Error("poolTranspose() called with mismatched output shape");
253
273
  const f_ = require_backend.zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
254
274
  const kidf = require_backend.zipn(ks, i_, d_, f_);
@@ -358,6 +378,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
358
378
  Primitive$1["Atan"] = "atan";
359
379
  Primitive$1["Exp"] = "exp";
360
380
  Primitive$1["Log"] = "log";
381
+ Primitive$1["Erf"] = "erf";
382
+ Primitive$1["Erfc"] = "erfc";
361
383
  Primitive$1["Sqrt"] = "sqrt";
362
384
  Primitive$1["Min"] = "min";
363
385
  Primitive$1["Max"] = "max";
@@ -435,6 +457,12 @@ function exp$1(x) {
435
457
  function log$1(x) {
436
458
  return bind1(Primitive.Log, [x]);
437
459
  }
460
+ function erf$1(x) {
461
+ return bind1(Primitive.Erf, [x]);
462
+ }
463
+ function erfc$1(x) {
464
+ return bind1(Primitive.Erfc, [x]);
465
+ }
438
466
  function sqrt$1(x) {
439
467
  return bind1(Primitive.Sqrt, [x]);
440
468
  }
@@ -1177,12 +1205,18 @@ function reshapeViews(exp$2, mapping, reduceAxis = false) {
1177
1205
  } else if (exp$3.op === require_backend.AluOp.GlobalIndex) throw new Error("internal: reshapeViews() called with GlobalIndex op");
1178
1206
  });
1179
1207
  }
1180
- function broadcastedJit(fn) {
1208
+ function broadcastedJit(fn, opts) {
1181
1209
  return (nargs, exps, avals, params) => {
1182
- const newShape = avals.map((aval) => aval.shape).reduce(require_backend.generalBroadcast);
1183
- exps = exps.map((exp$3) => reshapeViews(exp$3, (st) => {
1184
- if (!require_backend.deepEqual(st.shape, newShape)) return st.broadcast(newShape, require_backend.range(newShape.length - st.shape.length));
1185
- }));
1210
+ let { shape: newShape, dtype: newDtype } = avals.reduce(promoteAvals);
1211
+ const skipCastIdx = opts?.skipCastIdx ?? [];
1212
+ if (skipCastIdx.length) newDtype = avals.filter((_, i) => !skipCastIdx.includes(i)).reduce(promoteAvals).dtype;
1213
+ exps = exps.map((exp$3, i) => {
1214
+ exp$3 = reshapeViews(exp$3, (st) => {
1215
+ if (!require_backend.deepEqual(st.shape, newShape)) return st.broadcast(newShape, require_backend.range(newShape.length - st.shape.length));
1216
+ });
1217
+ if (exp$3.dtype !== newDtype && !skipCastIdx.includes(i)) exp$3 = require_backend.AluExp.cast(newDtype, exp$3);
1218
+ return exp$3;
1219
+ });
1186
1220
  const exp$2 = fn(exps, params);
1187
1221
  return new require_backend.Kernel(nargs, require_backend.prod(newShape), exp$2);
1188
1222
  };
@@ -1225,6 +1259,8 @@ const jitRules = {
1225
1259
  [Primitive.Atan]: unopJit(require_backend.AluExp.atan),
1226
1260
  [Primitive.Exp]: unopJit(require_backend.AluExp.exp),
1227
1261
  [Primitive.Log]: unopJit(require_backend.AluExp.log),
1262
+ [Primitive.Erf]: unopJit(require_backend.AluExp.erf),
1263
+ [Primitive.Erfc]: unopJit(require_backend.AluExp.erfc),
1228
1264
  [Primitive.Sqrt]: unopJit(require_backend.AluExp.sqrt),
1229
1265
  [Primitive.Min]: broadcastedJit(([a, b]) => require_backend.AluExp.min(a, b)),
1230
1266
  [Primitive.Max]: broadcastedJit(([a, b]) => require_backend.AluExp.max(a, b)),
@@ -1272,7 +1308,7 @@ const jitRules = {
1272
1308
  return jitRules[Primitive.Dot](nargs, [a, b], [as, bs], {});
1273
1309
  },
1274
1310
  [Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
1275
- [Primitive.Where]: broadcastedJit(([cond, a, b]) => require_backend.AluExp.where(cond, a, b)),
1311
+ [Primitive.Where]: broadcastedJit(([cond, a, b]) => require_backend.AluExp.where(cond, a, b), { skipCastIdx: [0] }),
1276
1312
  [Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
1277
1313
  [Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
1278
1314
  [Primitive.Reshape]: reshapeJit((st, { shape: shape$1 }) => st.reshape(shape$1)),
@@ -1443,7 +1479,7 @@ var PendingExecute = class {
1443
1479
  /**
1444
1480
  * A multidimensional numeric array with data stored on CPU or GPU.
1445
1481
  *
1446
- * This is the library's core data type. Equivalent to `jnp.Array` from JAX, or
1482
+ * This is the library's core data type. Equivalent to `jax.Array` from JAX, or
1447
1483
  * `torch.Tensor`.
1448
1484
  *
1449
1485
  * Not to be confused with the JavaScript "Array" constructor. Avoid importing
@@ -1458,6 +1494,7 @@ var Array$1 = class Array$1 extends Tracer {
1458
1494
  #source;
1459
1495
  #st;
1460
1496
  #backend;
1497
+ #committed;
1461
1498
  #rc;
1462
1499
  #pendingSet;
1463
1500
  /**
@@ -1474,6 +1511,7 @@ var Array$1 = class Array$1 extends Tracer {
1474
1511
  this.#source = args.source;
1475
1512
  this.#st = args.st;
1476
1513
  this.#backend = args.backend;
1514
+ this.#committed = args.committed;
1477
1515
  this.#rc = 1;
1478
1516
  this.#pendingSet = new Set(args.pending);
1479
1517
  if (this.#pendingSet.size === 0) this.#pendingSet = null;
@@ -1501,6 +1539,7 @@ var Array$1 = class Array$1 extends Tracer {
1501
1539
  dtype: args.dtype ?? this.#dtype,
1502
1540
  weakType: this.#weakType,
1503
1541
  backend: args.backend ?? this.#backend,
1542
+ committed: args.committed ?? this.#committed,
1504
1543
  pending: args.pending ?? this.#pending ?? void 0
1505
1544
  });
1506
1545
  }
@@ -1556,9 +1595,10 @@ var Array$1 = class Array$1 extends Tracer {
1556
1595
  */
1557
1596
  #gather(indices, axis, outDim) {
1558
1597
  this.#check();
1559
- if (indices.some((a) => a.#backend !== this.#backend)) throw new TypeError(`Gather indices must have the same backend: ${this.#backend.type}`);
1560
1598
  const axisSet = new Set(axis);
1561
1599
  if (axisSet.size !== axis.length) throw new TypeError("Gather axis must not have duplicates");
1600
+ if (indices.some((a) => a.#committed && a.#backend !== this.#backend)) throw new TypeError(`Gather indices must have the same backend: ${this.#backend.type}`);
1601
+ indices = indices.map((ar) => ar._putSync(this.#backend));
1562
1602
  indices = Array$1.#broadcastArrays(indices);
1563
1603
  const indexShape = indices[0].shape;
1564
1604
  const finalShape = this.shape.filter((_, i) => !axisSet.has(i));
@@ -1627,6 +1667,7 @@ var Array$1 = class Array$1 extends Tracer {
1627
1667
  this.#check();
1628
1668
  if (this.#source instanceof require_backend.AluExp) {
1629
1669
  const exp$3 = new require_backend.AluExp(op, dtypeOutput, [this.#source]);
1670
+ this.dispose();
1630
1671
  return this.#newArrayFrom({
1631
1672
  source: exp$3.simplify(),
1632
1673
  dtype: dtypeOutput,
@@ -1655,21 +1696,19 @@ var Array$1 = class Array$1 extends Tracer {
1655
1696
  }
1656
1697
  static #naryCustom(name, custom, arrays, { dtypeOverride, strongTypeOutput, reduceAxis } = {}) {
1657
1698
  const n = arrays.length;
1658
- const backend = arrays[0].#backend;
1659
1699
  if (n === 0) throw new TypeError(`No inputs for ${name}`);
1660
1700
  for (const ar of arrays) ar.#check();
1661
1701
  let castDtype;
1662
1702
  let castWeakType = true;
1663
- for (let i = 0; i < n; i++) {
1664
- if (dtypeOverride?.[i]) {
1665
- if (arrays[i].#dtype !== dtypeOverride[i]) throw new TypeError(`Wrong dtype in ${name}: expected ${dtypeOverride[i]}, got ${arrays[i].#dtype}`);
1666
- } else if (castDtype === void 0) {
1667
- castDtype = arrays[i].#dtype;
1668
- castWeakType = arrays[i].#weakType;
1669
- } else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), new ShapedArray([], arrays[i].#dtype, arrays[i].#weakType)));
1670
- if (arrays[i].#backend !== backend) throw new TypeError(`Backend mismatch in ${name}: ${backend.type} vs ${arrays[i].#backend.type}`);
1671
- }
1703
+ for (let i = 0; i < n; i++) if (dtypeOverride?.[i]) {
1704
+ if (arrays[i].#dtype !== dtypeOverride[i]) throw new TypeError(`Wrong dtype in ${name}: expected ${dtypeOverride[i]}, got ${arrays[i].#dtype}`);
1705
+ } else if (castDtype === void 0) {
1706
+ castDtype = arrays[i].#dtype;
1707
+ castWeakType = arrays[i].#weakType;
1708
+ } else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), new ShapedArray([], arrays[i].#dtype, arrays[i].#weakType)));
1672
1709
  const weakType = castWeakType && !strongTypeOutput;
1710
+ const { backend, committed } = Array$1.#computeBackend(name, arrays);
1711
+ arrays = arrays.map((ar) => ar._putSync(backend));
1673
1712
  arrays = Array$1.#broadcastArrays(arrays);
1674
1713
  const newShape = [...arrays[0].shape];
1675
1714
  if (arrays.every((ar) => ar.#source instanceof require_backend.AluExp) && !reduceAxis) {
@@ -1679,12 +1718,14 @@ var Array$1 = class Array$1 extends Tracer {
1679
1718
  });
1680
1719
  if (arrays.every((ar) => require_backend.deepEqual(ar.#st, arrays[0].#st))) {
1681
1720
  const exp$4 = custom(sources);
1721
+ arrays.forEach((ar) => ar.dispose());
1682
1722
  return new Array$1({
1683
1723
  source: exp$4.simplify(),
1684
1724
  st: arrays[0].#st,
1685
1725
  dtype: exp$4.dtype,
1686
1726
  weakType,
1687
- backend
1727
+ backend,
1728
+ committed
1688
1729
  });
1689
1730
  }
1690
1731
  const exp$3 = custom(arrays.map((ar, i) => {
@@ -1693,12 +1734,14 @@ var Array$1 = class Array$1 extends Tracer {
1693
1734
  return require_backend.accessorAluExp(src$1, ar.#st, require_backend.unravelAlu(newShape, require_backend.AluVar.idx));
1694
1735
  }));
1695
1736
  const st = require_backend.ShapeTracker.fromShape(newShape);
1737
+ arrays.forEach((ar) => ar.dispose());
1696
1738
  return new Array$1({
1697
1739
  source: exp$3.simplify(),
1698
1740
  st,
1699
1741
  dtype: exp$3.dtype,
1700
1742
  weakType,
1701
- backend
1743
+ backend,
1744
+ committed
1702
1745
  });
1703
1746
  }
1704
1747
  let indices;
@@ -1734,13 +1777,14 @@ var Array$1 = class Array$1 extends Tracer {
1734
1777
  const pending = new Set([...arrays.flatMap((ar) => ar.#pending)]);
1735
1778
  for (const exe of pending) exe.updateRc(1);
1736
1779
  pending.add(new PendingExecute(backend, kernel, inputs, [output]));
1737
- for (const ar of arrays) ar.dispose();
1780
+ arrays.forEach((ar) => ar.dispose());
1738
1781
  return new Array$1({
1739
1782
  source: output,
1740
1783
  st: require_backend.ShapeTracker.fromShape(newShape),
1741
1784
  dtype: kernel.dtype,
1742
1785
  weakType,
1743
1786
  backend,
1787
+ committed,
1744
1788
  pending
1745
1789
  });
1746
1790
  }
@@ -1818,6 +1862,23 @@ var Array$1 = class Array$1 extends Tracer {
1818
1862
  return ar.#reshape(ar.#st.broadcast(newShape, require_backend.range(newShape.length - ar.ndim)));
1819
1863
  });
1820
1864
  }
1865
+ static #computeBackend(name, arrays) {
1866
+ const committed = arrays.filter((ar) => ar.#committed);
1867
+ if (committed.length > 0) {
1868
+ const backend = committed[0].#backend;
1869
+ for (const ar of committed) if (ar.#backend !== backend) throw new Error(`Device mismatch in ${name} between committed arrays on (${backend.type}, ${ar.#backend.type}), please move to the same device with devicePut()`);
1870
+ return {
1871
+ backend,
1872
+ committed: true
1873
+ };
1874
+ } else {
1875
+ const backend = arrays.length > 0 ? arrays[0].#backend : require_backend.getBackend();
1876
+ return {
1877
+ backend,
1878
+ committed: false
1879
+ };
1880
+ }
1881
+ }
1821
1882
  /** Realize the array and return it as data. */
1822
1883
  async data() {
1823
1884
  if (this.#source instanceof require_backend.AluExp && this.size < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
@@ -1977,6 +2038,12 @@ var Array$1 = class Array$1 extends Tracer {
1977
2038
  [Primitive.Log]([x]) {
1978
2039
  return [x.#unary(require_backend.AluOp.Log)];
1979
2040
  },
2041
+ [Primitive.Erf]([x]) {
2042
+ return [x.#unary(require_backend.AluOp.Erf)];
2043
+ },
2044
+ [Primitive.Erfc]([x]) {
2045
+ return [x.#unary(require_backend.AluOp.Erfc)];
2046
+ },
1980
2047
  [Primitive.Sqrt]([x]) {
1981
2048
  return [x.#unary(require_backend.AluOp.Sqrt)];
1982
2049
  },
@@ -2045,7 +2112,8 @@ var Array$1 = class Array$1 extends Tracer {
2045
2112
  },
2046
2113
  [Primitive.JitCall](args, { jaxpr, numConsts }) {
2047
2114
  if (jaxpr.inBinders.length !== args.length) throw new Error(`jit_call expects ${jaxpr.inBinders.length} args, got ${args.length}`);
2048
- const backend = require_backend.getBackend();
2115
+ const { backend, committed } = Array$1.#computeBackend("jit_call", args);
2116
+ args = args.map((ar) => ar._putSync(backend));
2049
2117
  const consts = args.slice(0, numConsts);
2050
2118
  const tracers = args.slice(numConsts);
2051
2119
  const jp = jitCompile(backend, jaxpr, consts);
@@ -2062,16 +2130,54 @@ var Array$1 = class Array$1 extends Tracer {
2062
2130
  dtype: jaxpr.outs[i].aval.dtype,
2063
2131
  weakType: jaxpr.outs[i].aval.weakType,
2064
2132
  backend,
2133
+ committed,
2065
2134
  pending
2066
2135
  });
2067
2136
  });
2068
2137
  }
2069
2138
  };
2070
2139
  }
2140
+ /** @private */
2071
2141
  _realizeSource() {
2072
2142
  this.#realize();
2073
2143
  return this.#source;
2074
2144
  }
2145
+ /** @private Put this array on a new backend, asynchronously. */
2146
+ async _put(backend) {
2147
+ if (this.#backend === backend) return this;
2148
+ if (this.#source instanceof require_backend.AluExp) {
2149
+ const ar = this.#newArrayFrom({
2150
+ backend,
2151
+ committed: true
2152
+ });
2153
+ this.dispose();
2154
+ return ar;
2155
+ } else {
2156
+ const data = await this.data();
2157
+ return arrayFromData(data, this.shape, {
2158
+ dtype: this.#dtype,
2159
+ device: backend.type
2160
+ }, this.#weakType);
2161
+ }
2162
+ }
2163
+ /** @private Put this array on a new backend, synchronously. */
2164
+ _putSync(backend) {
2165
+ if (this.#backend === backend) return this;
2166
+ if (this.#source instanceof require_backend.AluExp) {
2167
+ const ar = this.#newArrayFrom({
2168
+ backend,
2169
+ committed: true
2170
+ });
2171
+ this.dispose();
2172
+ return ar;
2173
+ } else {
2174
+ const data = this.dataSync();
2175
+ return arrayFromData(data, this.shape, {
2176
+ dtype: this.#dtype,
2177
+ device: backend.type
2178
+ }, this.#weakType);
2179
+ }
2180
+ }
2075
2181
  };
2076
2182
  /** Constructor for creating a new array from data. */
2077
2183
  function array(values, { shape: shape$1, dtype, device } = {}) {
@@ -2154,7 +2260,8 @@ function arrayFromData(data, shape$1, { dtype, device }, weakType = false) {
2154
2260
  st: require_backend.ShapeTracker.fromShape(shape$1),
2155
2261
  dtype,
2156
2262
  weakType,
2157
- backend
2263
+ backend,
2264
+ committed: device != void 0
2158
2265
  });
2159
2266
  }
2160
2267
  function dataToJs(dtype, data, shape$1) {
@@ -2188,7 +2295,8 @@ function fullInternal(aval, fillValue, device) {
2188
2295
  st: require_backend.ShapeTracker.fromShape(aval.shape),
2189
2296
  dtype: aval.dtype,
2190
2297
  weakType: aval.weakType,
2191
- backend: require_backend.getBackend(device)
2298
+ backend: require_backend.getBackend(device),
2299
+ committed: device != void 0
2192
2300
  });
2193
2301
  }
2194
2302
  function zerosLike$1(val, dtype) {
@@ -2256,7 +2364,8 @@ function eye(numRows, numCols, { dtype, device } = {}) {
2256
2364
  st: require_backend.ShapeTracker.fromShape([numRows, numCols]),
2257
2365
  dtype,
2258
2366
  weakType,
2259
- backend: require_backend.getBackend(device)
2367
+ backend: require_backend.getBackend(device),
2368
+ committed: device != void 0
2260
2369
  });
2261
2370
  }
2262
2371
  /** Return the identity matrix, with ones on the main diagonal. */
@@ -2299,7 +2408,8 @@ function arange(start, stop, step = 1, { dtype, device } = {}) {
2299
2408
  st,
2300
2409
  dtype,
2301
2410
  weakType: false,
2302
- backend: require_backend.getBackend(device)
2411
+ backend: require_backend.getBackend(device),
2412
+ committed: device != void 0
2303
2413
  });
2304
2414
  }
2305
2415
  /**
@@ -2335,7 +2445,8 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
2335
2445
  st,
2336
2446
  dtype,
2337
2447
  weakType: false,
2338
- backend: require_backend.getBackend(device)
2448
+ backend: require_backend.getBackend(device),
2449
+ committed: device != void 0
2339
2450
  });
2340
2451
  }
2341
2452
  function aluCompare(a, b, op) {
@@ -2847,6 +2958,8 @@ const abstractEvalRules = {
2847
2958
  [Primitive.Atan]: vectorizedUnopAbstractEval,
2848
2959
  [Primitive.Exp]: vectorizedUnopAbstractEval,
2849
2960
  [Primitive.Log]: vectorizedUnopAbstractEval,
2961
+ [Primitive.Erf]: vectorizedUnopAbstractEval,
2962
+ [Primitive.Erfc]: vectorizedUnopAbstractEval,
2850
2963
  [Primitive.Sqrt]: vectorizedUnopAbstractEval,
2851
2964
  [Primitive.Min]: binopAbstractEval,
2852
2965
  [Primitive.Max]: binopAbstractEval,
@@ -3100,6 +3213,16 @@ const jvpRules = {
3100
3213
  [Primitive.Log]([x], [dx]) {
3101
3214
  return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
3102
3215
  },
3216
+ [Primitive.Erf]([x], [dx]) {
3217
+ const coeff = 2 / Math.sqrt(Math.PI);
3218
+ const expTerm = exp$1(neg(x.ref.mul(x.ref)));
3219
+ return [[erf$1(x)], [expTerm.mul(coeff).mul(dx)]];
3220
+ },
3221
+ [Primitive.Erfc]([x], [dx]) {
3222
+ const coeff = -2 / Math.sqrt(Math.PI);
3223
+ const expTerm = exp$1(neg(x.ref.mul(x.ref)));
3224
+ return [[erfc$1(x)], [expTerm.mul(coeff).mul(dx)]];
3225
+ },
3103
3226
  [Primitive.Sqrt]([x], [dx]) {
3104
3227
  const z = sqrt$1(x);
3105
3228
  return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
@@ -3262,6 +3385,10 @@ var BatchTrace = class extends Trace {
3262
3385
  const [valsIn, bdimsIn] = require_backend.unzip2(tracers.map((t) => [t.val, t.batchDim]));
3263
3386
  const vmapRule = vmapRules[primitive];
3264
3387
  if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
3388
+ if (bdimsIn.every((d) => d === null)) {
3389
+ const valOuts$1 = bind(primitive, valsIn, params);
3390
+ return valOuts$1.map((x) => new BatchTracer(this, x, null));
3391
+ }
3265
3392
  const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
3266
3393
  return require_backend.zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
3267
3394
  }
@@ -3269,24 +3396,28 @@ var BatchTrace = class extends Trace {
3269
3396
  return this.main.globalData;
3270
3397
  }
3271
3398
  };
3272
- function handleScalarBroadcasting(nd, x, d) {
3273
- if (d === null || nd === ndim$1(x)) return x;
3274
- else {
3275
- const axis = require_backend.range(ndim$1(x), nd);
3276
- const shape$1 = [...x.shape, ...axis.map(() => 1)];
3277
- return broadcast(x, shape$1, axis);
3278
- }
3279
- }
3280
- /** Process a primitive with built-in broadcasting. */
3399
+ /**
3400
+ * Process a primitive with built-in broadcasting.
3401
+ *
3402
+ * Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
3403
+ */
3281
3404
  function broadcastBatcher(op) {
3282
3405
  return (axisSize, args, dims) => {
3283
3406
  if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
3284
- const idx = dims.findIndex((d) => d !== null);
3285
- if (idx === -1) return [[op(...args)], [null]];
3286
- if (require_backend.zip(args, dims).every(([x, d]) => ndim$1(x) === 0 || require_backend.deepEqual(x.shape, args[idx].shape) && d === dims[idx])) return [[op(...args)], [dims[idx]]];
3287
- args = args.map((x, i) => ndim$1(x) > 0 ? moveBatchAxis(axisSize, dims[i], 0, x) : x);
3288
- const nd = Math.max(...args.map(ndim$1));
3289
- args = args.map((x, i) => handleScalarBroadcasting(nd, x, dims[i]));
3407
+ const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
3408
+ const firstIdx = dims.findIndex((d) => d !== null);
3409
+ const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
3410
+ if (require_backend.zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
3411
+ args = args.map((x, i) => {
3412
+ if (dims[i] === null) return x;
3413
+ x = moveBatchAxis(axisSize, dims[i], 0, x);
3414
+ if (x.ndim < nd) x = x.reshape([
3415
+ x.shape[0],
3416
+ ...require_backend.rep(nd - x.ndim, 1),
3417
+ ...x.shape.slice(1)
3418
+ ]);
3419
+ return x;
3420
+ });
3290
3421
  return [[op(...args)], [0]];
3291
3422
  };
3292
3423
  }
@@ -3310,17 +3441,18 @@ const vmapRules = {
3310
3441
  [Primitive.Atan]: unopBatcher(atan$1),
3311
3442
  [Primitive.Exp]: unopBatcher(exp$1),
3312
3443
  [Primitive.Log]: unopBatcher(log$1),
3444
+ [Primitive.Erf]: unopBatcher(erf$1),
3445
+ [Primitive.Erfc]: unopBatcher(erfc$1),
3313
3446
  [Primitive.Sqrt]: unopBatcher(sqrt$1),
3314
3447
  [Primitive.Min]: broadcastBatcher(min$1),
3315
3448
  [Primitive.Max]: broadcastBatcher(max$1),
3316
3449
  [Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
3317
- if (xBdim === null) return [[reduce(x, op, axis)], [null]];
3450
+ require_backend.assertNonNull(xBdim);
3318
3451
  const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3319
3452
  const outBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
3320
3453
  return [[reduce(x, op, newAxis)], [outBdim]];
3321
3454
  },
3322
3455
  [Primitive.Dot](axisSize, [x, y], [xBdim, yBdim]) {
3323
- if (xBdim === null && yBdim === null) return [[dot$1(x, y)], [null]];
3324
3456
  x = moveBatchAxis(axisSize, xBdim, x.ndim - (xBdim === null ? 1 : 2), x);
3325
3457
  y = moveBatchAxis(axisSize, yBdim, y.ndim - (yBdim === null ? 1 : 2), y);
3326
3458
  const z = dot$1(x, y);
@@ -3329,26 +3461,68 @@ const vmapRules = {
3329
3461
  [Primitive.Compare](axisSize, args, dims, { op }) {
3330
3462
  return broadcastBatcher((x, y) => compare(x, y, op))(axisSize, args, dims, {});
3331
3463
  },
3464
+ [Primitive.Where]: broadcastBatcher(where$1),
3465
+ [Primitive.Transpose](axisSize, [x], [xBdim], { perm }) {
3466
+ require_backend.assertNonNull(xBdim);
3467
+ const newPerm = perm.map((p) => p + (xBdim <= p ? 1 : 0));
3468
+ newPerm.splice(xBdim, 0, xBdim);
3469
+ return [[transpose$1(x, newPerm)], [xBdim]];
3470
+ },
3471
+ [Primitive.Broadcast](axisSize, [x], [xBdim], { shape: shape$1, axis }) {
3472
+ require_backend.assertNonNull(xBdim);
3473
+ const newShape = shape$1.toSpliced(xBdim, 0, axisSize);
3474
+ const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3475
+ return [[broadcast(x, newShape, newAxis)], [xBdim]];
3476
+ },
3332
3477
  [Primitive.Reshape](axisSize, [x], [xBdim], { shape: shape$1 }) {
3333
- if (xBdim === null) return [[reshape$1(x, shape$1)], [null]];
3334
3478
  x = moveBatchAxis(axisSize, xBdim, 0, x);
3335
3479
  return [[reshape$1(x, [axisSize, ...shape$1])], [0]];
3336
3480
  },
3337
3481
  [Primitive.Flip](axisSize, [x], [xBdim], { axis }) {
3338
- if (xBdim === null) return [[flip$1(x, axis)], [null]];
3482
+ require_backend.assertNonNull(xBdim);
3339
3483
  const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3340
3484
  return [[flip$1(x, newAxis)], [xBdim]];
3341
3485
  },
3342
3486
  [Primitive.Shrink](axisSize, [x], [xBdim], { slice }) {
3343
- if (xBdim === null) return [[shrink(x, slice)], [null]];
3487
+ require_backend.assertNonNull(xBdim);
3344
3488
  const newSlice = slice.toSpliced(xBdim, 0, [0, axisSize]);
3345
3489
  return [[shrink(x, newSlice)], [xBdim]];
3346
3490
  },
3347
3491
  [Primitive.Pad](axisSize, [x], [xBdim], { width }) {
3348
- if (xBdim === null) return [[pad$1(x, width)], [null]];
3492
+ require_backend.assertNonNull(xBdim);
3349
3493
  const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
3350
3494
  return [[pad$1(x, newWidth)], [xBdim]];
3351
3495
  },
3496
+ [Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
3497
+ if (indicesBdim.every((d) => d === null)) {
3498
+ require_backend.assertNonNull(xBdim);
3499
+ const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3500
+ let newBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
3501
+ let newOutDim = outDim;
3502
+ if (newOutDim < newBdim) newBdim += axis.length;
3503
+ else newOutDim += 1;
3504
+ return [[gather(x, indices, newAxis, newOutDim)], [newBdim]];
3505
+ }
3506
+ const nd = Math.max(...indices.map((m, i) => ndim$1(m) + (indicesBdim[i] === null ? 1 : 0)));
3507
+ indices = indices.map((m, i) => {
3508
+ if (indicesBdim[i] === null) return m;
3509
+ m = moveBatchAxis(axisSize, indicesBdim[i], 0, m);
3510
+ if (m.ndim < nd) m = m.reshape([
3511
+ m.shape[0],
3512
+ ...require_backend.rep(nd - m.ndim, 1),
3513
+ ...m.shape.slice(1)
3514
+ ]);
3515
+ return m;
3516
+ });
3517
+ if (xBdim === null) return [[gather(x, indices, axis, outDim)], [outDim]];
3518
+ else {
3519
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3520
+ const newAxis = [0, ...axis.map((ax) => ax + 1)];
3521
+ const extraBatchIndex = arange(axisSize).reshape([-1, ...require_backend.rep(nd - 1, 1)]);
3522
+ indices.splice(0, 0, extraBatchIndex);
3523
+ return [[gather(x, indices, newAxis, outDim)], [outDim]];
3524
+ }
3525
+ },
3352
3526
  [Primitive.JitCall](axisSize, args, dims, { name, jaxpr }) {
3353
3527
  const { newJaxpr, newConsts } = vmapJaxpr(jaxpr, axisSize, dims);
3354
3528
  const outs = bind(Primitive.JitCall, [...newConsts.map((c) => c.ref), ...args], {
@@ -3408,12 +3582,14 @@ function vmapFlat(f, inAxes, args) {
3408
3582
  function vmap$1(f, inAxes = 0) {
3409
3583
  return (...args) => {
3410
3584
  const [argsFlat, inTree] = flatten(args);
3411
- let inAxesFlat;
3585
+ let inAxesFlat = [];
3412
3586
  if (typeof inAxes === "number") inAxesFlat = require_backend.rep(argsFlat.length, inAxes);
3587
+ else for (let i = 0; i < args.length; i++) if (inAxes[i] == null) inAxesFlat.push(...require_backend.rep(inTree.childTreedefs[i].size, null));
3588
+ else if (typeof inAxes[i] === "number") inAxesFlat.push(...require_backend.rep(inTree.childTreedefs[i].size, inAxes[i]));
3413
3589
  else {
3414
- let inTree2;
3415
- [inAxesFlat, inTree2] = flatten(inAxes);
3416
- if (!inTree.equals(inTree2)) throw new TreeMismatchError("vmap", inTree, inTree2);
3590
+ const [axesFlat, axesTreeDef] = flatten(inAxes[i]);
3591
+ if (!inTree.childTreedefs[i].equals(axesTreeDef)) throw new TreeMismatchError("vmap", inTree.childTreedefs[i], axesTreeDef);
3592
+ inAxesFlat.push(...axesFlat);
3417
3593
  }
3418
3594
  const [fFlat, outTree] = flattenFun(f, inTree);
3419
3595
  const outsFlat = vmapFlat(fFlat, inAxesFlat, argsFlat);
@@ -4033,7 +4209,7 @@ function valueAndGrad$1(f) {
4033
4209
  const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
4034
4210
  if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
4035
4211
  if (!require_backend.isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
4036
- const [ct, ...rest] = fVjp(array(1, { dtype: y.dtype }));
4212
+ const [ct, ...rest] = fVjp(onesLike$1(y.ref));
4037
4213
  for (const r of rest) dispose(r);
4038
4214
  fVjp.dispose();
4039
4215
  return [y, ct];
@@ -4061,7 +4237,10 @@ __export(lax_exports, {
4061
4237
  conv: () => conv$1,
4062
4238
  convGeneralDilated: () => convGeneralDilated,
4063
4239
  convWithGeneralPadding: () => convWithGeneralPadding,
4064
- reduceWindow: () => reduceWindow
4240
+ erf: () => erf,
4241
+ erfc: () => erfc,
4242
+ reduceWindow: () => reduceWindow,
4243
+ stopGradient: () => stopGradient$1
4065
4244
  });
4066
4245
  function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
4067
4246
  const padType = padding.toUpperCase();
@@ -4120,6 +4299,28 @@ function reduceWindow(operand, computation, windowDimensions, windowStrides) {
4120
4299
  strides: windowStrides
4121
4300
  }));
4122
4301
  }
4302
+ /** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
4303
+ function erf(x) {
4304
+ return erf$1(x);
4305
+ }
4306
+ /**
4307
+ * The complementary error function: `erfc(x) = 1 - erf(x)`.
4308
+ *
4309
+ * This function is more accurate than `1 - erf(x)` for large values of `x`,
4310
+ * where `erf(x)` is very close to 1.
4311
+ */
4312
+ function erfc(x) {
4313
+ return erfc$1(x);
4314
+ }
4315
+ /**
4316
+ * Stops gradient computation.
4317
+ *
4318
+ * Behaves as the identity function but prevents the flow of gradients during
4319
+ * forward or reverse-mode automatic differentiation.
4320
+ */
4321
+ function stopGradient$1(x) {
4322
+ return stopGradient(x);
4323
+ }
4123
4324
 
4124
4325
  //#endregion
4125
4326
  //#region src/numpy.ts
@@ -4182,6 +4383,9 @@ __export(numpy_exports, {
4182
4383
  fullLike: () => fullLike$1,
4183
4384
  greater: () => greater,
4184
4385
  greaterEqual: () => greaterEqual,
4386
+ hamming: () => hamming,
4387
+ hann: () => hann,
4388
+ heaviside: () => heaviside,
4185
4389
  hstack: () => hstack,
4186
4390
  hypot: () => hypot,
4187
4391
  identity: () => identity$1,
@@ -4821,6 +5025,32 @@ function sign(x) {
4821
5025
  x = fudgeArray(x);
4822
5026
  return where(notEqual(x.ref, 0), where(less(x.ref, 0), -1, 1), 0);
4823
5027
  }
5028
+ /**
5029
+ * Return the Hamming window of size M, a taper with a weighted cosine bell.
5030
+ *
5031
+ * `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
5032
+ */
5033
+ function hamming(M) {
5034
+ return cos(linspace(0, 2 * Math.PI, M)).mul(-.46).add(.54);
5035
+ }
5036
+ /**
5037
+ * Return the Hann window of size M, a taper with a weighted cosine bell.
5038
+ *
5039
+ * `w(n) = 0.5 - 0.5 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
5040
+ */
5041
+ function hann(M) {
5042
+ return cos(linspace(0, 2 * Math.PI, M)).mul(-.5).add(.5);
5043
+ }
5044
+ /**
5045
+ * @function
5046
+ * Compute the Heaviside step function. It is defined piecewise:
5047
+ * - `heaviside(x1, x2) = 0` for `x1 < 0`,
5048
+ * - `heaviside(x1, x2) = x2` for `x1 == 0`,
5049
+ * - `heaviside(x1, x2) = 1` for `x1 > 0`.
5050
+ */
5051
+ const heaviside = jit$1(function heaviside$1(x1, x2) {
5052
+ return where(less(x1.ref, 0), 0, where(equal(x1, 0), x2, 1));
5053
+ });
4824
5054
  /** Calculate element-wise square of the input array. */
4825
5055
  function square(x) {
4826
5056
  x = fudgeArray(x);
@@ -4840,8 +5070,8 @@ function acos(x) {
4840
5070
  * Return element-wise hypotenuse for the given legs of a right triangle.
4841
5071
  *
4842
5072
  * In the original NumPy/JAX implementation, this function is more numerically
4843
- * stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
4844
- * improvements.
5073
+ * stable than `sqrt(x1**2 + x2**2)`. We don't currently implement those
5074
+ * stability improvements.
4845
5075
  */
4846
5076
  const hypot = jit$1(function hypot$1(x1, x2) {
4847
5077
  return sqrt(square(x1).add(square(x2)));
@@ -5165,18 +5395,20 @@ function celu(x, alpha = 1) {
5165
5395
  * @function
5166
5396
  * Gaussion error linear unit (GELU) activation function.
5167
5397
  *
5168
- * This is computed element-wise. Currently jax-js does not support the erf() or
5169
- * gelu() functions exactly as primitives, so an approximation is used:
5170
- * `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
5398
+ * This is computed element-wise. There are two variants depending on whether
5399
+ * `approximate` is set (default true):
5171
5400
  *
5172
- * Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
5401
+ * - Approximate: `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`
5402
+ * - Exact: `gelu(x) = x * 0.5 * erfc(-x / sqrt(2))`
5173
5403
  *
5174
- * This will be improved in the future.
5404
+ * Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
5175
5405
  */
5176
- const gelu = jit$1(function gelu$1(x) {
5177
- const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
5178
- return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
5179
- });
5406
+ const gelu = jit$1(function gelu$1(x, opts) {
5407
+ if (opts?.approximate ?? true) {
5408
+ const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
5409
+ return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
5410
+ } else return x.ref.mul(.5).mul(erfc$1(negative(x.ref.mul(Math.SQRT1_2))));
5411
+ }, { staticArgnums: [1] });
5180
5412
  /**
5181
5413
  * Gated linear unit (GLU) activation function.
5182
5414
  *
@@ -5397,6 +5629,25 @@ const normal = jit$1(function normal$1(key$1, shape$1 = []) {
5397
5629
  return radius.mul(cos(theta));
5398
5630
  }, { staticArgnums: [1] });
5399
5631
 
5632
+ //#endregion
5633
+ //#region src/scipy-special.ts
5634
+ var scipy_special_exports = {};
5635
+ __export(scipy_special_exports, {
5636
+ erf: () => erf,
5637
+ erfc: () => erfc,
5638
+ logSoftmax: () => logSoftmax,
5639
+ logit: () => logit,
5640
+ logsumexp: () => logsumexp,
5641
+ softmax: () => softmax
5642
+ });
5643
+ /**
5644
+ * @function
5645
+ * The logit function, `logit(p) = log(p / (1-p))`.
5646
+ */
5647
+ const logit = jit$1(function logit$1(x) {
5648
+ return log(x.ref.div(subtract(1, x)));
5649
+ });
5650
+
5400
5651
  //#endregion
5401
5652
  //#region src/polyfills.ts
5402
5653
  /** @file Polyfills for using this library. */
@@ -5490,6 +5741,24 @@ async function blockUntilReady(x) {
5490
5741
  await Promise.all(promises);
5491
5742
  return x;
5492
5743
  }
5744
+ /**
5745
+ * Transfer `x` to `device`.
5746
+ *
5747
+ * `x` may be a nested container of arrays or scalars. The resulting structure
5748
+ * is committed to the device.
5749
+ *
5750
+ * If `device` is not specified, this function behaves as identity if the input
5751
+ * is already an `Array`, otherwise it places the scalar uncommitted on the
5752
+ * default device.
5753
+ */
5754
+ async function devicePut(x, device) {
5755
+ const [xflat, structure$1] = flatten(x);
5756
+ const yflat = await Promise.all(xflat.map((leaf) => {
5757
+ if (leaf instanceof Array$1) return device ? leaf._put(require_backend.getBackend(device)) : Promise.resolve(leaf);
5758
+ else return Promise.resolve(array(leaf, { device }));
5759
+ }));
5760
+ return unflatten(structure$1, yflat);
5761
+ }
5493
5762
 
5494
5763
  //#endregion
5495
5764
  exports.Array = Array$1;
@@ -5497,6 +5766,7 @@ exports.DType = require_backend.DType;
5497
5766
  exports.Jaxpr = Jaxpr;
5498
5767
  exports.blockUntilReady = blockUntilReady;
5499
5768
  exports.defaultDevice = require_backend.defaultDevice;
5769
+ exports.devicePut = devicePut;
5500
5770
  exports.devices = require_backend.devices;
5501
5771
  exports.grad = grad;
5502
5772
  exports.init = require_backend.init;
@@ -5531,6 +5801,12 @@ Object.defineProperty(exports, 'random', {
5531
5801
  return random_exports;
5532
5802
  }
5533
5803
  });
5804
+ Object.defineProperty(exports, 'scipySpecial', {
5805
+ enumerable: true,
5806
+ get: function () {
5807
+ return scipy_special_exports;
5808
+ }
5809
+ });
5534
5810
  exports.setDebug = require_backend.setDebug;
5535
5811
  Object.defineProperty(exports, 'tree', {
5536
5812
  enumerable: true,
@@ -5540,4 +5816,5 @@ Object.defineProperty(exports, 'tree', {
5540
5816
  });
5541
5817
  exports.valueAndGrad = valueAndGrad;
5542
5818
  exports.vjp = vjp;
5543
- exports.vmap = vmap;
5819
+ exports.vmap = vmap;
5820
+ //# sourceMappingURL=index.cjs.map