@jax-js/jax 0.1.9 → 0.1.11
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +35 -19
- package/dist/{backend-BId79r5b.js → backend-DZvR7mZV.js} +831 -26
- package/dist/{backend-DpI0riom.cjs → backend-DlYlOYqN.cjs} +872 -25
- package/dist/index.cjs +364 -20
- package/dist/index.d.cts +175 -11
- package/dist/index.d.ts +175 -11
- package/dist/index.js +363 -21
- package/dist/{webgl-DnGrclTz.js → webgl-D8-14NzA.js} +7 -1
- package/dist/{webgl-C5NjXc1p.cjs → webgl-Ovaaa-Qx.cjs} +7 -1
- package/dist/{webgpu-AN0cG_nB.js → webgpu-Dg8FpYrH.js} +141 -6
- package/dist/{webgpu-CdjiJSa7.cjs → webgpu-uU9nnttc.cjs} +141 -6
- package/package.json +5 -16
package/dist/index.d.cts
CHANGED
|
@@ -232,6 +232,8 @@ declare class AluExp implements FpHashable {
|
|
|
232
232
|
static cast(dtype: DType, a: AluExp): AluExp;
|
|
233
233
|
static bitcast(dtype: DType, a: AluExp): AluExp;
|
|
234
234
|
static threefry2x32(k0: AluExp, k1: AluExp, c0: AluExp, c1: AluExp, mode?: "xor" | 0 | 1): AluExp;
|
|
235
|
+
static bitCombine(a: AluExp, b: AluExp, mode: "and" | "or" | "xor"): AluExp;
|
|
236
|
+
static bitShift(a: AluExp, b: AluExp, mode: "shl" | "shr"): AluExp;
|
|
235
237
|
static cmplt(a: AluExp, b: AluExp): AluExp;
|
|
236
238
|
static cmpne(a: AluExp, b: AluExp): AluExp;
|
|
237
239
|
static where(cond: AluExp, a: AluExp, b: AluExp): AluExp;
|
|
@@ -323,6 +325,11 @@ declare enum AluOp {
|
|
|
323
325
|
Reciprocal = "Reciprocal",
|
|
324
326
|
Cast = "Cast",
|
|
325
327
|
Bitcast = "Bitcast",
|
|
328
|
+
BitCombine = "BitCombine",
|
|
329
|
+
// arg = 'or' | 'and' | 'xor'
|
|
330
|
+
BitInvert = "BitInvert",
|
|
331
|
+
BitShift = "BitShift",
|
|
332
|
+
// arg = 'shl' | 'shr'
|
|
326
333
|
Cmplt = "Cmplt",
|
|
327
334
|
Cmpne = "Cmpne",
|
|
328
335
|
Where = "Where",
|
|
@@ -546,6 +553,11 @@ declare class Executable<T = any> {
|
|
|
546
553
|
source: Kernel | Routine, /** Extra data specific to the backend running this executable. */
|
|
547
554
|
data: T);
|
|
548
555
|
}
|
|
556
|
+
/**
|
|
557
|
+
* If the WebGPU backend has been initialized, return the `GPUDevice` that this
|
|
558
|
+
* backend runs on. This is useful for sharing buffers.
|
|
559
|
+
*/
|
|
560
|
+
declare function getWebGPUDevice(): GPUDevice;
|
|
549
561
|
declare namespace tree_d_exports {
|
|
550
562
|
export { JsTree, JsTreeDef, MapJsTree, NodeType, dispose, flatten, leaves, map, ref, structure, unflatten };
|
|
551
563
|
}
|
|
@@ -719,6 +731,8 @@ declare enum Primitive {
|
|
|
719
731
|
// uses sign of numerator, C-style, matches JS but not Python
|
|
720
732
|
Min = "min",
|
|
721
733
|
Max = "max",
|
|
734
|
+
BitCombine = "bit_combine",
|
|
735
|
+
BitShift = "bit_shift",
|
|
722
736
|
Neg = "neg",
|
|
723
737
|
Reciprocal = "reciprocal",
|
|
724
738
|
Floor = "floor",
|
|
@@ -767,6 +781,12 @@ declare enum Primitive {
|
|
|
767
781
|
Jit = "jit",
|
|
768
782
|
}
|
|
769
783
|
interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
784
|
+
[Primitive.BitCombine]: {
|
|
785
|
+
op: "and" | "or" | "xor";
|
|
786
|
+
};
|
|
787
|
+
[Primitive.BitShift]: {
|
|
788
|
+
op: "shl" | "shr";
|
|
789
|
+
};
|
|
770
790
|
[Primitive.Cast]: {
|
|
771
791
|
dtype: DType;
|
|
772
792
|
};
|
|
@@ -1004,6 +1024,8 @@ declare abstract class Tracer {
|
|
|
1004
1024
|
reshape(shape: number | number[]): this;
|
|
1005
1025
|
/** Copy the array and cast to a specified dtype. */
|
|
1006
1026
|
astype(dtype: DType): this;
|
|
1027
|
+
/** Return a bitwise cast of the array, viewed as a new dtype. */
|
|
1028
|
+
view(dtype?: DType): this;
|
|
1007
1029
|
/** Subtract an array from this one. */
|
|
1008
1030
|
sub(other: this | TracerValue): this;
|
|
1009
1031
|
/** Divide an array by this one. */
|
|
@@ -1192,6 +1214,19 @@ declare class Array extends Tracer {
|
|
|
1192
1214
|
* recommended for performance reasons, as it will block rendering.
|
|
1193
1215
|
*/
|
|
1194
1216
|
dataSync(): DataArray;
|
|
1217
|
+
/**
|
|
1218
|
+
* Return this array as a WebGPU buffer (with `STORAGE | COPY_SRC`).
|
|
1219
|
+
*
|
|
1220
|
+
* Only available on the WebGPU backend. The array's memory is still managed
|
|
1221
|
+
* by jax-js, and it will be freed when the buffer is no longer in use. You
|
|
1222
|
+
* _should not_ mutate the buffer's contents.
|
|
1223
|
+
*
|
|
1224
|
+
* Note that the GPU buffer may be slightly larger than the array's size; it
|
|
1225
|
+
* will always be aligned to 4 bytes.
|
|
1226
|
+
*/
|
|
1227
|
+
gpuBuffer(): Promise<GPUBuffer>;
|
|
1228
|
+
/** Synchronous version of `Array.gpuBuffer()`. */
|
|
1229
|
+
gpuBufferSync(): GPUBuffer;
|
|
1195
1230
|
/**
|
|
1196
1231
|
* Convert this array into a JavaScript object.
|
|
1197
1232
|
*
|
|
@@ -1427,8 +1462,10 @@ declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
|
|
|
1427
1462
|
unitDiagonal?: boolean;
|
|
1428
1463
|
}): Array;
|
|
1429
1464
|
declare namespace lax_d_exports {
|
|
1430
|
-
export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient, topK };
|
|
1465
|
+
export { DotDimensionNumbers, PaddingType, bitcastConvertType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient, topK };
|
|
1431
1466
|
}
|
|
1467
|
+
/** Elementwise bitcast an array into a new dtype. */
|
|
1468
|
+
declare function bitcastConvertType(x: ArrayLike, newDtype: DType): Array;
|
|
1432
1469
|
/**
|
|
1433
1470
|
* Dimension numbers for general `dot()` primitive.
|
|
1434
1471
|
*
|
|
@@ -1567,7 +1604,7 @@ declare function fft(a: ComplexPair, axis?: number): ComplexPair;
|
|
|
1567
1604
|
*/
|
|
1568
1605
|
declare function ifft(a: ComplexPair, axis?: number): ComplexPair;
|
|
1569
1606
|
declare namespace numpy_linalg_d_exports {
|
|
1570
|
-
export { cholesky, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot };
|
|
1607
|
+
export { cholesky, cross$1 as cross, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot, vectorNorm };
|
|
1571
1608
|
}
|
|
1572
1609
|
/**
|
|
1573
1610
|
* Compute the Cholesky decomposition of a (batched) positive-definite matrix.
|
|
@@ -1582,6 +1619,13 @@ declare function cholesky(a: ArrayLike, {
|
|
|
1582
1619
|
upper?: boolean;
|
|
1583
1620
|
symmetrizeInput?: boolean;
|
|
1584
1621
|
}): Array;
|
|
1622
|
+
/**
|
|
1623
|
+
* Compute the cross-product of two 3D vectors.
|
|
1624
|
+
*
|
|
1625
|
+
* This is a simpler and less flexible version of `jax.numpy.cross()`.
|
|
1626
|
+
* Both inputs must have size 3 along the specified axis.
|
|
1627
|
+
*/
|
|
1628
|
+
declare function cross$1(x1: ArrayLike, x2: ArrayLike, axis?: number): Array;
|
|
1585
1629
|
/** Compute the determinant of a square matrix (batched). */
|
|
1586
1630
|
declare function det(a: ArrayLike): Array;
|
|
1587
1631
|
/** Compute the inverse of a square matrix (batched). */
|
|
@@ -1615,6 +1659,24 @@ declare function slogdet(a: ArrayLike): [Array, Array];
|
|
|
1615
1659
|
* @returns Solution `x` of shape `(..., N)` or `(..., N, M)`.
|
|
1616
1660
|
*/
|
|
1617
1661
|
declare function solve(a: ArrayLike, b: ArrayLike): Array;
|
|
1662
|
+
/**
|
|
1663
|
+
* Compute the vector norm of an array.
|
|
1664
|
+
*
|
|
1665
|
+
* @param x - Input array.
|
|
1666
|
+
* @param ord - Order of the norm (default 2). Supports `Infinity`, `-Infinity`, `0`, or any real number.
|
|
1667
|
+
* @param axis - Axis/axes to reduce over (default: all axes).
|
|
1668
|
+
* @param keepdims - Whether to keep reduced dimensions as size 1.
|
|
1669
|
+
* @returns The norm of `x`, reduced over the given axes.
|
|
1670
|
+
*/
|
|
1671
|
+
declare function vectorNorm(x: ArrayLike, {
|
|
1672
|
+
ord,
|
|
1673
|
+
axis,
|
|
1674
|
+
keepdims
|
|
1675
|
+
}?: {
|
|
1676
|
+
ord?: number;
|
|
1677
|
+
axis?: number | number[] | null;
|
|
1678
|
+
keepdims?: boolean;
|
|
1679
|
+
}): Array;
|
|
1618
1680
|
//#endregion
|
|
1619
1681
|
//#region src/library/numpy/dtype-info.d.ts
|
|
1620
1682
|
/** @inline */
|
|
@@ -1668,7 +1730,7 @@ type IInfo = Readonly<{
|
|
|
1668
1730
|
/** Machine limits for integer types. */
|
|
1669
1731
|
declare function iinfo(dtype: DType): IInfo;
|
|
1670
1732
|
declare namespace numpy_d_exports {
|
|
1671
|
-
export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual,
|
|
1733
|
+
export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, round as around, array, arrayEqual, arrayEquiv, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, average, bitwiseAnd, invert as bitwiseInvert, leftShift as bitwiseLeftShift, invert as bitwiseNot, bitwiseOr, rightShift as bitwiseRightShift, bitwiseXor, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, copysign, corrcoef, correlate, cos, cosh, cov, cross, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, invert, isfinite, isinf, isnan, isneginf, isposinf, ldexp, leftShift, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logicalAnd, logicalNot, logicalOr, logicalXor, logspace, matmul, matrixTranspose, matvec, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, nanToNum, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, rightShift, rint, round, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vecmat, vstack, where, zeros, zerosLike };
|
|
1672
1734
|
}
|
|
1673
1735
|
declare const float32 = DType.Float32;
|
|
1674
1736
|
declare const int32 = DType.Int32;
|
|
@@ -1728,6 +1790,26 @@ declare const notEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
|
1728
1790
|
declare const greaterEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
1729
1791
|
/** @function Compare two arrays element-wise. */
|
|
1730
1792
|
declare const lessEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
1793
|
+
/** Compute element-wise logical AND. */
|
|
1794
|
+
declare function logicalAnd(x: ArrayLike, y: ArrayLike): Array;
|
|
1795
|
+
/** Compute element-wise logical OR. */
|
|
1796
|
+
declare function logicalOr(x: ArrayLike, y: ArrayLike): Array;
|
|
1797
|
+
/** Compute element-wise logical XOR. */
|
|
1798
|
+
declare function logicalXor(x: ArrayLike, y: ArrayLike): Array;
|
|
1799
|
+
/** Compute element-wise logical NOT. */
|
|
1800
|
+
declare function logicalNot(x: ArrayLike): Array;
|
|
1801
|
+
/** Compute element-wise bitwise AND. */
|
|
1802
|
+
declare function bitwiseAnd(x: ArrayLike, y: ArrayLike): Array;
|
|
1803
|
+
/** Compute element-wise bitwise OR. */
|
|
1804
|
+
declare function bitwiseOr(x: ArrayLike, y: ArrayLike): Array;
|
|
1805
|
+
/** Compute element-wise bitwise XOR. */
|
|
1806
|
+
declare function bitwiseXor(x: ArrayLike, y: ArrayLike): Array;
|
|
1807
|
+
/** Compute element-wise bitwise NOT (inversion). */
|
|
1808
|
+
declare function invert(x: ArrayLike): Array;
|
|
1809
|
+
/** Compute element-wise left bit shift. */
|
|
1810
|
+
declare function leftShift(x: ArrayLike, y: ArrayLike): Array;
|
|
1811
|
+
/** Compute element-wise right bit shift. */
|
|
1812
|
+
declare function rightShift(x: ArrayLike, y: ArrayLike): Array;
|
|
1731
1813
|
/** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
1732
1814
|
declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
|
|
1733
1815
|
/**
|
|
@@ -1812,6 +1894,16 @@ declare function all(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
|
1812
1894
|
declare function ptp(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1813
1895
|
/** Compute the average of the array elements along the specified axis. */
|
|
1814
1896
|
declare function mean(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1897
|
+
/**
|
|
1898
|
+
* Compute the weighted average along the specified axis.
|
|
1899
|
+
*
|
|
1900
|
+
* If no axis is specified, mean is computed along all the axes. The weights
|
|
1901
|
+
* should have shape matching that of `a`, or if an axis is specified, it should
|
|
1902
|
+
* match the shape along those axes.
|
|
1903
|
+
*/
|
|
1904
|
+
declare function average(a: ArrayLike, axis?: Axis, opts?: {
|
|
1905
|
+
weights?: ArrayLike;
|
|
1906
|
+
} & ReduceOpts): Array;
|
|
1815
1907
|
/**
|
|
1816
1908
|
* Returns the indices of the minimum values along an axis.
|
|
1817
1909
|
*
|
|
@@ -1983,13 +2075,39 @@ declare function argsort(a: ArrayLike, axis?: number): Array;
|
|
|
1983
2075
|
* numbered axis. By default, the flattened array is used.
|
|
1984
2076
|
*/
|
|
1985
2077
|
declare function take(a: ArrayLike, indices: ArrayLike, axis?: number | null): Array;
|
|
1986
|
-
/**
|
|
2078
|
+
/**
|
|
2079
|
+
* Return if two arrays are element-wise equal within a tolerance.
|
|
2080
|
+
*
|
|
2081
|
+
* The formula used is `|actual - expected| <= atol + rtol * |expected|`, with
|
|
2082
|
+
* NaN values comparing equal if `equalNaN` is true.
|
|
2083
|
+
*/
|
|
1987
2084
|
declare function allclose(actual: Parameters<typeof array>[0], expected: Parameters<typeof array>[0], options?: {
|
|
1988
2085
|
rtol?: number;
|
|
1989
2086
|
atol?: number;
|
|
2087
|
+
equalNaN?: boolean;
|
|
1990
2088
|
}): boolean;
|
|
2089
|
+
/**
|
|
2090
|
+
* Check if two arrays are element-wise equal.
|
|
2091
|
+
*
|
|
2092
|
+
* Returns False if the arrays have different shapes. If `equalNaN` is True,
|
|
2093
|
+
* NaNs in the same position are considered equal.
|
|
2094
|
+
*/
|
|
2095
|
+
declare function arrayEqual(a1: ArrayLike, a2: ArrayLike, opts?: {
|
|
2096
|
+
equalNaN?: boolean;
|
|
2097
|
+
}): Array;
|
|
2098
|
+
/**
|
|
2099
|
+
* Check if two arrays are element-wise equal after broadcasting.
|
|
2100
|
+
*
|
|
2101
|
+
* Unlike `arrayEqual`, this allows inputs with different but
|
|
2102
|
+
* broadcast-compatible shapes.
|
|
2103
|
+
*/
|
|
2104
|
+
declare function arrayEquiv(a1: ArrayLike, a2: ArrayLike): Array;
|
|
1991
2105
|
/** Matrix product of two arrays. */
|
|
1992
2106
|
declare function matmul(x: ArrayLike, y: ArrayLike): Array;
|
|
2107
|
+
/** Matrix-vector product. x1 is [..., M, N], x2 is [..., N] → [..., M]. */
|
|
2108
|
+
declare function matvec(x1: ArrayLike, x2: ArrayLike): Array;
|
|
2109
|
+
/** Vector-matrix product. x1 is [..., N], x2 is [..., N, M] → [..., M]. */
|
|
2110
|
+
declare function vecmat(x1: ArrayLike, x2: ArrayLike): Array;
|
|
1993
2111
|
/** Dot product of two arrays. */
|
|
1994
2112
|
declare function dot(x: ArrayLike, y: ArrayLike): Array;
|
|
1995
2113
|
/**
|
|
@@ -2042,6 +2160,18 @@ declare function inner(x: ArrayLike, y: ArrayLike): Array;
|
|
|
2042
2160
|
* be of shape `[x.size, y.size]`.
|
|
2043
2161
|
*/
|
|
2044
2162
|
declare function outer(x: ArrayLike, y: ArrayLike): Array;
|
|
2163
|
+
/**
|
|
2164
|
+
* @function Compute the cross product of two arrays.
|
|
2165
|
+
*
|
|
2166
|
+
* Supports 2D (scalar result) and 3D cross products, with optional axis
|
|
2167
|
+
* arguments. If `axis` is given, it overrides `axisa`, `axisb`, and `axisc`.
|
|
2168
|
+
*/
|
|
2169
|
+
declare const cross: OwnedFunction<(a: ArrayLike, b: ArrayLike, args_2?: {
|
|
2170
|
+
axisa?: number | undefined;
|
|
2171
|
+
axisb?: number | undefined;
|
|
2172
|
+
axisc?: number | undefined;
|
|
2173
|
+
axis?: number | undefined;
|
|
2174
|
+
} | undefined) => Array>;
|
|
2045
2175
|
/** Vector dot product of two arrays along a given axis. */
|
|
2046
2176
|
declare function vecdot(x: ArrayLike, y: ArrayLike, {
|
|
2047
2177
|
axis
|
|
@@ -2087,14 +2217,13 @@ declare function clip(a: ArrayLike, min?: ArrayLike, max?: ArrayLike): Array;
|
|
|
2087
2217
|
declare function absolute(x: ArrayLike): Array;
|
|
2088
2218
|
/** Return an element-wise indication of sign of the input. */
|
|
2089
2219
|
declare function sign(x: ArrayLike): Array;
|
|
2090
|
-
/** @function Return element-wise positive values of the input (no-op). */
|
|
2091
|
-
declare const positive: (x: ArrayLike) => Array;
|
|
2092
2220
|
/**
|
|
2093
|
-
*
|
|
2094
|
-
*
|
|
2095
|
-
* `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
2221
|
+
* @function
|
|
2222
|
+
* Return the value with the magnitude of x and the sign of y, element-wise.
|
|
2096
2223
|
*/
|
|
2097
|
-
declare
|
|
2224
|
+
declare const copysign: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
|
|
2225
|
+
/** @function Return element-wise positive values of the input (no-op). */
|
|
2226
|
+
declare const positive: (x: ArrayLike) => Array;
|
|
2098
2227
|
/**
|
|
2099
2228
|
* Return the Hann window of size M, a taper with a weighted cosine bell.
|
|
2100
2229
|
*
|
|
@@ -2189,6 +2318,18 @@ declare const remainder: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
|
|
|
2189
2318
|
declare function divmod(x: ArrayLike, y: ArrayLike): [Array, Array];
|
|
2190
2319
|
/** Round input to the nearest integer towards zero. */
|
|
2191
2320
|
declare function trunc(x: ArrayLike): Array;
|
|
2321
|
+
/**
|
|
2322
|
+
* @function
|
|
2323
|
+
* Round to the given number of decimals.
|
|
2324
|
+
*
|
|
2325
|
+
* Uses banker's rounding (round half to even) to match NumPy/JAX behavior.
|
|
2326
|
+
*/
|
|
2327
|
+
declare const round: OwnedFunction<(a: ArrayLike, decimals?: number | undefined) => Array>;
|
|
2328
|
+
/**
|
|
2329
|
+
* @function
|
|
2330
|
+
* Round to the nearest integer, with ties going to the nearest even integer.
|
|
2331
|
+
*/
|
|
2332
|
+
declare const rint: OwnedFunction<(x: ArrayLike) => Array>;
|
|
2192
2333
|
/**
|
|
2193
2334
|
* Compute `x1 * 2 ** x2` as a standard multiplication and exponentiation.
|
|
2194
2335
|
*
|
|
@@ -2691,8 +2832,31 @@ declare namespace scipy_special_d_exports {
|
|
|
2691
2832
|
* The logit function, `logit(p) = log(p / (1-p))`.
|
|
2692
2833
|
*/
|
|
2693
2834
|
declare const logit: OwnedFunction<(x: ArrayLike) => Array>;
|
|
2835
|
+
//#endregion
|
|
2836
|
+
//#region src/tracing.d.ts
|
|
2837
|
+
/**
|
|
2838
|
+
* Start collecting kernel traces.
|
|
2839
|
+
*
|
|
2840
|
+
* Traces appear in developer tools under the "Performance" tab, and they are
|
|
2841
|
+
* useful for measuring fine-grained kernel execution time.
|
|
2842
|
+
*/
|
|
2843
|
+
declare function startTrace(): void;
|
|
2844
|
+
/**
|
|
2845
|
+
* Stop collecting kernel traces.
|
|
2846
|
+
*
|
|
2847
|
+
* Traces appear in developer tools under the "Performance" tab, and they are
|
|
2848
|
+
* useful for measuring fine-grained kernel execution time.
|
|
2849
|
+
*/
|
|
2850
|
+
declare function stopTrace(): void;
|
|
2851
|
+
/** Check if tracing is currently enabled. */
|
|
2852
|
+
|
|
2694
2853
|
//#endregion
|
|
2695
2854
|
//#region src/index.d.ts
|
|
2855
|
+
/** @namespace */
|
|
2856
|
+
declare const profiler: {
|
|
2857
|
+
startTrace: typeof startTrace;
|
|
2858
|
+
stopTrace: typeof stopTrace;
|
|
2859
|
+
};
|
|
2696
2860
|
/**
|
|
2697
2861
|
* @function
|
|
2698
2862
|
* Compute the forward-mode Jacobian-vector product for a function.
|
|
@@ -2857,4 +3021,4 @@ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
|
|
|
2857
3021
|
*/
|
|
2858
3022
|
declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
|
|
2859
3023
|
//#endregion
|
|
2860
|
-
export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|
|
3024
|
+
export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, getWebGPUDevice, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, profiler, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|
package/dist/index.d.ts
CHANGED
|
@@ -229,6 +229,8 @@ declare class AluExp implements FpHashable {
|
|
|
229
229
|
static cast(dtype: DType, a: AluExp): AluExp;
|
|
230
230
|
static bitcast(dtype: DType, a: AluExp): AluExp;
|
|
231
231
|
static threefry2x32(k0: AluExp, k1: AluExp, c0: AluExp, c1: AluExp, mode?: "xor" | 0 | 1): AluExp;
|
|
232
|
+
static bitCombine(a: AluExp, b: AluExp, mode: "and" | "or" | "xor"): AluExp;
|
|
233
|
+
static bitShift(a: AluExp, b: AluExp, mode: "shl" | "shr"): AluExp;
|
|
232
234
|
static cmplt(a: AluExp, b: AluExp): AluExp;
|
|
233
235
|
static cmpne(a: AluExp, b: AluExp): AluExp;
|
|
234
236
|
static where(cond: AluExp, a: AluExp, b: AluExp): AluExp;
|
|
@@ -320,6 +322,11 @@ declare enum AluOp {
|
|
|
320
322
|
Reciprocal = "Reciprocal",
|
|
321
323
|
Cast = "Cast",
|
|
322
324
|
Bitcast = "Bitcast",
|
|
325
|
+
BitCombine = "BitCombine",
|
|
326
|
+
// arg = 'or' | 'and' | 'xor'
|
|
327
|
+
BitInvert = "BitInvert",
|
|
328
|
+
BitShift = "BitShift",
|
|
329
|
+
// arg = 'shl' | 'shr'
|
|
323
330
|
Cmplt = "Cmplt",
|
|
324
331
|
Cmpne = "Cmpne",
|
|
325
332
|
Where = "Where",
|
|
@@ -543,6 +550,11 @@ declare class Executable<T = any> {
|
|
|
543
550
|
source: Kernel | Routine, /** Extra data specific to the backend running this executable. */
|
|
544
551
|
data: T);
|
|
545
552
|
}
|
|
553
|
+
/**
|
|
554
|
+
* If the WebGPU backend has been initialized, return the `GPUDevice` that this
|
|
555
|
+
* backend runs on. This is useful for sharing buffers.
|
|
556
|
+
*/
|
|
557
|
+
declare function getWebGPUDevice(): GPUDevice;
|
|
546
558
|
declare namespace tree_d_exports {
|
|
547
559
|
export { JsTree, JsTreeDef, MapJsTree, NodeType, dispose, flatten, leaves, map, ref, structure, unflatten };
|
|
548
560
|
}
|
|
@@ -716,6 +728,8 @@ declare enum Primitive {
|
|
|
716
728
|
// uses sign of numerator, C-style, matches JS but not Python
|
|
717
729
|
Min = "min",
|
|
718
730
|
Max = "max",
|
|
731
|
+
BitCombine = "bit_combine",
|
|
732
|
+
BitShift = "bit_shift",
|
|
719
733
|
Neg = "neg",
|
|
720
734
|
Reciprocal = "reciprocal",
|
|
721
735
|
Floor = "floor",
|
|
@@ -764,6 +778,12 @@ declare enum Primitive {
|
|
|
764
778
|
Jit = "jit",
|
|
765
779
|
}
|
|
766
780
|
interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
781
|
+
[Primitive.BitCombine]: {
|
|
782
|
+
op: "and" | "or" | "xor";
|
|
783
|
+
};
|
|
784
|
+
[Primitive.BitShift]: {
|
|
785
|
+
op: "shl" | "shr";
|
|
786
|
+
};
|
|
767
787
|
[Primitive.Cast]: {
|
|
768
788
|
dtype: DType;
|
|
769
789
|
};
|
|
@@ -1001,6 +1021,8 @@ declare abstract class Tracer {
|
|
|
1001
1021
|
reshape(shape: number | number[]): this;
|
|
1002
1022
|
/** Copy the array and cast to a specified dtype. */
|
|
1003
1023
|
astype(dtype: DType): this;
|
|
1024
|
+
/** Return a bitwise cast of the array, viewed as a new dtype. */
|
|
1025
|
+
view(dtype?: DType): this;
|
|
1004
1026
|
/** Subtract an array from this one. */
|
|
1005
1027
|
sub(other: this | TracerValue): this;
|
|
1006
1028
|
/** Divide an array by this one. */
|
|
@@ -1189,6 +1211,19 @@ declare class Array extends Tracer {
|
|
|
1189
1211
|
* recommended for performance reasons, as it will block rendering.
|
|
1190
1212
|
*/
|
|
1191
1213
|
dataSync(): DataArray;
|
|
1214
|
+
/**
|
|
1215
|
+
* Return this array as a WebGPU buffer (with `STORAGE | COPY_SRC`).
|
|
1216
|
+
*
|
|
1217
|
+
* Only available on the WebGPU backend. The array's memory is still managed
|
|
1218
|
+
* by jax-js, and it will be freed when the buffer is no longer in use. You
|
|
1219
|
+
* _should not_ mutate the buffer's contents.
|
|
1220
|
+
*
|
|
1221
|
+
* Note that the GPU buffer may be slightly larger than the array's size; it
|
|
1222
|
+
* will always be aligned to 4 bytes.
|
|
1223
|
+
*/
|
|
1224
|
+
gpuBuffer(): Promise<GPUBuffer>;
|
|
1225
|
+
/** Synchronous version of `Array.gpuBuffer()`. */
|
|
1226
|
+
gpuBufferSync(): GPUBuffer;
|
|
1192
1227
|
/**
|
|
1193
1228
|
* Convert this array into a JavaScript object.
|
|
1194
1229
|
*
|
|
@@ -1424,8 +1459,10 @@ declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
|
|
|
1424
1459
|
unitDiagonal?: boolean;
|
|
1425
1460
|
}): Array;
|
|
1426
1461
|
declare namespace lax_d_exports {
|
|
1427
|
-
export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient, topK };
|
|
1462
|
+
export { DotDimensionNumbers, PaddingType, bitcastConvertType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient, topK };
|
|
1428
1463
|
}
|
|
1464
|
+
/** Elementwise bitcast an array into a new dtype. */
|
|
1465
|
+
declare function bitcastConvertType(x: ArrayLike, newDtype: DType): Array;
|
|
1429
1466
|
/**
|
|
1430
1467
|
* Dimension numbers for general `dot()` primitive.
|
|
1431
1468
|
*
|
|
@@ -1564,7 +1601,7 @@ declare function fft(a: ComplexPair, axis?: number): ComplexPair;
|
|
|
1564
1601
|
*/
|
|
1565
1602
|
declare function ifft(a: ComplexPair, axis?: number): ComplexPair;
|
|
1566
1603
|
declare namespace numpy_linalg_d_exports {
|
|
1567
|
-
export { cholesky, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot };
|
|
1604
|
+
export { cholesky, cross$1 as cross, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot, vectorNorm };
|
|
1568
1605
|
}
|
|
1569
1606
|
/**
|
|
1570
1607
|
* Compute the Cholesky decomposition of a (batched) positive-definite matrix.
|
|
@@ -1579,6 +1616,13 @@ declare function cholesky(a: ArrayLike, {
|
|
|
1579
1616
|
upper?: boolean;
|
|
1580
1617
|
symmetrizeInput?: boolean;
|
|
1581
1618
|
}): Array;
|
|
1619
|
+
/**
|
|
1620
|
+
* Compute the cross-product of two 3D vectors.
|
|
1621
|
+
*
|
|
1622
|
+
* This is a simpler and less flexible version of `jax.numpy.cross()`.
|
|
1623
|
+
* Both inputs must have size 3 along the specified axis.
|
|
1624
|
+
*/
|
|
1625
|
+
declare function cross$1(x1: ArrayLike, x2: ArrayLike, axis?: number): Array;
|
|
1582
1626
|
/** Compute the determinant of a square matrix (batched). */
|
|
1583
1627
|
declare function det(a: ArrayLike): Array;
|
|
1584
1628
|
/** Compute the inverse of a square matrix (batched). */
|
|
@@ -1612,6 +1656,24 @@ declare function slogdet(a: ArrayLike): [Array, Array];
|
|
|
1612
1656
|
* @returns Solution `x` of shape `(..., N)` or `(..., N, M)`.
|
|
1613
1657
|
*/
|
|
1614
1658
|
declare function solve(a: ArrayLike, b: ArrayLike): Array;
|
|
1659
|
+
/**
|
|
1660
|
+
* Compute the vector norm of an array.
|
|
1661
|
+
*
|
|
1662
|
+
* @param x - Input array.
|
|
1663
|
+
* @param ord - Order of the norm (default 2). Supports `Infinity`, `-Infinity`, `0`, or any real number.
|
|
1664
|
+
* @param axis - Axis/axes to reduce over (default: all axes).
|
|
1665
|
+
* @param keepdims - Whether to keep reduced dimensions as size 1.
|
|
1666
|
+
* @returns The norm of `x`, reduced over the given axes.
|
|
1667
|
+
*/
|
|
1668
|
+
declare function vectorNorm(x: ArrayLike, {
|
|
1669
|
+
ord,
|
|
1670
|
+
axis,
|
|
1671
|
+
keepdims
|
|
1672
|
+
}?: {
|
|
1673
|
+
ord?: number;
|
|
1674
|
+
axis?: number | number[] | null;
|
|
1675
|
+
keepdims?: boolean;
|
|
1676
|
+
}): Array;
|
|
1615
1677
|
//#endregion
|
|
1616
1678
|
//#region src/library/numpy/dtype-info.d.ts
|
|
1617
1679
|
/** @inline */
|
|
@@ -1665,7 +1727,7 @@ type IInfo = Readonly<{
|
|
|
1665
1727
|
/** Machine limits for integer types. */
|
|
1666
1728
|
declare function iinfo(dtype: DType): IInfo;
|
|
1667
1729
|
declare namespace numpy_d_exports {
|
|
1668
|
-
export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual,
|
|
1730
|
+
export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, round as around, array, arrayEqual, arrayEquiv, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, average, bitwiseAnd, invert as bitwiseInvert, leftShift as bitwiseLeftShift, invert as bitwiseNot, bitwiseOr, rightShift as bitwiseRightShift, bitwiseXor, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, copysign, corrcoef, correlate, cos, cosh, cov, cross, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, invert, isfinite, isinf, isnan, isneginf, isposinf, ldexp, leftShift, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logicalAnd, logicalNot, logicalOr, logicalXor, logspace, matmul, matrixTranspose, matvec, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, nanToNum, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, rightShift, rint, round, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vecmat, vstack, where, zeros, zerosLike };
|
|
1669
1731
|
}
|
|
1670
1732
|
declare const float32 = DType.Float32;
|
|
1671
1733
|
declare const int32 = DType.Int32;
|
|
@@ -1725,6 +1787,26 @@ declare const notEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
|
1725
1787
|
declare const greaterEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
1726
1788
|
/** @function Compare two arrays element-wise. */
|
|
1727
1789
|
declare const lessEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
1790
|
+
/** Compute element-wise logical AND. */
|
|
1791
|
+
declare function logicalAnd(x: ArrayLike, y: ArrayLike): Array;
|
|
1792
|
+
/** Compute element-wise logical OR. */
|
|
1793
|
+
declare function logicalOr(x: ArrayLike, y: ArrayLike): Array;
|
|
1794
|
+
/** Compute element-wise logical XOR. */
|
|
1795
|
+
declare function logicalXor(x: ArrayLike, y: ArrayLike): Array;
|
|
1796
|
+
/** Compute element-wise logical NOT. */
|
|
1797
|
+
declare function logicalNot(x: ArrayLike): Array;
|
|
1798
|
+
/** Compute element-wise bitwise AND. */
|
|
1799
|
+
declare function bitwiseAnd(x: ArrayLike, y: ArrayLike): Array;
|
|
1800
|
+
/** Compute element-wise bitwise OR. */
|
|
1801
|
+
declare function bitwiseOr(x: ArrayLike, y: ArrayLike): Array;
|
|
1802
|
+
/** Compute element-wise bitwise XOR. */
|
|
1803
|
+
declare function bitwiseXor(x: ArrayLike, y: ArrayLike): Array;
|
|
1804
|
+
/** Compute element-wise bitwise NOT (inversion). */
|
|
1805
|
+
declare function invert(x: ArrayLike): Array;
|
|
1806
|
+
/** Compute element-wise left bit shift. */
|
|
1807
|
+
declare function leftShift(x: ArrayLike, y: ArrayLike): Array;
|
|
1808
|
+
/** Compute element-wise right bit shift. */
|
|
1809
|
+
declare function rightShift(x: ArrayLike, y: ArrayLike): Array;
|
|
1728
1810
|
/** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
1729
1811
|
declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
|
|
1730
1812
|
/**
|
|
@@ -1809,6 +1891,16 @@ declare function all(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
|
1809
1891
|
declare function ptp(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1810
1892
|
/** Compute the average of the array elements along the specified axis. */
|
|
1811
1893
|
declare function mean(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1894
|
+
/**
|
|
1895
|
+
* Compute the weighted average along the specified axis.
|
|
1896
|
+
*
|
|
1897
|
+
* If no axis is specified, mean is computed along all the axes. The weights
|
|
1898
|
+
* should have shape matching that of `a`, or if an axis is specified, it should
|
|
1899
|
+
* match the shape along those axes.
|
|
1900
|
+
*/
|
|
1901
|
+
declare function average(a: ArrayLike, axis?: Axis, opts?: {
|
|
1902
|
+
weights?: ArrayLike;
|
|
1903
|
+
} & ReduceOpts): Array;
|
|
1812
1904
|
/**
|
|
1813
1905
|
* Returns the indices of the minimum values along an axis.
|
|
1814
1906
|
*
|
|
@@ -1980,13 +2072,39 @@ declare function argsort(a: ArrayLike, axis?: number): Array;
|
|
|
1980
2072
|
* numbered axis. By default, the flattened array is used.
|
|
1981
2073
|
*/
|
|
1982
2074
|
declare function take(a: ArrayLike, indices: ArrayLike, axis?: number | null): Array;
|
|
1983
|
-
/**
|
|
2075
|
+
/**
|
|
2076
|
+
* Return if two arrays are element-wise equal within a tolerance.
|
|
2077
|
+
*
|
|
2078
|
+
* The formula used is `|actual - expected| <= atol + rtol * |expected|`, with
|
|
2079
|
+
* NaN values comparing equal if `equalNaN` is true.
|
|
2080
|
+
*/
|
|
1984
2081
|
declare function allclose(actual: Parameters<typeof array>[0], expected: Parameters<typeof array>[0], options?: {
|
|
1985
2082
|
rtol?: number;
|
|
1986
2083
|
atol?: number;
|
|
2084
|
+
equalNaN?: boolean;
|
|
1987
2085
|
}): boolean;
|
|
2086
|
+
/**
|
|
2087
|
+
* Check if two arrays are element-wise equal.
|
|
2088
|
+
*
|
|
2089
|
+
* Returns False if the arrays have different shapes. If `equalNaN` is True,
|
|
2090
|
+
* NaNs in the same position are considered equal.
|
|
2091
|
+
*/
|
|
2092
|
+
declare function arrayEqual(a1: ArrayLike, a2: ArrayLike, opts?: {
|
|
2093
|
+
equalNaN?: boolean;
|
|
2094
|
+
}): Array;
|
|
2095
|
+
/**
|
|
2096
|
+
* Check if two arrays are element-wise equal after broadcasting.
|
|
2097
|
+
*
|
|
2098
|
+
* Unlike `arrayEqual`, this allows inputs with different but
|
|
2099
|
+
* broadcast-compatible shapes.
|
|
2100
|
+
*/
|
|
2101
|
+
declare function arrayEquiv(a1: ArrayLike, a2: ArrayLike): Array;
|
|
1988
2102
|
/** Matrix product of two arrays. */
|
|
1989
2103
|
declare function matmul(x: ArrayLike, y: ArrayLike): Array;
|
|
2104
|
+
/** Matrix-vector product. x1 is [..., M, N], x2 is [..., N] → [..., M]. */
|
|
2105
|
+
declare function matvec(x1: ArrayLike, x2: ArrayLike): Array;
|
|
2106
|
+
/** Vector-matrix product. x1 is [..., N], x2 is [..., N, M] → [..., M]. */
|
|
2107
|
+
declare function vecmat(x1: ArrayLike, x2: ArrayLike): Array;
|
|
1990
2108
|
/** Dot product of two arrays. */
|
|
1991
2109
|
declare function dot(x: ArrayLike, y: ArrayLike): Array;
|
|
1992
2110
|
/**
|
|
@@ -2039,6 +2157,18 @@ declare function inner(x: ArrayLike, y: ArrayLike): Array;
|
|
|
2039
2157
|
* be of shape `[x.size, y.size]`.
|
|
2040
2158
|
*/
|
|
2041
2159
|
declare function outer(x: ArrayLike, y: ArrayLike): Array;
|
|
2160
|
+
/**
|
|
2161
|
+
* @function Compute the cross product of two arrays.
|
|
2162
|
+
*
|
|
2163
|
+
* Supports 2D (scalar result) and 3D cross products, with optional axis
|
|
2164
|
+
* arguments. If `axis` is given, it overrides `axisa`, `axisb`, and `axisc`.
|
|
2165
|
+
*/
|
|
2166
|
+
declare const cross: OwnedFunction<(a: ArrayLike, b: ArrayLike, args_2?: {
|
|
2167
|
+
axisa?: number | undefined;
|
|
2168
|
+
axisb?: number | undefined;
|
|
2169
|
+
axisc?: number | undefined;
|
|
2170
|
+
axis?: number | undefined;
|
|
2171
|
+
} | undefined) => Array>;
|
|
2042
2172
|
/** Vector dot product of two arrays along a given axis. */
|
|
2043
2173
|
declare function vecdot(x: ArrayLike, y: ArrayLike, {
|
|
2044
2174
|
axis
|
|
@@ -2084,14 +2214,13 @@ declare function clip(a: ArrayLike, min?: ArrayLike, max?: ArrayLike): Array;
|
|
|
2084
2214
|
declare function absolute(x: ArrayLike): Array;
|
|
2085
2215
|
/** Return an element-wise indication of sign of the input. */
|
|
2086
2216
|
declare function sign(x: ArrayLike): Array;
|
|
2087
|
-
/** @function Return element-wise positive values of the input (no-op). */
|
|
2088
|
-
declare const positive: (x: ArrayLike) => Array;
|
|
2089
2217
|
/**
|
|
2090
|
-
*
|
|
2091
|
-
*
|
|
2092
|
-
* `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
2218
|
+
* @function
|
|
2219
|
+
* Return the value with the magnitude of x and the sign of y, element-wise.
|
|
2093
2220
|
*/
|
|
2094
|
-
declare
|
|
2221
|
+
declare const copysign: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
|
|
2222
|
+
/** @function Return element-wise positive values of the input (no-op). */
|
|
2223
|
+
declare const positive: (x: ArrayLike) => Array;
|
|
2095
2224
|
/**
|
|
2096
2225
|
* Return the Hann window of size M, a taper with a weighted cosine bell.
|
|
2097
2226
|
*
|
|
@@ -2186,6 +2315,18 @@ declare const remainder: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
|
|
|
2186
2315
|
declare function divmod(x: ArrayLike, y: ArrayLike): [Array, Array];
|
|
2187
2316
|
/** Round input to the nearest integer towards zero. */
|
|
2188
2317
|
declare function trunc(x: ArrayLike): Array;
|
|
2318
|
+
/**
|
|
2319
|
+
* @function
|
|
2320
|
+
* Round to the given number of decimals.
|
|
2321
|
+
*
|
|
2322
|
+
* Uses banker's rounding (round half to even) to match NumPy/JAX behavior.
|
|
2323
|
+
*/
|
|
2324
|
+
declare const round: OwnedFunction<(a: ArrayLike, decimals?: number | undefined) => Array>;
|
|
2325
|
+
/**
|
|
2326
|
+
* @function
|
|
2327
|
+
* Round to the nearest integer, with ties going to the nearest even integer.
|
|
2328
|
+
*/
|
|
2329
|
+
declare const rint: OwnedFunction<(x: ArrayLike) => Array>;
|
|
2189
2330
|
/**
|
|
2190
2331
|
* Compute `x1 * 2 ** x2` as a standard multiplication and exponentiation.
|
|
2191
2332
|
*
|
|
@@ -2688,8 +2829,31 @@ declare namespace scipy_special_d_exports {
|
|
|
2688
2829
|
* The logit function, `logit(p) = log(p / (1-p))`.
|
|
2689
2830
|
*/
|
|
2690
2831
|
declare const logit: OwnedFunction<(x: ArrayLike) => Array>;
|
|
2832
|
+
//#endregion
|
|
2833
|
+
//#region src/tracing.d.ts
|
|
2834
|
+
/**
|
|
2835
|
+
* Start collecting kernel traces.
|
|
2836
|
+
*
|
|
2837
|
+
* Traces appear in developer tools under the "Performance" tab, and they are
|
|
2838
|
+
* useful for measuring fine-grained kernel execution time.
|
|
2839
|
+
*/
|
|
2840
|
+
declare function startTrace(): void;
|
|
2841
|
+
/**
|
|
2842
|
+
* Stop collecting kernel traces.
|
|
2843
|
+
*
|
|
2844
|
+
* Traces appear in developer tools under the "Performance" tab, and they are
|
|
2845
|
+
* useful for measuring fine-grained kernel execution time.
|
|
2846
|
+
*/
|
|
2847
|
+
declare function stopTrace(): void;
|
|
2848
|
+
/** Check if tracing is currently enabled. */
|
|
2849
|
+
|
|
2691
2850
|
//#endregion
|
|
2692
2851
|
//#region src/index.d.ts
|
|
2852
|
+
/** @namespace */
|
|
2853
|
+
declare const profiler: {
|
|
2854
|
+
startTrace: typeof startTrace;
|
|
2855
|
+
stopTrace: typeof stopTrace;
|
|
2856
|
+
};
|
|
2693
2857
|
/**
|
|
2694
2858
|
* @function
|
|
2695
2859
|
* Compute the forward-mode Jacobian-vector product for a function.
|
|
@@ -2854,4 +3018,4 @@ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
|
|
|
2854
3018
|
*/
|
|
2855
3019
|
declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
|
|
2856
3020
|
//#endregion
|
|
2857
|
-
export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|
|
3021
|
+
export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, getWebGPUDevice, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, profiler, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|