@jax-js/jax 0.1.2 → 0.1.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +11 -32
- package/dist/{backend-BqymqzuU.js → backend-BY8wlLEl.js} +58 -20
- package/dist/{backend-DeVfWEFS.cjs → backend-CmaidnkQ.cjs} +58 -20
- package/dist/index.cjs +298 -134
- package/dist/index.d.cts +21 -5
- package/dist/index.d.ts +21 -5
- package/dist/index.js +298 -134
- package/dist/{webgpu-CcGP160M.cjs → webgpu-BVns4DbI.cjs} +14 -6
- package/dist/{webgpu-BGuG58KZ.js → webgpu-C9iAP5h5.js} +14 -6
- package/package.json +1 -1
package/dist/index.js
CHANGED
|
@@ -1,28 +1,36 @@
|
|
|
1
1
|
import { __export } from "./chunk-Cl8Af3a2.js";
|
|
2
|
-
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-
|
|
2
|
+
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-BY8wlLEl.js";
|
|
3
3
|
|
|
4
4
|
//#region src/frontend/convolution.ts
|
|
5
5
|
/**
|
|
6
6
|
* Check that the shapes and parameters passed to convolution are valid.
|
|
7
|
+
* Expected shapes of the lhs and rhs of the convolution are:
|
|
8
|
+
*
|
|
9
|
+
* - `lhsShape = [*vmapDims, batchSize, inChannels, spatialDims...]`
|
|
10
|
+
* - `rhsShape = [*vmapDims, outChannels, inChannels, kernelSize...]`
|
|
7
11
|
*
|
|
8
12
|
* If the check succeeds, returns the output shape.
|
|
9
13
|
*/
|
|
10
|
-
function checkConvShape(lhsShape, rhsShape, { strides, padding, lhsDilation, rhsDilation }) {
|
|
14
|
+
function checkConvShape(lhsShape, rhsShape, { vmapDims, strides, padding, lhsDilation, rhsDilation }) {
|
|
11
15
|
if (lhsShape.length !== rhsShape.length) throw new Error(`conv() requires inputs with the same number of dimensions, got ${lhsShape.length} and ${rhsShape.length}`);
|
|
12
|
-
const n = lhsShape.length - 2;
|
|
16
|
+
const n = lhsShape.length - 2 - vmapDims;
|
|
13
17
|
if (n < 0) throw new Error("conv() requires at least 2D inputs");
|
|
14
18
|
if (strides.length !== n) throw new Error("conv() strides != spatial dims");
|
|
15
19
|
if (padding.length !== n) throw new Error("conv() padding != spatial dims");
|
|
16
20
|
if (lhsDilation.length !== n) throw new Error("conv() lhsDilation != spatial dimensions");
|
|
17
21
|
if (rhsDilation.length !== n) throw new Error("conv() rhsDilation != spatial dimensions");
|
|
18
|
-
if (lhsShape[1] !== rhsShape[1]) throw new Error(`conv() input channels: ${lhsShape[1]} != ${rhsShape[1]}`);
|
|
19
|
-
const outShape = [
|
|
22
|
+
if (lhsShape[vmapDims + 1] !== rhsShape[vmapDims + 1]) throw new Error(`conv() input channels: ${lhsShape[1]} != ${rhsShape[1]}`);
|
|
23
|
+
const outShape = [
|
|
24
|
+
...generalBroadcast(lhsShape.slice(0, vmapDims), rhsShape.slice(0, vmapDims)),
|
|
25
|
+
lhsShape[vmapDims],
|
|
26
|
+
rhsShape[vmapDims]
|
|
27
|
+
];
|
|
20
28
|
for (let i = 0; i < n; i++) {
|
|
21
29
|
if (strides[i] <= 0 || !Number.isInteger(strides[i])) throw new Error(`conv() strides[${i}] must be a positive integer`);
|
|
22
30
|
if (padding[i].length !== 2 || !padding[i].every(Number.isInteger)) throw new Error(`conv() padding[${i}] must be a 2-tuple of integers`);
|
|
23
31
|
if (lhsDilation[i] <= 0 || !Number.isInteger(lhsDilation[i])) throw new Error(`conv() lhsDilation[${i}] must be a positive integer`);
|
|
24
32
|
if (rhsDilation[i] <= 0 || !Number.isInteger(rhsDilation[i])) throw new Error(`conv() rhsDilation[${i}] must be a positive integer`);
|
|
25
|
-
const [x, k] = [lhsShape[i + 2], rhsShape[i + 2]];
|
|
33
|
+
const [x, k] = [lhsShape[i + vmapDims + 2], rhsShape[i + vmapDims + 2]];
|
|
26
34
|
if (k <= 0) throw new Error("conv() kernel size must be positive");
|
|
27
35
|
const [pl, pr] = padding[i];
|
|
28
36
|
if (pl < -x || pr < -x || pl + pr < -x) throw new Error(`conv() padding[${i}]=(${pl},${pr}) is too negative for input size ${x}`);
|
|
@@ -147,27 +155,13 @@ function poolTranspose(st, inShape, ks, strides = 1, dilation = 1) {
|
|
|
147
155
|
function applyDilation(st, dilation) {
|
|
148
156
|
if (dilation.every((s) => s === 1)) return st;
|
|
149
157
|
const s_ = dilation;
|
|
150
|
-
const
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
]);
|
|
156
|
-
st = st.
|
|
157
|
-
[0, 0],
|
|
158
|
-
[0, 0],
|
|
159
|
-
...s_.flatMap((s) => [[0, 0], [0, s - 1]])
|
|
160
|
-
]);
|
|
161
|
-
st = st.reshape([
|
|
162
|
-
a,
|
|
163
|
-
b,
|
|
164
|
-
...k_.map((k, i) => k * s_[i])
|
|
165
|
-
]);
|
|
166
|
-
st = st.shrink([
|
|
167
|
-
[0, a],
|
|
168
|
-
[0, b],
|
|
169
|
-
...k_.map((k, i) => [0, (k - 1) * s_[i] + 1])
|
|
170
|
-
]);
|
|
158
|
+
const n = s_.length;
|
|
159
|
+
const prefix = st.shape.slice(0, -n);
|
|
160
|
+
const k_ = st.shape.slice(-n);
|
|
161
|
+
st = st.reshape([...prefix, ...k_.flatMap((k) => [k, 1])]);
|
|
162
|
+
st = st.pad([...prefix.map(() => [0, 0]), ...s_.flatMap((s) => [[0, 0], [0, s - 1]])]);
|
|
163
|
+
st = st.reshape([...prefix, ...k_.map((k, i) => k * s_[i])]);
|
|
164
|
+
st = st.shrink([...prefix.map((p) => [0, p]), ...k_.map((k, i) => [0, (k - 1) * s_[i] + 1])]);
|
|
171
165
|
return st;
|
|
172
166
|
}
|
|
173
167
|
/**
|
|
@@ -177,25 +171,26 @@ function applyDilation(st, dilation) {
|
|
|
177
171
|
* beforehand using `checkConvShape()`.
|
|
178
172
|
*/
|
|
179
173
|
function prepareConv(stX, stY, params) {
|
|
180
|
-
const
|
|
174
|
+
const v = params.vmapDims;
|
|
175
|
+
const n = stX.shape.length - 2 - v;
|
|
176
|
+
const vmapShape = stX.shape.slice(0, v);
|
|
181
177
|
stX = applyDilation(stX, params.lhsDilation);
|
|
182
|
-
const ks = stY.shape.slice(2);
|
|
183
|
-
stX = stX.padOrShrink([
|
|
184
|
-
[0, 0],
|
|
185
|
-
[0, 0],
|
|
186
|
-
...params.padding
|
|
187
|
-
]);
|
|
178
|
+
const ks = stY.shape.slice(v + 2);
|
|
179
|
+
stX = stX.padOrShrink([...rep(v + 2, [0, 0]), ...params.padding]);
|
|
188
180
|
stX = pool(stX, ks, params.strides, params.rhsDilation);
|
|
189
|
-
stX = stX.moveaxis(1, n + 1).reshape([
|
|
190
|
-
|
|
181
|
+
stX = stX.moveaxis(v + 1, v + n + 1).reshape([
|
|
182
|
+
...vmapShape,
|
|
183
|
+
stX.shape[v],
|
|
191
184
|
1,
|
|
192
|
-
...stX.shape.slice(2, n + 2),
|
|
193
|
-
stX.shape[1] * prod(ks)
|
|
185
|
+
...stX.shape.slice(v + 2, v + n + 2),
|
|
186
|
+
stX.shape[v + 1] * prod(ks)
|
|
194
187
|
]);
|
|
195
188
|
stY = stY.reshape([
|
|
196
|
-
|
|
189
|
+
...vmapShape,
|
|
190
|
+
1,
|
|
191
|
+
stY.shape[v],
|
|
197
192
|
...rep(n, 1),
|
|
198
|
-
stY.shape[1] * prod(ks)
|
|
193
|
+
stY.shape[v + 1] * prod(ks)
|
|
199
194
|
]);
|
|
200
195
|
return [stX, stY];
|
|
201
196
|
}
|
|
@@ -467,9 +462,11 @@ function dot$2(x, y) {
|
|
|
467
462
|
}
|
|
468
463
|
function conv$1(x, y, params = {}) {
|
|
469
464
|
if (x.ndim !== y.ndim) throw new Error(`conv() requires inputs with the same number of dimensions, got ${x.ndim} and ${y.ndim}`);
|
|
470
|
-
const
|
|
465
|
+
const vmapDims = params.vmapDims ?? 0;
|
|
466
|
+
const n = x.ndim - 2 - vmapDims;
|
|
471
467
|
if (n < 0) throw new Error("conv() requires at least 2D inputs");
|
|
472
468
|
return bind1(Primitive.Conv, [x, y], {
|
|
469
|
+
vmapDims,
|
|
473
470
|
strides: params.strides ?? rep(n, 1),
|
|
474
471
|
padding: params.padding ?? rep(n, [0, 0]),
|
|
475
472
|
lhsDilation: params.lhsDilation ?? rep(n, 1),
|
|
@@ -693,8 +690,10 @@ var Tracer = class Tracer {
|
|
|
693
690
|
axis = normalizeAxis(axis, this.ndim);
|
|
694
691
|
const n = axis.reduce((acc, a) => acc * this.shape[a], 1);
|
|
695
692
|
if (n === 0) throw new Error("mean: cannot compute mean over zero-length axis");
|
|
696
|
-
const
|
|
697
|
-
|
|
693
|
+
const originalDtype = this.dtype;
|
|
694
|
+
const castDtype = promoteTypes(originalDtype, DType.Float32);
|
|
695
|
+
const result = reduce(this.astype(castDtype), AluOp.Add, axis, opts);
|
|
696
|
+
return result.mul(1 / n).astype(originalDtype);
|
|
698
697
|
}
|
|
699
698
|
/** Permute the dimensions of an array. Defaults to reversing the axis order. */
|
|
700
699
|
transpose(perm) {
|
|
@@ -1170,7 +1169,7 @@ var Jaxpr = class Jaxpr {
|
|
|
1170
1169
|
} else if (eqn.primitive === Primitive.Idiv) {
|
|
1171
1170
|
const [a, b] = inputs;
|
|
1172
1171
|
const c = eqn.outBinders[0];
|
|
1173
|
-
if (atomIsLit(b, 1)) context.set(c, a);
|
|
1172
|
+
if (atomIsLit(b, 1) && !isFloatDtype(a.aval.dtype)) context.set(c, a);
|
|
1174
1173
|
else newEqns.push(eqn);
|
|
1175
1174
|
} else if ((eqn.primitive === Primitive.Broadcast || eqn.primitive === Primitive.Reshape) && deepEqual(eqn.params.shape, eqn.inputs[0].aval.shape) || eqn.primitive === Primitive.Transpose && eqn.params.perm.every((p, i) => p === i) || eqn.primitive === Primitive.Flip && eqn.params.axis.length === 0 || eqn.primitive === Primitive.Shrink && eqn.params.slice.every(([s, e$2], i) => s === 0 && e$2 === eqn.inputs[0].aval.shape[i]) || eqn.primitive === Primitive.Pad && eqn.params.width.every(([w0, w1]) => w0 === 0 && w1 === 0)) context.set(eqn.outBinders[0], eqn.inputs[0]);
|
|
1176
1175
|
else newEqns.push(eqn);
|
|
@@ -1755,48 +1754,73 @@ function jitCompile(backend, jaxpr, consts) {
|
|
|
1755
1754
|
const inputExps = [];
|
|
1756
1755
|
const inputAvals = [];
|
|
1757
1756
|
const inputArgs = [];
|
|
1758
|
-
|
|
1759
|
-
|
|
1760
|
-
|
|
1761
|
-
|
|
1762
|
-
|
|
1763
|
-
|
|
1764
|
-
|
|
1765
|
-
|
|
1766
|
-
inputArgs.push(jitId);
|
|
1767
|
-
}
|
|
1768
|
-
gidMap.set(gid, newGid);
|
|
1769
|
-
}
|
|
1770
|
-
inputExps.push(jitValue.exp.reindexGids(gidMap));
|
|
1771
|
-
} else if (jitValue.type === "imm") {
|
|
1772
|
-
let gid = inputArgs.indexOf(jitValue.arg);
|
|
1773
|
-
if (gid === -1) {
|
|
1774
|
-
gid = inputArgs.length;
|
|
1775
|
-
inputArgs.push(jitValue.arg);
|
|
1757
|
+
let inputReduction = null;
|
|
1758
|
+
const addArgs = (args) => {
|
|
1759
|
+
const newGids = [];
|
|
1760
|
+
for (const jitId of args) {
|
|
1761
|
+
let newGid = inputArgs.indexOf(jitId);
|
|
1762
|
+
if (newGid === -1) {
|
|
1763
|
+
newGid = inputArgs.length;
|
|
1764
|
+
inputArgs.push(jitId);
|
|
1776
1765
|
}
|
|
1766
|
+
newGids.push(newGid);
|
|
1767
|
+
}
|
|
1768
|
+
return newGids;
|
|
1769
|
+
};
|
|
1770
|
+
for (const input of eqn.inputs) if (input instanceof Var) {
|
|
1771
|
+
const jv = ctx.get(input);
|
|
1772
|
+
if (jv.type === "exp") {
|
|
1773
|
+
const newGids = addArgs(jv.args);
|
|
1774
|
+
inputExps.push(jv.exp.reindexGids(newGids));
|
|
1775
|
+
} else if (jv.type === "imm") {
|
|
1776
|
+
const [gid] = addArgs([jv.arg]);
|
|
1777
1777
|
const st = ShapeTracker.fromShape(input.aval.shape);
|
|
1778
1778
|
const indices = unravelAlu(st.shape, AluVar.gidx);
|
|
1779
1779
|
inputExps.push(AluExp.globalView(input.aval.dtype, gid, st, indices));
|
|
1780
|
+
} else if (jv.type === "red") {
|
|
1781
|
+
if (inputReduction) throw new Error("jit: unexpected, multiple red inputs");
|
|
1782
|
+
const newGids = addArgs(jv.args);
|
|
1783
|
+
inputExps.push(jv.reduction.epilogue.reindexGids(newGids));
|
|
1784
|
+
inputReduction = jv;
|
|
1780
1785
|
}
|
|
1781
1786
|
inputAvals.push(input.aval);
|
|
1782
1787
|
} else if (input instanceof Lit) {
|
|
1783
1788
|
inputExps.push(AluExp.const(input.dtype, input.value));
|
|
1784
1789
|
inputAvals.push(input.aval);
|
|
1785
1790
|
} else throw new TypeError(`Unexpected input in Jaxpr: ${input}`);
|
|
1786
|
-
const nargs$1 = inputArgs.length;
|
|
1787
1791
|
const rule = jitRules[eqn.primitive];
|
|
1788
1792
|
if (!rule) throw new TypeError(`JIT not implemented for primitive ${eqn.primitive}`);
|
|
1789
|
-
|
|
1793
|
+
let exp$2;
|
|
1794
|
+
let reduction;
|
|
1795
|
+
if (inputReduction) {
|
|
1796
|
+
const jv = inputReduction;
|
|
1797
|
+
const newEpilogue = rule(inputExps, inputAvals, eqn.params).exp;
|
|
1798
|
+
exp$2 = jv.exp.reindexGids(addArgs(jv.args));
|
|
1799
|
+
reduction = new Reduction(jv.reduction.dtype, jv.reduction.op, jv.reduction.size, newEpilogue);
|
|
1800
|
+
} else {
|
|
1801
|
+
const ruleOutput = rule(inputExps, inputAvals, eqn.params);
|
|
1802
|
+
exp$2 = ruleOutput.exp;
|
|
1803
|
+
reduction = ruleOutput.reduction;
|
|
1804
|
+
}
|
|
1790
1805
|
const outVar = eqn.outBinders[0];
|
|
1791
|
-
if (
|
|
1806
|
+
if (blackNodes.has(outVar)) {
|
|
1807
|
+
const nargs$1 = inputArgs.length;
|
|
1808
|
+
const size$1 = prod(outVar.aval.shape);
|
|
1809
|
+
const kernel = new Kernel(nargs$1, size$1, exp$2, reduction);
|
|
1792
1810
|
const outId = builder.pushKernel(kernel, inputArgs);
|
|
1793
1811
|
ctx.set(outVar, {
|
|
1794
1812
|
type: "imm",
|
|
1795
1813
|
arg: outId
|
|
1796
1814
|
});
|
|
1797
|
-
} else ctx.set(outVar, {
|
|
1815
|
+
} else if (reduction) ctx.set(outVar, {
|
|
1816
|
+
type: "red",
|
|
1817
|
+
exp: exp$2,
|
|
1818
|
+
reduction,
|
|
1819
|
+
args: inputArgs
|
|
1820
|
+
});
|
|
1821
|
+
else ctx.set(outVar, {
|
|
1798
1822
|
type: "exp",
|
|
1799
|
-
exp:
|
|
1823
|
+
exp: exp$2,
|
|
1800
1824
|
args: inputArgs
|
|
1801
1825
|
});
|
|
1802
1826
|
}
|
|
@@ -1828,31 +1852,28 @@ function reshapeViews(exp$2, mapping, reduceAxis = false) {
|
|
|
1828
1852
|
});
|
|
1829
1853
|
}
|
|
1830
1854
|
function broadcastedJit(fn, opts) {
|
|
1831
|
-
return (
|
|
1855
|
+
return (exps, avals, params) => {
|
|
1832
1856
|
let { shape: newShape, dtype: newDtype } = avals.reduce(promoteAvals);
|
|
1833
1857
|
const skipCastIdx = opts?.skipCastIdx ?? [];
|
|
1834
1858
|
if (skipCastIdx.length) newDtype = avals.filter((_, i) => !skipCastIdx.includes(i)).reduce(promoteAvals).dtype;
|
|
1835
|
-
exps = exps.map((exp$
|
|
1836
|
-
exp$
|
|
1859
|
+
exps = exps.map((exp$2, i) => {
|
|
1860
|
+
exp$2 = reshapeViews(exp$2, (st) => {
|
|
1837
1861
|
if (!deepEqual(st.shape, newShape)) return st.broadcast(newShape, range(newShape.length - st.shape.length));
|
|
1838
1862
|
});
|
|
1839
|
-
if (exp$
|
|
1840
|
-
return exp$
|
|
1863
|
+
if (exp$2.dtype !== newDtype && !skipCastIdx.includes(i)) exp$2 = AluExp.cast(newDtype, exp$2);
|
|
1864
|
+
return exp$2;
|
|
1841
1865
|
});
|
|
1842
|
-
|
|
1843
|
-
return new Kernel(nargs, prod(newShape), exp$2);
|
|
1866
|
+
return { exp: fn(exps, params) };
|
|
1844
1867
|
};
|
|
1845
1868
|
}
|
|
1846
1869
|
function unopJit(fn) {
|
|
1847
|
-
return (
|
|
1848
|
-
return
|
|
1870
|
+
return ([a], [_as], params) => {
|
|
1871
|
+
return { exp: fn(a, params) };
|
|
1849
1872
|
};
|
|
1850
1873
|
}
|
|
1851
1874
|
function reshapeJit(fn) {
|
|
1852
|
-
return (
|
|
1853
|
-
|
|
1854
|
-
const newShape = fn(ShapeTracker.fromShape(as.shape), params).shape;
|
|
1855
|
-
return new Kernel(nargs, prod(newShape), a);
|
|
1875
|
+
return ([a], [_as], params) => {
|
|
1876
|
+
return { exp: reshapeViews(a, (st) => fn(st, params)) };
|
|
1856
1877
|
};
|
|
1857
1878
|
}
|
|
1858
1879
|
const jitRules = {
|
|
@@ -1867,7 +1888,7 @@ const jitRules = {
|
|
|
1867
1888
|
[Primitive.StopGradient]: unopJit((a) => a),
|
|
1868
1889
|
[Primitive.Cast]: unopJit((a, { dtype }) => AluExp.cast(dtype, a)),
|
|
1869
1890
|
[Primitive.Bitcast]: unopJit((a, { dtype }) => AluExp.bitcast(dtype, a)),
|
|
1870
|
-
[Primitive.RandomBits]: (
|
|
1891
|
+
[Primitive.RandomBits]: (keys, keyShapes, { shape: shape$1, mode }) => {
|
|
1871
1892
|
const mapping = (st) => {
|
|
1872
1893
|
if (!deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, range(shape$1.length - st.shape.length));
|
|
1873
1894
|
};
|
|
@@ -1876,7 +1897,7 @@ const jitRules = {
|
|
|
1876
1897
|
const c0 = AluExp.u32(0);
|
|
1877
1898
|
const c1 = AluExp.cast(DType.Uint32, AluVar.gidx);
|
|
1878
1899
|
const exp$2 = AluExp.threefry2x32(k0, k1, c0, c1, mode);
|
|
1879
|
-
return
|
|
1900
|
+
return { exp: exp$2 };
|
|
1880
1901
|
},
|
|
1881
1902
|
[Primitive.Sin]: unopJit(AluExp.sin),
|
|
1882
1903
|
[Primitive.Cos]: unopJit(AluExp.cos),
|
|
@@ -1889,7 +1910,7 @@ const jitRules = {
|
|
|
1889
1910
|
[Primitive.Sqrt]: unopJit(AluExp.sqrt),
|
|
1890
1911
|
[Primitive.Min]: broadcastedJit(([a, b]) => AluExp.min(a, b)),
|
|
1891
1912
|
[Primitive.Max]: broadcastedJit(([a, b]) => AluExp.max(a, b)),
|
|
1892
|
-
[Primitive.Reduce](
|
|
1913
|
+
[Primitive.Reduce]([a], [as], { op, axis }) {
|
|
1893
1914
|
const keptAxes = [];
|
|
1894
1915
|
const shiftedAxes = [];
|
|
1895
1916
|
const newShape = [];
|
|
@@ -1898,39 +1919,43 @@ const jitRules = {
|
|
|
1898
1919
|
keptAxes.push(i);
|
|
1899
1920
|
newShape.push(as.shape[i]);
|
|
1900
1921
|
}
|
|
1901
|
-
const size$1 = prod(newShape);
|
|
1902
1922
|
const reductionSize = prod(shiftedAxes.map((ax) => as.shape[ax]));
|
|
1903
1923
|
newShape.push(reductionSize);
|
|
1904
1924
|
const perm = keptAxes.concat(shiftedAxes);
|
|
1905
1925
|
a = reshapeViews(a, (st) => st.permute(perm).reshape(newShape), true);
|
|
1906
1926
|
const reduction = new Reduction(a.dtype, op, reductionSize);
|
|
1907
|
-
return
|
|
1927
|
+
return {
|
|
1928
|
+
exp: a,
|
|
1929
|
+
reduction
|
|
1930
|
+
};
|
|
1908
1931
|
},
|
|
1909
1932
|
[Primitive.Pool]: reshapeJit((st, { window, strides }) => pool(st, window, strides)),
|
|
1910
|
-
[Primitive.PoolTranspose](
|
|
1933
|
+
[Primitive.PoolTranspose]([a], [as], { inShape, window, strides }) {
|
|
1911
1934
|
let stX = poolTranspose(ShapeTracker.fromShape(as.shape), inShape, window, strides);
|
|
1912
|
-
const size$1 = prod(inShape);
|
|
1913
1935
|
stX = stX.reshape([...inShape, prod(stX.shape.slice(inShape.length))]);
|
|
1914
1936
|
a = reshapeViews(a, (st) => st.compose(stX), true);
|
|
1915
1937
|
const reduction = new Reduction(a.dtype, AluOp.Add, stX.shape[stX.shape.length - 1]);
|
|
1916
|
-
return
|
|
1938
|
+
return {
|
|
1939
|
+
exp: a,
|
|
1940
|
+
reduction
|
|
1941
|
+
};
|
|
1917
1942
|
},
|
|
1918
|
-
[Primitive.Dot](
|
|
1919
|
-
const k1 = jitRules[Primitive.Mul](
|
|
1943
|
+
[Primitive.Dot]([a, b], [as, bs]) {
|
|
1944
|
+
const k1 = jitRules[Primitive.Mul]([a, b], [as, bs], {});
|
|
1920
1945
|
const c = k1.exp;
|
|
1921
1946
|
const cs = promoteAvals(as, bs);
|
|
1922
|
-
return jitRules[Primitive.Reduce](
|
|
1947
|
+
return jitRules[Primitive.Reduce]([c], [cs], {
|
|
1923
1948
|
op: AluOp.Add,
|
|
1924
1949
|
axis: [cs.ndim - 1]
|
|
1925
1950
|
});
|
|
1926
1951
|
},
|
|
1927
|
-
[Primitive.Conv](
|
|
1952
|
+
[Primitive.Conv]([a, b], [as, bs], params) {
|
|
1928
1953
|
const [stX, stY] = prepareConv(ShapeTracker.fromShape(as.shape), ShapeTracker.fromShape(bs.shape), params);
|
|
1929
1954
|
a = reshapeViews(a, (st) => st.compose(stX));
|
|
1930
1955
|
b = reshapeViews(b, (st) => st.compose(stY));
|
|
1931
1956
|
as = new ShapedArray(stX.shape, as.dtype, as.weakType);
|
|
1932
1957
|
bs = new ShapedArray(stY.shape, bs.dtype, bs.weakType);
|
|
1933
|
-
return jitRules[Primitive.Dot](
|
|
1958
|
+
return jitRules[Primitive.Dot]([a, b], [as, bs], {});
|
|
1934
1959
|
},
|
|
1935
1960
|
[Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
|
|
1936
1961
|
[Primitive.Where]: broadcastedJit(([cond, a, b]) => AluExp.where(cond, a, b), { skipCastIdx: [0] }),
|
|
@@ -1944,7 +1969,7 @@ const jitRules = {
|
|
|
1944
1969
|
}),
|
|
1945
1970
|
[Primitive.Shrink]: reshapeJit((st, { slice }) => st.shrink(slice)),
|
|
1946
1971
|
[Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
|
|
1947
|
-
[Primitive.Gather](
|
|
1972
|
+
[Primitive.Gather]([x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
|
|
1948
1973
|
const axisSet = new Set(axis);
|
|
1949
1974
|
const indexShape = indicesShapes.map((c) => c.shape).reduce(generalBroadcast);
|
|
1950
1975
|
const finalShape = xs.shape.filter((_, i) => !axisSet.has(i));
|
|
@@ -1957,7 +1982,7 @@ const jitRules = {
|
|
|
1957
1982
|
for (const [i, iexp] of indices.entries()) src[axis[i]] = AluExp.cast(DType.Int32, reshapeViews(iexp, (st) => st.broadcast(finalShape, [...range(outDim + indexShape.length - st.shape.length), ...range(outDim + indexShape.length, finalShape.length)])));
|
|
1958
1983
|
const [index, valid] = ShapeTracker.fromShape(xs.shape).toAluExp(src);
|
|
1959
1984
|
if (!valid.resolve()) throw new Error("internal: expected full validity mask in Gather");
|
|
1960
|
-
return
|
|
1985
|
+
return { exp: x.substitute({ gidx: index }) };
|
|
1961
1986
|
},
|
|
1962
1987
|
[Primitive.JitCall]() {
|
|
1963
1988
|
throw new Error("internal: JitCall should have been flattened before JIT compilation");
|
|
@@ -1965,16 +1990,16 @@ const jitRules = {
|
|
|
1965
1990
|
};
|
|
1966
1991
|
/** Determines how to split the Jaxpr into kernels via dataflow analysis. */
|
|
1967
1992
|
function splitGraphDataflow(backend, jaxpr) {
|
|
1968
|
-
const
|
|
1993
|
+
const varToDefn = /* @__PURE__ */ new Map();
|
|
1994
|
+
const varToUsages = /* @__PURE__ */ new Map();
|
|
1969
1995
|
for (let i = 0; i < jaxpr.eqns.length; i++) {
|
|
1970
1996
|
const eqn = jaxpr.eqns[i];
|
|
1971
|
-
for (const v of eqn.outBinders) if (v instanceof Var)
|
|
1972
|
-
|
|
1973
|
-
|
|
1974
|
-
|
|
1975
|
-
|
|
1976
|
-
|
|
1977
|
-
p1NextBlack.set(v, v);
|
|
1997
|
+
for (const v of eqn.outBinders) if (v instanceof Var) varToDefn.set(v, i);
|
|
1998
|
+
for (const input of eqn.inputs) if (input instanceof Var) {
|
|
1999
|
+
const usages = varToUsages.get(input);
|
|
2000
|
+
if (usages) usages.push(i);
|
|
2001
|
+
else varToUsages.set(input, [i]);
|
|
2002
|
+
}
|
|
1978
2003
|
}
|
|
1979
2004
|
const reducePrimitives = [
|
|
1980
2005
|
Primitive.Reduce,
|
|
@@ -1982,10 +2007,68 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
1982
2007
|
Primitive.Conv,
|
|
1983
2008
|
Primitive.PoolTranspose
|
|
1984
2009
|
];
|
|
2010
|
+
const reductionEpilogueEqns = /* @__PURE__ */ new Set();
|
|
2011
|
+
const reductionEndpointEqns = /* @__PURE__ */ new Set();
|
|
2012
|
+
for (let i = 0; i < jaxpr.eqns.length; i++) {
|
|
2013
|
+
const eqn = jaxpr.eqns[i];
|
|
2014
|
+
if (reducePrimitives.includes(eqn.primitive)) {
|
|
2015
|
+
let head = i;
|
|
2016
|
+
while (true) {
|
|
2017
|
+
reductionEpilogueEqns.add(head);
|
|
2018
|
+
const outVar = jaxpr.eqns[head].outBinders[0];
|
|
2019
|
+
const usages = varToUsages.get(outVar) ?? [];
|
|
2020
|
+
if (jaxpr.outs.includes(outVar) || usages.length !== 1) break;
|
|
2021
|
+
if (reductionEpilogueEqns.has(usages[0])) break;
|
|
2022
|
+
const nextEqn = jaxpr.eqns[usages[0]];
|
|
2023
|
+
switch (nextEqn.primitive) {
|
|
2024
|
+
case Primitive.Neg:
|
|
2025
|
+
case Primitive.Reciprocal:
|
|
2026
|
+
case Primitive.Floor:
|
|
2027
|
+
case Primitive.Ceil:
|
|
2028
|
+
case Primitive.StopGradient:
|
|
2029
|
+
case Primitive.Cast:
|
|
2030
|
+
case Primitive.Bitcast:
|
|
2031
|
+
case Primitive.Sin:
|
|
2032
|
+
case Primitive.Cos:
|
|
2033
|
+
case Primitive.Asin:
|
|
2034
|
+
case Primitive.Atan:
|
|
2035
|
+
case Primitive.Exp:
|
|
2036
|
+
case Primitive.Log:
|
|
2037
|
+
case Primitive.Erf:
|
|
2038
|
+
case Primitive.Erfc:
|
|
2039
|
+
case Primitive.Sqrt:
|
|
2040
|
+
head = usages[0];
|
|
2041
|
+
continue;
|
|
2042
|
+
case Primitive.Add:
|
|
2043
|
+
case Primitive.Mul:
|
|
2044
|
+
case Primitive.Idiv:
|
|
2045
|
+
case Primitive.Mod:
|
|
2046
|
+
case Primitive.Max:
|
|
2047
|
+
case Primitive.Min: {
|
|
2048
|
+
const otherInput = nextEqn.inputs.find((v) => v !== outVar);
|
|
2049
|
+
if (otherInput instanceof Lit || deepEqual(generalBroadcast(otherInput.aval.shape, outVar.aval.shape), outVar.aval.shape)) {
|
|
2050
|
+
head = usages[0];
|
|
2051
|
+
continue;
|
|
2052
|
+
}
|
|
2053
|
+
break;
|
|
2054
|
+
}
|
|
2055
|
+
}
|
|
2056
|
+
break;
|
|
2057
|
+
}
|
|
2058
|
+
reductionEndpointEqns.add(head);
|
|
2059
|
+
}
|
|
2060
|
+
}
|
|
2061
|
+
const blackNodes = /* @__PURE__ */ new Set();
|
|
2062
|
+
const p1NextBlack = /* @__PURE__ */ new Map();
|
|
2063
|
+
for (const v of jaxpr.outs) if (v instanceof Var) {
|
|
2064
|
+
blackNodes.add(v);
|
|
2065
|
+
p1NextBlack.set(v, v);
|
|
2066
|
+
}
|
|
1985
2067
|
const heterogeneousViewPrimitives = [Primitive.Gather, Primitive.RandomBits];
|
|
2068
|
+
const needsCleanShapePrimitives = [Primitive.Pad];
|
|
1986
2069
|
for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
|
|
1987
2070
|
const eqn = jaxpr.eqns[i];
|
|
1988
|
-
if (
|
|
2071
|
+
if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
|
|
1989
2072
|
for (const v of eqn.outBinders) {
|
|
1990
2073
|
blackNodes.add(v);
|
|
1991
2074
|
p1NextBlack.set(v, v);
|
|
@@ -1993,17 +2076,25 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
1993
2076
|
continue;
|
|
1994
2077
|
}
|
|
1995
2078
|
const reach = /* @__PURE__ */ new Set();
|
|
1996
|
-
|
|
1997
|
-
|
|
1998
|
-
if (
|
|
2079
|
+
let needsCleanOutput = false;
|
|
2080
|
+
outer: for (const v of eqn.outBinders) for (const j of varToUsages.get(v) ?? []) {
|
|
2081
|
+
if (needsCleanShapePrimitives.includes(jaxpr.eqns[j].primitive)) {
|
|
2082
|
+
needsCleanOutput = true;
|
|
2083
|
+
break outer;
|
|
2084
|
+
}
|
|
2085
|
+
for (const o of jaxpr.eqns[j].outBinders) {
|
|
2086
|
+
const u = p1NextBlack.get(o);
|
|
2087
|
+
if (u) reach.add(u);
|
|
2088
|
+
}
|
|
1999
2089
|
}
|
|
2000
|
-
if (reach.size
|
|
2001
|
-
const b = reach.values().next().value;
|
|
2002
|
-
for (const v of eqn.outBinders) p1NextBlack.set(v, b);
|
|
2003
|
-
} else if (reach.size > 1) for (const v of eqn.outBinders) {
|
|
2090
|
+
if (reach.size > 1 || needsCleanOutput) for (const v of eqn.outBinders) {
|
|
2004
2091
|
blackNodes.add(v);
|
|
2005
2092
|
p1NextBlack.set(v, v);
|
|
2006
2093
|
}
|
|
2094
|
+
else if (reach.size === 1) {
|
|
2095
|
+
const b = reach.values().next().value;
|
|
2096
|
+
for (const v of eqn.outBinders) p1NextBlack.set(v, b);
|
|
2097
|
+
}
|
|
2007
2098
|
}
|
|
2008
2099
|
const p2Deps = /* @__PURE__ */ new Map();
|
|
2009
2100
|
for (const v of jaxpr.inBinders) p2Deps.set(v, new Set([v]));
|
|
@@ -2022,7 +2113,7 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2022
2113
|
let assocInput = -1;
|
|
2023
2114
|
for (let i = 0; i < eqn.inputs.length; i++) {
|
|
2024
2115
|
const input = eqn.inputs[i];
|
|
2025
|
-
if (input instanceof Var &&
|
|
2116
|
+
if (input instanceof Var && varToDefn.has(input)) {
|
|
2026
2117
|
let uniqueDeps = 0;
|
|
2027
2118
|
for (const dep of deps[i]) if (depCounter.get(dep) === 1) uniqueDeps++;
|
|
2028
2119
|
if (uniqueDeps > maxUniqueDeps) {
|
|
@@ -2033,7 +2124,7 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2033
2124
|
}
|
|
2034
2125
|
if (assocInput === -1) throw new Error(`internal: maxArgs, no input found to mark as black in Jaxpr equation ${eqn}`);
|
|
2035
2126
|
const assocVar = eqn.inputs[assocInput];
|
|
2036
|
-
p2idx =
|
|
2127
|
+
p2idx = varToDefn.get(assocVar);
|
|
2037
2128
|
for (const out of jaxpr.eqns[p2idx].outBinders) blackNodes.add(out);
|
|
2038
2129
|
} else {
|
|
2039
2130
|
const s = new Set(depCounter.keys());
|
|
@@ -3460,6 +3551,15 @@ const vmapRules = {
|
|
|
3460
3551
|
const z = dot$2(x, y);
|
|
3461
3552
|
return [[z], [z.ndim - 1]];
|
|
3462
3553
|
},
|
|
3554
|
+
[Primitive.Conv](axisSize, [x, y], [xBdim, yBdim], params) {
|
|
3555
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3556
|
+
y = moveBatchAxis(axisSize, yBdim, 0, y);
|
|
3557
|
+
const z = conv$1(x, y, {
|
|
3558
|
+
...params,
|
|
3559
|
+
vmapDims: params.vmapDims + 1
|
|
3560
|
+
});
|
|
3561
|
+
return [[z], [0]];
|
|
3562
|
+
},
|
|
3463
3563
|
[Primitive.Compare](axisSize, args, dims, { op }) {
|
|
3464
3564
|
return broadcastBatcher((x, y) => compare(x, y, op))(axisSize, args, dims, {});
|
|
3465
3565
|
},
|
|
@@ -3904,7 +4004,7 @@ function partialEvalGraphToJaxpr(tracersIn, tracersOut) {
|
|
|
3904
4004
|
for (const t of tracersIn) t.dispose();
|
|
3905
4005
|
for (const t of tracersOut) t.dispose();
|
|
3906
4006
|
jaxpr = jaxpr.simplify();
|
|
3907
|
-
if (DEBUG >= 5) console.
|
|
4007
|
+
if (DEBUG >= 5) console.info("jaxpr from partial evaluation:\n" + jaxpr.toString());
|
|
3908
4008
|
return {
|
|
3909
4009
|
jaxpr,
|
|
3910
4010
|
consts
|
|
@@ -4038,22 +4138,25 @@ const transposeRules = {
|
|
|
4038
4138
|
},
|
|
4039
4139
|
[Primitive.Conv]([ct], [lhs, rhs], params) {
|
|
4040
4140
|
if (lhs instanceof UndefPrimal === rhs instanceof UndefPrimal) throw new NonlinearError(Primitive.Conv);
|
|
4141
|
+
const v = params.vmapDims;
|
|
4041
4142
|
const rev01 = [
|
|
4042
|
-
|
|
4043
|
-
|
|
4044
|
-
|
|
4143
|
+
...range(v),
|
|
4144
|
+
v + 1,
|
|
4145
|
+
v,
|
|
4146
|
+
...range(v + 2, ct.ndim)
|
|
4045
4147
|
];
|
|
4046
4148
|
if (lhs instanceof UndefPrimal) {
|
|
4047
4149
|
let kernel = rhs;
|
|
4048
4150
|
kernel = transpose$1(kernel, rev01);
|
|
4049
|
-
kernel = flip$1(kernel, range(2, kernel.ndim));
|
|
4151
|
+
kernel = flip$1(kernel, range(v + 2, kernel.ndim));
|
|
4050
4152
|
const result = conv$1(ct, kernel, {
|
|
4153
|
+
vmapDims: v,
|
|
4051
4154
|
strides: params.lhsDilation,
|
|
4052
4155
|
padding: params.padding.map(([pl, _pr], i) => {
|
|
4053
|
-
const dilatedKernel = (kernel.shape[i + 2] - 1) * params.rhsDilation[i] + 1;
|
|
4054
|
-
const dilatedCt = (ct.shape[i + 2] - 1) * params.strides[i] + 1;
|
|
4156
|
+
const dilatedKernel = (kernel.shape[i + v + 2] - 1) * params.rhsDilation[i] + 1;
|
|
4157
|
+
const dilatedCt = (ct.shape[i + v + 2] - 1) * params.strides[i] + 1;
|
|
4055
4158
|
const padBefore = dilatedKernel - 1 - pl;
|
|
4056
|
-
const dilatedLhs = (lhs.aval.shape[i + 2] - 1) * params.lhsDilation[i] + 1;
|
|
4159
|
+
const dilatedLhs = (lhs.aval.shape[i + v + 2] - 1) * params.lhsDilation[i] + 1;
|
|
4057
4160
|
const padAfter = dilatedLhs + dilatedKernel - 1 - dilatedCt - padBefore;
|
|
4058
4161
|
return [padBefore, padAfter];
|
|
4059
4162
|
}),
|
|
@@ -4065,11 +4168,12 @@ const transposeRules = {
|
|
|
4065
4168
|
const newLhs = transpose$1(lhs, rev01);
|
|
4066
4169
|
const newRhs = transpose$1(ct, rev01);
|
|
4067
4170
|
let result = conv$1(newLhs, newRhs, {
|
|
4171
|
+
vmapDims: v,
|
|
4068
4172
|
strides: params.rhsDilation,
|
|
4069
4173
|
padding: params.padding.map(([pl, _pr], i) => {
|
|
4070
|
-
const dilatedLhs = (lhs.aval.shape[i + 2] - 1) * params.lhsDilation[i] + 1;
|
|
4071
|
-
const dilatedKernel = (rhs.aval.shape[i + 2] - 1) * params.rhsDilation[i] + 1;
|
|
4072
|
-
const dilatedCt = (ct.shape[i + 2] - 1) * params.strides[i] + 1;
|
|
4174
|
+
const dilatedLhs = (lhs.aval.shape[i + v + 2] - 1) * params.lhsDilation[i] + 1;
|
|
4175
|
+
const dilatedKernel = (rhs.aval.shape[i + v + 2] - 1) * params.rhsDilation[i] + 1;
|
|
4176
|
+
const dilatedCt = (ct.shape[i + v + 2] - 1) * params.strides[i] + 1;
|
|
4073
4177
|
const padFromLhs = dilatedCt - dilatedLhs;
|
|
4074
4178
|
const padFromRhs = dilatedKernel - pl - 1;
|
|
4075
4179
|
return [pl, padFromLhs + padFromRhs];
|
|
@@ -4318,13 +4422,46 @@ function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
|
|
|
4318
4422
|
*
|
|
4319
4423
|
* Grouped convolutions are not supported right now.
|
|
4320
4424
|
*/
|
|
4321
|
-
function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation } = {}) {
|
|
4425
|
+
function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
|
|
4322
4426
|
if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
|
|
4323
4427
|
if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
|
|
4324
4428
|
if (typeof padding === "string") {
|
|
4325
4429
|
if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
|
|
4326
4430
|
padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? rep(rhs.ndim - 2, 1), padding);
|
|
4327
4431
|
}
|
|
4432
|
+
if (featureGroupCount !== 1) {
|
|
4433
|
+
const G = featureGroupCount;
|
|
4434
|
+
const [N, C_in, ...xs] = lhs.shape;
|
|
4435
|
+
const [C_out, C_in_per_group, ...ks] = rhs.shape;
|
|
4436
|
+
if (C_in % G !== 0) throw new Error(`featureGroupCount=${G} must divide input channels=${C_in}`);
|
|
4437
|
+
if (C_out % G !== 0) throw new Error(`featureGroupCount=${G} must divide output channels=${C_out}`);
|
|
4438
|
+
if (C_in / G !== C_in_per_group) throw new Error(`rhs input channels=${C_in_per_group} must equal lhs input channels / groups=${C_in / G}`);
|
|
4439
|
+
const lhsGrouped = moveaxis(lhs.reshape([
|
|
4440
|
+
N,
|
|
4441
|
+
G,
|
|
4442
|
+
C_in / G,
|
|
4443
|
+
...xs
|
|
4444
|
+
]), 1, 0);
|
|
4445
|
+
const rhsGrouped = rhs.reshape([
|
|
4446
|
+
G,
|
|
4447
|
+
C_out / G,
|
|
4448
|
+
C_in_per_group,
|
|
4449
|
+
...ks
|
|
4450
|
+
]);
|
|
4451
|
+
const result = conv$1(lhsGrouped, rhsGrouped, {
|
|
4452
|
+
vmapDims: 1,
|
|
4453
|
+
strides: windowStrides,
|
|
4454
|
+
padding,
|
|
4455
|
+
lhsDilation,
|
|
4456
|
+
rhsDilation
|
|
4457
|
+
});
|
|
4458
|
+
const ys = result.shape.slice(3);
|
|
4459
|
+
return moveaxis(result, 0, 1).reshape([
|
|
4460
|
+
N,
|
|
4461
|
+
C_out,
|
|
4462
|
+
...ys
|
|
4463
|
+
]);
|
|
4464
|
+
}
|
|
4328
4465
|
return conv$1(lhs, rhs, {
|
|
4329
4466
|
strides: windowStrides,
|
|
4330
4467
|
padding,
|
|
@@ -4610,6 +4747,8 @@ __export(numpy_exports, {
|
|
|
4610
4747
|
concatenate: () => concatenate,
|
|
4611
4748
|
cos: () => cos,
|
|
4612
4749
|
cosh: () => cosh,
|
|
4750
|
+
cumsum: () => cumsum,
|
|
4751
|
+
cumulativeSum: () => cumulativeSum,
|
|
4613
4752
|
deg2rad: () => deg2rad,
|
|
4614
4753
|
degrees: () => degrees,
|
|
4615
4754
|
diag: () => diag,
|
|
@@ -4918,6 +5057,25 @@ function argmax(a, axis, opts) {
|
|
|
4918
5057
|
}).reshape([shape$1[axis], ...rep(shape$1.length - axis - 1, 1)]));
|
|
4919
5058
|
return length.sub(max(idx, axis, opts));
|
|
4920
5059
|
}
|
|
5060
|
+
/**
|
|
5061
|
+
* Cumulative sum of elements along an axis.
|
|
5062
|
+
*
|
|
5063
|
+
* Currently this function is `O(n^2)`, we'll improve this later on with a
|
|
5064
|
+
* two-phase parallel reduction algorithm.
|
|
5065
|
+
*/
|
|
5066
|
+
function cumsum(a, axis) {
|
|
5067
|
+
a = fudgeArray(a);
|
|
5068
|
+
if (axis === void 0) {
|
|
5069
|
+
a = a.ravel();
|
|
5070
|
+
axis = 0;
|
|
5071
|
+
} else axis = checkAxis(axis, a.ndim);
|
|
5072
|
+
const n = a.shape[axis];
|
|
5073
|
+
a = moveaxis$1(a, axis, -1);
|
|
5074
|
+
a = broadcast(a, a.shape.concat(n), [-2]);
|
|
5075
|
+
return moveaxis$1(tril(a).sum(-1), -1, axis);
|
|
5076
|
+
}
|
|
5077
|
+
/** @function Alternative name for `jax.numpy.cumsum()`. */
|
|
5078
|
+
const cumulativeSum = cumsum;
|
|
4921
5079
|
/** Reverse the elements in an array along the given axes. */
|
|
4922
5080
|
function flip(x, axis = null) {
|
|
4923
5081
|
const nd = ndim(x);
|
|
@@ -5153,7 +5311,10 @@ function allclose(actual, expected, options) {
|
|
|
5153
5311
|
if (!deepEqual(x.shape, y.shape)) return false;
|
|
5154
5312
|
const xData = x.dataSync();
|
|
5155
5313
|
const yData = y.dataSync();
|
|
5156
|
-
for (let i = 0; i < xData.length; i++)
|
|
5314
|
+
for (let i = 0; i < xData.length; i++) {
|
|
5315
|
+
if (isNaN(xData[i]) !== isNaN(yData[i])) return false;
|
|
5316
|
+
if (Math.abs(xData[i] - yData[i]) > atol + rtol * Math.abs(yData[i])) return false;
|
|
5317
|
+
}
|
|
5157
5318
|
return true;
|
|
5158
5319
|
}
|
|
5159
5320
|
/** Matrix product of two arrays. */
|
|
@@ -5612,7 +5773,10 @@ const degrees = rad2deg;
|
|
|
5612
5773
|
* Computes first array raised to power of second array, element-wise.
|
|
5613
5774
|
*/
|
|
5614
5775
|
const power = jit$1(function power$1(x1, x2) {
|
|
5615
|
-
|
|
5776
|
+
const x2i = trunc(x2.ref);
|
|
5777
|
+
const shouldBeNaN = multiply(x2.ref.notEqual(x2i.ref), x1.ref.less(0));
|
|
5778
|
+
const resultSign = where(mod(x2i, 2).notEqual(0), where(x1.ref.less(0), -1, 1), 1);
|
|
5779
|
+
return where(shouldBeNaN, nan, exp(log(abs(x1)).mul(x2)).mul(resultSign));
|
|
5616
5780
|
});
|
|
5617
5781
|
/** @function Alias of `jax.numpy.power()`. */
|
|
5618
5782
|
const pow = power;
|
|
@@ -5968,22 +6132,22 @@ function logSoftmax(x, axis = -1) {
|
|
|
5968
6132
|
*
|
|
5969
6133
|
* Reference: https://en.wikipedia.org/wiki/LogSumExp
|
|
5970
6134
|
*/
|
|
5971
|
-
function logsumexp(x, axis = null) {
|
|
6135
|
+
function logsumexp(x, axis = null, opts) {
|
|
5972
6136
|
x = fudgeArray(x);
|
|
5973
6137
|
axis = normalizeAxis(axis, x.ndim);
|
|
5974
6138
|
if (axis.length === 0) return x;
|
|
5975
|
-
const xMax = stopGradient(max(x.ref, axis));
|
|
5976
|
-
const
|
|
5977
|
-
const
|
|
5978
|
-
return
|
|
6139
|
+
const xMax = stopGradient(max(x.ref, axis, { keepdims: true }));
|
|
6140
|
+
const shifted = x.sub(xMax.ref);
|
|
6141
|
+
const result = xMax.add(log(exp(shifted).sum(axis, { keepdims: true })));
|
|
6142
|
+
return opts?.keepdims ? result : squeeze(result, axis);
|
|
5979
6143
|
}
|
|
5980
6144
|
/** Log-mean-exp reduction, like `jax.nn.logsumexp()` but subtracts `log(n)`. */
|
|
5981
|
-
function logmeanexp(x, axis = null) {
|
|
6145
|
+
function logmeanexp(x, axis = null, opts) {
|
|
5982
6146
|
x = fudgeArray(x);
|
|
5983
6147
|
axis = normalizeAxis(axis, x.ndim);
|
|
5984
6148
|
if (axis.length === 0) return x;
|
|
5985
6149
|
const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
|
|
5986
|
-
return logsumexp(x, axis).sub(Math.log(n));
|
|
6150
|
+
return logsumexp(x, axis, opts).sub(Math.log(n));
|
|
5987
6151
|
}
|
|
5988
6152
|
/**
|
|
5989
6153
|
* Standardizes input to zero mean and unit variance.
|