@jax-js/jax 0.1.2 → 0.1.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +11 -32
- package/dist/{backend-BqymqzuU.js → backend-BY8wlLEl.js} +58 -20
- package/dist/{backend-DeVfWEFS.cjs → backend-CmaidnkQ.cjs} +58 -20
- package/dist/index.cjs +298 -134
- package/dist/index.d.cts +21 -5
- package/dist/index.d.ts +21 -5
- package/dist/index.js +298 -134
- package/dist/{webgpu-CcGP160M.cjs → webgpu-BVns4DbI.cjs} +14 -6
- package/dist/{webgpu-BGuG58KZ.js → webgpu-C9iAP5h5.js} +14 -6
- package/package.json +1 -1
package/dist/index.cjs
CHANGED
|
@@ -30,30 +30,38 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
|
|
|
30
30
|
}) : target, mod$1));
|
|
31
31
|
|
|
32
32
|
//#endregion
|
|
33
|
-
const require_backend = require('./backend-
|
|
33
|
+
const require_backend = require('./backend-CmaidnkQ.cjs');
|
|
34
34
|
|
|
35
35
|
//#region src/frontend/convolution.ts
|
|
36
36
|
/**
|
|
37
37
|
* Check that the shapes and parameters passed to convolution are valid.
|
|
38
|
+
* Expected shapes of the lhs and rhs of the convolution are:
|
|
39
|
+
*
|
|
40
|
+
* - `lhsShape = [*vmapDims, batchSize, inChannels, spatialDims...]`
|
|
41
|
+
* - `rhsShape = [*vmapDims, outChannels, inChannels, kernelSize...]`
|
|
38
42
|
*
|
|
39
43
|
* If the check succeeds, returns the output shape.
|
|
40
44
|
*/
|
|
41
|
-
function checkConvShape(lhsShape, rhsShape, { strides, padding, lhsDilation, rhsDilation }) {
|
|
45
|
+
function checkConvShape(lhsShape, rhsShape, { vmapDims, strides, padding, lhsDilation, rhsDilation }) {
|
|
42
46
|
if (lhsShape.length !== rhsShape.length) throw new Error(`conv() requires inputs with the same number of dimensions, got ${lhsShape.length} and ${rhsShape.length}`);
|
|
43
|
-
const n = lhsShape.length - 2;
|
|
47
|
+
const n = lhsShape.length - 2 - vmapDims;
|
|
44
48
|
if (n < 0) throw new Error("conv() requires at least 2D inputs");
|
|
45
49
|
if (strides.length !== n) throw new Error("conv() strides != spatial dims");
|
|
46
50
|
if (padding.length !== n) throw new Error("conv() padding != spatial dims");
|
|
47
51
|
if (lhsDilation.length !== n) throw new Error("conv() lhsDilation != spatial dimensions");
|
|
48
52
|
if (rhsDilation.length !== n) throw new Error("conv() rhsDilation != spatial dimensions");
|
|
49
|
-
if (lhsShape[1] !== rhsShape[1]) throw new Error(`conv() input channels: ${lhsShape[1]} != ${rhsShape[1]}`);
|
|
50
|
-
const outShape = [
|
|
53
|
+
if (lhsShape[vmapDims + 1] !== rhsShape[vmapDims + 1]) throw new Error(`conv() input channels: ${lhsShape[1]} != ${rhsShape[1]}`);
|
|
54
|
+
const outShape = [
|
|
55
|
+
...require_backend.generalBroadcast(lhsShape.slice(0, vmapDims), rhsShape.slice(0, vmapDims)),
|
|
56
|
+
lhsShape[vmapDims],
|
|
57
|
+
rhsShape[vmapDims]
|
|
58
|
+
];
|
|
51
59
|
for (let i = 0; i < n; i++) {
|
|
52
60
|
if (strides[i] <= 0 || !Number.isInteger(strides[i])) throw new Error(`conv() strides[${i}] must be a positive integer`);
|
|
53
61
|
if (padding[i].length !== 2 || !padding[i].every(Number.isInteger)) throw new Error(`conv() padding[${i}] must be a 2-tuple of integers`);
|
|
54
62
|
if (lhsDilation[i] <= 0 || !Number.isInteger(lhsDilation[i])) throw new Error(`conv() lhsDilation[${i}] must be a positive integer`);
|
|
55
63
|
if (rhsDilation[i] <= 0 || !Number.isInteger(rhsDilation[i])) throw new Error(`conv() rhsDilation[${i}] must be a positive integer`);
|
|
56
|
-
const [x, k] = [lhsShape[i + 2], rhsShape[i + 2]];
|
|
64
|
+
const [x, k] = [lhsShape[i + vmapDims + 2], rhsShape[i + vmapDims + 2]];
|
|
57
65
|
if (k <= 0) throw new Error("conv() kernel size must be positive");
|
|
58
66
|
const [pl, pr] = padding[i];
|
|
59
67
|
if (pl < -x || pr < -x || pl + pr < -x) throw new Error(`conv() padding[${i}]=(${pl},${pr}) is too negative for input size ${x}`);
|
|
@@ -178,27 +186,13 @@ function poolTranspose(st, inShape, ks, strides = 1, dilation = 1) {
|
|
|
178
186
|
function applyDilation(st, dilation) {
|
|
179
187
|
if (dilation.every((s) => s === 1)) return st;
|
|
180
188
|
const s_ = dilation;
|
|
181
|
-
const
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
]);
|
|
187
|
-
st = st.
|
|
188
|
-
[0, 0],
|
|
189
|
-
[0, 0],
|
|
190
|
-
...s_.flatMap((s) => [[0, 0], [0, s - 1]])
|
|
191
|
-
]);
|
|
192
|
-
st = st.reshape([
|
|
193
|
-
a,
|
|
194
|
-
b,
|
|
195
|
-
...k_.map((k, i) => k * s_[i])
|
|
196
|
-
]);
|
|
197
|
-
st = st.shrink([
|
|
198
|
-
[0, a],
|
|
199
|
-
[0, b],
|
|
200
|
-
...k_.map((k, i) => [0, (k - 1) * s_[i] + 1])
|
|
201
|
-
]);
|
|
189
|
+
const n = s_.length;
|
|
190
|
+
const prefix = st.shape.slice(0, -n);
|
|
191
|
+
const k_ = st.shape.slice(-n);
|
|
192
|
+
st = st.reshape([...prefix, ...k_.flatMap((k) => [k, 1])]);
|
|
193
|
+
st = st.pad([...prefix.map(() => [0, 0]), ...s_.flatMap((s) => [[0, 0], [0, s - 1]])]);
|
|
194
|
+
st = st.reshape([...prefix, ...k_.map((k, i) => k * s_[i])]);
|
|
195
|
+
st = st.shrink([...prefix.map((p) => [0, p]), ...k_.map((k, i) => [0, (k - 1) * s_[i] + 1])]);
|
|
202
196
|
return st;
|
|
203
197
|
}
|
|
204
198
|
/**
|
|
@@ -208,25 +202,26 @@ function applyDilation(st, dilation) {
|
|
|
208
202
|
* beforehand using `checkConvShape()`.
|
|
209
203
|
*/
|
|
210
204
|
function prepareConv(stX, stY, params) {
|
|
211
|
-
const
|
|
205
|
+
const v = params.vmapDims;
|
|
206
|
+
const n = stX.shape.length - 2 - v;
|
|
207
|
+
const vmapShape = stX.shape.slice(0, v);
|
|
212
208
|
stX = applyDilation(stX, params.lhsDilation);
|
|
213
|
-
const ks = stY.shape.slice(2);
|
|
214
|
-
stX = stX.padOrShrink([
|
|
215
|
-
[0, 0],
|
|
216
|
-
[0, 0],
|
|
217
|
-
...params.padding
|
|
218
|
-
]);
|
|
209
|
+
const ks = stY.shape.slice(v + 2);
|
|
210
|
+
stX = stX.padOrShrink([...require_backend.rep(v + 2, [0, 0]), ...params.padding]);
|
|
219
211
|
stX = pool(stX, ks, params.strides, params.rhsDilation);
|
|
220
|
-
stX = stX.moveaxis(1, n + 1).reshape([
|
|
221
|
-
|
|
212
|
+
stX = stX.moveaxis(v + 1, v + n + 1).reshape([
|
|
213
|
+
...vmapShape,
|
|
214
|
+
stX.shape[v],
|
|
222
215
|
1,
|
|
223
|
-
...stX.shape.slice(2, n + 2),
|
|
224
|
-
stX.shape[1] * require_backend.prod(ks)
|
|
216
|
+
...stX.shape.slice(v + 2, v + n + 2),
|
|
217
|
+
stX.shape[v + 1] * require_backend.prod(ks)
|
|
225
218
|
]);
|
|
226
219
|
stY = stY.reshape([
|
|
227
|
-
|
|
220
|
+
...vmapShape,
|
|
221
|
+
1,
|
|
222
|
+
stY.shape[v],
|
|
228
223
|
...require_backend.rep(n, 1),
|
|
229
|
-
stY.shape[1] * require_backend.prod(ks)
|
|
224
|
+
stY.shape[v + 1] * require_backend.prod(ks)
|
|
230
225
|
]);
|
|
231
226
|
return [stX, stY];
|
|
232
227
|
}
|
|
@@ -498,9 +493,11 @@ function dot$2(x, y) {
|
|
|
498
493
|
}
|
|
499
494
|
function conv$1(x, y, params = {}) {
|
|
500
495
|
if (x.ndim !== y.ndim) throw new Error(`conv() requires inputs with the same number of dimensions, got ${x.ndim} and ${y.ndim}`);
|
|
501
|
-
const
|
|
496
|
+
const vmapDims = params.vmapDims ?? 0;
|
|
497
|
+
const n = x.ndim - 2 - vmapDims;
|
|
502
498
|
if (n < 0) throw new Error("conv() requires at least 2D inputs");
|
|
503
499
|
return bind1(Primitive.Conv, [x, y], {
|
|
500
|
+
vmapDims,
|
|
504
501
|
strides: params.strides ?? require_backend.rep(n, 1),
|
|
505
502
|
padding: params.padding ?? require_backend.rep(n, [0, 0]),
|
|
506
503
|
lhsDilation: params.lhsDilation ?? require_backend.rep(n, 1),
|
|
@@ -724,8 +721,10 @@ var Tracer = class Tracer {
|
|
|
724
721
|
axis = require_backend.normalizeAxis(axis, this.ndim);
|
|
725
722
|
const n = axis.reduce((acc, a) => acc * this.shape[a], 1);
|
|
726
723
|
if (n === 0) throw new Error("mean: cannot compute mean over zero-length axis");
|
|
727
|
-
const
|
|
728
|
-
|
|
724
|
+
const originalDtype = this.dtype;
|
|
725
|
+
const castDtype = require_backend.promoteTypes(originalDtype, require_backend.DType.Float32);
|
|
726
|
+
const result = reduce(this.astype(castDtype), require_backend.AluOp.Add, axis, opts);
|
|
727
|
+
return result.mul(1 / n).astype(originalDtype);
|
|
729
728
|
}
|
|
730
729
|
/** Permute the dimensions of an array. Defaults to reversing the axis order. */
|
|
731
730
|
transpose(perm) {
|
|
@@ -1205,7 +1204,7 @@ var Jaxpr = class Jaxpr {
|
|
|
1205
1204
|
} else if (eqn.primitive === Primitive.Idiv) {
|
|
1206
1205
|
const [a, b] = inputs;
|
|
1207
1206
|
const c = eqn.outBinders[0];
|
|
1208
|
-
if (atomIsLit(b, 1)) context.set(c, a);
|
|
1207
|
+
if (atomIsLit(b, 1) && !require_backend.isFloatDtype(a.aval.dtype)) context.set(c, a);
|
|
1209
1208
|
else newEqns.push(eqn);
|
|
1210
1209
|
} else if ((eqn.primitive === Primitive.Broadcast || eqn.primitive === Primitive.Reshape) && require_backend.deepEqual(eqn.params.shape, eqn.inputs[0].aval.shape) || eqn.primitive === Primitive.Transpose && eqn.params.perm.every((p, i) => p === i) || eqn.primitive === Primitive.Flip && eqn.params.axis.length === 0 || eqn.primitive === Primitive.Shrink && eqn.params.slice.every(([s, e$2], i) => s === 0 && e$2 === eqn.inputs[0].aval.shape[i]) || eqn.primitive === Primitive.Pad && eqn.params.width.every(([w0, w1]) => w0 === 0 && w1 === 0)) context.set(eqn.outBinders[0], eqn.inputs[0]);
|
|
1211
1210
|
else newEqns.push(eqn);
|
|
@@ -1790,48 +1789,73 @@ function jitCompile(backend, jaxpr, consts) {
|
|
|
1790
1789
|
const inputExps = [];
|
|
1791
1790
|
const inputAvals = [];
|
|
1792
1791
|
const inputArgs = [];
|
|
1793
|
-
|
|
1794
|
-
|
|
1795
|
-
|
|
1796
|
-
|
|
1797
|
-
|
|
1798
|
-
|
|
1799
|
-
|
|
1800
|
-
|
|
1801
|
-
inputArgs.push(jitId);
|
|
1802
|
-
}
|
|
1803
|
-
gidMap.set(gid, newGid);
|
|
1804
|
-
}
|
|
1805
|
-
inputExps.push(jitValue.exp.reindexGids(gidMap));
|
|
1806
|
-
} else if (jitValue.type === "imm") {
|
|
1807
|
-
let gid = inputArgs.indexOf(jitValue.arg);
|
|
1808
|
-
if (gid === -1) {
|
|
1809
|
-
gid = inputArgs.length;
|
|
1810
|
-
inputArgs.push(jitValue.arg);
|
|
1792
|
+
let inputReduction = null;
|
|
1793
|
+
const addArgs = (args) => {
|
|
1794
|
+
const newGids = [];
|
|
1795
|
+
for (const jitId of args) {
|
|
1796
|
+
let newGid = inputArgs.indexOf(jitId);
|
|
1797
|
+
if (newGid === -1) {
|
|
1798
|
+
newGid = inputArgs.length;
|
|
1799
|
+
inputArgs.push(jitId);
|
|
1811
1800
|
}
|
|
1801
|
+
newGids.push(newGid);
|
|
1802
|
+
}
|
|
1803
|
+
return newGids;
|
|
1804
|
+
};
|
|
1805
|
+
for (const input of eqn.inputs) if (input instanceof Var) {
|
|
1806
|
+
const jv = ctx.get(input);
|
|
1807
|
+
if (jv.type === "exp") {
|
|
1808
|
+
const newGids = addArgs(jv.args);
|
|
1809
|
+
inputExps.push(jv.exp.reindexGids(newGids));
|
|
1810
|
+
} else if (jv.type === "imm") {
|
|
1811
|
+
const [gid] = addArgs([jv.arg]);
|
|
1812
1812
|
const st = require_backend.ShapeTracker.fromShape(input.aval.shape);
|
|
1813
1813
|
const indices = require_backend.unravelAlu(st.shape, require_backend.AluVar.gidx);
|
|
1814
1814
|
inputExps.push(require_backend.AluExp.globalView(input.aval.dtype, gid, st, indices));
|
|
1815
|
+
} else if (jv.type === "red") {
|
|
1816
|
+
if (inputReduction) throw new Error("jit: unexpected, multiple red inputs");
|
|
1817
|
+
const newGids = addArgs(jv.args);
|
|
1818
|
+
inputExps.push(jv.reduction.epilogue.reindexGids(newGids));
|
|
1819
|
+
inputReduction = jv;
|
|
1815
1820
|
}
|
|
1816
1821
|
inputAvals.push(input.aval);
|
|
1817
1822
|
} else if (input instanceof Lit) {
|
|
1818
1823
|
inputExps.push(require_backend.AluExp.const(input.dtype, input.value));
|
|
1819
1824
|
inputAvals.push(input.aval);
|
|
1820
1825
|
} else throw new TypeError(`Unexpected input in Jaxpr: ${input}`);
|
|
1821
|
-
const nargs$1 = inputArgs.length;
|
|
1822
1826
|
const rule = jitRules[eqn.primitive];
|
|
1823
1827
|
if (!rule) throw new TypeError(`JIT not implemented for primitive ${eqn.primitive}`);
|
|
1824
|
-
|
|
1828
|
+
let exp$2;
|
|
1829
|
+
let reduction;
|
|
1830
|
+
if (inputReduction) {
|
|
1831
|
+
const jv = inputReduction;
|
|
1832
|
+
const newEpilogue = rule(inputExps, inputAvals, eqn.params).exp;
|
|
1833
|
+
exp$2 = jv.exp.reindexGids(addArgs(jv.args));
|
|
1834
|
+
reduction = new require_backend.Reduction(jv.reduction.dtype, jv.reduction.op, jv.reduction.size, newEpilogue);
|
|
1835
|
+
} else {
|
|
1836
|
+
const ruleOutput = rule(inputExps, inputAvals, eqn.params);
|
|
1837
|
+
exp$2 = ruleOutput.exp;
|
|
1838
|
+
reduction = ruleOutput.reduction;
|
|
1839
|
+
}
|
|
1825
1840
|
const outVar = eqn.outBinders[0];
|
|
1826
|
-
if (
|
|
1841
|
+
if (blackNodes.has(outVar)) {
|
|
1842
|
+
const nargs$1 = inputArgs.length;
|
|
1843
|
+
const size$1 = require_backend.prod(outVar.aval.shape);
|
|
1844
|
+
const kernel = new require_backend.Kernel(nargs$1, size$1, exp$2, reduction);
|
|
1827
1845
|
const outId = builder.pushKernel(kernel, inputArgs);
|
|
1828
1846
|
ctx.set(outVar, {
|
|
1829
1847
|
type: "imm",
|
|
1830
1848
|
arg: outId
|
|
1831
1849
|
});
|
|
1832
|
-
} else ctx.set(outVar, {
|
|
1850
|
+
} else if (reduction) ctx.set(outVar, {
|
|
1851
|
+
type: "red",
|
|
1852
|
+
exp: exp$2,
|
|
1853
|
+
reduction,
|
|
1854
|
+
args: inputArgs
|
|
1855
|
+
});
|
|
1856
|
+
else ctx.set(outVar, {
|
|
1833
1857
|
type: "exp",
|
|
1834
|
-
exp:
|
|
1858
|
+
exp: exp$2,
|
|
1835
1859
|
args: inputArgs
|
|
1836
1860
|
});
|
|
1837
1861
|
}
|
|
@@ -1863,31 +1887,28 @@ function reshapeViews(exp$2, mapping, reduceAxis = false) {
|
|
|
1863
1887
|
});
|
|
1864
1888
|
}
|
|
1865
1889
|
function broadcastedJit(fn, opts) {
|
|
1866
|
-
return (
|
|
1890
|
+
return (exps, avals, params) => {
|
|
1867
1891
|
let { shape: newShape, dtype: newDtype } = avals.reduce(promoteAvals);
|
|
1868
1892
|
const skipCastIdx = opts?.skipCastIdx ?? [];
|
|
1869
1893
|
if (skipCastIdx.length) newDtype = avals.filter((_, i) => !skipCastIdx.includes(i)).reduce(promoteAvals).dtype;
|
|
1870
|
-
exps = exps.map((exp$
|
|
1871
|
-
exp$
|
|
1894
|
+
exps = exps.map((exp$2, i) => {
|
|
1895
|
+
exp$2 = reshapeViews(exp$2, (st) => {
|
|
1872
1896
|
if (!require_backend.deepEqual(st.shape, newShape)) return st.broadcast(newShape, require_backend.range(newShape.length - st.shape.length));
|
|
1873
1897
|
});
|
|
1874
|
-
if (exp$
|
|
1875
|
-
return exp$
|
|
1898
|
+
if (exp$2.dtype !== newDtype && !skipCastIdx.includes(i)) exp$2 = require_backend.AluExp.cast(newDtype, exp$2);
|
|
1899
|
+
return exp$2;
|
|
1876
1900
|
});
|
|
1877
|
-
|
|
1878
|
-
return new require_backend.Kernel(nargs, require_backend.prod(newShape), exp$2);
|
|
1901
|
+
return { exp: fn(exps, params) };
|
|
1879
1902
|
};
|
|
1880
1903
|
}
|
|
1881
1904
|
function unopJit(fn) {
|
|
1882
|
-
return (
|
|
1883
|
-
return
|
|
1905
|
+
return ([a], [_as], params) => {
|
|
1906
|
+
return { exp: fn(a, params) };
|
|
1884
1907
|
};
|
|
1885
1908
|
}
|
|
1886
1909
|
function reshapeJit(fn) {
|
|
1887
|
-
return (
|
|
1888
|
-
|
|
1889
|
-
const newShape = fn(require_backend.ShapeTracker.fromShape(as.shape), params).shape;
|
|
1890
|
-
return new require_backend.Kernel(nargs, require_backend.prod(newShape), a);
|
|
1910
|
+
return ([a], [_as], params) => {
|
|
1911
|
+
return { exp: reshapeViews(a, (st) => fn(st, params)) };
|
|
1891
1912
|
};
|
|
1892
1913
|
}
|
|
1893
1914
|
const jitRules = {
|
|
@@ -1902,7 +1923,7 @@ const jitRules = {
|
|
|
1902
1923
|
[Primitive.StopGradient]: unopJit((a) => a),
|
|
1903
1924
|
[Primitive.Cast]: unopJit((a, { dtype }) => require_backend.AluExp.cast(dtype, a)),
|
|
1904
1925
|
[Primitive.Bitcast]: unopJit((a, { dtype }) => require_backend.AluExp.bitcast(dtype, a)),
|
|
1905
|
-
[Primitive.RandomBits]: (
|
|
1926
|
+
[Primitive.RandomBits]: (keys, keyShapes, { shape: shape$1, mode }) => {
|
|
1906
1927
|
const mapping = (st) => {
|
|
1907
1928
|
if (!require_backend.deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, require_backend.range(shape$1.length - st.shape.length));
|
|
1908
1929
|
};
|
|
@@ -1911,7 +1932,7 @@ const jitRules = {
|
|
|
1911
1932
|
const c0 = require_backend.AluExp.u32(0);
|
|
1912
1933
|
const c1 = require_backend.AluExp.cast(require_backend.DType.Uint32, require_backend.AluVar.gidx);
|
|
1913
1934
|
const exp$2 = require_backend.AluExp.threefry2x32(k0, k1, c0, c1, mode);
|
|
1914
|
-
return
|
|
1935
|
+
return { exp: exp$2 };
|
|
1915
1936
|
},
|
|
1916
1937
|
[Primitive.Sin]: unopJit(require_backend.AluExp.sin),
|
|
1917
1938
|
[Primitive.Cos]: unopJit(require_backend.AluExp.cos),
|
|
@@ -1924,7 +1945,7 @@ const jitRules = {
|
|
|
1924
1945
|
[Primitive.Sqrt]: unopJit(require_backend.AluExp.sqrt),
|
|
1925
1946
|
[Primitive.Min]: broadcastedJit(([a, b]) => require_backend.AluExp.min(a, b)),
|
|
1926
1947
|
[Primitive.Max]: broadcastedJit(([a, b]) => require_backend.AluExp.max(a, b)),
|
|
1927
|
-
[Primitive.Reduce](
|
|
1948
|
+
[Primitive.Reduce]([a], [as], { op, axis }) {
|
|
1928
1949
|
const keptAxes = [];
|
|
1929
1950
|
const shiftedAxes = [];
|
|
1930
1951
|
const newShape = [];
|
|
@@ -1933,39 +1954,43 @@ const jitRules = {
|
|
|
1933
1954
|
keptAxes.push(i);
|
|
1934
1955
|
newShape.push(as.shape[i]);
|
|
1935
1956
|
}
|
|
1936
|
-
const size$1 = require_backend.prod(newShape);
|
|
1937
1957
|
const reductionSize = require_backend.prod(shiftedAxes.map((ax) => as.shape[ax]));
|
|
1938
1958
|
newShape.push(reductionSize);
|
|
1939
1959
|
const perm = keptAxes.concat(shiftedAxes);
|
|
1940
1960
|
a = reshapeViews(a, (st) => st.permute(perm).reshape(newShape), true);
|
|
1941
1961
|
const reduction = new require_backend.Reduction(a.dtype, op, reductionSize);
|
|
1942
|
-
return
|
|
1962
|
+
return {
|
|
1963
|
+
exp: a,
|
|
1964
|
+
reduction
|
|
1965
|
+
};
|
|
1943
1966
|
},
|
|
1944
1967
|
[Primitive.Pool]: reshapeJit((st, { window, strides }) => pool(st, window, strides)),
|
|
1945
|
-
[Primitive.PoolTranspose](
|
|
1968
|
+
[Primitive.PoolTranspose]([a], [as], { inShape, window, strides }) {
|
|
1946
1969
|
let stX = poolTranspose(require_backend.ShapeTracker.fromShape(as.shape), inShape, window, strides);
|
|
1947
|
-
const size$1 = require_backend.prod(inShape);
|
|
1948
1970
|
stX = stX.reshape([...inShape, require_backend.prod(stX.shape.slice(inShape.length))]);
|
|
1949
1971
|
a = reshapeViews(a, (st) => st.compose(stX), true);
|
|
1950
1972
|
const reduction = new require_backend.Reduction(a.dtype, require_backend.AluOp.Add, stX.shape[stX.shape.length - 1]);
|
|
1951
|
-
return
|
|
1973
|
+
return {
|
|
1974
|
+
exp: a,
|
|
1975
|
+
reduction
|
|
1976
|
+
};
|
|
1952
1977
|
},
|
|
1953
|
-
[Primitive.Dot](
|
|
1954
|
-
const k1 = jitRules[Primitive.Mul](
|
|
1978
|
+
[Primitive.Dot]([a, b], [as, bs]) {
|
|
1979
|
+
const k1 = jitRules[Primitive.Mul]([a, b], [as, bs], {});
|
|
1955
1980
|
const c = k1.exp;
|
|
1956
1981
|
const cs = promoteAvals(as, bs);
|
|
1957
|
-
return jitRules[Primitive.Reduce](
|
|
1982
|
+
return jitRules[Primitive.Reduce]([c], [cs], {
|
|
1958
1983
|
op: require_backend.AluOp.Add,
|
|
1959
1984
|
axis: [cs.ndim - 1]
|
|
1960
1985
|
});
|
|
1961
1986
|
},
|
|
1962
|
-
[Primitive.Conv](
|
|
1987
|
+
[Primitive.Conv]([a, b], [as, bs], params) {
|
|
1963
1988
|
const [stX, stY] = prepareConv(require_backend.ShapeTracker.fromShape(as.shape), require_backend.ShapeTracker.fromShape(bs.shape), params);
|
|
1964
1989
|
a = reshapeViews(a, (st) => st.compose(stX));
|
|
1965
1990
|
b = reshapeViews(b, (st) => st.compose(stY));
|
|
1966
1991
|
as = new ShapedArray(stX.shape, as.dtype, as.weakType);
|
|
1967
1992
|
bs = new ShapedArray(stY.shape, bs.dtype, bs.weakType);
|
|
1968
|
-
return jitRules[Primitive.Dot](
|
|
1993
|
+
return jitRules[Primitive.Dot]([a, b], [as, bs], {});
|
|
1969
1994
|
},
|
|
1970
1995
|
[Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
|
|
1971
1996
|
[Primitive.Where]: broadcastedJit(([cond, a, b]) => require_backend.AluExp.where(cond, a, b), { skipCastIdx: [0] }),
|
|
@@ -1979,7 +2004,7 @@ const jitRules = {
|
|
|
1979
2004
|
}),
|
|
1980
2005
|
[Primitive.Shrink]: reshapeJit((st, { slice }) => st.shrink(slice)),
|
|
1981
2006
|
[Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
|
|
1982
|
-
[Primitive.Gather](
|
|
2007
|
+
[Primitive.Gather]([x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
|
|
1983
2008
|
const axisSet = new Set(axis);
|
|
1984
2009
|
const indexShape = indicesShapes.map((c) => c.shape).reduce(require_backend.generalBroadcast);
|
|
1985
2010
|
const finalShape = xs.shape.filter((_, i) => !axisSet.has(i));
|
|
@@ -1992,7 +2017,7 @@ const jitRules = {
|
|
|
1992
2017
|
for (const [i, iexp] of indices.entries()) src[axis[i]] = require_backend.AluExp.cast(require_backend.DType.Int32, reshapeViews(iexp, (st) => st.broadcast(finalShape, [...require_backend.range(outDim + indexShape.length - st.shape.length), ...require_backend.range(outDim + indexShape.length, finalShape.length)])));
|
|
1993
2018
|
const [index, valid] = require_backend.ShapeTracker.fromShape(xs.shape).toAluExp(src);
|
|
1994
2019
|
if (!valid.resolve()) throw new Error("internal: expected full validity mask in Gather");
|
|
1995
|
-
return
|
|
2020
|
+
return { exp: x.substitute({ gidx: index }) };
|
|
1996
2021
|
},
|
|
1997
2022
|
[Primitive.JitCall]() {
|
|
1998
2023
|
throw new Error("internal: JitCall should have been flattened before JIT compilation");
|
|
@@ -2000,16 +2025,16 @@ const jitRules = {
|
|
|
2000
2025
|
};
|
|
2001
2026
|
/** Determines how to split the Jaxpr into kernels via dataflow analysis. */
|
|
2002
2027
|
function splitGraphDataflow(backend, jaxpr) {
|
|
2003
|
-
const
|
|
2028
|
+
const varToDefn = /* @__PURE__ */ new Map();
|
|
2029
|
+
const varToUsages = /* @__PURE__ */ new Map();
|
|
2004
2030
|
for (let i = 0; i < jaxpr.eqns.length; i++) {
|
|
2005
2031
|
const eqn = jaxpr.eqns[i];
|
|
2006
|
-
for (const v of eqn.outBinders) if (v instanceof Var)
|
|
2007
|
-
|
|
2008
|
-
|
|
2009
|
-
|
|
2010
|
-
|
|
2011
|
-
|
|
2012
|
-
p1NextBlack.set(v, v);
|
|
2032
|
+
for (const v of eqn.outBinders) if (v instanceof Var) varToDefn.set(v, i);
|
|
2033
|
+
for (const input of eqn.inputs) if (input instanceof Var) {
|
|
2034
|
+
const usages = varToUsages.get(input);
|
|
2035
|
+
if (usages) usages.push(i);
|
|
2036
|
+
else varToUsages.set(input, [i]);
|
|
2037
|
+
}
|
|
2013
2038
|
}
|
|
2014
2039
|
const reducePrimitives = [
|
|
2015
2040
|
Primitive.Reduce,
|
|
@@ -2017,10 +2042,68 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2017
2042
|
Primitive.Conv,
|
|
2018
2043
|
Primitive.PoolTranspose
|
|
2019
2044
|
];
|
|
2045
|
+
const reductionEpilogueEqns = /* @__PURE__ */ new Set();
|
|
2046
|
+
const reductionEndpointEqns = /* @__PURE__ */ new Set();
|
|
2047
|
+
for (let i = 0; i < jaxpr.eqns.length; i++) {
|
|
2048
|
+
const eqn = jaxpr.eqns[i];
|
|
2049
|
+
if (reducePrimitives.includes(eqn.primitive)) {
|
|
2050
|
+
let head = i;
|
|
2051
|
+
while (true) {
|
|
2052
|
+
reductionEpilogueEqns.add(head);
|
|
2053
|
+
const outVar = jaxpr.eqns[head].outBinders[0];
|
|
2054
|
+
const usages = varToUsages.get(outVar) ?? [];
|
|
2055
|
+
if (jaxpr.outs.includes(outVar) || usages.length !== 1) break;
|
|
2056
|
+
if (reductionEpilogueEqns.has(usages[0])) break;
|
|
2057
|
+
const nextEqn = jaxpr.eqns[usages[0]];
|
|
2058
|
+
switch (nextEqn.primitive) {
|
|
2059
|
+
case Primitive.Neg:
|
|
2060
|
+
case Primitive.Reciprocal:
|
|
2061
|
+
case Primitive.Floor:
|
|
2062
|
+
case Primitive.Ceil:
|
|
2063
|
+
case Primitive.StopGradient:
|
|
2064
|
+
case Primitive.Cast:
|
|
2065
|
+
case Primitive.Bitcast:
|
|
2066
|
+
case Primitive.Sin:
|
|
2067
|
+
case Primitive.Cos:
|
|
2068
|
+
case Primitive.Asin:
|
|
2069
|
+
case Primitive.Atan:
|
|
2070
|
+
case Primitive.Exp:
|
|
2071
|
+
case Primitive.Log:
|
|
2072
|
+
case Primitive.Erf:
|
|
2073
|
+
case Primitive.Erfc:
|
|
2074
|
+
case Primitive.Sqrt:
|
|
2075
|
+
head = usages[0];
|
|
2076
|
+
continue;
|
|
2077
|
+
case Primitive.Add:
|
|
2078
|
+
case Primitive.Mul:
|
|
2079
|
+
case Primitive.Idiv:
|
|
2080
|
+
case Primitive.Mod:
|
|
2081
|
+
case Primitive.Max:
|
|
2082
|
+
case Primitive.Min: {
|
|
2083
|
+
const otherInput = nextEqn.inputs.find((v) => v !== outVar);
|
|
2084
|
+
if (otherInput instanceof Lit || require_backend.deepEqual(require_backend.generalBroadcast(otherInput.aval.shape, outVar.aval.shape), outVar.aval.shape)) {
|
|
2085
|
+
head = usages[0];
|
|
2086
|
+
continue;
|
|
2087
|
+
}
|
|
2088
|
+
break;
|
|
2089
|
+
}
|
|
2090
|
+
}
|
|
2091
|
+
break;
|
|
2092
|
+
}
|
|
2093
|
+
reductionEndpointEqns.add(head);
|
|
2094
|
+
}
|
|
2095
|
+
}
|
|
2096
|
+
const blackNodes = /* @__PURE__ */ new Set();
|
|
2097
|
+
const p1NextBlack = /* @__PURE__ */ new Map();
|
|
2098
|
+
for (const v of jaxpr.outs) if (v instanceof Var) {
|
|
2099
|
+
blackNodes.add(v);
|
|
2100
|
+
p1NextBlack.set(v, v);
|
|
2101
|
+
}
|
|
2020
2102
|
const heterogeneousViewPrimitives = [Primitive.Gather, Primitive.RandomBits];
|
|
2103
|
+
const needsCleanShapePrimitives = [Primitive.Pad];
|
|
2021
2104
|
for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
|
|
2022
2105
|
const eqn = jaxpr.eqns[i];
|
|
2023
|
-
if (
|
|
2106
|
+
if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
|
|
2024
2107
|
for (const v of eqn.outBinders) {
|
|
2025
2108
|
blackNodes.add(v);
|
|
2026
2109
|
p1NextBlack.set(v, v);
|
|
@@ -2028,17 +2111,25 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2028
2111
|
continue;
|
|
2029
2112
|
}
|
|
2030
2113
|
const reach = /* @__PURE__ */ new Set();
|
|
2031
|
-
|
|
2032
|
-
|
|
2033
|
-
if (
|
|
2114
|
+
let needsCleanOutput = false;
|
|
2115
|
+
outer: for (const v of eqn.outBinders) for (const j of varToUsages.get(v) ?? []) {
|
|
2116
|
+
if (needsCleanShapePrimitives.includes(jaxpr.eqns[j].primitive)) {
|
|
2117
|
+
needsCleanOutput = true;
|
|
2118
|
+
break outer;
|
|
2119
|
+
}
|
|
2120
|
+
for (const o of jaxpr.eqns[j].outBinders) {
|
|
2121
|
+
const u = p1NextBlack.get(o);
|
|
2122
|
+
if (u) reach.add(u);
|
|
2123
|
+
}
|
|
2034
2124
|
}
|
|
2035
|
-
if (reach.size
|
|
2036
|
-
const b = reach.values().next().value;
|
|
2037
|
-
for (const v of eqn.outBinders) p1NextBlack.set(v, b);
|
|
2038
|
-
} else if (reach.size > 1) for (const v of eqn.outBinders) {
|
|
2125
|
+
if (reach.size > 1 || needsCleanOutput) for (const v of eqn.outBinders) {
|
|
2039
2126
|
blackNodes.add(v);
|
|
2040
2127
|
p1NextBlack.set(v, v);
|
|
2041
2128
|
}
|
|
2129
|
+
else if (reach.size === 1) {
|
|
2130
|
+
const b = reach.values().next().value;
|
|
2131
|
+
for (const v of eqn.outBinders) p1NextBlack.set(v, b);
|
|
2132
|
+
}
|
|
2042
2133
|
}
|
|
2043
2134
|
const p2Deps = /* @__PURE__ */ new Map();
|
|
2044
2135
|
for (const v of jaxpr.inBinders) p2Deps.set(v, new Set([v]));
|
|
@@ -2057,7 +2148,7 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2057
2148
|
let assocInput = -1;
|
|
2058
2149
|
for (let i = 0; i < eqn.inputs.length; i++) {
|
|
2059
2150
|
const input = eqn.inputs[i];
|
|
2060
|
-
if (input instanceof Var &&
|
|
2151
|
+
if (input instanceof Var && varToDefn.has(input)) {
|
|
2061
2152
|
let uniqueDeps = 0;
|
|
2062
2153
|
for (const dep of deps[i]) if (depCounter.get(dep) === 1) uniqueDeps++;
|
|
2063
2154
|
if (uniqueDeps > maxUniqueDeps) {
|
|
@@ -2068,7 +2159,7 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2068
2159
|
}
|
|
2069
2160
|
if (assocInput === -1) throw new Error(`internal: maxArgs, no input found to mark as black in Jaxpr equation ${eqn}`);
|
|
2070
2161
|
const assocVar = eqn.inputs[assocInput];
|
|
2071
|
-
p2idx =
|
|
2162
|
+
p2idx = varToDefn.get(assocVar);
|
|
2072
2163
|
for (const out of jaxpr.eqns[p2idx].outBinders) blackNodes.add(out);
|
|
2073
2164
|
} else {
|
|
2074
2165
|
const s = new Set(depCounter.keys());
|
|
@@ -3497,6 +3588,15 @@ const vmapRules = {
|
|
|
3497
3588
|
const z = dot$2(x, y);
|
|
3498
3589
|
return [[z], [z.ndim - 1]];
|
|
3499
3590
|
},
|
|
3591
|
+
[Primitive.Conv](axisSize, [x, y], [xBdim, yBdim], params) {
|
|
3592
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3593
|
+
y = moveBatchAxis(axisSize, yBdim, 0, y);
|
|
3594
|
+
const z = conv$1(x, y, {
|
|
3595
|
+
...params,
|
|
3596
|
+
vmapDims: params.vmapDims + 1
|
|
3597
|
+
});
|
|
3598
|
+
return [[z], [0]];
|
|
3599
|
+
},
|
|
3500
3600
|
[Primitive.Compare](axisSize, args, dims, { op }) {
|
|
3501
3601
|
return broadcastBatcher((x, y) => compare(x, y, op))(axisSize, args, dims, {});
|
|
3502
3602
|
},
|
|
@@ -3941,7 +4041,7 @@ function partialEvalGraphToJaxpr(tracersIn, tracersOut) {
|
|
|
3941
4041
|
for (const t of tracersIn) t.dispose();
|
|
3942
4042
|
for (const t of tracersOut) t.dispose();
|
|
3943
4043
|
jaxpr = jaxpr.simplify();
|
|
3944
|
-
if (require_backend.DEBUG >= 5) console.
|
|
4044
|
+
if (require_backend.DEBUG >= 5) console.info("jaxpr from partial evaluation:\n" + jaxpr.toString());
|
|
3945
4045
|
return {
|
|
3946
4046
|
jaxpr,
|
|
3947
4047
|
consts
|
|
@@ -4075,22 +4175,25 @@ const transposeRules = {
|
|
|
4075
4175
|
},
|
|
4076
4176
|
[Primitive.Conv]([ct], [lhs, rhs], params) {
|
|
4077
4177
|
if (lhs instanceof UndefPrimal === rhs instanceof UndefPrimal) throw new NonlinearError(Primitive.Conv);
|
|
4178
|
+
const v = params.vmapDims;
|
|
4078
4179
|
const rev01 = [
|
|
4079
|
-
|
|
4080
|
-
|
|
4081
|
-
|
|
4180
|
+
...require_backend.range(v),
|
|
4181
|
+
v + 1,
|
|
4182
|
+
v,
|
|
4183
|
+
...require_backend.range(v + 2, ct.ndim)
|
|
4082
4184
|
];
|
|
4083
4185
|
if (lhs instanceof UndefPrimal) {
|
|
4084
4186
|
let kernel = rhs;
|
|
4085
4187
|
kernel = transpose$1(kernel, rev01);
|
|
4086
|
-
kernel = flip$1(kernel, require_backend.range(2, kernel.ndim));
|
|
4188
|
+
kernel = flip$1(kernel, require_backend.range(v + 2, kernel.ndim));
|
|
4087
4189
|
const result = conv$1(ct, kernel, {
|
|
4190
|
+
vmapDims: v,
|
|
4088
4191
|
strides: params.lhsDilation,
|
|
4089
4192
|
padding: params.padding.map(([pl, _pr], i) => {
|
|
4090
|
-
const dilatedKernel = (kernel.shape[i + 2] - 1) * params.rhsDilation[i] + 1;
|
|
4091
|
-
const dilatedCt = (ct.shape[i + 2] - 1) * params.strides[i] + 1;
|
|
4193
|
+
const dilatedKernel = (kernel.shape[i + v + 2] - 1) * params.rhsDilation[i] + 1;
|
|
4194
|
+
const dilatedCt = (ct.shape[i + v + 2] - 1) * params.strides[i] + 1;
|
|
4092
4195
|
const padBefore = dilatedKernel - 1 - pl;
|
|
4093
|
-
const dilatedLhs = (lhs.aval.shape[i + 2] - 1) * params.lhsDilation[i] + 1;
|
|
4196
|
+
const dilatedLhs = (lhs.aval.shape[i + v + 2] - 1) * params.lhsDilation[i] + 1;
|
|
4094
4197
|
const padAfter = dilatedLhs + dilatedKernel - 1 - dilatedCt - padBefore;
|
|
4095
4198
|
return [padBefore, padAfter];
|
|
4096
4199
|
}),
|
|
@@ -4102,11 +4205,12 @@ const transposeRules = {
|
|
|
4102
4205
|
const newLhs = transpose$1(lhs, rev01);
|
|
4103
4206
|
const newRhs = transpose$1(ct, rev01);
|
|
4104
4207
|
let result = conv$1(newLhs, newRhs, {
|
|
4208
|
+
vmapDims: v,
|
|
4105
4209
|
strides: params.rhsDilation,
|
|
4106
4210
|
padding: params.padding.map(([pl, _pr], i) => {
|
|
4107
|
-
const dilatedLhs = (lhs.aval.shape[i + 2] - 1) * params.lhsDilation[i] + 1;
|
|
4108
|
-
const dilatedKernel = (rhs.aval.shape[i + 2] - 1) * params.rhsDilation[i] + 1;
|
|
4109
|
-
const dilatedCt = (ct.shape[i + 2] - 1) * params.strides[i] + 1;
|
|
4211
|
+
const dilatedLhs = (lhs.aval.shape[i + v + 2] - 1) * params.lhsDilation[i] + 1;
|
|
4212
|
+
const dilatedKernel = (rhs.aval.shape[i + v + 2] - 1) * params.rhsDilation[i] + 1;
|
|
4213
|
+
const dilatedCt = (ct.shape[i + v + 2] - 1) * params.strides[i] + 1;
|
|
4110
4214
|
const padFromLhs = dilatedCt - dilatedLhs;
|
|
4111
4215
|
const padFromRhs = dilatedKernel - pl - 1;
|
|
4112
4216
|
return [pl, padFromLhs + padFromRhs];
|
|
@@ -4355,13 +4459,46 @@ function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
|
|
|
4355
4459
|
*
|
|
4356
4460
|
* Grouped convolutions are not supported right now.
|
|
4357
4461
|
*/
|
|
4358
|
-
function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation } = {}) {
|
|
4462
|
+
function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
|
|
4359
4463
|
if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
|
|
4360
4464
|
if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
|
|
4361
4465
|
if (typeof padding === "string") {
|
|
4362
4466
|
if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
|
|
4363
4467
|
padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? require_backend.rep(rhs.ndim - 2, 1), padding);
|
|
4364
4468
|
}
|
|
4469
|
+
if (featureGroupCount !== 1) {
|
|
4470
|
+
const G = featureGroupCount;
|
|
4471
|
+
const [N, C_in, ...xs] = lhs.shape;
|
|
4472
|
+
const [C_out, C_in_per_group, ...ks] = rhs.shape;
|
|
4473
|
+
if (C_in % G !== 0) throw new Error(`featureGroupCount=${G} must divide input channels=${C_in}`);
|
|
4474
|
+
if (C_out % G !== 0) throw new Error(`featureGroupCount=${G} must divide output channels=${C_out}`);
|
|
4475
|
+
if (C_in / G !== C_in_per_group) throw new Error(`rhs input channels=${C_in_per_group} must equal lhs input channels / groups=${C_in / G}`);
|
|
4476
|
+
const lhsGrouped = moveaxis(lhs.reshape([
|
|
4477
|
+
N,
|
|
4478
|
+
G,
|
|
4479
|
+
C_in / G,
|
|
4480
|
+
...xs
|
|
4481
|
+
]), 1, 0);
|
|
4482
|
+
const rhsGrouped = rhs.reshape([
|
|
4483
|
+
G,
|
|
4484
|
+
C_out / G,
|
|
4485
|
+
C_in_per_group,
|
|
4486
|
+
...ks
|
|
4487
|
+
]);
|
|
4488
|
+
const result = conv$1(lhsGrouped, rhsGrouped, {
|
|
4489
|
+
vmapDims: 1,
|
|
4490
|
+
strides: windowStrides,
|
|
4491
|
+
padding,
|
|
4492
|
+
lhsDilation,
|
|
4493
|
+
rhsDilation
|
|
4494
|
+
});
|
|
4495
|
+
const ys = result.shape.slice(3);
|
|
4496
|
+
return moveaxis(result, 0, 1).reshape([
|
|
4497
|
+
N,
|
|
4498
|
+
C_out,
|
|
4499
|
+
...ys
|
|
4500
|
+
]);
|
|
4501
|
+
}
|
|
4365
4502
|
return conv$1(lhs, rhs, {
|
|
4366
4503
|
strides: windowStrides,
|
|
4367
4504
|
padding,
|
|
@@ -4647,6 +4784,8 @@ __export(numpy_exports, {
|
|
|
4647
4784
|
concatenate: () => concatenate,
|
|
4648
4785
|
cos: () => cos,
|
|
4649
4786
|
cosh: () => cosh,
|
|
4787
|
+
cumsum: () => cumsum,
|
|
4788
|
+
cumulativeSum: () => cumulativeSum,
|
|
4650
4789
|
deg2rad: () => deg2rad,
|
|
4651
4790
|
degrees: () => degrees,
|
|
4652
4791
|
diag: () => diag,
|
|
@@ -4955,6 +5094,25 @@ function argmax(a, axis, opts) {
|
|
|
4955
5094
|
}).reshape([shape$1[axis], ...require_backend.rep(shape$1.length - axis - 1, 1)]));
|
|
4956
5095
|
return length.sub(max(idx, axis, opts));
|
|
4957
5096
|
}
|
|
5097
|
+
/**
|
|
5098
|
+
* Cumulative sum of elements along an axis.
|
|
5099
|
+
*
|
|
5100
|
+
* Currently this function is `O(n^2)`, we'll improve this later on with a
|
|
5101
|
+
* two-phase parallel reduction algorithm.
|
|
5102
|
+
*/
|
|
5103
|
+
function cumsum(a, axis) {
|
|
5104
|
+
a = fudgeArray(a);
|
|
5105
|
+
if (axis === void 0) {
|
|
5106
|
+
a = a.ravel();
|
|
5107
|
+
axis = 0;
|
|
5108
|
+
} else axis = require_backend.checkAxis(axis, a.ndim);
|
|
5109
|
+
const n = a.shape[axis];
|
|
5110
|
+
a = moveaxis$1(a, axis, -1);
|
|
5111
|
+
a = broadcast(a, a.shape.concat(n), [-2]);
|
|
5112
|
+
return moveaxis$1(tril(a).sum(-1), -1, axis);
|
|
5113
|
+
}
|
|
5114
|
+
/** @function Alternative name for `jax.numpy.cumsum()`. */
|
|
5115
|
+
const cumulativeSum = cumsum;
|
|
4958
5116
|
/** Reverse the elements in an array along the given axes. */
|
|
4959
5117
|
function flip(x, axis = null) {
|
|
4960
5118
|
const nd = ndim(x);
|
|
@@ -5190,7 +5348,10 @@ function allclose(actual, expected, options) {
|
|
|
5190
5348
|
if (!require_backend.deepEqual(x.shape, y.shape)) return false;
|
|
5191
5349
|
const xData = x.dataSync();
|
|
5192
5350
|
const yData = y.dataSync();
|
|
5193
|
-
for (let i = 0; i < xData.length; i++)
|
|
5351
|
+
for (let i = 0; i < xData.length; i++) {
|
|
5352
|
+
if (isNaN(xData[i]) !== isNaN(yData[i])) return false;
|
|
5353
|
+
if (Math.abs(xData[i] - yData[i]) > atol + rtol * Math.abs(yData[i])) return false;
|
|
5354
|
+
}
|
|
5194
5355
|
return true;
|
|
5195
5356
|
}
|
|
5196
5357
|
/** Matrix product of two arrays. */
|
|
@@ -5649,7 +5810,10 @@ const degrees = rad2deg;
|
|
|
5649
5810
|
* Computes first array raised to power of second array, element-wise.
|
|
5650
5811
|
*/
|
|
5651
5812
|
const power = jit$1(function power$1(x1, x2) {
|
|
5652
|
-
|
|
5813
|
+
const x2i = trunc(x2.ref);
|
|
5814
|
+
const shouldBeNaN = multiply(x2.ref.notEqual(x2i.ref), x1.ref.less(0));
|
|
5815
|
+
const resultSign = where(mod(x2i, 2).notEqual(0), where(x1.ref.less(0), -1, 1), 1);
|
|
5816
|
+
return where(shouldBeNaN, nan, exp(log(abs(x1)).mul(x2)).mul(resultSign));
|
|
5653
5817
|
});
|
|
5654
5818
|
/** @function Alias of `jax.numpy.power()`. */
|
|
5655
5819
|
const pow = power;
|
|
@@ -6005,22 +6169,22 @@ function logSoftmax(x, axis = -1) {
|
|
|
6005
6169
|
*
|
|
6006
6170
|
* Reference: https://en.wikipedia.org/wiki/LogSumExp
|
|
6007
6171
|
*/
|
|
6008
|
-
function logsumexp(x, axis = null) {
|
|
6172
|
+
function logsumexp(x, axis = null, opts) {
|
|
6009
6173
|
x = fudgeArray(x);
|
|
6010
6174
|
axis = require_backend.normalizeAxis(axis, x.ndim);
|
|
6011
6175
|
if (axis.length === 0) return x;
|
|
6012
|
-
const xMax = stopGradient(max(x.ref, axis));
|
|
6013
|
-
const
|
|
6014
|
-
const
|
|
6015
|
-
return
|
|
6176
|
+
const xMax = stopGradient(max(x.ref, axis, { keepdims: true }));
|
|
6177
|
+
const shifted = x.sub(xMax.ref);
|
|
6178
|
+
const result = xMax.add(log(exp(shifted).sum(axis, { keepdims: true })));
|
|
6179
|
+
return opts?.keepdims ? result : squeeze(result, axis);
|
|
6016
6180
|
}
|
|
6017
6181
|
/** Log-mean-exp reduction, like `jax.nn.logsumexp()` but subtracts `log(n)`. */
|
|
6018
|
-
function logmeanexp(x, axis = null) {
|
|
6182
|
+
function logmeanexp(x, axis = null, opts) {
|
|
6019
6183
|
x = fudgeArray(x);
|
|
6020
6184
|
axis = require_backend.normalizeAxis(axis, x.ndim);
|
|
6021
6185
|
if (axis.length === 0) return x;
|
|
6022
6186
|
const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
|
|
6023
|
-
return logsumexp(x, axis).sub(Math.log(n));
|
|
6187
|
+
return logsumexp(x, axis, opts).sub(Math.log(n));
|
|
6024
6188
|
}
|
|
6025
6189
|
/**
|
|
6026
6190
|
* Standardizes input to zero mean and unit variance.
|