@jax-js/jax 0.1.2 → 0.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +16 -34
- package/dist/{backend-DeVfWEFS.cjs → backend-Bu9GY6sK.cjs} +222 -36
- package/dist/{backend-BqymqzuU.js → backend-tngXtWe4.js} +204 -36
- package/dist/index.cjs +1798 -955
- package/dist/index.d.cts +383 -97
- package/dist/index.d.ts +383 -97
- package/dist/index.js +1791 -949
- package/dist/{webgpu-BGuG58KZ.js → webgpu-ChVgx3b6.js} +410 -97
- package/dist/{webgpu-CcGP160M.cjs → webgpu-Oj3Kd-kd.cjs} +410 -97
- package/package.json +1 -1
package/dist/index.d.ts
CHANGED
|
@@ -121,7 +121,6 @@ declare class ShapeTracker {
|
|
|
121
121
|
/** Like pad(), but allows for negative values. */
|
|
122
122
|
padOrShrink(arg: Pair[]): ShapeTracker;
|
|
123
123
|
}
|
|
124
|
-
//# sourceMappingURL=shape.d.ts.map
|
|
125
124
|
//#endregion
|
|
126
125
|
//#region src/utils.d.ts
|
|
127
126
|
/**
|
|
@@ -251,7 +250,7 @@ declare class AluExp implements FpHashable {
|
|
|
251
250
|
/** Substitute variables in this AluExp to values. */
|
|
252
251
|
substitute(variables: Record<string, AluExp>): AluExp;
|
|
253
252
|
/** Reindex gid values in this expression as needed. */
|
|
254
|
-
reindexGids(
|
|
253
|
+
reindexGids(newGids: number[]): AluExp;
|
|
255
254
|
get min(): number;
|
|
256
255
|
get max(): number;
|
|
257
256
|
/** Largest known integer that divides self. */
|
|
@@ -404,6 +403,52 @@ declare class Reduction implements FpHashable {
|
|
|
404
403
|
}
|
|
405
404
|
/** Expression for accessing `indices` in input array with the given shape. */
|
|
406
405
|
//#endregion
|
|
406
|
+
//#region src/routine.d.ts
|
|
407
|
+
/**
|
|
408
|
+
* Advanced operations that don't fit into the `AluExp` compiler representation.
|
|
409
|
+
*
|
|
410
|
+
* Some routines like iterative matrix algorithms, FFTs, or sorting may not be
|
|
411
|
+
* easy to express efficiently as a `Kernel` object. These also tend to be
|
|
412
|
+
* somewhat expensive, so the benefit of kernel fusion and inlining is less
|
|
413
|
+
* relevant.
|
|
414
|
+
*
|
|
415
|
+
* For these operations, we dispatch them as a custom operation on the backend,
|
|
416
|
+
* which each backend implements in a specific way. These are listed in the
|
|
417
|
+
* `Routines` enum below.
|
|
418
|
+
*
|
|
419
|
+
* Routines cannot be fused into other kernels and always operate on contiguous
|
|
420
|
+
* arrays (default `ShapeTracker`).
|
|
421
|
+
*/
|
|
422
|
+
declare class Routine {
|
|
423
|
+
/** The name of the routine. */
|
|
424
|
+
readonly name: Routines;
|
|
425
|
+
/** Dtype and shape of the inputs and outputs. */
|
|
426
|
+
readonly type: RoutineType;
|
|
427
|
+
/** Extra parameters specific to the routine. */
|
|
428
|
+
readonly params?: any | undefined;
|
|
429
|
+
constructor(/** The name of the routine. */
|
|
430
|
+
name: Routines, /** Dtype and shape of the inputs and outputs. */
|
|
431
|
+
type: RoutineType, /** Extra parameters specific to the routine. */
|
|
432
|
+
params?: any | undefined);
|
|
433
|
+
}
|
|
434
|
+
/** One of the valid `Routine` that can be dispatched to backend. */
|
|
435
|
+
declare enum Routines {
|
|
436
|
+
/** Stable sorting algorithm along the last axis. */
|
|
437
|
+
Sort = "Sort",
|
|
438
|
+
/** Returns `int32` indices of the stably sorted array. */
|
|
439
|
+
Argsort = "Argsort",
|
|
440
|
+
/** Solve a triangular system of questions. */
|
|
441
|
+
TriangularSolve = "TriangularSolve",
|
|
442
|
+
/** Cholesky decomposition of 2D positive semi-definite matrices. */
|
|
443
|
+
Cholesky = "Cholesky",
|
|
444
|
+
}
|
|
445
|
+
interface RoutineType {
|
|
446
|
+
inputShapes: number[][];
|
|
447
|
+
inputDtypes: DType[];
|
|
448
|
+
outputShapes: number[][];
|
|
449
|
+
outputDtypes: DType[];
|
|
450
|
+
}
|
|
451
|
+
//#endregion
|
|
407
452
|
//#region src/backend.d.ts
|
|
408
453
|
type Device = "cpu" | "wasm" | "webgpu";
|
|
409
454
|
declare const devices: Device[];
|
|
@@ -441,9 +486,13 @@ interface Backend {
|
|
|
441
486
|
/** Read a range of bytes from a buffer, blocking variant. */
|
|
442
487
|
readSync(slot: Slot, start?: number, count?: number): Uint8Array<ArrayBuffer>;
|
|
443
488
|
/** Prepare an expression to be executed later. */
|
|
444
|
-
|
|
489
|
+
prepareKernel(kernel: Kernel): Promise<Executable>;
|
|
445
490
|
/** Prepare an expression to be executed later, blocking variant. */
|
|
446
|
-
|
|
491
|
+
prepareKernelSync(kernel: Kernel): Executable;
|
|
492
|
+
/** Prepare an advanced routine to be executed later. */
|
|
493
|
+
prepareRoutine(routine: Routine): Promise<Executable>;
|
|
494
|
+
/** Prepare an advanced routine to be executed later, blocking variant. */
|
|
495
|
+
prepareRoutineSync(routine: Routine): Executable;
|
|
447
496
|
/**
|
|
448
497
|
* Run a backend operation that was previously prepared.
|
|
449
498
|
*
|
|
@@ -454,14 +503,69 @@ interface Backend {
|
|
|
454
503
|
dispatch(exe: Executable, inputs: Slot[], outputs: Slot[]): void;
|
|
455
504
|
}
|
|
456
505
|
declare class Executable<T = any> {
|
|
457
|
-
|
|
458
|
-
|
|
506
|
+
/** The `Kernel` or `Routine` that was prepared. */
|
|
507
|
+
readonly source: Kernel | Routine;
|
|
508
|
+
/** Extra data specific to the backend running this executable. */
|
|
459
509
|
readonly data: T;
|
|
460
|
-
constructor(
|
|
510
|
+
constructor(/** The `Kernel` or `Routine` that was prepared. */
|
|
511
|
+
source: Kernel | Routine, /** Extra data specific to the backend running this executable. */
|
|
461
512
|
data: T);
|
|
462
513
|
}
|
|
514
|
+
declare namespace numpy_fft_d_exports {
|
|
515
|
+
export { ComplexPair, fft, ifft };
|
|
516
|
+
}
|
|
517
|
+
/**
|
|
518
|
+
* A pair of arrays representing real and imaginary part `a + bj`. Both arrays
|
|
519
|
+
* must have the same shape.
|
|
520
|
+
*/
|
|
521
|
+
type ComplexPair = {
|
|
522
|
+
real: Array;
|
|
523
|
+
imag: Array;
|
|
524
|
+
};
|
|
525
|
+
/**
|
|
526
|
+
* Compute a one-dimensional discrete Fourier transform.
|
|
527
|
+
*
|
|
528
|
+
* Currently, the size of the axis must be a power of two.
|
|
529
|
+
*/
|
|
530
|
+
declare function fft(a: ComplexPair, axis?: number): ComplexPair;
|
|
531
|
+
/**
|
|
532
|
+
* Compute a one-dimensional inverse discrete Fourier transform.
|
|
533
|
+
*
|
|
534
|
+
* Currently, the size of the axis must be a power of two.
|
|
535
|
+
*/
|
|
536
|
+
declare function ifft(a: ComplexPair, axis?: number): ComplexPair;
|
|
537
|
+
declare namespace numpy_linalg_d_exports {
|
|
538
|
+
export { cholesky$1 as cholesky, diagonal, lstsq, matmul, matrixTranspose, outer, tensordot, trace, vecdot };
|
|
539
|
+
}
|
|
540
|
+
/**
|
|
541
|
+
* Compute the Cholesky decomposition of a (batched) positive-definite matrix.
|
|
542
|
+
*
|
|
543
|
+
* This is like `jax.lax.linalg.cholesky()`, except with an option to symmetrize
|
|
544
|
+
* the input matrix, which is on by default.
|
|
545
|
+
*/
|
|
546
|
+
declare function cholesky$1(a: ArrayLike, {
|
|
547
|
+
upper,
|
|
548
|
+
symmetrizeInput
|
|
549
|
+
}?: {
|
|
550
|
+
upper?: boolean;
|
|
551
|
+
symmetrizeInput?: boolean;
|
|
552
|
+
}): Array;
|
|
553
|
+
/**
|
|
554
|
+
* Return the least-squares solution to a linear equation.
|
|
555
|
+
*
|
|
556
|
+
* For overdetermined systems, this finds the `x` that minimizes `norm(ax - b)`.
|
|
557
|
+
* For underdetermined systems, this finds the minimum-norm solution for `x`.
|
|
558
|
+
*
|
|
559
|
+
* This currently uses Cholesky decomposition to solve the normal equations,
|
|
560
|
+
* under the hood. The method is not as robust as QR or SVD.
|
|
561
|
+
*
|
|
562
|
+
* @param a coefficient matrix of shape `(M, N)`
|
|
563
|
+
* @param b right-hand side of shape `(M,)` or `(M, K)`
|
|
564
|
+
* @return least-squares solution of shape `(N,)` or `(N, K)`
|
|
565
|
+
*/
|
|
566
|
+
declare function lstsq(a: ArrayLike, b: ArrayLike): Array;
|
|
463
567
|
declare namespace numpy_d_exports {
|
|
464
|
-
export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot$1 as dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, float64, floor, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, positive, pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, squeeze, stack, std, subtract, sum, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
568
|
+
export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, dot$1 as dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, flip, fliplr, flipud, float16, float32, float64, floor, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, matmul, matrixTranspose, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sort, sqrt, square, squeeze, stack, std, subtract, sum, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
465
569
|
}
|
|
466
570
|
declare const float32 = DType.Float32;
|
|
467
571
|
declare const int32 = DType.Int32;
|
|
@@ -587,6 +691,20 @@ declare function prod(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
|
587
691
|
declare function min(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
588
692
|
/** Return the maximum of array elements along a given axis. */
|
|
589
693
|
declare function max(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
694
|
+
/**
|
|
695
|
+
* Test whether all array elements along a given axis evaluate to True.
|
|
696
|
+
*
|
|
697
|
+
* Returns a boolean array with the same shape as `a` with the specified axis
|
|
698
|
+
* removed. If axis is None, returns a scalar.
|
|
699
|
+
*/
|
|
700
|
+
declare function all(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
701
|
+
/**
|
|
702
|
+
* Test whether any array element along a given axis evaluates to True.
|
|
703
|
+
*
|
|
704
|
+
* Returns a boolean array with the same shape as `a` with the specified axis
|
|
705
|
+
* removed. If axis is None, returns a scalar.
|
|
706
|
+
*/
|
|
707
|
+
declare function any(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
590
708
|
/** Return the peak-to-peak range along a given axis (`max - min`). */
|
|
591
709
|
declare function ptp(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
592
710
|
/** Compute the average of the array elements along the specified axis. */
|
|
@@ -605,6 +723,13 @@ declare function argmin(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
|
|
|
605
723
|
* specified axis.
|
|
606
724
|
*/
|
|
607
725
|
declare function argmax(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
|
|
726
|
+
/**
|
|
727
|
+
* Cumulative sum of elements along an axis.
|
|
728
|
+
*
|
|
729
|
+
* Currently this function is `O(n^2)`, we'll improve this later on with a
|
|
730
|
+
* two-phase parallel reduction algorithm.
|
|
731
|
+
*/
|
|
732
|
+
declare function cumsum(a: ArrayLike, axis?: number): Array;
|
|
608
733
|
/** Reverse the elements in an array along the given axes. */
|
|
609
734
|
declare function flip(x: ArrayLike, axis?: Axis): Array;
|
|
610
735
|
/**
|
|
@@ -650,12 +775,29 @@ declare function columnStack(xs: ArrayLike[]): Array;
|
|
|
650
775
|
declare function flipud(x: ArrayLike): Array;
|
|
651
776
|
/** Flip an array horizontally (axis=1). */
|
|
652
777
|
declare function fliplr(x: ArrayLike): Array;
|
|
653
|
-
/**
|
|
654
|
-
declare
|
|
778
|
+
/** Transpose the last two dimensions of an array. */
|
|
779
|
+
declare function matrixTranspose(a: ArrayLike): Array;
|
|
655
780
|
/** Return a 1-D flattened array containing the elements of the input. */
|
|
656
781
|
declare function ravel(a: ArrayLike): Array;
|
|
657
782
|
/** Remove one or more length-1 axes from an array. */
|
|
658
783
|
declare function squeeze(a: ArrayLike, axis?: Axis): Array;
|
|
784
|
+
/**
|
|
785
|
+
* Expand the shape of an array by inserting new axes of length 1.
|
|
786
|
+
*
|
|
787
|
+
* @param a - Input array.
|
|
788
|
+
* @param axis - Position(s) in the expanded axes where the new axis (or axes)
|
|
789
|
+
* is placed. Can be a single integer or an array of integers.
|
|
790
|
+
* @returns Array with the number of dimensions increased.
|
|
791
|
+
*
|
|
792
|
+
* @example
|
|
793
|
+
* ```ts
|
|
794
|
+
* const x = np.array([1, 2]);
|
|
795
|
+
* np.expandDims(x, 0); // Shape [1, 2]
|
|
796
|
+
* np.expandDims(x, 1); // Shape [2, 1]
|
|
797
|
+
* np.expandDims(x, [0, 2]); // Shape [1, 2, 1]
|
|
798
|
+
* ```
|
|
799
|
+
*/
|
|
800
|
+
declare function expandDims(a: ArrayLike, axis: number | number[]): Array;
|
|
659
801
|
/**
|
|
660
802
|
* Repeat each element of an array after themselves.
|
|
661
803
|
*
|
|
@@ -702,6 +844,22 @@ declare function diagonal(a: ArrayLike, offset?: number, axis1?: number, axis2?:
|
|
|
702
844
|
declare function diag(v: ArrayLike, k?: number): Array;
|
|
703
845
|
/** Calculate the sum of the diagonal of an array along the given axes. */
|
|
704
846
|
declare function trace(a: ArrayLike, offset?: number, axis1?: number, axis2?: number): Array;
|
|
847
|
+
/**
|
|
848
|
+
* Return a sorted copy of an array.
|
|
849
|
+
*
|
|
850
|
+
* The array is sorted along a specified axis (the last by default). This may be
|
|
851
|
+
* an unstable sort, and it dispatches to device-specific implementation.
|
|
852
|
+
*/
|
|
853
|
+
declare function sort(a: ArrayLike, axis?: number): Array;
|
|
854
|
+
/**
|
|
855
|
+
* Return indices that would sort an array. This may be an unstable sorting
|
|
856
|
+
* algorithm; it need not preserve order of indices in ties.
|
|
857
|
+
*
|
|
858
|
+
* Returns an array of `int32` indices.
|
|
859
|
+
*
|
|
860
|
+
* The array is sorted along a specified axis (the last by default).
|
|
861
|
+
*/
|
|
862
|
+
declare function argsort(a: ArrayLike, axis?: number): Array;
|
|
705
863
|
/** Return if two arrays are element-wise equal within a tolerance. */
|
|
706
864
|
declare function allclose(actual: Parameters<typeof array>[0], expected: Parameters<typeof array>[0], options?: {
|
|
707
865
|
rtol?: number;
|
|
@@ -773,6 +931,10 @@ declare function vecdot(x: ArrayLike, y: ArrayLike, {
|
|
|
773
931
|
* Like vecdot() but flattens the arguments first into vectors.
|
|
774
932
|
*/
|
|
775
933
|
declare function vdot(x: ArrayLike, y: ArrayLike): Array;
|
|
934
|
+
/** Convolution of two one-dimensional arrays. */
|
|
935
|
+
declare function convolve(x: Array, y: Array, mode?: "full" | "same" | "valid"): Array;
|
|
936
|
+
/** Correlation of two one dimensional arrays. */
|
|
937
|
+
declare function correlate(x: Array, y: Array, mode?: "full" | "same" | "valid"): Array;
|
|
776
938
|
/**
|
|
777
939
|
* Return a tuple of coordinate matrices from coordinate vectors.
|
|
778
940
|
*
|
|
@@ -784,21 +946,6 @@ declare function meshgrid(xs: Array[], {
|
|
|
784
946
|
}?: {
|
|
785
947
|
indexing?: "xy" | "ij";
|
|
786
948
|
}): Array[];
|
|
787
|
-
/**
|
|
788
|
-
* Return an array with ones on and below the diagonal and zeros elsewhere.
|
|
789
|
-
*
|
|
790
|
-
* If `k` is provided, it specifies the sub-diagonal on and below which the
|
|
791
|
-
* array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
|
|
792
|
-
* `k>0` is above it.
|
|
793
|
-
*/
|
|
794
|
-
declare function tri(n: number, m?: number, k?: number, {
|
|
795
|
-
dtype,
|
|
796
|
-
device
|
|
797
|
-
}?: DTypeAndDevice): Array;
|
|
798
|
-
/** Return the lower triangle of an array. Must be of dimension >= 2. */
|
|
799
|
-
declare function tril(a: ArrayLike, k?: number): Array;
|
|
800
|
-
/** Return the upper triangle of an array. Must be of dimension >= 2. */
|
|
801
|
-
declare function triu(a: ArrayLike, k?: number): Array;
|
|
802
949
|
/**
|
|
803
950
|
* Clip (limit) the values in an array.
|
|
804
951
|
*
|
|
@@ -815,8 +962,6 @@ declare function clip(a: ArrayLike, min?: ArrayLike, max?: ArrayLike): Array;
|
|
|
815
962
|
* This is the same function as `jax.numpy.abs()`.
|
|
816
963
|
*/
|
|
817
964
|
declare function absolute(x: ArrayLike): Array;
|
|
818
|
-
/** @function Alias of `jax.numpy.absolute()`. */
|
|
819
|
-
declare const abs: typeof absolute;
|
|
820
965
|
/** Return an element-wise indication of sign of the input. */
|
|
821
966
|
declare function sign(x: ArrayLike): Array;
|
|
822
967
|
/** @function Return element-wise positive values of the input (no-op). */
|
|
@@ -870,12 +1015,6 @@ declare const hypot: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
|
870
1015
|
* The output is ill-defined when both x and y are zero.
|
|
871
1016
|
*/
|
|
872
1017
|
declare const atan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
|
|
873
|
-
/** @function Alias of `jax.numpy.acos()`. */
|
|
874
|
-
declare const arccos: typeof acos;
|
|
875
|
-
/** @function Alias of `jax.numpy.atan()`. */
|
|
876
|
-
declare const arctan: (x: ArrayLike) => Array;
|
|
877
|
-
/** @function Alias of `jax.numpy.atan2()`. */
|
|
878
|
-
declare const arctan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
|
|
879
1018
|
/** Element-wise subtraction, with broadcasting. */
|
|
880
1019
|
declare function subtract(x: ArrayLike, y: ArrayLike): Array;
|
|
881
1020
|
/** Calculates the floating-point division of x by y element-wise. */
|
|
@@ -890,8 +1029,6 @@ declare const fmod: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
|
|
|
890
1029
|
* Calculate element-wise remainder of the division (matches sign of y).
|
|
891
1030
|
*/
|
|
892
1031
|
declare const remainder: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
|
|
893
|
-
/** @function Alias of `jax.numpy.trueDivide()`. */
|
|
894
|
-
declare const divide: typeof trueDivide;
|
|
895
1032
|
/** Round input to the nearest integer towards zero. */
|
|
896
1033
|
declare function trunc(x: ArrayLike): Array;
|
|
897
1034
|
/**
|
|
@@ -931,8 +1068,6 @@ declare const degrees: typeof rad2deg;
|
|
|
931
1068
|
* Computes first array raised to power of second array, element-wise.
|
|
932
1069
|
*/
|
|
933
1070
|
declare const power: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
934
|
-
/** @function Alias of `jax.numpy.power()`. */
|
|
935
|
-
declare const pow: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
936
1071
|
/** @function Calculate the element-wise cube root of the input array. */
|
|
937
1072
|
declare const cbrt: OwnedFunction<(x: ArrayLike) => Array>;
|
|
938
1073
|
/**
|
|
@@ -977,12 +1112,6 @@ declare const arccosh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
|
977
1112
|
* `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
|
|
978
1113
|
*/
|
|
979
1114
|
declare const arctanh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
980
|
-
/** @function Alias of `jax.numpy.arcsinh()`. */
|
|
981
|
-
declare const asinh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
982
|
-
/** @function Alias of `jax.numpy.arccosh()`. */
|
|
983
|
-
declare const acosh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
984
|
-
/** @function Alias of `jax.numpy.arctanh()`. */
|
|
985
|
-
declare const atanh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
986
1115
|
/**
|
|
987
1116
|
* Compute the variance of an array.
|
|
988
1117
|
*
|
|
@@ -1009,6 +1138,10 @@ declare function std(x: ArrayLike, axis?: Axis, opts?: {
|
|
|
1009
1138
|
mean?: ArrayLike;
|
|
1010
1139
|
correction?: number;
|
|
1011
1140
|
} & ReduceOpts): Array;
|
|
1141
|
+
/** Estimate the sample covariance of a set of variables. */
|
|
1142
|
+
declare function cov(x: ArrayLike, y?: ArrayLike): Array;
|
|
1143
|
+
/** Compute the Pearson correlation coefficients (in range `[-1, 1]`). */
|
|
1144
|
+
declare function corrcoef(x: ArrayLike, y?: ArrayLike): Array;
|
|
1012
1145
|
/** Test element-wise for positive or negative infinity, return bool array. */
|
|
1013
1146
|
declare function isinf(x: ArrayLike): Array;
|
|
1014
1147
|
/** Test element-wise for NaN (Not a Number). */
|
|
@@ -1022,7 +1155,6 @@ declare function isposinf(x: ArrayLike): Array;
|
|
|
1022
1155
|
* Test element-wise for finite values (not infinity or NaN).
|
|
1023
1156
|
*/
|
|
1024
1157
|
declare const isfinite: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1025
|
-
//# sourceMappingURL=numpy.d.ts.map
|
|
1026
1158
|
declare namespace tree_d_exports {
|
|
1027
1159
|
export { JsTree, JsTreeDef, MapJsTree, NodeType, dispose, flatten, leaves, map, ref, structure, unflatten };
|
|
1028
1160
|
}
|
|
@@ -1073,13 +1205,18 @@ declare function dispose<Tree extends JsTree<Array>>(tree: Tree | null | undefin
|
|
|
1073
1205
|
//#region src/frontend/convolution.d.ts
|
|
1074
1206
|
/** Definition of a general dilated convolution. Should be valid on creation. */
|
|
1075
1207
|
interface ConvParams {
|
|
1208
|
+
vmapDims: number;
|
|
1076
1209
|
strides: number[];
|
|
1077
|
-
padding: [
|
|
1210
|
+
padding: Pair[];
|
|
1078
1211
|
lhsDilation: number[];
|
|
1079
1212
|
rhsDilation: number[];
|
|
1080
1213
|
}
|
|
1081
1214
|
/**
|
|
1082
1215
|
* Check that the shapes and parameters passed to convolution are valid.
|
|
1216
|
+
* Expected shapes of the lhs and rhs of the convolution are:
|
|
1217
|
+
*
|
|
1218
|
+
* - `lhsShape = [*vmapDims, batchSize, inChannels, spatialDims...]`
|
|
1219
|
+
* - `rhsShape = [*vmapDims, outChannels, inChannels, kernelSize...]`
|
|
1083
1220
|
*
|
|
1084
1221
|
* If the check succeeds, returns the output shape.
|
|
1085
1222
|
*/
|
|
@@ -1151,9 +1288,21 @@ declare class Jaxpr implements FpHashable {
|
|
|
1151
1288
|
* - Remove no-op movement operations.
|
|
1152
1289
|
*/
|
|
1153
1290
|
simplify(): Jaxpr;
|
|
1154
|
-
/** Flattens nested
|
|
1291
|
+
/** Flattens nested Jit in a Jaxpr. Useful for handling jit-of-jit. */
|
|
1155
1292
|
flatten(): Jaxpr;
|
|
1156
1293
|
}
|
|
1294
|
+
/** Jaxpr with a collection of associated, traced constants. */
|
|
1295
|
+
declare class ClosedJaxpr {
|
|
1296
|
+
readonly jaxpr: Jaxpr;
|
|
1297
|
+
readonly consts: Tracer[];
|
|
1298
|
+
constructor(jaxpr: Jaxpr, consts: Tracer[]);
|
|
1299
|
+
/** String representation of this Jaxpr. */
|
|
1300
|
+
toString(): string;
|
|
1301
|
+
/** Apply a function to the underlying Jaxpr. */
|
|
1302
|
+
mapJaxpr(f: (jaxpr: Jaxpr) => Jaxpr): ClosedJaxpr;
|
|
1303
|
+
/** Dispose of the constants in this Jaxpr. */
|
|
1304
|
+
dispose(): void;
|
|
1305
|
+
}
|
|
1157
1306
|
/** @inline */
|
|
1158
1307
|
type JitOpts = {
|
|
1159
1308
|
staticArgnums?: number[];
|
|
@@ -1176,7 +1325,9 @@ declare enum Primitive {
|
|
|
1176
1325
|
Mul = "mul",
|
|
1177
1326
|
Idiv = "idiv",
|
|
1178
1327
|
Mod = "mod",
|
|
1179
|
-
// uses sign of
|
|
1328
|
+
// uses sign of numerator, C-style, matches JS but not Python
|
|
1329
|
+
Min = "min",
|
|
1330
|
+
Max = "max",
|
|
1180
1331
|
Neg = "neg",
|
|
1181
1332
|
Reciprocal = "reciprocal",
|
|
1182
1333
|
Floor = "floor",
|
|
@@ -1184,7 +1335,6 @@ declare enum Primitive {
|
|
|
1184
1335
|
StopGradient = "stop_gradient",
|
|
1185
1336
|
Cast = "cast",
|
|
1186
1337
|
Bitcast = "bitcast",
|
|
1187
|
-
RandomBits = "random_bits",
|
|
1188
1338
|
Sin = "sin",
|
|
1189
1339
|
Cos = "cos",
|
|
1190
1340
|
Asin = "asin",
|
|
@@ -1194,8 +1344,6 @@ declare enum Primitive {
|
|
|
1194
1344
|
Erf = "erf",
|
|
1195
1345
|
Erfc = "erfc",
|
|
1196
1346
|
Sqrt = "sqrt",
|
|
1197
|
-
Min = "min",
|
|
1198
|
-
Max = "max",
|
|
1199
1347
|
Reduce = "reduce",
|
|
1200
1348
|
Dot = "dot",
|
|
1201
1349
|
// sum(x*y, axis=-1)
|
|
@@ -1205,14 +1353,23 @@ declare enum Primitive {
|
|
|
1205
1353
|
PoolTranspose = "pool_transpose",
|
|
1206
1354
|
Compare = "compare",
|
|
1207
1355
|
Where = "where",
|
|
1356
|
+
RandomBits = "random_bits",
|
|
1357
|
+
Gather = "gather",
|
|
1208
1358
|
Transpose = "transpose",
|
|
1209
1359
|
Broadcast = "broadcast",
|
|
1210
1360
|
Reshape = "reshape",
|
|
1211
1361
|
Flip = "flip",
|
|
1212
1362
|
Shrink = "shrink",
|
|
1213
1363
|
Pad = "pad",
|
|
1214
|
-
|
|
1215
|
-
|
|
1364
|
+
Sort = "sort",
|
|
1365
|
+
// sort(x, axis=-1)
|
|
1366
|
+
Argsort = "argsort",
|
|
1367
|
+
// argsort(x, axis=-1)
|
|
1368
|
+
TriangularSolve = "triangular_solve",
|
|
1369
|
+
// A is upper triangular, A @ X.T = B.T
|
|
1370
|
+
Cholesky = "cholesky",
|
|
1371
|
+
// A is positive-definite, A = L @ L^T
|
|
1372
|
+
Jit = "jit",
|
|
1216
1373
|
}
|
|
1217
1374
|
interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
1218
1375
|
[Primitive.Cast]: {
|
|
@@ -1238,6 +1395,14 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
|
1238
1395
|
[Primitive.Compare]: {
|
|
1239
1396
|
op: CompareOp;
|
|
1240
1397
|
};
|
|
1398
|
+
[Primitive.RandomBits]: {
|
|
1399
|
+
shape: number[];
|
|
1400
|
+
mode: "xor" | 0 | 1;
|
|
1401
|
+
};
|
|
1402
|
+
[Primitive.Gather]: {
|
|
1403
|
+
axis: number[];
|
|
1404
|
+
outDim: number;
|
|
1405
|
+
};
|
|
1241
1406
|
[Primitive.Transpose]: {
|
|
1242
1407
|
perm: number[];
|
|
1243
1408
|
};
|
|
@@ -1245,10 +1410,6 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
|
1245
1410
|
shape: number[];
|
|
1246
1411
|
axis: number[];
|
|
1247
1412
|
};
|
|
1248
|
-
[Primitive.RandomBits]: {
|
|
1249
|
-
shape: number[];
|
|
1250
|
-
mode: "xor" | 0 | 1;
|
|
1251
|
-
};
|
|
1252
1413
|
[Primitive.Reshape]: {
|
|
1253
1414
|
shape: number[];
|
|
1254
1415
|
};
|
|
@@ -1261,15 +1422,14 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
|
1261
1422
|
[Primitive.Pad]: {
|
|
1262
1423
|
width: Pair[];
|
|
1263
1424
|
};
|
|
1264
|
-
[Primitive.
|
|
1265
|
-
axis: number[];
|
|
1266
|
-
outDim: number;
|
|
1267
|
-
};
|
|
1268
|
-
[Primitive.JitCall]: {
|
|
1425
|
+
[Primitive.Jit]: {
|
|
1269
1426
|
name: string;
|
|
1270
1427
|
jaxpr: Jaxpr;
|
|
1271
1428
|
numConsts: number;
|
|
1272
1429
|
};
|
|
1430
|
+
[Primitive.TriangularSolve]: {
|
|
1431
|
+
unitDiagonal: boolean;
|
|
1432
|
+
};
|
|
1273
1433
|
}
|
|
1274
1434
|
/** Type of parameters taken by each primitive. */
|
|
1275
1435
|
type PrimitiveParams<T extends Primitive> = T extends keyof PrimitiveParamsImpl ? PrimitiveParamsImpl[T] : Record<string, never>;
|
|
@@ -1437,7 +1597,7 @@ declare abstract class Tracer {
|
|
|
1437
1597
|
sub(other: this | TracerValue): this;
|
|
1438
1598
|
/** Divide an array by this one. */
|
|
1439
1599
|
div(other: this | TracerValue): this;
|
|
1440
|
-
/** Return specified diagonals. See `numpy.diagonal` for full docs. */
|
|
1600
|
+
/** Return specified diagonals. See `jax.numpy.diagonal` for full docs. */
|
|
1441
1601
|
diagonal(offset?: number, axis1?: number, axis2?: number): this;
|
|
1442
1602
|
/** Flatten the array without changing its data. */
|
|
1443
1603
|
flatten(): this;
|
|
@@ -1456,6 +1616,19 @@ declare abstract class Tracer {
|
|
|
1456
1616
|
* ```
|
|
1457
1617
|
*/
|
|
1458
1618
|
[Symbol.iterator](): IterableIterator<this>;
|
|
1619
|
+
/**
|
|
1620
|
+
* Return a sorted copy of an array in ascending order.
|
|
1621
|
+
*
|
|
1622
|
+
* See `jax.numpy.sort` for full docs.
|
|
1623
|
+
*/
|
|
1624
|
+
sort(axis?: number): this;
|
|
1625
|
+
/**
|
|
1626
|
+
* Return the indices that would sort an array. This may not be a stable
|
|
1627
|
+
* sorting algorithm; it need not preserve order of indices in ties.
|
|
1628
|
+
*
|
|
1629
|
+
* See `jax.numpy.argsort` for full docs.
|
|
1630
|
+
*/
|
|
1631
|
+
argsort(axis?: number): this;
|
|
1459
1632
|
/**
|
|
1460
1633
|
* Slice an array along one or more axes.
|
|
1461
1634
|
*
|
|
@@ -1498,6 +1671,7 @@ declare class ShapedArray implements AbstractValue {
|
|
|
1498
1671
|
constructor(shape: number[], dtype: DType, weakType: boolean);
|
|
1499
1672
|
static fromAval(aval: AbstractValue): ShapedArray;
|
|
1500
1673
|
get ndim(): number;
|
|
1674
|
+
get size(): number;
|
|
1501
1675
|
toString(): string;
|
|
1502
1676
|
equals(other: ShapedArray): boolean;
|
|
1503
1677
|
}
|
|
@@ -1515,12 +1689,12 @@ type ArrayLike = Array | number | boolean;
|
|
|
1515
1689
|
declare class PendingExecute {
|
|
1516
1690
|
#private;
|
|
1517
1691
|
readonly backend: Backend;
|
|
1518
|
-
readonly
|
|
1692
|
+
readonly source: Kernel | Routine;
|
|
1519
1693
|
readonly inputs: Slot[];
|
|
1520
1694
|
readonly outputs: Slot[];
|
|
1521
1695
|
prepared: Executable | null;
|
|
1522
1696
|
submitted: boolean;
|
|
1523
|
-
constructor(backend: Backend,
|
|
1697
|
+
constructor(backend: Backend, source: Kernel | Routine, inputs: Slot[], outputs: Slot[]);
|
|
1524
1698
|
updateRc(delta: number): void;
|
|
1525
1699
|
prepare(): Promise<void>;
|
|
1526
1700
|
prepareSync(): void;
|
|
@@ -1552,7 +1726,6 @@ type ArrayConstructorArgs = {
|
|
|
1552
1726
|
*/
|
|
1553
1727
|
declare class Array extends Tracer {
|
|
1554
1728
|
#private;
|
|
1555
|
-
id: number;
|
|
1556
1729
|
/**
|
|
1557
1730
|
* @ignore
|
|
1558
1731
|
* Constructs an array from source, shape and backend. Note that if the source
|
|
@@ -1686,6 +1859,21 @@ declare function arange(start: number, stop?: number, step?: number, {
|
|
|
1686
1859
|
dtype,
|
|
1687
1860
|
device
|
|
1688
1861
|
}?: DTypeAndDevice): Array;
|
|
1862
|
+
/**
|
|
1863
|
+
* Return an array with ones on and below the diagonal and zeros elsewhere.
|
|
1864
|
+
*
|
|
1865
|
+
* If `k` is provided, it specifies the sub-diagonal on and below which the
|
|
1866
|
+
* array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
|
|
1867
|
+
* `k>0` is above it.
|
|
1868
|
+
*/
|
|
1869
|
+
declare function tri(n: number, m?: number, k?: number, {
|
|
1870
|
+
dtype,
|
|
1871
|
+
device
|
|
1872
|
+
}?: DTypeAndDevice): Array;
|
|
1873
|
+
/** Return the lower triangle of an array. Must be of dimension >= 2. */
|
|
1874
|
+
declare function tril(a: ArrayLike, k?: number): Array;
|
|
1875
|
+
/** Return the upper triangle of an array. Must be of dimension >= 2. */
|
|
1876
|
+
declare function triu(a: ArrayLike, k?: number): Array;
|
|
1689
1877
|
/**
|
|
1690
1878
|
* Return evenly spaced numbers over a specified interval.
|
|
1691
1879
|
*
|
|
@@ -1699,8 +1887,71 @@ declare function linspace(start: number, stop: number, num?: number, endpoint?:
|
|
|
1699
1887
|
dtype,
|
|
1700
1888
|
device
|
|
1701
1889
|
}?: DTypeAndDevice): Array;
|
|
1890
|
+
declare namespace lax_linalg_d_exports {
|
|
1891
|
+
export { cholesky, triangularSolve };
|
|
1892
|
+
}
|
|
1893
|
+
/**
|
|
1894
|
+
* Compute the Cholesky decomposition of a symmetric positive-definite matrix.
|
|
1895
|
+
*
|
|
1896
|
+
* The Cholesky decomposition of a matrix `A` is:
|
|
1897
|
+
*
|
|
1898
|
+
* - A = L @ L^T (for upper=false, default)
|
|
1899
|
+
* - A = U^T @ U (for upper=true)
|
|
1900
|
+
*
|
|
1901
|
+
* where `L` is a lower-triangular matrix and `U` is an upper-triangular matrix.
|
|
1902
|
+
* The input matrix must be symmetric and positive-definite.
|
|
1903
|
+
*
|
|
1904
|
+
* @example
|
|
1905
|
+
* ```ts
|
|
1906
|
+
* import { lax, numpy as np } from "@jax-js/jax";
|
|
1907
|
+
*
|
|
1908
|
+
* const x = np.array([[2., 1.], [1., 2.]]);
|
|
1909
|
+
*
|
|
1910
|
+
* // Lower Cholesky factorization (default):
|
|
1911
|
+
* const L = lax.linalg.cholesky(x);
|
|
1912
|
+
* // L ≈ [[1.4142135, 0], [0.70710677, 1.2247449]]
|
|
1913
|
+
*
|
|
1914
|
+
* // Upper Cholesky factorization:
|
|
1915
|
+
* const U = lax.linalg.cholesky(x, { upper: true });
|
|
1916
|
+
* // U ≈ [[1.4142135, 0.70710677], [0, 1.2247449]]
|
|
1917
|
+
* ```
|
|
1918
|
+
*/
|
|
1919
|
+
declare function cholesky(a: ArrayLike, {
|
|
1920
|
+
upper
|
|
1921
|
+
}?: {
|
|
1922
|
+
upper?: boolean;
|
|
1923
|
+
}): Array;
|
|
1924
|
+
/**
|
|
1925
|
+
* Solve a triangular linear system.
|
|
1926
|
+
*
|
|
1927
|
+
* Solves `a @ x = b` (if leftSide=true) or `x @ a = b` (if leftSide=false)
|
|
1928
|
+
* where `a` is a triangular matrix.
|
|
1929
|
+
*
|
|
1930
|
+
* @example
|
|
1931
|
+
* ```ts
|
|
1932
|
+
* import { lax, numpy as np } from "@jax-js/jax";
|
|
1933
|
+
*
|
|
1934
|
+
* const L = np.array([[2., 0.], [1., 3.]]);
|
|
1935
|
+
* const b = np.array([4., 7.]).reshape([2, 1]);
|
|
1936
|
+
*
|
|
1937
|
+
* // Solve L @ x = b
|
|
1938
|
+
* const x = lax.linalg.triangularSolve(L, b, { leftSide: true, lower: true });
|
|
1939
|
+
* // x = [[2.], [5./3.]]
|
|
1940
|
+
* ```
|
|
1941
|
+
*/
|
|
1942
|
+
declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
|
|
1943
|
+
leftSide,
|
|
1944
|
+
lower,
|
|
1945
|
+
transposeA,
|
|
1946
|
+
unitDiagonal
|
|
1947
|
+
}?: {
|
|
1948
|
+
leftSide?: boolean;
|
|
1949
|
+
lower?: boolean;
|
|
1950
|
+
transposeA?: boolean;
|
|
1951
|
+
unitDiagonal?: boolean;
|
|
1952
|
+
}): Array;
|
|
1702
1953
|
declare namespace lax_d_exports {
|
|
1703
|
-
export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convWithGeneralPadding, dot, erf, erfc, reduceWindow, stopGradient };
|
|
1954
|
+
export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convWithGeneralPadding, dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
|
|
1704
1955
|
}
|
|
1705
1956
|
/**
|
|
1706
1957
|
* Dimension numbers for general `dot()` primitive.
|
|
@@ -1731,7 +1982,7 @@ declare function dot(lhs: Array, rhs: Array, {
|
|
|
1731
1982
|
lhsBatchDims: lb,
|
|
1732
1983
|
rhsBatchDims: rb
|
|
1733
1984
|
}?: DotDimensionNumbers): Array;
|
|
1734
|
-
type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | [
|
|
1985
|
+
type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | Pair[];
|
|
1735
1986
|
/**
|
|
1736
1987
|
* General n-dimensional convolution operator, with optional dilation.
|
|
1737
1988
|
*
|
|
@@ -1742,10 +1993,12 @@ type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | [number, number][];
|
|
|
1742
1993
|
*/
|
|
1743
1994
|
declare function convGeneralDilated(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, {
|
|
1744
1995
|
lhsDilation,
|
|
1745
|
-
rhsDilation
|
|
1996
|
+
rhsDilation,
|
|
1997
|
+
featureGroupCount
|
|
1746
1998
|
}?: {
|
|
1747
1999
|
lhsDilation?: number[];
|
|
1748
2000
|
rhsDilation?: number[];
|
|
2001
|
+
featureGroupCount?: number;
|
|
1749
2002
|
}): Array;
|
|
1750
2003
|
/** Convenience wrapper around `convGeneralDilated`. */
|
|
1751
2004
|
declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, lhsDilation?: number[], rhsDilation?: number[]): Array;
|
|
@@ -1769,9 +2022,8 @@ declare function erfc(x: ArrayLike): Array;
|
|
|
1769
2022
|
* forward or reverse-mode automatic differentiation.
|
|
1770
2023
|
*/
|
|
1771
2024
|
declare function stopGradient(x: ArrayLike): Array;
|
|
1772
|
-
//# sourceMappingURL=lax.d.ts.map
|
|
1773
2025
|
declare namespace nn_d_exports {
|
|
1774
|
-
export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, squareplus, standardize, swish };
|
|
2026
|
+
export { celu, elu, gelu, glu, hardSigmoid, hardSilu, hardSilu as hardSwish, hardTanh, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, selu, sigmoid, silu, softSign, softmax, softplus, sparsePlus, sparseSigmoid, squareplus, standardize, silu as swish };
|
|
1775
2027
|
}
|
|
1776
2028
|
/**
|
|
1777
2029
|
* Rectified Linear Unit (ReLU) activation function:
|
|
@@ -1798,21 +2050,28 @@ declare function sigmoid(x: ArrayLike): Array;
|
|
|
1798
2050
|
*/
|
|
1799
2051
|
declare function softplus(x: ArrayLike): Array;
|
|
1800
2052
|
/**
|
|
1801
|
-
*
|
|
1802
|
-
*
|
|
2053
|
+
* @function
|
|
2054
|
+
* Sparse plus function:
|
|
2055
|
+
*
|
|
2056
|
+
* - When `x <= -1`: `0`
|
|
2057
|
+
* - When `-1 < x < 1`: `(x+1)**2 / 4`
|
|
2058
|
+
* - When `x >= 1`: `x`
|
|
1803
2059
|
*/
|
|
1804
|
-
declare
|
|
2060
|
+
declare const sparsePlus: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1805
2061
|
/**
|
|
1806
2062
|
* @function
|
|
1807
|
-
*
|
|
1808
|
-
* Swish, computed element-wise:
|
|
1809
|
-
* `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
|
|
1810
|
-
*
|
|
1811
|
-
* `swish()` and `silu()` are both aliases for the same function.
|
|
2063
|
+
* Sparse sigmoid activation function.
|
|
1812
2064
|
*
|
|
1813
|
-
*
|
|
2065
|
+
* - When `x <= -1`: `0`
|
|
2066
|
+
* - When `-1 < x < 1`: `(x + 1) / 2`
|
|
2067
|
+
* - When `x >= 1`: `1`
|
|
1814
2068
|
*/
|
|
1815
|
-
declare const
|
|
2069
|
+
declare const sparseSigmoid: OwnedFunction<(x: ArrayLike) => Array>;
|
|
2070
|
+
/**
|
|
2071
|
+
* Soft-sign activation function, computed element-wise:
|
|
2072
|
+
* `softsign(x) = x / (|x| + 1)`.
|
|
2073
|
+
*/
|
|
2074
|
+
declare function softSign(x: ArrayLike): Array;
|
|
1816
2075
|
/**
|
|
1817
2076
|
* @function
|
|
1818
2077
|
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
@@ -1823,7 +2082,7 @@ declare const silu: OwnedFunction<(x: ArrayLike) => Array>;
|
|
|
1823
2082
|
*
|
|
1824
2083
|
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
1825
2084
|
*/
|
|
1826
|
-
declare const
|
|
2085
|
+
declare const silu: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1827
2086
|
/**
|
|
1828
2087
|
* Log-sigmoid activation function, computed element-wise:
|
|
1829
2088
|
* `log_sigmoid(x) = log(sigmoid(x)) = -log(1 + exp(-x))`.
|
|
@@ -1836,6 +2095,12 @@ declare function logSigmoid(x: ArrayLike): Array;
|
|
|
1836
2095
|
declare const identity: (x: ArrayLike) => Array;
|
|
1837
2096
|
/** Leaky rectified linear (ReLU) activation function */
|
|
1838
2097
|
declare function leakyRelu(x: ArrayLike, negativeSlope?: ArrayLike): Array;
|
|
2098
|
+
/** Hard sigmoid activation function: `relu6(x+3)/6`. */
|
|
2099
|
+
declare function hardSigmoid(x: ArrayLike): Array;
|
|
2100
|
+
/** Hard SiLU (swish) activation function: `x * hardSigmoid(x)`. */
|
|
2101
|
+
declare function hardSilu(x: ArrayLike): Array;
|
|
2102
|
+
/** Hard tanh activation function: `clip(x, -1, 1)`. */
|
|
2103
|
+
declare function hardTanh(x: ArrayLike): Array;
|
|
1839
2104
|
/**
|
|
1840
2105
|
* Exponential linear unit activation function.
|
|
1841
2106
|
*
|
|
@@ -1850,6 +2115,16 @@ declare function elu(x: ArrayLike, alpha?: ArrayLike): Array;
|
|
|
1850
2115
|
* `celu(x) = x > 0 ? x : alpha * (exp(x/alpha) - 1)`
|
|
1851
2116
|
*/
|
|
1852
2117
|
declare function celu(x: ArrayLike, alpha?: ArrayLike): Array;
|
|
2118
|
+
/**
|
|
2119
|
+
* @function
|
|
2120
|
+
* Scaled exponential linear unit activation.
|
|
2121
|
+
*
|
|
2122
|
+
* Computes the element-wise function:
|
|
2123
|
+
* `selu(x) = lambda * (x > 0 ? x : alpha * (exp(x) - 1))`
|
|
2124
|
+
*
|
|
2125
|
+
* Where `alpha = 1.6732632423543772` and `lambda = 1.0507009873554805`.
|
|
2126
|
+
*/
|
|
2127
|
+
declare const selu: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1853
2128
|
/**
|
|
1854
2129
|
* @function
|
|
1855
2130
|
* Gaussion error linear unit (GELU) activation function.
|
|
@@ -1912,9 +2187,9 @@ declare function logSoftmax(x: ArrayLike, axis?: Axis): Array;
|
|
|
1912
2187
|
*
|
|
1913
2188
|
* Reference: https://en.wikipedia.org/wiki/LogSumExp
|
|
1914
2189
|
*/
|
|
1915
|
-
declare function logsumexp(x: ArrayLike, axis?: Axis): Array;
|
|
2190
|
+
declare function logsumexp(x: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1916
2191
|
/** Log-mean-exp reduction, like `jax.nn.logsumexp()` but subtracts `log(n)`. */
|
|
1917
|
-
declare function logmeanexp(x: ArrayLike, axis?: Axis): Array;
|
|
2192
|
+
declare function logmeanexp(x: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1918
2193
|
/**
|
|
1919
2194
|
* Standardizes input to zero mean and unit variance.
|
|
1920
2195
|
*
|
|
@@ -1945,9 +2220,8 @@ declare function standardize(x: ArrayLike, axis?: Axis, opts?: {
|
|
|
1945
2220
|
* ```
|
|
1946
2221
|
*/
|
|
1947
2222
|
declare function oneHot(x: Array, numClasses: number): Array;
|
|
1948
|
-
//# sourceMappingURL=nn.d.ts.map
|
|
1949
2223
|
declare namespace random_d_exports {
|
|
1950
|
-
export { bernoulli, bits, exponential, key, normal, split, uniform };
|
|
2224
|
+
export { bernoulli, bits, cauchy, exponential, gumbel, key, laplace, normal, split, uniform };
|
|
1951
2225
|
}
|
|
1952
2226
|
/** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
|
|
1953
2227
|
declare function key(seed: number): Array;
|
|
@@ -1970,11 +2244,33 @@ declare const uniform: OwnedFunction<(key: ArrayLike, shape?: number[] | undefin
|
|
|
1970
2244
|
* and must be broadcastable to `shape`.
|
|
1971
2245
|
*/
|
|
1972
2246
|
declare function bernoulli(key: Array, p?: ArrayLike, shape?: number[]): Array;
|
|
2247
|
+
/**
|
|
2248
|
+
* @function
|
|
2249
|
+
* Sample from a Cauchy distribution with location 0 and scale 1.
|
|
2250
|
+
*
|
|
2251
|
+
* Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
|
|
2252
|
+
*/
|
|
2253
|
+
declare const cauchy: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
1973
2254
|
/**
|
|
1974
2255
|
* @function
|
|
1975
2256
|
* Sample exponential random values according to `p(x) = exp(-x)`.
|
|
1976
2257
|
*/
|
|
1977
2258
|
declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
2259
|
+
/**
|
|
2260
|
+
* @function
|
|
2261
|
+
* Sample from a Gumbel distribution with location 0 and scale 1.
|
|
2262
|
+
*
|
|
2263
|
+
* Uses inverse transform sampling: `x = -log(-log(u))` where u ~ Uniform(0, 1).
|
|
2264
|
+
*/
|
|
2265
|
+
declare const gumbel: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
2266
|
+
/**
|
|
2267
|
+
* @function
|
|
2268
|
+
* Sample from a Laplace distribution with location 0 and scale 1.
|
|
2269
|
+
*
|
|
2270
|
+
* Uses inverse transform sampling: the CDF is `F(x) = 0.5 + 0.5 * sign(x) * (1 - exp(-|x|))`.
|
|
2271
|
+
* Inverting: `x = -sign(u - 0.5) * log(1 - 2 * |u - 0.5|)`.
|
|
2272
|
+
*/
|
|
2273
|
+
declare const laplace: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
1978
2274
|
/**
|
|
1979
2275
|
* @function
|
|
1980
2276
|
* Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
|
|
@@ -1984,7 +2280,6 @@ declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | und
|
|
|
1984
2280
|
* bitwise identical to JAX.
|
|
1985
2281
|
*/
|
|
1986
2282
|
declare const normal: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
1987
|
-
//# sourceMappingURL=random.d.ts.map
|
|
1988
2283
|
declare namespace scipy_special_d_exports {
|
|
1989
2284
|
export { erf, erfc, logSoftmax, logit, logsumexp, softmax };
|
|
1990
2285
|
}
|
|
@@ -2015,8 +2310,7 @@ declare const jacfwd: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTr
|
|
|
2015
2310
|
* Construct a Jaxpr by dynamically tracing a function with example inputs.
|
|
2016
2311
|
*/
|
|
2017
2312
|
declare const makeJaxpr: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...args: Parameters<F>) => {
|
|
2018
|
-
jaxpr:
|
|
2019
|
-
consts: Array[];
|
|
2313
|
+
jaxpr: ClosedJaxpr;
|
|
2020
2314
|
treedef: JsTreeDef;
|
|
2021
2315
|
};
|
|
2022
2316
|
/**
|
|
@@ -2064,11 +2358,6 @@ declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>>(f: F)
|
|
|
2064
2358
|
* Compute the Jacobian evaluated row-by-row by reverse-mode AD.
|
|
2065
2359
|
*/
|
|
2066
2360
|
declare const jacrev: typeof jacfwd;
|
|
2067
|
-
/**
|
|
2068
|
-
* @function
|
|
2069
|
-
* Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
|
|
2070
|
-
*/
|
|
2071
|
-
declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
2072
2361
|
/**
|
|
2073
2362
|
* Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
|
|
2074
2363
|
*
|
|
@@ -2090,8 +2379,5 @@ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
|
|
|
2090
2379
|
* default device.
|
|
2091
2380
|
*/
|
|
2092
2381
|
declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
|
|
2093
|
-
//# sourceMappingURL=index.d.ts.map
|
|
2094
|
-
|
|
2095
2382
|
//#endregion
|
|
2096
|
-
export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|
|
2097
|
-
//# sourceMappingURL=index.d.ts.map
|
|
2383
|
+
export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|