@jax-js/jax 0.1.5 → 0.1.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/{backend-DziQSaoQ.cjs → backend-D7s-Retx.cjs} +23 -4
- package/dist/{backend-DaqL-MNz.js → backend-Dx6Ob2D1.js} +18 -5
- package/dist/index.cjs +365 -110
- package/dist/index.d.cts +192 -13
- package/dist/index.d.ts +192 -13
- package/dist/index.js +365 -111
- package/dist/{webgl-RSuZKvgc.js → webgl-CLLvzJlO.js} +1 -1
- package/dist/{webgl-ClIYb8jP.cjs → webgl-CyfzNW8T.cjs} +1 -1
- package/dist/{webgpu-Dh7k9io0.js → webgpu-C-VfevQW.js} +1 -1
- package/dist/{webgpu-Db2JrNBr.cjs → webgpu-rraa6dfz.cjs} +1 -1
- package/package.json +1 -1
package/dist/index.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { __export } from "./chunk-Cl8Af3a2.js";
|
|
2
|
-
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-
|
|
2
|
+
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-Dx6Ob2D1.js";
|
|
3
3
|
|
|
4
4
|
//#region src/frontend/convolution.ts
|
|
5
5
|
/**
|
|
@@ -209,7 +209,7 @@ __export(tree_exports, {
|
|
|
209
209
|
structure: () => structure,
|
|
210
210
|
unflatten: () => unflatten
|
|
211
211
|
});
|
|
212
|
-
const JsArray$
|
|
212
|
+
const JsArray$2 = globalThis.Array;
|
|
213
213
|
let NodeType = /* @__PURE__ */ function(NodeType$1) {
|
|
214
214
|
NodeType$1["Array"] = "Array";
|
|
215
215
|
NodeType$1["Object"] = "Object";
|
|
@@ -257,7 +257,7 @@ function flatten(tree) {
|
|
|
257
257
|
return [leaves$1, treedef];
|
|
258
258
|
}
|
|
259
259
|
function _flatten(tree, leaves$1) {
|
|
260
|
-
if (JsArray$
|
|
260
|
+
if (JsArray$2.isArray(tree)) {
|
|
261
261
|
const childTrees = tree.map((c) => _flatten(c, leaves$1));
|
|
262
262
|
return new JsTreeDef(NodeType.Array, null, childTrees);
|
|
263
263
|
} else if (typeof tree === "object" && tree !== null && tree.constructor === Object) {
|
|
@@ -381,6 +381,13 @@ let CompareOp = /* @__PURE__ */ function(CompareOp$1) {
|
|
|
381
381
|
CompareOp$1["LessEqual"] = "less_equal";
|
|
382
382
|
return CompareOp$1;
|
|
383
383
|
}({});
|
|
384
|
+
const routinePrimitives = new Map([
|
|
385
|
+
[Primitive.Sort, Routines.Sort],
|
|
386
|
+
[Primitive.Argsort, Routines.Argsort],
|
|
387
|
+
[Primitive.TriangularSolve, Routines.TriangularSolve],
|
|
388
|
+
[Primitive.Cholesky, Routines.Cholesky],
|
|
389
|
+
[Primitive.LU, Routines.LU]
|
|
390
|
+
]);
|
|
384
391
|
function add$1(x, y) {
|
|
385
392
|
return bind1(Primitive.Add, [x, y]);
|
|
386
393
|
}
|
|
@@ -654,6 +661,9 @@ function newDynamic(main) {
|
|
|
654
661
|
dynamicTrace = prevDynamicTrace;
|
|
655
662
|
} };
|
|
656
663
|
}
|
|
664
|
+
function currentTraceLevel() {
|
|
665
|
+
return traceStack[traceStack.length - 1].level;
|
|
666
|
+
}
|
|
657
667
|
var Trace = class {
|
|
658
668
|
constructor(main) {
|
|
659
669
|
this.main = main;
|
|
@@ -1031,6 +1041,7 @@ var TreeMismatchError = class extends TypeError {
|
|
|
1031
1041
|
super(`Mismatched tree structures in ${where$2}: ${left} != ${right}`);
|
|
1032
1042
|
}
|
|
1033
1043
|
};
|
|
1044
|
+
/** Flatten a function of `JsTree` input/output for use in tracing. */
|
|
1034
1045
|
function flattenFun(f, inTree) {
|
|
1035
1046
|
const store = { value: void 0 };
|
|
1036
1047
|
const flatFun = (...argsFlat) => {
|
|
@@ -1042,6 +1053,26 @@ function flattenFun(f, inTree) {
|
|
|
1042
1053
|
};
|
|
1043
1054
|
return [flatFun, store];
|
|
1044
1055
|
}
|
|
1056
|
+
/** Like flattenFun, but expects f to return [main, aux] tuple. */
|
|
1057
|
+
function flattenFunWithAux(f, inTree) {
|
|
1058
|
+
const store = { value: void 0 };
|
|
1059
|
+
const auxStore = { value: void 0 };
|
|
1060
|
+
const flatFun = (...argsFlat) => {
|
|
1061
|
+
const pytreeArgs = unflatten(inTree, argsFlat);
|
|
1062
|
+
const result = f(...pytreeArgs);
|
|
1063
|
+
if (!Array.isArray(result) || result.length !== 2) throw new Error("Function with `hasAux: true` must return [output, aux] tuple");
|
|
1064
|
+
const [out, aux] = result;
|
|
1065
|
+
const [outFlat, outTree] = flatten(out);
|
|
1066
|
+
store.value = outTree;
|
|
1067
|
+
auxStore.value = aux;
|
|
1068
|
+
return outFlat;
|
|
1069
|
+
};
|
|
1070
|
+
return [
|
|
1071
|
+
flatFun,
|
|
1072
|
+
store,
|
|
1073
|
+
auxStore
|
|
1074
|
+
];
|
|
1075
|
+
}
|
|
1045
1076
|
var UseAfterFreeError = class extends ReferenceError {
|
|
1046
1077
|
constructor(tracer) {
|
|
1047
1078
|
super(`Referenced tracer ${tracer.toString()} freed, please use .ref move semantics`);
|
|
@@ -1771,13 +1802,6 @@ function jit$1(f, opts) {
|
|
|
1771
1802
|
|
|
1772
1803
|
//#endregion
|
|
1773
1804
|
//#region src/frontend/jit.ts
|
|
1774
|
-
const routinePrimitives = new Map([
|
|
1775
|
-
[Primitive.Sort, Routines.Sort],
|
|
1776
|
-
[Primitive.Argsort, Routines.Argsort],
|
|
1777
|
-
[Primitive.TriangularSolve, Routines.TriangularSolve],
|
|
1778
|
-
[Primitive.Cholesky, Routines.Cholesky],
|
|
1779
|
-
[Primitive.LU, Routines.LU]
|
|
1780
|
-
]);
|
|
1781
1805
|
/** Result of compiling a Jaxpr. Can be evaluated on a series of inputs. */
|
|
1782
1806
|
var JitProgram = class {
|
|
1783
1807
|
constructor(backend, steps, inputs, outputs) {
|
|
@@ -2166,12 +2190,13 @@ const jitRules = {
|
|
|
2166
2190
|
const ndim$2 = avals[0].ndim;
|
|
2167
2191
|
const sizes = avals.map((x) => x.shape[axis]);
|
|
2168
2192
|
const finalSize = sizes.reduce((a, b) => a + b, 0);
|
|
2193
|
+
const { dtype: dtypeOut } = avals.map((x) => x.scalar()).reduce(promoteAvals);
|
|
2169
2194
|
const makePadAxis = (start, end) => range(ndim$2).map((i) => i === axis ? [start, end] : [0, 0]);
|
|
2170
2195
|
let cum = 0;
|
|
2171
2196
|
const src = [];
|
|
2172
2197
|
for (let i = 0; i < exps.length; i++) {
|
|
2173
2198
|
const padding = makePadAxis(cum, finalSize - cum - sizes[i]);
|
|
2174
|
-
src.push(reshapeViews(exps[i], (st) => st.pad(padding)));
|
|
2199
|
+
src.push(reshapeViews(AluExp.cast(dtypeOut, exps[i]), (st) => st.pad(padding)));
|
|
2175
2200
|
cum += sizes[i];
|
|
2176
2201
|
}
|
|
2177
2202
|
return { exp: [src.reduce(AluExp.add)] };
|
|
@@ -2309,7 +2334,7 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2309
2334
|
p1NextBlack.set(v, v);
|
|
2310
2335
|
}
|
|
2311
2336
|
const heterogeneousViewPrimitives = [Primitive.RandomBits, Primitive.Gather];
|
|
2312
|
-
const needsCleanShapePrimitives = [Primitive.Pad];
|
|
2337
|
+
const needsCleanShapePrimitives = [Primitive.Concatenate, Primitive.Pad];
|
|
2313
2338
|
for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
|
|
2314
2339
|
const eqn = jaxpr.eqns[i];
|
|
2315
2340
|
if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || routinePrimitives.has(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
|
|
@@ -2379,7 +2404,7 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2379
2404
|
|
|
2380
2405
|
//#endregion
|
|
2381
2406
|
//#region src/frontend/array.ts
|
|
2382
|
-
const JsArray = globalThis.Array;
|
|
2407
|
+
const JsArray$1 = globalThis.Array;
|
|
2383
2408
|
const inlineArrayLimit = 128;
|
|
2384
2409
|
/** Version of pureArray with fudged types. */
|
|
2385
2410
|
const fudgeArray = pureArray;
|
|
@@ -2777,25 +2802,35 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2777
2802
|
});
|
|
2778
2803
|
}
|
|
2779
2804
|
/** Apply an operation with custom lowering to this array. */
|
|
2780
|
-
static #routine(
|
|
2781
|
-
|
|
2782
|
-
|
|
2783
|
-
|
|
2784
|
-
|
|
2785
|
-
|
|
2786
|
-
|
|
2787
|
-
|
|
2788
|
-
|
|
2789
|
-
|
|
2790
|
-
|
|
2791
|
-
|
|
2792
|
-
|
|
2793
|
-
dtype
|
|
2794
|
-
|
|
2795
|
-
|
|
2796
|
-
|
|
2797
|
-
pending
|
|
2798
|
-
|
|
2805
|
+
static #routine(prim) {
|
|
2806
|
+
return (arrays, params) => {
|
|
2807
|
+
const { backend, committed } = Array$1.#computeBackend(prim, arrays);
|
|
2808
|
+
for (const ar of arrays) ar.#realize();
|
|
2809
|
+
const avals = arrays.map((ar) => ar.aval);
|
|
2810
|
+
const avalsOut = abstractEvalRules[prim](avals, params);
|
|
2811
|
+
const routine = new Routine(routinePrimitives.get(prim), {
|
|
2812
|
+
inputShapes: avals.map((a) => a.shape),
|
|
2813
|
+
inputDtypes: avals.map((a) => a.dtype),
|
|
2814
|
+
outputShapes: avalsOut.map((a) => a.shape),
|
|
2815
|
+
outputDtypes: avalsOut.map((a) => a.dtype)
|
|
2816
|
+
}, params);
|
|
2817
|
+
const inputs = arrays.map((ar) => ar.#source);
|
|
2818
|
+
const outputs = avalsOut.map((x) => backend.malloc(byteWidth(x.dtype) * x.size));
|
|
2819
|
+
const pending = arrays.flatMap((ar) => ar.#pending);
|
|
2820
|
+
for (const exe of pending) exe.updateRc(+outputs.length);
|
|
2821
|
+
pending.push(new PendingExecute(backend, routine, inputs, outputs));
|
|
2822
|
+
pending[pending.length - 1].updateRc(+outputs.length - 1);
|
|
2823
|
+
arrays.forEach((ar) => ar.dispose());
|
|
2824
|
+
return outputs.map((output, i) => new Array$1({
|
|
2825
|
+
source: output,
|
|
2826
|
+
st: ShapeTracker.fromShape(avalsOut[i].shape),
|
|
2827
|
+
dtype: avalsOut[i].dtype,
|
|
2828
|
+
weakType: avalsOut[i].weakType,
|
|
2829
|
+
backend,
|
|
2830
|
+
committed,
|
|
2831
|
+
pending
|
|
2832
|
+
}));
|
|
2833
|
+
};
|
|
2799
2834
|
}
|
|
2800
2835
|
/**
|
|
2801
2836
|
* Normalizes this array into one backed by a `Slot`.
|
|
@@ -3129,65 +3164,11 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
3129
3164
|
[Primitive.Pad]([x], { width }) {
|
|
3130
3165
|
return [x.#reshape(x.#st.pad(width))];
|
|
3131
3166
|
},
|
|
3132
|
-
[Primitive.Sort](
|
|
3133
|
-
|
|
3134
|
-
|
|
3135
|
-
|
|
3136
|
-
|
|
3137
|
-
outputDtypes: [x.dtype]
|
|
3138
|
-
});
|
|
3139
|
-
return Array$1.#routine(routine, [x], [x.#weakType]);
|
|
3140
|
-
},
|
|
3141
|
-
[Primitive.Argsort]([x]) {
|
|
3142
|
-
const routine = new Routine(Routines.Argsort, {
|
|
3143
|
-
inputShapes: [x.shape],
|
|
3144
|
-
inputDtypes: [x.dtype],
|
|
3145
|
-
outputShapes: [x.shape, x.shape],
|
|
3146
|
-
outputDtypes: [x.dtype, DType.Int32]
|
|
3147
|
-
});
|
|
3148
|
-
return Array$1.#routine(routine, [x], [x.#weakType, false]);
|
|
3149
|
-
},
|
|
3150
|
-
[Primitive.TriangularSolve]([a, b], { unitDiagonal }) {
|
|
3151
|
-
const routine = new Routine(Routines.TriangularSolve, {
|
|
3152
|
-
inputShapes: [a.shape, b.shape],
|
|
3153
|
-
inputDtypes: [a.dtype, b.dtype],
|
|
3154
|
-
outputShapes: [b.shape],
|
|
3155
|
-
outputDtypes: [b.dtype]
|
|
3156
|
-
}, { unitDiagonal });
|
|
3157
|
-
return Array$1.#routine(routine, [a, b], [a.#weakType && b.#weakType]);
|
|
3158
|
-
},
|
|
3159
|
-
[Primitive.Cholesky]([a]) {
|
|
3160
|
-
const routine = new Routine(Routines.Cholesky, {
|
|
3161
|
-
inputShapes: [a.shape],
|
|
3162
|
-
inputDtypes: [a.dtype],
|
|
3163
|
-
outputShapes: [a.shape],
|
|
3164
|
-
outputDtypes: [a.dtype]
|
|
3165
|
-
});
|
|
3166
|
-
return Array$1.#routine(routine, [a], [a.#weakType]);
|
|
3167
|
-
},
|
|
3168
|
-
[Primitive.LU]([a]) {
|
|
3169
|
-
const batch = a.shape.slice(0, -2);
|
|
3170
|
-
const [m, n] = a.shape.slice(-2);
|
|
3171
|
-
const routine = new Routine(Routines.LU, {
|
|
3172
|
-
inputShapes: [a.shape],
|
|
3173
|
-
inputDtypes: [a.dtype],
|
|
3174
|
-
outputShapes: [
|
|
3175
|
-
a.shape,
|
|
3176
|
-
[...batch, Math.min(m, n)],
|
|
3177
|
-
[...batch, m]
|
|
3178
|
-
],
|
|
3179
|
-
outputDtypes: [
|
|
3180
|
-
a.dtype,
|
|
3181
|
-
DType.Int32,
|
|
3182
|
-
DType.Int32
|
|
3183
|
-
]
|
|
3184
|
-
});
|
|
3185
|
-
return Array$1.#routine(routine, [a], [
|
|
3186
|
-
a.#weakType,
|
|
3187
|
-
false,
|
|
3188
|
-
false
|
|
3189
|
-
]);
|
|
3190
|
-
},
|
|
3167
|
+
[Primitive.Sort]: Array$1.#routine(Primitive.Sort),
|
|
3168
|
+
[Primitive.Argsort]: Array$1.#routine(Primitive.Argsort),
|
|
3169
|
+
[Primitive.TriangularSolve]: Array$1.#routine(Primitive.TriangularSolve),
|
|
3170
|
+
[Primitive.Cholesky]: Array$1.#routine(Primitive.Cholesky),
|
|
3171
|
+
[Primitive.LU]: Array$1.#routine(Primitive.LU),
|
|
3191
3172
|
[Primitive.Jit](args, { jaxpr }) {
|
|
3192
3173
|
if (jaxpr.inBinders.length !== args.length) throw new Error(`jit expects ${jaxpr.inBinders.length} args, got ${args.length}`);
|
|
3193
3174
|
const { backend, committed } = Array$1.#computeBackend("jit", args);
|
|
@@ -3269,7 +3250,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
3269
3250
|
if (!shape$1) {
|
|
3270
3251
|
shape$1 = [];
|
|
3271
3252
|
let cur = values;
|
|
3272
|
-
while (JsArray.isArray(cur)) {
|
|
3253
|
+
while (JsArray$1.isArray(cur)) {
|
|
3273
3254
|
shape$1.push(cur.length);
|
|
3274
3255
|
cur = cur[0];
|
|
3275
3256
|
}
|
|
@@ -4223,17 +4204,39 @@ function jvpFlat(f, primals, tangents) {
|
|
|
4223
4204
|
_usingCtx$1.d();
|
|
4224
4205
|
}
|
|
4225
4206
|
}
|
|
4226
|
-
function jvp$1(f, primals, tangents) {
|
|
4207
|
+
function jvp$1(f, primals, tangents, { hasAux = false } = {}) {
|
|
4227
4208
|
const [primalsFlat, inTree] = flatten(primals);
|
|
4228
4209
|
const [tangentsFlat, inTree2] = flatten(tangents);
|
|
4229
4210
|
if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
|
|
4230
|
-
|
|
4211
|
+
let flatFun, outTree, aux;
|
|
4212
|
+
if (hasAux) [flatFun, outTree, aux] = flattenFunWithAux(f, inTree);
|
|
4213
|
+
else [flatFun, outTree] = flattenFun(f, inTree);
|
|
4231
4214
|
const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
|
|
4232
4215
|
if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
|
|
4233
4216
|
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
4234
4217
|
const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
|
|
4218
|
+
if (hasAux) return [
|
|
4219
|
+
primalsOut,
|
|
4220
|
+
tangentsOut,
|
|
4221
|
+
lowerAux(aux.value)
|
|
4222
|
+
];
|
|
4235
4223
|
return [primalsOut, tangentsOut];
|
|
4236
4224
|
}
|
|
4225
|
+
/** Lowering for auxiliary data returned in `hasAux: true` methods. */
|
|
4226
|
+
function lowerAux(aux) {
|
|
4227
|
+
const level = currentTraceLevel();
|
|
4228
|
+
return map((x) => {
|
|
4229
|
+
if (x instanceof Tracer) while (x._trace.main.level > level) if (x instanceof JVPTracer) {
|
|
4230
|
+
x.tangent.dispose();
|
|
4231
|
+
x = x.primal;
|
|
4232
|
+
} else {
|
|
4233
|
+
const y = x.fullLower();
|
|
4234
|
+
if (y._trace.main.level >= x._trace.main.level) throw new Error("internal: lowerAux did not reduce trace level");
|
|
4235
|
+
x = y;
|
|
4236
|
+
}
|
|
4237
|
+
return x;
|
|
4238
|
+
}, aux);
|
|
4239
|
+
}
|
|
4237
4240
|
|
|
4238
4241
|
//#endregion
|
|
4239
4242
|
//#region src/frontend/linearize.ts
|
|
@@ -4304,9 +4307,11 @@ function linearizeFlat(f, primalsIn) {
|
|
|
4304
4307
|
dispose$1
|
|
4305
4308
|
];
|
|
4306
4309
|
}
|
|
4307
|
-
function linearize$1(f,
|
|
4310
|
+
function linearize$1(f, primalsIn, { hasAux = false } = {}) {
|
|
4308
4311
|
const [primalsInFlat, inTree] = flatten(primalsIn);
|
|
4309
|
-
|
|
4312
|
+
let fFlat, outTree, aux;
|
|
4313
|
+
if (hasAux) [fFlat, outTree, aux] = flattenFunWithAux(f, inTree);
|
|
4314
|
+
else [fFlat, outTree] = flattenFun(f, inTree);
|
|
4310
4315
|
const [primalsOutFlat, fLinFlat, dispose$1] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
|
|
4311
4316
|
if (outTree.value === void 0) throw new Error("outTree was not set in linearize");
|
|
4312
4317
|
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
@@ -4317,6 +4322,11 @@ function linearize$1(f, ...primalsIn) {
|
|
|
4317
4322
|
return unflatten(outTree.value, tangentsOutFlat);
|
|
4318
4323
|
});
|
|
4319
4324
|
fLin.dispose = dispose$1;
|
|
4325
|
+
if (hasAux) return [
|
|
4326
|
+
primalsOut,
|
|
4327
|
+
fLin,
|
|
4328
|
+
lowerAux(aux.value)
|
|
4329
|
+
];
|
|
4320
4330
|
return [primalsOut, fLin];
|
|
4321
4331
|
}
|
|
4322
4332
|
var PartialEvalTracer = class extends Tracer {
|
|
@@ -4817,9 +4827,11 @@ function vjpFlat(f, primalsIn) {
|
|
|
4817
4827
|
dispose$1
|
|
4818
4828
|
];
|
|
4819
4829
|
}
|
|
4820
|
-
function vjp$1(f,
|
|
4830
|
+
function vjp$1(f, primalsIn, { hasAux = false } = {}) {
|
|
4821
4831
|
const [primalsInFlat, inTree] = flatten(primalsIn);
|
|
4822
|
-
|
|
4832
|
+
let fFlat, outTree, aux;
|
|
4833
|
+
if (hasAux) [fFlat, outTree, aux] = flattenFunWithAux(f, inTree);
|
|
4834
|
+
else [fFlat, outTree] = flattenFun(f, inTree);
|
|
4823
4835
|
const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
|
|
4824
4836
|
if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
|
|
4825
4837
|
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
@@ -4830,26 +4842,43 @@ function vjp$1(f, ...primalsIn) {
|
|
|
4830
4842
|
return unflatten(inTree, cotangentsInFlat);
|
|
4831
4843
|
});
|
|
4832
4844
|
fVjp.dispose = dispose$1;
|
|
4845
|
+
if (hasAux) return [
|
|
4846
|
+
primalsOut,
|
|
4847
|
+
fVjp,
|
|
4848
|
+
lowerAux(aux.value)
|
|
4849
|
+
];
|
|
4833
4850
|
return [primalsOut, fVjp];
|
|
4834
4851
|
}
|
|
4835
|
-
function grad$1(f) {
|
|
4836
|
-
const valueAndGradFn = valueAndGrad$1(f);
|
|
4852
|
+
function grad$1(f, opts) {
|
|
4853
|
+
const valueAndGradFn = valueAndGrad$1(f, opts);
|
|
4837
4854
|
return (...x) => {
|
|
4838
|
-
|
|
4839
|
-
|
|
4840
|
-
|
|
4855
|
+
if (opts?.hasAux) {
|
|
4856
|
+
const [[y, aux], dx] = valueAndGradFn(...x);
|
|
4857
|
+
y.dispose();
|
|
4858
|
+
return [dx, aux];
|
|
4859
|
+
} else {
|
|
4860
|
+
const [y, dx] = valueAndGradFn(...x);
|
|
4861
|
+
y.dispose();
|
|
4862
|
+
return dx;
|
|
4863
|
+
}
|
|
4841
4864
|
};
|
|
4842
4865
|
}
|
|
4843
|
-
function valueAndGrad$1(f) {
|
|
4866
|
+
function valueAndGrad$1(f, opts) {
|
|
4867
|
+
const argnums = opts?.argnums ?? 0;
|
|
4868
|
+
const hasAux = opts?.hasAux ?? false;
|
|
4869
|
+
checkInts(argnums);
|
|
4870
|
+
const argnumsSet = new Set(typeof argnums === "number" ? [argnums] : argnums);
|
|
4844
4871
|
return (...x) => {
|
|
4845
4872
|
if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
|
|
4846
|
-
|
|
4873
|
+
for (let i = 0; i < x.length; i++) if (!argnumsSet.has(i)) x[i] = map(stopGradient, x[i]);
|
|
4874
|
+
const [y, fVjp, aux] = vjp$1(f, x, { hasAux });
|
|
4847
4875
|
if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
|
|
4848
4876
|
if (!isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
|
|
4849
|
-
const
|
|
4850
|
-
for (const r of rest) dispose(r);
|
|
4877
|
+
const cts = fVjp(onesLike$1(y.ref));
|
|
4851
4878
|
fVjp.dispose();
|
|
4852
|
-
|
|
4879
|
+
for (let i = 0; i < cts.length; i++) if (!argnumsSet.has(i)) dispose(cts[i]);
|
|
4880
|
+
const grads = typeof argnums === "number" ? cts[argnums] : argnums.map((i) => cts[i]);
|
|
4881
|
+
return hasAux ? [[y, aux], grads] : [y, grads];
|
|
4853
4882
|
};
|
|
4854
4883
|
}
|
|
4855
4884
|
function jacrev$1(f) {
|
|
@@ -4857,7 +4886,7 @@ function jacrev$1(f) {
|
|
|
4857
4886
|
if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
|
|
4858
4887
|
const [size$1] = x.shape;
|
|
4859
4888
|
const pullback = (ct) => {
|
|
4860
|
-
const [y, fVjp] = vjp$1(f, x);
|
|
4889
|
+
const [y, fVjp] = vjp$1(f, [x]);
|
|
4861
4890
|
y.dispose();
|
|
4862
4891
|
const [ret] = fVjp(ct);
|
|
4863
4892
|
fVjp.dispose();
|
|
@@ -4866,6 +4895,9 @@ function jacrev$1(f) {
|
|
|
4866
4895
|
return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
|
|
4867
4896
|
};
|
|
4868
4897
|
}
|
|
4898
|
+
function hessian$1(f) {
|
|
4899
|
+
return jacfwd$1(grad$1(f));
|
|
4900
|
+
}
|
|
4869
4901
|
|
|
4870
4902
|
//#endregion
|
|
4871
4903
|
//#region src/library/numpy/einsum.ts
|
|
@@ -5575,6 +5607,7 @@ __export(numpy_exports, {
|
|
|
5575
5607
|
std: () => std,
|
|
5576
5608
|
subtract: () => subtract,
|
|
5577
5609
|
sum: () => sum,
|
|
5610
|
+
swapaxes: () => swapaxes,
|
|
5578
5611
|
take: () => take,
|
|
5579
5612
|
tan: () => tan,
|
|
5580
5613
|
tanh: () => tanh,
|
|
@@ -5973,6 +6006,17 @@ function flipud(x) {
|
|
|
5973
6006
|
function fliplr(x) {
|
|
5974
6007
|
return flip(x, 1);
|
|
5975
6008
|
}
|
|
6009
|
+
/** Interchange two axes of an array. */
|
|
6010
|
+
function swapaxes(a, axis1, axis2) {
|
|
6011
|
+
a = fudgeArray(a);
|
|
6012
|
+
axis1 = checkAxis(axis1, a.ndim);
|
|
6013
|
+
axis2 = checkAxis(axis2, a.ndim);
|
|
6014
|
+
if (axis1 === axis2) return a;
|
|
6015
|
+
const perm = range(a.ndim);
|
|
6016
|
+
perm[axis1] = axis2;
|
|
6017
|
+
perm[axis2] = axis1;
|
|
6018
|
+
return transpose(a, perm);
|
|
6019
|
+
}
|
|
5976
6020
|
/** Transpose the last two dimensions of an array. */
|
|
5977
6021
|
function matrixTranspose(a) {
|
|
5978
6022
|
if (ndim(a) < 2) throw new Error(`matrixTranspose: input array must be at least 2D`);
|
|
@@ -6901,6 +6945,7 @@ var lax_exports = {};
|
|
|
6901
6945
|
__export(lax_exports, {
|
|
6902
6946
|
conv: () => conv,
|
|
6903
6947
|
convGeneralDilated: () => convGeneralDilated,
|
|
6948
|
+
convTranspose: () => convTranspose,
|
|
6904
6949
|
convWithGeneralPadding: () => convWithGeneralPadding,
|
|
6905
6950
|
dot: () => dot,
|
|
6906
6951
|
erf: () => erf,
|
|
@@ -6909,6 +6954,7 @@ __export(lax_exports, {
|
|
|
6909
6954
|
reduceWindow: () => reduceWindow,
|
|
6910
6955
|
stopGradient: () => stopGradient$1
|
|
6911
6956
|
});
|
|
6957
|
+
const JsArray = globalThis.Array;
|
|
6912
6958
|
/**
|
|
6913
6959
|
* General dot product/contraction operator.
|
|
6914
6960
|
*
|
|
@@ -6980,7 +7026,11 @@ function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
|
|
|
6980
7026
|
* The semantics of this operation mimic the `jax.lax.conv_general_dilated`
|
|
6981
7027
|
* function in JAX, which wraps XLA's general convolution operator.
|
|
6982
7028
|
*
|
|
6983
|
-
*
|
|
7029
|
+
* @param lhs - Input tensor; shape `[N, C_in, ...xs]`
|
|
7030
|
+
* @param rhs - Convolution kernel; shape `[C_out, C_in / G, ...ks]`
|
|
7031
|
+
* @param windowStrides - Strides for each spatial dimension
|
|
7032
|
+
* @param padding - Padding for each spatial dimension, or a string
|
|
7033
|
+
* (`"VALID"`, `"SAME"`, or `"SAME_LOWER"`)
|
|
6984
7034
|
*/
|
|
6985
7035
|
function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
|
|
6986
7036
|
if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
|
|
@@ -7040,6 +7090,60 @@ function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, r
|
|
|
7040
7090
|
function conv(lhs, rhs, windowStrides, padding) {
|
|
7041
7091
|
return convGeneralDilated(lhs, rhs, windowStrides, padding);
|
|
7042
7092
|
}
|
|
7093
|
+
/**
|
|
7094
|
+
* Convenience wrapper for calculating the N-d convolution "transpose".
|
|
7095
|
+
*
|
|
7096
|
+
* This function directly calculates a fractionally strided conv rather than
|
|
7097
|
+
* indirectly calculating the gradient (transpose) of a forward convolution.
|
|
7098
|
+
* It is equivalent to the JAX version, except:
|
|
7099
|
+
*
|
|
7100
|
+
* - The `use_consistent_padding` option is not available. We only have the
|
|
7101
|
+
* consistent padding case (JAX version >0.8.4).
|
|
7102
|
+
* - The order of dimensions matches `lax.conv_general_dilated`.
|
|
7103
|
+
*
|
|
7104
|
+
* Unlike PyTorch/TensorFlow, by default we don't reverse the kernel's spatial
|
|
7105
|
+
* dimensions or the `(C_out, C_in)` axis order. To get this behavior, set
|
|
7106
|
+
* `transposeKernel` to true.
|
|
7107
|
+
*
|
|
7108
|
+
* @param lhs - Input tensor; shape `[N, C_in, ...xs]`
|
|
7109
|
+
* @param rhs - Convolution kernel; shape `[C_out, C_in, ...ks]`
|
|
7110
|
+
* @param strides - Sequence of n integers, sets fractional stride
|
|
7111
|
+
* @param padding - Apply padding of `dilation * (kernel_size - 1) - padding` to
|
|
7112
|
+
* each side of the input, so it acts like gradient of `conv()`
|
|
7113
|
+
* @param rhsDilation - Atrous dilation for the kernel
|
|
7114
|
+
* @param transposeKernel - Flip spatial axes and swap the input/output channels
|
|
7115
|
+
* of the kernel; its shape should be `[C_in, C_out, ...ks]`
|
|
7116
|
+
*/
|
|
7117
|
+
function convTranspose(lhs, rhs, strides, padding, { rhsDilation, transposeKernel = false } = {}) {
|
|
7118
|
+
const kernelShape = rhs.shape.slice(2);
|
|
7119
|
+
rhsDilation = rhsDilation ?? rep(kernelShape.length, 1);
|
|
7120
|
+
const effectiveKernel = kernelShape.map((k, i) => Math.max(0, (k - 1) * rhsDilation[i] + 1));
|
|
7121
|
+
const pads = effectiveKernel.map((k, i) => convTransposePadding(k, strides[i], typeof padding === "string" ? padding : padding[i]));
|
|
7122
|
+
if (transposeKernel) {
|
|
7123
|
+
rhs = flip$1(rhs, range(2, rhs.ndim));
|
|
7124
|
+
rhs = moveaxis(rhs, 0, 1);
|
|
7125
|
+
}
|
|
7126
|
+
return convGeneralDilated(lhs, rhs, rep(lhs.ndim - 2, 1), pads, {
|
|
7127
|
+
lhsDilation: strides,
|
|
7128
|
+
rhsDilation
|
|
7129
|
+
});
|
|
7130
|
+
}
|
|
7131
|
+
function convTransposePadding(k, s, padding) {
|
|
7132
|
+
let padLen;
|
|
7133
|
+
let pad1;
|
|
7134
|
+
if (padding === "SAME") {
|
|
7135
|
+
padLen = k + s - 2;
|
|
7136
|
+
pad1 = s > k - 1 ? k - 1 : Math.ceil(padLen / 2);
|
|
7137
|
+
} else if (padding === "VALID") {
|
|
7138
|
+
padLen = k + s - 2 + Math.max(k - s, 0);
|
|
7139
|
+
pad1 = k - 1;
|
|
7140
|
+
} else if (JsArray.isArray(padding)) {
|
|
7141
|
+
const pads = [k - 1 - padding[0], k - 1 - padding[1]];
|
|
7142
|
+
pad1 = pads[0];
|
|
7143
|
+
padLen = pads[0] + pads[1];
|
|
7144
|
+
} else throw new Error(`convTranspose: Invalid padding type ${padding}`);
|
|
7145
|
+
return [pad1, padLen - pad1];
|
|
7146
|
+
}
|
|
7043
7147
|
/** Reduce a computation over padded windows. */
|
|
7044
7148
|
function reduceWindow(operand, computation, windowDimensions, windowStrides) {
|
|
7045
7149
|
if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
|
|
@@ -7078,6 +7182,7 @@ function stopGradient$1(x) {
|
|
|
7078
7182
|
var nn_exports = {};
|
|
7079
7183
|
__export(nn_exports, {
|
|
7080
7184
|
celu: () => celu,
|
|
7185
|
+
dotProductAttention: () => dotProductAttention,
|
|
7081
7186
|
elu: () => elu,
|
|
7082
7187
|
gelu: () => gelu,
|
|
7083
7188
|
glu: () => glu,
|
|
@@ -7394,6 +7499,95 @@ function oneHot(x, numClasses) {
|
|
|
7394
7499
|
if (isFloatDtype(x.dtype)) throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
|
|
7395
7500
|
return eye(numClasses, void 0, { device: x.device }).slice(x);
|
|
7396
7501
|
}
|
|
7502
|
+
/**
|
|
7503
|
+
* Scaled dot product attention (SDPA).
|
|
7504
|
+
*
|
|
7505
|
+
* Computes `softmax((Q @ K^T) / sqrt(d) + bias) @ V`, where `Q` is the query,
|
|
7506
|
+
* `K` is the key, `V` is the value, and `d` is the dimensionality of each key
|
|
7507
|
+
* and query vector.
|
|
7508
|
+
*
|
|
7509
|
+
* Multi-query attention is applied when input `key` and `value` tensors have
|
|
7510
|
+
* fewer heads than `query`.
|
|
7511
|
+
*
|
|
7512
|
+
* We use the following uppercase letters to denote array shapes:
|
|
7513
|
+
* - `B` = batch size
|
|
7514
|
+
* - `S` = length of key/value sequences (source)
|
|
7515
|
+
* - `L` = length of query sequences
|
|
7516
|
+
* - `N` = number of attention heads
|
|
7517
|
+
* - `H` = dimensionality of each attention head
|
|
7518
|
+
* - `K` = number of key/value heads (for grouped-query attention)
|
|
7519
|
+
*
|
|
7520
|
+
* The batch size `B` may be omitted, which is equivalent to `B = 1`. In this
|
|
7521
|
+
* case it must be omitted from all inputs.
|
|
7522
|
+
*
|
|
7523
|
+
* @param query - Query array; shape `[B, L, N, H]`
|
|
7524
|
+
* @param key - Key array; shape `[B, S, K, H]`
|
|
7525
|
+
* @param value - Value array; same shape as `key`
|
|
7526
|
+
* @param opts.bias - Optional bias to add to the attention logits; shape
|
|
7527
|
+
* `[B, N, L, S]` or broadcastable to it.
|
|
7528
|
+
* @param opts.mask - Optional mask to apply to the attention logits; should be
|
|
7529
|
+
* a boolean array broadcastable to `[B, N, L, S]`, where `true` indicates
|
|
7530
|
+
* the element should take part in attention.
|
|
7531
|
+
* @param opts.scale - Scaling factor override, default is `1 / sqrt(H)`.
|
|
7532
|
+
* @param opts.isCausal - If true, applies a casual mask.
|
|
7533
|
+
* @param opts.querySeqLengths - Optional sequence lengths for the queries;
|
|
7534
|
+
* shape `(B,)`. Taken from the beginning of the tensor.
|
|
7535
|
+
* @param opts.keyValueSeqLengths - Optional sequence lengths for the keys and
|
|
7536
|
+
* values; shape `(B,)`. Taken from the beginning of the tensor.
|
|
7537
|
+
* @param opts.localWindowSize - If specified, applies a local attention window
|
|
7538
|
+
* of the given size. Can be a single number or a tuple `[left, right]`.
|
|
7539
|
+
*
|
|
7540
|
+
* @returns The result of the attention operation; shape is the same as query
|
|
7541
|
+
* `[B, L, N, H]`, or `[L, N, H]` if `B` is omitted.
|
|
7542
|
+
*/
|
|
7543
|
+
function dotProductAttention(query, key$1, value, opts = {}) {
|
|
7544
|
+
if (opts.querySeqLengths !== void 0 || opts.keyValueSeqLengths !== void 0) throw new Error("Sequence length masking is not yet implemented");
|
|
7545
|
+
if (opts.localWindowSize !== void 0) throw new Error("Local attention is not yet implemented");
|
|
7546
|
+
query = fudgeArray(query);
|
|
7547
|
+
key$1 = fudgeArray(key$1);
|
|
7548
|
+
value = fudgeArray(value);
|
|
7549
|
+
if (query.ndim !== 3 && query.ndim !== 4 || query.ndim !== key$1.ndim || query.ndim !== value.ndim) throw new Error(`dotProductAttention: expected all tensors to have rank 3 or 4, got Q=${query.aval}, K=${key$1.aval}, V=${value.aval}`);
|
|
7550
|
+
if (!deepEqual(key$1.shape, value.shape)) throw new Error(`dotProductAttention: key and value shapes must match, got K=${key$1.shape}, V=${value.shape}`);
|
|
7551
|
+
const isRank3 = query.ndim === 3;
|
|
7552
|
+
if (isRank3) {
|
|
7553
|
+
query = expandDims(query, 0);
|
|
7554
|
+
key$1 = expandDims(key$1, 0);
|
|
7555
|
+
value = expandDims(value, 0);
|
|
7556
|
+
}
|
|
7557
|
+
const [B, L, N, H] = query.shape;
|
|
7558
|
+
if (key$1.shape[0] !== B || key$1.shape[3] !== H) throw new Error(`dotProductAttention: query and key shapes mismatch, got Q=${query.aval}, K=${key$1.aval}`);
|
|
7559
|
+
const S = key$1.shape[1];
|
|
7560
|
+
const K = key$1.shape[2];
|
|
7561
|
+
if (N < K || N != K && N % K !== 0) throw new Error(`dotProductAttention: number of query heads N=${N} must be divisible by number of key/value heads K=${K} for GQA`);
|
|
7562
|
+
const G = N / K;
|
|
7563
|
+
key$1 = tile(key$1, [
|
|
7564
|
+
1,
|
|
7565
|
+
1,
|
|
7566
|
+
G,
|
|
7567
|
+
1
|
|
7568
|
+
]);
|
|
7569
|
+
value = tile(value, [
|
|
7570
|
+
1,
|
|
7571
|
+
1,
|
|
7572
|
+
G,
|
|
7573
|
+
1
|
|
7574
|
+
]);
|
|
7575
|
+
const scale = opts.scale ?? 1 / Math.sqrt(H);
|
|
7576
|
+
let scores = einsum("BLNH,BSNH->BNLS", query, key$1).mul(scale);
|
|
7577
|
+
if (opts.bias !== void 0) scores = scores.add(opts.bias);
|
|
7578
|
+
if (opts.mask !== void 0) scores = where(opts.mask, scores, -Infinity);
|
|
7579
|
+
if (opts.isCausal) {
|
|
7580
|
+
const causalMask = tri(L, S, 0, { dtype: DType.Bool });
|
|
7581
|
+
scores = where(causalMask, scores, -Infinity);
|
|
7582
|
+
}
|
|
7583
|
+
const attn = softmax(scores, -1);
|
|
7584
|
+
const out = einsum("BNLS,BSNH->BLNH", attn, value);
|
|
7585
|
+
return isRank3 ? out.reshape([
|
|
7586
|
+
L,
|
|
7587
|
+
N,
|
|
7588
|
+
H
|
|
7589
|
+
]) : out;
|
|
7590
|
+
}
|
|
7397
7591
|
|
|
7398
7592
|
//#endregion
|
|
7399
7593
|
//#region src/library/random.ts
|
|
@@ -7629,17 +7823,62 @@ const linearize = linearize$1;
|
|
|
7629
7823
|
/**
|
|
7630
7824
|
* @function
|
|
7631
7825
|
* Calculate the reverse-mode vector-Jacobian product for a function.
|
|
7826
|
+
*
|
|
7827
|
+
* The return value is a tuple of `[out, vjpFn]`, where `out` is the output of
|
|
7828
|
+
* `f(primals)`, and `vjpFn` is a function that takes in cotangents for each
|
|
7829
|
+
* output and returns the cotangents for each input.
|
|
7830
|
+
*
|
|
7831
|
+
* When `{ hasAux: true }` is passed, the function `f` is expected to return an
|
|
7832
|
+
* `[out, aux]` tuple, and `vjp` returns `[out, vjpFn, aux]`.
|
|
7833
|
+
*
|
|
7834
|
+
* @example
|
|
7835
|
+
* ```ts
|
|
7836
|
+
* const [y, vjpFn] = vjp(f, [x]);
|
|
7837
|
+
*
|
|
7838
|
+
* // With hasAux
|
|
7839
|
+
* const [y, vjpFn, aux] = vjp(f, [x], { hasAux: true });
|
|
7840
|
+
* ```
|
|
7632
7841
|
*/
|
|
7633
7842
|
const vjp = vjp$1;
|
|
7634
7843
|
/**
|
|
7635
7844
|
* @function
|
|
7636
7845
|
* Compute the gradient of a scalar-valued function `f` with respect to its
|
|
7637
7846
|
* first argument.
|
|
7847
|
+
*
|
|
7848
|
+
* Pass in different `argnums` to differentiate with respect to other
|
|
7849
|
+
* arguments. If a tuple is provided, the return value will be a tuple of
|
|
7850
|
+
* gradients corresponding to each argument index.
|
|
7851
|
+
*
|
|
7852
|
+
* When `{ hasAux: true }` is passed, the function `f` is expected to return a
|
|
7853
|
+
* `[out, aux]` tuple, and the return value will be `[gradient, aux]`.
|
|
7854
|
+
*
|
|
7855
|
+
* @example
|
|
7856
|
+
* ```ts
|
|
7857
|
+
* const gradient = grad(f)(x);
|
|
7858
|
+
*
|
|
7859
|
+
* // With `argnums`
|
|
7860
|
+
* const [gradientX, gradientZ] = grad(f, { argnums: [0, 2] })(x, y, z);
|
|
7861
|
+
*
|
|
7862
|
+
* // With `hasAux`
|
|
7863
|
+
* const [gradient, aux] = grad(f, { hasAux: true })(x);
|
|
7864
|
+
* ```
|
|
7638
7865
|
*/
|
|
7639
7866
|
const grad = grad$1;
|
|
7640
7867
|
/**
|
|
7641
7868
|
* @function
|
|
7642
7869
|
* Create a function that evaluates both `f` and the gradient of `f`.
|
|
7870
|
+
*
|
|
7871
|
+
* When `{ hasAux: true }` is passed, the function `f` is expected to return an
|
|
7872
|
+
* `[out, aux]` tuple, and the return value will be `[[out, aux], gradient]`.
|
|
7873
|
+
*
|
|
7874
|
+
* @example
|
|
7875
|
+
* ```ts
|
|
7876
|
+
* // Without hasAux
|
|
7877
|
+
* const [value, gradient] = valueAndGrad(f)(x);
|
|
7878
|
+
*
|
|
7879
|
+
* // With hasAux
|
|
7880
|
+
* const [[value, aux], gradient] = valueAndGrad(f, { hasAux: true })(x);
|
|
7881
|
+
* ```
|
|
7643
7882
|
*/
|
|
7644
7883
|
const valueAndGrad = valueAndGrad$1;
|
|
7645
7884
|
/**
|
|
@@ -7648,6 +7887,21 @@ const valueAndGrad = valueAndGrad$1;
|
|
|
7648
7887
|
*/
|
|
7649
7888
|
const jacrev = jacrev$1;
|
|
7650
7889
|
/**
|
|
7890
|
+
* @function
|
|
7891
|
+
* Compute the Hessian matrix of a scalar-valued function.
|
|
7892
|
+
*
|
|
7893
|
+
* The Hessian is the matrix of second-order partial derivatives of a function.
|
|
7894
|
+
* This is implemented as `jacfwd(grad(f))`.
|
|
7895
|
+
*
|
|
7896
|
+
* @example
|
|
7897
|
+
* ```ts
|
|
7898
|
+
* const f = (x: np.Array) => np.sum(x.ref.mul(x.ref).mul(x)); // x^3
|
|
7899
|
+
* const H = hessian(f)(np.array([1, 2, 3]));
|
|
7900
|
+
* // H[i,j] = d^2f / dx_i dx_j
|
|
7901
|
+
* ```
|
|
7902
|
+
*/
|
|
7903
|
+
const hessian = hessian$1;
|
|
7904
|
+
/**
|
|
7651
7905
|
* Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
|
|
7652
7906
|
*
|
|
7653
7907
|
* This can be used to wait for the results of an intermediate computation to
|
|
@@ -7682,4 +7936,4 @@ async function devicePut(x, device) {
|
|
|
7682
7936
|
}
|
|
7683
7937
|
|
|
7684
7938
|
//#endregion
|
|
7685
|
-
export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
|
|
7939
|
+
export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
|