@jax-js/jax 0.1.10 → 0.1.12

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.d.cts CHANGED
@@ -232,6 +232,8 @@ declare class AluExp implements FpHashable {
232
232
  static cast(dtype: DType, a: AluExp): AluExp;
233
233
  static bitcast(dtype: DType, a: AluExp): AluExp;
234
234
  static threefry2x32(k0: AluExp, k1: AluExp, c0: AluExp, c1: AluExp, mode?: "xor" | 0 | 1): AluExp;
235
+ static bitCombine(a: AluExp, b: AluExp, mode: "and" | "or" | "xor"): AluExp;
236
+ static bitShift(a: AluExp, b: AluExp, mode: "shl" | "shr"): AluExp;
235
237
  static cmplt(a: AluExp, b: AluExp): AluExp;
236
238
  static cmpne(a: AluExp, b: AluExp): AluExp;
237
239
  static where(cond: AluExp, a: AluExp, b: AluExp): AluExp;
@@ -323,6 +325,11 @@ declare enum AluOp {
323
325
  Reciprocal = "Reciprocal",
324
326
  Cast = "Cast",
325
327
  Bitcast = "Bitcast",
328
+ BitCombine = "BitCombine",
329
+ // arg = 'or' | 'and' | 'xor'
330
+ BitInvert = "BitInvert",
331
+ BitShift = "BitShift",
332
+ // arg = 'shl' | 'shr'
326
333
  Cmplt = "Cmplt",
327
334
  Cmpne = "Cmpne",
328
335
  Where = "Where",
@@ -546,6 +553,11 @@ declare class Executable<T = any> {
546
553
  source: Kernel | Routine, /** Extra data specific to the backend running this executable. */
547
554
  data: T);
548
555
  }
556
+ /**
557
+ * If the WebGPU backend has been initialized, return the `GPUDevice` that this
558
+ * backend runs on. This is useful for sharing buffers.
559
+ */
560
+ declare function getWebGPUDevice(): GPUDevice;
549
561
  declare namespace tree_d_exports {
550
562
  export { JsTree, JsTreeDef, MapJsTree, NodeType, dispose, flatten, leaves, map, ref, structure, unflatten };
551
563
  }
@@ -719,6 +731,8 @@ declare enum Primitive {
719
731
  // uses sign of numerator, C-style, matches JS but not Python
720
732
  Min = "min",
721
733
  Max = "max",
734
+ BitCombine = "bit_combine",
735
+ BitShift = "bit_shift",
722
736
  Neg = "neg",
723
737
  Reciprocal = "reciprocal",
724
738
  Floor = "floor",
@@ -767,6 +781,12 @@ declare enum Primitive {
767
781
  Jit = "jit",
768
782
  }
769
783
  interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
784
+ [Primitive.BitCombine]: {
785
+ op: "and" | "or" | "xor";
786
+ };
787
+ [Primitive.BitShift]: {
788
+ op: "shl" | "shr";
789
+ };
770
790
  [Primitive.Cast]: {
771
791
  dtype: DType;
772
792
  };
@@ -1194,6 +1214,19 @@ declare class Array extends Tracer {
1194
1214
  * recommended for performance reasons, as it will block rendering.
1195
1215
  */
1196
1216
  dataSync(): DataArray;
1217
+ /**
1218
+ * Return this array as a WebGPU buffer (with `STORAGE | COPY_SRC`).
1219
+ *
1220
+ * Only available on the WebGPU backend. The array's memory is still managed
1221
+ * by jax-js, and it will be freed when the buffer is no longer in use. You
1222
+ * _should not_ mutate the buffer's contents.
1223
+ *
1224
+ * Note that the GPU buffer may be slightly larger than the array's size; it
1225
+ * will always be aligned to 4 bytes.
1226
+ */
1227
+ gpuBuffer(): Promise<GPUBuffer>;
1228
+ /** Synchronous version of `Array.gpuBuffer()`. */
1229
+ gpuBufferSync(): GPUBuffer;
1197
1230
  /**
1198
1231
  * Convert this array into a JavaScript object.
1199
1232
  *
@@ -1571,7 +1604,7 @@ declare function fft(a: ComplexPair, axis?: number): ComplexPair;
1571
1604
  */
1572
1605
  declare function ifft(a: ComplexPair, axis?: number): ComplexPair;
1573
1606
  declare namespace numpy_linalg_d_exports {
1574
- export { cholesky, cross$1 as cross, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot };
1607
+ export { cholesky, cross$1 as cross, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot, vectorNorm };
1575
1608
  }
1576
1609
  /**
1577
1610
  * Compute the Cholesky decomposition of a (batched) positive-definite matrix.
@@ -1626,6 +1659,24 @@ declare function slogdet(a: ArrayLike): [Array, Array];
1626
1659
  * @returns Solution `x` of shape `(..., N)` or `(..., N, M)`.
1627
1660
  */
1628
1661
  declare function solve(a: ArrayLike, b: ArrayLike): Array;
1662
+ /**
1663
+ * Compute the vector norm of an array.
1664
+ *
1665
+ * @param x - Input array.
1666
+ * @param ord - Order of the norm (default 2). Supports `Infinity`, `-Infinity`, `0`, or any real number.
1667
+ * @param axis - Axis/axes to reduce over (default: all axes).
1668
+ * @param keepdims - Whether to keep reduced dimensions as size 1.
1669
+ * @returns The norm of `x`, reduced over the given axes.
1670
+ */
1671
+ declare function vectorNorm(x: ArrayLike, {
1672
+ ord,
1673
+ axis,
1674
+ keepdims
1675
+ }?: {
1676
+ ord?: number;
1677
+ axis?: number | number[] | null;
1678
+ keepdims?: boolean;
1679
+ }): Array;
1629
1680
  //#endregion
1630
1681
  //#region src/library/numpy/dtype-info.d.ts
1631
1682
  /** @inline */
@@ -1679,7 +1730,7 @@ type IInfo = Readonly<{
1679
1730
  /** Machine limits for integer types. */
1680
1731
  declare function iinfo(dtype: DType): IInfo;
1681
1732
  declare namespace numpy_d_exports {
1682
- export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, round as around, array, arrayEqual, arrayEquiv, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, average, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, copysign, corrcoef, correlate, cos, cosh, cov, cross, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logicalAnd, logicalNot, logicalOr, logicalXor, logspace, matmul, matrixTranspose, matvec, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, nanToNum, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, rint, round, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vecmat, vstack, where, zeros, zerosLike };
1733
+ export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, round as around, array, arrayEqual, arrayEquiv, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, average, bitwiseAnd, invert as bitwiseInvert, leftShift as bitwiseLeftShift, invert as bitwiseNot, bitwiseOr, rightShift as bitwiseRightShift, bitwiseXor, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, copysign, corrcoef, correlate, cos, cosh, cov, cross, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, invert, isfinite, isinf, isnan, isneginf, isposinf, ldexp, leftShift, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logicalAnd, logicalNot, logicalOr, logicalXor, logspace, matmul, matrixTranspose, matvec, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, nanToNum, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, rightShift, rint, round, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vecmat, vstack, where, zeros, zerosLike };
1683
1734
  }
1684
1735
  declare const float32 = DType.Float32;
1685
1736
  declare const int32 = DType.Int32;
@@ -1747,6 +1798,18 @@ declare function logicalOr(x: ArrayLike, y: ArrayLike): Array;
1747
1798
  declare function logicalXor(x: ArrayLike, y: ArrayLike): Array;
1748
1799
  /** Compute element-wise logical NOT. */
1749
1800
  declare function logicalNot(x: ArrayLike): Array;
1801
+ /** Compute element-wise bitwise AND. */
1802
+ declare function bitwiseAnd(x: ArrayLike, y: ArrayLike): Array;
1803
+ /** Compute element-wise bitwise OR. */
1804
+ declare function bitwiseOr(x: ArrayLike, y: ArrayLike): Array;
1805
+ /** Compute element-wise bitwise XOR. */
1806
+ declare function bitwiseXor(x: ArrayLike, y: ArrayLike): Array;
1807
+ /** Compute element-wise bitwise NOT (inversion). */
1808
+ declare function invert(x: ArrayLike): Array;
1809
+ /** Compute element-wise left bit shift. */
1810
+ declare function leftShift(x: ArrayLike, y: ArrayLike): Array;
1811
+ /** Compute element-wise right bit shift. */
1812
+ declare function rightShift(x: ArrayLike, y: ArrayLike): Array;
1750
1813
  /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
1751
1814
  declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
1752
1815
  /**
@@ -2659,7 +2722,7 @@ declare function dotProductAttention(query: ArrayLike, key: ArrayLike, value: Ar
2659
2722
  localWindowSize?: number | [number, number];
2660
2723
  }): Array;
2661
2724
  declare namespace random_d_exports {
2662
- export { bernoulli, bits, categorical, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
2725
+ export { ball, bernoulli, bits, categorical, cauchy, choice, doubleSidedMaxwell, exponential, geometric, gumbel, key, laplace, logistic, lognormal, maxwell, multivariateNormal, normal, pareto, permutation, rademacher, randint, rayleigh, split, triangular, uniform, weibullMin };
2663
2726
  }
2664
2727
  /** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
2665
2728
  declare function key(seed: ArrayLike): Array;
@@ -2675,6 +2738,16 @@ declare const uniform: OwnedFunction<(key: ArrayLike, shape?: number[] | undefin
2675
2738
  minval?: number | undefined;
2676
2739
  maxval?: number | undefined;
2677
2740
  } | undefined) => Array>;
2741
+ /**
2742
+ * @function
2743
+ * Sample points uniformly from the Euclidean unit ball in `d` dimensions.
2744
+ *
2745
+ * Only the Euclidean `p=2` case is currently supported.
2746
+ */
2747
+ declare const ball: OwnedFunction<(key: ArrayLike, d: number, args_2?: {
2748
+ p?: number | undefined;
2749
+ shape?: number[] | undefined;
2750
+ } | undefined) => Array>;
2678
2751
  /**
2679
2752
  * Sample Bernoulli random variables with given mean (0,1 categorical).
2680
2753
  *
@@ -2715,11 +2788,42 @@ declare const categorical: OwnedFunction<(key: ArrayLike, logits: ArrayLike, arg
2715
2788
  * Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
2716
2789
  */
2717
2790
  declare const cauchy: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
2791
+ /**
2792
+ * Sample from a population with optional replacement and optional probabilities.
2793
+ *
2794
+ * This implements the common JAX-compatible cases: integer populations and
2795
+ * array populations along `axis`. Probabilities `p`, if provided, are sampled
2796
+ * via `categorical(log(p))`.
2797
+ */
2798
+ declare function choice(key: Array, a: number | ArrayLike, {
2799
+ shape,
2800
+ replace,
2801
+ p,
2802
+ axis
2803
+ }?: {
2804
+ shape?: number[];
2805
+ replace?: boolean;
2806
+ p?: ArrayLike;
2807
+ axis?: number;
2808
+ }): Array;
2809
+ /**
2810
+ * @function
2811
+ * Sample double-sided Maxwell random values with the provided location and scale.
2812
+ */
2813
+ declare const doubleSidedMaxwell: OwnedFunction<(key: ArrayLike, loc: ArrayLike, scale: ArrayLike, shape?: number[] | undefined) => Array>;
2718
2814
  /**
2719
2815
  * @function
2720
2816
  * Sample exponential random values according to `p(x) = exp(-x)`.
2721
2817
  */
2722
2818
  declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
2819
+ /**
2820
+ * @function
2821
+ * Sample geometric random values: the number of trials until first success.
2822
+ */
2823
+ declare const geometric: OwnedFunction<(key: ArrayLike, p: ArrayLike, args_2?: {
2824
+ shape?: number[] | undefined;
2825
+ dtype?: DType | undefined;
2826
+ } | undefined) => Array>;
2723
2827
  /**
2724
2828
  * @function
2725
2829
  * Sample from a Gumbel distribution with location 0 and scale 1.
@@ -2735,6 +2839,23 @@ declare const gumbel: OwnedFunction<(key: ArrayLike, shape?: number[] | undefine
2735
2839
  * Inverting: `x = -sign(u - 0.5) * log(1 - 2 * |u - 0.5|)`.
2736
2840
  */
2737
2841
  declare const laplace: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
2842
+ /**
2843
+ * @function
2844
+ * Sample from a logistic distribution with location 0 and scale 1.
2845
+ *
2846
+ * Uses inverse transform sampling: `x = log(u) - log(1-u)`.
2847
+ */
2848
+ declare const logistic: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
2849
+ /**
2850
+ * @function
2851
+ * Sample log-normal random values: `exp(sigma * normal(key, shape))`.
2852
+ */
2853
+ declare const lognormal: OwnedFunction<(key: ArrayLike, sigma?: ArrayLike | undefined, shape?: number[] | undefined) => Array>;
2854
+ /**
2855
+ * @function
2856
+ * Sample Maxwell-distributed random values.
2857
+ */
2858
+ declare const maxwell: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
2738
2859
  /**
2739
2860
  * @function
2740
2861
  * Sample multivariate normal random values with given mean and covariance.
@@ -2761,6 +2882,53 @@ declare const multivariateNormal: OwnedFunction<(key: ArrayLike, mean: ArrayLike
2761
2882
  * bitwise identical to JAX.
2762
2883
  */
2763
2884
  declare const normal: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
2885
+ /**
2886
+ * @function
2887
+ * Sample from a Pareto distribution with shape parameter `b` and support [1, ∞).
2888
+ */
2889
+ declare const pareto: OwnedFunction<(key: ArrayLike, b: ArrayLike, shape?: number[] | undefined) => Array>;
2890
+ /**
2891
+ * Return a random permutation of an integer range or of an array along `axis`.
2892
+ */
2893
+ declare function permutation(key: Array, x: number | ArrayLike, axis?: number): Array;
2894
+ /**
2895
+ * @function
2896
+ * Sample Rademacher random values, uniformly from {-1, 1}.
2897
+ */
2898
+ declare const rademacher: OwnedFunction<(key: ArrayLike, args_1?: {
2899
+ shape?: number[] | undefined;
2900
+ dtype?: DType | undefined;
2901
+ } | undefined) => Array>;
2902
+ /**
2903
+ * @function
2904
+ * Sample integer values uniformly from `[minval, maxval)`.
2905
+ *
2906
+ * This uses modulo reduction of uniform 32-bit random bits. For ranges that do
2907
+ * not divide 2^32, this introduces a very small modulo bias.
2908
+ */
2909
+ declare const randint: OwnedFunction<(key: ArrayLike, args_1: {
2910
+ minval: number;
2911
+ maxval: number;
2912
+ shape?: number[] | undefined;
2913
+ dtype?: DType | undefined;
2914
+ }) => Array>;
2915
+ /**
2916
+ * @function
2917
+ * Sample Rayleigh random values with the provided scale parameter.
2918
+ */
2919
+ declare const rayleigh: OwnedFunction<(key: ArrayLike, scale?: ArrayLike | undefined, shape?: number[] | undefined) => Array>;
2920
+ /**
2921
+ * @function
2922
+ * Sample triangular random values on `[left, right]` with the given mode.
2923
+ */
2924
+ declare const triangular: OwnedFunction<(key: ArrayLike, left: ArrayLike, mode: ArrayLike, right: ArrayLike, shape?: number[] | undefined) => Array>;
2925
+ /**
2926
+ * @function
2927
+ * Sample Weibull minimum random values.
2928
+ *
2929
+ * Uses `scale * exponential(key) ** (1 / concentration)`.
2930
+ */
2931
+ declare const weibullMin: OwnedFunction<(key: ArrayLike, scale: ArrayLike, concentration: ArrayLike, shape?: number[] | undefined) => Array>;
2764
2932
  declare namespace scipy_special_d_exports {
2765
2933
  export { erf, erfc, logSoftmax, logit, logsumexp, softmax };
2766
2934
  }
@@ -2958,4 +3126,4 @@ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
2958
3126
  */
2959
3127
  declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
2960
3128
  //#endregion
2961
- export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, profiler, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
3129
+ export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, getWebGPUDevice, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, profiler, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
package/dist/index.d.ts CHANGED
@@ -229,6 +229,8 @@ declare class AluExp implements FpHashable {
229
229
  static cast(dtype: DType, a: AluExp): AluExp;
230
230
  static bitcast(dtype: DType, a: AluExp): AluExp;
231
231
  static threefry2x32(k0: AluExp, k1: AluExp, c0: AluExp, c1: AluExp, mode?: "xor" | 0 | 1): AluExp;
232
+ static bitCombine(a: AluExp, b: AluExp, mode: "and" | "or" | "xor"): AluExp;
233
+ static bitShift(a: AluExp, b: AluExp, mode: "shl" | "shr"): AluExp;
232
234
  static cmplt(a: AluExp, b: AluExp): AluExp;
233
235
  static cmpne(a: AluExp, b: AluExp): AluExp;
234
236
  static where(cond: AluExp, a: AluExp, b: AluExp): AluExp;
@@ -320,6 +322,11 @@ declare enum AluOp {
320
322
  Reciprocal = "Reciprocal",
321
323
  Cast = "Cast",
322
324
  Bitcast = "Bitcast",
325
+ BitCombine = "BitCombine",
326
+ // arg = 'or' | 'and' | 'xor'
327
+ BitInvert = "BitInvert",
328
+ BitShift = "BitShift",
329
+ // arg = 'shl' | 'shr'
323
330
  Cmplt = "Cmplt",
324
331
  Cmpne = "Cmpne",
325
332
  Where = "Where",
@@ -543,6 +550,11 @@ declare class Executable<T = any> {
543
550
  source: Kernel | Routine, /** Extra data specific to the backend running this executable. */
544
551
  data: T);
545
552
  }
553
+ /**
554
+ * If the WebGPU backend has been initialized, return the `GPUDevice` that this
555
+ * backend runs on. This is useful for sharing buffers.
556
+ */
557
+ declare function getWebGPUDevice(): GPUDevice;
546
558
  declare namespace tree_d_exports {
547
559
  export { JsTree, JsTreeDef, MapJsTree, NodeType, dispose, flatten, leaves, map, ref, structure, unflatten };
548
560
  }
@@ -716,6 +728,8 @@ declare enum Primitive {
716
728
  // uses sign of numerator, C-style, matches JS but not Python
717
729
  Min = "min",
718
730
  Max = "max",
731
+ BitCombine = "bit_combine",
732
+ BitShift = "bit_shift",
719
733
  Neg = "neg",
720
734
  Reciprocal = "reciprocal",
721
735
  Floor = "floor",
@@ -764,6 +778,12 @@ declare enum Primitive {
764
778
  Jit = "jit",
765
779
  }
766
780
  interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
781
+ [Primitive.BitCombine]: {
782
+ op: "and" | "or" | "xor";
783
+ };
784
+ [Primitive.BitShift]: {
785
+ op: "shl" | "shr";
786
+ };
767
787
  [Primitive.Cast]: {
768
788
  dtype: DType;
769
789
  };
@@ -1191,6 +1211,19 @@ declare class Array extends Tracer {
1191
1211
  * recommended for performance reasons, as it will block rendering.
1192
1212
  */
1193
1213
  dataSync(): DataArray;
1214
+ /**
1215
+ * Return this array as a WebGPU buffer (with `STORAGE | COPY_SRC`).
1216
+ *
1217
+ * Only available on the WebGPU backend. The array's memory is still managed
1218
+ * by jax-js, and it will be freed when the buffer is no longer in use. You
1219
+ * _should not_ mutate the buffer's contents.
1220
+ *
1221
+ * Note that the GPU buffer may be slightly larger than the array's size; it
1222
+ * will always be aligned to 4 bytes.
1223
+ */
1224
+ gpuBuffer(): Promise<GPUBuffer>;
1225
+ /** Synchronous version of `Array.gpuBuffer()`. */
1226
+ gpuBufferSync(): GPUBuffer;
1194
1227
  /**
1195
1228
  * Convert this array into a JavaScript object.
1196
1229
  *
@@ -1568,7 +1601,7 @@ declare function fft(a: ComplexPair, axis?: number): ComplexPair;
1568
1601
  */
1569
1602
  declare function ifft(a: ComplexPair, axis?: number): ComplexPair;
1570
1603
  declare namespace numpy_linalg_d_exports {
1571
- export { cholesky, cross$1 as cross, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot };
1604
+ export { cholesky, cross$1 as cross, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot, vectorNorm };
1572
1605
  }
1573
1606
  /**
1574
1607
  * Compute the Cholesky decomposition of a (batched) positive-definite matrix.
@@ -1623,6 +1656,24 @@ declare function slogdet(a: ArrayLike): [Array, Array];
1623
1656
  * @returns Solution `x` of shape `(..., N)` or `(..., N, M)`.
1624
1657
  */
1625
1658
  declare function solve(a: ArrayLike, b: ArrayLike): Array;
1659
+ /**
1660
+ * Compute the vector norm of an array.
1661
+ *
1662
+ * @param x - Input array.
1663
+ * @param ord - Order of the norm (default 2). Supports `Infinity`, `-Infinity`, `0`, or any real number.
1664
+ * @param axis - Axis/axes to reduce over (default: all axes).
1665
+ * @param keepdims - Whether to keep reduced dimensions as size 1.
1666
+ * @returns The norm of `x`, reduced over the given axes.
1667
+ */
1668
+ declare function vectorNorm(x: ArrayLike, {
1669
+ ord,
1670
+ axis,
1671
+ keepdims
1672
+ }?: {
1673
+ ord?: number;
1674
+ axis?: number | number[] | null;
1675
+ keepdims?: boolean;
1676
+ }): Array;
1626
1677
  //#endregion
1627
1678
  //#region src/library/numpy/dtype-info.d.ts
1628
1679
  /** @inline */
@@ -1676,7 +1727,7 @@ type IInfo = Readonly<{
1676
1727
  /** Machine limits for integer types. */
1677
1728
  declare function iinfo(dtype: DType): IInfo;
1678
1729
  declare namespace numpy_d_exports {
1679
- export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, round as around, array, arrayEqual, arrayEquiv, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, average, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, copysign, corrcoef, correlate, cos, cosh, cov, cross, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logicalAnd, logicalNot, logicalOr, logicalXor, logspace, matmul, matrixTranspose, matvec, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, nanToNum, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, rint, round, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vecmat, vstack, where, zeros, zerosLike };
1730
+ export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, round as around, array, arrayEqual, arrayEquiv, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, average, bitwiseAnd, invert as bitwiseInvert, leftShift as bitwiseLeftShift, invert as bitwiseNot, bitwiseOr, rightShift as bitwiseRightShift, bitwiseXor, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, copysign, corrcoef, correlate, cos, cosh, cov, cross, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, invert, isfinite, isinf, isnan, isneginf, isposinf, ldexp, leftShift, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logicalAnd, logicalNot, logicalOr, logicalXor, logspace, matmul, matrixTranspose, matvec, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, nanToNum, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, rightShift, rint, round, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vecmat, vstack, where, zeros, zerosLike };
1680
1731
  }
1681
1732
  declare const float32 = DType.Float32;
1682
1733
  declare const int32 = DType.Int32;
@@ -1744,6 +1795,18 @@ declare function logicalOr(x: ArrayLike, y: ArrayLike): Array;
1744
1795
  declare function logicalXor(x: ArrayLike, y: ArrayLike): Array;
1745
1796
  /** Compute element-wise logical NOT. */
1746
1797
  declare function logicalNot(x: ArrayLike): Array;
1798
+ /** Compute element-wise bitwise AND. */
1799
+ declare function bitwiseAnd(x: ArrayLike, y: ArrayLike): Array;
1800
+ /** Compute element-wise bitwise OR. */
1801
+ declare function bitwiseOr(x: ArrayLike, y: ArrayLike): Array;
1802
+ /** Compute element-wise bitwise XOR. */
1803
+ declare function bitwiseXor(x: ArrayLike, y: ArrayLike): Array;
1804
+ /** Compute element-wise bitwise NOT (inversion). */
1805
+ declare function invert(x: ArrayLike): Array;
1806
+ /** Compute element-wise left bit shift. */
1807
+ declare function leftShift(x: ArrayLike, y: ArrayLike): Array;
1808
+ /** Compute element-wise right bit shift. */
1809
+ declare function rightShift(x: ArrayLike, y: ArrayLike): Array;
1747
1810
  /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
1748
1811
  declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
1749
1812
  /**
@@ -2656,7 +2719,7 @@ declare function dotProductAttention(query: ArrayLike, key: ArrayLike, value: Ar
2656
2719
  localWindowSize?: number | [number, number];
2657
2720
  }): Array;
2658
2721
  declare namespace random_d_exports {
2659
- export { bernoulli, bits, categorical, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
2722
+ export { ball, bernoulli, bits, categorical, cauchy, choice, doubleSidedMaxwell, exponential, geometric, gumbel, key, laplace, logistic, lognormal, maxwell, multivariateNormal, normal, pareto, permutation, rademacher, randint, rayleigh, split, triangular, uniform, weibullMin };
2660
2723
  }
2661
2724
  /** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
2662
2725
  declare function key(seed: ArrayLike): Array;
@@ -2672,6 +2735,16 @@ declare const uniform: OwnedFunction<(key: ArrayLike, shape?: number[] | undefin
2672
2735
  minval?: number | undefined;
2673
2736
  maxval?: number | undefined;
2674
2737
  } | undefined) => Array>;
2738
+ /**
2739
+ * @function
2740
+ * Sample points uniformly from the Euclidean unit ball in `d` dimensions.
2741
+ *
2742
+ * Only the Euclidean `p=2` case is currently supported.
2743
+ */
2744
+ declare const ball: OwnedFunction<(key: ArrayLike, d: number, args_2?: {
2745
+ p?: number | undefined;
2746
+ shape?: number[] | undefined;
2747
+ } | undefined) => Array>;
2675
2748
  /**
2676
2749
  * Sample Bernoulli random variables with given mean (0,1 categorical).
2677
2750
  *
@@ -2712,11 +2785,42 @@ declare const categorical: OwnedFunction<(key: ArrayLike, logits: ArrayLike, arg
2712
2785
  * Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
2713
2786
  */
2714
2787
  declare const cauchy: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
2788
+ /**
2789
+ * Sample from a population with optional replacement and optional probabilities.
2790
+ *
2791
+ * This implements the common JAX-compatible cases: integer populations and
2792
+ * array populations along `axis`. Probabilities `p`, if provided, are sampled
2793
+ * via `categorical(log(p))`.
2794
+ */
2795
+ declare function choice(key: Array, a: number | ArrayLike, {
2796
+ shape,
2797
+ replace,
2798
+ p,
2799
+ axis
2800
+ }?: {
2801
+ shape?: number[];
2802
+ replace?: boolean;
2803
+ p?: ArrayLike;
2804
+ axis?: number;
2805
+ }): Array;
2806
+ /**
2807
+ * @function
2808
+ * Sample double-sided Maxwell random values with the provided location and scale.
2809
+ */
2810
+ declare const doubleSidedMaxwell: OwnedFunction<(key: ArrayLike, loc: ArrayLike, scale: ArrayLike, shape?: number[] | undefined) => Array>;
2715
2811
  /**
2716
2812
  * @function
2717
2813
  * Sample exponential random values according to `p(x) = exp(-x)`.
2718
2814
  */
2719
2815
  declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
2816
+ /**
2817
+ * @function
2818
+ * Sample geometric random values: the number of trials until first success.
2819
+ */
2820
+ declare const geometric: OwnedFunction<(key: ArrayLike, p: ArrayLike, args_2?: {
2821
+ shape?: number[] | undefined;
2822
+ dtype?: DType | undefined;
2823
+ } | undefined) => Array>;
2720
2824
  /**
2721
2825
  * @function
2722
2826
  * Sample from a Gumbel distribution with location 0 and scale 1.
@@ -2732,6 +2836,23 @@ declare const gumbel: OwnedFunction<(key: ArrayLike, shape?: number[] | undefine
2732
2836
  * Inverting: `x = -sign(u - 0.5) * log(1 - 2 * |u - 0.5|)`.
2733
2837
  */
2734
2838
  declare const laplace: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
2839
+ /**
2840
+ * @function
2841
+ * Sample from a logistic distribution with location 0 and scale 1.
2842
+ *
2843
+ * Uses inverse transform sampling: `x = log(u) - log(1-u)`.
2844
+ */
2845
+ declare const logistic: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
2846
+ /**
2847
+ * @function
2848
+ * Sample log-normal random values: `exp(sigma * normal(key, shape))`.
2849
+ */
2850
+ declare const lognormal: OwnedFunction<(key: ArrayLike, sigma?: ArrayLike | undefined, shape?: number[] | undefined) => Array>;
2851
+ /**
2852
+ * @function
2853
+ * Sample Maxwell-distributed random values.
2854
+ */
2855
+ declare const maxwell: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
2735
2856
  /**
2736
2857
  * @function
2737
2858
  * Sample multivariate normal random values with given mean and covariance.
@@ -2758,6 +2879,53 @@ declare const multivariateNormal: OwnedFunction<(key: ArrayLike, mean: ArrayLike
2758
2879
  * bitwise identical to JAX.
2759
2880
  */
2760
2881
  declare const normal: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
2882
+ /**
2883
+ * @function
2884
+ * Sample from a Pareto distribution with shape parameter `b` and support [1, ∞).
2885
+ */
2886
+ declare const pareto: OwnedFunction<(key: ArrayLike, b: ArrayLike, shape?: number[] | undefined) => Array>;
2887
+ /**
2888
+ * Return a random permutation of an integer range or of an array along `axis`.
2889
+ */
2890
+ declare function permutation(key: Array, x: number | ArrayLike, axis?: number): Array;
2891
+ /**
2892
+ * @function
2893
+ * Sample Rademacher random values, uniformly from {-1, 1}.
2894
+ */
2895
+ declare const rademacher: OwnedFunction<(key: ArrayLike, args_1?: {
2896
+ shape?: number[] | undefined;
2897
+ dtype?: DType | undefined;
2898
+ } | undefined) => Array>;
2899
+ /**
2900
+ * @function
2901
+ * Sample integer values uniformly from `[minval, maxval)`.
2902
+ *
2903
+ * This uses modulo reduction of uniform 32-bit random bits. For ranges that do
2904
+ * not divide 2^32, this introduces a very small modulo bias.
2905
+ */
2906
+ declare const randint: OwnedFunction<(key: ArrayLike, args_1: {
2907
+ minval: number;
2908
+ maxval: number;
2909
+ shape?: number[] | undefined;
2910
+ dtype?: DType | undefined;
2911
+ }) => Array>;
2912
+ /**
2913
+ * @function
2914
+ * Sample Rayleigh random values with the provided scale parameter.
2915
+ */
2916
+ declare const rayleigh: OwnedFunction<(key: ArrayLike, scale?: ArrayLike | undefined, shape?: number[] | undefined) => Array>;
2917
+ /**
2918
+ * @function
2919
+ * Sample triangular random values on `[left, right]` with the given mode.
2920
+ */
2921
+ declare const triangular: OwnedFunction<(key: ArrayLike, left: ArrayLike, mode: ArrayLike, right: ArrayLike, shape?: number[] | undefined) => Array>;
2922
+ /**
2923
+ * @function
2924
+ * Sample Weibull minimum random values.
2925
+ *
2926
+ * Uses `scale * exponential(key) ** (1 / concentration)`.
2927
+ */
2928
+ declare const weibullMin: OwnedFunction<(key: ArrayLike, scale: ArrayLike, concentration: ArrayLike, shape?: number[] | undefined) => Array>;
2761
2929
  declare namespace scipy_special_d_exports {
2762
2930
  export { erf, erfc, logSoftmax, logit, logsumexp, softmax };
2763
2931
  }
@@ -2955,4 +3123,4 @@ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
2955
3123
  */
2956
3124
  declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
2957
3125
  //#endregion
2958
- export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, profiler, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
3126
+ export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, getWebGPUDevice, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, profiler, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };