@jax-js/jax 0.1.4 → 0.1.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +10 -7
- package/dist/{backend-Bu9GY6sK.cjs → backend-D7s-Retx.cjs} +122 -8
- package/dist/{backend-tngXtWe4.js → backend-Dx6Ob2D1.js} +111 -9
- package/dist/index.cjs +1059 -208
- package/dist/index.d.cts +429 -21
- package/dist/index.d.ts +429 -21
- package/dist/index.js +1059 -209
- package/dist/webgl-CLLvzJlO.js +522 -0
- package/dist/webgl-CyfzNW8T.cjs +522 -0
- package/dist/{webgpu-ChVgx3b6.js → webgpu-C-VfevQW.js} +296 -3
- package/dist/{webgpu-Oj3Kd-kd.cjs → webgpu-rraa6dfz.cjs} +296 -3
- package/package.json +1 -1
package/dist/index.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { __export } from "./chunk-Cl8Af3a2.js";
|
|
2
|
-
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-
|
|
2
|
+
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-Dx6Ob2D1.js";
|
|
3
3
|
|
|
4
4
|
//#region src/frontend/convolution.ts
|
|
5
5
|
/**
|
|
@@ -209,7 +209,7 @@ __export(tree_exports, {
|
|
|
209
209
|
structure: () => structure,
|
|
210
210
|
unflatten: () => unflatten
|
|
211
211
|
});
|
|
212
|
-
const JsArray$
|
|
212
|
+
const JsArray$2 = globalThis.Array;
|
|
213
213
|
let NodeType = /* @__PURE__ */ function(NodeType$1) {
|
|
214
214
|
NodeType$1["Array"] = "Array";
|
|
215
215
|
NodeType$1["Object"] = "Object";
|
|
@@ -257,7 +257,7 @@ function flatten(tree) {
|
|
|
257
257
|
return [leaves$1, treedef];
|
|
258
258
|
}
|
|
259
259
|
function _flatten(tree, leaves$1) {
|
|
260
|
-
if (JsArray$
|
|
260
|
+
if (JsArray$2.isArray(tree)) {
|
|
261
261
|
const childTrees = tree.map((c) => _flatten(c, leaves$1));
|
|
262
262
|
return new JsTreeDef(NodeType.Array, null, childTrees);
|
|
263
263
|
} else if (typeof tree === "object" && tree !== null && tree.constructor === Object) {
|
|
@@ -356,6 +356,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
356
356
|
Primitive$1["PoolTranspose"] = "pool_transpose";
|
|
357
357
|
Primitive$1["Compare"] = "compare";
|
|
358
358
|
Primitive$1["Where"] = "where";
|
|
359
|
+
Primitive$1["Concatenate"] = "concatenate";
|
|
360
|
+
Primitive$1["Split"] = "split";
|
|
359
361
|
Primitive$1["RandomBits"] = "random_bits";
|
|
360
362
|
Primitive$1["Gather"] = "gather";
|
|
361
363
|
Primitive$1["Transpose"] = "transpose";
|
|
@@ -368,6 +370,7 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
368
370
|
Primitive$1["Argsort"] = "argsort";
|
|
369
371
|
Primitive$1["TriangularSolve"] = "triangular_solve";
|
|
370
372
|
Primitive$1["Cholesky"] = "cholesky";
|
|
373
|
+
Primitive$1["LU"] = "lu";
|
|
371
374
|
Primitive$1["Jit"] = "jit";
|
|
372
375
|
return Primitive$1;
|
|
373
376
|
}({});
|
|
@@ -378,6 +381,13 @@ let CompareOp = /* @__PURE__ */ function(CompareOp$1) {
|
|
|
378
381
|
CompareOp$1["LessEqual"] = "less_equal";
|
|
379
382
|
return CompareOp$1;
|
|
380
383
|
}({});
|
|
384
|
+
const routinePrimitives = new Map([
|
|
385
|
+
[Primitive.Sort, Routines.Sort],
|
|
386
|
+
[Primitive.Argsort, Routines.Argsort],
|
|
387
|
+
[Primitive.TriangularSolve, Routines.TriangularSolve],
|
|
388
|
+
[Primitive.Cholesky, Routines.Cholesky],
|
|
389
|
+
[Primitive.LU, Routines.LU]
|
|
390
|
+
]);
|
|
381
391
|
function add$1(x, y) {
|
|
382
392
|
return bind1(Primitive.Add, [x, y]);
|
|
383
393
|
}
|
|
@@ -499,7 +509,25 @@ function where$1(cond, x, y) {
|
|
|
499
509
|
y
|
|
500
510
|
]);
|
|
501
511
|
}
|
|
512
|
+
function concatenate$1(xs, axis) {
|
|
513
|
+
if (xs.length === 0) throw new Error("concatenate requires at least one input");
|
|
514
|
+
const avals = xs.map((x) => ShapedArray.fromAval(getAval(x)));
|
|
515
|
+
axis = checkAxis(axis, avals[0].ndim);
|
|
516
|
+
for (const x of avals) if (x.ndim !== avals[0].ndim || !x.shape.every((s, i) => i === axis || s === avals[0].shape[i])) throw new Error(`Concatenate: inputs ${avals[0]} and ${x} must match shapes except on axis ${axis}`);
|
|
517
|
+
return bind1(Primitive.Concatenate, xs, { axis });
|
|
518
|
+
}
|
|
519
|
+
function split$2(x, axis, sizes) {
|
|
520
|
+
axis = checkAxis(axis, ndim$1(x));
|
|
521
|
+
if (sizes.some((s) => s < 0 || !Number.isInteger(s))) throw new Error(`split: sizes must be nonnegative integers, got ${JSON.stringify(sizes)}`);
|
|
522
|
+
const totalSize = sizes.reduce((a, b) => a + b, 0);
|
|
523
|
+
if (totalSize !== getShape(x)[axis]) throw new Error(`split: sizes must sum to the size of the axis ${axis}, got ${totalSize}`);
|
|
524
|
+
return bind(Primitive.Split, [x], {
|
|
525
|
+
axis,
|
|
526
|
+
sizes
|
|
527
|
+
});
|
|
528
|
+
}
|
|
502
529
|
function randomBits(k0, k1, shape$1, mode = "xor") {
|
|
530
|
+
if (!deepEqual(k0.shape, k1.shape) || k0.dtype !== DType.Uint32 || k1.dtype !== DType.Uint32) throw new Error(`randomBits: key parts must be uint32 with the same shape, got ${ShapedArray.fromAval(k0.aval)} and ${ShapedArray.fromAval(k1.aval)}`);
|
|
503
531
|
return bind1(Primitive.RandomBits, [k0, k1], {
|
|
504
532
|
shape: shape$1,
|
|
505
533
|
mode
|
|
@@ -566,6 +594,11 @@ function pad$1(x, width) {
|
|
|
566
594
|
return bind1(Primitive.Pad, [x], { width });
|
|
567
595
|
}
|
|
568
596
|
function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
|
|
597
|
+
const as = getShape(a);
|
|
598
|
+
const bs = getShape(b);
|
|
599
|
+
if (as.length < 2 || bs.length < 2) throw new Error(`triangular_solve: must be >=2D, got a=${as}, b=${bs}`);
|
|
600
|
+
const n = as[as.length - 2];
|
|
601
|
+
if (n !== as[as.length - 1] || n !== bs[bs.length - 1]) throw new Error(`triangular_solve: incompatible shapes a=${as}, b=${bs}`);
|
|
569
602
|
if (lower) {
|
|
570
603
|
a = flip$1(a, [-2, -1]);
|
|
571
604
|
b = flip$1(b, [-1]);
|
|
@@ -575,8 +608,15 @@ function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
|
|
|
575
608
|
return x;
|
|
576
609
|
}
|
|
577
610
|
function cholesky$2(x) {
|
|
611
|
+
const aval = ShapedArray.fromAval(getAval(x));
|
|
612
|
+
if (aval.ndim < 2 || aval.shape[aval.ndim - 1] !== aval.shape[aval.ndim - 2]) throw new Error(`cholesky: expected batch of square matrices, got ${aval}`);
|
|
578
613
|
return bind1(Primitive.Cholesky, [x]);
|
|
579
614
|
}
|
|
615
|
+
function lu$1(x) {
|
|
616
|
+
const aval = ShapedArray.fromAval(getAval(x));
|
|
617
|
+
if (aval.ndim < 2) throw new Error(`lu: expected batch of matrices, got ${aval}`);
|
|
618
|
+
return bind(Primitive.LU, [x]);
|
|
619
|
+
}
|
|
580
620
|
function sort$1(x) {
|
|
581
621
|
const nd = ndim$1(x);
|
|
582
622
|
if (nd === 0) throw new Error("sort: requires at least 1D input");
|
|
@@ -621,6 +661,9 @@ function newDynamic(main) {
|
|
|
621
661
|
dynamicTrace = prevDynamicTrace;
|
|
622
662
|
} };
|
|
623
663
|
}
|
|
664
|
+
function currentTraceLevel() {
|
|
665
|
+
return traceStack[traceStack.length - 1].level;
|
|
666
|
+
}
|
|
624
667
|
var Trace = class {
|
|
625
668
|
constructor(main) {
|
|
626
669
|
this.main = main;
|
|
@@ -685,6 +728,9 @@ var Tracer = class Tracer {
|
|
|
685
728
|
mul(other) {
|
|
686
729
|
return mul(this, other);
|
|
687
730
|
}
|
|
731
|
+
mod(other) {
|
|
732
|
+
return mod(this, other);
|
|
733
|
+
}
|
|
688
734
|
greater(other) {
|
|
689
735
|
return greater$1(this, other);
|
|
690
736
|
}
|
|
@@ -797,8 +843,14 @@ var Tracer = class Tracer {
|
|
|
797
843
|
*/
|
|
798
844
|
*[Symbol.iterator]() {
|
|
799
845
|
if (this.ndim === 0) throw new Error("Cannot iterate over a scalar array");
|
|
800
|
-
|
|
801
|
-
this.
|
|
846
|
+
let residual = this;
|
|
847
|
+
const subarrayShape = this.shape.slice(1);
|
|
848
|
+
for (let i = 0; i < this.shape[0]; i++) {
|
|
849
|
+
const lr = split$2(residual, 0, [1, residual.shape[0] - 1]);
|
|
850
|
+
yield lr[0].reshape(subarrayShape);
|
|
851
|
+
residual = lr[1];
|
|
852
|
+
}
|
|
853
|
+
residual.dispose();
|
|
802
854
|
}
|
|
803
855
|
/**
|
|
804
856
|
* Return a sorted copy of an array in ascending order.
|
|
@@ -948,6 +1000,9 @@ var ShapedArray = class ShapedArray {
|
|
|
948
1000
|
get size() {
|
|
949
1001
|
return prod(this.shape);
|
|
950
1002
|
}
|
|
1003
|
+
scalar() {
|
|
1004
|
+
return new ShapedArray([], this.dtype, this.weakType);
|
|
1005
|
+
}
|
|
951
1006
|
toString() {
|
|
952
1007
|
return `${this.dtype}[${this.shape.join(",")}]`;
|
|
953
1008
|
}
|
|
@@ -986,6 +1041,7 @@ var TreeMismatchError = class extends TypeError {
|
|
|
986
1041
|
super(`Mismatched tree structures in ${where$2}: ${left} != ${right}`);
|
|
987
1042
|
}
|
|
988
1043
|
};
|
|
1044
|
+
/** Flatten a function of `JsTree` input/output for use in tracing. */
|
|
989
1045
|
function flattenFun(f, inTree) {
|
|
990
1046
|
const store = { value: void 0 };
|
|
991
1047
|
const flatFun = (...argsFlat) => {
|
|
@@ -997,6 +1053,26 @@ function flattenFun(f, inTree) {
|
|
|
997
1053
|
};
|
|
998
1054
|
return [flatFun, store];
|
|
999
1055
|
}
|
|
1056
|
+
/** Like flattenFun, but expects f to return [main, aux] tuple. */
|
|
1057
|
+
function flattenFunWithAux(f, inTree) {
|
|
1058
|
+
const store = { value: void 0 };
|
|
1059
|
+
const auxStore = { value: void 0 };
|
|
1060
|
+
const flatFun = (...argsFlat) => {
|
|
1061
|
+
const pytreeArgs = unflatten(inTree, argsFlat);
|
|
1062
|
+
const result = f(...pytreeArgs);
|
|
1063
|
+
if (!Array.isArray(result) || result.length !== 2) throw new Error("Function with `hasAux: true` must return [output, aux] tuple");
|
|
1064
|
+
const [out, aux] = result;
|
|
1065
|
+
const [outFlat, outTree] = flatten(out);
|
|
1066
|
+
store.value = outTree;
|
|
1067
|
+
auxStore.value = aux;
|
|
1068
|
+
return outFlat;
|
|
1069
|
+
};
|
|
1070
|
+
return [
|
|
1071
|
+
flatFun,
|
|
1072
|
+
store,
|
|
1073
|
+
auxStore
|
|
1074
|
+
];
|
|
1075
|
+
}
|
|
1000
1076
|
var UseAfterFreeError = class extends ReferenceError {
|
|
1001
1077
|
constructor(tracer) {
|
|
1002
1078
|
super(`Referenced tracer ${tracer.toString()} freed, please use .ref move semantics`);
|
|
@@ -1553,7 +1629,7 @@ const abstractEvalRules = {
|
|
|
1553
1629
|
return [new ShapedArray(shape$1, dtype, weakType)];
|
|
1554
1630
|
},
|
|
1555
1631
|
[Primitive.Conv]([lhs, rhs], params) {
|
|
1556
|
-
const { dtype, weakType } = promoteAvals(
|
|
1632
|
+
const { dtype, weakType } = promoteAvals(lhs.scalar(), rhs.scalar());
|
|
1557
1633
|
const shape$1 = checkConvShape(lhs.shape, rhs.shape, params);
|
|
1558
1634
|
return [new ShapedArray(shape$1, dtype, weakType)];
|
|
1559
1635
|
},
|
|
@@ -1564,10 +1640,25 @@ const abstractEvalRules = {
|
|
|
1564
1640
|
const shape$1 = generalBroadcast(cond.shape, xy.shape);
|
|
1565
1641
|
return [new ShapedArray(shape$1, xy.dtype, xy.weakType)];
|
|
1566
1642
|
},
|
|
1643
|
+
[Primitive.Concatenate](xs, { axis }) {
|
|
1644
|
+
if (xs.length === 0) throw new TypeError("Concatenate requires at least one input");
|
|
1645
|
+
for (const x of xs) if (x.ndim !== xs[0].ndim || !x.shape.every((s, i) => i === axis || s === xs[0].shape[i])) throw new TypeError(`Concatenate: inputs ${xs[0]} and ${x} must match shapes except on axis ${axis}`);
|
|
1646
|
+
const shape$1 = xs[0].shape.slice();
|
|
1647
|
+
shape$1[axis] = xs.reduce((sum$1, x) => sum$1 + x.shape[axis], 0);
|
|
1648
|
+
const { dtype, weakType } = xs.map((x) => x.scalar()).reduce(promoteAvals);
|
|
1649
|
+
return [new ShapedArray(shape$1, dtype, weakType)];
|
|
1650
|
+
},
|
|
1651
|
+
[Primitive.Split]([x], { axis, sizes }) {
|
|
1652
|
+
const totalSize = sizes.reduce((a, b) => a + b, 0);
|
|
1653
|
+
if (x.shape[axis] !== totalSize) throw new TypeError(`Split: sizes ${sizes} do not sum to dimension ${x.shape[axis]} on axis ${axis}`);
|
|
1654
|
+
return sizes.map((size$1) => {
|
|
1655
|
+
return new ShapedArray(x.shape.toSpliced(axis, 1, size$1), x.dtype, x.weakType);
|
|
1656
|
+
});
|
|
1657
|
+
},
|
|
1567
1658
|
[Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
|
|
1568
1659
|
if (k0.dtype !== DType.Uint32 || k1.dtype !== DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
|
|
1569
|
-
|
|
1570
|
-
if (!deepEqual(
|
|
1660
|
+
if (!deepEqual(k0.shape, k1.shape)) throw new TypeError(`RandomBits: Keys have different shapes ${k0.shape} and ${k1.shape}`);
|
|
1661
|
+
if (!deepEqual(shape$1.slice(0, k0.ndim), k0.shape)) throw new TypeError(`RandomBits: generated shape ${shape$1} must match key shape ${k0.shape}`);
|
|
1571
1662
|
return [new ShapedArray(shape$1, DType.Uint32, false)];
|
|
1572
1663
|
},
|
|
1573
1664
|
[Primitive.Gather]([x, ...indices], { axis, outDim }) {
|
|
@@ -1624,6 +1715,16 @@ const abstractEvalRules = {
|
|
|
1624
1715
|
if (a.shape[a.ndim - 2] !== a.shape[a.ndim - 1]) throw new TypeError(`cholesky: must be square, got ${a}`);
|
|
1625
1716
|
return [ShapedArray.fromAval(a)];
|
|
1626
1717
|
},
|
|
1718
|
+
[Primitive.LU]([a]) {
|
|
1719
|
+
if (a.ndim < 2) throw new TypeError(`lu: requires at least 2D input, got ${a}`);
|
|
1720
|
+
const batch = a.shape.slice(0, -2);
|
|
1721
|
+
const [m, n] = a.shape.slice(-2);
|
|
1722
|
+
return [
|
|
1723
|
+
ShapedArray.fromAval(a),
|
|
1724
|
+
new ShapedArray([...batch, Math.min(m, n)], DType.Int32, false),
|
|
1725
|
+
new ShapedArray([...batch, m], DType.Int32, false)
|
|
1726
|
+
];
|
|
1727
|
+
},
|
|
1627
1728
|
[Primitive.Jit](args, { jaxpr }) {
|
|
1628
1729
|
const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
|
|
1629
1730
|
if (args.length !== inTypes.length) throw new TypeError(`jit expected ${inTypes.length} arguments, got ${args.length}`);
|
|
@@ -1701,12 +1802,6 @@ function jit$1(f, opts) {
|
|
|
1701
1802
|
|
|
1702
1803
|
//#endregion
|
|
1703
1804
|
//#region src/frontend/jit.ts
|
|
1704
|
-
const routinePrimitives = new Map([
|
|
1705
|
-
[Primitive.Sort, Routines.Sort],
|
|
1706
|
-
[Primitive.Argsort, Routines.Argsort],
|
|
1707
|
-
[Primitive.TriangularSolve, Routines.TriangularSolve],
|
|
1708
|
-
[Primitive.Cholesky, Routines.Cholesky]
|
|
1709
|
-
]);
|
|
1710
1805
|
/** Result of compiling a Jaxpr. Can be evaluated on a series of inputs. */
|
|
1711
1806
|
var JitProgram = class {
|
|
1712
1807
|
constructor(backend, steps, inputs, outputs) {
|
|
@@ -1876,10 +1971,10 @@ function jitCompile(backend, jaxpr) {
|
|
|
1876
1971
|
inputs.push(jv.arg);
|
|
1877
1972
|
} else if (input instanceof Lit) inputs.push(builder.pushLit(input));
|
|
1878
1973
|
const outputs = [];
|
|
1879
|
-
for (const outVar
|
|
1880
|
-
const outId = builder.pushBuffer(outVar
|
|
1974
|
+
for (const outVar of eqn.outBinders) {
|
|
1975
|
+
const outId = builder.pushBuffer(outVar.aval.size * byteWidth(outVar.aval.dtype));
|
|
1881
1976
|
outputs.push(outId);
|
|
1882
|
-
ctx.set(outVar
|
|
1977
|
+
ctx.set(outVar, {
|
|
1883
1978
|
type: "imm",
|
|
1884
1979
|
arg: outId
|
|
1885
1980
|
});
|
|
@@ -1930,35 +2025,37 @@ function jitCompile(backend, jaxpr) {
|
|
|
1930
2025
|
let reduction;
|
|
1931
2026
|
if (inputReduction) {
|
|
1932
2027
|
const jv = inputReduction;
|
|
1933
|
-
const newEpilogue = rule(inputExps, inputAvals, eqn.params).exp;
|
|
1934
|
-
exp$2 = jv.exp.reindexGids(addArgs(jv.args));
|
|
2028
|
+
const newEpilogue = rule(inputExps, inputAvals, eqn.params).exp[0];
|
|
2029
|
+
exp$2 = [jv.exp.reindexGids(addArgs(jv.args))];
|
|
1935
2030
|
reduction = new Reduction(jv.reduction.dtype, jv.reduction.op, jv.reduction.size, newEpilogue);
|
|
1936
2031
|
} else {
|
|
1937
2032
|
const ruleOutput = rule(inputExps, inputAvals, eqn.params);
|
|
1938
2033
|
exp$2 = ruleOutput.exp;
|
|
1939
2034
|
reduction = ruleOutput.reduction;
|
|
1940
2035
|
}
|
|
1941
|
-
|
|
1942
|
-
|
|
1943
|
-
|
|
1944
|
-
|
|
1945
|
-
|
|
1946
|
-
|
|
1947
|
-
|
|
1948
|
-
|
|
1949
|
-
|
|
2036
|
+
for (let i$1 = 0; i$1 < eqn.outBinders.length; i$1++) {
|
|
2037
|
+
const outVar = eqn.outBinders[i$1];
|
|
2038
|
+
if (blackNodes.has(outVar)) {
|
|
2039
|
+
const nargs$1 = inputArgs.length;
|
|
2040
|
+
const size$1 = outVar.aval.size;
|
|
2041
|
+
const kernel = new Kernel(nargs$1, size$1, exp$2[i$1], reduction);
|
|
2042
|
+
const outId = builder.pushKernel(kernel, inputArgs);
|
|
2043
|
+
ctx.set(outVar, {
|
|
2044
|
+
type: "imm",
|
|
2045
|
+
arg: outId
|
|
2046
|
+
});
|
|
2047
|
+
} else if (reduction) ctx.set(outVar, {
|
|
2048
|
+
type: "red",
|
|
2049
|
+
exp: exp$2[i$1],
|
|
2050
|
+
reduction,
|
|
2051
|
+
args: inputArgs
|
|
1950
2052
|
});
|
|
1951
|
-
|
|
1952
|
-
|
|
1953
|
-
|
|
1954
|
-
|
|
1955
|
-
|
|
1956
|
-
}
|
|
1957
|
-
else ctx.set(outVar, {
|
|
1958
|
-
type: "exp",
|
|
1959
|
-
exp: exp$2,
|
|
1960
|
-
args: inputArgs
|
|
1961
|
-
});
|
|
2053
|
+
else ctx.set(outVar, {
|
|
2054
|
+
type: "exp",
|
|
2055
|
+
exp: exp$2[i$1],
|
|
2056
|
+
args: inputArgs
|
|
2057
|
+
});
|
|
2058
|
+
}
|
|
1962
2059
|
}
|
|
1963
2060
|
const outputIds = [];
|
|
1964
2061
|
for (const out of jaxpr.outs) if (out instanceof Var) {
|
|
@@ -1999,17 +2096,17 @@ function broadcastedJit(fn, opts) {
|
|
|
1999
2096
|
if (exp$2.dtype !== newDtype && !skipCastIdx.includes(i)) exp$2 = AluExp.cast(newDtype, exp$2);
|
|
2000
2097
|
return exp$2;
|
|
2001
2098
|
});
|
|
2002
|
-
return { exp: fn(exps, params) };
|
|
2099
|
+
return { exp: [fn(exps, params)] };
|
|
2003
2100
|
};
|
|
2004
2101
|
}
|
|
2005
2102
|
function unopJit(fn) {
|
|
2006
2103
|
return ([a], [_as], params) => {
|
|
2007
|
-
return { exp: fn(a, params) };
|
|
2104
|
+
return { exp: [fn(a, params)] };
|
|
2008
2105
|
};
|
|
2009
2106
|
}
|
|
2010
2107
|
function reshapeJit(fn) {
|
|
2011
2108
|
return ([a], [_as], params) => {
|
|
2012
|
-
return { exp: reshapeViews(a, (st) => fn(st, params)) };
|
|
2109
|
+
return { exp: [reshapeViews(a, (st) => fn(st, params))] };
|
|
2013
2110
|
};
|
|
2014
2111
|
}
|
|
2015
2112
|
function routineNoJit() {
|
|
@@ -2055,7 +2152,7 @@ const jitRules = {
|
|
|
2055
2152
|
a = reshapeViews(a, (st) => st.permute(perm).reshape(newShape), true);
|
|
2056
2153
|
const reduction = new Reduction(a.dtype, op, reductionSize);
|
|
2057
2154
|
return {
|
|
2058
|
-
exp: a,
|
|
2155
|
+
exp: [a],
|
|
2059
2156
|
reduction
|
|
2060
2157
|
};
|
|
2061
2158
|
},
|
|
@@ -2066,13 +2163,13 @@ const jitRules = {
|
|
|
2066
2163
|
a = reshapeViews(a, (st) => st.compose(stX), true);
|
|
2067
2164
|
const reduction = new Reduction(a.dtype, AluOp.Add, stX.shape[stX.shape.length - 1]);
|
|
2068
2165
|
return {
|
|
2069
|
-
exp: a,
|
|
2166
|
+
exp: [a],
|
|
2070
2167
|
reduction
|
|
2071
2168
|
};
|
|
2072
2169
|
},
|
|
2073
2170
|
[Primitive.Dot]([a, b], [as, bs]) {
|
|
2074
2171
|
const k1 = jitRules[Primitive.Mul]([a, b], [as, bs], {});
|
|
2075
|
-
const c = k1.exp;
|
|
2172
|
+
const [c] = k1.exp;
|
|
2076
2173
|
const cs = promoteAvals(as, bs);
|
|
2077
2174
|
return jitRules[Primitive.Reduce]([c], [cs], {
|
|
2078
2175
|
op: AluOp.Add,
|
|
@@ -2089,16 +2186,42 @@ const jitRules = {
|
|
|
2089
2186
|
},
|
|
2090
2187
|
[Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
|
|
2091
2188
|
[Primitive.Where]: broadcastedJit(([cond, a, b]) => AluExp.where(cond, a, b), { skipCastIdx: [0] }),
|
|
2189
|
+
[Primitive.Concatenate](exps, avals, { axis }) {
|
|
2190
|
+
const ndim$2 = avals[0].ndim;
|
|
2191
|
+
const sizes = avals.map((x) => x.shape[axis]);
|
|
2192
|
+
const finalSize = sizes.reduce((a, b) => a + b, 0);
|
|
2193
|
+
const { dtype: dtypeOut } = avals.map((x) => x.scalar()).reduce(promoteAvals);
|
|
2194
|
+
const makePadAxis = (start, end) => range(ndim$2).map((i) => i === axis ? [start, end] : [0, 0]);
|
|
2195
|
+
let cum = 0;
|
|
2196
|
+
const src = [];
|
|
2197
|
+
for (let i = 0; i < exps.length; i++) {
|
|
2198
|
+
const padding = makePadAxis(cum, finalSize - cum - sizes[i]);
|
|
2199
|
+
src.push(reshapeViews(AluExp.cast(dtypeOut, exps[i]), (st) => st.pad(padding)));
|
|
2200
|
+
cum += sizes[i];
|
|
2201
|
+
}
|
|
2202
|
+
return { exp: [src.reduce(AluExp.add)] };
|
|
2203
|
+
},
|
|
2204
|
+
[Primitive.Split]([a], [as], { axis, sizes }) {
|
|
2205
|
+
const exp$2 = [];
|
|
2206
|
+
let start = 0;
|
|
2207
|
+
for (const size$1 of sizes) {
|
|
2208
|
+
const slice = range(as.ndim).map((d) => d === axis ? [start, start + size$1] : [0, as.shape[d]]);
|
|
2209
|
+
exp$2.push(reshapeViews(a, (st) => st.shrink(slice)));
|
|
2210
|
+
start += size$1;
|
|
2211
|
+
}
|
|
2212
|
+
return { exp: exp$2 };
|
|
2213
|
+
},
|
|
2092
2214
|
[Primitive.RandomBits]: (keys, keyShapes, { shape: shape$1, mode }) => {
|
|
2215
|
+
const keyShape = keyShapes[0].shape;
|
|
2093
2216
|
const mapping = (st) => {
|
|
2094
|
-
if (!deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, range(shape
|
|
2217
|
+
if (!deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, range(st.shape.length, shape$1.length));
|
|
2095
2218
|
};
|
|
2096
2219
|
const k0 = reshapeViews(keys[0], mapping);
|
|
2097
2220
|
const k1 = reshapeViews(keys[1], mapping);
|
|
2098
2221
|
const c0 = AluExp.u32(0);
|
|
2099
|
-
const c1 = AluExp.cast(DType.Uint32, AluVar.gidx);
|
|
2222
|
+
const c1 = AluExp.mod(AluExp.cast(DType.Uint32, AluVar.gidx), AluExp.u32(Math.max(prod(shape$1.slice(keyShape.length)), 1)));
|
|
2100
2223
|
const exp$2 = AluExp.threefry2x32(k0, k1, c0, c1, mode);
|
|
2101
|
-
return { exp: exp$2 };
|
|
2224
|
+
return { exp: [exp$2] };
|
|
2102
2225
|
},
|
|
2103
2226
|
[Primitive.Gather]([x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
|
|
2104
2227
|
const axisSet = new Set(axis);
|
|
@@ -2113,7 +2236,7 @@ const jitRules = {
|
|
|
2113
2236
|
for (const [i, iexp] of indices.entries()) src[axis[i]] = AluExp.cast(DType.Int32, reshapeViews(iexp, (st) => st.broadcast(finalShape, [...range(outDim + indexShape.length - st.shape.length), ...range(outDim + indexShape.length, finalShape.length)])));
|
|
2114
2237
|
const [index, valid] = ShapeTracker.fromShape(xs.shape).toAluExp(src);
|
|
2115
2238
|
if (!valid.resolve()) throw new Error("internal: expected full validity mask in Gather");
|
|
2116
|
-
return { exp: x.substitute({ gidx: index }) };
|
|
2239
|
+
return { exp: [x.substitute({ gidx: index })] };
|
|
2117
2240
|
},
|
|
2118
2241
|
[Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
|
|
2119
2242
|
[Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
|
|
@@ -2129,6 +2252,7 @@ const jitRules = {
|
|
|
2129
2252
|
[Primitive.Argsort]: routineNoJit(),
|
|
2130
2253
|
[Primitive.TriangularSolve]: routineNoJit(),
|
|
2131
2254
|
[Primitive.Cholesky]: routineNoJit(),
|
|
2255
|
+
[Primitive.LU]: routineNoJit(),
|
|
2132
2256
|
[Primitive.Jit]() {
|
|
2133
2257
|
throw new Error("internal: Jit should have been flattened before JIT compilation");
|
|
2134
2258
|
}
|
|
@@ -2210,7 +2334,7 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2210
2334
|
p1NextBlack.set(v, v);
|
|
2211
2335
|
}
|
|
2212
2336
|
const heterogeneousViewPrimitives = [Primitive.RandomBits, Primitive.Gather];
|
|
2213
|
-
const needsCleanShapePrimitives = [Primitive.Pad];
|
|
2337
|
+
const needsCleanShapePrimitives = [Primitive.Concatenate, Primitive.Pad];
|
|
2214
2338
|
for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
|
|
2215
2339
|
const eqn = jaxpr.eqns[i];
|
|
2216
2340
|
if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || routinePrimitives.has(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
|
|
@@ -2280,7 +2404,7 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2280
2404
|
|
|
2281
2405
|
//#endregion
|
|
2282
2406
|
//#region src/frontend/array.ts
|
|
2283
|
-
const JsArray = globalThis.Array;
|
|
2407
|
+
const JsArray$1 = globalThis.Array;
|
|
2284
2408
|
const inlineArrayLimit = 128;
|
|
2285
2409
|
/** Version of pureArray with fudged types. */
|
|
2286
2410
|
const fudgeArray = pureArray;
|
|
@@ -2407,6 +2531,10 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2407
2531
|
this.#rc++;
|
|
2408
2532
|
return this;
|
|
2409
2533
|
}
|
|
2534
|
+
/** Get the current reference count (for debugging memory management). */
|
|
2535
|
+
get refCount() {
|
|
2536
|
+
return this.#rc;
|
|
2537
|
+
}
|
|
2410
2538
|
dispose() {
|
|
2411
2539
|
this.#check();
|
|
2412
2540
|
if (--this.#rc === 0) {
|
|
@@ -2564,7 +2692,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2564
2692
|
} else if (castDtype === void 0) {
|
|
2565
2693
|
castDtype = arrays[i].#dtype;
|
|
2566
2694
|
castWeakType = arrays[i].#weakType;
|
|
2567
|
-
} else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType),
|
|
2695
|
+
} else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), arrays[i].aval.scalar()));
|
|
2568
2696
|
const weakType = castWeakType && !strongTypeOutput;
|
|
2569
2697
|
const { backend, committed } = Array$1.#computeBackend(name, arrays);
|
|
2570
2698
|
arrays = arrays.map((ar) => ar._putSync(backend));
|
|
@@ -2674,25 +2802,35 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2674
2802
|
});
|
|
2675
2803
|
}
|
|
2676
2804
|
/** Apply an operation with custom lowering to this array. */
|
|
2677
|
-
static #routine(
|
|
2678
|
-
|
|
2679
|
-
|
|
2680
|
-
|
|
2681
|
-
|
|
2682
|
-
|
|
2683
|
-
|
|
2684
|
-
|
|
2685
|
-
|
|
2686
|
-
|
|
2687
|
-
|
|
2688
|
-
|
|
2689
|
-
|
|
2690
|
-
dtype
|
|
2691
|
-
|
|
2692
|
-
|
|
2693
|
-
|
|
2694
|
-
pending
|
|
2695
|
-
|
|
2805
|
+
static #routine(prim) {
|
|
2806
|
+
return (arrays, params) => {
|
|
2807
|
+
const { backend, committed } = Array$1.#computeBackend(prim, arrays);
|
|
2808
|
+
for (const ar of arrays) ar.#realize();
|
|
2809
|
+
const avals = arrays.map((ar) => ar.aval);
|
|
2810
|
+
const avalsOut = abstractEvalRules[prim](avals, params);
|
|
2811
|
+
const routine = new Routine(routinePrimitives.get(prim), {
|
|
2812
|
+
inputShapes: avals.map((a) => a.shape),
|
|
2813
|
+
inputDtypes: avals.map((a) => a.dtype),
|
|
2814
|
+
outputShapes: avalsOut.map((a) => a.shape),
|
|
2815
|
+
outputDtypes: avalsOut.map((a) => a.dtype)
|
|
2816
|
+
}, params);
|
|
2817
|
+
const inputs = arrays.map((ar) => ar.#source);
|
|
2818
|
+
const outputs = avalsOut.map((x) => backend.malloc(byteWidth(x.dtype) * x.size));
|
|
2819
|
+
const pending = arrays.flatMap((ar) => ar.#pending);
|
|
2820
|
+
for (const exe of pending) exe.updateRc(+outputs.length);
|
|
2821
|
+
pending.push(new PendingExecute(backend, routine, inputs, outputs));
|
|
2822
|
+
pending[pending.length - 1].updateRc(+outputs.length - 1);
|
|
2823
|
+
arrays.forEach((ar) => ar.dispose());
|
|
2824
|
+
return outputs.map((output, i) => new Array$1({
|
|
2825
|
+
source: output,
|
|
2826
|
+
st: ShapeTracker.fromShape(avalsOut[i].shape),
|
|
2827
|
+
dtype: avalsOut[i].dtype,
|
|
2828
|
+
weakType: avalsOut[i].weakType,
|
|
2829
|
+
backend,
|
|
2830
|
+
committed,
|
|
2831
|
+
pending
|
|
2832
|
+
}));
|
|
2833
|
+
};
|
|
2696
2834
|
}
|
|
2697
2835
|
/**
|
|
2698
2836
|
* Normalizes this array into one backed by a `Slot`.
|
|
@@ -2957,17 +3095,44 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2957
3095
|
y
|
|
2958
3096
|
], { dtypeOverride: [DType.Bool] })];
|
|
2959
3097
|
},
|
|
3098
|
+
[Primitive.Concatenate](xs, { axis }) {
|
|
3099
|
+
const ndim$2 = xs[0].ndim;
|
|
3100
|
+
const sizes = xs.map((x) => x.shape[axis]);
|
|
3101
|
+
const finalSize = sizes.reduce((a, b) => a + b, 0);
|
|
3102
|
+
const makePadAxis = (start, end) => range(ndim$2).map((i) => i === axis ? [start, end] : [0, 0]);
|
|
3103
|
+
let cum = 0;
|
|
3104
|
+
const xsPadded = [];
|
|
3105
|
+
for (let i = 0; i < xs.length; i++) {
|
|
3106
|
+
const padding = makePadAxis(cum, finalSize - cum - sizes[i]);
|
|
3107
|
+
xsPadded.push(xs[i].#reshape(xs[i].#st.pad(padding)));
|
|
3108
|
+
cum += sizes[i];
|
|
3109
|
+
}
|
|
3110
|
+
const custom = (exps) => exps.reduce(AluExp.add);
|
|
3111
|
+
return [Array$1.#naryCustom("concatenate", custom, xsPadded)];
|
|
3112
|
+
},
|
|
3113
|
+
[Primitive.Split]([x], { axis, sizes }) {
|
|
3114
|
+
const outputs = [];
|
|
3115
|
+
for (let i = 0, start = 0; i < sizes.length; i++) {
|
|
3116
|
+
const slice = range(x.ndim).map((d) => d === axis ? [start, start + sizes[i]] : [0, x.shape[d]]);
|
|
3117
|
+
outputs.push(x.ref.#reshape(x.#st.shrink(slice)));
|
|
3118
|
+
start += sizes[i];
|
|
3119
|
+
}
|
|
3120
|
+
x.dispose();
|
|
3121
|
+
return outputs;
|
|
3122
|
+
},
|
|
2960
3123
|
[Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
|
|
2961
|
-
const keyShape =
|
|
2962
|
-
|
|
2963
|
-
const c0 = zeros(
|
|
3124
|
+
const keyShape = k0.shape;
|
|
3125
|
+
const genShape = shape$1.slice(keyShape.length);
|
|
3126
|
+
const c0 = zeros(genShape, {
|
|
2964
3127
|
dtype: DType.Uint32,
|
|
2965
3128
|
device: k0.device
|
|
2966
3129
|
});
|
|
2967
|
-
const c1 = arange(0, prod(
|
|
3130
|
+
const c1 = arange(0, prod(genShape), 1, {
|
|
2968
3131
|
dtype: DType.Uint32,
|
|
2969
3132
|
device: k0.device
|
|
2970
|
-
}).reshape(
|
|
3133
|
+
}).reshape(genShape);
|
|
3134
|
+
k0 = k0.#reshape(k0.#st.reshape(keyShape.concat(rep(genShape.length, 1))));
|
|
3135
|
+
k1 = k1.#reshape(k1.#st.reshape(keyShape.concat(rep(genShape.length, 1))));
|
|
2971
3136
|
const custom = ([k0$1, k1$1, c0$1, c1$1]) => AluExp.threefry2x32(k0$1, k1$1, c0$1, c1$1, mode);
|
|
2972
3137
|
return [Array$1.#naryCustom("random_bits", custom, [
|
|
2973
3138
|
k0,
|
|
@@ -2999,42 +3164,11 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2999
3164
|
[Primitive.Pad]([x], { width }) {
|
|
3000
3165
|
return [x.#reshape(x.#st.pad(width))];
|
|
3001
3166
|
},
|
|
3002
|
-
[Primitive.Sort](
|
|
3003
|
-
|
|
3004
|
-
|
|
3005
|
-
|
|
3006
|
-
|
|
3007
|
-
outputDtypes: [x.aval.dtype]
|
|
3008
|
-
});
|
|
3009
|
-
return Array$1.#routine(routine, [x], [x.#weakType]);
|
|
3010
|
-
},
|
|
3011
|
-
[Primitive.Argsort]([x]) {
|
|
3012
|
-
const routine = new Routine(Routines.Argsort, {
|
|
3013
|
-
inputShapes: [x.aval.shape],
|
|
3014
|
-
inputDtypes: [x.aval.dtype],
|
|
3015
|
-
outputShapes: [x.aval.shape, x.aval.shape],
|
|
3016
|
-
outputDtypes: [x.aval.dtype, DType.Int32]
|
|
3017
|
-
});
|
|
3018
|
-
return Array$1.#routine(routine, [x], [x.#weakType, false]);
|
|
3019
|
-
},
|
|
3020
|
-
[Primitive.TriangularSolve]([a, b], { unitDiagonal }) {
|
|
3021
|
-
const routine = new Routine(Routines.TriangularSolve, {
|
|
3022
|
-
inputShapes: [a.aval.shape, b.aval.shape],
|
|
3023
|
-
inputDtypes: [a.aval.dtype, b.aval.dtype],
|
|
3024
|
-
outputShapes: [b.aval.shape],
|
|
3025
|
-
outputDtypes: [b.aval.dtype]
|
|
3026
|
-
}, { unitDiagonal });
|
|
3027
|
-
return Array$1.#routine(routine, [a, b], [a.#weakType && b.#weakType]);
|
|
3028
|
-
},
|
|
3029
|
-
[Primitive.Cholesky]([a]) {
|
|
3030
|
-
const routine = new Routine(Routines.Cholesky, {
|
|
3031
|
-
inputShapes: [a.aval.shape],
|
|
3032
|
-
inputDtypes: [a.aval.dtype],
|
|
3033
|
-
outputShapes: [a.aval.shape],
|
|
3034
|
-
outputDtypes: [a.aval.dtype]
|
|
3035
|
-
});
|
|
3036
|
-
return Array$1.#routine(routine, [a], [a.#weakType]);
|
|
3037
|
-
},
|
|
3167
|
+
[Primitive.Sort]: Array$1.#routine(Primitive.Sort),
|
|
3168
|
+
[Primitive.Argsort]: Array$1.#routine(Primitive.Argsort),
|
|
3169
|
+
[Primitive.TriangularSolve]: Array$1.#routine(Primitive.TriangularSolve),
|
|
3170
|
+
[Primitive.Cholesky]: Array$1.#routine(Primitive.Cholesky),
|
|
3171
|
+
[Primitive.LU]: Array$1.#routine(Primitive.LU),
|
|
3038
3172
|
[Primitive.Jit](args, { jaxpr }) {
|
|
3039
3173
|
if (jaxpr.inBinders.length !== args.length) throw new Error(`jit expects ${jaxpr.inBinders.length} args, got ${args.length}`);
|
|
3040
3174
|
const { backend, committed } = Array$1.#computeBackend("jit", args);
|
|
@@ -3116,7 +3250,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
3116
3250
|
if (!shape$1) {
|
|
3117
3251
|
shape$1 = [];
|
|
3118
3252
|
let cur = values;
|
|
3119
|
-
while (JsArray.isArray(cur)) {
|
|
3253
|
+
while (JsArray$1.isArray(cur)) {
|
|
3120
3254
|
shape$1.push(cur.length);
|
|
3121
3255
|
cur = cur[0];
|
|
3122
3256
|
}
|
|
@@ -3140,7 +3274,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
3140
3274
|
device
|
|
3141
3275
|
});
|
|
3142
3276
|
} else {
|
|
3143
|
-
const weakType = dtype == void 0;
|
|
3277
|
+
const weakType = dtype == void 0 && shape$1.length === 0;
|
|
3144
3278
|
dtype = dtype ?? DType.Float32;
|
|
3145
3279
|
const data = dtypedJsArray(dtype, flat);
|
|
3146
3280
|
return arrayFromData(data, shape$1, {
|
|
@@ -3254,7 +3388,7 @@ function ones(shape$1, { dtype, device } = {}) {
|
|
|
3254
3388
|
}
|
|
3255
3389
|
/** Return a new array of given shape and type, filled with `fill_value`. */
|
|
3256
3390
|
function full(shape$1, fillValue, { dtype, device } = {}) {
|
|
3257
|
-
let weakType = dtype == void 0;
|
|
3391
|
+
let weakType = dtype == void 0 && shape$1.length === 0;
|
|
3258
3392
|
if (typeof fillValue === "number") dtype = dtype ?? DType.Float32;
|
|
3259
3393
|
else if (typeof fillValue === "boolean") {
|
|
3260
3394
|
dtype = dtype ?? DType.Bool;
|
|
@@ -3412,6 +3546,27 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
|
|
|
3412
3546
|
committed: device != void 0
|
|
3413
3547
|
});
|
|
3414
3548
|
}
|
|
3549
|
+
/**
|
|
3550
|
+
* Return numbers spaced evenly on a log scale.
|
|
3551
|
+
*
|
|
3552
|
+
* In linear space, the sequence starts at `base ** start` and ends at
|
|
3553
|
+
* `base ** stop` (see `endpoint` below).
|
|
3554
|
+
*
|
|
3555
|
+
* @param start - `base ** start` is the starting value of the sequence.
|
|
3556
|
+
* @param stop - `base ** stop` is the final value of the sequence, unless `endpoint` is false.
|
|
3557
|
+
* @param num - Number of samples to generate. Default is 50.
|
|
3558
|
+
* @param endpoint - If true, `stop` is the last sample. Otherwise, it is not included. Default is true.
|
|
3559
|
+
* @param base - The base of the log space. Default is 10.
|
|
3560
|
+
* @returns Array of evenly spaced values on a log scale.
|
|
3561
|
+
*/
|
|
3562
|
+
function logspace(start, stop, num = 50, endpoint = true, base = 10, { dtype, device } = {}) {
|
|
3563
|
+
const y = linspace(start, stop, num, endpoint, {
|
|
3564
|
+
dtype,
|
|
3565
|
+
device
|
|
3566
|
+
});
|
|
3567
|
+
const logBase = Math.log(base);
|
|
3568
|
+
return exp$1(mul(y, logBase));
|
|
3569
|
+
}
|
|
3415
3570
|
function aluCompare(a, b, op) {
|
|
3416
3571
|
switch (op) {
|
|
3417
3572
|
case CompareOp.Less: return AluExp.cmplt(a, b);
|
|
@@ -3488,6 +3643,7 @@ var BatchTrace = class extends Trace {
|
|
|
3488
3643
|
return valOuts$1.map((x) => new BatchTracer(this, x, null));
|
|
3489
3644
|
}
|
|
3490
3645
|
const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
|
|
3646
|
+
if (valOuts.length !== bdimOuts.length) throw new Error(`vmap rule for ${primitive} returned mismatched lengths: ${valOuts.length} vs ${bdimOuts.length}`);
|
|
3491
3647
|
return zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
|
|
3492
3648
|
}
|
|
3493
3649
|
get axisSize() {
|
|
@@ -3499,13 +3655,13 @@ var BatchTrace = class extends Trace {
|
|
|
3499
3655
|
*
|
|
3500
3656
|
* Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
|
|
3501
3657
|
*/
|
|
3502
|
-
function broadcastBatcher(
|
|
3503
|
-
return (axisSize, args, dims) => {
|
|
3658
|
+
function broadcastBatcher(prim) {
|
|
3659
|
+
return (axisSize, args, dims, params) => {
|
|
3504
3660
|
if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
|
|
3505
3661
|
const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
|
|
3506
3662
|
const firstIdx = dims.findIndex((d) => d !== null);
|
|
3507
3663
|
const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
|
|
3508
|
-
if (zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[
|
|
3664
|
+
if (zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[bind1(prim, args, params)], [nd + firstBdim]];
|
|
3509
3665
|
args = args.map((x, i) => {
|
|
3510
3666
|
if (dims[i] === null) return x;
|
|
3511
3667
|
x = moveBatchAxis(axisSize, dims[i], 0, x);
|
|
@@ -3516,37 +3672,45 @@ function broadcastBatcher(op) {
|
|
|
3516
3672
|
]);
|
|
3517
3673
|
return x;
|
|
3518
3674
|
});
|
|
3519
|
-
return [[
|
|
3675
|
+
return [[bind1(prim, args, params)], [0]];
|
|
3520
3676
|
};
|
|
3521
3677
|
}
|
|
3522
|
-
function unopBatcher(
|
|
3678
|
+
function unopBatcher(prim) {
|
|
3523
3679
|
return (axisSize, [x], [xBdim], params) => {
|
|
3524
|
-
return [[
|
|
3680
|
+
return [[bind1(prim, [x], params)], [xBdim]];
|
|
3681
|
+
};
|
|
3682
|
+
}
|
|
3683
|
+
function lastDimsBatcher(prim, inputDims, numOutputs = 1) {
|
|
3684
|
+
return (axisSize, [x], [xBdim], params) => {
|
|
3685
|
+
assertNonNull(xBdim);
|
|
3686
|
+
if (xBdim < x.ndim - inputDims) return [bind(prim, [x], params), rep(numOutputs, xBdim)];
|
|
3687
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3688
|
+
return [bind(prim, [x], params), rep(numOutputs, 0)];
|
|
3525
3689
|
};
|
|
3526
3690
|
}
|
|
3527
3691
|
const vmapRules = {
|
|
3528
|
-
[Primitive.Add]: broadcastBatcher(
|
|
3529
|
-
[Primitive.Mul]: broadcastBatcher(
|
|
3530
|
-
[Primitive.Idiv]: broadcastBatcher(
|
|
3531
|
-
[Primitive.Mod]: broadcastBatcher(
|
|
3532
|
-
[Primitive.Min]: broadcastBatcher(
|
|
3533
|
-
[Primitive.Max]: broadcastBatcher(
|
|
3534
|
-
[Primitive.Neg]: unopBatcher(
|
|
3535
|
-
[Primitive.Reciprocal]: unopBatcher(
|
|
3536
|
-
[Primitive.Floor]: unopBatcher(
|
|
3537
|
-
[Primitive.Ceil]: unopBatcher(
|
|
3538
|
-
[Primitive.StopGradient]: unopBatcher(
|
|
3539
|
-
[Primitive.Cast]: unopBatcher(
|
|
3540
|
-
[Primitive.Bitcast]: unopBatcher(
|
|
3541
|
-
[Primitive.Sin]: unopBatcher(
|
|
3542
|
-
[Primitive.Cos]: unopBatcher(
|
|
3543
|
-
[Primitive.Asin]: unopBatcher(
|
|
3544
|
-
[Primitive.Atan]: unopBatcher(
|
|
3545
|
-
[Primitive.Exp]: unopBatcher(
|
|
3546
|
-
[Primitive.Log]: unopBatcher(
|
|
3547
|
-
[Primitive.Erf]: unopBatcher(
|
|
3548
|
-
[Primitive.Erfc]: unopBatcher(
|
|
3549
|
-
[Primitive.Sqrt]: unopBatcher(
|
|
3692
|
+
[Primitive.Add]: broadcastBatcher(Primitive.Add),
|
|
3693
|
+
[Primitive.Mul]: broadcastBatcher(Primitive.Mul),
|
|
3694
|
+
[Primitive.Idiv]: broadcastBatcher(Primitive.Idiv),
|
|
3695
|
+
[Primitive.Mod]: broadcastBatcher(Primitive.Mod),
|
|
3696
|
+
[Primitive.Min]: broadcastBatcher(Primitive.Min),
|
|
3697
|
+
[Primitive.Max]: broadcastBatcher(Primitive.Max),
|
|
3698
|
+
[Primitive.Neg]: unopBatcher(Primitive.Neg),
|
|
3699
|
+
[Primitive.Reciprocal]: unopBatcher(Primitive.Reciprocal),
|
|
3700
|
+
[Primitive.Floor]: unopBatcher(Primitive.Floor),
|
|
3701
|
+
[Primitive.Ceil]: unopBatcher(Primitive.Ceil),
|
|
3702
|
+
[Primitive.StopGradient]: unopBatcher(Primitive.StopGradient),
|
|
3703
|
+
[Primitive.Cast]: unopBatcher(Primitive.Cast),
|
|
3704
|
+
[Primitive.Bitcast]: unopBatcher(Primitive.Bitcast),
|
|
3705
|
+
[Primitive.Sin]: unopBatcher(Primitive.Sin),
|
|
3706
|
+
[Primitive.Cos]: unopBatcher(Primitive.Cos),
|
|
3707
|
+
[Primitive.Asin]: unopBatcher(Primitive.Asin),
|
|
3708
|
+
[Primitive.Atan]: unopBatcher(Primitive.Atan),
|
|
3709
|
+
[Primitive.Exp]: unopBatcher(Primitive.Exp),
|
|
3710
|
+
[Primitive.Log]: unopBatcher(Primitive.Log),
|
|
3711
|
+
[Primitive.Erf]: unopBatcher(Primitive.Erf),
|
|
3712
|
+
[Primitive.Erfc]: unopBatcher(Primitive.Erfc),
|
|
3713
|
+
[Primitive.Sqrt]: unopBatcher(Primitive.Sqrt),
|
|
3550
3714
|
[Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
|
|
3551
3715
|
assertNonNull(xBdim);
|
|
3552
3716
|
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
@@ -3568,10 +3732,25 @@ const vmapRules = {
|
|
|
3568
3732
|
});
|
|
3569
3733
|
return [[z], [0]];
|
|
3570
3734
|
},
|
|
3571
|
-
[Primitive.Compare](
|
|
3572
|
-
|
|
3735
|
+
[Primitive.Compare]: broadcastBatcher(Primitive.Compare),
|
|
3736
|
+
[Primitive.Where]: broadcastBatcher(Primitive.Where),
|
|
3737
|
+
[Primitive.Concatenate](axisSize, xs, xBdims, { axis }) {
|
|
3738
|
+
const minBdim = Math.min(...xBdims.filter((d) => d !== null));
|
|
3739
|
+
xs = xs.map((x, i) => moveBatchAxis(axisSize, xBdims[i], minBdim, x));
|
|
3740
|
+
const newAxis = axis + (minBdim <= axis ? 1 : 0);
|
|
3741
|
+
return [[concatenate$1(xs, newAxis)], [minBdim]];
|
|
3742
|
+
},
|
|
3743
|
+
[Primitive.Split](axisSize, [x], [xBdim], { axis, sizes }) {
|
|
3744
|
+
assertNonNull(xBdim);
|
|
3745
|
+
const newAxis = axis + (xBdim <= axis ? 1 : 0);
|
|
3746
|
+
const outs = split$2(x, newAxis, sizes);
|
|
3747
|
+
return [outs, rep(outs.length, xBdim)];
|
|
3748
|
+
},
|
|
3749
|
+
[Primitive.RandomBits](axisSize, [k0, k1], [bdim0, bdim1], { shape: shape$1, mode }) {
|
|
3750
|
+
k0 = moveBatchAxis(axisSize, bdim0, 0, k0);
|
|
3751
|
+
k1 = moveBatchAxis(axisSize, bdim1, 0, k1);
|
|
3752
|
+
return [[randomBits(k0, k1, [axisSize, ...shape$1], mode)], [0]];
|
|
3573
3753
|
},
|
|
3574
|
-
[Primitive.Where]: broadcastBatcher(where$1),
|
|
3575
3754
|
[Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
|
|
3576
3755
|
if (indicesBdim.every((d) => d === null)) {
|
|
3577
3756
|
assertNonNull(xBdim);
|
|
@@ -3633,18 +3812,8 @@ const vmapRules = {
|
|
|
3633
3812
|
const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
|
|
3634
3813
|
return [[pad$1(x, newWidth)], [xBdim]];
|
|
3635
3814
|
},
|
|
3636
|
-
[Primitive.Sort](
|
|
3637
|
-
|
|
3638
|
-
if (xBdim !== x.ndim - 1) return [[sort$1(x)], [xBdim]];
|
|
3639
|
-
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3640
|
-
return [[sort$1(x)], [0]];
|
|
3641
|
-
},
|
|
3642
|
-
[Primitive.Argsort](axisSize, [x], [xBdim]) {
|
|
3643
|
-
assertNonNull(xBdim);
|
|
3644
|
-
if (xBdim !== x.ndim - 1) return [argsort$1(x), [xBdim, xBdim]];
|
|
3645
|
-
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3646
|
-
return [argsort$1(x), [0, 0]];
|
|
3647
|
-
},
|
|
3815
|
+
[Primitive.Sort]: lastDimsBatcher(Primitive.Sort, 1),
|
|
3816
|
+
[Primitive.Argsort]: lastDimsBatcher(Primitive.Argsort, 1, 2),
|
|
3648
3817
|
[Primitive.TriangularSolve](axisSize, [a, b], [aBdim, bBdim], { unitDiagonal }) {
|
|
3649
3818
|
if (aBdim === null) {
|
|
3650
3819
|
b = moveBatchAxis(axisSize, bBdim, -3, b);
|
|
@@ -3668,12 +3837,8 @@ const vmapRules = {
|
|
|
3668
3837
|
const x = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
|
|
3669
3838
|
return [[x], [0]];
|
|
3670
3839
|
},
|
|
3671
|
-
[Primitive.Cholesky](
|
|
3672
|
-
|
|
3673
|
-
if (xBdim < x.ndim - 2) return [[cholesky$2(x)], [xBdim]];
|
|
3674
|
-
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3675
|
-
return [[cholesky$2(x)], [0]];
|
|
3676
|
-
},
|
|
3840
|
+
[Primitive.Cholesky]: lastDimsBatcher(Primitive.Cholesky, 2),
|
|
3841
|
+
[Primitive.LU]: lastDimsBatcher(Primitive.LU, 2, 3),
|
|
3677
3842
|
[Primitive.Jit](axisSize, args, dims, { name, jaxpr }) {
|
|
3678
3843
|
const newJaxpr = vmapJaxpr(jaxpr, axisSize, dims);
|
|
3679
3844
|
const outs = bind(Primitive.Jit, [...newJaxpr.consts.map((c) => c.ref), ...args], {
|
|
@@ -3823,6 +3988,16 @@ function batchMatmulT(a, b) {
|
|
|
3823
3988
|
function mT(a) {
|
|
3824
3989
|
return moveaxis(a, -2, -1);
|
|
3825
3990
|
}
|
|
3991
|
+
function sliceAxis(a, axis, p) {
|
|
3992
|
+
const slices = Array(a.shape.length).fill([]);
|
|
3993
|
+
slices[checkAxis(axis, a.ndim)] = p;
|
|
3994
|
+
return a.slice(...slices);
|
|
3995
|
+
}
|
|
3996
|
+
function padAxis(a, axis, p) {
|
|
3997
|
+
const pads = Array(a.shape.length).fill([0, 0]);
|
|
3998
|
+
pads[checkAxis(axis, a.ndim)] = p;
|
|
3999
|
+
return pad$1(a, pads);
|
|
4000
|
+
}
|
|
3826
4001
|
const jvpRules = {
|
|
3827
4002
|
[Primitive.Add]: linearTangentsJvp(Primitive.Add),
|
|
3828
4003
|
[Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
|
|
@@ -3921,6 +4096,8 @@ const jvpRules = {
|
|
|
3921
4096
|
dcond.dispose();
|
|
3922
4097
|
return [[where$1(cond.ref, x, y)], [where$1(cond, dx, dy)]];
|
|
3923
4098
|
},
|
|
4099
|
+
[Primitive.Concatenate]: linearTangentsJvp(Primitive.Concatenate),
|
|
4100
|
+
[Primitive.Split]: linearTangentsJvp(Primitive.Split),
|
|
3924
4101
|
[Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
|
|
3925
4102
|
[Primitive.Gather]([x, ...indices], [dx, ..._], { axis, outDim }) {
|
|
3926
4103
|
const indicesRef = indices.map((t) => t.ref);
|
|
@@ -3955,6 +4132,38 @@ const jvpRules = {
|
|
|
3955
4132
|
const dL = batchMatmulT(L.ref, triu(ST.ref, 1).add(triu(ST)).mul(.5));
|
|
3956
4133
|
return [[L], [dL]];
|
|
3957
4134
|
},
|
|
4135
|
+
[Primitive.LU]([a], [da]) {
|
|
4136
|
+
const [luMatrix, pivots, permutation] = lu$1(a);
|
|
4137
|
+
const [m, n] = a.shape.slice(-2);
|
|
4138
|
+
const k = Math.min(m, n);
|
|
4139
|
+
const luSliceL = sliceAxis(luMatrix.ref, -1, [0, k]);
|
|
4140
|
+
const lLower = tril(luSliceL, -1);
|
|
4141
|
+
const lPadded = m > k ? padAxis(lLower, -1, [0, m - k]) : lLower;
|
|
4142
|
+
const L = lPadded.add(eye(m));
|
|
4143
|
+
const luSliceU = sliceAxis(luMatrix.ref, -2, [0, k]);
|
|
4144
|
+
const uUpper = triu(luSliceU);
|
|
4145
|
+
const uPadded = n > k ? padAxis(uUpper, -2, [0, n - k]) : uUpper;
|
|
4146
|
+
const uEye = n > k ? padAxis(padAxis(eye(n - k), -1, [k, 0]), -2, [k, 0]) : zerosLike$1(uPadded.ref);
|
|
4147
|
+
const U = uPadded.add(uEye);
|
|
4148
|
+
const P = permutation.ref.reshape([...permutation.shape, 1]).equal(arange(m)).astype(da.dtype);
|
|
4149
|
+
const pda = batchMatmulT(P, mT(da));
|
|
4150
|
+
const la = mT(triangularSolve$1(L.ref, mT(pda), {
|
|
4151
|
+
lower: true,
|
|
4152
|
+
unitDiagonal: true
|
|
4153
|
+
}));
|
|
4154
|
+
const lau = triangularSolve$1(mT(U.ref), la, { lower: true });
|
|
4155
|
+
const lDot = batchMatmulT(L, mT(tril(lau.ref, -1)));
|
|
4156
|
+
const uDot = batchMatmulT(triu(lau), mT(U));
|
|
4157
|
+
return [[
|
|
4158
|
+
luMatrix,
|
|
4159
|
+
pivots,
|
|
4160
|
+
permutation
|
|
4161
|
+
], [
|
|
4162
|
+
lDot.add(uDot),
|
|
4163
|
+
zerosLike$1(pivots.ref),
|
|
4164
|
+
zerosLike$1(permutation.ref)
|
|
4165
|
+
]];
|
|
4166
|
+
},
|
|
3958
4167
|
[Primitive.Jit](primals, tangents, { name, jaxpr }) {
|
|
3959
4168
|
const newJaxpr = jvpJaxpr(jaxpr);
|
|
3960
4169
|
const outs = bind(Primitive.Jit, [
|
|
@@ -3995,17 +4204,39 @@ function jvpFlat(f, primals, tangents) {
|
|
|
3995
4204
|
_usingCtx$1.d();
|
|
3996
4205
|
}
|
|
3997
4206
|
}
|
|
3998
|
-
function jvp$1(f, primals, tangents) {
|
|
4207
|
+
function jvp$1(f, primals, tangents, { hasAux = false } = {}) {
|
|
3999
4208
|
const [primalsFlat, inTree] = flatten(primals);
|
|
4000
4209
|
const [tangentsFlat, inTree2] = flatten(tangents);
|
|
4001
4210
|
if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
|
|
4002
|
-
|
|
4211
|
+
let flatFun, outTree, aux;
|
|
4212
|
+
if (hasAux) [flatFun, outTree, aux] = flattenFunWithAux(f, inTree);
|
|
4213
|
+
else [flatFun, outTree] = flattenFun(f, inTree);
|
|
4003
4214
|
const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
|
|
4004
4215
|
if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
|
|
4005
4216
|
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
4006
4217
|
const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
|
|
4218
|
+
if (hasAux) return [
|
|
4219
|
+
primalsOut,
|
|
4220
|
+
tangentsOut,
|
|
4221
|
+
lowerAux(aux.value)
|
|
4222
|
+
];
|
|
4007
4223
|
return [primalsOut, tangentsOut];
|
|
4008
4224
|
}
|
|
4225
|
+
/** Lowering for auxiliary data returned in `hasAux: true` methods. */
|
|
4226
|
+
function lowerAux(aux) {
|
|
4227
|
+
const level = currentTraceLevel();
|
|
4228
|
+
return map((x) => {
|
|
4229
|
+
if (x instanceof Tracer) while (x._trace.main.level > level) if (x instanceof JVPTracer) {
|
|
4230
|
+
x.tangent.dispose();
|
|
4231
|
+
x = x.primal;
|
|
4232
|
+
} else {
|
|
4233
|
+
const y = x.fullLower();
|
|
4234
|
+
if (y._trace.main.level >= x._trace.main.level) throw new Error("internal: lowerAux did not reduce trace level");
|
|
4235
|
+
x = y;
|
|
4236
|
+
}
|
|
4237
|
+
return x;
|
|
4238
|
+
}, aux);
|
|
4239
|
+
}
|
|
4009
4240
|
|
|
4010
4241
|
//#endregion
|
|
4011
4242
|
//#region src/frontend/linearize.ts
|
|
@@ -4076,9 +4307,11 @@ function linearizeFlat(f, primalsIn) {
|
|
|
4076
4307
|
dispose$1
|
|
4077
4308
|
];
|
|
4078
4309
|
}
|
|
4079
|
-
function linearize$1(f,
|
|
4310
|
+
function linearize$1(f, primalsIn, { hasAux = false } = {}) {
|
|
4080
4311
|
const [primalsInFlat, inTree] = flatten(primalsIn);
|
|
4081
|
-
|
|
4312
|
+
let fFlat, outTree, aux;
|
|
4313
|
+
if (hasAux) [fFlat, outTree, aux] = flattenFunWithAux(f, inTree);
|
|
4314
|
+
else [fFlat, outTree] = flattenFun(f, inTree);
|
|
4082
4315
|
const [primalsOutFlat, fLinFlat, dispose$1] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
|
|
4083
4316
|
if (outTree.value === void 0) throw new Error("outTree was not set in linearize");
|
|
4084
4317
|
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
@@ -4089,6 +4322,11 @@ function linearize$1(f, ...primalsIn) {
|
|
|
4089
4322
|
return unflatten(outTree.value, tangentsOutFlat);
|
|
4090
4323
|
});
|
|
4091
4324
|
fLin.dispose = dispose$1;
|
|
4325
|
+
if (hasAux) return [
|
|
4326
|
+
primalsOut,
|
|
4327
|
+
fLin,
|
|
4328
|
+
lowerAux(aux.value)
|
|
4329
|
+
];
|
|
4092
4330
|
return [primalsOut, fLin];
|
|
4093
4331
|
}
|
|
4094
4332
|
var PartialEvalTracer = class extends Tracer {
|
|
@@ -4492,6 +4730,15 @@ const transposeRules = {
|
|
|
4492
4730
|
cond.dispose();
|
|
4493
4731
|
return cts;
|
|
4494
4732
|
},
|
|
4733
|
+
[Primitive.Concatenate]([ct], inputs, { axis }) {
|
|
4734
|
+
if (inputs.some((x) => !(x instanceof UndefPrimal))) throw new NonlinearError(Primitive.Concatenate);
|
|
4735
|
+
const sizes = inputs.map((x) => x.aval.shape[axis]);
|
|
4736
|
+
return split$2(ct, axis, sizes);
|
|
4737
|
+
},
|
|
4738
|
+
[Primitive.Split](cts, [x], { axis }) {
|
|
4739
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Split);
|
|
4740
|
+
return [concatenate$1(cts, axis)];
|
|
4741
|
+
},
|
|
4495
4742
|
[Primitive.Gather]([ct], [x, ...indices], { axis, outDim }) {
|
|
4496
4743
|
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
|
|
4497
4744
|
if (indices.some((i) => i instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
|
|
@@ -4580,9 +4827,11 @@ function vjpFlat(f, primalsIn) {
|
|
|
4580
4827
|
dispose$1
|
|
4581
4828
|
];
|
|
4582
4829
|
}
|
|
4583
|
-
function vjp$1(f,
|
|
4830
|
+
function vjp$1(f, primalsIn, { hasAux = false } = {}) {
|
|
4584
4831
|
const [primalsInFlat, inTree] = flatten(primalsIn);
|
|
4585
|
-
|
|
4832
|
+
let fFlat, outTree, aux;
|
|
4833
|
+
if (hasAux) [fFlat, outTree, aux] = flattenFunWithAux(f, inTree);
|
|
4834
|
+
else [fFlat, outTree] = flattenFun(f, inTree);
|
|
4586
4835
|
const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
|
|
4587
4836
|
if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
|
|
4588
4837
|
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
@@ -4593,26 +4842,43 @@ function vjp$1(f, ...primalsIn) {
|
|
|
4593
4842
|
return unflatten(inTree, cotangentsInFlat);
|
|
4594
4843
|
});
|
|
4595
4844
|
fVjp.dispose = dispose$1;
|
|
4845
|
+
if (hasAux) return [
|
|
4846
|
+
primalsOut,
|
|
4847
|
+
fVjp,
|
|
4848
|
+
lowerAux(aux.value)
|
|
4849
|
+
];
|
|
4596
4850
|
return [primalsOut, fVjp];
|
|
4597
4851
|
}
|
|
4598
|
-
function grad$1(f) {
|
|
4599
|
-
const valueAndGradFn = valueAndGrad$1(f);
|
|
4852
|
+
function grad$1(f, opts) {
|
|
4853
|
+
const valueAndGradFn = valueAndGrad$1(f, opts);
|
|
4600
4854
|
return (...x) => {
|
|
4601
|
-
|
|
4602
|
-
|
|
4603
|
-
|
|
4855
|
+
if (opts?.hasAux) {
|
|
4856
|
+
const [[y, aux], dx] = valueAndGradFn(...x);
|
|
4857
|
+
y.dispose();
|
|
4858
|
+
return [dx, aux];
|
|
4859
|
+
} else {
|
|
4860
|
+
const [y, dx] = valueAndGradFn(...x);
|
|
4861
|
+
y.dispose();
|
|
4862
|
+
return dx;
|
|
4863
|
+
}
|
|
4604
4864
|
};
|
|
4605
4865
|
}
|
|
4606
|
-
function valueAndGrad$1(f) {
|
|
4866
|
+
function valueAndGrad$1(f, opts) {
|
|
4867
|
+
const argnums = opts?.argnums ?? 0;
|
|
4868
|
+
const hasAux = opts?.hasAux ?? false;
|
|
4869
|
+
checkInts(argnums);
|
|
4870
|
+
const argnumsSet = new Set(typeof argnums === "number" ? [argnums] : argnums);
|
|
4607
4871
|
return (...x) => {
|
|
4608
4872
|
if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
|
|
4609
|
-
|
|
4873
|
+
for (let i = 0; i < x.length; i++) if (!argnumsSet.has(i)) x[i] = map(stopGradient, x[i]);
|
|
4874
|
+
const [y, fVjp, aux] = vjp$1(f, x, { hasAux });
|
|
4610
4875
|
if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
|
|
4611
4876
|
if (!isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
|
|
4612
|
-
const
|
|
4613
|
-
for (const r of rest) dispose(r);
|
|
4877
|
+
const cts = fVjp(onesLike$1(y.ref));
|
|
4614
4878
|
fVjp.dispose();
|
|
4615
|
-
|
|
4879
|
+
for (let i = 0; i < cts.length; i++) if (!argnumsSet.has(i)) dispose(cts[i]);
|
|
4880
|
+
const grads = typeof argnums === "number" ? cts[argnums] : argnums.map((i) => cts[i]);
|
|
4881
|
+
return hasAux ? [[y, aux], grads] : [y, grads];
|
|
4616
4882
|
};
|
|
4617
4883
|
}
|
|
4618
4884
|
function jacrev$1(f) {
|
|
@@ -4620,7 +4886,7 @@ function jacrev$1(f) {
|
|
|
4620
4886
|
if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
|
|
4621
4887
|
const [size$1] = x.shape;
|
|
4622
4888
|
const pullback = (ct) => {
|
|
4623
|
-
const [y, fVjp] = vjp$1(f, x);
|
|
4889
|
+
const [y, fVjp] = vjp$1(f, [x]);
|
|
4624
4890
|
y.dispose();
|
|
4625
4891
|
const [ret] = fVjp(ct);
|
|
4626
4892
|
fVjp.dispose();
|
|
@@ -4629,6 +4895,9 @@ function jacrev$1(f) {
|
|
|
4629
4895
|
return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
|
|
4630
4896
|
};
|
|
4631
4897
|
}
|
|
4898
|
+
function hessian$1(f) {
|
|
4899
|
+
return jacfwd$1(grad$1(f));
|
|
4900
|
+
}
|
|
4632
4901
|
|
|
4633
4902
|
//#endregion
|
|
4634
4903
|
//#region src/library/numpy/einsum.ts
|
|
@@ -4767,8 +5036,8 @@ function computeSizeMap({ shapes, lhsIndices, rhsIndex }) {
|
|
|
4767
5036
|
const idx = lhsIndex[j];
|
|
4768
5037
|
const dim = shape$1[j];
|
|
4769
5038
|
const existing = sizeMap.get(idx);
|
|
4770
|
-
if (existing === void 0) sizeMap.set(idx, dim);
|
|
4771
|
-
else if (existing !== dim) throw new Error(`Inconsistent size for index ${idx} in einsum: ${existing} vs ${dim}`);
|
|
5039
|
+
if (existing === void 0 || existing === 1) sizeMap.set(idx, dim);
|
|
5040
|
+
else if (existing !== dim && dim !== 1) throw new Error(`Inconsistent size for index ${idx} in einsum: ${existing} vs ${dim}`);
|
|
4772
5041
|
}
|
|
4773
5042
|
}
|
|
4774
5043
|
for (const [idx, size$1] of sizeMap) if (!Number.isInteger(idx) || idx < 0) throw new Error(`Invalid index ${idx} in einsum expression, must be non-negative integer`);
|
|
@@ -4924,27 +5193,53 @@ function ifft(a, axis = -1) {
|
|
|
4924
5193
|
//#region src/library/numpy-linalg.ts
|
|
4925
5194
|
var numpy_linalg_exports = {};
|
|
4926
5195
|
__export(numpy_linalg_exports, {
|
|
4927
|
-
cholesky: () => cholesky
|
|
5196
|
+
cholesky: () => cholesky,
|
|
5197
|
+
det: () => det,
|
|
4928
5198
|
diagonal: () => diagonal,
|
|
5199
|
+
inv: () => inv,
|
|
4929
5200
|
lstsq: () => lstsq,
|
|
4930
5201
|
matmul: () => matmul,
|
|
5202
|
+
matrixPower: () => matrixPower,
|
|
4931
5203
|
matrixTranspose: () => matrixTranspose,
|
|
4932
5204
|
outer: () => outer,
|
|
5205
|
+
slogdet: () => slogdet,
|
|
5206
|
+
solve: () => solve,
|
|
4933
5207
|
tensordot: () => tensordot,
|
|
4934
5208
|
trace: () => trace,
|
|
4935
5209
|
vecdot: () => vecdot
|
|
4936
5210
|
});
|
|
5211
|
+
function checkSquare(name, a) {
|
|
5212
|
+
if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`${name}: input must be at least 2D square matrix, got ${a.aval}`);
|
|
5213
|
+
return a.shape[a.ndim - 1];
|
|
5214
|
+
}
|
|
4937
5215
|
/**
|
|
4938
5216
|
* Compute the Cholesky decomposition of a (batched) positive-definite matrix.
|
|
4939
5217
|
*
|
|
4940
5218
|
* This is like `jax.lax.linalg.cholesky()`, except with an option to symmetrize
|
|
4941
5219
|
* the input matrix, which is on by default.
|
|
4942
5220
|
*/
|
|
4943
|
-
function cholesky
|
|
5221
|
+
function cholesky(a, { upper = false, symmetrizeInput = true } = {}) {
|
|
4944
5222
|
a = fudgeArray(a);
|
|
4945
|
-
|
|
5223
|
+
checkSquare("cholesky", a);
|
|
4946
5224
|
if (symmetrizeInput) a = a.ref.add(matrixTranspose(a)).mul(.5);
|
|
4947
|
-
return cholesky(a, { upper });
|
|
5225
|
+
return cholesky$1(a, { upper });
|
|
5226
|
+
}
|
|
5227
|
+
/** Compute the determinant of a square matrix (batched). */
|
|
5228
|
+
function det(a) {
|
|
5229
|
+
a = fudgeArray(a);
|
|
5230
|
+
const n = checkSquare("det", a);
|
|
5231
|
+
const [lu$2, pivots, permutation] = lu(a);
|
|
5232
|
+
permutation.dispose();
|
|
5233
|
+
const parity = pivots.notEqual(arange(n)).astype(int32).sum(-1).mod(2);
|
|
5234
|
+
const sign$1 = parity.mul(-2).add(1);
|
|
5235
|
+
const diag$1 = lu$2.diagonal(0, -1, -2);
|
|
5236
|
+
return prod$1(diag$1, -1).mul(sign$1);
|
|
5237
|
+
}
|
|
5238
|
+
/** Compute the inverse of a square matrix (batched). */
|
|
5239
|
+
function inv(a) {
|
|
5240
|
+
a = fudgeArray(a);
|
|
5241
|
+
const n = checkSquare("inv", a);
|
|
5242
|
+
return solve(a, eye(n));
|
|
4948
5243
|
}
|
|
4949
5244
|
/**
|
|
4950
5245
|
* Return the least-squares solution to a linear equation.
|
|
@@ -4968,7 +5263,7 @@ function lstsq(a, b) {
|
|
|
4968
5263
|
const at = matrixTranspose(a.ref);
|
|
4969
5264
|
if (m <= n) {
|
|
4970
5265
|
const aat = matmul(a, at.ref);
|
|
4971
|
-
const l = cholesky
|
|
5266
|
+
const l = cholesky(aat, { symmetrizeInput: false });
|
|
4972
5267
|
const lb = triangularSolve(l.ref, b, {
|
|
4973
5268
|
leftSide: true,
|
|
4974
5269
|
lower: true
|
|
@@ -4980,7 +5275,7 @@ function lstsq(a, b) {
|
|
|
4980
5275
|
return matmul(at, llb.ref);
|
|
4981
5276
|
} else {
|
|
4982
5277
|
const ata = matmul(at.ref, a);
|
|
4983
|
-
const l = cholesky
|
|
5278
|
+
const l = cholesky(ata, { symmetrizeInput: false });
|
|
4984
5279
|
const atb = matmul(at, b);
|
|
4985
5280
|
const lb = triangularSolve(l.ref, atb, {
|
|
4986
5281
|
leftSide: true,
|
|
@@ -4993,6 +5288,169 @@ function lstsq(a, b) {
|
|
|
4993
5288
|
return llb;
|
|
4994
5289
|
}
|
|
4995
5290
|
}
|
|
5291
|
+
/** Raise a square matrix to an integer power, via repeated squarings. */
|
|
5292
|
+
function matrixPower(a, n) {
|
|
5293
|
+
if (!Number.isInteger(n)) throw new Error(`matrixPower: exponent must be an integer, got ${n}`);
|
|
5294
|
+
a = fudgeArray(a);
|
|
5295
|
+
const m = checkSquare("matrixPower", a);
|
|
5296
|
+
if (n === 0) {
|
|
5297
|
+
a.dispose();
|
|
5298
|
+
return broadcastTo(eye(m), a.shape);
|
|
5299
|
+
}
|
|
5300
|
+
if (n < 0) {
|
|
5301
|
+
a = inv(a);
|
|
5302
|
+
n = -n;
|
|
5303
|
+
}
|
|
5304
|
+
let result = null;
|
|
5305
|
+
let a2k = a;
|
|
5306
|
+
for (let k = 0; n; k++) {
|
|
5307
|
+
if (k > 0) a2k = matmul(a2k.ref, a2k);
|
|
5308
|
+
if (n % 2 === 1) result = result === null ? a2k.ref : matmul(result, a2k.ref);
|
|
5309
|
+
n = Math.floor(n / 2);
|
|
5310
|
+
}
|
|
5311
|
+
a2k.dispose();
|
|
5312
|
+
return result;
|
|
5313
|
+
}
|
|
5314
|
+
/** Return sign and natural logarithm of the determinant of `a`. */
|
|
5315
|
+
function slogdet(a) {
|
|
5316
|
+
a = fudgeArray(a);
|
|
5317
|
+
const n = checkSquare("slogdet", a);
|
|
5318
|
+
const [lu$2, pivots, permutation] = lu(a);
|
|
5319
|
+
permutation.dispose();
|
|
5320
|
+
let parity = pivots.notEqual(arange(n)).astype(int32).sum(-1);
|
|
5321
|
+
const diag$1 = lu$2.diagonal(0, -1, -2);
|
|
5322
|
+
parity = parity.add(diag$1.ref.less(0).astype(int32).sum(-1)).mod(2);
|
|
5323
|
+
const logabsdet = log(absolute(diag$1)).sum(-1);
|
|
5324
|
+
const sign$1 = parity.mul(-2).add(1);
|
|
5325
|
+
return [sign$1, logabsdet];
|
|
5326
|
+
}
|
|
5327
|
+
/**
|
|
5328
|
+
* Solve a linear system of equations.
|
|
5329
|
+
*
|
|
5330
|
+
* This solves a (batched) linear system of equations `a @ x = b` for `x` given
|
|
5331
|
+
* `a` and `b`. If `a` is singular, this will return `nan` or `inf` values.
|
|
5332
|
+
*
|
|
5333
|
+
* @param a - Coefficient matrix of shape `(..., N, N)`.
|
|
5334
|
+
* @param b - Values of shape `(N,)` or `(..., N, M)`.
|
|
5335
|
+
* @returns Solution `x` of shape `(..., N)` or `(..., N, M)`.
|
|
5336
|
+
*/
|
|
5337
|
+
function solve(a, b) {
|
|
5338
|
+
a = fudgeArray(a);
|
|
5339
|
+
b = fudgeArray(b);
|
|
5340
|
+
const n = checkSquare("solve", a);
|
|
5341
|
+
if (b.ndim === 0) throw new Error(`solve: b cannot be scalar`);
|
|
5342
|
+
const bIs1d = b.ndim === 1;
|
|
5343
|
+
if (bIs1d) b = b.reshape([...b.shape, 1]);
|
|
5344
|
+
if (b.shape[b.ndim - 2] !== n) throw new Error(`solve: leading dimension of b must match size of a, got a=${a.aval}, b=${b.aval}`);
|
|
5345
|
+
const m = b.shape[b.ndim - 1];
|
|
5346
|
+
const batchDims = generalBroadcast(a.shape.slice(0, -2), b.shape.slice(0, -2));
|
|
5347
|
+
a = broadcastTo(a, [
|
|
5348
|
+
...batchDims,
|
|
5349
|
+
n,
|
|
5350
|
+
n
|
|
5351
|
+
]);
|
|
5352
|
+
b = broadcastTo(b, [
|
|
5353
|
+
...batchDims,
|
|
5354
|
+
n,
|
|
5355
|
+
m
|
|
5356
|
+
]);
|
|
5357
|
+
const [lu$2, pivots, permutation] = lu(a);
|
|
5358
|
+
pivots.dispose();
|
|
5359
|
+
const P = arange(n).equal(permutation.reshape([...permutation.shape, 1])).astype(b.dtype);
|
|
5360
|
+
const LPb = triangularSolve(lu$2.ref, matmul(P, b), {
|
|
5361
|
+
leftSide: true,
|
|
5362
|
+
lower: true,
|
|
5363
|
+
unitDiagonal: true
|
|
5364
|
+
});
|
|
5365
|
+
let x = triangularSolve(lu$2, LPb.ref, {
|
|
5366
|
+
leftSide: true,
|
|
5367
|
+
lower: false
|
|
5368
|
+
});
|
|
5369
|
+
if (bIs1d) x = squeeze(x, -1);
|
|
5370
|
+
return x;
|
|
5371
|
+
}
|
|
5372
|
+
|
|
5373
|
+
//#endregion
|
|
5374
|
+
//#region src/library/numpy/dtype-info.ts
|
|
5375
|
+
/** Machine limits for floating-point types. */
|
|
5376
|
+
function finfo(dtype) {
|
|
5377
|
+
if (!isFloatDtype(dtype)) throw new Error(`finfo: received ${dtype}, must be a floating-point type`);
|
|
5378
|
+
switch (dtype) {
|
|
5379
|
+
case DType.Float16: return Object.freeze({
|
|
5380
|
+
bits: 16,
|
|
5381
|
+
dtype: DType.Float16,
|
|
5382
|
+
eps: 2 ** -10,
|
|
5383
|
+
epsneg: 2 ** -11,
|
|
5384
|
+
machep: -10,
|
|
5385
|
+
max: 65504,
|
|
5386
|
+
maxexp: 16,
|
|
5387
|
+
min: -65504,
|
|
5388
|
+
minexp: -14,
|
|
5389
|
+
negep: -24,
|
|
5390
|
+
nexp: 5,
|
|
5391
|
+
nmant: 10,
|
|
5392
|
+
precision: 3,
|
|
5393
|
+
resolution: .001,
|
|
5394
|
+
smallestNormal: 2 ** -14,
|
|
5395
|
+
smallestSubnormal: 2 ** -24
|
|
5396
|
+
});
|
|
5397
|
+
case DType.Float32: return Object.freeze({
|
|
5398
|
+
bits: 32,
|
|
5399
|
+
dtype: DType.Float32,
|
|
5400
|
+
eps: 2 ** -23,
|
|
5401
|
+
epsneg: 2 ** -24,
|
|
5402
|
+
machep: -23,
|
|
5403
|
+
max: 34028234663852886e22,
|
|
5404
|
+
maxexp: 128,
|
|
5405
|
+
min: -34028234663852886e22,
|
|
5406
|
+
minexp: -126,
|
|
5407
|
+
negep: -24,
|
|
5408
|
+
nexp: 8,
|
|
5409
|
+
nmant: 23,
|
|
5410
|
+
precision: 6,
|
|
5411
|
+
resolution: 1e-6,
|
|
5412
|
+
smallestNormal: 2 ** -126,
|
|
5413
|
+
smallestSubnormal: 2 ** -149
|
|
5414
|
+
});
|
|
5415
|
+
case DType.Float64: return Object.freeze({
|
|
5416
|
+
bits: 64,
|
|
5417
|
+
dtype: DType.Float64,
|
|
5418
|
+
eps: 2 ** -52,
|
|
5419
|
+
epsneg: 2 ** -53,
|
|
5420
|
+
machep: -52,
|
|
5421
|
+
max: Number.MAX_VALUE,
|
|
5422
|
+
maxexp: 1024,
|
|
5423
|
+
min: -Number.MAX_VALUE,
|
|
5424
|
+
minexp: -1022,
|
|
5425
|
+
negep: -53,
|
|
5426
|
+
nexp: 11,
|
|
5427
|
+
nmant: 52,
|
|
5428
|
+
precision: 15,
|
|
5429
|
+
resolution: 1e-15,
|
|
5430
|
+
smallestNormal: 2 ** -1022,
|
|
5431
|
+
smallestSubnormal: 2 ** -1074
|
|
5432
|
+
});
|
|
5433
|
+
default: throw new Error(`finfo: unsupported dtype ${dtype}`);
|
|
5434
|
+
}
|
|
5435
|
+
}
|
|
5436
|
+
/** Machine limits for integer types. */
|
|
5437
|
+
function iinfo(dtype) {
|
|
5438
|
+
switch (dtype) {
|
|
5439
|
+
case DType.Int32: return Object.freeze({
|
|
5440
|
+
bits: 32,
|
|
5441
|
+
dtype: DType.Int32,
|
|
5442
|
+
max: 2147483647,
|
|
5443
|
+
min: -2147483648
|
|
5444
|
+
});
|
|
5445
|
+
case DType.Uint32: return Object.freeze({
|
|
5446
|
+
bits: 32,
|
|
5447
|
+
dtype: DType.Uint32,
|
|
5448
|
+
max: 4294967295,
|
|
5449
|
+
min: 0
|
|
5450
|
+
});
|
|
5451
|
+
default: throw new Error(`iinfo: unsupported dtype ${dtype}`);
|
|
5452
|
+
}
|
|
5453
|
+
}
|
|
4996
5454
|
|
|
4997
5455
|
//#endregion
|
|
4998
5456
|
//#region src/library/numpy.ts
|
|
@@ -5048,6 +5506,7 @@ __export(numpy_exports, {
|
|
|
5048
5506
|
diag: () => diag,
|
|
5049
5507
|
diagonal: () => diagonal,
|
|
5050
5508
|
divide: () => trueDivide,
|
|
5509
|
+
divmod: () => divmod,
|
|
5051
5510
|
dot: () => dot$1,
|
|
5052
5511
|
dstack: () => dstack,
|
|
5053
5512
|
e: () => e,
|
|
@@ -5060,6 +5519,7 @@ __export(numpy_exports, {
|
|
|
5060
5519
|
expm1: () => expm1,
|
|
5061
5520
|
eye: () => eye,
|
|
5062
5521
|
fft: () => numpy_fft_exports,
|
|
5522
|
+
finfo: () => finfo,
|
|
5063
5523
|
flip: () => flip,
|
|
5064
5524
|
fliplr: () => fliplr,
|
|
5065
5525
|
flipud: () => flipud,
|
|
@@ -5067,6 +5527,7 @@ __export(numpy_exports, {
|
|
|
5067
5527
|
float32: () => float32,
|
|
5068
5528
|
float64: () => float64,
|
|
5069
5529
|
floor: () => floor,
|
|
5530
|
+
floorDivide: () => floorDivide,
|
|
5070
5531
|
fmod: () => fmod,
|
|
5071
5532
|
frexp: () => frexp,
|
|
5072
5533
|
full: () => full,
|
|
@@ -5079,6 +5540,7 @@ __export(numpy_exports, {
|
|
|
5079
5540
|
hstack: () => hstack,
|
|
5080
5541
|
hypot: () => hypot,
|
|
5081
5542
|
identity: () => identity$1,
|
|
5543
|
+
iinfo: () => iinfo,
|
|
5082
5544
|
inf: () => inf,
|
|
5083
5545
|
inner: () => inner,
|
|
5084
5546
|
int32: () => int32,
|
|
@@ -5096,6 +5558,7 @@ __export(numpy_exports, {
|
|
|
5096
5558
|
log10: () => log10,
|
|
5097
5559
|
log1p: () => log1p,
|
|
5098
5560
|
log2: () => log2,
|
|
5561
|
+
logspace: () => logspace,
|
|
5099
5562
|
matmul: () => matmul,
|
|
5100
5563
|
matrixTranspose: () => matrixTranspose,
|
|
5101
5564
|
max: () => max,
|
|
@@ -5132,9 +5595,11 @@ __export(numpy_exports, {
|
|
|
5132
5595
|
shape: () => shape,
|
|
5133
5596
|
sign: () => sign,
|
|
5134
5597
|
sin: () => sin,
|
|
5598
|
+
sinc: () => sinc,
|
|
5135
5599
|
sinh: () => sinh,
|
|
5136
5600
|
size: () => size,
|
|
5137
5601
|
sort: () => sort,
|
|
5602
|
+
split: () => split$1,
|
|
5138
5603
|
sqrt: () => sqrt,
|
|
5139
5604
|
square: () => square,
|
|
5140
5605
|
squeeze: () => squeeze,
|
|
@@ -5142,6 +5607,8 @@ __export(numpy_exports, {
|
|
|
5142
5607
|
std: () => std,
|
|
5143
5608
|
subtract: () => subtract,
|
|
5144
5609
|
sum: () => sum,
|
|
5610
|
+
swapaxes: () => swapaxes,
|
|
5611
|
+
take: () => take,
|
|
5145
5612
|
tan: () => tan,
|
|
5146
5613
|
tanh: () => tanh,
|
|
5147
5614
|
tensordot: () => tensordot,
|
|
@@ -5400,6 +5867,45 @@ function flip(x, axis = null) {
|
|
|
5400
5867
|
return flip$1(x, axis);
|
|
5401
5868
|
}
|
|
5402
5869
|
/**
|
|
5870
|
+
* Split an array into multiple sub-arrays along an axis.
|
|
5871
|
+
*
|
|
5872
|
+
* @param a - The input array to split.
|
|
5873
|
+
* @param indicesOrSections - If an integer, it indicates the number of equal
|
|
5874
|
+
* sections to create along the specified axis. If a list of integers, it
|
|
5875
|
+
* specifies the indices at which to split the array.
|
|
5876
|
+
* @param axis - The axis along which to split the array. Default is 0.
|
|
5877
|
+
*/
|
|
5878
|
+
function split$1(a, indicesOrSections, axis = 0) {
|
|
5879
|
+
a = fudgeArray(a);
|
|
5880
|
+
axis = checkAxis(axis, a.ndim);
|
|
5881
|
+
const size$1 = a.shape[axis];
|
|
5882
|
+
let sizes;
|
|
5883
|
+
if (typeof indicesOrSections === "number") {
|
|
5884
|
+
if (size$1 % indicesOrSections !== 0) throw new Error(`Array of size ${size$1} cannot be split into ${indicesOrSections} equal parts`);
|
|
5885
|
+
const partSize = size$1 / indicesOrSections;
|
|
5886
|
+
sizes = rep(indicesOrSections, partSize);
|
|
5887
|
+
} else {
|
|
5888
|
+
const indices = indicesOrSections;
|
|
5889
|
+
sizes = [indices[0]];
|
|
5890
|
+
for (let i = 1; i < indices.length; i++) sizes.push(indices[i] - indices[i - 1]);
|
|
5891
|
+
sizes.push(size$1 - indices[indices.length - 1]);
|
|
5892
|
+
}
|
|
5893
|
+
const results = [];
|
|
5894
|
+
for (let i = 0; i < sizes.length; i += 7) if (i === sizes.length) {
|
|
5895
|
+
results.push(a);
|
|
5896
|
+
break;
|
|
5897
|
+
} else if (i + 8 >= sizes.length) {
|
|
5898
|
+
results.push(...split$2(a, axis, sizes.slice(i)));
|
|
5899
|
+
break;
|
|
5900
|
+
} else {
|
|
5901
|
+
const groupSizes = [...sizes.slice(i, i + 7), sizes.slice(i + 7).reduce((x, y) => x + y, 0)];
|
|
5902
|
+
const outs = split$2(a, axis, groupSizes);
|
|
5903
|
+
results.push(...outs.slice(0, -1));
|
|
5904
|
+
a = outs[outs.length - 1];
|
|
5905
|
+
}
|
|
5906
|
+
return results;
|
|
5907
|
+
}
|
|
5908
|
+
/**
|
|
5403
5909
|
* Join a sequence of arrays along an existing axis.
|
|
5404
5910
|
*
|
|
5405
5911
|
* The arrays must have the same shape, except in the dimension corresponding to
|
|
@@ -5411,13 +5917,11 @@ function concatenate(xs, axis = 0) {
|
|
|
5411
5917
|
if (xs.length === 0) throw new Error("Need at least one array to concatenate");
|
|
5412
5918
|
const shapes = xs.map(shape);
|
|
5413
5919
|
axis = checkAxis(axis, shapes[0].length);
|
|
5414
|
-
for (let i = 1; i < shapes.length; i++) if (shapes[i].length !== shapes[0].length || !shapes[i].every((d, j) => j === axis || d === shapes[0][j])) throw new Error(`Cannot concatenate arrays
|
|
5415
|
-
const makePadAxis = (start, end) => shapes[0].map((_, i) => i === axis ? [start, end] : [0, 0]);
|
|
5920
|
+
for (let i = 1; i < shapes.length; i++) if (shapes[i].length !== shapes[0].length || !shapes[i].every((d, j) => j === axis || d === shapes[0][j])) throw new Error(`Cannot concatenate arrays ${xs[0].aval} and ${xs[i].aval} along axis ${axis}`);
|
|
5416
5921
|
let result = xs[0];
|
|
5417
|
-
for (let i = 1; i < xs.length; i
|
|
5418
|
-
const
|
|
5419
|
-
|
|
5420
|
-
result = pad(result, makePadAxis(0, len2)).add(pad(xs[i], makePadAxis(len1, 0)));
|
|
5922
|
+
for (let i = 1; i < xs.length; i += 7) {
|
|
5923
|
+
const group = xs.slice(i, i + 7);
|
|
5924
|
+
result = concatenate$1([result, ...group], axis);
|
|
5421
5925
|
}
|
|
5422
5926
|
return result;
|
|
5423
5927
|
}
|
|
@@ -5502,6 +6006,17 @@ function flipud(x) {
|
|
|
5502
6006
|
function fliplr(x) {
|
|
5503
6007
|
return flip(x, 1);
|
|
5504
6008
|
}
|
|
6009
|
+
/** Interchange two axes of an array. */
|
|
6010
|
+
function swapaxes(a, axis1, axis2) {
|
|
6011
|
+
a = fudgeArray(a);
|
|
6012
|
+
axis1 = checkAxis(axis1, a.ndim);
|
|
6013
|
+
axis2 = checkAxis(axis2, a.ndim);
|
|
6014
|
+
if (axis1 === axis2) return a;
|
|
6015
|
+
const perm = range(a.ndim);
|
|
6016
|
+
perm[axis1] = axis2;
|
|
6017
|
+
perm[axis2] = axis1;
|
|
6018
|
+
return transpose(a, perm);
|
|
6019
|
+
}
|
|
5505
6020
|
/** Transpose the last two dimensions of an array. */
|
|
5506
6021
|
function matrixTranspose(a) {
|
|
5507
6022
|
if (ndim(a) < 2) throw new Error(`matrixTranspose: input array must be at least 2D`);
|
|
@@ -5669,6 +6184,20 @@ function sort(a, axis = -1) {
|
|
|
5669
6184
|
function argsort(a, axis = -1) {
|
|
5670
6185
|
return fudgeArray(a).argsort(axis);
|
|
5671
6186
|
}
|
|
6187
|
+
/**
|
|
6188
|
+
* Take elements from an array along an axis.
|
|
6189
|
+
*
|
|
6190
|
+
* This is equivalent to advanced indexing with integer indices over that
|
|
6191
|
+
* numbered axis. By default, the flattened array is used.
|
|
6192
|
+
*/
|
|
6193
|
+
function take(a, indices, axis = null) {
|
|
6194
|
+
if (axis === null) {
|
|
6195
|
+
a = ravel(a);
|
|
6196
|
+
axis = 0;
|
|
6197
|
+
}
|
|
6198
|
+
axis = checkAxis(axis, ndim(a));
|
|
6199
|
+
return gather(a, [indices], [axis], axis);
|
|
6200
|
+
}
|
|
5672
6201
|
/** Return if two arrays are element-wise equal within a tolerance. */
|
|
5673
6202
|
function allclose(actual, expected, options) {
|
|
5674
6203
|
const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
|
|
@@ -5988,6 +6517,20 @@ function tan(x) {
|
|
|
5988
6517
|
x = fudgeArray(x);
|
|
5989
6518
|
return sin(x.ref).div(cos(x));
|
|
5990
6519
|
}
|
|
6520
|
+
/**
|
|
6521
|
+
* @function
|
|
6522
|
+
* Return the normalized sinc function.
|
|
6523
|
+
*
|
|
6524
|
+
* The sinc function is defined as `sin(πx) / (πx)` for `x != 0`, and `1` for `x = 0`.
|
|
6525
|
+
* This is the normalized sinc function commonly used in signal processing.
|
|
6526
|
+
*
|
|
6527
|
+
* **Note:** JVP is not supported at x=0 due to discontinuous derivative. This
|
|
6528
|
+
* requires a custom JVP rule to handle properly (see JAX implementation).
|
|
6529
|
+
*/
|
|
6530
|
+
const sinc = jit$1(function sinc$1(x) {
|
|
6531
|
+
const pix = x.ref.mul(Math.PI);
|
|
6532
|
+
return where(equal(x, 0), 1, sin(pix.ref).div(pix));
|
|
6533
|
+
});
|
|
5991
6534
|
/** Element-wise inverse cosine function (inverse of cos). */
|
|
5992
6535
|
function acos(x) {
|
|
5993
6536
|
return subtract(pi / 2, asin(x));
|
|
@@ -6040,6 +6583,25 @@ function trueDivide(x, y) {
|
|
|
6040
6583
|
return x.div(y);
|
|
6041
6584
|
}
|
|
6042
6585
|
/**
|
|
6586
|
+
* Return the largest integer smaller or equal to the division of the inputs.
|
|
6587
|
+
*
|
|
6588
|
+
* The result is always rounded towards negative infinity.
|
|
6589
|
+
*
|
|
6590
|
+
* For floating-point inputs, this is equivalent to `floor(x / y)`.
|
|
6591
|
+
* For integer inputs, we use `(x - remainder(x, y)) / y` to handle
|
|
6592
|
+
* negative values correctly (note: may overflow near int32 boundaries).
|
|
6593
|
+
*
|
|
6594
|
+
* @param x - Dividend array.
|
|
6595
|
+
* @param y - Divisor array.
|
|
6596
|
+
* @returns Element-wise floor division of x by y.
|
|
6597
|
+
*/
|
|
6598
|
+
function floorDivide(x, y) {
|
|
6599
|
+
x = fudgeArray(x);
|
|
6600
|
+
y = fudgeArray(y);
|
|
6601
|
+
if (isFloatDtype(x.dtype) || isFloatDtype(y.dtype)) return floor(trueDivide(x, y));
|
|
6602
|
+
return subtract(x, remainder(x.ref, y.ref)).div(y);
|
|
6603
|
+
}
|
|
6604
|
+
/**
|
|
6043
6605
|
* @function
|
|
6044
6606
|
* Calculate element-wise floating-point modulo operation.
|
|
6045
6607
|
*/
|
|
@@ -6053,6 +6615,20 @@ const fmod = jit$1(function fmod$1(x, y) {
|
|
|
6053
6615
|
const remainder = jit$1(function remainder$1(x, y) {
|
|
6054
6616
|
return mod(mod(x, y.ref).add(y.ref), y);
|
|
6055
6617
|
});
|
|
6618
|
+
/**
|
|
6619
|
+
* Return element-wise quotient and remainder simultaneously.
|
|
6620
|
+
*
|
|
6621
|
+
* Equivalent to `[floorDivide(x, y), remainder(x, y)]`.
|
|
6622
|
+
*
|
|
6623
|
+
* @param x - Dividend array.
|
|
6624
|
+
* @param y - Divisor array.
|
|
6625
|
+
* @returns Tuple of [quotient, remainder].
|
|
6626
|
+
*/
|
|
6627
|
+
function divmod(x, y) {
|
|
6628
|
+
const xArr = fudgeArray(x);
|
|
6629
|
+
const yArr = fudgeArray(y);
|
|
6630
|
+
return [floorDivide(xArr.ref, yArr.ref), remainder(xArr, yArr)];
|
|
6631
|
+
}
|
|
6056
6632
|
/** Round input to the nearest integer towards zero. */
|
|
6057
6633
|
function trunc(x) {
|
|
6058
6634
|
return idiv(x, 1);
|
|
@@ -6216,14 +6792,15 @@ function std(x, axis = null, opts) {
|
|
|
6216
6792
|
return sqrt(var_(x, axis, opts));
|
|
6217
6793
|
}
|
|
6218
6794
|
/** Estimate the sample covariance of a set of variables. */
|
|
6219
|
-
function cov(x, y) {
|
|
6795
|
+
function cov(x, y = null, { rowvar = true } = {}) {
|
|
6220
6796
|
x = fudgeArray(x);
|
|
6221
6797
|
if (x.ndim === 1) x = x.reshape([1, x.shape[0]]);
|
|
6222
|
-
if (y !==
|
|
6798
|
+
if (y !== null) {
|
|
6223
6799
|
y = fudgeArray(y);
|
|
6224
6800
|
if (y.ndim === 1) y = y.reshape([1, y.shape[0]]);
|
|
6225
6801
|
x = vstack([x, y]);
|
|
6226
6802
|
}
|
|
6803
|
+
if (!rowvar) x = x.transpose();
|
|
6227
6804
|
const [_M, N] = x.shape;
|
|
6228
6805
|
x = x.ref.sub(x.mean(1, { keepdims: true }));
|
|
6229
6806
|
return dot$1(x.ref, x.transpose()).div(N - 1);
|
|
@@ -6268,7 +6845,8 @@ const isfinite = jit$1(function isfinite$1(x) {
|
|
|
6268
6845
|
//#region src/library/lax-linalg.ts
|
|
6269
6846
|
var lax_linalg_exports = {};
|
|
6270
6847
|
__export(lax_linalg_exports, {
|
|
6271
|
-
cholesky: () => cholesky,
|
|
6848
|
+
cholesky: () => cholesky$1,
|
|
6849
|
+
lu: () => lu,
|
|
6272
6850
|
triangularSolve: () => triangularSolve
|
|
6273
6851
|
});
|
|
6274
6852
|
/**
|
|
@@ -6297,11 +6875,39 @@ __export(lax_linalg_exports, {
|
|
|
6297
6875
|
* // U ≈ [[1.4142135, 0.70710677], [0, 1.2247449]]
|
|
6298
6876
|
* ```
|
|
6299
6877
|
*/
|
|
6300
|
-
function cholesky(a, { upper = false } = {}) {
|
|
6878
|
+
function cholesky$1(a, { upper = false } = {}) {
|
|
6301
6879
|
const L = cholesky$2(a);
|
|
6302
6880
|
return upper ? moveaxis$1(L, -2, -1) : L;
|
|
6303
6881
|
}
|
|
6304
6882
|
/**
|
|
6883
|
+
* LU decomposition with partial pivoting.
|
|
6884
|
+
*
|
|
6885
|
+
* Computes the matrix decomposition: `P @ A = L @ U`, where `P` is a
|
|
6886
|
+
* permutation of the rows of `A`, `L` is lower-triangular with unit diagonal,
|
|
6887
|
+
* and `U` is upper-triangular.
|
|
6888
|
+
*
|
|
6889
|
+
* @param x - A batch of matrices with shape `[..., m, n]`.
|
|
6890
|
+
*
|
|
6891
|
+
* @returns A tuple `(lu, pivots, permutation)` where:
|
|
6892
|
+
* - `lu`: combined lower and upper triangular matrices.
|
|
6893
|
+
* - `pivots`: an array of pivot indices with shape `[..., min(m, n)]`.
|
|
6894
|
+
* - `permutation`: the permutation generated by pivots with shape `[..., m]`.
|
|
6895
|
+
*
|
|
6896
|
+
* @example
|
|
6897
|
+
* ```ts
|
|
6898
|
+
* import { lax, numpy as np } from "@jax-js/jax";
|
|
6899
|
+
*
|
|
6900
|
+
* const A = np.array([[4., 3.], [6., 3.]]);
|
|
6901
|
+
* const [lu, pivots, permutation] = lax.linalg.lu(A);
|
|
6902
|
+
* // lu ≈ [[6., 3.], [0.6666667, 1.0]]
|
|
6903
|
+
* // pivots = [1, 1]
|
|
6904
|
+
* // permutation = [1, 0]
|
|
6905
|
+
* ```
|
|
6906
|
+
*/
|
|
6907
|
+
function lu(x) {
|
|
6908
|
+
return lu$1(x);
|
|
6909
|
+
}
|
|
6910
|
+
/**
|
|
6305
6911
|
* Solve a triangular linear system.
|
|
6306
6912
|
*
|
|
6307
6913
|
* Solves `a @ x = b` (if leftSide=true) or `x @ a = b` (if leftSide=false)
|
|
@@ -6339,6 +6945,7 @@ var lax_exports = {};
|
|
|
6339
6945
|
__export(lax_exports, {
|
|
6340
6946
|
conv: () => conv,
|
|
6341
6947
|
convGeneralDilated: () => convGeneralDilated,
|
|
6948
|
+
convTranspose: () => convTranspose,
|
|
6342
6949
|
convWithGeneralPadding: () => convWithGeneralPadding,
|
|
6343
6950
|
dot: () => dot,
|
|
6344
6951
|
erf: () => erf,
|
|
@@ -6347,6 +6954,7 @@ __export(lax_exports, {
|
|
|
6347
6954
|
reduceWindow: () => reduceWindow,
|
|
6348
6955
|
stopGradient: () => stopGradient$1
|
|
6349
6956
|
});
|
|
6957
|
+
const JsArray = globalThis.Array;
|
|
6350
6958
|
/**
|
|
6351
6959
|
* General dot product/contraction operator.
|
|
6352
6960
|
*
|
|
@@ -6418,7 +7026,11 @@ function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
|
|
|
6418
7026
|
* The semantics of this operation mimic the `jax.lax.conv_general_dilated`
|
|
6419
7027
|
* function in JAX, which wraps XLA's general convolution operator.
|
|
6420
7028
|
*
|
|
6421
|
-
*
|
|
7029
|
+
* @param lhs - Input tensor; shape `[N, C_in, ...xs]`
|
|
7030
|
+
* @param rhs - Convolution kernel; shape `[C_out, C_in / G, ...ks]`
|
|
7031
|
+
* @param windowStrides - Strides for each spatial dimension
|
|
7032
|
+
* @param padding - Padding for each spatial dimension, or a string
|
|
7033
|
+
* (`"VALID"`, `"SAME"`, or `"SAME_LOWER"`)
|
|
6422
7034
|
*/
|
|
6423
7035
|
function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
|
|
6424
7036
|
if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
|
|
@@ -6478,6 +7090,60 @@ function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, r
|
|
|
6478
7090
|
function conv(lhs, rhs, windowStrides, padding) {
|
|
6479
7091
|
return convGeneralDilated(lhs, rhs, windowStrides, padding);
|
|
6480
7092
|
}
|
|
7093
|
+
/**
|
|
7094
|
+
* Convenience wrapper for calculating the N-d convolution "transpose".
|
|
7095
|
+
*
|
|
7096
|
+
* This function directly calculates a fractionally strided conv rather than
|
|
7097
|
+
* indirectly calculating the gradient (transpose) of a forward convolution.
|
|
7098
|
+
* It is equivalent to the JAX version, except:
|
|
7099
|
+
*
|
|
7100
|
+
* - The `use_consistent_padding` option is not available. We only have the
|
|
7101
|
+
* consistent padding case (JAX version >0.8.4).
|
|
7102
|
+
* - The order of dimensions matches `lax.conv_general_dilated`.
|
|
7103
|
+
*
|
|
7104
|
+
* Unlike PyTorch/TensorFlow, by default we don't reverse the kernel's spatial
|
|
7105
|
+
* dimensions or the `(C_out, C_in)` axis order. To get this behavior, set
|
|
7106
|
+
* `transposeKernel` to true.
|
|
7107
|
+
*
|
|
7108
|
+
* @param lhs - Input tensor; shape `[N, C_in, ...xs]`
|
|
7109
|
+
* @param rhs - Convolution kernel; shape `[C_out, C_in, ...ks]`
|
|
7110
|
+
* @param strides - Sequence of n integers, sets fractional stride
|
|
7111
|
+
* @param padding - Apply padding of `dilation * (kernel_size - 1) - padding` to
|
|
7112
|
+
* each side of the input, so it acts like gradient of `conv()`
|
|
7113
|
+
* @param rhsDilation - Atrous dilation for the kernel
|
|
7114
|
+
* @param transposeKernel - Flip spatial axes and swap the input/output channels
|
|
7115
|
+
* of the kernel; its shape should be `[C_in, C_out, ...ks]`
|
|
7116
|
+
*/
|
|
7117
|
+
function convTranspose(lhs, rhs, strides, padding, { rhsDilation, transposeKernel = false } = {}) {
|
|
7118
|
+
const kernelShape = rhs.shape.slice(2);
|
|
7119
|
+
rhsDilation = rhsDilation ?? rep(kernelShape.length, 1);
|
|
7120
|
+
const effectiveKernel = kernelShape.map((k, i) => Math.max(0, (k - 1) * rhsDilation[i] + 1));
|
|
7121
|
+
const pads = effectiveKernel.map((k, i) => convTransposePadding(k, strides[i], typeof padding === "string" ? padding : padding[i]));
|
|
7122
|
+
if (transposeKernel) {
|
|
7123
|
+
rhs = flip$1(rhs, range(2, rhs.ndim));
|
|
7124
|
+
rhs = moveaxis(rhs, 0, 1);
|
|
7125
|
+
}
|
|
7126
|
+
return convGeneralDilated(lhs, rhs, rep(lhs.ndim - 2, 1), pads, {
|
|
7127
|
+
lhsDilation: strides,
|
|
7128
|
+
rhsDilation
|
|
7129
|
+
});
|
|
7130
|
+
}
|
|
7131
|
+
function convTransposePadding(k, s, padding) {
|
|
7132
|
+
let padLen;
|
|
7133
|
+
let pad1;
|
|
7134
|
+
if (padding === "SAME") {
|
|
7135
|
+
padLen = k + s - 2;
|
|
7136
|
+
pad1 = s > k - 1 ? k - 1 : Math.ceil(padLen / 2);
|
|
7137
|
+
} else if (padding === "VALID") {
|
|
7138
|
+
padLen = k + s - 2 + Math.max(k - s, 0);
|
|
7139
|
+
pad1 = k - 1;
|
|
7140
|
+
} else if (JsArray.isArray(padding)) {
|
|
7141
|
+
const pads = [k - 1 - padding[0], k - 1 - padding[1]];
|
|
7142
|
+
pad1 = pads[0];
|
|
7143
|
+
padLen = pads[0] + pads[1];
|
|
7144
|
+
} else throw new Error(`convTranspose: Invalid padding type ${padding}`);
|
|
7145
|
+
return [pad1, padLen - pad1];
|
|
7146
|
+
}
|
|
6481
7147
|
/** Reduce a computation over padded windows. */
|
|
6482
7148
|
function reduceWindow(operand, computation, windowDimensions, windowStrides) {
|
|
6483
7149
|
if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
|
|
@@ -6516,6 +7182,7 @@ function stopGradient$1(x) {
|
|
|
6516
7182
|
var nn_exports = {};
|
|
6517
7183
|
__export(nn_exports, {
|
|
6518
7184
|
celu: () => celu,
|
|
7185
|
+
dotProductAttention: () => dotProductAttention,
|
|
6519
7186
|
elu: () => elu,
|
|
6520
7187
|
gelu: () => gelu,
|
|
6521
7188
|
glu: () => glu,
|
|
@@ -6832,6 +7499,95 @@ function oneHot(x, numClasses) {
|
|
|
6832
7499
|
if (isFloatDtype(x.dtype)) throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
|
|
6833
7500
|
return eye(numClasses, void 0, { device: x.device }).slice(x);
|
|
6834
7501
|
}
|
|
7502
|
+
/**
|
|
7503
|
+
* Scaled dot product attention (SDPA).
|
|
7504
|
+
*
|
|
7505
|
+
* Computes `softmax((Q @ K^T) / sqrt(d) + bias) @ V`, where `Q` is the query,
|
|
7506
|
+
* `K` is the key, `V` is the value, and `d` is the dimensionality of each key
|
|
7507
|
+
* and query vector.
|
|
7508
|
+
*
|
|
7509
|
+
* Multi-query attention is applied when input `key` and `value` tensors have
|
|
7510
|
+
* fewer heads than `query`.
|
|
7511
|
+
*
|
|
7512
|
+
* We use the following uppercase letters to denote array shapes:
|
|
7513
|
+
* - `B` = batch size
|
|
7514
|
+
* - `S` = length of key/value sequences (source)
|
|
7515
|
+
* - `L` = length of query sequences
|
|
7516
|
+
* - `N` = number of attention heads
|
|
7517
|
+
* - `H` = dimensionality of each attention head
|
|
7518
|
+
* - `K` = number of key/value heads (for grouped-query attention)
|
|
7519
|
+
*
|
|
7520
|
+
* The batch size `B` may be omitted, which is equivalent to `B = 1`. In this
|
|
7521
|
+
* case it must be omitted from all inputs.
|
|
7522
|
+
*
|
|
7523
|
+
* @param query - Query array; shape `[B, L, N, H]`
|
|
7524
|
+
* @param key - Key array; shape `[B, S, K, H]`
|
|
7525
|
+
* @param value - Value array; same shape as `key`
|
|
7526
|
+
* @param opts.bias - Optional bias to add to the attention logits; shape
|
|
7527
|
+
* `[B, N, L, S]` or broadcastable to it.
|
|
7528
|
+
* @param opts.mask - Optional mask to apply to the attention logits; should be
|
|
7529
|
+
* a boolean array broadcastable to `[B, N, L, S]`, where `true` indicates
|
|
7530
|
+
* the element should take part in attention.
|
|
7531
|
+
* @param opts.scale - Scaling factor override, default is `1 / sqrt(H)`.
|
|
7532
|
+
* @param opts.isCausal - If true, applies a casual mask.
|
|
7533
|
+
* @param opts.querySeqLengths - Optional sequence lengths for the queries;
|
|
7534
|
+
* shape `(B,)`. Taken from the beginning of the tensor.
|
|
7535
|
+
* @param opts.keyValueSeqLengths - Optional sequence lengths for the keys and
|
|
7536
|
+
* values; shape `(B,)`. Taken from the beginning of the tensor.
|
|
7537
|
+
* @param opts.localWindowSize - If specified, applies a local attention window
|
|
7538
|
+
* of the given size. Can be a single number or a tuple `[left, right]`.
|
|
7539
|
+
*
|
|
7540
|
+
* @returns The result of the attention operation; shape is the same as query
|
|
7541
|
+
* `[B, L, N, H]`, or `[L, N, H]` if `B` is omitted.
|
|
7542
|
+
*/
|
|
7543
|
+
function dotProductAttention(query, key$1, value, opts = {}) {
|
|
7544
|
+
if (opts.querySeqLengths !== void 0 || opts.keyValueSeqLengths !== void 0) throw new Error("Sequence length masking is not yet implemented");
|
|
7545
|
+
if (opts.localWindowSize !== void 0) throw new Error("Local attention is not yet implemented");
|
|
7546
|
+
query = fudgeArray(query);
|
|
7547
|
+
key$1 = fudgeArray(key$1);
|
|
7548
|
+
value = fudgeArray(value);
|
|
7549
|
+
if (query.ndim !== 3 && query.ndim !== 4 || query.ndim !== key$1.ndim || query.ndim !== value.ndim) throw new Error(`dotProductAttention: expected all tensors to have rank 3 or 4, got Q=${query.aval}, K=${key$1.aval}, V=${value.aval}`);
|
|
7550
|
+
if (!deepEqual(key$1.shape, value.shape)) throw new Error(`dotProductAttention: key and value shapes must match, got K=${key$1.shape}, V=${value.shape}`);
|
|
7551
|
+
const isRank3 = query.ndim === 3;
|
|
7552
|
+
if (isRank3) {
|
|
7553
|
+
query = expandDims(query, 0);
|
|
7554
|
+
key$1 = expandDims(key$1, 0);
|
|
7555
|
+
value = expandDims(value, 0);
|
|
7556
|
+
}
|
|
7557
|
+
const [B, L, N, H] = query.shape;
|
|
7558
|
+
if (key$1.shape[0] !== B || key$1.shape[3] !== H) throw new Error(`dotProductAttention: query and key shapes mismatch, got Q=${query.aval}, K=${key$1.aval}`);
|
|
7559
|
+
const S = key$1.shape[1];
|
|
7560
|
+
const K = key$1.shape[2];
|
|
7561
|
+
if (N < K || N != K && N % K !== 0) throw new Error(`dotProductAttention: number of query heads N=${N} must be divisible by number of key/value heads K=${K} for GQA`);
|
|
7562
|
+
const G = N / K;
|
|
7563
|
+
key$1 = tile(key$1, [
|
|
7564
|
+
1,
|
|
7565
|
+
1,
|
|
7566
|
+
G,
|
|
7567
|
+
1
|
|
7568
|
+
]);
|
|
7569
|
+
value = tile(value, [
|
|
7570
|
+
1,
|
|
7571
|
+
1,
|
|
7572
|
+
G,
|
|
7573
|
+
1
|
|
7574
|
+
]);
|
|
7575
|
+
const scale = opts.scale ?? 1 / Math.sqrt(H);
|
|
7576
|
+
let scores = einsum("BLNH,BSNH->BNLS", query, key$1).mul(scale);
|
|
7577
|
+
if (opts.bias !== void 0) scores = scores.add(opts.bias);
|
|
7578
|
+
if (opts.mask !== void 0) scores = where(opts.mask, scores, -Infinity);
|
|
7579
|
+
if (opts.isCausal) {
|
|
7580
|
+
const causalMask = tri(L, S, 0, { dtype: DType.Bool });
|
|
7581
|
+
scores = where(causalMask, scores, -Infinity);
|
|
7582
|
+
}
|
|
7583
|
+
const attn = softmax(scores, -1);
|
|
7584
|
+
const out = einsum("BNLS,BSNH->BLNH", attn, value);
|
|
7585
|
+
return isRank3 ? out.reshape([
|
|
7586
|
+
L,
|
|
7587
|
+
N,
|
|
7588
|
+
H
|
|
7589
|
+
]) : out;
|
|
7590
|
+
}
|
|
6835
7591
|
|
|
6836
7592
|
//#endregion
|
|
6837
7593
|
//#region src/library/random.ts
|
|
@@ -6844,33 +7600,41 @@ __export(random_exports, {
|
|
|
6844
7600
|
gumbel: () => gumbel,
|
|
6845
7601
|
key: () => key,
|
|
6846
7602
|
laplace: () => laplace,
|
|
7603
|
+
multivariateNormal: () => multivariateNormal,
|
|
6847
7604
|
normal: () => normal,
|
|
6848
7605
|
split: () => split,
|
|
6849
7606
|
uniform: () => uniform
|
|
6850
7607
|
});
|
|
6851
|
-
function validateKeyShape(key$1) {
|
|
7608
|
+
function validateKeyShape(key$1, scalar = false) {
|
|
6852
7609
|
if (key$1.ndim === 0) throw new Error("Key must have at least one dimension.");
|
|
6853
7610
|
if (key$1.shape[key$1.shape.length - 1] !== 2) throw new Error(`Invalid key shape: ${key$1.shape}. Expected last dimension to be 2.`);
|
|
7611
|
+
if (scalar && key$1.shape.length > 1) throw new Error(`Expected a single PRNG key, but got a batch of keys with shape ${JSON.stringify(key$1.shape)} - use jax.vmap for batching.`);
|
|
6854
7612
|
return key$1.shape.slice(0, -1);
|
|
6855
7613
|
}
|
|
7614
|
+
function getK01(key$1) {
|
|
7615
|
+
const keyShape = validateKeyShape(key$1, true);
|
|
7616
|
+
let [k0, k1] = split$2(key$1, -1, [1, 1]);
|
|
7617
|
+
k0 = k0.reshape(keyShape);
|
|
7618
|
+
k1 = k1.reshape(keyShape);
|
|
7619
|
+
return [k0, k1];
|
|
7620
|
+
}
|
|
6856
7621
|
/** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
|
|
6857
7622
|
function key(seed) {
|
|
6858
|
-
seed = seed
|
|
6859
|
-
|
|
7623
|
+
seed = array(seed, { dtype: DType.Uint32 });
|
|
7624
|
+
if (seed.ndim !== 0) throw new Error(`key: seed must be a scalar integer, but got shape ${seed.shape} - use jax.vmap for batching.`);
|
|
7625
|
+
return stack([0, seed]);
|
|
6860
7626
|
}
|
|
6861
7627
|
/** Splits a PRNG key into `num` new keys by adding a leading axis. */
|
|
6862
7628
|
function split(key$1, num = 2) {
|
|
6863
7629
|
const shape$1 = typeof num === "number" ? [num] : num;
|
|
6864
7630
|
for (const len of shape$1) if (len <= 0 || !Number.isInteger(len)) throw new Error(`Invalid split length: ${len}. Must be a positive integer.`);
|
|
6865
|
-
const
|
|
6866
|
-
const k0 = key$1.ref.slice(...keyShape.map(() => null), 0);
|
|
6867
|
-
const k1 = key$1.slice(...keyShape.map(() => null), 1);
|
|
7631
|
+
const [k0, k1] = getK01(key$1);
|
|
6868
7632
|
return stack([randomBits(k0.ref, k1.ref, shape$1, 0), randomBits(k0, k1, shape$1, 1)], -1);
|
|
6869
7633
|
}
|
|
6870
7634
|
/** Sample uniform bits in the form of unsigned integers. */
|
|
6871
7635
|
function bits(key$1, shape$1 = []) {
|
|
6872
|
-
const
|
|
6873
|
-
return randomBits(
|
|
7636
|
+
const [k0, k1] = getK01(key$1);
|
|
7637
|
+
return randomBits(k0, k1, shape$1);
|
|
6874
7638
|
}
|
|
6875
7639
|
/**
|
|
6876
7640
|
* @function
|
|
@@ -6944,6 +7708,32 @@ const laplace = jit$1(function laplace$1(key$1, shape$1 = []) {
|
|
|
6944
7708
|
}, { staticArgnums: [1] });
|
|
6945
7709
|
/**
|
|
6946
7710
|
* @function
|
|
7711
|
+
* Sample multivariate normal random values with given mean and covariance.
|
|
7712
|
+
*
|
|
7713
|
+
* The values are returned with the given shape, along with the final dimension
|
|
7714
|
+
* used to represent the n-dimensional multivariate normal factors.
|
|
7715
|
+
*
|
|
7716
|
+
* This uses Cholesky decomposition on the covariance matrix.
|
|
7717
|
+
*
|
|
7718
|
+
* - `key` - PRNG key
|
|
7719
|
+
* - `mean` - Mean vector of shape `[..., n]`
|
|
7720
|
+
* - `cov` - Covariance of shape `[..., n, n]`, must be positive-definite
|
|
7721
|
+
* - `shape` - Result batch shape, must be broadcastable with
|
|
7722
|
+
* `mean.shape[:-1]` and `cov.shape[:-2]`
|
|
7723
|
+
* @returns Random samples of shape `[...shape, n]`
|
|
7724
|
+
*/
|
|
7725
|
+
const multivariateNormal = jit$1(function multivariateNormal$1(key$1, mean$1, cov$1, shape$1 = []) {
|
|
7726
|
+
mean$1 = fudgeArray(mean$1);
|
|
7727
|
+
cov$1 = fudgeArray(cov$1);
|
|
7728
|
+
const n = mean$1.shape[mean$1.ndim - 1];
|
|
7729
|
+
if (cov$1.shape[cov$1.ndim - 1] !== n || cov$1.shape[cov$1.ndim - 2] !== n) throw new Error(`Invalid covariance shape: ${cov$1.shape}. Expected last two dimensions to be [${n}, ${n}].`);
|
|
7730
|
+
const outputShape = broadcastShapes(shape$1, mean$1.shape.slice(0, -1), cov$1.shape.slice(0, -2)).concat(n);
|
|
7731
|
+
const L = cholesky(cov$1);
|
|
7732
|
+
const z = normal(key$1, outputShape);
|
|
7733
|
+
return einsum("...ij,...j->...i", L, z).add(mean$1);
|
|
7734
|
+
}, { staticArgnums: [3] });
|
|
7735
|
+
/**
|
|
7736
|
+
* @function
|
|
6947
7737
|
* Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
|
|
6948
7738
|
*
|
|
6949
7739
|
* Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
|
|
@@ -7033,17 +7823,62 @@ const linearize = linearize$1;
|
|
|
7033
7823
|
/**
|
|
7034
7824
|
* @function
|
|
7035
7825
|
* Calculate the reverse-mode vector-Jacobian product for a function.
|
|
7826
|
+
*
|
|
7827
|
+
* The return value is a tuple of `[out, vjpFn]`, where `out` is the output of
|
|
7828
|
+
* `f(primals)`, and `vjpFn` is a function that takes in cotangents for each
|
|
7829
|
+
* output and returns the cotangents for each input.
|
|
7830
|
+
*
|
|
7831
|
+
* When `{ hasAux: true }` is passed, the function `f` is expected to return an
|
|
7832
|
+
* `[out, aux]` tuple, and `vjp` returns `[out, vjpFn, aux]`.
|
|
7833
|
+
*
|
|
7834
|
+
* @example
|
|
7835
|
+
* ```ts
|
|
7836
|
+
* const [y, vjpFn] = vjp(f, [x]);
|
|
7837
|
+
*
|
|
7838
|
+
* // With hasAux
|
|
7839
|
+
* const [y, vjpFn, aux] = vjp(f, [x], { hasAux: true });
|
|
7840
|
+
* ```
|
|
7036
7841
|
*/
|
|
7037
7842
|
const vjp = vjp$1;
|
|
7038
7843
|
/**
|
|
7039
7844
|
* @function
|
|
7040
7845
|
* Compute the gradient of a scalar-valued function `f` with respect to its
|
|
7041
7846
|
* first argument.
|
|
7847
|
+
*
|
|
7848
|
+
* Pass in different `argnums` to differentiate with respect to other
|
|
7849
|
+
* arguments. If a tuple is provided, the return value will be a tuple of
|
|
7850
|
+
* gradients corresponding to each argument index.
|
|
7851
|
+
*
|
|
7852
|
+
* When `{ hasAux: true }` is passed, the function `f` is expected to return a
|
|
7853
|
+
* `[out, aux]` tuple, and the return value will be `[gradient, aux]`.
|
|
7854
|
+
*
|
|
7855
|
+
* @example
|
|
7856
|
+
* ```ts
|
|
7857
|
+
* const gradient = grad(f)(x);
|
|
7858
|
+
*
|
|
7859
|
+
* // With `argnums`
|
|
7860
|
+
* const [gradientX, gradientZ] = grad(f, { argnums: [0, 2] })(x, y, z);
|
|
7861
|
+
*
|
|
7862
|
+
* // With `hasAux`
|
|
7863
|
+
* const [gradient, aux] = grad(f, { hasAux: true })(x);
|
|
7864
|
+
* ```
|
|
7042
7865
|
*/
|
|
7043
7866
|
const grad = grad$1;
|
|
7044
7867
|
/**
|
|
7045
7868
|
* @function
|
|
7046
7869
|
* Create a function that evaluates both `f` and the gradient of `f`.
|
|
7870
|
+
*
|
|
7871
|
+
* When `{ hasAux: true }` is passed, the function `f` is expected to return an
|
|
7872
|
+
* `[out, aux]` tuple, and the return value will be `[[out, aux], gradient]`.
|
|
7873
|
+
*
|
|
7874
|
+
* @example
|
|
7875
|
+
* ```ts
|
|
7876
|
+
* // Without hasAux
|
|
7877
|
+
* const [value, gradient] = valueAndGrad(f)(x);
|
|
7878
|
+
*
|
|
7879
|
+
* // With hasAux
|
|
7880
|
+
* const [[value, aux], gradient] = valueAndGrad(f, { hasAux: true })(x);
|
|
7881
|
+
* ```
|
|
7047
7882
|
*/
|
|
7048
7883
|
const valueAndGrad = valueAndGrad$1;
|
|
7049
7884
|
/**
|
|
@@ -7052,6 +7887,21 @@ const valueAndGrad = valueAndGrad$1;
|
|
|
7052
7887
|
*/
|
|
7053
7888
|
const jacrev = jacrev$1;
|
|
7054
7889
|
/**
|
|
7890
|
+
* @function
|
|
7891
|
+
* Compute the Hessian matrix of a scalar-valued function.
|
|
7892
|
+
*
|
|
7893
|
+
* The Hessian is the matrix of second-order partial derivatives of a function.
|
|
7894
|
+
* This is implemented as `jacfwd(grad(f))`.
|
|
7895
|
+
*
|
|
7896
|
+
* @example
|
|
7897
|
+
* ```ts
|
|
7898
|
+
* const f = (x: np.Array) => np.sum(x.ref.mul(x.ref).mul(x)); // x^3
|
|
7899
|
+
* const H = hessian(f)(np.array([1, 2, 3]));
|
|
7900
|
+
* // H[i,j] = d^2f / dx_i dx_j
|
|
7901
|
+
* ```
|
|
7902
|
+
*/
|
|
7903
|
+
const hessian = hessian$1;
|
|
7904
|
+
/**
|
|
7055
7905
|
* Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
|
|
7056
7906
|
*
|
|
7057
7907
|
* This can be used to wait for the results of an intermediate computation to
|
|
@@ -7086,4 +7936,4 @@ async function devicePut(x, device) {
|
|
|
7086
7936
|
}
|
|
7087
7937
|
|
|
7088
7938
|
//#endregion
|
|
7089
|
-
export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
|
|
7939
|
+
export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
|