@jax-js/jax 0.1.3 → 0.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.cjs CHANGED
@@ -8,9 +8,9 @@ var __hasOwnProp = Object.prototype.hasOwnProperty;
8
8
  var __commonJS = (cb, mod$1) => function() {
9
9
  return mod$1 || (0, cb[__getOwnPropNames(cb)[0]])((mod$1 = { exports: {} }).exports, mod$1), mod$1.exports;
10
10
  };
11
- var __export = (target, all) => {
12
- for (var name in all) __defProp(target, name, {
13
- get: all[name],
11
+ var __export = (target, all$1) => {
12
+ for (var name in all$1) __defProp(target, name, {
13
+ get: all$1[name],
14
14
  enumerable: true
15
15
  });
16
16
  };
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
30
30
  }) : target, mod$1));
31
31
 
32
32
  //#endregion
33
- const require_backend = require('./backend-CmaidnkQ.cjs');
33
+ const require_backend = require('./backend-DziQSaoQ.cjs');
34
34
 
35
35
  //#region src/frontend/convolution.ts
36
36
  /**
@@ -362,6 +362,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
362
362
  Primitive$1["Mul"] = "mul";
363
363
  Primitive$1["Idiv"] = "idiv";
364
364
  Primitive$1["Mod"] = "mod";
365
+ Primitive$1["Min"] = "min";
366
+ Primitive$1["Max"] = "max";
365
367
  Primitive$1["Neg"] = "neg";
366
368
  Primitive$1["Reciprocal"] = "reciprocal";
367
369
  Primitive$1["Floor"] = "floor";
@@ -369,7 +371,6 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
369
371
  Primitive$1["StopGradient"] = "stop_gradient";
370
372
  Primitive$1["Cast"] = "cast";
371
373
  Primitive$1["Bitcast"] = "bitcast";
372
- Primitive$1["RandomBits"] = "random_bits";
373
374
  Primitive$1["Sin"] = "sin";
374
375
  Primitive$1["Cos"] = "cos";
375
376
  Primitive$1["Asin"] = "asin";
@@ -379,8 +380,6 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
379
380
  Primitive$1["Erf"] = "erf";
380
381
  Primitive$1["Erfc"] = "erfc";
381
382
  Primitive$1["Sqrt"] = "sqrt";
382
- Primitive$1["Min"] = "min";
383
- Primitive$1["Max"] = "max";
384
383
  Primitive$1["Reduce"] = "reduce";
385
384
  Primitive$1["Dot"] = "dot";
386
385
  Primitive$1["Conv"] = "conv";
@@ -388,14 +387,22 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
388
387
  Primitive$1["PoolTranspose"] = "pool_transpose";
389
388
  Primitive$1["Compare"] = "compare";
390
389
  Primitive$1["Where"] = "where";
390
+ Primitive$1["Concatenate"] = "concatenate";
391
+ Primitive$1["Split"] = "split";
392
+ Primitive$1["RandomBits"] = "random_bits";
393
+ Primitive$1["Gather"] = "gather";
391
394
  Primitive$1["Transpose"] = "transpose";
392
395
  Primitive$1["Broadcast"] = "broadcast";
393
396
  Primitive$1["Reshape"] = "reshape";
394
397
  Primitive$1["Flip"] = "flip";
395
398
  Primitive$1["Shrink"] = "shrink";
396
399
  Primitive$1["Pad"] = "pad";
397
- Primitive$1["Gather"] = "gather";
398
- Primitive$1["JitCall"] = "jit_call";
400
+ Primitive$1["Sort"] = "sort";
401
+ Primitive$1["Argsort"] = "argsort";
402
+ Primitive$1["TriangularSolve"] = "triangular_solve";
403
+ Primitive$1["Cholesky"] = "cholesky";
404
+ Primitive$1["LU"] = "lu";
405
+ Primitive$1["Jit"] = "jit";
399
406
  return Primitive$1;
400
407
  }({});
401
408
  let CompareOp = /* @__PURE__ */ function(CompareOp$1) {
@@ -417,6 +424,12 @@ function idiv(x, y) {
417
424
  function mod(x, y) {
418
425
  return bind1(Primitive.Mod, [x, y]);
419
426
  }
427
+ function min$1(x, y) {
428
+ return bind1(Primitive.Min, [x, y]);
429
+ }
430
+ function max$1(x, y) {
431
+ return bind1(Primitive.Max, [x, y]);
432
+ }
420
433
  function neg(x) {
421
434
  return bind1(Primitive.Neg, [x]);
422
435
  }
@@ -438,12 +451,6 @@ function cast(x, dtype) {
438
451
  function bitcast(x, dtype) {
439
452
  return bind1(Primitive.Bitcast, [x], { dtype });
440
453
  }
441
- function randomBits(k0, k1, shape$1, mode = "xor") {
442
- return bind1(Primitive.RandomBits, [k0, k1], {
443
- shape: shape$1,
444
- mode
445
- });
446
- }
447
454
  function sin$1(x) {
448
455
  return bind1(Primitive.Sin, [x]);
449
456
  }
@@ -471,12 +478,6 @@ function erfc$1(x) {
471
478
  function sqrt$1(x) {
472
479
  return bind1(Primitive.Sqrt, [x]);
473
480
  }
474
- function min$1(x, y) {
475
- return bind1(Primitive.Min, [x, y]);
476
- }
477
- function max$1(x, y) {
478
- return bind1(Primitive.Max, [x, y]);
479
- }
480
481
  function reduce(x, op, axis = null, opts) {
481
482
  if (!require_backend.AluGroup.Reduce.has(op)) throw new TypeError(`Invalid reduce operation: ${op}`);
482
483
  axis = require_backend.normalizeAxis(axis, ndim$1(x));
@@ -532,6 +533,41 @@ function where$1(cond, x, y) {
532
533
  y
533
534
  ]);
534
535
  }
536
+ function concatenate$1(xs, axis) {
537
+ if (xs.length === 0) throw new Error("concatenate requires at least one input");
538
+ const avals = xs.map((x) => ShapedArray.fromAval(getAval(x)));
539
+ axis = require_backend.checkAxis(axis, avals[0].ndim);
540
+ for (const x of avals) if (x.ndim !== avals[0].ndim || !x.shape.every((s, i) => i === axis || s === avals[0].shape[i])) throw new Error(`Concatenate: inputs ${avals[0]} and ${x} must match shapes except on axis ${axis}`);
541
+ return bind1(Primitive.Concatenate, xs, { axis });
542
+ }
543
+ function split$2(x, axis, sizes) {
544
+ axis = require_backend.checkAxis(axis, ndim$1(x));
545
+ if (sizes.some((s) => s < 0 || !Number.isInteger(s))) throw new Error(`split: sizes must be nonnegative integers, got ${JSON.stringify(sizes)}`);
546
+ const totalSize = sizes.reduce((a, b) => a + b, 0);
547
+ if (totalSize !== getShape(x)[axis]) throw new Error(`split: sizes must sum to the size of the axis ${axis}, got ${totalSize}`);
548
+ return bind(Primitive.Split, [x], {
549
+ axis,
550
+ sizes
551
+ });
552
+ }
553
+ function randomBits(k0, k1, shape$1, mode = "xor") {
554
+ if (!require_backend.deepEqual(k0.shape, k1.shape) || k0.dtype !== require_backend.DType.Uint32 || k1.dtype !== require_backend.DType.Uint32) throw new Error(`randomBits: key parts must be uint32 with the same shape, got ${ShapedArray.fromAval(k0.aval)} and ${ShapedArray.fromAval(k1.aval)}`);
555
+ return bind1(Primitive.RandomBits, [k0, k1], {
556
+ shape: shape$1,
557
+ mode
558
+ });
559
+ }
560
+ function gather(x, indices, axis, outDim) {
561
+ if (indices.length === 0) throw new Error("gather() requires at least one index");
562
+ if (!Array.isArray(axis) || axis.length !== indices.length) throw new Error(`Invalid gather() axis: expected ${indices.length} axes, got ${JSON.stringify(axis)}`);
563
+ axis = axis.map((a) => require_backend.checkAxis(a, ndim$1(x)));
564
+ if (new Set(axis).size !== axis.length) throw new Error(`Invalid gather() axis: duplicate axes ${JSON.stringify(axis)}`);
565
+ outDim = require_backend.checkAxis(outDim, ndim$1(x) - axis.length + 1);
566
+ return bind1(Primitive.Gather, [x, ...indices], {
567
+ axis,
568
+ outDim
569
+ });
570
+ }
535
571
  function transpose$1(x, perm) {
536
572
  perm = perm ? perm.map((a) => require_backend.checkAxis(a, ndim$1(x))) : require_backend.range(ndim$1(x)).reverse();
537
573
  if (!require_backend.isPermutation(perm, ndim$1(x))) throw new Error(`Invalid transpose permutation for ${ndim$1(x)} axes: ${JSON.stringify(perm)}`);
@@ -581,16 +617,39 @@ function pad$1(x, width) {
581
617
  } else if (width.length !== nd) throw new Error(`Invalid pad(): expected ${nd} axes, got ${width.length}`);
582
618
  return bind1(Primitive.Pad, [x], { width });
583
619
  }
584
- function gather(x, indices, axis, outDim) {
585
- if (indices.length === 0) throw new Error("gather() requires at least one index");
586
- if (!Array.isArray(axis) || axis.length !== indices.length) throw new Error(`Invalid gather() axis: expected ${indices.length} axes, got ${JSON.stringify(axis)}`);
587
- axis = axis.map((a) => require_backend.checkAxis(a, ndim$1(x)));
588
- if (new Set(axis).size !== axis.length) throw new Error(`Invalid gather() axis: duplicate axes ${JSON.stringify(axis)}`);
589
- outDim = require_backend.checkAxis(outDim, ndim$1(x) - axis.length + 1);
590
- return bind1(Primitive.Gather, [x, ...indices], {
591
- axis,
592
- outDim
593
- });
620
+ function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
621
+ const as = getShape(a);
622
+ const bs = getShape(b);
623
+ if (as.length < 2 || bs.length < 2) throw new Error(`triangular_solve: must be >=2D, got a=${as}, b=${bs}`);
624
+ const n = as[as.length - 2];
625
+ if (n !== as[as.length - 1] || n !== bs[bs.length - 1]) throw new Error(`triangular_solve: incompatible shapes a=${as}, b=${bs}`);
626
+ if (lower) {
627
+ a = flip$1(a, [-2, -1]);
628
+ b = flip$1(b, [-1]);
629
+ }
630
+ let x = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
631
+ if (lower) x = flip$1(x, [-1]);
632
+ return x;
633
+ }
634
+ function cholesky$2(x) {
635
+ const aval = ShapedArray.fromAval(getAval(x));
636
+ if (aval.ndim < 2 || aval.shape[aval.ndim - 1] !== aval.shape[aval.ndim - 2]) throw new Error(`cholesky: expected batch of square matrices, got ${aval}`);
637
+ return bind1(Primitive.Cholesky, [x]);
638
+ }
639
+ function lu$1(x) {
640
+ const aval = ShapedArray.fromAval(getAval(x));
641
+ if (aval.ndim < 2) throw new Error(`lu: expected batch of matrices, got ${aval}`);
642
+ return bind(Primitive.LU, [x]);
643
+ }
644
+ function sort$1(x) {
645
+ const nd = ndim$1(x);
646
+ if (nd === 0) throw new Error("sort: requires at least 1D input");
647
+ return bind1(Primitive.Sort, [x]);
648
+ }
649
+ function argsort$1(x) {
650
+ const nd = ndim$1(x);
651
+ if (nd === 0) throw new Error("argsort: requires at least 1D input");
652
+ return bind(Primitive.Argsort, [x]);
594
653
  }
595
654
  function bind1(prim, args, params = {}) {
596
655
  const [results] = bind(prim, args, params);
@@ -690,6 +749,9 @@ var Tracer = class Tracer {
690
749
  mul(other) {
691
750
  return mul(this, other);
692
751
  }
752
+ mod(other) {
753
+ return mod(this, other);
754
+ }
693
755
  greater(other) {
694
756
  return greater$1(this, other);
695
757
  }
@@ -753,7 +815,7 @@ var Tracer = class Tracer {
753
815
  if (require_backend.isFloatDtype(this.dtype)) return this.mul(reciprocal$1(other));
754
816
  return idiv(this, other);
755
817
  }
756
- /** Return specified diagonals. See `numpy.diagonal` for full docs. */
818
+ /** Return specified diagonals. See `jax.numpy.diagonal` for full docs. */
757
819
  diagonal(offset = 0, axis1 = 0, axis2 = 1) {
758
820
  if (!Number.isInteger(offset)) throw new TypeError(`offset must be an integer, got ${offset}`);
759
821
  if (offset < 0) return this.diagonal(-offset, axis2, axis1);
@@ -802,8 +864,42 @@ var Tracer = class Tracer {
802
864
  */
803
865
  *[Symbol.iterator]() {
804
866
  if (this.ndim === 0) throw new Error("Cannot iterate over a scalar array");
805
- for (let i = 0; i < this.shape[0]; i++) yield this.ref.slice(i);
806
- this.dispose();
867
+ let residual = this;
868
+ const subarrayShape = this.shape.slice(1);
869
+ for (let i = 0; i < this.shape[0]; i++) {
870
+ const lr = split$2(residual, 0, [1, residual.shape[0] - 1]);
871
+ yield lr[0].reshape(subarrayShape);
872
+ residual = lr[1];
873
+ }
874
+ residual.dispose();
875
+ }
876
+ /**
877
+ * Return a sorted copy of an array in ascending order.
878
+ *
879
+ * See `jax.numpy.sort` for full docs.
880
+ */
881
+ sort(axis = -1) {
882
+ axis = require_backend.checkAxis(axis, this.ndim);
883
+ if (this.shape[axis] <= 1) return this;
884
+ if (axis === this.ndim - 1) return sort$1(this);
885
+ const perm = require_backend.range(this.ndim);
886
+ perm.splice(axis, 1);
887
+ perm.push(axis);
888
+ return sort$1(this.transpose(perm)).transpose(require_backend.invertPermutation(perm));
889
+ }
890
+ /**
891
+ * Return the indices that would sort an array. This may not be a stable
892
+ * sorting algorithm; it need not preserve order of indices in ties.
893
+ *
894
+ * See `jax.numpy.argsort` for full docs.
895
+ */
896
+ argsort(axis = -1) {
897
+ axis = require_backend.checkAxis(axis, this.ndim);
898
+ if (axis === this.ndim - 1) return argsort$1(this)[1];
899
+ const perm = require_backend.range(this.ndim);
900
+ perm.splice(axis, 1);
901
+ perm.push(axis);
902
+ return argsort$1(this.transpose(perm))[1].transpose(require_backend.invertPermutation(perm));
807
903
  }
808
904
  /**
809
905
  * Slice an array along one or more axes.
@@ -922,6 +1018,12 @@ var ShapedArray = class ShapedArray {
922
1018
  get ndim() {
923
1019
  return this.shape.length;
924
1020
  }
1021
+ get size() {
1022
+ return require_backend.prod(this.shape);
1023
+ }
1024
+ scalar() {
1025
+ return new ShapedArray([], this.dtype, this.weakType);
1026
+ }
925
1027
  toString() {
926
1028
  return `${this.dtype}[${this.shape.join(",")}]`;
927
1029
  }
@@ -1221,13 +1323,13 @@ var Jaxpr = class Jaxpr {
1221
1323
  }
1222
1324
  return new Jaxpr(this.inBinders, liveEqns.reverse(), outs);
1223
1325
  }
1224
- /** Flattens nested JitCall in a Jaxpr. Useful for handling jit-of-jit. */
1326
+ /** Flattens nested Jit in a Jaxpr. Useful for handling jit-of-jit. */
1225
1327
  flatten() {
1226
- if (!this.eqns.some((eqn) => eqn.primitive === Primitive.JitCall)) return this;
1328
+ if (!this.eqns.some((eqn) => eqn.primitive === Primitive.Jit)) return this;
1227
1329
  const newEqns = [];
1228
1330
  const varMap = /* @__PURE__ */ new Map();
1229
1331
  const varMapF = (x) => x instanceof Var ? varMap.get(x) ?? x : x;
1230
- for (const eqn of this.eqns) if (eqn.primitive === Primitive.JitCall) {
1332
+ for (const eqn of this.eqns) if (eqn.primitive === Primitive.Jit) {
1231
1333
  const jaxpr = eqn.params.jaxpr.flatten();
1232
1334
  const translation = /* @__PURE__ */ new Map();
1233
1335
  const translationF = (x) => x instanceof Var ? translation.get(x) : x;
@@ -1328,19 +1430,48 @@ function evalJaxpr(jaxpr, args) {
1328
1430
  function jaxprAsFun(jaxpr) {
1329
1431
  return (...args) => evalJaxpr(jaxpr, args);
1330
1432
  }
1433
+ /** Jaxpr with a collection of associated, traced constants. */
1434
+ var ClosedJaxpr = class ClosedJaxpr {
1435
+ constructor(jaxpr, consts) {
1436
+ this.jaxpr = jaxpr;
1437
+ this.consts = consts;
1438
+ }
1439
+ /** String representation of this Jaxpr. */
1440
+ toString() {
1441
+ return this.jaxpr.toString();
1442
+ }
1443
+ /** Apply a function to the underlying Jaxpr. */
1444
+ mapJaxpr(f) {
1445
+ return new ClosedJaxpr(f(this.jaxpr), this.consts);
1446
+ }
1447
+ /** Dispose of the constants in this Jaxpr. */
1448
+ dispose() {
1449
+ for (const c of this.consts) c.dispose();
1450
+ }
1451
+ };
1331
1452
  /** Tracer that records its operations to dynamically construct a Jaxpr. */
1332
1453
  var JaxprTracer = class extends Tracer {
1454
+ #rc;
1333
1455
  constructor(trace$1, aval) {
1334
1456
  super(trace$1);
1335
1457
  this.aval = aval;
1458
+ this.#rc = 1;
1336
1459
  }
1337
1460
  toString() {
1338
1461
  return `JaxprTracer(${this.aval.toString()})`;
1339
1462
  }
1340
1463
  get ref() {
1464
+ if (this.#rc <= 0) throw new UseAfterFreeError(this);
1465
+ this.#rc++;
1341
1466
  return this;
1342
1467
  }
1343
- dispose() {}
1468
+ dispose() {
1469
+ if (this.#rc <= 0) throw new UseAfterFreeError(this);
1470
+ this.#rc--;
1471
+ }
1472
+ trackLiftedConstant() {
1473
+ this.#rc++;
1474
+ }
1344
1475
  };
1345
1476
  /** Analogous to the 'DynamicJaxprTrace' class in JAX. */
1346
1477
  var JaxprTrace = class extends Trace {
@@ -1353,17 +1484,24 @@ var JaxprTrace = class extends Trace {
1353
1484
  }
1354
1485
  /** Register a constant / literal in this Jaxpr. */
1355
1486
  getOrMakeConstTracer(val) {
1487
+ if (!(val instanceof Tracer)) val = pureArray(val);
1356
1488
  let tracer = this.builder.constTracers.get(val);
1357
1489
  if (tracer === void 0) {
1358
1490
  tracer = this.builder.newTracer(this, ShapedArray.fromAval(getAval(val)));
1359
- this.builder.addConst(tracer, val instanceof Tracer ? val.ref : array(val));
1491
+ this.builder.addConst(tracer, val);
1492
+ } else {
1493
+ val.dispose();
1494
+ tracer.trackLiftedConstant();
1360
1495
  }
1361
1496
  return tracer;
1362
1497
  }
1363
1498
  pure = this.getOrMakeConstTracer;
1364
1499
  lift = this.getOrMakeConstTracer;
1365
1500
  processPrimitive(primitive, tracers, params) {
1366
- const avalsIn = tracers.map((t) => t.aval);
1501
+ const avalsIn = tracers.map((t) => {
1502
+ t.dispose();
1503
+ return t.aval;
1504
+ });
1367
1505
  const avalsOut = abstractEvalRules[primitive](avalsIn, params);
1368
1506
  const outTracers = avalsOut.map((aval) => this.builder.newTracer(this, aval));
1369
1507
  this.builder.addEqn(new JaxprEqn(primitive, tracers.map((t) => this.builder.getVar(t)), params, outTracers.map((t) => this.builder.addVar(t))));
@@ -1406,20 +1544,17 @@ var JaxprBuilder = class {
1406
1544
  return v;
1407
1545
  }
1408
1546
  build(inTracers, outTracers) {
1409
- let [constVars, consts] = require_backend.unzip2(this.constVals.entries());
1547
+ const [constVars, consts] = require_backend.unzip2(this.constVals.entries());
1410
1548
  const t2v = this.getVar.bind(this);
1411
1549
  const inBinders = [...constVars, ...inTracers.map(t2v)];
1412
1550
  const outVars = outTracers.map(t2v);
1413
- let jaxpr = new Jaxpr(inBinders, this.eqns, outVars);
1551
+ const jaxpr = new Jaxpr(inBinders, this.eqns, outVars);
1414
1552
  typecheckJaxpr(jaxpr);
1415
- [jaxpr, consts] = _inlineLiterals(jaxpr, consts);
1416
- return {
1417
- jaxpr,
1418
- consts
1419
- };
1553
+ const cjaxpr = new ClosedJaxpr(jaxpr, consts);
1554
+ return _inlineLiterals(cjaxpr);
1420
1555
  }
1421
1556
  };
1422
- function _inlineLiterals(jaxpr, consts) {
1557
+ function _inlineLiterals({ jaxpr, consts }) {
1423
1558
  const literals = /* @__PURE__ */ new Map();
1424
1559
  const constBinders = [];
1425
1560
  const newConsts = [];
@@ -1434,7 +1569,7 @@ function _inlineLiterals(jaxpr, consts) {
1434
1569
  const newOuts = jaxpr.outs.map((x) => literals.get(x) ?? x);
1435
1570
  const newJaxpr = new Jaxpr([...constBinders, ...jaxpr.inBinders.slice(consts.length)], newEqns, newOuts);
1436
1571
  typecheckJaxpr(newJaxpr);
1437
- return [newJaxpr, newConsts];
1572
+ return new ClosedJaxpr(newJaxpr, newConsts);
1438
1573
  }
1439
1574
  function binopAbstractEval([x, y]) {
1440
1575
  if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("binopAbstractEval expects ShapedArray inputs");
@@ -1453,6 +1588,8 @@ const abstractEvalRules = {
1453
1588
  [Primitive.Mul]: binopAbstractEval,
1454
1589
  [Primitive.Idiv]: binopAbstractEval,
1455
1590
  [Primitive.Mod]: binopAbstractEval,
1591
+ [Primitive.Min]: binopAbstractEval,
1592
+ [Primitive.Max]: binopAbstractEval,
1456
1593
  [Primitive.Neg]: vectorizedUnopAbstractEval,
1457
1594
  [Primitive.Reciprocal]: vectorizedUnopAbstractEval,
1458
1595
  [Primitive.Floor]: vectorizedUnopAbstractEval,
@@ -1466,12 +1603,6 @@ const abstractEvalRules = {
1466
1603
  if (require_backend.byteWidth(x.dtype) !== require_backend.byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
1467
1604
  return [new ShapedArray(x.shape, dtype, false)];
1468
1605
  },
1469
- [Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
1470
- if (k0.dtype !== require_backend.DType.Uint32 || k1.dtype !== require_backend.DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
1471
- const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
1472
- if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
1473
- return [new ShapedArray(shape$1, require_backend.DType.Uint32, false)];
1474
- },
1475
1606
  [Primitive.Sin]: vectorizedUnopAbstractEval,
1476
1607
  [Primitive.Cos]: vectorizedUnopAbstractEval,
1477
1608
  [Primitive.Asin]: vectorizedUnopAbstractEval,
@@ -1481,8 +1612,6 @@ const abstractEvalRules = {
1481
1612
  [Primitive.Erf]: vectorizedUnopAbstractEval,
1482
1613
  [Primitive.Erfc]: vectorizedUnopAbstractEval,
1483
1614
  [Primitive.Sqrt]: vectorizedUnopAbstractEval,
1484
- [Primitive.Min]: binopAbstractEval,
1485
- [Primitive.Max]: binopAbstractEval,
1486
1615
  [Primitive.Reduce]([x], { axis }) {
1487
1616
  const axisSet = new Set(axis);
1488
1617
  const newShape = x.shape.filter((_, i) => !axisSet.has(i));
@@ -1504,7 +1633,7 @@ const abstractEvalRules = {
1504
1633
  return [new ShapedArray(shape$1, dtype, weakType)];
1505
1634
  },
1506
1635
  [Primitive.Conv]([lhs, rhs], params) {
1507
- const { dtype, weakType } = promoteAvals(new ShapedArray([], lhs.dtype, lhs.weakType), new ShapedArray([], rhs.dtype, rhs.weakType));
1636
+ const { dtype, weakType } = promoteAvals(lhs.scalar(), rhs.scalar());
1508
1637
  const shape$1 = checkConvShape(lhs.shape, rhs.shape, params);
1509
1638
  return [new ShapedArray(shape$1, dtype, weakType)];
1510
1639
  },
@@ -1515,6 +1644,40 @@ const abstractEvalRules = {
1515
1644
  const shape$1 = require_backend.generalBroadcast(cond.shape, xy.shape);
1516
1645
  return [new ShapedArray(shape$1, xy.dtype, xy.weakType)];
1517
1646
  },
1647
+ [Primitive.Concatenate](xs, { axis }) {
1648
+ if (xs.length === 0) throw new TypeError("Concatenate requires at least one input");
1649
+ for (const x of xs) if (x.ndim !== xs[0].ndim || !x.shape.every((s, i) => i === axis || s === xs[0].shape[i])) throw new TypeError(`Concatenate: inputs ${xs[0]} and ${x} must match shapes except on axis ${axis}`);
1650
+ const shape$1 = xs[0].shape.slice();
1651
+ shape$1[axis] = xs.reduce((sum$1, x) => sum$1 + x.shape[axis], 0);
1652
+ const { dtype, weakType } = xs.map((x) => x.scalar()).reduce(promoteAvals);
1653
+ return [new ShapedArray(shape$1, dtype, weakType)];
1654
+ },
1655
+ [Primitive.Split]([x], { axis, sizes }) {
1656
+ const totalSize = sizes.reduce((a, b) => a + b, 0);
1657
+ if (x.shape[axis] !== totalSize) throw new TypeError(`Split: sizes ${sizes} do not sum to dimension ${x.shape[axis]} on axis ${axis}`);
1658
+ return sizes.map((size$1) => {
1659
+ return new ShapedArray(x.shape.toSpliced(axis, 1, size$1), x.dtype, x.weakType);
1660
+ });
1661
+ },
1662
+ [Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
1663
+ if (k0.dtype !== require_backend.DType.Uint32 || k1.dtype !== require_backend.DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
1664
+ if (!require_backend.deepEqual(k0.shape, k1.shape)) throw new TypeError(`RandomBits: Keys have different shapes ${k0.shape} and ${k1.shape}`);
1665
+ if (!require_backend.deepEqual(shape$1.slice(0, k0.ndim), k0.shape)) throw new TypeError(`RandomBits: generated shape ${shape$1} must match key shape ${k0.shape}`);
1666
+ return [new ShapedArray(shape$1, require_backend.DType.Uint32, false)];
1667
+ },
1668
+ [Primitive.Gather]([x, ...indices], { axis, outDim }) {
1669
+ for (const a of indices) if (a.dtype !== require_backend.DType.Int32 && a.dtype !== require_backend.DType.Uint32) throw new TypeError(`Gather indices must be Int32 or Uint32, got ${a.dtype}`);
1670
+ if (axis.length !== indices.length) throw new TypeError(`Gather: ${axis} axes but ${indices.length} indices`);
1671
+ if (indices.length === 0) throw new TypeError("Gather must have 1+ indices with same shape");
1672
+ if (axis.some((a) => a < 0 || a >= x.shape.length)) throw new TypeError("Gather axis out of bounds");
1673
+ if (outDim < 0 || outDim > x.shape.length - axis.length) throw new TypeError("Gather outDim out of bounds");
1674
+ const axisSet = new Set(axis);
1675
+ if (axisSet.size !== axis.length) throw new TypeError("Gather axes are not unique");
1676
+ const gatherShape = indices.reduce((shape$1, a) => require_backend.generalBroadcast(shape$1, a.shape), []);
1677
+ const newShape = x.shape.filter((_, i) => !axisSet.has(i));
1678
+ newShape.splice(outDim, 0, ...gatherShape);
1679
+ return [new ShapedArray(newShape, x.dtype, x.weakType)];
1680
+ },
1518
1681
  [Primitive.Transpose]([x], { perm }) {
1519
1682
  return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype, x.weakType)];
1520
1683
  },
@@ -1535,23 +1698,41 @@ const abstractEvalRules = {
1535
1698
  const newShape = x.shape.map((dim, i) => dim + width[i][0] + width[i][1]);
1536
1699
  return [new ShapedArray(newShape, x.dtype, x.weakType)];
1537
1700
  },
1538
- [Primitive.Gather]([x, ...indices], { axis, outDim }) {
1539
- for (const a of indices) if (a.dtype !== require_backend.DType.Int32 && a.dtype !== require_backend.DType.Uint32) throw new TypeError(`Gather indices must be Int32 or Uint32, got ${a.dtype}`);
1540
- if (axis.length !== indices.length) throw new TypeError(`Gather: ${axis} axes but ${indices.length} indices`);
1541
- if (indices.length === 0) throw new TypeError("Gather must have 1+ indices with same shape");
1542
- if (axis.some((a) => a < 0 || a >= x.shape.length)) throw new TypeError("Gather axis out of bounds");
1543
- if (outDim < 0 || outDim > x.shape.length - axis.length) throw new TypeError("Gather outDim out of bounds");
1544
- const axisSet = new Set(axis);
1545
- if (axisSet.size !== axis.length) throw new TypeError("Gather axes are not unique");
1546
- const gatherShape = indices.reduce((shape$1, a) => require_backend.generalBroadcast(shape$1, a.shape), []);
1547
- const newShape = x.shape.filter((_, i) => !axisSet.has(i));
1548
- newShape.splice(outDim, 0, ...gatherShape);
1549
- return [new ShapedArray(newShape, x.dtype, x.weakType)];
1701
+ [Primitive.Sort]([x]) {
1702
+ if (x.ndim === 0) throw new TypeError("sort: requires at least 1D input");
1703
+ return [ShapedArray.fromAval(x)];
1704
+ },
1705
+ [Primitive.Argsort]([x]) {
1706
+ if (x.ndim === 0) throw new TypeError("argsort: requires at least 1D input");
1707
+ return [ShapedArray.fromAval(x), new ShapedArray(x.shape, require_backend.DType.Int32, false)];
1708
+ },
1709
+ [Primitive.TriangularSolve]([a, b]) {
1710
+ if (a.ndim < 2) throw new TypeError(`triangular_solve: a must be at least 2D, got ${a}`);
1711
+ if (b.ndim < 2) throw new TypeError(`triangular_solve: b must be at least 2D, got ${b}`);
1712
+ const [m, n] = a.shape.slice(-2);
1713
+ const [_batch, q] = b.shape.slice(-2);
1714
+ if (!require_backend.deepEqual(a.shape.slice(0, -2), b.shape.slice(0, -2)) || a.dtype !== b.dtype || m !== n || n !== q) throw new TypeError(`triangular_solve: mismatch ${a} vs ${b}`);
1715
+ return [new ShapedArray(b.shape, b.dtype, a.weakType && b.weakType)];
1716
+ },
1717
+ [Primitive.Cholesky]([a]) {
1718
+ if (a.ndim < 2) throw new TypeError(`cholesky: requires at least 2D input, got ${a}`);
1719
+ if (a.shape[a.ndim - 2] !== a.shape[a.ndim - 1]) throw new TypeError(`cholesky: must be square, got ${a}`);
1720
+ return [ShapedArray.fromAval(a)];
1721
+ },
1722
+ [Primitive.LU]([a]) {
1723
+ if (a.ndim < 2) throw new TypeError(`lu: requires at least 2D input, got ${a}`);
1724
+ const batch = a.shape.slice(0, -2);
1725
+ const [m, n] = a.shape.slice(-2);
1726
+ return [
1727
+ ShapedArray.fromAval(a),
1728
+ new ShapedArray([...batch, Math.min(m, n)], require_backend.DType.Int32, false),
1729
+ new ShapedArray([...batch, m], require_backend.DType.Int32, false)
1730
+ ];
1550
1731
  },
1551
- [Primitive.JitCall](args, { jaxpr }) {
1732
+ [Primitive.Jit](args, { jaxpr }) {
1552
1733
  const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
1553
- if (args.length !== inTypes.length) throw new TypeError(`jit_call expected ${inTypes.length} arguments, got ${args.length}`);
1554
- for (let i = 0; i < inTypes.length; i++) if (!args[i].equals(inTypes[i])) throw new TypeError(`jit_call argument ${i} has type ${args[i]}, expected ${inTypes[i]}`);
1734
+ if (args.length !== inTypes.length) throw new TypeError(`jit expected ${inTypes.length} arguments, got ${args.length}`);
1735
+ for (let i = 0; i < inTypes.length; i++) if (!args[i].equals(inTypes[i])) throw new TypeError(`jit argument ${i} has type ${args[i]}, expected ${inTypes[i]}`);
1555
1736
  return outTypes;
1556
1737
  }
1557
1738
  };
@@ -1587,11 +1768,10 @@ function makeJaxpr$1(f, opts) {
1587
1768
  const tracersIn = avalsIn.map((aval) => trace$1.newArg(typeof aval === "object" ? aval : pureArray(aval)));
1588
1769
  const outs = fFlat(...tracersIn);
1589
1770
  const tracersOut = outs.map((out) => fullRaise(trace$1, out));
1590
- const { jaxpr, consts } = builder.build(tracersIn, tracersOut);
1771
+ const jaxpr = builder.build(tracersIn, tracersOut);
1591
1772
  if (outTree.value === void 0) throw new Error("outTree was not set in makeJaxpr");
1592
1773
  return {
1593
- jaxpr: jaxpr.simplify(),
1594
- consts,
1774
+ jaxpr: jaxpr.mapJaxpr((j) => j.simplify()),
1595
1775
  treedef: outTree.value
1596
1776
  };
1597
1777
  } catch (_) {
@@ -1610,22 +1790,29 @@ function jit$1(f, opts) {
1610
1790
  const avalsInFlat = argsFlat.map((x) => ShapedArray.fromAval(getAval(x)));
1611
1791
  const avalsIn = unflatten(inTree, avalsInFlat);
1612
1792
  const jaxprArgs = joinIdx(args.length, staticArgs, avalsIn, staticArgnums);
1613
- const { jaxpr, consts, treedef: outTree } = require_backend.runWithCache(cache, jaxprArgs, () => makeJaxpr$1(f, opts)(...jaxprArgs));
1614
- const outs = bind(Primitive.JitCall, [...consts.map((c) => c.ref), ...argsFlat], {
1793
+ const { jaxpr, treedef: outTree } = require_backend.runWithCache(cache, jaxprArgs, () => makeJaxpr$1(f, opts)(...jaxprArgs));
1794
+ const outs = bind(Primitive.Jit, [...jaxpr.consts.map((c) => c.ref), ...argsFlat], {
1615
1795
  name: f.name || "closure",
1616
- jaxpr,
1617
- numConsts: consts.length
1796
+ jaxpr: jaxpr.jaxpr,
1797
+ numConsts: jaxpr.consts.length
1618
1798
  });
1619
1799
  return unflatten(outTree, outs);
1620
1800
  });
1621
1801
  result.dispose = () => {
1622
- for (const { consts } of cache.values()) for (const c of consts) c.dispose();
1802
+ for (const { jaxpr } of cache.values()) jaxpr.dispose();
1623
1803
  };
1624
1804
  return result;
1625
1805
  }
1626
1806
 
1627
1807
  //#endregion
1628
1808
  //#region src/frontend/jit.ts
1809
+ const routinePrimitives = new Map([
1810
+ [Primitive.Sort, require_backend.Routines.Sort],
1811
+ [Primitive.Argsort, require_backend.Routines.Argsort],
1812
+ [Primitive.TriangularSolve, require_backend.Routines.TriangularSolve],
1813
+ [Primitive.Cholesky, require_backend.Routines.Cholesky],
1814
+ [Primitive.LU, require_backend.Routines.LU]
1815
+ ]);
1629
1816
  /** Result of compiling a Jaxpr. Can be evaluated on a series of inputs. */
1630
1817
  var JitProgram = class {
1631
1818
  constructor(backend, steps, inputs, outputs) {
@@ -1640,9 +1827,14 @@ var JitProgram = class {
1640
1827
  case "execute": {
1641
1828
  const inputsNice = step.inputs.map((id, i) => `${i}: %${id}`).join(", ");
1642
1829
  const outputsNice = step.outputs.map((id) => `%${id}`).join(", ");
1643
- return require_backend.PPrint.pp(`execute (${inputsNice}) -> ${outputsNice}, kernel`).concat(step.kernel.pprint().indent(2));
1830
+ const executeText = `execute (${inputsNice}) -> ${outputsNice}`;
1831
+ if (step.source instanceof require_backend.Kernel) return require_backend.PPrint.pp(`${executeText}, kernel`).concat(step.source.pprint().indent(2));
1832
+ else if (step.source instanceof require_backend.Routine) return require_backend.PPrint.pp(`${executeText}, routine ${step.source.name}`);
1833
+ else {
1834
+ step.source;
1835
+ return require_backend.PPrint.pp(executeText);
1836
+ }
1644
1837
  }
1645
- case "const": return require_backend.PPrint.pp(`%${step.output} = const <Slot ${step.slot}>`);
1646
1838
  case "malloc": return require_backend.PPrint.pp(`%${step.output} = malloc <${step.size} bytes>`);
1647
1839
  case "incref": return require_backend.PPrint.pp(`incref ${step.input}`);
1648
1840
  case "free": return require_backend.PPrint.pp(`free ${step.input}`);
@@ -1665,12 +1857,9 @@ var JitProgram = class {
1665
1857
  const inputs$1 = step.inputs.map((id) => scope.get(id));
1666
1858
  const outputs = step.outputs.map((id) => scope.get(id));
1667
1859
  if (inputs$1.some((s) => s === void 0) || outputs.some((s) => s === void 0)) throw new Error(`internal: JitProgram scope undefined`);
1668
- pending.push(new PendingExecute(this.backend, step.kernel, inputs$1, outputs));
1860
+ pending.push(new PendingExecute(this.backend, step.source, inputs$1, outputs));
1669
1861
  break;
1670
1862
  }
1671
- case "const":
1672
- scope.set(step.output, step.slot);
1673
- break;
1674
1863
  case "malloc": {
1675
1864
  const slot = this.backend.malloc(step.size);
1676
1865
  scope.set(step.output, slot);
@@ -1704,34 +1893,37 @@ var JitProgramBuilder = class {
1704
1893
  this.#nextId = nargs;
1705
1894
  this.steps = [];
1706
1895
  }
1707
- pushConst(slot) {
1708
- const id = this.#nextId++;
1709
- this.steps.push({
1710
- type: "const",
1711
- slot,
1712
- output: id
1713
- });
1714
- return id;
1715
- }
1716
1896
  pushLit(lit) {
1717
- const kernel = new require_backend.Kernel(0, require_backend.prod(lit.aval.shape), require_backend.AluExp.const(lit.dtype, lit.value));
1897
+ const kernel = new require_backend.Kernel(0, lit.aval.size, require_backend.AluExp.const(lit.dtype, lit.value));
1718
1898
  return this.pushKernel(kernel, []);
1719
1899
  }
1720
- pushKernel(kernel, inputs) {
1900
+ pushBuffer(size$1) {
1721
1901
  const id = this.#nextId++;
1722
1902
  this.steps.push({
1723
1903
  type: "malloc",
1724
- size: kernel.bytes,
1904
+ size: size$1,
1725
1905
  output: id
1726
1906
  });
1907
+ return id;
1908
+ }
1909
+ pushKernel(kernel, inputs) {
1910
+ const id = this.pushBuffer(kernel.bytes);
1727
1911
  this.steps.push({
1728
1912
  type: "execute",
1729
- kernel,
1913
+ source: kernel,
1730
1914
  inputs,
1731
1915
  outputs: [id]
1732
1916
  });
1733
1917
  return id;
1734
1918
  }
1919
+ pushRoutine(routine, inputs, outputs) {
1920
+ this.steps.push({
1921
+ type: "execute",
1922
+ source: routine,
1923
+ inputs,
1924
+ outputs
1925
+ });
1926
+ }
1735
1927
  pushIncref(id) {
1736
1928
  this.steps.push({
1737
1929
  type: "incref",
@@ -1757,28 +1949,18 @@ var JitProgramBuilder = class {
1757
1949
  }
1758
1950
  };
1759
1951
  const jitCompileCache = /* @__PURE__ */ new Map();
1760
- function jitCompile(backend, jaxpr, consts) {
1761
- if (jaxpr.inBinders.length < consts.length) throw new TypeError(`Jaxpr has ${jaxpr.inBinders.length} inputs, but ${consts.length} consts were provided`);
1762
- for (let i = 0; i < consts.length; i++) if (consts[i].device !== backend.type) throw new TypeError(`Const ${i} has device ${consts[i].device}, but expected ${backend.type}`);
1763
- const cacheKey = backend.type + require_backend.FpHash.hash(jaxpr, ...consts.map((c) => c.id));
1952
+ function jitCompile(backend, jaxpr) {
1953
+ const cacheKey = backend.type + "," + require_backend.FpHash.hash(jaxpr);
1764
1954
  const cached = jitCompileCache.get(cacheKey);
1765
1955
  if (cached) return cached;
1766
1956
  if (require_backend.DEBUG >= 1) console.info("=========== JIT Compile ===========\n" + jaxpr.toString());
1767
1957
  jaxpr = jaxpr.flatten().simplify();
1768
- const nargs = jaxpr.inBinders.length - consts.length;
1958
+ const nargs = jaxpr.inBinders.length;
1769
1959
  const builder = new JitProgramBuilder(backend, nargs);
1770
1960
  const blackNodes = splitGraphDataflow(backend, jaxpr);
1771
1961
  const ctx = /* @__PURE__ */ new Map();
1772
- for (let i = 0; i < consts.length; i++) {
1773
- const v = jaxpr.inBinders[i];
1774
- const slot = consts[i]._realizeSource();
1775
- ctx.set(v, {
1776
- type: "imm",
1777
- arg: builder.pushConst(slot)
1778
- });
1779
- }
1780
1962
  for (let i = 0; i < nargs; i++) {
1781
- const v = jaxpr.inBinders[consts.length + i];
1963
+ const v = jaxpr.inBinders[i];
1782
1964
  ctx.set(v, {
1783
1965
  type: "imm",
1784
1966
  arg: i
@@ -1786,6 +1968,31 @@ function jitCompile(backend, jaxpr, consts) {
1786
1968
  }
1787
1969
  for (let i = 0; i < jaxpr.eqns.length; i++) {
1788
1970
  const eqn = jaxpr.eqns[i];
1971
+ if (routinePrimitives.has(eqn.primitive)) {
1972
+ const routine = new require_backend.Routine(routinePrimitives.get(eqn.primitive), {
1973
+ inputShapes: eqn.inputs.map((x) => x.aval.shape),
1974
+ inputDtypes: eqn.inputs.map((x) => x.aval.dtype),
1975
+ outputShapes: eqn.outBinders.map((x) => x.aval.shape),
1976
+ outputDtypes: eqn.outBinders.map((x) => x.aval.dtype)
1977
+ }, eqn.params);
1978
+ const inputs = [];
1979
+ for (const input of eqn.inputs) if (input instanceof Var) {
1980
+ const jv = ctx.get(input);
1981
+ if (jv.type !== "imm") throw new Error(`jit: routine primitive ${eqn.primitive} input is not imm`);
1982
+ inputs.push(jv.arg);
1983
+ } else if (input instanceof Lit) inputs.push(builder.pushLit(input));
1984
+ const outputs = [];
1985
+ for (const outVar of eqn.outBinders) {
1986
+ const outId = builder.pushBuffer(outVar.aval.size * require_backend.byteWidth(outVar.aval.dtype));
1987
+ outputs.push(outId);
1988
+ ctx.set(outVar, {
1989
+ type: "imm",
1990
+ arg: outId
1991
+ });
1992
+ }
1993
+ builder.pushRoutine(routine, inputs, outputs);
1994
+ continue;
1995
+ }
1789
1996
  const inputExps = [];
1790
1997
  const inputAvals = [];
1791
1998
  const inputArgs = [];
@@ -1829,35 +2036,37 @@ function jitCompile(backend, jaxpr, consts) {
1829
2036
  let reduction;
1830
2037
  if (inputReduction) {
1831
2038
  const jv = inputReduction;
1832
- const newEpilogue = rule(inputExps, inputAvals, eqn.params).exp;
1833
- exp$2 = jv.exp.reindexGids(addArgs(jv.args));
2039
+ const newEpilogue = rule(inputExps, inputAvals, eqn.params).exp[0];
2040
+ exp$2 = [jv.exp.reindexGids(addArgs(jv.args))];
1834
2041
  reduction = new require_backend.Reduction(jv.reduction.dtype, jv.reduction.op, jv.reduction.size, newEpilogue);
1835
2042
  } else {
1836
2043
  const ruleOutput = rule(inputExps, inputAvals, eqn.params);
1837
2044
  exp$2 = ruleOutput.exp;
1838
2045
  reduction = ruleOutput.reduction;
1839
2046
  }
1840
- const outVar = eqn.outBinders[0];
1841
- if (blackNodes.has(outVar)) {
1842
- const nargs$1 = inputArgs.length;
1843
- const size$1 = require_backend.prod(outVar.aval.shape);
1844
- const kernel = new require_backend.Kernel(nargs$1, size$1, exp$2, reduction);
1845
- const outId = builder.pushKernel(kernel, inputArgs);
1846
- ctx.set(outVar, {
1847
- type: "imm",
1848
- arg: outId
2047
+ for (let i$1 = 0; i$1 < eqn.outBinders.length; i$1++) {
2048
+ const outVar = eqn.outBinders[i$1];
2049
+ if (blackNodes.has(outVar)) {
2050
+ const nargs$1 = inputArgs.length;
2051
+ const size$1 = outVar.aval.size;
2052
+ const kernel = new require_backend.Kernel(nargs$1, size$1, exp$2[i$1], reduction);
2053
+ const outId = builder.pushKernel(kernel, inputArgs);
2054
+ ctx.set(outVar, {
2055
+ type: "imm",
2056
+ arg: outId
2057
+ });
2058
+ } else if (reduction) ctx.set(outVar, {
2059
+ type: "red",
2060
+ exp: exp$2[i$1],
2061
+ reduction,
2062
+ args: inputArgs
1849
2063
  });
1850
- } else if (reduction) ctx.set(outVar, {
1851
- type: "red",
1852
- exp: exp$2,
1853
- reduction,
1854
- args: inputArgs
1855
- });
1856
- else ctx.set(outVar, {
1857
- type: "exp",
1858
- exp: exp$2,
1859
- args: inputArgs
1860
- });
2064
+ else ctx.set(outVar, {
2065
+ type: "exp",
2066
+ exp: exp$2[i$1],
2067
+ args: inputArgs
2068
+ });
2069
+ }
1861
2070
  }
1862
2071
  const outputIds = [];
1863
2072
  for (const out of jaxpr.outs) if (out instanceof Var) {
@@ -1865,7 +2074,7 @@ function jitCompile(backend, jaxpr, consts) {
1865
2074
  if (jitValue.type !== "imm") throw new Error("internal: Expected imm, since outs are black nodes");
1866
2075
  outputIds.push(jitValue.arg);
1867
2076
  } else if (out instanceof Lit) outputIds.push(builder.pushLit(out));
1868
- const outputNeedsRef = new Set([...require_backend.range(nargs), ...builder.steps.filter((s) => s.type === "const").map((s) => s.output)]);
2077
+ const outputNeedsRef = new Set(require_backend.range(nargs));
1869
2078
  for (const outputId of outputIds) if (outputNeedsRef.has(outputId)) builder.pushIncref(outputId);
1870
2079
  else outputNeedsRef.add(outputId);
1871
2080
  builder.insertFreeSteps(outputIds);
@@ -1898,17 +2107,22 @@ function broadcastedJit(fn, opts) {
1898
2107
  if (exp$2.dtype !== newDtype && !skipCastIdx.includes(i)) exp$2 = require_backend.AluExp.cast(newDtype, exp$2);
1899
2108
  return exp$2;
1900
2109
  });
1901
- return { exp: fn(exps, params) };
2110
+ return { exp: [fn(exps, params)] };
1902
2111
  };
1903
2112
  }
1904
2113
  function unopJit(fn) {
1905
2114
  return ([a], [_as], params) => {
1906
- return { exp: fn(a, params) };
2115
+ return { exp: [fn(a, params)] };
1907
2116
  };
1908
2117
  }
1909
2118
  function reshapeJit(fn) {
1910
2119
  return ([a], [_as], params) => {
1911
- return { exp: reshapeViews(a, (st) => fn(st, params)) };
2120
+ return { exp: [reshapeViews(a, (st) => fn(st, params))] };
2121
+ };
2122
+ }
2123
+ function routineNoJit() {
2124
+ return () => {
2125
+ throw new Error("jit: rule is not implemented for routines");
1912
2126
  };
1913
2127
  }
1914
2128
  const jitRules = {
@@ -1916,6 +2130,8 @@ const jitRules = {
1916
2130
  [Primitive.Mul]: broadcastedJit(([a, b]) => require_backend.AluExp.mul(a, b)),
1917
2131
  [Primitive.Idiv]: broadcastedJit(([a, b]) => require_backend.AluExp.idiv(a, b)),
1918
2132
  [Primitive.Mod]: broadcastedJit(([a, b]) => require_backend.AluExp.mod(a, b)),
2133
+ [Primitive.Min]: broadcastedJit(([a, b]) => require_backend.AluExp.min(a, b)),
2134
+ [Primitive.Max]: broadcastedJit(([a, b]) => require_backend.AluExp.max(a, b)),
1919
2135
  [Primitive.Neg]: unopJit((a) => require_backend.AluExp.sub(require_backend.AluExp.const(a.dtype, 0), a)),
1920
2136
  [Primitive.Reciprocal]: unopJit(require_backend.AluExp.reciprocal),
1921
2137
  [Primitive.Floor]: unopJit(require_backend.AluExp.floor),
@@ -1923,17 +2139,6 @@ const jitRules = {
1923
2139
  [Primitive.StopGradient]: unopJit((a) => a),
1924
2140
  [Primitive.Cast]: unopJit((a, { dtype }) => require_backend.AluExp.cast(dtype, a)),
1925
2141
  [Primitive.Bitcast]: unopJit((a, { dtype }) => require_backend.AluExp.bitcast(dtype, a)),
1926
- [Primitive.RandomBits]: (keys, keyShapes, { shape: shape$1, mode }) => {
1927
- const mapping = (st) => {
1928
- if (!require_backend.deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, require_backend.range(shape$1.length - st.shape.length));
1929
- };
1930
- const k0 = reshapeViews(keys[0], mapping);
1931
- const k1 = reshapeViews(keys[1], mapping);
1932
- const c0 = require_backend.AluExp.u32(0);
1933
- const c1 = require_backend.AluExp.cast(require_backend.DType.Uint32, require_backend.AluVar.gidx);
1934
- const exp$2 = require_backend.AluExp.threefry2x32(k0, k1, c0, c1, mode);
1935
- return { exp: exp$2 };
1936
- },
1937
2142
  [Primitive.Sin]: unopJit(require_backend.AluExp.sin),
1938
2143
  [Primitive.Cos]: unopJit(require_backend.AluExp.cos),
1939
2144
  [Primitive.Asin]: unopJit(require_backend.AluExp.asin),
@@ -1943,8 +2148,6 @@ const jitRules = {
1943
2148
  [Primitive.Erf]: unopJit(require_backend.AluExp.erf),
1944
2149
  [Primitive.Erfc]: unopJit(require_backend.AluExp.erfc),
1945
2150
  [Primitive.Sqrt]: unopJit(require_backend.AluExp.sqrt),
1946
- [Primitive.Min]: broadcastedJit(([a, b]) => require_backend.AluExp.min(a, b)),
1947
- [Primitive.Max]: broadcastedJit(([a, b]) => require_backend.AluExp.max(a, b)),
1948
2151
  [Primitive.Reduce]([a], [as], { op, axis }) {
1949
2152
  const keptAxes = [];
1950
2153
  const shiftedAxes = [];
@@ -1960,7 +2163,7 @@ const jitRules = {
1960
2163
  a = reshapeViews(a, (st) => st.permute(perm).reshape(newShape), true);
1961
2164
  const reduction = new require_backend.Reduction(a.dtype, op, reductionSize);
1962
2165
  return {
1963
- exp: a,
2166
+ exp: [a],
1964
2167
  reduction
1965
2168
  };
1966
2169
  },
@@ -1971,13 +2174,13 @@ const jitRules = {
1971
2174
  a = reshapeViews(a, (st) => st.compose(stX), true);
1972
2175
  const reduction = new require_backend.Reduction(a.dtype, require_backend.AluOp.Add, stX.shape[stX.shape.length - 1]);
1973
2176
  return {
1974
- exp: a,
2177
+ exp: [a],
1975
2178
  reduction
1976
2179
  };
1977
2180
  },
1978
2181
  [Primitive.Dot]([a, b], [as, bs]) {
1979
2182
  const k1 = jitRules[Primitive.Mul]([a, b], [as, bs], {});
1980
- const c = k1.exp;
2183
+ const [c] = k1.exp;
1981
2184
  const cs = promoteAvals(as, bs);
1982
2185
  return jitRules[Primitive.Reduce]([c], [cs], {
1983
2186
  op: require_backend.AluOp.Add,
@@ -1994,16 +2197,42 @@ const jitRules = {
1994
2197
  },
1995
2198
  [Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
1996
2199
  [Primitive.Where]: broadcastedJit(([cond, a, b]) => require_backend.AluExp.where(cond, a, b), { skipCastIdx: [0] }),
1997
- [Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
1998
- [Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
1999
- [Primitive.Reshape]: reshapeJit((st, { shape: shape$1 }) => st.reshape(shape$1)),
2000
- [Primitive.Flip]: reshapeJit((st, { axis }) => {
2001
- const arg = require_backend.rep(st.shape.length, false);
2002
- for (const ax of axis) arg[ax] = true;
2003
- return st.flip(arg);
2004
- }),
2005
- [Primitive.Shrink]: reshapeJit((st, { slice }) => st.shrink(slice)),
2006
- [Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
2200
+ [Primitive.Concatenate](exps, avals, { axis }) {
2201
+ const ndim$2 = avals[0].ndim;
2202
+ const sizes = avals.map((x) => x.shape[axis]);
2203
+ const finalSize = sizes.reduce((a, b) => a + b, 0);
2204
+ const makePadAxis = (start, end) => require_backend.range(ndim$2).map((i) => i === axis ? [start, end] : [0, 0]);
2205
+ let cum = 0;
2206
+ const src = [];
2207
+ for (let i = 0; i < exps.length; i++) {
2208
+ const padding = makePadAxis(cum, finalSize - cum - sizes[i]);
2209
+ src.push(reshapeViews(exps[i], (st) => st.pad(padding)));
2210
+ cum += sizes[i];
2211
+ }
2212
+ return { exp: [src.reduce(require_backend.AluExp.add)] };
2213
+ },
2214
+ [Primitive.Split]([a], [as], { axis, sizes }) {
2215
+ const exp$2 = [];
2216
+ let start = 0;
2217
+ for (const size$1 of sizes) {
2218
+ const slice = require_backend.range(as.ndim).map((d) => d === axis ? [start, start + size$1] : [0, as.shape[d]]);
2219
+ exp$2.push(reshapeViews(a, (st) => st.shrink(slice)));
2220
+ start += size$1;
2221
+ }
2222
+ return { exp: exp$2 };
2223
+ },
2224
+ [Primitive.RandomBits]: (keys, keyShapes, { shape: shape$1, mode }) => {
2225
+ const keyShape = keyShapes[0].shape;
2226
+ const mapping = (st) => {
2227
+ if (!require_backend.deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, require_backend.range(st.shape.length, shape$1.length));
2228
+ };
2229
+ const k0 = reshapeViews(keys[0], mapping);
2230
+ const k1 = reshapeViews(keys[1], mapping);
2231
+ const c0 = require_backend.AluExp.u32(0);
2232
+ const c1 = require_backend.AluExp.mod(require_backend.AluExp.cast(require_backend.DType.Uint32, require_backend.AluVar.gidx), require_backend.AluExp.u32(Math.max(require_backend.prod(shape$1.slice(keyShape.length)), 1)));
2233
+ const exp$2 = require_backend.AluExp.threefry2x32(k0, k1, c0, c1, mode);
2234
+ return { exp: [exp$2] };
2235
+ },
2007
2236
  [Primitive.Gather]([x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
2008
2237
  const axisSet = new Set(axis);
2009
2238
  const indexShape = indicesShapes.map((c) => c.shape).reduce(require_backend.generalBroadcast);
@@ -2017,10 +2246,25 @@ const jitRules = {
2017
2246
  for (const [i, iexp] of indices.entries()) src[axis[i]] = require_backend.AluExp.cast(require_backend.DType.Int32, reshapeViews(iexp, (st) => st.broadcast(finalShape, [...require_backend.range(outDim + indexShape.length - st.shape.length), ...require_backend.range(outDim + indexShape.length, finalShape.length)])));
2018
2247
  const [index, valid] = require_backend.ShapeTracker.fromShape(xs.shape).toAluExp(src);
2019
2248
  if (!valid.resolve()) throw new Error("internal: expected full validity mask in Gather");
2020
- return { exp: x.substitute({ gidx: index }) };
2249
+ return { exp: [x.substitute({ gidx: index })] };
2021
2250
  },
2022
- [Primitive.JitCall]() {
2023
- throw new Error("internal: JitCall should have been flattened before JIT compilation");
2251
+ [Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
2252
+ [Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
2253
+ [Primitive.Reshape]: reshapeJit((st, { shape: shape$1 }) => st.reshape(shape$1)),
2254
+ [Primitive.Flip]: reshapeJit((st, { axis }) => {
2255
+ const arg = require_backend.rep(st.shape.length, false);
2256
+ for (const ax of axis) arg[ax] = true;
2257
+ return st.flip(arg);
2258
+ }),
2259
+ [Primitive.Shrink]: reshapeJit((st, { slice }) => st.shrink(slice)),
2260
+ [Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
2261
+ [Primitive.Sort]: routineNoJit(),
2262
+ [Primitive.Argsort]: routineNoJit(),
2263
+ [Primitive.TriangularSolve]: routineNoJit(),
2264
+ [Primitive.Cholesky]: routineNoJit(),
2265
+ [Primitive.LU]: routineNoJit(),
2266
+ [Primitive.Jit]() {
2267
+ throw new Error("internal: Jit should have been flattened before JIT compilation");
2024
2268
  }
2025
2269
  };
2026
2270
  /** Determines how to split the Jaxpr into kernels via dataflow analysis. */
@@ -2078,8 +2322,8 @@ function splitGraphDataflow(backend, jaxpr) {
2078
2322
  case Primitive.Mul:
2079
2323
  case Primitive.Idiv:
2080
2324
  case Primitive.Mod:
2081
- case Primitive.Max:
2082
- case Primitive.Min: {
2325
+ case Primitive.Min:
2326
+ case Primitive.Max: {
2083
2327
  const otherInput = nextEqn.inputs.find((v) => v !== outVar);
2084
2328
  if (otherInput instanceof Lit || require_backend.deepEqual(require_backend.generalBroadcast(otherInput.aval.shape, outVar.aval.shape), outVar.aval.shape)) {
2085
2329
  head = usages[0];
@@ -2099,11 +2343,11 @@ function splitGraphDataflow(backend, jaxpr) {
2099
2343
  blackNodes.add(v);
2100
2344
  p1NextBlack.set(v, v);
2101
2345
  }
2102
- const heterogeneousViewPrimitives = [Primitive.Gather, Primitive.RandomBits];
2346
+ const heterogeneousViewPrimitives = [Primitive.RandomBits, Primitive.Gather];
2103
2347
  const needsCleanShapePrimitives = [Primitive.Pad];
2104
2348
  for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
2105
2349
  const eqn = jaxpr.eqns[i];
2106
- if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
2350
+ if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || routinePrimitives.has(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
2107
2351
  for (const v of eqn.outBinders) {
2108
2352
  blackNodes.add(v);
2109
2353
  p1NextBlack.set(v, v);
@@ -2113,7 +2357,7 @@ function splitGraphDataflow(backend, jaxpr) {
2113
2357
  const reach = /* @__PURE__ */ new Set();
2114
2358
  let needsCleanOutput = false;
2115
2359
  outer: for (const v of eqn.outBinders) for (const j of varToUsages.get(v) ?? []) {
2116
- if (needsCleanShapePrimitives.includes(jaxpr.eqns[j].primitive)) {
2360
+ if (needsCleanShapePrimitives.includes(jaxpr.eqns[j].primitive) || routinePrimitives.has(jaxpr.eqns[j].primitive)) {
2117
2361
  needsCleanOutput = true;
2118
2362
  break outer;
2119
2363
  }
@@ -2137,7 +2381,6 @@ function splitGraphDataflow(backend, jaxpr) {
2137
2381
  while (p2idx < jaxpr.eqns.length) {
2138
2382
  const eqn = jaxpr.eqns[p2idx++];
2139
2383
  const deps = [];
2140
- if (eqn.outBinders.some((v) => blackNodes.has(v))) continue;
2141
2384
  for (const input of eqn.inputs) if (input instanceof Var) if (blackNodes.has(input)) deps.push(new Set([input]));
2142
2385
  else deps.push(p2Deps.get(input));
2143
2386
  else deps.push(/* @__PURE__ */ new Set());
@@ -2160,7 +2403,7 @@ function splitGraphDataflow(backend, jaxpr) {
2160
2403
  if (assocInput === -1) throw new Error(`internal: maxArgs, no input found to mark as black in Jaxpr equation ${eqn}`);
2161
2404
  const assocVar = eqn.inputs[assocInput];
2162
2405
  p2idx = varToDefn.get(assocVar);
2163
- for (const out of jaxpr.eqns[p2idx].outBinders) blackNodes.add(out);
2406
+ for (const out of jaxpr.eqns[p2idx++].outBinders) blackNodes.add(out);
2164
2407
  } else {
2165
2408
  const s = new Set(depCounter.keys());
2166
2409
  for (const out of eqn.outBinders) p2Deps.set(out, s);
@@ -2186,9 +2429,9 @@ var PendingExecute = class {
2186
2429
  submitted = false;
2187
2430
  #promise = null;
2188
2431
  #rc = 1;
2189
- constructor(backend, kernel, inputs, outputs) {
2432
+ constructor(backend, source, inputs, outputs) {
2190
2433
  this.backend = backend;
2191
- this.kernel = kernel;
2434
+ this.source = source;
2192
2435
  this.inputs = inputs;
2193
2436
  this.outputs = outputs;
2194
2437
  for (const slot of inputs) this.backend.incRef(slot);
@@ -2209,13 +2452,15 @@ var PendingExecute = class {
2209
2452
  return;
2210
2453
  }
2211
2454
  this.#promise = (async () => {
2212
- this.prepared = await this.backend.prepare(this.kernel);
2455
+ if (this.source instanceof require_backend.Kernel) this.prepared = await this.backend.prepareKernel(this.source);
2456
+ else this.prepared = await this.backend.prepareRoutine(this.source);
2213
2457
  })();
2214
2458
  await this.#promise;
2215
2459
  }
2216
2460
  prepareSync() {
2217
2461
  if (this.prepared) return;
2218
- this.prepared = this.backend.prepareSync(this.kernel);
2462
+ if (this.source instanceof require_backend.Kernel) this.prepared = this.backend.prepareKernelSync(this.source);
2463
+ else this.prepared = this.backend.prepareRoutineSync(this.source);
2219
2464
  }
2220
2465
  submit() {
2221
2466
  if (this.submitted) return;
@@ -2238,8 +2483,6 @@ var PendingExecute = class {
2238
2483
  * "Array" type by name.
2239
2484
  */
2240
2485
  var Array$1 = class Array$1 extends Tracer {
2241
- static #nextId = 1001;
2242
- id;
2243
2486
  #dtype;
2244
2487
  #weakType;
2245
2488
  #source;
@@ -2256,7 +2499,6 @@ var Array$1 = class Array$1 extends Tracer {
2256
2499
  */
2257
2500
  constructor(args) {
2258
2501
  super(baseArrayTrace);
2259
- this.id = Array$1.#nextId++;
2260
2502
  this.#dtype = args.dtype;
2261
2503
  this.#weakType = args.weakType;
2262
2504
  this.#source = args.source;
@@ -2299,6 +2541,10 @@ var Array$1 = class Array$1 extends Tracer {
2299
2541
  this.#rc++;
2300
2542
  return this;
2301
2543
  }
2544
+ /** Get the current reference count (for debugging memory management). */
2545
+ get refCount() {
2546
+ return this.#rc;
2547
+ }
2302
2548
  dispose() {
2303
2549
  this.#check();
2304
2550
  if (--this.#rc === 0) {
@@ -2456,7 +2702,7 @@ var Array$1 = class Array$1 extends Tracer {
2456
2702
  } else if (castDtype === void 0) {
2457
2703
  castDtype = arrays[i].#dtype;
2458
2704
  castWeakType = arrays[i].#weakType;
2459
- } else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), new ShapedArray([], arrays[i].#dtype, arrays[i].#weakType)));
2705
+ } else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), arrays[i].aval.scalar()));
2460
2706
  const weakType = castWeakType && !strongTypeOutput;
2461
2707
  const { backend, committed } = Array$1.#computeBackend(name, arrays);
2462
2708
  arrays = arrays.map((ar) => ar._putSync(backend));
@@ -2565,6 +2811,27 @@ var Array$1 = class Array$1 extends Tracer {
2565
2811
  pending
2566
2812
  });
2567
2813
  }
2814
+ /** Apply an operation with custom lowering to this array. */
2815
+ static #routine(routine, arrays, outputWeakType) {
2816
+ const { backend, committed } = Array$1.#computeBackend(routine.name, arrays);
2817
+ for (const ar of arrays) ar.#realize();
2818
+ const inputs = arrays.map((ar) => ar.#source);
2819
+ const outputs = routine.type.outputDtypes.map((dtype, i) => backend.malloc(require_backend.byteWidth(dtype) * require_backend.prod(routine.type.outputShapes[i])));
2820
+ const pending = arrays.flatMap((ar) => ar.#pending);
2821
+ for (const exe of pending) exe.updateRc(+outputs.length);
2822
+ pending.push(new PendingExecute(backend, routine, inputs, outputs));
2823
+ pending[pending.length - 1].updateRc(+outputs.length - 1);
2824
+ arrays.forEach((ar) => ar.dispose());
2825
+ return outputs.map((output, i) => new Array$1({
2826
+ source: output,
2827
+ st: require_backend.ShapeTracker.fromShape(routine.type.outputShapes[i]),
2828
+ dtype: routine.type.outputDtypes[i],
2829
+ weakType: outputWeakType[i],
2830
+ backend,
2831
+ committed,
2832
+ pending
2833
+ }));
2834
+ }
2568
2835
  /**
2569
2836
  * Normalizes this array into one backed by a `Slot`.
2570
2837
  *
@@ -2725,6 +2992,12 @@ var Array$1 = class Array$1 extends Tracer {
2725
2992
  [Primitive.Mod]([x, y]) {
2726
2993
  return [x.#binary(require_backend.AluOp.Mod, y)];
2727
2994
  },
2995
+ [Primitive.Min]([x, y]) {
2996
+ return [x.#binary(require_backend.AluOp.Min, y)];
2997
+ },
2998
+ [Primitive.Max]([x, y]) {
2999
+ return [x.#binary(require_backend.AluOp.Max, y)];
3000
+ },
2728
3001
  [Primitive.Neg]([x]) {
2729
3002
  return [zerosLike$1(x.ref).#binary(require_backend.AluOp.Sub, x)];
2730
3003
  },
@@ -2761,25 +3034,6 @@ var Array$1 = class Array$1 extends Tracer {
2761
3034
  return [y];
2762
3035
  }
2763
3036
  },
2764
- [Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
2765
- const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
2766
- if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
2767
- const c0 = zeros(shape$1, {
2768
- dtype: require_backend.DType.Uint32,
2769
- device: k0.device
2770
- });
2771
- const c1 = arange(0, require_backend.prod(shape$1), 1, {
2772
- dtype: require_backend.DType.Uint32,
2773
- device: k0.device
2774
- }).reshape(shape$1);
2775
- const custom = ([k0$1, k1$1, c0$1, c1$1]) => require_backend.AluExp.threefry2x32(k0$1, k1$1, c0$1, c1$1, mode);
2776
- return [Array$1.#naryCustom("random_bits", custom, [
2777
- k0,
2778
- k1,
2779
- c0,
2780
- c1
2781
- ])];
2782
- },
2783
3037
  [Primitive.Sin]([x]) {
2784
3038
  return [x.#unary(require_backend.AluOp.Sin)];
2785
3039
  },
@@ -2807,12 +3061,6 @@ var Array$1 = class Array$1 extends Tracer {
2807
3061
  [Primitive.Sqrt]([x]) {
2808
3062
  return [x.#unary(require_backend.AluOp.Sqrt)];
2809
3063
  },
2810
- [Primitive.Min]([x, y]) {
2811
- return [x.#binary(require_backend.AluOp.Min, y)];
2812
- },
2813
- [Primitive.Max]([x, y]) {
2814
- return [x.#binary(require_backend.AluOp.Max, y)];
2815
- },
2816
3064
  [Primitive.Reduce]([x], { op, axis }) {
2817
3065
  if (axis.length === 0) return [x];
2818
3066
  return [x.#moveAxesDown(axis).#reduce(op)];
@@ -2847,6 +3095,55 @@ var Array$1 = class Array$1 extends Tracer {
2847
3095
  y
2848
3096
  ], { dtypeOverride: [require_backend.DType.Bool] })];
2849
3097
  },
3098
+ [Primitive.Concatenate](xs, { axis }) {
3099
+ const ndim$2 = xs[0].ndim;
3100
+ const sizes = xs.map((x) => x.shape[axis]);
3101
+ const finalSize = sizes.reduce((a, b) => a + b, 0);
3102
+ const makePadAxis = (start, end) => require_backend.range(ndim$2).map((i) => i === axis ? [start, end] : [0, 0]);
3103
+ let cum = 0;
3104
+ const xsPadded = [];
3105
+ for (let i = 0; i < xs.length; i++) {
3106
+ const padding = makePadAxis(cum, finalSize - cum - sizes[i]);
3107
+ xsPadded.push(xs[i].#reshape(xs[i].#st.pad(padding)));
3108
+ cum += sizes[i];
3109
+ }
3110
+ const custom = (exps) => exps.reduce(require_backend.AluExp.add);
3111
+ return [Array$1.#naryCustom("concatenate", custom, xsPadded)];
3112
+ },
3113
+ [Primitive.Split]([x], { axis, sizes }) {
3114
+ const outputs = [];
3115
+ for (let i = 0, start = 0; i < sizes.length; i++) {
3116
+ const slice = require_backend.range(x.ndim).map((d) => d === axis ? [start, start + sizes[i]] : [0, x.shape[d]]);
3117
+ outputs.push(x.ref.#reshape(x.#st.shrink(slice)));
3118
+ start += sizes[i];
3119
+ }
3120
+ x.dispose();
3121
+ return outputs;
3122
+ },
3123
+ [Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
3124
+ const keyShape = k0.shape;
3125
+ const genShape = shape$1.slice(keyShape.length);
3126
+ const c0 = zeros(genShape, {
3127
+ dtype: require_backend.DType.Uint32,
3128
+ device: k0.device
3129
+ });
3130
+ const c1 = arange(0, require_backend.prod(genShape), 1, {
3131
+ dtype: require_backend.DType.Uint32,
3132
+ device: k0.device
3133
+ }).reshape(genShape);
3134
+ k0 = k0.#reshape(k0.#st.reshape(keyShape.concat(require_backend.rep(genShape.length, 1))));
3135
+ k1 = k1.#reshape(k1.#st.reshape(keyShape.concat(require_backend.rep(genShape.length, 1))));
3136
+ const custom = ([k0$1, k1$1, c0$1, c1$1]) => require_backend.AluExp.threefry2x32(k0$1, k1$1, c0$1, c1$1, mode);
3137
+ return [Array$1.#naryCustom("random_bits", custom, [
3138
+ k0,
3139
+ k1,
3140
+ c0,
3141
+ c1
3142
+ ])];
3143
+ },
3144
+ [Primitive.Gather]([x, ...indices], { axis, outDim }) {
3145
+ return [x.#gather(indices, axis, outDim)];
3146
+ },
2850
3147
  [Primitive.Transpose]([x], { perm }) {
2851
3148
  return [x.#transpose(perm)];
2852
3149
  },
@@ -2867,17 +3164,71 @@ var Array$1 = class Array$1 extends Tracer {
2867
3164
  [Primitive.Pad]([x], { width }) {
2868
3165
  return [x.#reshape(x.#st.pad(width))];
2869
3166
  },
2870
- [Primitive.Gather]([x, ...indices], { axis, outDim }) {
2871
- return [x.#gather(indices, axis, outDim)];
3167
+ [Primitive.Sort]([x]) {
3168
+ const routine = new require_backend.Routine(require_backend.Routines.Sort, {
3169
+ inputShapes: [x.shape],
3170
+ inputDtypes: [x.dtype],
3171
+ outputShapes: [x.shape],
3172
+ outputDtypes: [x.dtype]
3173
+ });
3174
+ return Array$1.#routine(routine, [x], [x.#weakType]);
3175
+ },
3176
+ [Primitive.Argsort]([x]) {
3177
+ const routine = new require_backend.Routine(require_backend.Routines.Argsort, {
3178
+ inputShapes: [x.shape],
3179
+ inputDtypes: [x.dtype],
3180
+ outputShapes: [x.shape, x.shape],
3181
+ outputDtypes: [x.dtype, require_backend.DType.Int32]
3182
+ });
3183
+ return Array$1.#routine(routine, [x], [x.#weakType, false]);
3184
+ },
3185
+ [Primitive.TriangularSolve]([a, b], { unitDiagonal }) {
3186
+ const routine = new require_backend.Routine(require_backend.Routines.TriangularSolve, {
3187
+ inputShapes: [a.shape, b.shape],
3188
+ inputDtypes: [a.dtype, b.dtype],
3189
+ outputShapes: [b.shape],
3190
+ outputDtypes: [b.dtype]
3191
+ }, { unitDiagonal });
3192
+ return Array$1.#routine(routine, [a, b], [a.#weakType && b.#weakType]);
2872
3193
  },
2873
- [Primitive.JitCall](args, { jaxpr, numConsts }) {
2874
- if (jaxpr.inBinders.length !== args.length) throw new Error(`jit_call expects ${jaxpr.inBinders.length} args, got ${args.length}`);
2875
- const { backend, committed } = Array$1.#computeBackend("jit_call", args);
3194
+ [Primitive.Cholesky]([a]) {
3195
+ const routine = new require_backend.Routine(require_backend.Routines.Cholesky, {
3196
+ inputShapes: [a.shape],
3197
+ inputDtypes: [a.dtype],
3198
+ outputShapes: [a.shape],
3199
+ outputDtypes: [a.dtype]
3200
+ });
3201
+ return Array$1.#routine(routine, [a], [a.#weakType]);
3202
+ },
3203
+ [Primitive.LU]([a]) {
3204
+ const batch = a.shape.slice(0, -2);
3205
+ const [m, n] = a.shape.slice(-2);
3206
+ const routine = new require_backend.Routine(require_backend.Routines.LU, {
3207
+ inputShapes: [a.shape],
3208
+ inputDtypes: [a.dtype],
3209
+ outputShapes: [
3210
+ a.shape,
3211
+ [...batch, Math.min(m, n)],
3212
+ [...batch, m]
3213
+ ],
3214
+ outputDtypes: [
3215
+ a.dtype,
3216
+ require_backend.DType.Int32,
3217
+ require_backend.DType.Int32
3218
+ ]
3219
+ });
3220
+ return Array$1.#routine(routine, [a], [
3221
+ a.#weakType,
3222
+ false,
3223
+ false
3224
+ ]);
3225
+ },
3226
+ [Primitive.Jit](args, { jaxpr }) {
3227
+ if (jaxpr.inBinders.length !== args.length) throw new Error(`jit expects ${jaxpr.inBinders.length} args, got ${args.length}`);
3228
+ const { backend, committed } = Array$1.#computeBackend("jit", args);
2876
3229
  args = args.map((ar) => ar._putSync(backend));
2877
- const consts = args.slice(0, numConsts);
2878
- const tracers = args.slice(numConsts);
2879
- const jp = jitCompile(backend, jaxpr, consts);
2880
- const { outputs, pending } = jp.execute(tracers.map((x) => x._realizeSource()));
3230
+ const jp = jitCompile(backend, jaxpr);
3231
+ const { outputs, pending } = jp.execute(args.map((x) => x._realizeSource()));
2881
3232
  for (const exe of pending) exe.updateRc(+outputs.length - 1);
2882
3233
  const prevPending = [...new Set(args.flatMap((x) => x.#pending))];
2883
3234
  for (const exe of prevPending) exe.updateRc(+outputs.length);
@@ -2977,7 +3328,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
2977
3328
  device
2978
3329
  });
2979
3330
  } else {
2980
- const weakType = dtype == void 0;
3331
+ const weakType = dtype == void 0 && shape$1.length === 0;
2981
3332
  dtype = dtype ?? require_backend.DType.Float32;
2982
3333
  const data = require_backend.dtypedJsArray(dtype, flat);
2983
3334
  return arrayFromData(data, shape$1, {
@@ -3091,7 +3442,7 @@ function ones(shape$1, { dtype, device } = {}) {
3091
3442
  }
3092
3443
  /** Return a new array of given shape and type, filled with `fill_value`. */
3093
3444
  function full(shape$1, fillValue, { dtype, device } = {}) {
3094
- let weakType = dtype == void 0;
3445
+ let weakType = dtype == void 0 && shape$1.length === 0;
3095
3446
  if (typeof fillValue === "number") dtype = dtype ?? require_backend.DType.Float32;
3096
3447
  else if (typeof fillValue === "boolean") {
3097
3448
  dtype = dtype ?? require_backend.DType.Bool;
@@ -3176,6 +3527,43 @@ function arange(start, stop, step = 1, { dtype, device } = {}) {
3176
3527
  });
3177
3528
  }
3178
3529
  /**
3530
+ * Return an array with ones on and below the diagonal and zeros elsewhere.
3531
+ *
3532
+ * If `k` is provided, it specifies the sub-diagonal on and below which the
3533
+ * array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
3534
+ * `k>0` is above it.
3535
+ */
3536
+ function tri(n, m, k = 0, { dtype, device } = {}) {
3537
+ m ??= n;
3538
+ dtype ??= require_backend.DType.Float32;
3539
+ if (!Number.isInteger(n) || n < 0) throw new Error(`tri: n must be a non-negative integer, got ${n}`);
3540
+ if (!Number.isInteger(m) || m < 0) throw new Error(`tri: m must be a non-negative integer, got ${m}`);
3541
+ if (!Number.isInteger(k)) throw new Error(`tri: k must be an integer, got ${k}`);
3542
+ const rows = arange(k, n + k, 1, {
3543
+ dtype: require_backend.DType.Int32,
3544
+ device
3545
+ });
3546
+ const cols = arange(0, m, 1, {
3547
+ dtype: require_backend.DType.Int32,
3548
+ device
3549
+ });
3550
+ return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
3551
+ }
3552
+ /** Return the lower triangle of an array. Must be of dimension >= 2. */
3553
+ function tril(a, k = 0) {
3554
+ if (ndim$1(a) < 2) throw new Error(`tril: input array must be at least 2D, got ${ndim$1(a)}D`);
3555
+ a = fudgeArray(a);
3556
+ const [n, m] = a.shape.slice(-2);
3557
+ return where$1(tri(n, m, k, { dtype: require_backend.DType.Bool }), a.ref, zerosLike$1(a));
3558
+ }
3559
+ /** Return the upper triangle of an array. Must be of dimension >= 2. */
3560
+ function triu(a, k = 0) {
3561
+ if (ndim$1(a) < 2) throw new Error(`tril: input array must be at least 2D, got ${ndim$1(a)}D`);
3562
+ a = fudgeArray(a);
3563
+ const [n, m] = a.shape.slice(-2);
3564
+ return where$1(tri(n, m, k - 1, { dtype: require_backend.DType.Bool }), zerosLike$1(a.ref), a);
3565
+ }
3566
+ /**
3179
3567
  * Return evenly spaced numbers over a specified interval.
3180
3568
  *
3181
3569
  * Returns _num_ evenly spaced samples, calculated over the interval
@@ -3212,6 +3600,27 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
3212
3600
  committed: device != void 0
3213
3601
  });
3214
3602
  }
3603
+ /**
3604
+ * Return numbers spaced evenly on a log scale.
3605
+ *
3606
+ * In linear space, the sequence starts at `base ** start` and ends at
3607
+ * `base ** stop` (see `endpoint` below).
3608
+ *
3609
+ * @param start - `base ** start` is the starting value of the sequence.
3610
+ * @param stop - `base ** stop` is the final value of the sequence, unless `endpoint` is false.
3611
+ * @param num - Number of samples to generate. Default is 50.
3612
+ * @param endpoint - If true, `stop` is the last sample. Otherwise, it is not included. Default is true.
3613
+ * @param base - The base of the log space. Default is 10.
3614
+ * @returns Array of evenly spaced values on a log scale.
3615
+ */
3616
+ function logspace(start, stop, num = 50, endpoint = true, base = 10, { dtype, device } = {}) {
3617
+ const y = linspace(start, stop, num, endpoint, {
3618
+ dtype,
3619
+ device
3620
+ });
3621
+ const logBase = Math.log(base);
3622
+ return exp$1(mul(y, logBase));
3623
+ }
3215
3624
  function aluCompare(a, b, op) {
3216
3625
  switch (op) {
3217
3626
  case CompareOp.Less: return require_backend.AluExp.cmplt(a, b);
@@ -3222,385 +3631,211 @@ function aluCompare(a, b, op) {
3222
3631
  }
3223
3632
 
3224
3633
  //#endregion
3225
- //#region src/frontend/jvp.ts
3634
+ //#region src/frontend/vmap.ts
3226
3635
  var import_usingCtx$1 = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
3227
- var JVPTracer = class extends Tracer {
3228
- constructor(trace$1, primal, tangent) {
3636
+ function mappedAval(batchDim, aval) {
3637
+ const shape$1 = [...aval.shape];
3638
+ shape$1.splice(batchDim, 1);
3639
+ return new ShapedArray(shape$1, aval.dtype, aval.weakType);
3640
+ }
3641
+ /** Move one axis to a different index. */
3642
+ function moveaxis(x, src, dst) {
3643
+ const t = pureArray(x);
3644
+ src = require_backend.checkAxis(src, t.ndim);
3645
+ dst = require_backend.checkAxis(dst, t.ndim);
3646
+ if (src === dst) return t;
3647
+ const perm = require_backend.range(t.ndim);
3648
+ perm.splice(src, 1);
3649
+ perm.splice(dst, 0, src);
3650
+ return transpose$1(t, perm);
3651
+ }
3652
+ function moveBatchAxis(axisSize, src, dst, x) {
3653
+ if (src === null) {
3654
+ const targetShape = [...x.shape];
3655
+ targetShape.splice(dst, 0, axisSize);
3656
+ return broadcast(x, targetShape, [dst]);
3657
+ } else if (src === dst) return x;
3658
+ else return moveaxis(x, src, dst);
3659
+ }
3660
+ var BatchTracer = class extends Tracer {
3661
+ constructor(trace$1, val, batchDim) {
3229
3662
  super(trace$1);
3230
- this.primal = primal;
3231
- this.tangent = tangent;
3663
+ this.val = val;
3664
+ this.batchDim = batchDim;
3232
3665
  }
3233
3666
  get aval() {
3234
- return this.primal.aval;
3667
+ if (this.batchDim === null) return this.val.aval;
3668
+ else return mappedAval(this.batchDim, this.val.aval);
3235
3669
  }
3236
3670
  toString() {
3237
- return `JVPTracer(${this.primal.toString()}, ${this.tangent.toString()})`;
3671
+ return `BatchTracer(${this.val.toString()}, ${this.batchDim})`;
3238
3672
  }
3239
3673
  get ref() {
3240
- this.primal.ref, this.tangent.ref;
3674
+ this.val.ref;
3241
3675
  return this;
3242
3676
  }
3243
3677
  dispose() {
3244
- this.primal.dispose();
3245
- this.tangent.dispose();
3678
+ this.val.dispose();
3679
+ }
3680
+ fullLower() {
3681
+ if (this.batchDim === null) return this.val.fullLower();
3682
+ else return this;
3246
3683
  }
3247
3684
  };
3248
- var JVPTrace = class extends Trace {
3685
+ var BatchTrace = class extends Trace {
3249
3686
  pure(val) {
3250
3687
  return this.lift(pureArray(val));
3251
3688
  }
3252
3689
  lift(val) {
3253
- return new JVPTracer(this, val, zerosLike$1(val.ref));
3690
+ return new BatchTracer(this, val, null);
3254
3691
  }
3255
3692
  processPrimitive(primitive, tracers, params) {
3256
- const [primalsIn, tangentsIn] = require_backend.unzip2(tracers.map((x) => [x.primal, x.tangent]));
3257
- const jvpRule = jvpRules[primitive];
3258
- if (jvpRule === void 0) throw new Error(`No JVP rule for: ${primitive}`);
3259
- const [primalsOut, tangentsOut] = jvpRule(primalsIn, tangentsIn, params);
3260
- return require_backend.zip(primalsOut, tangentsOut).map(([x, t]) => new JVPTracer(this, x, t));
3693
+ const [valsIn, bdimsIn] = require_backend.unzip2(tracers.map((t) => [t.val, t.batchDim]));
3694
+ const vmapRule = vmapRules[primitive];
3695
+ if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
3696
+ if (bdimsIn.every((d) => d === null)) {
3697
+ const valOuts$1 = bind(primitive, valsIn, params);
3698
+ return valOuts$1.map((x) => new BatchTracer(this, x, null));
3699
+ }
3700
+ const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
3701
+ if (valOuts.length !== bdimOuts.length) throw new Error(`vmap rule for ${primitive} returned mismatched lengths: ${valOuts.length} vs ${bdimOuts.length}`);
3702
+ return require_backend.zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
3703
+ }
3704
+ get axisSize() {
3705
+ return this.main.globalData;
3261
3706
  }
3262
3707
  };
3263
- /** Rule that applies the same operation to primals and tangents. */
3264
- function linearTangentsJvp(primitive) {
3265
- return (primals, tangents, params) => {
3266
- const ys = bind(primitive, primals, params);
3267
- const dys = bind(primitive, tangents, params);
3268
- return [ys, dys];
3708
+ /**
3709
+ * Process a primitive with built-in broadcasting.
3710
+ *
3711
+ * Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
3712
+ */
3713
+ function broadcastBatcher(prim) {
3714
+ return (axisSize, args, dims, params) => {
3715
+ if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
3716
+ const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
3717
+ const firstIdx = dims.findIndex((d) => d !== null);
3718
+ const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
3719
+ if (require_backend.zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[bind1(prim, args, params)], [nd + firstBdim]];
3720
+ args = args.map((x, i) => {
3721
+ if (dims[i] === null) return x;
3722
+ x = moveBatchAxis(axisSize, dims[i], 0, x);
3723
+ if (x.ndim < nd) x = x.reshape([
3724
+ x.shape[0],
3725
+ ...require_backend.rep(nd - x.ndim, 1),
3726
+ ...x.shape.slice(1)
3727
+ ]);
3728
+ return x;
3729
+ });
3730
+ return [[bind1(prim, args, params)], [0]];
3269
3731
  };
3270
3732
  }
3271
- /** Rule for product of gradients in bilinear operations. */
3272
- function bilinearTangentsJvp(primitive) {
3273
- return ([x, y], [dx, dy], params) => {
3274
- const primal = bind1(primitive, [x.ref, y.ref], params);
3275
- const tangent = bind1(primitive, [x, dy], params).add(bind1(primitive, [dx, y], params));
3276
- return [[primal], [tangent]];
3733
+ function unopBatcher(prim) {
3734
+ return (axisSize, [x], [xBdim], params) => {
3735
+ return [[bind1(prim, [x], params)], [xBdim]];
3277
3736
  };
3278
3737
  }
3279
- /** Rule that zeros out any tangents. */
3280
- function zeroTangentsJvp(primitive) {
3281
- return (primals, tangents, params) => {
3282
- for (const t of tangents) t.dispose();
3283
- const ys = bind(primitive, primals, params);
3284
- return [ys, ys.map((y) => zerosLike$1(y.ref))];
3738
+ function lastDimsBatcher(prim, inputDims, numOutputs = 1) {
3739
+ return (axisSize, [x], [xBdim], params) => {
3740
+ require_backend.assertNonNull(xBdim);
3741
+ if (xBdim < x.ndim - inputDims) return [bind(prim, [x], params), require_backend.rep(numOutputs, xBdim)];
3742
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3743
+ return [bind(prim, [x], params), require_backend.rep(numOutputs, 0)];
3285
3744
  };
3286
3745
  }
3287
- const jvpRules = {
3288
- [Primitive.Add]: linearTangentsJvp(Primitive.Add),
3289
- [Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
3290
- [Primitive.Idiv]: zeroTangentsJvp(Primitive.Idiv),
3291
- [Primitive.Mod]([x, y], [dx, dy]) {
3292
- if (!require_backend.isFloatDtype(x.dtype) && !require_backend.isFloatDtype(y.dtype)) {
3293
- dx.dispose();
3294
- dy.dispose();
3295
- return [[x.ref, y.ref], [zerosLike$1(x), zerosLike$1(y)]];
3296
- }
3297
- const q = idiv(x.ref, y.ref);
3298
- return [[mod(x, y)], [dx.sub(dy.mul(q))]];
3299
- },
3300
- [Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
3301
- [Primitive.Reciprocal]([x], [dx]) {
3302
- const xRecip = reciprocal$1(x.ref);
3303
- return [[xRecip.ref], [neg(xRecip.ref.mul(xRecip)).mul(dx)]];
3304
- },
3305
- [Primitive.Floor]: zeroTangentsJvp(Primitive.Floor),
3306
- [Primitive.Ceil]: zeroTangentsJvp(Primitive.Ceil),
3307
- [Primitive.StopGradient]: zeroTangentsJvp(Primitive.StopGradient),
3308
- [Primitive.Cast]([x], [dx], { dtype }) {
3309
- if (x.dtype === dtype) return [[x], [dx]];
3310
- if (require_backend.isFloatDtype(dtype) && require_backend.isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
3311
- else {
3312
- dx.dispose();
3313
- return [[cast(x.ref, dtype)], [zerosLike$1(x)]];
3314
- }
3315
- },
3316
- [Primitive.Bitcast]([x], [dx], { dtype }) {
3317
- if (x.dtype === dtype) return [[x], [dx]];
3318
- dx.dispose();
3319
- return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
3320
- },
3321
- [Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
3322
- [Primitive.Sin]([x], [dx]) {
3323
- return [[sin$1(x.ref)], [cos$1(x).mul(dx)]];
3746
+ const vmapRules = {
3747
+ [Primitive.Add]: broadcastBatcher(Primitive.Add),
3748
+ [Primitive.Mul]: broadcastBatcher(Primitive.Mul),
3749
+ [Primitive.Idiv]: broadcastBatcher(Primitive.Idiv),
3750
+ [Primitive.Mod]: broadcastBatcher(Primitive.Mod),
3751
+ [Primitive.Min]: broadcastBatcher(Primitive.Min),
3752
+ [Primitive.Max]: broadcastBatcher(Primitive.Max),
3753
+ [Primitive.Neg]: unopBatcher(Primitive.Neg),
3754
+ [Primitive.Reciprocal]: unopBatcher(Primitive.Reciprocal),
3755
+ [Primitive.Floor]: unopBatcher(Primitive.Floor),
3756
+ [Primitive.Ceil]: unopBatcher(Primitive.Ceil),
3757
+ [Primitive.StopGradient]: unopBatcher(Primitive.StopGradient),
3758
+ [Primitive.Cast]: unopBatcher(Primitive.Cast),
3759
+ [Primitive.Bitcast]: unopBatcher(Primitive.Bitcast),
3760
+ [Primitive.Sin]: unopBatcher(Primitive.Sin),
3761
+ [Primitive.Cos]: unopBatcher(Primitive.Cos),
3762
+ [Primitive.Asin]: unopBatcher(Primitive.Asin),
3763
+ [Primitive.Atan]: unopBatcher(Primitive.Atan),
3764
+ [Primitive.Exp]: unopBatcher(Primitive.Exp),
3765
+ [Primitive.Log]: unopBatcher(Primitive.Log),
3766
+ [Primitive.Erf]: unopBatcher(Primitive.Erf),
3767
+ [Primitive.Erfc]: unopBatcher(Primitive.Erfc),
3768
+ [Primitive.Sqrt]: unopBatcher(Primitive.Sqrt),
3769
+ [Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
3770
+ require_backend.assertNonNull(xBdim);
3771
+ const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3772
+ const outBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
3773
+ return [[reduce(x, op, newAxis)], [outBdim]];
3324
3774
  },
3325
- [Primitive.Cos]([x], [dx]) {
3326
- return [[cos$1(x.ref)], [neg(sin$1(x)).mul(dx)]];
3775
+ [Primitive.Dot](axisSize, [x, y], [xBdim, yBdim]) {
3776
+ x = moveBatchAxis(axisSize, xBdim, x.ndim - (xBdim === null ? 1 : 2), x);
3777
+ y = moveBatchAxis(axisSize, yBdim, y.ndim - (yBdim === null ? 1 : 2), y);
3778
+ const z = dot$2(x, y);
3779
+ return [[z], [z.ndim - 1]];
3327
3780
  },
3328
- [Primitive.Asin]([x], [dx]) {
3329
- const denom = sqrt$1(reciprocal$1(cast(1, x.dtype).sub(x.ref.mul(x.ref))));
3330
- return [[asin$1(x)], [denom.mul(dx)]];
3781
+ [Primitive.Conv](axisSize, [x, y], [xBdim, yBdim], params) {
3782
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3783
+ y = moveBatchAxis(axisSize, yBdim, 0, y);
3784
+ const z = conv$1(x, y, {
3785
+ ...params,
3786
+ vmapDims: params.vmapDims + 1
3787
+ });
3788
+ return [[z], [0]];
3331
3789
  },
3332
- [Primitive.Atan]([x], [dx]) {
3333
- const denom = cast(1, x.dtype).add(x.ref.mul(x.ref));
3334
- return [[atan$1(x)], [dx.div(denom)]];
3790
+ [Primitive.Compare]: broadcastBatcher(Primitive.Compare),
3791
+ [Primitive.Where]: broadcastBatcher(Primitive.Where),
3792
+ [Primitive.Concatenate](axisSize, xs, xBdims, { axis }) {
3793
+ const minBdim = Math.min(...xBdims.filter((d) => d !== null));
3794
+ xs = xs.map((x, i) => moveBatchAxis(axisSize, xBdims[i], minBdim, x));
3795
+ const newAxis = axis + (minBdim <= axis ? 1 : 0);
3796
+ return [[concatenate$1(xs, newAxis)], [minBdim]];
3335
3797
  },
3336
- [Primitive.Exp]([x], [dx]) {
3337
- const z = exp$1(x);
3338
- return [[z.ref], [z.mul(dx)]];
3798
+ [Primitive.Split](axisSize, [x], [xBdim], { axis, sizes }) {
3799
+ require_backend.assertNonNull(xBdim);
3800
+ const newAxis = axis + (xBdim <= axis ? 1 : 0);
3801
+ const outs = split$2(x, newAxis, sizes);
3802
+ return [outs, require_backend.rep(outs.length, xBdim)];
3339
3803
  },
3340
- [Primitive.Log]([x], [dx]) {
3341
- return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
3804
+ [Primitive.RandomBits](axisSize, [k0, k1], [bdim0, bdim1], { shape: shape$1, mode }) {
3805
+ k0 = moveBatchAxis(axisSize, bdim0, 0, k0);
3806
+ k1 = moveBatchAxis(axisSize, bdim1, 0, k1);
3807
+ return [[randomBits(k0, k1, [axisSize, ...shape$1], mode)], [0]];
3342
3808
  },
3343
- [Primitive.Erf]([x], [dx]) {
3344
- const coeff = 2 / Math.sqrt(Math.PI);
3345
- const expTerm = exp$1(neg(x.ref.mul(x.ref)));
3346
- return [[erf$1(x)], [expTerm.mul(coeff).mul(dx)]];
3347
- },
3348
- [Primitive.Erfc]([x], [dx]) {
3349
- const coeff = -2 / Math.sqrt(Math.PI);
3350
- const expTerm = exp$1(neg(x.ref.mul(x.ref)));
3351
- return [[erfc$1(x)], [expTerm.mul(coeff).mul(dx)]];
3352
- },
3353
- [Primitive.Sqrt]([x], [dx]) {
3354
- const z = sqrt$1(x);
3355
- return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
3356
- },
3357
- [Primitive.Min]([x, y], [dx, dy]) {
3358
- return [[min$1(x.ref, y.ref)], [where$1(less$1(y, x), dy, dx)]];
3359
- },
3360
- [Primitive.Max]([x, y], [dx, dy]) {
3361
- return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
3362
- },
3363
- [Primitive.Reduce]([x], [dx], { op, axis }) {
3364
- if (op === require_backend.AluOp.Add) return [[reduce(x, op, axis)], [reduce(dx, op, axis)]];
3365
- else if (op === require_backend.AluOp.Mul) {
3366
- const primal = reduce(x.ref, op, axis);
3367
- const tangent = broadcast(primal.ref, x.shape, axis).mul(reciprocal$1(x)).mul(dx).sum(axis);
3368
- return [[primal], [tangent]];
3369
- } else if (op === require_backend.AluOp.Min || op === require_backend.AluOp.Max) {
3370
- const primal = reduce(x.ref, op, axis);
3371
- const notMin = notEqual$1(x, broadcast(primal.ref, x.shape, axis));
3372
- const minCount = where$1(notMin.ref, 0, 1).sum(axis);
3373
- const tangent = where$1(notMin, 0, dx).sum(axis).div(minCount);
3374
- return [[primal], [tangent]];
3375
- } else throw new Error(`JVP rule not implemented for reduce op: ${op}`);
3376
- },
3377
- [Primitive.Pool]: linearTangentsJvp(Primitive.Pool),
3378
- [Primitive.PoolTranspose]: linearTangentsJvp(Primitive.PoolTranspose),
3379
- [Primitive.Dot]: bilinearTangentsJvp(Primitive.Dot),
3380
- [Primitive.Conv]: bilinearTangentsJvp(Primitive.Conv),
3381
- [Primitive.Compare]: zeroTangentsJvp(Primitive.Compare),
3382
- [Primitive.Where]([cond, x, y], [dcond, dx, dy]) {
3383
- dcond.dispose();
3384
- return [[where$1(cond.ref, x, y)], [where$1(cond, dx, dy)]];
3385
- },
3386
- [Primitive.Transpose]: linearTangentsJvp(Primitive.Transpose),
3387
- [Primitive.Broadcast]: linearTangentsJvp(Primitive.Broadcast),
3388
- [Primitive.Reshape]: linearTangentsJvp(Primitive.Reshape),
3389
- [Primitive.Flip]: linearTangentsJvp(Primitive.Flip),
3390
- [Primitive.Shrink]: linearTangentsJvp(Primitive.Shrink),
3391
- [Primitive.Pad]: linearTangentsJvp(Primitive.Pad),
3392
- [Primitive.Gather]([x, ...indices], [dx, ..._], { axis, outDim }) {
3393
- const indicesRef = indices.map((t) => t.ref);
3394
- return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
3395
- },
3396
- [Primitive.JitCall](primals, tangents, { name, jaxpr }) {
3397
- const { newJaxpr, newConsts } = jvpJaxpr(jaxpr);
3398
- const outs = bind(Primitive.JitCall, [
3399
- ...newConsts.map((c) => c.ref),
3400
- ...primals,
3401
- ...tangents
3402
- ], {
3403
- name: `${name}_jvp`,
3404
- jaxpr: newJaxpr,
3405
- numConsts: newConsts.length
3406
- });
3407
- const n = outs.length / 2;
3408
- if (!Number.isInteger(n)) throw new Error("internal: JVP Jaxpr output length is not even");
3409
- const [primalsOut, tangentsOut] = [outs.slice(0, n), outs.slice(n)];
3410
- return [primalsOut, tangentsOut];
3411
- }
3412
- };
3413
- const jvpJaxprCache = /* @__PURE__ */ new Map();
3414
- function jvpJaxpr(jaxpr) {
3415
- if (jvpJaxprCache.has(jaxpr)) return jvpJaxprCache.get(jaxpr);
3416
- const inAvals = jaxpr.inBinders.map((v) => v.aval);
3417
- const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((primals, tangents) => jvpFlat(jaxprAsFun(jaxpr), primals, tangents))(inAvals, inAvals);
3418
- const result = {
3419
- newJaxpr,
3420
- newConsts
3421
- };
3422
- jvpJaxprCache.set(jaxpr, result);
3423
- return result;
3424
- }
3425
- function jvpFlat(f, primals, tangents) {
3426
- try {
3427
- var _usingCtx$1 = (0, import_usingCtx$1.default)();
3428
- const main = _usingCtx$1.u(newMain(JVPTrace));
3429
- const trace$1 = new JVPTrace(main);
3430
- const tracersIn = require_backend.zip(primals, tangents).map(([x, t]) => new JVPTracer(trace$1, pureArray(x), pureArray(t)));
3431
- const outs = f(...tracersIn);
3432
- const tracersOut = outs.map((out) => fullRaise(trace$1, out));
3433
- return require_backend.unzip2(tracersOut.map((t) => [t.primal, t.tangent]));
3434
- } catch (_) {
3435
- _usingCtx$1.e = _;
3436
- } finally {
3437
- _usingCtx$1.d();
3438
- }
3439
- }
3440
- function jvp$1(f, primals, tangents) {
3441
- const [primalsFlat, inTree] = flatten(primals);
3442
- const [tangentsFlat, inTree2] = flatten(tangents);
3443
- if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
3444
- const [flatFun, outTree] = flattenFun(f, inTree);
3445
- const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
3446
- if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
3447
- const primalsOut = unflatten(outTree.value, primalsOutFlat);
3448
- const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
3449
- return [primalsOut, tangentsOut];
3450
- }
3451
-
3452
- //#endregion
3453
- //#region src/frontend/vmap.ts
3454
- var import_usingCtx = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
3455
- function mappedAval(batchDim, aval) {
3456
- const shape$1 = [...aval.shape];
3457
- shape$1.splice(batchDim, 1);
3458
- return new ShapedArray(shape$1, aval.dtype, aval.weakType);
3459
- }
3460
- /** Move one axis to a different index. */
3461
- function moveaxis(x, src, dst) {
3462
- const t = pureArray(x);
3463
- src = require_backend.checkAxis(src, t.ndim);
3464
- dst = require_backend.checkAxis(dst, t.ndim);
3465
- if (src === dst) return t;
3466
- const perm = require_backend.range(t.ndim);
3467
- perm.splice(src, 1);
3468
- perm.splice(dst, 0, src);
3469
- return transpose$1(t, perm);
3470
- }
3471
- function moveBatchAxis(axisSize, src, dst, x) {
3472
- if (src === null) {
3473
- const targetShape = [...x.shape];
3474
- targetShape.splice(dst, 0, axisSize);
3475
- return broadcast(x, targetShape, [dst]);
3476
- } else if (src === dst) return x;
3477
- else return moveaxis(x, src, dst);
3478
- }
3479
- var BatchTracer = class extends Tracer {
3480
- constructor(trace$1, val, batchDim) {
3481
- super(trace$1);
3482
- this.val = val;
3483
- this.batchDim = batchDim;
3484
- }
3485
- get aval() {
3486
- if (this.batchDim === null) return this.val.aval;
3487
- else return mappedAval(this.batchDim, this.val.aval);
3488
- }
3489
- toString() {
3490
- return `BatchTracer(${this.val.toString()}, ${this.batchDim})`;
3491
- }
3492
- get ref() {
3493
- this.val.ref;
3494
- return this;
3495
- }
3496
- dispose() {
3497
- this.val.dispose();
3498
- }
3499
- fullLower() {
3500
- if (this.batchDim === null) return this.val.fullLower();
3501
- else return this;
3502
- }
3503
- };
3504
- var BatchTrace = class extends Trace {
3505
- pure(val) {
3506
- return this.lift(pureArray(val));
3507
- }
3508
- lift(val) {
3509
- return new BatchTracer(this, val, null);
3510
- }
3511
- processPrimitive(primitive, tracers, params) {
3512
- const [valsIn, bdimsIn] = require_backend.unzip2(tracers.map((t) => [t.val, t.batchDim]));
3513
- const vmapRule = vmapRules[primitive];
3514
- if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
3515
- if (bdimsIn.every((d) => d === null)) {
3516
- const valOuts$1 = bind(primitive, valsIn, params);
3517
- return valOuts$1.map((x) => new BatchTracer(this, x, null));
3809
+ [Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
3810
+ if (indicesBdim.every((d) => d === null)) {
3811
+ require_backend.assertNonNull(xBdim);
3812
+ const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3813
+ let newBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
3814
+ let newOutDim = outDim;
3815
+ if (newOutDim < newBdim) newBdim += axis.length;
3816
+ else newOutDim += 1;
3817
+ return [[gather(x, indices, newAxis, newOutDim)], [newBdim]];
3518
3818
  }
3519
- const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
3520
- return require_backend.zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
3521
- }
3522
- get axisSize() {
3523
- return this.main.globalData;
3524
- }
3525
- };
3526
- /**
3527
- * Process a primitive with built-in broadcasting.
3528
- *
3529
- * Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
3530
- */
3531
- function broadcastBatcher(op) {
3532
- return (axisSize, args, dims) => {
3533
- if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
3534
- const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
3535
- const firstIdx = dims.findIndex((d) => d !== null);
3536
- const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
3537
- if (require_backend.zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
3538
- args = args.map((x, i) => {
3539
- if (dims[i] === null) return x;
3540
- x = moveBatchAxis(axisSize, dims[i], 0, x);
3541
- if (x.ndim < nd) x = x.reshape([
3542
- x.shape[0],
3543
- ...require_backend.rep(nd - x.ndim, 1),
3544
- ...x.shape.slice(1)
3819
+ const nd = Math.max(...indices.map((m, i) => ndim$1(m) + (indicesBdim[i] === null ? 1 : 0)));
3820
+ indices = indices.map((m, i) => {
3821
+ if (indicesBdim[i] === null) return m;
3822
+ m = moveBatchAxis(axisSize, indicesBdim[i], 0, m);
3823
+ if (m.ndim < nd) m = m.reshape([
3824
+ m.shape[0],
3825
+ ...require_backend.rep(nd - m.ndim, 1),
3826
+ ...m.shape.slice(1)
3545
3827
  ]);
3546
- return x;
3547
- });
3548
- return [[op(...args)], [0]];
3549
- };
3550
- }
3551
- function unopBatcher(op) {
3552
- return (axisSize, [x], [xBdim], params) => {
3553
- return [[op(x, params)], [xBdim]];
3554
- };
3555
- }
3556
- const vmapRules = {
3557
- [Primitive.Add]: broadcastBatcher(add$1),
3558
- [Primitive.Mul]: broadcastBatcher(mul),
3559
- [Primitive.Idiv]: broadcastBatcher(idiv),
3560
- [Primitive.Mod]: broadcastBatcher(mod),
3561
- [Primitive.Neg]: unopBatcher(neg),
3562
- [Primitive.Reciprocal]: unopBatcher(reciprocal$1),
3563
- [Primitive.Floor]: unopBatcher(floor$1),
3564
- [Primitive.Ceil]: unopBatcher(ceil$1),
3565
- [Primitive.StopGradient]: unopBatcher(stopGradient),
3566
- [Primitive.Cast]: unopBatcher((x, { dtype }) => cast(x, dtype)),
3567
- [Primitive.Bitcast]: unopBatcher((x, { dtype }) => bitcast(x, dtype)),
3568
- [Primitive.Sin]: unopBatcher(sin$1),
3569
- [Primitive.Cos]: unopBatcher(cos$1),
3570
- [Primitive.Asin]: unopBatcher(asin$1),
3571
- [Primitive.Atan]: unopBatcher(atan$1),
3572
- [Primitive.Exp]: unopBatcher(exp$1),
3573
- [Primitive.Log]: unopBatcher(log$1),
3574
- [Primitive.Erf]: unopBatcher(erf$1),
3575
- [Primitive.Erfc]: unopBatcher(erfc$1),
3576
- [Primitive.Sqrt]: unopBatcher(sqrt$1),
3577
- [Primitive.Min]: broadcastBatcher(min$1),
3578
- [Primitive.Max]: broadcastBatcher(max$1),
3579
- [Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
3580
- require_backend.assertNonNull(xBdim);
3581
- const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3582
- const outBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
3583
- return [[reduce(x, op, newAxis)], [outBdim]];
3584
- },
3585
- [Primitive.Dot](axisSize, [x, y], [xBdim, yBdim]) {
3586
- x = moveBatchAxis(axisSize, xBdim, x.ndim - (xBdim === null ? 1 : 2), x);
3587
- y = moveBatchAxis(axisSize, yBdim, y.ndim - (yBdim === null ? 1 : 2), y);
3588
- const z = dot$2(x, y);
3589
- return [[z], [z.ndim - 1]];
3590
- },
3591
- [Primitive.Conv](axisSize, [x, y], [xBdim, yBdim], params) {
3592
- x = moveBatchAxis(axisSize, xBdim, 0, x);
3593
- y = moveBatchAxis(axisSize, yBdim, 0, y);
3594
- const z = conv$1(x, y, {
3595
- ...params,
3596
- vmapDims: params.vmapDims + 1
3828
+ return m;
3597
3829
  });
3598
- return [[z], [0]];
3599
- },
3600
- [Primitive.Compare](axisSize, args, dims, { op }) {
3601
- return broadcastBatcher((x, y) => compare(x, y, op))(axisSize, args, dims, {});
3830
+ if (xBdim === null) return [[gather(x, indices, axis, outDim)], [outDim]];
3831
+ else {
3832
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3833
+ const newAxis = [0, ...axis.map((ax) => ax + 1)];
3834
+ const extraBatchIndex = arange(axisSize).reshape([-1, ...require_backend.rep(nd - 1, 1)]);
3835
+ indices.splice(0, 0, extraBatchIndex);
3836
+ return [[gather(x, indices, newAxis, outDim)], [outDim]];
3837
+ }
3602
3838
  },
3603
- [Primitive.Where]: broadcastBatcher(where$1),
3604
3839
  [Primitive.Transpose](axisSize, [x], [xBdim], { perm }) {
3605
3840
  require_backend.assertNonNull(xBdim);
3606
3841
  const newPerm = perm.map((p) => p + (xBdim <= p ? 1 : 0));
@@ -3632,42 +3867,39 @@ const vmapRules = {
3632
3867
  const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
3633
3868
  return [[pad$1(x, newWidth)], [xBdim]];
3634
3869
  },
3635
- [Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
3636
- if (indicesBdim.every((d) => d === null)) {
3637
- require_backend.assertNonNull(xBdim);
3638
- const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3639
- let newBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
3640
- let newOutDim = outDim;
3641
- if (newOutDim < newBdim) newBdim += axis.length;
3642
- else newOutDim += 1;
3643
- return [[gather(x, indices, newAxis, newOutDim)], [newBdim]];
3644
- }
3645
- const nd = Math.max(...indices.map((m, i) => ndim$1(m) + (indicesBdim[i] === null ? 1 : 0)));
3646
- indices = indices.map((m, i) => {
3647
- if (indicesBdim[i] === null) return m;
3648
- m = moveBatchAxis(axisSize, indicesBdim[i], 0, m);
3649
- if (m.ndim < nd) m = m.reshape([
3650
- m.shape[0],
3651
- ...require_backend.rep(nd - m.ndim, 1),
3652
- ...m.shape.slice(1)
3870
+ [Primitive.Sort]: lastDimsBatcher(Primitive.Sort, 1),
3871
+ [Primitive.Argsort]: lastDimsBatcher(Primitive.Argsort, 1, 2),
3872
+ [Primitive.TriangularSolve](axisSize, [a, b], [aBdim, bBdim], { unitDiagonal }) {
3873
+ if (aBdim === null) {
3874
+ b = moveBatchAxis(axisSize, bBdim, -3, b);
3875
+ const [s, m, n] = b.shape.slice(-3);
3876
+ b = b.reshape([
3877
+ ...b.shape.slice(0, -3),
3878
+ s * m,
3879
+ n
3653
3880
  ]);
3654
- return m;
3655
- });
3656
- if (xBdim === null) return [[gather(x, indices, axis, outDim)], [outDim]];
3657
- else {
3658
- x = moveBatchAxis(axisSize, xBdim, 0, x);
3659
- const newAxis = [0, ...axis.map((ax) => ax + 1)];
3660
- const extraBatchIndex = arange(axisSize).reshape([-1, ...require_backend.rep(nd - 1, 1)]);
3661
- indices.splice(0, 0, extraBatchIndex);
3662
- return [[gather(x, indices, newAxis, outDim)], [outDim]];
3881
+ let x$1 = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
3882
+ x$1 = x$1.reshape([
3883
+ ...b.shape.slice(0, -2),
3884
+ s,
3885
+ m,
3886
+ n
3887
+ ]);
3888
+ return [[x$1], [x$1.ndim - 3]];
3663
3889
  }
3890
+ a = moveBatchAxis(axisSize, aBdim, 0, a);
3891
+ b = moveBatchAxis(axisSize, bBdim, 0, b);
3892
+ const x = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
3893
+ return [[x], [0]];
3664
3894
  },
3665
- [Primitive.JitCall](axisSize, args, dims, { name, jaxpr }) {
3666
- const { newJaxpr, newConsts } = vmapJaxpr(jaxpr, axisSize, dims);
3667
- const outs = bind(Primitive.JitCall, [...newConsts.map((c) => c.ref), ...args], {
3895
+ [Primitive.Cholesky]: lastDimsBatcher(Primitive.Cholesky, 2),
3896
+ [Primitive.LU]: lastDimsBatcher(Primitive.LU, 2, 3),
3897
+ [Primitive.Jit](axisSize, args, dims, { name, jaxpr }) {
3898
+ const newJaxpr = vmapJaxpr(jaxpr, axisSize, dims);
3899
+ const outs = bind(Primitive.Jit, [...newJaxpr.consts.map((c) => c.ref), ...args], {
3668
3900
  name: `${name}_vmap`,
3669
- jaxpr: newJaxpr,
3670
- numConsts: newConsts.length
3901
+ jaxpr: newJaxpr.jaxpr,
3902
+ numConsts: newJaxpr.consts.length
3671
3903
  });
3672
3904
  return [outs, require_backend.rep(outs.length, 0)];
3673
3905
  }
@@ -3683,14 +3915,10 @@ function vmapJaxpr(jaxpr, axisSize, dims) {
3683
3915
  shape$1.splice(dims[i], 0, axisSize);
3684
3916
  return new ShapedArray(shape$1, v.aval.dtype, v.aval.weakType);
3685
3917
  });
3686
- const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((args) => vmapFlat(jaxprAsFun(jaxpr), dims, args))(inAvals);
3687
- const result = {
3688
- newJaxpr,
3689
- newConsts
3690
- };
3918
+ const { jaxpr: newJaxpr } = makeJaxpr$1((args) => vmapFlat(jaxprAsFun(jaxpr), dims, args))(inAvals);
3691
3919
  if (!vmapJaxprCache.has(jaxpr)) vmapJaxprCache.set(jaxpr, /* @__PURE__ */ new Map());
3692
- vmapJaxprCache.get(jaxpr).set(cacheKey, result);
3693
- return result;
3920
+ vmapJaxprCache.get(jaxpr).set(cacheKey, newJaxpr);
3921
+ return newJaxpr;
3694
3922
  }
3695
3923
  function vmapFlat(f, inAxes, args) {
3696
3924
  let axisSize = void 0;
@@ -3704,7 +3932,7 @@ function vmapFlat(f, inAxes, args) {
3704
3932
  if (axisSize === void 0) throw new TypeError("vmap requires at least one mapped axis");
3705
3933
  let valsOut, bdimsOut;
3706
3934
  try {
3707
- var _usingCtx$1 = (0, import_usingCtx.default)();
3935
+ var _usingCtx$1 = (0, import_usingCtx$1.default)();
3708
3936
  const main = _usingCtx$1.u(newMain(BatchTrace, axisSize));
3709
3937
  const trace$1 = new BatchTrace(main);
3710
3938
  const tracersIn = args.map((x, i) => inAxes[i] === null ? pureArray(x) : new BatchTracer(trace$1, pureArray(x), inAxes[i]));
@@ -3736,13 +3964,312 @@ function vmap$1(f, inAxes = 0) {
3736
3964
  return unflatten(outTree.value, outsFlat);
3737
3965
  };
3738
3966
  }
3739
- function jacfwd$1(f) {
3740
- return function jacobianForward(x) {
3741
- if (x.shape.length !== 1) throw new TypeError("jacfwd only supports 1D inputs");
3742
- const [size$1] = x.shape;
3743
- const pushfwd = (v) => jvp$1(f, [x], [v])[1];
3744
- return vmap$1(pushfwd, [0])(eye(size$1, void 0, { dtype: x.dtype }));
3745
- };
3967
+ function jacfwd$1(f) {
3968
+ return function jacobianForward(x) {
3969
+ if (x.shape.length !== 1) throw new TypeError("jacfwd only supports 1D inputs");
3970
+ const [size$1] = x.shape;
3971
+ const pushfwd = (v) => jvp$1(f, [x], [v])[1];
3972
+ return vmap$1(pushfwd, [0])(eye(size$1, void 0, { dtype: x.dtype }));
3973
+ };
3974
+ }
3975
+
3976
+ //#endregion
3977
+ //#region src/frontend/jvp.ts
3978
+ var import_usingCtx = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
3979
+ var JVPTracer = class extends Tracer {
3980
+ constructor(trace$1, primal, tangent) {
3981
+ super(trace$1);
3982
+ this.primal = primal;
3983
+ this.tangent = tangent;
3984
+ }
3985
+ get aval() {
3986
+ return this.primal.aval;
3987
+ }
3988
+ toString() {
3989
+ return `JVPTracer(${this.primal.toString()}, ${this.tangent.toString()})`;
3990
+ }
3991
+ get ref() {
3992
+ this.primal.ref, this.tangent.ref;
3993
+ return this;
3994
+ }
3995
+ dispose() {
3996
+ this.primal.dispose();
3997
+ this.tangent.dispose();
3998
+ }
3999
+ };
4000
+ var JVPTrace = class extends Trace {
4001
+ pure(val) {
4002
+ return this.lift(pureArray(val));
4003
+ }
4004
+ lift(val) {
4005
+ return new JVPTracer(this, val, zerosLike$1(val.ref));
4006
+ }
4007
+ processPrimitive(primitive, tracers, params) {
4008
+ const [primalsIn, tangentsIn] = require_backend.unzip2(tracers.map((x) => [x.primal, x.tangent]));
4009
+ const jvpRule = jvpRules[primitive];
4010
+ if (jvpRule === void 0) throw new Error(`No JVP rule for: ${primitive}`);
4011
+ const [primalsOut, tangentsOut] = jvpRule(primalsIn, tangentsIn, params);
4012
+ return require_backend.zip(primalsOut, tangentsOut).map(([x, t]) => new JVPTracer(this, x, t));
4013
+ }
4014
+ };
4015
+ /** Rule that applies the same operation to primals and tangents. */
4016
+ function linearTangentsJvp(primitive) {
4017
+ return (primals, tangents, params) => {
4018
+ const ys = bind(primitive, primals, params);
4019
+ const dys = bind(primitive, tangents, params);
4020
+ return [ys, dys];
4021
+ };
4022
+ }
4023
+ /** Rule for product of gradients in bilinear operations. */
4024
+ function bilinearTangentsJvp(primitive) {
4025
+ return ([x, y], [dx, dy], params) => {
4026
+ const primal = bind1(primitive, [x.ref, y.ref], params);
4027
+ const tangent = bind1(primitive, [x, dy], params).add(bind1(primitive, [dx, y], params));
4028
+ return [[primal], [tangent]];
4029
+ };
4030
+ }
4031
+ /** Rule that zeros out any tangents. */
4032
+ function zeroTangentsJvp(primitive) {
4033
+ return (primals, tangents, params) => {
4034
+ for (const t of tangents) t.dispose();
4035
+ const ys = bind(primitive, primals, params);
4036
+ return [ys, ys.map((y) => zerosLike$1(y.ref))];
4037
+ };
4038
+ }
4039
+ /** Compute `a @ b.T`, batched to last two axes. */
4040
+ function batchMatmulT(a, b) {
4041
+ return dot$2(a.reshape(a.shape.toSpliced(-1, 0, 1)), b.reshape(b.shape.toSpliced(-2, 0, 1)));
4042
+ }
4043
+ /** Batch matrix transpose. */
4044
+ function mT(a) {
4045
+ return moveaxis(a, -2, -1);
4046
+ }
4047
+ function sliceAxis(a, axis, p) {
4048
+ const slices = Array(a.shape.length).fill([]);
4049
+ slices[require_backend.checkAxis(axis, a.ndim)] = p;
4050
+ return a.slice(...slices);
4051
+ }
4052
+ function padAxis(a, axis, p) {
4053
+ const pads = Array(a.shape.length).fill([0, 0]);
4054
+ pads[require_backend.checkAxis(axis, a.ndim)] = p;
4055
+ return pad$1(a, pads);
4056
+ }
4057
+ const jvpRules = {
4058
+ [Primitive.Add]: linearTangentsJvp(Primitive.Add),
4059
+ [Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
4060
+ [Primitive.Idiv]: zeroTangentsJvp(Primitive.Idiv),
4061
+ [Primitive.Mod]([x, y], [dx, dy]) {
4062
+ if (!require_backend.isFloatDtype(x.dtype) && !require_backend.isFloatDtype(y.dtype)) {
4063
+ dx.dispose();
4064
+ dy.dispose();
4065
+ return [[x.ref, y.ref], [zerosLike$1(x), zerosLike$1(y)]];
4066
+ }
4067
+ const q = idiv(x.ref, y.ref);
4068
+ return [[mod(x, y)], [dx.sub(dy.mul(q))]];
4069
+ },
4070
+ [Primitive.Min]([x, y], [dx, dy]) {
4071
+ return [[min$1(x.ref, y.ref)], [where$1(less$1(y, x), dy, dx)]];
4072
+ },
4073
+ [Primitive.Max]([x, y], [dx, dy]) {
4074
+ return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
4075
+ },
4076
+ [Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
4077
+ [Primitive.Reciprocal]([x], [dx]) {
4078
+ const xRecip = reciprocal$1(x.ref);
4079
+ return [[xRecip.ref], [neg(xRecip.ref.mul(xRecip)).mul(dx)]];
4080
+ },
4081
+ [Primitive.Floor]: zeroTangentsJvp(Primitive.Floor),
4082
+ [Primitive.Ceil]: zeroTangentsJvp(Primitive.Ceil),
4083
+ [Primitive.StopGradient]: zeroTangentsJvp(Primitive.StopGradient),
4084
+ [Primitive.Cast]([x], [dx], { dtype }) {
4085
+ if (x.dtype === dtype) return [[x], [dx]];
4086
+ if (require_backend.isFloatDtype(dtype) && require_backend.isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
4087
+ else {
4088
+ dx.dispose();
4089
+ return [[cast(x.ref, dtype)], [zerosLike$1(x)]];
4090
+ }
4091
+ },
4092
+ [Primitive.Bitcast]([x], [dx], { dtype }) {
4093
+ if (x.dtype === dtype) return [[x], [dx]];
4094
+ dx.dispose();
4095
+ return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
4096
+ },
4097
+ [Primitive.Sin]([x], [dx]) {
4098
+ return [[sin$1(x.ref)], [cos$1(x).mul(dx)]];
4099
+ },
4100
+ [Primitive.Cos]([x], [dx]) {
4101
+ return [[cos$1(x.ref)], [neg(sin$1(x)).mul(dx)]];
4102
+ },
4103
+ [Primitive.Asin]([x], [dx]) {
4104
+ const denom = sqrt$1(reciprocal$1(cast(1, x.dtype).sub(x.ref.mul(x.ref))));
4105
+ return [[asin$1(x)], [denom.mul(dx)]];
4106
+ },
4107
+ [Primitive.Atan]([x], [dx]) {
4108
+ const denom = cast(1, x.dtype).add(x.ref.mul(x.ref));
4109
+ return [[atan$1(x)], [dx.div(denom)]];
4110
+ },
4111
+ [Primitive.Exp]([x], [dx]) {
4112
+ const z = exp$1(x);
4113
+ return [[z.ref], [z.mul(dx)]];
4114
+ },
4115
+ [Primitive.Log]([x], [dx]) {
4116
+ return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
4117
+ },
4118
+ [Primitive.Erf]([x], [dx]) {
4119
+ const coeff = 2 / Math.sqrt(Math.PI);
4120
+ const expTerm = exp$1(neg(x.ref.mul(x.ref)));
4121
+ return [[erf$1(x)], [expTerm.mul(coeff).mul(dx)]];
4122
+ },
4123
+ [Primitive.Erfc]([x], [dx]) {
4124
+ const coeff = -2 / Math.sqrt(Math.PI);
4125
+ const expTerm = exp$1(neg(x.ref.mul(x.ref)));
4126
+ return [[erfc$1(x)], [expTerm.mul(coeff).mul(dx)]];
4127
+ },
4128
+ [Primitive.Sqrt]([x], [dx]) {
4129
+ const z = sqrt$1(x);
4130
+ return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
4131
+ },
4132
+ [Primitive.Reduce]([x], [dx], { op, axis }) {
4133
+ if (op === require_backend.AluOp.Add) return [[reduce(x, op, axis)], [reduce(dx, op, axis)]];
4134
+ else if (op === require_backend.AluOp.Mul) {
4135
+ const primal = reduce(x.ref, op, axis);
4136
+ const tangent = broadcast(primal.ref, x.shape, axis).mul(reciprocal$1(x)).mul(dx).sum(axis);
4137
+ return [[primal], [tangent]];
4138
+ } else if (op === require_backend.AluOp.Min || op === require_backend.AluOp.Max) {
4139
+ const primal = reduce(x.ref, op, axis);
4140
+ const notMin = notEqual$1(x, broadcast(primal.ref, x.shape, axis));
4141
+ const minCount = where$1(notMin.ref, 0, 1).sum(axis);
4142
+ const tangent = where$1(notMin, 0, dx).sum(axis).div(minCount);
4143
+ return [[primal], [tangent]];
4144
+ } else throw new Error(`JVP rule not implemented for reduce op: ${op}`);
4145
+ },
4146
+ [Primitive.Pool]: linearTangentsJvp(Primitive.Pool),
4147
+ [Primitive.PoolTranspose]: linearTangentsJvp(Primitive.PoolTranspose),
4148
+ [Primitive.Dot]: bilinearTangentsJvp(Primitive.Dot),
4149
+ [Primitive.Conv]: bilinearTangentsJvp(Primitive.Conv),
4150
+ [Primitive.Compare]: zeroTangentsJvp(Primitive.Compare),
4151
+ [Primitive.Where]([cond, x, y], [dcond, dx, dy]) {
4152
+ dcond.dispose();
4153
+ return [[where$1(cond.ref, x, y)], [where$1(cond, dx, dy)]];
4154
+ },
4155
+ [Primitive.Concatenate]: linearTangentsJvp(Primitive.Concatenate),
4156
+ [Primitive.Split]: linearTangentsJvp(Primitive.Split),
4157
+ [Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
4158
+ [Primitive.Gather]([x, ...indices], [dx, ..._], { axis, outDim }) {
4159
+ const indicesRef = indices.map((t) => t.ref);
4160
+ return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
4161
+ },
4162
+ [Primitive.Transpose]: linearTangentsJvp(Primitive.Transpose),
4163
+ [Primitive.Broadcast]: linearTangentsJvp(Primitive.Broadcast),
4164
+ [Primitive.Reshape]: linearTangentsJvp(Primitive.Reshape),
4165
+ [Primitive.Flip]: linearTangentsJvp(Primitive.Flip),
4166
+ [Primitive.Shrink]: linearTangentsJvp(Primitive.Shrink),
4167
+ [Primitive.Pad]: linearTangentsJvp(Primitive.Pad),
4168
+ [Primitive.Sort]([x], [dx]) {
4169
+ const [y, idx] = argsort$1(x);
4170
+ return [[y], [gather(dx, [idx], [-1], -1)]];
4171
+ },
4172
+ [Primitive.Argsort]([x], [dx]) {
4173
+ const [y, idx] = argsort$1(x);
4174
+ return [[y, idx.ref], [gather(dx, [idx.ref], [-1], -1), zerosLike$1(idx)]];
4175
+ },
4176
+ [Primitive.TriangularSolve]([a, b], [da, db], { unitDiagonal }) {
4177
+ const x = triangularSolve$1(a.ref, b, { unitDiagonal });
4178
+ const dax = batchMatmulT(da, x.ref);
4179
+ const rhsT = db.sub(mT(dax));
4180
+ const dx = triangularSolve$1(a, rhsT, { unitDiagonal });
4181
+ return [[x], [dx]];
4182
+ },
4183
+ [Primitive.Cholesky]([a], [da]) {
4184
+ const L = cholesky$2(a.ref);
4185
+ da = da.ref.add(mT(da)).mul(.5);
4186
+ const W = triangularSolve$1(L.ref, da, { lower: true });
4187
+ const ST = triangularSolve$1(L.ref, mT(W), { lower: true });
4188
+ const dL = batchMatmulT(L.ref, triu(ST.ref, 1).add(triu(ST)).mul(.5));
4189
+ return [[L], [dL]];
4190
+ },
4191
+ [Primitive.LU]([a], [da]) {
4192
+ const [luMatrix, pivots, permutation] = lu$1(a);
4193
+ const [m, n] = a.shape.slice(-2);
4194
+ const k = Math.min(m, n);
4195
+ const luSliceL = sliceAxis(luMatrix.ref, -1, [0, k]);
4196
+ const lLower = tril(luSliceL, -1);
4197
+ const lPadded = m > k ? padAxis(lLower, -1, [0, m - k]) : lLower;
4198
+ const L = lPadded.add(eye(m));
4199
+ const luSliceU = sliceAxis(luMatrix.ref, -2, [0, k]);
4200
+ const uUpper = triu(luSliceU);
4201
+ const uPadded = n > k ? padAxis(uUpper, -2, [0, n - k]) : uUpper;
4202
+ const uEye = n > k ? padAxis(padAxis(eye(n - k), -1, [k, 0]), -2, [k, 0]) : zerosLike$1(uPadded.ref);
4203
+ const U = uPadded.add(uEye);
4204
+ const P = permutation.ref.reshape([...permutation.shape, 1]).equal(arange(m)).astype(da.dtype);
4205
+ const pda = batchMatmulT(P, mT(da));
4206
+ const la = mT(triangularSolve$1(L.ref, mT(pda), {
4207
+ lower: true,
4208
+ unitDiagonal: true
4209
+ }));
4210
+ const lau = triangularSolve$1(mT(U.ref), la, { lower: true });
4211
+ const lDot = batchMatmulT(L, mT(tril(lau.ref, -1)));
4212
+ const uDot = batchMatmulT(triu(lau), mT(U));
4213
+ return [[
4214
+ luMatrix,
4215
+ pivots,
4216
+ permutation
4217
+ ], [
4218
+ lDot.add(uDot),
4219
+ zerosLike$1(pivots.ref),
4220
+ zerosLike$1(permutation.ref)
4221
+ ]];
4222
+ },
4223
+ [Primitive.Jit](primals, tangents, { name, jaxpr }) {
4224
+ const newJaxpr = jvpJaxpr(jaxpr);
4225
+ const outs = bind(Primitive.Jit, [
4226
+ ...newJaxpr.consts.map((c) => c.ref),
4227
+ ...primals,
4228
+ ...tangents
4229
+ ], {
4230
+ name: `${name}_jvp`,
4231
+ jaxpr: newJaxpr.jaxpr,
4232
+ numConsts: newJaxpr.consts.length
4233
+ });
4234
+ const n = outs.length / 2;
4235
+ if (!Number.isInteger(n)) throw new Error("internal: JVP Jaxpr output length is not even");
4236
+ const [primalsOut, tangentsOut] = [outs.slice(0, n), outs.slice(n)];
4237
+ return [primalsOut, tangentsOut];
4238
+ }
4239
+ };
4240
+ const jvpJaxprCache = /* @__PURE__ */ new Map();
4241
+ function jvpJaxpr(jaxpr) {
4242
+ if (jvpJaxprCache.has(jaxpr)) return jvpJaxprCache.get(jaxpr);
4243
+ const inAvals = jaxpr.inBinders.map((v) => v.aval);
4244
+ const { jaxpr: newJaxpr } = makeJaxpr$1((primals, tangents) => jvpFlat(jaxprAsFun(jaxpr), primals, tangents))(inAvals, inAvals);
4245
+ jvpJaxprCache.set(jaxpr, newJaxpr);
4246
+ return newJaxpr;
4247
+ }
4248
+ function jvpFlat(f, primals, tangents) {
4249
+ try {
4250
+ var _usingCtx$1 = (0, import_usingCtx.default)();
4251
+ const main = _usingCtx$1.u(newMain(JVPTrace));
4252
+ const trace$1 = new JVPTrace(main);
4253
+ const tracersIn = require_backend.zip(primals, tangents).map(([x, t]) => new JVPTracer(trace$1, pureArray(x), pureArray(t)));
4254
+ const outs = f(...tracersIn);
4255
+ const tracersOut = outs.map((out) => fullRaise(trace$1, out));
4256
+ return require_backend.unzip2(tracersOut.map((t) => [t.primal, t.tangent]));
4257
+ } catch (_) {
4258
+ _usingCtx$1.e = _;
4259
+ } finally {
4260
+ _usingCtx$1.d();
4261
+ }
4262
+ }
4263
+ function jvp$1(f, primals, tangents) {
4264
+ const [primalsFlat, inTree] = flatten(primals);
4265
+ const [tangentsFlat, inTree2] = flatten(tangents);
4266
+ if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
4267
+ const [flatFun, outTree] = flattenFun(f, inTree);
4268
+ const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
4269
+ if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
4270
+ const primalsOut = unflatten(outTree.value, primalsOutFlat);
4271
+ const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
4272
+ return [primalsOut, tangentsOut];
3746
4273
  }
3747
4274
 
3748
4275
  //#endregion
@@ -3775,11 +4302,10 @@ function partialEvalFlat(f, pvalsIn) {
3775
4302
  const tracersOut = outs.map((out) => fullRaise(trace$1, out));
3776
4303
  const pvalsOut = tracersOut.map((t) => t.pval);
3777
4304
  const unknownTracersOut = tracersOut.filter((t) => !t.pval.isKnown);
3778
- const { jaxpr, consts } = partialEvalGraphToJaxpr(unknownTracersIn, unknownTracersOut);
4305
+ const jaxpr = partialEvalGraphToJaxpr(unknownTracersIn, unknownTracersOut);
3779
4306
  return {
3780
4307
  jaxpr,
3781
- pvalsOut,
3782
- consts
4308
+ pvalsOut
3783
4309
  };
3784
4310
  }
3785
4311
  /**
@@ -3796,22 +4322,19 @@ function linearizeFlatUtil(f, primalsIn) {
3796
4322
  const [primalsOut$1, tangentsOut] = jvp$1(f, x.slice(0, k), x.slice(k, 2 * k));
3797
4323
  return [...primalsOut$1, ...tangentsOut];
3798
4324
  };
3799
- const { jaxpr, pvalsOut, consts } = partialEvalFlat(fJvp, pvalsIn);
4325
+ const { jaxpr, pvalsOut } = partialEvalFlat(fJvp, pvalsIn);
3800
4326
  const primalPvals = pvalsOut.slice(0, pvalsOut.length / 2);
3801
4327
  if (!primalPvals.every((pval) => pval.isKnown)) throw new Error("Not all primal values are known after partial evaluation");
3802
4328
  const primalsOut = primalPvals.map((pval) => pval.val);
3803
4329
  return {
3804
4330
  primalsOut,
3805
- jaxpr,
3806
- consts
4331
+ jaxpr
3807
4332
  };
3808
4333
  }
3809
4334
  function linearizeFlat(f, primalsIn) {
3810
- const { primalsOut, jaxpr, consts } = linearizeFlatUtil(f, primalsIn);
3811
- const fLin = (...tangents) => evalJaxpr(jaxpr, [...consts.map((c) => c.ref), ...tangents]);
3812
- const dispose$1 = () => {
3813
- for (const c of consts) c.dispose();
3814
- };
4335
+ const { primalsOut, jaxpr } = linearizeFlatUtil(f, primalsIn);
4336
+ const fLin = (...tangents) => evalJaxpr(jaxpr.jaxpr, [...jaxpr.consts.map((c) => c.ref), ...tangents]);
4337
+ const dispose$1 = () => jaxpr.dispose();
3815
4338
  return [
3816
4339
  primalsOut,
3817
4340
  fLin,
@@ -3895,7 +4418,7 @@ var PartialEvalTrace = class extends Trace {
3895
4418
  }
3896
4419
  processPrimitive(primitive, tracers, params) {
3897
4420
  if (tracers.every((t) => t.pval.isKnown)) return bind(primitive, tracers.map((t) => t.fullLower()), params);
3898
- if (primitive === Primitive.JitCall) {
4421
+ if (primitive === Primitive.Jit) {
3899
4422
  const { name, jaxpr, numConsts } = params;
3900
4423
  return this.#partialEvalJaxpr(name, jaxpr, numConsts, tracers);
3901
4424
  }
@@ -3921,14 +4444,14 @@ var PartialEvalTrace = class extends Trace {
3921
4444
  * Evaluate a Jaxpr on a set of PartialEvalTracers, computing as many known
3922
4445
  * values as possible (with JIT) and forwarding the unknown ones.
3923
4446
  *
3924
- * Used when encountering a JitCall rule during the trace.
4447
+ * Used when encountering a Jit rule during the trace.
3925
4448
  */
3926
4449
  #partialEvalJaxpr(name, jaxpr, numConsts, tracers) {
3927
4450
  jaxpr = jaxpr.flatten();
3928
4451
  const inUnknowns = tracers.map((t) => !t.pval.isKnown);
3929
4452
  const { jaxpr1, jaxpr2, outUnknowns, numRes } = partialEvalJaxpr(jaxpr, inUnknowns);
3930
4453
  const [knownTracers, unknownTracers] = require_backend.partitionList(inUnknowns, tracers);
3931
- const outs1Res = bind(Primitive.JitCall, knownTracers.map((t) => t.ref.fullLower()), {
4454
+ const outs1Res = bind(Primitive.Jit, knownTracers.map((t) => t.ref.fullLower()), {
3932
4455
  name: `${name}_peval`,
3933
4456
  jaxpr: jaxpr1,
3934
4457
  numConsts: 0
@@ -3938,7 +4461,7 @@ var PartialEvalTrace = class extends Trace {
3938
4461
  const resTracers = res.map((x) => this.instantiateConst(fullRaise(this, x)));
3939
4462
  const recipe = {
3940
4463
  type: "JaxprEqn",
3941
- prim: Primitive.JitCall,
4464
+ prim: Primitive.Jit,
3942
4465
  tracersIn: resTracers.concat(unknownTracers),
3943
4466
  params: {
3944
4467
  name: `${name}_resid`,
@@ -3967,7 +4490,7 @@ function partialEvalJaxpr(jaxpr, inUnknowns, instantiate) {
3967
4490
  const eqns1 = [];
3968
4491
  const eqns2 = [];
3969
4492
  for (const eqn of jaxpr.eqns) {
3970
- if (eqn.primitive === Primitive.JitCall) throw new TypeError("partialEvalJaxpr requires flattened Jaxpr");
4493
+ if (eqn.primitive === Primitive.Jit) throw new TypeError("partialEvalJaxpr requires flattened Jaxpr");
3971
4494
  const hasUnknowns = eqn.inputs.some((x) => x instanceof Var && !knownVars.has(x));
3972
4495
  if (hasUnknowns) {
3973
4496
  for (const x of eqn.inputs) if (x instanceof Var && knownVars.has(x)) residuals.add(x);
@@ -4042,10 +4565,7 @@ function partialEvalGraphToJaxpr(tracersIn, tracersOut) {
4042
4565
  for (const t of tracersOut) t.dispose();
4043
4566
  jaxpr = jaxpr.simplify();
4044
4567
  if (require_backend.DEBUG >= 5) console.info("jaxpr from partial evaluation:\n" + jaxpr.toString());
4045
- return {
4046
- jaxpr,
4047
- consts
4048
- };
4568
+ return new ClosedJaxpr(jaxpr, consts);
4049
4569
  }
4050
4570
  /** Marker type for pullback, used by transpose rules. */
4051
4571
  var UndefPrimal = class {
@@ -4237,317 +4757,151 @@ const transposeRules = {
4237
4757
  cond.dispose();
4238
4758
  return cts;
4239
4759
  },
4240
- [Primitive.Transpose]([ct], [x], { perm }) {
4241
- if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Transpose);
4242
- return [transpose$1(ct, require_backend.invertPermutation(perm))];
4243
- },
4244
- [Primitive.Broadcast]([ct], [x], { axis }) {
4245
- if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Broadcast);
4246
- return [reduce(ct, require_backend.AluOp.Add, axis)];
4247
- },
4248
- [Primitive.Reshape]([ct], [x], _) {
4249
- if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Reshape);
4250
- return [reshape$1(ct, x.aval.shape)];
4251
- },
4252
- [Primitive.Flip]([ct], [x], { axis }) {
4253
- if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Flip);
4254
- return [flip$1(ct, axis)];
4255
- },
4256
- [Primitive.Shrink]([ct], [x], { slice }) {
4257
- if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Shrink);
4258
- const width = slice.map(([s, e$1], i) => [s, x.aval.shape[i] - e$1]);
4259
- return [pad$1(ct, width)];
4760
+ [Primitive.Concatenate]([ct], inputs, { axis }) {
4761
+ if (inputs.some((x) => !(x instanceof UndefPrimal))) throw new NonlinearError(Primitive.Concatenate);
4762
+ const sizes = inputs.map((x) => x.aval.shape[axis]);
4763
+ return split$2(ct, axis, sizes);
4260
4764
  },
4261
- [Primitive.Pad]([ct], [x], { width }) {
4262
- if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Pad);
4263
- const slice = width.map(([s, _e], i) => [s, s + x.aval.shape[i]]);
4264
- return [shrink(ct, slice)];
4765
+ [Primitive.Split](cts, [x], { axis }) {
4766
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Split);
4767
+ return [concatenate$1(cts, axis)];
4265
4768
  },
4266
4769
  [Primitive.Gather]([ct], [x, ...indices], { axis, outDim }) {
4267
4770
  if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
4268
4771
  if (indices.some((i) => i instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
4269
- throw new Error("Gather transpose rule is not yet implemented, requires complex Scatter sum operation");
4270
- },
4271
- [Primitive.JitCall](cts, args, { name, jaxpr }) {
4272
- const undefPrimals = args.map((x) => x instanceof UndefPrimal);
4273
- const { newJaxpr, newConsts } = transposeJaxpr(jaxpr, undefPrimals);
4274
- const residuals = args.filter((x, i$1) => !undefPrimals[i$1]);
4275
- const outs = bind(Primitive.JitCall, [
4276
- ...newConsts.map((c) => c.ref),
4277
- ...residuals,
4278
- ...cts
4279
- ], {
4280
- name: `${name}_t`,
4281
- jaxpr: newJaxpr,
4282
- numConsts: newConsts.length
4283
- });
4284
- let i = 0;
4285
- return undefPrimals.map((isUndef) => isUndef ? outs[i++] : null);
4286
- }
4287
- };
4288
- const transposeJaxprCache = /* @__PURE__ */ new Map();
4289
- function transposeJaxpr(jaxpr, undefPrimals) {
4290
- const cacheKey = JSON.stringify(undefPrimals);
4291
- const prevResult = transposeJaxprCache.get(jaxpr)?.get(cacheKey);
4292
- if (prevResult) return prevResult;
4293
- const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
4294
- const forwardInTypes = inTypes.filter((_, i) => !undefPrimals[i]);
4295
- const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((forwardIn, cotangents) => {
4296
- const args = [];
4297
- let forwardInIdx = 0;
4298
- for (let i = 0; i < undefPrimals.length; i++) if (undefPrimals[i]) args.push(new UndefPrimal(inTypes[i]));
4299
- else args.push(forwardIn[forwardInIdx++]);
4300
- return evalJaxprTransposed(jaxpr, args, cotangents);
4301
- })(forwardInTypes, outTypes);
4302
- typecheckJaxpr(newJaxpr);
4303
- const result = {
4304
- newJaxpr,
4305
- newConsts
4306
- };
4307
- if (!transposeJaxprCache.has(jaxpr)) transposeJaxprCache.set(jaxpr, /* @__PURE__ */ new Map());
4308
- transposeJaxprCache.get(jaxpr).set(cacheKey, result);
4309
- return result;
4310
- }
4311
- function vjpFlat(f, primalsIn) {
4312
- const { primalsOut, jaxpr, consts } = linearizeFlatUtil(f, primalsIn);
4313
- const fVjp = (...cotangents) => {
4314
- const transposeInputs = [...consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
4315
- return evalJaxprTransposed(jaxpr, transposeInputs, cotangents);
4316
- };
4317
- const dispose$1 = () => {
4318
- for (const c of consts) c.dispose();
4319
- };
4320
- return [
4321
- primalsOut,
4322
- fVjp,
4323
- dispose$1
4324
- ];
4325
- }
4326
- function vjp$1(f, ...primalsIn) {
4327
- const [primalsInFlat, inTree] = flatten(primalsIn);
4328
- const [fFlat, outTree] = flattenFun(f, inTree);
4329
- const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
4330
- if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
4331
- const primalsOut = unflatten(outTree.value, primalsOutFlat);
4332
- const fVjp = ((cotangentsOut) => {
4333
- const [cotangentsOutFlat, outTree2] = flatten(cotangentsOut);
4334
- if (!outTree.value.equals(outTree2)) throw new TreeMismatchError("vjp", outTree.value, outTree2);
4335
- const cotangentsInFlat = fVjpFlat(...cotangentsOutFlat.map(pureArray));
4336
- return unflatten(inTree, cotangentsInFlat);
4337
- });
4338
- fVjp.dispose = dispose$1;
4339
- return [primalsOut, fVjp];
4340
- }
4341
- function grad$1(f) {
4342
- const valueAndGradFn = valueAndGrad$1(f);
4343
- return (...x) => {
4344
- const [y, dx] = valueAndGradFn(...x);
4345
- y.dispose();
4346
- return dx;
4347
- };
4348
- }
4349
- function valueAndGrad$1(f) {
4350
- return (...x) => {
4351
- if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
4352
- const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
4353
- if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
4354
- if (!require_backend.isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
4355
- const [ct, ...rest] = fVjp(onesLike$1(y.ref));
4356
- for (const r of rest) dispose(r);
4357
- fVjp.dispose();
4358
- return [y, ct];
4359
- };
4360
- }
4361
- function jacrev$1(f) {
4362
- return function jacobianReverse(x) {
4363
- if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
4364
- const [size$1] = x.shape;
4365
- const pullback = (ct) => {
4366
- const [y, fVjp] = vjp$1(f, x);
4367
- y.dispose();
4368
- const [ret] = fVjp(ct);
4369
- fVjp.dispose();
4370
- return ret;
4371
- };
4372
- return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
4373
- };
4374
- }
4375
-
4376
- //#endregion
4377
- //#region src/library/lax.ts
4378
- var lax_exports = {};
4379
- __export(lax_exports, {
4380
- conv: () => conv,
4381
- convGeneralDilated: () => convGeneralDilated,
4382
- convWithGeneralPadding: () => convWithGeneralPadding,
4383
- dot: () => dot$1,
4384
- erf: () => erf,
4385
- erfc: () => erfc,
4386
- reduceWindow: () => reduceWindow,
4387
- stopGradient: () => stopGradient$1
4388
- });
4389
- /**
4390
- * General dot product/contraction operator.
4391
- *
4392
- * Prefer higher-level functions like `jax.numpy.dot()`, `jax.numpy.matmul()`,
4393
- * `jax.numpy.tensordot(), and `jax.numpy.einsum()` where possible.
4394
- */
4395
- function dot$1(lhs, rhs, { lhsContractingDims: lc = [], rhsContractingDims: rc = [], lhsBatchDims: lb = [], rhsBatchDims: rb = [] } = {}) {
4396
- if (lc.length !== rc.length) throw new Error(`dot: contracting dims lengths mismatch, got ${JSON.stringify(lc)} and ${JSON.stringify(rc)}`);
4397
- else if (lb.length !== rb.length) throw new Error(`dot: batch dims lengths mismatch, got ${JSON.stringify(lb)} and ${JSON.stringify(rb)}`);
4398
- lc = lc.map((a) => require_backend.checkAxis(a, lhs.ndim));
4399
- rc = rc.map((a) => require_backend.checkAxis(a, rhs.ndim));
4400
- lb = lb.map((a) => require_backend.checkAxis(a, lhs.ndim));
4401
- rb = rb.map((a) => require_backend.checkAxis(a, rhs.ndim));
4402
- if (lc.some((a) => lb.includes(a))) throw new Error(`dot: lhs contracting dims ${JSON.stringify(lc)} overlap with batch dims ${JSON.stringify(lb)}`);
4403
- else if (rc.some((a) => rb.includes(a))) throw new Error(`dot: rhs contracting dims ${JSON.stringify(rc)} overlap with batch dims ${JSON.stringify(rb)}`);
4404
- const lf = require_backend.range(lhs.ndim).filter((a) => !lc.includes(a) && !lb.includes(a));
4405
- const rf = require_backend.range(rhs.ndim).filter((a) => !rc.includes(a) && !rb.includes(a));
4406
- const lhs2 = lhs.transpose([
4407
- ...lb,
4408
- ...lf,
4409
- ...lc
4410
- ]);
4411
- const rhs2 = rhs.transpose([
4412
- ...rb,
4413
- ...rf,
4414
- ...rc
4415
- ]);
4416
- if (lc.length === 0) return mul(lhs2.reshape([
4417
- ...lb.map((a) => lhs.shape[a]),
4418
- ...lf.map((a) => lhs.shape[a]),
4419
- ...require_backend.rep(rf.length, 1)
4420
- ]), rhs2.reshape([
4421
- ...rb.map((a) => rhs.shape[a]),
4422
- ...require_backend.rep(lf.length, 1),
4423
- ...rf.map((a) => rhs.shape[a])
4424
- ]));
4425
- const dotShapeX = lc.map((a) => lhs.shape[a]);
4426
- const dotShapeY = rc.map((a) => rhs.shape[a]);
4427
- if (!require_backend.deepEqual(dotShapeX, dotShapeY)) throw new Error(`dot: shapes not aligned along contracting dims: ${JSON.stringify(dotShapeX)} != ${JSON.stringify(dotShapeY)}`);
4428
- return dot$2(lhs2.reshape([
4429
- ...lb.map((a) => lhs.shape[a]),
4430
- ...lf.map((a) => lhs.shape[a]),
4431
- ...require_backend.rep(rf.length, 1),
4432
- require_backend.prod(dotShapeX)
4433
- ]), rhs2.reshape([
4434
- ...rb.map((a) => rhs.shape[a]),
4435
- ...require_backend.rep(lf.length, 1),
4436
- ...rf.map((a) => rhs.shape[a]),
4437
- require_backend.prod(dotShapeY)
4438
- ]));
4439
- }
4440
- function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
4441
- const padType = padding.toUpperCase();
4442
- switch (padType) {
4443
- case "VALID": return require_backend.rep(inShape.length, [0, 0]);
4444
- case "SAME":
4445
- case "SAME_LOWER": {
4446
- const outShape = inShape.map((size$1, i) => Math.ceil(size$1 / strides[i]));
4447
- const padSizes = require_backend.zipn(outShape, strides, filterShape, dilation, inShape).map(([o, s, k, d, i]) => Math.max(0, (o - 1) * s + 1 + (k - 1) * d - i));
4448
- if (padType === "SAME") return padSizes.map((size$1) => [size$1 >> 1, size$1 - (size$1 >> 1)]);
4449
- else return padSizes.map((size$1) => [size$1 - (size$1 >> 1), size$1 >> 1]);
4450
- }
4451
- default: throw new Error(`Unknown padding type: ${padType}`);
4452
- }
4453
- }
4454
- /**
4455
- * General n-dimensional convolution operator, with optional dilation.
4456
- *
4457
- * The semantics of this operation mimic the `jax.lax.conv_general_dilated`
4458
- * function in JAX, which wraps XLA's general convolution operator.
4459
- *
4460
- * Grouped convolutions are not supported right now.
4461
- */
4462
- function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
4463
- if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
4464
- if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
4465
- if (typeof padding === "string") {
4466
- if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
4467
- padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? require_backend.rep(rhs.ndim - 2, 1), padding);
4468
- }
4469
- if (featureGroupCount !== 1) {
4470
- const G = featureGroupCount;
4471
- const [N, C_in, ...xs] = lhs.shape;
4472
- const [C_out, C_in_per_group, ...ks] = rhs.shape;
4473
- if (C_in % G !== 0) throw new Error(`featureGroupCount=${G} must divide input channels=${C_in}`);
4474
- if (C_out % G !== 0) throw new Error(`featureGroupCount=${G} must divide output channels=${C_out}`);
4475
- if (C_in / G !== C_in_per_group) throw new Error(`rhs input channels=${C_in_per_group} must equal lhs input channels / groups=${C_in / G}`);
4476
- const lhsGrouped = moveaxis(lhs.reshape([
4477
- N,
4478
- G,
4479
- C_in / G,
4480
- ...xs
4481
- ]), 1, 0);
4482
- const rhsGrouped = rhs.reshape([
4483
- G,
4484
- C_out / G,
4485
- C_in_per_group,
4486
- ...ks
4487
- ]);
4488
- const result = conv$1(lhsGrouped, rhsGrouped, {
4489
- vmapDims: 1,
4490
- strides: windowStrides,
4491
- padding,
4492
- lhsDilation,
4493
- rhsDilation
4772
+ throw new Error("Gather transpose rule is not yet implemented, requires complex Scatter sum operation");
4773
+ },
4774
+ [Primitive.Transpose]([ct], [x], { perm }) {
4775
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Transpose);
4776
+ return [transpose$1(ct, require_backend.invertPermutation(perm))];
4777
+ },
4778
+ [Primitive.Broadcast]([ct], [x], { axis }) {
4779
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Broadcast);
4780
+ return [reduce(ct, require_backend.AluOp.Add, axis)];
4781
+ },
4782
+ [Primitive.Reshape]([ct], [x], _) {
4783
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Reshape);
4784
+ return [reshape$1(ct, x.aval.shape)];
4785
+ },
4786
+ [Primitive.Flip]([ct], [x], { axis }) {
4787
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Flip);
4788
+ return [flip$1(ct, axis)];
4789
+ },
4790
+ [Primitive.Shrink]([ct], [x], { slice }) {
4791
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Shrink);
4792
+ const width = slice.map(([s, e$1], i) => [s, x.aval.shape[i] - e$1]);
4793
+ return [pad$1(ct, width)];
4794
+ },
4795
+ [Primitive.Pad]([ct], [x], { width }) {
4796
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Pad);
4797
+ const slice = width.map(([s, _e], i) => [s, s + x.aval.shape[i]]);
4798
+ return [shrink(ct, slice)];
4799
+ },
4800
+ [Primitive.TriangularSolve]([ct], [a, b], { unitDiagonal }) {
4801
+ if (a instanceof UndefPrimal || !(b instanceof UndefPrimal)) throw new NonlinearError(Primitive.TriangularSolve);
4802
+ const ctB = triangularSolve$1(moveaxis(a, -2, -1), ct, {
4803
+ lower: true,
4804
+ unitDiagonal
4494
4805
  });
4495
- const ys = result.shape.slice(3);
4496
- return moveaxis(result, 0, 1).reshape([
4497
- N,
4498
- C_out,
4499
- ...ys
4500
- ]);
4806
+ return [null, ctB];
4807
+ },
4808
+ [Primitive.Jit](cts, args, { name, jaxpr }) {
4809
+ const undefPrimals = args.map((x) => x instanceof UndefPrimal);
4810
+ const newJaxpr = transposeJaxpr(jaxpr, undefPrimals);
4811
+ const residuals = args.filter((x, i$1) => !undefPrimals[i$1]);
4812
+ const outs = bind(Primitive.Jit, [
4813
+ ...newJaxpr.consts.map((c) => c.ref),
4814
+ ...residuals,
4815
+ ...cts
4816
+ ], {
4817
+ name: `${name}_t`,
4818
+ jaxpr: newJaxpr.jaxpr,
4819
+ numConsts: newJaxpr.consts.length
4820
+ });
4821
+ let i = 0;
4822
+ return undefPrimals.map((isUndef) => isUndef ? outs[i++] : null);
4501
4823
  }
4502
- return conv$1(lhs, rhs, {
4503
- strides: windowStrides,
4504
- padding,
4505
- lhsDilation,
4506
- rhsDilation
4507
- });
4508
- }
4509
- /** Convenience wrapper around `convGeneralDilated`. */
4510
- function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation) {
4511
- return convGeneralDilated(lhs, rhs, windowStrides, padding, {
4512
- lhsDilation,
4513
- rhsDilation
4514
- });
4824
+ };
4825
+ const transposeJaxprCache = /* @__PURE__ */ new Map();
4826
+ function transposeJaxpr(jaxpr, undefPrimals) {
4827
+ const cacheKey = JSON.stringify(undefPrimals);
4828
+ const prevResult = transposeJaxprCache.get(jaxpr)?.get(cacheKey);
4829
+ if (prevResult) return prevResult;
4830
+ const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
4831
+ const forwardInTypes = inTypes.filter((_, i) => !undefPrimals[i]);
4832
+ const { jaxpr: newJaxpr } = makeJaxpr$1((forwardIn, cotangents) => {
4833
+ const args = [];
4834
+ let forwardInIdx = 0;
4835
+ for (let i = 0; i < undefPrimals.length; i++) if (undefPrimals[i]) args.push(new UndefPrimal(inTypes[i]));
4836
+ else args.push(forwardIn[forwardInIdx++]);
4837
+ return evalJaxprTransposed(jaxpr, args, cotangents);
4838
+ })(forwardInTypes, outTypes);
4839
+ typecheckJaxpr(newJaxpr.jaxpr);
4840
+ if (!transposeJaxprCache.has(jaxpr)) transposeJaxprCache.set(jaxpr, /* @__PURE__ */ new Map());
4841
+ transposeJaxprCache.get(jaxpr).set(cacheKey, newJaxpr);
4842
+ return newJaxpr;
4515
4843
  }
4516
- /** Convenience wrapper around `convGeneralDilated`. */
4517
- function conv(lhs, rhs, windowStrides, padding) {
4518
- return convGeneralDilated(lhs, rhs, windowStrides, padding);
4844
+ function vjpFlat(f, primalsIn) {
4845
+ const { primalsOut, jaxpr } = linearizeFlatUtil(f, primalsIn);
4846
+ const fVjp = (...cotangents) => {
4847
+ const transposeInputs = [...jaxpr.consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
4848
+ return evalJaxprTransposed(jaxpr.jaxpr, transposeInputs, cotangents);
4849
+ };
4850
+ const dispose$1 = () => jaxpr.dispose();
4851
+ return [
4852
+ primalsOut,
4853
+ fVjp,
4854
+ dispose$1
4855
+ ];
4519
4856
  }
4520
- /** Reduce a computation over padded windows. */
4521
- function reduceWindow(operand, computation, windowDimensions, windowStrides) {
4522
- if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
4523
- if (!windowStrides) windowStrides = require_backend.rep(windowDimensions.length, 1);
4524
- for (let i = 0; i < operand.ndim; i++) computation = vmap$1(computation, 0);
4525
- return computation(bind1(Primitive.Pool, [operand], {
4526
- window: windowDimensions,
4527
- strides: windowStrides
4528
- }));
4857
+ function vjp$1(f, ...primalsIn) {
4858
+ const [primalsInFlat, inTree] = flatten(primalsIn);
4859
+ const [fFlat, outTree] = flattenFun(f, inTree);
4860
+ const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
4861
+ if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
4862
+ const primalsOut = unflatten(outTree.value, primalsOutFlat);
4863
+ const fVjp = ((cotangentsOut) => {
4864
+ const [cotangentsOutFlat, outTree2] = flatten(cotangentsOut);
4865
+ if (!outTree.value.equals(outTree2)) throw new TreeMismatchError("vjp", outTree.value, outTree2);
4866
+ const cotangentsInFlat = fVjpFlat(...cotangentsOutFlat.map(pureArray));
4867
+ return unflatten(inTree, cotangentsInFlat);
4868
+ });
4869
+ fVjp.dispose = dispose$1;
4870
+ return [primalsOut, fVjp];
4529
4871
  }
4530
- /** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
4531
- function erf(x) {
4532
- return erf$1(x);
4872
+ function grad$1(f) {
4873
+ const valueAndGradFn = valueAndGrad$1(f);
4874
+ return (...x) => {
4875
+ const [y, dx] = valueAndGradFn(...x);
4876
+ y.dispose();
4877
+ return dx;
4878
+ };
4533
4879
  }
4534
- /**
4535
- * The complementary error function: `erfc(x) = 1 - erf(x)`.
4536
- *
4537
- * This function is more accurate than `1 - erf(x)` for large values of `x`,
4538
- * where `erf(x)` is very close to 1.
4539
- */
4540
- function erfc(x) {
4541
- return erfc$1(x);
4880
+ function valueAndGrad$1(f) {
4881
+ return (...x) => {
4882
+ if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
4883
+ const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
4884
+ if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
4885
+ if (!require_backend.isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
4886
+ const [ct, ...rest] = fVjp(onesLike$1(y.ref));
4887
+ for (const r of rest) dispose(r);
4888
+ fVjp.dispose();
4889
+ return [y, ct];
4890
+ };
4542
4891
  }
4543
- /**
4544
- * Stops gradient computation.
4545
- *
4546
- * Behaves as the identity function but prevents the flow of gradients during
4547
- * forward or reverse-mode automatic differentiation.
4548
- */
4549
- function stopGradient$1(x) {
4550
- return stopGradient(x);
4892
+ function jacrev$1(f) {
4893
+ return function jacobianReverse(x) {
4894
+ if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
4895
+ const [size$1] = x.shape;
4896
+ const pullback = (ct) => {
4897
+ const [y, fVjp] = vjp$1(f, x);
4898
+ y.dispose();
4899
+ const [ret] = fVjp(ct);
4900
+ fVjp.dispose();
4901
+ return ret;
4902
+ };
4903
+ return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
4904
+ };
4551
4905
  }
4552
4906
 
4553
4907
  //#endregion
@@ -4687,8 +5041,8 @@ function computeSizeMap({ shapes, lhsIndices, rhsIndex }) {
4687
5041
  const idx = lhsIndex[j];
4688
5042
  const dim = shape$1[j];
4689
5043
  const existing = sizeMap.get(idx);
4690
- if (existing === void 0) sizeMap.set(idx, dim);
4691
- else if (existing !== dim) throw new Error(`Inconsistent size for index ${idx} in einsum: ${existing} vs ${dim}`);
5044
+ if (existing === void 0 || existing === 1) sizeMap.set(idx, dim);
5045
+ else if (existing !== dim && dim !== 1) throw new Error(`Inconsistent size for index ${idx} in einsum: ${existing} vs ${dim}`);
4692
5046
  }
4693
5047
  }
4694
5048
  for (const [idx, size$1] of sizeMap) if (!Number.isInteger(idx) || idx < 0) throw new Error(`Invalid index ${idx} in einsum expression, must be non-negative integer`);
@@ -4696,52 +5050,410 @@ function computeSizeMap({ shapes, lhsIndices, rhsIndex }) {
4696
5050
  for (const idx of rhsIndex) if (!sizeMap.has(idx)) throw new Error(`Output index ${idx} not present in einsum inputs`);
4697
5051
  return sizeMap;
4698
5052
  }
4699
- const einsumPathCache = /* @__PURE__ */ new Map();
4700
- function computeEinsumPath(input, method) {
4701
- if (!method) method = input.shapes.length <= 5 ? "optimal" : "naive";
4702
- return require_backend.runWithCache(einsumPathCache, [input, method], () => {
4703
- const sizeMap = computeSizeMap(input);
4704
- if (input.shapes.length === 1) return new EinsumPath(input, sizeMap, []);
4705
- switch (method) {
4706
- case "naive": return computePathNaive(input, sizeMap);
4707
- case "optimal": return computePathOptimal(input, sizeMap);
4708
- default: throw new Error(`Unknown computePath method: ${method}`);
4709
- }
4710
- });
5053
+ const einsumPathCache = /* @__PURE__ */ new Map();
5054
+ function computeEinsumPath(input, method) {
5055
+ if (!method) method = input.shapes.length <= 5 ? "optimal" : "naive";
5056
+ return require_backend.runWithCache(einsumPathCache, [input, method], () => {
5057
+ const sizeMap = computeSizeMap(input);
5058
+ if (input.shapes.length === 1) return new EinsumPath(input, sizeMap, []);
5059
+ switch (method) {
5060
+ case "naive": return computePathNaive(input, sizeMap);
5061
+ case "optimal": return computePathOptimal(input, sizeMap);
5062
+ default: throw new Error(`Unknown computePath method: ${method}`);
5063
+ }
5064
+ });
5065
+ }
5066
+ function computePathNaive(input, sizeMap) {
5067
+ const n = input.shapes.length;
5068
+ const path = [];
5069
+ let lastTensorIndex = 0;
5070
+ for (let i = 1; i < n; i++) {
5071
+ path.push([lastTensorIndex, i]);
5072
+ lastTensorIndex = n + i - 1;
5073
+ }
5074
+ return new EinsumPath(input, sizeMap, path);
5075
+ }
5076
+ function computePathOptimal(input, sizeMap) {
5077
+ const n = input.shapes.length;
5078
+ let bestPath = null;
5079
+ let bestFlops = null;
5080
+ for (const path of allPaths(require_backend.range(n), n)) {
5081
+ const flops = approximatePathFlops(input, sizeMap, path);
5082
+ if (bestFlops === null || flops < bestFlops) {
5083
+ bestPath = path;
5084
+ bestFlops = flops;
5085
+ }
5086
+ }
5087
+ return new EinsumPath(input, sizeMap, bestPath);
5088
+ }
5089
+ function* allPaths(tensors, next) {
5090
+ if (tensors.length === 2) {
5091
+ yield [[tensors[0], tensors[1]]];
5092
+ return;
5093
+ }
5094
+ for (let i = 0; i < tensors.length; i++) for (let j = i + 1; j < tensors.length; j++) {
5095
+ const pair = [tensors[i], tensors[j]];
5096
+ const newTensors = tensors.filter((t) => t !== pair[0] && t !== pair[1]);
5097
+ newTensors.push(next);
5098
+ for (const subpath of allPaths(newTensors, next + 1)) yield [pair, ...subpath];
5099
+ }
5100
+ }
5101
+
5102
+ //#endregion
5103
+ //#region src/library/numpy-fft.ts
5104
+ var numpy_fft_exports = {};
5105
+ __export(numpy_fft_exports, {
5106
+ fft: () => fft,
5107
+ ifft: () => ifft
5108
+ });
5109
+ function checkPairInput(name, a) {
5110
+ const fullName = `jax.numpy.fft.${name}`;
5111
+ if (!require_backend.deepEqual(a.real.shape, a.imag.shape)) throw new Error(`${fullName}: real and imaginary parts must have the same shape, got ${JSON.stringify(a.real.shape)} and ${JSON.stringify(a.imag.shape)}`);
5112
+ if (a.real.dtype !== a.imag.dtype) throw new Error(`${fullName}: real and imaginary parts must have the same dtype, got ${a.real.dtype} and ${a.imag.dtype}`);
5113
+ if (!require_backend.isFloatDtype(a.real.dtype)) throw new Error(`${fullName}: input must have a float dtype, got ${a.real.dtype}`);
5114
+ }
5115
+ function checkPowerOfTwo(name, n) {
5116
+ if ((n & n - 1) !== 0) throw new Error(`jax.numpy.fft.${name}: size must be a power of two, got ${n}`);
5117
+ }
5118
+ const fftUpdate = jit$1(function fftUpdate$1(i, { real, imag }) {
5119
+ const half = 2 ** i;
5120
+ real = real.reshape([-1, 2 * half]);
5121
+ imag = imag.reshape([-1, 2 * half]);
5122
+ const k = arange(0, half, 1, { dtype: real.dtype });
5123
+ const theta = k.mul(-Math.PI / half);
5124
+ const wr = cos(theta.ref);
5125
+ const wi = sin(theta);
5126
+ const ur = real.ref.slice([], [0, half]);
5127
+ const ui = imag.ref.slice([], [0, half]);
5128
+ const vr = real.slice([], [half, 2 * half]);
5129
+ const vi = imag.slice([], [half, 2 * half]);
5130
+ const tr = vr.ref.mul(wr.ref).sub(vi.ref.mul(wi.ref));
5131
+ const ti = vr.mul(wi).add(vi.mul(wr));
5132
+ return {
5133
+ real: concatenate([ur.ref.add(tr.ref), ur.sub(tr)], -1),
5134
+ imag: concatenate([ui.ref.add(ti.ref), ui.sub(ti)], -1)
5135
+ };
5136
+ }, { staticArgnums: [0] });
5137
+ /**
5138
+ * Compute a one-dimensional discrete Fourier transform.
5139
+ *
5140
+ * Currently, the size of the axis must be a power of two.
5141
+ */
5142
+ function fft(a, axis = -1) {
5143
+ checkPairInput("fft", a);
5144
+ let { real, imag } = a;
5145
+ axis = require_backend.checkAxis(axis, real.ndim);
5146
+ const n = real.shape[axis];
5147
+ checkPowerOfTwo("fft", n);
5148
+ const logN = Math.log2(n);
5149
+ let perm = null;
5150
+ if (axis !== real.ndim - 1) {
5151
+ perm = require_backend.range(real.ndim);
5152
+ perm.splice(axis, 1);
5153
+ perm.push(axis);
5154
+ real = real.transpose(perm);
5155
+ imag = imag.transpose(perm);
5156
+ }
5157
+ const originalShape = real.shape;
5158
+ real = real.reshape([-1, ...require_backend.rep(logN, 2)]).transpose([0, ...require_backend.range(1, logN + 1).reverse()]).flatten();
5159
+ imag = imag.reshape([-1, ...require_backend.rep(logN, 2)]).transpose([0, ...require_backend.range(1, logN + 1).reverse()]).flatten();
5160
+ for (let i = 0; i < logN; i++) ({real, imag} = fftUpdate(i, {
5161
+ real,
5162
+ imag
5163
+ }));
5164
+ real = real.reshape(originalShape);
5165
+ imag = imag.reshape(originalShape);
5166
+ if (perm !== null) {
5167
+ real = real.transpose(require_backend.invertPermutation(perm));
5168
+ imag = imag.transpose(require_backend.invertPermutation(perm));
5169
+ }
5170
+ return {
5171
+ real,
5172
+ imag
5173
+ };
5174
+ }
5175
+ /**
5176
+ * Compute a one-dimensional inverse discrete Fourier transform.
5177
+ *
5178
+ * Currently, the size of the axis must be a power of two.
5179
+ */
5180
+ function ifft(a, axis = -1) {
5181
+ checkPairInput("ifft", a);
5182
+ let { real, imag } = a;
5183
+ axis = require_backend.checkAxis(axis, real.ndim);
5184
+ const n = real.shape[axis];
5185
+ checkPowerOfTwo("ifft", n);
5186
+ imag = imag.mul(-1);
5187
+ const result = fft({
5188
+ real,
5189
+ imag
5190
+ }, axis);
5191
+ return {
5192
+ real: result.real.div(n),
5193
+ imag: result.imag.mul(-1).div(n)
5194
+ };
5195
+ }
5196
+
5197
+ //#endregion
5198
+ //#region src/library/numpy-linalg.ts
5199
+ var numpy_linalg_exports = {};
5200
+ __export(numpy_linalg_exports, {
5201
+ cholesky: () => cholesky,
5202
+ det: () => det,
5203
+ diagonal: () => diagonal,
5204
+ inv: () => inv,
5205
+ lstsq: () => lstsq,
5206
+ matmul: () => matmul,
5207
+ matrixPower: () => matrixPower,
5208
+ matrixTranspose: () => matrixTranspose,
5209
+ outer: () => outer,
5210
+ slogdet: () => slogdet,
5211
+ solve: () => solve,
5212
+ tensordot: () => tensordot,
5213
+ trace: () => trace,
5214
+ vecdot: () => vecdot
5215
+ });
5216
+ function checkSquare(name, a) {
5217
+ if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`${name}: input must be at least 2D square matrix, got ${a.aval}`);
5218
+ return a.shape[a.ndim - 1];
5219
+ }
5220
+ /**
5221
+ * Compute the Cholesky decomposition of a (batched) positive-definite matrix.
5222
+ *
5223
+ * This is like `jax.lax.linalg.cholesky()`, except with an option to symmetrize
5224
+ * the input matrix, which is on by default.
5225
+ */
5226
+ function cholesky(a, { upper = false, symmetrizeInput = true } = {}) {
5227
+ a = fudgeArray(a);
5228
+ checkSquare("cholesky", a);
5229
+ if (symmetrizeInput) a = a.ref.add(matrixTranspose(a)).mul(.5);
5230
+ return cholesky$1(a, { upper });
5231
+ }
5232
+ /** Compute the determinant of a square matrix (batched). */
5233
+ function det(a) {
5234
+ a = fudgeArray(a);
5235
+ const n = checkSquare("det", a);
5236
+ const [lu$2, pivots, permutation] = lu(a);
5237
+ permutation.dispose();
5238
+ const parity = pivots.notEqual(arange(n)).astype(int32).sum(-1).mod(2);
5239
+ const sign$1 = parity.mul(-2).add(1);
5240
+ const diag$1 = lu$2.diagonal(0, -1, -2);
5241
+ return prod$1(diag$1, -1).mul(sign$1);
5242
+ }
5243
+ /** Compute the inverse of a square matrix (batched). */
5244
+ function inv(a) {
5245
+ a = fudgeArray(a);
5246
+ const n = checkSquare("inv", a);
5247
+ return solve(a, eye(n));
5248
+ }
5249
+ /**
5250
+ * Return the least-squares solution to a linear equation.
5251
+ *
5252
+ * For overdetermined systems, this finds the `x` that minimizes `norm(ax - b)`.
5253
+ * For underdetermined systems, this finds the minimum-norm solution for `x`.
5254
+ *
5255
+ * This currently uses Cholesky decomposition to solve the normal equations,
5256
+ * under the hood. The method is not as robust as QR or SVD.
5257
+ *
5258
+ * @param a coefficient matrix of shape `(M, N)`
5259
+ * @param b right-hand side of shape `(M,)` or `(M, K)`
5260
+ * @return least-squares solution of shape `(N,)` or `(N, K)`
5261
+ */
5262
+ function lstsq(a, b) {
5263
+ a = fudgeArray(a);
5264
+ b = fudgeArray(b);
5265
+ if (a.ndim !== 2) throw new Error(`lstsq: 'a' must be a 2D array, got ${a.aval}`);
5266
+ const [m, n] = a.shape;
5267
+ if (b.shape[0] !== m) throw new Error(`lstsq: leading dimension of 'b' must match number of rows of 'a', got ${b.aval}`);
5268
+ const at = matrixTranspose(a.ref);
5269
+ if (m <= n) {
5270
+ const aat = matmul(a, at.ref);
5271
+ const l = cholesky(aat, { symmetrizeInput: false });
5272
+ const lb = triangularSolve(l.ref, b, {
5273
+ leftSide: true,
5274
+ lower: true
5275
+ });
5276
+ const llb = triangularSolve(l, lb, {
5277
+ leftSide: true,
5278
+ transposeA: true
5279
+ });
5280
+ return matmul(at, llb.ref);
5281
+ } else {
5282
+ const ata = matmul(at.ref, a);
5283
+ const l = cholesky(ata, { symmetrizeInput: false });
5284
+ const atb = matmul(at, b);
5285
+ const lb = triangularSolve(l.ref, atb, {
5286
+ leftSide: true,
5287
+ lower: true
5288
+ });
5289
+ const llb = triangularSolve(l, lb, {
5290
+ leftSide: true,
5291
+ transposeA: true
5292
+ });
5293
+ return llb;
5294
+ }
5295
+ }
5296
+ /** Raise a square matrix to an integer power, via repeated squarings. */
5297
+ function matrixPower(a, n) {
5298
+ if (!Number.isInteger(n)) throw new Error(`matrixPower: exponent must be an integer, got ${n}`);
5299
+ a = fudgeArray(a);
5300
+ const m = checkSquare("matrixPower", a);
5301
+ if (n === 0) {
5302
+ a.dispose();
5303
+ return broadcastTo(eye(m), a.shape);
5304
+ }
5305
+ if (n < 0) {
5306
+ a = inv(a);
5307
+ n = -n;
5308
+ }
5309
+ let result = null;
5310
+ let a2k = a;
5311
+ for (let k = 0; n; k++) {
5312
+ if (k > 0) a2k = matmul(a2k.ref, a2k);
5313
+ if (n % 2 === 1) result = result === null ? a2k.ref : matmul(result, a2k.ref);
5314
+ n = Math.floor(n / 2);
5315
+ }
5316
+ a2k.dispose();
5317
+ return result;
5318
+ }
5319
+ /** Return sign and natural logarithm of the determinant of `a`. */
5320
+ function slogdet(a) {
5321
+ a = fudgeArray(a);
5322
+ const n = checkSquare("slogdet", a);
5323
+ const [lu$2, pivots, permutation] = lu(a);
5324
+ permutation.dispose();
5325
+ let parity = pivots.notEqual(arange(n)).astype(int32).sum(-1);
5326
+ const diag$1 = lu$2.diagonal(0, -1, -2);
5327
+ parity = parity.add(diag$1.ref.less(0).astype(int32).sum(-1)).mod(2);
5328
+ const logabsdet = log(absolute(diag$1)).sum(-1);
5329
+ const sign$1 = parity.mul(-2).add(1);
5330
+ return [sign$1, logabsdet];
4711
5331
  }
4712
- function computePathNaive(input, sizeMap) {
4713
- const n = input.shapes.length;
4714
- const path = [];
4715
- let lastTensorIndex = 0;
4716
- for (let i = 1; i < n; i++) {
4717
- path.push([lastTensorIndex, i]);
4718
- lastTensorIndex = n + i - 1;
4719
- }
4720
- return new EinsumPath(input, sizeMap, path);
5332
+ /**
5333
+ * Solve a linear system of equations.
5334
+ *
5335
+ * This solves a (batched) linear system of equations `a @ x = b` for `x` given
5336
+ * `a` and `b`. If `a` is singular, this will return `nan` or `inf` values.
5337
+ *
5338
+ * @param a - Coefficient matrix of shape `(..., N, N)`.
5339
+ * @param b - Values of shape `(N,)` or `(..., N, M)`.
5340
+ * @returns Solution `x` of shape `(..., N)` or `(..., N, M)`.
5341
+ */
5342
+ function solve(a, b) {
5343
+ a = fudgeArray(a);
5344
+ b = fudgeArray(b);
5345
+ const n = checkSquare("solve", a);
5346
+ if (b.ndim === 0) throw new Error(`solve: b cannot be scalar`);
5347
+ const bIs1d = b.ndim === 1;
5348
+ if (bIs1d) b = b.reshape([...b.shape, 1]);
5349
+ if (b.shape[b.ndim - 2] !== n) throw new Error(`solve: leading dimension of b must match size of a, got a=${a.aval}, b=${b.aval}`);
5350
+ const m = b.shape[b.ndim - 1];
5351
+ const batchDims = require_backend.generalBroadcast(a.shape.slice(0, -2), b.shape.slice(0, -2));
5352
+ a = broadcastTo(a, [
5353
+ ...batchDims,
5354
+ n,
5355
+ n
5356
+ ]);
5357
+ b = broadcastTo(b, [
5358
+ ...batchDims,
5359
+ n,
5360
+ m
5361
+ ]);
5362
+ const [lu$2, pivots, permutation] = lu(a);
5363
+ pivots.dispose();
5364
+ const P = arange(n).equal(permutation.reshape([...permutation.shape, 1])).astype(b.dtype);
5365
+ const LPb = triangularSolve(lu$2.ref, matmul(P, b), {
5366
+ leftSide: true,
5367
+ lower: true,
5368
+ unitDiagonal: true
5369
+ });
5370
+ let x = triangularSolve(lu$2, LPb.ref, {
5371
+ leftSide: true,
5372
+ lower: false
5373
+ });
5374
+ if (bIs1d) x = squeeze(x, -1);
5375
+ return x;
4721
5376
  }
4722
- function computePathOptimal(input, sizeMap) {
4723
- const n = input.shapes.length;
4724
- let bestPath = null;
4725
- let bestFlops = null;
4726
- for (const path of allPaths(require_backend.range(n), n)) {
4727
- const flops = approximatePathFlops(input, sizeMap, path);
4728
- if (bestFlops === null || flops < bestFlops) {
4729
- bestPath = path;
4730
- bestFlops = flops;
4731
- }
5377
+
5378
+ //#endregion
5379
+ //#region src/library/numpy/dtype-info.ts
5380
+ /** Machine limits for floating-point types. */
5381
+ function finfo(dtype) {
5382
+ if (!require_backend.isFloatDtype(dtype)) throw new Error(`finfo: received ${dtype}, must be a floating-point type`);
5383
+ switch (dtype) {
5384
+ case require_backend.DType.Float16: return Object.freeze({
5385
+ bits: 16,
5386
+ dtype: require_backend.DType.Float16,
5387
+ eps: 2 ** -10,
5388
+ epsneg: 2 ** -11,
5389
+ machep: -10,
5390
+ max: 65504,
5391
+ maxexp: 16,
5392
+ min: -65504,
5393
+ minexp: -14,
5394
+ negep: -24,
5395
+ nexp: 5,
5396
+ nmant: 10,
5397
+ precision: 3,
5398
+ resolution: .001,
5399
+ smallestNormal: 2 ** -14,
5400
+ smallestSubnormal: 2 ** -24
5401
+ });
5402
+ case require_backend.DType.Float32: return Object.freeze({
5403
+ bits: 32,
5404
+ dtype: require_backend.DType.Float32,
5405
+ eps: 2 ** -23,
5406
+ epsneg: 2 ** -24,
5407
+ machep: -23,
5408
+ max: 34028234663852886e22,
5409
+ maxexp: 128,
5410
+ min: -34028234663852886e22,
5411
+ minexp: -126,
5412
+ negep: -24,
5413
+ nexp: 8,
5414
+ nmant: 23,
5415
+ precision: 6,
5416
+ resolution: 1e-6,
5417
+ smallestNormal: 2 ** -126,
5418
+ smallestSubnormal: 2 ** -149
5419
+ });
5420
+ case require_backend.DType.Float64: return Object.freeze({
5421
+ bits: 64,
5422
+ dtype: require_backend.DType.Float64,
5423
+ eps: 2 ** -52,
5424
+ epsneg: 2 ** -53,
5425
+ machep: -52,
5426
+ max: Number.MAX_VALUE,
5427
+ maxexp: 1024,
5428
+ min: -Number.MAX_VALUE,
5429
+ minexp: -1022,
5430
+ negep: -53,
5431
+ nexp: 11,
5432
+ nmant: 52,
5433
+ precision: 15,
5434
+ resolution: 1e-15,
5435
+ smallestNormal: 2 ** -1022,
5436
+ smallestSubnormal: 2 ** -1074
5437
+ });
5438
+ default: throw new Error(`finfo: unsupported dtype ${dtype}`);
4732
5439
  }
4733
- return new EinsumPath(input, sizeMap, bestPath);
4734
5440
  }
4735
- function* allPaths(tensors, next) {
4736
- if (tensors.length === 2) {
4737
- yield [[tensors[0], tensors[1]]];
4738
- return;
4739
- }
4740
- for (let i = 0; i < tensors.length; i++) for (let j = i + 1; j < tensors.length; j++) {
4741
- const pair = [tensors[i], tensors[j]];
4742
- const newTensors = tensors.filter((t) => t !== pair[0] && t !== pair[1]);
4743
- newTensors.push(next);
4744
- for (const subpath of allPaths(newTensors, next + 1)) yield [pair, ...subpath];
5441
+ /** Machine limits for integer types. */
5442
+ function iinfo(dtype) {
5443
+ switch (dtype) {
5444
+ case require_backend.DType.Int32: return Object.freeze({
5445
+ bits: 32,
5446
+ dtype: require_backend.DType.Int32,
5447
+ max: 2147483647,
5448
+ min: -2147483648
5449
+ });
5450
+ case require_backend.DType.Uint32: return Object.freeze({
5451
+ bits: 32,
5452
+ dtype: require_backend.DType.Uint32,
5453
+ max: 4294967295,
5454
+ min: 0
5455
+ });
5456
+ default: throw new Error(`iinfo: unsupported dtype ${dtype}`);
4745
5457
  }
4746
5458
  }
4747
5459
 
@@ -4751,28 +5463,32 @@ var numpy_exports = {};
4751
5463
  __export(numpy_exports, {
4752
5464
  Array: () => Array$1,
4753
5465
  DType: () => require_backend.DType,
4754
- abs: () => abs,
5466
+ abs: () => absolute,
4755
5467
  absolute: () => absolute,
4756
5468
  acos: () => acos,
4757
- acosh: () => acosh,
5469
+ acosh: () => arccosh,
4758
5470
  add: () => add,
5471
+ all: () => all,
4759
5472
  allclose: () => allclose,
5473
+ any: () => any,
4760
5474
  arange: () => arange,
4761
- arccos: () => arccos,
5475
+ arccos: () => acos,
4762
5476
  arccosh: () => arccosh,
5477
+ arcsin: () => asin,
4763
5478
  arcsinh: () => arcsinh,
4764
- arctan: () => arctan,
4765
- arctan2: () => arctan2,
5479
+ arctan: () => atan,
5480
+ arctan2: () => atan2,
4766
5481
  arctanh: () => arctanh,
4767
5482
  argmax: () => argmax,
4768
5483
  argmin: () => argmin,
5484
+ argsort: () => argsort,
4769
5485
  array: () => array,
4770
5486
  asin: () => asin,
4771
- asinh: () => asinh,
5487
+ asinh: () => arcsinh,
4772
5488
  astype: () => astype,
4773
5489
  atan: () => atan,
4774
5490
  atan2: () => atan2,
4775
- atanh: () => atanh,
5491
+ atanh: () => arctanh,
4776
5492
  bool: () => bool,
4777
5493
  broadcastArrays: () => broadcastArrays,
4778
5494
  broadcastShapes: () => broadcastShapes,
@@ -4782,16 +5498,21 @@ __export(numpy_exports, {
4782
5498
  clip: () => clip,
4783
5499
  columnStack: () => columnStack,
4784
5500
  concatenate: () => concatenate,
5501
+ convolve: () => convolve,
5502
+ corrcoef: () => corrcoef,
5503
+ correlate: () => correlate,
4785
5504
  cos: () => cos,
4786
5505
  cosh: () => cosh,
5506
+ cov: () => cov,
4787
5507
  cumsum: () => cumsum,
4788
- cumulativeSum: () => cumulativeSum,
5508
+ cumulativeSum: () => cumsum,
4789
5509
  deg2rad: () => deg2rad,
4790
5510
  degrees: () => degrees,
4791
5511
  diag: () => diag,
4792
5512
  diagonal: () => diagonal,
4793
- divide: () => divide,
4794
- dot: () => dot,
5513
+ divide: () => trueDivide,
5514
+ divmod: () => divmod,
5515
+ dot: () => dot$1,
4795
5516
  dstack: () => dstack,
4796
5517
  e: () => e,
4797
5518
  einsum: () => einsum,
@@ -4799,8 +5520,11 @@ __export(numpy_exports, {
4799
5520
  eulerGamma: () => eulerGamma,
4800
5521
  exp: () => exp,
4801
5522
  exp2: () => exp2,
5523
+ expandDims: () => expandDims,
4802
5524
  expm1: () => expm1,
4803
5525
  eye: () => eye,
5526
+ fft: () => numpy_fft_exports,
5527
+ finfo: () => finfo,
4804
5528
  flip: () => flip,
4805
5529
  fliplr: () => fliplr,
4806
5530
  flipud: () => flipud,
@@ -4808,6 +5532,7 @@ __export(numpy_exports, {
4808
5532
  float32: () => float32,
4809
5533
  float64: () => float64,
4810
5534
  floor: () => floor,
5535
+ floorDivide: () => floorDivide,
4811
5536
  fmod: () => fmod,
4812
5537
  frexp: () => frexp,
4813
5538
  full: () => full,
@@ -4820,6 +5545,7 @@ __export(numpy_exports, {
4820
5545
  hstack: () => hstack,
4821
5546
  hypot: () => hypot,
4822
5547
  identity: () => identity$1,
5548
+ iinfo: () => iinfo,
4823
5549
  inf: () => inf,
4824
5550
  inner: () => inner,
4825
5551
  int32: () => int32,
@@ -4831,12 +5557,15 @@ __export(numpy_exports, {
4831
5557
  ldexp: () => ldexp,
4832
5558
  less: () => less,
4833
5559
  lessEqual: () => lessEqual,
5560
+ linalg: () => numpy_linalg_exports,
4834
5561
  linspace: () => linspace,
4835
5562
  log: () => log,
4836
5563
  log10: () => log10,
4837
5564
  log1p: () => log1p,
4838
5565
  log2: () => log2,
5566
+ logspace: () => logspace,
4839
5567
  matmul: () => matmul,
5568
+ matrixTranspose: () => matrixTranspose,
4840
5569
  max: () => max,
4841
5570
  maximum: () => maximum,
4842
5571
  mean: () => mean,
@@ -4853,10 +5582,10 @@ __export(numpy_exports, {
4853
5582
  onesLike: () => onesLike,
4854
5583
  outer: () => outer,
4855
5584
  pad: () => pad,
4856
- permuteDims: () => permuteDims,
5585
+ permuteDims: () => transpose,
4857
5586
  pi: () => pi,
4858
5587
  positive: () => positive,
4859
- pow: () => pow,
5588
+ pow: () => power,
4860
5589
  power: () => power,
4861
5590
  prod: () => prod$1,
4862
5591
  promoteTypes: () => require_backend.promoteTypes,
@@ -4871,8 +5600,11 @@ __export(numpy_exports, {
4871
5600
  shape: () => shape,
4872
5601
  sign: () => sign,
4873
5602
  sin: () => sin,
5603
+ sinc: () => sinc,
4874
5604
  sinh: () => sinh,
4875
5605
  size: () => size,
5606
+ sort: () => sort,
5607
+ split: () => split$1,
4876
5608
  sqrt: () => sqrt,
4877
5609
  square: () => square,
4878
5610
  squeeze: () => squeeze,
@@ -4880,6 +5612,7 @@ __export(numpy_exports, {
4880
5612
  std: () => std,
4881
5613
  subtract: () => subtract,
4882
5614
  sum: () => sum,
5615
+ take: () => take,
4883
5616
  tan: () => tan,
4884
5617
  tanh: () => tanh,
4885
5618
  tensordot: () => tensordot,
@@ -5037,6 +5770,26 @@ function min(a, axis = null, opts) {
5037
5770
  function max(a, axis = null, opts) {
5038
5771
  return reduce(a, require_backend.AluOp.Max, axis, opts);
5039
5772
  }
5773
+ /**
5774
+ * Test whether all array elements along a given axis evaluate to True.
5775
+ *
5776
+ * Returns a boolean array with the same shape as `a` with the specified axis
5777
+ * removed. If axis is None, returns a scalar.
5778
+ */
5779
+ function all(a, axis = null, opts) {
5780
+ a = fudgeArray(a).astype(require_backend.DType.Bool);
5781
+ return min(a, axis, opts);
5782
+ }
5783
+ /**
5784
+ * Test whether any array element along a given axis evaluates to True.
5785
+ *
5786
+ * Returns a boolean array with the same shape as `a` with the specified axis
5787
+ * removed. If axis is None, returns a scalar.
5788
+ */
5789
+ function any(a, axis = null, opts) {
5790
+ a = fudgeArray(a).astype(require_backend.DType.Bool);
5791
+ return max(a, axis, opts);
5792
+ }
5040
5793
  /** Return the peak-to-peak range along a given axis (`max - min`). */
5041
5794
  function ptp(a, axis = null, opts) {
5042
5795
  a = fudgeArray(a);
@@ -5111,8 +5864,6 @@ function cumsum(a, axis) {
5111
5864
  a = broadcast(a, a.shape.concat(n), [-2]);
5112
5865
  return moveaxis$1(tril(a).sum(-1), -1, axis);
5113
5866
  }
5114
- /** @function Alternative name for `jax.numpy.cumsum()`. */
5115
- const cumulativeSum = cumsum;
5116
5867
  /** Reverse the elements in an array along the given axes. */
5117
5868
  function flip(x, axis = null) {
5118
5869
  const nd = ndim(x);
@@ -5120,6 +5871,45 @@ function flip(x, axis = null) {
5120
5871
  return flip$1(x, axis);
5121
5872
  }
5122
5873
  /**
5874
+ * Split an array into multiple sub-arrays along an axis.
5875
+ *
5876
+ * @param a - The input array to split.
5877
+ * @param indicesOrSections - If an integer, it indicates the number of equal
5878
+ * sections to create along the specified axis. If a list of integers, it
5879
+ * specifies the indices at which to split the array.
5880
+ * @param axis - The axis along which to split the array. Default is 0.
5881
+ */
5882
+ function split$1(a, indicesOrSections, axis = 0) {
5883
+ a = fudgeArray(a);
5884
+ axis = require_backend.checkAxis(axis, a.ndim);
5885
+ const size$1 = a.shape[axis];
5886
+ let sizes;
5887
+ if (typeof indicesOrSections === "number") {
5888
+ if (size$1 % indicesOrSections !== 0) throw new Error(`Array of size ${size$1} cannot be split into ${indicesOrSections} equal parts`);
5889
+ const partSize = size$1 / indicesOrSections;
5890
+ sizes = require_backend.rep(indicesOrSections, partSize);
5891
+ } else {
5892
+ const indices = indicesOrSections;
5893
+ sizes = [indices[0]];
5894
+ for (let i = 1; i < indices.length; i++) sizes.push(indices[i] - indices[i - 1]);
5895
+ sizes.push(size$1 - indices[indices.length - 1]);
5896
+ }
5897
+ const results = [];
5898
+ for (let i = 0; i < sizes.length; i += 7) if (i === sizes.length) {
5899
+ results.push(a);
5900
+ break;
5901
+ } else if (i + 8 >= sizes.length) {
5902
+ results.push(...split$2(a, axis, sizes.slice(i)));
5903
+ break;
5904
+ } else {
5905
+ const groupSizes = [...sizes.slice(i, i + 7), sizes.slice(i + 7).reduce((x, y) => x + y, 0)];
5906
+ const outs = split$2(a, axis, groupSizes);
5907
+ results.push(...outs.slice(0, -1));
5908
+ a = outs[outs.length - 1];
5909
+ }
5910
+ return results;
5911
+ }
5912
+ /**
5123
5913
  * Join a sequence of arrays along an existing axis.
5124
5914
  *
5125
5915
  * The arrays must have the same shape, except in the dimension corresponding to
@@ -5131,13 +5921,11 @@ function concatenate(xs, axis = 0) {
5131
5921
  if (xs.length === 0) throw new Error("Need at least one array to concatenate");
5132
5922
  const shapes = xs.map(shape);
5133
5923
  axis = require_backend.checkAxis(axis, shapes[0].length);
5134
- for (let i = 1; i < shapes.length; i++) if (shapes[i].length !== shapes[0].length || !shapes[i].every((d, j) => j === axis || d === shapes[0][j])) throw new Error(`Cannot concatenate arrays with shapes ${JSON.stringify(shapes)} along axis ${axis}`);
5135
- const makePadAxis = (start, end) => shapes[0].map((_, i) => i === axis ? [start, end] : [0, 0]);
5924
+ for (let i = 1; i < shapes.length; i++) if (shapes[i].length !== shapes[0].length || !shapes[i].every((d, j) => j === axis || d === shapes[0][j])) throw new Error(`Cannot concatenate arrays ${xs[0].aval} and ${xs[i].aval} along axis ${axis}`);
5136
5925
  let result = xs[0];
5137
- for (let i = 1; i < xs.length; i++) {
5138
- const len1 = result.shape[axis];
5139
- const len2 = shapes[i][axis];
5140
- result = pad(result, makePadAxis(0, len2)).add(pad(xs[i], makePadAxis(len1, 0)));
5926
+ for (let i = 1; i < xs.length; i += 7) {
5927
+ const group = xs.slice(i, i + 7);
5928
+ result = concatenate$1([result, ...group], axis);
5141
5929
  }
5142
5930
  return result;
5143
5931
  }
@@ -5222,8 +6010,11 @@ function flipud(x) {
5222
6010
  function fliplr(x) {
5223
6011
  return flip(x, 1);
5224
6012
  }
5225
- /** @function Alternative name for `numpy.transpose()`. */
5226
- const permuteDims = transpose;
6013
+ /** Transpose the last two dimensions of an array. */
6014
+ function matrixTranspose(a) {
6015
+ if (ndim(a) < 2) throw new Error(`matrixTranspose: input array must be at least 2D`);
6016
+ return moveaxis$1(a, -1, -2);
6017
+ }
5227
6018
  /** Return a 1-D flattened array containing the elements of the input. */
5228
6019
  function ravel(a) {
5229
6020
  return fudgeArray(a).ravel();
@@ -5239,6 +6030,32 @@ function squeeze(a, axis = null) {
5239
6030
  return reshape(a, newShape);
5240
6031
  }
5241
6032
  /**
6033
+ * Expand the shape of an array by inserting new axes of length 1.
6034
+ *
6035
+ * @param a - Input array.
6036
+ * @param axis - Position(s) in the expanded axes where the new axis (or axes)
6037
+ * is placed. Can be a single integer or an array of integers.
6038
+ * @returns Array with the number of dimensions increased.
6039
+ *
6040
+ * @example
6041
+ * ```ts
6042
+ * const x = np.array([1, 2]);
6043
+ * np.expandDims(x, 0); // Shape [1, 2]
6044
+ * np.expandDims(x, 1); // Shape [2, 1]
6045
+ * np.expandDims(x, [0, 2]); // Shape [1, 2, 1]
6046
+ * ```
6047
+ */
6048
+ function expandDims(a, axis) {
6049
+ const as = shape(a);
6050
+ axis = typeof axis === "number" ? [axis] : axis;
6051
+ axis = require_backend.normalizeAxis(axis, as.length + axis.length);
6052
+ const newShape = [];
6053
+ let srcIdx = 0;
6054
+ for (let i = 0; i < as.length + axis.length; i++) if (axis.includes(i)) newShape.push(1);
6055
+ else newShape.push(as[srcIdx++]);
6056
+ return reshape(a, newShape);
6057
+ }
6058
+ /**
5242
6059
  * Repeat each element of an array after themselves.
5243
6060
  *
5244
6061
  * If no axis is provided, use the flattened input array, and return a flat
@@ -5326,7 +6143,7 @@ function diagonal(a, offset, axis1, axis2) {
5326
6143
  */
5327
6144
  function diag(v, k = 0) {
5328
6145
  const a = fudgeArray(v);
5329
- if (!Number.isInteger(k)) throw new TypeError(`k must be an integer, got ${k}`);
6146
+ if (!Number.isInteger(k)) throw new Error(`k must be an integer, got ${k}`);
5330
6147
  if (a.ndim === 1) {
5331
6148
  const n = a.shape[0];
5332
6149
  const ret = where(eye(n).equal(1), a.ref, zerosLike(a));
@@ -5334,12 +6151,46 @@ function diag(v, k = 0) {
5334
6151
  else if (k < 0) return pad(ret, [[-k, 0], [0, -k]]);
5335
6152
  else return ret;
5336
6153
  } else if (a.ndim === 2) return diagonal(a, k);
5337
- else throw new TypeError("numpy.diag only supports 1D and 2D arrays");
6154
+ else throw new Error("numpy.diag only supports 1D and 2D arrays");
5338
6155
  }
5339
6156
  /** Calculate the sum of the diagonal of an array along the given axes. */
5340
6157
  function trace(a, offset = 0, axis1 = 0, axis2 = 1) {
5341
6158
  return diagonal(a, offset, axis1, axis2).sum(-1);
5342
6159
  }
6160
+ /**
6161
+ * Return a sorted copy of an array.
6162
+ *
6163
+ * The array is sorted along a specified axis (the last by default). This may be
6164
+ * an unstable sort, and it dispatches to device-specific implementation.
6165
+ */
6166
+ function sort(a, axis = -1) {
6167
+ return fudgeArray(a).sort(axis);
6168
+ }
6169
+ /**
6170
+ * Return indices that would sort an array. This may be an unstable sorting
6171
+ * algorithm; it need not preserve order of indices in ties.
6172
+ *
6173
+ * Returns an array of `int32` indices.
6174
+ *
6175
+ * The array is sorted along a specified axis (the last by default).
6176
+ */
6177
+ function argsort(a, axis = -1) {
6178
+ return fudgeArray(a).argsort(axis);
6179
+ }
6180
+ /**
6181
+ * Take elements from an array along an axis.
6182
+ *
6183
+ * This is equivalent to advanced indexing with integer indices over that
6184
+ * numbered axis. By default, the flattened array is used.
6185
+ */
6186
+ function take(a, indices, axis = null) {
6187
+ if (axis === null) {
6188
+ a = ravel(a);
6189
+ axis = 0;
6190
+ }
6191
+ axis = require_backend.checkAxis(axis, ndim(a));
6192
+ return gather(a, [indices], [axis], axis);
6193
+ }
5343
6194
  /** Return if two arrays are element-wise equal within a tolerance. */
5344
6195
  function allclose(actual, expected, options) {
5345
6196
  const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
@@ -5356,11 +6207,11 @@ function allclose(actual, expected, options) {
5356
6207
  }
5357
6208
  /** Matrix product of two arrays. */
5358
6209
  function matmul(x, y) {
5359
- if (ndim(x) === 0 || ndim(y) === 0) throw new TypeError("matmul: x and y must be at least 1D");
6210
+ if (ndim(x) === 0 || ndim(y) === 0) throw new Error("matmul: x and y must be at least 1D");
5360
6211
  x = x, y = y;
5361
6212
  if (y.ndim === 1) return dot$2(x, y);
5362
6213
  const numBatchDims = Math.min(Math.max(x.ndim, 2), y.ndim) - 2;
5363
- return dot$1(x, y, {
6214
+ return dot(x, y, {
5364
6215
  lhsContractingDims: [-1],
5365
6216
  rhsContractingDims: [-2],
5366
6217
  lhsBatchDims: require_backend.range(-2 - numBatchDims, -2),
@@ -5368,11 +6219,11 @@ function matmul(x, y) {
5368
6219
  });
5369
6220
  }
5370
6221
  /** Dot product of two arrays. */
5371
- function dot(x, y) {
6222
+ function dot$1(x, y) {
5372
6223
  if (ndim(x) === 0 || ndim(y) === 0) return multiply(x, y);
5373
6224
  x = x, y = y;
5374
6225
  if (y.ndim === 1) return dot$2(x, y);
5375
- return dot$1(x, y, {
6226
+ return dot(x, y, {
5376
6227
  lhsContractingDims: [-1],
5377
6228
  rhsContractingDims: [-2]
5378
6229
  });
@@ -5388,7 +6239,7 @@ function tensordot(x, y, axes = 2) {
5388
6239
  x = fudgeArray(x);
5389
6240
  y = fudgeArray(y);
5390
6241
  if (typeof axes === "number") axes = [require_backend.range(-axes, 0), require_backend.range(axes)];
5391
- return dot$1(x, y, {
6242
+ return dot(x, y, {
5392
6243
  lhsContractingDims: axes[0],
5393
6244
  rhsContractingDims: axes[1]
5394
6245
  });
@@ -5481,7 +6332,7 @@ function einsum(...args) {
5481
6332
  const [b, bidx] = processSingleTensor(operands[j], indices[j], indices[i]);
5482
6333
  indexReduced = indexReduced.filter((idx) => aidx.includes(idx));
5483
6334
  const indexBatch = aidx.filter((idx) => bidx.includes(idx) && !indexReduced.includes(idx));
5484
- const result = dot$1(a, b, {
6335
+ const result = dot(a, b, {
5485
6336
  lhsContractingDims: indexReduced.map((idx) => aidx.indexOf(idx)),
5486
6337
  rhsContractingDims: indexReduced.map((idx) => bidx.indexOf(idx)),
5487
6338
  lhsBatchDims: indexBatch.map((idx) => aidx.indexOf(idx)),
@@ -5509,7 +6360,7 @@ function einsum(...args) {
5509
6360
  * Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
5510
6361
  */
5511
6362
  function inner(x, y) {
5512
- return dot$1(fudgeArray(x), fudgeArray(y), {
6363
+ return dot(fudgeArray(x), fudgeArray(y), {
5513
6364
  lhsContractingDims: [-1],
5514
6365
  rhsContractingDims: [-1]
5515
6366
  });
@@ -5542,6 +6393,30 @@ function vecdot(x, y, { axis } = {}) {
5542
6393
  function vdot(x, y) {
5543
6394
  return dot$2(ravel(x), ravel(y));
5544
6395
  }
6396
+ function _convImpl(name, x, y, mode) {
6397
+ if (x.ndim !== 1 || y.ndim !== 1) throw new Error(`${name}: both inputs must be 1D arrays, got ${x.ndim}D and ${y.ndim}D`);
6398
+ let flipOutput = false;
6399
+ if (x.shape[0] < y.shape[0]) {
6400
+ [x, y] = [y, x];
6401
+ if (name === "correlate") flipOutput = true;
6402
+ }
6403
+ if (name === "convolve") y = flip(y);
6404
+ let padding;
6405
+ if (mode === "valid") padding = "VALID";
6406
+ else if (mode === "same") padding = "SAME_LOWER";
6407
+ else if (mode === "full") padding = [[y.shape[0] - 1, y.shape[0] - 1]];
6408
+ else throw new Error(`${name}: invalid mode ${mode}, expected "full", "same", or "valid"`);
6409
+ const z = conv(x.slice(null, null), y.slice(null, null), [1], padding).slice(0, 0);
6410
+ return flipOutput ? flip(z) : z;
6411
+ }
6412
+ /** Convolution of two one-dimensional arrays. */
6413
+ function convolve(x, y, mode = "full") {
6414
+ return _convImpl("convolve", x, y, mode);
6415
+ }
6416
+ /** Correlation of two one dimensional arrays. */
6417
+ function correlate(x, y, mode = "valid") {
6418
+ return _convImpl("correlate", x, y, mode);
6419
+ }
5545
6420
  /**
5546
6421
  * Return a tuple of coordinate matrices from coordinate vectors.
5547
6422
  *
@@ -5550,7 +6425,7 @@ function vdot(x, y) {
5550
6425
  */
5551
6426
  function meshgrid(xs, { indexing } = {}) {
5552
6427
  indexing ??= "xy";
5553
- for (const x of xs) if (x.ndim !== 1) throw new TypeError(`meshgrid: all inputs must be 1D arrays, got ${x.ndim}D array`);
6428
+ for (const x of xs) if (x.ndim !== 1) throw new Error(`meshgrid: all inputs must be 1D arrays, got ${x.ndim}D array`);
5554
6429
  if (xs.length <= 1) return xs;
5555
6430
  if (indexing === "xy") {
5556
6431
  const [a, b, ...rest] = xs;
@@ -5569,43 +6444,6 @@ function meshgrid(xs, { indexing } = {}) {
5569
6444
  return xs.map((x, i) => broadcast(x, shape$1, [...require_backend.range(i), ...require_backend.range(i + 1, xs.length)]));
5570
6445
  }
5571
6446
  /**
5572
- * Return an array with ones on and below the diagonal and zeros elsewhere.
5573
- *
5574
- * If `k` is provided, it specifies the sub-diagonal on and below which the
5575
- * array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
5576
- * `k>0` is above it.
5577
- */
5578
- function tri(n, m, k = 0, { dtype, device } = {}) {
5579
- m ??= n;
5580
- dtype ??= require_backend.DType.Float32;
5581
- if (!Number.isInteger(n) || n < 0) throw new TypeError(`tri: n must be a non-negative integer, got ${n}`);
5582
- if (!Number.isInteger(m) || m < 0) throw new TypeError(`tri: m must be a non-negative integer, got ${m}`);
5583
- if (!Number.isInteger(k)) throw new TypeError(`tri: k must be an integer, got ${k}`);
5584
- const rows = arange(k, n + k, 1, {
5585
- dtype: require_backend.DType.Int32,
5586
- device
5587
- });
5588
- const cols = arange(0, m, 1, {
5589
- dtype: require_backend.DType.Int32,
5590
- device
5591
- });
5592
- return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
5593
- }
5594
- /** Return the lower triangle of an array. Must be of dimension >= 2. */
5595
- function tril(a, k = 0) {
5596
- if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
5597
- a = fudgeArray(a);
5598
- const [n, m] = a.shape.slice(-2);
5599
- return where(tri(n, m, k, { dtype: bool }), a.ref, zerosLike(a));
5600
- }
5601
- /** Return the upper triangle of an array. Must be of dimension >= 2. */
5602
- function triu(a, k = 0) {
5603
- if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
5604
- a = fudgeArray(a);
5605
- const [n, m] = a.shape.slice(-2);
5606
- return where(tri(n, m, k - 1, { dtype: bool }), zerosLike(a.ref), a);
5607
- }
5608
- /**
5609
6447
  * Clip (limit) the values in an array.
5610
6448
  *
5611
6449
  * Given an interval, values outside the interval are clipped to the interval
@@ -5629,8 +6467,6 @@ function absolute(x) {
5629
6467
  x = fudgeArray(x);
5630
6468
  return where(less(x.ref, 0), x.ref.mul(-1), x);
5631
6469
  }
5632
- /** @function Alias of `jax.numpy.absolute()`. */
5633
- const abs = absolute;
5634
6470
  /** Return an element-wise indication of sign of the input. */
5635
6471
  function sign(x) {
5636
6472
  x = fudgeArray(x);
@@ -5674,6 +6510,20 @@ function tan(x) {
5674
6510
  x = fudgeArray(x);
5675
6511
  return sin(x.ref).div(cos(x));
5676
6512
  }
6513
+ /**
6514
+ * @function
6515
+ * Return the normalized sinc function.
6516
+ *
6517
+ * The sinc function is defined as `sin(πx) / (πx)` for `x != 0`, and `1` for `x = 0`.
6518
+ * This is the normalized sinc function commonly used in signal processing.
6519
+ *
6520
+ * **Note:** JVP is not supported at x=0 due to discontinuous derivative. This
6521
+ * requires a custom JVP rule to handle properly (see JAX implementation).
6522
+ */
6523
+ const sinc = jit$1(function sinc$1(x) {
6524
+ const pix = x.ref.mul(Math.PI);
6525
+ return where(equal(x, 0), 1, sin(pix.ref).div(pix));
6526
+ });
5677
6527
  /** Element-wise inverse cosine function (inverse of cos). */
5678
6528
  function acos(x) {
5679
6529
  return subtract(pi / 2, asin(x));
@@ -5709,12 +6559,6 @@ const atan2 = jit$1(function atan2$1(y, x) {
5709
6559
  const denom = where(xNeg, y, r.add(x));
5710
6560
  return atan(numer.div(denom)).mul(2);
5711
6561
  });
5712
- /** @function Alias of `jax.numpy.acos()`. */
5713
- const arccos = acos;
5714
- /** @function Alias of `jax.numpy.atan()`. */
5715
- const arctan = atan;
5716
- /** @function Alias of `jax.numpy.atan2()`. */
5717
- const arctan2 = atan2;
5718
6562
  /** Element-wise subtraction, with broadcasting. */
5719
6563
  function subtract(x, y) {
5720
6564
  x = fudgeArray(x);
@@ -5732,6 +6576,25 @@ function trueDivide(x, y) {
5732
6576
  return x.div(y);
5733
6577
  }
5734
6578
  /**
6579
+ * Return the largest integer smaller or equal to the division of the inputs.
6580
+ *
6581
+ * The result is always rounded towards negative infinity.
6582
+ *
6583
+ * For floating-point inputs, this is equivalent to `floor(x / y)`.
6584
+ * For integer inputs, we use `(x - remainder(x, y)) / y` to handle
6585
+ * negative values correctly (note: may overflow near int32 boundaries).
6586
+ *
6587
+ * @param x - Dividend array.
6588
+ * @param y - Divisor array.
6589
+ * @returns Element-wise floor division of x by y.
6590
+ */
6591
+ function floorDivide(x, y) {
6592
+ x = fudgeArray(x);
6593
+ y = fudgeArray(y);
6594
+ if (require_backend.isFloatDtype(x.dtype) || require_backend.isFloatDtype(y.dtype)) return floor(trueDivide(x, y));
6595
+ return subtract(x, remainder(x.ref, y.ref)).div(y);
6596
+ }
6597
+ /**
5735
6598
  * @function
5736
6599
  * Calculate element-wise floating-point modulo operation.
5737
6600
  */
@@ -5745,8 +6608,20 @@ const fmod = jit$1(function fmod$1(x, y) {
5745
6608
  const remainder = jit$1(function remainder$1(x, y) {
5746
6609
  return mod(mod(x, y.ref).add(y.ref), y);
5747
6610
  });
5748
- /** @function Alias of `jax.numpy.trueDivide()`. */
5749
- const divide = trueDivide;
6611
+ /**
6612
+ * Return element-wise quotient and remainder simultaneously.
6613
+ *
6614
+ * Equivalent to `[floorDivide(x, y), remainder(x, y)]`.
6615
+ *
6616
+ * @param x - Dividend array.
6617
+ * @param y - Divisor array.
6618
+ * @returns Tuple of [quotient, remainder].
6619
+ */
6620
+ function divmod(x, y) {
6621
+ const xArr = fudgeArray(x);
6622
+ const yArr = fudgeArray(y);
6623
+ return [floorDivide(xArr.ref, yArr.ref), remainder(xArr, yArr)];
6624
+ }
5750
6625
  /** Round input to the nearest integer towards zero. */
5751
6626
  function trunc(x) {
5752
6627
  return idiv(x, 1);
@@ -5768,9 +6643,9 @@ function ldexp(x1, x2) {
5768
6643
  */
5769
6644
  function frexp(x) {
5770
6645
  x = fudgeArray(x);
5771
- const absx = abs(x.ref);
6646
+ const absx = absolute(x.ref);
5772
6647
  const exponent = where(equal(x.ref, 0), 0, floor(log2(absx)).add(1).astype(require_backend.DType.Int32));
5773
- const mantissa = divide(x, exp2(exponent.ref.astype(x.dtype)));
6648
+ const mantissa = x.div(exp2(exponent.ref.astype(x.dtype)));
5774
6649
  return [mantissa, exponent];
5775
6650
  }
5776
6651
  /** Calculate `2**p` for all p in the input array. */
@@ -5813,10 +6688,8 @@ const power = jit$1(function power$1(x1, x2) {
5813
6688
  const x2i = trunc(x2.ref);
5814
6689
  const shouldBeNaN = multiply(x2.ref.notEqual(x2i.ref), x1.ref.less(0));
5815
6690
  const resultSign = where(mod(x2i, 2).notEqual(0), where(x1.ref.less(0), -1, 1), 1);
5816
- return where(shouldBeNaN, nan, exp(log(abs(x1)).mul(x2)).mul(resultSign));
6691
+ return where(shouldBeNaN, nan, exp(log(absolute(x1)).mul(x2)).mul(resultSign));
5817
6692
  });
5818
- /** @function Alias of `jax.numpy.power()`. */
5819
- const pow = power;
5820
6693
  /** @function Calculate the element-wise cube root of the input array. */
5821
6694
  const cbrt = jit$1(function cbrt$1(x) {
5822
6695
  const sgn = where(less(x.ref, 0), -1, 1);
@@ -5882,69 +6755,360 @@ const arccosh = jit$1(function arccosh$1(x) {
5882
6755
  const arctanh = jit$1(function arctanh$1(x) {
5883
6756
  return log(add(1, x.ref).div(subtract(1, x))).mul(.5);
5884
6757
  });
5885
- /** @function Alias of `jax.numpy.arcsinh()`. */
5886
- const asinh = arcsinh;
5887
- /** @function Alias of `jax.numpy.arccosh()`. */
5888
- const acosh = arccosh;
5889
- /** @function Alias of `jax.numpy.arctanh()`. */
5890
- const atanh = arctanh;
5891
6758
  /**
5892
- * Compute the variance of an array.
5893
- *
5894
- * The variance is computed for the flattened array by default, otherwise over
5895
- * the specified axis.
6759
+ * Compute the variance of an array.
6760
+ *
6761
+ * The variance is computed for the flattened array by default, otherwise over
6762
+ * the specified axis.
6763
+ *
6764
+ * If `correction` is provided, the divisor in calculation is `N - correction`,
6765
+ * where `N` represents the number of elements (e.g., for Bessel's correction).
6766
+ */
6767
+ function var_(x, axis = null, opts) {
6768
+ x = fudgeArray(x);
6769
+ axis = require_backend.normalizeAxis(axis, x.ndim);
6770
+ const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
6771
+ if (n === 0) throw new Error("var: cannot compute variance over zero-length axis");
6772
+ const mu = opts?.mean !== void 0 ? opts.mean : mean(x.ref, axis, { keepdims: true });
6773
+ return square(x.sub(mu)).sum(axis, { keepdims: opts?.keepdims }).mul(1 / (n - (opts?.correction ?? 0)));
6774
+ }
6775
+ /**
6776
+ * Compute the standard deviation of an array.
6777
+ *
6778
+ * The standard deviation is computed for the flattened array by default,
6779
+ * otherwise over the specified axis.
6780
+ *
6781
+ * If `correction` is provided, the divisor in calculation is `N - correction`,
6782
+ * where `N` represents the number of elements (e.g., for Bessel's correction).
6783
+ */
6784
+ function std(x, axis = null, opts) {
6785
+ return sqrt(var_(x, axis, opts));
6786
+ }
6787
+ /** Estimate the sample covariance of a set of variables. */
6788
+ function cov(x, y = null, { rowvar = true } = {}) {
6789
+ x = fudgeArray(x);
6790
+ if (x.ndim === 1) x = x.reshape([1, x.shape[0]]);
6791
+ if (y !== null) {
6792
+ y = fudgeArray(y);
6793
+ if (y.ndim === 1) y = y.reshape([1, y.shape[0]]);
6794
+ x = vstack([x, y]);
6795
+ }
6796
+ if (!rowvar) x = x.transpose();
6797
+ const [_M, N] = x.shape;
6798
+ x = x.ref.sub(x.mean(1, { keepdims: true }));
6799
+ return dot$1(x.ref, x.transpose()).div(N - 1);
6800
+ }
6801
+ /** Compute the Pearson correlation coefficients (in range `[-1, 1]`). */
6802
+ function corrcoef(x, y) {
6803
+ const c = cov(x, y);
6804
+ const variances = diag(c.ref);
6805
+ const norm = sqrt(outer(variances.ref, variances));
6806
+ return c.div(norm);
6807
+ }
6808
+ /** Test element-wise for positive or negative infinity, return bool array. */
6809
+ function isinf(x) {
6810
+ x = fudgeArray(x);
6811
+ return require_backend.isFloatDtype(x.dtype) ? x.ref.equal(Infinity).add(x.equal(-Infinity)) : fullLike$1(x, false);
6812
+ }
6813
+ /** Test element-wise for NaN (Not a Number). */
6814
+ function isnan(x) {
6815
+ x = fudgeArray(x);
6816
+ return require_backend.isFloatDtype(x.dtype) ? x.ref.notEqual(x) : fullLike$1(x, false);
6817
+ }
6818
+ /** Test element-wise for negative infinity, return bool array. */
6819
+ function isneginf(x) {
6820
+ x = fudgeArray(x);
6821
+ return require_backend.isFloatDtype(x.dtype) ? x.equal(-Infinity) : fullLike$1(x, false);
6822
+ }
6823
+ /** Test element-wise for positive infinity, return bool array. */
6824
+ function isposinf(x) {
6825
+ x = fudgeArray(x);
6826
+ return require_backend.isFloatDtype(x.dtype) ? x.equal(Infinity) : fullLike$1(x, false);
6827
+ }
6828
+ /**
6829
+ * @function
6830
+ * Test element-wise for finite values (not infinity or NaN).
6831
+ */
6832
+ const isfinite = jit$1(function isfinite$1(x) {
6833
+ if (!require_backend.isFloatDtype(x.dtype)) return fullLike$1(x, true);
6834
+ return isnan(x.ref).add(isinf(x)).notEqual(true);
6835
+ });
6836
+
6837
+ //#endregion
6838
+ //#region src/library/lax-linalg.ts
6839
+ var lax_linalg_exports = {};
6840
+ __export(lax_linalg_exports, {
6841
+ cholesky: () => cholesky$1,
6842
+ lu: () => lu,
6843
+ triangularSolve: () => triangularSolve
6844
+ });
6845
+ /**
6846
+ * Compute the Cholesky decomposition of a symmetric positive-definite matrix.
6847
+ *
6848
+ * The Cholesky decomposition of a matrix `A` is:
6849
+ *
6850
+ * - A = L @ L^T (for upper=false, default)
6851
+ * - A = U^T @ U (for upper=true)
6852
+ *
6853
+ * where `L` is a lower-triangular matrix and `U` is an upper-triangular matrix.
6854
+ * The input matrix must be symmetric and positive-definite.
6855
+ *
6856
+ * @example
6857
+ * ```ts
6858
+ * import { lax, numpy as np } from "@jax-js/jax";
6859
+ *
6860
+ * const x = np.array([[2., 1.], [1., 2.]]);
6861
+ *
6862
+ * // Lower Cholesky factorization (default):
6863
+ * const L = lax.linalg.cholesky(x);
6864
+ * // L ≈ [[1.4142135, 0], [0.70710677, 1.2247449]]
6865
+ *
6866
+ * // Upper Cholesky factorization:
6867
+ * const U = lax.linalg.cholesky(x, { upper: true });
6868
+ * // U ≈ [[1.4142135, 0.70710677], [0, 1.2247449]]
6869
+ * ```
6870
+ */
6871
+ function cholesky$1(a, { upper = false } = {}) {
6872
+ const L = cholesky$2(a);
6873
+ return upper ? moveaxis$1(L, -2, -1) : L;
6874
+ }
6875
+ /**
6876
+ * LU decomposition with partial pivoting.
6877
+ *
6878
+ * Computes the matrix decomposition: `P @ A = L @ U`, where `P` is a
6879
+ * permutation of the rows of `A`, `L` is lower-triangular with unit diagonal,
6880
+ * and `U` is upper-triangular.
6881
+ *
6882
+ * @param x - A batch of matrices with shape `[..., m, n]`.
6883
+ *
6884
+ * @returns A tuple `(lu, pivots, permutation)` where:
6885
+ * - `lu`: combined lower and upper triangular matrices.
6886
+ * - `pivots`: an array of pivot indices with shape `[..., min(m, n)]`.
6887
+ * - `permutation`: the permutation generated by pivots with shape `[..., m]`.
6888
+ *
6889
+ * @example
6890
+ * ```ts
6891
+ * import { lax, numpy as np } from "@jax-js/jax";
6892
+ *
6893
+ * const A = np.array([[4., 3.], [6., 3.]]);
6894
+ * const [lu, pivots, permutation] = lax.linalg.lu(A);
6895
+ * // lu ≈ [[6., 3.], [0.6666667, 1.0]]
6896
+ * // pivots = [1, 1]
6897
+ * // permutation = [1, 0]
6898
+ * ```
6899
+ */
6900
+ function lu(x) {
6901
+ return lu$1(x);
6902
+ }
6903
+ /**
6904
+ * Solve a triangular linear system.
6905
+ *
6906
+ * Solves `a @ x = b` (if leftSide=true) or `x @ a = b` (if leftSide=false)
6907
+ * where `a` is a triangular matrix.
6908
+ *
6909
+ * @example
6910
+ * ```ts
6911
+ * import { lax, numpy as np } from "@jax-js/jax";
6912
+ *
6913
+ * const L = np.array([[2., 0.], [1., 3.]]);
6914
+ * const b = np.array([4., 7.]).reshape([2, 1]);
6915
+ *
6916
+ * // Solve L @ x = b
6917
+ * const x = lax.linalg.triangularSolve(L, b, { leftSide: true, lower: true });
6918
+ * // x = [[2.], [5./3.]]
6919
+ * ```
6920
+ */
6921
+ function triangularSolve(a, b, { leftSide = false, lower = false, transposeA = false, unitDiagonal = false } = {}) {
6922
+ a = fudgeArray(a);
6923
+ b = fudgeArray(b);
6924
+ if (!leftSide) transposeA = !transposeA;
6925
+ else b = moveaxis$1(b, -2, -1);
6926
+ if (transposeA) a = moveaxis$1(a, -2, -1);
6927
+ let x = triangularSolve$1(a, b, {
6928
+ lower,
6929
+ unitDiagonal
6930
+ });
6931
+ if (leftSide) x = moveaxis$1(x, -2, -1);
6932
+ return x;
6933
+ }
6934
+
6935
+ //#endregion
6936
+ //#region src/library/lax.ts
6937
+ var lax_exports = {};
6938
+ __export(lax_exports, {
6939
+ conv: () => conv,
6940
+ convGeneralDilated: () => convGeneralDilated,
6941
+ convWithGeneralPadding: () => convWithGeneralPadding,
6942
+ dot: () => dot,
6943
+ erf: () => erf,
6944
+ erfc: () => erfc,
6945
+ linalg: () => lax_linalg_exports,
6946
+ reduceWindow: () => reduceWindow,
6947
+ stopGradient: () => stopGradient$1
6948
+ });
6949
+ /**
6950
+ * General dot product/contraction operator.
5896
6951
  *
5897
- * If `correction` is provided, the divisor in calculation is `N - correction`,
5898
- * where `N` represents the number of elements (e.g., for Bessel's correction).
6952
+ * Prefer higher-level functions like `jax.numpy.dot()`, `jax.numpy.matmul()`,
6953
+ * `jax.numpy.tensordot(), and `jax.numpy.einsum()` where possible.
5899
6954
  */
5900
- function var_(x, axis = null, opts) {
5901
- x = fudgeArray(x);
5902
- axis = require_backend.normalizeAxis(axis, x.ndim);
5903
- const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
5904
- if (n === 0) throw new Error("var: cannot compute variance over zero-length axis");
5905
- const mu = opts?.mean !== void 0 ? opts.mean : mean(x.ref, axis, { keepdims: true });
5906
- return square(x.sub(mu)).sum(axis, { keepdims: opts?.keepdims }).mul(1 / (n - (opts?.correction ?? 0)));
6955
+ function dot(lhs, rhs, { lhsContractingDims: lc = [], rhsContractingDims: rc = [], lhsBatchDims: lb = [], rhsBatchDims: rb = [] } = {}) {
6956
+ if (lc.length !== rc.length) throw new Error(`dot: contracting dims lengths mismatch, got ${JSON.stringify(lc)} and ${JSON.stringify(rc)}`);
6957
+ else if (lb.length !== rb.length) throw new Error(`dot: batch dims lengths mismatch, got ${JSON.stringify(lb)} and ${JSON.stringify(rb)}`);
6958
+ lc = lc.map((a) => require_backend.checkAxis(a, lhs.ndim));
6959
+ rc = rc.map((a) => require_backend.checkAxis(a, rhs.ndim));
6960
+ lb = lb.map((a) => require_backend.checkAxis(a, lhs.ndim));
6961
+ rb = rb.map((a) => require_backend.checkAxis(a, rhs.ndim));
6962
+ if (lc.some((a) => lb.includes(a))) throw new Error(`dot: lhs contracting dims ${JSON.stringify(lc)} overlap with batch dims ${JSON.stringify(lb)}`);
6963
+ else if (rc.some((a) => rb.includes(a))) throw new Error(`dot: rhs contracting dims ${JSON.stringify(rc)} overlap with batch dims ${JSON.stringify(rb)}`);
6964
+ const lf = require_backend.range(lhs.ndim).filter((a) => !lc.includes(a) && !lb.includes(a));
6965
+ const rf = require_backend.range(rhs.ndim).filter((a) => !rc.includes(a) && !rb.includes(a));
6966
+ const lhs2 = lhs.transpose([
6967
+ ...lb,
6968
+ ...lf,
6969
+ ...lc
6970
+ ]);
6971
+ const rhs2 = rhs.transpose([
6972
+ ...rb,
6973
+ ...rf,
6974
+ ...rc
6975
+ ]);
6976
+ if (lc.length === 0) return mul(lhs2.reshape([
6977
+ ...lb.map((a) => lhs.shape[a]),
6978
+ ...lf.map((a) => lhs.shape[a]),
6979
+ ...require_backend.rep(rf.length, 1)
6980
+ ]), rhs2.reshape([
6981
+ ...rb.map((a) => rhs.shape[a]),
6982
+ ...require_backend.rep(lf.length, 1),
6983
+ ...rf.map((a) => rhs.shape[a])
6984
+ ]));
6985
+ const dotShapeX = lc.map((a) => lhs.shape[a]);
6986
+ const dotShapeY = rc.map((a) => rhs.shape[a]);
6987
+ if (!require_backend.deepEqual(dotShapeX, dotShapeY)) throw new Error(`dot: shapes not aligned along contracting dims: ${JSON.stringify(dotShapeX)} != ${JSON.stringify(dotShapeY)}`);
6988
+ return dot$2(lhs2.reshape([
6989
+ ...lb.map((a) => lhs.shape[a]),
6990
+ ...lf.map((a) => lhs.shape[a]),
6991
+ ...require_backend.rep(rf.length, 1),
6992
+ require_backend.prod(dotShapeX)
6993
+ ]), rhs2.reshape([
6994
+ ...rb.map((a) => rhs.shape[a]),
6995
+ ...require_backend.rep(lf.length, 1),
6996
+ ...rf.map((a) => rhs.shape[a]),
6997
+ require_backend.prod(dotShapeY)
6998
+ ]));
6999
+ }
7000
+ function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
7001
+ const padType = padding.toUpperCase();
7002
+ switch (padType) {
7003
+ case "VALID": return require_backend.rep(inShape.length, [0, 0]);
7004
+ case "SAME":
7005
+ case "SAME_LOWER": {
7006
+ const outShape = inShape.map((size$1, i) => Math.ceil(size$1 / strides[i]));
7007
+ const padSizes = require_backend.zipn(outShape, strides, filterShape, dilation, inShape).map(([o, s, k, d, i]) => Math.max(0, (o - 1) * s + 1 + (k - 1) * d - i));
7008
+ if (padType === "SAME") return padSizes.map((size$1) => [size$1 >> 1, size$1 - (size$1 >> 1)]);
7009
+ else return padSizes.map((size$1) => [size$1 - (size$1 >> 1), size$1 >> 1]);
7010
+ }
7011
+ default: throw new Error(`Unknown padding type: ${padType}`);
7012
+ }
5907
7013
  }
5908
7014
  /**
5909
- * Compute the standard deviation of an array.
7015
+ * General n-dimensional convolution operator, with optional dilation.
5910
7016
  *
5911
- * The standard deviation is computed for the flattened array by default,
5912
- * otherwise over the specified axis.
7017
+ * The semantics of this operation mimic the `jax.lax.conv_general_dilated`
7018
+ * function in JAX, which wraps XLA's general convolution operator.
5913
7019
  *
5914
- * If `correction` is provided, the divisor in calculation is `N - correction`,
5915
- * where `N` represents the number of elements (e.g., for Bessel's correction).
7020
+ * Grouped convolutions are not supported right now.
5916
7021
  */
5917
- function std(x, axis = null, opts) {
5918
- return sqrt(var_(x, axis, opts));
7022
+ function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
7023
+ if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
7024
+ if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
7025
+ if (typeof padding === "string") {
7026
+ if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
7027
+ padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? require_backend.rep(rhs.ndim - 2, 1), padding);
7028
+ }
7029
+ if (featureGroupCount !== 1) {
7030
+ const G = featureGroupCount;
7031
+ const [N, C_in, ...xs] = lhs.shape;
7032
+ const [C_out, C_in_per_group, ...ks] = rhs.shape;
7033
+ if (C_in % G !== 0) throw new Error(`featureGroupCount=${G} must divide input channels=${C_in}`);
7034
+ if (C_out % G !== 0) throw new Error(`featureGroupCount=${G} must divide output channels=${C_out}`);
7035
+ if (C_in / G !== C_in_per_group) throw new Error(`rhs input channels=${C_in_per_group} must equal lhs input channels / groups=${C_in / G}`);
7036
+ const lhsGrouped = moveaxis(lhs.reshape([
7037
+ N,
7038
+ G,
7039
+ C_in / G,
7040
+ ...xs
7041
+ ]), 1, 0);
7042
+ const rhsGrouped = rhs.reshape([
7043
+ G,
7044
+ C_out / G,
7045
+ C_in_per_group,
7046
+ ...ks
7047
+ ]);
7048
+ const result = conv$1(lhsGrouped, rhsGrouped, {
7049
+ vmapDims: 1,
7050
+ strides: windowStrides,
7051
+ padding,
7052
+ lhsDilation,
7053
+ rhsDilation
7054
+ });
7055
+ const ys = result.shape.slice(3);
7056
+ return moveaxis(result, 0, 1).reshape([
7057
+ N,
7058
+ C_out,
7059
+ ...ys
7060
+ ]);
7061
+ }
7062
+ return conv$1(lhs, rhs, {
7063
+ strides: windowStrides,
7064
+ padding,
7065
+ lhsDilation,
7066
+ rhsDilation
7067
+ });
5919
7068
  }
5920
- /** Test element-wise for positive or negative infinity, return bool array. */
5921
- function isinf(x) {
5922
- x = fudgeArray(x);
5923
- return require_backend.isFloatDtype(x.dtype) ? x.ref.equal(Infinity).add(x.equal(-Infinity)) : fullLike$1(x, false);
7069
+ /** Convenience wrapper around `convGeneralDilated`. */
7070
+ function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation) {
7071
+ return convGeneralDilated(lhs, rhs, windowStrides, padding, {
7072
+ lhsDilation,
7073
+ rhsDilation
7074
+ });
5924
7075
  }
5925
- /** Test element-wise for NaN (Not a Number). */
5926
- function isnan(x) {
5927
- x = fudgeArray(x);
5928
- return require_backend.isFloatDtype(x.dtype) ? x.ref.notEqual(x) : fullLike$1(x, false);
7076
+ /** Convenience wrapper around `convGeneralDilated`. */
7077
+ function conv(lhs, rhs, windowStrides, padding) {
7078
+ return convGeneralDilated(lhs, rhs, windowStrides, padding);
5929
7079
  }
5930
- /** Test element-wise for negative infinity, return bool array. */
5931
- function isneginf(x) {
5932
- x = fudgeArray(x);
5933
- return require_backend.isFloatDtype(x.dtype) ? x.equal(-Infinity) : fullLike$1(x, false);
7080
+ /** Reduce a computation over padded windows. */
7081
+ function reduceWindow(operand, computation, windowDimensions, windowStrides) {
7082
+ if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
7083
+ if (!windowStrides) windowStrides = require_backend.rep(windowDimensions.length, 1);
7084
+ for (let i = 0; i < operand.ndim; i++) computation = vmap$1(computation, 0);
7085
+ return computation(bind1(Primitive.Pool, [operand], {
7086
+ window: windowDimensions,
7087
+ strides: windowStrides
7088
+ }));
5934
7089
  }
5935
- /** Test element-wise for positive infinity, return bool array. */
5936
- function isposinf(x) {
5937
- x = fudgeArray(x);
5938
- return require_backend.isFloatDtype(x.dtype) ? x.equal(Infinity) : fullLike$1(x, false);
7090
+ /** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
7091
+ function erf(x) {
7092
+ return erf$1(x);
5939
7093
  }
5940
7094
  /**
5941
- * @function
5942
- * Test element-wise for finite values (not infinity or NaN).
7095
+ * The complementary error function: `erfc(x) = 1 - erf(x)`.
7096
+ *
7097
+ * This function is more accurate than `1 - erf(x)` for large values of `x`,
7098
+ * where `erf(x)` is very close to 1.
5943
7099
  */
5944
- const isfinite = jit$1(function isfinite$1(x) {
5945
- if (!require_backend.isFloatDtype(x.dtype)) return fullLike$1(x, true);
5946
- return isnan(x.ref).add(isinf(x)).notEqual(true);
5947
- });
7100
+ function erfc(x) {
7101
+ return erfc$1(x);
7102
+ }
7103
+ /**
7104
+ * Stops gradient computation.
7105
+ *
7106
+ * Behaves as the identity function but prevents the flow of gradients during
7107
+ * forward or reverse-mode automatic differentiation.
7108
+ */
7109
+ function stopGradient$1(x) {
7110
+ return stopGradient(x);
7111
+ }
5948
7112
 
5949
7113
  //#endregion
5950
7114
  //#region src/library/nn.ts
@@ -5954,6 +7118,10 @@ __export(nn_exports, {
5954
7118
  elu: () => elu,
5955
7119
  gelu: () => gelu,
5956
7120
  glu: () => glu,
7121
+ hardSigmoid: () => hardSigmoid,
7122
+ hardSilu: () => hardSilu,
7123
+ hardSwish: () => hardSilu,
7124
+ hardTanh: () => hardTanh,
5957
7125
  identity: () => identity,
5958
7126
  leakyRelu: () => leakyRelu,
5959
7127
  logSigmoid: () => logSigmoid,
@@ -5964,14 +7132,17 @@ __export(nn_exports, {
5964
7132
  oneHot: () => oneHot,
5965
7133
  relu: () => relu,
5966
7134
  relu6: () => relu6,
7135
+ selu: () => selu,
5967
7136
  sigmoid: () => sigmoid,
5968
7137
  silu: () => silu,
5969
7138
  softSign: () => softSign,
5970
7139
  softmax: () => softmax,
5971
7140
  softplus: () => softplus,
7141
+ sparsePlus: () => sparsePlus,
7142
+ sparseSigmoid: () => sparseSigmoid,
5972
7143
  squareplus: () => squareplus,
5973
7144
  standardize: () => standardize,
5974
- swish: () => swish
7145
+ swish: () => silu
5975
7146
  });
5976
7147
  /**
5977
7148
  * Rectified Linear Unit (ReLU) activation function:
@@ -6006,6 +7177,28 @@ function softplus(x) {
6006
7177
  return log(exp(x).add(1));
6007
7178
  }
6008
7179
  /**
7180
+ * @function
7181
+ * Sparse plus function:
7182
+ *
7183
+ * - When `x <= -1`: `0`
7184
+ * - When `-1 < x < 1`: `(x+1)**2 / 4`
7185
+ * - When `x >= 1`: `x`
7186
+ */
7187
+ const sparsePlus = jit$1((x) => {
7188
+ return where(x.ref.lessEqual(-1), 0, where(x.ref.less(1), square(x.ref.add(1)).mul(.25), x));
7189
+ });
7190
+ /**
7191
+ * @function
7192
+ * Sparse sigmoid activation function.
7193
+ *
7194
+ * - When `x <= -1`: `0`
7195
+ * - When `-1 < x < 1`: `(x + 1) / 2`
7196
+ * - When `x >= 1`: `1`
7197
+ */
7198
+ const sparseSigmoid = jit$1((x) => {
7199
+ return clip(x.add(1).mul(.5), 0, 1);
7200
+ });
7201
+ /**
6009
7202
  * Soft-sign activation function, computed element-wise:
6010
7203
  * `softsign(x) = x / (|x| + 1)`.
6011
7204
  */
@@ -6027,17 +7220,6 @@ const silu = jit$1(function silu$1(x) {
6027
7220
  return x.ref.mul(sigmoid(x));
6028
7221
  });
6029
7222
  /**
6030
- * @function
6031
- * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
6032
- * Swish, computed element-wise:
6033
- * `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
6034
- *
6035
- * `swish()` and `silu()` are both aliases for the same function.
6036
- *
6037
- * Reference: https://en.wikipedia.org/wiki/Swish_function
6038
- */
6039
- const swish = silu;
6040
- /**
6041
7223
  * Log-sigmoid activation function, computed element-wise:
6042
7224
  * `log_sigmoid(x) = log(sigmoid(x)) = -log(1 + exp(-x))`.
6043
7225
  */
@@ -6054,6 +7236,19 @@ function leakyRelu(x, negativeSlope = .01) {
6054
7236
  x = fudgeArray(x);
6055
7237
  return where(less(x.ref, 0), x.ref.mul(negativeSlope), x);
6056
7238
  }
7239
+ /** Hard sigmoid activation function: `relu6(x+3)/6`. */
7240
+ function hardSigmoid(x) {
7241
+ return relu6(add(x, 3)).mul(1 / 6);
7242
+ }
7243
+ /** Hard SiLU (swish) activation function: `x * hardSigmoid(x)`. */
7244
+ function hardSilu(x) {
7245
+ x = fudgeArray(x);
7246
+ return x.ref.mul(hardSigmoid(x));
7247
+ }
7248
+ /** Hard tanh activation function: `clip(x, -1, 1)`. */
7249
+ function hardTanh(x) {
7250
+ return clip(x, -1, 1);
7251
+ }
6057
7252
  /**
6058
7253
  * Exponential linear unit activation function.
6059
7254
  *
@@ -6076,6 +7271,20 @@ function celu(x, alpha = 1) {
6076
7271
  }
6077
7272
  /**
6078
7273
  * @function
7274
+ * Scaled exponential linear unit activation.
7275
+ *
7276
+ * Computes the element-wise function:
7277
+ * `selu(x) = lambda * (x > 0 ? x : alpha * (exp(x) - 1))`
7278
+ *
7279
+ * Where `alpha = 1.6732632423543772` and `lambda = 1.0507009873554805`.
7280
+ */
7281
+ const selu = jit$1(function selu$1(x) {
7282
+ const alpha = 1.6732632423543772;
7283
+ const lambda = 1.0507009873554805;
7284
+ return where(x.ref.less(0), expm1(x.ref).mul(alpha), x).mul(lambda);
7285
+ });
7286
+ /**
7287
+ * @function
6079
7288
  * Gaussion error linear unit (GELU) activation function.
6080
7289
  *
6081
7290
  * This is computed element-wise. There are two variants depending on whether
@@ -6229,35 +7438,46 @@ var random_exports = {};
6229
7438
  __export(random_exports, {
6230
7439
  bernoulli: () => bernoulli,
6231
7440
  bits: () => bits,
7441
+ cauchy: () => cauchy,
6232
7442
  exponential: () => exponential,
7443
+ gumbel: () => gumbel,
6233
7444
  key: () => key,
7445
+ laplace: () => laplace,
7446
+ multivariateNormal: () => multivariateNormal,
6234
7447
  normal: () => normal,
6235
7448
  split: () => split,
6236
7449
  uniform: () => uniform
6237
7450
  });
6238
- function validateKeyShape(key$1) {
7451
+ function validateKeyShape(key$1, scalar = false) {
6239
7452
  if (key$1.ndim === 0) throw new Error("Key must have at least one dimension.");
6240
7453
  if (key$1.shape[key$1.shape.length - 1] !== 2) throw new Error(`Invalid key shape: ${key$1.shape}. Expected last dimension to be 2.`);
7454
+ if (scalar && key$1.shape.length > 1) throw new Error(`Expected a single PRNG key, but got a batch of keys with shape ${JSON.stringify(key$1.shape)} - use jax.vmap for batching.`);
6241
7455
  return key$1.shape.slice(0, -1);
6242
7456
  }
7457
+ function getK01(key$1) {
7458
+ const keyShape = validateKeyShape(key$1, true);
7459
+ let [k0, k1] = split$2(key$1, -1, [1, 1]);
7460
+ k0 = k0.reshape(keyShape);
7461
+ k1 = k1.reshape(keyShape);
7462
+ return [k0, k1];
7463
+ }
6243
7464
  /** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
6244
7465
  function key(seed) {
6245
- seed = seed >>> 0;
6246
- return array([0, seed], { dtype: require_backend.DType.Uint32 });
7466
+ seed = array(seed, { dtype: require_backend.DType.Uint32 });
7467
+ if (seed.ndim !== 0) throw new Error(`key: seed must be a scalar integer, but got shape ${seed.shape} - use jax.vmap for batching.`);
7468
+ return stack([0, seed]);
6247
7469
  }
6248
7470
  /** Splits a PRNG key into `num` new keys by adding a leading axis. */
6249
7471
  function split(key$1, num = 2) {
6250
7472
  const shape$1 = typeof num === "number" ? [num] : num;
6251
7473
  for (const len of shape$1) if (len <= 0 || !Number.isInteger(len)) throw new Error(`Invalid split length: ${len}. Must be a positive integer.`);
6252
- const keyShape = validateKeyShape(key$1);
6253
- const k0 = key$1.ref.slice(...keyShape.map(() => null), 0);
6254
- const k1 = key$1.slice(...keyShape.map(() => null), 1);
7474
+ const [k0, k1] = getK01(key$1);
6255
7475
  return stack([randomBits(k0.ref, k1.ref, shape$1, 0), randomBits(k0, k1, shape$1, 1)], -1);
6256
7476
  }
6257
7477
  /** Sample uniform bits in the form of unsigned integers. */
6258
7478
  function bits(key$1, shape$1 = []) {
6259
- const keyShape = validateKeyShape(key$1);
6260
- return randomBits(key$1.ref.slice(...keyShape.map(() => null), 0), key$1.slice(...keyShape.map(() => null), 1), shape$1);
7479
+ const [k0, k1] = getK01(key$1);
7480
+ return randomBits(k0, k1, shape$1);
6261
7481
  }
6262
7482
  /**
6263
7483
  * @function
@@ -6289,6 +7509,16 @@ function bernoulli(key$1, p = .5, shape$1 = []) {
6289
7509
  }
6290
7510
  /**
6291
7511
  * @function
7512
+ * Sample from a Cauchy distribution with location 0 and scale 1.
7513
+ *
7514
+ * Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
7515
+ */
7516
+ const cauchy = jit$1(function cauchy$1(key$1, shape$1 = []) {
7517
+ const u = uniform(key$1, shape$1);
7518
+ return tan(u.sub(.5).mul(Math.PI));
7519
+ }, { staticArgnums: [1] });
7520
+ /**
7521
+ * @function
6292
7522
  * Sample exponential random values according to `p(x) = exp(-x)`.
6293
7523
  */
6294
7524
  const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
@@ -6297,6 +7527,56 @@ const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
6297
7527
  }, { staticArgnums: [1] });
6298
7528
  /**
6299
7529
  * @function
7530
+ * Sample from a Gumbel distribution with location 0 and scale 1.
7531
+ *
7532
+ * Uses inverse transform sampling: `x = -log(-log(u))` where u ~ Uniform(0, 1).
7533
+ */
7534
+ const gumbel = jit$1(function gumbel$1(key$1, shape$1 = []) {
7535
+ const u = uniform(key$1, shape$1);
7536
+ return negative(log(negative(log1p(negative(u)))));
7537
+ }, { staticArgnums: [1] });
7538
+ /**
7539
+ * @function
7540
+ * Sample from a Laplace distribution with location 0 and scale 1.
7541
+ *
7542
+ * Uses inverse transform sampling: the CDF is `F(x) = 0.5 + 0.5 * sign(x) * (1 - exp(-|x|))`.
7543
+ * Inverting: `x = -sign(u - 0.5) * log(1 - 2 * |u - 0.5|)`.
7544
+ */
7545
+ const laplace = jit$1(function laplace$1(key$1, shape$1 = []) {
7546
+ const u = uniform(key$1, shape$1);
7547
+ const centered = u.sub(.5);
7548
+ const s = sign(centered.ref);
7549
+ const absVal = absolute(centered);
7550
+ return s.mul(log1p(absVal.mul(-2)).mul(-1));
7551
+ }, { staticArgnums: [1] });
7552
+ /**
7553
+ * @function
7554
+ * Sample multivariate normal random values with given mean and covariance.
7555
+ *
7556
+ * The values are returned with the given shape, along with the final dimension
7557
+ * used to represent the n-dimensional multivariate normal factors.
7558
+ *
7559
+ * This uses Cholesky decomposition on the covariance matrix.
7560
+ *
7561
+ * - `key` - PRNG key
7562
+ * - `mean` - Mean vector of shape `[..., n]`
7563
+ * - `cov` - Covariance of shape `[..., n, n]`, must be positive-definite
7564
+ * - `shape` - Result batch shape, must be broadcastable with
7565
+ * `mean.shape[:-1]` and `cov.shape[:-2]`
7566
+ * @returns Random samples of shape `[...shape, n]`
7567
+ */
7568
+ const multivariateNormal = jit$1(function multivariateNormal$1(key$1, mean$1, cov$1, shape$1 = []) {
7569
+ mean$1 = fudgeArray(mean$1);
7570
+ cov$1 = fudgeArray(cov$1);
7571
+ const n = mean$1.shape[mean$1.ndim - 1];
7572
+ if (cov$1.shape[cov$1.ndim - 1] !== n || cov$1.shape[cov$1.ndim - 2] !== n) throw new Error(`Invalid covariance shape: ${cov$1.shape}. Expected last two dimensions to be [${n}, ${n}].`);
7573
+ const outputShape = broadcastShapes(shape$1, mean$1.shape.slice(0, -1), cov$1.shape.slice(0, -2)).concat(n);
7574
+ const L = cholesky(cov$1);
7575
+ const z = normal(key$1, outputShape);
7576
+ return einsum("...ij,...j->...i", L, z).add(mean$1);
7577
+ }, { staticArgnums: [3] });
7578
+ /**
7579
+ * @function
6300
7580
  * Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
6301
7581
  *
6302
7582
  * Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
@@ -6405,11 +7685,6 @@ const valueAndGrad = valueAndGrad$1;
6405
7685
  */
6406
7686
  const jacrev = jacrev$1;
6407
7687
  /**
6408
- * @function
6409
- * Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
6410
- */
6411
- const jacobian = jacrev;
6412
- /**
6413
7688
  * Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
6414
7689
  *
6415
7690
  * This can be used to wait for the results of an intermediate computation to
@@ -6445,6 +7720,7 @@ async function devicePut(x, device) {
6445
7720
 
6446
7721
  //#endregion
6447
7722
  exports.Array = Array$1;
7723
+ exports.ClosedJaxpr = ClosedJaxpr;
6448
7724
  exports.DType = require_backend.DType;
6449
7725
  exports.Jaxpr = Jaxpr;
6450
7726
  exports.blockUntilReady = blockUntilReady;
@@ -6454,7 +7730,7 @@ exports.devices = require_backend.devices;
6454
7730
  exports.grad = grad;
6455
7731
  exports.init = require_backend.init;
6456
7732
  exports.jacfwd = jacfwd;
6457
- exports.jacobian = jacobian;
7733
+ exports.jacobian = jacrev;
6458
7734
  exports.jacrev = jacrev;
6459
7735
  exports.jit = jit;
6460
7736
  exports.jvp = jvp;
@@ -6499,5 +7775,4 @@ Object.defineProperty(exports, 'tree', {
6499
7775
  });
6500
7776
  exports.valueAndGrad = valueAndGrad;
6501
7777
  exports.vjp = vjp;
6502
- exports.vmap = vmap;
6503
- //# sourceMappingURL=index.cjs.map
7778
+ exports.vmap = vmap;