@jax-js/jax 0.1.4 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.d.cts CHANGED
@@ -440,10 +440,37 @@ declare enum Routines {
440
440
  Sort = "Sort",
441
441
  /** Returns `int32` indices of the stably sorted array. */
442
442
  Argsort = "Argsort",
443
- /** Solve a triangular system of questions. */
443
+ /**
444
+ * Solve a triangular system of equations.
445
+ *
446
+ * The first batch of inputs `A` should be of shape `[..., N, N]` and upper
447
+ * triangular, while the second batch `B` should be of shape `[..., M, N]`.
448
+ *
449
+ * Solves for `X` in the equation `A @ X.T = B.T`, where `A` is the
450
+ * triangular matrix. This is equivalent to `X = B @ A^-T`.
451
+ */
444
452
  TriangularSolve = "TriangularSolve",
445
- /** Cholesky decomposition of 2D positive semi-definite matrices. */
453
+ /**
454
+ * Cholesky decomposition of 2D positive semi-definite matrices.
455
+ *
456
+ * The input batch should be of shape `[..., N, N]`, and the output batch is
457
+ * of the same shape, containing the lower-triangular matrix `L` such that
458
+ * `A = L @ L.T`. Behavior is unspecified if A is not positive semi-definite.
459
+ */
446
460
  Cholesky = "Cholesky",
461
+ /**
462
+ * LU decomposition of 2D rectangular matrices.
463
+ *
464
+ * The input is a batch of shape `[..., M, N]`, and the output is a tuple of
465
+ * three arrays: `LU, Pivots, Permutation`.
466
+ *
467
+ * - `LU` is of shape `[..., M, N]`, containing the combined lower and upper
468
+ * triangular matrices. (lower triangular = implicit unit diagonal)
469
+ * - `Pivots` is of shape `[..., min(M, N)]`, containing the row swaps.
470
+ * - `Permutation` is of shape `[..., M]`, containing the permutation vector
471
+ * such that `P = eye(M).slice(Permutation)` -> `P @ A = L @ U`.
472
+ */
473
+ LU = "LU",
447
474
  }
448
475
  interface RoutineType {
449
476
  inputShapes: number[][];
@@ -453,7 +480,7 @@ interface RoutineType {
453
480
  }
454
481
  //#endregion
455
482
  //#region src/backend.d.ts
456
- type Device = "cpu" | "wasm" | "webgpu";
483
+ type Device = "cpu" | "wasm" | "webgpu" | "webgl";
457
484
  declare const devices: Device[];
458
485
  /** Configure the default device for arrays. */
459
486
  declare function defaultDevice(device?: Device): Device;
@@ -538,7 +565,7 @@ declare function fft(a: ComplexPair, axis?: number): ComplexPair;
538
565
  */
539
566
  declare function ifft(a: ComplexPair, axis?: number): ComplexPair;
540
567
  declare namespace numpy_linalg_d_exports {
541
- export { cholesky$1 as cholesky, diagonal, lstsq, matmul, matrixTranspose, outer, tensordot, trace, vecdot };
568
+ export { cholesky$1 as cholesky, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot };
542
569
  }
543
570
  /**
544
571
  * Compute the Cholesky decomposition of a (batched) positive-definite matrix.
@@ -553,6 +580,10 @@ declare function cholesky$1(a: ArrayLike, {
553
580
  upper?: boolean;
554
581
  symmetrizeInput?: boolean;
555
582
  }): Array;
583
+ /** Compute the determinant of a square matrix (batched). */
584
+ declare function det(a: ArrayLike): Array;
585
+ /** Compute the inverse of a square matrix (batched). */
586
+ declare function inv(a: ArrayLike): Array;
556
587
  /**
557
588
  * Return the least-squares solution to a linear equation.
558
589
  *
@@ -567,8 +598,75 @@ declare function cholesky$1(a: ArrayLike, {
567
598
  * @return least-squares solution of shape `(N,)` or `(N, K)`
568
599
  */
569
600
  declare function lstsq(a: ArrayLike, b: ArrayLike): Array;
601
+ /** Raise a square matrix to an integer power, via repeated squarings. */
602
+ declare function matrixPower(a: ArrayLike, n: number): Array;
603
+ /** Return sign and natural logarithm of the determinant of `a`. */
604
+ declare function slogdet(a: ArrayLike): [Array, Array];
605
+ /**
606
+ * Solve a linear system of equations.
607
+ *
608
+ * This solves a (batched) linear system of equations `a @ x = b` for `x` given
609
+ * `a` and `b`. If `a` is singular, this will return `nan` or `inf` values.
610
+ *
611
+ * @param a - Coefficient matrix of shape `(..., N, N)`.
612
+ * @param b - Values of shape `(N,)` or `(..., N, M)`.
613
+ * @returns Solution `x` of shape `(..., N)` or `(..., N, M)`.
614
+ */
615
+ declare function solve(a: ArrayLike, b: ArrayLike): Array;
616
+ //#endregion
617
+ //#region src/library/numpy/dtype-info.d.ts
618
+ /** @inline */
619
+ type FInfo = Readonly<{
620
+ /** The number of bits occupied by the type. */
621
+ bits: number;
622
+ /** Returns the _dtype_ for which finfo returns information. */
623
+ dtype: DType;
624
+ /** The difference between 1.0 and the next smallest representable float larger than 1.0. */
625
+ eps: number;
626
+ /** The difference between 1.0 and the next largest representable float smaller than 1.0. */
627
+ epsneg: number;
628
+ /** The exponent that yields `eps`. */
629
+ machep: number;
630
+ /** The largest representable finite number. */
631
+ max: number;
632
+ /** The smallest positive power of the base (2) that causes overflow. */
633
+ maxexp: number;
634
+ /** The smallest representable (most negative) finite number. */
635
+ min: number;
636
+ /** The largest negative power of the base (2) without leading zeros in mantissa. */
637
+ minexp: number;
638
+ /** The exponent that yields `epsneg`. */
639
+ negep: number;
640
+ /** Number of bits in the exponent portion. */
641
+ nexp: number;
642
+ /** Number of bits in the mantissa portion. */
643
+ nmant: number;
644
+ /** The approximate number of decimal digits to which this kind of float is precise. */
645
+ precision: number;
646
+ /** The approximate decimal resolution, i.e., `10 ** -precision`. */
647
+ resolution: number;
648
+ /** The smallest positive normal number. */
649
+ smallestNormal: number;
650
+ /** The smallest positive subnormal number. */
651
+ smallestSubnormal: number;
652
+ }>;
653
+ /** Machine limits for floating-point types. */
654
+ declare function finfo(dtype: DType): FInfo;
655
+ /** @inline */
656
+ type IInfo = Readonly<{
657
+ /** The number of bits occupied by the type. */
658
+ bits: number;
659
+ /** Returns the _dtype_ for which iinfo returns information. */
660
+ dtype: DType;
661
+ /** The largest representable integer. */
662
+ max: number;
663
+ /** The smallest representable integer. */
664
+ min: number;
665
+ }>;
666
+ /** Machine limits for integer types. */
667
+ declare function iinfo(dtype: DType): IInfo;
570
668
  declare namespace numpy_d_exports {
571
- export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, dot$1 as dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, flip, fliplr, flipud, float16, float32, float64, floor, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, matmul, matrixTranspose, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sort, sqrt, square, squeeze, stack, std, subtract, sum, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
669
+ export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot$1 as dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logspace, matmul, matrixTranspose, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
572
670
  }
573
671
  declare const float32 = DType.Float32;
574
672
  declare const int32 = DType.Int32;
@@ -735,6 +833,16 @@ declare function argmax(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
735
833
  declare function cumsum(a: ArrayLike, axis?: number): Array;
736
834
  /** Reverse the elements in an array along the given axes. */
737
835
  declare function flip(x: ArrayLike, axis?: Axis): Array;
836
+ /**
837
+ * Split an array into multiple sub-arrays along an axis.
838
+ *
839
+ * @param a - The input array to split.
840
+ * @param indicesOrSections - If an integer, it indicates the number of equal
841
+ * sections to create along the specified axis. If a list of integers, it
842
+ * specifies the indices at which to split the array.
843
+ * @param axis - The axis along which to split the array. Default is 0.
844
+ */
845
+ declare function split$1(a: ArrayLike, indicesOrSections: number | number[], axis?: number): Array[];
738
846
  /**
739
847
  * Join a sequence of arrays along an existing axis.
740
848
  *
@@ -778,6 +886,8 @@ declare function columnStack(xs: ArrayLike[]): Array;
778
886
  declare function flipud(x: ArrayLike): Array;
779
887
  /** Flip an array horizontally (axis=1). */
780
888
  declare function fliplr(x: ArrayLike): Array;
889
+ /** Interchange two axes of an array. */
890
+ declare function swapaxes(a: ArrayLike, axis1: number, axis2: number): Array;
781
891
  /** Transpose the last two dimensions of an array. */
782
892
  declare function matrixTranspose(a: ArrayLike): Array;
783
893
  /** Return a 1-D flattened array containing the elements of the input. */
@@ -863,6 +973,13 @@ declare function sort(a: ArrayLike, axis?: number): Array;
863
973
  * The array is sorted along a specified axis (the last by default).
864
974
  */
865
975
  declare function argsort(a: ArrayLike, axis?: number): Array;
976
+ /**
977
+ * Take elements from an array along an axis.
978
+ *
979
+ * This is equivalent to advanced indexing with integer indices over that
980
+ * numbered axis. By default, the flattened array is used.
981
+ */
982
+ declare function take(a: ArrayLike, indices: ArrayLike, axis?: number | null): Array;
866
983
  /** Return if two arrays are element-wise equal within a tolerance. */
867
984
  declare function allclose(actual: Parameters<typeof array>[0], expected: Parameters<typeof array>[0], options?: {
868
985
  rtol?: number;
@@ -993,6 +1110,17 @@ declare const heaviside: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
993
1110
  declare function square(x: ArrayLike): Array;
994
1111
  /** Element-wise tangent function (takes radians). */
995
1112
  declare function tan(x: ArrayLike): Array;
1113
+ /**
1114
+ * @function
1115
+ * Return the normalized sinc function.
1116
+ *
1117
+ * The sinc function is defined as `sin(πx) / (πx)` for `x != 0`, and `1` for `x = 0`.
1118
+ * This is the normalized sinc function commonly used in signal processing.
1119
+ *
1120
+ * **Note:** JVP is not supported at x=0 due to discontinuous derivative. This
1121
+ * requires a custom JVP rule to handle properly (see JAX implementation).
1122
+ */
1123
+ declare const sinc: OwnedFunction<(x: ArrayLike) => Array>;
996
1124
  /** Element-wise inverse cosine function (inverse of cos). */
997
1125
  declare function acos(x: ArrayLike): Array;
998
1126
  /**
@@ -1022,6 +1150,20 @@ declare const atan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
1022
1150
  declare function subtract(x: ArrayLike, y: ArrayLike): Array;
1023
1151
  /** Calculates the floating-point division of x by y element-wise. */
1024
1152
  declare function trueDivide(x: ArrayLike, y: ArrayLike): Array;
1153
+ /**
1154
+ * Return the largest integer smaller or equal to the division of the inputs.
1155
+ *
1156
+ * The result is always rounded towards negative infinity.
1157
+ *
1158
+ * For floating-point inputs, this is equivalent to `floor(x / y)`.
1159
+ * For integer inputs, we use `(x - remainder(x, y)) / y` to handle
1160
+ * negative values correctly (note: may overflow near int32 boundaries).
1161
+ *
1162
+ * @param x - Dividend array.
1163
+ * @param y - Divisor array.
1164
+ * @returns Element-wise floor division of x by y.
1165
+ */
1166
+ declare function floorDivide(x: ArrayLike, y: ArrayLike): Array;
1025
1167
  /**
1026
1168
  * @function
1027
1169
  * Calculate element-wise floating-point modulo operation.
@@ -1032,6 +1174,16 @@ declare const fmod: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
1032
1174
  * Calculate element-wise remainder of the division (matches sign of y).
1033
1175
  */
1034
1176
  declare const remainder: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
1177
+ /**
1178
+ * Return element-wise quotient and remainder simultaneously.
1179
+ *
1180
+ * Equivalent to `[floorDivide(x, y), remainder(x, y)]`.
1181
+ *
1182
+ * @param x - Dividend array.
1183
+ * @param y - Divisor array.
1184
+ * @returns Tuple of [quotient, remainder].
1185
+ */
1186
+ declare function divmod(x: ArrayLike, y: ArrayLike): [Array, Array];
1035
1187
  /** Round input to the nearest integer towards zero. */
1036
1188
  declare function trunc(x: ArrayLike): Array;
1037
1189
  /**
@@ -1142,7 +1294,11 @@ declare function std(x: ArrayLike, axis?: Axis, opts?: {
1142
1294
  correction?: number;
1143
1295
  } & ReduceOpts): Array;
1144
1296
  /** Estimate the sample covariance of a set of variables. */
1145
- declare function cov(x: ArrayLike, y?: ArrayLike): Array;
1297
+ declare function cov(x: ArrayLike, y?: ArrayLike | null, {
1298
+ rowvar
1299
+ }?: {
1300
+ rowvar?: boolean;
1301
+ }): Array;
1146
1302
  /** Compute the Pearson correlation coefficients (in range `[-1, 1]`). */
1147
1303
  declare function corrcoef(x: ArrayLike, y?: ArrayLike): Array;
1148
1304
  /** Test element-wise for positive or negative infinity, return bool array. */
@@ -1356,6 +1512,8 @@ declare enum Primitive {
1356
1512
  PoolTranspose = "pool_transpose",
1357
1513
  Compare = "compare",
1358
1514
  Where = "where",
1515
+ Concatenate = "concatenate",
1516
+ Split = "split",
1359
1517
  RandomBits = "random_bits",
1360
1518
  Gather = "gather",
1361
1519
  Transpose = "transpose",
@@ -1372,6 +1530,8 @@ declare enum Primitive {
1372
1530
  // A is upper triangular, A @ X.T = B.T
1373
1531
  Cholesky = "cholesky",
1374
1532
  // A is positive-definite, A = L @ L^T
1533
+ LU = "lu",
1534
+ // LU decomposition with partial pivoting
1375
1535
  Jit = "jit",
1376
1536
  }
1377
1537
  interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
@@ -1398,6 +1558,13 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
1398
1558
  [Primitive.Compare]: {
1399
1559
  op: CompareOp;
1400
1560
  };
1561
+ [Primitive.Concatenate]: {
1562
+ axis: number;
1563
+ };
1564
+ [Primitive.Split]: {
1565
+ axis: number;
1566
+ sizes: number[];
1567
+ };
1401
1568
  [Primitive.RandomBits]: {
1402
1569
  shape: number[];
1403
1570
  mode: "xor" | 0 | 1;
@@ -1425,14 +1592,14 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
1425
1592
  [Primitive.Pad]: {
1426
1593
  width: Pair[];
1427
1594
  };
1595
+ [Primitive.TriangularSolve]: {
1596
+ unitDiagonal: boolean;
1597
+ };
1428
1598
  [Primitive.Jit]: {
1429
1599
  name: string;
1430
1600
  jaxpr: Jaxpr;
1431
1601
  numConsts: number;
1432
1602
  };
1433
- [Primitive.TriangularSolve]: {
1434
- unitDiagonal: boolean;
1435
- };
1436
1603
  }
1437
1604
  /** Type of parameters taken by each primitive. */
1438
1605
  type PrimitiveParams<T extends Primitive> = T extends keyof PrimitiveParamsImpl ? PrimitiveParamsImpl[T] : Record<string, never>;
@@ -1573,6 +1740,7 @@ declare abstract class Tracer {
1573
1740
  neg(): this;
1574
1741
  add(other: this | TracerValue): this;
1575
1742
  mul(other: this | TracerValue): this;
1743
+ mod(other: this | TracerValue): this;
1576
1744
  greater(other: this | TracerValue): this;
1577
1745
  less(other: this | TracerValue): this;
1578
1746
  equal(other: this | TracerValue): this;
@@ -1675,6 +1843,7 @@ declare class ShapedArray implements AbstractValue {
1675
1843
  static fromAval(aval: AbstractValue): ShapedArray;
1676
1844
  get ndim(): number;
1677
1845
  get size(): number;
1846
+ scalar(): ShapedArray;
1678
1847
  toString(): string;
1679
1848
  equals(other: ShapedArray): boolean;
1680
1849
  }
@@ -1742,6 +1911,8 @@ declare class Array extends Tracer {
1742
1911
  toString(): string;
1743
1912
  get device(): Device;
1744
1913
  get ref(): this;
1914
+ /** Get the current reference count (for debugging memory management). */
1915
+ get refCount(): number;
1745
1916
  dispose(): void;
1746
1917
  /**
1747
1918
  * Convert this array into a primitive value.
@@ -1890,8 +2061,43 @@ declare function linspace(start: number, stop: number, num?: number, endpoint?:
1890
2061
  dtype,
1891
2062
  device
1892
2063
  }?: DTypeAndDevice): Array;
2064
+ /**
2065
+ * Return numbers spaced evenly on a log scale.
2066
+ *
2067
+ * In linear space, the sequence starts at `base ** start` and ends at
2068
+ * `base ** stop` (see `endpoint` below).
2069
+ *
2070
+ * @param start - `base ** start` is the starting value of the sequence.
2071
+ * @param stop - `base ** stop` is the final value of the sequence, unless `endpoint` is false.
2072
+ * @param num - Number of samples to generate. Default is 50.
2073
+ * @param endpoint - If true, `stop` is the last sample. Otherwise, it is not included. Default is true.
2074
+ * @param base - The base of the log space. Default is 10.
2075
+ * @returns Array of evenly spaced values on a log scale.
2076
+ */
2077
+ declare function logspace(start: number, stop: number, num?: number, endpoint?: boolean, base?: number, {
2078
+ dtype,
2079
+ device
2080
+ }?: DTypeAndDevice): Array;
2081
+ //#endregion
2082
+ //#region src/frontend/linearize.d.ts
2083
+ /** @inline */
2084
+ type GradOpts = {
2085
+ /**
2086
+ * Integer or sequence of integers. Specifies which positional argument(s) to
2087
+ * differentiate with respect to.
2088
+ *
2089
+ * Defaults to `0` (the first argument).
2090
+ */
2091
+ argnums?: number | number[];
2092
+ /**
2093
+ * The input function returns a pair of `[out, aux]` including an auxiliary
2094
+ * value. This `aux` is not differentiated, but is returned alongside the
2095
+ * gradient when evaluating the function.
2096
+ */
2097
+ hasAux?: boolean;
2098
+ };
1893
2099
  declare namespace lax_linalg_d_exports {
1894
- export { cholesky, triangularSolve };
2100
+ export { cholesky, lu, triangularSolve };
1895
2101
  }
1896
2102
  /**
1897
2103
  * Compute the Cholesky decomposition of a symmetric positive-definite matrix.
@@ -1924,6 +2130,32 @@ declare function cholesky(a: ArrayLike, {
1924
2130
  }?: {
1925
2131
  upper?: boolean;
1926
2132
  }): Array;
2133
+ /**
2134
+ * LU decomposition with partial pivoting.
2135
+ *
2136
+ * Computes the matrix decomposition: `P @ A = L @ U`, where `P` is a
2137
+ * permutation of the rows of `A`, `L` is lower-triangular with unit diagonal,
2138
+ * and `U` is upper-triangular.
2139
+ *
2140
+ * @param x - A batch of matrices with shape `[..., m, n]`.
2141
+ *
2142
+ * @returns A tuple `(lu, pivots, permutation)` where:
2143
+ * - `lu`: combined lower and upper triangular matrices.
2144
+ * - `pivots`: an array of pivot indices with shape `[..., min(m, n)]`.
2145
+ * - `permutation`: the permutation generated by pivots with shape `[..., m]`.
2146
+ *
2147
+ * @example
2148
+ * ```ts
2149
+ * import { lax, numpy as np } from "@jax-js/jax";
2150
+ *
2151
+ * const A = np.array([[4., 3.], [6., 3.]]);
2152
+ * const [lu, pivots, permutation] = lax.linalg.lu(A);
2153
+ * // lu ≈ [[6., 3.], [0.6666667, 1.0]]
2154
+ * // pivots = [1, 1]
2155
+ * // permutation = [1, 0]
2156
+ * ```
2157
+ */
2158
+ declare function lu(x: ArrayLike): [Array, Array, Array];
1927
2159
  /**
1928
2160
  * Solve a triangular linear system.
1929
2161
  *
@@ -1954,7 +2186,7 @@ declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
1954
2186
  unitDiagonal?: boolean;
1955
2187
  }): Array;
1956
2188
  declare namespace lax_d_exports {
1957
- export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convWithGeneralPadding, dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
2189
+ export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
1958
2190
  }
1959
2191
  /**
1960
2192
  * Dimension numbers for general `dot()` primitive.
@@ -1992,7 +2224,11 @@ type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | Pair[];
1992
2224
  * The semantics of this operation mimic the `jax.lax.conv_general_dilated`
1993
2225
  * function in JAX, which wraps XLA's general convolution operator.
1994
2226
  *
1995
- * Grouped convolutions are not supported right now.
2227
+ * @param lhs - Input tensor; shape `[N, C_in, ...xs]`
2228
+ * @param rhs - Convolution kernel; shape `[C_out, C_in / G, ...ks]`
2229
+ * @param windowStrides - Strides for each spatial dimension
2230
+ * @param padding - Padding for each spatial dimension, or a string
2231
+ * (`"VALID"`, `"SAME"`, or `"SAME_LOWER"`)
1996
2232
  */
1997
2233
  declare function convGeneralDilated(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, {
1998
2234
  lhsDilation,
@@ -2007,6 +2243,37 @@ declare function convGeneralDilated(lhs: Array, rhs: Array, windowStrides: numbe
2007
2243
  declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, lhsDilation?: number[], rhsDilation?: number[]): Array;
2008
2244
  /** Convenience wrapper around `convGeneralDilated`. */
2009
2245
  declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType): Array;
2246
+ /**
2247
+ * Convenience wrapper for calculating the N-d convolution "transpose".
2248
+ *
2249
+ * This function directly calculates a fractionally strided conv rather than
2250
+ * indirectly calculating the gradient (transpose) of a forward convolution.
2251
+ * It is equivalent to the JAX version, except:
2252
+ *
2253
+ * - The `use_consistent_padding` option is not available. We only have the
2254
+ * consistent padding case (JAX version >0.8.4).
2255
+ * - The order of dimensions matches `lax.conv_general_dilated`.
2256
+ *
2257
+ * Unlike PyTorch/TensorFlow, by default we don't reverse the kernel's spatial
2258
+ * dimensions or the `(C_out, C_in)` axis order. To get this behavior, set
2259
+ * `transposeKernel` to true.
2260
+ *
2261
+ * @param lhs - Input tensor; shape `[N, C_in, ...xs]`
2262
+ * @param rhs - Convolution kernel; shape `[C_out, C_in, ...ks]`
2263
+ * @param strides - Sequence of n integers, sets fractional stride
2264
+ * @param padding - Apply padding of `dilation * (kernel_size - 1) - padding` to
2265
+ * each side of the input, so it acts like gradient of `conv()`
2266
+ * @param rhsDilation - Atrous dilation for the kernel
2267
+ * @param transposeKernel - Flip spatial axes and swap the input/output channels
2268
+ * of the kernel; its shape should be `[C_in, C_out, ...ks]`
2269
+ */
2270
+ declare function convTranspose(lhs: Array, rhs: Array, strides: number[], padding: PaddingType, {
2271
+ rhsDilation,
2272
+ transposeKernel
2273
+ }?: {
2274
+ rhsDilation?: number[];
2275
+ transposeKernel?: boolean;
2276
+ }): Array;
2010
2277
  /** Reduce a computation over padded windows. */
2011
2278
  declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
2012
2279
  /** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
@@ -2026,7 +2293,7 @@ declare function erfc(x: ArrayLike): Array;
2026
2293
  */
2027
2294
  declare function stopGradient(x: ArrayLike): Array;
2028
2295
  declare namespace nn_d_exports {
2029
- export { celu, elu, gelu, glu, hardSigmoid, hardSilu, hardSilu as hardSwish, hardTanh, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, selu, sigmoid, silu, softSign, softmax, softplus, sparsePlus, sparseSigmoid, squareplus, standardize, silu as swish };
2296
+ export { celu, dotProductAttention, elu, gelu, glu, hardSigmoid, hardSilu, hardSilu as hardSwish, hardTanh, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, selu, sigmoid, silu, softSign, softmax, softplus, sparsePlus, sparseSigmoid, squareplus, standardize, silu as swish };
2030
2297
  }
2031
2298
  /**
2032
2299
  * Rectified Linear Unit (ReLU) activation function:
@@ -2223,11 +2490,61 @@ declare function standardize(x: ArrayLike, axis?: Axis, opts?: {
2223
2490
  * ```
2224
2491
  */
2225
2492
  declare function oneHot(x: Array, numClasses: number): Array;
2493
+ /**
2494
+ * Scaled dot product attention (SDPA).
2495
+ *
2496
+ * Computes `softmax((Q @ K^T) / sqrt(d) + bias) @ V`, where `Q` is the query,
2497
+ * `K` is the key, `V` is the value, and `d` is the dimensionality of each key
2498
+ * and query vector.
2499
+ *
2500
+ * Multi-query attention is applied when input `key` and `value` tensors have
2501
+ * fewer heads than `query`.
2502
+ *
2503
+ * We use the following uppercase letters to denote array shapes:
2504
+ * - `B` = batch size
2505
+ * - `S` = length of key/value sequences (source)
2506
+ * - `L` = length of query sequences
2507
+ * - `N` = number of attention heads
2508
+ * - `H` = dimensionality of each attention head
2509
+ * - `K` = number of key/value heads (for grouped-query attention)
2510
+ *
2511
+ * The batch size `B` may be omitted, which is equivalent to `B = 1`. In this
2512
+ * case it must be omitted from all inputs.
2513
+ *
2514
+ * @param query - Query array; shape `[B, L, N, H]`
2515
+ * @param key - Key array; shape `[B, S, K, H]`
2516
+ * @param value - Value array; same shape as `key`
2517
+ * @param opts.bias - Optional bias to add to the attention logits; shape
2518
+ * `[B, N, L, S]` or broadcastable to it.
2519
+ * @param opts.mask - Optional mask to apply to the attention logits; should be
2520
+ * a boolean array broadcastable to `[B, N, L, S]`, where `true` indicates
2521
+ * the element should take part in attention.
2522
+ * @param opts.scale - Scaling factor override, default is `1 / sqrt(H)`.
2523
+ * @param opts.isCausal - If true, applies a casual mask.
2524
+ * @param opts.querySeqLengths - Optional sequence lengths for the queries;
2525
+ * shape `(B,)`. Taken from the beginning of the tensor.
2526
+ * @param opts.keyValueSeqLengths - Optional sequence lengths for the keys and
2527
+ * values; shape `(B,)`. Taken from the beginning of the tensor.
2528
+ * @param opts.localWindowSize - If specified, applies a local attention window
2529
+ * of the given size. Can be a single number or a tuple `[left, right]`.
2530
+ *
2531
+ * @returns The result of the attention operation; shape is the same as query
2532
+ * `[B, L, N, H]`, or `[L, N, H]` if `B` is omitted.
2533
+ */
2534
+ declare function dotProductAttention(query: ArrayLike, key: ArrayLike, value: ArrayLike, opts?: {
2535
+ bias?: ArrayLike;
2536
+ mask?: ArrayLike;
2537
+ scale?: number;
2538
+ isCausal?: boolean;
2539
+ querySeqLengths?: ArrayLike;
2540
+ keyValueSeqLengths?: ArrayLike;
2541
+ localWindowSize?: number | [number, number];
2542
+ }): Array;
2226
2543
  declare namespace random_d_exports {
2227
- export { bernoulli, bits, cauchy, exponential, gumbel, key, laplace, normal, split, uniform };
2544
+ export { bernoulli, bits, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
2228
2545
  }
2229
2546
  /** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
2230
- declare function key(seed: number): Array;
2547
+ declare function key(seed: ArrayLike): Array;
2231
2548
  /** Splits a PRNG key into `num` new keys by adding a leading axis. */
2232
2549
  declare function split(key: Array, num?: number | number[]): Array;
2233
2550
  /** Sample uniform bits in the form of unsigned integers. */
@@ -2274,6 +2591,23 @@ declare const gumbel: OwnedFunction<(key: ArrayLike, shape?: number[] | undefine
2274
2591
  * Inverting: `x = -sign(u - 0.5) * log(1 - 2 * |u - 0.5|)`.
2275
2592
  */
2276
2593
  declare const laplace: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
2594
+ /**
2595
+ * @function
2596
+ * Sample multivariate normal random values with given mean and covariance.
2597
+ *
2598
+ * The values are returned with the given shape, along with the final dimension
2599
+ * used to represent the n-dimensional multivariate normal factors.
2600
+ *
2601
+ * This uses Cholesky decomposition on the covariance matrix.
2602
+ *
2603
+ * - `key` - PRNG key
2604
+ * - `mean` - Mean vector of shape `[..., n]`
2605
+ * - `cov` - Covariance of shape `[..., n, n]`, must be positive-definite
2606
+ * - `shape` - Result batch shape, must be broadcastable with
2607
+ * `mean.shape[:-1]` and `cov.shape[:-2]`
2608
+ * @returns Random samples of shape `[...shape, n]`
2609
+ */
2610
+ declare const multivariateNormal: OwnedFunction<(key: ArrayLike, mean: ArrayLike, cov: ArrayLike, shape?: number[] | undefined) => Array>;
2277
2611
  /**
2278
2612
  * @function
2279
2613
  * Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
@@ -2297,7 +2631,9 @@ declare const logit: OwnedFunction<(x: ArrayLike) => Array>;
2297
2631
  * @function
2298
2632
  * Compute the forward-mode Jacobian-vector product for a function.
2299
2633
  */
2300
- declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, ReturnType<F>];
2634
+ declare const jvp: <F extends (...args: any[]) => JsTree<Array>, HA extends boolean = false>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, tangents: MapJsTree<Parameters<F>, Array, ArrayLike>, opts?: {
2635
+ hasAux?: HA;
2636
+ }) => HA extends true ? ReturnType<F> extends [infer Out, infer Aux] ? [Out, Out, Aux] : never : [ReturnType<F>, ReturnType<F>];
2301
2637
  /**
2302
2638
  * @function
2303
2639
  * Vectorize an operation on a batched axis for one or more inputs.
@@ -2339,28 +2675,100 @@ declare const jit: <F extends (...args: any[]) => JsTree<Array>>(f: F, opts?: Ji
2339
2675
  * Produce a local linear approximation to a function at a point using jvp() and
2340
2676
  * partial evaluation.
2341
2677
  */
2342
- declare const linearize: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (...tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>];
2678
+ declare const linearize: <F extends (...args: any[]) => JsTree<Array>, HA extends boolean = false>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, opts?: {
2679
+ hasAux?: HA;
2680
+ }) => HA extends true ? ReturnType<F> extends [infer Out, infer Aux] ? [Out, OwnedFunction<(...tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => Out>, Aux] : never : [ReturnType<F>, OwnedFunction<(...tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>>];
2343
2681
  /**
2344
2682
  * @function
2345
2683
  * Calculate the reverse-mode vector-Jacobian product for a function.
2684
+ *
2685
+ * The return value is a tuple of `[out, vjpFn]`, where `out` is the output of
2686
+ * `f(primals)`, and `vjpFn` is a function that takes in cotangents for each
2687
+ * output and returns the cotangents for each input.
2688
+ *
2689
+ * When `{ hasAux: true }` is passed, the function `f` is expected to return an
2690
+ * `[out, aux]` tuple, and `vjp` returns `[out, vjpFn, aux]`.
2691
+ *
2692
+ * @example
2693
+ * ```ts
2694
+ * const [y, vjpFn] = vjp(f, [x]);
2695
+ *
2696
+ * // With hasAux
2697
+ * const [y, vjpFn, aux] = vjp(f, [x], { hasAux: true });
2698
+ * ```
2346
2699
  */
2347
- declare const vjp: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (cotangents: MapJsTree<ReturnType<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>];
2700
+ declare const vjp: <F extends (...args: any[]) => JsTree<Array>, const HA extends boolean = false>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, opts?: {
2701
+ hasAux?: HA;
2702
+ }) => HA extends true ? ReturnType<F> extends [infer Out, infer Aux] ? [Out, OwnedFunction<(cotangents: MapJsTree<Out, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>>, Aux] : never : [ReturnType<F>, OwnedFunction<(cotangents: MapJsTree<ReturnType<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>>];
2703
+ /** @inline */
2704
+ type GradOutputType<I, F extends (...args: any[]) => any> = MapJsTree<I extends undefined ? Parameters<F>[0] : I extends number ? Parameters<F>[I] : I extends number[] ? { [K in keyof I]: I[K] extends number ? Parameters<F>[I[K]] : never } : never, ArrayLike, Array>;
2348
2705
  /**
2349
2706
  * @function
2350
2707
  * Compute the gradient of a scalar-valued function `f` with respect to its
2351
2708
  * first argument.
2709
+ *
2710
+ * Pass in different `argnums` to differentiate with respect to other
2711
+ * arguments. If a tuple is provided, the return value will be a tuple of
2712
+ * gradients corresponding to each argument index.
2713
+ *
2714
+ * When `{ hasAux: true }` is passed, the function `f` is expected to return a
2715
+ * `[out, aux]` tuple, and the return value will be `[gradient, aux]`.
2716
+ *
2717
+ * @example
2718
+ * ```ts
2719
+ * const gradient = grad(f)(x);
2720
+ *
2721
+ * // With `argnums`
2722
+ * const [gradientX, gradientZ] = grad(f, { argnums: [0, 2] })(x, y, z);
2723
+ *
2724
+ * // With `hasAux`
2725
+ * const [gradient, aux] = grad(f, { hasAux: true })(x);
2726
+ * ```
2352
2727
  */
2353
- declare const grad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>[0], ArrayLike, Array>;
2728
+ declare const grad: <F extends (...args: any[]) => JsTree<Array>, const I extends undefined | number | number[] = undefined, const HA extends boolean = false>(f: F, opts?: Omit<GradOpts, "argnums" | "hasAux"> & {
2729
+ argnums?: I;
2730
+ hasAux?: HA;
2731
+ }) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => HA extends true ? ReturnType<F> extends [any, infer Aux] ? [GradOutputType<I, F>, Aux] : never : GradOutputType<I, F>;
2354
2732
  /**
2355
2733
  * @function
2356
2734
  * Create a function that evaluates both `f` and the gradient of `f`.
2735
+ *
2736
+ * When `{ hasAux: true }` is passed, the function `f` is expected to return an
2737
+ * `[out, aux]` tuple, and the return value will be `[[out, aux], gradient]`.
2738
+ *
2739
+ * @example
2740
+ * ```ts
2741
+ * // Without hasAux
2742
+ * const [value, gradient] = valueAndGrad(f)(x);
2743
+ *
2744
+ * // With hasAux
2745
+ * const [[value, aux], gradient] = valueAndGrad(f, { hasAux: true })(x);
2746
+ * ```
2357
2747
  */
2358
- declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, MapJsTree<Parameters<F>[0], ArrayLike, Array>];
2748
+ declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>, const I extends undefined | number | number[] = undefined, const HA extends boolean = false>(f: F, opts?: Omit<GradOpts, "argnums"> & {
2749
+ argnums?: I;
2750
+ hasAux?: HA;
2751
+ }) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, GradOutputType<I, F>];
2359
2752
  /**
2360
2753
  * @function
2361
2754
  * Compute the Jacobian evaluated row-by-row by reverse-mode AD.
2362
2755
  */
2363
2756
  declare const jacrev: typeof jacfwd;
2757
+ /**
2758
+ * @function
2759
+ * Compute the Hessian matrix of a scalar-valued function.
2760
+ *
2761
+ * The Hessian is the matrix of second-order partial derivatives of a function.
2762
+ * This is implemented as `jacfwd(grad(f))`.
2763
+ *
2764
+ * @example
2765
+ * ```ts
2766
+ * const f = (x: np.Array) => np.sum(x.ref.mul(x.ref).mul(x)); // x^3
2767
+ * const H = hessian(f)(np.array([1, 2, 3]));
2768
+ * // H[i,j] = d^2f / dx_i dx_j
2769
+ * ```
2770
+ */
2771
+ declare const hessian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
2364
2772
  /**
2365
2773
  * Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
2366
2774
  *
@@ -2383,4 +2791,4 @@ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
2383
2791
  */
2384
2792
  declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
2385
2793
  //#endregion
2386
- export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
2794
+ export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };