@jax-js/jax 0.1.5 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.cjs CHANGED
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
30
30
  }) : target, mod$1));
31
31
 
32
32
  //#endregion
33
- const require_backend = require('./backend-DziQSaoQ.cjs');
33
+ const require_backend = require('./backend-D7s-Retx.cjs');
34
34
 
35
35
  //#region src/frontend/convolution.ts
36
36
  /**
@@ -240,7 +240,7 @@ __export(tree_exports, {
240
240
  structure: () => structure,
241
241
  unflatten: () => unflatten
242
242
  });
243
- const JsArray$1 = globalThis.Array;
243
+ const JsArray$2 = globalThis.Array;
244
244
  let NodeType = /* @__PURE__ */ function(NodeType$1) {
245
245
  NodeType$1["Array"] = "Array";
246
246
  NodeType$1["Object"] = "Object";
@@ -288,7 +288,7 @@ function flatten(tree) {
288
288
  return [leaves$1, treedef];
289
289
  }
290
290
  function _flatten(tree, leaves$1) {
291
- if (JsArray$1.isArray(tree)) {
291
+ if (JsArray$2.isArray(tree)) {
292
292
  const childTrees = tree.map((c) => _flatten(c, leaves$1));
293
293
  return new JsTreeDef(NodeType.Array, null, childTrees);
294
294
  } else if (typeof tree === "object" && tree !== null && tree.constructor === Object) {
@@ -412,6 +412,13 @@ let CompareOp = /* @__PURE__ */ function(CompareOp$1) {
412
412
  CompareOp$1["LessEqual"] = "less_equal";
413
413
  return CompareOp$1;
414
414
  }({});
415
+ const routinePrimitives = new Map([
416
+ [Primitive.Sort, require_backend.Routines.Sort],
417
+ [Primitive.Argsort, require_backend.Routines.Argsort],
418
+ [Primitive.TriangularSolve, require_backend.Routines.TriangularSolve],
419
+ [Primitive.Cholesky, require_backend.Routines.Cholesky],
420
+ [Primitive.LU, require_backend.Routines.LU]
421
+ ]);
415
422
  function add$1(x, y) {
416
423
  return bind1(Primitive.Add, [x, y]);
417
424
  }
@@ -685,6 +692,9 @@ function newDynamic(main) {
685
692
  dynamicTrace = prevDynamicTrace;
686
693
  } };
687
694
  }
695
+ function currentTraceLevel() {
696
+ return traceStack[traceStack.length - 1].level;
697
+ }
688
698
  var Trace = class {
689
699
  constructor(main) {
690
700
  this.main = main;
@@ -1062,6 +1072,7 @@ var TreeMismatchError = class extends TypeError {
1062
1072
  super(`Mismatched tree structures in ${where$2}: ${left} != ${right}`);
1063
1073
  }
1064
1074
  };
1075
+ /** Flatten a function of `JsTree` input/output for use in tracing. */
1065
1076
  function flattenFun(f, inTree) {
1066
1077
  const store = { value: void 0 };
1067
1078
  const flatFun = (...argsFlat) => {
@@ -1073,6 +1084,26 @@ function flattenFun(f, inTree) {
1073
1084
  };
1074
1085
  return [flatFun, store];
1075
1086
  }
1087
+ /** Like flattenFun, but expects f to return [main, aux] tuple. */
1088
+ function flattenFunWithAux(f, inTree) {
1089
+ const store = { value: void 0 };
1090
+ const auxStore = { value: void 0 };
1091
+ const flatFun = (...argsFlat) => {
1092
+ const pytreeArgs = unflatten(inTree, argsFlat);
1093
+ const result = f(...pytreeArgs);
1094
+ if (!Array.isArray(result) || result.length !== 2) throw new Error("Function with `hasAux: true` must return [output, aux] tuple");
1095
+ const [out, aux] = result;
1096
+ const [outFlat, outTree] = flatten(out);
1097
+ store.value = outTree;
1098
+ auxStore.value = aux;
1099
+ return outFlat;
1100
+ };
1101
+ return [
1102
+ flatFun,
1103
+ store,
1104
+ auxStore
1105
+ ];
1106
+ }
1076
1107
  var UseAfterFreeError = class extends ReferenceError {
1077
1108
  constructor(tracer) {
1078
1109
  super(`Referenced tracer ${tracer.toString()} freed, please use .ref move semantics`);
@@ -1806,13 +1837,6 @@ function jit$1(f, opts) {
1806
1837
 
1807
1838
  //#endregion
1808
1839
  //#region src/frontend/jit.ts
1809
- const routinePrimitives = new Map([
1810
- [Primitive.Sort, require_backend.Routines.Sort],
1811
- [Primitive.Argsort, require_backend.Routines.Argsort],
1812
- [Primitive.TriangularSolve, require_backend.Routines.TriangularSolve],
1813
- [Primitive.Cholesky, require_backend.Routines.Cholesky],
1814
- [Primitive.LU, require_backend.Routines.LU]
1815
- ]);
1816
1840
  /** Result of compiling a Jaxpr. Can be evaluated on a series of inputs. */
1817
1841
  var JitProgram = class {
1818
1842
  constructor(backend, steps, inputs, outputs) {
@@ -2201,12 +2225,13 @@ const jitRules = {
2201
2225
  const ndim$2 = avals[0].ndim;
2202
2226
  const sizes = avals.map((x) => x.shape[axis]);
2203
2227
  const finalSize = sizes.reduce((a, b) => a + b, 0);
2228
+ const { dtype: dtypeOut } = avals.map((x) => x.scalar()).reduce(promoteAvals);
2204
2229
  const makePadAxis = (start, end) => require_backend.range(ndim$2).map((i) => i === axis ? [start, end] : [0, 0]);
2205
2230
  let cum = 0;
2206
2231
  const src = [];
2207
2232
  for (let i = 0; i < exps.length; i++) {
2208
2233
  const padding = makePadAxis(cum, finalSize - cum - sizes[i]);
2209
- src.push(reshapeViews(exps[i], (st) => st.pad(padding)));
2234
+ src.push(reshapeViews(require_backend.AluExp.cast(dtypeOut, exps[i]), (st) => st.pad(padding)));
2210
2235
  cum += sizes[i];
2211
2236
  }
2212
2237
  return { exp: [src.reduce(require_backend.AluExp.add)] };
@@ -2344,7 +2369,7 @@ function splitGraphDataflow(backend, jaxpr) {
2344
2369
  p1NextBlack.set(v, v);
2345
2370
  }
2346
2371
  const heterogeneousViewPrimitives = [Primitive.RandomBits, Primitive.Gather];
2347
- const needsCleanShapePrimitives = [Primitive.Pad];
2372
+ const needsCleanShapePrimitives = [Primitive.Concatenate, Primitive.Pad];
2348
2373
  for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
2349
2374
  const eqn = jaxpr.eqns[i];
2350
2375
  if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || routinePrimitives.has(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
@@ -2414,7 +2439,7 @@ function splitGraphDataflow(backend, jaxpr) {
2414
2439
 
2415
2440
  //#endregion
2416
2441
  //#region src/frontend/array.ts
2417
- const JsArray = globalThis.Array;
2442
+ const JsArray$1 = globalThis.Array;
2418
2443
  const inlineArrayLimit = 128;
2419
2444
  /** Version of pureArray with fudged types. */
2420
2445
  const fudgeArray = pureArray;
@@ -2812,25 +2837,35 @@ var Array$1 = class Array$1 extends Tracer {
2812
2837
  });
2813
2838
  }
2814
2839
  /** Apply an operation with custom lowering to this array. */
2815
- static #routine(routine, arrays, outputWeakType) {
2816
- const { backend, committed } = Array$1.#computeBackend(routine.name, arrays);
2817
- for (const ar of arrays) ar.#realize();
2818
- const inputs = arrays.map((ar) => ar.#source);
2819
- const outputs = routine.type.outputDtypes.map((dtype, i) => backend.malloc(require_backend.byteWidth(dtype) * require_backend.prod(routine.type.outputShapes[i])));
2820
- const pending = arrays.flatMap((ar) => ar.#pending);
2821
- for (const exe of pending) exe.updateRc(+outputs.length);
2822
- pending.push(new PendingExecute(backend, routine, inputs, outputs));
2823
- pending[pending.length - 1].updateRc(+outputs.length - 1);
2824
- arrays.forEach((ar) => ar.dispose());
2825
- return outputs.map((output, i) => new Array$1({
2826
- source: output,
2827
- st: require_backend.ShapeTracker.fromShape(routine.type.outputShapes[i]),
2828
- dtype: routine.type.outputDtypes[i],
2829
- weakType: outputWeakType[i],
2830
- backend,
2831
- committed,
2832
- pending
2833
- }));
2840
+ static #routine(prim) {
2841
+ return (arrays, params) => {
2842
+ const { backend, committed } = Array$1.#computeBackend(prim, arrays);
2843
+ for (const ar of arrays) ar.#realize();
2844
+ const avals = arrays.map((ar) => ar.aval);
2845
+ const avalsOut = abstractEvalRules[prim](avals, params);
2846
+ const routine = new require_backend.Routine(routinePrimitives.get(prim), {
2847
+ inputShapes: avals.map((a) => a.shape),
2848
+ inputDtypes: avals.map((a) => a.dtype),
2849
+ outputShapes: avalsOut.map((a) => a.shape),
2850
+ outputDtypes: avalsOut.map((a) => a.dtype)
2851
+ }, params);
2852
+ const inputs = arrays.map((ar) => ar.#source);
2853
+ const outputs = avalsOut.map((x) => backend.malloc(require_backend.byteWidth(x.dtype) * x.size));
2854
+ const pending = arrays.flatMap((ar) => ar.#pending);
2855
+ for (const exe of pending) exe.updateRc(+outputs.length);
2856
+ pending.push(new PendingExecute(backend, routine, inputs, outputs));
2857
+ pending[pending.length - 1].updateRc(+outputs.length - 1);
2858
+ arrays.forEach((ar) => ar.dispose());
2859
+ return outputs.map((output, i) => new Array$1({
2860
+ source: output,
2861
+ st: require_backend.ShapeTracker.fromShape(avalsOut[i].shape),
2862
+ dtype: avalsOut[i].dtype,
2863
+ weakType: avalsOut[i].weakType,
2864
+ backend,
2865
+ committed,
2866
+ pending
2867
+ }));
2868
+ };
2834
2869
  }
2835
2870
  /**
2836
2871
  * Normalizes this array into one backed by a `Slot`.
@@ -3164,65 +3199,11 @@ var Array$1 = class Array$1 extends Tracer {
3164
3199
  [Primitive.Pad]([x], { width }) {
3165
3200
  return [x.#reshape(x.#st.pad(width))];
3166
3201
  },
3167
- [Primitive.Sort]([x]) {
3168
- const routine = new require_backend.Routine(require_backend.Routines.Sort, {
3169
- inputShapes: [x.shape],
3170
- inputDtypes: [x.dtype],
3171
- outputShapes: [x.shape],
3172
- outputDtypes: [x.dtype]
3173
- });
3174
- return Array$1.#routine(routine, [x], [x.#weakType]);
3175
- },
3176
- [Primitive.Argsort]([x]) {
3177
- const routine = new require_backend.Routine(require_backend.Routines.Argsort, {
3178
- inputShapes: [x.shape],
3179
- inputDtypes: [x.dtype],
3180
- outputShapes: [x.shape, x.shape],
3181
- outputDtypes: [x.dtype, require_backend.DType.Int32]
3182
- });
3183
- return Array$1.#routine(routine, [x], [x.#weakType, false]);
3184
- },
3185
- [Primitive.TriangularSolve]([a, b], { unitDiagonal }) {
3186
- const routine = new require_backend.Routine(require_backend.Routines.TriangularSolve, {
3187
- inputShapes: [a.shape, b.shape],
3188
- inputDtypes: [a.dtype, b.dtype],
3189
- outputShapes: [b.shape],
3190
- outputDtypes: [b.dtype]
3191
- }, { unitDiagonal });
3192
- return Array$1.#routine(routine, [a, b], [a.#weakType && b.#weakType]);
3193
- },
3194
- [Primitive.Cholesky]([a]) {
3195
- const routine = new require_backend.Routine(require_backend.Routines.Cholesky, {
3196
- inputShapes: [a.shape],
3197
- inputDtypes: [a.dtype],
3198
- outputShapes: [a.shape],
3199
- outputDtypes: [a.dtype]
3200
- });
3201
- return Array$1.#routine(routine, [a], [a.#weakType]);
3202
- },
3203
- [Primitive.LU]([a]) {
3204
- const batch = a.shape.slice(0, -2);
3205
- const [m, n] = a.shape.slice(-2);
3206
- const routine = new require_backend.Routine(require_backend.Routines.LU, {
3207
- inputShapes: [a.shape],
3208
- inputDtypes: [a.dtype],
3209
- outputShapes: [
3210
- a.shape,
3211
- [...batch, Math.min(m, n)],
3212
- [...batch, m]
3213
- ],
3214
- outputDtypes: [
3215
- a.dtype,
3216
- require_backend.DType.Int32,
3217
- require_backend.DType.Int32
3218
- ]
3219
- });
3220
- return Array$1.#routine(routine, [a], [
3221
- a.#weakType,
3222
- false,
3223
- false
3224
- ]);
3225
- },
3202
+ [Primitive.Sort]: Array$1.#routine(Primitive.Sort),
3203
+ [Primitive.Argsort]: Array$1.#routine(Primitive.Argsort),
3204
+ [Primitive.TriangularSolve]: Array$1.#routine(Primitive.TriangularSolve),
3205
+ [Primitive.Cholesky]: Array$1.#routine(Primitive.Cholesky),
3206
+ [Primitive.LU]: Array$1.#routine(Primitive.LU),
3226
3207
  [Primitive.Jit](args, { jaxpr }) {
3227
3208
  if (jaxpr.inBinders.length !== args.length) throw new Error(`jit expects ${jaxpr.inBinders.length} args, got ${args.length}`);
3228
3209
  const { backend, committed } = Array$1.#computeBackend("jit", args);
@@ -3304,7 +3285,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
3304
3285
  if (!shape$1) {
3305
3286
  shape$1 = [];
3306
3287
  let cur = values;
3307
- while (JsArray.isArray(cur)) {
3288
+ while (JsArray$1.isArray(cur)) {
3308
3289
  shape$1.push(cur.length);
3309
3290
  cur = cur[0];
3310
3291
  }
@@ -4260,17 +4241,39 @@ function jvpFlat(f, primals, tangents) {
4260
4241
  _usingCtx$1.d();
4261
4242
  }
4262
4243
  }
4263
- function jvp$1(f, primals, tangents) {
4244
+ function jvp$1(f, primals, tangents, { hasAux = false } = {}) {
4264
4245
  const [primalsFlat, inTree] = flatten(primals);
4265
4246
  const [tangentsFlat, inTree2] = flatten(tangents);
4266
4247
  if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
4267
- const [flatFun, outTree] = flattenFun(f, inTree);
4248
+ let flatFun, outTree, aux;
4249
+ if (hasAux) [flatFun, outTree, aux] = flattenFunWithAux(f, inTree);
4250
+ else [flatFun, outTree] = flattenFun(f, inTree);
4268
4251
  const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
4269
4252
  if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
4270
4253
  const primalsOut = unflatten(outTree.value, primalsOutFlat);
4271
4254
  const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
4255
+ if (hasAux) return [
4256
+ primalsOut,
4257
+ tangentsOut,
4258
+ lowerAux(aux.value)
4259
+ ];
4272
4260
  return [primalsOut, tangentsOut];
4273
4261
  }
4262
+ /** Lowering for auxiliary data returned in `hasAux: true` methods. */
4263
+ function lowerAux(aux) {
4264
+ const level = currentTraceLevel();
4265
+ return map((x) => {
4266
+ if (x instanceof Tracer) while (x._trace.main.level > level) if (x instanceof JVPTracer) {
4267
+ x.tangent.dispose();
4268
+ x = x.primal;
4269
+ } else {
4270
+ const y = x.fullLower();
4271
+ if (y._trace.main.level >= x._trace.main.level) throw new Error("internal: lowerAux did not reduce trace level");
4272
+ x = y;
4273
+ }
4274
+ return x;
4275
+ }, aux);
4276
+ }
4274
4277
 
4275
4278
  //#endregion
4276
4279
  //#region src/frontend/linearize.ts
@@ -4341,9 +4344,11 @@ function linearizeFlat(f, primalsIn) {
4341
4344
  dispose$1
4342
4345
  ];
4343
4346
  }
4344
- function linearize$1(f, ...primalsIn) {
4347
+ function linearize$1(f, primalsIn, { hasAux = false } = {}) {
4345
4348
  const [primalsInFlat, inTree] = flatten(primalsIn);
4346
- const [fFlat, outTree] = flattenFun(f, inTree);
4349
+ let fFlat, outTree, aux;
4350
+ if (hasAux) [fFlat, outTree, aux] = flattenFunWithAux(f, inTree);
4351
+ else [fFlat, outTree] = flattenFun(f, inTree);
4347
4352
  const [primalsOutFlat, fLinFlat, dispose$1] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
4348
4353
  if (outTree.value === void 0) throw new Error("outTree was not set in linearize");
4349
4354
  const primalsOut = unflatten(outTree.value, primalsOutFlat);
@@ -4354,6 +4359,11 @@ function linearize$1(f, ...primalsIn) {
4354
4359
  return unflatten(outTree.value, tangentsOutFlat);
4355
4360
  });
4356
4361
  fLin.dispose = dispose$1;
4362
+ if (hasAux) return [
4363
+ primalsOut,
4364
+ fLin,
4365
+ lowerAux(aux.value)
4366
+ ];
4357
4367
  return [primalsOut, fLin];
4358
4368
  }
4359
4369
  var PartialEvalTracer = class extends Tracer {
@@ -4854,9 +4864,11 @@ function vjpFlat(f, primalsIn) {
4854
4864
  dispose$1
4855
4865
  ];
4856
4866
  }
4857
- function vjp$1(f, ...primalsIn) {
4867
+ function vjp$1(f, primalsIn, { hasAux = false } = {}) {
4858
4868
  const [primalsInFlat, inTree] = flatten(primalsIn);
4859
- const [fFlat, outTree] = flattenFun(f, inTree);
4869
+ let fFlat, outTree, aux;
4870
+ if (hasAux) [fFlat, outTree, aux] = flattenFunWithAux(f, inTree);
4871
+ else [fFlat, outTree] = flattenFun(f, inTree);
4860
4872
  const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
4861
4873
  if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
4862
4874
  const primalsOut = unflatten(outTree.value, primalsOutFlat);
@@ -4867,26 +4879,43 @@ function vjp$1(f, ...primalsIn) {
4867
4879
  return unflatten(inTree, cotangentsInFlat);
4868
4880
  });
4869
4881
  fVjp.dispose = dispose$1;
4882
+ if (hasAux) return [
4883
+ primalsOut,
4884
+ fVjp,
4885
+ lowerAux(aux.value)
4886
+ ];
4870
4887
  return [primalsOut, fVjp];
4871
4888
  }
4872
- function grad$1(f) {
4873
- const valueAndGradFn = valueAndGrad$1(f);
4889
+ function grad$1(f, opts) {
4890
+ const valueAndGradFn = valueAndGrad$1(f, opts);
4874
4891
  return (...x) => {
4875
- const [y, dx] = valueAndGradFn(...x);
4876
- y.dispose();
4877
- return dx;
4892
+ if (opts?.hasAux) {
4893
+ const [[y, aux], dx] = valueAndGradFn(...x);
4894
+ y.dispose();
4895
+ return [dx, aux];
4896
+ } else {
4897
+ const [y, dx] = valueAndGradFn(...x);
4898
+ y.dispose();
4899
+ return dx;
4900
+ }
4878
4901
  };
4879
4902
  }
4880
- function valueAndGrad$1(f) {
4903
+ function valueAndGrad$1(f, opts) {
4904
+ const argnums = opts?.argnums ?? 0;
4905
+ const hasAux = opts?.hasAux ?? false;
4906
+ require_backend.checkInts(argnums);
4907
+ const argnumsSet = new Set(typeof argnums === "number" ? [argnums] : argnums);
4881
4908
  return (...x) => {
4882
4909
  if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
4883
- const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
4910
+ for (let i = 0; i < x.length; i++) if (!argnumsSet.has(i)) x[i] = map(stopGradient, x[i]);
4911
+ const [y, fVjp, aux] = vjp$1(f, x, { hasAux });
4884
4912
  if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
4885
4913
  if (!require_backend.isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
4886
- const [ct, ...rest] = fVjp(onesLike$1(y.ref));
4887
- for (const r of rest) dispose(r);
4914
+ const cts = fVjp(onesLike$1(y.ref));
4888
4915
  fVjp.dispose();
4889
- return [y, ct];
4916
+ for (let i = 0; i < cts.length; i++) if (!argnumsSet.has(i)) dispose(cts[i]);
4917
+ const grads = typeof argnums === "number" ? cts[argnums] : argnums.map((i) => cts[i]);
4918
+ return hasAux ? [[y, aux], grads] : [y, grads];
4890
4919
  };
4891
4920
  }
4892
4921
  function jacrev$1(f) {
@@ -4894,7 +4923,7 @@ function jacrev$1(f) {
4894
4923
  if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
4895
4924
  const [size$1] = x.shape;
4896
4925
  const pullback = (ct) => {
4897
- const [y, fVjp] = vjp$1(f, x);
4926
+ const [y, fVjp] = vjp$1(f, [x]);
4898
4927
  y.dispose();
4899
4928
  const [ret] = fVjp(ct);
4900
4929
  fVjp.dispose();
@@ -4903,6 +4932,9 @@ function jacrev$1(f) {
4903
4932
  return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
4904
4933
  };
4905
4934
  }
4935
+ function hessian$1(f) {
4936
+ return jacfwd$1(grad$1(f));
4937
+ }
4906
4938
 
4907
4939
  //#endregion
4908
4940
  //#region src/library/numpy/einsum.ts
@@ -5612,6 +5644,7 @@ __export(numpy_exports, {
5612
5644
  std: () => std,
5613
5645
  subtract: () => subtract,
5614
5646
  sum: () => sum,
5647
+ swapaxes: () => swapaxes,
5615
5648
  take: () => take,
5616
5649
  tan: () => tan,
5617
5650
  tanh: () => tanh,
@@ -6010,6 +6043,17 @@ function flipud(x) {
6010
6043
  function fliplr(x) {
6011
6044
  return flip(x, 1);
6012
6045
  }
6046
+ /** Interchange two axes of an array. */
6047
+ function swapaxes(a, axis1, axis2) {
6048
+ a = fudgeArray(a);
6049
+ axis1 = require_backend.checkAxis(axis1, a.ndim);
6050
+ axis2 = require_backend.checkAxis(axis2, a.ndim);
6051
+ if (axis1 === axis2) return a;
6052
+ const perm = require_backend.range(a.ndim);
6053
+ perm[axis1] = axis2;
6054
+ perm[axis2] = axis1;
6055
+ return transpose(a, perm);
6056
+ }
6013
6057
  /** Transpose the last two dimensions of an array. */
6014
6058
  function matrixTranspose(a) {
6015
6059
  if (ndim(a) < 2) throw new Error(`matrixTranspose: input array must be at least 2D`);
@@ -6938,6 +6982,7 @@ var lax_exports = {};
6938
6982
  __export(lax_exports, {
6939
6983
  conv: () => conv,
6940
6984
  convGeneralDilated: () => convGeneralDilated,
6985
+ convTranspose: () => convTranspose,
6941
6986
  convWithGeneralPadding: () => convWithGeneralPadding,
6942
6987
  dot: () => dot,
6943
6988
  erf: () => erf,
@@ -6946,6 +6991,7 @@ __export(lax_exports, {
6946
6991
  reduceWindow: () => reduceWindow,
6947
6992
  stopGradient: () => stopGradient$1
6948
6993
  });
6994
+ const JsArray = globalThis.Array;
6949
6995
  /**
6950
6996
  * General dot product/contraction operator.
6951
6997
  *
@@ -7017,7 +7063,11 @@ function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
7017
7063
  * The semantics of this operation mimic the `jax.lax.conv_general_dilated`
7018
7064
  * function in JAX, which wraps XLA's general convolution operator.
7019
7065
  *
7020
- * Grouped convolutions are not supported right now.
7066
+ * @param lhs - Input tensor; shape `[N, C_in, ...xs]`
7067
+ * @param rhs - Convolution kernel; shape `[C_out, C_in / G, ...ks]`
7068
+ * @param windowStrides - Strides for each spatial dimension
7069
+ * @param padding - Padding for each spatial dimension, or a string
7070
+ * (`"VALID"`, `"SAME"`, or `"SAME_LOWER"`)
7021
7071
  */
7022
7072
  function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
7023
7073
  if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
@@ -7077,6 +7127,60 @@ function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, r
7077
7127
  function conv(lhs, rhs, windowStrides, padding) {
7078
7128
  return convGeneralDilated(lhs, rhs, windowStrides, padding);
7079
7129
  }
7130
+ /**
7131
+ * Convenience wrapper for calculating the N-d convolution "transpose".
7132
+ *
7133
+ * This function directly calculates a fractionally strided conv rather than
7134
+ * indirectly calculating the gradient (transpose) of a forward convolution.
7135
+ * It is equivalent to the JAX version, except:
7136
+ *
7137
+ * - The `use_consistent_padding` option is not available. We only have the
7138
+ * consistent padding case (JAX version >0.8.4).
7139
+ * - The order of dimensions matches `lax.conv_general_dilated`.
7140
+ *
7141
+ * Unlike PyTorch/TensorFlow, by default we don't reverse the kernel's spatial
7142
+ * dimensions or the `(C_out, C_in)` axis order. To get this behavior, set
7143
+ * `transposeKernel` to true.
7144
+ *
7145
+ * @param lhs - Input tensor; shape `[N, C_in, ...xs]`
7146
+ * @param rhs - Convolution kernel; shape `[C_out, C_in, ...ks]`
7147
+ * @param strides - Sequence of n integers, sets fractional stride
7148
+ * @param padding - Apply padding of `dilation * (kernel_size - 1) - padding` to
7149
+ * each side of the input, so it acts like gradient of `conv()`
7150
+ * @param rhsDilation - Atrous dilation for the kernel
7151
+ * @param transposeKernel - Flip spatial axes and swap the input/output channels
7152
+ * of the kernel; its shape should be `[C_in, C_out, ...ks]`
7153
+ */
7154
+ function convTranspose(lhs, rhs, strides, padding, { rhsDilation, transposeKernel = false } = {}) {
7155
+ const kernelShape = rhs.shape.slice(2);
7156
+ rhsDilation = rhsDilation ?? require_backend.rep(kernelShape.length, 1);
7157
+ const effectiveKernel = kernelShape.map((k, i) => Math.max(0, (k - 1) * rhsDilation[i] + 1));
7158
+ const pads = effectiveKernel.map((k, i) => convTransposePadding(k, strides[i], typeof padding === "string" ? padding : padding[i]));
7159
+ if (transposeKernel) {
7160
+ rhs = flip$1(rhs, require_backend.range(2, rhs.ndim));
7161
+ rhs = moveaxis(rhs, 0, 1);
7162
+ }
7163
+ return convGeneralDilated(lhs, rhs, require_backend.rep(lhs.ndim - 2, 1), pads, {
7164
+ lhsDilation: strides,
7165
+ rhsDilation
7166
+ });
7167
+ }
7168
+ function convTransposePadding(k, s, padding) {
7169
+ let padLen;
7170
+ let pad1;
7171
+ if (padding === "SAME") {
7172
+ padLen = k + s - 2;
7173
+ pad1 = s > k - 1 ? k - 1 : Math.ceil(padLen / 2);
7174
+ } else if (padding === "VALID") {
7175
+ padLen = k + s - 2 + Math.max(k - s, 0);
7176
+ pad1 = k - 1;
7177
+ } else if (JsArray.isArray(padding)) {
7178
+ const pads = [k - 1 - padding[0], k - 1 - padding[1]];
7179
+ pad1 = pads[0];
7180
+ padLen = pads[0] + pads[1];
7181
+ } else throw new Error(`convTranspose: Invalid padding type ${padding}`);
7182
+ return [pad1, padLen - pad1];
7183
+ }
7080
7184
  /** Reduce a computation over padded windows. */
7081
7185
  function reduceWindow(operand, computation, windowDimensions, windowStrides) {
7082
7186
  if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
@@ -7115,6 +7219,7 @@ function stopGradient$1(x) {
7115
7219
  var nn_exports = {};
7116
7220
  __export(nn_exports, {
7117
7221
  celu: () => celu,
7222
+ dotProductAttention: () => dotProductAttention,
7118
7223
  elu: () => elu,
7119
7224
  gelu: () => gelu,
7120
7225
  glu: () => glu,
@@ -7431,6 +7536,95 @@ function oneHot(x, numClasses) {
7431
7536
  if (require_backend.isFloatDtype(x.dtype)) throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
7432
7537
  return eye(numClasses, void 0, { device: x.device }).slice(x);
7433
7538
  }
7539
+ /**
7540
+ * Scaled dot product attention (SDPA).
7541
+ *
7542
+ * Computes `softmax((Q @ K^T) / sqrt(d) + bias) @ V`, where `Q` is the query,
7543
+ * `K` is the key, `V` is the value, and `d` is the dimensionality of each key
7544
+ * and query vector.
7545
+ *
7546
+ * Multi-query attention is applied when input `key` and `value` tensors have
7547
+ * fewer heads than `query`.
7548
+ *
7549
+ * We use the following uppercase letters to denote array shapes:
7550
+ * - `B` = batch size
7551
+ * - `S` = length of key/value sequences (source)
7552
+ * - `L` = length of query sequences
7553
+ * - `N` = number of attention heads
7554
+ * - `H` = dimensionality of each attention head
7555
+ * - `K` = number of key/value heads (for grouped-query attention)
7556
+ *
7557
+ * The batch size `B` may be omitted, which is equivalent to `B = 1`. In this
7558
+ * case it must be omitted from all inputs.
7559
+ *
7560
+ * @param query - Query array; shape `[B, L, N, H]`
7561
+ * @param key - Key array; shape `[B, S, K, H]`
7562
+ * @param value - Value array; same shape as `key`
7563
+ * @param opts.bias - Optional bias to add to the attention logits; shape
7564
+ * `[B, N, L, S]` or broadcastable to it.
7565
+ * @param opts.mask - Optional mask to apply to the attention logits; should be
7566
+ * a boolean array broadcastable to `[B, N, L, S]`, where `true` indicates
7567
+ * the element should take part in attention.
7568
+ * @param opts.scale - Scaling factor override, default is `1 / sqrt(H)`.
7569
+ * @param opts.isCausal - If true, applies a casual mask.
7570
+ * @param opts.querySeqLengths - Optional sequence lengths for the queries;
7571
+ * shape `(B,)`. Taken from the beginning of the tensor.
7572
+ * @param opts.keyValueSeqLengths - Optional sequence lengths for the keys and
7573
+ * values; shape `(B,)`. Taken from the beginning of the tensor.
7574
+ * @param opts.localWindowSize - If specified, applies a local attention window
7575
+ * of the given size. Can be a single number or a tuple `[left, right]`.
7576
+ *
7577
+ * @returns The result of the attention operation; shape is the same as query
7578
+ * `[B, L, N, H]`, or `[L, N, H]` if `B` is omitted.
7579
+ */
7580
+ function dotProductAttention(query, key$1, value, opts = {}) {
7581
+ if (opts.querySeqLengths !== void 0 || opts.keyValueSeqLengths !== void 0) throw new Error("Sequence length masking is not yet implemented");
7582
+ if (opts.localWindowSize !== void 0) throw new Error("Local attention is not yet implemented");
7583
+ query = fudgeArray(query);
7584
+ key$1 = fudgeArray(key$1);
7585
+ value = fudgeArray(value);
7586
+ if (query.ndim !== 3 && query.ndim !== 4 || query.ndim !== key$1.ndim || query.ndim !== value.ndim) throw new Error(`dotProductAttention: expected all tensors to have rank 3 or 4, got Q=${query.aval}, K=${key$1.aval}, V=${value.aval}`);
7587
+ if (!require_backend.deepEqual(key$1.shape, value.shape)) throw new Error(`dotProductAttention: key and value shapes must match, got K=${key$1.shape}, V=${value.shape}`);
7588
+ const isRank3 = query.ndim === 3;
7589
+ if (isRank3) {
7590
+ query = expandDims(query, 0);
7591
+ key$1 = expandDims(key$1, 0);
7592
+ value = expandDims(value, 0);
7593
+ }
7594
+ const [B, L, N, H] = query.shape;
7595
+ if (key$1.shape[0] !== B || key$1.shape[3] !== H) throw new Error(`dotProductAttention: query and key shapes mismatch, got Q=${query.aval}, K=${key$1.aval}`);
7596
+ const S = key$1.shape[1];
7597
+ const K = key$1.shape[2];
7598
+ if (N < K || N != K && N % K !== 0) throw new Error(`dotProductAttention: number of query heads N=${N} must be divisible by number of key/value heads K=${K} for GQA`);
7599
+ const G = N / K;
7600
+ key$1 = tile(key$1, [
7601
+ 1,
7602
+ 1,
7603
+ G,
7604
+ 1
7605
+ ]);
7606
+ value = tile(value, [
7607
+ 1,
7608
+ 1,
7609
+ G,
7610
+ 1
7611
+ ]);
7612
+ const scale = opts.scale ?? 1 / Math.sqrt(H);
7613
+ let scores = einsum("BLNH,BSNH->BNLS", query, key$1).mul(scale);
7614
+ if (opts.bias !== void 0) scores = scores.add(opts.bias);
7615
+ if (opts.mask !== void 0) scores = where(opts.mask, scores, -Infinity);
7616
+ if (opts.isCausal) {
7617
+ const causalMask = tri(L, S, 0, { dtype: require_backend.DType.Bool });
7618
+ scores = where(causalMask, scores, -Infinity);
7619
+ }
7620
+ const attn = softmax(scores, -1);
7621
+ const out = einsum("BNLS,BSNH->BLNH", attn, value);
7622
+ return isRank3 ? out.reshape([
7623
+ L,
7624
+ N,
7625
+ H
7626
+ ]) : out;
7627
+ }
7434
7628
 
7435
7629
  //#endregion
7436
7630
  //#region src/library/random.ts
@@ -7666,17 +7860,62 @@ const linearize = linearize$1;
7666
7860
  /**
7667
7861
  * @function
7668
7862
  * Calculate the reverse-mode vector-Jacobian product for a function.
7863
+ *
7864
+ * The return value is a tuple of `[out, vjpFn]`, where `out` is the output of
7865
+ * `f(primals)`, and `vjpFn` is a function that takes in cotangents for each
7866
+ * output and returns the cotangents for each input.
7867
+ *
7868
+ * When `{ hasAux: true }` is passed, the function `f` is expected to return an
7869
+ * `[out, aux]` tuple, and `vjp` returns `[out, vjpFn, aux]`.
7870
+ *
7871
+ * @example
7872
+ * ```ts
7873
+ * const [y, vjpFn] = vjp(f, [x]);
7874
+ *
7875
+ * // With hasAux
7876
+ * const [y, vjpFn, aux] = vjp(f, [x], { hasAux: true });
7877
+ * ```
7669
7878
  */
7670
7879
  const vjp = vjp$1;
7671
7880
  /**
7672
7881
  * @function
7673
7882
  * Compute the gradient of a scalar-valued function `f` with respect to its
7674
7883
  * first argument.
7884
+ *
7885
+ * Pass in different `argnums` to differentiate with respect to other
7886
+ * arguments. If a tuple is provided, the return value will be a tuple of
7887
+ * gradients corresponding to each argument index.
7888
+ *
7889
+ * When `{ hasAux: true }` is passed, the function `f` is expected to return a
7890
+ * `[out, aux]` tuple, and the return value will be `[gradient, aux]`.
7891
+ *
7892
+ * @example
7893
+ * ```ts
7894
+ * const gradient = grad(f)(x);
7895
+ *
7896
+ * // With `argnums`
7897
+ * const [gradientX, gradientZ] = grad(f, { argnums: [0, 2] })(x, y, z);
7898
+ *
7899
+ * // With `hasAux`
7900
+ * const [gradient, aux] = grad(f, { hasAux: true })(x);
7901
+ * ```
7675
7902
  */
7676
7903
  const grad = grad$1;
7677
7904
  /**
7678
7905
  * @function
7679
7906
  * Create a function that evaluates both `f` and the gradient of `f`.
7907
+ *
7908
+ * When `{ hasAux: true }` is passed, the function `f` is expected to return an
7909
+ * `[out, aux]` tuple, and the return value will be `[[out, aux], gradient]`.
7910
+ *
7911
+ * @example
7912
+ * ```ts
7913
+ * // Without hasAux
7914
+ * const [value, gradient] = valueAndGrad(f)(x);
7915
+ *
7916
+ * // With hasAux
7917
+ * const [[value, aux], gradient] = valueAndGrad(f, { hasAux: true })(x);
7918
+ * ```
7680
7919
  */
7681
7920
  const valueAndGrad = valueAndGrad$1;
7682
7921
  /**
@@ -7685,6 +7924,21 @@ const valueAndGrad = valueAndGrad$1;
7685
7924
  */
7686
7925
  const jacrev = jacrev$1;
7687
7926
  /**
7927
+ * @function
7928
+ * Compute the Hessian matrix of a scalar-valued function.
7929
+ *
7930
+ * The Hessian is the matrix of second-order partial derivatives of a function.
7931
+ * This is implemented as `jacfwd(grad(f))`.
7932
+ *
7933
+ * @example
7934
+ * ```ts
7935
+ * const f = (x: np.Array) => np.sum(x.ref.mul(x.ref).mul(x)); // x^3
7936
+ * const H = hessian(f)(np.array([1, 2, 3]));
7937
+ * // H[i,j] = d^2f / dx_i dx_j
7938
+ * ```
7939
+ */
7940
+ const hessian = hessian$1;
7941
+ /**
7688
7942
  * Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
7689
7943
  *
7690
7944
  * This can be used to wait for the results of an intermediate computation to
@@ -7728,6 +7982,7 @@ exports.defaultDevice = require_backend.defaultDevice;
7728
7982
  exports.devicePut = devicePut;
7729
7983
  exports.devices = require_backend.devices;
7730
7984
  exports.grad = grad;
7985
+ exports.hessian = hessian;
7731
7986
  exports.init = require_backend.init;
7732
7987
  exports.jacfwd = jacfwd;
7733
7988
  exports.jacobian = jacrev;