@jax-js/jax 0.0.3 → 0.0.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +50 -19
- package/dist/{backend-BqDtPGaR.js → backend-EBRGmEYw.js} +296 -153
- package/dist/{backend-D2C4MJRP.cjs → backend-Ss1Mev_-.cjs} +315 -154
- package/dist/index.cjs +681 -157
- package/dist/index.d.cts +422 -76
- package/dist/index.d.ts +422 -76
- package/dist/index.js +677 -157
- package/dist/{webgpu-fqhx41TC.cjs → webgpu-BVdMaO9T.cjs} +9 -3
- package/dist/{webgpu-CNg9JGva.js → webgpu-ow0Pn_6q.js} +9 -3
- package/package.json +15 -4
package/dist/index.d.ts
CHANGED
|
@@ -123,6 +123,19 @@ declare class ShapeTracker {
|
|
|
123
123
|
}
|
|
124
124
|
//#endregion
|
|
125
125
|
//#region src/utils.d.ts
|
|
126
|
+
/**
|
|
127
|
+
* Set the debug level for verbose logging.
|
|
128
|
+
*
|
|
129
|
+
* 1. JIT compile logs
|
|
130
|
+
* 2. Shader code
|
|
131
|
+
* 3. Expressions and metadata
|
|
132
|
+
* 4. JIT programs, tuning details
|
|
133
|
+
* 5. Most verbose operation traces
|
|
134
|
+
*
|
|
135
|
+
* This is an experimental API and may change in behavior. Do not rely on this
|
|
136
|
+
* in production.
|
|
137
|
+
*/
|
|
138
|
+
declare function setDebug(level: number): void;
|
|
126
139
|
/** @inline */
|
|
127
140
|
type RecursiveArray<T> = T | RecursiveArray<T>[];
|
|
128
141
|
interface FpHashable {
|
|
@@ -138,7 +151,7 @@ interface FpHashable {
|
|
|
138
151
|
declare class FpHash {
|
|
139
152
|
#private;
|
|
140
153
|
value: bigint;
|
|
141
|
-
update(
|
|
154
|
+
update(x: string | boolean | number | bigint | null | undefined | FpHashable): this;
|
|
142
155
|
static hash(...values: (string | boolean | number | bigint | null | undefined | FpHashable)[]): bigint;
|
|
143
156
|
}
|
|
144
157
|
/** Run a function while caching it inline inside a `Map`. */
|
|
@@ -154,6 +167,31 @@ declare enum DType {
|
|
|
154
167
|
}
|
|
155
168
|
/** @inline */
|
|
156
169
|
type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | Float16Array<ArrayBuffer>;
|
|
170
|
+
/**
|
|
171
|
+
* Promote two dtypes to their join according to the type lattice.
|
|
172
|
+
*
|
|
173
|
+
* When performing operations between arrays of different types, we need to
|
|
174
|
+
* promote both operands to a common type that can represent values from both
|
|
175
|
+
* input types. This follows JAX's type promotion rules.
|
|
176
|
+
*
|
|
177
|
+
* **Type lattice:**
|
|
178
|
+
* ```text
|
|
179
|
+
* bool -> uint32 -> int32 -> float16 -> float32
|
|
180
|
+
* weak f* --^
|
|
181
|
+
* ```
|
|
182
|
+
*
|
|
183
|
+
* The asterisk f* is a weak type used for JS number constants. When creating
|
|
184
|
+
* arrays, JS numbers default to float32 but "weak" so they cast to the dtype of
|
|
185
|
+
* any array they are first combined with.
|
|
186
|
+
*
|
|
187
|
+
* **Examples:**
|
|
188
|
+
* - `promoteTypes(bool, int32) → int32`
|
|
189
|
+
* - `promoteTypes(uint32, int32) → int32`
|
|
190
|
+
* - `promoteTypes(int32, float16) → float16`
|
|
191
|
+
* - `promoteTypes(float16, float32) → float32`
|
|
192
|
+
* - `promoteTypes(uint32, float32) → float32`
|
|
193
|
+
*/
|
|
194
|
+
declare function promoteTypes(dtype1: DType, dtype2: DType): DType;
|
|
157
195
|
/**
|
|
158
196
|
* Mathematical expression on scalar values.
|
|
159
197
|
*
|
|
@@ -177,6 +215,8 @@ declare class AluExp implements FpHashable {
|
|
|
177
215
|
static max(a: AluExp, b: AluExp): AluExp;
|
|
178
216
|
static sin(a: AluExp): AluExp;
|
|
179
217
|
static cos(a: AluExp): AluExp;
|
|
218
|
+
static asin(a: AluExp): AluExp;
|
|
219
|
+
static atan(a: AluExp): AluExp;
|
|
180
220
|
static exp(a: AluExp): AluExp;
|
|
181
221
|
static log(a: AluExp): AluExp;
|
|
182
222
|
static sqrt(a: AluExp): AluExp;
|
|
@@ -262,6 +302,8 @@ declare enum AluOp {
|
|
|
262
302
|
Max = "Max",
|
|
263
303
|
Sin = "Sin",
|
|
264
304
|
Cos = "Cos",
|
|
305
|
+
Asin = "Asin",
|
|
306
|
+
Atan = "Atan",
|
|
265
307
|
Exp = "Exp",
|
|
266
308
|
Log = "Log",
|
|
267
309
|
Sqrt = "Sqrt",
|
|
@@ -354,8 +396,8 @@ declare class Reduction implements FpHashable {
|
|
|
354
396
|
//#region src/backend.d.ts
|
|
355
397
|
type Device = "cpu" | "wasm" | "webgpu";
|
|
356
398
|
declare const devices: Device[];
|
|
357
|
-
/**
|
|
358
|
-
declare function
|
|
399
|
+
/** Configure the default device for arrays. */
|
|
400
|
+
declare function defaultDevice(device?: Device): Device;
|
|
359
401
|
/**
|
|
360
402
|
* Initialize `jax-js` library backends.
|
|
361
403
|
*
|
|
@@ -491,6 +533,8 @@ declare enum Primitive {
|
|
|
491
533
|
RandomBits = "random_bits",
|
|
492
534
|
Sin = "sin",
|
|
493
535
|
Cos = "cos",
|
|
536
|
+
Asin = "asin",
|
|
537
|
+
Atan = "atan",
|
|
494
538
|
Exp = "exp",
|
|
495
539
|
Log = "log",
|
|
496
540
|
Sqrt = "sqrt",
|
|
@@ -581,8 +625,10 @@ declare enum CompareOp {
|
|
|
581
625
|
LessEqual = "less_equal",
|
|
582
626
|
}
|
|
583
627
|
/** @inline */
|
|
628
|
+
type Axis = number | number[] | null;
|
|
629
|
+
/** @inline */
|
|
584
630
|
type ReduceOpts = {
|
|
585
|
-
|
|
631
|
+
keepdims?: boolean;
|
|
586
632
|
};
|
|
587
633
|
type MainTrace = {
|
|
588
634
|
level: number;
|
|
@@ -659,9 +705,13 @@ declare abstract class Tracer {
|
|
|
659
705
|
* ```
|
|
660
706
|
*/
|
|
661
707
|
abstract dispose(): void;
|
|
708
|
+
/** The shape of the array. */
|
|
662
709
|
get shape(): number[];
|
|
710
|
+
/** The total number of elements in the array. */
|
|
663
711
|
get size(): number;
|
|
712
|
+
/** The dtype of the array. */
|
|
664
713
|
get dtype(): DType;
|
|
714
|
+
/** The number of dimensions of the array. */
|
|
665
715
|
get ndim(): number;
|
|
666
716
|
/** @ignore */
|
|
667
717
|
fullLower(): Tracer;
|
|
@@ -675,11 +725,11 @@ declare abstract class Tracer {
|
|
|
675
725
|
greaterEqual(other: this | TracerValue): this;
|
|
676
726
|
lessEqual(other: this | TracerValue): this;
|
|
677
727
|
/** Sum of the elements of the array over a given axis, or axes. */
|
|
678
|
-
sum(axis?:
|
|
728
|
+
sum(axis?: Axis, opts?: ReduceOpts): this;
|
|
679
729
|
/** Product of the array elements over a given axis. */
|
|
680
|
-
prod(axis?:
|
|
730
|
+
prod(axis?: Axis, opts?: ReduceOpts): this;
|
|
681
731
|
/** Compute the average of the array elements along the specified axis. */
|
|
682
|
-
mean(axis?:
|
|
732
|
+
mean(axis?: Axis, opts?: ReduceOpts): this;
|
|
683
733
|
/** Permute the dimensions of an array. Defaults to reversing the axis order. */
|
|
684
734
|
transpose(perm?: number[]): this;
|
|
685
735
|
/**
|
|
@@ -807,7 +857,11 @@ declare class Array extends Tracer {
|
|
|
807
857
|
* is a backend `Slot`, this constructor _takes ownership_ of the slot. It
|
|
808
858
|
* will be freed when the array is disposed.
|
|
809
859
|
*/
|
|
810
|
-
constructor(source: AluExp | Slot, st: ShapeTracker, dtype: DType, backend: Backend,
|
|
860
|
+
constructor(source: AluExp | Slot, st: ShapeTracker, dtype: DType, backend: Backend, {
|
|
861
|
+
pending
|
|
862
|
+
}?: {
|
|
863
|
+
pending?: Iterable<PendingExecute> | null;
|
|
864
|
+
});
|
|
811
865
|
/** @ignore */
|
|
812
866
|
get aval(): ShapedArray;
|
|
813
867
|
/** Return a simple string representation of the array's dimensions. */
|
|
@@ -836,8 +890,11 @@ declare class Array extends Tracer {
|
|
|
836
890
|
*
|
|
837
891
|
* If you are mapping from `data()` or `dataSync()`, it will also trigger
|
|
838
892
|
* dispatch of operations as well.
|
|
893
|
+
*
|
|
894
|
+
* **Note:** `jax.blockUntilReady()` is a higher-level API, it calls this
|
|
895
|
+
* asynchronously for multiple arrays.
|
|
839
896
|
*/
|
|
840
|
-
|
|
897
|
+
blockUntilReady(): Promise<Array>;
|
|
841
898
|
/**
|
|
842
899
|
* Realize the array and return it as data. This is a sync variant and not
|
|
843
900
|
* recommended for performance reasons, as it will block rendering.
|
|
@@ -867,10 +924,7 @@ declare class Array extends Tracer {
|
|
|
867
924
|
_realizeSource(): number;
|
|
868
925
|
}
|
|
869
926
|
/** Construct an array from a single scalar constant. */
|
|
870
|
-
|
|
871
|
-
dtype,
|
|
872
|
-
device
|
|
873
|
-
}?: DTypeAndDevice): Array;
|
|
927
|
+
|
|
874
928
|
/** Constructor for creating a new array from data. */
|
|
875
929
|
declare function array(values: Array | Float16Array<ArrayBuffer> | Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | RecursiveArray<number> | RecursiveArray<boolean>, {
|
|
876
930
|
shape,
|
|
@@ -908,7 +962,7 @@ declare function eye(numRows: number, numCols?: number, {
|
|
|
908
962
|
dtype,
|
|
909
963
|
device
|
|
910
964
|
}?: DTypeAndDevice): Array;
|
|
911
|
-
/** Return the identity
|
|
965
|
+
/** Return the identity matrix, with ones on the main diagonal. */
|
|
912
966
|
declare function identity$1(n: number, {
|
|
913
967
|
dtype,
|
|
914
968
|
device
|
|
@@ -945,7 +999,7 @@ declare function linspace(start: number, stop: number, num?: number, endpoint?:
|
|
|
945
999
|
device
|
|
946
1000
|
}?: DTypeAndDevice): Array;
|
|
947
1001
|
declare namespace numpy_d_exports {
|
|
948
|
-
export { Array, ArrayLike, DType, abs, absolute, add, allclose, arange, argmax, argmin, array, astype, bool, clip, columnStack, concatenate, cos, cosh, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hstack, identity$1 as identity, inf, int32, less, lessEqual, linspace, log, log10, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, pad, permuteDims, pi, prod, ravel, reciprocal,
|
|
1002
|
+
export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hstack, hypot, identity$1 as identity, inf, inner, int32, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
949
1003
|
}
|
|
950
1004
|
declare const float32 = DType.Float32;
|
|
951
1005
|
declare const int32 = DType.Int32;
|
|
@@ -962,54 +1016,66 @@ declare const inf: number;
|
|
|
962
1016
|
declare const nan: number;
|
|
963
1017
|
/** This is Pi, `π = 3.14159265358979...` */
|
|
964
1018
|
declare const pi: number;
|
|
965
|
-
/** Element-wise addition, with broadcasting. */
|
|
1019
|
+
/** @function Element-wise addition, with broadcasting. */
|
|
966
1020
|
declare const add: (x: ArrayLike, y: ArrayLike) => Array;
|
|
967
|
-
/** Element-wise multiplication, with broadcasting. */
|
|
1021
|
+
/** @function Element-wise multiplication, with broadcasting. */
|
|
968
1022
|
declare const multiply: (x: ArrayLike, y: ArrayLike) => Array;
|
|
969
|
-
/** Numerical negative of every element of an array. */
|
|
1023
|
+
/** @function Numerical negative of every element of an array. */
|
|
970
1024
|
declare const negative: (x: ArrayLike) => Array;
|
|
971
|
-
/** Calculate element-wise reciprocal of the input. This is `1/x`. */
|
|
1025
|
+
/** @function Calculate element-wise reciprocal of the input. This is `1/x`. */
|
|
972
1026
|
declare const reciprocal: (x: ArrayLike) => Array;
|
|
973
|
-
/** Element-wise sine function (takes radians). */
|
|
1027
|
+
/** @function Element-wise sine function (takes radians). */
|
|
974
1028
|
declare const sin: (x: ArrayLike) => Array;
|
|
975
|
-
/** Element-wise cosine function (takes radians). */
|
|
1029
|
+
/** @function Element-wise cosine function (takes radians). */
|
|
976
1030
|
declare const cos: (x: ArrayLike) => Array;
|
|
977
|
-
/**
|
|
1031
|
+
/** @function Element-wise inverse sine function (inverse of sin). */
|
|
1032
|
+
declare const asin: (x: ArrayLike) => Array;
|
|
1033
|
+
/** @function Element-wise inverse tangent function (inverse of tan). */
|
|
1034
|
+
declare const atan: (x: ArrayLike) => Array;
|
|
1035
|
+
/** @function Calculate the exponential of all elements in the input array. */
|
|
978
1036
|
declare const exp: (x: ArrayLike) => Array;
|
|
979
|
-
/** Calculate the natural logarithm of all elements in the input array. */
|
|
1037
|
+
/** @function Calculate the natural logarithm of all elements in the input array. */
|
|
980
1038
|
declare const log: (x: ArrayLike) => Array;
|
|
981
|
-
/** Calculate the square root of all elements in the input array. */
|
|
1039
|
+
/** @function Calculate the square root of all elements in the input array. */
|
|
982
1040
|
declare const sqrt: (x: ArrayLike) => Array;
|
|
983
|
-
/** Return element-wise minimum of the input arrays. */
|
|
1041
|
+
/** @function Return element-wise minimum of the input arrays. */
|
|
984
1042
|
declare const minimum: (x: ArrayLike, y: ArrayLike) => Array;
|
|
985
|
-
/** Return element-wise maximum of the input arrays. */
|
|
1043
|
+
/** @function Return element-wise maximum of the input arrays. */
|
|
986
1044
|
declare const maximum: (x: ArrayLike, y: ArrayLike) => Array;
|
|
987
|
-
/** Compare two arrays element-wise. */
|
|
1045
|
+
/** @function Compare two arrays element-wise. */
|
|
988
1046
|
declare const greater: (x: ArrayLike, y: ArrayLike) => Array;
|
|
989
|
-
/** Compare two arrays element-wise. */
|
|
1047
|
+
/** @function Compare two arrays element-wise. */
|
|
990
1048
|
declare const less: (x: ArrayLike, y: ArrayLike) => Array;
|
|
991
|
-
/** Compare two arrays element-wise. */
|
|
1049
|
+
/** @function Compare two arrays element-wise. */
|
|
992
1050
|
declare const equal: (x: ArrayLike, y: ArrayLike) => Array;
|
|
993
|
-
/** Compare two arrays element-wise. */
|
|
1051
|
+
/** @function Compare two arrays element-wise. */
|
|
994
1052
|
declare const notEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
995
|
-
/** Compare two arrays element-wise. */
|
|
1053
|
+
/** @function Compare two arrays element-wise. */
|
|
996
1054
|
declare const greaterEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
997
|
-
/** Compare two arrays element-wise. */
|
|
1055
|
+
/** @function Compare two arrays element-wise. */
|
|
998
1056
|
declare const lessEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
999
|
-
/** Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
1057
|
+
/** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
1000
1058
|
declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
|
|
1001
|
-
/**
|
|
1059
|
+
/**
|
|
1060
|
+
* @function
|
|
1061
|
+
* Permute the dimensions of an array. Defaults to reversing the axis order.
|
|
1062
|
+
*/
|
|
1002
1063
|
declare const transpose: (x: ArrayLike, perm?: number[]) => Array;
|
|
1003
1064
|
/**
|
|
1065
|
+
* @function
|
|
1004
1066
|
* Give a new shape to an array without changing its data.
|
|
1005
1067
|
*
|
|
1006
1068
|
* One shape dimension can be -1. In this case, the value is inferred from the
|
|
1007
1069
|
* length of the array and remaining dimensions.
|
|
1008
1070
|
*/
|
|
1009
1071
|
declare const reshape: (x: ArrayLike, shape: number[]) => Array;
|
|
1010
|
-
/**
|
|
1072
|
+
/**
|
|
1073
|
+
* @function
|
|
1074
|
+
* Move axes of an array to new positions. Other axes retain original order.
|
|
1075
|
+
*/
|
|
1011
1076
|
declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
|
|
1012
1077
|
/**
|
|
1078
|
+
* @function
|
|
1013
1079
|
* Add padding (zeros) to an array.
|
|
1014
1080
|
*
|
|
1015
1081
|
* The `width` argument is either an integer or pair of integers, in which case
|
|
@@ -1017,15 +1083,27 @@ declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
|
|
|
1017
1083
|
* pair specifies the padding for its corresponding axis.
|
|
1018
1084
|
*/
|
|
1019
1085
|
declare const pad: (x: ArrayLike, width: number | [number, number] | [number, number][]) => Array;
|
|
1020
|
-
/**
|
|
1086
|
+
/**
|
|
1087
|
+
* @function
|
|
1088
|
+
* Return the number of dimensions of an array. Does not consume array reference.
|
|
1089
|
+
*/
|
|
1021
1090
|
declare const ndim: (x: ArrayLike) => number;
|
|
1022
|
-
/** Return the shape of an array. Does not consume array reference. */
|
|
1091
|
+
/** @function Return the shape of an array. Does not consume array reference. */
|
|
1023
1092
|
declare const shape$1: (x: ArrayLike) => number[];
|
|
1024
|
-
/**
|
|
1093
|
+
/**
|
|
1094
|
+
* @function
|
|
1095
|
+
* Return an array of zeros with the same shape and type as a given array.
|
|
1096
|
+
*/
|
|
1025
1097
|
declare const zerosLike: (a: ArrayLike, dtype?: DType) => Array;
|
|
1026
|
-
/**
|
|
1098
|
+
/**
|
|
1099
|
+
* @function
|
|
1100
|
+
* Return an array of ones with the same shape and type as a given array.
|
|
1101
|
+
*/
|
|
1027
1102
|
declare const onesLike: (a: ArrayLike, dtype?: DType) => Array;
|
|
1028
|
-
/**
|
|
1103
|
+
/**
|
|
1104
|
+
* @function
|
|
1105
|
+
* Return a full array with the same shape and type as a given array.
|
|
1106
|
+
*/
|
|
1029
1107
|
declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, dtype?: DType) => Array;
|
|
1030
1108
|
/**
|
|
1031
1109
|
* Return the number of elements in an array, optionally along an axis.
|
|
@@ -1035,15 +1113,15 @@ declare function size(a: ArrayLike, axis?: number): number;
|
|
|
1035
1113
|
/** Convert an array to a specified dtype. */
|
|
1036
1114
|
declare function astype(a: ArrayLike, dtype: DType): Array;
|
|
1037
1115
|
/** Sum of the elements of the array over a given axis, or axes. */
|
|
1038
|
-
declare function sum(a: ArrayLike, axis?:
|
|
1116
|
+
declare function sum(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1039
1117
|
/** Product of the array elements over a given axis. */
|
|
1040
|
-
declare function prod(a: ArrayLike, axis?:
|
|
1118
|
+
declare function prod(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1041
1119
|
/** Return the minimum of array elements along a given axis. */
|
|
1042
|
-
declare function min(a: ArrayLike, axis?:
|
|
1120
|
+
declare function min(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1043
1121
|
/** Return the maximum of array elements along a given axis. */
|
|
1044
|
-
declare function max(a: ArrayLike, axis?:
|
|
1122
|
+
declare function max(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1045
1123
|
/** Compute the average of the array elements along the specified axis. */
|
|
1046
|
-
declare function mean(a: ArrayLike, axis?:
|
|
1124
|
+
declare function mean(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1047
1125
|
/**
|
|
1048
1126
|
* Returns the indices of the minimum values along an axis.
|
|
1049
1127
|
*
|
|
@@ -1059,7 +1137,7 @@ declare function argmin(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
|
|
|
1059
1137
|
*/
|
|
1060
1138
|
declare function argmax(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
|
|
1061
1139
|
/** Reverse the elements in an array along the given axes. */
|
|
1062
|
-
declare function flip(x: ArrayLike, axis?:
|
|
1140
|
+
declare function flip(x: ArrayLike, axis?: Axis): Array;
|
|
1063
1141
|
/**
|
|
1064
1142
|
* Join a sequence of arrays along an existing axis.
|
|
1065
1143
|
*
|
|
@@ -1103,9 +1181,36 @@ declare function columnStack(xs: ArrayLike[]): Array;
|
|
|
1103
1181
|
declare function flipud(x: ArrayLike): Array;
|
|
1104
1182
|
/** Flip an array horizontally (axis=1). */
|
|
1105
1183
|
declare function fliplr(x: ArrayLike): Array;
|
|
1184
|
+
/** @function Alternative name for `numpy.transpose()`. */
|
|
1106
1185
|
declare const permuteDims: (x: ArrayLike, perm?: number[]) => Array;
|
|
1107
1186
|
/** Return a 1-D flattened array containing the elements of the input. */
|
|
1108
1187
|
declare function ravel(a: ArrayLike): Array;
|
|
1188
|
+
/**
|
|
1189
|
+
* Repeat each element of an array after themselves.
|
|
1190
|
+
*
|
|
1191
|
+
* If no axis is provided, use the flattened input array, and return a flat
|
|
1192
|
+
* output array.
|
|
1193
|
+
*/
|
|
1194
|
+
declare function repeat(a: ArrayLike, repeats: number, axis?: number): Array;
|
|
1195
|
+
/**
|
|
1196
|
+
* Construct an array by repeating A the number of times given by reps.
|
|
1197
|
+
*
|
|
1198
|
+
* If `A` is an array of shape `(d1, d2, ..., dn)` and `reps` is a sequence of
|
|
1199
|
+
* integers, the resulting array will have a shape of `(reps[0] * d1,
|
|
1200
|
+
* reps[1] * d2, ..., reps[n] * dn)`, with `A` tiled along each dimension.
|
|
1201
|
+
*/
|
|
1202
|
+
declare function tile(a: ArrayLike, reps: number | number[]): Array;
|
|
1203
|
+
/**
|
|
1204
|
+
* Broadcast an array to a shape, with NumPy-style broadcasing rules.
|
|
1205
|
+
*
|
|
1206
|
+
* In other words, this lets you append axes to the left, and/or expand
|
|
1207
|
+
* dimensions where the shape is 1.
|
|
1208
|
+
*/
|
|
1209
|
+
declare function broadcastTo(a: ArrayLike, shape: number[]): Array;
|
|
1210
|
+
/** Broadcast input shapes to a common output shape. */
|
|
1211
|
+
declare function broadcastShapes(...shapes: number[][]): number[];
|
|
1212
|
+
/** Broadcast arrays to a common shape. */
|
|
1213
|
+
declare function broadcastArrays(...arrays: ArrayLike[]): Array[];
|
|
1109
1214
|
/**
|
|
1110
1215
|
* Return specified diagonals.
|
|
1111
1216
|
*
|
|
@@ -1133,8 +1238,28 @@ declare function allclose(actual: Parameters<typeof array>[0], expected: Paramet
|
|
|
1133
1238
|
declare function matmul(x: ArrayLike, y: ArrayLike): Array;
|
|
1134
1239
|
/** Dot product of two arrays. */
|
|
1135
1240
|
declare function dot(x: ArrayLike, y: ArrayLike): Array;
|
|
1136
|
-
/**
|
|
1137
|
-
|
|
1241
|
+
/**
|
|
1242
|
+
* Compute the inner product of two arrays.
|
|
1243
|
+
*
|
|
1244
|
+
* Unlike `jax.numpy.matmul()` or `jax.numpy.dot()`, this always performs a
|
|
1245
|
+
* contraction on the last axis.
|
|
1246
|
+
*
|
|
1247
|
+
* Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
|
|
1248
|
+
*/
|
|
1249
|
+
declare function inner(x: ArrayLike, y: ArrayLike): Array;
|
|
1250
|
+
/**
|
|
1251
|
+
* Compute the outer product of two arrays.
|
|
1252
|
+
*
|
|
1253
|
+
* If the input arrays are not 1D, they will be flattened. Returned array will
|
|
1254
|
+
* be of shape `[x.size, y.size]`.
|
|
1255
|
+
*/
|
|
1256
|
+
declare function outer(x: ArrayLike, y: ArrayLike): Array;
|
|
1257
|
+
/** Vector dot product of two arrays along a given axis. */
|
|
1258
|
+
declare function vecdot(x: ArrayLike, y: ArrayLike, {
|
|
1259
|
+
axis
|
|
1260
|
+
}?: {
|
|
1261
|
+
axis?: number;
|
|
1262
|
+
}): Array;
|
|
1138
1263
|
/**
|
|
1139
1264
|
* Return the dot product of two vectors.
|
|
1140
1265
|
*
|
|
@@ -1152,6 +1277,21 @@ declare function meshgrid(xs: Array[], {
|
|
|
1152
1277
|
}?: {
|
|
1153
1278
|
indexing?: "xy" | "ij";
|
|
1154
1279
|
}): Array[];
|
|
1280
|
+
/**
|
|
1281
|
+
* Return an array with ones on and below the diagonal and zeros elsewhere.
|
|
1282
|
+
*
|
|
1283
|
+
* If `k` is provided, it specifies the sub-diagonal on and below which the
|
|
1284
|
+
* array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
|
|
1285
|
+
* `k>0` is above it.
|
|
1286
|
+
*/
|
|
1287
|
+
declare function tri(n: number, m?: number, k?: number, {
|
|
1288
|
+
dtype,
|
|
1289
|
+
device
|
|
1290
|
+
}?: DTypeAndDevice): Array;
|
|
1291
|
+
/** Return the lower triangle of an array. Must be of dimension >= 2. */
|
|
1292
|
+
declare function tril(a: ArrayLike, k?: number): Array;
|
|
1293
|
+
/** Return the upper triangle of an array. Must be of dimension >= 2. */
|
|
1294
|
+
declare function triu(a: ArrayLike, k?: number): Array;
|
|
1155
1295
|
/**
|
|
1156
1296
|
* Clip (limit) the values in an array.
|
|
1157
1297
|
*
|
|
@@ -1168,15 +1308,50 @@ declare function clip(a: ArrayLike, min?: ArrayLike, max?: ArrayLike): Array;
|
|
|
1168
1308
|
* This is the same function as `jax.numpy.abs()`.
|
|
1169
1309
|
*/
|
|
1170
1310
|
declare function absolute(x: ArrayLike): Array;
|
|
1171
|
-
/** Alias of `jax.numpy.absolute()`. */
|
|
1311
|
+
/** @function Alias of `jax.numpy.absolute()`. */
|
|
1172
1312
|
declare const abs: typeof absolute;
|
|
1313
|
+
/** Return an element-wise indication of sign of the input. */
|
|
1314
|
+
declare function sign(x: ArrayLike): Array;
|
|
1173
1315
|
/** Calculate element-wise square of the input array. */
|
|
1174
1316
|
declare function square(x: ArrayLike): Array;
|
|
1175
|
-
/**
|
|
1317
|
+
/** Element-wise tangent function (takes radians). */
|
|
1176
1318
|
declare function tan(x: ArrayLike): Array;
|
|
1319
|
+
/** Element-wise inverse cosine function (inverse of cos). */
|
|
1320
|
+
declare function acos(x: ArrayLike): Array;
|
|
1321
|
+
/**
|
|
1322
|
+
* @function
|
|
1323
|
+
* Return element-wise hypotenuse for the given legs of a right triangle.
|
|
1324
|
+
*
|
|
1325
|
+
* In the original NumPy/JAX implementation, this function is more numerically
|
|
1326
|
+
* stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
|
|
1327
|
+
* improvements.
|
|
1328
|
+
*/
|
|
1329
|
+
declare const hypot: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
1330
|
+
/**
|
|
1331
|
+
* @function
|
|
1332
|
+
* Element-wise arc tangent of y/x with correct quadrant.
|
|
1333
|
+
*
|
|
1334
|
+
* Returns the angle in radians between the positive x-axis and the point (x, y).
|
|
1335
|
+
* The result is in the range [-π, π].
|
|
1336
|
+
*
|
|
1337
|
+
* Uses numerically stable formulas:
|
|
1338
|
+
* - When x >= 0: atan2(y, x) = 2 * atan(y / (sqrt(x^2 + y^2) + x))
|
|
1339
|
+
* - When x < 0: atan2(y, x) = 2 * atan((sqrt(x^2 + y^2) - x) / y)
|
|
1340
|
+
*
|
|
1341
|
+
* The output is ill-defined when both x and y are zero.
|
|
1342
|
+
*/
|
|
1343
|
+
declare const atan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
|
|
1344
|
+
/** @function Alias of `jax.numpy.acos()`. */
|
|
1345
|
+
declare const arccos: typeof acos;
|
|
1346
|
+
/** @function Alias of `jax.numpy.atan()`. */
|
|
1347
|
+
declare const arctan: (x: ArrayLike) => Array;
|
|
1348
|
+
/** @function Alias of `jax.numpy.atan2()`. */
|
|
1349
|
+
declare const arctan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
|
|
1350
|
+
/** Element-wise subtraction, with broadcasting. */
|
|
1351
|
+
declare function subtract(x: ArrayLike, y: ArrayLike): Array;
|
|
1177
1352
|
/** Calculates the floating-point division of x by y element-wise. */
|
|
1178
1353
|
declare function trueDivide(x: ArrayLike, y: ArrayLike): Array;
|
|
1179
|
-
/** Alias of `jax.numpy.trueDivide()`. */
|
|
1354
|
+
/** @function Alias of `jax.numpy.trueDivide()`. */
|
|
1180
1355
|
declare const divide: typeof trueDivide;
|
|
1181
1356
|
/** Round input to the nearest integer towards zero. */
|
|
1182
1357
|
declare function trunc(x: ArrayLike): Array;
|
|
@@ -1186,26 +1361,112 @@ declare function exp2(p: ArrayLike): Array;
|
|
|
1186
1361
|
declare function log2(x: ArrayLike): Array;
|
|
1187
1362
|
/** Return the base-10 logarithm of x, element-wise. */
|
|
1188
1363
|
declare function log10(x: ArrayLike): Array;
|
|
1364
|
+
/** Calculate `exp(x) - 1` element-wise. */
|
|
1365
|
+
declare function expm1(x: ArrayLike): Array;
|
|
1366
|
+
/** Calculate the natural logarithm of `1 + x` element-wise. */
|
|
1367
|
+
declare function log1p(x: ArrayLike): Array;
|
|
1368
|
+
/** Convert angles from degrees to radians. */
|
|
1369
|
+
declare function deg2rad(x: ArrayLike): Array;
|
|
1370
|
+
/** @function Alias of `jax.numpy.deg2rad()`. */
|
|
1371
|
+
declare const radians: typeof deg2rad;
|
|
1372
|
+
/** Convert angles from radians to degrees. */
|
|
1373
|
+
declare function rad2deg(x: ArrayLike): Array;
|
|
1374
|
+
/** @function Alias of `jax.numpy.rad2deg()`. */
|
|
1375
|
+
declare const degrees: typeof rad2deg;
|
|
1189
1376
|
/**
|
|
1377
|
+
* @function
|
|
1378
|
+
* Computes first array raised to power of second array, element-wise.
|
|
1379
|
+
*/
|
|
1380
|
+
declare const power: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
1381
|
+
/** @function Alias of `jax.numpy.power()`. */
|
|
1382
|
+
declare const pow: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
1383
|
+
/** @function Calculate the element-wise cube root of the input array. */
|
|
1384
|
+
declare const cbrt: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1385
|
+
/**
|
|
1386
|
+
* @function
|
|
1190
1387
|
* Calculate element-wise hyperbolic sine of input.
|
|
1191
1388
|
*
|
|
1192
1389
|
* `sinh(x) = (exp(x) - exp(-x)) / 2`
|
|
1193
1390
|
*/
|
|
1194
|
-
declare
|
|
1391
|
+
declare const sinh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1195
1392
|
/**
|
|
1393
|
+
* @function
|
|
1196
1394
|
* Calculate element-wise hyperbolic cosine of input.
|
|
1197
1395
|
*
|
|
1198
1396
|
* `cosh(x) = (exp(x) + exp(-x)) / 2`
|
|
1199
1397
|
*/
|
|
1200
|
-
declare
|
|
1398
|
+
declare const cosh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1201
1399
|
/**
|
|
1400
|
+
* @function
|
|
1202
1401
|
* Calculate element-wise hyperbolic tangent of input.
|
|
1203
1402
|
*
|
|
1204
1403
|
* `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
|
|
1205
1404
|
*/
|
|
1206
|
-
declare
|
|
1405
|
+
declare const tanh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1406
|
+
/**
|
|
1407
|
+
* @function
|
|
1408
|
+
* Calculate element-wise inverse hyperbolic sine of input.
|
|
1409
|
+
*
|
|
1410
|
+
* `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
|
|
1411
|
+
*/
|
|
1412
|
+
declare const arcsinh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1413
|
+
/**
|
|
1414
|
+
* @function
|
|
1415
|
+
* Calculate element-wise inverse hyperbolic cosine of input.
|
|
1416
|
+
*
|
|
1417
|
+
* `arccosh(x) = ln(x + sqrt(x^2 - 1))`
|
|
1418
|
+
*/
|
|
1419
|
+
declare const arccosh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1420
|
+
/**
|
|
1421
|
+
* @function
|
|
1422
|
+
* Calculate element-wise inverse hyperbolic tangent of input.
|
|
1423
|
+
*
|
|
1424
|
+
* `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
|
|
1425
|
+
*/
|
|
1426
|
+
declare const arctanh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1427
|
+
/** @function Alias of `jax.numpy.arcsinh()`. */
|
|
1428
|
+
declare const asinh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1429
|
+
/** @function Alias of `jax.numpy.arccosh()`. */
|
|
1430
|
+
declare const acosh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1431
|
+
/** @function Alias of `jax.numpy.arctanh()`. */
|
|
1432
|
+
declare const atanh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1433
|
+
/**
|
|
1434
|
+
* Compute the variance of an array.
|
|
1435
|
+
*
|
|
1436
|
+
* The variance is computed for the flattened array by default, otherwise over
|
|
1437
|
+
* the specified axis.
|
|
1438
|
+
*
|
|
1439
|
+
* If `correction` is provided, the divisor in calculation is `N - correction`,
|
|
1440
|
+
* where `N` represents the number of elements (e.g., for Bessel's correction).
|
|
1441
|
+
*/
|
|
1442
|
+
declare function var_(x: ArrayLike, axis?: Axis, opts?: {
|
|
1443
|
+
mean?: ArrayLike;
|
|
1444
|
+
correction?: number;
|
|
1445
|
+
} & ReduceOpts): Array;
|
|
1446
|
+
/**
|
|
1447
|
+
* Compute the standard deviation of an array.
|
|
1448
|
+
*
|
|
1449
|
+
* The standard deviation is computed for the flattened array by default,
|
|
1450
|
+
* otherwise over the specified axis.
|
|
1451
|
+
*
|
|
1452
|
+
* If `correction` is provided, the divisor in calculation is `N - correction`,
|
|
1453
|
+
* where `N` represents the number of elements (e.g., for Bessel's correction).
|
|
1454
|
+
*/
|
|
1455
|
+
declare function std(x: ArrayLike, axis?: Axis, opts?: {
|
|
1456
|
+
mean?: ArrayLike;
|
|
1457
|
+
correction?: number;
|
|
1458
|
+
} & ReduceOpts): Array;
|
|
1207
1459
|
//#endregion
|
|
1208
1460
|
//#region src/frontend/jaxpr.d.ts
|
|
1461
|
+
/**
|
|
1462
|
+
* Function callback with an associated dispose() method.
|
|
1463
|
+
*
|
|
1464
|
+
* The dispose() method should be called to clean up any tracer resources needed
|
|
1465
|
+
* by the function after the last time it is called.
|
|
1466
|
+
*/
|
|
1467
|
+
type OwnedFunction<F extends Function> = F & {
|
|
1468
|
+
dispose: () => void;
|
|
1469
|
+
};
|
|
1209
1470
|
/** Variable in a Jaxpr expression. */
|
|
1210
1471
|
declare class Var {
|
|
1211
1472
|
#private;
|
|
@@ -1297,7 +1558,7 @@ declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding:
|
|
|
1297
1558
|
/** Reduce a computation over padded windows. */
|
|
1298
1559
|
declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
|
|
1299
1560
|
declare namespace nn_d_exports {
|
|
1300
|
-
export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, swish };
|
|
1561
|
+
export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, squareplus, standardize, swish };
|
|
1301
1562
|
}
|
|
1302
1563
|
/**
|
|
1303
1564
|
* Rectified Linear Unit (ReLU) activation function:
|
|
@@ -1329,6 +1590,7 @@ declare function softplus(x: ArrayLike): Array;
|
|
|
1329
1590
|
*/
|
|
1330
1591
|
declare function softSign(x: ArrayLike): Array;
|
|
1331
1592
|
/**
|
|
1593
|
+
* @function
|
|
1332
1594
|
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
1333
1595
|
* Swish, computed element-wise:
|
|
1334
1596
|
* `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
|
|
@@ -1337,8 +1599,9 @@ declare function softSign(x: ArrayLike): Array;
|
|
|
1337
1599
|
*
|
|
1338
1600
|
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
1339
1601
|
*/
|
|
1340
|
-
declare const silu: (x: ArrayLike) => Array
|
|
1602
|
+
declare const silu: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1341
1603
|
/**
|
|
1604
|
+
* @function
|
|
1342
1605
|
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
1343
1606
|
* Swish, computed element-wise:
|
|
1344
1607
|
* `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
|
|
@@ -1347,31 +1610,35 @@ declare const silu: (x: ArrayLike) => Array;
|
|
|
1347
1610
|
*
|
|
1348
1611
|
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
1349
1612
|
*/
|
|
1350
|
-
declare const swish: (x: ArrayLike) => Array
|
|
1613
|
+
declare const swish: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1351
1614
|
/**
|
|
1352
1615
|
* Log-sigmoid activation function, computed element-wise:
|
|
1353
1616
|
* `log_sigmoid(x) = log(sigmoid(x)) = -log(1 + exp(-x))`.
|
|
1354
1617
|
*/
|
|
1355
1618
|
declare function logSigmoid(x: ArrayLike): Array;
|
|
1356
|
-
/**
|
|
1619
|
+
/**
|
|
1620
|
+
* @function
|
|
1621
|
+
* Identity activation function. Returns the argument unmodified.
|
|
1622
|
+
*/
|
|
1357
1623
|
declare const identity: (x: ArrayLike) => Array;
|
|
1358
1624
|
/** Leaky rectified linear (ReLU) activation function */
|
|
1359
|
-
declare function leakyRelu(x: ArrayLike, negativeSlope?:
|
|
1625
|
+
declare function leakyRelu(x: ArrayLike, negativeSlope?: ArrayLike): Array;
|
|
1360
1626
|
/**
|
|
1361
1627
|
* Exponential linear unit activation function.
|
|
1362
1628
|
*
|
|
1363
1629
|
* Computes the element-wise function:
|
|
1364
1630
|
* `elu(x) = x > 0 ? x : alpha * (exp(x) - 1)`
|
|
1365
1631
|
*/
|
|
1366
|
-
declare function elu(x: ArrayLike, alpha?:
|
|
1632
|
+
declare function elu(x: ArrayLike, alpha?: ArrayLike): Array;
|
|
1367
1633
|
/**
|
|
1368
1634
|
* Continuously-differentiable exponential linear unit activation function.
|
|
1369
1635
|
*
|
|
1370
1636
|
* Computes the element-wise function:
|
|
1371
1637
|
* `celu(x) = x > 0 ? x : alpha * (exp(x/alpha) - 1)`
|
|
1372
1638
|
*/
|
|
1373
|
-
declare function celu(x: ArrayLike, alpha?:
|
|
1639
|
+
declare function celu(x: ArrayLike, alpha?: ArrayLike): Array;
|
|
1374
1640
|
/**
|
|
1641
|
+
* @function
|
|
1375
1642
|
* Gaussion error linear unit (GELU) activation function.
|
|
1376
1643
|
*
|
|
1377
1644
|
* This is computed element-wise. Currently jax-js does not support the erf() or
|
|
@@ -1382,7 +1649,7 @@ declare function celu(x: ArrayLike, alpha?: number): Array;
|
|
|
1382
1649
|
*
|
|
1383
1650
|
* This will be improved in the future.
|
|
1384
1651
|
*/
|
|
1385
|
-
declare const gelu: (x: ArrayLike) => Array
|
|
1652
|
+
declare const gelu: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1386
1653
|
/**
|
|
1387
1654
|
* Gated linear unit (GLU) activation function.
|
|
1388
1655
|
*
|
|
@@ -1390,6 +1657,13 @@ declare const gelu: (x: ArrayLike) => Array;
|
|
|
1390
1657
|
* computes `a * sigmoid(b)`.
|
|
1391
1658
|
*/
|
|
1392
1659
|
declare function glu(x: ArrayLike, axis?: number): Array;
|
|
1660
|
+
/**
|
|
1661
|
+
* Squareplus activation function.
|
|
1662
|
+
*
|
|
1663
|
+
* Computes the element-wise function:
|
|
1664
|
+
* `squareplus(x) = 0.5 * (x + sqrt(x^2 + b))`
|
|
1665
|
+
*/
|
|
1666
|
+
declare function squareplus(x: ArrayLike, b?: ArrayLike): Array;
|
|
1393
1667
|
/**
|
|
1394
1668
|
* Mish activation function.
|
|
1395
1669
|
*
|
|
@@ -1405,7 +1679,7 @@ declare function mish(x: ArrayLike): Array;
|
|
|
1405
1679
|
*
|
|
1406
1680
|
* Reference: https://en.wikipedia.org/wiki/Softmax_function
|
|
1407
1681
|
*/
|
|
1408
|
-
declare function softmax(x: ArrayLike, axis?:
|
|
1682
|
+
declare function softmax(x: ArrayLike, axis?: Axis): Array;
|
|
1409
1683
|
/**
|
|
1410
1684
|
* Log-Softmax function.
|
|
1411
1685
|
*
|
|
@@ -1414,7 +1688,7 @@ declare function softmax(x: ArrayLike, axis?: number | number[]): Array;
|
|
|
1414
1688
|
*
|
|
1415
1689
|
* If `axis` is not specified, it defaults to the last axis.
|
|
1416
1690
|
*/
|
|
1417
|
-
declare function logSoftmax(x: ArrayLike, axis?:
|
|
1691
|
+
declare function logSoftmax(x: ArrayLike, axis?: Axis): Array;
|
|
1418
1692
|
/**
|
|
1419
1693
|
* Log-sum-exp reduction. Also a multivariate version of `softplus`.
|
|
1420
1694
|
*
|
|
@@ -1423,7 +1697,22 @@ declare function logSoftmax(x: ArrayLike, axis?: number | number[]): Array;
|
|
|
1423
1697
|
*
|
|
1424
1698
|
* Reference: https://en.wikipedia.org/wiki/LogSumExp
|
|
1425
1699
|
*/
|
|
1426
|
-
declare function logsumexp(x: ArrayLike, axis?:
|
|
1700
|
+
declare function logsumexp(x: ArrayLike, axis?: Axis): Array;
|
|
1701
|
+
/** Log-mean-exp reduction, like `jax.nn.logsumexp()` but subtracts `log(n)`. */
|
|
1702
|
+
declare function logmeanexp(x: ArrayLike, axis?: Axis): Array;
|
|
1703
|
+
/**
|
|
1704
|
+
* Standardizes input to zero mean and unit variance.
|
|
1705
|
+
*
|
|
1706
|
+
* By default, this is computed over the last axis. You can pass in a different
|
|
1707
|
+
* axis, or `null` to standardize over all elements.
|
|
1708
|
+
*
|
|
1709
|
+
* Epsilon is added to denominator, it defaults to `1e-5` for stability.
|
|
1710
|
+
*/
|
|
1711
|
+
declare function standardize(x: ArrayLike, axis?: Axis, opts?: {
|
|
1712
|
+
mean?: ArrayLike;
|
|
1713
|
+
variance?: ArrayLike;
|
|
1714
|
+
epsilon?: ArrayLike;
|
|
1715
|
+
}): Array;
|
|
1427
1716
|
/**
|
|
1428
1717
|
* One-hot encodes the given indices.
|
|
1429
1718
|
*
|
|
@@ -1442,7 +1731,7 @@ declare function logsumexp(x: ArrayLike, axis?: number | number[]): Array;
|
|
|
1442
1731
|
*/
|
|
1443
1732
|
declare function oneHot(x: Array, numClasses: number): Array;
|
|
1444
1733
|
declare namespace random_d_exports {
|
|
1445
|
-
export { bits, key, split, uniform };
|
|
1734
|
+
export { bernoulli, bits, exponential, key, normal, split, uniform };
|
|
1446
1735
|
}
|
|
1447
1736
|
/** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
|
|
1448
1737
|
declare function key(seed: number): Array;
|
|
@@ -1458,26 +1747,59 @@ declare function uniform(key: Array, shape?: number[], {
|
|
|
1458
1747
|
minval?: number;
|
|
1459
1748
|
maxval?: number;
|
|
1460
1749
|
}): Array;
|
|
1750
|
+
/**
|
|
1751
|
+
* Sample Bernoulli random variables with given mean (0,1 categorical).
|
|
1752
|
+
*
|
|
1753
|
+
* Returns a random Boolean array with the specified shape. `p` can be an array
|
|
1754
|
+
* and must be broadcastable to `shape`.
|
|
1755
|
+
*/
|
|
1756
|
+
declare function bernoulli(key: Array, p?: ArrayLike, shape?: number[]): Array;
|
|
1757
|
+
/** Sample exponential random values according to `p(x) = exp(-x)`. */
|
|
1758
|
+
declare function exponential(key: Array, shape?: number[]): Array;
|
|
1759
|
+
/**
|
|
1760
|
+
* Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
|
|
1761
|
+
*
|
|
1762
|
+
* Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
|
|
1763
|
+
* directly inverts the CDF, but we don't have support for that yet. Outputs will not be
|
|
1764
|
+
* bitwise identical to JAX.
|
|
1765
|
+
*/
|
|
1766
|
+
declare function normal(key: Array, shape?: number[]): Array;
|
|
1461
1767
|
//#endregion
|
|
1462
1768
|
//#region src/index.d.ts
|
|
1463
|
-
/**
|
|
1769
|
+
/**
|
|
1770
|
+
* @function
|
|
1771
|
+
* Compute the forward-mode Jacobian-vector product for a function.
|
|
1772
|
+
*/
|
|
1464
1773
|
declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, ReturnType<F>];
|
|
1465
|
-
/**
|
|
1774
|
+
/**
|
|
1775
|
+
* @function
|
|
1776
|
+
* Vectorize an operation on a batched axis for one or more inputs.
|
|
1777
|
+
*/
|
|
1466
1778
|
declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number | MapJsTree<Parameters<F>, ArrayLike, number | null>) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
1467
|
-
/**
|
|
1779
|
+
/**
|
|
1780
|
+
* @function
|
|
1781
|
+
* Compute the Jacobian evaluated column-by-column by forward-mode AD.
|
|
1782
|
+
*/
|
|
1468
1783
|
declare const jacfwd: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
1469
|
-
/**
|
|
1784
|
+
/**
|
|
1785
|
+
* @function
|
|
1786
|
+
* Construct a Jaxpr by dynamically tracing a function with example inputs.
|
|
1787
|
+
*/
|
|
1470
1788
|
declare const makeJaxpr: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...args: Parameters<F>) => {
|
|
1471
1789
|
jaxpr: Jaxpr;
|
|
1472
1790
|
consts: Array[];
|
|
1473
1791
|
treedef: JsTreeDef;
|
|
1474
1792
|
};
|
|
1475
1793
|
/**
|
|
1794
|
+
* @function
|
|
1476
1795
|
* Mark a function for automatic JIT compilation, with operator fusion.
|
|
1477
1796
|
*
|
|
1478
1797
|
* The function will be compiled the first time it is called with a set of
|
|
1479
1798
|
* argument shapes.
|
|
1480
1799
|
*
|
|
1800
|
+
* You can call `.dispose()` on the returned, JIT-compiled function after all
|
|
1801
|
+
* calls to free memory associated with array constants.
|
|
1802
|
+
*
|
|
1481
1803
|
* **Options:**
|
|
1482
1804
|
* - `staticArgnums`: An array of argument indices to treat as static
|
|
1483
1805
|
* (compile-time constant). These arguments must be hashable, won't be traced,
|
|
@@ -1485,24 +1807,48 @@ declare const makeJaxpr: <F extends (...args: any[]) => JsTree<Array>>(f: F) =>
|
|
|
1485
1807
|
* - `device`: The device to place the computation on. If not specified, the
|
|
1486
1808
|
* computation will be placed on the default device.
|
|
1487
1809
|
*/
|
|
1488
|
-
declare const jit: <F extends (...args: any[]) => JsTree<Array>>(f: F, opts?: JitOpts) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F
|
|
1810
|
+
declare const jit: <F extends (...args: any[]) => JsTree<Array>>(f: F, opts?: JitOpts) => OwnedFunction<(...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>>;
|
|
1489
1811
|
/**
|
|
1812
|
+
* @function
|
|
1490
1813
|
* Produce a local linear approximation to a function at a point using jvp() and
|
|
1491
1814
|
* partial evaluation.
|
|
1492
1815
|
*/
|
|
1493
1816
|
declare const linearize: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (...tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>];
|
|
1494
|
-
/**
|
|
1817
|
+
/**
|
|
1818
|
+
* @function
|
|
1819
|
+
* Calculate the reverse-mode vector-Jacobian product for a function.
|
|
1820
|
+
*/
|
|
1495
1821
|
declare const vjp: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (cotangents: MapJsTree<ReturnType<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>];
|
|
1496
1822
|
/**
|
|
1823
|
+
* @function
|
|
1497
1824
|
* Compute the gradient of a scalar-valued function `f` with respect to its
|
|
1498
1825
|
* first argument.
|
|
1499
1826
|
*/
|
|
1500
1827
|
declare const grad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>[0], ArrayLike, Array>;
|
|
1501
|
-
/**
|
|
1828
|
+
/**
|
|
1829
|
+
* @function
|
|
1830
|
+
* Create a function that evaluates both `f` and the gradient of `f`.
|
|
1831
|
+
*/
|
|
1502
1832
|
declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, MapJsTree<Parameters<F>[0], ArrayLike, Array>];
|
|
1503
|
-
/**
|
|
1833
|
+
/**
|
|
1834
|
+
* @function
|
|
1835
|
+
* Compute the Jacobian evaluated row-by-row by reverse-mode AD.
|
|
1836
|
+
*/
|
|
1504
1837
|
declare const jacrev: typeof jacfwd;
|
|
1505
|
-
/**
|
|
1838
|
+
/**
|
|
1839
|
+
* @function
|
|
1840
|
+
* Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
|
|
1841
|
+
*/
|
|
1506
1842
|
declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
1843
|
+
/**
|
|
1844
|
+
* Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
|
|
1845
|
+
*
|
|
1846
|
+
* This can be used to wait for the results of an intermediate computation to
|
|
1847
|
+
* finish. It's recommended to call this regularly in an iterative computation
|
|
1848
|
+
* to avoid queueing up too many pending operations.
|
|
1849
|
+
*
|
|
1850
|
+
* Does not consume reference to the arrays.
|
|
1851
|
+
*/
|
|
1852
|
+
declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
|
|
1507
1853
|
//#endregion
|
|
1508
|
-
export { DType, type Device, type JsTree, type JsTreeDef, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random,
|
|
1854
|
+
export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|