@jax-js/jax 0.1.8 → 0.1.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.d.cts CHANGED
@@ -436,9 +436,14 @@ declare class Routine {
436
436
  }
437
437
  /** One of the valid `Routine` that can be dispatched to backend. */
438
438
  declare enum Routines {
439
- /** Stable sorting algorithm along the last axis. */
439
+ /**
440
+ * Sort along the last axis.
441
+ *
442
+ * This may be _unstable_ but it often doesn't matter, sorting numbers is
443
+ * bitwise unique up to signed zeros and NaNs.
444
+ */
440
445
  Sort = "Sort",
441
- /** Returns `int32` indices of the stably sorted array. */
446
+ /** Stable sorting, returns `int32` indices and values of the sorted array. */
442
447
  Argsort = "Argsort",
443
448
  /**
444
449
  * Solve a triangular system of equations.
@@ -750,9 +755,9 @@ declare enum Primitive {
750
755
  Shrink = "shrink",
751
756
  Pad = "pad",
752
757
  Sort = "sort",
753
- // sort(x, axis=-1)
758
+ // sort(x, axis=-1), unstable
754
759
  Argsort = "argsort",
755
- // argsort(x, axis=-1)
760
+ // argsort(x, axis=-1), stable
756
761
  TriangularSolve = "triangular_solve",
757
762
  // A is upper triangular, A @ X.T = B.T
758
763
  Cholesky = "cholesky",
@@ -999,6 +1004,8 @@ declare abstract class Tracer {
999
1004
  reshape(shape: number | number[]): this;
1000
1005
  /** Copy the array and cast to a specified dtype. */
1001
1006
  astype(dtype: DType): this;
1007
+ /** Return a bitwise cast of the array, viewed as a new dtype. */
1008
+ view(dtype?: DType): this;
1002
1009
  /** Subtract an array from this one. */
1003
1010
  sub(other: this | TracerValue): this;
1004
1011
  /** Divide an array by this one. */
@@ -1029,8 +1036,9 @@ declare abstract class Tracer {
1029
1036
  */
1030
1037
  sort(axis?: number): this;
1031
1038
  /**
1032
- * Return the indices that would sort an array. This may not be a stable
1033
- * sorting algorithm; it need not preserve order of indices in ties.
1039
+ * Return the indices that would sort an array. Unlike `sort`, this is
1040
+ * guaranteed to be a stable sorting algorithm; it always returns the smaller
1041
+ * index first in event of ties.
1034
1042
  *
1035
1043
  * See `jax.numpy.argsort` for full docs.
1036
1044
  */
@@ -1112,6 +1120,12 @@ type DTypeAndDevice = {
1112
1120
  dtype?: DType;
1113
1121
  device?: Device;
1114
1122
  };
1123
+ /** @inline */
1124
+ type DTypeShapeAndDevice = {
1125
+ dtype?: DType;
1126
+ shape?: number[];
1127
+ device?: Device;
1128
+ };
1115
1129
  type ArrayConstructorArgs = {
1116
1130
  source: AluExp | Slot;
1117
1131
  st: ShapeTracker;
@@ -1221,15 +1235,9 @@ declare function array(values: Array | DataArray | RecursiveArray<number> | Recu
1221
1235
  type ImplRule<P extends Primitive> = (tracers: Array[], params: PrimitiveParams<P>) => Array[];
1222
1236
  declare const implRules: { [P in Primitive]: ImplRule<P> };
1223
1237
  /** Return a new array of given shape and type, filled with zeros. */
1224
- declare function zeros(shape: number[], {
1225
- dtype,
1226
- device
1227
- }?: DTypeAndDevice): Array;
1238
+ declare function zeros(shape: number[], opts?: DTypeAndDevice): Array;
1228
1239
  /** Return a new array of given shape and type, filled with ones. */
1229
- declare function ones(shape: number[], {
1230
- dtype,
1231
- device
1232
- }?: DTypeAndDevice): Array;
1240
+ declare function ones(shape: number[], opts?: DTypeAndDevice): Array;
1233
1241
  /** Return a new array of given shape and type, filled with `fill_value`. */
1234
1242
  declare function full(shape: number[], fillValue: number | boolean | Array, {
1235
1243
  dtype,
@@ -1421,8 +1429,10 @@ declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
1421
1429
  unitDiagonal?: boolean;
1422
1430
  }): Array;
1423
1431
  declare namespace lax_d_exports {
1424
- export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
1432
+ export { DotDimensionNumbers, PaddingType, bitcastConvertType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient, topK };
1425
1433
  }
1434
+ /** Elementwise bitcast an array into a new dtype. */
1435
+ declare function bitcastConvertType(x: ArrayLike, newDtype: DType): Array;
1426
1436
  /**
1427
1437
  * Dimension numbers for general `dot()` primitive.
1428
1438
  *
@@ -1527,6 +1537,16 @@ declare function erfc(x: ArrayLike): Array;
1527
1537
  * forward or reverse-mode automatic differentiation.
1528
1538
  */
1529
1539
  declare function stopGradient(x: ArrayLike): Array;
1540
+ /**
1541
+ * Returns top `k` values and their indices along the specified axis of operand.
1542
+ *
1543
+ * This is a _stable_ algorithm: If two elements are equal, the lower-index
1544
+ * element appears first.
1545
+ *
1546
+ * @returns A tuple of `(values, indices)`, where `values` and `indices` have
1547
+ * the same shape as `x`, except along `axis` where they have size `k`.
1548
+ */
1549
+ declare function topK(x: ArrayLike, k: number, axis?: number): [Array, Array];
1530
1550
  declare namespace numpy_fft_d_exports {
1531
1551
  export { ComplexPair, fft, ifft };
1532
1552
  }
@@ -1551,7 +1571,7 @@ declare function fft(a: ComplexPair, axis?: number): ComplexPair;
1551
1571
  */
1552
1572
  declare function ifft(a: ComplexPair, axis?: number): ComplexPair;
1553
1573
  declare namespace numpy_linalg_d_exports {
1554
- export { cholesky, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot };
1574
+ export { cholesky, cross$1 as cross, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot };
1555
1575
  }
1556
1576
  /**
1557
1577
  * Compute the Cholesky decomposition of a (batched) positive-definite matrix.
@@ -1566,6 +1586,13 @@ declare function cholesky(a: ArrayLike, {
1566
1586
  upper?: boolean;
1567
1587
  symmetrizeInput?: boolean;
1568
1588
  }): Array;
1589
+ /**
1590
+ * Compute the cross-product of two 3D vectors.
1591
+ *
1592
+ * This is a simpler and less flexible version of `jax.numpy.cross()`.
1593
+ * Both inputs must have size 3 along the specified axis.
1594
+ */
1595
+ declare function cross$1(x1: ArrayLike, x2: ArrayLike, axis?: number): Array;
1569
1596
  /** Compute the determinant of a square matrix (batched). */
1570
1597
  declare function det(a: ArrayLike): Array;
1571
1598
  /** Compute the inverse of a square matrix (batched). */
@@ -1652,7 +1679,7 @@ type IInfo = Readonly<{
1652
1679
  /** Machine limits for integer types. */
1653
1680
  declare function iinfo(dtype: DType): IInfo;
1654
1681
  declare namespace numpy_d_exports {
1655
- export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logspace, matmul, matrixTranspose, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, nanToNum, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
1682
+ export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, round as around, array, arrayEqual, arrayEquiv, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, average, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, copysign, corrcoef, correlate, cos, cosh, cov, cross, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logicalAnd, logicalNot, logicalOr, logicalXor, logspace, matmul, matrixTranspose, matvec, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, nanToNum, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, rint, round, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vecmat, vstack, where, zeros, zerosLike };
1656
1683
  }
1657
1684
  declare const float32 = DType.Float32;
1658
1685
  declare const int32 = DType.Int32;
@@ -1712,6 +1739,14 @@ declare const notEqual: (x: ArrayLike, y: ArrayLike) => Array;
1712
1739
  declare const greaterEqual: (x: ArrayLike, y: ArrayLike) => Array;
1713
1740
  /** @function Compare two arrays element-wise. */
1714
1741
  declare const lessEqual: (x: ArrayLike, y: ArrayLike) => Array;
1742
+ /** Compute element-wise logical AND. */
1743
+ declare function logicalAnd(x: ArrayLike, y: ArrayLike): Array;
1744
+ /** Compute element-wise logical OR. */
1745
+ declare function logicalOr(x: ArrayLike, y: ArrayLike): Array;
1746
+ /** Compute element-wise logical XOR. */
1747
+ declare function logicalXor(x: ArrayLike, y: ArrayLike): Array;
1748
+ /** Compute element-wise logical NOT. */
1749
+ declare function logicalNot(x: ArrayLike): Array;
1715
1750
  /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
1716
1751
  declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
1717
1752
  /**
@@ -1752,17 +1787,17 @@ declare const shape$1: (x: ArrayLike) => number[];
1752
1787
  * @function
1753
1788
  * Return an array of zeros with the same shape and type as a given array.
1754
1789
  */
1755
- declare const zerosLike: (a: ArrayLike, dtype?: DType) => Array;
1790
+ declare const zerosLike: (a: ArrayLike, opts?: DTypeShapeAndDevice) => Array;
1756
1791
  /**
1757
1792
  * @function
1758
1793
  * Return an array of ones with the same shape and type as a given array.
1759
1794
  */
1760
- declare const onesLike: (a: ArrayLike, dtype?: DType) => Array;
1795
+ declare const onesLike: (a: ArrayLike, opts?: DTypeShapeAndDevice) => Array;
1761
1796
  /**
1762
1797
  * @function
1763
1798
  * Return a full array with the same shape and type as a given array.
1764
1799
  */
1765
- declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, dtype?: DType) => Array;
1800
+ declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, opts?: DTypeShapeAndDevice) => Array;
1766
1801
  /**
1767
1802
  * Return the number of elements in an array, optionally along an axis.
1768
1803
  * Does not consume array reference.
@@ -1796,6 +1831,16 @@ declare function all(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1796
1831
  declare function ptp(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1797
1832
  /** Compute the average of the array elements along the specified axis. */
1798
1833
  declare function mean(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1834
+ /**
1835
+ * Compute the weighted average along the specified axis.
1836
+ *
1837
+ * If no axis is specified, mean is computed along all the axes. The weights
1838
+ * should have shape matching that of `a`, or if an axis is specified, it should
1839
+ * match the shape along those axes.
1840
+ */
1841
+ declare function average(a: ArrayLike, axis?: Axis, opts?: {
1842
+ weights?: ArrayLike;
1843
+ } & ReduceOpts): Array;
1799
1844
  /**
1800
1845
  * Returns the indices of the minimum values along an axis.
1801
1846
  *
@@ -1951,8 +1996,9 @@ declare function trace(a: ArrayLike, offset?: number, axis1?: number, axis2?: nu
1951
1996
  */
1952
1997
  declare function sort(a: ArrayLike, axis?: number): Array;
1953
1998
  /**
1954
- * Return indices that would sort an array. This may be an unstable sorting
1955
- * algorithm; it need not preserve order of indices in ties.
1999
+ * Return indices that would sort an array. Unlike `sort`, this is guaranteed to
2000
+ * be a stable sorting algorithm; it always returns the smaller index first in
2001
+ * event of ties.
1956
2002
  *
1957
2003
  * Returns an array of `int32` indices.
1958
2004
  *
@@ -1966,13 +2012,39 @@ declare function argsort(a: ArrayLike, axis?: number): Array;
1966
2012
  * numbered axis. By default, the flattened array is used.
1967
2013
  */
1968
2014
  declare function take(a: ArrayLike, indices: ArrayLike, axis?: number | null): Array;
1969
- /** Return if two arrays are element-wise equal within a tolerance. */
2015
+ /**
2016
+ * Return if two arrays are element-wise equal within a tolerance.
2017
+ *
2018
+ * The formula used is `|actual - expected| <= atol + rtol * |expected|`, with
2019
+ * NaN values comparing equal if `equalNaN` is true.
2020
+ */
1970
2021
  declare function allclose(actual: Parameters<typeof array>[0], expected: Parameters<typeof array>[0], options?: {
1971
2022
  rtol?: number;
1972
2023
  atol?: number;
2024
+ equalNaN?: boolean;
1973
2025
  }): boolean;
2026
+ /**
2027
+ * Check if two arrays are element-wise equal.
2028
+ *
2029
+ * Returns False if the arrays have different shapes. If `equalNaN` is True,
2030
+ * NaNs in the same position are considered equal.
2031
+ */
2032
+ declare function arrayEqual(a1: ArrayLike, a2: ArrayLike, opts?: {
2033
+ equalNaN?: boolean;
2034
+ }): Array;
2035
+ /**
2036
+ * Check if two arrays are element-wise equal after broadcasting.
2037
+ *
2038
+ * Unlike `arrayEqual`, this allows inputs with different but
2039
+ * broadcast-compatible shapes.
2040
+ */
2041
+ declare function arrayEquiv(a1: ArrayLike, a2: ArrayLike): Array;
1974
2042
  /** Matrix product of two arrays. */
1975
2043
  declare function matmul(x: ArrayLike, y: ArrayLike): Array;
2044
+ /** Matrix-vector product. x1 is [..., M, N], x2 is [..., N] → [..., M]. */
2045
+ declare function matvec(x1: ArrayLike, x2: ArrayLike): Array;
2046
+ /** Vector-matrix product. x1 is [..., N], x2 is [..., N, M] → [..., M]. */
2047
+ declare function vecmat(x1: ArrayLike, x2: ArrayLike): Array;
1976
2048
  /** Dot product of two arrays. */
1977
2049
  declare function dot(x: ArrayLike, y: ArrayLike): Array;
1978
2050
  /**
@@ -2025,6 +2097,18 @@ declare function inner(x: ArrayLike, y: ArrayLike): Array;
2025
2097
  * be of shape `[x.size, y.size]`.
2026
2098
  */
2027
2099
  declare function outer(x: ArrayLike, y: ArrayLike): Array;
2100
+ /**
2101
+ * @function Compute the cross product of two arrays.
2102
+ *
2103
+ * Supports 2D (scalar result) and 3D cross products, with optional axis
2104
+ * arguments. If `axis` is given, it overrides `axisa`, `axisb`, and `axisc`.
2105
+ */
2106
+ declare const cross: OwnedFunction<(a: ArrayLike, b: ArrayLike, args_2?: {
2107
+ axisa?: number | undefined;
2108
+ axisb?: number | undefined;
2109
+ axisc?: number | undefined;
2110
+ axis?: number | undefined;
2111
+ } | undefined) => Array>;
2028
2112
  /** Vector dot product of two arrays along a given axis. */
2029
2113
  declare function vecdot(x: ArrayLike, y: ArrayLike, {
2030
2114
  axis
@@ -2070,14 +2154,13 @@ declare function clip(a: ArrayLike, min?: ArrayLike, max?: ArrayLike): Array;
2070
2154
  declare function absolute(x: ArrayLike): Array;
2071
2155
  /** Return an element-wise indication of sign of the input. */
2072
2156
  declare function sign(x: ArrayLike): Array;
2073
- /** @function Return element-wise positive values of the input (no-op). */
2074
- declare const positive: (x: ArrayLike) => Array;
2075
2157
  /**
2076
- * Return the Hamming window of size M, a taper with a weighted cosine bell.
2077
- *
2078
- * `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
2158
+ * @function
2159
+ * Return the value with the magnitude of x and the sign of y, element-wise.
2079
2160
  */
2080
- declare function hamming(M: number): Array;
2161
+ declare const copysign: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
2162
+ /** @function Return element-wise positive values of the input (no-op). */
2163
+ declare const positive: (x: ArrayLike) => Array;
2081
2164
  /**
2082
2165
  * Return the Hann window of size M, a taper with a weighted cosine bell.
2083
2166
  *
@@ -2172,6 +2255,18 @@ declare const remainder: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
2172
2255
  declare function divmod(x: ArrayLike, y: ArrayLike): [Array, Array];
2173
2256
  /** Round input to the nearest integer towards zero. */
2174
2257
  declare function trunc(x: ArrayLike): Array;
2258
+ /**
2259
+ * @function
2260
+ * Round to the given number of decimals.
2261
+ *
2262
+ * Uses banker's rounding (round half to even) to match NumPy/JAX behavior.
2263
+ */
2264
+ declare const round: OwnedFunction<(a: ArrayLike, decimals?: number | undefined) => Array>;
2265
+ /**
2266
+ * @function
2267
+ * Round to the nearest integer, with ties going to the nearest even integer.
2268
+ */
2269
+ declare const rint: OwnedFunction<(x: ArrayLike) => Array>;
2175
2270
  /**
2176
2271
  * Compute `x1 * 2 ** x2` as a standard multiplication and exponentiation.
2177
2272
  *
@@ -2564,7 +2659,7 @@ declare function dotProductAttention(query: ArrayLike, key: ArrayLike, value: Ar
2564
2659
  localWindowSize?: number | [number, number];
2565
2660
  }): Array;
2566
2661
  declare namespace random_d_exports {
2567
- export { bernoulli, bits, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
2662
+ export { bernoulli, bits, categorical, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
2568
2663
  }
2569
2664
  /** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
2570
2665
  declare function key(seed: ArrayLike): Array;
@@ -2587,6 +2682,32 @@ declare const uniform: OwnedFunction<(key: ArrayLike, shape?: number[] | undefin
2587
2682
  * and must be broadcastable to `shape`.
2588
2683
  */
2589
2684
  declare function bernoulli(key: Array, p?: ArrayLike, shape?: number[]): Array;
2685
+ /**
2686
+ * @function
2687
+ * Sample random values from categorical distributions.
2688
+ *
2689
+ * Uses the Gumbel max trick for sampling with replacement, or the Gumbel top-k
2690
+ * trick for sampling without replacement.
2691
+ *
2692
+ * Note: Sampling without replacement currently uses argsort and slices the last
2693
+ * k elements. This should be replaced with a more efficient topK implementation.
2694
+ *
2695
+ * - `key` - PRNG key
2696
+ * - `logits` - Unnormalized log probabilities of the categorical distribution(s).
2697
+ * `softmax(logits, axis)` gives the corresponding probabilities.
2698
+ * - `axis` - Axis along which logits belong to the same categorical distribution.
2699
+ * - `shape` - Result batch shape. Must be broadcast-compatible with
2700
+ * `logits.shape` with `axis` removed. Default is `logits.shape` with `axis` removed.
2701
+ * - `replace` - If true (default), sample with replacement. If false, sample
2702
+ * without replacement (each category can only be selected once per batch).
2703
+ * @returns A random array with int dtype and shape given by `shape` if provided,
2704
+ * otherwise `logits.shape` with `axis` removed.
2705
+ */
2706
+ declare const categorical: OwnedFunction<(key: ArrayLike, logits: ArrayLike, args_2?: {
2707
+ axis?: number | undefined;
2708
+ shape?: number[] | undefined;
2709
+ replace?: boolean | undefined;
2710
+ } | undefined) => Array>;
2590
2711
  /**
2591
2712
  * @function
2592
2713
  * Sample from a Cauchy distribution with location 0 and scale 1.
@@ -2648,8 +2769,31 @@ declare namespace scipy_special_d_exports {
2648
2769
  * The logit function, `logit(p) = log(p / (1-p))`.
2649
2770
  */
2650
2771
  declare const logit: OwnedFunction<(x: ArrayLike) => Array>;
2772
+ //#endregion
2773
+ //#region src/tracing.d.ts
2774
+ /**
2775
+ * Start collecting kernel traces.
2776
+ *
2777
+ * Traces appear in developer tools under the "Performance" tab, and they are
2778
+ * useful for measuring fine-grained kernel execution time.
2779
+ */
2780
+ declare function startTrace(): void;
2781
+ /**
2782
+ * Stop collecting kernel traces.
2783
+ *
2784
+ * Traces appear in developer tools under the "Performance" tab, and they are
2785
+ * useful for measuring fine-grained kernel execution time.
2786
+ */
2787
+ declare function stopTrace(): void;
2788
+ /** Check if tracing is currently enabled. */
2789
+
2651
2790
  //#endregion
2652
2791
  //#region src/index.d.ts
2792
+ /** @namespace */
2793
+ declare const profiler: {
2794
+ startTrace: typeof startTrace;
2795
+ stopTrace: typeof stopTrace;
2796
+ };
2653
2797
  /**
2654
2798
  * @function
2655
2799
  * Compute the forward-mode Jacobian-vector product for a function.
@@ -2814,4 +2958,4 @@ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
2814
2958
  */
2815
2959
  declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
2816
2960
  //#endregion
2817
- export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
2961
+ export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, profiler, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };