@jax-js/jax 0.1.8 → 0.1.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +46 -29
- package/dist/{backend-nEolvdLv.js → backend-Ctqs8la1.js} +122 -15
- package/dist/{backend-B3foXiV_.cjs → backend-DMauYnfl.cjs} +157 -14
- package/dist/index.cjs +331 -46
- package/dist/index.d.cts +175 -31
- package/dist/index.d.ts +175 -31
- package/dist/index.js +331 -47
- package/dist/{webgl-DweKSWEm.js → webgl-CvQ1QBX1.js} +1 -1
- package/dist/{webgl-DIIbKJ0G.cjs → webgl-kvVt7-T7.cjs} +1 -1
- package/dist/{webgpu-BykvF26B.cjs → webgpu-DMSx7a6M.cjs} +160 -15
- package/dist/{webgpu-B96vzWGE.js → webgpu-v_W_-oKw.js} +160 -15
- package/package.json +5 -16
package/dist/index.d.ts
CHANGED
|
@@ -433,9 +433,14 @@ declare class Routine {
|
|
|
433
433
|
}
|
|
434
434
|
/** One of the valid `Routine` that can be dispatched to backend. */
|
|
435
435
|
declare enum Routines {
|
|
436
|
-
/**
|
|
436
|
+
/**
|
|
437
|
+
* Sort along the last axis.
|
|
438
|
+
*
|
|
439
|
+
* This may be _unstable_ but it often doesn't matter, sorting numbers is
|
|
440
|
+
* bitwise unique up to signed zeros and NaNs.
|
|
441
|
+
*/
|
|
437
442
|
Sort = "Sort",
|
|
438
|
-
/**
|
|
443
|
+
/** Stable sorting, returns `int32` indices and values of the sorted array. */
|
|
439
444
|
Argsort = "Argsort",
|
|
440
445
|
/**
|
|
441
446
|
* Solve a triangular system of equations.
|
|
@@ -747,9 +752,9 @@ declare enum Primitive {
|
|
|
747
752
|
Shrink = "shrink",
|
|
748
753
|
Pad = "pad",
|
|
749
754
|
Sort = "sort",
|
|
750
|
-
// sort(x, axis=-1)
|
|
755
|
+
// sort(x, axis=-1), unstable
|
|
751
756
|
Argsort = "argsort",
|
|
752
|
-
// argsort(x, axis=-1)
|
|
757
|
+
// argsort(x, axis=-1), stable
|
|
753
758
|
TriangularSolve = "triangular_solve",
|
|
754
759
|
// A is upper triangular, A @ X.T = B.T
|
|
755
760
|
Cholesky = "cholesky",
|
|
@@ -996,6 +1001,8 @@ declare abstract class Tracer {
|
|
|
996
1001
|
reshape(shape: number | number[]): this;
|
|
997
1002
|
/** Copy the array and cast to a specified dtype. */
|
|
998
1003
|
astype(dtype: DType): this;
|
|
1004
|
+
/** Return a bitwise cast of the array, viewed as a new dtype. */
|
|
1005
|
+
view(dtype?: DType): this;
|
|
999
1006
|
/** Subtract an array from this one. */
|
|
1000
1007
|
sub(other: this | TracerValue): this;
|
|
1001
1008
|
/** Divide an array by this one. */
|
|
@@ -1026,8 +1033,9 @@ declare abstract class Tracer {
|
|
|
1026
1033
|
*/
|
|
1027
1034
|
sort(axis?: number): this;
|
|
1028
1035
|
/**
|
|
1029
|
-
* Return the indices that would sort an array.
|
|
1030
|
-
* sorting algorithm; it
|
|
1036
|
+
* Return the indices that would sort an array. Unlike `sort`, this is
|
|
1037
|
+
* guaranteed to be a stable sorting algorithm; it always returns the smaller
|
|
1038
|
+
* index first in event of ties.
|
|
1031
1039
|
*
|
|
1032
1040
|
* See `jax.numpy.argsort` for full docs.
|
|
1033
1041
|
*/
|
|
@@ -1109,6 +1117,12 @@ type DTypeAndDevice = {
|
|
|
1109
1117
|
dtype?: DType;
|
|
1110
1118
|
device?: Device;
|
|
1111
1119
|
};
|
|
1120
|
+
/** @inline */
|
|
1121
|
+
type DTypeShapeAndDevice = {
|
|
1122
|
+
dtype?: DType;
|
|
1123
|
+
shape?: number[];
|
|
1124
|
+
device?: Device;
|
|
1125
|
+
};
|
|
1112
1126
|
type ArrayConstructorArgs = {
|
|
1113
1127
|
source: AluExp | Slot;
|
|
1114
1128
|
st: ShapeTracker;
|
|
@@ -1218,15 +1232,9 @@ declare function array(values: Array | DataArray | RecursiveArray<number> | Recu
|
|
|
1218
1232
|
type ImplRule<P extends Primitive> = (tracers: Array[], params: PrimitiveParams<P>) => Array[];
|
|
1219
1233
|
declare const implRules: { [P in Primitive]: ImplRule<P> };
|
|
1220
1234
|
/** Return a new array of given shape and type, filled with zeros. */
|
|
1221
|
-
declare function zeros(shape: number[],
|
|
1222
|
-
dtype,
|
|
1223
|
-
device
|
|
1224
|
-
}?: DTypeAndDevice): Array;
|
|
1235
|
+
declare function zeros(shape: number[], opts?: DTypeAndDevice): Array;
|
|
1225
1236
|
/** Return a new array of given shape and type, filled with ones. */
|
|
1226
|
-
declare function ones(shape: number[],
|
|
1227
|
-
dtype,
|
|
1228
|
-
device
|
|
1229
|
-
}?: DTypeAndDevice): Array;
|
|
1237
|
+
declare function ones(shape: number[], opts?: DTypeAndDevice): Array;
|
|
1230
1238
|
/** Return a new array of given shape and type, filled with `fill_value`. */
|
|
1231
1239
|
declare function full(shape: number[], fillValue: number | boolean | Array, {
|
|
1232
1240
|
dtype,
|
|
@@ -1418,8 +1426,10 @@ declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
|
|
|
1418
1426
|
unitDiagonal?: boolean;
|
|
1419
1427
|
}): Array;
|
|
1420
1428
|
declare namespace lax_d_exports {
|
|
1421
|
-
export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
|
|
1429
|
+
export { DotDimensionNumbers, PaddingType, bitcastConvertType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient, topK };
|
|
1422
1430
|
}
|
|
1431
|
+
/** Elementwise bitcast an array into a new dtype. */
|
|
1432
|
+
declare function bitcastConvertType(x: ArrayLike, newDtype: DType): Array;
|
|
1423
1433
|
/**
|
|
1424
1434
|
* Dimension numbers for general `dot()` primitive.
|
|
1425
1435
|
*
|
|
@@ -1524,6 +1534,16 @@ declare function erfc(x: ArrayLike): Array;
|
|
|
1524
1534
|
* forward or reverse-mode automatic differentiation.
|
|
1525
1535
|
*/
|
|
1526
1536
|
declare function stopGradient(x: ArrayLike): Array;
|
|
1537
|
+
/**
|
|
1538
|
+
* Returns top `k` values and their indices along the specified axis of operand.
|
|
1539
|
+
*
|
|
1540
|
+
* This is a _stable_ algorithm: If two elements are equal, the lower-index
|
|
1541
|
+
* element appears first.
|
|
1542
|
+
*
|
|
1543
|
+
* @returns A tuple of `(values, indices)`, where `values` and `indices` have
|
|
1544
|
+
* the same shape as `x`, except along `axis` where they have size `k`.
|
|
1545
|
+
*/
|
|
1546
|
+
declare function topK(x: ArrayLike, k: number, axis?: number): [Array, Array];
|
|
1527
1547
|
declare namespace numpy_fft_d_exports {
|
|
1528
1548
|
export { ComplexPair, fft, ifft };
|
|
1529
1549
|
}
|
|
@@ -1548,7 +1568,7 @@ declare function fft(a: ComplexPair, axis?: number): ComplexPair;
|
|
|
1548
1568
|
*/
|
|
1549
1569
|
declare function ifft(a: ComplexPair, axis?: number): ComplexPair;
|
|
1550
1570
|
declare namespace numpy_linalg_d_exports {
|
|
1551
|
-
export { cholesky, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot };
|
|
1571
|
+
export { cholesky, cross$1 as cross, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot };
|
|
1552
1572
|
}
|
|
1553
1573
|
/**
|
|
1554
1574
|
* Compute the Cholesky decomposition of a (batched) positive-definite matrix.
|
|
@@ -1563,6 +1583,13 @@ declare function cholesky(a: ArrayLike, {
|
|
|
1563
1583
|
upper?: boolean;
|
|
1564
1584
|
symmetrizeInput?: boolean;
|
|
1565
1585
|
}): Array;
|
|
1586
|
+
/**
|
|
1587
|
+
* Compute the cross-product of two 3D vectors.
|
|
1588
|
+
*
|
|
1589
|
+
* This is a simpler and less flexible version of `jax.numpy.cross()`.
|
|
1590
|
+
* Both inputs must have size 3 along the specified axis.
|
|
1591
|
+
*/
|
|
1592
|
+
declare function cross$1(x1: ArrayLike, x2: ArrayLike, axis?: number): Array;
|
|
1566
1593
|
/** Compute the determinant of a square matrix (batched). */
|
|
1567
1594
|
declare function det(a: ArrayLike): Array;
|
|
1568
1595
|
/** Compute the inverse of a square matrix (batched). */
|
|
@@ -1649,7 +1676,7 @@ type IInfo = Readonly<{
|
|
|
1649
1676
|
/** Machine limits for integer types. */
|
|
1650
1677
|
declare function iinfo(dtype: DType): IInfo;
|
|
1651
1678
|
declare namespace numpy_d_exports {
|
|
1652
|
-
export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual,
|
|
1679
|
+
export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, round as around, array, arrayEqual, arrayEquiv, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, average, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, copysign, corrcoef, correlate, cos, cosh, cov, cross, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logicalAnd, logicalNot, logicalOr, logicalXor, logspace, matmul, matrixTranspose, matvec, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, nanToNum, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, rint, round, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vecmat, vstack, where, zeros, zerosLike };
|
|
1653
1680
|
}
|
|
1654
1681
|
declare const float32 = DType.Float32;
|
|
1655
1682
|
declare const int32 = DType.Int32;
|
|
@@ -1709,6 +1736,14 @@ declare const notEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
|
1709
1736
|
declare const greaterEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
1710
1737
|
/** @function Compare two arrays element-wise. */
|
|
1711
1738
|
declare const lessEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
1739
|
+
/** Compute element-wise logical AND. */
|
|
1740
|
+
declare function logicalAnd(x: ArrayLike, y: ArrayLike): Array;
|
|
1741
|
+
/** Compute element-wise logical OR. */
|
|
1742
|
+
declare function logicalOr(x: ArrayLike, y: ArrayLike): Array;
|
|
1743
|
+
/** Compute element-wise logical XOR. */
|
|
1744
|
+
declare function logicalXor(x: ArrayLike, y: ArrayLike): Array;
|
|
1745
|
+
/** Compute element-wise logical NOT. */
|
|
1746
|
+
declare function logicalNot(x: ArrayLike): Array;
|
|
1712
1747
|
/** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
1713
1748
|
declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
|
|
1714
1749
|
/**
|
|
@@ -1749,17 +1784,17 @@ declare const shape$1: (x: ArrayLike) => number[];
|
|
|
1749
1784
|
* @function
|
|
1750
1785
|
* Return an array of zeros with the same shape and type as a given array.
|
|
1751
1786
|
*/
|
|
1752
|
-
declare const zerosLike: (a: ArrayLike,
|
|
1787
|
+
declare const zerosLike: (a: ArrayLike, opts?: DTypeShapeAndDevice) => Array;
|
|
1753
1788
|
/**
|
|
1754
1789
|
* @function
|
|
1755
1790
|
* Return an array of ones with the same shape and type as a given array.
|
|
1756
1791
|
*/
|
|
1757
|
-
declare const onesLike: (a: ArrayLike,
|
|
1792
|
+
declare const onesLike: (a: ArrayLike, opts?: DTypeShapeAndDevice) => Array;
|
|
1758
1793
|
/**
|
|
1759
1794
|
* @function
|
|
1760
1795
|
* Return a full array with the same shape and type as a given array.
|
|
1761
1796
|
*/
|
|
1762
|
-
declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array,
|
|
1797
|
+
declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, opts?: DTypeShapeAndDevice) => Array;
|
|
1763
1798
|
/**
|
|
1764
1799
|
* Return the number of elements in an array, optionally along an axis.
|
|
1765
1800
|
* Does not consume array reference.
|
|
@@ -1793,6 +1828,16 @@ declare function all(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
|
1793
1828
|
declare function ptp(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1794
1829
|
/** Compute the average of the array elements along the specified axis. */
|
|
1795
1830
|
declare function mean(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1831
|
+
/**
|
|
1832
|
+
* Compute the weighted average along the specified axis.
|
|
1833
|
+
*
|
|
1834
|
+
* If no axis is specified, mean is computed along all the axes. The weights
|
|
1835
|
+
* should have shape matching that of `a`, or if an axis is specified, it should
|
|
1836
|
+
* match the shape along those axes.
|
|
1837
|
+
*/
|
|
1838
|
+
declare function average(a: ArrayLike, axis?: Axis, opts?: {
|
|
1839
|
+
weights?: ArrayLike;
|
|
1840
|
+
} & ReduceOpts): Array;
|
|
1796
1841
|
/**
|
|
1797
1842
|
* Returns the indices of the minimum values along an axis.
|
|
1798
1843
|
*
|
|
@@ -1948,8 +1993,9 @@ declare function trace(a: ArrayLike, offset?: number, axis1?: number, axis2?: nu
|
|
|
1948
1993
|
*/
|
|
1949
1994
|
declare function sort(a: ArrayLike, axis?: number): Array;
|
|
1950
1995
|
/**
|
|
1951
|
-
* Return indices that would sort an array.
|
|
1952
|
-
* algorithm; it
|
|
1996
|
+
* Return indices that would sort an array. Unlike `sort`, this is guaranteed to
|
|
1997
|
+
* be a stable sorting algorithm; it always returns the smaller index first in
|
|
1998
|
+
* event of ties.
|
|
1953
1999
|
*
|
|
1954
2000
|
* Returns an array of `int32` indices.
|
|
1955
2001
|
*
|
|
@@ -1963,13 +2009,39 @@ declare function argsort(a: ArrayLike, axis?: number): Array;
|
|
|
1963
2009
|
* numbered axis. By default, the flattened array is used.
|
|
1964
2010
|
*/
|
|
1965
2011
|
declare function take(a: ArrayLike, indices: ArrayLike, axis?: number | null): Array;
|
|
1966
|
-
/**
|
|
2012
|
+
/**
|
|
2013
|
+
* Return if two arrays are element-wise equal within a tolerance.
|
|
2014
|
+
*
|
|
2015
|
+
* The formula used is `|actual - expected| <= atol + rtol * |expected|`, with
|
|
2016
|
+
* NaN values comparing equal if `equalNaN` is true.
|
|
2017
|
+
*/
|
|
1967
2018
|
declare function allclose(actual: Parameters<typeof array>[0], expected: Parameters<typeof array>[0], options?: {
|
|
1968
2019
|
rtol?: number;
|
|
1969
2020
|
atol?: number;
|
|
2021
|
+
equalNaN?: boolean;
|
|
1970
2022
|
}): boolean;
|
|
2023
|
+
/**
|
|
2024
|
+
* Check if two arrays are element-wise equal.
|
|
2025
|
+
*
|
|
2026
|
+
* Returns False if the arrays have different shapes. If `equalNaN` is True,
|
|
2027
|
+
* NaNs in the same position are considered equal.
|
|
2028
|
+
*/
|
|
2029
|
+
declare function arrayEqual(a1: ArrayLike, a2: ArrayLike, opts?: {
|
|
2030
|
+
equalNaN?: boolean;
|
|
2031
|
+
}): Array;
|
|
2032
|
+
/**
|
|
2033
|
+
* Check if two arrays are element-wise equal after broadcasting.
|
|
2034
|
+
*
|
|
2035
|
+
* Unlike `arrayEqual`, this allows inputs with different but
|
|
2036
|
+
* broadcast-compatible shapes.
|
|
2037
|
+
*/
|
|
2038
|
+
declare function arrayEquiv(a1: ArrayLike, a2: ArrayLike): Array;
|
|
1971
2039
|
/** Matrix product of two arrays. */
|
|
1972
2040
|
declare function matmul(x: ArrayLike, y: ArrayLike): Array;
|
|
2041
|
+
/** Matrix-vector product. x1 is [..., M, N], x2 is [..., N] → [..., M]. */
|
|
2042
|
+
declare function matvec(x1: ArrayLike, x2: ArrayLike): Array;
|
|
2043
|
+
/** Vector-matrix product. x1 is [..., N], x2 is [..., N, M] → [..., M]. */
|
|
2044
|
+
declare function vecmat(x1: ArrayLike, x2: ArrayLike): Array;
|
|
1973
2045
|
/** Dot product of two arrays. */
|
|
1974
2046
|
declare function dot(x: ArrayLike, y: ArrayLike): Array;
|
|
1975
2047
|
/**
|
|
@@ -2022,6 +2094,18 @@ declare function inner(x: ArrayLike, y: ArrayLike): Array;
|
|
|
2022
2094
|
* be of shape `[x.size, y.size]`.
|
|
2023
2095
|
*/
|
|
2024
2096
|
declare function outer(x: ArrayLike, y: ArrayLike): Array;
|
|
2097
|
+
/**
|
|
2098
|
+
* @function Compute the cross product of two arrays.
|
|
2099
|
+
*
|
|
2100
|
+
* Supports 2D (scalar result) and 3D cross products, with optional axis
|
|
2101
|
+
* arguments. If `axis` is given, it overrides `axisa`, `axisb`, and `axisc`.
|
|
2102
|
+
*/
|
|
2103
|
+
declare const cross: OwnedFunction<(a: ArrayLike, b: ArrayLike, args_2?: {
|
|
2104
|
+
axisa?: number | undefined;
|
|
2105
|
+
axisb?: number | undefined;
|
|
2106
|
+
axisc?: number | undefined;
|
|
2107
|
+
axis?: number | undefined;
|
|
2108
|
+
} | undefined) => Array>;
|
|
2025
2109
|
/** Vector dot product of two arrays along a given axis. */
|
|
2026
2110
|
declare function vecdot(x: ArrayLike, y: ArrayLike, {
|
|
2027
2111
|
axis
|
|
@@ -2067,14 +2151,13 @@ declare function clip(a: ArrayLike, min?: ArrayLike, max?: ArrayLike): Array;
|
|
|
2067
2151
|
declare function absolute(x: ArrayLike): Array;
|
|
2068
2152
|
/** Return an element-wise indication of sign of the input. */
|
|
2069
2153
|
declare function sign(x: ArrayLike): Array;
|
|
2070
|
-
/** @function Return element-wise positive values of the input (no-op). */
|
|
2071
|
-
declare const positive: (x: ArrayLike) => Array;
|
|
2072
2154
|
/**
|
|
2073
|
-
*
|
|
2074
|
-
*
|
|
2075
|
-
* `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
2155
|
+
* @function
|
|
2156
|
+
* Return the value with the magnitude of x and the sign of y, element-wise.
|
|
2076
2157
|
*/
|
|
2077
|
-
declare
|
|
2158
|
+
declare const copysign: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
|
|
2159
|
+
/** @function Return element-wise positive values of the input (no-op). */
|
|
2160
|
+
declare const positive: (x: ArrayLike) => Array;
|
|
2078
2161
|
/**
|
|
2079
2162
|
* Return the Hann window of size M, a taper with a weighted cosine bell.
|
|
2080
2163
|
*
|
|
@@ -2169,6 +2252,18 @@ declare const remainder: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
|
|
|
2169
2252
|
declare function divmod(x: ArrayLike, y: ArrayLike): [Array, Array];
|
|
2170
2253
|
/** Round input to the nearest integer towards zero. */
|
|
2171
2254
|
declare function trunc(x: ArrayLike): Array;
|
|
2255
|
+
/**
|
|
2256
|
+
* @function
|
|
2257
|
+
* Round to the given number of decimals.
|
|
2258
|
+
*
|
|
2259
|
+
* Uses banker's rounding (round half to even) to match NumPy/JAX behavior.
|
|
2260
|
+
*/
|
|
2261
|
+
declare const round: OwnedFunction<(a: ArrayLike, decimals?: number | undefined) => Array>;
|
|
2262
|
+
/**
|
|
2263
|
+
* @function
|
|
2264
|
+
* Round to the nearest integer, with ties going to the nearest even integer.
|
|
2265
|
+
*/
|
|
2266
|
+
declare const rint: OwnedFunction<(x: ArrayLike) => Array>;
|
|
2172
2267
|
/**
|
|
2173
2268
|
* Compute `x1 * 2 ** x2` as a standard multiplication and exponentiation.
|
|
2174
2269
|
*
|
|
@@ -2561,7 +2656,7 @@ declare function dotProductAttention(query: ArrayLike, key: ArrayLike, value: Ar
|
|
|
2561
2656
|
localWindowSize?: number | [number, number];
|
|
2562
2657
|
}): Array;
|
|
2563
2658
|
declare namespace random_d_exports {
|
|
2564
|
-
export { bernoulli, bits, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
|
|
2659
|
+
export { bernoulli, bits, categorical, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
|
|
2565
2660
|
}
|
|
2566
2661
|
/** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
|
|
2567
2662
|
declare function key(seed: ArrayLike): Array;
|
|
@@ -2584,6 +2679,32 @@ declare const uniform: OwnedFunction<(key: ArrayLike, shape?: number[] | undefin
|
|
|
2584
2679
|
* and must be broadcastable to `shape`.
|
|
2585
2680
|
*/
|
|
2586
2681
|
declare function bernoulli(key: Array, p?: ArrayLike, shape?: number[]): Array;
|
|
2682
|
+
/**
|
|
2683
|
+
* @function
|
|
2684
|
+
* Sample random values from categorical distributions.
|
|
2685
|
+
*
|
|
2686
|
+
* Uses the Gumbel max trick for sampling with replacement, or the Gumbel top-k
|
|
2687
|
+
* trick for sampling without replacement.
|
|
2688
|
+
*
|
|
2689
|
+
* Note: Sampling without replacement currently uses argsort and slices the last
|
|
2690
|
+
* k elements. This should be replaced with a more efficient topK implementation.
|
|
2691
|
+
*
|
|
2692
|
+
* - `key` - PRNG key
|
|
2693
|
+
* - `logits` - Unnormalized log probabilities of the categorical distribution(s).
|
|
2694
|
+
* `softmax(logits, axis)` gives the corresponding probabilities.
|
|
2695
|
+
* - `axis` - Axis along which logits belong to the same categorical distribution.
|
|
2696
|
+
* - `shape` - Result batch shape. Must be broadcast-compatible with
|
|
2697
|
+
* `logits.shape` with `axis` removed. Default is `logits.shape` with `axis` removed.
|
|
2698
|
+
* - `replace` - If true (default), sample with replacement. If false, sample
|
|
2699
|
+
* without replacement (each category can only be selected once per batch).
|
|
2700
|
+
* @returns A random array with int dtype and shape given by `shape` if provided,
|
|
2701
|
+
* otherwise `logits.shape` with `axis` removed.
|
|
2702
|
+
*/
|
|
2703
|
+
declare const categorical: OwnedFunction<(key: ArrayLike, logits: ArrayLike, args_2?: {
|
|
2704
|
+
axis?: number | undefined;
|
|
2705
|
+
shape?: number[] | undefined;
|
|
2706
|
+
replace?: boolean | undefined;
|
|
2707
|
+
} | undefined) => Array>;
|
|
2587
2708
|
/**
|
|
2588
2709
|
* @function
|
|
2589
2710
|
* Sample from a Cauchy distribution with location 0 and scale 1.
|
|
@@ -2645,8 +2766,31 @@ declare namespace scipy_special_d_exports {
|
|
|
2645
2766
|
* The logit function, `logit(p) = log(p / (1-p))`.
|
|
2646
2767
|
*/
|
|
2647
2768
|
declare const logit: OwnedFunction<(x: ArrayLike) => Array>;
|
|
2769
|
+
//#endregion
|
|
2770
|
+
//#region src/tracing.d.ts
|
|
2771
|
+
/**
|
|
2772
|
+
* Start collecting kernel traces.
|
|
2773
|
+
*
|
|
2774
|
+
* Traces appear in developer tools under the "Performance" tab, and they are
|
|
2775
|
+
* useful for measuring fine-grained kernel execution time.
|
|
2776
|
+
*/
|
|
2777
|
+
declare function startTrace(): void;
|
|
2778
|
+
/**
|
|
2779
|
+
* Stop collecting kernel traces.
|
|
2780
|
+
*
|
|
2781
|
+
* Traces appear in developer tools under the "Performance" tab, and they are
|
|
2782
|
+
* useful for measuring fine-grained kernel execution time.
|
|
2783
|
+
*/
|
|
2784
|
+
declare function stopTrace(): void;
|
|
2785
|
+
/** Check if tracing is currently enabled. */
|
|
2786
|
+
|
|
2648
2787
|
//#endregion
|
|
2649
2788
|
//#region src/index.d.ts
|
|
2789
|
+
/** @namespace */
|
|
2790
|
+
declare const profiler: {
|
|
2791
|
+
startTrace: typeof startTrace;
|
|
2792
|
+
stopTrace: typeof stopTrace;
|
|
2793
|
+
};
|
|
2650
2794
|
/**
|
|
2651
2795
|
* @function
|
|
2652
2796
|
* Compute the forward-mode Jacobian-vector product for a function.
|
|
@@ -2811,4 +2955,4 @@ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
|
|
|
2811
2955
|
*/
|
|
2812
2956
|
declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
|
|
2813
2957
|
//#endregion
|
|
2814
|
-
export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|
|
2958
|
+
export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, profiler, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|