@jax-js/jax 0.1.4 → 0.1.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +10 -7
- package/dist/{backend-Bu9GY6sK.cjs → backend-D7s-Retx.cjs} +122 -8
- package/dist/{backend-tngXtWe4.js → backend-Dx6Ob2D1.js} +111 -9
- package/dist/index.cjs +1059 -208
- package/dist/index.d.cts +429 -21
- package/dist/index.d.ts +429 -21
- package/dist/index.js +1059 -209
- package/dist/webgl-CLLvzJlO.js +522 -0
- package/dist/webgl-CyfzNW8T.cjs +522 -0
- package/dist/{webgpu-ChVgx3b6.js → webgpu-C-VfevQW.js} +296 -3
- package/dist/{webgpu-Oj3Kd-kd.cjs → webgpu-rraa6dfz.cjs} +296 -3
- package/package.json +1 -1
package/dist/index.cjs
CHANGED
|
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
|
|
|
30
30
|
}) : target, mod$1));
|
|
31
31
|
|
|
32
32
|
//#endregion
|
|
33
|
-
const require_backend = require('./backend-
|
|
33
|
+
const require_backend = require('./backend-D7s-Retx.cjs');
|
|
34
34
|
|
|
35
35
|
//#region src/frontend/convolution.ts
|
|
36
36
|
/**
|
|
@@ -240,7 +240,7 @@ __export(tree_exports, {
|
|
|
240
240
|
structure: () => structure,
|
|
241
241
|
unflatten: () => unflatten
|
|
242
242
|
});
|
|
243
|
-
const JsArray$
|
|
243
|
+
const JsArray$2 = globalThis.Array;
|
|
244
244
|
let NodeType = /* @__PURE__ */ function(NodeType$1) {
|
|
245
245
|
NodeType$1["Array"] = "Array";
|
|
246
246
|
NodeType$1["Object"] = "Object";
|
|
@@ -288,7 +288,7 @@ function flatten(tree) {
|
|
|
288
288
|
return [leaves$1, treedef];
|
|
289
289
|
}
|
|
290
290
|
function _flatten(tree, leaves$1) {
|
|
291
|
-
if (JsArray$
|
|
291
|
+
if (JsArray$2.isArray(tree)) {
|
|
292
292
|
const childTrees = tree.map((c) => _flatten(c, leaves$1));
|
|
293
293
|
return new JsTreeDef(NodeType.Array, null, childTrees);
|
|
294
294
|
} else if (typeof tree === "object" && tree !== null && tree.constructor === Object) {
|
|
@@ -387,6 +387,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
387
387
|
Primitive$1["PoolTranspose"] = "pool_transpose";
|
|
388
388
|
Primitive$1["Compare"] = "compare";
|
|
389
389
|
Primitive$1["Where"] = "where";
|
|
390
|
+
Primitive$1["Concatenate"] = "concatenate";
|
|
391
|
+
Primitive$1["Split"] = "split";
|
|
390
392
|
Primitive$1["RandomBits"] = "random_bits";
|
|
391
393
|
Primitive$1["Gather"] = "gather";
|
|
392
394
|
Primitive$1["Transpose"] = "transpose";
|
|
@@ -399,6 +401,7 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
399
401
|
Primitive$1["Argsort"] = "argsort";
|
|
400
402
|
Primitive$1["TriangularSolve"] = "triangular_solve";
|
|
401
403
|
Primitive$1["Cholesky"] = "cholesky";
|
|
404
|
+
Primitive$1["LU"] = "lu";
|
|
402
405
|
Primitive$1["Jit"] = "jit";
|
|
403
406
|
return Primitive$1;
|
|
404
407
|
}({});
|
|
@@ -409,6 +412,13 @@ let CompareOp = /* @__PURE__ */ function(CompareOp$1) {
|
|
|
409
412
|
CompareOp$1["LessEqual"] = "less_equal";
|
|
410
413
|
return CompareOp$1;
|
|
411
414
|
}({});
|
|
415
|
+
const routinePrimitives = new Map([
|
|
416
|
+
[Primitive.Sort, require_backend.Routines.Sort],
|
|
417
|
+
[Primitive.Argsort, require_backend.Routines.Argsort],
|
|
418
|
+
[Primitive.TriangularSolve, require_backend.Routines.TriangularSolve],
|
|
419
|
+
[Primitive.Cholesky, require_backend.Routines.Cholesky],
|
|
420
|
+
[Primitive.LU, require_backend.Routines.LU]
|
|
421
|
+
]);
|
|
412
422
|
function add$1(x, y) {
|
|
413
423
|
return bind1(Primitive.Add, [x, y]);
|
|
414
424
|
}
|
|
@@ -530,7 +540,25 @@ function where$1(cond, x, y) {
|
|
|
530
540
|
y
|
|
531
541
|
]);
|
|
532
542
|
}
|
|
543
|
+
function concatenate$1(xs, axis) {
|
|
544
|
+
if (xs.length === 0) throw new Error("concatenate requires at least one input");
|
|
545
|
+
const avals = xs.map((x) => ShapedArray.fromAval(getAval(x)));
|
|
546
|
+
axis = require_backend.checkAxis(axis, avals[0].ndim);
|
|
547
|
+
for (const x of avals) if (x.ndim !== avals[0].ndim || !x.shape.every((s, i) => i === axis || s === avals[0].shape[i])) throw new Error(`Concatenate: inputs ${avals[0]} and ${x} must match shapes except on axis ${axis}`);
|
|
548
|
+
return bind1(Primitive.Concatenate, xs, { axis });
|
|
549
|
+
}
|
|
550
|
+
function split$2(x, axis, sizes) {
|
|
551
|
+
axis = require_backend.checkAxis(axis, ndim$1(x));
|
|
552
|
+
if (sizes.some((s) => s < 0 || !Number.isInteger(s))) throw new Error(`split: sizes must be nonnegative integers, got ${JSON.stringify(sizes)}`);
|
|
553
|
+
const totalSize = sizes.reduce((a, b) => a + b, 0);
|
|
554
|
+
if (totalSize !== getShape(x)[axis]) throw new Error(`split: sizes must sum to the size of the axis ${axis}, got ${totalSize}`);
|
|
555
|
+
return bind(Primitive.Split, [x], {
|
|
556
|
+
axis,
|
|
557
|
+
sizes
|
|
558
|
+
});
|
|
559
|
+
}
|
|
533
560
|
function randomBits(k0, k1, shape$1, mode = "xor") {
|
|
561
|
+
if (!require_backend.deepEqual(k0.shape, k1.shape) || k0.dtype !== require_backend.DType.Uint32 || k1.dtype !== require_backend.DType.Uint32) throw new Error(`randomBits: key parts must be uint32 with the same shape, got ${ShapedArray.fromAval(k0.aval)} and ${ShapedArray.fromAval(k1.aval)}`);
|
|
534
562
|
return bind1(Primitive.RandomBits, [k0, k1], {
|
|
535
563
|
shape: shape$1,
|
|
536
564
|
mode
|
|
@@ -597,6 +625,11 @@ function pad$1(x, width) {
|
|
|
597
625
|
return bind1(Primitive.Pad, [x], { width });
|
|
598
626
|
}
|
|
599
627
|
function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
|
|
628
|
+
const as = getShape(a);
|
|
629
|
+
const bs = getShape(b);
|
|
630
|
+
if (as.length < 2 || bs.length < 2) throw new Error(`triangular_solve: must be >=2D, got a=${as}, b=${bs}`);
|
|
631
|
+
const n = as[as.length - 2];
|
|
632
|
+
if (n !== as[as.length - 1] || n !== bs[bs.length - 1]) throw new Error(`triangular_solve: incompatible shapes a=${as}, b=${bs}`);
|
|
600
633
|
if (lower) {
|
|
601
634
|
a = flip$1(a, [-2, -1]);
|
|
602
635
|
b = flip$1(b, [-1]);
|
|
@@ -606,8 +639,15 @@ function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
|
|
|
606
639
|
return x;
|
|
607
640
|
}
|
|
608
641
|
function cholesky$2(x) {
|
|
642
|
+
const aval = ShapedArray.fromAval(getAval(x));
|
|
643
|
+
if (aval.ndim < 2 || aval.shape[aval.ndim - 1] !== aval.shape[aval.ndim - 2]) throw new Error(`cholesky: expected batch of square matrices, got ${aval}`);
|
|
609
644
|
return bind1(Primitive.Cholesky, [x]);
|
|
610
645
|
}
|
|
646
|
+
function lu$1(x) {
|
|
647
|
+
const aval = ShapedArray.fromAval(getAval(x));
|
|
648
|
+
if (aval.ndim < 2) throw new Error(`lu: expected batch of matrices, got ${aval}`);
|
|
649
|
+
return bind(Primitive.LU, [x]);
|
|
650
|
+
}
|
|
611
651
|
function sort$1(x) {
|
|
612
652
|
const nd = ndim$1(x);
|
|
613
653
|
if (nd === 0) throw new Error("sort: requires at least 1D input");
|
|
@@ -652,6 +692,9 @@ function newDynamic(main) {
|
|
|
652
692
|
dynamicTrace = prevDynamicTrace;
|
|
653
693
|
} };
|
|
654
694
|
}
|
|
695
|
+
function currentTraceLevel() {
|
|
696
|
+
return traceStack[traceStack.length - 1].level;
|
|
697
|
+
}
|
|
655
698
|
var Trace = class {
|
|
656
699
|
constructor(main) {
|
|
657
700
|
this.main = main;
|
|
@@ -716,6 +759,9 @@ var Tracer = class Tracer {
|
|
|
716
759
|
mul(other) {
|
|
717
760
|
return mul(this, other);
|
|
718
761
|
}
|
|
762
|
+
mod(other) {
|
|
763
|
+
return mod(this, other);
|
|
764
|
+
}
|
|
719
765
|
greater(other) {
|
|
720
766
|
return greater$1(this, other);
|
|
721
767
|
}
|
|
@@ -828,8 +874,14 @@ var Tracer = class Tracer {
|
|
|
828
874
|
*/
|
|
829
875
|
*[Symbol.iterator]() {
|
|
830
876
|
if (this.ndim === 0) throw new Error("Cannot iterate over a scalar array");
|
|
831
|
-
|
|
832
|
-
this.
|
|
877
|
+
let residual = this;
|
|
878
|
+
const subarrayShape = this.shape.slice(1);
|
|
879
|
+
for (let i = 0; i < this.shape[0]; i++) {
|
|
880
|
+
const lr = split$2(residual, 0, [1, residual.shape[0] - 1]);
|
|
881
|
+
yield lr[0].reshape(subarrayShape);
|
|
882
|
+
residual = lr[1];
|
|
883
|
+
}
|
|
884
|
+
residual.dispose();
|
|
833
885
|
}
|
|
834
886
|
/**
|
|
835
887
|
* Return a sorted copy of an array in ascending order.
|
|
@@ -979,6 +1031,9 @@ var ShapedArray = class ShapedArray {
|
|
|
979
1031
|
get size() {
|
|
980
1032
|
return require_backend.prod(this.shape);
|
|
981
1033
|
}
|
|
1034
|
+
scalar() {
|
|
1035
|
+
return new ShapedArray([], this.dtype, this.weakType);
|
|
1036
|
+
}
|
|
982
1037
|
toString() {
|
|
983
1038
|
return `${this.dtype}[${this.shape.join(",")}]`;
|
|
984
1039
|
}
|
|
@@ -1017,6 +1072,7 @@ var TreeMismatchError = class extends TypeError {
|
|
|
1017
1072
|
super(`Mismatched tree structures in ${where$2}: ${left} != ${right}`);
|
|
1018
1073
|
}
|
|
1019
1074
|
};
|
|
1075
|
+
/** Flatten a function of `JsTree` input/output for use in tracing. */
|
|
1020
1076
|
function flattenFun(f, inTree) {
|
|
1021
1077
|
const store = { value: void 0 };
|
|
1022
1078
|
const flatFun = (...argsFlat) => {
|
|
@@ -1028,6 +1084,26 @@ function flattenFun(f, inTree) {
|
|
|
1028
1084
|
};
|
|
1029
1085
|
return [flatFun, store];
|
|
1030
1086
|
}
|
|
1087
|
+
/** Like flattenFun, but expects f to return [main, aux] tuple. */
|
|
1088
|
+
function flattenFunWithAux(f, inTree) {
|
|
1089
|
+
const store = { value: void 0 };
|
|
1090
|
+
const auxStore = { value: void 0 };
|
|
1091
|
+
const flatFun = (...argsFlat) => {
|
|
1092
|
+
const pytreeArgs = unflatten(inTree, argsFlat);
|
|
1093
|
+
const result = f(...pytreeArgs);
|
|
1094
|
+
if (!Array.isArray(result) || result.length !== 2) throw new Error("Function with `hasAux: true` must return [output, aux] tuple");
|
|
1095
|
+
const [out, aux] = result;
|
|
1096
|
+
const [outFlat, outTree] = flatten(out);
|
|
1097
|
+
store.value = outTree;
|
|
1098
|
+
auxStore.value = aux;
|
|
1099
|
+
return outFlat;
|
|
1100
|
+
};
|
|
1101
|
+
return [
|
|
1102
|
+
flatFun,
|
|
1103
|
+
store,
|
|
1104
|
+
auxStore
|
|
1105
|
+
];
|
|
1106
|
+
}
|
|
1031
1107
|
var UseAfterFreeError = class extends ReferenceError {
|
|
1032
1108
|
constructor(tracer) {
|
|
1033
1109
|
super(`Referenced tracer ${tracer.toString()} freed, please use .ref move semantics`);
|
|
@@ -1588,7 +1664,7 @@ const abstractEvalRules = {
|
|
|
1588
1664
|
return [new ShapedArray(shape$1, dtype, weakType)];
|
|
1589
1665
|
},
|
|
1590
1666
|
[Primitive.Conv]([lhs, rhs], params) {
|
|
1591
|
-
const { dtype, weakType } = promoteAvals(
|
|
1667
|
+
const { dtype, weakType } = promoteAvals(lhs.scalar(), rhs.scalar());
|
|
1592
1668
|
const shape$1 = checkConvShape(lhs.shape, rhs.shape, params);
|
|
1593
1669
|
return [new ShapedArray(shape$1, dtype, weakType)];
|
|
1594
1670
|
},
|
|
@@ -1599,10 +1675,25 @@ const abstractEvalRules = {
|
|
|
1599
1675
|
const shape$1 = require_backend.generalBroadcast(cond.shape, xy.shape);
|
|
1600
1676
|
return [new ShapedArray(shape$1, xy.dtype, xy.weakType)];
|
|
1601
1677
|
},
|
|
1678
|
+
[Primitive.Concatenate](xs, { axis }) {
|
|
1679
|
+
if (xs.length === 0) throw new TypeError("Concatenate requires at least one input");
|
|
1680
|
+
for (const x of xs) if (x.ndim !== xs[0].ndim || !x.shape.every((s, i) => i === axis || s === xs[0].shape[i])) throw new TypeError(`Concatenate: inputs ${xs[0]} and ${x} must match shapes except on axis ${axis}`);
|
|
1681
|
+
const shape$1 = xs[0].shape.slice();
|
|
1682
|
+
shape$1[axis] = xs.reduce((sum$1, x) => sum$1 + x.shape[axis], 0);
|
|
1683
|
+
const { dtype, weakType } = xs.map((x) => x.scalar()).reduce(promoteAvals);
|
|
1684
|
+
return [new ShapedArray(shape$1, dtype, weakType)];
|
|
1685
|
+
},
|
|
1686
|
+
[Primitive.Split]([x], { axis, sizes }) {
|
|
1687
|
+
const totalSize = sizes.reduce((a, b) => a + b, 0);
|
|
1688
|
+
if (x.shape[axis] !== totalSize) throw new TypeError(`Split: sizes ${sizes} do not sum to dimension ${x.shape[axis]} on axis ${axis}`);
|
|
1689
|
+
return sizes.map((size$1) => {
|
|
1690
|
+
return new ShapedArray(x.shape.toSpliced(axis, 1, size$1), x.dtype, x.weakType);
|
|
1691
|
+
});
|
|
1692
|
+
},
|
|
1602
1693
|
[Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
|
|
1603
1694
|
if (k0.dtype !== require_backend.DType.Uint32 || k1.dtype !== require_backend.DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
|
|
1604
|
-
|
|
1605
|
-
if (!require_backend.deepEqual(
|
|
1695
|
+
if (!require_backend.deepEqual(k0.shape, k1.shape)) throw new TypeError(`RandomBits: Keys have different shapes ${k0.shape} and ${k1.shape}`);
|
|
1696
|
+
if (!require_backend.deepEqual(shape$1.slice(0, k0.ndim), k0.shape)) throw new TypeError(`RandomBits: generated shape ${shape$1} must match key shape ${k0.shape}`);
|
|
1606
1697
|
return [new ShapedArray(shape$1, require_backend.DType.Uint32, false)];
|
|
1607
1698
|
},
|
|
1608
1699
|
[Primitive.Gather]([x, ...indices], { axis, outDim }) {
|
|
@@ -1659,6 +1750,16 @@ const abstractEvalRules = {
|
|
|
1659
1750
|
if (a.shape[a.ndim - 2] !== a.shape[a.ndim - 1]) throw new TypeError(`cholesky: must be square, got ${a}`);
|
|
1660
1751
|
return [ShapedArray.fromAval(a)];
|
|
1661
1752
|
},
|
|
1753
|
+
[Primitive.LU]([a]) {
|
|
1754
|
+
if (a.ndim < 2) throw new TypeError(`lu: requires at least 2D input, got ${a}`);
|
|
1755
|
+
const batch = a.shape.slice(0, -2);
|
|
1756
|
+
const [m, n] = a.shape.slice(-2);
|
|
1757
|
+
return [
|
|
1758
|
+
ShapedArray.fromAval(a),
|
|
1759
|
+
new ShapedArray([...batch, Math.min(m, n)], require_backend.DType.Int32, false),
|
|
1760
|
+
new ShapedArray([...batch, m], require_backend.DType.Int32, false)
|
|
1761
|
+
];
|
|
1762
|
+
},
|
|
1662
1763
|
[Primitive.Jit](args, { jaxpr }) {
|
|
1663
1764
|
const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
|
|
1664
1765
|
if (args.length !== inTypes.length) throw new TypeError(`jit expected ${inTypes.length} arguments, got ${args.length}`);
|
|
@@ -1736,12 +1837,6 @@ function jit$1(f, opts) {
|
|
|
1736
1837
|
|
|
1737
1838
|
//#endregion
|
|
1738
1839
|
//#region src/frontend/jit.ts
|
|
1739
|
-
const routinePrimitives = new Map([
|
|
1740
|
-
[Primitive.Sort, require_backend.Routines.Sort],
|
|
1741
|
-
[Primitive.Argsort, require_backend.Routines.Argsort],
|
|
1742
|
-
[Primitive.TriangularSolve, require_backend.Routines.TriangularSolve],
|
|
1743
|
-
[Primitive.Cholesky, require_backend.Routines.Cholesky]
|
|
1744
|
-
]);
|
|
1745
1840
|
/** Result of compiling a Jaxpr. Can be evaluated on a series of inputs. */
|
|
1746
1841
|
var JitProgram = class {
|
|
1747
1842
|
constructor(backend, steps, inputs, outputs) {
|
|
@@ -1911,10 +2006,10 @@ function jitCompile(backend, jaxpr) {
|
|
|
1911
2006
|
inputs.push(jv.arg);
|
|
1912
2007
|
} else if (input instanceof Lit) inputs.push(builder.pushLit(input));
|
|
1913
2008
|
const outputs = [];
|
|
1914
|
-
for (const outVar
|
|
1915
|
-
const outId = builder.pushBuffer(outVar
|
|
2009
|
+
for (const outVar of eqn.outBinders) {
|
|
2010
|
+
const outId = builder.pushBuffer(outVar.aval.size * require_backend.byteWidth(outVar.aval.dtype));
|
|
1916
2011
|
outputs.push(outId);
|
|
1917
|
-
ctx.set(outVar
|
|
2012
|
+
ctx.set(outVar, {
|
|
1918
2013
|
type: "imm",
|
|
1919
2014
|
arg: outId
|
|
1920
2015
|
});
|
|
@@ -1965,35 +2060,37 @@ function jitCompile(backend, jaxpr) {
|
|
|
1965
2060
|
let reduction;
|
|
1966
2061
|
if (inputReduction) {
|
|
1967
2062
|
const jv = inputReduction;
|
|
1968
|
-
const newEpilogue = rule(inputExps, inputAvals, eqn.params).exp;
|
|
1969
|
-
exp$2 = jv.exp.reindexGids(addArgs(jv.args));
|
|
2063
|
+
const newEpilogue = rule(inputExps, inputAvals, eqn.params).exp[0];
|
|
2064
|
+
exp$2 = [jv.exp.reindexGids(addArgs(jv.args))];
|
|
1970
2065
|
reduction = new require_backend.Reduction(jv.reduction.dtype, jv.reduction.op, jv.reduction.size, newEpilogue);
|
|
1971
2066
|
} else {
|
|
1972
2067
|
const ruleOutput = rule(inputExps, inputAvals, eqn.params);
|
|
1973
2068
|
exp$2 = ruleOutput.exp;
|
|
1974
2069
|
reduction = ruleOutput.reduction;
|
|
1975
2070
|
}
|
|
1976
|
-
|
|
1977
|
-
|
|
1978
|
-
|
|
1979
|
-
|
|
1980
|
-
|
|
1981
|
-
|
|
1982
|
-
|
|
1983
|
-
|
|
1984
|
-
|
|
2071
|
+
for (let i$1 = 0; i$1 < eqn.outBinders.length; i$1++) {
|
|
2072
|
+
const outVar = eqn.outBinders[i$1];
|
|
2073
|
+
if (blackNodes.has(outVar)) {
|
|
2074
|
+
const nargs$1 = inputArgs.length;
|
|
2075
|
+
const size$1 = outVar.aval.size;
|
|
2076
|
+
const kernel = new require_backend.Kernel(nargs$1, size$1, exp$2[i$1], reduction);
|
|
2077
|
+
const outId = builder.pushKernel(kernel, inputArgs);
|
|
2078
|
+
ctx.set(outVar, {
|
|
2079
|
+
type: "imm",
|
|
2080
|
+
arg: outId
|
|
2081
|
+
});
|
|
2082
|
+
} else if (reduction) ctx.set(outVar, {
|
|
2083
|
+
type: "red",
|
|
2084
|
+
exp: exp$2[i$1],
|
|
2085
|
+
reduction,
|
|
2086
|
+
args: inputArgs
|
|
1985
2087
|
});
|
|
1986
|
-
|
|
1987
|
-
|
|
1988
|
-
|
|
1989
|
-
|
|
1990
|
-
|
|
1991
|
-
}
|
|
1992
|
-
else ctx.set(outVar, {
|
|
1993
|
-
type: "exp",
|
|
1994
|
-
exp: exp$2,
|
|
1995
|
-
args: inputArgs
|
|
1996
|
-
});
|
|
2088
|
+
else ctx.set(outVar, {
|
|
2089
|
+
type: "exp",
|
|
2090
|
+
exp: exp$2[i$1],
|
|
2091
|
+
args: inputArgs
|
|
2092
|
+
});
|
|
2093
|
+
}
|
|
1997
2094
|
}
|
|
1998
2095
|
const outputIds = [];
|
|
1999
2096
|
for (const out of jaxpr.outs) if (out instanceof Var) {
|
|
@@ -2034,17 +2131,17 @@ function broadcastedJit(fn, opts) {
|
|
|
2034
2131
|
if (exp$2.dtype !== newDtype && !skipCastIdx.includes(i)) exp$2 = require_backend.AluExp.cast(newDtype, exp$2);
|
|
2035
2132
|
return exp$2;
|
|
2036
2133
|
});
|
|
2037
|
-
return { exp: fn(exps, params) };
|
|
2134
|
+
return { exp: [fn(exps, params)] };
|
|
2038
2135
|
};
|
|
2039
2136
|
}
|
|
2040
2137
|
function unopJit(fn) {
|
|
2041
2138
|
return ([a], [_as], params) => {
|
|
2042
|
-
return { exp: fn(a, params) };
|
|
2139
|
+
return { exp: [fn(a, params)] };
|
|
2043
2140
|
};
|
|
2044
2141
|
}
|
|
2045
2142
|
function reshapeJit(fn) {
|
|
2046
2143
|
return ([a], [_as], params) => {
|
|
2047
|
-
return { exp: reshapeViews(a, (st) => fn(st, params)) };
|
|
2144
|
+
return { exp: [reshapeViews(a, (st) => fn(st, params))] };
|
|
2048
2145
|
};
|
|
2049
2146
|
}
|
|
2050
2147
|
function routineNoJit() {
|
|
@@ -2090,7 +2187,7 @@ const jitRules = {
|
|
|
2090
2187
|
a = reshapeViews(a, (st) => st.permute(perm).reshape(newShape), true);
|
|
2091
2188
|
const reduction = new require_backend.Reduction(a.dtype, op, reductionSize);
|
|
2092
2189
|
return {
|
|
2093
|
-
exp: a,
|
|
2190
|
+
exp: [a],
|
|
2094
2191
|
reduction
|
|
2095
2192
|
};
|
|
2096
2193
|
},
|
|
@@ -2101,13 +2198,13 @@ const jitRules = {
|
|
|
2101
2198
|
a = reshapeViews(a, (st) => st.compose(stX), true);
|
|
2102
2199
|
const reduction = new require_backend.Reduction(a.dtype, require_backend.AluOp.Add, stX.shape[stX.shape.length - 1]);
|
|
2103
2200
|
return {
|
|
2104
|
-
exp: a,
|
|
2201
|
+
exp: [a],
|
|
2105
2202
|
reduction
|
|
2106
2203
|
};
|
|
2107
2204
|
},
|
|
2108
2205
|
[Primitive.Dot]([a, b], [as, bs]) {
|
|
2109
2206
|
const k1 = jitRules[Primitive.Mul]([a, b], [as, bs], {});
|
|
2110
|
-
const c = k1.exp;
|
|
2207
|
+
const [c] = k1.exp;
|
|
2111
2208
|
const cs = promoteAvals(as, bs);
|
|
2112
2209
|
return jitRules[Primitive.Reduce]([c], [cs], {
|
|
2113
2210
|
op: require_backend.AluOp.Add,
|
|
@@ -2124,16 +2221,42 @@ const jitRules = {
|
|
|
2124
2221
|
},
|
|
2125
2222
|
[Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
|
|
2126
2223
|
[Primitive.Where]: broadcastedJit(([cond, a, b]) => require_backend.AluExp.where(cond, a, b), { skipCastIdx: [0] }),
|
|
2224
|
+
[Primitive.Concatenate](exps, avals, { axis }) {
|
|
2225
|
+
const ndim$2 = avals[0].ndim;
|
|
2226
|
+
const sizes = avals.map((x) => x.shape[axis]);
|
|
2227
|
+
const finalSize = sizes.reduce((a, b) => a + b, 0);
|
|
2228
|
+
const { dtype: dtypeOut } = avals.map((x) => x.scalar()).reduce(promoteAvals);
|
|
2229
|
+
const makePadAxis = (start, end) => require_backend.range(ndim$2).map((i) => i === axis ? [start, end] : [0, 0]);
|
|
2230
|
+
let cum = 0;
|
|
2231
|
+
const src = [];
|
|
2232
|
+
for (let i = 0; i < exps.length; i++) {
|
|
2233
|
+
const padding = makePadAxis(cum, finalSize - cum - sizes[i]);
|
|
2234
|
+
src.push(reshapeViews(require_backend.AluExp.cast(dtypeOut, exps[i]), (st) => st.pad(padding)));
|
|
2235
|
+
cum += sizes[i];
|
|
2236
|
+
}
|
|
2237
|
+
return { exp: [src.reduce(require_backend.AluExp.add)] };
|
|
2238
|
+
},
|
|
2239
|
+
[Primitive.Split]([a], [as], { axis, sizes }) {
|
|
2240
|
+
const exp$2 = [];
|
|
2241
|
+
let start = 0;
|
|
2242
|
+
for (const size$1 of sizes) {
|
|
2243
|
+
const slice = require_backend.range(as.ndim).map((d) => d === axis ? [start, start + size$1] : [0, as.shape[d]]);
|
|
2244
|
+
exp$2.push(reshapeViews(a, (st) => st.shrink(slice)));
|
|
2245
|
+
start += size$1;
|
|
2246
|
+
}
|
|
2247
|
+
return { exp: exp$2 };
|
|
2248
|
+
},
|
|
2127
2249
|
[Primitive.RandomBits]: (keys, keyShapes, { shape: shape$1, mode }) => {
|
|
2250
|
+
const keyShape = keyShapes[0].shape;
|
|
2128
2251
|
const mapping = (st) => {
|
|
2129
|
-
if (!require_backend.deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, require_backend.range(shape
|
|
2252
|
+
if (!require_backend.deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, require_backend.range(st.shape.length, shape$1.length));
|
|
2130
2253
|
};
|
|
2131
2254
|
const k0 = reshapeViews(keys[0], mapping);
|
|
2132
2255
|
const k1 = reshapeViews(keys[1], mapping);
|
|
2133
2256
|
const c0 = require_backend.AluExp.u32(0);
|
|
2134
|
-
const c1 = require_backend.AluExp.cast(require_backend.DType.Uint32, require_backend.AluVar.gidx);
|
|
2257
|
+
const c1 = require_backend.AluExp.mod(require_backend.AluExp.cast(require_backend.DType.Uint32, require_backend.AluVar.gidx), require_backend.AluExp.u32(Math.max(require_backend.prod(shape$1.slice(keyShape.length)), 1)));
|
|
2135
2258
|
const exp$2 = require_backend.AluExp.threefry2x32(k0, k1, c0, c1, mode);
|
|
2136
|
-
return { exp: exp$2 };
|
|
2259
|
+
return { exp: [exp$2] };
|
|
2137
2260
|
},
|
|
2138
2261
|
[Primitive.Gather]([x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
|
|
2139
2262
|
const axisSet = new Set(axis);
|
|
@@ -2148,7 +2271,7 @@ const jitRules = {
|
|
|
2148
2271
|
for (const [i, iexp] of indices.entries()) src[axis[i]] = require_backend.AluExp.cast(require_backend.DType.Int32, reshapeViews(iexp, (st) => st.broadcast(finalShape, [...require_backend.range(outDim + indexShape.length - st.shape.length), ...require_backend.range(outDim + indexShape.length, finalShape.length)])));
|
|
2149
2272
|
const [index, valid] = require_backend.ShapeTracker.fromShape(xs.shape).toAluExp(src);
|
|
2150
2273
|
if (!valid.resolve()) throw new Error("internal: expected full validity mask in Gather");
|
|
2151
|
-
return { exp: x.substitute({ gidx: index }) };
|
|
2274
|
+
return { exp: [x.substitute({ gidx: index })] };
|
|
2152
2275
|
},
|
|
2153
2276
|
[Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
|
|
2154
2277
|
[Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
|
|
@@ -2164,6 +2287,7 @@ const jitRules = {
|
|
|
2164
2287
|
[Primitive.Argsort]: routineNoJit(),
|
|
2165
2288
|
[Primitive.TriangularSolve]: routineNoJit(),
|
|
2166
2289
|
[Primitive.Cholesky]: routineNoJit(),
|
|
2290
|
+
[Primitive.LU]: routineNoJit(),
|
|
2167
2291
|
[Primitive.Jit]() {
|
|
2168
2292
|
throw new Error("internal: Jit should have been flattened before JIT compilation");
|
|
2169
2293
|
}
|
|
@@ -2245,7 +2369,7 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2245
2369
|
p1NextBlack.set(v, v);
|
|
2246
2370
|
}
|
|
2247
2371
|
const heterogeneousViewPrimitives = [Primitive.RandomBits, Primitive.Gather];
|
|
2248
|
-
const needsCleanShapePrimitives = [Primitive.Pad];
|
|
2372
|
+
const needsCleanShapePrimitives = [Primitive.Concatenate, Primitive.Pad];
|
|
2249
2373
|
for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
|
|
2250
2374
|
const eqn = jaxpr.eqns[i];
|
|
2251
2375
|
if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || routinePrimitives.has(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
|
|
@@ -2315,7 +2439,7 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2315
2439
|
|
|
2316
2440
|
//#endregion
|
|
2317
2441
|
//#region src/frontend/array.ts
|
|
2318
|
-
const JsArray = globalThis.Array;
|
|
2442
|
+
const JsArray$1 = globalThis.Array;
|
|
2319
2443
|
const inlineArrayLimit = 128;
|
|
2320
2444
|
/** Version of pureArray with fudged types. */
|
|
2321
2445
|
const fudgeArray = pureArray;
|
|
@@ -2442,6 +2566,10 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2442
2566
|
this.#rc++;
|
|
2443
2567
|
return this;
|
|
2444
2568
|
}
|
|
2569
|
+
/** Get the current reference count (for debugging memory management). */
|
|
2570
|
+
get refCount() {
|
|
2571
|
+
return this.#rc;
|
|
2572
|
+
}
|
|
2445
2573
|
dispose() {
|
|
2446
2574
|
this.#check();
|
|
2447
2575
|
if (--this.#rc === 0) {
|
|
@@ -2599,7 +2727,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2599
2727
|
} else if (castDtype === void 0) {
|
|
2600
2728
|
castDtype = arrays[i].#dtype;
|
|
2601
2729
|
castWeakType = arrays[i].#weakType;
|
|
2602
|
-
} else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType),
|
|
2730
|
+
} else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), arrays[i].aval.scalar()));
|
|
2603
2731
|
const weakType = castWeakType && !strongTypeOutput;
|
|
2604
2732
|
const { backend, committed } = Array$1.#computeBackend(name, arrays);
|
|
2605
2733
|
arrays = arrays.map((ar) => ar._putSync(backend));
|
|
@@ -2709,25 +2837,35 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2709
2837
|
});
|
|
2710
2838
|
}
|
|
2711
2839
|
/** Apply an operation with custom lowering to this array. */
|
|
2712
|
-
static #routine(
|
|
2713
|
-
|
|
2714
|
-
|
|
2715
|
-
|
|
2716
|
-
|
|
2717
|
-
|
|
2718
|
-
|
|
2719
|
-
|
|
2720
|
-
|
|
2721
|
-
|
|
2722
|
-
|
|
2723
|
-
|
|
2724
|
-
|
|
2725
|
-
dtype
|
|
2726
|
-
|
|
2727
|
-
|
|
2728
|
-
|
|
2729
|
-
pending
|
|
2730
|
-
|
|
2840
|
+
static #routine(prim) {
|
|
2841
|
+
return (arrays, params) => {
|
|
2842
|
+
const { backend, committed } = Array$1.#computeBackend(prim, arrays);
|
|
2843
|
+
for (const ar of arrays) ar.#realize();
|
|
2844
|
+
const avals = arrays.map((ar) => ar.aval);
|
|
2845
|
+
const avalsOut = abstractEvalRules[prim](avals, params);
|
|
2846
|
+
const routine = new require_backend.Routine(routinePrimitives.get(prim), {
|
|
2847
|
+
inputShapes: avals.map((a) => a.shape),
|
|
2848
|
+
inputDtypes: avals.map((a) => a.dtype),
|
|
2849
|
+
outputShapes: avalsOut.map((a) => a.shape),
|
|
2850
|
+
outputDtypes: avalsOut.map((a) => a.dtype)
|
|
2851
|
+
}, params);
|
|
2852
|
+
const inputs = arrays.map((ar) => ar.#source);
|
|
2853
|
+
const outputs = avalsOut.map((x) => backend.malloc(require_backend.byteWidth(x.dtype) * x.size));
|
|
2854
|
+
const pending = arrays.flatMap((ar) => ar.#pending);
|
|
2855
|
+
for (const exe of pending) exe.updateRc(+outputs.length);
|
|
2856
|
+
pending.push(new PendingExecute(backend, routine, inputs, outputs));
|
|
2857
|
+
pending[pending.length - 1].updateRc(+outputs.length - 1);
|
|
2858
|
+
arrays.forEach((ar) => ar.dispose());
|
|
2859
|
+
return outputs.map((output, i) => new Array$1({
|
|
2860
|
+
source: output,
|
|
2861
|
+
st: require_backend.ShapeTracker.fromShape(avalsOut[i].shape),
|
|
2862
|
+
dtype: avalsOut[i].dtype,
|
|
2863
|
+
weakType: avalsOut[i].weakType,
|
|
2864
|
+
backend,
|
|
2865
|
+
committed,
|
|
2866
|
+
pending
|
|
2867
|
+
}));
|
|
2868
|
+
};
|
|
2731
2869
|
}
|
|
2732
2870
|
/**
|
|
2733
2871
|
* Normalizes this array into one backed by a `Slot`.
|
|
@@ -2992,17 +3130,44 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2992
3130
|
y
|
|
2993
3131
|
], { dtypeOverride: [require_backend.DType.Bool] })];
|
|
2994
3132
|
},
|
|
3133
|
+
[Primitive.Concatenate](xs, { axis }) {
|
|
3134
|
+
const ndim$2 = xs[0].ndim;
|
|
3135
|
+
const sizes = xs.map((x) => x.shape[axis]);
|
|
3136
|
+
const finalSize = sizes.reduce((a, b) => a + b, 0);
|
|
3137
|
+
const makePadAxis = (start, end) => require_backend.range(ndim$2).map((i) => i === axis ? [start, end] : [0, 0]);
|
|
3138
|
+
let cum = 0;
|
|
3139
|
+
const xsPadded = [];
|
|
3140
|
+
for (let i = 0; i < xs.length; i++) {
|
|
3141
|
+
const padding = makePadAxis(cum, finalSize - cum - sizes[i]);
|
|
3142
|
+
xsPadded.push(xs[i].#reshape(xs[i].#st.pad(padding)));
|
|
3143
|
+
cum += sizes[i];
|
|
3144
|
+
}
|
|
3145
|
+
const custom = (exps) => exps.reduce(require_backend.AluExp.add);
|
|
3146
|
+
return [Array$1.#naryCustom("concatenate", custom, xsPadded)];
|
|
3147
|
+
},
|
|
3148
|
+
[Primitive.Split]([x], { axis, sizes }) {
|
|
3149
|
+
const outputs = [];
|
|
3150
|
+
for (let i = 0, start = 0; i < sizes.length; i++) {
|
|
3151
|
+
const slice = require_backend.range(x.ndim).map((d) => d === axis ? [start, start + sizes[i]] : [0, x.shape[d]]);
|
|
3152
|
+
outputs.push(x.ref.#reshape(x.#st.shrink(slice)));
|
|
3153
|
+
start += sizes[i];
|
|
3154
|
+
}
|
|
3155
|
+
x.dispose();
|
|
3156
|
+
return outputs;
|
|
3157
|
+
},
|
|
2995
3158
|
[Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
|
|
2996
|
-
const keyShape =
|
|
2997
|
-
|
|
2998
|
-
const c0 = zeros(
|
|
3159
|
+
const keyShape = k0.shape;
|
|
3160
|
+
const genShape = shape$1.slice(keyShape.length);
|
|
3161
|
+
const c0 = zeros(genShape, {
|
|
2999
3162
|
dtype: require_backend.DType.Uint32,
|
|
3000
3163
|
device: k0.device
|
|
3001
3164
|
});
|
|
3002
|
-
const c1 = arange(0, require_backend.prod(
|
|
3165
|
+
const c1 = arange(0, require_backend.prod(genShape), 1, {
|
|
3003
3166
|
dtype: require_backend.DType.Uint32,
|
|
3004
3167
|
device: k0.device
|
|
3005
|
-
}).reshape(
|
|
3168
|
+
}).reshape(genShape);
|
|
3169
|
+
k0 = k0.#reshape(k0.#st.reshape(keyShape.concat(require_backend.rep(genShape.length, 1))));
|
|
3170
|
+
k1 = k1.#reshape(k1.#st.reshape(keyShape.concat(require_backend.rep(genShape.length, 1))));
|
|
3006
3171
|
const custom = ([k0$1, k1$1, c0$1, c1$1]) => require_backend.AluExp.threefry2x32(k0$1, k1$1, c0$1, c1$1, mode);
|
|
3007
3172
|
return [Array$1.#naryCustom("random_bits", custom, [
|
|
3008
3173
|
k0,
|
|
@@ -3034,42 +3199,11 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
3034
3199
|
[Primitive.Pad]([x], { width }) {
|
|
3035
3200
|
return [x.#reshape(x.#st.pad(width))];
|
|
3036
3201
|
},
|
|
3037
|
-
[Primitive.Sort](
|
|
3038
|
-
|
|
3039
|
-
|
|
3040
|
-
|
|
3041
|
-
|
|
3042
|
-
outputDtypes: [x.aval.dtype]
|
|
3043
|
-
});
|
|
3044
|
-
return Array$1.#routine(routine, [x], [x.#weakType]);
|
|
3045
|
-
},
|
|
3046
|
-
[Primitive.Argsort]([x]) {
|
|
3047
|
-
const routine = new require_backend.Routine(require_backend.Routines.Argsort, {
|
|
3048
|
-
inputShapes: [x.aval.shape],
|
|
3049
|
-
inputDtypes: [x.aval.dtype],
|
|
3050
|
-
outputShapes: [x.aval.shape, x.aval.shape],
|
|
3051
|
-
outputDtypes: [x.aval.dtype, require_backend.DType.Int32]
|
|
3052
|
-
});
|
|
3053
|
-
return Array$1.#routine(routine, [x], [x.#weakType, false]);
|
|
3054
|
-
},
|
|
3055
|
-
[Primitive.TriangularSolve]([a, b], { unitDiagonal }) {
|
|
3056
|
-
const routine = new require_backend.Routine(require_backend.Routines.TriangularSolve, {
|
|
3057
|
-
inputShapes: [a.aval.shape, b.aval.shape],
|
|
3058
|
-
inputDtypes: [a.aval.dtype, b.aval.dtype],
|
|
3059
|
-
outputShapes: [b.aval.shape],
|
|
3060
|
-
outputDtypes: [b.aval.dtype]
|
|
3061
|
-
}, { unitDiagonal });
|
|
3062
|
-
return Array$1.#routine(routine, [a, b], [a.#weakType && b.#weakType]);
|
|
3063
|
-
},
|
|
3064
|
-
[Primitive.Cholesky]([a]) {
|
|
3065
|
-
const routine = new require_backend.Routine(require_backend.Routines.Cholesky, {
|
|
3066
|
-
inputShapes: [a.aval.shape],
|
|
3067
|
-
inputDtypes: [a.aval.dtype],
|
|
3068
|
-
outputShapes: [a.aval.shape],
|
|
3069
|
-
outputDtypes: [a.aval.dtype]
|
|
3070
|
-
});
|
|
3071
|
-
return Array$1.#routine(routine, [a], [a.#weakType]);
|
|
3072
|
-
},
|
|
3202
|
+
[Primitive.Sort]: Array$1.#routine(Primitive.Sort),
|
|
3203
|
+
[Primitive.Argsort]: Array$1.#routine(Primitive.Argsort),
|
|
3204
|
+
[Primitive.TriangularSolve]: Array$1.#routine(Primitive.TriangularSolve),
|
|
3205
|
+
[Primitive.Cholesky]: Array$1.#routine(Primitive.Cholesky),
|
|
3206
|
+
[Primitive.LU]: Array$1.#routine(Primitive.LU),
|
|
3073
3207
|
[Primitive.Jit](args, { jaxpr }) {
|
|
3074
3208
|
if (jaxpr.inBinders.length !== args.length) throw new Error(`jit expects ${jaxpr.inBinders.length} args, got ${args.length}`);
|
|
3075
3209
|
const { backend, committed } = Array$1.#computeBackend("jit", args);
|
|
@@ -3151,7 +3285,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
3151
3285
|
if (!shape$1) {
|
|
3152
3286
|
shape$1 = [];
|
|
3153
3287
|
let cur = values;
|
|
3154
|
-
while (JsArray.isArray(cur)) {
|
|
3288
|
+
while (JsArray$1.isArray(cur)) {
|
|
3155
3289
|
shape$1.push(cur.length);
|
|
3156
3290
|
cur = cur[0];
|
|
3157
3291
|
}
|
|
@@ -3175,7 +3309,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
3175
3309
|
device
|
|
3176
3310
|
});
|
|
3177
3311
|
} else {
|
|
3178
|
-
const weakType = dtype == void 0;
|
|
3312
|
+
const weakType = dtype == void 0 && shape$1.length === 0;
|
|
3179
3313
|
dtype = dtype ?? require_backend.DType.Float32;
|
|
3180
3314
|
const data = require_backend.dtypedJsArray(dtype, flat);
|
|
3181
3315
|
return arrayFromData(data, shape$1, {
|
|
@@ -3289,7 +3423,7 @@ function ones(shape$1, { dtype, device } = {}) {
|
|
|
3289
3423
|
}
|
|
3290
3424
|
/** Return a new array of given shape and type, filled with `fill_value`. */
|
|
3291
3425
|
function full(shape$1, fillValue, { dtype, device } = {}) {
|
|
3292
|
-
let weakType = dtype == void 0;
|
|
3426
|
+
let weakType = dtype == void 0 && shape$1.length === 0;
|
|
3293
3427
|
if (typeof fillValue === "number") dtype = dtype ?? require_backend.DType.Float32;
|
|
3294
3428
|
else if (typeof fillValue === "boolean") {
|
|
3295
3429
|
dtype = dtype ?? require_backend.DType.Bool;
|
|
@@ -3447,6 +3581,27 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
|
|
|
3447
3581
|
committed: device != void 0
|
|
3448
3582
|
});
|
|
3449
3583
|
}
|
|
3584
|
+
/**
|
|
3585
|
+
* Return numbers spaced evenly on a log scale.
|
|
3586
|
+
*
|
|
3587
|
+
* In linear space, the sequence starts at `base ** start` and ends at
|
|
3588
|
+
* `base ** stop` (see `endpoint` below).
|
|
3589
|
+
*
|
|
3590
|
+
* @param start - `base ** start` is the starting value of the sequence.
|
|
3591
|
+
* @param stop - `base ** stop` is the final value of the sequence, unless `endpoint` is false.
|
|
3592
|
+
* @param num - Number of samples to generate. Default is 50.
|
|
3593
|
+
* @param endpoint - If true, `stop` is the last sample. Otherwise, it is not included. Default is true.
|
|
3594
|
+
* @param base - The base of the log space. Default is 10.
|
|
3595
|
+
* @returns Array of evenly spaced values on a log scale.
|
|
3596
|
+
*/
|
|
3597
|
+
function logspace(start, stop, num = 50, endpoint = true, base = 10, { dtype, device } = {}) {
|
|
3598
|
+
const y = linspace(start, stop, num, endpoint, {
|
|
3599
|
+
dtype,
|
|
3600
|
+
device
|
|
3601
|
+
});
|
|
3602
|
+
const logBase = Math.log(base);
|
|
3603
|
+
return exp$1(mul(y, logBase));
|
|
3604
|
+
}
|
|
3450
3605
|
function aluCompare(a, b, op) {
|
|
3451
3606
|
switch (op) {
|
|
3452
3607
|
case CompareOp.Less: return require_backend.AluExp.cmplt(a, b);
|
|
@@ -3524,6 +3679,7 @@ var BatchTrace = class extends Trace {
|
|
|
3524
3679
|
return valOuts$1.map((x) => new BatchTracer(this, x, null));
|
|
3525
3680
|
}
|
|
3526
3681
|
const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
|
|
3682
|
+
if (valOuts.length !== bdimOuts.length) throw new Error(`vmap rule for ${primitive} returned mismatched lengths: ${valOuts.length} vs ${bdimOuts.length}`);
|
|
3527
3683
|
return require_backend.zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
|
|
3528
3684
|
}
|
|
3529
3685
|
get axisSize() {
|
|
@@ -3535,13 +3691,13 @@ var BatchTrace = class extends Trace {
|
|
|
3535
3691
|
*
|
|
3536
3692
|
* Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
|
|
3537
3693
|
*/
|
|
3538
|
-
function broadcastBatcher(
|
|
3539
|
-
return (axisSize, args, dims) => {
|
|
3694
|
+
function broadcastBatcher(prim) {
|
|
3695
|
+
return (axisSize, args, dims, params) => {
|
|
3540
3696
|
if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
|
|
3541
3697
|
const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
|
|
3542
3698
|
const firstIdx = dims.findIndex((d) => d !== null);
|
|
3543
3699
|
const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
|
|
3544
|
-
if (require_backend.zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[
|
|
3700
|
+
if (require_backend.zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[bind1(prim, args, params)], [nd + firstBdim]];
|
|
3545
3701
|
args = args.map((x, i) => {
|
|
3546
3702
|
if (dims[i] === null) return x;
|
|
3547
3703
|
x = moveBatchAxis(axisSize, dims[i], 0, x);
|
|
@@ -3552,37 +3708,45 @@ function broadcastBatcher(op) {
|
|
|
3552
3708
|
]);
|
|
3553
3709
|
return x;
|
|
3554
3710
|
});
|
|
3555
|
-
return [[
|
|
3711
|
+
return [[bind1(prim, args, params)], [0]];
|
|
3556
3712
|
};
|
|
3557
3713
|
}
|
|
3558
|
-
function unopBatcher(
|
|
3714
|
+
function unopBatcher(prim) {
|
|
3559
3715
|
return (axisSize, [x], [xBdim], params) => {
|
|
3560
|
-
return [[
|
|
3716
|
+
return [[bind1(prim, [x], params)], [xBdim]];
|
|
3717
|
+
};
|
|
3718
|
+
}
|
|
3719
|
+
function lastDimsBatcher(prim, inputDims, numOutputs = 1) {
|
|
3720
|
+
return (axisSize, [x], [xBdim], params) => {
|
|
3721
|
+
require_backend.assertNonNull(xBdim);
|
|
3722
|
+
if (xBdim < x.ndim - inputDims) return [bind(prim, [x], params), require_backend.rep(numOutputs, xBdim)];
|
|
3723
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3724
|
+
return [bind(prim, [x], params), require_backend.rep(numOutputs, 0)];
|
|
3561
3725
|
};
|
|
3562
3726
|
}
|
|
3563
3727
|
const vmapRules = {
|
|
3564
|
-
[Primitive.Add]: broadcastBatcher(
|
|
3565
|
-
[Primitive.Mul]: broadcastBatcher(
|
|
3566
|
-
[Primitive.Idiv]: broadcastBatcher(
|
|
3567
|
-
[Primitive.Mod]: broadcastBatcher(
|
|
3568
|
-
[Primitive.Min]: broadcastBatcher(
|
|
3569
|
-
[Primitive.Max]: broadcastBatcher(
|
|
3570
|
-
[Primitive.Neg]: unopBatcher(
|
|
3571
|
-
[Primitive.Reciprocal]: unopBatcher(
|
|
3572
|
-
[Primitive.Floor]: unopBatcher(
|
|
3573
|
-
[Primitive.Ceil]: unopBatcher(
|
|
3574
|
-
[Primitive.StopGradient]: unopBatcher(
|
|
3575
|
-
[Primitive.Cast]: unopBatcher(
|
|
3576
|
-
[Primitive.Bitcast]: unopBatcher(
|
|
3577
|
-
[Primitive.Sin]: unopBatcher(
|
|
3578
|
-
[Primitive.Cos]: unopBatcher(
|
|
3579
|
-
[Primitive.Asin]: unopBatcher(
|
|
3580
|
-
[Primitive.Atan]: unopBatcher(
|
|
3581
|
-
[Primitive.Exp]: unopBatcher(
|
|
3582
|
-
[Primitive.Log]: unopBatcher(
|
|
3583
|
-
[Primitive.Erf]: unopBatcher(
|
|
3584
|
-
[Primitive.Erfc]: unopBatcher(
|
|
3585
|
-
[Primitive.Sqrt]: unopBatcher(
|
|
3728
|
+
[Primitive.Add]: broadcastBatcher(Primitive.Add),
|
|
3729
|
+
[Primitive.Mul]: broadcastBatcher(Primitive.Mul),
|
|
3730
|
+
[Primitive.Idiv]: broadcastBatcher(Primitive.Idiv),
|
|
3731
|
+
[Primitive.Mod]: broadcastBatcher(Primitive.Mod),
|
|
3732
|
+
[Primitive.Min]: broadcastBatcher(Primitive.Min),
|
|
3733
|
+
[Primitive.Max]: broadcastBatcher(Primitive.Max),
|
|
3734
|
+
[Primitive.Neg]: unopBatcher(Primitive.Neg),
|
|
3735
|
+
[Primitive.Reciprocal]: unopBatcher(Primitive.Reciprocal),
|
|
3736
|
+
[Primitive.Floor]: unopBatcher(Primitive.Floor),
|
|
3737
|
+
[Primitive.Ceil]: unopBatcher(Primitive.Ceil),
|
|
3738
|
+
[Primitive.StopGradient]: unopBatcher(Primitive.StopGradient),
|
|
3739
|
+
[Primitive.Cast]: unopBatcher(Primitive.Cast),
|
|
3740
|
+
[Primitive.Bitcast]: unopBatcher(Primitive.Bitcast),
|
|
3741
|
+
[Primitive.Sin]: unopBatcher(Primitive.Sin),
|
|
3742
|
+
[Primitive.Cos]: unopBatcher(Primitive.Cos),
|
|
3743
|
+
[Primitive.Asin]: unopBatcher(Primitive.Asin),
|
|
3744
|
+
[Primitive.Atan]: unopBatcher(Primitive.Atan),
|
|
3745
|
+
[Primitive.Exp]: unopBatcher(Primitive.Exp),
|
|
3746
|
+
[Primitive.Log]: unopBatcher(Primitive.Log),
|
|
3747
|
+
[Primitive.Erf]: unopBatcher(Primitive.Erf),
|
|
3748
|
+
[Primitive.Erfc]: unopBatcher(Primitive.Erfc),
|
|
3749
|
+
[Primitive.Sqrt]: unopBatcher(Primitive.Sqrt),
|
|
3586
3750
|
[Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
|
|
3587
3751
|
require_backend.assertNonNull(xBdim);
|
|
3588
3752
|
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
@@ -3604,10 +3768,25 @@ const vmapRules = {
|
|
|
3604
3768
|
});
|
|
3605
3769
|
return [[z], [0]];
|
|
3606
3770
|
},
|
|
3607
|
-
[Primitive.Compare](
|
|
3608
|
-
|
|
3771
|
+
[Primitive.Compare]: broadcastBatcher(Primitive.Compare),
|
|
3772
|
+
[Primitive.Where]: broadcastBatcher(Primitive.Where),
|
|
3773
|
+
[Primitive.Concatenate](axisSize, xs, xBdims, { axis }) {
|
|
3774
|
+
const minBdim = Math.min(...xBdims.filter((d) => d !== null));
|
|
3775
|
+
xs = xs.map((x, i) => moveBatchAxis(axisSize, xBdims[i], minBdim, x));
|
|
3776
|
+
const newAxis = axis + (minBdim <= axis ? 1 : 0);
|
|
3777
|
+
return [[concatenate$1(xs, newAxis)], [minBdim]];
|
|
3778
|
+
},
|
|
3779
|
+
[Primitive.Split](axisSize, [x], [xBdim], { axis, sizes }) {
|
|
3780
|
+
require_backend.assertNonNull(xBdim);
|
|
3781
|
+
const newAxis = axis + (xBdim <= axis ? 1 : 0);
|
|
3782
|
+
const outs = split$2(x, newAxis, sizes);
|
|
3783
|
+
return [outs, require_backend.rep(outs.length, xBdim)];
|
|
3784
|
+
},
|
|
3785
|
+
[Primitive.RandomBits](axisSize, [k0, k1], [bdim0, bdim1], { shape: shape$1, mode }) {
|
|
3786
|
+
k0 = moveBatchAxis(axisSize, bdim0, 0, k0);
|
|
3787
|
+
k1 = moveBatchAxis(axisSize, bdim1, 0, k1);
|
|
3788
|
+
return [[randomBits(k0, k1, [axisSize, ...shape$1], mode)], [0]];
|
|
3609
3789
|
},
|
|
3610
|
-
[Primitive.Where]: broadcastBatcher(where$1),
|
|
3611
3790
|
[Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
|
|
3612
3791
|
if (indicesBdim.every((d) => d === null)) {
|
|
3613
3792
|
require_backend.assertNonNull(xBdim);
|
|
@@ -3669,18 +3848,8 @@ const vmapRules = {
|
|
|
3669
3848
|
const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
|
|
3670
3849
|
return [[pad$1(x, newWidth)], [xBdim]];
|
|
3671
3850
|
},
|
|
3672
|
-
[Primitive.Sort](
|
|
3673
|
-
|
|
3674
|
-
if (xBdim !== x.ndim - 1) return [[sort$1(x)], [xBdim]];
|
|
3675
|
-
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3676
|
-
return [[sort$1(x)], [0]];
|
|
3677
|
-
},
|
|
3678
|
-
[Primitive.Argsort](axisSize, [x], [xBdim]) {
|
|
3679
|
-
require_backend.assertNonNull(xBdim);
|
|
3680
|
-
if (xBdim !== x.ndim - 1) return [argsort$1(x), [xBdim, xBdim]];
|
|
3681
|
-
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3682
|
-
return [argsort$1(x), [0, 0]];
|
|
3683
|
-
},
|
|
3851
|
+
[Primitive.Sort]: lastDimsBatcher(Primitive.Sort, 1),
|
|
3852
|
+
[Primitive.Argsort]: lastDimsBatcher(Primitive.Argsort, 1, 2),
|
|
3684
3853
|
[Primitive.TriangularSolve](axisSize, [a, b], [aBdim, bBdim], { unitDiagonal }) {
|
|
3685
3854
|
if (aBdim === null) {
|
|
3686
3855
|
b = moveBatchAxis(axisSize, bBdim, -3, b);
|
|
@@ -3704,12 +3873,8 @@ const vmapRules = {
|
|
|
3704
3873
|
const x = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
|
|
3705
3874
|
return [[x], [0]];
|
|
3706
3875
|
},
|
|
3707
|
-
[Primitive.Cholesky](
|
|
3708
|
-
|
|
3709
|
-
if (xBdim < x.ndim - 2) return [[cholesky$2(x)], [xBdim]];
|
|
3710
|
-
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3711
|
-
return [[cholesky$2(x)], [0]];
|
|
3712
|
-
},
|
|
3876
|
+
[Primitive.Cholesky]: lastDimsBatcher(Primitive.Cholesky, 2),
|
|
3877
|
+
[Primitive.LU]: lastDimsBatcher(Primitive.LU, 2, 3),
|
|
3713
3878
|
[Primitive.Jit](axisSize, args, dims, { name, jaxpr }) {
|
|
3714
3879
|
const newJaxpr = vmapJaxpr(jaxpr, axisSize, dims);
|
|
3715
3880
|
const outs = bind(Primitive.Jit, [...newJaxpr.consts.map((c) => c.ref), ...args], {
|
|
@@ -3860,6 +4025,16 @@ function batchMatmulT(a, b) {
|
|
|
3860
4025
|
function mT(a) {
|
|
3861
4026
|
return moveaxis(a, -2, -1);
|
|
3862
4027
|
}
|
|
4028
|
+
function sliceAxis(a, axis, p) {
|
|
4029
|
+
const slices = Array(a.shape.length).fill([]);
|
|
4030
|
+
slices[require_backend.checkAxis(axis, a.ndim)] = p;
|
|
4031
|
+
return a.slice(...slices);
|
|
4032
|
+
}
|
|
4033
|
+
function padAxis(a, axis, p) {
|
|
4034
|
+
const pads = Array(a.shape.length).fill([0, 0]);
|
|
4035
|
+
pads[require_backend.checkAxis(axis, a.ndim)] = p;
|
|
4036
|
+
return pad$1(a, pads);
|
|
4037
|
+
}
|
|
3863
4038
|
const jvpRules = {
|
|
3864
4039
|
[Primitive.Add]: linearTangentsJvp(Primitive.Add),
|
|
3865
4040
|
[Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
|
|
@@ -3958,6 +4133,8 @@ const jvpRules = {
|
|
|
3958
4133
|
dcond.dispose();
|
|
3959
4134
|
return [[where$1(cond.ref, x, y)], [where$1(cond, dx, dy)]];
|
|
3960
4135
|
},
|
|
4136
|
+
[Primitive.Concatenate]: linearTangentsJvp(Primitive.Concatenate),
|
|
4137
|
+
[Primitive.Split]: linearTangentsJvp(Primitive.Split),
|
|
3961
4138
|
[Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
|
|
3962
4139
|
[Primitive.Gather]([x, ...indices], [dx, ..._], { axis, outDim }) {
|
|
3963
4140
|
const indicesRef = indices.map((t) => t.ref);
|
|
@@ -3992,6 +4169,38 @@ const jvpRules = {
|
|
|
3992
4169
|
const dL = batchMatmulT(L.ref, triu(ST.ref, 1).add(triu(ST)).mul(.5));
|
|
3993
4170
|
return [[L], [dL]];
|
|
3994
4171
|
},
|
|
4172
|
+
[Primitive.LU]([a], [da]) {
|
|
4173
|
+
const [luMatrix, pivots, permutation] = lu$1(a);
|
|
4174
|
+
const [m, n] = a.shape.slice(-2);
|
|
4175
|
+
const k = Math.min(m, n);
|
|
4176
|
+
const luSliceL = sliceAxis(luMatrix.ref, -1, [0, k]);
|
|
4177
|
+
const lLower = tril(luSliceL, -1);
|
|
4178
|
+
const lPadded = m > k ? padAxis(lLower, -1, [0, m - k]) : lLower;
|
|
4179
|
+
const L = lPadded.add(eye(m));
|
|
4180
|
+
const luSliceU = sliceAxis(luMatrix.ref, -2, [0, k]);
|
|
4181
|
+
const uUpper = triu(luSliceU);
|
|
4182
|
+
const uPadded = n > k ? padAxis(uUpper, -2, [0, n - k]) : uUpper;
|
|
4183
|
+
const uEye = n > k ? padAxis(padAxis(eye(n - k), -1, [k, 0]), -2, [k, 0]) : zerosLike$1(uPadded.ref);
|
|
4184
|
+
const U = uPadded.add(uEye);
|
|
4185
|
+
const P = permutation.ref.reshape([...permutation.shape, 1]).equal(arange(m)).astype(da.dtype);
|
|
4186
|
+
const pda = batchMatmulT(P, mT(da));
|
|
4187
|
+
const la = mT(triangularSolve$1(L.ref, mT(pda), {
|
|
4188
|
+
lower: true,
|
|
4189
|
+
unitDiagonal: true
|
|
4190
|
+
}));
|
|
4191
|
+
const lau = triangularSolve$1(mT(U.ref), la, { lower: true });
|
|
4192
|
+
const lDot = batchMatmulT(L, mT(tril(lau.ref, -1)));
|
|
4193
|
+
const uDot = batchMatmulT(triu(lau), mT(U));
|
|
4194
|
+
return [[
|
|
4195
|
+
luMatrix,
|
|
4196
|
+
pivots,
|
|
4197
|
+
permutation
|
|
4198
|
+
], [
|
|
4199
|
+
lDot.add(uDot),
|
|
4200
|
+
zerosLike$1(pivots.ref),
|
|
4201
|
+
zerosLike$1(permutation.ref)
|
|
4202
|
+
]];
|
|
4203
|
+
},
|
|
3995
4204
|
[Primitive.Jit](primals, tangents, { name, jaxpr }) {
|
|
3996
4205
|
const newJaxpr = jvpJaxpr(jaxpr);
|
|
3997
4206
|
const outs = bind(Primitive.Jit, [
|
|
@@ -4032,17 +4241,39 @@ function jvpFlat(f, primals, tangents) {
|
|
|
4032
4241
|
_usingCtx$1.d();
|
|
4033
4242
|
}
|
|
4034
4243
|
}
|
|
4035
|
-
function jvp$1(f, primals, tangents) {
|
|
4244
|
+
function jvp$1(f, primals, tangents, { hasAux = false } = {}) {
|
|
4036
4245
|
const [primalsFlat, inTree] = flatten(primals);
|
|
4037
4246
|
const [tangentsFlat, inTree2] = flatten(tangents);
|
|
4038
4247
|
if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
|
|
4039
|
-
|
|
4248
|
+
let flatFun, outTree, aux;
|
|
4249
|
+
if (hasAux) [flatFun, outTree, aux] = flattenFunWithAux(f, inTree);
|
|
4250
|
+
else [flatFun, outTree] = flattenFun(f, inTree);
|
|
4040
4251
|
const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
|
|
4041
4252
|
if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
|
|
4042
4253
|
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
4043
4254
|
const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
|
|
4255
|
+
if (hasAux) return [
|
|
4256
|
+
primalsOut,
|
|
4257
|
+
tangentsOut,
|
|
4258
|
+
lowerAux(aux.value)
|
|
4259
|
+
];
|
|
4044
4260
|
return [primalsOut, tangentsOut];
|
|
4045
4261
|
}
|
|
4262
|
+
/** Lowering for auxiliary data returned in `hasAux: true` methods. */
|
|
4263
|
+
function lowerAux(aux) {
|
|
4264
|
+
const level = currentTraceLevel();
|
|
4265
|
+
return map((x) => {
|
|
4266
|
+
if (x instanceof Tracer) while (x._trace.main.level > level) if (x instanceof JVPTracer) {
|
|
4267
|
+
x.tangent.dispose();
|
|
4268
|
+
x = x.primal;
|
|
4269
|
+
} else {
|
|
4270
|
+
const y = x.fullLower();
|
|
4271
|
+
if (y._trace.main.level >= x._trace.main.level) throw new Error("internal: lowerAux did not reduce trace level");
|
|
4272
|
+
x = y;
|
|
4273
|
+
}
|
|
4274
|
+
return x;
|
|
4275
|
+
}, aux);
|
|
4276
|
+
}
|
|
4046
4277
|
|
|
4047
4278
|
//#endregion
|
|
4048
4279
|
//#region src/frontend/linearize.ts
|
|
@@ -4113,9 +4344,11 @@ function linearizeFlat(f, primalsIn) {
|
|
|
4113
4344
|
dispose$1
|
|
4114
4345
|
];
|
|
4115
4346
|
}
|
|
4116
|
-
function linearize$1(f,
|
|
4347
|
+
function linearize$1(f, primalsIn, { hasAux = false } = {}) {
|
|
4117
4348
|
const [primalsInFlat, inTree] = flatten(primalsIn);
|
|
4118
|
-
|
|
4349
|
+
let fFlat, outTree, aux;
|
|
4350
|
+
if (hasAux) [fFlat, outTree, aux] = flattenFunWithAux(f, inTree);
|
|
4351
|
+
else [fFlat, outTree] = flattenFun(f, inTree);
|
|
4119
4352
|
const [primalsOutFlat, fLinFlat, dispose$1] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
|
|
4120
4353
|
if (outTree.value === void 0) throw new Error("outTree was not set in linearize");
|
|
4121
4354
|
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
@@ -4126,6 +4359,11 @@ function linearize$1(f, ...primalsIn) {
|
|
|
4126
4359
|
return unflatten(outTree.value, tangentsOutFlat);
|
|
4127
4360
|
});
|
|
4128
4361
|
fLin.dispose = dispose$1;
|
|
4362
|
+
if (hasAux) return [
|
|
4363
|
+
primalsOut,
|
|
4364
|
+
fLin,
|
|
4365
|
+
lowerAux(aux.value)
|
|
4366
|
+
];
|
|
4129
4367
|
return [primalsOut, fLin];
|
|
4130
4368
|
}
|
|
4131
4369
|
var PartialEvalTracer = class extends Tracer {
|
|
@@ -4529,6 +4767,15 @@ const transposeRules = {
|
|
|
4529
4767
|
cond.dispose();
|
|
4530
4768
|
return cts;
|
|
4531
4769
|
},
|
|
4770
|
+
[Primitive.Concatenate]([ct], inputs, { axis }) {
|
|
4771
|
+
if (inputs.some((x) => !(x instanceof UndefPrimal))) throw new NonlinearError(Primitive.Concatenate);
|
|
4772
|
+
const sizes = inputs.map((x) => x.aval.shape[axis]);
|
|
4773
|
+
return split$2(ct, axis, sizes);
|
|
4774
|
+
},
|
|
4775
|
+
[Primitive.Split](cts, [x], { axis }) {
|
|
4776
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Split);
|
|
4777
|
+
return [concatenate$1(cts, axis)];
|
|
4778
|
+
},
|
|
4532
4779
|
[Primitive.Gather]([ct], [x, ...indices], { axis, outDim }) {
|
|
4533
4780
|
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
|
|
4534
4781
|
if (indices.some((i) => i instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
|
|
@@ -4617,9 +4864,11 @@ function vjpFlat(f, primalsIn) {
|
|
|
4617
4864
|
dispose$1
|
|
4618
4865
|
];
|
|
4619
4866
|
}
|
|
4620
|
-
function vjp$1(f,
|
|
4867
|
+
function vjp$1(f, primalsIn, { hasAux = false } = {}) {
|
|
4621
4868
|
const [primalsInFlat, inTree] = flatten(primalsIn);
|
|
4622
|
-
|
|
4869
|
+
let fFlat, outTree, aux;
|
|
4870
|
+
if (hasAux) [fFlat, outTree, aux] = flattenFunWithAux(f, inTree);
|
|
4871
|
+
else [fFlat, outTree] = flattenFun(f, inTree);
|
|
4623
4872
|
const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
|
|
4624
4873
|
if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
|
|
4625
4874
|
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
@@ -4630,26 +4879,43 @@ function vjp$1(f, ...primalsIn) {
|
|
|
4630
4879
|
return unflatten(inTree, cotangentsInFlat);
|
|
4631
4880
|
});
|
|
4632
4881
|
fVjp.dispose = dispose$1;
|
|
4882
|
+
if (hasAux) return [
|
|
4883
|
+
primalsOut,
|
|
4884
|
+
fVjp,
|
|
4885
|
+
lowerAux(aux.value)
|
|
4886
|
+
];
|
|
4633
4887
|
return [primalsOut, fVjp];
|
|
4634
4888
|
}
|
|
4635
|
-
function grad$1(f) {
|
|
4636
|
-
const valueAndGradFn = valueAndGrad$1(f);
|
|
4889
|
+
function grad$1(f, opts) {
|
|
4890
|
+
const valueAndGradFn = valueAndGrad$1(f, opts);
|
|
4637
4891
|
return (...x) => {
|
|
4638
|
-
|
|
4639
|
-
|
|
4640
|
-
|
|
4892
|
+
if (opts?.hasAux) {
|
|
4893
|
+
const [[y, aux], dx] = valueAndGradFn(...x);
|
|
4894
|
+
y.dispose();
|
|
4895
|
+
return [dx, aux];
|
|
4896
|
+
} else {
|
|
4897
|
+
const [y, dx] = valueAndGradFn(...x);
|
|
4898
|
+
y.dispose();
|
|
4899
|
+
return dx;
|
|
4900
|
+
}
|
|
4641
4901
|
};
|
|
4642
4902
|
}
|
|
4643
|
-
function valueAndGrad$1(f) {
|
|
4903
|
+
function valueAndGrad$1(f, opts) {
|
|
4904
|
+
const argnums = opts?.argnums ?? 0;
|
|
4905
|
+
const hasAux = opts?.hasAux ?? false;
|
|
4906
|
+
require_backend.checkInts(argnums);
|
|
4907
|
+
const argnumsSet = new Set(typeof argnums === "number" ? [argnums] : argnums);
|
|
4644
4908
|
return (...x) => {
|
|
4645
4909
|
if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
|
|
4646
|
-
|
|
4910
|
+
for (let i = 0; i < x.length; i++) if (!argnumsSet.has(i)) x[i] = map(stopGradient, x[i]);
|
|
4911
|
+
const [y, fVjp, aux] = vjp$1(f, x, { hasAux });
|
|
4647
4912
|
if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
|
|
4648
4913
|
if (!require_backend.isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
|
|
4649
|
-
const
|
|
4650
|
-
for (const r of rest) dispose(r);
|
|
4914
|
+
const cts = fVjp(onesLike$1(y.ref));
|
|
4651
4915
|
fVjp.dispose();
|
|
4652
|
-
|
|
4916
|
+
for (let i = 0; i < cts.length; i++) if (!argnumsSet.has(i)) dispose(cts[i]);
|
|
4917
|
+
const grads = typeof argnums === "number" ? cts[argnums] : argnums.map((i) => cts[i]);
|
|
4918
|
+
return hasAux ? [[y, aux], grads] : [y, grads];
|
|
4653
4919
|
};
|
|
4654
4920
|
}
|
|
4655
4921
|
function jacrev$1(f) {
|
|
@@ -4657,7 +4923,7 @@ function jacrev$1(f) {
|
|
|
4657
4923
|
if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
|
|
4658
4924
|
const [size$1] = x.shape;
|
|
4659
4925
|
const pullback = (ct) => {
|
|
4660
|
-
const [y, fVjp] = vjp$1(f, x);
|
|
4926
|
+
const [y, fVjp] = vjp$1(f, [x]);
|
|
4661
4927
|
y.dispose();
|
|
4662
4928
|
const [ret] = fVjp(ct);
|
|
4663
4929
|
fVjp.dispose();
|
|
@@ -4666,6 +4932,9 @@ function jacrev$1(f) {
|
|
|
4666
4932
|
return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
|
|
4667
4933
|
};
|
|
4668
4934
|
}
|
|
4935
|
+
function hessian$1(f) {
|
|
4936
|
+
return jacfwd$1(grad$1(f));
|
|
4937
|
+
}
|
|
4669
4938
|
|
|
4670
4939
|
//#endregion
|
|
4671
4940
|
//#region src/library/numpy/einsum.ts
|
|
@@ -4804,8 +5073,8 @@ function computeSizeMap({ shapes, lhsIndices, rhsIndex }) {
|
|
|
4804
5073
|
const idx = lhsIndex[j];
|
|
4805
5074
|
const dim = shape$1[j];
|
|
4806
5075
|
const existing = sizeMap.get(idx);
|
|
4807
|
-
if (existing === void 0) sizeMap.set(idx, dim);
|
|
4808
|
-
else if (existing !== dim) throw new Error(`Inconsistent size for index ${idx} in einsum: ${existing} vs ${dim}`);
|
|
5076
|
+
if (existing === void 0 || existing === 1) sizeMap.set(idx, dim);
|
|
5077
|
+
else if (existing !== dim && dim !== 1) throw new Error(`Inconsistent size for index ${idx} in einsum: ${existing} vs ${dim}`);
|
|
4809
5078
|
}
|
|
4810
5079
|
}
|
|
4811
5080
|
for (const [idx, size$1] of sizeMap) if (!Number.isInteger(idx) || idx < 0) throw new Error(`Invalid index ${idx} in einsum expression, must be non-negative integer`);
|
|
@@ -4961,27 +5230,53 @@ function ifft(a, axis = -1) {
|
|
|
4961
5230
|
//#region src/library/numpy-linalg.ts
|
|
4962
5231
|
var numpy_linalg_exports = {};
|
|
4963
5232
|
__export(numpy_linalg_exports, {
|
|
4964
|
-
cholesky: () => cholesky
|
|
5233
|
+
cholesky: () => cholesky,
|
|
5234
|
+
det: () => det,
|
|
4965
5235
|
diagonal: () => diagonal,
|
|
5236
|
+
inv: () => inv,
|
|
4966
5237
|
lstsq: () => lstsq,
|
|
4967
5238
|
matmul: () => matmul,
|
|
5239
|
+
matrixPower: () => matrixPower,
|
|
4968
5240
|
matrixTranspose: () => matrixTranspose,
|
|
4969
5241
|
outer: () => outer,
|
|
5242
|
+
slogdet: () => slogdet,
|
|
5243
|
+
solve: () => solve,
|
|
4970
5244
|
tensordot: () => tensordot,
|
|
4971
5245
|
trace: () => trace,
|
|
4972
5246
|
vecdot: () => vecdot
|
|
4973
5247
|
});
|
|
5248
|
+
function checkSquare(name, a) {
|
|
5249
|
+
if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`${name}: input must be at least 2D square matrix, got ${a.aval}`);
|
|
5250
|
+
return a.shape[a.ndim - 1];
|
|
5251
|
+
}
|
|
4974
5252
|
/**
|
|
4975
5253
|
* Compute the Cholesky decomposition of a (batched) positive-definite matrix.
|
|
4976
5254
|
*
|
|
4977
5255
|
* This is like `jax.lax.linalg.cholesky()`, except with an option to symmetrize
|
|
4978
5256
|
* the input matrix, which is on by default.
|
|
4979
5257
|
*/
|
|
4980
|
-
function cholesky
|
|
5258
|
+
function cholesky(a, { upper = false, symmetrizeInput = true } = {}) {
|
|
4981
5259
|
a = fudgeArray(a);
|
|
4982
|
-
|
|
5260
|
+
checkSquare("cholesky", a);
|
|
4983
5261
|
if (symmetrizeInput) a = a.ref.add(matrixTranspose(a)).mul(.5);
|
|
4984
|
-
return cholesky(a, { upper });
|
|
5262
|
+
return cholesky$1(a, { upper });
|
|
5263
|
+
}
|
|
5264
|
+
/** Compute the determinant of a square matrix (batched). */
|
|
5265
|
+
function det(a) {
|
|
5266
|
+
a = fudgeArray(a);
|
|
5267
|
+
const n = checkSquare("det", a);
|
|
5268
|
+
const [lu$2, pivots, permutation] = lu(a);
|
|
5269
|
+
permutation.dispose();
|
|
5270
|
+
const parity = pivots.notEqual(arange(n)).astype(int32).sum(-1).mod(2);
|
|
5271
|
+
const sign$1 = parity.mul(-2).add(1);
|
|
5272
|
+
const diag$1 = lu$2.diagonal(0, -1, -2);
|
|
5273
|
+
return prod$1(diag$1, -1).mul(sign$1);
|
|
5274
|
+
}
|
|
5275
|
+
/** Compute the inverse of a square matrix (batched). */
|
|
5276
|
+
function inv(a) {
|
|
5277
|
+
a = fudgeArray(a);
|
|
5278
|
+
const n = checkSquare("inv", a);
|
|
5279
|
+
return solve(a, eye(n));
|
|
4985
5280
|
}
|
|
4986
5281
|
/**
|
|
4987
5282
|
* Return the least-squares solution to a linear equation.
|
|
@@ -5005,7 +5300,7 @@ function lstsq(a, b) {
|
|
|
5005
5300
|
const at = matrixTranspose(a.ref);
|
|
5006
5301
|
if (m <= n) {
|
|
5007
5302
|
const aat = matmul(a, at.ref);
|
|
5008
|
-
const l = cholesky
|
|
5303
|
+
const l = cholesky(aat, { symmetrizeInput: false });
|
|
5009
5304
|
const lb = triangularSolve(l.ref, b, {
|
|
5010
5305
|
leftSide: true,
|
|
5011
5306
|
lower: true
|
|
@@ -5017,7 +5312,7 @@ function lstsq(a, b) {
|
|
|
5017
5312
|
return matmul(at, llb.ref);
|
|
5018
5313
|
} else {
|
|
5019
5314
|
const ata = matmul(at.ref, a);
|
|
5020
|
-
const l = cholesky
|
|
5315
|
+
const l = cholesky(ata, { symmetrizeInput: false });
|
|
5021
5316
|
const atb = matmul(at, b);
|
|
5022
5317
|
const lb = triangularSolve(l.ref, atb, {
|
|
5023
5318
|
leftSide: true,
|
|
@@ -5030,6 +5325,169 @@ function lstsq(a, b) {
|
|
|
5030
5325
|
return llb;
|
|
5031
5326
|
}
|
|
5032
5327
|
}
|
|
5328
|
+
/** Raise a square matrix to an integer power, via repeated squarings. */
|
|
5329
|
+
function matrixPower(a, n) {
|
|
5330
|
+
if (!Number.isInteger(n)) throw new Error(`matrixPower: exponent must be an integer, got ${n}`);
|
|
5331
|
+
a = fudgeArray(a);
|
|
5332
|
+
const m = checkSquare("matrixPower", a);
|
|
5333
|
+
if (n === 0) {
|
|
5334
|
+
a.dispose();
|
|
5335
|
+
return broadcastTo(eye(m), a.shape);
|
|
5336
|
+
}
|
|
5337
|
+
if (n < 0) {
|
|
5338
|
+
a = inv(a);
|
|
5339
|
+
n = -n;
|
|
5340
|
+
}
|
|
5341
|
+
let result = null;
|
|
5342
|
+
let a2k = a;
|
|
5343
|
+
for (let k = 0; n; k++) {
|
|
5344
|
+
if (k > 0) a2k = matmul(a2k.ref, a2k);
|
|
5345
|
+
if (n % 2 === 1) result = result === null ? a2k.ref : matmul(result, a2k.ref);
|
|
5346
|
+
n = Math.floor(n / 2);
|
|
5347
|
+
}
|
|
5348
|
+
a2k.dispose();
|
|
5349
|
+
return result;
|
|
5350
|
+
}
|
|
5351
|
+
/** Return sign and natural logarithm of the determinant of `a`. */
|
|
5352
|
+
function slogdet(a) {
|
|
5353
|
+
a = fudgeArray(a);
|
|
5354
|
+
const n = checkSquare("slogdet", a);
|
|
5355
|
+
const [lu$2, pivots, permutation] = lu(a);
|
|
5356
|
+
permutation.dispose();
|
|
5357
|
+
let parity = pivots.notEqual(arange(n)).astype(int32).sum(-1);
|
|
5358
|
+
const diag$1 = lu$2.diagonal(0, -1, -2);
|
|
5359
|
+
parity = parity.add(diag$1.ref.less(0).astype(int32).sum(-1)).mod(2);
|
|
5360
|
+
const logabsdet = log(absolute(diag$1)).sum(-1);
|
|
5361
|
+
const sign$1 = parity.mul(-2).add(1);
|
|
5362
|
+
return [sign$1, logabsdet];
|
|
5363
|
+
}
|
|
5364
|
+
/**
|
|
5365
|
+
* Solve a linear system of equations.
|
|
5366
|
+
*
|
|
5367
|
+
* This solves a (batched) linear system of equations `a @ x = b` for `x` given
|
|
5368
|
+
* `a` and `b`. If `a` is singular, this will return `nan` or `inf` values.
|
|
5369
|
+
*
|
|
5370
|
+
* @param a - Coefficient matrix of shape `(..., N, N)`.
|
|
5371
|
+
* @param b - Values of shape `(N,)` or `(..., N, M)`.
|
|
5372
|
+
* @returns Solution `x` of shape `(..., N)` or `(..., N, M)`.
|
|
5373
|
+
*/
|
|
5374
|
+
function solve(a, b) {
|
|
5375
|
+
a = fudgeArray(a);
|
|
5376
|
+
b = fudgeArray(b);
|
|
5377
|
+
const n = checkSquare("solve", a);
|
|
5378
|
+
if (b.ndim === 0) throw new Error(`solve: b cannot be scalar`);
|
|
5379
|
+
const bIs1d = b.ndim === 1;
|
|
5380
|
+
if (bIs1d) b = b.reshape([...b.shape, 1]);
|
|
5381
|
+
if (b.shape[b.ndim - 2] !== n) throw new Error(`solve: leading dimension of b must match size of a, got a=${a.aval}, b=${b.aval}`);
|
|
5382
|
+
const m = b.shape[b.ndim - 1];
|
|
5383
|
+
const batchDims = require_backend.generalBroadcast(a.shape.slice(0, -2), b.shape.slice(0, -2));
|
|
5384
|
+
a = broadcastTo(a, [
|
|
5385
|
+
...batchDims,
|
|
5386
|
+
n,
|
|
5387
|
+
n
|
|
5388
|
+
]);
|
|
5389
|
+
b = broadcastTo(b, [
|
|
5390
|
+
...batchDims,
|
|
5391
|
+
n,
|
|
5392
|
+
m
|
|
5393
|
+
]);
|
|
5394
|
+
const [lu$2, pivots, permutation] = lu(a);
|
|
5395
|
+
pivots.dispose();
|
|
5396
|
+
const P = arange(n).equal(permutation.reshape([...permutation.shape, 1])).astype(b.dtype);
|
|
5397
|
+
const LPb = triangularSolve(lu$2.ref, matmul(P, b), {
|
|
5398
|
+
leftSide: true,
|
|
5399
|
+
lower: true,
|
|
5400
|
+
unitDiagonal: true
|
|
5401
|
+
});
|
|
5402
|
+
let x = triangularSolve(lu$2, LPb.ref, {
|
|
5403
|
+
leftSide: true,
|
|
5404
|
+
lower: false
|
|
5405
|
+
});
|
|
5406
|
+
if (bIs1d) x = squeeze(x, -1);
|
|
5407
|
+
return x;
|
|
5408
|
+
}
|
|
5409
|
+
|
|
5410
|
+
//#endregion
|
|
5411
|
+
//#region src/library/numpy/dtype-info.ts
|
|
5412
|
+
/** Machine limits for floating-point types. */
|
|
5413
|
+
function finfo(dtype) {
|
|
5414
|
+
if (!require_backend.isFloatDtype(dtype)) throw new Error(`finfo: received ${dtype}, must be a floating-point type`);
|
|
5415
|
+
switch (dtype) {
|
|
5416
|
+
case require_backend.DType.Float16: return Object.freeze({
|
|
5417
|
+
bits: 16,
|
|
5418
|
+
dtype: require_backend.DType.Float16,
|
|
5419
|
+
eps: 2 ** -10,
|
|
5420
|
+
epsneg: 2 ** -11,
|
|
5421
|
+
machep: -10,
|
|
5422
|
+
max: 65504,
|
|
5423
|
+
maxexp: 16,
|
|
5424
|
+
min: -65504,
|
|
5425
|
+
minexp: -14,
|
|
5426
|
+
negep: -24,
|
|
5427
|
+
nexp: 5,
|
|
5428
|
+
nmant: 10,
|
|
5429
|
+
precision: 3,
|
|
5430
|
+
resolution: .001,
|
|
5431
|
+
smallestNormal: 2 ** -14,
|
|
5432
|
+
smallestSubnormal: 2 ** -24
|
|
5433
|
+
});
|
|
5434
|
+
case require_backend.DType.Float32: return Object.freeze({
|
|
5435
|
+
bits: 32,
|
|
5436
|
+
dtype: require_backend.DType.Float32,
|
|
5437
|
+
eps: 2 ** -23,
|
|
5438
|
+
epsneg: 2 ** -24,
|
|
5439
|
+
machep: -23,
|
|
5440
|
+
max: 34028234663852886e22,
|
|
5441
|
+
maxexp: 128,
|
|
5442
|
+
min: -34028234663852886e22,
|
|
5443
|
+
minexp: -126,
|
|
5444
|
+
negep: -24,
|
|
5445
|
+
nexp: 8,
|
|
5446
|
+
nmant: 23,
|
|
5447
|
+
precision: 6,
|
|
5448
|
+
resolution: 1e-6,
|
|
5449
|
+
smallestNormal: 2 ** -126,
|
|
5450
|
+
smallestSubnormal: 2 ** -149
|
|
5451
|
+
});
|
|
5452
|
+
case require_backend.DType.Float64: return Object.freeze({
|
|
5453
|
+
bits: 64,
|
|
5454
|
+
dtype: require_backend.DType.Float64,
|
|
5455
|
+
eps: 2 ** -52,
|
|
5456
|
+
epsneg: 2 ** -53,
|
|
5457
|
+
machep: -52,
|
|
5458
|
+
max: Number.MAX_VALUE,
|
|
5459
|
+
maxexp: 1024,
|
|
5460
|
+
min: -Number.MAX_VALUE,
|
|
5461
|
+
minexp: -1022,
|
|
5462
|
+
negep: -53,
|
|
5463
|
+
nexp: 11,
|
|
5464
|
+
nmant: 52,
|
|
5465
|
+
precision: 15,
|
|
5466
|
+
resolution: 1e-15,
|
|
5467
|
+
smallestNormal: 2 ** -1022,
|
|
5468
|
+
smallestSubnormal: 2 ** -1074
|
|
5469
|
+
});
|
|
5470
|
+
default: throw new Error(`finfo: unsupported dtype ${dtype}`);
|
|
5471
|
+
}
|
|
5472
|
+
}
|
|
5473
|
+
/** Machine limits for integer types. */
|
|
5474
|
+
function iinfo(dtype) {
|
|
5475
|
+
switch (dtype) {
|
|
5476
|
+
case require_backend.DType.Int32: return Object.freeze({
|
|
5477
|
+
bits: 32,
|
|
5478
|
+
dtype: require_backend.DType.Int32,
|
|
5479
|
+
max: 2147483647,
|
|
5480
|
+
min: -2147483648
|
|
5481
|
+
});
|
|
5482
|
+
case require_backend.DType.Uint32: return Object.freeze({
|
|
5483
|
+
bits: 32,
|
|
5484
|
+
dtype: require_backend.DType.Uint32,
|
|
5485
|
+
max: 4294967295,
|
|
5486
|
+
min: 0
|
|
5487
|
+
});
|
|
5488
|
+
default: throw new Error(`iinfo: unsupported dtype ${dtype}`);
|
|
5489
|
+
}
|
|
5490
|
+
}
|
|
5033
5491
|
|
|
5034
5492
|
//#endregion
|
|
5035
5493
|
//#region src/library/numpy.ts
|
|
@@ -5085,6 +5543,7 @@ __export(numpy_exports, {
|
|
|
5085
5543
|
diag: () => diag,
|
|
5086
5544
|
diagonal: () => diagonal,
|
|
5087
5545
|
divide: () => trueDivide,
|
|
5546
|
+
divmod: () => divmod,
|
|
5088
5547
|
dot: () => dot$1,
|
|
5089
5548
|
dstack: () => dstack,
|
|
5090
5549
|
e: () => e,
|
|
@@ -5097,6 +5556,7 @@ __export(numpy_exports, {
|
|
|
5097
5556
|
expm1: () => expm1,
|
|
5098
5557
|
eye: () => eye,
|
|
5099
5558
|
fft: () => numpy_fft_exports,
|
|
5559
|
+
finfo: () => finfo,
|
|
5100
5560
|
flip: () => flip,
|
|
5101
5561
|
fliplr: () => fliplr,
|
|
5102
5562
|
flipud: () => flipud,
|
|
@@ -5104,6 +5564,7 @@ __export(numpy_exports, {
|
|
|
5104
5564
|
float32: () => float32,
|
|
5105
5565
|
float64: () => float64,
|
|
5106
5566
|
floor: () => floor,
|
|
5567
|
+
floorDivide: () => floorDivide,
|
|
5107
5568
|
fmod: () => fmod,
|
|
5108
5569
|
frexp: () => frexp,
|
|
5109
5570
|
full: () => full,
|
|
@@ -5116,6 +5577,7 @@ __export(numpy_exports, {
|
|
|
5116
5577
|
hstack: () => hstack,
|
|
5117
5578
|
hypot: () => hypot,
|
|
5118
5579
|
identity: () => identity$1,
|
|
5580
|
+
iinfo: () => iinfo,
|
|
5119
5581
|
inf: () => inf,
|
|
5120
5582
|
inner: () => inner,
|
|
5121
5583
|
int32: () => int32,
|
|
@@ -5133,6 +5595,7 @@ __export(numpy_exports, {
|
|
|
5133
5595
|
log10: () => log10,
|
|
5134
5596
|
log1p: () => log1p,
|
|
5135
5597
|
log2: () => log2,
|
|
5598
|
+
logspace: () => logspace,
|
|
5136
5599
|
matmul: () => matmul,
|
|
5137
5600
|
matrixTranspose: () => matrixTranspose,
|
|
5138
5601
|
max: () => max,
|
|
@@ -5169,9 +5632,11 @@ __export(numpy_exports, {
|
|
|
5169
5632
|
shape: () => shape,
|
|
5170
5633
|
sign: () => sign,
|
|
5171
5634
|
sin: () => sin,
|
|
5635
|
+
sinc: () => sinc,
|
|
5172
5636
|
sinh: () => sinh,
|
|
5173
5637
|
size: () => size,
|
|
5174
5638
|
sort: () => sort,
|
|
5639
|
+
split: () => split$1,
|
|
5175
5640
|
sqrt: () => sqrt,
|
|
5176
5641
|
square: () => square,
|
|
5177
5642
|
squeeze: () => squeeze,
|
|
@@ -5179,6 +5644,8 @@ __export(numpy_exports, {
|
|
|
5179
5644
|
std: () => std,
|
|
5180
5645
|
subtract: () => subtract,
|
|
5181
5646
|
sum: () => sum,
|
|
5647
|
+
swapaxes: () => swapaxes,
|
|
5648
|
+
take: () => take,
|
|
5182
5649
|
tan: () => tan,
|
|
5183
5650
|
tanh: () => tanh,
|
|
5184
5651
|
tensordot: () => tensordot,
|
|
@@ -5437,6 +5904,45 @@ function flip(x, axis = null) {
|
|
|
5437
5904
|
return flip$1(x, axis);
|
|
5438
5905
|
}
|
|
5439
5906
|
/**
|
|
5907
|
+
* Split an array into multiple sub-arrays along an axis.
|
|
5908
|
+
*
|
|
5909
|
+
* @param a - The input array to split.
|
|
5910
|
+
* @param indicesOrSections - If an integer, it indicates the number of equal
|
|
5911
|
+
* sections to create along the specified axis. If a list of integers, it
|
|
5912
|
+
* specifies the indices at which to split the array.
|
|
5913
|
+
* @param axis - The axis along which to split the array. Default is 0.
|
|
5914
|
+
*/
|
|
5915
|
+
function split$1(a, indicesOrSections, axis = 0) {
|
|
5916
|
+
a = fudgeArray(a);
|
|
5917
|
+
axis = require_backend.checkAxis(axis, a.ndim);
|
|
5918
|
+
const size$1 = a.shape[axis];
|
|
5919
|
+
let sizes;
|
|
5920
|
+
if (typeof indicesOrSections === "number") {
|
|
5921
|
+
if (size$1 % indicesOrSections !== 0) throw new Error(`Array of size ${size$1} cannot be split into ${indicesOrSections} equal parts`);
|
|
5922
|
+
const partSize = size$1 / indicesOrSections;
|
|
5923
|
+
sizes = require_backend.rep(indicesOrSections, partSize);
|
|
5924
|
+
} else {
|
|
5925
|
+
const indices = indicesOrSections;
|
|
5926
|
+
sizes = [indices[0]];
|
|
5927
|
+
for (let i = 1; i < indices.length; i++) sizes.push(indices[i] - indices[i - 1]);
|
|
5928
|
+
sizes.push(size$1 - indices[indices.length - 1]);
|
|
5929
|
+
}
|
|
5930
|
+
const results = [];
|
|
5931
|
+
for (let i = 0; i < sizes.length; i += 7) if (i === sizes.length) {
|
|
5932
|
+
results.push(a);
|
|
5933
|
+
break;
|
|
5934
|
+
} else if (i + 8 >= sizes.length) {
|
|
5935
|
+
results.push(...split$2(a, axis, sizes.slice(i)));
|
|
5936
|
+
break;
|
|
5937
|
+
} else {
|
|
5938
|
+
const groupSizes = [...sizes.slice(i, i + 7), sizes.slice(i + 7).reduce((x, y) => x + y, 0)];
|
|
5939
|
+
const outs = split$2(a, axis, groupSizes);
|
|
5940
|
+
results.push(...outs.slice(0, -1));
|
|
5941
|
+
a = outs[outs.length - 1];
|
|
5942
|
+
}
|
|
5943
|
+
return results;
|
|
5944
|
+
}
|
|
5945
|
+
/**
|
|
5440
5946
|
* Join a sequence of arrays along an existing axis.
|
|
5441
5947
|
*
|
|
5442
5948
|
* The arrays must have the same shape, except in the dimension corresponding to
|
|
@@ -5448,13 +5954,11 @@ function concatenate(xs, axis = 0) {
|
|
|
5448
5954
|
if (xs.length === 0) throw new Error("Need at least one array to concatenate");
|
|
5449
5955
|
const shapes = xs.map(shape);
|
|
5450
5956
|
axis = require_backend.checkAxis(axis, shapes[0].length);
|
|
5451
|
-
for (let i = 1; i < shapes.length; i++) if (shapes[i].length !== shapes[0].length || !shapes[i].every((d, j) => j === axis || d === shapes[0][j])) throw new Error(`Cannot concatenate arrays
|
|
5452
|
-
const makePadAxis = (start, end) => shapes[0].map((_, i) => i === axis ? [start, end] : [0, 0]);
|
|
5957
|
+
for (let i = 1; i < shapes.length; i++) if (shapes[i].length !== shapes[0].length || !shapes[i].every((d, j) => j === axis || d === shapes[0][j])) throw new Error(`Cannot concatenate arrays ${xs[0].aval} and ${xs[i].aval} along axis ${axis}`);
|
|
5453
5958
|
let result = xs[0];
|
|
5454
|
-
for (let i = 1; i < xs.length; i
|
|
5455
|
-
const
|
|
5456
|
-
|
|
5457
|
-
result = pad(result, makePadAxis(0, len2)).add(pad(xs[i], makePadAxis(len1, 0)));
|
|
5959
|
+
for (let i = 1; i < xs.length; i += 7) {
|
|
5960
|
+
const group = xs.slice(i, i + 7);
|
|
5961
|
+
result = concatenate$1([result, ...group], axis);
|
|
5458
5962
|
}
|
|
5459
5963
|
return result;
|
|
5460
5964
|
}
|
|
@@ -5539,6 +6043,17 @@ function flipud(x) {
|
|
|
5539
6043
|
function fliplr(x) {
|
|
5540
6044
|
return flip(x, 1);
|
|
5541
6045
|
}
|
|
6046
|
+
/** Interchange two axes of an array. */
|
|
6047
|
+
function swapaxes(a, axis1, axis2) {
|
|
6048
|
+
a = fudgeArray(a);
|
|
6049
|
+
axis1 = require_backend.checkAxis(axis1, a.ndim);
|
|
6050
|
+
axis2 = require_backend.checkAxis(axis2, a.ndim);
|
|
6051
|
+
if (axis1 === axis2) return a;
|
|
6052
|
+
const perm = require_backend.range(a.ndim);
|
|
6053
|
+
perm[axis1] = axis2;
|
|
6054
|
+
perm[axis2] = axis1;
|
|
6055
|
+
return transpose(a, perm);
|
|
6056
|
+
}
|
|
5542
6057
|
/** Transpose the last two dimensions of an array. */
|
|
5543
6058
|
function matrixTranspose(a) {
|
|
5544
6059
|
if (ndim(a) < 2) throw new Error(`matrixTranspose: input array must be at least 2D`);
|
|
@@ -5706,6 +6221,20 @@ function sort(a, axis = -1) {
|
|
|
5706
6221
|
function argsort(a, axis = -1) {
|
|
5707
6222
|
return fudgeArray(a).argsort(axis);
|
|
5708
6223
|
}
|
|
6224
|
+
/**
|
|
6225
|
+
* Take elements from an array along an axis.
|
|
6226
|
+
*
|
|
6227
|
+
* This is equivalent to advanced indexing with integer indices over that
|
|
6228
|
+
* numbered axis. By default, the flattened array is used.
|
|
6229
|
+
*/
|
|
6230
|
+
function take(a, indices, axis = null) {
|
|
6231
|
+
if (axis === null) {
|
|
6232
|
+
a = ravel(a);
|
|
6233
|
+
axis = 0;
|
|
6234
|
+
}
|
|
6235
|
+
axis = require_backend.checkAxis(axis, ndim(a));
|
|
6236
|
+
return gather(a, [indices], [axis], axis);
|
|
6237
|
+
}
|
|
5709
6238
|
/** Return if two arrays are element-wise equal within a tolerance. */
|
|
5710
6239
|
function allclose(actual, expected, options) {
|
|
5711
6240
|
const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
|
|
@@ -6025,6 +6554,20 @@ function tan(x) {
|
|
|
6025
6554
|
x = fudgeArray(x);
|
|
6026
6555
|
return sin(x.ref).div(cos(x));
|
|
6027
6556
|
}
|
|
6557
|
+
/**
|
|
6558
|
+
* @function
|
|
6559
|
+
* Return the normalized sinc function.
|
|
6560
|
+
*
|
|
6561
|
+
* The sinc function is defined as `sin(πx) / (πx)` for `x != 0`, and `1` for `x = 0`.
|
|
6562
|
+
* This is the normalized sinc function commonly used in signal processing.
|
|
6563
|
+
*
|
|
6564
|
+
* **Note:** JVP is not supported at x=0 due to discontinuous derivative. This
|
|
6565
|
+
* requires a custom JVP rule to handle properly (see JAX implementation).
|
|
6566
|
+
*/
|
|
6567
|
+
const sinc = jit$1(function sinc$1(x) {
|
|
6568
|
+
const pix = x.ref.mul(Math.PI);
|
|
6569
|
+
return where(equal(x, 0), 1, sin(pix.ref).div(pix));
|
|
6570
|
+
});
|
|
6028
6571
|
/** Element-wise inverse cosine function (inverse of cos). */
|
|
6029
6572
|
function acos(x) {
|
|
6030
6573
|
return subtract(pi / 2, asin(x));
|
|
@@ -6077,6 +6620,25 @@ function trueDivide(x, y) {
|
|
|
6077
6620
|
return x.div(y);
|
|
6078
6621
|
}
|
|
6079
6622
|
/**
|
|
6623
|
+
* Return the largest integer smaller or equal to the division of the inputs.
|
|
6624
|
+
*
|
|
6625
|
+
* The result is always rounded towards negative infinity.
|
|
6626
|
+
*
|
|
6627
|
+
* For floating-point inputs, this is equivalent to `floor(x / y)`.
|
|
6628
|
+
* For integer inputs, we use `(x - remainder(x, y)) / y` to handle
|
|
6629
|
+
* negative values correctly (note: may overflow near int32 boundaries).
|
|
6630
|
+
*
|
|
6631
|
+
* @param x - Dividend array.
|
|
6632
|
+
* @param y - Divisor array.
|
|
6633
|
+
* @returns Element-wise floor division of x by y.
|
|
6634
|
+
*/
|
|
6635
|
+
function floorDivide(x, y) {
|
|
6636
|
+
x = fudgeArray(x);
|
|
6637
|
+
y = fudgeArray(y);
|
|
6638
|
+
if (require_backend.isFloatDtype(x.dtype) || require_backend.isFloatDtype(y.dtype)) return floor(trueDivide(x, y));
|
|
6639
|
+
return subtract(x, remainder(x.ref, y.ref)).div(y);
|
|
6640
|
+
}
|
|
6641
|
+
/**
|
|
6080
6642
|
* @function
|
|
6081
6643
|
* Calculate element-wise floating-point modulo operation.
|
|
6082
6644
|
*/
|
|
@@ -6090,6 +6652,20 @@ const fmod = jit$1(function fmod$1(x, y) {
|
|
|
6090
6652
|
const remainder = jit$1(function remainder$1(x, y) {
|
|
6091
6653
|
return mod(mod(x, y.ref).add(y.ref), y);
|
|
6092
6654
|
});
|
|
6655
|
+
/**
|
|
6656
|
+
* Return element-wise quotient and remainder simultaneously.
|
|
6657
|
+
*
|
|
6658
|
+
* Equivalent to `[floorDivide(x, y), remainder(x, y)]`.
|
|
6659
|
+
*
|
|
6660
|
+
* @param x - Dividend array.
|
|
6661
|
+
* @param y - Divisor array.
|
|
6662
|
+
* @returns Tuple of [quotient, remainder].
|
|
6663
|
+
*/
|
|
6664
|
+
function divmod(x, y) {
|
|
6665
|
+
const xArr = fudgeArray(x);
|
|
6666
|
+
const yArr = fudgeArray(y);
|
|
6667
|
+
return [floorDivide(xArr.ref, yArr.ref), remainder(xArr, yArr)];
|
|
6668
|
+
}
|
|
6093
6669
|
/** Round input to the nearest integer towards zero. */
|
|
6094
6670
|
function trunc(x) {
|
|
6095
6671
|
return idiv(x, 1);
|
|
@@ -6253,14 +6829,15 @@ function std(x, axis = null, opts) {
|
|
|
6253
6829
|
return sqrt(var_(x, axis, opts));
|
|
6254
6830
|
}
|
|
6255
6831
|
/** Estimate the sample covariance of a set of variables. */
|
|
6256
|
-
function cov(x, y) {
|
|
6832
|
+
function cov(x, y = null, { rowvar = true } = {}) {
|
|
6257
6833
|
x = fudgeArray(x);
|
|
6258
6834
|
if (x.ndim === 1) x = x.reshape([1, x.shape[0]]);
|
|
6259
|
-
if (y !==
|
|
6835
|
+
if (y !== null) {
|
|
6260
6836
|
y = fudgeArray(y);
|
|
6261
6837
|
if (y.ndim === 1) y = y.reshape([1, y.shape[0]]);
|
|
6262
6838
|
x = vstack([x, y]);
|
|
6263
6839
|
}
|
|
6840
|
+
if (!rowvar) x = x.transpose();
|
|
6264
6841
|
const [_M, N] = x.shape;
|
|
6265
6842
|
x = x.ref.sub(x.mean(1, { keepdims: true }));
|
|
6266
6843
|
return dot$1(x.ref, x.transpose()).div(N - 1);
|
|
@@ -6305,7 +6882,8 @@ const isfinite = jit$1(function isfinite$1(x) {
|
|
|
6305
6882
|
//#region src/library/lax-linalg.ts
|
|
6306
6883
|
var lax_linalg_exports = {};
|
|
6307
6884
|
__export(lax_linalg_exports, {
|
|
6308
|
-
cholesky: () => cholesky,
|
|
6885
|
+
cholesky: () => cholesky$1,
|
|
6886
|
+
lu: () => lu,
|
|
6309
6887
|
triangularSolve: () => triangularSolve
|
|
6310
6888
|
});
|
|
6311
6889
|
/**
|
|
@@ -6334,11 +6912,39 @@ __export(lax_linalg_exports, {
|
|
|
6334
6912
|
* // U ≈ [[1.4142135, 0.70710677], [0, 1.2247449]]
|
|
6335
6913
|
* ```
|
|
6336
6914
|
*/
|
|
6337
|
-
function cholesky(a, { upper = false } = {}) {
|
|
6915
|
+
function cholesky$1(a, { upper = false } = {}) {
|
|
6338
6916
|
const L = cholesky$2(a);
|
|
6339
6917
|
return upper ? moveaxis$1(L, -2, -1) : L;
|
|
6340
6918
|
}
|
|
6341
6919
|
/**
|
|
6920
|
+
* LU decomposition with partial pivoting.
|
|
6921
|
+
*
|
|
6922
|
+
* Computes the matrix decomposition: `P @ A = L @ U`, where `P` is a
|
|
6923
|
+
* permutation of the rows of `A`, `L` is lower-triangular with unit diagonal,
|
|
6924
|
+
* and `U` is upper-triangular.
|
|
6925
|
+
*
|
|
6926
|
+
* @param x - A batch of matrices with shape `[..., m, n]`.
|
|
6927
|
+
*
|
|
6928
|
+
* @returns A tuple `(lu, pivots, permutation)` where:
|
|
6929
|
+
* - `lu`: combined lower and upper triangular matrices.
|
|
6930
|
+
* - `pivots`: an array of pivot indices with shape `[..., min(m, n)]`.
|
|
6931
|
+
* - `permutation`: the permutation generated by pivots with shape `[..., m]`.
|
|
6932
|
+
*
|
|
6933
|
+
* @example
|
|
6934
|
+
* ```ts
|
|
6935
|
+
* import { lax, numpy as np } from "@jax-js/jax";
|
|
6936
|
+
*
|
|
6937
|
+
* const A = np.array([[4., 3.], [6., 3.]]);
|
|
6938
|
+
* const [lu, pivots, permutation] = lax.linalg.lu(A);
|
|
6939
|
+
* // lu ≈ [[6., 3.], [0.6666667, 1.0]]
|
|
6940
|
+
* // pivots = [1, 1]
|
|
6941
|
+
* // permutation = [1, 0]
|
|
6942
|
+
* ```
|
|
6943
|
+
*/
|
|
6944
|
+
function lu(x) {
|
|
6945
|
+
return lu$1(x);
|
|
6946
|
+
}
|
|
6947
|
+
/**
|
|
6342
6948
|
* Solve a triangular linear system.
|
|
6343
6949
|
*
|
|
6344
6950
|
* Solves `a @ x = b` (if leftSide=true) or `x @ a = b` (if leftSide=false)
|
|
@@ -6376,6 +6982,7 @@ var lax_exports = {};
|
|
|
6376
6982
|
__export(lax_exports, {
|
|
6377
6983
|
conv: () => conv,
|
|
6378
6984
|
convGeneralDilated: () => convGeneralDilated,
|
|
6985
|
+
convTranspose: () => convTranspose,
|
|
6379
6986
|
convWithGeneralPadding: () => convWithGeneralPadding,
|
|
6380
6987
|
dot: () => dot,
|
|
6381
6988
|
erf: () => erf,
|
|
@@ -6384,6 +6991,7 @@ __export(lax_exports, {
|
|
|
6384
6991
|
reduceWindow: () => reduceWindow,
|
|
6385
6992
|
stopGradient: () => stopGradient$1
|
|
6386
6993
|
});
|
|
6994
|
+
const JsArray = globalThis.Array;
|
|
6387
6995
|
/**
|
|
6388
6996
|
* General dot product/contraction operator.
|
|
6389
6997
|
*
|
|
@@ -6455,7 +7063,11 @@ function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
|
|
|
6455
7063
|
* The semantics of this operation mimic the `jax.lax.conv_general_dilated`
|
|
6456
7064
|
* function in JAX, which wraps XLA's general convolution operator.
|
|
6457
7065
|
*
|
|
6458
|
-
*
|
|
7066
|
+
* @param lhs - Input tensor; shape `[N, C_in, ...xs]`
|
|
7067
|
+
* @param rhs - Convolution kernel; shape `[C_out, C_in / G, ...ks]`
|
|
7068
|
+
* @param windowStrides - Strides for each spatial dimension
|
|
7069
|
+
* @param padding - Padding for each spatial dimension, or a string
|
|
7070
|
+
* (`"VALID"`, `"SAME"`, or `"SAME_LOWER"`)
|
|
6459
7071
|
*/
|
|
6460
7072
|
function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
|
|
6461
7073
|
if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
|
|
@@ -6515,6 +7127,60 @@ function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, r
|
|
|
6515
7127
|
function conv(lhs, rhs, windowStrides, padding) {
|
|
6516
7128
|
return convGeneralDilated(lhs, rhs, windowStrides, padding);
|
|
6517
7129
|
}
|
|
7130
|
+
/**
|
|
7131
|
+
* Convenience wrapper for calculating the N-d convolution "transpose".
|
|
7132
|
+
*
|
|
7133
|
+
* This function directly calculates a fractionally strided conv rather than
|
|
7134
|
+
* indirectly calculating the gradient (transpose) of a forward convolution.
|
|
7135
|
+
* It is equivalent to the JAX version, except:
|
|
7136
|
+
*
|
|
7137
|
+
* - The `use_consistent_padding` option is not available. We only have the
|
|
7138
|
+
* consistent padding case (JAX version >0.8.4).
|
|
7139
|
+
* - The order of dimensions matches `lax.conv_general_dilated`.
|
|
7140
|
+
*
|
|
7141
|
+
* Unlike PyTorch/TensorFlow, by default we don't reverse the kernel's spatial
|
|
7142
|
+
* dimensions or the `(C_out, C_in)` axis order. To get this behavior, set
|
|
7143
|
+
* `transposeKernel` to true.
|
|
7144
|
+
*
|
|
7145
|
+
* @param lhs - Input tensor; shape `[N, C_in, ...xs]`
|
|
7146
|
+
* @param rhs - Convolution kernel; shape `[C_out, C_in, ...ks]`
|
|
7147
|
+
* @param strides - Sequence of n integers, sets fractional stride
|
|
7148
|
+
* @param padding - Apply padding of `dilation * (kernel_size - 1) - padding` to
|
|
7149
|
+
* each side of the input, so it acts like gradient of `conv()`
|
|
7150
|
+
* @param rhsDilation - Atrous dilation for the kernel
|
|
7151
|
+
* @param transposeKernel - Flip spatial axes and swap the input/output channels
|
|
7152
|
+
* of the kernel; its shape should be `[C_in, C_out, ...ks]`
|
|
7153
|
+
*/
|
|
7154
|
+
function convTranspose(lhs, rhs, strides, padding, { rhsDilation, transposeKernel = false } = {}) {
|
|
7155
|
+
const kernelShape = rhs.shape.slice(2);
|
|
7156
|
+
rhsDilation = rhsDilation ?? require_backend.rep(kernelShape.length, 1);
|
|
7157
|
+
const effectiveKernel = kernelShape.map((k, i) => Math.max(0, (k - 1) * rhsDilation[i] + 1));
|
|
7158
|
+
const pads = effectiveKernel.map((k, i) => convTransposePadding(k, strides[i], typeof padding === "string" ? padding : padding[i]));
|
|
7159
|
+
if (transposeKernel) {
|
|
7160
|
+
rhs = flip$1(rhs, require_backend.range(2, rhs.ndim));
|
|
7161
|
+
rhs = moveaxis(rhs, 0, 1);
|
|
7162
|
+
}
|
|
7163
|
+
return convGeneralDilated(lhs, rhs, require_backend.rep(lhs.ndim - 2, 1), pads, {
|
|
7164
|
+
lhsDilation: strides,
|
|
7165
|
+
rhsDilation
|
|
7166
|
+
});
|
|
7167
|
+
}
|
|
7168
|
+
function convTransposePadding(k, s, padding) {
|
|
7169
|
+
let padLen;
|
|
7170
|
+
let pad1;
|
|
7171
|
+
if (padding === "SAME") {
|
|
7172
|
+
padLen = k + s - 2;
|
|
7173
|
+
pad1 = s > k - 1 ? k - 1 : Math.ceil(padLen / 2);
|
|
7174
|
+
} else if (padding === "VALID") {
|
|
7175
|
+
padLen = k + s - 2 + Math.max(k - s, 0);
|
|
7176
|
+
pad1 = k - 1;
|
|
7177
|
+
} else if (JsArray.isArray(padding)) {
|
|
7178
|
+
const pads = [k - 1 - padding[0], k - 1 - padding[1]];
|
|
7179
|
+
pad1 = pads[0];
|
|
7180
|
+
padLen = pads[0] + pads[1];
|
|
7181
|
+
} else throw new Error(`convTranspose: Invalid padding type ${padding}`);
|
|
7182
|
+
return [pad1, padLen - pad1];
|
|
7183
|
+
}
|
|
6518
7184
|
/** Reduce a computation over padded windows. */
|
|
6519
7185
|
function reduceWindow(operand, computation, windowDimensions, windowStrides) {
|
|
6520
7186
|
if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
|
|
@@ -6553,6 +7219,7 @@ function stopGradient$1(x) {
|
|
|
6553
7219
|
var nn_exports = {};
|
|
6554
7220
|
__export(nn_exports, {
|
|
6555
7221
|
celu: () => celu,
|
|
7222
|
+
dotProductAttention: () => dotProductAttention,
|
|
6556
7223
|
elu: () => elu,
|
|
6557
7224
|
gelu: () => gelu,
|
|
6558
7225
|
glu: () => glu,
|
|
@@ -6869,6 +7536,95 @@ function oneHot(x, numClasses) {
|
|
|
6869
7536
|
if (require_backend.isFloatDtype(x.dtype)) throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
|
|
6870
7537
|
return eye(numClasses, void 0, { device: x.device }).slice(x);
|
|
6871
7538
|
}
|
|
7539
|
+
/**
|
|
7540
|
+
* Scaled dot product attention (SDPA).
|
|
7541
|
+
*
|
|
7542
|
+
* Computes `softmax((Q @ K^T) / sqrt(d) + bias) @ V`, where `Q` is the query,
|
|
7543
|
+
* `K` is the key, `V` is the value, and `d` is the dimensionality of each key
|
|
7544
|
+
* and query vector.
|
|
7545
|
+
*
|
|
7546
|
+
* Multi-query attention is applied when input `key` and `value` tensors have
|
|
7547
|
+
* fewer heads than `query`.
|
|
7548
|
+
*
|
|
7549
|
+
* We use the following uppercase letters to denote array shapes:
|
|
7550
|
+
* - `B` = batch size
|
|
7551
|
+
* - `S` = length of key/value sequences (source)
|
|
7552
|
+
* - `L` = length of query sequences
|
|
7553
|
+
* - `N` = number of attention heads
|
|
7554
|
+
* - `H` = dimensionality of each attention head
|
|
7555
|
+
* - `K` = number of key/value heads (for grouped-query attention)
|
|
7556
|
+
*
|
|
7557
|
+
* The batch size `B` may be omitted, which is equivalent to `B = 1`. In this
|
|
7558
|
+
* case it must be omitted from all inputs.
|
|
7559
|
+
*
|
|
7560
|
+
* @param query - Query array; shape `[B, L, N, H]`
|
|
7561
|
+
* @param key - Key array; shape `[B, S, K, H]`
|
|
7562
|
+
* @param value - Value array; same shape as `key`
|
|
7563
|
+
* @param opts.bias - Optional bias to add to the attention logits; shape
|
|
7564
|
+
* `[B, N, L, S]` or broadcastable to it.
|
|
7565
|
+
* @param opts.mask - Optional mask to apply to the attention logits; should be
|
|
7566
|
+
* a boolean array broadcastable to `[B, N, L, S]`, where `true` indicates
|
|
7567
|
+
* the element should take part in attention.
|
|
7568
|
+
* @param opts.scale - Scaling factor override, default is `1 / sqrt(H)`.
|
|
7569
|
+
* @param opts.isCausal - If true, applies a casual mask.
|
|
7570
|
+
* @param opts.querySeqLengths - Optional sequence lengths for the queries;
|
|
7571
|
+
* shape `(B,)`. Taken from the beginning of the tensor.
|
|
7572
|
+
* @param opts.keyValueSeqLengths - Optional sequence lengths for the keys and
|
|
7573
|
+
* values; shape `(B,)`. Taken from the beginning of the tensor.
|
|
7574
|
+
* @param opts.localWindowSize - If specified, applies a local attention window
|
|
7575
|
+
* of the given size. Can be a single number or a tuple `[left, right]`.
|
|
7576
|
+
*
|
|
7577
|
+
* @returns The result of the attention operation; shape is the same as query
|
|
7578
|
+
* `[B, L, N, H]`, or `[L, N, H]` if `B` is omitted.
|
|
7579
|
+
*/
|
|
7580
|
+
function dotProductAttention(query, key$1, value, opts = {}) {
|
|
7581
|
+
if (opts.querySeqLengths !== void 0 || opts.keyValueSeqLengths !== void 0) throw new Error("Sequence length masking is not yet implemented");
|
|
7582
|
+
if (opts.localWindowSize !== void 0) throw new Error("Local attention is not yet implemented");
|
|
7583
|
+
query = fudgeArray(query);
|
|
7584
|
+
key$1 = fudgeArray(key$1);
|
|
7585
|
+
value = fudgeArray(value);
|
|
7586
|
+
if (query.ndim !== 3 && query.ndim !== 4 || query.ndim !== key$1.ndim || query.ndim !== value.ndim) throw new Error(`dotProductAttention: expected all tensors to have rank 3 or 4, got Q=${query.aval}, K=${key$1.aval}, V=${value.aval}`);
|
|
7587
|
+
if (!require_backend.deepEqual(key$1.shape, value.shape)) throw new Error(`dotProductAttention: key and value shapes must match, got K=${key$1.shape}, V=${value.shape}`);
|
|
7588
|
+
const isRank3 = query.ndim === 3;
|
|
7589
|
+
if (isRank3) {
|
|
7590
|
+
query = expandDims(query, 0);
|
|
7591
|
+
key$1 = expandDims(key$1, 0);
|
|
7592
|
+
value = expandDims(value, 0);
|
|
7593
|
+
}
|
|
7594
|
+
const [B, L, N, H] = query.shape;
|
|
7595
|
+
if (key$1.shape[0] !== B || key$1.shape[3] !== H) throw new Error(`dotProductAttention: query and key shapes mismatch, got Q=${query.aval}, K=${key$1.aval}`);
|
|
7596
|
+
const S = key$1.shape[1];
|
|
7597
|
+
const K = key$1.shape[2];
|
|
7598
|
+
if (N < K || N != K && N % K !== 0) throw new Error(`dotProductAttention: number of query heads N=${N} must be divisible by number of key/value heads K=${K} for GQA`);
|
|
7599
|
+
const G = N / K;
|
|
7600
|
+
key$1 = tile(key$1, [
|
|
7601
|
+
1,
|
|
7602
|
+
1,
|
|
7603
|
+
G,
|
|
7604
|
+
1
|
|
7605
|
+
]);
|
|
7606
|
+
value = tile(value, [
|
|
7607
|
+
1,
|
|
7608
|
+
1,
|
|
7609
|
+
G,
|
|
7610
|
+
1
|
|
7611
|
+
]);
|
|
7612
|
+
const scale = opts.scale ?? 1 / Math.sqrt(H);
|
|
7613
|
+
let scores = einsum("BLNH,BSNH->BNLS", query, key$1).mul(scale);
|
|
7614
|
+
if (opts.bias !== void 0) scores = scores.add(opts.bias);
|
|
7615
|
+
if (opts.mask !== void 0) scores = where(opts.mask, scores, -Infinity);
|
|
7616
|
+
if (opts.isCausal) {
|
|
7617
|
+
const causalMask = tri(L, S, 0, { dtype: require_backend.DType.Bool });
|
|
7618
|
+
scores = where(causalMask, scores, -Infinity);
|
|
7619
|
+
}
|
|
7620
|
+
const attn = softmax(scores, -1);
|
|
7621
|
+
const out = einsum("BNLS,BSNH->BLNH", attn, value);
|
|
7622
|
+
return isRank3 ? out.reshape([
|
|
7623
|
+
L,
|
|
7624
|
+
N,
|
|
7625
|
+
H
|
|
7626
|
+
]) : out;
|
|
7627
|
+
}
|
|
6872
7628
|
|
|
6873
7629
|
//#endregion
|
|
6874
7630
|
//#region src/library/random.ts
|
|
@@ -6881,33 +7637,41 @@ __export(random_exports, {
|
|
|
6881
7637
|
gumbel: () => gumbel,
|
|
6882
7638
|
key: () => key,
|
|
6883
7639
|
laplace: () => laplace,
|
|
7640
|
+
multivariateNormal: () => multivariateNormal,
|
|
6884
7641
|
normal: () => normal,
|
|
6885
7642
|
split: () => split,
|
|
6886
7643
|
uniform: () => uniform
|
|
6887
7644
|
});
|
|
6888
|
-
function validateKeyShape(key$1) {
|
|
7645
|
+
function validateKeyShape(key$1, scalar = false) {
|
|
6889
7646
|
if (key$1.ndim === 0) throw new Error("Key must have at least one dimension.");
|
|
6890
7647
|
if (key$1.shape[key$1.shape.length - 1] !== 2) throw new Error(`Invalid key shape: ${key$1.shape}. Expected last dimension to be 2.`);
|
|
7648
|
+
if (scalar && key$1.shape.length > 1) throw new Error(`Expected a single PRNG key, but got a batch of keys with shape ${JSON.stringify(key$1.shape)} - use jax.vmap for batching.`);
|
|
6891
7649
|
return key$1.shape.slice(0, -1);
|
|
6892
7650
|
}
|
|
7651
|
+
function getK01(key$1) {
|
|
7652
|
+
const keyShape = validateKeyShape(key$1, true);
|
|
7653
|
+
let [k0, k1] = split$2(key$1, -1, [1, 1]);
|
|
7654
|
+
k0 = k0.reshape(keyShape);
|
|
7655
|
+
k1 = k1.reshape(keyShape);
|
|
7656
|
+
return [k0, k1];
|
|
7657
|
+
}
|
|
6893
7658
|
/** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
|
|
6894
7659
|
function key(seed) {
|
|
6895
|
-
seed = seed
|
|
6896
|
-
|
|
7660
|
+
seed = array(seed, { dtype: require_backend.DType.Uint32 });
|
|
7661
|
+
if (seed.ndim !== 0) throw new Error(`key: seed must be a scalar integer, but got shape ${seed.shape} - use jax.vmap for batching.`);
|
|
7662
|
+
return stack([0, seed]);
|
|
6897
7663
|
}
|
|
6898
7664
|
/** Splits a PRNG key into `num` new keys by adding a leading axis. */
|
|
6899
7665
|
function split(key$1, num = 2) {
|
|
6900
7666
|
const shape$1 = typeof num === "number" ? [num] : num;
|
|
6901
7667
|
for (const len of shape$1) if (len <= 0 || !Number.isInteger(len)) throw new Error(`Invalid split length: ${len}. Must be a positive integer.`);
|
|
6902
|
-
const
|
|
6903
|
-
const k0 = key$1.ref.slice(...keyShape.map(() => null), 0);
|
|
6904
|
-
const k1 = key$1.slice(...keyShape.map(() => null), 1);
|
|
7668
|
+
const [k0, k1] = getK01(key$1);
|
|
6905
7669
|
return stack([randomBits(k0.ref, k1.ref, shape$1, 0), randomBits(k0, k1, shape$1, 1)], -1);
|
|
6906
7670
|
}
|
|
6907
7671
|
/** Sample uniform bits in the form of unsigned integers. */
|
|
6908
7672
|
function bits(key$1, shape$1 = []) {
|
|
6909
|
-
const
|
|
6910
|
-
return randomBits(
|
|
7673
|
+
const [k0, k1] = getK01(key$1);
|
|
7674
|
+
return randomBits(k0, k1, shape$1);
|
|
6911
7675
|
}
|
|
6912
7676
|
/**
|
|
6913
7677
|
* @function
|
|
@@ -6981,6 +7745,32 @@ const laplace = jit$1(function laplace$1(key$1, shape$1 = []) {
|
|
|
6981
7745
|
}, { staticArgnums: [1] });
|
|
6982
7746
|
/**
|
|
6983
7747
|
* @function
|
|
7748
|
+
* Sample multivariate normal random values with given mean and covariance.
|
|
7749
|
+
*
|
|
7750
|
+
* The values are returned with the given shape, along with the final dimension
|
|
7751
|
+
* used to represent the n-dimensional multivariate normal factors.
|
|
7752
|
+
*
|
|
7753
|
+
* This uses Cholesky decomposition on the covariance matrix.
|
|
7754
|
+
*
|
|
7755
|
+
* - `key` - PRNG key
|
|
7756
|
+
* - `mean` - Mean vector of shape `[..., n]`
|
|
7757
|
+
* - `cov` - Covariance of shape `[..., n, n]`, must be positive-definite
|
|
7758
|
+
* - `shape` - Result batch shape, must be broadcastable with
|
|
7759
|
+
* `mean.shape[:-1]` and `cov.shape[:-2]`
|
|
7760
|
+
* @returns Random samples of shape `[...shape, n]`
|
|
7761
|
+
*/
|
|
7762
|
+
const multivariateNormal = jit$1(function multivariateNormal$1(key$1, mean$1, cov$1, shape$1 = []) {
|
|
7763
|
+
mean$1 = fudgeArray(mean$1);
|
|
7764
|
+
cov$1 = fudgeArray(cov$1);
|
|
7765
|
+
const n = mean$1.shape[mean$1.ndim - 1];
|
|
7766
|
+
if (cov$1.shape[cov$1.ndim - 1] !== n || cov$1.shape[cov$1.ndim - 2] !== n) throw new Error(`Invalid covariance shape: ${cov$1.shape}. Expected last two dimensions to be [${n}, ${n}].`);
|
|
7767
|
+
const outputShape = broadcastShapes(shape$1, mean$1.shape.slice(0, -1), cov$1.shape.slice(0, -2)).concat(n);
|
|
7768
|
+
const L = cholesky(cov$1);
|
|
7769
|
+
const z = normal(key$1, outputShape);
|
|
7770
|
+
return einsum("...ij,...j->...i", L, z).add(mean$1);
|
|
7771
|
+
}, { staticArgnums: [3] });
|
|
7772
|
+
/**
|
|
7773
|
+
* @function
|
|
6984
7774
|
* Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
|
|
6985
7775
|
*
|
|
6986
7776
|
* Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
|
|
@@ -7070,17 +7860,62 @@ const linearize = linearize$1;
|
|
|
7070
7860
|
/**
|
|
7071
7861
|
* @function
|
|
7072
7862
|
* Calculate the reverse-mode vector-Jacobian product for a function.
|
|
7863
|
+
*
|
|
7864
|
+
* The return value is a tuple of `[out, vjpFn]`, where `out` is the output of
|
|
7865
|
+
* `f(primals)`, and `vjpFn` is a function that takes in cotangents for each
|
|
7866
|
+
* output and returns the cotangents for each input.
|
|
7867
|
+
*
|
|
7868
|
+
* When `{ hasAux: true }` is passed, the function `f` is expected to return an
|
|
7869
|
+
* `[out, aux]` tuple, and `vjp` returns `[out, vjpFn, aux]`.
|
|
7870
|
+
*
|
|
7871
|
+
* @example
|
|
7872
|
+
* ```ts
|
|
7873
|
+
* const [y, vjpFn] = vjp(f, [x]);
|
|
7874
|
+
*
|
|
7875
|
+
* // With hasAux
|
|
7876
|
+
* const [y, vjpFn, aux] = vjp(f, [x], { hasAux: true });
|
|
7877
|
+
* ```
|
|
7073
7878
|
*/
|
|
7074
7879
|
const vjp = vjp$1;
|
|
7075
7880
|
/**
|
|
7076
7881
|
* @function
|
|
7077
7882
|
* Compute the gradient of a scalar-valued function `f` with respect to its
|
|
7078
7883
|
* first argument.
|
|
7884
|
+
*
|
|
7885
|
+
* Pass in different `argnums` to differentiate with respect to other
|
|
7886
|
+
* arguments. If a tuple is provided, the return value will be a tuple of
|
|
7887
|
+
* gradients corresponding to each argument index.
|
|
7888
|
+
*
|
|
7889
|
+
* When `{ hasAux: true }` is passed, the function `f` is expected to return a
|
|
7890
|
+
* `[out, aux]` tuple, and the return value will be `[gradient, aux]`.
|
|
7891
|
+
*
|
|
7892
|
+
* @example
|
|
7893
|
+
* ```ts
|
|
7894
|
+
* const gradient = grad(f)(x);
|
|
7895
|
+
*
|
|
7896
|
+
* // With `argnums`
|
|
7897
|
+
* const [gradientX, gradientZ] = grad(f, { argnums: [0, 2] })(x, y, z);
|
|
7898
|
+
*
|
|
7899
|
+
* // With `hasAux`
|
|
7900
|
+
* const [gradient, aux] = grad(f, { hasAux: true })(x);
|
|
7901
|
+
* ```
|
|
7079
7902
|
*/
|
|
7080
7903
|
const grad = grad$1;
|
|
7081
7904
|
/**
|
|
7082
7905
|
* @function
|
|
7083
7906
|
* Create a function that evaluates both `f` and the gradient of `f`.
|
|
7907
|
+
*
|
|
7908
|
+
* When `{ hasAux: true }` is passed, the function `f` is expected to return an
|
|
7909
|
+
* `[out, aux]` tuple, and the return value will be `[[out, aux], gradient]`.
|
|
7910
|
+
*
|
|
7911
|
+
* @example
|
|
7912
|
+
* ```ts
|
|
7913
|
+
* // Without hasAux
|
|
7914
|
+
* const [value, gradient] = valueAndGrad(f)(x);
|
|
7915
|
+
*
|
|
7916
|
+
* // With hasAux
|
|
7917
|
+
* const [[value, aux], gradient] = valueAndGrad(f, { hasAux: true })(x);
|
|
7918
|
+
* ```
|
|
7084
7919
|
*/
|
|
7085
7920
|
const valueAndGrad = valueAndGrad$1;
|
|
7086
7921
|
/**
|
|
@@ -7089,6 +7924,21 @@ const valueAndGrad = valueAndGrad$1;
|
|
|
7089
7924
|
*/
|
|
7090
7925
|
const jacrev = jacrev$1;
|
|
7091
7926
|
/**
|
|
7927
|
+
* @function
|
|
7928
|
+
* Compute the Hessian matrix of a scalar-valued function.
|
|
7929
|
+
*
|
|
7930
|
+
* The Hessian is the matrix of second-order partial derivatives of a function.
|
|
7931
|
+
* This is implemented as `jacfwd(grad(f))`.
|
|
7932
|
+
*
|
|
7933
|
+
* @example
|
|
7934
|
+
* ```ts
|
|
7935
|
+
* const f = (x: np.Array) => np.sum(x.ref.mul(x.ref).mul(x)); // x^3
|
|
7936
|
+
* const H = hessian(f)(np.array([1, 2, 3]));
|
|
7937
|
+
* // H[i,j] = d^2f / dx_i dx_j
|
|
7938
|
+
* ```
|
|
7939
|
+
*/
|
|
7940
|
+
const hessian = hessian$1;
|
|
7941
|
+
/**
|
|
7092
7942
|
* Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
|
|
7093
7943
|
*
|
|
7094
7944
|
* This can be used to wait for the results of an intermediate computation to
|
|
@@ -7132,6 +7982,7 @@ exports.defaultDevice = require_backend.defaultDevice;
|
|
|
7132
7982
|
exports.devicePut = devicePut;
|
|
7133
7983
|
exports.devices = require_backend.devices;
|
|
7134
7984
|
exports.grad = grad;
|
|
7985
|
+
exports.hessian = hessian;
|
|
7135
7986
|
exports.init = require_backend.init;
|
|
7136
7987
|
exports.jacfwd = jacfwd;
|
|
7137
7988
|
exports.jacobian = jacrev;
|