@jax-js/jax 0.1.3 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +15 -9
- package/dist/{backend-BY8wlLEl.js → backend-DaqL-MNz.js} +240 -21
- package/dist/{backend-CmaidnkQ.cjs → backend-DziQSaoQ.cjs} +264 -21
- package/dist/index.cjs +2407 -1132
- package/dist/index.d.cts +596 -97
- package/dist/index.d.ts +596 -97
- package/dist/index.js +2400 -1126
- package/dist/webgl-ClIYb8jP.cjs +522 -0
- package/dist/webgl-RSuZKvgc.js +522 -0
- package/dist/webgpu-Db2JrNBr.cjs +1261 -0
- package/dist/webgpu-Dh7k9io0.js +1261 -0
- package/package.json +1 -1
- package/dist/webgpu-BVns4DbI.cjs +0 -663
- package/dist/webgpu-C9iAP5h5.js +0 -663
package/dist/index.d.cts
CHANGED
|
@@ -124,7 +124,6 @@ declare class ShapeTracker {
|
|
|
124
124
|
/** Like pad(), but allows for negative values. */
|
|
125
125
|
padOrShrink(arg: Pair[]): ShapeTracker;
|
|
126
126
|
}
|
|
127
|
-
//# sourceMappingURL=shape.d.ts.map
|
|
128
127
|
//#endregion
|
|
129
128
|
//#region src/utils.d.ts
|
|
130
129
|
/**
|
|
@@ -407,8 +406,81 @@ declare class Reduction implements FpHashable {
|
|
|
407
406
|
}
|
|
408
407
|
/** Expression for accessing `indices` in input array with the given shape. */
|
|
409
408
|
//#endregion
|
|
409
|
+
//#region src/routine.d.ts
|
|
410
|
+
/**
|
|
411
|
+
* Advanced operations that don't fit into the `AluExp` compiler representation.
|
|
412
|
+
*
|
|
413
|
+
* Some routines like iterative matrix algorithms, FFTs, or sorting may not be
|
|
414
|
+
* easy to express efficiently as a `Kernel` object. These also tend to be
|
|
415
|
+
* somewhat expensive, so the benefit of kernel fusion and inlining is less
|
|
416
|
+
* relevant.
|
|
417
|
+
*
|
|
418
|
+
* For these operations, we dispatch them as a custom operation on the backend,
|
|
419
|
+
* which each backend implements in a specific way. These are listed in the
|
|
420
|
+
* `Routines` enum below.
|
|
421
|
+
*
|
|
422
|
+
* Routines cannot be fused into other kernels and always operate on contiguous
|
|
423
|
+
* arrays (default `ShapeTracker`).
|
|
424
|
+
*/
|
|
425
|
+
declare class Routine {
|
|
426
|
+
/** The name of the routine. */
|
|
427
|
+
readonly name: Routines;
|
|
428
|
+
/** Dtype and shape of the inputs and outputs. */
|
|
429
|
+
readonly type: RoutineType;
|
|
430
|
+
/** Extra parameters specific to the routine. */
|
|
431
|
+
readonly params?: any | undefined;
|
|
432
|
+
constructor(/** The name of the routine. */
|
|
433
|
+
name: Routines, /** Dtype and shape of the inputs and outputs. */
|
|
434
|
+
type: RoutineType, /** Extra parameters specific to the routine. */
|
|
435
|
+
params?: any | undefined);
|
|
436
|
+
}
|
|
437
|
+
/** One of the valid `Routine` that can be dispatched to backend. */
|
|
438
|
+
declare enum Routines {
|
|
439
|
+
/** Stable sorting algorithm along the last axis. */
|
|
440
|
+
Sort = "Sort",
|
|
441
|
+
/** Returns `int32` indices of the stably sorted array. */
|
|
442
|
+
Argsort = "Argsort",
|
|
443
|
+
/**
|
|
444
|
+
* Solve a triangular system of equations.
|
|
445
|
+
*
|
|
446
|
+
* The first batch of inputs `A` should be of shape `[..., N, N]` and upper
|
|
447
|
+
* triangular, while the second batch `B` should be of shape `[..., M, N]`.
|
|
448
|
+
*
|
|
449
|
+
* Solves for `X` in the equation `A @ X.T = B.T`, where `A` is the
|
|
450
|
+
* triangular matrix. This is equivalent to `X = B @ A^-T`.
|
|
451
|
+
*/
|
|
452
|
+
TriangularSolve = "TriangularSolve",
|
|
453
|
+
/**
|
|
454
|
+
* Cholesky decomposition of 2D positive semi-definite matrices.
|
|
455
|
+
*
|
|
456
|
+
* The input batch should be of shape `[..., N, N]`, and the output batch is
|
|
457
|
+
* of the same shape, containing the lower-triangular matrix `L` such that
|
|
458
|
+
* `A = L @ L.T`. Behavior is unspecified if A is not positive semi-definite.
|
|
459
|
+
*/
|
|
460
|
+
Cholesky = "Cholesky",
|
|
461
|
+
/**
|
|
462
|
+
* LU decomposition of 2D rectangular matrices.
|
|
463
|
+
*
|
|
464
|
+
* The input is a batch of shape `[..., M, N]`, and the output is a tuple of
|
|
465
|
+
* three arrays: `LU, Pivots, Permutation`.
|
|
466
|
+
*
|
|
467
|
+
* - `LU` is of shape `[..., M, N]`, containing the combined lower and upper
|
|
468
|
+
* triangular matrices. (lower triangular = implicit unit diagonal)
|
|
469
|
+
* - `Pivots` is of shape `[..., min(M, N)]`, containing the row swaps.
|
|
470
|
+
* - `Permutation` is of shape `[..., M]`, containing the permutation vector
|
|
471
|
+
* such that `P = eye(M).slice(Permutation)` -> `P @ A = L @ U`.
|
|
472
|
+
*/
|
|
473
|
+
LU = "LU",
|
|
474
|
+
}
|
|
475
|
+
interface RoutineType {
|
|
476
|
+
inputShapes: number[][];
|
|
477
|
+
inputDtypes: DType[];
|
|
478
|
+
outputShapes: number[][];
|
|
479
|
+
outputDtypes: DType[];
|
|
480
|
+
}
|
|
481
|
+
//#endregion
|
|
410
482
|
//#region src/backend.d.ts
|
|
411
|
-
type Device = "cpu" | "wasm" | "webgpu";
|
|
483
|
+
type Device = "cpu" | "wasm" | "webgpu" | "webgl";
|
|
412
484
|
declare const devices: Device[];
|
|
413
485
|
/** Configure the default device for arrays. */
|
|
414
486
|
declare function defaultDevice(device?: Device): Device;
|
|
@@ -444,9 +516,13 @@ interface Backend {
|
|
|
444
516
|
/** Read a range of bytes from a buffer, blocking variant. */
|
|
445
517
|
readSync(slot: Slot, start?: number, count?: number): Uint8Array<ArrayBuffer>;
|
|
446
518
|
/** Prepare an expression to be executed later. */
|
|
447
|
-
|
|
519
|
+
prepareKernel(kernel: Kernel): Promise<Executable>;
|
|
448
520
|
/** Prepare an expression to be executed later, blocking variant. */
|
|
449
|
-
|
|
521
|
+
prepareKernelSync(kernel: Kernel): Executable;
|
|
522
|
+
/** Prepare an advanced routine to be executed later. */
|
|
523
|
+
prepareRoutine(routine: Routine): Promise<Executable>;
|
|
524
|
+
/** Prepare an advanced routine to be executed later, blocking variant. */
|
|
525
|
+
prepareRoutineSync(routine: Routine): Executable;
|
|
450
526
|
/**
|
|
451
527
|
* Run a backend operation that was previously prepared.
|
|
452
528
|
*
|
|
@@ -457,14 +533,140 @@ interface Backend {
|
|
|
457
533
|
dispatch(exe: Executable, inputs: Slot[], outputs: Slot[]): void;
|
|
458
534
|
}
|
|
459
535
|
declare class Executable<T = any> {
|
|
460
|
-
|
|
461
|
-
|
|
536
|
+
/** The `Kernel` or `Routine` that was prepared. */
|
|
537
|
+
readonly source: Kernel | Routine;
|
|
538
|
+
/** Extra data specific to the backend running this executable. */
|
|
462
539
|
readonly data: T;
|
|
463
|
-
constructor(
|
|
540
|
+
constructor(/** The `Kernel` or `Routine` that was prepared. */
|
|
541
|
+
source: Kernel | Routine, /** Extra data specific to the backend running this executable. */
|
|
464
542
|
data: T);
|
|
465
543
|
}
|
|
544
|
+
declare namespace numpy_fft_d_exports {
|
|
545
|
+
export { ComplexPair, fft, ifft };
|
|
546
|
+
}
|
|
547
|
+
/**
|
|
548
|
+
* A pair of arrays representing real and imaginary part `a + bj`. Both arrays
|
|
549
|
+
* must have the same shape.
|
|
550
|
+
*/
|
|
551
|
+
type ComplexPair = {
|
|
552
|
+
real: Array;
|
|
553
|
+
imag: Array;
|
|
554
|
+
};
|
|
555
|
+
/**
|
|
556
|
+
* Compute a one-dimensional discrete Fourier transform.
|
|
557
|
+
*
|
|
558
|
+
* Currently, the size of the axis must be a power of two.
|
|
559
|
+
*/
|
|
560
|
+
declare function fft(a: ComplexPair, axis?: number): ComplexPair;
|
|
561
|
+
/**
|
|
562
|
+
* Compute a one-dimensional inverse discrete Fourier transform.
|
|
563
|
+
*
|
|
564
|
+
* Currently, the size of the axis must be a power of two.
|
|
565
|
+
*/
|
|
566
|
+
declare function ifft(a: ComplexPair, axis?: number): ComplexPair;
|
|
567
|
+
declare namespace numpy_linalg_d_exports {
|
|
568
|
+
export { cholesky$1 as cholesky, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot };
|
|
569
|
+
}
|
|
570
|
+
/**
|
|
571
|
+
* Compute the Cholesky decomposition of a (batched) positive-definite matrix.
|
|
572
|
+
*
|
|
573
|
+
* This is like `jax.lax.linalg.cholesky()`, except with an option to symmetrize
|
|
574
|
+
* the input matrix, which is on by default.
|
|
575
|
+
*/
|
|
576
|
+
declare function cholesky$1(a: ArrayLike, {
|
|
577
|
+
upper,
|
|
578
|
+
symmetrizeInput
|
|
579
|
+
}?: {
|
|
580
|
+
upper?: boolean;
|
|
581
|
+
symmetrizeInput?: boolean;
|
|
582
|
+
}): Array;
|
|
583
|
+
/** Compute the determinant of a square matrix (batched). */
|
|
584
|
+
declare function det(a: ArrayLike): Array;
|
|
585
|
+
/** Compute the inverse of a square matrix (batched). */
|
|
586
|
+
declare function inv(a: ArrayLike): Array;
|
|
587
|
+
/**
|
|
588
|
+
* Return the least-squares solution to a linear equation.
|
|
589
|
+
*
|
|
590
|
+
* For overdetermined systems, this finds the `x` that minimizes `norm(ax - b)`.
|
|
591
|
+
* For underdetermined systems, this finds the minimum-norm solution for `x`.
|
|
592
|
+
*
|
|
593
|
+
* This currently uses Cholesky decomposition to solve the normal equations,
|
|
594
|
+
* under the hood. The method is not as robust as QR or SVD.
|
|
595
|
+
*
|
|
596
|
+
* @param a coefficient matrix of shape `(M, N)`
|
|
597
|
+
* @param b right-hand side of shape `(M,)` or `(M, K)`
|
|
598
|
+
* @return least-squares solution of shape `(N,)` or `(N, K)`
|
|
599
|
+
*/
|
|
600
|
+
declare function lstsq(a: ArrayLike, b: ArrayLike): Array;
|
|
601
|
+
/** Raise a square matrix to an integer power, via repeated squarings. */
|
|
602
|
+
declare function matrixPower(a: ArrayLike, n: number): Array;
|
|
603
|
+
/** Return sign and natural logarithm of the determinant of `a`. */
|
|
604
|
+
declare function slogdet(a: ArrayLike): [Array, Array];
|
|
605
|
+
/**
|
|
606
|
+
* Solve a linear system of equations.
|
|
607
|
+
*
|
|
608
|
+
* This solves a (batched) linear system of equations `a @ x = b` for `x` given
|
|
609
|
+
* `a` and `b`. If `a` is singular, this will return `nan` or `inf` values.
|
|
610
|
+
*
|
|
611
|
+
* @param a - Coefficient matrix of shape `(..., N, N)`.
|
|
612
|
+
* @param b - Values of shape `(N,)` or `(..., N, M)`.
|
|
613
|
+
* @returns Solution `x` of shape `(..., N)` or `(..., N, M)`.
|
|
614
|
+
*/
|
|
615
|
+
declare function solve(a: ArrayLike, b: ArrayLike): Array;
|
|
616
|
+
//#endregion
|
|
617
|
+
//#region src/library/numpy/dtype-info.d.ts
|
|
618
|
+
/** @inline */
|
|
619
|
+
type FInfo = Readonly<{
|
|
620
|
+
/** The number of bits occupied by the type. */
|
|
621
|
+
bits: number;
|
|
622
|
+
/** Returns the _dtype_ for which finfo returns information. */
|
|
623
|
+
dtype: DType;
|
|
624
|
+
/** The difference between 1.0 and the next smallest representable float larger than 1.0. */
|
|
625
|
+
eps: number;
|
|
626
|
+
/** The difference between 1.0 and the next largest representable float smaller than 1.0. */
|
|
627
|
+
epsneg: number;
|
|
628
|
+
/** The exponent that yields `eps`. */
|
|
629
|
+
machep: number;
|
|
630
|
+
/** The largest representable finite number. */
|
|
631
|
+
max: number;
|
|
632
|
+
/** The smallest positive power of the base (2) that causes overflow. */
|
|
633
|
+
maxexp: number;
|
|
634
|
+
/** The smallest representable (most negative) finite number. */
|
|
635
|
+
min: number;
|
|
636
|
+
/** The largest negative power of the base (2) without leading zeros in mantissa. */
|
|
637
|
+
minexp: number;
|
|
638
|
+
/** The exponent that yields `epsneg`. */
|
|
639
|
+
negep: number;
|
|
640
|
+
/** Number of bits in the exponent portion. */
|
|
641
|
+
nexp: number;
|
|
642
|
+
/** Number of bits in the mantissa portion. */
|
|
643
|
+
nmant: number;
|
|
644
|
+
/** The approximate number of decimal digits to which this kind of float is precise. */
|
|
645
|
+
precision: number;
|
|
646
|
+
/** The approximate decimal resolution, i.e., `10 ** -precision`. */
|
|
647
|
+
resolution: number;
|
|
648
|
+
/** The smallest positive normal number. */
|
|
649
|
+
smallestNormal: number;
|
|
650
|
+
/** The smallest positive subnormal number. */
|
|
651
|
+
smallestSubnormal: number;
|
|
652
|
+
}>;
|
|
653
|
+
/** Machine limits for floating-point types. */
|
|
654
|
+
declare function finfo(dtype: DType): FInfo;
|
|
655
|
+
/** @inline */
|
|
656
|
+
type IInfo = Readonly<{
|
|
657
|
+
/** The number of bits occupied by the type. */
|
|
658
|
+
bits: number;
|
|
659
|
+
/** Returns the _dtype_ for which iinfo returns information. */
|
|
660
|
+
dtype: DType;
|
|
661
|
+
/** The largest representable integer. */
|
|
662
|
+
max: number;
|
|
663
|
+
/** The smallest representable integer. */
|
|
664
|
+
min: number;
|
|
665
|
+
}>;
|
|
666
|
+
/** Machine limits for integer types. */
|
|
667
|
+
declare function iinfo(dtype: DType): IInfo;
|
|
466
668
|
declare namespace numpy_d_exports {
|
|
467
|
-
export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, cos, cosh, cumsum, cumulativeSum, deg2rad, degrees, diag, diagonal, divide, dot$1 as dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, float64, floor, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, positive, pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, squeeze, stack, std, subtract, sum, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
669
|
+
export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot$1 as dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logspace, matmul, matrixTranspose, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
468
670
|
}
|
|
469
671
|
declare const float32 = DType.Float32;
|
|
470
672
|
declare const int32 = DType.Int32;
|
|
@@ -590,6 +792,20 @@ declare function prod(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
|
590
792
|
declare function min(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
591
793
|
/** Return the maximum of array elements along a given axis. */
|
|
592
794
|
declare function max(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
795
|
+
/**
|
|
796
|
+
* Test whether all array elements along a given axis evaluate to True.
|
|
797
|
+
*
|
|
798
|
+
* Returns a boolean array with the same shape as `a` with the specified axis
|
|
799
|
+
* removed. If axis is None, returns a scalar.
|
|
800
|
+
*/
|
|
801
|
+
declare function all(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
802
|
+
/**
|
|
803
|
+
* Test whether any array element along a given axis evaluates to True.
|
|
804
|
+
*
|
|
805
|
+
* Returns a boolean array with the same shape as `a` with the specified axis
|
|
806
|
+
* removed. If axis is None, returns a scalar.
|
|
807
|
+
*/
|
|
808
|
+
declare function any(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
593
809
|
/** Return the peak-to-peak range along a given axis (`max - min`). */
|
|
594
810
|
declare function ptp(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
595
811
|
/** Compute the average of the array elements along the specified axis. */
|
|
@@ -615,10 +831,18 @@ declare function argmax(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
|
|
|
615
831
|
* two-phase parallel reduction algorithm.
|
|
616
832
|
*/
|
|
617
833
|
declare function cumsum(a: ArrayLike, axis?: number): Array;
|
|
618
|
-
/** @function Alternative name for `jax.numpy.cumsum()`. */
|
|
619
|
-
declare const cumulativeSum: typeof cumsum;
|
|
620
834
|
/** Reverse the elements in an array along the given axes. */
|
|
621
835
|
declare function flip(x: ArrayLike, axis?: Axis): Array;
|
|
836
|
+
/**
|
|
837
|
+
* Split an array into multiple sub-arrays along an axis.
|
|
838
|
+
*
|
|
839
|
+
* @param a - The input array to split.
|
|
840
|
+
* @param indicesOrSections - If an integer, it indicates the number of equal
|
|
841
|
+
* sections to create along the specified axis. If a list of integers, it
|
|
842
|
+
* specifies the indices at which to split the array.
|
|
843
|
+
* @param axis - The axis along which to split the array. Default is 0.
|
|
844
|
+
*/
|
|
845
|
+
declare function split$1(a: ArrayLike, indicesOrSections: number | number[], axis?: number): Array[];
|
|
622
846
|
/**
|
|
623
847
|
* Join a sequence of arrays along an existing axis.
|
|
624
848
|
*
|
|
@@ -662,12 +886,29 @@ declare function columnStack(xs: ArrayLike[]): Array;
|
|
|
662
886
|
declare function flipud(x: ArrayLike): Array;
|
|
663
887
|
/** Flip an array horizontally (axis=1). */
|
|
664
888
|
declare function fliplr(x: ArrayLike): Array;
|
|
665
|
-
/**
|
|
666
|
-
declare
|
|
889
|
+
/** Transpose the last two dimensions of an array. */
|
|
890
|
+
declare function matrixTranspose(a: ArrayLike): Array;
|
|
667
891
|
/** Return a 1-D flattened array containing the elements of the input. */
|
|
668
892
|
declare function ravel(a: ArrayLike): Array;
|
|
669
893
|
/** Remove one or more length-1 axes from an array. */
|
|
670
894
|
declare function squeeze(a: ArrayLike, axis?: Axis): Array;
|
|
895
|
+
/**
|
|
896
|
+
* Expand the shape of an array by inserting new axes of length 1.
|
|
897
|
+
*
|
|
898
|
+
* @param a - Input array.
|
|
899
|
+
* @param axis - Position(s) in the expanded axes where the new axis (or axes)
|
|
900
|
+
* is placed. Can be a single integer or an array of integers.
|
|
901
|
+
* @returns Array with the number of dimensions increased.
|
|
902
|
+
*
|
|
903
|
+
* @example
|
|
904
|
+
* ```ts
|
|
905
|
+
* const x = np.array([1, 2]);
|
|
906
|
+
* np.expandDims(x, 0); // Shape [1, 2]
|
|
907
|
+
* np.expandDims(x, 1); // Shape [2, 1]
|
|
908
|
+
* np.expandDims(x, [0, 2]); // Shape [1, 2, 1]
|
|
909
|
+
* ```
|
|
910
|
+
*/
|
|
911
|
+
declare function expandDims(a: ArrayLike, axis: number | number[]): Array;
|
|
671
912
|
/**
|
|
672
913
|
* Repeat each element of an array after themselves.
|
|
673
914
|
*
|
|
@@ -714,6 +955,29 @@ declare function diagonal(a: ArrayLike, offset?: number, axis1?: number, axis2?:
|
|
|
714
955
|
declare function diag(v: ArrayLike, k?: number): Array;
|
|
715
956
|
/** Calculate the sum of the diagonal of an array along the given axes. */
|
|
716
957
|
declare function trace(a: ArrayLike, offset?: number, axis1?: number, axis2?: number): Array;
|
|
958
|
+
/**
|
|
959
|
+
* Return a sorted copy of an array.
|
|
960
|
+
*
|
|
961
|
+
* The array is sorted along a specified axis (the last by default). This may be
|
|
962
|
+
* an unstable sort, and it dispatches to device-specific implementation.
|
|
963
|
+
*/
|
|
964
|
+
declare function sort(a: ArrayLike, axis?: number): Array;
|
|
965
|
+
/**
|
|
966
|
+
* Return indices that would sort an array. This may be an unstable sorting
|
|
967
|
+
* algorithm; it need not preserve order of indices in ties.
|
|
968
|
+
*
|
|
969
|
+
* Returns an array of `int32` indices.
|
|
970
|
+
*
|
|
971
|
+
* The array is sorted along a specified axis (the last by default).
|
|
972
|
+
*/
|
|
973
|
+
declare function argsort(a: ArrayLike, axis?: number): Array;
|
|
974
|
+
/**
|
|
975
|
+
* Take elements from an array along an axis.
|
|
976
|
+
*
|
|
977
|
+
* This is equivalent to advanced indexing with integer indices over that
|
|
978
|
+
* numbered axis. By default, the flattened array is used.
|
|
979
|
+
*/
|
|
980
|
+
declare function take(a: ArrayLike, indices: ArrayLike, axis?: number | null): Array;
|
|
717
981
|
/** Return if two arrays are element-wise equal within a tolerance. */
|
|
718
982
|
declare function allclose(actual: Parameters<typeof array>[0], expected: Parameters<typeof array>[0], options?: {
|
|
719
983
|
rtol?: number;
|
|
@@ -785,6 +1049,10 @@ declare function vecdot(x: ArrayLike, y: ArrayLike, {
|
|
|
785
1049
|
* Like vecdot() but flattens the arguments first into vectors.
|
|
786
1050
|
*/
|
|
787
1051
|
declare function vdot(x: ArrayLike, y: ArrayLike): Array;
|
|
1052
|
+
/** Convolution of two one-dimensional arrays. */
|
|
1053
|
+
declare function convolve(x: Array, y: Array, mode?: "full" | "same" | "valid"): Array;
|
|
1054
|
+
/** Correlation of two one dimensional arrays. */
|
|
1055
|
+
declare function correlate(x: Array, y: Array, mode?: "full" | "same" | "valid"): Array;
|
|
788
1056
|
/**
|
|
789
1057
|
* Return a tuple of coordinate matrices from coordinate vectors.
|
|
790
1058
|
*
|
|
@@ -796,21 +1064,6 @@ declare function meshgrid(xs: Array[], {
|
|
|
796
1064
|
}?: {
|
|
797
1065
|
indexing?: "xy" | "ij";
|
|
798
1066
|
}): Array[];
|
|
799
|
-
/**
|
|
800
|
-
* Return an array with ones on and below the diagonal and zeros elsewhere.
|
|
801
|
-
*
|
|
802
|
-
* If `k` is provided, it specifies the sub-diagonal on and below which the
|
|
803
|
-
* array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
|
|
804
|
-
* `k>0` is above it.
|
|
805
|
-
*/
|
|
806
|
-
declare function tri(n: number, m?: number, k?: number, {
|
|
807
|
-
dtype,
|
|
808
|
-
device
|
|
809
|
-
}?: DTypeAndDevice): Array;
|
|
810
|
-
/** Return the lower triangle of an array. Must be of dimension >= 2. */
|
|
811
|
-
declare function tril(a: ArrayLike, k?: number): Array;
|
|
812
|
-
/** Return the upper triangle of an array. Must be of dimension >= 2. */
|
|
813
|
-
declare function triu(a: ArrayLike, k?: number): Array;
|
|
814
1067
|
/**
|
|
815
1068
|
* Clip (limit) the values in an array.
|
|
816
1069
|
*
|
|
@@ -827,8 +1080,6 @@ declare function clip(a: ArrayLike, min?: ArrayLike, max?: ArrayLike): Array;
|
|
|
827
1080
|
* This is the same function as `jax.numpy.abs()`.
|
|
828
1081
|
*/
|
|
829
1082
|
declare function absolute(x: ArrayLike): Array;
|
|
830
|
-
/** @function Alias of `jax.numpy.absolute()`. */
|
|
831
|
-
declare const abs: typeof absolute;
|
|
832
1083
|
/** Return an element-wise indication of sign of the input. */
|
|
833
1084
|
declare function sign(x: ArrayLike): Array;
|
|
834
1085
|
/** @function Return element-wise positive values of the input (no-op). */
|
|
@@ -857,6 +1108,17 @@ declare const heaviside: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
|
857
1108
|
declare function square(x: ArrayLike): Array;
|
|
858
1109
|
/** Element-wise tangent function (takes radians). */
|
|
859
1110
|
declare function tan(x: ArrayLike): Array;
|
|
1111
|
+
/**
|
|
1112
|
+
* @function
|
|
1113
|
+
* Return the normalized sinc function.
|
|
1114
|
+
*
|
|
1115
|
+
* The sinc function is defined as `sin(πx) / (πx)` for `x != 0`, and `1` for `x = 0`.
|
|
1116
|
+
* This is the normalized sinc function commonly used in signal processing.
|
|
1117
|
+
*
|
|
1118
|
+
* **Note:** JVP is not supported at x=0 due to discontinuous derivative. This
|
|
1119
|
+
* requires a custom JVP rule to handle properly (see JAX implementation).
|
|
1120
|
+
*/
|
|
1121
|
+
declare const sinc: OwnedFunction<(x: ArrayLike) => Array>;
|
|
860
1122
|
/** Element-wise inverse cosine function (inverse of cos). */
|
|
861
1123
|
declare function acos(x: ArrayLike): Array;
|
|
862
1124
|
/**
|
|
@@ -882,16 +1144,24 @@ declare const hypot: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
|
882
1144
|
* The output is ill-defined when both x and y are zero.
|
|
883
1145
|
*/
|
|
884
1146
|
declare const atan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
|
|
885
|
-
/** @function Alias of `jax.numpy.acos()`. */
|
|
886
|
-
declare const arccos: typeof acos;
|
|
887
|
-
/** @function Alias of `jax.numpy.atan()`. */
|
|
888
|
-
declare const arctan: (x: ArrayLike) => Array;
|
|
889
|
-
/** @function Alias of `jax.numpy.atan2()`. */
|
|
890
|
-
declare const arctan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
|
|
891
1147
|
/** Element-wise subtraction, with broadcasting. */
|
|
892
1148
|
declare function subtract(x: ArrayLike, y: ArrayLike): Array;
|
|
893
1149
|
/** Calculates the floating-point division of x by y element-wise. */
|
|
894
1150
|
declare function trueDivide(x: ArrayLike, y: ArrayLike): Array;
|
|
1151
|
+
/**
|
|
1152
|
+
* Return the largest integer smaller or equal to the division of the inputs.
|
|
1153
|
+
*
|
|
1154
|
+
* The result is always rounded towards negative infinity.
|
|
1155
|
+
*
|
|
1156
|
+
* For floating-point inputs, this is equivalent to `floor(x / y)`.
|
|
1157
|
+
* For integer inputs, we use `(x - remainder(x, y)) / y` to handle
|
|
1158
|
+
* negative values correctly (note: may overflow near int32 boundaries).
|
|
1159
|
+
*
|
|
1160
|
+
* @param x - Dividend array.
|
|
1161
|
+
* @param y - Divisor array.
|
|
1162
|
+
* @returns Element-wise floor division of x by y.
|
|
1163
|
+
*/
|
|
1164
|
+
declare function floorDivide(x: ArrayLike, y: ArrayLike): Array;
|
|
895
1165
|
/**
|
|
896
1166
|
* @function
|
|
897
1167
|
* Calculate element-wise floating-point modulo operation.
|
|
@@ -902,8 +1172,16 @@ declare const fmod: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
|
|
|
902
1172
|
* Calculate element-wise remainder of the division (matches sign of y).
|
|
903
1173
|
*/
|
|
904
1174
|
declare const remainder: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
|
|
905
|
-
/**
|
|
906
|
-
|
|
1175
|
+
/**
|
|
1176
|
+
* Return element-wise quotient and remainder simultaneously.
|
|
1177
|
+
*
|
|
1178
|
+
* Equivalent to `[floorDivide(x, y), remainder(x, y)]`.
|
|
1179
|
+
*
|
|
1180
|
+
* @param x - Dividend array.
|
|
1181
|
+
* @param y - Divisor array.
|
|
1182
|
+
* @returns Tuple of [quotient, remainder].
|
|
1183
|
+
*/
|
|
1184
|
+
declare function divmod(x: ArrayLike, y: ArrayLike): [Array, Array];
|
|
907
1185
|
/** Round input to the nearest integer towards zero. */
|
|
908
1186
|
declare function trunc(x: ArrayLike): Array;
|
|
909
1187
|
/**
|
|
@@ -943,8 +1221,6 @@ declare const degrees: typeof rad2deg;
|
|
|
943
1221
|
* Computes first array raised to power of second array, element-wise.
|
|
944
1222
|
*/
|
|
945
1223
|
declare const power: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
946
|
-
/** @function Alias of `jax.numpy.power()`. */
|
|
947
|
-
declare const pow: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
948
1224
|
/** @function Calculate the element-wise cube root of the input array. */
|
|
949
1225
|
declare const cbrt: OwnedFunction<(x: ArrayLike) => Array>;
|
|
950
1226
|
/**
|
|
@@ -989,12 +1265,6 @@ declare const arccosh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
|
989
1265
|
* `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
|
|
990
1266
|
*/
|
|
991
1267
|
declare const arctanh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
992
|
-
/** @function Alias of `jax.numpy.arcsinh()`. */
|
|
993
|
-
declare const asinh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
994
|
-
/** @function Alias of `jax.numpy.arccosh()`. */
|
|
995
|
-
declare const acosh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
996
|
-
/** @function Alias of `jax.numpy.arctanh()`. */
|
|
997
|
-
declare const atanh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
998
1268
|
/**
|
|
999
1269
|
* Compute the variance of an array.
|
|
1000
1270
|
*
|
|
@@ -1021,6 +1291,14 @@ declare function std(x: ArrayLike, axis?: Axis, opts?: {
|
|
|
1021
1291
|
mean?: ArrayLike;
|
|
1022
1292
|
correction?: number;
|
|
1023
1293
|
} & ReduceOpts): Array;
|
|
1294
|
+
/** Estimate the sample covariance of a set of variables. */
|
|
1295
|
+
declare function cov(x: ArrayLike, y?: ArrayLike | null, {
|
|
1296
|
+
rowvar
|
|
1297
|
+
}?: {
|
|
1298
|
+
rowvar?: boolean;
|
|
1299
|
+
}): Array;
|
|
1300
|
+
/** Compute the Pearson correlation coefficients (in range `[-1, 1]`). */
|
|
1301
|
+
declare function corrcoef(x: ArrayLike, y?: ArrayLike): Array;
|
|
1024
1302
|
/** Test element-wise for positive or negative infinity, return bool array. */
|
|
1025
1303
|
declare function isinf(x: ArrayLike): Array;
|
|
1026
1304
|
/** Test element-wise for NaN (Not a Number). */
|
|
@@ -1034,7 +1312,6 @@ declare function isposinf(x: ArrayLike): Array;
|
|
|
1034
1312
|
* Test element-wise for finite values (not infinity or NaN).
|
|
1035
1313
|
*/
|
|
1036
1314
|
declare const isfinite: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1037
|
-
//# sourceMappingURL=numpy.d.ts.map
|
|
1038
1315
|
declare namespace tree_d_exports {
|
|
1039
1316
|
export { JsTree, JsTreeDef, MapJsTree, NodeType, dispose, flatten, leaves, map, ref, structure, unflatten };
|
|
1040
1317
|
}
|
|
@@ -1087,7 +1364,7 @@ declare function dispose<Tree extends JsTree<Array>>(tree: Tree | null | undefin
|
|
|
1087
1364
|
interface ConvParams {
|
|
1088
1365
|
vmapDims: number;
|
|
1089
1366
|
strides: number[];
|
|
1090
|
-
padding: [
|
|
1367
|
+
padding: Pair[];
|
|
1091
1368
|
lhsDilation: number[];
|
|
1092
1369
|
rhsDilation: number[];
|
|
1093
1370
|
}
|
|
@@ -1168,9 +1445,21 @@ declare class Jaxpr implements FpHashable {
|
|
|
1168
1445
|
* - Remove no-op movement operations.
|
|
1169
1446
|
*/
|
|
1170
1447
|
simplify(): Jaxpr;
|
|
1171
|
-
/** Flattens nested
|
|
1448
|
+
/** Flattens nested Jit in a Jaxpr. Useful for handling jit-of-jit. */
|
|
1172
1449
|
flatten(): Jaxpr;
|
|
1173
1450
|
}
|
|
1451
|
+
/** Jaxpr with a collection of associated, traced constants. */
|
|
1452
|
+
declare class ClosedJaxpr {
|
|
1453
|
+
readonly jaxpr: Jaxpr;
|
|
1454
|
+
readonly consts: Tracer[];
|
|
1455
|
+
constructor(jaxpr: Jaxpr, consts: Tracer[]);
|
|
1456
|
+
/** String representation of this Jaxpr. */
|
|
1457
|
+
toString(): string;
|
|
1458
|
+
/** Apply a function to the underlying Jaxpr. */
|
|
1459
|
+
mapJaxpr(f: (jaxpr: Jaxpr) => Jaxpr): ClosedJaxpr;
|
|
1460
|
+
/** Dispose of the constants in this Jaxpr. */
|
|
1461
|
+
dispose(): void;
|
|
1462
|
+
}
|
|
1174
1463
|
/** @inline */
|
|
1175
1464
|
type JitOpts = {
|
|
1176
1465
|
staticArgnums?: number[];
|
|
@@ -1193,7 +1482,9 @@ declare enum Primitive {
|
|
|
1193
1482
|
Mul = "mul",
|
|
1194
1483
|
Idiv = "idiv",
|
|
1195
1484
|
Mod = "mod",
|
|
1196
|
-
// uses sign of
|
|
1485
|
+
// uses sign of numerator, C-style, matches JS but not Python
|
|
1486
|
+
Min = "min",
|
|
1487
|
+
Max = "max",
|
|
1197
1488
|
Neg = "neg",
|
|
1198
1489
|
Reciprocal = "reciprocal",
|
|
1199
1490
|
Floor = "floor",
|
|
@@ -1201,7 +1492,6 @@ declare enum Primitive {
|
|
|
1201
1492
|
StopGradient = "stop_gradient",
|
|
1202
1493
|
Cast = "cast",
|
|
1203
1494
|
Bitcast = "bitcast",
|
|
1204
|
-
RandomBits = "random_bits",
|
|
1205
1495
|
Sin = "sin",
|
|
1206
1496
|
Cos = "cos",
|
|
1207
1497
|
Asin = "asin",
|
|
@@ -1211,8 +1501,6 @@ declare enum Primitive {
|
|
|
1211
1501
|
Erf = "erf",
|
|
1212
1502
|
Erfc = "erfc",
|
|
1213
1503
|
Sqrt = "sqrt",
|
|
1214
|
-
Min = "min",
|
|
1215
|
-
Max = "max",
|
|
1216
1504
|
Reduce = "reduce",
|
|
1217
1505
|
Dot = "dot",
|
|
1218
1506
|
// sum(x*y, axis=-1)
|
|
@@ -1222,14 +1510,27 @@ declare enum Primitive {
|
|
|
1222
1510
|
PoolTranspose = "pool_transpose",
|
|
1223
1511
|
Compare = "compare",
|
|
1224
1512
|
Where = "where",
|
|
1513
|
+
Concatenate = "concatenate",
|
|
1514
|
+
Split = "split",
|
|
1515
|
+
RandomBits = "random_bits",
|
|
1516
|
+
Gather = "gather",
|
|
1225
1517
|
Transpose = "transpose",
|
|
1226
1518
|
Broadcast = "broadcast",
|
|
1227
1519
|
Reshape = "reshape",
|
|
1228
1520
|
Flip = "flip",
|
|
1229
1521
|
Shrink = "shrink",
|
|
1230
1522
|
Pad = "pad",
|
|
1231
|
-
|
|
1232
|
-
|
|
1523
|
+
Sort = "sort",
|
|
1524
|
+
// sort(x, axis=-1)
|
|
1525
|
+
Argsort = "argsort",
|
|
1526
|
+
// argsort(x, axis=-1)
|
|
1527
|
+
TriangularSolve = "triangular_solve",
|
|
1528
|
+
// A is upper triangular, A @ X.T = B.T
|
|
1529
|
+
Cholesky = "cholesky",
|
|
1530
|
+
// A is positive-definite, A = L @ L^T
|
|
1531
|
+
LU = "lu",
|
|
1532
|
+
// LU decomposition with partial pivoting
|
|
1533
|
+
Jit = "jit",
|
|
1233
1534
|
}
|
|
1234
1535
|
interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
1235
1536
|
[Primitive.Cast]: {
|
|
@@ -1255,6 +1556,21 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
|
1255
1556
|
[Primitive.Compare]: {
|
|
1256
1557
|
op: CompareOp;
|
|
1257
1558
|
};
|
|
1559
|
+
[Primitive.Concatenate]: {
|
|
1560
|
+
axis: number;
|
|
1561
|
+
};
|
|
1562
|
+
[Primitive.Split]: {
|
|
1563
|
+
axis: number;
|
|
1564
|
+
sizes: number[];
|
|
1565
|
+
};
|
|
1566
|
+
[Primitive.RandomBits]: {
|
|
1567
|
+
shape: number[];
|
|
1568
|
+
mode: "xor" | 0 | 1;
|
|
1569
|
+
};
|
|
1570
|
+
[Primitive.Gather]: {
|
|
1571
|
+
axis: number[];
|
|
1572
|
+
outDim: number;
|
|
1573
|
+
};
|
|
1258
1574
|
[Primitive.Transpose]: {
|
|
1259
1575
|
perm: number[];
|
|
1260
1576
|
};
|
|
@@ -1262,10 +1578,6 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
|
1262
1578
|
shape: number[];
|
|
1263
1579
|
axis: number[];
|
|
1264
1580
|
};
|
|
1265
|
-
[Primitive.RandomBits]: {
|
|
1266
|
-
shape: number[];
|
|
1267
|
-
mode: "xor" | 0 | 1;
|
|
1268
|
-
};
|
|
1269
1581
|
[Primitive.Reshape]: {
|
|
1270
1582
|
shape: number[];
|
|
1271
1583
|
};
|
|
@@ -1278,15 +1590,14 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
|
1278
1590
|
[Primitive.Pad]: {
|
|
1279
1591
|
width: Pair[];
|
|
1280
1592
|
};
|
|
1281
|
-
[Primitive.
|
|
1282
|
-
axis: number[];
|
|
1283
|
-
outDim: number;
|
|
1284
|
-
};
|
|
1285
|
-
[Primitive.JitCall]: {
|
|
1593
|
+
[Primitive.Jit]: {
|
|
1286
1594
|
name: string;
|
|
1287
1595
|
jaxpr: Jaxpr;
|
|
1288
1596
|
numConsts: number;
|
|
1289
1597
|
};
|
|
1598
|
+
[Primitive.TriangularSolve]: {
|
|
1599
|
+
unitDiagonal: boolean;
|
|
1600
|
+
};
|
|
1290
1601
|
}
|
|
1291
1602
|
/** Type of parameters taken by each primitive. */
|
|
1292
1603
|
type PrimitiveParams<T extends Primitive> = T extends keyof PrimitiveParamsImpl ? PrimitiveParamsImpl[T] : Record<string, never>;
|
|
@@ -1427,6 +1738,7 @@ declare abstract class Tracer {
|
|
|
1427
1738
|
neg(): this;
|
|
1428
1739
|
add(other: this | TracerValue): this;
|
|
1429
1740
|
mul(other: this | TracerValue): this;
|
|
1741
|
+
mod(other: this | TracerValue): this;
|
|
1430
1742
|
greater(other: this | TracerValue): this;
|
|
1431
1743
|
less(other: this | TracerValue): this;
|
|
1432
1744
|
equal(other: this | TracerValue): this;
|
|
@@ -1454,7 +1766,7 @@ declare abstract class Tracer {
|
|
|
1454
1766
|
sub(other: this | TracerValue): this;
|
|
1455
1767
|
/** Divide an array by this one. */
|
|
1456
1768
|
div(other: this | TracerValue): this;
|
|
1457
|
-
/** Return specified diagonals. See `numpy.diagonal` for full docs. */
|
|
1769
|
+
/** Return specified diagonals. See `jax.numpy.diagonal` for full docs. */
|
|
1458
1770
|
diagonal(offset?: number, axis1?: number, axis2?: number): this;
|
|
1459
1771
|
/** Flatten the array without changing its data. */
|
|
1460
1772
|
flatten(): this;
|
|
@@ -1473,6 +1785,19 @@ declare abstract class Tracer {
|
|
|
1473
1785
|
* ```
|
|
1474
1786
|
*/
|
|
1475
1787
|
[Symbol.iterator](): IterableIterator<this>;
|
|
1788
|
+
/**
|
|
1789
|
+
* Return a sorted copy of an array in ascending order.
|
|
1790
|
+
*
|
|
1791
|
+
* See `jax.numpy.sort` for full docs.
|
|
1792
|
+
*/
|
|
1793
|
+
sort(axis?: number): this;
|
|
1794
|
+
/**
|
|
1795
|
+
* Return the indices that would sort an array. This may not be a stable
|
|
1796
|
+
* sorting algorithm; it need not preserve order of indices in ties.
|
|
1797
|
+
*
|
|
1798
|
+
* See `jax.numpy.argsort` for full docs.
|
|
1799
|
+
*/
|
|
1800
|
+
argsort(axis?: number): this;
|
|
1476
1801
|
/**
|
|
1477
1802
|
* Slice an array along one or more axes.
|
|
1478
1803
|
*
|
|
@@ -1515,6 +1840,8 @@ declare class ShapedArray implements AbstractValue {
|
|
|
1515
1840
|
constructor(shape: number[], dtype: DType, weakType: boolean);
|
|
1516
1841
|
static fromAval(aval: AbstractValue): ShapedArray;
|
|
1517
1842
|
get ndim(): number;
|
|
1843
|
+
get size(): number;
|
|
1844
|
+
scalar(): ShapedArray;
|
|
1518
1845
|
toString(): string;
|
|
1519
1846
|
equals(other: ShapedArray): boolean;
|
|
1520
1847
|
}
|
|
@@ -1532,12 +1859,12 @@ type ArrayLike = Array | number | boolean;
|
|
|
1532
1859
|
declare class PendingExecute {
|
|
1533
1860
|
#private;
|
|
1534
1861
|
readonly backend: Backend;
|
|
1535
|
-
readonly
|
|
1862
|
+
readonly source: Kernel | Routine;
|
|
1536
1863
|
readonly inputs: Slot[];
|
|
1537
1864
|
readonly outputs: Slot[];
|
|
1538
1865
|
prepared: Executable | null;
|
|
1539
1866
|
submitted: boolean;
|
|
1540
|
-
constructor(backend: Backend,
|
|
1867
|
+
constructor(backend: Backend, source: Kernel | Routine, inputs: Slot[], outputs: Slot[]);
|
|
1541
1868
|
updateRc(delta: number): void;
|
|
1542
1869
|
prepare(): Promise<void>;
|
|
1543
1870
|
prepareSync(): void;
|
|
@@ -1569,7 +1896,6 @@ type ArrayConstructorArgs = {
|
|
|
1569
1896
|
*/
|
|
1570
1897
|
declare class Array extends Tracer {
|
|
1571
1898
|
#private;
|
|
1572
|
-
id: number;
|
|
1573
1899
|
/**
|
|
1574
1900
|
* @ignore
|
|
1575
1901
|
* Constructs an array from source, shape and backend. Note that if the source
|
|
@@ -1583,6 +1909,8 @@ declare class Array extends Tracer {
|
|
|
1583
1909
|
toString(): string;
|
|
1584
1910
|
get device(): Device;
|
|
1585
1911
|
get ref(): this;
|
|
1912
|
+
/** Get the current reference count (for debugging memory management). */
|
|
1913
|
+
get refCount(): number;
|
|
1586
1914
|
dispose(): void;
|
|
1587
1915
|
/**
|
|
1588
1916
|
* Convert this array into a primitive value.
|
|
@@ -1703,6 +2031,21 @@ declare function arange(start: number, stop?: number, step?: number, {
|
|
|
1703
2031
|
dtype,
|
|
1704
2032
|
device
|
|
1705
2033
|
}?: DTypeAndDevice): Array;
|
|
2034
|
+
/**
|
|
2035
|
+
* Return an array with ones on and below the diagonal and zeros elsewhere.
|
|
2036
|
+
*
|
|
2037
|
+
* If `k` is provided, it specifies the sub-diagonal on and below which the
|
|
2038
|
+
* array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
|
|
2039
|
+
* `k>0` is above it.
|
|
2040
|
+
*/
|
|
2041
|
+
declare function tri(n: number, m?: number, k?: number, {
|
|
2042
|
+
dtype,
|
|
2043
|
+
device
|
|
2044
|
+
}?: DTypeAndDevice): Array;
|
|
2045
|
+
/** Return the lower triangle of an array. Must be of dimension >= 2. */
|
|
2046
|
+
declare function tril(a: ArrayLike, k?: number): Array;
|
|
2047
|
+
/** Return the upper triangle of an array. Must be of dimension >= 2. */
|
|
2048
|
+
declare function triu(a: ArrayLike, k?: number): Array;
|
|
1706
2049
|
/**
|
|
1707
2050
|
* Return evenly spaced numbers over a specified interval.
|
|
1708
2051
|
*
|
|
@@ -1716,8 +2059,114 @@ declare function linspace(start: number, stop: number, num?: number, endpoint?:
|
|
|
1716
2059
|
dtype,
|
|
1717
2060
|
device
|
|
1718
2061
|
}?: DTypeAndDevice): Array;
|
|
2062
|
+
/**
|
|
2063
|
+
* Return numbers spaced evenly on a log scale.
|
|
2064
|
+
*
|
|
2065
|
+
* In linear space, the sequence starts at `base ** start` and ends at
|
|
2066
|
+
* `base ** stop` (see `endpoint` below).
|
|
2067
|
+
*
|
|
2068
|
+
* @param start - `base ** start` is the starting value of the sequence.
|
|
2069
|
+
* @param stop - `base ** stop` is the final value of the sequence, unless `endpoint` is false.
|
|
2070
|
+
* @param num - Number of samples to generate. Default is 50.
|
|
2071
|
+
* @param endpoint - If true, `stop` is the last sample. Otherwise, it is not included. Default is true.
|
|
2072
|
+
* @param base - The base of the log space. Default is 10.
|
|
2073
|
+
* @returns Array of evenly spaced values on a log scale.
|
|
2074
|
+
*/
|
|
2075
|
+
declare function logspace(start: number, stop: number, num?: number, endpoint?: boolean, base?: number, {
|
|
2076
|
+
dtype,
|
|
2077
|
+
device
|
|
2078
|
+
}?: DTypeAndDevice): Array;
|
|
2079
|
+
declare namespace lax_linalg_d_exports {
|
|
2080
|
+
export { cholesky, lu, triangularSolve };
|
|
2081
|
+
}
|
|
2082
|
+
/**
|
|
2083
|
+
* Compute the Cholesky decomposition of a symmetric positive-definite matrix.
|
|
2084
|
+
*
|
|
2085
|
+
* The Cholesky decomposition of a matrix `A` is:
|
|
2086
|
+
*
|
|
2087
|
+
* - A = L @ L^T (for upper=false, default)
|
|
2088
|
+
* - A = U^T @ U (for upper=true)
|
|
2089
|
+
*
|
|
2090
|
+
* where `L` is a lower-triangular matrix and `U` is an upper-triangular matrix.
|
|
2091
|
+
* The input matrix must be symmetric and positive-definite.
|
|
2092
|
+
*
|
|
2093
|
+
* @example
|
|
2094
|
+
* ```ts
|
|
2095
|
+
* import { lax, numpy as np } from "@jax-js/jax";
|
|
2096
|
+
*
|
|
2097
|
+
* const x = np.array([[2., 1.], [1., 2.]]);
|
|
2098
|
+
*
|
|
2099
|
+
* // Lower Cholesky factorization (default):
|
|
2100
|
+
* const L = lax.linalg.cholesky(x);
|
|
2101
|
+
* // L ≈ [[1.4142135, 0], [0.70710677, 1.2247449]]
|
|
2102
|
+
*
|
|
2103
|
+
* // Upper Cholesky factorization:
|
|
2104
|
+
* const U = lax.linalg.cholesky(x, { upper: true });
|
|
2105
|
+
* // U ≈ [[1.4142135, 0.70710677], [0, 1.2247449]]
|
|
2106
|
+
* ```
|
|
2107
|
+
*/
|
|
2108
|
+
declare function cholesky(a: ArrayLike, {
|
|
2109
|
+
upper
|
|
2110
|
+
}?: {
|
|
2111
|
+
upper?: boolean;
|
|
2112
|
+
}): Array;
|
|
2113
|
+
/**
|
|
2114
|
+
* LU decomposition with partial pivoting.
|
|
2115
|
+
*
|
|
2116
|
+
* Computes the matrix decomposition: `P @ A = L @ U`, where `P` is a
|
|
2117
|
+
* permutation of the rows of `A`, `L` is lower-triangular with unit diagonal,
|
|
2118
|
+
* and `U` is upper-triangular.
|
|
2119
|
+
*
|
|
2120
|
+
* @param x - A batch of matrices with shape `[..., m, n]`.
|
|
2121
|
+
*
|
|
2122
|
+
* @returns A tuple `(lu, pivots, permutation)` where:
|
|
2123
|
+
* - `lu`: combined lower and upper triangular matrices.
|
|
2124
|
+
* - `pivots`: an array of pivot indices with shape `[..., min(m, n)]`.
|
|
2125
|
+
* - `permutation`: the permutation generated by pivots with shape `[..., m]`.
|
|
2126
|
+
*
|
|
2127
|
+
* @example
|
|
2128
|
+
* ```ts
|
|
2129
|
+
* import { lax, numpy as np } from "@jax-js/jax";
|
|
2130
|
+
*
|
|
2131
|
+
* const A = np.array([[4., 3.], [6., 3.]]);
|
|
2132
|
+
* const [lu, pivots, permutation] = lax.linalg.lu(A);
|
|
2133
|
+
* // lu ≈ [[6., 3.], [0.6666667, 1.0]]
|
|
2134
|
+
* // pivots = [1, 1]
|
|
2135
|
+
* // permutation = [1, 0]
|
|
2136
|
+
* ```
|
|
2137
|
+
*/
|
|
2138
|
+
declare function lu(x: ArrayLike): [Array, Array, Array];
|
|
2139
|
+
/**
|
|
2140
|
+
* Solve a triangular linear system.
|
|
2141
|
+
*
|
|
2142
|
+
* Solves `a @ x = b` (if leftSide=true) or `x @ a = b` (if leftSide=false)
|
|
2143
|
+
* where `a` is a triangular matrix.
|
|
2144
|
+
*
|
|
2145
|
+
* @example
|
|
2146
|
+
* ```ts
|
|
2147
|
+
* import { lax, numpy as np } from "@jax-js/jax";
|
|
2148
|
+
*
|
|
2149
|
+
* const L = np.array([[2., 0.], [1., 3.]]);
|
|
2150
|
+
* const b = np.array([4., 7.]).reshape([2, 1]);
|
|
2151
|
+
*
|
|
2152
|
+
* // Solve L @ x = b
|
|
2153
|
+
* const x = lax.linalg.triangularSolve(L, b, { leftSide: true, lower: true });
|
|
2154
|
+
* // x = [[2.], [5./3.]]
|
|
2155
|
+
* ```
|
|
2156
|
+
*/
|
|
2157
|
+
declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
|
|
2158
|
+
leftSide,
|
|
2159
|
+
lower,
|
|
2160
|
+
transposeA,
|
|
2161
|
+
unitDiagonal
|
|
2162
|
+
}?: {
|
|
2163
|
+
leftSide?: boolean;
|
|
2164
|
+
lower?: boolean;
|
|
2165
|
+
transposeA?: boolean;
|
|
2166
|
+
unitDiagonal?: boolean;
|
|
2167
|
+
}): Array;
|
|
1719
2168
|
declare namespace lax_d_exports {
|
|
1720
|
-
export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convWithGeneralPadding, dot, erf, erfc, reduceWindow, stopGradient };
|
|
2169
|
+
export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convWithGeneralPadding, dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
|
|
1721
2170
|
}
|
|
1722
2171
|
/**
|
|
1723
2172
|
* Dimension numbers for general `dot()` primitive.
|
|
@@ -1748,7 +2197,7 @@ declare function dot(lhs: Array, rhs: Array, {
|
|
|
1748
2197
|
lhsBatchDims: lb,
|
|
1749
2198
|
rhsBatchDims: rb
|
|
1750
2199
|
}?: DotDimensionNumbers): Array;
|
|
1751
|
-
type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | [
|
|
2200
|
+
type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | Pair[];
|
|
1752
2201
|
/**
|
|
1753
2202
|
* General n-dimensional convolution operator, with optional dilation.
|
|
1754
2203
|
*
|
|
@@ -1788,9 +2237,8 @@ declare function erfc(x: ArrayLike): Array;
|
|
|
1788
2237
|
* forward or reverse-mode automatic differentiation.
|
|
1789
2238
|
*/
|
|
1790
2239
|
declare function stopGradient(x: ArrayLike): Array;
|
|
1791
|
-
//# sourceMappingURL=lax.d.ts.map
|
|
1792
2240
|
declare namespace nn_d_exports {
|
|
1793
|
-
export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, squareplus, standardize, swish };
|
|
2241
|
+
export { celu, elu, gelu, glu, hardSigmoid, hardSilu, hardSilu as hardSwish, hardTanh, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, selu, sigmoid, silu, softSign, softmax, softplus, sparsePlus, sparseSigmoid, squareplus, standardize, silu as swish };
|
|
1794
2242
|
}
|
|
1795
2243
|
/**
|
|
1796
2244
|
* Rectified Linear Unit (ReLU) activation function:
|
|
@@ -1817,21 +2265,28 @@ declare function sigmoid(x: ArrayLike): Array;
|
|
|
1817
2265
|
*/
|
|
1818
2266
|
declare function softplus(x: ArrayLike): Array;
|
|
1819
2267
|
/**
|
|
1820
|
-
*
|
|
1821
|
-
*
|
|
2268
|
+
* @function
|
|
2269
|
+
* Sparse plus function:
|
|
2270
|
+
*
|
|
2271
|
+
* - When `x <= -1`: `0`
|
|
2272
|
+
* - When `-1 < x < 1`: `(x+1)**2 / 4`
|
|
2273
|
+
* - When `x >= 1`: `x`
|
|
1822
2274
|
*/
|
|
1823
|
-
declare
|
|
2275
|
+
declare const sparsePlus: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1824
2276
|
/**
|
|
1825
2277
|
* @function
|
|
1826
|
-
*
|
|
1827
|
-
* Swish, computed element-wise:
|
|
1828
|
-
* `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
|
|
2278
|
+
* Sparse sigmoid activation function.
|
|
1829
2279
|
*
|
|
1830
|
-
*
|
|
1831
|
-
*
|
|
1832
|
-
*
|
|
2280
|
+
* - When `x <= -1`: `0`
|
|
2281
|
+
* - When `-1 < x < 1`: `(x + 1) / 2`
|
|
2282
|
+
* - When `x >= 1`: `1`
|
|
1833
2283
|
*/
|
|
1834
|
-
declare const
|
|
2284
|
+
declare const sparseSigmoid: OwnedFunction<(x: ArrayLike) => Array>;
|
|
2285
|
+
/**
|
|
2286
|
+
* Soft-sign activation function, computed element-wise:
|
|
2287
|
+
* `softsign(x) = x / (|x| + 1)`.
|
|
2288
|
+
*/
|
|
2289
|
+
declare function softSign(x: ArrayLike): Array;
|
|
1835
2290
|
/**
|
|
1836
2291
|
* @function
|
|
1837
2292
|
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
@@ -1842,7 +2297,7 @@ declare const silu: OwnedFunction<(x: ArrayLike) => Array>;
|
|
|
1842
2297
|
*
|
|
1843
2298
|
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
1844
2299
|
*/
|
|
1845
|
-
declare const
|
|
2300
|
+
declare const silu: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1846
2301
|
/**
|
|
1847
2302
|
* Log-sigmoid activation function, computed element-wise:
|
|
1848
2303
|
* `log_sigmoid(x) = log(sigmoid(x)) = -log(1 + exp(-x))`.
|
|
@@ -1855,6 +2310,12 @@ declare function logSigmoid(x: ArrayLike): Array;
|
|
|
1855
2310
|
declare const identity: (x: ArrayLike) => Array;
|
|
1856
2311
|
/** Leaky rectified linear (ReLU) activation function */
|
|
1857
2312
|
declare function leakyRelu(x: ArrayLike, negativeSlope?: ArrayLike): Array;
|
|
2313
|
+
/** Hard sigmoid activation function: `relu6(x+3)/6`. */
|
|
2314
|
+
declare function hardSigmoid(x: ArrayLike): Array;
|
|
2315
|
+
/** Hard SiLU (swish) activation function: `x * hardSigmoid(x)`. */
|
|
2316
|
+
declare function hardSilu(x: ArrayLike): Array;
|
|
2317
|
+
/** Hard tanh activation function: `clip(x, -1, 1)`. */
|
|
2318
|
+
declare function hardTanh(x: ArrayLike): Array;
|
|
1858
2319
|
/**
|
|
1859
2320
|
* Exponential linear unit activation function.
|
|
1860
2321
|
*
|
|
@@ -1869,6 +2330,16 @@ declare function elu(x: ArrayLike, alpha?: ArrayLike): Array;
|
|
|
1869
2330
|
* `celu(x) = x > 0 ? x : alpha * (exp(x/alpha) - 1)`
|
|
1870
2331
|
*/
|
|
1871
2332
|
declare function celu(x: ArrayLike, alpha?: ArrayLike): Array;
|
|
2333
|
+
/**
|
|
2334
|
+
* @function
|
|
2335
|
+
* Scaled exponential linear unit activation.
|
|
2336
|
+
*
|
|
2337
|
+
* Computes the element-wise function:
|
|
2338
|
+
* `selu(x) = lambda * (x > 0 ? x : alpha * (exp(x) - 1))`
|
|
2339
|
+
*
|
|
2340
|
+
* Where `alpha = 1.6732632423543772` and `lambda = 1.0507009873554805`.
|
|
2341
|
+
*/
|
|
2342
|
+
declare const selu: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1872
2343
|
/**
|
|
1873
2344
|
* @function
|
|
1874
2345
|
* Gaussion error linear unit (GELU) activation function.
|
|
@@ -1964,12 +2435,11 @@ declare function standardize(x: ArrayLike, axis?: Axis, opts?: {
|
|
|
1964
2435
|
* ```
|
|
1965
2436
|
*/
|
|
1966
2437
|
declare function oneHot(x: Array, numClasses: number): Array;
|
|
1967
|
-
//# sourceMappingURL=nn.d.ts.map
|
|
1968
2438
|
declare namespace random_d_exports {
|
|
1969
|
-
export { bernoulli, bits, exponential, key, normal, split, uniform };
|
|
2439
|
+
export { bernoulli, bits, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
|
|
1970
2440
|
}
|
|
1971
2441
|
/** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
|
|
1972
|
-
declare function key(seed:
|
|
2442
|
+
declare function key(seed: ArrayLike): Array;
|
|
1973
2443
|
/** Splits a PRNG key into `num` new keys by adding a leading axis. */
|
|
1974
2444
|
declare function split(key: Array, num?: number | number[]): Array;
|
|
1975
2445
|
/** Sample uniform bits in the form of unsigned integers. */
|
|
@@ -1989,11 +2459,50 @@ declare const uniform: OwnedFunction<(key: ArrayLike, shape?: number[] | undefin
|
|
|
1989
2459
|
* and must be broadcastable to `shape`.
|
|
1990
2460
|
*/
|
|
1991
2461
|
declare function bernoulli(key: Array, p?: ArrayLike, shape?: number[]): Array;
|
|
2462
|
+
/**
|
|
2463
|
+
* @function
|
|
2464
|
+
* Sample from a Cauchy distribution with location 0 and scale 1.
|
|
2465
|
+
*
|
|
2466
|
+
* Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
|
|
2467
|
+
*/
|
|
2468
|
+
declare const cauchy: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
1992
2469
|
/**
|
|
1993
2470
|
* @function
|
|
1994
2471
|
* Sample exponential random values according to `p(x) = exp(-x)`.
|
|
1995
2472
|
*/
|
|
1996
2473
|
declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
2474
|
+
/**
|
|
2475
|
+
* @function
|
|
2476
|
+
* Sample from a Gumbel distribution with location 0 and scale 1.
|
|
2477
|
+
*
|
|
2478
|
+
* Uses inverse transform sampling: `x = -log(-log(u))` where u ~ Uniform(0, 1).
|
|
2479
|
+
*/
|
|
2480
|
+
declare const gumbel: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
2481
|
+
/**
|
|
2482
|
+
* @function
|
|
2483
|
+
* Sample from a Laplace distribution with location 0 and scale 1.
|
|
2484
|
+
*
|
|
2485
|
+
* Uses inverse transform sampling: the CDF is `F(x) = 0.5 + 0.5 * sign(x) * (1 - exp(-|x|))`.
|
|
2486
|
+
* Inverting: `x = -sign(u - 0.5) * log(1 - 2 * |u - 0.5|)`.
|
|
2487
|
+
*/
|
|
2488
|
+
declare const laplace: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
2489
|
+
/**
|
|
2490
|
+
* @function
|
|
2491
|
+
* Sample multivariate normal random values with given mean and covariance.
|
|
2492
|
+
*
|
|
2493
|
+
* The values are returned with the given shape, along with the final dimension
|
|
2494
|
+
* used to represent the n-dimensional multivariate normal factors.
|
|
2495
|
+
*
|
|
2496
|
+
* This uses Cholesky decomposition on the covariance matrix.
|
|
2497
|
+
*
|
|
2498
|
+
* - `key` - PRNG key
|
|
2499
|
+
* - `mean` - Mean vector of shape `[..., n]`
|
|
2500
|
+
* - `cov` - Covariance of shape `[..., n, n]`, must be positive-definite
|
|
2501
|
+
* - `shape` - Result batch shape, must be broadcastable with
|
|
2502
|
+
* `mean.shape[:-1]` and `cov.shape[:-2]`
|
|
2503
|
+
* @returns Random samples of shape `[...shape, n]`
|
|
2504
|
+
*/
|
|
2505
|
+
declare const multivariateNormal: OwnedFunction<(key: ArrayLike, mean: ArrayLike, cov: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
1997
2506
|
/**
|
|
1998
2507
|
* @function
|
|
1999
2508
|
* Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
|
|
@@ -2003,7 +2512,6 @@ declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | und
|
|
|
2003
2512
|
* bitwise identical to JAX.
|
|
2004
2513
|
*/
|
|
2005
2514
|
declare const normal: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
2006
|
-
//# sourceMappingURL=random.d.ts.map
|
|
2007
2515
|
declare namespace scipy_special_d_exports {
|
|
2008
2516
|
export { erf, erfc, logSoftmax, logit, logsumexp, softmax };
|
|
2009
2517
|
}
|
|
@@ -2034,8 +2542,7 @@ declare const jacfwd: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTr
|
|
|
2034
2542
|
* Construct a Jaxpr by dynamically tracing a function with example inputs.
|
|
2035
2543
|
*/
|
|
2036
2544
|
declare const makeJaxpr: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...args: Parameters<F>) => {
|
|
2037
|
-
jaxpr:
|
|
2038
|
-
consts: Array[];
|
|
2545
|
+
jaxpr: ClosedJaxpr;
|
|
2039
2546
|
treedef: JsTreeDef;
|
|
2040
2547
|
};
|
|
2041
2548
|
/**
|
|
@@ -2083,11 +2590,6 @@ declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>>(f: F)
|
|
|
2083
2590
|
* Compute the Jacobian evaluated row-by-row by reverse-mode AD.
|
|
2084
2591
|
*/
|
|
2085
2592
|
declare const jacrev: typeof jacfwd;
|
|
2086
|
-
/**
|
|
2087
|
-
* @function
|
|
2088
|
-
* Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
|
|
2089
|
-
*/
|
|
2090
|
-
declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
2091
2593
|
/**
|
|
2092
2594
|
* Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
|
|
2093
2595
|
*
|
|
@@ -2109,8 +2611,5 @@ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
|
|
|
2109
2611
|
* default device.
|
|
2110
2612
|
*/
|
|
2111
2613
|
declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
|
|
2112
|
-
//# sourceMappingURL=index.d.ts.map
|
|
2113
|
-
|
|
2114
2614
|
//#endregion
|
|
2115
|
-
export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|
|
2116
|
-
//# sourceMappingURL=index.d.cts.map
|
|
2615
|
+
export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|