@jax-js/jax 0.1.3 → 0.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -1,5 +1,5 @@
1
1
  import { __export } from "./chunk-Cl8Af3a2.js";
2
- import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-BY8wlLEl.js";
2
+ import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-DaqL-MNz.js";
3
3
 
4
4
  //#region src/frontend/convolution.ts
5
5
  /**
@@ -331,6 +331,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
331
331
  Primitive$1["Mul"] = "mul";
332
332
  Primitive$1["Idiv"] = "idiv";
333
333
  Primitive$1["Mod"] = "mod";
334
+ Primitive$1["Min"] = "min";
335
+ Primitive$1["Max"] = "max";
334
336
  Primitive$1["Neg"] = "neg";
335
337
  Primitive$1["Reciprocal"] = "reciprocal";
336
338
  Primitive$1["Floor"] = "floor";
@@ -338,7 +340,6 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
338
340
  Primitive$1["StopGradient"] = "stop_gradient";
339
341
  Primitive$1["Cast"] = "cast";
340
342
  Primitive$1["Bitcast"] = "bitcast";
341
- Primitive$1["RandomBits"] = "random_bits";
342
343
  Primitive$1["Sin"] = "sin";
343
344
  Primitive$1["Cos"] = "cos";
344
345
  Primitive$1["Asin"] = "asin";
@@ -348,8 +349,6 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
348
349
  Primitive$1["Erf"] = "erf";
349
350
  Primitive$1["Erfc"] = "erfc";
350
351
  Primitive$1["Sqrt"] = "sqrt";
351
- Primitive$1["Min"] = "min";
352
- Primitive$1["Max"] = "max";
353
352
  Primitive$1["Reduce"] = "reduce";
354
353
  Primitive$1["Dot"] = "dot";
355
354
  Primitive$1["Conv"] = "conv";
@@ -357,14 +356,22 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
357
356
  Primitive$1["PoolTranspose"] = "pool_transpose";
358
357
  Primitive$1["Compare"] = "compare";
359
358
  Primitive$1["Where"] = "where";
359
+ Primitive$1["Concatenate"] = "concatenate";
360
+ Primitive$1["Split"] = "split";
361
+ Primitive$1["RandomBits"] = "random_bits";
362
+ Primitive$1["Gather"] = "gather";
360
363
  Primitive$1["Transpose"] = "transpose";
361
364
  Primitive$1["Broadcast"] = "broadcast";
362
365
  Primitive$1["Reshape"] = "reshape";
363
366
  Primitive$1["Flip"] = "flip";
364
367
  Primitive$1["Shrink"] = "shrink";
365
368
  Primitive$1["Pad"] = "pad";
366
- Primitive$1["Gather"] = "gather";
367
- Primitive$1["JitCall"] = "jit_call";
369
+ Primitive$1["Sort"] = "sort";
370
+ Primitive$1["Argsort"] = "argsort";
371
+ Primitive$1["TriangularSolve"] = "triangular_solve";
372
+ Primitive$1["Cholesky"] = "cholesky";
373
+ Primitive$1["LU"] = "lu";
374
+ Primitive$1["Jit"] = "jit";
368
375
  return Primitive$1;
369
376
  }({});
370
377
  let CompareOp = /* @__PURE__ */ function(CompareOp$1) {
@@ -386,6 +393,12 @@ function idiv(x, y) {
386
393
  function mod(x, y) {
387
394
  return bind1(Primitive.Mod, [x, y]);
388
395
  }
396
+ function min$1(x, y) {
397
+ return bind1(Primitive.Min, [x, y]);
398
+ }
399
+ function max$1(x, y) {
400
+ return bind1(Primitive.Max, [x, y]);
401
+ }
389
402
  function neg(x) {
390
403
  return bind1(Primitive.Neg, [x]);
391
404
  }
@@ -407,12 +420,6 @@ function cast(x, dtype) {
407
420
  function bitcast(x, dtype) {
408
421
  return bind1(Primitive.Bitcast, [x], { dtype });
409
422
  }
410
- function randomBits(k0, k1, shape$1, mode = "xor") {
411
- return bind1(Primitive.RandomBits, [k0, k1], {
412
- shape: shape$1,
413
- mode
414
- });
415
- }
416
423
  function sin$1(x) {
417
424
  return bind1(Primitive.Sin, [x]);
418
425
  }
@@ -440,12 +447,6 @@ function erfc$1(x) {
440
447
  function sqrt$1(x) {
441
448
  return bind1(Primitive.Sqrt, [x]);
442
449
  }
443
- function min$1(x, y) {
444
- return bind1(Primitive.Min, [x, y]);
445
- }
446
- function max$1(x, y) {
447
- return bind1(Primitive.Max, [x, y]);
448
- }
449
450
  function reduce(x, op, axis = null, opts) {
450
451
  if (!AluGroup.Reduce.has(op)) throw new TypeError(`Invalid reduce operation: ${op}`);
451
452
  axis = normalizeAxis(axis, ndim$1(x));
@@ -501,6 +502,41 @@ function where$1(cond, x, y) {
501
502
  y
502
503
  ]);
503
504
  }
505
+ function concatenate$1(xs, axis) {
506
+ if (xs.length === 0) throw new Error("concatenate requires at least one input");
507
+ const avals = xs.map((x) => ShapedArray.fromAval(getAval(x)));
508
+ axis = checkAxis(axis, avals[0].ndim);
509
+ for (const x of avals) if (x.ndim !== avals[0].ndim || !x.shape.every((s, i) => i === axis || s === avals[0].shape[i])) throw new Error(`Concatenate: inputs ${avals[0]} and ${x} must match shapes except on axis ${axis}`);
510
+ return bind1(Primitive.Concatenate, xs, { axis });
511
+ }
512
+ function split$2(x, axis, sizes) {
513
+ axis = checkAxis(axis, ndim$1(x));
514
+ if (sizes.some((s) => s < 0 || !Number.isInteger(s))) throw new Error(`split: sizes must be nonnegative integers, got ${JSON.stringify(sizes)}`);
515
+ const totalSize = sizes.reduce((a, b) => a + b, 0);
516
+ if (totalSize !== getShape(x)[axis]) throw new Error(`split: sizes must sum to the size of the axis ${axis}, got ${totalSize}`);
517
+ return bind(Primitive.Split, [x], {
518
+ axis,
519
+ sizes
520
+ });
521
+ }
522
+ function randomBits(k0, k1, shape$1, mode = "xor") {
523
+ if (!deepEqual(k0.shape, k1.shape) || k0.dtype !== DType.Uint32 || k1.dtype !== DType.Uint32) throw new Error(`randomBits: key parts must be uint32 with the same shape, got ${ShapedArray.fromAval(k0.aval)} and ${ShapedArray.fromAval(k1.aval)}`);
524
+ return bind1(Primitive.RandomBits, [k0, k1], {
525
+ shape: shape$1,
526
+ mode
527
+ });
528
+ }
529
+ function gather(x, indices, axis, outDim) {
530
+ if (indices.length === 0) throw new Error("gather() requires at least one index");
531
+ if (!Array.isArray(axis) || axis.length !== indices.length) throw new Error(`Invalid gather() axis: expected ${indices.length} axes, got ${JSON.stringify(axis)}`);
532
+ axis = axis.map((a) => checkAxis(a, ndim$1(x)));
533
+ if (new Set(axis).size !== axis.length) throw new Error(`Invalid gather() axis: duplicate axes ${JSON.stringify(axis)}`);
534
+ outDim = checkAxis(outDim, ndim$1(x) - axis.length + 1);
535
+ return bind1(Primitive.Gather, [x, ...indices], {
536
+ axis,
537
+ outDim
538
+ });
539
+ }
504
540
  function transpose$1(x, perm) {
505
541
  perm = perm ? perm.map((a) => checkAxis(a, ndim$1(x))) : range(ndim$1(x)).reverse();
506
542
  if (!isPermutation(perm, ndim$1(x))) throw new Error(`Invalid transpose permutation for ${ndim$1(x)} axes: ${JSON.stringify(perm)}`);
@@ -550,16 +586,39 @@ function pad$1(x, width) {
550
586
  } else if (width.length !== nd) throw new Error(`Invalid pad(): expected ${nd} axes, got ${width.length}`);
551
587
  return bind1(Primitive.Pad, [x], { width });
552
588
  }
553
- function gather(x, indices, axis, outDim) {
554
- if (indices.length === 0) throw new Error("gather() requires at least one index");
555
- if (!Array.isArray(axis) || axis.length !== indices.length) throw new Error(`Invalid gather() axis: expected ${indices.length} axes, got ${JSON.stringify(axis)}`);
556
- axis = axis.map((a) => checkAxis(a, ndim$1(x)));
557
- if (new Set(axis).size !== axis.length) throw new Error(`Invalid gather() axis: duplicate axes ${JSON.stringify(axis)}`);
558
- outDim = checkAxis(outDim, ndim$1(x) - axis.length + 1);
559
- return bind1(Primitive.Gather, [x, ...indices], {
560
- axis,
561
- outDim
562
- });
589
+ function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
590
+ const as = getShape(a);
591
+ const bs = getShape(b);
592
+ if (as.length < 2 || bs.length < 2) throw new Error(`triangular_solve: must be >=2D, got a=${as}, b=${bs}`);
593
+ const n = as[as.length - 2];
594
+ if (n !== as[as.length - 1] || n !== bs[bs.length - 1]) throw new Error(`triangular_solve: incompatible shapes a=${as}, b=${bs}`);
595
+ if (lower) {
596
+ a = flip$1(a, [-2, -1]);
597
+ b = flip$1(b, [-1]);
598
+ }
599
+ let x = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
600
+ if (lower) x = flip$1(x, [-1]);
601
+ return x;
602
+ }
603
+ function cholesky$2(x) {
604
+ const aval = ShapedArray.fromAval(getAval(x));
605
+ if (aval.ndim < 2 || aval.shape[aval.ndim - 1] !== aval.shape[aval.ndim - 2]) throw new Error(`cholesky: expected batch of square matrices, got ${aval}`);
606
+ return bind1(Primitive.Cholesky, [x]);
607
+ }
608
+ function lu$1(x) {
609
+ const aval = ShapedArray.fromAval(getAval(x));
610
+ if (aval.ndim < 2) throw new Error(`lu: expected batch of matrices, got ${aval}`);
611
+ return bind(Primitive.LU, [x]);
612
+ }
613
+ function sort$1(x) {
614
+ const nd = ndim$1(x);
615
+ if (nd === 0) throw new Error("sort: requires at least 1D input");
616
+ return bind1(Primitive.Sort, [x]);
617
+ }
618
+ function argsort$1(x) {
619
+ const nd = ndim$1(x);
620
+ if (nd === 0) throw new Error("argsort: requires at least 1D input");
621
+ return bind(Primitive.Argsort, [x]);
563
622
  }
564
623
  function bind1(prim, args, params = {}) {
565
624
  const [results] = bind(prim, args, params);
@@ -659,6 +718,9 @@ var Tracer = class Tracer {
659
718
  mul(other) {
660
719
  return mul(this, other);
661
720
  }
721
+ mod(other) {
722
+ return mod(this, other);
723
+ }
662
724
  greater(other) {
663
725
  return greater$1(this, other);
664
726
  }
@@ -722,7 +784,7 @@ var Tracer = class Tracer {
722
784
  if (isFloatDtype(this.dtype)) return this.mul(reciprocal$1(other));
723
785
  return idiv(this, other);
724
786
  }
725
- /** Return specified diagonals. See `numpy.diagonal` for full docs. */
787
+ /** Return specified diagonals. See `jax.numpy.diagonal` for full docs. */
726
788
  diagonal(offset = 0, axis1 = 0, axis2 = 1) {
727
789
  if (!Number.isInteger(offset)) throw new TypeError(`offset must be an integer, got ${offset}`);
728
790
  if (offset < 0) return this.diagonal(-offset, axis2, axis1);
@@ -771,8 +833,42 @@ var Tracer = class Tracer {
771
833
  */
772
834
  *[Symbol.iterator]() {
773
835
  if (this.ndim === 0) throw new Error("Cannot iterate over a scalar array");
774
- for (let i = 0; i < this.shape[0]; i++) yield this.ref.slice(i);
775
- this.dispose();
836
+ let residual = this;
837
+ const subarrayShape = this.shape.slice(1);
838
+ for (let i = 0; i < this.shape[0]; i++) {
839
+ const lr = split$2(residual, 0, [1, residual.shape[0] - 1]);
840
+ yield lr[0].reshape(subarrayShape);
841
+ residual = lr[1];
842
+ }
843
+ residual.dispose();
844
+ }
845
+ /**
846
+ * Return a sorted copy of an array in ascending order.
847
+ *
848
+ * See `jax.numpy.sort` for full docs.
849
+ */
850
+ sort(axis = -1) {
851
+ axis = checkAxis(axis, this.ndim);
852
+ if (this.shape[axis] <= 1) return this;
853
+ if (axis === this.ndim - 1) return sort$1(this);
854
+ const perm = range(this.ndim);
855
+ perm.splice(axis, 1);
856
+ perm.push(axis);
857
+ return sort$1(this.transpose(perm)).transpose(invertPermutation(perm));
858
+ }
859
+ /**
860
+ * Return the indices that would sort an array. This may not be a stable
861
+ * sorting algorithm; it need not preserve order of indices in ties.
862
+ *
863
+ * See `jax.numpy.argsort` for full docs.
864
+ */
865
+ argsort(axis = -1) {
866
+ axis = checkAxis(axis, this.ndim);
867
+ if (axis === this.ndim - 1) return argsort$1(this)[1];
868
+ const perm = range(this.ndim);
869
+ perm.splice(axis, 1);
870
+ perm.push(axis);
871
+ return argsort$1(this.transpose(perm))[1].transpose(invertPermutation(perm));
776
872
  }
777
873
  /**
778
874
  * Slice an array along one or more axes.
@@ -891,6 +987,12 @@ var ShapedArray = class ShapedArray {
891
987
  get ndim() {
892
988
  return this.shape.length;
893
989
  }
990
+ get size() {
991
+ return prod(this.shape);
992
+ }
993
+ scalar() {
994
+ return new ShapedArray([], this.dtype, this.weakType);
995
+ }
894
996
  toString() {
895
997
  return `${this.dtype}[${this.shape.join(",")}]`;
896
998
  }
@@ -1186,13 +1288,13 @@ var Jaxpr = class Jaxpr {
1186
1288
  }
1187
1289
  return new Jaxpr(this.inBinders, liveEqns.reverse(), outs);
1188
1290
  }
1189
- /** Flattens nested JitCall in a Jaxpr. Useful for handling jit-of-jit. */
1291
+ /** Flattens nested Jit in a Jaxpr. Useful for handling jit-of-jit. */
1190
1292
  flatten() {
1191
- if (!this.eqns.some((eqn) => eqn.primitive === Primitive.JitCall)) return this;
1293
+ if (!this.eqns.some((eqn) => eqn.primitive === Primitive.Jit)) return this;
1192
1294
  const newEqns = [];
1193
1295
  const varMap = /* @__PURE__ */ new Map();
1194
1296
  const varMapF = (x) => x instanceof Var ? varMap.get(x) ?? x : x;
1195
- for (const eqn of this.eqns) if (eqn.primitive === Primitive.JitCall) {
1297
+ for (const eqn of this.eqns) if (eqn.primitive === Primitive.Jit) {
1196
1298
  const jaxpr = eqn.params.jaxpr.flatten();
1197
1299
  const translation = /* @__PURE__ */ new Map();
1198
1300
  const translationF = (x) => x instanceof Var ? translation.get(x) : x;
@@ -1293,19 +1395,48 @@ function evalJaxpr(jaxpr, args) {
1293
1395
  function jaxprAsFun(jaxpr) {
1294
1396
  return (...args) => evalJaxpr(jaxpr, args);
1295
1397
  }
1398
+ /** Jaxpr with a collection of associated, traced constants. */
1399
+ var ClosedJaxpr = class ClosedJaxpr {
1400
+ constructor(jaxpr, consts) {
1401
+ this.jaxpr = jaxpr;
1402
+ this.consts = consts;
1403
+ }
1404
+ /** String representation of this Jaxpr. */
1405
+ toString() {
1406
+ return this.jaxpr.toString();
1407
+ }
1408
+ /** Apply a function to the underlying Jaxpr. */
1409
+ mapJaxpr(f) {
1410
+ return new ClosedJaxpr(f(this.jaxpr), this.consts);
1411
+ }
1412
+ /** Dispose of the constants in this Jaxpr. */
1413
+ dispose() {
1414
+ for (const c of this.consts) c.dispose();
1415
+ }
1416
+ };
1296
1417
  /** Tracer that records its operations to dynamically construct a Jaxpr. */
1297
1418
  var JaxprTracer = class extends Tracer {
1419
+ #rc;
1298
1420
  constructor(trace$1, aval) {
1299
1421
  super(trace$1);
1300
1422
  this.aval = aval;
1423
+ this.#rc = 1;
1301
1424
  }
1302
1425
  toString() {
1303
1426
  return `JaxprTracer(${this.aval.toString()})`;
1304
1427
  }
1305
1428
  get ref() {
1429
+ if (this.#rc <= 0) throw new UseAfterFreeError(this);
1430
+ this.#rc++;
1306
1431
  return this;
1307
1432
  }
1308
- dispose() {}
1433
+ dispose() {
1434
+ if (this.#rc <= 0) throw new UseAfterFreeError(this);
1435
+ this.#rc--;
1436
+ }
1437
+ trackLiftedConstant() {
1438
+ this.#rc++;
1439
+ }
1309
1440
  };
1310
1441
  /** Analogous to the 'DynamicJaxprTrace' class in JAX. */
1311
1442
  var JaxprTrace = class extends Trace {
@@ -1318,17 +1449,24 @@ var JaxprTrace = class extends Trace {
1318
1449
  }
1319
1450
  /** Register a constant / literal in this Jaxpr. */
1320
1451
  getOrMakeConstTracer(val) {
1452
+ if (!(val instanceof Tracer)) val = pureArray(val);
1321
1453
  let tracer = this.builder.constTracers.get(val);
1322
1454
  if (tracer === void 0) {
1323
1455
  tracer = this.builder.newTracer(this, ShapedArray.fromAval(getAval(val)));
1324
- this.builder.addConst(tracer, val instanceof Tracer ? val.ref : array(val));
1456
+ this.builder.addConst(tracer, val);
1457
+ } else {
1458
+ val.dispose();
1459
+ tracer.trackLiftedConstant();
1325
1460
  }
1326
1461
  return tracer;
1327
1462
  }
1328
1463
  pure = this.getOrMakeConstTracer;
1329
1464
  lift = this.getOrMakeConstTracer;
1330
1465
  processPrimitive(primitive, tracers, params) {
1331
- const avalsIn = tracers.map((t) => t.aval);
1466
+ const avalsIn = tracers.map((t) => {
1467
+ t.dispose();
1468
+ return t.aval;
1469
+ });
1332
1470
  const avalsOut = abstractEvalRules[primitive](avalsIn, params);
1333
1471
  const outTracers = avalsOut.map((aval) => this.builder.newTracer(this, aval));
1334
1472
  this.builder.addEqn(new JaxprEqn(primitive, tracers.map((t) => this.builder.getVar(t)), params, outTracers.map((t) => this.builder.addVar(t))));
@@ -1371,20 +1509,17 @@ var JaxprBuilder = class {
1371
1509
  return v;
1372
1510
  }
1373
1511
  build(inTracers, outTracers) {
1374
- let [constVars, consts] = unzip2(this.constVals.entries());
1512
+ const [constVars, consts] = unzip2(this.constVals.entries());
1375
1513
  const t2v = this.getVar.bind(this);
1376
1514
  const inBinders = [...constVars, ...inTracers.map(t2v)];
1377
1515
  const outVars = outTracers.map(t2v);
1378
- let jaxpr = new Jaxpr(inBinders, this.eqns, outVars);
1516
+ const jaxpr = new Jaxpr(inBinders, this.eqns, outVars);
1379
1517
  typecheckJaxpr(jaxpr);
1380
- [jaxpr, consts] = _inlineLiterals(jaxpr, consts);
1381
- return {
1382
- jaxpr,
1383
- consts
1384
- };
1518
+ const cjaxpr = new ClosedJaxpr(jaxpr, consts);
1519
+ return _inlineLiterals(cjaxpr);
1385
1520
  }
1386
1521
  };
1387
- function _inlineLiterals(jaxpr, consts) {
1522
+ function _inlineLiterals({ jaxpr, consts }) {
1388
1523
  const literals = /* @__PURE__ */ new Map();
1389
1524
  const constBinders = [];
1390
1525
  const newConsts = [];
@@ -1399,7 +1534,7 @@ function _inlineLiterals(jaxpr, consts) {
1399
1534
  const newOuts = jaxpr.outs.map((x) => literals.get(x) ?? x);
1400
1535
  const newJaxpr = new Jaxpr([...constBinders, ...jaxpr.inBinders.slice(consts.length)], newEqns, newOuts);
1401
1536
  typecheckJaxpr(newJaxpr);
1402
- return [newJaxpr, newConsts];
1537
+ return new ClosedJaxpr(newJaxpr, newConsts);
1403
1538
  }
1404
1539
  function binopAbstractEval([x, y]) {
1405
1540
  if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("binopAbstractEval expects ShapedArray inputs");
@@ -1418,6 +1553,8 @@ const abstractEvalRules = {
1418
1553
  [Primitive.Mul]: binopAbstractEval,
1419
1554
  [Primitive.Idiv]: binopAbstractEval,
1420
1555
  [Primitive.Mod]: binopAbstractEval,
1556
+ [Primitive.Min]: binopAbstractEval,
1557
+ [Primitive.Max]: binopAbstractEval,
1421
1558
  [Primitive.Neg]: vectorizedUnopAbstractEval,
1422
1559
  [Primitive.Reciprocal]: vectorizedUnopAbstractEval,
1423
1560
  [Primitive.Floor]: vectorizedUnopAbstractEval,
@@ -1431,12 +1568,6 @@ const abstractEvalRules = {
1431
1568
  if (byteWidth(x.dtype) !== byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
1432
1569
  return [new ShapedArray(x.shape, dtype, false)];
1433
1570
  },
1434
- [Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
1435
- if (k0.dtype !== DType.Uint32 || k1.dtype !== DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
1436
- const keyShape = generalBroadcast(k0.shape, k1.shape);
1437
- if (!deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
1438
- return [new ShapedArray(shape$1, DType.Uint32, false)];
1439
- },
1440
1571
  [Primitive.Sin]: vectorizedUnopAbstractEval,
1441
1572
  [Primitive.Cos]: vectorizedUnopAbstractEval,
1442
1573
  [Primitive.Asin]: vectorizedUnopAbstractEval,
@@ -1446,8 +1577,6 @@ const abstractEvalRules = {
1446
1577
  [Primitive.Erf]: vectorizedUnopAbstractEval,
1447
1578
  [Primitive.Erfc]: vectorizedUnopAbstractEval,
1448
1579
  [Primitive.Sqrt]: vectorizedUnopAbstractEval,
1449
- [Primitive.Min]: binopAbstractEval,
1450
- [Primitive.Max]: binopAbstractEval,
1451
1580
  [Primitive.Reduce]([x], { axis }) {
1452
1581
  const axisSet = new Set(axis);
1453
1582
  const newShape = x.shape.filter((_, i) => !axisSet.has(i));
@@ -1469,7 +1598,7 @@ const abstractEvalRules = {
1469
1598
  return [new ShapedArray(shape$1, dtype, weakType)];
1470
1599
  },
1471
1600
  [Primitive.Conv]([lhs, rhs], params) {
1472
- const { dtype, weakType } = promoteAvals(new ShapedArray([], lhs.dtype, lhs.weakType), new ShapedArray([], rhs.dtype, rhs.weakType));
1601
+ const { dtype, weakType } = promoteAvals(lhs.scalar(), rhs.scalar());
1473
1602
  const shape$1 = checkConvShape(lhs.shape, rhs.shape, params);
1474
1603
  return [new ShapedArray(shape$1, dtype, weakType)];
1475
1604
  },
@@ -1480,6 +1609,40 @@ const abstractEvalRules = {
1480
1609
  const shape$1 = generalBroadcast(cond.shape, xy.shape);
1481
1610
  return [new ShapedArray(shape$1, xy.dtype, xy.weakType)];
1482
1611
  },
1612
+ [Primitive.Concatenate](xs, { axis }) {
1613
+ if (xs.length === 0) throw new TypeError("Concatenate requires at least one input");
1614
+ for (const x of xs) if (x.ndim !== xs[0].ndim || !x.shape.every((s, i) => i === axis || s === xs[0].shape[i])) throw new TypeError(`Concatenate: inputs ${xs[0]} and ${x} must match shapes except on axis ${axis}`);
1615
+ const shape$1 = xs[0].shape.slice();
1616
+ shape$1[axis] = xs.reduce((sum$1, x) => sum$1 + x.shape[axis], 0);
1617
+ const { dtype, weakType } = xs.map((x) => x.scalar()).reduce(promoteAvals);
1618
+ return [new ShapedArray(shape$1, dtype, weakType)];
1619
+ },
1620
+ [Primitive.Split]([x], { axis, sizes }) {
1621
+ const totalSize = sizes.reduce((a, b) => a + b, 0);
1622
+ if (x.shape[axis] !== totalSize) throw new TypeError(`Split: sizes ${sizes} do not sum to dimension ${x.shape[axis]} on axis ${axis}`);
1623
+ return sizes.map((size$1) => {
1624
+ return new ShapedArray(x.shape.toSpliced(axis, 1, size$1), x.dtype, x.weakType);
1625
+ });
1626
+ },
1627
+ [Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
1628
+ if (k0.dtype !== DType.Uint32 || k1.dtype !== DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
1629
+ if (!deepEqual(k0.shape, k1.shape)) throw new TypeError(`RandomBits: Keys have different shapes ${k0.shape} and ${k1.shape}`);
1630
+ if (!deepEqual(shape$1.slice(0, k0.ndim), k0.shape)) throw new TypeError(`RandomBits: generated shape ${shape$1} must match key shape ${k0.shape}`);
1631
+ return [new ShapedArray(shape$1, DType.Uint32, false)];
1632
+ },
1633
+ [Primitive.Gather]([x, ...indices], { axis, outDim }) {
1634
+ for (const a of indices) if (a.dtype !== DType.Int32 && a.dtype !== DType.Uint32) throw new TypeError(`Gather indices must be Int32 or Uint32, got ${a.dtype}`);
1635
+ if (axis.length !== indices.length) throw new TypeError(`Gather: ${axis} axes but ${indices.length} indices`);
1636
+ if (indices.length === 0) throw new TypeError("Gather must have 1+ indices with same shape");
1637
+ if (axis.some((a) => a < 0 || a >= x.shape.length)) throw new TypeError("Gather axis out of bounds");
1638
+ if (outDim < 0 || outDim > x.shape.length - axis.length) throw new TypeError("Gather outDim out of bounds");
1639
+ const axisSet = new Set(axis);
1640
+ if (axisSet.size !== axis.length) throw new TypeError("Gather axes are not unique");
1641
+ const gatherShape = indices.reduce((shape$1, a) => generalBroadcast(shape$1, a.shape), []);
1642
+ const newShape = x.shape.filter((_, i) => !axisSet.has(i));
1643
+ newShape.splice(outDim, 0, ...gatherShape);
1644
+ return [new ShapedArray(newShape, x.dtype, x.weakType)];
1645
+ },
1483
1646
  [Primitive.Transpose]([x], { perm }) {
1484
1647
  return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype, x.weakType)];
1485
1648
  },
@@ -1500,23 +1663,41 @@ const abstractEvalRules = {
1500
1663
  const newShape = x.shape.map((dim, i) => dim + width[i][0] + width[i][1]);
1501
1664
  return [new ShapedArray(newShape, x.dtype, x.weakType)];
1502
1665
  },
1503
- [Primitive.Gather]([x, ...indices], { axis, outDim }) {
1504
- for (const a of indices) if (a.dtype !== DType.Int32 && a.dtype !== DType.Uint32) throw new TypeError(`Gather indices must be Int32 or Uint32, got ${a.dtype}`);
1505
- if (axis.length !== indices.length) throw new TypeError(`Gather: ${axis} axes but ${indices.length} indices`);
1506
- if (indices.length === 0) throw new TypeError("Gather must have 1+ indices with same shape");
1507
- if (axis.some((a) => a < 0 || a >= x.shape.length)) throw new TypeError("Gather axis out of bounds");
1508
- if (outDim < 0 || outDim > x.shape.length - axis.length) throw new TypeError("Gather outDim out of bounds");
1509
- const axisSet = new Set(axis);
1510
- if (axisSet.size !== axis.length) throw new TypeError("Gather axes are not unique");
1511
- const gatherShape = indices.reduce((shape$1, a) => generalBroadcast(shape$1, a.shape), []);
1512
- const newShape = x.shape.filter((_, i) => !axisSet.has(i));
1513
- newShape.splice(outDim, 0, ...gatherShape);
1514
- return [new ShapedArray(newShape, x.dtype, x.weakType)];
1666
+ [Primitive.Sort]([x]) {
1667
+ if (x.ndim === 0) throw new TypeError("sort: requires at least 1D input");
1668
+ return [ShapedArray.fromAval(x)];
1669
+ },
1670
+ [Primitive.Argsort]([x]) {
1671
+ if (x.ndim === 0) throw new TypeError("argsort: requires at least 1D input");
1672
+ return [ShapedArray.fromAval(x), new ShapedArray(x.shape, DType.Int32, false)];
1673
+ },
1674
+ [Primitive.TriangularSolve]([a, b]) {
1675
+ if (a.ndim < 2) throw new TypeError(`triangular_solve: a must be at least 2D, got ${a}`);
1676
+ if (b.ndim < 2) throw new TypeError(`triangular_solve: b must be at least 2D, got ${b}`);
1677
+ const [m, n] = a.shape.slice(-2);
1678
+ const [_batch, q] = b.shape.slice(-2);
1679
+ if (!deepEqual(a.shape.slice(0, -2), b.shape.slice(0, -2)) || a.dtype !== b.dtype || m !== n || n !== q) throw new TypeError(`triangular_solve: mismatch ${a} vs ${b}`);
1680
+ return [new ShapedArray(b.shape, b.dtype, a.weakType && b.weakType)];
1681
+ },
1682
+ [Primitive.Cholesky]([a]) {
1683
+ if (a.ndim < 2) throw new TypeError(`cholesky: requires at least 2D input, got ${a}`);
1684
+ if (a.shape[a.ndim - 2] !== a.shape[a.ndim - 1]) throw new TypeError(`cholesky: must be square, got ${a}`);
1685
+ return [ShapedArray.fromAval(a)];
1686
+ },
1687
+ [Primitive.LU]([a]) {
1688
+ if (a.ndim < 2) throw new TypeError(`lu: requires at least 2D input, got ${a}`);
1689
+ const batch = a.shape.slice(0, -2);
1690
+ const [m, n] = a.shape.slice(-2);
1691
+ return [
1692
+ ShapedArray.fromAval(a),
1693
+ new ShapedArray([...batch, Math.min(m, n)], DType.Int32, false),
1694
+ new ShapedArray([...batch, m], DType.Int32, false)
1695
+ ];
1515
1696
  },
1516
- [Primitive.JitCall](args, { jaxpr }) {
1697
+ [Primitive.Jit](args, { jaxpr }) {
1517
1698
  const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
1518
- if (args.length !== inTypes.length) throw new TypeError(`jit_call expected ${inTypes.length} arguments, got ${args.length}`);
1519
- for (let i = 0; i < inTypes.length; i++) if (!args[i].equals(inTypes[i])) throw new TypeError(`jit_call argument ${i} has type ${args[i]}, expected ${inTypes[i]}`);
1699
+ if (args.length !== inTypes.length) throw new TypeError(`jit expected ${inTypes.length} arguments, got ${args.length}`);
1700
+ for (let i = 0; i < inTypes.length; i++) if (!args[i].equals(inTypes[i])) throw new TypeError(`jit argument ${i} has type ${args[i]}, expected ${inTypes[i]}`);
1520
1701
  return outTypes;
1521
1702
  }
1522
1703
  };
@@ -1552,11 +1733,10 @@ function makeJaxpr$1(f, opts) {
1552
1733
  const tracersIn = avalsIn.map((aval) => trace$1.newArg(typeof aval === "object" ? aval : pureArray(aval)));
1553
1734
  const outs = fFlat(...tracersIn);
1554
1735
  const tracersOut = outs.map((out) => fullRaise(trace$1, out));
1555
- const { jaxpr, consts } = builder.build(tracersIn, tracersOut);
1736
+ const jaxpr = builder.build(tracersIn, tracersOut);
1556
1737
  if (outTree.value === void 0) throw new Error("outTree was not set in makeJaxpr");
1557
1738
  return {
1558
- jaxpr: jaxpr.simplify(),
1559
- consts,
1739
+ jaxpr: jaxpr.mapJaxpr((j) => j.simplify()),
1560
1740
  treedef: outTree.value
1561
1741
  };
1562
1742
  } catch (_) {
@@ -1575,22 +1755,29 @@ function jit$1(f, opts) {
1575
1755
  const avalsInFlat = argsFlat.map((x) => ShapedArray.fromAval(getAval(x)));
1576
1756
  const avalsIn = unflatten(inTree, avalsInFlat);
1577
1757
  const jaxprArgs = joinIdx(args.length, staticArgs, avalsIn, staticArgnums);
1578
- const { jaxpr, consts, treedef: outTree } = runWithCache(cache, jaxprArgs, () => makeJaxpr$1(f, opts)(...jaxprArgs));
1579
- const outs = bind(Primitive.JitCall, [...consts.map((c) => c.ref), ...argsFlat], {
1758
+ const { jaxpr, treedef: outTree } = runWithCache(cache, jaxprArgs, () => makeJaxpr$1(f, opts)(...jaxprArgs));
1759
+ const outs = bind(Primitive.Jit, [...jaxpr.consts.map((c) => c.ref), ...argsFlat], {
1580
1760
  name: f.name || "closure",
1581
- jaxpr,
1582
- numConsts: consts.length
1761
+ jaxpr: jaxpr.jaxpr,
1762
+ numConsts: jaxpr.consts.length
1583
1763
  });
1584
1764
  return unflatten(outTree, outs);
1585
1765
  });
1586
1766
  result.dispose = () => {
1587
- for (const { consts } of cache.values()) for (const c of consts) c.dispose();
1767
+ for (const { jaxpr } of cache.values()) jaxpr.dispose();
1588
1768
  };
1589
1769
  return result;
1590
1770
  }
1591
1771
 
1592
1772
  //#endregion
1593
1773
  //#region src/frontend/jit.ts
1774
+ const routinePrimitives = new Map([
1775
+ [Primitive.Sort, Routines.Sort],
1776
+ [Primitive.Argsort, Routines.Argsort],
1777
+ [Primitive.TriangularSolve, Routines.TriangularSolve],
1778
+ [Primitive.Cholesky, Routines.Cholesky],
1779
+ [Primitive.LU, Routines.LU]
1780
+ ]);
1594
1781
  /** Result of compiling a Jaxpr. Can be evaluated on a series of inputs. */
1595
1782
  var JitProgram = class {
1596
1783
  constructor(backend, steps, inputs, outputs) {
@@ -1605,9 +1792,14 @@ var JitProgram = class {
1605
1792
  case "execute": {
1606
1793
  const inputsNice = step.inputs.map((id, i) => `${i}: %${id}`).join(", ");
1607
1794
  const outputsNice = step.outputs.map((id) => `%${id}`).join(", ");
1608
- return PPrint.pp(`execute (${inputsNice}) -> ${outputsNice}, kernel`).concat(step.kernel.pprint().indent(2));
1795
+ const executeText = `execute (${inputsNice}) -> ${outputsNice}`;
1796
+ if (step.source instanceof Kernel) return PPrint.pp(`${executeText}, kernel`).concat(step.source.pprint().indent(2));
1797
+ else if (step.source instanceof Routine) return PPrint.pp(`${executeText}, routine ${step.source.name}`);
1798
+ else {
1799
+ step.source;
1800
+ return PPrint.pp(executeText);
1801
+ }
1609
1802
  }
1610
- case "const": return PPrint.pp(`%${step.output} = const <Slot ${step.slot}>`);
1611
1803
  case "malloc": return PPrint.pp(`%${step.output} = malloc <${step.size} bytes>`);
1612
1804
  case "incref": return PPrint.pp(`incref ${step.input}`);
1613
1805
  case "free": return PPrint.pp(`free ${step.input}`);
@@ -1630,12 +1822,9 @@ var JitProgram = class {
1630
1822
  const inputs$1 = step.inputs.map((id) => scope.get(id));
1631
1823
  const outputs = step.outputs.map((id) => scope.get(id));
1632
1824
  if (inputs$1.some((s) => s === void 0) || outputs.some((s) => s === void 0)) throw new Error(`internal: JitProgram scope undefined`);
1633
- pending.push(new PendingExecute(this.backend, step.kernel, inputs$1, outputs));
1825
+ pending.push(new PendingExecute(this.backend, step.source, inputs$1, outputs));
1634
1826
  break;
1635
1827
  }
1636
- case "const":
1637
- scope.set(step.output, step.slot);
1638
- break;
1639
1828
  case "malloc": {
1640
1829
  const slot = this.backend.malloc(step.size);
1641
1830
  scope.set(step.output, slot);
@@ -1669,34 +1858,37 @@ var JitProgramBuilder = class {
1669
1858
  this.#nextId = nargs;
1670
1859
  this.steps = [];
1671
1860
  }
1672
- pushConst(slot) {
1673
- const id = this.#nextId++;
1674
- this.steps.push({
1675
- type: "const",
1676
- slot,
1677
- output: id
1678
- });
1679
- return id;
1680
- }
1681
1861
  pushLit(lit) {
1682
- const kernel = new Kernel(0, prod(lit.aval.shape), AluExp.const(lit.dtype, lit.value));
1862
+ const kernel = new Kernel(0, lit.aval.size, AluExp.const(lit.dtype, lit.value));
1683
1863
  return this.pushKernel(kernel, []);
1684
1864
  }
1685
- pushKernel(kernel, inputs) {
1865
+ pushBuffer(size$1) {
1686
1866
  const id = this.#nextId++;
1687
1867
  this.steps.push({
1688
1868
  type: "malloc",
1689
- size: kernel.bytes,
1869
+ size: size$1,
1690
1870
  output: id
1691
1871
  });
1872
+ return id;
1873
+ }
1874
+ pushKernel(kernel, inputs) {
1875
+ const id = this.pushBuffer(kernel.bytes);
1692
1876
  this.steps.push({
1693
1877
  type: "execute",
1694
- kernel,
1878
+ source: kernel,
1695
1879
  inputs,
1696
1880
  outputs: [id]
1697
1881
  });
1698
1882
  return id;
1699
1883
  }
1884
+ pushRoutine(routine, inputs, outputs) {
1885
+ this.steps.push({
1886
+ type: "execute",
1887
+ source: routine,
1888
+ inputs,
1889
+ outputs
1890
+ });
1891
+ }
1700
1892
  pushIncref(id) {
1701
1893
  this.steps.push({
1702
1894
  type: "incref",
@@ -1722,28 +1914,18 @@ var JitProgramBuilder = class {
1722
1914
  }
1723
1915
  };
1724
1916
  const jitCompileCache = /* @__PURE__ */ new Map();
1725
- function jitCompile(backend, jaxpr, consts) {
1726
- if (jaxpr.inBinders.length < consts.length) throw new TypeError(`Jaxpr has ${jaxpr.inBinders.length} inputs, but ${consts.length} consts were provided`);
1727
- for (let i = 0; i < consts.length; i++) if (consts[i].device !== backend.type) throw new TypeError(`Const ${i} has device ${consts[i].device}, but expected ${backend.type}`);
1728
- const cacheKey = backend.type + FpHash.hash(jaxpr, ...consts.map((c) => c.id));
1917
+ function jitCompile(backend, jaxpr) {
1918
+ const cacheKey = backend.type + "," + FpHash.hash(jaxpr);
1729
1919
  const cached = jitCompileCache.get(cacheKey);
1730
1920
  if (cached) return cached;
1731
1921
  if (DEBUG >= 1) console.info("=========== JIT Compile ===========\n" + jaxpr.toString());
1732
1922
  jaxpr = jaxpr.flatten().simplify();
1733
- const nargs = jaxpr.inBinders.length - consts.length;
1923
+ const nargs = jaxpr.inBinders.length;
1734
1924
  const builder = new JitProgramBuilder(backend, nargs);
1735
1925
  const blackNodes = splitGraphDataflow(backend, jaxpr);
1736
1926
  const ctx = /* @__PURE__ */ new Map();
1737
- for (let i = 0; i < consts.length; i++) {
1738
- const v = jaxpr.inBinders[i];
1739
- const slot = consts[i]._realizeSource();
1740
- ctx.set(v, {
1741
- type: "imm",
1742
- arg: builder.pushConst(slot)
1743
- });
1744
- }
1745
1927
  for (let i = 0; i < nargs; i++) {
1746
- const v = jaxpr.inBinders[consts.length + i];
1928
+ const v = jaxpr.inBinders[i];
1747
1929
  ctx.set(v, {
1748
1930
  type: "imm",
1749
1931
  arg: i
@@ -1751,6 +1933,31 @@ function jitCompile(backend, jaxpr, consts) {
1751
1933
  }
1752
1934
  for (let i = 0; i < jaxpr.eqns.length; i++) {
1753
1935
  const eqn = jaxpr.eqns[i];
1936
+ if (routinePrimitives.has(eqn.primitive)) {
1937
+ const routine = new Routine(routinePrimitives.get(eqn.primitive), {
1938
+ inputShapes: eqn.inputs.map((x) => x.aval.shape),
1939
+ inputDtypes: eqn.inputs.map((x) => x.aval.dtype),
1940
+ outputShapes: eqn.outBinders.map((x) => x.aval.shape),
1941
+ outputDtypes: eqn.outBinders.map((x) => x.aval.dtype)
1942
+ }, eqn.params);
1943
+ const inputs = [];
1944
+ for (const input of eqn.inputs) if (input instanceof Var) {
1945
+ const jv = ctx.get(input);
1946
+ if (jv.type !== "imm") throw new Error(`jit: routine primitive ${eqn.primitive} input is not imm`);
1947
+ inputs.push(jv.arg);
1948
+ } else if (input instanceof Lit) inputs.push(builder.pushLit(input));
1949
+ const outputs = [];
1950
+ for (const outVar of eqn.outBinders) {
1951
+ const outId = builder.pushBuffer(outVar.aval.size * byteWidth(outVar.aval.dtype));
1952
+ outputs.push(outId);
1953
+ ctx.set(outVar, {
1954
+ type: "imm",
1955
+ arg: outId
1956
+ });
1957
+ }
1958
+ builder.pushRoutine(routine, inputs, outputs);
1959
+ continue;
1960
+ }
1754
1961
  const inputExps = [];
1755
1962
  const inputAvals = [];
1756
1963
  const inputArgs = [];
@@ -1794,35 +2001,37 @@ function jitCompile(backend, jaxpr, consts) {
1794
2001
  let reduction;
1795
2002
  if (inputReduction) {
1796
2003
  const jv = inputReduction;
1797
- const newEpilogue = rule(inputExps, inputAvals, eqn.params).exp;
1798
- exp$2 = jv.exp.reindexGids(addArgs(jv.args));
2004
+ const newEpilogue = rule(inputExps, inputAvals, eqn.params).exp[0];
2005
+ exp$2 = [jv.exp.reindexGids(addArgs(jv.args))];
1799
2006
  reduction = new Reduction(jv.reduction.dtype, jv.reduction.op, jv.reduction.size, newEpilogue);
1800
2007
  } else {
1801
2008
  const ruleOutput = rule(inputExps, inputAvals, eqn.params);
1802
2009
  exp$2 = ruleOutput.exp;
1803
2010
  reduction = ruleOutput.reduction;
1804
2011
  }
1805
- const outVar = eqn.outBinders[0];
1806
- if (blackNodes.has(outVar)) {
1807
- const nargs$1 = inputArgs.length;
1808
- const size$1 = prod(outVar.aval.shape);
1809
- const kernel = new Kernel(nargs$1, size$1, exp$2, reduction);
1810
- const outId = builder.pushKernel(kernel, inputArgs);
1811
- ctx.set(outVar, {
1812
- type: "imm",
1813
- arg: outId
2012
+ for (let i$1 = 0; i$1 < eqn.outBinders.length; i$1++) {
2013
+ const outVar = eqn.outBinders[i$1];
2014
+ if (blackNodes.has(outVar)) {
2015
+ const nargs$1 = inputArgs.length;
2016
+ const size$1 = outVar.aval.size;
2017
+ const kernel = new Kernel(nargs$1, size$1, exp$2[i$1], reduction);
2018
+ const outId = builder.pushKernel(kernel, inputArgs);
2019
+ ctx.set(outVar, {
2020
+ type: "imm",
2021
+ arg: outId
2022
+ });
2023
+ } else if (reduction) ctx.set(outVar, {
2024
+ type: "red",
2025
+ exp: exp$2[i$1],
2026
+ reduction,
2027
+ args: inputArgs
1814
2028
  });
1815
- } else if (reduction) ctx.set(outVar, {
1816
- type: "red",
1817
- exp: exp$2,
1818
- reduction,
1819
- args: inputArgs
1820
- });
1821
- else ctx.set(outVar, {
1822
- type: "exp",
1823
- exp: exp$2,
1824
- args: inputArgs
1825
- });
2029
+ else ctx.set(outVar, {
2030
+ type: "exp",
2031
+ exp: exp$2[i$1],
2032
+ args: inputArgs
2033
+ });
2034
+ }
1826
2035
  }
1827
2036
  const outputIds = [];
1828
2037
  for (const out of jaxpr.outs) if (out instanceof Var) {
@@ -1830,7 +2039,7 @@ function jitCompile(backend, jaxpr, consts) {
1830
2039
  if (jitValue.type !== "imm") throw new Error("internal: Expected imm, since outs are black nodes");
1831
2040
  outputIds.push(jitValue.arg);
1832
2041
  } else if (out instanceof Lit) outputIds.push(builder.pushLit(out));
1833
- const outputNeedsRef = new Set([...range(nargs), ...builder.steps.filter((s) => s.type === "const").map((s) => s.output)]);
2042
+ const outputNeedsRef = new Set(range(nargs));
1834
2043
  for (const outputId of outputIds) if (outputNeedsRef.has(outputId)) builder.pushIncref(outputId);
1835
2044
  else outputNeedsRef.add(outputId);
1836
2045
  builder.insertFreeSteps(outputIds);
@@ -1863,17 +2072,22 @@ function broadcastedJit(fn, opts) {
1863
2072
  if (exp$2.dtype !== newDtype && !skipCastIdx.includes(i)) exp$2 = AluExp.cast(newDtype, exp$2);
1864
2073
  return exp$2;
1865
2074
  });
1866
- return { exp: fn(exps, params) };
2075
+ return { exp: [fn(exps, params)] };
1867
2076
  };
1868
2077
  }
1869
2078
  function unopJit(fn) {
1870
2079
  return ([a], [_as], params) => {
1871
- return { exp: fn(a, params) };
2080
+ return { exp: [fn(a, params)] };
1872
2081
  };
1873
2082
  }
1874
2083
  function reshapeJit(fn) {
1875
2084
  return ([a], [_as], params) => {
1876
- return { exp: reshapeViews(a, (st) => fn(st, params)) };
2085
+ return { exp: [reshapeViews(a, (st) => fn(st, params))] };
2086
+ };
2087
+ }
2088
+ function routineNoJit() {
2089
+ return () => {
2090
+ throw new Error("jit: rule is not implemented for routines");
1877
2091
  };
1878
2092
  }
1879
2093
  const jitRules = {
@@ -1881,6 +2095,8 @@ const jitRules = {
1881
2095
  [Primitive.Mul]: broadcastedJit(([a, b]) => AluExp.mul(a, b)),
1882
2096
  [Primitive.Idiv]: broadcastedJit(([a, b]) => AluExp.idiv(a, b)),
1883
2097
  [Primitive.Mod]: broadcastedJit(([a, b]) => AluExp.mod(a, b)),
2098
+ [Primitive.Min]: broadcastedJit(([a, b]) => AluExp.min(a, b)),
2099
+ [Primitive.Max]: broadcastedJit(([a, b]) => AluExp.max(a, b)),
1884
2100
  [Primitive.Neg]: unopJit((a) => AluExp.sub(AluExp.const(a.dtype, 0), a)),
1885
2101
  [Primitive.Reciprocal]: unopJit(AluExp.reciprocal),
1886
2102
  [Primitive.Floor]: unopJit(AluExp.floor),
@@ -1888,17 +2104,6 @@ const jitRules = {
1888
2104
  [Primitive.StopGradient]: unopJit((a) => a),
1889
2105
  [Primitive.Cast]: unopJit((a, { dtype }) => AluExp.cast(dtype, a)),
1890
2106
  [Primitive.Bitcast]: unopJit((a, { dtype }) => AluExp.bitcast(dtype, a)),
1891
- [Primitive.RandomBits]: (keys, keyShapes, { shape: shape$1, mode }) => {
1892
- const mapping = (st) => {
1893
- if (!deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, range(shape$1.length - st.shape.length));
1894
- };
1895
- const k0 = reshapeViews(keys[0], mapping);
1896
- const k1 = reshapeViews(keys[1], mapping);
1897
- const c0 = AluExp.u32(0);
1898
- const c1 = AluExp.cast(DType.Uint32, AluVar.gidx);
1899
- const exp$2 = AluExp.threefry2x32(k0, k1, c0, c1, mode);
1900
- return { exp: exp$2 };
1901
- },
1902
2107
  [Primitive.Sin]: unopJit(AluExp.sin),
1903
2108
  [Primitive.Cos]: unopJit(AluExp.cos),
1904
2109
  [Primitive.Asin]: unopJit(AluExp.asin),
@@ -1908,8 +2113,6 @@ const jitRules = {
1908
2113
  [Primitive.Erf]: unopJit(AluExp.erf),
1909
2114
  [Primitive.Erfc]: unopJit(AluExp.erfc),
1910
2115
  [Primitive.Sqrt]: unopJit(AluExp.sqrt),
1911
- [Primitive.Min]: broadcastedJit(([a, b]) => AluExp.min(a, b)),
1912
- [Primitive.Max]: broadcastedJit(([a, b]) => AluExp.max(a, b)),
1913
2116
  [Primitive.Reduce]([a], [as], { op, axis }) {
1914
2117
  const keptAxes = [];
1915
2118
  const shiftedAxes = [];
@@ -1925,7 +2128,7 @@ const jitRules = {
1925
2128
  a = reshapeViews(a, (st) => st.permute(perm).reshape(newShape), true);
1926
2129
  const reduction = new Reduction(a.dtype, op, reductionSize);
1927
2130
  return {
1928
- exp: a,
2131
+ exp: [a],
1929
2132
  reduction
1930
2133
  };
1931
2134
  },
@@ -1936,13 +2139,13 @@ const jitRules = {
1936
2139
  a = reshapeViews(a, (st) => st.compose(stX), true);
1937
2140
  const reduction = new Reduction(a.dtype, AluOp.Add, stX.shape[stX.shape.length - 1]);
1938
2141
  return {
1939
- exp: a,
2142
+ exp: [a],
1940
2143
  reduction
1941
2144
  };
1942
2145
  },
1943
2146
  [Primitive.Dot]([a, b], [as, bs]) {
1944
2147
  const k1 = jitRules[Primitive.Mul]([a, b], [as, bs], {});
1945
- const c = k1.exp;
2148
+ const [c] = k1.exp;
1946
2149
  const cs = promoteAvals(as, bs);
1947
2150
  return jitRules[Primitive.Reduce]([c], [cs], {
1948
2151
  op: AluOp.Add,
@@ -1959,16 +2162,42 @@ const jitRules = {
1959
2162
  },
1960
2163
  [Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
1961
2164
  [Primitive.Where]: broadcastedJit(([cond, a, b]) => AluExp.where(cond, a, b), { skipCastIdx: [0] }),
1962
- [Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
1963
- [Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
1964
- [Primitive.Reshape]: reshapeJit((st, { shape: shape$1 }) => st.reshape(shape$1)),
1965
- [Primitive.Flip]: reshapeJit((st, { axis }) => {
1966
- const arg = rep(st.shape.length, false);
1967
- for (const ax of axis) arg[ax] = true;
1968
- return st.flip(arg);
1969
- }),
1970
- [Primitive.Shrink]: reshapeJit((st, { slice }) => st.shrink(slice)),
1971
- [Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
2165
+ [Primitive.Concatenate](exps, avals, { axis }) {
2166
+ const ndim$2 = avals[0].ndim;
2167
+ const sizes = avals.map((x) => x.shape[axis]);
2168
+ const finalSize = sizes.reduce((a, b) => a + b, 0);
2169
+ const makePadAxis = (start, end) => range(ndim$2).map((i) => i === axis ? [start, end] : [0, 0]);
2170
+ let cum = 0;
2171
+ const src = [];
2172
+ for (let i = 0; i < exps.length; i++) {
2173
+ const padding = makePadAxis(cum, finalSize - cum - sizes[i]);
2174
+ src.push(reshapeViews(exps[i], (st) => st.pad(padding)));
2175
+ cum += sizes[i];
2176
+ }
2177
+ return { exp: [src.reduce(AluExp.add)] };
2178
+ },
2179
+ [Primitive.Split]([a], [as], { axis, sizes }) {
2180
+ const exp$2 = [];
2181
+ let start = 0;
2182
+ for (const size$1 of sizes) {
2183
+ const slice = range(as.ndim).map((d) => d === axis ? [start, start + size$1] : [0, as.shape[d]]);
2184
+ exp$2.push(reshapeViews(a, (st) => st.shrink(slice)));
2185
+ start += size$1;
2186
+ }
2187
+ return { exp: exp$2 };
2188
+ },
2189
+ [Primitive.RandomBits]: (keys, keyShapes, { shape: shape$1, mode }) => {
2190
+ const keyShape = keyShapes[0].shape;
2191
+ const mapping = (st) => {
2192
+ if (!deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, range(st.shape.length, shape$1.length));
2193
+ };
2194
+ const k0 = reshapeViews(keys[0], mapping);
2195
+ const k1 = reshapeViews(keys[1], mapping);
2196
+ const c0 = AluExp.u32(0);
2197
+ const c1 = AluExp.mod(AluExp.cast(DType.Uint32, AluVar.gidx), AluExp.u32(Math.max(prod(shape$1.slice(keyShape.length)), 1)));
2198
+ const exp$2 = AluExp.threefry2x32(k0, k1, c0, c1, mode);
2199
+ return { exp: [exp$2] };
2200
+ },
1972
2201
  [Primitive.Gather]([x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
1973
2202
  const axisSet = new Set(axis);
1974
2203
  const indexShape = indicesShapes.map((c) => c.shape).reduce(generalBroadcast);
@@ -1982,10 +2211,25 @@ const jitRules = {
1982
2211
  for (const [i, iexp] of indices.entries()) src[axis[i]] = AluExp.cast(DType.Int32, reshapeViews(iexp, (st) => st.broadcast(finalShape, [...range(outDim + indexShape.length - st.shape.length), ...range(outDim + indexShape.length, finalShape.length)])));
1983
2212
  const [index, valid] = ShapeTracker.fromShape(xs.shape).toAluExp(src);
1984
2213
  if (!valid.resolve()) throw new Error("internal: expected full validity mask in Gather");
1985
- return { exp: x.substitute({ gidx: index }) };
2214
+ return { exp: [x.substitute({ gidx: index })] };
1986
2215
  },
1987
- [Primitive.JitCall]() {
1988
- throw new Error("internal: JitCall should have been flattened before JIT compilation");
2216
+ [Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
2217
+ [Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
2218
+ [Primitive.Reshape]: reshapeJit((st, { shape: shape$1 }) => st.reshape(shape$1)),
2219
+ [Primitive.Flip]: reshapeJit((st, { axis }) => {
2220
+ const arg = rep(st.shape.length, false);
2221
+ for (const ax of axis) arg[ax] = true;
2222
+ return st.flip(arg);
2223
+ }),
2224
+ [Primitive.Shrink]: reshapeJit((st, { slice }) => st.shrink(slice)),
2225
+ [Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
2226
+ [Primitive.Sort]: routineNoJit(),
2227
+ [Primitive.Argsort]: routineNoJit(),
2228
+ [Primitive.TriangularSolve]: routineNoJit(),
2229
+ [Primitive.Cholesky]: routineNoJit(),
2230
+ [Primitive.LU]: routineNoJit(),
2231
+ [Primitive.Jit]() {
2232
+ throw new Error("internal: Jit should have been flattened before JIT compilation");
1989
2233
  }
1990
2234
  };
1991
2235
  /** Determines how to split the Jaxpr into kernels via dataflow analysis. */
@@ -2043,8 +2287,8 @@ function splitGraphDataflow(backend, jaxpr) {
2043
2287
  case Primitive.Mul:
2044
2288
  case Primitive.Idiv:
2045
2289
  case Primitive.Mod:
2046
- case Primitive.Max:
2047
- case Primitive.Min: {
2290
+ case Primitive.Min:
2291
+ case Primitive.Max: {
2048
2292
  const otherInput = nextEqn.inputs.find((v) => v !== outVar);
2049
2293
  if (otherInput instanceof Lit || deepEqual(generalBroadcast(otherInput.aval.shape, outVar.aval.shape), outVar.aval.shape)) {
2050
2294
  head = usages[0];
@@ -2064,11 +2308,11 @@ function splitGraphDataflow(backend, jaxpr) {
2064
2308
  blackNodes.add(v);
2065
2309
  p1NextBlack.set(v, v);
2066
2310
  }
2067
- const heterogeneousViewPrimitives = [Primitive.Gather, Primitive.RandomBits];
2311
+ const heterogeneousViewPrimitives = [Primitive.RandomBits, Primitive.Gather];
2068
2312
  const needsCleanShapePrimitives = [Primitive.Pad];
2069
2313
  for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
2070
2314
  const eqn = jaxpr.eqns[i];
2071
- if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
2315
+ if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || routinePrimitives.has(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
2072
2316
  for (const v of eqn.outBinders) {
2073
2317
  blackNodes.add(v);
2074
2318
  p1NextBlack.set(v, v);
@@ -2078,7 +2322,7 @@ function splitGraphDataflow(backend, jaxpr) {
2078
2322
  const reach = /* @__PURE__ */ new Set();
2079
2323
  let needsCleanOutput = false;
2080
2324
  outer: for (const v of eqn.outBinders) for (const j of varToUsages.get(v) ?? []) {
2081
- if (needsCleanShapePrimitives.includes(jaxpr.eqns[j].primitive)) {
2325
+ if (needsCleanShapePrimitives.includes(jaxpr.eqns[j].primitive) || routinePrimitives.has(jaxpr.eqns[j].primitive)) {
2082
2326
  needsCleanOutput = true;
2083
2327
  break outer;
2084
2328
  }
@@ -2102,7 +2346,6 @@ function splitGraphDataflow(backend, jaxpr) {
2102
2346
  while (p2idx < jaxpr.eqns.length) {
2103
2347
  const eqn = jaxpr.eqns[p2idx++];
2104
2348
  const deps = [];
2105
- if (eqn.outBinders.some((v) => blackNodes.has(v))) continue;
2106
2349
  for (const input of eqn.inputs) if (input instanceof Var) if (blackNodes.has(input)) deps.push(new Set([input]));
2107
2350
  else deps.push(p2Deps.get(input));
2108
2351
  else deps.push(/* @__PURE__ */ new Set());
@@ -2125,7 +2368,7 @@ function splitGraphDataflow(backend, jaxpr) {
2125
2368
  if (assocInput === -1) throw new Error(`internal: maxArgs, no input found to mark as black in Jaxpr equation ${eqn}`);
2126
2369
  const assocVar = eqn.inputs[assocInput];
2127
2370
  p2idx = varToDefn.get(assocVar);
2128
- for (const out of jaxpr.eqns[p2idx].outBinders) blackNodes.add(out);
2371
+ for (const out of jaxpr.eqns[p2idx++].outBinders) blackNodes.add(out);
2129
2372
  } else {
2130
2373
  const s = new Set(depCounter.keys());
2131
2374
  for (const out of eqn.outBinders) p2Deps.set(out, s);
@@ -2151,9 +2394,9 @@ var PendingExecute = class {
2151
2394
  submitted = false;
2152
2395
  #promise = null;
2153
2396
  #rc = 1;
2154
- constructor(backend, kernel, inputs, outputs) {
2397
+ constructor(backend, source, inputs, outputs) {
2155
2398
  this.backend = backend;
2156
- this.kernel = kernel;
2399
+ this.source = source;
2157
2400
  this.inputs = inputs;
2158
2401
  this.outputs = outputs;
2159
2402
  for (const slot of inputs) this.backend.incRef(slot);
@@ -2174,13 +2417,15 @@ var PendingExecute = class {
2174
2417
  return;
2175
2418
  }
2176
2419
  this.#promise = (async () => {
2177
- this.prepared = await this.backend.prepare(this.kernel);
2420
+ if (this.source instanceof Kernel) this.prepared = await this.backend.prepareKernel(this.source);
2421
+ else this.prepared = await this.backend.prepareRoutine(this.source);
2178
2422
  })();
2179
2423
  await this.#promise;
2180
2424
  }
2181
2425
  prepareSync() {
2182
2426
  if (this.prepared) return;
2183
- this.prepared = this.backend.prepareSync(this.kernel);
2427
+ if (this.source instanceof Kernel) this.prepared = this.backend.prepareKernelSync(this.source);
2428
+ else this.prepared = this.backend.prepareRoutineSync(this.source);
2184
2429
  }
2185
2430
  submit() {
2186
2431
  if (this.submitted) return;
@@ -2203,8 +2448,6 @@ var PendingExecute = class {
2203
2448
  * "Array" type by name.
2204
2449
  */
2205
2450
  var Array$1 = class Array$1 extends Tracer {
2206
- static #nextId = 1001;
2207
- id;
2208
2451
  #dtype;
2209
2452
  #weakType;
2210
2453
  #source;
@@ -2221,7 +2464,6 @@ var Array$1 = class Array$1 extends Tracer {
2221
2464
  */
2222
2465
  constructor(args) {
2223
2466
  super(baseArrayTrace);
2224
- this.id = Array$1.#nextId++;
2225
2467
  this.#dtype = args.dtype;
2226
2468
  this.#weakType = args.weakType;
2227
2469
  this.#source = args.source;
@@ -2264,6 +2506,10 @@ var Array$1 = class Array$1 extends Tracer {
2264
2506
  this.#rc++;
2265
2507
  return this;
2266
2508
  }
2509
+ /** Get the current reference count (for debugging memory management). */
2510
+ get refCount() {
2511
+ return this.#rc;
2512
+ }
2267
2513
  dispose() {
2268
2514
  this.#check();
2269
2515
  if (--this.#rc === 0) {
@@ -2421,7 +2667,7 @@ var Array$1 = class Array$1 extends Tracer {
2421
2667
  } else if (castDtype === void 0) {
2422
2668
  castDtype = arrays[i].#dtype;
2423
2669
  castWeakType = arrays[i].#weakType;
2424
- } else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), new ShapedArray([], arrays[i].#dtype, arrays[i].#weakType)));
2670
+ } else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), arrays[i].aval.scalar()));
2425
2671
  const weakType = castWeakType && !strongTypeOutput;
2426
2672
  const { backend, committed } = Array$1.#computeBackend(name, arrays);
2427
2673
  arrays = arrays.map((ar) => ar._putSync(backend));
@@ -2530,6 +2776,27 @@ var Array$1 = class Array$1 extends Tracer {
2530
2776
  pending
2531
2777
  });
2532
2778
  }
2779
+ /** Apply an operation with custom lowering to this array. */
2780
+ static #routine(routine, arrays, outputWeakType) {
2781
+ const { backend, committed } = Array$1.#computeBackend(routine.name, arrays);
2782
+ for (const ar of arrays) ar.#realize();
2783
+ const inputs = arrays.map((ar) => ar.#source);
2784
+ const outputs = routine.type.outputDtypes.map((dtype, i) => backend.malloc(byteWidth(dtype) * prod(routine.type.outputShapes[i])));
2785
+ const pending = arrays.flatMap((ar) => ar.#pending);
2786
+ for (const exe of pending) exe.updateRc(+outputs.length);
2787
+ pending.push(new PendingExecute(backend, routine, inputs, outputs));
2788
+ pending[pending.length - 1].updateRc(+outputs.length - 1);
2789
+ arrays.forEach((ar) => ar.dispose());
2790
+ return outputs.map((output, i) => new Array$1({
2791
+ source: output,
2792
+ st: ShapeTracker.fromShape(routine.type.outputShapes[i]),
2793
+ dtype: routine.type.outputDtypes[i],
2794
+ weakType: outputWeakType[i],
2795
+ backend,
2796
+ committed,
2797
+ pending
2798
+ }));
2799
+ }
2533
2800
  /**
2534
2801
  * Normalizes this array into one backed by a `Slot`.
2535
2802
  *
@@ -2690,6 +2957,12 @@ var Array$1 = class Array$1 extends Tracer {
2690
2957
  [Primitive.Mod]([x, y]) {
2691
2958
  return [x.#binary(AluOp.Mod, y)];
2692
2959
  },
2960
+ [Primitive.Min]([x, y]) {
2961
+ return [x.#binary(AluOp.Min, y)];
2962
+ },
2963
+ [Primitive.Max]([x, y]) {
2964
+ return [x.#binary(AluOp.Max, y)];
2965
+ },
2693
2966
  [Primitive.Neg]([x]) {
2694
2967
  return [zerosLike$1(x.ref).#binary(AluOp.Sub, x)];
2695
2968
  },
@@ -2726,25 +2999,6 @@ var Array$1 = class Array$1 extends Tracer {
2726
2999
  return [y];
2727
3000
  }
2728
3001
  },
2729
- [Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
2730
- const keyShape = generalBroadcast(k0.shape, k1.shape);
2731
- if (!deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
2732
- const c0 = zeros(shape$1, {
2733
- dtype: DType.Uint32,
2734
- device: k0.device
2735
- });
2736
- const c1 = arange(0, prod(shape$1), 1, {
2737
- dtype: DType.Uint32,
2738
- device: k0.device
2739
- }).reshape(shape$1);
2740
- const custom = ([k0$1, k1$1, c0$1, c1$1]) => AluExp.threefry2x32(k0$1, k1$1, c0$1, c1$1, mode);
2741
- return [Array$1.#naryCustom("random_bits", custom, [
2742
- k0,
2743
- k1,
2744
- c0,
2745
- c1
2746
- ])];
2747
- },
2748
3002
  [Primitive.Sin]([x]) {
2749
3003
  return [x.#unary(AluOp.Sin)];
2750
3004
  },
@@ -2772,12 +3026,6 @@ var Array$1 = class Array$1 extends Tracer {
2772
3026
  [Primitive.Sqrt]([x]) {
2773
3027
  return [x.#unary(AluOp.Sqrt)];
2774
3028
  },
2775
- [Primitive.Min]([x, y]) {
2776
- return [x.#binary(AluOp.Min, y)];
2777
- },
2778
- [Primitive.Max]([x, y]) {
2779
- return [x.#binary(AluOp.Max, y)];
2780
- },
2781
3029
  [Primitive.Reduce]([x], { op, axis }) {
2782
3030
  if (axis.length === 0) return [x];
2783
3031
  return [x.#moveAxesDown(axis).#reduce(op)];
@@ -2812,6 +3060,55 @@ var Array$1 = class Array$1 extends Tracer {
2812
3060
  y
2813
3061
  ], { dtypeOverride: [DType.Bool] })];
2814
3062
  },
3063
+ [Primitive.Concatenate](xs, { axis }) {
3064
+ const ndim$2 = xs[0].ndim;
3065
+ const sizes = xs.map((x) => x.shape[axis]);
3066
+ const finalSize = sizes.reduce((a, b) => a + b, 0);
3067
+ const makePadAxis = (start, end) => range(ndim$2).map((i) => i === axis ? [start, end] : [0, 0]);
3068
+ let cum = 0;
3069
+ const xsPadded = [];
3070
+ for (let i = 0; i < xs.length; i++) {
3071
+ const padding = makePadAxis(cum, finalSize - cum - sizes[i]);
3072
+ xsPadded.push(xs[i].#reshape(xs[i].#st.pad(padding)));
3073
+ cum += sizes[i];
3074
+ }
3075
+ const custom = (exps) => exps.reduce(AluExp.add);
3076
+ return [Array$1.#naryCustom("concatenate", custom, xsPadded)];
3077
+ },
3078
+ [Primitive.Split]([x], { axis, sizes }) {
3079
+ const outputs = [];
3080
+ for (let i = 0, start = 0; i < sizes.length; i++) {
3081
+ const slice = range(x.ndim).map((d) => d === axis ? [start, start + sizes[i]] : [0, x.shape[d]]);
3082
+ outputs.push(x.ref.#reshape(x.#st.shrink(slice)));
3083
+ start += sizes[i];
3084
+ }
3085
+ x.dispose();
3086
+ return outputs;
3087
+ },
3088
+ [Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
3089
+ const keyShape = k0.shape;
3090
+ const genShape = shape$1.slice(keyShape.length);
3091
+ const c0 = zeros(genShape, {
3092
+ dtype: DType.Uint32,
3093
+ device: k0.device
3094
+ });
3095
+ const c1 = arange(0, prod(genShape), 1, {
3096
+ dtype: DType.Uint32,
3097
+ device: k0.device
3098
+ }).reshape(genShape);
3099
+ k0 = k0.#reshape(k0.#st.reshape(keyShape.concat(rep(genShape.length, 1))));
3100
+ k1 = k1.#reshape(k1.#st.reshape(keyShape.concat(rep(genShape.length, 1))));
3101
+ const custom = ([k0$1, k1$1, c0$1, c1$1]) => AluExp.threefry2x32(k0$1, k1$1, c0$1, c1$1, mode);
3102
+ return [Array$1.#naryCustom("random_bits", custom, [
3103
+ k0,
3104
+ k1,
3105
+ c0,
3106
+ c1
3107
+ ])];
3108
+ },
3109
+ [Primitive.Gather]([x, ...indices], { axis, outDim }) {
3110
+ return [x.#gather(indices, axis, outDim)];
3111
+ },
2815
3112
  [Primitive.Transpose]([x], { perm }) {
2816
3113
  return [x.#transpose(perm)];
2817
3114
  },
@@ -2832,17 +3129,71 @@ var Array$1 = class Array$1 extends Tracer {
2832
3129
  [Primitive.Pad]([x], { width }) {
2833
3130
  return [x.#reshape(x.#st.pad(width))];
2834
3131
  },
2835
- [Primitive.Gather]([x, ...indices], { axis, outDim }) {
2836
- return [x.#gather(indices, axis, outDim)];
3132
+ [Primitive.Sort]([x]) {
3133
+ const routine = new Routine(Routines.Sort, {
3134
+ inputShapes: [x.shape],
3135
+ inputDtypes: [x.dtype],
3136
+ outputShapes: [x.shape],
3137
+ outputDtypes: [x.dtype]
3138
+ });
3139
+ return Array$1.#routine(routine, [x], [x.#weakType]);
3140
+ },
3141
+ [Primitive.Argsort]([x]) {
3142
+ const routine = new Routine(Routines.Argsort, {
3143
+ inputShapes: [x.shape],
3144
+ inputDtypes: [x.dtype],
3145
+ outputShapes: [x.shape, x.shape],
3146
+ outputDtypes: [x.dtype, DType.Int32]
3147
+ });
3148
+ return Array$1.#routine(routine, [x], [x.#weakType, false]);
3149
+ },
3150
+ [Primitive.TriangularSolve]([a, b], { unitDiagonal }) {
3151
+ const routine = new Routine(Routines.TriangularSolve, {
3152
+ inputShapes: [a.shape, b.shape],
3153
+ inputDtypes: [a.dtype, b.dtype],
3154
+ outputShapes: [b.shape],
3155
+ outputDtypes: [b.dtype]
3156
+ }, { unitDiagonal });
3157
+ return Array$1.#routine(routine, [a, b], [a.#weakType && b.#weakType]);
2837
3158
  },
2838
- [Primitive.JitCall](args, { jaxpr, numConsts }) {
2839
- if (jaxpr.inBinders.length !== args.length) throw new Error(`jit_call expects ${jaxpr.inBinders.length} args, got ${args.length}`);
2840
- const { backend, committed } = Array$1.#computeBackend("jit_call", args);
3159
+ [Primitive.Cholesky]([a]) {
3160
+ const routine = new Routine(Routines.Cholesky, {
3161
+ inputShapes: [a.shape],
3162
+ inputDtypes: [a.dtype],
3163
+ outputShapes: [a.shape],
3164
+ outputDtypes: [a.dtype]
3165
+ });
3166
+ return Array$1.#routine(routine, [a], [a.#weakType]);
3167
+ },
3168
+ [Primitive.LU]([a]) {
3169
+ const batch = a.shape.slice(0, -2);
3170
+ const [m, n] = a.shape.slice(-2);
3171
+ const routine = new Routine(Routines.LU, {
3172
+ inputShapes: [a.shape],
3173
+ inputDtypes: [a.dtype],
3174
+ outputShapes: [
3175
+ a.shape,
3176
+ [...batch, Math.min(m, n)],
3177
+ [...batch, m]
3178
+ ],
3179
+ outputDtypes: [
3180
+ a.dtype,
3181
+ DType.Int32,
3182
+ DType.Int32
3183
+ ]
3184
+ });
3185
+ return Array$1.#routine(routine, [a], [
3186
+ a.#weakType,
3187
+ false,
3188
+ false
3189
+ ]);
3190
+ },
3191
+ [Primitive.Jit](args, { jaxpr }) {
3192
+ if (jaxpr.inBinders.length !== args.length) throw new Error(`jit expects ${jaxpr.inBinders.length} args, got ${args.length}`);
3193
+ const { backend, committed } = Array$1.#computeBackend("jit", args);
2841
3194
  args = args.map((ar) => ar._putSync(backend));
2842
- const consts = args.slice(0, numConsts);
2843
- const tracers = args.slice(numConsts);
2844
- const jp = jitCompile(backend, jaxpr, consts);
2845
- const { outputs, pending } = jp.execute(tracers.map((x) => x._realizeSource()));
3195
+ const jp = jitCompile(backend, jaxpr);
3196
+ const { outputs, pending } = jp.execute(args.map((x) => x._realizeSource()));
2846
3197
  for (const exe of pending) exe.updateRc(+outputs.length - 1);
2847
3198
  const prevPending = [...new Set(args.flatMap((x) => x.#pending))];
2848
3199
  for (const exe of prevPending) exe.updateRc(+outputs.length);
@@ -2942,7 +3293,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
2942
3293
  device
2943
3294
  });
2944
3295
  } else {
2945
- const weakType = dtype == void 0;
3296
+ const weakType = dtype == void 0 && shape$1.length === 0;
2946
3297
  dtype = dtype ?? DType.Float32;
2947
3298
  const data = dtypedJsArray(dtype, flat);
2948
3299
  return arrayFromData(data, shape$1, {
@@ -3056,7 +3407,7 @@ function ones(shape$1, { dtype, device } = {}) {
3056
3407
  }
3057
3408
  /** Return a new array of given shape and type, filled with `fill_value`. */
3058
3409
  function full(shape$1, fillValue, { dtype, device } = {}) {
3059
- let weakType = dtype == void 0;
3410
+ let weakType = dtype == void 0 && shape$1.length === 0;
3060
3411
  if (typeof fillValue === "number") dtype = dtype ?? DType.Float32;
3061
3412
  else if (typeof fillValue === "boolean") {
3062
3413
  dtype = dtype ?? DType.Bool;
@@ -3141,6 +3492,43 @@ function arange(start, stop, step = 1, { dtype, device } = {}) {
3141
3492
  });
3142
3493
  }
3143
3494
  /**
3495
+ * Return an array with ones on and below the diagonal and zeros elsewhere.
3496
+ *
3497
+ * If `k` is provided, it specifies the sub-diagonal on and below which the
3498
+ * array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
3499
+ * `k>0` is above it.
3500
+ */
3501
+ function tri(n, m, k = 0, { dtype, device } = {}) {
3502
+ m ??= n;
3503
+ dtype ??= DType.Float32;
3504
+ if (!Number.isInteger(n) || n < 0) throw new Error(`tri: n must be a non-negative integer, got ${n}`);
3505
+ if (!Number.isInteger(m) || m < 0) throw new Error(`tri: m must be a non-negative integer, got ${m}`);
3506
+ if (!Number.isInteger(k)) throw new Error(`tri: k must be an integer, got ${k}`);
3507
+ const rows = arange(k, n + k, 1, {
3508
+ dtype: DType.Int32,
3509
+ device
3510
+ });
3511
+ const cols = arange(0, m, 1, {
3512
+ dtype: DType.Int32,
3513
+ device
3514
+ });
3515
+ return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
3516
+ }
3517
+ /** Return the lower triangle of an array. Must be of dimension >= 2. */
3518
+ function tril(a, k = 0) {
3519
+ if (ndim$1(a) < 2) throw new Error(`tril: input array must be at least 2D, got ${ndim$1(a)}D`);
3520
+ a = fudgeArray(a);
3521
+ const [n, m] = a.shape.slice(-2);
3522
+ return where$1(tri(n, m, k, { dtype: DType.Bool }), a.ref, zerosLike$1(a));
3523
+ }
3524
+ /** Return the upper triangle of an array. Must be of dimension >= 2. */
3525
+ function triu(a, k = 0) {
3526
+ if (ndim$1(a) < 2) throw new Error(`tril: input array must be at least 2D, got ${ndim$1(a)}D`);
3527
+ a = fudgeArray(a);
3528
+ const [n, m] = a.shape.slice(-2);
3529
+ return where$1(tri(n, m, k - 1, { dtype: DType.Bool }), zerosLike$1(a.ref), a);
3530
+ }
3531
+ /**
3144
3532
  * Return evenly spaced numbers over a specified interval.
3145
3533
  *
3146
3534
  * Returns _num_ evenly spaced samples, calculated over the interval
@@ -3177,6 +3565,27 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
3177
3565
  committed: device != void 0
3178
3566
  });
3179
3567
  }
3568
+ /**
3569
+ * Return numbers spaced evenly on a log scale.
3570
+ *
3571
+ * In linear space, the sequence starts at `base ** start` and ends at
3572
+ * `base ** stop` (see `endpoint` below).
3573
+ *
3574
+ * @param start - `base ** start` is the starting value of the sequence.
3575
+ * @param stop - `base ** stop` is the final value of the sequence, unless `endpoint` is false.
3576
+ * @param num - Number of samples to generate. Default is 50.
3577
+ * @param endpoint - If true, `stop` is the last sample. Otherwise, it is not included. Default is true.
3578
+ * @param base - The base of the log space. Default is 10.
3579
+ * @returns Array of evenly spaced values on a log scale.
3580
+ */
3581
+ function logspace(start, stop, num = 50, endpoint = true, base = 10, { dtype, device } = {}) {
3582
+ const y = linspace(start, stop, num, endpoint, {
3583
+ dtype,
3584
+ device
3585
+ });
3586
+ const logBase = Math.log(base);
3587
+ return exp$1(mul(y, logBase));
3588
+ }
3180
3589
  function aluCompare(a, b, op) {
3181
3590
  switch (op) {
3182
3591
  case CompareOp.Less: return AluExp.cmplt(a, b);
@@ -3187,383 +3596,210 @@ function aluCompare(a, b, op) {
3187
3596
  }
3188
3597
 
3189
3598
  //#endregion
3190
- //#region src/frontend/jvp.ts
3191
- var JVPTracer = class extends Tracer {
3192
- constructor(trace$1, primal, tangent) {
3599
+ //#region src/frontend/vmap.ts
3600
+ function mappedAval(batchDim, aval) {
3601
+ const shape$1 = [...aval.shape];
3602
+ shape$1.splice(batchDim, 1);
3603
+ return new ShapedArray(shape$1, aval.dtype, aval.weakType);
3604
+ }
3605
+ /** Move one axis to a different index. */
3606
+ function moveaxis(x, src, dst) {
3607
+ const t = pureArray(x);
3608
+ src = checkAxis(src, t.ndim);
3609
+ dst = checkAxis(dst, t.ndim);
3610
+ if (src === dst) return t;
3611
+ const perm = range(t.ndim);
3612
+ perm.splice(src, 1);
3613
+ perm.splice(dst, 0, src);
3614
+ return transpose$1(t, perm);
3615
+ }
3616
+ function moveBatchAxis(axisSize, src, dst, x) {
3617
+ if (src === null) {
3618
+ const targetShape = [...x.shape];
3619
+ targetShape.splice(dst, 0, axisSize);
3620
+ return broadcast(x, targetShape, [dst]);
3621
+ } else if (src === dst) return x;
3622
+ else return moveaxis(x, src, dst);
3623
+ }
3624
+ var BatchTracer = class extends Tracer {
3625
+ constructor(trace$1, val, batchDim) {
3193
3626
  super(trace$1);
3194
- this.primal = primal;
3195
- this.tangent = tangent;
3627
+ this.val = val;
3628
+ this.batchDim = batchDim;
3196
3629
  }
3197
3630
  get aval() {
3198
- return this.primal.aval;
3631
+ if (this.batchDim === null) return this.val.aval;
3632
+ else return mappedAval(this.batchDim, this.val.aval);
3199
3633
  }
3200
3634
  toString() {
3201
- return `JVPTracer(${this.primal.toString()}, ${this.tangent.toString()})`;
3635
+ return `BatchTracer(${this.val.toString()}, ${this.batchDim})`;
3202
3636
  }
3203
3637
  get ref() {
3204
- this.primal.ref, this.tangent.ref;
3638
+ this.val.ref;
3205
3639
  return this;
3206
3640
  }
3207
3641
  dispose() {
3208
- this.primal.dispose();
3209
- this.tangent.dispose();
3642
+ this.val.dispose();
3643
+ }
3644
+ fullLower() {
3645
+ if (this.batchDim === null) return this.val.fullLower();
3646
+ else return this;
3210
3647
  }
3211
3648
  };
3212
- var JVPTrace = class extends Trace {
3649
+ var BatchTrace = class extends Trace {
3213
3650
  pure(val) {
3214
3651
  return this.lift(pureArray(val));
3215
3652
  }
3216
3653
  lift(val) {
3217
- return new JVPTracer(this, val, zerosLike$1(val.ref));
3654
+ return new BatchTracer(this, val, null);
3218
3655
  }
3219
3656
  processPrimitive(primitive, tracers, params) {
3220
- const [primalsIn, tangentsIn] = unzip2(tracers.map((x) => [x.primal, x.tangent]));
3221
- const jvpRule = jvpRules[primitive];
3222
- if (jvpRule === void 0) throw new Error(`No JVP rule for: ${primitive}`);
3223
- const [primalsOut, tangentsOut] = jvpRule(primalsIn, tangentsIn, params);
3224
- return zip(primalsOut, tangentsOut).map(([x, t]) => new JVPTracer(this, x, t));
3657
+ const [valsIn, bdimsIn] = unzip2(tracers.map((t) => [t.val, t.batchDim]));
3658
+ const vmapRule = vmapRules[primitive];
3659
+ if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
3660
+ if (bdimsIn.every((d) => d === null)) {
3661
+ const valOuts$1 = bind(primitive, valsIn, params);
3662
+ return valOuts$1.map((x) => new BatchTracer(this, x, null));
3663
+ }
3664
+ const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
3665
+ if (valOuts.length !== bdimOuts.length) throw new Error(`vmap rule for ${primitive} returned mismatched lengths: ${valOuts.length} vs ${bdimOuts.length}`);
3666
+ return zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
3667
+ }
3668
+ get axisSize() {
3669
+ return this.main.globalData;
3225
3670
  }
3226
3671
  };
3227
- /** Rule that applies the same operation to primals and tangents. */
3228
- function linearTangentsJvp(primitive) {
3229
- return (primals, tangents, params) => {
3230
- const ys = bind(primitive, primals, params);
3231
- const dys = bind(primitive, tangents, params);
3232
- return [ys, dys];
3672
+ /**
3673
+ * Process a primitive with built-in broadcasting.
3674
+ *
3675
+ * Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
3676
+ */
3677
+ function broadcastBatcher(prim) {
3678
+ return (axisSize, args, dims, params) => {
3679
+ if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
3680
+ const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
3681
+ const firstIdx = dims.findIndex((d) => d !== null);
3682
+ const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
3683
+ if (zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[bind1(prim, args, params)], [nd + firstBdim]];
3684
+ args = args.map((x, i) => {
3685
+ if (dims[i] === null) return x;
3686
+ x = moveBatchAxis(axisSize, dims[i], 0, x);
3687
+ if (x.ndim < nd) x = x.reshape([
3688
+ x.shape[0],
3689
+ ...rep(nd - x.ndim, 1),
3690
+ ...x.shape.slice(1)
3691
+ ]);
3692
+ return x;
3693
+ });
3694
+ return [[bind1(prim, args, params)], [0]];
3233
3695
  };
3234
3696
  }
3235
- /** Rule for product of gradients in bilinear operations. */
3236
- function bilinearTangentsJvp(primitive) {
3237
- return ([x, y], [dx, dy], params) => {
3238
- const primal = bind1(primitive, [x.ref, y.ref], params);
3239
- const tangent = bind1(primitive, [x, dy], params).add(bind1(primitive, [dx, y], params));
3240
- return [[primal], [tangent]];
3697
+ function unopBatcher(prim) {
3698
+ return (axisSize, [x], [xBdim], params) => {
3699
+ return [[bind1(prim, [x], params)], [xBdim]];
3241
3700
  };
3242
3701
  }
3243
- /** Rule that zeros out any tangents. */
3244
- function zeroTangentsJvp(primitive) {
3245
- return (primals, tangents, params) => {
3246
- for (const t of tangents) t.dispose();
3247
- const ys = bind(primitive, primals, params);
3248
- return [ys, ys.map((y) => zerosLike$1(y.ref))];
3702
+ function lastDimsBatcher(prim, inputDims, numOutputs = 1) {
3703
+ return (axisSize, [x], [xBdim], params) => {
3704
+ assertNonNull(xBdim);
3705
+ if (xBdim < x.ndim - inputDims) return [bind(prim, [x], params), rep(numOutputs, xBdim)];
3706
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3707
+ return [bind(prim, [x], params), rep(numOutputs, 0)];
3249
3708
  };
3250
3709
  }
3251
- const jvpRules = {
3252
- [Primitive.Add]: linearTangentsJvp(Primitive.Add),
3253
- [Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
3254
- [Primitive.Idiv]: zeroTangentsJvp(Primitive.Idiv),
3255
- [Primitive.Mod]([x, y], [dx, dy]) {
3256
- if (!isFloatDtype(x.dtype) && !isFloatDtype(y.dtype)) {
3257
- dx.dispose();
3258
- dy.dispose();
3259
- return [[x.ref, y.ref], [zerosLike$1(x), zerosLike$1(y)]];
3260
- }
3261
- const q = idiv(x.ref, y.ref);
3262
- return [[mod(x, y)], [dx.sub(dy.mul(q))]];
3263
- },
3264
- [Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
3265
- [Primitive.Reciprocal]([x], [dx]) {
3266
- const xRecip = reciprocal$1(x.ref);
3267
- return [[xRecip.ref], [neg(xRecip.ref.mul(xRecip)).mul(dx)]];
3268
- },
3269
- [Primitive.Floor]: zeroTangentsJvp(Primitive.Floor),
3270
- [Primitive.Ceil]: zeroTangentsJvp(Primitive.Ceil),
3271
- [Primitive.StopGradient]: zeroTangentsJvp(Primitive.StopGradient),
3272
- [Primitive.Cast]([x], [dx], { dtype }) {
3273
- if (x.dtype === dtype) return [[x], [dx]];
3274
- if (isFloatDtype(dtype) && isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
3275
- else {
3276
- dx.dispose();
3277
- return [[cast(x.ref, dtype)], [zerosLike$1(x)]];
3278
- }
3279
- },
3280
- [Primitive.Bitcast]([x], [dx], { dtype }) {
3281
- if (x.dtype === dtype) return [[x], [dx]];
3282
- dx.dispose();
3283
- return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
3284
- },
3285
- [Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
3286
- [Primitive.Sin]([x], [dx]) {
3287
- return [[sin$1(x.ref)], [cos$1(x).mul(dx)]];
3710
+ const vmapRules = {
3711
+ [Primitive.Add]: broadcastBatcher(Primitive.Add),
3712
+ [Primitive.Mul]: broadcastBatcher(Primitive.Mul),
3713
+ [Primitive.Idiv]: broadcastBatcher(Primitive.Idiv),
3714
+ [Primitive.Mod]: broadcastBatcher(Primitive.Mod),
3715
+ [Primitive.Min]: broadcastBatcher(Primitive.Min),
3716
+ [Primitive.Max]: broadcastBatcher(Primitive.Max),
3717
+ [Primitive.Neg]: unopBatcher(Primitive.Neg),
3718
+ [Primitive.Reciprocal]: unopBatcher(Primitive.Reciprocal),
3719
+ [Primitive.Floor]: unopBatcher(Primitive.Floor),
3720
+ [Primitive.Ceil]: unopBatcher(Primitive.Ceil),
3721
+ [Primitive.StopGradient]: unopBatcher(Primitive.StopGradient),
3722
+ [Primitive.Cast]: unopBatcher(Primitive.Cast),
3723
+ [Primitive.Bitcast]: unopBatcher(Primitive.Bitcast),
3724
+ [Primitive.Sin]: unopBatcher(Primitive.Sin),
3725
+ [Primitive.Cos]: unopBatcher(Primitive.Cos),
3726
+ [Primitive.Asin]: unopBatcher(Primitive.Asin),
3727
+ [Primitive.Atan]: unopBatcher(Primitive.Atan),
3728
+ [Primitive.Exp]: unopBatcher(Primitive.Exp),
3729
+ [Primitive.Log]: unopBatcher(Primitive.Log),
3730
+ [Primitive.Erf]: unopBatcher(Primitive.Erf),
3731
+ [Primitive.Erfc]: unopBatcher(Primitive.Erfc),
3732
+ [Primitive.Sqrt]: unopBatcher(Primitive.Sqrt),
3733
+ [Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
3734
+ assertNonNull(xBdim);
3735
+ const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3736
+ const outBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
3737
+ return [[reduce(x, op, newAxis)], [outBdim]];
3288
3738
  },
3289
- [Primitive.Cos]([x], [dx]) {
3290
- return [[cos$1(x.ref)], [neg(sin$1(x)).mul(dx)]];
3739
+ [Primitive.Dot](axisSize, [x, y], [xBdim, yBdim]) {
3740
+ x = moveBatchAxis(axisSize, xBdim, x.ndim - (xBdim === null ? 1 : 2), x);
3741
+ y = moveBatchAxis(axisSize, yBdim, y.ndim - (yBdim === null ? 1 : 2), y);
3742
+ const z = dot$2(x, y);
3743
+ return [[z], [z.ndim - 1]];
3291
3744
  },
3292
- [Primitive.Asin]([x], [dx]) {
3293
- const denom = sqrt$1(reciprocal$1(cast(1, x.dtype).sub(x.ref.mul(x.ref))));
3294
- return [[asin$1(x)], [denom.mul(dx)]];
3745
+ [Primitive.Conv](axisSize, [x, y], [xBdim, yBdim], params) {
3746
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3747
+ y = moveBatchAxis(axisSize, yBdim, 0, y);
3748
+ const z = conv$1(x, y, {
3749
+ ...params,
3750
+ vmapDims: params.vmapDims + 1
3751
+ });
3752
+ return [[z], [0]];
3295
3753
  },
3296
- [Primitive.Atan]([x], [dx]) {
3297
- const denom = cast(1, x.dtype).add(x.ref.mul(x.ref));
3298
- return [[atan$1(x)], [dx.div(denom)]];
3754
+ [Primitive.Compare]: broadcastBatcher(Primitive.Compare),
3755
+ [Primitive.Where]: broadcastBatcher(Primitive.Where),
3756
+ [Primitive.Concatenate](axisSize, xs, xBdims, { axis }) {
3757
+ const minBdim = Math.min(...xBdims.filter((d) => d !== null));
3758
+ xs = xs.map((x, i) => moveBatchAxis(axisSize, xBdims[i], minBdim, x));
3759
+ const newAxis = axis + (minBdim <= axis ? 1 : 0);
3760
+ return [[concatenate$1(xs, newAxis)], [minBdim]];
3299
3761
  },
3300
- [Primitive.Exp]([x], [dx]) {
3301
- const z = exp$1(x);
3302
- return [[z.ref], [z.mul(dx)]];
3762
+ [Primitive.Split](axisSize, [x], [xBdim], { axis, sizes }) {
3763
+ assertNonNull(xBdim);
3764
+ const newAxis = axis + (xBdim <= axis ? 1 : 0);
3765
+ const outs = split$2(x, newAxis, sizes);
3766
+ return [outs, rep(outs.length, xBdim)];
3303
3767
  },
3304
- [Primitive.Log]([x], [dx]) {
3305
- return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
3768
+ [Primitive.RandomBits](axisSize, [k0, k1], [bdim0, bdim1], { shape: shape$1, mode }) {
3769
+ k0 = moveBatchAxis(axisSize, bdim0, 0, k0);
3770
+ k1 = moveBatchAxis(axisSize, bdim1, 0, k1);
3771
+ return [[randomBits(k0, k1, [axisSize, ...shape$1], mode)], [0]];
3306
3772
  },
3307
- [Primitive.Erf]([x], [dx]) {
3308
- const coeff = 2 / Math.sqrt(Math.PI);
3309
- const expTerm = exp$1(neg(x.ref.mul(x.ref)));
3310
- return [[erf$1(x)], [expTerm.mul(coeff).mul(dx)]];
3311
- },
3312
- [Primitive.Erfc]([x], [dx]) {
3313
- const coeff = -2 / Math.sqrt(Math.PI);
3314
- const expTerm = exp$1(neg(x.ref.mul(x.ref)));
3315
- return [[erfc$1(x)], [expTerm.mul(coeff).mul(dx)]];
3316
- },
3317
- [Primitive.Sqrt]([x], [dx]) {
3318
- const z = sqrt$1(x);
3319
- return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
3320
- },
3321
- [Primitive.Min]([x, y], [dx, dy]) {
3322
- return [[min$1(x.ref, y.ref)], [where$1(less$1(y, x), dy, dx)]];
3323
- },
3324
- [Primitive.Max]([x, y], [dx, dy]) {
3325
- return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
3326
- },
3327
- [Primitive.Reduce]([x], [dx], { op, axis }) {
3328
- if (op === AluOp.Add) return [[reduce(x, op, axis)], [reduce(dx, op, axis)]];
3329
- else if (op === AluOp.Mul) {
3330
- const primal = reduce(x.ref, op, axis);
3331
- const tangent = broadcast(primal.ref, x.shape, axis).mul(reciprocal$1(x)).mul(dx).sum(axis);
3332
- return [[primal], [tangent]];
3333
- } else if (op === AluOp.Min || op === AluOp.Max) {
3334
- const primal = reduce(x.ref, op, axis);
3335
- const notMin = notEqual$1(x, broadcast(primal.ref, x.shape, axis));
3336
- const minCount = where$1(notMin.ref, 0, 1).sum(axis);
3337
- const tangent = where$1(notMin, 0, dx).sum(axis).div(minCount);
3338
- return [[primal], [tangent]];
3339
- } else throw new Error(`JVP rule not implemented for reduce op: ${op}`);
3340
- },
3341
- [Primitive.Pool]: linearTangentsJvp(Primitive.Pool),
3342
- [Primitive.PoolTranspose]: linearTangentsJvp(Primitive.PoolTranspose),
3343
- [Primitive.Dot]: bilinearTangentsJvp(Primitive.Dot),
3344
- [Primitive.Conv]: bilinearTangentsJvp(Primitive.Conv),
3345
- [Primitive.Compare]: zeroTangentsJvp(Primitive.Compare),
3346
- [Primitive.Where]([cond, x, y], [dcond, dx, dy]) {
3347
- dcond.dispose();
3348
- return [[where$1(cond.ref, x, y)], [where$1(cond, dx, dy)]];
3349
- },
3350
- [Primitive.Transpose]: linearTangentsJvp(Primitive.Transpose),
3351
- [Primitive.Broadcast]: linearTangentsJvp(Primitive.Broadcast),
3352
- [Primitive.Reshape]: linearTangentsJvp(Primitive.Reshape),
3353
- [Primitive.Flip]: linearTangentsJvp(Primitive.Flip),
3354
- [Primitive.Shrink]: linearTangentsJvp(Primitive.Shrink),
3355
- [Primitive.Pad]: linearTangentsJvp(Primitive.Pad),
3356
- [Primitive.Gather]([x, ...indices], [dx, ..._], { axis, outDim }) {
3357
- const indicesRef = indices.map((t) => t.ref);
3358
- return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
3359
- },
3360
- [Primitive.JitCall](primals, tangents, { name, jaxpr }) {
3361
- const { newJaxpr, newConsts } = jvpJaxpr(jaxpr);
3362
- const outs = bind(Primitive.JitCall, [
3363
- ...newConsts.map((c) => c.ref),
3364
- ...primals,
3365
- ...tangents
3366
- ], {
3367
- name: `${name}_jvp`,
3368
- jaxpr: newJaxpr,
3369
- numConsts: newConsts.length
3370
- });
3371
- const n = outs.length / 2;
3372
- if (!Number.isInteger(n)) throw new Error("internal: JVP Jaxpr output length is not even");
3373
- const [primalsOut, tangentsOut] = [outs.slice(0, n), outs.slice(n)];
3374
- return [primalsOut, tangentsOut];
3375
- }
3376
- };
3377
- const jvpJaxprCache = /* @__PURE__ */ new Map();
3378
- function jvpJaxpr(jaxpr) {
3379
- if (jvpJaxprCache.has(jaxpr)) return jvpJaxprCache.get(jaxpr);
3380
- const inAvals = jaxpr.inBinders.map((v) => v.aval);
3381
- const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((primals, tangents) => jvpFlat(jaxprAsFun(jaxpr), primals, tangents))(inAvals, inAvals);
3382
- const result = {
3383
- newJaxpr,
3384
- newConsts
3385
- };
3386
- jvpJaxprCache.set(jaxpr, result);
3387
- return result;
3388
- }
3389
- function jvpFlat(f, primals, tangents) {
3390
- try {
3391
- var _usingCtx$1 = _usingCtx();
3392
- const main = _usingCtx$1.u(newMain(JVPTrace));
3393
- const trace$1 = new JVPTrace(main);
3394
- const tracersIn = zip(primals, tangents).map(([x, t]) => new JVPTracer(trace$1, pureArray(x), pureArray(t)));
3395
- const outs = f(...tracersIn);
3396
- const tracersOut = outs.map((out) => fullRaise(trace$1, out));
3397
- return unzip2(tracersOut.map((t) => [t.primal, t.tangent]));
3398
- } catch (_) {
3399
- _usingCtx$1.e = _;
3400
- } finally {
3401
- _usingCtx$1.d();
3402
- }
3403
- }
3404
- function jvp$1(f, primals, tangents) {
3405
- const [primalsFlat, inTree] = flatten(primals);
3406
- const [tangentsFlat, inTree2] = flatten(tangents);
3407
- if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
3408
- const [flatFun, outTree] = flattenFun(f, inTree);
3409
- const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
3410
- if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
3411
- const primalsOut = unflatten(outTree.value, primalsOutFlat);
3412
- const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
3413
- return [primalsOut, tangentsOut];
3414
- }
3415
-
3416
- //#endregion
3417
- //#region src/frontend/vmap.ts
3418
- function mappedAval(batchDim, aval) {
3419
- const shape$1 = [...aval.shape];
3420
- shape$1.splice(batchDim, 1);
3421
- return new ShapedArray(shape$1, aval.dtype, aval.weakType);
3422
- }
3423
- /** Move one axis to a different index. */
3424
- function moveaxis(x, src, dst) {
3425
- const t = pureArray(x);
3426
- src = checkAxis(src, t.ndim);
3427
- dst = checkAxis(dst, t.ndim);
3428
- if (src === dst) return t;
3429
- const perm = range(t.ndim);
3430
- perm.splice(src, 1);
3431
- perm.splice(dst, 0, src);
3432
- return transpose$1(t, perm);
3433
- }
3434
- function moveBatchAxis(axisSize, src, dst, x) {
3435
- if (src === null) {
3436
- const targetShape = [...x.shape];
3437
- targetShape.splice(dst, 0, axisSize);
3438
- return broadcast(x, targetShape, [dst]);
3439
- } else if (src === dst) return x;
3440
- else return moveaxis(x, src, dst);
3441
- }
3442
- var BatchTracer = class extends Tracer {
3443
- constructor(trace$1, val, batchDim) {
3444
- super(trace$1);
3445
- this.val = val;
3446
- this.batchDim = batchDim;
3447
- }
3448
- get aval() {
3449
- if (this.batchDim === null) return this.val.aval;
3450
- else return mappedAval(this.batchDim, this.val.aval);
3451
- }
3452
- toString() {
3453
- return `BatchTracer(${this.val.toString()}, ${this.batchDim})`;
3454
- }
3455
- get ref() {
3456
- this.val.ref;
3457
- return this;
3458
- }
3459
- dispose() {
3460
- this.val.dispose();
3461
- }
3462
- fullLower() {
3463
- if (this.batchDim === null) return this.val.fullLower();
3464
- else return this;
3465
- }
3466
- };
3467
- var BatchTrace = class extends Trace {
3468
- pure(val) {
3469
- return this.lift(pureArray(val));
3470
- }
3471
- lift(val) {
3472
- return new BatchTracer(this, val, null);
3473
- }
3474
- processPrimitive(primitive, tracers, params) {
3475
- const [valsIn, bdimsIn] = unzip2(tracers.map((t) => [t.val, t.batchDim]));
3476
- const vmapRule = vmapRules[primitive];
3477
- if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
3478
- if (bdimsIn.every((d) => d === null)) {
3479
- const valOuts$1 = bind(primitive, valsIn, params);
3480
- return valOuts$1.map((x) => new BatchTracer(this, x, null));
3773
+ [Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
3774
+ if (indicesBdim.every((d) => d === null)) {
3775
+ assertNonNull(xBdim);
3776
+ const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3777
+ let newBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
3778
+ let newOutDim = outDim;
3779
+ if (newOutDim < newBdim) newBdim += axis.length;
3780
+ else newOutDim += 1;
3781
+ return [[gather(x, indices, newAxis, newOutDim)], [newBdim]];
3481
3782
  }
3482
- const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
3483
- return zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
3484
- }
3485
- get axisSize() {
3486
- return this.main.globalData;
3487
- }
3488
- };
3489
- /**
3490
- * Process a primitive with built-in broadcasting.
3491
- *
3492
- * Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
3493
- */
3494
- function broadcastBatcher(op) {
3495
- return (axisSize, args, dims) => {
3496
- if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
3497
- const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
3498
- const firstIdx = dims.findIndex((d) => d !== null);
3499
- const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
3500
- if (zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
3501
- args = args.map((x, i) => {
3502
- if (dims[i] === null) return x;
3503
- x = moveBatchAxis(axisSize, dims[i], 0, x);
3504
- if (x.ndim < nd) x = x.reshape([
3505
- x.shape[0],
3506
- ...rep(nd - x.ndim, 1),
3507
- ...x.shape.slice(1)
3783
+ const nd = Math.max(...indices.map((m, i) => ndim$1(m) + (indicesBdim[i] === null ? 1 : 0)));
3784
+ indices = indices.map((m, i) => {
3785
+ if (indicesBdim[i] === null) return m;
3786
+ m = moveBatchAxis(axisSize, indicesBdim[i], 0, m);
3787
+ if (m.ndim < nd) m = m.reshape([
3788
+ m.shape[0],
3789
+ ...rep(nd - m.ndim, 1),
3790
+ ...m.shape.slice(1)
3508
3791
  ]);
3509
- return x;
3510
- });
3511
- return [[op(...args)], [0]];
3512
- };
3513
- }
3514
- function unopBatcher(op) {
3515
- return (axisSize, [x], [xBdim], params) => {
3516
- return [[op(x, params)], [xBdim]];
3517
- };
3518
- }
3519
- const vmapRules = {
3520
- [Primitive.Add]: broadcastBatcher(add$1),
3521
- [Primitive.Mul]: broadcastBatcher(mul),
3522
- [Primitive.Idiv]: broadcastBatcher(idiv),
3523
- [Primitive.Mod]: broadcastBatcher(mod),
3524
- [Primitive.Neg]: unopBatcher(neg),
3525
- [Primitive.Reciprocal]: unopBatcher(reciprocal$1),
3526
- [Primitive.Floor]: unopBatcher(floor$1),
3527
- [Primitive.Ceil]: unopBatcher(ceil$1),
3528
- [Primitive.StopGradient]: unopBatcher(stopGradient),
3529
- [Primitive.Cast]: unopBatcher((x, { dtype }) => cast(x, dtype)),
3530
- [Primitive.Bitcast]: unopBatcher((x, { dtype }) => bitcast(x, dtype)),
3531
- [Primitive.Sin]: unopBatcher(sin$1),
3532
- [Primitive.Cos]: unopBatcher(cos$1),
3533
- [Primitive.Asin]: unopBatcher(asin$1),
3534
- [Primitive.Atan]: unopBatcher(atan$1),
3535
- [Primitive.Exp]: unopBatcher(exp$1),
3536
- [Primitive.Log]: unopBatcher(log$1),
3537
- [Primitive.Erf]: unopBatcher(erf$1),
3538
- [Primitive.Erfc]: unopBatcher(erfc$1),
3539
- [Primitive.Sqrt]: unopBatcher(sqrt$1),
3540
- [Primitive.Min]: broadcastBatcher(min$1),
3541
- [Primitive.Max]: broadcastBatcher(max$1),
3542
- [Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
3543
- assertNonNull(xBdim);
3544
- const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3545
- const outBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
3546
- return [[reduce(x, op, newAxis)], [outBdim]];
3547
- },
3548
- [Primitive.Dot](axisSize, [x, y], [xBdim, yBdim]) {
3549
- x = moveBatchAxis(axisSize, xBdim, x.ndim - (xBdim === null ? 1 : 2), x);
3550
- y = moveBatchAxis(axisSize, yBdim, y.ndim - (yBdim === null ? 1 : 2), y);
3551
- const z = dot$2(x, y);
3552
- return [[z], [z.ndim - 1]];
3553
- },
3554
- [Primitive.Conv](axisSize, [x, y], [xBdim, yBdim], params) {
3555
- x = moveBatchAxis(axisSize, xBdim, 0, x);
3556
- y = moveBatchAxis(axisSize, yBdim, 0, y);
3557
- const z = conv$1(x, y, {
3558
- ...params,
3559
- vmapDims: params.vmapDims + 1
3792
+ return m;
3560
3793
  });
3561
- return [[z], [0]];
3562
- },
3563
- [Primitive.Compare](axisSize, args, dims, { op }) {
3564
- return broadcastBatcher((x, y) => compare(x, y, op))(axisSize, args, dims, {});
3794
+ if (xBdim === null) return [[gather(x, indices, axis, outDim)], [outDim]];
3795
+ else {
3796
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3797
+ const newAxis = [0, ...axis.map((ax) => ax + 1)];
3798
+ const extraBatchIndex = arange(axisSize).reshape([-1, ...rep(nd - 1, 1)]);
3799
+ indices.splice(0, 0, extraBatchIndex);
3800
+ return [[gather(x, indices, newAxis, outDim)], [outDim]];
3801
+ }
3565
3802
  },
3566
- [Primitive.Where]: broadcastBatcher(where$1),
3567
3803
  [Primitive.Transpose](axisSize, [x], [xBdim], { perm }) {
3568
3804
  assertNonNull(xBdim);
3569
3805
  const newPerm = perm.map((p) => p + (xBdim <= p ? 1 : 0));
@@ -3595,42 +3831,39 @@ const vmapRules = {
3595
3831
  const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
3596
3832
  return [[pad$1(x, newWidth)], [xBdim]];
3597
3833
  },
3598
- [Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
3599
- if (indicesBdim.every((d) => d === null)) {
3600
- assertNonNull(xBdim);
3601
- const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3602
- let newBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
3603
- let newOutDim = outDim;
3604
- if (newOutDim < newBdim) newBdim += axis.length;
3605
- else newOutDim += 1;
3606
- return [[gather(x, indices, newAxis, newOutDim)], [newBdim]];
3607
- }
3608
- const nd = Math.max(...indices.map((m, i) => ndim$1(m) + (indicesBdim[i] === null ? 1 : 0)));
3609
- indices = indices.map((m, i) => {
3610
- if (indicesBdim[i] === null) return m;
3611
- m = moveBatchAxis(axisSize, indicesBdim[i], 0, m);
3612
- if (m.ndim < nd) m = m.reshape([
3613
- m.shape[0],
3614
- ...rep(nd - m.ndim, 1),
3615
- ...m.shape.slice(1)
3834
+ [Primitive.Sort]: lastDimsBatcher(Primitive.Sort, 1),
3835
+ [Primitive.Argsort]: lastDimsBatcher(Primitive.Argsort, 1, 2),
3836
+ [Primitive.TriangularSolve](axisSize, [a, b], [aBdim, bBdim], { unitDiagonal }) {
3837
+ if (aBdim === null) {
3838
+ b = moveBatchAxis(axisSize, bBdim, -3, b);
3839
+ const [s, m, n] = b.shape.slice(-3);
3840
+ b = b.reshape([
3841
+ ...b.shape.slice(0, -3),
3842
+ s * m,
3843
+ n
3616
3844
  ]);
3617
- return m;
3618
- });
3619
- if (xBdim === null) return [[gather(x, indices, axis, outDim)], [outDim]];
3620
- else {
3621
- x = moveBatchAxis(axisSize, xBdim, 0, x);
3622
- const newAxis = [0, ...axis.map((ax) => ax + 1)];
3623
- const extraBatchIndex = arange(axisSize).reshape([-1, ...rep(nd - 1, 1)]);
3624
- indices.splice(0, 0, extraBatchIndex);
3625
- return [[gather(x, indices, newAxis, outDim)], [outDim]];
3845
+ let x$1 = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
3846
+ x$1 = x$1.reshape([
3847
+ ...b.shape.slice(0, -2),
3848
+ s,
3849
+ m,
3850
+ n
3851
+ ]);
3852
+ return [[x$1], [x$1.ndim - 3]];
3626
3853
  }
3854
+ a = moveBatchAxis(axisSize, aBdim, 0, a);
3855
+ b = moveBatchAxis(axisSize, bBdim, 0, b);
3856
+ const x = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
3857
+ return [[x], [0]];
3627
3858
  },
3628
- [Primitive.JitCall](axisSize, args, dims, { name, jaxpr }) {
3629
- const { newJaxpr, newConsts } = vmapJaxpr(jaxpr, axisSize, dims);
3630
- const outs = bind(Primitive.JitCall, [...newConsts.map((c) => c.ref), ...args], {
3859
+ [Primitive.Cholesky]: lastDimsBatcher(Primitive.Cholesky, 2),
3860
+ [Primitive.LU]: lastDimsBatcher(Primitive.LU, 2, 3),
3861
+ [Primitive.Jit](axisSize, args, dims, { name, jaxpr }) {
3862
+ const newJaxpr = vmapJaxpr(jaxpr, axisSize, dims);
3863
+ const outs = bind(Primitive.Jit, [...newJaxpr.consts.map((c) => c.ref), ...args], {
3631
3864
  name: `${name}_vmap`,
3632
- jaxpr: newJaxpr,
3633
- numConsts: newConsts.length
3865
+ jaxpr: newJaxpr.jaxpr,
3866
+ numConsts: newJaxpr.consts.length
3634
3867
  });
3635
3868
  return [outs, rep(outs.length, 0)];
3636
3869
  }
@@ -3646,14 +3879,10 @@ function vmapJaxpr(jaxpr, axisSize, dims) {
3646
3879
  shape$1.splice(dims[i], 0, axisSize);
3647
3880
  return new ShapedArray(shape$1, v.aval.dtype, v.aval.weakType);
3648
3881
  });
3649
- const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((args) => vmapFlat(jaxprAsFun(jaxpr), dims, args))(inAvals);
3650
- const result = {
3651
- newJaxpr,
3652
- newConsts
3653
- };
3882
+ const { jaxpr: newJaxpr } = makeJaxpr$1((args) => vmapFlat(jaxprAsFun(jaxpr), dims, args))(inAvals);
3654
3883
  if (!vmapJaxprCache.has(jaxpr)) vmapJaxprCache.set(jaxpr, /* @__PURE__ */ new Map());
3655
- vmapJaxprCache.get(jaxpr).set(cacheKey, result);
3656
- return result;
3884
+ vmapJaxprCache.get(jaxpr).set(cacheKey, newJaxpr);
3885
+ return newJaxpr;
3657
3886
  }
3658
3887
  function vmapFlat(f, inAxes, args) {
3659
3888
  let axisSize = void 0;
@@ -3699,13 +3928,311 @@ function vmap$1(f, inAxes = 0) {
3699
3928
  return unflatten(outTree.value, outsFlat);
3700
3929
  };
3701
3930
  }
3702
- function jacfwd$1(f) {
3703
- return function jacobianForward(x) {
3704
- if (x.shape.length !== 1) throw new TypeError("jacfwd only supports 1D inputs");
3705
- const [size$1] = x.shape;
3706
- const pushfwd = (v) => jvp$1(f, [x], [v])[1];
3707
- return vmap$1(pushfwd, [0])(eye(size$1, void 0, { dtype: x.dtype }));
3708
- };
3931
+ function jacfwd$1(f) {
3932
+ return function jacobianForward(x) {
3933
+ if (x.shape.length !== 1) throw new TypeError("jacfwd only supports 1D inputs");
3934
+ const [size$1] = x.shape;
3935
+ const pushfwd = (v) => jvp$1(f, [x], [v])[1];
3936
+ return vmap$1(pushfwd, [0])(eye(size$1, void 0, { dtype: x.dtype }));
3937
+ };
3938
+ }
3939
+
3940
+ //#endregion
3941
+ //#region src/frontend/jvp.ts
3942
+ var JVPTracer = class extends Tracer {
3943
+ constructor(trace$1, primal, tangent) {
3944
+ super(trace$1);
3945
+ this.primal = primal;
3946
+ this.tangent = tangent;
3947
+ }
3948
+ get aval() {
3949
+ return this.primal.aval;
3950
+ }
3951
+ toString() {
3952
+ return `JVPTracer(${this.primal.toString()}, ${this.tangent.toString()})`;
3953
+ }
3954
+ get ref() {
3955
+ this.primal.ref, this.tangent.ref;
3956
+ return this;
3957
+ }
3958
+ dispose() {
3959
+ this.primal.dispose();
3960
+ this.tangent.dispose();
3961
+ }
3962
+ };
3963
+ var JVPTrace = class extends Trace {
3964
+ pure(val) {
3965
+ return this.lift(pureArray(val));
3966
+ }
3967
+ lift(val) {
3968
+ return new JVPTracer(this, val, zerosLike$1(val.ref));
3969
+ }
3970
+ processPrimitive(primitive, tracers, params) {
3971
+ const [primalsIn, tangentsIn] = unzip2(tracers.map((x) => [x.primal, x.tangent]));
3972
+ const jvpRule = jvpRules[primitive];
3973
+ if (jvpRule === void 0) throw new Error(`No JVP rule for: ${primitive}`);
3974
+ const [primalsOut, tangentsOut] = jvpRule(primalsIn, tangentsIn, params);
3975
+ return zip(primalsOut, tangentsOut).map(([x, t]) => new JVPTracer(this, x, t));
3976
+ }
3977
+ };
3978
+ /** Rule that applies the same operation to primals and tangents. */
3979
+ function linearTangentsJvp(primitive) {
3980
+ return (primals, tangents, params) => {
3981
+ const ys = bind(primitive, primals, params);
3982
+ const dys = bind(primitive, tangents, params);
3983
+ return [ys, dys];
3984
+ };
3985
+ }
3986
+ /** Rule for product of gradients in bilinear operations. */
3987
+ function bilinearTangentsJvp(primitive) {
3988
+ return ([x, y], [dx, dy], params) => {
3989
+ const primal = bind1(primitive, [x.ref, y.ref], params);
3990
+ const tangent = bind1(primitive, [x, dy], params).add(bind1(primitive, [dx, y], params));
3991
+ return [[primal], [tangent]];
3992
+ };
3993
+ }
3994
+ /** Rule that zeros out any tangents. */
3995
+ function zeroTangentsJvp(primitive) {
3996
+ return (primals, tangents, params) => {
3997
+ for (const t of tangents) t.dispose();
3998
+ const ys = bind(primitive, primals, params);
3999
+ return [ys, ys.map((y) => zerosLike$1(y.ref))];
4000
+ };
4001
+ }
4002
+ /** Compute `a @ b.T`, batched to last two axes. */
4003
+ function batchMatmulT(a, b) {
4004
+ return dot$2(a.reshape(a.shape.toSpliced(-1, 0, 1)), b.reshape(b.shape.toSpliced(-2, 0, 1)));
4005
+ }
4006
+ /** Batch matrix transpose. */
4007
+ function mT(a) {
4008
+ return moveaxis(a, -2, -1);
4009
+ }
4010
+ function sliceAxis(a, axis, p) {
4011
+ const slices = Array(a.shape.length).fill([]);
4012
+ slices[checkAxis(axis, a.ndim)] = p;
4013
+ return a.slice(...slices);
4014
+ }
4015
+ function padAxis(a, axis, p) {
4016
+ const pads = Array(a.shape.length).fill([0, 0]);
4017
+ pads[checkAxis(axis, a.ndim)] = p;
4018
+ return pad$1(a, pads);
4019
+ }
4020
+ const jvpRules = {
4021
+ [Primitive.Add]: linearTangentsJvp(Primitive.Add),
4022
+ [Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
4023
+ [Primitive.Idiv]: zeroTangentsJvp(Primitive.Idiv),
4024
+ [Primitive.Mod]([x, y], [dx, dy]) {
4025
+ if (!isFloatDtype(x.dtype) && !isFloatDtype(y.dtype)) {
4026
+ dx.dispose();
4027
+ dy.dispose();
4028
+ return [[x.ref, y.ref], [zerosLike$1(x), zerosLike$1(y)]];
4029
+ }
4030
+ const q = idiv(x.ref, y.ref);
4031
+ return [[mod(x, y)], [dx.sub(dy.mul(q))]];
4032
+ },
4033
+ [Primitive.Min]([x, y], [dx, dy]) {
4034
+ return [[min$1(x.ref, y.ref)], [where$1(less$1(y, x), dy, dx)]];
4035
+ },
4036
+ [Primitive.Max]([x, y], [dx, dy]) {
4037
+ return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
4038
+ },
4039
+ [Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
4040
+ [Primitive.Reciprocal]([x], [dx]) {
4041
+ const xRecip = reciprocal$1(x.ref);
4042
+ return [[xRecip.ref], [neg(xRecip.ref.mul(xRecip)).mul(dx)]];
4043
+ },
4044
+ [Primitive.Floor]: zeroTangentsJvp(Primitive.Floor),
4045
+ [Primitive.Ceil]: zeroTangentsJvp(Primitive.Ceil),
4046
+ [Primitive.StopGradient]: zeroTangentsJvp(Primitive.StopGradient),
4047
+ [Primitive.Cast]([x], [dx], { dtype }) {
4048
+ if (x.dtype === dtype) return [[x], [dx]];
4049
+ if (isFloatDtype(dtype) && isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
4050
+ else {
4051
+ dx.dispose();
4052
+ return [[cast(x.ref, dtype)], [zerosLike$1(x)]];
4053
+ }
4054
+ },
4055
+ [Primitive.Bitcast]([x], [dx], { dtype }) {
4056
+ if (x.dtype === dtype) return [[x], [dx]];
4057
+ dx.dispose();
4058
+ return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
4059
+ },
4060
+ [Primitive.Sin]([x], [dx]) {
4061
+ return [[sin$1(x.ref)], [cos$1(x).mul(dx)]];
4062
+ },
4063
+ [Primitive.Cos]([x], [dx]) {
4064
+ return [[cos$1(x.ref)], [neg(sin$1(x)).mul(dx)]];
4065
+ },
4066
+ [Primitive.Asin]([x], [dx]) {
4067
+ const denom = sqrt$1(reciprocal$1(cast(1, x.dtype).sub(x.ref.mul(x.ref))));
4068
+ return [[asin$1(x)], [denom.mul(dx)]];
4069
+ },
4070
+ [Primitive.Atan]([x], [dx]) {
4071
+ const denom = cast(1, x.dtype).add(x.ref.mul(x.ref));
4072
+ return [[atan$1(x)], [dx.div(denom)]];
4073
+ },
4074
+ [Primitive.Exp]([x], [dx]) {
4075
+ const z = exp$1(x);
4076
+ return [[z.ref], [z.mul(dx)]];
4077
+ },
4078
+ [Primitive.Log]([x], [dx]) {
4079
+ return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
4080
+ },
4081
+ [Primitive.Erf]([x], [dx]) {
4082
+ const coeff = 2 / Math.sqrt(Math.PI);
4083
+ const expTerm = exp$1(neg(x.ref.mul(x.ref)));
4084
+ return [[erf$1(x)], [expTerm.mul(coeff).mul(dx)]];
4085
+ },
4086
+ [Primitive.Erfc]([x], [dx]) {
4087
+ const coeff = -2 / Math.sqrt(Math.PI);
4088
+ const expTerm = exp$1(neg(x.ref.mul(x.ref)));
4089
+ return [[erfc$1(x)], [expTerm.mul(coeff).mul(dx)]];
4090
+ },
4091
+ [Primitive.Sqrt]([x], [dx]) {
4092
+ const z = sqrt$1(x);
4093
+ return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
4094
+ },
4095
+ [Primitive.Reduce]([x], [dx], { op, axis }) {
4096
+ if (op === AluOp.Add) return [[reduce(x, op, axis)], [reduce(dx, op, axis)]];
4097
+ else if (op === AluOp.Mul) {
4098
+ const primal = reduce(x.ref, op, axis);
4099
+ const tangent = broadcast(primal.ref, x.shape, axis).mul(reciprocal$1(x)).mul(dx).sum(axis);
4100
+ return [[primal], [tangent]];
4101
+ } else if (op === AluOp.Min || op === AluOp.Max) {
4102
+ const primal = reduce(x.ref, op, axis);
4103
+ const notMin = notEqual$1(x, broadcast(primal.ref, x.shape, axis));
4104
+ const minCount = where$1(notMin.ref, 0, 1).sum(axis);
4105
+ const tangent = where$1(notMin, 0, dx).sum(axis).div(minCount);
4106
+ return [[primal], [tangent]];
4107
+ } else throw new Error(`JVP rule not implemented for reduce op: ${op}`);
4108
+ },
4109
+ [Primitive.Pool]: linearTangentsJvp(Primitive.Pool),
4110
+ [Primitive.PoolTranspose]: linearTangentsJvp(Primitive.PoolTranspose),
4111
+ [Primitive.Dot]: bilinearTangentsJvp(Primitive.Dot),
4112
+ [Primitive.Conv]: bilinearTangentsJvp(Primitive.Conv),
4113
+ [Primitive.Compare]: zeroTangentsJvp(Primitive.Compare),
4114
+ [Primitive.Where]([cond, x, y], [dcond, dx, dy]) {
4115
+ dcond.dispose();
4116
+ return [[where$1(cond.ref, x, y)], [where$1(cond, dx, dy)]];
4117
+ },
4118
+ [Primitive.Concatenate]: linearTangentsJvp(Primitive.Concatenate),
4119
+ [Primitive.Split]: linearTangentsJvp(Primitive.Split),
4120
+ [Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
4121
+ [Primitive.Gather]([x, ...indices], [dx, ..._], { axis, outDim }) {
4122
+ const indicesRef = indices.map((t) => t.ref);
4123
+ return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
4124
+ },
4125
+ [Primitive.Transpose]: linearTangentsJvp(Primitive.Transpose),
4126
+ [Primitive.Broadcast]: linearTangentsJvp(Primitive.Broadcast),
4127
+ [Primitive.Reshape]: linearTangentsJvp(Primitive.Reshape),
4128
+ [Primitive.Flip]: linearTangentsJvp(Primitive.Flip),
4129
+ [Primitive.Shrink]: linearTangentsJvp(Primitive.Shrink),
4130
+ [Primitive.Pad]: linearTangentsJvp(Primitive.Pad),
4131
+ [Primitive.Sort]([x], [dx]) {
4132
+ const [y, idx] = argsort$1(x);
4133
+ return [[y], [gather(dx, [idx], [-1], -1)]];
4134
+ },
4135
+ [Primitive.Argsort]([x], [dx]) {
4136
+ const [y, idx] = argsort$1(x);
4137
+ return [[y, idx.ref], [gather(dx, [idx.ref], [-1], -1), zerosLike$1(idx)]];
4138
+ },
4139
+ [Primitive.TriangularSolve]([a, b], [da, db], { unitDiagonal }) {
4140
+ const x = triangularSolve$1(a.ref, b, { unitDiagonal });
4141
+ const dax = batchMatmulT(da, x.ref);
4142
+ const rhsT = db.sub(mT(dax));
4143
+ const dx = triangularSolve$1(a, rhsT, { unitDiagonal });
4144
+ return [[x], [dx]];
4145
+ },
4146
+ [Primitive.Cholesky]([a], [da]) {
4147
+ const L = cholesky$2(a.ref);
4148
+ da = da.ref.add(mT(da)).mul(.5);
4149
+ const W = triangularSolve$1(L.ref, da, { lower: true });
4150
+ const ST = triangularSolve$1(L.ref, mT(W), { lower: true });
4151
+ const dL = batchMatmulT(L.ref, triu(ST.ref, 1).add(triu(ST)).mul(.5));
4152
+ return [[L], [dL]];
4153
+ },
4154
+ [Primitive.LU]([a], [da]) {
4155
+ const [luMatrix, pivots, permutation] = lu$1(a);
4156
+ const [m, n] = a.shape.slice(-2);
4157
+ const k = Math.min(m, n);
4158
+ const luSliceL = sliceAxis(luMatrix.ref, -1, [0, k]);
4159
+ const lLower = tril(luSliceL, -1);
4160
+ const lPadded = m > k ? padAxis(lLower, -1, [0, m - k]) : lLower;
4161
+ const L = lPadded.add(eye(m));
4162
+ const luSliceU = sliceAxis(luMatrix.ref, -2, [0, k]);
4163
+ const uUpper = triu(luSliceU);
4164
+ const uPadded = n > k ? padAxis(uUpper, -2, [0, n - k]) : uUpper;
4165
+ const uEye = n > k ? padAxis(padAxis(eye(n - k), -1, [k, 0]), -2, [k, 0]) : zerosLike$1(uPadded.ref);
4166
+ const U = uPadded.add(uEye);
4167
+ const P = permutation.ref.reshape([...permutation.shape, 1]).equal(arange(m)).astype(da.dtype);
4168
+ const pda = batchMatmulT(P, mT(da));
4169
+ const la = mT(triangularSolve$1(L.ref, mT(pda), {
4170
+ lower: true,
4171
+ unitDiagonal: true
4172
+ }));
4173
+ const lau = triangularSolve$1(mT(U.ref), la, { lower: true });
4174
+ const lDot = batchMatmulT(L, mT(tril(lau.ref, -1)));
4175
+ const uDot = batchMatmulT(triu(lau), mT(U));
4176
+ return [[
4177
+ luMatrix,
4178
+ pivots,
4179
+ permutation
4180
+ ], [
4181
+ lDot.add(uDot),
4182
+ zerosLike$1(pivots.ref),
4183
+ zerosLike$1(permutation.ref)
4184
+ ]];
4185
+ },
4186
+ [Primitive.Jit](primals, tangents, { name, jaxpr }) {
4187
+ const newJaxpr = jvpJaxpr(jaxpr);
4188
+ const outs = bind(Primitive.Jit, [
4189
+ ...newJaxpr.consts.map((c) => c.ref),
4190
+ ...primals,
4191
+ ...tangents
4192
+ ], {
4193
+ name: `${name}_jvp`,
4194
+ jaxpr: newJaxpr.jaxpr,
4195
+ numConsts: newJaxpr.consts.length
4196
+ });
4197
+ const n = outs.length / 2;
4198
+ if (!Number.isInteger(n)) throw new Error("internal: JVP Jaxpr output length is not even");
4199
+ const [primalsOut, tangentsOut] = [outs.slice(0, n), outs.slice(n)];
4200
+ return [primalsOut, tangentsOut];
4201
+ }
4202
+ };
4203
+ const jvpJaxprCache = /* @__PURE__ */ new Map();
4204
+ function jvpJaxpr(jaxpr) {
4205
+ if (jvpJaxprCache.has(jaxpr)) return jvpJaxprCache.get(jaxpr);
4206
+ const inAvals = jaxpr.inBinders.map((v) => v.aval);
4207
+ const { jaxpr: newJaxpr } = makeJaxpr$1((primals, tangents) => jvpFlat(jaxprAsFun(jaxpr), primals, tangents))(inAvals, inAvals);
4208
+ jvpJaxprCache.set(jaxpr, newJaxpr);
4209
+ return newJaxpr;
4210
+ }
4211
+ function jvpFlat(f, primals, tangents) {
4212
+ try {
4213
+ var _usingCtx$1 = _usingCtx();
4214
+ const main = _usingCtx$1.u(newMain(JVPTrace));
4215
+ const trace$1 = new JVPTrace(main);
4216
+ const tracersIn = zip(primals, tangents).map(([x, t]) => new JVPTracer(trace$1, pureArray(x), pureArray(t)));
4217
+ const outs = f(...tracersIn);
4218
+ const tracersOut = outs.map((out) => fullRaise(trace$1, out));
4219
+ return unzip2(tracersOut.map((t) => [t.primal, t.tangent]));
4220
+ } catch (_) {
4221
+ _usingCtx$1.e = _;
4222
+ } finally {
4223
+ _usingCtx$1.d();
4224
+ }
4225
+ }
4226
+ function jvp$1(f, primals, tangents) {
4227
+ const [primalsFlat, inTree] = flatten(primals);
4228
+ const [tangentsFlat, inTree2] = flatten(tangents);
4229
+ if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
4230
+ const [flatFun, outTree] = flattenFun(f, inTree);
4231
+ const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
4232
+ if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
4233
+ const primalsOut = unflatten(outTree.value, primalsOutFlat);
4234
+ const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
4235
+ return [primalsOut, tangentsOut];
3709
4236
  }
3710
4237
 
3711
4238
  //#endregion
@@ -3738,11 +4265,10 @@ function partialEvalFlat(f, pvalsIn) {
3738
4265
  const tracersOut = outs.map((out) => fullRaise(trace$1, out));
3739
4266
  const pvalsOut = tracersOut.map((t) => t.pval);
3740
4267
  const unknownTracersOut = tracersOut.filter((t) => !t.pval.isKnown);
3741
- const { jaxpr, consts } = partialEvalGraphToJaxpr(unknownTracersIn, unknownTracersOut);
4268
+ const jaxpr = partialEvalGraphToJaxpr(unknownTracersIn, unknownTracersOut);
3742
4269
  return {
3743
4270
  jaxpr,
3744
- pvalsOut,
3745
- consts
4271
+ pvalsOut
3746
4272
  };
3747
4273
  }
3748
4274
  /**
@@ -3759,22 +4285,19 @@ function linearizeFlatUtil(f, primalsIn) {
3759
4285
  const [primalsOut$1, tangentsOut] = jvp$1(f, x.slice(0, k), x.slice(k, 2 * k));
3760
4286
  return [...primalsOut$1, ...tangentsOut];
3761
4287
  };
3762
- const { jaxpr, pvalsOut, consts } = partialEvalFlat(fJvp, pvalsIn);
4288
+ const { jaxpr, pvalsOut } = partialEvalFlat(fJvp, pvalsIn);
3763
4289
  const primalPvals = pvalsOut.slice(0, pvalsOut.length / 2);
3764
4290
  if (!primalPvals.every((pval) => pval.isKnown)) throw new Error("Not all primal values are known after partial evaluation");
3765
4291
  const primalsOut = primalPvals.map((pval) => pval.val);
3766
4292
  return {
3767
4293
  primalsOut,
3768
- jaxpr,
3769
- consts
4294
+ jaxpr
3770
4295
  };
3771
4296
  }
3772
4297
  function linearizeFlat(f, primalsIn) {
3773
- const { primalsOut, jaxpr, consts } = linearizeFlatUtil(f, primalsIn);
3774
- const fLin = (...tangents) => evalJaxpr(jaxpr, [...consts.map((c) => c.ref), ...tangents]);
3775
- const dispose$1 = () => {
3776
- for (const c of consts) c.dispose();
3777
- };
4298
+ const { primalsOut, jaxpr } = linearizeFlatUtil(f, primalsIn);
4299
+ const fLin = (...tangents) => evalJaxpr(jaxpr.jaxpr, [...jaxpr.consts.map((c) => c.ref), ...tangents]);
4300
+ const dispose$1 = () => jaxpr.dispose();
3778
4301
  return [
3779
4302
  primalsOut,
3780
4303
  fLin,
@@ -3858,7 +4381,7 @@ var PartialEvalTrace = class extends Trace {
3858
4381
  }
3859
4382
  processPrimitive(primitive, tracers, params) {
3860
4383
  if (tracers.every((t) => t.pval.isKnown)) return bind(primitive, tracers.map((t) => t.fullLower()), params);
3861
- if (primitive === Primitive.JitCall) {
4384
+ if (primitive === Primitive.Jit) {
3862
4385
  const { name, jaxpr, numConsts } = params;
3863
4386
  return this.#partialEvalJaxpr(name, jaxpr, numConsts, tracers);
3864
4387
  }
@@ -3884,14 +4407,14 @@ var PartialEvalTrace = class extends Trace {
3884
4407
  * Evaluate a Jaxpr on a set of PartialEvalTracers, computing as many known
3885
4408
  * values as possible (with JIT) and forwarding the unknown ones.
3886
4409
  *
3887
- * Used when encountering a JitCall rule during the trace.
4410
+ * Used when encountering a Jit rule during the trace.
3888
4411
  */
3889
4412
  #partialEvalJaxpr(name, jaxpr, numConsts, tracers) {
3890
4413
  jaxpr = jaxpr.flatten();
3891
4414
  const inUnknowns = tracers.map((t) => !t.pval.isKnown);
3892
4415
  const { jaxpr1, jaxpr2, outUnknowns, numRes } = partialEvalJaxpr(jaxpr, inUnknowns);
3893
4416
  const [knownTracers, unknownTracers] = partitionList(inUnknowns, tracers);
3894
- const outs1Res = bind(Primitive.JitCall, knownTracers.map((t) => t.ref.fullLower()), {
4417
+ const outs1Res = bind(Primitive.Jit, knownTracers.map((t) => t.ref.fullLower()), {
3895
4418
  name: `${name}_peval`,
3896
4419
  jaxpr: jaxpr1,
3897
4420
  numConsts: 0
@@ -3901,7 +4424,7 @@ var PartialEvalTrace = class extends Trace {
3901
4424
  const resTracers = res.map((x) => this.instantiateConst(fullRaise(this, x)));
3902
4425
  const recipe = {
3903
4426
  type: "JaxprEqn",
3904
- prim: Primitive.JitCall,
4427
+ prim: Primitive.Jit,
3905
4428
  tracersIn: resTracers.concat(unknownTracers),
3906
4429
  params: {
3907
4430
  name: `${name}_resid`,
@@ -3930,7 +4453,7 @@ function partialEvalJaxpr(jaxpr, inUnknowns, instantiate) {
3930
4453
  const eqns1 = [];
3931
4454
  const eqns2 = [];
3932
4455
  for (const eqn of jaxpr.eqns) {
3933
- if (eqn.primitive === Primitive.JitCall) throw new TypeError("partialEvalJaxpr requires flattened Jaxpr");
4456
+ if (eqn.primitive === Primitive.Jit) throw new TypeError("partialEvalJaxpr requires flattened Jaxpr");
3934
4457
  const hasUnknowns = eqn.inputs.some((x) => x instanceof Var && !knownVars.has(x));
3935
4458
  if (hasUnknowns) {
3936
4459
  for (const x of eqn.inputs) if (x instanceof Var && knownVars.has(x)) residuals.add(x);
@@ -4005,10 +4528,7 @@ function partialEvalGraphToJaxpr(tracersIn, tracersOut) {
4005
4528
  for (const t of tracersOut) t.dispose();
4006
4529
  jaxpr = jaxpr.simplify();
4007
4530
  if (DEBUG >= 5) console.info("jaxpr from partial evaluation:\n" + jaxpr.toString());
4008
- return {
4009
- jaxpr,
4010
- consts
4011
- };
4531
+ return new ClosedJaxpr(jaxpr, consts);
4012
4532
  }
4013
4533
  /** Marker type for pullback, used by transpose rules. */
4014
4534
  var UndefPrimal = class {
@@ -4200,317 +4720,151 @@ const transposeRules = {
4200
4720
  cond.dispose();
4201
4721
  return cts;
4202
4722
  },
4203
- [Primitive.Transpose]([ct], [x], { perm }) {
4204
- if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Transpose);
4205
- return [transpose$1(ct, invertPermutation(perm))];
4206
- },
4207
- [Primitive.Broadcast]([ct], [x], { axis }) {
4208
- if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Broadcast);
4209
- return [reduce(ct, AluOp.Add, axis)];
4210
- },
4211
- [Primitive.Reshape]([ct], [x], _) {
4212
- if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Reshape);
4213
- return [reshape$1(ct, x.aval.shape)];
4214
- },
4215
- [Primitive.Flip]([ct], [x], { axis }) {
4216
- if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Flip);
4217
- return [flip$1(ct, axis)];
4218
- },
4219
- [Primitive.Shrink]([ct], [x], { slice }) {
4220
- if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Shrink);
4221
- const width = slice.map(([s, e$1], i) => [s, x.aval.shape[i] - e$1]);
4222
- return [pad$1(ct, width)];
4723
+ [Primitive.Concatenate]([ct], inputs, { axis }) {
4724
+ if (inputs.some((x) => !(x instanceof UndefPrimal))) throw new NonlinearError(Primitive.Concatenate);
4725
+ const sizes = inputs.map((x) => x.aval.shape[axis]);
4726
+ return split$2(ct, axis, sizes);
4223
4727
  },
4224
- [Primitive.Pad]([ct], [x], { width }) {
4225
- if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Pad);
4226
- const slice = width.map(([s, _e], i) => [s, s + x.aval.shape[i]]);
4227
- return [shrink(ct, slice)];
4728
+ [Primitive.Split](cts, [x], { axis }) {
4729
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Split);
4730
+ return [concatenate$1(cts, axis)];
4228
4731
  },
4229
4732
  [Primitive.Gather]([ct], [x, ...indices], { axis, outDim }) {
4230
4733
  if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
4231
4734
  if (indices.some((i) => i instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
4232
- throw new Error("Gather transpose rule is not yet implemented, requires complex Scatter sum operation");
4233
- },
4234
- [Primitive.JitCall](cts, args, { name, jaxpr }) {
4235
- const undefPrimals = args.map((x) => x instanceof UndefPrimal);
4236
- const { newJaxpr, newConsts } = transposeJaxpr(jaxpr, undefPrimals);
4237
- const residuals = args.filter((x, i$1) => !undefPrimals[i$1]);
4238
- const outs = bind(Primitive.JitCall, [
4239
- ...newConsts.map((c) => c.ref),
4240
- ...residuals,
4241
- ...cts
4242
- ], {
4243
- name: `${name}_t`,
4244
- jaxpr: newJaxpr,
4245
- numConsts: newConsts.length
4246
- });
4247
- let i = 0;
4248
- return undefPrimals.map((isUndef) => isUndef ? outs[i++] : null);
4249
- }
4250
- };
4251
- const transposeJaxprCache = /* @__PURE__ */ new Map();
4252
- function transposeJaxpr(jaxpr, undefPrimals) {
4253
- const cacheKey = JSON.stringify(undefPrimals);
4254
- const prevResult = transposeJaxprCache.get(jaxpr)?.get(cacheKey);
4255
- if (prevResult) return prevResult;
4256
- const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
4257
- const forwardInTypes = inTypes.filter((_, i) => !undefPrimals[i]);
4258
- const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((forwardIn, cotangents) => {
4259
- const args = [];
4260
- let forwardInIdx = 0;
4261
- for (let i = 0; i < undefPrimals.length; i++) if (undefPrimals[i]) args.push(new UndefPrimal(inTypes[i]));
4262
- else args.push(forwardIn[forwardInIdx++]);
4263
- return evalJaxprTransposed(jaxpr, args, cotangents);
4264
- })(forwardInTypes, outTypes);
4265
- typecheckJaxpr(newJaxpr);
4266
- const result = {
4267
- newJaxpr,
4268
- newConsts
4269
- };
4270
- if (!transposeJaxprCache.has(jaxpr)) transposeJaxprCache.set(jaxpr, /* @__PURE__ */ new Map());
4271
- transposeJaxprCache.get(jaxpr).set(cacheKey, result);
4272
- return result;
4273
- }
4274
- function vjpFlat(f, primalsIn) {
4275
- const { primalsOut, jaxpr, consts } = linearizeFlatUtil(f, primalsIn);
4276
- const fVjp = (...cotangents) => {
4277
- const transposeInputs = [...consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
4278
- return evalJaxprTransposed(jaxpr, transposeInputs, cotangents);
4279
- };
4280
- const dispose$1 = () => {
4281
- for (const c of consts) c.dispose();
4282
- };
4283
- return [
4284
- primalsOut,
4285
- fVjp,
4286
- dispose$1
4287
- ];
4288
- }
4289
- function vjp$1(f, ...primalsIn) {
4290
- const [primalsInFlat, inTree] = flatten(primalsIn);
4291
- const [fFlat, outTree] = flattenFun(f, inTree);
4292
- const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
4293
- if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
4294
- const primalsOut = unflatten(outTree.value, primalsOutFlat);
4295
- const fVjp = ((cotangentsOut) => {
4296
- const [cotangentsOutFlat, outTree2] = flatten(cotangentsOut);
4297
- if (!outTree.value.equals(outTree2)) throw new TreeMismatchError("vjp", outTree.value, outTree2);
4298
- const cotangentsInFlat = fVjpFlat(...cotangentsOutFlat.map(pureArray));
4299
- return unflatten(inTree, cotangentsInFlat);
4300
- });
4301
- fVjp.dispose = dispose$1;
4302
- return [primalsOut, fVjp];
4303
- }
4304
- function grad$1(f) {
4305
- const valueAndGradFn = valueAndGrad$1(f);
4306
- return (...x) => {
4307
- const [y, dx] = valueAndGradFn(...x);
4308
- y.dispose();
4309
- return dx;
4310
- };
4311
- }
4312
- function valueAndGrad$1(f) {
4313
- return (...x) => {
4314
- if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
4315
- const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
4316
- if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
4317
- if (!isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
4318
- const [ct, ...rest] = fVjp(onesLike$1(y.ref));
4319
- for (const r of rest) dispose(r);
4320
- fVjp.dispose();
4321
- return [y, ct];
4322
- };
4323
- }
4324
- function jacrev$1(f) {
4325
- return function jacobianReverse(x) {
4326
- if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
4327
- const [size$1] = x.shape;
4328
- const pullback = (ct) => {
4329
- const [y, fVjp] = vjp$1(f, x);
4330
- y.dispose();
4331
- const [ret] = fVjp(ct);
4332
- fVjp.dispose();
4333
- return ret;
4334
- };
4335
- return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
4336
- };
4337
- }
4338
-
4339
- //#endregion
4340
- //#region src/library/lax.ts
4341
- var lax_exports = {};
4342
- __export(lax_exports, {
4343
- conv: () => conv,
4344
- convGeneralDilated: () => convGeneralDilated,
4345
- convWithGeneralPadding: () => convWithGeneralPadding,
4346
- dot: () => dot$1,
4347
- erf: () => erf,
4348
- erfc: () => erfc,
4349
- reduceWindow: () => reduceWindow,
4350
- stopGradient: () => stopGradient$1
4351
- });
4352
- /**
4353
- * General dot product/contraction operator.
4354
- *
4355
- * Prefer higher-level functions like `jax.numpy.dot()`, `jax.numpy.matmul()`,
4356
- * `jax.numpy.tensordot(), and `jax.numpy.einsum()` where possible.
4357
- */
4358
- function dot$1(lhs, rhs, { lhsContractingDims: lc = [], rhsContractingDims: rc = [], lhsBatchDims: lb = [], rhsBatchDims: rb = [] } = {}) {
4359
- if (lc.length !== rc.length) throw new Error(`dot: contracting dims lengths mismatch, got ${JSON.stringify(lc)} and ${JSON.stringify(rc)}`);
4360
- else if (lb.length !== rb.length) throw new Error(`dot: batch dims lengths mismatch, got ${JSON.stringify(lb)} and ${JSON.stringify(rb)}`);
4361
- lc = lc.map((a) => checkAxis(a, lhs.ndim));
4362
- rc = rc.map((a) => checkAxis(a, rhs.ndim));
4363
- lb = lb.map((a) => checkAxis(a, lhs.ndim));
4364
- rb = rb.map((a) => checkAxis(a, rhs.ndim));
4365
- if (lc.some((a) => lb.includes(a))) throw new Error(`dot: lhs contracting dims ${JSON.stringify(lc)} overlap with batch dims ${JSON.stringify(lb)}`);
4366
- else if (rc.some((a) => rb.includes(a))) throw new Error(`dot: rhs contracting dims ${JSON.stringify(rc)} overlap with batch dims ${JSON.stringify(rb)}`);
4367
- const lf = range(lhs.ndim).filter((a) => !lc.includes(a) && !lb.includes(a));
4368
- const rf = range(rhs.ndim).filter((a) => !rc.includes(a) && !rb.includes(a));
4369
- const lhs2 = lhs.transpose([
4370
- ...lb,
4371
- ...lf,
4372
- ...lc
4373
- ]);
4374
- const rhs2 = rhs.transpose([
4375
- ...rb,
4376
- ...rf,
4377
- ...rc
4378
- ]);
4379
- if (lc.length === 0) return mul(lhs2.reshape([
4380
- ...lb.map((a) => lhs.shape[a]),
4381
- ...lf.map((a) => lhs.shape[a]),
4382
- ...rep(rf.length, 1)
4383
- ]), rhs2.reshape([
4384
- ...rb.map((a) => rhs.shape[a]),
4385
- ...rep(lf.length, 1),
4386
- ...rf.map((a) => rhs.shape[a])
4387
- ]));
4388
- const dotShapeX = lc.map((a) => lhs.shape[a]);
4389
- const dotShapeY = rc.map((a) => rhs.shape[a]);
4390
- if (!deepEqual(dotShapeX, dotShapeY)) throw new Error(`dot: shapes not aligned along contracting dims: ${JSON.stringify(dotShapeX)} != ${JSON.stringify(dotShapeY)}`);
4391
- return dot$2(lhs2.reshape([
4392
- ...lb.map((a) => lhs.shape[a]),
4393
- ...lf.map((a) => lhs.shape[a]),
4394
- ...rep(rf.length, 1),
4395
- prod(dotShapeX)
4396
- ]), rhs2.reshape([
4397
- ...rb.map((a) => rhs.shape[a]),
4398
- ...rep(lf.length, 1),
4399
- ...rf.map((a) => rhs.shape[a]),
4400
- prod(dotShapeY)
4401
- ]));
4402
- }
4403
- function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
4404
- const padType = padding.toUpperCase();
4405
- switch (padType) {
4406
- case "VALID": return rep(inShape.length, [0, 0]);
4407
- case "SAME":
4408
- case "SAME_LOWER": {
4409
- const outShape = inShape.map((size$1, i) => Math.ceil(size$1 / strides[i]));
4410
- const padSizes = zipn(outShape, strides, filterShape, dilation, inShape).map(([o, s, k, d, i]) => Math.max(0, (o - 1) * s + 1 + (k - 1) * d - i));
4411
- if (padType === "SAME") return padSizes.map((size$1) => [size$1 >> 1, size$1 - (size$1 >> 1)]);
4412
- else return padSizes.map((size$1) => [size$1 - (size$1 >> 1), size$1 >> 1]);
4413
- }
4414
- default: throw new Error(`Unknown padding type: ${padType}`);
4415
- }
4416
- }
4417
- /**
4418
- * General n-dimensional convolution operator, with optional dilation.
4419
- *
4420
- * The semantics of this operation mimic the `jax.lax.conv_general_dilated`
4421
- * function in JAX, which wraps XLA's general convolution operator.
4422
- *
4423
- * Grouped convolutions are not supported right now.
4424
- */
4425
- function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
4426
- if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
4427
- if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
4428
- if (typeof padding === "string") {
4429
- if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
4430
- padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? rep(rhs.ndim - 2, 1), padding);
4431
- }
4432
- if (featureGroupCount !== 1) {
4433
- const G = featureGroupCount;
4434
- const [N, C_in, ...xs] = lhs.shape;
4435
- const [C_out, C_in_per_group, ...ks] = rhs.shape;
4436
- if (C_in % G !== 0) throw new Error(`featureGroupCount=${G} must divide input channels=${C_in}`);
4437
- if (C_out % G !== 0) throw new Error(`featureGroupCount=${G} must divide output channels=${C_out}`);
4438
- if (C_in / G !== C_in_per_group) throw new Error(`rhs input channels=${C_in_per_group} must equal lhs input channels / groups=${C_in / G}`);
4439
- const lhsGrouped = moveaxis(lhs.reshape([
4440
- N,
4441
- G,
4442
- C_in / G,
4443
- ...xs
4444
- ]), 1, 0);
4445
- const rhsGrouped = rhs.reshape([
4446
- G,
4447
- C_out / G,
4448
- C_in_per_group,
4449
- ...ks
4450
- ]);
4451
- const result = conv$1(lhsGrouped, rhsGrouped, {
4452
- vmapDims: 1,
4453
- strides: windowStrides,
4454
- padding,
4455
- lhsDilation,
4456
- rhsDilation
4735
+ throw new Error("Gather transpose rule is not yet implemented, requires complex Scatter sum operation");
4736
+ },
4737
+ [Primitive.Transpose]([ct], [x], { perm }) {
4738
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Transpose);
4739
+ return [transpose$1(ct, invertPermutation(perm))];
4740
+ },
4741
+ [Primitive.Broadcast]([ct], [x], { axis }) {
4742
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Broadcast);
4743
+ return [reduce(ct, AluOp.Add, axis)];
4744
+ },
4745
+ [Primitive.Reshape]([ct], [x], _) {
4746
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Reshape);
4747
+ return [reshape$1(ct, x.aval.shape)];
4748
+ },
4749
+ [Primitive.Flip]([ct], [x], { axis }) {
4750
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Flip);
4751
+ return [flip$1(ct, axis)];
4752
+ },
4753
+ [Primitive.Shrink]([ct], [x], { slice }) {
4754
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Shrink);
4755
+ const width = slice.map(([s, e$1], i) => [s, x.aval.shape[i] - e$1]);
4756
+ return [pad$1(ct, width)];
4757
+ },
4758
+ [Primitive.Pad]([ct], [x], { width }) {
4759
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Pad);
4760
+ const slice = width.map(([s, _e], i) => [s, s + x.aval.shape[i]]);
4761
+ return [shrink(ct, slice)];
4762
+ },
4763
+ [Primitive.TriangularSolve]([ct], [a, b], { unitDiagonal }) {
4764
+ if (a instanceof UndefPrimal || !(b instanceof UndefPrimal)) throw new NonlinearError(Primitive.TriangularSolve);
4765
+ const ctB = triangularSolve$1(moveaxis(a, -2, -1), ct, {
4766
+ lower: true,
4767
+ unitDiagonal
4457
4768
  });
4458
- const ys = result.shape.slice(3);
4459
- return moveaxis(result, 0, 1).reshape([
4460
- N,
4461
- C_out,
4462
- ...ys
4463
- ]);
4769
+ return [null, ctB];
4770
+ },
4771
+ [Primitive.Jit](cts, args, { name, jaxpr }) {
4772
+ const undefPrimals = args.map((x) => x instanceof UndefPrimal);
4773
+ const newJaxpr = transposeJaxpr(jaxpr, undefPrimals);
4774
+ const residuals = args.filter((x, i$1) => !undefPrimals[i$1]);
4775
+ const outs = bind(Primitive.Jit, [
4776
+ ...newJaxpr.consts.map((c) => c.ref),
4777
+ ...residuals,
4778
+ ...cts
4779
+ ], {
4780
+ name: `${name}_t`,
4781
+ jaxpr: newJaxpr.jaxpr,
4782
+ numConsts: newJaxpr.consts.length
4783
+ });
4784
+ let i = 0;
4785
+ return undefPrimals.map((isUndef) => isUndef ? outs[i++] : null);
4464
4786
  }
4465
- return conv$1(lhs, rhs, {
4466
- strides: windowStrides,
4467
- padding,
4468
- lhsDilation,
4469
- rhsDilation
4470
- });
4471
- }
4472
- /** Convenience wrapper around `convGeneralDilated`. */
4473
- function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation) {
4474
- return convGeneralDilated(lhs, rhs, windowStrides, padding, {
4475
- lhsDilation,
4476
- rhsDilation
4477
- });
4787
+ };
4788
+ const transposeJaxprCache = /* @__PURE__ */ new Map();
4789
+ function transposeJaxpr(jaxpr, undefPrimals) {
4790
+ const cacheKey = JSON.stringify(undefPrimals);
4791
+ const prevResult = transposeJaxprCache.get(jaxpr)?.get(cacheKey);
4792
+ if (prevResult) return prevResult;
4793
+ const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
4794
+ const forwardInTypes = inTypes.filter((_, i) => !undefPrimals[i]);
4795
+ const { jaxpr: newJaxpr } = makeJaxpr$1((forwardIn, cotangents) => {
4796
+ const args = [];
4797
+ let forwardInIdx = 0;
4798
+ for (let i = 0; i < undefPrimals.length; i++) if (undefPrimals[i]) args.push(new UndefPrimal(inTypes[i]));
4799
+ else args.push(forwardIn[forwardInIdx++]);
4800
+ return evalJaxprTransposed(jaxpr, args, cotangents);
4801
+ })(forwardInTypes, outTypes);
4802
+ typecheckJaxpr(newJaxpr.jaxpr);
4803
+ if (!transposeJaxprCache.has(jaxpr)) transposeJaxprCache.set(jaxpr, /* @__PURE__ */ new Map());
4804
+ transposeJaxprCache.get(jaxpr).set(cacheKey, newJaxpr);
4805
+ return newJaxpr;
4478
4806
  }
4479
- /** Convenience wrapper around `convGeneralDilated`. */
4480
- function conv(lhs, rhs, windowStrides, padding) {
4481
- return convGeneralDilated(lhs, rhs, windowStrides, padding);
4807
+ function vjpFlat(f, primalsIn) {
4808
+ const { primalsOut, jaxpr } = linearizeFlatUtil(f, primalsIn);
4809
+ const fVjp = (...cotangents) => {
4810
+ const transposeInputs = [...jaxpr.consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
4811
+ return evalJaxprTransposed(jaxpr.jaxpr, transposeInputs, cotangents);
4812
+ };
4813
+ const dispose$1 = () => jaxpr.dispose();
4814
+ return [
4815
+ primalsOut,
4816
+ fVjp,
4817
+ dispose$1
4818
+ ];
4482
4819
  }
4483
- /** Reduce a computation over padded windows. */
4484
- function reduceWindow(operand, computation, windowDimensions, windowStrides) {
4485
- if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
4486
- if (!windowStrides) windowStrides = rep(windowDimensions.length, 1);
4487
- for (let i = 0; i < operand.ndim; i++) computation = vmap$1(computation, 0);
4488
- return computation(bind1(Primitive.Pool, [operand], {
4489
- window: windowDimensions,
4490
- strides: windowStrides
4491
- }));
4820
+ function vjp$1(f, ...primalsIn) {
4821
+ const [primalsInFlat, inTree] = flatten(primalsIn);
4822
+ const [fFlat, outTree] = flattenFun(f, inTree);
4823
+ const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
4824
+ if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
4825
+ const primalsOut = unflatten(outTree.value, primalsOutFlat);
4826
+ const fVjp = ((cotangentsOut) => {
4827
+ const [cotangentsOutFlat, outTree2] = flatten(cotangentsOut);
4828
+ if (!outTree.value.equals(outTree2)) throw new TreeMismatchError("vjp", outTree.value, outTree2);
4829
+ const cotangentsInFlat = fVjpFlat(...cotangentsOutFlat.map(pureArray));
4830
+ return unflatten(inTree, cotangentsInFlat);
4831
+ });
4832
+ fVjp.dispose = dispose$1;
4833
+ return [primalsOut, fVjp];
4492
4834
  }
4493
- /** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
4494
- function erf(x) {
4495
- return erf$1(x);
4835
+ function grad$1(f) {
4836
+ const valueAndGradFn = valueAndGrad$1(f);
4837
+ return (...x) => {
4838
+ const [y, dx] = valueAndGradFn(...x);
4839
+ y.dispose();
4840
+ return dx;
4841
+ };
4496
4842
  }
4497
- /**
4498
- * The complementary error function: `erfc(x) = 1 - erf(x)`.
4499
- *
4500
- * This function is more accurate than `1 - erf(x)` for large values of `x`,
4501
- * where `erf(x)` is very close to 1.
4502
- */
4503
- function erfc(x) {
4504
- return erfc$1(x);
4843
+ function valueAndGrad$1(f) {
4844
+ return (...x) => {
4845
+ if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
4846
+ const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
4847
+ if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
4848
+ if (!isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
4849
+ const [ct, ...rest] = fVjp(onesLike$1(y.ref));
4850
+ for (const r of rest) dispose(r);
4851
+ fVjp.dispose();
4852
+ return [y, ct];
4853
+ };
4505
4854
  }
4506
- /**
4507
- * Stops gradient computation.
4508
- *
4509
- * Behaves as the identity function but prevents the flow of gradients during
4510
- * forward or reverse-mode automatic differentiation.
4511
- */
4512
- function stopGradient$1(x) {
4513
- return stopGradient(x);
4855
+ function jacrev$1(f) {
4856
+ return function jacobianReverse(x) {
4857
+ if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
4858
+ const [size$1] = x.shape;
4859
+ const pullback = (ct) => {
4860
+ const [y, fVjp] = vjp$1(f, x);
4861
+ y.dispose();
4862
+ const [ret] = fVjp(ct);
4863
+ fVjp.dispose();
4864
+ return ret;
4865
+ };
4866
+ return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
4867
+ };
4514
4868
  }
4515
4869
 
4516
4870
  //#endregion
@@ -4650,8 +5004,8 @@ function computeSizeMap({ shapes, lhsIndices, rhsIndex }) {
4650
5004
  const idx = lhsIndex[j];
4651
5005
  const dim = shape$1[j];
4652
5006
  const existing = sizeMap.get(idx);
4653
- if (existing === void 0) sizeMap.set(idx, dim);
4654
- else if (existing !== dim) throw new Error(`Inconsistent size for index ${idx} in einsum: ${existing} vs ${dim}`);
5007
+ if (existing === void 0 || existing === 1) sizeMap.set(idx, dim);
5008
+ else if (existing !== dim && dim !== 1) throw new Error(`Inconsistent size for index ${idx} in einsum: ${existing} vs ${dim}`);
4655
5009
  }
4656
5010
  }
4657
5011
  for (const [idx, size$1] of sizeMap) if (!Number.isInteger(idx) || idx < 0) throw new Error(`Invalid index ${idx} in einsum expression, must be non-negative integer`);
@@ -4659,52 +5013,410 @@ function computeSizeMap({ shapes, lhsIndices, rhsIndex }) {
4659
5013
  for (const idx of rhsIndex) if (!sizeMap.has(idx)) throw new Error(`Output index ${idx} not present in einsum inputs`);
4660
5014
  return sizeMap;
4661
5015
  }
4662
- const einsumPathCache = /* @__PURE__ */ new Map();
4663
- function computeEinsumPath(input, method) {
4664
- if (!method) method = input.shapes.length <= 5 ? "optimal" : "naive";
4665
- return runWithCache(einsumPathCache, [input, method], () => {
4666
- const sizeMap = computeSizeMap(input);
4667
- if (input.shapes.length === 1) return new EinsumPath(input, sizeMap, []);
4668
- switch (method) {
4669
- case "naive": return computePathNaive(input, sizeMap);
4670
- case "optimal": return computePathOptimal(input, sizeMap);
4671
- default: throw new Error(`Unknown computePath method: ${method}`);
4672
- }
4673
- });
5016
+ const einsumPathCache = /* @__PURE__ */ new Map();
5017
+ function computeEinsumPath(input, method) {
5018
+ if (!method) method = input.shapes.length <= 5 ? "optimal" : "naive";
5019
+ return runWithCache(einsumPathCache, [input, method], () => {
5020
+ const sizeMap = computeSizeMap(input);
5021
+ if (input.shapes.length === 1) return new EinsumPath(input, sizeMap, []);
5022
+ switch (method) {
5023
+ case "naive": return computePathNaive(input, sizeMap);
5024
+ case "optimal": return computePathOptimal(input, sizeMap);
5025
+ default: throw new Error(`Unknown computePath method: ${method}`);
5026
+ }
5027
+ });
5028
+ }
5029
+ function computePathNaive(input, sizeMap) {
5030
+ const n = input.shapes.length;
5031
+ const path = [];
5032
+ let lastTensorIndex = 0;
5033
+ for (let i = 1; i < n; i++) {
5034
+ path.push([lastTensorIndex, i]);
5035
+ lastTensorIndex = n + i - 1;
5036
+ }
5037
+ return new EinsumPath(input, sizeMap, path);
5038
+ }
5039
+ function computePathOptimal(input, sizeMap) {
5040
+ const n = input.shapes.length;
5041
+ let bestPath = null;
5042
+ let bestFlops = null;
5043
+ for (const path of allPaths(range(n), n)) {
5044
+ const flops = approximatePathFlops(input, sizeMap, path);
5045
+ if (bestFlops === null || flops < bestFlops) {
5046
+ bestPath = path;
5047
+ bestFlops = flops;
5048
+ }
5049
+ }
5050
+ return new EinsumPath(input, sizeMap, bestPath);
5051
+ }
5052
+ function* allPaths(tensors, next) {
5053
+ if (tensors.length === 2) {
5054
+ yield [[tensors[0], tensors[1]]];
5055
+ return;
5056
+ }
5057
+ for (let i = 0; i < tensors.length; i++) for (let j = i + 1; j < tensors.length; j++) {
5058
+ const pair = [tensors[i], tensors[j]];
5059
+ const newTensors = tensors.filter((t) => t !== pair[0] && t !== pair[1]);
5060
+ newTensors.push(next);
5061
+ for (const subpath of allPaths(newTensors, next + 1)) yield [pair, ...subpath];
5062
+ }
5063
+ }
5064
+
5065
+ //#endregion
5066
+ //#region src/library/numpy-fft.ts
5067
+ var numpy_fft_exports = {};
5068
+ __export(numpy_fft_exports, {
5069
+ fft: () => fft,
5070
+ ifft: () => ifft
5071
+ });
5072
+ function checkPairInput(name, a) {
5073
+ const fullName = `jax.numpy.fft.${name}`;
5074
+ if (!deepEqual(a.real.shape, a.imag.shape)) throw new Error(`${fullName}: real and imaginary parts must have the same shape, got ${JSON.stringify(a.real.shape)} and ${JSON.stringify(a.imag.shape)}`);
5075
+ if (a.real.dtype !== a.imag.dtype) throw new Error(`${fullName}: real and imaginary parts must have the same dtype, got ${a.real.dtype} and ${a.imag.dtype}`);
5076
+ if (!isFloatDtype(a.real.dtype)) throw new Error(`${fullName}: input must have a float dtype, got ${a.real.dtype}`);
5077
+ }
5078
+ function checkPowerOfTwo(name, n) {
5079
+ if ((n & n - 1) !== 0) throw new Error(`jax.numpy.fft.${name}: size must be a power of two, got ${n}`);
5080
+ }
5081
+ const fftUpdate = jit$1(function fftUpdate$1(i, { real, imag }) {
5082
+ const half = 2 ** i;
5083
+ real = real.reshape([-1, 2 * half]);
5084
+ imag = imag.reshape([-1, 2 * half]);
5085
+ const k = arange(0, half, 1, { dtype: real.dtype });
5086
+ const theta = k.mul(-Math.PI / half);
5087
+ const wr = cos(theta.ref);
5088
+ const wi = sin(theta);
5089
+ const ur = real.ref.slice([], [0, half]);
5090
+ const ui = imag.ref.slice([], [0, half]);
5091
+ const vr = real.slice([], [half, 2 * half]);
5092
+ const vi = imag.slice([], [half, 2 * half]);
5093
+ const tr = vr.ref.mul(wr.ref).sub(vi.ref.mul(wi.ref));
5094
+ const ti = vr.mul(wi).add(vi.mul(wr));
5095
+ return {
5096
+ real: concatenate([ur.ref.add(tr.ref), ur.sub(tr)], -1),
5097
+ imag: concatenate([ui.ref.add(ti.ref), ui.sub(ti)], -1)
5098
+ };
5099
+ }, { staticArgnums: [0] });
5100
+ /**
5101
+ * Compute a one-dimensional discrete Fourier transform.
5102
+ *
5103
+ * Currently, the size of the axis must be a power of two.
5104
+ */
5105
+ function fft(a, axis = -1) {
5106
+ checkPairInput("fft", a);
5107
+ let { real, imag } = a;
5108
+ axis = checkAxis(axis, real.ndim);
5109
+ const n = real.shape[axis];
5110
+ checkPowerOfTwo("fft", n);
5111
+ const logN = Math.log2(n);
5112
+ let perm = null;
5113
+ if (axis !== real.ndim - 1) {
5114
+ perm = range(real.ndim);
5115
+ perm.splice(axis, 1);
5116
+ perm.push(axis);
5117
+ real = real.transpose(perm);
5118
+ imag = imag.transpose(perm);
5119
+ }
5120
+ const originalShape = real.shape;
5121
+ real = real.reshape([-1, ...rep(logN, 2)]).transpose([0, ...range(1, logN + 1).reverse()]).flatten();
5122
+ imag = imag.reshape([-1, ...rep(logN, 2)]).transpose([0, ...range(1, logN + 1).reverse()]).flatten();
5123
+ for (let i = 0; i < logN; i++) ({real, imag} = fftUpdate(i, {
5124
+ real,
5125
+ imag
5126
+ }));
5127
+ real = real.reshape(originalShape);
5128
+ imag = imag.reshape(originalShape);
5129
+ if (perm !== null) {
5130
+ real = real.transpose(invertPermutation(perm));
5131
+ imag = imag.transpose(invertPermutation(perm));
5132
+ }
5133
+ return {
5134
+ real,
5135
+ imag
5136
+ };
5137
+ }
5138
+ /**
5139
+ * Compute a one-dimensional inverse discrete Fourier transform.
5140
+ *
5141
+ * Currently, the size of the axis must be a power of two.
5142
+ */
5143
+ function ifft(a, axis = -1) {
5144
+ checkPairInput("ifft", a);
5145
+ let { real, imag } = a;
5146
+ axis = checkAxis(axis, real.ndim);
5147
+ const n = real.shape[axis];
5148
+ checkPowerOfTwo("ifft", n);
5149
+ imag = imag.mul(-1);
5150
+ const result = fft({
5151
+ real,
5152
+ imag
5153
+ }, axis);
5154
+ return {
5155
+ real: result.real.div(n),
5156
+ imag: result.imag.mul(-1).div(n)
5157
+ };
5158
+ }
5159
+
5160
+ //#endregion
5161
+ //#region src/library/numpy-linalg.ts
5162
+ var numpy_linalg_exports = {};
5163
+ __export(numpy_linalg_exports, {
5164
+ cholesky: () => cholesky,
5165
+ det: () => det,
5166
+ diagonal: () => diagonal,
5167
+ inv: () => inv,
5168
+ lstsq: () => lstsq,
5169
+ matmul: () => matmul,
5170
+ matrixPower: () => matrixPower,
5171
+ matrixTranspose: () => matrixTranspose,
5172
+ outer: () => outer,
5173
+ slogdet: () => slogdet,
5174
+ solve: () => solve,
5175
+ tensordot: () => tensordot,
5176
+ trace: () => trace,
5177
+ vecdot: () => vecdot
5178
+ });
5179
+ function checkSquare(name, a) {
5180
+ if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`${name}: input must be at least 2D square matrix, got ${a.aval}`);
5181
+ return a.shape[a.ndim - 1];
5182
+ }
5183
+ /**
5184
+ * Compute the Cholesky decomposition of a (batched) positive-definite matrix.
5185
+ *
5186
+ * This is like `jax.lax.linalg.cholesky()`, except with an option to symmetrize
5187
+ * the input matrix, which is on by default.
5188
+ */
5189
+ function cholesky(a, { upper = false, symmetrizeInput = true } = {}) {
5190
+ a = fudgeArray(a);
5191
+ checkSquare("cholesky", a);
5192
+ if (symmetrizeInput) a = a.ref.add(matrixTranspose(a)).mul(.5);
5193
+ return cholesky$1(a, { upper });
5194
+ }
5195
+ /** Compute the determinant of a square matrix (batched). */
5196
+ function det(a) {
5197
+ a = fudgeArray(a);
5198
+ const n = checkSquare("det", a);
5199
+ const [lu$2, pivots, permutation] = lu(a);
5200
+ permutation.dispose();
5201
+ const parity = pivots.notEqual(arange(n)).astype(int32).sum(-1).mod(2);
5202
+ const sign$1 = parity.mul(-2).add(1);
5203
+ const diag$1 = lu$2.diagonal(0, -1, -2);
5204
+ return prod$1(diag$1, -1).mul(sign$1);
5205
+ }
5206
+ /** Compute the inverse of a square matrix (batched). */
5207
+ function inv(a) {
5208
+ a = fudgeArray(a);
5209
+ const n = checkSquare("inv", a);
5210
+ return solve(a, eye(n));
5211
+ }
5212
+ /**
5213
+ * Return the least-squares solution to a linear equation.
5214
+ *
5215
+ * For overdetermined systems, this finds the `x` that minimizes `norm(ax - b)`.
5216
+ * For underdetermined systems, this finds the minimum-norm solution for `x`.
5217
+ *
5218
+ * This currently uses Cholesky decomposition to solve the normal equations,
5219
+ * under the hood. The method is not as robust as QR or SVD.
5220
+ *
5221
+ * @param a coefficient matrix of shape `(M, N)`
5222
+ * @param b right-hand side of shape `(M,)` or `(M, K)`
5223
+ * @return least-squares solution of shape `(N,)` or `(N, K)`
5224
+ */
5225
+ function lstsq(a, b) {
5226
+ a = fudgeArray(a);
5227
+ b = fudgeArray(b);
5228
+ if (a.ndim !== 2) throw new Error(`lstsq: 'a' must be a 2D array, got ${a.aval}`);
5229
+ const [m, n] = a.shape;
5230
+ if (b.shape[0] !== m) throw new Error(`lstsq: leading dimension of 'b' must match number of rows of 'a', got ${b.aval}`);
5231
+ const at = matrixTranspose(a.ref);
5232
+ if (m <= n) {
5233
+ const aat = matmul(a, at.ref);
5234
+ const l = cholesky(aat, { symmetrizeInput: false });
5235
+ const lb = triangularSolve(l.ref, b, {
5236
+ leftSide: true,
5237
+ lower: true
5238
+ });
5239
+ const llb = triangularSolve(l, lb, {
5240
+ leftSide: true,
5241
+ transposeA: true
5242
+ });
5243
+ return matmul(at, llb.ref);
5244
+ } else {
5245
+ const ata = matmul(at.ref, a);
5246
+ const l = cholesky(ata, { symmetrizeInput: false });
5247
+ const atb = matmul(at, b);
5248
+ const lb = triangularSolve(l.ref, atb, {
5249
+ leftSide: true,
5250
+ lower: true
5251
+ });
5252
+ const llb = triangularSolve(l, lb, {
5253
+ leftSide: true,
5254
+ transposeA: true
5255
+ });
5256
+ return llb;
5257
+ }
5258
+ }
5259
+ /** Raise a square matrix to an integer power, via repeated squarings. */
5260
+ function matrixPower(a, n) {
5261
+ if (!Number.isInteger(n)) throw new Error(`matrixPower: exponent must be an integer, got ${n}`);
5262
+ a = fudgeArray(a);
5263
+ const m = checkSquare("matrixPower", a);
5264
+ if (n === 0) {
5265
+ a.dispose();
5266
+ return broadcastTo(eye(m), a.shape);
5267
+ }
5268
+ if (n < 0) {
5269
+ a = inv(a);
5270
+ n = -n;
5271
+ }
5272
+ let result = null;
5273
+ let a2k = a;
5274
+ for (let k = 0; n; k++) {
5275
+ if (k > 0) a2k = matmul(a2k.ref, a2k);
5276
+ if (n % 2 === 1) result = result === null ? a2k.ref : matmul(result, a2k.ref);
5277
+ n = Math.floor(n / 2);
5278
+ }
5279
+ a2k.dispose();
5280
+ return result;
5281
+ }
5282
+ /** Return sign and natural logarithm of the determinant of `a`. */
5283
+ function slogdet(a) {
5284
+ a = fudgeArray(a);
5285
+ const n = checkSquare("slogdet", a);
5286
+ const [lu$2, pivots, permutation] = lu(a);
5287
+ permutation.dispose();
5288
+ let parity = pivots.notEqual(arange(n)).astype(int32).sum(-1);
5289
+ const diag$1 = lu$2.diagonal(0, -1, -2);
5290
+ parity = parity.add(diag$1.ref.less(0).astype(int32).sum(-1)).mod(2);
5291
+ const logabsdet = log(absolute(diag$1)).sum(-1);
5292
+ const sign$1 = parity.mul(-2).add(1);
5293
+ return [sign$1, logabsdet];
4674
5294
  }
4675
- function computePathNaive(input, sizeMap) {
4676
- const n = input.shapes.length;
4677
- const path = [];
4678
- let lastTensorIndex = 0;
4679
- for (let i = 1; i < n; i++) {
4680
- path.push([lastTensorIndex, i]);
4681
- lastTensorIndex = n + i - 1;
4682
- }
4683
- return new EinsumPath(input, sizeMap, path);
5295
+ /**
5296
+ * Solve a linear system of equations.
5297
+ *
5298
+ * This solves a (batched) linear system of equations `a @ x = b` for `x` given
5299
+ * `a` and `b`. If `a` is singular, this will return `nan` or `inf` values.
5300
+ *
5301
+ * @param a - Coefficient matrix of shape `(..., N, N)`.
5302
+ * @param b - Values of shape `(N,)` or `(..., N, M)`.
5303
+ * @returns Solution `x` of shape `(..., N)` or `(..., N, M)`.
5304
+ */
5305
+ function solve(a, b) {
5306
+ a = fudgeArray(a);
5307
+ b = fudgeArray(b);
5308
+ const n = checkSquare("solve", a);
5309
+ if (b.ndim === 0) throw new Error(`solve: b cannot be scalar`);
5310
+ const bIs1d = b.ndim === 1;
5311
+ if (bIs1d) b = b.reshape([...b.shape, 1]);
5312
+ if (b.shape[b.ndim - 2] !== n) throw new Error(`solve: leading dimension of b must match size of a, got a=${a.aval}, b=${b.aval}`);
5313
+ const m = b.shape[b.ndim - 1];
5314
+ const batchDims = generalBroadcast(a.shape.slice(0, -2), b.shape.slice(0, -2));
5315
+ a = broadcastTo(a, [
5316
+ ...batchDims,
5317
+ n,
5318
+ n
5319
+ ]);
5320
+ b = broadcastTo(b, [
5321
+ ...batchDims,
5322
+ n,
5323
+ m
5324
+ ]);
5325
+ const [lu$2, pivots, permutation] = lu(a);
5326
+ pivots.dispose();
5327
+ const P = arange(n).equal(permutation.reshape([...permutation.shape, 1])).astype(b.dtype);
5328
+ const LPb = triangularSolve(lu$2.ref, matmul(P, b), {
5329
+ leftSide: true,
5330
+ lower: true,
5331
+ unitDiagonal: true
5332
+ });
5333
+ let x = triangularSolve(lu$2, LPb.ref, {
5334
+ leftSide: true,
5335
+ lower: false
5336
+ });
5337
+ if (bIs1d) x = squeeze(x, -1);
5338
+ return x;
4684
5339
  }
4685
- function computePathOptimal(input, sizeMap) {
4686
- const n = input.shapes.length;
4687
- let bestPath = null;
4688
- let bestFlops = null;
4689
- for (const path of allPaths(range(n), n)) {
4690
- const flops = approximatePathFlops(input, sizeMap, path);
4691
- if (bestFlops === null || flops < bestFlops) {
4692
- bestPath = path;
4693
- bestFlops = flops;
4694
- }
5340
+
5341
+ //#endregion
5342
+ //#region src/library/numpy/dtype-info.ts
5343
+ /** Machine limits for floating-point types. */
5344
+ function finfo(dtype) {
5345
+ if (!isFloatDtype(dtype)) throw new Error(`finfo: received ${dtype}, must be a floating-point type`);
5346
+ switch (dtype) {
5347
+ case DType.Float16: return Object.freeze({
5348
+ bits: 16,
5349
+ dtype: DType.Float16,
5350
+ eps: 2 ** -10,
5351
+ epsneg: 2 ** -11,
5352
+ machep: -10,
5353
+ max: 65504,
5354
+ maxexp: 16,
5355
+ min: -65504,
5356
+ minexp: -14,
5357
+ negep: -24,
5358
+ nexp: 5,
5359
+ nmant: 10,
5360
+ precision: 3,
5361
+ resolution: .001,
5362
+ smallestNormal: 2 ** -14,
5363
+ smallestSubnormal: 2 ** -24
5364
+ });
5365
+ case DType.Float32: return Object.freeze({
5366
+ bits: 32,
5367
+ dtype: DType.Float32,
5368
+ eps: 2 ** -23,
5369
+ epsneg: 2 ** -24,
5370
+ machep: -23,
5371
+ max: 34028234663852886e22,
5372
+ maxexp: 128,
5373
+ min: -34028234663852886e22,
5374
+ minexp: -126,
5375
+ negep: -24,
5376
+ nexp: 8,
5377
+ nmant: 23,
5378
+ precision: 6,
5379
+ resolution: 1e-6,
5380
+ smallestNormal: 2 ** -126,
5381
+ smallestSubnormal: 2 ** -149
5382
+ });
5383
+ case DType.Float64: return Object.freeze({
5384
+ bits: 64,
5385
+ dtype: DType.Float64,
5386
+ eps: 2 ** -52,
5387
+ epsneg: 2 ** -53,
5388
+ machep: -52,
5389
+ max: Number.MAX_VALUE,
5390
+ maxexp: 1024,
5391
+ min: -Number.MAX_VALUE,
5392
+ minexp: -1022,
5393
+ negep: -53,
5394
+ nexp: 11,
5395
+ nmant: 52,
5396
+ precision: 15,
5397
+ resolution: 1e-15,
5398
+ smallestNormal: 2 ** -1022,
5399
+ smallestSubnormal: 2 ** -1074
5400
+ });
5401
+ default: throw new Error(`finfo: unsupported dtype ${dtype}`);
4695
5402
  }
4696
- return new EinsumPath(input, sizeMap, bestPath);
4697
5403
  }
4698
- function* allPaths(tensors, next) {
4699
- if (tensors.length === 2) {
4700
- yield [[tensors[0], tensors[1]]];
4701
- return;
4702
- }
4703
- for (let i = 0; i < tensors.length; i++) for (let j = i + 1; j < tensors.length; j++) {
4704
- const pair = [tensors[i], tensors[j]];
4705
- const newTensors = tensors.filter((t) => t !== pair[0] && t !== pair[1]);
4706
- newTensors.push(next);
4707
- for (const subpath of allPaths(newTensors, next + 1)) yield [pair, ...subpath];
5404
+ /** Machine limits for integer types. */
5405
+ function iinfo(dtype) {
5406
+ switch (dtype) {
5407
+ case DType.Int32: return Object.freeze({
5408
+ bits: 32,
5409
+ dtype: DType.Int32,
5410
+ max: 2147483647,
5411
+ min: -2147483648
5412
+ });
5413
+ case DType.Uint32: return Object.freeze({
5414
+ bits: 32,
5415
+ dtype: DType.Uint32,
5416
+ max: 4294967295,
5417
+ min: 0
5418
+ });
5419
+ default: throw new Error(`iinfo: unsupported dtype ${dtype}`);
4708
5420
  }
4709
5421
  }
4710
5422
 
@@ -4714,28 +5426,32 @@ var numpy_exports = {};
4714
5426
  __export(numpy_exports, {
4715
5427
  Array: () => Array$1,
4716
5428
  DType: () => DType,
4717
- abs: () => abs,
5429
+ abs: () => absolute,
4718
5430
  absolute: () => absolute,
4719
5431
  acos: () => acos,
4720
- acosh: () => acosh,
5432
+ acosh: () => arccosh,
4721
5433
  add: () => add,
5434
+ all: () => all,
4722
5435
  allclose: () => allclose,
5436
+ any: () => any,
4723
5437
  arange: () => arange,
4724
- arccos: () => arccos,
5438
+ arccos: () => acos,
4725
5439
  arccosh: () => arccosh,
5440
+ arcsin: () => asin,
4726
5441
  arcsinh: () => arcsinh,
4727
- arctan: () => arctan,
4728
- arctan2: () => arctan2,
5442
+ arctan: () => atan,
5443
+ arctan2: () => atan2,
4729
5444
  arctanh: () => arctanh,
4730
5445
  argmax: () => argmax,
4731
5446
  argmin: () => argmin,
5447
+ argsort: () => argsort,
4732
5448
  array: () => array,
4733
5449
  asin: () => asin,
4734
- asinh: () => asinh,
5450
+ asinh: () => arcsinh,
4735
5451
  astype: () => astype,
4736
5452
  atan: () => atan,
4737
5453
  atan2: () => atan2,
4738
- atanh: () => atanh,
5454
+ atanh: () => arctanh,
4739
5455
  bool: () => bool,
4740
5456
  broadcastArrays: () => broadcastArrays,
4741
5457
  broadcastShapes: () => broadcastShapes,
@@ -4745,16 +5461,21 @@ __export(numpy_exports, {
4745
5461
  clip: () => clip,
4746
5462
  columnStack: () => columnStack,
4747
5463
  concatenate: () => concatenate,
5464
+ convolve: () => convolve,
5465
+ corrcoef: () => corrcoef,
5466
+ correlate: () => correlate,
4748
5467
  cos: () => cos,
4749
5468
  cosh: () => cosh,
5469
+ cov: () => cov,
4750
5470
  cumsum: () => cumsum,
4751
- cumulativeSum: () => cumulativeSum,
5471
+ cumulativeSum: () => cumsum,
4752
5472
  deg2rad: () => deg2rad,
4753
5473
  degrees: () => degrees,
4754
5474
  diag: () => diag,
4755
5475
  diagonal: () => diagonal,
4756
- divide: () => divide,
4757
- dot: () => dot,
5476
+ divide: () => trueDivide,
5477
+ divmod: () => divmod,
5478
+ dot: () => dot$1,
4758
5479
  dstack: () => dstack,
4759
5480
  e: () => e,
4760
5481
  einsum: () => einsum,
@@ -4762,8 +5483,11 @@ __export(numpy_exports, {
4762
5483
  eulerGamma: () => eulerGamma,
4763
5484
  exp: () => exp,
4764
5485
  exp2: () => exp2,
5486
+ expandDims: () => expandDims,
4765
5487
  expm1: () => expm1,
4766
5488
  eye: () => eye,
5489
+ fft: () => numpy_fft_exports,
5490
+ finfo: () => finfo,
4767
5491
  flip: () => flip,
4768
5492
  fliplr: () => fliplr,
4769
5493
  flipud: () => flipud,
@@ -4771,6 +5495,7 @@ __export(numpy_exports, {
4771
5495
  float32: () => float32,
4772
5496
  float64: () => float64,
4773
5497
  floor: () => floor,
5498
+ floorDivide: () => floorDivide,
4774
5499
  fmod: () => fmod,
4775
5500
  frexp: () => frexp,
4776
5501
  full: () => full,
@@ -4783,6 +5508,7 @@ __export(numpy_exports, {
4783
5508
  hstack: () => hstack,
4784
5509
  hypot: () => hypot,
4785
5510
  identity: () => identity$1,
5511
+ iinfo: () => iinfo,
4786
5512
  inf: () => inf,
4787
5513
  inner: () => inner,
4788
5514
  int32: () => int32,
@@ -4794,12 +5520,15 @@ __export(numpy_exports, {
4794
5520
  ldexp: () => ldexp,
4795
5521
  less: () => less,
4796
5522
  lessEqual: () => lessEqual,
5523
+ linalg: () => numpy_linalg_exports,
4797
5524
  linspace: () => linspace,
4798
5525
  log: () => log,
4799
5526
  log10: () => log10,
4800
5527
  log1p: () => log1p,
4801
5528
  log2: () => log2,
5529
+ logspace: () => logspace,
4802
5530
  matmul: () => matmul,
5531
+ matrixTranspose: () => matrixTranspose,
4803
5532
  max: () => max,
4804
5533
  maximum: () => maximum,
4805
5534
  mean: () => mean,
@@ -4816,10 +5545,10 @@ __export(numpy_exports, {
4816
5545
  onesLike: () => onesLike,
4817
5546
  outer: () => outer,
4818
5547
  pad: () => pad,
4819
- permuteDims: () => permuteDims,
5548
+ permuteDims: () => transpose,
4820
5549
  pi: () => pi,
4821
5550
  positive: () => positive,
4822
- pow: () => pow,
5551
+ pow: () => power,
4823
5552
  power: () => power,
4824
5553
  prod: () => prod$1,
4825
5554
  promoteTypes: () => promoteTypes,
@@ -4834,8 +5563,11 @@ __export(numpy_exports, {
4834
5563
  shape: () => shape,
4835
5564
  sign: () => sign,
4836
5565
  sin: () => sin,
5566
+ sinc: () => sinc,
4837
5567
  sinh: () => sinh,
4838
5568
  size: () => size,
5569
+ sort: () => sort,
5570
+ split: () => split$1,
4839
5571
  sqrt: () => sqrt,
4840
5572
  square: () => square,
4841
5573
  squeeze: () => squeeze,
@@ -4843,6 +5575,7 @@ __export(numpy_exports, {
4843
5575
  std: () => std,
4844
5576
  subtract: () => subtract,
4845
5577
  sum: () => sum,
5578
+ take: () => take,
4846
5579
  tan: () => tan,
4847
5580
  tanh: () => tanh,
4848
5581
  tensordot: () => tensordot,
@@ -5000,6 +5733,26 @@ function min(a, axis = null, opts) {
5000
5733
  function max(a, axis = null, opts) {
5001
5734
  return reduce(a, AluOp.Max, axis, opts);
5002
5735
  }
5736
+ /**
5737
+ * Test whether all array elements along a given axis evaluate to True.
5738
+ *
5739
+ * Returns a boolean array with the same shape as `a` with the specified axis
5740
+ * removed. If axis is None, returns a scalar.
5741
+ */
5742
+ function all(a, axis = null, opts) {
5743
+ a = fudgeArray(a).astype(DType.Bool);
5744
+ return min(a, axis, opts);
5745
+ }
5746
+ /**
5747
+ * Test whether any array element along a given axis evaluates to True.
5748
+ *
5749
+ * Returns a boolean array with the same shape as `a` with the specified axis
5750
+ * removed. If axis is None, returns a scalar.
5751
+ */
5752
+ function any(a, axis = null, opts) {
5753
+ a = fudgeArray(a).astype(DType.Bool);
5754
+ return max(a, axis, opts);
5755
+ }
5003
5756
  /** Return the peak-to-peak range along a given axis (`max - min`). */
5004
5757
  function ptp(a, axis = null, opts) {
5005
5758
  a = fudgeArray(a);
@@ -5074,8 +5827,6 @@ function cumsum(a, axis) {
5074
5827
  a = broadcast(a, a.shape.concat(n), [-2]);
5075
5828
  return moveaxis$1(tril(a).sum(-1), -1, axis);
5076
5829
  }
5077
- /** @function Alternative name for `jax.numpy.cumsum()`. */
5078
- const cumulativeSum = cumsum;
5079
5830
  /** Reverse the elements in an array along the given axes. */
5080
5831
  function flip(x, axis = null) {
5081
5832
  const nd = ndim(x);
@@ -5083,6 +5834,45 @@ function flip(x, axis = null) {
5083
5834
  return flip$1(x, axis);
5084
5835
  }
5085
5836
  /**
5837
+ * Split an array into multiple sub-arrays along an axis.
5838
+ *
5839
+ * @param a - The input array to split.
5840
+ * @param indicesOrSections - If an integer, it indicates the number of equal
5841
+ * sections to create along the specified axis. If a list of integers, it
5842
+ * specifies the indices at which to split the array.
5843
+ * @param axis - The axis along which to split the array. Default is 0.
5844
+ */
5845
+ function split$1(a, indicesOrSections, axis = 0) {
5846
+ a = fudgeArray(a);
5847
+ axis = checkAxis(axis, a.ndim);
5848
+ const size$1 = a.shape[axis];
5849
+ let sizes;
5850
+ if (typeof indicesOrSections === "number") {
5851
+ if (size$1 % indicesOrSections !== 0) throw new Error(`Array of size ${size$1} cannot be split into ${indicesOrSections} equal parts`);
5852
+ const partSize = size$1 / indicesOrSections;
5853
+ sizes = rep(indicesOrSections, partSize);
5854
+ } else {
5855
+ const indices = indicesOrSections;
5856
+ sizes = [indices[0]];
5857
+ for (let i = 1; i < indices.length; i++) sizes.push(indices[i] - indices[i - 1]);
5858
+ sizes.push(size$1 - indices[indices.length - 1]);
5859
+ }
5860
+ const results = [];
5861
+ for (let i = 0; i < sizes.length; i += 7) if (i === sizes.length) {
5862
+ results.push(a);
5863
+ break;
5864
+ } else if (i + 8 >= sizes.length) {
5865
+ results.push(...split$2(a, axis, sizes.slice(i)));
5866
+ break;
5867
+ } else {
5868
+ const groupSizes = [...sizes.slice(i, i + 7), sizes.slice(i + 7).reduce((x, y) => x + y, 0)];
5869
+ const outs = split$2(a, axis, groupSizes);
5870
+ results.push(...outs.slice(0, -1));
5871
+ a = outs[outs.length - 1];
5872
+ }
5873
+ return results;
5874
+ }
5875
+ /**
5086
5876
  * Join a sequence of arrays along an existing axis.
5087
5877
  *
5088
5878
  * The arrays must have the same shape, except in the dimension corresponding to
@@ -5094,13 +5884,11 @@ function concatenate(xs, axis = 0) {
5094
5884
  if (xs.length === 0) throw new Error("Need at least one array to concatenate");
5095
5885
  const shapes = xs.map(shape);
5096
5886
  axis = checkAxis(axis, shapes[0].length);
5097
- for (let i = 1; i < shapes.length; i++) if (shapes[i].length !== shapes[0].length || !shapes[i].every((d, j) => j === axis || d === shapes[0][j])) throw new Error(`Cannot concatenate arrays with shapes ${JSON.stringify(shapes)} along axis ${axis}`);
5098
- const makePadAxis = (start, end) => shapes[0].map((_, i) => i === axis ? [start, end] : [0, 0]);
5887
+ for (let i = 1; i < shapes.length; i++) if (shapes[i].length !== shapes[0].length || !shapes[i].every((d, j) => j === axis || d === shapes[0][j])) throw new Error(`Cannot concatenate arrays ${xs[0].aval} and ${xs[i].aval} along axis ${axis}`);
5099
5888
  let result = xs[0];
5100
- for (let i = 1; i < xs.length; i++) {
5101
- const len1 = result.shape[axis];
5102
- const len2 = shapes[i][axis];
5103
- result = pad(result, makePadAxis(0, len2)).add(pad(xs[i], makePadAxis(len1, 0)));
5889
+ for (let i = 1; i < xs.length; i += 7) {
5890
+ const group = xs.slice(i, i + 7);
5891
+ result = concatenate$1([result, ...group], axis);
5104
5892
  }
5105
5893
  return result;
5106
5894
  }
@@ -5185,8 +5973,11 @@ function flipud(x) {
5185
5973
  function fliplr(x) {
5186
5974
  return flip(x, 1);
5187
5975
  }
5188
- /** @function Alternative name for `numpy.transpose()`. */
5189
- const permuteDims = transpose;
5976
+ /** Transpose the last two dimensions of an array. */
5977
+ function matrixTranspose(a) {
5978
+ if (ndim(a) < 2) throw new Error(`matrixTranspose: input array must be at least 2D`);
5979
+ return moveaxis$1(a, -1, -2);
5980
+ }
5190
5981
  /** Return a 1-D flattened array containing the elements of the input. */
5191
5982
  function ravel(a) {
5192
5983
  return fudgeArray(a).ravel();
@@ -5202,6 +5993,32 @@ function squeeze(a, axis = null) {
5202
5993
  return reshape(a, newShape);
5203
5994
  }
5204
5995
  /**
5996
+ * Expand the shape of an array by inserting new axes of length 1.
5997
+ *
5998
+ * @param a - Input array.
5999
+ * @param axis - Position(s) in the expanded axes where the new axis (or axes)
6000
+ * is placed. Can be a single integer or an array of integers.
6001
+ * @returns Array with the number of dimensions increased.
6002
+ *
6003
+ * @example
6004
+ * ```ts
6005
+ * const x = np.array([1, 2]);
6006
+ * np.expandDims(x, 0); // Shape [1, 2]
6007
+ * np.expandDims(x, 1); // Shape [2, 1]
6008
+ * np.expandDims(x, [0, 2]); // Shape [1, 2, 1]
6009
+ * ```
6010
+ */
6011
+ function expandDims(a, axis) {
6012
+ const as = shape(a);
6013
+ axis = typeof axis === "number" ? [axis] : axis;
6014
+ axis = normalizeAxis(axis, as.length + axis.length);
6015
+ const newShape = [];
6016
+ let srcIdx = 0;
6017
+ for (let i = 0; i < as.length + axis.length; i++) if (axis.includes(i)) newShape.push(1);
6018
+ else newShape.push(as[srcIdx++]);
6019
+ return reshape(a, newShape);
6020
+ }
6021
+ /**
5205
6022
  * Repeat each element of an array after themselves.
5206
6023
  *
5207
6024
  * If no axis is provided, use the flattened input array, and return a flat
@@ -5289,7 +6106,7 @@ function diagonal(a, offset, axis1, axis2) {
5289
6106
  */
5290
6107
  function diag(v, k = 0) {
5291
6108
  const a = fudgeArray(v);
5292
- if (!Number.isInteger(k)) throw new TypeError(`k must be an integer, got ${k}`);
6109
+ if (!Number.isInteger(k)) throw new Error(`k must be an integer, got ${k}`);
5293
6110
  if (a.ndim === 1) {
5294
6111
  const n = a.shape[0];
5295
6112
  const ret = where(eye(n).equal(1), a.ref, zerosLike(a));
@@ -5297,12 +6114,46 @@ function diag(v, k = 0) {
5297
6114
  else if (k < 0) return pad(ret, [[-k, 0], [0, -k]]);
5298
6115
  else return ret;
5299
6116
  } else if (a.ndim === 2) return diagonal(a, k);
5300
- else throw new TypeError("numpy.diag only supports 1D and 2D arrays");
6117
+ else throw new Error("numpy.diag only supports 1D and 2D arrays");
5301
6118
  }
5302
6119
  /** Calculate the sum of the diagonal of an array along the given axes. */
5303
6120
  function trace(a, offset = 0, axis1 = 0, axis2 = 1) {
5304
6121
  return diagonal(a, offset, axis1, axis2).sum(-1);
5305
6122
  }
6123
+ /**
6124
+ * Return a sorted copy of an array.
6125
+ *
6126
+ * The array is sorted along a specified axis (the last by default). This may be
6127
+ * an unstable sort, and it dispatches to device-specific implementation.
6128
+ */
6129
+ function sort(a, axis = -1) {
6130
+ return fudgeArray(a).sort(axis);
6131
+ }
6132
+ /**
6133
+ * Return indices that would sort an array. This may be an unstable sorting
6134
+ * algorithm; it need not preserve order of indices in ties.
6135
+ *
6136
+ * Returns an array of `int32` indices.
6137
+ *
6138
+ * The array is sorted along a specified axis (the last by default).
6139
+ */
6140
+ function argsort(a, axis = -1) {
6141
+ return fudgeArray(a).argsort(axis);
6142
+ }
6143
+ /**
6144
+ * Take elements from an array along an axis.
6145
+ *
6146
+ * This is equivalent to advanced indexing with integer indices over that
6147
+ * numbered axis. By default, the flattened array is used.
6148
+ */
6149
+ function take(a, indices, axis = null) {
6150
+ if (axis === null) {
6151
+ a = ravel(a);
6152
+ axis = 0;
6153
+ }
6154
+ axis = checkAxis(axis, ndim(a));
6155
+ return gather(a, [indices], [axis], axis);
6156
+ }
5306
6157
  /** Return if two arrays are element-wise equal within a tolerance. */
5307
6158
  function allclose(actual, expected, options) {
5308
6159
  const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
@@ -5319,11 +6170,11 @@ function allclose(actual, expected, options) {
5319
6170
  }
5320
6171
  /** Matrix product of two arrays. */
5321
6172
  function matmul(x, y) {
5322
- if (ndim(x) === 0 || ndim(y) === 0) throw new TypeError("matmul: x and y must be at least 1D");
6173
+ if (ndim(x) === 0 || ndim(y) === 0) throw new Error("matmul: x and y must be at least 1D");
5323
6174
  x = x, y = y;
5324
6175
  if (y.ndim === 1) return dot$2(x, y);
5325
6176
  const numBatchDims = Math.min(Math.max(x.ndim, 2), y.ndim) - 2;
5326
- return dot$1(x, y, {
6177
+ return dot(x, y, {
5327
6178
  lhsContractingDims: [-1],
5328
6179
  rhsContractingDims: [-2],
5329
6180
  lhsBatchDims: range(-2 - numBatchDims, -2),
@@ -5331,11 +6182,11 @@ function matmul(x, y) {
5331
6182
  });
5332
6183
  }
5333
6184
  /** Dot product of two arrays. */
5334
- function dot(x, y) {
6185
+ function dot$1(x, y) {
5335
6186
  if (ndim(x) === 0 || ndim(y) === 0) return multiply(x, y);
5336
6187
  x = x, y = y;
5337
6188
  if (y.ndim === 1) return dot$2(x, y);
5338
- return dot$1(x, y, {
6189
+ return dot(x, y, {
5339
6190
  lhsContractingDims: [-1],
5340
6191
  rhsContractingDims: [-2]
5341
6192
  });
@@ -5351,7 +6202,7 @@ function tensordot(x, y, axes = 2) {
5351
6202
  x = fudgeArray(x);
5352
6203
  y = fudgeArray(y);
5353
6204
  if (typeof axes === "number") axes = [range(-axes, 0), range(axes)];
5354
- return dot$1(x, y, {
6205
+ return dot(x, y, {
5355
6206
  lhsContractingDims: axes[0],
5356
6207
  rhsContractingDims: axes[1]
5357
6208
  });
@@ -5444,7 +6295,7 @@ function einsum(...args) {
5444
6295
  const [b, bidx] = processSingleTensor(operands[j], indices[j], indices[i]);
5445
6296
  indexReduced = indexReduced.filter((idx) => aidx.includes(idx));
5446
6297
  const indexBatch = aidx.filter((idx) => bidx.includes(idx) && !indexReduced.includes(idx));
5447
- const result = dot$1(a, b, {
6298
+ const result = dot(a, b, {
5448
6299
  lhsContractingDims: indexReduced.map((idx) => aidx.indexOf(idx)),
5449
6300
  rhsContractingDims: indexReduced.map((idx) => bidx.indexOf(idx)),
5450
6301
  lhsBatchDims: indexBatch.map((idx) => aidx.indexOf(idx)),
@@ -5472,7 +6323,7 @@ function einsum(...args) {
5472
6323
  * Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
5473
6324
  */
5474
6325
  function inner(x, y) {
5475
- return dot$1(fudgeArray(x), fudgeArray(y), {
6326
+ return dot(fudgeArray(x), fudgeArray(y), {
5476
6327
  lhsContractingDims: [-1],
5477
6328
  rhsContractingDims: [-1]
5478
6329
  });
@@ -5505,6 +6356,30 @@ function vecdot(x, y, { axis } = {}) {
5505
6356
  function vdot(x, y) {
5506
6357
  return dot$2(ravel(x), ravel(y));
5507
6358
  }
6359
+ function _convImpl(name, x, y, mode) {
6360
+ if (x.ndim !== 1 || y.ndim !== 1) throw new Error(`${name}: both inputs must be 1D arrays, got ${x.ndim}D and ${y.ndim}D`);
6361
+ let flipOutput = false;
6362
+ if (x.shape[0] < y.shape[0]) {
6363
+ [x, y] = [y, x];
6364
+ if (name === "correlate") flipOutput = true;
6365
+ }
6366
+ if (name === "convolve") y = flip(y);
6367
+ let padding;
6368
+ if (mode === "valid") padding = "VALID";
6369
+ else if (mode === "same") padding = "SAME_LOWER";
6370
+ else if (mode === "full") padding = [[y.shape[0] - 1, y.shape[0] - 1]];
6371
+ else throw new Error(`${name}: invalid mode ${mode}, expected "full", "same", or "valid"`);
6372
+ const z = conv(x.slice(null, null), y.slice(null, null), [1], padding).slice(0, 0);
6373
+ return flipOutput ? flip(z) : z;
6374
+ }
6375
+ /** Convolution of two one-dimensional arrays. */
6376
+ function convolve(x, y, mode = "full") {
6377
+ return _convImpl("convolve", x, y, mode);
6378
+ }
6379
+ /** Correlation of two one dimensional arrays. */
6380
+ function correlate(x, y, mode = "valid") {
6381
+ return _convImpl("correlate", x, y, mode);
6382
+ }
5508
6383
  /**
5509
6384
  * Return a tuple of coordinate matrices from coordinate vectors.
5510
6385
  *
@@ -5513,7 +6388,7 @@ function vdot(x, y) {
5513
6388
  */
5514
6389
  function meshgrid(xs, { indexing } = {}) {
5515
6390
  indexing ??= "xy";
5516
- for (const x of xs) if (x.ndim !== 1) throw new TypeError(`meshgrid: all inputs must be 1D arrays, got ${x.ndim}D array`);
6391
+ for (const x of xs) if (x.ndim !== 1) throw new Error(`meshgrid: all inputs must be 1D arrays, got ${x.ndim}D array`);
5517
6392
  if (xs.length <= 1) return xs;
5518
6393
  if (indexing === "xy") {
5519
6394
  const [a, b, ...rest] = xs;
@@ -5532,43 +6407,6 @@ function meshgrid(xs, { indexing } = {}) {
5532
6407
  return xs.map((x, i) => broadcast(x, shape$1, [...range(i), ...range(i + 1, xs.length)]));
5533
6408
  }
5534
6409
  /**
5535
- * Return an array with ones on and below the diagonal and zeros elsewhere.
5536
- *
5537
- * If `k` is provided, it specifies the sub-diagonal on and below which the
5538
- * array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
5539
- * `k>0` is above it.
5540
- */
5541
- function tri(n, m, k = 0, { dtype, device } = {}) {
5542
- m ??= n;
5543
- dtype ??= DType.Float32;
5544
- if (!Number.isInteger(n) || n < 0) throw new TypeError(`tri: n must be a non-negative integer, got ${n}`);
5545
- if (!Number.isInteger(m) || m < 0) throw new TypeError(`tri: m must be a non-negative integer, got ${m}`);
5546
- if (!Number.isInteger(k)) throw new TypeError(`tri: k must be an integer, got ${k}`);
5547
- const rows = arange(k, n + k, 1, {
5548
- dtype: DType.Int32,
5549
- device
5550
- });
5551
- const cols = arange(0, m, 1, {
5552
- dtype: DType.Int32,
5553
- device
5554
- });
5555
- return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
5556
- }
5557
- /** Return the lower triangle of an array. Must be of dimension >= 2. */
5558
- function tril(a, k = 0) {
5559
- if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
5560
- a = fudgeArray(a);
5561
- const [n, m] = a.shape.slice(-2);
5562
- return where(tri(n, m, k, { dtype: bool }), a.ref, zerosLike(a));
5563
- }
5564
- /** Return the upper triangle of an array. Must be of dimension >= 2. */
5565
- function triu(a, k = 0) {
5566
- if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
5567
- a = fudgeArray(a);
5568
- const [n, m] = a.shape.slice(-2);
5569
- return where(tri(n, m, k - 1, { dtype: bool }), zerosLike(a.ref), a);
5570
- }
5571
- /**
5572
6410
  * Clip (limit) the values in an array.
5573
6411
  *
5574
6412
  * Given an interval, values outside the interval are clipped to the interval
@@ -5592,8 +6430,6 @@ function absolute(x) {
5592
6430
  x = fudgeArray(x);
5593
6431
  return where(less(x.ref, 0), x.ref.mul(-1), x);
5594
6432
  }
5595
- /** @function Alias of `jax.numpy.absolute()`. */
5596
- const abs = absolute;
5597
6433
  /** Return an element-wise indication of sign of the input. */
5598
6434
  function sign(x) {
5599
6435
  x = fudgeArray(x);
@@ -5637,6 +6473,20 @@ function tan(x) {
5637
6473
  x = fudgeArray(x);
5638
6474
  return sin(x.ref).div(cos(x));
5639
6475
  }
6476
+ /**
6477
+ * @function
6478
+ * Return the normalized sinc function.
6479
+ *
6480
+ * The sinc function is defined as `sin(πx) / (πx)` for `x != 0`, and `1` for `x = 0`.
6481
+ * This is the normalized sinc function commonly used in signal processing.
6482
+ *
6483
+ * **Note:** JVP is not supported at x=0 due to discontinuous derivative. This
6484
+ * requires a custom JVP rule to handle properly (see JAX implementation).
6485
+ */
6486
+ const sinc = jit$1(function sinc$1(x) {
6487
+ const pix = x.ref.mul(Math.PI);
6488
+ return where(equal(x, 0), 1, sin(pix.ref).div(pix));
6489
+ });
5640
6490
  /** Element-wise inverse cosine function (inverse of cos). */
5641
6491
  function acos(x) {
5642
6492
  return subtract(pi / 2, asin(x));
@@ -5672,12 +6522,6 @@ const atan2 = jit$1(function atan2$1(y, x) {
5672
6522
  const denom = where(xNeg, y, r.add(x));
5673
6523
  return atan(numer.div(denom)).mul(2);
5674
6524
  });
5675
- /** @function Alias of `jax.numpy.acos()`. */
5676
- const arccos = acos;
5677
- /** @function Alias of `jax.numpy.atan()`. */
5678
- const arctan = atan;
5679
- /** @function Alias of `jax.numpy.atan2()`. */
5680
- const arctan2 = atan2;
5681
6525
  /** Element-wise subtraction, with broadcasting. */
5682
6526
  function subtract(x, y) {
5683
6527
  x = fudgeArray(x);
@@ -5695,6 +6539,25 @@ function trueDivide(x, y) {
5695
6539
  return x.div(y);
5696
6540
  }
5697
6541
  /**
6542
+ * Return the largest integer smaller or equal to the division of the inputs.
6543
+ *
6544
+ * The result is always rounded towards negative infinity.
6545
+ *
6546
+ * For floating-point inputs, this is equivalent to `floor(x / y)`.
6547
+ * For integer inputs, we use `(x - remainder(x, y)) / y` to handle
6548
+ * negative values correctly (note: may overflow near int32 boundaries).
6549
+ *
6550
+ * @param x - Dividend array.
6551
+ * @param y - Divisor array.
6552
+ * @returns Element-wise floor division of x by y.
6553
+ */
6554
+ function floorDivide(x, y) {
6555
+ x = fudgeArray(x);
6556
+ y = fudgeArray(y);
6557
+ if (isFloatDtype(x.dtype) || isFloatDtype(y.dtype)) return floor(trueDivide(x, y));
6558
+ return subtract(x, remainder(x.ref, y.ref)).div(y);
6559
+ }
6560
+ /**
5698
6561
  * @function
5699
6562
  * Calculate element-wise floating-point modulo operation.
5700
6563
  */
@@ -5708,8 +6571,20 @@ const fmod = jit$1(function fmod$1(x, y) {
5708
6571
  const remainder = jit$1(function remainder$1(x, y) {
5709
6572
  return mod(mod(x, y.ref).add(y.ref), y);
5710
6573
  });
5711
- /** @function Alias of `jax.numpy.trueDivide()`. */
5712
- const divide = trueDivide;
6574
+ /**
6575
+ * Return element-wise quotient and remainder simultaneously.
6576
+ *
6577
+ * Equivalent to `[floorDivide(x, y), remainder(x, y)]`.
6578
+ *
6579
+ * @param x - Dividend array.
6580
+ * @param y - Divisor array.
6581
+ * @returns Tuple of [quotient, remainder].
6582
+ */
6583
+ function divmod(x, y) {
6584
+ const xArr = fudgeArray(x);
6585
+ const yArr = fudgeArray(y);
6586
+ return [floorDivide(xArr.ref, yArr.ref), remainder(xArr, yArr)];
6587
+ }
5713
6588
  /** Round input to the nearest integer towards zero. */
5714
6589
  function trunc(x) {
5715
6590
  return idiv(x, 1);
@@ -5731,9 +6606,9 @@ function ldexp(x1, x2) {
5731
6606
  */
5732
6607
  function frexp(x) {
5733
6608
  x = fudgeArray(x);
5734
- const absx = abs(x.ref);
6609
+ const absx = absolute(x.ref);
5735
6610
  const exponent = where(equal(x.ref, 0), 0, floor(log2(absx)).add(1).astype(DType.Int32));
5736
- const mantissa = divide(x, exp2(exponent.ref.astype(x.dtype)));
6611
+ const mantissa = x.div(exp2(exponent.ref.astype(x.dtype)));
5737
6612
  return [mantissa, exponent];
5738
6613
  }
5739
6614
  /** Calculate `2**p` for all p in the input array. */
@@ -5776,10 +6651,8 @@ const power = jit$1(function power$1(x1, x2) {
5776
6651
  const x2i = trunc(x2.ref);
5777
6652
  const shouldBeNaN = multiply(x2.ref.notEqual(x2i.ref), x1.ref.less(0));
5778
6653
  const resultSign = where(mod(x2i, 2).notEqual(0), where(x1.ref.less(0), -1, 1), 1);
5779
- return where(shouldBeNaN, nan, exp(log(abs(x1)).mul(x2)).mul(resultSign));
6654
+ return where(shouldBeNaN, nan, exp(log(absolute(x1)).mul(x2)).mul(resultSign));
5780
6655
  });
5781
- /** @function Alias of `jax.numpy.power()`. */
5782
- const pow = power;
5783
6656
  /** @function Calculate the element-wise cube root of the input array. */
5784
6657
  const cbrt = jit$1(function cbrt$1(x) {
5785
6658
  const sgn = where(less(x.ref, 0), -1, 1);
@@ -5845,69 +6718,360 @@ const arccosh = jit$1(function arccosh$1(x) {
5845
6718
  const arctanh = jit$1(function arctanh$1(x) {
5846
6719
  return log(add(1, x.ref).div(subtract(1, x))).mul(.5);
5847
6720
  });
5848
- /** @function Alias of `jax.numpy.arcsinh()`. */
5849
- const asinh = arcsinh;
5850
- /** @function Alias of `jax.numpy.arccosh()`. */
5851
- const acosh = arccosh;
5852
- /** @function Alias of `jax.numpy.arctanh()`. */
5853
- const atanh = arctanh;
5854
6721
  /**
5855
- * Compute the variance of an array.
5856
- *
5857
- * The variance is computed for the flattened array by default, otherwise over
5858
- * the specified axis.
6722
+ * Compute the variance of an array.
6723
+ *
6724
+ * The variance is computed for the flattened array by default, otherwise over
6725
+ * the specified axis.
6726
+ *
6727
+ * If `correction` is provided, the divisor in calculation is `N - correction`,
6728
+ * where `N` represents the number of elements (e.g., for Bessel's correction).
6729
+ */
6730
+ function var_(x, axis = null, opts) {
6731
+ x = fudgeArray(x);
6732
+ axis = normalizeAxis(axis, x.ndim);
6733
+ const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
6734
+ if (n === 0) throw new Error("var: cannot compute variance over zero-length axis");
6735
+ const mu = opts?.mean !== void 0 ? opts.mean : mean(x.ref, axis, { keepdims: true });
6736
+ return square(x.sub(mu)).sum(axis, { keepdims: opts?.keepdims }).mul(1 / (n - (opts?.correction ?? 0)));
6737
+ }
6738
+ /**
6739
+ * Compute the standard deviation of an array.
6740
+ *
6741
+ * The standard deviation is computed for the flattened array by default,
6742
+ * otherwise over the specified axis.
6743
+ *
6744
+ * If `correction` is provided, the divisor in calculation is `N - correction`,
6745
+ * where `N` represents the number of elements (e.g., for Bessel's correction).
6746
+ */
6747
+ function std(x, axis = null, opts) {
6748
+ return sqrt(var_(x, axis, opts));
6749
+ }
6750
+ /** Estimate the sample covariance of a set of variables. */
6751
+ function cov(x, y = null, { rowvar = true } = {}) {
6752
+ x = fudgeArray(x);
6753
+ if (x.ndim === 1) x = x.reshape([1, x.shape[0]]);
6754
+ if (y !== null) {
6755
+ y = fudgeArray(y);
6756
+ if (y.ndim === 1) y = y.reshape([1, y.shape[0]]);
6757
+ x = vstack([x, y]);
6758
+ }
6759
+ if (!rowvar) x = x.transpose();
6760
+ const [_M, N] = x.shape;
6761
+ x = x.ref.sub(x.mean(1, { keepdims: true }));
6762
+ return dot$1(x.ref, x.transpose()).div(N - 1);
6763
+ }
6764
+ /** Compute the Pearson correlation coefficients (in range `[-1, 1]`). */
6765
+ function corrcoef(x, y) {
6766
+ const c = cov(x, y);
6767
+ const variances = diag(c.ref);
6768
+ const norm = sqrt(outer(variances.ref, variances));
6769
+ return c.div(norm);
6770
+ }
6771
+ /** Test element-wise for positive or negative infinity, return bool array. */
6772
+ function isinf(x) {
6773
+ x = fudgeArray(x);
6774
+ return isFloatDtype(x.dtype) ? x.ref.equal(Infinity).add(x.equal(-Infinity)) : fullLike$1(x, false);
6775
+ }
6776
+ /** Test element-wise for NaN (Not a Number). */
6777
+ function isnan(x) {
6778
+ x = fudgeArray(x);
6779
+ return isFloatDtype(x.dtype) ? x.ref.notEqual(x) : fullLike$1(x, false);
6780
+ }
6781
+ /** Test element-wise for negative infinity, return bool array. */
6782
+ function isneginf(x) {
6783
+ x = fudgeArray(x);
6784
+ return isFloatDtype(x.dtype) ? x.equal(-Infinity) : fullLike$1(x, false);
6785
+ }
6786
+ /** Test element-wise for positive infinity, return bool array. */
6787
+ function isposinf(x) {
6788
+ x = fudgeArray(x);
6789
+ return isFloatDtype(x.dtype) ? x.equal(Infinity) : fullLike$1(x, false);
6790
+ }
6791
+ /**
6792
+ * @function
6793
+ * Test element-wise for finite values (not infinity or NaN).
6794
+ */
6795
+ const isfinite = jit$1(function isfinite$1(x) {
6796
+ if (!isFloatDtype(x.dtype)) return fullLike$1(x, true);
6797
+ return isnan(x.ref).add(isinf(x)).notEqual(true);
6798
+ });
6799
+
6800
+ //#endregion
6801
+ //#region src/library/lax-linalg.ts
6802
+ var lax_linalg_exports = {};
6803
+ __export(lax_linalg_exports, {
6804
+ cholesky: () => cholesky$1,
6805
+ lu: () => lu,
6806
+ triangularSolve: () => triangularSolve
6807
+ });
6808
+ /**
6809
+ * Compute the Cholesky decomposition of a symmetric positive-definite matrix.
6810
+ *
6811
+ * The Cholesky decomposition of a matrix `A` is:
6812
+ *
6813
+ * - A = L @ L^T (for upper=false, default)
6814
+ * - A = U^T @ U (for upper=true)
6815
+ *
6816
+ * where `L` is a lower-triangular matrix and `U` is an upper-triangular matrix.
6817
+ * The input matrix must be symmetric and positive-definite.
6818
+ *
6819
+ * @example
6820
+ * ```ts
6821
+ * import { lax, numpy as np } from "@jax-js/jax";
6822
+ *
6823
+ * const x = np.array([[2., 1.], [1., 2.]]);
6824
+ *
6825
+ * // Lower Cholesky factorization (default):
6826
+ * const L = lax.linalg.cholesky(x);
6827
+ * // L ≈ [[1.4142135, 0], [0.70710677, 1.2247449]]
6828
+ *
6829
+ * // Upper Cholesky factorization:
6830
+ * const U = lax.linalg.cholesky(x, { upper: true });
6831
+ * // U ≈ [[1.4142135, 0.70710677], [0, 1.2247449]]
6832
+ * ```
6833
+ */
6834
+ function cholesky$1(a, { upper = false } = {}) {
6835
+ const L = cholesky$2(a);
6836
+ return upper ? moveaxis$1(L, -2, -1) : L;
6837
+ }
6838
+ /**
6839
+ * LU decomposition with partial pivoting.
6840
+ *
6841
+ * Computes the matrix decomposition: `P @ A = L @ U`, where `P` is a
6842
+ * permutation of the rows of `A`, `L` is lower-triangular with unit diagonal,
6843
+ * and `U` is upper-triangular.
6844
+ *
6845
+ * @param x - A batch of matrices with shape `[..., m, n]`.
6846
+ *
6847
+ * @returns A tuple `(lu, pivots, permutation)` where:
6848
+ * - `lu`: combined lower and upper triangular matrices.
6849
+ * - `pivots`: an array of pivot indices with shape `[..., min(m, n)]`.
6850
+ * - `permutation`: the permutation generated by pivots with shape `[..., m]`.
6851
+ *
6852
+ * @example
6853
+ * ```ts
6854
+ * import { lax, numpy as np } from "@jax-js/jax";
6855
+ *
6856
+ * const A = np.array([[4., 3.], [6., 3.]]);
6857
+ * const [lu, pivots, permutation] = lax.linalg.lu(A);
6858
+ * // lu ≈ [[6., 3.], [0.6666667, 1.0]]
6859
+ * // pivots = [1, 1]
6860
+ * // permutation = [1, 0]
6861
+ * ```
6862
+ */
6863
+ function lu(x) {
6864
+ return lu$1(x);
6865
+ }
6866
+ /**
6867
+ * Solve a triangular linear system.
6868
+ *
6869
+ * Solves `a @ x = b` (if leftSide=true) or `x @ a = b` (if leftSide=false)
6870
+ * where `a` is a triangular matrix.
6871
+ *
6872
+ * @example
6873
+ * ```ts
6874
+ * import { lax, numpy as np } from "@jax-js/jax";
6875
+ *
6876
+ * const L = np.array([[2., 0.], [1., 3.]]);
6877
+ * const b = np.array([4., 7.]).reshape([2, 1]);
6878
+ *
6879
+ * // Solve L @ x = b
6880
+ * const x = lax.linalg.triangularSolve(L, b, { leftSide: true, lower: true });
6881
+ * // x = [[2.], [5./3.]]
6882
+ * ```
6883
+ */
6884
+ function triangularSolve(a, b, { leftSide = false, lower = false, transposeA = false, unitDiagonal = false } = {}) {
6885
+ a = fudgeArray(a);
6886
+ b = fudgeArray(b);
6887
+ if (!leftSide) transposeA = !transposeA;
6888
+ else b = moveaxis$1(b, -2, -1);
6889
+ if (transposeA) a = moveaxis$1(a, -2, -1);
6890
+ let x = triangularSolve$1(a, b, {
6891
+ lower,
6892
+ unitDiagonal
6893
+ });
6894
+ if (leftSide) x = moveaxis$1(x, -2, -1);
6895
+ return x;
6896
+ }
6897
+
6898
+ //#endregion
6899
+ //#region src/library/lax.ts
6900
+ var lax_exports = {};
6901
+ __export(lax_exports, {
6902
+ conv: () => conv,
6903
+ convGeneralDilated: () => convGeneralDilated,
6904
+ convWithGeneralPadding: () => convWithGeneralPadding,
6905
+ dot: () => dot,
6906
+ erf: () => erf,
6907
+ erfc: () => erfc,
6908
+ linalg: () => lax_linalg_exports,
6909
+ reduceWindow: () => reduceWindow,
6910
+ stopGradient: () => stopGradient$1
6911
+ });
6912
+ /**
6913
+ * General dot product/contraction operator.
5859
6914
  *
5860
- * If `correction` is provided, the divisor in calculation is `N - correction`,
5861
- * where `N` represents the number of elements (e.g., for Bessel's correction).
6915
+ * Prefer higher-level functions like `jax.numpy.dot()`, `jax.numpy.matmul()`,
6916
+ * `jax.numpy.tensordot(), and `jax.numpy.einsum()` where possible.
5862
6917
  */
5863
- function var_(x, axis = null, opts) {
5864
- x = fudgeArray(x);
5865
- axis = normalizeAxis(axis, x.ndim);
5866
- const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
5867
- if (n === 0) throw new Error("var: cannot compute variance over zero-length axis");
5868
- const mu = opts?.mean !== void 0 ? opts.mean : mean(x.ref, axis, { keepdims: true });
5869
- return square(x.sub(mu)).sum(axis, { keepdims: opts?.keepdims }).mul(1 / (n - (opts?.correction ?? 0)));
6918
+ function dot(lhs, rhs, { lhsContractingDims: lc = [], rhsContractingDims: rc = [], lhsBatchDims: lb = [], rhsBatchDims: rb = [] } = {}) {
6919
+ if (lc.length !== rc.length) throw new Error(`dot: contracting dims lengths mismatch, got ${JSON.stringify(lc)} and ${JSON.stringify(rc)}`);
6920
+ else if (lb.length !== rb.length) throw new Error(`dot: batch dims lengths mismatch, got ${JSON.stringify(lb)} and ${JSON.stringify(rb)}`);
6921
+ lc = lc.map((a) => checkAxis(a, lhs.ndim));
6922
+ rc = rc.map((a) => checkAxis(a, rhs.ndim));
6923
+ lb = lb.map((a) => checkAxis(a, lhs.ndim));
6924
+ rb = rb.map((a) => checkAxis(a, rhs.ndim));
6925
+ if (lc.some((a) => lb.includes(a))) throw new Error(`dot: lhs contracting dims ${JSON.stringify(lc)} overlap with batch dims ${JSON.stringify(lb)}`);
6926
+ else if (rc.some((a) => rb.includes(a))) throw new Error(`dot: rhs contracting dims ${JSON.stringify(rc)} overlap with batch dims ${JSON.stringify(rb)}`);
6927
+ const lf = range(lhs.ndim).filter((a) => !lc.includes(a) && !lb.includes(a));
6928
+ const rf = range(rhs.ndim).filter((a) => !rc.includes(a) && !rb.includes(a));
6929
+ const lhs2 = lhs.transpose([
6930
+ ...lb,
6931
+ ...lf,
6932
+ ...lc
6933
+ ]);
6934
+ const rhs2 = rhs.transpose([
6935
+ ...rb,
6936
+ ...rf,
6937
+ ...rc
6938
+ ]);
6939
+ if (lc.length === 0) return mul(lhs2.reshape([
6940
+ ...lb.map((a) => lhs.shape[a]),
6941
+ ...lf.map((a) => lhs.shape[a]),
6942
+ ...rep(rf.length, 1)
6943
+ ]), rhs2.reshape([
6944
+ ...rb.map((a) => rhs.shape[a]),
6945
+ ...rep(lf.length, 1),
6946
+ ...rf.map((a) => rhs.shape[a])
6947
+ ]));
6948
+ const dotShapeX = lc.map((a) => lhs.shape[a]);
6949
+ const dotShapeY = rc.map((a) => rhs.shape[a]);
6950
+ if (!deepEqual(dotShapeX, dotShapeY)) throw new Error(`dot: shapes not aligned along contracting dims: ${JSON.stringify(dotShapeX)} != ${JSON.stringify(dotShapeY)}`);
6951
+ return dot$2(lhs2.reshape([
6952
+ ...lb.map((a) => lhs.shape[a]),
6953
+ ...lf.map((a) => lhs.shape[a]),
6954
+ ...rep(rf.length, 1),
6955
+ prod(dotShapeX)
6956
+ ]), rhs2.reshape([
6957
+ ...rb.map((a) => rhs.shape[a]),
6958
+ ...rep(lf.length, 1),
6959
+ ...rf.map((a) => rhs.shape[a]),
6960
+ prod(dotShapeY)
6961
+ ]));
6962
+ }
6963
+ function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
6964
+ const padType = padding.toUpperCase();
6965
+ switch (padType) {
6966
+ case "VALID": return rep(inShape.length, [0, 0]);
6967
+ case "SAME":
6968
+ case "SAME_LOWER": {
6969
+ const outShape = inShape.map((size$1, i) => Math.ceil(size$1 / strides[i]));
6970
+ const padSizes = zipn(outShape, strides, filterShape, dilation, inShape).map(([o, s, k, d, i]) => Math.max(0, (o - 1) * s + 1 + (k - 1) * d - i));
6971
+ if (padType === "SAME") return padSizes.map((size$1) => [size$1 >> 1, size$1 - (size$1 >> 1)]);
6972
+ else return padSizes.map((size$1) => [size$1 - (size$1 >> 1), size$1 >> 1]);
6973
+ }
6974
+ default: throw new Error(`Unknown padding type: ${padType}`);
6975
+ }
5870
6976
  }
5871
6977
  /**
5872
- * Compute the standard deviation of an array.
6978
+ * General n-dimensional convolution operator, with optional dilation.
5873
6979
  *
5874
- * The standard deviation is computed for the flattened array by default,
5875
- * otherwise over the specified axis.
6980
+ * The semantics of this operation mimic the `jax.lax.conv_general_dilated`
6981
+ * function in JAX, which wraps XLA's general convolution operator.
5876
6982
  *
5877
- * If `correction` is provided, the divisor in calculation is `N - correction`,
5878
- * where `N` represents the number of elements (e.g., for Bessel's correction).
6983
+ * Grouped convolutions are not supported right now.
5879
6984
  */
5880
- function std(x, axis = null, opts) {
5881
- return sqrt(var_(x, axis, opts));
6985
+ function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
6986
+ if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
6987
+ if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
6988
+ if (typeof padding === "string") {
6989
+ if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
6990
+ padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? rep(rhs.ndim - 2, 1), padding);
6991
+ }
6992
+ if (featureGroupCount !== 1) {
6993
+ const G = featureGroupCount;
6994
+ const [N, C_in, ...xs] = lhs.shape;
6995
+ const [C_out, C_in_per_group, ...ks] = rhs.shape;
6996
+ if (C_in % G !== 0) throw new Error(`featureGroupCount=${G} must divide input channels=${C_in}`);
6997
+ if (C_out % G !== 0) throw new Error(`featureGroupCount=${G} must divide output channels=${C_out}`);
6998
+ if (C_in / G !== C_in_per_group) throw new Error(`rhs input channels=${C_in_per_group} must equal lhs input channels / groups=${C_in / G}`);
6999
+ const lhsGrouped = moveaxis(lhs.reshape([
7000
+ N,
7001
+ G,
7002
+ C_in / G,
7003
+ ...xs
7004
+ ]), 1, 0);
7005
+ const rhsGrouped = rhs.reshape([
7006
+ G,
7007
+ C_out / G,
7008
+ C_in_per_group,
7009
+ ...ks
7010
+ ]);
7011
+ const result = conv$1(lhsGrouped, rhsGrouped, {
7012
+ vmapDims: 1,
7013
+ strides: windowStrides,
7014
+ padding,
7015
+ lhsDilation,
7016
+ rhsDilation
7017
+ });
7018
+ const ys = result.shape.slice(3);
7019
+ return moveaxis(result, 0, 1).reshape([
7020
+ N,
7021
+ C_out,
7022
+ ...ys
7023
+ ]);
7024
+ }
7025
+ return conv$1(lhs, rhs, {
7026
+ strides: windowStrides,
7027
+ padding,
7028
+ lhsDilation,
7029
+ rhsDilation
7030
+ });
5882
7031
  }
5883
- /** Test element-wise for positive or negative infinity, return bool array. */
5884
- function isinf(x) {
5885
- x = fudgeArray(x);
5886
- return isFloatDtype(x.dtype) ? x.ref.equal(Infinity).add(x.equal(-Infinity)) : fullLike$1(x, false);
7032
+ /** Convenience wrapper around `convGeneralDilated`. */
7033
+ function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation) {
7034
+ return convGeneralDilated(lhs, rhs, windowStrides, padding, {
7035
+ lhsDilation,
7036
+ rhsDilation
7037
+ });
5887
7038
  }
5888
- /** Test element-wise for NaN (Not a Number). */
5889
- function isnan(x) {
5890
- x = fudgeArray(x);
5891
- return isFloatDtype(x.dtype) ? x.ref.notEqual(x) : fullLike$1(x, false);
7039
+ /** Convenience wrapper around `convGeneralDilated`. */
7040
+ function conv(lhs, rhs, windowStrides, padding) {
7041
+ return convGeneralDilated(lhs, rhs, windowStrides, padding);
5892
7042
  }
5893
- /** Test element-wise for negative infinity, return bool array. */
5894
- function isneginf(x) {
5895
- x = fudgeArray(x);
5896
- return isFloatDtype(x.dtype) ? x.equal(-Infinity) : fullLike$1(x, false);
7043
+ /** Reduce a computation over padded windows. */
7044
+ function reduceWindow(operand, computation, windowDimensions, windowStrides) {
7045
+ if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
7046
+ if (!windowStrides) windowStrides = rep(windowDimensions.length, 1);
7047
+ for (let i = 0; i < operand.ndim; i++) computation = vmap$1(computation, 0);
7048
+ return computation(bind1(Primitive.Pool, [operand], {
7049
+ window: windowDimensions,
7050
+ strides: windowStrides
7051
+ }));
5897
7052
  }
5898
- /** Test element-wise for positive infinity, return bool array. */
5899
- function isposinf(x) {
5900
- x = fudgeArray(x);
5901
- return isFloatDtype(x.dtype) ? x.equal(Infinity) : fullLike$1(x, false);
7053
+ /** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
7054
+ function erf(x) {
7055
+ return erf$1(x);
5902
7056
  }
5903
7057
  /**
5904
- * @function
5905
- * Test element-wise for finite values (not infinity or NaN).
7058
+ * The complementary error function: `erfc(x) = 1 - erf(x)`.
7059
+ *
7060
+ * This function is more accurate than `1 - erf(x)` for large values of `x`,
7061
+ * where `erf(x)` is very close to 1.
5906
7062
  */
5907
- const isfinite = jit$1(function isfinite$1(x) {
5908
- if (!isFloatDtype(x.dtype)) return fullLike$1(x, true);
5909
- return isnan(x.ref).add(isinf(x)).notEqual(true);
5910
- });
7063
+ function erfc(x) {
7064
+ return erfc$1(x);
7065
+ }
7066
+ /**
7067
+ * Stops gradient computation.
7068
+ *
7069
+ * Behaves as the identity function but prevents the flow of gradients during
7070
+ * forward or reverse-mode automatic differentiation.
7071
+ */
7072
+ function stopGradient$1(x) {
7073
+ return stopGradient(x);
7074
+ }
5911
7075
 
5912
7076
  //#endregion
5913
7077
  //#region src/library/nn.ts
@@ -5917,6 +7081,10 @@ __export(nn_exports, {
5917
7081
  elu: () => elu,
5918
7082
  gelu: () => gelu,
5919
7083
  glu: () => glu,
7084
+ hardSigmoid: () => hardSigmoid,
7085
+ hardSilu: () => hardSilu,
7086
+ hardSwish: () => hardSilu,
7087
+ hardTanh: () => hardTanh,
5920
7088
  identity: () => identity,
5921
7089
  leakyRelu: () => leakyRelu,
5922
7090
  logSigmoid: () => logSigmoid,
@@ -5927,14 +7095,17 @@ __export(nn_exports, {
5927
7095
  oneHot: () => oneHot,
5928
7096
  relu: () => relu,
5929
7097
  relu6: () => relu6,
7098
+ selu: () => selu,
5930
7099
  sigmoid: () => sigmoid,
5931
7100
  silu: () => silu,
5932
7101
  softSign: () => softSign,
5933
7102
  softmax: () => softmax,
5934
7103
  softplus: () => softplus,
7104
+ sparsePlus: () => sparsePlus,
7105
+ sparseSigmoid: () => sparseSigmoid,
5935
7106
  squareplus: () => squareplus,
5936
7107
  standardize: () => standardize,
5937
- swish: () => swish
7108
+ swish: () => silu
5938
7109
  });
5939
7110
  /**
5940
7111
  * Rectified Linear Unit (ReLU) activation function:
@@ -5969,6 +7140,28 @@ function softplus(x) {
5969
7140
  return log(exp(x).add(1));
5970
7141
  }
5971
7142
  /**
7143
+ * @function
7144
+ * Sparse plus function:
7145
+ *
7146
+ * - When `x <= -1`: `0`
7147
+ * - When `-1 < x < 1`: `(x+1)**2 / 4`
7148
+ * - When `x >= 1`: `x`
7149
+ */
7150
+ const sparsePlus = jit$1((x) => {
7151
+ return where(x.ref.lessEqual(-1), 0, where(x.ref.less(1), square(x.ref.add(1)).mul(.25), x));
7152
+ });
7153
+ /**
7154
+ * @function
7155
+ * Sparse sigmoid activation function.
7156
+ *
7157
+ * - When `x <= -1`: `0`
7158
+ * - When `-1 < x < 1`: `(x + 1) / 2`
7159
+ * - When `x >= 1`: `1`
7160
+ */
7161
+ const sparseSigmoid = jit$1((x) => {
7162
+ return clip(x.add(1).mul(.5), 0, 1);
7163
+ });
7164
+ /**
5972
7165
  * Soft-sign activation function, computed element-wise:
5973
7166
  * `softsign(x) = x / (|x| + 1)`.
5974
7167
  */
@@ -5990,17 +7183,6 @@ const silu = jit$1(function silu$1(x) {
5990
7183
  return x.ref.mul(sigmoid(x));
5991
7184
  });
5992
7185
  /**
5993
- * @function
5994
- * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
5995
- * Swish, computed element-wise:
5996
- * `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
5997
- *
5998
- * `swish()` and `silu()` are both aliases for the same function.
5999
- *
6000
- * Reference: https://en.wikipedia.org/wiki/Swish_function
6001
- */
6002
- const swish = silu;
6003
- /**
6004
7186
  * Log-sigmoid activation function, computed element-wise:
6005
7187
  * `log_sigmoid(x) = log(sigmoid(x)) = -log(1 + exp(-x))`.
6006
7188
  */
@@ -6017,6 +7199,19 @@ function leakyRelu(x, negativeSlope = .01) {
6017
7199
  x = fudgeArray(x);
6018
7200
  return where(less(x.ref, 0), x.ref.mul(negativeSlope), x);
6019
7201
  }
7202
+ /** Hard sigmoid activation function: `relu6(x+3)/6`. */
7203
+ function hardSigmoid(x) {
7204
+ return relu6(add(x, 3)).mul(1 / 6);
7205
+ }
7206
+ /** Hard SiLU (swish) activation function: `x * hardSigmoid(x)`. */
7207
+ function hardSilu(x) {
7208
+ x = fudgeArray(x);
7209
+ return x.ref.mul(hardSigmoid(x));
7210
+ }
7211
+ /** Hard tanh activation function: `clip(x, -1, 1)`. */
7212
+ function hardTanh(x) {
7213
+ return clip(x, -1, 1);
7214
+ }
6020
7215
  /**
6021
7216
  * Exponential linear unit activation function.
6022
7217
  *
@@ -6039,6 +7234,20 @@ function celu(x, alpha = 1) {
6039
7234
  }
6040
7235
  /**
6041
7236
  * @function
7237
+ * Scaled exponential linear unit activation.
7238
+ *
7239
+ * Computes the element-wise function:
7240
+ * `selu(x) = lambda * (x > 0 ? x : alpha * (exp(x) - 1))`
7241
+ *
7242
+ * Where `alpha = 1.6732632423543772` and `lambda = 1.0507009873554805`.
7243
+ */
7244
+ const selu = jit$1(function selu$1(x) {
7245
+ const alpha = 1.6732632423543772;
7246
+ const lambda = 1.0507009873554805;
7247
+ return where(x.ref.less(0), expm1(x.ref).mul(alpha), x).mul(lambda);
7248
+ });
7249
+ /**
7250
+ * @function
6042
7251
  * Gaussion error linear unit (GELU) activation function.
6043
7252
  *
6044
7253
  * This is computed element-wise. There are two variants depending on whether
@@ -6192,35 +7401,46 @@ var random_exports = {};
6192
7401
  __export(random_exports, {
6193
7402
  bernoulli: () => bernoulli,
6194
7403
  bits: () => bits,
7404
+ cauchy: () => cauchy,
6195
7405
  exponential: () => exponential,
7406
+ gumbel: () => gumbel,
6196
7407
  key: () => key,
7408
+ laplace: () => laplace,
7409
+ multivariateNormal: () => multivariateNormal,
6197
7410
  normal: () => normal,
6198
7411
  split: () => split,
6199
7412
  uniform: () => uniform
6200
7413
  });
6201
- function validateKeyShape(key$1) {
7414
+ function validateKeyShape(key$1, scalar = false) {
6202
7415
  if (key$1.ndim === 0) throw new Error("Key must have at least one dimension.");
6203
7416
  if (key$1.shape[key$1.shape.length - 1] !== 2) throw new Error(`Invalid key shape: ${key$1.shape}. Expected last dimension to be 2.`);
7417
+ if (scalar && key$1.shape.length > 1) throw new Error(`Expected a single PRNG key, but got a batch of keys with shape ${JSON.stringify(key$1.shape)} - use jax.vmap for batching.`);
6204
7418
  return key$1.shape.slice(0, -1);
6205
7419
  }
7420
+ function getK01(key$1) {
7421
+ const keyShape = validateKeyShape(key$1, true);
7422
+ let [k0, k1] = split$2(key$1, -1, [1, 1]);
7423
+ k0 = k0.reshape(keyShape);
7424
+ k1 = k1.reshape(keyShape);
7425
+ return [k0, k1];
7426
+ }
6206
7427
  /** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
6207
7428
  function key(seed) {
6208
- seed = seed >>> 0;
6209
- return array([0, seed], { dtype: DType.Uint32 });
7429
+ seed = array(seed, { dtype: DType.Uint32 });
7430
+ if (seed.ndim !== 0) throw new Error(`key: seed must be a scalar integer, but got shape ${seed.shape} - use jax.vmap for batching.`);
7431
+ return stack([0, seed]);
6210
7432
  }
6211
7433
  /** Splits a PRNG key into `num` new keys by adding a leading axis. */
6212
7434
  function split(key$1, num = 2) {
6213
7435
  const shape$1 = typeof num === "number" ? [num] : num;
6214
7436
  for (const len of shape$1) if (len <= 0 || !Number.isInteger(len)) throw new Error(`Invalid split length: ${len}. Must be a positive integer.`);
6215
- const keyShape = validateKeyShape(key$1);
6216
- const k0 = key$1.ref.slice(...keyShape.map(() => null), 0);
6217
- const k1 = key$1.slice(...keyShape.map(() => null), 1);
7437
+ const [k0, k1] = getK01(key$1);
6218
7438
  return stack([randomBits(k0.ref, k1.ref, shape$1, 0), randomBits(k0, k1, shape$1, 1)], -1);
6219
7439
  }
6220
7440
  /** Sample uniform bits in the form of unsigned integers. */
6221
7441
  function bits(key$1, shape$1 = []) {
6222
- const keyShape = validateKeyShape(key$1);
6223
- return randomBits(key$1.ref.slice(...keyShape.map(() => null), 0), key$1.slice(...keyShape.map(() => null), 1), shape$1);
7442
+ const [k0, k1] = getK01(key$1);
7443
+ return randomBits(k0, k1, shape$1);
6224
7444
  }
6225
7445
  /**
6226
7446
  * @function
@@ -6252,6 +7472,16 @@ function bernoulli(key$1, p = .5, shape$1 = []) {
6252
7472
  }
6253
7473
  /**
6254
7474
  * @function
7475
+ * Sample from a Cauchy distribution with location 0 and scale 1.
7476
+ *
7477
+ * Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
7478
+ */
7479
+ const cauchy = jit$1(function cauchy$1(key$1, shape$1 = []) {
7480
+ const u = uniform(key$1, shape$1);
7481
+ return tan(u.sub(.5).mul(Math.PI));
7482
+ }, { staticArgnums: [1] });
7483
+ /**
7484
+ * @function
6255
7485
  * Sample exponential random values according to `p(x) = exp(-x)`.
6256
7486
  */
6257
7487
  const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
@@ -6260,6 +7490,56 @@ const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
6260
7490
  }, { staticArgnums: [1] });
6261
7491
  /**
6262
7492
  * @function
7493
+ * Sample from a Gumbel distribution with location 0 and scale 1.
7494
+ *
7495
+ * Uses inverse transform sampling: `x = -log(-log(u))` where u ~ Uniform(0, 1).
7496
+ */
7497
+ const gumbel = jit$1(function gumbel$1(key$1, shape$1 = []) {
7498
+ const u = uniform(key$1, shape$1);
7499
+ return negative(log(negative(log1p(negative(u)))));
7500
+ }, { staticArgnums: [1] });
7501
+ /**
7502
+ * @function
7503
+ * Sample from a Laplace distribution with location 0 and scale 1.
7504
+ *
7505
+ * Uses inverse transform sampling: the CDF is `F(x) = 0.5 + 0.5 * sign(x) * (1 - exp(-|x|))`.
7506
+ * Inverting: `x = -sign(u - 0.5) * log(1 - 2 * |u - 0.5|)`.
7507
+ */
7508
+ const laplace = jit$1(function laplace$1(key$1, shape$1 = []) {
7509
+ const u = uniform(key$1, shape$1);
7510
+ const centered = u.sub(.5);
7511
+ const s = sign(centered.ref);
7512
+ const absVal = absolute(centered);
7513
+ return s.mul(log1p(absVal.mul(-2)).mul(-1));
7514
+ }, { staticArgnums: [1] });
7515
+ /**
7516
+ * @function
7517
+ * Sample multivariate normal random values with given mean and covariance.
7518
+ *
7519
+ * The values are returned with the given shape, along with the final dimension
7520
+ * used to represent the n-dimensional multivariate normal factors.
7521
+ *
7522
+ * This uses Cholesky decomposition on the covariance matrix.
7523
+ *
7524
+ * - `key` - PRNG key
7525
+ * - `mean` - Mean vector of shape `[..., n]`
7526
+ * - `cov` - Covariance of shape `[..., n, n]`, must be positive-definite
7527
+ * - `shape` - Result batch shape, must be broadcastable with
7528
+ * `mean.shape[:-1]` and `cov.shape[:-2]`
7529
+ * @returns Random samples of shape `[...shape, n]`
7530
+ */
7531
+ const multivariateNormal = jit$1(function multivariateNormal$1(key$1, mean$1, cov$1, shape$1 = []) {
7532
+ mean$1 = fudgeArray(mean$1);
7533
+ cov$1 = fudgeArray(cov$1);
7534
+ const n = mean$1.shape[mean$1.ndim - 1];
7535
+ if (cov$1.shape[cov$1.ndim - 1] !== n || cov$1.shape[cov$1.ndim - 2] !== n) throw new Error(`Invalid covariance shape: ${cov$1.shape}. Expected last two dimensions to be [${n}, ${n}].`);
7536
+ const outputShape = broadcastShapes(shape$1, mean$1.shape.slice(0, -1), cov$1.shape.slice(0, -2)).concat(n);
7537
+ const L = cholesky(cov$1);
7538
+ const z = normal(key$1, outputShape);
7539
+ return einsum("...ij,...j->...i", L, z).add(mean$1);
7540
+ }, { staticArgnums: [3] });
7541
+ /**
7542
+ * @function
6263
7543
  * Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
6264
7544
  *
6265
7545
  * Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
@@ -6368,11 +7648,6 @@ const valueAndGrad = valueAndGrad$1;
6368
7648
  */
6369
7649
  const jacrev = jacrev$1;
6370
7650
  /**
6371
- * @function
6372
- * Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
6373
- */
6374
- const jacobian = jacrev;
6375
- /**
6376
7651
  * Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
6377
7652
  *
6378
7653
  * This can be used to wait for the results of an intermediate computation to
@@ -6407,5 +7682,4 @@ async function devicePut(x, device) {
6407
7682
  }
6408
7683
 
6409
7684
  //#endregion
6410
- export { Array$1 as Array, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
6411
- //# sourceMappingURL=index.js.map
7685
+ export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };