@jax-js/jax 0.1.5 → 0.1.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +60 -7
- package/dist/{backend-DziQSaoQ.cjs → backend-B3foXiV_.cjs} +25 -6
- package/dist/{backend-DaqL-MNz.js → backend-nEolvdLv.js} +20 -7
- package/dist/index.cjs +450 -129
- package/dist/index.d.cts +1669 -1467
- package/dist/index.d.ts +1669 -1467
- package/dist/index.js +450 -130
- package/dist/{webgl-ClIYb8jP.cjs → webgl-DIIbKJ0G.cjs} +1 -1
- package/dist/{webgl-RSuZKvgc.js → webgl-DweKSWEm.js} +1 -1
- package/dist/{webgpu-Dh7k9io0.js → webgpu-B96vzWGE.js} +1 -1
- package/dist/{webgpu-Db2JrNBr.cjs → webgpu-BykvF26B.cjs} +1 -1
- package/package.json +1 -1
package/dist/index.cjs
CHANGED
|
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
|
|
|
30
30
|
}) : target, mod$1));
|
|
31
31
|
|
|
32
32
|
//#endregion
|
|
33
|
-
const require_backend = require('./backend-
|
|
33
|
+
const require_backend = require('./backend-B3foXiV_.cjs');
|
|
34
34
|
|
|
35
35
|
//#region src/frontend/convolution.ts
|
|
36
36
|
/**
|
|
@@ -240,7 +240,7 @@ __export(tree_exports, {
|
|
|
240
240
|
structure: () => structure,
|
|
241
241
|
unflatten: () => unflatten
|
|
242
242
|
});
|
|
243
|
-
const JsArray$
|
|
243
|
+
const JsArray$2 = globalThis.Array;
|
|
244
244
|
let NodeType = /* @__PURE__ */ function(NodeType$1) {
|
|
245
245
|
NodeType$1["Array"] = "Array";
|
|
246
246
|
NodeType$1["Object"] = "Object";
|
|
@@ -288,7 +288,7 @@ function flatten(tree) {
|
|
|
288
288
|
return [leaves$1, treedef];
|
|
289
289
|
}
|
|
290
290
|
function _flatten(tree, leaves$1) {
|
|
291
|
-
if (JsArray$
|
|
291
|
+
if (JsArray$2.isArray(tree)) {
|
|
292
292
|
const childTrees = tree.map((c) => _flatten(c, leaves$1));
|
|
293
293
|
return new JsTreeDef(NodeType.Array, null, childTrees);
|
|
294
294
|
} else if (typeof tree === "object" && tree !== null && tree.constructor === Object) {
|
|
@@ -337,11 +337,11 @@ function map(fn, tree, ...rest) {
|
|
|
337
337
|
}
|
|
338
338
|
/** Take a reference of every array in a tree. */
|
|
339
339
|
function ref(tree) {
|
|
340
|
-
return map((x) => x.ref, tree);
|
|
340
|
+
return map((x) => x instanceof Tracer ? x.ref : x, tree);
|
|
341
341
|
}
|
|
342
342
|
/** Dispose every array in a tree. */
|
|
343
343
|
function dispose(tree) {
|
|
344
|
-
if (tree) map((x) => x.dispose(), tree);
|
|
344
|
+
if (tree) map((x) => x instanceof Tracer ? x.dispose() : void 0, tree);
|
|
345
345
|
}
|
|
346
346
|
|
|
347
347
|
//#endregion
|
|
@@ -412,6 +412,13 @@ let CompareOp = /* @__PURE__ */ function(CompareOp$1) {
|
|
|
412
412
|
CompareOp$1["LessEqual"] = "less_equal";
|
|
413
413
|
return CompareOp$1;
|
|
414
414
|
}({});
|
|
415
|
+
const routinePrimitives = new Map([
|
|
416
|
+
[Primitive.Sort, require_backend.Routines.Sort],
|
|
417
|
+
[Primitive.Argsort, require_backend.Routines.Argsort],
|
|
418
|
+
[Primitive.TriangularSolve, require_backend.Routines.TriangularSolve],
|
|
419
|
+
[Primitive.Cholesky, require_backend.Routines.Cholesky],
|
|
420
|
+
[Primitive.LU, require_backend.Routines.LU]
|
|
421
|
+
]);
|
|
415
422
|
function add$1(x, y) {
|
|
416
423
|
return bind1(Primitive.Add, [x, y]);
|
|
417
424
|
}
|
|
@@ -608,14 +615,20 @@ function shrink(x, slice) {
|
|
|
608
615
|
}
|
|
609
616
|
function pad$1(x, width) {
|
|
610
617
|
const nd = ndim$1(x);
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
else if (
|
|
614
|
-
if (width
|
|
615
|
-
const
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
618
|
+
let w;
|
|
619
|
+
if (typeof width === "number") w = [[width, width]];
|
|
620
|
+
else if (require_backend.isNumberPair(width)) w = [width];
|
|
621
|
+
else if (!Array.isArray(width)) {
|
|
622
|
+
const indicesAndPairs = Object.entries(width);
|
|
623
|
+
w = require_backend.rep(nd, [0, 0]);
|
|
624
|
+
for (const [k, v] of indicesAndPairs) w[require_backend.checkAxis(parseInt(k), nd)] = v;
|
|
625
|
+
} else if (!width.every(require_backend.isNumberPair)) throw new TypeError(`Invalid pad() type: ${JSON.stringify(width)}`);
|
|
626
|
+
else w = width;
|
|
627
|
+
if (w.length === 1) {
|
|
628
|
+
const [w0, w1] = w[0];
|
|
629
|
+
w = require_backend.rep(nd, () => [w0, w1]);
|
|
630
|
+
} else if (w.length !== nd) throw new Error(`Invalid pad(): expected ${nd} axes, got ${w.length}`);
|
|
631
|
+
return bind1(Primitive.Pad, [x], { width: w });
|
|
619
632
|
}
|
|
620
633
|
function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
|
|
621
634
|
const as = getShape(a);
|
|
@@ -685,6 +698,9 @@ function newDynamic(main) {
|
|
|
685
698
|
dynamicTrace = prevDynamicTrace;
|
|
686
699
|
} };
|
|
687
700
|
}
|
|
701
|
+
function currentTraceLevel() {
|
|
702
|
+
return traceStack[traceStack.length - 1].level;
|
|
703
|
+
}
|
|
688
704
|
var Trace = class {
|
|
689
705
|
constructor(main) {
|
|
690
706
|
this.main = main;
|
|
@@ -788,6 +804,22 @@ var Tracer = class Tracer {
|
|
|
788
804
|
const result = reduce(this.astype(castDtype), require_backend.AluOp.Add, axis, opts);
|
|
789
805
|
return result.mul(1 / n).astype(originalDtype);
|
|
790
806
|
}
|
|
807
|
+
/** Minimum of the elements of the array along a given axis. */
|
|
808
|
+
min(axis = null, opts) {
|
|
809
|
+
return reduce(this, require_backend.AluOp.Min, axis, opts);
|
|
810
|
+
}
|
|
811
|
+
/** Maximum of the elements of the array along a given axis. */
|
|
812
|
+
max(axis = null, opts) {
|
|
813
|
+
return reduce(this, require_backend.AluOp.Max, axis, opts);
|
|
814
|
+
}
|
|
815
|
+
/** Test whether all array elements along a given axis evaluate to true. */
|
|
816
|
+
all(axis = null, opts) {
|
|
817
|
+
return this.astype(require_backend.DType.Bool).min(axis, opts);
|
|
818
|
+
}
|
|
819
|
+
/** Test whether any array element along a given axis evaluates to true. */
|
|
820
|
+
any(axis = null, opts) {
|
|
821
|
+
return this.astype(require_backend.DType.Bool).max(axis, opts);
|
|
822
|
+
}
|
|
791
823
|
/** Permute the dimensions of an array. Defaults to reversing the axis order. */
|
|
792
824
|
transpose(perm) {
|
|
793
825
|
return transpose$1(this, perm);
|
|
@@ -1062,6 +1094,7 @@ var TreeMismatchError = class extends TypeError {
|
|
|
1062
1094
|
super(`Mismatched tree structures in ${where$2}: ${left} != ${right}`);
|
|
1063
1095
|
}
|
|
1064
1096
|
};
|
|
1097
|
+
/** Flatten a function of `JsTree` input/output for use in tracing. */
|
|
1065
1098
|
function flattenFun(f, inTree) {
|
|
1066
1099
|
const store = { value: void 0 };
|
|
1067
1100
|
const flatFun = (...argsFlat) => {
|
|
@@ -1073,6 +1106,26 @@ function flattenFun(f, inTree) {
|
|
|
1073
1106
|
};
|
|
1074
1107
|
return [flatFun, store];
|
|
1075
1108
|
}
|
|
1109
|
+
/** Like flattenFun, but expects f to return [main, aux] tuple. */
|
|
1110
|
+
function flattenFunWithAux(f, inTree) {
|
|
1111
|
+
const store = { value: void 0 };
|
|
1112
|
+
const auxStore = { value: void 0 };
|
|
1113
|
+
const flatFun = (...argsFlat) => {
|
|
1114
|
+
const pytreeArgs = unflatten(inTree, argsFlat);
|
|
1115
|
+
const result = f(...pytreeArgs);
|
|
1116
|
+
if (!Array.isArray(result) || result.length !== 2) throw new Error("Function with `hasAux: true` must return [output, aux] tuple");
|
|
1117
|
+
const [out, aux] = result;
|
|
1118
|
+
const [outFlat, outTree] = flatten(out);
|
|
1119
|
+
store.value = outTree;
|
|
1120
|
+
auxStore.value = aux;
|
|
1121
|
+
return outFlat;
|
|
1122
|
+
};
|
|
1123
|
+
return [
|
|
1124
|
+
flatFun,
|
|
1125
|
+
store,
|
|
1126
|
+
auxStore
|
|
1127
|
+
];
|
|
1128
|
+
}
|
|
1076
1129
|
var UseAfterFreeError = class extends ReferenceError {
|
|
1077
1130
|
constructor(tracer) {
|
|
1078
1131
|
super(`Referenced tracer ${tracer.toString()} freed, please use .ref move semantics`);
|
|
@@ -1806,13 +1859,6 @@ function jit$1(f, opts) {
|
|
|
1806
1859
|
|
|
1807
1860
|
//#endregion
|
|
1808
1861
|
//#region src/frontend/jit.ts
|
|
1809
|
-
const routinePrimitives = new Map([
|
|
1810
|
-
[Primitive.Sort, require_backend.Routines.Sort],
|
|
1811
|
-
[Primitive.Argsort, require_backend.Routines.Argsort],
|
|
1812
|
-
[Primitive.TriangularSolve, require_backend.Routines.TriangularSolve],
|
|
1813
|
-
[Primitive.Cholesky, require_backend.Routines.Cholesky],
|
|
1814
|
-
[Primitive.LU, require_backend.Routines.LU]
|
|
1815
|
-
]);
|
|
1816
1862
|
/** Result of compiling a Jaxpr. Can be evaluated on a series of inputs. */
|
|
1817
1863
|
var JitProgram = class {
|
|
1818
1864
|
constructor(backend, steps, inputs, outputs) {
|
|
@@ -2201,12 +2247,13 @@ const jitRules = {
|
|
|
2201
2247
|
const ndim$2 = avals[0].ndim;
|
|
2202
2248
|
const sizes = avals.map((x) => x.shape[axis]);
|
|
2203
2249
|
const finalSize = sizes.reduce((a, b) => a + b, 0);
|
|
2250
|
+
const { dtype: dtypeOut } = avals.map((x) => x.scalar()).reduce(promoteAvals);
|
|
2204
2251
|
const makePadAxis = (start, end) => require_backend.range(ndim$2).map((i) => i === axis ? [start, end] : [0, 0]);
|
|
2205
2252
|
let cum = 0;
|
|
2206
2253
|
const src = [];
|
|
2207
2254
|
for (let i = 0; i < exps.length; i++) {
|
|
2208
2255
|
const padding = makePadAxis(cum, finalSize - cum - sizes[i]);
|
|
2209
|
-
src.push(reshapeViews(exps[i], (st) => st.pad(padding)));
|
|
2256
|
+
src.push(reshapeViews(require_backend.AluExp.cast(dtypeOut, exps[i]), (st) => st.pad(padding)));
|
|
2210
2257
|
cum += sizes[i];
|
|
2211
2258
|
}
|
|
2212
2259
|
return { exp: [src.reduce(require_backend.AluExp.add)] };
|
|
@@ -2344,7 +2391,7 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2344
2391
|
p1NextBlack.set(v, v);
|
|
2345
2392
|
}
|
|
2346
2393
|
const heterogeneousViewPrimitives = [Primitive.RandomBits, Primitive.Gather];
|
|
2347
|
-
const needsCleanShapePrimitives = [Primitive.Pad];
|
|
2394
|
+
const needsCleanShapePrimitives = [Primitive.Concatenate, Primitive.Pad];
|
|
2348
2395
|
for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
|
|
2349
2396
|
const eqn = jaxpr.eqns[i];
|
|
2350
2397
|
if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || routinePrimitives.has(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
|
|
@@ -2414,7 +2461,7 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2414
2461
|
|
|
2415
2462
|
//#endregion
|
|
2416
2463
|
//#region src/frontend/array.ts
|
|
2417
|
-
const JsArray = globalThis.Array;
|
|
2464
|
+
const JsArray$1 = globalThis.Array;
|
|
2418
2465
|
const inlineArrayLimit = 128;
|
|
2419
2466
|
/** Version of pureArray with fudged types. */
|
|
2420
2467
|
const fudgeArray = pureArray;
|
|
@@ -2812,25 +2859,35 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2812
2859
|
});
|
|
2813
2860
|
}
|
|
2814
2861
|
/** Apply an operation with custom lowering to this array. */
|
|
2815
|
-
static #routine(
|
|
2816
|
-
|
|
2817
|
-
|
|
2818
|
-
|
|
2819
|
-
|
|
2820
|
-
|
|
2821
|
-
|
|
2822
|
-
|
|
2823
|
-
|
|
2824
|
-
|
|
2825
|
-
|
|
2826
|
-
|
|
2827
|
-
|
|
2828
|
-
dtype
|
|
2829
|
-
|
|
2830
|
-
|
|
2831
|
-
|
|
2832
|
-
pending
|
|
2833
|
-
|
|
2862
|
+
static #routine(prim) {
|
|
2863
|
+
return (arrays, params) => {
|
|
2864
|
+
const { backend, committed } = Array$1.#computeBackend(prim, arrays);
|
|
2865
|
+
for (const ar of arrays) ar.#realize();
|
|
2866
|
+
const avals = arrays.map((ar) => ar.aval);
|
|
2867
|
+
const avalsOut = abstractEvalRules[prim](avals, params);
|
|
2868
|
+
const routine = new require_backend.Routine(routinePrimitives.get(prim), {
|
|
2869
|
+
inputShapes: avals.map((a) => a.shape),
|
|
2870
|
+
inputDtypes: avals.map((a) => a.dtype),
|
|
2871
|
+
outputShapes: avalsOut.map((a) => a.shape),
|
|
2872
|
+
outputDtypes: avalsOut.map((a) => a.dtype)
|
|
2873
|
+
}, params);
|
|
2874
|
+
const inputs = arrays.map((ar) => ar.#source);
|
|
2875
|
+
const outputs = avalsOut.map((x) => backend.malloc(require_backend.byteWidth(x.dtype) * x.size));
|
|
2876
|
+
const pending = arrays.flatMap((ar) => ar.#pending);
|
|
2877
|
+
for (const exe of pending) exe.updateRc(+outputs.length);
|
|
2878
|
+
pending.push(new PendingExecute(backend, routine, inputs, outputs));
|
|
2879
|
+
pending[pending.length - 1].updateRc(+outputs.length - 1);
|
|
2880
|
+
arrays.forEach((ar) => ar.dispose());
|
|
2881
|
+
return outputs.map((output, i) => new Array$1({
|
|
2882
|
+
source: output,
|
|
2883
|
+
st: require_backend.ShapeTracker.fromShape(avalsOut[i].shape),
|
|
2884
|
+
dtype: avalsOut[i].dtype,
|
|
2885
|
+
weakType: avalsOut[i].weakType,
|
|
2886
|
+
backend,
|
|
2887
|
+
committed,
|
|
2888
|
+
pending
|
|
2889
|
+
}));
|
|
2890
|
+
};
|
|
2834
2891
|
}
|
|
2835
2892
|
/**
|
|
2836
2893
|
* Normalizes this array into one backed by a `Slot`.
|
|
@@ -3164,65 +3221,11 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
3164
3221
|
[Primitive.Pad]([x], { width }) {
|
|
3165
3222
|
return [x.#reshape(x.#st.pad(width))];
|
|
3166
3223
|
},
|
|
3167
|
-
[Primitive.Sort](
|
|
3168
|
-
|
|
3169
|
-
|
|
3170
|
-
|
|
3171
|
-
|
|
3172
|
-
outputDtypes: [x.dtype]
|
|
3173
|
-
});
|
|
3174
|
-
return Array$1.#routine(routine, [x], [x.#weakType]);
|
|
3175
|
-
},
|
|
3176
|
-
[Primitive.Argsort]([x]) {
|
|
3177
|
-
const routine = new require_backend.Routine(require_backend.Routines.Argsort, {
|
|
3178
|
-
inputShapes: [x.shape],
|
|
3179
|
-
inputDtypes: [x.dtype],
|
|
3180
|
-
outputShapes: [x.shape, x.shape],
|
|
3181
|
-
outputDtypes: [x.dtype, require_backend.DType.Int32]
|
|
3182
|
-
});
|
|
3183
|
-
return Array$1.#routine(routine, [x], [x.#weakType, false]);
|
|
3184
|
-
},
|
|
3185
|
-
[Primitive.TriangularSolve]([a, b], { unitDiagonal }) {
|
|
3186
|
-
const routine = new require_backend.Routine(require_backend.Routines.TriangularSolve, {
|
|
3187
|
-
inputShapes: [a.shape, b.shape],
|
|
3188
|
-
inputDtypes: [a.dtype, b.dtype],
|
|
3189
|
-
outputShapes: [b.shape],
|
|
3190
|
-
outputDtypes: [b.dtype]
|
|
3191
|
-
}, { unitDiagonal });
|
|
3192
|
-
return Array$1.#routine(routine, [a, b], [a.#weakType && b.#weakType]);
|
|
3193
|
-
},
|
|
3194
|
-
[Primitive.Cholesky]([a]) {
|
|
3195
|
-
const routine = new require_backend.Routine(require_backend.Routines.Cholesky, {
|
|
3196
|
-
inputShapes: [a.shape],
|
|
3197
|
-
inputDtypes: [a.dtype],
|
|
3198
|
-
outputShapes: [a.shape],
|
|
3199
|
-
outputDtypes: [a.dtype]
|
|
3200
|
-
});
|
|
3201
|
-
return Array$1.#routine(routine, [a], [a.#weakType]);
|
|
3202
|
-
},
|
|
3203
|
-
[Primitive.LU]([a]) {
|
|
3204
|
-
const batch = a.shape.slice(0, -2);
|
|
3205
|
-
const [m, n] = a.shape.slice(-2);
|
|
3206
|
-
const routine = new require_backend.Routine(require_backend.Routines.LU, {
|
|
3207
|
-
inputShapes: [a.shape],
|
|
3208
|
-
inputDtypes: [a.dtype],
|
|
3209
|
-
outputShapes: [
|
|
3210
|
-
a.shape,
|
|
3211
|
-
[...batch, Math.min(m, n)],
|
|
3212
|
-
[...batch, m]
|
|
3213
|
-
],
|
|
3214
|
-
outputDtypes: [
|
|
3215
|
-
a.dtype,
|
|
3216
|
-
require_backend.DType.Int32,
|
|
3217
|
-
require_backend.DType.Int32
|
|
3218
|
-
]
|
|
3219
|
-
});
|
|
3220
|
-
return Array$1.#routine(routine, [a], [
|
|
3221
|
-
a.#weakType,
|
|
3222
|
-
false,
|
|
3223
|
-
false
|
|
3224
|
-
]);
|
|
3225
|
-
},
|
|
3224
|
+
[Primitive.Sort]: Array$1.#routine(Primitive.Sort),
|
|
3225
|
+
[Primitive.Argsort]: Array$1.#routine(Primitive.Argsort),
|
|
3226
|
+
[Primitive.TriangularSolve]: Array$1.#routine(Primitive.TriangularSolve),
|
|
3227
|
+
[Primitive.Cholesky]: Array$1.#routine(Primitive.Cholesky),
|
|
3228
|
+
[Primitive.LU]: Array$1.#routine(Primitive.LU),
|
|
3226
3229
|
[Primitive.Jit](args, { jaxpr }) {
|
|
3227
3230
|
if (jaxpr.inBinders.length !== args.length) throw new Error(`jit expects ${jaxpr.inBinders.length} args, got ${args.length}`);
|
|
3228
3231
|
const { backend, committed } = Array$1.#computeBackend("jit", args);
|
|
@@ -3304,7 +3307,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
3304
3307
|
if (!shape$1) {
|
|
3305
3308
|
shape$1 = [];
|
|
3306
3309
|
let cur = values;
|
|
3307
|
-
while (JsArray.isArray(cur)) {
|
|
3310
|
+
while (JsArray$1.isArray(cur)) {
|
|
3308
3311
|
shape$1.push(cur.length);
|
|
3309
3312
|
cur = cur[0];
|
|
3310
3313
|
}
|
|
@@ -4260,17 +4263,39 @@ function jvpFlat(f, primals, tangents) {
|
|
|
4260
4263
|
_usingCtx$1.d();
|
|
4261
4264
|
}
|
|
4262
4265
|
}
|
|
4263
|
-
function jvp$1(f, primals, tangents) {
|
|
4266
|
+
function jvp$1(f, primals, tangents, { hasAux = false } = {}) {
|
|
4264
4267
|
const [primalsFlat, inTree] = flatten(primals);
|
|
4265
4268
|
const [tangentsFlat, inTree2] = flatten(tangents);
|
|
4266
4269
|
if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
|
|
4267
|
-
|
|
4270
|
+
let flatFun, outTree, aux;
|
|
4271
|
+
if (hasAux) [flatFun, outTree, aux] = flattenFunWithAux(f, inTree);
|
|
4272
|
+
else [flatFun, outTree] = flattenFun(f, inTree);
|
|
4268
4273
|
const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
|
|
4269
4274
|
if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
|
|
4270
4275
|
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
4271
4276
|
const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
|
|
4277
|
+
if (hasAux) return [
|
|
4278
|
+
primalsOut,
|
|
4279
|
+
tangentsOut,
|
|
4280
|
+
lowerAux(aux.value)
|
|
4281
|
+
];
|
|
4272
4282
|
return [primalsOut, tangentsOut];
|
|
4273
4283
|
}
|
|
4284
|
+
/** Lowering for auxiliary data returned in `hasAux: true` methods. */
|
|
4285
|
+
function lowerAux(aux) {
|
|
4286
|
+
const level = currentTraceLevel();
|
|
4287
|
+
return map((x) => {
|
|
4288
|
+
if (x instanceof Tracer) while (x._trace.main.level > level) if (x instanceof JVPTracer) {
|
|
4289
|
+
x.tangent.dispose();
|
|
4290
|
+
x = x.primal;
|
|
4291
|
+
} else {
|
|
4292
|
+
const y = x.fullLower();
|
|
4293
|
+
if (y._trace.main.level >= x._trace.main.level) throw new Error("internal: lowerAux did not reduce trace level");
|
|
4294
|
+
x = y;
|
|
4295
|
+
}
|
|
4296
|
+
return x;
|
|
4297
|
+
}, aux);
|
|
4298
|
+
}
|
|
4274
4299
|
|
|
4275
4300
|
//#endregion
|
|
4276
4301
|
//#region src/frontend/linearize.ts
|
|
@@ -4341,9 +4366,11 @@ function linearizeFlat(f, primalsIn) {
|
|
|
4341
4366
|
dispose$1
|
|
4342
4367
|
];
|
|
4343
4368
|
}
|
|
4344
|
-
function linearize$1(f,
|
|
4369
|
+
function linearize$1(f, primalsIn, { hasAux = false } = {}) {
|
|
4345
4370
|
const [primalsInFlat, inTree] = flatten(primalsIn);
|
|
4346
|
-
|
|
4371
|
+
let fFlat, outTree, aux;
|
|
4372
|
+
if (hasAux) [fFlat, outTree, aux] = flattenFunWithAux(f, inTree);
|
|
4373
|
+
else [fFlat, outTree] = flattenFun(f, inTree);
|
|
4347
4374
|
const [primalsOutFlat, fLinFlat, dispose$1] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
|
|
4348
4375
|
if (outTree.value === void 0) throw new Error("outTree was not set in linearize");
|
|
4349
4376
|
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
@@ -4354,6 +4381,11 @@ function linearize$1(f, ...primalsIn) {
|
|
|
4354
4381
|
return unflatten(outTree.value, tangentsOutFlat);
|
|
4355
4382
|
});
|
|
4356
4383
|
fLin.dispose = dispose$1;
|
|
4384
|
+
if (hasAux) return [
|
|
4385
|
+
primalsOut,
|
|
4386
|
+
fLin,
|
|
4387
|
+
lowerAux(aux.value)
|
|
4388
|
+
];
|
|
4357
4389
|
return [primalsOut, fLin];
|
|
4358
4390
|
}
|
|
4359
4391
|
var PartialEvalTracer = class extends Tracer {
|
|
@@ -4854,9 +4886,11 @@ function vjpFlat(f, primalsIn) {
|
|
|
4854
4886
|
dispose$1
|
|
4855
4887
|
];
|
|
4856
4888
|
}
|
|
4857
|
-
function vjp$1(f,
|
|
4889
|
+
function vjp$1(f, primalsIn, { hasAux = false } = {}) {
|
|
4858
4890
|
const [primalsInFlat, inTree] = flatten(primalsIn);
|
|
4859
|
-
|
|
4891
|
+
let fFlat, outTree, aux;
|
|
4892
|
+
if (hasAux) [fFlat, outTree, aux] = flattenFunWithAux(f, inTree);
|
|
4893
|
+
else [fFlat, outTree] = flattenFun(f, inTree);
|
|
4860
4894
|
const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
|
|
4861
4895
|
if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
|
|
4862
4896
|
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
@@ -4867,26 +4901,43 @@ function vjp$1(f, ...primalsIn) {
|
|
|
4867
4901
|
return unflatten(inTree, cotangentsInFlat);
|
|
4868
4902
|
});
|
|
4869
4903
|
fVjp.dispose = dispose$1;
|
|
4904
|
+
if (hasAux) return [
|
|
4905
|
+
primalsOut,
|
|
4906
|
+
fVjp,
|
|
4907
|
+
lowerAux(aux.value)
|
|
4908
|
+
];
|
|
4870
4909
|
return [primalsOut, fVjp];
|
|
4871
4910
|
}
|
|
4872
|
-
function grad$1(f) {
|
|
4873
|
-
const valueAndGradFn = valueAndGrad$1(f);
|
|
4911
|
+
function grad$1(f, opts) {
|
|
4912
|
+
const valueAndGradFn = valueAndGrad$1(f, opts);
|
|
4874
4913
|
return (...x) => {
|
|
4875
|
-
|
|
4876
|
-
|
|
4877
|
-
|
|
4914
|
+
if (opts?.hasAux) {
|
|
4915
|
+
const [[y, aux], dx] = valueAndGradFn(...x);
|
|
4916
|
+
y.dispose();
|
|
4917
|
+
return [dx, aux];
|
|
4918
|
+
} else {
|
|
4919
|
+
const [y, dx] = valueAndGradFn(...x);
|
|
4920
|
+
y.dispose();
|
|
4921
|
+
return dx;
|
|
4922
|
+
}
|
|
4878
4923
|
};
|
|
4879
4924
|
}
|
|
4880
|
-
function valueAndGrad$1(f) {
|
|
4925
|
+
function valueAndGrad$1(f, opts) {
|
|
4926
|
+
const argnums = opts?.argnums ?? 0;
|
|
4927
|
+
const hasAux = opts?.hasAux ?? false;
|
|
4928
|
+
require_backend.checkInts(argnums);
|
|
4929
|
+
const argnumsSet = new Set(typeof argnums === "number" ? [argnums] : argnums);
|
|
4881
4930
|
return (...x) => {
|
|
4882
4931
|
if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
|
|
4883
|
-
|
|
4932
|
+
for (let i = 0; i < x.length; i++) if (!argnumsSet.has(i)) x[i] = map(stopGradient, x[i]);
|
|
4933
|
+
const [y, fVjp, aux] = vjp$1(f, x, { hasAux });
|
|
4884
4934
|
if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
|
|
4885
4935
|
if (!require_backend.isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
|
|
4886
|
-
const
|
|
4887
|
-
for (const r of rest) dispose(r);
|
|
4936
|
+
const cts = fVjp(onesLike$1(y.ref));
|
|
4888
4937
|
fVjp.dispose();
|
|
4889
|
-
|
|
4938
|
+
for (let i = 0; i < cts.length; i++) if (!argnumsSet.has(i)) dispose(cts[i]);
|
|
4939
|
+
const grads = typeof argnums === "number" ? cts[argnums] : argnums.map((i) => cts[i]);
|
|
4940
|
+
return hasAux ? [[y, aux], grads] : [y, grads];
|
|
4890
4941
|
};
|
|
4891
4942
|
}
|
|
4892
4943
|
function jacrev$1(f) {
|
|
@@ -4894,7 +4945,7 @@ function jacrev$1(f) {
|
|
|
4894
4945
|
if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
|
|
4895
4946
|
const [size$1] = x.shape;
|
|
4896
4947
|
const pullback = (ct) => {
|
|
4897
|
-
const [y, fVjp] = vjp$1(f, x);
|
|
4948
|
+
const [y, fVjp] = vjp$1(f, [x]);
|
|
4898
4949
|
y.dispose();
|
|
4899
4950
|
const [ret] = fVjp(ct);
|
|
4900
4951
|
fVjp.dispose();
|
|
@@ -4903,6 +4954,9 @@ function jacrev$1(f) {
|
|
|
4903
4954
|
return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
|
|
4904
4955
|
};
|
|
4905
4956
|
}
|
|
4957
|
+
function hessian$1(f) {
|
|
4958
|
+
return jacfwd$1(grad$1(f));
|
|
4959
|
+
}
|
|
4906
4960
|
|
|
4907
4961
|
//#endregion
|
|
4908
4962
|
//#region src/library/numpy/einsum.ts
|
|
@@ -5575,6 +5629,7 @@ __export(numpy_exports, {
|
|
|
5575
5629
|
moveaxis: () => moveaxis$1,
|
|
5576
5630
|
multiply: () => multiply,
|
|
5577
5631
|
nan: () => nan,
|
|
5632
|
+
nanToNum: () => nanToNum,
|
|
5578
5633
|
ndim: () => ndim,
|
|
5579
5634
|
negative: () => negative,
|
|
5580
5635
|
notEqual: () => notEqual,
|
|
@@ -5612,6 +5667,7 @@ __export(numpy_exports, {
|
|
|
5612
5667
|
std: () => std,
|
|
5613
5668
|
subtract: () => subtract,
|
|
5614
5669
|
sum: () => sum,
|
|
5670
|
+
swapaxes: () => swapaxes,
|
|
5615
5671
|
take: () => take,
|
|
5616
5672
|
tan: () => tan,
|
|
5617
5673
|
tanh: () => tanh,
|
|
@@ -5771,24 +5827,22 @@ function max(a, axis = null, opts) {
|
|
|
5771
5827
|
return reduce(a, require_backend.AluOp.Max, axis, opts);
|
|
5772
5828
|
}
|
|
5773
5829
|
/**
|
|
5774
|
-
* Test whether
|
|
5830
|
+
* Test whether any array element along a given axis evaluates to True.
|
|
5775
5831
|
*
|
|
5776
5832
|
* Returns a boolean array with the same shape as `a` with the specified axis
|
|
5777
5833
|
* removed. If axis is None, returns a scalar.
|
|
5778
5834
|
*/
|
|
5779
|
-
function
|
|
5780
|
-
|
|
5781
|
-
return min(a, axis, opts);
|
|
5835
|
+
function any(a, axis = null, opts) {
|
|
5836
|
+
return fudgeArray(a).any(axis, opts);
|
|
5782
5837
|
}
|
|
5783
5838
|
/**
|
|
5784
|
-
* Test whether
|
|
5839
|
+
* Test whether all array elements along a given axis evaluate to True.
|
|
5785
5840
|
*
|
|
5786
5841
|
* Returns a boolean array with the same shape as `a` with the specified axis
|
|
5787
5842
|
* removed. If axis is None, returns a scalar.
|
|
5788
5843
|
*/
|
|
5789
|
-
function
|
|
5790
|
-
|
|
5791
|
-
return max(a, axis, opts);
|
|
5844
|
+
function all(a, axis = null, opts) {
|
|
5845
|
+
return fudgeArray(a).all(axis, opts);
|
|
5792
5846
|
}
|
|
5793
5847
|
/** Return the peak-to-peak range along a given axis (`max - min`). */
|
|
5794
5848
|
function ptp(a, axis = null, opts) {
|
|
@@ -5889,7 +5943,7 @@ function split$1(a, indicesOrSections, axis = 0) {
|
|
|
5889
5943
|
const partSize = size$1 / indicesOrSections;
|
|
5890
5944
|
sizes = require_backend.rep(indicesOrSections, partSize);
|
|
5891
5945
|
} else {
|
|
5892
|
-
const indices = indicesOrSections;
|
|
5946
|
+
const indices = indicesOrSections.map((i) => i < 0 ? i + size$1 : i);
|
|
5893
5947
|
sizes = [indices[0]];
|
|
5894
5948
|
for (let i = 1; i < indices.length; i++) sizes.push(indices[i] - indices[i - 1]);
|
|
5895
5949
|
sizes.push(size$1 - indices[indices.length - 1]);
|
|
@@ -6010,6 +6064,17 @@ function flipud(x) {
|
|
|
6010
6064
|
function fliplr(x) {
|
|
6011
6065
|
return flip(x, 1);
|
|
6012
6066
|
}
|
|
6067
|
+
/** Interchange two axes of an array. */
|
|
6068
|
+
function swapaxes(a, axis1, axis2) {
|
|
6069
|
+
a = fudgeArray(a);
|
|
6070
|
+
axis1 = require_backend.checkAxis(axis1, a.ndim);
|
|
6071
|
+
axis2 = require_backend.checkAxis(axis2, a.ndim);
|
|
6072
|
+
if (axis1 === axis2) return a;
|
|
6073
|
+
const perm = require_backend.range(a.ndim);
|
|
6074
|
+
perm[axis1] = axis2;
|
|
6075
|
+
perm[axis2] = axis1;
|
|
6076
|
+
return transpose(a, perm);
|
|
6077
|
+
}
|
|
6013
6078
|
/** Transpose the last two dimensions of an array. */
|
|
6014
6079
|
function matrixTranspose(a) {
|
|
6015
6080
|
if (ndim(a) < 2) throw new Error(`matrixTranspose: input array must be at least 2D`);
|
|
@@ -6826,6 +6891,21 @@ function isposinf(x) {
|
|
|
6826
6891
|
return require_backend.isFloatDtype(x.dtype) ? x.equal(Infinity) : fullLike$1(x, false);
|
|
6827
6892
|
}
|
|
6828
6893
|
/**
|
|
6894
|
+
* Replace NaN and infinite entries in an array.
|
|
6895
|
+
*
|
|
6896
|
+
* By default, NaNs are replaced with `0.0`, and infinities are are substituted
|
|
6897
|
+
* with the corresponding maximum or minimum finite values.
|
|
6898
|
+
*/
|
|
6899
|
+
function nanToNum(x, { nan: nan$1 = 0, posinf = null, neginf = null } = {}) {
|
|
6900
|
+
x = fudgeArray(x);
|
|
6901
|
+
x = where(isnan(x.ref), nan$1, x);
|
|
6902
|
+
posinf ??= require_backend.isFloatDtype(x.dtype) ? finfo(x.dtype).max : iinfo(x.dtype).max;
|
|
6903
|
+
neginf ??= require_backend.isFloatDtype(x.dtype) ? finfo(x.dtype).min : iinfo(x.dtype).min;
|
|
6904
|
+
x = where(isposinf(x.ref), posinf, x);
|
|
6905
|
+
x = where(isneginf(x.ref), neginf, x);
|
|
6906
|
+
return x;
|
|
6907
|
+
}
|
|
6908
|
+
/**
|
|
6829
6909
|
* @function
|
|
6830
6910
|
* Test element-wise for finite values (not infinity or NaN).
|
|
6831
6911
|
*/
|
|
@@ -6938,6 +7018,7 @@ var lax_exports = {};
|
|
|
6938
7018
|
__export(lax_exports, {
|
|
6939
7019
|
conv: () => conv,
|
|
6940
7020
|
convGeneralDilated: () => convGeneralDilated,
|
|
7021
|
+
convTranspose: () => convTranspose,
|
|
6941
7022
|
convWithGeneralPadding: () => convWithGeneralPadding,
|
|
6942
7023
|
dot: () => dot,
|
|
6943
7024
|
erf: () => erf,
|
|
@@ -6946,6 +7027,7 @@ __export(lax_exports, {
|
|
|
6946
7027
|
reduceWindow: () => reduceWindow,
|
|
6947
7028
|
stopGradient: () => stopGradient$1
|
|
6948
7029
|
});
|
|
7030
|
+
const JsArray = globalThis.Array;
|
|
6949
7031
|
/**
|
|
6950
7032
|
* General dot product/contraction operator.
|
|
6951
7033
|
*
|
|
@@ -7017,7 +7099,11 @@ function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
|
|
|
7017
7099
|
* The semantics of this operation mimic the `jax.lax.conv_general_dilated`
|
|
7018
7100
|
* function in JAX, which wraps XLA's general convolution operator.
|
|
7019
7101
|
*
|
|
7020
|
-
*
|
|
7102
|
+
* @param lhs - Input tensor; shape `[N, C_in, ...xs]`
|
|
7103
|
+
* @param rhs - Convolution kernel; shape `[C_out, C_in / G, ...ks]`
|
|
7104
|
+
* @param windowStrides - Strides for each spatial dimension
|
|
7105
|
+
* @param padding - Padding for each spatial dimension, or a string
|
|
7106
|
+
* (`"VALID"`, `"SAME"`, or `"SAME_LOWER"`)
|
|
7021
7107
|
*/
|
|
7022
7108
|
function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
|
|
7023
7109
|
if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
|
|
@@ -7077,6 +7163,60 @@ function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, r
|
|
|
7077
7163
|
function conv(lhs, rhs, windowStrides, padding) {
|
|
7078
7164
|
return convGeneralDilated(lhs, rhs, windowStrides, padding);
|
|
7079
7165
|
}
|
|
7166
|
+
/**
|
|
7167
|
+
* Convenience wrapper for calculating the N-d convolution "transpose".
|
|
7168
|
+
*
|
|
7169
|
+
* This function directly calculates a fractionally strided conv rather than
|
|
7170
|
+
* indirectly calculating the gradient (transpose) of a forward convolution.
|
|
7171
|
+
* It is equivalent to the JAX version, except:
|
|
7172
|
+
*
|
|
7173
|
+
* - The `use_consistent_padding` option is not available. We only have the
|
|
7174
|
+
* consistent padding case (JAX version >0.8.4).
|
|
7175
|
+
* - The order of dimensions matches `lax.conv_general_dilated`.
|
|
7176
|
+
*
|
|
7177
|
+
* Unlike PyTorch/TensorFlow, by default we don't reverse the kernel's spatial
|
|
7178
|
+
* dimensions or the `(C_out, C_in)` axis order. To get this behavior, set
|
|
7179
|
+
* `transposeKernel` to true.
|
|
7180
|
+
*
|
|
7181
|
+
* @param lhs - Input tensor; shape `[N, C_in, ...xs]`
|
|
7182
|
+
* @param rhs - Convolution kernel; shape `[C_out, C_in, ...ks]`
|
|
7183
|
+
* @param strides - Sequence of n integers, sets fractional stride
|
|
7184
|
+
* @param padding - Apply padding of `dilation * (kernel_size - 1) - padding` to
|
|
7185
|
+
* each side of the input, so it acts like gradient of `conv()`
|
|
7186
|
+
* @param rhsDilation - Atrous dilation for the kernel
|
|
7187
|
+
* @param transposeKernel - Flip spatial axes and swap the input/output channels
|
|
7188
|
+
* of the kernel; its shape should be `[C_in, C_out, ...ks]`
|
|
7189
|
+
*/
|
|
7190
|
+
function convTranspose(lhs, rhs, strides, padding, { rhsDilation, transposeKernel = false } = {}) {
|
|
7191
|
+
const kernelShape = rhs.shape.slice(2);
|
|
7192
|
+
rhsDilation = rhsDilation ?? require_backend.rep(kernelShape.length, 1);
|
|
7193
|
+
const effectiveKernel = kernelShape.map((k, i) => Math.max(0, (k - 1) * rhsDilation[i] + 1));
|
|
7194
|
+
const pads = effectiveKernel.map((k, i) => convTransposePadding(k, strides[i], typeof padding === "string" ? padding : padding[i]));
|
|
7195
|
+
if (transposeKernel) {
|
|
7196
|
+
rhs = flip$1(rhs, require_backend.range(2, rhs.ndim));
|
|
7197
|
+
rhs = moveaxis(rhs, 0, 1);
|
|
7198
|
+
}
|
|
7199
|
+
return convGeneralDilated(lhs, rhs, require_backend.rep(lhs.ndim - 2, 1), pads, {
|
|
7200
|
+
lhsDilation: strides,
|
|
7201
|
+
rhsDilation
|
|
7202
|
+
});
|
|
7203
|
+
}
|
|
7204
|
+
function convTransposePadding(k, s, padding) {
|
|
7205
|
+
let padLen;
|
|
7206
|
+
let pad1;
|
|
7207
|
+
if (padding === "SAME") {
|
|
7208
|
+
padLen = k + s - 2;
|
|
7209
|
+
pad1 = s > k - 1 ? k - 1 : Math.ceil(padLen / 2);
|
|
7210
|
+
} else if (padding === "VALID") {
|
|
7211
|
+
padLen = k + s - 2 + Math.max(k - s, 0);
|
|
7212
|
+
pad1 = k - 1;
|
|
7213
|
+
} else if (JsArray.isArray(padding)) {
|
|
7214
|
+
const pads = [k - 1 - padding[0], k - 1 - padding[1]];
|
|
7215
|
+
pad1 = pads[0];
|
|
7216
|
+
padLen = pads[0] + pads[1];
|
|
7217
|
+
} else throw new Error(`convTranspose: Invalid padding type ${padding}`);
|
|
7218
|
+
return [pad1, padLen - pad1];
|
|
7219
|
+
}
|
|
7080
7220
|
/** Reduce a computation over padded windows. */
|
|
7081
7221
|
function reduceWindow(operand, computation, windowDimensions, windowStrides) {
|
|
7082
7222
|
if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
|
|
@@ -7115,6 +7255,7 @@ function stopGradient$1(x) {
|
|
|
7115
7255
|
var nn_exports = {};
|
|
7116
7256
|
__export(nn_exports, {
|
|
7117
7257
|
celu: () => celu,
|
|
7258
|
+
dotProductAttention: () => dotProductAttention,
|
|
7118
7259
|
elu: () => elu,
|
|
7119
7260
|
gelu: () => gelu,
|
|
7120
7261
|
glu: () => glu,
|
|
@@ -7431,6 +7572,125 @@ function oneHot(x, numClasses) {
|
|
|
7431
7572
|
if (require_backend.isFloatDtype(x.dtype)) throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
|
|
7432
7573
|
return eye(numClasses, void 0, { device: x.device }).slice(x);
|
|
7433
7574
|
}
|
|
7575
|
+
/**
|
|
7576
|
+
* Scaled dot product attention (SDPA).
|
|
7577
|
+
*
|
|
7578
|
+
* Computes `softmax((Q @ K^T) / sqrt(d) + bias) @ V`, where `Q` is the query,
|
|
7579
|
+
* `K` is the key, `V` is the value, and `d` is the dimensionality of each key
|
|
7580
|
+
* and query vector.
|
|
7581
|
+
*
|
|
7582
|
+
* Multi-query attention is applied when input `key` and `value` tensors have
|
|
7583
|
+
* fewer heads than `query`.
|
|
7584
|
+
*
|
|
7585
|
+
* We use the following uppercase letters to denote array shapes:
|
|
7586
|
+
* - `B` = batch size
|
|
7587
|
+
* - `S` = length of key/value sequences (source)
|
|
7588
|
+
* - `L` = length of query sequences
|
|
7589
|
+
* - `N` = number of attention heads
|
|
7590
|
+
* - `H` = dimensionality of each attention head
|
|
7591
|
+
* - `K` = number of key/value heads (for grouped-query attention)
|
|
7592
|
+
*
|
|
7593
|
+
* The batch size `B` may be omitted, which is equivalent to `B = 1`. In this
|
|
7594
|
+
* case it must be omitted from all inputs.
|
|
7595
|
+
*
|
|
7596
|
+
* @param query - Query array; shape `[B, L, N, H]`
|
|
7597
|
+
* @param key - Key array; shape `[B, S, K, H]`
|
|
7598
|
+
* @param value - Value array; same shape as `key`
|
|
7599
|
+
* @param opts.bias - Optional bias to add to the attention logits; shape
|
|
7600
|
+
* `[B, N, L, S]` or broadcastable to it.
|
|
7601
|
+
* @param opts.mask - Optional mask to apply to the attention logits; should be
|
|
7602
|
+
* a boolean array broadcastable to `[B, N, L, S]`, where `true` indicates
|
|
7603
|
+
* the element should take part in attention.
|
|
7604
|
+
* @param opts.scale - Scaling factor override, default is `1 / sqrt(H)`.
|
|
7605
|
+
* @param opts.isCausal - If true, applies a casual mask.
|
|
7606
|
+
* @param opts.querySeqLengths - Optional sequence lengths for the queries;
|
|
7607
|
+
* shape `(B,)`. Taken from the beginning of the tensor.
|
|
7608
|
+
* @param opts.keyValueSeqLengths - Optional sequence lengths for the keys and
|
|
7609
|
+
* values; shape `(B,)`. Taken from the beginning of the tensor.
|
|
7610
|
+
* @param opts.localWindowSize - If specified, applies a local attention window
|
|
7611
|
+
* of the given size. Can be a single number or a tuple `[left, right]`.
|
|
7612
|
+
*
|
|
7613
|
+
* @returns The result of the attention operation; shape is the same as query
|
|
7614
|
+
* `[B, L, N, H]`, or `[L, N, H]` if `B` is omitted.
|
|
7615
|
+
*/
|
|
7616
|
+
function dotProductAttention(query, key$1, value, opts = {}) {
|
|
7617
|
+
query = fudgeArray(query);
|
|
7618
|
+
key$1 = fudgeArray(key$1);
|
|
7619
|
+
value = fudgeArray(value);
|
|
7620
|
+
if (query.ndim !== 3 && query.ndim !== 4 || query.ndim !== key$1.ndim || query.ndim !== value.ndim) throw new Error(`dotProductAttention: expected all tensors to have rank 3 or 4, got Q=${query.aval}, K=${key$1.aval}, V=${value.aval}`);
|
|
7621
|
+
if (!require_backend.deepEqual(key$1.shape, value.shape)) throw new Error(`dotProductAttention: key and value shapes must match, got K=${key$1.shape}, V=${value.shape}`);
|
|
7622
|
+
const isRank3 = query.ndim === 3;
|
|
7623
|
+
if (isRank3) {
|
|
7624
|
+
query = expandDims(query, 0);
|
|
7625
|
+
key$1 = expandDims(key$1, 0);
|
|
7626
|
+
value = expandDims(value, 0);
|
|
7627
|
+
}
|
|
7628
|
+
const [B, L, N, H] = query.shape;
|
|
7629
|
+
if (key$1.shape[0] !== B || key$1.shape[3] !== H) throw new Error(`dotProductAttention: query and key shapes mismatch, got Q=${query.aval}, K=${key$1.aval}`);
|
|
7630
|
+
const S = key$1.shape[1];
|
|
7631
|
+
const K = key$1.shape[2];
|
|
7632
|
+
if (N < K || N != K && N % K !== 0) throw new Error(`dotProductAttention: number of query heads N=${N} must be divisible by number of key/value heads K=${K} for GQA`);
|
|
7633
|
+
const G = N / K;
|
|
7634
|
+
key$1 = tile(key$1, [
|
|
7635
|
+
1,
|
|
7636
|
+
1,
|
|
7637
|
+
G,
|
|
7638
|
+
1
|
|
7639
|
+
]);
|
|
7640
|
+
value = tile(value, [
|
|
7641
|
+
1,
|
|
7642
|
+
1,
|
|
7643
|
+
G,
|
|
7644
|
+
1
|
|
7645
|
+
]);
|
|
7646
|
+
const scale = opts.scale ?? 1 / Math.sqrt(H);
|
|
7647
|
+
let scores = einsum("BLNH,BSNH->BNLS", query, key$1).mul(scale);
|
|
7648
|
+
if (opts.bias !== void 0) scores = scores.add(opts.bias);
|
|
7649
|
+
if (opts.mask !== void 0) scores = where(opts.mask, scores, -Infinity);
|
|
7650
|
+
if (opts.isCausal) {
|
|
7651
|
+
const causalMask = tri(L, S, 0, { dtype: require_backend.DType.Bool });
|
|
7652
|
+
scores = where(causalMask, scores, -Infinity);
|
|
7653
|
+
}
|
|
7654
|
+
if (opts.localWindowSize !== void 0) {
|
|
7655
|
+
const [before, after] = typeof opts.localWindowSize === "number" ? [opts.localWindowSize, opts.localWindowSize] : opts.localWindowSize;
|
|
7656
|
+
if (before < 0 || after < 0 || !Number.isInteger(before) || !Number.isInteger(after)) throw new Error(`dotProductAttention: localWindowSize values must be non-negative, got ${opts.localWindowSize}`);
|
|
7657
|
+
const localMask = tri(L, S, after, { dtype: require_backend.DType.Bool }).mul(tri(L, S, -before - 1, { dtype: require_backend.DType.Bool }).notEqual(true));
|
|
7658
|
+
scores = where(localMask, scores, -Infinity);
|
|
7659
|
+
}
|
|
7660
|
+
if (opts.querySeqLengths !== void 0) {
|
|
7661
|
+
const sl = expandDims(opts.querySeqLengths, [
|
|
7662
|
+
-1,
|
|
7663
|
+
-2,
|
|
7664
|
+
-3
|
|
7665
|
+
]);
|
|
7666
|
+
scores = where(arange(L).reshape([
|
|
7667
|
+
1,
|
|
7668
|
+
1,
|
|
7669
|
+
L,
|
|
7670
|
+
1
|
|
7671
|
+
]).less(sl), scores, -Infinity);
|
|
7672
|
+
}
|
|
7673
|
+
if (opts.keyValueSeqLengths !== void 0) {
|
|
7674
|
+
const sl = expandDims(opts.keyValueSeqLengths, [
|
|
7675
|
+
-1,
|
|
7676
|
+
-2,
|
|
7677
|
+
-3
|
|
7678
|
+
]);
|
|
7679
|
+
scores = where(arange(S).reshape([
|
|
7680
|
+
1,
|
|
7681
|
+
1,
|
|
7682
|
+
1,
|
|
7683
|
+
S
|
|
7684
|
+
]).less(sl), scores, -Infinity);
|
|
7685
|
+
}
|
|
7686
|
+
const attn = softmax(scores, -1);
|
|
7687
|
+
const out = einsum("BNLS,BSNH->BLNH", attn, value);
|
|
7688
|
+
return isRank3 ? out.reshape([
|
|
7689
|
+
L,
|
|
7690
|
+
N,
|
|
7691
|
+
H
|
|
7692
|
+
]) : out;
|
|
7693
|
+
}
|
|
7434
7694
|
|
|
7435
7695
|
//#endregion
|
|
7436
7696
|
//#region src/library/random.ts
|
|
@@ -7666,17 +7926,62 @@ const linearize = linearize$1;
|
|
|
7666
7926
|
/**
|
|
7667
7927
|
* @function
|
|
7668
7928
|
* Calculate the reverse-mode vector-Jacobian product for a function.
|
|
7929
|
+
*
|
|
7930
|
+
* The return value is a tuple of `[out, vjpFn]`, where `out` is the output of
|
|
7931
|
+
* `f(primals)`, and `vjpFn` is a function that takes in cotangents for each
|
|
7932
|
+
* output and returns the cotangents for each input.
|
|
7933
|
+
*
|
|
7934
|
+
* When `{ hasAux: true }` is passed, the function `f` is expected to return an
|
|
7935
|
+
* `[out, aux]` tuple, and `vjp` returns `[out, vjpFn, aux]`.
|
|
7936
|
+
*
|
|
7937
|
+
* @example
|
|
7938
|
+
* ```ts
|
|
7939
|
+
* const [y, vjpFn] = vjp(f, [x]);
|
|
7940
|
+
*
|
|
7941
|
+
* // With hasAux
|
|
7942
|
+
* const [y, vjpFn, aux] = vjp(f, [x], { hasAux: true });
|
|
7943
|
+
* ```
|
|
7669
7944
|
*/
|
|
7670
7945
|
const vjp = vjp$1;
|
|
7671
7946
|
/**
|
|
7672
7947
|
* @function
|
|
7673
7948
|
* Compute the gradient of a scalar-valued function `f` with respect to its
|
|
7674
7949
|
* first argument.
|
|
7950
|
+
*
|
|
7951
|
+
* Pass in different `argnums` to differentiate with respect to other
|
|
7952
|
+
* arguments. If a tuple is provided, the return value will be a tuple of
|
|
7953
|
+
* gradients corresponding to each argument index.
|
|
7954
|
+
*
|
|
7955
|
+
* When `{ hasAux: true }` is passed, the function `f` is expected to return a
|
|
7956
|
+
* `[out, aux]` tuple, and the return value will be `[gradient, aux]`.
|
|
7957
|
+
*
|
|
7958
|
+
* @example
|
|
7959
|
+
* ```ts
|
|
7960
|
+
* const gradient = grad(f)(x);
|
|
7961
|
+
*
|
|
7962
|
+
* // With `argnums`
|
|
7963
|
+
* const [gradientX, gradientZ] = grad(f, { argnums: [0, 2] })(x, y, z);
|
|
7964
|
+
*
|
|
7965
|
+
* // With `hasAux`
|
|
7966
|
+
* const [gradient, aux] = grad(f, { hasAux: true })(x);
|
|
7967
|
+
* ```
|
|
7675
7968
|
*/
|
|
7676
7969
|
const grad = grad$1;
|
|
7677
7970
|
/**
|
|
7678
7971
|
* @function
|
|
7679
7972
|
* Create a function that evaluates both `f` and the gradient of `f`.
|
|
7973
|
+
*
|
|
7974
|
+
* When `{ hasAux: true }` is passed, the function `f` is expected to return an
|
|
7975
|
+
* `[out, aux]` tuple, and the return value will be `[[out, aux], gradient]`.
|
|
7976
|
+
*
|
|
7977
|
+
* @example
|
|
7978
|
+
* ```ts
|
|
7979
|
+
* // Without hasAux
|
|
7980
|
+
* const [value, gradient] = valueAndGrad(f)(x);
|
|
7981
|
+
*
|
|
7982
|
+
* // With hasAux
|
|
7983
|
+
* const [[value, aux], gradient] = valueAndGrad(f, { hasAux: true })(x);
|
|
7984
|
+
* ```
|
|
7680
7985
|
*/
|
|
7681
7986
|
const valueAndGrad = valueAndGrad$1;
|
|
7682
7987
|
/**
|
|
@@ -7685,6 +7990,21 @@ const valueAndGrad = valueAndGrad$1;
|
|
|
7685
7990
|
*/
|
|
7686
7991
|
const jacrev = jacrev$1;
|
|
7687
7992
|
/**
|
|
7993
|
+
* @function
|
|
7994
|
+
* Compute the Hessian matrix of a scalar-valued function.
|
|
7995
|
+
*
|
|
7996
|
+
* The Hessian is the matrix of second-order partial derivatives of a function.
|
|
7997
|
+
* This is implemented as `jacfwd(grad(f))`.
|
|
7998
|
+
*
|
|
7999
|
+
* @example
|
|
8000
|
+
* ```ts
|
|
8001
|
+
* const f = (x: np.Array) => np.sum(x.ref.mul(x.ref).mul(x)); // x^3
|
|
8002
|
+
* const H = hessian(f)(np.array([1, 2, 3]));
|
|
8003
|
+
* // H[i,j] = d^2f / dx_i dx_j
|
|
8004
|
+
* ```
|
|
8005
|
+
*/
|
|
8006
|
+
const hessian = hessian$1;
|
|
8007
|
+
/**
|
|
7688
8008
|
* Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
|
|
7689
8009
|
*
|
|
7690
8010
|
* This can be used to wait for the results of an intermediate computation to
|
|
@@ -7728,6 +8048,7 @@ exports.defaultDevice = require_backend.defaultDevice;
|
|
|
7728
8048
|
exports.devicePut = devicePut;
|
|
7729
8049
|
exports.devices = require_backend.devices;
|
|
7730
8050
|
exports.grad = grad;
|
|
8051
|
+
exports.hessian = hessian;
|
|
7731
8052
|
exports.init = require_backend.init;
|
|
7732
8053
|
exports.jacfwd = jacfwd;
|
|
7733
8054
|
exports.jacobian = jacrev;
|