@jax-js/jax 0.1.2 → 0.1.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.cjs CHANGED
@@ -30,30 +30,38 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
30
30
  }) : target, mod$1));
31
31
 
32
32
  //#endregion
33
- const require_backend = require('./backend-DeVfWEFS.cjs');
33
+ const require_backend = require('./backend-CmaidnkQ.cjs');
34
34
 
35
35
  //#region src/frontend/convolution.ts
36
36
  /**
37
37
  * Check that the shapes and parameters passed to convolution are valid.
38
+ * Expected shapes of the lhs and rhs of the convolution are:
39
+ *
40
+ * - `lhsShape = [*vmapDims, batchSize, inChannels, spatialDims...]`
41
+ * - `rhsShape = [*vmapDims, outChannels, inChannels, kernelSize...]`
38
42
  *
39
43
  * If the check succeeds, returns the output shape.
40
44
  */
41
- function checkConvShape(lhsShape, rhsShape, { strides, padding, lhsDilation, rhsDilation }) {
45
+ function checkConvShape(lhsShape, rhsShape, { vmapDims, strides, padding, lhsDilation, rhsDilation }) {
42
46
  if (lhsShape.length !== rhsShape.length) throw new Error(`conv() requires inputs with the same number of dimensions, got ${lhsShape.length} and ${rhsShape.length}`);
43
- const n = lhsShape.length - 2;
47
+ const n = lhsShape.length - 2 - vmapDims;
44
48
  if (n < 0) throw new Error("conv() requires at least 2D inputs");
45
49
  if (strides.length !== n) throw new Error("conv() strides != spatial dims");
46
50
  if (padding.length !== n) throw new Error("conv() padding != spatial dims");
47
51
  if (lhsDilation.length !== n) throw new Error("conv() lhsDilation != spatial dimensions");
48
52
  if (rhsDilation.length !== n) throw new Error("conv() rhsDilation != spatial dimensions");
49
- if (lhsShape[1] !== rhsShape[1]) throw new Error(`conv() input channels: ${lhsShape[1]} != ${rhsShape[1]}`);
50
- const outShape = [lhsShape[0], rhsShape[0]];
53
+ if (lhsShape[vmapDims + 1] !== rhsShape[vmapDims + 1]) throw new Error(`conv() input channels: ${lhsShape[1]} != ${rhsShape[1]}`);
54
+ const outShape = [
55
+ ...require_backend.generalBroadcast(lhsShape.slice(0, vmapDims), rhsShape.slice(0, vmapDims)),
56
+ lhsShape[vmapDims],
57
+ rhsShape[vmapDims]
58
+ ];
51
59
  for (let i = 0; i < n; i++) {
52
60
  if (strides[i] <= 0 || !Number.isInteger(strides[i])) throw new Error(`conv() strides[${i}] must be a positive integer`);
53
61
  if (padding[i].length !== 2 || !padding[i].every(Number.isInteger)) throw new Error(`conv() padding[${i}] must be a 2-tuple of integers`);
54
62
  if (lhsDilation[i] <= 0 || !Number.isInteger(lhsDilation[i])) throw new Error(`conv() lhsDilation[${i}] must be a positive integer`);
55
63
  if (rhsDilation[i] <= 0 || !Number.isInteger(rhsDilation[i])) throw new Error(`conv() rhsDilation[${i}] must be a positive integer`);
56
- const [x, k] = [lhsShape[i + 2], rhsShape[i + 2]];
64
+ const [x, k] = [lhsShape[i + vmapDims + 2], rhsShape[i + vmapDims + 2]];
57
65
  if (k <= 0) throw new Error("conv() kernel size must be positive");
58
66
  const [pl, pr] = padding[i];
59
67
  if (pl < -x || pr < -x || pl + pr < -x) throw new Error(`conv() padding[${i}]=(${pl},${pr}) is too negative for input size ${x}`);
@@ -178,27 +186,13 @@ function poolTranspose(st, inShape, ks, strides = 1, dilation = 1) {
178
186
  function applyDilation(st, dilation) {
179
187
  if (dilation.every((s) => s === 1)) return st;
180
188
  const s_ = dilation;
181
- const [a, b, ...k_] = st.shape;
182
- st = st.reshape([
183
- a,
184
- b,
185
- ...k_.flatMap((k) => [k, 1])
186
- ]);
187
- st = st.pad([
188
- [0, 0],
189
- [0, 0],
190
- ...s_.flatMap((s) => [[0, 0], [0, s - 1]])
191
- ]);
192
- st = st.reshape([
193
- a,
194
- b,
195
- ...k_.map((k, i) => k * s_[i])
196
- ]);
197
- st = st.shrink([
198
- [0, a],
199
- [0, b],
200
- ...k_.map((k, i) => [0, (k - 1) * s_[i] + 1])
201
- ]);
189
+ const n = s_.length;
190
+ const prefix = st.shape.slice(0, -n);
191
+ const k_ = st.shape.slice(-n);
192
+ st = st.reshape([...prefix, ...k_.flatMap((k) => [k, 1])]);
193
+ st = st.pad([...prefix.map(() => [0, 0]), ...s_.flatMap((s) => [[0, 0], [0, s - 1]])]);
194
+ st = st.reshape([...prefix, ...k_.map((k, i) => k * s_[i])]);
195
+ st = st.shrink([...prefix.map((p) => [0, p]), ...k_.map((k, i) => [0, (k - 1) * s_[i] + 1])]);
202
196
  return st;
203
197
  }
204
198
  /**
@@ -208,25 +202,26 @@ function applyDilation(st, dilation) {
208
202
  * beforehand using `checkConvShape()`.
209
203
  */
210
204
  function prepareConv(stX, stY, params) {
211
- const n = stX.shape.length - 2;
205
+ const v = params.vmapDims;
206
+ const n = stX.shape.length - 2 - v;
207
+ const vmapShape = stX.shape.slice(0, v);
212
208
  stX = applyDilation(stX, params.lhsDilation);
213
- const ks = stY.shape.slice(2);
214
- stX = stX.padOrShrink([
215
- [0, 0],
216
- [0, 0],
217
- ...params.padding
218
- ]);
209
+ const ks = stY.shape.slice(v + 2);
210
+ stX = stX.padOrShrink([...require_backend.rep(v + 2, [0, 0]), ...params.padding]);
219
211
  stX = pool(stX, ks, params.strides, params.rhsDilation);
220
- stX = stX.moveaxis(1, n + 1).reshape([
221
- stX.shape[0],
212
+ stX = stX.moveaxis(v + 1, v + n + 1).reshape([
213
+ ...vmapShape,
214
+ stX.shape[v],
222
215
  1,
223
- ...stX.shape.slice(2, n + 2),
224
- stX.shape[1] * require_backend.prod(ks)
216
+ ...stX.shape.slice(v + 2, v + n + 2),
217
+ stX.shape[v + 1] * require_backend.prod(ks)
225
218
  ]);
226
219
  stY = stY.reshape([
227
- stY.shape[0],
220
+ ...vmapShape,
221
+ 1,
222
+ stY.shape[v],
228
223
  ...require_backend.rep(n, 1),
229
- stY.shape[1] * require_backend.prod(ks)
224
+ stY.shape[v + 1] * require_backend.prod(ks)
230
225
  ]);
231
226
  return [stX, stY];
232
227
  }
@@ -498,9 +493,11 @@ function dot$2(x, y) {
498
493
  }
499
494
  function conv$1(x, y, params = {}) {
500
495
  if (x.ndim !== y.ndim) throw new Error(`conv() requires inputs with the same number of dimensions, got ${x.ndim} and ${y.ndim}`);
501
- const n = x.ndim - 2;
496
+ const vmapDims = params.vmapDims ?? 0;
497
+ const n = x.ndim - 2 - vmapDims;
502
498
  if (n < 0) throw new Error("conv() requires at least 2D inputs");
503
499
  return bind1(Primitive.Conv, [x, y], {
500
+ vmapDims,
504
501
  strides: params.strides ?? require_backend.rep(n, 1),
505
502
  padding: params.padding ?? require_backend.rep(n, [0, 0]),
506
503
  lhsDilation: params.lhsDilation ?? require_backend.rep(n, 1),
@@ -724,8 +721,10 @@ var Tracer = class Tracer {
724
721
  axis = require_backend.normalizeAxis(axis, this.ndim);
725
722
  const n = axis.reduce((acc, a) => acc * this.shape[a], 1);
726
723
  if (n === 0) throw new Error("mean: cannot compute mean over zero-length axis");
727
- const result = reduce(this, require_backend.AluOp.Add, axis, opts);
728
- return result.mul(1 / n);
724
+ const originalDtype = this.dtype;
725
+ const castDtype = require_backend.promoteTypes(originalDtype, require_backend.DType.Float32);
726
+ const result = reduce(this.astype(castDtype), require_backend.AluOp.Add, axis, opts);
727
+ return result.mul(1 / n).astype(originalDtype);
729
728
  }
730
729
  /** Permute the dimensions of an array. Defaults to reversing the axis order. */
731
730
  transpose(perm) {
@@ -1205,7 +1204,7 @@ var Jaxpr = class Jaxpr {
1205
1204
  } else if (eqn.primitive === Primitive.Idiv) {
1206
1205
  const [a, b] = inputs;
1207
1206
  const c = eqn.outBinders[0];
1208
- if (atomIsLit(b, 1)) context.set(c, a);
1207
+ if (atomIsLit(b, 1) && !require_backend.isFloatDtype(a.aval.dtype)) context.set(c, a);
1209
1208
  else newEqns.push(eqn);
1210
1209
  } else if ((eqn.primitive === Primitive.Broadcast || eqn.primitive === Primitive.Reshape) && require_backend.deepEqual(eqn.params.shape, eqn.inputs[0].aval.shape) || eqn.primitive === Primitive.Transpose && eqn.params.perm.every((p, i) => p === i) || eqn.primitive === Primitive.Flip && eqn.params.axis.length === 0 || eqn.primitive === Primitive.Shrink && eqn.params.slice.every(([s, e$2], i) => s === 0 && e$2 === eqn.inputs[0].aval.shape[i]) || eqn.primitive === Primitive.Pad && eqn.params.width.every(([w0, w1]) => w0 === 0 && w1 === 0)) context.set(eqn.outBinders[0], eqn.inputs[0]);
1211
1210
  else newEqns.push(eqn);
@@ -1790,48 +1789,73 @@ function jitCompile(backend, jaxpr, consts) {
1790
1789
  const inputExps = [];
1791
1790
  const inputAvals = [];
1792
1791
  const inputArgs = [];
1793
- for (const input of eqn.inputs) if (input instanceof Var) {
1794
- const jitValue = ctx.get(input);
1795
- if (jitValue.type === "exp") {
1796
- const gidMap = /* @__PURE__ */ new Map();
1797
- for (const [gid, jitId] of jitValue.args.entries()) {
1798
- let newGid = inputArgs.indexOf(jitId);
1799
- if (newGid === -1) {
1800
- newGid = inputArgs.length;
1801
- inputArgs.push(jitId);
1802
- }
1803
- gidMap.set(gid, newGid);
1804
- }
1805
- inputExps.push(jitValue.exp.reindexGids(gidMap));
1806
- } else if (jitValue.type === "imm") {
1807
- let gid = inputArgs.indexOf(jitValue.arg);
1808
- if (gid === -1) {
1809
- gid = inputArgs.length;
1810
- inputArgs.push(jitValue.arg);
1792
+ let inputReduction = null;
1793
+ const addArgs = (args) => {
1794
+ const newGids = [];
1795
+ for (const jitId of args) {
1796
+ let newGid = inputArgs.indexOf(jitId);
1797
+ if (newGid === -1) {
1798
+ newGid = inputArgs.length;
1799
+ inputArgs.push(jitId);
1811
1800
  }
1801
+ newGids.push(newGid);
1802
+ }
1803
+ return newGids;
1804
+ };
1805
+ for (const input of eqn.inputs) if (input instanceof Var) {
1806
+ const jv = ctx.get(input);
1807
+ if (jv.type === "exp") {
1808
+ const newGids = addArgs(jv.args);
1809
+ inputExps.push(jv.exp.reindexGids(newGids));
1810
+ } else if (jv.type === "imm") {
1811
+ const [gid] = addArgs([jv.arg]);
1812
1812
  const st = require_backend.ShapeTracker.fromShape(input.aval.shape);
1813
1813
  const indices = require_backend.unravelAlu(st.shape, require_backend.AluVar.gidx);
1814
1814
  inputExps.push(require_backend.AluExp.globalView(input.aval.dtype, gid, st, indices));
1815
+ } else if (jv.type === "red") {
1816
+ if (inputReduction) throw new Error("jit: unexpected, multiple red inputs");
1817
+ const newGids = addArgs(jv.args);
1818
+ inputExps.push(jv.reduction.epilogue.reindexGids(newGids));
1819
+ inputReduction = jv;
1815
1820
  }
1816
1821
  inputAvals.push(input.aval);
1817
1822
  } else if (input instanceof Lit) {
1818
1823
  inputExps.push(require_backend.AluExp.const(input.dtype, input.value));
1819
1824
  inputAvals.push(input.aval);
1820
1825
  } else throw new TypeError(`Unexpected input in Jaxpr: ${input}`);
1821
- const nargs$1 = inputArgs.length;
1822
1826
  const rule = jitRules[eqn.primitive];
1823
1827
  if (!rule) throw new TypeError(`JIT not implemented for primitive ${eqn.primitive}`);
1824
- const kernel = rule(nargs$1, inputExps, inputAvals, eqn.params);
1828
+ let exp$2;
1829
+ let reduction;
1830
+ if (inputReduction) {
1831
+ const jv = inputReduction;
1832
+ const newEpilogue = rule(inputExps, inputAvals, eqn.params).exp;
1833
+ exp$2 = jv.exp.reindexGids(addArgs(jv.args));
1834
+ reduction = new require_backend.Reduction(jv.reduction.dtype, jv.reduction.op, jv.reduction.size, newEpilogue);
1835
+ } else {
1836
+ const ruleOutput = rule(inputExps, inputAvals, eqn.params);
1837
+ exp$2 = ruleOutput.exp;
1838
+ reduction = ruleOutput.reduction;
1839
+ }
1825
1840
  const outVar = eqn.outBinders[0];
1826
- if (kernel.reduction || blackNodes.has(outVar)) {
1841
+ if (blackNodes.has(outVar)) {
1842
+ const nargs$1 = inputArgs.length;
1843
+ const size$1 = require_backend.prod(outVar.aval.shape);
1844
+ const kernel = new require_backend.Kernel(nargs$1, size$1, exp$2, reduction);
1827
1845
  const outId = builder.pushKernel(kernel, inputArgs);
1828
1846
  ctx.set(outVar, {
1829
1847
  type: "imm",
1830
1848
  arg: outId
1831
1849
  });
1832
- } else ctx.set(outVar, {
1850
+ } else if (reduction) ctx.set(outVar, {
1851
+ type: "red",
1852
+ exp: exp$2,
1853
+ reduction,
1854
+ args: inputArgs
1855
+ });
1856
+ else ctx.set(outVar, {
1833
1857
  type: "exp",
1834
- exp: kernel.exp,
1858
+ exp: exp$2,
1835
1859
  args: inputArgs
1836
1860
  });
1837
1861
  }
@@ -1863,31 +1887,28 @@ function reshapeViews(exp$2, mapping, reduceAxis = false) {
1863
1887
  });
1864
1888
  }
1865
1889
  function broadcastedJit(fn, opts) {
1866
- return (nargs, exps, avals, params) => {
1890
+ return (exps, avals, params) => {
1867
1891
  let { shape: newShape, dtype: newDtype } = avals.reduce(promoteAvals);
1868
1892
  const skipCastIdx = opts?.skipCastIdx ?? [];
1869
1893
  if (skipCastIdx.length) newDtype = avals.filter((_, i) => !skipCastIdx.includes(i)).reduce(promoteAvals).dtype;
1870
- exps = exps.map((exp$3, i) => {
1871
- exp$3 = reshapeViews(exp$3, (st) => {
1894
+ exps = exps.map((exp$2, i) => {
1895
+ exp$2 = reshapeViews(exp$2, (st) => {
1872
1896
  if (!require_backend.deepEqual(st.shape, newShape)) return st.broadcast(newShape, require_backend.range(newShape.length - st.shape.length));
1873
1897
  });
1874
- if (exp$3.dtype !== newDtype && !skipCastIdx.includes(i)) exp$3 = require_backend.AluExp.cast(newDtype, exp$3);
1875
- return exp$3;
1898
+ if (exp$2.dtype !== newDtype && !skipCastIdx.includes(i)) exp$2 = require_backend.AluExp.cast(newDtype, exp$2);
1899
+ return exp$2;
1876
1900
  });
1877
- const exp$2 = fn(exps, params);
1878
- return new require_backend.Kernel(nargs, require_backend.prod(newShape), exp$2);
1901
+ return { exp: fn(exps, params) };
1879
1902
  };
1880
1903
  }
1881
1904
  function unopJit(fn) {
1882
- return (nargs, [a], [as], params) => {
1883
- return new require_backend.Kernel(nargs, require_backend.prod(as.shape), fn(a, params));
1905
+ return ([a], [_as], params) => {
1906
+ return { exp: fn(a, params) };
1884
1907
  };
1885
1908
  }
1886
1909
  function reshapeJit(fn) {
1887
- return (nargs, [a], [as], params) => {
1888
- a = reshapeViews(a, (st) => fn(st, params));
1889
- const newShape = fn(require_backend.ShapeTracker.fromShape(as.shape), params).shape;
1890
- return new require_backend.Kernel(nargs, require_backend.prod(newShape), a);
1910
+ return ([a], [_as], params) => {
1911
+ return { exp: reshapeViews(a, (st) => fn(st, params)) };
1891
1912
  };
1892
1913
  }
1893
1914
  const jitRules = {
@@ -1902,7 +1923,7 @@ const jitRules = {
1902
1923
  [Primitive.StopGradient]: unopJit((a) => a),
1903
1924
  [Primitive.Cast]: unopJit((a, { dtype }) => require_backend.AluExp.cast(dtype, a)),
1904
1925
  [Primitive.Bitcast]: unopJit((a, { dtype }) => require_backend.AluExp.bitcast(dtype, a)),
1905
- [Primitive.RandomBits]: (nargs, keys, keyShapes, { shape: shape$1, mode }) => {
1926
+ [Primitive.RandomBits]: (keys, keyShapes, { shape: shape$1, mode }) => {
1906
1927
  const mapping = (st) => {
1907
1928
  if (!require_backend.deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, require_backend.range(shape$1.length - st.shape.length));
1908
1929
  };
@@ -1911,7 +1932,7 @@ const jitRules = {
1911
1932
  const c0 = require_backend.AluExp.u32(0);
1912
1933
  const c1 = require_backend.AluExp.cast(require_backend.DType.Uint32, require_backend.AluVar.gidx);
1913
1934
  const exp$2 = require_backend.AluExp.threefry2x32(k0, k1, c0, c1, mode);
1914
- return new require_backend.Kernel(nargs, require_backend.prod(shape$1), exp$2);
1935
+ return { exp: exp$2 };
1915
1936
  },
1916
1937
  [Primitive.Sin]: unopJit(require_backend.AluExp.sin),
1917
1938
  [Primitive.Cos]: unopJit(require_backend.AluExp.cos),
@@ -1924,7 +1945,7 @@ const jitRules = {
1924
1945
  [Primitive.Sqrt]: unopJit(require_backend.AluExp.sqrt),
1925
1946
  [Primitive.Min]: broadcastedJit(([a, b]) => require_backend.AluExp.min(a, b)),
1926
1947
  [Primitive.Max]: broadcastedJit(([a, b]) => require_backend.AluExp.max(a, b)),
1927
- [Primitive.Reduce](nargs, [a], [as], { op, axis }) {
1948
+ [Primitive.Reduce]([a], [as], { op, axis }) {
1928
1949
  const keptAxes = [];
1929
1950
  const shiftedAxes = [];
1930
1951
  const newShape = [];
@@ -1933,39 +1954,43 @@ const jitRules = {
1933
1954
  keptAxes.push(i);
1934
1955
  newShape.push(as.shape[i]);
1935
1956
  }
1936
- const size$1 = require_backend.prod(newShape);
1937
1957
  const reductionSize = require_backend.prod(shiftedAxes.map((ax) => as.shape[ax]));
1938
1958
  newShape.push(reductionSize);
1939
1959
  const perm = keptAxes.concat(shiftedAxes);
1940
1960
  a = reshapeViews(a, (st) => st.permute(perm).reshape(newShape), true);
1941
1961
  const reduction = new require_backend.Reduction(a.dtype, op, reductionSize);
1942
- return new require_backend.Kernel(nargs, size$1, a, reduction);
1962
+ return {
1963
+ exp: a,
1964
+ reduction
1965
+ };
1943
1966
  },
1944
1967
  [Primitive.Pool]: reshapeJit((st, { window, strides }) => pool(st, window, strides)),
1945
- [Primitive.PoolTranspose](nargs, [a], [as], { inShape, window, strides }) {
1968
+ [Primitive.PoolTranspose]([a], [as], { inShape, window, strides }) {
1946
1969
  let stX = poolTranspose(require_backend.ShapeTracker.fromShape(as.shape), inShape, window, strides);
1947
- const size$1 = require_backend.prod(inShape);
1948
1970
  stX = stX.reshape([...inShape, require_backend.prod(stX.shape.slice(inShape.length))]);
1949
1971
  a = reshapeViews(a, (st) => st.compose(stX), true);
1950
1972
  const reduction = new require_backend.Reduction(a.dtype, require_backend.AluOp.Add, stX.shape[stX.shape.length - 1]);
1951
- return new require_backend.Kernel(nargs, size$1, a, reduction);
1973
+ return {
1974
+ exp: a,
1975
+ reduction
1976
+ };
1952
1977
  },
1953
- [Primitive.Dot](nargs, [a, b], [as, bs]) {
1954
- const k1 = jitRules[Primitive.Mul](nargs, [a, b], [as, bs], {});
1978
+ [Primitive.Dot]([a, b], [as, bs]) {
1979
+ const k1 = jitRules[Primitive.Mul]([a, b], [as, bs], {});
1955
1980
  const c = k1.exp;
1956
1981
  const cs = promoteAvals(as, bs);
1957
- return jitRules[Primitive.Reduce](nargs, [c], [cs], {
1982
+ return jitRules[Primitive.Reduce]([c], [cs], {
1958
1983
  op: require_backend.AluOp.Add,
1959
1984
  axis: [cs.ndim - 1]
1960
1985
  });
1961
1986
  },
1962
- [Primitive.Conv](nargs, [a, b], [as, bs], params) {
1987
+ [Primitive.Conv]([a, b], [as, bs], params) {
1963
1988
  const [stX, stY] = prepareConv(require_backend.ShapeTracker.fromShape(as.shape), require_backend.ShapeTracker.fromShape(bs.shape), params);
1964
1989
  a = reshapeViews(a, (st) => st.compose(stX));
1965
1990
  b = reshapeViews(b, (st) => st.compose(stY));
1966
1991
  as = new ShapedArray(stX.shape, as.dtype, as.weakType);
1967
1992
  bs = new ShapedArray(stY.shape, bs.dtype, bs.weakType);
1968
- return jitRules[Primitive.Dot](nargs, [a, b], [as, bs], {});
1993
+ return jitRules[Primitive.Dot]([a, b], [as, bs], {});
1969
1994
  },
1970
1995
  [Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
1971
1996
  [Primitive.Where]: broadcastedJit(([cond, a, b]) => require_backend.AluExp.where(cond, a, b), { skipCastIdx: [0] }),
@@ -1979,7 +2004,7 @@ const jitRules = {
1979
2004
  }),
1980
2005
  [Primitive.Shrink]: reshapeJit((st, { slice }) => st.shrink(slice)),
1981
2006
  [Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
1982
- [Primitive.Gather](nargs, [x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
2007
+ [Primitive.Gather]([x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
1983
2008
  const axisSet = new Set(axis);
1984
2009
  const indexShape = indicesShapes.map((c) => c.shape).reduce(require_backend.generalBroadcast);
1985
2010
  const finalShape = xs.shape.filter((_, i) => !axisSet.has(i));
@@ -1992,7 +2017,7 @@ const jitRules = {
1992
2017
  for (const [i, iexp] of indices.entries()) src[axis[i]] = require_backend.AluExp.cast(require_backend.DType.Int32, reshapeViews(iexp, (st) => st.broadcast(finalShape, [...require_backend.range(outDim + indexShape.length - st.shape.length), ...require_backend.range(outDim + indexShape.length, finalShape.length)])));
1993
2018
  const [index, valid] = require_backend.ShapeTracker.fromShape(xs.shape).toAluExp(src);
1994
2019
  if (!valid.resolve()) throw new Error("internal: expected full validity mask in Gather");
1995
- return new require_backend.Kernel(nargs, require_backend.prod(finalShape), x.substitute({ gidx: index }));
2020
+ return { exp: x.substitute({ gidx: index }) };
1996
2021
  },
1997
2022
  [Primitive.JitCall]() {
1998
2023
  throw new Error("internal: JitCall should have been flattened before JIT compilation");
@@ -2000,16 +2025,16 @@ const jitRules = {
2000
2025
  };
2001
2026
  /** Determines how to split the Jaxpr into kernels via dataflow analysis. */
2002
2027
  function splitGraphDataflow(backend, jaxpr) {
2003
- const varToEqn = /* @__PURE__ */ new Map();
2028
+ const varToDefn = /* @__PURE__ */ new Map();
2029
+ const varToUsages = /* @__PURE__ */ new Map();
2004
2030
  for (let i = 0; i < jaxpr.eqns.length; i++) {
2005
2031
  const eqn = jaxpr.eqns[i];
2006
- for (const v of eqn.outBinders) if (v instanceof Var) varToEqn.set(v, i);
2007
- }
2008
- const blackNodes = /* @__PURE__ */ new Set();
2009
- const p1NextBlack = /* @__PURE__ */ new Map();
2010
- for (const v of jaxpr.outs) if (v instanceof Var) {
2011
- blackNodes.add(v);
2012
- p1NextBlack.set(v, v);
2032
+ for (const v of eqn.outBinders) if (v instanceof Var) varToDefn.set(v, i);
2033
+ for (const input of eqn.inputs) if (input instanceof Var) {
2034
+ const usages = varToUsages.get(input);
2035
+ if (usages) usages.push(i);
2036
+ else varToUsages.set(input, [i]);
2037
+ }
2013
2038
  }
2014
2039
  const reducePrimitives = [
2015
2040
  Primitive.Reduce,
@@ -2017,10 +2042,68 @@ function splitGraphDataflow(backend, jaxpr) {
2017
2042
  Primitive.Conv,
2018
2043
  Primitive.PoolTranspose
2019
2044
  ];
2045
+ const reductionEpilogueEqns = /* @__PURE__ */ new Set();
2046
+ const reductionEndpointEqns = /* @__PURE__ */ new Set();
2047
+ for (let i = 0; i < jaxpr.eqns.length; i++) {
2048
+ const eqn = jaxpr.eqns[i];
2049
+ if (reducePrimitives.includes(eqn.primitive)) {
2050
+ let head = i;
2051
+ while (true) {
2052
+ reductionEpilogueEqns.add(head);
2053
+ const outVar = jaxpr.eqns[head].outBinders[0];
2054
+ const usages = varToUsages.get(outVar) ?? [];
2055
+ if (jaxpr.outs.includes(outVar) || usages.length !== 1) break;
2056
+ if (reductionEpilogueEqns.has(usages[0])) break;
2057
+ const nextEqn = jaxpr.eqns[usages[0]];
2058
+ switch (nextEqn.primitive) {
2059
+ case Primitive.Neg:
2060
+ case Primitive.Reciprocal:
2061
+ case Primitive.Floor:
2062
+ case Primitive.Ceil:
2063
+ case Primitive.StopGradient:
2064
+ case Primitive.Cast:
2065
+ case Primitive.Bitcast:
2066
+ case Primitive.Sin:
2067
+ case Primitive.Cos:
2068
+ case Primitive.Asin:
2069
+ case Primitive.Atan:
2070
+ case Primitive.Exp:
2071
+ case Primitive.Log:
2072
+ case Primitive.Erf:
2073
+ case Primitive.Erfc:
2074
+ case Primitive.Sqrt:
2075
+ head = usages[0];
2076
+ continue;
2077
+ case Primitive.Add:
2078
+ case Primitive.Mul:
2079
+ case Primitive.Idiv:
2080
+ case Primitive.Mod:
2081
+ case Primitive.Max:
2082
+ case Primitive.Min: {
2083
+ const otherInput = nextEqn.inputs.find((v) => v !== outVar);
2084
+ if (otherInput instanceof Lit || require_backend.deepEqual(require_backend.generalBroadcast(otherInput.aval.shape, outVar.aval.shape), outVar.aval.shape)) {
2085
+ head = usages[0];
2086
+ continue;
2087
+ }
2088
+ break;
2089
+ }
2090
+ }
2091
+ break;
2092
+ }
2093
+ reductionEndpointEqns.add(head);
2094
+ }
2095
+ }
2096
+ const blackNodes = /* @__PURE__ */ new Set();
2097
+ const p1NextBlack = /* @__PURE__ */ new Map();
2098
+ for (const v of jaxpr.outs) if (v instanceof Var) {
2099
+ blackNodes.add(v);
2100
+ p1NextBlack.set(v, v);
2101
+ }
2020
2102
  const heterogeneousViewPrimitives = [Primitive.Gather, Primitive.RandomBits];
2103
+ const needsCleanShapePrimitives = [Primitive.Pad];
2021
2104
  for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
2022
2105
  const eqn = jaxpr.eqns[i];
2023
- if (reducePrimitives.includes(eqn.primitive) || heterogeneousViewPrimitives.includes(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
2106
+ if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
2024
2107
  for (const v of eqn.outBinders) {
2025
2108
  blackNodes.add(v);
2026
2109
  p1NextBlack.set(v, v);
@@ -2028,17 +2111,25 @@ function splitGraphDataflow(backend, jaxpr) {
2028
2111
  continue;
2029
2112
  }
2030
2113
  const reach = /* @__PURE__ */ new Set();
2031
- for (let j = i + 1; j < jaxpr.eqns.length; j++) for (const v of jaxpr.eqns[j].inputs) if (v instanceof Var && eqn.outBinders.includes(v)) for (const o of jaxpr.eqns[j].outBinders) {
2032
- const u = p1NextBlack.get(o);
2033
- if (u) reach.add(u);
2114
+ let needsCleanOutput = false;
2115
+ outer: for (const v of eqn.outBinders) for (const j of varToUsages.get(v) ?? []) {
2116
+ if (needsCleanShapePrimitives.includes(jaxpr.eqns[j].primitive)) {
2117
+ needsCleanOutput = true;
2118
+ break outer;
2119
+ }
2120
+ for (const o of jaxpr.eqns[j].outBinders) {
2121
+ const u = p1NextBlack.get(o);
2122
+ if (u) reach.add(u);
2123
+ }
2034
2124
  }
2035
- if (reach.size === 1) {
2036
- const b = reach.values().next().value;
2037
- for (const v of eqn.outBinders) p1NextBlack.set(v, b);
2038
- } else if (reach.size > 1) for (const v of eqn.outBinders) {
2125
+ if (reach.size > 1 || needsCleanOutput) for (const v of eqn.outBinders) {
2039
2126
  blackNodes.add(v);
2040
2127
  p1NextBlack.set(v, v);
2041
2128
  }
2129
+ else if (reach.size === 1) {
2130
+ const b = reach.values().next().value;
2131
+ for (const v of eqn.outBinders) p1NextBlack.set(v, b);
2132
+ }
2042
2133
  }
2043
2134
  const p2Deps = /* @__PURE__ */ new Map();
2044
2135
  for (const v of jaxpr.inBinders) p2Deps.set(v, new Set([v]));
@@ -2057,7 +2148,7 @@ function splitGraphDataflow(backend, jaxpr) {
2057
2148
  let assocInput = -1;
2058
2149
  for (let i = 0; i < eqn.inputs.length; i++) {
2059
2150
  const input = eqn.inputs[i];
2060
- if (input instanceof Var && varToEqn.has(input)) {
2151
+ if (input instanceof Var && varToDefn.has(input)) {
2061
2152
  let uniqueDeps = 0;
2062
2153
  for (const dep of deps[i]) if (depCounter.get(dep) === 1) uniqueDeps++;
2063
2154
  if (uniqueDeps > maxUniqueDeps) {
@@ -2068,7 +2159,7 @@ function splitGraphDataflow(backend, jaxpr) {
2068
2159
  }
2069
2160
  if (assocInput === -1) throw new Error(`internal: maxArgs, no input found to mark as black in Jaxpr equation ${eqn}`);
2070
2161
  const assocVar = eqn.inputs[assocInput];
2071
- p2idx = varToEqn.get(assocVar);
2162
+ p2idx = varToDefn.get(assocVar);
2072
2163
  for (const out of jaxpr.eqns[p2idx].outBinders) blackNodes.add(out);
2073
2164
  } else {
2074
2165
  const s = new Set(depCounter.keys());
@@ -3497,6 +3588,15 @@ const vmapRules = {
3497
3588
  const z = dot$2(x, y);
3498
3589
  return [[z], [z.ndim - 1]];
3499
3590
  },
3591
+ [Primitive.Conv](axisSize, [x, y], [xBdim, yBdim], params) {
3592
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3593
+ y = moveBatchAxis(axisSize, yBdim, 0, y);
3594
+ const z = conv$1(x, y, {
3595
+ ...params,
3596
+ vmapDims: params.vmapDims + 1
3597
+ });
3598
+ return [[z], [0]];
3599
+ },
3500
3600
  [Primitive.Compare](axisSize, args, dims, { op }) {
3501
3601
  return broadcastBatcher((x, y) => compare(x, y, op))(axisSize, args, dims, {});
3502
3602
  },
@@ -3941,7 +4041,7 @@ function partialEvalGraphToJaxpr(tracersIn, tracersOut) {
3941
4041
  for (const t of tracersIn) t.dispose();
3942
4042
  for (const t of tracersOut) t.dispose();
3943
4043
  jaxpr = jaxpr.simplify();
3944
- if (require_backend.DEBUG >= 5) console.log("jaxpr from partial evaluation:\n" + jaxpr.toString());
4044
+ if (require_backend.DEBUG >= 5) console.info("jaxpr from partial evaluation:\n" + jaxpr.toString());
3945
4045
  return {
3946
4046
  jaxpr,
3947
4047
  consts
@@ -4075,22 +4175,25 @@ const transposeRules = {
4075
4175
  },
4076
4176
  [Primitive.Conv]([ct], [lhs, rhs], params) {
4077
4177
  if (lhs instanceof UndefPrimal === rhs instanceof UndefPrimal) throw new NonlinearError(Primitive.Conv);
4178
+ const v = params.vmapDims;
4078
4179
  const rev01 = [
4079
- 1,
4080
- 0,
4081
- ...require_backend.range(2, ct.ndim)
4180
+ ...require_backend.range(v),
4181
+ v + 1,
4182
+ v,
4183
+ ...require_backend.range(v + 2, ct.ndim)
4082
4184
  ];
4083
4185
  if (lhs instanceof UndefPrimal) {
4084
4186
  let kernel = rhs;
4085
4187
  kernel = transpose$1(kernel, rev01);
4086
- kernel = flip$1(kernel, require_backend.range(2, kernel.ndim));
4188
+ kernel = flip$1(kernel, require_backend.range(v + 2, kernel.ndim));
4087
4189
  const result = conv$1(ct, kernel, {
4190
+ vmapDims: v,
4088
4191
  strides: params.lhsDilation,
4089
4192
  padding: params.padding.map(([pl, _pr], i) => {
4090
- const dilatedKernel = (kernel.shape[i + 2] - 1) * params.rhsDilation[i] + 1;
4091
- const dilatedCt = (ct.shape[i + 2] - 1) * params.strides[i] + 1;
4193
+ const dilatedKernel = (kernel.shape[i + v + 2] - 1) * params.rhsDilation[i] + 1;
4194
+ const dilatedCt = (ct.shape[i + v + 2] - 1) * params.strides[i] + 1;
4092
4195
  const padBefore = dilatedKernel - 1 - pl;
4093
- const dilatedLhs = (lhs.aval.shape[i + 2] - 1) * params.lhsDilation[i] + 1;
4196
+ const dilatedLhs = (lhs.aval.shape[i + v + 2] - 1) * params.lhsDilation[i] + 1;
4094
4197
  const padAfter = dilatedLhs + dilatedKernel - 1 - dilatedCt - padBefore;
4095
4198
  return [padBefore, padAfter];
4096
4199
  }),
@@ -4102,11 +4205,12 @@ const transposeRules = {
4102
4205
  const newLhs = transpose$1(lhs, rev01);
4103
4206
  const newRhs = transpose$1(ct, rev01);
4104
4207
  let result = conv$1(newLhs, newRhs, {
4208
+ vmapDims: v,
4105
4209
  strides: params.rhsDilation,
4106
4210
  padding: params.padding.map(([pl, _pr], i) => {
4107
- const dilatedLhs = (lhs.aval.shape[i + 2] - 1) * params.lhsDilation[i] + 1;
4108
- const dilatedKernel = (rhs.aval.shape[i + 2] - 1) * params.rhsDilation[i] + 1;
4109
- const dilatedCt = (ct.shape[i + 2] - 1) * params.strides[i] + 1;
4211
+ const dilatedLhs = (lhs.aval.shape[i + v + 2] - 1) * params.lhsDilation[i] + 1;
4212
+ const dilatedKernel = (rhs.aval.shape[i + v + 2] - 1) * params.rhsDilation[i] + 1;
4213
+ const dilatedCt = (ct.shape[i + v + 2] - 1) * params.strides[i] + 1;
4110
4214
  const padFromLhs = dilatedCt - dilatedLhs;
4111
4215
  const padFromRhs = dilatedKernel - pl - 1;
4112
4216
  return [pl, padFromLhs + padFromRhs];
@@ -4355,13 +4459,46 @@ function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
4355
4459
  *
4356
4460
  * Grouped convolutions are not supported right now.
4357
4461
  */
4358
- function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation } = {}) {
4462
+ function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
4359
4463
  if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
4360
4464
  if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
4361
4465
  if (typeof padding === "string") {
4362
4466
  if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
4363
4467
  padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? require_backend.rep(rhs.ndim - 2, 1), padding);
4364
4468
  }
4469
+ if (featureGroupCount !== 1) {
4470
+ const G = featureGroupCount;
4471
+ const [N, C_in, ...xs] = lhs.shape;
4472
+ const [C_out, C_in_per_group, ...ks] = rhs.shape;
4473
+ if (C_in % G !== 0) throw new Error(`featureGroupCount=${G} must divide input channels=${C_in}`);
4474
+ if (C_out % G !== 0) throw new Error(`featureGroupCount=${G} must divide output channels=${C_out}`);
4475
+ if (C_in / G !== C_in_per_group) throw new Error(`rhs input channels=${C_in_per_group} must equal lhs input channels / groups=${C_in / G}`);
4476
+ const lhsGrouped = moveaxis(lhs.reshape([
4477
+ N,
4478
+ G,
4479
+ C_in / G,
4480
+ ...xs
4481
+ ]), 1, 0);
4482
+ const rhsGrouped = rhs.reshape([
4483
+ G,
4484
+ C_out / G,
4485
+ C_in_per_group,
4486
+ ...ks
4487
+ ]);
4488
+ const result = conv$1(lhsGrouped, rhsGrouped, {
4489
+ vmapDims: 1,
4490
+ strides: windowStrides,
4491
+ padding,
4492
+ lhsDilation,
4493
+ rhsDilation
4494
+ });
4495
+ const ys = result.shape.slice(3);
4496
+ return moveaxis(result, 0, 1).reshape([
4497
+ N,
4498
+ C_out,
4499
+ ...ys
4500
+ ]);
4501
+ }
4365
4502
  return conv$1(lhs, rhs, {
4366
4503
  strides: windowStrides,
4367
4504
  padding,
@@ -4647,6 +4784,8 @@ __export(numpy_exports, {
4647
4784
  concatenate: () => concatenate,
4648
4785
  cos: () => cos,
4649
4786
  cosh: () => cosh,
4787
+ cumsum: () => cumsum,
4788
+ cumulativeSum: () => cumulativeSum,
4650
4789
  deg2rad: () => deg2rad,
4651
4790
  degrees: () => degrees,
4652
4791
  diag: () => diag,
@@ -4955,6 +5094,25 @@ function argmax(a, axis, opts) {
4955
5094
  }).reshape([shape$1[axis], ...require_backend.rep(shape$1.length - axis - 1, 1)]));
4956
5095
  return length.sub(max(idx, axis, opts));
4957
5096
  }
5097
+ /**
5098
+ * Cumulative sum of elements along an axis.
5099
+ *
5100
+ * Currently this function is `O(n^2)`, we'll improve this later on with a
5101
+ * two-phase parallel reduction algorithm.
5102
+ */
5103
+ function cumsum(a, axis) {
5104
+ a = fudgeArray(a);
5105
+ if (axis === void 0) {
5106
+ a = a.ravel();
5107
+ axis = 0;
5108
+ } else axis = require_backend.checkAxis(axis, a.ndim);
5109
+ const n = a.shape[axis];
5110
+ a = moveaxis$1(a, axis, -1);
5111
+ a = broadcast(a, a.shape.concat(n), [-2]);
5112
+ return moveaxis$1(tril(a).sum(-1), -1, axis);
5113
+ }
5114
+ /** @function Alternative name for `jax.numpy.cumsum()`. */
5115
+ const cumulativeSum = cumsum;
4958
5116
  /** Reverse the elements in an array along the given axes. */
4959
5117
  function flip(x, axis = null) {
4960
5118
  const nd = ndim(x);
@@ -5190,7 +5348,10 @@ function allclose(actual, expected, options) {
5190
5348
  if (!require_backend.deepEqual(x.shape, y.shape)) return false;
5191
5349
  const xData = x.dataSync();
5192
5350
  const yData = y.dataSync();
5193
- for (let i = 0; i < xData.length; i++) if (Math.abs(xData[i] - yData[i]) > atol + rtol * Math.abs(yData[i])) return false;
5351
+ for (let i = 0; i < xData.length; i++) {
5352
+ if (isNaN(xData[i]) !== isNaN(yData[i])) return false;
5353
+ if (Math.abs(xData[i] - yData[i]) > atol + rtol * Math.abs(yData[i])) return false;
5354
+ }
5194
5355
  return true;
5195
5356
  }
5196
5357
  /** Matrix product of two arrays. */
@@ -5649,7 +5810,10 @@ const degrees = rad2deg;
5649
5810
  * Computes first array raised to power of second array, element-wise.
5650
5811
  */
5651
5812
  const power = jit$1(function power$1(x1, x2) {
5652
- return exp(log(x1).mul(x2));
5813
+ const x2i = trunc(x2.ref);
5814
+ const shouldBeNaN = multiply(x2.ref.notEqual(x2i.ref), x1.ref.less(0));
5815
+ const resultSign = where(mod(x2i, 2).notEqual(0), where(x1.ref.less(0), -1, 1), 1);
5816
+ return where(shouldBeNaN, nan, exp(log(abs(x1)).mul(x2)).mul(resultSign));
5653
5817
  });
5654
5818
  /** @function Alias of `jax.numpy.power()`. */
5655
5819
  const pow = power;
@@ -6005,22 +6169,22 @@ function logSoftmax(x, axis = -1) {
6005
6169
  *
6006
6170
  * Reference: https://en.wikipedia.org/wiki/LogSumExp
6007
6171
  */
6008
- function logsumexp(x, axis = null) {
6172
+ function logsumexp(x, axis = null, opts) {
6009
6173
  x = fudgeArray(x);
6010
6174
  axis = require_backend.normalizeAxis(axis, x.ndim);
6011
6175
  if (axis.length === 0) return x;
6012
- const xMax = stopGradient(max(x.ref, axis));
6013
- const xMaxDims = broadcast(xMax.ref, x.shape, axis);
6014
- const shifted = x.sub(xMaxDims);
6015
- return xMax.add(log(exp(shifted).sum(axis)));
6176
+ const xMax = stopGradient(max(x.ref, axis, { keepdims: true }));
6177
+ const shifted = x.sub(xMax.ref);
6178
+ const result = xMax.add(log(exp(shifted).sum(axis, { keepdims: true })));
6179
+ return opts?.keepdims ? result : squeeze(result, axis);
6016
6180
  }
6017
6181
  /** Log-mean-exp reduction, like `jax.nn.logsumexp()` but subtracts `log(n)`. */
6018
- function logmeanexp(x, axis = null) {
6182
+ function logmeanexp(x, axis = null, opts) {
6019
6183
  x = fudgeArray(x);
6020
6184
  axis = require_backend.normalizeAxis(axis, x.ndim);
6021
6185
  if (axis.length === 0) return x;
6022
6186
  const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
6023
- return logsumexp(x, axis).sub(Math.log(n));
6187
+ return logsumexp(x, axis, opts).sub(Math.log(n));
6024
6188
  }
6025
6189
  /**
6026
6190
  * Standardizes input to zero mean and unit variance.