@jax-js/jax 0.1.3 → 0.1.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.cjs CHANGED
@@ -8,9 +8,9 @@ var __hasOwnProp = Object.prototype.hasOwnProperty;
8
8
  var __commonJS = (cb, mod$1) => function() {
9
9
  return mod$1 || (0, cb[__getOwnPropNames(cb)[0]])((mod$1 = { exports: {} }).exports, mod$1), mod$1.exports;
10
10
  };
11
- var __export = (target, all) => {
12
- for (var name in all) __defProp(target, name, {
13
- get: all[name],
11
+ var __export = (target, all$1) => {
12
+ for (var name in all$1) __defProp(target, name, {
13
+ get: all$1[name],
14
14
  enumerable: true
15
15
  });
16
16
  };
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
30
30
  }) : target, mod$1));
31
31
 
32
32
  //#endregion
33
- const require_backend = require('./backend-CmaidnkQ.cjs');
33
+ const require_backend = require('./backend-Bu9GY6sK.cjs');
34
34
 
35
35
  //#region src/frontend/convolution.ts
36
36
  /**
@@ -362,6 +362,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
362
362
  Primitive$1["Mul"] = "mul";
363
363
  Primitive$1["Idiv"] = "idiv";
364
364
  Primitive$1["Mod"] = "mod";
365
+ Primitive$1["Min"] = "min";
366
+ Primitive$1["Max"] = "max";
365
367
  Primitive$1["Neg"] = "neg";
366
368
  Primitive$1["Reciprocal"] = "reciprocal";
367
369
  Primitive$1["Floor"] = "floor";
@@ -369,7 +371,6 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
369
371
  Primitive$1["StopGradient"] = "stop_gradient";
370
372
  Primitive$1["Cast"] = "cast";
371
373
  Primitive$1["Bitcast"] = "bitcast";
372
- Primitive$1["RandomBits"] = "random_bits";
373
374
  Primitive$1["Sin"] = "sin";
374
375
  Primitive$1["Cos"] = "cos";
375
376
  Primitive$1["Asin"] = "asin";
@@ -379,8 +380,6 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
379
380
  Primitive$1["Erf"] = "erf";
380
381
  Primitive$1["Erfc"] = "erfc";
381
382
  Primitive$1["Sqrt"] = "sqrt";
382
- Primitive$1["Min"] = "min";
383
- Primitive$1["Max"] = "max";
384
383
  Primitive$1["Reduce"] = "reduce";
385
384
  Primitive$1["Dot"] = "dot";
386
385
  Primitive$1["Conv"] = "conv";
@@ -388,14 +387,19 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
388
387
  Primitive$1["PoolTranspose"] = "pool_transpose";
389
388
  Primitive$1["Compare"] = "compare";
390
389
  Primitive$1["Where"] = "where";
390
+ Primitive$1["RandomBits"] = "random_bits";
391
+ Primitive$1["Gather"] = "gather";
391
392
  Primitive$1["Transpose"] = "transpose";
392
393
  Primitive$1["Broadcast"] = "broadcast";
393
394
  Primitive$1["Reshape"] = "reshape";
394
395
  Primitive$1["Flip"] = "flip";
395
396
  Primitive$1["Shrink"] = "shrink";
396
397
  Primitive$1["Pad"] = "pad";
397
- Primitive$1["Gather"] = "gather";
398
- Primitive$1["JitCall"] = "jit_call";
398
+ Primitive$1["Sort"] = "sort";
399
+ Primitive$1["Argsort"] = "argsort";
400
+ Primitive$1["TriangularSolve"] = "triangular_solve";
401
+ Primitive$1["Cholesky"] = "cholesky";
402
+ Primitive$1["Jit"] = "jit";
399
403
  return Primitive$1;
400
404
  }({});
401
405
  let CompareOp = /* @__PURE__ */ function(CompareOp$1) {
@@ -417,6 +421,12 @@ function idiv(x, y) {
417
421
  function mod(x, y) {
418
422
  return bind1(Primitive.Mod, [x, y]);
419
423
  }
424
+ function min$1(x, y) {
425
+ return bind1(Primitive.Min, [x, y]);
426
+ }
427
+ function max$1(x, y) {
428
+ return bind1(Primitive.Max, [x, y]);
429
+ }
420
430
  function neg(x) {
421
431
  return bind1(Primitive.Neg, [x]);
422
432
  }
@@ -438,12 +448,6 @@ function cast(x, dtype) {
438
448
  function bitcast(x, dtype) {
439
449
  return bind1(Primitive.Bitcast, [x], { dtype });
440
450
  }
441
- function randomBits(k0, k1, shape$1, mode = "xor") {
442
- return bind1(Primitive.RandomBits, [k0, k1], {
443
- shape: shape$1,
444
- mode
445
- });
446
- }
447
451
  function sin$1(x) {
448
452
  return bind1(Primitive.Sin, [x]);
449
453
  }
@@ -471,12 +475,6 @@ function erfc$1(x) {
471
475
  function sqrt$1(x) {
472
476
  return bind1(Primitive.Sqrt, [x]);
473
477
  }
474
- function min$1(x, y) {
475
- return bind1(Primitive.Min, [x, y]);
476
- }
477
- function max$1(x, y) {
478
- return bind1(Primitive.Max, [x, y]);
479
- }
480
478
  function reduce(x, op, axis = null, opts) {
481
479
  if (!require_backend.AluGroup.Reduce.has(op)) throw new TypeError(`Invalid reduce operation: ${op}`);
482
480
  axis = require_backend.normalizeAxis(axis, ndim$1(x));
@@ -532,6 +530,23 @@ function where$1(cond, x, y) {
532
530
  y
533
531
  ]);
534
532
  }
533
+ function randomBits(k0, k1, shape$1, mode = "xor") {
534
+ return bind1(Primitive.RandomBits, [k0, k1], {
535
+ shape: shape$1,
536
+ mode
537
+ });
538
+ }
539
+ function gather(x, indices, axis, outDim) {
540
+ if (indices.length === 0) throw new Error("gather() requires at least one index");
541
+ if (!Array.isArray(axis) || axis.length !== indices.length) throw new Error(`Invalid gather() axis: expected ${indices.length} axes, got ${JSON.stringify(axis)}`);
542
+ axis = axis.map((a) => require_backend.checkAxis(a, ndim$1(x)));
543
+ if (new Set(axis).size !== axis.length) throw new Error(`Invalid gather() axis: duplicate axes ${JSON.stringify(axis)}`);
544
+ outDim = require_backend.checkAxis(outDim, ndim$1(x) - axis.length + 1);
545
+ return bind1(Primitive.Gather, [x, ...indices], {
546
+ axis,
547
+ outDim
548
+ });
549
+ }
535
550
  function transpose$1(x, perm) {
536
551
  perm = perm ? perm.map((a) => require_backend.checkAxis(a, ndim$1(x))) : require_backend.range(ndim$1(x)).reverse();
537
552
  if (!require_backend.isPermutation(perm, ndim$1(x))) throw new Error(`Invalid transpose permutation for ${ndim$1(x)} axes: ${JSON.stringify(perm)}`);
@@ -581,16 +596,27 @@ function pad$1(x, width) {
581
596
  } else if (width.length !== nd) throw new Error(`Invalid pad(): expected ${nd} axes, got ${width.length}`);
582
597
  return bind1(Primitive.Pad, [x], { width });
583
598
  }
584
- function gather(x, indices, axis, outDim) {
585
- if (indices.length === 0) throw new Error("gather() requires at least one index");
586
- if (!Array.isArray(axis) || axis.length !== indices.length) throw new Error(`Invalid gather() axis: expected ${indices.length} axes, got ${JSON.stringify(axis)}`);
587
- axis = axis.map((a) => require_backend.checkAxis(a, ndim$1(x)));
588
- if (new Set(axis).size !== axis.length) throw new Error(`Invalid gather() axis: duplicate axes ${JSON.stringify(axis)}`);
589
- outDim = require_backend.checkAxis(outDim, ndim$1(x) - axis.length + 1);
590
- return bind1(Primitive.Gather, [x, ...indices], {
591
- axis,
592
- outDim
593
- });
599
+ function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
600
+ if (lower) {
601
+ a = flip$1(a, [-2, -1]);
602
+ b = flip$1(b, [-1]);
603
+ }
604
+ let x = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
605
+ if (lower) x = flip$1(x, [-1]);
606
+ return x;
607
+ }
608
+ function cholesky$2(x) {
609
+ return bind1(Primitive.Cholesky, [x]);
610
+ }
611
+ function sort$1(x) {
612
+ const nd = ndim$1(x);
613
+ if (nd === 0) throw new Error("sort: requires at least 1D input");
614
+ return bind1(Primitive.Sort, [x]);
615
+ }
616
+ function argsort$1(x) {
617
+ const nd = ndim$1(x);
618
+ if (nd === 0) throw new Error("argsort: requires at least 1D input");
619
+ return bind(Primitive.Argsort, [x]);
594
620
  }
595
621
  function bind1(prim, args, params = {}) {
596
622
  const [results] = bind(prim, args, params);
@@ -753,7 +779,7 @@ var Tracer = class Tracer {
753
779
  if (require_backend.isFloatDtype(this.dtype)) return this.mul(reciprocal$1(other));
754
780
  return idiv(this, other);
755
781
  }
756
- /** Return specified diagonals. See `numpy.diagonal` for full docs. */
782
+ /** Return specified diagonals. See `jax.numpy.diagonal` for full docs. */
757
783
  diagonal(offset = 0, axis1 = 0, axis2 = 1) {
758
784
  if (!Number.isInteger(offset)) throw new TypeError(`offset must be an integer, got ${offset}`);
759
785
  if (offset < 0) return this.diagonal(-offset, axis2, axis1);
@@ -806,6 +832,34 @@ var Tracer = class Tracer {
806
832
  this.dispose();
807
833
  }
808
834
  /**
835
+ * Return a sorted copy of an array in ascending order.
836
+ *
837
+ * See `jax.numpy.sort` for full docs.
838
+ */
839
+ sort(axis = -1) {
840
+ axis = require_backend.checkAxis(axis, this.ndim);
841
+ if (this.shape[axis] <= 1) return this;
842
+ if (axis === this.ndim - 1) return sort$1(this);
843
+ const perm = require_backend.range(this.ndim);
844
+ perm.splice(axis, 1);
845
+ perm.push(axis);
846
+ return sort$1(this.transpose(perm)).transpose(require_backend.invertPermutation(perm));
847
+ }
848
+ /**
849
+ * Return the indices that would sort an array. This may not be a stable
850
+ * sorting algorithm; it need not preserve order of indices in ties.
851
+ *
852
+ * See `jax.numpy.argsort` for full docs.
853
+ */
854
+ argsort(axis = -1) {
855
+ axis = require_backend.checkAxis(axis, this.ndim);
856
+ if (axis === this.ndim - 1) return argsort$1(this)[1];
857
+ const perm = require_backend.range(this.ndim);
858
+ perm.splice(axis, 1);
859
+ perm.push(axis);
860
+ return argsort$1(this.transpose(perm))[1].transpose(require_backend.invertPermutation(perm));
861
+ }
862
+ /**
809
863
  * Slice an array along one or more axes.
810
864
  *
811
865
  * This is the equivalent of slicing in Python, e.g. `x[1:3, 2, :, None]`. To
@@ -922,6 +976,9 @@ var ShapedArray = class ShapedArray {
922
976
  get ndim() {
923
977
  return this.shape.length;
924
978
  }
979
+ get size() {
980
+ return require_backend.prod(this.shape);
981
+ }
925
982
  toString() {
926
983
  return `${this.dtype}[${this.shape.join(",")}]`;
927
984
  }
@@ -1221,13 +1278,13 @@ var Jaxpr = class Jaxpr {
1221
1278
  }
1222
1279
  return new Jaxpr(this.inBinders, liveEqns.reverse(), outs);
1223
1280
  }
1224
- /** Flattens nested JitCall in a Jaxpr. Useful for handling jit-of-jit. */
1281
+ /** Flattens nested Jit in a Jaxpr. Useful for handling jit-of-jit. */
1225
1282
  flatten() {
1226
- if (!this.eqns.some((eqn) => eqn.primitive === Primitive.JitCall)) return this;
1283
+ if (!this.eqns.some((eqn) => eqn.primitive === Primitive.Jit)) return this;
1227
1284
  const newEqns = [];
1228
1285
  const varMap = /* @__PURE__ */ new Map();
1229
1286
  const varMapF = (x) => x instanceof Var ? varMap.get(x) ?? x : x;
1230
- for (const eqn of this.eqns) if (eqn.primitive === Primitive.JitCall) {
1287
+ for (const eqn of this.eqns) if (eqn.primitive === Primitive.Jit) {
1231
1288
  const jaxpr = eqn.params.jaxpr.flatten();
1232
1289
  const translation = /* @__PURE__ */ new Map();
1233
1290
  const translationF = (x) => x instanceof Var ? translation.get(x) : x;
@@ -1328,19 +1385,48 @@ function evalJaxpr(jaxpr, args) {
1328
1385
  function jaxprAsFun(jaxpr) {
1329
1386
  return (...args) => evalJaxpr(jaxpr, args);
1330
1387
  }
1388
+ /** Jaxpr with a collection of associated, traced constants. */
1389
+ var ClosedJaxpr = class ClosedJaxpr {
1390
+ constructor(jaxpr, consts) {
1391
+ this.jaxpr = jaxpr;
1392
+ this.consts = consts;
1393
+ }
1394
+ /** String representation of this Jaxpr. */
1395
+ toString() {
1396
+ return this.jaxpr.toString();
1397
+ }
1398
+ /** Apply a function to the underlying Jaxpr. */
1399
+ mapJaxpr(f) {
1400
+ return new ClosedJaxpr(f(this.jaxpr), this.consts);
1401
+ }
1402
+ /** Dispose of the constants in this Jaxpr. */
1403
+ dispose() {
1404
+ for (const c of this.consts) c.dispose();
1405
+ }
1406
+ };
1331
1407
  /** Tracer that records its operations to dynamically construct a Jaxpr. */
1332
1408
  var JaxprTracer = class extends Tracer {
1409
+ #rc;
1333
1410
  constructor(trace$1, aval) {
1334
1411
  super(trace$1);
1335
1412
  this.aval = aval;
1413
+ this.#rc = 1;
1336
1414
  }
1337
1415
  toString() {
1338
1416
  return `JaxprTracer(${this.aval.toString()})`;
1339
1417
  }
1340
1418
  get ref() {
1419
+ if (this.#rc <= 0) throw new UseAfterFreeError(this);
1420
+ this.#rc++;
1341
1421
  return this;
1342
1422
  }
1343
- dispose() {}
1423
+ dispose() {
1424
+ if (this.#rc <= 0) throw new UseAfterFreeError(this);
1425
+ this.#rc--;
1426
+ }
1427
+ trackLiftedConstant() {
1428
+ this.#rc++;
1429
+ }
1344
1430
  };
1345
1431
  /** Analogous to the 'DynamicJaxprTrace' class in JAX. */
1346
1432
  var JaxprTrace = class extends Trace {
@@ -1353,17 +1439,24 @@ var JaxprTrace = class extends Trace {
1353
1439
  }
1354
1440
  /** Register a constant / literal in this Jaxpr. */
1355
1441
  getOrMakeConstTracer(val) {
1442
+ if (!(val instanceof Tracer)) val = pureArray(val);
1356
1443
  let tracer = this.builder.constTracers.get(val);
1357
1444
  if (tracer === void 0) {
1358
1445
  tracer = this.builder.newTracer(this, ShapedArray.fromAval(getAval(val)));
1359
- this.builder.addConst(tracer, val instanceof Tracer ? val.ref : array(val));
1446
+ this.builder.addConst(tracer, val);
1447
+ } else {
1448
+ val.dispose();
1449
+ tracer.trackLiftedConstant();
1360
1450
  }
1361
1451
  return tracer;
1362
1452
  }
1363
1453
  pure = this.getOrMakeConstTracer;
1364
1454
  lift = this.getOrMakeConstTracer;
1365
1455
  processPrimitive(primitive, tracers, params) {
1366
- const avalsIn = tracers.map((t) => t.aval);
1456
+ const avalsIn = tracers.map((t) => {
1457
+ t.dispose();
1458
+ return t.aval;
1459
+ });
1367
1460
  const avalsOut = abstractEvalRules[primitive](avalsIn, params);
1368
1461
  const outTracers = avalsOut.map((aval) => this.builder.newTracer(this, aval));
1369
1462
  this.builder.addEqn(new JaxprEqn(primitive, tracers.map((t) => this.builder.getVar(t)), params, outTracers.map((t) => this.builder.addVar(t))));
@@ -1406,20 +1499,17 @@ var JaxprBuilder = class {
1406
1499
  return v;
1407
1500
  }
1408
1501
  build(inTracers, outTracers) {
1409
- let [constVars, consts] = require_backend.unzip2(this.constVals.entries());
1502
+ const [constVars, consts] = require_backend.unzip2(this.constVals.entries());
1410
1503
  const t2v = this.getVar.bind(this);
1411
1504
  const inBinders = [...constVars, ...inTracers.map(t2v)];
1412
1505
  const outVars = outTracers.map(t2v);
1413
- let jaxpr = new Jaxpr(inBinders, this.eqns, outVars);
1506
+ const jaxpr = new Jaxpr(inBinders, this.eqns, outVars);
1414
1507
  typecheckJaxpr(jaxpr);
1415
- [jaxpr, consts] = _inlineLiterals(jaxpr, consts);
1416
- return {
1417
- jaxpr,
1418
- consts
1419
- };
1508
+ const cjaxpr = new ClosedJaxpr(jaxpr, consts);
1509
+ return _inlineLiterals(cjaxpr);
1420
1510
  }
1421
1511
  };
1422
- function _inlineLiterals(jaxpr, consts) {
1512
+ function _inlineLiterals({ jaxpr, consts }) {
1423
1513
  const literals = /* @__PURE__ */ new Map();
1424
1514
  const constBinders = [];
1425
1515
  const newConsts = [];
@@ -1434,7 +1524,7 @@ function _inlineLiterals(jaxpr, consts) {
1434
1524
  const newOuts = jaxpr.outs.map((x) => literals.get(x) ?? x);
1435
1525
  const newJaxpr = new Jaxpr([...constBinders, ...jaxpr.inBinders.slice(consts.length)], newEqns, newOuts);
1436
1526
  typecheckJaxpr(newJaxpr);
1437
- return [newJaxpr, newConsts];
1527
+ return new ClosedJaxpr(newJaxpr, newConsts);
1438
1528
  }
1439
1529
  function binopAbstractEval([x, y]) {
1440
1530
  if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("binopAbstractEval expects ShapedArray inputs");
@@ -1453,6 +1543,8 @@ const abstractEvalRules = {
1453
1543
  [Primitive.Mul]: binopAbstractEval,
1454
1544
  [Primitive.Idiv]: binopAbstractEval,
1455
1545
  [Primitive.Mod]: binopAbstractEval,
1546
+ [Primitive.Min]: binopAbstractEval,
1547
+ [Primitive.Max]: binopAbstractEval,
1456
1548
  [Primitive.Neg]: vectorizedUnopAbstractEval,
1457
1549
  [Primitive.Reciprocal]: vectorizedUnopAbstractEval,
1458
1550
  [Primitive.Floor]: vectorizedUnopAbstractEval,
@@ -1466,12 +1558,6 @@ const abstractEvalRules = {
1466
1558
  if (require_backend.byteWidth(x.dtype) !== require_backend.byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
1467
1559
  return [new ShapedArray(x.shape, dtype, false)];
1468
1560
  },
1469
- [Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
1470
- if (k0.dtype !== require_backend.DType.Uint32 || k1.dtype !== require_backend.DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
1471
- const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
1472
- if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
1473
- return [new ShapedArray(shape$1, require_backend.DType.Uint32, false)];
1474
- },
1475
1561
  [Primitive.Sin]: vectorizedUnopAbstractEval,
1476
1562
  [Primitive.Cos]: vectorizedUnopAbstractEval,
1477
1563
  [Primitive.Asin]: vectorizedUnopAbstractEval,
@@ -1481,8 +1567,6 @@ const abstractEvalRules = {
1481
1567
  [Primitive.Erf]: vectorizedUnopAbstractEval,
1482
1568
  [Primitive.Erfc]: vectorizedUnopAbstractEval,
1483
1569
  [Primitive.Sqrt]: vectorizedUnopAbstractEval,
1484
- [Primitive.Min]: binopAbstractEval,
1485
- [Primitive.Max]: binopAbstractEval,
1486
1570
  [Primitive.Reduce]([x], { axis }) {
1487
1571
  const axisSet = new Set(axis);
1488
1572
  const newShape = x.shape.filter((_, i) => !axisSet.has(i));
@@ -1515,6 +1599,25 @@ const abstractEvalRules = {
1515
1599
  const shape$1 = require_backend.generalBroadcast(cond.shape, xy.shape);
1516
1600
  return [new ShapedArray(shape$1, xy.dtype, xy.weakType)];
1517
1601
  },
1602
+ [Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
1603
+ if (k0.dtype !== require_backend.DType.Uint32 || k1.dtype !== require_backend.DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
1604
+ const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
1605
+ if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
1606
+ return [new ShapedArray(shape$1, require_backend.DType.Uint32, false)];
1607
+ },
1608
+ [Primitive.Gather]([x, ...indices], { axis, outDim }) {
1609
+ for (const a of indices) if (a.dtype !== require_backend.DType.Int32 && a.dtype !== require_backend.DType.Uint32) throw new TypeError(`Gather indices must be Int32 or Uint32, got ${a.dtype}`);
1610
+ if (axis.length !== indices.length) throw new TypeError(`Gather: ${axis} axes but ${indices.length} indices`);
1611
+ if (indices.length === 0) throw new TypeError("Gather must have 1+ indices with same shape");
1612
+ if (axis.some((a) => a < 0 || a >= x.shape.length)) throw new TypeError("Gather axis out of bounds");
1613
+ if (outDim < 0 || outDim > x.shape.length - axis.length) throw new TypeError("Gather outDim out of bounds");
1614
+ const axisSet = new Set(axis);
1615
+ if (axisSet.size !== axis.length) throw new TypeError("Gather axes are not unique");
1616
+ const gatherShape = indices.reduce((shape$1, a) => require_backend.generalBroadcast(shape$1, a.shape), []);
1617
+ const newShape = x.shape.filter((_, i) => !axisSet.has(i));
1618
+ newShape.splice(outDim, 0, ...gatherShape);
1619
+ return [new ShapedArray(newShape, x.dtype, x.weakType)];
1620
+ },
1518
1621
  [Primitive.Transpose]([x], { perm }) {
1519
1622
  return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype, x.weakType)];
1520
1623
  },
@@ -1535,23 +1638,31 @@ const abstractEvalRules = {
1535
1638
  const newShape = x.shape.map((dim, i) => dim + width[i][0] + width[i][1]);
1536
1639
  return [new ShapedArray(newShape, x.dtype, x.weakType)];
1537
1640
  },
1538
- [Primitive.Gather]([x, ...indices], { axis, outDim }) {
1539
- for (const a of indices) if (a.dtype !== require_backend.DType.Int32 && a.dtype !== require_backend.DType.Uint32) throw new TypeError(`Gather indices must be Int32 or Uint32, got ${a.dtype}`);
1540
- if (axis.length !== indices.length) throw new TypeError(`Gather: ${axis} axes but ${indices.length} indices`);
1541
- if (indices.length === 0) throw new TypeError("Gather must have 1+ indices with same shape");
1542
- if (axis.some((a) => a < 0 || a >= x.shape.length)) throw new TypeError("Gather axis out of bounds");
1543
- if (outDim < 0 || outDim > x.shape.length - axis.length) throw new TypeError("Gather outDim out of bounds");
1544
- const axisSet = new Set(axis);
1545
- if (axisSet.size !== axis.length) throw new TypeError("Gather axes are not unique");
1546
- const gatherShape = indices.reduce((shape$1, a) => require_backend.generalBroadcast(shape$1, a.shape), []);
1547
- const newShape = x.shape.filter((_, i) => !axisSet.has(i));
1548
- newShape.splice(outDim, 0, ...gatherShape);
1549
- return [new ShapedArray(newShape, x.dtype, x.weakType)];
1641
+ [Primitive.Sort]([x]) {
1642
+ if (x.ndim === 0) throw new TypeError("sort: requires at least 1D input");
1643
+ return [ShapedArray.fromAval(x)];
1644
+ },
1645
+ [Primitive.Argsort]([x]) {
1646
+ if (x.ndim === 0) throw new TypeError("argsort: requires at least 1D input");
1647
+ return [ShapedArray.fromAval(x), new ShapedArray(x.shape, require_backend.DType.Int32, false)];
1648
+ },
1649
+ [Primitive.TriangularSolve]([a, b]) {
1650
+ if (a.ndim < 2) throw new TypeError(`triangular_solve: a must be at least 2D, got ${a}`);
1651
+ if (b.ndim < 2) throw new TypeError(`triangular_solve: b must be at least 2D, got ${b}`);
1652
+ const [m, n] = a.shape.slice(-2);
1653
+ const [_batch, q] = b.shape.slice(-2);
1654
+ if (!require_backend.deepEqual(a.shape.slice(0, -2), b.shape.slice(0, -2)) || a.dtype !== b.dtype || m !== n || n !== q) throw new TypeError(`triangular_solve: mismatch ${a} vs ${b}`);
1655
+ return [new ShapedArray(b.shape, b.dtype, a.weakType && b.weakType)];
1656
+ },
1657
+ [Primitive.Cholesky]([a]) {
1658
+ if (a.ndim < 2) throw new TypeError(`cholesky: requires at least 2D input, got ${a}`);
1659
+ if (a.shape[a.ndim - 2] !== a.shape[a.ndim - 1]) throw new TypeError(`cholesky: must be square, got ${a}`);
1660
+ return [ShapedArray.fromAval(a)];
1550
1661
  },
1551
- [Primitive.JitCall](args, { jaxpr }) {
1662
+ [Primitive.Jit](args, { jaxpr }) {
1552
1663
  const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
1553
- if (args.length !== inTypes.length) throw new TypeError(`jit_call expected ${inTypes.length} arguments, got ${args.length}`);
1554
- for (let i = 0; i < inTypes.length; i++) if (!args[i].equals(inTypes[i])) throw new TypeError(`jit_call argument ${i} has type ${args[i]}, expected ${inTypes[i]}`);
1664
+ if (args.length !== inTypes.length) throw new TypeError(`jit expected ${inTypes.length} arguments, got ${args.length}`);
1665
+ for (let i = 0; i < inTypes.length; i++) if (!args[i].equals(inTypes[i])) throw new TypeError(`jit argument ${i} has type ${args[i]}, expected ${inTypes[i]}`);
1555
1666
  return outTypes;
1556
1667
  }
1557
1668
  };
@@ -1587,11 +1698,10 @@ function makeJaxpr$1(f, opts) {
1587
1698
  const tracersIn = avalsIn.map((aval) => trace$1.newArg(typeof aval === "object" ? aval : pureArray(aval)));
1588
1699
  const outs = fFlat(...tracersIn);
1589
1700
  const tracersOut = outs.map((out) => fullRaise(trace$1, out));
1590
- const { jaxpr, consts } = builder.build(tracersIn, tracersOut);
1701
+ const jaxpr = builder.build(tracersIn, tracersOut);
1591
1702
  if (outTree.value === void 0) throw new Error("outTree was not set in makeJaxpr");
1592
1703
  return {
1593
- jaxpr: jaxpr.simplify(),
1594
- consts,
1704
+ jaxpr: jaxpr.mapJaxpr((j) => j.simplify()),
1595
1705
  treedef: outTree.value
1596
1706
  };
1597
1707
  } catch (_) {
@@ -1610,22 +1720,28 @@ function jit$1(f, opts) {
1610
1720
  const avalsInFlat = argsFlat.map((x) => ShapedArray.fromAval(getAval(x)));
1611
1721
  const avalsIn = unflatten(inTree, avalsInFlat);
1612
1722
  const jaxprArgs = joinIdx(args.length, staticArgs, avalsIn, staticArgnums);
1613
- const { jaxpr, consts, treedef: outTree } = require_backend.runWithCache(cache, jaxprArgs, () => makeJaxpr$1(f, opts)(...jaxprArgs));
1614
- const outs = bind(Primitive.JitCall, [...consts.map((c) => c.ref), ...argsFlat], {
1723
+ const { jaxpr, treedef: outTree } = require_backend.runWithCache(cache, jaxprArgs, () => makeJaxpr$1(f, opts)(...jaxprArgs));
1724
+ const outs = bind(Primitive.Jit, [...jaxpr.consts.map((c) => c.ref), ...argsFlat], {
1615
1725
  name: f.name || "closure",
1616
- jaxpr,
1617
- numConsts: consts.length
1726
+ jaxpr: jaxpr.jaxpr,
1727
+ numConsts: jaxpr.consts.length
1618
1728
  });
1619
1729
  return unflatten(outTree, outs);
1620
1730
  });
1621
1731
  result.dispose = () => {
1622
- for (const { consts } of cache.values()) for (const c of consts) c.dispose();
1732
+ for (const { jaxpr } of cache.values()) jaxpr.dispose();
1623
1733
  };
1624
1734
  return result;
1625
1735
  }
1626
1736
 
1627
1737
  //#endregion
1628
1738
  //#region src/frontend/jit.ts
1739
+ const routinePrimitives = new Map([
1740
+ [Primitive.Sort, require_backend.Routines.Sort],
1741
+ [Primitive.Argsort, require_backend.Routines.Argsort],
1742
+ [Primitive.TriangularSolve, require_backend.Routines.TriangularSolve],
1743
+ [Primitive.Cholesky, require_backend.Routines.Cholesky]
1744
+ ]);
1629
1745
  /** Result of compiling a Jaxpr. Can be evaluated on a series of inputs. */
1630
1746
  var JitProgram = class {
1631
1747
  constructor(backend, steps, inputs, outputs) {
@@ -1640,9 +1756,14 @@ var JitProgram = class {
1640
1756
  case "execute": {
1641
1757
  const inputsNice = step.inputs.map((id, i) => `${i}: %${id}`).join(", ");
1642
1758
  const outputsNice = step.outputs.map((id) => `%${id}`).join(", ");
1643
- return require_backend.PPrint.pp(`execute (${inputsNice}) -> ${outputsNice}, kernel`).concat(step.kernel.pprint().indent(2));
1759
+ const executeText = `execute (${inputsNice}) -> ${outputsNice}`;
1760
+ if (step.source instanceof require_backend.Kernel) return require_backend.PPrint.pp(`${executeText}, kernel`).concat(step.source.pprint().indent(2));
1761
+ else if (step.source instanceof require_backend.Routine) return require_backend.PPrint.pp(`${executeText}, routine ${step.source.name}`);
1762
+ else {
1763
+ step.source;
1764
+ return require_backend.PPrint.pp(executeText);
1765
+ }
1644
1766
  }
1645
- case "const": return require_backend.PPrint.pp(`%${step.output} = const <Slot ${step.slot}>`);
1646
1767
  case "malloc": return require_backend.PPrint.pp(`%${step.output} = malloc <${step.size} bytes>`);
1647
1768
  case "incref": return require_backend.PPrint.pp(`incref ${step.input}`);
1648
1769
  case "free": return require_backend.PPrint.pp(`free ${step.input}`);
@@ -1665,12 +1786,9 @@ var JitProgram = class {
1665
1786
  const inputs$1 = step.inputs.map((id) => scope.get(id));
1666
1787
  const outputs = step.outputs.map((id) => scope.get(id));
1667
1788
  if (inputs$1.some((s) => s === void 0) || outputs.some((s) => s === void 0)) throw new Error(`internal: JitProgram scope undefined`);
1668
- pending.push(new PendingExecute(this.backend, step.kernel, inputs$1, outputs));
1789
+ pending.push(new PendingExecute(this.backend, step.source, inputs$1, outputs));
1669
1790
  break;
1670
1791
  }
1671
- case "const":
1672
- scope.set(step.output, step.slot);
1673
- break;
1674
1792
  case "malloc": {
1675
1793
  const slot = this.backend.malloc(step.size);
1676
1794
  scope.set(step.output, slot);
@@ -1704,34 +1822,37 @@ var JitProgramBuilder = class {
1704
1822
  this.#nextId = nargs;
1705
1823
  this.steps = [];
1706
1824
  }
1707
- pushConst(slot) {
1708
- const id = this.#nextId++;
1709
- this.steps.push({
1710
- type: "const",
1711
- slot,
1712
- output: id
1713
- });
1714
- return id;
1715
- }
1716
1825
  pushLit(lit) {
1717
- const kernel = new require_backend.Kernel(0, require_backend.prod(lit.aval.shape), require_backend.AluExp.const(lit.dtype, lit.value));
1826
+ const kernel = new require_backend.Kernel(0, lit.aval.size, require_backend.AluExp.const(lit.dtype, lit.value));
1718
1827
  return this.pushKernel(kernel, []);
1719
1828
  }
1720
- pushKernel(kernel, inputs) {
1829
+ pushBuffer(size$1) {
1721
1830
  const id = this.#nextId++;
1722
1831
  this.steps.push({
1723
1832
  type: "malloc",
1724
- size: kernel.bytes,
1833
+ size: size$1,
1725
1834
  output: id
1726
1835
  });
1836
+ return id;
1837
+ }
1838
+ pushKernel(kernel, inputs) {
1839
+ const id = this.pushBuffer(kernel.bytes);
1727
1840
  this.steps.push({
1728
1841
  type: "execute",
1729
- kernel,
1842
+ source: kernel,
1730
1843
  inputs,
1731
1844
  outputs: [id]
1732
1845
  });
1733
1846
  return id;
1734
1847
  }
1848
+ pushRoutine(routine, inputs, outputs) {
1849
+ this.steps.push({
1850
+ type: "execute",
1851
+ source: routine,
1852
+ inputs,
1853
+ outputs
1854
+ });
1855
+ }
1735
1856
  pushIncref(id) {
1736
1857
  this.steps.push({
1737
1858
  type: "incref",
@@ -1757,28 +1878,18 @@ var JitProgramBuilder = class {
1757
1878
  }
1758
1879
  };
1759
1880
  const jitCompileCache = /* @__PURE__ */ new Map();
1760
- function jitCompile(backend, jaxpr, consts) {
1761
- if (jaxpr.inBinders.length < consts.length) throw new TypeError(`Jaxpr has ${jaxpr.inBinders.length} inputs, but ${consts.length} consts were provided`);
1762
- for (let i = 0; i < consts.length; i++) if (consts[i].device !== backend.type) throw new TypeError(`Const ${i} has device ${consts[i].device}, but expected ${backend.type}`);
1763
- const cacheKey = backend.type + require_backend.FpHash.hash(jaxpr, ...consts.map((c) => c.id));
1881
+ function jitCompile(backend, jaxpr) {
1882
+ const cacheKey = backend.type + "," + require_backend.FpHash.hash(jaxpr);
1764
1883
  const cached = jitCompileCache.get(cacheKey);
1765
1884
  if (cached) return cached;
1766
1885
  if (require_backend.DEBUG >= 1) console.info("=========== JIT Compile ===========\n" + jaxpr.toString());
1767
1886
  jaxpr = jaxpr.flatten().simplify();
1768
- const nargs = jaxpr.inBinders.length - consts.length;
1887
+ const nargs = jaxpr.inBinders.length;
1769
1888
  const builder = new JitProgramBuilder(backend, nargs);
1770
1889
  const blackNodes = splitGraphDataflow(backend, jaxpr);
1771
1890
  const ctx = /* @__PURE__ */ new Map();
1772
- for (let i = 0; i < consts.length; i++) {
1773
- const v = jaxpr.inBinders[i];
1774
- const slot = consts[i]._realizeSource();
1775
- ctx.set(v, {
1776
- type: "imm",
1777
- arg: builder.pushConst(slot)
1778
- });
1779
- }
1780
1891
  for (let i = 0; i < nargs; i++) {
1781
- const v = jaxpr.inBinders[consts.length + i];
1892
+ const v = jaxpr.inBinders[i];
1782
1893
  ctx.set(v, {
1783
1894
  type: "imm",
1784
1895
  arg: i
@@ -1786,6 +1897,31 @@ function jitCompile(backend, jaxpr, consts) {
1786
1897
  }
1787
1898
  for (let i = 0; i < jaxpr.eqns.length; i++) {
1788
1899
  const eqn = jaxpr.eqns[i];
1900
+ if (routinePrimitives.has(eqn.primitive)) {
1901
+ const routine = new require_backend.Routine(routinePrimitives.get(eqn.primitive), {
1902
+ inputShapes: eqn.inputs.map((x) => x.aval.shape),
1903
+ inputDtypes: eqn.inputs.map((x) => x.aval.dtype),
1904
+ outputShapes: eqn.outBinders.map((x) => x.aval.shape),
1905
+ outputDtypes: eqn.outBinders.map((x) => x.aval.dtype)
1906
+ }, eqn.params);
1907
+ const inputs = [];
1908
+ for (const input of eqn.inputs) if (input instanceof Var) {
1909
+ const jv = ctx.get(input);
1910
+ if (jv.type !== "imm") throw new Error(`jit: routine primitive ${eqn.primitive} input is not imm`);
1911
+ inputs.push(jv.arg);
1912
+ } else if (input instanceof Lit) inputs.push(builder.pushLit(input));
1913
+ const outputs = [];
1914
+ for (const outVar$1 of eqn.outBinders) {
1915
+ const outId = builder.pushBuffer(outVar$1.aval.size * require_backend.byteWidth(outVar$1.aval.dtype));
1916
+ outputs.push(outId);
1917
+ ctx.set(outVar$1, {
1918
+ type: "imm",
1919
+ arg: outId
1920
+ });
1921
+ }
1922
+ builder.pushRoutine(routine, inputs, outputs);
1923
+ continue;
1924
+ }
1789
1925
  const inputExps = [];
1790
1926
  const inputAvals = [];
1791
1927
  const inputArgs = [];
@@ -1840,7 +1976,7 @@ function jitCompile(backend, jaxpr, consts) {
1840
1976
  const outVar = eqn.outBinders[0];
1841
1977
  if (blackNodes.has(outVar)) {
1842
1978
  const nargs$1 = inputArgs.length;
1843
- const size$1 = require_backend.prod(outVar.aval.shape);
1979
+ const size$1 = outVar.aval.size;
1844
1980
  const kernel = new require_backend.Kernel(nargs$1, size$1, exp$2, reduction);
1845
1981
  const outId = builder.pushKernel(kernel, inputArgs);
1846
1982
  ctx.set(outVar, {
@@ -1865,7 +2001,7 @@ function jitCompile(backend, jaxpr, consts) {
1865
2001
  if (jitValue.type !== "imm") throw new Error("internal: Expected imm, since outs are black nodes");
1866
2002
  outputIds.push(jitValue.arg);
1867
2003
  } else if (out instanceof Lit) outputIds.push(builder.pushLit(out));
1868
- const outputNeedsRef = new Set([...require_backend.range(nargs), ...builder.steps.filter((s) => s.type === "const").map((s) => s.output)]);
2004
+ const outputNeedsRef = new Set(require_backend.range(nargs));
1869
2005
  for (const outputId of outputIds) if (outputNeedsRef.has(outputId)) builder.pushIncref(outputId);
1870
2006
  else outputNeedsRef.add(outputId);
1871
2007
  builder.insertFreeSteps(outputIds);
@@ -1911,11 +2047,18 @@ function reshapeJit(fn) {
1911
2047
  return { exp: reshapeViews(a, (st) => fn(st, params)) };
1912
2048
  };
1913
2049
  }
2050
+ function routineNoJit() {
2051
+ return () => {
2052
+ throw new Error("jit: rule is not implemented for routines");
2053
+ };
2054
+ }
1914
2055
  const jitRules = {
1915
2056
  [Primitive.Add]: broadcastedJit(([a, b]) => require_backend.AluExp.add(a, b)),
1916
2057
  [Primitive.Mul]: broadcastedJit(([a, b]) => require_backend.AluExp.mul(a, b)),
1917
2058
  [Primitive.Idiv]: broadcastedJit(([a, b]) => require_backend.AluExp.idiv(a, b)),
1918
2059
  [Primitive.Mod]: broadcastedJit(([a, b]) => require_backend.AluExp.mod(a, b)),
2060
+ [Primitive.Min]: broadcastedJit(([a, b]) => require_backend.AluExp.min(a, b)),
2061
+ [Primitive.Max]: broadcastedJit(([a, b]) => require_backend.AluExp.max(a, b)),
1919
2062
  [Primitive.Neg]: unopJit((a) => require_backend.AluExp.sub(require_backend.AluExp.const(a.dtype, 0), a)),
1920
2063
  [Primitive.Reciprocal]: unopJit(require_backend.AluExp.reciprocal),
1921
2064
  [Primitive.Floor]: unopJit(require_backend.AluExp.floor),
@@ -1923,17 +2066,6 @@ const jitRules = {
1923
2066
  [Primitive.StopGradient]: unopJit((a) => a),
1924
2067
  [Primitive.Cast]: unopJit((a, { dtype }) => require_backend.AluExp.cast(dtype, a)),
1925
2068
  [Primitive.Bitcast]: unopJit((a, { dtype }) => require_backend.AluExp.bitcast(dtype, a)),
1926
- [Primitive.RandomBits]: (keys, keyShapes, { shape: shape$1, mode }) => {
1927
- const mapping = (st) => {
1928
- if (!require_backend.deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, require_backend.range(shape$1.length - st.shape.length));
1929
- };
1930
- const k0 = reshapeViews(keys[0], mapping);
1931
- const k1 = reshapeViews(keys[1], mapping);
1932
- const c0 = require_backend.AluExp.u32(0);
1933
- const c1 = require_backend.AluExp.cast(require_backend.DType.Uint32, require_backend.AluVar.gidx);
1934
- const exp$2 = require_backend.AluExp.threefry2x32(k0, k1, c0, c1, mode);
1935
- return { exp: exp$2 };
1936
- },
1937
2069
  [Primitive.Sin]: unopJit(require_backend.AluExp.sin),
1938
2070
  [Primitive.Cos]: unopJit(require_backend.AluExp.cos),
1939
2071
  [Primitive.Asin]: unopJit(require_backend.AluExp.asin),
@@ -1943,8 +2075,6 @@ const jitRules = {
1943
2075
  [Primitive.Erf]: unopJit(require_backend.AluExp.erf),
1944
2076
  [Primitive.Erfc]: unopJit(require_backend.AluExp.erfc),
1945
2077
  [Primitive.Sqrt]: unopJit(require_backend.AluExp.sqrt),
1946
- [Primitive.Min]: broadcastedJit(([a, b]) => require_backend.AluExp.min(a, b)),
1947
- [Primitive.Max]: broadcastedJit(([a, b]) => require_backend.AluExp.max(a, b)),
1948
2078
  [Primitive.Reduce]([a], [as], { op, axis }) {
1949
2079
  const keptAxes = [];
1950
2080
  const shiftedAxes = [];
@@ -1994,16 +2124,17 @@ const jitRules = {
1994
2124
  },
1995
2125
  [Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
1996
2126
  [Primitive.Where]: broadcastedJit(([cond, a, b]) => require_backend.AluExp.where(cond, a, b), { skipCastIdx: [0] }),
1997
- [Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
1998
- [Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
1999
- [Primitive.Reshape]: reshapeJit((st, { shape: shape$1 }) => st.reshape(shape$1)),
2000
- [Primitive.Flip]: reshapeJit((st, { axis }) => {
2001
- const arg = require_backend.rep(st.shape.length, false);
2002
- for (const ax of axis) arg[ax] = true;
2003
- return st.flip(arg);
2004
- }),
2005
- [Primitive.Shrink]: reshapeJit((st, { slice }) => st.shrink(slice)),
2006
- [Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
2127
+ [Primitive.RandomBits]: (keys, keyShapes, { shape: shape$1, mode }) => {
2128
+ const mapping = (st) => {
2129
+ if (!require_backend.deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, require_backend.range(shape$1.length - st.shape.length));
2130
+ };
2131
+ const k0 = reshapeViews(keys[0], mapping);
2132
+ const k1 = reshapeViews(keys[1], mapping);
2133
+ const c0 = require_backend.AluExp.u32(0);
2134
+ const c1 = require_backend.AluExp.cast(require_backend.DType.Uint32, require_backend.AluVar.gidx);
2135
+ const exp$2 = require_backend.AluExp.threefry2x32(k0, k1, c0, c1, mode);
2136
+ return { exp: exp$2 };
2137
+ },
2007
2138
  [Primitive.Gather]([x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
2008
2139
  const axisSet = new Set(axis);
2009
2140
  const indexShape = indicesShapes.map((c) => c.shape).reduce(require_backend.generalBroadcast);
@@ -2019,8 +2150,22 @@ const jitRules = {
2019
2150
  if (!valid.resolve()) throw new Error("internal: expected full validity mask in Gather");
2020
2151
  return { exp: x.substitute({ gidx: index }) };
2021
2152
  },
2022
- [Primitive.JitCall]() {
2023
- throw new Error("internal: JitCall should have been flattened before JIT compilation");
2153
+ [Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
2154
+ [Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
2155
+ [Primitive.Reshape]: reshapeJit((st, { shape: shape$1 }) => st.reshape(shape$1)),
2156
+ [Primitive.Flip]: reshapeJit((st, { axis }) => {
2157
+ const arg = require_backend.rep(st.shape.length, false);
2158
+ for (const ax of axis) arg[ax] = true;
2159
+ return st.flip(arg);
2160
+ }),
2161
+ [Primitive.Shrink]: reshapeJit((st, { slice }) => st.shrink(slice)),
2162
+ [Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
2163
+ [Primitive.Sort]: routineNoJit(),
2164
+ [Primitive.Argsort]: routineNoJit(),
2165
+ [Primitive.TriangularSolve]: routineNoJit(),
2166
+ [Primitive.Cholesky]: routineNoJit(),
2167
+ [Primitive.Jit]() {
2168
+ throw new Error("internal: Jit should have been flattened before JIT compilation");
2024
2169
  }
2025
2170
  };
2026
2171
  /** Determines how to split the Jaxpr into kernels via dataflow analysis. */
@@ -2078,8 +2223,8 @@ function splitGraphDataflow(backend, jaxpr) {
2078
2223
  case Primitive.Mul:
2079
2224
  case Primitive.Idiv:
2080
2225
  case Primitive.Mod:
2081
- case Primitive.Max:
2082
- case Primitive.Min: {
2226
+ case Primitive.Min:
2227
+ case Primitive.Max: {
2083
2228
  const otherInput = nextEqn.inputs.find((v) => v !== outVar);
2084
2229
  if (otherInput instanceof Lit || require_backend.deepEqual(require_backend.generalBroadcast(otherInput.aval.shape, outVar.aval.shape), outVar.aval.shape)) {
2085
2230
  head = usages[0];
@@ -2099,11 +2244,11 @@ function splitGraphDataflow(backend, jaxpr) {
2099
2244
  blackNodes.add(v);
2100
2245
  p1NextBlack.set(v, v);
2101
2246
  }
2102
- const heterogeneousViewPrimitives = [Primitive.Gather, Primitive.RandomBits];
2247
+ const heterogeneousViewPrimitives = [Primitive.RandomBits, Primitive.Gather];
2103
2248
  const needsCleanShapePrimitives = [Primitive.Pad];
2104
2249
  for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
2105
2250
  const eqn = jaxpr.eqns[i];
2106
- if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
2251
+ if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || routinePrimitives.has(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
2107
2252
  for (const v of eqn.outBinders) {
2108
2253
  blackNodes.add(v);
2109
2254
  p1NextBlack.set(v, v);
@@ -2113,7 +2258,7 @@ function splitGraphDataflow(backend, jaxpr) {
2113
2258
  const reach = /* @__PURE__ */ new Set();
2114
2259
  let needsCleanOutput = false;
2115
2260
  outer: for (const v of eqn.outBinders) for (const j of varToUsages.get(v) ?? []) {
2116
- if (needsCleanShapePrimitives.includes(jaxpr.eqns[j].primitive)) {
2261
+ if (needsCleanShapePrimitives.includes(jaxpr.eqns[j].primitive) || routinePrimitives.has(jaxpr.eqns[j].primitive)) {
2117
2262
  needsCleanOutput = true;
2118
2263
  break outer;
2119
2264
  }
@@ -2137,7 +2282,6 @@ function splitGraphDataflow(backend, jaxpr) {
2137
2282
  while (p2idx < jaxpr.eqns.length) {
2138
2283
  const eqn = jaxpr.eqns[p2idx++];
2139
2284
  const deps = [];
2140
- if (eqn.outBinders.some((v) => blackNodes.has(v))) continue;
2141
2285
  for (const input of eqn.inputs) if (input instanceof Var) if (blackNodes.has(input)) deps.push(new Set([input]));
2142
2286
  else deps.push(p2Deps.get(input));
2143
2287
  else deps.push(/* @__PURE__ */ new Set());
@@ -2160,7 +2304,7 @@ function splitGraphDataflow(backend, jaxpr) {
2160
2304
  if (assocInput === -1) throw new Error(`internal: maxArgs, no input found to mark as black in Jaxpr equation ${eqn}`);
2161
2305
  const assocVar = eqn.inputs[assocInput];
2162
2306
  p2idx = varToDefn.get(assocVar);
2163
- for (const out of jaxpr.eqns[p2idx].outBinders) blackNodes.add(out);
2307
+ for (const out of jaxpr.eqns[p2idx++].outBinders) blackNodes.add(out);
2164
2308
  } else {
2165
2309
  const s = new Set(depCounter.keys());
2166
2310
  for (const out of eqn.outBinders) p2Deps.set(out, s);
@@ -2186,9 +2330,9 @@ var PendingExecute = class {
2186
2330
  submitted = false;
2187
2331
  #promise = null;
2188
2332
  #rc = 1;
2189
- constructor(backend, kernel, inputs, outputs) {
2333
+ constructor(backend, source, inputs, outputs) {
2190
2334
  this.backend = backend;
2191
- this.kernel = kernel;
2335
+ this.source = source;
2192
2336
  this.inputs = inputs;
2193
2337
  this.outputs = outputs;
2194
2338
  for (const slot of inputs) this.backend.incRef(slot);
@@ -2209,13 +2353,15 @@ var PendingExecute = class {
2209
2353
  return;
2210
2354
  }
2211
2355
  this.#promise = (async () => {
2212
- this.prepared = await this.backend.prepare(this.kernel);
2356
+ if (this.source instanceof require_backend.Kernel) this.prepared = await this.backend.prepareKernel(this.source);
2357
+ else this.prepared = await this.backend.prepareRoutine(this.source);
2213
2358
  })();
2214
2359
  await this.#promise;
2215
2360
  }
2216
2361
  prepareSync() {
2217
2362
  if (this.prepared) return;
2218
- this.prepared = this.backend.prepareSync(this.kernel);
2363
+ if (this.source instanceof require_backend.Kernel) this.prepared = this.backend.prepareKernelSync(this.source);
2364
+ else this.prepared = this.backend.prepareRoutineSync(this.source);
2219
2365
  }
2220
2366
  submit() {
2221
2367
  if (this.submitted) return;
@@ -2238,8 +2384,6 @@ var PendingExecute = class {
2238
2384
  * "Array" type by name.
2239
2385
  */
2240
2386
  var Array$1 = class Array$1 extends Tracer {
2241
- static #nextId = 1001;
2242
- id;
2243
2387
  #dtype;
2244
2388
  #weakType;
2245
2389
  #source;
@@ -2256,7 +2400,6 @@ var Array$1 = class Array$1 extends Tracer {
2256
2400
  */
2257
2401
  constructor(args) {
2258
2402
  super(baseArrayTrace);
2259
- this.id = Array$1.#nextId++;
2260
2403
  this.#dtype = args.dtype;
2261
2404
  this.#weakType = args.weakType;
2262
2405
  this.#source = args.source;
@@ -2565,6 +2708,27 @@ var Array$1 = class Array$1 extends Tracer {
2565
2708
  pending
2566
2709
  });
2567
2710
  }
2711
+ /** Apply an operation with custom lowering to this array. */
2712
+ static #routine(routine, arrays, outputWeakType) {
2713
+ const { backend, committed } = Array$1.#computeBackend(routine.name, arrays);
2714
+ for (const ar of arrays) ar.#realize();
2715
+ const inputs = arrays.map((ar) => ar.#source);
2716
+ const outputs = routine.type.outputDtypes.map((dtype, i) => backend.malloc(require_backend.byteWidth(dtype) * require_backend.prod(routine.type.outputShapes[i])));
2717
+ const pending = arrays.flatMap((ar) => ar.#pending);
2718
+ for (const exe of pending) exe.updateRc(+outputs.length);
2719
+ pending.push(new PendingExecute(backend, routine, inputs, outputs));
2720
+ pending[pending.length - 1].updateRc(+outputs.length - 1);
2721
+ arrays.forEach((ar) => ar.dispose());
2722
+ return outputs.map((output, i) => new Array$1({
2723
+ source: output,
2724
+ st: require_backend.ShapeTracker.fromShape(routine.type.outputShapes[i]),
2725
+ dtype: routine.type.outputDtypes[i],
2726
+ weakType: outputWeakType[i],
2727
+ backend,
2728
+ committed,
2729
+ pending
2730
+ }));
2731
+ }
2568
2732
  /**
2569
2733
  * Normalizes this array into one backed by a `Slot`.
2570
2734
  *
@@ -2725,6 +2889,12 @@ var Array$1 = class Array$1 extends Tracer {
2725
2889
  [Primitive.Mod]([x, y]) {
2726
2890
  return [x.#binary(require_backend.AluOp.Mod, y)];
2727
2891
  },
2892
+ [Primitive.Min]([x, y]) {
2893
+ return [x.#binary(require_backend.AluOp.Min, y)];
2894
+ },
2895
+ [Primitive.Max]([x, y]) {
2896
+ return [x.#binary(require_backend.AluOp.Max, y)];
2897
+ },
2728
2898
  [Primitive.Neg]([x]) {
2729
2899
  return [zerosLike$1(x.ref).#binary(require_backend.AluOp.Sub, x)];
2730
2900
  },
@@ -2761,25 +2931,6 @@ var Array$1 = class Array$1 extends Tracer {
2761
2931
  return [y];
2762
2932
  }
2763
2933
  },
2764
- [Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
2765
- const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
2766
- if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
2767
- const c0 = zeros(shape$1, {
2768
- dtype: require_backend.DType.Uint32,
2769
- device: k0.device
2770
- });
2771
- const c1 = arange(0, require_backend.prod(shape$1), 1, {
2772
- dtype: require_backend.DType.Uint32,
2773
- device: k0.device
2774
- }).reshape(shape$1);
2775
- const custom = ([k0$1, k1$1, c0$1, c1$1]) => require_backend.AluExp.threefry2x32(k0$1, k1$1, c0$1, c1$1, mode);
2776
- return [Array$1.#naryCustom("random_bits", custom, [
2777
- k0,
2778
- k1,
2779
- c0,
2780
- c1
2781
- ])];
2782
- },
2783
2934
  [Primitive.Sin]([x]) {
2784
2935
  return [x.#unary(require_backend.AluOp.Sin)];
2785
2936
  },
@@ -2807,12 +2958,6 @@ var Array$1 = class Array$1 extends Tracer {
2807
2958
  [Primitive.Sqrt]([x]) {
2808
2959
  return [x.#unary(require_backend.AluOp.Sqrt)];
2809
2960
  },
2810
- [Primitive.Min]([x, y]) {
2811
- return [x.#binary(require_backend.AluOp.Min, y)];
2812
- },
2813
- [Primitive.Max]([x, y]) {
2814
- return [x.#binary(require_backend.AluOp.Max, y)];
2815
- },
2816
2961
  [Primitive.Reduce]([x], { op, axis }) {
2817
2962
  if (axis.length === 0) return [x];
2818
2963
  return [x.#moveAxesDown(axis).#reduce(op)];
@@ -2847,13 +2992,35 @@ var Array$1 = class Array$1 extends Tracer {
2847
2992
  y
2848
2993
  ], { dtypeOverride: [require_backend.DType.Bool] })];
2849
2994
  },
2850
- [Primitive.Transpose]([x], { perm }) {
2851
- return [x.#transpose(perm)];
2852
- },
2853
- [Primitive.Broadcast]([x], { shape: shape$1, axis }) {
2854
- return [x.#reshape(x.#st.broadcast(shape$1, axis))];
2855
- },
2856
- [Primitive.Reshape]([x], { shape: shape$1 }) {
2995
+ [Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
2996
+ const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
2997
+ if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
2998
+ const c0 = zeros(shape$1, {
2999
+ dtype: require_backend.DType.Uint32,
3000
+ device: k0.device
3001
+ });
3002
+ const c1 = arange(0, require_backend.prod(shape$1), 1, {
3003
+ dtype: require_backend.DType.Uint32,
3004
+ device: k0.device
3005
+ }).reshape(shape$1);
3006
+ const custom = ([k0$1, k1$1, c0$1, c1$1]) => require_backend.AluExp.threefry2x32(k0$1, k1$1, c0$1, c1$1, mode);
3007
+ return [Array$1.#naryCustom("random_bits", custom, [
3008
+ k0,
3009
+ k1,
3010
+ c0,
3011
+ c1
3012
+ ])];
3013
+ },
3014
+ [Primitive.Gather]([x, ...indices], { axis, outDim }) {
3015
+ return [x.#gather(indices, axis, outDim)];
3016
+ },
3017
+ [Primitive.Transpose]([x], { perm }) {
3018
+ return [x.#transpose(perm)];
3019
+ },
3020
+ [Primitive.Broadcast]([x], { shape: shape$1, axis }) {
3021
+ return [x.#reshape(x.#st.broadcast(shape$1, axis))];
3022
+ },
3023
+ [Primitive.Reshape]([x], { shape: shape$1 }) {
2857
3024
  return [x.#reshape(x.#st.reshape(shape$1))];
2858
3025
  },
2859
3026
  [Primitive.Flip]([x], { axis }) {
@@ -2867,17 +3034,48 @@ var Array$1 = class Array$1 extends Tracer {
2867
3034
  [Primitive.Pad]([x], { width }) {
2868
3035
  return [x.#reshape(x.#st.pad(width))];
2869
3036
  },
2870
- [Primitive.Gather]([x, ...indices], { axis, outDim }) {
2871
- return [x.#gather(indices, axis, outDim)];
3037
+ [Primitive.Sort]([x]) {
3038
+ const routine = new require_backend.Routine(require_backend.Routines.Sort, {
3039
+ inputShapes: [x.aval.shape],
3040
+ inputDtypes: [x.aval.dtype],
3041
+ outputShapes: [x.aval.shape],
3042
+ outputDtypes: [x.aval.dtype]
3043
+ });
3044
+ return Array$1.#routine(routine, [x], [x.#weakType]);
2872
3045
  },
2873
- [Primitive.JitCall](args, { jaxpr, numConsts }) {
2874
- if (jaxpr.inBinders.length !== args.length) throw new Error(`jit_call expects ${jaxpr.inBinders.length} args, got ${args.length}`);
2875
- const { backend, committed } = Array$1.#computeBackend("jit_call", args);
3046
+ [Primitive.Argsort]([x]) {
3047
+ const routine = new require_backend.Routine(require_backend.Routines.Argsort, {
3048
+ inputShapes: [x.aval.shape],
3049
+ inputDtypes: [x.aval.dtype],
3050
+ outputShapes: [x.aval.shape, x.aval.shape],
3051
+ outputDtypes: [x.aval.dtype, require_backend.DType.Int32]
3052
+ });
3053
+ return Array$1.#routine(routine, [x], [x.#weakType, false]);
3054
+ },
3055
+ [Primitive.TriangularSolve]([a, b], { unitDiagonal }) {
3056
+ const routine = new require_backend.Routine(require_backend.Routines.TriangularSolve, {
3057
+ inputShapes: [a.aval.shape, b.aval.shape],
3058
+ inputDtypes: [a.aval.dtype, b.aval.dtype],
3059
+ outputShapes: [b.aval.shape],
3060
+ outputDtypes: [b.aval.dtype]
3061
+ }, { unitDiagonal });
3062
+ return Array$1.#routine(routine, [a, b], [a.#weakType && b.#weakType]);
3063
+ },
3064
+ [Primitive.Cholesky]([a]) {
3065
+ const routine = new require_backend.Routine(require_backend.Routines.Cholesky, {
3066
+ inputShapes: [a.aval.shape],
3067
+ inputDtypes: [a.aval.dtype],
3068
+ outputShapes: [a.aval.shape],
3069
+ outputDtypes: [a.aval.dtype]
3070
+ });
3071
+ return Array$1.#routine(routine, [a], [a.#weakType]);
3072
+ },
3073
+ [Primitive.Jit](args, { jaxpr }) {
3074
+ if (jaxpr.inBinders.length !== args.length) throw new Error(`jit expects ${jaxpr.inBinders.length} args, got ${args.length}`);
3075
+ const { backend, committed } = Array$1.#computeBackend("jit", args);
2876
3076
  args = args.map((ar) => ar._putSync(backend));
2877
- const consts = args.slice(0, numConsts);
2878
- const tracers = args.slice(numConsts);
2879
- const jp = jitCompile(backend, jaxpr, consts);
2880
- const { outputs, pending } = jp.execute(tracers.map((x) => x._realizeSource()));
3077
+ const jp = jitCompile(backend, jaxpr);
3078
+ const { outputs, pending } = jp.execute(args.map((x) => x._realizeSource()));
2881
3079
  for (const exe of pending) exe.updateRc(+outputs.length - 1);
2882
3080
  const prevPending = [...new Set(args.flatMap((x) => x.#pending))];
2883
3081
  for (const exe of prevPending) exe.updateRc(+outputs.length);
@@ -3176,6 +3374,43 @@ function arange(start, stop, step = 1, { dtype, device } = {}) {
3176
3374
  });
3177
3375
  }
3178
3376
  /**
3377
+ * Return an array with ones on and below the diagonal and zeros elsewhere.
3378
+ *
3379
+ * If `k` is provided, it specifies the sub-diagonal on and below which the
3380
+ * array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
3381
+ * `k>0` is above it.
3382
+ */
3383
+ function tri(n, m, k = 0, { dtype, device } = {}) {
3384
+ m ??= n;
3385
+ dtype ??= require_backend.DType.Float32;
3386
+ if (!Number.isInteger(n) || n < 0) throw new Error(`tri: n must be a non-negative integer, got ${n}`);
3387
+ if (!Number.isInteger(m) || m < 0) throw new Error(`tri: m must be a non-negative integer, got ${m}`);
3388
+ if (!Number.isInteger(k)) throw new Error(`tri: k must be an integer, got ${k}`);
3389
+ const rows = arange(k, n + k, 1, {
3390
+ dtype: require_backend.DType.Int32,
3391
+ device
3392
+ });
3393
+ const cols = arange(0, m, 1, {
3394
+ dtype: require_backend.DType.Int32,
3395
+ device
3396
+ });
3397
+ return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
3398
+ }
3399
+ /** Return the lower triangle of an array. Must be of dimension >= 2. */
3400
+ function tril(a, k = 0) {
3401
+ if (ndim$1(a) < 2) throw new Error(`tril: input array must be at least 2D, got ${ndim$1(a)}D`);
3402
+ a = fudgeArray(a);
3403
+ const [n, m] = a.shape.slice(-2);
3404
+ return where$1(tri(n, m, k, { dtype: require_backend.DType.Bool }), a.ref, zerosLike$1(a));
3405
+ }
3406
+ /** Return the upper triangle of an array. Must be of dimension >= 2. */
3407
+ function triu(a, k = 0) {
3408
+ if (ndim$1(a) < 2) throw new Error(`tril: input array must be at least 2D, got ${ndim$1(a)}D`);
3409
+ a = fudgeArray(a);
3410
+ const [n, m] = a.shape.slice(-2);
3411
+ return where$1(tri(n, m, k - 1, { dtype: require_backend.DType.Bool }), zerosLike$1(a.ref), a);
3412
+ }
3413
+ /**
3179
3414
  * Return evenly spaced numbers over a specified interval.
3180
3415
  *
3181
3416
  * Returns _num_ evenly spaced samples, calculated over the interval
@@ -3222,385 +3457,187 @@ function aluCompare(a, b, op) {
3222
3457
  }
3223
3458
 
3224
3459
  //#endregion
3225
- //#region src/frontend/jvp.ts
3460
+ //#region src/frontend/vmap.ts
3226
3461
  var import_usingCtx$1 = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
3227
- var JVPTracer = class extends Tracer {
3228
- constructor(trace$1, primal, tangent) {
3462
+ function mappedAval(batchDim, aval) {
3463
+ const shape$1 = [...aval.shape];
3464
+ shape$1.splice(batchDim, 1);
3465
+ return new ShapedArray(shape$1, aval.dtype, aval.weakType);
3466
+ }
3467
+ /** Move one axis to a different index. */
3468
+ function moveaxis(x, src, dst) {
3469
+ const t = pureArray(x);
3470
+ src = require_backend.checkAxis(src, t.ndim);
3471
+ dst = require_backend.checkAxis(dst, t.ndim);
3472
+ if (src === dst) return t;
3473
+ const perm = require_backend.range(t.ndim);
3474
+ perm.splice(src, 1);
3475
+ perm.splice(dst, 0, src);
3476
+ return transpose$1(t, perm);
3477
+ }
3478
+ function moveBatchAxis(axisSize, src, dst, x) {
3479
+ if (src === null) {
3480
+ const targetShape = [...x.shape];
3481
+ targetShape.splice(dst, 0, axisSize);
3482
+ return broadcast(x, targetShape, [dst]);
3483
+ } else if (src === dst) return x;
3484
+ else return moveaxis(x, src, dst);
3485
+ }
3486
+ var BatchTracer = class extends Tracer {
3487
+ constructor(trace$1, val, batchDim) {
3229
3488
  super(trace$1);
3230
- this.primal = primal;
3231
- this.tangent = tangent;
3489
+ this.val = val;
3490
+ this.batchDim = batchDim;
3232
3491
  }
3233
3492
  get aval() {
3234
- return this.primal.aval;
3493
+ if (this.batchDim === null) return this.val.aval;
3494
+ else return mappedAval(this.batchDim, this.val.aval);
3235
3495
  }
3236
3496
  toString() {
3237
- return `JVPTracer(${this.primal.toString()}, ${this.tangent.toString()})`;
3497
+ return `BatchTracer(${this.val.toString()}, ${this.batchDim})`;
3238
3498
  }
3239
3499
  get ref() {
3240
- this.primal.ref, this.tangent.ref;
3500
+ this.val.ref;
3241
3501
  return this;
3242
3502
  }
3243
3503
  dispose() {
3244
- this.primal.dispose();
3245
- this.tangent.dispose();
3504
+ this.val.dispose();
3505
+ }
3506
+ fullLower() {
3507
+ if (this.batchDim === null) return this.val.fullLower();
3508
+ else return this;
3246
3509
  }
3247
3510
  };
3248
- var JVPTrace = class extends Trace {
3511
+ var BatchTrace = class extends Trace {
3249
3512
  pure(val) {
3250
3513
  return this.lift(pureArray(val));
3251
3514
  }
3252
3515
  lift(val) {
3253
- return new JVPTracer(this, val, zerosLike$1(val.ref));
3516
+ return new BatchTracer(this, val, null);
3254
3517
  }
3255
3518
  processPrimitive(primitive, tracers, params) {
3256
- const [primalsIn, tangentsIn] = require_backend.unzip2(tracers.map((x) => [x.primal, x.tangent]));
3257
- const jvpRule = jvpRules[primitive];
3258
- if (jvpRule === void 0) throw new Error(`No JVP rule for: ${primitive}`);
3259
- const [primalsOut, tangentsOut] = jvpRule(primalsIn, tangentsIn, params);
3260
- return require_backend.zip(primalsOut, tangentsOut).map(([x, t]) => new JVPTracer(this, x, t));
3519
+ const [valsIn, bdimsIn] = require_backend.unzip2(tracers.map((t) => [t.val, t.batchDim]));
3520
+ const vmapRule = vmapRules[primitive];
3521
+ if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
3522
+ if (bdimsIn.every((d) => d === null)) {
3523
+ const valOuts$1 = bind(primitive, valsIn, params);
3524
+ return valOuts$1.map((x) => new BatchTracer(this, x, null));
3525
+ }
3526
+ const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
3527
+ return require_backend.zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
3528
+ }
3529
+ get axisSize() {
3530
+ return this.main.globalData;
3261
3531
  }
3262
3532
  };
3263
- /** Rule that applies the same operation to primals and tangents. */
3264
- function linearTangentsJvp(primitive) {
3265
- return (primals, tangents, params) => {
3266
- const ys = bind(primitive, primals, params);
3267
- const dys = bind(primitive, tangents, params);
3268
- return [ys, dys];
3269
- };
3270
- }
3271
- /** Rule for product of gradients in bilinear operations. */
3272
- function bilinearTangentsJvp(primitive) {
3273
- return ([x, y], [dx, dy], params) => {
3274
- const primal = bind1(primitive, [x.ref, y.ref], params);
3275
- const tangent = bind1(primitive, [x, dy], params).add(bind1(primitive, [dx, y], params));
3276
- return [[primal], [tangent]];
3533
+ /**
3534
+ * Process a primitive with built-in broadcasting.
3535
+ *
3536
+ * Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
3537
+ */
3538
+ function broadcastBatcher(op) {
3539
+ return (axisSize, args, dims) => {
3540
+ if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
3541
+ const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
3542
+ const firstIdx = dims.findIndex((d) => d !== null);
3543
+ const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
3544
+ if (require_backend.zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
3545
+ args = args.map((x, i) => {
3546
+ if (dims[i] === null) return x;
3547
+ x = moveBatchAxis(axisSize, dims[i], 0, x);
3548
+ if (x.ndim < nd) x = x.reshape([
3549
+ x.shape[0],
3550
+ ...require_backend.rep(nd - x.ndim, 1),
3551
+ ...x.shape.slice(1)
3552
+ ]);
3553
+ return x;
3554
+ });
3555
+ return [[op(...args)], [0]];
3277
3556
  };
3278
3557
  }
3279
- /** Rule that zeros out any tangents. */
3280
- function zeroTangentsJvp(primitive) {
3281
- return (primals, tangents, params) => {
3282
- for (const t of tangents) t.dispose();
3283
- const ys = bind(primitive, primals, params);
3284
- return [ys, ys.map((y) => zerosLike$1(y.ref))];
3558
+ function unopBatcher(op) {
3559
+ return (axisSize, [x], [xBdim], params) => {
3560
+ return [[op(x, params)], [xBdim]];
3285
3561
  };
3286
3562
  }
3287
- const jvpRules = {
3288
- [Primitive.Add]: linearTangentsJvp(Primitive.Add),
3289
- [Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
3290
- [Primitive.Idiv]: zeroTangentsJvp(Primitive.Idiv),
3291
- [Primitive.Mod]([x, y], [dx, dy]) {
3292
- if (!require_backend.isFloatDtype(x.dtype) && !require_backend.isFloatDtype(y.dtype)) {
3293
- dx.dispose();
3294
- dy.dispose();
3295
- return [[x.ref, y.ref], [zerosLike$1(x), zerosLike$1(y)]];
3296
- }
3297
- const q = idiv(x.ref, y.ref);
3298
- return [[mod(x, y)], [dx.sub(dy.mul(q))]];
3563
+ const vmapRules = {
3564
+ [Primitive.Add]: broadcastBatcher(add$1),
3565
+ [Primitive.Mul]: broadcastBatcher(mul),
3566
+ [Primitive.Idiv]: broadcastBatcher(idiv),
3567
+ [Primitive.Mod]: broadcastBatcher(mod),
3568
+ [Primitive.Min]: broadcastBatcher(min$1),
3569
+ [Primitive.Max]: broadcastBatcher(max$1),
3570
+ [Primitive.Neg]: unopBatcher(neg),
3571
+ [Primitive.Reciprocal]: unopBatcher(reciprocal$1),
3572
+ [Primitive.Floor]: unopBatcher(floor$1),
3573
+ [Primitive.Ceil]: unopBatcher(ceil$1),
3574
+ [Primitive.StopGradient]: unopBatcher(stopGradient),
3575
+ [Primitive.Cast]: unopBatcher((x, { dtype }) => cast(x, dtype)),
3576
+ [Primitive.Bitcast]: unopBatcher((x, { dtype }) => bitcast(x, dtype)),
3577
+ [Primitive.Sin]: unopBatcher(sin$1),
3578
+ [Primitive.Cos]: unopBatcher(cos$1),
3579
+ [Primitive.Asin]: unopBatcher(asin$1),
3580
+ [Primitive.Atan]: unopBatcher(atan$1),
3581
+ [Primitive.Exp]: unopBatcher(exp$1),
3582
+ [Primitive.Log]: unopBatcher(log$1),
3583
+ [Primitive.Erf]: unopBatcher(erf$1),
3584
+ [Primitive.Erfc]: unopBatcher(erfc$1),
3585
+ [Primitive.Sqrt]: unopBatcher(sqrt$1),
3586
+ [Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
3587
+ require_backend.assertNonNull(xBdim);
3588
+ const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3589
+ const outBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
3590
+ return [[reduce(x, op, newAxis)], [outBdim]];
3299
3591
  },
3300
- [Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
3301
- [Primitive.Reciprocal]([x], [dx]) {
3302
- const xRecip = reciprocal$1(x.ref);
3303
- return [[xRecip.ref], [neg(xRecip.ref.mul(xRecip)).mul(dx)]];
3592
+ [Primitive.Dot](axisSize, [x, y], [xBdim, yBdim]) {
3593
+ x = moveBatchAxis(axisSize, xBdim, x.ndim - (xBdim === null ? 1 : 2), x);
3594
+ y = moveBatchAxis(axisSize, yBdim, y.ndim - (yBdim === null ? 1 : 2), y);
3595
+ const z = dot$2(x, y);
3596
+ return [[z], [z.ndim - 1]];
3304
3597
  },
3305
- [Primitive.Floor]: zeroTangentsJvp(Primitive.Floor),
3306
- [Primitive.Ceil]: zeroTangentsJvp(Primitive.Ceil),
3307
- [Primitive.StopGradient]: zeroTangentsJvp(Primitive.StopGradient),
3308
- [Primitive.Cast]([x], [dx], { dtype }) {
3309
- if (x.dtype === dtype) return [[x], [dx]];
3310
- if (require_backend.isFloatDtype(dtype) && require_backend.isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
3311
- else {
3312
- dx.dispose();
3313
- return [[cast(x.ref, dtype)], [zerosLike$1(x)]];
3314
- }
3598
+ [Primitive.Conv](axisSize, [x, y], [xBdim, yBdim], params) {
3599
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3600
+ y = moveBatchAxis(axisSize, yBdim, 0, y);
3601
+ const z = conv$1(x, y, {
3602
+ ...params,
3603
+ vmapDims: params.vmapDims + 1
3604
+ });
3605
+ return [[z], [0]];
3315
3606
  },
3316
- [Primitive.Bitcast]([x], [dx], { dtype }) {
3317
- if (x.dtype === dtype) return [[x], [dx]];
3318
- dx.dispose();
3319
- return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
3607
+ [Primitive.Compare](axisSize, args, dims, { op }) {
3608
+ return broadcastBatcher((x, y) => compare(x, y, op))(axisSize, args, dims, {});
3320
3609
  },
3321
- [Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
3322
- [Primitive.Sin]([x], [dx]) {
3323
- return [[sin$1(x.ref)], [cos$1(x).mul(dx)]];
3324
- },
3325
- [Primitive.Cos]([x], [dx]) {
3326
- return [[cos$1(x.ref)], [neg(sin$1(x)).mul(dx)]];
3327
- },
3328
- [Primitive.Asin]([x], [dx]) {
3329
- const denom = sqrt$1(reciprocal$1(cast(1, x.dtype).sub(x.ref.mul(x.ref))));
3330
- return [[asin$1(x)], [denom.mul(dx)]];
3331
- },
3332
- [Primitive.Atan]([x], [dx]) {
3333
- const denom = cast(1, x.dtype).add(x.ref.mul(x.ref));
3334
- return [[atan$1(x)], [dx.div(denom)]];
3335
- },
3336
- [Primitive.Exp]([x], [dx]) {
3337
- const z = exp$1(x);
3338
- return [[z.ref], [z.mul(dx)]];
3339
- },
3340
- [Primitive.Log]([x], [dx]) {
3341
- return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
3342
- },
3343
- [Primitive.Erf]([x], [dx]) {
3344
- const coeff = 2 / Math.sqrt(Math.PI);
3345
- const expTerm = exp$1(neg(x.ref.mul(x.ref)));
3346
- return [[erf$1(x)], [expTerm.mul(coeff).mul(dx)]];
3347
- },
3348
- [Primitive.Erfc]([x], [dx]) {
3349
- const coeff = -2 / Math.sqrt(Math.PI);
3350
- const expTerm = exp$1(neg(x.ref.mul(x.ref)));
3351
- return [[erfc$1(x)], [expTerm.mul(coeff).mul(dx)]];
3352
- },
3353
- [Primitive.Sqrt]([x], [dx]) {
3354
- const z = sqrt$1(x);
3355
- return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
3356
- },
3357
- [Primitive.Min]([x, y], [dx, dy]) {
3358
- return [[min$1(x.ref, y.ref)], [where$1(less$1(y, x), dy, dx)]];
3359
- },
3360
- [Primitive.Max]([x, y], [dx, dy]) {
3361
- return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
3362
- },
3363
- [Primitive.Reduce]([x], [dx], { op, axis }) {
3364
- if (op === require_backend.AluOp.Add) return [[reduce(x, op, axis)], [reduce(dx, op, axis)]];
3365
- else if (op === require_backend.AluOp.Mul) {
3366
- const primal = reduce(x.ref, op, axis);
3367
- const tangent = broadcast(primal.ref, x.shape, axis).mul(reciprocal$1(x)).mul(dx).sum(axis);
3368
- return [[primal], [tangent]];
3369
- } else if (op === require_backend.AluOp.Min || op === require_backend.AluOp.Max) {
3370
- const primal = reduce(x.ref, op, axis);
3371
- const notMin = notEqual$1(x, broadcast(primal.ref, x.shape, axis));
3372
- const minCount = where$1(notMin.ref, 0, 1).sum(axis);
3373
- const tangent = where$1(notMin, 0, dx).sum(axis).div(minCount);
3374
- return [[primal], [tangent]];
3375
- } else throw new Error(`JVP rule not implemented for reduce op: ${op}`);
3376
- },
3377
- [Primitive.Pool]: linearTangentsJvp(Primitive.Pool),
3378
- [Primitive.PoolTranspose]: linearTangentsJvp(Primitive.PoolTranspose),
3379
- [Primitive.Dot]: bilinearTangentsJvp(Primitive.Dot),
3380
- [Primitive.Conv]: bilinearTangentsJvp(Primitive.Conv),
3381
- [Primitive.Compare]: zeroTangentsJvp(Primitive.Compare),
3382
- [Primitive.Where]([cond, x, y], [dcond, dx, dy]) {
3383
- dcond.dispose();
3384
- return [[where$1(cond.ref, x, y)], [where$1(cond, dx, dy)]];
3385
- },
3386
- [Primitive.Transpose]: linearTangentsJvp(Primitive.Transpose),
3387
- [Primitive.Broadcast]: linearTangentsJvp(Primitive.Broadcast),
3388
- [Primitive.Reshape]: linearTangentsJvp(Primitive.Reshape),
3389
- [Primitive.Flip]: linearTangentsJvp(Primitive.Flip),
3390
- [Primitive.Shrink]: linearTangentsJvp(Primitive.Shrink),
3391
- [Primitive.Pad]: linearTangentsJvp(Primitive.Pad),
3392
- [Primitive.Gather]([x, ...indices], [dx, ..._], { axis, outDim }) {
3393
- const indicesRef = indices.map((t) => t.ref);
3394
- return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
3395
- },
3396
- [Primitive.JitCall](primals, tangents, { name, jaxpr }) {
3397
- const { newJaxpr, newConsts } = jvpJaxpr(jaxpr);
3398
- const outs = bind(Primitive.JitCall, [
3399
- ...newConsts.map((c) => c.ref),
3400
- ...primals,
3401
- ...tangents
3402
- ], {
3403
- name: `${name}_jvp`,
3404
- jaxpr: newJaxpr,
3405
- numConsts: newConsts.length
3406
- });
3407
- const n = outs.length / 2;
3408
- if (!Number.isInteger(n)) throw new Error("internal: JVP Jaxpr output length is not even");
3409
- const [primalsOut, tangentsOut] = [outs.slice(0, n), outs.slice(n)];
3410
- return [primalsOut, tangentsOut];
3411
- }
3412
- };
3413
- const jvpJaxprCache = /* @__PURE__ */ new Map();
3414
- function jvpJaxpr(jaxpr) {
3415
- if (jvpJaxprCache.has(jaxpr)) return jvpJaxprCache.get(jaxpr);
3416
- const inAvals = jaxpr.inBinders.map((v) => v.aval);
3417
- const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((primals, tangents) => jvpFlat(jaxprAsFun(jaxpr), primals, tangents))(inAvals, inAvals);
3418
- const result = {
3419
- newJaxpr,
3420
- newConsts
3421
- };
3422
- jvpJaxprCache.set(jaxpr, result);
3423
- return result;
3424
- }
3425
- function jvpFlat(f, primals, tangents) {
3426
- try {
3427
- var _usingCtx$1 = (0, import_usingCtx$1.default)();
3428
- const main = _usingCtx$1.u(newMain(JVPTrace));
3429
- const trace$1 = new JVPTrace(main);
3430
- const tracersIn = require_backend.zip(primals, tangents).map(([x, t]) => new JVPTracer(trace$1, pureArray(x), pureArray(t)));
3431
- const outs = f(...tracersIn);
3432
- const tracersOut = outs.map((out) => fullRaise(trace$1, out));
3433
- return require_backend.unzip2(tracersOut.map((t) => [t.primal, t.tangent]));
3434
- } catch (_) {
3435
- _usingCtx$1.e = _;
3436
- } finally {
3437
- _usingCtx$1.d();
3438
- }
3439
- }
3440
- function jvp$1(f, primals, tangents) {
3441
- const [primalsFlat, inTree] = flatten(primals);
3442
- const [tangentsFlat, inTree2] = flatten(tangents);
3443
- if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
3444
- const [flatFun, outTree] = flattenFun(f, inTree);
3445
- const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
3446
- if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
3447
- const primalsOut = unflatten(outTree.value, primalsOutFlat);
3448
- const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
3449
- return [primalsOut, tangentsOut];
3450
- }
3451
-
3452
- //#endregion
3453
- //#region src/frontend/vmap.ts
3454
- var import_usingCtx = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
3455
- function mappedAval(batchDim, aval) {
3456
- const shape$1 = [...aval.shape];
3457
- shape$1.splice(batchDim, 1);
3458
- return new ShapedArray(shape$1, aval.dtype, aval.weakType);
3459
- }
3460
- /** Move one axis to a different index. */
3461
- function moveaxis(x, src, dst) {
3462
- const t = pureArray(x);
3463
- src = require_backend.checkAxis(src, t.ndim);
3464
- dst = require_backend.checkAxis(dst, t.ndim);
3465
- if (src === dst) return t;
3466
- const perm = require_backend.range(t.ndim);
3467
- perm.splice(src, 1);
3468
- perm.splice(dst, 0, src);
3469
- return transpose$1(t, perm);
3470
- }
3471
- function moveBatchAxis(axisSize, src, dst, x) {
3472
- if (src === null) {
3473
- const targetShape = [...x.shape];
3474
- targetShape.splice(dst, 0, axisSize);
3475
- return broadcast(x, targetShape, [dst]);
3476
- } else if (src === dst) return x;
3477
- else return moveaxis(x, src, dst);
3478
- }
3479
- var BatchTracer = class extends Tracer {
3480
- constructor(trace$1, val, batchDim) {
3481
- super(trace$1);
3482
- this.val = val;
3483
- this.batchDim = batchDim;
3484
- }
3485
- get aval() {
3486
- if (this.batchDim === null) return this.val.aval;
3487
- else return mappedAval(this.batchDim, this.val.aval);
3488
- }
3489
- toString() {
3490
- return `BatchTracer(${this.val.toString()}, ${this.batchDim})`;
3491
- }
3492
- get ref() {
3493
- this.val.ref;
3494
- return this;
3495
- }
3496
- dispose() {
3497
- this.val.dispose();
3498
- }
3499
- fullLower() {
3500
- if (this.batchDim === null) return this.val.fullLower();
3501
- else return this;
3502
- }
3503
- };
3504
- var BatchTrace = class extends Trace {
3505
- pure(val) {
3506
- return this.lift(pureArray(val));
3507
- }
3508
- lift(val) {
3509
- return new BatchTracer(this, val, null);
3510
- }
3511
- processPrimitive(primitive, tracers, params) {
3512
- const [valsIn, bdimsIn] = require_backend.unzip2(tracers.map((t) => [t.val, t.batchDim]));
3513
- const vmapRule = vmapRules[primitive];
3514
- if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
3515
- if (bdimsIn.every((d) => d === null)) {
3516
- const valOuts$1 = bind(primitive, valsIn, params);
3517
- return valOuts$1.map((x) => new BatchTracer(this, x, null));
3610
+ [Primitive.Where]: broadcastBatcher(where$1),
3611
+ [Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
3612
+ if (indicesBdim.every((d) => d === null)) {
3613
+ require_backend.assertNonNull(xBdim);
3614
+ const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3615
+ let newBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
3616
+ let newOutDim = outDim;
3617
+ if (newOutDim < newBdim) newBdim += axis.length;
3618
+ else newOutDim += 1;
3619
+ return [[gather(x, indices, newAxis, newOutDim)], [newBdim]];
3518
3620
  }
3519
- const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
3520
- return require_backend.zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
3521
- }
3522
- get axisSize() {
3523
- return this.main.globalData;
3524
- }
3525
- };
3526
- /**
3527
- * Process a primitive with built-in broadcasting.
3528
- *
3529
- * Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
3530
- */
3531
- function broadcastBatcher(op) {
3532
- return (axisSize, args, dims) => {
3533
- if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
3534
- const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
3535
- const firstIdx = dims.findIndex((d) => d !== null);
3536
- const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
3537
- if (require_backend.zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
3538
- args = args.map((x, i) => {
3539
- if (dims[i] === null) return x;
3540
- x = moveBatchAxis(axisSize, dims[i], 0, x);
3541
- if (x.ndim < nd) x = x.reshape([
3542
- x.shape[0],
3543
- ...require_backend.rep(nd - x.ndim, 1),
3544
- ...x.shape.slice(1)
3621
+ const nd = Math.max(...indices.map((m, i) => ndim$1(m) + (indicesBdim[i] === null ? 1 : 0)));
3622
+ indices = indices.map((m, i) => {
3623
+ if (indicesBdim[i] === null) return m;
3624
+ m = moveBatchAxis(axisSize, indicesBdim[i], 0, m);
3625
+ if (m.ndim < nd) m = m.reshape([
3626
+ m.shape[0],
3627
+ ...require_backend.rep(nd - m.ndim, 1),
3628
+ ...m.shape.slice(1)
3545
3629
  ]);
3546
- return x;
3547
- });
3548
- return [[op(...args)], [0]];
3549
- };
3550
- }
3551
- function unopBatcher(op) {
3552
- return (axisSize, [x], [xBdim], params) => {
3553
- return [[op(x, params)], [xBdim]];
3554
- };
3555
- }
3556
- const vmapRules = {
3557
- [Primitive.Add]: broadcastBatcher(add$1),
3558
- [Primitive.Mul]: broadcastBatcher(mul),
3559
- [Primitive.Idiv]: broadcastBatcher(idiv),
3560
- [Primitive.Mod]: broadcastBatcher(mod),
3561
- [Primitive.Neg]: unopBatcher(neg),
3562
- [Primitive.Reciprocal]: unopBatcher(reciprocal$1),
3563
- [Primitive.Floor]: unopBatcher(floor$1),
3564
- [Primitive.Ceil]: unopBatcher(ceil$1),
3565
- [Primitive.StopGradient]: unopBatcher(stopGradient),
3566
- [Primitive.Cast]: unopBatcher((x, { dtype }) => cast(x, dtype)),
3567
- [Primitive.Bitcast]: unopBatcher((x, { dtype }) => bitcast(x, dtype)),
3568
- [Primitive.Sin]: unopBatcher(sin$1),
3569
- [Primitive.Cos]: unopBatcher(cos$1),
3570
- [Primitive.Asin]: unopBatcher(asin$1),
3571
- [Primitive.Atan]: unopBatcher(atan$1),
3572
- [Primitive.Exp]: unopBatcher(exp$1),
3573
- [Primitive.Log]: unopBatcher(log$1),
3574
- [Primitive.Erf]: unopBatcher(erf$1),
3575
- [Primitive.Erfc]: unopBatcher(erfc$1),
3576
- [Primitive.Sqrt]: unopBatcher(sqrt$1),
3577
- [Primitive.Min]: broadcastBatcher(min$1),
3578
- [Primitive.Max]: broadcastBatcher(max$1),
3579
- [Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
3580
- require_backend.assertNonNull(xBdim);
3581
- const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3582
- const outBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
3583
- return [[reduce(x, op, newAxis)], [outBdim]];
3584
- },
3585
- [Primitive.Dot](axisSize, [x, y], [xBdim, yBdim]) {
3586
- x = moveBatchAxis(axisSize, xBdim, x.ndim - (xBdim === null ? 1 : 2), x);
3587
- y = moveBatchAxis(axisSize, yBdim, y.ndim - (yBdim === null ? 1 : 2), y);
3588
- const z = dot$2(x, y);
3589
- return [[z], [z.ndim - 1]];
3590
- },
3591
- [Primitive.Conv](axisSize, [x, y], [xBdim, yBdim], params) {
3592
- x = moveBatchAxis(axisSize, xBdim, 0, x);
3593
- y = moveBatchAxis(axisSize, yBdim, 0, y);
3594
- const z = conv$1(x, y, {
3595
- ...params,
3596
- vmapDims: params.vmapDims + 1
3597
- });
3598
- return [[z], [0]];
3599
- },
3600
- [Primitive.Compare](axisSize, args, dims, { op }) {
3601
- return broadcastBatcher((x, y) => compare(x, y, op))(axisSize, args, dims, {});
3630
+ return m;
3631
+ });
3632
+ if (xBdim === null) return [[gather(x, indices, axis, outDim)], [outDim]];
3633
+ else {
3634
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3635
+ const newAxis = [0, ...axis.map((ax) => ax + 1)];
3636
+ const extraBatchIndex = arange(axisSize).reshape([-1, ...require_backend.rep(nd - 1, 1)]);
3637
+ indices.splice(0, 0, extraBatchIndex);
3638
+ return [[gather(x, indices, newAxis, outDim)], [outDim]];
3639
+ }
3602
3640
  },
3603
- [Primitive.Where]: broadcastBatcher(where$1),
3604
3641
  [Primitive.Transpose](axisSize, [x], [xBdim], { perm }) {
3605
3642
  require_backend.assertNonNull(xBdim);
3606
3643
  const newPerm = perm.map((p) => p + (xBdim <= p ? 1 : 0));
@@ -3632,42 +3669,53 @@ const vmapRules = {
3632
3669
  const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
3633
3670
  return [[pad$1(x, newWidth)], [xBdim]];
3634
3671
  },
3635
- [Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
3636
- if (indicesBdim.every((d) => d === null)) {
3637
- require_backend.assertNonNull(xBdim);
3638
- const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3639
- let newBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
3640
- let newOutDim = outDim;
3641
- if (newOutDim < newBdim) newBdim += axis.length;
3642
- else newOutDim += 1;
3643
- return [[gather(x, indices, newAxis, newOutDim)], [newBdim]];
3644
- }
3645
- const nd = Math.max(...indices.map((m, i) => ndim$1(m) + (indicesBdim[i] === null ? 1 : 0)));
3646
- indices = indices.map((m, i) => {
3647
- if (indicesBdim[i] === null) return m;
3648
- m = moveBatchAxis(axisSize, indicesBdim[i], 0, m);
3649
- if (m.ndim < nd) m = m.reshape([
3650
- m.shape[0],
3651
- ...require_backend.rep(nd - m.ndim, 1),
3652
- ...m.shape.slice(1)
3672
+ [Primitive.Sort](axisSize, [x], [xBdim]) {
3673
+ require_backend.assertNonNull(xBdim);
3674
+ if (xBdim !== x.ndim - 1) return [[sort$1(x)], [xBdim]];
3675
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3676
+ return [[sort$1(x)], [0]];
3677
+ },
3678
+ [Primitive.Argsort](axisSize, [x], [xBdim]) {
3679
+ require_backend.assertNonNull(xBdim);
3680
+ if (xBdim !== x.ndim - 1) return [argsort$1(x), [xBdim, xBdim]];
3681
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3682
+ return [argsort$1(x), [0, 0]];
3683
+ },
3684
+ [Primitive.TriangularSolve](axisSize, [a, b], [aBdim, bBdim], { unitDiagonal }) {
3685
+ if (aBdim === null) {
3686
+ b = moveBatchAxis(axisSize, bBdim, -3, b);
3687
+ const [s, m, n] = b.shape.slice(-3);
3688
+ b = b.reshape([
3689
+ ...b.shape.slice(0, -3),
3690
+ s * m,
3691
+ n
3653
3692
  ]);
3654
- return m;
3655
- });
3656
- if (xBdim === null) return [[gather(x, indices, axis, outDim)], [outDim]];
3657
- else {
3658
- x = moveBatchAxis(axisSize, xBdim, 0, x);
3659
- const newAxis = [0, ...axis.map((ax) => ax + 1)];
3660
- const extraBatchIndex = arange(axisSize).reshape([-1, ...require_backend.rep(nd - 1, 1)]);
3661
- indices.splice(0, 0, extraBatchIndex);
3662
- return [[gather(x, indices, newAxis, outDim)], [outDim]];
3693
+ let x$1 = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
3694
+ x$1 = x$1.reshape([
3695
+ ...b.shape.slice(0, -2),
3696
+ s,
3697
+ m,
3698
+ n
3699
+ ]);
3700
+ return [[x$1], [x$1.ndim - 3]];
3663
3701
  }
3702
+ a = moveBatchAxis(axisSize, aBdim, 0, a);
3703
+ b = moveBatchAxis(axisSize, bBdim, 0, b);
3704
+ const x = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
3705
+ return [[x], [0]];
3664
3706
  },
3665
- [Primitive.JitCall](axisSize, args, dims, { name, jaxpr }) {
3666
- const { newJaxpr, newConsts } = vmapJaxpr(jaxpr, axisSize, dims);
3667
- const outs = bind(Primitive.JitCall, [...newConsts.map((c) => c.ref), ...args], {
3707
+ [Primitive.Cholesky](axisSize, [x], [xBdim]) {
3708
+ require_backend.assertNonNull(xBdim);
3709
+ if (xBdim < x.ndim - 2) return [[cholesky$2(x)], [xBdim]];
3710
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3711
+ return [[cholesky$2(x)], [0]];
3712
+ },
3713
+ [Primitive.Jit](axisSize, args, dims, { name, jaxpr }) {
3714
+ const newJaxpr = vmapJaxpr(jaxpr, axisSize, dims);
3715
+ const outs = bind(Primitive.Jit, [...newJaxpr.consts.map((c) => c.ref), ...args], {
3668
3716
  name: `${name}_vmap`,
3669
- jaxpr: newJaxpr,
3670
- numConsts: newConsts.length
3717
+ jaxpr: newJaxpr.jaxpr,
3718
+ numConsts: newJaxpr.consts.length
3671
3719
  });
3672
3720
  return [outs, require_backend.rep(outs.length, 0)];
3673
3721
  }
@@ -3683,14 +3731,10 @@ function vmapJaxpr(jaxpr, axisSize, dims) {
3683
3731
  shape$1.splice(dims[i], 0, axisSize);
3684
3732
  return new ShapedArray(shape$1, v.aval.dtype, v.aval.weakType);
3685
3733
  });
3686
- const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((args) => vmapFlat(jaxprAsFun(jaxpr), dims, args))(inAvals);
3687
- const result = {
3688
- newJaxpr,
3689
- newConsts
3690
- };
3734
+ const { jaxpr: newJaxpr } = makeJaxpr$1((args) => vmapFlat(jaxprAsFun(jaxpr), dims, args))(inAvals);
3691
3735
  if (!vmapJaxprCache.has(jaxpr)) vmapJaxprCache.set(jaxpr, /* @__PURE__ */ new Map());
3692
- vmapJaxprCache.get(jaxpr).set(cacheKey, result);
3693
- return result;
3736
+ vmapJaxprCache.get(jaxpr).set(cacheKey, newJaxpr);
3737
+ return newJaxpr;
3694
3738
  }
3695
3739
  function vmapFlat(f, inAxes, args) {
3696
3740
  let axisSize = void 0;
@@ -3704,7 +3748,7 @@ function vmapFlat(f, inAxes, args) {
3704
3748
  if (axisSize === void 0) throw new TypeError("vmap requires at least one mapped axis");
3705
3749
  let valsOut, bdimsOut;
3706
3750
  try {
3707
- var _usingCtx$1 = (0, import_usingCtx.default)();
3751
+ var _usingCtx$1 = (0, import_usingCtx$1.default)();
3708
3752
  const main = _usingCtx$1.u(newMain(BatchTrace, axisSize));
3709
3753
  const trace$1 = new BatchTrace(main);
3710
3754
  const tracersIn = args.map((x, i) => inAxes[i] === null ? pureArray(x) : new BatchTracer(trace$1, pureArray(x), inAxes[i]));
@@ -3745,6 +3789,261 @@ function jacfwd$1(f) {
3745
3789
  };
3746
3790
  }
3747
3791
 
3792
+ //#endregion
3793
+ //#region src/frontend/jvp.ts
3794
+ var import_usingCtx = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
3795
+ var JVPTracer = class extends Tracer {
3796
+ constructor(trace$1, primal, tangent) {
3797
+ super(trace$1);
3798
+ this.primal = primal;
3799
+ this.tangent = tangent;
3800
+ }
3801
+ get aval() {
3802
+ return this.primal.aval;
3803
+ }
3804
+ toString() {
3805
+ return `JVPTracer(${this.primal.toString()}, ${this.tangent.toString()})`;
3806
+ }
3807
+ get ref() {
3808
+ this.primal.ref, this.tangent.ref;
3809
+ return this;
3810
+ }
3811
+ dispose() {
3812
+ this.primal.dispose();
3813
+ this.tangent.dispose();
3814
+ }
3815
+ };
3816
+ var JVPTrace = class extends Trace {
3817
+ pure(val) {
3818
+ return this.lift(pureArray(val));
3819
+ }
3820
+ lift(val) {
3821
+ return new JVPTracer(this, val, zerosLike$1(val.ref));
3822
+ }
3823
+ processPrimitive(primitive, tracers, params) {
3824
+ const [primalsIn, tangentsIn] = require_backend.unzip2(tracers.map((x) => [x.primal, x.tangent]));
3825
+ const jvpRule = jvpRules[primitive];
3826
+ if (jvpRule === void 0) throw new Error(`No JVP rule for: ${primitive}`);
3827
+ const [primalsOut, tangentsOut] = jvpRule(primalsIn, tangentsIn, params);
3828
+ return require_backend.zip(primalsOut, tangentsOut).map(([x, t]) => new JVPTracer(this, x, t));
3829
+ }
3830
+ };
3831
+ /** Rule that applies the same operation to primals and tangents. */
3832
+ function linearTangentsJvp(primitive) {
3833
+ return (primals, tangents, params) => {
3834
+ const ys = bind(primitive, primals, params);
3835
+ const dys = bind(primitive, tangents, params);
3836
+ return [ys, dys];
3837
+ };
3838
+ }
3839
+ /** Rule for product of gradients in bilinear operations. */
3840
+ function bilinearTangentsJvp(primitive) {
3841
+ return ([x, y], [dx, dy], params) => {
3842
+ const primal = bind1(primitive, [x.ref, y.ref], params);
3843
+ const tangent = bind1(primitive, [x, dy], params).add(bind1(primitive, [dx, y], params));
3844
+ return [[primal], [tangent]];
3845
+ };
3846
+ }
3847
+ /** Rule that zeros out any tangents. */
3848
+ function zeroTangentsJvp(primitive) {
3849
+ return (primals, tangents, params) => {
3850
+ for (const t of tangents) t.dispose();
3851
+ const ys = bind(primitive, primals, params);
3852
+ return [ys, ys.map((y) => zerosLike$1(y.ref))];
3853
+ };
3854
+ }
3855
+ /** Compute `a @ b.T`, batched to last two axes. */
3856
+ function batchMatmulT(a, b) {
3857
+ return dot$2(a.reshape(a.shape.toSpliced(-1, 0, 1)), b.reshape(b.shape.toSpliced(-2, 0, 1)));
3858
+ }
3859
+ /** Batch matrix transpose. */
3860
+ function mT(a) {
3861
+ return moveaxis(a, -2, -1);
3862
+ }
3863
+ const jvpRules = {
3864
+ [Primitive.Add]: linearTangentsJvp(Primitive.Add),
3865
+ [Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
3866
+ [Primitive.Idiv]: zeroTangentsJvp(Primitive.Idiv),
3867
+ [Primitive.Mod]([x, y], [dx, dy]) {
3868
+ if (!require_backend.isFloatDtype(x.dtype) && !require_backend.isFloatDtype(y.dtype)) {
3869
+ dx.dispose();
3870
+ dy.dispose();
3871
+ return [[x.ref, y.ref], [zerosLike$1(x), zerosLike$1(y)]];
3872
+ }
3873
+ const q = idiv(x.ref, y.ref);
3874
+ return [[mod(x, y)], [dx.sub(dy.mul(q))]];
3875
+ },
3876
+ [Primitive.Min]([x, y], [dx, dy]) {
3877
+ return [[min$1(x.ref, y.ref)], [where$1(less$1(y, x), dy, dx)]];
3878
+ },
3879
+ [Primitive.Max]([x, y], [dx, dy]) {
3880
+ return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
3881
+ },
3882
+ [Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
3883
+ [Primitive.Reciprocal]([x], [dx]) {
3884
+ const xRecip = reciprocal$1(x.ref);
3885
+ return [[xRecip.ref], [neg(xRecip.ref.mul(xRecip)).mul(dx)]];
3886
+ },
3887
+ [Primitive.Floor]: zeroTangentsJvp(Primitive.Floor),
3888
+ [Primitive.Ceil]: zeroTangentsJvp(Primitive.Ceil),
3889
+ [Primitive.StopGradient]: zeroTangentsJvp(Primitive.StopGradient),
3890
+ [Primitive.Cast]([x], [dx], { dtype }) {
3891
+ if (x.dtype === dtype) return [[x], [dx]];
3892
+ if (require_backend.isFloatDtype(dtype) && require_backend.isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
3893
+ else {
3894
+ dx.dispose();
3895
+ return [[cast(x.ref, dtype)], [zerosLike$1(x)]];
3896
+ }
3897
+ },
3898
+ [Primitive.Bitcast]([x], [dx], { dtype }) {
3899
+ if (x.dtype === dtype) return [[x], [dx]];
3900
+ dx.dispose();
3901
+ return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
3902
+ },
3903
+ [Primitive.Sin]([x], [dx]) {
3904
+ return [[sin$1(x.ref)], [cos$1(x).mul(dx)]];
3905
+ },
3906
+ [Primitive.Cos]([x], [dx]) {
3907
+ return [[cos$1(x.ref)], [neg(sin$1(x)).mul(dx)]];
3908
+ },
3909
+ [Primitive.Asin]([x], [dx]) {
3910
+ const denom = sqrt$1(reciprocal$1(cast(1, x.dtype).sub(x.ref.mul(x.ref))));
3911
+ return [[asin$1(x)], [denom.mul(dx)]];
3912
+ },
3913
+ [Primitive.Atan]([x], [dx]) {
3914
+ const denom = cast(1, x.dtype).add(x.ref.mul(x.ref));
3915
+ return [[atan$1(x)], [dx.div(denom)]];
3916
+ },
3917
+ [Primitive.Exp]([x], [dx]) {
3918
+ const z = exp$1(x);
3919
+ return [[z.ref], [z.mul(dx)]];
3920
+ },
3921
+ [Primitive.Log]([x], [dx]) {
3922
+ return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
3923
+ },
3924
+ [Primitive.Erf]([x], [dx]) {
3925
+ const coeff = 2 / Math.sqrt(Math.PI);
3926
+ const expTerm = exp$1(neg(x.ref.mul(x.ref)));
3927
+ return [[erf$1(x)], [expTerm.mul(coeff).mul(dx)]];
3928
+ },
3929
+ [Primitive.Erfc]([x], [dx]) {
3930
+ const coeff = -2 / Math.sqrt(Math.PI);
3931
+ const expTerm = exp$1(neg(x.ref.mul(x.ref)));
3932
+ return [[erfc$1(x)], [expTerm.mul(coeff).mul(dx)]];
3933
+ },
3934
+ [Primitive.Sqrt]([x], [dx]) {
3935
+ const z = sqrt$1(x);
3936
+ return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
3937
+ },
3938
+ [Primitive.Reduce]([x], [dx], { op, axis }) {
3939
+ if (op === require_backend.AluOp.Add) return [[reduce(x, op, axis)], [reduce(dx, op, axis)]];
3940
+ else if (op === require_backend.AluOp.Mul) {
3941
+ const primal = reduce(x.ref, op, axis);
3942
+ const tangent = broadcast(primal.ref, x.shape, axis).mul(reciprocal$1(x)).mul(dx).sum(axis);
3943
+ return [[primal], [tangent]];
3944
+ } else if (op === require_backend.AluOp.Min || op === require_backend.AluOp.Max) {
3945
+ const primal = reduce(x.ref, op, axis);
3946
+ const notMin = notEqual$1(x, broadcast(primal.ref, x.shape, axis));
3947
+ const minCount = where$1(notMin.ref, 0, 1).sum(axis);
3948
+ const tangent = where$1(notMin, 0, dx).sum(axis).div(minCount);
3949
+ return [[primal], [tangent]];
3950
+ } else throw new Error(`JVP rule not implemented for reduce op: ${op}`);
3951
+ },
3952
+ [Primitive.Pool]: linearTangentsJvp(Primitive.Pool),
3953
+ [Primitive.PoolTranspose]: linearTangentsJvp(Primitive.PoolTranspose),
3954
+ [Primitive.Dot]: bilinearTangentsJvp(Primitive.Dot),
3955
+ [Primitive.Conv]: bilinearTangentsJvp(Primitive.Conv),
3956
+ [Primitive.Compare]: zeroTangentsJvp(Primitive.Compare),
3957
+ [Primitive.Where]([cond, x, y], [dcond, dx, dy]) {
3958
+ dcond.dispose();
3959
+ return [[where$1(cond.ref, x, y)], [where$1(cond, dx, dy)]];
3960
+ },
3961
+ [Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
3962
+ [Primitive.Gather]([x, ...indices], [dx, ..._], { axis, outDim }) {
3963
+ const indicesRef = indices.map((t) => t.ref);
3964
+ return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
3965
+ },
3966
+ [Primitive.Transpose]: linearTangentsJvp(Primitive.Transpose),
3967
+ [Primitive.Broadcast]: linearTangentsJvp(Primitive.Broadcast),
3968
+ [Primitive.Reshape]: linearTangentsJvp(Primitive.Reshape),
3969
+ [Primitive.Flip]: linearTangentsJvp(Primitive.Flip),
3970
+ [Primitive.Shrink]: linearTangentsJvp(Primitive.Shrink),
3971
+ [Primitive.Pad]: linearTangentsJvp(Primitive.Pad),
3972
+ [Primitive.Sort]([x], [dx]) {
3973
+ const [y, idx] = argsort$1(x);
3974
+ return [[y], [gather(dx, [idx], [-1], -1)]];
3975
+ },
3976
+ [Primitive.Argsort]([x], [dx]) {
3977
+ const [y, idx] = argsort$1(x);
3978
+ return [[y, idx.ref], [gather(dx, [idx.ref], [-1], -1), zerosLike$1(idx)]];
3979
+ },
3980
+ [Primitive.TriangularSolve]([a, b], [da, db], { unitDiagonal }) {
3981
+ const x = triangularSolve$1(a.ref, b, { unitDiagonal });
3982
+ const dax = batchMatmulT(da, x.ref);
3983
+ const rhsT = db.sub(mT(dax));
3984
+ const dx = triangularSolve$1(a, rhsT, { unitDiagonal });
3985
+ return [[x], [dx]];
3986
+ },
3987
+ [Primitive.Cholesky]([a], [da]) {
3988
+ const L = cholesky$2(a.ref);
3989
+ da = da.ref.add(mT(da)).mul(.5);
3990
+ const W = triangularSolve$1(L.ref, da, { lower: true });
3991
+ const ST = triangularSolve$1(L.ref, mT(W), { lower: true });
3992
+ const dL = batchMatmulT(L.ref, triu(ST.ref, 1).add(triu(ST)).mul(.5));
3993
+ return [[L], [dL]];
3994
+ },
3995
+ [Primitive.Jit](primals, tangents, { name, jaxpr }) {
3996
+ const newJaxpr = jvpJaxpr(jaxpr);
3997
+ const outs = bind(Primitive.Jit, [
3998
+ ...newJaxpr.consts.map((c) => c.ref),
3999
+ ...primals,
4000
+ ...tangents
4001
+ ], {
4002
+ name: `${name}_jvp`,
4003
+ jaxpr: newJaxpr.jaxpr,
4004
+ numConsts: newJaxpr.consts.length
4005
+ });
4006
+ const n = outs.length / 2;
4007
+ if (!Number.isInteger(n)) throw new Error("internal: JVP Jaxpr output length is not even");
4008
+ const [primalsOut, tangentsOut] = [outs.slice(0, n), outs.slice(n)];
4009
+ return [primalsOut, tangentsOut];
4010
+ }
4011
+ };
4012
+ const jvpJaxprCache = /* @__PURE__ */ new Map();
4013
+ function jvpJaxpr(jaxpr) {
4014
+ if (jvpJaxprCache.has(jaxpr)) return jvpJaxprCache.get(jaxpr);
4015
+ const inAvals = jaxpr.inBinders.map((v) => v.aval);
4016
+ const { jaxpr: newJaxpr } = makeJaxpr$1((primals, tangents) => jvpFlat(jaxprAsFun(jaxpr), primals, tangents))(inAvals, inAvals);
4017
+ jvpJaxprCache.set(jaxpr, newJaxpr);
4018
+ return newJaxpr;
4019
+ }
4020
+ function jvpFlat(f, primals, tangents) {
4021
+ try {
4022
+ var _usingCtx$1 = (0, import_usingCtx.default)();
4023
+ const main = _usingCtx$1.u(newMain(JVPTrace));
4024
+ const trace$1 = new JVPTrace(main);
4025
+ const tracersIn = require_backend.zip(primals, tangents).map(([x, t]) => new JVPTracer(trace$1, pureArray(x), pureArray(t)));
4026
+ const outs = f(...tracersIn);
4027
+ const tracersOut = outs.map((out) => fullRaise(trace$1, out));
4028
+ return require_backend.unzip2(tracersOut.map((t) => [t.primal, t.tangent]));
4029
+ } catch (_) {
4030
+ _usingCtx$1.e = _;
4031
+ } finally {
4032
+ _usingCtx$1.d();
4033
+ }
4034
+ }
4035
+ function jvp$1(f, primals, tangents) {
4036
+ const [primalsFlat, inTree] = flatten(primals);
4037
+ const [tangentsFlat, inTree2] = flatten(tangents);
4038
+ if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
4039
+ const [flatFun, outTree] = flattenFun(f, inTree);
4040
+ const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
4041
+ if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
4042
+ const primalsOut = unflatten(outTree.value, primalsOutFlat);
4043
+ const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
4044
+ return [primalsOut, tangentsOut];
4045
+ }
4046
+
3748
4047
  //#endregion
3749
4048
  //#region src/frontend/linearize.ts
3750
4049
  /** Array value that can either be known or unknown. */
@@ -3775,11 +4074,10 @@ function partialEvalFlat(f, pvalsIn) {
3775
4074
  const tracersOut = outs.map((out) => fullRaise(trace$1, out));
3776
4075
  const pvalsOut = tracersOut.map((t) => t.pval);
3777
4076
  const unknownTracersOut = tracersOut.filter((t) => !t.pval.isKnown);
3778
- const { jaxpr, consts } = partialEvalGraphToJaxpr(unknownTracersIn, unknownTracersOut);
4077
+ const jaxpr = partialEvalGraphToJaxpr(unknownTracersIn, unknownTracersOut);
3779
4078
  return {
3780
4079
  jaxpr,
3781
- pvalsOut,
3782
- consts
4080
+ pvalsOut
3783
4081
  };
3784
4082
  }
3785
4083
  /**
@@ -3796,22 +4094,19 @@ function linearizeFlatUtil(f, primalsIn) {
3796
4094
  const [primalsOut$1, tangentsOut] = jvp$1(f, x.slice(0, k), x.slice(k, 2 * k));
3797
4095
  return [...primalsOut$1, ...tangentsOut];
3798
4096
  };
3799
- const { jaxpr, pvalsOut, consts } = partialEvalFlat(fJvp, pvalsIn);
4097
+ const { jaxpr, pvalsOut } = partialEvalFlat(fJvp, pvalsIn);
3800
4098
  const primalPvals = pvalsOut.slice(0, pvalsOut.length / 2);
3801
4099
  if (!primalPvals.every((pval) => pval.isKnown)) throw new Error("Not all primal values are known after partial evaluation");
3802
4100
  const primalsOut = primalPvals.map((pval) => pval.val);
3803
4101
  return {
3804
4102
  primalsOut,
3805
- jaxpr,
3806
- consts
4103
+ jaxpr
3807
4104
  };
3808
4105
  }
3809
4106
  function linearizeFlat(f, primalsIn) {
3810
- const { primalsOut, jaxpr, consts } = linearizeFlatUtil(f, primalsIn);
3811
- const fLin = (...tangents) => evalJaxpr(jaxpr, [...consts.map((c) => c.ref), ...tangents]);
3812
- const dispose$1 = () => {
3813
- for (const c of consts) c.dispose();
3814
- };
4107
+ const { primalsOut, jaxpr } = linearizeFlatUtil(f, primalsIn);
4108
+ const fLin = (...tangents) => evalJaxpr(jaxpr.jaxpr, [...jaxpr.consts.map((c) => c.ref), ...tangents]);
4109
+ const dispose$1 = () => jaxpr.dispose();
3815
4110
  return [
3816
4111
  primalsOut,
3817
4112
  fLin,
@@ -3895,7 +4190,7 @@ var PartialEvalTrace = class extends Trace {
3895
4190
  }
3896
4191
  processPrimitive(primitive, tracers, params) {
3897
4192
  if (tracers.every((t) => t.pval.isKnown)) return bind(primitive, tracers.map((t) => t.fullLower()), params);
3898
- if (primitive === Primitive.JitCall) {
4193
+ if (primitive === Primitive.Jit) {
3899
4194
  const { name, jaxpr, numConsts } = params;
3900
4195
  return this.#partialEvalJaxpr(name, jaxpr, numConsts, tracers);
3901
4196
  }
@@ -3921,14 +4216,14 @@ var PartialEvalTrace = class extends Trace {
3921
4216
  * Evaluate a Jaxpr on a set of PartialEvalTracers, computing as many known
3922
4217
  * values as possible (with JIT) and forwarding the unknown ones.
3923
4218
  *
3924
- * Used when encountering a JitCall rule during the trace.
4219
+ * Used when encountering a Jit rule during the trace.
3925
4220
  */
3926
4221
  #partialEvalJaxpr(name, jaxpr, numConsts, tracers) {
3927
4222
  jaxpr = jaxpr.flatten();
3928
4223
  const inUnknowns = tracers.map((t) => !t.pval.isKnown);
3929
4224
  const { jaxpr1, jaxpr2, outUnknowns, numRes } = partialEvalJaxpr(jaxpr, inUnknowns);
3930
4225
  const [knownTracers, unknownTracers] = require_backend.partitionList(inUnknowns, tracers);
3931
- const outs1Res = bind(Primitive.JitCall, knownTracers.map((t) => t.ref.fullLower()), {
4226
+ const outs1Res = bind(Primitive.Jit, knownTracers.map((t) => t.ref.fullLower()), {
3932
4227
  name: `${name}_peval`,
3933
4228
  jaxpr: jaxpr1,
3934
4229
  numConsts: 0
@@ -3938,7 +4233,7 @@ var PartialEvalTrace = class extends Trace {
3938
4233
  const resTracers = res.map((x) => this.instantiateConst(fullRaise(this, x)));
3939
4234
  const recipe = {
3940
4235
  type: "JaxprEqn",
3941
- prim: Primitive.JitCall,
4236
+ prim: Primitive.Jit,
3942
4237
  tracersIn: resTracers.concat(unknownTracers),
3943
4238
  params: {
3944
4239
  name: `${name}_resid`,
@@ -3967,7 +4262,7 @@ function partialEvalJaxpr(jaxpr, inUnknowns, instantiate) {
3967
4262
  const eqns1 = [];
3968
4263
  const eqns2 = [];
3969
4264
  for (const eqn of jaxpr.eqns) {
3970
- if (eqn.primitive === Primitive.JitCall) throw new TypeError("partialEvalJaxpr requires flattened Jaxpr");
4265
+ if (eqn.primitive === Primitive.Jit) throw new TypeError("partialEvalJaxpr requires flattened Jaxpr");
3971
4266
  const hasUnknowns = eqn.inputs.some((x) => x instanceof Var && !knownVars.has(x));
3972
4267
  if (hasUnknowns) {
3973
4268
  for (const x of eqn.inputs) if (x instanceof Var && knownVars.has(x)) residuals.add(x);
@@ -4042,10 +4337,7 @@ function partialEvalGraphToJaxpr(tracersIn, tracersOut) {
4042
4337
  for (const t of tracersOut) t.dispose();
4043
4338
  jaxpr = jaxpr.simplify();
4044
4339
  if (require_backend.DEBUG >= 5) console.info("jaxpr from partial evaluation:\n" + jaxpr.toString());
4045
- return {
4046
- jaxpr,
4047
- consts
4048
- };
4340
+ return new ClosedJaxpr(jaxpr, consts);
4049
4341
  }
4050
4342
  /** Marker type for pullback, used by transpose rules. */
4051
4343
  var UndefPrimal = class {
@@ -4237,317 +4529,142 @@ const transposeRules = {
4237
4529
  cond.dispose();
4238
4530
  return cts;
4239
4531
  },
4240
- [Primitive.Transpose]([ct], [x], { perm }) {
4241
- if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Transpose);
4242
- return [transpose$1(ct, require_backend.invertPermutation(perm))];
4243
- },
4244
- [Primitive.Broadcast]([ct], [x], { axis }) {
4245
- if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Broadcast);
4246
- return [reduce(ct, require_backend.AluOp.Add, axis)];
4247
- },
4248
- [Primitive.Reshape]([ct], [x], _) {
4249
- if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Reshape);
4250
- return [reshape$1(ct, x.aval.shape)];
4251
- },
4252
- [Primitive.Flip]([ct], [x], { axis }) {
4253
- if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Flip);
4254
- return [flip$1(ct, axis)];
4255
- },
4256
- [Primitive.Shrink]([ct], [x], { slice }) {
4257
- if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Shrink);
4258
- const width = slice.map(([s, e$1], i) => [s, x.aval.shape[i] - e$1]);
4259
- return [pad$1(ct, width)];
4260
- },
4261
- [Primitive.Pad]([ct], [x], { width }) {
4262
- if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Pad);
4263
- const slice = width.map(([s, _e], i) => [s, s + x.aval.shape[i]]);
4264
- return [shrink(ct, slice)];
4265
- },
4266
4532
  [Primitive.Gather]([ct], [x, ...indices], { axis, outDim }) {
4267
4533
  if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
4268
4534
  if (indices.some((i) => i instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
4269
4535
  throw new Error("Gather transpose rule is not yet implemented, requires complex Scatter sum operation");
4270
4536
  },
4271
- [Primitive.JitCall](cts, args, { name, jaxpr }) {
4272
- const undefPrimals = args.map((x) => x instanceof UndefPrimal);
4273
- const { newJaxpr, newConsts } = transposeJaxpr(jaxpr, undefPrimals);
4274
- const residuals = args.filter((x, i$1) => !undefPrimals[i$1]);
4275
- const outs = bind(Primitive.JitCall, [
4276
- ...newConsts.map((c) => c.ref),
4277
- ...residuals,
4278
- ...cts
4279
- ], {
4280
- name: `${name}_t`,
4281
- jaxpr: newJaxpr,
4282
- numConsts: newConsts.length
4283
- });
4284
- let i = 0;
4285
- return undefPrimals.map((isUndef) => isUndef ? outs[i++] : null);
4286
- }
4287
- };
4288
- const transposeJaxprCache = /* @__PURE__ */ new Map();
4289
- function transposeJaxpr(jaxpr, undefPrimals) {
4290
- const cacheKey = JSON.stringify(undefPrimals);
4291
- const prevResult = transposeJaxprCache.get(jaxpr)?.get(cacheKey);
4292
- if (prevResult) return prevResult;
4293
- const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
4294
- const forwardInTypes = inTypes.filter((_, i) => !undefPrimals[i]);
4295
- const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((forwardIn, cotangents) => {
4296
- const args = [];
4297
- let forwardInIdx = 0;
4298
- for (let i = 0; i < undefPrimals.length; i++) if (undefPrimals[i]) args.push(new UndefPrimal(inTypes[i]));
4299
- else args.push(forwardIn[forwardInIdx++]);
4300
- return evalJaxprTransposed(jaxpr, args, cotangents);
4301
- })(forwardInTypes, outTypes);
4302
- typecheckJaxpr(newJaxpr);
4303
- const result = {
4304
- newJaxpr,
4305
- newConsts
4306
- };
4307
- if (!transposeJaxprCache.has(jaxpr)) transposeJaxprCache.set(jaxpr, /* @__PURE__ */ new Map());
4308
- transposeJaxprCache.get(jaxpr).set(cacheKey, result);
4309
- return result;
4310
- }
4311
- function vjpFlat(f, primalsIn) {
4312
- const { primalsOut, jaxpr, consts } = linearizeFlatUtil(f, primalsIn);
4313
- const fVjp = (...cotangents) => {
4314
- const transposeInputs = [...consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
4315
- return evalJaxprTransposed(jaxpr, transposeInputs, cotangents);
4316
- };
4317
- const dispose$1 = () => {
4318
- for (const c of consts) c.dispose();
4319
- };
4320
- return [
4321
- primalsOut,
4322
- fVjp,
4323
- dispose$1
4324
- ];
4325
- }
4326
- function vjp$1(f, ...primalsIn) {
4327
- const [primalsInFlat, inTree] = flatten(primalsIn);
4328
- const [fFlat, outTree] = flattenFun(f, inTree);
4329
- const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
4330
- if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
4331
- const primalsOut = unflatten(outTree.value, primalsOutFlat);
4332
- const fVjp = ((cotangentsOut) => {
4333
- const [cotangentsOutFlat, outTree2] = flatten(cotangentsOut);
4334
- if (!outTree.value.equals(outTree2)) throw new TreeMismatchError("vjp", outTree.value, outTree2);
4335
- const cotangentsInFlat = fVjpFlat(...cotangentsOutFlat.map(pureArray));
4336
- return unflatten(inTree, cotangentsInFlat);
4337
- });
4338
- fVjp.dispose = dispose$1;
4339
- return [primalsOut, fVjp];
4340
- }
4341
- function grad$1(f) {
4342
- const valueAndGradFn = valueAndGrad$1(f);
4343
- return (...x) => {
4344
- const [y, dx] = valueAndGradFn(...x);
4345
- y.dispose();
4346
- return dx;
4347
- };
4348
- }
4349
- function valueAndGrad$1(f) {
4350
- return (...x) => {
4351
- if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
4352
- const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
4353
- if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
4354
- if (!require_backend.isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
4355
- const [ct, ...rest] = fVjp(onesLike$1(y.ref));
4356
- for (const r of rest) dispose(r);
4357
- fVjp.dispose();
4358
- return [y, ct];
4359
- };
4360
- }
4361
- function jacrev$1(f) {
4362
- return function jacobianReverse(x) {
4363
- if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
4364
- const [size$1] = x.shape;
4365
- const pullback = (ct) => {
4366
- const [y, fVjp] = vjp$1(f, x);
4367
- y.dispose();
4368
- const [ret] = fVjp(ct);
4369
- fVjp.dispose();
4370
- return ret;
4371
- };
4372
- return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
4373
- };
4374
- }
4375
-
4376
- //#endregion
4377
- //#region src/library/lax.ts
4378
- var lax_exports = {};
4379
- __export(lax_exports, {
4380
- conv: () => conv,
4381
- convGeneralDilated: () => convGeneralDilated,
4382
- convWithGeneralPadding: () => convWithGeneralPadding,
4383
- dot: () => dot$1,
4384
- erf: () => erf,
4385
- erfc: () => erfc,
4386
- reduceWindow: () => reduceWindow,
4387
- stopGradient: () => stopGradient$1
4388
- });
4389
- /**
4390
- * General dot product/contraction operator.
4391
- *
4392
- * Prefer higher-level functions like `jax.numpy.dot()`, `jax.numpy.matmul()`,
4393
- * `jax.numpy.tensordot(), and `jax.numpy.einsum()` where possible.
4394
- */
4395
- function dot$1(lhs, rhs, { lhsContractingDims: lc = [], rhsContractingDims: rc = [], lhsBatchDims: lb = [], rhsBatchDims: rb = [] } = {}) {
4396
- if (lc.length !== rc.length) throw new Error(`dot: contracting dims lengths mismatch, got ${JSON.stringify(lc)} and ${JSON.stringify(rc)}`);
4397
- else if (lb.length !== rb.length) throw new Error(`dot: batch dims lengths mismatch, got ${JSON.stringify(lb)} and ${JSON.stringify(rb)}`);
4398
- lc = lc.map((a) => require_backend.checkAxis(a, lhs.ndim));
4399
- rc = rc.map((a) => require_backend.checkAxis(a, rhs.ndim));
4400
- lb = lb.map((a) => require_backend.checkAxis(a, lhs.ndim));
4401
- rb = rb.map((a) => require_backend.checkAxis(a, rhs.ndim));
4402
- if (lc.some((a) => lb.includes(a))) throw new Error(`dot: lhs contracting dims ${JSON.stringify(lc)} overlap with batch dims ${JSON.stringify(lb)}`);
4403
- else if (rc.some((a) => rb.includes(a))) throw new Error(`dot: rhs contracting dims ${JSON.stringify(rc)} overlap with batch dims ${JSON.stringify(rb)}`);
4404
- const lf = require_backend.range(lhs.ndim).filter((a) => !lc.includes(a) && !lb.includes(a));
4405
- const rf = require_backend.range(rhs.ndim).filter((a) => !rc.includes(a) && !rb.includes(a));
4406
- const lhs2 = lhs.transpose([
4407
- ...lb,
4408
- ...lf,
4409
- ...lc
4410
- ]);
4411
- const rhs2 = rhs.transpose([
4412
- ...rb,
4413
- ...rf,
4414
- ...rc
4415
- ]);
4416
- if (lc.length === 0) return mul(lhs2.reshape([
4417
- ...lb.map((a) => lhs.shape[a]),
4418
- ...lf.map((a) => lhs.shape[a]),
4419
- ...require_backend.rep(rf.length, 1)
4420
- ]), rhs2.reshape([
4421
- ...rb.map((a) => rhs.shape[a]),
4422
- ...require_backend.rep(lf.length, 1),
4423
- ...rf.map((a) => rhs.shape[a])
4424
- ]));
4425
- const dotShapeX = lc.map((a) => lhs.shape[a]);
4426
- const dotShapeY = rc.map((a) => rhs.shape[a]);
4427
- if (!require_backend.deepEqual(dotShapeX, dotShapeY)) throw new Error(`dot: shapes not aligned along contracting dims: ${JSON.stringify(dotShapeX)} != ${JSON.stringify(dotShapeY)}`);
4428
- return dot$2(lhs2.reshape([
4429
- ...lb.map((a) => lhs.shape[a]),
4430
- ...lf.map((a) => lhs.shape[a]),
4431
- ...require_backend.rep(rf.length, 1),
4432
- require_backend.prod(dotShapeX)
4433
- ]), rhs2.reshape([
4434
- ...rb.map((a) => rhs.shape[a]),
4435
- ...require_backend.rep(lf.length, 1),
4436
- ...rf.map((a) => rhs.shape[a]),
4437
- require_backend.prod(dotShapeY)
4438
- ]));
4439
- }
4440
- function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
4441
- const padType = padding.toUpperCase();
4442
- switch (padType) {
4443
- case "VALID": return require_backend.rep(inShape.length, [0, 0]);
4444
- case "SAME":
4445
- case "SAME_LOWER": {
4446
- const outShape = inShape.map((size$1, i) => Math.ceil(size$1 / strides[i]));
4447
- const padSizes = require_backend.zipn(outShape, strides, filterShape, dilation, inShape).map(([o, s, k, d, i]) => Math.max(0, (o - 1) * s + 1 + (k - 1) * d - i));
4448
- if (padType === "SAME") return padSizes.map((size$1) => [size$1 >> 1, size$1 - (size$1 >> 1)]);
4449
- else return padSizes.map((size$1) => [size$1 - (size$1 >> 1), size$1 >> 1]);
4450
- }
4451
- default: throw new Error(`Unknown padding type: ${padType}`);
4452
- }
4453
- }
4454
- /**
4455
- * General n-dimensional convolution operator, with optional dilation.
4456
- *
4457
- * The semantics of this operation mimic the `jax.lax.conv_general_dilated`
4458
- * function in JAX, which wraps XLA's general convolution operator.
4459
- *
4460
- * Grouped convolutions are not supported right now.
4461
- */
4462
- function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
4463
- if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
4464
- if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
4465
- if (typeof padding === "string") {
4466
- if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
4467
- padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? require_backend.rep(rhs.ndim - 2, 1), padding);
4468
- }
4469
- if (featureGroupCount !== 1) {
4470
- const G = featureGroupCount;
4471
- const [N, C_in, ...xs] = lhs.shape;
4472
- const [C_out, C_in_per_group, ...ks] = rhs.shape;
4473
- if (C_in % G !== 0) throw new Error(`featureGroupCount=${G} must divide input channels=${C_in}`);
4474
- if (C_out % G !== 0) throw new Error(`featureGroupCount=${G} must divide output channels=${C_out}`);
4475
- if (C_in / G !== C_in_per_group) throw new Error(`rhs input channels=${C_in_per_group} must equal lhs input channels / groups=${C_in / G}`);
4476
- const lhsGrouped = moveaxis(lhs.reshape([
4477
- N,
4478
- G,
4479
- C_in / G,
4480
- ...xs
4481
- ]), 1, 0);
4482
- const rhsGrouped = rhs.reshape([
4483
- G,
4484
- C_out / G,
4485
- C_in_per_group,
4486
- ...ks
4487
- ]);
4488
- const result = conv$1(lhsGrouped, rhsGrouped, {
4489
- vmapDims: 1,
4490
- strides: windowStrides,
4491
- padding,
4492
- lhsDilation,
4493
- rhsDilation
4537
+ [Primitive.Transpose]([ct], [x], { perm }) {
4538
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Transpose);
4539
+ return [transpose$1(ct, require_backend.invertPermutation(perm))];
4540
+ },
4541
+ [Primitive.Broadcast]([ct], [x], { axis }) {
4542
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Broadcast);
4543
+ return [reduce(ct, require_backend.AluOp.Add, axis)];
4544
+ },
4545
+ [Primitive.Reshape]([ct], [x], _) {
4546
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Reshape);
4547
+ return [reshape$1(ct, x.aval.shape)];
4548
+ },
4549
+ [Primitive.Flip]([ct], [x], { axis }) {
4550
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Flip);
4551
+ return [flip$1(ct, axis)];
4552
+ },
4553
+ [Primitive.Shrink]([ct], [x], { slice }) {
4554
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Shrink);
4555
+ const width = slice.map(([s, e$1], i) => [s, x.aval.shape[i] - e$1]);
4556
+ return [pad$1(ct, width)];
4557
+ },
4558
+ [Primitive.Pad]([ct], [x], { width }) {
4559
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Pad);
4560
+ const slice = width.map(([s, _e], i) => [s, s + x.aval.shape[i]]);
4561
+ return [shrink(ct, slice)];
4562
+ },
4563
+ [Primitive.TriangularSolve]([ct], [a, b], { unitDiagonal }) {
4564
+ if (a instanceof UndefPrimal || !(b instanceof UndefPrimal)) throw new NonlinearError(Primitive.TriangularSolve);
4565
+ const ctB = triangularSolve$1(moveaxis(a, -2, -1), ct, {
4566
+ lower: true,
4567
+ unitDiagonal
4494
4568
  });
4495
- const ys = result.shape.slice(3);
4496
- return moveaxis(result, 0, 1).reshape([
4497
- N,
4498
- C_out,
4499
- ...ys
4500
- ]);
4569
+ return [null, ctB];
4570
+ },
4571
+ [Primitive.Jit](cts, args, { name, jaxpr }) {
4572
+ const undefPrimals = args.map((x) => x instanceof UndefPrimal);
4573
+ const newJaxpr = transposeJaxpr(jaxpr, undefPrimals);
4574
+ const residuals = args.filter((x, i$1) => !undefPrimals[i$1]);
4575
+ const outs = bind(Primitive.Jit, [
4576
+ ...newJaxpr.consts.map((c) => c.ref),
4577
+ ...residuals,
4578
+ ...cts
4579
+ ], {
4580
+ name: `${name}_t`,
4581
+ jaxpr: newJaxpr.jaxpr,
4582
+ numConsts: newJaxpr.consts.length
4583
+ });
4584
+ let i = 0;
4585
+ return undefPrimals.map((isUndef) => isUndef ? outs[i++] : null);
4501
4586
  }
4502
- return conv$1(lhs, rhs, {
4503
- strides: windowStrides,
4504
- padding,
4505
- lhsDilation,
4506
- rhsDilation
4507
- });
4508
- }
4509
- /** Convenience wrapper around `convGeneralDilated`. */
4510
- function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation) {
4511
- return convGeneralDilated(lhs, rhs, windowStrides, padding, {
4512
- lhsDilation,
4513
- rhsDilation
4514
- });
4587
+ };
4588
+ const transposeJaxprCache = /* @__PURE__ */ new Map();
4589
+ function transposeJaxpr(jaxpr, undefPrimals) {
4590
+ const cacheKey = JSON.stringify(undefPrimals);
4591
+ const prevResult = transposeJaxprCache.get(jaxpr)?.get(cacheKey);
4592
+ if (prevResult) return prevResult;
4593
+ const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
4594
+ const forwardInTypes = inTypes.filter((_, i) => !undefPrimals[i]);
4595
+ const { jaxpr: newJaxpr } = makeJaxpr$1((forwardIn, cotangents) => {
4596
+ const args = [];
4597
+ let forwardInIdx = 0;
4598
+ for (let i = 0; i < undefPrimals.length; i++) if (undefPrimals[i]) args.push(new UndefPrimal(inTypes[i]));
4599
+ else args.push(forwardIn[forwardInIdx++]);
4600
+ return evalJaxprTransposed(jaxpr, args, cotangents);
4601
+ })(forwardInTypes, outTypes);
4602
+ typecheckJaxpr(newJaxpr.jaxpr);
4603
+ if (!transposeJaxprCache.has(jaxpr)) transposeJaxprCache.set(jaxpr, /* @__PURE__ */ new Map());
4604
+ transposeJaxprCache.get(jaxpr).set(cacheKey, newJaxpr);
4605
+ return newJaxpr;
4515
4606
  }
4516
- /** Convenience wrapper around `convGeneralDilated`. */
4517
- function conv(lhs, rhs, windowStrides, padding) {
4518
- return convGeneralDilated(lhs, rhs, windowStrides, padding);
4607
+ function vjpFlat(f, primalsIn) {
4608
+ const { primalsOut, jaxpr } = linearizeFlatUtil(f, primalsIn);
4609
+ const fVjp = (...cotangents) => {
4610
+ const transposeInputs = [...jaxpr.consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
4611
+ return evalJaxprTransposed(jaxpr.jaxpr, transposeInputs, cotangents);
4612
+ };
4613
+ const dispose$1 = () => jaxpr.dispose();
4614
+ return [
4615
+ primalsOut,
4616
+ fVjp,
4617
+ dispose$1
4618
+ ];
4519
4619
  }
4520
- /** Reduce a computation over padded windows. */
4521
- function reduceWindow(operand, computation, windowDimensions, windowStrides) {
4522
- if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
4523
- if (!windowStrides) windowStrides = require_backend.rep(windowDimensions.length, 1);
4524
- for (let i = 0; i < operand.ndim; i++) computation = vmap$1(computation, 0);
4525
- return computation(bind1(Primitive.Pool, [operand], {
4526
- window: windowDimensions,
4527
- strides: windowStrides
4528
- }));
4620
+ function vjp$1(f, ...primalsIn) {
4621
+ const [primalsInFlat, inTree] = flatten(primalsIn);
4622
+ const [fFlat, outTree] = flattenFun(f, inTree);
4623
+ const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
4624
+ if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
4625
+ const primalsOut = unflatten(outTree.value, primalsOutFlat);
4626
+ const fVjp = ((cotangentsOut) => {
4627
+ const [cotangentsOutFlat, outTree2] = flatten(cotangentsOut);
4628
+ if (!outTree.value.equals(outTree2)) throw new TreeMismatchError("vjp", outTree.value, outTree2);
4629
+ const cotangentsInFlat = fVjpFlat(...cotangentsOutFlat.map(pureArray));
4630
+ return unflatten(inTree, cotangentsInFlat);
4631
+ });
4632
+ fVjp.dispose = dispose$1;
4633
+ return [primalsOut, fVjp];
4529
4634
  }
4530
- /** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
4531
- function erf(x) {
4532
- return erf$1(x);
4635
+ function grad$1(f) {
4636
+ const valueAndGradFn = valueAndGrad$1(f);
4637
+ return (...x) => {
4638
+ const [y, dx] = valueAndGradFn(...x);
4639
+ y.dispose();
4640
+ return dx;
4641
+ };
4533
4642
  }
4534
- /**
4535
- * The complementary error function: `erfc(x) = 1 - erf(x)`.
4536
- *
4537
- * This function is more accurate than `1 - erf(x)` for large values of `x`,
4538
- * where `erf(x)` is very close to 1.
4539
- */
4540
- function erfc(x) {
4541
- return erfc$1(x);
4643
+ function valueAndGrad$1(f) {
4644
+ return (...x) => {
4645
+ if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
4646
+ const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
4647
+ if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
4648
+ if (!require_backend.isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
4649
+ const [ct, ...rest] = fVjp(onesLike$1(y.ref));
4650
+ for (const r of rest) dispose(r);
4651
+ fVjp.dispose();
4652
+ return [y, ct];
4653
+ };
4542
4654
  }
4543
- /**
4544
- * Stops gradient computation.
4545
- *
4546
- * Behaves as the identity function but prevents the flow of gradients during
4547
- * forward or reverse-mode automatic differentiation.
4548
- */
4549
- function stopGradient$1(x) {
4550
- return stopGradient(x);
4655
+ function jacrev$1(f) {
4656
+ return function jacobianReverse(x) {
4657
+ if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
4658
+ const [size$1] = x.shape;
4659
+ const pullback = (ct) => {
4660
+ const [y, fVjp] = vjp$1(f, x);
4661
+ y.dispose();
4662
+ const [ret] = fVjp(ct);
4663
+ fVjp.dispose();
4664
+ return ret;
4665
+ };
4666
+ return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
4667
+ };
4551
4668
  }
4552
4669
 
4553
4670
  //#endregion
@@ -4745,34 +4862,207 @@ function* allPaths(tensors, next) {
4745
4862
  }
4746
4863
  }
4747
4864
 
4865
+ //#endregion
4866
+ //#region src/library/numpy-fft.ts
4867
+ var numpy_fft_exports = {};
4868
+ __export(numpy_fft_exports, {
4869
+ fft: () => fft,
4870
+ ifft: () => ifft
4871
+ });
4872
+ function checkPairInput(name, a) {
4873
+ const fullName = `jax.numpy.fft.${name}`;
4874
+ if (!require_backend.deepEqual(a.real.shape, a.imag.shape)) throw new Error(`${fullName}: real and imaginary parts must have the same shape, got ${JSON.stringify(a.real.shape)} and ${JSON.stringify(a.imag.shape)}`);
4875
+ if (a.real.dtype !== a.imag.dtype) throw new Error(`${fullName}: real and imaginary parts must have the same dtype, got ${a.real.dtype} and ${a.imag.dtype}`);
4876
+ if (!require_backend.isFloatDtype(a.real.dtype)) throw new Error(`${fullName}: input must have a float dtype, got ${a.real.dtype}`);
4877
+ }
4878
+ function checkPowerOfTwo(name, n) {
4879
+ if ((n & n - 1) !== 0) throw new Error(`jax.numpy.fft.${name}: size must be a power of two, got ${n}`);
4880
+ }
4881
+ const fftUpdate = jit$1(function fftUpdate$1(i, { real, imag }) {
4882
+ const half = 2 ** i;
4883
+ real = real.reshape([-1, 2 * half]);
4884
+ imag = imag.reshape([-1, 2 * half]);
4885
+ const k = arange(0, half, 1, { dtype: real.dtype });
4886
+ const theta = k.mul(-Math.PI / half);
4887
+ const wr = cos(theta.ref);
4888
+ const wi = sin(theta);
4889
+ const ur = real.ref.slice([], [0, half]);
4890
+ const ui = imag.ref.slice([], [0, half]);
4891
+ const vr = real.slice([], [half, 2 * half]);
4892
+ const vi = imag.slice([], [half, 2 * half]);
4893
+ const tr = vr.ref.mul(wr.ref).sub(vi.ref.mul(wi.ref));
4894
+ const ti = vr.mul(wi).add(vi.mul(wr));
4895
+ return {
4896
+ real: concatenate([ur.ref.add(tr.ref), ur.sub(tr)], -1),
4897
+ imag: concatenate([ui.ref.add(ti.ref), ui.sub(ti)], -1)
4898
+ };
4899
+ }, { staticArgnums: [0] });
4900
+ /**
4901
+ * Compute a one-dimensional discrete Fourier transform.
4902
+ *
4903
+ * Currently, the size of the axis must be a power of two.
4904
+ */
4905
+ function fft(a, axis = -1) {
4906
+ checkPairInput("fft", a);
4907
+ let { real, imag } = a;
4908
+ axis = require_backend.checkAxis(axis, real.ndim);
4909
+ const n = real.shape[axis];
4910
+ checkPowerOfTwo("fft", n);
4911
+ const logN = Math.log2(n);
4912
+ let perm = null;
4913
+ if (axis !== real.ndim - 1) {
4914
+ perm = require_backend.range(real.ndim);
4915
+ perm.splice(axis, 1);
4916
+ perm.push(axis);
4917
+ real = real.transpose(perm);
4918
+ imag = imag.transpose(perm);
4919
+ }
4920
+ const originalShape = real.shape;
4921
+ real = real.reshape([-1, ...require_backend.rep(logN, 2)]).transpose([0, ...require_backend.range(1, logN + 1).reverse()]).flatten();
4922
+ imag = imag.reshape([-1, ...require_backend.rep(logN, 2)]).transpose([0, ...require_backend.range(1, logN + 1).reverse()]).flatten();
4923
+ for (let i = 0; i < logN; i++) ({real, imag} = fftUpdate(i, {
4924
+ real,
4925
+ imag
4926
+ }));
4927
+ real = real.reshape(originalShape);
4928
+ imag = imag.reshape(originalShape);
4929
+ if (perm !== null) {
4930
+ real = real.transpose(require_backend.invertPermutation(perm));
4931
+ imag = imag.transpose(require_backend.invertPermutation(perm));
4932
+ }
4933
+ return {
4934
+ real,
4935
+ imag
4936
+ };
4937
+ }
4938
+ /**
4939
+ * Compute a one-dimensional inverse discrete Fourier transform.
4940
+ *
4941
+ * Currently, the size of the axis must be a power of two.
4942
+ */
4943
+ function ifft(a, axis = -1) {
4944
+ checkPairInput("ifft", a);
4945
+ let { real, imag } = a;
4946
+ axis = require_backend.checkAxis(axis, real.ndim);
4947
+ const n = real.shape[axis];
4948
+ checkPowerOfTwo("ifft", n);
4949
+ imag = imag.mul(-1);
4950
+ const result = fft({
4951
+ real,
4952
+ imag
4953
+ }, axis);
4954
+ return {
4955
+ real: result.real.div(n),
4956
+ imag: result.imag.mul(-1).div(n)
4957
+ };
4958
+ }
4959
+
4960
+ //#endregion
4961
+ //#region src/library/numpy-linalg.ts
4962
+ var numpy_linalg_exports = {};
4963
+ __export(numpy_linalg_exports, {
4964
+ cholesky: () => cholesky$1,
4965
+ diagonal: () => diagonal,
4966
+ lstsq: () => lstsq,
4967
+ matmul: () => matmul,
4968
+ matrixTranspose: () => matrixTranspose,
4969
+ outer: () => outer,
4970
+ tensordot: () => tensordot,
4971
+ trace: () => trace,
4972
+ vecdot: () => vecdot
4973
+ });
4974
+ /**
4975
+ * Compute the Cholesky decomposition of a (batched) positive-definite matrix.
4976
+ *
4977
+ * This is like `jax.lax.linalg.cholesky()`, except with an option to symmetrize
4978
+ * the input matrix, which is on by default.
4979
+ */
4980
+ function cholesky$1(a, { upper = false, symmetrizeInput = true } = {}) {
4981
+ a = fudgeArray(a);
4982
+ if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`cholesky: input must be at least 2D square matrix, got ${a.aval}`);
4983
+ if (symmetrizeInput) a = a.ref.add(matrixTranspose(a)).mul(.5);
4984
+ return cholesky(a, { upper });
4985
+ }
4986
+ /**
4987
+ * Return the least-squares solution to a linear equation.
4988
+ *
4989
+ * For overdetermined systems, this finds the `x` that minimizes `norm(ax - b)`.
4990
+ * For underdetermined systems, this finds the minimum-norm solution for `x`.
4991
+ *
4992
+ * This currently uses Cholesky decomposition to solve the normal equations,
4993
+ * under the hood. The method is not as robust as QR or SVD.
4994
+ *
4995
+ * @param a coefficient matrix of shape `(M, N)`
4996
+ * @param b right-hand side of shape `(M,)` or `(M, K)`
4997
+ * @return least-squares solution of shape `(N,)` or `(N, K)`
4998
+ */
4999
+ function lstsq(a, b) {
5000
+ a = fudgeArray(a);
5001
+ b = fudgeArray(b);
5002
+ if (a.ndim !== 2) throw new Error(`lstsq: 'a' must be a 2D array, got ${a.aval}`);
5003
+ const [m, n] = a.shape;
5004
+ if (b.shape[0] !== m) throw new Error(`lstsq: leading dimension of 'b' must match number of rows of 'a', got ${b.aval}`);
5005
+ const at = matrixTranspose(a.ref);
5006
+ if (m <= n) {
5007
+ const aat = matmul(a, at.ref);
5008
+ const l = cholesky$1(aat, { symmetrizeInput: false });
5009
+ const lb = triangularSolve(l.ref, b, {
5010
+ leftSide: true,
5011
+ lower: true
5012
+ });
5013
+ const llb = triangularSolve(l, lb, {
5014
+ leftSide: true,
5015
+ transposeA: true
5016
+ });
5017
+ return matmul(at, llb.ref);
5018
+ } else {
5019
+ const ata = matmul(at.ref, a);
5020
+ const l = cholesky$1(ata, { symmetrizeInput: false });
5021
+ const atb = matmul(at, b);
5022
+ const lb = triangularSolve(l.ref, atb, {
5023
+ leftSide: true,
5024
+ lower: true
5025
+ });
5026
+ const llb = triangularSolve(l, lb, {
5027
+ leftSide: true,
5028
+ transposeA: true
5029
+ });
5030
+ return llb;
5031
+ }
5032
+ }
5033
+
4748
5034
  //#endregion
4749
5035
  //#region src/library/numpy.ts
4750
5036
  var numpy_exports = {};
4751
5037
  __export(numpy_exports, {
4752
5038
  Array: () => Array$1,
4753
5039
  DType: () => require_backend.DType,
4754
- abs: () => abs,
5040
+ abs: () => absolute,
4755
5041
  absolute: () => absolute,
4756
5042
  acos: () => acos,
4757
- acosh: () => acosh,
5043
+ acosh: () => arccosh,
4758
5044
  add: () => add,
5045
+ all: () => all,
4759
5046
  allclose: () => allclose,
5047
+ any: () => any,
4760
5048
  arange: () => arange,
4761
- arccos: () => arccos,
5049
+ arccos: () => acos,
4762
5050
  arccosh: () => arccosh,
5051
+ arcsin: () => asin,
4763
5052
  arcsinh: () => arcsinh,
4764
- arctan: () => arctan,
4765
- arctan2: () => arctan2,
5053
+ arctan: () => atan,
5054
+ arctan2: () => atan2,
4766
5055
  arctanh: () => arctanh,
4767
5056
  argmax: () => argmax,
4768
5057
  argmin: () => argmin,
5058
+ argsort: () => argsort,
4769
5059
  array: () => array,
4770
5060
  asin: () => asin,
4771
- asinh: () => asinh,
5061
+ asinh: () => arcsinh,
4772
5062
  astype: () => astype,
4773
5063
  atan: () => atan,
4774
5064
  atan2: () => atan2,
4775
- atanh: () => atanh,
5065
+ atanh: () => arctanh,
4776
5066
  bool: () => bool,
4777
5067
  broadcastArrays: () => broadcastArrays,
4778
5068
  broadcastShapes: () => broadcastShapes,
@@ -4782,16 +5072,20 @@ __export(numpy_exports, {
4782
5072
  clip: () => clip,
4783
5073
  columnStack: () => columnStack,
4784
5074
  concatenate: () => concatenate,
5075
+ convolve: () => convolve,
5076
+ corrcoef: () => corrcoef,
5077
+ correlate: () => correlate,
4785
5078
  cos: () => cos,
4786
5079
  cosh: () => cosh,
5080
+ cov: () => cov,
4787
5081
  cumsum: () => cumsum,
4788
- cumulativeSum: () => cumulativeSum,
5082
+ cumulativeSum: () => cumsum,
4789
5083
  deg2rad: () => deg2rad,
4790
5084
  degrees: () => degrees,
4791
5085
  diag: () => diag,
4792
5086
  diagonal: () => diagonal,
4793
- divide: () => divide,
4794
- dot: () => dot,
5087
+ divide: () => trueDivide,
5088
+ dot: () => dot$1,
4795
5089
  dstack: () => dstack,
4796
5090
  e: () => e,
4797
5091
  einsum: () => einsum,
@@ -4799,8 +5093,10 @@ __export(numpy_exports, {
4799
5093
  eulerGamma: () => eulerGamma,
4800
5094
  exp: () => exp,
4801
5095
  exp2: () => exp2,
5096
+ expandDims: () => expandDims,
4802
5097
  expm1: () => expm1,
4803
5098
  eye: () => eye,
5099
+ fft: () => numpy_fft_exports,
4804
5100
  flip: () => flip,
4805
5101
  fliplr: () => fliplr,
4806
5102
  flipud: () => flipud,
@@ -4831,12 +5127,14 @@ __export(numpy_exports, {
4831
5127
  ldexp: () => ldexp,
4832
5128
  less: () => less,
4833
5129
  lessEqual: () => lessEqual,
5130
+ linalg: () => numpy_linalg_exports,
4834
5131
  linspace: () => linspace,
4835
5132
  log: () => log,
4836
5133
  log10: () => log10,
4837
5134
  log1p: () => log1p,
4838
5135
  log2: () => log2,
4839
5136
  matmul: () => matmul,
5137
+ matrixTranspose: () => matrixTranspose,
4840
5138
  max: () => max,
4841
5139
  maximum: () => maximum,
4842
5140
  mean: () => mean,
@@ -4853,10 +5151,10 @@ __export(numpy_exports, {
4853
5151
  onesLike: () => onesLike,
4854
5152
  outer: () => outer,
4855
5153
  pad: () => pad,
4856
- permuteDims: () => permuteDims,
5154
+ permuteDims: () => transpose,
4857
5155
  pi: () => pi,
4858
5156
  positive: () => positive,
4859
- pow: () => pow,
5157
+ pow: () => power,
4860
5158
  power: () => power,
4861
5159
  prod: () => prod$1,
4862
5160
  promoteTypes: () => require_backend.promoteTypes,
@@ -4873,6 +5171,7 @@ __export(numpy_exports, {
4873
5171
  sin: () => sin,
4874
5172
  sinh: () => sinh,
4875
5173
  size: () => size,
5174
+ sort: () => sort,
4876
5175
  sqrt: () => sqrt,
4877
5176
  square: () => square,
4878
5177
  squeeze: () => squeeze,
@@ -5037,6 +5336,26 @@ function min(a, axis = null, opts) {
5037
5336
  function max(a, axis = null, opts) {
5038
5337
  return reduce(a, require_backend.AluOp.Max, axis, opts);
5039
5338
  }
5339
+ /**
5340
+ * Test whether all array elements along a given axis evaluate to True.
5341
+ *
5342
+ * Returns a boolean array with the same shape as `a` with the specified axis
5343
+ * removed. If axis is None, returns a scalar.
5344
+ */
5345
+ function all(a, axis = null, opts) {
5346
+ a = fudgeArray(a).astype(require_backend.DType.Bool);
5347
+ return min(a, axis, opts);
5348
+ }
5349
+ /**
5350
+ * Test whether any array element along a given axis evaluates to True.
5351
+ *
5352
+ * Returns a boolean array with the same shape as `a` with the specified axis
5353
+ * removed. If axis is None, returns a scalar.
5354
+ */
5355
+ function any(a, axis = null, opts) {
5356
+ a = fudgeArray(a).astype(require_backend.DType.Bool);
5357
+ return max(a, axis, opts);
5358
+ }
5040
5359
  /** Return the peak-to-peak range along a given axis (`max - min`). */
5041
5360
  function ptp(a, axis = null, opts) {
5042
5361
  a = fudgeArray(a);
@@ -5111,8 +5430,6 @@ function cumsum(a, axis) {
5111
5430
  a = broadcast(a, a.shape.concat(n), [-2]);
5112
5431
  return moveaxis$1(tril(a).sum(-1), -1, axis);
5113
5432
  }
5114
- /** @function Alternative name for `jax.numpy.cumsum()`. */
5115
- const cumulativeSum = cumsum;
5116
5433
  /** Reverse the elements in an array along the given axes. */
5117
5434
  function flip(x, axis = null) {
5118
5435
  const nd = ndim(x);
@@ -5222,8 +5539,11 @@ function flipud(x) {
5222
5539
  function fliplr(x) {
5223
5540
  return flip(x, 1);
5224
5541
  }
5225
- /** @function Alternative name for `numpy.transpose()`. */
5226
- const permuteDims = transpose;
5542
+ /** Transpose the last two dimensions of an array. */
5543
+ function matrixTranspose(a) {
5544
+ if (ndim(a) < 2) throw new Error(`matrixTranspose: input array must be at least 2D`);
5545
+ return moveaxis$1(a, -1, -2);
5546
+ }
5227
5547
  /** Return a 1-D flattened array containing the elements of the input. */
5228
5548
  function ravel(a) {
5229
5549
  return fudgeArray(a).ravel();
@@ -5239,6 +5559,32 @@ function squeeze(a, axis = null) {
5239
5559
  return reshape(a, newShape);
5240
5560
  }
5241
5561
  /**
5562
+ * Expand the shape of an array by inserting new axes of length 1.
5563
+ *
5564
+ * @param a - Input array.
5565
+ * @param axis - Position(s) in the expanded axes where the new axis (or axes)
5566
+ * is placed. Can be a single integer or an array of integers.
5567
+ * @returns Array with the number of dimensions increased.
5568
+ *
5569
+ * @example
5570
+ * ```ts
5571
+ * const x = np.array([1, 2]);
5572
+ * np.expandDims(x, 0); // Shape [1, 2]
5573
+ * np.expandDims(x, 1); // Shape [2, 1]
5574
+ * np.expandDims(x, [0, 2]); // Shape [1, 2, 1]
5575
+ * ```
5576
+ */
5577
+ function expandDims(a, axis) {
5578
+ const as = shape(a);
5579
+ axis = typeof axis === "number" ? [axis] : axis;
5580
+ axis = require_backend.normalizeAxis(axis, as.length + axis.length);
5581
+ const newShape = [];
5582
+ let srcIdx = 0;
5583
+ for (let i = 0; i < as.length + axis.length; i++) if (axis.includes(i)) newShape.push(1);
5584
+ else newShape.push(as[srcIdx++]);
5585
+ return reshape(a, newShape);
5586
+ }
5587
+ /**
5242
5588
  * Repeat each element of an array after themselves.
5243
5589
  *
5244
5590
  * If no axis is provided, use the flattened input array, and return a flat
@@ -5326,7 +5672,7 @@ function diagonal(a, offset, axis1, axis2) {
5326
5672
  */
5327
5673
  function diag(v, k = 0) {
5328
5674
  const a = fudgeArray(v);
5329
- if (!Number.isInteger(k)) throw new TypeError(`k must be an integer, got ${k}`);
5675
+ if (!Number.isInteger(k)) throw new Error(`k must be an integer, got ${k}`);
5330
5676
  if (a.ndim === 1) {
5331
5677
  const n = a.shape[0];
5332
5678
  const ret = where(eye(n).equal(1), a.ref, zerosLike(a));
@@ -5334,12 +5680,32 @@ function diag(v, k = 0) {
5334
5680
  else if (k < 0) return pad(ret, [[-k, 0], [0, -k]]);
5335
5681
  else return ret;
5336
5682
  } else if (a.ndim === 2) return diagonal(a, k);
5337
- else throw new TypeError("numpy.diag only supports 1D and 2D arrays");
5683
+ else throw new Error("numpy.diag only supports 1D and 2D arrays");
5338
5684
  }
5339
5685
  /** Calculate the sum of the diagonal of an array along the given axes. */
5340
5686
  function trace(a, offset = 0, axis1 = 0, axis2 = 1) {
5341
5687
  return diagonal(a, offset, axis1, axis2).sum(-1);
5342
5688
  }
5689
+ /**
5690
+ * Return a sorted copy of an array.
5691
+ *
5692
+ * The array is sorted along a specified axis (the last by default). This may be
5693
+ * an unstable sort, and it dispatches to device-specific implementation.
5694
+ */
5695
+ function sort(a, axis = -1) {
5696
+ return fudgeArray(a).sort(axis);
5697
+ }
5698
+ /**
5699
+ * Return indices that would sort an array. This may be an unstable sorting
5700
+ * algorithm; it need not preserve order of indices in ties.
5701
+ *
5702
+ * Returns an array of `int32` indices.
5703
+ *
5704
+ * The array is sorted along a specified axis (the last by default).
5705
+ */
5706
+ function argsort(a, axis = -1) {
5707
+ return fudgeArray(a).argsort(axis);
5708
+ }
5343
5709
  /** Return if two arrays are element-wise equal within a tolerance. */
5344
5710
  function allclose(actual, expected, options) {
5345
5711
  const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
@@ -5356,11 +5722,11 @@ function allclose(actual, expected, options) {
5356
5722
  }
5357
5723
  /** Matrix product of two arrays. */
5358
5724
  function matmul(x, y) {
5359
- if (ndim(x) === 0 || ndim(y) === 0) throw new TypeError("matmul: x and y must be at least 1D");
5725
+ if (ndim(x) === 0 || ndim(y) === 0) throw new Error("matmul: x and y must be at least 1D");
5360
5726
  x = x, y = y;
5361
5727
  if (y.ndim === 1) return dot$2(x, y);
5362
5728
  const numBatchDims = Math.min(Math.max(x.ndim, 2), y.ndim) - 2;
5363
- return dot$1(x, y, {
5729
+ return dot(x, y, {
5364
5730
  lhsContractingDims: [-1],
5365
5731
  rhsContractingDims: [-2],
5366
5732
  lhsBatchDims: require_backend.range(-2 - numBatchDims, -2),
@@ -5368,11 +5734,11 @@ function matmul(x, y) {
5368
5734
  });
5369
5735
  }
5370
5736
  /** Dot product of two arrays. */
5371
- function dot(x, y) {
5737
+ function dot$1(x, y) {
5372
5738
  if (ndim(x) === 0 || ndim(y) === 0) return multiply(x, y);
5373
5739
  x = x, y = y;
5374
5740
  if (y.ndim === 1) return dot$2(x, y);
5375
- return dot$1(x, y, {
5741
+ return dot(x, y, {
5376
5742
  lhsContractingDims: [-1],
5377
5743
  rhsContractingDims: [-2]
5378
5744
  });
@@ -5388,7 +5754,7 @@ function tensordot(x, y, axes = 2) {
5388
5754
  x = fudgeArray(x);
5389
5755
  y = fudgeArray(y);
5390
5756
  if (typeof axes === "number") axes = [require_backend.range(-axes, 0), require_backend.range(axes)];
5391
- return dot$1(x, y, {
5757
+ return dot(x, y, {
5392
5758
  lhsContractingDims: axes[0],
5393
5759
  rhsContractingDims: axes[1]
5394
5760
  });
@@ -5481,7 +5847,7 @@ function einsum(...args) {
5481
5847
  const [b, bidx] = processSingleTensor(operands[j], indices[j], indices[i]);
5482
5848
  indexReduced = indexReduced.filter((idx) => aidx.includes(idx));
5483
5849
  const indexBatch = aidx.filter((idx) => bidx.includes(idx) && !indexReduced.includes(idx));
5484
- const result = dot$1(a, b, {
5850
+ const result = dot(a, b, {
5485
5851
  lhsContractingDims: indexReduced.map((idx) => aidx.indexOf(idx)),
5486
5852
  rhsContractingDims: indexReduced.map((idx) => bidx.indexOf(idx)),
5487
5853
  lhsBatchDims: indexBatch.map((idx) => aidx.indexOf(idx)),
@@ -5509,7 +5875,7 @@ function einsum(...args) {
5509
5875
  * Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
5510
5876
  */
5511
5877
  function inner(x, y) {
5512
- return dot$1(fudgeArray(x), fudgeArray(y), {
5878
+ return dot(fudgeArray(x), fudgeArray(y), {
5513
5879
  lhsContractingDims: [-1],
5514
5880
  rhsContractingDims: [-1]
5515
5881
  });
@@ -5542,6 +5908,30 @@ function vecdot(x, y, { axis } = {}) {
5542
5908
  function vdot(x, y) {
5543
5909
  return dot$2(ravel(x), ravel(y));
5544
5910
  }
5911
+ function _convImpl(name, x, y, mode) {
5912
+ if (x.ndim !== 1 || y.ndim !== 1) throw new Error(`${name}: both inputs must be 1D arrays, got ${x.ndim}D and ${y.ndim}D`);
5913
+ let flipOutput = false;
5914
+ if (x.shape[0] < y.shape[0]) {
5915
+ [x, y] = [y, x];
5916
+ if (name === "correlate") flipOutput = true;
5917
+ }
5918
+ if (name === "convolve") y = flip(y);
5919
+ let padding;
5920
+ if (mode === "valid") padding = "VALID";
5921
+ else if (mode === "same") padding = "SAME_LOWER";
5922
+ else if (mode === "full") padding = [[y.shape[0] - 1, y.shape[0] - 1]];
5923
+ else throw new Error(`${name}: invalid mode ${mode}, expected "full", "same", or "valid"`);
5924
+ const z = conv(x.slice(null, null), y.slice(null, null), [1], padding).slice(0, 0);
5925
+ return flipOutput ? flip(z) : z;
5926
+ }
5927
+ /** Convolution of two one-dimensional arrays. */
5928
+ function convolve(x, y, mode = "full") {
5929
+ return _convImpl("convolve", x, y, mode);
5930
+ }
5931
+ /** Correlation of two one dimensional arrays. */
5932
+ function correlate(x, y, mode = "valid") {
5933
+ return _convImpl("correlate", x, y, mode);
5934
+ }
5545
5935
  /**
5546
5936
  * Return a tuple of coordinate matrices from coordinate vectors.
5547
5937
  *
@@ -5550,7 +5940,7 @@ function vdot(x, y) {
5550
5940
  */
5551
5941
  function meshgrid(xs, { indexing } = {}) {
5552
5942
  indexing ??= "xy";
5553
- for (const x of xs) if (x.ndim !== 1) throw new TypeError(`meshgrid: all inputs must be 1D arrays, got ${x.ndim}D array`);
5943
+ for (const x of xs) if (x.ndim !== 1) throw new Error(`meshgrid: all inputs must be 1D arrays, got ${x.ndim}D array`);
5554
5944
  if (xs.length <= 1) return xs;
5555
5945
  if (indexing === "xy") {
5556
5946
  const [a, b, ...rest] = xs;
@@ -5566,44 +5956,7 @@ function meshgrid(xs, { indexing } = {}) {
5566
5956
  ];
5567
5957
  }
5568
5958
  const shape$1 = xs.map((x) => x.shape[0]);
5569
- return xs.map((x, i) => broadcast(x, shape$1, [...require_backend.range(i), ...require_backend.range(i + 1, xs.length)]));
5570
- }
5571
- /**
5572
- * Return an array with ones on and below the diagonal and zeros elsewhere.
5573
- *
5574
- * If `k` is provided, it specifies the sub-diagonal on and below which the
5575
- * array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
5576
- * `k>0` is above it.
5577
- */
5578
- function tri(n, m, k = 0, { dtype, device } = {}) {
5579
- m ??= n;
5580
- dtype ??= require_backend.DType.Float32;
5581
- if (!Number.isInteger(n) || n < 0) throw new TypeError(`tri: n must be a non-negative integer, got ${n}`);
5582
- if (!Number.isInteger(m) || m < 0) throw new TypeError(`tri: m must be a non-negative integer, got ${m}`);
5583
- if (!Number.isInteger(k)) throw new TypeError(`tri: k must be an integer, got ${k}`);
5584
- const rows = arange(k, n + k, 1, {
5585
- dtype: require_backend.DType.Int32,
5586
- device
5587
- });
5588
- const cols = arange(0, m, 1, {
5589
- dtype: require_backend.DType.Int32,
5590
- device
5591
- });
5592
- return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
5593
- }
5594
- /** Return the lower triangle of an array. Must be of dimension >= 2. */
5595
- function tril(a, k = 0) {
5596
- if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
5597
- a = fudgeArray(a);
5598
- const [n, m] = a.shape.slice(-2);
5599
- return where(tri(n, m, k, { dtype: bool }), a.ref, zerosLike(a));
5600
- }
5601
- /** Return the upper triangle of an array. Must be of dimension >= 2. */
5602
- function triu(a, k = 0) {
5603
- if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
5604
- a = fudgeArray(a);
5605
- const [n, m] = a.shape.slice(-2);
5606
- return where(tri(n, m, k - 1, { dtype: bool }), zerosLike(a.ref), a);
5959
+ return xs.map((x, i) => broadcast(x, shape$1, [...require_backend.range(i), ...require_backend.range(i + 1, xs.length)]));
5607
5960
  }
5608
5961
  /**
5609
5962
  * Clip (limit) the values in an array.
@@ -5629,8 +5982,6 @@ function absolute(x) {
5629
5982
  x = fudgeArray(x);
5630
5983
  return where(less(x.ref, 0), x.ref.mul(-1), x);
5631
5984
  }
5632
- /** @function Alias of `jax.numpy.absolute()`. */
5633
- const abs = absolute;
5634
5985
  /** Return an element-wise indication of sign of the input. */
5635
5986
  function sign(x) {
5636
5987
  x = fudgeArray(x);
@@ -5709,12 +6060,6 @@ const atan2 = jit$1(function atan2$1(y, x) {
5709
6060
  const denom = where(xNeg, y, r.add(x));
5710
6061
  return atan(numer.div(denom)).mul(2);
5711
6062
  });
5712
- /** @function Alias of `jax.numpy.acos()`. */
5713
- const arccos = acos;
5714
- /** @function Alias of `jax.numpy.atan()`. */
5715
- const arctan = atan;
5716
- /** @function Alias of `jax.numpy.atan2()`. */
5717
- const arctan2 = atan2;
5718
6063
  /** Element-wise subtraction, with broadcasting. */
5719
6064
  function subtract(x, y) {
5720
6065
  x = fudgeArray(x);
@@ -5745,8 +6090,6 @@ const fmod = jit$1(function fmod$1(x, y) {
5745
6090
  const remainder = jit$1(function remainder$1(x, y) {
5746
6091
  return mod(mod(x, y.ref).add(y.ref), y);
5747
6092
  });
5748
- /** @function Alias of `jax.numpy.trueDivide()`. */
5749
- const divide = trueDivide;
5750
6093
  /** Round input to the nearest integer towards zero. */
5751
6094
  function trunc(x) {
5752
6095
  return idiv(x, 1);
@@ -5768,9 +6111,9 @@ function ldexp(x1, x2) {
5768
6111
  */
5769
6112
  function frexp(x) {
5770
6113
  x = fudgeArray(x);
5771
- const absx = abs(x.ref);
6114
+ const absx = absolute(x.ref);
5772
6115
  const exponent = where(equal(x.ref, 0), 0, floor(log2(absx)).add(1).astype(require_backend.DType.Int32));
5773
- const mantissa = divide(x, exp2(exponent.ref.astype(x.dtype)));
6116
+ const mantissa = x.div(exp2(exponent.ref.astype(x.dtype)));
5774
6117
  return [mantissa, exponent];
5775
6118
  }
5776
6119
  /** Calculate `2**p` for all p in the input array. */
@@ -5813,10 +6156,8 @@ const power = jit$1(function power$1(x1, x2) {
5813
6156
  const x2i = trunc(x2.ref);
5814
6157
  const shouldBeNaN = multiply(x2.ref.notEqual(x2i.ref), x1.ref.less(0));
5815
6158
  const resultSign = where(mod(x2i, 2).notEqual(0), where(x1.ref.less(0), -1, 1), 1);
5816
- return where(shouldBeNaN, nan, exp(log(abs(x1)).mul(x2)).mul(resultSign));
6159
+ return where(shouldBeNaN, nan, exp(log(absolute(x1)).mul(x2)).mul(resultSign));
5817
6160
  });
5818
- /** @function Alias of `jax.numpy.power()`. */
5819
- const pow = power;
5820
6161
  /** @function Calculate the element-wise cube root of the input array. */
5821
6162
  const cbrt = jit$1(function cbrt$1(x) {
5822
6163
  const sgn = where(less(x.ref, 0), -1, 1);
@@ -5882,12 +6223,6 @@ const arccosh = jit$1(function arccosh$1(x) {
5882
6223
  const arctanh = jit$1(function arctanh$1(x) {
5883
6224
  return log(add(1, x.ref).div(subtract(1, x))).mul(.5);
5884
6225
  });
5885
- /** @function Alias of `jax.numpy.arcsinh()`. */
5886
- const asinh = arcsinh;
5887
- /** @function Alias of `jax.numpy.arccosh()`. */
5888
- const acosh = arccosh;
5889
- /** @function Alias of `jax.numpy.arctanh()`. */
5890
- const atanh = arctanh;
5891
6226
  /**
5892
6227
  * Compute the variance of an array.
5893
6228
  *
@@ -5917,6 +6252,26 @@ function var_(x, axis = null, opts) {
5917
6252
  function std(x, axis = null, opts) {
5918
6253
  return sqrt(var_(x, axis, opts));
5919
6254
  }
6255
+ /** Estimate the sample covariance of a set of variables. */
6256
+ function cov(x, y) {
6257
+ x = fudgeArray(x);
6258
+ if (x.ndim === 1) x = x.reshape([1, x.shape[0]]);
6259
+ if (y !== void 0) {
6260
+ y = fudgeArray(y);
6261
+ if (y.ndim === 1) y = y.reshape([1, y.shape[0]]);
6262
+ x = vstack([x, y]);
6263
+ }
6264
+ const [_M, N] = x.shape;
6265
+ x = x.ref.sub(x.mean(1, { keepdims: true }));
6266
+ return dot$1(x.ref, x.transpose()).div(N - 1);
6267
+ }
6268
+ /** Compute the Pearson correlation coefficients (in range `[-1, 1]`). */
6269
+ function corrcoef(x, y) {
6270
+ const c = cov(x, y);
6271
+ const variances = diag(c.ref);
6272
+ const norm = sqrt(outer(variances.ref, variances));
6273
+ return c.div(norm);
6274
+ }
5920
6275
  /** Test element-wise for positive or negative infinity, return bool array. */
5921
6276
  function isinf(x) {
5922
6277
  x = fudgeArray(x);
@@ -5946,6 +6301,253 @@ const isfinite = jit$1(function isfinite$1(x) {
5946
6301
  return isnan(x.ref).add(isinf(x)).notEqual(true);
5947
6302
  });
5948
6303
 
6304
+ //#endregion
6305
+ //#region src/library/lax-linalg.ts
6306
+ var lax_linalg_exports = {};
6307
+ __export(lax_linalg_exports, {
6308
+ cholesky: () => cholesky,
6309
+ triangularSolve: () => triangularSolve
6310
+ });
6311
+ /**
6312
+ * Compute the Cholesky decomposition of a symmetric positive-definite matrix.
6313
+ *
6314
+ * The Cholesky decomposition of a matrix `A` is:
6315
+ *
6316
+ * - A = L @ L^T (for upper=false, default)
6317
+ * - A = U^T @ U (for upper=true)
6318
+ *
6319
+ * where `L` is a lower-triangular matrix and `U` is an upper-triangular matrix.
6320
+ * The input matrix must be symmetric and positive-definite.
6321
+ *
6322
+ * @example
6323
+ * ```ts
6324
+ * import { lax, numpy as np } from "@jax-js/jax";
6325
+ *
6326
+ * const x = np.array([[2., 1.], [1., 2.]]);
6327
+ *
6328
+ * // Lower Cholesky factorization (default):
6329
+ * const L = lax.linalg.cholesky(x);
6330
+ * // L ≈ [[1.4142135, 0], [0.70710677, 1.2247449]]
6331
+ *
6332
+ * // Upper Cholesky factorization:
6333
+ * const U = lax.linalg.cholesky(x, { upper: true });
6334
+ * // U ≈ [[1.4142135, 0.70710677], [0, 1.2247449]]
6335
+ * ```
6336
+ */
6337
+ function cholesky(a, { upper = false } = {}) {
6338
+ const L = cholesky$2(a);
6339
+ return upper ? moveaxis$1(L, -2, -1) : L;
6340
+ }
6341
+ /**
6342
+ * Solve a triangular linear system.
6343
+ *
6344
+ * Solves `a @ x = b` (if leftSide=true) or `x @ a = b` (if leftSide=false)
6345
+ * where `a` is a triangular matrix.
6346
+ *
6347
+ * @example
6348
+ * ```ts
6349
+ * import { lax, numpy as np } from "@jax-js/jax";
6350
+ *
6351
+ * const L = np.array([[2., 0.], [1., 3.]]);
6352
+ * const b = np.array([4., 7.]).reshape([2, 1]);
6353
+ *
6354
+ * // Solve L @ x = b
6355
+ * const x = lax.linalg.triangularSolve(L, b, { leftSide: true, lower: true });
6356
+ * // x = [[2.], [5./3.]]
6357
+ * ```
6358
+ */
6359
+ function triangularSolve(a, b, { leftSide = false, lower = false, transposeA = false, unitDiagonal = false } = {}) {
6360
+ a = fudgeArray(a);
6361
+ b = fudgeArray(b);
6362
+ if (!leftSide) transposeA = !transposeA;
6363
+ else b = moveaxis$1(b, -2, -1);
6364
+ if (transposeA) a = moveaxis$1(a, -2, -1);
6365
+ let x = triangularSolve$1(a, b, {
6366
+ lower,
6367
+ unitDiagonal
6368
+ });
6369
+ if (leftSide) x = moveaxis$1(x, -2, -1);
6370
+ return x;
6371
+ }
6372
+
6373
+ //#endregion
6374
+ //#region src/library/lax.ts
6375
+ var lax_exports = {};
6376
+ __export(lax_exports, {
6377
+ conv: () => conv,
6378
+ convGeneralDilated: () => convGeneralDilated,
6379
+ convWithGeneralPadding: () => convWithGeneralPadding,
6380
+ dot: () => dot,
6381
+ erf: () => erf,
6382
+ erfc: () => erfc,
6383
+ linalg: () => lax_linalg_exports,
6384
+ reduceWindow: () => reduceWindow,
6385
+ stopGradient: () => stopGradient$1
6386
+ });
6387
+ /**
6388
+ * General dot product/contraction operator.
6389
+ *
6390
+ * Prefer higher-level functions like `jax.numpy.dot()`, `jax.numpy.matmul()`,
6391
+ * `jax.numpy.tensordot(), and `jax.numpy.einsum()` where possible.
6392
+ */
6393
+ function dot(lhs, rhs, { lhsContractingDims: lc = [], rhsContractingDims: rc = [], lhsBatchDims: lb = [], rhsBatchDims: rb = [] } = {}) {
6394
+ if (lc.length !== rc.length) throw new Error(`dot: contracting dims lengths mismatch, got ${JSON.stringify(lc)} and ${JSON.stringify(rc)}`);
6395
+ else if (lb.length !== rb.length) throw new Error(`dot: batch dims lengths mismatch, got ${JSON.stringify(lb)} and ${JSON.stringify(rb)}`);
6396
+ lc = lc.map((a) => require_backend.checkAxis(a, lhs.ndim));
6397
+ rc = rc.map((a) => require_backend.checkAxis(a, rhs.ndim));
6398
+ lb = lb.map((a) => require_backend.checkAxis(a, lhs.ndim));
6399
+ rb = rb.map((a) => require_backend.checkAxis(a, rhs.ndim));
6400
+ if (lc.some((a) => lb.includes(a))) throw new Error(`dot: lhs contracting dims ${JSON.stringify(lc)} overlap with batch dims ${JSON.stringify(lb)}`);
6401
+ else if (rc.some((a) => rb.includes(a))) throw new Error(`dot: rhs contracting dims ${JSON.stringify(rc)} overlap with batch dims ${JSON.stringify(rb)}`);
6402
+ const lf = require_backend.range(lhs.ndim).filter((a) => !lc.includes(a) && !lb.includes(a));
6403
+ const rf = require_backend.range(rhs.ndim).filter((a) => !rc.includes(a) && !rb.includes(a));
6404
+ const lhs2 = lhs.transpose([
6405
+ ...lb,
6406
+ ...lf,
6407
+ ...lc
6408
+ ]);
6409
+ const rhs2 = rhs.transpose([
6410
+ ...rb,
6411
+ ...rf,
6412
+ ...rc
6413
+ ]);
6414
+ if (lc.length === 0) return mul(lhs2.reshape([
6415
+ ...lb.map((a) => lhs.shape[a]),
6416
+ ...lf.map((a) => lhs.shape[a]),
6417
+ ...require_backend.rep(rf.length, 1)
6418
+ ]), rhs2.reshape([
6419
+ ...rb.map((a) => rhs.shape[a]),
6420
+ ...require_backend.rep(lf.length, 1),
6421
+ ...rf.map((a) => rhs.shape[a])
6422
+ ]));
6423
+ const dotShapeX = lc.map((a) => lhs.shape[a]);
6424
+ const dotShapeY = rc.map((a) => rhs.shape[a]);
6425
+ if (!require_backend.deepEqual(dotShapeX, dotShapeY)) throw new Error(`dot: shapes not aligned along contracting dims: ${JSON.stringify(dotShapeX)} != ${JSON.stringify(dotShapeY)}`);
6426
+ return dot$2(lhs2.reshape([
6427
+ ...lb.map((a) => lhs.shape[a]),
6428
+ ...lf.map((a) => lhs.shape[a]),
6429
+ ...require_backend.rep(rf.length, 1),
6430
+ require_backend.prod(dotShapeX)
6431
+ ]), rhs2.reshape([
6432
+ ...rb.map((a) => rhs.shape[a]),
6433
+ ...require_backend.rep(lf.length, 1),
6434
+ ...rf.map((a) => rhs.shape[a]),
6435
+ require_backend.prod(dotShapeY)
6436
+ ]));
6437
+ }
6438
+ function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
6439
+ const padType = padding.toUpperCase();
6440
+ switch (padType) {
6441
+ case "VALID": return require_backend.rep(inShape.length, [0, 0]);
6442
+ case "SAME":
6443
+ case "SAME_LOWER": {
6444
+ const outShape = inShape.map((size$1, i) => Math.ceil(size$1 / strides[i]));
6445
+ const padSizes = require_backend.zipn(outShape, strides, filterShape, dilation, inShape).map(([o, s, k, d, i]) => Math.max(0, (o - 1) * s + 1 + (k - 1) * d - i));
6446
+ if (padType === "SAME") return padSizes.map((size$1) => [size$1 >> 1, size$1 - (size$1 >> 1)]);
6447
+ else return padSizes.map((size$1) => [size$1 - (size$1 >> 1), size$1 >> 1]);
6448
+ }
6449
+ default: throw new Error(`Unknown padding type: ${padType}`);
6450
+ }
6451
+ }
6452
+ /**
6453
+ * General n-dimensional convolution operator, with optional dilation.
6454
+ *
6455
+ * The semantics of this operation mimic the `jax.lax.conv_general_dilated`
6456
+ * function in JAX, which wraps XLA's general convolution operator.
6457
+ *
6458
+ * Grouped convolutions are not supported right now.
6459
+ */
6460
+ function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
6461
+ if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
6462
+ if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
6463
+ if (typeof padding === "string") {
6464
+ if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
6465
+ padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? require_backend.rep(rhs.ndim - 2, 1), padding);
6466
+ }
6467
+ if (featureGroupCount !== 1) {
6468
+ const G = featureGroupCount;
6469
+ const [N, C_in, ...xs] = lhs.shape;
6470
+ const [C_out, C_in_per_group, ...ks] = rhs.shape;
6471
+ if (C_in % G !== 0) throw new Error(`featureGroupCount=${G} must divide input channels=${C_in}`);
6472
+ if (C_out % G !== 0) throw new Error(`featureGroupCount=${G} must divide output channels=${C_out}`);
6473
+ if (C_in / G !== C_in_per_group) throw new Error(`rhs input channels=${C_in_per_group} must equal lhs input channels / groups=${C_in / G}`);
6474
+ const lhsGrouped = moveaxis(lhs.reshape([
6475
+ N,
6476
+ G,
6477
+ C_in / G,
6478
+ ...xs
6479
+ ]), 1, 0);
6480
+ const rhsGrouped = rhs.reshape([
6481
+ G,
6482
+ C_out / G,
6483
+ C_in_per_group,
6484
+ ...ks
6485
+ ]);
6486
+ const result = conv$1(lhsGrouped, rhsGrouped, {
6487
+ vmapDims: 1,
6488
+ strides: windowStrides,
6489
+ padding,
6490
+ lhsDilation,
6491
+ rhsDilation
6492
+ });
6493
+ const ys = result.shape.slice(3);
6494
+ return moveaxis(result, 0, 1).reshape([
6495
+ N,
6496
+ C_out,
6497
+ ...ys
6498
+ ]);
6499
+ }
6500
+ return conv$1(lhs, rhs, {
6501
+ strides: windowStrides,
6502
+ padding,
6503
+ lhsDilation,
6504
+ rhsDilation
6505
+ });
6506
+ }
6507
+ /** Convenience wrapper around `convGeneralDilated`. */
6508
+ function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation) {
6509
+ return convGeneralDilated(lhs, rhs, windowStrides, padding, {
6510
+ lhsDilation,
6511
+ rhsDilation
6512
+ });
6513
+ }
6514
+ /** Convenience wrapper around `convGeneralDilated`. */
6515
+ function conv(lhs, rhs, windowStrides, padding) {
6516
+ return convGeneralDilated(lhs, rhs, windowStrides, padding);
6517
+ }
6518
+ /** Reduce a computation over padded windows. */
6519
+ function reduceWindow(operand, computation, windowDimensions, windowStrides) {
6520
+ if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
6521
+ if (!windowStrides) windowStrides = require_backend.rep(windowDimensions.length, 1);
6522
+ for (let i = 0; i < operand.ndim; i++) computation = vmap$1(computation, 0);
6523
+ return computation(bind1(Primitive.Pool, [operand], {
6524
+ window: windowDimensions,
6525
+ strides: windowStrides
6526
+ }));
6527
+ }
6528
+ /** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
6529
+ function erf(x) {
6530
+ return erf$1(x);
6531
+ }
6532
+ /**
6533
+ * The complementary error function: `erfc(x) = 1 - erf(x)`.
6534
+ *
6535
+ * This function is more accurate than `1 - erf(x)` for large values of `x`,
6536
+ * where `erf(x)` is very close to 1.
6537
+ */
6538
+ function erfc(x) {
6539
+ return erfc$1(x);
6540
+ }
6541
+ /**
6542
+ * Stops gradient computation.
6543
+ *
6544
+ * Behaves as the identity function but prevents the flow of gradients during
6545
+ * forward or reverse-mode automatic differentiation.
6546
+ */
6547
+ function stopGradient$1(x) {
6548
+ return stopGradient(x);
6549
+ }
6550
+
5949
6551
  //#endregion
5950
6552
  //#region src/library/nn.ts
5951
6553
  var nn_exports = {};
@@ -5954,6 +6556,10 @@ __export(nn_exports, {
5954
6556
  elu: () => elu,
5955
6557
  gelu: () => gelu,
5956
6558
  glu: () => glu,
6559
+ hardSigmoid: () => hardSigmoid,
6560
+ hardSilu: () => hardSilu,
6561
+ hardSwish: () => hardSilu,
6562
+ hardTanh: () => hardTanh,
5957
6563
  identity: () => identity,
5958
6564
  leakyRelu: () => leakyRelu,
5959
6565
  logSigmoid: () => logSigmoid,
@@ -5964,14 +6570,17 @@ __export(nn_exports, {
5964
6570
  oneHot: () => oneHot,
5965
6571
  relu: () => relu,
5966
6572
  relu6: () => relu6,
6573
+ selu: () => selu,
5967
6574
  sigmoid: () => sigmoid,
5968
6575
  silu: () => silu,
5969
6576
  softSign: () => softSign,
5970
6577
  softmax: () => softmax,
5971
6578
  softplus: () => softplus,
6579
+ sparsePlus: () => sparsePlus,
6580
+ sparseSigmoid: () => sparseSigmoid,
5972
6581
  squareplus: () => squareplus,
5973
6582
  standardize: () => standardize,
5974
- swish: () => swish
6583
+ swish: () => silu
5975
6584
  });
5976
6585
  /**
5977
6586
  * Rectified Linear Unit (ReLU) activation function:
@@ -6006,6 +6615,28 @@ function softplus(x) {
6006
6615
  return log(exp(x).add(1));
6007
6616
  }
6008
6617
  /**
6618
+ * @function
6619
+ * Sparse plus function:
6620
+ *
6621
+ * - When `x <= -1`: `0`
6622
+ * - When `-1 < x < 1`: `(x+1)**2 / 4`
6623
+ * - When `x >= 1`: `x`
6624
+ */
6625
+ const sparsePlus = jit$1((x) => {
6626
+ return where(x.ref.lessEqual(-1), 0, where(x.ref.less(1), square(x.ref.add(1)).mul(.25), x));
6627
+ });
6628
+ /**
6629
+ * @function
6630
+ * Sparse sigmoid activation function.
6631
+ *
6632
+ * - When `x <= -1`: `0`
6633
+ * - When `-1 < x < 1`: `(x + 1) / 2`
6634
+ * - When `x >= 1`: `1`
6635
+ */
6636
+ const sparseSigmoid = jit$1((x) => {
6637
+ return clip(x.add(1).mul(.5), 0, 1);
6638
+ });
6639
+ /**
6009
6640
  * Soft-sign activation function, computed element-wise:
6010
6641
  * `softsign(x) = x / (|x| + 1)`.
6011
6642
  */
@@ -6027,17 +6658,6 @@ const silu = jit$1(function silu$1(x) {
6027
6658
  return x.ref.mul(sigmoid(x));
6028
6659
  });
6029
6660
  /**
6030
- * @function
6031
- * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
6032
- * Swish, computed element-wise:
6033
- * `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
6034
- *
6035
- * `swish()` and `silu()` are both aliases for the same function.
6036
- *
6037
- * Reference: https://en.wikipedia.org/wiki/Swish_function
6038
- */
6039
- const swish = silu;
6040
- /**
6041
6661
  * Log-sigmoid activation function, computed element-wise:
6042
6662
  * `log_sigmoid(x) = log(sigmoid(x)) = -log(1 + exp(-x))`.
6043
6663
  */
@@ -6054,6 +6674,19 @@ function leakyRelu(x, negativeSlope = .01) {
6054
6674
  x = fudgeArray(x);
6055
6675
  return where(less(x.ref, 0), x.ref.mul(negativeSlope), x);
6056
6676
  }
6677
+ /** Hard sigmoid activation function: `relu6(x+3)/6`. */
6678
+ function hardSigmoid(x) {
6679
+ return relu6(add(x, 3)).mul(1 / 6);
6680
+ }
6681
+ /** Hard SiLU (swish) activation function: `x * hardSigmoid(x)`. */
6682
+ function hardSilu(x) {
6683
+ x = fudgeArray(x);
6684
+ return x.ref.mul(hardSigmoid(x));
6685
+ }
6686
+ /** Hard tanh activation function: `clip(x, -1, 1)`. */
6687
+ function hardTanh(x) {
6688
+ return clip(x, -1, 1);
6689
+ }
6057
6690
  /**
6058
6691
  * Exponential linear unit activation function.
6059
6692
  *
@@ -6076,6 +6709,20 @@ function celu(x, alpha = 1) {
6076
6709
  }
6077
6710
  /**
6078
6711
  * @function
6712
+ * Scaled exponential linear unit activation.
6713
+ *
6714
+ * Computes the element-wise function:
6715
+ * `selu(x) = lambda * (x > 0 ? x : alpha * (exp(x) - 1))`
6716
+ *
6717
+ * Where `alpha = 1.6732632423543772` and `lambda = 1.0507009873554805`.
6718
+ */
6719
+ const selu = jit$1(function selu$1(x) {
6720
+ const alpha = 1.6732632423543772;
6721
+ const lambda = 1.0507009873554805;
6722
+ return where(x.ref.less(0), expm1(x.ref).mul(alpha), x).mul(lambda);
6723
+ });
6724
+ /**
6725
+ * @function
6079
6726
  * Gaussion error linear unit (GELU) activation function.
6080
6727
  *
6081
6728
  * This is computed element-wise. There are two variants depending on whether
@@ -6229,8 +6876,11 @@ var random_exports = {};
6229
6876
  __export(random_exports, {
6230
6877
  bernoulli: () => bernoulli,
6231
6878
  bits: () => bits,
6879
+ cauchy: () => cauchy,
6232
6880
  exponential: () => exponential,
6881
+ gumbel: () => gumbel,
6233
6882
  key: () => key,
6883
+ laplace: () => laplace,
6234
6884
  normal: () => normal,
6235
6885
  split: () => split,
6236
6886
  uniform: () => uniform
@@ -6289,6 +6939,16 @@ function bernoulli(key$1, p = .5, shape$1 = []) {
6289
6939
  }
6290
6940
  /**
6291
6941
  * @function
6942
+ * Sample from a Cauchy distribution with location 0 and scale 1.
6943
+ *
6944
+ * Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
6945
+ */
6946
+ const cauchy = jit$1(function cauchy$1(key$1, shape$1 = []) {
6947
+ const u = uniform(key$1, shape$1);
6948
+ return tan(u.sub(.5).mul(Math.PI));
6949
+ }, { staticArgnums: [1] });
6950
+ /**
6951
+ * @function
6292
6952
  * Sample exponential random values according to `p(x) = exp(-x)`.
6293
6953
  */
6294
6954
  const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
@@ -6297,6 +6957,30 @@ const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
6297
6957
  }, { staticArgnums: [1] });
6298
6958
  /**
6299
6959
  * @function
6960
+ * Sample from a Gumbel distribution with location 0 and scale 1.
6961
+ *
6962
+ * Uses inverse transform sampling: `x = -log(-log(u))` where u ~ Uniform(0, 1).
6963
+ */
6964
+ const gumbel = jit$1(function gumbel$1(key$1, shape$1 = []) {
6965
+ const u = uniform(key$1, shape$1);
6966
+ return negative(log(negative(log1p(negative(u)))));
6967
+ }, { staticArgnums: [1] });
6968
+ /**
6969
+ * @function
6970
+ * Sample from a Laplace distribution with location 0 and scale 1.
6971
+ *
6972
+ * Uses inverse transform sampling: the CDF is `F(x) = 0.5 + 0.5 * sign(x) * (1 - exp(-|x|))`.
6973
+ * Inverting: `x = -sign(u - 0.5) * log(1 - 2 * |u - 0.5|)`.
6974
+ */
6975
+ const laplace = jit$1(function laplace$1(key$1, shape$1 = []) {
6976
+ const u = uniform(key$1, shape$1);
6977
+ const centered = u.sub(.5);
6978
+ const s = sign(centered.ref);
6979
+ const absVal = absolute(centered);
6980
+ return s.mul(log1p(absVal.mul(-2)).mul(-1));
6981
+ }, { staticArgnums: [1] });
6982
+ /**
6983
+ * @function
6300
6984
  * Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
6301
6985
  *
6302
6986
  * Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
@@ -6405,11 +7089,6 @@ const valueAndGrad = valueAndGrad$1;
6405
7089
  */
6406
7090
  const jacrev = jacrev$1;
6407
7091
  /**
6408
- * @function
6409
- * Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
6410
- */
6411
- const jacobian = jacrev;
6412
- /**
6413
7092
  * Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
6414
7093
  *
6415
7094
  * This can be used to wait for the results of an intermediate computation to
@@ -6445,6 +7124,7 @@ async function devicePut(x, device) {
6445
7124
 
6446
7125
  //#endregion
6447
7126
  exports.Array = Array$1;
7127
+ exports.ClosedJaxpr = ClosedJaxpr;
6448
7128
  exports.DType = require_backend.DType;
6449
7129
  exports.Jaxpr = Jaxpr;
6450
7130
  exports.blockUntilReady = blockUntilReady;
@@ -6454,7 +7134,7 @@ exports.devices = require_backend.devices;
6454
7134
  exports.grad = grad;
6455
7135
  exports.init = require_backend.init;
6456
7136
  exports.jacfwd = jacfwd;
6457
- exports.jacobian = jacobian;
7137
+ exports.jacobian = jacrev;
6458
7138
  exports.jacrev = jacrev;
6459
7139
  exports.jit = jit;
6460
7140
  exports.jvp = jvp;
@@ -6499,5 +7179,4 @@ Object.defineProperty(exports, 'tree', {
6499
7179
  });
6500
7180
  exports.valueAndGrad = valueAndGrad;
6501
7181
  exports.vjp = vjp;
6502
- exports.vmap = vmap;
6503
- //# sourceMappingURL=index.cjs.map
7182
+ exports.vmap = vmap;