@jax-js/jax 0.0.5 → 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -1,5 +1,5 @@
1
1
  import { __export } from "./chunk-Cl8Af3a2.js";
2
- import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-CdcTZEOF.js";
2
+ import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-DwIAd0AG.js";
3
3
 
4
4
  //#region src/tree.ts
5
5
  var tree_exports = {};
@@ -29,6 +29,10 @@ var JsTreeDef = class JsTreeDef {
29
29
  this.nodeMetadata = nodeMetadata;
30
30
  this.childTreedefs = childTreedefs;
31
31
  }
32
+ /** Get the total number of leaves in the tree. */
33
+ get size() {
34
+ return this.nodeType === NodeType.Leaf ? 1 : this.childTreedefs.reduce((a, b) => a + b.size, 0);
35
+ }
32
36
  /** Returns a string representation of this tree definition. */
33
37
  toString(root = true) {
34
38
  if (root) return "JsTreeDef(" + this.toString(false) + ")";
@@ -184,6 +188,16 @@ function pool(st, ks, strides = 1, dilation = 1) {
184
188
  const s_ = strides;
185
189
  const d_ = dilation;
186
190
  const o_ = zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
191
+ if (d_.every((d) => d === 1) && ks.every((k, j) => k <= s_[j])) {
192
+ st = st.padOrShrink([...noop.map(() => [0, 0]), ...zipn(i_, o_, s_).map(([i, o, s]) => [0, o * s - i])]);
193
+ st = st.reshape([...noop, ...zip(o_, s_).flatMap(([o, s]) => [o, s])]).shrink([...noop.map((x) => [0, x]), ...zip(o_, ks).flatMap(([o, k]) => [[0, o], [0, k]])]);
194
+ st = st.permute([
195
+ ...range(noop.length),
196
+ ...ks.map((_, j) => noop.length + 2 * j),
197
+ ...ks.map((_, j) => noop.length + 2 * j + 1)
198
+ ]);
199
+ return st;
200
+ }
187
201
  const f_ = zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
188
202
  const kidf = zipn(ks, i_, d_, f_);
189
203
  st = st.repeat([...rep(noop.length, 1), ...kidf.map(([k, i, d, f]) => Math.ceil(k * (i * f + d) / i))]);
@@ -218,6 +232,12 @@ function poolTranspose(st, inShape, ks, strides = 1, dilation = 1) {
218
232
  const s_ = strides;
219
233
  const d_ = dilation;
220
234
  const o_ = zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
235
+ if (d_.every((d) => d === 1) && ks.every((k, j) => k <= s_[j])) {
236
+ st = st.permute([...range(noop.length), ...ks.flatMap((_, j) => [noop.length + j, noop.length + o_.length + j])]);
237
+ st = st.pad([...noop.map(() => [0, 0]), ...zip(s_, ks).flatMap(([s, k]) => [[0, 0], [0, s - k]])]).reshape([...noop, ...zip(o_, s_).map(([o, s]) => o * s)]);
238
+ st = st.padOrShrink([...noop.map(() => [0, 0]), ...zipn(i_, o_, s_).map(([i, o, s]) => [0, i - o * s])]);
239
+ return st.reshape(st.shape.concat(rep(ks.length, 1)));
240
+ }
221
241
  if (!deepEqual(o_, st.shape.slice(noop.length, noop.length + ks.length))) throw new Error("poolTranspose() called with mismatched output shape");
222
242
  const f_ = zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
223
243
  const kidf = zipn(ks, i_, d_, f_);
@@ -327,6 +347,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
327
347
  Primitive$1["Atan"] = "atan";
328
348
  Primitive$1["Exp"] = "exp";
329
349
  Primitive$1["Log"] = "log";
350
+ Primitive$1["Erf"] = "erf";
351
+ Primitive$1["Erfc"] = "erfc";
330
352
  Primitive$1["Sqrt"] = "sqrt";
331
353
  Primitive$1["Min"] = "min";
332
354
  Primitive$1["Max"] = "max";
@@ -404,6 +426,12 @@ function exp$1(x) {
404
426
  function log$1(x) {
405
427
  return bind1(Primitive.Log, [x]);
406
428
  }
429
+ function erf$1(x) {
430
+ return bind1(Primitive.Erf, [x]);
431
+ }
432
+ function erfc$1(x) {
433
+ return bind1(Primitive.Erfc, [x]);
434
+ }
407
435
  function sqrt$1(x) {
408
436
  return bind1(Primitive.Sqrt, [x]);
409
437
  }
@@ -1146,12 +1174,18 @@ function reshapeViews(exp$2, mapping, reduceAxis = false) {
1146
1174
  } else if (exp$3.op === AluOp.GlobalIndex) throw new Error("internal: reshapeViews() called with GlobalIndex op");
1147
1175
  });
1148
1176
  }
1149
- function broadcastedJit(fn) {
1177
+ function broadcastedJit(fn, opts) {
1150
1178
  return (nargs, exps, avals, params) => {
1151
- const newShape = avals.map((aval) => aval.shape).reduce(generalBroadcast);
1152
- exps = exps.map((exp$3) => reshapeViews(exp$3, (st) => {
1153
- if (!deepEqual(st.shape, newShape)) return st.broadcast(newShape, range(newShape.length - st.shape.length));
1154
- }));
1179
+ let { shape: newShape, dtype: newDtype } = avals.reduce(promoteAvals);
1180
+ const skipCastIdx = opts?.skipCastIdx ?? [];
1181
+ if (skipCastIdx.length) newDtype = avals.filter((_, i) => !skipCastIdx.includes(i)).reduce(promoteAvals).dtype;
1182
+ exps = exps.map((exp$3, i) => {
1183
+ exp$3 = reshapeViews(exp$3, (st) => {
1184
+ if (!deepEqual(st.shape, newShape)) return st.broadcast(newShape, range(newShape.length - st.shape.length));
1185
+ });
1186
+ if (exp$3.dtype !== newDtype && !skipCastIdx.includes(i)) exp$3 = AluExp.cast(newDtype, exp$3);
1187
+ return exp$3;
1188
+ });
1155
1189
  const exp$2 = fn(exps, params);
1156
1190
  return new Kernel(nargs, prod(newShape), exp$2);
1157
1191
  };
@@ -1194,6 +1228,8 @@ const jitRules = {
1194
1228
  [Primitive.Atan]: unopJit(AluExp.atan),
1195
1229
  [Primitive.Exp]: unopJit(AluExp.exp),
1196
1230
  [Primitive.Log]: unopJit(AluExp.log),
1231
+ [Primitive.Erf]: unopJit(AluExp.erf),
1232
+ [Primitive.Erfc]: unopJit(AluExp.erfc),
1197
1233
  [Primitive.Sqrt]: unopJit(AluExp.sqrt),
1198
1234
  [Primitive.Min]: broadcastedJit(([a, b]) => AluExp.min(a, b)),
1199
1235
  [Primitive.Max]: broadcastedJit(([a, b]) => AluExp.max(a, b)),
@@ -1241,7 +1277,7 @@ const jitRules = {
1241
1277
  return jitRules[Primitive.Dot](nargs, [a, b], [as, bs], {});
1242
1278
  },
1243
1279
  [Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
1244
- [Primitive.Where]: broadcastedJit(([cond, a, b]) => AluExp.where(cond, a, b)),
1280
+ [Primitive.Where]: broadcastedJit(([cond, a, b]) => AluExp.where(cond, a, b), { skipCastIdx: [0] }),
1245
1281
  [Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
1246
1282
  [Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
1247
1283
  [Primitive.Reshape]: reshapeJit((st, { shape: shape$1 }) => st.reshape(shape$1)),
@@ -1412,7 +1448,7 @@ var PendingExecute = class {
1412
1448
  /**
1413
1449
  * A multidimensional numeric array with data stored on CPU or GPU.
1414
1450
  *
1415
- * This is the library's core data type. Equivalent to `jnp.Array` from JAX, or
1451
+ * This is the library's core data type. Equivalent to `jax.Array` from JAX, or
1416
1452
  * `torch.Tensor`.
1417
1453
  *
1418
1454
  * Not to be confused with the JavaScript "Array" constructor. Avoid importing
@@ -1427,6 +1463,7 @@ var Array$1 = class Array$1 extends Tracer {
1427
1463
  #source;
1428
1464
  #st;
1429
1465
  #backend;
1466
+ #committed;
1430
1467
  #rc;
1431
1468
  #pendingSet;
1432
1469
  /**
@@ -1443,6 +1480,7 @@ var Array$1 = class Array$1 extends Tracer {
1443
1480
  this.#source = args.source;
1444
1481
  this.#st = args.st;
1445
1482
  this.#backend = args.backend;
1483
+ this.#committed = args.committed;
1446
1484
  this.#rc = 1;
1447
1485
  this.#pendingSet = new Set(args.pending);
1448
1486
  if (this.#pendingSet.size === 0) this.#pendingSet = null;
@@ -1470,6 +1508,7 @@ var Array$1 = class Array$1 extends Tracer {
1470
1508
  dtype: args.dtype ?? this.#dtype,
1471
1509
  weakType: this.#weakType,
1472
1510
  backend: args.backend ?? this.#backend,
1511
+ committed: args.committed ?? this.#committed,
1473
1512
  pending: args.pending ?? this.#pending ?? void 0
1474
1513
  });
1475
1514
  }
@@ -1525,9 +1564,10 @@ var Array$1 = class Array$1 extends Tracer {
1525
1564
  */
1526
1565
  #gather(indices, axis, outDim) {
1527
1566
  this.#check();
1528
- if (indices.some((a) => a.#backend !== this.#backend)) throw new TypeError(`Gather indices must have the same backend: ${this.#backend.type}`);
1529
1567
  const axisSet = new Set(axis);
1530
1568
  if (axisSet.size !== axis.length) throw new TypeError("Gather axis must not have duplicates");
1569
+ if (indices.some((a) => a.#committed && a.#backend !== this.#backend)) throw new TypeError(`Gather indices must have the same backend: ${this.#backend.type}`);
1570
+ indices = indices.map((ar) => ar._putSync(this.#backend));
1531
1571
  indices = Array$1.#broadcastArrays(indices);
1532
1572
  const indexShape = indices[0].shape;
1533
1573
  const finalShape = this.shape.filter((_, i) => !axisSet.has(i));
@@ -1596,6 +1636,7 @@ var Array$1 = class Array$1 extends Tracer {
1596
1636
  this.#check();
1597
1637
  if (this.#source instanceof AluExp) {
1598
1638
  const exp$3 = new AluExp(op, dtypeOutput, [this.#source]);
1639
+ this.dispose();
1599
1640
  return this.#newArrayFrom({
1600
1641
  source: exp$3.simplify(),
1601
1642
  dtype: dtypeOutput,
@@ -1624,21 +1665,19 @@ var Array$1 = class Array$1 extends Tracer {
1624
1665
  }
1625
1666
  static #naryCustom(name, custom, arrays, { dtypeOverride, strongTypeOutput, reduceAxis } = {}) {
1626
1667
  const n = arrays.length;
1627
- const backend = arrays[0].#backend;
1628
1668
  if (n === 0) throw new TypeError(`No inputs for ${name}`);
1629
1669
  for (const ar of arrays) ar.#check();
1630
1670
  let castDtype;
1631
1671
  let castWeakType = true;
1632
- for (let i = 0; i < n; i++) {
1633
- if (dtypeOverride?.[i]) {
1634
- if (arrays[i].#dtype !== dtypeOverride[i]) throw new TypeError(`Wrong dtype in ${name}: expected ${dtypeOverride[i]}, got ${arrays[i].#dtype}`);
1635
- } else if (castDtype === void 0) {
1636
- castDtype = arrays[i].#dtype;
1637
- castWeakType = arrays[i].#weakType;
1638
- } else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), new ShapedArray([], arrays[i].#dtype, arrays[i].#weakType)));
1639
- if (arrays[i].#backend !== backend) throw new TypeError(`Backend mismatch in ${name}: ${backend.type} vs ${arrays[i].#backend.type}`);
1640
- }
1672
+ for (let i = 0; i < n; i++) if (dtypeOverride?.[i]) {
1673
+ if (arrays[i].#dtype !== dtypeOverride[i]) throw new TypeError(`Wrong dtype in ${name}: expected ${dtypeOverride[i]}, got ${arrays[i].#dtype}`);
1674
+ } else if (castDtype === void 0) {
1675
+ castDtype = arrays[i].#dtype;
1676
+ castWeakType = arrays[i].#weakType;
1677
+ } else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), new ShapedArray([], arrays[i].#dtype, arrays[i].#weakType)));
1641
1678
  const weakType = castWeakType && !strongTypeOutput;
1679
+ const { backend, committed } = Array$1.#computeBackend(name, arrays);
1680
+ arrays = arrays.map((ar) => ar._putSync(backend));
1642
1681
  arrays = Array$1.#broadcastArrays(arrays);
1643
1682
  const newShape = [...arrays[0].shape];
1644
1683
  if (arrays.every((ar) => ar.#source instanceof AluExp) && !reduceAxis) {
@@ -1648,12 +1687,14 @@ var Array$1 = class Array$1 extends Tracer {
1648
1687
  });
1649
1688
  if (arrays.every((ar) => deepEqual(ar.#st, arrays[0].#st))) {
1650
1689
  const exp$4 = custom(sources);
1690
+ arrays.forEach((ar) => ar.dispose());
1651
1691
  return new Array$1({
1652
1692
  source: exp$4.simplify(),
1653
1693
  st: arrays[0].#st,
1654
1694
  dtype: exp$4.dtype,
1655
1695
  weakType,
1656
- backend
1696
+ backend,
1697
+ committed
1657
1698
  });
1658
1699
  }
1659
1700
  const exp$3 = custom(arrays.map((ar, i) => {
@@ -1662,12 +1703,14 @@ var Array$1 = class Array$1 extends Tracer {
1662
1703
  return accessorAluExp(src$1, ar.#st, unravelAlu(newShape, AluVar.idx));
1663
1704
  }));
1664
1705
  const st = ShapeTracker.fromShape(newShape);
1706
+ arrays.forEach((ar) => ar.dispose());
1665
1707
  return new Array$1({
1666
1708
  source: exp$3.simplify(),
1667
1709
  st,
1668
1710
  dtype: exp$3.dtype,
1669
1711
  weakType,
1670
- backend
1712
+ backend,
1713
+ committed
1671
1714
  });
1672
1715
  }
1673
1716
  let indices;
@@ -1703,13 +1746,14 @@ var Array$1 = class Array$1 extends Tracer {
1703
1746
  const pending = new Set([...arrays.flatMap((ar) => ar.#pending)]);
1704
1747
  for (const exe of pending) exe.updateRc(1);
1705
1748
  pending.add(new PendingExecute(backend, kernel, inputs, [output]));
1706
- for (const ar of arrays) ar.dispose();
1749
+ arrays.forEach((ar) => ar.dispose());
1707
1750
  return new Array$1({
1708
1751
  source: output,
1709
1752
  st: ShapeTracker.fromShape(newShape),
1710
1753
  dtype: kernel.dtype,
1711
1754
  weakType,
1712
1755
  backend,
1756
+ committed,
1713
1757
  pending
1714
1758
  });
1715
1759
  }
@@ -1787,6 +1831,23 @@ var Array$1 = class Array$1 extends Tracer {
1787
1831
  return ar.#reshape(ar.#st.broadcast(newShape, range(newShape.length - ar.ndim)));
1788
1832
  });
1789
1833
  }
1834
+ static #computeBackend(name, arrays) {
1835
+ const committed = arrays.filter((ar) => ar.#committed);
1836
+ if (committed.length > 0) {
1837
+ const backend = committed[0].#backend;
1838
+ for (const ar of committed) if (ar.#backend !== backend) throw new Error(`Device mismatch in ${name} between committed arrays on (${backend.type}, ${ar.#backend.type}), please move to the same device with devicePut()`);
1839
+ return {
1840
+ backend,
1841
+ committed: true
1842
+ };
1843
+ } else {
1844
+ const backend = arrays.length > 0 ? arrays[0].#backend : getBackend();
1845
+ return {
1846
+ backend,
1847
+ committed: false
1848
+ };
1849
+ }
1850
+ }
1790
1851
  /** Realize the array and return it as data. */
1791
1852
  async data() {
1792
1853
  if (this.#source instanceof AluExp && this.size < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
@@ -1946,6 +2007,12 @@ var Array$1 = class Array$1 extends Tracer {
1946
2007
  [Primitive.Log]([x]) {
1947
2008
  return [x.#unary(AluOp.Log)];
1948
2009
  },
2010
+ [Primitive.Erf]([x]) {
2011
+ return [x.#unary(AluOp.Erf)];
2012
+ },
2013
+ [Primitive.Erfc]([x]) {
2014
+ return [x.#unary(AluOp.Erfc)];
2015
+ },
1949
2016
  [Primitive.Sqrt]([x]) {
1950
2017
  return [x.#unary(AluOp.Sqrt)];
1951
2018
  },
@@ -2014,7 +2081,8 @@ var Array$1 = class Array$1 extends Tracer {
2014
2081
  },
2015
2082
  [Primitive.JitCall](args, { jaxpr, numConsts }) {
2016
2083
  if (jaxpr.inBinders.length !== args.length) throw new Error(`jit_call expects ${jaxpr.inBinders.length} args, got ${args.length}`);
2017
- const backend = getBackend();
2084
+ const { backend, committed } = Array$1.#computeBackend("jit_call", args);
2085
+ args = args.map((ar) => ar._putSync(backend));
2018
2086
  const consts = args.slice(0, numConsts);
2019
2087
  const tracers = args.slice(numConsts);
2020
2088
  const jp = jitCompile(backend, jaxpr, consts);
@@ -2031,16 +2099,54 @@ var Array$1 = class Array$1 extends Tracer {
2031
2099
  dtype: jaxpr.outs[i].aval.dtype,
2032
2100
  weakType: jaxpr.outs[i].aval.weakType,
2033
2101
  backend,
2102
+ committed,
2034
2103
  pending
2035
2104
  });
2036
2105
  });
2037
2106
  }
2038
2107
  };
2039
2108
  }
2109
+ /** @private */
2040
2110
  _realizeSource() {
2041
2111
  this.#realize();
2042
2112
  return this.#source;
2043
2113
  }
2114
+ /** @private Put this array on a new backend, asynchronously. */
2115
+ async _put(backend) {
2116
+ if (this.#backend === backend) return this;
2117
+ if (this.#source instanceof AluExp) {
2118
+ const ar = this.#newArrayFrom({
2119
+ backend,
2120
+ committed: true
2121
+ });
2122
+ this.dispose();
2123
+ return ar;
2124
+ } else {
2125
+ const data = await this.data();
2126
+ return arrayFromData(data, this.shape, {
2127
+ dtype: this.#dtype,
2128
+ device: backend.type
2129
+ }, this.#weakType);
2130
+ }
2131
+ }
2132
+ /** @private Put this array on a new backend, synchronously. */
2133
+ _putSync(backend) {
2134
+ if (this.#backend === backend) return this;
2135
+ if (this.#source instanceof AluExp) {
2136
+ const ar = this.#newArrayFrom({
2137
+ backend,
2138
+ committed: true
2139
+ });
2140
+ this.dispose();
2141
+ return ar;
2142
+ } else {
2143
+ const data = this.dataSync();
2144
+ return arrayFromData(data, this.shape, {
2145
+ dtype: this.#dtype,
2146
+ device: backend.type
2147
+ }, this.#weakType);
2148
+ }
2149
+ }
2044
2150
  };
2045
2151
  /** Constructor for creating a new array from data. */
2046
2152
  function array(values, { shape: shape$1, dtype, device } = {}) {
@@ -2123,7 +2229,8 @@ function arrayFromData(data, shape$1, { dtype, device }, weakType = false) {
2123
2229
  st: ShapeTracker.fromShape(shape$1),
2124
2230
  dtype,
2125
2231
  weakType,
2126
- backend
2232
+ backend,
2233
+ committed: device != void 0
2127
2234
  });
2128
2235
  }
2129
2236
  function dataToJs(dtype, data, shape$1) {
@@ -2157,7 +2264,8 @@ function fullInternal(aval, fillValue, device) {
2157
2264
  st: ShapeTracker.fromShape(aval.shape),
2158
2265
  dtype: aval.dtype,
2159
2266
  weakType: aval.weakType,
2160
- backend: getBackend(device)
2267
+ backend: getBackend(device),
2268
+ committed: device != void 0
2161
2269
  });
2162
2270
  }
2163
2271
  function zerosLike$1(val, dtype) {
@@ -2225,7 +2333,8 @@ function eye(numRows, numCols, { dtype, device } = {}) {
2225
2333
  st: ShapeTracker.fromShape([numRows, numCols]),
2226
2334
  dtype,
2227
2335
  weakType,
2228
- backend: getBackend(device)
2336
+ backend: getBackend(device),
2337
+ committed: device != void 0
2229
2338
  });
2230
2339
  }
2231
2340
  /** Return the identity matrix, with ones on the main diagonal. */
@@ -2268,7 +2377,8 @@ function arange(start, stop, step = 1, { dtype, device } = {}) {
2268
2377
  st,
2269
2378
  dtype,
2270
2379
  weakType: false,
2271
- backend: getBackend(device)
2380
+ backend: getBackend(device),
2381
+ committed: device != void 0
2272
2382
  });
2273
2383
  }
2274
2384
  /**
@@ -2304,7 +2414,8 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
2304
2414
  st,
2305
2415
  dtype,
2306
2416
  weakType: false,
2307
- backend: getBackend(device)
2417
+ backend: getBackend(device),
2418
+ committed: device != void 0
2308
2419
  });
2309
2420
  }
2310
2421
  function aluCompare(a, b, op) {
@@ -2812,6 +2923,8 @@ const abstractEvalRules = {
2812
2923
  [Primitive.Atan]: vectorizedUnopAbstractEval,
2813
2924
  [Primitive.Exp]: vectorizedUnopAbstractEval,
2814
2925
  [Primitive.Log]: vectorizedUnopAbstractEval,
2926
+ [Primitive.Erf]: vectorizedUnopAbstractEval,
2927
+ [Primitive.Erfc]: vectorizedUnopAbstractEval,
2815
2928
  [Primitive.Sqrt]: vectorizedUnopAbstractEval,
2816
2929
  [Primitive.Min]: binopAbstractEval,
2817
2930
  [Primitive.Max]: binopAbstractEval,
@@ -3064,6 +3177,16 @@ const jvpRules = {
3064
3177
  [Primitive.Log]([x], [dx]) {
3065
3178
  return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
3066
3179
  },
3180
+ [Primitive.Erf]([x], [dx]) {
3181
+ const coeff = 2 / Math.sqrt(Math.PI);
3182
+ const expTerm = exp$1(neg(x.ref.mul(x.ref)));
3183
+ return [[erf$1(x)], [expTerm.mul(coeff).mul(dx)]];
3184
+ },
3185
+ [Primitive.Erfc]([x], [dx]) {
3186
+ const coeff = -2 / Math.sqrt(Math.PI);
3187
+ const expTerm = exp$1(neg(x.ref.mul(x.ref)));
3188
+ return [[erfc$1(x)], [expTerm.mul(coeff).mul(dx)]];
3189
+ },
3067
3190
  [Primitive.Sqrt]([x], [dx]) {
3068
3191
  const z = sqrt$1(x);
3069
3192
  return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
@@ -3225,6 +3348,10 @@ var BatchTrace = class extends Trace {
3225
3348
  const [valsIn, bdimsIn] = unzip2(tracers.map((t) => [t.val, t.batchDim]));
3226
3349
  const vmapRule = vmapRules[primitive];
3227
3350
  if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
3351
+ if (bdimsIn.every((d) => d === null)) {
3352
+ const valOuts$1 = bind(primitive, valsIn, params);
3353
+ return valOuts$1.map((x) => new BatchTracer(this, x, null));
3354
+ }
3228
3355
  const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
3229
3356
  return zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
3230
3357
  }
@@ -3232,24 +3359,28 @@ var BatchTrace = class extends Trace {
3232
3359
  return this.main.globalData;
3233
3360
  }
3234
3361
  };
3235
- function handleScalarBroadcasting(nd, x, d) {
3236
- if (d === null || nd === ndim$1(x)) return x;
3237
- else {
3238
- const axis = range(ndim$1(x), nd);
3239
- const shape$1 = [...x.shape, ...axis.map(() => 1)];
3240
- return broadcast(x, shape$1, axis);
3241
- }
3242
- }
3243
- /** Process a primitive with built-in broadcasting. */
3362
+ /**
3363
+ * Process a primitive with built-in broadcasting.
3364
+ *
3365
+ * Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
3366
+ */
3244
3367
  function broadcastBatcher(op) {
3245
3368
  return (axisSize, args, dims) => {
3246
3369
  if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
3247
- const idx = dims.findIndex((d) => d !== null);
3248
- if (idx === -1) return [[op(...args)], [null]];
3249
- if (zip(args, dims).every(([x, d]) => ndim$1(x) === 0 || deepEqual(x.shape, args[idx].shape) && d === dims[idx])) return [[op(...args)], [dims[idx]]];
3250
- args = args.map((x, i) => ndim$1(x) > 0 ? moveBatchAxis(axisSize, dims[i], 0, x) : x);
3251
- const nd = Math.max(...args.map(ndim$1));
3252
- args = args.map((x, i) => handleScalarBroadcasting(nd, x, dims[i]));
3370
+ const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
3371
+ const firstIdx = dims.findIndex((d) => d !== null);
3372
+ const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
3373
+ if (zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
3374
+ args = args.map((x, i) => {
3375
+ if (dims[i] === null) return x;
3376
+ x = moveBatchAxis(axisSize, dims[i], 0, x);
3377
+ if (x.ndim < nd) x = x.reshape([
3378
+ x.shape[0],
3379
+ ...rep(nd - x.ndim, 1),
3380
+ ...x.shape.slice(1)
3381
+ ]);
3382
+ return x;
3383
+ });
3253
3384
  return [[op(...args)], [0]];
3254
3385
  };
3255
3386
  }
@@ -3273,17 +3404,18 @@ const vmapRules = {
3273
3404
  [Primitive.Atan]: unopBatcher(atan$1),
3274
3405
  [Primitive.Exp]: unopBatcher(exp$1),
3275
3406
  [Primitive.Log]: unopBatcher(log$1),
3407
+ [Primitive.Erf]: unopBatcher(erf$1),
3408
+ [Primitive.Erfc]: unopBatcher(erfc$1),
3276
3409
  [Primitive.Sqrt]: unopBatcher(sqrt$1),
3277
3410
  [Primitive.Min]: broadcastBatcher(min$1),
3278
3411
  [Primitive.Max]: broadcastBatcher(max$1),
3279
3412
  [Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
3280
- if (xBdim === null) return [[reduce(x, op, axis)], [null]];
3413
+ assertNonNull(xBdim);
3281
3414
  const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3282
3415
  const outBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
3283
3416
  return [[reduce(x, op, newAxis)], [outBdim]];
3284
3417
  },
3285
3418
  [Primitive.Dot](axisSize, [x, y], [xBdim, yBdim]) {
3286
- if (xBdim === null && yBdim === null) return [[dot$1(x, y)], [null]];
3287
3419
  x = moveBatchAxis(axisSize, xBdim, x.ndim - (xBdim === null ? 1 : 2), x);
3288
3420
  y = moveBatchAxis(axisSize, yBdim, y.ndim - (yBdim === null ? 1 : 2), y);
3289
3421
  const z = dot$1(x, y);
@@ -3292,26 +3424,68 @@ const vmapRules = {
3292
3424
  [Primitive.Compare](axisSize, args, dims, { op }) {
3293
3425
  return broadcastBatcher((x, y) => compare(x, y, op))(axisSize, args, dims, {});
3294
3426
  },
3427
+ [Primitive.Where]: broadcastBatcher(where$1),
3428
+ [Primitive.Transpose](axisSize, [x], [xBdim], { perm }) {
3429
+ assertNonNull(xBdim);
3430
+ const newPerm = perm.map((p) => p + (xBdim <= p ? 1 : 0));
3431
+ newPerm.splice(xBdim, 0, xBdim);
3432
+ return [[transpose$1(x, newPerm)], [xBdim]];
3433
+ },
3434
+ [Primitive.Broadcast](axisSize, [x], [xBdim], { shape: shape$1, axis }) {
3435
+ assertNonNull(xBdim);
3436
+ const newShape = shape$1.toSpliced(xBdim, 0, axisSize);
3437
+ const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3438
+ return [[broadcast(x, newShape, newAxis)], [xBdim]];
3439
+ },
3295
3440
  [Primitive.Reshape](axisSize, [x], [xBdim], { shape: shape$1 }) {
3296
- if (xBdim === null) return [[reshape$1(x, shape$1)], [null]];
3297
3441
  x = moveBatchAxis(axisSize, xBdim, 0, x);
3298
3442
  return [[reshape$1(x, [axisSize, ...shape$1])], [0]];
3299
3443
  },
3300
3444
  [Primitive.Flip](axisSize, [x], [xBdim], { axis }) {
3301
- if (xBdim === null) return [[flip$1(x, axis)], [null]];
3445
+ assertNonNull(xBdim);
3302
3446
  const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3303
3447
  return [[flip$1(x, newAxis)], [xBdim]];
3304
3448
  },
3305
3449
  [Primitive.Shrink](axisSize, [x], [xBdim], { slice }) {
3306
- if (xBdim === null) return [[shrink(x, slice)], [null]];
3450
+ assertNonNull(xBdim);
3307
3451
  const newSlice = slice.toSpliced(xBdim, 0, [0, axisSize]);
3308
3452
  return [[shrink(x, newSlice)], [xBdim]];
3309
3453
  },
3310
3454
  [Primitive.Pad](axisSize, [x], [xBdim], { width }) {
3311
- if (xBdim === null) return [[pad$1(x, width)], [null]];
3455
+ assertNonNull(xBdim);
3312
3456
  const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
3313
3457
  return [[pad$1(x, newWidth)], [xBdim]];
3314
3458
  },
3459
+ [Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
3460
+ if (indicesBdim.every((d) => d === null)) {
3461
+ assertNonNull(xBdim);
3462
+ const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3463
+ let newBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
3464
+ let newOutDim = outDim;
3465
+ if (newOutDim < newBdim) newBdim += axis.length;
3466
+ else newOutDim += 1;
3467
+ return [[gather(x, indices, newAxis, newOutDim)], [newBdim]];
3468
+ }
3469
+ const nd = Math.max(...indices.map((m, i) => ndim$1(m) + (indicesBdim[i] === null ? 1 : 0)));
3470
+ indices = indices.map((m, i) => {
3471
+ if (indicesBdim[i] === null) return m;
3472
+ m = moveBatchAxis(axisSize, indicesBdim[i], 0, m);
3473
+ if (m.ndim < nd) m = m.reshape([
3474
+ m.shape[0],
3475
+ ...rep(nd - m.ndim, 1),
3476
+ ...m.shape.slice(1)
3477
+ ]);
3478
+ return m;
3479
+ });
3480
+ if (xBdim === null) return [[gather(x, indices, axis, outDim)], [outDim]];
3481
+ else {
3482
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3483
+ const newAxis = [0, ...axis.map((ax) => ax + 1)];
3484
+ const extraBatchIndex = arange(axisSize).reshape([-1, ...rep(nd - 1, 1)]);
3485
+ indices.splice(0, 0, extraBatchIndex);
3486
+ return [[gather(x, indices, newAxis, outDim)], [outDim]];
3487
+ }
3488
+ },
3315
3489
  [Primitive.JitCall](axisSize, args, dims, { name, jaxpr }) {
3316
3490
  const { newJaxpr, newConsts } = vmapJaxpr(jaxpr, axisSize, dims);
3317
3491
  const outs = bind(Primitive.JitCall, [...newConsts.map((c) => c.ref), ...args], {
@@ -3371,12 +3545,14 @@ function vmapFlat(f, inAxes, args) {
3371
3545
  function vmap$1(f, inAxes = 0) {
3372
3546
  return (...args) => {
3373
3547
  const [argsFlat, inTree] = flatten(args);
3374
- let inAxesFlat;
3548
+ let inAxesFlat = [];
3375
3549
  if (typeof inAxes === "number") inAxesFlat = rep(argsFlat.length, inAxes);
3550
+ else for (let i = 0; i < args.length; i++) if (inAxes[i] == null) inAxesFlat.push(...rep(inTree.childTreedefs[i].size, null));
3551
+ else if (typeof inAxes[i] === "number") inAxesFlat.push(...rep(inTree.childTreedefs[i].size, inAxes[i]));
3376
3552
  else {
3377
- let inTree2;
3378
- [inAxesFlat, inTree2] = flatten(inAxes);
3379
- if (!inTree.equals(inTree2)) throw new TreeMismatchError("vmap", inTree, inTree2);
3553
+ const [axesFlat, axesTreeDef] = flatten(inAxes[i]);
3554
+ if (!inTree.childTreedefs[i].equals(axesTreeDef)) throw new TreeMismatchError("vmap", inTree.childTreedefs[i], axesTreeDef);
3555
+ inAxesFlat.push(...axesFlat);
3380
3556
  }
3381
3557
  const [fFlat, outTree] = flattenFun(f, inTree);
3382
3558
  const outsFlat = vmapFlat(fFlat, inAxesFlat, argsFlat);
@@ -3996,7 +4172,7 @@ function valueAndGrad$1(f) {
3996
4172
  const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
3997
4173
  if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
3998
4174
  if (!isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
3999
- const [ct, ...rest] = fVjp(array(1, { dtype: y.dtype }));
4175
+ const [ct, ...rest] = fVjp(onesLike$1(y.ref));
4000
4176
  for (const r of rest) dispose(r);
4001
4177
  fVjp.dispose();
4002
4178
  return [y, ct];
@@ -4024,7 +4200,10 @@ __export(lax_exports, {
4024
4200
  conv: () => conv$1,
4025
4201
  convGeneralDilated: () => convGeneralDilated,
4026
4202
  convWithGeneralPadding: () => convWithGeneralPadding,
4027
- reduceWindow: () => reduceWindow
4203
+ erf: () => erf,
4204
+ erfc: () => erfc,
4205
+ reduceWindow: () => reduceWindow,
4206
+ stopGradient: () => stopGradient$1
4028
4207
  });
4029
4208
  function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
4030
4209
  const padType = padding.toUpperCase();
@@ -4083,6 +4262,28 @@ function reduceWindow(operand, computation, windowDimensions, windowStrides) {
4083
4262
  strides: windowStrides
4084
4263
  }));
4085
4264
  }
4265
+ /** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
4266
+ function erf(x) {
4267
+ return erf$1(x);
4268
+ }
4269
+ /**
4270
+ * The complementary error function: `erfc(x) = 1 - erf(x)`.
4271
+ *
4272
+ * This function is more accurate than `1 - erf(x)` for large values of `x`,
4273
+ * where `erf(x)` is very close to 1.
4274
+ */
4275
+ function erfc(x) {
4276
+ return erfc$1(x);
4277
+ }
4278
+ /**
4279
+ * Stops gradient computation.
4280
+ *
4281
+ * Behaves as the identity function but prevents the flow of gradients during
4282
+ * forward or reverse-mode automatic differentiation.
4283
+ */
4284
+ function stopGradient$1(x) {
4285
+ return stopGradient(x);
4286
+ }
4086
4287
 
4087
4288
  //#endregion
4088
4289
  //#region src/numpy.ts
@@ -4145,6 +4346,9 @@ __export(numpy_exports, {
4145
4346
  fullLike: () => fullLike$1,
4146
4347
  greater: () => greater,
4147
4348
  greaterEqual: () => greaterEqual,
4349
+ hamming: () => hamming,
4350
+ hann: () => hann,
4351
+ heaviside: () => heaviside,
4148
4352
  hstack: () => hstack,
4149
4353
  hypot: () => hypot,
4150
4354
  identity: () => identity$1,
@@ -4784,6 +4988,32 @@ function sign(x) {
4784
4988
  x = fudgeArray(x);
4785
4989
  return where(notEqual(x.ref, 0), where(less(x.ref, 0), -1, 1), 0);
4786
4990
  }
4991
+ /**
4992
+ * Return the Hamming window of size M, a taper with a weighted cosine bell.
4993
+ *
4994
+ * `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
4995
+ */
4996
+ function hamming(M) {
4997
+ return cos(linspace(0, 2 * Math.PI, M)).mul(-.46).add(.54);
4998
+ }
4999
+ /**
5000
+ * Return the Hann window of size M, a taper with a weighted cosine bell.
5001
+ *
5002
+ * `w(n) = 0.5 - 0.5 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
5003
+ */
5004
+ function hann(M) {
5005
+ return cos(linspace(0, 2 * Math.PI, M)).mul(-.5).add(.5);
5006
+ }
5007
+ /**
5008
+ * @function
5009
+ * Compute the Heaviside step function. It is defined piecewise:
5010
+ * - `heaviside(x1, x2) = 0` for `x1 < 0`,
5011
+ * - `heaviside(x1, x2) = x2` for `x1 == 0`,
5012
+ * - `heaviside(x1, x2) = 1` for `x1 > 0`.
5013
+ */
5014
+ const heaviside = jit$1(function heaviside$1(x1, x2) {
5015
+ return where(less(x1.ref, 0), 0, where(equal(x1, 0), x2, 1));
5016
+ });
4787
5017
  /** Calculate element-wise square of the input array. */
4788
5018
  function square(x) {
4789
5019
  x = fudgeArray(x);
@@ -4803,8 +5033,8 @@ function acos(x) {
4803
5033
  * Return element-wise hypotenuse for the given legs of a right triangle.
4804
5034
  *
4805
5035
  * In the original NumPy/JAX implementation, this function is more numerically
4806
- * stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
4807
- * improvements.
5036
+ * stable than `sqrt(x1**2 + x2**2)`. We don't currently implement those
5037
+ * stability improvements.
4808
5038
  */
4809
5039
  const hypot = jit$1(function hypot$1(x1, x2) {
4810
5040
  return sqrt(square(x1).add(square(x2)));
@@ -5128,18 +5358,20 @@ function celu(x, alpha = 1) {
5128
5358
  * @function
5129
5359
  * Gaussion error linear unit (GELU) activation function.
5130
5360
  *
5131
- * This is computed element-wise. Currently jax-js does not support the erf() or
5132
- * gelu() functions exactly as primitives, so an approximation is used:
5133
- * `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
5361
+ * This is computed element-wise. There are two variants depending on whether
5362
+ * `approximate` is set (default true):
5134
5363
  *
5135
- * Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
5364
+ * - Approximate: `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`
5365
+ * - Exact: `gelu(x) = x * 0.5 * erfc(-x / sqrt(2))`
5136
5366
  *
5137
- * This will be improved in the future.
5367
+ * Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
5138
5368
  */
5139
- const gelu = jit$1(function gelu$1(x) {
5140
- const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
5141
- return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
5142
- });
5369
+ const gelu = jit$1(function gelu$1(x, opts) {
5370
+ if (opts?.approximate ?? true) {
5371
+ const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
5372
+ return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
5373
+ } else return x.ref.mul(.5).mul(erfc$1(negative(x.ref.mul(Math.SQRT1_2))));
5374
+ }, { staticArgnums: [1] });
5143
5375
  /**
5144
5376
  * Gated linear unit (GLU) activation function.
5145
5377
  *
@@ -5360,6 +5592,25 @@ const normal = jit$1(function normal$1(key$1, shape$1 = []) {
5360
5592
  return radius.mul(cos(theta));
5361
5593
  }, { staticArgnums: [1] });
5362
5594
 
5595
+ //#endregion
5596
+ //#region src/scipy-special.ts
5597
+ var scipy_special_exports = {};
5598
+ __export(scipy_special_exports, {
5599
+ erf: () => erf,
5600
+ erfc: () => erfc,
5601
+ logSoftmax: () => logSoftmax,
5602
+ logit: () => logit,
5603
+ logsumexp: () => logsumexp,
5604
+ softmax: () => softmax
5605
+ });
5606
+ /**
5607
+ * @function
5608
+ * The logit function, `logit(p) = log(p / (1-p))`.
5609
+ */
5610
+ const logit = jit$1(function logit$1(x) {
5611
+ return log(x.ref.div(subtract(1, x)));
5612
+ });
5613
+
5363
5614
  //#endregion
5364
5615
  //#region src/polyfills.ts
5365
5616
  /** @file Polyfills for using this library. */
@@ -5453,6 +5704,25 @@ async function blockUntilReady(x) {
5453
5704
  await Promise.all(promises);
5454
5705
  return x;
5455
5706
  }
5707
+ /**
5708
+ * Transfer `x` to `device`.
5709
+ *
5710
+ * `x` may be a nested container of arrays or scalars. The resulting structure
5711
+ * is committed to the device.
5712
+ *
5713
+ * If `device` is not specified, this function behaves as identity if the input
5714
+ * is already an `Array`, otherwise it places the scalar uncommitted on the
5715
+ * default device.
5716
+ */
5717
+ async function devicePut(x, device) {
5718
+ const [xflat, structure$1] = flatten(x);
5719
+ const yflat = await Promise.all(xflat.map((leaf) => {
5720
+ if (leaf instanceof Array$1) return device ? leaf._put(getBackend(device)) : Promise.resolve(leaf);
5721
+ else return Promise.resolve(array(leaf, { device }));
5722
+ }));
5723
+ return unflatten(structure$1, yflat);
5724
+ }
5456
5725
 
5457
5726
  //#endregion
5458
- export { Array$1 as Array, DType, Jaxpr, blockUntilReady, defaultDevice, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
5727
+ export { Array$1 as Array, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
5728
+ //# sourceMappingURL=index.js.map