@jax-js/jax 0.1.5 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -1,5 +1,5 @@
1
1
  import { __export } from "./chunk-Cl8Af3a2.js";
2
- import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-DaqL-MNz.js";
2
+ import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-Dx6Ob2D1.js";
3
3
 
4
4
  //#region src/frontend/convolution.ts
5
5
  /**
@@ -209,7 +209,7 @@ __export(tree_exports, {
209
209
  structure: () => structure,
210
210
  unflatten: () => unflatten
211
211
  });
212
- const JsArray$1 = globalThis.Array;
212
+ const JsArray$2 = globalThis.Array;
213
213
  let NodeType = /* @__PURE__ */ function(NodeType$1) {
214
214
  NodeType$1["Array"] = "Array";
215
215
  NodeType$1["Object"] = "Object";
@@ -257,7 +257,7 @@ function flatten(tree) {
257
257
  return [leaves$1, treedef];
258
258
  }
259
259
  function _flatten(tree, leaves$1) {
260
- if (JsArray$1.isArray(tree)) {
260
+ if (JsArray$2.isArray(tree)) {
261
261
  const childTrees = tree.map((c) => _flatten(c, leaves$1));
262
262
  return new JsTreeDef(NodeType.Array, null, childTrees);
263
263
  } else if (typeof tree === "object" && tree !== null && tree.constructor === Object) {
@@ -381,6 +381,13 @@ let CompareOp = /* @__PURE__ */ function(CompareOp$1) {
381
381
  CompareOp$1["LessEqual"] = "less_equal";
382
382
  return CompareOp$1;
383
383
  }({});
384
+ const routinePrimitives = new Map([
385
+ [Primitive.Sort, Routines.Sort],
386
+ [Primitive.Argsort, Routines.Argsort],
387
+ [Primitive.TriangularSolve, Routines.TriangularSolve],
388
+ [Primitive.Cholesky, Routines.Cholesky],
389
+ [Primitive.LU, Routines.LU]
390
+ ]);
384
391
  function add$1(x, y) {
385
392
  return bind1(Primitive.Add, [x, y]);
386
393
  }
@@ -654,6 +661,9 @@ function newDynamic(main) {
654
661
  dynamicTrace = prevDynamicTrace;
655
662
  } };
656
663
  }
664
+ function currentTraceLevel() {
665
+ return traceStack[traceStack.length - 1].level;
666
+ }
657
667
  var Trace = class {
658
668
  constructor(main) {
659
669
  this.main = main;
@@ -1031,6 +1041,7 @@ var TreeMismatchError = class extends TypeError {
1031
1041
  super(`Mismatched tree structures in ${where$2}: ${left} != ${right}`);
1032
1042
  }
1033
1043
  };
1044
+ /** Flatten a function of `JsTree` input/output for use in tracing. */
1034
1045
  function flattenFun(f, inTree) {
1035
1046
  const store = { value: void 0 };
1036
1047
  const flatFun = (...argsFlat) => {
@@ -1042,6 +1053,26 @@ function flattenFun(f, inTree) {
1042
1053
  };
1043
1054
  return [flatFun, store];
1044
1055
  }
1056
+ /** Like flattenFun, but expects f to return [main, aux] tuple. */
1057
+ function flattenFunWithAux(f, inTree) {
1058
+ const store = { value: void 0 };
1059
+ const auxStore = { value: void 0 };
1060
+ const flatFun = (...argsFlat) => {
1061
+ const pytreeArgs = unflatten(inTree, argsFlat);
1062
+ const result = f(...pytreeArgs);
1063
+ if (!Array.isArray(result) || result.length !== 2) throw new Error("Function with `hasAux: true` must return [output, aux] tuple");
1064
+ const [out, aux] = result;
1065
+ const [outFlat, outTree] = flatten(out);
1066
+ store.value = outTree;
1067
+ auxStore.value = aux;
1068
+ return outFlat;
1069
+ };
1070
+ return [
1071
+ flatFun,
1072
+ store,
1073
+ auxStore
1074
+ ];
1075
+ }
1045
1076
  var UseAfterFreeError = class extends ReferenceError {
1046
1077
  constructor(tracer) {
1047
1078
  super(`Referenced tracer ${tracer.toString()} freed, please use .ref move semantics`);
@@ -1771,13 +1802,6 @@ function jit$1(f, opts) {
1771
1802
 
1772
1803
  //#endregion
1773
1804
  //#region src/frontend/jit.ts
1774
- const routinePrimitives = new Map([
1775
- [Primitive.Sort, Routines.Sort],
1776
- [Primitive.Argsort, Routines.Argsort],
1777
- [Primitive.TriangularSolve, Routines.TriangularSolve],
1778
- [Primitive.Cholesky, Routines.Cholesky],
1779
- [Primitive.LU, Routines.LU]
1780
- ]);
1781
1805
  /** Result of compiling a Jaxpr. Can be evaluated on a series of inputs. */
1782
1806
  var JitProgram = class {
1783
1807
  constructor(backend, steps, inputs, outputs) {
@@ -2166,12 +2190,13 @@ const jitRules = {
2166
2190
  const ndim$2 = avals[0].ndim;
2167
2191
  const sizes = avals.map((x) => x.shape[axis]);
2168
2192
  const finalSize = sizes.reduce((a, b) => a + b, 0);
2193
+ const { dtype: dtypeOut } = avals.map((x) => x.scalar()).reduce(promoteAvals);
2169
2194
  const makePadAxis = (start, end) => range(ndim$2).map((i) => i === axis ? [start, end] : [0, 0]);
2170
2195
  let cum = 0;
2171
2196
  const src = [];
2172
2197
  for (let i = 0; i < exps.length; i++) {
2173
2198
  const padding = makePadAxis(cum, finalSize - cum - sizes[i]);
2174
- src.push(reshapeViews(exps[i], (st) => st.pad(padding)));
2199
+ src.push(reshapeViews(AluExp.cast(dtypeOut, exps[i]), (st) => st.pad(padding)));
2175
2200
  cum += sizes[i];
2176
2201
  }
2177
2202
  return { exp: [src.reduce(AluExp.add)] };
@@ -2309,7 +2334,7 @@ function splitGraphDataflow(backend, jaxpr) {
2309
2334
  p1NextBlack.set(v, v);
2310
2335
  }
2311
2336
  const heterogeneousViewPrimitives = [Primitive.RandomBits, Primitive.Gather];
2312
- const needsCleanShapePrimitives = [Primitive.Pad];
2337
+ const needsCleanShapePrimitives = [Primitive.Concatenate, Primitive.Pad];
2313
2338
  for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
2314
2339
  const eqn = jaxpr.eqns[i];
2315
2340
  if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || routinePrimitives.has(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
@@ -2379,7 +2404,7 @@ function splitGraphDataflow(backend, jaxpr) {
2379
2404
 
2380
2405
  //#endregion
2381
2406
  //#region src/frontend/array.ts
2382
- const JsArray = globalThis.Array;
2407
+ const JsArray$1 = globalThis.Array;
2383
2408
  const inlineArrayLimit = 128;
2384
2409
  /** Version of pureArray with fudged types. */
2385
2410
  const fudgeArray = pureArray;
@@ -2777,25 +2802,35 @@ var Array$1 = class Array$1 extends Tracer {
2777
2802
  });
2778
2803
  }
2779
2804
  /** Apply an operation with custom lowering to this array. */
2780
- static #routine(routine, arrays, outputWeakType) {
2781
- const { backend, committed } = Array$1.#computeBackend(routine.name, arrays);
2782
- for (const ar of arrays) ar.#realize();
2783
- const inputs = arrays.map((ar) => ar.#source);
2784
- const outputs = routine.type.outputDtypes.map((dtype, i) => backend.malloc(byteWidth(dtype) * prod(routine.type.outputShapes[i])));
2785
- const pending = arrays.flatMap((ar) => ar.#pending);
2786
- for (const exe of pending) exe.updateRc(+outputs.length);
2787
- pending.push(new PendingExecute(backend, routine, inputs, outputs));
2788
- pending[pending.length - 1].updateRc(+outputs.length - 1);
2789
- arrays.forEach((ar) => ar.dispose());
2790
- return outputs.map((output, i) => new Array$1({
2791
- source: output,
2792
- st: ShapeTracker.fromShape(routine.type.outputShapes[i]),
2793
- dtype: routine.type.outputDtypes[i],
2794
- weakType: outputWeakType[i],
2795
- backend,
2796
- committed,
2797
- pending
2798
- }));
2805
+ static #routine(prim) {
2806
+ return (arrays, params) => {
2807
+ const { backend, committed } = Array$1.#computeBackend(prim, arrays);
2808
+ for (const ar of arrays) ar.#realize();
2809
+ const avals = arrays.map((ar) => ar.aval);
2810
+ const avalsOut = abstractEvalRules[prim](avals, params);
2811
+ const routine = new Routine(routinePrimitives.get(prim), {
2812
+ inputShapes: avals.map((a) => a.shape),
2813
+ inputDtypes: avals.map((a) => a.dtype),
2814
+ outputShapes: avalsOut.map((a) => a.shape),
2815
+ outputDtypes: avalsOut.map((a) => a.dtype)
2816
+ }, params);
2817
+ const inputs = arrays.map((ar) => ar.#source);
2818
+ const outputs = avalsOut.map((x) => backend.malloc(byteWidth(x.dtype) * x.size));
2819
+ const pending = arrays.flatMap((ar) => ar.#pending);
2820
+ for (const exe of pending) exe.updateRc(+outputs.length);
2821
+ pending.push(new PendingExecute(backend, routine, inputs, outputs));
2822
+ pending[pending.length - 1].updateRc(+outputs.length - 1);
2823
+ arrays.forEach((ar) => ar.dispose());
2824
+ return outputs.map((output, i) => new Array$1({
2825
+ source: output,
2826
+ st: ShapeTracker.fromShape(avalsOut[i].shape),
2827
+ dtype: avalsOut[i].dtype,
2828
+ weakType: avalsOut[i].weakType,
2829
+ backend,
2830
+ committed,
2831
+ pending
2832
+ }));
2833
+ };
2799
2834
  }
2800
2835
  /**
2801
2836
  * Normalizes this array into one backed by a `Slot`.
@@ -3129,65 +3164,11 @@ var Array$1 = class Array$1 extends Tracer {
3129
3164
  [Primitive.Pad]([x], { width }) {
3130
3165
  return [x.#reshape(x.#st.pad(width))];
3131
3166
  },
3132
- [Primitive.Sort]([x]) {
3133
- const routine = new Routine(Routines.Sort, {
3134
- inputShapes: [x.shape],
3135
- inputDtypes: [x.dtype],
3136
- outputShapes: [x.shape],
3137
- outputDtypes: [x.dtype]
3138
- });
3139
- return Array$1.#routine(routine, [x], [x.#weakType]);
3140
- },
3141
- [Primitive.Argsort]([x]) {
3142
- const routine = new Routine(Routines.Argsort, {
3143
- inputShapes: [x.shape],
3144
- inputDtypes: [x.dtype],
3145
- outputShapes: [x.shape, x.shape],
3146
- outputDtypes: [x.dtype, DType.Int32]
3147
- });
3148
- return Array$1.#routine(routine, [x], [x.#weakType, false]);
3149
- },
3150
- [Primitive.TriangularSolve]([a, b], { unitDiagonal }) {
3151
- const routine = new Routine(Routines.TriangularSolve, {
3152
- inputShapes: [a.shape, b.shape],
3153
- inputDtypes: [a.dtype, b.dtype],
3154
- outputShapes: [b.shape],
3155
- outputDtypes: [b.dtype]
3156
- }, { unitDiagonal });
3157
- return Array$1.#routine(routine, [a, b], [a.#weakType && b.#weakType]);
3158
- },
3159
- [Primitive.Cholesky]([a]) {
3160
- const routine = new Routine(Routines.Cholesky, {
3161
- inputShapes: [a.shape],
3162
- inputDtypes: [a.dtype],
3163
- outputShapes: [a.shape],
3164
- outputDtypes: [a.dtype]
3165
- });
3166
- return Array$1.#routine(routine, [a], [a.#weakType]);
3167
- },
3168
- [Primitive.LU]([a]) {
3169
- const batch = a.shape.slice(0, -2);
3170
- const [m, n] = a.shape.slice(-2);
3171
- const routine = new Routine(Routines.LU, {
3172
- inputShapes: [a.shape],
3173
- inputDtypes: [a.dtype],
3174
- outputShapes: [
3175
- a.shape,
3176
- [...batch, Math.min(m, n)],
3177
- [...batch, m]
3178
- ],
3179
- outputDtypes: [
3180
- a.dtype,
3181
- DType.Int32,
3182
- DType.Int32
3183
- ]
3184
- });
3185
- return Array$1.#routine(routine, [a], [
3186
- a.#weakType,
3187
- false,
3188
- false
3189
- ]);
3190
- },
3167
+ [Primitive.Sort]: Array$1.#routine(Primitive.Sort),
3168
+ [Primitive.Argsort]: Array$1.#routine(Primitive.Argsort),
3169
+ [Primitive.TriangularSolve]: Array$1.#routine(Primitive.TriangularSolve),
3170
+ [Primitive.Cholesky]: Array$1.#routine(Primitive.Cholesky),
3171
+ [Primitive.LU]: Array$1.#routine(Primitive.LU),
3191
3172
  [Primitive.Jit](args, { jaxpr }) {
3192
3173
  if (jaxpr.inBinders.length !== args.length) throw new Error(`jit expects ${jaxpr.inBinders.length} args, got ${args.length}`);
3193
3174
  const { backend, committed } = Array$1.#computeBackend("jit", args);
@@ -3269,7 +3250,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
3269
3250
  if (!shape$1) {
3270
3251
  shape$1 = [];
3271
3252
  let cur = values;
3272
- while (JsArray.isArray(cur)) {
3253
+ while (JsArray$1.isArray(cur)) {
3273
3254
  shape$1.push(cur.length);
3274
3255
  cur = cur[0];
3275
3256
  }
@@ -4223,17 +4204,39 @@ function jvpFlat(f, primals, tangents) {
4223
4204
  _usingCtx$1.d();
4224
4205
  }
4225
4206
  }
4226
- function jvp$1(f, primals, tangents) {
4207
+ function jvp$1(f, primals, tangents, { hasAux = false } = {}) {
4227
4208
  const [primalsFlat, inTree] = flatten(primals);
4228
4209
  const [tangentsFlat, inTree2] = flatten(tangents);
4229
4210
  if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
4230
- const [flatFun, outTree] = flattenFun(f, inTree);
4211
+ let flatFun, outTree, aux;
4212
+ if (hasAux) [flatFun, outTree, aux] = flattenFunWithAux(f, inTree);
4213
+ else [flatFun, outTree] = flattenFun(f, inTree);
4231
4214
  const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
4232
4215
  if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
4233
4216
  const primalsOut = unflatten(outTree.value, primalsOutFlat);
4234
4217
  const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
4218
+ if (hasAux) return [
4219
+ primalsOut,
4220
+ tangentsOut,
4221
+ lowerAux(aux.value)
4222
+ ];
4235
4223
  return [primalsOut, tangentsOut];
4236
4224
  }
4225
+ /** Lowering for auxiliary data returned in `hasAux: true` methods. */
4226
+ function lowerAux(aux) {
4227
+ const level = currentTraceLevel();
4228
+ return map((x) => {
4229
+ if (x instanceof Tracer) while (x._trace.main.level > level) if (x instanceof JVPTracer) {
4230
+ x.tangent.dispose();
4231
+ x = x.primal;
4232
+ } else {
4233
+ const y = x.fullLower();
4234
+ if (y._trace.main.level >= x._trace.main.level) throw new Error("internal: lowerAux did not reduce trace level");
4235
+ x = y;
4236
+ }
4237
+ return x;
4238
+ }, aux);
4239
+ }
4237
4240
 
4238
4241
  //#endregion
4239
4242
  //#region src/frontend/linearize.ts
@@ -4304,9 +4307,11 @@ function linearizeFlat(f, primalsIn) {
4304
4307
  dispose$1
4305
4308
  ];
4306
4309
  }
4307
- function linearize$1(f, ...primalsIn) {
4310
+ function linearize$1(f, primalsIn, { hasAux = false } = {}) {
4308
4311
  const [primalsInFlat, inTree] = flatten(primalsIn);
4309
- const [fFlat, outTree] = flattenFun(f, inTree);
4312
+ let fFlat, outTree, aux;
4313
+ if (hasAux) [fFlat, outTree, aux] = flattenFunWithAux(f, inTree);
4314
+ else [fFlat, outTree] = flattenFun(f, inTree);
4310
4315
  const [primalsOutFlat, fLinFlat, dispose$1] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
4311
4316
  if (outTree.value === void 0) throw new Error("outTree was not set in linearize");
4312
4317
  const primalsOut = unflatten(outTree.value, primalsOutFlat);
@@ -4317,6 +4322,11 @@ function linearize$1(f, ...primalsIn) {
4317
4322
  return unflatten(outTree.value, tangentsOutFlat);
4318
4323
  });
4319
4324
  fLin.dispose = dispose$1;
4325
+ if (hasAux) return [
4326
+ primalsOut,
4327
+ fLin,
4328
+ lowerAux(aux.value)
4329
+ ];
4320
4330
  return [primalsOut, fLin];
4321
4331
  }
4322
4332
  var PartialEvalTracer = class extends Tracer {
@@ -4817,9 +4827,11 @@ function vjpFlat(f, primalsIn) {
4817
4827
  dispose$1
4818
4828
  ];
4819
4829
  }
4820
- function vjp$1(f, ...primalsIn) {
4830
+ function vjp$1(f, primalsIn, { hasAux = false } = {}) {
4821
4831
  const [primalsInFlat, inTree] = flatten(primalsIn);
4822
- const [fFlat, outTree] = flattenFun(f, inTree);
4832
+ let fFlat, outTree, aux;
4833
+ if (hasAux) [fFlat, outTree, aux] = flattenFunWithAux(f, inTree);
4834
+ else [fFlat, outTree] = flattenFun(f, inTree);
4823
4835
  const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
4824
4836
  if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
4825
4837
  const primalsOut = unflatten(outTree.value, primalsOutFlat);
@@ -4830,26 +4842,43 @@ function vjp$1(f, ...primalsIn) {
4830
4842
  return unflatten(inTree, cotangentsInFlat);
4831
4843
  });
4832
4844
  fVjp.dispose = dispose$1;
4845
+ if (hasAux) return [
4846
+ primalsOut,
4847
+ fVjp,
4848
+ lowerAux(aux.value)
4849
+ ];
4833
4850
  return [primalsOut, fVjp];
4834
4851
  }
4835
- function grad$1(f) {
4836
- const valueAndGradFn = valueAndGrad$1(f);
4852
+ function grad$1(f, opts) {
4853
+ const valueAndGradFn = valueAndGrad$1(f, opts);
4837
4854
  return (...x) => {
4838
- const [y, dx] = valueAndGradFn(...x);
4839
- y.dispose();
4840
- return dx;
4855
+ if (opts?.hasAux) {
4856
+ const [[y, aux], dx] = valueAndGradFn(...x);
4857
+ y.dispose();
4858
+ return [dx, aux];
4859
+ } else {
4860
+ const [y, dx] = valueAndGradFn(...x);
4861
+ y.dispose();
4862
+ return dx;
4863
+ }
4841
4864
  };
4842
4865
  }
4843
- function valueAndGrad$1(f) {
4866
+ function valueAndGrad$1(f, opts) {
4867
+ const argnums = opts?.argnums ?? 0;
4868
+ const hasAux = opts?.hasAux ?? false;
4869
+ checkInts(argnums);
4870
+ const argnumsSet = new Set(typeof argnums === "number" ? [argnums] : argnums);
4844
4871
  return (...x) => {
4845
4872
  if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
4846
- const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
4873
+ for (let i = 0; i < x.length; i++) if (!argnumsSet.has(i)) x[i] = map(stopGradient, x[i]);
4874
+ const [y, fVjp, aux] = vjp$1(f, x, { hasAux });
4847
4875
  if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
4848
4876
  if (!isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
4849
- const [ct, ...rest] = fVjp(onesLike$1(y.ref));
4850
- for (const r of rest) dispose(r);
4877
+ const cts = fVjp(onesLike$1(y.ref));
4851
4878
  fVjp.dispose();
4852
- return [y, ct];
4879
+ for (let i = 0; i < cts.length; i++) if (!argnumsSet.has(i)) dispose(cts[i]);
4880
+ const grads = typeof argnums === "number" ? cts[argnums] : argnums.map((i) => cts[i]);
4881
+ return hasAux ? [[y, aux], grads] : [y, grads];
4853
4882
  };
4854
4883
  }
4855
4884
  function jacrev$1(f) {
@@ -4857,7 +4886,7 @@ function jacrev$1(f) {
4857
4886
  if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
4858
4887
  const [size$1] = x.shape;
4859
4888
  const pullback = (ct) => {
4860
- const [y, fVjp] = vjp$1(f, x);
4889
+ const [y, fVjp] = vjp$1(f, [x]);
4861
4890
  y.dispose();
4862
4891
  const [ret] = fVjp(ct);
4863
4892
  fVjp.dispose();
@@ -4866,6 +4895,9 @@ function jacrev$1(f) {
4866
4895
  return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
4867
4896
  };
4868
4897
  }
4898
+ function hessian$1(f) {
4899
+ return jacfwd$1(grad$1(f));
4900
+ }
4869
4901
 
4870
4902
  //#endregion
4871
4903
  //#region src/library/numpy/einsum.ts
@@ -5575,6 +5607,7 @@ __export(numpy_exports, {
5575
5607
  std: () => std,
5576
5608
  subtract: () => subtract,
5577
5609
  sum: () => sum,
5610
+ swapaxes: () => swapaxes,
5578
5611
  take: () => take,
5579
5612
  tan: () => tan,
5580
5613
  tanh: () => tanh,
@@ -5973,6 +6006,17 @@ function flipud(x) {
5973
6006
  function fliplr(x) {
5974
6007
  return flip(x, 1);
5975
6008
  }
6009
+ /** Interchange two axes of an array. */
6010
+ function swapaxes(a, axis1, axis2) {
6011
+ a = fudgeArray(a);
6012
+ axis1 = checkAxis(axis1, a.ndim);
6013
+ axis2 = checkAxis(axis2, a.ndim);
6014
+ if (axis1 === axis2) return a;
6015
+ const perm = range(a.ndim);
6016
+ perm[axis1] = axis2;
6017
+ perm[axis2] = axis1;
6018
+ return transpose(a, perm);
6019
+ }
5976
6020
  /** Transpose the last two dimensions of an array. */
5977
6021
  function matrixTranspose(a) {
5978
6022
  if (ndim(a) < 2) throw new Error(`matrixTranspose: input array must be at least 2D`);
@@ -6901,6 +6945,7 @@ var lax_exports = {};
6901
6945
  __export(lax_exports, {
6902
6946
  conv: () => conv,
6903
6947
  convGeneralDilated: () => convGeneralDilated,
6948
+ convTranspose: () => convTranspose,
6904
6949
  convWithGeneralPadding: () => convWithGeneralPadding,
6905
6950
  dot: () => dot,
6906
6951
  erf: () => erf,
@@ -6909,6 +6954,7 @@ __export(lax_exports, {
6909
6954
  reduceWindow: () => reduceWindow,
6910
6955
  stopGradient: () => stopGradient$1
6911
6956
  });
6957
+ const JsArray = globalThis.Array;
6912
6958
  /**
6913
6959
  * General dot product/contraction operator.
6914
6960
  *
@@ -6980,7 +7026,11 @@ function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
6980
7026
  * The semantics of this operation mimic the `jax.lax.conv_general_dilated`
6981
7027
  * function in JAX, which wraps XLA's general convolution operator.
6982
7028
  *
6983
- * Grouped convolutions are not supported right now.
7029
+ * @param lhs - Input tensor; shape `[N, C_in, ...xs]`
7030
+ * @param rhs - Convolution kernel; shape `[C_out, C_in / G, ...ks]`
7031
+ * @param windowStrides - Strides for each spatial dimension
7032
+ * @param padding - Padding for each spatial dimension, or a string
7033
+ * (`"VALID"`, `"SAME"`, or `"SAME_LOWER"`)
6984
7034
  */
6985
7035
  function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
6986
7036
  if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
@@ -7040,6 +7090,60 @@ function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, r
7040
7090
  function conv(lhs, rhs, windowStrides, padding) {
7041
7091
  return convGeneralDilated(lhs, rhs, windowStrides, padding);
7042
7092
  }
7093
+ /**
7094
+ * Convenience wrapper for calculating the N-d convolution "transpose".
7095
+ *
7096
+ * This function directly calculates a fractionally strided conv rather than
7097
+ * indirectly calculating the gradient (transpose) of a forward convolution.
7098
+ * It is equivalent to the JAX version, except:
7099
+ *
7100
+ * - The `use_consistent_padding` option is not available. We only have the
7101
+ * consistent padding case (JAX version >0.8.4).
7102
+ * - The order of dimensions matches `lax.conv_general_dilated`.
7103
+ *
7104
+ * Unlike PyTorch/TensorFlow, by default we don't reverse the kernel's spatial
7105
+ * dimensions or the `(C_out, C_in)` axis order. To get this behavior, set
7106
+ * `transposeKernel` to true.
7107
+ *
7108
+ * @param lhs - Input tensor; shape `[N, C_in, ...xs]`
7109
+ * @param rhs - Convolution kernel; shape `[C_out, C_in, ...ks]`
7110
+ * @param strides - Sequence of n integers, sets fractional stride
7111
+ * @param padding - Apply padding of `dilation * (kernel_size - 1) - padding` to
7112
+ * each side of the input, so it acts like gradient of `conv()`
7113
+ * @param rhsDilation - Atrous dilation for the kernel
7114
+ * @param transposeKernel - Flip spatial axes and swap the input/output channels
7115
+ * of the kernel; its shape should be `[C_in, C_out, ...ks]`
7116
+ */
7117
+ function convTranspose(lhs, rhs, strides, padding, { rhsDilation, transposeKernel = false } = {}) {
7118
+ const kernelShape = rhs.shape.slice(2);
7119
+ rhsDilation = rhsDilation ?? rep(kernelShape.length, 1);
7120
+ const effectiveKernel = kernelShape.map((k, i) => Math.max(0, (k - 1) * rhsDilation[i] + 1));
7121
+ const pads = effectiveKernel.map((k, i) => convTransposePadding(k, strides[i], typeof padding === "string" ? padding : padding[i]));
7122
+ if (transposeKernel) {
7123
+ rhs = flip$1(rhs, range(2, rhs.ndim));
7124
+ rhs = moveaxis(rhs, 0, 1);
7125
+ }
7126
+ return convGeneralDilated(lhs, rhs, rep(lhs.ndim - 2, 1), pads, {
7127
+ lhsDilation: strides,
7128
+ rhsDilation
7129
+ });
7130
+ }
7131
+ function convTransposePadding(k, s, padding) {
7132
+ let padLen;
7133
+ let pad1;
7134
+ if (padding === "SAME") {
7135
+ padLen = k + s - 2;
7136
+ pad1 = s > k - 1 ? k - 1 : Math.ceil(padLen / 2);
7137
+ } else if (padding === "VALID") {
7138
+ padLen = k + s - 2 + Math.max(k - s, 0);
7139
+ pad1 = k - 1;
7140
+ } else if (JsArray.isArray(padding)) {
7141
+ const pads = [k - 1 - padding[0], k - 1 - padding[1]];
7142
+ pad1 = pads[0];
7143
+ padLen = pads[0] + pads[1];
7144
+ } else throw new Error(`convTranspose: Invalid padding type ${padding}`);
7145
+ return [pad1, padLen - pad1];
7146
+ }
7043
7147
  /** Reduce a computation over padded windows. */
7044
7148
  function reduceWindow(operand, computation, windowDimensions, windowStrides) {
7045
7149
  if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
@@ -7078,6 +7182,7 @@ function stopGradient$1(x) {
7078
7182
  var nn_exports = {};
7079
7183
  __export(nn_exports, {
7080
7184
  celu: () => celu,
7185
+ dotProductAttention: () => dotProductAttention,
7081
7186
  elu: () => elu,
7082
7187
  gelu: () => gelu,
7083
7188
  glu: () => glu,
@@ -7394,6 +7499,95 @@ function oneHot(x, numClasses) {
7394
7499
  if (isFloatDtype(x.dtype)) throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
7395
7500
  return eye(numClasses, void 0, { device: x.device }).slice(x);
7396
7501
  }
7502
+ /**
7503
+ * Scaled dot product attention (SDPA).
7504
+ *
7505
+ * Computes `softmax((Q @ K^T) / sqrt(d) + bias) @ V`, where `Q` is the query,
7506
+ * `K` is the key, `V` is the value, and `d` is the dimensionality of each key
7507
+ * and query vector.
7508
+ *
7509
+ * Multi-query attention is applied when input `key` and `value` tensors have
7510
+ * fewer heads than `query`.
7511
+ *
7512
+ * We use the following uppercase letters to denote array shapes:
7513
+ * - `B` = batch size
7514
+ * - `S` = length of key/value sequences (source)
7515
+ * - `L` = length of query sequences
7516
+ * - `N` = number of attention heads
7517
+ * - `H` = dimensionality of each attention head
7518
+ * - `K` = number of key/value heads (for grouped-query attention)
7519
+ *
7520
+ * The batch size `B` may be omitted, which is equivalent to `B = 1`. In this
7521
+ * case it must be omitted from all inputs.
7522
+ *
7523
+ * @param query - Query array; shape `[B, L, N, H]`
7524
+ * @param key - Key array; shape `[B, S, K, H]`
7525
+ * @param value - Value array; same shape as `key`
7526
+ * @param opts.bias - Optional bias to add to the attention logits; shape
7527
+ * `[B, N, L, S]` or broadcastable to it.
7528
+ * @param opts.mask - Optional mask to apply to the attention logits; should be
7529
+ * a boolean array broadcastable to `[B, N, L, S]`, where `true` indicates
7530
+ * the element should take part in attention.
7531
+ * @param opts.scale - Scaling factor override, default is `1 / sqrt(H)`.
7532
+ * @param opts.isCausal - If true, applies a casual mask.
7533
+ * @param opts.querySeqLengths - Optional sequence lengths for the queries;
7534
+ * shape `(B,)`. Taken from the beginning of the tensor.
7535
+ * @param opts.keyValueSeqLengths - Optional sequence lengths for the keys and
7536
+ * values; shape `(B,)`. Taken from the beginning of the tensor.
7537
+ * @param opts.localWindowSize - If specified, applies a local attention window
7538
+ * of the given size. Can be a single number or a tuple `[left, right]`.
7539
+ *
7540
+ * @returns The result of the attention operation; shape is the same as query
7541
+ * `[B, L, N, H]`, or `[L, N, H]` if `B` is omitted.
7542
+ */
7543
+ function dotProductAttention(query, key$1, value, opts = {}) {
7544
+ if (opts.querySeqLengths !== void 0 || opts.keyValueSeqLengths !== void 0) throw new Error("Sequence length masking is not yet implemented");
7545
+ if (opts.localWindowSize !== void 0) throw new Error("Local attention is not yet implemented");
7546
+ query = fudgeArray(query);
7547
+ key$1 = fudgeArray(key$1);
7548
+ value = fudgeArray(value);
7549
+ if (query.ndim !== 3 && query.ndim !== 4 || query.ndim !== key$1.ndim || query.ndim !== value.ndim) throw new Error(`dotProductAttention: expected all tensors to have rank 3 or 4, got Q=${query.aval}, K=${key$1.aval}, V=${value.aval}`);
7550
+ if (!deepEqual(key$1.shape, value.shape)) throw new Error(`dotProductAttention: key and value shapes must match, got K=${key$1.shape}, V=${value.shape}`);
7551
+ const isRank3 = query.ndim === 3;
7552
+ if (isRank3) {
7553
+ query = expandDims(query, 0);
7554
+ key$1 = expandDims(key$1, 0);
7555
+ value = expandDims(value, 0);
7556
+ }
7557
+ const [B, L, N, H] = query.shape;
7558
+ if (key$1.shape[0] !== B || key$1.shape[3] !== H) throw new Error(`dotProductAttention: query and key shapes mismatch, got Q=${query.aval}, K=${key$1.aval}`);
7559
+ const S = key$1.shape[1];
7560
+ const K = key$1.shape[2];
7561
+ if (N < K || N != K && N % K !== 0) throw new Error(`dotProductAttention: number of query heads N=${N} must be divisible by number of key/value heads K=${K} for GQA`);
7562
+ const G = N / K;
7563
+ key$1 = tile(key$1, [
7564
+ 1,
7565
+ 1,
7566
+ G,
7567
+ 1
7568
+ ]);
7569
+ value = tile(value, [
7570
+ 1,
7571
+ 1,
7572
+ G,
7573
+ 1
7574
+ ]);
7575
+ const scale = opts.scale ?? 1 / Math.sqrt(H);
7576
+ let scores = einsum("BLNH,BSNH->BNLS", query, key$1).mul(scale);
7577
+ if (opts.bias !== void 0) scores = scores.add(opts.bias);
7578
+ if (opts.mask !== void 0) scores = where(opts.mask, scores, -Infinity);
7579
+ if (opts.isCausal) {
7580
+ const causalMask = tri(L, S, 0, { dtype: DType.Bool });
7581
+ scores = where(causalMask, scores, -Infinity);
7582
+ }
7583
+ const attn = softmax(scores, -1);
7584
+ const out = einsum("BNLS,BSNH->BLNH", attn, value);
7585
+ return isRank3 ? out.reshape([
7586
+ L,
7587
+ N,
7588
+ H
7589
+ ]) : out;
7590
+ }
7397
7591
 
7398
7592
  //#endregion
7399
7593
  //#region src/library/random.ts
@@ -7629,17 +7823,62 @@ const linearize = linearize$1;
7629
7823
  /**
7630
7824
  * @function
7631
7825
  * Calculate the reverse-mode vector-Jacobian product for a function.
7826
+ *
7827
+ * The return value is a tuple of `[out, vjpFn]`, where `out` is the output of
7828
+ * `f(primals)`, and `vjpFn` is a function that takes in cotangents for each
7829
+ * output and returns the cotangents for each input.
7830
+ *
7831
+ * When `{ hasAux: true }` is passed, the function `f` is expected to return an
7832
+ * `[out, aux]` tuple, and `vjp` returns `[out, vjpFn, aux]`.
7833
+ *
7834
+ * @example
7835
+ * ```ts
7836
+ * const [y, vjpFn] = vjp(f, [x]);
7837
+ *
7838
+ * // With hasAux
7839
+ * const [y, vjpFn, aux] = vjp(f, [x], { hasAux: true });
7840
+ * ```
7632
7841
  */
7633
7842
  const vjp = vjp$1;
7634
7843
  /**
7635
7844
  * @function
7636
7845
  * Compute the gradient of a scalar-valued function `f` with respect to its
7637
7846
  * first argument.
7847
+ *
7848
+ * Pass in different `argnums` to differentiate with respect to other
7849
+ * arguments. If a tuple is provided, the return value will be a tuple of
7850
+ * gradients corresponding to each argument index.
7851
+ *
7852
+ * When `{ hasAux: true }` is passed, the function `f` is expected to return a
7853
+ * `[out, aux]` tuple, and the return value will be `[gradient, aux]`.
7854
+ *
7855
+ * @example
7856
+ * ```ts
7857
+ * const gradient = grad(f)(x);
7858
+ *
7859
+ * // With `argnums`
7860
+ * const [gradientX, gradientZ] = grad(f, { argnums: [0, 2] })(x, y, z);
7861
+ *
7862
+ * // With `hasAux`
7863
+ * const [gradient, aux] = grad(f, { hasAux: true })(x);
7864
+ * ```
7638
7865
  */
7639
7866
  const grad = grad$1;
7640
7867
  /**
7641
7868
  * @function
7642
7869
  * Create a function that evaluates both `f` and the gradient of `f`.
7870
+ *
7871
+ * When `{ hasAux: true }` is passed, the function `f` is expected to return an
7872
+ * `[out, aux]` tuple, and the return value will be `[[out, aux], gradient]`.
7873
+ *
7874
+ * @example
7875
+ * ```ts
7876
+ * // Without hasAux
7877
+ * const [value, gradient] = valueAndGrad(f)(x);
7878
+ *
7879
+ * // With hasAux
7880
+ * const [[value, aux], gradient] = valueAndGrad(f, { hasAux: true })(x);
7881
+ * ```
7643
7882
  */
7644
7883
  const valueAndGrad = valueAndGrad$1;
7645
7884
  /**
@@ -7648,6 +7887,21 @@ const valueAndGrad = valueAndGrad$1;
7648
7887
  */
7649
7888
  const jacrev = jacrev$1;
7650
7889
  /**
7890
+ * @function
7891
+ * Compute the Hessian matrix of a scalar-valued function.
7892
+ *
7893
+ * The Hessian is the matrix of second-order partial derivatives of a function.
7894
+ * This is implemented as `jacfwd(grad(f))`.
7895
+ *
7896
+ * @example
7897
+ * ```ts
7898
+ * const f = (x: np.Array) => np.sum(x.ref.mul(x.ref).mul(x)); // x^3
7899
+ * const H = hessian(f)(np.array([1, 2, 3]));
7900
+ * // H[i,j] = d^2f / dx_i dx_j
7901
+ * ```
7902
+ */
7903
+ const hessian = hessian$1;
7904
+ /**
7651
7905
  * Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
7652
7906
  *
7653
7907
  * This can be used to wait for the results of an intermediate computation to
@@ -7682,4 +7936,4 @@ async function devicePut(x, device) {
7682
7936
  }
7683
7937
 
7684
7938
  //#endregion
7685
- export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
7939
+ export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };