@jax-js/jax 0.1.2 → 0.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +16 -34
- package/dist/{backend-DeVfWEFS.cjs → backend-Bu9GY6sK.cjs} +222 -36
- package/dist/{backend-BqymqzuU.js → backend-tngXtWe4.js} +204 -36
- package/dist/index.cjs +1798 -955
- package/dist/index.d.cts +383 -97
- package/dist/index.d.ts +383 -97
- package/dist/index.js +1791 -949
- package/dist/{webgpu-BGuG58KZ.js → webgpu-ChVgx3b6.js} +410 -97
- package/dist/{webgpu-CcGP160M.cjs → webgpu-Oj3Kd-kd.cjs} +410 -97
- package/package.json +1 -1
package/dist/index.d.cts
CHANGED
|
@@ -124,7 +124,6 @@ declare class ShapeTracker {
|
|
|
124
124
|
/** Like pad(), but allows for negative values. */
|
|
125
125
|
padOrShrink(arg: Pair[]): ShapeTracker;
|
|
126
126
|
}
|
|
127
|
-
//# sourceMappingURL=shape.d.ts.map
|
|
128
127
|
//#endregion
|
|
129
128
|
//#region src/utils.d.ts
|
|
130
129
|
/**
|
|
@@ -254,7 +253,7 @@ declare class AluExp implements FpHashable {
|
|
|
254
253
|
/** Substitute variables in this AluExp to values. */
|
|
255
254
|
substitute(variables: Record<string, AluExp>): AluExp;
|
|
256
255
|
/** Reindex gid values in this expression as needed. */
|
|
257
|
-
reindexGids(
|
|
256
|
+
reindexGids(newGids: number[]): AluExp;
|
|
258
257
|
get min(): number;
|
|
259
258
|
get max(): number;
|
|
260
259
|
/** Largest known integer that divides self. */
|
|
@@ -407,6 +406,52 @@ declare class Reduction implements FpHashable {
|
|
|
407
406
|
}
|
|
408
407
|
/** Expression for accessing `indices` in input array with the given shape. */
|
|
409
408
|
//#endregion
|
|
409
|
+
//#region src/routine.d.ts
|
|
410
|
+
/**
|
|
411
|
+
* Advanced operations that don't fit into the `AluExp` compiler representation.
|
|
412
|
+
*
|
|
413
|
+
* Some routines like iterative matrix algorithms, FFTs, or sorting may not be
|
|
414
|
+
* easy to express efficiently as a `Kernel` object. These also tend to be
|
|
415
|
+
* somewhat expensive, so the benefit of kernel fusion and inlining is less
|
|
416
|
+
* relevant.
|
|
417
|
+
*
|
|
418
|
+
* For these operations, we dispatch them as a custom operation on the backend,
|
|
419
|
+
* which each backend implements in a specific way. These are listed in the
|
|
420
|
+
* `Routines` enum below.
|
|
421
|
+
*
|
|
422
|
+
* Routines cannot be fused into other kernels and always operate on contiguous
|
|
423
|
+
* arrays (default `ShapeTracker`).
|
|
424
|
+
*/
|
|
425
|
+
declare class Routine {
|
|
426
|
+
/** The name of the routine. */
|
|
427
|
+
readonly name: Routines;
|
|
428
|
+
/** Dtype and shape of the inputs and outputs. */
|
|
429
|
+
readonly type: RoutineType;
|
|
430
|
+
/** Extra parameters specific to the routine. */
|
|
431
|
+
readonly params?: any | undefined;
|
|
432
|
+
constructor(/** The name of the routine. */
|
|
433
|
+
name: Routines, /** Dtype and shape of the inputs and outputs. */
|
|
434
|
+
type: RoutineType, /** Extra parameters specific to the routine. */
|
|
435
|
+
params?: any | undefined);
|
|
436
|
+
}
|
|
437
|
+
/** One of the valid `Routine` that can be dispatched to backend. */
|
|
438
|
+
declare enum Routines {
|
|
439
|
+
/** Stable sorting algorithm along the last axis. */
|
|
440
|
+
Sort = "Sort",
|
|
441
|
+
/** Returns `int32` indices of the stably sorted array. */
|
|
442
|
+
Argsort = "Argsort",
|
|
443
|
+
/** Solve a triangular system of questions. */
|
|
444
|
+
TriangularSolve = "TriangularSolve",
|
|
445
|
+
/** Cholesky decomposition of 2D positive semi-definite matrices. */
|
|
446
|
+
Cholesky = "Cholesky",
|
|
447
|
+
}
|
|
448
|
+
interface RoutineType {
|
|
449
|
+
inputShapes: number[][];
|
|
450
|
+
inputDtypes: DType[];
|
|
451
|
+
outputShapes: number[][];
|
|
452
|
+
outputDtypes: DType[];
|
|
453
|
+
}
|
|
454
|
+
//#endregion
|
|
410
455
|
//#region src/backend.d.ts
|
|
411
456
|
type Device = "cpu" | "wasm" | "webgpu";
|
|
412
457
|
declare const devices: Device[];
|
|
@@ -444,9 +489,13 @@ interface Backend {
|
|
|
444
489
|
/** Read a range of bytes from a buffer, blocking variant. */
|
|
445
490
|
readSync(slot: Slot, start?: number, count?: number): Uint8Array<ArrayBuffer>;
|
|
446
491
|
/** Prepare an expression to be executed later. */
|
|
447
|
-
|
|
492
|
+
prepareKernel(kernel: Kernel): Promise<Executable>;
|
|
448
493
|
/** Prepare an expression to be executed later, blocking variant. */
|
|
449
|
-
|
|
494
|
+
prepareKernelSync(kernel: Kernel): Executable;
|
|
495
|
+
/** Prepare an advanced routine to be executed later. */
|
|
496
|
+
prepareRoutine(routine: Routine): Promise<Executable>;
|
|
497
|
+
/** Prepare an advanced routine to be executed later, blocking variant. */
|
|
498
|
+
prepareRoutineSync(routine: Routine): Executable;
|
|
450
499
|
/**
|
|
451
500
|
* Run a backend operation that was previously prepared.
|
|
452
501
|
*
|
|
@@ -457,14 +506,69 @@ interface Backend {
|
|
|
457
506
|
dispatch(exe: Executable, inputs: Slot[], outputs: Slot[]): void;
|
|
458
507
|
}
|
|
459
508
|
declare class Executable<T = any> {
|
|
460
|
-
|
|
461
|
-
|
|
509
|
+
/** The `Kernel` or `Routine` that was prepared. */
|
|
510
|
+
readonly source: Kernel | Routine;
|
|
511
|
+
/** Extra data specific to the backend running this executable. */
|
|
462
512
|
readonly data: T;
|
|
463
|
-
constructor(
|
|
513
|
+
constructor(/** The `Kernel` or `Routine` that was prepared. */
|
|
514
|
+
source: Kernel | Routine, /** Extra data specific to the backend running this executable. */
|
|
464
515
|
data: T);
|
|
465
516
|
}
|
|
517
|
+
declare namespace numpy_fft_d_exports {
|
|
518
|
+
export { ComplexPair, fft, ifft };
|
|
519
|
+
}
|
|
520
|
+
/**
|
|
521
|
+
* A pair of arrays representing real and imaginary part `a + bj`. Both arrays
|
|
522
|
+
* must have the same shape.
|
|
523
|
+
*/
|
|
524
|
+
type ComplexPair = {
|
|
525
|
+
real: Array;
|
|
526
|
+
imag: Array;
|
|
527
|
+
};
|
|
528
|
+
/**
|
|
529
|
+
* Compute a one-dimensional discrete Fourier transform.
|
|
530
|
+
*
|
|
531
|
+
* Currently, the size of the axis must be a power of two.
|
|
532
|
+
*/
|
|
533
|
+
declare function fft(a: ComplexPair, axis?: number): ComplexPair;
|
|
534
|
+
/**
|
|
535
|
+
* Compute a one-dimensional inverse discrete Fourier transform.
|
|
536
|
+
*
|
|
537
|
+
* Currently, the size of the axis must be a power of two.
|
|
538
|
+
*/
|
|
539
|
+
declare function ifft(a: ComplexPair, axis?: number): ComplexPair;
|
|
540
|
+
declare namespace numpy_linalg_d_exports {
|
|
541
|
+
export { cholesky$1 as cholesky, diagonal, lstsq, matmul, matrixTranspose, outer, tensordot, trace, vecdot };
|
|
542
|
+
}
|
|
543
|
+
/**
|
|
544
|
+
* Compute the Cholesky decomposition of a (batched) positive-definite matrix.
|
|
545
|
+
*
|
|
546
|
+
* This is like `jax.lax.linalg.cholesky()`, except with an option to symmetrize
|
|
547
|
+
* the input matrix, which is on by default.
|
|
548
|
+
*/
|
|
549
|
+
declare function cholesky$1(a: ArrayLike, {
|
|
550
|
+
upper,
|
|
551
|
+
symmetrizeInput
|
|
552
|
+
}?: {
|
|
553
|
+
upper?: boolean;
|
|
554
|
+
symmetrizeInput?: boolean;
|
|
555
|
+
}): Array;
|
|
556
|
+
/**
|
|
557
|
+
* Return the least-squares solution to a linear equation.
|
|
558
|
+
*
|
|
559
|
+
* For overdetermined systems, this finds the `x` that minimizes `norm(ax - b)`.
|
|
560
|
+
* For underdetermined systems, this finds the minimum-norm solution for `x`.
|
|
561
|
+
*
|
|
562
|
+
* This currently uses Cholesky decomposition to solve the normal equations,
|
|
563
|
+
* under the hood. The method is not as robust as QR or SVD.
|
|
564
|
+
*
|
|
565
|
+
* @param a coefficient matrix of shape `(M, N)`
|
|
566
|
+
* @param b right-hand side of shape `(M,)` or `(M, K)`
|
|
567
|
+
* @return least-squares solution of shape `(N,)` or `(N, K)`
|
|
568
|
+
*/
|
|
569
|
+
declare function lstsq(a: ArrayLike, b: ArrayLike): Array;
|
|
466
570
|
declare namespace numpy_d_exports {
|
|
467
|
-
export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot$1 as dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, float64, floor, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, positive, pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, squeeze, stack, std, subtract, sum, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
571
|
+
export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, dot$1 as dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, flip, fliplr, flipud, float16, float32, float64, floor, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, matmul, matrixTranspose, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sort, sqrt, square, squeeze, stack, std, subtract, sum, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
468
572
|
}
|
|
469
573
|
declare const float32 = DType.Float32;
|
|
470
574
|
declare const int32 = DType.Int32;
|
|
@@ -590,6 +694,20 @@ declare function prod(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
|
590
694
|
declare function min(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
591
695
|
/** Return the maximum of array elements along a given axis. */
|
|
592
696
|
declare function max(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
697
|
+
/**
|
|
698
|
+
* Test whether all array elements along a given axis evaluate to True.
|
|
699
|
+
*
|
|
700
|
+
* Returns a boolean array with the same shape as `a` with the specified axis
|
|
701
|
+
* removed. If axis is None, returns a scalar.
|
|
702
|
+
*/
|
|
703
|
+
declare function all(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
704
|
+
/**
|
|
705
|
+
* Test whether any array element along a given axis evaluates to True.
|
|
706
|
+
*
|
|
707
|
+
* Returns a boolean array with the same shape as `a` with the specified axis
|
|
708
|
+
* removed. If axis is None, returns a scalar.
|
|
709
|
+
*/
|
|
710
|
+
declare function any(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
593
711
|
/** Return the peak-to-peak range along a given axis (`max - min`). */
|
|
594
712
|
declare function ptp(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
595
713
|
/** Compute the average of the array elements along the specified axis. */
|
|
@@ -608,6 +726,13 @@ declare function argmin(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
|
|
|
608
726
|
* specified axis.
|
|
609
727
|
*/
|
|
610
728
|
declare function argmax(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
|
|
729
|
+
/**
|
|
730
|
+
* Cumulative sum of elements along an axis.
|
|
731
|
+
*
|
|
732
|
+
* Currently this function is `O(n^2)`, we'll improve this later on with a
|
|
733
|
+
* two-phase parallel reduction algorithm.
|
|
734
|
+
*/
|
|
735
|
+
declare function cumsum(a: ArrayLike, axis?: number): Array;
|
|
611
736
|
/** Reverse the elements in an array along the given axes. */
|
|
612
737
|
declare function flip(x: ArrayLike, axis?: Axis): Array;
|
|
613
738
|
/**
|
|
@@ -653,12 +778,29 @@ declare function columnStack(xs: ArrayLike[]): Array;
|
|
|
653
778
|
declare function flipud(x: ArrayLike): Array;
|
|
654
779
|
/** Flip an array horizontally (axis=1). */
|
|
655
780
|
declare function fliplr(x: ArrayLike): Array;
|
|
656
|
-
/**
|
|
657
|
-
declare
|
|
781
|
+
/** Transpose the last two dimensions of an array. */
|
|
782
|
+
declare function matrixTranspose(a: ArrayLike): Array;
|
|
658
783
|
/** Return a 1-D flattened array containing the elements of the input. */
|
|
659
784
|
declare function ravel(a: ArrayLike): Array;
|
|
660
785
|
/** Remove one or more length-1 axes from an array. */
|
|
661
786
|
declare function squeeze(a: ArrayLike, axis?: Axis): Array;
|
|
787
|
+
/**
|
|
788
|
+
* Expand the shape of an array by inserting new axes of length 1.
|
|
789
|
+
*
|
|
790
|
+
* @param a - Input array.
|
|
791
|
+
* @param axis - Position(s) in the expanded axes where the new axis (or axes)
|
|
792
|
+
* is placed. Can be a single integer or an array of integers.
|
|
793
|
+
* @returns Array with the number of dimensions increased.
|
|
794
|
+
*
|
|
795
|
+
* @example
|
|
796
|
+
* ```ts
|
|
797
|
+
* const x = np.array([1, 2]);
|
|
798
|
+
* np.expandDims(x, 0); // Shape [1, 2]
|
|
799
|
+
* np.expandDims(x, 1); // Shape [2, 1]
|
|
800
|
+
* np.expandDims(x, [0, 2]); // Shape [1, 2, 1]
|
|
801
|
+
* ```
|
|
802
|
+
*/
|
|
803
|
+
declare function expandDims(a: ArrayLike, axis: number | number[]): Array;
|
|
662
804
|
/**
|
|
663
805
|
* Repeat each element of an array after themselves.
|
|
664
806
|
*
|
|
@@ -705,6 +847,22 @@ declare function diagonal(a: ArrayLike, offset?: number, axis1?: number, axis2?:
|
|
|
705
847
|
declare function diag(v: ArrayLike, k?: number): Array;
|
|
706
848
|
/** Calculate the sum of the diagonal of an array along the given axes. */
|
|
707
849
|
declare function trace(a: ArrayLike, offset?: number, axis1?: number, axis2?: number): Array;
|
|
850
|
+
/**
|
|
851
|
+
* Return a sorted copy of an array.
|
|
852
|
+
*
|
|
853
|
+
* The array is sorted along a specified axis (the last by default). This may be
|
|
854
|
+
* an unstable sort, and it dispatches to device-specific implementation.
|
|
855
|
+
*/
|
|
856
|
+
declare function sort(a: ArrayLike, axis?: number): Array;
|
|
857
|
+
/**
|
|
858
|
+
* Return indices that would sort an array. This may be an unstable sorting
|
|
859
|
+
* algorithm; it need not preserve order of indices in ties.
|
|
860
|
+
*
|
|
861
|
+
* Returns an array of `int32` indices.
|
|
862
|
+
*
|
|
863
|
+
* The array is sorted along a specified axis (the last by default).
|
|
864
|
+
*/
|
|
865
|
+
declare function argsort(a: ArrayLike, axis?: number): Array;
|
|
708
866
|
/** Return if two arrays are element-wise equal within a tolerance. */
|
|
709
867
|
declare function allclose(actual: Parameters<typeof array>[0], expected: Parameters<typeof array>[0], options?: {
|
|
710
868
|
rtol?: number;
|
|
@@ -776,6 +934,10 @@ declare function vecdot(x: ArrayLike, y: ArrayLike, {
|
|
|
776
934
|
* Like vecdot() but flattens the arguments first into vectors.
|
|
777
935
|
*/
|
|
778
936
|
declare function vdot(x: ArrayLike, y: ArrayLike): Array;
|
|
937
|
+
/** Convolution of two one-dimensional arrays. */
|
|
938
|
+
declare function convolve(x: Array, y: Array, mode?: "full" | "same" | "valid"): Array;
|
|
939
|
+
/** Correlation of two one dimensional arrays. */
|
|
940
|
+
declare function correlate(x: Array, y: Array, mode?: "full" | "same" | "valid"): Array;
|
|
779
941
|
/**
|
|
780
942
|
* Return a tuple of coordinate matrices from coordinate vectors.
|
|
781
943
|
*
|
|
@@ -787,21 +949,6 @@ declare function meshgrid(xs: Array[], {
|
|
|
787
949
|
}?: {
|
|
788
950
|
indexing?: "xy" | "ij";
|
|
789
951
|
}): Array[];
|
|
790
|
-
/**
|
|
791
|
-
* Return an array with ones on and below the diagonal and zeros elsewhere.
|
|
792
|
-
*
|
|
793
|
-
* If `k` is provided, it specifies the sub-diagonal on and below which the
|
|
794
|
-
* array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
|
|
795
|
-
* `k>0` is above it.
|
|
796
|
-
*/
|
|
797
|
-
declare function tri(n: number, m?: number, k?: number, {
|
|
798
|
-
dtype,
|
|
799
|
-
device
|
|
800
|
-
}?: DTypeAndDevice): Array;
|
|
801
|
-
/** Return the lower triangle of an array. Must be of dimension >= 2. */
|
|
802
|
-
declare function tril(a: ArrayLike, k?: number): Array;
|
|
803
|
-
/** Return the upper triangle of an array. Must be of dimension >= 2. */
|
|
804
|
-
declare function triu(a: ArrayLike, k?: number): Array;
|
|
805
952
|
/**
|
|
806
953
|
* Clip (limit) the values in an array.
|
|
807
954
|
*
|
|
@@ -818,8 +965,6 @@ declare function clip(a: ArrayLike, min?: ArrayLike, max?: ArrayLike): Array;
|
|
|
818
965
|
* This is the same function as `jax.numpy.abs()`.
|
|
819
966
|
*/
|
|
820
967
|
declare function absolute(x: ArrayLike): Array;
|
|
821
|
-
/** @function Alias of `jax.numpy.absolute()`. */
|
|
822
|
-
declare const abs: typeof absolute;
|
|
823
968
|
/** Return an element-wise indication of sign of the input. */
|
|
824
969
|
declare function sign(x: ArrayLike): Array;
|
|
825
970
|
/** @function Return element-wise positive values of the input (no-op). */
|
|
@@ -873,12 +1018,6 @@ declare const hypot: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
|
873
1018
|
* The output is ill-defined when both x and y are zero.
|
|
874
1019
|
*/
|
|
875
1020
|
declare const atan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
|
|
876
|
-
/** @function Alias of `jax.numpy.acos()`. */
|
|
877
|
-
declare const arccos: typeof acos;
|
|
878
|
-
/** @function Alias of `jax.numpy.atan()`. */
|
|
879
|
-
declare const arctan: (x: ArrayLike) => Array;
|
|
880
|
-
/** @function Alias of `jax.numpy.atan2()`. */
|
|
881
|
-
declare const arctan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
|
|
882
1021
|
/** Element-wise subtraction, with broadcasting. */
|
|
883
1022
|
declare function subtract(x: ArrayLike, y: ArrayLike): Array;
|
|
884
1023
|
/** Calculates the floating-point division of x by y element-wise. */
|
|
@@ -893,8 +1032,6 @@ declare const fmod: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
|
|
|
893
1032
|
* Calculate element-wise remainder of the division (matches sign of y).
|
|
894
1033
|
*/
|
|
895
1034
|
declare const remainder: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
|
|
896
|
-
/** @function Alias of `jax.numpy.trueDivide()`. */
|
|
897
|
-
declare const divide: typeof trueDivide;
|
|
898
1035
|
/** Round input to the nearest integer towards zero. */
|
|
899
1036
|
declare function trunc(x: ArrayLike): Array;
|
|
900
1037
|
/**
|
|
@@ -934,8 +1071,6 @@ declare const degrees: typeof rad2deg;
|
|
|
934
1071
|
* Computes first array raised to power of second array, element-wise.
|
|
935
1072
|
*/
|
|
936
1073
|
declare const power: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
937
|
-
/** @function Alias of `jax.numpy.power()`. */
|
|
938
|
-
declare const pow: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
939
1074
|
/** @function Calculate the element-wise cube root of the input array. */
|
|
940
1075
|
declare const cbrt: OwnedFunction<(x: ArrayLike) => Array>;
|
|
941
1076
|
/**
|
|
@@ -980,12 +1115,6 @@ declare const arccosh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
|
980
1115
|
* `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
|
|
981
1116
|
*/
|
|
982
1117
|
declare const arctanh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
983
|
-
/** @function Alias of `jax.numpy.arcsinh()`. */
|
|
984
|
-
declare const asinh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
985
|
-
/** @function Alias of `jax.numpy.arccosh()`. */
|
|
986
|
-
declare const acosh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
987
|
-
/** @function Alias of `jax.numpy.arctanh()`. */
|
|
988
|
-
declare const atanh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
989
1118
|
/**
|
|
990
1119
|
* Compute the variance of an array.
|
|
991
1120
|
*
|
|
@@ -1012,6 +1141,10 @@ declare function std(x: ArrayLike, axis?: Axis, opts?: {
|
|
|
1012
1141
|
mean?: ArrayLike;
|
|
1013
1142
|
correction?: number;
|
|
1014
1143
|
} & ReduceOpts): Array;
|
|
1144
|
+
/** Estimate the sample covariance of a set of variables. */
|
|
1145
|
+
declare function cov(x: ArrayLike, y?: ArrayLike): Array;
|
|
1146
|
+
/** Compute the Pearson correlation coefficients (in range `[-1, 1]`). */
|
|
1147
|
+
declare function corrcoef(x: ArrayLike, y?: ArrayLike): Array;
|
|
1015
1148
|
/** Test element-wise for positive or negative infinity, return bool array. */
|
|
1016
1149
|
declare function isinf(x: ArrayLike): Array;
|
|
1017
1150
|
/** Test element-wise for NaN (Not a Number). */
|
|
@@ -1025,7 +1158,6 @@ declare function isposinf(x: ArrayLike): Array;
|
|
|
1025
1158
|
* Test element-wise for finite values (not infinity or NaN).
|
|
1026
1159
|
*/
|
|
1027
1160
|
declare const isfinite: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1028
|
-
//# sourceMappingURL=numpy.d.ts.map
|
|
1029
1161
|
declare namespace tree_d_exports {
|
|
1030
1162
|
export { JsTree, JsTreeDef, MapJsTree, NodeType, dispose, flatten, leaves, map, ref, structure, unflatten };
|
|
1031
1163
|
}
|
|
@@ -1076,13 +1208,18 @@ declare function dispose<Tree extends JsTree<Array>>(tree: Tree | null | undefin
|
|
|
1076
1208
|
//#region src/frontend/convolution.d.ts
|
|
1077
1209
|
/** Definition of a general dilated convolution. Should be valid on creation. */
|
|
1078
1210
|
interface ConvParams {
|
|
1211
|
+
vmapDims: number;
|
|
1079
1212
|
strides: number[];
|
|
1080
|
-
padding: [
|
|
1213
|
+
padding: Pair[];
|
|
1081
1214
|
lhsDilation: number[];
|
|
1082
1215
|
rhsDilation: number[];
|
|
1083
1216
|
}
|
|
1084
1217
|
/**
|
|
1085
1218
|
* Check that the shapes and parameters passed to convolution are valid.
|
|
1219
|
+
* Expected shapes of the lhs and rhs of the convolution are:
|
|
1220
|
+
*
|
|
1221
|
+
* - `lhsShape = [*vmapDims, batchSize, inChannels, spatialDims...]`
|
|
1222
|
+
* - `rhsShape = [*vmapDims, outChannels, inChannels, kernelSize...]`
|
|
1086
1223
|
*
|
|
1087
1224
|
* If the check succeeds, returns the output shape.
|
|
1088
1225
|
*/
|
|
@@ -1154,9 +1291,21 @@ declare class Jaxpr implements FpHashable {
|
|
|
1154
1291
|
* - Remove no-op movement operations.
|
|
1155
1292
|
*/
|
|
1156
1293
|
simplify(): Jaxpr;
|
|
1157
|
-
/** Flattens nested
|
|
1294
|
+
/** Flattens nested Jit in a Jaxpr. Useful for handling jit-of-jit. */
|
|
1158
1295
|
flatten(): Jaxpr;
|
|
1159
1296
|
}
|
|
1297
|
+
/** Jaxpr with a collection of associated, traced constants. */
|
|
1298
|
+
declare class ClosedJaxpr {
|
|
1299
|
+
readonly jaxpr: Jaxpr;
|
|
1300
|
+
readonly consts: Tracer[];
|
|
1301
|
+
constructor(jaxpr: Jaxpr, consts: Tracer[]);
|
|
1302
|
+
/** String representation of this Jaxpr. */
|
|
1303
|
+
toString(): string;
|
|
1304
|
+
/** Apply a function to the underlying Jaxpr. */
|
|
1305
|
+
mapJaxpr(f: (jaxpr: Jaxpr) => Jaxpr): ClosedJaxpr;
|
|
1306
|
+
/** Dispose of the constants in this Jaxpr. */
|
|
1307
|
+
dispose(): void;
|
|
1308
|
+
}
|
|
1160
1309
|
/** @inline */
|
|
1161
1310
|
type JitOpts = {
|
|
1162
1311
|
staticArgnums?: number[];
|
|
@@ -1179,7 +1328,9 @@ declare enum Primitive {
|
|
|
1179
1328
|
Mul = "mul",
|
|
1180
1329
|
Idiv = "idiv",
|
|
1181
1330
|
Mod = "mod",
|
|
1182
|
-
// uses sign of
|
|
1331
|
+
// uses sign of numerator, C-style, matches JS but not Python
|
|
1332
|
+
Min = "min",
|
|
1333
|
+
Max = "max",
|
|
1183
1334
|
Neg = "neg",
|
|
1184
1335
|
Reciprocal = "reciprocal",
|
|
1185
1336
|
Floor = "floor",
|
|
@@ -1187,7 +1338,6 @@ declare enum Primitive {
|
|
|
1187
1338
|
StopGradient = "stop_gradient",
|
|
1188
1339
|
Cast = "cast",
|
|
1189
1340
|
Bitcast = "bitcast",
|
|
1190
|
-
RandomBits = "random_bits",
|
|
1191
1341
|
Sin = "sin",
|
|
1192
1342
|
Cos = "cos",
|
|
1193
1343
|
Asin = "asin",
|
|
@@ -1197,8 +1347,6 @@ declare enum Primitive {
|
|
|
1197
1347
|
Erf = "erf",
|
|
1198
1348
|
Erfc = "erfc",
|
|
1199
1349
|
Sqrt = "sqrt",
|
|
1200
|
-
Min = "min",
|
|
1201
|
-
Max = "max",
|
|
1202
1350
|
Reduce = "reduce",
|
|
1203
1351
|
Dot = "dot",
|
|
1204
1352
|
// sum(x*y, axis=-1)
|
|
@@ -1208,14 +1356,23 @@ declare enum Primitive {
|
|
|
1208
1356
|
PoolTranspose = "pool_transpose",
|
|
1209
1357
|
Compare = "compare",
|
|
1210
1358
|
Where = "where",
|
|
1359
|
+
RandomBits = "random_bits",
|
|
1360
|
+
Gather = "gather",
|
|
1211
1361
|
Transpose = "transpose",
|
|
1212
1362
|
Broadcast = "broadcast",
|
|
1213
1363
|
Reshape = "reshape",
|
|
1214
1364
|
Flip = "flip",
|
|
1215
1365
|
Shrink = "shrink",
|
|
1216
1366
|
Pad = "pad",
|
|
1217
|
-
|
|
1218
|
-
|
|
1367
|
+
Sort = "sort",
|
|
1368
|
+
// sort(x, axis=-1)
|
|
1369
|
+
Argsort = "argsort",
|
|
1370
|
+
// argsort(x, axis=-1)
|
|
1371
|
+
TriangularSolve = "triangular_solve",
|
|
1372
|
+
// A is upper triangular, A @ X.T = B.T
|
|
1373
|
+
Cholesky = "cholesky",
|
|
1374
|
+
// A is positive-definite, A = L @ L^T
|
|
1375
|
+
Jit = "jit",
|
|
1219
1376
|
}
|
|
1220
1377
|
interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
1221
1378
|
[Primitive.Cast]: {
|
|
@@ -1241,6 +1398,14 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
|
1241
1398
|
[Primitive.Compare]: {
|
|
1242
1399
|
op: CompareOp;
|
|
1243
1400
|
};
|
|
1401
|
+
[Primitive.RandomBits]: {
|
|
1402
|
+
shape: number[];
|
|
1403
|
+
mode: "xor" | 0 | 1;
|
|
1404
|
+
};
|
|
1405
|
+
[Primitive.Gather]: {
|
|
1406
|
+
axis: number[];
|
|
1407
|
+
outDim: number;
|
|
1408
|
+
};
|
|
1244
1409
|
[Primitive.Transpose]: {
|
|
1245
1410
|
perm: number[];
|
|
1246
1411
|
};
|
|
@@ -1248,10 +1413,6 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
|
1248
1413
|
shape: number[];
|
|
1249
1414
|
axis: number[];
|
|
1250
1415
|
};
|
|
1251
|
-
[Primitive.RandomBits]: {
|
|
1252
|
-
shape: number[];
|
|
1253
|
-
mode: "xor" | 0 | 1;
|
|
1254
|
-
};
|
|
1255
1416
|
[Primitive.Reshape]: {
|
|
1256
1417
|
shape: number[];
|
|
1257
1418
|
};
|
|
@@ -1264,15 +1425,14 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
|
1264
1425
|
[Primitive.Pad]: {
|
|
1265
1426
|
width: Pair[];
|
|
1266
1427
|
};
|
|
1267
|
-
[Primitive.
|
|
1268
|
-
axis: number[];
|
|
1269
|
-
outDim: number;
|
|
1270
|
-
};
|
|
1271
|
-
[Primitive.JitCall]: {
|
|
1428
|
+
[Primitive.Jit]: {
|
|
1272
1429
|
name: string;
|
|
1273
1430
|
jaxpr: Jaxpr;
|
|
1274
1431
|
numConsts: number;
|
|
1275
1432
|
};
|
|
1433
|
+
[Primitive.TriangularSolve]: {
|
|
1434
|
+
unitDiagonal: boolean;
|
|
1435
|
+
};
|
|
1276
1436
|
}
|
|
1277
1437
|
/** Type of parameters taken by each primitive. */
|
|
1278
1438
|
type PrimitiveParams<T extends Primitive> = T extends keyof PrimitiveParamsImpl ? PrimitiveParamsImpl[T] : Record<string, never>;
|
|
@@ -1440,7 +1600,7 @@ declare abstract class Tracer {
|
|
|
1440
1600
|
sub(other: this | TracerValue): this;
|
|
1441
1601
|
/** Divide an array by this one. */
|
|
1442
1602
|
div(other: this | TracerValue): this;
|
|
1443
|
-
/** Return specified diagonals. See `numpy.diagonal` for full docs. */
|
|
1603
|
+
/** Return specified diagonals. See `jax.numpy.diagonal` for full docs. */
|
|
1444
1604
|
diagonal(offset?: number, axis1?: number, axis2?: number): this;
|
|
1445
1605
|
/** Flatten the array without changing its data. */
|
|
1446
1606
|
flatten(): this;
|
|
@@ -1459,6 +1619,19 @@ declare abstract class Tracer {
|
|
|
1459
1619
|
* ```
|
|
1460
1620
|
*/
|
|
1461
1621
|
[Symbol.iterator](): IterableIterator<this>;
|
|
1622
|
+
/**
|
|
1623
|
+
* Return a sorted copy of an array in ascending order.
|
|
1624
|
+
*
|
|
1625
|
+
* See `jax.numpy.sort` for full docs.
|
|
1626
|
+
*/
|
|
1627
|
+
sort(axis?: number): this;
|
|
1628
|
+
/**
|
|
1629
|
+
* Return the indices that would sort an array. This may not be a stable
|
|
1630
|
+
* sorting algorithm; it need not preserve order of indices in ties.
|
|
1631
|
+
*
|
|
1632
|
+
* See `jax.numpy.argsort` for full docs.
|
|
1633
|
+
*/
|
|
1634
|
+
argsort(axis?: number): this;
|
|
1462
1635
|
/**
|
|
1463
1636
|
* Slice an array along one or more axes.
|
|
1464
1637
|
*
|
|
@@ -1501,6 +1674,7 @@ declare class ShapedArray implements AbstractValue {
|
|
|
1501
1674
|
constructor(shape: number[], dtype: DType, weakType: boolean);
|
|
1502
1675
|
static fromAval(aval: AbstractValue): ShapedArray;
|
|
1503
1676
|
get ndim(): number;
|
|
1677
|
+
get size(): number;
|
|
1504
1678
|
toString(): string;
|
|
1505
1679
|
equals(other: ShapedArray): boolean;
|
|
1506
1680
|
}
|
|
@@ -1518,12 +1692,12 @@ type ArrayLike = Array | number | boolean;
|
|
|
1518
1692
|
declare class PendingExecute {
|
|
1519
1693
|
#private;
|
|
1520
1694
|
readonly backend: Backend;
|
|
1521
|
-
readonly
|
|
1695
|
+
readonly source: Kernel | Routine;
|
|
1522
1696
|
readonly inputs: Slot[];
|
|
1523
1697
|
readonly outputs: Slot[];
|
|
1524
1698
|
prepared: Executable | null;
|
|
1525
1699
|
submitted: boolean;
|
|
1526
|
-
constructor(backend: Backend,
|
|
1700
|
+
constructor(backend: Backend, source: Kernel | Routine, inputs: Slot[], outputs: Slot[]);
|
|
1527
1701
|
updateRc(delta: number): void;
|
|
1528
1702
|
prepare(): Promise<void>;
|
|
1529
1703
|
prepareSync(): void;
|
|
@@ -1555,7 +1729,6 @@ type ArrayConstructorArgs = {
|
|
|
1555
1729
|
*/
|
|
1556
1730
|
declare class Array extends Tracer {
|
|
1557
1731
|
#private;
|
|
1558
|
-
id: number;
|
|
1559
1732
|
/**
|
|
1560
1733
|
* @ignore
|
|
1561
1734
|
* Constructs an array from source, shape and backend. Note that if the source
|
|
@@ -1689,6 +1862,21 @@ declare function arange(start: number, stop?: number, step?: number, {
|
|
|
1689
1862
|
dtype,
|
|
1690
1863
|
device
|
|
1691
1864
|
}?: DTypeAndDevice): Array;
|
|
1865
|
+
/**
|
|
1866
|
+
* Return an array with ones on and below the diagonal and zeros elsewhere.
|
|
1867
|
+
*
|
|
1868
|
+
* If `k` is provided, it specifies the sub-diagonal on and below which the
|
|
1869
|
+
* array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
|
|
1870
|
+
* `k>0` is above it.
|
|
1871
|
+
*/
|
|
1872
|
+
declare function tri(n: number, m?: number, k?: number, {
|
|
1873
|
+
dtype,
|
|
1874
|
+
device
|
|
1875
|
+
}?: DTypeAndDevice): Array;
|
|
1876
|
+
/** Return the lower triangle of an array. Must be of dimension >= 2. */
|
|
1877
|
+
declare function tril(a: ArrayLike, k?: number): Array;
|
|
1878
|
+
/** Return the upper triangle of an array. Must be of dimension >= 2. */
|
|
1879
|
+
declare function triu(a: ArrayLike, k?: number): Array;
|
|
1692
1880
|
/**
|
|
1693
1881
|
* Return evenly spaced numbers over a specified interval.
|
|
1694
1882
|
*
|
|
@@ -1702,8 +1890,71 @@ declare function linspace(start: number, stop: number, num?: number, endpoint?:
|
|
|
1702
1890
|
dtype,
|
|
1703
1891
|
device
|
|
1704
1892
|
}?: DTypeAndDevice): Array;
|
|
1893
|
+
declare namespace lax_linalg_d_exports {
|
|
1894
|
+
export { cholesky, triangularSolve };
|
|
1895
|
+
}
|
|
1896
|
+
/**
|
|
1897
|
+
* Compute the Cholesky decomposition of a symmetric positive-definite matrix.
|
|
1898
|
+
*
|
|
1899
|
+
* The Cholesky decomposition of a matrix `A` is:
|
|
1900
|
+
*
|
|
1901
|
+
* - A = L @ L^T (for upper=false, default)
|
|
1902
|
+
* - A = U^T @ U (for upper=true)
|
|
1903
|
+
*
|
|
1904
|
+
* where `L` is a lower-triangular matrix and `U` is an upper-triangular matrix.
|
|
1905
|
+
* The input matrix must be symmetric and positive-definite.
|
|
1906
|
+
*
|
|
1907
|
+
* @example
|
|
1908
|
+
* ```ts
|
|
1909
|
+
* import { lax, numpy as np } from "@jax-js/jax";
|
|
1910
|
+
*
|
|
1911
|
+
* const x = np.array([[2., 1.], [1., 2.]]);
|
|
1912
|
+
*
|
|
1913
|
+
* // Lower Cholesky factorization (default):
|
|
1914
|
+
* const L = lax.linalg.cholesky(x);
|
|
1915
|
+
* // L ≈ [[1.4142135, 0], [0.70710677, 1.2247449]]
|
|
1916
|
+
*
|
|
1917
|
+
* // Upper Cholesky factorization:
|
|
1918
|
+
* const U = lax.linalg.cholesky(x, { upper: true });
|
|
1919
|
+
* // U ≈ [[1.4142135, 0.70710677], [0, 1.2247449]]
|
|
1920
|
+
* ```
|
|
1921
|
+
*/
|
|
1922
|
+
declare function cholesky(a: ArrayLike, {
|
|
1923
|
+
upper
|
|
1924
|
+
}?: {
|
|
1925
|
+
upper?: boolean;
|
|
1926
|
+
}): Array;
|
|
1927
|
+
/**
|
|
1928
|
+
* Solve a triangular linear system.
|
|
1929
|
+
*
|
|
1930
|
+
* Solves `a @ x = b` (if leftSide=true) or `x @ a = b` (if leftSide=false)
|
|
1931
|
+
* where `a` is a triangular matrix.
|
|
1932
|
+
*
|
|
1933
|
+
* @example
|
|
1934
|
+
* ```ts
|
|
1935
|
+
* import { lax, numpy as np } from "@jax-js/jax";
|
|
1936
|
+
*
|
|
1937
|
+
* const L = np.array([[2., 0.], [1., 3.]]);
|
|
1938
|
+
* const b = np.array([4., 7.]).reshape([2, 1]);
|
|
1939
|
+
*
|
|
1940
|
+
* // Solve L @ x = b
|
|
1941
|
+
* const x = lax.linalg.triangularSolve(L, b, { leftSide: true, lower: true });
|
|
1942
|
+
* // x = [[2.], [5./3.]]
|
|
1943
|
+
* ```
|
|
1944
|
+
*/
|
|
1945
|
+
declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
|
|
1946
|
+
leftSide,
|
|
1947
|
+
lower,
|
|
1948
|
+
transposeA,
|
|
1949
|
+
unitDiagonal
|
|
1950
|
+
}?: {
|
|
1951
|
+
leftSide?: boolean;
|
|
1952
|
+
lower?: boolean;
|
|
1953
|
+
transposeA?: boolean;
|
|
1954
|
+
unitDiagonal?: boolean;
|
|
1955
|
+
}): Array;
|
|
1705
1956
|
declare namespace lax_d_exports {
|
|
1706
|
-
export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convWithGeneralPadding, dot, erf, erfc, reduceWindow, stopGradient };
|
|
1957
|
+
export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convWithGeneralPadding, dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
|
|
1707
1958
|
}
|
|
1708
1959
|
/**
|
|
1709
1960
|
* Dimension numbers for general `dot()` primitive.
|
|
@@ -1734,7 +1985,7 @@ declare function dot(lhs: Array, rhs: Array, {
|
|
|
1734
1985
|
lhsBatchDims: lb,
|
|
1735
1986
|
rhsBatchDims: rb
|
|
1736
1987
|
}?: DotDimensionNumbers): Array;
|
|
1737
|
-
type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | [
|
|
1988
|
+
type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | Pair[];
|
|
1738
1989
|
/**
|
|
1739
1990
|
* General n-dimensional convolution operator, with optional dilation.
|
|
1740
1991
|
*
|
|
@@ -1745,10 +1996,12 @@ type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | [number, number][];
|
|
|
1745
1996
|
*/
|
|
1746
1997
|
declare function convGeneralDilated(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, {
|
|
1747
1998
|
lhsDilation,
|
|
1748
|
-
rhsDilation
|
|
1999
|
+
rhsDilation,
|
|
2000
|
+
featureGroupCount
|
|
1749
2001
|
}?: {
|
|
1750
2002
|
lhsDilation?: number[];
|
|
1751
2003
|
rhsDilation?: number[];
|
|
2004
|
+
featureGroupCount?: number;
|
|
1752
2005
|
}): Array;
|
|
1753
2006
|
/** Convenience wrapper around `convGeneralDilated`. */
|
|
1754
2007
|
declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, lhsDilation?: number[], rhsDilation?: number[]): Array;
|
|
@@ -1772,9 +2025,8 @@ declare function erfc(x: ArrayLike): Array;
|
|
|
1772
2025
|
* forward or reverse-mode automatic differentiation.
|
|
1773
2026
|
*/
|
|
1774
2027
|
declare function stopGradient(x: ArrayLike): Array;
|
|
1775
|
-
//# sourceMappingURL=lax.d.ts.map
|
|
1776
2028
|
declare namespace nn_d_exports {
|
|
1777
|
-
export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, squareplus, standardize, swish };
|
|
2029
|
+
export { celu, elu, gelu, glu, hardSigmoid, hardSilu, hardSilu as hardSwish, hardTanh, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, selu, sigmoid, silu, softSign, softmax, softplus, sparsePlus, sparseSigmoid, squareplus, standardize, silu as swish };
|
|
1778
2030
|
}
|
|
1779
2031
|
/**
|
|
1780
2032
|
* Rectified Linear Unit (ReLU) activation function:
|
|
@@ -1801,21 +2053,28 @@ declare function sigmoid(x: ArrayLike): Array;
|
|
|
1801
2053
|
*/
|
|
1802
2054
|
declare function softplus(x: ArrayLike): Array;
|
|
1803
2055
|
/**
|
|
1804
|
-
*
|
|
1805
|
-
*
|
|
2056
|
+
* @function
|
|
2057
|
+
* Sparse plus function:
|
|
2058
|
+
*
|
|
2059
|
+
* - When `x <= -1`: `0`
|
|
2060
|
+
* - When `-1 < x < 1`: `(x+1)**2 / 4`
|
|
2061
|
+
* - When `x >= 1`: `x`
|
|
1806
2062
|
*/
|
|
1807
|
-
declare
|
|
2063
|
+
declare const sparsePlus: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1808
2064
|
/**
|
|
1809
2065
|
* @function
|
|
1810
|
-
*
|
|
1811
|
-
* Swish, computed element-wise:
|
|
1812
|
-
* `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
|
|
1813
|
-
*
|
|
1814
|
-
* `swish()` and `silu()` are both aliases for the same function.
|
|
2066
|
+
* Sparse sigmoid activation function.
|
|
1815
2067
|
*
|
|
1816
|
-
*
|
|
2068
|
+
* - When `x <= -1`: `0`
|
|
2069
|
+
* - When `-1 < x < 1`: `(x + 1) / 2`
|
|
2070
|
+
* - When `x >= 1`: `1`
|
|
1817
2071
|
*/
|
|
1818
|
-
declare const
|
|
2072
|
+
declare const sparseSigmoid: OwnedFunction<(x: ArrayLike) => Array>;
|
|
2073
|
+
/**
|
|
2074
|
+
* Soft-sign activation function, computed element-wise:
|
|
2075
|
+
* `softsign(x) = x / (|x| + 1)`.
|
|
2076
|
+
*/
|
|
2077
|
+
declare function softSign(x: ArrayLike): Array;
|
|
1819
2078
|
/**
|
|
1820
2079
|
* @function
|
|
1821
2080
|
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
@@ -1826,7 +2085,7 @@ declare const silu: OwnedFunction<(x: ArrayLike) => Array>;
|
|
|
1826
2085
|
*
|
|
1827
2086
|
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
1828
2087
|
*/
|
|
1829
|
-
declare const
|
|
2088
|
+
declare const silu: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1830
2089
|
/**
|
|
1831
2090
|
* Log-sigmoid activation function, computed element-wise:
|
|
1832
2091
|
* `log_sigmoid(x) = log(sigmoid(x)) = -log(1 + exp(-x))`.
|
|
@@ -1839,6 +2098,12 @@ declare function logSigmoid(x: ArrayLike): Array;
|
|
|
1839
2098
|
declare const identity: (x: ArrayLike) => Array;
|
|
1840
2099
|
/** Leaky rectified linear (ReLU) activation function */
|
|
1841
2100
|
declare function leakyRelu(x: ArrayLike, negativeSlope?: ArrayLike): Array;
|
|
2101
|
+
/** Hard sigmoid activation function: `relu6(x+3)/6`. */
|
|
2102
|
+
declare function hardSigmoid(x: ArrayLike): Array;
|
|
2103
|
+
/** Hard SiLU (swish) activation function: `x * hardSigmoid(x)`. */
|
|
2104
|
+
declare function hardSilu(x: ArrayLike): Array;
|
|
2105
|
+
/** Hard tanh activation function: `clip(x, -1, 1)`. */
|
|
2106
|
+
declare function hardTanh(x: ArrayLike): Array;
|
|
1842
2107
|
/**
|
|
1843
2108
|
* Exponential linear unit activation function.
|
|
1844
2109
|
*
|
|
@@ -1853,6 +2118,16 @@ declare function elu(x: ArrayLike, alpha?: ArrayLike): Array;
|
|
|
1853
2118
|
* `celu(x) = x > 0 ? x : alpha * (exp(x/alpha) - 1)`
|
|
1854
2119
|
*/
|
|
1855
2120
|
declare function celu(x: ArrayLike, alpha?: ArrayLike): Array;
|
|
2121
|
+
/**
|
|
2122
|
+
* @function
|
|
2123
|
+
* Scaled exponential linear unit activation.
|
|
2124
|
+
*
|
|
2125
|
+
* Computes the element-wise function:
|
|
2126
|
+
* `selu(x) = lambda * (x > 0 ? x : alpha * (exp(x) - 1))`
|
|
2127
|
+
*
|
|
2128
|
+
* Where `alpha = 1.6732632423543772` and `lambda = 1.0507009873554805`.
|
|
2129
|
+
*/
|
|
2130
|
+
declare const selu: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1856
2131
|
/**
|
|
1857
2132
|
* @function
|
|
1858
2133
|
* Gaussion error linear unit (GELU) activation function.
|
|
@@ -1915,9 +2190,9 @@ declare function logSoftmax(x: ArrayLike, axis?: Axis): Array;
|
|
|
1915
2190
|
*
|
|
1916
2191
|
* Reference: https://en.wikipedia.org/wiki/LogSumExp
|
|
1917
2192
|
*/
|
|
1918
|
-
declare function logsumexp(x: ArrayLike, axis?: Axis): Array;
|
|
2193
|
+
declare function logsumexp(x: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1919
2194
|
/** Log-mean-exp reduction, like `jax.nn.logsumexp()` but subtracts `log(n)`. */
|
|
1920
|
-
declare function logmeanexp(x: ArrayLike, axis?: Axis): Array;
|
|
2195
|
+
declare function logmeanexp(x: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1921
2196
|
/**
|
|
1922
2197
|
* Standardizes input to zero mean and unit variance.
|
|
1923
2198
|
*
|
|
@@ -1948,9 +2223,8 @@ declare function standardize(x: ArrayLike, axis?: Axis, opts?: {
|
|
|
1948
2223
|
* ```
|
|
1949
2224
|
*/
|
|
1950
2225
|
declare function oneHot(x: Array, numClasses: number): Array;
|
|
1951
|
-
//# sourceMappingURL=nn.d.ts.map
|
|
1952
2226
|
declare namespace random_d_exports {
|
|
1953
|
-
export { bernoulli, bits, exponential, key, normal, split, uniform };
|
|
2227
|
+
export { bernoulli, bits, cauchy, exponential, gumbel, key, laplace, normal, split, uniform };
|
|
1954
2228
|
}
|
|
1955
2229
|
/** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
|
|
1956
2230
|
declare function key(seed: number): Array;
|
|
@@ -1973,11 +2247,33 @@ declare const uniform: OwnedFunction<(key: ArrayLike, shape?: number[] | undefin
|
|
|
1973
2247
|
* and must be broadcastable to `shape`.
|
|
1974
2248
|
*/
|
|
1975
2249
|
declare function bernoulli(key: Array, p?: ArrayLike, shape?: number[]): Array;
|
|
2250
|
+
/**
|
|
2251
|
+
* @function
|
|
2252
|
+
* Sample from a Cauchy distribution with location 0 and scale 1.
|
|
2253
|
+
*
|
|
2254
|
+
* Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
|
|
2255
|
+
*/
|
|
2256
|
+
declare const cauchy: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
1976
2257
|
/**
|
|
1977
2258
|
* @function
|
|
1978
2259
|
* Sample exponential random values according to `p(x) = exp(-x)`.
|
|
1979
2260
|
*/
|
|
1980
2261
|
declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
2262
|
+
/**
|
|
2263
|
+
* @function
|
|
2264
|
+
* Sample from a Gumbel distribution with location 0 and scale 1.
|
|
2265
|
+
*
|
|
2266
|
+
* Uses inverse transform sampling: `x = -log(-log(u))` where u ~ Uniform(0, 1).
|
|
2267
|
+
*/
|
|
2268
|
+
declare const gumbel: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
2269
|
+
/**
|
|
2270
|
+
* @function
|
|
2271
|
+
* Sample from a Laplace distribution with location 0 and scale 1.
|
|
2272
|
+
*
|
|
2273
|
+
* Uses inverse transform sampling: the CDF is `F(x) = 0.5 + 0.5 * sign(x) * (1 - exp(-|x|))`.
|
|
2274
|
+
* Inverting: `x = -sign(u - 0.5) * log(1 - 2 * |u - 0.5|)`.
|
|
2275
|
+
*/
|
|
2276
|
+
declare const laplace: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
1981
2277
|
/**
|
|
1982
2278
|
* @function
|
|
1983
2279
|
* Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
|
|
@@ -1987,7 +2283,6 @@ declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | und
|
|
|
1987
2283
|
* bitwise identical to JAX.
|
|
1988
2284
|
*/
|
|
1989
2285
|
declare const normal: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
1990
|
-
//# sourceMappingURL=random.d.ts.map
|
|
1991
2286
|
declare namespace scipy_special_d_exports {
|
|
1992
2287
|
export { erf, erfc, logSoftmax, logit, logsumexp, softmax };
|
|
1993
2288
|
}
|
|
@@ -2018,8 +2313,7 @@ declare const jacfwd: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTr
|
|
|
2018
2313
|
* Construct a Jaxpr by dynamically tracing a function with example inputs.
|
|
2019
2314
|
*/
|
|
2020
2315
|
declare const makeJaxpr: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...args: Parameters<F>) => {
|
|
2021
|
-
jaxpr:
|
|
2022
|
-
consts: Array[];
|
|
2316
|
+
jaxpr: ClosedJaxpr;
|
|
2023
2317
|
treedef: JsTreeDef;
|
|
2024
2318
|
};
|
|
2025
2319
|
/**
|
|
@@ -2067,11 +2361,6 @@ declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>>(f: F)
|
|
|
2067
2361
|
* Compute the Jacobian evaluated row-by-row by reverse-mode AD.
|
|
2068
2362
|
*/
|
|
2069
2363
|
declare const jacrev: typeof jacfwd;
|
|
2070
|
-
/**
|
|
2071
|
-
* @function
|
|
2072
|
-
* Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
|
|
2073
|
-
*/
|
|
2074
|
-
declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
2075
2364
|
/**
|
|
2076
2365
|
* Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
|
|
2077
2366
|
*
|
|
@@ -2093,8 +2382,5 @@ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
|
|
|
2093
2382
|
* default device.
|
|
2094
2383
|
*/
|
|
2095
2384
|
declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
|
|
2096
|
-
//# sourceMappingURL=index.d.ts.map
|
|
2097
|
-
|
|
2098
2385
|
//#endregion
|
|
2099
|
-
export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|
|
2100
|
-
//# sourceMappingURL=index.d.cts.map
|
|
2386
|
+
export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|