@jax-js/jax 0.1.5 → 0.1.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/{backend-DziQSaoQ.cjs → backend-D7s-Retx.cjs} +23 -4
- package/dist/{backend-DaqL-MNz.js → backend-Dx6Ob2D1.js} +18 -5
- package/dist/index.cjs +365 -110
- package/dist/index.d.cts +192 -13
- package/dist/index.d.ts +192 -13
- package/dist/index.js +365 -111
- package/dist/{webgl-RSuZKvgc.js → webgl-CLLvzJlO.js} +1 -1
- package/dist/{webgl-ClIYb8jP.cjs → webgl-CyfzNW8T.cjs} +1 -1
- package/dist/{webgpu-Dh7k9io0.js → webgpu-C-VfevQW.js} +1 -1
- package/dist/{webgpu-Db2JrNBr.cjs → webgpu-rraa6dfz.cjs} +1 -1
- package/package.json +1 -1
package/dist/index.cjs
CHANGED
|
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
|
|
|
30
30
|
}) : target, mod$1));
|
|
31
31
|
|
|
32
32
|
//#endregion
|
|
33
|
-
const require_backend = require('./backend-
|
|
33
|
+
const require_backend = require('./backend-D7s-Retx.cjs');
|
|
34
34
|
|
|
35
35
|
//#region src/frontend/convolution.ts
|
|
36
36
|
/**
|
|
@@ -240,7 +240,7 @@ __export(tree_exports, {
|
|
|
240
240
|
structure: () => structure,
|
|
241
241
|
unflatten: () => unflatten
|
|
242
242
|
});
|
|
243
|
-
const JsArray$
|
|
243
|
+
const JsArray$2 = globalThis.Array;
|
|
244
244
|
let NodeType = /* @__PURE__ */ function(NodeType$1) {
|
|
245
245
|
NodeType$1["Array"] = "Array";
|
|
246
246
|
NodeType$1["Object"] = "Object";
|
|
@@ -288,7 +288,7 @@ function flatten(tree) {
|
|
|
288
288
|
return [leaves$1, treedef];
|
|
289
289
|
}
|
|
290
290
|
function _flatten(tree, leaves$1) {
|
|
291
|
-
if (JsArray$
|
|
291
|
+
if (JsArray$2.isArray(tree)) {
|
|
292
292
|
const childTrees = tree.map((c) => _flatten(c, leaves$1));
|
|
293
293
|
return new JsTreeDef(NodeType.Array, null, childTrees);
|
|
294
294
|
} else if (typeof tree === "object" && tree !== null && tree.constructor === Object) {
|
|
@@ -412,6 +412,13 @@ let CompareOp = /* @__PURE__ */ function(CompareOp$1) {
|
|
|
412
412
|
CompareOp$1["LessEqual"] = "less_equal";
|
|
413
413
|
return CompareOp$1;
|
|
414
414
|
}({});
|
|
415
|
+
const routinePrimitives = new Map([
|
|
416
|
+
[Primitive.Sort, require_backend.Routines.Sort],
|
|
417
|
+
[Primitive.Argsort, require_backend.Routines.Argsort],
|
|
418
|
+
[Primitive.TriangularSolve, require_backend.Routines.TriangularSolve],
|
|
419
|
+
[Primitive.Cholesky, require_backend.Routines.Cholesky],
|
|
420
|
+
[Primitive.LU, require_backend.Routines.LU]
|
|
421
|
+
]);
|
|
415
422
|
function add$1(x, y) {
|
|
416
423
|
return bind1(Primitive.Add, [x, y]);
|
|
417
424
|
}
|
|
@@ -685,6 +692,9 @@ function newDynamic(main) {
|
|
|
685
692
|
dynamicTrace = prevDynamicTrace;
|
|
686
693
|
} };
|
|
687
694
|
}
|
|
695
|
+
function currentTraceLevel() {
|
|
696
|
+
return traceStack[traceStack.length - 1].level;
|
|
697
|
+
}
|
|
688
698
|
var Trace = class {
|
|
689
699
|
constructor(main) {
|
|
690
700
|
this.main = main;
|
|
@@ -1062,6 +1072,7 @@ var TreeMismatchError = class extends TypeError {
|
|
|
1062
1072
|
super(`Mismatched tree structures in ${where$2}: ${left} != ${right}`);
|
|
1063
1073
|
}
|
|
1064
1074
|
};
|
|
1075
|
+
/** Flatten a function of `JsTree` input/output for use in tracing. */
|
|
1065
1076
|
function flattenFun(f, inTree) {
|
|
1066
1077
|
const store = { value: void 0 };
|
|
1067
1078
|
const flatFun = (...argsFlat) => {
|
|
@@ -1073,6 +1084,26 @@ function flattenFun(f, inTree) {
|
|
|
1073
1084
|
};
|
|
1074
1085
|
return [flatFun, store];
|
|
1075
1086
|
}
|
|
1087
|
+
/** Like flattenFun, but expects f to return [main, aux] tuple. */
|
|
1088
|
+
function flattenFunWithAux(f, inTree) {
|
|
1089
|
+
const store = { value: void 0 };
|
|
1090
|
+
const auxStore = { value: void 0 };
|
|
1091
|
+
const flatFun = (...argsFlat) => {
|
|
1092
|
+
const pytreeArgs = unflatten(inTree, argsFlat);
|
|
1093
|
+
const result = f(...pytreeArgs);
|
|
1094
|
+
if (!Array.isArray(result) || result.length !== 2) throw new Error("Function with `hasAux: true` must return [output, aux] tuple");
|
|
1095
|
+
const [out, aux] = result;
|
|
1096
|
+
const [outFlat, outTree] = flatten(out);
|
|
1097
|
+
store.value = outTree;
|
|
1098
|
+
auxStore.value = aux;
|
|
1099
|
+
return outFlat;
|
|
1100
|
+
};
|
|
1101
|
+
return [
|
|
1102
|
+
flatFun,
|
|
1103
|
+
store,
|
|
1104
|
+
auxStore
|
|
1105
|
+
];
|
|
1106
|
+
}
|
|
1076
1107
|
var UseAfterFreeError = class extends ReferenceError {
|
|
1077
1108
|
constructor(tracer) {
|
|
1078
1109
|
super(`Referenced tracer ${tracer.toString()} freed, please use .ref move semantics`);
|
|
@@ -1806,13 +1837,6 @@ function jit$1(f, opts) {
|
|
|
1806
1837
|
|
|
1807
1838
|
//#endregion
|
|
1808
1839
|
//#region src/frontend/jit.ts
|
|
1809
|
-
const routinePrimitives = new Map([
|
|
1810
|
-
[Primitive.Sort, require_backend.Routines.Sort],
|
|
1811
|
-
[Primitive.Argsort, require_backend.Routines.Argsort],
|
|
1812
|
-
[Primitive.TriangularSolve, require_backend.Routines.TriangularSolve],
|
|
1813
|
-
[Primitive.Cholesky, require_backend.Routines.Cholesky],
|
|
1814
|
-
[Primitive.LU, require_backend.Routines.LU]
|
|
1815
|
-
]);
|
|
1816
1840
|
/** Result of compiling a Jaxpr. Can be evaluated on a series of inputs. */
|
|
1817
1841
|
var JitProgram = class {
|
|
1818
1842
|
constructor(backend, steps, inputs, outputs) {
|
|
@@ -2201,12 +2225,13 @@ const jitRules = {
|
|
|
2201
2225
|
const ndim$2 = avals[0].ndim;
|
|
2202
2226
|
const sizes = avals.map((x) => x.shape[axis]);
|
|
2203
2227
|
const finalSize = sizes.reduce((a, b) => a + b, 0);
|
|
2228
|
+
const { dtype: dtypeOut } = avals.map((x) => x.scalar()).reduce(promoteAvals);
|
|
2204
2229
|
const makePadAxis = (start, end) => require_backend.range(ndim$2).map((i) => i === axis ? [start, end] : [0, 0]);
|
|
2205
2230
|
let cum = 0;
|
|
2206
2231
|
const src = [];
|
|
2207
2232
|
for (let i = 0; i < exps.length; i++) {
|
|
2208
2233
|
const padding = makePadAxis(cum, finalSize - cum - sizes[i]);
|
|
2209
|
-
src.push(reshapeViews(exps[i], (st) => st.pad(padding)));
|
|
2234
|
+
src.push(reshapeViews(require_backend.AluExp.cast(dtypeOut, exps[i]), (st) => st.pad(padding)));
|
|
2210
2235
|
cum += sizes[i];
|
|
2211
2236
|
}
|
|
2212
2237
|
return { exp: [src.reduce(require_backend.AluExp.add)] };
|
|
@@ -2344,7 +2369,7 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2344
2369
|
p1NextBlack.set(v, v);
|
|
2345
2370
|
}
|
|
2346
2371
|
const heterogeneousViewPrimitives = [Primitive.RandomBits, Primitive.Gather];
|
|
2347
|
-
const needsCleanShapePrimitives = [Primitive.Pad];
|
|
2372
|
+
const needsCleanShapePrimitives = [Primitive.Concatenate, Primitive.Pad];
|
|
2348
2373
|
for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
|
|
2349
2374
|
const eqn = jaxpr.eqns[i];
|
|
2350
2375
|
if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || routinePrimitives.has(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
|
|
@@ -2414,7 +2439,7 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2414
2439
|
|
|
2415
2440
|
//#endregion
|
|
2416
2441
|
//#region src/frontend/array.ts
|
|
2417
|
-
const JsArray = globalThis.Array;
|
|
2442
|
+
const JsArray$1 = globalThis.Array;
|
|
2418
2443
|
const inlineArrayLimit = 128;
|
|
2419
2444
|
/** Version of pureArray with fudged types. */
|
|
2420
2445
|
const fudgeArray = pureArray;
|
|
@@ -2812,25 +2837,35 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2812
2837
|
});
|
|
2813
2838
|
}
|
|
2814
2839
|
/** Apply an operation with custom lowering to this array. */
|
|
2815
|
-
static #routine(
|
|
2816
|
-
|
|
2817
|
-
|
|
2818
|
-
|
|
2819
|
-
|
|
2820
|
-
|
|
2821
|
-
|
|
2822
|
-
|
|
2823
|
-
|
|
2824
|
-
|
|
2825
|
-
|
|
2826
|
-
|
|
2827
|
-
|
|
2828
|
-
dtype
|
|
2829
|
-
|
|
2830
|
-
|
|
2831
|
-
|
|
2832
|
-
pending
|
|
2833
|
-
|
|
2840
|
+
static #routine(prim) {
|
|
2841
|
+
return (arrays, params) => {
|
|
2842
|
+
const { backend, committed } = Array$1.#computeBackend(prim, arrays);
|
|
2843
|
+
for (const ar of arrays) ar.#realize();
|
|
2844
|
+
const avals = arrays.map((ar) => ar.aval);
|
|
2845
|
+
const avalsOut = abstractEvalRules[prim](avals, params);
|
|
2846
|
+
const routine = new require_backend.Routine(routinePrimitives.get(prim), {
|
|
2847
|
+
inputShapes: avals.map((a) => a.shape),
|
|
2848
|
+
inputDtypes: avals.map((a) => a.dtype),
|
|
2849
|
+
outputShapes: avalsOut.map((a) => a.shape),
|
|
2850
|
+
outputDtypes: avalsOut.map((a) => a.dtype)
|
|
2851
|
+
}, params);
|
|
2852
|
+
const inputs = arrays.map((ar) => ar.#source);
|
|
2853
|
+
const outputs = avalsOut.map((x) => backend.malloc(require_backend.byteWidth(x.dtype) * x.size));
|
|
2854
|
+
const pending = arrays.flatMap((ar) => ar.#pending);
|
|
2855
|
+
for (const exe of pending) exe.updateRc(+outputs.length);
|
|
2856
|
+
pending.push(new PendingExecute(backend, routine, inputs, outputs));
|
|
2857
|
+
pending[pending.length - 1].updateRc(+outputs.length - 1);
|
|
2858
|
+
arrays.forEach((ar) => ar.dispose());
|
|
2859
|
+
return outputs.map((output, i) => new Array$1({
|
|
2860
|
+
source: output,
|
|
2861
|
+
st: require_backend.ShapeTracker.fromShape(avalsOut[i].shape),
|
|
2862
|
+
dtype: avalsOut[i].dtype,
|
|
2863
|
+
weakType: avalsOut[i].weakType,
|
|
2864
|
+
backend,
|
|
2865
|
+
committed,
|
|
2866
|
+
pending
|
|
2867
|
+
}));
|
|
2868
|
+
};
|
|
2834
2869
|
}
|
|
2835
2870
|
/**
|
|
2836
2871
|
* Normalizes this array into one backed by a `Slot`.
|
|
@@ -3164,65 +3199,11 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
3164
3199
|
[Primitive.Pad]([x], { width }) {
|
|
3165
3200
|
return [x.#reshape(x.#st.pad(width))];
|
|
3166
3201
|
},
|
|
3167
|
-
[Primitive.Sort](
|
|
3168
|
-
|
|
3169
|
-
|
|
3170
|
-
|
|
3171
|
-
|
|
3172
|
-
outputDtypes: [x.dtype]
|
|
3173
|
-
});
|
|
3174
|
-
return Array$1.#routine(routine, [x], [x.#weakType]);
|
|
3175
|
-
},
|
|
3176
|
-
[Primitive.Argsort]([x]) {
|
|
3177
|
-
const routine = new require_backend.Routine(require_backend.Routines.Argsort, {
|
|
3178
|
-
inputShapes: [x.shape],
|
|
3179
|
-
inputDtypes: [x.dtype],
|
|
3180
|
-
outputShapes: [x.shape, x.shape],
|
|
3181
|
-
outputDtypes: [x.dtype, require_backend.DType.Int32]
|
|
3182
|
-
});
|
|
3183
|
-
return Array$1.#routine(routine, [x], [x.#weakType, false]);
|
|
3184
|
-
},
|
|
3185
|
-
[Primitive.TriangularSolve]([a, b], { unitDiagonal }) {
|
|
3186
|
-
const routine = new require_backend.Routine(require_backend.Routines.TriangularSolve, {
|
|
3187
|
-
inputShapes: [a.shape, b.shape],
|
|
3188
|
-
inputDtypes: [a.dtype, b.dtype],
|
|
3189
|
-
outputShapes: [b.shape],
|
|
3190
|
-
outputDtypes: [b.dtype]
|
|
3191
|
-
}, { unitDiagonal });
|
|
3192
|
-
return Array$1.#routine(routine, [a, b], [a.#weakType && b.#weakType]);
|
|
3193
|
-
},
|
|
3194
|
-
[Primitive.Cholesky]([a]) {
|
|
3195
|
-
const routine = new require_backend.Routine(require_backend.Routines.Cholesky, {
|
|
3196
|
-
inputShapes: [a.shape],
|
|
3197
|
-
inputDtypes: [a.dtype],
|
|
3198
|
-
outputShapes: [a.shape],
|
|
3199
|
-
outputDtypes: [a.dtype]
|
|
3200
|
-
});
|
|
3201
|
-
return Array$1.#routine(routine, [a], [a.#weakType]);
|
|
3202
|
-
},
|
|
3203
|
-
[Primitive.LU]([a]) {
|
|
3204
|
-
const batch = a.shape.slice(0, -2);
|
|
3205
|
-
const [m, n] = a.shape.slice(-2);
|
|
3206
|
-
const routine = new require_backend.Routine(require_backend.Routines.LU, {
|
|
3207
|
-
inputShapes: [a.shape],
|
|
3208
|
-
inputDtypes: [a.dtype],
|
|
3209
|
-
outputShapes: [
|
|
3210
|
-
a.shape,
|
|
3211
|
-
[...batch, Math.min(m, n)],
|
|
3212
|
-
[...batch, m]
|
|
3213
|
-
],
|
|
3214
|
-
outputDtypes: [
|
|
3215
|
-
a.dtype,
|
|
3216
|
-
require_backend.DType.Int32,
|
|
3217
|
-
require_backend.DType.Int32
|
|
3218
|
-
]
|
|
3219
|
-
});
|
|
3220
|
-
return Array$1.#routine(routine, [a], [
|
|
3221
|
-
a.#weakType,
|
|
3222
|
-
false,
|
|
3223
|
-
false
|
|
3224
|
-
]);
|
|
3225
|
-
},
|
|
3202
|
+
[Primitive.Sort]: Array$1.#routine(Primitive.Sort),
|
|
3203
|
+
[Primitive.Argsort]: Array$1.#routine(Primitive.Argsort),
|
|
3204
|
+
[Primitive.TriangularSolve]: Array$1.#routine(Primitive.TriangularSolve),
|
|
3205
|
+
[Primitive.Cholesky]: Array$1.#routine(Primitive.Cholesky),
|
|
3206
|
+
[Primitive.LU]: Array$1.#routine(Primitive.LU),
|
|
3226
3207
|
[Primitive.Jit](args, { jaxpr }) {
|
|
3227
3208
|
if (jaxpr.inBinders.length !== args.length) throw new Error(`jit expects ${jaxpr.inBinders.length} args, got ${args.length}`);
|
|
3228
3209
|
const { backend, committed } = Array$1.#computeBackend("jit", args);
|
|
@@ -3304,7 +3285,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
3304
3285
|
if (!shape$1) {
|
|
3305
3286
|
shape$1 = [];
|
|
3306
3287
|
let cur = values;
|
|
3307
|
-
while (JsArray.isArray(cur)) {
|
|
3288
|
+
while (JsArray$1.isArray(cur)) {
|
|
3308
3289
|
shape$1.push(cur.length);
|
|
3309
3290
|
cur = cur[0];
|
|
3310
3291
|
}
|
|
@@ -4260,17 +4241,39 @@ function jvpFlat(f, primals, tangents) {
|
|
|
4260
4241
|
_usingCtx$1.d();
|
|
4261
4242
|
}
|
|
4262
4243
|
}
|
|
4263
|
-
function jvp$1(f, primals, tangents) {
|
|
4244
|
+
function jvp$1(f, primals, tangents, { hasAux = false } = {}) {
|
|
4264
4245
|
const [primalsFlat, inTree] = flatten(primals);
|
|
4265
4246
|
const [tangentsFlat, inTree2] = flatten(tangents);
|
|
4266
4247
|
if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
|
|
4267
|
-
|
|
4248
|
+
let flatFun, outTree, aux;
|
|
4249
|
+
if (hasAux) [flatFun, outTree, aux] = flattenFunWithAux(f, inTree);
|
|
4250
|
+
else [flatFun, outTree] = flattenFun(f, inTree);
|
|
4268
4251
|
const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
|
|
4269
4252
|
if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
|
|
4270
4253
|
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
4271
4254
|
const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
|
|
4255
|
+
if (hasAux) return [
|
|
4256
|
+
primalsOut,
|
|
4257
|
+
tangentsOut,
|
|
4258
|
+
lowerAux(aux.value)
|
|
4259
|
+
];
|
|
4272
4260
|
return [primalsOut, tangentsOut];
|
|
4273
4261
|
}
|
|
4262
|
+
/** Lowering for auxiliary data returned in `hasAux: true` methods. */
|
|
4263
|
+
function lowerAux(aux) {
|
|
4264
|
+
const level = currentTraceLevel();
|
|
4265
|
+
return map((x) => {
|
|
4266
|
+
if (x instanceof Tracer) while (x._trace.main.level > level) if (x instanceof JVPTracer) {
|
|
4267
|
+
x.tangent.dispose();
|
|
4268
|
+
x = x.primal;
|
|
4269
|
+
} else {
|
|
4270
|
+
const y = x.fullLower();
|
|
4271
|
+
if (y._trace.main.level >= x._trace.main.level) throw new Error("internal: lowerAux did not reduce trace level");
|
|
4272
|
+
x = y;
|
|
4273
|
+
}
|
|
4274
|
+
return x;
|
|
4275
|
+
}, aux);
|
|
4276
|
+
}
|
|
4274
4277
|
|
|
4275
4278
|
//#endregion
|
|
4276
4279
|
//#region src/frontend/linearize.ts
|
|
@@ -4341,9 +4344,11 @@ function linearizeFlat(f, primalsIn) {
|
|
|
4341
4344
|
dispose$1
|
|
4342
4345
|
];
|
|
4343
4346
|
}
|
|
4344
|
-
function linearize$1(f,
|
|
4347
|
+
function linearize$1(f, primalsIn, { hasAux = false } = {}) {
|
|
4345
4348
|
const [primalsInFlat, inTree] = flatten(primalsIn);
|
|
4346
|
-
|
|
4349
|
+
let fFlat, outTree, aux;
|
|
4350
|
+
if (hasAux) [fFlat, outTree, aux] = flattenFunWithAux(f, inTree);
|
|
4351
|
+
else [fFlat, outTree] = flattenFun(f, inTree);
|
|
4347
4352
|
const [primalsOutFlat, fLinFlat, dispose$1] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
|
|
4348
4353
|
if (outTree.value === void 0) throw new Error("outTree was not set in linearize");
|
|
4349
4354
|
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
@@ -4354,6 +4359,11 @@ function linearize$1(f, ...primalsIn) {
|
|
|
4354
4359
|
return unflatten(outTree.value, tangentsOutFlat);
|
|
4355
4360
|
});
|
|
4356
4361
|
fLin.dispose = dispose$1;
|
|
4362
|
+
if (hasAux) return [
|
|
4363
|
+
primalsOut,
|
|
4364
|
+
fLin,
|
|
4365
|
+
lowerAux(aux.value)
|
|
4366
|
+
];
|
|
4357
4367
|
return [primalsOut, fLin];
|
|
4358
4368
|
}
|
|
4359
4369
|
var PartialEvalTracer = class extends Tracer {
|
|
@@ -4854,9 +4864,11 @@ function vjpFlat(f, primalsIn) {
|
|
|
4854
4864
|
dispose$1
|
|
4855
4865
|
];
|
|
4856
4866
|
}
|
|
4857
|
-
function vjp$1(f,
|
|
4867
|
+
function vjp$1(f, primalsIn, { hasAux = false } = {}) {
|
|
4858
4868
|
const [primalsInFlat, inTree] = flatten(primalsIn);
|
|
4859
|
-
|
|
4869
|
+
let fFlat, outTree, aux;
|
|
4870
|
+
if (hasAux) [fFlat, outTree, aux] = flattenFunWithAux(f, inTree);
|
|
4871
|
+
else [fFlat, outTree] = flattenFun(f, inTree);
|
|
4860
4872
|
const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
|
|
4861
4873
|
if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
|
|
4862
4874
|
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
@@ -4867,26 +4879,43 @@ function vjp$1(f, ...primalsIn) {
|
|
|
4867
4879
|
return unflatten(inTree, cotangentsInFlat);
|
|
4868
4880
|
});
|
|
4869
4881
|
fVjp.dispose = dispose$1;
|
|
4882
|
+
if (hasAux) return [
|
|
4883
|
+
primalsOut,
|
|
4884
|
+
fVjp,
|
|
4885
|
+
lowerAux(aux.value)
|
|
4886
|
+
];
|
|
4870
4887
|
return [primalsOut, fVjp];
|
|
4871
4888
|
}
|
|
4872
|
-
function grad$1(f) {
|
|
4873
|
-
const valueAndGradFn = valueAndGrad$1(f);
|
|
4889
|
+
function grad$1(f, opts) {
|
|
4890
|
+
const valueAndGradFn = valueAndGrad$1(f, opts);
|
|
4874
4891
|
return (...x) => {
|
|
4875
|
-
|
|
4876
|
-
|
|
4877
|
-
|
|
4892
|
+
if (opts?.hasAux) {
|
|
4893
|
+
const [[y, aux], dx] = valueAndGradFn(...x);
|
|
4894
|
+
y.dispose();
|
|
4895
|
+
return [dx, aux];
|
|
4896
|
+
} else {
|
|
4897
|
+
const [y, dx] = valueAndGradFn(...x);
|
|
4898
|
+
y.dispose();
|
|
4899
|
+
return dx;
|
|
4900
|
+
}
|
|
4878
4901
|
};
|
|
4879
4902
|
}
|
|
4880
|
-
function valueAndGrad$1(f) {
|
|
4903
|
+
function valueAndGrad$1(f, opts) {
|
|
4904
|
+
const argnums = opts?.argnums ?? 0;
|
|
4905
|
+
const hasAux = opts?.hasAux ?? false;
|
|
4906
|
+
require_backend.checkInts(argnums);
|
|
4907
|
+
const argnumsSet = new Set(typeof argnums === "number" ? [argnums] : argnums);
|
|
4881
4908
|
return (...x) => {
|
|
4882
4909
|
if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
|
|
4883
|
-
|
|
4910
|
+
for (let i = 0; i < x.length; i++) if (!argnumsSet.has(i)) x[i] = map(stopGradient, x[i]);
|
|
4911
|
+
const [y, fVjp, aux] = vjp$1(f, x, { hasAux });
|
|
4884
4912
|
if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
|
|
4885
4913
|
if (!require_backend.isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
|
|
4886
|
-
const
|
|
4887
|
-
for (const r of rest) dispose(r);
|
|
4914
|
+
const cts = fVjp(onesLike$1(y.ref));
|
|
4888
4915
|
fVjp.dispose();
|
|
4889
|
-
|
|
4916
|
+
for (let i = 0; i < cts.length; i++) if (!argnumsSet.has(i)) dispose(cts[i]);
|
|
4917
|
+
const grads = typeof argnums === "number" ? cts[argnums] : argnums.map((i) => cts[i]);
|
|
4918
|
+
return hasAux ? [[y, aux], grads] : [y, grads];
|
|
4890
4919
|
};
|
|
4891
4920
|
}
|
|
4892
4921
|
function jacrev$1(f) {
|
|
@@ -4894,7 +4923,7 @@ function jacrev$1(f) {
|
|
|
4894
4923
|
if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
|
|
4895
4924
|
const [size$1] = x.shape;
|
|
4896
4925
|
const pullback = (ct) => {
|
|
4897
|
-
const [y, fVjp] = vjp$1(f, x);
|
|
4926
|
+
const [y, fVjp] = vjp$1(f, [x]);
|
|
4898
4927
|
y.dispose();
|
|
4899
4928
|
const [ret] = fVjp(ct);
|
|
4900
4929
|
fVjp.dispose();
|
|
@@ -4903,6 +4932,9 @@ function jacrev$1(f) {
|
|
|
4903
4932
|
return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
|
|
4904
4933
|
};
|
|
4905
4934
|
}
|
|
4935
|
+
function hessian$1(f) {
|
|
4936
|
+
return jacfwd$1(grad$1(f));
|
|
4937
|
+
}
|
|
4906
4938
|
|
|
4907
4939
|
//#endregion
|
|
4908
4940
|
//#region src/library/numpy/einsum.ts
|
|
@@ -5612,6 +5644,7 @@ __export(numpy_exports, {
|
|
|
5612
5644
|
std: () => std,
|
|
5613
5645
|
subtract: () => subtract,
|
|
5614
5646
|
sum: () => sum,
|
|
5647
|
+
swapaxes: () => swapaxes,
|
|
5615
5648
|
take: () => take,
|
|
5616
5649
|
tan: () => tan,
|
|
5617
5650
|
tanh: () => tanh,
|
|
@@ -6010,6 +6043,17 @@ function flipud(x) {
|
|
|
6010
6043
|
function fliplr(x) {
|
|
6011
6044
|
return flip(x, 1);
|
|
6012
6045
|
}
|
|
6046
|
+
/** Interchange two axes of an array. */
|
|
6047
|
+
function swapaxes(a, axis1, axis2) {
|
|
6048
|
+
a = fudgeArray(a);
|
|
6049
|
+
axis1 = require_backend.checkAxis(axis1, a.ndim);
|
|
6050
|
+
axis2 = require_backend.checkAxis(axis2, a.ndim);
|
|
6051
|
+
if (axis1 === axis2) return a;
|
|
6052
|
+
const perm = require_backend.range(a.ndim);
|
|
6053
|
+
perm[axis1] = axis2;
|
|
6054
|
+
perm[axis2] = axis1;
|
|
6055
|
+
return transpose(a, perm);
|
|
6056
|
+
}
|
|
6013
6057
|
/** Transpose the last two dimensions of an array. */
|
|
6014
6058
|
function matrixTranspose(a) {
|
|
6015
6059
|
if (ndim(a) < 2) throw new Error(`matrixTranspose: input array must be at least 2D`);
|
|
@@ -6938,6 +6982,7 @@ var lax_exports = {};
|
|
|
6938
6982
|
__export(lax_exports, {
|
|
6939
6983
|
conv: () => conv,
|
|
6940
6984
|
convGeneralDilated: () => convGeneralDilated,
|
|
6985
|
+
convTranspose: () => convTranspose,
|
|
6941
6986
|
convWithGeneralPadding: () => convWithGeneralPadding,
|
|
6942
6987
|
dot: () => dot,
|
|
6943
6988
|
erf: () => erf,
|
|
@@ -6946,6 +6991,7 @@ __export(lax_exports, {
|
|
|
6946
6991
|
reduceWindow: () => reduceWindow,
|
|
6947
6992
|
stopGradient: () => stopGradient$1
|
|
6948
6993
|
});
|
|
6994
|
+
const JsArray = globalThis.Array;
|
|
6949
6995
|
/**
|
|
6950
6996
|
* General dot product/contraction operator.
|
|
6951
6997
|
*
|
|
@@ -7017,7 +7063,11 @@ function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
|
|
|
7017
7063
|
* The semantics of this operation mimic the `jax.lax.conv_general_dilated`
|
|
7018
7064
|
* function in JAX, which wraps XLA's general convolution operator.
|
|
7019
7065
|
*
|
|
7020
|
-
*
|
|
7066
|
+
* @param lhs - Input tensor; shape `[N, C_in, ...xs]`
|
|
7067
|
+
* @param rhs - Convolution kernel; shape `[C_out, C_in / G, ...ks]`
|
|
7068
|
+
* @param windowStrides - Strides for each spatial dimension
|
|
7069
|
+
* @param padding - Padding for each spatial dimension, or a string
|
|
7070
|
+
* (`"VALID"`, `"SAME"`, or `"SAME_LOWER"`)
|
|
7021
7071
|
*/
|
|
7022
7072
|
function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
|
|
7023
7073
|
if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
|
|
@@ -7077,6 +7127,60 @@ function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, r
|
|
|
7077
7127
|
function conv(lhs, rhs, windowStrides, padding) {
|
|
7078
7128
|
return convGeneralDilated(lhs, rhs, windowStrides, padding);
|
|
7079
7129
|
}
|
|
7130
|
+
/**
|
|
7131
|
+
* Convenience wrapper for calculating the N-d convolution "transpose".
|
|
7132
|
+
*
|
|
7133
|
+
* This function directly calculates a fractionally strided conv rather than
|
|
7134
|
+
* indirectly calculating the gradient (transpose) of a forward convolution.
|
|
7135
|
+
* It is equivalent to the JAX version, except:
|
|
7136
|
+
*
|
|
7137
|
+
* - The `use_consistent_padding` option is not available. We only have the
|
|
7138
|
+
* consistent padding case (JAX version >0.8.4).
|
|
7139
|
+
* - The order of dimensions matches `lax.conv_general_dilated`.
|
|
7140
|
+
*
|
|
7141
|
+
* Unlike PyTorch/TensorFlow, by default we don't reverse the kernel's spatial
|
|
7142
|
+
* dimensions or the `(C_out, C_in)` axis order. To get this behavior, set
|
|
7143
|
+
* `transposeKernel` to true.
|
|
7144
|
+
*
|
|
7145
|
+
* @param lhs - Input tensor; shape `[N, C_in, ...xs]`
|
|
7146
|
+
* @param rhs - Convolution kernel; shape `[C_out, C_in, ...ks]`
|
|
7147
|
+
* @param strides - Sequence of n integers, sets fractional stride
|
|
7148
|
+
* @param padding - Apply padding of `dilation * (kernel_size - 1) - padding` to
|
|
7149
|
+
* each side of the input, so it acts like gradient of `conv()`
|
|
7150
|
+
* @param rhsDilation - Atrous dilation for the kernel
|
|
7151
|
+
* @param transposeKernel - Flip spatial axes and swap the input/output channels
|
|
7152
|
+
* of the kernel; its shape should be `[C_in, C_out, ...ks]`
|
|
7153
|
+
*/
|
|
7154
|
+
function convTranspose(lhs, rhs, strides, padding, { rhsDilation, transposeKernel = false } = {}) {
|
|
7155
|
+
const kernelShape = rhs.shape.slice(2);
|
|
7156
|
+
rhsDilation = rhsDilation ?? require_backend.rep(kernelShape.length, 1);
|
|
7157
|
+
const effectiveKernel = kernelShape.map((k, i) => Math.max(0, (k - 1) * rhsDilation[i] + 1));
|
|
7158
|
+
const pads = effectiveKernel.map((k, i) => convTransposePadding(k, strides[i], typeof padding === "string" ? padding : padding[i]));
|
|
7159
|
+
if (transposeKernel) {
|
|
7160
|
+
rhs = flip$1(rhs, require_backend.range(2, rhs.ndim));
|
|
7161
|
+
rhs = moveaxis(rhs, 0, 1);
|
|
7162
|
+
}
|
|
7163
|
+
return convGeneralDilated(lhs, rhs, require_backend.rep(lhs.ndim - 2, 1), pads, {
|
|
7164
|
+
lhsDilation: strides,
|
|
7165
|
+
rhsDilation
|
|
7166
|
+
});
|
|
7167
|
+
}
|
|
7168
|
+
function convTransposePadding(k, s, padding) {
|
|
7169
|
+
let padLen;
|
|
7170
|
+
let pad1;
|
|
7171
|
+
if (padding === "SAME") {
|
|
7172
|
+
padLen = k + s - 2;
|
|
7173
|
+
pad1 = s > k - 1 ? k - 1 : Math.ceil(padLen / 2);
|
|
7174
|
+
} else if (padding === "VALID") {
|
|
7175
|
+
padLen = k + s - 2 + Math.max(k - s, 0);
|
|
7176
|
+
pad1 = k - 1;
|
|
7177
|
+
} else if (JsArray.isArray(padding)) {
|
|
7178
|
+
const pads = [k - 1 - padding[0], k - 1 - padding[1]];
|
|
7179
|
+
pad1 = pads[0];
|
|
7180
|
+
padLen = pads[0] + pads[1];
|
|
7181
|
+
} else throw new Error(`convTranspose: Invalid padding type ${padding}`);
|
|
7182
|
+
return [pad1, padLen - pad1];
|
|
7183
|
+
}
|
|
7080
7184
|
/** Reduce a computation over padded windows. */
|
|
7081
7185
|
function reduceWindow(operand, computation, windowDimensions, windowStrides) {
|
|
7082
7186
|
if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
|
|
@@ -7115,6 +7219,7 @@ function stopGradient$1(x) {
|
|
|
7115
7219
|
var nn_exports = {};
|
|
7116
7220
|
__export(nn_exports, {
|
|
7117
7221
|
celu: () => celu,
|
|
7222
|
+
dotProductAttention: () => dotProductAttention,
|
|
7118
7223
|
elu: () => elu,
|
|
7119
7224
|
gelu: () => gelu,
|
|
7120
7225
|
glu: () => glu,
|
|
@@ -7431,6 +7536,95 @@ function oneHot(x, numClasses) {
|
|
|
7431
7536
|
if (require_backend.isFloatDtype(x.dtype)) throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
|
|
7432
7537
|
return eye(numClasses, void 0, { device: x.device }).slice(x);
|
|
7433
7538
|
}
|
|
7539
|
+
/**
|
|
7540
|
+
* Scaled dot product attention (SDPA).
|
|
7541
|
+
*
|
|
7542
|
+
* Computes `softmax((Q @ K^T) / sqrt(d) + bias) @ V`, where `Q` is the query,
|
|
7543
|
+
* `K` is the key, `V` is the value, and `d` is the dimensionality of each key
|
|
7544
|
+
* and query vector.
|
|
7545
|
+
*
|
|
7546
|
+
* Multi-query attention is applied when input `key` and `value` tensors have
|
|
7547
|
+
* fewer heads than `query`.
|
|
7548
|
+
*
|
|
7549
|
+
* We use the following uppercase letters to denote array shapes:
|
|
7550
|
+
* - `B` = batch size
|
|
7551
|
+
* - `S` = length of key/value sequences (source)
|
|
7552
|
+
* - `L` = length of query sequences
|
|
7553
|
+
* - `N` = number of attention heads
|
|
7554
|
+
* - `H` = dimensionality of each attention head
|
|
7555
|
+
* - `K` = number of key/value heads (for grouped-query attention)
|
|
7556
|
+
*
|
|
7557
|
+
* The batch size `B` may be omitted, which is equivalent to `B = 1`. In this
|
|
7558
|
+
* case it must be omitted from all inputs.
|
|
7559
|
+
*
|
|
7560
|
+
* @param query - Query array; shape `[B, L, N, H]`
|
|
7561
|
+
* @param key - Key array; shape `[B, S, K, H]`
|
|
7562
|
+
* @param value - Value array; same shape as `key`
|
|
7563
|
+
* @param opts.bias - Optional bias to add to the attention logits; shape
|
|
7564
|
+
* `[B, N, L, S]` or broadcastable to it.
|
|
7565
|
+
* @param opts.mask - Optional mask to apply to the attention logits; should be
|
|
7566
|
+
* a boolean array broadcastable to `[B, N, L, S]`, where `true` indicates
|
|
7567
|
+
* the element should take part in attention.
|
|
7568
|
+
* @param opts.scale - Scaling factor override, default is `1 / sqrt(H)`.
|
|
7569
|
+
* @param opts.isCausal - If true, applies a casual mask.
|
|
7570
|
+
* @param opts.querySeqLengths - Optional sequence lengths for the queries;
|
|
7571
|
+
* shape `(B,)`. Taken from the beginning of the tensor.
|
|
7572
|
+
* @param opts.keyValueSeqLengths - Optional sequence lengths for the keys and
|
|
7573
|
+
* values; shape `(B,)`. Taken from the beginning of the tensor.
|
|
7574
|
+
* @param opts.localWindowSize - If specified, applies a local attention window
|
|
7575
|
+
* of the given size. Can be a single number or a tuple `[left, right]`.
|
|
7576
|
+
*
|
|
7577
|
+
* @returns The result of the attention operation; shape is the same as query
|
|
7578
|
+
* `[B, L, N, H]`, or `[L, N, H]` if `B` is omitted.
|
|
7579
|
+
*/
|
|
7580
|
+
function dotProductAttention(query, key$1, value, opts = {}) {
|
|
7581
|
+
if (opts.querySeqLengths !== void 0 || opts.keyValueSeqLengths !== void 0) throw new Error("Sequence length masking is not yet implemented");
|
|
7582
|
+
if (opts.localWindowSize !== void 0) throw new Error("Local attention is not yet implemented");
|
|
7583
|
+
query = fudgeArray(query);
|
|
7584
|
+
key$1 = fudgeArray(key$1);
|
|
7585
|
+
value = fudgeArray(value);
|
|
7586
|
+
if (query.ndim !== 3 && query.ndim !== 4 || query.ndim !== key$1.ndim || query.ndim !== value.ndim) throw new Error(`dotProductAttention: expected all tensors to have rank 3 or 4, got Q=${query.aval}, K=${key$1.aval}, V=${value.aval}`);
|
|
7587
|
+
if (!require_backend.deepEqual(key$1.shape, value.shape)) throw new Error(`dotProductAttention: key and value shapes must match, got K=${key$1.shape}, V=${value.shape}`);
|
|
7588
|
+
const isRank3 = query.ndim === 3;
|
|
7589
|
+
if (isRank3) {
|
|
7590
|
+
query = expandDims(query, 0);
|
|
7591
|
+
key$1 = expandDims(key$1, 0);
|
|
7592
|
+
value = expandDims(value, 0);
|
|
7593
|
+
}
|
|
7594
|
+
const [B, L, N, H] = query.shape;
|
|
7595
|
+
if (key$1.shape[0] !== B || key$1.shape[3] !== H) throw new Error(`dotProductAttention: query and key shapes mismatch, got Q=${query.aval}, K=${key$1.aval}`);
|
|
7596
|
+
const S = key$1.shape[1];
|
|
7597
|
+
const K = key$1.shape[2];
|
|
7598
|
+
if (N < K || N != K && N % K !== 0) throw new Error(`dotProductAttention: number of query heads N=${N} must be divisible by number of key/value heads K=${K} for GQA`);
|
|
7599
|
+
const G = N / K;
|
|
7600
|
+
key$1 = tile(key$1, [
|
|
7601
|
+
1,
|
|
7602
|
+
1,
|
|
7603
|
+
G,
|
|
7604
|
+
1
|
|
7605
|
+
]);
|
|
7606
|
+
value = tile(value, [
|
|
7607
|
+
1,
|
|
7608
|
+
1,
|
|
7609
|
+
G,
|
|
7610
|
+
1
|
|
7611
|
+
]);
|
|
7612
|
+
const scale = opts.scale ?? 1 / Math.sqrt(H);
|
|
7613
|
+
let scores = einsum("BLNH,BSNH->BNLS", query, key$1).mul(scale);
|
|
7614
|
+
if (opts.bias !== void 0) scores = scores.add(opts.bias);
|
|
7615
|
+
if (opts.mask !== void 0) scores = where(opts.mask, scores, -Infinity);
|
|
7616
|
+
if (opts.isCausal) {
|
|
7617
|
+
const causalMask = tri(L, S, 0, { dtype: require_backend.DType.Bool });
|
|
7618
|
+
scores = where(causalMask, scores, -Infinity);
|
|
7619
|
+
}
|
|
7620
|
+
const attn = softmax(scores, -1);
|
|
7621
|
+
const out = einsum("BNLS,BSNH->BLNH", attn, value);
|
|
7622
|
+
return isRank3 ? out.reshape([
|
|
7623
|
+
L,
|
|
7624
|
+
N,
|
|
7625
|
+
H
|
|
7626
|
+
]) : out;
|
|
7627
|
+
}
|
|
7434
7628
|
|
|
7435
7629
|
//#endregion
|
|
7436
7630
|
//#region src/library/random.ts
|
|
@@ -7666,17 +7860,62 @@ const linearize = linearize$1;
|
|
|
7666
7860
|
/**
|
|
7667
7861
|
* @function
|
|
7668
7862
|
* Calculate the reverse-mode vector-Jacobian product for a function.
|
|
7863
|
+
*
|
|
7864
|
+
* The return value is a tuple of `[out, vjpFn]`, where `out` is the output of
|
|
7865
|
+
* `f(primals)`, and `vjpFn` is a function that takes in cotangents for each
|
|
7866
|
+
* output and returns the cotangents for each input.
|
|
7867
|
+
*
|
|
7868
|
+
* When `{ hasAux: true }` is passed, the function `f` is expected to return an
|
|
7869
|
+
* `[out, aux]` tuple, and `vjp` returns `[out, vjpFn, aux]`.
|
|
7870
|
+
*
|
|
7871
|
+
* @example
|
|
7872
|
+
* ```ts
|
|
7873
|
+
* const [y, vjpFn] = vjp(f, [x]);
|
|
7874
|
+
*
|
|
7875
|
+
* // With hasAux
|
|
7876
|
+
* const [y, vjpFn, aux] = vjp(f, [x], { hasAux: true });
|
|
7877
|
+
* ```
|
|
7669
7878
|
*/
|
|
7670
7879
|
const vjp = vjp$1;
|
|
7671
7880
|
/**
|
|
7672
7881
|
* @function
|
|
7673
7882
|
* Compute the gradient of a scalar-valued function `f` with respect to its
|
|
7674
7883
|
* first argument.
|
|
7884
|
+
*
|
|
7885
|
+
* Pass in different `argnums` to differentiate with respect to other
|
|
7886
|
+
* arguments. If a tuple is provided, the return value will be a tuple of
|
|
7887
|
+
* gradients corresponding to each argument index.
|
|
7888
|
+
*
|
|
7889
|
+
* When `{ hasAux: true }` is passed, the function `f` is expected to return a
|
|
7890
|
+
* `[out, aux]` tuple, and the return value will be `[gradient, aux]`.
|
|
7891
|
+
*
|
|
7892
|
+
* @example
|
|
7893
|
+
* ```ts
|
|
7894
|
+
* const gradient = grad(f)(x);
|
|
7895
|
+
*
|
|
7896
|
+
* // With `argnums`
|
|
7897
|
+
* const [gradientX, gradientZ] = grad(f, { argnums: [0, 2] })(x, y, z);
|
|
7898
|
+
*
|
|
7899
|
+
* // With `hasAux`
|
|
7900
|
+
* const [gradient, aux] = grad(f, { hasAux: true })(x);
|
|
7901
|
+
* ```
|
|
7675
7902
|
*/
|
|
7676
7903
|
const grad = grad$1;
|
|
7677
7904
|
/**
|
|
7678
7905
|
* @function
|
|
7679
7906
|
* Create a function that evaluates both `f` and the gradient of `f`.
|
|
7907
|
+
*
|
|
7908
|
+
* When `{ hasAux: true }` is passed, the function `f` is expected to return an
|
|
7909
|
+
* `[out, aux]` tuple, and the return value will be `[[out, aux], gradient]`.
|
|
7910
|
+
*
|
|
7911
|
+
* @example
|
|
7912
|
+
* ```ts
|
|
7913
|
+
* // Without hasAux
|
|
7914
|
+
* const [value, gradient] = valueAndGrad(f)(x);
|
|
7915
|
+
*
|
|
7916
|
+
* // With hasAux
|
|
7917
|
+
* const [[value, aux], gradient] = valueAndGrad(f, { hasAux: true })(x);
|
|
7918
|
+
* ```
|
|
7680
7919
|
*/
|
|
7681
7920
|
const valueAndGrad = valueAndGrad$1;
|
|
7682
7921
|
/**
|
|
@@ -7685,6 +7924,21 @@ const valueAndGrad = valueAndGrad$1;
|
|
|
7685
7924
|
*/
|
|
7686
7925
|
const jacrev = jacrev$1;
|
|
7687
7926
|
/**
|
|
7927
|
+
* @function
|
|
7928
|
+
* Compute the Hessian matrix of a scalar-valued function.
|
|
7929
|
+
*
|
|
7930
|
+
* The Hessian is the matrix of second-order partial derivatives of a function.
|
|
7931
|
+
* This is implemented as `jacfwd(grad(f))`.
|
|
7932
|
+
*
|
|
7933
|
+
* @example
|
|
7934
|
+
* ```ts
|
|
7935
|
+
* const f = (x: np.Array) => np.sum(x.ref.mul(x.ref).mul(x)); // x^3
|
|
7936
|
+
* const H = hessian(f)(np.array([1, 2, 3]));
|
|
7937
|
+
* // H[i,j] = d^2f / dx_i dx_j
|
|
7938
|
+
* ```
|
|
7939
|
+
*/
|
|
7940
|
+
const hessian = hessian$1;
|
|
7941
|
+
/**
|
|
7688
7942
|
* Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
|
|
7689
7943
|
*
|
|
7690
7944
|
* This can be used to wait for the results of an intermediate computation to
|
|
@@ -7728,6 +7982,7 @@ exports.defaultDevice = require_backend.defaultDevice;
|
|
|
7728
7982
|
exports.devicePut = devicePut;
|
|
7729
7983
|
exports.devices = require_backend.devices;
|
|
7730
7984
|
exports.grad = grad;
|
|
7985
|
+
exports.hessian = hessian;
|
|
7731
7986
|
exports.init = require_backend.init;
|
|
7732
7987
|
exports.jacfwd = jacfwd;
|
|
7733
7988
|
exports.jacobian = jacrev;
|