@jax-js/jax 0.1.3 → 0.1.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -1,5 +1,5 @@
1
1
  import { __export } from "./chunk-Cl8Af3a2.js";
2
- import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-BY8wlLEl.js";
2
+ import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-tngXtWe4.js";
3
3
 
4
4
  //#region src/frontend/convolution.ts
5
5
  /**
@@ -331,6 +331,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
331
331
  Primitive$1["Mul"] = "mul";
332
332
  Primitive$1["Idiv"] = "idiv";
333
333
  Primitive$1["Mod"] = "mod";
334
+ Primitive$1["Min"] = "min";
335
+ Primitive$1["Max"] = "max";
334
336
  Primitive$1["Neg"] = "neg";
335
337
  Primitive$1["Reciprocal"] = "reciprocal";
336
338
  Primitive$1["Floor"] = "floor";
@@ -338,7 +340,6 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
338
340
  Primitive$1["StopGradient"] = "stop_gradient";
339
341
  Primitive$1["Cast"] = "cast";
340
342
  Primitive$1["Bitcast"] = "bitcast";
341
- Primitive$1["RandomBits"] = "random_bits";
342
343
  Primitive$1["Sin"] = "sin";
343
344
  Primitive$1["Cos"] = "cos";
344
345
  Primitive$1["Asin"] = "asin";
@@ -348,8 +349,6 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
348
349
  Primitive$1["Erf"] = "erf";
349
350
  Primitive$1["Erfc"] = "erfc";
350
351
  Primitive$1["Sqrt"] = "sqrt";
351
- Primitive$1["Min"] = "min";
352
- Primitive$1["Max"] = "max";
353
352
  Primitive$1["Reduce"] = "reduce";
354
353
  Primitive$1["Dot"] = "dot";
355
354
  Primitive$1["Conv"] = "conv";
@@ -357,14 +356,19 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
357
356
  Primitive$1["PoolTranspose"] = "pool_transpose";
358
357
  Primitive$1["Compare"] = "compare";
359
358
  Primitive$1["Where"] = "where";
359
+ Primitive$1["RandomBits"] = "random_bits";
360
+ Primitive$1["Gather"] = "gather";
360
361
  Primitive$1["Transpose"] = "transpose";
361
362
  Primitive$1["Broadcast"] = "broadcast";
362
363
  Primitive$1["Reshape"] = "reshape";
363
364
  Primitive$1["Flip"] = "flip";
364
365
  Primitive$1["Shrink"] = "shrink";
365
366
  Primitive$1["Pad"] = "pad";
366
- Primitive$1["Gather"] = "gather";
367
- Primitive$1["JitCall"] = "jit_call";
367
+ Primitive$1["Sort"] = "sort";
368
+ Primitive$1["Argsort"] = "argsort";
369
+ Primitive$1["TriangularSolve"] = "triangular_solve";
370
+ Primitive$1["Cholesky"] = "cholesky";
371
+ Primitive$1["Jit"] = "jit";
368
372
  return Primitive$1;
369
373
  }({});
370
374
  let CompareOp = /* @__PURE__ */ function(CompareOp$1) {
@@ -386,6 +390,12 @@ function idiv(x, y) {
386
390
  function mod(x, y) {
387
391
  return bind1(Primitive.Mod, [x, y]);
388
392
  }
393
+ function min$1(x, y) {
394
+ return bind1(Primitive.Min, [x, y]);
395
+ }
396
+ function max$1(x, y) {
397
+ return bind1(Primitive.Max, [x, y]);
398
+ }
389
399
  function neg(x) {
390
400
  return bind1(Primitive.Neg, [x]);
391
401
  }
@@ -407,12 +417,6 @@ function cast(x, dtype) {
407
417
  function bitcast(x, dtype) {
408
418
  return bind1(Primitive.Bitcast, [x], { dtype });
409
419
  }
410
- function randomBits(k0, k1, shape$1, mode = "xor") {
411
- return bind1(Primitive.RandomBits, [k0, k1], {
412
- shape: shape$1,
413
- mode
414
- });
415
- }
416
420
  function sin$1(x) {
417
421
  return bind1(Primitive.Sin, [x]);
418
422
  }
@@ -440,12 +444,6 @@ function erfc$1(x) {
440
444
  function sqrt$1(x) {
441
445
  return bind1(Primitive.Sqrt, [x]);
442
446
  }
443
- function min$1(x, y) {
444
- return bind1(Primitive.Min, [x, y]);
445
- }
446
- function max$1(x, y) {
447
- return bind1(Primitive.Max, [x, y]);
448
- }
449
447
  function reduce(x, op, axis = null, opts) {
450
448
  if (!AluGroup.Reduce.has(op)) throw new TypeError(`Invalid reduce operation: ${op}`);
451
449
  axis = normalizeAxis(axis, ndim$1(x));
@@ -501,6 +499,23 @@ function where$1(cond, x, y) {
501
499
  y
502
500
  ]);
503
501
  }
502
+ function randomBits(k0, k1, shape$1, mode = "xor") {
503
+ return bind1(Primitive.RandomBits, [k0, k1], {
504
+ shape: shape$1,
505
+ mode
506
+ });
507
+ }
508
+ function gather(x, indices, axis, outDim) {
509
+ if (indices.length === 0) throw new Error("gather() requires at least one index");
510
+ if (!Array.isArray(axis) || axis.length !== indices.length) throw new Error(`Invalid gather() axis: expected ${indices.length} axes, got ${JSON.stringify(axis)}`);
511
+ axis = axis.map((a) => checkAxis(a, ndim$1(x)));
512
+ if (new Set(axis).size !== axis.length) throw new Error(`Invalid gather() axis: duplicate axes ${JSON.stringify(axis)}`);
513
+ outDim = checkAxis(outDim, ndim$1(x) - axis.length + 1);
514
+ return bind1(Primitive.Gather, [x, ...indices], {
515
+ axis,
516
+ outDim
517
+ });
518
+ }
504
519
  function transpose$1(x, perm) {
505
520
  perm = perm ? perm.map((a) => checkAxis(a, ndim$1(x))) : range(ndim$1(x)).reverse();
506
521
  if (!isPermutation(perm, ndim$1(x))) throw new Error(`Invalid transpose permutation for ${ndim$1(x)} axes: ${JSON.stringify(perm)}`);
@@ -550,16 +565,27 @@ function pad$1(x, width) {
550
565
  } else if (width.length !== nd) throw new Error(`Invalid pad(): expected ${nd} axes, got ${width.length}`);
551
566
  return bind1(Primitive.Pad, [x], { width });
552
567
  }
553
- function gather(x, indices, axis, outDim) {
554
- if (indices.length === 0) throw new Error("gather() requires at least one index");
555
- if (!Array.isArray(axis) || axis.length !== indices.length) throw new Error(`Invalid gather() axis: expected ${indices.length} axes, got ${JSON.stringify(axis)}`);
556
- axis = axis.map((a) => checkAxis(a, ndim$1(x)));
557
- if (new Set(axis).size !== axis.length) throw new Error(`Invalid gather() axis: duplicate axes ${JSON.stringify(axis)}`);
558
- outDim = checkAxis(outDim, ndim$1(x) - axis.length + 1);
559
- return bind1(Primitive.Gather, [x, ...indices], {
560
- axis,
561
- outDim
562
- });
568
+ function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
569
+ if (lower) {
570
+ a = flip$1(a, [-2, -1]);
571
+ b = flip$1(b, [-1]);
572
+ }
573
+ let x = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
574
+ if (lower) x = flip$1(x, [-1]);
575
+ return x;
576
+ }
577
+ function cholesky$2(x) {
578
+ return bind1(Primitive.Cholesky, [x]);
579
+ }
580
+ function sort$1(x) {
581
+ const nd = ndim$1(x);
582
+ if (nd === 0) throw new Error("sort: requires at least 1D input");
583
+ return bind1(Primitive.Sort, [x]);
584
+ }
585
+ function argsort$1(x) {
586
+ const nd = ndim$1(x);
587
+ if (nd === 0) throw new Error("argsort: requires at least 1D input");
588
+ return bind(Primitive.Argsort, [x]);
563
589
  }
564
590
  function bind1(prim, args, params = {}) {
565
591
  const [results] = bind(prim, args, params);
@@ -722,7 +748,7 @@ var Tracer = class Tracer {
722
748
  if (isFloatDtype(this.dtype)) return this.mul(reciprocal$1(other));
723
749
  return idiv(this, other);
724
750
  }
725
- /** Return specified diagonals. See `numpy.diagonal` for full docs. */
751
+ /** Return specified diagonals. See `jax.numpy.diagonal` for full docs. */
726
752
  diagonal(offset = 0, axis1 = 0, axis2 = 1) {
727
753
  if (!Number.isInteger(offset)) throw new TypeError(`offset must be an integer, got ${offset}`);
728
754
  if (offset < 0) return this.diagonal(-offset, axis2, axis1);
@@ -775,6 +801,34 @@ var Tracer = class Tracer {
775
801
  this.dispose();
776
802
  }
777
803
  /**
804
+ * Return a sorted copy of an array in ascending order.
805
+ *
806
+ * See `jax.numpy.sort` for full docs.
807
+ */
808
+ sort(axis = -1) {
809
+ axis = checkAxis(axis, this.ndim);
810
+ if (this.shape[axis] <= 1) return this;
811
+ if (axis === this.ndim - 1) return sort$1(this);
812
+ const perm = range(this.ndim);
813
+ perm.splice(axis, 1);
814
+ perm.push(axis);
815
+ return sort$1(this.transpose(perm)).transpose(invertPermutation(perm));
816
+ }
817
+ /**
818
+ * Return the indices that would sort an array. This may not be a stable
819
+ * sorting algorithm; it need not preserve order of indices in ties.
820
+ *
821
+ * See `jax.numpy.argsort` for full docs.
822
+ */
823
+ argsort(axis = -1) {
824
+ axis = checkAxis(axis, this.ndim);
825
+ if (axis === this.ndim - 1) return argsort$1(this)[1];
826
+ const perm = range(this.ndim);
827
+ perm.splice(axis, 1);
828
+ perm.push(axis);
829
+ return argsort$1(this.transpose(perm))[1].transpose(invertPermutation(perm));
830
+ }
831
+ /**
778
832
  * Slice an array along one or more axes.
779
833
  *
780
834
  * This is the equivalent of slicing in Python, e.g. `x[1:3, 2, :, None]`. To
@@ -891,6 +945,9 @@ var ShapedArray = class ShapedArray {
891
945
  get ndim() {
892
946
  return this.shape.length;
893
947
  }
948
+ get size() {
949
+ return prod(this.shape);
950
+ }
894
951
  toString() {
895
952
  return `${this.dtype}[${this.shape.join(",")}]`;
896
953
  }
@@ -1186,13 +1243,13 @@ var Jaxpr = class Jaxpr {
1186
1243
  }
1187
1244
  return new Jaxpr(this.inBinders, liveEqns.reverse(), outs);
1188
1245
  }
1189
- /** Flattens nested JitCall in a Jaxpr. Useful for handling jit-of-jit. */
1246
+ /** Flattens nested Jit in a Jaxpr. Useful for handling jit-of-jit. */
1190
1247
  flatten() {
1191
- if (!this.eqns.some((eqn) => eqn.primitive === Primitive.JitCall)) return this;
1248
+ if (!this.eqns.some((eqn) => eqn.primitive === Primitive.Jit)) return this;
1192
1249
  const newEqns = [];
1193
1250
  const varMap = /* @__PURE__ */ new Map();
1194
1251
  const varMapF = (x) => x instanceof Var ? varMap.get(x) ?? x : x;
1195
- for (const eqn of this.eqns) if (eqn.primitive === Primitive.JitCall) {
1252
+ for (const eqn of this.eqns) if (eqn.primitive === Primitive.Jit) {
1196
1253
  const jaxpr = eqn.params.jaxpr.flatten();
1197
1254
  const translation = /* @__PURE__ */ new Map();
1198
1255
  const translationF = (x) => x instanceof Var ? translation.get(x) : x;
@@ -1293,19 +1350,48 @@ function evalJaxpr(jaxpr, args) {
1293
1350
  function jaxprAsFun(jaxpr) {
1294
1351
  return (...args) => evalJaxpr(jaxpr, args);
1295
1352
  }
1353
+ /** Jaxpr with a collection of associated, traced constants. */
1354
+ var ClosedJaxpr = class ClosedJaxpr {
1355
+ constructor(jaxpr, consts) {
1356
+ this.jaxpr = jaxpr;
1357
+ this.consts = consts;
1358
+ }
1359
+ /** String representation of this Jaxpr. */
1360
+ toString() {
1361
+ return this.jaxpr.toString();
1362
+ }
1363
+ /** Apply a function to the underlying Jaxpr. */
1364
+ mapJaxpr(f) {
1365
+ return new ClosedJaxpr(f(this.jaxpr), this.consts);
1366
+ }
1367
+ /** Dispose of the constants in this Jaxpr. */
1368
+ dispose() {
1369
+ for (const c of this.consts) c.dispose();
1370
+ }
1371
+ };
1296
1372
  /** Tracer that records its operations to dynamically construct a Jaxpr. */
1297
1373
  var JaxprTracer = class extends Tracer {
1374
+ #rc;
1298
1375
  constructor(trace$1, aval) {
1299
1376
  super(trace$1);
1300
1377
  this.aval = aval;
1378
+ this.#rc = 1;
1301
1379
  }
1302
1380
  toString() {
1303
1381
  return `JaxprTracer(${this.aval.toString()})`;
1304
1382
  }
1305
1383
  get ref() {
1384
+ if (this.#rc <= 0) throw new UseAfterFreeError(this);
1385
+ this.#rc++;
1306
1386
  return this;
1307
1387
  }
1308
- dispose() {}
1388
+ dispose() {
1389
+ if (this.#rc <= 0) throw new UseAfterFreeError(this);
1390
+ this.#rc--;
1391
+ }
1392
+ trackLiftedConstant() {
1393
+ this.#rc++;
1394
+ }
1309
1395
  };
1310
1396
  /** Analogous to the 'DynamicJaxprTrace' class in JAX. */
1311
1397
  var JaxprTrace = class extends Trace {
@@ -1318,17 +1404,24 @@ var JaxprTrace = class extends Trace {
1318
1404
  }
1319
1405
  /** Register a constant / literal in this Jaxpr. */
1320
1406
  getOrMakeConstTracer(val) {
1407
+ if (!(val instanceof Tracer)) val = pureArray(val);
1321
1408
  let tracer = this.builder.constTracers.get(val);
1322
1409
  if (tracer === void 0) {
1323
1410
  tracer = this.builder.newTracer(this, ShapedArray.fromAval(getAval(val)));
1324
- this.builder.addConst(tracer, val instanceof Tracer ? val.ref : array(val));
1411
+ this.builder.addConst(tracer, val);
1412
+ } else {
1413
+ val.dispose();
1414
+ tracer.trackLiftedConstant();
1325
1415
  }
1326
1416
  return tracer;
1327
1417
  }
1328
1418
  pure = this.getOrMakeConstTracer;
1329
1419
  lift = this.getOrMakeConstTracer;
1330
1420
  processPrimitive(primitive, tracers, params) {
1331
- const avalsIn = tracers.map((t) => t.aval);
1421
+ const avalsIn = tracers.map((t) => {
1422
+ t.dispose();
1423
+ return t.aval;
1424
+ });
1332
1425
  const avalsOut = abstractEvalRules[primitive](avalsIn, params);
1333
1426
  const outTracers = avalsOut.map((aval) => this.builder.newTracer(this, aval));
1334
1427
  this.builder.addEqn(new JaxprEqn(primitive, tracers.map((t) => this.builder.getVar(t)), params, outTracers.map((t) => this.builder.addVar(t))));
@@ -1371,20 +1464,17 @@ var JaxprBuilder = class {
1371
1464
  return v;
1372
1465
  }
1373
1466
  build(inTracers, outTracers) {
1374
- let [constVars, consts] = unzip2(this.constVals.entries());
1467
+ const [constVars, consts] = unzip2(this.constVals.entries());
1375
1468
  const t2v = this.getVar.bind(this);
1376
1469
  const inBinders = [...constVars, ...inTracers.map(t2v)];
1377
1470
  const outVars = outTracers.map(t2v);
1378
- let jaxpr = new Jaxpr(inBinders, this.eqns, outVars);
1471
+ const jaxpr = new Jaxpr(inBinders, this.eqns, outVars);
1379
1472
  typecheckJaxpr(jaxpr);
1380
- [jaxpr, consts] = _inlineLiterals(jaxpr, consts);
1381
- return {
1382
- jaxpr,
1383
- consts
1384
- };
1473
+ const cjaxpr = new ClosedJaxpr(jaxpr, consts);
1474
+ return _inlineLiterals(cjaxpr);
1385
1475
  }
1386
1476
  };
1387
- function _inlineLiterals(jaxpr, consts) {
1477
+ function _inlineLiterals({ jaxpr, consts }) {
1388
1478
  const literals = /* @__PURE__ */ new Map();
1389
1479
  const constBinders = [];
1390
1480
  const newConsts = [];
@@ -1399,7 +1489,7 @@ function _inlineLiterals(jaxpr, consts) {
1399
1489
  const newOuts = jaxpr.outs.map((x) => literals.get(x) ?? x);
1400
1490
  const newJaxpr = new Jaxpr([...constBinders, ...jaxpr.inBinders.slice(consts.length)], newEqns, newOuts);
1401
1491
  typecheckJaxpr(newJaxpr);
1402
- return [newJaxpr, newConsts];
1492
+ return new ClosedJaxpr(newJaxpr, newConsts);
1403
1493
  }
1404
1494
  function binopAbstractEval([x, y]) {
1405
1495
  if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("binopAbstractEval expects ShapedArray inputs");
@@ -1418,6 +1508,8 @@ const abstractEvalRules = {
1418
1508
  [Primitive.Mul]: binopAbstractEval,
1419
1509
  [Primitive.Idiv]: binopAbstractEval,
1420
1510
  [Primitive.Mod]: binopAbstractEval,
1511
+ [Primitive.Min]: binopAbstractEval,
1512
+ [Primitive.Max]: binopAbstractEval,
1421
1513
  [Primitive.Neg]: vectorizedUnopAbstractEval,
1422
1514
  [Primitive.Reciprocal]: vectorizedUnopAbstractEval,
1423
1515
  [Primitive.Floor]: vectorizedUnopAbstractEval,
@@ -1431,12 +1523,6 @@ const abstractEvalRules = {
1431
1523
  if (byteWidth(x.dtype) !== byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
1432
1524
  return [new ShapedArray(x.shape, dtype, false)];
1433
1525
  },
1434
- [Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
1435
- if (k0.dtype !== DType.Uint32 || k1.dtype !== DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
1436
- const keyShape = generalBroadcast(k0.shape, k1.shape);
1437
- if (!deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
1438
- return [new ShapedArray(shape$1, DType.Uint32, false)];
1439
- },
1440
1526
  [Primitive.Sin]: vectorizedUnopAbstractEval,
1441
1527
  [Primitive.Cos]: vectorizedUnopAbstractEval,
1442
1528
  [Primitive.Asin]: vectorizedUnopAbstractEval,
@@ -1446,8 +1532,6 @@ const abstractEvalRules = {
1446
1532
  [Primitive.Erf]: vectorizedUnopAbstractEval,
1447
1533
  [Primitive.Erfc]: vectorizedUnopAbstractEval,
1448
1534
  [Primitive.Sqrt]: vectorizedUnopAbstractEval,
1449
- [Primitive.Min]: binopAbstractEval,
1450
- [Primitive.Max]: binopAbstractEval,
1451
1535
  [Primitive.Reduce]([x], { axis }) {
1452
1536
  const axisSet = new Set(axis);
1453
1537
  const newShape = x.shape.filter((_, i) => !axisSet.has(i));
@@ -1480,6 +1564,25 @@ const abstractEvalRules = {
1480
1564
  const shape$1 = generalBroadcast(cond.shape, xy.shape);
1481
1565
  return [new ShapedArray(shape$1, xy.dtype, xy.weakType)];
1482
1566
  },
1567
+ [Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
1568
+ if (k0.dtype !== DType.Uint32 || k1.dtype !== DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
1569
+ const keyShape = generalBroadcast(k0.shape, k1.shape);
1570
+ if (!deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
1571
+ return [new ShapedArray(shape$1, DType.Uint32, false)];
1572
+ },
1573
+ [Primitive.Gather]([x, ...indices], { axis, outDim }) {
1574
+ for (const a of indices) if (a.dtype !== DType.Int32 && a.dtype !== DType.Uint32) throw new TypeError(`Gather indices must be Int32 or Uint32, got ${a.dtype}`);
1575
+ if (axis.length !== indices.length) throw new TypeError(`Gather: ${axis} axes but ${indices.length} indices`);
1576
+ if (indices.length === 0) throw new TypeError("Gather must have 1+ indices with same shape");
1577
+ if (axis.some((a) => a < 0 || a >= x.shape.length)) throw new TypeError("Gather axis out of bounds");
1578
+ if (outDim < 0 || outDim > x.shape.length - axis.length) throw new TypeError("Gather outDim out of bounds");
1579
+ const axisSet = new Set(axis);
1580
+ if (axisSet.size !== axis.length) throw new TypeError("Gather axes are not unique");
1581
+ const gatherShape = indices.reduce((shape$1, a) => generalBroadcast(shape$1, a.shape), []);
1582
+ const newShape = x.shape.filter((_, i) => !axisSet.has(i));
1583
+ newShape.splice(outDim, 0, ...gatherShape);
1584
+ return [new ShapedArray(newShape, x.dtype, x.weakType)];
1585
+ },
1483
1586
  [Primitive.Transpose]([x], { perm }) {
1484
1587
  return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype, x.weakType)];
1485
1588
  },
@@ -1500,23 +1603,31 @@ const abstractEvalRules = {
1500
1603
  const newShape = x.shape.map((dim, i) => dim + width[i][0] + width[i][1]);
1501
1604
  return [new ShapedArray(newShape, x.dtype, x.weakType)];
1502
1605
  },
1503
- [Primitive.Gather]([x, ...indices], { axis, outDim }) {
1504
- for (const a of indices) if (a.dtype !== DType.Int32 && a.dtype !== DType.Uint32) throw new TypeError(`Gather indices must be Int32 or Uint32, got ${a.dtype}`);
1505
- if (axis.length !== indices.length) throw new TypeError(`Gather: ${axis} axes but ${indices.length} indices`);
1506
- if (indices.length === 0) throw new TypeError("Gather must have 1+ indices with same shape");
1507
- if (axis.some((a) => a < 0 || a >= x.shape.length)) throw new TypeError("Gather axis out of bounds");
1508
- if (outDim < 0 || outDim > x.shape.length - axis.length) throw new TypeError("Gather outDim out of bounds");
1509
- const axisSet = new Set(axis);
1510
- if (axisSet.size !== axis.length) throw new TypeError("Gather axes are not unique");
1511
- const gatherShape = indices.reduce((shape$1, a) => generalBroadcast(shape$1, a.shape), []);
1512
- const newShape = x.shape.filter((_, i) => !axisSet.has(i));
1513
- newShape.splice(outDim, 0, ...gatherShape);
1514
- return [new ShapedArray(newShape, x.dtype, x.weakType)];
1606
+ [Primitive.Sort]([x]) {
1607
+ if (x.ndim === 0) throw new TypeError("sort: requires at least 1D input");
1608
+ return [ShapedArray.fromAval(x)];
1609
+ },
1610
+ [Primitive.Argsort]([x]) {
1611
+ if (x.ndim === 0) throw new TypeError("argsort: requires at least 1D input");
1612
+ return [ShapedArray.fromAval(x), new ShapedArray(x.shape, DType.Int32, false)];
1613
+ },
1614
+ [Primitive.TriangularSolve]([a, b]) {
1615
+ if (a.ndim < 2) throw new TypeError(`triangular_solve: a must be at least 2D, got ${a}`);
1616
+ if (b.ndim < 2) throw new TypeError(`triangular_solve: b must be at least 2D, got ${b}`);
1617
+ const [m, n] = a.shape.slice(-2);
1618
+ const [_batch, q] = b.shape.slice(-2);
1619
+ if (!deepEqual(a.shape.slice(0, -2), b.shape.slice(0, -2)) || a.dtype !== b.dtype || m !== n || n !== q) throw new TypeError(`triangular_solve: mismatch ${a} vs ${b}`);
1620
+ return [new ShapedArray(b.shape, b.dtype, a.weakType && b.weakType)];
1621
+ },
1622
+ [Primitive.Cholesky]([a]) {
1623
+ if (a.ndim < 2) throw new TypeError(`cholesky: requires at least 2D input, got ${a}`);
1624
+ if (a.shape[a.ndim - 2] !== a.shape[a.ndim - 1]) throw new TypeError(`cholesky: must be square, got ${a}`);
1625
+ return [ShapedArray.fromAval(a)];
1515
1626
  },
1516
- [Primitive.JitCall](args, { jaxpr }) {
1627
+ [Primitive.Jit](args, { jaxpr }) {
1517
1628
  const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
1518
- if (args.length !== inTypes.length) throw new TypeError(`jit_call expected ${inTypes.length} arguments, got ${args.length}`);
1519
- for (let i = 0; i < inTypes.length; i++) if (!args[i].equals(inTypes[i])) throw new TypeError(`jit_call argument ${i} has type ${args[i]}, expected ${inTypes[i]}`);
1629
+ if (args.length !== inTypes.length) throw new TypeError(`jit expected ${inTypes.length} arguments, got ${args.length}`);
1630
+ for (let i = 0; i < inTypes.length; i++) if (!args[i].equals(inTypes[i])) throw new TypeError(`jit argument ${i} has type ${args[i]}, expected ${inTypes[i]}`);
1520
1631
  return outTypes;
1521
1632
  }
1522
1633
  };
@@ -1552,11 +1663,10 @@ function makeJaxpr$1(f, opts) {
1552
1663
  const tracersIn = avalsIn.map((aval) => trace$1.newArg(typeof aval === "object" ? aval : pureArray(aval)));
1553
1664
  const outs = fFlat(...tracersIn);
1554
1665
  const tracersOut = outs.map((out) => fullRaise(trace$1, out));
1555
- const { jaxpr, consts } = builder.build(tracersIn, tracersOut);
1666
+ const jaxpr = builder.build(tracersIn, tracersOut);
1556
1667
  if (outTree.value === void 0) throw new Error("outTree was not set in makeJaxpr");
1557
1668
  return {
1558
- jaxpr: jaxpr.simplify(),
1559
- consts,
1669
+ jaxpr: jaxpr.mapJaxpr((j) => j.simplify()),
1560
1670
  treedef: outTree.value
1561
1671
  };
1562
1672
  } catch (_) {
@@ -1575,22 +1685,28 @@ function jit$1(f, opts) {
1575
1685
  const avalsInFlat = argsFlat.map((x) => ShapedArray.fromAval(getAval(x)));
1576
1686
  const avalsIn = unflatten(inTree, avalsInFlat);
1577
1687
  const jaxprArgs = joinIdx(args.length, staticArgs, avalsIn, staticArgnums);
1578
- const { jaxpr, consts, treedef: outTree } = runWithCache(cache, jaxprArgs, () => makeJaxpr$1(f, opts)(...jaxprArgs));
1579
- const outs = bind(Primitive.JitCall, [...consts.map((c) => c.ref), ...argsFlat], {
1688
+ const { jaxpr, treedef: outTree } = runWithCache(cache, jaxprArgs, () => makeJaxpr$1(f, opts)(...jaxprArgs));
1689
+ const outs = bind(Primitive.Jit, [...jaxpr.consts.map((c) => c.ref), ...argsFlat], {
1580
1690
  name: f.name || "closure",
1581
- jaxpr,
1582
- numConsts: consts.length
1691
+ jaxpr: jaxpr.jaxpr,
1692
+ numConsts: jaxpr.consts.length
1583
1693
  });
1584
1694
  return unflatten(outTree, outs);
1585
1695
  });
1586
1696
  result.dispose = () => {
1587
- for (const { consts } of cache.values()) for (const c of consts) c.dispose();
1697
+ for (const { jaxpr } of cache.values()) jaxpr.dispose();
1588
1698
  };
1589
1699
  return result;
1590
1700
  }
1591
1701
 
1592
1702
  //#endregion
1593
1703
  //#region src/frontend/jit.ts
1704
+ const routinePrimitives = new Map([
1705
+ [Primitive.Sort, Routines.Sort],
1706
+ [Primitive.Argsort, Routines.Argsort],
1707
+ [Primitive.TriangularSolve, Routines.TriangularSolve],
1708
+ [Primitive.Cholesky, Routines.Cholesky]
1709
+ ]);
1594
1710
  /** Result of compiling a Jaxpr. Can be evaluated on a series of inputs. */
1595
1711
  var JitProgram = class {
1596
1712
  constructor(backend, steps, inputs, outputs) {
@@ -1605,9 +1721,14 @@ var JitProgram = class {
1605
1721
  case "execute": {
1606
1722
  const inputsNice = step.inputs.map((id, i) => `${i}: %${id}`).join(", ");
1607
1723
  const outputsNice = step.outputs.map((id) => `%${id}`).join(", ");
1608
- return PPrint.pp(`execute (${inputsNice}) -> ${outputsNice}, kernel`).concat(step.kernel.pprint().indent(2));
1724
+ const executeText = `execute (${inputsNice}) -> ${outputsNice}`;
1725
+ if (step.source instanceof Kernel) return PPrint.pp(`${executeText}, kernel`).concat(step.source.pprint().indent(2));
1726
+ else if (step.source instanceof Routine) return PPrint.pp(`${executeText}, routine ${step.source.name}`);
1727
+ else {
1728
+ step.source;
1729
+ return PPrint.pp(executeText);
1730
+ }
1609
1731
  }
1610
- case "const": return PPrint.pp(`%${step.output} = const <Slot ${step.slot}>`);
1611
1732
  case "malloc": return PPrint.pp(`%${step.output} = malloc <${step.size} bytes>`);
1612
1733
  case "incref": return PPrint.pp(`incref ${step.input}`);
1613
1734
  case "free": return PPrint.pp(`free ${step.input}`);
@@ -1630,12 +1751,9 @@ var JitProgram = class {
1630
1751
  const inputs$1 = step.inputs.map((id) => scope.get(id));
1631
1752
  const outputs = step.outputs.map((id) => scope.get(id));
1632
1753
  if (inputs$1.some((s) => s === void 0) || outputs.some((s) => s === void 0)) throw new Error(`internal: JitProgram scope undefined`);
1633
- pending.push(new PendingExecute(this.backend, step.kernel, inputs$1, outputs));
1754
+ pending.push(new PendingExecute(this.backend, step.source, inputs$1, outputs));
1634
1755
  break;
1635
1756
  }
1636
- case "const":
1637
- scope.set(step.output, step.slot);
1638
- break;
1639
1757
  case "malloc": {
1640
1758
  const slot = this.backend.malloc(step.size);
1641
1759
  scope.set(step.output, slot);
@@ -1669,34 +1787,37 @@ var JitProgramBuilder = class {
1669
1787
  this.#nextId = nargs;
1670
1788
  this.steps = [];
1671
1789
  }
1672
- pushConst(slot) {
1673
- const id = this.#nextId++;
1674
- this.steps.push({
1675
- type: "const",
1676
- slot,
1677
- output: id
1678
- });
1679
- return id;
1680
- }
1681
1790
  pushLit(lit) {
1682
- const kernel = new Kernel(0, prod(lit.aval.shape), AluExp.const(lit.dtype, lit.value));
1791
+ const kernel = new Kernel(0, lit.aval.size, AluExp.const(lit.dtype, lit.value));
1683
1792
  return this.pushKernel(kernel, []);
1684
1793
  }
1685
- pushKernel(kernel, inputs) {
1794
+ pushBuffer(size$1) {
1686
1795
  const id = this.#nextId++;
1687
1796
  this.steps.push({
1688
1797
  type: "malloc",
1689
- size: kernel.bytes,
1798
+ size: size$1,
1690
1799
  output: id
1691
1800
  });
1801
+ return id;
1802
+ }
1803
+ pushKernel(kernel, inputs) {
1804
+ const id = this.pushBuffer(kernel.bytes);
1692
1805
  this.steps.push({
1693
1806
  type: "execute",
1694
- kernel,
1807
+ source: kernel,
1695
1808
  inputs,
1696
1809
  outputs: [id]
1697
1810
  });
1698
1811
  return id;
1699
1812
  }
1813
+ pushRoutine(routine, inputs, outputs) {
1814
+ this.steps.push({
1815
+ type: "execute",
1816
+ source: routine,
1817
+ inputs,
1818
+ outputs
1819
+ });
1820
+ }
1700
1821
  pushIncref(id) {
1701
1822
  this.steps.push({
1702
1823
  type: "incref",
@@ -1722,28 +1843,18 @@ var JitProgramBuilder = class {
1722
1843
  }
1723
1844
  };
1724
1845
  const jitCompileCache = /* @__PURE__ */ new Map();
1725
- function jitCompile(backend, jaxpr, consts) {
1726
- if (jaxpr.inBinders.length < consts.length) throw new TypeError(`Jaxpr has ${jaxpr.inBinders.length} inputs, but ${consts.length} consts were provided`);
1727
- for (let i = 0; i < consts.length; i++) if (consts[i].device !== backend.type) throw new TypeError(`Const ${i} has device ${consts[i].device}, but expected ${backend.type}`);
1728
- const cacheKey = backend.type + FpHash.hash(jaxpr, ...consts.map((c) => c.id));
1846
+ function jitCompile(backend, jaxpr) {
1847
+ const cacheKey = backend.type + "," + FpHash.hash(jaxpr);
1729
1848
  const cached = jitCompileCache.get(cacheKey);
1730
1849
  if (cached) return cached;
1731
1850
  if (DEBUG >= 1) console.info("=========== JIT Compile ===========\n" + jaxpr.toString());
1732
1851
  jaxpr = jaxpr.flatten().simplify();
1733
- const nargs = jaxpr.inBinders.length - consts.length;
1852
+ const nargs = jaxpr.inBinders.length;
1734
1853
  const builder = new JitProgramBuilder(backend, nargs);
1735
1854
  const blackNodes = splitGraphDataflow(backend, jaxpr);
1736
1855
  const ctx = /* @__PURE__ */ new Map();
1737
- for (let i = 0; i < consts.length; i++) {
1738
- const v = jaxpr.inBinders[i];
1739
- const slot = consts[i]._realizeSource();
1740
- ctx.set(v, {
1741
- type: "imm",
1742
- arg: builder.pushConst(slot)
1743
- });
1744
- }
1745
1856
  for (let i = 0; i < nargs; i++) {
1746
- const v = jaxpr.inBinders[consts.length + i];
1857
+ const v = jaxpr.inBinders[i];
1747
1858
  ctx.set(v, {
1748
1859
  type: "imm",
1749
1860
  arg: i
@@ -1751,6 +1862,31 @@ function jitCompile(backend, jaxpr, consts) {
1751
1862
  }
1752
1863
  for (let i = 0; i < jaxpr.eqns.length; i++) {
1753
1864
  const eqn = jaxpr.eqns[i];
1865
+ if (routinePrimitives.has(eqn.primitive)) {
1866
+ const routine = new Routine(routinePrimitives.get(eqn.primitive), {
1867
+ inputShapes: eqn.inputs.map((x) => x.aval.shape),
1868
+ inputDtypes: eqn.inputs.map((x) => x.aval.dtype),
1869
+ outputShapes: eqn.outBinders.map((x) => x.aval.shape),
1870
+ outputDtypes: eqn.outBinders.map((x) => x.aval.dtype)
1871
+ }, eqn.params);
1872
+ const inputs = [];
1873
+ for (const input of eqn.inputs) if (input instanceof Var) {
1874
+ const jv = ctx.get(input);
1875
+ if (jv.type !== "imm") throw new Error(`jit: routine primitive ${eqn.primitive} input is not imm`);
1876
+ inputs.push(jv.arg);
1877
+ } else if (input instanceof Lit) inputs.push(builder.pushLit(input));
1878
+ const outputs = [];
1879
+ for (const outVar$1 of eqn.outBinders) {
1880
+ const outId = builder.pushBuffer(outVar$1.aval.size * byteWidth(outVar$1.aval.dtype));
1881
+ outputs.push(outId);
1882
+ ctx.set(outVar$1, {
1883
+ type: "imm",
1884
+ arg: outId
1885
+ });
1886
+ }
1887
+ builder.pushRoutine(routine, inputs, outputs);
1888
+ continue;
1889
+ }
1754
1890
  const inputExps = [];
1755
1891
  const inputAvals = [];
1756
1892
  const inputArgs = [];
@@ -1805,7 +1941,7 @@ function jitCompile(backend, jaxpr, consts) {
1805
1941
  const outVar = eqn.outBinders[0];
1806
1942
  if (blackNodes.has(outVar)) {
1807
1943
  const nargs$1 = inputArgs.length;
1808
- const size$1 = prod(outVar.aval.shape);
1944
+ const size$1 = outVar.aval.size;
1809
1945
  const kernel = new Kernel(nargs$1, size$1, exp$2, reduction);
1810
1946
  const outId = builder.pushKernel(kernel, inputArgs);
1811
1947
  ctx.set(outVar, {
@@ -1830,7 +1966,7 @@ function jitCompile(backend, jaxpr, consts) {
1830
1966
  if (jitValue.type !== "imm") throw new Error("internal: Expected imm, since outs are black nodes");
1831
1967
  outputIds.push(jitValue.arg);
1832
1968
  } else if (out instanceof Lit) outputIds.push(builder.pushLit(out));
1833
- const outputNeedsRef = new Set([...range(nargs), ...builder.steps.filter((s) => s.type === "const").map((s) => s.output)]);
1969
+ const outputNeedsRef = new Set(range(nargs));
1834
1970
  for (const outputId of outputIds) if (outputNeedsRef.has(outputId)) builder.pushIncref(outputId);
1835
1971
  else outputNeedsRef.add(outputId);
1836
1972
  builder.insertFreeSteps(outputIds);
@@ -1876,11 +2012,18 @@ function reshapeJit(fn) {
1876
2012
  return { exp: reshapeViews(a, (st) => fn(st, params)) };
1877
2013
  };
1878
2014
  }
2015
+ function routineNoJit() {
2016
+ return () => {
2017
+ throw new Error("jit: rule is not implemented for routines");
2018
+ };
2019
+ }
1879
2020
  const jitRules = {
1880
2021
  [Primitive.Add]: broadcastedJit(([a, b]) => AluExp.add(a, b)),
1881
2022
  [Primitive.Mul]: broadcastedJit(([a, b]) => AluExp.mul(a, b)),
1882
2023
  [Primitive.Idiv]: broadcastedJit(([a, b]) => AluExp.idiv(a, b)),
1883
2024
  [Primitive.Mod]: broadcastedJit(([a, b]) => AluExp.mod(a, b)),
2025
+ [Primitive.Min]: broadcastedJit(([a, b]) => AluExp.min(a, b)),
2026
+ [Primitive.Max]: broadcastedJit(([a, b]) => AluExp.max(a, b)),
1884
2027
  [Primitive.Neg]: unopJit((a) => AluExp.sub(AluExp.const(a.dtype, 0), a)),
1885
2028
  [Primitive.Reciprocal]: unopJit(AluExp.reciprocal),
1886
2029
  [Primitive.Floor]: unopJit(AluExp.floor),
@@ -1888,17 +2031,6 @@ const jitRules = {
1888
2031
  [Primitive.StopGradient]: unopJit((a) => a),
1889
2032
  [Primitive.Cast]: unopJit((a, { dtype }) => AluExp.cast(dtype, a)),
1890
2033
  [Primitive.Bitcast]: unopJit((a, { dtype }) => AluExp.bitcast(dtype, a)),
1891
- [Primitive.RandomBits]: (keys, keyShapes, { shape: shape$1, mode }) => {
1892
- const mapping = (st) => {
1893
- if (!deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, range(shape$1.length - st.shape.length));
1894
- };
1895
- const k0 = reshapeViews(keys[0], mapping);
1896
- const k1 = reshapeViews(keys[1], mapping);
1897
- const c0 = AluExp.u32(0);
1898
- const c1 = AluExp.cast(DType.Uint32, AluVar.gidx);
1899
- const exp$2 = AluExp.threefry2x32(k0, k1, c0, c1, mode);
1900
- return { exp: exp$2 };
1901
- },
1902
2034
  [Primitive.Sin]: unopJit(AluExp.sin),
1903
2035
  [Primitive.Cos]: unopJit(AluExp.cos),
1904
2036
  [Primitive.Asin]: unopJit(AluExp.asin),
@@ -1908,8 +2040,6 @@ const jitRules = {
1908
2040
  [Primitive.Erf]: unopJit(AluExp.erf),
1909
2041
  [Primitive.Erfc]: unopJit(AluExp.erfc),
1910
2042
  [Primitive.Sqrt]: unopJit(AluExp.sqrt),
1911
- [Primitive.Min]: broadcastedJit(([a, b]) => AluExp.min(a, b)),
1912
- [Primitive.Max]: broadcastedJit(([a, b]) => AluExp.max(a, b)),
1913
2043
  [Primitive.Reduce]([a], [as], { op, axis }) {
1914
2044
  const keptAxes = [];
1915
2045
  const shiftedAxes = [];
@@ -1959,16 +2089,17 @@ const jitRules = {
1959
2089
  },
1960
2090
  [Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
1961
2091
  [Primitive.Where]: broadcastedJit(([cond, a, b]) => AluExp.where(cond, a, b), { skipCastIdx: [0] }),
1962
- [Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
1963
- [Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
1964
- [Primitive.Reshape]: reshapeJit((st, { shape: shape$1 }) => st.reshape(shape$1)),
1965
- [Primitive.Flip]: reshapeJit((st, { axis }) => {
1966
- const arg = rep(st.shape.length, false);
1967
- for (const ax of axis) arg[ax] = true;
1968
- return st.flip(arg);
1969
- }),
1970
- [Primitive.Shrink]: reshapeJit((st, { slice }) => st.shrink(slice)),
1971
- [Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
2092
+ [Primitive.RandomBits]: (keys, keyShapes, { shape: shape$1, mode }) => {
2093
+ const mapping = (st) => {
2094
+ if (!deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, range(shape$1.length - st.shape.length));
2095
+ };
2096
+ const k0 = reshapeViews(keys[0], mapping);
2097
+ const k1 = reshapeViews(keys[1], mapping);
2098
+ const c0 = AluExp.u32(0);
2099
+ const c1 = AluExp.cast(DType.Uint32, AluVar.gidx);
2100
+ const exp$2 = AluExp.threefry2x32(k0, k1, c0, c1, mode);
2101
+ return { exp: exp$2 };
2102
+ },
1972
2103
  [Primitive.Gather]([x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
1973
2104
  const axisSet = new Set(axis);
1974
2105
  const indexShape = indicesShapes.map((c) => c.shape).reduce(generalBroadcast);
@@ -1984,8 +2115,22 @@ const jitRules = {
1984
2115
  if (!valid.resolve()) throw new Error("internal: expected full validity mask in Gather");
1985
2116
  return { exp: x.substitute({ gidx: index }) };
1986
2117
  },
1987
- [Primitive.JitCall]() {
1988
- throw new Error("internal: JitCall should have been flattened before JIT compilation");
2118
+ [Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
2119
+ [Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
2120
+ [Primitive.Reshape]: reshapeJit((st, { shape: shape$1 }) => st.reshape(shape$1)),
2121
+ [Primitive.Flip]: reshapeJit((st, { axis }) => {
2122
+ const arg = rep(st.shape.length, false);
2123
+ for (const ax of axis) arg[ax] = true;
2124
+ return st.flip(arg);
2125
+ }),
2126
+ [Primitive.Shrink]: reshapeJit((st, { slice }) => st.shrink(slice)),
2127
+ [Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
2128
+ [Primitive.Sort]: routineNoJit(),
2129
+ [Primitive.Argsort]: routineNoJit(),
2130
+ [Primitive.TriangularSolve]: routineNoJit(),
2131
+ [Primitive.Cholesky]: routineNoJit(),
2132
+ [Primitive.Jit]() {
2133
+ throw new Error("internal: Jit should have been flattened before JIT compilation");
1989
2134
  }
1990
2135
  };
1991
2136
  /** Determines how to split the Jaxpr into kernels via dataflow analysis. */
@@ -2043,8 +2188,8 @@ function splitGraphDataflow(backend, jaxpr) {
2043
2188
  case Primitive.Mul:
2044
2189
  case Primitive.Idiv:
2045
2190
  case Primitive.Mod:
2046
- case Primitive.Max:
2047
- case Primitive.Min: {
2191
+ case Primitive.Min:
2192
+ case Primitive.Max: {
2048
2193
  const otherInput = nextEqn.inputs.find((v) => v !== outVar);
2049
2194
  if (otherInput instanceof Lit || deepEqual(generalBroadcast(otherInput.aval.shape, outVar.aval.shape), outVar.aval.shape)) {
2050
2195
  head = usages[0];
@@ -2064,11 +2209,11 @@ function splitGraphDataflow(backend, jaxpr) {
2064
2209
  blackNodes.add(v);
2065
2210
  p1NextBlack.set(v, v);
2066
2211
  }
2067
- const heterogeneousViewPrimitives = [Primitive.Gather, Primitive.RandomBits];
2212
+ const heterogeneousViewPrimitives = [Primitive.RandomBits, Primitive.Gather];
2068
2213
  const needsCleanShapePrimitives = [Primitive.Pad];
2069
2214
  for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
2070
2215
  const eqn = jaxpr.eqns[i];
2071
- if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
2216
+ if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || routinePrimitives.has(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
2072
2217
  for (const v of eqn.outBinders) {
2073
2218
  blackNodes.add(v);
2074
2219
  p1NextBlack.set(v, v);
@@ -2078,7 +2223,7 @@ function splitGraphDataflow(backend, jaxpr) {
2078
2223
  const reach = /* @__PURE__ */ new Set();
2079
2224
  let needsCleanOutput = false;
2080
2225
  outer: for (const v of eqn.outBinders) for (const j of varToUsages.get(v) ?? []) {
2081
- if (needsCleanShapePrimitives.includes(jaxpr.eqns[j].primitive)) {
2226
+ if (needsCleanShapePrimitives.includes(jaxpr.eqns[j].primitive) || routinePrimitives.has(jaxpr.eqns[j].primitive)) {
2082
2227
  needsCleanOutput = true;
2083
2228
  break outer;
2084
2229
  }
@@ -2102,7 +2247,6 @@ function splitGraphDataflow(backend, jaxpr) {
2102
2247
  while (p2idx < jaxpr.eqns.length) {
2103
2248
  const eqn = jaxpr.eqns[p2idx++];
2104
2249
  const deps = [];
2105
- if (eqn.outBinders.some((v) => blackNodes.has(v))) continue;
2106
2250
  for (const input of eqn.inputs) if (input instanceof Var) if (blackNodes.has(input)) deps.push(new Set([input]));
2107
2251
  else deps.push(p2Deps.get(input));
2108
2252
  else deps.push(/* @__PURE__ */ new Set());
@@ -2125,7 +2269,7 @@ function splitGraphDataflow(backend, jaxpr) {
2125
2269
  if (assocInput === -1) throw new Error(`internal: maxArgs, no input found to mark as black in Jaxpr equation ${eqn}`);
2126
2270
  const assocVar = eqn.inputs[assocInput];
2127
2271
  p2idx = varToDefn.get(assocVar);
2128
- for (const out of jaxpr.eqns[p2idx].outBinders) blackNodes.add(out);
2272
+ for (const out of jaxpr.eqns[p2idx++].outBinders) blackNodes.add(out);
2129
2273
  } else {
2130
2274
  const s = new Set(depCounter.keys());
2131
2275
  for (const out of eqn.outBinders) p2Deps.set(out, s);
@@ -2151,9 +2295,9 @@ var PendingExecute = class {
2151
2295
  submitted = false;
2152
2296
  #promise = null;
2153
2297
  #rc = 1;
2154
- constructor(backend, kernel, inputs, outputs) {
2298
+ constructor(backend, source, inputs, outputs) {
2155
2299
  this.backend = backend;
2156
- this.kernel = kernel;
2300
+ this.source = source;
2157
2301
  this.inputs = inputs;
2158
2302
  this.outputs = outputs;
2159
2303
  for (const slot of inputs) this.backend.incRef(slot);
@@ -2174,13 +2318,15 @@ var PendingExecute = class {
2174
2318
  return;
2175
2319
  }
2176
2320
  this.#promise = (async () => {
2177
- this.prepared = await this.backend.prepare(this.kernel);
2321
+ if (this.source instanceof Kernel) this.prepared = await this.backend.prepareKernel(this.source);
2322
+ else this.prepared = await this.backend.prepareRoutine(this.source);
2178
2323
  })();
2179
2324
  await this.#promise;
2180
2325
  }
2181
2326
  prepareSync() {
2182
2327
  if (this.prepared) return;
2183
- this.prepared = this.backend.prepareSync(this.kernel);
2328
+ if (this.source instanceof Kernel) this.prepared = this.backend.prepareKernelSync(this.source);
2329
+ else this.prepared = this.backend.prepareRoutineSync(this.source);
2184
2330
  }
2185
2331
  submit() {
2186
2332
  if (this.submitted) return;
@@ -2203,8 +2349,6 @@ var PendingExecute = class {
2203
2349
  * "Array" type by name.
2204
2350
  */
2205
2351
  var Array$1 = class Array$1 extends Tracer {
2206
- static #nextId = 1001;
2207
- id;
2208
2352
  #dtype;
2209
2353
  #weakType;
2210
2354
  #source;
@@ -2221,7 +2365,6 @@ var Array$1 = class Array$1 extends Tracer {
2221
2365
  */
2222
2366
  constructor(args) {
2223
2367
  super(baseArrayTrace);
2224
- this.id = Array$1.#nextId++;
2225
2368
  this.#dtype = args.dtype;
2226
2369
  this.#weakType = args.weakType;
2227
2370
  this.#source = args.source;
@@ -2530,6 +2673,27 @@ var Array$1 = class Array$1 extends Tracer {
2530
2673
  pending
2531
2674
  });
2532
2675
  }
2676
+ /** Apply an operation with custom lowering to this array. */
2677
+ static #routine(routine, arrays, outputWeakType) {
2678
+ const { backend, committed } = Array$1.#computeBackend(routine.name, arrays);
2679
+ for (const ar of arrays) ar.#realize();
2680
+ const inputs = arrays.map((ar) => ar.#source);
2681
+ const outputs = routine.type.outputDtypes.map((dtype, i) => backend.malloc(byteWidth(dtype) * prod(routine.type.outputShapes[i])));
2682
+ const pending = arrays.flatMap((ar) => ar.#pending);
2683
+ for (const exe of pending) exe.updateRc(+outputs.length);
2684
+ pending.push(new PendingExecute(backend, routine, inputs, outputs));
2685
+ pending[pending.length - 1].updateRc(+outputs.length - 1);
2686
+ arrays.forEach((ar) => ar.dispose());
2687
+ return outputs.map((output, i) => new Array$1({
2688
+ source: output,
2689
+ st: ShapeTracker.fromShape(routine.type.outputShapes[i]),
2690
+ dtype: routine.type.outputDtypes[i],
2691
+ weakType: outputWeakType[i],
2692
+ backend,
2693
+ committed,
2694
+ pending
2695
+ }));
2696
+ }
2533
2697
  /**
2534
2698
  * Normalizes this array into one backed by a `Slot`.
2535
2699
  *
@@ -2690,6 +2854,12 @@ var Array$1 = class Array$1 extends Tracer {
2690
2854
  [Primitive.Mod]([x, y]) {
2691
2855
  return [x.#binary(AluOp.Mod, y)];
2692
2856
  },
2857
+ [Primitive.Min]([x, y]) {
2858
+ return [x.#binary(AluOp.Min, y)];
2859
+ },
2860
+ [Primitive.Max]([x, y]) {
2861
+ return [x.#binary(AluOp.Max, y)];
2862
+ },
2693
2863
  [Primitive.Neg]([x]) {
2694
2864
  return [zerosLike$1(x.ref).#binary(AluOp.Sub, x)];
2695
2865
  },
@@ -2726,25 +2896,6 @@ var Array$1 = class Array$1 extends Tracer {
2726
2896
  return [y];
2727
2897
  }
2728
2898
  },
2729
- [Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
2730
- const keyShape = generalBroadcast(k0.shape, k1.shape);
2731
- if (!deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
2732
- const c0 = zeros(shape$1, {
2733
- dtype: DType.Uint32,
2734
- device: k0.device
2735
- });
2736
- const c1 = arange(0, prod(shape$1), 1, {
2737
- dtype: DType.Uint32,
2738
- device: k0.device
2739
- }).reshape(shape$1);
2740
- const custom = ([k0$1, k1$1, c0$1, c1$1]) => AluExp.threefry2x32(k0$1, k1$1, c0$1, c1$1, mode);
2741
- return [Array$1.#naryCustom("random_bits", custom, [
2742
- k0,
2743
- k1,
2744
- c0,
2745
- c1
2746
- ])];
2747
- },
2748
2899
  [Primitive.Sin]([x]) {
2749
2900
  return [x.#unary(AluOp.Sin)];
2750
2901
  },
@@ -2772,12 +2923,6 @@ var Array$1 = class Array$1 extends Tracer {
2772
2923
  [Primitive.Sqrt]([x]) {
2773
2924
  return [x.#unary(AluOp.Sqrt)];
2774
2925
  },
2775
- [Primitive.Min]([x, y]) {
2776
- return [x.#binary(AluOp.Min, y)];
2777
- },
2778
- [Primitive.Max]([x, y]) {
2779
- return [x.#binary(AluOp.Max, y)];
2780
- },
2781
2926
  [Primitive.Reduce]([x], { op, axis }) {
2782
2927
  if (axis.length === 0) return [x];
2783
2928
  return [x.#moveAxesDown(axis).#reduce(op)];
@@ -2812,13 +2957,35 @@ var Array$1 = class Array$1 extends Tracer {
2812
2957
  y
2813
2958
  ], { dtypeOverride: [DType.Bool] })];
2814
2959
  },
2815
- [Primitive.Transpose]([x], { perm }) {
2816
- return [x.#transpose(perm)];
2817
- },
2818
- [Primitive.Broadcast]([x], { shape: shape$1, axis }) {
2819
- return [x.#reshape(x.#st.broadcast(shape$1, axis))];
2820
- },
2821
- [Primitive.Reshape]([x], { shape: shape$1 }) {
2960
+ [Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
2961
+ const keyShape = generalBroadcast(k0.shape, k1.shape);
2962
+ if (!deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
2963
+ const c0 = zeros(shape$1, {
2964
+ dtype: DType.Uint32,
2965
+ device: k0.device
2966
+ });
2967
+ const c1 = arange(0, prod(shape$1), 1, {
2968
+ dtype: DType.Uint32,
2969
+ device: k0.device
2970
+ }).reshape(shape$1);
2971
+ const custom = ([k0$1, k1$1, c0$1, c1$1]) => AluExp.threefry2x32(k0$1, k1$1, c0$1, c1$1, mode);
2972
+ return [Array$1.#naryCustom("random_bits", custom, [
2973
+ k0,
2974
+ k1,
2975
+ c0,
2976
+ c1
2977
+ ])];
2978
+ },
2979
+ [Primitive.Gather]([x, ...indices], { axis, outDim }) {
2980
+ return [x.#gather(indices, axis, outDim)];
2981
+ },
2982
+ [Primitive.Transpose]([x], { perm }) {
2983
+ return [x.#transpose(perm)];
2984
+ },
2985
+ [Primitive.Broadcast]([x], { shape: shape$1, axis }) {
2986
+ return [x.#reshape(x.#st.broadcast(shape$1, axis))];
2987
+ },
2988
+ [Primitive.Reshape]([x], { shape: shape$1 }) {
2822
2989
  return [x.#reshape(x.#st.reshape(shape$1))];
2823
2990
  },
2824
2991
  [Primitive.Flip]([x], { axis }) {
@@ -2832,17 +2999,48 @@ var Array$1 = class Array$1 extends Tracer {
2832
2999
  [Primitive.Pad]([x], { width }) {
2833
3000
  return [x.#reshape(x.#st.pad(width))];
2834
3001
  },
2835
- [Primitive.Gather]([x, ...indices], { axis, outDim }) {
2836
- return [x.#gather(indices, axis, outDim)];
3002
+ [Primitive.Sort]([x]) {
3003
+ const routine = new Routine(Routines.Sort, {
3004
+ inputShapes: [x.aval.shape],
3005
+ inputDtypes: [x.aval.dtype],
3006
+ outputShapes: [x.aval.shape],
3007
+ outputDtypes: [x.aval.dtype]
3008
+ });
3009
+ return Array$1.#routine(routine, [x], [x.#weakType]);
2837
3010
  },
2838
- [Primitive.JitCall](args, { jaxpr, numConsts }) {
2839
- if (jaxpr.inBinders.length !== args.length) throw new Error(`jit_call expects ${jaxpr.inBinders.length} args, got ${args.length}`);
2840
- const { backend, committed } = Array$1.#computeBackend("jit_call", args);
3011
+ [Primitive.Argsort]([x]) {
3012
+ const routine = new Routine(Routines.Argsort, {
3013
+ inputShapes: [x.aval.shape],
3014
+ inputDtypes: [x.aval.dtype],
3015
+ outputShapes: [x.aval.shape, x.aval.shape],
3016
+ outputDtypes: [x.aval.dtype, DType.Int32]
3017
+ });
3018
+ return Array$1.#routine(routine, [x], [x.#weakType, false]);
3019
+ },
3020
+ [Primitive.TriangularSolve]([a, b], { unitDiagonal }) {
3021
+ const routine = new Routine(Routines.TriangularSolve, {
3022
+ inputShapes: [a.aval.shape, b.aval.shape],
3023
+ inputDtypes: [a.aval.dtype, b.aval.dtype],
3024
+ outputShapes: [b.aval.shape],
3025
+ outputDtypes: [b.aval.dtype]
3026
+ }, { unitDiagonal });
3027
+ return Array$1.#routine(routine, [a, b], [a.#weakType && b.#weakType]);
3028
+ },
3029
+ [Primitive.Cholesky]([a]) {
3030
+ const routine = new Routine(Routines.Cholesky, {
3031
+ inputShapes: [a.aval.shape],
3032
+ inputDtypes: [a.aval.dtype],
3033
+ outputShapes: [a.aval.shape],
3034
+ outputDtypes: [a.aval.dtype]
3035
+ });
3036
+ return Array$1.#routine(routine, [a], [a.#weakType]);
3037
+ },
3038
+ [Primitive.Jit](args, { jaxpr }) {
3039
+ if (jaxpr.inBinders.length !== args.length) throw new Error(`jit expects ${jaxpr.inBinders.length} args, got ${args.length}`);
3040
+ const { backend, committed } = Array$1.#computeBackend("jit", args);
2841
3041
  args = args.map((ar) => ar._putSync(backend));
2842
- const consts = args.slice(0, numConsts);
2843
- const tracers = args.slice(numConsts);
2844
- const jp = jitCompile(backend, jaxpr, consts);
2845
- const { outputs, pending } = jp.execute(tracers.map((x) => x._realizeSource()));
3042
+ const jp = jitCompile(backend, jaxpr);
3043
+ const { outputs, pending } = jp.execute(args.map((x) => x._realizeSource()));
2846
3044
  for (const exe of pending) exe.updateRc(+outputs.length - 1);
2847
3045
  const prevPending = [...new Set(args.flatMap((x) => x.#pending))];
2848
3046
  for (const exe of prevPending) exe.updateRc(+outputs.length);
@@ -3141,6 +3339,43 @@ function arange(start, stop, step = 1, { dtype, device } = {}) {
3141
3339
  });
3142
3340
  }
3143
3341
  /**
3342
+ * Return an array with ones on and below the diagonal and zeros elsewhere.
3343
+ *
3344
+ * If `k` is provided, it specifies the sub-diagonal on and below which the
3345
+ * array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
3346
+ * `k>0` is above it.
3347
+ */
3348
+ function tri(n, m, k = 0, { dtype, device } = {}) {
3349
+ m ??= n;
3350
+ dtype ??= DType.Float32;
3351
+ if (!Number.isInteger(n) || n < 0) throw new Error(`tri: n must be a non-negative integer, got ${n}`);
3352
+ if (!Number.isInteger(m) || m < 0) throw new Error(`tri: m must be a non-negative integer, got ${m}`);
3353
+ if (!Number.isInteger(k)) throw new Error(`tri: k must be an integer, got ${k}`);
3354
+ const rows = arange(k, n + k, 1, {
3355
+ dtype: DType.Int32,
3356
+ device
3357
+ });
3358
+ const cols = arange(0, m, 1, {
3359
+ dtype: DType.Int32,
3360
+ device
3361
+ });
3362
+ return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
3363
+ }
3364
+ /** Return the lower triangle of an array. Must be of dimension >= 2. */
3365
+ function tril(a, k = 0) {
3366
+ if (ndim$1(a) < 2) throw new Error(`tril: input array must be at least 2D, got ${ndim$1(a)}D`);
3367
+ a = fudgeArray(a);
3368
+ const [n, m] = a.shape.slice(-2);
3369
+ return where$1(tri(n, m, k, { dtype: DType.Bool }), a.ref, zerosLike$1(a));
3370
+ }
3371
+ /** Return the upper triangle of an array. Must be of dimension >= 2. */
3372
+ function triu(a, k = 0) {
3373
+ if (ndim$1(a) < 2) throw new Error(`tril: input array must be at least 2D, got ${ndim$1(a)}D`);
3374
+ a = fudgeArray(a);
3375
+ const [n, m] = a.shape.slice(-2);
3376
+ return where$1(tri(n, m, k - 1, { dtype: DType.Bool }), zerosLike$1(a.ref), a);
3377
+ }
3378
+ /**
3144
3379
  * Return evenly spaced numbers over a specified interval.
3145
3380
  *
3146
3381
  * Returns _num_ evenly spaced samples, calculated over the interval
@@ -3187,383 +3422,186 @@ function aluCompare(a, b, op) {
3187
3422
  }
3188
3423
 
3189
3424
  //#endregion
3190
- //#region src/frontend/jvp.ts
3191
- var JVPTracer = class extends Tracer {
3192
- constructor(trace$1, primal, tangent) {
3425
+ //#region src/frontend/vmap.ts
3426
+ function mappedAval(batchDim, aval) {
3427
+ const shape$1 = [...aval.shape];
3428
+ shape$1.splice(batchDim, 1);
3429
+ return new ShapedArray(shape$1, aval.dtype, aval.weakType);
3430
+ }
3431
+ /** Move one axis to a different index. */
3432
+ function moveaxis(x, src, dst) {
3433
+ const t = pureArray(x);
3434
+ src = checkAxis(src, t.ndim);
3435
+ dst = checkAxis(dst, t.ndim);
3436
+ if (src === dst) return t;
3437
+ const perm = range(t.ndim);
3438
+ perm.splice(src, 1);
3439
+ perm.splice(dst, 0, src);
3440
+ return transpose$1(t, perm);
3441
+ }
3442
+ function moveBatchAxis(axisSize, src, dst, x) {
3443
+ if (src === null) {
3444
+ const targetShape = [...x.shape];
3445
+ targetShape.splice(dst, 0, axisSize);
3446
+ return broadcast(x, targetShape, [dst]);
3447
+ } else if (src === dst) return x;
3448
+ else return moveaxis(x, src, dst);
3449
+ }
3450
+ var BatchTracer = class extends Tracer {
3451
+ constructor(trace$1, val, batchDim) {
3193
3452
  super(trace$1);
3194
- this.primal = primal;
3195
- this.tangent = tangent;
3453
+ this.val = val;
3454
+ this.batchDim = batchDim;
3196
3455
  }
3197
3456
  get aval() {
3198
- return this.primal.aval;
3457
+ if (this.batchDim === null) return this.val.aval;
3458
+ else return mappedAval(this.batchDim, this.val.aval);
3199
3459
  }
3200
3460
  toString() {
3201
- return `JVPTracer(${this.primal.toString()}, ${this.tangent.toString()})`;
3461
+ return `BatchTracer(${this.val.toString()}, ${this.batchDim})`;
3202
3462
  }
3203
3463
  get ref() {
3204
- this.primal.ref, this.tangent.ref;
3464
+ this.val.ref;
3205
3465
  return this;
3206
3466
  }
3207
3467
  dispose() {
3208
- this.primal.dispose();
3209
- this.tangent.dispose();
3468
+ this.val.dispose();
3469
+ }
3470
+ fullLower() {
3471
+ if (this.batchDim === null) return this.val.fullLower();
3472
+ else return this;
3210
3473
  }
3211
3474
  };
3212
- var JVPTrace = class extends Trace {
3475
+ var BatchTrace = class extends Trace {
3213
3476
  pure(val) {
3214
3477
  return this.lift(pureArray(val));
3215
3478
  }
3216
3479
  lift(val) {
3217
- return new JVPTracer(this, val, zerosLike$1(val.ref));
3480
+ return new BatchTracer(this, val, null);
3218
3481
  }
3219
3482
  processPrimitive(primitive, tracers, params) {
3220
- const [primalsIn, tangentsIn] = unzip2(tracers.map((x) => [x.primal, x.tangent]));
3221
- const jvpRule = jvpRules[primitive];
3222
- if (jvpRule === void 0) throw new Error(`No JVP rule for: ${primitive}`);
3223
- const [primalsOut, tangentsOut] = jvpRule(primalsIn, tangentsIn, params);
3224
- return zip(primalsOut, tangentsOut).map(([x, t]) => new JVPTracer(this, x, t));
3483
+ const [valsIn, bdimsIn] = unzip2(tracers.map((t) => [t.val, t.batchDim]));
3484
+ const vmapRule = vmapRules[primitive];
3485
+ if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
3486
+ if (bdimsIn.every((d) => d === null)) {
3487
+ const valOuts$1 = bind(primitive, valsIn, params);
3488
+ return valOuts$1.map((x) => new BatchTracer(this, x, null));
3489
+ }
3490
+ const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
3491
+ return zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
3492
+ }
3493
+ get axisSize() {
3494
+ return this.main.globalData;
3225
3495
  }
3226
3496
  };
3227
- /** Rule that applies the same operation to primals and tangents. */
3228
- function linearTangentsJvp(primitive) {
3229
- return (primals, tangents, params) => {
3230
- const ys = bind(primitive, primals, params);
3231
- const dys = bind(primitive, tangents, params);
3232
- return [ys, dys];
3233
- };
3234
- }
3235
- /** Rule for product of gradients in bilinear operations. */
3236
- function bilinearTangentsJvp(primitive) {
3237
- return ([x, y], [dx, dy], params) => {
3238
- const primal = bind1(primitive, [x.ref, y.ref], params);
3239
- const tangent = bind1(primitive, [x, dy], params).add(bind1(primitive, [dx, y], params));
3240
- return [[primal], [tangent]];
3497
+ /**
3498
+ * Process a primitive with built-in broadcasting.
3499
+ *
3500
+ * Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
3501
+ */
3502
+ function broadcastBatcher(op) {
3503
+ return (axisSize, args, dims) => {
3504
+ if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
3505
+ const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
3506
+ const firstIdx = dims.findIndex((d) => d !== null);
3507
+ const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
3508
+ if (zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
3509
+ args = args.map((x, i) => {
3510
+ if (dims[i] === null) return x;
3511
+ x = moveBatchAxis(axisSize, dims[i], 0, x);
3512
+ if (x.ndim < nd) x = x.reshape([
3513
+ x.shape[0],
3514
+ ...rep(nd - x.ndim, 1),
3515
+ ...x.shape.slice(1)
3516
+ ]);
3517
+ return x;
3518
+ });
3519
+ return [[op(...args)], [0]];
3241
3520
  };
3242
3521
  }
3243
- /** Rule that zeros out any tangents. */
3244
- function zeroTangentsJvp(primitive) {
3245
- return (primals, tangents, params) => {
3246
- for (const t of tangents) t.dispose();
3247
- const ys = bind(primitive, primals, params);
3248
- return [ys, ys.map((y) => zerosLike$1(y.ref))];
3522
+ function unopBatcher(op) {
3523
+ return (axisSize, [x], [xBdim], params) => {
3524
+ return [[op(x, params)], [xBdim]];
3249
3525
  };
3250
3526
  }
3251
- const jvpRules = {
3252
- [Primitive.Add]: linearTangentsJvp(Primitive.Add),
3253
- [Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
3254
- [Primitive.Idiv]: zeroTangentsJvp(Primitive.Idiv),
3255
- [Primitive.Mod]([x, y], [dx, dy]) {
3256
- if (!isFloatDtype(x.dtype) && !isFloatDtype(y.dtype)) {
3257
- dx.dispose();
3258
- dy.dispose();
3259
- return [[x.ref, y.ref], [zerosLike$1(x), zerosLike$1(y)]];
3260
- }
3261
- const q = idiv(x.ref, y.ref);
3262
- return [[mod(x, y)], [dx.sub(dy.mul(q))]];
3527
+ const vmapRules = {
3528
+ [Primitive.Add]: broadcastBatcher(add$1),
3529
+ [Primitive.Mul]: broadcastBatcher(mul),
3530
+ [Primitive.Idiv]: broadcastBatcher(idiv),
3531
+ [Primitive.Mod]: broadcastBatcher(mod),
3532
+ [Primitive.Min]: broadcastBatcher(min$1),
3533
+ [Primitive.Max]: broadcastBatcher(max$1),
3534
+ [Primitive.Neg]: unopBatcher(neg),
3535
+ [Primitive.Reciprocal]: unopBatcher(reciprocal$1),
3536
+ [Primitive.Floor]: unopBatcher(floor$1),
3537
+ [Primitive.Ceil]: unopBatcher(ceil$1),
3538
+ [Primitive.StopGradient]: unopBatcher(stopGradient),
3539
+ [Primitive.Cast]: unopBatcher((x, { dtype }) => cast(x, dtype)),
3540
+ [Primitive.Bitcast]: unopBatcher((x, { dtype }) => bitcast(x, dtype)),
3541
+ [Primitive.Sin]: unopBatcher(sin$1),
3542
+ [Primitive.Cos]: unopBatcher(cos$1),
3543
+ [Primitive.Asin]: unopBatcher(asin$1),
3544
+ [Primitive.Atan]: unopBatcher(atan$1),
3545
+ [Primitive.Exp]: unopBatcher(exp$1),
3546
+ [Primitive.Log]: unopBatcher(log$1),
3547
+ [Primitive.Erf]: unopBatcher(erf$1),
3548
+ [Primitive.Erfc]: unopBatcher(erfc$1),
3549
+ [Primitive.Sqrt]: unopBatcher(sqrt$1),
3550
+ [Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
3551
+ assertNonNull(xBdim);
3552
+ const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3553
+ const outBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
3554
+ return [[reduce(x, op, newAxis)], [outBdim]];
3263
3555
  },
3264
- [Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
3265
- [Primitive.Reciprocal]([x], [dx]) {
3266
- const xRecip = reciprocal$1(x.ref);
3267
- return [[xRecip.ref], [neg(xRecip.ref.mul(xRecip)).mul(dx)]];
3556
+ [Primitive.Dot](axisSize, [x, y], [xBdim, yBdim]) {
3557
+ x = moveBatchAxis(axisSize, xBdim, x.ndim - (xBdim === null ? 1 : 2), x);
3558
+ y = moveBatchAxis(axisSize, yBdim, y.ndim - (yBdim === null ? 1 : 2), y);
3559
+ const z = dot$2(x, y);
3560
+ return [[z], [z.ndim - 1]];
3268
3561
  },
3269
- [Primitive.Floor]: zeroTangentsJvp(Primitive.Floor),
3270
- [Primitive.Ceil]: zeroTangentsJvp(Primitive.Ceil),
3271
- [Primitive.StopGradient]: zeroTangentsJvp(Primitive.StopGradient),
3272
- [Primitive.Cast]([x], [dx], { dtype }) {
3273
- if (x.dtype === dtype) return [[x], [dx]];
3274
- if (isFloatDtype(dtype) && isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
3275
- else {
3276
- dx.dispose();
3277
- return [[cast(x.ref, dtype)], [zerosLike$1(x)]];
3278
- }
3562
+ [Primitive.Conv](axisSize, [x, y], [xBdim, yBdim], params) {
3563
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3564
+ y = moveBatchAxis(axisSize, yBdim, 0, y);
3565
+ const z = conv$1(x, y, {
3566
+ ...params,
3567
+ vmapDims: params.vmapDims + 1
3568
+ });
3569
+ return [[z], [0]];
3279
3570
  },
3280
- [Primitive.Bitcast]([x], [dx], { dtype }) {
3281
- if (x.dtype === dtype) return [[x], [dx]];
3282
- dx.dispose();
3283
- return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
3571
+ [Primitive.Compare](axisSize, args, dims, { op }) {
3572
+ return broadcastBatcher((x, y) => compare(x, y, op))(axisSize, args, dims, {});
3284
3573
  },
3285
- [Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
3286
- [Primitive.Sin]([x], [dx]) {
3287
- return [[sin$1(x.ref)], [cos$1(x).mul(dx)]];
3288
- },
3289
- [Primitive.Cos]([x], [dx]) {
3290
- return [[cos$1(x.ref)], [neg(sin$1(x)).mul(dx)]];
3291
- },
3292
- [Primitive.Asin]([x], [dx]) {
3293
- const denom = sqrt$1(reciprocal$1(cast(1, x.dtype).sub(x.ref.mul(x.ref))));
3294
- return [[asin$1(x)], [denom.mul(dx)]];
3295
- },
3296
- [Primitive.Atan]([x], [dx]) {
3297
- const denom = cast(1, x.dtype).add(x.ref.mul(x.ref));
3298
- return [[atan$1(x)], [dx.div(denom)]];
3299
- },
3300
- [Primitive.Exp]([x], [dx]) {
3301
- const z = exp$1(x);
3302
- return [[z.ref], [z.mul(dx)]];
3303
- },
3304
- [Primitive.Log]([x], [dx]) {
3305
- return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
3306
- },
3307
- [Primitive.Erf]([x], [dx]) {
3308
- const coeff = 2 / Math.sqrt(Math.PI);
3309
- const expTerm = exp$1(neg(x.ref.mul(x.ref)));
3310
- return [[erf$1(x)], [expTerm.mul(coeff).mul(dx)]];
3311
- },
3312
- [Primitive.Erfc]([x], [dx]) {
3313
- const coeff = -2 / Math.sqrt(Math.PI);
3314
- const expTerm = exp$1(neg(x.ref.mul(x.ref)));
3315
- return [[erfc$1(x)], [expTerm.mul(coeff).mul(dx)]];
3316
- },
3317
- [Primitive.Sqrt]([x], [dx]) {
3318
- const z = sqrt$1(x);
3319
- return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
3320
- },
3321
- [Primitive.Min]([x, y], [dx, dy]) {
3322
- return [[min$1(x.ref, y.ref)], [where$1(less$1(y, x), dy, dx)]];
3323
- },
3324
- [Primitive.Max]([x, y], [dx, dy]) {
3325
- return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
3326
- },
3327
- [Primitive.Reduce]([x], [dx], { op, axis }) {
3328
- if (op === AluOp.Add) return [[reduce(x, op, axis)], [reduce(dx, op, axis)]];
3329
- else if (op === AluOp.Mul) {
3330
- const primal = reduce(x.ref, op, axis);
3331
- const tangent = broadcast(primal.ref, x.shape, axis).mul(reciprocal$1(x)).mul(dx).sum(axis);
3332
- return [[primal], [tangent]];
3333
- } else if (op === AluOp.Min || op === AluOp.Max) {
3334
- const primal = reduce(x.ref, op, axis);
3335
- const notMin = notEqual$1(x, broadcast(primal.ref, x.shape, axis));
3336
- const minCount = where$1(notMin.ref, 0, 1).sum(axis);
3337
- const tangent = where$1(notMin, 0, dx).sum(axis).div(minCount);
3338
- return [[primal], [tangent]];
3339
- } else throw new Error(`JVP rule not implemented for reduce op: ${op}`);
3340
- },
3341
- [Primitive.Pool]: linearTangentsJvp(Primitive.Pool),
3342
- [Primitive.PoolTranspose]: linearTangentsJvp(Primitive.PoolTranspose),
3343
- [Primitive.Dot]: bilinearTangentsJvp(Primitive.Dot),
3344
- [Primitive.Conv]: bilinearTangentsJvp(Primitive.Conv),
3345
- [Primitive.Compare]: zeroTangentsJvp(Primitive.Compare),
3346
- [Primitive.Where]([cond, x, y], [dcond, dx, dy]) {
3347
- dcond.dispose();
3348
- return [[where$1(cond.ref, x, y)], [where$1(cond, dx, dy)]];
3349
- },
3350
- [Primitive.Transpose]: linearTangentsJvp(Primitive.Transpose),
3351
- [Primitive.Broadcast]: linearTangentsJvp(Primitive.Broadcast),
3352
- [Primitive.Reshape]: linearTangentsJvp(Primitive.Reshape),
3353
- [Primitive.Flip]: linearTangentsJvp(Primitive.Flip),
3354
- [Primitive.Shrink]: linearTangentsJvp(Primitive.Shrink),
3355
- [Primitive.Pad]: linearTangentsJvp(Primitive.Pad),
3356
- [Primitive.Gather]([x, ...indices], [dx, ..._], { axis, outDim }) {
3357
- const indicesRef = indices.map((t) => t.ref);
3358
- return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
3359
- },
3360
- [Primitive.JitCall](primals, tangents, { name, jaxpr }) {
3361
- const { newJaxpr, newConsts } = jvpJaxpr(jaxpr);
3362
- const outs = bind(Primitive.JitCall, [
3363
- ...newConsts.map((c) => c.ref),
3364
- ...primals,
3365
- ...tangents
3366
- ], {
3367
- name: `${name}_jvp`,
3368
- jaxpr: newJaxpr,
3369
- numConsts: newConsts.length
3370
- });
3371
- const n = outs.length / 2;
3372
- if (!Number.isInteger(n)) throw new Error("internal: JVP Jaxpr output length is not even");
3373
- const [primalsOut, tangentsOut] = [outs.slice(0, n), outs.slice(n)];
3374
- return [primalsOut, tangentsOut];
3375
- }
3376
- };
3377
- const jvpJaxprCache = /* @__PURE__ */ new Map();
3378
- function jvpJaxpr(jaxpr) {
3379
- if (jvpJaxprCache.has(jaxpr)) return jvpJaxprCache.get(jaxpr);
3380
- const inAvals = jaxpr.inBinders.map((v) => v.aval);
3381
- const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((primals, tangents) => jvpFlat(jaxprAsFun(jaxpr), primals, tangents))(inAvals, inAvals);
3382
- const result = {
3383
- newJaxpr,
3384
- newConsts
3385
- };
3386
- jvpJaxprCache.set(jaxpr, result);
3387
- return result;
3388
- }
3389
- function jvpFlat(f, primals, tangents) {
3390
- try {
3391
- var _usingCtx$1 = _usingCtx();
3392
- const main = _usingCtx$1.u(newMain(JVPTrace));
3393
- const trace$1 = new JVPTrace(main);
3394
- const tracersIn = zip(primals, tangents).map(([x, t]) => new JVPTracer(trace$1, pureArray(x), pureArray(t)));
3395
- const outs = f(...tracersIn);
3396
- const tracersOut = outs.map((out) => fullRaise(trace$1, out));
3397
- return unzip2(tracersOut.map((t) => [t.primal, t.tangent]));
3398
- } catch (_) {
3399
- _usingCtx$1.e = _;
3400
- } finally {
3401
- _usingCtx$1.d();
3402
- }
3403
- }
3404
- function jvp$1(f, primals, tangents) {
3405
- const [primalsFlat, inTree] = flatten(primals);
3406
- const [tangentsFlat, inTree2] = flatten(tangents);
3407
- if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
3408
- const [flatFun, outTree] = flattenFun(f, inTree);
3409
- const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
3410
- if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
3411
- const primalsOut = unflatten(outTree.value, primalsOutFlat);
3412
- const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
3413
- return [primalsOut, tangentsOut];
3414
- }
3415
-
3416
- //#endregion
3417
- //#region src/frontend/vmap.ts
3418
- function mappedAval(batchDim, aval) {
3419
- const shape$1 = [...aval.shape];
3420
- shape$1.splice(batchDim, 1);
3421
- return new ShapedArray(shape$1, aval.dtype, aval.weakType);
3422
- }
3423
- /** Move one axis to a different index. */
3424
- function moveaxis(x, src, dst) {
3425
- const t = pureArray(x);
3426
- src = checkAxis(src, t.ndim);
3427
- dst = checkAxis(dst, t.ndim);
3428
- if (src === dst) return t;
3429
- const perm = range(t.ndim);
3430
- perm.splice(src, 1);
3431
- perm.splice(dst, 0, src);
3432
- return transpose$1(t, perm);
3433
- }
3434
- function moveBatchAxis(axisSize, src, dst, x) {
3435
- if (src === null) {
3436
- const targetShape = [...x.shape];
3437
- targetShape.splice(dst, 0, axisSize);
3438
- return broadcast(x, targetShape, [dst]);
3439
- } else if (src === dst) return x;
3440
- else return moveaxis(x, src, dst);
3441
- }
3442
- var BatchTracer = class extends Tracer {
3443
- constructor(trace$1, val, batchDim) {
3444
- super(trace$1);
3445
- this.val = val;
3446
- this.batchDim = batchDim;
3447
- }
3448
- get aval() {
3449
- if (this.batchDim === null) return this.val.aval;
3450
- else return mappedAval(this.batchDim, this.val.aval);
3451
- }
3452
- toString() {
3453
- return `BatchTracer(${this.val.toString()}, ${this.batchDim})`;
3454
- }
3455
- get ref() {
3456
- this.val.ref;
3457
- return this;
3458
- }
3459
- dispose() {
3460
- this.val.dispose();
3461
- }
3462
- fullLower() {
3463
- if (this.batchDim === null) return this.val.fullLower();
3464
- else return this;
3465
- }
3466
- };
3467
- var BatchTrace = class extends Trace {
3468
- pure(val) {
3469
- return this.lift(pureArray(val));
3470
- }
3471
- lift(val) {
3472
- return new BatchTracer(this, val, null);
3473
- }
3474
- processPrimitive(primitive, tracers, params) {
3475
- const [valsIn, bdimsIn] = unzip2(tracers.map((t) => [t.val, t.batchDim]));
3476
- const vmapRule = vmapRules[primitive];
3477
- if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
3478
- if (bdimsIn.every((d) => d === null)) {
3479
- const valOuts$1 = bind(primitive, valsIn, params);
3480
- return valOuts$1.map((x) => new BatchTracer(this, x, null));
3574
+ [Primitive.Where]: broadcastBatcher(where$1),
3575
+ [Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
3576
+ if (indicesBdim.every((d) => d === null)) {
3577
+ assertNonNull(xBdim);
3578
+ const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3579
+ let newBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
3580
+ let newOutDim = outDim;
3581
+ if (newOutDim < newBdim) newBdim += axis.length;
3582
+ else newOutDim += 1;
3583
+ return [[gather(x, indices, newAxis, newOutDim)], [newBdim]];
3481
3584
  }
3482
- const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
3483
- return zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
3484
- }
3485
- get axisSize() {
3486
- return this.main.globalData;
3487
- }
3488
- };
3489
- /**
3490
- * Process a primitive with built-in broadcasting.
3491
- *
3492
- * Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
3493
- */
3494
- function broadcastBatcher(op) {
3495
- return (axisSize, args, dims) => {
3496
- if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
3497
- const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
3498
- const firstIdx = dims.findIndex((d) => d !== null);
3499
- const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
3500
- if (zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
3501
- args = args.map((x, i) => {
3502
- if (dims[i] === null) return x;
3503
- x = moveBatchAxis(axisSize, dims[i], 0, x);
3504
- if (x.ndim < nd) x = x.reshape([
3505
- x.shape[0],
3506
- ...rep(nd - x.ndim, 1),
3507
- ...x.shape.slice(1)
3585
+ const nd = Math.max(...indices.map((m, i) => ndim$1(m) + (indicesBdim[i] === null ? 1 : 0)));
3586
+ indices = indices.map((m, i) => {
3587
+ if (indicesBdim[i] === null) return m;
3588
+ m = moveBatchAxis(axisSize, indicesBdim[i], 0, m);
3589
+ if (m.ndim < nd) m = m.reshape([
3590
+ m.shape[0],
3591
+ ...rep(nd - m.ndim, 1),
3592
+ ...m.shape.slice(1)
3508
3593
  ]);
3509
- return x;
3510
- });
3511
- return [[op(...args)], [0]];
3512
- };
3513
- }
3514
- function unopBatcher(op) {
3515
- return (axisSize, [x], [xBdim], params) => {
3516
- return [[op(x, params)], [xBdim]];
3517
- };
3518
- }
3519
- const vmapRules = {
3520
- [Primitive.Add]: broadcastBatcher(add$1),
3521
- [Primitive.Mul]: broadcastBatcher(mul),
3522
- [Primitive.Idiv]: broadcastBatcher(idiv),
3523
- [Primitive.Mod]: broadcastBatcher(mod),
3524
- [Primitive.Neg]: unopBatcher(neg),
3525
- [Primitive.Reciprocal]: unopBatcher(reciprocal$1),
3526
- [Primitive.Floor]: unopBatcher(floor$1),
3527
- [Primitive.Ceil]: unopBatcher(ceil$1),
3528
- [Primitive.StopGradient]: unopBatcher(stopGradient),
3529
- [Primitive.Cast]: unopBatcher((x, { dtype }) => cast(x, dtype)),
3530
- [Primitive.Bitcast]: unopBatcher((x, { dtype }) => bitcast(x, dtype)),
3531
- [Primitive.Sin]: unopBatcher(sin$1),
3532
- [Primitive.Cos]: unopBatcher(cos$1),
3533
- [Primitive.Asin]: unopBatcher(asin$1),
3534
- [Primitive.Atan]: unopBatcher(atan$1),
3535
- [Primitive.Exp]: unopBatcher(exp$1),
3536
- [Primitive.Log]: unopBatcher(log$1),
3537
- [Primitive.Erf]: unopBatcher(erf$1),
3538
- [Primitive.Erfc]: unopBatcher(erfc$1),
3539
- [Primitive.Sqrt]: unopBatcher(sqrt$1),
3540
- [Primitive.Min]: broadcastBatcher(min$1),
3541
- [Primitive.Max]: broadcastBatcher(max$1),
3542
- [Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
3543
- assertNonNull(xBdim);
3544
- const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3545
- const outBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
3546
- return [[reduce(x, op, newAxis)], [outBdim]];
3547
- },
3548
- [Primitive.Dot](axisSize, [x, y], [xBdim, yBdim]) {
3549
- x = moveBatchAxis(axisSize, xBdim, x.ndim - (xBdim === null ? 1 : 2), x);
3550
- y = moveBatchAxis(axisSize, yBdim, y.ndim - (yBdim === null ? 1 : 2), y);
3551
- const z = dot$2(x, y);
3552
- return [[z], [z.ndim - 1]];
3553
- },
3554
- [Primitive.Conv](axisSize, [x, y], [xBdim, yBdim], params) {
3555
- x = moveBatchAxis(axisSize, xBdim, 0, x);
3556
- y = moveBatchAxis(axisSize, yBdim, 0, y);
3557
- const z = conv$1(x, y, {
3558
- ...params,
3559
- vmapDims: params.vmapDims + 1
3594
+ return m;
3560
3595
  });
3561
- return [[z], [0]];
3562
- },
3563
- [Primitive.Compare](axisSize, args, dims, { op }) {
3564
- return broadcastBatcher((x, y) => compare(x, y, op))(axisSize, args, dims, {});
3596
+ if (xBdim === null) return [[gather(x, indices, axis, outDim)], [outDim]];
3597
+ else {
3598
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3599
+ const newAxis = [0, ...axis.map((ax) => ax + 1)];
3600
+ const extraBatchIndex = arange(axisSize).reshape([-1, ...rep(nd - 1, 1)]);
3601
+ indices.splice(0, 0, extraBatchIndex);
3602
+ return [[gather(x, indices, newAxis, outDim)], [outDim]];
3603
+ }
3565
3604
  },
3566
- [Primitive.Where]: broadcastBatcher(where$1),
3567
3605
  [Primitive.Transpose](axisSize, [x], [xBdim], { perm }) {
3568
3606
  assertNonNull(xBdim);
3569
3607
  const newPerm = perm.map((p) => p + (xBdim <= p ? 1 : 0));
@@ -3595,42 +3633,53 @@ const vmapRules = {
3595
3633
  const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
3596
3634
  return [[pad$1(x, newWidth)], [xBdim]];
3597
3635
  },
3598
- [Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
3599
- if (indicesBdim.every((d) => d === null)) {
3600
- assertNonNull(xBdim);
3601
- const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3602
- let newBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
3603
- let newOutDim = outDim;
3604
- if (newOutDim < newBdim) newBdim += axis.length;
3605
- else newOutDim += 1;
3606
- return [[gather(x, indices, newAxis, newOutDim)], [newBdim]];
3607
- }
3608
- const nd = Math.max(...indices.map((m, i) => ndim$1(m) + (indicesBdim[i] === null ? 1 : 0)));
3609
- indices = indices.map((m, i) => {
3610
- if (indicesBdim[i] === null) return m;
3611
- m = moveBatchAxis(axisSize, indicesBdim[i], 0, m);
3612
- if (m.ndim < nd) m = m.reshape([
3613
- m.shape[0],
3614
- ...rep(nd - m.ndim, 1),
3615
- ...m.shape.slice(1)
3636
+ [Primitive.Sort](axisSize, [x], [xBdim]) {
3637
+ assertNonNull(xBdim);
3638
+ if (xBdim !== x.ndim - 1) return [[sort$1(x)], [xBdim]];
3639
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3640
+ return [[sort$1(x)], [0]];
3641
+ },
3642
+ [Primitive.Argsort](axisSize, [x], [xBdim]) {
3643
+ assertNonNull(xBdim);
3644
+ if (xBdim !== x.ndim - 1) return [argsort$1(x), [xBdim, xBdim]];
3645
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3646
+ return [argsort$1(x), [0, 0]];
3647
+ },
3648
+ [Primitive.TriangularSolve](axisSize, [a, b], [aBdim, bBdim], { unitDiagonal }) {
3649
+ if (aBdim === null) {
3650
+ b = moveBatchAxis(axisSize, bBdim, -3, b);
3651
+ const [s, m, n] = b.shape.slice(-3);
3652
+ b = b.reshape([
3653
+ ...b.shape.slice(0, -3),
3654
+ s * m,
3655
+ n
3616
3656
  ]);
3617
- return m;
3618
- });
3619
- if (xBdim === null) return [[gather(x, indices, axis, outDim)], [outDim]];
3620
- else {
3621
- x = moveBatchAxis(axisSize, xBdim, 0, x);
3622
- const newAxis = [0, ...axis.map((ax) => ax + 1)];
3623
- const extraBatchIndex = arange(axisSize).reshape([-1, ...rep(nd - 1, 1)]);
3624
- indices.splice(0, 0, extraBatchIndex);
3625
- return [[gather(x, indices, newAxis, outDim)], [outDim]];
3657
+ let x$1 = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
3658
+ x$1 = x$1.reshape([
3659
+ ...b.shape.slice(0, -2),
3660
+ s,
3661
+ m,
3662
+ n
3663
+ ]);
3664
+ return [[x$1], [x$1.ndim - 3]];
3626
3665
  }
3666
+ a = moveBatchAxis(axisSize, aBdim, 0, a);
3667
+ b = moveBatchAxis(axisSize, bBdim, 0, b);
3668
+ const x = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
3669
+ return [[x], [0]];
3670
+ },
3671
+ [Primitive.Cholesky](axisSize, [x], [xBdim]) {
3672
+ assertNonNull(xBdim);
3673
+ if (xBdim < x.ndim - 2) return [[cholesky$2(x)], [xBdim]];
3674
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3675
+ return [[cholesky$2(x)], [0]];
3627
3676
  },
3628
- [Primitive.JitCall](axisSize, args, dims, { name, jaxpr }) {
3629
- const { newJaxpr, newConsts } = vmapJaxpr(jaxpr, axisSize, dims);
3630
- const outs = bind(Primitive.JitCall, [...newConsts.map((c) => c.ref), ...args], {
3677
+ [Primitive.Jit](axisSize, args, dims, { name, jaxpr }) {
3678
+ const newJaxpr = vmapJaxpr(jaxpr, axisSize, dims);
3679
+ const outs = bind(Primitive.Jit, [...newJaxpr.consts.map((c) => c.ref), ...args], {
3631
3680
  name: `${name}_vmap`,
3632
- jaxpr: newJaxpr,
3633
- numConsts: newConsts.length
3681
+ jaxpr: newJaxpr.jaxpr,
3682
+ numConsts: newJaxpr.consts.length
3634
3683
  });
3635
3684
  return [outs, rep(outs.length, 0)];
3636
3685
  }
@@ -3646,14 +3695,10 @@ function vmapJaxpr(jaxpr, axisSize, dims) {
3646
3695
  shape$1.splice(dims[i], 0, axisSize);
3647
3696
  return new ShapedArray(shape$1, v.aval.dtype, v.aval.weakType);
3648
3697
  });
3649
- const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((args) => vmapFlat(jaxprAsFun(jaxpr), dims, args))(inAvals);
3650
- const result = {
3651
- newJaxpr,
3652
- newConsts
3653
- };
3698
+ const { jaxpr: newJaxpr } = makeJaxpr$1((args) => vmapFlat(jaxprAsFun(jaxpr), dims, args))(inAvals);
3654
3699
  if (!vmapJaxprCache.has(jaxpr)) vmapJaxprCache.set(jaxpr, /* @__PURE__ */ new Map());
3655
- vmapJaxprCache.get(jaxpr).set(cacheKey, result);
3656
- return result;
3700
+ vmapJaxprCache.get(jaxpr).set(cacheKey, newJaxpr);
3701
+ return newJaxpr;
3657
3702
  }
3658
3703
  function vmapFlat(f, inAxes, args) {
3659
3704
  let axisSize = void 0;
@@ -3708,6 +3753,260 @@ function jacfwd$1(f) {
3708
3753
  };
3709
3754
  }
3710
3755
 
3756
+ //#endregion
3757
+ //#region src/frontend/jvp.ts
3758
+ var JVPTracer = class extends Tracer {
3759
+ constructor(trace$1, primal, tangent) {
3760
+ super(trace$1);
3761
+ this.primal = primal;
3762
+ this.tangent = tangent;
3763
+ }
3764
+ get aval() {
3765
+ return this.primal.aval;
3766
+ }
3767
+ toString() {
3768
+ return `JVPTracer(${this.primal.toString()}, ${this.tangent.toString()})`;
3769
+ }
3770
+ get ref() {
3771
+ this.primal.ref, this.tangent.ref;
3772
+ return this;
3773
+ }
3774
+ dispose() {
3775
+ this.primal.dispose();
3776
+ this.tangent.dispose();
3777
+ }
3778
+ };
3779
+ var JVPTrace = class extends Trace {
3780
+ pure(val) {
3781
+ return this.lift(pureArray(val));
3782
+ }
3783
+ lift(val) {
3784
+ return new JVPTracer(this, val, zerosLike$1(val.ref));
3785
+ }
3786
+ processPrimitive(primitive, tracers, params) {
3787
+ const [primalsIn, tangentsIn] = unzip2(tracers.map((x) => [x.primal, x.tangent]));
3788
+ const jvpRule = jvpRules[primitive];
3789
+ if (jvpRule === void 0) throw new Error(`No JVP rule for: ${primitive}`);
3790
+ const [primalsOut, tangentsOut] = jvpRule(primalsIn, tangentsIn, params);
3791
+ return zip(primalsOut, tangentsOut).map(([x, t]) => new JVPTracer(this, x, t));
3792
+ }
3793
+ };
3794
+ /** Rule that applies the same operation to primals and tangents. */
3795
+ function linearTangentsJvp(primitive) {
3796
+ return (primals, tangents, params) => {
3797
+ const ys = bind(primitive, primals, params);
3798
+ const dys = bind(primitive, tangents, params);
3799
+ return [ys, dys];
3800
+ };
3801
+ }
3802
+ /** Rule for product of gradients in bilinear operations. */
3803
+ function bilinearTangentsJvp(primitive) {
3804
+ return ([x, y], [dx, dy], params) => {
3805
+ const primal = bind1(primitive, [x.ref, y.ref], params);
3806
+ const tangent = bind1(primitive, [x, dy], params).add(bind1(primitive, [dx, y], params));
3807
+ return [[primal], [tangent]];
3808
+ };
3809
+ }
3810
+ /** Rule that zeros out any tangents. */
3811
+ function zeroTangentsJvp(primitive) {
3812
+ return (primals, tangents, params) => {
3813
+ for (const t of tangents) t.dispose();
3814
+ const ys = bind(primitive, primals, params);
3815
+ return [ys, ys.map((y) => zerosLike$1(y.ref))];
3816
+ };
3817
+ }
3818
+ /** Compute `a @ b.T`, batched to last two axes. */
3819
+ function batchMatmulT(a, b) {
3820
+ return dot$2(a.reshape(a.shape.toSpliced(-1, 0, 1)), b.reshape(b.shape.toSpliced(-2, 0, 1)));
3821
+ }
3822
+ /** Batch matrix transpose. */
3823
+ function mT(a) {
3824
+ return moveaxis(a, -2, -1);
3825
+ }
3826
+ const jvpRules = {
3827
+ [Primitive.Add]: linearTangentsJvp(Primitive.Add),
3828
+ [Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
3829
+ [Primitive.Idiv]: zeroTangentsJvp(Primitive.Idiv),
3830
+ [Primitive.Mod]([x, y], [dx, dy]) {
3831
+ if (!isFloatDtype(x.dtype) && !isFloatDtype(y.dtype)) {
3832
+ dx.dispose();
3833
+ dy.dispose();
3834
+ return [[x.ref, y.ref], [zerosLike$1(x), zerosLike$1(y)]];
3835
+ }
3836
+ const q = idiv(x.ref, y.ref);
3837
+ return [[mod(x, y)], [dx.sub(dy.mul(q))]];
3838
+ },
3839
+ [Primitive.Min]([x, y], [dx, dy]) {
3840
+ return [[min$1(x.ref, y.ref)], [where$1(less$1(y, x), dy, dx)]];
3841
+ },
3842
+ [Primitive.Max]([x, y], [dx, dy]) {
3843
+ return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
3844
+ },
3845
+ [Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
3846
+ [Primitive.Reciprocal]([x], [dx]) {
3847
+ const xRecip = reciprocal$1(x.ref);
3848
+ return [[xRecip.ref], [neg(xRecip.ref.mul(xRecip)).mul(dx)]];
3849
+ },
3850
+ [Primitive.Floor]: zeroTangentsJvp(Primitive.Floor),
3851
+ [Primitive.Ceil]: zeroTangentsJvp(Primitive.Ceil),
3852
+ [Primitive.StopGradient]: zeroTangentsJvp(Primitive.StopGradient),
3853
+ [Primitive.Cast]([x], [dx], { dtype }) {
3854
+ if (x.dtype === dtype) return [[x], [dx]];
3855
+ if (isFloatDtype(dtype) && isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
3856
+ else {
3857
+ dx.dispose();
3858
+ return [[cast(x.ref, dtype)], [zerosLike$1(x)]];
3859
+ }
3860
+ },
3861
+ [Primitive.Bitcast]([x], [dx], { dtype }) {
3862
+ if (x.dtype === dtype) return [[x], [dx]];
3863
+ dx.dispose();
3864
+ return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
3865
+ },
3866
+ [Primitive.Sin]([x], [dx]) {
3867
+ return [[sin$1(x.ref)], [cos$1(x).mul(dx)]];
3868
+ },
3869
+ [Primitive.Cos]([x], [dx]) {
3870
+ return [[cos$1(x.ref)], [neg(sin$1(x)).mul(dx)]];
3871
+ },
3872
+ [Primitive.Asin]([x], [dx]) {
3873
+ const denom = sqrt$1(reciprocal$1(cast(1, x.dtype).sub(x.ref.mul(x.ref))));
3874
+ return [[asin$1(x)], [denom.mul(dx)]];
3875
+ },
3876
+ [Primitive.Atan]([x], [dx]) {
3877
+ const denom = cast(1, x.dtype).add(x.ref.mul(x.ref));
3878
+ return [[atan$1(x)], [dx.div(denom)]];
3879
+ },
3880
+ [Primitive.Exp]([x], [dx]) {
3881
+ const z = exp$1(x);
3882
+ return [[z.ref], [z.mul(dx)]];
3883
+ },
3884
+ [Primitive.Log]([x], [dx]) {
3885
+ return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
3886
+ },
3887
+ [Primitive.Erf]([x], [dx]) {
3888
+ const coeff = 2 / Math.sqrt(Math.PI);
3889
+ const expTerm = exp$1(neg(x.ref.mul(x.ref)));
3890
+ return [[erf$1(x)], [expTerm.mul(coeff).mul(dx)]];
3891
+ },
3892
+ [Primitive.Erfc]([x], [dx]) {
3893
+ const coeff = -2 / Math.sqrt(Math.PI);
3894
+ const expTerm = exp$1(neg(x.ref.mul(x.ref)));
3895
+ return [[erfc$1(x)], [expTerm.mul(coeff).mul(dx)]];
3896
+ },
3897
+ [Primitive.Sqrt]([x], [dx]) {
3898
+ const z = sqrt$1(x);
3899
+ return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
3900
+ },
3901
+ [Primitive.Reduce]([x], [dx], { op, axis }) {
3902
+ if (op === AluOp.Add) return [[reduce(x, op, axis)], [reduce(dx, op, axis)]];
3903
+ else if (op === AluOp.Mul) {
3904
+ const primal = reduce(x.ref, op, axis);
3905
+ const tangent = broadcast(primal.ref, x.shape, axis).mul(reciprocal$1(x)).mul(dx).sum(axis);
3906
+ return [[primal], [tangent]];
3907
+ } else if (op === AluOp.Min || op === AluOp.Max) {
3908
+ const primal = reduce(x.ref, op, axis);
3909
+ const notMin = notEqual$1(x, broadcast(primal.ref, x.shape, axis));
3910
+ const minCount = where$1(notMin.ref, 0, 1).sum(axis);
3911
+ const tangent = where$1(notMin, 0, dx).sum(axis).div(minCount);
3912
+ return [[primal], [tangent]];
3913
+ } else throw new Error(`JVP rule not implemented for reduce op: ${op}`);
3914
+ },
3915
+ [Primitive.Pool]: linearTangentsJvp(Primitive.Pool),
3916
+ [Primitive.PoolTranspose]: linearTangentsJvp(Primitive.PoolTranspose),
3917
+ [Primitive.Dot]: bilinearTangentsJvp(Primitive.Dot),
3918
+ [Primitive.Conv]: bilinearTangentsJvp(Primitive.Conv),
3919
+ [Primitive.Compare]: zeroTangentsJvp(Primitive.Compare),
3920
+ [Primitive.Where]([cond, x, y], [dcond, dx, dy]) {
3921
+ dcond.dispose();
3922
+ return [[where$1(cond.ref, x, y)], [where$1(cond, dx, dy)]];
3923
+ },
3924
+ [Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
3925
+ [Primitive.Gather]([x, ...indices], [dx, ..._], { axis, outDim }) {
3926
+ const indicesRef = indices.map((t) => t.ref);
3927
+ return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
3928
+ },
3929
+ [Primitive.Transpose]: linearTangentsJvp(Primitive.Transpose),
3930
+ [Primitive.Broadcast]: linearTangentsJvp(Primitive.Broadcast),
3931
+ [Primitive.Reshape]: linearTangentsJvp(Primitive.Reshape),
3932
+ [Primitive.Flip]: linearTangentsJvp(Primitive.Flip),
3933
+ [Primitive.Shrink]: linearTangentsJvp(Primitive.Shrink),
3934
+ [Primitive.Pad]: linearTangentsJvp(Primitive.Pad),
3935
+ [Primitive.Sort]([x], [dx]) {
3936
+ const [y, idx] = argsort$1(x);
3937
+ return [[y], [gather(dx, [idx], [-1], -1)]];
3938
+ },
3939
+ [Primitive.Argsort]([x], [dx]) {
3940
+ const [y, idx] = argsort$1(x);
3941
+ return [[y, idx.ref], [gather(dx, [idx.ref], [-1], -1), zerosLike$1(idx)]];
3942
+ },
3943
+ [Primitive.TriangularSolve]([a, b], [da, db], { unitDiagonal }) {
3944
+ const x = triangularSolve$1(a.ref, b, { unitDiagonal });
3945
+ const dax = batchMatmulT(da, x.ref);
3946
+ const rhsT = db.sub(mT(dax));
3947
+ const dx = triangularSolve$1(a, rhsT, { unitDiagonal });
3948
+ return [[x], [dx]];
3949
+ },
3950
+ [Primitive.Cholesky]([a], [da]) {
3951
+ const L = cholesky$2(a.ref);
3952
+ da = da.ref.add(mT(da)).mul(.5);
3953
+ const W = triangularSolve$1(L.ref, da, { lower: true });
3954
+ const ST = triangularSolve$1(L.ref, mT(W), { lower: true });
3955
+ const dL = batchMatmulT(L.ref, triu(ST.ref, 1).add(triu(ST)).mul(.5));
3956
+ return [[L], [dL]];
3957
+ },
3958
+ [Primitive.Jit](primals, tangents, { name, jaxpr }) {
3959
+ const newJaxpr = jvpJaxpr(jaxpr);
3960
+ const outs = bind(Primitive.Jit, [
3961
+ ...newJaxpr.consts.map((c) => c.ref),
3962
+ ...primals,
3963
+ ...tangents
3964
+ ], {
3965
+ name: `${name}_jvp`,
3966
+ jaxpr: newJaxpr.jaxpr,
3967
+ numConsts: newJaxpr.consts.length
3968
+ });
3969
+ const n = outs.length / 2;
3970
+ if (!Number.isInteger(n)) throw new Error("internal: JVP Jaxpr output length is not even");
3971
+ const [primalsOut, tangentsOut] = [outs.slice(0, n), outs.slice(n)];
3972
+ return [primalsOut, tangentsOut];
3973
+ }
3974
+ };
3975
+ const jvpJaxprCache = /* @__PURE__ */ new Map();
3976
+ function jvpJaxpr(jaxpr) {
3977
+ if (jvpJaxprCache.has(jaxpr)) return jvpJaxprCache.get(jaxpr);
3978
+ const inAvals = jaxpr.inBinders.map((v) => v.aval);
3979
+ const { jaxpr: newJaxpr } = makeJaxpr$1((primals, tangents) => jvpFlat(jaxprAsFun(jaxpr), primals, tangents))(inAvals, inAvals);
3980
+ jvpJaxprCache.set(jaxpr, newJaxpr);
3981
+ return newJaxpr;
3982
+ }
3983
+ function jvpFlat(f, primals, tangents) {
3984
+ try {
3985
+ var _usingCtx$1 = _usingCtx();
3986
+ const main = _usingCtx$1.u(newMain(JVPTrace));
3987
+ const trace$1 = new JVPTrace(main);
3988
+ const tracersIn = zip(primals, tangents).map(([x, t]) => new JVPTracer(trace$1, pureArray(x), pureArray(t)));
3989
+ const outs = f(...tracersIn);
3990
+ const tracersOut = outs.map((out) => fullRaise(trace$1, out));
3991
+ return unzip2(tracersOut.map((t) => [t.primal, t.tangent]));
3992
+ } catch (_) {
3993
+ _usingCtx$1.e = _;
3994
+ } finally {
3995
+ _usingCtx$1.d();
3996
+ }
3997
+ }
3998
+ function jvp$1(f, primals, tangents) {
3999
+ const [primalsFlat, inTree] = flatten(primals);
4000
+ const [tangentsFlat, inTree2] = flatten(tangents);
4001
+ if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
4002
+ const [flatFun, outTree] = flattenFun(f, inTree);
4003
+ const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
4004
+ if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
4005
+ const primalsOut = unflatten(outTree.value, primalsOutFlat);
4006
+ const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
4007
+ return [primalsOut, tangentsOut];
4008
+ }
4009
+
3711
4010
  //#endregion
3712
4011
  //#region src/frontend/linearize.ts
3713
4012
  /** Array value that can either be known or unknown. */
@@ -3738,11 +4037,10 @@ function partialEvalFlat(f, pvalsIn) {
3738
4037
  const tracersOut = outs.map((out) => fullRaise(trace$1, out));
3739
4038
  const pvalsOut = tracersOut.map((t) => t.pval);
3740
4039
  const unknownTracersOut = tracersOut.filter((t) => !t.pval.isKnown);
3741
- const { jaxpr, consts } = partialEvalGraphToJaxpr(unknownTracersIn, unknownTracersOut);
4040
+ const jaxpr = partialEvalGraphToJaxpr(unknownTracersIn, unknownTracersOut);
3742
4041
  return {
3743
4042
  jaxpr,
3744
- pvalsOut,
3745
- consts
4043
+ pvalsOut
3746
4044
  };
3747
4045
  }
3748
4046
  /**
@@ -3759,22 +4057,19 @@ function linearizeFlatUtil(f, primalsIn) {
3759
4057
  const [primalsOut$1, tangentsOut] = jvp$1(f, x.slice(0, k), x.slice(k, 2 * k));
3760
4058
  return [...primalsOut$1, ...tangentsOut];
3761
4059
  };
3762
- const { jaxpr, pvalsOut, consts } = partialEvalFlat(fJvp, pvalsIn);
4060
+ const { jaxpr, pvalsOut } = partialEvalFlat(fJvp, pvalsIn);
3763
4061
  const primalPvals = pvalsOut.slice(0, pvalsOut.length / 2);
3764
4062
  if (!primalPvals.every((pval) => pval.isKnown)) throw new Error("Not all primal values are known after partial evaluation");
3765
4063
  const primalsOut = primalPvals.map((pval) => pval.val);
3766
4064
  return {
3767
4065
  primalsOut,
3768
- jaxpr,
3769
- consts
4066
+ jaxpr
3770
4067
  };
3771
4068
  }
3772
4069
  function linearizeFlat(f, primalsIn) {
3773
- const { primalsOut, jaxpr, consts } = linearizeFlatUtil(f, primalsIn);
3774
- const fLin = (...tangents) => evalJaxpr(jaxpr, [...consts.map((c) => c.ref), ...tangents]);
3775
- const dispose$1 = () => {
3776
- for (const c of consts) c.dispose();
3777
- };
4070
+ const { primalsOut, jaxpr } = linearizeFlatUtil(f, primalsIn);
4071
+ const fLin = (...tangents) => evalJaxpr(jaxpr.jaxpr, [...jaxpr.consts.map((c) => c.ref), ...tangents]);
4072
+ const dispose$1 = () => jaxpr.dispose();
3778
4073
  return [
3779
4074
  primalsOut,
3780
4075
  fLin,
@@ -3858,7 +4153,7 @@ var PartialEvalTrace = class extends Trace {
3858
4153
  }
3859
4154
  processPrimitive(primitive, tracers, params) {
3860
4155
  if (tracers.every((t) => t.pval.isKnown)) return bind(primitive, tracers.map((t) => t.fullLower()), params);
3861
- if (primitive === Primitive.JitCall) {
4156
+ if (primitive === Primitive.Jit) {
3862
4157
  const { name, jaxpr, numConsts } = params;
3863
4158
  return this.#partialEvalJaxpr(name, jaxpr, numConsts, tracers);
3864
4159
  }
@@ -3884,14 +4179,14 @@ var PartialEvalTrace = class extends Trace {
3884
4179
  * Evaluate a Jaxpr on a set of PartialEvalTracers, computing as many known
3885
4180
  * values as possible (with JIT) and forwarding the unknown ones.
3886
4181
  *
3887
- * Used when encountering a JitCall rule during the trace.
4182
+ * Used when encountering a Jit rule during the trace.
3888
4183
  */
3889
4184
  #partialEvalJaxpr(name, jaxpr, numConsts, tracers) {
3890
4185
  jaxpr = jaxpr.flatten();
3891
4186
  const inUnknowns = tracers.map((t) => !t.pval.isKnown);
3892
4187
  const { jaxpr1, jaxpr2, outUnknowns, numRes } = partialEvalJaxpr(jaxpr, inUnknowns);
3893
4188
  const [knownTracers, unknownTracers] = partitionList(inUnknowns, tracers);
3894
- const outs1Res = bind(Primitive.JitCall, knownTracers.map((t) => t.ref.fullLower()), {
4189
+ const outs1Res = bind(Primitive.Jit, knownTracers.map((t) => t.ref.fullLower()), {
3895
4190
  name: `${name}_peval`,
3896
4191
  jaxpr: jaxpr1,
3897
4192
  numConsts: 0
@@ -3901,7 +4196,7 @@ var PartialEvalTrace = class extends Trace {
3901
4196
  const resTracers = res.map((x) => this.instantiateConst(fullRaise(this, x)));
3902
4197
  const recipe = {
3903
4198
  type: "JaxprEqn",
3904
- prim: Primitive.JitCall,
4199
+ prim: Primitive.Jit,
3905
4200
  tracersIn: resTracers.concat(unknownTracers),
3906
4201
  params: {
3907
4202
  name: `${name}_resid`,
@@ -3930,7 +4225,7 @@ function partialEvalJaxpr(jaxpr, inUnknowns, instantiate) {
3930
4225
  const eqns1 = [];
3931
4226
  const eqns2 = [];
3932
4227
  for (const eqn of jaxpr.eqns) {
3933
- if (eqn.primitive === Primitive.JitCall) throw new TypeError("partialEvalJaxpr requires flattened Jaxpr");
4228
+ if (eqn.primitive === Primitive.Jit) throw new TypeError("partialEvalJaxpr requires flattened Jaxpr");
3934
4229
  const hasUnknowns = eqn.inputs.some((x) => x instanceof Var && !knownVars.has(x));
3935
4230
  if (hasUnknowns) {
3936
4231
  for (const x of eqn.inputs) if (x instanceof Var && knownVars.has(x)) residuals.add(x);
@@ -4005,10 +4300,7 @@ function partialEvalGraphToJaxpr(tracersIn, tracersOut) {
4005
4300
  for (const t of tracersOut) t.dispose();
4006
4301
  jaxpr = jaxpr.simplify();
4007
4302
  if (DEBUG >= 5) console.info("jaxpr from partial evaluation:\n" + jaxpr.toString());
4008
- return {
4009
- jaxpr,
4010
- consts
4011
- };
4303
+ return new ClosedJaxpr(jaxpr, consts);
4012
4304
  }
4013
4305
  /** Marker type for pullback, used by transpose rules. */
4014
4306
  var UndefPrimal = class {
@@ -4200,317 +4492,142 @@ const transposeRules = {
4200
4492
  cond.dispose();
4201
4493
  return cts;
4202
4494
  },
4203
- [Primitive.Transpose]([ct], [x], { perm }) {
4204
- if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Transpose);
4205
- return [transpose$1(ct, invertPermutation(perm))];
4206
- },
4207
- [Primitive.Broadcast]([ct], [x], { axis }) {
4208
- if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Broadcast);
4209
- return [reduce(ct, AluOp.Add, axis)];
4210
- },
4211
- [Primitive.Reshape]([ct], [x], _) {
4212
- if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Reshape);
4213
- return [reshape$1(ct, x.aval.shape)];
4214
- },
4215
- [Primitive.Flip]([ct], [x], { axis }) {
4216
- if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Flip);
4217
- return [flip$1(ct, axis)];
4218
- },
4219
- [Primitive.Shrink]([ct], [x], { slice }) {
4220
- if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Shrink);
4221
- const width = slice.map(([s, e$1], i) => [s, x.aval.shape[i] - e$1]);
4222
- return [pad$1(ct, width)];
4223
- },
4224
- [Primitive.Pad]([ct], [x], { width }) {
4225
- if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Pad);
4226
- const slice = width.map(([s, _e], i) => [s, s + x.aval.shape[i]]);
4227
- return [shrink(ct, slice)];
4228
- },
4229
4495
  [Primitive.Gather]([ct], [x, ...indices], { axis, outDim }) {
4230
4496
  if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
4231
4497
  if (indices.some((i) => i instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
4232
4498
  throw new Error("Gather transpose rule is not yet implemented, requires complex Scatter sum operation");
4233
4499
  },
4234
- [Primitive.JitCall](cts, args, { name, jaxpr }) {
4235
- const undefPrimals = args.map((x) => x instanceof UndefPrimal);
4236
- const { newJaxpr, newConsts } = transposeJaxpr(jaxpr, undefPrimals);
4237
- const residuals = args.filter((x, i$1) => !undefPrimals[i$1]);
4238
- const outs = bind(Primitive.JitCall, [
4239
- ...newConsts.map((c) => c.ref),
4240
- ...residuals,
4241
- ...cts
4242
- ], {
4243
- name: `${name}_t`,
4244
- jaxpr: newJaxpr,
4245
- numConsts: newConsts.length
4246
- });
4247
- let i = 0;
4248
- return undefPrimals.map((isUndef) => isUndef ? outs[i++] : null);
4249
- }
4250
- };
4251
- const transposeJaxprCache = /* @__PURE__ */ new Map();
4252
- function transposeJaxpr(jaxpr, undefPrimals) {
4253
- const cacheKey = JSON.stringify(undefPrimals);
4254
- const prevResult = transposeJaxprCache.get(jaxpr)?.get(cacheKey);
4255
- if (prevResult) return prevResult;
4256
- const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
4257
- const forwardInTypes = inTypes.filter((_, i) => !undefPrimals[i]);
4258
- const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((forwardIn, cotangents) => {
4259
- const args = [];
4260
- let forwardInIdx = 0;
4261
- for (let i = 0; i < undefPrimals.length; i++) if (undefPrimals[i]) args.push(new UndefPrimal(inTypes[i]));
4262
- else args.push(forwardIn[forwardInIdx++]);
4263
- return evalJaxprTransposed(jaxpr, args, cotangents);
4264
- })(forwardInTypes, outTypes);
4265
- typecheckJaxpr(newJaxpr);
4266
- const result = {
4267
- newJaxpr,
4268
- newConsts
4269
- };
4270
- if (!transposeJaxprCache.has(jaxpr)) transposeJaxprCache.set(jaxpr, /* @__PURE__ */ new Map());
4271
- transposeJaxprCache.get(jaxpr).set(cacheKey, result);
4272
- return result;
4273
- }
4274
- function vjpFlat(f, primalsIn) {
4275
- const { primalsOut, jaxpr, consts } = linearizeFlatUtil(f, primalsIn);
4276
- const fVjp = (...cotangents) => {
4277
- const transposeInputs = [...consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
4278
- return evalJaxprTransposed(jaxpr, transposeInputs, cotangents);
4279
- };
4280
- const dispose$1 = () => {
4281
- for (const c of consts) c.dispose();
4282
- };
4283
- return [
4284
- primalsOut,
4285
- fVjp,
4286
- dispose$1
4287
- ];
4288
- }
4289
- function vjp$1(f, ...primalsIn) {
4290
- const [primalsInFlat, inTree] = flatten(primalsIn);
4291
- const [fFlat, outTree] = flattenFun(f, inTree);
4292
- const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
4293
- if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
4294
- const primalsOut = unflatten(outTree.value, primalsOutFlat);
4295
- const fVjp = ((cotangentsOut) => {
4296
- const [cotangentsOutFlat, outTree2] = flatten(cotangentsOut);
4297
- if (!outTree.value.equals(outTree2)) throw new TreeMismatchError("vjp", outTree.value, outTree2);
4298
- const cotangentsInFlat = fVjpFlat(...cotangentsOutFlat.map(pureArray));
4299
- return unflatten(inTree, cotangentsInFlat);
4300
- });
4301
- fVjp.dispose = dispose$1;
4302
- return [primalsOut, fVjp];
4303
- }
4304
- function grad$1(f) {
4305
- const valueAndGradFn = valueAndGrad$1(f);
4306
- return (...x) => {
4307
- const [y, dx] = valueAndGradFn(...x);
4308
- y.dispose();
4309
- return dx;
4310
- };
4311
- }
4312
- function valueAndGrad$1(f) {
4313
- return (...x) => {
4314
- if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
4315
- const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
4316
- if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
4317
- if (!isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
4318
- const [ct, ...rest] = fVjp(onesLike$1(y.ref));
4319
- for (const r of rest) dispose(r);
4320
- fVjp.dispose();
4321
- return [y, ct];
4322
- };
4323
- }
4324
- function jacrev$1(f) {
4325
- return function jacobianReverse(x) {
4326
- if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
4327
- const [size$1] = x.shape;
4328
- const pullback = (ct) => {
4329
- const [y, fVjp] = vjp$1(f, x);
4330
- y.dispose();
4331
- const [ret] = fVjp(ct);
4332
- fVjp.dispose();
4333
- return ret;
4334
- };
4335
- return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
4336
- };
4337
- }
4338
-
4339
- //#endregion
4340
- //#region src/library/lax.ts
4341
- var lax_exports = {};
4342
- __export(lax_exports, {
4343
- conv: () => conv,
4344
- convGeneralDilated: () => convGeneralDilated,
4345
- convWithGeneralPadding: () => convWithGeneralPadding,
4346
- dot: () => dot$1,
4347
- erf: () => erf,
4348
- erfc: () => erfc,
4349
- reduceWindow: () => reduceWindow,
4350
- stopGradient: () => stopGradient$1
4351
- });
4352
- /**
4353
- * General dot product/contraction operator.
4354
- *
4355
- * Prefer higher-level functions like `jax.numpy.dot()`, `jax.numpy.matmul()`,
4356
- * `jax.numpy.tensordot(), and `jax.numpy.einsum()` where possible.
4357
- */
4358
- function dot$1(lhs, rhs, { lhsContractingDims: lc = [], rhsContractingDims: rc = [], lhsBatchDims: lb = [], rhsBatchDims: rb = [] } = {}) {
4359
- if (lc.length !== rc.length) throw new Error(`dot: contracting dims lengths mismatch, got ${JSON.stringify(lc)} and ${JSON.stringify(rc)}`);
4360
- else if (lb.length !== rb.length) throw new Error(`dot: batch dims lengths mismatch, got ${JSON.stringify(lb)} and ${JSON.stringify(rb)}`);
4361
- lc = lc.map((a) => checkAxis(a, lhs.ndim));
4362
- rc = rc.map((a) => checkAxis(a, rhs.ndim));
4363
- lb = lb.map((a) => checkAxis(a, lhs.ndim));
4364
- rb = rb.map((a) => checkAxis(a, rhs.ndim));
4365
- if (lc.some((a) => lb.includes(a))) throw new Error(`dot: lhs contracting dims ${JSON.stringify(lc)} overlap with batch dims ${JSON.stringify(lb)}`);
4366
- else if (rc.some((a) => rb.includes(a))) throw new Error(`dot: rhs contracting dims ${JSON.stringify(rc)} overlap with batch dims ${JSON.stringify(rb)}`);
4367
- const lf = range(lhs.ndim).filter((a) => !lc.includes(a) && !lb.includes(a));
4368
- const rf = range(rhs.ndim).filter((a) => !rc.includes(a) && !rb.includes(a));
4369
- const lhs2 = lhs.transpose([
4370
- ...lb,
4371
- ...lf,
4372
- ...lc
4373
- ]);
4374
- const rhs2 = rhs.transpose([
4375
- ...rb,
4376
- ...rf,
4377
- ...rc
4378
- ]);
4379
- if (lc.length === 0) return mul(lhs2.reshape([
4380
- ...lb.map((a) => lhs.shape[a]),
4381
- ...lf.map((a) => lhs.shape[a]),
4382
- ...rep(rf.length, 1)
4383
- ]), rhs2.reshape([
4384
- ...rb.map((a) => rhs.shape[a]),
4385
- ...rep(lf.length, 1),
4386
- ...rf.map((a) => rhs.shape[a])
4387
- ]));
4388
- const dotShapeX = lc.map((a) => lhs.shape[a]);
4389
- const dotShapeY = rc.map((a) => rhs.shape[a]);
4390
- if (!deepEqual(dotShapeX, dotShapeY)) throw new Error(`dot: shapes not aligned along contracting dims: ${JSON.stringify(dotShapeX)} != ${JSON.stringify(dotShapeY)}`);
4391
- return dot$2(lhs2.reshape([
4392
- ...lb.map((a) => lhs.shape[a]),
4393
- ...lf.map((a) => lhs.shape[a]),
4394
- ...rep(rf.length, 1),
4395
- prod(dotShapeX)
4396
- ]), rhs2.reshape([
4397
- ...rb.map((a) => rhs.shape[a]),
4398
- ...rep(lf.length, 1),
4399
- ...rf.map((a) => rhs.shape[a]),
4400
- prod(dotShapeY)
4401
- ]));
4402
- }
4403
- function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
4404
- const padType = padding.toUpperCase();
4405
- switch (padType) {
4406
- case "VALID": return rep(inShape.length, [0, 0]);
4407
- case "SAME":
4408
- case "SAME_LOWER": {
4409
- const outShape = inShape.map((size$1, i) => Math.ceil(size$1 / strides[i]));
4410
- const padSizes = zipn(outShape, strides, filterShape, dilation, inShape).map(([o, s, k, d, i]) => Math.max(0, (o - 1) * s + 1 + (k - 1) * d - i));
4411
- if (padType === "SAME") return padSizes.map((size$1) => [size$1 >> 1, size$1 - (size$1 >> 1)]);
4412
- else return padSizes.map((size$1) => [size$1 - (size$1 >> 1), size$1 >> 1]);
4413
- }
4414
- default: throw new Error(`Unknown padding type: ${padType}`);
4415
- }
4416
- }
4417
- /**
4418
- * General n-dimensional convolution operator, with optional dilation.
4419
- *
4420
- * The semantics of this operation mimic the `jax.lax.conv_general_dilated`
4421
- * function in JAX, which wraps XLA's general convolution operator.
4422
- *
4423
- * Grouped convolutions are not supported right now.
4424
- */
4425
- function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
4426
- if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
4427
- if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
4428
- if (typeof padding === "string") {
4429
- if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
4430
- padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? rep(rhs.ndim - 2, 1), padding);
4431
- }
4432
- if (featureGroupCount !== 1) {
4433
- const G = featureGroupCount;
4434
- const [N, C_in, ...xs] = lhs.shape;
4435
- const [C_out, C_in_per_group, ...ks] = rhs.shape;
4436
- if (C_in % G !== 0) throw new Error(`featureGroupCount=${G} must divide input channels=${C_in}`);
4437
- if (C_out % G !== 0) throw new Error(`featureGroupCount=${G} must divide output channels=${C_out}`);
4438
- if (C_in / G !== C_in_per_group) throw new Error(`rhs input channels=${C_in_per_group} must equal lhs input channels / groups=${C_in / G}`);
4439
- const lhsGrouped = moveaxis(lhs.reshape([
4440
- N,
4441
- G,
4442
- C_in / G,
4443
- ...xs
4444
- ]), 1, 0);
4445
- const rhsGrouped = rhs.reshape([
4446
- G,
4447
- C_out / G,
4448
- C_in_per_group,
4449
- ...ks
4450
- ]);
4451
- const result = conv$1(lhsGrouped, rhsGrouped, {
4452
- vmapDims: 1,
4453
- strides: windowStrides,
4454
- padding,
4455
- lhsDilation,
4456
- rhsDilation
4500
+ [Primitive.Transpose]([ct], [x], { perm }) {
4501
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Transpose);
4502
+ return [transpose$1(ct, invertPermutation(perm))];
4503
+ },
4504
+ [Primitive.Broadcast]([ct], [x], { axis }) {
4505
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Broadcast);
4506
+ return [reduce(ct, AluOp.Add, axis)];
4507
+ },
4508
+ [Primitive.Reshape]([ct], [x], _) {
4509
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Reshape);
4510
+ return [reshape$1(ct, x.aval.shape)];
4511
+ },
4512
+ [Primitive.Flip]([ct], [x], { axis }) {
4513
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Flip);
4514
+ return [flip$1(ct, axis)];
4515
+ },
4516
+ [Primitive.Shrink]([ct], [x], { slice }) {
4517
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Shrink);
4518
+ const width = slice.map(([s, e$1], i) => [s, x.aval.shape[i] - e$1]);
4519
+ return [pad$1(ct, width)];
4520
+ },
4521
+ [Primitive.Pad]([ct], [x], { width }) {
4522
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Pad);
4523
+ const slice = width.map(([s, _e], i) => [s, s + x.aval.shape[i]]);
4524
+ return [shrink(ct, slice)];
4525
+ },
4526
+ [Primitive.TriangularSolve]([ct], [a, b], { unitDiagonal }) {
4527
+ if (a instanceof UndefPrimal || !(b instanceof UndefPrimal)) throw new NonlinearError(Primitive.TriangularSolve);
4528
+ const ctB = triangularSolve$1(moveaxis(a, -2, -1), ct, {
4529
+ lower: true,
4530
+ unitDiagonal
4457
4531
  });
4458
- const ys = result.shape.slice(3);
4459
- return moveaxis(result, 0, 1).reshape([
4460
- N,
4461
- C_out,
4462
- ...ys
4463
- ]);
4532
+ return [null, ctB];
4533
+ },
4534
+ [Primitive.Jit](cts, args, { name, jaxpr }) {
4535
+ const undefPrimals = args.map((x) => x instanceof UndefPrimal);
4536
+ const newJaxpr = transposeJaxpr(jaxpr, undefPrimals);
4537
+ const residuals = args.filter((x, i$1) => !undefPrimals[i$1]);
4538
+ const outs = bind(Primitive.Jit, [
4539
+ ...newJaxpr.consts.map((c) => c.ref),
4540
+ ...residuals,
4541
+ ...cts
4542
+ ], {
4543
+ name: `${name}_t`,
4544
+ jaxpr: newJaxpr.jaxpr,
4545
+ numConsts: newJaxpr.consts.length
4546
+ });
4547
+ let i = 0;
4548
+ return undefPrimals.map((isUndef) => isUndef ? outs[i++] : null);
4464
4549
  }
4465
- return conv$1(lhs, rhs, {
4466
- strides: windowStrides,
4467
- padding,
4468
- lhsDilation,
4469
- rhsDilation
4470
- });
4471
- }
4472
- /** Convenience wrapper around `convGeneralDilated`. */
4473
- function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation) {
4474
- return convGeneralDilated(lhs, rhs, windowStrides, padding, {
4475
- lhsDilation,
4476
- rhsDilation
4477
- });
4550
+ };
4551
+ const transposeJaxprCache = /* @__PURE__ */ new Map();
4552
+ function transposeJaxpr(jaxpr, undefPrimals) {
4553
+ const cacheKey = JSON.stringify(undefPrimals);
4554
+ const prevResult = transposeJaxprCache.get(jaxpr)?.get(cacheKey);
4555
+ if (prevResult) return prevResult;
4556
+ const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
4557
+ const forwardInTypes = inTypes.filter((_, i) => !undefPrimals[i]);
4558
+ const { jaxpr: newJaxpr } = makeJaxpr$1((forwardIn, cotangents) => {
4559
+ const args = [];
4560
+ let forwardInIdx = 0;
4561
+ for (let i = 0; i < undefPrimals.length; i++) if (undefPrimals[i]) args.push(new UndefPrimal(inTypes[i]));
4562
+ else args.push(forwardIn[forwardInIdx++]);
4563
+ return evalJaxprTransposed(jaxpr, args, cotangents);
4564
+ })(forwardInTypes, outTypes);
4565
+ typecheckJaxpr(newJaxpr.jaxpr);
4566
+ if (!transposeJaxprCache.has(jaxpr)) transposeJaxprCache.set(jaxpr, /* @__PURE__ */ new Map());
4567
+ transposeJaxprCache.get(jaxpr).set(cacheKey, newJaxpr);
4568
+ return newJaxpr;
4478
4569
  }
4479
- /** Convenience wrapper around `convGeneralDilated`. */
4480
- function conv(lhs, rhs, windowStrides, padding) {
4481
- return convGeneralDilated(lhs, rhs, windowStrides, padding);
4570
+ function vjpFlat(f, primalsIn) {
4571
+ const { primalsOut, jaxpr } = linearizeFlatUtil(f, primalsIn);
4572
+ const fVjp = (...cotangents) => {
4573
+ const transposeInputs = [...jaxpr.consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
4574
+ return evalJaxprTransposed(jaxpr.jaxpr, transposeInputs, cotangents);
4575
+ };
4576
+ const dispose$1 = () => jaxpr.dispose();
4577
+ return [
4578
+ primalsOut,
4579
+ fVjp,
4580
+ dispose$1
4581
+ ];
4482
4582
  }
4483
- /** Reduce a computation over padded windows. */
4484
- function reduceWindow(operand, computation, windowDimensions, windowStrides) {
4485
- if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
4486
- if (!windowStrides) windowStrides = rep(windowDimensions.length, 1);
4487
- for (let i = 0; i < operand.ndim; i++) computation = vmap$1(computation, 0);
4488
- return computation(bind1(Primitive.Pool, [operand], {
4489
- window: windowDimensions,
4490
- strides: windowStrides
4491
- }));
4583
+ function vjp$1(f, ...primalsIn) {
4584
+ const [primalsInFlat, inTree] = flatten(primalsIn);
4585
+ const [fFlat, outTree] = flattenFun(f, inTree);
4586
+ const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
4587
+ if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
4588
+ const primalsOut = unflatten(outTree.value, primalsOutFlat);
4589
+ const fVjp = ((cotangentsOut) => {
4590
+ const [cotangentsOutFlat, outTree2] = flatten(cotangentsOut);
4591
+ if (!outTree.value.equals(outTree2)) throw new TreeMismatchError("vjp", outTree.value, outTree2);
4592
+ const cotangentsInFlat = fVjpFlat(...cotangentsOutFlat.map(pureArray));
4593
+ return unflatten(inTree, cotangentsInFlat);
4594
+ });
4595
+ fVjp.dispose = dispose$1;
4596
+ return [primalsOut, fVjp];
4492
4597
  }
4493
- /** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
4494
- function erf(x) {
4495
- return erf$1(x);
4598
+ function grad$1(f) {
4599
+ const valueAndGradFn = valueAndGrad$1(f);
4600
+ return (...x) => {
4601
+ const [y, dx] = valueAndGradFn(...x);
4602
+ y.dispose();
4603
+ return dx;
4604
+ };
4496
4605
  }
4497
- /**
4498
- * The complementary error function: `erfc(x) = 1 - erf(x)`.
4499
- *
4500
- * This function is more accurate than `1 - erf(x)` for large values of `x`,
4501
- * where `erf(x)` is very close to 1.
4502
- */
4503
- function erfc(x) {
4504
- return erfc$1(x);
4606
+ function valueAndGrad$1(f) {
4607
+ return (...x) => {
4608
+ if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
4609
+ const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
4610
+ if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
4611
+ if (!isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
4612
+ const [ct, ...rest] = fVjp(onesLike$1(y.ref));
4613
+ for (const r of rest) dispose(r);
4614
+ fVjp.dispose();
4615
+ return [y, ct];
4616
+ };
4505
4617
  }
4506
- /**
4507
- * Stops gradient computation.
4508
- *
4509
- * Behaves as the identity function but prevents the flow of gradients during
4510
- * forward or reverse-mode automatic differentiation.
4511
- */
4512
- function stopGradient$1(x) {
4513
- return stopGradient(x);
4618
+ function jacrev$1(f) {
4619
+ return function jacobianReverse(x) {
4620
+ if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
4621
+ const [size$1] = x.shape;
4622
+ const pullback = (ct) => {
4623
+ const [y, fVjp] = vjp$1(f, x);
4624
+ y.dispose();
4625
+ const [ret] = fVjp(ct);
4626
+ fVjp.dispose();
4627
+ return ret;
4628
+ };
4629
+ return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
4630
+ };
4514
4631
  }
4515
4632
 
4516
4633
  //#endregion
@@ -4708,34 +4825,207 @@ function* allPaths(tensors, next) {
4708
4825
  }
4709
4826
  }
4710
4827
 
4828
+ //#endregion
4829
+ //#region src/library/numpy-fft.ts
4830
+ var numpy_fft_exports = {};
4831
+ __export(numpy_fft_exports, {
4832
+ fft: () => fft,
4833
+ ifft: () => ifft
4834
+ });
4835
+ function checkPairInput(name, a) {
4836
+ const fullName = `jax.numpy.fft.${name}`;
4837
+ if (!deepEqual(a.real.shape, a.imag.shape)) throw new Error(`${fullName}: real and imaginary parts must have the same shape, got ${JSON.stringify(a.real.shape)} and ${JSON.stringify(a.imag.shape)}`);
4838
+ if (a.real.dtype !== a.imag.dtype) throw new Error(`${fullName}: real and imaginary parts must have the same dtype, got ${a.real.dtype} and ${a.imag.dtype}`);
4839
+ if (!isFloatDtype(a.real.dtype)) throw new Error(`${fullName}: input must have a float dtype, got ${a.real.dtype}`);
4840
+ }
4841
+ function checkPowerOfTwo(name, n) {
4842
+ if ((n & n - 1) !== 0) throw new Error(`jax.numpy.fft.${name}: size must be a power of two, got ${n}`);
4843
+ }
4844
+ const fftUpdate = jit$1(function fftUpdate$1(i, { real, imag }) {
4845
+ const half = 2 ** i;
4846
+ real = real.reshape([-1, 2 * half]);
4847
+ imag = imag.reshape([-1, 2 * half]);
4848
+ const k = arange(0, half, 1, { dtype: real.dtype });
4849
+ const theta = k.mul(-Math.PI / half);
4850
+ const wr = cos(theta.ref);
4851
+ const wi = sin(theta);
4852
+ const ur = real.ref.slice([], [0, half]);
4853
+ const ui = imag.ref.slice([], [0, half]);
4854
+ const vr = real.slice([], [half, 2 * half]);
4855
+ const vi = imag.slice([], [half, 2 * half]);
4856
+ const tr = vr.ref.mul(wr.ref).sub(vi.ref.mul(wi.ref));
4857
+ const ti = vr.mul(wi).add(vi.mul(wr));
4858
+ return {
4859
+ real: concatenate([ur.ref.add(tr.ref), ur.sub(tr)], -1),
4860
+ imag: concatenate([ui.ref.add(ti.ref), ui.sub(ti)], -1)
4861
+ };
4862
+ }, { staticArgnums: [0] });
4863
+ /**
4864
+ * Compute a one-dimensional discrete Fourier transform.
4865
+ *
4866
+ * Currently, the size of the axis must be a power of two.
4867
+ */
4868
+ function fft(a, axis = -1) {
4869
+ checkPairInput("fft", a);
4870
+ let { real, imag } = a;
4871
+ axis = checkAxis(axis, real.ndim);
4872
+ const n = real.shape[axis];
4873
+ checkPowerOfTwo("fft", n);
4874
+ const logN = Math.log2(n);
4875
+ let perm = null;
4876
+ if (axis !== real.ndim - 1) {
4877
+ perm = range(real.ndim);
4878
+ perm.splice(axis, 1);
4879
+ perm.push(axis);
4880
+ real = real.transpose(perm);
4881
+ imag = imag.transpose(perm);
4882
+ }
4883
+ const originalShape = real.shape;
4884
+ real = real.reshape([-1, ...rep(logN, 2)]).transpose([0, ...range(1, logN + 1).reverse()]).flatten();
4885
+ imag = imag.reshape([-1, ...rep(logN, 2)]).transpose([0, ...range(1, logN + 1).reverse()]).flatten();
4886
+ for (let i = 0; i < logN; i++) ({real, imag} = fftUpdate(i, {
4887
+ real,
4888
+ imag
4889
+ }));
4890
+ real = real.reshape(originalShape);
4891
+ imag = imag.reshape(originalShape);
4892
+ if (perm !== null) {
4893
+ real = real.transpose(invertPermutation(perm));
4894
+ imag = imag.transpose(invertPermutation(perm));
4895
+ }
4896
+ return {
4897
+ real,
4898
+ imag
4899
+ };
4900
+ }
4901
+ /**
4902
+ * Compute a one-dimensional inverse discrete Fourier transform.
4903
+ *
4904
+ * Currently, the size of the axis must be a power of two.
4905
+ */
4906
+ function ifft(a, axis = -1) {
4907
+ checkPairInput("ifft", a);
4908
+ let { real, imag } = a;
4909
+ axis = checkAxis(axis, real.ndim);
4910
+ const n = real.shape[axis];
4911
+ checkPowerOfTwo("ifft", n);
4912
+ imag = imag.mul(-1);
4913
+ const result = fft({
4914
+ real,
4915
+ imag
4916
+ }, axis);
4917
+ return {
4918
+ real: result.real.div(n),
4919
+ imag: result.imag.mul(-1).div(n)
4920
+ };
4921
+ }
4922
+
4923
+ //#endregion
4924
+ //#region src/library/numpy-linalg.ts
4925
+ var numpy_linalg_exports = {};
4926
+ __export(numpy_linalg_exports, {
4927
+ cholesky: () => cholesky$1,
4928
+ diagonal: () => diagonal,
4929
+ lstsq: () => lstsq,
4930
+ matmul: () => matmul,
4931
+ matrixTranspose: () => matrixTranspose,
4932
+ outer: () => outer,
4933
+ tensordot: () => tensordot,
4934
+ trace: () => trace,
4935
+ vecdot: () => vecdot
4936
+ });
4937
+ /**
4938
+ * Compute the Cholesky decomposition of a (batched) positive-definite matrix.
4939
+ *
4940
+ * This is like `jax.lax.linalg.cholesky()`, except with an option to symmetrize
4941
+ * the input matrix, which is on by default.
4942
+ */
4943
+ function cholesky$1(a, { upper = false, symmetrizeInput = true } = {}) {
4944
+ a = fudgeArray(a);
4945
+ if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`cholesky: input must be at least 2D square matrix, got ${a.aval}`);
4946
+ if (symmetrizeInput) a = a.ref.add(matrixTranspose(a)).mul(.5);
4947
+ return cholesky(a, { upper });
4948
+ }
4949
+ /**
4950
+ * Return the least-squares solution to a linear equation.
4951
+ *
4952
+ * For overdetermined systems, this finds the `x` that minimizes `norm(ax - b)`.
4953
+ * For underdetermined systems, this finds the minimum-norm solution for `x`.
4954
+ *
4955
+ * This currently uses Cholesky decomposition to solve the normal equations,
4956
+ * under the hood. The method is not as robust as QR or SVD.
4957
+ *
4958
+ * @param a coefficient matrix of shape `(M, N)`
4959
+ * @param b right-hand side of shape `(M,)` or `(M, K)`
4960
+ * @return least-squares solution of shape `(N,)` or `(N, K)`
4961
+ */
4962
+ function lstsq(a, b) {
4963
+ a = fudgeArray(a);
4964
+ b = fudgeArray(b);
4965
+ if (a.ndim !== 2) throw new Error(`lstsq: 'a' must be a 2D array, got ${a.aval}`);
4966
+ const [m, n] = a.shape;
4967
+ if (b.shape[0] !== m) throw new Error(`lstsq: leading dimension of 'b' must match number of rows of 'a', got ${b.aval}`);
4968
+ const at = matrixTranspose(a.ref);
4969
+ if (m <= n) {
4970
+ const aat = matmul(a, at.ref);
4971
+ const l = cholesky$1(aat, { symmetrizeInput: false });
4972
+ const lb = triangularSolve(l.ref, b, {
4973
+ leftSide: true,
4974
+ lower: true
4975
+ });
4976
+ const llb = triangularSolve(l, lb, {
4977
+ leftSide: true,
4978
+ transposeA: true
4979
+ });
4980
+ return matmul(at, llb.ref);
4981
+ } else {
4982
+ const ata = matmul(at.ref, a);
4983
+ const l = cholesky$1(ata, { symmetrizeInput: false });
4984
+ const atb = matmul(at, b);
4985
+ const lb = triangularSolve(l.ref, atb, {
4986
+ leftSide: true,
4987
+ lower: true
4988
+ });
4989
+ const llb = triangularSolve(l, lb, {
4990
+ leftSide: true,
4991
+ transposeA: true
4992
+ });
4993
+ return llb;
4994
+ }
4995
+ }
4996
+
4711
4997
  //#endregion
4712
4998
  //#region src/library/numpy.ts
4713
4999
  var numpy_exports = {};
4714
5000
  __export(numpy_exports, {
4715
5001
  Array: () => Array$1,
4716
5002
  DType: () => DType,
4717
- abs: () => abs,
5003
+ abs: () => absolute,
4718
5004
  absolute: () => absolute,
4719
5005
  acos: () => acos,
4720
- acosh: () => acosh,
5006
+ acosh: () => arccosh,
4721
5007
  add: () => add,
5008
+ all: () => all,
4722
5009
  allclose: () => allclose,
5010
+ any: () => any,
4723
5011
  arange: () => arange,
4724
- arccos: () => arccos,
5012
+ arccos: () => acos,
4725
5013
  arccosh: () => arccosh,
5014
+ arcsin: () => asin,
4726
5015
  arcsinh: () => arcsinh,
4727
- arctan: () => arctan,
4728
- arctan2: () => arctan2,
5016
+ arctan: () => atan,
5017
+ arctan2: () => atan2,
4729
5018
  arctanh: () => arctanh,
4730
5019
  argmax: () => argmax,
4731
5020
  argmin: () => argmin,
5021
+ argsort: () => argsort,
4732
5022
  array: () => array,
4733
5023
  asin: () => asin,
4734
- asinh: () => asinh,
5024
+ asinh: () => arcsinh,
4735
5025
  astype: () => astype,
4736
5026
  atan: () => atan,
4737
5027
  atan2: () => atan2,
4738
- atanh: () => atanh,
5028
+ atanh: () => arctanh,
4739
5029
  bool: () => bool,
4740
5030
  broadcastArrays: () => broadcastArrays,
4741
5031
  broadcastShapes: () => broadcastShapes,
@@ -4745,16 +5035,20 @@ __export(numpy_exports, {
4745
5035
  clip: () => clip,
4746
5036
  columnStack: () => columnStack,
4747
5037
  concatenate: () => concatenate,
5038
+ convolve: () => convolve,
5039
+ corrcoef: () => corrcoef,
5040
+ correlate: () => correlate,
4748
5041
  cos: () => cos,
4749
5042
  cosh: () => cosh,
5043
+ cov: () => cov,
4750
5044
  cumsum: () => cumsum,
4751
- cumulativeSum: () => cumulativeSum,
5045
+ cumulativeSum: () => cumsum,
4752
5046
  deg2rad: () => deg2rad,
4753
5047
  degrees: () => degrees,
4754
5048
  diag: () => diag,
4755
5049
  diagonal: () => diagonal,
4756
- divide: () => divide,
4757
- dot: () => dot,
5050
+ divide: () => trueDivide,
5051
+ dot: () => dot$1,
4758
5052
  dstack: () => dstack,
4759
5053
  e: () => e,
4760
5054
  einsum: () => einsum,
@@ -4762,8 +5056,10 @@ __export(numpy_exports, {
4762
5056
  eulerGamma: () => eulerGamma,
4763
5057
  exp: () => exp,
4764
5058
  exp2: () => exp2,
5059
+ expandDims: () => expandDims,
4765
5060
  expm1: () => expm1,
4766
5061
  eye: () => eye,
5062
+ fft: () => numpy_fft_exports,
4767
5063
  flip: () => flip,
4768
5064
  fliplr: () => fliplr,
4769
5065
  flipud: () => flipud,
@@ -4794,12 +5090,14 @@ __export(numpy_exports, {
4794
5090
  ldexp: () => ldexp,
4795
5091
  less: () => less,
4796
5092
  lessEqual: () => lessEqual,
5093
+ linalg: () => numpy_linalg_exports,
4797
5094
  linspace: () => linspace,
4798
5095
  log: () => log,
4799
5096
  log10: () => log10,
4800
5097
  log1p: () => log1p,
4801
5098
  log2: () => log2,
4802
5099
  matmul: () => matmul,
5100
+ matrixTranspose: () => matrixTranspose,
4803
5101
  max: () => max,
4804
5102
  maximum: () => maximum,
4805
5103
  mean: () => mean,
@@ -4816,10 +5114,10 @@ __export(numpy_exports, {
4816
5114
  onesLike: () => onesLike,
4817
5115
  outer: () => outer,
4818
5116
  pad: () => pad,
4819
- permuteDims: () => permuteDims,
5117
+ permuteDims: () => transpose,
4820
5118
  pi: () => pi,
4821
5119
  positive: () => positive,
4822
- pow: () => pow,
5120
+ pow: () => power,
4823
5121
  power: () => power,
4824
5122
  prod: () => prod$1,
4825
5123
  promoteTypes: () => promoteTypes,
@@ -4836,6 +5134,7 @@ __export(numpy_exports, {
4836
5134
  sin: () => sin,
4837
5135
  sinh: () => sinh,
4838
5136
  size: () => size,
5137
+ sort: () => sort,
4839
5138
  sqrt: () => sqrt,
4840
5139
  square: () => square,
4841
5140
  squeeze: () => squeeze,
@@ -5000,6 +5299,26 @@ function min(a, axis = null, opts) {
5000
5299
  function max(a, axis = null, opts) {
5001
5300
  return reduce(a, AluOp.Max, axis, opts);
5002
5301
  }
5302
+ /**
5303
+ * Test whether all array elements along a given axis evaluate to True.
5304
+ *
5305
+ * Returns a boolean array with the same shape as `a` with the specified axis
5306
+ * removed. If axis is None, returns a scalar.
5307
+ */
5308
+ function all(a, axis = null, opts) {
5309
+ a = fudgeArray(a).astype(DType.Bool);
5310
+ return min(a, axis, opts);
5311
+ }
5312
+ /**
5313
+ * Test whether any array element along a given axis evaluates to True.
5314
+ *
5315
+ * Returns a boolean array with the same shape as `a` with the specified axis
5316
+ * removed. If axis is None, returns a scalar.
5317
+ */
5318
+ function any(a, axis = null, opts) {
5319
+ a = fudgeArray(a).astype(DType.Bool);
5320
+ return max(a, axis, opts);
5321
+ }
5003
5322
  /** Return the peak-to-peak range along a given axis (`max - min`). */
5004
5323
  function ptp(a, axis = null, opts) {
5005
5324
  a = fudgeArray(a);
@@ -5074,8 +5393,6 @@ function cumsum(a, axis) {
5074
5393
  a = broadcast(a, a.shape.concat(n), [-2]);
5075
5394
  return moveaxis$1(tril(a).sum(-1), -1, axis);
5076
5395
  }
5077
- /** @function Alternative name for `jax.numpy.cumsum()`. */
5078
- const cumulativeSum = cumsum;
5079
5396
  /** Reverse the elements in an array along the given axes. */
5080
5397
  function flip(x, axis = null) {
5081
5398
  const nd = ndim(x);
@@ -5185,8 +5502,11 @@ function flipud(x) {
5185
5502
  function fliplr(x) {
5186
5503
  return flip(x, 1);
5187
5504
  }
5188
- /** @function Alternative name for `numpy.transpose()`. */
5189
- const permuteDims = transpose;
5505
+ /** Transpose the last two dimensions of an array. */
5506
+ function matrixTranspose(a) {
5507
+ if (ndim(a) < 2) throw new Error(`matrixTranspose: input array must be at least 2D`);
5508
+ return moveaxis$1(a, -1, -2);
5509
+ }
5190
5510
  /** Return a 1-D flattened array containing the elements of the input. */
5191
5511
  function ravel(a) {
5192
5512
  return fudgeArray(a).ravel();
@@ -5202,6 +5522,32 @@ function squeeze(a, axis = null) {
5202
5522
  return reshape(a, newShape);
5203
5523
  }
5204
5524
  /**
5525
+ * Expand the shape of an array by inserting new axes of length 1.
5526
+ *
5527
+ * @param a - Input array.
5528
+ * @param axis - Position(s) in the expanded axes where the new axis (or axes)
5529
+ * is placed. Can be a single integer or an array of integers.
5530
+ * @returns Array with the number of dimensions increased.
5531
+ *
5532
+ * @example
5533
+ * ```ts
5534
+ * const x = np.array([1, 2]);
5535
+ * np.expandDims(x, 0); // Shape [1, 2]
5536
+ * np.expandDims(x, 1); // Shape [2, 1]
5537
+ * np.expandDims(x, [0, 2]); // Shape [1, 2, 1]
5538
+ * ```
5539
+ */
5540
+ function expandDims(a, axis) {
5541
+ const as = shape(a);
5542
+ axis = typeof axis === "number" ? [axis] : axis;
5543
+ axis = normalizeAxis(axis, as.length + axis.length);
5544
+ const newShape = [];
5545
+ let srcIdx = 0;
5546
+ for (let i = 0; i < as.length + axis.length; i++) if (axis.includes(i)) newShape.push(1);
5547
+ else newShape.push(as[srcIdx++]);
5548
+ return reshape(a, newShape);
5549
+ }
5550
+ /**
5205
5551
  * Repeat each element of an array after themselves.
5206
5552
  *
5207
5553
  * If no axis is provided, use the flattened input array, and return a flat
@@ -5289,7 +5635,7 @@ function diagonal(a, offset, axis1, axis2) {
5289
5635
  */
5290
5636
  function diag(v, k = 0) {
5291
5637
  const a = fudgeArray(v);
5292
- if (!Number.isInteger(k)) throw new TypeError(`k must be an integer, got ${k}`);
5638
+ if (!Number.isInteger(k)) throw new Error(`k must be an integer, got ${k}`);
5293
5639
  if (a.ndim === 1) {
5294
5640
  const n = a.shape[0];
5295
5641
  const ret = where(eye(n).equal(1), a.ref, zerosLike(a));
@@ -5297,12 +5643,32 @@ function diag(v, k = 0) {
5297
5643
  else if (k < 0) return pad(ret, [[-k, 0], [0, -k]]);
5298
5644
  else return ret;
5299
5645
  } else if (a.ndim === 2) return diagonal(a, k);
5300
- else throw new TypeError("numpy.diag only supports 1D and 2D arrays");
5646
+ else throw new Error("numpy.diag only supports 1D and 2D arrays");
5301
5647
  }
5302
5648
  /** Calculate the sum of the diagonal of an array along the given axes. */
5303
5649
  function trace(a, offset = 0, axis1 = 0, axis2 = 1) {
5304
5650
  return diagonal(a, offset, axis1, axis2).sum(-1);
5305
5651
  }
5652
+ /**
5653
+ * Return a sorted copy of an array.
5654
+ *
5655
+ * The array is sorted along a specified axis (the last by default). This may be
5656
+ * an unstable sort, and it dispatches to device-specific implementation.
5657
+ */
5658
+ function sort(a, axis = -1) {
5659
+ return fudgeArray(a).sort(axis);
5660
+ }
5661
+ /**
5662
+ * Return indices that would sort an array. This may be an unstable sorting
5663
+ * algorithm; it need not preserve order of indices in ties.
5664
+ *
5665
+ * Returns an array of `int32` indices.
5666
+ *
5667
+ * The array is sorted along a specified axis (the last by default).
5668
+ */
5669
+ function argsort(a, axis = -1) {
5670
+ return fudgeArray(a).argsort(axis);
5671
+ }
5306
5672
  /** Return if two arrays are element-wise equal within a tolerance. */
5307
5673
  function allclose(actual, expected, options) {
5308
5674
  const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
@@ -5319,11 +5685,11 @@ function allclose(actual, expected, options) {
5319
5685
  }
5320
5686
  /** Matrix product of two arrays. */
5321
5687
  function matmul(x, y) {
5322
- if (ndim(x) === 0 || ndim(y) === 0) throw new TypeError("matmul: x and y must be at least 1D");
5688
+ if (ndim(x) === 0 || ndim(y) === 0) throw new Error("matmul: x and y must be at least 1D");
5323
5689
  x = x, y = y;
5324
5690
  if (y.ndim === 1) return dot$2(x, y);
5325
5691
  const numBatchDims = Math.min(Math.max(x.ndim, 2), y.ndim) - 2;
5326
- return dot$1(x, y, {
5692
+ return dot(x, y, {
5327
5693
  lhsContractingDims: [-1],
5328
5694
  rhsContractingDims: [-2],
5329
5695
  lhsBatchDims: range(-2 - numBatchDims, -2),
@@ -5331,11 +5697,11 @@ function matmul(x, y) {
5331
5697
  });
5332
5698
  }
5333
5699
  /** Dot product of two arrays. */
5334
- function dot(x, y) {
5700
+ function dot$1(x, y) {
5335
5701
  if (ndim(x) === 0 || ndim(y) === 0) return multiply(x, y);
5336
5702
  x = x, y = y;
5337
5703
  if (y.ndim === 1) return dot$2(x, y);
5338
- return dot$1(x, y, {
5704
+ return dot(x, y, {
5339
5705
  lhsContractingDims: [-1],
5340
5706
  rhsContractingDims: [-2]
5341
5707
  });
@@ -5351,7 +5717,7 @@ function tensordot(x, y, axes = 2) {
5351
5717
  x = fudgeArray(x);
5352
5718
  y = fudgeArray(y);
5353
5719
  if (typeof axes === "number") axes = [range(-axes, 0), range(axes)];
5354
- return dot$1(x, y, {
5720
+ return dot(x, y, {
5355
5721
  lhsContractingDims: axes[0],
5356
5722
  rhsContractingDims: axes[1]
5357
5723
  });
@@ -5444,7 +5810,7 @@ function einsum(...args) {
5444
5810
  const [b, bidx] = processSingleTensor(operands[j], indices[j], indices[i]);
5445
5811
  indexReduced = indexReduced.filter((idx) => aidx.includes(idx));
5446
5812
  const indexBatch = aidx.filter((idx) => bidx.includes(idx) && !indexReduced.includes(idx));
5447
- const result = dot$1(a, b, {
5813
+ const result = dot(a, b, {
5448
5814
  lhsContractingDims: indexReduced.map((idx) => aidx.indexOf(idx)),
5449
5815
  rhsContractingDims: indexReduced.map((idx) => bidx.indexOf(idx)),
5450
5816
  lhsBatchDims: indexBatch.map((idx) => aidx.indexOf(idx)),
@@ -5472,7 +5838,7 @@ function einsum(...args) {
5472
5838
  * Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
5473
5839
  */
5474
5840
  function inner(x, y) {
5475
- return dot$1(fudgeArray(x), fudgeArray(y), {
5841
+ return dot(fudgeArray(x), fudgeArray(y), {
5476
5842
  lhsContractingDims: [-1],
5477
5843
  rhsContractingDims: [-1]
5478
5844
  });
@@ -5505,6 +5871,30 @@ function vecdot(x, y, { axis } = {}) {
5505
5871
  function vdot(x, y) {
5506
5872
  return dot$2(ravel(x), ravel(y));
5507
5873
  }
5874
+ function _convImpl(name, x, y, mode) {
5875
+ if (x.ndim !== 1 || y.ndim !== 1) throw new Error(`${name}: both inputs must be 1D arrays, got ${x.ndim}D and ${y.ndim}D`);
5876
+ let flipOutput = false;
5877
+ if (x.shape[0] < y.shape[0]) {
5878
+ [x, y] = [y, x];
5879
+ if (name === "correlate") flipOutput = true;
5880
+ }
5881
+ if (name === "convolve") y = flip(y);
5882
+ let padding;
5883
+ if (mode === "valid") padding = "VALID";
5884
+ else if (mode === "same") padding = "SAME_LOWER";
5885
+ else if (mode === "full") padding = [[y.shape[0] - 1, y.shape[0] - 1]];
5886
+ else throw new Error(`${name}: invalid mode ${mode}, expected "full", "same", or "valid"`);
5887
+ const z = conv(x.slice(null, null), y.slice(null, null), [1], padding).slice(0, 0);
5888
+ return flipOutput ? flip(z) : z;
5889
+ }
5890
+ /** Convolution of two one-dimensional arrays. */
5891
+ function convolve(x, y, mode = "full") {
5892
+ return _convImpl("convolve", x, y, mode);
5893
+ }
5894
+ /** Correlation of two one dimensional arrays. */
5895
+ function correlate(x, y, mode = "valid") {
5896
+ return _convImpl("correlate", x, y, mode);
5897
+ }
5508
5898
  /**
5509
5899
  * Return a tuple of coordinate matrices from coordinate vectors.
5510
5900
  *
@@ -5513,7 +5903,7 @@ function vdot(x, y) {
5513
5903
  */
5514
5904
  function meshgrid(xs, { indexing } = {}) {
5515
5905
  indexing ??= "xy";
5516
- for (const x of xs) if (x.ndim !== 1) throw new TypeError(`meshgrid: all inputs must be 1D arrays, got ${x.ndim}D array`);
5906
+ for (const x of xs) if (x.ndim !== 1) throw new Error(`meshgrid: all inputs must be 1D arrays, got ${x.ndim}D array`);
5517
5907
  if (xs.length <= 1) return xs;
5518
5908
  if (indexing === "xy") {
5519
5909
  const [a, b, ...rest] = xs;
@@ -5529,44 +5919,7 @@ function meshgrid(xs, { indexing } = {}) {
5529
5919
  ];
5530
5920
  }
5531
5921
  const shape$1 = xs.map((x) => x.shape[0]);
5532
- return xs.map((x, i) => broadcast(x, shape$1, [...range(i), ...range(i + 1, xs.length)]));
5533
- }
5534
- /**
5535
- * Return an array with ones on and below the diagonal and zeros elsewhere.
5536
- *
5537
- * If `k` is provided, it specifies the sub-diagonal on and below which the
5538
- * array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
5539
- * `k>0` is above it.
5540
- */
5541
- function tri(n, m, k = 0, { dtype, device } = {}) {
5542
- m ??= n;
5543
- dtype ??= DType.Float32;
5544
- if (!Number.isInteger(n) || n < 0) throw new TypeError(`tri: n must be a non-negative integer, got ${n}`);
5545
- if (!Number.isInteger(m) || m < 0) throw new TypeError(`tri: m must be a non-negative integer, got ${m}`);
5546
- if (!Number.isInteger(k)) throw new TypeError(`tri: k must be an integer, got ${k}`);
5547
- const rows = arange(k, n + k, 1, {
5548
- dtype: DType.Int32,
5549
- device
5550
- });
5551
- const cols = arange(0, m, 1, {
5552
- dtype: DType.Int32,
5553
- device
5554
- });
5555
- return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
5556
- }
5557
- /** Return the lower triangle of an array. Must be of dimension >= 2. */
5558
- function tril(a, k = 0) {
5559
- if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
5560
- a = fudgeArray(a);
5561
- const [n, m] = a.shape.slice(-2);
5562
- return where(tri(n, m, k, { dtype: bool }), a.ref, zerosLike(a));
5563
- }
5564
- /** Return the upper triangle of an array. Must be of dimension >= 2. */
5565
- function triu(a, k = 0) {
5566
- if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
5567
- a = fudgeArray(a);
5568
- const [n, m] = a.shape.slice(-2);
5569
- return where(tri(n, m, k - 1, { dtype: bool }), zerosLike(a.ref), a);
5922
+ return xs.map((x, i) => broadcast(x, shape$1, [...range(i), ...range(i + 1, xs.length)]));
5570
5923
  }
5571
5924
  /**
5572
5925
  * Clip (limit) the values in an array.
@@ -5592,8 +5945,6 @@ function absolute(x) {
5592
5945
  x = fudgeArray(x);
5593
5946
  return where(less(x.ref, 0), x.ref.mul(-1), x);
5594
5947
  }
5595
- /** @function Alias of `jax.numpy.absolute()`. */
5596
- const abs = absolute;
5597
5948
  /** Return an element-wise indication of sign of the input. */
5598
5949
  function sign(x) {
5599
5950
  x = fudgeArray(x);
@@ -5672,12 +6023,6 @@ const atan2 = jit$1(function atan2$1(y, x) {
5672
6023
  const denom = where(xNeg, y, r.add(x));
5673
6024
  return atan(numer.div(denom)).mul(2);
5674
6025
  });
5675
- /** @function Alias of `jax.numpy.acos()`. */
5676
- const arccos = acos;
5677
- /** @function Alias of `jax.numpy.atan()`. */
5678
- const arctan = atan;
5679
- /** @function Alias of `jax.numpy.atan2()`. */
5680
- const arctan2 = atan2;
5681
6026
  /** Element-wise subtraction, with broadcasting. */
5682
6027
  function subtract(x, y) {
5683
6028
  x = fudgeArray(x);
@@ -5708,8 +6053,6 @@ const fmod = jit$1(function fmod$1(x, y) {
5708
6053
  const remainder = jit$1(function remainder$1(x, y) {
5709
6054
  return mod(mod(x, y.ref).add(y.ref), y);
5710
6055
  });
5711
- /** @function Alias of `jax.numpy.trueDivide()`. */
5712
- const divide = trueDivide;
5713
6056
  /** Round input to the nearest integer towards zero. */
5714
6057
  function trunc(x) {
5715
6058
  return idiv(x, 1);
@@ -5731,9 +6074,9 @@ function ldexp(x1, x2) {
5731
6074
  */
5732
6075
  function frexp(x) {
5733
6076
  x = fudgeArray(x);
5734
- const absx = abs(x.ref);
6077
+ const absx = absolute(x.ref);
5735
6078
  const exponent = where(equal(x.ref, 0), 0, floor(log2(absx)).add(1).astype(DType.Int32));
5736
- const mantissa = divide(x, exp2(exponent.ref.astype(x.dtype)));
6079
+ const mantissa = x.div(exp2(exponent.ref.astype(x.dtype)));
5737
6080
  return [mantissa, exponent];
5738
6081
  }
5739
6082
  /** Calculate `2**p` for all p in the input array. */
@@ -5776,10 +6119,8 @@ const power = jit$1(function power$1(x1, x2) {
5776
6119
  const x2i = trunc(x2.ref);
5777
6120
  const shouldBeNaN = multiply(x2.ref.notEqual(x2i.ref), x1.ref.less(0));
5778
6121
  const resultSign = where(mod(x2i, 2).notEqual(0), where(x1.ref.less(0), -1, 1), 1);
5779
- return where(shouldBeNaN, nan, exp(log(abs(x1)).mul(x2)).mul(resultSign));
6122
+ return where(shouldBeNaN, nan, exp(log(absolute(x1)).mul(x2)).mul(resultSign));
5780
6123
  });
5781
- /** @function Alias of `jax.numpy.power()`. */
5782
- const pow = power;
5783
6124
  /** @function Calculate the element-wise cube root of the input array. */
5784
6125
  const cbrt = jit$1(function cbrt$1(x) {
5785
6126
  const sgn = where(less(x.ref, 0), -1, 1);
@@ -5845,12 +6186,6 @@ const arccosh = jit$1(function arccosh$1(x) {
5845
6186
  const arctanh = jit$1(function arctanh$1(x) {
5846
6187
  return log(add(1, x.ref).div(subtract(1, x))).mul(.5);
5847
6188
  });
5848
- /** @function Alias of `jax.numpy.arcsinh()`. */
5849
- const asinh = arcsinh;
5850
- /** @function Alias of `jax.numpy.arccosh()`. */
5851
- const acosh = arccosh;
5852
- /** @function Alias of `jax.numpy.arctanh()`. */
5853
- const atanh = arctanh;
5854
6189
  /**
5855
6190
  * Compute the variance of an array.
5856
6191
  *
@@ -5880,6 +6215,26 @@ function var_(x, axis = null, opts) {
5880
6215
  function std(x, axis = null, opts) {
5881
6216
  return sqrt(var_(x, axis, opts));
5882
6217
  }
6218
+ /** Estimate the sample covariance of a set of variables. */
6219
+ function cov(x, y) {
6220
+ x = fudgeArray(x);
6221
+ if (x.ndim === 1) x = x.reshape([1, x.shape[0]]);
6222
+ if (y !== void 0) {
6223
+ y = fudgeArray(y);
6224
+ if (y.ndim === 1) y = y.reshape([1, y.shape[0]]);
6225
+ x = vstack([x, y]);
6226
+ }
6227
+ const [_M, N] = x.shape;
6228
+ x = x.ref.sub(x.mean(1, { keepdims: true }));
6229
+ return dot$1(x.ref, x.transpose()).div(N - 1);
6230
+ }
6231
+ /** Compute the Pearson correlation coefficients (in range `[-1, 1]`). */
6232
+ function corrcoef(x, y) {
6233
+ const c = cov(x, y);
6234
+ const variances = diag(c.ref);
6235
+ const norm = sqrt(outer(variances.ref, variances));
6236
+ return c.div(norm);
6237
+ }
5883
6238
  /** Test element-wise for positive or negative infinity, return bool array. */
5884
6239
  function isinf(x) {
5885
6240
  x = fudgeArray(x);
@@ -5909,6 +6264,253 @@ const isfinite = jit$1(function isfinite$1(x) {
5909
6264
  return isnan(x.ref).add(isinf(x)).notEqual(true);
5910
6265
  });
5911
6266
 
6267
+ //#endregion
6268
+ //#region src/library/lax-linalg.ts
6269
+ var lax_linalg_exports = {};
6270
+ __export(lax_linalg_exports, {
6271
+ cholesky: () => cholesky,
6272
+ triangularSolve: () => triangularSolve
6273
+ });
6274
+ /**
6275
+ * Compute the Cholesky decomposition of a symmetric positive-definite matrix.
6276
+ *
6277
+ * The Cholesky decomposition of a matrix `A` is:
6278
+ *
6279
+ * - A = L @ L^T (for upper=false, default)
6280
+ * - A = U^T @ U (for upper=true)
6281
+ *
6282
+ * where `L` is a lower-triangular matrix and `U` is an upper-triangular matrix.
6283
+ * The input matrix must be symmetric and positive-definite.
6284
+ *
6285
+ * @example
6286
+ * ```ts
6287
+ * import { lax, numpy as np } from "@jax-js/jax";
6288
+ *
6289
+ * const x = np.array([[2., 1.], [1., 2.]]);
6290
+ *
6291
+ * // Lower Cholesky factorization (default):
6292
+ * const L = lax.linalg.cholesky(x);
6293
+ * // L ≈ [[1.4142135, 0], [0.70710677, 1.2247449]]
6294
+ *
6295
+ * // Upper Cholesky factorization:
6296
+ * const U = lax.linalg.cholesky(x, { upper: true });
6297
+ * // U ≈ [[1.4142135, 0.70710677], [0, 1.2247449]]
6298
+ * ```
6299
+ */
6300
+ function cholesky(a, { upper = false } = {}) {
6301
+ const L = cholesky$2(a);
6302
+ return upper ? moveaxis$1(L, -2, -1) : L;
6303
+ }
6304
+ /**
6305
+ * Solve a triangular linear system.
6306
+ *
6307
+ * Solves `a @ x = b` (if leftSide=true) or `x @ a = b` (if leftSide=false)
6308
+ * where `a` is a triangular matrix.
6309
+ *
6310
+ * @example
6311
+ * ```ts
6312
+ * import { lax, numpy as np } from "@jax-js/jax";
6313
+ *
6314
+ * const L = np.array([[2., 0.], [1., 3.]]);
6315
+ * const b = np.array([4., 7.]).reshape([2, 1]);
6316
+ *
6317
+ * // Solve L @ x = b
6318
+ * const x = lax.linalg.triangularSolve(L, b, { leftSide: true, lower: true });
6319
+ * // x = [[2.], [5./3.]]
6320
+ * ```
6321
+ */
6322
+ function triangularSolve(a, b, { leftSide = false, lower = false, transposeA = false, unitDiagonal = false } = {}) {
6323
+ a = fudgeArray(a);
6324
+ b = fudgeArray(b);
6325
+ if (!leftSide) transposeA = !transposeA;
6326
+ else b = moveaxis$1(b, -2, -1);
6327
+ if (transposeA) a = moveaxis$1(a, -2, -1);
6328
+ let x = triangularSolve$1(a, b, {
6329
+ lower,
6330
+ unitDiagonal
6331
+ });
6332
+ if (leftSide) x = moveaxis$1(x, -2, -1);
6333
+ return x;
6334
+ }
6335
+
6336
+ //#endregion
6337
+ //#region src/library/lax.ts
6338
+ var lax_exports = {};
6339
+ __export(lax_exports, {
6340
+ conv: () => conv,
6341
+ convGeneralDilated: () => convGeneralDilated,
6342
+ convWithGeneralPadding: () => convWithGeneralPadding,
6343
+ dot: () => dot,
6344
+ erf: () => erf,
6345
+ erfc: () => erfc,
6346
+ linalg: () => lax_linalg_exports,
6347
+ reduceWindow: () => reduceWindow,
6348
+ stopGradient: () => stopGradient$1
6349
+ });
6350
+ /**
6351
+ * General dot product/contraction operator.
6352
+ *
6353
+ * Prefer higher-level functions like `jax.numpy.dot()`, `jax.numpy.matmul()`,
6354
+ * `jax.numpy.tensordot(), and `jax.numpy.einsum()` where possible.
6355
+ */
6356
+ function dot(lhs, rhs, { lhsContractingDims: lc = [], rhsContractingDims: rc = [], lhsBatchDims: lb = [], rhsBatchDims: rb = [] } = {}) {
6357
+ if (lc.length !== rc.length) throw new Error(`dot: contracting dims lengths mismatch, got ${JSON.stringify(lc)} and ${JSON.stringify(rc)}`);
6358
+ else if (lb.length !== rb.length) throw new Error(`dot: batch dims lengths mismatch, got ${JSON.stringify(lb)} and ${JSON.stringify(rb)}`);
6359
+ lc = lc.map((a) => checkAxis(a, lhs.ndim));
6360
+ rc = rc.map((a) => checkAxis(a, rhs.ndim));
6361
+ lb = lb.map((a) => checkAxis(a, lhs.ndim));
6362
+ rb = rb.map((a) => checkAxis(a, rhs.ndim));
6363
+ if (lc.some((a) => lb.includes(a))) throw new Error(`dot: lhs contracting dims ${JSON.stringify(lc)} overlap with batch dims ${JSON.stringify(lb)}`);
6364
+ else if (rc.some((a) => rb.includes(a))) throw new Error(`dot: rhs contracting dims ${JSON.stringify(rc)} overlap with batch dims ${JSON.stringify(rb)}`);
6365
+ const lf = range(lhs.ndim).filter((a) => !lc.includes(a) && !lb.includes(a));
6366
+ const rf = range(rhs.ndim).filter((a) => !rc.includes(a) && !rb.includes(a));
6367
+ const lhs2 = lhs.transpose([
6368
+ ...lb,
6369
+ ...lf,
6370
+ ...lc
6371
+ ]);
6372
+ const rhs2 = rhs.transpose([
6373
+ ...rb,
6374
+ ...rf,
6375
+ ...rc
6376
+ ]);
6377
+ if (lc.length === 0) return mul(lhs2.reshape([
6378
+ ...lb.map((a) => lhs.shape[a]),
6379
+ ...lf.map((a) => lhs.shape[a]),
6380
+ ...rep(rf.length, 1)
6381
+ ]), rhs2.reshape([
6382
+ ...rb.map((a) => rhs.shape[a]),
6383
+ ...rep(lf.length, 1),
6384
+ ...rf.map((a) => rhs.shape[a])
6385
+ ]));
6386
+ const dotShapeX = lc.map((a) => lhs.shape[a]);
6387
+ const dotShapeY = rc.map((a) => rhs.shape[a]);
6388
+ if (!deepEqual(dotShapeX, dotShapeY)) throw new Error(`dot: shapes not aligned along contracting dims: ${JSON.stringify(dotShapeX)} != ${JSON.stringify(dotShapeY)}`);
6389
+ return dot$2(lhs2.reshape([
6390
+ ...lb.map((a) => lhs.shape[a]),
6391
+ ...lf.map((a) => lhs.shape[a]),
6392
+ ...rep(rf.length, 1),
6393
+ prod(dotShapeX)
6394
+ ]), rhs2.reshape([
6395
+ ...rb.map((a) => rhs.shape[a]),
6396
+ ...rep(lf.length, 1),
6397
+ ...rf.map((a) => rhs.shape[a]),
6398
+ prod(dotShapeY)
6399
+ ]));
6400
+ }
6401
+ function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
6402
+ const padType = padding.toUpperCase();
6403
+ switch (padType) {
6404
+ case "VALID": return rep(inShape.length, [0, 0]);
6405
+ case "SAME":
6406
+ case "SAME_LOWER": {
6407
+ const outShape = inShape.map((size$1, i) => Math.ceil(size$1 / strides[i]));
6408
+ const padSizes = zipn(outShape, strides, filterShape, dilation, inShape).map(([o, s, k, d, i]) => Math.max(0, (o - 1) * s + 1 + (k - 1) * d - i));
6409
+ if (padType === "SAME") return padSizes.map((size$1) => [size$1 >> 1, size$1 - (size$1 >> 1)]);
6410
+ else return padSizes.map((size$1) => [size$1 - (size$1 >> 1), size$1 >> 1]);
6411
+ }
6412
+ default: throw new Error(`Unknown padding type: ${padType}`);
6413
+ }
6414
+ }
6415
+ /**
6416
+ * General n-dimensional convolution operator, with optional dilation.
6417
+ *
6418
+ * The semantics of this operation mimic the `jax.lax.conv_general_dilated`
6419
+ * function in JAX, which wraps XLA's general convolution operator.
6420
+ *
6421
+ * Grouped convolutions are not supported right now.
6422
+ */
6423
+ function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
6424
+ if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
6425
+ if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
6426
+ if (typeof padding === "string") {
6427
+ if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
6428
+ padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? rep(rhs.ndim - 2, 1), padding);
6429
+ }
6430
+ if (featureGroupCount !== 1) {
6431
+ const G = featureGroupCount;
6432
+ const [N, C_in, ...xs] = lhs.shape;
6433
+ const [C_out, C_in_per_group, ...ks] = rhs.shape;
6434
+ if (C_in % G !== 0) throw new Error(`featureGroupCount=${G} must divide input channels=${C_in}`);
6435
+ if (C_out % G !== 0) throw new Error(`featureGroupCount=${G} must divide output channels=${C_out}`);
6436
+ if (C_in / G !== C_in_per_group) throw new Error(`rhs input channels=${C_in_per_group} must equal lhs input channels / groups=${C_in / G}`);
6437
+ const lhsGrouped = moveaxis(lhs.reshape([
6438
+ N,
6439
+ G,
6440
+ C_in / G,
6441
+ ...xs
6442
+ ]), 1, 0);
6443
+ const rhsGrouped = rhs.reshape([
6444
+ G,
6445
+ C_out / G,
6446
+ C_in_per_group,
6447
+ ...ks
6448
+ ]);
6449
+ const result = conv$1(lhsGrouped, rhsGrouped, {
6450
+ vmapDims: 1,
6451
+ strides: windowStrides,
6452
+ padding,
6453
+ lhsDilation,
6454
+ rhsDilation
6455
+ });
6456
+ const ys = result.shape.slice(3);
6457
+ return moveaxis(result, 0, 1).reshape([
6458
+ N,
6459
+ C_out,
6460
+ ...ys
6461
+ ]);
6462
+ }
6463
+ return conv$1(lhs, rhs, {
6464
+ strides: windowStrides,
6465
+ padding,
6466
+ lhsDilation,
6467
+ rhsDilation
6468
+ });
6469
+ }
6470
+ /** Convenience wrapper around `convGeneralDilated`. */
6471
+ function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation) {
6472
+ return convGeneralDilated(lhs, rhs, windowStrides, padding, {
6473
+ lhsDilation,
6474
+ rhsDilation
6475
+ });
6476
+ }
6477
+ /** Convenience wrapper around `convGeneralDilated`. */
6478
+ function conv(lhs, rhs, windowStrides, padding) {
6479
+ return convGeneralDilated(lhs, rhs, windowStrides, padding);
6480
+ }
6481
+ /** Reduce a computation over padded windows. */
6482
+ function reduceWindow(operand, computation, windowDimensions, windowStrides) {
6483
+ if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
6484
+ if (!windowStrides) windowStrides = rep(windowDimensions.length, 1);
6485
+ for (let i = 0; i < operand.ndim; i++) computation = vmap$1(computation, 0);
6486
+ return computation(bind1(Primitive.Pool, [operand], {
6487
+ window: windowDimensions,
6488
+ strides: windowStrides
6489
+ }));
6490
+ }
6491
+ /** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
6492
+ function erf(x) {
6493
+ return erf$1(x);
6494
+ }
6495
+ /**
6496
+ * The complementary error function: `erfc(x) = 1 - erf(x)`.
6497
+ *
6498
+ * This function is more accurate than `1 - erf(x)` for large values of `x`,
6499
+ * where `erf(x)` is very close to 1.
6500
+ */
6501
+ function erfc(x) {
6502
+ return erfc$1(x);
6503
+ }
6504
+ /**
6505
+ * Stops gradient computation.
6506
+ *
6507
+ * Behaves as the identity function but prevents the flow of gradients during
6508
+ * forward or reverse-mode automatic differentiation.
6509
+ */
6510
+ function stopGradient$1(x) {
6511
+ return stopGradient(x);
6512
+ }
6513
+
5912
6514
  //#endregion
5913
6515
  //#region src/library/nn.ts
5914
6516
  var nn_exports = {};
@@ -5917,6 +6519,10 @@ __export(nn_exports, {
5917
6519
  elu: () => elu,
5918
6520
  gelu: () => gelu,
5919
6521
  glu: () => glu,
6522
+ hardSigmoid: () => hardSigmoid,
6523
+ hardSilu: () => hardSilu,
6524
+ hardSwish: () => hardSilu,
6525
+ hardTanh: () => hardTanh,
5920
6526
  identity: () => identity,
5921
6527
  leakyRelu: () => leakyRelu,
5922
6528
  logSigmoid: () => logSigmoid,
@@ -5927,14 +6533,17 @@ __export(nn_exports, {
5927
6533
  oneHot: () => oneHot,
5928
6534
  relu: () => relu,
5929
6535
  relu6: () => relu6,
6536
+ selu: () => selu,
5930
6537
  sigmoid: () => sigmoid,
5931
6538
  silu: () => silu,
5932
6539
  softSign: () => softSign,
5933
6540
  softmax: () => softmax,
5934
6541
  softplus: () => softplus,
6542
+ sparsePlus: () => sparsePlus,
6543
+ sparseSigmoid: () => sparseSigmoid,
5935
6544
  squareplus: () => squareplus,
5936
6545
  standardize: () => standardize,
5937
- swish: () => swish
6546
+ swish: () => silu
5938
6547
  });
5939
6548
  /**
5940
6549
  * Rectified Linear Unit (ReLU) activation function:
@@ -5969,6 +6578,28 @@ function softplus(x) {
5969
6578
  return log(exp(x).add(1));
5970
6579
  }
5971
6580
  /**
6581
+ * @function
6582
+ * Sparse plus function:
6583
+ *
6584
+ * - When `x <= -1`: `0`
6585
+ * - When `-1 < x < 1`: `(x+1)**2 / 4`
6586
+ * - When `x >= 1`: `x`
6587
+ */
6588
+ const sparsePlus = jit$1((x) => {
6589
+ return where(x.ref.lessEqual(-1), 0, where(x.ref.less(1), square(x.ref.add(1)).mul(.25), x));
6590
+ });
6591
+ /**
6592
+ * @function
6593
+ * Sparse sigmoid activation function.
6594
+ *
6595
+ * - When `x <= -1`: `0`
6596
+ * - When `-1 < x < 1`: `(x + 1) / 2`
6597
+ * - When `x >= 1`: `1`
6598
+ */
6599
+ const sparseSigmoid = jit$1((x) => {
6600
+ return clip(x.add(1).mul(.5), 0, 1);
6601
+ });
6602
+ /**
5972
6603
  * Soft-sign activation function, computed element-wise:
5973
6604
  * `softsign(x) = x / (|x| + 1)`.
5974
6605
  */
@@ -5990,17 +6621,6 @@ const silu = jit$1(function silu$1(x) {
5990
6621
  return x.ref.mul(sigmoid(x));
5991
6622
  });
5992
6623
  /**
5993
- * @function
5994
- * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
5995
- * Swish, computed element-wise:
5996
- * `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
5997
- *
5998
- * `swish()` and `silu()` are both aliases for the same function.
5999
- *
6000
- * Reference: https://en.wikipedia.org/wiki/Swish_function
6001
- */
6002
- const swish = silu;
6003
- /**
6004
6624
  * Log-sigmoid activation function, computed element-wise:
6005
6625
  * `log_sigmoid(x) = log(sigmoid(x)) = -log(1 + exp(-x))`.
6006
6626
  */
@@ -6017,6 +6637,19 @@ function leakyRelu(x, negativeSlope = .01) {
6017
6637
  x = fudgeArray(x);
6018
6638
  return where(less(x.ref, 0), x.ref.mul(negativeSlope), x);
6019
6639
  }
6640
+ /** Hard sigmoid activation function: `relu6(x+3)/6`. */
6641
+ function hardSigmoid(x) {
6642
+ return relu6(add(x, 3)).mul(1 / 6);
6643
+ }
6644
+ /** Hard SiLU (swish) activation function: `x * hardSigmoid(x)`. */
6645
+ function hardSilu(x) {
6646
+ x = fudgeArray(x);
6647
+ return x.ref.mul(hardSigmoid(x));
6648
+ }
6649
+ /** Hard tanh activation function: `clip(x, -1, 1)`. */
6650
+ function hardTanh(x) {
6651
+ return clip(x, -1, 1);
6652
+ }
6020
6653
  /**
6021
6654
  * Exponential linear unit activation function.
6022
6655
  *
@@ -6039,6 +6672,20 @@ function celu(x, alpha = 1) {
6039
6672
  }
6040
6673
  /**
6041
6674
  * @function
6675
+ * Scaled exponential linear unit activation.
6676
+ *
6677
+ * Computes the element-wise function:
6678
+ * `selu(x) = lambda * (x > 0 ? x : alpha * (exp(x) - 1))`
6679
+ *
6680
+ * Where `alpha = 1.6732632423543772` and `lambda = 1.0507009873554805`.
6681
+ */
6682
+ const selu = jit$1(function selu$1(x) {
6683
+ const alpha = 1.6732632423543772;
6684
+ const lambda = 1.0507009873554805;
6685
+ return where(x.ref.less(0), expm1(x.ref).mul(alpha), x).mul(lambda);
6686
+ });
6687
+ /**
6688
+ * @function
6042
6689
  * Gaussion error linear unit (GELU) activation function.
6043
6690
  *
6044
6691
  * This is computed element-wise. There are two variants depending on whether
@@ -6192,8 +6839,11 @@ var random_exports = {};
6192
6839
  __export(random_exports, {
6193
6840
  bernoulli: () => bernoulli,
6194
6841
  bits: () => bits,
6842
+ cauchy: () => cauchy,
6195
6843
  exponential: () => exponential,
6844
+ gumbel: () => gumbel,
6196
6845
  key: () => key,
6846
+ laplace: () => laplace,
6197
6847
  normal: () => normal,
6198
6848
  split: () => split,
6199
6849
  uniform: () => uniform
@@ -6252,6 +6902,16 @@ function bernoulli(key$1, p = .5, shape$1 = []) {
6252
6902
  }
6253
6903
  /**
6254
6904
  * @function
6905
+ * Sample from a Cauchy distribution with location 0 and scale 1.
6906
+ *
6907
+ * Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
6908
+ */
6909
+ const cauchy = jit$1(function cauchy$1(key$1, shape$1 = []) {
6910
+ const u = uniform(key$1, shape$1);
6911
+ return tan(u.sub(.5).mul(Math.PI));
6912
+ }, { staticArgnums: [1] });
6913
+ /**
6914
+ * @function
6255
6915
  * Sample exponential random values according to `p(x) = exp(-x)`.
6256
6916
  */
6257
6917
  const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
@@ -6260,6 +6920,30 @@ const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
6260
6920
  }, { staticArgnums: [1] });
6261
6921
  /**
6262
6922
  * @function
6923
+ * Sample from a Gumbel distribution with location 0 and scale 1.
6924
+ *
6925
+ * Uses inverse transform sampling: `x = -log(-log(u))` where u ~ Uniform(0, 1).
6926
+ */
6927
+ const gumbel = jit$1(function gumbel$1(key$1, shape$1 = []) {
6928
+ const u = uniform(key$1, shape$1);
6929
+ return negative(log(negative(log1p(negative(u)))));
6930
+ }, { staticArgnums: [1] });
6931
+ /**
6932
+ * @function
6933
+ * Sample from a Laplace distribution with location 0 and scale 1.
6934
+ *
6935
+ * Uses inverse transform sampling: the CDF is `F(x) = 0.5 + 0.5 * sign(x) * (1 - exp(-|x|))`.
6936
+ * Inverting: `x = -sign(u - 0.5) * log(1 - 2 * |u - 0.5|)`.
6937
+ */
6938
+ const laplace = jit$1(function laplace$1(key$1, shape$1 = []) {
6939
+ const u = uniform(key$1, shape$1);
6940
+ const centered = u.sub(.5);
6941
+ const s = sign(centered.ref);
6942
+ const absVal = absolute(centered);
6943
+ return s.mul(log1p(absVal.mul(-2)).mul(-1));
6944
+ }, { staticArgnums: [1] });
6945
+ /**
6946
+ * @function
6263
6947
  * Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
6264
6948
  *
6265
6949
  * Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
@@ -6368,11 +7052,6 @@ const valueAndGrad = valueAndGrad$1;
6368
7052
  */
6369
7053
  const jacrev = jacrev$1;
6370
7054
  /**
6371
- * @function
6372
- * Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
6373
- */
6374
- const jacobian = jacrev;
6375
- /**
6376
7055
  * Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
6377
7056
  *
6378
7057
  * This can be used to wait for the results of an intermediate computation to
@@ -6407,5 +7086,4 @@ async function devicePut(x, device) {
6407
7086
  }
6408
7087
 
6409
7088
  //#endregion
6410
- export { Array$1 as Array, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
6411
- //# sourceMappingURL=index.js.map
7089
+ export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };