@jax-js/jax 0.1.8 → 0.1.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.d.ts CHANGED
@@ -433,9 +433,14 @@ declare class Routine {
433
433
  }
434
434
  /** One of the valid `Routine` that can be dispatched to backend. */
435
435
  declare enum Routines {
436
- /** Stable sorting algorithm along the last axis. */
436
+ /**
437
+ * Sort along the last axis.
438
+ *
439
+ * This may be _unstable_ but it often doesn't matter, sorting numbers is
440
+ * bitwise unique up to signed zeros and NaNs.
441
+ */
437
442
  Sort = "Sort",
438
- /** Returns `int32` indices of the stably sorted array. */
443
+ /** Stable sorting, returns `int32` indices and values of the sorted array. */
439
444
  Argsort = "Argsort",
440
445
  /**
441
446
  * Solve a triangular system of equations.
@@ -747,9 +752,9 @@ declare enum Primitive {
747
752
  Shrink = "shrink",
748
753
  Pad = "pad",
749
754
  Sort = "sort",
750
- // sort(x, axis=-1)
755
+ // sort(x, axis=-1), unstable
751
756
  Argsort = "argsort",
752
- // argsort(x, axis=-1)
757
+ // argsort(x, axis=-1), stable
753
758
  TriangularSolve = "triangular_solve",
754
759
  // A is upper triangular, A @ X.T = B.T
755
760
  Cholesky = "cholesky",
@@ -996,6 +1001,8 @@ declare abstract class Tracer {
996
1001
  reshape(shape: number | number[]): this;
997
1002
  /** Copy the array and cast to a specified dtype. */
998
1003
  astype(dtype: DType): this;
1004
+ /** Return a bitwise cast of the array, viewed as a new dtype. */
1005
+ view(dtype?: DType): this;
999
1006
  /** Subtract an array from this one. */
1000
1007
  sub(other: this | TracerValue): this;
1001
1008
  /** Divide an array by this one. */
@@ -1026,8 +1033,9 @@ declare abstract class Tracer {
1026
1033
  */
1027
1034
  sort(axis?: number): this;
1028
1035
  /**
1029
- * Return the indices that would sort an array. This may not be a stable
1030
- * sorting algorithm; it need not preserve order of indices in ties.
1036
+ * Return the indices that would sort an array. Unlike `sort`, this is
1037
+ * guaranteed to be a stable sorting algorithm; it always returns the smaller
1038
+ * index first in event of ties.
1031
1039
  *
1032
1040
  * See `jax.numpy.argsort` for full docs.
1033
1041
  */
@@ -1109,6 +1117,12 @@ type DTypeAndDevice = {
1109
1117
  dtype?: DType;
1110
1118
  device?: Device;
1111
1119
  };
1120
+ /** @inline */
1121
+ type DTypeShapeAndDevice = {
1122
+ dtype?: DType;
1123
+ shape?: number[];
1124
+ device?: Device;
1125
+ };
1112
1126
  type ArrayConstructorArgs = {
1113
1127
  source: AluExp | Slot;
1114
1128
  st: ShapeTracker;
@@ -1218,15 +1232,9 @@ declare function array(values: Array | DataArray | RecursiveArray<number> | Recu
1218
1232
  type ImplRule<P extends Primitive> = (tracers: Array[], params: PrimitiveParams<P>) => Array[];
1219
1233
  declare const implRules: { [P in Primitive]: ImplRule<P> };
1220
1234
  /** Return a new array of given shape and type, filled with zeros. */
1221
- declare function zeros(shape: number[], {
1222
- dtype,
1223
- device
1224
- }?: DTypeAndDevice): Array;
1235
+ declare function zeros(shape: number[], opts?: DTypeAndDevice): Array;
1225
1236
  /** Return a new array of given shape and type, filled with ones. */
1226
- declare function ones(shape: number[], {
1227
- dtype,
1228
- device
1229
- }?: DTypeAndDevice): Array;
1237
+ declare function ones(shape: number[], opts?: DTypeAndDevice): Array;
1230
1238
  /** Return a new array of given shape and type, filled with `fill_value`. */
1231
1239
  declare function full(shape: number[], fillValue: number | boolean | Array, {
1232
1240
  dtype,
@@ -1418,8 +1426,10 @@ declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
1418
1426
  unitDiagonal?: boolean;
1419
1427
  }): Array;
1420
1428
  declare namespace lax_d_exports {
1421
- export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
1429
+ export { DotDimensionNumbers, PaddingType, bitcastConvertType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient, topK };
1422
1430
  }
1431
+ /** Elementwise bitcast an array into a new dtype. */
1432
+ declare function bitcastConvertType(x: ArrayLike, newDtype: DType): Array;
1423
1433
  /**
1424
1434
  * Dimension numbers for general `dot()` primitive.
1425
1435
  *
@@ -1524,6 +1534,16 @@ declare function erfc(x: ArrayLike): Array;
1524
1534
  * forward or reverse-mode automatic differentiation.
1525
1535
  */
1526
1536
  declare function stopGradient(x: ArrayLike): Array;
1537
+ /**
1538
+ * Returns top `k` values and their indices along the specified axis of operand.
1539
+ *
1540
+ * This is a _stable_ algorithm: If two elements are equal, the lower-index
1541
+ * element appears first.
1542
+ *
1543
+ * @returns A tuple of `(values, indices)`, where `values` and `indices` have
1544
+ * the same shape as `x`, except along `axis` where they have size `k`.
1545
+ */
1546
+ declare function topK(x: ArrayLike, k: number, axis?: number): [Array, Array];
1527
1547
  declare namespace numpy_fft_d_exports {
1528
1548
  export { ComplexPair, fft, ifft };
1529
1549
  }
@@ -1548,7 +1568,7 @@ declare function fft(a: ComplexPair, axis?: number): ComplexPair;
1548
1568
  */
1549
1569
  declare function ifft(a: ComplexPair, axis?: number): ComplexPair;
1550
1570
  declare namespace numpy_linalg_d_exports {
1551
- export { cholesky, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot };
1571
+ export { cholesky, cross$1 as cross, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot };
1552
1572
  }
1553
1573
  /**
1554
1574
  * Compute the Cholesky decomposition of a (batched) positive-definite matrix.
@@ -1563,6 +1583,13 @@ declare function cholesky(a: ArrayLike, {
1563
1583
  upper?: boolean;
1564
1584
  symmetrizeInput?: boolean;
1565
1585
  }): Array;
1586
+ /**
1587
+ * Compute the cross-product of two 3D vectors.
1588
+ *
1589
+ * This is a simpler and less flexible version of `jax.numpy.cross()`.
1590
+ * Both inputs must have size 3 along the specified axis.
1591
+ */
1592
+ declare function cross$1(x1: ArrayLike, x2: ArrayLike, axis?: number): Array;
1566
1593
  /** Compute the determinant of a square matrix (batched). */
1567
1594
  declare function det(a: ArrayLike): Array;
1568
1595
  /** Compute the inverse of a square matrix (batched). */
@@ -1649,7 +1676,7 @@ type IInfo = Readonly<{
1649
1676
  /** Machine limits for integer types. */
1650
1677
  declare function iinfo(dtype: DType): IInfo;
1651
1678
  declare namespace numpy_d_exports {
1652
- export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logspace, matmul, matrixTranspose, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, nanToNum, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
1679
+ export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, round as around, array, arrayEqual, arrayEquiv, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, average, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, copysign, corrcoef, correlate, cos, cosh, cov, cross, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logicalAnd, logicalNot, logicalOr, logicalXor, logspace, matmul, matrixTranspose, matvec, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, nanToNum, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, rint, round, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vecmat, vstack, where, zeros, zerosLike };
1653
1680
  }
1654
1681
  declare const float32 = DType.Float32;
1655
1682
  declare const int32 = DType.Int32;
@@ -1709,6 +1736,14 @@ declare const notEqual: (x: ArrayLike, y: ArrayLike) => Array;
1709
1736
  declare const greaterEqual: (x: ArrayLike, y: ArrayLike) => Array;
1710
1737
  /** @function Compare two arrays element-wise. */
1711
1738
  declare const lessEqual: (x: ArrayLike, y: ArrayLike) => Array;
1739
+ /** Compute element-wise logical AND. */
1740
+ declare function logicalAnd(x: ArrayLike, y: ArrayLike): Array;
1741
+ /** Compute element-wise logical OR. */
1742
+ declare function logicalOr(x: ArrayLike, y: ArrayLike): Array;
1743
+ /** Compute element-wise logical XOR. */
1744
+ declare function logicalXor(x: ArrayLike, y: ArrayLike): Array;
1745
+ /** Compute element-wise logical NOT. */
1746
+ declare function logicalNot(x: ArrayLike): Array;
1712
1747
  /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
1713
1748
  declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
1714
1749
  /**
@@ -1749,17 +1784,17 @@ declare const shape$1: (x: ArrayLike) => number[];
1749
1784
  * @function
1750
1785
  * Return an array of zeros with the same shape and type as a given array.
1751
1786
  */
1752
- declare const zerosLike: (a: ArrayLike, dtype?: DType) => Array;
1787
+ declare const zerosLike: (a: ArrayLike, opts?: DTypeShapeAndDevice) => Array;
1753
1788
  /**
1754
1789
  * @function
1755
1790
  * Return an array of ones with the same shape and type as a given array.
1756
1791
  */
1757
- declare const onesLike: (a: ArrayLike, dtype?: DType) => Array;
1792
+ declare const onesLike: (a: ArrayLike, opts?: DTypeShapeAndDevice) => Array;
1758
1793
  /**
1759
1794
  * @function
1760
1795
  * Return a full array with the same shape and type as a given array.
1761
1796
  */
1762
- declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, dtype?: DType) => Array;
1797
+ declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, opts?: DTypeShapeAndDevice) => Array;
1763
1798
  /**
1764
1799
  * Return the number of elements in an array, optionally along an axis.
1765
1800
  * Does not consume array reference.
@@ -1793,6 +1828,16 @@ declare function all(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1793
1828
  declare function ptp(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1794
1829
  /** Compute the average of the array elements along the specified axis. */
1795
1830
  declare function mean(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1831
+ /**
1832
+ * Compute the weighted average along the specified axis.
1833
+ *
1834
+ * If no axis is specified, mean is computed along all the axes. The weights
1835
+ * should have shape matching that of `a`, or if an axis is specified, it should
1836
+ * match the shape along those axes.
1837
+ */
1838
+ declare function average(a: ArrayLike, axis?: Axis, opts?: {
1839
+ weights?: ArrayLike;
1840
+ } & ReduceOpts): Array;
1796
1841
  /**
1797
1842
  * Returns the indices of the minimum values along an axis.
1798
1843
  *
@@ -1948,8 +1993,9 @@ declare function trace(a: ArrayLike, offset?: number, axis1?: number, axis2?: nu
1948
1993
  */
1949
1994
  declare function sort(a: ArrayLike, axis?: number): Array;
1950
1995
  /**
1951
- * Return indices that would sort an array. This may be an unstable sorting
1952
- * algorithm; it need not preserve order of indices in ties.
1996
+ * Return indices that would sort an array. Unlike `sort`, this is guaranteed to
1997
+ * be a stable sorting algorithm; it always returns the smaller index first in
1998
+ * event of ties.
1953
1999
  *
1954
2000
  * Returns an array of `int32` indices.
1955
2001
  *
@@ -1963,13 +2009,39 @@ declare function argsort(a: ArrayLike, axis?: number): Array;
1963
2009
  * numbered axis. By default, the flattened array is used.
1964
2010
  */
1965
2011
  declare function take(a: ArrayLike, indices: ArrayLike, axis?: number | null): Array;
1966
- /** Return if two arrays are element-wise equal within a tolerance. */
2012
+ /**
2013
+ * Return if two arrays are element-wise equal within a tolerance.
2014
+ *
2015
+ * The formula used is `|actual - expected| <= atol + rtol * |expected|`, with
2016
+ * NaN values comparing equal if `equalNaN` is true.
2017
+ */
1967
2018
  declare function allclose(actual: Parameters<typeof array>[0], expected: Parameters<typeof array>[0], options?: {
1968
2019
  rtol?: number;
1969
2020
  atol?: number;
2021
+ equalNaN?: boolean;
1970
2022
  }): boolean;
2023
+ /**
2024
+ * Check if two arrays are element-wise equal.
2025
+ *
2026
+ * Returns False if the arrays have different shapes. If `equalNaN` is True,
2027
+ * NaNs in the same position are considered equal.
2028
+ */
2029
+ declare function arrayEqual(a1: ArrayLike, a2: ArrayLike, opts?: {
2030
+ equalNaN?: boolean;
2031
+ }): Array;
2032
+ /**
2033
+ * Check if two arrays are element-wise equal after broadcasting.
2034
+ *
2035
+ * Unlike `arrayEqual`, this allows inputs with different but
2036
+ * broadcast-compatible shapes.
2037
+ */
2038
+ declare function arrayEquiv(a1: ArrayLike, a2: ArrayLike): Array;
1971
2039
  /** Matrix product of two arrays. */
1972
2040
  declare function matmul(x: ArrayLike, y: ArrayLike): Array;
2041
+ /** Matrix-vector product. x1 is [..., M, N], x2 is [..., N] → [..., M]. */
2042
+ declare function matvec(x1: ArrayLike, x2: ArrayLike): Array;
2043
+ /** Vector-matrix product. x1 is [..., N], x2 is [..., N, M] → [..., M]. */
2044
+ declare function vecmat(x1: ArrayLike, x2: ArrayLike): Array;
1973
2045
  /** Dot product of two arrays. */
1974
2046
  declare function dot(x: ArrayLike, y: ArrayLike): Array;
1975
2047
  /**
@@ -2022,6 +2094,18 @@ declare function inner(x: ArrayLike, y: ArrayLike): Array;
2022
2094
  * be of shape `[x.size, y.size]`.
2023
2095
  */
2024
2096
  declare function outer(x: ArrayLike, y: ArrayLike): Array;
2097
+ /**
2098
+ * @function Compute the cross product of two arrays.
2099
+ *
2100
+ * Supports 2D (scalar result) and 3D cross products, with optional axis
2101
+ * arguments. If `axis` is given, it overrides `axisa`, `axisb`, and `axisc`.
2102
+ */
2103
+ declare const cross: OwnedFunction<(a: ArrayLike, b: ArrayLike, args_2?: {
2104
+ axisa?: number | undefined;
2105
+ axisb?: number | undefined;
2106
+ axisc?: number | undefined;
2107
+ axis?: number | undefined;
2108
+ } | undefined) => Array>;
2025
2109
  /** Vector dot product of two arrays along a given axis. */
2026
2110
  declare function vecdot(x: ArrayLike, y: ArrayLike, {
2027
2111
  axis
@@ -2067,14 +2151,13 @@ declare function clip(a: ArrayLike, min?: ArrayLike, max?: ArrayLike): Array;
2067
2151
  declare function absolute(x: ArrayLike): Array;
2068
2152
  /** Return an element-wise indication of sign of the input. */
2069
2153
  declare function sign(x: ArrayLike): Array;
2070
- /** @function Return element-wise positive values of the input (no-op). */
2071
- declare const positive: (x: ArrayLike) => Array;
2072
2154
  /**
2073
- * Return the Hamming window of size M, a taper with a weighted cosine bell.
2074
- *
2075
- * `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
2155
+ * @function
2156
+ * Return the value with the magnitude of x and the sign of y, element-wise.
2076
2157
  */
2077
- declare function hamming(M: number): Array;
2158
+ declare const copysign: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
2159
+ /** @function Return element-wise positive values of the input (no-op). */
2160
+ declare const positive: (x: ArrayLike) => Array;
2078
2161
  /**
2079
2162
  * Return the Hann window of size M, a taper with a weighted cosine bell.
2080
2163
  *
@@ -2169,6 +2252,18 @@ declare const remainder: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
2169
2252
  declare function divmod(x: ArrayLike, y: ArrayLike): [Array, Array];
2170
2253
  /** Round input to the nearest integer towards zero. */
2171
2254
  declare function trunc(x: ArrayLike): Array;
2255
+ /**
2256
+ * @function
2257
+ * Round to the given number of decimals.
2258
+ *
2259
+ * Uses banker's rounding (round half to even) to match NumPy/JAX behavior.
2260
+ */
2261
+ declare const round: OwnedFunction<(a: ArrayLike, decimals?: number | undefined) => Array>;
2262
+ /**
2263
+ * @function
2264
+ * Round to the nearest integer, with ties going to the nearest even integer.
2265
+ */
2266
+ declare const rint: OwnedFunction<(x: ArrayLike) => Array>;
2172
2267
  /**
2173
2268
  * Compute `x1 * 2 ** x2` as a standard multiplication and exponentiation.
2174
2269
  *
@@ -2561,7 +2656,7 @@ declare function dotProductAttention(query: ArrayLike, key: ArrayLike, value: Ar
2561
2656
  localWindowSize?: number | [number, number];
2562
2657
  }): Array;
2563
2658
  declare namespace random_d_exports {
2564
- export { bernoulli, bits, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
2659
+ export { bernoulli, bits, categorical, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
2565
2660
  }
2566
2661
  /** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
2567
2662
  declare function key(seed: ArrayLike): Array;
@@ -2584,6 +2679,32 @@ declare const uniform: OwnedFunction<(key: ArrayLike, shape?: number[] | undefin
2584
2679
  * and must be broadcastable to `shape`.
2585
2680
  */
2586
2681
  declare function bernoulli(key: Array, p?: ArrayLike, shape?: number[]): Array;
2682
+ /**
2683
+ * @function
2684
+ * Sample random values from categorical distributions.
2685
+ *
2686
+ * Uses the Gumbel max trick for sampling with replacement, or the Gumbel top-k
2687
+ * trick for sampling without replacement.
2688
+ *
2689
+ * Note: Sampling without replacement currently uses argsort and slices the last
2690
+ * k elements. This should be replaced with a more efficient topK implementation.
2691
+ *
2692
+ * - `key` - PRNG key
2693
+ * - `logits` - Unnormalized log probabilities of the categorical distribution(s).
2694
+ * `softmax(logits, axis)` gives the corresponding probabilities.
2695
+ * - `axis` - Axis along which logits belong to the same categorical distribution.
2696
+ * - `shape` - Result batch shape. Must be broadcast-compatible with
2697
+ * `logits.shape` with `axis` removed. Default is `logits.shape` with `axis` removed.
2698
+ * - `replace` - If true (default), sample with replacement. If false, sample
2699
+ * without replacement (each category can only be selected once per batch).
2700
+ * @returns A random array with int dtype and shape given by `shape` if provided,
2701
+ * otherwise `logits.shape` with `axis` removed.
2702
+ */
2703
+ declare const categorical: OwnedFunction<(key: ArrayLike, logits: ArrayLike, args_2?: {
2704
+ axis?: number | undefined;
2705
+ shape?: number[] | undefined;
2706
+ replace?: boolean | undefined;
2707
+ } | undefined) => Array>;
2587
2708
  /**
2588
2709
  * @function
2589
2710
  * Sample from a Cauchy distribution with location 0 and scale 1.
@@ -2645,8 +2766,31 @@ declare namespace scipy_special_d_exports {
2645
2766
  * The logit function, `logit(p) = log(p / (1-p))`.
2646
2767
  */
2647
2768
  declare const logit: OwnedFunction<(x: ArrayLike) => Array>;
2769
+ //#endregion
2770
+ //#region src/tracing.d.ts
2771
+ /**
2772
+ * Start collecting kernel traces.
2773
+ *
2774
+ * Traces appear in developer tools under the "Performance" tab, and they are
2775
+ * useful for measuring fine-grained kernel execution time.
2776
+ */
2777
+ declare function startTrace(): void;
2778
+ /**
2779
+ * Stop collecting kernel traces.
2780
+ *
2781
+ * Traces appear in developer tools under the "Performance" tab, and they are
2782
+ * useful for measuring fine-grained kernel execution time.
2783
+ */
2784
+ declare function stopTrace(): void;
2785
+ /** Check if tracing is currently enabled. */
2786
+
2648
2787
  //#endregion
2649
2788
  //#region src/index.d.ts
2789
+ /** @namespace */
2790
+ declare const profiler: {
2791
+ startTrace: typeof startTrace;
2792
+ stopTrace: typeof stopTrace;
2793
+ };
2650
2794
  /**
2651
2795
  * @function
2652
2796
  * Compute the forward-mode Jacobian-vector product for a function.
@@ -2811,4 +2955,4 @@ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
2811
2955
  */
2812
2956
  declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
2813
2957
  //#endregion
2814
- export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
2958
+ export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, profiler, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };