@jax-js/jax 0.1.4 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +10 -7
- package/dist/{backend-tngXtWe4.js → backend-DaqL-MNz.js} +96 -7
- package/dist/{backend-Bu9GY6sK.cjs → backend-DziQSaoQ.cjs} +101 -6
- package/dist/index.cjs +737 -141
- package/dist/index.d.cts +238 -9
- package/dist/index.d.ts +238 -9
- package/dist/index.js +737 -141
- package/dist/webgl-ClIYb8jP.cjs +522 -0
- package/dist/webgl-RSuZKvgc.js +522 -0
- package/dist/{webgpu-Oj3Kd-kd.cjs → webgpu-Db2JrNBr.cjs} +296 -3
- package/dist/{webgpu-ChVgx3b6.js → webgpu-Dh7k9io0.js} +296 -3
- package/package.json +1 -1
package/dist/index.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { __export } from "./chunk-Cl8Af3a2.js";
|
|
2
|
-
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-
|
|
2
|
+
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-DaqL-MNz.js";
|
|
3
3
|
|
|
4
4
|
//#region src/frontend/convolution.ts
|
|
5
5
|
/**
|
|
@@ -356,6 +356,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
356
356
|
Primitive$1["PoolTranspose"] = "pool_transpose";
|
|
357
357
|
Primitive$1["Compare"] = "compare";
|
|
358
358
|
Primitive$1["Where"] = "where";
|
|
359
|
+
Primitive$1["Concatenate"] = "concatenate";
|
|
360
|
+
Primitive$1["Split"] = "split";
|
|
359
361
|
Primitive$1["RandomBits"] = "random_bits";
|
|
360
362
|
Primitive$1["Gather"] = "gather";
|
|
361
363
|
Primitive$1["Transpose"] = "transpose";
|
|
@@ -368,6 +370,7 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
368
370
|
Primitive$1["Argsort"] = "argsort";
|
|
369
371
|
Primitive$1["TriangularSolve"] = "triangular_solve";
|
|
370
372
|
Primitive$1["Cholesky"] = "cholesky";
|
|
373
|
+
Primitive$1["LU"] = "lu";
|
|
371
374
|
Primitive$1["Jit"] = "jit";
|
|
372
375
|
return Primitive$1;
|
|
373
376
|
}({});
|
|
@@ -499,7 +502,25 @@ function where$1(cond, x, y) {
|
|
|
499
502
|
y
|
|
500
503
|
]);
|
|
501
504
|
}
|
|
505
|
+
function concatenate$1(xs, axis) {
|
|
506
|
+
if (xs.length === 0) throw new Error("concatenate requires at least one input");
|
|
507
|
+
const avals = xs.map((x) => ShapedArray.fromAval(getAval(x)));
|
|
508
|
+
axis = checkAxis(axis, avals[0].ndim);
|
|
509
|
+
for (const x of avals) if (x.ndim !== avals[0].ndim || !x.shape.every((s, i) => i === axis || s === avals[0].shape[i])) throw new Error(`Concatenate: inputs ${avals[0]} and ${x} must match shapes except on axis ${axis}`);
|
|
510
|
+
return bind1(Primitive.Concatenate, xs, { axis });
|
|
511
|
+
}
|
|
512
|
+
function split$2(x, axis, sizes) {
|
|
513
|
+
axis = checkAxis(axis, ndim$1(x));
|
|
514
|
+
if (sizes.some((s) => s < 0 || !Number.isInteger(s))) throw new Error(`split: sizes must be nonnegative integers, got ${JSON.stringify(sizes)}`);
|
|
515
|
+
const totalSize = sizes.reduce((a, b) => a + b, 0);
|
|
516
|
+
if (totalSize !== getShape(x)[axis]) throw new Error(`split: sizes must sum to the size of the axis ${axis}, got ${totalSize}`);
|
|
517
|
+
return bind(Primitive.Split, [x], {
|
|
518
|
+
axis,
|
|
519
|
+
sizes
|
|
520
|
+
});
|
|
521
|
+
}
|
|
502
522
|
function randomBits(k0, k1, shape$1, mode = "xor") {
|
|
523
|
+
if (!deepEqual(k0.shape, k1.shape) || k0.dtype !== DType.Uint32 || k1.dtype !== DType.Uint32) throw new Error(`randomBits: key parts must be uint32 with the same shape, got ${ShapedArray.fromAval(k0.aval)} and ${ShapedArray.fromAval(k1.aval)}`);
|
|
503
524
|
return bind1(Primitive.RandomBits, [k0, k1], {
|
|
504
525
|
shape: shape$1,
|
|
505
526
|
mode
|
|
@@ -566,6 +587,11 @@ function pad$1(x, width) {
|
|
|
566
587
|
return bind1(Primitive.Pad, [x], { width });
|
|
567
588
|
}
|
|
568
589
|
function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
|
|
590
|
+
const as = getShape(a);
|
|
591
|
+
const bs = getShape(b);
|
|
592
|
+
if (as.length < 2 || bs.length < 2) throw new Error(`triangular_solve: must be >=2D, got a=${as}, b=${bs}`);
|
|
593
|
+
const n = as[as.length - 2];
|
|
594
|
+
if (n !== as[as.length - 1] || n !== bs[bs.length - 1]) throw new Error(`triangular_solve: incompatible shapes a=${as}, b=${bs}`);
|
|
569
595
|
if (lower) {
|
|
570
596
|
a = flip$1(a, [-2, -1]);
|
|
571
597
|
b = flip$1(b, [-1]);
|
|
@@ -575,8 +601,15 @@ function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
|
|
|
575
601
|
return x;
|
|
576
602
|
}
|
|
577
603
|
function cholesky$2(x) {
|
|
604
|
+
const aval = ShapedArray.fromAval(getAval(x));
|
|
605
|
+
if (aval.ndim < 2 || aval.shape[aval.ndim - 1] !== aval.shape[aval.ndim - 2]) throw new Error(`cholesky: expected batch of square matrices, got ${aval}`);
|
|
578
606
|
return bind1(Primitive.Cholesky, [x]);
|
|
579
607
|
}
|
|
608
|
+
function lu$1(x) {
|
|
609
|
+
const aval = ShapedArray.fromAval(getAval(x));
|
|
610
|
+
if (aval.ndim < 2) throw new Error(`lu: expected batch of matrices, got ${aval}`);
|
|
611
|
+
return bind(Primitive.LU, [x]);
|
|
612
|
+
}
|
|
580
613
|
function sort$1(x) {
|
|
581
614
|
const nd = ndim$1(x);
|
|
582
615
|
if (nd === 0) throw new Error("sort: requires at least 1D input");
|
|
@@ -685,6 +718,9 @@ var Tracer = class Tracer {
|
|
|
685
718
|
mul(other) {
|
|
686
719
|
return mul(this, other);
|
|
687
720
|
}
|
|
721
|
+
mod(other) {
|
|
722
|
+
return mod(this, other);
|
|
723
|
+
}
|
|
688
724
|
greater(other) {
|
|
689
725
|
return greater$1(this, other);
|
|
690
726
|
}
|
|
@@ -797,8 +833,14 @@ var Tracer = class Tracer {
|
|
|
797
833
|
*/
|
|
798
834
|
*[Symbol.iterator]() {
|
|
799
835
|
if (this.ndim === 0) throw new Error("Cannot iterate over a scalar array");
|
|
800
|
-
|
|
801
|
-
this.
|
|
836
|
+
let residual = this;
|
|
837
|
+
const subarrayShape = this.shape.slice(1);
|
|
838
|
+
for (let i = 0; i < this.shape[0]; i++) {
|
|
839
|
+
const lr = split$2(residual, 0, [1, residual.shape[0] - 1]);
|
|
840
|
+
yield lr[0].reshape(subarrayShape);
|
|
841
|
+
residual = lr[1];
|
|
842
|
+
}
|
|
843
|
+
residual.dispose();
|
|
802
844
|
}
|
|
803
845
|
/**
|
|
804
846
|
* Return a sorted copy of an array in ascending order.
|
|
@@ -948,6 +990,9 @@ var ShapedArray = class ShapedArray {
|
|
|
948
990
|
get size() {
|
|
949
991
|
return prod(this.shape);
|
|
950
992
|
}
|
|
993
|
+
scalar() {
|
|
994
|
+
return new ShapedArray([], this.dtype, this.weakType);
|
|
995
|
+
}
|
|
951
996
|
toString() {
|
|
952
997
|
return `${this.dtype}[${this.shape.join(",")}]`;
|
|
953
998
|
}
|
|
@@ -1553,7 +1598,7 @@ const abstractEvalRules = {
|
|
|
1553
1598
|
return [new ShapedArray(shape$1, dtype, weakType)];
|
|
1554
1599
|
},
|
|
1555
1600
|
[Primitive.Conv]([lhs, rhs], params) {
|
|
1556
|
-
const { dtype, weakType } = promoteAvals(
|
|
1601
|
+
const { dtype, weakType } = promoteAvals(lhs.scalar(), rhs.scalar());
|
|
1557
1602
|
const shape$1 = checkConvShape(lhs.shape, rhs.shape, params);
|
|
1558
1603
|
return [new ShapedArray(shape$1, dtype, weakType)];
|
|
1559
1604
|
},
|
|
@@ -1564,10 +1609,25 @@ const abstractEvalRules = {
|
|
|
1564
1609
|
const shape$1 = generalBroadcast(cond.shape, xy.shape);
|
|
1565
1610
|
return [new ShapedArray(shape$1, xy.dtype, xy.weakType)];
|
|
1566
1611
|
},
|
|
1612
|
+
[Primitive.Concatenate](xs, { axis }) {
|
|
1613
|
+
if (xs.length === 0) throw new TypeError("Concatenate requires at least one input");
|
|
1614
|
+
for (const x of xs) if (x.ndim !== xs[0].ndim || !x.shape.every((s, i) => i === axis || s === xs[0].shape[i])) throw new TypeError(`Concatenate: inputs ${xs[0]} and ${x} must match shapes except on axis ${axis}`);
|
|
1615
|
+
const shape$1 = xs[0].shape.slice();
|
|
1616
|
+
shape$1[axis] = xs.reduce((sum$1, x) => sum$1 + x.shape[axis], 0);
|
|
1617
|
+
const { dtype, weakType } = xs.map((x) => x.scalar()).reduce(promoteAvals);
|
|
1618
|
+
return [new ShapedArray(shape$1, dtype, weakType)];
|
|
1619
|
+
},
|
|
1620
|
+
[Primitive.Split]([x], { axis, sizes }) {
|
|
1621
|
+
const totalSize = sizes.reduce((a, b) => a + b, 0);
|
|
1622
|
+
if (x.shape[axis] !== totalSize) throw new TypeError(`Split: sizes ${sizes} do not sum to dimension ${x.shape[axis]} on axis ${axis}`);
|
|
1623
|
+
return sizes.map((size$1) => {
|
|
1624
|
+
return new ShapedArray(x.shape.toSpliced(axis, 1, size$1), x.dtype, x.weakType);
|
|
1625
|
+
});
|
|
1626
|
+
},
|
|
1567
1627
|
[Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
|
|
1568
1628
|
if (k0.dtype !== DType.Uint32 || k1.dtype !== DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
|
|
1569
|
-
|
|
1570
|
-
if (!deepEqual(
|
|
1629
|
+
if (!deepEqual(k0.shape, k1.shape)) throw new TypeError(`RandomBits: Keys have different shapes ${k0.shape} and ${k1.shape}`);
|
|
1630
|
+
if (!deepEqual(shape$1.slice(0, k0.ndim), k0.shape)) throw new TypeError(`RandomBits: generated shape ${shape$1} must match key shape ${k0.shape}`);
|
|
1571
1631
|
return [new ShapedArray(shape$1, DType.Uint32, false)];
|
|
1572
1632
|
},
|
|
1573
1633
|
[Primitive.Gather]([x, ...indices], { axis, outDim }) {
|
|
@@ -1624,6 +1684,16 @@ const abstractEvalRules = {
|
|
|
1624
1684
|
if (a.shape[a.ndim - 2] !== a.shape[a.ndim - 1]) throw new TypeError(`cholesky: must be square, got ${a}`);
|
|
1625
1685
|
return [ShapedArray.fromAval(a)];
|
|
1626
1686
|
},
|
|
1687
|
+
[Primitive.LU]([a]) {
|
|
1688
|
+
if (a.ndim < 2) throw new TypeError(`lu: requires at least 2D input, got ${a}`);
|
|
1689
|
+
const batch = a.shape.slice(0, -2);
|
|
1690
|
+
const [m, n] = a.shape.slice(-2);
|
|
1691
|
+
return [
|
|
1692
|
+
ShapedArray.fromAval(a),
|
|
1693
|
+
new ShapedArray([...batch, Math.min(m, n)], DType.Int32, false),
|
|
1694
|
+
new ShapedArray([...batch, m], DType.Int32, false)
|
|
1695
|
+
];
|
|
1696
|
+
},
|
|
1627
1697
|
[Primitive.Jit](args, { jaxpr }) {
|
|
1628
1698
|
const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
|
|
1629
1699
|
if (args.length !== inTypes.length) throw new TypeError(`jit expected ${inTypes.length} arguments, got ${args.length}`);
|
|
@@ -1705,7 +1775,8 @@ const routinePrimitives = new Map([
|
|
|
1705
1775
|
[Primitive.Sort, Routines.Sort],
|
|
1706
1776
|
[Primitive.Argsort, Routines.Argsort],
|
|
1707
1777
|
[Primitive.TriangularSolve, Routines.TriangularSolve],
|
|
1708
|
-
[Primitive.Cholesky, Routines.Cholesky]
|
|
1778
|
+
[Primitive.Cholesky, Routines.Cholesky],
|
|
1779
|
+
[Primitive.LU, Routines.LU]
|
|
1709
1780
|
]);
|
|
1710
1781
|
/** Result of compiling a Jaxpr. Can be evaluated on a series of inputs. */
|
|
1711
1782
|
var JitProgram = class {
|
|
@@ -1876,10 +1947,10 @@ function jitCompile(backend, jaxpr) {
|
|
|
1876
1947
|
inputs.push(jv.arg);
|
|
1877
1948
|
} else if (input instanceof Lit) inputs.push(builder.pushLit(input));
|
|
1878
1949
|
const outputs = [];
|
|
1879
|
-
for (const outVar
|
|
1880
|
-
const outId = builder.pushBuffer(outVar
|
|
1950
|
+
for (const outVar of eqn.outBinders) {
|
|
1951
|
+
const outId = builder.pushBuffer(outVar.aval.size * byteWidth(outVar.aval.dtype));
|
|
1881
1952
|
outputs.push(outId);
|
|
1882
|
-
ctx.set(outVar
|
|
1953
|
+
ctx.set(outVar, {
|
|
1883
1954
|
type: "imm",
|
|
1884
1955
|
arg: outId
|
|
1885
1956
|
});
|
|
@@ -1930,35 +2001,37 @@ function jitCompile(backend, jaxpr) {
|
|
|
1930
2001
|
let reduction;
|
|
1931
2002
|
if (inputReduction) {
|
|
1932
2003
|
const jv = inputReduction;
|
|
1933
|
-
const newEpilogue = rule(inputExps, inputAvals, eqn.params).exp;
|
|
1934
|
-
exp$2 = jv.exp.reindexGids(addArgs(jv.args));
|
|
2004
|
+
const newEpilogue = rule(inputExps, inputAvals, eqn.params).exp[0];
|
|
2005
|
+
exp$2 = [jv.exp.reindexGids(addArgs(jv.args))];
|
|
1935
2006
|
reduction = new Reduction(jv.reduction.dtype, jv.reduction.op, jv.reduction.size, newEpilogue);
|
|
1936
2007
|
} else {
|
|
1937
2008
|
const ruleOutput = rule(inputExps, inputAvals, eqn.params);
|
|
1938
2009
|
exp$2 = ruleOutput.exp;
|
|
1939
2010
|
reduction = ruleOutput.reduction;
|
|
1940
2011
|
}
|
|
1941
|
-
|
|
1942
|
-
|
|
1943
|
-
|
|
1944
|
-
|
|
1945
|
-
|
|
1946
|
-
|
|
1947
|
-
|
|
1948
|
-
|
|
1949
|
-
|
|
2012
|
+
for (let i$1 = 0; i$1 < eqn.outBinders.length; i$1++) {
|
|
2013
|
+
const outVar = eqn.outBinders[i$1];
|
|
2014
|
+
if (blackNodes.has(outVar)) {
|
|
2015
|
+
const nargs$1 = inputArgs.length;
|
|
2016
|
+
const size$1 = outVar.aval.size;
|
|
2017
|
+
const kernel = new Kernel(nargs$1, size$1, exp$2[i$1], reduction);
|
|
2018
|
+
const outId = builder.pushKernel(kernel, inputArgs);
|
|
2019
|
+
ctx.set(outVar, {
|
|
2020
|
+
type: "imm",
|
|
2021
|
+
arg: outId
|
|
2022
|
+
});
|
|
2023
|
+
} else if (reduction) ctx.set(outVar, {
|
|
2024
|
+
type: "red",
|
|
2025
|
+
exp: exp$2[i$1],
|
|
2026
|
+
reduction,
|
|
2027
|
+
args: inputArgs
|
|
1950
2028
|
});
|
|
1951
|
-
|
|
1952
|
-
|
|
1953
|
-
|
|
1954
|
-
|
|
1955
|
-
|
|
1956
|
-
}
|
|
1957
|
-
else ctx.set(outVar, {
|
|
1958
|
-
type: "exp",
|
|
1959
|
-
exp: exp$2,
|
|
1960
|
-
args: inputArgs
|
|
1961
|
-
});
|
|
2029
|
+
else ctx.set(outVar, {
|
|
2030
|
+
type: "exp",
|
|
2031
|
+
exp: exp$2[i$1],
|
|
2032
|
+
args: inputArgs
|
|
2033
|
+
});
|
|
2034
|
+
}
|
|
1962
2035
|
}
|
|
1963
2036
|
const outputIds = [];
|
|
1964
2037
|
for (const out of jaxpr.outs) if (out instanceof Var) {
|
|
@@ -1999,17 +2072,17 @@ function broadcastedJit(fn, opts) {
|
|
|
1999
2072
|
if (exp$2.dtype !== newDtype && !skipCastIdx.includes(i)) exp$2 = AluExp.cast(newDtype, exp$2);
|
|
2000
2073
|
return exp$2;
|
|
2001
2074
|
});
|
|
2002
|
-
return { exp: fn(exps, params) };
|
|
2075
|
+
return { exp: [fn(exps, params)] };
|
|
2003
2076
|
};
|
|
2004
2077
|
}
|
|
2005
2078
|
function unopJit(fn) {
|
|
2006
2079
|
return ([a], [_as], params) => {
|
|
2007
|
-
return { exp: fn(a, params) };
|
|
2080
|
+
return { exp: [fn(a, params)] };
|
|
2008
2081
|
};
|
|
2009
2082
|
}
|
|
2010
2083
|
function reshapeJit(fn) {
|
|
2011
2084
|
return ([a], [_as], params) => {
|
|
2012
|
-
return { exp: reshapeViews(a, (st) => fn(st, params)) };
|
|
2085
|
+
return { exp: [reshapeViews(a, (st) => fn(st, params))] };
|
|
2013
2086
|
};
|
|
2014
2087
|
}
|
|
2015
2088
|
function routineNoJit() {
|
|
@@ -2055,7 +2128,7 @@ const jitRules = {
|
|
|
2055
2128
|
a = reshapeViews(a, (st) => st.permute(perm).reshape(newShape), true);
|
|
2056
2129
|
const reduction = new Reduction(a.dtype, op, reductionSize);
|
|
2057
2130
|
return {
|
|
2058
|
-
exp: a,
|
|
2131
|
+
exp: [a],
|
|
2059
2132
|
reduction
|
|
2060
2133
|
};
|
|
2061
2134
|
},
|
|
@@ -2066,13 +2139,13 @@ const jitRules = {
|
|
|
2066
2139
|
a = reshapeViews(a, (st) => st.compose(stX), true);
|
|
2067
2140
|
const reduction = new Reduction(a.dtype, AluOp.Add, stX.shape[stX.shape.length - 1]);
|
|
2068
2141
|
return {
|
|
2069
|
-
exp: a,
|
|
2142
|
+
exp: [a],
|
|
2070
2143
|
reduction
|
|
2071
2144
|
};
|
|
2072
2145
|
},
|
|
2073
2146
|
[Primitive.Dot]([a, b], [as, bs]) {
|
|
2074
2147
|
const k1 = jitRules[Primitive.Mul]([a, b], [as, bs], {});
|
|
2075
|
-
const c = k1.exp;
|
|
2148
|
+
const [c] = k1.exp;
|
|
2076
2149
|
const cs = promoteAvals(as, bs);
|
|
2077
2150
|
return jitRules[Primitive.Reduce]([c], [cs], {
|
|
2078
2151
|
op: AluOp.Add,
|
|
@@ -2089,16 +2162,41 @@ const jitRules = {
|
|
|
2089
2162
|
},
|
|
2090
2163
|
[Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
|
|
2091
2164
|
[Primitive.Where]: broadcastedJit(([cond, a, b]) => AluExp.where(cond, a, b), { skipCastIdx: [0] }),
|
|
2165
|
+
[Primitive.Concatenate](exps, avals, { axis }) {
|
|
2166
|
+
const ndim$2 = avals[0].ndim;
|
|
2167
|
+
const sizes = avals.map((x) => x.shape[axis]);
|
|
2168
|
+
const finalSize = sizes.reduce((a, b) => a + b, 0);
|
|
2169
|
+
const makePadAxis = (start, end) => range(ndim$2).map((i) => i === axis ? [start, end] : [0, 0]);
|
|
2170
|
+
let cum = 0;
|
|
2171
|
+
const src = [];
|
|
2172
|
+
for (let i = 0; i < exps.length; i++) {
|
|
2173
|
+
const padding = makePadAxis(cum, finalSize - cum - sizes[i]);
|
|
2174
|
+
src.push(reshapeViews(exps[i], (st) => st.pad(padding)));
|
|
2175
|
+
cum += sizes[i];
|
|
2176
|
+
}
|
|
2177
|
+
return { exp: [src.reduce(AluExp.add)] };
|
|
2178
|
+
},
|
|
2179
|
+
[Primitive.Split]([a], [as], { axis, sizes }) {
|
|
2180
|
+
const exp$2 = [];
|
|
2181
|
+
let start = 0;
|
|
2182
|
+
for (const size$1 of sizes) {
|
|
2183
|
+
const slice = range(as.ndim).map((d) => d === axis ? [start, start + size$1] : [0, as.shape[d]]);
|
|
2184
|
+
exp$2.push(reshapeViews(a, (st) => st.shrink(slice)));
|
|
2185
|
+
start += size$1;
|
|
2186
|
+
}
|
|
2187
|
+
return { exp: exp$2 };
|
|
2188
|
+
},
|
|
2092
2189
|
[Primitive.RandomBits]: (keys, keyShapes, { shape: shape$1, mode }) => {
|
|
2190
|
+
const keyShape = keyShapes[0].shape;
|
|
2093
2191
|
const mapping = (st) => {
|
|
2094
|
-
if (!deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, range(shape
|
|
2192
|
+
if (!deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, range(st.shape.length, shape$1.length));
|
|
2095
2193
|
};
|
|
2096
2194
|
const k0 = reshapeViews(keys[0], mapping);
|
|
2097
2195
|
const k1 = reshapeViews(keys[1], mapping);
|
|
2098
2196
|
const c0 = AluExp.u32(0);
|
|
2099
|
-
const c1 = AluExp.cast(DType.Uint32, AluVar.gidx);
|
|
2197
|
+
const c1 = AluExp.mod(AluExp.cast(DType.Uint32, AluVar.gidx), AluExp.u32(Math.max(prod(shape$1.slice(keyShape.length)), 1)));
|
|
2100
2198
|
const exp$2 = AluExp.threefry2x32(k0, k1, c0, c1, mode);
|
|
2101
|
-
return { exp: exp$2 };
|
|
2199
|
+
return { exp: [exp$2] };
|
|
2102
2200
|
},
|
|
2103
2201
|
[Primitive.Gather]([x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
|
|
2104
2202
|
const axisSet = new Set(axis);
|
|
@@ -2113,7 +2211,7 @@ const jitRules = {
|
|
|
2113
2211
|
for (const [i, iexp] of indices.entries()) src[axis[i]] = AluExp.cast(DType.Int32, reshapeViews(iexp, (st) => st.broadcast(finalShape, [...range(outDim + indexShape.length - st.shape.length), ...range(outDim + indexShape.length, finalShape.length)])));
|
|
2114
2212
|
const [index, valid] = ShapeTracker.fromShape(xs.shape).toAluExp(src);
|
|
2115
2213
|
if (!valid.resolve()) throw new Error("internal: expected full validity mask in Gather");
|
|
2116
|
-
return { exp: x.substitute({ gidx: index }) };
|
|
2214
|
+
return { exp: [x.substitute({ gidx: index })] };
|
|
2117
2215
|
},
|
|
2118
2216
|
[Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
|
|
2119
2217
|
[Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
|
|
@@ -2129,6 +2227,7 @@ const jitRules = {
|
|
|
2129
2227
|
[Primitive.Argsort]: routineNoJit(),
|
|
2130
2228
|
[Primitive.TriangularSolve]: routineNoJit(),
|
|
2131
2229
|
[Primitive.Cholesky]: routineNoJit(),
|
|
2230
|
+
[Primitive.LU]: routineNoJit(),
|
|
2132
2231
|
[Primitive.Jit]() {
|
|
2133
2232
|
throw new Error("internal: Jit should have been flattened before JIT compilation");
|
|
2134
2233
|
}
|
|
@@ -2407,6 +2506,10 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2407
2506
|
this.#rc++;
|
|
2408
2507
|
return this;
|
|
2409
2508
|
}
|
|
2509
|
+
/** Get the current reference count (for debugging memory management). */
|
|
2510
|
+
get refCount() {
|
|
2511
|
+
return this.#rc;
|
|
2512
|
+
}
|
|
2410
2513
|
dispose() {
|
|
2411
2514
|
this.#check();
|
|
2412
2515
|
if (--this.#rc === 0) {
|
|
@@ -2564,7 +2667,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2564
2667
|
} else if (castDtype === void 0) {
|
|
2565
2668
|
castDtype = arrays[i].#dtype;
|
|
2566
2669
|
castWeakType = arrays[i].#weakType;
|
|
2567
|
-
} else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType),
|
|
2670
|
+
} else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), arrays[i].aval.scalar()));
|
|
2568
2671
|
const weakType = castWeakType && !strongTypeOutput;
|
|
2569
2672
|
const { backend, committed } = Array$1.#computeBackend(name, arrays);
|
|
2570
2673
|
arrays = arrays.map((ar) => ar._putSync(backend));
|
|
@@ -2957,17 +3060,44 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2957
3060
|
y
|
|
2958
3061
|
], { dtypeOverride: [DType.Bool] })];
|
|
2959
3062
|
},
|
|
3063
|
+
[Primitive.Concatenate](xs, { axis }) {
|
|
3064
|
+
const ndim$2 = xs[0].ndim;
|
|
3065
|
+
const sizes = xs.map((x) => x.shape[axis]);
|
|
3066
|
+
const finalSize = sizes.reduce((a, b) => a + b, 0);
|
|
3067
|
+
const makePadAxis = (start, end) => range(ndim$2).map((i) => i === axis ? [start, end] : [0, 0]);
|
|
3068
|
+
let cum = 0;
|
|
3069
|
+
const xsPadded = [];
|
|
3070
|
+
for (let i = 0; i < xs.length; i++) {
|
|
3071
|
+
const padding = makePadAxis(cum, finalSize - cum - sizes[i]);
|
|
3072
|
+
xsPadded.push(xs[i].#reshape(xs[i].#st.pad(padding)));
|
|
3073
|
+
cum += sizes[i];
|
|
3074
|
+
}
|
|
3075
|
+
const custom = (exps) => exps.reduce(AluExp.add);
|
|
3076
|
+
return [Array$1.#naryCustom("concatenate", custom, xsPadded)];
|
|
3077
|
+
},
|
|
3078
|
+
[Primitive.Split]([x], { axis, sizes }) {
|
|
3079
|
+
const outputs = [];
|
|
3080
|
+
for (let i = 0, start = 0; i < sizes.length; i++) {
|
|
3081
|
+
const slice = range(x.ndim).map((d) => d === axis ? [start, start + sizes[i]] : [0, x.shape[d]]);
|
|
3082
|
+
outputs.push(x.ref.#reshape(x.#st.shrink(slice)));
|
|
3083
|
+
start += sizes[i];
|
|
3084
|
+
}
|
|
3085
|
+
x.dispose();
|
|
3086
|
+
return outputs;
|
|
3087
|
+
},
|
|
2960
3088
|
[Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
|
|
2961
|
-
const keyShape =
|
|
2962
|
-
|
|
2963
|
-
const c0 = zeros(
|
|
3089
|
+
const keyShape = k0.shape;
|
|
3090
|
+
const genShape = shape$1.slice(keyShape.length);
|
|
3091
|
+
const c0 = zeros(genShape, {
|
|
2964
3092
|
dtype: DType.Uint32,
|
|
2965
3093
|
device: k0.device
|
|
2966
3094
|
});
|
|
2967
|
-
const c1 = arange(0, prod(
|
|
3095
|
+
const c1 = arange(0, prod(genShape), 1, {
|
|
2968
3096
|
dtype: DType.Uint32,
|
|
2969
3097
|
device: k0.device
|
|
2970
|
-
}).reshape(
|
|
3098
|
+
}).reshape(genShape);
|
|
3099
|
+
k0 = k0.#reshape(k0.#st.reshape(keyShape.concat(rep(genShape.length, 1))));
|
|
3100
|
+
k1 = k1.#reshape(k1.#st.reshape(keyShape.concat(rep(genShape.length, 1))));
|
|
2971
3101
|
const custom = ([k0$1, k1$1, c0$1, c1$1]) => AluExp.threefry2x32(k0$1, k1$1, c0$1, c1$1, mode);
|
|
2972
3102
|
return [Array$1.#naryCustom("random_bits", custom, [
|
|
2973
3103
|
k0,
|
|
@@ -3001,40 +3131,63 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
3001
3131
|
},
|
|
3002
3132
|
[Primitive.Sort]([x]) {
|
|
3003
3133
|
const routine = new Routine(Routines.Sort, {
|
|
3004
|
-
inputShapes: [x.
|
|
3005
|
-
inputDtypes: [x.
|
|
3006
|
-
outputShapes: [x.
|
|
3007
|
-
outputDtypes: [x.
|
|
3134
|
+
inputShapes: [x.shape],
|
|
3135
|
+
inputDtypes: [x.dtype],
|
|
3136
|
+
outputShapes: [x.shape],
|
|
3137
|
+
outputDtypes: [x.dtype]
|
|
3008
3138
|
});
|
|
3009
3139
|
return Array$1.#routine(routine, [x], [x.#weakType]);
|
|
3010
3140
|
},
|
|
3011
3141
|
[Primitive.Argsort]([x]) {
|
|
3012
3142
|
const routine = new Routine(Routines.Argsort, {
|
|
3013
|
-
inputShapes: [x.
|
|
3014
|
-
inputDtypes: [x.
|
|
3015
|
-
outputShapes: [x.
|
|
3016
|
-
outputDtypes: [x.
|
|
3143
|
+
inputShapes: [x.shape],
|
|
3144
|
+
inputDtypes: [x.dtype],
|
|
3145
|
+
outputShapes: [x.shape, x.shape],
|
|
3146
|
+
outputDtypes: [x.dtype, DType.Int32]
|
|
3017
3147
|
});
|
|
3018
3148
|
return Array$1.#routine(routine, [x], [x.#weakType, false]);
|
|
3019
3149
|
},
|
|
3020
3150
|
[Primitive.TriangularSolve]([a, b], { unitDiagonal }) {
|
|
3021
3151
|
const routine = new Routine(Routines.TriangularSolve, {
|
|
3022
|
-
inputShapes: [a.
|
|
3023
|
-
inputDtypes: [a.
|
|
3024
|
-
outputShapes: [b.
|
|
3025
|
-
outputDtypes: [b.
|
|
3152
|
+
inputShapes: [a.shape, b.shape],
|
|
3153
|
+
inputDtypes: [a.dtype, b.dtype],
|
|
3154
|
+
outputShapes: [b.shape],
|
|
3155
|
+
outputDtypes: [b.dtype]
|
|
3026
3156
|
}, { unitDiagonal });
|
|
3027
3157
|
return Array$1.#routine(routine, [a, b], [a.#weakType && b.#weakType]);
|
|
3028
3158
|
},
|
|
3029
3159
|
[Primitive.Cholesky]([a]) {
|
|
3030
3160
|
const routine = new Routine(Routines.Cholesky, {
|
|
3031
|
-
inputShapes: [a.
|
|
3032
|
-
inputDtypes: [a.
|
|
3033
|
-
outputShapes: [a.
|
|
3034
|
-
outputDtypes: [a.
|
|
3161
|
+
inputShapes: [a.shape],
|
|
3162
|
+
inputDtypes: [a.dtype],
|
|
3163
|
+
outputShapes: [a.shape],
|
|
3164
|
+
outputDtypes: [a.dtype]
|
|
3035
3165
|
});
|
|
3036
3166
|
return Array$1.#routine(routine, [a], [a.#weakType]);
|
|
3037
3167
|
},
|
|
3168
|
+
[Primitive.LU]([a]) {
|
|
3169
|
+
const batch = a.shape.slice(0, -2);
|
|
3170
|
+
const [m, n] = a.shape.slice(-2);
|
|
3171
|
+
const routine = new Routine(Routines.LU, {
|
|
3172
|
+
inputShapes: [a.shape],
|
|
3173
|
+
inputDtypes: [a.dtype],
|
|
3174
|
+
outputShapes: [
|
|
3175
|
+
a.shape,
|
|
3176
|
+
[...batch, Math.min(m, n)],
|
|
3177
|
+
[...batch, m]
|
|
3178
|
+
],
|
|
3179
|
+
outputDtypes: [
|
|
3180
|
+
a.dtype,
|
|
3181
|
+
DType.Int32,
|
|
3182
|
+
DType.Int32
|
|
3183
|
+
]
|
|
3184
|
+
});
|
|
3185
|
+
return Array$1.#routine(routine, [a], [
|
|
3186
|
+
a.#weakType,
|
|
3187
|
+
false,
|
|
3188
|
+
false
|
|
3189
|
+
]);
|
|
3190
|
+
},
|
|
3038
3191
|
[Primitive.Jit](args, { jaxpr }) {
|
|
3039
3192
|
if (jaxpr.inBinders.length !== args.length) throw new Error(`jit expects ${jaxpr.inBinders.length} args, got ${args.length}`);
|
|
3040
3193
|
const { backend, committed } = Array$1.#computeBackend("jit", args);
|
|
@@ -3140,7 +3293,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
3140
3293
|
device
|
|
3141
3294
|
});
|
|
3142
3295
|
} else {
|
|
3143
|
-
const weakType = dtype == void 0;
|
|
3296
|
+
const weakType = dtype == void 0 && shape$1.length === 0;
|
|
3144
3297
|
dtype = dtype ?? DType.Float32;
|
|
3145
3298
|
const data = dtypedJsArray(dtype, flat);
|
|
3146
3299
|
return arrayFromData(data, shape$1, {
|
|
@@ -3254,7 +3407,7 @@ function ones(shape$1, { dtype, device } = {}) {
|
|
|
3254
3407
|
}
|
|
3255
3408
|
/** Return a new array of given shape and type, filled with `fill_value`. */
|
|
3256
3409
|
function full(shape$1, fillValue, { dtype, device } = {}) {
|
|
3257
|
-
let weakType = dtype == void 0;
|
|
3410
|
+
let weakType = dtype == void 0 && shape$1.length === 0;
|
|
3258
3411
|
if (typeof fillValue === "number") dtype = dtype ?? DType.Float32;
|
|
3259
3412
|
else if (typeof fillValue === "boolean") {
|
|
3260
3413
|
dtype = dtype ?? DType.Bool;
|
|
@@ -3412,6 +3565,27 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
|
|
|
3412
3565
|
committed: device != void 0
|
|
3413
3566
|
});
|
|
3414
3567
|
}
|
|
3568
|
+
/**
|
|
3569
|
+
* Return numbers spaced evenly on a log scale.
|
|
3570
|
+
*
|
|
3571
|
+
* In linear space, the sequence starts at `base ** start` and ends at
|
|
3572
|
+
* `base ** stop` (see `endpoint` below).
|
|
3573
|
+
*
|
|
3574
|
+
* @param start - `base ** start` is the starting value of the sequence.
|
|
3575
|
+
* @param stop - `base ** stop` is the final value of the sequence, unless `endpoint` is false.
|
|
3576
|
+
* @param num - Number of samples to generate. Default is 50.
|
|
3577
|
+
* @param endpoint - If true, `stop` is the last sample. Otherwise, it is not included. Default is true.
|
|
3578
|
+
* @param base - The base of the log space. Default is 10.
|
|
3579
|
+
* @returns Array of evenly spaced values on a log scale.
|
|
3580
|
+
*/
|
|
3581
|
+
function logspace(start, stop, num = 50, endpoint = true, base = 10, { dtype, device } = {}) {
|
|
3582
|
+
const y = linspace(start, stop, num, endpoint, {
|
|
3583
|
+
dtype,
|
|
3584
|
+
device
|
|
3585
|
+
});
|
|
3586
|
+
const logBase = Math.log(base);
|
|
3587
|
+
return exp$1(mul(y, logBase));
|
|
3588
|
+
}
|
|
3415
3589
|
function aluCompare(a, b, op) {
|
|
3416
3590
|
switch (op) {
|
|
3417
3591
|
case CompareOp.Less: return AluExp.cmplt(a, b);
|
|
@@ -3488,6 +3662,7 @@ var BatchTrace = class extends Trace {
|
|
|
3488
3662
|
return valOuts$1.map((x) => new BatchTracer(this, x, null));
|
|
3489
3663
|
}
|
|
3490
3664
|
const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
|
|
3665
|
+
if (valOuts.length !== bdimOuts.length) throw new Error(`vmap rule for ${primitive} returned mismatched lengths: ${valOuts.length} vs ${bdimOuts.length}`);
|
|
3491
3666
|
return zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
|
|
3492
3667
|
}
|
|
3493
3668
|
get axisSize() {
|
|
@@ -3499,13 +3674,13 @@ var BatchTrace = class extends Trace {
|
|
|
3499
3674
|
*
|
|
3500
3675
|
* Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
|
|
3501
3676
|
*/
|
|
3502
|
-
function broadcastBatcher(
|
|
3503
|
-
return (axisSize, args, dims) => {
|
|
3677
|
+
function broadcastBatcher(prim) {
|
|
3678
|
+
return (axisSize, args, dims, params) => {
|
|
3504
3679
|
if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
|
|
3505
3680
|
const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
|
|
3506
3681
|
const firstIdx = dims.findIndex((d) => d !== null);
|
|
3507
3682
|
const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
|
|
3508
|
-
if (zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[
|
|
3683
|
+
if (zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[bind1(prim, args, params)], [nd + firstBdim]];
|
|
3509
3684
|
args = args.map((x, i) => {
|
|
3510
3685
|
if (dims[i] === null) return x;
|
|
3511
3686
|
x = moveBatchAxis(axisSize, dims[i], 0, x);
|
|
@@ -3516,37 +3691,45 @@ function broadcastBatcher(op) {
|
|
|
3516
3691
|
]);
|
|
3517
3692
|
return x;
|
|
3518
3693
|
});
|
|
3519
|
-
return [[
|
|
3694
|
+
return [[bind1(prim, args, params)], [0]];
|
|
3520
3695
|
};
|
|
3521
3696
|
}
|
|
3522
|
-
function unopBatcher(
|
|
3697
|
+
function unopBatcher(prim) {
|
|
3523
3698
|
return (axisSize, [x], [xBdim], params) => {
|
|
3524
|
-
return [[
|
|
3699
|
+
return [[bind1(prim, [x], params)], [xBdim]];
|
|
3700
|
+
};
|
|
3701
|
+
}
|
|
3702
|
+
function lastDimsBatcher(prim, inputDims, numOutputs = 1) {
|
|
3703
|
+
return (axisSize, [x], [xBdim], params) => {
|
|
3704
|
+
assertNonNull(xBdim);
|
|
3705
|
+
if (xBdim < x.ndim - inputDims) return [bind(prim, [x], params), rep(numOutputs, xBdim)];
|
|
3706
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3707
|
+
return [bind(prim, [x], params), rep(numOutputs, 0)];
|
|
3525
3708
|
};
|
|
3526
3709
|
}
|
|
3527
3710
|
const vmapRules = {
|
|
3528
|
-
[Primitive.Add]: broadcastBatcher(
|
|
3529
|
-
[Primitive.Mul]: broadcastBatcher(
|
|
3530
|
-
[Primitive.Idiv]: broadcastBatcher(
|
|
3531
|
-
[Primitive.Mod]: broadcastBatcher(
|
|
3532
|
-
[Primitive.Min]: broadcastBatcher(
|
|
3533
|
-
[Primitive.Max]: broadcastBatcher(
|
|
3534
|
-
[Primitive.Neg]: unopBatcher(
|
|
3535
|
-
[Primitive.Reciprocal]: unopBatcher(
|
|
3536
|
-
[Primitive.Floor]: unopBatcher(
|
|
3537
|
-
[Primitive.Ceil]: unopBatcher(
|
|
3538
|
-
[Primitive.StopGradient]: unopBatcher(
|
|
3539
|
-
[Primitive.Cast]: unopBatcher(
|
|
3540
|
-
[Primitive.Bitcast]: unopBatcher(
|
|
3541
|
-
[Primitive.Sin]: unopBatcher(
|
|
3542
|
-
[Primitive.Cos]: unopBatcher(
|
|
3543
|
-
[Primitive.Asin]: unopBatcher(
|
|
3544
|
-
[Primitive.Atan]: unopBatcher(
|
|
3545
|
-
[Primitive.Exp]: unopBatcher(
|
|
3546
|
-
[Primitive.Log]: unopBatcher(
|
|
3547
|
-
[Primitive.Erf]: unopBatcher(
|
|
3548
|
-
[Primitive.Erfc]: unopBatcher(
|
|
3549
|
-
[Primitive.Sqrt]: unopBatcher(
|
|
3711
|
+
[Primitive.Add]: broadcastBatcher(Primitive.Add),
|
|
3712
|
+
[Primitive.Mul]: broadcastBatcher(Primitive.Mul),
|
|
3713
|
+
[Primitive.Idiv]: broadcastBatcher(Primitive.Idiv),
|
|
3714
|
+
[Primitive.Mod]: broadcastBatcher(Primitive.Mod),
|
|
3715
|
+
[Primitive.Min]: broadcastBatcher(Primitive.Min),
|
|
3716
|
+
[Primitive.Max]: broadcastBatcher(Primitive.Max),
|
|
3717
|
+
[Primitive.Neg]: unopBatcher(Primitive.Neg),
|
|
3718
|
+
[Primitive.Reciprocal]: unopBatcher(Primitive.Reciprocal),
|
|
3719
|
+
[Primitive.Floor]: unopBatcher(Primitive.Floor),
|
|
3720
|
+
[Primitive.Ceil]: unopBatcher(Primitive.Ceil),
|
|
3721
|
+
[Primitive.StopGradient]: unopBatcher(Primitive.StopGradient),
|
|
3722
|
+
[Primitive.Cast]: unopBatcher(Primitive.Cast),
|
|
3723
|
+
[Primitive.Bitcast]: unopBatcher(Primitive.Bitcast),
|
|
3724
|
+
[Primitive.Sin]: unopBatcher(Primitive.Sin),
|
|
3725
|
+
[Primitive.Cos]: unopBatcher(Primitive.Cos),
|
|
3726
|
+
[Primitive.Asin]: unopBatcher(Primitive.Asin),
|
|
3727
|
+
[Primitive.Atan]: unopBatcher(Primitive.Atan),
|
|
3728
|
+
[Primitive.Exp]: unopBatcher(Primitive.Exp),
|
|
3729
|
+
[Primitive.Log]: unopBatcher(Primitive.Log),
|
|
3730
|
+
[Primitive.Erf]: unopBatcher(Primitive.Erf),
|
|
3731
|
+
[Primitive.Erfc]: unopBatcher(Primitive.Erfc),
|
|
3732
|
+
[Primitive.Sqrt]: unopBatcher(Primitive.Sqrt),
|
|
3550
3733
|
[Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
|
|
3551
3734
|
assertNonNull(xBdim);
|
|
3552
3735
|
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
@@ -3568,10 +3751,25 @@ const vmapRules = {
|
|
|
3568
3751
|
});
|
|
3569
3752
|
return [[z], [0]];
|
|
3570
3753
|
},
|
|
3571
|
-
[Primitive.Compare](
|
|
3572
|
-
|
|
3754
|
+
[Primitive.Compare]: broadcastBatcher(Primitive.Compare),
|
|
3755
|
+
[Primitive.Where]: broadcastBatcher(Primitive.Where),
|
|
3756
|
+
[Primitive.Concatenate](axisSize, xs, xBdims, { axis }) {
|
|
3757
|
+
const minBdim = Math.min(...xBdims.filter((d) => d !== null));
|
|
3758
|
+
xs = xs.map((x, i) => moveBatchAxis(axisSize, xBdims[i], minBdim, x));
|
|
3759
|
+
const newAxis = axis + (minBdim <= axis ? 1 : 0);
|
|
3760
|
+
return [[concatenate$1(xs, newAxis)], [minBdim]];
|
|
3761
|
+
},
|
|
3762
|
+
[Primitive.Split](axisSize, [x], [xBdim], { axis, sizes }) {
|
|
3763
|
+
assertNonNull(xBdim);
|
|
3764
|
+
const newAxis = axis + (xBdim <= axis ? 1 : 0);
|
|
3765
|
+
const outs = split$2(x, newAxis, sizes);
|
|
3766
|
+
return [outs, rep(outs.length, xBdim)];
|
|
3767
|
+
},
|
|
3768
|
+
[Primitive.RandomBits](axisSize, [k0, k1], [bdim0, bdim1], { shape: shape$1, mode }) {
|
|
3769
|
+
k0 = moveBatchAxis(axisSize, bdim0, 0, k0);
|
|
3770
|
+
k1 = moveBatchAxis(axisSize, bdim1, 0, k1);
|
|
3771
|
+
return [[randomBits(k0, k1, [axisSize, ...shape$1], mode)], [0]];
|
|
3573
3772
|
},
|
|
3574
|
-
[Primitive.Where]: broadcastBatcher(where$1),
|
|
3575
3773
|
[Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
|
|
3576
3774
|
if (indicesBdim.every((d) => d === null)) {
|
|
3577
3775
|
assertNonNull(xBdim);
|
|
@@ -3633,18 +3831,8 @@ const vmapRules = {
|
|
|
3633
3831
|
const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
|
|
3634
3832
|
return [[pad$1(x, newWidth)], [xBdim]];
|
|
3635
3833
|
},
|
|
3636
|
-
[Primitive.Sort](
|
|
3637
|
-
|
|
3638
|
-
if (xBdim !== x.ndim - 1) return [[sort$1(x)], [xBdim]];
|
|
3639
|
-
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3640
|
-
return [[sort$1(x)], [0]];
|
|
3641
|
-
},
|
|
3642
|
-
[Primitive.Argsort](axisSize, [x], [xBdim]) {
|
|
3643
|
-
assertNonNull(xBdim);
|
|
3644
|
-
if (xBdim !== x.ndim - 1) return [argsort$1(x), [xBdim, xBdim]];
|
|
3645
|
-
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3646
|
-
return [argsort$1(x), [0, 0]];
|
|
3647
|
-
},
|
|
3834
|
+
[Primitive.Sort]: lastDimsBatcher(Primitive.Sort, 1),
|
|
3835
|
+
[Primitive.Argsort]: lastDimsBatcher(Primitive.Argsort, 1, 2),
|
|
3648
3836
|
[Primitive.TriangularSolve](axisSize, [a, b], [aBdim, bBdim], { unitDiagonal }) {
|
|
3649
3837
|
if (aBdim === null) {
|
|
3650
3838
|
b = moveBatchAxis(axisSize, bBdim, -3, b);
|
|
@@ -3668,12 +3856,8 @@ const vmapRules = {
|
|
|
3668
3856
|
const x = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
|
|
3669
3857
|
return [[x], [0]];
|
|
3670
3858
|
},
|
|
3671
|
-
[Primitive.Cholesky](
|
|
3672
|
-
|
|
3673
|
-
if (xBdim < x.ndim - 2) return [[cholesky$2(x)], [xBdim]];
|
|
3674
|
-
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3675
|
-
return [[cholesky$2(x)], [0]];
|
|
3676
|
-
},
|
|
3859
|
+
[Primitive.Cholesky]: lastDimsBatcher(Primitive.Cholesky, 2),
|
|
3860
|
+
[Primitive.LU]: lastDimsBatcher(Primitive.LU, 2, 3),
|
|
3677
3861
|
[Primitive.Jit](axisSize, args, dims, { name, jaxpr }) {
|
|
3678
3862
|
const newJaxpr = vmapJaxpr(jaxpr, axisSize, dims);
|
|
3679
3863
|
const outs = bind(Primitive.Jit, [...newJaxpr.consts.map((c) => c.ref), ...args], {
|
|
@@ -3823,6 +4007,16 @@ function batchMatmulT(a, b) {
|
|
|
3823
4007
|
function mT(a) {
|
|
3824
4008
|
return moveaxis(a, -2, -1);
|
|
3825
4009
|
}
|
|
4010
|
+
function sliceAxis(a, axis, p) {
|
|
4011
|
+
const slices = Array(a.shape.length).fill([]);
|
|
4012
|
+
slices[checkAxis(axis, a.ndim)] = p;
|
|
4013
|
+
return a.slice(...slices);
|
|
4014
|
+
}
|
|
4015
|
+
function padAxis(a, axis, p) {
|
|
4016
|
+
const pads = Array(a.shape.length).fill([0, 0]);
|
|
4017
|
+
pads[checkAxis(axis, a.ndim)] = p;
|
|
4018
|
+
return pad$1(a, pads);
|
|
4019
|
+
}
|
|
3826
4020
|
const jvpRules = {
|
|
3827
4021
|
[Primitive.Add]: linearTangentsJvp(Primitive.Add),
|
|
3828
4022
|
[Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
|
|
@@ -3921,6 +4115,8 @@ const jvpRules = {
|
|
|
3921
4115
|
dcond.dispose();
|
|
3922
4116
|
return [[where$1(cond.ref, x, y)], [where$1(cond, dx, dy)]];
|
|
3923
4117
|
},
|
|
4118
|
+
[Primitive.Concatenate]: linearTangentsJvp(Primitive.Concatenate),
|
|
4119
|
+
[Primitive.Split]: linearTangentsJvp(Primitive.Split),
|
|
3924
4120
|
[Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
|
|
3925
4121
|
[Primitive.Gather]([x, ...indices], [dx, ..._], { axis, outDim }) {
|
|
3926
4122
|
const indicesRef = indices.map((t) => t.ref);
|
|
@@ -3955,6 +4151,38 @@ const jvpRules = {
|
|
|
3955
4151
|
const dL = batchMatmulT(L.ref, triu(ST.ref, 1).add(triu(ST)).mul(.5));
|
|
3956
4152
|
return [[L], [dL]];
|
|
3957
4153
|
},
|
|
4154
|
+
[Primitive.LU]([a], [da]) {
|
|
4155
|
+
const [luMatrix, pivots, permutation] = lu$1(a);
|
|
4156
|
+
const [m, n] = a.shape.slice(-2);
|
|
4157
|
+
const k = Math.min(m, n);
|
|
4158
|
+
const luSliceL = sliceAxis(luMatrix.ref, -1, [0, k]);
|
|
4159
|
+
const lLower = tril(luSliceL, -1);
|
|
4160
|
+
const lPadded = m > k ? padAxis(lLower, -1, [0, m - k]) : lLower;
|
|
4161
|
+
const L = lPadded.add(eye(m));
|
|
4162
|
+
const luSliceU = sliceAxis(luMatrix.ref, -2, [0, k]);
|
|
4163
|
+
const uUpper = triu(luSliceU);
|
|
4164
|
+
const uPadded = n > k ? padAxis(uUpper, -2, [0, n - k]) : uUpper;
|
|
4165
|
+
const uEye = n > k ? padAxis(padAxis(eye(n - k), -1, [k, 0]), -2, [k, 0]) : zerosLike$1(uPadded.ref);
|
|
4166
|
+
const U = uPadded.add(uEye);
|
|
4167
|
+
const P = permutation.ref.reshape([...permutation.shape, 1]).equal(arange(m)).astype(da.dtype);
|
|
4168
|
+
const pda = batchMatmulT(P, mT(da));
|
|
4169
|
+
const la = mT(triangularSolve$1(L.ref, mT(pda), {
|
|
4170
|
+
lower: true,
|
|
4171
|
+
unitDiagonal: true
|
|
4172
|
+
}));
|
|
4173
|
+
const lau = triangularSolve$1(mT(U.ref), la, { lower: true });
|
|
4174
|
+
const lDot = batchMatmulT(L, mT(tril(lau.ref, -1)));
|
|
4175
|
+
const uDot = batchMatmulT(triu(lau), mT(U));
|
|
4176
|
+
return [[
|
|
4177
|
+
luMatrix,
|
|
4178
|
+
pivots,
|
|
4179
|
+
permutation
|
|
4180
|
+
], [
|
|
4181
|
+
lDot.add(uDot),
|
|
4182
|
+
zerosLike$1(pivots.ref),
|
|
4183
|
+
zerosLike$1(permutation.ref)
|
|
4184
|
+
]];
|
|
4185
|
+
},
|
|
3958
4186
|
[Primitive.Jit](primals, tangents, { name, jaxpr }) {
|
|
3959
4187
|
const newJaxpr = jvpJaxpr(jaxpr);
|
|
3960
4188
|
const outs = bind(Primitive.Jit, [
|
|
@@ -4492,6 +4720,15 @@ const transposeRules = {
|
|
|
4492
4720
|
cond.dispose();
|
|
4493
4721
|
return cts;
|
|
4494
4722
|
},
|
|
4723
|
+
[Primitive.Concatenate]([ct], inputs, { axis }) {
|
|
4724
|
+
if (inputs.some((x) => !(x instanceof UndefPrimal))) throw new NonlinearError(Primitive.Concatenate);
|
|
4725
|
+
const sizes = inputs.map((x) => x.aval.shape[axis]);
|
|
4726
|
+
return split$2(ct, axis, sizes);
|
|
4727
|
+
},
|
|
4728
|
+
[Primitive.Split](cts, [x], { axis }) {
|
|
4729
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Split);
|
|
4730
|
+
return [concatenate$1(cts, axis)];
|
|
4731
|
+
},
|
|
4495
4732
|
[Primitive.Gather]([ct], [x, ...indices], { axis, outDim }) {
|
|
4496
4733
|
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
|
|
4497
4734
|
if (indices.some((i) => i instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
|
|
@@ -4767,8 +5004,8 @@ function computeSizeMap({ shapes, lhsIndices, rhsIndex }) {
|
|
|
4767
5004
|
const idx = lhsIndex[j];
|
|
4768
5005
|
const dim = shape$1[j];
|
|
4769
5006
|
const existing = sizeMap.get(idx);
|
|
4770
|
-
if (existing === void 0) sizeMap.set(idx, dim);
|
|
4771
|
-
else if (existing !== dim) throw new Error(`Inconsistent size for index ${idx} in einsum: ${existing} vs ${dim}`);
|
|
5007
|
+
if (existing === void 0 || existing === 1) sizeMap.set(idx, dim);
|
|
5008
|
+
else if (existing !== dim && dim !== 1) throw new Error(`Inconsistent size for index ${idx} in einsum: ${existing} vs ${dim}`);
|
|
4772
5009
|
}
|
|
4773
5010
|
}
|
|
4774
5011
|
for (const [idx, size$1] of sizeMap) if (!Number.isInteger(idx) || idx < 0) throw new Error(`Invalid index ${idx} in einsum expression, must be non-negative integer`);
|
|
@@ -4924,27 +5161,53 @@ function ifft(a, axis = -1) {
|
|
|
4924
5161
|
//#region src/library/numpy-linalg.ts
|
|
4925
5162
|
var numpy_linalg_exports = {};
|
|
4926
5163
|
__export(numpy_linalg_exports, {
|
|
4927
|
-
cholesky: () => cholesky
|
|
5164
|
+
cholesky: () => cholesky,
|
|
5165
|
+
det: () => det,
|
|
4928
5166
|
diagonal: () => diagonal,
|
|
5167
|
+
inv: () => inv,
|
|
4929
5168
|
lstsq: () => lstsq,
|
|
4930
5169
|
matmul: () => matmul,
|
|
5170
|
+
matrixPower: () => matrixPower,
|
|
4931
5171
|
matrixTranspose: () => matrixTranspose,
|
|
4932
5172
|
outer: () => outer,
|
|
5173
|
+
slogdet: () => slogdet,
|
|
5174
|
+
solve: () => solve,
|
|
4933
5175
|
tensordot: () => tensordot,
|
|
4934
5176
|
trace: () => trace,
|
|
4935
5177
|
vecdot: () => vecdot
|
|
4936
5178
|
});
|
|
5179
|
+
function checkSquare(name, a) {
|
|
5180
|
+
if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`${name}: input must be at least 2D square matrix, got ${a.aval}`);
|
|
5181
|
+
return a.shape[a.ndim - 1];
|
|
5182
|
+
}
|
|
4937
5183
|
/**
|
|
4938
5184
|
* Compute the Cholesky decomposition of a (batched) positive-definite matrix.
|
|
4939
5185
|
*
|
|
4940
5186
|
* This is like `jax.lax.linalg.cholesky()`, except with an option to symmetrize
|
|
4941
5187
|
* the input matrix, which is on by default.
|
|
4942
5188
|
*/
|
|
4943
|
-
function cholesky
|
|
5189
|
+
function cholesky(a, { upper = false, symmetrizeInput = true } = {}) {
|
|
4944
5190
|
a = fudgeArray(a);
|
|
4945
|
-
|
|
5191
|
+
checkSquare("cholesky", a);
|
|
4946
5192
|
if (symmetrizeInput) a = a.ref.add(matrixTranspose(a)).mul(.5);
|
|
4947
|
-
return cholesky(a, { upper });
|
|
5193
|
+
return cholesky$1(a, { upper });
|
|
5194
|
+
}
|
|
5195
|
+
/** Compute the determinant of a square matrix (batched). */
|
|
5196
|
+
function det(a) {
|
|
5197
|
+
a = fudgeArray(a);
|
|
5198
|
+
const n = checkSquare("det", a);
|
|
5199
|
+
const [lu$2, pivots, permutation] = lu(a);
|
|
5200
|
+
permutation.dispose();
|
|
5201
|
+
const parity = pivots.notEqual(arange(n)).astype(int32).sum(-1).mod(2);
|
|
5202
|
+
const sign$1 = parity.mul(-2).add(1);
|
|
5203
|
+
const diag$1 = lu$2.diagonal(0, -1, -2);
|
|
5204
|
+
return prod$1(diag$1, -1).mul(sign$1);
|
|
5205
|
+
}
|
|
5206
|
+
/** Compute the inverse of a square matrix (batched). */
|
|
5207
|
+
function inv(a) {
|
|
5208
|
+
a = fudgeArray(a);
|
|
5209
|
+
const n = checkSquare("inv", a);
|
|
5210
|
+
return solve(a, eye(n));
|
|
4948
5211
|
}
|
|
4949
5212
|
/**
|
|
4950
5213
|
* Return the least-squares solution to a linear equation.
|
|
@@ -4968,7 +5231,7 @@ function lstsq(a, b) {
|
|
|
4968
5231
|
const at = matrixTranspose(a.ref);
|
|
4969
5232
|
if (m <= n) {
|
|
4970
5233
|
const aat = matmul(a, at.ref);
|
|
4971
|
-
const l = cholesky
|
|
5234
|
+
const l = cholesky(aat, { symmetrizeInput: false });
|
|
4972
5235
|
const lb = triangularSolve(l.ref, b, {
|
|
4973
5236
|
leftSide: true,
|
|
4974
5237
|
lower: true
|
|
@@ -4980,7 +5243,7 @@ function lstsq(a, b) {
|
|
|
4980
5243
|
return matmul(at, llb.ref);
|
|
4981
5244
|
} else {
|
|
4982
5245
|
const ata = matmul(at.ref, a);
|
|
4983
|
-
const l = cholesky
|
|
5246
|
+
const l = cholesky(ata, { symmetrizeInput: false });
|
|
4984
5247
|
const atb = matmul(at, b);
|
|
4985
5248
|
const lb = triangularSolve(l.ref, atb, {
|
|
4986
5249
|
leftSide: true,
|
|
@@ -4993,6 +5256,169 @@ function lstsq(a, b) {
|
|
|
4993
5256
|
return llb;
|
|
4994
5257
|
}
|
|
4995
5258
|
}
|
|
5259
|
+
/** Raise a square matrix to an integer power, via repeated squarings. */
|
|
5260
|
+
function matrixPower(a, n) {
|
|
5261
|
+
if (!Number.isInteger(n)) throw new Error(`matrixPower: exponent must be an integer, got ${n}`);
|
|
5262
|
+
a = fudgeArray(a);
|
|
5263
|
+
const m = checkSquare("matrixPower", a);
|
|
5264
|
+
if (n === 0) {
|
|
5265
|
+
a.dispose();
|
|
5266
|
+
return broadcastTo(eye(m), a.shape);
|
|
5267
|
+
}
|
|
5268
|
+
if (n < 0) {
|
|
5269
|
+
a = inv(a);
|
|
5270
|
+
n = -n;
|
|
5271
|
+
}
|
|
5272
|
+
let result = null;
|
|
5273
|
+
let a2k = a;
|
|
5274
|
+
for (let k = 0; n; k++) {
|
|
5275
|
+
if (k > 0) a2k = matmul(a2k.ref, a2k);
|
|
5276
|
+
if (n % 2 === 1) result = result === null ? a2k.ref : matmul(result, a2k.ref);
|
|
5277
|
+
n = Math.floor(n / 2);
|
|
5278
|
+
}
|
|
5279
|
+
a2k.dispose();
|
|
5280
|
+
return result;
|
|
5281
|
+
}
|
|
5282
|
+
/** Return sign and natural logarithm of the determinant of `a`. */
|
|
5283
|
+
function slogdet(a) {
|
|
5284
|
+
a = fudgeArray(a);
|
|
5285
|
+
const n = checkSquare("slogdet", a);
|
|
5286
|
+
const [lu$2, pivots, permutation] = lu(a);
|
|
5287
|
+
permutation.dispose();
|
|
5288
|
+
let parity = pivots.notEqual(arange(n)).astype(int32).sum(-1);
|
|
5289
|
+
const diag$1 = lu$2.diagonal(0, -1, -2);
|
|
5290
|
+
parity = parity.add(diag$1.ref.less(0).astype(int32).sum(-1)).mod(2);
|
|
5291
|
+
const logabsdet = log(absolute(diag$1)).sum(-1);
|
|
5292
|
+
const sign$1 = parity.mul(-2).add(1);
|
|
5293
|
+
return [sign$1, logabsdet];
|
|
5294
|
+
}
|
|
5295
|
+
/**
|
|
5296
|
+
* Solve a linear system of equations.
|
|
5297
|
+
*
|
|
5298
|
+
* This solves a (batched) linear system of equations `a @ x = b` for `x` given
|
|
5299
|
+
* `a` and `b`. If `a` is singular, this will return `nan` or `inf` values.
|
|
5300
|
+
*
|
|
5301
|
+
* @param a - Coefficient matrix of shape `(..., N, N)`.
|
|
5302
|
+
* @param b - Values of shape `(N,)` or `(..., N, M)`.
|
|
5303
|
+
* @returns Solution `x` of shape `(..., N)` or `(..., N, M)`.
|
|
5304
|
+
*/
|
|
5305
|
+
function solve(a, b) {
|
|
5306
|
+
a = fudgeArray(a);
|
|
5307
|
+
b = fudgeArray(b);
|
|
5308
|
+
const n = checkSquare("solve", a);
|
|
5309
|
+
if (b.ndim === 0) throw new Error(`solve: b cannot be scalar`);
|
|
5310
|
+
const bIs1d = b.ndim === 1;
|
|
5311
|
+
if (bIs1d) b = b.reshape([...b.shape, 1]);
|
|
5312
|
+
if (b.shape[b.ndim - 2] !== n) throw new Error(`solve: leading dimension of b must match size of a, got a=${a.aval}, b=${b.aval}`);
|
|
5313
|
+
const m = b.shape[b.ndim - 1];
|
|
5314
|
+
const batchDims = generalBroadcast(a.shape.slice(0, -2), b.shape.slice(0, -2));
|
|
5315
|
+
a = broadcastTo(a, [
|
|
5316
|
+
...batchDims,
|
|
5317
|
+
n,
|
|
5318
|
+
n
|
|
5319
|
+
]);
|
|
5320
|
+
b = broadcastTo(b, [
|
|
5321
|
+
...batchDims,
|
|
5322
|
+
n,
|
|
5323
|
+
m
|
|
5324
|
+
]);
|
|
5325
|
+
const [lu$2, pivots, permutation] = lu(a);
|
|
5326
|
+
pivots.dispose();
|
|
5327
|
+
const P = arange(n).equal(permutation.reshape([...permutation.shape, 1])).astype(b.dtype);
|
|
5328
|
+
const LPb = triangularSolve(lu$2.ref, matmul(P, b), {
|
|
5329
|
+
leftSide: true,
|
|
5330
|
+
lower: true,
|
|
5331
|
+
unitDiagonal: true
|
|
5332
|
+
});
|
|
5333
|
+
let x = triangularSolve(lu$2, LPb.ref, {
|
|
5334
|
+
leftSide: true,
|
|
5335
|
+
lower: false
|
|
5336
|
+
});
|
|
5337
|
+
if (bIs1d) x = squeeze(x, -1);
|
|
5338
|
+
return x;
|
|
5339
|
+
}
|
|
5340
|
+
|
|
5341
|
+
//#endregion
|
|
5342
|
+
//#region src/library/numpy/dtype-info.ts
|
|
5343
|
+
/** Machine limits for floating-point types. */
|
|
5344
|
+
function finfo(dtype) {
|
|
5345
|
+
if (!isFloatDtype(dtype)) throw new Error(`finfo: received ${dtype}, must be a floating-point type`);
|
|
5346
|
+
switch (dtype) {
|
|
5347
|
+
case DType.Float16: return Object.freeze({
|
|
5348
|
+
bits: 16,
|
|
5349
|
+
dtype: DType.Float16,
|
|
5350
|
+
eps: 2 ** -10,
|
|
5351
|
+
epsneg: 2 ** -11,
|
|
5352
|
+
machep: -10,
|
|
5353
|
+
max: 65504,
|
|
5354
|
+
maxexp: 16,
|
|
5355
|
+
min: -65504,
|
|
5356
|
+
minexp: -14,
|
|
5357
|
+
negep: -24,
|
|
5358
|
+
nexp: 5,
|
|
5359
|
+
nmant: 10,
|
|
5360
|
+
precision: 3,
|
|
5361
|
+
resolution: .001,
|
|
5362
|
+
smallestNormal: 2 ** -14,
|
|
5363
|
+
smallestSubnormal: 2 ** -24
|
|
5364
|
+
});
|
|
5365
|
+
case DType.Float32: return Object.freeze({
|
|
5366
|
+
bits: 32,
|
|
5367
|
+
dtype: DType.Float32,
|
|
5368
|
+
eps: 2 ** -23,
|
|
5369
|
+
epsneg: 2 ** -24,
|
|
5370
|
+
machep: -23,
|
|
5371
|
+
max: 34028234663852886e22,
|
|
5372
|
+
maxexp: 128,
|
|
5373
|
+
min: -34028234663852886e22,
|
|
5374
|
+
minexp: -126,
|
|
5375
|
+
negep: -24,
|
|
5376
|
+
nexp: 8,
|
|
5377
|
+
nmant: 23,
|
|
5378
|
+
precision: 6,
|
|
5379
|
+
resolution: 1e-6,
|
|
5380
|
+
smallestNormal: 2 ** -126,
|
|
5381
|
+
smallestSubnormal: 2 ** -149
|
|
5382
|
+
});
|
|
5383
|
+
case DType.Float64: return Object.freeze({
|
|
5384
|
+
bits: 64,
|
|
5385
|
+
dtype: DType.Float64,
|
|
5386
|
+
eps: 2 ** -52,
|
|
5387
|
+
epsneg: 2 ** -53,
|
|
5388
|
+
machep: -52,
|
|
5389
|
+
max: Number.MAX_VALUE,
|
|
5390
|
+
maxexp: 1024,
|
|
5391
|
+
min: -Number.MAX_VALUE,
|
|
5392
|
+
minexp: -1022,
|
|
5393
|
+
negep: -53,
|
|
5394
|
+
nexp: 11,
|
|
5395
|
+
nmant: 52,
|
|
5396
|
+
precision: 15,
|
|
5397
|
+
resolution: 1e-15,
|
|
5398
|
+
smallestNormal: 2 ** -1022,
|
|
5399
|
+
smallestSubnormal: 2 ** -1074
|
|
5400
|
+
});
|
|
5401
|
+
default: throw new Error(`finfo: unsupported dtype ${dtype}`);
|
|
5402
|
+
}
|
|
5403
|
+
}
|
|
5404
|
+
/** Machine limits for integer types. */
|
|
5405
|
+
function iinfo(dtype) {
|
|
5406
|
+
switch (dtype) {
|
|
5407
|
+
case DType.Int32: return Object.freeze({
|
|
5408
|
+
bits: 32,
|
|
5409
|
+
dtype: DType.Int32,
|
|
5410
|
+
max: 2147483647,
|
|
5411
|
+
min: -2147483648
|
|
5412
|
+
});
|
|
5413
|
+
case DType.Uint32: return Object.freeze({
|
|
5414
|
+
bits: 32,
|
|
5415
|
+
dtype: DType.Uint32,
|
|
5416
|
+
max: 4294967295,
|
|
5417
|
+
min: 0
|
|
5418
|
+
});
|
|
5419
|
+
default: throw new Error(`iinfo: unsupported dtype ${dtype}`);
|
|
5420
|
+
}
|
|
5421
|
+
}
|
|
4996
5422
|
|
|
4997
5423
|
//#endregion
|
|
4998
5424
|
//#region src/library/numpy.ts
|
|
@@ -5048,6 +5474,7 @@ __export(numpy_exports, {
|
|
|
5048
5474
|
diag: () => diag,
|
|
5049
5475
|
diagonal: () => diagonal,
|
|
5050
5476
|
divide: () => trueDivide,
|
|
5477
|
+
divmod: () => divmod,
|
|
5051
5478
|
dot: () => dot$1,
|
|
5052
5479
|
dstack: () => dstack,
|
|
5053
5480
|
e: () => e,
|
|
@@ -5060,6 +5487,7 @@ __export(numpy_exports, {
|
|
|
5060
5487
|
expm1: () => expm1,
|
|
5061
5488
|
eye: () => eye,
|
|
5062
5489
|
fft: () => numpy_fft_exports,
|
|
5490
|
+
finfo: () => finfo,
|
|
5063
5491
|
flip: () => flip,
|
|
5064
5492
|
fliplr: () => fliplr,
|
|
5065
5493
|
flipud: () => flipud,
|
|
@@ -5067,6 +5495,7 @@ __export(numpy_exports, {
|
|
|
5067
5495
|
float32: () => float32,
|
|
5068
5496
|
float64: () => float64,
|
|
5069
5497
|
floor: () => floor,
|
|
5498
|
+
floorDivide: () => floorDivide,
|
|
5070
5499
|
fmod: () => fmod,
|
|
5071
5500
|
frexp: () => frexp,
|
|
5072
5501
|
full: () => full,
|
|
@@ -5079,6 +5508,7 @@ __export(numpy_exports, {
|
|
|
5079
5508
|
hstack: () => hstack,
|
|
5080
5509
|
hypot: () => hypot,
|
|
5081
5510
|
identity: () => identity$1,
|
|
5511
|
+
iinfo: () => iinfo,
|
|
5082
5512
|
inf: () => inf,
|
|
5083
5513
|
inner: () => inner,
|
|
5084
5514
|
int32: () => int32,
|
|
@@ -5096,6 +5526,7 @@ __export(numpy_exports, {
|
|
|
5096
5526
|
log10: () => log10,
|
|
5097
5527
|
log1p: () => log1p,
|
|
5098
5528
|
log2: () => log2,
|
|
5529
|
+
logspace: () => logspace,
|
|
5099
5530
|
matmul: () => matmul,
|
|
5100
5531
|
matrixTranspose: () => matrixTranspose,
|
|
5101
5532
|
max: () => max,
|
|
@@ -5132,9 +5563,11 @@ __export(numpy_exports, {
|
|
|
5132
5563
|
shape: () => shape,
|
|
5133
5564
|
sign: () => sign,
|
|
5134
5565
|
sin: () => sin,
|
|
5566
|
+
sinc: () => sinc,
|
|
5135
5567
|
sinh: () => sinh,
|
|
5136
5568
|
size: () => size,
|
|
5137
5569
|
sort: () => sort,
|
|
5570
|
+
split: () => split$1,
|
|
5138
5571
|
sqrt: () => sqrt,
|
|
5139
5572
|
square: () => square,
|
|
5140
5573
|
squeeze: () => squeeze,
|
|
@@ -5142,6 +5575,7 @@ __export(numpy_exports, {
|
|
|
5142
5575
|
std: () => std,
|
|
5143
5576
|
subtract: () => subtract,
|
|
5144
5577
|
sum: () => sum,
|
|
5578
|
+
take: () => take,
|
|
5145
5579
|
tan: () => tan,
|
|
5146
5580
|
tanh: () => tanh,
|
|
5147
5581
|
tensordot: () => tensordot,
|
|
@@ -5400,6 +5834,45 @@ function flip(x, axis = null) {
|
|
|
5400
5834
|
return flip$1(x, axis);
|
|
5401
5835
|
}
|
|
5402
5836
|
/**
|
|
5837
|
+
* Split an array into multiple sub-arrays along an axis.
|
|
5838
|
+
*
|
|
5839
|
+
* @param a - The input array to split.
|
|
5840
|
+
* @param indicesOrSections - If an integer, it indicates the number of equal
|
|
5841
|
+
* sections to create along the specified axis. If a list of integers, it
|
|
5842
|
+
* specifies the indices at which to split the array.
|
|
5843
|
+
* @param axis - The axis along which to split the array. Default is 0.
|
|
5844
|
+
*/
|
|
5845
|
+
function split$1(a, indicesOrSections, axis = 0) {
|
|
5846
|
+
a = fudgeArray(a);
|
|
5847
|
+
axis = checkAxis(axis, a.ndim);
|
|
5848
|
+
const size$1 = a.shape[axis];
|
|
5849
|
+
let sizes;
|
|
5850
|
+
if (typeof indicesOrSections === "number") {
|
|
5851
|
+
if (size$1 % indicesOrSections !== 0) throw new Error(`Array of size ${size$1} cannot be split into ${indicesOrSections} equal parts`);
|
|
5852
|
+
const partSize = size$1 / indicesOrSections;
|
|
5853
|
+
sizes = rep(indicesOrSections, partSize);
|
|
5854
|
+
} else {
|
|
5855
|
+
const indices = indicesOrSections;
|
|
5856
|
+
sizes = [indices[0]];
|
|
5857
|
+
for (let i = 1; i < indices.length; i++) sizes.push(indices[i] - indices[i - 1]);
|
|
5858
|
+
sizes.push(size$1 - indices[indices.length - 1]);
|
|
5859
|
+
}
|
|
5860
|
+
const results = [];
|
|
5861
|
+
for (let i = 0; i < sizes.length; i += 7) if (i === sizes.length) {
|
|
5862
|
+
results.push(a);
|
|
5863
|
+
break;
|
|
5864
|
+
} else if (i + 8 >= sizes.length) {
|
|
5865
|
+
results.push(...split$2(a, axis, sizes.slice(i)));
|
|
5866
|
+
break;
|
|
5867
|
+
} else {
|
|
5868
|
+
const groupSizes = [...sizes.slice(i, i + 7), sizes.slice(i + 7).reduce((x, y) => x + y, 0)];
|
|
5869
|
+
const outs = split$2(a, axis, groupSizes);
|
|
5870
|
+
results.push(...outs.slice(0, -1));
|
|
5871
|
+
a = outs[outs.length - 1];
|
|
5872
|
+
}
|
|
5873
|
+
return results;
|
|
5874
|
+
}
|
|
5875
|
+
/**
|
|
5403
5876
|
* Join a sequence of arrays along an existing axis.
|
|
5404
5877
|
*
|
|
5405
5878
|
* The arrays must have the same shape, except in the dimension corresponding to
|
|
@@ -5411,13 +5884,11 @@ function concatenate(xs, axis = 0) {
|
|
|
5411
5884
|
if (xs.length === 0) throw new Error("Need at least one array to concatenate");
|
|
5412
5885
|
const shapes = xs.map(shape);
|
|
5413
5886
|
axis = checkAxis(axis, shapes[0].length);
|
|
5414
|
-
for (let i = 1; i < shapes.length; i++) if (shapes[i].length !== shapes[0].length || !shapes[i].every((d, j) => j === axis || d === shapes[0][j])) throw new Error(`Cannot concatenate arrays
|
|
5415
|
-
const makePadAxis = (start, end) => shapes[0].map((_, i) => i === axis ? [start, end] : [0, 0]);
|
|
5887
|
+
for (let i = 1; i < shapes.length; i++) if (shapes[i].length !== shapes[0].length || !shapes[i].every((d, j) => j === axis || d === shapes[0][j])) throw new Error(`Cannot concatenate arrays ${xs[0].aval} and ${xs[i].aval} along axis ${axis}`);
|
|
5416
5888
|
let result = xs[0];
|
|
5417
|
-
for (let i = 1; i < xs.length; i
|
|
5418
|
-
const
|
|
5419
|
-
|
|
5420
|
-
result = pad(result, makePadAxis(0, len2)).add(pad(xs[i], makePadAxis(len1, 0)));
|
|
5889
|
+
for (let i = 1; i < xs.length; i += 7) {
|
|
5890
|
+
const group = xs.slice(i, i + 7);
|
|
5891
|
+
result = concatenate$1([result, ...group], axis);
|
|
5421
5892
|
}
|
|
5422
5893
|
return result;
|
|
5423
5894
|
}
|
|
@@ -5669,6 +6140,20 @@ function sort(a, axis = -1) {
|
|
|
5669
6140
|
function argsort(a, axis = -1) {
|
|
5670
6141
|
return fudgeArray(a).argsort(axis);
|
|
5671
6142
|
}
|
|
6143
|
+
/**
|
|
6144
|
+
* Take elements from an array along an axis.
|
|
6145
|
+
*
|
|
6146
|
+
* This is equivalent to advanced indexing with integer indices over that
|
|
6147
|
+
* numbered axis. By default, the flattened array is used.
|
|
6148
|
+
*/
|
|
6149
|
+
function take(a, indices, axis = null) {
|
|
6150
|
+
if (axis === null) {
|
|
6151
|
+
a = ravel(a);
|
|
6152
|
+
axis = 0;
|
|
6153
|
+
}
|
|
6154
|
+
axis = checkAxis(axis, ndim(a));
|
|
6155
|
+
return gather(a, [indices], [axis], axis);
|
|
6156
|
+
}
|
|
5672
6157
|
/** Return if two arrays are element-wise equal within a tolerance. */
|
|
5673
6158
|
function allclose(actual, expected, options) {
|
|
5674
6159
|
const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
|
|
@@ -5988,6 +6473,20 @@ function tan(x) {
|
|
|
5988
6473
|
x = fudgeArray(x);
|
|
5989
6474
|
return sin(x.ref).div(cos(x));
|
|
5990
6475
|
}
|
|
6476
|
+
/**
|
|
6477
|
+
* @function
|
|
6478
|
+
* Return the normalized sinc function.
|
|
6479
|
+
*
|
|
6480
|
+
* The sinc function is defined as `sin(πx) / (πx)` for `x != 0`, and `1` for `x = 0`.
|
|
6481
|
+
* This is the normalized sinc function commonly used in signal processing.
|
|
6482
|
+
*
|
|
6483
|
+
* **Note:** JVP is not supported at x=0 due to discontinuous derivative. This
|
|
6484
|
+
* requires a custom JVP rule to handle properly (see JAX implementation).
|
|
6485
|
+
*/
|
|
6486
|
+
const sinc = jit$1(function sinc$1(x) {
|
|
6487
|
+
const pix = x.ref.mul(Math.PI);
|
|
6488
|
+
return where(equal(x, 0), 1, sin(pix.ref).div(pix));
|
|
6489
|
+
});
|
|
5991
6490
|
/** Element-wise inverse cosine function (inverse of cos). */
|
|
5992
6491
|
function acos(x) {
|
|
5993
6492
|
return subtract(pi / 2, asin(x));
|
|
@@ -6040,6 +6539,25 @@ function trueDivide(x, y) {
|
|
|
6040
6539
|
return x.div(y);
|
|
6041
6540
|
}
|
|
6042
6541
|
/**
|
|
6542
|
+
* Return the largest integer smaller or equal to the division of the inputs.
|
|
6543
|
+
*
|
|
6544
|
+
* The result is always rounded towards negative infinity.
|
|
6545
|
+
*
|
|
6546
|
+
* For floating-point inputs, this is equivalent to `floor(x / y)`.
|
|
6547
|
+
* For integer inputs, we use `(x - remainder(x, y)) / y` to handle
|
|
6548
|
+
* negative values correctly (note: may overflow near int32 boundaries).
|
|
6549
|
+
*
|
|
6550
|
+
* @param x - Dividend array.
|
|
6551
|
+
* @param y - Divisor array.
|
|
6552
|
+
* @returns Element-wise floor division of x by y.
|
|
6553
|
+
*/
|
|
6554
|
+
function floorDivide(x, y) {
|
|
6555
|
+
x = fudgeArray(x);
|
|
6556
|
+
y = fudgeArray(y);
|
|
6557
|
+
if (isFloatDtype(x.dtype) || isFloatDtype(y.dtype)) return floor(trueDivide(x, y));
|
|
6558
|
+
return subtract(x, remainder(x.ref, y.ref)).div(y);
|
|
6559
|
+
}
|
|
6560
|
+
/**
|
|
6043
6561
|
* @function
|
|
6044
6562
|
* Calculate element-wise floating-point modulo operation.
|
|
6045
6563
|
*/
|
|
@@ -6053,6 +6571,20 @@ const fmod = jit$1(function fmod$1(x, y) {
|
|
|
6053
6571
|
const remainder = jit$1(function remainder$1(x, y) {
|
|
6054
6572
|
return mod(mod(x, y.ref).add(y.ref), y);
|
|
6055
6573
|
});
|
|
6574
|
+
/**
|
|
6575
|
+
* Return element-wise quotient and remainder simultaneously.
|
|
6576
|
+
*
|
|
6577
|
+
* Equivalent to `[floorDivide(x, y), remainder(x, y)]`.
|
|
6578
|
+
*
|
|
6579
|
+
* @param x - Dividend array.
|
|
6580
|
+
* @param y - Divisor array.
|
|
6581
|
+
* @returns Tuple of [quotient, remainder].
|
|
6582
|
+
*/
|
|
6583
|
+
function divmod(x, y) {
|
|
6584
|
+
const xArr = fudgeArray(x);
|
|
6585
|
+
const yArr = fudgeArray(y);
|
|
6586
|
+
return [floorDivide(xArr.ref, yArr.ref), remainder(xArr, yArr)];
|
|
6587
|
+
}
|
|
6056
6588
|
/** Round input to the nearest integer towards zero. */
|
|
6057
6589
|
function trunc(x) {
|
|
6058
6590
|
return idiv(x, 1);
|
|
@@ -6216,14 +6748,15 @@ function std(x, axis = null, opts) {
|
|
|
6216
6748
|
return sqrt(var_(x, axis, opts));
|
|
6217
6749
|
}
|
|
6218
6750
|
/** Estimate the sample covariance of a set of variables. */
|
|
6219
|
-
function cov(x, y) {
|
|
6751
|
+
function cov(x, y = null, { rowvar = true } = {}) {
|
|
6220
6752
|
x = fudgeArray(x);
|
|
6221
6753
|
if (x.ndim === 1) x = x.reshape([1, x.shape[0]]);
|
|
6222
|
-
if (y !==
|
|
6754
|
+
if (y !== null) {
|
|
6223
6755
|
y = fudgeArray(y);
|
|
6224
6756
|
if (y.ndim === 1) y = y.reshape([1, y.shape[0]]);
|
|
6225
6757
|
x = vstack([x, y]);
|
|
6226
6758
|
}
|
|
6759
|
+
if (!rowvar) x = x.transpose();
|
|
6227
6760
|
const [_M, N] = x.shape;
|
|
6228
6761
|
x = x.ref.sub(x.mean(1, { keepdims: true }));
|
|
6229
6762
|
return dot$1(x.ref, x.transpose()).div(N - 1);
|
|
@@ -6268,7 +6801,8 @@ const isfinite = jit$1(function isfinite$1(x) {
|
|
|
6268
6801
|
//#region src/library/lax-linalg.ts
|
|
6269
6802
|
var lax_linalg_exports = {};
|
|
6270
6803
|
__export(lax_linalg_exports, {
|
|
6271
|
-
cholesky: () => cholesky,
|
|
6804
|
+
cholesky: () => cholesky$1,
|
|
6805
|
+
lu: () => lu,
|
|
6272
6806
|
triangularSolve: () => triangularSolve
|
|
6273
6807
|
});
|
|
6274
6808
|
/**
|
|
@@ -6297,11 +6831,39 @@ __export(lax_linalg_exports, {
|
|
|
6297
6831
|
* // U ≈ [[1.4142135, 0.70710677], [0, 1.2247449]]
|
|
6298
6832
|
* ```
|
|
6299
6833
|
*/
|
|
6300
|
-
function cholesky(a, { upper = false } = {}) {
|
|
6834
|
+
function cholesky$1(a, { upper = false } = {}) {
|
|
6301
6835
|
const L = cholesky$2(a);
|
|
6302
6836
|
return upper ? moveaxis$1(L, -2, -1) : L;
|
|
6303
6837
|
}
|
|
6304
6838
|
/**
|
|
6839
|
+
* LU decomposition with partial pivoting.
|
|
6840
|
+
*
|
|
6841
|
+
* Computes the matrix decomposition: `P @ A = L @ U`, where `P` is a
|
|
6842
|
+
* permutation of the rows of `A`, `L` is lower-triangular with unit diagonal,
|
|
6843
|
+
* and `U` is upper-triangular.
|
|
6844
|
+
*
|
|
6845
|
+
* @param x - A batch of matrices with shape `[..., m, n]`.
|
|
6846
|
+
*
|
|
6847
|
+
* @returns A tuple `(lu, pivots, permutation)` where:
|
|
6848
|
+
* - `lu`: combined lower and upper triangular matrices.
|
|
6849
|
+
* - `pivots`: an array of pivot indices with shape `[..., min(m, n)]`.
|
|
6850
|
+
* - `permutation`: the permutation generated by pivots with shape `[..., m]`.
|
|
6851
|
+
*
|
|
6852
|
+
* @example
|
|
6853
|
+
* ```ts
|
|
6854
|
+
* import { lax, numpy as np } from "@jax-js/jax";
|
|
6855
|
+
*
|
|
6856
|
+
* const A = np.array([[4., 3.], [6., 3.]]);
|
|
6857
|
+
* const [lu, pivots, permutation] = lax.linalg.lu(A);
|
|
6858
|
+
* // lu ≈ [[6., 3.], [0.6666667, 1.0]]
|
|
6859
|
+
* // pivots = [1, 1]
|
|
6860
|
+
* // permutation = [1, 0]
|
|
6861
|
+
* ```
|
|
6862
|
+
*/
|
|
6863
|
+
function lu(x) {
|
|
6864
|
+
return lu$1(x);
|
|
6865
|
+
}
|
|
6866
|
+
/**
|
|
6305
6867
|
* Solve a triangular linear system.
|
|
6306
6868
|
*
|
|
6307
6869
|
* Solves `a @ x = b` (if leftSide=true) or `x @ a = b` (if leftSide=false)
|
|
@@ -6844,33 +7406,41 @@ __export(random_exports, {
|
|
|
6844
7406
|
gumbel: () => gumbel,
|
|
6845
7407
|
key: () => key,
|
|
6846
7408
|
laplace: () => laplace,
|
|
7409
|
+
multivariateNormal: () => multivariateNormal,
|
|
6847
7410
|
normal: () => normal,
|
|
6848
7411
|
split: () => split,
|
|
6849
7412
|
uniform: () => uniform
|
|
6850
7413
|
});
|
|
6851
|
-
function validateKeyShape(key$1) {
|
|
7414
|
+
function validateKeyShape(key$1, scalar = false) {
|
|
6852
7415
|
if (key$1.ndim === 0) throw new Error("Key must have at least one dimension.");
|
|
6853
7416
|
if (key$1.shape[key$1.shape.length - 1] !== 2) throw new Error(`Invalid key shape: ${key$1.shape}. Expected last dimension to be 2.`);
|
|
7417
|
+
if (scalar && key$1.shape.length > 1) throw new Error(`Expected a single PRNG key, but got a batch of keys with shape ${JSON.stringify(key$1.shape)} - use jax.vmap for batching.`);
|
|
6854
7418
|
return key$1.shape.slice(0, -1);
|
|
6855
7419
|
}
|
|
7420
|
+
function getK01(key$1) {
|
|
7421
|
+
const keyShape = validateKeyShape(key$1, true);
|
|
7422
|
+
let [k0, k1] = split$2(key$1, -1, [1, 1]);
|
|
7423
|
+
k0 = k0.reshape(keyShape);
|
|
7424
|
+
k1 = k1.reshape(keyShape);
|
|
7425
|
+
return [k0, k1];
|
|
7426
|
+
}
|
|
6856
7427
|
/** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
|
|
6857
7428
|
function key(seed) {
|
|
6858
|
-
seed = seed
|
|
6859
|
-
|
|
7429
|
+
seed = array(seed, { dtype: DType.Uint32 });
|
|
7430
|
+
if (seed.ndim !== 0) throw new Error(`key: seed must be a scalar integer, but got shape ${seed.shape} - use jax.vmap for batching.`);
|
|
7431
|
+
return stack([0, seed]);
|
|
6860
7432
|
}
|
|
6861
7433
|
/** Splits a PRNG key into `num` new keys by adding a leading axis. */
|
|
6862
7434
|
function split(key$1, num = 2) {
|
|
6863
7435
|
const shape$1 = typeof num === "number" ? [num] : num;
|
|
6864
7436
|
for (const len of shape$1) if (len <= 0 || !Number.isInteger(len)) throw new Error(`Invalid split length: ${len}. Must be a positive integer.`);
|
|
6865
|
-
const
|
|
6866
|
-
const k0 = key$1.ref.slice(...keyShape.map(() => null), 0);
|
|
6867
|
-
const k1 = key$1.slice(...keyShape.map(() => null), 1);
|
|
7437
|
+
const [k0, k1] = getK01(key$1);
|
|
6868
7438
|
return stack([randomBits(k0.ref, k1.ref, shape$1, 0), randomBits(k0, k1, shape$1, 1)], -1);
|
|
6869
7439
|
}
|
|
6870
7440
|
/** Sample uniform bits in the form of unsigned integers. */
|
|
6871
7441
|
function bits(key$1, shape$1 = []) {
|
|
6872
|
-
const
|
|
6873
|
-
return randomBits(
|
|
7442
|
+
const [k0, k1] = getK01(key$1);
|
|
7443
|
+
return randomBits(k0, k1, shape$1);
|
|
6874
7444
|
}
|
|
6875
7445
|
/**
|
|
6876
7446
|
* @function
|
|
@@ -6944,6 +7514,32 @@ const laplace = jit$1(function laplace$1(key$1, shape$1 = []) {
|
|
|
6944
7514
|
}, { staticArgnums: [1] });
|
|
6945
7515
|
/**
|
|
6946
7516
|
* @function
|
|
7517
|
+
* Sample multivariate normal random values with given mean and covariance.
|
|
7518
|
+
*
|
|
7519
|
+
* The values are returned with the given shape, along with the final dimension
|
|
7520
|
+
* used to represent the n-dimensional multivariate normal factors.
|
|
7521
|
+
*
|
|
7522
|
+
* This uses Cholesky decomposition on the covariance matrix.
|
|
7523
|
+
*
|
|
7524
|
+
* - `key` - PRNG key
|
|
7525
|
+
* - `mean` - Mean vector of shape `[..., n]`
|
|
7526
|
+
* - `cov` - Covariance of shape `[..., n, n]`, must be positive-definite
|
|
7527
|
+
* - `shape` - Result batch shape, must be broadcastable with
|
|
7528
|
+
* `mean.shape[:-1]` and `cov.shape[:-2]`
|
|
7529
|
+
* @returns Random samples of shape `[...shape, n]`
|
|
7530
|
+
*/
|
|
7531
|
+
const multivariateNormal = jit$1(function multivariateNormal$1(key$1, mean$1, cov$1, shape$1 = []) {
|
|
7532
|
+
mean$1 = fudgeArray(mean$1);
|
|
7533
|
+
cov$1 = fudgeArray(cov$1);
|
|
7534
|
+
const n = mean$1.shape[mean$1.ndim - 1];
|
|
7535
|
+
if (cov$1.shape[cov$1.ndim - 1] !== n || cov$1.shape[cov$1.ndim - 2] !== n) throw new Error(`Invalid covariance shape: ${cov$1.shape}. Expected last two dimensions to be [${n}, ${n}].`);
|
|
7536
|
+
const outputShape = broadcastShapes(shape$1, mean$1.shape.slice(0, -1), cov$1.shape.slice(0, -2)).concat(n);
|
|
7537
|
+
const L = cholesky(cov$1);
|
|
7538
|
+
const z = normal(key$1, outputShape);
|
|
7539
|
+
return einsum("...ij,...j->...i", L, z).add(mean$1);
|
|
7540
|
+
}, { staticArgnums: [3] });
|
|
7541
|
+
/**
|
|
7542
|
+
* @function
|
|
6947
7543
|
* Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
|
|
6948
7544
|
*
|
|
6949
7545
|
* Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
|