@jax-js/jax 0.1.4 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +10 -7
- package/dist/{backend-tngXtWe4.js → backend-DaqL-MNz.js} +96 -7
- package/dist/{backend-Bu9GY6sK.cjs → backend-DziQSaoQ.cjs} +101 -6
- package/dist/index.cjs +737 -141
- package/dist/index.d.cts +238 -9
- package/dist/index.d.ts +238 -9
- package/dist/index.js +737 -141
- package/dist/webgl-ClIYb8jP.cjs +522 -0
- package/dist/webgl-RSuZKvgc.js +522 -0
- package/dist/{webgpu-Oj3Kd-kd.cjs → webgpu-Db2JrNBr.cjs} +296 -3
- package/dist/{webgpu-ChVgx3b6.js → webgpu-Dh7k9io0.js} +296 -3
- package/package.json +1 -1
package/dist/index.cjs
CHANGED
|
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
|
|
|
30
30
|
}) : target, mod$1));
|
|
31
31
|
|
|
32
32
|
//#endregion
|
|
33
|
-
const require_backend = require('./backend-
|
|
33
|
+
const require_backend = require('./backend-DziQSaoQ.cjs');
|
|
34
34
|
|
|
35
35
|
//#region src/frontend/convolution.ts
|
|
36
36
|
/**
|
|
@@ -387,6 +387,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
387
387
|
Primitive$1["PoolTranspose"] = "pool_transpose";
|
|
388
388
|
Primitive$1["Compare"] = "compare";
|
|
389
389
|
Primitive$1["Where"] = "where";
|
|
390
|
+
Primitive$1["Concatenate"] = "concatenate";
|
|
391
|
+
Primitive$1["Split"] = "split";
|
|
390
392
|
Primitive$1["RandomBits"] = "random_bits";
|
|
391
393
|
Primitive$1["Gather"] = "gather";
|
|
392
394
|
Primitive$1["Transpose"] = "transpose";
|
|
@@ -399,6 +401,7 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
399
401
|
Primitive$1["Argsort"] = "argsort";
|
|
400
402
|
Primitive$1["TriangularSolve"] = "triangular_solve";
|
|
401
403
|
Primitive$1["Cholesky"] = "cholesky";
|
|
404
|
+
Primitive$1["LU"] = "lu";
|
|
402
405
|
Primitive$1["Jit"] = "jit";
|
|
403
406
|
return Primitive$1;
|
|
404
407
|
}({});
|
|
@@ -530,7 +533,25 @@ function where$1(cond, x, y) {
|
|
|
530
533
|
y
|
|
531
534
|
]);
|
|
532
535
|
}
|
|
536
|
+
function concatenate$1(xs, axis) {
|
|
537
|
+
if (xs.length === 0) throw new Error("concatenate requires at least one input");
|
|
538
|
+
const avals = xs.map((x) => ShapedArray.fromAval(getAval(x)));
|
|
539
|
+
axis = require_backend.checkAxis(axis, avals[0].ndim);
|
|
540
|
+
for (const x of avals) if (x.ndim !== avals[0].ndim || !x.shape.every((s, i) => i === axis || s === avals[0].shape[i])) throw new Error(`Concatenate: inputs ${avals[0]} and ${x} must match shapes except on axis ${axis}`);
|
|
541
|
+
return bind1(Primitive.Concatenate, xs, { axis });
|
|
542
|
+
}
|
|
543
|
+
function split$2(x, axis, sizes) {
|
|
544
|
+
axis = require_backend.checkAxis(axis, ndim$1(x));
|
|
545
|
+
if (sizes.some((s) => s < 0 || !Number.isInteger(s))) throw new Error(`split: sizes must be nonnegative integers, got ${JSON.stringify(sizes)}`);
|
|
546
|
+
const totalSize = sizes.reduce((a, b) => a + b, 0);
|
|
547
|
+
if (totalSize !== getShape(x)[axis]) throw new Error(`split: sizes must sum to the size of the axis ${axis}, got ${totalSize}`);
|
|
548
|
+
return bind(Primitive.Split, [x], {
|
|
549
|
+
axis,
|
|
550
|
+
sizes
|
|
551
|
+
});
|
|
552
|
+
}
|
|
533
553
|
function randomBits(k0, k1, shape$1, mode = "xor") {
|
|
554
|
+
if (!require_backend.deepEqual(k0.shape, k1.shape) || k0.dtype !== require_backend.DType.Uint32 || k1.dtype !== require_backend.DType.Uint32) throw new Error(`randomBits: key parts must be uint32 with the same shape, got ${ShapedArray.fromAval(k0.aval)} and ${ShapedArray.fromAval(k1.aval)}`);
|
|
534
555
|
return bind1(Primitive.RandomBits, [k0, k1], {
|
|
535
556
|
shape: shape$1,
|
|
536
557
|
mode
|
|
@@ -597,6 +618,11 @@ function pad$1(x, width) {
|
|
|
597
618
|
return bind1(Primitive.Pad, [x], { width });
|
|
598
619
|
}
|
|
599
620
|
function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
|
|
621
|
+
const as = getShape(a);
|
|
622
|
+
const bs = getShape(b);
|
|
623
|
+
if (as.length < 2 || bs.length < 2) throw new Error(`triangular_solve: must be >=2D, got a=${as}, b=${bs}`);
|
|
624
|
+
const n = as[as.length - 2];
|
|
625
|
+
if (n !== as[as.length - 1] || n !== bs[bs.length - 1]) throw new Error(`triangular_solve: incompatible shapes a=${as}, b=${bs}`);
|
|
600
626
|
if (lower) {
|
|
601
627
|
a = flip$1(a, [-2, -1]);
|
|
602
628
|
b = flip$1(b, [-1]);
|
|
@@ -606,8 +632,15 @@ function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
|
|
|
606
632
|
return x;
|
|
607
633
|
}
|
|
608
634
|
function cholesky$2(x) {
|
|
635
|
+
const aval = ShapedArray.fromAval(getAval(x));
|
|
636
|
+
if (aval.ndim < 2 || aval.shape[aval.ndim - 1] !== aval.shape[aval.ndim - 2]) throw new Error(`cholesky: expected batch of square matrices, got ${aval}`);
|
|
609
637
|
return bind1(Primitive.Cholesky, [x]);
|
|
610
638
|
}
|
|
639
|
+
function lu$1(x) {
|
|
640
|
+
const aval = ShapedArray.fromAval(getAval(x));
|
|
641
|
+
if (aval.ndim < 2) throw new Error(`lu: expected batch of matrices, got ${aval}`);
|
|
642
|
+
return bind(Primitive.LU, [x]);
|
|
643
|
+
}
|
|
611
644
|
function sort$1(x) {
|
|
612
645
|
const nd = ndim$1(x);
|
|
613
646
|
if (nd === 0) throw new Error("sort: requires at least 1D input");
|
|
@@ -716,6 +749,9 @@ var Tracer = class Tracer {
|
|
|
716
749
|
mul(other) {
|
|
717
750
|
return mul(this, other);
|
|
718
751
|
}
|
|
752
|
+
mod(other) {
|
|
753
|
+
return mod(this, other);
|
|
754
|
+
}
|
|
719
755
|
greater(other) {
|
|
720
756
|
return greater$1(this, other);
|
|
721
757
|
}
|
|
@@ -828,8 +864,14 @@ var Tracer = class Tracer {
|
|
|
828
864
|
*/
|
|
829
865
|
*[Symbol.iterator]() {
|
|
830
866
|
if (this.ndim === 0) throw new Error("Cannot iterate over a scalar array");
|
|
831
|
-
|
|
832
|
-
this.
|
|
867
|
+
let residual = this;
|
|
868
|
+
const subarrayShape = this.shape.slice(1);
|
|
869
|
+
for (let i = 0; i < this.shape[0]; i++) {
|
|
870
|
+
const lr = split$2(residual, 0, [1, residual.shape[0] - 1]);
|
|
871
|
+
yield lr[0].reshape(subarrayShape);
|
|
872
|
+
residual = lr[1];
|
|
873
|
+
}
|
|
874
|
+
residual.dispose();
|
|
833
875
|
}
|
|
834
876
|
/**
|
|
835
877
|
* Return a sorted copy of an array in ascending order.
|
|
@@ -979,6 +1021,9 @@ var ShapedArray = class ShapedArray {
|
|
|
979
1021
|
get size() {
|
|
980
1022
|
return require_backend.prod(this.shape);
|
|
981
1023
|
}
|
|
1024
|
+
scalar() {
|
|
1025
|
+
return new ShapedArray([], this.dtype, this.weakType);
|
|
1026
|
+
}
|
|
982
1027
|
toString() {
|
|
983
1028
|
return `${this.dtype}[${this.shape.join(",")}]`;
|
|
984
1029
|
}
|
|
@@ -1588,7 +1633,7 @@ const abstractEvalRules = {
|
|
|
1588
1633
|
return [new ShapedArray(shape$1, dtype, weakType)];
|
|
1589
1634
|
},
|
|
1590
1635
|
[Primitive.Conv]([lhs, rhs], params) {
|
|
1591
|
-
const { dtype, weakType } = promoteAvals(
|
|
1636
|
+
const { dtype, weakType } = promoteAvals(lhs.scalar(), rhs.scalar());
|
|
1592
1637
|
const shape$1 = checkConvShape(lhs.shape, rhs.shape, params);
|
|
1593
1638
|
return [new ShapedArray(shape$1, dtype, weakType)];
|
|
1594
1639
|
},
|
|
@@ -1599,10 +1644,25 @@ const abstractEvalRules = {
|
|
|
1599
1644
|
const shape$1 = require_backend.generalBroadcast(cond.shape, xy.shape);
|
|
1600
1645
|
return [new ShapedArray(shape$1, xy.dtype, xy.weakType)];
|
|
1601
1646
|
},
|
|
1647
|
+
[Primitive.Concatenate](xs, { axis }) {
|
|
1648
|
+
if (xs.length === 0) throw new TypeError("Concatenate requires at least one input");
|
|
1649
|
+
for (const x of xs) if (x.ndim !== xs[0].ndim || !x.shape.every((s, i) => i === axis || s === xs[0].shape[i])) throw new TypeError(`Concatenate: inputs ${xs[0]} and ${x} must match shapes except on axis ${axis}`);
|
|
1650
|
+
const shape$1 = xs[0].shape.slice();
|
|
1651
|
+
shape$1[axis] = xs.reduce((sum$1, x) => sum$1 + x.shape[axis], 0);
|
|
1652
|
+
const { dtype, weakType } = xs.map((x) => x.scalar()).reduce(promoteAvals);
|
|
1653
|
+
return [new ShapedArray(shape$1, dtype, weakType)];
|
|
1654
|
+
},
|
|
1655
|
+
[Primitive.Split]([x], { axis, sizes }) {
|
|
1656
|
+
const totalSize = sizes.reduce((a, b) => a + b, 0);
|
|
1657
|
+
if (x.shape[axis] !== totalSize) throw new TypeError(`Split: sizes ${sizes} do not sum to dimension ${x.shape[axis]} on axis ${axis}`);
|
|
1658
|
+
return sizes.map((size$1) => {
|
|
1659
|
+
return new ShapedArray(x.shape.toSpliced(axis, 1, size$1), x.dtype, x.weakType);
|
|
1660
|
+
});
|
|
1661
|
+
},
|
|
1602
1662
|
[Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
|
|
1603
1663
|
if (k0.dtype !== require_backend.DType.Uint32 || k1.dtype !== require_backend.DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
|
|
1604
|
-
|
|
1605
|
-
if (!require_backend.deepEqual(
|
|
1664
|
+
if (!require_backend.deepEqual(k0.shape, k1.shape)) throw new TypeError(`RandomBits: Keys have different shapes ${k0.shape} and ${k1.shape}`);
|
|
1665
|
+
if (!require_backend.deepEqual(shape$1.slice(0, k0.ndim), k0.shape)) throw new TypeError(`RandomBits: generated shape ${shape$1} must match key shape ${k0.shape}`);
|
|
1606
1666
|
return [new ShapedArray(shape$1, require_backend.DType.Uint32, false)];
|
|
1607
1667
|
},
|
|
1608
1668
|
[Primitive.Gather]([x, ...indices], { axis, outDim }) {
|
|
@@ -1659,6 +1719,16 @@ const abstractEvalRules = {
|
|
|
1659
1719
|
if (a.shape[a.ndim - 2] !== a.shape[a.ndim - 1]) throw new TypeError(`cholesky: must be square, got ${a}`);
|
|
1660
1720
|
return [ShapedArray.fromAval(a)];
|
|
1661
1721
|
},
|
|
1722
|
+
[Primitive.LU]([a]) {
|
|
1723
|
+
if (a.ndim < 2) throw new TypeError(`lu: requires at least 2D input, got ${a}`);
|
|
1724
|
+
const batch = a.shape.slice(0, -2);
|
|
1725
|
+
const [m, n] = a.shape.slice(-2);
|
|
1726
|
+
return [
|
|
1727
|
+
ShapedArray.fromAval(a),
|
|
1728
|
+
new ShapedArray([...batch, Math.min(m, n)], require_backend.DType.Int32, false),
|
|
1729
|
+
new ShapedArray([...batch, m], require_backend.DType.Int32, false)
|
|
1730
|
+
];
|
|
1731
|
+
},
|
|
1662
1732
|
[Primitive.Jit](args, { jaxpr }) {
|
|
1663
1733
|
const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
|
|
1664
1734
|
if (args.length !== inTypes.length) throw new TypeError(`jit expected ${inTypes.length} arguments, got ${args.length}`);
|
|
@@ -1740,7 +1810,8 @@ const routinePrimitives = new Map([
|
|
|
1740
1810
|
[Primitive.Sort, require_backend.Routines.Sort],
|
|
1741
1811
|
[Primitive.Argsort, require_backend.Routines.Argsort],
|
|
1742
1812
|
[Primitive.TriangularSolve, require_backend.Routines.TriangularSolve],
|
|
1743
|
-
[Primitive.Cholesky, require_backend.Routines.Cholesky]
|
|
1813
|
+
[Primitive.Cholesky, require_backend.Routines.Cholesky],
|
|
1814
|
+
[Primitive.LU, require_backend.Routines.LU]
|
|
1744
1815
|
]);
|
|
1745
1816
|
/** Result of compiling a Jaxpr. Can be evaluated on a series of inputs. */
|
|
1746
1817
|
var JitProgram = class {
|
|
@@ -1911,10 +1982,10 @@ function jitCompile(backend, jaxpr) {
|
|
|
1911
1982
|
inputs.push(jv.arg);
|
|
1912
1983
|
} else if (input instanceof Lit) inputs.push(builder.pushLit(input));
|
|
1913
1984
|
const outputs = [];
|
|
1914
|
-
for (const outVar
|
|
1915
|
-
const outId = builder.pushBuffer(outVar
|
|
1985
|
+
for (const outVar of eqn.outBinders) {
|
|
1986
|
+
const outId = builder.pushBuffer(outVar.aval.size * require_backend.byteWidth(outVar.aval.dtype));
|
|
1916
1987
|
outputs.push(outId);
|
|
1917
|
-
ctx.set(outVar
|
|
1988
|
+
ctx.set(outVar, {
|
|
1918
1989
|
type: "imm",
|
|
1919
1990
|
arg: outId
|
|
1920
1991
|
});
|
|
@@ -1965,35 +2036,37 @@ function jitCompile(backend, jaxpr) {
|
|
|
1965
2036
|
let reduction;
|
|
1966
2037
|
if (inputReduction) {
|
|
1967
2038
|
const jv = inputReduction;
|
|
1968
|
-
const newEpilogue = rule(inputExps, inputAvals, eqn.params).exp;
|
|
1969
|
-
exp$2 = jv.exp.reindexGids(addArgs(jv.args));
|
|
2039
|
+
const newEpilogue = rule(inputExps, inputAvals, eqn.params).exp[0];
|
|
2040
|
+
exp$2 = [jv.exp.reindexGids(addArgs(jv.args))];
|
|
1970
2041
|
reduction = new require_backend.Reduction(jv.reduction.dtype, jv.reduction.op, jv.reduction.size, newEpilogue);
|
|
1971
2042
|
} else {
|
|
1972
2043
|
const ruleOutput = rule(inputExps, inputAvals, eqn.params);
|
|
1973
2044
|
exp$2 = ruleOutput.exp;
|
|
1974
2045
|
reduction = ruleOutput.reduction;
|
|
1975
2046
|
}
|
|
1976
|
-
|
|
1977
|
-
|
|
1978
|
-
|
|
1979
|
-
|
|
1980
|
-
|
|
1981
|
-
|
|
1982
|
-
|
|
1983
|
-
|
|
1984
|
-
|
|
2047
|
+
for (let i$1 = 0; i$1 < eqn.outBinders.length; i$1++) {
|
|
2048
|
+
const outVar = eqn.outBinders[i$1];
|
|
2049
|
+
if (blackNodes.has(outVar)) {
|
|
2050
|
+
const nargs$1 = inputArgs.length;
|
|
2051
|
+
const size$1 = outVar.aval.size;
|
|
2052
|
+
const kernel = new require_backend.Kernel(nargs$1, size$1, exp$2[i$1], reduction);
|
|
2053
|
+
const outId = builder.pushKernel(kernel, inputArgs);
|
|
2054
|
+
ctx.set(outVar, {
|
|
2055
|
+
type: "imm",
|
|
2056
|
+
arg: outId
|
|
2057
|
+
});
|
|
2058
|
+
} else if (reduction) ctx.set(outVar, {
|
|
2059
|
+
type: "red",
|
|
2060
|
+
exp: exp$2[i$1],
|
|
2061
|
+
reduction,
|
|
2062
|
+
args: inputArgs
|
|
1985
2063
|
});
|
|
1986
|
-
|
|
1987
|
-
|
|
1988
|
-
|
|
1989
|
-
|
|
1990
|
-
|
|
1991
|
-
}
|
|
1992
|
-
else ctx.set(outVar, {
|
|
1993
|
-
type: "exp",
|
|
1994
|
-
exp: exp$2,
|
|
1995
|
-
args: inputArgs
|
|
1996
|
-
});
|
|
2064
|
+
else ctx.set(outVar, {
|
|
2065
|
+
type: "exp",
|
|
2066
|
+
exp: exp$2[i$1],
|
|
2067
|
+
args: inputArgs
|
|
2068
|
+
});
|
|
2069
|
+
}
|
|
1997
2070
|
}
|
|
1998
2071
|
const outputIds = [];
|
|
1999
2072
|
for (const out of jaxpr.outs) if (out instanceof Var) {
|
|
@@ -2034,17 +2107,17 @@ function broadcastedJit(fn, opts) {
|
|
|
2034
2107
|
if (exp$2.dtype !== newDtype && !skipCastIdx.includes(i)) exp$2 = require_backend.AluExp.cast(newDtype, exp$2);
|
|
2035
2108
|
return exp$2;
|
|
2036
2109
|
});
|
|
2037
|
-
return { exp: fn(exps, params) };
|
|
2110
|
+
return { exp: [fn(exps, params)] };
|
|
2038
2111
|
};
|
|
2039
2112
|
}
|
|
2040
2113
|
function unopJit(fn) {
|
|
2041
2114
|
return ([a], [_as], params) => {
|
|
2042
|
-
return { exp: fn(a, params) };
|
|
2115
|
+
return { exp: [fn(a, params)] };
|
|
2043
2116
|
};
|
|
2044
2117
|
}
|
|
2045
2118
|
function reshapeJit(fn) {
|
|
2046
2119
|
return ([a], [_as], params) => {
|
|
2047
|
-
return { exp: reshapeViews(a, (st) => fn(st, params)) };
|
|
2120
|
+
return { exp: [reshapeViews(a, (st) => fn(st, params))] };
|
|
2048
2121
|
};
|
|
2049
2122
|
}
|
|
2050
2123
|
function routineNoJit() {
|
|
@@ -2090,7 +2163,7 @@ const jitRules = {
|
|
|
2090
2163
|
a = reshapeViews(a, (st) => st.permute(perm).reshape(newShape), true);
|
|
2091
2164
|
const reduction = new require_backend.Reduction(a.dtype, op, reductionSize);
|
|
2092
2165
|
return {
|
|
2093
|
-
exp: a,
|
|
2166
|
+
exp: [a],
|
|
2094
2167
|
reduction
|
|
2095
2168
|
};
|
|
2096
2169
|
},
|
|
@@ -2101,13 +2174,13 @@ const jitRules = {
|
|
|
2101
2174
|
a = reshapeViews(a, (st) => st.compose(stX), true);
|
|
2102
2175
|
const reduction = new require_backend.Reduction(a.dtype, require_backend.AluOp.Add, stX.shape[stX.shape.length - 1]);
|
|
2103
2176
|
return {
|
|
2104
|
-
exp: a,
|
|
2177
|
+
exp: [a],
|
|
2105
2178
|
reduction
|
|
2106
2179
|
};
|
|
2107
2180
|
},
|
|
2108
2181
|
[Primitive.Dot]([a, b], [as, bs]) {
|
|
2109
2182
|
const k1 = jitRules[Primitive.Mul]([a, b], [as, bs], {});
|
|
2110
|
-
const c = k1.exp;
|
|
2183
|
+
const [c] = k1.exp;
|
|
2111
2184
|
const cs = promoteAvals(as, bs);
|
|
2112
2185
|
return jitRules[Primitive.Reduce]([c], [cs], {
|
|
2113
2186
|
op: require_backend.AluOp.Add,
|
|
@@ -2124,16 +2197,41 @@ const jitRules = {
|
|
|
2124
2197
|
},
|
|
2125
2198
|
[Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
|
|
2126
2199
|
[Primitive.Where]: broadcastedJit(([cond, a, b]) => require_backend.AluExp.where(cond, a, b), { skipCastIdx: [0] }),
|
|
2200
|
+
[Primitive.Concatenate](exps, avals, { axis }) {
|
|
2201
|
+
const ndim$2 = avals[0].ndim;
|
|
2202
|
+
const sizes = avals.map((x) => x.shape[axis]);
|
|
2203
|
+
const finalSize = sizes.reduce((a, b) => a + b, 0);
|
|
2204
|
+
const makePadAxis = (start, end) => require_backend.range(ndim$2).map((i) => i === axis ? [start, end] : [0, 0]);
|
|
2205
|
+
let cum = 0;
|
|
2206
|
+
const src = [];
|
|
2207
|
+
for (let i = 0; i < exps.length; i++) {
|
|
2208
|
+
const padding = makePadAxis(cum, finalSize - cum - sizes[i]);
|
|
2209
|
+
src.push(reshapeViews(exps[i], (st) => st.pad(padding)));
|
|
2210
|
+
cum += sizes[i];
|
|
2211
|
+
}
|
|
2212
|
+
return { exp: [src.reduce(require_backend.AluExp.add)] };
|
|
2213
|
+
},
|
|
2214
|
+
[Primitive.Split]([a], [as], { axis, sizes }) {
|
|
2215
|
+
const exp$2 = [];
|
|
2216
|
+
let start = 0;
|
|
2217
|
+
for (const size$1 of sizes) {
|
|
2218
|
+
const slice = require_backend.range(as.ndim).map((d) => d === axis ? [start, start + size$1] : [0, as.shape[d]]);
|
|
2219
|
+
exp$2.push(reshapeViews(a, (st) => st.shrink(slice)));
|
|
2220
|
+
start += size$1;
|
|
2221
|
+
}
|
|
2222
|
+
return { exp: exp$2 };
|
|
2223
|
+
},
|
|
2127
2224
|
[Primitive.RandomBits]: (keys, keyShapes, { shape: shape$1, mode }) => {
|
|
2225
|
+
const keyShape = keyShapes[0].shape;
|
|
2128
2226
|
const mapping = (st) => {
|
|
2129
|
-
if (!require_backend.deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, require_backend.range(shape
|
|
2227
|
+
if (!require_backend.deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, require_backend.range(st.shape.length, shape$1.length));
|
|
2130
2228
|
};
|
|
2131
2229
|
const k0 = reshapeViews(keys[0], mapping);
|
|
2132
2230
|
const k1 = reshapeViews(keys[1], mapping);
|
|
2133
2231
|
const c0 = require_backend.AluExp.u32(0);
|
|
2134
|
-
const c1 = require_backend.AluExp.cast(require_backend.DType.Uint32, require_backend.AluVar.gidx);
|
|
2232
|
+
const c1 = require_backend.AluExp.mod(require_backend.AluExp.cast(require_backend.DType.Uint32, require_backend.AluVar.gidx), require_backend.AluExp.u32(Math.max(require_backend.prod(shape$1.slice(keyShape.length)), 1)));
|
|
2135
2233
|
const exp$2 = require_backend.AluExp.threefry2x32(k0, k1, c0, c1, mode);
|
|
2136
|
-
return { exp: exp$2 };
|
|
2234
|
+
return { exp: [exp$2] };
|
|
2137
2235
|
},
|
|
2138
2236
|
[Primitive.Gather]([x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
|
|
2139
2237
|
const axisSet = new Set(axis);
|
|
@@ -2148,7 +2246,7 @@ const jitRules = {
|
|
|
2148
2246
|
for (const [i, iexp] of indices.entries()) src[axis[i]] = require_backend.AluExp.cast(require_backend.DType.Int32, reshapeViews(iexp, (st) => st.broadcast(finalShape, [...require_backend.range(outDim + indexShape.length - st.shape.length), ...require_backend.range(outDim + indexShape.length, finalShape.length)])));
|
|
2149
2247
|
const [index, valid] = require_backend.ShapeTracker.fromShape(xs.shape).toAluExp(src);
|
|
2150
2248
|
if (!valid.resolve()) throw new Error("internal: expected full validity mask in Gather");
|
|
2151
|
-
return { exp: x.substitute({ gidx: index }) };
|
|
2249
|
+
return { exp: [x.substitute({ gidx: index })] };
|
|
2152
2250
|
},
|
|
2153
2251
|
[Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
|
|
2154
2252
|
[Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
|
|
@@ -2164,6 +2262,7 @@ const jitRules = {
|
|
|
2164
2262
|
[Primitive.Argsort]: routineNoJit(),
|
|
2165
2263
|
[Primitive.TriangularSolve]: routineNoJit(),
|
|
2166
2264
|
[Primitive.Cholesky]: routineNoJit(),
|
|
2265
|
+
[Primitive.LU]: routineNoJit(),
|
|
2167
2266
|
[Primitive.Jit]() {
|
|
2168
2267
|
throw new Error("internal: Jit should have been flattened before JIT compilation");
|
|
2169
2268
|
}
|
|
@@ -2442,6 +2541,10 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2442
2541
|
this.#rc++;
|
|
2443
2542
|
return this;
|
|
2444
2543
|
}
|
|
2544
|
+
/** Get the current reference count (for debugging memory management). */
|
|
2545
|
+
get refCount() {
|
|
2546
|
+
return this.#rc;
|
|
2547
|
+
}
|
|
2445
2548
|
dispose() {
|
|
2446
2549
|
this.#check();
|
|
2447
2550
|
if (--this.#rc === 0) {
|
|
@@ -2599,7 +2702,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2599
2702
|
} else if (castDtype === void 0) {
|
|
2600
2703
|
castDtype = arrays[i].#dtype;
|
|
2601
2704
|
castWeakType = arrays[i].#weakType;
|
|
2602
|
-
} else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType),
|
|
2705
|
+
} else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), arrays[i].aval.scalar()));
|
|
2603
2706
|
const weakType = castWeakType && !strongTypeOutput;
|
|
2604
2707
|
const { backend, committed } = Array$1.#computeBackend(name, arrays);
|
|
2605
2708
|
arrays = arrays.map((ar) => ar._putSync(backend));
|
|
@@ -2992,17 +3095,44 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2992
3095
|
y
|
|
2993
3096
|
], { dtypeOverride: [require_backend.DType.Bool] })];
|
|
2994
3097
|
},
|
|
3098
|
+
[Primitive.Concatenate](xs, { axis }) {
|
|
3099
|
+
const ndim$2 = xs[0].ndim;
|
|
3100
|
+
const sizes = xs.map((x) => x.shape[axis]);
|
|
3101
|
+
const finalSize = sizes.reduce((a, b) => a + b, 0);
|
|
3102
|
+
const makePadAxis = (start, end) => require_backend.range(ndim$2).map((i) => i === axis ? [start, end] : [0, 0]);
|
|
3103
|
+
let cum = 0;
|
|
3104
|
+
const xsPadded = [];
|
|
3105
|
+
for (let i = 0; i < xs.length; i++) {
|
|
3106
|
+
const padding = makePadAxis(cum, finalSize - cum - sizes[i]);
|
|
3107
|
+
xsPadded.push(xs[i].#reshape(xs[i].#st.pad(padding)));
|
|
3108
|
+
cum += sizes[i];
|
|
3109
|
+
}
|
|
3110
|
+
const custom = (exps) => exps.reduce(require_backend.AluExp.add);
|
|
3111
|
+
return [Array$1.#naryCustom("concatenate", custom, xsPadded)];
|
|
3112
|
+
},
|
|
3113
|
+
[Primitive.Split]([x], { axis, sizes }) {
|
|
3114
|
+
const outputs = [];
|
|
3115
|
+
for (let i = 0, start = 0; i < sizes.length; i++) {
|
|
3116
|
+
const slice = require_backend.range(x.ndim).map((d) => d === axis ? [start, start + sizes[i]] : [0, x.shape[d]]);
|
|
3117
|
+
outputs.push(x.ref.#reshape(x.#st.shrink(slice)));
|
|
3118
|
+
start += sizes[i];
|
|
3119
|
+
}
|
|
3120
|
+
x.dispose();
|
|
3121
|
+
return outputs;
|
|
3122
|
+
},
|
|
2995
3123
|
[Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
|
|
2996
|
-
const keyShape =
|
|
2997
|
-
|
|
2998
|
-
const c0 = zeros(
|
|
3124
|
+
const keyShape = k0.shape;
|
|
3125
|
+
const genShape = shape$1.slice(keyShape.length);
|
|
3126
|
+
const c0 = zeros(genShape, {
|
|
2999
3127
|
dtype: require_backend.DType.Uint32,
|
|
3000
3128
|
device: k0.device
|
|
3001
3129
|
});
|
|
3002
|
-
const c1 = arange(0, require_backend.prod(
|
|
3130
|
+
const c1 = arange(0, require_backend.prod(genShape), 1, {
|
|
3003
3131
|
dtype: require_backend.DType.Uint32,
|
|
3004
3132
|
device: k0.device
|
|
3005
|
-
}).reshape(
|
|
3133
|
+
}).reshape(genShape);
|
|
3134
|
+
k0 = k0.#reshape(k0.#st.reshape(keyShape.concat(require_backend.rep(genShape.length, 1))));
|
|
3135
|
+
k1 = k1.#reshape(k1.#st.reshape(keyShape.concat(require_backend.rep(genShape.length, 1))));
|
|
3006
3136
|
const custom = ([k0$1, k1$1, c0$1, c1$1]) => require_backend.AluExp.threefry2x32(k0$1, k1$1, c0$1, c1$1, mode);
|
|
3007
3137
|
return [Array$1.#naryCustom("random_bits", custom, [
|
|
3008
3138
|
k0,
|
|
@@ -3036,40 +3166,63 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
3036
3166
|
},
|
|
3037
3167
|
[Primitive.Sort]([x]) {
|
|
3038
3168
|
const routine = new require_backend.Routine(require_backend.Routines.Sort, {
|
|
3039
|
-
inputShapes: [x.
|
|
3040
|
-
inputDtypes: [x.
|
|
3041
|
-
outputShapes: [x.
|
|
3042
|
-
outputDtypes: [x.
|
|
3169
|
+
inputShapes: [x.shape],
|
|
3170
|
+
inputDtypes: [x.dtype],
|
|
3171
|
+
outputShapes: [x.shape],
|
|
3172
|
+
outputDtypes: [x.dtype]
|
|
3043
3173
|
});
|
|
3044
3174
|
return Array$1.#routine(routine, [x], [x.#weakType]);
|
|
3045
3175
|
},
|
|
3046
3176
|
[Primitive.Argsort]([x]) {
|
|
3047
3177
|
const routine = new require_backend.Routine(require_backend.Routines.Argsort, {
|
|
3048
|
-
inputShapes: [x.
|
|
3049
|
-
inputDtypes: [x.
|
|
3050
|
-
outputShapes: [x.
|
|
3051
|
-
outputDtypes: [x.
|
|
3178
|
+
inputShapes: [x.shape],
|
|
3179
|
+
inputDtypes: [x.dtype],
|
|
3180
|
+
outputShapes: [x.shape, x.shape],
|
|
3181
|
+
outputDtypes: [x.dtype, require_backend.DType.Int32]
|
|
3052
3182
|
});
|
|
3053
3183
|
return Array$1.#routine(routine, [x], [x.#weakType, false]);
|
|
3054
3184
|
},
|
|
3055
3185
|
[Primitive.TriangularSolve]([a, b], { unitDiagonal }) {
|
|
3056
3186
|
const routine = new require_backend.Routine(require_backend.Routines.TriangularSolve, {
|
|
3057
|
-
inputShapes: [a.
|
|
3058
|
-
inputDtypes: [a.
|
|
3059
|
-
outputShapes: [b.
|
|
3060
|
-
outputDtypes: [b.
|
|
3187
|
+
inputShapes: [a.shape, b.shape],
|
|
3188
|
+
inputDtypes: [a.dtype, b.dtype],
|
|
3189
|
+
outputShapes: [b.shape],
|
|
3190
|
+
outputDtypes: [b.dtype]
|
|
3061
3191
|
}, { unitDiagonal });
|
|
3062
3192
|
return Array$1.#routine(routine, [a, b], [a.#weakType && b.#weakType]);
|
|
3063
3193
|
},
|
|
3064
3194
|
[Primitive.Cholesky]([a]) {
|
|
3065
3195
|
const routine = new require_backend.Routine(require_backend.Routines.Cholesky, {
|
|
3066
|
-
inputShapes: [a.
|
|
3067
|
-
inputDtypes: [a.
|
|
3068
|
-
outputShapes: [a.
|
|
3069
|
-
outputDtypes: [a.
|
|
3196
|
+
inputShapes: [a.shape],
|
|
3197
|
+
inputDtypes: [a.dtype],
|
|
3198
|
+
outputShapes: [a.shape],
|
|
3199
|
+
outputDtypes: [a.dtype]
|
|
3070
3200
|
});
|
|
3071
3201
|
return Array$1.#routine(routine, [a], [a.#weakType]);
|
|
3072
3202
|
},
|
|
3203
|
+
[Primitive.LU]([a]) {
|
|
3204
|
+
const batch = a.shape.slice(0, -2);
|
|
3205
|
+
const [m, n] = a.shape.slice(-2);
|
|
3206
|
+
const routine = new require_backend.Routine(require_backend.Routines.LU, {
|
|
3207
|
+
inputShapes: [a.shape],
|
|
3208
|
+
inputDtypes: [a.dtype],
|
|
3209
|
+
outputShapes: [
|
|
3210
|
+
a.shape,
|
|
3211
|
+
[...batch, Math.min(m, n)],
|
|
3212
|
+
[...batch, m]
|
|
3213
|
+
],
|
|
3214
|
+
outputDtypes: [
|
|
3215
|
+
a.dtype,
|
|
3216
|
+
require_backend.DType.Int32,
|
|
3217
|
+
require_backend.DType.Int32
|
|
3218
|
+
]
|
|
3219
|
+
});
|
|
3220
|
+
return Array$1.#routine(routine, [a], [
|
|
3221
|
+
a.#weakType,
|
|
3222
|
+
false,
|
|
3223
|
+
false
|
|
3224
|
+
]);
|
|
3225
|
+
},
|
|
3073
3226
|
[Primitive.Jit](args, { jaxpr }) {
|
|
3074
3227
|
if (jaxpr.inBinders.length !== args.length) throw new Error(`jit expects ${jaxpr.inBinders.length} args, got ${args.length}`);
|
|
3075
3228
|
const { backend, committed } = Array$1.#computeBackend("jit", args);
|
|
@@ -3175,7 +3328,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
3175
3328
|
device
|
|
3176
3329
|
});
|
|
3177
3330
|
} else {
|
|
3178
|
-
const weakType = dtype == void 0;
|
|
3331
|
+
const weakType = dtype == void 0 && shape$1.length === 0;
|
|
3179
3332
|
dtype = dtype ?? require_backend.DType.Float32;
|
|
3180
3333
|
const data = require_backend.dtypedJsArray(dtype, flat);
|
|
3181
3334
|
return arrayFromData(data, shape$1, {
|
|
@@ -3289,7 +3442,7 @@ function ones(shape$1, { dtype, device } = {}) {
|
|
|
3289
3442
|
}
|
|
3290
3443
|
/** Return a new array of given shape and type, filled with `fill_value`. */
|
|
3291
3444
|
function full(shape$1, fillValue, { dtype, device } = {}) {
|
|
3292
|
-
let weakType = dtype == void 0;
|
|
3445
|
+
let weakType = dtype == void 0 && shape$1.length === 0;
|
|
3293
3446
|
if (typeof fillValue === "number") dtype = dtype ?? require_backend.DType.Float32;
|
|
3294
3447
|
else if (typeof fillValue === "boolean") {
|
|
3295
3448
|
dtype = dtype ?? require_backend.DType.Bool;
|
|
@@ -3447,6 +3600,27 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
|
|
|
3447
3600
|
committed: device != void 0
|
|
3448
3601
|
});
|
|
3449
3602
|
}
|
|
3603
|
+
/**
|
|
3604
|
+
* Return numbers spaced evenly on a log scale.
|
|
3605
|
+
*
|
|
3606
|
+
* In linear space, the sequence starts at `base ** start` and ends at
|
|
3607
|
+
* `base ** stop` (see `endpoint` below).
|
|
3608
|
+
*
|
|
3609
|
+
* @param start - `base ** start` is the starting value of the sequence.
|
|
3610
|
+
* @param stop - `base ** stop` is the final value of the sequence, unless `endpoint` is false.
|
|
3611
|
+
* @param num - Number of samples to generate. Default is 50.
|
|
3612
|
+
* @param endpoint - If true, `stop` is the last sample. Otherwise, it is not included. Default is true.
|
|
3613
|
+
* @param base - The base of the log space. Default is 10.
|
|
3614
|
+
* @returns Array of evenly spaced values on a log scale.
|
|
3615
|
+
*/
|
|
3616
|
+
function logspace(start, stop, num = 50, endpoint = true, base = 10, { dtype, device } = {}) {
|
|
3617
|
+
const y = linspace(start, stop, num, endpoint, {
|
|
3618
|
+
dtype,
|
|
3619
|
+
device
|
|
3620
|
+
});
|
|
3621
|
+
const logBase = Math.log(base);
|
|
3622
|
+
return exp$1(mul(y, logBase));
|
|
3623
|
+
}
|
|
3450
3624
|
function aluCompare(a, b, op) {
|
|
3451
3625
|
switch (op) {
|
|
3452
3626
|
case CompareOp.Less: return require_backend.AluExp.cmplt(a, b);
|
|
@@ -3524,6 +3698,7 @@ var BatchTrace = class extends Trace {
|
|
|
3524
3698
|
return valOuts$1.map((x) => new BatchTracer(this, x, null));
|
|
3525
3699
|
}
|
|
3526
3700
|
const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
|
|
3701
|
+
if (valOuts.length !== bdimOuts.length) throw new Error(`vmap rule for ${primitive} returned mismatched lengths: ${valOuts.length} vs ${bdimOuts.length}`);
|
|
3527
3702
|
return require_backend.zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
|
|
3528
3703
|
}
|
|
3529
3704
|
get axisSize() {
|
|
@@ -3535,13 +3710,13 @@ var BatchTrace = class extends Trace {
|
|
|
3535
3710
|
*
|
|
3536
3711
|
* Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
|
|
3537
3712
|
*/
|
|
3538
|
-
function broadcastBatcher(
|
|
3539
|
-
return (axisSize, args, dims) => {
|
|
3713
|
+
function broadcastBatcher(prim) {
|
|
3714
|
+
return (axisSize, args, dims, params) => {
|
|
3540
3715
|
if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
|
|
3541
3716
|
const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
|
|
3542
3717
|
const firstIdx = dims.findIndex((d) => d !== null);
|
|
3543
3718
|
const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
|
|
3544
|
-
if (require_backend.zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[
|
|
3719
|
+
if (require_backend.zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[bind1(prim, args, params)], [nd + firstBdim]];
|
|
3545
3720
|
args = args.map((x, i) => {
|
|
3546
3721
|
if (dims[i] === null) return x;
|
|
3547
3722
|
x = moveBatchAxis(axisSize, dims[i], 0, x);
|
|
@@ -3552,37 +3727,45 @@ function broadcastBatcher(op) {
|
|
|
3552
3727
|
]);
|
|
3553
3728
|
return x;
|
|
3554
3729
|
});
|
|
3555
|
-
return [[
|
|
3730
|
+
return [[bind1(prim, args, params)], [0]];
|
|
3556
3731
|
};
|
|
3557
3732
|
}
|
|
3558
|
-
function unopBatcher(
|
|
3733
|
+
function unopBatcher(prim) {
|
|
3559
3734
|
return (axisSize, [x], [xBdim], params) => {
|
|
3560
|
-
return [[
|
|
3735
|
+
return [[bind1(prim, [x], params)], [xBdim]];
|
|
3736
|
+
};
|
|
3737
|
+
}
|
|
3738
|
+
function lastDimsBatcher(prim, inputDims, numOutputs = 1) {
|
|
3739
|
+
return (axisSize, [x], [xBdim], params) => {
|
|
3740
|
+
require_backend.assertNonNull(xBdim);
|
|
3741
|
+
if (xBdim < x.ndim - inputDims) return [bind(prim, [x], params), require_backend.rep(numOutputs, xBdim)];
|
|
3742
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3743
|
+
return [bind(prim, [x], params), require_backend.rep(numOutputs, 0)];
|
|
3561
3744
|
};
|
|
3562
3745
|
}
|
|
3563
3746
|
const vmapRules = {
|
|
3564
|
-
[Primitive.Add]: broadcastBatcher(
|
|
3565
|
-
[Primitive.Mul]: broadcastBatcher(
|
|
3566
|
-
[Primitive.Idiv]: broadcastBatcher(
|
|
3567
|
-
[Primitive.Mod]: broadcastBatcher(
|
|
3568
|
-
[Primitive.Min]: broadcastBatcher(
|
|
3569
|
-
[Primitive.Max]: broadcastBatcher(
|
|
3570
|
-
[Primitive.Neg]: unopBatcher(
|
|
3571
|
-
[Primitive.Reciprocal]: unopBatcher(
|
|
3572
|
-
[Primitive.Floor]: unopBatcher(
|
|
3573
|
-
[Primitive.Ceil]: unopBatcher(
|
|
3574
|
-
[Primitive.StopGradient]: unopBatcher(
|
|
3575
|
-
[Primitive.Cast]: unopBatcher(
|
|
3576
|
-
[Primitive.Bitcast]: unopBatcher(
|
|
3577
|
-
[Primitive.Sin]: unopBatcher(
|
|
3578
|
-
[Primitive.Cos]: unopBatcher(
|
|
3579
|
-
[Primitive.Asin]: unopBatcher(
|
|
3580
|
-
[Primitive.Atan]: unopBatcher(
|
|
3581
|
-
[Primitive.Exp]: unopBatcher(
|
|
3582
|
-
[Primitive.Log]: unopBatcher(
|
|
3583
|
-
[Primitive.Erf]: unopBatcher(
|
|
3584
|
-
[Primitive.Erfc]: unopBatcher(
|
|
3585
|
-
[Primitive.Sqrt]: unopBatcher(
|
|
3747
|
+
[Primitive.Add]: broadcastBatcher(Primitive.Add),
|
|
3748
|
+
[Primitive.Mul]: broadcastBatcher(Primitive.Mul),
|
|
3749
|
+
[Primitive.Idiv]: broadcastBatcher(Primitive.Idiv),
|
|
3750
|
+
[Primitive.Mod]: broadcastBatcher(Primitive.Mod),
|
|
3751
|
+
[Primitive.Min]: broadcastBatcher(Primitive.Min),
|
|
3752
|
+
[Primitive.Max]: broadcastBatcher(Primitive.Max),
|
|
3753
|
+
[Primitive.Neg]: unopBatcher(Primitive.Neg),
|
|
3754
|
+
[Primitive.Reciprocal]: unopBatcher(Primitive.Reciprocal),
|
|
3755
|
+
[Primitive.Floor]: unopBatcher(Primitive.Floor),
|
|
3756
|
+
[Primitive.Ceil]: unopBatcher(Primitive.Ceil),
|
|
3757
|
+
[Primitive.StopGradient]: unopBatcher(Primitive.StopGradient),
|
|
3758
|
+
[Primitive.Cast]: unopBatcher(Primitive.Cast),
|
|
3759
|
+
[Primitive.Bitcast]: unopBatcher(Primitive.Bitcast),
|
|
3760
|
+
[Primitive.Sin]: unopBatcher(Primitive.Sin),
|
|
3761
|
+
[Primitive.Cos]: unopBatcher(Primitive.Cos),
|
|
3762
|
+
[Primitive.Asin]: unopBatcher(Primitive.Asin),
|
|
3763
|
+
[Primitive.Atan]: unopBatcher(Primitive.Atan),
|
|
3764
|
+
[Primitive.Exp]: unopBatcher(Primitive.Exp),
|
|
3765
|
+
[Primitive.Log]: unopBatcher(Primitive.Log),
|
|
3766
|
+
[Primitive.Erf]: unopBatcher(Primitive.Erf),
|
|
3767
|
+
[Primitive.Erfc]: unopBatcher(Primitive.Erfc),
|
|
3768
|
+
[Primitive.Sqrt]: unopBatcher(Primitive.Sqrt),
|
|
3586
3769
|
[Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
|
|
3587
3770
|
require_backend.assertNonNull(xBdim);
|
|
3588
3771
|
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
@@ -3604,10 +3787,25 @@ const vmapRules = {
|
|
|
3604
3787
|
});
|
|
3605
3788
|
return [[z], [0]];
|
|
3606
3789
|
},
|
|
3607
|
-
[Primitive.Compare](
|
|
3608
|
-
|
|
3790
|
+
[Primitive.Compare]: broadcastBatcher(Primitive.Compare),
|
|
3791
|
+
[Primitive.Where]: broadcastBatcher(Primitive.Where),
|
|
3792
|
+
[Primitive.Concatenate](axisSize, xs, xBdims, { axis }) {
|
|
3793
|
+
const minBdim = Math.min(...xBdims.filter((d) => d !== null));
|
|
3794
|
+
xs = xs.map((x, i) => moveBatchAxis(axisSize, xBdims[i], minBdim, x));
|
|
3795
|
+
const newAxis = axis + (minBdim <= axis ? 1 : 0);
|
|
3796
|
+
return [[concatenate$1(xs, newAxis)], [minBdim]];
|
|
3797
|
+
},
|
|
3798
|
+
[Primitive.Split](axisSize, [x], [xBdim], { axis, sizes }) {
|
|
3799
|
+
require_backend.assertNonNull(xBdim);
|
|
3800
|
+
const newAxis = axis + (xBdim <= axis ? 1 : 0);
|
|
3801
|
+
const outs = split$2(x, newAxis, sizes);
|
|
3802
|
+
return [outs, require_backend.rep(outs.length, xBdim)];
|
|
3803
|
+
},
|
|
3804
|
+
[Primitive.RandomBits](axisSize, [k0, k1], [bdim0, bdim1], { shape: shape$1, mode }) {
|
|
3805
|
+
k0 = moveBatchAxis(axisSize, bdim0, 0, k0);
|
|
3806
|
+
k1 = moveBatchAxis(axisSize, bdim1, 0, k1);
|
|
3807
|
+
return [[randomBits(k0, k1, [axisSize, ...shape$1], mode)], [0]];
|
|
3609
3808
|
},
|
|
3610
|
-
[Primitive.Where]: broadcastBatcher(where$1),
|
|
3611
3809
|
[Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
|
|
3612
3810
|
if (indicesBdim.every((d) => d === null)) {
|
|
3613
3811
|
require_backend.assertNonNull(xBdim);
|
|
@@ -3669,18 +3867,8 @@ const vmapRules = {
|
|
|
3669
3867
|
const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
|
|
3670
3868
|
return [[pad$1(x, newWidth)], [xBdim]];
|
|
3671
3869
|
},
|
|
3672
|
-
[Primitive.Sort](
|
|
3673
|
-
|
|
3674
|
-
if (xBdim !== x.ndim - 1) return [[sort$1(x)], [xBdim]];
|
|
3675
|
-
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3676
|
-
return [[sort$1(x)], [0]];
|
|
3677
|
-
},
|
|
3678
|
-
[Primitive.Argsort](axisSize, [x], [xBdim]) {
|
|
3679
|
-
require_backend.assertNonNull(xBdim);
|
|
3680
|
-
if (xBdim !== x.ndim - 1) return [argsort$1(x), [xBdim, xBdim]];
|
|
3681
|
-
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3682
|
-
return [argsort$1(x), [0, 0]];
|
|
3683
|
-
},
|
|
3870
|
+
[Primitive.Sort]: lastDimsBatcher(Primitive.Sort, 1),
|
|
3871
|
+
[Primitive.Argsort]: lastDimsBatcher(Primitive.Argsort, 1, 2),
|
|
3684
3872
|
[Primitive.TriangularSolve](axisSize, [a, b], [aBdim, bBdim], { unitDiagonal }) {
|
|
3685
3873
|
if (aBdim === null) {
|
|
3686
3874
|
b = moveBatchAxis(axisSize, bBdim, -3, b);
|
|
@@ -3704,12 +3892,8 @@ const vmapRules = {
|
|
|
3704
3892
|
const x = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
|
|
3705
3893
|
return [[x], [0]];
|
|
3706
3894
|
},
|
|
3707
|
-
[Primitive.Cholesky](
|
|
3708
|
-
|
|
3709
|
-
if (xBdim < x.ndim - 2) return [[cholesky$2(x)], [xBdim]];
|
|
3710
|
-
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3711
|
-
return [[cholesky$2(x)], [0]];
|
|
3712
|
-
},
|
|
3895
|
+
[Primitive.Cholesky]: lastDimsBatcher(Primitive.Cholesky, 2),
|
|
3896
|
+
[Primitive.LU]: lastDimsBatcher(Primitive.LU, 2, 3),
|
|
3713
3897
|
[Primitive.Jit](axisSize, args, dims, { name, jaxpr }) {
|
|
3714
3898
|
const newJaxpr = vmapJaxpr(jaxpr, axisSize, dims);
|
|
3715
3899
|
const outs = bind(Primitive.Jit, [...newJaxpr.consts.map((c) => c.ref), ...args], {
|
|
@@ -3860,6 +4044,16 @@ function batchMatmulT(a, b) {
|
|
|
3860
4044
|
function mT(a) {
|
|
3861
4045
|
return moveaxis(a, -2, -1);
|
|
3862
4046
|
}
|
|
4047
|
+
function sliceAxis(a, axis, p) {
|
|
4048
|
+
const slices = Array(a.shape.length).fill([]);
|
|
4049
|
+
slices[require_backend.checkAxis(axis, a.ndim)] = p;
|
|
4050
|
+
return a.slice(...slices);
|
|
4051
|
+
}
|
|
4052
|
+
function padAxis(a, axis, p) {
|
|
4053
|
+
const pads = Array(a.shape.length).fill([0, 0]);
|
|
4054
|
+
pads[require_backend.checkAxis(axis, a.ndim)] = p;
|
|
4055
|
+
return pad$1(a, pads);
|
|
4056
|
+
}
|
|
3863
4057
|
const jvpRules = {
|
|
3864
4058
|
[Primitive.Add]: linearTangentsJvp(Primitive.Add),
|
|
3865
4059
|
[Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
|
|
@@ -3958,6 +4152,8 @@ const jvpRules = {
|
|
|
3958
4152
|
dcond.dispose();
|
|
3959
4153
|
return [[where$1(cond.ref, x, y)], [where$1(cond, dx, dy)]];
|
|
3960
4154
|
},
|
|
4155
|
+
[Primitive.Concatenate]: linearTangentsJvp(Primitive.Concatenate),
|
|
4156
|
+
[Primitive.Split]: linearTangentsJvp(Primitive.Split),
|
|
3961
4157
|
[Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
|
|
3962
4158
|
[Primitive.Gather]([x, ...indices], [dx, ..._], { axis, outDim }) {
|
|
3963
4159
|
const indicesRef = indices.map((t) => t.ref);
|
|
@@ -3992,6 +4188,38 @@ const jvpRules = {
|
|
|
3992
4188
|
const dL = batchMatmulT(L.ref, triu(ST.ref, 1).add(triu(ST)).mul(.5));
|
|
3993
4189
|
return [[L], [dL]];
|
|
3994
4190
|
},
|
|
4191
|
+
[Primitive.LU]([a], [da]) {
|
|
4192
|
+
const [luMatrix, pivots, permutation] = lu$1(a);
|
|
4193
|
+
const [m, n] = a.shape.slice(-2);
|
|
4194
|
+
const k = Math.min(m, n);
|
|
4195
|
+
const luSliceL = sliceAxis(luMatrix.ref, -1, [0, k]);
|
|
4196
|
+
const lLower = tril(luSliceL, -1);
|
|
4197
|
+
const lPadded = m > k ? padAxis(lLower, -1, [0, m - k]) : lLower;
|
|
4198
|
+
const L = lPadded.add(eye(m));
|
|
4199
|
+
const luSliceU = sliceAxis(luMatrix.ref, -2, [0, k]);
|
|
4200
|
+
const uUpper = triu(luSliceU);
|
|
4201
|
+
const uPadded = n > k ? padAxis(uUpper, -2, [0, n - k]) : uUpper;
|
|
4202
|
+
const uEye = n > k ? padAxis(padAxis(eye(n - k), -1, [k, 0]), -2, [k, 0]) : zerosLike$1(uPadded.ref);
|
|
4203
|
+
const U = uPadded.add(uEye);
|
|
4204
|
+
const P = permutation.ref.reshape([...permutation.shape, 1]).equal(arange(m)).astype(da.dtype);
|
|
4205
|
+
const pda = batchMatmulT(P, mT(da));
|
|
4206
|
+
const la = mT(triangularSolve$1(L.ref, mT(pda), {
|
|
4207
|
+
lower: true,
|
|
4208
|
+
unitDiagonal: true
|
|
4209
|
+
}));
|
|
4210
|
+
const lau = triangularSolve$1(mT(U.ref), la, { lower: true });
|
|
4211
|
+
const lDot = batchMatmulT(L, mT(tril(lau.ref, -1)));
|
|
4212
|
+
const uDot = batchMatmulT(triu(lau), mT(U));
|
|
4213
|
+
return [[
|
|
4214
|
+
luMatrix,
|
|
4215
|
+
pivots,
|
|
4216
|
+
permutation
|
|
4217
|
+
], [
|
|
4218
|
+
lDot.add(uDot),
|
|
4219
|
+
zerosLike$1(pivots.ref),
|
|
4220
|
+
zerosLike$1(permutation.ref)
|
|
4221
|
+
]];
|
|
4222
|
+
},
|
|
3995
4223
|
[Primitive.Jit](primals, tangents, { name, jaxpr }) {
|
|
3996
4224
|
const newJaxpr = jvpJaxpr(jaxpr);
|
|
3997
4225
|
const outs = bind(Primitive.Jit, [
|
|
@@ -4529,6 +4757,15 @@ const transposeRules = {
|
|
|
4529
4757
|
cond.dispose();
|
|
4530
4758
|
return cts;
|
|
4531
4759
|
},
|
|
4760
|
+
[Primitive.Concatenate]([ct], inputs, { axis }) {
|
|
4761
|
+
if (inputs.some((x) => !(x instanceof UndefPrimal))) throw new NonlinearError(Primitive.Concatenate);
|
|
4762
|
+
const sizes = inputs.map((x) => x.aval.shape[axis]);
|
|
4763
|
+
return split$2(ct, axis, sizes);
|
|
4764
|
+
},
|
|
4765
|
+
[Primitive.Split](cts, [x], { axis }) {
|
|
4766
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Split);
|
|
4767
|
+
return [concatenate$1(cts, axis)];
|
|
4768
|
+
},
|
|
4532
4769
|
[Primitive.Gather]([ct], [x, ...indices], { axis, outDim }) {
|
|
4533
4770
|
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
|
|
4534
4771
|
if (indices.some((i) => i instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
|
|
@@ -4804,8 +5041,8 @@ function computeSizeMap({ shapes, lhsIndices, rhsIndex }) {
|
|
|
4804
5041
|
const idx = lhsIndex[j];
|
|
4805
5042
|
const dim = shape$1[j];
|
|
4806
5043
|
const existing = sizeMap.get(idx);
|
|
4807
|
-
if (existing === void 0) sizeMap.set(idx, dim);
|
|
4808
|
-
else if (existing !== dim) throw new Error(`Inconsistent size for index ${idx} in einsum: ${existing} vs ${dim}`);
|
|
5044
|
+
if (existing === void 0 || existing === 1) sizeMap.set(idx, dim);
|
|
5045
|
+
else if (existing !== dim && dim !== 1) throw new Error(`Inconsistent size for index ${idx} in einsum: ${existing} vs ${dim}`);
|
|
4809
5046
|
}
|
|
4810
5047
|
}
|
|
4811
5048
|
for (const [idx, size$1] of sizeMap) if (!Number.isInteger(idx) || idx < 0) throw new Error(`Invalid index ${idx} in einsum expression, must be non-negative integer`);
|
|
@@ -4961,27 +5198,53 @@ function ifft(a, axis = -1) {
|
|
|
4961
5198
|
//#region src/library/numpy-linalg.ts
|
|
4962
5199
|
var numpy_linalg_exports = {};
|
|
4963
5200
|
__export(numpy_linalg_exports, {
|
|
4964
|
-
cholesky: () => cholesky
|
|
5201
|
+
cholesky: () => cholesky,
|
|
5202
|
+
det: () => det,
|
|
4965
5203
|
diagonal: () => diagonal,
|
|
5204
|
+
inv: () => inv,
|
|
4966
5205
|
lstsq: () => lstsq,
|
|
4967
5206
|
matmul: () => matmul,
|
|
5207
|
+
matrixPower: () => matrixPower,
|
|
4968
5208
|
matrixTranspose: () => matrixTranspose,
|
|
4969
5209
|
outer: () => outer,
|
|
5210
|
+
slogdet: () => slogdet,
|
|
5211
|
+
solve: () => solve,
|
|
4970
5212
|
tensordot: () => tensordot,
|
|
4971
5213
|
trace: () => trace,
|
|
4972
5214
|
vecdot: () => vecdot
|
|
4973
5215
|
});
|
|
5216
|
+
function checkSquare(name, a) {
|
|
5217
|
+
if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`${name}: input must be at least 2D square matrix, got ${a.aval}`);
|
|
5218
|
+
return a.shape[a.ndim - 1];
|
|
5219
|
+
}
|
|
4974
5220
|
/**
|
|
4975
5221
|
* Compute the Cholesky decomposition of a (batched) positive-definite matrix.
|
|
4976
5222
|
*
|
|
4977
5223
|
* This is like `jax.lax.linalg.cholesky()`, except with an option to symmetrize
|
|
4978
5224
|
* the input matrix, which is on by default.
|
|
4979
5225
|
*/
|
|
4980
|
-
function cholesky
|
|
5226
|
+
function cholesky(a, { upper = false, symmetrizeInput = true } = {}) {
|
|
4981
5227
|
a = fudgeArray(a);
|
|
4982
|
-
|
|
5228
|
+
checkSquare("cholesky", a);
|
|
4983
5229
|
if (symmetrizeInput) a = a.ref.add(matrixTranspose(a)).mul(.5);
|
|
4984
|
-
return cholesky(a, { upper });
|
|
5230
|
+
return cholesky$1(a, { upper });
|
|
5231
|
+
}
|
|
5232
|
+
/** Compute the determinant of a square matrix (batched). */
|
|
5233
|
+
function det(a) {
|
|
5234
|
+
a = fudgeArray(a);
|
|
5235
|
+
const n = checkSquare("det", a);
|
|
5236
|
+
const [lu$2, pivots, permutation] = lu(a);
|
|
5237
|
+
permutation.dispose();
|
|
5238
|
+
const parity = pivots.notEqual(arange(n)).astype(int32).sum(-1).mod(2);
|
|
5239
|
+
const sign$1 = parity.mul(-2).add(1);
|
|
5240
|
+
const diag$1 = lu$2.diagonal(0, -1, -2);
|
|
5241
|
+
return prod$1(diag$1, -1).mul(sign$1);
|
|
5242
|
+
}
|
|
5243
|
+
/** Compute the inverse of a square matrix (batched). */
|
|
5244
|
+
function inv(a) {
|
|
5245
|
+
a = fudgeArray(a);
|
|
5246
|
+
const n = checkSquare("inv", a);
|
|
5247
|
+
return solve(a, eye(n));
|
|
4985
5248
|
}
|
|
4986
5249
|
/**
|
|
4987
5250
|
* Return the least-squares solution to a linear equation.
|
|
@@ -5005,7 +5268,7 @@ function lstsq(a, b) {
|
|
|
5005
5268
|
const at = matrixTranspose(a.ref);
|
|
5006
5269
|
if (m <= n) {
|
|
5007
5270
|
const aat = matmul(a, at.ref);
|
|
5008
|
-
const l = cholesky
|
|
5271
|
+
const l = cholesky(aat, { symmetrizeInput: false });
|
|
5009
5272
|
const lb = triangularSolve(l.ref, b, {
|
|
5010
5273
|
leftSide: true,
|
|
5011
5274
|
lower: true
|
|
@@ -5017,7 +5280,7 @@ function lstsq(a, b) {
|
|
|
5017
5280
|
return matmul(at, llb.ref);
|
|
5018
5281
|
} else {
|
|
5019
5282
|
const ata = matmul(at.ref, a);
|
|
5020
|
-
const l = cholesky
|
|
5283
|
+
const l = cholesky(ata, { symmetrizeInput: false });
|
|
5021
5284
|
const atb = matmul(at, b);
|
|
5022
5285
|
const lb = triangularSolve(l.ref, atb, {
|
|
5023
5286
|
leftSide: true,
|
|
@@ -5030,6 +5293,169 @@ function lstsq(a, b) {
|
|
|
5030
5293
|
return llb;
|
|
5031
5294
|
}
|
|
5032
5295
|
}
|
|
5296
|
+
/** Raise a square matrix to an integer power, via repeated squarings. */
|
|
5297
|
+
function matrixPower(a, n) {
|
|
5298
|
+
if (!Number.isInteger(n)) throw new Error(`matrixPower: exponent must be an integer, got ${n}`);
|
|
5299
|
+
a = fudgeArray(a);
|
|
5300
|
+
const m = checkSquare("matrixPower", a);
|
|
5301
|
+
if (n === 0) {
|
|
5302
|
+
a.dispose();
|
|
5303
|
+
return broadcastTo(eye(m), a.shape);
|
|
5304
|
+
}
|
|
5305
|
+
if (n < 0) {
|
|
5306
|
+
a = inv(a);
|
|
5307
|
+
n = -n;
|
|
5308
|
+
}
|
|
5309
|
+
let result = null;
|
|
5310
|
+
let a2k = a;
|
|
5311
|
+
for (let k = 0; n; k++) {
|
|
5312
|
+
if (k > 0) a2k = matmul(a2k.ref, a2k);
|
|
5313
|
+
if (n % 2 === 1) result = result === null ? a2k.ref : matmul(result, a2k.ref);
|
|
5314
|
+
n = Math.floor(n / 2);
|
|
5315
|
+
}
|
|
5316
|
+
a2k.dispose();
|
|
5317
|
+
return result;
|
|
5318
|
+
}
|
|
5319
|
+
/** Return sign and natural logarithm of the determinant of `a`. */
|
|
5320
|
+
function slogdet(a) {
|
|
5321
|
+
a = fudgeArray(a);
|
|
5322
|
+
const n = checkSquare("slogdet", a);
|
|
5323
|
+
const [lu$2, pivots, permutation] = lu(a);
|
|
5324
|
+
permutation.dispose();
|
|
5325
|
+
let parity = pivots.notEqual(arange(n)).astype(int32).sum(-1);
|
|
5326
|
+
const diag$1 = lu$2.diagonal(0, -1, -2);
|
|
5327
|
+
parity = parity.add(diag$1.ref.less(0).astype(int32).sum(-1)).mod(2);
|
|
5328
|
+
const logabsdet = log(absolute(diag$1)).sum(-1);
|
|
5329
|
+
const sign$1 = parity.mul(-2).add(1);
|
|
5330
|
+
return [sign$1, logabsdet];
|
|
5331
|
+
}
|
|
5332
|
+
/**
|
|
5333
|
+
* Solve a linear system of equations.
|
|
5334
|
+
*
|
|
5335
|
+
* This solves a (batched) linear system of equations `a @ x = b` for `x` given
|
|
5336
|
+
* `a` and `b`. If `a` is singular, this will return `nan` or `inf` values.
|
|
5337
|
+
*
|
|
5338
|
+
* @param a - Coefficient matrix of shape `(..., N, N)`.
|
|
5339
|
+
* @param b - Values of shape `(N,)` or `(..., N, M)`.
|
|
5340
|
+
* @returns Solution `x` of shape `(..., N)` or `(..., N, M)`.
|
|
5341
|
+
*/
|
|
5342
|
+
function solve(a, b) {
|
|
5343
|
+
a = fudgeArray(a);
|
|
5344
|
+
b = fudgeArray(b);
|
|
5345
|
+
const n = checkSquare("solve", a);
|
|
5346
|
+
if (b.ndim === 0) throw new Error(`solve: b cannot be scalar`);
|
|
5347
|
+
const bIs1d = b.ndim === 1;
|
|
5348
|
+
if (bIs1d) b = b.reshape([...b.shape, 1]);
|
|
5349
|
+
if (b.shape[b.ndim - 2] !== n) throw new Error(`solve: leading dimension of b must match size of a, got a=${a.aval}, b=${b.aval}`);
|
|
5350
|
+
const m = b.shape[b.ndim - 1];
|
|
5351
|
+
const batchDims = require_backend.generalBroadcast(a.shape.slice(0, -2), b.shape.slice(0, -2));
|
|
5352
|
+
a = broadcastTo(a, [
|
|
5353
|
+
...batchDims,
|
|
5354
|
+
n,
|
|
5355
|
+
n
|
|
5356
|
+
]);
|
|
5357
|
+
b = broadcastTo(b, [
|
|
5358
|
+
...batchDims,
|
|
5359
|
+
n,
|
|
5360
|
+
m
|
|
5361
|
+
]);
|
|
5362
|
+
const [lu$2, pivots, permutation] = lu(a);
|
|
5363
|
+
pivots.dispose();
|
|
5364
|
+
const P = arange(n).equal(permutation.reshape([...permutation.shape, 1])).astype(b.dtype);
|
|
5365
|
+
const LPb = triangularSolve(lu$2.ref, matmul(P, b), {
|
|
5366
|
+
leftSide: true,
|
|
5367
|
+
lower: true,
|
|
5368
|
+
unitDiagonal: true
|
|
5369
|
+
});
|
|
5370
|
+
let x = triangularSolve(lu$2, LPb.ref, {
|
|
5371
|
+
leftSide: true,
|
|
5372
|
+
lower: false
|
|
5373
|
+
});
|
|
5374
|
+
if (bIs1d) x = squeeze(x, -1);
|
|
5375
|
+
return x;
|
|
5376
|
+
}
|
|
5377
|
+
|
|
5378
|
+
//#endregion
|
|
5379
|
+
//#region src/library/numpy/dtype-info.ts
|
|
5380
|
+
/** Machine limits for floating-point types. */
|
|
5381
|
+
function finfo(dtype) {
|
|
5382
|
+
if (!require_backend.isFloatDtype(dtype)) throw new Error(`finfo: received ${dtype}, must be a floating-point type`);
|
|
5383
|
+
switch (dtype) {
|
|
5384
|
+
case require_backend.DType.Float16: return Object.freeze({
|
|
5385
|
+
bits: 16,
|
|
5386
|
+
dtype: require_backend.DType.Float16,
|
|
5387
|
+
eps: 2 ** -10,
|
|
5388
|
+
epsneg: 2 ** -11,
|
|
5389
|
+
machep: -10,
|
|
5390
|
+
max: 65504,
|
|
5391
|
+
maxexp: 16,
|
|
5392
|
+
min: -65504,
|
|
5393
|
+
minexp: -14,
|
|
5394
|
+
negep: -24,
|
|
5395
|
+
nexp: 5,
|
|
5396
|
+
nmant: 10,
|
|
5397
|
+
precision: 3,
|
|
5398
|
+
resolution: .001,
|
|
5399
|
+
smallestNormal: 2 ** -14,
|
|
5400
|
+
smallestSubnormal: 2 ** -24
|
|
5401
|
+
});
|
|
5402
|
+
case require_backend.DType.Float32: return Object.freeze({
|
|
5403
|
+
bits: 32,
|
|
5404
|
+
dtype: require_backend.DType.Float32,
|
|
5405
|
+
eps: 2 ** -23,
|
|
5406
|
+
epsneg: 2 ** -24,
|
|
5407
|
+
machep: -23,
|
|
5408
|
+
max: 34028234663852886e22,
|
|
5409
|
+
maxexp: 128,
|
|
5410
|
+
min: -34028234663852886e22,
|
|
5411
|
+
minexp: -126,
|
|
5412
|
+
negep: -24,
|
|
5413
|
+
nexp: 8,
|
|
5414
|
+
nmant: 23,
|
|
5415
|
+
precision: 6,
|
|
5416
|
+
resolution: 1e-6,
|
|
5417
|
+
smallestNormal: 2 ** -126,
|
|
5418
|
+
smallestSubnormal: 2 ** -149
|
|
5419
|
+
});
|
|
5420
|
+
case require_backend.DType.Float64: return Object.freeze({
|
|
5421
|
+
bits: 64,
|
|
5422
|
+
dtype: require_backend.DType.Float64,
|
|
5423
|
+
eps: 2 ** -52,
|
|
5424
|
+
epsneg: 2 ** -53,
|
|
5425
|
+
machep: -52,
|
|
5426
|
+
max: Number.MAX_VALUE,
|
|
5427
|
+
maxexp: 1024,
|
|
5428
|
+
min: -Number.MAX_VALUE,
|
|
5429
|
+
minexp: -1022,
|
|
5430
|
+
negep: -53,
|
|
5431
|
+
nexp: 11,
|
|
5432
|
+
nmant: 52,
|
|
5433
|
+
precision: 15,
|
|
5434
|
+
resolution: 1e-15,
|
|
5435
|
+
smallestNormal: 2 ** -1022,
|
|
5436
|
+
smallestSubnormal: 2 ** -1074
|
|
5437
|
+
});
|
|
5438
|
+
default: throw new Error(`finfo: unsupported dtype ${dtype}`);
|
|
5439
|
+
}
|
|
5440
|
+
}
|
|
5441
|
+
/** Machine limits for integer types. */
|
|
5442
|
+
function iinfo(dtype) {
|
|
5443
|
+
switch (dtype) {
|
|
5444
|
+
case require_backend.DType.Int32: return Object.freeze({
|
|
5445
|
+
bits: 32,
|
|
5446
|
+
dtype: require_backend.DType.Int32,
|
|
5447
|
+
max: 2147483647,
|
|
5448
|
+
min: -2147483648
|
|
5449
|
+
});
|
|
5450
|
+
case require_backend.DType.Uint32: return Object.freeze({
|
|
5451
|
+
bits: 32,
|
|
5452
|
+
dtype: require_backend.DType.Uint32,
|
|
5453
|
+
max: 4294967295,
|
|
5454
|
+
min: 0
|
|
5455
|
+
});
|
|
5456
|
+
default: throw new Error(`iinfo: unsupported dtype ${dtype}`);
|
|
5457
|
+
}
|
|
5458
|
+
}
|
|
5033
5459
|
|
|
5034
5460
|
//#endregion
|
|
5035
5461
|
//#region src/library/numpy.ts
|
|
@@ -5085,6 +5511,7 @@ __export(numpy_exports, {
|
|
|
5085
5511
|
diag: () => diag,
|
|
5086
5512
|
diagonal: () => diagonal,
|
|
5087
5513
|
divide: () => trueDivide,
|
|
5514
|
+
divmod: () => divmod,
|
|
5088
5515
|
dot: () => dot$1,
|
|
5089
5516
|
dstack: () => dstack,
|
|
5090
5517
|
e: () => e,
|
|
@@ -5097,6 +5524,7 @@ __export(numpy_exports, {
|
|
|
5097
5524
|
expm1: () => expm1,
|
|
5098
5525
|
eye: () => eye,
|
|
5099
5526
|
fft: () => numpy_fft_exports,
|
|
5527
|
+
finfo: () => finfo,
|
|
5100
5528
|
flip: () => flip,
|
|
5101
5529
|
fliplr: () => fliplr,
|
|
5102
5530
|
flipud: () => flipud,
|
|
@@ -5104,6 +5532,7 @@ __export(numpy_exports, {
|
|
|
5104
5532
|
float32: () => float32,
|
|
5105
5533
|
float64: () => float64,
|
|
5106
5534
|
floor: () => floor,
|
|
5535
|
+
floorDivide: () => floorDivide,
|
|
5107
5536
|
fmod: () => fmod,
|
|
5108
5537
|
frexp: () => frexp,
|
|
5109
5538
|
full: () => full,
|
|
@@ -5116,6 +5545,7 @@ __export(numpy_exports, {
|
|
|
5116
5545
|
hstack: () => hstack,
|
|
5117
5546
|
hypot: () => hypot,
|
|
5118
5547
|
identity: () => identity$1,
|
|
5548
|
+
iinfo: () => iinfo,
|
|
5119
5549
|
inf: () => inf,
|
|
5120
5550
|
inner: () => inner,
|
|
5121
5551
|
int32: () => int32,
|
|
@@ -5133,6 +5563,7 @@ __export(numpy_exports, {
|
|
|
5133
5563
|
log10: () => log10,
|
|
5134
5564
|
log1p: () => log1p,
|
|
5135
5565
|
log2: () => log2,
|
|
5566
|
+
logspace: () => logspace,
|
|
5136
5567
|
matmul: () => matmul,
|
|
5137
5568
|
matrixTranspose: () => matrixTranspose,
|
|
5138
5569
|
max: () => max,
|
|
@@ -5169,9 +5600,11 @@ __export(numpy_exports, {
|
|
|
5169
5600
|
shape: () => shape,
|
|
5170
5601
|
sign: () => sign,
|
|
5171
5602
|
sin: () => sin,
|
|
5603
|
+
sinc: () => sinc,
|
|
5172
5604
|
sinh: () => sinh,
|
|
5173
5605
|
size: () => size,
|
|
5174
5606
|
sort: () => sort,
|
|
5607
|
+
split: () => split$1,
|
|
5175
5608
|
sqrt: () => sqrt,
|
|
5176
5609
|
square: () => square,
|
|
5177
5610
|
squeeze: () => squeeze,
|
|
@@ -5179,6 +5612,7 @@ __export(numpy_exports, {
|
|
|
5179
5612
|
std: () => std,
|
|
5180
5613
|
subtract: () => subtract,
|
|
5181
5614
|
sum: () => sum,
|
|
5615
|
+
take: () => take,
|
|
5182
5616
|
tan: () => tan,
|
|
5183
5617
|
tanh: () => tanh,
|
|
5184
5618
|
tensordot: () => tensordot,
|
|
@@ -5437,6 +5871,45 @@ function flip(x, axis = null) {
|
|
|
5437
5871
|
return flip$1(x, axis);
|
|
5438
5872
|
}
|
|
5439
5873
|
/**
|
|
5874
|
+
* Split an array into multiple sub-arrays along an axis.
|
|
5875
|
+
*
|
|
5876
|
+
* @param a - The input array to split.
|
|
5877
|
+
* @param indicesOrSections - If an integer, it indicates the number of equal
|
|
5878
|
+
* sections to create along the specified axis. If a list of integers, it
|
|
5879
|
+
* specifies the indices at which to split the array.
|
|
5880
|
+
* @param axis - The axis along which to split the array. Default is 0.
|
|
5881
|
+
*/
|
|
5882
|
+
function split$1(a, indicesOrSections, axis = 0) {
|
|
5883
|
+
a = fudgeArray(a);
|
|
5884
|
+
axis = require_backend.checkAxis(axis, a.ndim);
|
|
5885
|
+
const size$1 = a.shape[axis];
|
|
5886
|
+
let sizes;
|
|
5887
|
+
if (typeof indicesOrSections === "number") {
|
|
5888
|
+
if (size$1 % indicesOrSections !== 0) throw new Error(`Array of size ${size$1} cannot be split into ${indicesOrSections} equal parts`);
|
|
5889
|
+
const partSize = size$1 / indicesOrSections;
|
|
5890
|
+
sizes = require_backend.rep(indicesOrSections, partSize);
|
|
5891
|
+
} else {
|
|
5892
|
+
const indices = indicesOrSections;
|
|
5893
|
+
sizes = [indices[0]];
|
|
5894
|
+
for (let i = 1; i < indices.length; i++) sizes.push(indices[i] - indices[i - 1]);
|
|
5895
|
+
sizes.push(size$1 - indices[indices.length - 1]);
|
|
5896
|
+
}
|
|
5897
|
+
const results = [];
|
|
5898
|
+
for (let i = 0; i < sizes.length; i += 7) if (i === sizes.length) {
|
|
5899
|
+
results.push(a);
|
|
5900
|
+
break;
|
|
5901
|
+
} else if (i + 8 >= sizes.length) {
|
|
5902
|
+
results.push(...split$2(a, axis, sizes.slice(i)));
|
|
5903
|
+
break;
|
|
5904
|
+
} else {
|
|
5905
|
+
const groupSizes = [...sizes.slice(i, i + 7), sizes.slice(i + 7).reduce((x, y) => x + y, 0)];
|
|
5906
|
+
const outs = split$2(a, axis, groupSizes);
|
|
5907
|
+
results.push(...outs.slice(0, -1));
|
|
5908
|
+
a = outs[outs.length - 1];
|
|
5909
|
+
}
|
|
5910
|
+
return results;
|
|
5911
|
+
}
|
|
5912
|
+
/**
|
|
5440
5913
|
* Join a sequence of arrays along an existing axis.
|
|
5441
5914
|
*
|
|
5442
5915
|
* The arrays must have the same shape, except in the dimension corresponding to
|
|
@@ -5448,13 +5921,11 @@ function concatenate(xs, axis = 0) {
|
|
|
5448
5921
|
if (xs.length === 0) throw new Error("Need at least one array to concatenate");
|
|
5449
5922
|
const shapes = xs.map(shape);
|
|
5450
5923
|
axis = require_backend.checkAxis(axis, shapes[0].length);
|
|
5451
|
-
for (let i = 1; i < shapes.length; i++) if (shapes[i].length !== shapes[0].length || !shapes[i].every((d, j) => j === axis || d === shapes[0][j])) throw new Error(`Cannot concatenate arrays
|
|
5452
|
-
const makePadAxis = (start, end) => shapes[0].map((_, i) => i === axis ? [start, end] : [0, 0]);
|
|
5924
|
+
for (let i = 1; i < shapes.length; i++) if (shapes[i].length !== shapes[0].length || !shapes[i].every((d, j) => j === axis || d === shapes[0][j])) throw new Error(`Cannot concatenate arrays ${xs[0].aval} and ${xs[i].aval} along axis ${axis}`);
|
|
5453
5925
|
let result = xs[0];
|
|
5454
|
-
for (let i = 1; i < xs.length; i
|
|
5455
|
-
const
|
|
5456
|
-
|
|
5457
|
-
result = pad(result, makePadAxis(0, len2)).add(pad(xs[i], makePadAxis(len1, 0)));
|
|
5926
|
+
for (let i = 1; i < xs.length; i += 7) {
|
|
5927
|
+
const group = xs.slice(i, i + 7);
|
|
5928
|
+
result = concatenate$1([result, ...group], axis);
|
|
5458
5929
|
}
|
|
5459
5930
|
return result;
|
|
5460
5931
|
}
|
|
@@ -5706,6 +6177,20 @@ function sort(a, axis = -1) {
|
|
|
5706
6177
|
function argsort(a, axis = -1) {
|
|
5707
6178
|
return fudgeArray(a).argsort(axis);
|
|
5708
6179
|
}
|
|
6180
|
+
/**
|
|
6181
|
+
* Take elements from an array along an axis.
|
|
6182
|
+
*
|
|
6183
|
+
* This is equivalent to advanced indexing with integer indices over that
|
|
6184
|
+
* numbered axis. By default, the flattened array is used.
|
|
6185
|
+
*/
|
|
6186
|
+
function take(a, indices, axis = null) {
|
|
6187
|
+
if (axis === null) {
|
|
6188
|
+
a = ravel(a);
|
|
6189
|
+
axis = 0;
|
|
6190
|
+
}
|
|
6191
|
+
axis = require_backend.checkAxis(axis, ndim(a));
|
|
6192
|
+
return gather(a, [indices], [axis], axis);
|
|
6193
|
+
}
|
|
5709
6194
|
/** Return if two arrays are element-wise equal within a tolerance. */
|
|
5710
6195
|
function allclose(actual, expected, options) {
|
|
5711
6196
|
const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
|
|
@@ -6025,6 +6510,20 @@ function tan(x) {
|
|
|
6025
6510
|
x = fudgeArray(x);
|
|
6026
6511
|
return sin(x.ref).div(cos(x));
|
|
6027
6512
|
}
|
|
6513
|
+
/**
|
|
6514
|
+
* @function
|
|
6515
|
+
* Return the normalized sinc function.
|
|
6516
|
+
*
|
|
6517
|
+
* The sinc function is defined as `sin(πx) / (πx)` for `x != 0`, and `1` for `x = 0`.
|
|
6518
|
+
* This is the normalized sinc function commonly used in signal processing.
|
|
6519
|
+
*
|
|
6520
|
+
* **Note:** JVP is not supported at x=0 due to discontinuous derivative. This
|
|
6521
|
+
* requires a custom JVP rule to handle properly (see JAX implementation).
|
|
6522
|
+
*/
|
|
6523
|
+
const sinc = jit$1(function sinc$1(x) {
|
|
6524
|
+
const pix = x.ref.mul(Math.PI);
|
|
6525
|
+
return where(equal(x, 0), 1, sin(pix.ref).div(pix));
|
|
6526
|
+
});
|
|
6028
6527
|
/** Element-wise inverse cosine function (inverse of cos). */
|
|
6029
6528
|
function acos(x) {
|
|
6030
6529
|
return subtract(pi / 2, asin(x));
|
|
@@ -6077,6 +6576,25 @@ function trueDivide(x, y) {
|
|
|
6077
6576
|
return x.div(y);
|
|
6078
6577
|
}
|
|
6079
6578
|
/**
|
|
6579
|
+
* Return the largest integer smaller or equal to the division of the inputs.
|
|
6580
|
+
*
|
|
6581
|
+
* The result is always rounded towards negative infinity.
|
|
6582
|
+
*
|
|
6583
|
+
* For floating-point inputs, this is equivalent to `floor(x / y)`.
|
|
6584
|
+
* For integer inputs, we use `(x - remainder(x, y)) / y` to handle
|
|
6585
|
+
* negative values correctly (note: may overflow near int32 boundaries).
|
|
6586
|
+
*
|
|
6587
|
+
* @param x - Dividend array.
|
|
6588
|
+
* @param y - Divisor array.
|
|
6589
|
+
* @returns Element-wise floor division of x by y.
|
|
6590
|
+
*/
|
|
6591
|
+
function floorDivide(x, y) {
|
|
6592
|
+
x = fudgeArray(x);
|
|
6593
|
+
y = fudgeArray(y);
|
|
6594
|
+
if (require_backend.isFloatDtype(x.dtype) || require_backend.isFloatDtype(y.dtype)) return floor(trueDivide(x, y));
|
|
6595
|
+
return subtract(x, remainder(x.ref, y.ref)).div(y);
|
|
6596
|
+
}
|
|
6597
|
+
/**
|
|
6080
6598
|
* @function
|
|
6081
6599
|
* Calculate element-wise floating-point modulo operation.
|
|
6082
6600
|
*/
|
|
@@ -6090,6 +6608,20 @@ const fmod = jit$1(function fmod$1(x, y) {
|
|
|
6090
6608
|
const remainder = jit$1(function remainder$1(x, y) {
|
|
6091
6609
|
return mod(mod(x, y.ref).add(y.ref), y);
|
|
6092
6610
|
});
|
|
6611
|
+
/**
|
|
6612
|
+
* Return element-wise quotient and remainder simultaneously.
|
|
6613
|
+
*
|
|
6614
|
+
* Equivalent to `[floorDivide(x, y), remainder(x, y)]`.
|
|
6615
|
+
*
|
|
6616
|
+
* @param x - Dividend array.
|
|
6617
|
+
* @param y - Divisor array.
|
|
6618
|
+
* @returns Tuple of [quotient, remainder].
|
|
6619
|
+
*/
|
|
6620
|
+
function divmod(x, y) {
|
|
6621
|
+
const xArr = fudgeArray(x);
|
|
6622
|
+
const yArr = fudgeArray(y);
|
|
6623
|
+
return [floorDivide(xArr.ref, yArr.ref), remainder(xArr, yArr)];
|
|
6624
|
+
}
|
|
6093
6625
|
/** Round input to the nearest integer towards zero. */
|
|
6094
6626
|
function trunc(x) {
|
|
6095
6627
|
return idiv(x, 1);
|
|
@@ -6253,14 +6785,15 @@ function std(x, axis = null, opts) {
|
|
|
6253
6785
|
return sqrt(var_(x, axis, opts));
|
|
6254
6786
|
}
|
|
6255
6787
|
/** Estimate the sample covariance of a set of variables. */
|
|
6256
|
-
function cov(x, y) {
|
|
6788
|
+
function cov(x, y = null, { rowvar = true } = {}) {
|
|
6257
6789
|
x = fudgeArray(x);
|
|
6258
6790
|
if (x.ndim === 1) x = x.reshape([1, x.shape[0]]);
|
|
6259
|
-
if (y !==
|
|
6791
|
+
if (y !== null) {
|
|
6260
6792
|
y = fudgeArray(y);
|
|
6261
6793
|
if (y.ndim === 1) y = y.reshape([1, y.shape[0]]);
|
|
6262
6794
|
x = vstack([x, y]);
|
|
6263
6795
|
}
|
|
6796
|
+
if (!rowvar) x = x.transpose();
|
|
6264
6797
|
const [_M, N] = x.shape;
|
|
6265
6798
|
x = x.ref.sub(x.mean(1, { keepdims: true }));
|
|
6266
6799
|
return dot$1(x.ref, x.transpose()).div(N - 1);
|
|
@@ -6305,7 +6838,8 @@ const isfinite = jit$1(function isfinite$1(x) {
|
|
|
6305
6838
|
//#region src/library/lax-linalg.ts
|
|
6306
6839
|
var lax_linalg_exports = {};
|
|
6307
6840
|
__export(lax_linalg_exports, {
|
|
6308
|
-
cholesky: () => cholesky,
|
|
6841
|
+
cholesky: () => cholesky$1,
|
|
6842
|
+
lu: () => lu,
|
|
6309
6843
|
triangularSolve: () => triangularSolve
|
|
6310
6844
|
});
|
|
6311
6845
|
/**
|
|
@@ -6334,11 +6868,39 @@ __export(lax_linalg_exports, {
|
|
|
6334
6868
|
* // U ≈ [[1.4142135, 0.70710677], [0, 1.2247449]]
|
|
6335
6869
|
* ```
|
|
6336
6870
|
*/
|
|
6337
|
-
function cholesky(a, { upper = false } = {}) {
|
|
6871
|
+
function cholesky$1(a, { upper = false } = {}) {
|
|
6338
6872
|
const L = cholesky$2(a);
|
|
6339
6873
|
return upper ? moveaxis$1(L, -2, -1) : L;
|
|
6340
6874
|
}
|
|
6341
6875
|
/**
|
|
6876
|
+
* LU decomposition with partial pivoting.
|
|
6877
|
+
*
|
|
6878
|
+
* Computes the matrix decomposition: `P @ A = L @ U`, where `P` is a
|
|
6879
|
+
* permutation of the rows of `A`, `L` is lower-triangular with unit diagonal,
|
|
6880
|
+
* and `U` is upper-triangular.
|
|
6881
|
+
*
|
|
6882
|
+
* @param x - A batch of matrices with shape `[..., m, n]`.
|
|
6883
|
+
*
|
|
6884
|
+
* @returns A tuple `(lu, pivots, permutation)` where:
|
|
6885
|
+
* - `lu`: combined lower and upper triangular matrices.
|
|
6886
|
+
* - `pivots`: an array of pivot indices with shape `[..., min(m, n)]`.
|
|
6887
|
+
* - `permutation`: the permutation generated by pivots with shape `[..., m]`.
|
|
6888
|
+
*
|
|
6889
|
+
* @example
|
|
6890
|
+
* ```ts
|
|
6891
|
+
* import { lax, numpy as np } from "@jax-js/jax";
|
|
6892
|
+
*
|
|
6893
|
+
* const A = np.array([[4., 3.], [6., 3.]]);
|
|
6894
|
+
* const [lu, pivots, permutation] = lax.linalg.lu(A);
|
|
6895
|
+
* // lu ≈ [[6., 3.], [0.6666667, 1.0]]
|
|
6896
|
+
* // pivots = [1, 1]
|
|
6897
|
+
* // permutation = [1, 0]
|
|
6898
|
+
* ```
|
|
6899
|
+
*/
|
|
6900
|
+
function lu(x) {
|
|
6901
|
+
return lu$1(x);
|
|
6902
|
+
}
|
|
6903
|
+
/**
|
|
6342
6904
|
* Solve a triangular linear system.
|
|
6343
6905
|
*
|
|
6344
6906
|
* Solves `a @ x = b` (if leftSide=true) or `x @ a = b` (if leftSide=false)
|
|
@@ -6881,33 +7443,41 @@ __export(random_exports, {
|
|
|
6881
7443
|
gumbel: () => gumbel,
|
|
6882
7444
|
key: () => key,
|
|
6883
7445
|
laplace: () => laplace,
|
|
7446
|
+
multivariateNormal: () => multivariateNormal,
|
|
6884
7447
|
normal: () => normal,
|
|
6885
7448
|
split: () => split,
|
|
6886
7449
|
uniform: () => uniform
|
|
6887
7450
|
});
|
|
6888
|
-
function validateKeyShape(key$1) {
|
|
7451
|
+
function validateKeyShape(key$1, scalar = false) {
|
|
6889
7452
|
if (key$1.ndim === 0) throw new Error("Key must have at least one dimension.");
|
|
6890
7453
|
if (key$1.shape[key$1.shape.length - 1] !== 2) throw new Error(`Invalid key shape: ${key$1.shape}. Expected last dimension to be 2.`);
|
|
7454
|
+
if (scalar && key$1.shape.length > 1) throw new Error(`Expected a single PRNG key, but got a batch of keys with shape ${JSON.stringify(key$1.shape)} - use jax.vmap for batching.`);
|
|
6891
7455
|
return key$1.shape.slice(0, -1);
|
|
6892
7456
|
}
|
|
7457
|
+
function getK01(key$1) {
|
|
7458
|
+
const keyShape = validateKeyShape(key$1, true);
|
|
7459
|
+
let [k0, k1] = split$2(key$1, -1, [1, 1]);
|
|
7460
|
+
k0 = k0.reshape(keyShape);
|
|
7461
|
+
k1 = k1.reshape(keyShape);
|
|
7462
|
+
return [k0, k1];
|
|
7463
|
+
}
|
|
6893
7464
|
/** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
|
|
6894
7465
|
function key(seed) {
|
|
6895
|
-
seed = seed
|
|
6896
|
-
|
|
7466
|
+
seed = array(seed, { dtype: require_backend.DType.Uint32 });
|
|
7467
|
+
if (seed.ndim !== 0) throw new Error(`key: seed must be a scalar integer, but got shape ${seed.shape} - use jax.vmap for batching.`);
|
|
7468
|
+
return stack([0, seed]);
|
|
6897
7469
|
}
|
|
6898
7470
|
/** Splits a PRNG key into `num` new keys by adding a leading axis. */
|
|
6899
7471
|
function split(key$1, num = 2) {
|
|
6900
7472
|
const shape$1 = typeof num === "number" ? [num] : num;
|
|
6901
7473
|
for (const len of shape$1) if (len <= 0 || !Number.isInteger(len)) throw new Error(`Invalid split length: ${len}. Must be a positive integer.`);
|
|
6902
|
-
const
|
|
6903
|
-
const k0 = key$1.ref.slice(...keyShape.map(() => null), 0);
|
|
6904
|
-
const k1 = key$1.slice(...keyShape.map(() => null), 1);
|
|
7474
|
+
const [k0, k1] = getK01(key$1);
|
|
6905
7475
|
return stack([randomBits(k0.ref, k1.ref, shape$1, 0), randomBits(k0, k1, shape$1, 1)], -1);
|
|
6906
7476
|
}
|
|
6907
7477
|
/** Sample uniform bits in the form of unsigned integers. */
|
|
6908
7478
|
function bits(key$1, shape$1 = []) {
|
|
6909
|
-
const
|
|
6910
|
-
return randomBits(
|
|
7479
|
+
const [k0, k1] = getK01(key$1);
|
|
7480
|
+
return randomBits(k0, k1, shape$1);
|
|
6911
7481
|
}
|
|
6912
7482
|
/**
|
|
6913
7483
|
* @function
|
|
@@ -6981,6 +7551,32 @@ const laplace = jit$1(function laplace$1(key$1, shape$1 = []) {
|
|
|
6981
7551
|
}, { staticArgnums: [1] });
|
|
6982
7552
|
/**
|
|
6983
7553
|
* @function
|
|
7554
|
+
* Sample multivariate normal random values with given mean and covariance.
|
|
7555
|
+
*
|
|
7556
|
+
* The values are returned with the given shape, along with the final dimension
|
|
7557
|
+
* used to represent the n-dimensional multivariate normal factors.
|
|
7558
|
+
*
|
|
7559
|
+
* This uses Cholesky decomposition on the covariance matrix.
|
|
7560
|
+
*
|
|
7561
|
+
* - `key` - PRNG key
|
|
7562
|
+
* - `mean` - Mean vector of shape `[..., n]`
|
|
7563
|
+
* - `cov` - Covariance of shape `[..., n, n]`, must be positive-definite
|
|
7564
|
+
* - `shape` - Result batch shape, must be broadcastable with
|
|
7565
|
+
* `mean.shape[:-1]` and `cov.shape[:-2]`
|
|
7566
|
+
* @returns Random samples of shape `[...shape, n]`
|
|
7567
|
+
*/
|
|
7568
|
+
const multivariateNormal = jit$1(function multivariateNormal$1(key$1, mean$1, cov$1, shape$1 = []) {
|
|
7569
|
+
mean$1 = fudgeArray(mean$1);
|
|
7570
|
+
cov$1 = fudgeArray(cov$1);
|
|
7571
|
+
const n = mean$1.shape[mean$1.ndim - 1];
|
|
7572
|
+
if (cov$1.shape[cov$1.ndim - 1] !== n || cov$1.shape[cov$1.ndim - 2] !== n) throw new Error(`Invalid covariance shape: ${cov$1.shape}. Expected last two dimensions to be [${n}, ${n}].`);
|
|
7573
|
+
const outputShape = broadcastShapes(shape$1, mean$1.shape.slice(0, -1), cov$1.shape.slice(0, -2)).concat(n);
|
|
7574
|
+
const L = cholesky(cov$1);
|
|
7575
|
+
const z = normal(key$1, outputShape);
|
|
7576
|
+
return einsum("...ij,...j->...i", L, z).add(mean$1);
|
|
7577
|
+
}, { staticArgnums: [3] });
|
|
7578
|
+
/**
|
|
7579
|
+
* @function
|
|
6984
7580
|
* Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
|
|
6985
7581
|
*
|
|
6986
7582
|
* Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
|