@jax-js/jax 0.1.2 → 0.1.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -1,28 +1,36 @@
1
1
  import { __export } from "./chunk-Cl8Af3a2.js";
2
- import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-BqymqzuU.js";
2
+ import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-BY8wlLEl.js";
3
3
 
4
4
  //#region src/frontend/convolution.ts
5
5
  /**
6
6
  * Check that the shapes and parameters passed to convolution are valid.
7
+ * Expected shapes of the lhs and rhs of the convolution are:
8
+ *
9
+ * - `lhsShape = [*vmapDims, batchSize, inChannels, spatialDims...]`
10
+ * - `rhsShape = [*vmapDims, outChannels, inChannels, kernelSize...]`
7
11
  *
8
12
  * If the check succeeds, returns the output shape.
9
13
  */
10
- function checkConvShape(lhsShape, rhsShape, { strides, padding, lhsDilation, rhsDilation }) {
14
+ function checkConvShape(lhsShape, rhsShape, { vmapDims, strides, padding, lhsDilation, rhsDilation }) {
11
15
  if (lhsShape.length !== rhsShape.length) throw new Error(`conv() requires inputs with the same number of dimensions, got ${lhsShape.length} and ${rhsShape.length}`);
12
- const n = lhsShape.length - 2;
16
+ const n = lhsShape.length - 2 - vmapDims;
13
17
  if (n < 0) throw new Error("conv() requires at least 2D inputs");
14
18
  if (strides.length !== n) throw new Error("conv() strides != spatial dims");
15
19
  if (padding.length !== n) throw new Error("conv() padding != spatial dims");
16
20
  if (lhsDilation.length !== n) throw new Error("conv() lhsDilation != spatial dimensions");
17
21
  if (rhsDilation.length !== n) throw new Error("conv() rhsDilation != spatial dimensions");
18
- if (lhsShape[1] !== rhsShape[1]) throw new Error(`conv() input channels: ${lhsShape[1]} != ${rhsShape[1]}`);
19
- const outShape = [lhsShape[0], rhsShape[0]];
22
+ if (lhsShape[vmapDims + 1] !== rhsShape[vmapDims + 1]) throw new Error(`conv() input channels: ${lhsShape[1]} != ${rhsShape[1]}`);
23
+ const outShape = [
24
+ ...generalBroadcast(lhsShape.slice(0, vmapDims), rhsShape.slice(0, vmapDims)),
25
+ lhsShape[vmapDims],
26
+ rhsShape[vmapDims]
27
+ ];
20
28
  for (let i = 0; i < n; i++) {
21
29
  if (strides[i] <= 0 || !Number.isInteger(strides[i])) throw new Error(`conv() strides[${i}] must be a positive integer`);
22
30
  if (padding[i].length !== 2 || !padding[i].every(Number.isInteger)) throw new Error(`conv() padding[${i}] must be a 2-tuple of integers`);
23
31
  if (lhsDilation[i] <= 0 || !Number.isInteger(lhsDilation[i])) throw new Error(`conv() lhsDilation[${i}] must be a positive integer`);
24
32
  if (rhsDilation[i] <= 0 || !Number.isInteger(rhsDilation[i])) throw new Error(`conv() rhsDilation[${i}] must be a positive integer`);
25
- const [x, k] = [lhsShape[i + 2], rhsShape[i + 2]];
33
+ const [x, k] = [lhsShape[i + vmapDims + 2], rhsShape[i + vmapDims + 2]];
26
34
  if (k <= 0) throw new Error("conv() kernel size must be positive");
27
35
  const [pl, pr] = padding[i];
28
36
  if (pl < -x || pr < -x || pl + pr < -x) throw new Error(`conv() padding[${i}]=(${pl},${pr}) is too negative for input size ${x}`);
@@ -147,27 +155,13 @@ function poolTranspose(st, inShape, ks, strides = 1, dilation = 1) {
147
155
  function applyDilation(st, dilation) {
148
156
  if (dilation.every((s) => s === 1)) return st;
149
157
  const s_ = dilation;
150
- const [a, b, ...k_] = st.shape;
151
- st = st.reshape([
152
- a,
153
- b,
154
- ...k_.flatMap((k) => [k, 1])
155
- ]);
156
- st = st.pad([
157
- [0, 0],
158
- [0, 0],
159
- ...s_.flatMap((s) => [[0, 0], [0, s - 1]])
160
- ]);
161
- st = st.reshape([
162
- a,
163
- b,
164
- ...k_.map((k, i) => k * s_[i])
165
- ]);
166
- st = st.shrink([
167
- [0, a],
168
- [0, b],
169
- ...k_.map((k, i) => [0, (k - 1) * s_[i] + 1])
170
- ]);
158
+ const n = s_.length;
159
+ const prefix = st.shape.slice(0, -n);
160
+ const k_ = st.shape.slice(-n);
161
+ st = st.reshape([...prefix, ...k_.flatMap((k) => [k, 1])]);
162
+ st = st.pad([...prefix.map(() => [0, 0]), ...s_.flatMap((s) => [[0, 0], [0, s - 1]])]);
163
+ st = st.reshape([...prefix, ...k_.map((k, i) => k * s_[i])]);
164
+ st = st.shrink([...prefix.map((p) => [0, p]), ...k_.map((k, i) => [0, (k - 1) * s_[i] + 1])]);
171
165
  return st;
172
166
  }
173
167
  /**
@@ -177,25 +171,26 @@ function applyDilation(st, dilation) {
177
171
  * beforehand using `checkConvShape()`.
178
172
  */
179
173
  function prepareConv(stX, stY, params) {
180
- const n = stX.shape.length - 2;
174
+ const v = params.vmapDims;
175
+ const n = stX.shape.length - 2 - v;
176
+ const vmapShape = stX.shape.slice(0, v);
181
177
  stX = applyDilation(stX, params.lhsDilation);
182
- const ks = stY.shape.slice(2);
183
- stX = stX.padOrShrink([
184
- [0, 0],
185
- [0, 0],
186
- ...params.padding
187
- ]);
178
+ const ks = stY.shape.slice(v + 2);
179
+ stX = stX.padOrShrink([...rep(v + 2, [0, 0]), ...params.padding]);
188
180
  stX = pool(stX, ks, params.strides, params.rhsDilation);
189
- stX = stX.moveaxis(1, n + 1).reshape([
190
- stX.shape[0],
181
+ stX = stX.moveaxis(v + 1, v + n + 1).reshape([
182
+ ...vmapShape,
183
+ stX.shape[v],
191
184
  1,
192
- ...stX.shape.slice(2, n + 2),
193
- stX.shape[1] * prod(ks)
185
+ ...stX.shape.slice(v + 2, v + n + 2),
186
+ stX.shape[v + 1] * prod(ks)
194
187
  ]);
195
188
  stY = stY.reshape([
196
- stY.shape[0],
189
+ ...vmapShape,
190
+ 1,
191
+ stY.shape[v],
197
192
  ...rep(n, 1),
198
- stY.shape[1] * prod(ks)
193
+ stY.shape[v + 1] * prod(ks)
199
194
  ]);
200
195
  return [stX, stY];
201
196
  }
@@ -467,9 +462,11 @@ function dot$2(x, y) {
467
462
  }
468
463
  function conv$1(x, y, params = {}) {
469
464
  if (x.ndim !== y.ndim) throw new Error(`conv() requires inputs with the same number of dimensions, got ${x.ndim} and ${y.ndim}`);
470
- const n = x.ndim - 2;
465
+ const vmapDims = params.vmapDims ?? 0;
466
+ const n = x.ndim - 2 - vmapDims;
471
467
  if (n < 0) throw new Error("conv() requires at least 2D inputs");
472
468
  return bind1(Primitive.Conv, [x, y], {
469
+ vmapDims,
473
470
  strides: params.strides ?? rep(n, 1),
474
471
  padding: params.padding ?? rep(n, [0, 0]),
475
472
  lhsDilation: params.lhsDilation ?? rep(n, 1),
@@ -693,8 +690,10 @@ var Tracer = class Tracer {
693
690
  axis = normalizeAxis(axis, this.ndim);
694
691
  const n = axis.reduce((acc, a) => acc * this.shape[a], 1);
695
692
  if (n === 0) throw new Error("mean: cannot compute mean over zero-length axis");
696
- const result = reduce(this, AluOp.Add, axis, opts);
697
- return result.mul(1 / n);
693
+ const originalDtype = this.dtype;
694
+ const castDtype = promoteTypes(originalDtype, DType.Float32);
695
+ const result = reduce(this.astype(castDtype), AluOp.Add, axis, opts);
696
+ return result.mul(1 / n).astype(originalDtype);
698
697
  }
699
698
  /** Permute the dimensions of an array. Defaults to reversing the axis order. */
700
699
  transpose(perm) {
@@ -1170,7 +1169,7 @@ var Jaxpr = class Jaxpr {
1170
1169
  } else if (eqn.primitive === Primitive.Idiv) {
1171
1170
  const [a, b] = inputs;
1172
1171
  const c = eqn.outBinders[0];
1173
- if (atomIsLit(b, 1)) context.set(c, a);
1172
+ if (atomIsLit(b, 1) && !isFloatDtype(a.aval.dtype)) context.set(c, a);
1174
1173
  else newEqns.push(eqn);
1175
1174
  } else if ((eqn.primitive === Primitive.Broadcast || eqn.primitive === Primitive.Reshape) && deepEqual(eqn.params.shape, eqn.inputs[0].aval.shape) || eqn.primitive === Primitive.Transpose && eqn.params.perm.every((p, i) => p === i) || eqn.primitive === Primitive.Flip && eqn.params.axis.length === 0 || eqn.primitive === Primitive.Shrink && eqn.params.slice.every(([s, e$2], i) => s === 0 && e$2 === eqn.inputs[0].aval.shape[i]) || eqn.primitive === Primitive.Pad && eqn.params.width.every(([w0, w1]) => w0 === 0 && w1 === 0)) context.set(eqn.outBinders[0], eqn.inputs[0]);
1176
1175
  else newEqns.push(eqn);
@@ -1755,48 +1754,73 @@ function jitCompile(backend, jaxpr, consts) {
1755
1754
  const inputExps = [];
1756
1755
  const inputAvals = [];
1757
1756
  const inputArgs = [];
1758
- for (const input of eqn.inputs) if (input instanceof Var) {
1759
- const jitValue = ctx.get(input);
1760
- if (jitValue.type === "exp") {
1761
- const gidMap = /* @__PURE__ */ new Map();
1762
- for (const [gid, jitId] of jitValue.args.entries()) {
1763
- let newGid = inputArgs.indexOf(jitId);
1764
- if (newGid === -1) {
1765
- newGid = inputArgs.length;
1766
- inputArgs.push(jitId);
1767
- }
1768
- gidMap.set(gid, newGid);
1769
- }
1770
- inputExps.push(jitValue.exp.reindexGids(gidMap));
1771
- } else if (jitValue.type === "imm") {
1772
- let gid = inputArgs.indexOf(jitValue.arg);
1773
- if (gid === -1) {
1774
- gid = inputArgs.length;
1775
- inputArgs.push(jitValue.arg);
1757
+ let inputReduction = null;
1758
+ const addArgs = (args) => {
1759
+ const newGids = [];
1760
+ for (const jitId of args) {
1761
+ let newGid = inputArgs.indexOf(jitId);
1762
+ if (newGid === -1) {
1763
+ newGid = inputArgs.length;
1764
+ inputArgs.push(jitId);
1776
1765
  }
1766
+ newGids.push(newGid);
1767
+ }
1768
+ return newGids;
1769
+ };
1770
+ for (const input of eqn.inputs) if (input instanceof Var) {
1771
+ const jv = ctx.get(input);
1772
+ if (jv.type === "exp") {
1773
+ const newGids = addArgs(jv.args);
1774
+ inputExps.push(jv.exp.reindexGids(newGids));
1775
+ } else if (jv.type === "imm") {
1776
+ const [gid] = addArgs([jv.arg]);
1777
1777
  const st = ShapeTracker.fromShape(input.aval.shape);
1778
1778
  const indices = unravelAlu(st.shape, AluVar.gidx);
1779
1779
  inputExps.push(AluExp.globalView(input.aval.dtype, gid, st, indices));
1780
+ } else if (jv.type === "red") {
1781
+ if (inputReduction) throw new Error("jit: unexpected, multiple red inputs");
1782
+ const newGids = addArgs(jv.args);
1783
+ inputExps.push(jv.reduction.epilogue.reindexGids(newGids));
1784
+ inputReduction = jv;
1780
1785
  }
1781
1786
  inputAvals.push(input.aval);
1782
1787
  } else if (input instanceof Lit) {
1783
1788
  inputExps.push(AluExp.const(input.dtype, input.value));
1784
1789
  inputAvals.push(input.aval);
1785
1790
  } else throw new TypeError(`Unexpected input in Jaxpr: ${input}`);
1786
- const nargs$1 = inputArgs.length;
1787
1791
  const rule = jitRules[eqn.primitive];
1788
1792
  if (!rule) throw new TypeError(`JIT not implemented for primitive ${eqn.primitive}`);
1789
- const kernel = rule(nargs$1, inputExps, inputAvals, eqn.params);
1793
+ let exp$2;
1794
+ let reduction;
1795
+ if (inputReduction) {
1796
+ const jv = inputReduction;
1797
+ const newEpilogue = rule(inputExps, inputAvals, eqn.params).exp;
1798
+ exp$2 = jv.exp.reindexGids(addArgs(jv.args));
1799
+ reduction = new Reduction(jv.reduction.dtype, jv.reduction.op, jv.reduction.size, newEpilogue);
1800
+ } else {
1801
+ const ruleOutput = rule(inputExps, inputAvals, eqn.params);
1802
+ exp$2 = ruleOutput.exp;
1803
+ reduction = ruleOutput.reduction;
1804
+ }
1790
1805
  const outVar = eqn.outBinders[0];
1791
- if (kernel.reduction || blackNodes.has(outVar)) {
1806
+ if (blackNodes.has(outVar)) {
1807
+ const nargs$1 = inputArgs.length;
1808
+ const size$1 = prod(outVar.aval.shape);
1809
+ const kernel = new Kernel(nargs$1, size$1, exp$2, reduction);
1792
1810
  const outId = builder.pushKernel(kernel, inputArgs);
1793
1811
  ctx.set(outVar, {
1794
1812
  type: "imm",
1795
1813
  arg: outId
1796
1814
  });
1797
- } else ctx.set(outVar, {
1815
+ } else if (reduction) ctx.set(outVar, {
1816
+ type: "red",
1817
+ exp: exp$2,
1818
+ reduction,
1819
+ args: inputArgs
1820
+ });
1821
+ else ctx.set(outVar, {
1798
1822
  type: "exp",
1799
- exp: kernel.exp,
1823
+ exp: exp$2,
1800
1824
  args: inputArgs
1801
1825
  });
1802
1826
  }
@@ -1828,31 +1852,28 @@ function reshapeViews(exp$2, mapping, reduceAxis = false) {
1828
1852
  });
1829
1853
  }
1830
1854
  function broadcastedJit(fn, opts) {
1831
- return (nargs, exps, avals, params) => {
1855
+ return (exps, avals, params) => {
1832
1856
  let { shape: newShape, dtype: newDtype } = avals.reduce(promoteAvals);
1833
1857
  const skipCastIdx = opts?.skipCastIdx ?? [];
1834
1858
  if (skipCastIdx.length) newDtype = avals.filter((_, i) => !skipCastIdx.includes(i)).reduce(promoteAvals).dtype;
1835
- exps = exps.map((exp$3, i) => {
1836
- exp$3 = reshapeViews(exp$3, (st) => {
1859
+ exps = exps.map((exp$2, i) => {
1860
+ exp$2 = reshapeViews(exp$2, (st) => {
1837
1861
  if (!deepEqual(st.shape, newShape)) return st.broadcast(newShape, range(newShape.length - st.shape.length));
1838
1862
  });
1839
- if (exp$3.dtype !== newDtype && !skipCastIdx.includes(i)) exp$3 = AluExp.cast(newDtype, exp$3);
1840
- return exp$3;
1863
+ if (exp$2.dtype !== newDtype && !skipCastIdx.includes(i)) exp$2 = AluExp.cast(newDtype, exp$2);
1864
+ return exp$2;
1841
1865
  });
1842
- const exp$2 = fn(exps, params);
1843
- return new Kernel(nargs, prod(newShape), exp$2);
1866
+ return { exp: fn(exps, params) };
1844
1867
  };
1845
1868
  }
1846
1869
  function unopJit(fn) {
1847
- return (nargs, [a], [as], params) => {
1848
- return new Kernel(nargs, prod(as.shape), fn(a, params));
1870
+ return ([a], [_as], params) => {
1871
+ return { exp: fn(a, params) };
1849
1872
  };
1850
1873
  }
1851
1874
  function reshapeJit(fn) {
1852
- return (nargs, [a], [as], params) => {
1853
- a = reshapeViews(a, (st) => fn(st, params));
1854
- const newShape = fn(ShapeTracker.fromShape(as.shape), params).shape;
1855
- return new Kernel(nargs, prod(newShape), a);
1875
+ return ([a], [_as], params) => {
1876
+ return { exp: reshapeViews(a, (st) => fn(st, params)) };
1856
1877
  };
1857
1878
  }
1858
1879
  const jitRules = {
@@ -1867,7 +1888,7 @@ const jitRules = {
1867
1888
  [Primitive.StopGradient]: unopJit((a) => a),
1868
1889
  [Primitive.Cast]: unopJit((a, { dtype }) => AluExp.cast(dtype, a)),
1869
1890
  [Primitive.Bitcast]: unopJit((a, { dtype }) => AluExp.bitcast(dtype, a)),
1870
- [Primitive.RandomBits]: (nargs, keys, keyShapes, { shape: shape$1, mode }) => {
1891
+ [Primitive.RandomBits]: (keys, keyShapes, { shape: shape$1, mode }) => {
1871
1892
  const mapping = (st) => {
1872
1893
  if (!deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, range(shape$1.length - st.shape.length));
1873
1894
  };
@@ -1876,7 +1897,7 @@ const jitRules = {
1876
1897
  const c0 = AluExp.u32(0);
1877
1898
  const c1 = AluExp.cast(DType.Uint32, AluVar.gidx);
1878
1899
  const exp$2 = AluExp.threefry2x32(k0, k1, c0, c1, mode);
1879
- return new Kernel(nargs, prod(shape$1), exp$2);
1900
+ return { exp: exp$2 };
1880
1901
  },
1881
1902
  [Primitive.Sin]: unopJit(AluExp.sin),
1882
1903
  [Primitive.Cos]: unopJit(AluExp.cos),
@@ -1889,7 +1910,7 @@ const jitRules = {
1889
1910
  [Primitive.Sqrt]: unopJit(AluExp.sqrt),
1890
1911
  [Primitive.Min]: broadcastedJit(([a, b]) => AluExp.min(a, b)),
1891
1912
  [Primitive.Max]: broadcastedJit(([a, b]) => AluExp.max(a, b)),
1892
- [Primitive.Reduce](nargs, [a], [as], { op, axis }) {
1913
+ [Primitive.Reduce]([a], [as], { op, axis }) {
1893
1914
  const keptAxes = [];
1894
1915
  const shiftedAxes = [];
1895
1916
  const newShape = [];
@@ -1898,39 +1919,43 @@ const jitRules = {
1898
1919
  keptAxes.push(i);
1899
1920
  newShape.push(as.shape[i]);
1900
1921
  }
1901
- const size$1 = prod(newShape);
1902
1922
  const reductionSize = prod(shiftedAxes.map((ax) => as.shape[ax]));
1903
1923
  newShape.push(reductionSize);
1904
1924
  const perm = keptAxes.concat(shiftedAxes);
1905
1925
  a = reshapeViews(a, (st) => st.permute(perm).reshape(newShape), true);
1906
1926
  const reduction = new Reduction(a.dtype, op, reductionSize);
1907
- return new Kernel(nargs, size$1, a, reduction);
1927
+ return {
1928
+ exp: a,
1929
+ reduction
1930
+ };
1908
1931
  },
1909
1932
  [Primitive.Pool]: reshapeJit((st, { window, strides }) => pool(st, window, strides)),
1910
- [Primitive.PoolTranspose](nargs, [a], [as], { inShape, window, strides }) {
1933
+ [Primitive.PoolTranspose]([a], [as], { inShape, window, strides }) {
1911
1934
  let stX = poolTranspose(ShapeTracker.fromShape(as.shape), inShape, window, strides);
1912
- const size$1 = prod(inShape);
1913
1935
  stX = stX.reshape([...inShape, prod(stX.shape.slice(inShape.length))]);
1914
1936
  a = reshapeViews(a, (st) => st.compose(stX), true);
1915
1937
  const reduction = new Reduction(a.dtype, AluOp.Add, stX.shape[stX.shape.length - 1]);
1916
- return new Kernel(nargs, size$1, a, reduction);
1938
+ return {
1939
+ exp: a,
1940
+ reduction
1941
+ };
1917
1942
  },
1918
- [Primitive.Dot](nargs, [a, b], [as, bs]) {
1919
- const k1 = jitRules[Primitive.Mul](nargs, [a, b], [as, bs], {});
1943
+ [Primitive.Dot]([a, b], [as, bs]) {
1944
+ const k1 = jitRules[Primitive.Mul]([a, b], [as, bs], {});
1920
1945
  const c = k1.exp;
1921
1946
  const cs = promoteAvals(as, bs);
1922
- return jitRules[Primitive.Reduce](nargs, [c], [cs], {
1947
+ return jitRules[Primitive.Reduce]([c], [cs], {
1923
1948
  op: AluOp.Add,
1924
1949
  axis: [cs.ndim - 1]
1925
1950
  });
1926
1951
  },
1927
- [Primitive.Conv](nargs, [a, b], [as, bs], params) {
1952
+ [Primitive.Conv]([a, b], [as, bs], params) {
1928
1953
  const [stX, stY] = prepareConv(ShapeTracker.fromShape(as.shape), ShapeTracker.fromShape(bs.shape), params);
1929
1954
  a = reshapeViews(a, (st) => st.compose(stX));
1930
1955
  b = reshapeViews(b, (st) => st.compose(stY));
1931
1956
  as = new ShapedArray(stX.shape, as.dtype, as.weakType);
1932
1957
  bs = new ShapedArray(stY.shape, bs.dtype, bs.weakType);
1933
- return jitRules[Primitive.Dot](nargs, [a, b], [as, bs], {});
1958
+ return jitRules[Primitive.Dot]([a, b], [as, bs], {});
1934
1959
  },
1935
1960
  [Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
1936
1961
  [Primitive.Where]: broadcastedJit(([cond, a, b]) => AluExp.where(cond, a, b), { skipCastIdx: [0] }),
@@ -1944,7 +1969,7 @@ const jitRules = {
1944
1969
  }),
1945
1970
  [Primitive.Shrink]: reshapeJit((st, { slice }) => st.shrink(slice)),
1946
1971
  [Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
1947
- [Primitive.Gather](nargs, [x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
1972
+ [Primitive.Gather]([x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
1948
1973
  const axisSet = new Set(axis);
1949
1974
  const indexShape = indicesShapes.map((c) => c.shape).reduce(generalBroadcast);
1950
1975
  const finalShape = xs.shape.filter((_, i) => !axisSet.has(i));
@@ -1957,7 +1982,7 @@ const jitRules = {
1957
1982
  for (const [i, iexp] of indices.entries()) src[axis[i]] = AluExp.cast(DType.Int32, reshapeViews(iexp, (st) => st.broadcast(finalShape, [...range(outDim + indexShape.length - st.shape.length), ...range(outDim + indexShape.length, finalShape.length)])));
1958
1983
  const [index, valid] = ShapeTracker.fromShape(xs.shape).toAluExp(src);
1959
1984
  if (!valid.resolve()) throw new Error("internal: expected full validity mask in Gather");
1960
- return new Kernel(nargs, prod(finalShape), x.substitute({ gidx: index }));
1985
+ return { exp: x.substitute({ gidx: index }) };
1961
1986
  },
1962
1987
  [Primitive.JitCall]() {
1963
1988
  throw new Error("internal: JitCall should have been flattened before JIT compilation");
@@ -1965,16 +1990,16 @@ const jitRules = {
1965
1990
  };
1966
1991
  /** Determines how to split the Jaxpr into kernels via dataflow analysis. */
1967
1992
  function splitGraphDataflow(backend, jaxpr) {
1968
- const varToEqn = /* @__PURE__ */ new Map();
1993
+ const varToDefn = /* @__PURE__ */ new Map();
1994
+ const varToUsages = /* @__PURE__ */ new Map();
1969
1995
  for (let i = 0; i < jaxpr.eqns.length; i++) {
1970
1996
  const eqn = jaxpr.eqns[i];
1971
- for (const v of eqn.outBinders) if (v instanceof Var) varToEqn.set(v, i);
1972
- }
1973
- const blackNodes = /* @__PURE__ */ new Set();
1974
- const p1NextBlack = /* @__PURE__ */ new Map();
1975
- for (const v of jaxpr.outs) if (v instanceof Var) {
1976
- blackNodes.add(v);
1977
- p1NextBlack.set(v, v);
1997
+ for (const v of eqn.outBinders) if (v instanceof Var) varToDefn.set(v, i);
1998
+ for (const input of eqn.inputs) if (input instanceof Var) {
1999
+ const usages = varToUsages.get(input);
2000
+ if (usages) usages.push(i);
2001
+ else varToUsages.set(input, [i]);
2002
+ }
1978
2003
  }
1979
2004
  const reducePrimitives = [
1980
2005
  Primitive.Reduce,
@@ -1982,10 +2007,68 @@ function splitGraphDataflow(backend, jaxpr) {
1982
2007
  Primitive.Conv,
1983
2008
  Primitive.PoolTranspose
1984
2009
  ];
2010
+ const reductionEpilogueEqns = /* @__PURE__ */ new Set();
2011
+ const reductionEndpointEqns = /* @__PURE__ */ new Set();
2012
+ for (let i = 0; i < jaxpr.eqns.length; i++) {
2013
+ const eqn = jaxpr.eqns[i];
2014
+ if (reducePrimitives.includes(eqn.primitive)) {
2015
+ let head = i;
2016
+ while (true) {
2017
+ reductionEpilogueEqns.add(head);
2018
+ const outVar = jaxpr.eqns[head].outBinders[0];
2019
+ const usages = varToUsages.get(outVar) ?? [];
2020
+ if (jaxpr.outs.includes(outVar) || usages.length !== 1) break;
2021
+ if (reductionEpilogueEqns.has(usages[0])) break;
2022
+ const nextEqn = jaxpr.eqns[usages[0]];
2023
+ switch (nextEqn.primitive) {
2024
+ case Primitive.Neg:
2025
+ case Primitive.Reciprocal:
2026
+ case Primitive.Floor:
2027
+ case Primitive.Ceil:
2028
+ case Primitive.StopGradient:
2029
+ case Primitive.Cast:
2030
+ case Primitive.Bitcast:
2031
+ case Primitive.Sin:
2032
+ case Primitive.Cos:
2033
+ case Primitive.Asin:
2034
+ case Primitive.Atan:
2035
+ case Primitive.Exp:
2036
+ case Primitive.Log:
2037
+ case Primitive.Erf:
2038
+ case Primitive.Erfc:
2039
+ case Primitive.Sqrt:
2040
+ head = usages[0];
2041
+ continue;
2042
+ case Primitive.Add:
2043
+ case Primitive.Mul:
2044
+ case Primitive.Idiv:
2045
+ case Primitive.Mod:
2046
+ case Primitive.Max:
2047
+ case Primitive.Min: {
2048
+ const otherInput = nextEqn.inputs.find((v) => v !== outVar);
2049
+ if (otherInput instanceof Lit || deepEqual(generalBroadcast(otherInput.aval.shape, outVar.aval.shape), outVar.aval.shape)) {
2050
+ head = usages[0];
2051
+ continue;
2052
+ }
2053
+ break;
2054
+ }
2055
+ }
2056
+ break;
2057
+ }
2058
+ reductionEndpointEqns.add(head);
2059
+ }
2060
+ }
2061
+ const blackNodes = /* @__PURE__ */ new Set();
2062
+ const p1NextBlack = /* @__PURE__ */ new Map();
2063
+ for (const v of jaxpr.outs) if (v instanceof Var) {
2064
+ blackNodes.add(v);
2065
+ p1NextBlack.set(v, v);
2066
+ }
1985
2067
  const heterogeneousViewPrimitives = [Primitive.Gather, Primitive.RandomBits];
2068
+ const needsCleanShapePrimitives = [Primitive.Pad];
1986
2069
  for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
1987
2070
  const eqn = jaxpr.eqns[i];
1988
- if (reducePrimitives.includes(eqn.primitive) || heterogeneousViewPrimitives.includes(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
2071
+ if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
1989
2072
  for (const v of eqn.outBinders) {
1990
2073
  blackNodes.add(v);
1991
2074
  p1NextBlack.set(v, v);
@@ -1993,17 +2076,25 @@ function splitGraphDataflow(backend, jaxpr) {
1993
2076
  continue;
1994
2077
  }
1995
2078
  const reach = /* @__PURE__ */ new Set();
1996
- for (let j = i + 1; j < jaxpr.eqns.length; j++) for (const v of jaxpr.eqns[j].inputs) if (v instanceof Var && eqn.outBinders.includes(v)) for (const o of jaxpr.eqns[j].outBinders) {
1997
- const u = p1NextBlack.get(o);
1998
- if (u) reach.add(u);
2079
+ let needsCleanOutput = false;
2080
+ outer: for (const v of eqn.outBinders) for (const j of varToUsages.get(v) ?? []) {
2081
+ if (needsCleanShapePrimitives.includes(jaxpr.eqns[j].primitive)) {
2082
+ needsCleanOutput = true;
2083
+ break outer;
2084
+ }
2085
+ for (const o of jaxpr.eqns[j].outBinders) {
2086
+ const u = p1NextBlack.get(o);
2087
+ if (u) reach.add(u);
2088
+ }
1999
2089
  }
2000
- if (reach.size === 1) {
2001
- const b = reach.values().next().value;
2002
- for (const v of eqn.outBinders) p1NextBlack.set(v, b);
2003
- } else if (reach.size > 1) for (const v of eqn.outBinders) {
2090
+ if (reach.size > 1 || needsCleanOutput) for (const v of eqn.outBinders) {
2004
2091
  blackNodes.add(v);
2005
2092
  p1NextBlack.set(v, v);
2006
2093
  }
2094
+ else if (reach.size === 1) {
2095
+ const b = reach.values().next().value;
2096
+ for (const v of eqn.outBinders) p1NextBlack.set(v, b);
2097
+ }
2007
2098
  }
2008
2099
  const p2Deps = /* @__PURE__ */ new Map();
2009
2100
  for (const v of jaxpr.inBinders) p2Deps.set(v, new Set([v]));
@@ -2022,7 +2113,7 @@ function splitGraphDataflow(backend, jaxpr) {
2022
2113
  let assocInput = -1;
2023
2114
  for (let i = 0; i < eqn.inputs.length; i++) {
2024
2115
  const input = eqn.inputs[i];
2025
- if (input instanceof Var && varToEqn.has(input)) {
2116
+ if (input instanceof Var && varToDefn.has(input)) {
2026
2117
  let uniqueDeps = 0;
2027
2118
  for (const dep of deps[i]) if (depCounter.get(dep) === 1) uniqueDeps++;
2028
2119
  if (uniqueDeps > maxUniqueDeps) {
@@ -2033,7 +2124,7 @@ function splitGraphDataflow(backend, jaxpr) {
2033
2124
  }
2034
2125
  if (assocInput === -1) throw new Error(`internal: maxArgs, no input found to mark as black in Jaxpr equation ${eqn}`);
2035
2126
  const assocVar = eqn.inputs[assocInput];
2036
- p2idx = varToEqn.get(assocVar);
2127
+ p2idx = varToDefn.get(assocVar);
2037
2128
  for (const out of jaxpr.eqns[p2idx].outBinders) blackNodes.add(out);
2038
2129
  } else {
2039
2130
  const s = new Set(depCounter.keys());
@@ -3460,6 +3551,15 @@ const vmapRules = {
3460
3551
  const z = dot$2(x, y);
3461
3552
  return [[z], [z.ndim - 1]];
3462
3553
  },
3554
+ [Primitive.Conv](axisSize, [x, y], [xBdim, yBdim], params) {
3555
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3556
+ y = moveBatchAxis(axisSize, yBdim, 0, y);
3557
+ const z = conv$1(x, y, {
3558
+ ...params,
3559
+ vmapDims: params.vmapDims + 1
3560
+ });
3561
+ return [[z], [0]];
3562
+ },
3463
3563
  [Primitive.Compare](axisSize, args, dims, { op }) {
3464
3564
  return broadcastBatcher((x, y) => compare(x, y, op))(axisSize, args, dims, {});
3465
3565
  },
@@ -3904,7 +4004,7 @@ function partialEvalGraphToJaxpr(tracersIn, tracersOut) {
3904
4004
  for (const t of tracersIn) t.dispose();
3905
4005
  for (const t of tracersOut) t.dispose();
3906
4006
  jaxpr = jaxpr.simplify();
3907
- if (DEBUG >= 5) console.log("jaxpr from partial evaluation:\n" + jaxpr.toString());
4007
+ if (DEBUG >= 5) console.info("jaxpr from partial evaluation:\n" + jaxpr.toString());
3908
4008
  return {
3909
4009
  jaxpr,
3910
4010
  consts
@@ -4038,22 +4138,25 @@ const transposeRules = {
4038
4138
  },
4039
4139
  [Primitive.Conv]([ct], [lhs, rhs], params) {
4040
4140
  if (lhs instanceof UndefPrimal === rhs instanceof UndefPrimal) throw new NonlinearError(Primitive.Conv);
4141
+ const v = params.vmapDims;
4041
4142
  const rev01 = [
4042
- 1,
4043
- 0,
4044
- ...range(2, ct.ndim)
4143
+ ...range(v),
4144
+ v + 1,
4145
+ v,
4146
+ ...range(v + 2, ct.ndim)
4045
4147
  ];
4046
4148
  if (lhs instanceof UndefPrimal) {
4047
4149
  let kernel = rhs;
4048
4150
  kernel = transpose$1(kernel, rev01);
4049
- kernel = flip$1(kernel, range(2, kernel.ndim));
4151
+ kernel = flip$1(kernel, range(v + 2, kernel.ndim));
4050
4152
  const result = conv$1(ct, kernel, {
4153
+ vmapDims: v,
4051
4154
  strides: params.lhsDilation,
4052
4155
  padding: params.padding.map(([pl, _pr], i) => {
4053
- const dilatedKernel = (kernel.shape[i + 2] - 1) * params.rhsDilation[i] + 1;
4054
- const dilatedCt = (ct.shape[i + 2] - 1) * params.strides[i] + 1;
4156
+ const dilatedKernel = (kernel.shape[i + v + 2] - 1) * params.rhsDilation[i] + 1;
4157
+ const dilatedCt = (ct.shape[i + v + 2] - 1) * params.strides[i] + 1;
4055
4158
  const padBefore = dilatedKernel - 1 - pl;
4056
- const dilatedLhs = (lhs.aval.shape[i + 2] - 1) * params.lhsDilation[i] + 1;
4159
+ const dilatedLhs = (lhs.aval.shape[i + v + 2] - 1) * params.lhsDilation[i] + 1;
4057
4160
  const padAfter = dilatedLhs + dilatedKernel - 1 - dilatedCt - padBefore;
4058
4161
  return [padBefore, padAfter];
4059
4162
  }),
@@ -4065,11 +4168,12 @@ const transposeRules = {
4065
4168
  const newLhs = transpose$1(lhs, rev01);
4066
4169
  const newRhs = transpose$1(ct, rev01);
4067
4170
  let result = conv$1(newLhs, newRhs, {
4171
+ vmapDims: v,
4068
4172
  strides: params.rhsDilation,
4069
4173
  padding: params.padding.map(([pl, _pr], i) => {
4070
- const dilatedLhs = (lhs.aval.shape[i + 2] - 1) * params.lhsDilation[i] + 1;
4071
- const dilatedKernel = (rhs.aval.shape[i + 2] - 1) * params.rhsDilation[i] + 1;
4072
- const dilatedCt = (ct.shape[i + 2] - 1) * params.strides[i] + 1;
4174
+ const dilatedLhs = (lhs.aval.shape[i + v + 2] - 1) * params.lhsDilation[i] + 1;
4175
+ const dilatedKernel = (rhs.aval.shape[i + v + 2] - 1) * params.rhsDilation[i] + 1;
4176
+ const dilatedCt = (ct.shape[i + v + 2] - 1) * params.strides[i] + 1;
4073
4177
  const padFromLhs = dilatedCt - dilatedLhs;
4074
4178
  const padFromRhs = dilatedKernel - pl - 1;
4075
4179
  return [pl, padFromLhs + padFromRhs];
@@ -4318,13 +4422,46 @@ function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
4318
4422
  *
4319
4423
  * Grouped convolutions are not supported right now.
4320
4424
  */
4321
- function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation } = {}) {
4425
+ function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
4322
4426
  if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
4323
4427
  if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
4324
4428
  if (typeof padding === "string") {
4325
4429
  if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
4326
4430
  padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? rep(rhs.ndim - 2, 1), padding);
4327
4431
  }
4432
+ if (featureGroupCount !== 1) {
4433
+ const G = featureGroupCount;
4434
+ const [N, C_in, ...xs] = lhs.shape;
4435
+ const [C_out, C_in_per_group, ...ks] = rhs.shape;
4436
+ if (C_in % G !== 0) throw new Error(`featureGroupCount=${G} must divide input channels=${C_in}`);
4437
+ if (C_out % G !== 0) throw new Error(`featureGroupCount=${G} must divide output channels=${C_out}`);
4438
+ if (C_in / G !== C_in_per_group) throw new Error(`rhs input channels=${C_in_per_group} must equal lhs input channels / groups=${C_in / G}`);
4439
+ const lhsGrouped = moveaxis(lhs.reshape([
4440
+ N,
4441
+ G,
4442
+ C_in / G,
4443
+ ...xs
4444
+ ]), 1, 0);
4445
+ const rhsGrouped = rhs.reshape([
4446
+ G,
4447
+ C_out / G,
4448
+ C_in_per_group,
4449
+ ...ks
4450
+ ]);
4451
+ const result = conv$1(lhsGrouped, rhsGrouped, {
4452
+ vmapDims: 1,
4453
+ strides: windowStrides,
4454
+ padding,
4455
+ lhsDilation,
4456
+ rhsDilation
4457
+ });
4458
+ const ys = result.shape.slice(3);
4459
+ return moveaxis(result, 0, 1).reshape([
4460
+ N,
4461
+ C_out,
4462
+ ...ys
4463
+ ]);
4464
+ }
4328
4465
  return conv$1(lhs, rhs, {
4329
4466
  strides: windowStrides,
4330
4467
  padding,
@@ -4610,6 +4747,8 @@ __export(numpy_exports, {
4610
4747
  concatenate: () => concatenate,
4611
4748
  cos: () => cos,
4612
4749
  cosh: () => cosh,
4750
+ cumsum: () => cumsum,
4751
+ cumulativeSum: () => cumulativeSum,
4613
4752
  deg2rad: () => deg2rad,
4614
4753
  degrees: () => degrees,
4615
4754
  diag: () => diag,
@@ -4918,6 +5057,25 @@ function argmax(a, axis, opts) {
4918
5057
  }).reshape([shape$1[axis], ...rep(shape$1.length - axis - 1, 1)]));
4919
5058
  return length.sub(max(idx, axis, opts));
4920
5059
  }
5060
+ /**
5061
+ * Cumulative sum of elements along an axis.
5062
+ *
5063
+ * Currently this function is `O(n^2)`, we'll improve this later on with a
5064
+ * two-phase parallel reduction algorithm.
5065
+ */
5066
+ function cumsum(a, axis) {
5067
+ a = fudgeArray(a);
5068
+ if (axis === void 0) {
5069
+ a = a.ravel();
5070
+ axis = 0;
5071
+ } else axis = checkAxis(axis, a.ndim);
5072
+ const n = a.shape[axis];
5073
+ a = moveaxis$1(a, axis, -1);
5074
+ a = broadcast(a, a.shape.concat(n), [-2]);
5075
+ return moveaxis$1(tril(a).sum(-1), -1, axis);
5076
+ }
5077
+ /** @function Alternative name for `jax.numpy.cumsum()`. */
5078
+ const cumulativeSum = cumsum;
4921
5079
  /** Reverse the elements in an array along the given axes. */
4922
5080
  function flip(x, axis = null) {
4923
5081
  const nd = ndim(x);
@@ -5153,7 +5311,10 @@ function allclose(actual, expected, options) {
5153
5311
  if (!deepEqual(x.shape, y.shape)) return false;
5154
5312
  const xData = x.dataSync();
5155
5313
  const yData = y.dataSync();
5156
- for (let i = 0; i < xData.length; i++) if (Math.abs(xData[i] - yData[i]) > atol + rtol * Math.abs(yData[i])) return false;
5314
+ for (let i = 0; i < xData.length; i++) {
5315
+ if (isNaN(xData[i]) !== isNaN(yData[i])) return false;
5316
+ if (Math.abs(xData[i] - yData[i]) > atol + rtol * Math.abs(yData[i])) return false;
5317
+ }
5157
5318
  return true;
5158
5319
  }
5159
5320
  /** Matrix product of two arrays. */
@@ -5612,7 +5773,10 @@ const degrees = rad2deg;
5612
5773
  * Computes first array raised to power of second array, element-wise.
5613
5774
  */
5614
5775
  const power = jit$1(function power$1(x1, x2) {
5615
- return exp(log(x1).mul(x2));
5776
+ const x2i = trunc(x2.ref);
5777
+ const shouldBeNaN = multiply(x2.ref.notEqual(x2i.ref), x1.ref.less(0));
5778
+ const resultSign = where(mod(x2i, 2).notEqual(0), where(x1.ref.less(0), -1, 1), 1);
5779
+ return where(shouldBeNaN, nan, exp(log(abs(x1)).mul(x2)).mul(resultSign));
5616
5780
  });
5617
5781
  /** @function Alias of `jax.numpy.power()`. */
5618
5782
  const pow = power;
@@ -5968,22 +6132,22 @@ function logSoftmax(x, axis = -1) {
5968
6132
  *
5969
6133
  * Reference: https://en.wikipedia.org/wiki/LogSumExp
5970
6134
  */
5971
- function logsumexp(x, axis = null) {
6135
+ function logsumexp(x, axis = null, opts) {
5972
6136
  x = fudgeArray(x);
5973
6137
  axis = normalizeAxis(axis, x.ndim);
5974
6138
  if (axis.length === 0) return x;
5975
- const xMax = stopGradient(max(x.ref, axis));
5976
- const xMaxDims = broadcast(xMax.ref, x.shape, axis);
5977
- const shifted = x.sub(xMaxDims);
5978
- return xMax.add(log(exp(shifted).sum(axis)));
6139
+ const xMax = stopGradient(max(x.ref, axis, { keepdims: true }));
6140
+ const shifted = x.sub(xMax.ref);
6141
+ const result = xMax.add(log(exp(shifted).sum(axis, { keepdims: true })));
6142
+ return opts?.keepdims ? result : squeeze(result, axis);
5979
6143
  }
5980
6144
  /** Log-mean-exp reduction, like `jax.nn.logsumexp()` but subtracts `log(n)`. */
5981
- function logmeanexp(x, axis = null) {
6145
+ function logmeanexp(x, axis = null, opts) {
5982
6146
  x = fudgeArray(x);
5983
6147
  axis = normalizeAxis(axis, x.ndim);
5984
6148
  if (axis.length === 0) return x;
5985
6149
  const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
5986
- return logsumexp(x, axis).sub(Math.log(n));
6150
+ return logsumexp(x, axis, opts).sub(Math.log(n));
5987
6151
  }
5988
6152
  /**
5989
6153
  * Standardizes input to zero mean and unit variance.