@jax-js/jax 0.1.9 → 0.1.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.d.cts CHANGED
@@ -232,6 +232,8 @@ declare class AluExp implements FpHashable {
232
232
  static cast(dtype: DType, a: AluExp): AluExp;
233
233
  static bitcast(dtype: DType, a: AluExp): AluExp;
234
234
  static threefry2x32(k0: AluExp, k1: AluExp, c0: AluExp, c1: AluExp, mode?: "xor" | 0 | 1): AluExp;
235
+ static bitCombine(a: AluExp, b: AluExp, mode: "and" | "or" | "xor"): AluExp;
236
+ static bitShift(a: AluExp, b: AluExp, mode: "shl" | "shr"): AluExp;
235
237
  static cmplt(a: AluExp, b: AluExp): AluExp;
236
238
  static cmpne(a: AluExp, b: AluExp): AluExp;
237
239
  static where(cond: AluExp, a: AluExp, b: AluExp): AluExp;
@@ -323,6 +325,11 @@ declare enum AluOp {
323
325
  Reciprocal = "Reciprocal",
324
326
  Cast = "Cast",
325
327
  Bitcast = "Bitcast",
328
+ BitCombine = "BitCombine",
329
+ // arg = 'or' | 'and' | 'xor'
330
+ BitInvert = "BitInvert",
331
+ BitShift = "BitShift",
332
+ // arg = 'shl' | 'shr'
326
333
  Cmplt = "Cmplt",
327
334
  Cmpne = "Cmpne",
328
335
  Where = "Where",
@@ -546,6 +553,11 @@ declare class Executable<T = any> {
546
553
  source: Kernel | Routine, /** Extra data specific to the backend running this executable. */
547
554
  data: T);
548
555
  }
556
+ /**
557
+ * If the WebGPU backend has been initialized, return the `GPUDevice` that this
558
+ * backend runs on. This is useful for sharing buffers.
559
+ */
560
+ declare function getWebGPUDevice(): GPUDevice;
549
561
  declare namespace tree_d_exports {
550
562
  export { JsTree, JsTreeDef, MapJsTree, NodeType, dispose, flatten, leaves, map, ref, structure, unflatten };
551
563
  }
@@ -719,6 +731,8 @@ declare enum Primitive {
719
731
  // uses sign of numerator, C-style, matches JS but not Python
720
732
  Min = "min",
721
733
  Max = "max",
734
+ BitCombine = "bit_combine",
735
+ BitShift = "bit_shift",
722
736
  Neg = "neg",
723
737
  Reciprocal = "reciprocal",
724
738
  Floor = "floor",
@@ -767,6 +781,12 @@ declare enum Primitive {
767
781
  Jit = "jit",
768
782
  }
769
783
  interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
784
+ [Primitive.BitCombine]: {
785
+ op: "and" | "or" | "xor";
786
+ };
787
+ [Primitive.BitShift]: {
788
+ op: "shl" | "shr";
789
+ };
770
790
  [Primitive.Cast]: {
771
791
  dtype: DType;
772
792
  };
@@ -1004,6 +1024,8 @@ declare abstract class Tracer {
1004
1024
  reshape(shape: number | number[]): this;
1005
1025
  /** Copy the array and cast to a specified dtype. */
1006
1026
  astype(dtype: DType): this;
1027
+ /** Return a bitwise cast of the array, viewed as a new dtype. */
1028
+ view(dtype?: DType): this;
1007
1029
  /** Subtract an array from this one. */
1008
1030
  sub(other: this | TracerValue): this;
1009
1031
  /** Divide an array by this one. */
@@ -1192,6 +1214,19 @@ declare class Array extends Tracer {
1192
1214
  * recommended for performance reasons, as it will block rendering.
1193
1215
  */
1194
1216
  dataSync(): DataArray;
1217
+ /**
1218
+ * Return this array as a WebGPU buffer (with `STORAGE | COPY_SRC`).
1219
+ *
1220
+ * Only available on the WebGPU backend. The array's memory is still managed
1221
+ * by jax-js, and it will be freed when the buffer is no longer in use. You
1222
+ * _should not_ mutate the buffer's contents.
1223
+ *
1224
+ * Note that the GPU buffer may be slightly larger than the array's size; it
1225
+ * will always be aligned to 4 bytes.
1226
+ */
1227
+ gpuBuffer(): Promise<GPUBuffer>;
1228
+ /** Synchronous version of `Array.gpuBuffer()`. */
1229
+ gpuBufferSync(): GPUBuffer;
1195
1230
  /**
1196
1231
  * Convert this array into a JavaScript object.
1197
1232
  *
@@ -1427,8 +1462,10 @@ declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
1427
1462
  unitDiagonal?: boolean;
1428
1463
  }): Array;
1429
1464
  declare namespace lax_d_exports {
1430
- export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient, topK };
1465
+ export { DotDimensionNumbers, PaddingType, bitcastConvertType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient, topK };
1431
1466
  }
1467
+ /** Elementwise bitcast an array into a new dtype. */
1468
+ declare function bitcastConvertType(x: ArrayLike, newDtype: DType): Array;
1432
1469
  /**
1433
1470
  * Dimension numbers for general `dot()` primitive.
1434
1471
  *
@@ -1567,7 +1604,7 @@ declare function fft(a: ComplexPair, axis?: number): ComplexPair;
1567
1604
  */
1568
1605
  declare function ifft(a: ComplexPair, axis?: number): ComplexPair;
1569
1606
  declare namespace numpy_linalg_d_exports {
1570
- export { cholesky, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot };
1607
+ export { cholesky, cross$1 as cross, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot, vectorNorm };
1571
1608
  }
1572
1609
  /**
1573
1610
  * Compute the Cholesky decomposition of a (batched) positive-definite matrix.
@@ -1582,6 +1619,13 @@ declare function cholesky(a: ArrayLike, {
1582
1619
  upper?: boolean;
1583
1620
  symmetrizeInput?: boolean;
1584
1621
  }): Array;
1622
+ /**
1623
+ * Compute the cross-product of two 3D vectors.
1624
+ *
1625
+ * This is a simpler and less flexible version of `jax.numpy.cross()`.
1626
+ * Both inputs must have size 3 along the specified axis.
1627
+ */
1628
+ declare function cross$1(x1: ArrayLike, x2: ArrayLike, axis?: number): Array;
1585
1629
  /** Compute the determinant of a square matrix (batched). */
1586
1630
  declare function det(a: ArrayLike): Array;
1587
1631
  /** Compute the inverse of a square matrix (batched). */
@@ -1615,6 +1659,24 @@ declare function slogdet(a: ArrayLike): [Array, Array];
1615
1659
  * @returns Solution `x` of shape `(..., N)` or `(..., N, M)`.
1616
1660
  */
1617
1661
  declare function solve(a: ArrayLike, b: ArrayLike): Array;
1662
+ /**
1663
+ * Compute the vector norm of an array.
1664
+ *
1665
+ * @param x - Input array.
1666
+ * @param ord - Order of the norm (default 2). Supports `Infinity`, `-Infinity`, `0`, or any real number.
1667
+ * @param axis - Axis/axes to reduce over (default: all axes).
1668
+ * @param keepdims - Whether to keep reduced dimensions as size 1.
1669
+ * @returns The norm of `x`, reduced over the given axes.
1670
+ */
1671
+ declare function vectorNorm(x: ArrayLike, {
1672
+ ord,
1673
+ axis,
1674
+ keepdims
1675
+ }?: {
1676
+ ord?: number;
1677
+ axis?: number | number[] | null;
1678
+ keepdims?: boolean;
1679
+ }): Array;
1618
1680
  //#endregion
1619
1681
  //#region src/library/numpy/dtype-info.d.ts
1620
1682
  /** @inline */
@@ -1668,7 +1730,7 @@ type IInfo = Readonly<{
1668
1730
  /** Machine limits for integer types. */
1669
1731
  declare function iinfo(dtype: DType): IInfo;
1670
1732
  declare namespace numpy_d_exports {
1671
- export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logspace, matmul, matrixTranspose, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, nanToNum, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
1733
+ export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, round as around, array, arrayEqual, arrayEquiv, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, average, bitwiseAnd, invert as bitwiseInvert, leftShift as bitwiseLeftShift, invert as bitwiseNot, bitwiseOr, rightShift as bitwiseRightShift, bitwiseXor, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, copysign, corrcoef, correlate, cos, cosh, cov, cross, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, invert, isfinite, isinf, isnan, isneginf, isposinf, ldexp, leftShift, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logicalAnd, logicalNot, logicalOr, logicalXor, logspace, matmul, matrixTranspose, matvec, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, nanToNum, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, rightShift, rint, round, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vecmat, vstack, where, zeros, zerosLike };
1672
1734
  }
1673
1735
  declare const float32 = DType.Float32;
1674
1736
  declare const int32 = DType.Int32;
@@ -1728,6 +1790,26 @@ declare const notEqual: (x: ArrayLike, y: ArrayLike) => Array;
1728
1790
  declare const greaterEqual: (x: ArrayLike, y: ArrayLike) => Array;
1729
1791
  /** @function Compare two arrays element-wise. */
1730
1792
  declare const lessEqual: (x: ArrayLike, y: ArrayLike) => Array;
1793
+ /** Compute element-wise logical AND. */
1794
+ declare function logicalAnd(x: ArrayLike, y: ArrayLike): Array;
1795
+ /** Compute element-wise logical OR. */
1796
+ declare function logicalOr(x: ArrayLike, y: ArrayLike): Array;
1797
+ /** Compute element-wise logical XOR. */
1798
+ declare function logicalXor(x: ArrayLike, y: ArrayLike): Array;
1799
+ /** Compute element-wise logical NOT. */
1800
+ declare function logicalNot(x: ArrayLike): Array;
1801
+ /** Compute element-wise bitwise AND. */
1802
+ declare function bitwiseAnd(x: ArrayLike, y: ArrayLike): Array;
1803
+ /** Compute element-wise bitwise OR. */
1804
+ declare function bitwiseOr(x: ArrayLike, y: ArrayLike): Array;
1805
+ /** Compute element-wise bitwise XOR. */
1806
+ declare function bitwiseXor(x: ArrayLike, y: ArrayLike): Array;
1807
+ /** Compute element-wise bitwise NOT (inversion). */
1808
+ declare function invert(x: ArrayLike): Array;
1809
+ /** Compute element-wise left bit shift. */
1810
+ declare function leftShift(x: ArrayLike, y: ArrayLike): Array;
1811
+ /** Compute element-wise right bit shift. */
1812
+ declare function rightShift(x: ArrayLike, y: ArrayLike): Array;
1731
1813
  /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
1732
1814
  declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
1733
1815
  /**
@@ -1812,6 +1894,16 @@ declare function all(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1812
1894
  declare function ptp(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1813
1895
  /** Compute the average of the array elements along the specified axis. */
1814
1896
  declare function mean(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1897
+ /**
1898
+ * Compute the weighted average along the specified axis.
1899
+ *
1900
+ * If no axis is specified, mean is computed along all the axes. The weights
1901
+ * should have shape matching that of `a`, or if an axis is specified, it should
1902
+ * match the shape along those axes.
1903
+ */
1904
+ declare function average(a: ArrayLike, axis?: Axis, opts?: {
1905
+ weights?: ArrayLike;
1906
+ } & ReduceOpts): Array;
1815
1907
  /**
1816
1908
  * Returns the indices of the minimum values along an axis.
1817
1909
  *
@@ -1983,13 +2075,39 @@ declare function argsort(a: ArrayLike, axis?: number): Array;
1983
2075
  * numbered axis. By default, the flattened array is used.
1984
2076
  */
1985
2077
  declare function take(a: ArrayLike, indices: ArrayLike, axis?: number | null): Array;
1986
- /** Return if two arrays are element-wise equal within a tolerance. */
2078
+ /**
2079
+ * Return if two arrays are element-wise equal within a tolerance.
2080
+ *
2081
+ * The formula used is `|actual - expected| <= atol + rtol * |expected|`, with
2082
+ * NaN values comparing equal if `equalNaN` is true.
2083
+ */
1987
2084
  declare function allclose(actual: Parameters<typeof array>[0], expected: Parameters<typeof array>[0], options?: {
1988
2085
  rtol?: number;
1989
2086
  atol?: number;
2087
+ equalNaN?: boolean;
1990
2088
  }): boolean;
2089
+ /**
2090
+ * Check if two arrays are element-wise equal.
2091
+ *
2092
+ * Returns False if the arrays have different shapes. If `equalNaN` is True,
2093
+ * NaNs in the same position are considered equal.
2094
+ */
2095
+ declare function arrayEqual(a1: ArrayLike, a2: ArrayLike, opts?: {
2096
+ equalNaN?: boolean;
2097
+ }): Array;
2098
+ /**
2099
+ * Check if two arrays are element-wise equal after broadcasting.
2100
+ *
2101
+ * Unlike `arrayEqual`, this allows inputs with different but
2102
+ * broadcast-compatible shapes.
2103
+ */
2104
+ declare function arrayEquiv(a1: ArrayLike, a2: ArrayLike): Array;
1991
2105
  /** Matrix product of two arrays. */
1992
2106
  declare function matmul(x: ArrayLike, y: ArrayLike): Array;
2107
+ /** Matrix-vector product. x1 is [..., M, N], x2 is [..., N] → [..., M]. */
2108
+ declare function matvec(x1: ArrayLike, x2: ArrayLike): Array;
2109
+ /** Vector-matrix product. x1 is [..., N], x2 is [..., N, M] → [..., M]. */
2110
+ declare function vecmat(x1: ArrayLike, x2: ArrayLike): Array;
1993
2111
  /** Dot product of two arrays. */
1994
2112
  declare function dot(x: ArrayLike, y: ArrayLike): Array;
1995
2113
  /**
@@ -2042,6 +2160,18 @@ declare function inner(x: ArrayLike, y: ArrayLike): Array;
2042
2160
  * be of shape `[x.size, y.size]`.
2043
2161
  */
2044
2162
  declare function outer(x: ArrayLike, y: ArrayLike): Array;
2163
+ /**
2164
+ * @function Compute the cross product of two arrays.
2165
+ *
2166
+ * Supports 2D (scalar result) and 3D cross products, with optional axis
2167
+ * arguments. If `axis` is given, it overrides `axisa`, `axisb`, and `axisc`.
2168
+ */
2169
+ declare const cross: OwnedFunction<(a: ArrayLike, b: ArrayLike, args_2?: {
2170
+ axisa?: number | undefined;
2171
+ axisb?: number | undefined;
2172
+ axisc?: number | undefined;
2173
+ axis?: number | undefined;
2174
+ } | undefined) => Array>;
2045
2175
  /** Vector dot product of two arrays along a given axis. */
2046
2176
  declare function vecdot(x: ArrayLike, y: ArrayLike, {
2047
2177
  axis
@@ -2087,14 +2217,13 @@ declare function clip(a: ArrayLike, min?: ArrayLike, max?: ArrayLike): Array;
2087
2217
  declare function absolute(x: ArrayLike): Array;
2088
2218
  /** Return an element-wise indication of sign of the input. */
2089
2219
  declare function sign(x: ArrayLike): Array;
2090
- /** @function Return element-wise positive values of the input (no-op). */
2091
- declare const positive: (x: ArrayLike) => Array;
2092
2220
  /**
2093
- * Return the Hamming window of size M, a taper with a weighted cosine bell.
2094
- *
2095
- * `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
2221
+ * @function
2222
+ * Return the value with the magnitude of x and the sign of y, element-wise.
2096
2223
  */
2097
- declare function hamming(M: number): Array;
2224
+ declare const copysign: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
2225
+ /** @function Return element-wise positive values of the input (no-op). */
2226
+ declare const positive: (x: ArrayLike) => Array;
2098
2227
  /**
2099
2228
  * Return the Hann window of size M, a taper with a weighted cosine bell.
2100
2229
  *
@@ -2189,6 +2318,18 @@ declare const remainder: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
2189
2318
  declare function divmod(x: ArrayLike, y: ArrayLike): [Array, Array];
2190
2319
  /** Round input to the nearest integer towards zero. */
2191
2320
  declare function trunc(x: ArrayLike): Array;
2321
+ /**
2322
+ * @function
2323
+ * Round to the given number of decimals.
2324
+ *
2325
+ * Uses banker's rounding (round half to even) to match NumPy/JAX behavior.
2326
+ */
2327
+ declare const round: OwnedFunction<(a: ArrayLike, decimals?: number | undefined) => Array>;
2328
+ /**
2329
+ * @function
2330
+ * Round to the nearest integer, with ties going to the nearest even integer.
2331
+ */
2332
+ declare const rint: OwnedFunction<(x: ArrayLike) => Array>;
2192
2333
  /**
2193
2334
  * Compute `x1 * 2 ** x2` as a standard multiplication and exponentiation.
2194
2335
  *
@@ -2691,8 +2832,31 @@ declare namespace scipy_special_d_exports {
2691
2832
  * The logit function, `logit(p) = log(p / (1-p))`.
2692
2833
  */
2693
2834
  declare const logit: OwnedFunction<(x: ArrayLike) => Array>;
2835
+ //#endregion
2836
+ //#region src/tracing.d.ts
2837
+ /**
2838
+ * Start collecting kernel traces.
2839
+ *
2840
+ * Traces appear in developer tools under the "Performance" tab, and they are
2841
+ * useful for measuring fine-grained kernel execution time.
2842
+ */
2843
+ declare function startTrace(): void;
2844
+ /**
2845
+ * Stop collecting kernel traces.
2846
+ *
2847
+ * Traces appear in developer tools under the "Performance" tab, and they are
2848
+ * useful for measuring fine-grained kernel execution time.
2849
+ */
2850
+ declare function stopTrace(): void;
2851
+ /** Check if tracing is currently enabled. */
2852
+
2694
2853
  //#endregion
2695
2854
  //#region src/index.d.ts
2855
+ /** @namespace */
2856
+ declare const profiler: {
2857
+ startTrace: typeof startTrace;
2858
+ stopTrace: typeof stopTrace;
2859
+ };
2696
2860
  /**
2697
2861
  * @function
2698
2862
  * Compute the forward-mode Jacobian-vector product for a function.
@@ -2857,4 +3021,4 @@ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
2857
3021
  */
2858
3022
  declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
2859
3023
  //#endregion
2860
- export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
3024
+ export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, getWebGPUDevice, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, profiler, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
package/dist/index.d.ts CHANGED
@@ -229,6 +229,8 @@ declare class AluExp implements FpHashable {
229
229
  static cast(dtype: DType, a: AluExp): AluExp;
230
230
  static bitcast(dtype: DType, a: AluExp): AluExp;
231
231
  static threefry2x32(k0: AluExp, k1: AluExp, c0: AluExp, c1: AluExp, mode?: "xor" | 0 | 1): AluExp;
232
+ static bitCombine(a: AluExp, b: AluExp, mode: "and" | "or" | "xor"): AluExp;
233
+ static bitShift(a: AluExp, b: AluExp, mode: "shl" | "shr"): AluExp;
232
234
  static cmplt(a: AluExp, b: AluExp): AluExp;
233
235
  static cmpne(a: AluExp, b: AluExp): AluExp;
234
236
  static where(cond: AluExp, a: AluExp, b: AluExp): AluExp;
@@ -320,6 +322,11 @@ declare enum AluOp {
320
322
  Reciprocal = "Reciprocal",
321
323
  Cast = "Cast",
322
324
  Bitcast = "Bitcast",
325
+ BitCombine = "BitCombine",
326
+ // arg = 'or' | 'and' | 'xor'
327
+ BitInvert = "BitInvert",
328
+ BitShift = "BitShift",
329
+ // arg = 'shl' | 'shr'
323
330
  Cmplt = "Cmplt",
324
331
  Cmpne = "Cmpne",
325
332
  Where = "Where",
@@ -543,6 +550,11 @@ declare class Executable<T = any> {
543
550
  source: Kernel | Routine, /** Extra data specific to the backend running this executable. */
544
551
  data: T);
545
552
  }
553
+ /**
554
+ * If the WebGPU backend has been initialized, return the `GPUDevice` that this
555
+ * backend runs on. This is useful for sharing buffers.
556
+ */
557
+ declare function getWebGPUDevice(): GPUDevice;
546
558
  declare namespace tree_d_exports {
547
559
  export { JsTree, JsTreeDef, MapJsTree, NodeType, dispose, flatten, leaves, map, ref, structure, unflatten };
548
560
  }
@@ -716,6 +728,8 @@ declare enum Primitive {
716
728
  // uses sign of numerator, C-style, matches JS but not Python
717
729
  Min = "min",
718
730
  Max = "max",
731
+ BitCombine = "bit_combine",
732
+ BitShift = "bit_shift",
719
733
  Neg = "neg",
720
734
  Reciprocal = "reciprocal",
721
735
  Floor = "floor",
@@ -764,6 +778,12 @@ declare enum Primitive {
764
778
  Jit = "jit",
765
779
  }
766
780
  interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
781
+ [Primitive.BitCombine]: {
782
+ op: "and" | "or" | "xor";
783
+ };
784
+ [Primitive.BitShift]: {
785
+ op: "shl" | "shr";
786
+ };
767
787
  [Primitive.Cast]: {
768
788
  dtype: DType;
769
789
  };
@@ -1001,6 +1021,8 @@ declare abstract class Tracer {
1001
1021
  reshape(shape: number | number[]): this;
1002
1022
  /** Copy the array and cast to a specified dtype. */
1003
1023
  astype(dtype: DType): this;
1024
+ /** Return a bitwise cast of the array, viewed as a new dtype. */
1025
+ view(dtype?: DType): this;
1004
1026
  /** Subtract an array from this one. */
1005
1027
  sub(other: this | TracerValue): this;
1006
1028
  /** Divide an array by this one. */
@@ -1189,6 +1211,19 @@ declare class Array extends Tracer {
1189
1211
  * recommended for performance reasons, as it will block rendering.
1190
1212
  */
1191
1213
  dataSync(): DataArray;
1214
+ /**
1215
+ * Return this array as a WebGPU buffer (with `STORAGE | COPY_SRC`).
1216
+ *
1217
+ * Only available on the WebGPU backend. The array's memory is still managed
1218
+ * by jax-js, and it will be freed when the buffer is no longer in use. You
1219
+ * _should not_ mutate the buffer's contents.
1220
+ *
1221
+ * Note that the GPU buffer may be slightly larger than the array's size; it
1222
+ * will always be aligned to 4 bytes.
1223
+ */
1224
+ gpuBuffer(): Promise<GPUBuffer>;
1225
+ /** Synchronous version of `Array.gpuBuffer()`. */
1226
+ gpuBufferSync(): GPUBuffer;
1192
1227
  /**
1193
1228
  * Convert this array into a JavaScript object.
1194
1229
  *
@@ -1424,8 +1459,10 @@ declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
1424
1459
  unitDiagonal?: boolean;
1425
1460
  }): Array;
1426
1461
  declare namespace lax_d_exports {
1427
- export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient, topK };
1462
+ export { DotDimensionNumbers, PaddingType, bitcastConvertType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient, topK };
1428
1463
  }
1464
+ /** Elementwise bitcast an array into a new dtype. */
1465
+ declare function bitcastConvertType(x: ArrayLike, newDtype: DType): Array;
1429
1466
  /**
1430
1467
  * Dimension numbers for general `dot()` primitive.
1431
1468
  *
@@ -1564,7 +1601,7 @@ declare function fft(a: ComplexPair, axis?: number): ComplexPair;
1564
1601
  */
1565
1602
  declare function ifft(a: ComplexPair, axis?: number): ComplexPair;
1566
1603
  declare namespace numpy_linalg_d_exports {
1567
- export { cholesky, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot };
1604
+ export { cholesky, cross$1 as cross, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot, vectorNorm };
1568
1605
  }
1569
1606
  /**
1570
1607
  * Compute the Cholesky decomposition of a (batched) positive-definite matrix.
@@ -1579,6 +1616,13 @@ declare function cholesky(a: ArrayLike, {
1579
1616
  upper?: boolean;
1580
1617
  symmetrizeInput?: boolean;
1581
1618
  }): Array;
1619
+ /**
1620
+ * Compute the cross-product of two 3D vectors.
1621
+ *
1622
+ * This is a simpler and less flexible version of `jax.numpy.cross()`.
1623
+ * Both inputs must have size 3 along the specified axis.
1624
+ */
1625
+ declare function cross$1(x1: ArrayLike, x2: ArrayLike, axis?: number): Array;
1582
1626
  /** Compute the determinant of a square matrix (batched). */
1583
1627
  declare function det(a: ArrayLike): Array;
1584
1628
  /** Compute the inverse of a square matrix (batched). */
@@ -1612,6 +1656,24 @@ declare function slogdet(a: ArrayLike): [Array, Array];
1612
1656
  * @returns Solution `x` of shape `(..., N)` or `(..., N, M)`.
1613
1657
  */
1614
1658
  declare function solve(a: ArrayLike, b: ArrayLike): Array;
1659
+ /**
1660
+ * Compute the vector norm of an array.
1661
+ *
1662
+ * @param x - Input array.
1663
+ * @param ord - Order of the norm (default 2). Supports `Infinity`, `-Infinity`, `0`, or any real number.
1664
+ * @param axis - Axis/axes to reduce over (default: all axes).
1665
+ * @param keepdims - Whether to keep reduced dimensions as size 1.
1666
+ * @returns The norm of `x`, reduced over the given axes.
1667
+ */
1668
+ declare function vectorNorm(x: ArrayLike, {
1669
+ ord,
1670
+ axis,
1671
+ keepdims
1672
+ }?: {
1673
+ ord?: number;
1674
+ axis?: number | number[] | null;
1675
+ keepdims?: boolean;
1676
+ }): Array;
1615
1677
  //#endregion
1616
1678
  //#region src/library/numpy/dtype-info.d.ts
1617
1679
  /** @inline */
@@ -1665,7 +1727,7 @@ type IInfo = Readonly<{
1665
1727
  /** Machine limits for integer types. */
1666
1728
  declare function iinfo(dtype: DType): IInfo;
1667
1729
  declare namespace numpy_d_exports {
1668
- export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logspace, matmul, matrixTranspose, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, nanToNum, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
1730
+ export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, round as around, array, arrayEqual, arrayEquiv, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, average, bitwiseAnd, invert as bitwiseInvert, leftShift as bitwiseLeftShift, invert as bitwiseNot, bitwiseOr, rightShift as bitwiseRightShift, bitwiseXor, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, copysign, corrcoef, correlate, cos, cosh, cov, cross, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, invert, isfinite, isinf, isnan, isneginf, isposinf, ldexp, leftShift, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logicalAnd, logicalNot, logicalOr, logicalXor, logspace, matmul, matrixTranspose, matvec, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, nanToNum, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, rightShift, rint, round, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vecmat, vstack, where, zeros, zerosLike };
1669
1731
  }
1670
1732
  declare const float32 = DType.Float32;
1671
1733
  declare const int32 = DType.Int32;
@@ -1725,6 +1787,26 @@ declare const notEqual: (x: ArrayLike, y: ArrayLike) => Array;
1725
1787
  declare const greaterEqual: (x: ArrayLike, y: ArrayLike) => Array;
1726
1788
  /** @function Compare two arrays element-wise. */
1727
1789
  declare const lessEqual: (x: ArrayLike, y: ArrayLike) => Array;
1790
+ /** Compute element-wise logical AND. */
1791
+ declare function logicalAnd(x: ArrayLike, y: ArrayLike): Array;
1792
+ /** Compute element-wise logical OR. */
1793
+ declare function logicalOr(x: ArrayLike, y: ArrayLike): Array;
1794
+ /** Compute element-wise logical XOR. */
1795
+ declare function logicalXor(x: ArrayLike, y: ArrayLike): Array;
1796
+ /** Compute element-wise logical NOT. */
1797
+ declare function logicalNot(x: ArrayLike): Array;
1798
+ /** Compute element-wise bitwise AND. */
1799
+ declare function bitwiseAnd(x: ArrayLike, y: ArrayLike): Array;
1800
+ /** Compute element-wise bitwise OR. */
1801
+ declare function bitwiseOr(x: ArrayLike, y: ArrayLike): Array;
1802
+ /** Compute element-wise bitwise XOR. */
1803
+ declare function bitwiseXor(x: ArrayLike, y: ArrayLike): Array;
1804
+ /** Compute element-wise bitwise NOT (inversion). */
1805
+ declare function invert(x: ArrayLike): Array;
1806
+ /** Compute element-wise left bit shift. */
1807
+ declare function leftShift(x: ArrayLike, y: ArrayLike): Array;
1808
+ /** Compute element-wise right bit shift. */
1809
+ declare function rightShift(x: ArrayLike, y: ArrayLike): Array;
1728
1810
  /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
1729
1811
  declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
1730
1812
  /**
@@ -1809,6 +1891,16 @@ declare function all(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1809
1891
  declare function ptp(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1810
1892
  /** Compute the average of the array elements along the specified axis. */
1811
1893
  declare function mean(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1894
+ /**
1895
+ * Compute the weighted average along the specified axis.
1896
+ *
1897
+ * If no axis is specified, mean is computed along all the axes. The weights
1898
+ * should have shape matching that of `a`, or if an axis is specified, it should
1899
+ * match the shape along those axes.
1900
+ */
1901
+ declare function average(a: ArrayLike, axis?: Axis, opts?: {
1902
+ weights?: ArrayLike;
1903
+ } & ReduceOpts): Array;
1812
1904
  /**
1813
1905
  * Returns the indices of the minimum values along an axis.
1814
1906
  *
@@ -1980,13 +2072,39 @@ declare function argsort(a: ArrayLike, axis?: number): Array;
1980
2072
  * numbered axis. By default, the flattened array is used.
1981
2073
  */
1982
2074
  declare function take(a: ArrayLike, indices: ArrayLike, axis?: number | null): Array;
1983
- /** Return if two arrays are element-wise equal within a tolerance. */
2075
+ /**
2076
+ * Return if two arrays are element-wise equal within a tolerance.
2077
+ *
2078
+ * The formula used is `|actual - expected| <= atol + rtol * |expected|`, with
2079
+ * NaN values comparing equal if `equalNaN` is true.
2080
+ */
1984
2081
  declare function allclose(actual: Parameters<typeof array>[0], expected: Parameters<typeof array>[0], options?: {
1985
2082
  rtol?: number;
1986
2083
  atol?: number;
2084
+ equalNaN?: boolean;
1987
2085
  }): boolean;
2086
+ /**
2087
+ * Check if two arrays are element-wise equal.
2088
+ *
2089
+ * Returns False if the arrays have different shapes. If `equalNaN` is True,
2090
+ * NaNs in the same position are considered equal.
2091
+ */
2092
+ declare function arrayEqual(a1: ArrayLike, a2: ArrayLike, opts?: {
2093
+ equalNaN?: boolean;
2094
+ }): Array;
2095
+ /**
2096
+ * Check if two arrays are element-wise equal after broadcasting.
2097
+ *
2098
+ * Unlike `arrayEqual`, this allows inputs with different but
2099
+ * broadcast-compatible shapes.
2100
+ */
2101
+ declare function arrayEquiv(a1: ArrayLike, a2: ArrayLike): Array;
1988
2102
  /** Matrix product of two arrays. */
1989
2103
  declare function matmul(x: ArrayLike, y: ArrayLike): Array;
2104
+ /** Matrix-vector product. x1 is [..., M, N], x2 is [..., N] → [..., M]. */
2105
+ declare function matvec(x1: ArrayLike, x2: ArrayLike): Array;
2106
+ /** Vector-matrix product. x1 is [..., N], x2 is [..., N, M] → [..., M]. */
2107
+ declare function vecmat(x1: ArrayLike, x2: ArrayLike): Array;
1990
2108
  /** Dot product of two arrays. */
1991
2109
  declare function dot(x: ArrayLike, y: ArrayLike): Array;
1992
2110
  /**
@@ -2039,6 +2157,18 @@ declare function inner(x: ArrayLike, y: ArrayLike): Array;
2039
2157
  * be of shape `[x.size, y.size]`.
2040
2158
  */
2041
2159
  declare function outer(x: ArrayLike, y: ArrayLike): Array;
2160
+ /**
2161
+ * @function Compute the cross product of two arrays.
2162
+ *
2163
+ * Supports 2D (scalar result) and 3D cross products, with optional axis
2164
+ * arguments. If `axis` is given, it overrides `axisa`, `axisb`, and `axisc`.
2165
+ */
2166
+ declare const cross: OwnedFunction<(a: ArrayLike, b: ArrayLike, args_2?: {
2167
+ axisa?: number | undefined;
2168
+ axisb?: number | undefined;
2169
+ axisc?: number | undefined;
2170
+ axis?: number | undefined;
2171
+ } | undefined) => Array>;
2042
2172
  /** Vector dot product of two arrays along a given axis. */
2043
2173
  declare function vecdot(x: ArrayLike, y: ArrayLike, {
2044
2174
  axis
@@ -2084,14 +2214,13 @@ declare function clip(a: ArrayLike, min?: ArrayLike, max?: ArrayLike): Array;
2084
2214
  declare function absolute(x: ArrayLike): Array;
2085
2215
  /** Return an element-wise indication of sign of the input. */
2086
2216
  declare function sign(x: ArrayLike): Array;
2087
- /** @function Return element-wise positive values of the input (no-op). */
2088
- declare const positive: (x: ArrayLike) => Array;
2089
2217
  /**
2090
- * Return the Hamming window of size M, a taper with a weighted cosine bell.
2091
- *
2092
- * `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
2218
+ * @function
2219
+ * Return the value with the magnitude of x and the sign of y, element-wise.
2093
2220
  */
2094
- declare function hamming(M: number): Array;
2221
+ declare const copysign: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
2222
+ /** @function Return element-wise positive values of the input (no-op). */
2223
+ declare const positive: (x: ArrayLike) => Array;
2095
2224
  /**
2096
2225
  * Return the Hann window of size M, a taper with a weighted cosine bell.
2097
2226
  *
@@ -2186,6 +2315,18 @@ declare const remainder: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
2186
2315
  declare function divmod(x: ArrayLike, y: ArrayLike): [Array, Array];
2187
2316
  /** Round input to the nearest integer towards zero. */
2188
2317
  declare function trunc(x: ArrayLike): Array;
2318
+ /**
2319
+ * @function
2320
+ * Round to the given number of decimals.
2321
+ *
2322
+ * Uses banker's rounding (round half to even) to match NumPy/JAX behavior.
2323
+ */
2324
+ declare const round: OwnedFunction<(a: ArrayLike, decimals?: number | undefined) => Array>;
2325
+ /**
2326
+ * @function
2327
+ * Round to the nearest integer, with ties going to the nearest even integer.
2328
+ */
2329
+ declare const rint: OwnedFunction<(x: ArrayLike) => Array>;
2189
2330
  /**
2190
2331
  * Compute `x1 * 2 ** x2` as a standard multiplication and exponentiation.
2191
2332
  *
@@ -2688,8 +2829,31 @@ declare namespace scipy_special_d_exports {
2688
2829
  * The logit function, `logit(p) = log(p / (1-p))`.
2689
2830
  */
2690
2831
  declare const logit: OwnedFunction<(x: ArrayLike) => Array>;
2832
+ //#endregion
2833
+ //#region src/tracing.d.ts
2834
+ /**
2835
+ * Start collecting kernel traces.
2836
+ *
2837
+ * Traces appear in developer tools under the "Performance" tab, and they are
2838
+ * useful for measuring fine-grained kernel execution time.
2839
+ */
2840
+ declare function startTrace(): void;
2841
+ /**
2842
+ * Stop collecting kernel traces.
2843
+ *
2844
+ * Traces appear in developer tools under the "Performance" tab, and they are
2845
+ * useful for measuring fine-grained kernel execution time.
2846
+ */
2847
+ declare function stopTrace(): void;
2848
+ /** Check if tracing is currently enabled. */
2849
+
2691
2850
  //#endregion
2692
2851
  //#region src/index.d.ts
2852
+ /** @namespace */
2853
+ declare const profiler: {
2854
+ startTrace: typeof startTrace;
2855
+ stopTrace: typeof stopTrace;
2856
+ };
2693
2857
  /**
2694
2858
  * @function
2695
2859
  * Compute the forward-mode Jacobian-vector product for a function.
@@ -2854,4 +3018,4 @@ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
2854
3018
  */
2855
3019
  declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
2856
3020
  //#endregion
2857
- export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
3021
+ export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, getWebGPUDevice, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, profiler, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };