@jax-js/jax 0.1.2 → 0.1.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -1,28 +1,36 @@
1
1
  import { __export } from "./chunk-Cl8Af3a2.js";
2
- import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-BqymqzuU.js";
2
+ import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-tngXtWe4.js";
3
3
 
4
4
  //#region src/frontend/convolution.ts
5
5
  /**
6
6
  * Check that the shapes and parameters passed to convolution are valid.
7
+ * Expected shapes of the lhs and rhs of the convolution are:
8
+ *
9
+ * - `lhsShape = [*vmapDims, batchSize, inChannels, spatialDims...]`
10
+ * - `rhsShape = [*vmapDims, outChannels, inChannels, kernelSize...]`
7
11
  *
8
12
  * If the check succeeds, returns the output shape.
9
13
  */
10
- function checkConvShape(lhsShape, rhsShape, { strides, padding, lhsDilation, rhsDilation }) {
14
+ function checkConvShape(lhsShape, rhsShape, { vmapDims, strides, padding, lhsDilation, rhsDilation }) {
11
15
  if (lhsShape.length !== rhsShape.length) throw new Error(`conv() requires inputs with the same number of dimensions, got ${lhsShape.length} and ${rhsShape.length}`);
12
- const n = lhsShape.length - 2;
16
+ const n = lhsShape.length - 2 - vmapDims;
13
17
  if (n < 0) throw new Error("conv() requires at least 2D inputs");
14
18
  if (strides.length !== n) throw new Error("conv() strides != spatial dims");
15
19
  if (padding.length !== n) throw new Error("conv() padding != spatial dims");
16
20
  if (lhsDilation.length !== n) throw new Error("conv() lhsDilation != spatial dimensions");
17
21
  if (rhsDilation.length !== n) throw new Error("conv() rhsDilation != spatial dimensions");
18
- if (lhsShape[1] !== rhsShape[1]) throw new Error(`conv() input channels: ${lhsShape[1]} != ${rhsShape[1]}`);
19
- const outShape = [lhsShape[0], rhsShape[0]];
22
+ if (lhsShape[vmapDims + 1] !== rhsShape[vmapDims + 1]) throw new Error(`conv() input channels: ${lhsShape[1]} != ${rhsShape[1]}`);
23
+ const outShape = [
24
+ ...generalBroadcast(lhsShape.slice(0, vmapDims), rhsShape.slice(0, vmapDims)),
25
+ lhsShape[vmapDims],
26
+ rhsShape[vmapDims]
27
+ ];
20
28
  for (let i = 0; i < n; i++) {
21
29
  if (strides[i] <= 0 || !Number.isInteger(strides[i])) throw new Error(`conv() strides[${i}] must be a positive integer`);
22
30
  if (padding[i].length !== 2 || !padding[i].every(Number.isInteger)) throw new Error(`conv() padding[${i}] must be a 2-tuple of integers`);
23
31
  if (lhsDilation[i] <= 0 || !Number.isInteger(lhsDilation[i])) throw new Error(`conv() lhsDilation[${i}] must be a positive integer`);
24
32
  if (rhsDilation[i] <= 0 || !Number.isInteger(rhsDilation[i])) throw new Error(`conv() rhsDilation[${i}] must be a positive integer`);
25
- const [x, k] = [lhsShape[i + 2], rhsShape[i + 2]];
33
+ const [x, k] = [lhsShape[i + vmapDims + 2], rhsShape[i + vmapDims + 2]];
26
34
  if (k <= 0) throw new Error("conv() kernel size must be positive");
27
35
  const [pl, pr] = padding[i];
28
36
  if (pl < -x || pr < -x || pl + pr < -x) throw new Error(`conv() padding[${i}]=(${pl},${pr}) is too negative for input size ${x}`);
@@ -147,27 +155,13 @@ function poolTranspose(st, inShape, ks, strides = 1, dilation = 1) {
147
155
  function applyDilation(st, dilation) {
148
156
  if (dilation.every((s) => s === 1)) return st;
149
157
  const s_ = dilation;
150
- const [a, b, ...k_] = st.shape;
151
- st = st.reshape([
152
- a,
153
- b,
154
- ...k_.flatMap((k) => [k, 1])
155
- ]);
156
- st = st.pad([
157
- [0, 0],
158
- [0, 0],
159
- ...s_.flatMap((s) => [[0, 0], [0, s - 1]])
160
- ]);
161
- st = st.reshape([
162
- a,
163
- b,
164
- ...k_.map((k, i) => k * s_[i])
165
- ]);
166
- st = st.shrink([
167
- [0, a],
168
- [0, b],
169
- ...k_.map((k, i) => [0, (k - 1) * s_[i] + 1])
170
- ]);
158
+ const n = s_.length;
159
+ const prefix = st.shape.slice(0, -n);
160
+ const k_ = st.shape.slice(-n);
161
+ st = st.reshape([...prefix, ...k_.flatMap((k) => [k, 1])]);
162
+ st = st.pad([...prefix.map(() => [0, 0]), ...s_.flatMap((s) => [[0, 0], [0, s - 1]])]);
163
+ st = st.reshape([...prefix, ...k_.map((k, i) => k * s_[i])]);
164
+ st = st.shrink([...prefix.map((p) => [0, p]), ...k_.map((k, i) => [0, (k - 1) * s_[i] + 1])]);
171
165
  return st;
172
166
  }
173
167
  /**
@@ -177,25 +171,26 @@ function applyDilation(st, dilation) {
177
171
  * beforehand using `checkConvShape()`.
178
172
  */
179
173
  function prepareConv(stX, stY, params) {
180
- const n = stX.shape.length - 2;
174
+ const v = params.vmapDims;
175
+ const n = stX.shape.length - 2 - v;
176
+ const vmapShape = stX.shape.slice(0, v);
181
177
  stX = applyDilation(stX, params.lhsDilation);
182
- const ks = stY.shape.slice(2);
183
- stX = stX.padOrShrink([
184
- [0, 0],
185
- [0, 0],
186
- ...params.padding
187
- ]);
178
+ const ks = stY.shape.slice(v + 2);
179
+ stX = stX.padOrShrink([...rep(v + 2, [0, 0]), ...params.padding]);
188
180
  stX = pool(stX, ks, params.strides, params.rhsDilation);
189
- stX = stX.moveaxis(1, n + 1).reshape([
190
- stX.shape[0],
181
+ stX = stX.moveaxis(v + 1, v + n + 1).reshape([
182
+ ...vmapShape,
183
+ stX.shape[v],
191
184
  1,
192
- ...stX.shape.slice(2, n + 2),
193
- stX.shape[1] * prod(ks)
185
+ ...stX.shape.slice(v + 2, v + n + 2),
186
+ stX.shape[v + 1] * prod(ks)
194
187
  ]);
195
188
  stY = stY.reshape([
196
- stY.shape[0],
189
+ ...vmapShape,
190
+ 1,
191
+ stY.shape[v],
197
192
  ...rep(n, 1),
198
- stY.shape[1] * prod(ks)
193
+ stY.shape[v + 1] * prod(ks)
199
194
  ]);
200
195
  return [stX, stY];
201
196
  }
@@ -336,6 +331,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
336
331
  Primitive$1["Mul"] = "mul";
337
332
  Primitive$1["Idiv"] = "idiv";
338
333
  Primitive$1["Mod"] = "mod";
334
+ Primitive$1["Min"] = "min";
335
+ Primitive$1["Max"] = "max";
339
336
  Primitive$1["Neg"] = "neg";
340
337
  Primitive$1["Reciprocal"] = "reciprocal";
341
338
  Primitive$1["Floor"] = "floor";
@@ -343,7 +340,6 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
343
340
  Primitive$1["StopGradient"] = "stop_gradient";
344
341
  Primitive$1["Cast"] = "cast";
345
342
  Primitive$1["Bitcast"] = "bitcast";
346
- Primitive$1["RandomBits"] = "random_bits";
347
343
  Primitive$1["Sin"] = "sin";
348
344
  Primitive$1["Cos"] = "cos";
349
345
  Primitive$1["Asin"] = "asin";
@@ -353,8 +349,6 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
353
349
  Primitive$1["Erf"] = "erf";
354
350
  Primitive$1["Erfc"] = "erfc";
355
351
  Primitive$1["Sqrt"] = "sqrt";
356
- Primitive$1["Min"] = "min";
357
- Primitive$1["Max"] = "max";
358
352
  Primitive$1["Reduce"] = "reduce";
359
353
  Primitive$1["Dot"] = "dot";
360
354
  Primitive$1["Conv"] = "conv";
@@ -362,14 +356,19 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
362
356
  Primitive$1["PoolTranspose"] = "pool_transpose";
363
357
  Primitive$1["Compare"] = "compare";
364
358
  Primitive$1["Where"] = "where";
359
+ Primitive$1["RandomBits"] = "random_bits";
360
+ Primitive$1["Gather"] = "gather";
365
361
  Primitive$1["Transpose"] = "transpose";
366
362
  Primitive$1["Broadcast"] = "broadcast";
367
363
  Primitive$1["Reshape"] = "reshape";
368
364
  Primitive$1["Flip"] = "flip";
369
365
  Primitive$1["Shrink"] = "shrink";
370
366
  Primitive$1["Pad"] = "pad";
371
- Primitive$1["Gather"] = "gather";
372
- Primitive$1["JitCall"] = "jit_call";
367
+ Primitive$1["Sort"] = "sort";
368
+ Primitive$1["Argsort"] = "argsort";
369
+ Primitive$1["TriangularSolve"] = "triangular_solve";
370
+ Primitive$1["Cholesky"] = "cholesky";
371
+ Primitive$1["Jit"] = "jit";
373
372
  return Primitive$1;
374
373
  }({});
375
374
  let CompareOp = /* @__PURE__ */ function(CompareOp$1) {
@@ -391,6 +390,12 @@ function idiv(x, y) {
391
390
  function mod(x, y) {
392
391
  return bind1(Primitive.Mod, [x, y]);
393
392
  }
393
+ function min$1(x, y) {
394
+ return bind1(Primitive.Min, [x, y]);
395
+ }
396
+ function max$1(x, y) {
397
+ return bind1(Primitive.Max, [x, y]);
398
+ }
394
399
  function neg(x) {
395
400
  return bind1(Primitive.Neg, [x]);
396
401
  }
@@ -412,12 +417,6 @@ function cast(x, dtype) {
412
417
  function bitcast(x, dtype) {
413
418
  return bind1(Primitive.Bitcast, [x], { dtype });
414
419
  }
415
- function randomBits(k0, k1, shape$1, mode = "xor") {
416
- return bind1(Primitive.RandomBits, [k0, k1], {
417
- shape: shape$1,
418
- mode
419
- });
420
- }
421
420
  function sin$1(x) {
422
421
  return bind1(Primitive.Sin, [x]);
423
422
  }
@@ -445,12 +444,6 @@ function erfc$1(x) {
445
444
  function sqrt$1(x) {
446
445
  return bind1(Primitive.Sqrt, [x]);
447
446
  }
448
- function min$1(x, y) {
449
- return bind1(Primitive.Min, [x, y]);
450
- }
451
- function max$1(x, y) {
452
- return bind1(Primitive.Max, [x, y]);
453
- }
454
447
  function reduce(x, op, axis = null, opts) {
455
448
  if (!AluGroup.Reduce.has(op)) throw new TypeError(`Invalid reduce operation: ${op}`);
456
449
  axis = normalizeAxis(axis, ndim$1(x));
@@ -467,9 +460,11 @@ function dot$2(x, y) {
467
460
  }
468
461
  function conv$1(x, y, params = {}) {
469
462
  if (x.ndim !== y.ndim) throw new Error(`conv() requires inputs with the same number of dimensions, got ${x.ndim} and ${y.ndim}`);
470
- const n = x.ndim - 2;
463
+ const vmapDims = params.vmapDims ?? 0;
464
+ const n = x.ndim - 2 - vmapDims;
471
465
  if (n < 0) throw new Error("conv() requires at least 2D inputs");
472
466
  return bind1(Primitive.Conv, [x, y], {
467
+ vmapDims,
473
468
  strides: params.strides ?? rep(n, 1),
474
469
  padding: params.padding ?? rep(n, [0, 0]),
475
470
  lhsDilation: params.lhsDilation ?? rep(n, 1),
@@ -504,6 +499,23 @@ function where$1(cond, x, y) {
504
499
  y
505
500
  ]);
506
501
  }
502
+ function randomBits(k0, k1, shape$1, mode = "xor") {
503
+ return bind1(Primitive.RandomBits, [k0, k1], {
504
+ shape: shape$1,
505
+ mode
506
+ });
507
+ }
508
+ function gather(x, indices, axis, outDim) {
509
+ if (indices.length === 0) throw new Error("gather() requires at least one index");
510
+ if (!Array.isArray(axis) || axis.length !== indices.length) throw new Error(`Invalid gather() axis: expected ${indices.length} axes, got ${JSON.stringify(axis)}`);
511
+ axis = axis.map((a) => checkAxis(a, ndim$1(x)));
512
+ if (new Set(axis).size !== axis.length) throw new Error(`Invalid gather() axis: duplicate axes ${JSON.stringify(axis)}`);
513
+ outDim = checkAxis(outDim, ndim$1(x) - axis.length + 1);
514
+ return bind1(Primitive.Gather, [x, ...indices], {
515
+ axis,
516
+ outDim
517
+ });
518
+ }
507
519
  function transpose$1(x, perm) {
508
520
  perm = perm ? perm.map((a) => checkAxis(a, ndim$1(x))) : range(ndim$1(x)).reverse();
509
521
  if (!isPermutation(perm, ndim$1(x))) throw new Error(`Invalid transpose permutation for ${ndim$1(x)} axes: ${JSON.stringify(perm)}`);
@@ -553,16 +565,27 @@ function pad$1(x, width) {
553
565
  } else if (width.length !== nd) throw new Error(`Invalid pad(): expected ${nd} axes, got ${width.length}`);
554
566
  return bind1(Primitive.Pad, [x], { width });
555
567
  }
556
- function gather(x, indices, axis, outDim) {
557
- if (indices.length === 0) throw new Error("gather() requires at least one index");
558
- if (!Array.isArray(axis) || axis.length !== indices.length) throw new Error(`Invalid gather() axis: expected ${indices.length} axes, got ${JSON.stringify(axis)}`);
559
- axis = axis.map((a) => checkAxis(a, ndim$1(x)));
560
- if (new Set(axis).size !== axis.length) throw new Error(`Invalid gather() axis: duplicate axes ${JSON.stringify(axis)}`);
561
- outDim = checkAxis(outDim, ndim$1(x) - axis.length + 1);
562
- return bind1(Primitive.Gather, [x, ...indices], {
563
- axis,
564
- outDim
565
- });
568
+ function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
569
+ if (lower) {
570
+ a = flip$1(a, [-2, -1]);
571
+ b = flip$1(b, [-1]);
572
+ }
573
+ let x = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
574
+ if (lower) x = flip$1(x, [-1]);
575
+ return x;
576
+ }
577
+ function cholesky$2(x) {
578
+ return bind1(Primitive.Cholesky, [x]);
579
+ }
580
+ function sort$1(x) {
581
+ const nd = ndim$1(x);
582
+ if (nd === 0) throw new Error("sort: requires at least 1D input");
583
+ return bind1(Primitive.Sort, [x]);
584
+ }
585
+ function argsort$1(x) {
586
+ const nd = ndim$1(x);
587
+ if (nd === 0) throw new Error("argsort: requires at least 1D input");
588
+ return bind(Primitive.Argsort, [x]);
566
589
  }
567
590
  function bind1(prim, args, params = {}) {
568
591
  const [results] = bind(prim, args, params);
@@ -693,8 +716,10 @@ var Tracer = class Tracer {
693
716
  axis = normalizeAxis(axis, this.ndim);
694
717
  const n = axis.reduce((acc, a) => acc * this.shape[a], 1);
695
718
  if (n === 0) throw new Error("mean: cannot compute mean over zero-length axis");
696
- const result = reduce(this, AluOp.Add, axis, opts);
697
- return result.mul(1 / n);
719
+ const originalDtype = this.dtype;
720
+ const castDtype = promoteTypes(originalDtype, DType.Float32);
721
+ const result = reduce(this.astype(castDtype), AluOp.Add, axis, opts);
722
+ return result.mul(1 / n).astype(originalDtype);
698
723
  }
699
724
  /** Permute the dimensions of an array. Defaults to reversing the axis order. */
700
725
  transpose(perm) {
@@ -723,7 +748,7 @@ var Tracer = class Tracer {
723
748
  if (isFloatDtype(this.dtype)) return this.mul(reciprocal$1(other));
724
749
  return idiv(this, other);
725
750
  }
726
- /** Return specified diagonals. See `numpy.diagonal` for full docs. */
751
+ /** Return specified diagonals. See `jax.numpy.diagonal` for full docs. */
727
752
  diagonal(offset = 0, axis1 = 0, axis2 = 1) {
728
753
  if (!Number.isInteger(offset)) throw new TypeError(`offset must be an integer, got ${offset}`);
729
754
  if (offset < 0) return this.diagonal(-offset, axis2, axis1);
@@ -776,6 +801,34 @@ var Tracer = class Tracer {
776
801
  this.dispose();
777
802
  }
778
803
  /**
804
+ * Return a sorted copy of an array in ascending order.
805
+ *
806
+ * See `jax.numpy.sort` for full docs.
807
+ */
808
+ sort(axis = -1) {
809
+ axis = checkAxis(axis, this.ndim);
810
+ if (this.shape[axis] <= 1) return this;
811
+ if (axis === this.ndim - 1) return sort$1(this);
812
+ const perm = range(this.ndim);
813
+ perm.splice(axis, 1);
814
+ perm.push(axis);
815
+ return sort$1(this.transpose(perm)).transpose(invertPermutation(perm));
816
+ }
817
+ /**
818
+ * Return the indices that would sort an array. This may not be a stable
819
+ * sorting algorithm; it need not preserve order of indices in ties.
820
+ *
821
+ * See `jax.numpy.argsort` for full docs.
822
+ */
823
+ argsort(axis = -1) {
824
+ axis = checkAxis(axis, this.ndim);
825
+ if (axis === this.ndim - 1) return argsort$1(this)[1];
826
+ const perm = range(this.ndim);
827
+ perm.splice(axis, 1);
828
+ perm.push(axis);
829
+ return argsort$1(this.transpose(perm))[1].transpose(invertPermutation(perm));
830
+ }
831
+ /**
779
832
  * Slice an array along one or more axes.
780
833
  *
781
834
  * This is the equivalent of slicing in Python, e.g. `x[1:3, 2, :, None]`. To
@@ -892,6 +945,9 @@ var ShapedArray = class ShapedArray {
892
945
  get ndim() {
893
946
  return this.shape.length;
894
947
  }
948
+ get size() {
949
+ return prod(this.shape);
950
+ }
895
951
  toString() {
896
952
  return `${this.dtype}[${this.shape.join(",")}]`;
897
953
  }
@@ -1170,7 +1226,7 @@ var Jaxpr = class Jaxpr {
1170
1226
  } else if (eqn.primitive === Primitive.Idiv) {
1171
1227
  const [a, b] = inputs;
1172
1228
  const c = eqn.outBinders[0];
1173
- if (atomIsLit(b, 1)) context.set(c, a);
1229
+ if (atomIsLit(b, 1) && !isFloatDtype(a.aval.dtype)) context.set(c, a);
1174
1230
  else newEqns.push(eqn);
1175
1231
  } else if ((eqn.primitive === Primitive.Broadcast || eqn.primitive === Primitive.Reshape) && deepEqual(eqn.params.shape, eqn.inputs[0].aval.shape) || eqn.primitive === Primitive.Transpose && eqn.params.perm.every((p, i) => p === i) || eqn.primitive === Primitive.Flip && eqn.params.axis.length === 0 || eqn.primitive === Primitive.Shrink && eqn.params.slice.every(([s, e$2], i) => s === 0 && e$2 === eqn.inputs[0].aval.shape[i]) || eqn.primitive === Primitive.Pad && eqn.params.width.every(([w0, w1]) => w0 === 0 && w1 === 0)) context.set(eqn.outBinders[0], eqn.inputs[0]);
1176
1232
  else newEqns.push(eqn);
@@ -1187,13 +1243,13 @@ var Jaxpr = class Jaxpr {
1187
1243
  }
1188
1244
  return new Jaxpr(this.inBinders, liveEqns.reverse(), outs);
1189
1245
  }
1190
- /** Flattens nested JitCall in a Jaxpr. Useful for handling jit-of-jit. */
1246
+ /** Flattens nested Jit in a Jaxpr. Useful for handling jit-of-jit. */
1191
1247
  flatten() {
1192
- if (!this.eqns.some((eqn) => eqn.primitive === Primitive.JitCall)) return this;
1248
+ if (!this.eqns.some((eqn) => eqn.primitive === Primitive.Jit)) return this;
1193
1249
  const newEqns = [];
1194
1250
  const varMap = /* @__PURE__ */ new Map();
1195
1251
  const varMapF = (x) => x instanceof Var ? varMap.get(x) ?? x : x;
1196
- for (const eqn of this.eqns) if (eqn.primitive === Primitive.JitCall) {
1252
+ for (const eqn of this.eqns) if (eqn.primitive === Primitive.Jit) {
1197
1253
  const jaxpr = eqn.params.jaxpr.flatten();
1198
1254
  const translation = /* @__PURE__ */ new Map();
1199
1255
  const translationF = (x) => x instanceof Var ? translation.get(x) : x;
@@ -1294,19 +1350,48 @@ function evalJaxpr(jaxpr, args) {
1294
1350
  function jaxprAsFun(jaxpr) {
1295
1351
  return (...args) => evalJaxpr(jaxpr, args);
1296
1352
  }
1353
+ /** Jaxpr with a collection of associated, traced constants. */
1354
+ var ClosedJaxpr = class ClosedJaxpr {
1355
+ constructor(jaxpr, consts) {
1356
+ this.jaxpr = jaxpr;
1357
+ this.consts = consts;
1358
+ }
1359
+ /** String representation of this Jaxpr. */
1360
+ toString() {
1361
+ return this.jaxpr.toString();
1362
+ }
1363
+ /** Apply a function to the underlying Jaxpr. */
1364
+ mapJaxpr(f) {
1365
+ return new ClosedJaxpr(f(this.jaxpr), this.consts);
1366
+ }
1367
+ /** Dispose of the constants in this Jaxpr. */
1368
+ dispose() {
1369
+ for (const c of this.consts) c.dispose();
1370
+ }
1371
+ };
1297
1372
  /** Tracer that records its operations to dynamically construct a Jaxpr. */
1298
1373
  var JaxprTracer = class extends Tracer {
1374
+ #rc;
1299
1375
  constructor(trace$1, aval) {
1300
1376
  super(trace$1);
1301
1377
  this.aval = aval;
1378
+ this.#rc = 1;
1302
1379
  }
1303
1380
  toString() {
1304
1381
  return `JaxprTracer(${this.aval.toString()})`;
1305
1382
  }
1306
1383
  get ref() {
1384
+ if (this.#rc <= 0) throw new UseAfterFreeError(this);
1385
+ this.#rc++;
1307
1386
  return this;
1308
1387
  }
1309
- dispose() {}
1388
+ dispose() {
1389
+ if (this.#rc <= 0) throw new UseAfterFreeError(this);
1390
+ this.#rc--;
1391
+ }
1392
+ trackLiftedConstant() {
1393
+ this.#rc++;
1394
+ }
1310
1395
  };
1311
1396
  /** Analogous to the 'DynamicJaxprTrace' class in JAX. */
1312
1397
  var JaxprTrace = class extends Trace {
@@ -1319,17 +1404,24 @@ var JaxprTrace = class extends Trace {
1319
1404
  }
1320
1405
  /** Register a constant / literal in this Jaxpr. */
1321
1406
  getOrMakeConstTracer(val) {
1407
+ if (!(val instanceof Tracer)) val = pureArray(val);
1322
1408
  let tracer = this.builder.constTracers.get(val);
1323
1409
  if (tracer === void 0) {
1324
1410
  tracer = this.builder.newTracer(this, ShapedArray.fromAval(getAval(val)));
1325
- this.builder.addConst(tracer, val instanceof Tracer ? val.ref : array(val));
1411
+ this.builder.addConst(tracer, val);
1412
+ } else {
1413
+ val.dispose();
1414
+ tracer.trackLiftedConstant();
1326
1415
  }
1327
1416
  return tracer;
1328
1417
  }
1329
1418
  pure = this.getOrMakeConstTracer;
1330
1419
  lift = this.getOrMakeConstTracer;
1331
1420
  processPrimitive(primitive, tracers, params) {
1332
- const avalsIn = tracers.map((t) => t.aval);
1421
+ const avalsIn = tracers.map((t) => {
1422
+ t.dispose();
1423
+ return t.aval;
1424
+ });
1333
1425
  const avalsOut = abstractEvalRules[primitive](avalsIn, params);
1334
1426
  const outTracers = avalsOut.map((aval) => this.builder.newTracer(this, aval));
1335
1427
  this.builder.addEqn(new JaxprEqn(primitive, tracers.map((t) => this.builder.getVar(t)), params, outTracers.map((t) => this.builder.addVar(t))));
@@ -1372,20 +1464,17 @@ var JaxprBuilder = class {
1372
1464
  return v;
1373
1465
  }
1374
1466
  build(inTracers, outTracers) {
1375
- let [constVars, consts] = unzip2(this.constVals.entries());
1467
+ const [constVars, consts] = unzip2(this.constVals.entries());
1376
1468
  const t2v = this.getVar.bind(this);
1377
1469
  const inBinders = [...constVars, ...inTracers.map(t2v)];
1378
1470
  const outVars = outTracers.map(t2v);
1379
- let jaxpr = new Jaxpr(inBinders, this.eqns, outVars);
1471
+ const jaxpr = new Jaxpr(inBinders, this.eqns, outVars);
1380
1472
  typecheckJaxpr(jaxpr);
1381
- [jaxpr, consts] = _inlineLiterals(jaxpr, consts);
1382
- return {
1383
- jaxpr,
1384
- consts
1385
- };
1473
+ const cjaxpr = new ClosedJaxpr(jaxpr, consts);
1474
+ return _inlineLiterals(cjaxpr);
1386
1475
  }
1387
1476
  };
1388
- function _inlineLiterals(jaxpr, consts) {
1477
+ function _inlineLiterals({ jaxpr, consts }) {
1389
1478
  const literals = /* @__PURE__ */ new Map();
1390
1479
  const constBinders = [];
1391
1480
  const newConsts = [];
@@ -1400,7 +1489,7 @@ function _inlineLiterals(jaxpr, consts) {
1400
1489
  const newOuts = jaxpr.outs.map((x) => literals.get(x) ?? x);
1401
1490
  const newJaxpr = new Jaxpr([...constBinders, ...jaxpr.inBinders.slice(consts.length)], newEqns, newOuts);
1402
1491
  typecheckJaxpr(newJaxpr);
1403
- return [newJaxpr, newConsts];
1492
+ return new ClosedJaxpr(newJaxpr, newConsts);
1404
1493
  }
1405
1494
  function binopAbstractEval([x, y]) {
1406
1495
  if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("binopAbstractEval expects ShapedArray inputs");
@@ -1419,6 +1508,8 @@ const abstractEvalRules = {
1419
1508
  [Primitive.Mul]: binopAbstractEval,
1420
1509
  [Primitive.Idiv]: binopAbstractEval,
1421
1510
  [Primitive.Mod]: binopAbstractEval,
1511
+ [Primitive.Min]: binopAbstractEval,
1512
+ [Primitive.Max]: binopAbstractEval,
1422
1513
  [Primitive.Neg]: vectorizedUnopAbstractEval,
1423
1514
  [Primitive.Reciprocal]: vectorizedUnopAbstractEval,
1424
1515
  [Primitive.Floor]: vectorizedUnopAbstractEval,
@@ -1432,12 +1523,6 @@ const abstractEvalRules = {
1432
1523
  if (byteWidth(x.dtype) !== byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
1433
1524
  return [new ShapedArray(x.shape, dtype, false)];
1434
1525
  },
1435
- [Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
1436
- if (k0.dtype !== DType.Uint32 || k1.dtype !== DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
1437
- const keyShape = generalBroadcast(k0.shape, k1.shape);
1438
- if (!deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
1439
- return [new ShapedArray(shape$1, DType.Uint32, false)];
1440
- },
1441
1526
  [Primitive.Sin]: vectorizedUnopAbstractEval,
1442
1527
  [Primitive.Cos]: vectorizedUnopAbstractEval,
1443
1528
  [Primitive.Asin]: vectorizedUnopAbstractEval,
@@ -1447,8 +1532,6 @@ const abstractEvalRules = {
1447
1532
  [Primitive.Erf]: vectorizedUnopAbstractEval,
1448
1533
  [Primitive.Erfc]: vectorizedUnopAbstractEval,
1449
1534
  [Primitive.Sqrt]: vectorizedUnopAbstractEval,
1450
- [Primitive.Min]: binopAbstractEval,
1451
- [Primitive.Max]: binopAbstractEval,
1452
1535
  [Primitive.Reduce]([x], { axis }) {
1453
1536
  const axisSet = new Set(axis);
1454
1537
  const newShape = x.shape.filter((_, i) => !axisSet.has(i));
@@ -1481,6 +1564,25 @@ const abstractEvalRules = {
1481
1564
  const shape$1 = generalBroadcast(cond.shape, xy.shape);
1482
1565
  return [new ShapedArray(shape$1, xy.dtype, xy.weakType)];
1483
1566
  },
1567
+ [Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
1568
+ if (k0.dtype !== DType.Uint32 || k1.dtype !== DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
1569
+ const keyShape = generalBroadcast(k0.shape, k1.shape);
1570
+ if (!deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
1571
+ return [new ShapedArray(shape$1, DType.Uint32, false)];
1572
+ },
1573
+ [Primitive.Gather]([x, ...indices], { axis, outDim }) {
1574
+ for (const a of indices) if (a.dtype !== DType.Int32 && a.dtype !== DType.Uint32) throw new TypeError(`Gather indices must be Int32 or Uint32, got ${a.dtype}`);
1575
+ if (axis.length !== indices.length) throw new TypeError(`Gather: ${axis} axes but ${indices.length} indices`);
1576
+ if (indices.length === 0) throw new TypeError("Gather must have 1+ indices with same shape");
1577
+ if (axis.some((a) => a < 0 || a >= x.shape.length)) throw new TypeError("Gather axis out of bounds");
1578
+ if (outDim < 0 || outDim > x.shape.length - axis.length) throw new TypeError("Gather outDim out of bounds");
1579
+ const axisSet = new Set(axis);
1580
+ if (axisSet.size !== axis.length) throw new TypeError("Gather axes are not unique");
1581
+ const gatherShape = indices.reduce((shape$1, a) => generalBroadcast(shape$1, a.shape), []);
1582
+ const newShape = x.shape.filter((_, i) => !axisSet.has(i));
1583
+ newShape.splice(outDim, 0, ...gatherShape);
1584
+ return [new ShapedArray(newShape, x.dtype, x.weakType)];
1585
+ },
1484
1586
  [Primitive.Transpose]([x], { perm }) {
1485
1587
  return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype, x.weakType)];
1486
1588
  },
@@ -1501,23 +1603,31 @@ const abstractEvalRules = {
1501
1603
  const newShape = x.shape.map((dim, i) => dim + width[i][0] + width[i][1]);
1502
1604
  return [new ShapedArray(newShape, x.dtype, x.weakType)];
1503
1605
  },
1504
- [Primitive.Gather]([x, ...indices], { axis, outDim }) {
1505
- for (const a of indices) if (a.dtype !== DType.Int32 && a.dtype !== DType.Uint32) throw new TypeError(`Gather indices must be Int32 or Uint32, got ${a.dtype}`);
1506
- if (axis.length !== indices.length) throw new TypeError(`Gather: ${axis} axes but ${indices.length} indices`);
1507
- if (indices.length === 0) throw new TypeError("Gather must have 1+ indices with same shape");
1508
- if (axis.some((a) => a < 0 || a >= x.shape.length)) throw new TypeError("Gather axis out of bounds");
1509
- if (outDim < 0 || outDim > x.shape.length - axis.length) throw new TypeError("Gather outDim out of bounds");
1510
- const axisSet = new Set(axis);
1511
- if (axisSet.size !== axis.length) throw new TypeError("Gather axes are not unique");
1512
- const gatherShape = indices.reduce((shape$1, a) => generalBroadcast(shape$1, a.shape), []);
1513
- const newShape = x.shape.filter((_, i) => !axisSet.has(i));
1514
- newShape.splice(outDim, 0, ...gatherShape);
1515
- return [new ShapedArray(newShape, x.dtype, x.weakType)];
1606
+ [Primitive.Sort]([x]) {
1607
+ if (x.ndim === 0) throw new TypeError("sort: requires at least 1D input");
1608
+ return [ShapedArray.fromAval(x)];
1609
+ },
1610
+ [Primitive.Argsort]([x]) {
1611
+ if (x.ndim === 0) throw new TypeError("argsort: requires at least 1D input");
1612
+ return [ShapedArray.fromAval(x), new ShapedArray(x.shape, DType.Int32, false)];
1613
+ },
1614
+ [Primitive.TriangularSolve]([a, b]) {
1615
+ if (a.ndim < 2) throw new TypeError(`triangular_solve: a must be at least 2D, got ${a}`);
1616
+ if (b.ndim < 2) throw new TypeError(`triangular_solve: b must be at least 2D, got ${b}`);
1617
+ const [m, n] = a.shape.slice(-2);
1618
+ const [_batch, q] = b.shape.slice(-2);
1619
+ if (!deepEqual(a.shape.slice(0, -2), b.shape.slice(0, -2)) || a.dtype !== b.dtype || m !== n || n !== q) throw new TypeError(`triangular_solve: mismatch ${a} vs ${b}`);
1620
+ return [new ShapedArray(b.shape, b.dtype, a.weakType && b.weakType)];
1621
+ },
1622
+ [Primitive.Cholesky]([a]) {
1623
+ if (a.ndim < 2) throw new TypeError(`cholesky: requires at least 2D input, got ${a}`);
1624
+ if (a.shape[a.ndim - 2] !== a.shape[a.ndim - 1]) throw new TypeError(`cholesky: must be square, got ${a}`);
1625
+ return [ShapedArray.fromAval(a)];
1516
1626
  },
1517
- [Primitive.JitCall](args, { jaxpr }) {
1627
+ [Primitive.Jit](args, { jaxpr }) {
1518
1628
  const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
1519
- if (args.length !== inTypes.length) throw new TypeError(`jit_call expected ${inTypes.length} arguments, got ${args.length}`);
1520
- for (let i = 0; i < inTypes.length; i++) if (!args[i].equals(inTypes[i])) throw new TypeError(`jit_call argument ${i} has type ${args[i]}, expected ${inTypes[i]}`);
1629
+ if (args.length !== inTypes.length) throw new TypeError(`jit expected ${inTypes.length} arguments, got ${args.length}`);
1630
+ for (let i = 0; i < inTypes.length; i++) if (!args[i].equals(inTypes[i])) throw new TypeError(`jit argument ${i} has type ${args[i]}, expected ${inTypes[i]}`);
1521
1631
  return outTypes;
1522
1632
  }
1523
1633
  };
@@ -1553,11 +1663,10 @@ function makeJaxpr$1(f, opts) {
1553
1663
  const tracersIn = avalsIn.map((aval) => trace$1.newArg(typeof aval === "object" ? aval : pureArray(aval)));
1554
1664
  const outs = fFlat(...tracersIn);
1555
1665
  const tracersOut = outs.map((out) => fullRaise(trace$1, out));
1556
- const { jaxpr, consts } = builder.build(tracersIn, tracersOut);
1666
+ const jaxpr = builder.build(tracersIn, tracersOut);
1557
1667
  if (outTree.value === void 0) throw new Error("outTree was not set in makeJaxpr");
1558
1668
  return {
1559
- jaxpr: jaxpr.simplify(),
1560
- consts,
1669
+ jaxpr: jaxpr.mapJaxpr((j) => j.simplify()),
1561
1670
  treedef: outTree.value
1562
1671
  };
1563
1672
  } catch (_) {
@@ -1576,22 +1685,28 @@ function jit$1(f, opts) {
1576
1685
  const avalsInFlat = argsFlat.map((x) => ShapedArray.fromAval(getAval(x)));
1577
1686
  const avalsIn = unflatten(inTree, avalsInFlat);
1578
1687
  const jaxprArgs = joinIdx(args.length, staticArgs, avalsIn, staticArgnums);
1579
- const { jaxpr, consts, treedef: outTree } = runWithCache(cache, jaxprArgs, () => makeJaxpr$1(f, opts)(...jaxprArgs));
1580
- const outs = bind(Primitive.JitCall, [...consts.map((c) => c.ref), ...argsFlat], {
1688
+ const { jaxpr, treedef: outTree } = runWithCache(cache, jaxprArgs, () => makeJaxpr$1(f, opts)(...jaxprArgs));
1689
+ const outs = bind(Primitive.Jit, [...jaxpr.consts.map((c) => c.ref), ...argsFlat], {
1581
1690
  name: f.name || "closure",
1582
- jaxpr,
1583
- numConsts: consts.length
1691
+ jaxpr: jaxpr.jaxpr,
1692
+ numConsts: jaxpr.consts.length
1584
1693
  });
1585
1694
  return unflatten(outTree, outs);
1586
1695
  });
1587
1696
  result.dispose = () => {
1588
- for (const { consts } of cache.values()) for (const c of consts) c.dispose();
1697
+ for (const { jaxpr } of cache.values()) jaxpr.dispose();
1589
1698
  };
1590
1699
  return result;
1591
1700
  }
1592
1701
 
1593
1702
  //#endregion
1594
1703
  //#region src/frontend/jit.ts
1704
+ const routinePrimitives = new Map([
1705
+ [Primitive.Sort, Routines.Sort],
1706
+ [Primitive.Argsort, Routines.Argsort],
1707
+ [Primitive.TriangularSolve, Routines.TriangularSolve],
1708
+ [Primitive.Cholesky, Routines.Cholesky]
1709
+ ]);
1595
1710
  /** Result of compiling a Jaxpr. Can be evaluated on a series of inputs. */
1596
1711
  var JitProgram = class {
1597
1712
  constructor(backend, steps, inputs, outputs) {
@@ -1606,9 +1721,14 @@ var JitProgram = class {
1606
1721
  case "execute": {
1607
1722
  const inputsNice = step.inputs.map((id, i) => `${i}: %${id}`).join(", ");
1608
1723
  const outputsNice = step.outputs.map((id) => `%${id}`).join(", ");
1609
- return PPrint.pp(`execute (${inputsNice}) -> ${outputsNice}, kernel`).concat(step.kernel.pprint().indent(2));
1724
+ const executeText = `execute (${inputsNice}) -> ${outputsNice}`;
1725
+ if (step.source instanceof Kernel) return PPrint.pp(`${executeText}, kernel`).concat(step.source.pprint().indent(2));
1726
+ else if (step.source instanceof Routine) return PPrint.pp(`${executeText}, routine ${step.source.name}`);
1727
+ else {
1728
+ step.source;
1729
+ return PPrint.pp(executeText);
1730
+ }
1610
1731
  }
1611
- case "const": return PPrint.pp(`%${step.output} = const <Slot ${step.slot}>`);
1612
1732
  case "malloc": return PPrint.pp(`%${step.output} = malloc <${step.size} bytes>`);
1613
1733
  case "incref": return PPrint.pp(`incref ${step.input}`);
1614
1734
  case "free": return PPrint.pp(`free ${step.input}`);
@@ -1631,12 +1751,9 @@ var JitProgram = class {
1631
1751
  const inputs$1 = step.inputs.map((id) => scope.get(id));
1632
1752
  const outputs = step.outputs.map((id) => scope.get(id));
1633
1753
  if (inputs$1.some((s) => s === void 0) || outputs.some((s) => s === void 0)) throw new Error(`internal: JitProgram scope undefined`);
1634
- pending.push(new PendingExecute(this.backend, step.kernel, inputs$1, outputs));
1754
+ pending.push(new PendingExecute(this.backend, step.source, inputs$1, outputs));
1635
1755
  break;
1636
1756
  }
1637
- case "const":
1638
- scope.set(step.output, step.slot);
1639
- break;
1640
1757
  case "malloc": {
1641
1758
  const slot = this.backend.malloc(step.size);
1642
1759
  scope.set(step.output, slot);
@@ -1670,34 +1787,37 @@ var JitProgramBuilder = class {
1670
1787
  this.#nextId = nargs;
1671
1788
  this.steps = [];
1672
1789
  }
1673
- pushConst(slot) {
1674
- const id = this.#nextId++;
1675
- this.steps.push({
1676
- type: "const",
1677
- slot,
1678
- output: id
1679
- });
1680
- return id;
1681
- }
1682
1790
  pushLit(lit) {
1683
- const kernel = new Kernel(0, prod(lit.aval.shape), AluExp.const(lit.dtype, lit.value));
1791
+ const kernel = new Kernel(0, lit.aval.size, AluExp.const(lit.dtype, lit.value));
1684
1792
  return this.pushKernel(kernel, []);
1685
1793
  }
1686
- pushKernel(kernel, inputs) {
1794
+ pushBuffer(size$1) {
1687
1795
  const id = this.#nextId++;
1688
1796
  this.steps.push({
1689
1797
  type: "malloc",
1690
- size: kernel.bytes,
1798
+ size: size$1,
1691
1799
  output: id
1692
1800
  });
1801
+ return id;
1802
+ }
1803
+ pushKernel(kernel, inputs) {
1804
+ const id = this.pushBuffer(kernel.bytes);
1693
1805
  this.steps.push({
1694
1806
  type: "execute",
1695
- kernel,
1807
+ source: kernel,
1696
1808
  inputs,
1697
1809
  outputs: [id]
1698
1810
  });
1699
1811
  return id;
1700
1812
  }
1813
+ pushRoutine(routine, inputs, outputs) {
1814
+ this.steps.push({
1815
+ type: "execute",
1816
+ source: routine,
1817
+ inputs,
1818
+ outputs
1819
+ });
1820
+ }
1701
1821
  pushIncref(id) {
1702
1822
  this.steps.push({
1703
1823
  type: "incref",
@@ -1723,28 +1843,18 @@ var JitProgramBuilder = class {
1723
1843
  }
1724
1844
  };
1725
1845
  const jitCompileCache = /* @__PURE__ */ new Map();
1726
- function jitCompile(backend, jaxpr, consts) {
1727
- if (jaxpr.inBinders.length < consts.length) throw new TypeError(`Jaxpr has ${jaxpr.inBinders.length} inputs, but ${consts.length} consts were provided`);
1728
- for (let i = 0; i < consts.length; i++) if (consts[i].device !== backend.type) throw new TypeError(`Const ${i} has device ${consts[i].device}, but expected ${backend.type}`);
1729
- const cacheKey = backend.type + FpHash.hash(jaxpr, ...consts.map((c) => c.id));
1846
+ function jitCompile(backend, jaxpr) {
1847
+ const cacheKey = backend.type + "," + FpHash.hash(jaxpr);
1730
1848
  const cached = jitCompileCache.get(cacheKey);
1731
1849
  if (cached) return cached;
1732
1850
  if (DEBUG >= 1) console.info("=========== JIT Compile ===========\n" + jaxpr.toString());
1733
1851
  jaxpr = jaxpr.flatten().simplify();
1734
- const nargs = jaxpr.inBinders.length - consts.length;
1852
+ const nargs = jaxpr.inBinders.length;
1735
1853
  const builder = new JitProgramBuilder(backend, nargs);
1736
1854
  const blackNodes = splitGraphDataflow(backend, jaxpr);
1737
1855
  const ctx = /* @__PURE__ */ new Map();
1738
- for (let i = 0; i < consts.length; i++) {
1739
- const v = jaxpr.inBinders[i];
1740
- const slot = consts[i]._realizeSource();
1741
- ctx.set(v, {
1742
- type: "imm",
1743
- arg: builder.pushConst(slot)
1744
- });
1745
- }
1746
1856
  for (let i = 0; i < nargs; i++) {
1747
- const v = jaxpr.inBinders[consts.length + i];
1857
+ const v = jaxpr.inBinders[i];
1748
1858
  ctx.set(v, {
1749
1859
  type: "imm",
1750
1860
  arg: i
@@ -1752,51 +1862,101 @@ function jitCompile(backend, jaxpr, consts) {
1752
1862
  }
1753
1863
  for (let i = 0; i < jaxpr.eqns.length; i++) {
1754
1864
  const eqn = jaxpr.eqns[i];
1865
+ if (routinePrimitives.has(eqn.primitive)) {
1866
+ const routine = new Routine(routinePrimitives.get(eqn.primitive), {
1867
+ inputShapes: eqn.inputs.map((x) => x.aval.shape),
1868
+ inputDtypes: eqn.inputs.map((x) => x.aval.dtype),
1869
+ outputShapes: eqn.outBinders.map((x) => x.aval.shape),
1870
+ outputDtypes: eqn.outBinders.map((x) => x.aval.dtype)
1871
+ }, eqn.params);
1872
+ const inputs = [];
1873
+ for (const input of eqn.inputs) if (input instanceof Var) {
1874
+ const jv = ctx.get(input);
1875
+ if (jv.type !== "imm") throw new Error(`jit: routine primitive ${eqn.primitive} input is not imm`);
1876
+ inputs.push(jv.arg);
1877
+ } else if (input instanceof Lit) inputs.push(builder.pushLit(input));
1878
+ const outputs = [];
1879
+ for (const outVar$1 of eqn.outBinders) {
1880
+ const outId = builder.pushBuffer(outVar$1.aval.size * byteWidth(outVar$1.aval.dtype));
1881
+ outputs.push(outId);
1882
+ ctx.set(outVar$1, {
1883
+ type: "imm",
1884
+ arg: outId
1885
+ });
1886
+ }
1887
+ builder.pushRoutine(routine, inputs, outputs);
1888
+ continue;
1889
+ }
1755
1890
  const inputExps = [];
1756
1891
  const inputAvals = [];
1757
1892
  const inputArgs = [];
1758
- for (const input of eqn.inputs) if (input instanceof Var) {
1759
- const jitValue = ctx.get(input);
1760
- if (jitValue.type === "exp") {
1761
- const gidMap = /* @__PURE__ */ new Map();
1762
- for (const [gid, jitId] of jitValue.args.entries()) {
1763
- let newGid = inputArgs.indexOf(jitId);
1764
- if (newGid === -1) {
1765
- newGid = inputArgs.length;
1766
- inputArgs.push(jitId);
1767
- }
1768
- gidMap.set(gid, newGid);
1769
- }
1770
- inputExps.push(jitValue.exp.reindexGids(gidMap));
1771
- } else if (jitValue.type === "imm") {
1772
- let gid = inputArgs.indexOf(jitValue.arg);
1773
- if (gid === -1) {
1774
- gid = inputArgs.length;
1775
- inputArgs.push(jitValue.arg);
1893
+ let inputReduction = null;
1894
+ const addArgs = (args) => {
1895
+ const newGids = [];
1896
+ for (const jitId of args) {
1897
+ let newGid = inputArgs.indexOf(jitId);
1898
+ if (newGid === -1) {
1899
+ newGid = inputArgs.length;
1900
+ inputArgs.push(jitId);
1776
1901
  }
1902
+ newGids.push(newGid);
1903
+ }
1904
+ return newGids;
1905
+ };
1906
+ for (const input of eqn.inputs) if (input instanceof Var) {
1907
+ const jv = ctx.get(input);
1908
+ if (jv.type === "exp") {
1909
+ const newGids = addArgs(jv.args);
1910
+ inputExps.push(jv.exp.reindexGids(newGids));
1911
+ } else if (jv.type === "imm") {
1912
+ const [gid] = addArgs([jv.arg]);
1777
1913
  const st = ShapeTracker.fromShape(input.aval.shape);
1778
1914
  const indices = unravelAlu(st.shape, AluVar.gidx);
1779
1915
  inputExps.push(AluExp.globalView(input.aval.dtype, gid, st, indices));
1916
+ } else if (jv.type === "red") {
1917
+ if (inputReduction) throw new Error("jit: unexpected, multiple red inputs");
1918
+ const newGids = addArgs(jv.args);
1919
+ inputExps.push(jv.reduction.epilogue.reindexGids(newGids));
1920
+ inputReduction = jv;
1780
1921
  }
1781
1922
  inputAvals.push(input.aval);
1782
1923
  } else if (input instanceof Lit) {
1783
1924
  inputExps.push(AluExp.const(input.dtype, input.value));
1784
1925
  inputAvals.push(input.aval);
1785
1926
  } else throw new TypeError(`Unexpected input in Jaxpr: ${input}`);
1786
- const nargs$1 = inputArgs.length;
1787
1927
  const rule = jitRules[eqn.primitive];
1788
1928
  if (!rule) throw new TypeError(`JIT not implemented for primitive ${eqn.primitive}`);
1789
- const kernel = rule(nargs$1, inputExps, inputAvals, eqn.params);
1929
+ let exp$2;
1930
+ let reduction;
1931
+ if (inputReduction) {
1932
+ const jv = inputReduction;
1933
+ const newEpilogue = rule(inputExps, inputAvals, eqn.params).exp;
1934
+ exp$2 = jv.exp.reindexGids(addArgs(jv.args));
1935
+ reduction = new Reduction(jv.reduction.dtype, jv.reduction.op, jv.reduction.size, newEpilogue);
1936
+ } else {
1937
+ const ruleOutput = rule(inputExps, inputAvals, eqn.params);
1938
+ exp$2 = ruleOutput.exp;
1939
+ reduction = ruleOutput.reduction;
1940
+ }
1790
1941
  const outVar = eqn.outBinders[0];
1791
- if (kernel.reduction || blackNodes.has(outVar)) {
1942
+ if (blackNodes.has(outVar)) {
1943
+ const nargs$1 = inputArgs.length;
1944
+ const size$1 = outVar.aval.size;
1945
+ const kernel = new Kernel(nargs$1, size$1, exp$2, reduction);
1792
1946
  const outId = builder.pushKernel(kernel, inputArgs);
1793
1947
  ctx.set(outVar, {
1794
1948
  type: "imm",
1795
1949
  arg: outId
1796
1950
  });
1797
- } else ctx.set(outVar, {
1951
+ } else if (reduction) ctx.set(outVar, {
1952
+ type: "red",
1953
+ exp: exp$2,
1954
+ reduction,
1955
+ args: inputArgs
1956
+ });
1957
+ else ctx.set(outVar, {
1798
1958
  type: "exp",
1799
- exp: kernel.exp,
1959
+ exp: exp$2,
1800
1960
  args: inputArgs
1801
1961
  });
1802
1962
  }
@@ -1806,7 +1966,7 @@ function jitCompile(backend, jaxpr, consts) {
1806
1966
  if (jitValue.type !== "imm") throw new Error("internal: Expected imm, since outs are black nodes");
1807
1967
  outputIds.push(jitValue.arg);
1808
1968
  } else if (out instanceof Lit) outputIds.push(builder.pushLit(out));
1809
- const outputNeedsRef = new Set([...range(nargs), ...builder.steps.filter((s) => s.type === "const").map((s) => s.output)]);
1969
+ const outputNeedsRef = new Set(range(nargs));
1810
1970
  for (const outputId of outputIds) if (outputNeedsRef.has(outputId)) builder.pushIncref(outputId);
1811
1971
  else outputNeedsRef.add(outputId);
1812
1972
  builder.insertFreeSteps(outputIds);
@@ -1828,31 +1988,33 @@ function reshapeViews(exp$2, mapping, reduceAxis = false) {
1828
1988
  });
1829
1989
  }
1830
1990
  function broadcastedJit(fn, opts) {
1831
- return (nargs, exps, avals, params) => {
1991
+ return (exps, avals, params) => {
1832
1992
  let { shape: newShape, dtype: newDtype } = avals.reduce(promoteAvals);
1833
1993
  const skipCastIdx = opts?.skipCastIdx ?? [];
1834
1994
  if (skipCastIdx.length) newDtype = avals.filter((_, i) => !skipCastIdx.includes(i)).reduce(promoteAvals).dtype;
1835
- exps = exps.map((exp$3, i) => {
1836
- exp$3 = reshapeViews(exp$3, (st) => {
1995
+ exps = exps.map((exp$2, i) => {
1996
+ exp$2 = reshapeViews(exp$2, (st) => {
1837
1997
  if (!deepEqual(st.shape, newShape)) return st.broadcast(newShape, range(newShape.length - st.shape.length));
1838
1998
  });
1839
- if (exp$3.dtype !== newDtype && !skipCastIdx.includes(i)) exp$3 = AluExp.cast(newDtype, exp$3);
1840
- return exp$3;
1999
+ if (exp$2.dtype !== newDtype && !skipCastIdx.includes(i)) exp$2 = AluExp.cast(newDtype, exp$2);
2000
+ return exp$2;
1841
2001
  });
1842
- const exp$2 = fn(exps, params);
1843
- return new Kernel(nargs, prod(newShape), exp$2);
2002
+ return { exp: fn(exps, params) };
1844
2003
  };
1845
2004
  }
1846
2005
  function unopJit(fn) {
1847
- return (nargs, [a], [as], params) => {
1848
- return new Kernel(nargs, prod(as.shape), fn(a, params));
2006
+ return ([a], [_as], params) => {
2007
+ return { exp: fn(a, params) };
1849
2008
  };
1850
2009
  }
1851
2010
  function reshapeJit(fn) {
1852
- return (nargs, [a], [as], params) => {
1853
- a = reshapeViews(a, (st) => fn(st, params));
1854
- const newShape = fn(ShapeTracker.fromShape(as.shape), params).shape;
1855
- return new Kernel(nargs, prod(newShape), a);
2011
+ return ([a], [_as], params) => {
2012
+ return { exp: reshapeViews(a, (st) => fn(st, params)) };
2013
+ };
2014
+ }
2015
+ function routineNoJit() {
2016
+ return () => {
2017
+ throw new Error("jit: rule is not implemented for routines");
1856
2018
  };
1857
2019
  }
1858
2020
  const jitRules = {
@@ -1860,6 +2022,8 @@ const jitRules = {
1860
2022
  [Primitive.Mul]: broadcastedJit(([a, b]) => AluExp.mul(a, b)),
1861
2023
  [Primitive.Idiv]: broadcastedJit(([a, b]) => AluExp.idiv(a, b)),
1862
2024
  [Primitive.Mod]: broadcastedJit(([a, b]) => AluExp.mod(a, b)),
2025
+ [Primitive.Min]: broadcastedJit(([a, b]) => AluExp.min(a, b)),
2026
+ [Primitive.Max]: broadcastedJit(([a, b]) => AluExp.max(a, b)),
1863
2027
  [Primitive.Neg]: unopJit((a) => AluExp.sub(AluExp.const(a.dtype, 0), a)),
1864
2028
  [Primitive.Reciprocal]: unopJit(AluExp.reciprocal),
1865
2029
  [Primitive.Floor]: unopJit(AluExp.floor),
@@ -1867,17 +2031,6 @@ const jitRules = {
1867
2031
  [Primitive.StopGradient]: unopJit((a) => a),
1868
2032
  [Primitive.Cast]: unopJit((a, { dtype }) => AluExp.cast(dtype, a)),
1869
2033
  [Primitive.Bitcast]: unopJit((a, { dtype }) => AluExp.bitcast(dtype, a)),
1870
- [Primitive.RandomBits]: (nargs, keys, keyShapes, { shape: shape$1, mode }) => {
1871
- const mapping = (st) => {
1872
- if (!deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, range(shape$1.length - st.shape.length));
1873
- };
1874
- const k0 = reshapeViews(keys[0], mapping);
1875
- const k1 = reshapeViews(keys[1], mapping);
1876
- const c0 = AluExp.u32(0);
1877
- const c1 = AluExp.cast(DType.Uint32, AluVar.gidx);
1878
- const exp$2 = AluExp.threefry2x32(k0, k1, c0, c1, mode);
1879
- return new Kernel(nargs, prod(shape$1), exp$2);
1880
- },
1881
2034
  [Primitive.Sin]: unopJit(AluExp.sin),
1882
2035
  [Primitive.Cos]: unopJit(AluExp.cos),
1883
2036
  [Primitive.Asin]: unopJit(AluExp.asin),
@@ -1887,9 +2040,7 @@ const jitRules = {
1887
2040
  [Primitive.Erf]: unopJit(AluExp.erf),
1888
2041
  [Primitive.Erfc]: unopJit(AluExp.erfc),
1889
2042
  [Primitive.Sqrt]: unopJit(AluExp.sqrt),
1890
- [Primitive.Min]: broadcastedJit(([a, b]) => AluExp.min(a, b)),
1891
- [Primitive.Max]: broadcastedJit(([a, b]) => AluExp.max(a, b)),
1892
- [Primitive.Reduce](nargs, [a], [as], { op, axis }) {
2043
+ [Primitive.Reduce]([a], [as], { op, axis }) {
1893
2044
  const keptAxes = [];
1894
2045
  const shiftedAxes = [];
1895
2046
  const newShape = [];
@@ -1898,53 +2049,58 @@ const jitRules = {
1898
2049
  keptAxes.push(i);
1899
2050
  newShape.push(as.shape[i]);
1900
2051
  }
1901
- const size$1 = prod(newShape);
1902
2052
  const reductionSize = prod(shiftedAxes.map((ax) => as.shape[ax]));
1903
2053
  newShape.push(reductionSize);
1904
2054
  const perm = keptAxes.concat(shiftedAxes);
1905
2055
  a = reshapeViews(a, (st) => st.permute(perm).reshape(newShape), true);
1906
2056
  const reduction = new Reduction(a.dtype, op, reductionSize);
1907
- return new Kernel(nargs, size$1, a, reduction);
2057
+ return {
2058
+ exp: a,
2059
+ reduction
2060
+ };
1908
2061
  },
1909
2062
  [Primitive.Pool]: reshapeJit((st, { window, strides }) => pool(st, window, strides)),
1910
- [Primitive.PoolTranspose](nargs, [a], [as], { inShape, window, strides }) {
2063
+ [Primitive.PoolTranspose]([a], [as], { inShape, window, strides }) {
1911
2064
  let stX = poolTranspose(ShapeTracker.fromShape(as.shape), inShape, window, strides);
1912
- const size$1 = prod(inShape);
1913
2065
  stX = stX.reshape([...inShape, prod(stX.shape.slice(inShape.length))]);
1914
2066
  a = reshapeViews(a, (st) => st.compose(stX), true);
1915
2067
  const reduction = new Reduction(a.dtype, AluOp.Add, stX.shape[stX.shape.length - 1]);
1916
- return new Kernel(nargs, size$1, a, reduction);
2068
+ return {
2069
+ exp: a,
2070
+ reduction
2071
+ };
1917
2072
  },
1918
- [Primitive.Dot](nargs, [a, b], [as, bs]) {
1919
- const k1 = jitRules[Primitive.Mul](nargs, [a, b], [as, bs], {});
2073
+ [Primitive.Dot]([a, b], [as, bs]) {
2074
+ const k1 = jitRules[Primitive.Mul]([a, b], [as, bs], {});
1920
2075
  const c = k1.exp;
1921
2076
  const cs = promoteAvals(as, bs);
1922
- return jitRules[Primitive.Reduce](nargs, [c], [cs], {
2077
+ return jitRules[Primitive.Reduce]([c], [cs], {
1923
2078
  op: AluOp.Add,
1924
2079
  axis: [cs.ndim - 1]
1925
2080
  });
1926
2081
  },
1927
- [Primitive.Conv](nargs, [a, b], [as, bs], params) {
2082
+ [Primitive.Conv]([a, b], [as, bs], params) {
1928
2083
  const [stX, stY] = prepareConv(ShapeTracker.fromShape(as.shape), ShapeTracker.fromShape(bs.shape), params);
1929
2084
  a = reshapeViews(a, (st) => st.compose(stX));
1930
2085
  b = reshapeViews(b, (st) => st.compose(stY));
1931
2086
  as = new ShapedArray(stX.shape, as.dtype, as.weakType);
1932
2087
  bs = new ShapedArray(stY.shape, bs.dtype, bs.weakType);
1933
- return jitRules[Primitive.Dot](nargs, [a, b], [as, bs], {});
2088
+ return jitRules[Primitive.Dot]([a, b], [as, bs], {});
1934
2089
  },
1935
2090
  [Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
1936
2091
  [Primitive.Where]: broadcastedJit(([cond, a, b]) => AluExp.where(cond, a, b), { skipCastIdx: [0] }),
1937
- [Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
1938
- [Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
1939
- [Primitive.Reshape]: reshapeJit((st, { shape: shape$1 }) => st.reshape(shape$1)),
1940
- [Primitive.Flip]: reshapeJit((st, { axis }) => {
1941
- const arg = rep(st.shape.length, false);
1942
- for (const ax of axis) arg[ax] = true;
1943
- return st.flip(arg);
1944
- }),
1945
- [Primitive.Shrink]: reshapeJit((st, { slice }) => st.shrink(slice)),
1946
- [Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
1947
- [Primitive.Gather](nargs, [x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
2092
+ [Primitive.RandomBits]: (keys, keyShapes, { shape: shape$1, mode }) => {
2093
+ const mapping = (st) => {
2094
+ if (!deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, range(shape$1.length - st.shape.length));
2095
+ };
2096
+ const k0 = reshapeViews(keys[0], mapping);
2097
+ const k1 = reshapeViews(keys[1], mapping);
2098
+ const c0 = AluExp.u32(0);
2099
+ const c1 = AluExp.cast(DType.Uint32, AluVar.gidx);
2100
+ const exp$2 = AluExp.threefry2x32(k0, k1, c0, c1, mode);
2101
+ return { exp: exp$2 };
2102
+ },
2103
+ [Primitive.Gather]([x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
1948
2104
  const axisSet = new Set(axis);
1949
2105
  const indexShape = indicesShapes.map((c) => c.shape).reduce(generalBroadcast);
1950
2106
  const finalShape = xs.shape.filter((_, i) => !axisSet.has(i));
@@ -1957,24 +2113,38 @@ const jitRules = {
1957
2113
  for (const [i, iexp] of indices.entries()) src[axis[i]] = AluExp.cast(DType.Int32, reshapeViews(iexp, (st) => st.broadcast(finalShape, [...range(outDim + indexShape.length - st.shape.length), ...range(outDim + indexShape.length, finalShape.length)])));
1958
2114
  const [index, valid] = ShapeTracker.fromShape(xs.shape).toAluExp(src);
1959
2115
  if (!valid.resolve()) throw new Error("internal: expected full validity mask in Gather");
1960
- return new Kernel(nargs, prod(finalShape), x.substitute({ gidx: index }));
2116
+ return { exp: x.substitute({ gidx: index }) };
1961
2117
  },
1962
- [Primitive.JitCall]() {
1963
- throw new Error("internal: JitCall should have been flattened before JIT compilation");
2118
+ [Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
2119
+ [Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
2120
+ [Primitive.Reshape]: reshapeJit((st, { shape: shape$1 }) => st.reshape(shape$1)),
2121
+ [Primitive.Flip]: reshapeJit((st, { axis }) => {
2122
+ const arg = rep(st.shape.length, false);
2123
+ for (const ax of axis) arg[ax] = true;
2124
+ return st.flip(arg);
2125
+ }),
2126
+ [Primitive.Shrink]: reshapeJit((st, { slice }) => st.shrink(slice)),
2127
+ [Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
2128
+ [Primitive.Sort]: routineNoJit(),
2129
+ [Primitive.Argsort]: routineNoJit(),
2130
+ [Primitive.TriangularSolve]: routineNoJit(),
2131
+ [Primitive.Cholesky]: routineNoJit(),
2132
+ [Primitive.Jit]() {
2133
+ throw new Error("internal: Jit should have been flattened before JIT compilation");
1964
2134
  }
1965
2135
  };
1966
2136
  /** Determines how to split the Jaxpr into kernels via dataflow analysis. */
1967
2137
  function splitGraphDataflow(backend, jaxpr) {
1968
- const varToEqn = /* @__PURE__ */ new Map();
2138
+ const varToDefn = /* @__PURE__ */ new Map();
2139
+ const varToUsages = /* @__PURE__ */ new Map();
1969
2140
  for (let i = 0; i < jaxpr.eqns.length; i++) {
1970
2141
  const eqn = jaxpr.eqns[i];
1971
- for (const v of eqn.outBinders) if (v instanceof Var) varToEqn.set(v, i);
1972
- }
1973
- const blackNodes = /* @__PURE__ */ new Set();
1974
- const p1NextBlack = /* @__PURE__ */ new Map();
1975
- for (const v of jaxpr.outs) if (v instanceof Var) {
1976
- blackNodes.add(v);
1977
- p1NextBlack.set(v, v);
2142
+ for (const v of eqn.outBinders) if (v instanceof Var) varToDefn.set(v, i);
2143
+ for (const input of eqn.inputs) if (input instanceof Var) {
2144
+ const usages = varToUsages.get(input);
2145
+ if (usages) usages.push(i);
2146
+ else varToUsages.set(input, [i]);
2147
+ }
1978
2148
  }
1979
2149
  const reducePrimitives = [
1980
2150
  Primitive.Reduce,
@@ -1982,28 +2152,94 @@ function splitGraphDataflow(backend, jaxpr) {
1982
2152
  Primitive.Conv,
1983
2153
  Primitive.PoolTranspose
1984
2154
  ];
1985
- const heterogeneousViewPrimitives = [Primitive.Gather, Primitive.RandomBits];
1986
- for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
2155
+ const reductionEpilogueEqns = /* @__PURE__ */ new Set();
2156
+ const reductionEndpointEqns = /* @__PURE__ */ new Set();
2157
+ for (let i = 0; i < jaxpr.eqns.length; i++) {
1987
2158
  const eqn = jaxpr.eqns[i];
1988
- if (reducePrimitives.includes(eqn.primitive) || heterogeneousViewPrimitives.includes(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
1989
- for (const v of eqn.outBinders) {
1990
- blackNodes.add(v);
1991
- p1NextBlack.set(v, v);
2159
+ if (reducePrimitives.includes(eqn.primitive)) {
2160
+ let head = i;
2161
+ while (true) {
2162
+ reductionEpilogueEqns.add(head);
2163
+ const outVar = jaxpr.eqns[head].outBinders[0];
2164
+ const usages = varToUsages.get(outVar) ?? [];
2165
+ if (jaxpr.outs.includes(outVar) || usages.length !== 1) break;
2166
+ if (reductionEpilogueEqns.has(usages[0])) break;
2167
+ const nextEqn = jaxpr.eqns[usages[0]];
2168
+ switch (nextEqn.primitive) {
2169
+ case Primitive.Neg:
2170
+ case Primitive.Reciprocal:
2171
+ case Primitive.Floor:
2172
+ case Primitive.Ceil:
2173
+ case Primitive.StopGradient:
2174
+ case Primitive.Cast:
2175
+ case Primitive.Bitcast:
2176
+ case Primitive.Sin:
2177
+ case Primitive.Cos:
2178
+ case Primitive.Asin:
2179
+ case Primitive.Atan:
2180
+ case Primitive.Exp:
2181
+ case Primitive.Log:
2182
+ case Primitive.Erf:
2183
+ case Primitive.Erfc:
2184
+ case Primitive.Sqrt:
2185
+ head = usages[0];
2186
+ continue;
2187
+ case Primitive.Add:
2188
+ case Primitive.Mul:
2189
+ case Primitive.Idiv:
2190
+ case Primitive.Mod:
2191
+ case Primitive.Min:
2192
+ case Primitive.Max: {
2193
+ const otherInput = nextEqn.inputs.find((v) => v !== outVar);
2194
+ if (otherInput instanceof Lit || deepEqual(generalBroadcast(otherInput.aval.shape, outVar.aval.shape), outVar.aval.shape)) {
2195
+ head = usages[0];
2196
+ continue;
2197
+ }
2198
+ break;
2199
+ }
2200
+ }
2201
+ break;
1992
2202
  }
1993
- continue;
2203
+ reductionEndpointEqns.add(head);
1994
2204
  }
1995
- const reach = /* @__PURE__ */ new Set();
1996
- for (let j = i + 1; j < jaxpr.eqns.length; j++) for (const v of jaxpr.eqns[j].inputs) if (v instanceof Var && eqn.outBinders.includes(v)) for (const o of jaxpr.eqns[j].outBinders) {
1997
- const u = p1NextBlack.get(o);
1998
- if (u) reach.add(u);
1999
- }
2000
- if (reach.size === 1) {
2001
- const b = reach.values().next().value;
2002
- for (const v of eqn.outBinders) p1NextBlack.set(v, b);
2003
- } else if (reach.size > 1) for (const v of eqn.outBinders) {
2205
+ }
2206
+ const blackNodes = /* @__PURE__ */ new Set();
2207
+ const p1NextBlack = /* @__PURE__ */ new Map();
2208
+ for (const v of jaxpr.outs) if (v instanceof Var) {
2209
+ blackNodes.add(v);
2210
+ p1NextBlack.set(v, v);
2211
+ }
2212
+ const heterogeneousViewPrimitives = [Primitive.RandomBits, Primitive.Gather];
2213
+ const needsCleanShapePrimitives = [Primitive.Pad];
2214
+ for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
2215
+ const eqn = jaxpr.eqns[i];
2216
+ if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || routinePrimitives.has(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
2217
+ for (const v of eqn.outBinders) {
2218
+ blackNodes.add(v);
2219
+ p1NextBlack.set(v, v);
2220
+ }
2221
+ continue;
2222
+ }
2223
+ const reach = /* @__PURE__ */ new Set();
2224
+ let needsCleanOutput = false;
2225
+ outer: for (const v of eqn.outBinders) for (const j of varToUsages.get(v) ?? []) {
2226
+ if (needsCleanShapePrimitives.includes(jaxpr.eqns[j].primitive) || routinePrimitives.has(jaxpr.eqns[j].primitive)) {
2227
+ needsCleanOutput = true;
2228
+ break outer;
2229
+ }
2230
+ for (const o of jaxpr.eqns[j].outBinders) {
2231
+ const u = p1NextBlack.get(o);
2232
+ if (u) reach.add(u);
2233
+ }
2234
+ }
2235
+ if (reach.size > 1 || needsCleanOutput) for (const v of eqn.outBinders) {
2004
2236
  blackNodes.add(v);
2005
2237
  p1NextBlack.set(v, v);
2006
2238
  }
2239
+ else if (reach.size === 1) {
2240
+ const b = reach.values().next().value;
2241
+ for (const v of eqn.outBinders) p1NextBlack.set(v, b);
2242
+ }
2007
2243
  }
2008
2244
  const p2Deps = /* @__PURE__ */ new Map();
2009
2245
  for (const v of jaxpr.inBinders) p2Deps.set(v, new Set([v]));
@@ -2011,7 +2247,6 @@ function splitGraphDataflow(backend, jaxpr) {
2011
2247
  while (p2idx < jaxpr.eqns.length) {
2012
2248
  const eqn = jaxpr.eqns[p2idx++];
2013
2249
  const deps = [];
2014
- if (eqn.outBinders.some((v) => blackNodes.has(v))) continue;
2015
2250
  for (const input of eqn.inputs) if (input instanceof Var) if (blackNodes.has(input)) deps.push(new Set([input]));
2016
2251
  else deps.push(p2Deps.get(input));
2017
2252
  else deps.push(/* @__PURE__ */ new Set());
@@ -2022,7 +2257,7 @@ function splitGraphDataflow(backend, jaxpr) {
2022
2257
  let assocInput = -1;
2023
2258
  for (let i = 0; i < eqn.inputs.length; i++) {
2024
2259
  const input = eqn.inputs[i];
2025
- if (input instanceof Var && varToEqn.has(input)) {
2260
+ if (input instanceof Var && varToDefn.has(input)) {
2026
2261
  let uniqueDeps = 0;
2027
2262
  for (const dep of deps[i]) if (depCounter.get(dep) === 1) uniqueDeps++;
2028
2263
  if (uniqueDeps > maxUniqueDeps) {
@@ -2033,8 +2268,8 @@ function splitGraphDataflow(backend, jaxpr) {
2033
2268
  }
2034
2269
  if (assocInput === -1) throw new Error(`internal: maxArgs, no input found to mark as black in Jaxpr equation ${eqn}`);
2035
2270
  const assocVar = eqn.inputs[assocInput];
2036
- p2idx = varToEqn.get(assocVar);
2037
- for (const out of jaxpr.eqns[p2idx].outBinders) blackNodes.add(out);
2271
+ p2idx = varToDefn.get(assocVar);
2272
+ for (const out of jaxpr.eqns[p2idx++].outBinders) blackNodes.add(out);
2038
2273
  } else {
2039
2274
  const s = new Set(depCounter.keys());
2040
2275
  for (const out of eqn.outBinders) p2Deps.set(out, s);
@@ -2060,9 +2295,9 @@ var PendingExecute = class {
2060
2295
  submitted = false;
2061
2296
  #promise = null;
2062
2297
  #rc = 1;
2063
- constructor(backend, kernel, inputs, outputs) {
2298
+ constructor(backend, source, inputs, outputs) {
2064
2299
  this.backend = backend;
2065
- this.kernel = kernel;
2300
+ this.source = source;
2066
2301
  this.inputs = inputs;
2067
2302
  this.outputs = outputs;
2068
2303
  for (const slot of inputs) this.backend.incRef(slot);
@@ -2083,13 +2318,15 @@ var PendingExecute = class {
2083
2318
  return;
2084
2319
  }
2085
2320
  this.#promise = (async () => {
2086
- this.prepared = await this.backend.prepare(this.kernel);
2321
+ if (this.source instanceof Kernel) this.prepared = await this.backend.prepareKernel(this.source);
2322
+ else this.prepared = await this.backend.prepareRoutine(this.source);
2087
2323
  })();
2088
2324
  await this.#promise;
2089
2325
  }
2090
2326
  prepareSync() {
2091
2327
  if (this.prepared) return;
2092
- this.prepared = this.backend.prepareSync(this.kernel);
2328
+ if (this.source instanceof Kernel) this.prepared = this.backend.prepareKernelSync(this.source);
2329
+ else this.prepared = this.backend.prepareRoutineSync(this.source);
2093
2330
  }
2094
2331
  submit() {
2095
2332
  if (this.submitted) return;
@@ -2112,8 +2349,6 @@ var PendingExecute = class {
2112
2349
  * "Array" type by name.
2113
2350
  */
2114
2351
  var Array$1 = class Array$1 extends Tracer {
2115
- static #nextId = 1001;
2116
- id;
2117
2352
  #dtype;
2118
2353
  #weakType;
2119
2354
  #source;
@@ -2130,7 +2365,6 @@ var Array$1 = class Array$1 extends Tracer {
2130
2365
  */
2131
2366
  constructor(args) {
2132
2367
  super(baseArrayTrace);
2133
- this.id = Array$1.#nextId++;
2134
2368
  this.#dtype = args.dtype;
2135
2369
  this.#weakType = args.weakType;
2136
2370
  this.#source = args.source;
@@ -2439,6 +2673,27 @@ var Array$1 = class Array$1 extends Tracer {
2439
2673
  pending
2440
2674
  });
2441
2675
  }
2676
+ /** Apply an operation with custom lowering to this array. */
2677
+ static #routine(routine, arrays, outputWeakType) {
2678
+ const { backend, committed } = Array$1.#computeBackend(routine.name, arrays);
2679
+ for (const ar of arrays) ar.#realize();
2680
+ const inputs = arrays.map((ar) => ar.#source);
2681
+ const outputs = routine.type.outputDtypes.map((dtype, i) => backend.malloc(byteWidth(dtype) * prod(routine.type.outputShapes[i])));
2682
+ const pending = arrays.flatMap((ar) => ar.#pending);
2683
+ for (const exe of pending) exe.updateRc(+outputs.length);
2684
+ pending.push(new PendingExecute(backend, routine, inputs, outputs));
2685
+ pending[pending.length - 1].updateRc(+outputs.length - 1);
2686
+ arrays.forEach((ar) => ar.dispose());
2687
+ return outputs.map((output, i) => new Array$1({
2688
+ source: output,
2689
+ st: ShapeTracker.fromShape(routine.type.outputShapes[i]),
2690
+ dtype: routine.type.outputDtypes[i],
2691
+ weakType: outputWeakType[i],
2692
+ backend,
2693
+ committed,
2694
+ pending
2695
+ }));
2696
+ }
2442
2697
  /**
2443
2698
  * Normalizes this array into one backed by a `Slot`.
2444
2699
  *
@@ -2599,6 +2854,12 @@ var Array$1 = class Array$1 extends Tracer {
2599
2854
  [Primitive.Mod]([x, y]) {
2600
2855
  return [x.#binary(AluOp.Mod, y)];
2601
2856
  },
2857
+ [Primitive.Min]([x, y]) {
2858
+ return [x.#binary(AluOp.Min, y)];
2859
+ },
2860
+ [Primitive.Max]([x, y]) {
2861
+ return [x.#binary(AluOp.Max, y)];
2862
+ },
2602
2863
  [Primitive.Neg]([x]) {
2603
2864
  return [zerosLike$1(x.ref).#binary(AluOp.Sub, x)];
2604
2865
  },
@@ -2635,25 +2896,6 @@ var Array$1 = class Array$1 extends Tracer {
2635
2896
  return [y];
2636
2897
  }
2637
2898
  },
2638
- [Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
2639
- const keyShape = generalBroadcast(k0.shape, k1.shape);
2640
- if (!deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
2641
- const c0 = zeros(shape$1, {
2642
- dtype: DType.Uint32,
2643
- device: k0.device
2644
- });
2645
- const c1 = arange(0, prod(shape$1), 1, {
2646
- dtype: DType.Uint32,
2647
- device: k0.device
2648
- }).reshape(shape$1);
2649
- const custom = ([k0$1, k1$1, c0$1, c1$1]) => AluExp.threefry2x32(k0$1, k1$1, c0$1, c1$1, mode);
2650
- return [Array$1.#naryCustom("random_bits", custom, [
2651
- k0,
2652
- k1,
2653
- c0,
2654
- c1
2655
- ])];
2656
- },
2657
2899
  [Primitive.Sin]([x]) {
2658
2900
  return [x.#unary(AluOp.Sin)];
2659
2901
  },
@@ -2681,12 +2923,6 @@ var Array$1 = class Array$1 extends Tracer {
2681
2923
  [Primitive.Sqrt]([x]) {
2682
2924
  return [x.#unary(AluOp.Sqrt)];
2683
2925
  },
2684
- [Primitive.Min]([x, y]) {
2685
- return [x.#binary(AluOp.Min, y)];
2686
- },
2687
- [Primitive.Max]([x, y]) {
2688
- return [x.#binary(AluOp.Max, y)];
2689
- },
2690
2926
  [Primitive.Reduce]([x], { op, axis }) {
2691
2927
  if (axis.length === 0) return [x];
2692
2928
  return [x.#moveAxesDown(axis).#reduce(op)];
@@ -2721,6 +2957,28 @@ var Array$1 = class Array$1 extends Tracer {
2721
2957
  y
2722
2958
  ], { dtypeOverride: [DType.Bool] })];
2723
2959
  },
2960
+ [Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
2961
+ const keyShape = generalBroadcast(k0.shape, k1.shape);
2962
+ if (!deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
2963
+ const c0 = zeros(shape$1, {
2964
+ dtype: DType.Uint32,
2965
+ device: k0.device
2966
+ });
2967
+ const c1 = arange(0, prod(shape$1), 1, {
2968
+ dtype: DType.Uint32,
2969
+ device: k0.device
2970
+ }).reshape(shape$1);
2971
+ const custom = ([k0$1, k1$1, c0$1, c1$1]) => AluExp.threefry2x32(k0$1, k1$1, c0$1, c1$1, mode);
2972
+ return [Array$1.#naryCustom("random_bits", custom, [
2973
+ k0,
2974
+ k1,
2975
+ c0,
2976
+ c1
2977
+ ])];
2978
+ },
2979
+ [Primitive.Gather]([x, ...indices], { axis, outDim }) {
2980
+ return [x.#gather(indices, axis, outDim)];
2981
+ },
2724
2982
  [Primitive.Transpose]([x], { perm }) {
2725
2983
  return [x.#transpose(perm)];
2726
2984
  },
@@ -2741,17 +2999,48 @@ var Array$1 = class Array$1 extends Tracer {
2741
2999
  [Primitive.Pad]([x], { width }) {
2742
3000
  return [x.#reshape(x.#st.pad(width))];
2743
3001
  },
2744
- [Primitive.Gather]([x, ...indices], { axis, outDim }) {
2745
- return [x.#gather(indices, axis, outDim)];
3002
+ [Primitive.Sort]([x]) {
3003
+ const routine = new Routine(Routines.Sort, {
3004
+ inputShapes: [x.aval.shape],
3005
+ inputDtypes: [x.aval.dtype],
3006
+ outputShapes: [x.aval.shape],
3007
+ outputDtypes: [x.aval.dtype]
3008
+ });
3009
+ return Array$1.#routine(routine, [x], [x.#weakType]);
3010
+ },
3011
+ [Primitive.Argsort]([x]) {
3012
+ const routine = new Routine(Routines.Argsort, {
3013
+ inputShapes: [x.aval.shape],
3014
+ inputDtypes: [x.aval.dtype],
3015
+ outputShapes: [x.aval.shape, x.aval.shape],
3016
+ outputDtypes: [x.aval.dtype, DType.Int32]
3017
+ });
3018
+ return Array$1.#routine(routine, [x], [x.#weakType, false]);
3019
+ },
3020
+ [Primitive.TriangularSolve]([a, b], { unitDiagonal }) {
3021
+ const routine = new Routine(Routines.TriangularSolve, {
3022
+ inputShapes: [a.aval.shape, b.aval.shape],
3023
+ inputDtypes: [a.aval.dtype, b.aval.dtype],
3024
+ outputShapes: [b.aval.shape],
3025
+ outputDtypes: [b.aval.dtype]
3026
+ }, { unitDiagonal });
3027
+ return Array$1.#routine(routine, [a, b], [a.#weakType && b.#weakType]);
2746
3028
  },
2747
- [Primitive.JitCall](args, { jaxpr, numConsts }) {
2748
- if (jaxpr.inBinders.length !== args.length) throw new Error(`jit_call expects ${jaxpr.inBinders.length} args, got ${args.length}`);
2749
- const { backend, committed } = Array$1.#computeBackend("jit_call", args);
3029
+ [Primitive.Cholesky]([a]) {
3030
+ const routine = new Routine(Routines.Cholesky, {
3031
+ inputShapes: [a.aval.shape],
3032
+ inputDtypes: [a.aval.dtype],
3033
+ outputShapes: [a.aval.shape],
3034
+ outputDtypes: [a.aval.dtype]
3035
+ });
3036
+ return Array$1.#routine(routine, [a], [a.#weakType]);
3037
+ },
3038
+ [Primitive.Jit](args, { jaxpr }) {
3039
+ if (jaxpr.inBinders.length !== args.length) throw new Error(`jit expects ${jaxpr.inBinders.length} args, got ${args.length}`);
3040
+ const { backend, committed } = Array$1.#computeBackend("jit", args);
2750
3041
  args = args.map((ar) => ar._putSync(backend));
2751
- const consts = args.slice(0, numConsts);
2752
- const tracers = args.slice(numConsts);
2753
- const jp = jitCompile(backend, jaxpr, consts);
2754
- const { outputs, pending } = jp.execute(tracers.map((x) => x._realizeSource()));
3042
+ const jp = jitCompile(backend, jaxpr);
3043
+ const { outputs, pending } = jp.execute(args.map((x) => x._realizeSource()));
2755
3044
  for (const exe of pending) exe.updateRc(+outputs.length - 1);
2756
3045
  const prevPending = [...new Set(args.flatMap((x) => x.#pending))];
2757
3046
  for (const exe of prevPending) exe.updateRc(+outputs.length);
@@ -3050,6 +3339,43 @@ function arange(start, stop, step = 1, { dtype, device } = {}) {
3050
3339
  });
3051
3340
  }
3052
3341
  /**
3342
+ * Return an array with ones on and below the diagonal and zeros elsewhere.
3343
+ *
3344
+ * If `k` is provided, it specifies the sub-diagonal on and below which the
3345
+ * array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
3346
+ * `k>0` is above it.
3347
+ */
3348
+ function tri(n, m, k = 0, { dtype, device } = {}) {
3349
+ m ??= n;
3350
+ dtype ??= DType.Float32;
3351
+ if (!Number.isInteger(n) || n < 0) throw new Error(`tri: n must be a non-negative integer, got ${n}`);
3352
+ if (!Number.isInteger(m) || m < 0) throw new Error(`tri: m must be a non-negative integer, got ${m}`);
3353
+ if (!Number.isInteger(k)) throw new Error(`tri: k must be an integer, got ${k}`);
3354
+ const rows = arange(k, n + k, 1, {
3355
+ dtype: DType.Int32,
3356
+ device
3357
+ });
3358
+ const cols = arange(0, m, 1, {
3359
+ dtype: DType.Int32,
3360
+ device
3361
+ });
3362
+ return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
3363
+ }
3364
+ /** Return the lower triangle of an array. Must be of dimension >= 2. */
3365
+ function tril(a, k = 0) {
3366
+ if (ndim$1(a) < 2) throw new Error(`tril: input array must be at least 2D, got ${ndim$1(a)}D`);
3367
+ a = fudgeArray(a);
3368
+ const [n, m] = a.shape.slice(-2);
3369
+ return where$1(tri(n, m, k, { dtype: DType.Bool }), a.ref, zerosLike$1(a));
3370
+ }
3371
+ /** Return the upper triangle of an array. Must be of dimension >= 2. */
3372
+ function triu(a, k = 0) {
3373
+ if (ndim$1(a) < 2) throw new Error(`tril: input array must be at least 2D, got ${ndim$1(a)}D`);
3374
+ a = fudgeArray(a);
3375
+ const [n, m] = a.shape.slice(-2);
3376
+ return where$1(tri(n, m, k - 1, { dtype: DType.Bool }), zerosLike$1(a.ref), a);
3377
+ }
3378
+ /**
3053
3379
  * Return evenly spaced numbers over a specified interval.
3054
3380
  *
3055
3381
  * Returns _num_ evenly spaced samples, calculated over the interval
@@ -3096,333 +3422,106 @@ function aluCompare(a, b, op) {
3096
3422
  }
3097
3423
 
3098
3424
  //#endregion
3099
- //#region src/frontend/jvp.ts
3100
- var JVPTracer = class extends Tracer {
3101
- constructor(trace$1, primal, tangent) {
3425
+ //#region src/frontend/vmap.ts
3426
+ function mappedAval(batchDim, aval) {
3427
+ const shape$1 = [...aval.shape];
3428
+ shape$1.splice(batchDim, 1);
3429
+ return new ShapedArray(shape$1, aval.dtype, aval.weakType);
3430
+ }
3431
+ /** Move one axis to a different index. */
3432
+ function moveaxis(x, src, dst) {
3433
+ const t = pureArray(x);
3434
+ src = checkAxis(src, t.ndim);
3435
+ dst = checkAxis(dst, t.ndim);
3436
+ if (src === dst) return t;
3437
+ const perm = range(t.ndim);
3438
+ perm.splice(src, 1);
3439
+ perm.splice(dst, 0, src);
3440
+ return transpose$1(t, perm);
3441
+ }
3442
+ function moveBatchAxis(axisSize, src, dst, x) {
3443
+ if (src === null) {
3444
+ const targetShape = [...x.shape];
3445
+ targetShape.splice(dst, 0, axisSize);
3446
+ return broadcast(x, targetShape, [dst]);
3447
+ } else if (src === dst) return x;
3448
+ else return moveaxis(x, src, dst);
3449
+ }
3450
+ var BatchTracer = class extends Tracer {
3451
+ constructor(trace$1, val, batchDim) {
3102
3452
  super(trace$1);
3103
- this.primal = primal;
3104
- this.tangent = tangent;
3453
+ this.val = val;
3454
+ this.batchDim = batchDim;
3105
3455
  }
3106
3456
  get aval() {
3107
- return this.primal.aval;
3457
+ if (this.batchDim === null) return this.val.aval;
3458
+ else return mappedAval(this.batchDim, this.val.aval);
3108
3459
  }
3109
3460
  toString() {
3110
- return `JVPTracer(${this.primal.toString()}, ${this.tangent.toString()})`;
3461
+ return `BatchTracer(${this.val.toString()}, ${this.batchDim})`;
3111
3462
  }
3112
3463
  get ref() {
3113
- this.primal.ref, this.tangent.ref;
3464
+ this.val.ref;
3114
3465
  return this;
3115
3466
  }
3116
3467
  dispose() {
3117
- this.primal.dispose();
3118
- this.tangent.dispose();
3468
+ this.val.dispose();
3469
+ }
3470
+ fullLower() {
3471
+ if (this.batchDim === null) return this.val.fullLower();
3472
+ else return this;
3119
3473
  }
3120
3474
  };
3121
- var JVPTrace = class extends Trace {
3475
+ var BatchTrace = class extends Trace {
3122
3476
  pure(val) {
3123
3477
  return this.lift(pureArray(val));
3124
3478
  }
3125
3479
  lift(val) {
3126
- return new JVPTracer(this, val, zerosLike$1(val.ref));
3480
+ return new BatchTracer(this, val, null);
3127
3481
  }
3128
3482
  processPrimitive(primitive, tracers, params) {
3129
- const [primalsIn, tangentsIn] = unzip2(tracers.map((x) => [x.primal, x.tangent]));
3130
- const jvpRule = jvpRules[primitive];
3131
- if (jvpRule === void 0) throw new Error(`No JVP rule for: ${primitive}`);
3132
- const [primalsOut, tangentsOut] = jvpRule(primalsIn, tangentsIn, params);
3133
- return zip(primalsOut, tangentsOut).map(([x, t]) => new JVPTracer(this, x, t));
3483
+ const [valsIn, bdimsIn] = unzip2(tracers.map((t) => [t.val, t.batchDim]));
3484
+ const vmapRule = vmapRules[primitive];
3485
+ if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
3486
+ if (bdimsIn.every((d) => d === null)) {
3487
+ const valOuts$1 = bind(primitive, valsIn, params);
3488
+ return valOuts$1.map((x) => new BatchTracer(this, x, null));
3489
+ }
3490
+ const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
3491
+ return zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
3492
+ }
3493
+ get axisSize() {
3494
+ return this.main.globalData;
3134
3495
  }
3135
3496
  };
3136
- /** Rule that applies the same operation to primals and tangents. */
3137
- function linearTangentsJvp(primitive) {
3138
- return (primals, tangents, params) => {
3139
- const ys = bind(primitive, primals, params);
3140
- const dys = bind(primitive, tangents, params);
3141
- return [ys, dys];
3142
- };
3143
- }
3144
- /** Rule for product of gradients in bilinear operations. */
3145
- function bilinearTangentsJvp(primitive) {
3146
- return ([x, y], [dx, dy], params) => {
3147
- const primal = bind1(primitive, [x.ref, y.ref], params);
3148
- const tangent = bind1(primitive, [x, dy], params).add(bind1(primitive, [dx, y], params));
3149
- return [[primal], [tangent]];
3497
+ /**
3498
+ * Process a primitive with built-in broadcasting.
3499
+ *
3500
+ * Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
3501
+ */
3502
+ function broadcastBatcher(op) {
3503
+ return (axisSize, args, dims) => {
3504
+ if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
3505
+ const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
3506
+ const firstIdx = dims.findIndex((d) => d !== null);
3507
+ const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
3508
+ if (zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
3509
+ args = args.map((x, i) => {
3510
+ if (dims[i] === null) return x;
3511
+ x = moveBatchAxis(axisSize, dims[i], 0, x);
3512
+ if (x.ndim < nd) x = x.reshape([
3513
+ x.shape[0],
3514
+ ...rep(nd - x.ndim, 1),
3515
+ ...x.shape.slice(1)
3516
+ ]);
3517
+ return x;
3518
+ });
3519
+ return [[op(...args)], [0]];
3150
3520
  };
3151
3521
  }
3152
- /** Rule that zeros out any tangents. */
3153
- function zeroTangentsJvp(primitive) {
3154
- return (primals, tangents, params) => {
3155
- for (const t of tangents) t.dispose();
3156
- const ys = bind(primitive, primals, params);
3157
- return [ys, ys.map((y) => zerosLike$1(y.ref))];
3158
- };
3159
- }
3160
- const jvpRules = {
3161
- [Primitive.Add]: linearTangentsJvp(Primitive.Add),
3162
- [Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
3163
- [Primitive.Idiv]: zeroTangentsJvp(Primitive.Idiv),
3164
- [Primitive.Mod]([x, y], [dx, dy]) {
3165
- if (!isFloatDtype(x.dtype) && !isFloatDtype(y.dtype)) {
3166
- dx.dispose();
3167
- dy.dispose();
3168
- return [[x.ref, y.ref], [zerosLike$1(x), zerosLike$1(y)]];
3169
- }
3170
- const q = idiv(x.ref, y.ref);
3171
- return [[mod(x, y)], [dx.sub(dy.mul(q))]];
3172
- },
3173
- [Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
3174
- [Primitive.Reciprocal]([x], [dx]) {
3175
- const xRecip = reciprocal$1(x.ref);
3176
- return [[xRecip.ref], [neg(xRecip.ref.mul(xRecip)).mul(dx)]];
3177
- },
3178
- [Primitive.Floor]: zeroTangentsJvp(Primitive.Floor),
3179
- [Primitive.Ceil]: zeroTangentsJvp(Primitive.Ceil),
3180
- [Primitive.StopGradient]: zeroTangentsJvp(Primitive.StopGradient),
3181
- [Primitive.Cast]([x], [dx], { dtype }) {
3182
- if (x.dtype === dtype) return [[x], [dx]];
3183
- if (isFloatDtype(dtype) && isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
3184
- else {
3185
- dx.dispose();
3186
- return [[cast(x.ref, dtype)], [zerosLike$1(x)]];
3187
- }
3188
- },
3189
- [Primitive.Bitcast]([x], [dx], { dtype }) {
3190
- if (x.dtype === dtype) return [[x], [dx]];
3191
- dx.dispose();
3192
- return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
3193
- },
3194
- [Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
3195
- [Primitive.Sin]([x], [dx]) {
3196
- return [[sin$1(x.ref)], [cos$1(x).mul(dx)]];
3197
- },
3198
- [Primitive.Cos]([x], [dx]) {
3199
- return [[cos$1(x.ref)], [neg(sin$1(x)).mul(dx)]];
3200
- },
3201
- [Primitive.Asin]([x], [dx]) {
3202
- const denom = sqrt$1(reciprocal$1(cast(1, x.dtype).sub(x.ref.mul(x.ref))));
3203
- return [[asin$1(x)], [denom.mul(dx)]];
3204
- },
3205
- [Primitive.Atan]([x], [dx]) {
3206
- const denom = cast(1, x.dtype).add(x.ref.mul(x.ref));
3207
- return [[atan$1(x)], [dx.div(denom)]];
3208
- },
3209
- [Primitive.Exp]([x], [dx]) {
3210
- const z = exp$1(x);
3211
- return [[z.ref], [z.mul(dx)]];
3212
- },
3213
- [Primitive.Log]([x], [dx]) {
3214
- return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
3215
- },
3216
- [Primitive.Erf]([x], [dx]) {
3217
- const coeff = 2 / Math.sqrt(Math.PI);
3218
- const expTerm = exp$1(neg(x.ref.mul(x.ref)));
3219
- return [[erf$1(x)], [expTerm.mul(coeff).mul(dx)]];
3220
- },
3221
- [Primitive.Erfc]([x], [dx]) {
3222
- const coeff = -2 / Math.sqrt(Math.PI);
3223
- const expTerm = exp$1(neg(x.ref.mul(x.ref)));
3224
- return [[erfc$1(x)], [expTerm.mul(coeff).mul(dx)]];
3225
- },
3226
- [Primitive.Sqrt]([x], [dx]) {
3227
- const z = sqrt$1(x);
3228
- return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
3229
- },
3230
- [Primitive.Min]([x, y], [dx, dy]) {
3231
- return [[min$1(x.ref, y.ref)], [where$1(less$1(y, x), dy, dx)]];
3232
- },
3233
- [Primitive.Max]([x, y], [dx, dy]) {
3234
- return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
3235
- },
3236
- [Primitive.Reduce]([x], [dx], { op, axis }) {
3237
- if (op === AluOp.Add) return [[reduce(x, op, axis)], [reduce(dx, op, axis)]];
3238
- else if (op === AluOp.Mul) {
3239
- const primal = reduce(x.ref, op, axis);
3240
- const tangent = broadcast(primal.ref, x.shape, axis).mul(reciprocal$1(x)).mul(dx).sum(axis);
3241
- return [[primal], [tangent]];
3242
- } else if (op === AluOp.Min || op === AluOp.Max) {
3243
- const primal = reduce(x.ref, op, axis);
3244
- const notMin = notEqual$1(x, broadcast(primal.ref, x.shape, axis));
3245
- const minCount = where$1(notMin.ref, 0, 1).sum(axis);
3246
- const tangent = where$1(notMin, 0, dx).sum(axis).div(minCount);
3247
- return [[primal], [tangent]];
3248
- } else throw new Error(`JVP rule not implemented for reduce op: ${op}`);
3249
- },
3250
- [Primitive.Pool]: linearTangentsJvp(Primitive.Pool),
3251
- [Primitive.PoolTranspose]: linearTangentsJvp(Primitive.PoolTranspose),
3252
- [Primitive.Dot]: bilinearTangentsJvp(Primitive.Dot),
3253
- [Primitive.Conv]: bilinearTangentsJvp(Primitive.Conv),
3254
- [Primitive.Compare]: zeroTangentsJvp(Primitive.Compare),
3255
- [Primitive.Where]([cond, x, y], [dcond, dx, dy]) {
3256
- dcond.dispose();
3257
- return [[where$1(cond.ref, x, y)], [where$1(cond, dx, dy)]];
3258
- },
3259
- [Primitive.Transpose]: linearTangentsJvp(Primitive.Transpose),
3260
- [Primitive.Broadcast]: linearTangentsJvp(Primitive.Broadcast),
3261
- [Primitive.Reshape]: linearTangentsJvp(Primitive.Reshape),
3262
- [Primitive.Flip]: linearTangentsJvp(Primitive.Flip),
3263
- [Primitive.Shrink]: linearTangentsJvp(Primitive.Shrink),
3264
- [Primitive.Pad]: linearTangentsJvp(Primitive.Pad),
3265
- [Primitive.Gather]([x, ...indices], [dx, ..._], { axis, outDim }) {
3266
- const indicesRef = indices.map((t) => t.ref);
3267
- return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
3268
- },
3269
- [Primitive.JitCall](primals, tangents, { name, jaxpr }) {
3270
- const { newJaxpr, newConsts } = jvpJaxpr(jaxpr);
3271
- const outs = bind(Primitive.JitCall, [
3272
- ...newConsts.map((c) => c.ref),
3273
- ...primals,
3274
- ...tangents
3275
- ], {
3276
- name: `${name}_jvp`,
3277
- jaxpr: newJaxpr,
3278
- numConsts: newConsts.length
3279
- });
3280
- const n = outs.length / 2;
3281
- if (!Number.isInteger(n)) throw new Error("internal: JVP Jaxpr output length is not even");
3282
- const [primalsOut, tangentsOut] = [outs.slice(0, n), outs.slice(n)];
3283
- return [primalsOut, tangentsOut];
3284
- }
3285
- };
3286
- const jvpJaxprCache = /* @__PURE__ */ new Map();
3287
- function jvpJaxpr(jaxpr) {
3288
- if (jvpJaxprCache.has(jaxpr)) return jvpJaxprCache.get(jaxpr);
3289
- const inAvals = jaxpr.inBinders.map((v) => v.aval);
3290
- const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((primals, tangents) => jvpFlat(jaxprAsFun(jaxpr), primals, tangents))(inAvals, inAvals);
3291
- const result = {
3292
- newJaxpr,
3293
- newConsts
3294
- };
3295
- jvpJaxprCache.set(jaxpr, result);
3296
- return result;
3297
- }
3298
- function jvpFlat(f, primals, tangents) {
3299
- try {
3300
- var _usingCtx$1 = _usingCtx();
3301
- const main = _usingCtx$1.u(newMain(JVPTrace));
3302
- const trace$1 = new JVPTrace(main);
3303
- const tracersIn = zip(primals, tangents).map(([x, t]) => new JVPTracer(trace$1, pureArray(x), pureArray(t)));
3304
- const outs = f(...tracersIn);
3305
- const tracersOut = outs.map((out) => fullRaise(trace$1, out));
3306
- return unzip2(tracersOut.map((t) => [t.primal, t.tangent]));
3307
- } catch (_) {
3308
- _usingCtx$1.e = _;
3309
- } finally {
3310
- _usingCtx$1.d();
3311
- }
3312
- }
3313
- function jvp$1(f, primals, tangents) {
3314
- const [primalsFlat, inTree] = flatten(primals);
3315
- const [tangentsFlat, inTree2] = flatten(tangents);
3316
- if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
3317
- const [flatFun, outTree] = flattenFun(f, inTree);
3318
- const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
3319
- if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
3320
- const primalsOut = unflatten(outTree.value, primalsOutFlat);
3321
- const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
3322
- return [primalsOut, tangentsOut];
3323
- }
3324
-
3325
- //#endregion
3326
- //#region src/frontend/vmap.ts
3327
- function mappedAval(batchDim, aval) {
3328
- const shape$1 = [...aval.shape];
3329
- shape$1.splice(batchDim, 1);
3330
- return new ShapedArray(shape$1, aval.dtype, aval.weakType);
3331
- }
3332
- /** Move one axis to a different index. */
3333
- function moveaxis(x, src, dst) {
3334
- const t = pureArray(x);
3335
- src = checkAxis(src, t.ndim);
3336
- dst = checkAxis(dst, t.ndim);
3337
- if (src === dst) return t;
3338
- const perm = range(t.ndim);
3339
- perm.splice(src, 1);
3340
- perm.splice(dst, 0, src);
3341
- return transpose$1(t, perm);
3342
- }
3343
- function moveBatchAxis(axisSize, src, dst, x) {
3344
- if (src === null) {
3345
- const targetShape = [...x.shape];
3346
- targetShape.splice(dst, 0, axisSize);
3347
- return broadcast(x, targetShape, [dst]);
3348
- } else if (src === dst) return x;
3349
- else return moveaxis(x, src, dst);
3350
- }
3351
- var BatchTracer = class extends Tracer {
3352
- constructor(trace$1, val, batchDim) {
3353
- super(trace$1);
3354
- this.val = val;
3355
- this.batchDim = batchDim;
3356
- }
3357
- get aval() {
3358
- if (this.batchDim === null) return this.val.aval;
3359
- else return mappedAval(this.batchDim, this.val.aval);
3360
- }
3361
- toString() {
3362
- return `BatchTracer(${this.val.toString()}, ${this.batchDim})`;
3363
- }
3364
- get ref() {
3365
- this.val.ref;
3366
- return this;
3367
- }
3368
- dispose() {
3369
- this.val.dispose();
3370
- }
3371
- fullLower() {
3372
- if (this.batchDim === null) return this.val.fullLower();
3373
- else return this;
3374
- }
3375
- };
3376
- var BatchTrace = class extends Trace {
3377
- pure(val) {
3378
- return this.lift(pureArray(val));
3379
- }
3380
- lift(val) {
3381
- return new BatchTracer(this, val, null);
3382
- }
3383
- processPrimitive(primitive, tracers, params) {
3384
- const [valsIn, bdimsIn] = unzip2(tracers.map((t) => [t.val, t.batchDim]));
3385
- const vmapRule = vmapRules[primitive];
3386
- if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
3387
- if (bdimsIn.every((d) => d === null)) {
3388
- const valOuts$1 = bind(primitive, valsIn, params);
3389
- return valOuts$1.map((x) => new BatchTracer(this, x, null));
3390
- }
3391
- const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
3392
- return zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
3393
- }
3394
- get axisSize() {
3395
- return this.main.globalData;
3396
- }
3397
- };
3398
- /**
3399
- * Process a primitive with built-in broadcasting.
3400
- *
3401
- * Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
3402
- */
3403
- function broadcastBatcher(op) {
3404
- return (axisSize, args, dims) => {
3405
- if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
3406
- const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
3407
- const firstIdx = dims.findIndex((d) => d !== null);
3408
- const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
3409
- if (zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
3410
- args = args.map((x, i) => {
3411
- if (dims[i] === null) return x;
3412
- x = moveBatchAxis(axisSize, dims[i], 0, x);
3413
- if (x.ndim < nd) x = x.reshape([
3414
- x.shape[0],
3415
- ...rep(nd - x.ndim, 1),
3416
- ...x.shape.slice(1)
3417
- ]);
3418
- return x;
3419
- });
3420
- return [[op(...args)], [0]];
3421
- };
3422
- }
3423
- function unopBatcher(op) {
3424
- return (axisSize, [x], [xBdim], params) => {
3425
- return [[op(x, params)], [xBdim]];
3522
+ function unopBatcher(op) {
3523
+ return (axisSize, [x], [xBdim], params) => {
3524
+ return [[op(x, params)], [xBdim]];
3426
3525
  };
3427
3526
  }
3428
3527
  const vmapRules = {
@@ -3430,6 +3529,8 @@ const vmapRules = {
3430
3529
  [Primitive.Mul]: broadcastBatcher(mul),
3431
3530
  [Primitive.Idiv]: broadcastBatcher(idiv),
3432
3531
  [Primitive.Mod]: broadcastBatcher(mod),
3532
+ [Primitive.Min]: broadcastBatcher(min$1),
3533
+ [Primitive.Max]: broadcastBatcher(max$1),
3433
3534
  [Primitive.Neg]: unopBatcher(neg),
3434
3535
  [Primitive.Reciprocal]: unopBatcher(reciprocal$1),
3435
3536
  [Primitive.Floor]: unopBatcher(floor$1),
@@ -3446,8 +3547,6 @@ const vmapRules = {
3446
3547
  [Primitive.Erf]: unopBatcher(erf$1),
3447
3548
  [Primitive.Erfc]: unopBatcher(erfc$1),
3448
3549
  [Primitive.Sqrt]: unopBatcher(sqrt$1),
3449
- [Primitive.Min]: broadcastBatcher(min$1),
3450
- [Primitive.Max]: broadcastBatcher(max$1),
3451
3550
  [Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
3452
3551
  assertNonNull(xBdim);
3453
3552
  const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
@@ -3460,10 +3559,49 @@ const vmapRules = {
3460
3559
  const z = dot$2(x, y);
3461
3560
  return [[z], [z.ndim - 1]];
3462
3561
  },
3562
+ [Primitive.Conv](axisSize, [x, y], [xBdim, yBdim], params) {
3563
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3564
+ y = moveBatchAxis(axisSize, yBdim, 0, y);
3565
+ const z = conv$1(x, y, {
3566
+ ...params,
3567
+ vmapDims: params.vmapDims + 1
3568
+ });
3569
+ return [[z], [0]];
3570
+ },
3463
3571
  [Primitive.Compare](axisSize, args, dims, { op }) {
3464
3572
  return broadcastBatcher((x, y) => compare(x, y, op))(axisSize, args, dims, {});
3465
3573
  },
3466
- [Primitive.Where]: broadcastBatcher(where$1),
3574
+ [Primitive.Where]: broadcastBatcher(where$1),
3575
+ [Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
3576
+ if (indicesBdim.every((d) => d === null)) {
3577
+ assertNonNull(xBdim);
3578
+ const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3579
+ let newBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
3580
+ let newOutDim = outDim;
3581
+ if (newOutDim < newBdim) newBdim += axis.length;
3582
+ else newOutDim += 1;
3583
+ return [[gather(x, indices, newAxis, newOutDim)], [newBdim]];
3584
+ }
3585
+ const nd = Math.max(...indices.map((m, i) => ndim$1(m) + (indicesBdim[i] === null ? 1 : 0)));
3586
+ indices = indices.map((m, i) => {
3587
+ if (indicesBdim[i] === null) return m;
3588
+ m = moveBatchAxis(axisSize, indicesBdim[i], 0, m);
3589
+ if (m.ndim < nd) m = m.reshape([
3590
+ m.shape[0],
3591
+ ...rep(nd - m.ndim, 1),
3592
+ ...m.shape.slice(1)
3593
+ ]);
3594
+ return m;
3595
+ });
3596
+ if (xBdim === null) return [[gather(x, indices, axis, outDim)], [outDim]];
3597
+ else {
3598
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3599
+ const newAxis = [0, ...axis.map((ax) => ax + 1)];
3600
+ const extraBatchIndex = arange(axisSize).reshape([-1, ...rep(nd - 1, 1)]);
3601
+ indices.splice(0, 0, extraBatchIndex);
3602
+ return [[gather(x, indices, newAxis, outDim)], [outDim]];
3603
+ }
3604
+ },
3467
3605
  [Primitive.Transpose](axisSize, [x], [xBdim], { perm }) {
3468
3606
  assertNonNull(xBdim);
3469
3607
  const newPerm = perm.map((p) => p + (xBdim <= p ? 1 : 0));
@@ -3495,42 +3633,53 @@ const vmapRules = {
3495
3633
  const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
3496
3634
  return [[pad$1(x, newWidth)], [xBdim]];
3497
3635
  },
3498
- [Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
3499
- if (indicesBdim.every((d) => d === null)) {
3500
- assertNonNull(xBdim);
3501
- const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3502
- let newBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
3503
- let newOutDim = outDim;
3504
- if (newOutDim < newBdim) newBdim += axis.length;
3505
- else newOutDim += 1;
3506
- return [[gather(x, indices, newAxis, newOutDim)], [newBdim]];
3507
- }
3508
- const nd = Math.max(...indices.map((m, i) => ndim$1(m) + (indicesBdim[i] === null ? 1 : 0)));
3509
- indices = indices.map((m, i) => {
3510
- if (indicesBdim[i] === null) return m;
3511
- m = moveBatchAxis(axisSize, indicesBdim[i], 0, m);
3512
- if (m.ndim < nd) m = m.reshape([
3513
- m.shape[0],
3514
- ...rep(nd - m.ndim, 1),
3515
- ...m.shape.slice(1)
3636
+ [Primitive.Sort](axisSize, [x], [xBdim]) {
3637
+ assertNonNull(xBdim);
3638
+ if (xBdim !== x.ndim - 1) return [[sort$1(x)], [xBdim]];
3639
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3640
+ return [[sort$1(x)], [0]];
3641
+ },
3642
+ [Primitive.Argsort](axisSize, [x], [xBdim]) {
3643
+ assertNonNull(xBdim);
3644
+ if (xBdim !== x.ndim - 1) return [argsort$1(x), [xBdim, xBdim]];
3645
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3646
+ return [argsort$1(x), [0, 0]];
3647
+ },
3648
+ [Primitive.TriangularSolve](axisSize, [a, b], [aBdim, bBdim], { unitDiagonal }) {
3649
+ if (aBdim === null) {
3650
+ b = moveBatchAxis(axisSize, bBdim, -3, b);
3651
+ const [s, m, n] = b.shape.slice(-3);
3652
+ b = b.reshape([
3653
+ ...b.shape.slice(0, -3),
3654
+ s * m,
3655
+ n
3516
3656
  ]);
3517
- return m;
3518
- });
3519
- if (xBdim === null) return [[gather(x, indices, axis, outDim)], [outDim]];
3520
- else {
3521
- x = moveBatchAxis(axisSize, xBdim, 0, x);
3522
- const newAxis = [0, ...axis.map((ax) => ax + 1)];
3523
- const extraBatchIndex = arange(axisSize).reshape([-1, ...rep(nd - 1, 1)]);
3524
- indices.splice(0, 0, extraBatchIndex);
3525
- return [[gather(x, indices, newAxis, outDim)], [outDim]];
3657
+ let x$1 = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
3658
+ x$1 = x$1.reshape([
3659
+ ...b.shape.slice(0, -2),
3660
+ s,
3661
+ m,
3662
+ n
3663
+ ]);
3664
+ return [[x$1], [x$1.ndim - 3]];
3526
3665
  }
3666
+ a = moveBatchAxis(axisSize, aBdim, 0, a);
3667
+ b = moveBatchAxis(axisSize, bBdim, 0, b);
3668
+ const x = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
3669
+ return [[x], [0]];
3670
+ },
3671
+ [Primitive.Cholesky](axisSize, [x], [xBdim]) {
3672
+ assertNonNull(xBdim);
3673
+ if (xBdim < x.ndim - 2) return [[cholesky$2(x)], [xBdim]];
3674
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3675
+ return [[cholesky$2(x)], [0]];
3527
3676
  },
3528
- [Primitive.JitCall](axisSize, args, dims, { name, jaxpr }) {
3529
- const { newJaxpr, newConsts } = vmapJaxpr(jaxpr, axisSize, dims);
3530
- const outs = bind(Primitive.JitCall, [...newConsts.map((c) => c.ref), ...args], {
3677
+ [Primitive.Jit](axisSize, args, dims, { name, jaxpr }) {
3678
+ const newJaxpr = vmapJaxpr(jaxpr, axisSize, dims);
3679
+ const outs = bind(Primitive.Jit, [...newJaxpr.consts.map((c) => c.ref), ...args], {
3531
3680
  name: `${name}_vmap`,
3532
- jaxpr: newJaxpr,
3533
- numConsts: newConsts.length
3681
+ jaxpr: newJaxpr.jaxpr,
3682
+ numConsts: newJaxpr.consts.length
3534
3683
  });
3535
3684
  return [outs, rep(outs.length, 0)];
3536
3685
  }
@@ -3546,14 +3695,10 @@ function vmapJaxpr(jaxpr, axisSize, dims) {
3546
3695
  shape$1.splice(dims[i], 0, axisSize);
3547
3696
  return new ShapedArray(shape$1, v.aval.dtype, v.aval.weakType);
3548
3697
  });
3549
- const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((args) => vmapFlat(jaxprAsFun(jaxpr), dims, args))(inAvals);
3550
- const result = {
3551
- newJaxpr,
3552
- newConsts
3553
- };
3698
+ const { jaxpr: newJaxpr } = makeJaxpr$1((args) => vmapFlat(jaxprAsFun(jaxpr), dims, args))(inAvals);
3554
3699
  if (!vmapJaxprCache.has(jaxpr)) vmapJaxprCache.set(jaxpr, /* @__PURE__ */ new Map());
3555
- vmapJaxprCache.get(jaxpr).set(cacheKey, result);
3556
- return result;
3700
+ vmapJaxprCache.get(jaxpr).set(cacheKey, newJaxpr);
3701
+ return newJaxpr;
3557
3702
  }
3558
3703
  function vmapFlat(f, inAxes, args) {
3559
3704
  let axisSize = void 0;
@@ -3608,6 +3753,260 @@ function jacfwd$1(f) {
3608
3753
  };
3609
3754
  }
3610
3755
 
3756
+ //#endregion
3757
+ //#region src/frontend/jvp.ts
3758
+ var JVPTracer = class extends Tracer {
3759
+ constructor(trace$1, primal, tangent) {
3760
+ super(trace$1);
3761
+ this.primal = primal;
3762
+ this.tangent = tangent;
3763
+ }
3764
+ get aval() {
3765
+ return this.primal.aval;
3766
+ }
3767
+ toString() {
3768
+ return `JVPTracer(${this.primal.toString()}, ${this.tangent.toString()})`;
3769
+ }
3770
+ get ref() {
3771
+ this.primal.ref, this.tangent.ref;
3772
+ return this;
3773
+ }
3774
+ dispose() {
3775
+ this.primal.dispose();
3776
+ this.tangent.dispose();
3777
+ }
3778
+ };
3779
+ var JVPTrace = class extends Trace {
3780
+ pure(val) {
3781
+ return this.lift(pureArray(val));
3782
+ }
3783
+ lift(val) {
3784
+ return new JVPTracer(this, val, zerosLike$1(val.ref));
3785
+ }
3786
+ processPrimitive(primitive, tracers, params) {
3787
+ const [primalsIn, tangentsIn] = unzip2(tracers.map((x) => [x.primal, x.tangent]));
3788
+ const jvpRule = jvpRules[primitive];
3789
+ if (jvpRule === void 0) throw new Error(`No JVP rule for: ${primitive}`);
3790
+ const [primalsOut, tangentsOut] = jvpRule(primalsIn, tangentsIn, params);
3791
+ return zip(primalsOut, tangentsOut).map(([x, t]) => new JVPTracer(this, x, t));
3792
+ }
3793
+ };
3794
+ /** Rule that applies the same operation to primals and tangents. */
3795
+ function linearTangentsJvp(primitive) {
3796
+ return (primals, tangents, params) => {
3797
+ const ys = bind(primitive, primals, params);
3798
+ const dys = bind(primitive, tangents, params);
3799
+ return [ys, dys];
3800
+ };
3801
+ }
3802
+ /** Rule for product of gradients in bilinear operations. */
3803
+ function bilinearTangentsJvp(primitive) {
3804
+ return ([x, y], [dx, dy], params) => {
3805
+ const primal = bind1(primitive, [x.ref, y.ref], params);
3806
+ const tangent = bind1(primitive, [x, dy], params).add(bind1(primitive, [dx, y], params));
3807
+ return [[primal], [tangent]];
3808
+ };
3809
+ }
3810
+ /** Rule that zeros out any tangents. */
3811
+ function zeroTangentsJvp(primitive) {
3812
+ return (primals, tangents, params) => {
3813
+ for (const t of tangents) t.dispose();
3814
+ const ys = bind(primitive, primals, params);
3815
+ return [ys, ys.map((y) => zerosLike$1(y.ref))];
3816
+ };
3817
+ }
3818
+ /** Compute `a @ b.T`, batched to last two axes. */
3819
+ function batchMatmulT(a, b) {
3820
+ return dot$2(a.reshape(a.shape.toSpliced(-1, 0, 1)), b.reshape(b.shape.toSpliced(-2, 0, 1)));
3821
+ }
3822
+ /** Batch matrix transpose. */
3823
+ function mT(a) {
3824
+ return moveaxis(a, -2, -1);
3825
+ }
3826
+ const jvpRules = {
3827
+ [Primitive.Add]: linearTangentsJvp(Primitive.Add),
3828
+ [Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
3829
+ [Primitive.Idiv]: zeroTangentsJvp(Primitive.Idiv),
3830
+ [Primitive.Mod]([x, y], [dx, dy]) {
3831
+ if (!isFloatDtype(x.dtype) && !isFloatDtype(y.dtype)) {
3832
+ dx.dispose();
3833
+ dy.dispose();
3834
+ return [[x.ref, y.ref], [zerosLike$1(x), zerosLike$1(y)]];
3835
+ }
3836
+ const q = idiv(x.ref, y.ref);
3837
+ return [[mod(x, y)], [dx.sub(dy.mul(q))]];
3838
+ },
3839
+ [Primitive.Min]([x, y], [dx, dy]) {
3840
+ return [[min$1(x.ref, y.ref)], [where$1(less$1(y, x), dy, dx)]];
3841
+ },
3842
+ [Primitive.Max]([x, y], [dx, dy]) {
3843
+ return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
3844
+ },
3845
+ [Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
3846
+ [Primitive.Reciprocal]([x], [dx]) {
3847
+ const xRecip = reciprocal$1(x.ref);
3848
+ return [[xRecip.ref], [neg(xRecip.ref.mul(xRecip)).mul(dx)]];
3849
+ },
3850
+ [Primitive.Floor]: zeroTangentsJvp(Primitive.Floor),
3851
+ [Primitive.Ceil]: zeroTangentsJvp(Primitive.Ceil),
3852
+ [Primitive.StopGradient]: zeroTangentsJvp(Primitive.StopGradient),
3853
+ [Primitive.Cast]([x], [dx], { dtype }) {
3854
+ if (x.dtype === dtype) return [[x], [dx]];
3855
+ if (isFloatDtype(dtype) && isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
3856
+ else {
3857
+ dx.dispose();
3858
+ return [[cast(x.ref, dtype)], [zerosLike$1(x)]];
3859
+ }
3860
+ },
3861
+ [Primitive.Bitcast]([x], [dx], { dtype }) {
3862
+ if (x.dtype === dtype) return [[x], [dx]];
3863
+ dx.dispose();
3864
+ return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
3865
+ },
3866
+ [Primitive.Sin]([x], [dx]) {
3867
+ return [[sin$1(x.ref)], [cos$1(x).mul(dx)]];
3868
+ },
3869
+ [Primitive.Cos]([x], [dx]) {
3870
+ return [[cos$1(x.ref)], [neg(sin$1(x)).mul(dx)]];
3871
+ },
3872
+ [Primitive.Asin]([x], [dx]) {
3873
+ const denom = sqrt$1(reciprocal$1(cast(1, x.dtype).sub(x.ref.mul(x.ref))));
3874
+ return [[asin$1(x)], [denom.mul(dx)]];
3875
+ },
3876
+ [Primitive.Atan]([x], [dx]) {
3877
+ const denom = cast(1, x.dtype).add(x.ref.mul(x.ref));
3878
+ return [[atan$1(x)], [dx.div(denom)]];
3879
+ },
3880
+ [Primitive.Exp]([x], [dx]) {
3881
+ const z = exp$1(x);
3882
+ return [[z.ref], [z.mul(dx)]];
3883
+ },
3884
+ [Primitive.Log]([x], [dx]) {
3885
+ return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
3886
+ },
3887
+ [Primitive.Erf]([x], [dx]) {
3888
+ const coeff = 2 / Math.sqrt(Math.PI);
3889
+ const expTerm = exp$1(neg(x.ref.mul(x.ref)));
3890
+ return [[erf$1(x)], [expTerm.mul(coeff).mul(dx)]];
3891
+ },
3892
+ [Primitive.Erfc]([x], [dx]) {
3893
+ const coeff = -2 / Math.sqrt(Math.PI);
3894
+ const expTerm = exp$1(neg(x.ref.mul(x.ref)));
3895
+ return [[erfc$1(x)], [expTerm.mul(coeff).mul(dx)]];
3896
+ },
3897
+ [Primitive.Sqrt]([x], [dx]) {
3898
+ const z = sqrt$1(x);
3899
+ return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
3900
+ },
3901
+ [Primitive.Reduce]([x], [dx], { op, axis }) {
3902
+ if (op === AluOp.Add) return [[reduce(x, op, axis)], [reduce(dx, op, axis)]];
3903
+ else if (op === AluOp.Mul) {
3904
+ const primal = reduce(x.ref, op, axis);
3905
+ const tangent = broadcast(primal.ref, x.shape, axis).mul(reciprocal$1(x)).mul(dx).sum(axis);
3906
+ return [[primal], [tangent]];
3907
+ } else if (op === AluOp.Min || op === AluOp.Max) {
3908
+ const primal = reduce(x.ref, op, axis);
3909
+ const notMin = notEqual$1(x, broadcast(primal.ref, x.shape, axis));
3910
+ const minCount = where$1(notMin.ref, 0, 1).sum(axis);
3911
+ const tangent = where$1(notMin, 0, dx).sum(axis).div(minCount);
3912
+ return [[primal], [tangent]];
3913
+ } else throw new Error(`JVP rule not implemented for reduce op: ${op}`);
3914
+ },
3915
+ [Primitive.Pool]: linearTangentsJvp(Primitive.Pool),
3916
+ [Primitive.PoolTranspose]: linearTangentsJvp(Primitive.PoolTranspose),
3917
+ [Primitive.Dot]: bilinearTangentsJvp(Primitive.Dot),
3918
+ [Primitive.Conv]: bilinearTangentsJvp(Primitive.Conv),
3919
+ [Primitive.Compare]: zeroTangentsJvp(Primitive.Compare),
3920
+ [Primitive.Where]([cond, x, y], [dcond, dx, dy]) {
3921
+ dcond.dispose();
3922
+ return [[where$1(cond.ref, x, y)], [where$1(cond, dx, dy)]];
3923
+ },
3924
+ [Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
3925
+ [Primitive.Gather]([x, ...indices], [dx, ..._], { axis, outDim }) {
3926
+ const indicesRef = indices.map((t) => t.ref);
3927
+ return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
3928
+ },
3929
+ [Primitive.Transpose]: linearTangentsJvp(Primitive.Transpose),
3930
+ [Primitive.Broadcast]: linearTangentsJvp(Primitive.Broadcast),
3931
+ [Primitive.Reshape]: linearTangentsJvp(Primitive.Reshape),
3932
+ [Primitive.Flip]: linearTangentsJvp(Primitive.Flip),
3933
+ [Primitive.Shrink]: linearTangentsJvp(Primitive.Shrink),
3934
+ [Primitive.Pad]: linearTangentsJvp(Primitive.Pad),
3935
+ [Primitive.Sort]([x], [dx]) {
3936
+ const [y, idx] = argsort$1(x);
3937
+ return [[y], [gather(dx, [idx], [-1], -1)]];
3938
+ },
3939
+ [Primitive.Argsort]([x], [dx]) {
3940
+ const [y, idx] = argsort$1(x);
3941
+ return [[y, idx.ref], [gather(dx, [idx.ref], [-1], -1), zerosLike$1(idx)]];
3942
+ },
3943
+ [Primitive.TriangularSolve]([a, b], [da, db], { unitDiagonal }) {
3944
+ const x = triangularSolve$1(a.ref, b, { unitDiagonal });
3945
+ const dax = batchMatmulT(da, x.ref);
3946
+ const rhsT = db.sub(mT(dax));
3947
+ const dx = triangularSolve$1(a, rhsT, { unitDiagonal });
3948
+ return [[x], [dx]];
3949
+ },
3950
+ [Primitive.Cholesky]([a], [da]) {
3951
+ const L = cholesky$2(a.ref);
3952
+ da = da.ref.add(mT(da)).mul(.5);
3953
+ const W = triangularSolve$1(L.ref, da, { lower: true });
3954
+ const ST = triangularSolve$1(L.ref, mT(W), { lower: true });
3955
+ const dL = batchMatmulT(L.ref, triu(ST.ref, 1).add(triu(ST)).mul(.5));
3956
+ return [[L], [dL]];
3957
+ },
3958
+ [Primitive.Jit](primals, tangents, { name, jaxpr }) {
3959
+ const newJaxpr = jvpJaxpr(jaxpr);
3960
+ const outs = bind(Primitive.Jit, [
3961
+ ...newJaxpr.consts.map((c) => c.ref),
3962
+ ...primals,
3963
+ ...tangents
3964
+ ], {
3965
+ name: `${name}_jvp`,
3966
+ jaxpr: newJaxpr.jaxpr,
3967
+ numConsts: newJaxpr.consts.length
3968
+ });
3969
+ const n = outs.length / 2;
3970
+ if (!Number.isInteger(n)) throw new Error("internal: JVP Jaxpr output length is not even");
3971
+ const [primalsOut, tangentsOut] = [outs.slice(0, n), outs.slice(n)];
3972
+ return [primalsOut, tangentsOut];
3973
+ }
3974
+ };
3975
+ const jvpJaxprCache = /* @__PURE__ */ new Map();
3976
+ function jvpJaxpr(jaxpr) {
3977
+ if (jvpJaxprCache.has(jaxpr)) return jvpJaxprCache.get(jaxpr);
3978
+ const inAvals = jaxpr.inBinders.map((v) => v.aval);
3979
+ const { jaxpr: newJaxpr } = makeJaxpr$1((primals, tangents) => jvpFlat(jaxprAsFun(jaxpr), primals, tangents))(inAvals, inAvals);
3980
+ jvpJaxprCache.set(jaxpr, newJaxpr);
3981
+ return newJaxpr;
3982
+ }
3983
+ function jvpFlat(f, primals, tangents) {
3984
+ try {
3985
+ var _usingCtx$1 = _usingCtx();
3986
+ const main = _usingCtx$1.u(newMain(JVPTrace));
3987
+ const trace$1 = new JVPTrace(main);
3988
+ const tracersIn = zip(primals, tangents).map(([x, t]) => new JVPTracer(trace$1, pureArray(x), pureArray(t)));
3989
+ const outs = f(...tracersIn);
3990
+ const tracersOut = outs.map((out) => fullRaise(trace$1, out));
3991
+ return unzip2(tracersOut.map((t) => [t.primal, t.tangent]));
3992
+ } catch (_) {
3993
+ _usingCtx$1.e = _;
3994
+ } finally {
3995
+ _usingCtx$1.d();
3996
+ }
3997
+ }
3998
+ function jvp$1(f, primals, tangents) {
3999
+ const [primalsFlat, inTree] = flatten(primals);
4000
+ const [tangentsFlat, inTree2] = flatten(tangents);
4001
+ if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
4002
+ const [flatFun, outTree] = flattenFun(f, inTree);
4003
+ const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
4004
+ if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
4005
+ const primalsOut = unflatten(outTree.value, primalsOutFlat);
4006
+ const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
4007
+ return [primalsOut, tangentsOut];
4008
+ }
4009
+
3611
4010
  //#endregion
3612
4011
  //#region src/frontend/linearize.ts
3613
4012
  /** Array value that can either be known or unknown. */
@@ -3638,11 +4037,10 @@ function partialEvalFlat(f, pvalsIn) {
3638
4037
  const tracersOut = outs.map((out) => fullRaise(trace$1, out));
3639
4038
  const pvalsOut = tracersOut.map((t) => t.pval);
3640
4039
  const unknownTracersOut = tracersOut.filter((t) => !t.pval.isKnown);
3641
- const { jaxpr, consts } = partialEvalGraphToJaxpr(unknownTracersIn, unknownTracersOut);
4040
+ const jaxpr = partialEvalGraphToJaxpr(unknownTracersIn, unknownTracersOut);
3642
4041
  return {
3643
4042
  jaxpr,
3644
- pvalsOut,
3645
- consts
4043
+ pvalsOut
3646
4044
  };
3647
4045
  }
3648
4046
  /**
@@ -3659,22 +4057,19 @@ function linearizeFlatUtil(f, primalsIn) {
3659
4057
  const [primalsOut$1, tangentsOut] = jvp$1(f, x.slice(0, k), x.slice(k, 2 * k));
3660
4058
  return [...primalsOut$1, ...tangentsOut];
3661
4059
  };
3662
- const { jaxpr, pvalsOut, consts } = partialEvalFlat(fJvp, pvalsIn);
4060
+ const { jaxpr, pvalsOut } = partialEvalFlat(fJvp, pvalsIn);
3663
4061
  const primalPvals = pvalsOut.slice(0, pvalsOut.length / 2);
3664
4062
  if (!primalPvals.every((pval) => pval.isKnown)) throw new Error("Not all primal values are known after partial evaluation");
3665
4063
  const primalsOut = primalPvals.map((pval) => pval.val);
3666
4064
  return {
3667
4065
  primalsOut,
3668
- jaxpr,
3669
- consts
4066
+ jaxpr
3670
4067
  };
3671
4068
  }
3672
4069
  function linearizeFlat(f, primalsIn) {
3673
- const { primalsOut, jaxpr, consts } = linearizeFlatUtil(f, primalsIn);
3674
- const fLin = (...tangents) => evalJaxpr(jaxpr, [...consts.map((c) => c.ref), ...tangents]);
3675
- const dispose$1 = () => {
3676
- for (const c of consts) c.dispose();
3677
- };
4070
+ const { primalsOut, jaxpr } = linearizeFlatUtil(f, primalsIn);
4071
+ const fLin = (...tangents) => evalJaxpr(jaxpr.jaxpr, [...jaxpr.consts.map((c) => c.ref), ...tangents]);
4072
+ const dispose$1 = () => jaxpr.dispose();
3678
4073
  return [
3679
4074
  primalsOut,
3680
4075
  fLin,
@@ -3758,7 +4153,7 @@ var PartialEvalTrace = class extends Trace {
3758
4153
  }
3759
4154
  processPrimitive(primitive, tracers, params) {
3760
4155
  if (tracers.every((t) => t.pval.isKnown)) return bind(primitive, tracers.map((t) => t.fullLower()), params);
3761
- if (primitive === Primitive.JitCall) {
4156
+ if (primitive === Primitive.Jit) {
3762
4157
  const { name, jaxpr, numConsts } = params;
3763
4158
  return this.#partialEvalJaxpr(name, jaxpr, numConsts, tracers);
3764
4159
  }
@@ -3784,14 +4179,14 @@ var PartialEvalTrace = class extends Trace {
3784
4179
  * Evaluate a Jaxpr on a set of PartialEvalTracers, computing as many known
3785
4180
  * values as possible (with JIT) and forwarding the unknown ones.
3786
4181
  *
3787
- * Used when encountering a JitCall rule during the trace.
4182
+ * Used when encountering a Jit rule during the trace.
3788
4183
  */
3789
4184
  #partialEvalJaxpr(name, jaxpr, numConsts, tracers) {
3790
4185
  jaxpr = jaxpr.flatten();
3791
4186
  const inUnknowns = tracers.map((t) => !t.pval.isKnown);
3792
4187
  const { jaxpr1, jaxpr2, outUnknowns, numRes } = partialEvalJaxpr(jaxpr, inUnknowns);
3793
4188
  const [knownTracers, unknownTracers] = partitionList(inUnknowns, tracers);
3794
- const outs1Res = bind(Primitive.JitCall, knownTracers.map((t) => t.ref.fullLower()), {
4189
+ const outs1Res = bind(Primitive.Jit, knownTracers.map((t) => t.ref.fullLower()), {
3795
4190
  name: `${name}_peval`,
3796
4191
  jaxpr: jaxpr1,
3797
4192
  numConsts: 0
@@ -3801,7 +4196,7 @@ var PartialEvalTrace = class extends Trace {
3801
4196
  const resTracers = res.map((x) => this.instantiateConst(fullRaise(this, x)));
3802
4197
  const recipe = {
3803
4198
  type: "JaxprEqn",
3804
- prim: Primitive.JitCall,
4199
+ prim: Primitive.Jit,
3805
4200
  tracersIn: resTracers.concat(unknownTracers),
3806
4201
  params: {
3807
4202
  name: `${name}_resid`,
@@ -3830,7 +4225,7 @@ function partialEvalJaxpr(jaxpr, inUnknowns, instantiate) {
3830
4225
  const eqns1 = [];
3831
4226
  const eqns2 = [];
3832
4227
  for (const eqn of jaxpr.eqns) {
3833
- if (eqn.primitive === Primitive.JitCall) throw new TypeError("partialEvalJaxpr requires flattened Jaxpr");
4228
+ if (eqn.primitive === Primitive.Jit) throw new TypeError("partialEvalJaxpr requires flattened Jaxpr");
3834
4229
  const hasUnknowns = eqn.inputs.some((x) => x instanceof Var && !knownVars.has(x));
3835
4230
  if (hasUnknowns) {
3836
4231
  for (const x of eqn.inputs) if (x instanceof Var && knownVars.has(x)) residuals.add(x);
@@ -3904,11 +4299,8 @@ function partialEvalGraphToJaxpr(tracersIn, tracersOut) {
3904
4299
  for (const t of tracersIn) t.dispose();
3905
4300
  for (const t of tracersOut) t.dispose();
3906
4301
  jaxpr = jaxpr.simplify();
3907
- if (DEBUG >= 5) console.log("jaxpr from partial evaluation:\n" + jaxpr.toString());
3908
- return {
3909
- jaxpr,
3910
- consts
3911
- };
4302
+ if (DEBUG >= 5) console.info("jaxpr from partial evaluation:\n" + jaxpr.toString());
4303
+ return new ClosedJaxpr(jaxpr, consts);
3912
4304
  }
3913
4305
  /** Marker type for pullback, used by transpose rules. */
3914
4306
  var UndefPrimal = class {
@@ -4038,22 +4430,25 @@ const transposeRules = {
4038
4430
  },
4039
4431
  [Primitive.Conv]([ct], [lhs, rhs], params) {
4040
4432
  if (lhs instanceof UndefPrimal === rhs instanceof UndefPrimal) throw new NonlinearError(Primitive.Conv);
4433
+ const v = params.vmapDims;
4041
4434
  const rev01 = [
4042
- 1,
4043
- 0,
4044
- ...range(2, ct.ndim)
4435
+ ...range(v),
4436
+ v + 1,
4437
+ v,
4438
+ ...range(v + 2, ct.ndim)
4045
4439
  ];
4046
4440
  if (lhs instanceof UndefPrimal) {
4047
4441
  let kernel = rhs;
4048
4442
  kernel = transpose$1(kernel, rev01);
4049
- kernel = flip$1(kernel, range(2, kernel.ndim));
4443
+ kernel = flip$1(kernel, range(v + 2, kernel.ndim));
4050
4444
  const result = conv$1(ct, kernel, {
4445
+ vmapDims: v,
4051
4446
  strides: params.lhsDilation,
4052
4447
  padding: params.padding.map(([pl, _pr], i) => {
4053
- const dilatedKernel = (kernel.shape[i + 2] - 1) * params.rhsDilation[i] + 1;
4054
- const dilatedCt = (ct.shape[i + 2] - 1) * params.strides[i] + 1;
4448
+ const dilatedKernel = (kernel.shape[i + v + 2] - 1) * params.rhsDilation[i] + 1;
4449
+ const dilatedCt = (ct.shape[i + v + 2] - 1) * params.strides[i] + 1;
4055
4450
  const padBefore = dilatedKernel - 1 - pl;
4056
- const dilatedLhs = (lhs.aval.shape[i + 2] - 1) * params.lhsDilation[i] + 1;
4451
+ const dilatedLhs = (lhs.aval.shape[i + v + 2] - 1) * params.lhsDilation[i] + 1;
4057
4452
  const padAfter = dilatedLhs + dilatedKernel - 1 - dilatedCt - padBefore;
4058
4453
  return [padBefore, padAfter];
4059
4454
  }),
@@ -4065,11 +4460,12 @@ const transposeRules = {
4065
4460
  const newLhs = transpose$1(lhs, rev01);
4066
4461
  const newRhs = transpose$1(ct, rev01);
4067
4462
  let result = conv$1(newLhs, newRhs, {
4463
+ vmapDims: v,
4068
4464
  strides: params.rhsDilation,
4069
4465
  padding: params.padding.map(([pl, _pr], i) => {
4070
- const dilatedLhs = (lhs.aval.shape[i + 2] - 1) * params.lhsDilation[i] + 1;
4071
- const dilatedKernel = (rhs.aval.shape[i + 2] - 1) * params.rhsDilation[i] + 1;
4072
- const dilatedCt = (ct.shape[i + 2] - 1) * params.strides[i] + 1;
4466
+ const dilatedLhs = (lhs.aval.shape[i + v + 2] - 1) * params.lhsDilation[i] + 1;
4467
+ const dilatedKernel = (rhs.aval.shape[i + v + 2] - 1) * params.rhsDilation[i] + 1;
4468
+ const dilatedCt = (ct.shape[i + v + 2] - 1) * params.strides[i] + 1;
4073
4469
  const padFromLhs = dilatedCt - dilatedLhs;
4074
4470
  const padFromRhs = dilatedKernel - pl - 1;
4075
4471
  return [pl, padFromLhs + padFromRhs];
@@ -4096,6 +4492,11 @@ const transposeRules = {
4096
4492
  cond.dispose();
4097
4493
  return cts;
4098
4494
  },
4495
+ [Primitive.Gather]([ct], [x, ...indices], { axis, outDim }) {
4496
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
4497
+ if (indices.some((i) => i instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
4498
+ throw new Error("Gather transpose rule is not yet implemented, requires complex Scatter sum operation");
4499
+ },
4099
4500
  [Primitive.Transpose]([ct], [x], { perm }) {
4100
4501
  if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Transpose);
4101
4502
  return [transpose$1(ct, invertPermutation(perm))];
@@ -4122,23 +4523,26 @@ const transposeRules = {
4122
4523
  const slice = width.map(([s, _e], i) => [s, s + x.aval.shape[i]]);
4123
4524
  return [shrink(ct, slice)];
4124
4525
  },
4125
- [Primitive.Gather]([ct], [x, ...indices], { axis, outDim }) {
4126
- if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
4127
- if (indices.some((i) => i instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
4128
- throw new Error("Gather transpose rule is not yet implemented, requires complex Scatter sum operation");
4526
+ [Primitive.TriangularSolve]([ct], [a, b], { unitDiagonal }) {
4527
+ if (a instanceof UndefPrimal || !(b instanceof UndefPrimal)) throw new NonlinearError(Primitive.TriangularSolve);
4528
+ const ctB = triangularSolve$1(moveaxis(a, -2, -1), ct, {
4529
+ lower: true,
4530
+ unitDiagonal
4531
+ });
4532
+ return [null, ctB];
4129
4533
  },
4130
- [Primitive.JitCall](cts, args, { name, jaxpr }) {
4534
+ [Primitive.Jit](cts, args, { name, jaxpr }) {
4131
4535
  const undefPrimals = args.map((x) => x instanceof UndefPrimal);
4132
- const { newJaxpr, newConsts } = transposeJaxpr(jaxpr, undefPrimals);
4536
+ const newJaxpr = transposeJaxpr(jaxpr, undefPrimals);
4133
4537
  const residuals = args.filter((x, i$1) => !undefPrimals[i$1]);
4134
- const outs = bind(Primitive.JitCall, [
4135
- ...newConsts.map((c) => c.ref),
4538
+ const outs = bind(Primitive.Jit, [
4539
+ ...newJaxpr.consts.map((c) => c.ref),
4136
4540
  ...residuals,
4137
4541
  ...cts
4138
4542
  ], {
4139
4543
  name: `${name}_t`,
4140
- jaxpr: newJaxpr,
4141
- numConsts: newConsts.length
4544
+ jaxpr: newJaxpr.jaxpr,
4545
+ numConsts: newJaxpr.consts.length
4142
4546
  });
4143
4547
  let i = 0;
4144
4548
  return undefPrimals.map((isUndef) => isUndef ? outs[i++] : null);
@@ -4151,31 +4555,25 @@ function transposeJaxpr(jaxpr, undefPrimals) {
4151
4555
  if (prevResult) return prevResult;
4152
4556
  const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
4153
4557
  const forwardInTypes = inTypes.filter((_, i) => !undefPrimals[i]);
4154
- const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((forwardIn, cotangents) => {
4558
+ const { jaxpr: newJaxpr } = makeJaxpr$1((forwardIn, cotangents) => {
4155
4559
  const args = [];
4156
4560
  let forwardInIdx = 0;
4157
4561
  for (let i = 0; i < undefPrimals.length; i++) if (undefPrimals[i]) args.push(new UndefPrimal(inTypes[i]));
4158
4562
  else args.push(forwardIn[forwardInIdx++]);
4159
4563
  return evalJaxprTransposed(jaxpr, args, cotangents);
4160
4564
  })(forwardInTypes, outTypes);
4161
- typecheckJaxpr(newJaxpr);
4162
- const result = {
4163
- newJaxpr,
4164
- newConsts
4165
- };
4565
+ typecheckJaxpr(newJaxpr.jaxpr);
4166
4566
  if (!transposeJaxprCache.has(jaxpr)) transposeJaxprCache.set(jaxpr, /* @__PURE__ */ new Map());
4167
- transposeJaxprCache.get(jaxpr).set(cacheKey, result);
4168
- return result;
4567
+ transposeJaxprCache.get(jaxpr).set(cacheKey, newJaxpr);
4568
+ return newJaxpr;
4169
4569
  }
4170
4570
  function vjpFlat(f, primalsIn) {
4171
- const { primalsOut, jaxpr, consts } = linearizeFlatUtil(f, primalsIn);
4571
+ const { primalsOut, jaxpr } = linearizeFlatUtil(f, primalsIn);
4172
4572
  const fVjp = (...cotangents) => {
4173
- const transposeInputs = [...consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
4174
- return evalJaxprTransposed(jaxpr, transposeInputs, cotangents);
4175
- };
4176
- const dispose$1 = () => {
4177
- for (const c of consts) c.dispose();
4573
+ const transposeInputs = [...jaxpr.consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
4574
+ return evalJaxprTransposed(jaxpr.jaxpr, transposeInputs, cotangents);
4178
4575
  };
4576
+ const dispose$1 = () => jaxpr.dispose();
4179
4577
  return [
4180
4578
  primalsOut,
4181
4579
  fVjp,
@@ -4232,150 +4630,6 @@ function jacrev$1(f) {
4232
4630
  };
4233
4631
  }
4234
4632
 
4235
- //#endregion
4236
- //#region src/library/lax.ts
4237
- var lax_exports = {};
4238
- __export(lax_exports, {
4239
- conv: () => conv,
4240
- convGeneralDilated: () => convGeneralDilated,
4241
- convWithGeneralPadding: () => convWithGeneralPadding,
4242
- dot: () => dot$1,
4243
- erf: () => erf,
4244
- erfc: () => erfc,
4245
- reduceWindow: () => reduceWindow,
4246
- stopGradient: () => stopGradient$1
4247
- });
4248
- /**
4249
- * General dot product/contraction operator.
4250
- *
4251
- * Prefer higher-level functions like `jax.numpy.dot()`, `jax.numpy.matmul()`,
4252
- * `jax.numpy.tensordot(), and `jax.numpy.einsum()` where possible.
4253
- */
4254
- function dot$1(lhs, rhs, { lhsContractingDims: lc = [], rhsContractingDims: rc = [], lhsBatchDims: lb = [], rhsBatchDims: rb = [] } = {}) {
4255
- if (lc.length !== rc.length) throw new Error(`dot: contracting dims lengths mismatch, got ${JSON.stringify(lc)} and ${JSON.stringify(rc)}`);
4256
- else if (lb.length !== rb.length) throw new Error(`dot: batch dims lengths mismatch, got ${JSON.stringify(lb)} and ${JSON.stringify(rb)}`);
4257
- lc = lc.map((a) => checkAxis(a, lhs.ndim));
4258
- rc = rc.map((a) => checkAxis(a, rhs.ndim));
4259
- lb = lb.map((a) => checkAxis(a, lhs.ndim));
4260
- rb = rb.map((a) => checkAxis(a, rhs.ndim));
4261
- if (lc.some((a) => lb.includes(a))) throw new Error(`dot: lhs contracting dims ${JSON.stringify(lc)} overlap with batch dims ${JSON.stringify(lb)}`);
4262
- else if (rc.some((a) => rb.includes(a))) throw new Error(`dot: rhs contracting dims ${JSON.stringify(rc)} overlap with batch dims ${JSON.stringify(rb)}`);
4263
- const lf = range(lhs.ndim).filter((a) => !lc.includes(a) && !lb.includes(a));
4264
- const rf = range(rhs.ndim).filter((a) => !rc.includes(a) && !rb.includes(a));
4265
- const lhs2 = lhs.transpose([
4266
- ...lb,
4267
- ...lf,
4268
- ...lc
4269
- ]);
4270
- const rhs2 = rhs.transpose([
4271
- ...rb,
4272
- ...rf,
4273
- ...rc
4274
- ]);
4275
- if (lc.length === 0) return mul(lhs2.reshape([
4276
- ...lb.map((a) => lhs.shape[a]),
4277
- ...lf.map((a) => lhs.shape[a]),
4278
- ...rep(rf.length, 1)
4279
- ]), rhs2.reshape([
4280
- ...rb.map((a) => rhs.shape[a]),
4281
- ...rep(lf.length, 1),
4282
- ...rf.map((a) => rhs.shape[a])
4283
- ]));
4284
- const dotShapeX = lc.map((a) => lhs.shape[a]);
4285
- const dotShapeY = rc.map((a) => rhs.shape[a]);
4286
- if (!deepEqual(dotShapeX, dotShapeY)) throw new Error(`dot: shapes not aligned along contracting dims: ${JSON.stringify(dotShapeX)} != ${JSON.stringify(dotShapeY)}`);
4287
- return dot$2(lhs2.reshape([
4288
- ...lb.map((a) => lhs.shape[a]),
4289
- ...lf.map((a) => lhs.shape[a]),
4290
- ...rep(rf.length, 1),
4291
- prod(dotShapeX)
4292
- ]), rhs2.reshape([
4293
- ...rb.map((a) => rhs.shape[a]),
4294
- ...rep(lf.length, 1),
4295
- ...rf.map((a) => rhs.shape[a]),
4296
- prod(dotShapeY)
4297
- ]));
4298
- }
4299
- function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
4300
- const padType = padding.toUpperCase();
4301
- switch (padType) {
4302
- case "VALID": return rep(inShape.length, [0, 0]);
4303
- case "SAME":
4304
- case "SAME_LOWER": {
4305
- const outShape = inShape.map((size$1, i) => Math.ceil(size$1 / strides[i]));
4306
- const padSizes = zipn(outShape, strides, filterShape, dilation, inShape).map(([o, s, k, d, i]) => Math.max(0, (o - 1) * s + 1 + (k - 1) * d - i));
4307
- if (padType === "SAME") return padSizes.map((size$1) => [size$1 >> 1, size$1 - (size$1 >> 1)]);
4308
- else return padSizes.map((size$1) => [size$1 - (size$1 >> 1), size$1 >> 1]);
4309
- }
4310
- default: throw new Error(`Unknown padding type: ${padType}`);
4311
- }
4312
- }
4313
- /**
4314
- * General n-dimensional convolution operator, with optional dilation.
4315
- *
4316
- * The semantics of this operation mimic the `jax.lax.conv_general_dilated`
4317
- * function in JAX, which wraps XLA's general convolution operator.
4318
- *
4319
- * Grouped convolutions are not supported right now.
4320
- */
4321
- function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation } = {}) {
4322
- if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
4323
- if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
4324
- if (typeof padding === "string") {
4325
- if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
4326
- padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? rep(rhs.ndim - 2, 1), padding);
4327
- }
4328
- return conv$1(lhs, rhs, {
4329
- strides: windowStrides,
4330
- padding,
4331
- lhsDilation,
4332
- rhsDilation
4333
- });
4334
- }
4335
- /** Convenience wrapper around `convGeneralDilated`. */
4336
- function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation) {
4337
- return convGeneralDilated(lhs, rhs, windowStrides, padding, {
4338
- lhsDilation,
4339
- rhsDilation
4340
- });
4341
- }
4342
- /** Convenience wrapper around `convGeneralDilated`. */
4343
- function conv(lhs, rhs, windowStrides, padding) {
4344
- return convGeneralDilated(lhs, rhs, windowStrides, padding);
4345
- }
4346
- /** Reduce a computation over padded windows. */
4347
- function reduceWindow(operand, computation, windowDimensions, windowStrides) {
4348
- if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
4349
- if (!windowStrides) windowStrides = rep(windowDimensions.length, 1);
4350
- for (let i = 0; i < operand.ndim; i++) computation = vmap$1(computation, 0);
4351
- return computation(bind1(Primitive.Pool, [operand], {
4352
- window: windowDimensions,
4353
- strides: windowStrides
4354
- }));
4355
- }
4356
- /** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
4357
- function erf(x) {
4358
- return erf$1(x);
4359
- }
4360
- /**
4361
- * The complementary error function: `erfc(x) = 1 - erf(x)`.
4362
- *
4363
- * This function is more accurate than `1 - erf(x)` for large values of `x`,
4364
- * where `erf(x)` is very close to 1.
4365
- */
4366
- function erfc(x) {
4367
- return erfc$1(x);
4368
- }
4369
- /**
4370
- * Stops gradient computation.
4371
- *
4372
- * Behaves as the identity function but prevents the flow of gradients during
4373
- * forward or reverse-mode automatic differentiation.
4374
- */
4375
- function stopGradient$1(x) {
4376
- return stopGradient(x);
4377
- }
4378
-
4379
4633
  //#endregion
4380
4634
  //#region src/library/numpy/einsum.ts
4381
4635
  const bprod = (...xs) => xs.reduce((acc, x) => acc * BigInt(x), 1n);
@@ -4571,34 +4825,207 @@ function* allPaths(tensors, next) {
4571
4825
  }
4572
4826
  }
4573
4827
 
4828
+ //#endregion
4829
+ //#region src/library/numpy-fft.ts
4830
+ var numpy_fft_exports = {};
4831
+ __export(numpy_fft_exports, {
4832
+ fft: () => fft,
4833
+ ifft: () => ifft
4834
+ });
4835
+ function checkPairInput(name, a) {
4836
+ const fullName = `jax.numpy.fft.${name}`;
4837
+ if (!deepEqual(a.real.shape, a.imag.shape)) throw new Error(`${fullName}: real and imaginary parts must have the same shape, got ${JSON.stringify(a.real.shape)} and ${JSON.stringify(a.imag.shape)}`);
4838
+ if (a.real.dtype !== a.imag.dtype) throw new Error(`${fullName}: real and imaginary parts must have the same dtype, got ${a.real.dtype} and ${a.imag.dtype}`);
4839
+ if (!isFloatDtype(a.real.dtype)) throw new Error(`${fullName}: input must have a float dtype, got ${a.real.dtype}`);
4840
+ }
4841
+ function checkPowerOfTwo(name, n) {
4842
+ if ((n & n - 1) !== 0) throw new Error(`jax.numpy.fft.${name}: size must be a power of two, got ${n}`);
4843
+ }
4844
+ const fftUpdate = jit$1(function fftUpdate$1(i, { real, imag }) {
4845
+ const half = 2 ** i;
4846
+ real = real.reshape([-1, 2 * half]);
4847
+ imag = imag.reshape([-1, 2 * half]);
4848
+ const k = arange(0, half, 1, { dtype: real.dtype });
4849
+ const theta = k.mul(-Math.PI / half);
4850
+ const wr = cos(theta.ref);
4851
+ const wi = sin(theta);
4852
+ const ur = real.ref.slice([], [0, half]);
4853
+ const ui = imag.ref.slice([], [0, half]);
4854
+ const vr = real.slice([], [half, 2 * half]);
4855
+ const vi = imag.slice([], [half, 2 * half]);
4856
+ const tr = vr.ref.mul(wr.ref).sub(vi.ref.mul(wi.ref));
4857
+ const ti = vr.mul(wi).add(vi.mul(wr));
4858
+ return {
4859
+ real: concatenate([ur.ref.add(tr.ref), ur.sub(tr)], -1),
4860
+ imag: concatenate([ui.ref.add(ti.ref), ui.sub(ti)], -1)
4861
+ };
4862
+ }, { staticArgnums: [0] });
4863
+ /**
4864
+ * Compute a one-dimensional discrete Fourier transform.
4865
+ *
4866
+ * Currently, the size of the axis must be a power of two.
4867
+ */
4868
+ function fft(a, axis = -1) {
4869
+ checkPairInput("fft", a);
4870
+ let { real, imag } = a;
4871
+ axis = checkAxis(axis, real.ndim);
4872
+ const n = real.shape[axis];
4873
+ checkPowerOfTwo("fft", n);
4874
+ const logN = Math.log2(n);
4875
+ let perm = null;
4876
+ if (axis !== real.ndim - 1) {
4877
+ perm = range(real.ndim);
4878
+ perm.splice(axis, 1);
4879
+ perm.push(axis);
4880
+ real = real.transpose(perm);
4881
+ imag = imag.transpose(perm);
4882
+ }
4883
+ const originalShape = real.shape;
4884
+ real = real.reshape([-1, ...rep(logN, 2)]).transpose([0, ...range(1, logN + 1).reverse()]).flatten();
4885
+ imag = imag.reshape([-1, ...rep(logN, 2)]).transpose([0, ...range(1, logN + 1).reverse()]).flatten();
4886
+ for (let i = 0; i < logN; i++) ({real, imag} = fftUpdate(i, {
4887
+ real,
4888
+ imag
4889
+ }));
4890
+ real = real.reshape(originalShape);
4891
+ imag = imag.reshape(originalShape);
4892
+ if (perm !== null) {
4893
+ real = real.transpose(invertPermutation(perm));
4894
+ imag = imag.transpose(invertPermutation(perm));
4895
+ }
4896
+ return {
4897
+ real,
4898
+ imag
4899
+ };
4900
+ }
4901
+ /**
4902
+ * Compute a one-dimensional inverse discrete Fourier transform.
4903
+ *
4904
+ * Currently, the size of the axis must be a power of two.
4905
+ */
4906
+ function ifft(a, axis = -1) {
4907
+ checkPairInput("ifft", a);
4908
+ let { real, imag } = a;
4909
+ axis = checkAxis(axis, real.ndim);
4910
+ const n = real.shape[axis];
4911
+ checkPowerOfTwo("ifft", n);
4912
+ imag = imag.mul(-1);
4913
+ const result = fft({
4914
+ real,
4915
+ imag
4916
+ }, axis);
4917
+ return {
4918
+ real: result.real.div(n),
4919
+ imag: result.imag.mul(-1).div(n)
4920
+ };
4921
+ }
4922
+
4923
+ //#endregion
4924
+ //#region src/library/numpy-linalg.ts
4925
+ var numpy_linalg_exports = {};
4926
+ __export(numpy_linalg_exports, {
4927
+ cholesky: () => cholesky$1,
4928
+ diagonal: () => diagonal,
4929
+ lstsq: () => lstsq,
4930
+ matmul: () => matmul,
4931
+ matrixTranspose: () => matrixTranspose,
4932
+ outer: () => outer,
4933
+ tensordot: () => tensordot,
4934
+ trace: () => trace,
4935
+ vecdot: () => vecdot
4936
+ });
4937
+ /**
4938
+ * Compute the Cholesky decomposition of a (batched) positive-definite matrix.
4939
+ *
4940
+ * This is like `jax.lax.linalg.cholesky()`, except with an option to symmetrize
4941
+ * the input matrix, which is on by default.
4942
+ */
4943
+ function cholesky$1(a, { upper = false, symmetrizeInput = true } = {}) {
4944
+ a = fudgeArray(a);
4945
+ if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`cholesky: input must be at least 2D square matrix, got ${a.aval}`);
4946
+ if (symmetrizeInput) a = a.ref.add(matrixTranspose(a)).mul(.5);
4947
+ return cholesky(a, { upper });
4948
+ }
4949
+ /**
4950
+ * Return the least-squares solution to a linear equation.
4951
+ *
4952
+ * For overdetermined systems, this finds the `x` that minimizes `norm(ax - b)`.
4953
+ * For underdetermined systems, this finds the minimum-norm solution for `x`.
4954
+ *
4955
+ * This currently uses Cholesky decomposition to solve the normal equations,
4956
+ * under the hood. The method is not as robust as QR or SVD.
4957
+ *
4958
+ * @param a coefficient matrix of shape `(M, N)`
4959
+ * @param b right-hand side of shape `(M,)` or `(M, K)`
4960
+ * @return least-squares solution of shape `(N,)` or `(N, K)`
4961
+ */
4962
+ function lstsq(a, b) {
4963
+ a = fudgeArray(a);
4964
+ b = fudgeArray(b);
4965
+ if (a.ndim !== 2) throw new Error(`lstsq: 'a' must be a 2D array, got ${a.aval}`);
4966
+ const [m, n] = a.shape;
4967
+ if (b.shape[0] !== m) throw new Error(`lstsq: leading dimension of 'b' must match number of rows of 'a', got ${b.aval}`);
4968
+ const at = matrixTranspose(a.ref);
4969
+ if (m <= n) {
4970
+ const aat = matmul(a, at.ref);
4971
+ const l = cholesky$1(aat, { symmetrizeInput: false });
4972
+ const lb = triangularSolve(l.ref, b, {
4973
+ leftSide: true,
4974
+ lower: true
4975
+ });
4976
+ const llb = triangularSolve(l, lb, {
4977
+ leftSide: true,
4978
+ transposeA: true
4979
+ });
4980
+ return matmul(at, llb.ref);
4981
+ } else {
4982
+ const ata = matmul(at.ref, a);
4983
+ const l = cholesky$1(ata, { symmetrizeInput: false });
4984
+ const atb = matmul(at, b);
4985
+ const lb = triangularSolve(l.ref, atb, {
4986
+ leftSide: true,
4987
+ lower: true
4988
+ });
4989
+ const llb = triangularSolve(l, lb, {
4990
+ leftSide: true,
4991
+ transposeA: true
4992
+ });
4993
+ return llb;
4994
+ }
4995
+ }
4996
+
4574
4997
  //#endregion
4575
4998
  //#region src/library/numpy.ts
4576
4999
  var numpy_exports = {};
4577
5000
  __export(numpy_exports, {
4578
5001
  Array: () => Array$1,
4579
5002
  DType: () => DType,
4580
- abs: () => abs,
5003
+ abs: () => absolute,
4581
5004
  absolute: () => absolute,
4582
5005
  acos: () => acos,
4583
- acosh: () => acosh,
5006
+ acosh: () => arccosh,
4584
5007
  add: () => add,
5008
+ all: () => all,
4585
5009
  allclose: () => allclose,
5010
+ any: () => any,
4586
5011
  arange: () => arange,
4587
- arccos: () => arccos,
5012
+ arccos: () => acos,
4588
5013
  arccosh: () => arccosh,
5014
+ arcsin: () => asin,
4589
5015
  arcsinh: () => arcsinh,
4590
- arctan: () => arctan,
4591
- arctan2: () => arctan2,
5016
+ arctan: () => atan,
5017
+ arctan2: () => atan2,
4592
5018
  arctanh: () => arctanh,
4593
5019
  argmax: () => argmax,
4594
5020
  argmin: () => argmin,
5021
+ argsort: () => argsort,
4595
5022
  array: () => array,
4596
5023
  asin: () => asin,
4597
- asinh: () => asinh,
5024
+ asinh: () => arcsinh,
4598
5025
  astype: () => astype,
4599
5026
  atan: () => atan,
4600
5027
  atan2: () => atan2,
4601
- atanh: () => atanh,
5028
+ atanh: () => arctanh,
4602
5029
  bool: () => bool,
4603
5030
  broadcastArrays: () => broadcastArrays,
4604
5031
  broadcastShapes: () => broadcastShapes,
@@ -4608,14 +5035,20 @@ __export(numpy_exports, {
4608
5035
  clip: () => clip,
4609
5036
  columnStack: () => columnStack,
4610
5037
  concatenate: () => concatenate,
5038
+ convolve: () => convolve,
5039
+ corrcoef: () => corrcoef,
5040
+ correlate: () => correlate,
4611
5041
  cos: () => cos,
4612
5042
  cosh: () => cosh,
5043
+ cov: () => cov,
5044
+ cumsum: () => cumsum,
5045
+ cumulativeSum: () => cumsum,
4613
5046
  deg2rad: () => deg2rad,
4614
5047
  degrees: () => degrees,
4615
5048
  diag: () => diag,
4616
5049
  diagonal: () => diagonal,
4617
- divide: () => divide,
4618
- dot: () => dot,
5050
+ divide: () => trueDivide,
5051
+ dot: () => dot$1,
4619
5052
  dstack: () => dstack,
4620
5053
  e: () => e,
4621
5054
  einsum: () => einsum,
@@ -4623,8 +5056,10 @@ __export(numpy_exports, {
4623
5056
  eulerGamma: () => eulerGamma,
4624
5057
  exp: () => exp,
4625
5058
  exp2: () => exp2,
5059
+ expandDims: () => expandDims,
4626
5060
  expm1: () => expm1,
4627
5061
  eye: () => eye,
5062
+ fft: () => numpy_fft_exports,
4628
5063
  flip: () => flip,
4629
5064
  fliplr: () => fliplr,
4630
5065
  flipud: () => flipud,
@@ -4655,12 +5090,14 @@ __export(numpy_exports, {
4655
5090
  ldexp: () => ldexp,
4656
5091
  less: () => less,
4657
5092
  lessEqual: () => lessEqual,
5093
+ linalg: () => numpy_linalg_exports,
4658
5094
  linspace: () => linspace,
4659
5095
  log: () => log,
4660
5096
  log10: () => log10,
4661
5097
  log1p: () => log1p,
4662
5098
  log2: () => log2,
4663
5099
  matmul: () => matmul,
5100
+ matrixTranspose: () => matrixTranspose,
4664
5101
  max: () => max,
4665
5102
  maximum: () => maximum,
4666
5103
  mean: () => mean,
@@ -4677,10 +5114,10 @@ __export(numpy_exports, {
4677
5114
  onesLike: () => onesLike,
4678
5115
  outer: () => outer,
4679
5116
  pad: () => pad,
4680
- permuteDims: () => permuteDims,
5117
+ permuteDims: () => transpose,
4681
5118
  pi: () => pi,
4682
5119
  positive: () => positive,
4683
- pow: () => pow,
5120
+ pow: () => power,
4684
5121
  power: () => power,
4685
5122
  prod: () => prod$1,
4686
5123
  promoteTypes: () => promoteTypes,
@@ -4697,6 +5134,7 @@ __export(numpy_exports, {
4697
5134
  sin: () => sin,
4698
5135
  sinh: () => sinh,
4699
5136
  size: () => size,
5137
+ sort: () => sort,
4700
5138
  sqrt: () => sqrt,
4701
5139
  square: () => square,
4702
5140
  squeeze: () => squeeze,
@@ -4861,6 +5299,26 @@ function min(a, axis = null, opts) {
4861
5299
  function max(a, axis = null, opts) {
4862
5300
  return reduce(a, AluOp.Max, axis, opts);
4863
5301
  }
5302
+ /**
5303
+ * Test whether all array elements along a given axis evaluate to True.
5304
+ *
5305
+ * Returns a boolean array with the same shape as `a` with the specified axis
5306
+ * removed. If axis is None, returns a scalar.
5307
+ */
5308
+ function all(a, axis = null, opts) {
5309
+ a = fudgeArray(a).astype(DType.Bool);
5310
+ return min(a, axis, opts);
5311
+ }
5312
+ /**
5313
+ * Test whether any array element along a given axis evaluates to True.
5314
+ *
5315
+ * Returns a boolean array with the same shape as `a` with the specified axis
5316
+ * removed. If axis is None, returns a scalar.
5317
+ */
5318
+ function any(a, axis = null, opts) {
5319
+ a = fudgeArray(a).astype(DType.Bool);
5320
+ return max(a, axis, opts);
5321
+ }
4864
5322
  /** Return the peak-to-peak range along a given axis (`max - min`). */
4865
5323
  function ptp(a, axis = null, opts) {
4866
5324
  a = fudgeArray(a);
@@ -4918,6 +5376,23 @@ function argmax(a, axis, opts) {
4918
5376
  }).reshape([shape$1[axis], ...rep(shape$1.length - axis - 1, 1)]));
4919
5377
  return length.sub(max(idx, axis, opts));
4920
5378
  }
5379
+ /**
5380
+ * Cumulative sum of elements along an axis.
5381
+ *
5382
+ * Currently this function is `O(n^2)`, we'll improve this later on with a
5383
+ * two-phase parallel reduction algorithm.
5384
+ */
5385
+ function cumsum(a, axis) {
5386
+ a = fudgeArray(a);
5387
+ if (axis === void 0) {
5388
+ a = a.ravel();
5389
+ axis = 0;
5390
+ } else axis = checkAxis(axis, a.ndim);
5391
+ const n = a.shape[axis];
5392
+ a = moveaxis$1(a, axis, -1);
5393
+ a = broadcast(a, a.shape.concat(n), [-2]);
5394
+ return moveaxis$1(tril(a).sum(-1), -1, axis);
5395
+ }
4921
5396
  /** Reverse the elements in an array along the given axes. */
4922
5397
  function flip(x, axis = null) {
4923
5398
  const nd = ndim(x);
@@ -5027,8 +5502,11 @@ function flipud(x) {
5027
5502
  function fliplr(x) {
5028
5503
  return flip(x, 1);
5029
5504
  }
5030
- /** @function Alternative name for `numpy.transpose()`. */
5031
- const permuteDims = transpose;
5505
+ /** Transpose the last two dimensions of an array. */
5506
+ function matrixTranspose(a) {
5507
+ if (ndim(a) < 2) throw new Error(`matrixTranspose: input array must be at least 2D`);
5508
+ return moveaxis$1(a, -1, -2);
5509
+ }
5032
5510
  /** Return a 1-D flattened array containing the elements of the input. */
5033
5511
  function ravel(a) {
5034
5512
  return fudgeArray(a).ravel();
@@ -5044,6 +5522,32 @@ function squeeze(a, axis = null) {
5044
5522
  return reshape(a, newShape);
5045
5523
  }
5046
5524
  /**
5525
+ * Expand the shape of an array by inserting new axes of length 1.
5526
+ *
5527
+ * @param a - Input array.
5528
+ * @param axis - Position(s) in the expanded axes where the new axis (or axes)
5529
+ * is placed. Can be a single integer or an array of integers.
5530
+ * @returns Array with the number of dimensions increased.
5531
+ *
5532
+ * @example
5533
+ * ```ts
5534
+ * const x = np.array([1, 2]);
5535
+ * np.expandDims(x, 0); // Shape [1, 2]
5536
+ * np.expandDims(x, 1); // Shape [2, 1]
5537
+ * np.expandDims(x, [0, 2]); // Shape [1, 2, 1]
5538
+ * ```
5539
+ */
5540
+ function expandDims(a, axis) {
5541
+ const as = shape(a);
5542
+ axis = typeof axis === "number" ? [axis] : axis;
5543
+ axis = normalizeAxis(axis, as.length + axis.length);
5544
+ const newShape = [];
5545
+ let srcIdx = 0;
5546
+ for (let i = 0; i < as.length + axis.length; i++) if (axis.includes(i)) newShape.push(1);
5547
+ else newShape.push(as[srcIdx++]);
5548
+ return reshape(a, newShape);
5549
+ }
5550
+ /**
5047
5551
  * Repeat each element of an array after themselves.
5048
5552
  *
5049
5553
  * If no axis is provided, use the flattened input array, and return a flat
@@ -5131,7 +5635,7 @@ function diagonal(a, offset, axis1, axis2) {
5131
5635
  */
5132
5636
  function diag(v, k = 0) {
5133
5637
  const a = fudgeArray(v);
5134
- if (!Number.isInteger(k)) throw new TypeError(`k must be an integer, got ${k}`);
5638
+ if (!Number.isInteger(k)) throw new Error(`k must be an integer, got ${k}`);
5135
5639
  if (a.ndim === 1) {
5136
5640
  const n = a.shape[0];
5137
5641
  const ret = where(eye(n).equal(1), a.ref, zerosLike(a));
@@ -5139,12 +5643,32 @@ function diag(v, k = 0) {
5139
5643
  else if (k < 0) return pad(ret, [[-k, 0], [0, -k]]);
5140
5644
  else return ret;
5141
5645
  } else if (a.ndim === 2) return diagonal(a, k);
5142
- else throw new TypeError("numpy.diag only supports 1D and 2D arrays");
5646
+ else throw new Error("numpy.diag only supports 1D and 2D arrays");
5143
5647
  }
5144
5648
  /** Calculate the sum of the diagonal of an array along the given axes. */
5145
5649
  function trace(a, offset = 0, axis1 = 0, axis2 = 1) {
5146
5650
  return diagonal(a, offset, axis1, axis2).sum(-1);
5147
5651
  }
5652
+ /**
5653
+ * Return a sorted copy of an array.
5654
+ *
5655
+ * The array is sorted along a specified axis (the last by default). This may be
5656
+ * an unstable sort, and it dispatches to device-specific implementation.
5657
+ */
5658
+ function sort(a, axis = -1) {
5659
+ return fudgeArray(a).sort(axis);
5660
+ }
5661
+ /**
5662
+ * Return indices that would sort an array. This may be an unstable sorting
5663
+ * algorithm; it need not preserve order of indices in ties.
5664
+ *
5665
+ * Returns an array of `int32` indices.
5666
+ *
5667
+ * The array is sorted along a specified axis (the last by default).
5668
+ */
5669
+ function argsort(a, axis = -1) {
5670
+ return fudgeArray(a).argsort(axis);
5671
+ }
5148
5672
  /** Return if two arrays are element-wise equal within a tolerance. */
5149
5673
  function allclose(actual, expected, options) {
5150
5674
  const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
@@ -5153,16 +5677,19 @@ function allclose(actual, expected, options) {
5153
5677
  if (!deepEqual(x.shape, y.shape)) return false;
5154
5678
  const xData = x.dataSync();
5155
5679
  const yData = y.dataSync();
5156
- for (let i = 0; i < xData.length; i++) if (Math.abs(xData[i] - yData[i]) > atol + rtol * Math.abs(yData[i])) return false;
5680
+ for (let i = 0; i < xData.length; i++) {
5681
+ if (isNaN(xData[i]) !== isNaN(yData[i])) return false;
5682
+ if (Math.abs(xData[i] - yData[i]) > atol + rtol * Math.abs(yData[i])) return false;
5683
+ }
5157
5684
  return true;
5158
5685
  }
5159
5686
  /** Matrix product of two arrays. */
5160
5687
  function matmul(x, y) {
5161
- if (ndim(x) === 0 || ndim(y) === 0) throw new TypeError("matmul: x and y must be at least 1D");
5688
+ if (ndim(x) === 0 || ndim(y) === 0) throw new Error("matmul: x and y must be at least 1D");
5162
5689
  x = x, y = y;
5163
5690
  if (y.ndim === 1) return dot$2(x, y);
5164
5691
  const numBatchDims = Math.min(Math.max(x.ndim, 2), y.ndim) - 2;
5165
- return dot$1(x, y, {
5692
+ return dot(x, y, {
5166
5693
  lhsContractingDims: [-1],
5167
5694
  rhsContractingDims: [-2],
5168
5695
  lhsBatchDims: range(-2 - numBatchDims, -2),
@@ -5170,11 +5697,11 @@ function matmul(x, y) {
5170
5697
  });
5171
5698
  }
5172
5699
  /** Dot product of two arrays. */
5173
- function dot(x, y) {
5700
+ function dot$1(x, y) {
5174
5701
  if (ndim(x) === 0 || ndim(y) === 0) return multiply(x, y);
5175
5702
  x = x, y = y;
5176
5703
  if (y.ndim === 1) return dot$2(x, y);
5177
- return dot$1(x, y, {
5704
+ return dot(x, y, {
5178
5705
  lhsContractingDims: [-1],
5179
5706
  rhsContractingDims: [-2]
5180
5707
  });
@@ -5190,7 +5717,7 @@ function tensordot(x, y, axes = 2) {
5190
5717
  x = fudgeArray(x);
5191
5718
  y = fudgeArray(y);
5192
5719
  if (typeof axes === "number") axes = [range(-axes, 0), range(axes)];
5193
- return dot$1(x, y, {
5720
+ return dot(x, y, {
5194
5721
  lhsContractingDims: axes[0],
5195
5722
  rhsContractingDims: axes[1]
5196
5723
  });
@@ -5283,7 +5810,7 @@ function einsum(...args) {
5283
5810
  const [b, bidx] = processSingleTensor(operands[j], indices[j], indices[i]);
5284
5811
  indexReduced = indexReduced.filter((idx) => aidx.includes(idx));
5285
5812
  const indexBatch = aidx.filter((idx) => bidx.includes(idx) && !indexReduced.includes(idx));
5286
- const result = dot$1(a, b, {
5813
+ const result = dot(a, b, {
5287
5814
  lhsContractingDims: indexReduced.map((idx) => aidx.indexOf(idx)),
5288
5815
  rhsContractingDims: indexReduced.map((idx) => bidx.indexOf(idx)),
5289
5816
  lhsBatchDims: indexBatch.map((idx) => aidx.indexOf(idx)),
@@ -5311,7 +5838,7 @@ function einsum(...args) {
5311
5838
  * Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
5312
5839
  */
5313
5840
  function inner(x, y) {
5314
- return dot$1(fudgeArray(x), fudgeArray(y), {
5841
+ return dot(fudgeArray(x), fudgeArray(y), {
5315
5842
  lhsContractingDims: [-1],
5316
5843
  rhsContractingDims: [-1]
5317
5844
  });
@@ -5344,6 +5871,30 @@ function vecdot(x, y, { axis } = {}) {
5344
5871
  function vdot(x, y) {
5345
5872
  return dot$2(ravel(x), ravel(y));
5346
5873
  }
5874
+ function _convImpl(name, x, y, mode) {
5875
+ if (x.ndim !== 1 || y.ndim !== 1) throw new Error(`${name}: both inputs must be 1D arrays, got ${x.ndim}D and ${y.ndim}D`);
5876
+ let flipOutput = false;
5877
+ if (x.shape[0] < y.shape[0]) {
5878
+ [x, y] = [y, x];
5879
+ if (name === "correlate") flipOutput = true;
5880
+ }
5881
+ if (name === "convolve") y = flip(y);
5882
+ let padding;
5883
+ if (mode === "valid") padding = "VALID";
5884
+ else if (mode === "same") padding = "SAME_LOWER";
5885
+ else if (mode === "full") padding = [[y.shape[0] - 1, y.shape[0] - 1]];
5886
+ else throw new Error(`${name}: invalid mode ${mode}, expected "full", "same", or "valid"`);
5887
+ const z = conv(x.slice(null, null), y.slice(null, null), [1], padding).slice(0, 0);
5888
+ return flipOutput ? flip(z) : z;
5889
+ }
5890
+ /** Convolution of two one-dimensional arrays. */
5891
+ function convolve(x, y, mode = "full") {
5892
+ return _convImpl("convolve", x, y, mode);
5893
+ }
5894
+ /** Correlation of two one dimensional arrays. */
5895
+ function correlate(x, y, mode = "valid") {
5896
+ return _convImpl("correlate", x, y, mode);
5897
+ }
5347
5898
  /**
5348
5899
  * Return a tuple of coordinate matrices from coordinate vectors.
5349
5900
  *
@@ -5352,7 +5903,7 @@ function vdot(x, y) {
5352
5903
  */
5353
5904
  function meshgrid(xs, { indexing } = {}) {
5354
5905
  indexing ??= "xy";
5355
- for (const x of xs) if (x.ndim !== 1) throw new TypeError(`meshgrid: all inputs must be 1D arrays, got ${x.ndim}D array`);
5906
+ for (const x of xs) if (x.ndim !== 1) throw new Error(`meshgrid: all inputs must be 1D arrays, got ${x.ndim}D array`);
5356
5907
  if (xs.length <= 1) return xs;
5357
5908
  if (indexing === "xy") {
5358
5909
  const [a, b, ...rest] = xs;
@@ -5371,43 +5922,6 @@ function meshgrid(xs, { indexing } = {}) {
5371
5922
  return xs.map((x, i) => broadcast(x, shape$1, [...range(i), ...range(i + 1, xs.length)]));
5372
5923
  }
5373
5924
  /**
5374
- * Return an array with ones on and below the diagonal and zeros elsewhere.
5375
- *
5376
- * If `k` is provided, it specifies the sub-diagonal on and below which the
5377
- * array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
5378
- * `k>0` is above it.
5379
- */
5380
- function tri(n, m, k = 0, { dtype, device } = {}) {
5381
- m ??= n;
5382
- dtype ??= DType.Float32;
5383
- if (!Number.isInteger(n) || n < 0) throw new TypeError(`tri: n must be a non-negative integer, got ${n}`);
5384
- if (!Number.isInteger(m) || m < 0) throw new TypeError(`tri: m must be a non-negative integer, got ${m}`);
5385
- if (!Number.isInteger(k)) throw new TypeError(`tri: k must be an integer, got ${k}`);
5386
- const rows = arange(k, n + k, 1, {
5387
- dtype: DType.Int32,
5388
- device
5389
- });
5390
- const cols = arange(0, m, 1, {
5391
- dtype: DType.Int32,
5392
- device
5393
- });
5394
- return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
5395
- }
5396
- /** Return the lower triangle of an array. Must be of dimension >= 2. */
5397
- function tril(a, k = 0) {
5398
- if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
5399
- a = fudgeArray(a);
5400
- const [n, m] = a.shape.slice(-2);
5401
- return where(tri(n, m, k, { dtype: bool }), a.ref, zerosLike(a));
5402
- }
5403
- /** Return the upper triangle of an array. Must be of dimension >= 2. */
5404
- function triu(a, k = 0) {
5405
- if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
5406
- a = fudgeArray(a);
5407
- const [n, m] = a.shape.slice(-2);
5408
- return where(tri(n, m, k - 1, { dtype: bool }), zerosLike(a.ref), a);
5409
- }
5410
- /**
5411
5925
  * Clip (limit) the values in an array.
5412
5926
  *
5413
5927
  * Given an interval, values outside the interval are clipped to the interval
@@ -5431,8 +5945,6 @@ function absolute(x) {
5431
5945
  x = fudgeArray(x);
5432
5946
  return where(less(x.ref, 0), x.ref.mul(-1), x);
5433
5947
  }
5434
- /** @function Alias of `jax.numpy.absolute()`. */
5435
- const abs = absolute;
5436
5948
  /** Return an element-wise indication of sign of the input. */
5437
5949
  function sign(x) {
5438
5950
  x = fudgeArray(x);
@@ -5511,12 +6023,6 @@ const atan2 = jit$1(function atan2$1(y, x) {
5511
6023
  const denom = where(xNeg, y, r.add(x));
5512
6024
  return atan(numer.div(denom)).mul(2);
5513
6025
  });
5514
- /** @function Alias of `jax.numpy.acos()`. */
5515
- const arccos = acos;
5516
- /** @function Alias of `jax.numpy.atan()`. */
5517
- const arctan = atan;
5518
- /** @function Alias of `jax.numpy.atan2()`. */
5519
- const arctan2 = atan2;
5520
6026
  /** Element-wise subtraction, with broadcasting. */
5521
6027
  function subtract(x, y) {
5522
6028
  x = fudgeArray(x);
@@ -5547,8 +6053,6 @@ const fmod = jit$1(function fmod$1(x, y) {
5547
6053
  const remainder = jit$1(function remainder$1(x, y) {
5548
6054
  return mod(mod(x, y.ref).add(y.ref), y);
5549
6055
  });
5550
- /** @function Alias of `jax.numpy.trueDivide()`. */
5551
- const divide = trueDivide;
5552
6056
  /** Round input to the nearest integer towards zero. */
5553
6057
  function trunc(x) {
5554
6058
  return idiv(x, 1);
@@ -5570,9 +6074,9 @@ function ldexp(x1, x2) {
5570
6074
  */
5571
6075
  function frexp(x) {
5572
6076
  x = fudgeArray(x);
5573
- const absx = abs(x.ref);
6077
+ const absx = absolute(x.ref);
5574
6078
  const exponent = where(equal(x.ref, 0), 0, floor(log2(absx)).add(1).astype(DType.Int32));
5575
- const mantissa = divide(x, exp2(exponent.ref.astype(x.dtype)));
6079
+ const mantissa = x.div(exp2(exponent.ref.astype(x.dtype)));
5576
6080
  return [mantissa, exponent];
5577
6081
  }
5578
6082
  /** Calculate `2**p` for all p in the input array. */
@@ -5612,10 +6116,11 @@ const degrees = rad2deg;
5612
6116
  * Computes first array raised to power of second array, element-wise.
5613
6117
  */
5614
6118
  const power = jit$1(function power$1(x1, x2) {
5615
- return exp(log(x1).mul(x2));
6119
+ const x2i = trunc(x2.ref);
6120
+ const shouldBeNaN = multiply(x2.ref.notEqual(x2i.ref), x1.ref.less(0));
6121
+ const resultSign = where(mod(x2i, 2).notEqual(0), where(x1.ref.less(0), -1, 1), 1);
6122
+ return where(shouldBeNaN, nan, exp(log(absolute(x1)).mul(x2)).mul(resultSign));
5616
6123
  });
5617
- /** @function Alias of `jax.numpy.power()`. */
5618
- const pow = power;
5619
6124
  /** @function Calculate the element-wise cube root of the input array. */
5620
6125
  const cbrt = jit$1(function cbrt$1(x) {
5621
6126
  const sgn = where(less(x.ref, 0), -1, 1);
@@ -5681,12 +6186,6 @@ const arccosh = jit$1(function arccosh$1(x) {
5681
6186
  const arctanh = jit$1(function arctanh$1(x) {
5682
6187
  return log(add(1, x.ref).div(subtract(1, x))).mul(.5);
5683
6188
  });
5684
- /** @function Alias of `jax.numpy.arcsinh()`. */
5685
- const asinh = arcsinh;
5686
- /** @function Alias of `jax.numpy.arccosh()`. */
5687
- const acosh = arccosh;
5688
- /** @function Alias of `jax.numpy.arctanh()`. */
5689
- const atanh = arctanh;
5690
6189
  /**
5691
6190
  * Compute the variance of an array.
5692
6191
  *
@@ -5716,6 +6215,26 @@ function var_(x, axis = null, opts) {
5716
6215
  function std(x, axis = null, opts) {
5717
6216
  return sqrt(var_(x, axis, opts));
5718
6217
  }
6218
+ /** Estimate the sample covariance of a set of variables. */
6219
+ function cov(x, y) {
6220
+ x = fudgeArray(x);
6221
+ if (x.ndim === 1) x = x.reshape([1, x.shape[0]]);
6222
+ if (y !== void 0) {
6223
+ y = fudgeArray(y);
6224
+ if (y.ndim === 1) y = y.reshape([1, y.shape[0]]);
6225
+ x = vstack([x, y]);
6226
+ }
6227
+ const [_M, N] = x.shape;
6228
+ x = x.ref.sub(x.mean(1, { keepdims: true }));
6229
+ return dot$1(x.ref, x.transpose()).div(N - 1);
6230
+ }
6231
+ /** Compute the Pearson correlation coefficients (in range `[-1, 1]`). */
6232
+ function corrcoef(x, y) {
6233
+ const c = cov(x, y);
6234
+ const variances = diag(c.ref);
6235
+ const norm = sqrt(outer(variances.ref, variances));
6236
+ return c.div(norm);
6237
+ }
5719
6238
  /** Test element-wise for positive or negative infinity, return bool array. */
5720
6239
  function isinf(x) {
5721
6240
  x = fudgeArray(x);
@@ -5745,6 +6264,253 @@ const isfinite = jit$1(function isfinite$1(x) {
5745
6264
  return isnan(x.ref).add(isinf(x)).notEqual(true);
5746
6265
  });
5747
6266
 
6267
+ //#endregion
6268
+ //#region src/library/lax-linalg.ts
6269
+ var lax_linalg_exports = {};
6270
+ __export(lax_linalg_exports, {
6271
+ cholesky: () => cholesky,
6272
+ triangularSolve: () => triangularSolve
6273
+ });
6274
+ /**
6275
+ * Compute the Cholesky decomposition of a symmetric positive-definite matrix.
6276
+ *
6277
+ * The Cholesky decomposition of a matrix `A` is:
6278
+ *
6279
+ * - A = L @ L^T (for upper=false, default)
6280
+ * - A = U^T @ U (for upper=true)
6281
+ *
6282
+ * where `L` is a lower-triangular matrix and `U` is an upper-triangular matrix.
6283
+ * The input matrix must be symmetric and positive-definite.
6284
+ *
6285
+ * @example
6286
+ * ```ts
6287
+ * import { lax, numpy as np } from "@jax-js/jax";
6288
+ *
6289
+ * const x = np.array([[2., 1.], [1., 2.]]);
6290
+ *
6291
+ * // Lower Cholesky factorization (default):
6292
+ * const L = lax.linalg.cholesky(x);
6293
+ * // L ≈ [[1.4142135, 0], [0.70710677, 1.2247449]]
6294
+ *
6295
+ * // Upper Cholesky factorization:
6296
+ * const U = lax.linalg.cholesky(x, { upper: true });
6297
+ * // U ≈ [[1.4142135, 0.70710677], [0, 1.2247449]]
6298
+ * ```
6299
+ */
6300
+ function cholesky(a, { upper = false } = {}) {
6301
+ const L = cholesky$2(a);
6302
+ return upper ? moveaxis$1(L, -2, -1) : L;
6303
+ }
6304
+ /**
6305
+ * Solve a triangular linear system.
6306
+ *
6307
+ * Solves `a @ x = b` (if leftSide=true) or `x @ a = b` (if leftSide=false)
6308
+ * where `a` is a triangular matrix.
6309
+ *
6310
+ * @example
6311
+ * ```ts
6312
+ * import { lax, numpy as np } from "@jax-js/jax";
6313
+ *
6314
+ * const L = np.array([[2., 0.], [1., 3.]]);
6315
+ * const b = np.array([4., 7.]).reshape([2, 1]);
6316
+ *
6317
+ * // Solve L @ x = b
6318
+ * const x = lax.linalg.triangularSolve(L, b, { leftSide: true, lower: true });
6319
+ * // x = [[2.], [5./3.]]
6320
+ * ```
6321
+ */
6322
+ function triangularSolve(a, b, { leftSide = false, lower = false, transposeA = false, unitDiagonal = false } = {}) {
6323
+ a = fudgeArray(a);
6324
+ b = fudgeArray(b);
6325
+ if (!leftSide) transposeA = !transposeA;
6326
+ else b = moveaxis$1(b, -2, -1);
6327
+ if (transposeA) a = moveaxis$1(a, -2, -1);
6328
+ let x = triangularSolve$1(a, b, {
6329
+ lower,
6330
+ unitDiagonal
6331
+ });
6332
+ if (leftSide) x = moveaxis$1(x, -2, -1);
6333
+ return x;
6334
+ }
6335
+
6336
+ //#endregion
6337
+ //#region src/library/lax.ts
6338
+ var lax_exports = {};
6339
+ __export(lax_exports, {
6340
+ conv: () => conv,
6341
+ convGeneralDilated: () => convGeneralDilated,
6342
+ convWithGeneralPadding: () => convWithGeneralPadding,
6343
+ dot: () => dot,
6344
+ erf: () => erf,
6345
+ erfc: () => erfc,
6346
+ linalg: () => lax_linalg_exports,
6347
+ reduceWindow: () => reduceWindow,
6348
+ stopGradient: () => stopGradient$1
6349
+ });
6350
+ /**
6351
+ * General dot product/contraction operator.
6352
+ *
6353
+ * Prefer higher-level functions like `jax.numpy.dot()`, `jax.numpy.matmul()`,
6354
+ * `jax.numpy.tensordot(), and `jax.numpy.einsum()` where possible.
6355
+ */
6356
+ function dot(lhs, rhs, { lhsContractingDims: lc = [], rhsContractingDims: rc = [], lhsBatchDims: lb = [], rhsBatchDims: rb = [] } = {}) {
6357
+ if (lc.length !== rc.length) throw new Error(`dot: contracting dims lengths mismatch, got ${JSON.stringify(lc)} and ${JSON.stringify(rc)}`);
6358
+ else if (lb.length !== rb.length) throw new Error(`dot: batch dims lengths mismatch, got ${JSON.stringify(lb)} and ${JSON.stringify(rb)}`);
6359
+ lc = lc.map((a) => checkAxis(a, lhs.ndim));
6360
+ rc = rc.map((a) => checkAxis(a, rhs.ndim));
6361
+ lb = lb.map((a) => checkAxis(a, lhs.ndim));
6362
+ rb = rb.map((a) => checkAxis(a, rhs.ndim));
6363
+ if (lc.some((a) => lb.includes(a))) throw new Error(`dot: lhs contracting dims ${JSON.stringify(lc)} overlap with batch dims ${JSON.stringify(lb)}`);
6364
+ else if (rc.some((a) => rb.includes(a))) throw new Error(`dot: rhs contracting dims ${JSON.stringify(rc)} overlap with batch dims ${JSON.stringify(rb)}`);
6365
+ const lf = range(lhs.ndim).filter((a) => !lc.includes(a) && !lb.includes(a));
6366
+ const rf = range(rhs.ndim).filter((a) => !rc.includes(a) && !rb.includes(a));
6367
+ const lhs2 = lhs.transpose([
6368
+ ...lb,
6369
+ ...lf,
6370
+ ...lc
6371
+ ]);
6372
+ const rhs2 = rhs.transpose([
6373
+ ...rb,
6374
+ ...rf,
6375
+ ...rc
6376
+ ]);
6377
+ if (lc.length === 0) return mul(lhs2.reshape([
6378
+ ...lb.map((a) => lhs.shape[a]),
6379
+ ...lf.map((a) => lhs.shape[a]),
6380
+ ...rep(rf.length, 1)
6381
+ ]), rhs2.reshape([
6382
+ ...rb.map((a) => rhs.shape[a]),
6383
+ ...rep(lf.length, 1),
6384
+ ...rf.map((a) => rhs.shape[a])
6385
+ ]));
6386
+ const dotShapeX = lc.map((a) => lhs.shape[a]);
6387
+ const dotShapeY = rc.map((a) => rhs.shape[a]);
6388
+ if (!deepEqual(dotShapeX, dotShapeY)) throw new Error(`dot: shapes not aligned along contracting dims: ${JSON.stringify(dotShapeX)} != ${JSON.stringify(dotShapeY)}`);
6389
+ return dot$2(lhs2.reshape([
6390
+ ...lb.map((a) => lhs.shape[a]),
6391
+ ...lf.map((a) => lhs.shape[a]),
6392
+ ...rep(rf.length, 1),
6393
+ prod(dotShapeX)
6394
+ ]), rhs2.reshape([
6395
+ ...rb.map((a) => rhs.shape[a]),
6396
+ ...rep(lf.length, 1),
6397
+ ...rf.map((a) => rhs.shape[a]),
6398
+ prod(dotShapeY)
6399
+ ]));
6400
+ }
6401
+ function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
6402
+ const padType = padding.toUpperCase();
6403
+ switch (padType) {
6404
+ case "VALID": return rep(inShape.length, [0, 0]);
6405
+ case "SAME":
6406
+ case "SAME_LOWER": {
6407
+ const outShape = inShape.map((size$1, i) => Math.ceil(size$1 / strides[i]));
6408
+ const padSizes = zipn(outShape, strides, filterShape, dilation, inShape).map(([o, s, k, d, i]) => Math.max(0, (o - 1) * s + 1 + (k - 1) * d - i));
6409
+ if (padType === "SAME") return padSizes.map((size$1) => [size$1 >> 1, size$1 - (size$1 >> 1)]);
6410
+ else return padSizes.map((size$1) => [size$1 - (size$1 >> 1), size$1 >> 1]);
6411
+ }
6412
+ default: throw new Error(`Unknown padding type: ${padType}`);
6413
+ }
6414
+ }
6415
+ /**
6416
+ * General n-dimensional convolution operator, with optional dilation.
6417
+ *
6418
+ * The semantics of this operation mimic the `jax.lax.conv_general_dilated`
6419
+ * function in JAX, which wraps XLA's general convolution operator.
6420
+ *
6421
+ * Grouped convolutions are not supported right now.
6422
+ */
6423
+ function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
6424
+ if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
6425
+ if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
6426
+ if (typeof padding === "string") {
6427
+ if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
6428
+ padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? rep(rhs.ndim - 2, 1), padding);
6429
+ }
6430
+ if (featureGroupCount !== 1) {
6431
+ const G = featureGroupCount;
6432
+ const [N, C_in, ...xs] = lhs.shape;
6433
+ const [C_out, C_in_per_group, ...ks] = rhs.shape;
6434
+ if (C_in % G !== 0) throw new Error(`featureGroupCount=${G} must divide input channels=${C_in}`);
6435
+ if (C_out % G !== 0) throw new Error(`featureGroupCount=${G} must divide output channels=${C_out}`);
6436
+ if (C_in / G !== C_in_per_group) throw new Error(`rhs input channels=${C_in_per_group} must equal lhs input channels / groups=${C_in / G}`);
6437
+ const lhsGrouped = moveaxis(lhs.reshape([
6438
+ N,
6439
+ G,
6440
+ C_in / G,
6441
+ ...xs
6442
+ ]), 1, 0);
6443
+ const rhsGrouped = rhs.reshape([
6444
+ G,
6445
+ C_out / G,
6446
+ C_in_per_group,
6447
+ ...ks
6448
+ ]);
6449
+ const result = conv$1(lhsGrouped, rhsGrouped, {
6450
+ vmapDims: 1,
6451
+ strides: windowStrides,
6452
+ padding,
6453
+ lhsDilation,
6454
+ rhsDilation
6455
+ });
6456
+ const ys = result.shape.slice(3);
6457
+ return moveaxis(result, 0, 1).reshape([
6458
+ N,
6459
+ C_out,
6460
+ ...ys
6461
+ ]);
6462
+ }
6463
+ return conv$1(lhs, rhs, {
6464
+ strides: windowStrides,
6465
+ padding,
6466
+ lhsDilation,
6467
+ rhsDilation
6468
+ });
6469
+ }
6470
+ /** Convenience wrapper around `convGeneralDilated`. */
6471
+ function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation) {
6472
+ return convGeneralDilated(lhs, rhs, windowStrides, padding, {
6473
+ lhsDilation,
6474
+ rhsDilation
6475
+ });
6476
+ }
6477
+ /** Convenience wrapper around `convGeneralDilated`. */
6478
+ function conv(lhs, rhs, windowStrides, padding) {
6479
+ return convGeneralDilated(lhs, rhs, windowStrides, padding);
6480
+ }
6481
+ /** Reduce a computation over padded windows. */
6482
+ function reduceWindow(operand, computation, windowDimensions, windowStrides) {
6483
+ if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
6484
+ if (!windowStrides) windowStrides = rep(windowDimensions.length, 1);
6485
+ for (let i = 0; i < operand.ndim; i++) computation = vmap$1(computation, 0);
6486
+ return computation(bind1(Primitive.Pool, [operand], {
6487
+ window: windowDimensions,
6488
+ strides: windowStrides
6489
+ }));
6490
+ }
6491
+ /** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
6492
+ function erf(x) {
6493
+ return erf$1(x);
6494
+ }
6495
+ /**
6496
+ * The complementary error function: `erfc(x) = 1 - erf(x)`.
6497
+ *
6498
+ * This function is more accurate than `1 - erf(x)` for large values of `x`,
6499
+ * where `erf(x)` is very close to 1.
6500
+ */
6501
+ function erfc(x) {
6502
+ return erfc$1(x);
6503
+ }
6504
+ /**
6505
+ * Stops gradient computation.
6506
+ *
6507
+ * Behaves as the identity function but prevents the flow of gradients during
6508
+ * forward or reverse-mode automatic differentiation.
6509
+ */
6510
+ function stopGradient$1(x) {
6511
+ return stopGradient(x);
6512
+ }
6513
+
5748
6514
  //#endregion
5749
6515
  //#region src/library/nn.ts
5750
6516
  var nn_exports = {};
@@ -5753,6 +6519,10 @@ __export(nn_exports, {
5753
6519
  elu: () => elu,
5754
6520
  gelu: () => gelu,
5755
6521
  glu: () => glu,
6522
+ hardSigmoid: () => hardSigmoid,
6523
+ hardSilu: () => hardSilu,
6524
+ hardSwish: () => hardSilu,
6525
+ hardTanh: () => hardTanh,
5756
6526
  identity: () => identity,
5757
6527
  leakyRelu: () => leakyRelu,
5758
6528
  logSigmoid: () => logSigmoid,
@@ -5763,14 +6533,17 @@ __export(nn_exports, {
5763
6533
  oneHot: () => oneHot,
5764
6534
  relu: () => relu,
5765
6535
  relu6: () => relu6,
6536
+ selu: () => selu,
5766
6537
  sigmoid: () => sigmoid,
5767
6538
  silu: () => silu,
5768
6539
  softSign: () => softSign,
5769
6540
  softmax: () => softmax,
5770
6541
  softplus: () => softplus,
6542
+ sparsePlus: () => sparsePlus,
6543
+ sparseSigmoid: () => sparseSigmoid,
5771
6544
  squareplus: () => squareplus,
5772
6545
  standardize: () => standardize,
5773
- swish: () => swish
6546
+ swish: () => silu
5774
6547
  });
5775
6548
  /**
5776
6549
  * Rectified Linear Unit (ReLU) activation function:
@@ -5805,6 +6578,28 @@ function softplus(x) {
5805
6578
  return log(exp(x).add(1));
5806
6579
  }
5807
6580
  /**
6581
+ * @function
6582
+ * Sparse plus function:
6583
+ *
6584
+ * - When `x <= -1`: `0`
6585
+ * - When `-1 < x < 1`: `(x+1)**2 / 4`
6586
+ * - When `x >= 1`: `x`
6587
+ */
6588
+ const sparsePlus = jit$1((x) => {
6589
+ return where(x.ref.lessEqual(-1), 0, where(x.ref.less(1), square(x.ref.add(1)).mul(.25), x));
6590
+ });
6591
+ /**
6592
+ * @function
6593
+ * Sparse sigmoid activation function.
6594
+ *
6595
+ * - When `x <= -1`: `0`
6596
+ * - When `-1 < x < 1`: `(x + 1) / 2`
6597
+ * - When `x >= 1`: `1`
6598
+ */
6599
+ const sparseSigmoid = jit$1((x) => {
6600
+ return clip(x.add(1).mul(.5), 0, 1);
6601
+ });
6602
+ /**
5808
6603
  * Soft-sign activation function, computed element-wise:
5809
6604
  * `softsign(x) = x / (|x| + 1)`.
5810
6605
  */
@@ -5826,17 +6621,6 @@ const silu = jit$1(function silu$1(x) {
5826
6621
  return x.ref.mul(sigmoid(x));
5827
6622
  });
5828
6623
  /**
5829
- * @function
5830
- * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
5831
- * Swish, computed element-wise:
5832
- * `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
5833
- *
5834
- * `swish()` and `silu()` are both aliases for the same function.
5835
- *
5836
- * Reference: https://en.wikipedia.org/wiki/Swish_function
5837
- */
5838
- const swish = silu;
5839
- /**
5840
6624
  * Log-sigmoid activation function, computed element-wise:
5841
6625
  * `log_sigmoid(x) = log(sigmoid(x)) = -log(1 + exp(-x))`.
5842
6626
  */
@@ -5853,6 +6637,19 @@ function leakyRelu(x, negativeSlope = .01) {
5853
6637
  x = fudgeArray(x);
5854
6638
  return where(less(x.ref, 0), x.ref.mul(negativeSlope), x);
5855
6639
  }
6640
+ /** Hard sigmoid activation function: `relu6(x+3)/6`. */
6641
+ function hardSigmoid(x) {
6642
+ return relu6(add(x, 3)).mul(1 / 6);
6643
+ }
6644
+ /** Hard SiLU (swish) activation function: `x * hardSigmoid(x)`. */
6645
+ function hardSilu(x) {
6646
+ x = fudgeArray(x);
6647
+ return x.ref.mul(hardSigmoid(x));
6648
+ }
6649
+ /** Hard tanh activation function: `clip(x, -1, 1)`. */
6650
+ function hardTanh(x) {
6651
+ return clip(x, -1, 1);
6652
+ }
5856
6653
  /**
5857
6654
  * Exponential linear unit activation function.
5858
6655
  *
@@ -5875,6 +6672,20 @@ function celu(x, alpha = 1) {
5875
6672
  }
5876
6673
  /**
5877
6674
  * @function
6675
+ * Scaled exponential linear unit activation.
6676
+ *
6677
+ * Computes the element-wise function:
6678
+ * `selu(x) = lambda * (x > 0 ? x : alpha * (exp(x) - 1))`
6679
+ *
6680
+ * Where `alpha = 1.6732632423543772` and `lambda = 1.0507009873554805`.
6681
+ */
6682
+ const selu = jit$1(function selu$1(x) {
6683
+ const alpha = 1.6732632423543772;
6684
+ const lambda = 1.0507009873554805;
6685
+ return where(x.ref.less(0), expm1(x.ref).mul(alpha), x).mul(lambda);
6686
+ });
6687
+ /**
6688
+ * @function
5878
6689
  * Gaussion error linear unit (GELU) activation function.
5879
6690
  *
5880
6691
  * This is computed element-wise. There are two variants depending on whether
@@ -5968,22 +6779,22 @@ function logSoftmax(x, axis = -1) {
5968
6779
  *
5969
6780
  * Reference: https://en.wikipedia.org/wiki/LogSumExp
5970
6781
  */
5971
- function logsumexp(x, axis = null) {
6782
+ function logsumexp(x, axis = null, opts) {
5972
6783
  x = fudgeArray(x);
5973
6784
  axis = normalizeAxis(axis, x.ndim);
5974
6785
  if (axis.length === 0) return x;
5975
- const xMax = stopGradient(max(x.ref, axis));
5976
- const xMaxDims = broadcast(xMax.ref, x.shape, axis);
5977
- const shifted = x.sub(xMaxDims);
5978
- return xMax.add(log(exp(shifted).sum(axis)));
6786
+ const xMax = stopGradient(max(x.ref, axis, { keepdims: true }));
6787
+ const shifted = x.sub(xMax.ref);
6788
+ const result = xMax.add(log(exp(shifted).sum(axis, { keepdims: true })));
6789
+ return opts?.keepdims ? result : squeeze(result, axis);
5979
6790
  }
5980
6791
  /** Log-mean-exp reduction, like `jax.nn.logsumexp()` but subtracts `log(n)`. */
5981
- function logmeanexp(x, axis = null) {
6792
+ function logmeanexp(x, axis = null, opts) {
5982
6793
  x = fudgeArray(x);
5983
6794
  axis = normalizeAxis(axis, x.ndim);
5984
6795
  if (axis.length === 0) return x;
5985
6796
  const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
5986
- return logsumexp(x, axis).sub(Math.log(n));
6797
+ return logsumexp(x, axis, opts).sub(Math.log(n));
5987
6798
  }
5988
6799
  /**
5989
6800
  * Standardizes input to zero mean and unit variance.
@@ -6028,8 +6839,11 @@ var random_exports = {};
6028
6839
  __export(random_exports, {
6029
6840
  bernoulli: () => bernoulli,
6030
6841
  bits: () => bits,
6842
+ cauchy: () => cauchy,
6031
6843
  exponential: () => exponential,
6844
+ gumbel: () => gumbel,
6032
6845
  key: () => key,
6846
+ laplace: () => laplace,
6033
6847
  normal: () => normal,
6034
6848
  split: () => split,
6035
6849
  uniform: () => uniform
@@ -6088,6 +6902,16 @@ function bernoulli(key$1, p = .5, shape$1 = []) {
6088
6902
  }
6089
6903
  /**
6090
6904
  * @function
6905
+ * Sample from a Cauchy distribution with location 0 and scale 1.
6906
+ *
6907
+ * Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
6908
+ */
6909
+ const cauchy = jit$1(function cauchy$1(key$1, shape$1 = []) {
6910
+ const u = uniform(key$1, shape$1);
6911
+ return tan(u.sub(.5).mul(Math.PI));
6912
+ }, { staticArgnums: [1] });
6913
+ /**
6914
+ * @function
6091
6915
  * Sample exponential random values according to `p(x) = exp(-x)`.
6092
6916
  */
6093
6917
  const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
@@ -6096,6 +6920,30 @@ const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
6096
6920
  }, { staticArgnums: [1] });
6097
6921
  /**
6098
6922
  * @function
6923
+ * Sample from a Gumbel distribution with location 0 and scale 1.
6924
+ *
6925
+ * Uses inverse transform sampling: `x = -log(-log(u))` where u ~ Uniform(0, 1).
6926
+ */
6927
+ const gumbel = jit$1(function gumbel$1(key$1, shape$1 = []) {
6928
+ const u = uniform(key$1, shape$1);
6929
+ return negative(log(negative(log1p(negative(u)))));
6930
+ }, { staticArgnums: [1] });
6931
+ /**
6932
+ * @function
6933
+ * Sample from a Laplace distribution with location 0 and scale 1.
6934
+ *
6935
+ * Uses inverse transform sampling: the CDF is `F(x) = 0.5 + 0.5 * sign(x) * (1 - exp(-|x|))`.
6936
+ * Inverting: `x = -sign(u - 0.5) * log(1 - 2 * |u - 0.5|)`.
6937
+ */
6938
+ const laplace = jit$1(function laplace$1(key$1, shape$1 = []) {
6939
+ const u = uniform(key$1, shape$1);
6940
+ const centered = u.sub(.5);
6941
+ const s = sign(centered.ref);
6942
+ const absVal = absolute(centered);
6943
+ return s.mul(log1p(absVal.mul(-2)).mul(-1));
6944
+ }, { staticArgnums: [1] });
6945
+ /**
6946
+ * @function
6099
6947
  * Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
6100
6948
  *
6101
6949
  * Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
@@ -6204,11 +7052,6 @@ const valueAndGrad = valueAndGrad$1;
6204
7052
  */
6205
7053
  const jacrev = jacrev$1;
6206
7054
  /**
6207
- * @function
6208
- * Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
6209
- */
6210
- const jacobian = jacrev;
6211
- /**
6212
7055
  * Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
6213
7056
  *
6214
7057
  * This can be used to wait for the results of an intermediate computation to
@@ -6243,5 +7086,4 @@ async function devicePut(x, device) {
6243
7086
  }
6244
7087
 
6245
7088
  //#endregion
6246
- export { Array$1 as Array, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
6247
- //# sourceMappingURL=index.js.map
7089
+ export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };