@jax-js/jax 0.1.4 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.d.ts CHANGED
@@ -437,10 +437,37 @@ declare enum Routines {
437
437
  Sort = "Sort",
438
438
  /** Returns `int32` indices of the stably sorted array. */
439
439
  Argsort = "Argsort",
440
- /** Solve a triangular system of questions. */
440
+ /**
441
+ * Solve a triangular system of equations.
442
+ *
443
+ * The first batch of inputs `A` should be of shape `[..., N, N]` and upper
444
+ * triangular, while the second batch `B` should be of shape `[..., M, N]`.
445
+ *
446
+ * Solves for `X` in the equation `A @ X.T = B.T`, where `A` is the
447
+ * triangular matrix. This is equivalent to `X = B @ A^-T`.
448
+ */
441
449
  TriangularSolve = "TriangularSolve",
442
- /** Cholesky decomposition of 2D positive semi-definite matrices. */
450
+ /**
451
+ * Cholesky decomposition of 2D positive semi-definite matrices.
452
+ *
453
+ * The input batch should be of shape `[..., N, N]`, and the output batch is
454
+ * of the same shape, containing the lower-triangular matrix `L` such that
455
+ * `A = L @ L.T`. Behavior is unspecified if A is not positive semi-definite.
456
+ */
443
457
  Cholesky = "Cholesky",
458
+ /**
459
+ * LU decomposition of 2D rectangular matrices.
460
+ *
461
+ * The input is a batch of shape `[..., M, N]`, and the output is a tuple of
462
+ * three arrays: `LU, Pivots, Permutation`.
463
+ *
464
+ * - `LU` is of shape `[..., M, N]`, containing the combined lower and upper
465
+ * triangular matrices. (lower triangular = implicit unit diagonal)
466
+ * - `Pivots` is of shape `[..., min(M, N)]`, containing the row swaps.
467
+ * - `Permutation` is of shape `[..., M]`, containing the permutation vector
468
+ * such that `P = eye(M).slice(Permutation)` -> `P @ A = L @ U`.
469
+ */
470
+ LU = "LU",
444
471
  }
445
472
  interface RoutineType {
446
473
  inputShapes: number[][];
@@ -450,7 +477,7 @@ interface RoutineType {
450
477
  }
451
478
  //#endregion
452
479
  //#region src/backend.d.ts
453
- type Device = "cpu" | "wasm" | "webgpu";
480
+ type Device = "cpu" | "wasm" | "webgpu" | "webgl";
454
481
  declare const devices: Device[];
455
482
  /** Configure the default device for arrays. */
456
483
  declare function defaultDevice(device?: Device): Device;
@@ -535,7 +562,7 @@ declare function fft(a: ComplexPair, axis?: number): ComplexPair;
535
562
  */
536
563
  declare function ifft(a: ComplexPair, axis?: number): ComplexPair;
537
564
  declare namespace numpy_linalg_d_exports {
538
- export { cholesky$1 as cholesky, diagonal, lstsq, matmul, matrixTranspose, outer, tensordot, trace, vecdot };
565
+ export { cholesky$1 as cholesky, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot };
539
566
  }
540
567
  /**
541
568
  * Compute the Cholesky decomposition of a (batched) positive-definite matrix.
@@ -550,6 +577,10 @@ declare function cholesky$1(a: ArrayLike, {
550
577
  upper?: boolean;
551
578
  symmetrizeInput?: boolean;
552
579
  }): Array;
580
+ /** Compute the determinant of a square matrix (batched). */
581
+ declare function det(a: ArrayLike): Array;
582
+ /** Compute the inverse of a square matrix (batched). */
583
+ declare function inv(a: ArrayLike): Array;
553
584
  /**
554
585
  * Return the least-squares solution to a linear equation.
555
586
  *
@@ -564,8 +595,75 @@ declare function cholesky$1(a: ArrayLike, {
564
595
  * @return least-squares solution of shape `(N,)` or `(N, K)`
565
596
  */
566
597
  declare function lstsq(a: ArrayLike, b: ArrayLike): Array;
598
+ /** Raise a square matrix to an integer power, via repeated squarings. */
599
+ declare function matrixPower(a: ArrayLike, n: number): Array;
600
+ /** Return sign and natural logarithm of the determinant of `a`. */
601
+ declare function slogdet(a: ArrayLike): [Array, Array];
602
+ /**
603
+ * Solve a linear system of equations.
604
+ *
605
+ * This solves a (batched) linear system of equations `a @ x = b` for `x` given
606
+ * `a` and `b`. If `a` is singular, this will return `nan` or `inf` values.
607
+ *
608
+ * @param a - Coefficient matrix of shape `(..., N, N)`.
609
+ * @param b - Values of shape `(N,)` or `(..., N, M)`.
610
+ * @returns Solution `x` of shape `(..., N)` or `(..., N, M)`.
611
+ */
612
+ declare function solve(a: ArrayLike, b: ArrayLike): Array;
613
+ //#endregion
614
+ //#region src/library/numpy/dtype-info.d.ts
615
+ /** @inline */
616
+ type FInfo = Readonly<{
617
+ /** The number of bits occupied by the type. */
618
+ bits: number;
619
+ /** Returns the _dtype_ for which finfo returns information. */
620
+ dtype: DType;
621
+ /** The difference between 1.0 and the next smallest representable float larger than 1.0. */
622
+ eps: number;
623
+ /** The difference between 1.0 and the next largest representable float smaller than 1.0. */
624
+ epsneg: number;
625
+ /** The exponent that yields `eps`. */
626
+ machep: number;
627
+ /** The largest representable finite number. */
628
+ max: number;
629
+ /** The smallest positive power of the base (2) that causes overflow. */
630
+ maxexp: number;
631
+ /** The smallest representable (most negative) finite number. */
632
+ min: number;
633
+ /** The largest negative power of the base (2) without leading zeros in mantissa. */
634
+ minexp: number;
635
+ /** The exponent that yields `epsneg`. */
636
+ negep: number;
637
+ /** Number of bits in the exponent portion. */
638
+ nexp: number;
639
+ /** Number of bits in the mantissa portion. */
640
+ nmant: number;
641
+ /** The approximate number of decimal digits to which this kind of float is precise. */
642
+ precision: number;
643
+ /** The approximate decimal resolution, i.e., `10 ** -precision`. */
644
+ resolution: number;
645
+ /** The smallest positive normal number. */
646
+ smallestNormal: number;
647
+ /** The smallest positive subnormal number. */
648
+ smallestSubnormal: number;
649
+ }>;
650
+ /** Machine limits for floating-point types. */
651
+ declare function finfo(dtype: DType): FInfo;
652
+ /** @inline */
653
+ type IInfo = Readonly<{
654
+ /** The number of bits occupied by the type. */
655
+ bits: number;
656
+ /** Returns the _dtype_ for which iinfo returns information. */
657
+ dtype: DType;
658
+ /** The largest representable integer. */
659
+ max: number;
660
+ /** The smallest representable integer. */
661
+ min: number;
662
+ }>;
663
+ /** Machine limits for integer types. */
664
+ declare function iinfo(dtype: DType): IInfo;
567
665
  declare namespace numpy_d_exports {
568
- export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, dot$1 as dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, flip, fliplr, flipud, float16, float32, float64, floor, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, matmul, matrixTranspose, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sort, sqrt, square, squeeze, stack, std, subtract, sum, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
666
+ export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot$1 as dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logspace, matmul, matrixTranspose, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
569
667
  }
570
668
  declare const float32 = DType.Float32;
571
669
  declare const int32 = DType.Int32;
@@ -732,6 +830,16 @@ declare function argmax(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
732
830
  declare function cumsum(a: ArrayLike, axis?: number): Array;
733
831
  /** Reverse the elements in an array along the given axes. */
734
832
  declare function flip(x: ArrayLike, axis?: Axis): Array;
833
+ /**
834
+ * Split an array into multiple sub-arrays along an axis.
835
+ *
836
+ * @param a - The input array to split.
837
+ * @param indicesOrSections - If an integer, it indicates the number of equal
838
+ * sections to create along the specified axis. If a list of integers, it
839
+ * specifies the indices at which to split the array.
840
+ * @param axis - The axis along which to split the array. Default is 0.
841
+ */
842
+ declare function split$1(a: ArrayLike, indicesOrSections: number | number[], axis?: number): Array[];
735
843
  /**
736
844
  * Join a sequence of arrays along an existing axis.
737
845
  *
@@ -775,6 +883,8 @@ declare function columnStack(xs: ArrayLike[]): Array;
775
883
  declare function flipud(x: ArrayLike): Array;
776
884
  /** Flip an array horizontally (axis=1). */
777
885
  declare function fliplr(x: ArrayLike): Array;
886
+ /** Interchange two axes of an array. */
887
+ declare function swapaxes(a: ArrayLike, axis1: number, axis2: number): Array;
778
888
  /** Transpose the last two dimensions of an array. */
779
889
  declare function matrixTranspose(a: ArrayLike): Array;
780
890
  /** Return a 1-D flattened array containing the elements of the input. */
@@ -860,6 +970,13 @@ declare function sort(a: ArrayLike, axis?: number): Array;
860
970
  * The array is sorted along a specified axis (the last by default).
861
971
  */
862
972
  declare function argsort(a: ArrayLike, axis?: number): Array;
973
+ /**
974
+ * Take elements from an array along an axis.
975
+ *
976
+ * This is equivalent to advanced indexing with integer indices over that
977
+ * numbered axis. By default, the flattened array is used.
978
+ */
979
+ declare function take(a: ArrayLike, indices: ArrayLike, axis?: number | null): Array;
863
980
  /** Return if two arrays are element-wise equal within a tolerance. */
864
981
  declare function allclose(actual: Parameters<typeof array>[0], expected: Parameters<typeof array>[0], options?: {
865
982
  rtol?: number;
@@ -990,6 +1107,17 @@ declare const heaviside: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
990
1107
  declare function square(x: ArrayLike): Array;
991
1108
  /** Element-wise tangent function (takes radians). */
992
1109
  declare function tan(x: ArrayLike): Array;
1110
+ /**
1111
+ * @function
1112
+ * Return the normalized sinc function.
1113
+ *
1114
+ * The sinc function is defined as `sin(πx) / (πx)` for `x != 0`, and `1` for `x = 0`.
1115
+ * This is the normalized sinc function commonly used in signal processing.
1116
+ *
1117
+ * **Note:** JVP is not supported at x=0 due to discontinuous derivative. This
1118
+ * requires a custom JVP rule to handle properly (see JAX implementation).
1119
+ */
1120
+ declare const sinc: OwnedFunction<(x: ArrayLike) => Array>;
993
1121
  /** Element-wise inverse cosine function (inverse of cos). */
994
1122
  declare function acos(x: ArrayLike): Array;
995
1123
  /**
@@ -1019,6 +1147,20 @@ declare const atan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
1019
1147
  declare function subtract(x: ArrayLike, y: ArrayLike): Array;
1020
1148
  /** Calculates the floating-point division of x by y element-wise. */
1021
1149
  declare function trueDivide(x: ArrayLike, y: ArrayLike): Array;
1150
+ /**
1151
+ * Return the largest integer smaller or equal to the division of the inputs.
1152
+ *
1153
+ * The result is always rounded towards negative infinity.
1154
+ *
1155
+ * For floating-point inputs, this is equivalent to `floor(x / y)`.
1156
+ * For integer inputs, we use `(x - remainder(x, y)) / y` to handle
1157
+ * negative values correctly (note: may overflow near int32 boundaries).
1158
+ *
1159
+ * @param x - Dividend array.
1160
+ * @param y - Divisor array.
1161
+ * @returns Element-wise floor division of x by y.
1162
+ */
1163
+ declare function floorDivide(x: ArrayLike, y: ArrayLike): Array;
1022
1164
  /**
1023
1165
  * @function
1024
1166
  * Calculate element-wise floating-point modulo operation.
@@ -1029,6 +1171,16 @@ declare const fmod: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
1029
1171
  * Calculate element-wise remainder of the division (matches sign of y).
1030
1172
  */
1031
1173
  declare const remainder: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
1174
+ /**
1175
+ * Return element-wise quotient and remainder simultaneously.
1176
+ *
1177
+ * Equivalent to `[floorDivide(x, y), remainder(x, y)]`.
1178
+ *
1179
+ * @param x - Dividend array.
1180
+ * @param y - Divisor array.
1181
+ * @returns Tuple of [quotient, remainder].
1182
+ */
1183
+ declare function divmod(x: ArrayLike, y: ArrayLike): [Array, Array];
1032
1184
  /** Round input to the nearest integer towards zero. */
1033
1185
  declare function trunc(x: ArrayLike): Array;
1034
1186
  /**
@@ -1139,7 +1291,11 @@ declare function std(x: ArrayLike, axis?: Axis, opts?: {
1139
1291
  correction?: number;
1140
1292
  } & ReduceOpts): Array;
1141
1293
  /** Estimate the sample covariance of a set of variables. */
1142
- declare function cov(x: ArrayLike, y?: ArrayLike): Array;
1294
+ declare function cov(x: ArrayLike, y?: ArrayLike | null, {
1295
+ rowvar
1296
+ }?: {
1297
+ rowvar?: boolean;
1298
+ }): Array;
1143
1299
  /** Compute the Pearson correlation coefficients (in range `[-1, 1]`). */
1144
1300
  declare function corrcoef(x: ArrayLike, y?: ArrayLike): Array;
1145
1301
  /** Test element-wise for positive or negative infinity, return bool array. */
@@ -1353,6 +1509,8 @@ declare enum Primitive {
1353
1509
  PoolTranspose = "pool_transpose",
1354
1510
  Compare = "compare",
1355
1511
  Where = "where",
1512
+ Concatenate = "concatenate",
1513
+ Split = "split",
1356
1514
  RandomBits = "random_bits",
1357
1515
  Gather = "gather",
1358
1516
  Transpose = "transpose",
@@ -1369,6 +1527,8 @@ declare enum Primitive {
1369
1527
  // A is upper triangular, A @ X.T = B.T
1370
1528
  Cholesky = "cholesky",
1371
1529
  // A is positive-definite, A = L @ L^T
1530
+ LU = "lu",
1531
+ // LU decomposition with partial pivoting
1372
1532
  Jit = "jit",
1373
1533
  }
1374
1534
  interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
@@ -1395,6 +1555,13 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
1395
1555
  [Primitive.Compare]: {
1396
1556
  op: CompareOp;
1397
1557
  };
1558
+ [Primitive.Concatenate]: {
1559
+ axis: number;
1560
+ };
1561
+ [Primitive.Split]: {
1562
+ axis: number;
1563
+ sizes: number[];
1564
+ };
1398
1565
  [Primitive.RandomBits]: {
1399
1566
  shape: number[];
1400
1567
  mode: "xor" | 0 | 1;
@@ -1422,14 +1589,14 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
1422
1589
  [Primitive.Pad]: {
1423
1590
  width: Pair[];
1424
1591
  };
1592
+ [Primitive.TriangularSolve]: {
1593
+ unitDiagonal: boolean;
1594
+ };
1425
1595
  [Primitive.Jit]: {
1426
1596
  name: string;
1427
1597
  jaxpr: Jaxpr;
1428
1598
  numConsts: number;
1429
1599
  };
1430
- [Primitive.TriangularSolve]: {
1431
- unitDiagonal: boolean;
1432
- };
1433
1600
  }
1434
1601
  /** Type of parameters taken by each primitive. */
1435
1602
  type PrimitiveParams<T extends Primitive> = T extends keyof PrimitiveParamsImpl ? PrimitiveParamsImpl[T] : Record<string, never>;
@@ -1570,6 +1737,7 @@ declare abstract class Tracer {
1570
1737
  neg(): this;
1571
1738
  add(other: this | TracerValue): this;
1572
1739
  mul(other: this | TracerValue): this;
1740
+ mod(other: this | TracerValue): this;
1573
1741
  greater(other: this | TracerValue): this;
1574
1742
  less(other: this | TracerValue): this;
1575
1743
  equal(other: this | TracerValue): this;
@@ -1672,6 +1840,7 @@ declare class ShapedArray implements AbstractValue {
1672
1840
  static fromAval(aval: AbstractValue): ShapedArray;
1673
1841
  get ndim(): number;
1674
1842
  get size(): number;
1843
+ scalar(): ShapedArray;
1675
1844
  toString(): string;
1676
1845
  equals(other: ShapedArray): boolean;
1677
1846
  }
@@ -1739,6 +1908,8 @@ declare class Array extends Tracer {
1739
1908
  toString(): string;
1740
1909
  get device(): Device;
1741
1910
  get ref(): this;
1911
+ /** Get the current reference count (for debugging memory management). */
1912
+ get refCount(): number;
1742
1913
  dispose(): void;
1743
1914
  /**
1744
1915
  * Convert this array into a primitive value.
@@ -1887,8 +2058,43 @@ declare function linspace(start: number, stop: number, num?: number, endpoint?:
1887
2058
  dtype,
1888
2059
  device
1889
2060
  }?: DTypeAndDevice): Array;
2061
+ /**
2062
+ * Return numbers spaced evenly on a log scale.
2063
+ *
2064
+ * In linear space, the sequence starts at `base ** start` and ends at
2065
+ * `base ** stop` (see `endpoint` below).
2066
+ *
2067
+ * @param start - `base ** start` is the starting value of the sequence.
2068
+ * @param stop - `base ** stop` is the final value of the sequence, unless `endpoint` is false.
2069
+ * @param num - Number of samples to generate. Default is 50.
2070
+ * @param endpoint - If true, `stop` is the last sample. Otherwise, it is not included. Default is true.
2071
+ * @param base - The base of the log space. Default is 10.
2072
+ * @returns Array of evenly spaced values on a log scale.
2073
+ */
2074
+ declare function logspace(start: number, stop: number, num?: number, endpoint?: boolean, base?: number, {
2075
+ dtype,
2076
+ device
2077
+ }?: DTypeAndDevice): Array;
2078
+ //#endregion
2079
+ //#region src/frontend/linearize.d.ts
2080
+ /** @inline */
2081
+ type GradOpts = {
2082
+ /**
2083
+ * Integer or sequence of integers. Specifies which positional argument(s) to
2084
+ * differentiate with respect to.
2085
+ *
2086
+ * Defaults to `0` (the first argument).
2087
+ */
2088
+ argnums?: number | number[];
2089
+ /**
2090
+ * The input function returns a pair of `[out, aux]` including an auxiliary
2091
+ * value. This `aux` is not differentiated, but is returned alongside the
2092
+ * gradient when evaluating the function.
2093
+ */
2094
+ hasAux?: boolean;
2095
+ };
1890
2096
  declare namespace lax_linalg_d_exports {
1891
- export { cholesky, triangularSolve };
2097
+ export { cholesky, lu, triangularSolve };
1892
2098
  }
1893
2099
  /**
1894
2100
  * Compute the Cholesky decomposition of a symmetric positive-definite matrix.
@@ -1921,6 +2127,32 @@ declare function cholesky(a: ArrayLike, {
1921
2127
  }?: {
1922
2128
  upper?: boolean;
1923
2129
  }): Array;
2130
+ /**
2131
+ * LU decomposition with partial pivoting.
2132
+ *
2133
+ * Computes the matrix decomposition: `P @ A = L @ U`, where `P` is a
2134
+ * permutation of the rows of `A`, `L` is lower-triangular with unit diagonal,
2135
+ * and `U` is upper-triangular.
2136
+ *
2137
+ * @param x - A batch of matrices with shape `[..., m, n]`.
2138
+ *
2139
+ * @returns A tuple `(lu, pivots, permutation)` where:
2140
+ * - `lu`: combined lower and upper triangular matrices.
2141
+ * - `pivots`: an array of pivot indices with shape `[..., min(m, n)]`.
2142
+ * - `permutation`: the permutation generated by pivots with shape `[..., m]`.
2143
+ *
2144
+ * @example
2145
+ * ```ts
2146
+ * import { lax, numpy as np } from "@jax-js/jax";
2147
+ *
2148
+ * const A = np.array([[4., 3.], [6., 3.]]);
2149
+ * const [lu, pivots, permutation] = lax.linalg.lu(A);
2150
+ * // lu ≈ [[6., 3.], [0.6666667, 1.0]]
2151
+ * // pivots = [1, 1]
2152
+ * // permutation = [1, 0]
2153
+ * ```
2154
+ */
2155
+ declare function lu(x: ArrayLike): [Array, Array, Array];
1924
2156
  /**
1925
2157
  * Solve a triangular linear system.
1926
2158
  *
@@ -1951,7 +2183,7 @@ declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
1951
2183
  unitDiagonal?: boolean;
1952
2184
  }): Array;
1953
2185
  declare namespace lax_d_exports {
1954
- export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convWithGeneralPadding, dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
2186
+ export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
1955
2187
  }
1956
2188
  /**
1957
2189
  * Dimension numbers for general `dot()` primitive.
@@ -1989,7 +2221,11 @@ type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | Pair[];
1989
2221
  * The semantics of this operation mimic the `jax.lax.conv_general_dilated`
1990
2222
  * function in JAX, which wraps XLA's general convolution operator.
1991
2223
  *
1992
- * Grouped convolutions are not supported right now.
2224
+ * @param lhs - Input tensor; shape `[N, C_in, ...xs]`
2225
+ * @param rhs - Convolution kernel; shape `[C_out, C_in / G, ...ks]`
2226
+ * @param windowStrides - Strides for each spatial dimension
2227
+ * @param padding - Padding for each spatial dimension, or a string
2228
+ * (`"VALID"`, `"SAME"`, or `"SAME_LOWER"`)
1993
2229
  */
1994
2230
  declare function convGeneralDilated(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, {
1995
2231
  lhsDilation,
@@ -2004,6 +2240,37 @@ declare function convGeneralDilated(lhs: Array, rhs: Array, windowStrides: numbe
2004
2240
  declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, lhsDilation?: number[], rhsDilation?: number[]): Array;
2005
2241
  /** Convenience wrapper around `convGeneralDilated`. */
2006
2242
  declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType): Array;
2243
+ /**
2244
+ * Convenience wrapper for calculating the N-d convolution "transpose".
2245
+ *
2246
+ * This function directly calculates a fractionally strided conv rather than
2247
+ * indirectly calculating the gradient (transpose) of a forward convolution.
2248
+ * It is equivalent to the JAX version, except:
2249
+ *
2250
+ * - The `use_consistent_padding` option is not available. We only have the
2251
+ * consistent padding case (JAX version >0.8.4).
2252
+ * - The order of dimensions matches `lax.conv_general_dilated`.
2253
+ *
2254
+ * Unlike PyTorch/TensorFlow, by default we don't reverse the kernel's spatial
2255
+ * dimensions or the `(C_out, C_in)` axis order. To get this behavior, set
2256
+ * `transposeKernel` to true.
2257
+ *
2258
+ * @param lhs - Input tensor; shape `[N, C_in, ...xs]`
2259
+ * @param rhs - Convolution kernel; shape `[C_out, C_in, ...ks]`
2260
+ * @param strides - Sequence of n integers, sets fractional stride
2261
+ * @param padding - Apply padding of `dilation * (kernel_size - 1) - padding` to
2262
+ * each side of the input, so it acts like gradient of `conv()`
2263
+ * @param rhsDilation - Atrous dilation for the kernel
2264
+ * @param transposeKernel - Flip spatial axes and swap the input/output channels
2265
+ * of the kernel; its shape should be `[C_in, C_out, ...ks]`
2266
+ */
2267
+ declare function convTranspose(lhs: Array, rhs: Array, strides: number[], padding: PaddingType, {
2268
+ rhsDilation,
2269
+ transposeKernel
2270
+ }?: {
2271
+ rhsDilation?: number[];
2272
+ transposeKernel?: boolean;
2273
+ }): Array;
2007
2274
  /** Reduce a computation over padded windows. */
2008
2275
  declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
2009
2276
  /** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
@@ -2023,7 +2290,7 @@ declare function erfc(x: ArrayLike): Array;
2023
2290
  */
2024
2291
  declare function stopGradient(x: ArrayLike): Array;
2025
2292
  declare namespace nn_d_exports {
2026
- export { celu, elu, gelu, glu, hardSigmoid, hardSilu, hardSilu as hardSwish, hardTanh, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, selu, sigmoid, silu, softSign, softmax, softplus, sparsePlus, sparseSigmoid, squareplus, standardize, silu as swish };
2293
+ export { celu, dotProductAttention, elu, gelu, glu, hardSigmoid, hardSilu, hardSilu as hardSwish, hardTanh, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, selu, sigmoid, silu, softSign, softmax, softplus, sparsePlus, sparseSigmoid, squareplus, standardize, silu as swish };
2027
2294
  }
2028
2295
  /**
2029
2296
  * Rectified Linear Unit (ReLU) activation function:
@@ -2220,11 +2487,61 @@ declare function standardize(x: ArrayLike, axis?: Axis, opts?: {
2220
2487
  * ```
2221
2488
  */
2222
2489
  declare function oneHot(x: Array, numClasses: number): Array;
2490
+ /**
2491
+ * Scaled dot product attention (SDPA).
2492
+ *
2493
+ * Computes `softmax((Q @ K^T) / sqrt(d) + bias) @ V`, where `Q` is the query,
2494
+ * `K` is the key, `V` is the value, and `d` is the dimensionality of each key
2495
+ * and query vector.
2496
+ *
2497
+ * Multi-query attention is applied when input `key` and `value` tensors have
2498
+ * fewer heads than `query`.
2499
+ *
2500
+ * We use the following uppercase letters to denote array shapes:
2501
+ * - `B` = batch size
2502
+ * - `S` = length of key/value sequences (source)
2503
+ * - `L` = length of query sequences
2504
+ * - `N` = number of attention heads
2505
+ * - `H` = dimensionality of each attention head
2506
+ * - `K` = number of key/value heads (for grouped-query attention)
2507
+ *
2508
+ * The batch size `B` may be omitted, which is equivalent to `B = 1`. In this
2509
+ * case it must be omitted from all inputs.
2510
+ *
2511
+ * @param query - Query array; shape `[B, L, N, H]`
2512
+ * @param key - Key array; shape `[B, S, K, H]`
2513
+ * @param value - Value array; same shape as `key`
2514
+ * @param opts.bias - Optional bias to add to the attention logits; shape
2515
+ * `[B, N, L, S]` or broadcastable to it.
2516
+ * @param opts.mask - Optional mask to apply to the attention logits; should be
2517
+ * a boolean array broadcastable to `[B, N, L, S]`, where `true` indicates
2518
+ * the element should take part in attention.
2519
+ * @param opts.scale - Scaling factor override, default is `1 / sqrt(H)`.
2520
+ * @param opts.isCausal - If true, applies a casual mask.
2521
+ * @param opts.querySeqLengths - Optional sequence lengths for the queries;
2522
+ * shape `(B,)`. Taken from the beginning of the tensor.
2523
+ * @param opts.keyValueSeqLengths - Optional sequence lengths for the keys and
2524
+ * values; shape `(B,)`. Taken from the beginning of the tensor.
2525
+ * @param opts.localWindowSize - If specified, applies a local attention window
2526
+ * of the given size. Can be a single number or a tuple `[left, right]`.
2527
+ *
2528
+ * @returns The result of the attention operation; shape is the same as query
2529
+ * `[B, L, N, H]`, or `[L, N, H]` if `B` is omitted.
2530
+ */
2531
+ declare function dotProductAttention(query: ArrayLike, key: ArrayLike, value: ArrayLike, opts?: {
2532
+ bias?: ArrayLike;
2533
+ mask?: ArrayLike;
2534
+ scale?: number;
2535
+ isCausal?: boolean;
2536
+ querySeqLengths?: ArrayLike;
2537
+ keyValueSeqLengths?: ArrayLike;
2538
+ localWindowSize?: number | [number, number];
2539
+ }): Array;
2223
2540
  declare namespace random_d_exports {
2224
- export { bernoulli, bits, cauchy, exponential, gumbel, key, laplace, normal, split, uniform };
2541
+ export { bernoulli, bits, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
2225
2542
  }
2226
2543
  /** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
2227
- declare function key(seed: number): Array;
2544
+ declare function key(seed: ArrayLike): Array;
2228
2545
  /** Splits a PRNG key into `num` new keys by adding a leading axis. */
2229
2546
  declare function split(key: Array, num?: number | number[]): Array;
2230
2547
  /** Sample uniform bits in the form of unsigned integers. */
@@ -2271,6 +2588,23 @@ declare const gumbel: OwnedFunction<(key: ArrayLike, shape?: number[] | undefine
2271
2588
  * Inverting: `x = -sign(u - 0.5) * log(1 - 2 * |u - 0.5|)`.
2272
2589
  */
2273
2590
  declare const laplace: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
2591
+ /**
2592
+ * @function
2593
+ * Sample multivariate normal random values with given mean and covariance.
2594
+ *
2595
+ * The values are returned with the given shape, along with the final dimension
2596
+ * used to represent the n-dimensional multivariate normal factors.
2597
+ *
2598
+ * This uses Cholesky decomposition on the covariance matrix.
2599
+ *
2600
+ * - `key` - PRNG key
2601
+ * - `mean` - Mean vector of shape `[..., n]`
2602
+ * - `cov` - Covariance of shape `[..., n, n]`, must be positive-definite
2603
+ * - `shape` - Result batch shape, must be broadcastable with
2604
+ * `mean.shape[:-1]` and `cov.shape[:-2]`
2605
+ * @returns Random samples of shape `[...shape, n]`
2606
+ */
2607
+ declare const multivariateNormal: OwnedFunction<(key: ArrayLike, mean: ArrayLike, cov: ArrayLike, shape?: number[] | undefined) => Array>;
2274
2608
  /**
2275
2609
  * @function
2276
2610
  * Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
@@ -2294,7 +2628,9 @@ declare const logit: OwnedFunction<(x: ArrayLike) => Array>;
2294
2628
  * @function
2295
2629
  * Compute the forward-mode Jacobian-vector product for a function.
2296
2630
  */
2297
- declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, ReturnType<F>];
2631
+ declare const jvp: <F extends (...args: any[]) => JsTree<Array>, HA extends boolean = false>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, tangents: MapJsTree<Parameters<F>, Array, ArrayLike>, opts?: {
2632
+ hasAux?: HA;
2633
+ }) => HA extends true ? ReturnType<F> extends [infer Out, infer Aux] ? [Out, Out, Aux] : never : [ReturnType<F>, ReturnType<F>];
2298
2634
  /**
2299
2635
  * @function
2300
2636
  * Vectorize an operation on a batched axis for one or more inputs.
@@ -2336,28 +2672,100 @@ declare const jit: <F extends (...args: any[]) => JsTree<Array>>(f: F, opts?: Ji
2336
2672
  * Produce a local linear approximation to a function at a point using jvp() and
2337
2673
  * partial evaluation.
2338
2674
  */
2339
- declare const linearize: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (...tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>];
2675
+ declare const linearize: <F extends (...args: any[]) => JsTree<Array>, HA extends boolean = false>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, opts?: {
2676
+ hasAux?: HA;
2677
+ }) => HA extends true ? ReturnType<F> extends [infer Out, infer Aux] ? [Out, OwnedFunction<(...tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => Out>, Aux] : never : [ReturnType<F>, OwnedFunction<(...tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>>];
2340
2678
  /**
2341
2679
  * @function
2342
2680
  * Calculate the reverse-mode vector-Jacobian product for a function.
2681
+ *
2682
+ * The return value is a tuple of `[out, vjpFn]`, where `out` is the output of
2683
+ * `f(primals)`, and `vjpFn` is a function that takes in cotangents for each
2684
+ * output and returns the cotangents for each input.
2685
+ *
2686
+ * When `{ hasAux: true }` is passed, the function `f` is expected to return an
2687
+ * `[out, aux]` tuple, and `vjp` returns `[out, vjpFn, aux]`.
2688
+ *
2689
+ * @example
2690
+ * ```ts
2691
+ * const [y, vjpFn] = vjp(f, [x]);
2692
+ *
2693
+ * // With hasAux
2694
+ * const [y, vjpFn, aux] = vjp(f, [x], { hasAux: true });
2695
+ * ```
2343
2696
  */
2344
- declare const vjp: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (cotangents: MapJsTree<ReturnType<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>];
2697
+ declare const vjp: <F extends (...args: any[]) => JsTree<Array>, const HA extends boolean = false>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, opts?: {
2698
+ hasAux?: HA;
2699
+ }) => HA extends true ? ReturnType<F> extends [infer Out, infer Aux] ? [Out, OwnedFunction<(cotangents: MapJsTree<Out, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>>, Aux] : never : [ReturnType<F>, OwnedFunction<(cotangents: MapJsTree<ReturnType<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>>];
2700
+ /** @inline */
2701
+ type GradOutputType<I, F extends (...args: any[]) => any> = MapJsTree<I extends undefined ? Parameters<F>[0] : I extends number ? Parameters<F>[I] : I extends number[] ? { [K in keyof I]: I[K] extends number ? Parameters<F>[I[K]] : never } : never, ArrayLike, Array>;
2345
2702
  /**
2346
2703
  * @function
2347
2704
  * Compute the gradient of a scalar-valued function `f` with respect to its
2348
2705
  * first argument.
2706
+ *
2707
+ * Pass in different `argnums` to differentiate with respect to other
2708
+ * arguments. If a tuple is provided, the return value will be a tuple of
2709
+ * gradients corresponding to each argument index.
2710
+ *
2711
+ * When `{ hasAux: true }` is passed, the function `f` is expected to return a
2712
+ * `[out, aux]` tuple, and the return value will be `[gradient, aux]`.
2713
+ *
2714
+ * @example
2715
+ * ```ts
2716
+ * const gradient = grad(f)(x);
2717
+ *
2718
+ * // With `argnums`
2719
+ * const [gradientX, gradientZ] = grad(f, { argnums: [0, 2] })(x, y, z);
2720
+ *
2721
+ * // With `hasAux`
2722
+ * const [gradient, aux] = grad(f, { hasAux: true })(x);
2723
+ * ```
2349
2724
  */
2350
- declare const grad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>[0], ArrayLike, Array>;
2725
+ declare const grad: <F extends (...args: any[]) => JsTree<Array>, const I extends undefined | number | number[] = undefined, const HA extends boolean = false>(f: F, opts?: Omit<GradOpts, "argnums" | "hasAux"> & {
2726
+ argnums?: I;
2727
+ hasAux?: HA;
2728
+ }) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => HA extends true ? ReturnType<F> extends [any, infer Aux] ? [GradOutputType<I, F>, Aux] : never : GradOutputType<I, F>;
2351
2729
  /**
2352
2730
  * @function
2353
2731
  * Create a function that evaluates both `f` and the gradient of `f`.
2732
+ *
2733
+ * When `{ hasAux: true }` is passed, the function `f` is expected to return an
2734
+ * `[out, aux]` tuple, and the return value will be `[[out, aux], gradient]`.
2735
+ *
2736
+ * @example
2737
+ * ```ts
2738
+ * // Without hasAux
2739
+ * const [value, gradient] = valueAndGrad(f)(x);
2740
+ *
2741
+ * // With hasAux
2742
+ * const [[value, aux], gradient] = valueAndGrad(f, { hasAux: true })(x);
2743
+ * ```
2354
2744
  */
2355
- declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, MapJsTree<Parameters<F>[0], ArrayLike, Array>];
2745
+ declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>, const I extends undefined | number | number[] = undefined, const HA extends boolean = false>(f: F, opts?: Omit<GradOpts, "argnums"> & {
2746
+ argnums?: I;
2747
+ hasAux?: HA;
2748
+ }) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, GradOutputType<I, F>];
2356
2749
  /**
2357
2750
  * @function
2358
2751
  * Compute the Jacobian evaluated row-by-row by reverse-mode AD.
2359
2752
  */
2360
2753
  declare const jacrev: typeof jacfwd;
2754
+ /**
2755
+ * @function
2756
+ * Compute the Hessian matrix of a scalar-valued function.
2757
+ *
2758
+ * The Hessian is the matrix of second-order partial derivatives of a function.
2759
+ * This is implemented as `jacfwd(grad(f))`.
2760
+ *
2761
+ * @example
2762
+ * ```ts
2763
+ * const f = (x: np.Array) => np.sum(x.ref.mul(x.ref).mul(x)); // x^3
2764
+ * const H = hessian(f)(np.array([1, 2, 3]));
2765
+ * // H[i,j] = d^2f / dx_i dx_j
2766
+ * ```
2767
+ */
2768
+ declare const hessian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
2361
2769
  /**
2362
2770
  * Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
2363
2771
  *
@@ -2380,4 +2788,4 @@ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
2380
2788
  */
2381
2789
  declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
2382
2790
  //#endregion
2383
- export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
2791
+ export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };