@jax-js/jax 0.0.3 → 0.0.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +96 -22
- package/dist/{backend-BqDtPGaR.js → backend-CdcTZEOF.js} +325 -153
- package/dist/{backend-D2C4MJRP.cjs → backend-yEU0L_ig.cjs} +350 -154
- package/dist/index.cjs +977 -354
- package/dist/index.d.cts +479 -88
- package/dist/index.d.ts +479 -88
- package/dist/index.js +964 -345
- package/dist/{webgpu-CNg9JGva.js → webgpu-CM-xNYzW.js} +9 -3
- package/dist/{webgpu-fqhx41TC.cjs → webgpu-CNOpiO5T.cjs} +9 -3
- package/package.json +15 -4
package/dist/index.d.cts
CHANGED
|
@@ -126,6 +126,19 @@ declare class ShapeTracker {
|
|
|
126
126
|
}
|
|
127
127
|
//#endregion
|
|
128
128
|
//#region src/utils.d.ts
|
|
129
|
+
/**
|
|
130
|
+
* Set the debug level for verbose logging.
|
|
131
|
+
*
|
|
132
|
+
* 1. JIT compile logs
|
|
133
|
+
* 2. Shader code
|
|
134
|
+
* 3. Expressions and metadata
|
|
135
|
+
* 4. JIT programs, tuning details
|
|
136
|
+
* 5. Most verbose operation traces
|
|
137
|
+
*
|
|
138
|
+
* This is an experimental API and may change in behavior. Do not rely on this
|
|
139
|
+
* in production.
|
|
140
|
+
*/
|
|
141
|
+
declare function setDebug(level: number): void;
|
|
129
142
|
/** @inline */
|
|
130
143
|
type RecursiveArray<T> = T | RecursiveArray<T>[];
|
|
131
144
|
interface FpHashable {
|
|
@@ -141,7 +154,7 @@ interface FpHashable {
|
|
|
141
154
|
declare class FpHash {
|
|
142
155
|
#private;
|
|
143
156
|
value: bigint;
|
|
144
|
-
update(
|
|
157
|
+
update(x: string | boolean | number | bigint | null | undefined | FpHashable): this;
|
|
145
158
|
static hash(...values: (string | boolean | number | bigint | null | undefined | FpHashable)[]): bigint;
|
|
146
159
|
}
|
|
147
160
|
/** Run a function while caching it inline inside a `Map`. */
|
|
@@ -157,6 +170,31 @@ declare enum DType {
|
|
|
157
170
|
}
|
|
158
171
|
/** @inline */
|
|
159
172
|
type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | Float16Array<ArrayBuffer>;
|
|
173
|
+
/**
|
|
174
|
+
* Promote two dtypes to their join according to the type lattice.
|
|
175
|
+
*
|
|
176
|
+
* When performing operations between arrays of different types, we need to
|
|
177
|
+
* promote both operands to a common type that can represent values from both
|
|
178
|
+
* input types. This follows JAX's type promotion rules.
|
|
179
|
+
*
|
|
180
|
+
* **Type lattice:**
|
|
181
|
+
* ```text
|
|
182
|
+
* bool -> uint32 -> int32 -> float16 -> float32
|
|
183
|
+
* weakType --^
|
|
184
|
+
* ```
|
|
185
|
+
*
|
|
186
|
+
* `weakType` represents weakly typed arrays. These are created for JS numbers,
|
|
187
|
+
* which default to float32 but "weak" so they cast to the dtype of any array
|
|
188
|
+
* they are first combined with, except `bool`.
|
|
189
|
+
*
|
|
190
|
+
* **Examples:**
|
|
191
|
+
* - `promoteTypes(bool, int32) → int32`
|
|
192
|
+
* - `promoteTypes(uint32, int32) → int32`
|
|
193
|
+
* - `promoteTypes(int32, float16) → float16`
|
|
194
|
+
* - `promoteTypes(float16, float32) → float32`
|
|
195
|
+
* - `promoteTypes(uint32, float32) → float32`
|
|
196
|
+
*/
|
|
197
|
+
declare function promoteTypes(dtype1: DType, dtype2: DType): DType;
|
|
160
198
|
/**
|
|
161
199
|
* Mathematical expression on scalar values.
|
|
162
200
|
*
|
|
@@ -180,6 +218,8 @@ declare class AluExp implements FpHashable {
|
|
|
180
218
|
static max(a: AluExp, b: AluExp): AluExp;
|
|
181
219
|
static sin(a: AluExp): AluExp;
|
|
182
220
|
static cos(a: AluExp): AluExp;
|
|
221
|
+
static asin(a: AluExp): AluExp;
|
|
222
|
+
static atan(a: AluExp): AluExp;
|
|
183
223
|
static exp(a: AluExp): AluExp;
|
|
184
224
|
static log(a: AluExp): AluExp;
|
|
185
225
|
static sqrt(a: AluExp): AluExp;
|
|
@@ -265,6 +305,8 @@ declare enum AluOp {
|
|
|
265
305
|
Max = "Max",
|
|
266
306
|
Sin = "Sin",
|
|
267
307
|
Cos = "Cos",
|
|
308
|
+
Asin = "Asin",
|
|
309
|
+
Atan = "Atan",
|
|
268
310
|
Exp = "Exp",
|
|
269
311
|
Log = "Log",
|
|
270
312
|
Sqrt = "Sqrt",
|
|
@@ -357,8 +399,8 @@ declare class Reduction implements FpHashable {
|
|
|
357
399
|
//#region src/backend.d.ts
|
|
358
400
|
type Device = "cpu" | "wasm" | "webgpu";
|
|
359
401
|
declare const devices: Device[];
|
|
360
|
-
/**
|
|
361
|
-
declare function
|
|
402
|
+
/** Configure the default device for arrays. */
|
|
403
|
+
declare function defaultDevice(device?: Device): Device;
|
|
362
404
|
/**
|
|
363
405
|
* Initialize `jax-js` library backends.
|
|
364
406
|
*
|
|
@@ -494,6 +536,8 @@ declare enum Primitive {
|
|
|
494
536
|
RandomBits = "random_bits",
|
|
495
537
|
Sin = "sin",
|
|
496
538
|
Cos = "cos",
|
|
539
|
+
Asin = "asin",
|
|
540
|
+
Atan = "atan",
|
|
497
541
|
Exp = "exp",
|
|
498
542
|
Log = "log",
|
|
499
543
|
Sqrt = "sqrt",
|
|
@@ -569,6 +613,7 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
|
569
613
|
outDim: number;
|
|
570
614
|
};
|
|
571
615
|
[Primitive.JitCall]: {
|
|
616
|
+
name: string;
|
|
572
617
|
jaxpr: Jaxpr;
|
|
573
618
|
numConsts: number;
|
|
574
619
|
};
|
|
@@ -584,8 +629,10 @@ declare enum CompareOp {
|
|
|
584
629
|
LessEqual = "less_equal",
|
|
585
630
|
}
|
|
586
631
|
/** @inline */
|
|
632
|
+
type Axis = number | number[] | null;
|
|
633
|
+
/** @inline */
|
|
587
634
|
type ReduceOpts = {
|
|
588
|
-
|
|
635
|
+
keepdims?: boolean;
|
|
589
636
|
};
|
|
590
637
|
type MainTrace = {
|
|
591
638
|
level: number;
|
|
@@ -605,10 +652,40 @@ declare abstract class Trace {
|
|
|
605
652
|
abstract lift(val: Tracer): Tracer;
|
|
606
653
|
abstract processPrimitive<P extends Primitive>(primitive: P, tracers: Tracer[], params: PrimitiveParams<P>): Tracer[];
|
|
607
654
|
}
|
|
655
|
+
/** Internal representation of an array value. */
|
|
608
656
|
interface AbstractValue {
|
|
657
|
+
/** Shape of the array. Must be a static tuple of non-negative dimensions. */
|
|
609
658
|
shape: number[];
|
|
659
|
+
/** Concrete data type of array elements. */
|
|
610
660
|
dtype: DType;
|
|
661
|
+
/**
|
|
662
|
+
* Arrays created from JavaScript numbers (e.g., `np.array(3)`) are created as
|
|
663
|
+
* _weakly typed_ unless a dtype is explicitly specified.
|
|
664
|
+
*
|
|
665
|
+
* Weakly typed values will automatically cast to the data type of other
|
|
666
|
+
* arrays when used as an operand as an expression. This property only affects
|
|
667
|
+
* how they promote in type casting; their memory layout is still determined
|
|
668
|
+
* by the actual `dtype` field.
|
|
669
|
+
*
|
|
670
|
+
* ```ts
|
|
671
|
+
* const x = np.array(3); // weakType = true, dtype = float32
|
|
672
|
+
* const y = np.array([1, 2], { dtype: np.int32 }); // weakType = false, dtype = int32
|
|
673
|
+
* const z = x.add(y); // z has dtype int32 because x is weakly typed
|
|
674
|
+
* ```
|
|
675
|
+
*
|
|
676
|
+
* Weak types are present in JIT programs in their spec (e.g., Jaxpr inputs
|
|
677
|
+
* and outputs can be weakly typed) form. But they're solely a frontend
|
|
678
|
+
* concept. Backends are not aware of weak types.
|
|
679
|
+
*/
|
|
680
|
+
weakType: boolean;
|
|
611
681
|
}
|
|
682
|
+
/**
|
|
683
|
+
* Broadcast shapes and promote types with casting for two avals.
|
|
684
|
+
*
|
|
685
|
+
* This implements the weak type behavior described in `promoteTypes()`, but not
|
|
686
|
+
* implemented in that function as `weakType` is not passed.
|
|
687
|
+
*/
|
|
688
|
+
|
|
612
689
|
declare abstract class Tracer {
|
|
613
690
|
/** @ignore */
|
|
614
691
|
readonly _trace: Trace;
|
|
@@ -662,9 +739,20 @@ declare abstract class Tracer {
|
|
|
662
739
|
* ```
|
|
663
740
|
*/
|
|
664
741
|
abstract dispose(): void;
|
|
742
|
+
/** The shape of the array. */
|
|
665
743
|
get shape(): number[];
|
|
744
|
+
/** The total number of elements in the array. */
|
|
666
745
|
get size(): number;
|
|
746
|
+
/** The dtype of elements stored in the array. */
|
|
667
747
|
get dtype(): DType;
|
|
748
|
+
/**
|
|
749
|
+
* Whether the array is weakly typed.
|
|
750
|
+
*
|
|
751
|
+
* Weakly typed arrays will cast to the dtype of the other operand. See
|
|
752
|
+
* `promoteTypes()` for details.
|
|
753
|
+
*/
|
|
754
|
+
get weakType(): boolean;
|
|
755
|
+
/** The number of dimensions of the array. */
|
|
668
756
|
get ndim(): number;
|
|
669
757
|
/** @ignore */
|
|
670
758
|
fullLower(): Tracer;
|
|
@@ -678,11 +766,11 @@ declare abstract class Tracer {
|
|
|
678
766
|
greaterEqual(other: this | TracerValue): this;
|
|
679
767
|
lessEqual(other: this | TracerValue): this;
|
|
680
768
|
/** Sum of the elements of the array over a given axis, or axes. */
|
|
681
|
-
sum(axis?:
|
|
769
|
+
sum(axis?: Axis, opts?: ReduceOpts): this;
|
|
682
770
|
/** Product of the array elements over a given axis. */
|
|
683
|
-
prod(axis?:
|
|
771
|
+
prod(axis?: Axis, opts?: ReduceOpts): this;
|
|
684
772
|
/** Compute the average of the array elements along the specified axis. */
|
|
685
|
-
mean(axis?:
|
|
773
|
+
mean(axis?: Axis, opts?: ReduceOpts): this;
|
|
686
774
|
/** Permute the dimensions of an array. Defaults to reversing the axis order. */
|
|
687
775
|
transpose(perm?: number[]): this;
|
|
688
776
|
/**
|
|
@@ -755,7 +843,8 @@ declare abstract class Tracer {
|
|
|
755
843
|
declare class ShapedArray implements AbstractValue {
|
|
756
844
|
readonly shape: number[];
|
|
757
845
|
readonly dtype: DType;
|
|
758
|
-
|
|
846
|
+
readonly weakType: boolean;
|
|
847
|
+
constructor(shape: number[], dtype: DType, weakType: boolean);
|
|
759
848
|
static fromAval(aval: AbstractValue): ShapedArray;
|
|
760
849
|
get ndim(): number;
|
|
761
850
|
toString(): string;
|
|
@@ -791,6 +880,14 @@ type DTypeAndDevice = {
|
|
|
791
880
|
dtype?: DType;
|
|
792
881
|
device?: Device;
|
|
793
882
|
};
|
|
883
|
+
type ArrayConstructorArgs = {
|
|
884
|
+
source: AluExp | Slot;
|
|
885
|
+
st: ShapeTracker;
|
|
886
|
+
dtype: DType;
|
|
887
|
+
weakType: boolean;
|
|
888
|
+
backend: Backend;
|
|
889
|
+
pending?: Iterable<PendingExecute>;
|
|
890
|
+
};
|
|
794
891
|
/**
|
|
795
892
|
* A multidimensional numeric array with data stored on CPU or GPU.
|
|
796
893
|
*
|
|
@@ -810,7 +907,7 @@ declare class Array extends Tracer {
|
|
|
810
907
|
* is a backend `Slot`, this constructor _takes ownership_ of the slot. It
|
|
811
908
|
* will be freed when the array is disposed.
|
|
812
909
|
*/
|
|
813
|
-
constructor(
|
|
910
|
+
constructor(args: ArrayConstructorArgs);
|
|
814
911
|
/** @ignore */
|
|
815
912
|
get aval(): ShapedArray;
|
|
816
913
|
/** Return a simple string representation of the array's dimensions. */
|
|
@@ -839,8 +936,11 @@ declare class Array extends Tracer {
|
|
|
839
936
|
*
|
|
840
937
|
* If you are mapping from `data()` or `dataSync()`, it will also trigger
|
|
841
938
|
* dispatch of operations as well.
|
|
939
|
+
*
|
|
940
|
+
* **Note:** `jax.blockUntilReady()` is a higher-level API, it calls this
|
|
941
|
+
* asynchronously for multiple arrays.
|
|
842
942
|
*/
|
|
843
|
-
|
|
943
|
+
blockUntilReady(): Promise<Array>;
|
|
844
944
|
/**
|
|
845
945
|
* Realize the array and return it as data. This is a sync variant and not
|
|
846
946
|
* recommended for performance reasons, as it will block rendering.
|
|
@@ -869,11 +969,6 @@ declare class Array extends Tracer {
|
|
|
869
969
|
static _implRules(): typeof implRules;
|
|
870
970
|
_realizeSource(): number;
|
|
871
971
|
}
|
|
872
|
-
/** Construct an array from a single scalar constant. */
|
|
873
|
-
declare function scalar(value: number | boolean, {
|
|
874
|
-
dtype,
|
|
875
|
-
device
|
|
876
|
-
}?: DTypeAndDevice): Array;
|
|
877
972
|
/** Constructor for creating a new array from data. */
|
|
878
973
|
declare function array(values: Array | Float16Array<ArrayBuffer> | Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | RecursiveArray<number> | RecursiveArray<boolean>, {
|
|
879
974
|
shape,
|
|
@@ -911,7 +1006,7 @@ declare function eye(numRows: number, numCols?: number, {
|
|
|
911
1006
|
dtype,
|
|
912
1007
|
device
|
|
913
1008
|
}?: DTypeAndDevice): Array;
|
|
914
|
-
/** Return the identity
|
|
1009
|
+
/** Return the identity matrix, with ones on the main diagonal. */
|
|
915
1010
|
declare function identity$1(n: number, {
|
|
916
1011
|
dtype,
|
|
917
1012
|
device
|
|
@@ -948,7 +1043,7 @@ declare function linspace(start: number, stop: number, num?: number, endpoint?:
|
|
|
948
1043
|
device
|
|
949
1044
|
}?: DTypeAndDevice): Array;
|
|
950
1045
|
declare namespace numpy_d_exports {
|
|
951
|
-
export { Array, ArrayLike, DType, abs, absolute, add, allclose, arange, argmax, argmin, array, astype, bool, clip, columnStack, concatenate, cos, cosh, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hstack, identity$1 as identity, inf, int32, less, lessEqual, linspace, log, log10, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, pad, permuteDims, pi, prod, ravel, reciprocal,
|
|
1046
|
+
export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hstack, hypot, identity$1 as identity, inf, inner, int32, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
952
1047
|
}
|
|
953
1048
|
declare const float32 = DType.Float32;
|
|
954
1049
|
declare const int32 = DType.Int32;
|
|
@@ -965,54 +1060,66 @@ declare const inf: number;
|
|
|
965
1060
|
declare const nan: number;
|
|
966
1061
|
/** This is Pi, `π = 3.14159265358979...` */
|
|
967
1062
|
declare const pi: number;
|
|
968
|
-
/** Element-wise addition, with broadcasting. */
|
|
1063
|
+
/** @function Element-wise addition, with broadcasting. */
|
|
969
1064
|
declare const add: (x: ArrayLike, y: ArrayLike) => Array;
|
|
970
|
-
/** Element-wise multiplication, with broadcasting. */
|
|
1065
|
+
/** @function Element-wise multiplication, with broadcasting. */
|
|
971
1066
|
declare const multiply: (x: ArrayLike, y: ArrayLike) => Array;
|
|
972
|
-
/** Numerical negative of every element of an array. */
|
|
1067
|
+
/** @function Numerical negative of every element of an array. */
|
|
973
1068
|
declare const negative: (x: ArrayLike) => Array;
|
|
974
|
-
/** Calculate element-wise reciprocal of the input. This is `1/x`. */
|
|
1069
|
+
/** @function Calculate element-wise reciprocal of the input. This is `1/x`. */
|
|
975
1070
|
declare const reciprocal: (x: ArrayLike) => Array;
|
|
976
|
-
/** Element-wise sine function (takes radians). */
|
|
1071
|
+
/** @function Element-wise sine function (takes radians). */
|
|
977
1072
|
declare const sin: (x: ArrayLike) => Array;
|
|
978
|
-
/** Element-wise cosine function (takes radians). */
|
|
1073
|
+
/** @function Element-wise cosine function (takes radians). */
|
|
979
1074
|
declare const cos: (x: ArrayLike) => Array;
|
|
980
|
-
/**
|
|
1075
|
+
/** @function Element-wise inverse sine function (inverse of sin). */
|
|
1076
|
+
declare const asin: (x: ArrayLike) => Array;
|
|
1077
|
+
/** @function Element-wise inverse tangent function (inverse of tan). */
|
|
1078
|
+
declare const atan: (x: ArrayLike) => Array;
|
|
1079
|
+
/** @function Calculate the exponential of all elements in the input array. */
|
|
981
1080
|
declare const exp: (x: ArrayLike) => Array;
|
|
982
|
-
/** Calculate the natural logarithm of all elements in the input array. */
|
|
1081
|
+
/** @function Calculate the natural logarithm of all elements in the input array. */
|
|
983
1082
|
declare const log: (x: ArrayLike) => Array;
|
|
984
|
-
/** Calculate the square root of all elements in the input array. */
|
|
1083
|
+
/** @function Calculate the square root of all elements in the input array. */
|
|
985
1084
|
declare const sqrt: (x: ArrayLike) => Array;
|
|
986
|
-
/** Return element-wise minimum of the input arrays. */
|
|
1085
|
+
/** @function Return element-wise minimum of the input arrays. */
|
|
987
1086
|
declare const minimum: (x: ArrayLike, y: ArrayLike) => Array;
|
|
988
|
-
/** Return element-wise maximum of the input arrays. */
|
|
1087
|
+
/** @function Return element-wise maximum of the input arrays. */
|
|
989
1088
|
declare const maximum: (x: ArrayLike, y: ArrayLike) => Array;
|
|
990
|
-
/** Compare two arrays element-wise. */
|
|
1089
|
+
/** @function Compare two arrays element-wise. */
|
|
991
1090
|
declare const greater: (x: ArrayLike, y: ArrayLike) => Array;
|
|
992
|
-
/** Compare two arrays element-wise. */
|
|
1091
|
+
/** @function Compare two arrays element-wise. */
|
|
993
1092
|
declare const less: (x: ArrayLike, y: ArrayLike) => Array;
|
|
994
|
-
/** Compare two arrays element-wise. */
|
|
1093
|
+
/** @function Compare two arrays element-wise. */
|
|
995
1094
|
declare const equal: (x: ArrayLike, y: ArrayLike) => Array;
|
|
996
|
-
/** Compare two arrays element-wise. */
|
|
1095
|
+
/** @function Compare two arrays element-wise. */
|
|
997
1096
|
declare const notEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
998
|
-
/** Compare two arrays element-wise. */
|
|
1097
|
+
/** @function Compare two arrays element-wise. */
|
|
999
1098
|
declare const greaterEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
1000
|
-
/** Compare two arrays element-wise. */
|
|
1099
|
+
/** @function Compare two arrays element-wise. */
|
|
1001
1100
|
declare const lessEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
1002
|
-
/** Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
1101
|
+
/** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
1003
1102
|
declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
|
|
1004
|
-
/**
|
|
1103
|
+
/**
|
|
1104
|
+
* @function
|
|
1105
|
+
* Permute the dimensions of an array. Defaults to reversing the axis order.
|
|
1106
|
+
*/
|
|
1005
1107
|
declare const transpose: (x: ArrayLike, perm?: number[]) => Array;
|
|
1006
1108
|
/**
|
|
1109
|
+
* @function
|
|
1007
1110
|
* Give a new shape to an array without changing its data.
|
|
1008
1111
|
*
|
|
1009
1112
|
* One shape dimension can be -1. In this case, the value is inferred from the
|
|
1010
1113
|
* length of the array and remaining dimensions.
|
|
1011
1114
|
*/
|
|
1012
1115
|
declare const reshape: (x: ArrayLike, shape: number[]) => Array;
|
|
1013
|
-
/**
|
|
1116
|
+
/**
|
|
1117
|
+
* @function
|
|
1118
|
+
* Move axes of an array to new positions. Other axes retain original order.
|
|
1119
|
+
*/
|
|
1014
1120
|
declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
|
|
1015
1121
|
/**
|
|
1122
|
+
* @function
|
|
1016
1123
|
* Add padding (zeros) to an array.
|
|
1017
1124
|
*
|
|
1018
1125
|
* The `width` argument is either an integer or pair of integers, in which case
|
|
@@ -1020,15 +1127,27 @@ declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
|
|
|
1020
1127
|
* pair specifies the padding for its corresponding axis.
|
|
1021
1128
|
*/
|
|
1022
1129
|
declare const pad: (x: ArrayLike, width: number | [number, number] | [number, number][]) => Array;
|
|
1023
|
-
/**
|
|
1130
|
+
/**
|
|
1131
|
+
* @function
|
|
1132
|
+
* Return the number of dimensions of an array. Does not consume array reference.
|
|
1133
|
+
*/
|
|
1024
1134
|
declare const ndim: (x: ArrayLike) => number;
|
|
1025
|
-
/** Return the shape of an array. Does not consume array reference. */
|
|
1135
|
+
/** @function Return the shape of an array. Does not consume array reference. */
|
|
1026
1136
|
declare const shape$1: (x: ArrayLike) => number[];
|
|
1027
|
-
/**
|
|
1137
|
+
/**
|
|
1138
|
+
* @function
|
|
1139
|
+
* Return an array of zeros with the same shape and type as a given array.
|
|
1140
|
+
*/
|
|
1028
1141
|
declare const zerosLike: (a: ArrayLike, dtype?: DType) => Array;
|
|
1029
|
-
/**
|
|
1142
|
+
/**
|
|
1143
|
+
* @function
|
|
1144
|
+
* Return an array of ones with the same shape and type as a given array.
|
|
1145
|
+
*/
|
|
1030
1146
|
declare const onesLike: (a: ArrayLike, dtype?: DType) => Array;
|
|
1031
|
-
/**
|
|
1147
|
+
/**
|
|
1148
|
+
* @function
|
|
1149
|
+
* Return a full array with the same shape and type as a given array.
|
|
1150
|
+
*/
|
|
1032
1151
|
declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, dtype?: DType) => Array;
|
|
1033
1152
|
/**
|
|
1034
1153
|
* Return the number of elements in an array, optionally along an axis.
|
|
@@ -1038,15 +1157,15 @@ declare function size(a: ArrayLike, axis?: number): number;
|
|
|
1038
1157
|
/** Convert an array to a specified dtype. */
|
|
1039
1158
|
declare function astype(a: ArrayLike, dtype: DType): Array;
|
|
1040
1159
|
/** Sum of the elements of the array over a given axis, or axes. */
|
|
1041
|
-
declare function sum(a: ArrayLike, axis?:
|
|
1160
|
+
declare function sum(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1042
1161
|
/** Product of the array elements over a given axis. */
|
|
1043
|
-
declare function prod(a: ArrayLike, axis?:
|
|
1162
|
+
declare function prod(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1044
1163
|
/** Return the minimum of array elements along a given axis. */
|
|
1045
|
-
declare function min(a: ArrayLike, axis?:
|
|
1164
|
+
declare function min(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1046
1165
|
/** Return the maximum of array elements along a given axis. */
|
|
1047
|
-
declare function max(a: ArrayLike, axis?:
|
|
1166
|
+
declare function max(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1048
1167
|
/** Compute the average of the array elements along the specified axis. */
|
|
1049
|
-
declare function mean(a: ArrayLike, axis?:
|
|
1168
|
+
declare function mean(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1050
1169
|
/**
|
|
1051
1170
|
* Returns the indices of the minimum values along an axis.
|
|
1052
1171
|
*
|
|
@@ -1062,7 +1181,7 @@ declare function argmin(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
|
|
|
1062
1181
|
*/
|
|
1063
1182
|
declare function argmax(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
|
|
1064
1183
|
/** Reverse the elements in an array along the given axes. */
|
|
1065
|
-
declare function flip(x: ArrayLike, axis?:
|
|
1184
|
+
declare function flip(x: ArrayLike, axis?: Axis): Array;
|
|
1066
1185
|
/**
|
|
1067
1186
|
* Join a sequence of arrays along an existing axis.
|
|
1068
1187
|
*
|
|
@@ -1106,9 +1225,36 @@ declare function columnStack(xs: ArrayLike[]): Array;
|
|
|
1106
1225
|
declare function flipud(x: ArrayLike): Array;
|
|
1107
1226
|
/** Flip an array horizontally (axis=1). */
|
|
1108
1227
|
declare function fliplr(x: ArrayLike): Array;
|
|
1228
|
+
/** @function Alternative name for `numpy.transpose()`. */
|
|
1109
1229
|
declare const permuteDims: (x: ArrayLike, perm?: number[]) => Array;
|
|
1110
1230
|
/** Return a 1-D flattened array containing the elements of the input. */
|
|
1111
1231
|
declare function ravel(a: ArrayLike): Array;
|
|
1232
|
+
/**
|
|
1233
|
+
* Repeat each element of an array after themselves.
|
|
1234
|
+
*
|
|
1235
|
+
* If no axis is provided, use the flattened input array, and return a flat
|
|
1236
|
+
* output array.
|
|
1237
|
+
*/
|
|
1238
|
+
declare function repeat(a: ArrayLike, repeats: number, axis?: number): Array;
|
|
1239
|
+
/**
|
|
1240
|
+
* Construct an array by repeating A the number of times given by reps.
|
|
1241
|
+
*
|
|
1242
|
+
* If `A` is an array of shape `(d1, d2, ..., dn)` and `reps` is a sequence of
|
|
1243
|
+
* integers, the resulting array will have a shape of `(reps[0] * d1,
|
|
1244
|
+
* reps[1] * d2, ..., reps[n] * dn)`, with `A` tiled along each dimension.
|
|
1245
|
+
*/
|
|
1246
|
+
declare function tile(a: ArrayLike, reps: number | number[]): Array;
|
|
1247
|
+
/**
|
|
1248
|
+
* Broadcast an array to a shape, with NumPy-style broadcasing rules.
|
|
1249
|
+
*
|
|
1250
|
+
* In other words, this lets you append axes to the left, and/or expand
|
|
1251
|
+
* dimensions where the shape is 1.
|
|
1252
|
+
*/
|
|
1253
|
+
declare function broadcastTo(a: ArrayLike, shape: number[]): Array;
|
|
1254
|
+
/** Broadcast input shapes to a common output shape. */
|
|
1255
|
+
declare function broadcastShapes(...shapes: number[][]): number[];
|
|
1256
|
+
/** Broadcast arrays to a common shape. */
|
|
1257
|
+
declare function broadcastArrays(...arrays: ArrayLike[]): Array[];
|
|
1112
1258
|
/**
|
|
1113
1259
|
* Return specified diagonals.
|
|
1114
1260
|
*
|
|
@@ -1136,8 +1282,28 @@ declare function allclose(actual: Parameters<typeof array>[0], expected: Paramet
|
|
|
1136
1282
|
declare function matmul(x: ArrayLike, y: ArrayLike): Array;
|
|
1137
1283
|
/** Dot product of two arrays. */
|
|
1138
1284
|
declare function dot(x: ArrayLike, y: ArrayLike): Array;
|
|
1139
|
-
/**
|
|
1140
|
-
|
|
1285
|
+
/**
|
|
1286
|
+
* Compute the inner product of two arrays.
|
|
1287
|
+
*
|
|
1288
|
+
* Unlike `jax.numpy.matmul()` or `jax.numpy.dot()`, this always performs a
|
|
1289
|
+
* contraction on the last axis.
|
|
1290
|
+
*
|
|
1291
|
+
* Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
|
|
1292
|
+
*/
|
|
1293
|
+
declare function inner(x: ArrayLike, y: ArrayLike): Array;
|
|
1294
|
+
/**
|
|
1295
|
+
* Compute the outer product of two arrays.
|
|
1296
|
+
*
|
|
1297
|
+
* If the input arrays are not 1D, they will be flattened. Returned array will
|
|
1298
|
+
* be of shape `[x.size, y.size]`.
|
|
1299
|
+
*/
|
|
1300
|
+
declare function outer(x: ArrayLike, y: ArrayLike): Array;
|
|
1301
|
+
/** Vector dot product of two arrays along a given axis. */
|
|
1302
|
+
declare function vecdot(x: ArrayLike, y: ArrayLike, {
|
|
1303
|
+
axis
|
|
1304
|
+
}?: {
|
|
1305
|
+
axis?: number;
|
|
1306
|
+
}): Array;
|
|
1141
1307
|
/**
|
|
1142
1308
|
* Return the dot product of two vectors.
|
|
1143
1309
|
*
|
|
@@ -1155,6 +1321,21 @@ declare function meshgrid(xs: Array[], {
|
|
|
1155
1321
|
}?: {
|
|
1156
1322
|
indexing?: "xy" | "ij";
|
|
1157
1323
|
}): Array[];
|
|
1324
|
+
/**
|
|
1325
|
+
* Return an array with ones on and below the diagonal and zeros elsewhere.
|
|
1326
|
+
*
|
|
1327
|
+
* If `k` is provided, it specifies the sub-diagonal on and below which the
|
|
1328
|
+
* array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
|
|
1329
|
+
* `k>0` is above it.
|
|
1330
|
+
*/
|
|
1331
|
+
declare function tri(n: number, m?: number, k?: number, {
|
|
1332
|
+
dtype,
|
|
1333
|
+
device
|
|
1334
|
+
}?: DTypeAndDevice): Array;
|
|
1335
|
+
/** Return the lower triangle of an array. Must be of dimension >= 2. */
|
|
1336
|
+
declare function tril(a: ArrayLike, k?: number): Array;
|
|
1337
|
+
/** Return the upper triangle of an array. Must be of dimension >= 2. */
|
|
1338
|
+
declare function triu(a: ArrayLike, k?: number): Array;
|
|
1158
1339
|
/**
|
|
1159
1340
|
* Clip (limit) the values in an array.
|
|
1160
1341
|
*
|
|
@@ -1171,15 +1352,50 @@ declare function clip(a: ArrayLike, min?: ArrayLike, max?: ArrayLike): Array;
|
|
|
1171
1352
|
* This is the same function as `jax.numpy.abs()`.
|
|
1172
1353
|
*/
|
|
1173
1354
|
declare function absolute(x: ArrayLike): Array;
|
|
1174
|
-
/** Alias of `jax.numpy.absolute()`. */
|
|
1355
|
+
/** @function Alias of `jax.numpy.absolute()`. */
|
|
1175
1356
|
declare const abs: typeof absolute;
|
|
1357
|
+
/** Return an element-wise indication of sign of the input. */
|
|
1358
|
+
declare function sign(x: ArrayLike): Array;
|
|
1176
1359
|
/** Calculate element-wise square of the input array. */
|
|
1177
1360
|
declare function square(x: ArrayLike): Array;
|
|
1178
|
-
/**
|
|
1361
|
+
/** Element-wise tangent function (takes radians). */
|
|
1179
1362
|
declare function tan(x: ArrayLike): Array;
|
|
1363
|
+
/** Element-wise inverse cosine function (inverse of cos). */
|
|
1364
|
+
declare function acos(x: ArrayLike): Array;
|
|
1365
|
+
/**
|
|
1366
|
+
* @function
|
|
1367
|
+
* Return element-wise hypotenuse for the given legs of a right triangle.
|
|
1368
|
+
*
|
|
1369
|
+
* In the original NumPy/JAX implementation, this function is more numerically
|
|
1370
|
+
* stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
|
|
1371
|
+
* improvements.
|
|
1372
|
+
*/
|
|
1373
|
+
declare const hypot: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
1374
|
+
/**
|
|
1375
|
+
* @function
|
|
1376
|
+
* Element-wise arc tangent of y/x with correct quadrant.
|
|
1377
|
+
*
|
|
1378
|
+
* Returns the angle in radians between the positive x-axis and the point (x, y).
|
|
1379
|
+
* The result is in the range [-π, π].
|
|
1380
|
+
*
|
|
1381
|
+
* Uses numerically stable formulas:
|
|
1382
|
+
* - When x >= 0: atan2(y, x) = 2 * atan(y / (sqrt(x^2 + y^2) + x))
|
|
1383
|
+
* - When x < 0: atan2(y, x) = 2 * atan((sqrt(x^2 + y^2) - x) / y)
|
|
1384
|
+
*
|
|
1385
|
+
* The output is ill-defined when both x and y are zero.
|
|
1386
|
+
*/
|
|
1387
|
+
declare const atan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
|
|
1388
|
+
/** @function Alias of `jax.numpy.acos()`. */
|
|
1389
|
+
declare const arccos: typeof acos;
|
|
1390
|
+
/** @function Alias of `jax.numpy.atan()`. */
|
|
1391
|
+
declare const arctan: (x: ArrayLike) => Array;
|
|
1392
|
+
/** @function Alias of `jax.numpy.atan2()`. */
|
|
1393
|
+
declare const arctan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
|
|
1394
|
+
/** Element-wise subtraction, with broadcasting. */
|
|
1395
|
+
declare function subtract(x: ArrayLike, y: ArrayLike): Array;
|
|
1180
1396
|
/** Calculates the floating-point division of x by y element-wise. */
|
|
1181
1397
|
declare function trueDivide(x: ArrayLike, y: ArrayLike): Array;
|
|
1182
|
-
/** Alias of `jax.numpy.trueDivide()`. */
|
|
1398
|
+
/** @function Alias of `jax.numpy.trueDivide()`. */
|
|
1183
1399
|
declare const divide: typeof trueDivide;
|
|
1184
1400
|
/** Round input to the nearest integer towards zero. */
|
|
1185
1401
|
declare function trunc(x: ArrayLike): Array;
|
|
@@ -1189,26 +1405,112 @@ declare function exp2(p: ArrayLike): Array;
|
|
|
1189
1405
|
declare function log2(x: ArrayLike): Array;
|
|
1190
1406
|
/** Return the base-10 logarithm of x, element-wise. */
|
|
1191
1407
|
declare function log10(x: ArrayLike): Array;
|
|
1408
|
+
/** Calculate `exp(x) - 1` element-wise. */
|
|
1409
|
+
declare function expm1(x: ArrayLike): Array;
|
|
1410
|
+
/** Calculate the natural logarithm of `1 + x` element-wise. */
|
|
1411
|
+
declare function log1p(x: ArrayLike): Array;
|
|
1412
|
+
/** Convert angles from degrees to radians. */
|
|
1413
|
+
declare function deg2rad(x: ArrayLike): Array;
|
|
1414
|
+
/** @function Alias of `jax.numpy.deg2rad()`. */
|
|
1415
|
+
declare const radians: typeof deg2rad;
|
|
1416
|
+
/** Convert angles from radians to degrees. */
|
|
1417
|
+
declare function rad2deg(x: ArrayLike): Array;
|
|
1418
|
+
/** @function Alias of `jax.numpy.rad2deg()`. */
|
|
1419
|
+
declare const degrees: typeof rad2deg;
|
|
1420
|
+
/**
|
|
1421
|
+
* @function
|
|
1422
|
+
* Computes first array raised to power of second array, element-wise.
|
|
1423
|
+
*/
|
|
1424
|
+
declare const power: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
1425
|
+
/** @function Alias of `jax.numpy.power()`. */
|
|
1426
|
+
declare const pow: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
1427
|
+
/** @function Calculate the element-wise cube root of the input array. */
|
|
1428
|
+
declare const cbrt: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1192
1429
|
/**
|
|
1430
|
+
* @function
|
|
1193
1431
|
* Calculate element-wise hyperbolic sine of input.
|
|
1194
1432
|
*
|
|
1195
1433
|
* `sinh(x) = (exp(x) - exp(-x)) / 2`
|
|
1196
1434
|
*/
|
|
1197
|
-
declare
|
|
1435
|
+
declare const sinh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1198
1436
|
/**
|
|
1437
|
+
* @function
|
|
1199
1438
|
* Calculate element-wise hyperbolic cosine of input.
|
|
1200
1439
|
*
|
|
1201
1440
|
* `cosh(x) = (exp(x) + exp(-x)) / 2`
|
|
1202
1441
|
*/
|
|
1203
|
-
declare
|
|
1442
|
+
declare const cosh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1204
1443
|
/**
|
|
1444
|
+
* @function
|
|
1205
1445
|
* Calculate element-wise hyperbolic tangent of input.
|
|
1206
1446
|
*
|
|
1207
1447
|
* `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
|
|
1208
1448
|
*/
|
|
1209
|
-
declare
|
|
1449
|
+
declare const tanh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1450
|
+
/**
|
|
1451
|
+
* @function
|
|
1452
|
+
* Calculate element-wise inverse hyperbolic sine of input.
|
|
1453
|
+
*
|
|
1454
|
+
* `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
|
|
1455
|
+
*/
|
|
1456
|
+
declare const arcsinh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1457
|
+
/**
|
|
1458
|
+
* @function
|
|
1459
|
+
* Calculate element-wise inverse hyperbolic cosine of input.
|
|
1460
|
+
*
|
|
1461
|
+
* `arccosh(x) = ln(x + sqrt(x^2 - 1))`
|
|
1462
|
+
*/
|
|
1463
|
+
declare const arccosh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1464
|
+
/**
|
|
1465
|
+
* @function
|
|
1466
|
+
* Calculate element-wise inverse hyperbolic tangent of input.
|
|
1467
|
+
*
|
|
1468
|
+
* `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
|
|
1469
|
+
*/
|
|
1470
|
+
declare const arctanh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1471
|
+
/** @function Alias of `jax.numpy.arcsinh()`. */
|
|
1472
|
+
declare const asinh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1473
|
+
/** @function Alias of `jax.numpy.arccosh()`. */
|
|
1474
|
+
declare const acosh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1475
|
+
/** @function Alias of `jax.numpy.arctanh()`. */
|
|
1476
|
+
declare const atanh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1477
|
+
/**
|
|
1478
|
+
* Compute the variance of an array.
|
|
1479
|
+
*
|
|
1480
|
+
* The variance is computed for the flattened array by default, otherwise over
|
|
1481
|
+
* the specified axis.
|
|
1482
|
+
*
|
|
1483
|
+
* If `correction` is provided, the divisor in calculation is `N - correction`,
|
|
1484
|
+
* where `N` represents the number of elements (e.g., for Bessel's correction).
|
|
1485
|
+
*/
|
|
1486
|
+
declare function var_(x: ArrayLike, axis?: Axis, opts?: {
|
|
1487
|
+
mean?: ArrayLike;
|
|
1488
|
+
correction?: number;
|
|
1489
|
+
} & ReduceOpts): Array;
|
|
1490
|
+
/**
|
|
1491
|
+
* Compute the standard deviation of an array.
|
|
1492
|
+
*
|
|
1493
|
+
* The standard deviation is computed for the flattened array by default,
|
|
1494
|
+
* otherwise over the specified axis.
|
|
1495
|
+
*
|
|
1496
|
+
* If `correction` is provided, the divisor in calculation is `N - correction`,
|
|
1497
|
+
* where `N` represents the number of elements (e.g., for Bessel's correction).
|
|
1498
|
+
*/
|
|
1499
|
+
declare function std(x: ArrayLike, axis?: Axis, opts?: {
|
|
1500
|
+
mean?: ArrayLike;
|
|
1501
|
+
correction?: number;
|
|
1502
|
+
} & ReduceOpts): Array;
|
|
1210
1503
|
//#endregion
|
|
1211
1504
|
//#region src/frontend/jaxpr.d.ts
|
|
1505
|
+
/**
|
|
1506
|
+
* Function callback with an associated dispose() method.
|
|
1507
|
+
*
|
|
1508
|
+
* The dispose() method should be called to clean up any tracer resources needed
|
|
1509
|
+
* by the function after the last time it is called.
|
|
1510
|
+
*/
|
|
1511
|
+
type OwnedFunction<F extends Function> = F & {
|
|
1512
|
+
dispose: () => void;
|
|
1513
|
+
};
|
|
1212
1514
|
/** Variable in a Jaxpr expression. */
|
|
1213
1515
|
declare class Var {
|
|
1214
1516
|
#private;
|
|
@@ -1219,10 +1521,10 @@ declare class Var {
|
|
|
1219
1521
|
}
|
|
1220
1522
|
/** Literal in a Jaxpr expression. Currently, only scalars are supported. */
|
|
1221
1523
|
declare class Lit {
|
|
1222
|
-
readonly dtype: DType;
|
|
1223
1524
|
readonly value: number;
|
|
1224
1525
|
readonly aval: ShapedArray;
|
|
1225
|
-
|
|
1526
|
+
get dtype(): DType;
|
|
1527
|
+
constructor(aval: AbstractValue, value: number);
|
|
1226
1528
|
}
|
|
1227
1529
|
type Atom = Var | Lit;
|
|
1228
1530
|
declare class VarPrinter {
|
|
@@ -1300,7 +1602,7 @@ declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding:
|
|
|
1300
1602
|
/** Reduce a computation over padded windows. */
|
|
1301
1603
|
declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
|
|
1302
1604
|
declare namespace nn_d_exports {
|
|
1303
|
-
export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, swish };
|
|
1605
|
+
export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, squareplus, standardize, swish };
|
|
1304
1606
|
}
|
|
1305
1607
|
/**
|
|
1306
1608
|
* Rectified Linear Unit (ReLU) activation function:
|
|
@@ -1332,6 +1634,7 @@ declare function softplus(x: ArrayLike): Array;
|
|
|
1332
1634
|
*/
|
|
1333
1635
|
declare function softSign(x: ArrayLike): Array;
|
|
1334
1636
|
/**
|
|
1637
|
+
* @function
|
|
1335
1638
|
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
1336
1639
|
* Swish, computed element-wise:
|
|
1337
1640
|
* `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
|
|
@@ -1340,8 +1643,9 @@ declare function softSign(x: ArrayLike): Array;
|
|
|
1340
1643
|
*
|
|
1341
1644
|
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
1342
1645
|
*/
|
|
1343
|
-
declare const silu: (x: ArrayLike) => Array
|
|
1646
|
+
declare const silu: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1344
1647
|
/**
|
|
1648
|
+
* @function
|
|
1345
1649
|
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
1346
1650
|
* Swish, computed element-wise:
|
|
1347
1651
|
* `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
|
|
@@ -1350,31 +1654,35 @@ declare const silu: (x: ArrayLike) => Array;
|
|
|
1350
1654
|
*
|
|
1351
1655
|
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
1352
1656
|
*/
|
|
1353
|
-
declare const swish: (x: ArrayLike) => Array
|
|
1657
|
+
declare const swish: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1354
1658
|
/**
|
|
1355
1659
|
* Log-sigmoid activation function, computed element-wise:
|
|
1356
1660
|
* `log_sigmoid(x) = log(sigmoid(x)) = -log(1 + exp(-x))`.
|
|
1357
1661
|
*/
|
|
1358
1662
|
declare function logSigmoid(x: ArrayLike): Array;
|
|
1359
|
-
/**
|
|
1663
|
+
/**
|
|
1664
|
+
* @function
|
|
1665
|
+
* Identity activation function. Returns the argument unmodified.
|
|
1666
|
+
*/
|
|
1360
1667
|
declare const identity: (x: ArrayLike) => Array;
|
|
1361
1668
|
/** Leaky rectified linear (ReLU) activation function */
|
|
1362
|
-
declare function leakyRelu(x: ArrayLike, negativeSlope?:
|
|
1669
|
+
declare function leakyRelu(x: ArrayLike, negativeSlope?: ArrayLike): Array;
|
|
1363
1670
|
/**
|
|
1364
1671
|
* Exponential linear unit activation function.
|
|
1365
1672
|
*
|
|
1366
1673
|
* Computes the element-wise function:
|
|
1367
1674
|
* `elu(x) = x > 0 ? x : alpha * (exp(x) - 1)`
|
|
1368
1675
|
*/
|
|
1369
|
-
declare function elu(x: ArrayLike, alpha?:
|
|
1676
|
+
declare function elu(x: ArrayLike, alpha?: ArrayLike): Array;
|
|
1370
1677
|
/**
|
|
1371
1678
|
* Continuously-differentiable exponential linear unit activation function.
|
|
1372
1679
|
*
|
|
1373
1680
|
* Computes the element-wise function:
|
|
1374
1681
|
* `celu(x) = x > 0 ? x : alpha * (exp(x/alpha) - 1)`
|
|
1375
1682
|
*/
|
|
1376
|
-
declare function celu(x: ArrayLike, alpha?:
|
|
1683
|
+
declare function celu(x: ArrayLike, alpha?: ArrayLike): Array;
|
|
1377
1684
|
/**
|
|
1685
|
+
* @function
|
|
1378
1686
|
* Gaussion error linear unit (GELU) activation function.
|
|
1379
1687
|
*
|
|
1380
1688
|
* This is computed element-wise. Currently jax-js does not support the erf() or
|
|
@@ -1385,7 +1693,7 @@ declare function celu(x: ArrayLike, alpha?: number): Array;
|
|
|
1385
1693
|
*
|
|
1386
1694
|
* This will be improved in the future.
|
|
1387
1695
|
*/
|
|
1388
|
-
declare const gelu: (x: ArrayLike) => Array
|
|
1696
|
+
declare const gelu: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1389
1697
|
/**
|
|
1390
1698
|
* Gated linear unit (GLU) activation function.
|
|
1391
1699
|
*
|
|
@@ -1393,6 +1701,13 @@ declare const gelu: (x: ArrayLike) => Array;
|
|
|
1393
1701
|
* computes `a * sigmoid(b)`.
|
|
1394
1702
|
*/
|
|
1395
1703
|
declare function glu(x: ArrayLike, axis?: number): Array;
|
|
1704
|
+
/**
|
|
1705
|
+
* Squareplus activation function.
|
|
1706
|
+
*
|
|
1707
|
+
* Computes the element-wise function:
|
|
1708
|
+
* `squareplus(x) = 0.5 * (x + sqrt(x^2 + b))`
|
|
1709
|
+
*/
|
|
1710
|
+
declare function squareplus(x: ArrayLike, b?: ArrayLike): Array;
|
|
1396
1711
|
/**
|
|
1397
1712
|
* Mish activation function.
|
|
1398
1713
|
*
|
|
@@ -1408,7 +1723,7 @@ declare function mish(x: ArrayLike): Array;
|
|
|
1408
1723
|
*
|
|
1409
1724
|
* Reference: https://en.wikipedia.org/wiki/Softmax_function
|
|
1410
1725
|
*/
|
|
1411
|
-
declare function softmax(x: ArrayLike, axis?:
|
|
1726
|
+
declare function softmax(x: ArrayLike, axis?: Axis): Array;
|
|
1412
1727
|
/**
|
|
1413
1728
|
* Log-Softmax function.
|
|
1414
1729
|
*
|
|
@@ -1417,7 +1732,7 @@ declare function softmax(x: ArrayLike, axis?: number | number[]): Array;
|
|
|
1417
1732
|
*
|
|
1418
1733
|
* If `axis` is not specified, it defaults to the last axis.
|
|
1419
1734
|
*/
|
|
1420
|
-
declare function logSoftmax(x: ArrayLike, axis?:
|
|
1735
|
+
declare function logSoftmax(x: ArrayLike, axis?: Axis): Array;
|
|
1421
1736
|
/**
|
|
1422
1737
|
* Log-sum-exp reduction. Also a multivariate version of `softplus`.
|
|
1423
1738
|
*
|
|
@@ -1426,7 +1741,22 @@ declare function logSoftmax(x: ArrayLike, axis?: number | number[]): Array;
|
|
|
1426
1741
|
*
|
|
1427
1742
|
* Reference: https://en.wikipedia.org/wiki/LogSumExp
|
|
1428
1743
|
*/
|
|
1429
|
-
declare function logsumexp(x: ArrayLike, axis?:
|
|
1744
|
+
declare function logsumexp(x: ArrayLike, axis?: Axis): Array;
|
|
1745
|
+
/** Log-mean-exp reduction, like `jax.nn.logsumexp()` but subtracts `log(n)`. */
|
|
1746
|
+
declare function logmeanexp(x: ArrayLike, axis?: Axis): Array;
|
|
1747
|
+
/**
|
|
1748
|
+
* Standardizes input to zero mean and unit variance.
|
|
1749
|
+
*
|
|
1750
|
+
* By default, this is computed over the last axis. You can pass in a different
|
|
1751
|
+
* axis, or `null` to standardize over all elements.
|
|
1752
|
+
*
|
|
1753
|
+
* Epsilon is added to denominator, it defaults to `1e-5` for stability.
|
|
1754
|
+
*/
|
|
1755
|
+
declare function standardize(x: ArrayLike, axis?: Axis, opts?: {
|
|
1756
|
+
mean?: ArrayLike;
|
|
1757
|
+
variance?: ArrayLike;
|
|
1758
|
+
epsilon?: ArrayLike;
|
|
1759
|
+
}): Array;
|
|
1430
1760
|
/**
|
|
1431
1761
|
* One-hot encodes the given indices.
|
|
1432
1762
|
*
|
|
@@ -1445,7 +1775,7 @@ declare function logsumexp(x: ArrayLike, axis?: number | number[]): Array;
|
|
|
1445
1775
|
*/
|
|
1446
1776
|
declare function oneHot(x: Array, numClasses: number): Array;
|
|
1447
1777
|
declare namespace random_d_exports {
|
|
1448
|
-
export { bits, key, split, uniform };
|
|
1778
|
+
export { bernoulli, bits, exponential, key, normal, split, uniform };
|
|
1449
1779
|
}
|
|
1450
1780
|
/** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
|
|
1451
1781
|
declare function key(seed: number): Array;
|
|
@@ -1453,34 +1783,71 @@ declare function key(seed: number): Array;
|
|
|
1453
1783
|
declare function split(key: Array, num?: number | number[]): Array;
|
|
1454
1784
|
/** Sample uniform bits in the form of unsigned integers. */
|
|
1455
1785
|
declare function bits(key: Array, shape?: number[]): Array;
|
|
1456
|
-
/**
|
|
1457
|
-
|
|
1458
|
-
|
|
1459
|
-
|
|
1460
|
-
|
|
1461
|
-
minval?: number;
|
|
1462
|
-
maxval?: number;
|
|
1463
|
-
})
|
|
1786
|
+
/**
|
|
1787
|
+
* @function
|
|
1788
|
+
* Sample uniform random values in [minval, maxval) with given shape.
|
|
1789
|
+
*/
|
|
1790
|
+
declare const uniform: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined, args_2?: {
|
|
1791
|
+
minval?: number | undefined;
|
|
1792
|
+
maxval?: number | undefined;
|
|
1793
|
+
} | undefined) => Array>;
|
|
1794
|
+
/**
|
|
1795
|
+
* Sample Bernoulli random variables with given mean (0,1 categorical).
|
|
1796
|
+
*
|
|
1797
|
+
* Returns a random Boolean array with the specified shape. `p` can be an array
|
|
1798
|
+
* and must be broadcastable to `shape`.
|
|
1799
|
+
*/
|
|
1800
|
+
declare function bernoulli(key: Array, p?: ArrayLike, shape?: number[]): Array;
|
|
1801
|
+
/**
|
|
1802
|
+
* @function
|
|
1803
|
+
* Sample exponential random values according to `p(x) = exp(-x)`.
|
|
1804
|
+
*/
|
|
1805
|
+
declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
1806
|
+
/**
|
|
1807
|
+
* @function
|
|
1808
|
+
* Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
|
|
1809
|
+
*
|
|
1810
|
+
* Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
|
|
1811
|
+
* directly inverts the CDF, but we don't have support for that yet. Outputs will not be
|
|
1812
|
+
* bitwise identical to JAX.
|
|
1813
|
+
*/
|
|
1814
|
+
declare const normal: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
1464
1815
|
//#endregion
|
|
1465
1816
|
//#region src/index.d.ts
|
|
1466
|
-
/**
|
|
1817
|
+
/**
|
|
1818
|
+
* @function
|
|
1819
|
+
* Compute the forward-mode Jacobian-vector product for a function.
|
|
1820
|
+
*/
|
|
1467
1821
|
declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, ReturnType<F>];
|
|
1468
|
-
/**
|
|
1822
|
+
/**
|
|
1823
|
+
* @function
|
|
1824
|
+
* Vectorize an operation on a batched axis for one or more inputs.
|
|
1825
|
+
*/
|
|
1469
1826
|
declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number | MapJsTree<Parameters<F>, ArrayLike, number | null>) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
1470
|
-
/**
|
|
1827
|
+
/**
|
|
1828
|
+
* @function
|
|
1829
|
+
* Compute the Jacobian evaluated column-by-column by forward-mode AD.
|
|
1830
|
+
*/
|
|
1471
1831
|
declare const jacfwd: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
1472
|
-
/**
|
|
1832
|
+
/**
|
|
1833
|
+
* @function
|
|
1834
|
+
* Construct a Jaxpr by dynamically tracing a function with example inputs.
|
|
1835
|
+
*/
|
|
1473
1836
|
declare const makeJaxpr: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...args: Parameters<F>) => {
|
|
1474
1837
|
jaxpr: Jaxpr;
|
|
1475
1838
|
consts: Array[];
|
|
1476
1839
|
treedef: JsTreeDef;
|
|
1477
1840
|
};
|
|
1478
1841
|
/**
|
|
1842
|
+
* @function
|
|
1479
1843
|
* Mark a function for automatic JIT compilation, with operator fusion.
|
|
1480
1844
|
*
|
|
1481
1845
|
* The function will be compiled the first time it is called with a set of
|
|
1482
1846
|
* argument shapes.
|
|
1483
1847
|
*
|
|
1848
|
+
* You can call `.dispose()` on the returned, JIT-compiled function after all
|
|
1849
|
+
* calls to free memory associated with array constants.
|
|
1850
|
+
*
|
|
1484
1851
|
* **Options:**
|
|
1485
1852
|
* - `staticArgnums`: An array of argument indices to treat as static
|
|
1486
1853
|
* (compile-time constant). These arguments must be hashable, won't be traced,
|
|
@@ -1488,24 +1855,48 @@ declare const makeJaxpr: <F extends (...args: any[]) => JsTree<Array>>(f: F) =>
|
|
|
1488
1855
|
* - `device`: The device to place the computation on. If not specified, the
|
|
1489
1856
|
* computation will be placed on the default device.
|
|
1490
1857
|
*/
|
|
1491
|
-
declare const jit: <F extends (...args: any[]) => JsTree<Array>>(f: F, opts?: JitOpts) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F
|
|
1858
|
+
declare const jit: <F extends (...args: any[]) => JsTree<Array>>(f: F, opts?: JitOpts) => OwnedFunction<(...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>>;
|
|
1492
1859
|
/**
|
|
1860
|
+
* @function
|
|
1493
1861
|
* Produce a local linear approximation to a function at a point using jvp() and
|
|
1494
1862
|
* partial evaluation.
|
|
1495
1863
|
*/
|
|
1496
1864
|
declare const linearize: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (...tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>];
|
|
1497
|
-
/**
|
|
1865
|
+
/**
|
|
1866
|
+
* @function
|
|
1867
|
+
* Calculate the reverse-mode vector-Jacobian product for a function.
|
|
1868
|
+
*/
|
|
1498
1869
|
declare const vjp: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (cotangents: MapJsTree<ReturnType<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>];
|
|
1499
1870
|
/**
|
|
1871
|
+
* @function
|
|
1500
1872
|
* Compute the gradient of a scalar-valued function `f` with respect to its
|
|
1501
1873
|
* first argument.
|
|
1502
1874
|
*/
|
|
1503
1875
|
declare const grad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>[0], ArrayLike, Array>;
|
|
1504
|
-
/**
|
|
1876
|
+
/**
|
|
1877
|
+
* @function
|
|
1878
|
+
* Create a function that evaluates both `f` and the gradient of `f`.
|
|
1879
|
+
*/
|
|
1505
1880
|
declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, MapJsTree<Parameters<F>[0], ArrayLike, Array>];
|
|
1506
|
-
/**
|
|
1881
|
+
/**
|
|
1882
|
+
* @function
|
|
1883
|
+
* Compute the Jacobian evaluated row-by-row by reverse-mode AD.
|
|
1884
|
+
*/
|
|
1507
1885
|
declare const jacrev: typeof jacfwd;
|
|
1508
|
-
/**
|
|
1886
|
+
/**
|
|
1887
|
+
* @function
|
|
1888
|
+
* Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
|
|
1889
|
+
*/
|
|
1509
1890
|
declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
1891
|
+
/**
|
|
1892
|
+
* Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
|
|
1893
|
+
*
|
|
1894
|
+
* This can be used to wait for the results of an intermediate computation to
|
|
1895
|
+
* finish. It's recommended to call this regularly in an iterative computation
|
|
1896
|
+
* to avoid queueing up too many pending operations.
|
|
1897
|
+
*
|
|
1898
|
+
* Does not consume reference to the arrays.
|
|
1899
|
+
*/
|
|
1900
|
+
declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
|
|
1510
1901
|
//#endregion
|
|
1511
|
-
export { DType, type Device, type JsTree, type JsTreeDef, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random,
|
|
1902
|
+
export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|