@jax-js/jax 0.1.5 → 0.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -1,5 +1,5 @@
1
1
  import { __export } from "./chunk-Cl8Af3a2.js";
2
- import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-DaqL-MNz.js";
2
+ import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-nEolvdLv.js";
3
3
 
4
4
  //#region src/frontend/convolution.ts
5
5
  /**
@@ -209,7 +209,7 @@ __export(tree_exports, {
209
209
  structure: () => structure,
210
210
  unflatten: () => unflatten
211
211
  });
212
- const JsArray$1 = globalThis.Array;
212
+ const JsArray$2 = globalThis.Array;
213
213
  let NodeType = /* @__PURE__ */ function(NodeType$1) {
214
214
  NodeType$1["Array"] = "Array";
215
215
  NodeType$1["Object"] = "Object";
@@ -257,7 +257,7 @@ function flatten(tree) {
257
257
  return [leaves$1, treedef];
258
258
  }
259
259
  function _flatten(tree, leaves$1) {
260
- if (JsArray$1.isArray(tree)) {
260
+ if (JsArray$2.isArray(tree)) {
261
261
  const childTrees = tree.map((c) => _flatten(c, leaves$1));
262
262
  return new JsTreeDef(NodeType.Array, null, childTrees);
263
263
  } else if (typeof tree === "object" && tree !== null && tree.constructor === Object) {
@@ -306,11 +306,11 @@ function map(fn, tree, ...rest) {
306
306
  }
307
307
  /** Take a reference of every array in a tree. */
308
308
  function ref(tree) {
309
- return map((x) => x.ref, tree);
309
+ return map((x) => x instanceof Tracer ? x.ref : x, tree);
310
310
  }
311
311
  /** Dispose every array in a tree. */
312
312
  function dispose(tree) {
313
- if (tree) map((x) => x.dispose(), tree);
313
+ if (tree) map((x) => x instanceof Tracer ? x.dispose() : void 0, tree);
314
314
  }
315
315
 
316
316
  //#endregion
@@ -381,6 +381,13 @@ let CompareOp = /* @__PURE__ */ function(CompareOp$1) {
381
381
  CompareOp$1["LessEqual"] = "less_equal";
382
382
  return CompareOp$1;
383
383
  }({});
384
+ const routinePrimitives = new Map([
385
+ [Primitive.Sort, Routines.Sort],
386
+ [Primitive.Argsort, Routines.Argsort],
387
+ [Primitive.TriangularSolve, Routines.TriangularSolve],
388
+ [Primitive.Cholesky, Routines.Cholesky],
389
+ [Primitive.LU, Routines.LU]
390
+ ]);
384
391
  function add$1(x, y) {
385
392
  return bind1(Primitive.Add, [x, y]);
386
393
  }
@@ -577,14 +584,20 @@ function shrink(x, slice) {
577
584
  }
578
585
  function pad$1(x, width) {
579
586
  const nd = ndim$1(x);
580
- if (typeof width === "number") width = [[width, width]];
581
- else if (isNumberPair(width)) width = [width];
582
- else if (!Array.isArray(width) || !width.every(isNumberPair)) throw new TypeError(`Invalid pad() type: ${JSON.stringify(width)}`);
583
- if (width.length === 1) {
584
- const [w0, w1] = width[0];
585
- width = rep(nd, () => [w0, w1]);
586
- } else if (width.length !== nd) throw new Error(`Invalid pad(): expected ${nd} axes, got ${width.length}`);
587
- return bind1(Primitive.Pad, [x], { width });
587
+ let w;
588
+ if (typeof width === "number") w = [[width, width]];
589
+ else if (isNumberPair(width)) w = [width];
590
+ else if (!Array.isArray(width)) {
591
+ const indicesAndPairs = Object.entries(width);
592
+ w = rep(nd, [0, 0]);
593
+ for (const [k, v] of indicesAndPairs) w[checkAxis(parseInt(k), nd)] = v;
594
+ } else if (!width.every(isNumberPair)) throw new TypeError(`Invalid pad() type: ${JSON.stringify(width)}`);
595
+ else w = width;
596
+ if (w.length === 1) {
597
+ const [w0, w1] = w[0];
598
+ w = rep(nd, () => [w0, w1]);
599
+ } else if (w.length !== nd) throw new Error(`Invalid pad(): expected ${nd} axes, got ${w.length}`);
600
+ return bind1(Primitive.Pad, [x], { width: w });
588
601
  }
589
602
  function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
590
603
  const as = getShape(a);
@@ -654,6 +667,9 @@ function newDynamic(main) {
654
667
  dynamicTrace = prevDynamicTrace;
655
668
  } };
656
669
  }
670
+ function currentTraceLevel() {
671
+ return traceStack[traceStack.length - 1].level;
672
+ }
657
673
  var Trace = class {
658
674
  constructor(main) {
659
675
  this.main = main;
@@ -757,6 +773,22 @@ var Tracer = class Tracer {
757
773
  const result = reduce(this.astype(castDtype), AluOp.Add, axis, opts);
758
774
  return result.mul(1 / n).astype(originalDtype);
759
775
  }
776
+ /** Minimum of the elements of the array along a given axis. */
777
+ min(axis = null, opts) {
778
+ return reduce(this, AluOp.Min, axis, opts);
779
+ }
780
+ /** Maximum of the elements of the array along a given axis. */
781
+ max(axis = null, opts) {
782
+ return reduce(this, AluOp.Max, axis, opts);
783
+ }
784
+ /** Test whether all array elements along a given axis evaluate to true. */
785
+ all(axis = null, opts) {
786
+ return this.astype(DType.Bool).min(axis, opts);
787
+ }
788
+ /** Test whether any array element along a given axis evaluates to true. */
789
+ any(axis = null, opts) {
790
+ return this.astype(DType.Bool).max(axis, opts);
791
+ }
760
792
  /** Permute the dimensions of an array. Defaults to reversing the axis order. */
761
793
  transpose(perm) {
762
794
  return transpose$1(this, perm);
@@ -1031,6 +1063,7 @@ var TreeMismatchError = class extends TypeError {
1031
1063
  super(`Mismatched tree structures in ${where$2}: ${left} != ${right}`);
1032
1064
  }
1033
1065
  };
1066
+ /** Flatten a function of `JsTree` input/output for use in tracing. */
1034
1067
  function flattenFun(f, inTree) {
1035
1068
  const store = { value: void 0 };
1036
1069
  const flatFun = (...argsFlat) => {
@@ -1042,6 +1075,26 @@ function flattenFun(f, inTree) {
1042
1075
  };
1043
1076
  return [flatFun, store];
1044
1077
  }
1078
+ /** Like flattenFun, but expects f to return [main, aux] tuple. */
1079
+ function flattenFunWithAux(f, inTree) {
1080
+ const store = { value: void 0 };
1081
+ const auxStore = { value: void 0 };
1082
+ const flatFun = (...argsFlat) => {
1083
+ const pytreeArgs = unflatten(inTree, argsFlat);
1084
+ const result = f(...pytreeArgs);
1085
+ if (!Array.isArray(result) || result.length !== 2) throw new Error("Function with `hasAux: true` must return [output, aux] tuple");
1086
+ const [out, aux] = result;
1087
+ const [outFlat, outTree] = flatten(out);
1088
+ store.value = outTree;
1089
+ auxStore.value = aux;
1090
+ return outFlat;
1091
+ };
1092
+ return [
1093
+ flatFun,
1094
+ store,
1095
+ auxStore
1096
+ ];
1097
+ }
1045
1098
  var UseAfterFreeError = class extends ReferenceError {
1046
1099
  constructor(tracer) {
1047
1100
  super(`Referenced tracer ${tracer.toString()} freed, please use .ref move semantics`);
@@ -1771,13 +1824,6 @@ function jit$1(f, opts) {
1771
1824
 
1772
1825
  //#endregion
1773
1826
  //#region src/frontend/jit.ts
1774
- const routinePrimitives = new Map([
1775
- [Primitive.Sort, Routines.Sort],
1776
- [Primitive.Argsort, Routines.Argsort],
1777
- [Primitive.TriangularSolve, Routines.TriangularSolve],
1778
- [Primitive.Cholesky, Routines.Cholesky],
1779
- [Primitive.LU, Routines.LU]
1780
- ]);
1781
1827
  /** Result of compiling a Jaxpr. Can be evaluated on a series of inputs. */
1782
1828
  var JitProgram = class {
1783
1829
  constructor(backend, steps, inputs, outputs) {
@@ -2166,12 +2212,13 @@ const jitRules = {
2166
2212
  const ndim$2 = avals[0].ndim;
2167
2213
  const sizes = avals.map((x) => x.shape[axis]);
2168
2214
  const finalSize = sizes.reduce((a, b) => a + b, 0);
2215
+ const { dtype: dtypeOut } = avals.map((x) => x.scalar()).reduce(promoteAvals);
2169
2216
  const makePadAxis = (start, end) => range(ndim$2).map((i) => i === axis ? [start, end] : [0, 0]);
2170
2217
  let cum = 0;
2171
2218
  const src = [];
2172
2219
  for (let i = 0; i < exps.length; i++) {
2173
2220
  const padding = makePadAxis(cum, finalSize - cum - sizes[i]);
2174
- src.push(reshapeViews(exps[i], (st) => st.pad(padding)));
2221
+ src.push(reshapeViews(AluExp.cast(dtypeOut, exps[i]), (st) => st.pad(padding)));
2175
2222
  cum += sizes[i];
2176
2223
  }
2177
2224
  return { exp: [src.reduce(AluExp.add)] };
@@ -2309,7 +2356,7 @@ function splitGraphDataflow(backend, jaxpr) {
2309
2356
  p1NextBlack.set(v, v);
2310
2357
  }
2311
2358
  const heterogeneousViewPrimitives = [Primitive.RandomBits, Primitive.Gather];
2312
- const needsCleanShapePrimitives = [Primitive.Pad];
2359
+ const needsCleanShapePrimitives = [Primitive.Concatenate, Primitive.Pad];
2313
2360
  for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
2314
2361
  const eqn = jaxpr.eqns[i];
2315
2362
  if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || routinePrimitives.has(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
@@ -2379,7 +2426,7 @@ function splitGraphDataflow(backend, jaxpr) {
2379
2426
 
2380
2427
  //#endregion
2381
2428
  //#region src/frontend/array.ts
2382
- const JsArray = globalThis.Array;
2429
+ const JsArray$1 = globalThis.Array;
2383
2430
  const inlineArrayLimit = 128;
2384
2431
  /** Version of pureArray with fudged types. */
2385
2432
  const fudgeArray = pureArray;
@@ -2777,25 +2824,35 @@ var Array$1 = class Array$1 extends Tracer {
2777
2824
  });
2778
2825
  }
2779
2826
  /** Apply an operation with custom lowering to this array. */
2780
- static #routine(routine, arrays, outputWeakType) {
2781
- const { backend, committed } = Array$1.#computeBackend(routine.name, arrays);
2782
- for (const ar of arrays) ar.#realize();
2783
- const inputs = arrays.map((ar) => ar.#source);
2784
- const outputs = routine.type.outputDtypes.map((dtype, i) => backend.malloc(byteWidth(dtype) * prod(routine.type.outputShapes[i])));
2785
- const pending = arrays.flatMap((ar) => ar.#pending);
2786
- for (const exe of pending) exe.updateRc(+outputs.length);
2787
- pending.push(new PendingExecute(backend, routine, inputs, outputs));
2788
- pending[pending.length - 1].updateRc(+outputs.length - 1);
2789
- arrays.forEach((ar) => ar.dispose());
2790
- return outputs.map((output, i) => new Array$1({
2791
- source: output,
2792
- st: ShapeTracker.fromShape(routine.type.outputShapes[i]),
2793
- dtype: routine.type.outputDtypes[i],
2794
- weakType: outputWeakType[i],
2795
- backend,
2796
- committed,
2797
- pending
2798
- }));
2827
+ static #routine(prim) {
2828
+ return (arrays, params) => {
2829
+ const { backend, committed } = Array$1.#computeBackend(prim, arrays);
2830
+ for (const ar of arrays) ar.#realize();
2831
+ const avals = arrays.map((ar) => ar.aval);
2832
+ const avalsOut = abstractEvalRules[prim](avals, params);
2833
+ const routine = new Routine(routinePrimitives.get(prim), {
2834
+ inputShapes: avals.map((a) => a.shape),
2835
+ inputDtypes: avals.map((a) => a.dtype),
2836
+ outputShapes: avalsOut.map((a) => a.shape),
2837
+ outputDtypes: avalsOut.map((a) => a.dtype)
2838
+ }, params);
2839
+ const inputs = arrays.map((ar) => ar.#source);
2840
+ const outputs = avalsOut.map((x) => backend.malloc(byteWidth(x.dtype) * x.size));
2841
+ const pending = arrays.flatMap((ar) => ar.#pending);
2842
+ for (const exe of pending) exe.updateRc(+outputs.length);
2843
+ pending.push(new PendingExecute(backend, routine, inputs, outputs));
2844
+ pending[pending.length - 1].updateRc(+outputs.length - 1);
2845
+ arrays.forEach((ar) => ar.dispose());
2846
+ return outputs.map((output, i) => new Array$1({
2847
+ source: output,
2848
+ st: ShapeTracker.fromShape(avalsOut[i].shape),
2849
+ dtype: avalsOut[i].dtype,
2850
+ weakType: avalsOut[i].weakType,
2851
+ backend,
2852
+ committed,
2853
+ pending
2854
+ }));
2855
+ };
2799
2856
  }
2800
2857
  /**
2801
2858
  * Normalizes this array into one backed by a `Slot`.
@@ -3129,65 +3186,11 @@ var Array$1 = class Array$1 extends Tracer {
3129
3186
  [Primitive.Pad]([x], { width }) {
3130
3187
  return [x.#reshape(x.#st.pad(width))];
3131
3188
  },
3132
- [Primitive.Sort]([x]) {
3133
- const routine = new Routine(Routines.Sort, {
3134
- inputShapes: [x.shape],
3135
- inputDtypes: [x.dtype],
3136
- outputShapes: [x.shape],
3137
- outputDtypes: [x.dtype]
3138
- });
3139
- return Array$1.#routine(routine, [x], [x.#weakType]);
3140
- },
3141
- [Primitive.Argsort]([x]) {
3142
- const routine = new Routine(Routines.Argsort, {
3143
- inputShapes: [x.shape],
3144
- inputDtypes: [x.dtype],
3145
- outputShapes: [x.shape, x.shape],
3146
- outputDtypes: [x.dtype, DType.Int32]
3147
- });
3148
- return Array$1.#routine(routine, [x], [x.#weakType, false]);
3149
- },
3150
- [Primitive.TriangularSolve]([a, b], { unitDiagonal }) {
3151
- const routine = new Routine(Routines.TriangularSolve, {
3152
- inputShapes: [a.shape, b.shape],
3153
- inputDtypes: [a.dtype, b.dtype],
3154
- outputShapes: [b.shape],
3155
- outputDtypes: [b.dtype]
3156
- }, { unitDiagonal });
3157
- return Array$1.#routine(routine, [a, b], [a.#weakType && b.#weakType]);
3158
- },
3159
- [Primitive.Cholesky]([a]) {
3160
- const routine = new Routine(Routines.Cholesky, {
3161
- inputShapes: [a.shape],
3162
- inputDtypes: [a.dtype],
3163
- outputShapes: [a.shape],
3164
- outputDtypes: [a.dtype]
3165
- });
3166
- return Array$1.#routine(routine, [a], [a.#weakType]);
3167
- },
3168
- [Primitive.LU]([a]) {
3169
- const batch = a.shape.slice(0, -2);
3170
- const [m, n] = a.shape.slice(-2);
3171
- const routine = new Routine(Routines.LU, {
3172
- inputShapes: [a.shape],
3173
- inputDtypes: [a.dtype],
3174
- outputShapes: [
3175
- a.shape,
3176
- [...batch, Math.min(m, n)],
3177
- [...batch, m]
3178
- ],
3179
- outputDtypes: [
3180
- a.dtype,
3181
- DType.Int32,
3182
- DType.Int32
3183
- ]
3184
- });
3185
- return Array$1.#routine(routine, [a], [
3186
- a.#weakType,
3187
- false,
3188
- false
3189
- ]);
3190
- },
3189
+ [Primitive.Sort]: Array$1.#routine(Primitive.Sort),
3190
+ [Primitive.Argsort]: Array$1.#routine(Primitive.Argsort),
3191
+ [Primitive.TriangularSolve]: Array$1.#routine(Primitive.TriangularSolve),
3192
+ [Primitive.Cholesky]: Array$1.#routine(Primitive.Cholesky),
3193
+ [Primitive.LU]: Array$1.#routine(Primitive.LU),
3191
3194
  [Primitive.Jit](args, { jaxpr }) {
3192
3195
  if (jaxpr.inBinders.length !== args.length) throw new Error(`jit expects ${jaxpr.inBinders.length} args, got ${args.length}`);
3193
3196
  const { backend, committed } = Array$1.#computeBackend("jit", args);
@@ -3269,7 +3272,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
3269
3272
  if (!shape$1) {
3270
3273
  shape$1 = [];
3271
3274
  let cur = values;
3272
- while (JsArray.isArray(cur)) {
3275
+ while (JsArray$1.isArray(cur)) {
3273
3276
  shape$1.push(cur.length);
3274
3277
  cur = cur[0];
3275
3278
  }
@@ -4223,17 +4226,39 @@ function jvpFlat(f, primals, tangents) {
4223
4226
  _usingCtx$1.d();
4224
4227
  }
4225
4228
  }
4226
- function jvp$1(f, primals, tangents) {
4229
+ function jvp$1(f, primals, tangents, { hasAux = false } = {}) {
4227
4230
  const [primalsFlat, inTree] = flatten(primals);
4228
4231
  const [tangentsFlat, inTree2] = flatten(tangents);
4229
4232
  if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
4230
- const [flatFun, outTree] = flattenFun(f, inTree);
4233
+ let flatFun, outTree, aux;
4234
+ if (hasAux) [flatFun, outTree, aux] = flattenFunWithAux(f, inTree);
4235
+ else [flatFun, outTree] = flattenFun(f, inTree);
4231
4236
  const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
4232
4237
  if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
4233
4238
  const primalsOut = unflatten(outTree.value, primalsOutFlat);
4234
4239
  const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
4240
+ if (hasAux) return [
4241
+ primalsOut,
4242
+ tangentsOut,
4243
+ lowerAux(aux.value)
4244
+ ];
4235
4245
  return [primalsOut, tangentsOut];
4236
4246
  }
4247
+ /** Lowering for auxiliary data returned in `hasAux: true` methods. */
4248
+ function lowerAux(aux) {
4249
+ const level = currentTraceLevel();
4250
+ return map((x) => {
4251
+ if (x instanceof Tracer) while (x._trace.main.level > level) if (x instanceof JVPTracer) {
4252
+ x.tangent.dispose();
4253
+ x = x.primal;
4254
+ } else {
4255
+ const y = x.fullLower();
4256
+ if (y._trace.main.level >= x._trace.main.level) throw new Error("internal: lowerAux did not reduce trace level");
4257
+ x = y;
4258
+ }
4259
+ return x;
4260
+ }, aux);
4261
+ }
4237
4262
 
4238
4263
  //#endregion
4239
4264
  //#region src/frontend/linearize.ts
@@ -4304,9 +4329,11 @@ function linearizeFlat(f, primalsIn) {
4304
4329
  dispose$1
4305
4330
  ];
4306
4331
  }
4307
- function linearize$1(f, ...primalsIn) {
4332
+ function linearize$1(f, primalsIn, { hasAux = false } = {}) {
4308
4333
  const [primalsInFlat, inTree] = flatten(primalsIn);
4309
- const [fFlat, outTree] = flattenFun(f, inTree);
4334
+ let fFlat, outTree, aux;
4335
+ if (hasAux) [fFlat, outTree, aux] = flattenFunWithAux(f, inTree);
4336
+ else [fFlat, outTree] = flattenFun(f, inTree);
4310
4337
  const [primalsOutFlat, fLinFlat, dispose$1] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
4311
4338
  if (outTree.value === void 0) throw new Error("outTree was not set in linearize");
4312
4339
  const primalsOut = unflatten(outTree.value, primalsOutFlat);
@@ -4317,6 +4344,11 @@ function linearize$1(f, ...primalsIn) {
4317
4344
  return unflatten(outTree.value, tangentsOutFlat);
4318
4345
  });
4319
4346
  fLin.dispose = dispose$1;
4347
+ if (hasAux) return [
4348
+ primalsOut,
4349
+ fLin,
4350
+ lowerAux(aux.value)
4351
+ ];
4320
4352
  return [primalsOut, fLin];
4321
4353
  }
4322
4354
  var PartialEvalTracer = class extends Tracer {
@@ -4817,9 +4849,11 @@ function vjpFlat(f, primalsIn) {
4817
4849
  dispose$1
4818
4850
  ];
4819
4851
  }
4820
- function vjp$1(f, ...primalsIn) {
4852
+ function vjp$1(f, primalsIn, { hasAux = false } = {}) {
4821
4853
  const [primalsInFlat, inTree] = flatten(primalsIn);
4822
- const [fFlat, outTree] = flattenFun(f, inTree);
4854
+ let fFlat, outTree, aux;
4855
+ if (hasAux) [fFlat, outTree, aux] = flattenFunWithAux(f, inTree);
4856
+ else [fFlat, outTree] = flattenFun(f, inTree);
4823
4857
  const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
4824
4858
  if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
4825
4859
  const primalsOut = unflatten(outTree.value, primalsOutFlat);
@@ -4830,26 +4864,43 @@ function vjp$1(f, ...primalsIn) {
4830
4864
  return unflatten(inTree, cotangentsInFlat);
4831
4865
  });
4832
4866
  fVjp.dispose = dispose$1;
4867
+ if (hasAux) return [
4868
+ primalsOut,
4869
+ fVjp,
4870
+ lowerAux(aux.value)
4871
+ ];
4833
4872
  return [primalsOut, fVjp];
4834
4873
  }
4835
- function grad$1(f) {
4836
- const valueAndGradFn = valueAndGrad$1(f);
4874
+ function grad$1(f, opts) {
4875
+ const valueAndGradFn = valueAndGrad$1(f, opts);
4837
4876
  return (...x) => {
4838
- const [y, dx] = valueAndGradFn(...x);
4839
- y.dispose();
4840
- return dx;
4877
+ if (opts?.hasAux) {
4878
+ const [[y, aux], dx] = valueAndGradFn(...x);
4879
+ y.dispose();
4880
+ return [dx, aux];
4881
+ } else {
4882
+ const [y, dx] = valueAndGradFn(...x);
4883
+ y.dispose();
4884
+ return dx;
4885
+ }
4841
4886
  };
4842
4887
  }
4843
- function valueAndGrad$1(f) {
4888
+ function valueAndGrad$1(f, opts) {
4889
+ const argnums = opts?.argnums ?? 0;
4890
+ const hasAux = opts?.hasAux ?? false;
4891
+ checkInts(argnums);
4892
+ const argnumsSet = new Set(typeof argnums === "number" ? [argnums] : argnums);
4844
4893
  return (...x) => {
4845
4894
  if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
4846
- const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
4895
+ for (let i = 0; i < x.length; i++) if (!argnumsSet.has(i)) x[i] = map(stopGradient, x[i]);
4896
+ const [y, fVjp, aux] = vjp$1(f, x, { hasAux });
4847
4897
  if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
4848
4898
  if (!isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
4849
- const [ct, ...rest] = fVjp(onesLike$1(y.ref));
4850
- for (const r of rest) dispose(r);
4899
+ const cts = fVjp(onesLike$1(y.ref));
4851
4900
  fVjp.dispose();
4852
- return [y, ct];
4901
+ for (let i = 0; i < cts.length; i++) if (!argnumsSet.has(i)) dispose(cts[i]);
4902
+ const grads = typeof argnums === "number" ? cts[argnums] : argnums.map((i) => cts[i]);
4903
+ return hasAux ? [[y, aux], grads] : [y, grads];
4853
4904
  };
4854
4905
  }
4855
4906
  function jacrev$1(f) {
@@ -4857,7 +4908,7 @@ function jacrev$1(f) {
4857
4908
  if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
4858
4909
  const [size$1] = x.shape;
4859
4910
  const pullback = (ct) => {
4860
- const [y, fVjp] = vjp$1(f, x);
4911
+ const [y, fVjp] = vjp$1(f, [x]);
4861
4912
  y.dispose();
4862
4913
  const [ret] = fVjp(ct);
4863
4914
  fVjp.dispose();
@@ -4866,6 +4917,9 @@ function jacrev$1(f) {
4866
4917
  return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
4867
4918
  };
4868
4919
  }
4920
+ function hessian$1(f) {
4921
+ return jacfwd$1(grad$1(f));
4922
+ }
4869
4923
 
4870
4924
  //#endregion
4871
4925
  //#region src/library/numpy/einsum.ts
@@ -5538,6 +5592,7 @@ __export(numpy_exports, {
5538
5592
  moveaxis: () => moveaxis$1,
5539
5593
  multiply: () => multiply,
5540
5594
  nan: () => nan,
5595
+ nanToNum: () => nanToNum,
5541
5596
  ndim: () => ndim,
5542
5597
  negative: () => negative,
5543
5598
  notEqual: () => notEqual,
@@ -5575,6 +5630,7 @@ __export(numpy_exports, {
5575
5630
  std: () => std,
5576
5631
  subtract: () => subtract,
5577
5632
  sum: () => sum,
5633
+ swapaxes: () => swapaxes,
5578
5634
  take: () => take,
5579
5635
  tan: () => tan,
5580
5636
  tanh: () => tanh,
@@ -5734,24 +5790,22 @@ function max(a, axis = null, opts) {
5734
5790
  return reduce(a, AluOp.Max, axis, opts);
5735
5791
  }
5736
5792
  /**
5737
- * Test whether all array elements along a given axis evaluate to True.
5793
+ * Test whether any array element along a given axis evaluates to True.
5738
5794
  *
5739
5795
  * Returns a boolean array with the same shape as `a` with the specified axis
5740
5796
  * removed. If axis is None, returns a scalar.
5741
5797
  */
5742
- function all(a, axis = null, opts) {
5743
- a = fudgeArray(a).astype(DType.Bool);
5744
- return min(a, axis, opts);
5798
+ function any(a, axis = null, opts) {
5799
+ return fudgeArray(a).any(axis, opts);
5745
5800
  }
5746
5801
  /**
5747
- * Test whether any array element along a given axis evaluates to True.
5802
+ * Test whether all array elements along a given axis evaluate to True.
5748
5803
  *
5749
5804
  * Returns a boolean array with the same shape as `a` with the specified axis
5750
5805
  * removed. If axis is None, returns a scalar.
5751
5806
  */
5752
- function any(a, axis = null, opts) {
5753
- a = fudgeArray(a).astype(DType.Bool);
5754
- return max(a, axis, opts);
5807
+ function all(a, axis = null, opts) {
5808
+ return fudgeArray(a).all(axis, opts);
5755
5809
  }
5756
5810
  /** Return the peak-to-peak range along a given axis (`max - min`). */
5757
5811
  function ptp(a, axis = null, opts) {
@@ -5852,7 +5906,7 @@ function split$1(a, indicesOrSections, axis = 0) {
5852
5906
  const partSize = size$1 / indicesOrSections;
5853
5907
  sizes = rep(indicesOrSections, partSize);
5854
5908
  } else {
5855
- const indices = indicesOrSections;
5909
+ const indices = indicesOrSections.map((i) => i < 0 ? i + size$1 : i);
5856
5910
  sizes = [indices[0]];
5857
5911
  for (let i = 1; i < indices.length; i++) sizes.push(indices[i] - indices[i - 1]);
5858
5912
  sizes.push(size$1 - indices[indices.length - 1]);
@@ -5973,6 +6027,17 @@ function flipud(x) {
5973
6027
  function fliplr(x) {
5974
6028
  return flip(x, 1);
5975
6029
  }
6030
+ /** Interchange two axes of an array. */
6031
+ function swapaxes(a, axis1, axis2) {
6032
+ a = fudgeArray(a);
6033
+ axis1 = checkAxis(axis1, a.ndim);
6034
+ axis2 = checkAxis(axis2, a.ndim);
6035
+ if (axis1 === axis2) return a;
6036
+ const perm = range(a.ndim);
6037
+ perm[axis1] = axis2;
6038
+ perm[axis2] = axis1;
6039
+ return transpose(a, perm);
6040
+ }
5976
6041
  /** Transpose the last two dimensions of an array. */
5977
6042
  function matrixTranspose(a) {
5978
6043
  if (ndim(a) < 2) throw new Error(`matrixTranspose: input array must be at least 2D`);
@@ -6789,6 +6854,21 @@ function isposinf(x) {
6789
6854
  return isFloatDtype(x.dtype) ? x.equal(Infinity) : fullLike$1(x, false);
6790
6855
  }
6791
6856
  /**
6857
+ * Replace NaN and infinite entries in an array.
6858
+ *
6859
+ * By default, NaNs are replaced with `0.0`, and infinities are are substituted
6860
+ * with the corresponding maximum or minimum finite values.
6861
+ */
6862
+ function nanToNum(x, { nan: nan$1 = 0, posinf = null, neginf = null } = {}) {
6863
+ x = fudgeArray(x);
6864
+ x = where(isnan(x.ref), nan$1, x);
6865
+ posinf ??= isFloatDtype(x.dtype) ? finfo(x.dtype).max : iinfo(x.dtype).max;
6866
+ neginf ??= isFloatDtype(x.dtype) ? finfo(x.dtype).min : iinfo(x.dtype).min;
6867
+ x = where(isposinf(x.ref), posinf, x);
6868
+ x = where(isneginf(x.ref), neginf, x);
6869
+ return x;
6870
+ }
6871
+ /**
6792
6872
  * @function
6793
6873
  * Test element-wise for finite values (not infinity or NaN).
6794
6874
  */
@@ -6901,6 +6981,7 @@ var lax_exports = {};
6901
6981
  __export(lax_exports, {
6902
6982
  conv: () => conv,
6903
6983
  convGeneralDilated: () => convGeneralDilated,
6984
+ convTranspose: () => convTranspose,
6904
6985
  convWithGeneralPadding: () => convWithGeneralPadding,
6905
6986
  dot: () => dot,
6906
6987
  erf: () => erf,
@@ -6909,6 +6990,7 @@ __export(lax_exports, {
6909
6990
  reduceWindow: () => reduceWindow,
6910
6991
  stopGradient: () => stopGradient$1
6911
6992
  });
6993
+ const JsArray = globalThis.Array;
6912
6994
  /**
6913
6995
  * General dot product/contraction operator.
6914
6996
  *
@@ -6980,7 +7062,11 @@ function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
6980
7062
  * The semantics of this operation mimic the `jax.lax.conv_general_dilated`
6981
7063
  * function in JAX, which wraps XLA's general convolution operator.
6982
7064
  *
6983
- * Grouped convolutions are not supported right now.
7065
+ * @param lhs - Input tensor; shape `[N, C_in, ...xs]`
7066
+ * @param rhs - Convolution kernel; shape `[C_out, C_in / G, ...ks]`
7067
+ * @param windowStrides - Strides for each spatial dimension
7068
+ * @param padding - Padding for each spatial dimension, or a string
7069
+ * (`"VALID"`, `"SAME"`, or `"SAME_LOWER"`)
6984
7070
  */
6985
7071
  function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
6986
7072
  if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
@@ -7040,6 +7126,60 @@ function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, r
7040
7126
  function conv(lhs, rhs, windowStrides, padding) {
7041
7127
  return convGeneralDilated(lhs, rhs, windowStrides, padding);
7042
7128
  }
7129
+ /**
7130
+ * Convenience wrapper for calculating the N-d convolution "transpose".
7131
+ *
7132
+ * This function directly calculates a fractionally strided conv rather than
7133
+ * indirectly calculating the gradient (transpose) of a forward convolution.
7134
+ * It is equivalent to the JAX version, except:
7135
+ *
7136
+ * - The `use_consistent_padding` option is not available. We only have the
7137
+ * consistent padding case (JAX version >0.8.4).
7138
+ * - The order of dimensions matches `lax.conv_general_dilated`.
7139
+ *
7140
+ * Unlike PyTorch/TensorFlow, by default we don't reverse the kernel's spatial
7141
+ * dimensions or the `(C_out, C_in)` axis order. To get this behavior, set
7142
+ * `transposeKernel` to true.
7143
+ *
7144
+ * @param lhs - Input tensor; shape `[N, C_in, ...xs]`
7145
+ * @param rhs - Convolution kernel; shape `[C_out, C_in, ...ks]`
7146
+ * @param strides - Sequence of n integers, sets fractional stride
7147
+ * @param padding - Apply padding of `dilation * (kernel_size - 1) - padding` to
7148
+ * each side of the input, so it acts like gradient of `conv()`
7149
+ * @param rhsDilation - Atrous dilation for the kernel
7150
+ * @param transposeKernel - Flip spatial axes and swap the input/output channels
7151
+ * of the kernel; its shape should be `[C_in, C_out, ...ks]`
7152
+ */
7153
+ function convTranspose(lhs, rhs, strides, padding, { rhsDilation, transposeKernel = false } = {}) {
7154
+ const kernelShape = rhs.shape.slice(2);
7155
+ rhsDilation = rhsDilation ?? rep(kernelShape.length, 1);
7156
+ const effectiveKernel = kernelShape.map((k, i) => Math.max(0, (k - 1) * rhsDilation[i] + 1));
7157
+ const pads = effectiveKernel.map((k, i) => convTransposePadding(k, strides[i], typeof padding === "string" ? padding : padding[i]));
7158
+ if (transposeKernel) {
7159
+ rhs = flip$1(rhs, range(2, rhs.ndim));
7160
+ rhs = moveaxis(rhs, 0, 1);
7161
+ }
7162
+ return convGeneralDilated(lhs, rhs, rep(lhs.ndim - 2, 1), pads, {
7163
+ lhsDilation: strides,
7164
+ rhsDilation
7165
+ });
7166
+ }
7167
+ function convTransposePadding(k, s, padding) {
7168
+ let padLen;
7169
+ let pad1;
7170
+ if (padding === "SAME") {
7171
+ padLen = k + s - 2;
7172
+ pad1 = s > k - 1 ? k - 1 : Math.ceil(padLen / 2);
7173
+ } else if (padding === "VALID") {
7174
+ padLen = k + s - 2 + Math.max(k - s, 0);
7175
+ pad1 = k - 1;
7176
+ } else if (JsArray.isArray(padding)) {
7177
+ const pads = [k - 1 - padding[0], k - 1 - padding[1]];
7178
+ pad1 = pads[0];
7179
+ padLen = pads[0] + pads[1];
7180
+ } else throw new Error(`convTranspose: Invalid padding type ${padding}`);
7181
+ return [pad1, padLen - pad1];
7182
+ }
7043
7183
  /** Reduce a computation over padded windows. */
7044
7184
  function reduceWindow(operand, computation, windowDimensions, windowStrides) {
7045
7185
  if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
@@ -7078,6 +7218,7 @@ function stopGradient$1(x) {
7078
7218
  var nn_exports = {};
7079
7219
  __export(nn_exports, {
7080
7220
  celu: () => celu,
7221
+ dotProductAttention: () => dotProductAttention,
7081
7222
  elu: () => elu,
7082
7223
  gelu: () => gelu,
7083
7224
  glu: () => glu,
@@ -7394,6 +7535,125 @@ function oneHot(x, numClasses) {
7394
7535
  if (isFloatDtype(x.dtype)) throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
7395
7536
  return eye(numClasses, void 0, { device: x.device }).slice(x);
7396
7537
  }
7538
+ /**
7539
+ * Scaled dot product attention (SDPA).
7540
+ *
7541
+ * Computes `softmax((Q @ K^T) / sqrt(d) + bias) @ V`, where `Q` is the query,
7542
+ * `K` is the key, `V` is the value, and `d` is the dimensionality of each key
7543
+ * and query vector.
7544
+ *
7545
+ * Multi-query attention is applied when input `key` and `value` tensors have
7546
+ * fewer heads than `query`.
7547
+ *
7548
+ * We use the following uppercase letters to denote array shapes:
7549
+ * - `B` = batch size
7550
+ * - `S` = length of key/value sequences (source)
7551
+ * - `L` = length of query sequences
7552
+ * - `N` = number of attention heads
7553
+ * - `H` = dimensionality of each attention head
7554
+ * - `K` = number of key/value heads (for grouped-query attention)
7555
+ *
7556
+ * The batch size `B` may be omitted, which is equivalent to `B = 1`. In this
7557
+ * case it must be omitted from all inputs.
7558
+ *
7559
+ * @param query - Query array; shape `[B, L, N, H]`
7560
+ * @param key - Key array; shape `[B, S, K, H]`
7561
+ * @param value - Value array; same shape as `key`
7562
+ * @param opts.bias - Optional bias to add to the attention logits; shape
7563
+ * `[B, N, L, S]` or broadcastable to it.
7564
+ * @param opts.mask - Optional mask to apply to the attention logits; should be
7565
+ * a boolean array broadcastable to `[B, N, L, S]`, where `true` indicates
7566
+ * the element should take part in attention.
7567
+ * @param opts.scale - Scaling factor override, default is `1 / sqrt(H)`.
7568
+ * @param opts.isCausal - If true, applies a casual mask.
7569
+ * @param opts.querySeqLengths - Optional sequence lengths for the queries;
7570
+ * shape `(B,)`. Taken from the beginning of the tensor.
7571
+ * @param opts.keyValueSeqLengths - Optional sequence lengths for the keys and
7572
+ * values; shape `(B,)`. Taken from the beginning of the tensor.
7573
+ * @param opts.localWindowSize - If specified, applies a local attention window
7574
+ * of the given size. Can be a single number or a tuple `[left, right]`.
7575
+ *
7576
+ * @returns The result of the attention operation; shape is the same as query
7577
+ * `[B, L, N, H]`, or `[L, N, H]` if `B` is omitted.
7578
+ */
7579
+ function dotProductAttention(query, key$1, value, opts = {}) {
7580
+ query = fudgeArray(query);
7581
+ key$1 = fudgeArray(key$1);
7582
+ value = fudgeArray(value);
7583
+ if (query.ndim !== 3 && query.ndim !== 4 || query.ndim !== key$1.ndim || query.ndim !== value.ndim) throw new Error(`dotProductAttention: expected all tensors to have rank 3 or 4, got Q=${query.aval}, K=${key$1.aval}, V=${value.aval}`);
7584
+ if (!deepEqual(key$1.shape, value.shape)) throw new Error(`dotProductAttention: key and value shapes must match, got K=${key$1.shape}, V=${value.shape}`);
7585
+ const isRank3 = query.ndim === 3;
7586
+ if (isRank3) {
7587
+ query = expandDims(query, 0);
7588
+ key$1 = expandDims(key$1, 0);
7589
+ value = expandDims(value, 0);
7590
+ }
7591
+ const [B, L, N, H] = query.shape;
7592
+ if (key$1.shape[0] !== B || key$1.shape[3] !== H) throw new Error(`dotProductAttention: query and key shapes mismatch, got Q=${query.aval}, K=${key$1.aval}`);
7593
+ const S = key$1.shape[1];
7594
+ const K = key$1.shape[2];
7595
+ if (N < K || N != K && N % K !== 0) throw new Error(`dotProductAttention: number of query heads N=${N} must be divisible by number of key/value heads K=${K} for GQA`);
7596
+ const G = N / K;
7597
+ key$1 = tile(key$1, [
7598
+ 1,
7599
+ 1,
7600
+ G,
7601
+ 1
7602
+ ]);
7603
+ value = tile(value, [
7604
+ 1,
7605
+ 1,
7606
+ G,
7607
+ 1
7608
+ ]);
7609
+ const scale = opts.scale ?? 1 / Math.sqrt(H);
7610
+ let scores = einsum("BLNH,BSNH->BNLS", query, key$1).mul(scale);
7611
+ if (opts.bias !== void 0) scores = scores.add(opts.bias);
7612
+ if (opts.mask !== void 0) scores = where(opts.mask, scores, -Infinity);
7613
+ if (opts.isCausal) {
7614
+ const causalMask = tri(L, S, 0, { dtype: DType.Bool });
7615
+ scores = where(causalMask, scores, -Infinity);
7616
+ }
7617
+ if (opts.localWindowSize !== void 0) {
7618
+ const [before, after] = typeof opts.localWindowSize === "number" ? [opts.localWindowSize, opts.localWindowSize] : opts.localWindowSize;
7619
+ if (before < 0 || after < 0 || !Number.isInteger(before) || !Number.isInteger(after)) throw new Error(`dotProductAttention: localWindowSize values must be non-negative, got ${opts.localWindowSize}`);
7620
+ const localMask = tri(L, S, after, { dtype: DType.Bool }).mul(tri(L, S, -before - 1, { dtype: DType.Bool }).notEqual(true));
7621
+ scores = where(localMask, scores, -Infinity);
7622
+ }
7623
+ if (opts.querySeqLengths !== void 0) {
7624
+ const sl = expandDims(opts.querySeqLengths, [
7625
+ -1,
7626
+ -2,
7627
+ -3
7628
+ ]);
7629
+ scores = where(arange(L).reshape([
7630
+ 1,
7631
+ 1,
7632
+ L,
7633
+ 1
7634
+ ]).less(sl), scores, -Infinity);
7635
+ }
7636
+ if (opts.keyValueSeqLengths !== void 0) {
7637
+ const sl = expandDims(opts.keyValueSeqLengths, [
7638
+ -1,
7639
+ -2,
7640
+ -3
7641
+ ]);
7642
+ scores = where(arange(S).reshape([
7643
+ 1,
7644
+ 1,
7645
+ 1,
7646
+ S
7647
+ ]).less(sl), scores, -Infinity);
7648
+ }
7649
+ const attn = softmax(scores, -1);
7650
+ const out = einsum("BNLS,BSNH->BLNH", attn, value);
7651
+ return isRank3 ? out.reshape([
7652
+ L,
7653
+ N,
7654
+ H
7655
+ ]) : out;
7656
+ }
7397
7657
 
7398
7658
  //#endregion
7399
7659
  //#region src/library/random.ts
@@ -7629,17 +7889,62 @@ const linearize = linearize$1;
7629
7889
  /**
7630
7890
  * @function
7631
7891
  * Calculate the reverse-mode vector-Jacobian product for a function.
7892
+ *
7893
+ * The return value is a tuple of `[out, vjpFn]`, where `out` is the output of
7894
+ * `f(primals)`, and `vjpFn` is a function that takes in cotangents for each
7895
+ * output and returns the cotangents for each input.
7896
+ *
7897
+ * When `{ hasAux: true }` is passed, the function `f` is expected to return an
7898
+ * `[out, aux]` tuple, and `vjp` returns `[out, vjpFn, aux]`.
7899
+ *
7900
+ * @example
7901
+ * ```ts
7902
+ * const [y, vjpFn] = vjp(f, [x]);
7903
+ *
7904
+ * // With hasAux
7905
+ * const [y, vjpFn, aux] = vjp(f, [x], { hasAux: true });
7906
+ * ```
7632
7907
  */
7633
7908
  const vjp = vjp$1;
7634
7909
  /**
7635
7910
  * @function
7636
7911
  * Compute the gradient of a scalar-valued function `f` with respect to its
7637
7912
  * first argument.
7913
+ *
7914
+ * Pass in different `argnums` to differentiate with respect to other
7915
+ * arguments. If a tuple is provided, the return value will be a tuple of
7916
+ * gradients corresponding to each argument index.
7917
+ *
7918
+ * When `{ hasAux: true }` is passed, the function `f` is expected to return a
7919
+ * `[out, aux]` tuple, and the return value will be `[gradient, aux]`.
7920
+ *
7921
+ * @example
7922
+ * ```ts
7923
+ * const gradient = grad(f)(x);
7924
+ *
7925
+ * // With `argnums`
7926
+ * const [gradientX, gradientZ] = grad(f, { argnums: [0, 2] })(x, y, z);
7927
+ *
7928
+ * // With `hasAux`
7929
+ * const [gradient, aux] = grad(f, { hasAux: true })(x);
7930
+ * ```
7638
7931
  */
7639
7932
  const grad = grad$1;
7640
7933
  /**
7641
7934
  * @function
7642
7935
  * Create a function that evaluates both `f` and the gradient of `f`.
7936
+ *
7937
+ * When `{ hasAux: true }` is passed, the function `f` is expected to return an
7938
+ * `[out, aux]` tuple, and the return value will be `[[out, aux], gradient]`.
7939
+ *
7940
+ * @example
7941
+ * ```ts
7942
+ * // Without hasAux
7943
+ * const [value, gradient] = valueAndGrad(f)(x);
7944
+ *
7945
+ * // With hasAux
7946
+ * const [[value, aux], gradient] = valueAndGrad(f, { hasAux: true })(x);
7947
+ * ```
7643
7948
  */
7644
7949
  const valueAndGrad = valueAndGrad$1;
7645
7950
  /**
@@ -7648,6 +7953,21 @@ const valueAndGrad = valueAndGrad$1;
7648
7953
  */
7649
7954
  const jacrev = jacrev$1;
7650
7955
  /**
7956
+ * @function
7957
+ * Compute the Hessian matrix of a scalar-valued function.
7958
+ *
7959
+ * The Hessian is the matrix of second-order partial derivatives of a function.
7960
+ * This is implemented as `jacfwd(grad(f))`.
7961
+ *
7962
+ * @example
7963
+ * ```ts
7964
+ * const f = (x: np.Array) => np.sum(x.ref.mul(x.ref).mul(x)); // x^3
7965
+ * const H = hessian(f)(np.array([1, 2, 3]));
7966
+ * // H[i,j] = d^2f / dx_i dx_j
7967
+ * ```
7968
+ */
7969
+ const hessian = hessian$1;
7970
+ /**
7651
7971
  * Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
7652
7972
  *
7653
7973
  * This can be used to wait for the results of an intermediate computation to
@@ -7682,4 +8002,4 @@ async function devicePut(x, device) {
7682
8002
  }
7683
8003
 
7684
8004
  //#endregion
7685
- export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
8005
+ export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };