@jax-js/jax 0.1.5 → 0.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.cjs CHANGED
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
30
30
  }) : target, mod$1));
31
31
 
32
32
  //#endregion
33
- const require_backend = require('./backend-DziQSaoQ.cjs');
33
+ const require_backend = require('./backend-B3foXiV_.cjs');
34
34
 
35
35
  //#region src/frontend/convolution.ts
36
36
  /**
@@ -240,7 +240,7 @@ __export(tree_exports, {
240
240
  structure: () => structure,
241
241
  unflatten: () => unflatten
242
242
  });
243
- const JsArray$1 = globalThis.Array;
243
+ const JsArray$2 = globalThis.Array;
244
244
  let NodeType = /* @__PURE__ */ function(NodeType$1) {
245
245
  NodeType$1["Array"] = "Array";
246
246
  NodeType$1["Object"] = "Object";
@@ -288,7 +288,7 @@ function flatten(tree) {
288
288
  return [leaves$1, treedef];
289
289
  }
290
290
  function _flatten(tree, leaves$1) {
291
- if (JsArray$1.isArray(tree)) {
291
+ if (JsArray$2.isArray(tree)) {
292
292
  const childTrees = tree.map((c) => _flatten(c, leaves$1));
293
293
  return new JsTreeDef(NodeType.Array, null, childTrees);
294
294
  } else if (typeof tree === "object" && tree !== null && tree.constructor === Object) {
@@ -337,11 +337,11 @@ function map(fn, tree, ...rest) {
337
337
  }
338
338
  /** Take a reference of every array in a tree. */
339
339
  function ref(tree) {
340
- return map((x) => x.ref, tree);
340
+ return map((x) => x instanceof Tracer ? x.ref : x, tree);
341
341
  }
342
342
  /** Dispose every array in a tree. */
343
343
  function dispose(tree) {
344
- if (tree) map((x) => x.dispose(), tree);
344
+ if (tree) map((x) => x instanceof Tracer ? x.dispose() : void 0, tree);
345
345
  }
346
346
 
347
347
  //#endregion
@@ -412,6 +412,13 @@ let CompareOp = /* @__PURE__ */ function(CompareOp$1) {
412
412
  CompareOp$1["LessEqual"] = "less_equal";
413
413
  return CompareOp$1;
414
414
  }({});
415
+ const routinePrimitives = new Map([
416
+ [Primitive.Sort, require_backend.Routines.Sort],
417
+ [Primitive.Argsort, require_backend.Routines.Argsort],
418
+ [Primitive.TriangularSolve, require_backend.Routines.TriangularSolve],
419
+ [Primitive.Cholesky, require_backend.Routines.Cholesky],
420
+ [Primitive.LU, require_backend.Routines.LU]
421
+ ]);
415
422
  function add$1(x, y) {
416
423
  return bind1(Primitive.Add, [x, y]);
417
424
  }
@@ -608,14 +615,20 @@ function shrink(x, slice) {
608
615
  }
609
616
  function pad$1(x, width) {
610
617
  const nd = ndim$1(x);
611
- if (typeof width === "number") width = [[width, width]];
612
- else if (require_backend.isNumberPair(width)) width = [width];
613
- else if (!Array.isArray(width) || !width.every(require_backend.isNumberPair)) throw new TypeError(`Invalid pad() type: ${JSON.stringify(width)}`);
614
- if (width.length === 1) {
615
- const [w0, w1] = width[0];
616
- width = require_backend.rep(nd, () => [w0, w1]);
617
- } else if (width.length !== nd) throw new Error(`Invalid pad(): expected ${nd} axes, got ${width.length}`);
618
- return bind1(Primitive.Pad, [x], { width });
618
+ let w;
619
+ if (typeof width === "number") w = [[width, width]];
620
+ else if (require_backend.isNumberPair(width)) w = [width];
621
+ else if (!Array.isArray(width)) {
622
+ const indicesAndPairs = Object.entries(width);
623
+ w = require_backend.rep(nd, [0, 0]);
624
+ for (const [k, v] of indicesAndPairs) w[require_backend.checkAxis(parseInt(k), nd)] = v;
625
+ } else if (!width.every(require_backend.isNumberPair)) throw new TypeError(`Invalid pad() type: ${JSON.stringify(width)}`);
626
+ else w = width;
627
+ if (w.length === 1) {
628
+ const [w0, w1] = w[0];
629
+ w = require_backend.rep(nd, () => [w0, w1]);
630
+ } else if (w.length !== nd) throw new Error(`Invalid pad(): expected ${nd} axes, got ${w.length}`);
631
+ return bind1(Primitive.Pad, [x], { width: w });
619
632
  }
620
633
  function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
621
634
  const as = getShape(a);
@@ -685,6 +698,9 @@ function newDynamic(main) {
685
698
  dynamicTrace = prevDynamicTrace;
686
699
  } };
687
700
  }
701
+ function currentTraceLevel() {
702
+ return traceStack[traceStack.length - 1].level;
703
+ }
688
704
  var Trace = class {
689
705
  constructor(main) {
690
706
  this.main = main;
@@ -788,6 +804,22 @@ var Tracer = class Tracer {
788
804
  const result = reduce(this.astype(castDtype), require_backend.AluOp.Add, axis, opts);
789
805
  return result.mul(1 / n).astype(originalDtype);
790
806
  }
807
+ /** Minimum of the elements of the array along a given axis. */
808
+ min(axis = null, opts) {
809
+ return reduce(this, require_backend.AluOp.Min, axis, opts);
810
+ }
811
+ /** Maximum of the elements of the array along a given axis. */
812
+ max(axis = null, opts) {
813
+ return reduce(this, require_backend.AluOp.Max, axis, opts);
814
+ }
815
+ /** Test whether all array elements along a given axis evaluate to true. */
816
+ all(axis = null, opts) {
817
+ return this.astype(require_backend.DType.Bool).min(axis, opts);
818
+ }
819
+ /** Test whether any array element along a given axis evaluates to true. */
820
+ any(axis = null, opts) {
821
+ return this.astype(require_backend.DType.Bool).max(axis, opts);
822
+ }
791
823
  /** Permute the dimensions of an array. Defaults to reversing the axis order. */
792
824
  transpose(perm) {
793
825
  return transpose$1(this, perm);
@@ -1062,6 +1094,7 @@ var TreeMismatchError = class extends TypeError {
1062
1094
  super(`Mismatched tree structures in ${where$2}: ${left} != ${right}`);
1063
1095
  }
1064
1096
  };
1097
+ /** Flatten a function of `JsTree` input/output for use in tracing. */
1065
1098
  function flattenFun(f, inTree) {
1066
1099
  const store = { value: void 0 };
1067
1100
  const flatFun = (...argsFlat) => {
@@ -1073,6 +1106,26 @@ function flattenFun(f, inTree) {
1073
1106
  };
1074
1107
  return [flatFun, store];
1075
1108
  }
1109
+ /** Like flattenFun, but expects f to return [main, aux] tuple. */
1110
+ function flattenFunWithAux(f, inTree) {
1111
+ const store = { value: void 0 };
1112
+ const auxStore = { value: void 0 };
1113
+ const flatFun = (...argsFlat) => {
1114
+ const pytreeArgs = unflatten(inTree, argsFlat);
1115
+ const result = f(...pytreeArgs);
1116
+ if (!Array.isArray(result) || result.length !== 2) throw new Error("Function with `hasAux: true` must return [output, aux] tuple");
1117
+ const [out, aux] = result;
1118
+ const [outFlat, outTree] = flatten(out);
1119
+ store.value = outTree;
1120
+ auxStore.value = aux;
1121
+ return outFlat;
1122
+ };
1123
+ return [
1124
+ flatFun,
1125
+ store,
1126
+ auxStore
1127
+ ];
1128
+ }
1076
1129
  var UseAfterFreeError = class extends ReferenceError {
1077
1130
  constructor(tracer) {
1078
1131
  super(`Referenced tracer ${tracer.toString()} freed, please use .ref move semantics`);
@@ -1806,13 +1859,6 @@ function jit$1(f, opts) {
1806
1859
 
1807
1860
  //#endregion
1808
1861
  //#region src/frontend/jit.ts
1809
- const routinePrimitives = new Map([
1810
- [Primitive.Sort, require_backend.Routines.Sort],
1811
- [Primitive.Argsort, require_backend.Routines.Argsort],
1812
- [Primitive.TriangularSolve, require_backend.Routines.TriangularSolve],
1813
- [Primitive.Cholesky, require_backend.Routines.Cholesky],
1814
- [Primitive.LU, require_backend.Routines.LU]
1815
- ]);
1816
1862
  /** Result of compiling a Jaxpr. Can be evaluated on a series of inputs. */
1817
1863
  var JitProgram = class {
1818
1864
  constructor(backend, steps, inputs, outputs) {
@@ -2201,12 +2247,13 @@ const jitRules = {
2201
2247
  const ndim$2 = avals[0].ndim;
2202
2248
  const sizes = avals.map((x) => x.shape[axis]);
2203
2249
  const finalSize = sizes.reduce((a, b) => a + b, 0);
2250
+ const { dtype: dtypeOut } = avals.map((x) => x.scalar()).reduce(promoteAvals);
2204
2251
  const makePadAxis = (start, end) => require_backend.range(ndim$2).map((i) => i === axis ? [start, end] : [0, 0]);
2205
2252
  let cum = 0;
2206
2253
  const src = [];
2207
2254
  for (let i = 0; i < exps.length; i++) {
2208
2255
  const padding = makePadAxis(cum, finalSize - cum - sizes[i]);
2209
- src.push(reshapeViews(exps[i], (st) => st.pad(padding)));
2256
+ src.push(reshapeViews(require_backend.AluExp.cast(dtypeOut, exps[i]), (st) => st.pad(padding)));
2210
2257
  cum += sizes[i];
2211
2258
  }
2212
2259
  return { exp: [src.reduce(require_backend.AluExp.add)] };
@@ -2344,7 +2391,7 @@ function splitGraphDataflow(backend, jaxpr) {
2344
2391
  p1NextBlack.set(v, v);
2345
2392
  }
2346
2393
  const heterogeneousViewPrimitives = [Primitive.RandomBits, Primitive.Gather];
2347
- const needsCleanShapePrimitives = [Primitive.Pad];
2394
+ const needsCleanShapePrimitives = [Primitive.Concatenate, Primitive.Pad];
2348
2395
  for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
2349
2396
  const eqn = jaxpr.eqns[i];
2350
2397
  if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || routinePrimitives.has(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
@@ -2414,7 +2461,7 @@ function splitGraphDataflow(backend, jaxpr) {
2414
2461
 
2415
2462
  //#endregion
2416
2463
  //#region src/frontend/array.ts
2417
- const JsArray = globalThis.Array;
2464
+ const JsArray$1 = globalThis.Array;
2418
2465
  const inlineArrayLimit = 128;
2419
2466
  /** Version of pureArray with fudged types. */
2420
2467
  const fudgeArray = pureArray;
@@ -2812,25 +2859,35 @@ var Array$1 = class Array$1 extends Tracer {
2812
2859
  });
2813
2860
  }
2814
2861
  /** Apply an operation with custom lowering to this array. */
2815
- static #routine(routine, arrays, outputWeakType) {
2816
- const { backend, committed } = Array$1.#computeBackend(routine.name, arrays);
2817
- for (const ar of arrays) ar.#realize();
2818
- const inputs = arrays.map((ar) => ar.#source);
2819
- const outputs = routine.type.outputDtypes.map((dtype, i) => backend.malloc(require_backend.byteWidth(dtype) * require_backend.prod(routine.type.outputShapes[i])));
2820
- const pending = arrays.flatMap((ar) => ar.#pending);
2821
- for (const exe of pending) exe.updateRc(+outputs.length);
2822
- pending.push(new PendingExecute(backend, routine, inputs, outputs));
2823
- pending[pending.length - 1].updateRc(+outputs.length - 1);
2824
- arrays.forEach((ar) => ar.dispose());
2825
- return outputs.map((output, i) => new Array$1({
2826
- source: output,
2827
- st: require_backend.ShapeTracker.fromShape(routine.type.outputShapes[i]),
2828
- dtype: routine.type.outputDtypes[i],
2829
- weakType: outputWeakType[i],
2830
- backend,
2831
- committed,
2832
- pending
2833
- }));
2862
+ static #routine(prim) {
2863
+ return (arrays, params) => {
2864
+ const { backend, committed } = Array$1.#computeBackend(prim, arrays);
2865
+ for (const ar of arrays) ar.#realize();
2866
+ const avals = arrays.map((ar) => ar.aval);
2867
+ const avalsOut = abstractEvalRules[prim](avals, params);
2868
+ const routine = new require_backend.Routine(routinePrimitives.get(prim), {
2869
+ inputShapes: avals.map((a) => a.shape),
2870
+ inputDtypes: avals.map((a) => a.dtype),
2871
+ outputShapes: avalsOut.map((a) => a.shape),
2872
+ outputDtypes: avalsOut.map((a) => a.dtype)
2873
+ }, params);
2874
+ const inputs = arrays.map((ar) => ar.#source);
2875
+ const outputs = avalsOut.map((x) => backend.malloc(require_backend.byteWidth(x.dtype) * x.size));
2876
+ const pending = arrays.flatMap((ar) => ar.#pending);
2877
+ for (const exe of pending) exe.updateRc(+outputs.length);
2878
+ pending.push(new PendingExecute(backend, routine, inputs, outputs));
2879
+ pending[pending.length - 1].updateRc(+outputs.length - 1);
2880
+ arrays.forEach((ar) => ar.dispose());
2881
+ return outputs.map((output, i) => new Array$1({
2882
+ source: output,
2883
+ st: require_backend.ShapeTracker.fromShape(avalsOut[i].shape),
2884
+ dtype: avalsOut[i].dtype,
2885
+ weakType: avalsOut[i].weakType,
2886
+ backend,
2887
+ committed,
2888
+ pending
2889
+ }));
2890
+ };
2834
2891
  }
2835
2892
  /**
2836
2893
  * Normalizes this array into one backed by a `Slot`.
@@ -3164,65 +3221,11 @@ var Array$1 = class Array$1 extends Tracer {
3164
3221
  [Primitive.Pad]([x], { width }) {
3165
3222
  return [x.#reshape(x.#st.pad(width))];
3166
3223
  },
3167
- [Primitive.Sort]([x]) {
3168
- const routine = new require_backend.Routine(require_backend.Routines.Sort, {
3169
- inputShapes: [x.shape],
3170
- inputDtypes: [x.dtype],
3171
- outputShapes: [x.shape],
3172
- outputDtypes: [x.dtype]
3173
- });
3174
- return Array$1.#routine(routine, [x], [x.#weakType]);
3175
- },
3176
- [Primitive.Argsort]([x]) {
3177
- const routine = new require_backend.Routine(require_backend.Routines.Argsort, {
3178
- inputShapes: [x.shape],
3179
- inputDtypes: [x.dtype],
3180
- outputShapes: [x.shape, x.shape],
3181
- outputDtypes: [x.dtype, require_backend.DType.Int32]
3182
- });
3183
- return Array$1.#routine(routine, [x], [x.#weakType, false]);
3184
- },
3185
- [Primitive.TriangularSolve]([a, b], { unitDiagonal }) {
3186
- const routine = new require_backend.Routine(require_backend.Routines.TriangularSolve, {
3187
- inputShapes: [a.shape, b.shape],
3188
- inputDtypes: [a.dtype, b.dtype],
3189
- outputShapes: [b.shape],
3190
- outputDtypes: [b.dtype]
3191
- }, { unitDiagonal });
3192
- return Array$1.#routine(routine, [a, b], [a.#weakType && b.#weakType]);
3193
- },
3194
- [Primitive.Cholesky]([a]) {
3195
- const routine = new require_backend.Routine(require_backend.Routines.Cholesky, {
3196
- inputShapes: [a.shape],
3197
- inputDtypes: [a.dtype],
3198
- outputShapes: [a.shape],
3199
- outputDtypes: [a.dtype]
3200
- });
3201
- return Array$1.#routine(routine, [a], [a.#weakType]);
3202
- },
3203
- [Primitive.LU]([a]) {
3204
- const batch = a.shape.slice(0, -2);
3205
- const [m, n] = a.shape.slice(-2);
3206
- const routine = new require_backend.Routine(require_backend.Routines.LU, {
3207
- inputShapes: [a.shape],
3208
- inputDtypes: [a.dtype],
3209
- outputShapes: [
3210
- a.shape,
3211
- [...batch, Math.min(m, n)],
3212
- [...batch, m]
3213
- ],
3214
- outputDtypes: [
3215
- a.dtype,
3216
- require_backend.DType.Int32,
3217
- require_backend.DType.Int32
3218
- ]
3219
- });
3220
- return Array$1.#routine(routine, [a], [
3221
- a.#weakType,
3222
- false,
3223
- false
3224
- ]);
3225
- },
3224
+ [Primitive.Sort]: Array$1.#routine(Primitive.Sort),
3225
+ [Primitive.Argsort]: Array$1.#routine(Primitive.Argsort),
3226
+ [Primitive.TriangularSolve]: Array$1.#routine(Primitive.TriangularSolve),
3227
+ [Primitive.Cholesky]: Array$1.#routine(Primitive.Cholesky),
3228
+ [Primitive.LU]: Array$1.#routine(Primitive.LU),
3226
3229
  [Primitive.Jit](args, { jaxpr }) {
3227
3230
  if (jaxpr.inBinders.length !== args.length) throw new Error(`jit expects ${jaxpr.inBinders.length} args, got ${args.length}`);
3228
3231
  const { backend, committed } = Array$1.#computeBackend("jit", args);
@@ -3304,7 +3307,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
3304
3307
  if (!shape$1) {
3305
3308
  shape$1 = [];
3306
3309
  let cur = values;
3307
- while (JsArray.isArray(cur)) {
3310
+ while (JsArray$1.isArray(cur)) {
3308
3311
  shape$1.push(cur.length);
3309
3312
  cur = cur[0];
3310
3313
  }
@@ -4260,17 +4263,39 @@ function jvpFlat(f, primals, tangents) {
4260
4263
  _usingCtx$1.d();
4261
4264
  }
4262
4265
  }
4263
- function jvp$1(f, primals, tangents) {
4266
+ function jvp$1(f, primals, tangents, { hasAux = false } = {}) {
4264
4267
  const [primalsFlat, inTree] = flatten(primals);
4265
4268
  const [tangentsFlat, inTree2] = flatten(tangents);
4266
4269
  if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
4267
- const [flatFun, outTree] = flattenFun(f, inTree);
4270
+ let flatFun, outTree, aux;
4271
+ if (hasAux) [flatFun, outTree, aux] = flattenFunWithAux(f, inTree);
4272
+ else [flatFun, outTree] = flattenFun(f, inTree);
4268
4273
  const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
4269
4274
  if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
4270
4275
  const primalsOut = unflatten(outTree.value, primalsOutFlat);
4271
4276
  const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
4277
+ if (hasAux) return [
4278
+ primalsOut,
4279
+ tangentsOut,
4280
+ lowerAux(aux.value)
4281
+ ];
4272
4282
  return [primalsOut, tangentsOut];
4273
4283
  }
4284
+ /** Lowering for auxiliary data returned in `hasAux: true` methods. */
4285
+ function lowerAux(aux) {
4286
+ const level = currentTraceLevel();
4287
+ return map((x) => {
4288
+ if (x instanceof Tracer) while (x._trace.main.level > level) if (x instanceof JVPTracer) {
4289
+ x.tangent.dispose();
4290
+ x = x.primal;
4291
+ } else {
4292
+ const y = x.fullLower();
4293
+ if (y._trace.main.level >= x._trace.main.level) throw new Error("internal: lowerAux did not reduce trace level");
4294
+ x = y;
4295
+ }
4296
+ return x;
4297
+ }, aux);
4298
+ }
4274
4299
 
4275
4300
  //#endregion
4276
4301
  //#region src/frontend/linearize.ts
@@ -4341,9 +4366,11 @@ function linearizeFlat(f, primalsIn) {
4341
4366
  dispose$1
4342
4367
  ];
4343
4368
  }
4344
- function linearize$1(f, ...primalsIn) {
4369
+ function linearize$1(f, primalsIn, { hasAux = false } = {}) {
4345
4370
  const [primalsInFlat, inTree] = flatten(primalsIn);
4346
- const [fFlat, outTree] = flattenFun(f, inTree);
4371
+ let fFlat, outTree, aux;
4372
+ if (hasAux) [fFlat, outTree, aux] = flattenFunWithAux(f, inTree);
4373
+ else [fFlat, outTree] = flattenFun(f, inTree);
4347
4374
  const [primalsOutFlat, fLinFlat, dispose$1] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
4348
4375
  if (outTree.value === void 0) throw new Error("outTree was not set in linearize");
4349
4376
  const primalsOut = unflatten(outTree.value, primalsOutFlat);
@@ -4354,6 +4381,11 @@ function linearize$1(f, ...primalsIn) {
4354
4381
  return unflatten(outTree.value, tangentsOutFlat);
4355
4382
  });
4356
4383
  fLin.dispose = dispose$1;
4384
+ if (hasAux) return [
4385
+ primalsOut,
4386
+ fLin,
4387
+ lowerAux(aux.value)
4388
+ ];
4357
4389
  return [primalsOut, fLin];
4358
4390
  }
4359
4391
  var PartialEvalTracer = class extends Tracer {
@@ -4854,9 +4886,11 @@ function vjpFlat(f, primalsIn) {
4854
4886
  dispose$1
4855
4887
  ];
4856
4888
  }
4857
- function vjp$1(f, ...primalsIn) {
4889
+ function vjp$1(f, primalsIn, { hasAux = false } = {}) {
4858
4890
  const [primalsInFlat, inTree] = flatten(primalsIn);
4859
- const [fFlat, outTree] = flattenFun(f, inTree);
4891
+ let fFlat, outTree, aux;
4892
+ if (hasAux) [fFlat, outTree, aux] = flattenFunWithAux(f, inTree);
4893
+ else [fFlat, outTree] = flattenFun(f, inTree);
4860
4894
  const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
4861
4895
  if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
4862
4896
  const primalsOut = unflatten(outTree.value, primalsOutFlat);
@@ -4867,26 +4901,43 @@ function vjp$1(f, ...primalsIn) {
4867
4901
  return unflatten(inTree, cotangentsInFlat);
4868
4902
  });
4869
4903
  fVjp.dispose = dispose$1;
4904
+ if (hasAux) return [
4905
+ primalsOut,
4906
+ fVjp,
4907
+ lowerAux(aux.value)
4908
+ ];
4870
4909
  return [primalsOut, fVjp];
4871
4910
  }
4872
- function grad$1(f) {
4873
- const valueAndGradFn = valueAndGrad$1(f);
4911
+ function grad$1(f, opts) {
4912
+ const valueAndGradFn = valueAndGrad$1(f, opts);
4874
4913
  return (...x) => {
4875
- const [y, dx] = valueAndGradFn(...x);
4876
- y.dispose();
4877
- return dx;
4914
+ if (opts?.hasAux) {
4915
+ const [[y, aux], dx] = valueAndGradFn(...x);
4916
+ y.dispose();
4917
+ return [dx, aux];
4918
+ } else {
4919
+ const [y, dx] = valueAndGradFn(...x);
4920
+ y.dispose();
4921
+ return dx;
4922
+ }
4878
4923
  };
4879
4924
  }
4880
- function valueAndGrad$1(f) {
4925
+ function valueAndGrad$1(f, opts) {
4926
+ const argnums = opts?.argnums ?? 0;
4927
+ const hasAux = opts?.hasAux ?? false;
4928
+ require_backend.checkInts(argnums);
4929
+ const argnumsSet = new Set(typeof argnums === "number" ? [argnums] : argnums);
4881
4930
  return (...x) => {
4882
4931
  if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
4883
- const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
4932
+ for (let i = 0; i < x.length; i++) if (!argnumsSet.has(i)) x[i] = map(stopGradient, x[i]);
4933
+ const [y, fVjp, aux] = vjp$1(f, x, { hasAux });
4884
4934
  if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
4885
4935
  if (!require_backend.isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
4886
- const [ct, ...rest] = fVjp(onesLike$1(y.ref));
4887
- for (const r of rest) dispose(r);
4936
+ const cts = fVjp(onesLike$1(y.ref));
4888
4937
  fVjp.dispose();
4889
- return [y, ct];
4938
+ for (let i = 0; i < cts.length; i++) if (!argnumsSet.has(i)) dispose(cts[i]);
4939
+ const grads = typeof argnums === "number" ? cts[argnums] : argnums.map((i) => cts[i]);
4940
+ return hasAux ? [[y, aux], grads] : [y, grads];
4890
4941
  };
4891
4942
  }
4892
4943
  function jacrev$1(f) {
@@ -4894,7 +4945,7 @@ function jacrev$1(f) {
4894
4945
  if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
4895
4946
  const [size$1] = x.shape;
4896
4947
  const pullback = (ct) => {
4897
- const [y, fVjp] = vjp$1(f, x);
4948
+ const [y, fVjp] = vjp$1(f, [x]);
4898
4949
  y.dispose();
4899
4950
  const [ret] = fVjp(ct);
4900
4951
  fVjp.dispose();
@@ -4903,6 +4954,9 @@ function jacrev$1(f) {
4903
4954
  return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
4904
4955
  };
4905
4956
  }
4957
+ function hessian$1(f) {
4958
+ return jacfwd$1(grad$1(f));
4959
+ }
4906
4960
 
4907
4961
  //#endregion
4908
4962
  //#region src/library/numpy/einsum.ts
@@ -5575,6 +5629,7 @@ __export(numpy_exports, {
5575
5629
  moveaxis: () => moveaxis$1,
5576
5630
  multiply: () => multiply,
5577
5631
  nan: () => nan,
5632
+ nanToNum: () => nanToNum,
5578
5633
  ndim: () => ndim,
5579
5634
  negative: () => negative,
5580
5635
  notEqual: () => notEqual,
@@ -5612,6 +5667,7 @@ __export(numpy_exports, {
5612
5667
  std: () => std,
5613
5668
  subtract: () => subtract,
5614
5669
  sum: () => sum,
5670
+ swapaxes: () => swapaxes,
5615
5671
  take: () => take,
5616
5672
  tan: () => tan,
5617
5673
  tanh: () => tanh,
@@ -5771,24 +5827,22 @@ function max(a, axis = null, opts) {
5771
5827
  return reduce(a, require_backend.AluOp.Max, axis, opts);
5772
5828
  }
5773
5829
  /**
5774
- * Test whether all array elements along a given axis evaluate to True.
5830
+ * Test whether any array element along a given axis evaluates to True.
5775
5831
  *
5776
5832
  * Returns a boolean array with the same shape as `a` with the specified axis
5777
5833
  * removed. If axis is None, returns a scalar.
5778
5834
  */
5779
- function all(a, axis = null, opts) {
5780
- a = fudgeArray(a).astype(require_backend.DType.Bool);
5781
- return min(a, axis, opts);
5835
+ function any(a, axis = null, opts) {
5836
+ return fudgeArray(a).any(axis, opts);
5782
5837
  }
5783
5838
  /**
5784
- * Test whether any array element along a given axis evaluates to True.
5839
+ * Test whether all array elements along a given axis evaluate to True.
5785
5840
  *
5786
5841
  * Returns a boolean array with the same shape as `a` with the specified axis
5787
5842
  * removed. If axis is None, returns a scalar.
5788
5843
  */
5789
- function any(a, axis = null, opts) {
5790
- a = fudgeArray(a).astype(require_backend.DType.Bool);
5791
- return max(a, axis, opts);
5844
+ function all(a, axis = null, opts) {
5845
+ return fudgeArray(a).all(axis, opts);
5792
5846
  }
5793
5847
  /** Return the peak-to-peak range along a given axis (`max - min`). */
5794
5848
  function ptp(a, axis = null, opts) {
@@ -5889,7 +5943,7 @@ function split$1(a, indicesOrSections, axis = 0) {
5889
5943
  const partSize = size$1 / indicesOrSections;
5890
5944
  sizes = require_backend.rep(indicesOrSections, partSize);
5891
5945
  } else {
5892
- const indices = indicesOrSections;
5946
+ const indices = indicesOrSections.map((i) => i < 0 ? i + size$1 : i);
5893
5947
  sizes = [indices[0]];
5894
5948
  for (let i = 1; i < indices.length; i++) sizes.push(indices[i] - indices[i - 1]);
5895
5949
  sizes.push(size$1 - indices[indices.length - 1]);
@@ -6010,6 +6064,17 @@ function flipud(x) {
6010
6064
  function fliplr(x) {
6011
6065
  return flip(x, 1);
6012
6066
  }
6067
+ /** Interchange two axes of an array. */
6068
+ function swapaxes(a, axis1, axis2) {
6069
+ a = fudgeArray(a);
6070
+ axis1 = require_backend.checkAxis(axis1, a.ndim);
6071
+ axis2 = require_backend.checkAxis(axis2, a.ndim);
6072
+ if (axis1 === axis2) return a;
6073
+ const perm = require_backend.range(a.ndim);
6074
+ perm[axis1] = axis2;
6075
+ perm[axis2] = axis1;
6076
+ return transpose(a, perm);
6077
+ }
6013
6078
  /** Transpose the last two dimensions of an array. */
6014
6079
  function matrixTranspose(a) {
6015
6080
  if (ndim(a) < 2) throw new Error(`matrixTranspose: input array must be at least 2D`);
@@ -6826,6 +6891,21 @@ function isposinf(x) {
6826
6891
  return require_backend.isFloatDtype(x.dtype) ? x.equal(Infinity) : fullLike$1(x, false);
6827
6892
  }
6828
6893
  /**
6894
+ * Replace NaN and infinite entries in an array.
6895
+ *
6896
+ * By default, NaNs are replaced with `0.0`, and infinities are are substituted
6897
+ * with the corresponding maximum or minimum finite values.
6898
+ */
6899
+ function nanToNum(x, { nan: nan$1 = 0, posinf = null, neginf = null } = {}) {
6900
+ x = fudgeArray(x);
6901
+ x = where(isnan(x.ref), nan$1, x);
6902
+ posinf ??= require_backend.isFloatDtype(x.dtype) ? finfo(x.dtype).max : iinfo(x.dtype).max;
6903
+ neginf ??= require_backend.isFloatDtype(x.dtype) ? finfo(x.dtype).min : iinfo(x.dtype).min;
6904
+ x = where(isposinf(x.ref), posinf, x);
6905
+ x = where(isneginf(x.ref), neginf, x);
6906
+ return x;
6907
+ }
6908
+ /**
6829
6909
  * @function
6830
6910
  * Test element-wise for finite values (not infinity or NaN).
6831
6911
  */
@@ -6938,6 +7018,7 @@ var lax_exports = {};
6938
7018
  __export(lax_exports, {
6939
7019
  conv: () => conv,
6940
7020
  convGeneralDilated: () => convGeneralDilated,
7021
+ convTranspose: () => convTranspose,
6941
7022
  convWithGeneralPadding: () => convWithGeneralPadding,
6942
7023
  dot: () => dot,
6943
7024
  erf: () => erf,
@@ -6946,6 +7027,7 @@ __export(lax_exports, {
6946
7027
  reduceWindow: () => reduceWindow,
6947
7028
  stopGradient: () => stopGradient$1
6948
7029
  });
7030
+ const JsArray = globalThis.Array;
6949
7031
  /**
6950
7032
  * General dot product/contraction operator.
6951
7033
  *
@@ -7017,7 +7099,11 @@ function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
7017
7099
  * The semantics of this operation mimic the `jax.lax.conv_general_dilated`
7018
7100
  * function in JAX, which wraps XLA's general convolution operator.
7019
7101
  *
7020
- * Grouped convolutions are not supported right now.
7102
+ * @param lhs - Input tensor; shape `[N, C_in, ...xs]`
7103
+ * @param rhs - Convolution kernel; shape `[C_out, C_in / G, ...ks]`
7104
+ * @param windowStrides - Strides for each spatial dimension
7105
+ * @param padding - Padding for each spatial dimension, or a string
7106
+ * (`"VALID"`, `"SAME"`, or `"SAME_LOWER"`)
7021
7107
  */
7022
7108
  function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
7023
7109
  if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
@@ -7077,6 +7163,60 @@ function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, r
7077
7163
  function conv(lhs, rhs, windowStrides, padding) {
7078
7164
  return convGeneralDilated(lhs, rhs, windowStrides, padding);
7079
7165
  }
7166
+ /**
7167
+ * Convenience wrapper for calculating the N-d convolution "transpose".
7168
+ *
7169
+ * This function directly calculates a fractionally strided conv rather than
7170
+ * indirectly calculating the gradient (transpose) of a forward convolution.
7171
+ * It is equivalent to the JAX version, except:
7172
+ *
7173
+ * - The `use_consistent_padding` option is not available. We only have the
7174
+ * consistent padding case (JAX version >0.8.4).
7175
+ * - The order of dimensions matches `lax.conv_general_dilated`.
7176
+ *
7177
+ * Unlike PyTorch/TensorFlow, by default we don't reverse the kernel's spatial
7178
+ * dimensions or the `(C_out, C_in)` axis order. To get this behavior, set
7179
+ * `transposeKernel` to true.
7180
+ *
7181
+ * @param lhs - Input tensor; shape `[N, C_in, ...xs]`
7182
+ * @param rhs - Convolution kernel; shape `[C_out, C_in, ...ks]`
7183
+ * @param strides - Sequence of n integers, sets fractional stride
7184
+ * @param padding - Apply padding of `dilation * (kernel_size - 1) - padding` to
7185
+ * each side of the input, so it acts like gradient of `conv()`
7186
+ * @param rhsDilation - Atrous dilation for the kernel
7187
+ * @param transposeKernel - Flip spatial axes and swap the input/output channels
7188
+ * of the kernel; its shape should be `[C_in, C_out, ...ks]`
7189
+ */
7190
+ function convTranspose(lhs, rhs, strides, padding, { rhsDilation, transposeKernel = false } = {}) {
7191
+ const kernelShape = rhs.shape.slice(2);
7192
+ rhsDilation = rhsDilation ?? require_backend.rep(kernelShape.length, 1);
7193
+ const effectiveKernel = kernelShape.map((k, i) => Math.max(0, (k - 1) * rhsDilation[i] + 1));
7194
+ const pads = effectiveKernel.map((k, i) => convTransposePadding(k, strides[i], typeof padding === "string" ? padding : padding[i]));
7195
+ if (transposeKernel) {
7196
+ rhs = flip$1(rhs, require_backend.range(2, rhs.ndim));
7197
+ rhs = moveaxis(rhs, 0, 1);
7198
+ }
7199
+ return convGeneralDilated(lhs, rhs, require_backend.rep(lhs.ndim - 2, 1), pads, {
7200
+ lhsDilation: strides,
7201
+ rhsDilation
7202
+ });
7203
+ }
7204
+ function convTransposePadding(k, s, padding) {
7205
+ let padLen;
7206
+ let pad1;
7207
+ if (padding === "SAME") {
7208
+ padLen = k + s - 2;
7209
+ pad1 = s > k - 1 ? k - 1 : Math.ceil(padLen / 2);
7210
+ } else if (padding === "VALID") {
7211
+ padLen = k + s - 2 + Math.max(k - s, 0);
7212
+ pad1 = k - 1;
7213
+ } else if (JsArray.isArray(padding)) {
7214
+ const pads = [k - 1 - padding[0], k - 1 - padding[1]];
7215
+ pad1 = pads[0];
7216
+ padLen = pads[0] + pads[1];
7217
+ } else throw new Error(`convTranspose: Invalid padding type ${padding}`);
7218
+ return [pad1, padLen - pad1];
7219
+ }
7080
7220
  /** Reduce a computation over padded windows. */
7081
7221
  function reduceWindow(operand, computation, windowDimensions, windowStrides) {
7082
7222
  if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
@@ -7115,6 +7255,7 @@ function stopGradient$1(x) {
7115
7255
  var nn_exports = {};
7116
7256
  __export(nn_exports, {
7117
7257
  celu: () => celu,
7258
+ dotProductAttention: () => dotProductAttention,
7118
7259
  elu: () => elu,
7119
7260
  gelu: () => gelu,
7120
7261
  glu: () => glu,
@@ -7431,6 +7572,125 @@ function oneHot(x, numClasses) {
7431
7572
  if (require_backend.isFloatDtype(x.dtype)) throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
7432
7573
  return eye(numClasses, void 0, { device: x.device }).slice(x);
7433
7574
  }
7575
+ /**
7576
+ * Scaled dot product attention (SDPA).
7577
+ *
7578
+ * Computes `softmax((Q @ K^T) / sqrt(d) + bias) @ V`, where `Q` is the query,
7579
+ * `K` is the key, `V` is the value, and `d` is the dimensionality of each key
7580
+ * and query vector.
7581
+ *
7582
+ * Multi-query attention is applied when input `key` and `value` tensors have
7583
+ * fewer heads than `query`.
7584
+ *
7585
+ * We use the following uppercase letters to denote array shapes:
7586
+ * - `B` = batch size
7587
+ * - `S` = length of key/value sequences (source)
7588
+ * - `L` = length of query sequences
7589
+ * - `N` = number of attention heads
7590
+ * - `H` = dimensionality of each attention head
7591
+ * - `K` = number of key/value heads (for grouped-query attention)
7592
+ *
7593
+ * The batch size `B` may be omitted, which is equivalent to `B = 1`. In this
7594
+ * case it must be omitted from all inputs.
7595
+ *
7596
+ * @param query - Query array; shape `[B, L, N, H]`
7597
+ * @param key - Key array; shape `[B, S, K, H]`
7598
+ * @param value - Value array; same shape as `key`
7599
+ * @param opts.bias - Optional bias to add to the attention logits; shape
7600
+ * `[B, N, L, S]` or broadcastable to it.
7601
+ * @param opts.mask - Optional mask to apply to the attention logits; should be
7602
+ * a boolean array broadcastable to `[B, N, L, S]`, where `true` indicates
7603
+ * the element should take part in attention.
7604
+ * @param opts.scale - Scaling factor override, default is `1 / sqrt(H)`.
7605
+ * @param opts.isCausal - If true, applies a casual mask.
7606
+ * @param opts.querySeqLengths - Optional sequence lengths for the queries;
7607
+ * shape `(B,)`. Taken from the beginning of the tensor.
7608
+ * @param opts.keyValueSeqLengths - Optional sequence lengths for the keys and
7609
+ * values; shape `(B,)`. Taken from the beginning of the tensor.
7610
+ * @param opts.localWindowSize - If specified, applies a local attention window
7611
+ * of the given size. Can be a single number or a tuple `[left, right]`.
7612
+ *
7613
+ * @returns The result of the attention operation; shape is the same as query
7614
+ * `[B, L, N, H]`, or `[L, N, H]` if `B` is omitted.
7615
+ */
7616
+ function dotProductAttention(query, key$1, value, opts = {}) {
7617
+ query = fudgeArray(query);
7618
+ key$1 = fudgeArray(key$1);
7619
+ value = fudgeArray(value);
7620
+ if (query.ndim !== 3 && query.ndim !== 4 || query.ndim !== key$1.ndim || query.ndim !== value.ndim) throw new Error(`dotProductAttention: expected all tensors to have rank 3 or 4, got Q=${query.aval}, K=${key$1.aval}, V=${value.aval}`);
7621
+ if (!require_backend.deepEqual(key$1.shape, value.shape)) throw new Error(`dotProductAttention: key and value shapes must match, got K=${key$1.shape}, V=${value.shape}`);
7622
+ const isRank3 = query.ndim === 3;
7623
+ if (isRank3) {
7624
+ query = expandDims(query, 0);
7625
+ key$1 = expandDims(key$1, 0);
7626
+ value = expandDims(value, 0);
7627
+ }
7628
+ const [B, L, N, H] = query.shape;
7629
+ if (key$1.shape[0] !== B || key$1.shape[3] !== H) throw new Error(`dotProductAttention: query and key shapes mismatch, got Q=${query.aval}, K=${key$1.aval}`);
7630
+ const S = key$1.shape[1];
7631
+ const K = key$1.shape[2];
7632
+ if (N < K || N != K && N % K !== 0) throw new Error(`dotProductAttention: number of query heads N=${N} must be divisible by number of key/value heads K=${K} for GQA`);
7633
+ const G = N / K;
7634
+ key$1 = tile(key$1, [
7635
+ 1,
7636
+ 1,
7637
+ G,
7638
+ 1
7639
+ ]);
7640
+ value = tile(value, [
7641
+ 1,
7642
+ 1,
7643
+ G,
7644
+ 1
7645
+ ]);
7646
+ const scale = opts.scale ?? 1 / Math.sqrt(H);
7647
+ let scores = einsum("BLNH,BSNH->BNLS", query, key$1).mul(scale);
7648
+ if (opts.bias !== void 0) scores = scores.add(opts.bias);
7649
+ if (opts.mask !== void 0) scores = where(opts.mask, scores, -Infinity);
7650
+ if (opts.isCausal) {
7651
+ const causalMask = tri(L, S, 0, { dtype: require_backend.DType.Bool });
7652
+ scores = where(causalMask, scores, -Infinity);
7653
+ }
7654
+ if (opts.localWindowSize !== void 0) {
7655
+ const [before, after] = typeof opts.localWindowSize === "number" ? [opts.localWindowSize, opts.localWindowSize] : opts.localWindowSize;
7656
+ if (before < 0 || after < 0 || !Number.isInteger(before) || !Number.isInteger(after)) throw new Error(`dotProductAttention: localWindowSize values must be non-negative, got ${opts.localWindowSize}`);
7657
+ const localMask = tri(L, S, after, { dtype: require_backend.DType.Bool }).mul(tri(L, S, -before - 1, { dtype: require_backend.DType.Bool }).notEqual(true));
7658
+ scores = where(localMask, scores, -Infinity);
7659
+ }
7660
+ if (opts.querySeqLengths !== void 0) {
7661
+ const sl = expandDims(opts.querySeqLengths, [
7662
+ -1,
7663
+ -2,
7664
+ -3
7665
+ ]);
7666
+ scores = where(arange(L).reshape([
7667
+ 1,
7668
+ 1,
7669
+ L,
7670
+ 1
7671
+ ]).less(sl), scores, -Infinity);
7672
+ }
7673
+ if (opts.keyValueSeqLengths !== void 0) {
7674
+ const sl = expandDims(opts.keyValueSeqLengths, [
7675
+ -1,
7676
+ -2,
7677
+ -3
7678
+ ]);
7679
+ scores = where(arange(S).reshape([
7680
+ 1,
7681
+ 1,
7682
+ 1,
7683
+ S
7684
+ ]).less(sl), scores, -Infinity);
7685
+ }
7686
+ const attn = softmax(scores, -1);
7687
+ const out = einsum("BNLS,BSNH->BLNH", attn, value);
7688
+ return isRank3 ? out.reshape([
7689
+ L,
7690
+ N,
7691
+ H
7692
+ ]) : out;
7693
+ }
7434
7694
 
7435
7695
  //#endregion
7436
7696
  //#region src/library/random.ts
@@ -7666,17 +7926,62 @@ const linearize = linearize$1;
7666
7926
  /**
7667
7927
  * @function
7668
7928
  * Calculate the reverse-mode vector-Jacobian product for a function.
7929
+ *
7930
+ * The return value is a tuple of `[out, vjpFn]`, where `out` is the output of
7931
+ * `f(primals)`, and `vjpFn` is a function that takes in cotangents for each
7932
+ * output and returns the cotangents for each input.
7933
+ *
7934
+ * When `{ hasAux: true }` is passed, the function `f` is expected to return an
7935
+ * `[out, aux]` tuple, and `vjp` returns `[out, vjpFn, aux]`.
7936
+ *
7937
+ * @example
7938
+ * ```ts
7939
+ * const [y, vjpFn] = vjp(f, [x]);
7940
+ *
7941
+ * // With hasAux
7942
+ * const [y, vjpFn, aux] = vjp(f, [x], { hasAux: true });
7943
+ * ```
7669
7944
  */
7670
7945
  const vjp = vjp$1;
7671
7946
  /**
7672
7947
  * @function
7673
7948
  * Compute the gradient of a scalar-valued function `f` with respect to its
7674
7949
  * first argument.
7950
+ *
7951
+ * Pass in different `argnums` to differentiate with respect to other
7952
+ * arguments. If a tuple is provided, the return value will be a tuple of
7953
+ * gradients corresponding to each argument index.
7954
+ *
7955
+ * When `{ hasAux: true }` is passed, the function `f` is expected to return a
7956
+ * `[out, aux]` tuple, and the return value will be `[gradient, aux]`.
7957
+ *
7958
+ * @example
7959
+ * ```ts
7960
+ * const gradient = grad(f)(x);
7961
+ *
7962
+ * // With `argnums`
7963
+ * const [gradientX, gradientZ] = grad(f, { argnums: [0, 2] })(x, y, z);
7964
+ *
7965
+ * // With `hasAux`
7966
+ * const [gradient, aux] = grad(f, { hasAux: true })(x);
7967
+ * ```
7675
7968
  */
7676
7969
  const grad = grad$1;
7677
7970
  /**
7678
7971
  * @function
7679
7972
  * Create a function that evaluates both `f` and the gradient of `f`.
7973
+ *
7974
+ * When `{ hasAux: true }` is passed, the function `f` is expected to return an
7975
+ * `[out, aux]` tuple, and the return value will be `[[out, aux], gradient]`.
7976
+ *
7977
+ * @example
7978
+ * ```ts
7979
+ * // Without hasAux
7980
+ * const [value, gradient] = valueAndGrad(f)(x);
7981
+ *
7982
+ * // With hasAux
7983
+ * const [[value, aux], gradient] = valueAndGrad(f, { hasAux: true })(x);
7984
+ * ```
7680
7985
  */
7681
7986
  const valueAndGrad = valueAndGrad$1;
7682
7987
  /**
@@ -7685,6 +7990,21 @@ const valueAndGrad = valueAndGrad$1;
7685
7990
  */
7686
7991
  const jacrev = jacrev$1;
7687
7992
  /**
7993
+ * @function
7994
+ * Compute the Hessian matrix of a scalar-valued function.
7995
+ *
7996
+ * The Hessian is the matrix of second-order partial derivatives of a function.
7997
+ * This is implemented as `jacfwd(grad(f))`.
7998
+ *
7999
+ * @example
8000
+ * ```ts
8001
+ * const f = (x: np.Array) => np.sum(x.ref.mul(x.ref).mul(x)); // x^3
8002
+ * const H = hessian(f)(np.array([1, 2, 3]));
8003
+ * // H[i,j] = d^2f / dx_i dx_j
8004
+ * ```
8005
+ */
8006
+ const hessian = hessian$1;
8007
+ /**
7688
8008
  * Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
7689
8009
  *
7690
8010
  * This can be used to wait for the results of an intermediate computation to
@@ -7728,6 +8048,7 @@ exports.defaultDevice = require_backend.defaultDevice;
7728
8048
  exports.devicePut = devicePut;
7729
8049
  exports.devices = require_backend.devices;
7730
8050
  exports.grad = grad;
8051
+ exports.hessian = hessian;
7731
8052
  exports.init = require_backend.init;
7732
8053
  exports.jacfwd = jacfwd;
7733
8054
  exports.jacobian = jacrev;