@jax-js/jax 0.1.8 → 0.1.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +46 -29
- package/dist/{backend-nEolvdLv.js → backend-Ctqs8la1.js} +122 -15
- package/dist/{backend-B3foXiV_.cjs → backend-DMauYnfl.cjs} +157 -14
- package/dist/index.cjs +331 -46
- package/dist/index.d.cts +175 -31
- package/dist/index.d.ts +175 -31
- package/dist/index.js +331 -47
- package/dist/{webgl-DweKSWEm.js → webgl-CvQ1QBX1.js} +1 -1
- package/dist/{webgl-DIIbKJ0G.cjs → webgl-kvVt7-T7.cjs} +1 -1
- package/dist/{webgpu-BykvF26B.cjs → webgpu-DMSx7a6M.cjs} +160 -15
- package/dist/{webgpu-B96vzWGE.js → webgpu-v_W_-oKw.js} +160 -15
- package/package.json +5 -16
package/dist/index.d.cts
CHANGED
|
@@ -436,9 +436,14 @@ declare class Routine {
|
|
|
436
436
|
}
|
|
437
437
|
/** One of the valid `Routine` that can be dispatched to backend. */
|
|
438
438
|
declare enum Routines {
|
|
439
|
-
/**
|
|
439
|
+
/**
|
|
440
|
+
* Sort along the last axis.
|
|
441
|
+
*
|
|
442
|
+
* This may be _unstable_ but it often doesn't matter, sorting numbers is
|
|
443
|
+
* bitwise unique up to signed zeros and NaNs.
|
|
444
|
+
*/
|
|
440
445
|
Sort = "Sort",
|
|
441
|
-
/**
|
|
446
|
+
/** Stable sorting, returns `int32` indices and values of the sorted array. */
|
|
442
447
|
Argsort = "Argsort",
|
|
443
448
|
/**
|
|
444
449
|
* Solve a triangular system of equations.
|
|
@@ -750,9 +755,9 @@ declare enum Primitive {
|
|
|
750
755
|
Shrink = "shrink",
|
|
751
756
|
Pad = "pad",
|
|
752
757
|
Sort = "sort",
|
|
753
|
-
// sort(x, axis=-1)
|
|
758
|
+
// sort(x, axis=-1), unstable
|
|
754
759
|
Argsort = "argsort",
|
|
755
|
-
// argsort(x, axis=-1)
|
|
760
|
+
// argsort(x, axis=-1), stable
|
|
756
761
|
TriangularSolve = "triangular_solve",
|
|
757
762
|
// A is upper triangular, A @ X.T = B.T
|
|
758
763
|
Cholesky = "cholesky",
|
|
@@ -999,6 +1004,8 @@ declare abstract class Tracer {
|
|
|
999
1004
|
reshape(shape: number | number[]): this;
|
|
1000
1005
|
/** Copy the array and cast to a specified dtype. */
|
|
1001
1006
|
astype(dtype: DType): this;
|
|
1007
|
+
/** Return a bitwise cast of the array, viewed as a new dtype. */
|
|
1008
|
+
view(dtype?: DType): this;
|
|
1002
1009
|
/** Subtract an array from this one. */
|
|
1003
1010
|
sub(other: this | TracerValue): this;
|
|
1004
1011
|
/** Divide an array by this one. */
|
|
@@ -1029,8 +1036,9 @@ declare abstract class Tracer {
|
|
|
1029
1036
|
*/
|
|
1030
1037
|
sort(axis?: number): this;
|
|
1031
1038
|
/**
|
|
1032
|
-
* Return the indices that would sort an array.
|
|
1033
|
-
* sorting algorithm; it
|
|
1039
|
+
* Return the indices that would sort an array. Unlike `sort`, this is
|
|
1040
|
+
* guaranteed to be a stable sorting algorithm; it always returns the smaller
|
|
1041
|
+
* index first in event of ties.
|
|
1034
1042
|
*
|
|
1035
1043
|
* See `jax.numpy.argsort` for full docs.
|
|
1036
1044
|
*/
|
|
@@ -1112,6 +1120,12 @@ type DTypeAndDevice = {
|
|
|
1112
1120
|
dtype?: DType;
|
|
1113
1121
|
device?: Device;
|
|
1114
1122
|
};
|
|
1123
|
+
/** @inline */
|
|
1124
|
+
type DTypeShapeAndDevice = {
|
|
1125
|
+
dtype?: DType;
|
|
1126
|
+
shape?: number[];
|
|
1127
|
+
device?: Device;
|
|
1128
|
+
};
|
|
1115
1129
|
type ArrayConstructorArgs = {
|
|
1116
1130
|
source: AluExp | Slot;
|
|
1117
1131
|
st: ShapeTracker;
|
|
@@ -1221,15 +1235,9 @@ declare function array(values: Array | DataArray | RecursiveArray<number> | Recu
|
|
|
1221
1235
|
type ImplRule<P extends Primitive> = (tracers: Array[], params: PrimitiveParams<P>) => Array[];
|
|
1222
1236
|
declare const implRules: { [P in Primitive]: ImplRule<P> };
|
|
1223
1237
|
/** Return a new array of given shape and type, filled with zeros. */
|
|
1224
|
-
declare function zeros(shape: number[],
|
|
1225
|
-
dtype,
|
|
1226
|
-
device
|
|
1227
|
-
}?: DTypeAndDevice): Array;
|
|
1238
|
+
declare function zeros(shape: number[], opts?: DTypeAndDevice): Array;
|
|
1228
1239
|
/** Return a new array of given shape and type, filled with ones. */
|
|
1229
|
-
declare function ones(shape: number[],
|
|
1230
|
-
dtype,
|
|
1231
|
-
device
|
|
1232
|
-
}?: DTypeAndDevice): Array;
|
|
1240
|
+
declare function ones(shape: number[], opts?: DTypeAndDevice): Array;
|
|
1233
1241
|
/** Return a new array of given shape and type, filled with `fill_value`. */
|
|
1234
1242
|
declare function full(shape: number[], fillValue: number | boolean | Array, {
|
|
1235
1243
|
dtype,
|
|
@@ -1421,8 +1429,10 @@ declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
|
|
|
1421
1429
|
unitDiagonal?: boolean;
|
|
1422
1430
|
}): Array;
|
|
1423
1431
|
declare namespace lax_d_exports {
|
|
1424
|
-
export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
|
|
1432
|
+
export { DotDimensionNumbers, PaddingType, bitcastConvertType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient, topK };
|
|
1425
1433
|
}
|
|
1434
|
+
/** Elementwise bitcast an array into a new dtype. */
|
|
1435
|
+
declare function bitcastConvertType(x: ArrayLike, newDtype: DType): Array;
|
|
1426
1436
|
/**
|
|
1427
1437
|
* Dimension numbers for general `dot()` primitive.
|
|
1428
1438
|
*
|
|
@@ -1527,6 +1537,16 @@ declare function erfc(x: ArrayLike): Array;
|
|
|
1527
1537
|
* forward or reverse-mode automatic differentiation.
|
|
1528
1538
|
*/
|
|
1529
1539
|
declare function stopGradient(x: ArrayLike): Array;
|
|
1540
|
+
/**
|
|
1541
|
+
* Returns top `k` values and their indices along the specified axis of operand.
|
|
1542
|
+
*
|
|
1543
|
+
* This is a _stable_ algorithm: If two elements are equal, the lower-index
|
|
1544
|
+
* element appears first.
|
|
1545
|
+
*
|
|
1546
|
+
* @returns A tuple of `(values, indices)`, where `values` and `indices` have
|
|
1547
|
+
* the same shape as `x`, except along `axis` where they have size `k`.
|
|
1548
|
+
*/
|
|
1549
|
+
declare function topK(x: ArrayLike, k: number, axis?: number): [Array, Array];
|
|
1530
1550
|
declare namespace numpy_fft_d_exports {
|
|
1531
1551
|
export { ComplexPair, fft, ifft };
|
|
1532
1552
|
}
|
|
@@ -1551,7 +1571,7 @@ declare function fft(a: ComplexPair, axis?: number): ComplexPair;
|
|
|
1551
1571
|
*/
|
|
1552
1572
|
declare function ifft(a: ComplexPair, axis?: number): ComplexPair;
|
|
1553
1573
|
declare namespace numpy_linalg_d_exports {
|
|
1554
|
-
export { cholesky, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot };
|
|
1574
|
+
export { cholesky, cross$1 as cross, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot };
|
|
1555
1575
|
}
|
|
1556
1576
|
/**
|
|
1557
1577
|
* Compute the Cholesky decomposition of a (batched) positive-definite matrix.
|
|
@@ -1566,6 +1586,13 @@ declare function cholesky(a: ArrayLike, {
|
|
|
1566
1586
|
upper?: boolean;
|
|
1567
1587
|
symmetrizeInput?: boolean;
|
|
1568
1588
|
}): Array;
|
|
1589
|
+
/**
|
|
1590
|
+
* Compute the cross-product of two 3D vectors.
|
|
1591
|
+
*
|
|
1592
|
+
* This is a simpler and less flexible version of `jax.numpy.cross()`.
|
|
1593
|
+
* Both inputs must have size 3 along the specified axis.
|
|
1594
|
+
*/
|
|
1595
|
+
declare function cross$1(x1: ArrayLike, x2: ArrayLike, axis?: number): Array;
|
|
1569
1596
|
/** Compute the determinant of a square matrix (batched). */
|
|
1570
1597
|
declare function det(a: ArrayLike): Array;
|
|
1571
1598
|
/** Compute the inverse of a square matrix (batched). */
|
|
@@ -1652,7 +1679,7 @@ type IInfo = Readonly<{
|
|
|
1652
1679
|
/** Machine limits for integer types. */
|
|
1653
1680
|
declare function iinfo(dtype: DType): IInfo;
|
|
1654
1681
|
declare namespace numpy_d_exports {
|
|
1655
|
-
export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual,
|
|
1682
|
+
export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, round as around, array, arrayEqual, arrayEquiv, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, average, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, copysign, corrcoef, correlate, cos, cosh, cov, cross, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logicalAnd, logicalNot, logicalOr, logicalXor, logspace, matmul, matrixTranspose, matvec, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, nanToNum, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, rint, round, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vecmat, vstack, where, zeros, zerosLike };
|
|
1656
1683
|
}
|
|
1657
1684
|
declare const float32 = DType.Float32;
|
|
1658
1685
|
declare const int32 = DType.Int32;
|
|
@@ -1712,6 +1739,14 @@ declare const notEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
|
1712
1739
|
declare const greaterEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
1713
1740
|
/** @function Compare two arrays element-wise. */
|
|
1714
1741
|
declare const lessEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
1742
|
+
/** Compute element-wise logical AND. */
|
|
1743
|
+
declare function logicalAnd(x: ArrayLike, y: ArrayLike): Array;
|
|
1744
|
+
/** Compute element-wise logical OR. */
|
|
1745
|
+
declare function logicalOr(x: ArrayLike, y: ArrayLike): Array;
|
|
1746
|
+
/** Compute element-wise logical XOR. */
|
|
1747
|
+
declare function logicalXor(x: ArrayLike, y: ArrayLike): Array;
|
|
1748
|
+
/** Compute element-wise logical NOT. */
|
|
1749
|
+
declare function logicalNot(x: ArrayLike): Array;
|
|
1715
1750
|
/** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
1716
1751
|
declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
|
|
1717
1752
|
/**
|
|
@@ -1752,17 +1787,17 @@ declare const shape$1: (x: ArrayLike) => number[];
|
|
|
1752
1787
|
* @function
|
|
1753
1788
|
* Return an array of zeros with the same shape and type as a given array.
|
|
1754
1789
|
*/
|
|
1755
|
-
declare const zerosLike: (a: ArrayLike,
|
|
1790
|
+
declare const zerosLike: (a: ArrayLike, opts?: DTypeShapeAndDevice) => Array;
|
|
1756
1791
|
/**
|
|
1757
1792
|
* @function
|
|
1758
1793
|
* Return an array of ones with the same shape and type as a given array.
|
|
1759
1794
|
*/
|
|
1760
|
-
declare const onesLike: (a: ArrayLike,
|
|
1795
|
+
declare const onesLike: (a: ArrayLike, opts?: DTypeShapeAndDevice) => Array;
|
|
1761
1796
|
/**
|
|
1762
1797
|
* @function
|
|
1763
1798
|
* Return a full array with the same shape and type as a given array.
|
|
1764
1799
|
*/
|
|
1765
|
-
declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array,
|
|
1800
|
+
declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, opts?: DTypeShapeAndDevice) => Array;
|
|
1766
1801
|
/**
|
|
1767
1802
|
* Return the number of elements in an array, optionally along an axis.
|
|
1768
1803
|
* Does not consume array reference.
|
|
@@ -1796,6 +1831,16 @@ declare function all(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
|
1796
1831
|
declare function ptp(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1797
1832
|
/** Compute the average of the array elements along the specified axis. */
|
|
1798
1833
|
declare function mean(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1834
|
+
/**
|
|
1835
|
+
* Compute the weighted average along the specified axis.
|
|
1836
|
+
*
|
|
1837
|
+
* If no axis is specified, mean is computed along all the axes. The weights
|
|
1838
|
+
* should have shape matching that of `a`, or if an axis is specified, it should
|
|
1839
|
+
* match the shape along those axes.
|
|
1840
|
+
*/
|
|
1841
|
+
declare function average(a: ArrayLike, axis?: Axis, opts?: {
|
|
1842
|
+
weights?: ArrayLike;
|
|
1843
|
+
} & ReduceOpts): Array;
|
|
1799
1844
|
/**
|
|
1800
1845
|
* Returns the indices of the minimum values along an axis.
|
|
1801
1846
|
*
|
|
@@ -1951,8 +1996,9 @@ declare function trace(a: ArrayLike, offset?: number, axis1?: number, axis2?: nu
|
|
|
1951
1996
|
*/
|
|
1952
1997
|
declare function sort(a: ArrayLike, axis?: number): Array;
|
|
1953
1998
|
/**
|
|
1954
|
-
* Return indices that would sort an array.
|
|
1955
|
-
* algorithm; it
|
|
1999
|
+
* Return indices that would sort an array. Unlike `sort`, this is guaranteed to
|
|
2000
|
+
* be a stable sorting algorithm; it always returns the smaller index first in
|
|
2001
|
+
* event of ties.
|
|
1956
2002
|
*
|
|
1957
2003
|
* Returns an array of `int32` indices.
|
|
1958
2004
|
*
|
|
@@ -1966,13 +2012,39 @@ declare function argsort(a: ArrayLike, axis?: number): Array;
|
|
|
1966
2012
|
* numbered axis. By default, the flattened array is used.
|
|
1967
2013
|
*/
|
|
1968
2014
|
declare function take(a: ArrayLike, indices: ArrayLike, axis?: number | null): Array;
|
|
1969
|
-
/**
|
|
2015
|
+
/**
|
|
2016
|
+
* Return if two arrays are element-wise equal within a tolerance.
|
|
2017
|
+
*
|
|
2018
|
+
* The formula used is `|actual - expected| <= atol + rtol * |expected|`, with
|
|
2019
|
+
* NaN values comparing equal if `equalNaN` is true.
|
|
2020
|
+
*/
|
|
1970
2021
|
declare function allclose(actual: Parameters<typeof array>[0], expected: Parameters<typeof array>[0], options?: {
|
|
1971
2022
|
rtol?: number;
|
|
1972
2023
|
atol?: number;
|
|
2024
|
+
equalNaN?: boolean;
|
|
1973
2025
|
}): boolean;
|
|
2026
|
+
/**
|
|
2027
|
+
* Check if two arrays are element-wise equal.
|
|
2028
|
+
*
|
|
2029
|
+
* Returns False if the arrays have different shapes. If `equalNaN` is True,
|
|
2030
|
+
* NaNs in the same position are considered equal.
|
|
2031
|
+
*/
|
|
2032
|
+
declare function arrayEqual(a1: ArrayLike, a2: ArrayLike, opts?: {
|
|
2033
|
+
equalNaN?: boolean;
|
|
2034
|
+
}): Array;
|
|
2035
|
+
/**
|
|
2036
|
+
* Check if two arrays are element-wise equal after broadcasting.
|
|
2037
|
+
*
|
|
2038
|
+
* Unlike `arrayEqual`, this allows inputs with different but
|
|
2039
|
+
* broadcast-compatible shapes.
|
|
2040
|
+
*/
|
|
2041
|
+
declare function arrayEquiv(a1: ArrayLike, a2: ArrayLike): Array;
|
|
1974
2042
|
/** Matrix product of two arrays. */
|
|
1975
2043
|
declare function matmul(x: ArrayLike, y: ArrayLike): Array;
|
|
2044
|
+
/** Matrix-vector product. x1 is [..., M, N], x2 is [..., N] → [..., M]. */
|
|
2045
|
+
declare function matvec(x1: ArrayLike, x2: ArrayLike): Array;
|
|
2046
|
+
/** Vector-matrix product. x1 is [..., N], x2 is [..., N, M] → [..., M]. */
|
|
2047
|
+
declare function vecmat(x1: ArrayLike, x2: ArrayLike): Array;
|
|
1976
2048
|
/** Dot product of two arrays. */
|
|
1977
2049
|
declare function dot(x: ArrayLike, y: ArrayLike): Array;
|
|
1978
2050
|
/**
|
|
@@ -2025,6 +2097,18 @@ declare function inner(x: ArrayLike, y: ArrayLike): Array;
|
|
|
2025
2097
|
* be of shape `[x.size, y.size]`.
|
|
2026
2098
|
*/
|
|
2027
2099
|
declare function outer(x: ArrayLike, y: ArrayLike): Array;
|
|
2100
|
+
/**
|
|
2101
|
+
* @function Compute the cross product of two arrays.
|
|
2102
|
+
*
|
|
2103
|
+
* Supports 2D (scalar result) and 3D cross products, with optional axis
|
|
2104
|
+
* arguments. If `axis` is given, it overrides `axisa`, `axisb`, and `axisc`.
|
|
2105
|
+
*/
|
|
2106
|
+
declare const cross: OwnedFunction<(a: ArrayLike, b: ArrayLike, args_2?: {
|
|
2107
|
+
axisa?: number | undefined;
|
|
2108
|
+
axisb?: number | undefined;
|
|
2109
|
+
axisc?: number | undefined;
|
|
2110
|
+
axis?: number | undefined;
|
|
2111
|
+
} | undefined) => Array>;
|
|
2028
2112
|
/** Vector dot product of two arrays along a given axis. */
|
|
2029
2113
|
declare function vecdot(x: ArrayLike, y: ArrayLike, {
|
|
2030
2114
|
axis
|
|
@@ -2070,14 +2154,13 @@ declare function clip(a: ArrayLike, min?: ArrayLike, max?: ArrayLike): Array;
|
|
|
2070
2154
|
declare function absolute(x: ArrayLike): Array;
|
|
2071
2155
|
/** Return an element-wise indication of sign of the input. */
|
|
2072
2156
|
declare function sign(x: ArrayLike): Array;
|
|
2073
|
-
/** @function Return element-wise positive values of the input (no-op). */
|
|
2074
|
-
declare const positive: (x: ArrayLike) => Array;
|
|
2075
2157
|
/**
|
|
2076
|
-
*
|
|
2077
|
-
*
|
|
2078
|
-
* `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
2158
|
+
* @function
|
|
2159
|
+
* Return the value with the magnitude of x and the sign of y, element-wise.
|
|
2079
2160
|
*/
|
|
2080
|
-
declare
|
|
2161
|
+
declare const copysign: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
|
|
2162
|
+
/** @function Return element-wise positive values of the input (no-op). */
|
|
2163
|
+
declare const positive: (x: ArrayLike) => Array;
|
|
2081
2164
|
/**
|
|
2082
2165
|
* Return the Hann window of size M, a taper with a weighted cosine bell.
|
|
2083
2166
|
*
|
|
@@ -2172,6 +2255,18 @@ declare const remainder: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
|
|
|
2172
2255
|
declare function divmod(x: ArrayLike, y: ArrayLike): [Array, Array];
|
|
2173
2256
|
/** Round input to the nearest integer towards zero. */
|
|
2174
2257
|
declare function trunc(x: ArrayLike): Array;
|
|
2258
|
+
/**
|
|
2259
|
+
* @function
|
|
2260
|
+
* Round to the given number of decimals.
|
|
2261
|
+
*
|
|
2262
|
+
* Uses banker's rounding (round half to even) to match NumPy/JAX behavior.
|
|
2263
|
+
*/
|
|
2264
|
+
declare const round: OwnedFunction<(a: ArrayLike, decimals?: number | undefined) => Array>;
|
|
2265
|
+
/**
|
|
2266
|
+
* @function
|
|
2267
|
+
* Round to the nearest integer, with ties going to the nearest even integer.
|
|
2268
|
+
*/
|
|
2269
|
+
declare const rint: OwnedFunction<(x: ArrayLike) => Array>;
|
|
2175
2270
|
/**
|
|
2176
2271
|
* Compute `x1 * 2 ** x2` as a standard multiplication and exponentiation.
|
|
2177
2272
|
*
|
|
@@ -2564,7 +2659,7 @@ declare function dotProductAttention(query: ArrayLike, key: ArrayLike, value: Ar
|
|
|
2564
2659
|
localWindowSize?: number | [number, number];
|
|
2565
2660
|
}): Array;
|
|
2566
2661
|
declare namespace random_d_exports {
|
|
2567
|
-
export { bernoulli, bits, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
|
|
2662
|
+
export { bernoulli, bits, categorical, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
|
|
2568
2663
|
}
|
|
2569
2664
|
/** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
|
|
2570
2665
|
declare function key(seed: ArrayLike): Array;
|
|
@@ -2587,6 +2682,32 @@ declare const uniform: OwnedFunction<(key: ArrayLike, shape?: number[] | undefin
|
|
|
2587
2682
|
* and must be broadcastable to `shape`.
|
|
2588
2683
|
*/
|
|
2589
2684
|
declare function bernoulli(key: Array, p?: ArrayLike, shape?: number[]): Array;
|
|
2685
|
+
/**
|
|
2686
|
+
* @function
|
|
2687
|
+
* Sample random values from categorical distributions.
|
|
2688
|
+
*
|
|
2689
|
+
* Uses the Gumbel max trick for sampling with replacement, or the Gumbel top-k
|
|
2690
|
+
* trick for sampling without replacement.
|
|
2691
|
+
*
|
|
2692
|
+
* Note: Sampling without replacement currently uses argsort and slices the last
|
|
2693
|
+
* k elements. This should be replaced with a more efficient topK implementation.
|
|
2694
|
+
*
|
|
2695
|
+
* - `key` - PRNG key
|
|
2696
|
+
* - `logits` - Unnormalized log probabilities of the categorical distribution(s).
|
|
2697
|
+
* `softmax(logits, axis)` gives the corresponding probabilities.
|
|
2698
|
+
* - `axis` - Axis along which logits belong to the same categorical distribution.
|
|
2699
|
+
* - `shape` - Result batch shape. Must be broadcast-compatible with
|
|
2700
|
+
* `logits.shape` with `axis` removed. Default is `logits.shape` with `axis` removed.
|
|
2701
|
+
* - `replace` - If true (default), sample with replacement. If false, sample
|
|
2702
|
+
* without replacement (each category can only be selected once per batch).
|
|
2703
|
+
* @returns A random array with int dtype and shape given by `shape` if provided,
|
|
2704
|
+
* otherwise `logits.shape` with `axis` removed.
|
|
2705
|
+
*/
|
|
2706
|
+
declare const categorical: OwnedFunction<(key: ArrayLike, logits: ArrayLike, args_2?: {
|
|
2707
|
+
axis?: number | undefined;
|
|
2708
|
+
shape?: number[] | undefined;
|
|
2709
|
+
replace?: boolean | undefined;
|
|
2710
|
+
} | undefined) => Array>;
|
|
2590
2711
|
/**
|
|
2591
2712
|
* @function
|
|
2592
2713
|
* Sample from a Cauchy distribution with location 0 and scale 1.
|
|
@@ -2648,8 +2769,31 @@ declare namespace scipy_special_d_exports {
|
|
|
2648
2769
|
* The logit function, `logit(p) = log(p / (1-p))`.
|
|
2649
2770
|
*/
|
|
2650
2771
|
declare const logit: OwnedFunction<(x: ArrayLike) => Array>;
|
|
2772
|
+
//#endregion
|
|
2773
|
+
//#region src/tracing.d.ts
|
|
2774
|
+
/**
|
|
2775
|
+
* Start collecting kernel traces.
|
|
2776
|
+
*
|
|
2777
|
+
* Traces appear in developer tools under the "Performance" tab, and they are
|
|
2778
|
+
* useful for measuring fine-grained kernel execution time.
|
|
2779
|
+
*/
|
|
2780
|
+
declare function startTrace(): void;
|
|
2781
|
+
/**
|
|
2782
|
+
* Stop collecting kernel traces.
|
|
2783
|
+
*
|
|
2784
|
+
* Traces appear in developer tools under the "Performance" tab, and they are
|
|
2785
|
+
* useful for measuring fine-grained kernel execution time.
|
|
2786
|
+
*/
|
|
2787
|
+
declare function stopTrace(): void;
|
|
2788
|
+
/** Check if tracing is currently enabled. */
|
|
2789
|
+
|
|
2651
2790
|
//#endregion
|
|
2652
2791
|
//#region src/index.d.ts
|
|
2792
|
+
/** @namespace */
|
|
2793
|
+
declare const profiler: {
|
|
2794
|
+
startTrace: typeof startTrace;
|
|
2795
|
+
stopTrace: typeof stopTrace;
|
|
2796
|
+
};
|
|
2653
2797
|
/**
|
|
2654
2798
|
* @function
|
|
2655
2799
|
* Compute the forward-mode Jacobian-vector product for a function.
|
|
@@ -2814,4 +2958,4 @@ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
|
|
|
2814
2958
|
*/
|
|
2815
2959
|
declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
|
|
2816
2960
|
//#endregion
|
|
2817
|
-
export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|
|
2961
|
+
export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, profiler, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|