@jax-js/jax 0.1.3 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +15 -9
- package/dist/{backend-BY8wlLEl.js → backend-DaqL-MNz.js} +240 -21
- package/dist/{backend-CmaidnkQ.cjs → backend-DziQSaoQ.cjs} +264 -21
- package/dist/index.cjs +2407 -1132
- package/dist/index.d.cts +596 -97
- package/dist/index.d.ts +596 -97
- package/dist/index.js +2400 -1126
- package/dist/webgl-ClIYb8jP.cjs +522 -0
- package/dist/webgl-RSuZKvgc.js +522 -0
- package/dist/webgpu-Db2JrNBr.cjs +1261 -0
- package/dist/webgpu-Dh7k9io0.js +1261 -0
- package/package.json +1 -1
- package/dist/webgpu-BVns4DbI.cjs +0 -663
- package/dist/webgpu-C9iAP5h5.js +0 -663
package/dist/index.d.ts
CHANGED
|
@@ -121,7 +121,6 @@ declare class ShapeTracker {
|
|
|
121
121
|
/** Like pad(), but allows for negative values. */
|
|
122
122
|
padOrShrink(arg: Pair[]): ShapeTracker;
|
|
123
123
|
}
|
|
124
|
-
//# sourceMappingURL=shape.d.ts.map
|
|
125
124
|
//#endregion
|
|
126
125
|
//#region src/utils.d.ts
|
|
127
126
|
/**
|
|
@@ -404,8 +403,81 @@ declare class Reduction implements FpHashable {
|
|
|
404
403
|
}
|
|
405
404
|
/** Expression for accessing `indices` in input array with the given shape. */
|
|
406
405
|
//#endregion
|
|
406
|
+
//#region src/routine.d.ts
|
|
407
|
+
/**
|
|
408
|
+
* Advanced operations that don't fit into the `AluExp` compiler representation.
|
|
409
|
+
*
|
|
410
|
+
* Some routines like iterative matrix algorithms, FFTs, or sorting may not be
|
|
411
|
+
* easy to express efficiently as a `Kernel` object. These also tend to be
|
|
412
|
+
* somewhat expensive, so the benefit of kernel fusion and inlining is less
|
|
413
|
+
* relevant.
|
|
414
|
+
*
|
|
415
|
+
* For these operations, we dispatch them as a custom operation on the backend,
|
|
416
|
+
* which each backend implements in a specific way. These are listed in the
|
|
417
|
+
* `Routines` enum below.
|
|
418
|
+
*
|
|
419
|
+
* Routines cannot be fused into other kernels and always operate on contiguous
|
|
420
|
+
* arrays (default `ShapeTracker`).
|
|
421
|
+
*/
|
|
422
|
+
declare class Routine {
|
|
423
|
+
/** The name of the routine. */
|
|
424
|
+
readonly name: Routines;
|
|
425
|
+
/** Dtype and shape of the inputs and outputs. */
|
|
426
|
+
readonly type: RoutineType;
|
|
427
|
+
/** Extra parameters specific to the routine. */
|
|
428
|
+
readonly params?: any | undefined;
|
|
429
|
+
constructor(/** The name of the routine. */
|
|
430
|
+
name: Routines, /** Dtype and shape of the inputs and outputs. */
|
|
431
|
+
type: RoutineType, /** Extra parameters specific to the routine. */
|
|
432
|
+
params?: any | undefined);
|
|
433
|
+
}
|
|
434
|
+
/** One of the valid `Routine` that can be dispatched to backend. */
|
|
435
|
+
declare enum Routines {
|
|
436
|
+
/** Stable sorting algorithm along the last axis. */
|
|
437
|
+
Sort = "Sort",
|
|
438
|
+
/** Returns `int32` indices of the stably sorted array. */
|
|
439
|
+
Argsort = "Argsort",
|
|
440
|
+
/**
|
|
441
|
+
* Solve a triangular system of equations.
|
|
442
|
+
*
|
|
443
|
+
* The first batch of inputs `A` should be of shape `[..., N, N]` and upper
|
|
444
|
+
* triangular, while the second batch `B` should be of shape `[..., M, N]`.
|
|
445
|
+
*
|
|
446
|
+
* Solves for `X` in the equation `A @ X.T = B.T`, where `A` is the
|
|
447
|
+
* triangular matrix. This is equivalent to `X = B @ A^-T`.
|
|
448
|
+
*/
|
|
449
|
+
TriangularSolve = "TriangularSolve",
|
|
450
|
+
/**
|
|
451
|
+
* Cholesky decomposition of 2D positive semi-definite matrices.
|
|
452
|
+
*
|
|
453
|
+
* The input batch should be of shape `[..., N, N]`, and the output batch is
|
|
454
|
+
* of the same shape, containing the lower-triangular matrix `L` such that
|
|
455
|
+
* `A = L @ L.T`. Behavior is unspecified if A is not positive semi-definite.
|
|
456
|
+
*/
|
|
457
|
+
Cholesky = "Cholesky",
|
|
458
|
+
/**
|
|
459
|
+
* LU decomposition of 2D rectangular matrices.
|
|
460
|
+
*
|
|
461
|
+
* The input is a batch of shape `[..., M, N]`, and the output is a tuple of
|
|
462
|
+
* three arrays: `LU, Pivots, Permutation`.
|
|
463
|
+
*
|
|
464
|
+
* - `LU` is of shape `[..., M, N]`, containing the combined lower and upper
|
|
465
|
+
* triangular matrices. (lower triangular = implicit unit diagonal)
|
|
466
|
+
* - `Pivots` is of shape `[..., min(M, N)]`, containing the row swaps.
|
|
467
|
+
* - `Permutation` is of shape `[..., M]`, containing the permutation vector
|
|
468
|
+
* such that `P = eye(M).slice(Permutation)` -> `P @ A = L @ U`.
|
|
469
|
+
*/
|
|
470
|
+
LU = "LU",
|
|
471
|
+
}
|
|
472
|
+
interface RoutineType {
|
|
473
|
+
inputShapes: number[][];
|
|
474
|
+
inputDtypes: DType[];
|
|
475
|
+
outputShapes: number[][];
|
|
476
|
+
outputDtypes: DType[];
|
|
477
|
+
}
|
|
478
|
+
//#endregion
|
|
407
479
|
//#region src/backend.d.ts
|
|
408
|
-
type Device = "cpu" | "wasm" | "webgpu";
|
|
480
|
+
type Device = "cpu" | "wasm" | "webgpu" | "webgl";
|
|
409
481
|
declare const devices: Device[];
|
|
410
482
|
/** Configure the default device for arrays. */
|
|
411
483
|
declare function defaultDevice(device?: Device): Device;
|
|
@@ -441,9 +513,13 @@ interface Backend {
|
|
|
441
513
|
/** Read a range of bytes from a buffer, blocking variant. */
|
|
442
514
|
readSync(slot: Slot, start?: number, count?: number): Uint8Array<ArrayBuffer>;
|
|
443
515
|
/** Prepare an expression to be executed later. */
|
|
444
|
-
|
|
516
|
+
prepareKernel(kernel: Kernel): Promise<Executable>;
|
|
445
517
|
/** Prepare an expression to be executed later, blocking variant. */
|
|
446
|
-
|
|
518
|
+
prepareKernelSync(kernel: Kernel): Executable;
|
|
519
|
+
/** Prepare an advanced routine to be executed later. */
|
|
520
|
+
prepareRoutine(routine: Routine): Promise<Executable>;
|
|
521
|
+
/** Prepare an advanced routine to be executed later, blocking variant. */
|
|
522
|
+
prepareRoutineSync(routine: Routine): Executable;
|
|
447
523
|
/**
|
|
448
524
|
* Run a backend operation that was previously prepared.
|
|
449
525
|
*
|
|
@@ -454,14 +530,140 @@ interface Backend {
|
|
|
454
530
|
dispatch(exe: Executable, inputs: Slot[], outputs: Slot[]): void;
|
|
455
531
|
}
|
|
456
532
|
declare class Executable<T = any> {
|
|
457
|
-
|
|
458
|
-
|
|
533
|
+
/** The `Kernel` or `Routine` that was prepared. */
|
|
534
|
+
readonly source: Kernel | Routine;
|
|
535
|
+
/** Extra data specific to the backend running this executable. */
|
|
459
536
|
readonly data: T;
|
|
460
|
-
constructor(
|
|
537
|
+
constructor(/** The `Kernel` or `Routine` that was prepared. */
|
|
538
|
+
source: Kernel | Routine, /** Extra data specific to the backend running this executable. */
|
|
461
539
|
data: T);
|
|
462
540
|
}
|
|
541
|
+
declare namespace numpy_fft_d_exports {
|
|
542
|
+
export { ComplexPair, fft, ifft };
|
|
543
|
+
}
|
|
544
|
+
/**
|
|
545
|
+
* A pair of arrays representing real and imaginary part `a + bj`. Both arrays
|
|
546
|
+
* must have the same shape.
|
|
547
|
+
*/
|
|
548
|
+
type ComplexPair = {
|
|
549
|
+
real: Array;
|
|
550
|
+
imag: Array;
|
|
551
|
+
};
|
|
552
|
+
/**
|
|
553
|
+
* Compute a one-dimensional discrete Fourier transform.
|
|
554
|
+
*
|
|
555
|
+
* Currently, the size of the axis must be a power of two.
|
|
556
|
+
*/
|
|
557
|
+
declare function fft(a: ComplexPair, axis?: number): ComplexPair;
|
|
558
|
+
/**
|
|
559
|
+
* Compute a one-dimensional inverse discrete Fourier transform.
|
|
560
|
+
*
|
|
561
|
+
* Currently, the size of the axis must be a power of two.
|
|
562
|
+
*/
|
|
563
|
+
declare function ifft(a: ComplexPair, axis?: number): ComplexPair;
|
|
564
|
+
declare namespace numpy_linalg_d_exports {
|
|
565
|
+
export { cholesky$1 as cholesky, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot };
|
|
566
|
+
}
|
|
567
|
+
/**
|
|
568
|
+
* Compute the Cholesky decomposition of a (batched) positive-definite matrix.
|
|
569
|
+
*
|
|
570
|
+
* This is like `jax.lax.linalg.cholesky()`, except with an option to symmetrize
|
|
571
|
+
* the input matrix, which is on by default.
|
|
572
|
+
*/
|
|
573
|
+
declare function cholesky$1(a: ArrayLike, {
|
|
574
|
+
upper,
|
|
575
|
+
symmetrizeInput
|
|
576
|
+
}?: {
|
|
577
|
+
upper?: boolean;
|
|
578
|
+
symmetrizeInput?: boolean;
|
|
579
|
+
}): Array;
|
|
580
|
+
/** Compute the determinant of a square matrix (batched). */
|
|
581
|
+
declare function det(a: ArrayLike): Array;
|
|
582
|
+
/** Compute the inverse of a square matrix (batched). */
|
|
583
|
+
declare function inv(a: ArrayLike): Array;
|
|
584
|
+
/**
|
|
585
|
+
* Return the least-squares solution to a linear equation.
|
|
586
|
+
*
|
|
587
|
+
* For overdetermined systems, this finds the `x` that minimizes `norm(ax - b)`.
|
|
588
|
+
* For underdetermined systems, this finds the minimum-norm solution for `x`.
|
|
589
|
+
*
|
|
590
|
+
* This currently uses Cholesky decomposition to solve the normal equations,
|
|
591
|
+
* under the hood. The method is not as robust as QR or SVD.
|
|
592
|
+
*
|
|
593
|
+
* @param a coefficient matrix of shape `(M, N)`
|
|
594
|
+
* @param b right-hand side of shape `(M,)` or `(M, K)`
|
|
595
|
+
* @return least-squares solution of shape `(N,)` or `(N, K)`
|
|
596
|
+
*/
|
|
597
|
+
declare function lstsq(a: ArrayLike, b: ArrayLike): Array;
|
|
598
|
+
/** Raise a square matrix to an integer power, via repeated squarings. */
|
|
599
|
+
declare function matrixPower(a: ArrayLike, n: number): Array;
|
|
600
|
+
/** Return sign and natural logarithm of the determinant of `a`. */
|
|
601
|
+
declare function slogdet(a: ArrayLike): [Array, Array];
|
|
602
|
+
/**
|
|
603
|
+
* Solve a linear system of equations.
|
|
604
|
+
*
|
|
605
|
+
* This solves a (batched) linear system of equations `a @ x = b` for `x` given
|
|
606
|
+
* `a` and `b`. If `a` is singular, this will return `nan` or `inf` values.
|
|
607
|
+
*
|
|
608
|
+
* @param a - Coefficient matrix of shape `(..., N, N)`.
|
|
609
|
+
* @param b - Values of shape `(N,)` or `(..., N, M)`.
|
|
610
|
+
* @returns Solution `x` of shape `(..., N)` or `(..., N, M)`.
|
|
611
|
+
*/
|
|
612
|
+
declare function solve(a: ArrayLike, b: ArrayLike): Array;
|
|
613
|
+
//#endregion
|
|
614
|
+
//#region src/library/numpy/dtype-info.d.ts
|
|
615
|
+
/** @inline */
|
|
616
|
+
type FInfo = Readonly<{
|
|
617
|
+
/** The number of bits occupied by the type. */
|
|
618
|
+
bits: number;
|
|
619
|
+
/** Returns the _dtype_ for which finfo returns information. */
|
|
620
|
+
dtype: DType;
|
|
621
|
+
/** The difference between 1.0 and the next smallest representable float larger than 1.0. */
|
|
622
|
+
eps: number;
|
|
623
|
+
/** The difference between 1.0 and the next largest representable float smaller than 1.0. */
|
|
624
|
+
epsneg: number;
|
|
625
|
+
/** The exponent that yields `eps`. */
|
|
626
|
+
machep: number;
|
|
627
|
+
/** The largest representable finite number. */
|
|
628
|
+
max: number;
|
|
629
|
+
/** The smallest positive power of the base (2) that causes overflow. */
|
|
630
|
+
maxexp: number;
|
|
631
|
+
/** The smallest representable (most negative) finite number. */
|
|
632
|
+
min: number;
|
|
633
|
+
/** The largest negative power of the base (2) without leading zeros in mantissa. */
|
|
634
|
+
minexp: number;
|
|
635
|
+
/** The exponent that yields `epsneg`. */
|
|
636
|
+
negep: number;
|
|
637
|
+
/** Number of bits in the exponent portion. */
|
|
638
|
+
nexp: number;
|
|
639
|
+
/** Number of bits in the mantissa portion. */
|
|
640
|
+
nmant: number;
|
|
641
|
+
/** The approximate number of decimal digits to which this kind of float is precise. */
|
|
642
|
+
precision: number;
|
|
643
|
+
/** The approximate decimal resolution, i.e., `10 ** -precision`. */
|
|
644
|
+
resolution: number;
|
|
645
|
+
/** The smallest positive normal number. */
|
|
646
|
+
smallestNormal: number;
|
|
647
|
+
/** The smallest positive subnormal number. */
|
|
648
|
+
smallestSubnormal: number;
|
|
649
|
+
}>;
|
|
650
|
+
/** Machine limits for floating-point types. */
|
|
651
|
+
declare function finfo(dtype: DType): FInfo;
|
|
652
|
+
/** @inline */
|
|
653
|
+
type IInfo = Readonly<{
|
|
654
|
+
/** The number of bits occupied by the type. */
|
|
655
|
+
bits: number;
|
|
656
|
+
/** Returns the _dtype_ for which iinfo returns information. */
|
|
657
|
+
dtype: DType;
|
|
658
|
+
/** The largest representable integer. */
|
|
659
|
+
max: number;
|
|
660
|
+
/** The smallest representable integer. */
|
|
661
|
+
min: number;
|
|
662
|
+
}>;
|
|
663
|
+
/** Machine limits for integer types. */
|
|
664
|
+
declare function iinfo(dtype: DType): IInfo;
|
|
463
665
|
declare namespace numpy_d_exports {
|
|
464
|
-
export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, cos, cosh, cumsum, cumulativeSum, deg2rad, degrees, diag, diagonal, divide, dot$1 as dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, float64, floor, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, positive, pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, squeeze, stack, std, subtract, sum, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
666
|
+
export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot$1 as dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logspace, matmul, matrixTranspose, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
465
667
|
}
|
|
466
668
|
declare const float32 = DType.Float32;
|
|
467
669
|
declare const int32 = DType.Int32;
|
|
@@ -587,6 +789,20 @@ declare function prod(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
|
587
789
|
declare function min(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
588
790
|
/** Return the maximum of array elements along a given axis. */
|
|
589
791
|
declare function max(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
792
|
+
/**
|
|
793
|
+
* Test whether all array elements along a given axis evaluate to True.
|
|
794
|
+
*
|
|
795
|
+
* Returns a boolean array with the same shape as `a` with the specified axis
|
|
796
|
+
* removed. If axis is None, returns a scalar.
|
|
797
|
+
*/
|
|
798
|
+
declare function all(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
799
|
+
/**
|
|
800
|
+
* Test whether any array element along a given axis evaluates to True.
|
|
801
|
+
*
|
|
802
|
+
* Returns a boolean array with the same shape as `a` with the specified axis
|
|
803
|
+
* removed. If axis is None, returns a scalar.
|
|
804
|
+
*/
|
|
805
|
+
declare function any(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
590
806
|
/** Return the peak-to-peak range along a given axis (`max - min`). */
|
|
591
807
|
declare function ptp(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
592
808
|
/** Compute the average of the array elements along the specified axis. */
|
|
@@ -612,10 +828,18 @@ declare function argmax(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
|
|
|
612
828
|
* two-phase parallel reduction algorithm.
|
|
613
829
|
*/
|
|
614
830
|
declare function cumsum(a: ArrayLike, axis?: number): Array;
|
|
615
|
-
/** @function Alternative name for `jax.numpy.cumsum()`. */
|
|
616
|
-
declare const cumulativeSum: typeof cumsum;
|
|
617
831
|
/** Reverse the elements in an array along the given axes. */
|
|
618
832
|
declare function flip(x: ArrayLike, axis?: Axis): Array;
|
|
833
|
+
/**
|
|
834
|
+
* Split an array into multiple sub-arrays along an axis.
|
|
835
|
+
*
|
|
836
|
+
* @param a - The input array to split.
|
|
837
|
+
* @param indicesOrSections - If an integer, it indicates the number of equal
|
|
838
|
+
* sections to create along the specified axis. If a list of integers, it
|
|
839
|
+
* specifies the indices at which to split the array.
|
|
840
|
+
* @param axis - The axis along which to split the array. Default is 0.
|
|
841
|
+
*/
|
|
842
|
+
declare function split$1(a: ArrayLike, indicesOrSections: number | number[], axis?: number): Array[];
|
|
619
843
|
/**
|
|
620
844
|
* Join a sequence of arrays along an existing axis.
|
|
621
845
|
*
|
|
@@ -659,12 +883,29 @@ declare function columnStack(xs: ArrayLike[]): Array;
|
|
|
659
883
|
declare function flipud(x: ArrayLike): Array;
|
|
660
884
|
/** Flip an array horizontally (axis=1). */
|
|
661
885
|
declare function fliplr(x: ArrayLike): Array;
|
|
662
|
-
/**
|
|
663
|
-
declare
|
|
886
|
+
/** Transpose the last two dimensions of an array. */
|
|
887
|
+
declare function matrixTranspose(a: ArrayLike): Array;
|
|
664
888
|
/** Return a 1-D flattened array containing the elements of the input. */
|
|
665
889
|
declare function ravel(a: ArrayLike): Array;
|
|
666
890
|
/** Remove one or more length-1 axes from an array. */
|
|
667
891
|
declare function squeeze(a: ArrayLike, axis?: Axis): Array;
|
|
892
|
+
/**
|
|
893
|
+
* Expand the shape of an array by inserting new axes of length 1.
|
|
894
|
+
*
|
|
895
|
+
* @param a - Input array.
|
|
896
|
+
* @param axis - Position(s) in the expanded axes where the new axis (or axes)
|
|
897
|
+
* is placed. Can be a single integer or an array of integers.
|
|
898
|
+
* @returns Array with the number of dimensions increased.
|
|
899
|
+
*
|
|
900
|
+
* @example
|
|
901
|
+
* ```ts
|
|
902
|
+
* const x = np.array([1, 2]);
|
|
903
|
+
* np.expandDims(x, 0); // Shape [1, 2]
|
|
904
|
+
* np.expandDims(x, 1); // Shape [2, 1]
|
|
905
|
+
* np.expandDims(x, [0, 2]); // Shape [1, 2, 1]
|
|
906
|
+
* ```
|
|
907
|
+
*/
|
|
908
|
+
declare function expandDims(a: ArrayLike, axis: number | number[]): Array;
|
|
668
909
|
/**
|
|
669
910
|
* Repeat each element of an array after themselves.
|
|
670
911
|
*
|
|
@@ -711,6 +952,29 @@ declare function diagonal(a: ArrayLike, offset?: number, axis1?: number, axis2?:
|
|
|
711
952
|
declare function diag(v: ArrayLike, k?: number): Array;
|
|
712
953
|
/** Calculate the sum of the diagonal of an array along the given axes. */
|
|
713
954
|
declare function trace(a: ArrayLike, offset?: number, axis1?: number, axis2?: number): Array;
|
|
955
|
+
/**
|
|
956
|
+
* Return a sorted copy of an array.
|
|
957
|
+
*
|
|
958
|
+
* The array is sorted along a specified axis (the last by default). This may be
|
|
959
|
+
* an unstable sort, and it dispatches to device-specific implementation.
|
|
960
|
+
*/
|
|
961
|
+
declare function sort(a: ArrayLike, axis?: number): Array;
|
|
962
|
+
/**
|
|
963
|
+
* Return indices that would sort an array. This may be an unstable sorting
|
|
964
|
+
* algorithm; it need not preserve order of indices in ties.
|
|
965
|
+
*
|
|
966
|
+
* Returns an array of `int32` indices.
|
|
967
|
+
*
|
|
968
|
+
* The array is sorted along a specified axis (the last by default).
|
|
969
|
+
*/
|
|
970
|
+
declare function argsort(a: ArrayLike, axis?: number): Array;
|
|
971
|
+
/**
|
|
972
|
+
* Take elements from an array along an axis.
|
|
973
|
+
*
|
|
974
|
+
* This is equivalent to advanced indexing with integer indices over that
|
|
975
|
+
* numbered axis. By default, the flattened array is used.
|
|
976
|
+
*/
|
|
977
|
+
declare function take(a: ArrayLike, indices: ArrayLike, axis?: number | null): Array;
|
|
714
978
|
/** Return if two arrays are element-wise equal within a tolerance. */
|
|
715
979
|
declare function allclose(actual: Parameters<typeof array>[0], expected: Parameters<typeof array>[0], options?: {
|
|
716
980
|
rtol?: number;
|
|
@@ -782,6 +1046,10 @@ declare function vecdot(x: ArrayLike, y: ArrayLike, {
|
|
|
782
1046
|
* Like vecdot() but flattens the arguments first into vectors.
|
|
783
1047
|
*/
|
|
784
1048
|
declare function vdot(x: ArrayLike, y: ArrayLike): Array;
|
|
1049
|
+
/** Convolution of two one-dimensional arrays. */
|
|
1050
|
+
declare function convolve(x: Array, y: Array, mode?: "full" | "same" | "valid"): Array;
|
|
1051
|
+
/** Correlation of two one dimensional arrays. */
|
|
1052
|
+
declare function correlate(x: Array, y: Array, mode?: "full" | "same" | "valid"): Array;
|
|
785
1053
|
/**
|
|
786
1054
|
* Return a tuple of coordinate matrices from coordinate vectors.
|
|
787
1055
|
*
|
|
@@ -793,21 +1061,6 @@ declare function meshgrid(xs: Array[], {
|
|
|
793
1061
|
}?: {
|
|
794
1062
|
indexing?: "xy" | "ij";
|
|
795
1063
|
}): Array[];
|
|
796
|
-
/**
|
|
797
|
-
* Return an array with ones on and below the diagonal and zeros elsewhere.
|
|
798
|
-
*
|
|
799
|
-
* If `k` is provided, it specifies the sub-diagonal on and below which the
|
|
800
|
-
* array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
|
|
801
|
-
* `k>0` is above it.
|
|
802
|
-
*/
|
|
803
|
-
declare function tri(n: number, m?: number, k?: number, {
|
|
804
|
-
dtype,
|
|
805
|
-
device
|
|
806
|
-
}?: DTypeAndDevice): Array;
|
|
807
|
-
/** Return the lower triangle of an array. Must be of dimension >= 2. */
|
|
808
|
-
declare function tril(a: ArrayLike, k?: number): Array;
|
|
809
|
-
/** Return the upper triangle of an array. Must be of dimension >= 2. */
|
|
810
|
-
declare function triu(a: ArrayLike, k?: number): Array;
|
|
811
1064
|
/**
|
|
812
1065
|
* Clip (limit) the values in an array.
|
|
813
1066
|
*
|
|
@@ -824,8 +1077,6 @@ declare function clip(a: ArrayLike, min?: ArrayLike, max?: ArrayLike): Array;
|
|
|
824
1077
|
* This is the same function as `jax.numpy.abs()`.
|
|
825
1078
|
*/
|
|
826
1079
|
declare function absolute(x: ArrayLike): Array;
|
|
827
|
-
/** @function Alias of `jax.numpy.absolute()`. */
|
|
828
|
-
declare const abs: typeof absolute;
|
|
829
1080
|
/** Return an element-wise indication of sign of the input. */
|
|
830
1081
|
declare function sign(x: ArrayLike): Array;
|
|
831
1082
|
/** @function Return element-wise positive values of the input (no-op). */
|
|
@@ -854,6 +1105,17 @@ declare const heaviside: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
|
854
1105
|
declare function square(x: ArrayLike): Array;
|
|
855
1106
|
/** Element-wise tangent function (takes radians). */
|
|
856
1107
|
declare function tan(x: ArrayLike): Array;
|
|
1108
|
+
/**
|
|
1109
|
+
* @function
|
|
1110
|
+
* Return the normalized sinc function.
|
|
1111
|
+
*
|
|
1112
|
+
* The sinc function is defined as `sin(πx) / (πx)` for `x != 0`, and `1` for `x = 0`.
|
|
1113
|
+
* This is the normalized sinc function commonly used in signal processing.
|
|
1114
|
+
*
|
|
1115
|
+
* **Note:** JVP is not supported at x=0 due to discontinuous derivative. This
|
|
1116
|
+
* requires a custom JVP rule to handle properly (see JAX implementation).
|
|
1117
|
+
*/
|
|
1118
|
+
declare const sinc: OwnedFunction<(x: ArrayLike) => Array>;
|
|
857
1119
|
/** Element-wise inverse cosine function (inverse of cos). */
|
|
858
1120
|
declare function acos(x: ArrayLike): Array;
|
|
859
1121
|
/**
|
|
@@ -879,16 +1141,24 @@ declare const hypot: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
|
879
1141
|
* The output is ill-defined when both x and y are zero.
|
|
880
1142
|
*/
|
|
881
1143
|
declare const atan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
|
|
882
|
-
/** @function Alias of `jax.numpy.acos()`. */
|
|
883
|
-
declare const arccos: typeof acos;
|
|
884
|
-
/** @function Alias of `jax.numpy.atan()`. */
|
|
885
|
-
declare const arctan: (x: ArrayLike) => Array;
|
|
886
|
-
/** @function Alias of `jax.numpy.atan2()`. */
|
|
887
|
-
declare const arctan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
|
|
888
1144
|
/** Element-wise subtraction, with broadcasting. */
|
|
889
1145
|
declare function subtract(x: ArrayLike, y: ArrayLike): Array;
|
|
890
1146
|
/** Calculates the floating-point division of x by y element-wise. */
|
|
891
1147
|
declare function trueDivide(x: ArrayLike, y: ArrayLike): Array;
|
|
1148
|
+
/**
|
|
1149
|
+
* Return the largest integer smaller or equal to the division of the inputs.
|
|
1150
|
+
*
|
|
1151
|
+
* The result is always rounded towards negative infinity.
|
|
1152
|
+
*
|
|
1153
|
+
* For floating-point inputs, this is equivalent to `floor(x / y)`.
|
|
1154
|
+
* For integer inputs, we use `(x - remainder(x, y)) / y` to handle
|
|
1155
|
+
* negative values correctly (note: may overflow near int32 boundaries).
|
|
1156
|
+
*
|
|
1157
|
+
* @param x - Dividend array.
|
|
1158
|
+
* @param y - Divisor array.
|
|
1159
|
+
* @returns Element-wise floor division of x by y.
|
|
1160
|
+
*/
|
|
1161
|
+
declare function floorDivide(x: ArrayLike, y: ArrayLike): Array;
|
|
892
1162
|
/**
|
|
893
1163
|
* @function
|
|
894
1164
|
* Calculate element-wise floating-point modulo operation.
|
|
@@ -899,8 +1169,16 @@ declare const fmod: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
|
|
|
899
1169
|
* Calculate element-wise remainder of the division (matches sign of y).
|
|
900
1170
|
*/
|
|
901
1171
|
declare const remainder: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
|
|
902
|
-
/**
|
|
903
|
-
|
|
1172
|
+
/**
|
|
1173
|
+
* Return element-wise quotient and remainder simultaneously.
|
|
1174
|
+
*
|
|
1175
|
+
* Equivalent to `[floorDivide(x, y), remainder(x, y)]`.
|
|
1176
|
+
*
|
|
1177
|
+
* @param x - Dividend array.
|
|
1178
|
+
* @param y - Divisor array.
|
|
1179
|
+
* @returns Tuple of [quotient, remainder].
|
|
1180
|
+
*/
|
|
1181
|
+
declare function divmod(x: ArrayLike, y: ArrayLike): [Array, Array];
|
|
904
1182
|
/** Round input to the nearest integer towards zero. */
|
|
905
1183
|
declare function trunc(x: ArrayLike): Array;
|
|
906
1184
|
/**
|
|
@@ -940,8 +1218,6 @@ declare const degrees: typeof rad2deg;
|
|
|
940
1218
|
* Computes first array raised to power of second array, element-wise.
|
|
941
1219
|
*/
|
|
942
1220
|
declare const power: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
943
|
-
/** @function Alias of `jax.numpy.power()`. */
|
|
944
|
-
declare const pow: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
945
1221
|
/** @function Calculate the element-wise cube root of the input array. */
|
|
946
1222
|
declare const cbrt: OwnedFunction<(x: ArrayLike) => Array>;
|
|
947
1223
|
/**
|
|
@@ -986,12 +1262,6 @@ declare const arccosh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
|
986
1262
|
* `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
|
|
987
1263
|
*/
|
|
988
1264
|
declare const arctanh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
989
|
-
/** @function Alias of `jax.numpy.arcsinh()`. */
|
|
990
|
-
declare const asinh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
991
|
-
/** @function Alias of `jax.numpy.arccosh()`. */
|
|
992
|
-
declare const acosh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
993
|
-
/** @function Alias of `jax.numpy.arctanh()`. */
|
|
994
|
-
declare const atanh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
995
1265
|
/**
|
|
996
1266
|
* Compute the variance of an array.
|
|
997
1267
|
*
|
|
@@ -1018,6 +1288,14 @@ declare function std(x: ArrayLike, axis?: Axis, opts?: {
|
|
|
1018
1288
|
mean?: ArrayLike;
|
|
1019
1289
|
correction?: number;
|
|
1020
1290
|
} & ReduceOpts): Array;
|
|
1291
|
+
/** Estimate the sample covariance of a set of variables. */
|
|
1292
|
+
declare function cov(x: ArrayLike, y?: ArrayLike | null, {
|
|
1293
|
+
rowvar
|
|
1294
|
+
}?: {
|
|
1295
|
+
rowvar?: boolean;
|
|
1296
|
+
}): Array;
|
|
1297
|
+
/** Compute the Pearson correlation coefficients (in range `[-1, 1]`). */
|
|
1298
|
+
declare function corrcoef(x: ArrayLike, y?: ArrayLike): Array;
|
|
1021
1299
|
/** Test element-wise for positive or negative infinity, return bool array. */
|
|
1022
1300
|
declare function isinf(x: ArrayLike): Array;
|
|
1023
1301
|
/** Test element-wise for NaN (Not a Number). */
|
|
@@ -1031,7 +1309,6 @@ declare function isposinf(x: ArrayLike): Array;
|
|
|
1031
1309
|
* Test element-wise for finite values (not infinity or NaN).
|
|
1032
1310
|
*/
|
|
1033
1311
|
declare const isfinite: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1034
|
-
//# sourceMappingURL=numpy.d.ts.map
|
|
1035
1312
|
declare namespace tree_d_exports {
|
|
1036
1313
|
export { JsTree, JsTreeDef, MapJsTree, NodeType, dispose, flatten, leaves, map, ref, structure, unflatten };
|
|
1037
1314
|
}
|
|
@@ -1084,7 +1361,7 @@ declare function dispose<Tree extends JsTree<Array>>(tree: Tree | null | undefin
|
|
|
1084
1361
|
interface ConvParams {
|
|
1085
1362
|
vmapDims: number;
|
|
1086
1363
|
strides: number[];
|
|
1087
|
-
padding: [
|
|
1364
|
+
padding: Pair[];
|
|
1088
1365
|
lhsDilation: number[];
|
|
1089
1366
|
rhsDilation: number[];
|
|
1090
1367
|
}
|
|
@@ -1165,9 +1442,21 @@ declare class Jaxpr implements FpHashable {
|
|
|
1165
1442
|
* - Remove no-op movement operations.
|
|
1166
1443
|
*/
|
|
1167
1444
|
simplify(): Jaxpr;
|
|
1168
|
-
/** Flattens nested
|
|
1445
|
+
/** Flattens nested Jit in a Jaxpr. Useful for handling jit-of-jit. */
|
|
1169
1446
|
flatten(): Jaxpr;
|
|
1170
1447
|
}
|
|
1448
|
+
/** Jaxpr with a collection of associated, traced constants. */
|
|
1449
|
+
declare class ClosedJaxpr {
|
|
1450
|
+
readonly jaxpr: Jaxpr;
|
|
1451
|
+
readonly consts: Tracer[];
|
|
1452
|
+
constructor(jaxpr: Jaxpr, consts: Tracer[]);
|
|
1453
|
+
/** String representation of this Jaxpr. */
|
|
1454
|
+
toString(): string;
|
|
1455
|
+
/** Apply a function to the underlying Jaxpr. */
|
|
1456
|
+
mapJaxpr(f: (jaxpr: Jaxpr) => Jaxpr): ClosedJaxpr;
|
|
1457
|
+
/** Dispose of the constants in this Jaxpr. */
|
|
1458
|
+
dispose(): void;
|
|
1459
|
+
}
|
|
1171
1460
|
/** @inline */
|
|
1172
1461
|
type JitOpts = {
|
|
1173
1462
|
staticArgnums?: number[];
|
|
@@ -1190,7 +1479,9 @@ declare enum Primitive {
|
|
|
1190
1479
|
Mul = "mul",
|
|
1191
1480
|
Idiv = "idiv",
|
|
1192
1481
|
Mod = "mod",
|
|
1193
|
-
// uses sign of
|
|
1482
|
+
// uses sign of numerator, C-style, matches JS but not Python
|
|
1483
|
+
Min = "min",
|
|
1484
|
+
Max = "max",
|
|
1194
1485
|
Neg = "neg",
|
|
1195
1486
|
Reciprocal = "reciprocal",
|
|
1196
1487
|
Floor = "floor",
|
|
@@ -1198,7 +1489,6 @@ declare enum Primitive {
|
|
|
1198
1489
|
StopGradient = "stop_gradient",
|
|
1199
1490
|
Cast = "cast",
|
|
1200
1491
|
Bitcast = "bitcast",
|
|
1201
|
-
RandomBits = "random_bits",
|
|
1202
1492
|
Sin = "sin",
|
|
1203
1493
|
Cos = "cos",
|
|
1204
1494
|
Asin = "asin",
|
|
@@ -1208,8 +1498,6 @@ declare enum Primitive {
|
|
|
1208
1498
|
Erf = "erf",
|
|
1209
1499
|
Erfc = "erfc",
|
|
1210
1500
|
Sqrt = "sqrt",
|
|
1211
|
-
Min = "min",
|
|
1212
|
-
Max = "max",
|
|
1213
1501
|
Reduce = "reduce",
|
|
1214
1502
|
Dot = "dot",
|
|
1215
1503
|
// sum(x*y, axis=-1)
|
|
@@ -1219,14 +1507,27 @@ declare enum Primitive {
|
|
|
1219
1507
|
PoolTranspose = "pool_transpose",
|
|
1220
1508
|
Compare = "compare",
|
|
1221
1509
|
Where = "where",
|
|
1510
|
+
Concatenate = "concatenate",
|
|
1511
|
+
Split = "split",
|
|
1512
|
+
RandomBits = "random_bits",
|
|
1513
|
+
Gather = "gather",
|
|
1222
1514
|
Transpose = "transpose",
|
|
1223
1515
|
Broadcast = "broadcast",
|
|
1224
1516
|
Reshape = "reshape",
|
|
1225
1517
|
Flip = "flip",
|
|
1226
1518
|
Shrink = "shrink",
|
|
1227
1519
|
Pad = "pad",
|
|
1228
|
-
|
|
1229
|
-
|
|
1520
|
+
Sort = "sort",
|
|
1521
|
+
// sort(x, axis=-1)
|
|
1522
|
+
Argsort = "argsort",
|
|
1523
|
+
// argsort(x, axis=-1)
|
|
1524
|
+
TriangularSolve = "triangular_solve",
|
|
1525
|
+
// A is upper triangular, A @ X.T = B.T
|
|
1526
|
+
Cholesky = "cholesky",
|
|
1527
|
+
// A is positive-definite, A = L @ L^T
|
|
1528
|
+
LU = "lu",
|
|
1529
|
+
// LU decomposition with partial pivoting
|
|
1530
|
+
Jit = "jit",
|
|
1230
1531
|
}
|
|
1231
1532
|
interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
1232
1533
|
[Primitive.Cast]: {
|
|
@@ -1252,6 +1553,21 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
|
1252
1553
|
[Primitive.Compare]: {
|
|
1253
1554
|
op: CompareOp;
|
|
1254
1555
|
};
|
|
1556
|
+
[Primitive.Concatenate]: {
|
|
1557
|
+
axis: number;
|
|
1558
|
+
};
|
|
1559
|
+
[Primitive.Split]: {
|
|
1560
|
+
axis: number;
|
|
1561
|
+
sizes: number[];
|
|
1562
|
+
};
|
|
1563
|
+
[Primitive.RandomBits]: {
|
|
1564
|
+
shape: number[];
|
|
1565
|
+
mode: "xor" | 0 | 1;
|
|
1566
|
+
};
|
|
1567
|
+
[Primitive.Gather]: {
|
|
1568
|
+
axis: number[];
|
|
1569
|
+
outDim: number;
|
|
1570
|
+
};
|
|
1255
1571
|
[Primitive.Transpose]: {
|
|
1256
1572
|
perm: number[];
|
|
1257
1573
|
};
|
|
@@ -1259,10 +1575,6 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
|
1259
1575
|
shape: number[];
|
|
1260
1576
|
axis: number[];
|
|
1261
1577
|
};
|
|
1262
|
-
[Primitive.RandomBits]: {
|
|
1263
|
-
shape: number[];
|
|
1264
|
-
mode: "xor" | 0 | 1;
|
|
1265
|
-
};
|
|
1266
1578
|
[Primitive.Reshape]: {
|
|
1267
1579
|
shape: number[];
|
|
1268
1580
|
};
|
|
@@ -1275,15 +1587,14 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
|
1275
1587
|
[Primitive.Pad]: {
|
|
1276
1588
|
width: Pair[];
|
|
1277
1589
|
};
|
|
1278
|
-
[Primitive.
|
|
1279
|
-
axis: number[];
|
|
1280
|
-
outDim: number;
|
|
1281
|
-
};
|
|
1282
|
-
[Primitive.JitCall]: {
|
|
1590
|
+
[Primitive.Jit]: {
|
|
1283
1591
|
name: string;
|
|
1284
1592
|
jaxpr: Jaxpr;
|
|
1285
1593
|
numConsts: number;
|
|
1286
1594
|
};
|
|
1595
|
+
[Primitive.TriangularSolve]: {
|
|
1596
|
+
unitDiagonal: boolean;
|
|
1597
|
+
};
|
|
1287
1598
|
}
|
|
1288
1599
|
/** Type of parameters taken by each primitive. */
|
|
1289
1600
|
type PrimitiveParams<T extends Primitive> = T extends keyof PrimitiveParamsImpl ? PrimitiveParamsImpl[T] : Record<string, never>;
|
|
@@ -1424,6 +1735,7 @@ declare abstract class Tracer {
|
|
|
1424
1735
|
neg(): this;
|
|
1425
1736
|
add(other: this | TracerValue): this;
|
|
1426
1737
|
mul(other: this | TracerValue): this;
|
|
1738
|
+
mod(other: this | TracerValue): this;
|
|
1427
1739
|
greater(other: this | TracerValue): this;
|
|
1428
1740
|
less(other: this | TracerValue): this;
|
|
1429
1741
|
equal(other: this | TracerValue): this;
|
|
@@ -1451,7 +1763,7 @@ declare abstract class Tracer {
|
|
|
1451
1763
|
sub(other: this | TracerValue): this;
|
|
1452
1764
|
/** Divide an array by this one. */
|
|
1453
1765
|
div(other: this | TracerValue): this;
|
|
1454
|
-
/** Return specified diagonals. See `numpy.diagonal` for full docs. */
|
|
1766
|
+
/** Return specified diagonals. See `jax.numpy.diagonal` for full docs. */
|
|
1455
1767
|
diagonal(offset?: number, axis1?: number, axis2?: number): this;
|
|
1456
1768
|
/** Flatten the array without changing its data. */
|
|
1457
1769
|
flatten(): this;
|
|
@@ -1470,6 +1782,19 @@ declare abstract class Tracer {
|
|
|
1470
1782
|
* ```
|
|
1471
1783
|
*/
|
|
1472
1784
|
[Symbol.iterator](): IterableIterator<this>;
|
|
1785
|
+
/**
|
|
1786
|
+
* Return a sorted copy of an array in ascending order.
|
|
1787
|
+
*
|
|
1788
|
+
* See `jax.numpy.sort` for full docs.
|
|
1789
|
+
*/
|
|
1790
|
+
sort(axis?: number): this;
|
|
1791
|
+
/**
|
|
1792
|
+
* Return the indices that would sort an array. This may not be a stable
|
|
1793
|
+
* sorting algorithm; it need not preserve order of indices in ties.
|
|
1794
|
+
*
|
|
1795
|
+
* See `jax.numpy.argsort` for full docs.
|
|
1796
|
+
*/
|
|
1797
|
+
argsort(axis?: number): this;
|
|
1473
1798
|
/**
|
|
1474
1799
|
* Slice an array along one or more axes.
|
|
1475
1800
|
*
|
|
@@ -1512,6 +1837,8 @@ declare class ShapedArray implements AbstractValue {
|
|
|
1512
1837
|
constructor(shape: number[], dtype: DType, weakType: boolean);
|
|
1513
1838
|
static fromAval(aval: AbstractValue): ShapedArray;
|
|
1514
1839
|
get ndim(): number;
|
|
1840
|
+
get size(): number;
|
|
1841
|
+
scalar(): ShapedArray;
|
|
1515
1842
|
toString(): string;
|
|
1516
1843
|
equals(other: ShapedArray): boolean;
|
|
1517
1844
|
}
|
|
@@ -1529,12 +1856,12 @@ type ArrayLike = Array | number | boolean;
|
|
|
1529
1856
|
declare class PendingExecute {
|
|
1530
1857
|
#private;
|
|
1531
1858
|
readonly backend: Backend;
|
|
1532
|
-
readonly
|
|
1859
|
+
readonly source: Kernel | Routine;
|
|
1533
1860
|
readonly inputs: Slot[];
|
|
1534
1861
|
readonly outputs: Slot[];
|
|
1535
1862
|
prepared: Executable | null;
|
|
1536
1863
|
submitted: boolean;
|
|
1537
|
-
constructor(backend: Backend,
|
|
1864
|
+
constructor(backend: Backend, source: Kernel | Routine, inputs: Slot[], outputs: Slot[]);
|
|
1538
1865
|
updateRc(delta: number): void;
|
|
1539
1866
|
prepare(): Promise<void>;
|
|
1540
1867
|
prepareSync(): void;
|
|
@@ -1566,7 +1893,6 @@ type ArrayConstructorArgs = {
|
|
|
1566
1893
|
*/
|
|
1567
1894
|
declare class Array extends Tracer {
|
|
1568
1895
|
#private;
|
|
1569
|
-
id: number;
|
|
1570
1896
|
/**
|
|
1571
1897
|
* @ignore
|
|
1572
1898
|
* Constructs an array from source, shape and backend. Note that if the source
|
|
@@ -1580,6 +1906,8 @@ declare class Array extends Tracer {
|
|
|
1580
1906
|
toString(): string;
|
|
1581
1907
|
get device(): Device;
|
|
1582
1908
|
get ref(): this;
|
|
1909
|
+
/** Get the current reference count (for debugging memory management). */
|
|
1910
|
+
get refCount(): number;
|
|
1583
1911
|
dispose(): void;
|
|
1584
1912
|
/**
|
|
1585
1913
|
* Convert this array into a primitive value.
|
|
@@ -1700,6 +2028,21 @@ declare function arange(start: number, stop?: number, step?: number, {
|
|
|
1700
2028
|
dtype,
|
|
1701
2029
|
device
|
|
1702
2030
|
}?: DTypeAndDevice): Array;
|
|
2031
|
+
/**
|
|
2032
|
+
* Return an array with ones on and below the diagonal and zeros elsewhere.
|
|
2033
|
+
*
|
|
2034
|
+
* If `k` is provided, it specifies the sub-diagonal on and below which the
|
|
2035
|
+
* array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
|
|
2036
|
+
* `k>0` is above it.
|
|
2037
|
+
*/
|
|
2038
|
+
declare function tri(n: number, m?: number, k?: number, {
|
|
2039
|
+
dtype,
|
|
2040
|
+
device
|
|
2041
|
+
}?: DTypeAndDevice): Array;
|
|
2042
|
+
/** Return the lower triangle of an array. Must be of dimension >= 2. */
|
|
2043
|
+
declare function tril(a: ArrayLike, k?: number): Array;
|
|
2044
|
+
/** Return the upper triangle of an array. Must be of dimension >= 2. */
|
|
2045
|
+
declare function triu(a: ArrayLike, k?: number): Array;
|
|
1703
2046
|
/**
|
|
1704
2047
|
* Return evenly spaced numbers over a specified interval.
|
|
1705
2048
|
*
|
|
@@ -1713,8 +2056,114 @@ declare function linspace(start: number, stop: number, num?: number, endpoint?:
|
|
|
1713
2056
|
dtype,
|
|
1714
2057
|
device
|
|
1715
2058
|
}?: DTypeAndDevice): Array;
|
|
2059
|
+
/**
|
|
2060
|
+
* Return numbers spaced evenly on a log scale.
|
|
2061
|
+
*
|
|
2062
|
+
* In linear space, the sequence starts at `base ** start` and ends at
|
|
2063
|
+
* `base ** stop` (see `endpoint` below).
|
|
2064
|
+
*
|
|
2065
|
+
* @param start - `base ** start` is the starting value of the sequence.
|
|
2066
|
+
* @param stop - `base ** stop` is the final value of the sequence, unless `endpoint` is false.
|
|
2067
|
+
* @param num - Number of samples to generate. Default is 50.
|
|
2068
|
+
* @param endpoint - If true, `stop` is the last sample. Otherwise, it is not included. Default is true.
|
|
2069
|
+
* @param base - The base of the log space. Default is 10.
|
|
2070
|
+
* @returns Array of evenly spaced values on a log scale.
|
|
2071
|
+
*/
|
|
2072
|
+
declare function logspace(start: number, stop: number, num?: number, endpoint?: boolean, base?: number, {
|
|
2073
|
+
dtype,
|
|
2074
|
+
device
|
|
2075
|
+
}?: DTypeAndDevice): Array;
|
|
2076
|
+
declare namespace lax_linalg_d_exports {
|
|
2077
|
+
export { cholesky, lu, triangularSolve };
|
|
2078
|
+
}
|
|
2079
|
+
/**
|
|
2080
|
+
* Compute the Cholesky decomposition of a symmetric positive-definite matrix.
|
|
2081
|
+
*
|
|
2082
|
+
* The Cholesky decomposition of a matrix `A` is:
|
|
2083
|
+
*
|
|
2084
|
+
* - A = L @ L^T (for upper=false, default)
|
|
2085
|
+
* - A = U^T @ U (for upper=true)
|
|
2086
|
+
*
|
|
2087
|
+
* where `L` is a lower-triangular matrix and `U` is an upper-triangular matrix.
|
|
2088
|
+
* The input matrix must be symmetric and positive-definite.
|
|
2089
|
+
*
|
|
2090
|
+
* @example
|
|
2091
|
+
* ```ts
|
|
2092
|
+
* import { lax, numpy as np } from "@jax-js/jax";
|
|
2093
|
+
*
|
|
2094
|
+
* const x = np.array([[2., 1.], [1., 2.]]);
|
|
2095
|
+
*
|
|
2096
|
+
* // Lower Cholesky factorization (default):
|
|
2097
|
+
* const L = lax.linalg.cholesky(x);
|
|
2098
|
+
* // L ≈ [[1.4142135, 0], [0.70710677, 1.2247449]]
|
|
2099
|
+
*
|
|
2100
|
+
* // Upper Cholesky factorization:
|
|
2101
|
+
* const U = lax.linalg.cholesky(x, { upper: true });
|
|
2102
|
+
* // U ≈ [[1.4142135, 0.70710677], [0, 1.2247449]]
|
|
2103
|
+
* ```
|
|
2104
|
+
*/
|
|
2105
|
+
declare function cholesky(a: ArrayLike, {
|
|
2106
|
+
upper
|
|
2107
|
+
}?: {
|
|
2108
|
+
upper?: boolean;
|
|
2109
|
+
}): Array;
|
|
2110
|
+
/**
|
|
2111
|
+
* LU decomposition with partial pivoting.
|
|
2112
|
+
*
|
|
2113
|
+
* Computes the matrix decomposition: `P @ A = L @ U`, where `P` is a
|
|
2114
|
+
* permutation of the rows of `A`, `L` is lower-triangular with unit diagonal,
|
|
2115
|
+
* and `U` is upper-triangular.
|
|
2116
|
+
*
|
|
2117
|
+
* @param x - A batch of matrices with shape `[..., m, n]`.
|
|
2118
|
+
*
|
|
2119
|
+
* @returns A tuple `(lu, pivots, permutation)` where:
|
|
2120
|
+
* - `lu`: combined lower and upper triangular matrices.
|
|
2121
|
+
* - `pivots`: an array of pivot indices with shape `[..., min(m, n)]`.
|
|
2122
|
+
* - `permutation`: the permutation generated by pivots with shape `[..., m]`.
|
|
2123
|
+
*
|
|
2124
|
+
* @example
|
|
2125
|
+
* ```ts
|
|
2126
|
+
* import { lax, numpy as np } from "@jax-js/jax";
|
|
2127
|
+
*
|
|
2128
|
+
* const A = np.array([[4., 3.], [6., 3.]]);
|
|
2129
|
+
* const [lu, pivots, permutation] = lax.linalg.lu(A);
|
|
2130
|
+
* // lu ≈ [[6., 3.], [0.6666667, 1.0]]
|
|
2131
|
+
* // pivots = [1, 1]
|
|
2132
|
+
* // permutation = [1, 0]
|
|
2133
|
+
* ```
|
|
2134
|
+
*/
|
|
2135
|
+
declare function lu(x: ArrayLike): [Array, Array, Array];
|
|
2136
|
+
/**
|
|
2137
|
+
* Solve a triangular linear system.
|
|
2138
|
+
*
|
|
2139
|
+
* Solves `a @ x = b` (if leftSide=true) or `x @ a = b` (if leftSide=false)
|
|
2140
|
+
* where `a` is a triangular matrix.
|
|
2141
|
+
*
|
|
2142
|
+
* @example
|
|
2143
|
+
* ```ts
|
|
2144
|
+
* import { lax, numpy as np } from "@jax-js/jax";
|
|
2145
|
+
*
|
|
2146
|
+
* const L = np.array([[2., 0.], [1., 3.]]);
|
|
2147
|
+
* const b = np.array([4., 7.]).reshape([2, 1]);
|
|
2148
|
+
*
|
|
2149
|
+
* // Solve L @ x = b
|
|
2150
|
+
* const x = lax.linalg.triangularSolve(L, b, { leftSide: true, lower: true });
|
|
2151
|
+
* // x = [[2.], [5./3.]]
|
|
2152
|
+
* ```
|
|
2153
|
+
*/
|
|
2154
|
+
declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
|
|
2155
|
+
leftSide,
|
|
2156
|
+
lower,
|
|
2157
|
+
transposeA,
|
|
2158
|
+
unitDiagonal
|
|
2159
|
+
}?: {
|
|
2160
|
+
leftSide?: boolean;
|
|
2161
|
+
lower?: boolean;
|
|
2162
|
+
transposeA?: boolean;
|
|
2163
|
+
unitDiagonal?: boolean;
|
|
2164
|
+
}): Array;
|
|
1716
2165
|
declare namespace lax_d_exports {
|
|
1717
|
-
export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convWithGeneralPadding, dot, erf, erfc, reduceWindow, stopGradient };
|
|
2166
|
+
export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convWithGeneralPadding, dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
|
|
1718
2167
|
}
|
|
1719
2168
|
/**
|
|
1720
2169
|
* Dimension numbers for general `dot()` primitive.
|
|
@@ -1745,7 +2194,7 @@ declare function dot(lhs: Array, rhs: Array, {
|
|
|
1745
2194
|
lhsBatchDims: lb,
|
|
1746
2195
|
rhsBatchDims: rb
|
|
1747
2196
|
}?: DotDimensionNumbers): Array;
|
|
1748
|
-
type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | [
|
|
2197
|
+
type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | Pair[];
|
|
1749
2198
|
/**
|
|
1750
2199
|
* General n-dimensional convolution operator, with optional dilation.
|
|
1751
2200
|
*
|
|
@@ -1785,9 +2234,8 @@ declare function erfc(x: ArrayLike): Array;
|
|
|
1785
2234
|
* forward or reverse-mode automatic differentiation.
|
|
1786
2235
|
*/
|
|
1787
2236
|
declare function stopGradient(x: ArrayLike): Array;
|
|
1788
|
-
//# sourceMappingURL=lax.d.ts.map
|
|
1789
2237
|
declare namespace nn_d_exports {
|
|
1790
|
-
export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, squareplus, standardize, swish };
|
|
2238
|
+
export { celu, elu, gelu, glu, hardSigmoid, hardSilu, hardSilu as hardSwish, hardTanh, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, selu, sigmoid, silu, softSign, softmax, softplus, sparsePlus, sparseSigmoid, squareplus, standardize, silu as swish };
|
|
1791
2239
|
}
|
|
1792
2240
|
/**
|
|
1793
2241
|
* Rectified Linear Unit (ReLU) activation function:
|
|
@@ -1814,21 +2262,28 @@ declare function sigmoid(x: ArrayLike): Array;
|
|
|
1814
2262
|
*/
|
|
1815
2263
|
declare function softplus(x: ArrayLike): Array;
|
|
1816
2264
|
/**
|
|
1817
|
-
*
|
|
1818
|
-
*
|
|
2265
|
+
* @function
|
|
2266
|
+
* Sparse plus function:
|
|
2267
|
+
*
|
|
2268
|
+
* - When `x <= -1`: `0`
|
|
2269
|
+
* - When `-1 < x < 1`: `(x+1)**2 / 4`
|
|
2270
|
+
* - When `x >= 1`: `x`
|
|
1819
2271
|
*/
|
|
1820
|
-
declare
|
|
2272
|
+
declare const sparsePlus: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1821
2273
|
/**
|
|
1822
2274
|
* @function
|
|
1823
|
-
*
|
|
1824
|
-
* Swish, computed element-wise:
|
|
1825
|
-
* `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
|
|
2275
|
+
* Sparse sigmoid activation function.
|
|
1826
2276
|
*
|
|
1827
|
-
*
|
|
1828
|
-
*
|
|
1829
|
-
*
|
|
2277
|
+
* - When `x <= -1`: `0`
|
|
2278
|
+
* - When `-1 < x < 1`: `(x + 1) / 2`
|
|
2279
|
+
* - When `x >= 1`: `1`
|
|
1830
2280
|
*/
|
|
1831
|
-
declare const
|
|
2281
|
+
declare const sparseSigmoid: OwnedFunction<(x: ArrayLike) => Array>;
|
|
2282
|
+
/**
|
|
2283
|
+
* Soft-sign activation function, computed element-wise:
|
|
2284
|
+
* `softsign(x) = x / (|x| + 1)`.
|
|
2285
|
+
*/
|
|
2286
|
+
declare function softSign(x: ArrayLike): Array;
|
|
1832
2287
|
/**
|
|
1833
2288
|
* @function
|
|
1834
2289
|
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
@@ -1839,7 +2294,7 @@ declare const silu: OwnedFunction<(x: ArrayLike) => Array>;
|
|
|
1839
2294
|
*
|
|
1840
2295
|
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
1841
2296
|
*/
|
|
1842
|
-
declare const
|
|
2297
|
+
declare const silu: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1843
2298
|
/**
|
|
1844
2299
|
* Log-sigmoid activation function, computed element-wise:
|
|
1845
2300
|
* `log_sigmoid(x) = log(sigmoid(x)) = -log(1 + exp(-x))`.
|
|
@@ -1852,6 +2307,12 @@ declare function logSigmoid(x: ArrayLike): Array;
|
|
|
1852
2307
|
declare const identity: (x: ArrayLike) => Array;
|
|
1853
2308
|
/** Leaky rectified linear (ReLU) activation function */
|
|
1854
2309
|
declare function leakyRelu(x: ArrayLike, negativeSlope?: ArrayLike): Array;
|
|
2310
|
+
/** Hard sigmoid activation function: `relu6(x+3)/6`. */
|
|
2311
|
+
declare function hardSigmoid(x: ArrayLike): Array;
|
|
2312
|
+
/** Hard SiLU (swish) activation function: `x * hardSigmoid(x)`. */
|
|
2313
|
+
declare function hardSilu(x: ArrayLike): Array;
|
|
2314
|
+
/** Hard tanh activation function: `clip(x, -1, 1)`. */
|
|
2315
|
+
declare function hardTanh(x: ArrayLike): Array;
|
|
1855
2316
|
/**
|
|
1856
2317
|
* Exponential linear unit activation function.
|
|
1857
2318
|
*
|
|
@@ -1866,6 +2327,16 @@ declare function elu(x: ArrayLike, alpha?: ArrayLike): Array;
|
|
|
1866
2327
|
* `celu(x) = x > 0 ? x : alpha * (exp(x/alpha) - 1)`
|
|
1867
2328
|
*/
|
|
1868
2329
|
declare function celu(x: ArrayLike, alpha?: ArrayLike): Array;
|
|
2330
|
+
/**
|
|
2331
|
+
* @function
|
|
2332
|
+
* Scaled exponential linear unit activation.
|
|
2333
|
+
*
|
|
2334
|
+
* Computes the element-wise function:
|
|
2335
|
+
* `selu(x) = lambda * (x > 0 ? x : alpha * (exp(x) - 1))`
|
|
2336
|
+
*
|
|
2337
|
+
* Where `alpha = 1.6732632423543772` and `lambda = 1.0507009873554805`.
|
|
2338
|
+
*/
|
|
2339
|
+
declare const selu: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1869
2340
|
/**
|
|
1870
2341
|
* @function
|
|
1871
2342
|
* Gaussion error linear unit (GELU) activation function.
|
|
@@ -1961,12 +2432,11 @@ declare function standardize(x: ArrayLike, axis?: Axis, opts?: {
|
|
|
1961
2432
|
* ```
|
|
1962
2433
|
*/
|
|
1963
2434
|
declare function oneHot(x: Array, numClasses: number): Array;
|
|
1964
|
-
//# sourceMappingURL=nn.d.ts.map
|
|
1965
2435
|
declare namespace random_d_exports {
|
|
1966
|
-
export { bernoulli, bits, exponential, key, normal, split, uniform };
|
|
2436
|
+
export { bernoulli, bits, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
|
|
1967
2437
|
}
|
|
1968
2438
|
/** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
|
|
1969
|
-
declare function key(seed:
|
|
2439
|
+
declare function key(seed: ArrayLike): Array;
|
|
1970
2440
|
/** Splits a PRNG key into `num` new keys by adding a leading axis. */
|
|
1971
2441
|
declare function split(key: Array, num?: number | number[]): Array;
|
|
1972
2442
|
/** Sample uniform bits in the form of unsigned integers. */
|
|
@@ -1986,11 +2456,50 @@ declare const uniform: OwnedFunction<(key: ArrayLike, shape?: number[] | undefin
|
|
|
1986
2456
|
* and must be broadcastable to `shape`.
|
|
1987
2457
|
*/
|
|
1988
2458
|
declare function bernoulli(key: Array, p?: ArrayLike, shape?: number[]): Array;
|
|
2459
|
+
/**
|
|
2460
|
+
* @function
|
|
2461
|
+
* Sample from a Cauchy distribution with location 0 and scale 1.
|
|
2462
|
+
*
|
|
2463
|
+
* Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
|
|
2464
|
+
*/
|
|
2465
|
+
declare const cauchy: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
1989
2466
|
/**
|
|
1990
2467
|
* @function
|
|
1991
2468
|
* Sample exponential random values according to `p(x) = exp(-x)`.
|
|
1992
2469
|
*/
|
|
1993
2470
|
declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
2471
|
+
/**
|
|
2472
|
+
* @function
|
|
2473
|
+
* Sample from a Gumbel distribution with location 0 and scale 1.
|
|
2474
|
+
*
|
|
2475
|
+
* Uses inverse transform sampling: `x = -log(-log(u))` where u ~ Uniform(0, 1).
|
|
2476
|
+
*/
|
|
2477
|
+
declare const gumbel: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
2478
|
+
/**
|
|
2479
|
+
* @function
|
|
2480
|
+
* Sample from a Laplace distribution with location 0 and scale 1.
|
|
2481
|
+
*
|
|
2482
|
+
* Uses inverse transform sampling: the CDF is `F(x) = 0.5 + 0.5 * sign(x) * (1 - exp(-|x|))`.
|
|
2483
|
+
* Inverting: `x = -sign(u - 0.5) * log(1 - 2 * |u - 0.5|)`.
|
|
2484
|
+
*/
|
|
2485
|
+
declare const laplace: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
2486
|
+
/**
|
|
2487
|
+
* @function
|
|
2488
|
+
* Sample multivariate normal random values with given mean and covariance.
|
|
2489
|
+
*
|
|
2490
|
+
* The values are returned with the given shape, along with the final dimension
|
|
2491
|
+
* used to represent the n-dimensional multivariate normal factors.
|
|
2492
|
+
*
|
|
2493
|
+
* This uses Cholesky decomposition on the covariance matrix.
|
|
2494
|
+
*
|
|
2495
|
+
* - `key` - PRNG key
|
|
2496
|
+
* - `mean` - Mean vector of shape `[..., n]`
|
|
2497
|
+
* - `cov` - Covariance of shape `[..., n, n]`, must be positive-definite
|
|
2498
|
+
* - `shape` - Result batch shape, must be broadcastable with
|
|
2499
|
+
* `mean.shape[:-1]` and `cov.shape[:-2]`
|
|
2500
|
+
* @returns Random samples of shape `[...shape, n]`
|
|
2501
|
+
*/
|
|
2502
|
+
declare const multivariateNormal: OwnedFunction<(key: ArrayLike, mean: ArrayLike, cov: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
1994
2503
|
/**
|
|
1995
2504
|
* @function
|
|
1996
2505
|
* Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
|
|
@@ -2000,7 +2509,6 @@ declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | und
|
|
|
2000
2509
|
* bitwise identical to JAX.
|
|
2001
2510
|
*/
|
|
2002
2511
|
declare const normal: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
2003
|
-
//# sourceMappingURL=random.d.ts.map
|
|
2004
2512
|
declare namespace scipy_special_d_exports {
|
|
2005
2513
|
export { erf, erfc, logSoftmax, logit, logsumexp, softmax };
|
|
2006
2514
|
}
|
|
@@ -2031,8 +2539,7 @@ declare const jacfwd: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTr
|
|
|
2031
2539
|
* Construct a Jaxpr by dynamically tracing a function with example inputs.
|
|
2032
2540
|
*/
|
|
2033
2541
|
declare const makeJaxpr: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...args: Parameters<F>) => {
|
|
2034
|
-
jaxpr:
|
|
2035
|
-
consts: Array[];
|
|
2542
|
+
jaxpr: ClosedJaxpr;
|
|
2036
2543
|
treedef: JsTreeDef;
|
|
2037
2544
|
};
|
|
2038
2545
|
/**
|
|
@@ -2080,11 +2587,6 @@ declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>>(f: F)
|
|
|
2080
2587
|
* Compute the Jacobian evaluated row-by-row by reverse-mode AD.
|
|
2081
2588
|
*/
|
|
2082
2589
|
declare const jacrev: typeof jacfwd;
|
|
2083
|
-
/**
|
|
2084
|
-
* @function
|
|
2085
|
-
* Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
|
|
2086
|
-
*/
|
|
2087
|
-
declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
2088
2590
|
/**
|
|
2089
2591
|
* Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
|
|
2090
2592
|
*
|
|
@@ -2106,8 +2608,5 @@ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
|
|
|
2106
2608
|
* default device.
|
|
2107
2609
|
*/
|
|
2108
2610
|
declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
|
|
2109
|
-
//# sourceMappingURL=index.d.ts.map
|
|
2110
|
-
|
|
2111
2611
|
//#endregion
|
|
2112
|
-
export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|
|
2113
|
-
//# sourceMappingURL=index.d.ts.map
|
|
2612
|
+
export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|