@jax-js/jax 0.1.5 → 0.1.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +60 -7
- package/dist/{backend-DziQSaoQ.cjs → backend-B3foXiV_.cjs} +25 -6
- package/dist/{backend-DaqL-MNz.js → backend-nEolvdLv.js} +20 -7
- package/dist/index.cjs +450 -129
- package/dist/index.d.cts +1669 -1467
- package/dist/index.d.ts +1669 -1467
- package/dist/index.js +450 -130
- package/dist/{webgl-ClIYb8jP.cjs → webgl-DIIbKJ0G.cjs} +1 -1
- package/dist/{webgl-RSuZKvgc.js → webgl-DweKSWEm.js} +1 -1
- package/dist/{webgpu-Dh7k9io0.js → webgpu-B96vzWGE.js} +1 -1
- package/dist/{webgpu-Db2JrNBr.cjs → webgpu-BykvF26B.cjs} +1 -1
- package/package.json +1 -1
package/dist/index.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { __export } from "./chunk-Cl8Af3a2.js";
|
|
2
|
-
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-
|
|
2
|
+
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-nEolvdLv.js";
|
|
3
3
|
|
|
4
4
|
//#region src/frontend/convolution.ts
|
|
5
5
|
/**
|
|
@@ -209,7 +209,7 @@ __export(tree_exports, {
|
|
|
209
209
|
structure: () => structure,
|
|
210
210
|
unflatten: () => unflatten
|
|
211
211
|
});
|
|
212
|
-
const JsArray$
|
|
212
|
+
const JsArray$2 = globalThis.Array;
|
|
213
213
|
let NodeType = /* @__PURE__ */ function(NodeType$1) {
|
|
214
214
|
NodeType$1["Array"] = "Array";
|
|
215
215
|
NodeType$1["Object"] = "Object";
|
|
@@ -257,7 +257,7 @@ function flatten(tree) {
|
|
|
257
257
|
return [leaves$1, treedef];
|
|
258
258
|
}
|
|
259
259
|
function _flatten(tree, leaves$1) {
|
|
260
|
-
if (JsArray$
|
|
260
|
+
if (JsArray$2.isArray(tree)) {
|
|
261
261
|
const childTrees = tree.map((c) => _flatten(c, leaves$1));
|
|
262
262
|
return new JsTreeDef(NodeType.Array, null, childTrees);
|
|
263
263
|
} else if (typeof tree === "object" && tree !== null && tree.constructor === Object) {
|
|
@@ -306,11 +306,11 @@ function map(fn, tree, ...rest) {
|
|
|
306
306
|
}
|
|
307
307
|
/** Take a reference of every array in a tree. */
|
|
308
308
|
function ref(tree) {
|
|
309
|
-
return map((x) => x.ref, tree);
|
|
309
|
+
return map((x) => x instanceof Tracer ? x.ref : x, tree);
|
|
310
310
|
}
|
|
311
311
|
/** Dispose every array in a tree. */
|
|
312
312
|
function dispose(tree) {
|
|
313
|
-
if (tree) map((x) => x.dispose(), tree);
|
|
313
|
+
if (tree) map((x) => x instanceof Tracer ? x.dispose() : void 0, tree);
|
|
314
314
|
}
|
|
315
315
|
|
|
316
316
|
//#endregion
|
|
@@ -381,6 +381,13 @@ let CompareOp = /* @__PURE__ */ function(CompareOp$1) {
|
|
|
381
381
|
CompareOp$1["LessEqual"] = "less_equal";
|
|
382
382
|
return CompareOp$1;
|
|
383
383
|
}({});
|
|
384
|
+
const routinePrimitives = new Map([
|
|
385
|
+
[Primitive.Sort, Routines.Sort],
|
|
386
|
+
[Primitive.Argsort, Routines.Argsort],
|
|
387
|
+
[Primitive.TriangularSolve, Routines.TriangularSolve],
|
|
388
|
+
[Primitive.Cholesky, Routines.Cholesky],
|
|
389
|
+
[Primitive.LU, Routines.LU]
|
|
390
|
+
]);
|
|
384
391
|
function add$1(x, y) {
|
|
385
392
|
return bind1(Primitive.Add, [x, y]);
|
|
386
393
|
}
|
|
@@ -577,14 +584,20 @@ function shrink(x, slice) {
|
|
|
577
584
|
}
|
|
578
585
|
function pad$1(x, width) {
|
|
579
586
|
const nd = ndim$1(x);
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
else if (
|
|
583
|
-
if (width
|
|
584
|
-
const
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
587
|
+
let w;
|
|
588
|
+
if (typeof width === "number") w = [[width, width]];
|
|
589
|
+
else if (isNumberPair(width)) w = [width];
|
|
590
|
+
else if (!Array.isArray(width)) {
|
|
591
|
+
const indicesAndPairs = Object.entries(width);
|
|
592
|
+
w = rep(nd, [0, 0]);
|
|
593
|
+
for (const [k, v] of indicesAndPairs) w[checkAxis(parseInt(k), nd)] = v;
|
|
594
|
+
} else if (!width.every(isNumberPair)) throw new TypeError(`Invalid pad() type: ${JSON.stringify(width)}`);
|
|
595
|
+
else w = width;
|
|
596
|
+
if (w.length === 1) {
|
|
597
|
+
const [w0, w1] = w[0];
|
|
598
|
+
w = rep(nd, () => [w0, w1]);
|
|
599
|
+
} else if (w.length !== nd) throw new Error(`Invalid pad(): expected ${nd} axes, got ${w.length}`);
|
|
600
|
+
return bind1(Primitive.Pad, [x], { width: w });
|
|
588
601
|
}
|
|
589
602
|
function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
|
|
590
603
|
const as = getShape(a);
|
|
@@ -654,6 +667,9 @@ function newDynamic(main) {
|
|
|
654
667
|
dynamicTrace = prevDynamicTrace;
|
|
655
668
|
} };
|
|
656
669
|
}
|
|
670
|
+
function currentTraceLevel() {
|
|
671
|
+
return traceStack[traceStack.length - 1].level;
|
|
672
|
+
}
|
|
657
673
|
var Trace = class {
|
|
658
674
|
constructor(main) {
|
|
659
675
|
this.main = main;
|
|
@@ -757,6 +773,22 @@ var Tracer = class Tracer {
|
|
|
757
773
|
const result = reduce(this.astype(castDtype), AluOp.Add, axis, opts);
|
|
758
774
|
return result.mul(1 / n).astype(originalDtype);
|
|
759
775
|
}
|
|
776
|
+
/** Minimum of the elements of the array along a given axis. */
|
|
777
|
+
min(axis = null, opts) {
|
|
778
|
+
return reduce(this, AluOp.Min, axis, opts);
|
|
779
|
+
}
|
|
780
|
+
/** Maximum of the elements of the array along a given axis. */
|
|
781
|
+
max(axis = null, opts) {
|
|
782
|
+
return reduce(this, AluOp.Max, axis, opts);
|
|
783
|
+
}
|
|
784
|
+
/** Test whether all array elements along a given axis evaluate to true. */
|
|
785
|
+
all(axis = null, opts) {
|
|
786
|
+
return this.astype(DType.Bool).min(axis, opts);
|
|
787
|
+
}
|
|
788
|
+
/** Test whether any array element along a given axis evaluates to true. */
|
|
789
|
+
any(axis = null, opts) {
|
|
790
|
+
return this.astype(DType.Bool).max(axis, opts);
|
|
791
|
+
}
|
|
760
792
|
/** Permute the dimensions of an array. Defaults to reversing the axis order. */
|
|
761
793
|
transpose(perm) {
|
|
762
794
|
return transpose$1(this, perm);
|
|
@@ -1031,6 +1063,7 @@ var TreeMismatchError = class extends TypeError {
|
|
|
1031
1063
|
super(`Mismatched tree structures in ${where$2}: ${left} != ${right}`);
|
|
1032
1064
|
}
|
|
1033
1065
|
};
|
|
1066
|
+
/** Flatten a function of `JsTree` input/output for use in tracing. */
|
|
1034
1067
|
function flattenFun(f, inTree) {
|
|
1035
1068
|
const store = { value: void 0 };
|
|
1036
1069
|
const flatFun = (...argsFlat) => {
|
|
@@ -1042,6 +1075,26 @@ function flattenFun(f, inTree) {
|
|
|
1042
1075
|
};
|
|
1043
1076
|
return [flatFun, store];
|
|
1044
1077
|
}
|
|
1078
|
+
/** Like flattenFun, but expects f to return [main, aux] tuple. */
|
|
1079
|
+
function flattenFunWithAux(f, inTree) {
|
|
1080
|
+
const store = { value: void 0 };
|
|
1081
|
+
const auxStore = { value: void 0 };
|
|
1082
|
+
const flatFun = (...argsFlat) => {
|
|
1083
|
+
const pytreeArgs = unflatten(inTree, argsFlat);
|
|
1084
|
+
const result = f(...pytreeArgs);
|
|
1085
|
+
if (!Array.isArray(result) || result.length !== 2) throw new Error("Function with `hasAux: true` must return [output, aux] tuple");
|
|
1086
|
+
const [out, aux] = result;
|
|
1087
|
+
const [outFlat, outTree] = flatten(out);
|
|
1088
|
+
store.value = outTree;
|
|
1089
|
+
auxStore.value = aux;
|
|
1090
|
+
return outFlat;
|
|
1091
|
+
};
|
|
1092
|
+
return [
|
|
1093
|
+
flatFun,
|
|
1094
|
+
store,
|
|
1095
|
+
auxStore
|
|
1096
|
+
];
|
|
1097
|
+
}
|
|
1045
1098
|
var UseAfterFreeError = class extends ReferenceError {
|
|
1046
1099
|
constructor(tracer) {
|
|
1047
1100
|
super(`Referenced tracer ${tracer.toString()} freed, please use .ref move semantics`);
|
|
@@ -1771,13 +1824,6 @@ function jit$1(f, opts) {
|
|
|
1771
1824
|
|
|
1772
1825
|
//#endregion
|
|
1773
1826
|
//#region src/frontend/jit.ts
|
|
1774
|
-
const routinePrimitives = new Map([
|
|
1775
|
-
[Primitive.Sort, Routines.Sort],
|
|
1776
|
-
[Primitive.Argsort, Routines.Argsort],
|
|
1777
|
-
[Primitive.TriangularSolve, Routines.TriangularSolve],
|
|
1778
|
-
[Primitive.Cholesky, Routines.Cholesky],
|
|
1779
|
-
[Primitive.LU, Routines.LU]
|
|
1780
|
-
]);
|
|
1781
1827
|
/** Result of compiling a Jaxpr. Can be evaluated on a series of inputs. */
|
|
1782
1828
|
var JitProgram = class {
|
|
1783
1829
|
constructor(backend, steps, inputs, outputs) {
|
|
@@ -2166,12 +2212,13 @@ const jitRules = {
|
|
|
2166
2212
|
const ndim$2 = avals[0].ndim;
|
|
2167
2213
|
const sizes = avals.map((x) => x.shape[axis]);
|
|
2168
2214
|
const finalSize = sizes.reduce((a, b) => a + b, 0);
|
|
2215
|
+
const { dtype: dtypeOut } = avals.map((x) => x.scalar()).reduce(promoteAvals);
|
|
2169
2216
|
const makePadAxis = (start, end) => range(ndim$2).map((i) => i === axis ? [start, end] : [0, 0]);
|
|
2170
2217
|
let cum = 0;
|
|
2171
2218
|
const src = [];
|
|
2172
2219
|
for (let i = 0; i < exps.length; i++) {
|
|
2173
2220
|
const padding = makePadAxis(cum, finalSize - cum - sizes[i]);
|
|
2174
|
-
src.push(reshapeViews(exps[i], (st) => st.pad(padding)));
|
|
2221
|
+
src.push(reshapeViews(AluExp.cast(dtypeOut, exps[i]), (st) => st.pad(padding)));
|
|
2175
2222
|
cum += sizes[i];
|
|
2176
2223
|
}
|
|
2177
2224
|
return { exp: [src.reduce(AluExp.add)] };
|
|
@@ -2309,7 +2356,7 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2309
2356
|
p1NextBlack.set(v, v);
|
|
2310
2357
|
}
|
|
2311
2358
|
const heterogeneousViewPrimitives = [Primitive.RandomBits, Primitive.Gather];
|
|
2312
|
-
const needsCleanShapePrimitives = [Primitive.Pad];
|
|
2359
|
+
const needsCleanShapePrimitives = [Primitive.Concatenate, Primitive.Pad];
|
|
2313
2360
|
for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
|
|
2314
2361
|
const eqn = jaxpr.eqns[i];
|
|
2315
2362
|
if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || routinePrimitives.has(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
|
|
@@ -2379,7 +2426,7 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2379
2426
|
|
|
2380
2427
|
//#endregion
|
|
2381
2428
|
//#region src/frontend/array.ts
|
|
2382
|
-
const JsArray = globalThis.Array;
|
|
2429
|
+
const JsArray$1 = globalThis.Array;
|
|
2383
2430
|
const inlineArrayLimit = 128;
|
|
2384
2431
|
/** Version of pureArray with fudged types. */
|
|
2385
2432
|
const fudgeArray = pureArray;
|
|
@@ -2777,25 +2824,35 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2777
2824
|
});
|
|
2778
2825
|
}
|
|
2779
2826
|
/** Apply an operation with custom lowering to this array. */
|
|
2780
|
-
static #routine(
|
|
2781
|
-
|
|
2782
|
-
|
|
2783
|
-
|
|
2784
|
-
|
|
2785
|
-
|
|
2786
|
-
|
|
2787
|
-
|
|
2788
|
-
|
|
2789
|
-
|
|
2790
|
-
|
|
2791
|
-
|
|
2792
|
-
|
|
2793
|
-
dtype
|
|
2794
|
-
|
|
2795
|
-
|
|
2796
|
-
|
|
2797
|
-
pending
|
|
2798
|
-
|
|
2827
|
+
static #routine(prim) {
|
|
2828
|
+
return (arrays, params) => {
|
|
2829
|
+
const { backend, committed } = Array$1.#computeBackend(prim, arrays);
|
|
2830
|
+
for (const ar of arrays) ar.#realize();
|
|
2831
|
+
const avals = arrays.map((ar) => ar.aval);
|
|
2832
|
+
const avalsOut = abstractEvalRules[prim](avals, params);
|
|
2833
|
+
const routine = new Routine(routinePrimitives.get(prim), {
|
|
2834
|
+
inputShapes: avals.map((a) => a.shape),
|
|
2835
|
+
inputDtypes: avals.map((a) => a.dtype),
|
|
2836
|
+
outputShapes: avalsOut.map((a) => a.shape),
|
|
2837
|
+
outputDtypes: avalsOut.map((a) => a.dtype)
|
|
2838
|
+
}, params);
|
|
2839
|
+
const inputs = arrays.map((ar) => ar.#source);
|
|
2840
|
+
const outputs = avalsOut.map((x) => backend.malloc(byteWidth(x.dtype) * x.size));
|
|
2841
|
+
const pending = arrays.flatMap((ar) => ar.#pending);
|
|
2842
|
+
for (const exe of pending) exe.updateRc(+outputs.length);
|
|
2843
|
+
pending.push(new PendingExecute(backend, routine, inputs, outputs));
|
|
2844
|
+
pending[pending.length - 1].updateRc(+outputs.length - 1);
|
|
2845
|
+
arrays.forEach((ar) => ar.dispose());
|
|
2846
|
+
return outputs.map((output, i) => new Array$1({
|
|
2847
|
+
source: output,
|
|
2848
|
+
st: ShapeTracker.fromShape(avalsOut[i].shape),
|
|
2849
|
+
dtype: avalsOut[i].dtype,
|
|
2850
|
+
weakType: avalsOut[i].weakType,
|
|
2851
|
+
backend,
|
|
2852
|
+
committed,
|
|
2853
|
+
pending
|
|
2854
|
+
}));
|
|
2855
|
+
};
|
|
2799
2856
|
}
|
|
2800
2857
|
/**
|
|
2801
2858
|
* Normalizes this array into one backed by a `Slot`.
|
|
@@ -3129,65 +3186,11 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
3129
3186
|
[Primitive.Pad]([x], { width }) {
|
|
3130
3187
|
return [x.#reshape(x.#st.pad(width))];
|
|
3131
3188
|
},
|
|
3132
|
-
[Primitive.Sort](
|
|
3133
|
-
|
|
3134
|
-
|
|
3135
|
-
|
|
3136
|
-
|
|
3137
|
-
outputDtypes: [x.dtype]
|
|
3138
|
-
});
|
|
3139
|
-
return Array$1.#routine(routine, [x], [x.#weakType]);
|
|
3140
|
-
},
|
|
3141
|
-
[Primitive.Argsort]([x]) {
|
|
3142
|
-
const routine = new Routine(Routines.Argsort, {
|
|
3143
|
-
inputShapes: [x.shape],
|
|
3144
|
-
inputDtypes: [x.dtype],
|
|
3145
|
-
outputShapes: [x.shape, x.shape],
|
|
3146
|
-
outputDtypes: [x.dtype, DType.Int32]
|
|
3147
|
-
});
|
|
3148
|
-
return Array$1.#routine(routine, [x], [x.#weakType, false]);
|
|
3149
|
-
},
|
|
3150
|
-
[Primitive.TriangularSolve]([a, b], { unitDiagonal }) {
|
|
3151
|
-
const routine = new Routine(Routines.TriangularSolve, {
|
|
3152
|
-
inputShapes: [a.shape, b.shape],
|
|
3153
|
-
inputDtypes: [a.dtype, b.dtype],
|
|
3154
|
-
outputShapes: [b.shape],
|
|
3155
|
-
outputDtypes: [b.dtype]
|
|
3156
|
-
}, { unitDiagonal });
|
|
3157
|
-
return Array$1.#routine(routine, [a, b], [a.#weakType && b.#weakType]);
|
|
3158
|
-
},
|
|
3159
|
-
[Primitive.Cholesky]([a]) {
|
|
3160
|
-
const routine = new Routine(Routines.Cholesky, {
|
|
3161
|
-
inputShapes: [a.shape],
|
|
3162
|
-
inputDtypes: [a.dtype],
|
|
3163
|
-
outputShapes: [a.shape],
|
|
3164
|
-
outputDtypes: [a.dtype]
|
|
3165
|
-
});
|
|
3166
|
-
return Array$1.#routine(routine, [a], [a.#weakType]);
|
|
3167
|
-
},
|
|
3168
|
-
[Primitive.LU]([a]) {
|
|
3169
|
-
const batch = a.shape.slice(0, -2);
|
|
3170
|
-
const [m, n] = a.shape.slice(-2);
|
|
3171
|
-
const routine = new Routine(Routines.LU, {
|
|
3172
|
-
inputShapes: [a.shape],
|
|
3173
|
-
inputDtypes: [a.dtype],
|
|
3174
|
-
outputShapes: [
|
|
3175
|
-
a.shape,
|
|
3176
|
-
[...batch, Math.min(m, n)],
|
|
3177
|
-
[...batch, m]
|
|
3178
|
-
],
|
|
3179
|
-
outputDtypes: [
|
|
3180
|
-
a.dtype,
|
|
3181
|
-
DType.Int32,
|
|
3182
|
-
DType.Int32
|
|
3183
|
-
]
|
|
3184
|
-
});
|
|
3185
|
-
return Array$1.#routine(routine, [a], [
|
|
3186
|
-
a.#weakType,
|
|
3187
|
-
false,
|
|
3188
|
-
false
|
|
3189
|
-
]);
|
|
3190
|
-
},
|
|
3189
|
+
[Primitive.Sort]: Array$1.#routine(Primitive.Sort),
|
|
3190
|
+
[Primitive.Argsort]: Array$1.#routine(Primitive.Argsort),
|
|
3191
|
+
[Primitive.TriangularSolve]: Array$1.#routine(Primitive.TriangularSolve),
|
|
3192
|
+
[Primitive.Cholesky]: Array$1.#routine(Primitive.Cholesky),
|
|
3193
|
+
[Primitive.LU]: Array$1.#routine(Primitive.LU),
|
|
3191
3194
|
[Primitive.Jit](args, { jaxpr }) {
|
|
3192
3195
|
if (jaxpr.inBinders.length !== args.length) throw new Error(`jit expects ${jaxpr.inBinders.length} args, got ${args.length}`);
|
|
3193
3196
|
const { backend, committed } = Array$1.#computeBackend("jit", args);
|
|
@@ -3269,7 +3272,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
3269
3272
|
if (!shape$1) {
|
|
3270
3273
|
shape$1 = [];
|
|
3271
3274
|
let cur = values;
|
|
3272
|
-
while (JsArray.isArray(cur)) {
|
|
3275
|
+
while (JsArray$1.isArray(cur)) {
|
|
3273
3276
|
shape$1.push(cur.length);
|
|
3274
3277
|
cur = cur[0];
|
|
3275
3278
|
}
|
|
@@ -4223,17 +4226,39 @@ function jvpFlat(f, primals, tangents) {
|
|
|
4223
4226
|
_usingCtx$1.d();
|
|
4224
4227
|
}
|
|
4225
4228
|
}
|
|
4226
|
-
function jvp$1(f, primals, tangents) {
|
|
4229
|
+
function jvp$1(f, primals, tangents, { hasAux = false } = {}) {
|
|
4227
4230
|
const [primalsFlat, inTree] = flatten(primals);
|
|
4228
4231
|
const [tangentsFlat, inTree2] = flatten(tangents);
|
|
4229
4232
|
if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
|
|
4230
|
-
|
|
4233
|
+
let flatFun, outTree, aux;
|
|
4234
|
+
if (hasAux) [flatFun, outTree, aux] = flattenFunWithAux(f, inTree);
|
|
4235
|
+
else [flatFun, outTree] = flattenFun(f, inTree);
|
|
4231
4236
|
const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
|
|
4232
4237
|
if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
|
|
4233
4238
|
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
4234
4239
|
const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
|
|
4240
|
+
if (hasAux) return [
|
|
4241
|
+
primalsOut,
|
|
4242
|
+
tangentsOut,
|
|
4243
|
+
lowerAux(aux.value)
|
|
4244
|
+
];
|
|
4235
4245
|
return [primalsOut, tangentsOut];
|
|
4236
4246
|
}
|
|
4247
|
+
/** Lowering for auxiliary data returned in `hasAux: true` methods. */
|
|
4248
|
+
function lowerAux(aux) {
|
|
4249
|
+
const level = currentTraceLevel();
|
|
4250
|
+
return map((x) => {
|
|
4251
|
+
if (x instanceof Tracer) while (x._trace.main.level > level) if (x instanceof JVPTracer) {
|
|
4252
|
+
x.tangent.dispose();
|
|
4253
|
+
x = x.primal;
|
|
4254
|
+
} else {
|
|
4255
|
+
const y = x.fullLower();
|
|
4256
|
+
if (y._trace.main.level >= x._trace.main.level) throw new Error("internal: lowerAux did not reduce trace level");
|
|
4257
|
+
x = y;
|
|
4258
|
+
}
|
|
4259
|
+
return x;
|
|
4260
|
+
}, aux);
|
|
4261
|
+
}
|
|
4237
4262
|
|
|
4238
4263
|
//#endregion
|
|
4239
4264
|
//#region src/frontend/linearize.ts
|
|
@@ -4304,9 +4329,11 @@ function linearizeFlat(f, primalsIn) {
|
|
|
4304
4329
|
dispose$1
|
|
4305
4330
|
];
|
|
4306
4331
|
}
|
|
4307
|
-
function linearize$1(f,
|
|
4332
|
+
function linearize$1(f, primalsIn, { hasAux = false } = {}) {
|
|
4308
4333
|
const [primalsInFlat, inTree] = flatten(primalsIn);
|
|
4309
|
-
|
|
4334
|
+
let fFlat, outTree, aux;
|
|
4335
|
+
if (hasAux) [fFlat, outTree, aux] = flattenFunWithAux(f, inTree);
|
|
4336
|
+
else [fFlat, outTree] = flattenFun(f, inTree);
|
|
4310
4337
|
const [primalsOutFlat, fLinFlat, dispose$1] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
|
|
4311
4338
|
if (outTree.value === void 0) throw new Error("outTree was not set in linearize");
|
|
4312
4339
|
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
@@ -4317,6 +4344,11 @@ function linearize$1(f, ...primalsIn) {
|
|
|
4317
4344
|
return unflatten(outTree.value, tangentsOutFlat);
|
|
4318
4345
|
});
|
|
4319
4346
|
fLin.dispose = dispose$1;
|
|
4347
|
+
if (hasAux) return [
|
|
4348
|
+
primalsOut,
|
|
4349
|
+
fLin,
|
|
4350
|
+
lowerAux(aux.value)
|
|
4351
|
+
];
|
|
4320
4352
|
return [primalsOut, fLin];
|
|
4321
4353
|
}
|
|
4322
4354
|
var PartialEvalTracer = class extends Tracer {
|
|
@@ -4817,9 +4849,11 @@ function vjpFlat(f, primalsIn) {
|
|
|
4817
4849
|
dispose$1
|
|
4818
4850
|
];
|
|
4819
4851
|
}
|
|
4820
|
-
function vjp$1(f,
|
|
4852
|
+
function vjp$1(f, primalsIn, { hasAux = false } = {}) {
|
|
4821
4853
|
const [primalsInFlat, inTree] = flatten(primalsIn);
|
|
4822
|
-
|
|
4854
|
+
let fFlat, outTree, aux;
|
|
4855
|
+
if (hasAux) [fFlat, outTree, aux] = flattenFunWithAux(f, inTree);
|
|
4856
|
+
else [fFlat, outTree] = flattenFun(f, inTree);
|
|
4823
4857
|
const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
|
|
4824
4858
|
if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
|
|
4825
4859
|
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
@@ -4830,26 +4864,43 @@ function vjp$1(f, ...primalsIn) {
|
|
|
4830
4864
|
return unflatten(inTree, cotangentsInFlat);
|
|
4831
4865
|
});
|
|
4832
4866
|
fVjp.dispose = dispose$1;
|
|
4867
|
+
if (hasAux) return [
|
|
4868
|
+
primalsOut,
|
|
4869
|
+
fVjp,
|
|
4870
|
+
lowerAux(aux.value)
|
|
4871
|
+
];
|
|
4833
4872
|
return [primalsOut, fVjp];
|
|
4834
4873
|
}
|
|
4835
|
-
function grad$1(f) {
|
|
4836
|
-
const valueAndGradFn = valueAndGrad$1(f);
|
|
4874
|
+
function grad$1(f, opts) {
|
|
4875
|
+
const valueAndGradFn = valueAndGrad$1(f, opts);
|
|
4837
4876
|
return (...x) => {
|
|
4838
|
-
|
|
4839
|
-
|
|
4840
|
-
|
|
4877
|
+
if (opts?.hasAux) {
|
|
4878
|
+
const [[y, aux], dx] = valueAndGradFn(...x);
|
|
4879
|
+
y.dispose();
|
|
4880
|
+
return [dx, aux];
|
|
4881
|
+
} else {
|
|
4882
|
+
const [y, dx] = valueAndGradFn(...x);
|
|
4883
|
+
y.dispose();
|
|
4884
|
+
return dx;
|
|
4885
|
+
}
|
|
4841
4886
|
};
|
|
4842
4887
|
}
|
|
4843
|
-
function valueAndGrad$1(f) {
|
|
4888
|
+
function valueAndGrad$1(f, opts) {
|
|
4889
|
+
const argnums = opts?.argnums ?? 0;
|
|
4890
|
+
const hasAux = opts?.hasAux ?? false;
|
|
4891
|
+
checkInts(argnums);
|
|
4892
|
+
const argnumsSet = new Set(typeof argnums === "number" ? [argnums] : argnums);
|
|
4844
4893
|
return (...x) => {
|
|
4845
4894
|
if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
|
|
4846
|
-
|
|
4895
|
+
for (let i = 0; i < x.length; i++) if (!argnumsSet.has(i)) x[i] = map(stopGradient, x[i]);
|
|
4896
|
+
const [y, fVjp, aux] = vjp$1(f, x, { hasAux });
|
|
4847
4897
|
if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
|
|
4848
4898
|
if (!isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
|
|
4849
|
-
const
|
|
4850
|
-
for (const r of rest) dispose(r);
|
|
4899
|
+
const cts = fVjp(onesLike$1(y.ref));
|
|
4851
4900
|
fVjp.dispose();
|
|
4852
|
-
|
|
4901
|
+
for (let i = 0; i < cts.length; i++) if (!argnumsSet.has(i)) dispose(cts[i]);
|
|
4902
|
+
const grads = typeof argnums === "number" ? cts[argnums] : argnums.map((i) => cts[i]);
|
|
4903
|
+
return hasAux ? [[y, aux], grads] : [y, grads];
|
|
4853
4904
|
};
|
|
4854
4905
|
}
|
|
4855
4906
|
function jacrev$1(f) {
|
|
@@ -4857,7 +4908,7 @@ function jacrev$1(f) {
|
|
|
4857
4908
|
if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
|
|
4858
4909
|
const [size$1] = x.shape;
|
|
4859
4910
|
const pullback = (ct) => {
|
|
4860
|
-
const [y, fVjp] = vjp$1(f, x);
|
|
4911
|
+
const [y, fVjp] = vjp$1(f, [x]);
|
|
4861
4912
|
y.dispose();
|
|
4862
4913
|
const [ret] = fVjp(ct);
|
|
4863
4914
|
fVjp.dispose();
|
|
@@ -4866,6 +4917,9 @@ function jacrev$1(f) {
|
|
|
4866
4917
|
return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
|
|
4867
4918
|
};
|
|
4868
4919
|
}
|
|
4920
|
+
function hessian$1(f) {
|
|
4921
|
+
return jacfwd$1(grad$1(f));
|
|
4922
|
+
}
|
|
4869
4923
|
|
|
4870
4924
|
//#endregion
|
|
4871
4925
|
//#region src/library/numpy/einsum.ts
|
|
@@ -5538,6 +5592,7 @@ __export(numpy_exports, {
|
|
|
5538
5592
|
moveaxis: () => moveaxis$1,
|
|
5539
5593
|
multiply: () => multiply,
|
|
5540
5594
|
nan: () => nan,
|
|
5595
|
+
nanToNum: () => nanToNum,
|
|
5541
5596
|
ndim: () => ndim,
|
|
5542
5597
|
negative: () => negative,
|
|
5543
5598
|
notEqual: () => notEqual,
|
|
@@ -5575,6 +5630,7 @@ __export(numpy_exports, {
|
|
|
5575
5630
|
std: () => std,
|
|
5576
5631
|
subtract: () => subtract,
|
|
5577
5632
|
sum: () => sum,
|
|
5633
|
+
swapaxes: () => swapaxes,
|
|
5578
5634
|
take: () => take,
|
|
5579
5635
|
tan: () => tan,
|
|
5580
5636
|
tanh: () => tanh,
|
|
@@ -5734,24 +5790,22 @@ function max(a, axis = null, opts) {
|
|
|
5734
5790
|
return reduce(a, AluOp.Max, axis, opts);
|
|
5735
5791
|
}
|
|
5736
5792
|
/**
|
|
5737
|
-
* Test whether
|
|
5793
|
+
* Test whether any array element along a given axis evaluates to True.
|
|
5738
5794
|
*
|
|
5739
5795
|
* Returns a boolean array with the same shape as `a` with the specified axis
|
|
5740
5796
|
* removed. If axis is None, returns a scalar.
|
|
5741
5797
|
*/
|
|
5742
|
-
function
|
|
5743
|
-
|
|
5744
|
-
return min(a, axis, opts);
|
|
5798
|
+
function any(a, axis = null, opts) {
|
|
5799
|
+
return fudgeArray(a).any(axis, opts);
|
|
5745
5800
|
}
|
|
5746
5801
|
/**
|
|
5747
|
-
* Test whether
|
|
5802
|
+
* Test whether all array elements along a given axis evaluate to True.
|
|
5748
5803
|
*
|
|
5749
5804
|
* Returns a boolean array with the same shape as `a` with the specified axis
|
|
5750
5805
|
* removed. If axis is None, returns a scalar.
|
|
5751
5806
|
*/
|
|
5752
|
-
function
|
|
5753
|
-
|
|
5754
|
-
return max(a, axis, opts);
|
|
5807
|
+
function all(a, axis = null, opts) {
|
|
5808
|
+
return fudgeArray(a).all(axis, opts);
|
|
5755
5809
|
}
|
|
5756
5810
|
/** Return the peak-to-peak range along a given axis (`max - min`). */
|
|
5757
5811
|
function ptp(a, axis = null, opts) {
|
|
@@ -5852,7 +5906,7 @@ function split$1(a, indicesOrSections, axis = 0) {
|
|
|
5852
5906
|
const partSize = size$1 / indicesOrSections;
|
|
5853
5907
|
sizes = rep(indicesOrSections, partSize);
|
|
5854
5908
|
} else {
|
|
5855
|
-
const indices = indicesOrSections;
|
|
5909
|
+
const indices = indicesOrSections.map((i) => i < 0 ? i + size$1 : i);
|
|
5856
5910
|
sizes = [indices[0]];
|
|
5857
5911
|
for (let i = 1; i < indices.length; i++) sizes.push(indices[i] - indices[i - 1]);
|
|
5858
5912
|
sizes.push(size$1 - indices[indices.length - 1]);
|
|
@@ -5973,6 +6027,17 @@ function flipud(x) {
|
|
|
5973
6027
|
function fliplr(x) {
|
|
5974
6028
|
return flip(x, 1);
|
|
5975
6029
|
}
|
|
6030
|
+
/** Interchange two axes of an array. */
|
|
6031
|
+
function swapaxes(a, axis1, axis2) {
|
|
6032
|
+
a = fudgeArray(a);
|
|
6033
|
+
axis1 = checkAxis(axis1, a.ndim);
|
|
6034
|
+
axis2 = checkAxis(axis2, a.ndim);
|
|
6035
|
+
if (axis1 === axis2) return a;
|
|
6036
|
+
const perm = range(a.ndim);
|
|
6037
|
+
perm[axis1] = axis2;
|
|
6038
|
+
perm[axis2] = axis1;
|
|
6039
|
+
return transpose(a, perm);
|
|
6040
|
+
}
|
|
5976
6041
|
/** Transpose the last two dimensions of an array. */
|
|
5977
6042
|
function matrixTranspose(a) {
|
|
5978
6043
|
if (ndim(a) < 2) throw new Error(`matrixTranspose: input array must be at least 2D`);
|
|
@@ -6789,6 +6854,21 @@ function isposinf(x) {
|
|
|
6789
6854
|
return isFloatDtype(x.dtype) ? x.equal(Infinity) : fullLike$1(x, false);
|
|
6790
6855
|
}
|
|
6791
6856
|
/**
|
|
6857
|
+
* Replace NaN and infinite entries in an array.
|
|
6858
|
+
*
|
|
6859
|
+
* By default, NaNs are replaced with `0.0`, and infinities are are substituted
|
|
6860
|
+
* with the corresponding maximum or minimum finite values.
|
|
6861
|
+
*/
|
|
6862
|
+
function nanToNum(x, { nan: nan$1 = 0, posinf = null, neginf = null } = {}) {
|
|
6863
|
+
x = fudgeArray(x);
|
|
6864
|
+
x = where(isnan(x.ref), nan$1, x);
|
|
6865
|
+
posinf ??= isFloatDtype(x.dtype) ? finfo(x.dtype).max : iinfo(x.dtype).max;
|
|
6866
|
+
neginf ??= isFloatDtype(x.dtype) ? finfo(x.dtype).min : iinfo(x.dtype).min;
|
|
6867
|
+
x = where(isposinf(x.ref), posinf, x);
|
|
6868
|
+
x = where(isneginf(x.ref), neginf, x);
|
|
6869
|
+
return x;
|
|
6870
|
+
}
|
|
6871
|
+
/**
|
|
6792
6872
|
* @function
|
|
6793
6873
|
* Test element-wise for finite values (not infinity or NaN).
|
|
6794
6874
|
*/
|
|
@@ -6901,6 +6981,7 @@ var lax_exports = {};
|
|
|
6901
6981
|
__export(lax_exports, {
|
|
6902
6982
|
conv: () => conv,
|
|
6903
6983
|
convGeneralDilated: () => convGeneralDilated,
|
|
6984
|
+
convTranspose: () => convTranspose,
|
|
6904
6985
|
convWithGeneralPadding: () => convWithGeneralPadding,
|
|
6905
6986
|
dot: () => dot,
|
|
6906
6987
|
erf: () => erf,
|
|
@@ -6909,6 +6990,7 @@ __export(lax_exports, {
|
|
|
6909
6990
|
reduceWindow: () => reduceWindow,
|
|
6910
6991
|
stopGradient: () => stopGradient$1
|
|
6911
6992
|
});
|
|
6993
|
+
const JsArray = globalThis.Array;
|
|
6912
6994
|
/**
|
|
6913
6995
|
* General dot product/contraction operator.
|
|
6914
6996
|
*
|
|
@@ -6980,7 +7062,11 @@ function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
|
|
|
6980
7062
|
* The semantics of this operation mimic the `jax.lax.conv_general_dilated`
|
|
6981
7063
|
* function in JAX, which wraps XLA's general convolution operator.
|
|
6982
7064
|
*
|
|
6983
|
-
*
|
|
7065
|
+
* @param lhs - Input tensor; shape `[N, C_in, ...xs]`
|
|
7066
|
+
* @param rhs - Convolution kernel; shape `[C_out, C_in / G, ...ks]`
|
|
7067
|
+
* @param windowStrides - Strides for each spatial dimension
|
|
7068
|
+
* @param padding - Padding for each spatial dimension, or a string
|
|
7069
|
+
* (`"VALID"`, `"SAME"`, or `"SAME_LOWER"`)
|
|
6984
7070
|
*/
|
|
6985
7071
|
function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
|
|
6986
7072
|
if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
|
|
@@ -7040,6 +7126,60 @@ function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, r
|
|
|
7040
7126
|
function conv(lhs, rhs, windowStrides, padding) {
|
|
7041
7127
|
return convGeneralDilated(lhs, rhs, windowStrides, padding);
|
|
7042
7128
|
}
|
|
7129
|
+
/**
|
|
7130
|
+
* Convenience wrapper for calculating the N-d convolution "transpose".
|
|
7131
|
+
*
|
|
7132
|
+
* This function directly calculates a fractionally strided conv rather than
|
|
7133
|
+
* indirectly calculating the gradient (transpose) of a forward convolution.
|
|
7134
|
+
* It is equivalent to the JAX version, except:
|
|
7135
|
+
*
|
|
7136
|
+
* - The `use_consistent_padding` option is not available. We only have the
|
|
7137
|
+
* consistent padding case (JAX version >0.8.4).
|
|
7138
|
+
* - The order of dimensions matches `lax.conv_general_dilated`.
|
|
7139
|
+
*
|
|
7140
|
+
* Unlike PyTorch/TensorFlow, by default we don't reverse the kernel's spatial
|
|
7141
|
+
* dimensions or the `(C_out, C_in)` axis order. To get this behavior, set
|
|
7142
|
+
* `transposeKernel` to true.
|
|
7143
|
+
*
|
|
7144
|
+
* @param lhs - Input tensor; shape `[N, C_in, ...xs]`
|
|
7145
|
+
* @param rhs - Convolution kernel; shape `[C_out, C_in, ...ks]`
|
|
7146
|
+
* @param strides - Sequence of n integers, sets fractional stride
|
|
7147
|
+
* @param padding - Apply padding of `dilation * (kernel_size - 1) - padding` to
|
|
7148
|
+
* each side of the input, so it acts like gradient of `conv()`
|
|
7149
|
+
* @param rhsDilation - Atrous dilation for the kernel
|
|
7150
|
+
* @param transposeKernel - Flip spatial axes and swap the input/output channels
|
|
7151
|
+
* of the kernel; its shape should be `[C_in, C_out, ...ks]`
|
|
7152
|
+
*/
|
|
7153
|
+
function convTranspose(lhs, rhs, strides, padding, { rhsDilation, transposeKernel = false } = {}) {
|
|
7154
|
+
const kernelShape = rhs.shape.slice(2);
|
|
7155
|
+
rhsDilation = rhsDilation ?? rep(kernelShape.length, 1);
|
|
7156
|
+
const effectiveKernel = kernelShape.map((k, i) => Math.max(0, (k - 1) * rhsDilation[i] + 1));
|
|
7157
|
+
const pads = effectiveKernel.map((k, i) => convTransposePadding(k, strides[i], typeof padding === "string" ? padding : padding[i]));
|
|
7158
|
+
if (transposeKernel) {
|
|
7159
|
+
rhs = flip$1(rhs, range(2, rhs.ndim));
|
|
7160
|
+
rhs = moveaxis(rhs, 0, 1);
|
|
7161
|
+
}
|
|
7162
|
+
return convGeneralDilated(lhs, rhs, rep(lhs.ndim - 2, 1), pads, {
|
|
7163
|
+
lhsDilation: strides,
|
|
7164
|
+
rhsDilation
|
|
7165
|
+
});
|
|
7166
|
+
}
|
|
7167
|
+
function convTransposePadding(k, s, padding) {
|
|
7168
|
+
let padLen;
|
|
7169
|
+
let pad1;
|
|
7170
|
+
if (padding === "SAME") {
|
|
7171
|
+
padLen = k + s - 2;
|
|
7172
|
+
pad1 = s > k - 1 ? k - 1 : Math.ceil(padLen / 2);
|
|
7173
|
+
} else if (padding === "VALID") {
|
|
7174
|
+
padLen = k + s - 2 + Math.max(k - s, 0);
|
|
7175
|
+
pad1 = k - 1;
|
|
7176
|
+
} else if (JsArray.isArray(padding)) {
|
|
7177
|
+
const pads = [k - 1 - padding[0], k - 1 - padding[1]];
|
|
7178
|
+
pad1 = pads[0];
|
|
7179
|
+
padLen = pads[0] + pads[1];
|
|
7180
|
+
} else throw new Error(`convTranspose: Invalid padding type ${padding}`);
|
|
7181
|
+
return [pad1, padLen - pad1];
|
|
7182
|
+
}
|
|
7043
7183
|
/** Reduce a computation over padded windows. */
|
|
7044
7184
|
function reduceWindow(operand, computation, windowDimensions, windowStrides) {
|
|
7045
7185
|
if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
|
|
@@ -7078,6 +7218,7 @@ function stopGradient$1(x) {
|
|
|
7078
7218
|
var nn_exports = {};
|
|
7079
7219
|
__export(nn_exports, {
|
|
7080
7220
|
celu: () => celu,
|
|
7221
|
+
dotProductAttention: () => dotProductAttention,
|
|
7081
7222
|
elu: () => elu,
|
|
7082
7223
|
gelu: () => gelu,
|
|
7083
7224
|
glu: () => glu,
|
|
@@ -7394,6 +7535,125 @@ function oneHot(x, numClasses) {
|
|
|
7394
7535
|
if (isFloatDtype(x.dtype)) throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
|
|
7395
7536
|
return eye(numClasses, void 0, { device: x.device }).slice(x);
|
|
7396
7537
|
}
|
|
7538
|
+
/**
|
|
7539
|
+
* Scaled dot product attention (SDPA).
|
|
7540
|
+
*
|
|
7541
|
+
* Computes `softmax((Q @ K^T) / sqrt(d) + bias) @ V`, where `Q` is the query,
|
|
7542
|
+
* `K` is the key, `V` is the value, and `d` is the dimensionality of each key
|
|
7543
|
+
* and query vector.
|
|
7544
|
+
*
|
|
7545
|
+
* Multi-query attention is applied when input `key` and `value` tensors have
|
|
7546
|
+
* fewer heads than `query`.
|
|
7547
|
+
*
|
|
7548
|
+
* We use the following uppercase letters to denote array shapes:
|
|
7549
|
+
* - `B` = batch size
|
|
7550
|
+
* - `S` = length of key/value sequences (source)
|
|
7551
|
+
* - `L` = length of query sequences
|
|
7552
|
+
* - `N` = number of attention heads
|
|
7553
|
+
* - `H` = dimensionality of each attention head
|
|
7554
|
+
* - `K` = number of key/value heads (for grouped-query attention)
|
|
7555
|
+
*
|
|
7556
|
+
* The batch size `B` may be omitted, which is equivalent to `B = 1`. In this
|
|
7557
|
+
* case it must be omitted from all inputs.
|
|
7558
|
+
*
|
|
7559
|
+
* @param query - Query array; shape `[B, L, N, H]`
|
|
7560
|
+
* @param key - Key array; shape `[B, S, K, H]`
|
|
7561
|
+
* @param value - Value array; same shape as `key`
|
|
7562
|
+
* @param opts.bias - Optional bias to add to the attention logits; shape
|
|
7563
|
+
* `[B, N, L, S]` or broadcastable to it.
|
|
7564
|
+
* @param opts.mask - Optional mask to apply to the attention logits; should be
|
|
7565
|
+
* a boolean array broadcastable to `[B, N, L, S]`, where `true` indicates
|
|
7566
|
+
* the element should take part in attention.
|
|
7567
|
+
* @param opts.scale - Scaling factor override, default is `1 / sqrt(H)`.
|
|
7568
|
+
* @param opts.isCausal - If true, applies a casual mask.
|
|
7569
|
+
* @param opts.querySeqLengths - Optional sequence lengths for the queries;
|
|
7570
|
+
* shape `(B,)`. Taken from the beginning of the tensor.
|
|
7571
|
+
* @param opts.keyValueSeqLengths - Optional sequence lengths for the keys and
|
|
7572
|
+
* values; shape `(B,)`. Taken from the beginning of the tensor.
|
|
7573
|
+
* @param opts.localWindowSize - If specified, applies a local attention window
|
|
7574
|
+
* of the given size. Can be a single number or a tuple `[left, right]`.
|
|
7575
|
+
*
|
|
7576
|
+
* @returns The result of the attention operation; shape is the same as query
|
|
7577
|
+
* `[B, L, N, H]`, or `[L, N, H]` if `B` is omitted.
|
|
7578
|
+
*/
|
|
7579
|
+
function dotProductAttention(query, key$1, value, opts = {}) {
|
|
7580
|
+
query = fudgeArray(query);
|
|
7581
|
+
key$1 = fudgeArray(key$1);
|
|
7582
|
+
value = fudgeArray(value);
|
|
7583
|
+
if (query.ndim !== 3 && query.ndim !== 4 || query.ndim !== key$1.ndim || query.ndim !== value.ndim) throw new Error(`dotProductAttention: expected all tensors to have rank 3 or 4, got Q=${query.aval}, K=${key$1.aval}, V=${value.aval}`);
|
|
7584
|
+
if (!deepEqual(key$1.shape, value.shape)) throw new Error(`dotProductAttention: key and value shapes must match, got K=${key$1.shape}, V=${value.shape}`);
|
|
7585
|
+
const isRank3 = query.ndim === 3;
|
|
7586
|
+
if (isRank3) {
|
|
7587
|
+
query = expandDims(query, 0);
|
|
7588
|
+
key$1 = expandDims(key$1, 0);
|
|
7589
|
+
value = expandDims(value, 0);
|
|
7590
|
+
}
|
|
7591
|
+
const [B, L, N, H] = query.shape;
|
|
7592
|
+
if (key$1.shape[0] !== B || key$1.shape[3] !== H) throw new Error(`dotProductAttention: query and key shapes mismatch, got Q=${query.aval}, K=${key$1.aval}`);
|
|
7593
|
+
const S = key$1.shape[1];
|
|
7594
|
+
const K = key$1.shape[2];
|
|
7595
|
+
if (N < K || N != K && N % K !== 0) throw new Error(`dotProductAttention: number of query heads N=${N} must be divisible by number of key/value heads K=${K} for GQA`);
|
|
7596
|
+
const G = N / K;
|
|
7597
|
+
key$1 = tile(key$1, [
|
|
7598
|
+
1,
|
|
7599
|
+
1,
|
|
7600
|
+
G,
|
|
7601
|
+
1
|
|
7602
|
+
]);
|
|
7603
|
+
value = tile(value, [
|
|
7604
|
+
1,
|
|
7605
|
+
1,
|
|
7606
|
+
G,
|
|
7607
|
+
1
|
|
7608
|
+
]);
|
|
7609
|
+
const scale = opts.scale ?? 1 / Math.sqrt(H);
|
|
7610
|
+
let scores = einsum("BLNH,BSNH->BNLS", query, key$1).mul(scale);
|
|
7611
|
+
if (opts.bias !== void 0) scores = scores.add(opts.bias);
|
|
7612
|
+
if (opts.mask !== void 0) scores = where(opts.mask, scores, -Infinity);
|
|
7613
|
+
if (opts.isCausal) {
|
|
7614
|
+
const causalMask = tri(L, S, 0, { dtype: DType.Bool });
|
|
7615
|
+
scores = where(causalMask, scores, -Infinity);
|
|
7616
|
+
}
|
|
7617
|
+
if (opts.localWindowSize !== void 0) {
|
|
7618
|
+
const [before, after] = typeof opts.localWindowSize === "number" ? [opts.localWindowSize, opts.localWindowSize] : opts.localWindowSize;
|
|
7619
|
+
if (before < 0 || after < 0 || !Number.isInteger(before) || !Number.isInteger(after)) throw new Error(`dotProductAttention: localWindowSize values must be non-negative, got ${opts.localWindowSize}`);
|
|
7620
|
+
const localMask = tri(L, S, after, { dtype: DType.Bool }).mul(tri(L, S, -before - 1, { dtype: DType.Bool }).notEqual(true));
|
|
7621
|
+
scores = where(localMask, scores, -Infinity);
|
|
7622
|
+
}
|
|
7623
|
+
if (opts.querySeqLengths !== void 0) {
|
|
7624
|
+
const sl = expandDims(opts.querySeqLengths, [
|
|
7625
|
+
-1,
|
|
7626
|
+
-2,
|
|
7627
|
+
-3
|
|
7628
|
+
]);
|
|
7629
|
+
scores = where(arange(L).reshape([
|
|
7630
|
+
1,
|
|
7631
|
+
1,
|
|
7632
|
+
L,
|
|
7633
|
+
1
|
|
7634
|
+
]).less(sl), scores, -Infinity);
|
|
7635
|
+
}
|
|
7636
|
+
if (opts.keyValueSeqLengths !== void 0) {
|
|
7637
|
+
const sl = expandDims(opts.keyValueSeqLengths, [
|
|
7638
|
+
-1,
|
|
7639
|
+
-2,
|
|
7640
|
+
-3
|
|
7641
|
+
]);
|
|
7642
|
+
scores = where(arange(S).reshape([
|
|
7643
|
+
1,
|
|
7644
|
+
1,
|
|
7645
|
+
1,
|
|
7646
|
+
S
|
|
7647
|
+
]).less(sl), scores, -Infinity);
|
|
7648
|
+
}
|
|
7649
|
+
const attn = softmax(scores, -1);
|
|
7650
|
+
const out = einsum("BNLS,BSNH->BLNH", attn, value);
|
|
7651
|
+
return isRank3 ? out.reshape([
|
|
7652
|
+
L,
|
|
7653
|
+
N,
|
|
7654
|
+
H
|
|
7655
|
+
]) : out;
|
|
7656
|
+
}
|
|
7397
7657
|
|
|
7398
7658
|
//#endregion
|
|
7399
7659
|
//#region src/library/random.ts
|
|
@@ -7629,17 +7889,62 @@ const linearize = linearize$1;
|
|
|
7629
7889
|
/**
|
|
7630
7890
|
* @function
|
|
7631
7891
|
* Calculate the reverse-mode vector-Jacobian product for a function.
|
|
7892
|
+
*
|
|
7893
|
+
* The return value is a tuple of `[out, vjpFn]`, where `out` is the output of
|
|
7894
|
+
* `f(primals)`, and `vjpFn` is a function that takes in cotangents for each
|
|
7895
|
+
* output and returns the cotangents for each input.
|
|
7896
|
+
*
|
|
7897
|
+
* When `{ hasAux: true }` is passed, the function `f` is expected to return an
|
|
7898
|
+
* `[out, aux]` tuple, and `vjp` returns `[out, vjpFn, aux]`.
|
|
7899
|
+
*
|
|
7900
|
+
* @example
|
|
7901
|
+
* ```ts
|
|
7902
|
+
* const [y, vjpFn] = vjp(f, [x]);
|
|
7903
|
+
*
|
|
7904
|
+
* // With hasAux
|
|
7905
|
+
* const [y, vjpFn, aux] = vjp(f, [x], { hasAux: true });
|
|
7906
|
+
* ```
|
|
7632
7907
|
*/
|
|
7633
7908
|
const vjp = vjp$1;
|
|
7634
7909
|
/**
|
|
7635
7910
|
* @function
|
|
7636
7911
|
* Compute the gradient of a scalar-valued function `f` with respect to its
|
|
7637
7912
|
* first argument.
|
|
7913
|
+
*
|
|
7914
|
+
* Pass in different `argnums` to differentiate with respect to other
|
|
7915
|
+
* arguments. If a tuple is provided, the return value will be a tuple of
|
|
7916
|
+
* gradients corresponding to each argument index.
|
|
7917
|
+
*
|
|
7918
|
+
* When `{ hasAux: true }` is passed, the function `f` is expected to return a
|
|
7919
|
+
* `[out, aux]` tuple, and the return value will be `[gradient, aux]`.
|
|
7920
|
+
*
|
|
7921
|
+
* @example
|
|
7922
|
+
* ```ts
|
|
7923
|
+
* const gradient = grad(f)(x);
|
|
7924
|
+
*
|
|
7925
|
+
* // With `argnums`
|
|
7926
|
+
* const [gradientX, gradientZ] = grad(f, { argnums: [0, 2] })(x, y, z);
|
|
7927
|
+
*
|
|
7928
|
+
* // With `hasAux`
|
|
7929
|
+
* const [gradient, aux] = grad(f, { hasAux: true })(x);
|
|
7930
|
+
* ```
|
|
7638
7931
|
*/
|
|
7639
7932
|
const grad = grad$1;
|
|
7640
7933
|
/**
|
|
7641
7934
|
* @function
|
|
7642
7935
|
* Create a function that evaluates both `f` and the gradient of `f`.
|
|
7936
|
+
*
|
|
7937
|
+
* When `{ hasAux: true }` is passed, the function `f` is expected to return an
|
|
7938
|
+
* `[out, aux]` tuple, and the return value will be `[[out, aux], gradient]`.
|
|
7939
|
+
*
|
|
7940
|
+
* @example
|
|
7941
|
+
* ```ts
|
|
7942
|
+
* // Without hasAux
|
|
7943
|
+
* const [value, gradient] = valueAndGrad(f)(x);
|
|
7944
|
+
*
|
|
7945
|
+
* // With hasAux
|
|
7946
|
+
* const [[value, aux], gradient] = valueAndGrad(f, { hasAux: true })(x);
|
|
7947
|
+
* ```
|
|
7643
7948
|
*/
|
|
7644
7949
|
const valueAndGrad = valueAndGrad$1;
|
|
7645
7950
|
/**
|
|
@@ -7648,6 +7953,21 @@ const valueAndGrad = valueAndGrad$1;
|
|
|
7648
7953
|
*/
|
|
7649
7954
|
const jacrev = jacrev$1;
|
|
7650
7955
|
/**
|
|
7956
|
+
* @function
|
|
7957
|
+
* Compute the Hessian matrix of a scalar-valued function.
|
|
7958
|
+
*
|
|
7959
|
+
* The Hessian is the matrix of second-order partial derivatives of a function.
|
|
7960
|
+
* This is implemented as `jacfwd(grad(f))`.
|
|
7961
|
+
*
|
|
7962
|
+
* @example
|
|
7963
|
+
* ```ts
|
|
7964
|
+
* const f = (x: np.Array) => np.sum(x.ref.mul(x.ref).mul(x)); // x^3
|
|
7965
|
+
* const H = hessian(f)(np.array([1, 2, 3]));
|
|
7966
|
+
* // H[i,j] = d^2f / dx_i dx_j
|
|
7967
|
+
* ```
|
|
7968
|
+
*/
|
|
7969
|
+
const hessian = hessian$1;
|
|
7970
|
+
/**
|
|
7651
7971
|
* Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
|
|
7652
7972
|
*
|
|
7653
7973
|
* This can be used to wait for the results of an intermediate computation to
|
|
@@ -7682,4 +8002,4 @@ async function devicePut(x, device) {
|
|
|
7682
8002
|
}
|
|
7683
8003
|
|
|
7684
8004
|
//#endregion
|
|
7685
|
-
export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
|
|
8005
|
+
export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
|