@jax-js/jax 0.0.3 → 0.0.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +96 -22
- package/dist/{backend-BqDtPGaR.js → backend-CdcTZEOF.js} +325 -153
- package/dist/{backend-D2C4MJRP.cjs → backend-yEU0L_ig.cjs} +350 -154
- package/dist/index.cjs +977 -354
- package/dist/index.d.cts +479 -88
- package/dist/index.d.ts +479 -88
- package/dist/index.js +964 -345
- package/dist/{webgpu-CNg9JGva.js → webgpu-CM-xNYzW.js} +9 -3
- package/dist/{webgpu-fqhx41TC.cjs → webgpu-CNOpiO5T.cjs} +9 -3
- package/package.json +15 -4
package/dist/index.d.ts
CHANGED
|
@@ -123,6 +123,19 @@ declare class ShapeTracker {
|
|
|
123
123
|
}
|
|
124
124
|
//#endregion
|
|
125
125
|
//#region src/utils.d.ts
|
|
126
|
+
/**
|
|
127
|
+
* Set the debug level for verbose logging.
|
|
128
|
+
*
|
|
129
|
+
* 1. JIT compile logs
|
|
130
|
+
* 2. Shader code
|
|
131
|
+
* 3. Expressions and metadata
|
|
132
|
+
* 4. JIT programs, tuning details
|
|
133
|
+
* 5. Most verbose operation traces
|
|
134
|
+
*
|
|
135
|
+
* This is an experimental API and may change in behavior. Do not rely on this
|
|
136
|
+
* in production.
|
|
137
|
+
*/
|
|
138
|
+
declare function setDebug(level: number): void;
|
|
126
139
|
/** @inline */
|
|
127
140
|
type RecursiveArray<T> = T | RecursiveArray<T>[];
|
|
128
141
|
interface FpHashable {
|
|
@@ -138,7 +151,7 @@ interface FpHashable {
|
|
|
138
151
|
declare class FpHash {
|
|
139
152
|
#private;
|
|
140
153
|
value: bigint;
|
|
141
|
-
update(
|
|
154
|
+
update(x: string | boolean | number | bigint | null | undefined | FpHashable): this;
|
|
142
155
|
static hash(...values: (string | boolean | number | bigint | null | undefined | FpHashable)[]): bigint;
|
|
143
156
|
}
|
|
144
157
|
/** Run a function while caching it inline inside a `Map`. */
|
|
@@ -154,6 +167,31 @@ declare enum DType {
|
|
|
154
167
|
}
|
|
155
168
|
/** @inline */
|
|
156
169
|
type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | Float16Array<ArrayBuffer>;
|
|
170
|
+
/**
|
|
171
|
+
* Promote two dtypes to their join according to the type lattice.
|
|
172
|
+
*
|
|
173
|
+
* When performing operations between arrays of different types, we need to
|
|
174
|
+
* promote both operands to a common type that can represent values from both
|
|
175
|
+
* input types. This follows JAX's type promotion rules.
|
|
176
|
+
*
|
|
177
|
+
* **Type lattice:**
|
|
178
|
+
* ```text
|
|
179
|
+
* bool -> uint32 -> int32 -> float16 -> float32
|
|
180
|
+
* weakType --^
|
|
181
|
+
* ```
|
|
182
|
+
*
|
|
183
|
+
* `weakType` represents weakly typed arrays. These are created for JS numbers,
|
|
184
|
+
* which default to float32 but "weak" so they cast to the dtype of any array
|
|
185
|
+
* they are first combined with, except `bool`.
|
|
186
|
+
*
|
|
187
|
+
* **Examples:**
|
|
188
|
+
* - `promoteTypes(bool, int32) → int32`
|
|
189
|
+
* - `promoteTypes(uint32, int32) → int32`
|
|
190
|
+
* - `promoteTypes(int32, float16) → float16`
|
|
191
|
+
* - `promoteTypes(float16, float32) → float32`
|
|
192
|
+
* - `promoteTypes(uint32, float32) → float32`
|
|
193
|
+
*/
|
|
194
|
+
declare function promoteTypes(dtype1: DType, dtype2: DType): DType;
|
|
157
195
|
/**
|
|
158
196
|
* Mathematical expression on scalar values.
|
|
159
197
|
*
|
|
@@ -177,6 +215,8 @@ declare class AluExp implements FpHashable {
|
|
|
177
215
|
static max(a: AluExp, b: AluExp): AluExp;
|
|
178
216
|
static sin(a: AluExp): AluExp;
|
|
179
217
|
static cos(a: AluExp): AluExp;
|
|
218
|
+
static asin(a: AluExp): AluExp;
|
|
219
|
+
static atan(a: AluExp): AluExp;
|
|
180
220
|
static exp(a: AluExp): AluExp;
|
|
181
221
|
static log(a: AluExp): AluExp;
|
|
182
222
|
static sqrt(a: AluExp): AluExp;
|
|
@@ -262,6 +302,8 @@ declare enum AluOp {
|
|
|
262
302
|
Max = "Max",
|
|
263
303
|
Sin = "Sin",
|
|
264
304
|
Cos = "Cos",
|
|
305
|
+
Asin = "Asin",
|
|
306
|
+
Atan = "Atan",
|
|
265
307
|
Exp = "Exp",
|
|
266
308
|
Log = "Log",
|
|
267
309
|
Sqrt = "Sqrt",
|
|
@@ -354,8 +396,8 @@ declare class Reduction implements FpHashable {
|
|
|
354
396
|
//#region src/backend.d.ts
|
|
355
397
|
type Device = "cpu" | "wasm" | "webgpu";
|
|
356
398
|
declare const devices: Device[];
|
|
357
|
-
/**
|
|
358
|
-
declare function
|
|
399
|
+
/** Configure the default device for arrays. */
|
|
400
|
+
declare function defaultDevice(device?: Device): Device;
|
|
359
401
|
/**
|
|
360
402
|
* Initialize `jax-js` library backends.
|
|
361
403
|
*
|
|
@@ -491,6 +533,8 @@ declare enum Primitive {
|
|
|
491
533
|
RandomBits = "random_bits",
|
|
492
534
|
Sin = "sin",
|
|
493
535
|
Cos = "cos",
|
|
536
|
+
Asin = "asin",
|
|
537
|
+
Atan = "atan",
|
|
494
538
|
Exp = "exp",
|
|
495
539
|
Log = "log",
|
|
496
540
|
Sqrt = "sqrt",
|
|
@@ -566,6 +610,7 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
|
566
610
|
outDim: number;
|
|
567
611
|
};
|
|
568
612
|
[Primitive.JitCall]: {
|
|
613
|
+
name: string;
|
|
569
614
|
jaxpr: Jaxpr;
|
|
570
615
|
numConsts: number;
|
|
571
616
|
};
|
|
@@ -581,8 +626,10 @@ declare enum CompareOp {
|
|
|
581
626
|
LessEqual = "less_equal",
|
|
582
627
|
}
|
|
583
628
|
/** @inline */
|
|
629
|
+
type Axis = number | number[] | null;
|
|
630
|
+
/** @inline */
|
|
584
631
|
type ReduceOpts = {
|
|
585
|
-
|
|
632
|
+
keepdims?: boolean;
|
|
586
633
|
};
|
|
587
634
|
type MainTrace = {
|
|
588
635
|
level: number;
|
|
@@ -602,10 +649,40 @@ declare abstract class Trace {
|
|
|
602
649
|
abstract lift(val: Tracer): Tracer;
|
|
603
650
|
abstract processPrimitive<P extends Primitive>(primitive: P, tracers: Tracer[], params: PrimitiveParams<P>): Tracer[];
|
|
604
651
|
}
|
|
652
|
+
/** Internal representation of an array value. */
|
|
605
653
|
interface AbstractValue {
|
|
654
|
+
/** Shape of the array. Must be a static tuple of non-negative dimensions. */
|
|
606
655
|
shape: number[];
|
|
656
|
+
/** Concrete data type of array elements. */
|
|
607
657
|
dtype: DType;
|
|
658
|
+
/**
|
|
659
|
+
* Arrays created from JavaScript numbers (e.g., `np.array(3)`) are created as
|
|
660
|
+
* _weakly typed_ unless a dtype is explicitly specified.
|
|
661
|
+
*
|
|
662
|
+
* Weakly typed values will automatically cast to the data type of other
|
|
663
|
+
* arrays when used as an operand as an expression. This property only affects
|
|
664
|
+
* how they promote in type casting; their memory layout is still determined
|
|
665
|
+
* by the actual `dtype` field.
|
|
666
|
+
*
|
|
667
|
+
* ```ts
|
|
668
|
+
* const x = np.array(3); // weakType = true, dtype = float32
|
|
669
|
+
* const y = np.array([1, 2], { dtype: np.int32 }); // weakType = false, dtype = int32
|
|
670
|
+
* const z = x.add(y); // z has dtype int32 because x is weakly typed
|
|
671
|
+
* ```
|
|
672
|
+
*
|
|
673
|
+
* Weak types are present in JIT programs in their spec (e.g., Jaxpr inputs
|
|
674
|
+
* and outputs can be weakly typed) form. But they're solely a frontend
|
|
675
|
+
* concept. Backends are not aware of weak types.
|
|
676
|
+
*/
|
|
677
|
+
weakType: boolean;
|
|
608
678
|
}
|
|
679
|
+
/**
|
|
680
|
+
* Broadcast shapes and promote types with casting for two avals.
|
|
681
|
+
*
|
|
682
|
+
* This implements the weak type behavior described in `promoteTypes()`, but not
|
|
683
|
+
* implemented in that function as `weakType` is not passed.
|
|
684
|
+
*/
|
|
685
|
+
|
|
609
686
|
declare abstract class Tracer {
|
|
610
687
|
/** @ignore */
|
|
611
688
|
readonly _trace: Trace;
|
|
@@ -659,9 +736,20 @@ declare abstract class Tracer {
|
|
|
659
736
|
* ```
|
|
660
737
|
*/
|
|
661
738
|
abstract dispose(): void;
|
|
739
|
+
/** The shape of the array. */
|
|
662
740
|
get shape(): number[];
|
|
741
|
+
/** The total number of elements in the array. */
|
|
663
742
|
get size(): number;
|
|
743
|
+
/** The dtype of elements stored in the array. */
|
|
664
744
|
get dtype(): DType;
|
|
745
|
+
/**
|
|
746
|
+
* Whether the array is weakly typed.
|
|
747
|
+
*
|
|
748
|
+
* Weakly typed arrays will cast to the dtype of the other operand. See
|
|
749
|
+
* `promoteTypes()` for details.
|
|
750
|
+
*/
|
|
751
|
+
get weakType(): boolean;
|
|
752
|
+
/** The number of dimensions of the array. */
|
|
665
753
|
get ndim(): number;
|
|
666
754
|
/** @ignore */
|
|
667
755
|
fullLower(): Tracer;
|
|
@@ -675,11 +763,11 @@ declare abstract class Tracer {
|
|
|
675
763
|
greaterEqual(other: this | TracerValue): this;
|
|
676
764
|
lessEqual(other: this | TracerValue): this;
|
|
677
765
|
/** Sum of the elements of the array over a given axis, or axes. */
|
|
678
|
-
sum(axis?:
|
|
766
|
+
sum(axis?: Axis, opts?: ReduceOpts): this;
|
|
679
767
|
/** Product of the array elements over a given axis. */
|
|
680
|
-
prod(axis?:
|
|
768
|
+
prod(axis?: Axis, opts?: ReduceOpts): this;
|
|
681
769
|
/** Compute the average of the array elements along the specified axis. */
|
|
682
|
-
mean(axis?:
|
|
770
|
+
mean(axis?: Axis, opts?: ReduceOpts): this;
|
|
683
771
|
/** Permute the dimensions of an array. Defaults to reversing the axis order. */
|
|
684
772
|
transpose(perm?: number[]): this;
|
|
685
773
|
/**
|
|
@@ -752,7 +840,8 @@ declare abstract class Tracer {
|
|
|
752
840
|
declare class ShapedArray implements AbstractValue {
|
|
753
841
|
readonly shape: number[];
|
|
754
842
|
readonly dtype: DType;
|
|
755
|
-
|
|
843
|
+
readonly weakType: boolean;
|
|
844
|
+
constructor(shape: number[], dtype: DType, weakType: boolean);
|
|
756
845
|
static fromAval(aval: AbstractValue): ShapedArray;
|
|
757
846
|
get ndim(): number;
|
|
758
847
|
toString(): string;
|
|
@@ -788,6 +877,14 @@ type DTypeAndDevice = {
|
|
|
788
877
|
dtype?: DType;
|
|
789
878
|
device?: Device;
|
|
790
879
|
};
|
|
880
|
+
type ArrayConstructorArgs = {
|
|
881
|
+
source: AluExp | Slot;
|
|
882
|
+
st: ShapeTracker;
|
|
883
|
+
dtype: DType;
|
|
884
|
+
weakType: boolean;
|
|
885
|
+
backend: Backend;
|
|
886
|
+
pending?: Iterable<PendingExecute>;
|
|
887
|
+
};
|
|
791
888
|
/**
|
|
792
889
|
* A multidimensional numeric array with data stored on CPU or GPU.
|
|
793
890
|
*
|
|
@@ -807,7 +904,7 @@ declare class Array extends Tracer {
|
|
|
807
904
|
* is a backend `Slot`, this constructor _takes ownership_ of the slot. It
|
|
808
905
|
* will be freed when the array is disposed.
|
|
809
906
|
*/
|
|
810
|
-
constructor(
|
|
907
|
+
constructor(args: ArrayConstructorArgs);
|
|
811
908
|
/** @ignore */
|
|
812
909
|
get aval(): ShapedArray;
|
|
813
910
|
/** Return a simple string representation of the array's dimensions. */
|
|
@@ -836,8 +933,11 @@ declare class Array extends Tracer {
|
|
|
836
933
|
*
|
|
837
934
|
* If you are mapping from `data()` or `dataSync()`, it will also trigger
|
|
838
935
|
* dispatch of operations as well.
|
|
936
|
+
*
|
|
937
|
+
* **Note:** `jax.blockUntilReady()` is a higher-level API, it calls this
|
|
938
|
+
* asynchronously for multiple arrays.
|
|
839
939
|
*/
|
|
840
|
-
|
|
940
|
+
blockUntilReady(): Promise<Array>;
|
|
841
941
|
/**
|
|
842
942
|
* Realize the array and return it as data. This is a sync variant and not
|
|
843
943
|
* recommended for performance reasons, as it will block rendering.
|
|
@@ -866,11 +966,6 @@ declare class Array extends Tracer {
|
|
|
866
966
|
static _implRules(): typeof implRules;
|
|
867
967
|
_realizeSource(): number;
|
|
868
968
|
}
|
|
869
|
-
/** Construct an array from a single scalar constant. */
|
|
870
|
-
declare function scalar(value: number | boolean, {
|
|
871
|
-
dtype,
|
|
872
|
-
device
|
|
873
|
-
}?: DTypeAndDevice): Array;
|
|
874
969
|
/** Constructor for creating a new array from data. */
|
|
875
970
|
declare function array(values: Array | Float16Array<ArrayBuffer> | Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | RecursiveArray<number> | RecursiveArray<boolean>, {
|
|
876
971
|
shape,
|
|
@@ -908,7 +1003,7 @@ declare function eye(numRows: number, numCols?: number, {
|
|
|
908
1003
|
dtype,
|
|
909
1004
|
device
|
|
910
1005
|
}?: DTypeAndDevice): Array;
|
|
911
|
-
/** Return the identity
|
|
1006
|
+
/** Return the identity matrix, with ones on the main diagonal. */
|
|
912
1007
|
declare function identity$1(n: number, {
|
|
913
1008
|
dtype,
|
|
914
1009
|
device
|
|
@@ -945,7 +1040,7 @@ declare function linspace(start: number, stop: number, num?: number, endpoint?:
|
|
|
945
1040
|
device
|
|
946
1041
|
}?: DTypeAndDevice): Array;
|
|
947
1042
|
declare namespace numpy_d_exports {
|
|
948
|
-
export { Array, ArrayLike, DType, abs, absolute, add, allclose, arange, argmax, argmin, array, astype, bool, clip, columnStack, concatenate, cos, cosh, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hstack, identity$1 as identity, inf, int32, less, lessEqual, linspace, log, log10, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, pad, permuteDims, pi, prod, ravel, reciprocal,
|
|
1043
|
+
export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hstack, hypot, identity$1 as identity, inf, inner, int32, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
949
1044
|
}
|
|
950
1045
|
declare const float32 = DType.Float32;
|
|
951
1046
|
declare const int32 = DType.Int32;
|
|
@@ -962,54 +1057,66 @@ declare const inf: number;
|
|
|
962
1057
|
declare const nan: number;
|
|
963
1058
|
/** This is Pi, `π = 3.14159265358979...` */
|
|
964
1059
|
declare const pi: number;
|
|
965
|
-
/** Element-wise addition, with broadcasting. */
|
|
1060
|
+
/** @function Element-wise addition, with broadcasting. */
|
|
966
1061
|
declare const add: (x: ArrayLike, y: ArrayLike) => Array;
|
|
967
|
-
/** Element-wise multiplication, with broadcasting. */
|
|
1062
|
+
/** @function Element-wise multiplication, with broadcasting. */
|
|
968
1063
|
declare const multiply: (x: ArrayLike, y: ArrayLike) => Array;
|
|
969
|
-
/** Numerical negative of every element of an array. */
|
|
1064
|
+
/** @function Numerical negative of every element of an array. */
|
|
970
1065
|
declare const negative: (x: ArrayLike) => Array;
|
|
971
|
-
/** Calculate element-wise reciprocal of the input. This is `1/x`. */
|
|
1066
|
+
/** @function Calculate element-wise reciprocal of the input. This is `1/x`. */
|
|
972
1067
|
declare const reciprocal: (x: ArrayLike) => Array;
|
|
973
|
-
/** Element-wise sine function (takes radians). */
|
|
1068
|
+
/** @function Element-wise sine function (takes radians). */
|
|
974
1069
|
declare const sin: (x: ArrayLike) => Array;
|
|
975
|
-
/** Element-wise cosine function (takes radians). */
|
|
1070
|
+
/** @function Element-wise cosine function (takes radians). */
|
|
976
1071
|
declare const cos: (x: ArrayLike) => Array;
|
|
977
|
-
/**
|
|
1072
|
+
/** @function Element-wise inverse sine function (inverse of sin). */
|
|
1073
|
+
declare const asin: (x: ArrayLike) => Array;
|
|
1074
|
+
/** @function Element-wise inverse tangent function (inverse of tan). */
|
|
1075
|
+
declare const atan: (x: ArrayLike) => Array;
|
|
1076
|
+
/** @function Calculate the exponential of all elements in the input array. */
|
|
978
1077
|
declare const exp: (x: ArrayLike) => Array;
|
|
979
|
-
/** Calculate the natural logarithm of all elements in the input array. */
|
|
1078
|
+
/** @function Calculate the natural logarithm of all elements in the input array. */
|
|
980
1079
|
declare const log: (x: ArrayLike) => Array;
|
|
981
|
-
/** Calculate the square root of all elements in the input array. */
|
|
1080
|
+
/** @function Calculate the square root of all elements in the input array. */
|
|
982
1081
|
declare const sqrt: (x: ArrayLike) => Array;
|
|
983
|
-
/** Return element-wise minimum of the input arrays. */
|
|
1082
|
+
/** @function Return element-wise minimum of the input arrays. */
|
|
984
1083
|
declare const minimum: (x: ArrayLike, y: ArrayLike) => Array;
|
|
985
|
-
/** Return element-wise maximum of the input arrays. */
|
|
1084
|
+
/** @function Return element-wise maximum of the input arrays. */
|
|
986
1085
|
declare const maximum: (x: ArrayLike, y: ArrayLike) => Array;
|
|
987
|
-
/** Compare two arrays element-wise. */
|
|
1086
|
+
/** @function Compare two arrays element-wise. */
|
|
988
1087
|
declare const greater: (x: ArrayLike, y: ArrayLike) => Array;
|
|
989
|
-
/** Compare two arrays element-wise. */
|
|
1088
|
+
/** @function Compare two arrays element-wise. */
|
|
990
1089
|
declare const less: (x: ArrayLike, y: ArrayLike) => Array;
|
|
991
|
-
/** Compare two arrays element-wise. */
|
|
1090
|
+
/** @function Compare two arrays element-wise. */
|
|
992
1091
|
declare const equal: (x: ArrayLike, y: ArrayLike) => Array;
|
|
993
|
-
/** Compare two arrays element-wise. */
|
|
1092
|
+
/** @function Compare two arrays element-wise. */
|
|
994
1093
|
declare const notEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
995
|
-
/** Compare two arrays element-wise. */
|
|
1094
|
+
/** @function Compare two arrays element-wise. */
|
|
996
1095
|
declare const greaterEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
997
|
-
/** Compare two arrays element-wise. */
|
|
1096
|
+
/** @function Compare two arrays element-wise. */
|
|
998
1097
|
declare const lessEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
999
|
-
/** Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
1098
|
+
/** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
1000
1099
|
declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
|
|
1001
|
-
/**
|
|
1100
|
+
/**
|
|
1101
|
+
* @function
|
|
1102
|
+
* Permute the dimensions of an array. Defaults to reversing the axis order.
|
|
1103
|
+
*/
|
|
1002
1104
|
declare const transpose: (x: ArrayLike, perm?: number[]) => Array;
|
|
1003
1105
|
/**
|
|
1106
|
+
* @function
|
|
1004
1107
|
* Give a new shape to an array without changing its data.
|
|
1005
1108
|
*
|
|
1006
1109
|
* One shape dimension can be -1. In this case, the value is inferred from the
|
|
1007
1110
|
* length of the array and remaining dimensions.
|
|
1008
1111
|
*/
|
|
1009
1112
|
declare const reshape: (x: ArrayLike, shape: number[]) => Array;
|
|
1010
|
-
/**
|
|
1113
|
+
/**
|
|
1114
|
+
* @function
|
|
1115
|
+
* Move axes of an array to new positions. Other axes retain original order.
|
|
1116
|
+
*/
|
|
1011
1117
|
declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
|
|
1012
1118
|
/**
|
|
1119
|
+
* @function
|
|
1013
1120
|
* Add padding (zeros) to an array.
|
|
1014
1121
|
*
|
|
1015
1122
|
* The `width` argument is either an integer or pair of integers, in which case
|
|
@@ -1017,15 +1124,27 @@ declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
|
|
|
1017
1124
|
* pair specifies the padding for its corresponding axis.
|
|
1018
1125
|
*/
|
|
1019
1126
|
declare const pad: (x: ArrayLike, width: number | [number, number] | [number, number][]) => Array;
|
|
1020
|
-
/**
|
|
1127
|
+
/**
|
|
1128
|
+
* @function
|
|
1129
|
+
* Return the number of dimensions of an array. Does not consume array reference.
|
|
1130
|
+
*/
|
|
1021
1131
|
declare const ndim: (x: ArrayLike) => number;
|
|
1022
|
-
/** Return the shape of an array. Does not consume array reference. */
|
|
1132
|
+
/** @function Return the shape of an array. Does not consume array reference. */
|
|
1023
1133
|
declare const shape$1: (x: ArrayLike) => number[];
|
|
1024
|
-
/**
|
|
1134
|
+
/**
|
|
1135
|
+
* @function
|
|
1136
|
+
* Return an array of zeros with the same shape and type as a given array.
|
|
1137
|
+
*/
|
|
1025
1138
|
declare const zerosLike: (a: ArrayLike, dtype?: DType) => Array;
|
|
1026
|
-
/**
|
|
1139
|
+
/**
|
|
1140
|
+
* @function
|
|
1141
|
+
* Return an array of ones with the same shape and type as a given array.
|
|
1142
|
+
*/
|
|
1027
1143
|
declare const onesLike: (a: ArrayLike, dtype?: DType) => Array;
|
|
1028
|
-
/**
|
|
1144
|
+
/**
|
|
1145
|
+
* @function
|
|
1146
|
+
* Return a full array with the same shape and type as a given array.
|
|
1147
|
+
*/
|
|
1029
1148
|
declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, dtype?: DType) => Array;
|
|
1030
1149
|
/**
|
|
1031
1150
|
* Return the number of elements in an array, optionally along an axis.
|
|
@@ -1035,15 +1154,15 @@ declare function size(a: ArrayLike, axis?: number): number;
|
|
|
1035
1154
|
/** Convert an array to a specified dtype. */
|
|
1036
1155
|
declare function astype(a: ArrayLike, dtype: DType): Array;
|
|
1037
1156
|
/** Sum of the elements of the array over a given axis, or axes. */
|
|
1038
|
-
declare function sum(a: ArrayLike, axis?:
|
|
1157
|
+
declare function sum(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1039
1158
|
/** Product of the array elements over a given axis. */
|
|
1040
|
-
declare function prod(a: ArrayLike, axis?:
|
|
1159
|
+
declare function prod(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1041
1160
|
/** Return the minimum of array elements along a given axis. */
|
|
1042
|
-
declare function min(a: ArrayLike, axis?:
|
|
1161
|
+
declare function min(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1043
1162
|
/** Return the maximum of array elements along a given axis. */
|
|
1044
|
-
declare function max(a: ArrayLike, axis?:
|
|
1163
|
+
declare function max(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1045
1164
|
/** Compute the average of the array elements along the specified axis. */
|
|
1046
|
-
declare function mean(a: ArrayLike, axis?:
|
|
1165
|
+
declare function mean(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1047
1166
|
/**
|
|
1048
1167
|
* Returns the indices of the minimum values along an axis.
|
|
1049
1168
|
*
|
|
@@ -1059,7 +1178,7 @@ declare function argmin(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
|
|
|
1059
1178
|
*/
|
|
1060
1179
|
declare function argmax(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
|
|
1061
1180
|
/** Reverse the elements in an array along the given axes. */
|
|
1062
|
-
declare function flip(x: ArrayLike, axis?:
|
|
1181
|
+
declare function flip(x: ArrayLike, axis?: Axis): Array;
|
|
1063
1182
|
/**
|
|
1064
1183
|
* Join a sequence of arrays along an existing axis.
|
|
1065
1184
|
*
|
|
@@ -1103,9 +1222,36 @@ declare function columnStack(xs: ArrayLike[]): Array;
|
|
|
1103
1222
|
declare function flipud(x: ArrayLike): Array;
|
|
1104
1223
|
/** Flip an array horizontally (axis=1). */
|
|
1105
1224
|
declare function fliplr(x: ArrayLike): Array;
|
|
1225
|
+
/** @function Alternative name for `numpy.transpose()`. */
|
|
1106
1226
|
declare const permuteDims: (x: ArrayLike, perm?: number[]) => Array;
|
|
1107
1227
|
/** Return a 1-D flattened array containing the elements of the input. */
|
|
1108
1228
|
declare function ravel(a: ArrayLike): Array;
|
|
1229
|
+
/**
|
|
1230
|
+
* Repeat each element of an array after themselves.
|
|
1231
|
+
*
|
|
1232
|
+
* If no axis is provided, use the flattened input array, and return a flat
|
|
1233
|
+
* output array.
|
|
1234
|
+
*/
|
|
1235
|
+
declare function repeat(a: ArrayLike, repeats: number, axis?: number): Array;
|
|
1236
|
+
/**
|
|
1237
|
+
* Construct an array by repeating A the number of times given by reps.
|
|
1238
|
+
*
|
|
1239
|
+
* If `A` is an array of shape `(d1, d2, ..., dn)` and `reps` is a sequence of
|
|
1240
|
+
* integers, the resulting array will have a shape of `(reps[0] * d1,
|
|
1241
|
+
* reps[1] * d2, ..., reps[n] * dn)`, with `A` tiled along each dimension.
|
|
1242
|
+
*/
|
|
1243
|
+
declare function tile(a: ArrayLike, reps: number | number[]): Array;
|
|
1244
|
+
/**
|
|
1245
|
+
* Broadcast an array to a shape, with NumPy-style broadcasing rules.
|
|
1246
|
+
*
|
|
1247
|
+
* In other words, this lets you append axes to the left, and/or expand
|
|
1248
|
+
* dimensions where the shape is 1.
|
|
1249
|
+
*/
|
|
1250
|
+
declare function broadcastTo(a: ArrayLike, shape: number[]): Array;
|
|
1251
|
+
/** Broadcast input shapes to a common output shape. */
|
|
1252
|
+
declare function broadcastShapes(...shapes: number[][]): number[];
|
|
1253
|
+
/** Broadcast arrays to a common shape. */
|
|
1254
|
+
declare function broadcastArrays(...arrays: ArrayLike[]): Array[];
|
|
1109
1255
|
/**
|
|
1110
1256
|
* Return specified diagonals.
|
|
1111
1257
|
*
|
|
@@ -1133,8 +1279,28 @@ declare function allclose(actual: Parameters<typeof array>[0], expected: Paramet
|
|
|
1133
1279
|
declare function matmul(x: ArrayLike, y: ArrayLike): Array;
|
|
1134
1280
|
/** Dot product of two arrays. */
|
|
1135
1281
|
declare function dot(x: ArrayLike, y: ArrayLike): Array;
|
|
1136
|
-
/**
|
|
1137
|
-
|
|
1282
|
+
/**
|
|
1283
|
+
* Compute the inner product of two arrays.
|
|
1284
|
+
*
|
|
1285
|
+
* Unlike `jax.numpy.matmul()` or `jax.numpy.dot()`, this always performs a
|
|
1286
|
+
* contraction on the last axis.
|
|
1287
|
+
*
|
|
1288
|
+
* Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
|
|
1289
|
+
*/
|
|
1290
|
+
declare function inner(x: ArrayLike, y: ArrayLike): Array;
|
|
1291
|
+
/**
|
|
1292
|
+
* Compute the outer product of two arrays.
|
|
1293
|
+
*
|
|
1294
|
+
* If the input arrays are not 1D, they will be flattened. Returned array will
|
|
1295
|
+
* be of shape `[x.size, y.size]`.
|
|
1296
|
+
*/
|
|
1297
|
+
declare function outer(x: ArrayLike, y: ArrayLike): Array;
|
|
1298
|
+
/** Vector dot product of two arrays along a given axis. */
|
|
1299
|
+
declare function vecdot(x: ArrayLike, y: ArrayLike, {
|
|
1300
|
+
axis
|
|
1301
|
+
}?: {
|
|
1302
|
+
axis?: number;
|
|
1303
|
+
}): Array;
|
|
1138
1304
|
/**
|
|
1139
1305
|
* Return the dot product of two vectors.
|
|
1140
1306
|
*
|
|
@@ -1152,6 +1318,21 @@ declare function meshgrid(xs: Array[], {
|
|
|
1152
1318
|
}?: {
|
|
1153
1319
|
indexing?: "xy" | "ij";
|
|
1154
1320
|
}): Array[];
|
|
1321
|
+
/**
|
|
1322
|
+
* Return an array with ones on and below the diagonal and zeros elsewhere.
|
|
1323
|
+
*
|
|
1324
|
+
* If `k` is provided, it specifies the sub-diagonal on and below which the
|
|
1325
|
+
* array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
|
|
1326
|
+
* `k>0` is above it.
|
|
1327
|
+
*/
|
|
1328
|
+
declare function tri(n: number, m?: number, k?: number, {
|
|
1329
|
+
dtype,
|
|
1330
|
+
device
|
|
1331
|
+
}?: DTypeAndDevice): Array;
|
|
1332
|
+
/** Return the lower triangle of an array. Must be of dimension >= 2. */
|
|
1333
|
+
declare function tril(a: ArrayLike, k?: number): Array;
|
|
1334
|
+
/** Return the upper triangle of an array. Must be of dimension >= 2. */
|
|
1335
|
+
declare function triu(a: ArrayLike, k?: number): Array;
|
|
1155
1336
|
/**
|
|
1156
1337
|
* Clip (limit) the values in an array.
|
|
1157
1338
|
*
|
|
@@ -1168,15 +1349,50 @@ declare function clip(a: ArrayLike, min?: ArrayLike, max?: ArrayLike): Array;
|
|
|
1168
1349
|
* This is the same function as `jax.numpy.abs()`.
|
|
1169
1350
|
*/
|
|
1170
1351
|
declare function absolute(x: ArrayLike): Array;
|
|
1171
|
-
/** Alias of `jax.numpy.absolute()`. */
|
|
1352
|
+
/** @function Alias of `jax.numpy.absolute()`. */
|
|
1172
1353
|
declare const abs: typeof absolute;
|
|
1354
|
+
/** Return an element-wise indication of sign of the input. */
|
|
1355
|
+
declare function sign(x: ArrayLike): Array;
|
|
1173
1356
|
/** Calculate element-wise square of the input array. */
|
|
1174
1357
|
declare function square(x: ArrayLike): Array;
|
|
1175
|
-
/**
|
|
1358
|
+
/** Element-wise tangent function (takes radians). */
|
|
1176
1359
|
declare function tan(x: ArrayLike): Array;
|
|
1360
|
+
/** Element-wise inverse cosine function (inverse of cos). */
|
|
1361
|
+
declare function acos(x: ArrayLike): Array;
|
|
1362
|
+
/**
|
|
1363
|
+
* @function
|
|
1364
|
+
* Return element-wise hypotenuse for the given legs of a right triangle.
|
|
1365
|
+
*
|
|
1366
|
+
* In the original NumPy/JAX implementation, this function is more numerically
|
|
1367
|
+
* stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
|
|
1368
|
+
* improvements.
|
|
1369
|
+
*/
|
|
1370
|
+
declare const hypot: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
1371
|
+
/**
|
|
1372
|
+
* @function
|
|
1373
|
+
* Element-wise arc tangent of y/x with correct quadrant.
|
|
1374
|
+
*
|
|
1375
|
+
* Returns the angle in radians between the positive x-axis and the point (x, y).
|
|
1376
|
+
* The result is in the range [-π, π].
|
|
1377
|
+
*
|
|
1378
|
+
* Uses numerically stable formulas:
|
|
1379
|
+
* - When x >= 0: atan2(y, x) = 2 * atan(y / (sqrt(x^2 + y^2) + x))
|
|
1380
|
+
* - When x < 0: atan2(y, x) = 2 * atan((sqrt(x^2 + y^2) - x) / y)
|
|
1381
|
+
*
|
|
1382
|
+
* The output is ill-defined when both x and y are zero.
|
|
1383
|
+
*/
|
|
1384
|
+
declare const atan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
|
|
1385
|
+
/** @function Alias of `jax.numpy.acos()`. */
|
|
1386
|
+
declare const arccos: typeof acos;
|
|
1387
|
+
/** @function Alias of `jax.numpy.atan()`. */
|
|
1388
|
+
declare const arctan: (x: ArrayLike) => Array;
|
|
1389
|
+
/** @function Alias of `jax.numpy.atan2()`. */
|
|
1390
|
+
declare const arctan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
|
|
1391
|
+
/** Element-wise subtraction, with broadcasting. */
|
|
1392
|
+
declare function subtract(x: ArrayLike, y: ArrayLike): Array;
|
|
1177
1393
|
/** Calculates the floating-point division of x by y element-wise. */
|
|
1178
1394
|
declare function trueDivide(x: ArrayLike, y: ArrayLike): Array;
|
|
1179
|
-
/** Alias of `jax.numpy.trueDivide()`. */
|
|
1395
|
+
/** @function Alias of `jax.numpy.trueDivide()`. */
|
|
1180
1396
|
declare const divide: typeof trueDivide;
|
|
1181
1397
|
/** Round input to the nearest integer towards zero. */
|
|
1182
1398
|
declare function trunc(x: ArrayLike): Array;
|
|
@@ -1186,26 +1402,112 @@ declare function exp2(p: ArrayLike): Array;
|
|
|
1186
1402
|
declare function log2(x: ArrayLike): Array;
|
|
1187
1403
|
/** Return the base-10 logarithm of x, element-wise. */
|
|
1188
1404
|
declare function log10(x: ArrayLike): Array;
|
|
1405
|
+
/** Calculate `exp(x) - 1` element-wise. */
|
|
1406
|
+
declare function expm1(x: ArrayLike): Array;
|
|
1407
|
+
/** Calculate the natural logarithm of `1 + x` element-wise. */
|
|
1408
|
+
declare function log1p(x: ArrayLike): Array;
|
|
1409
|
+
/** Convert angles from degrees to radians. */
|
|
1410
|
+
declare function deg2rad(x: ArrayLike): Array;
|
|
1411
|
+
/** @function Alias of `jax.numpy.deg2rad()`. */
|
|
1412
|
+
declare const radians: typeof deg2rad;
|
|
1413
|
+
/** Convert angles from radians to degrees. */
|
|
1414
|
+
declare function rad2deg(x: ArrayLike): Array;
|
|
1415
|
+
/** @function Alias of `jax.numpy.rad2deg()`. */
|
|
1416
|
+
declare const degrees: typeof rad2deg;
|
|
1417
|
+
/**
|
|
1418
|
+
* @function
|
|
1419
|
+
* Computes first array raised to power of second array, element-wise.
|
|
1420
|
+
*/
|
|
1421
|
+
declare const power: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
1422
|
+
/** @function Alias of `jax.numpy.power()`. */
|
|
1423
|
+
declare const pow: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
1424
|
+
/** @function Calculate the element-wise cube root of the input array. */
|
|
1425
|
+
declare const cbrt: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1189
1426
|
/**
|
|
1427
|
+
* @function
|
|
1190
1428
|
* Calculate element-wise hyperbolic sine of input.
|
|
1191
1429
|
*
|
|
1192
1430
|
* `sinh(x) = (exp(x) - exp(-x)) / 2`
|
|
1193
1431
|
*/
|
|
1194
|
-
declare
|
|
1432
|
+
declare const sinh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1195
1433
|
/**
|
|
1434
|
+
* @function
|
|
1196
1435
|
* Calculate element-wise hyperbolic cosine of input.
|
|
1197
1436
|
*
|
|
1198
1437
|
* `cosh(x) = (exp(x) + exp(-x)) / 2`
|
|
1199
1438
|
*/
|
|
1200
|
-
declare
|
|
1439
|
+
declare const cosh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1201
1440
|
/**
|
|
1441
|
+
* @function
|
|
1202
1442
|
* Calculate element-wise hyperbolic tangent of input.
|
|
1203
1443
|
*
|
|
1204
1444
|
* `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
|
|
1205
1445
|
*/
|
|
1206
|
-
declare
|
|
1446
|
+
declare const tanh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1447
|
+
/**
|
|
1448
|
+
* @function
|
|
1449
|
+
* Calculate element-wise inverse hyperbolic sine of input.
|
|
1450
|
+
*
|
|
1451
|
+
* `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
|
|
1452
|
+
*/
|
|
1453
|
+
declare const arcsinh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1454
|
+
/**
|
|
1455
|
+
* @function
|
|
1456
|
+
* Calculate element-wise inverse hyperbolic cosine of input.
|
|
1457
|
+
*
|
|
1458
|
+
* `arccosh(x) = ln(x + sqrt(x^2 - 1))`
|
|
1459
|
+
*/
|
|
1460
|
+
declare const arccosh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1461
|
+
/**
|
|
1462
|
+
* @function
|
|
1463
|
+
* Calculate element-wise inverse hyperbolic tangent of input.
|
|
1464
|
+
*
|
|
1465
|
+
* `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
|
|
1466
|
+
*/
|
|
1467
|
+
declare const arctanh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1468
|
+
/** @function Alias of `jax.numpy.arcsinh()`. */
|
|
1469
|
+
declare const asinh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1470
|
+
/** @function Alias of `jax.numpy.arccosh()`. */
|
|
1471
|
+
declare const acosh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1472
|
+
/** @function Alias of `jax.numpy.arctanh()`. */
|
|
1473
|
+
declare const atanh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1474
|
+
/**
|
|
1475
|
+
* Compute the variance of an array.
|
|
1476
|
+
*
|
|
1477
|
+
* The variance is computed for the flattened array by default, otherwise over
|
|
1478
|
+
* the specified axis.
|
|
1479
|
+
*
|
|
1480
|
+
* If `correction` is provided, the divisor in calculation is `N - correction`,
|
|
1481
|
+
* where `N` represents the number of elements (e.g., for Bessel's correction).
|
|
1482
|
+
*/
|
|
1483
|
+
declare function var_(x: ArrayLike, axis?: Axis, opts?: {
|
|
1484
|
+
mean?: ArrayLike;
|
|
1485
|
+
correction?: number;
|
|
1486
|
+
} & ReduceOpts): Array;
|
|
1487
|
+
/**
|
|
1488
|
+
* Compute the standard deviation of an array.
|
|
1489
|
+
*
|
|
1490
|
+
* The standard deviation is computed for the flattened array by default,
|
|
1491
|
+
* otherwise over the specified axis.
|
|
1492
|
+
*
|
|
1493
|
+
* If `correction` is provided, the divisor in calculation is `N - correction`,
|
|
1494
|
+
* where `N` represents the number of elements (e.g., for Bessel's correction).
|
|
1495
|
+
*/
|
|
1496
|
+
declare function std(x: ArrayLike, axis?: Axis, opts?: {
|
|
1497
|
+
mean?: ArrayLike;
|
|
1498
|
+
correction?: number;
|
|
1499
|
+
} & ReduceOpts): Array;
|
|
1207
1500
|
//#endregion
|
|
1208
1501
|
//#region src/frontend/jaxpr.d.ts
|
|
1502
|
+
/**
|
|
1503
|
+
* Function callback with an associated dispose() method.
|
|
1504
|
+
*
|
|
1505
|
+
* The dispose() method should be called to clean up any tracer resources needed
|
|
1506
|
+
* by the function after the last time it is called.
|
|
1507
|
+
*/
|
|
1508
|
+
type OwnedFunction<F extends Function> = F & {
|
|
1509
|
+
dispose: () => void;
|
|
1510
|
+
};
|
|
1209
1511
|
/** Variable in a Jaxpr expression. */
|
|
1210
1512
|
declare class Var {
|
|
1211
1513
|
#private;
|
|
@@ -1216,10 +1518,10 @@ declare class Var {
|
|
|
1216
1518
|
}
|
|
1217
1519
|
/** Literal in a Jaxpr expression. Currently, only scalars are supported. */
|
|
1218
1520
|
declare class Lit {
|
|
1219
|
-
readonly dtype: DType;
|
|
1220
1521
|
readonly value: number;
|
|
1221
1522
|
readonly aval: ShapedArray;
|
|
1222
|
-
|
|
1523
|
+
get dtype(): DType;
|
|
1524
|
+
constructor(aval: AbstractValue, value: number);
|
|
1223
1525
|
}
|
|
1224
1526
|
type Atom = Var | Lit;
|
|
1225
1527
|
declare class VarPrinter {
|
|
@@ -1297,7 +1599,7 @@ declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding:
|
|
|
1297
1599
|
/** Reduce a computation over padded windows. */
|
|
1298
1600
|
declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
|
|
1299
1601
|
declare namespace nn_d_exports {
|
|
1300
|
-
export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, swish };
|
|
1602
|
+
export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, squareplus, standardize, swish };
|
|
1301
1603
|
}
|
|
1302
1604
|
/**
|
|
1303
1605
|
* Rectified Linear Unit (ReLU) activation function:
|
|
@@ -1329,6 +1631,7 @@ declare function softplus(x: ArrayLike): Array;
|
|
|
1329
1631
|
*/
|
|
1330
1632
|
declare function softSign(x: ArrayLike): Array;
|
|
1331
1633
|
/**
|
|
1634
|
+
* @function
|
|
1332
1635
|
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
1333
1636
|
* Swish, computed element-wise:
|
|
1334
1637
|
* `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
|
|
@@ -1337,8 +1640,9 @@ declare function softSign(x: ArrayLike): Array;
|
|
|
1337
1640
|
*
|
|
1338
1641
|
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
1339
1642
|
*/
|
|
1340
|
-
declare const silu: (x: ArrayLike) => Array
|
|
1643
|
+
declare const silu: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1341
1644
|
/**
|
|
1645
|
+
* @function
|
|
1342
1646
|
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
1343
1647
|
* Swish, computed element-wise:
|
|
1344
1648
|
* `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
|
|
@@ -1347,31 +1651,35 @@ declare const silu: (x: ArrayLike) => Array;
|
|
|
1347
1651
|
*
|
|
1348
1652
|
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
1349
1653
|
*/
|
|
1350
|
-
declare const swish: (x: ArrayLike) => Array
|
|
1654
|
+
declare const swish: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1351
1655
|
/**
|
|
1352
1656
|
* Log-sigmoid activation function, computed element-wise:
|
|
1353
1657
|
* `log_sigmoid(x) = log(sigmoid(x)) = -log(1 + exp(-x))`.
|
|
1354
1658
|
*/
|
|
1355
1659
|
declare function logSigmoid(x: ArrayLike): Array;
|
|
1356
|
-
/**
|
|
1660
|
+
/**
|
|
1661
|
+
* @function
|
|
1662
|
+
* Identity activation function. Returns the argument unmodified.
|
|
1663
|
+
*/
|
|
1357
1664
|
declare const identity: (x: ArrayLike) => Array;
|
|
1358
1665
|
/** Leaky rectified linear (ReLU) activation function */
|
|
1359
|
-
declare function leakyRelu(x: ArrayLike, negativeSlope?:
|
|
1666
|
+
declare function leakyRelu(x: ArrayLike, negativeSlope?: ArrayLike): Array;
|
|
1360
1667
|
/**
|
|
1361
1668
|
* Exponential linear unit activation function.
|
|
1362
1669
|
*
|
|
1363
1670
|
* Computes the element-wise function:
|
|
1364
1671
|
* `elu(x) = x > 0 ? x : alpha * (exp(x) - 1)`
|
|
1365
1672
|
*/
|
|
1366
|
-
declare function elu(x: ArrayLike, alpha?:
|
|
1673
|
+
declare function elu(x: ArrayLike, alpha?: ArrayLike): Array;
|
|
1367
1674
|
/**
|
|
1368
1675
|
* Continuously-differentiable exponential linear unit activation function.
|
|
1369
1676
|
*
|
|
1370
1677
|
* Computes the element-wise function:
|
|
1371
1678
|
* `celu(x) = x > 0 ? x : alpha * (exp(x/alpha) - 1)`
|
|
1372
1679
|
*/
|
|
1373
|
-
declare function celu(x: ArrayLike, alpha?:
|
|
1680
|
+
declare function celu(x: ArrayLike, alpha?: ArrayLike): Array;
|
|
1374
1681
|
/**
|
|
1682
|
+
* @function
|
|
1375
1683
|
* Gaussion error linear unit (GELU) activation function.
|
|
1376
1684
|
*
|
|
1377
1685
|
* This is computed element-wise. Currently jax-js does not support the erf() or
|
|
@@ -1382,7 +1690,7 @@ declare function celu(x: ArrayLike, alpha?: number): Array;
|
|
|
1382
1690
|
*
|
|
1383
1691
|
* This will be improved in the future.
|
|
1384
1692
|
*/
|
|
1385
|
-
declare const gelu: (x: ArrayLike) => Array
|
|
1693
|
+
declare const gelu: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1386
1694
|
/**
|
|
1387
1695
|
* Gated linear unit (GLU) activation function.
|
|
1388
1696
|
*
|
|
@@ -1390,6 +1698,13 @@ declare const gelu: (x: ArrayLike) => Array;
|
|
|
1390
1698
|
* computes `a * sigmoid(b)`.
|
|
1391
1699
|
*/
|
|
1392
1700
|
declare function glu(x: ArrayLike, axis?: number): Array;
|
|
1701
|
+
/**
|
|
1702
|
+
* Squareplus activation function.
|
|
1703
|
+
*
|
|
1704
|
+
* Computes the element-wise function:
|
|
1705
|
+
* `squareplus(x) = 0.5 * (x + sqrt(x^2 + b))`
|
|
1706
|
+
*/
|
|
1707
|
+
declare function squareplus(x: ArrayLike, b?: ArrayLike): Array;
|
|
1393
1708
|
/**
|
|
1394
1709
|
* Mish activation function.
|
|
1395
1710
|
*
|
|
@@ -1405,7 +1720,7 @@ declare function mish(x: ArrayLike): Array;
|
|
|
1405
1720
|
*
|
|
1406
1721
|
* Reference: https://en.wikipedia.org/wiki/Softmax_function
|
|
1407
1722
|
*/
|
|
1408
|
-
declare function softmax(x: ArrayLike, axis?:
|
|
1723
|
+
declare function softmax(x: ArrayLike, axis?: Axis): Array;
|
|
1409
1724
|
/**
|
|
1410
1725
|
* Log-Softmax function.
|
|
1411
1726
|
*
|
|
@@ -1414,7 +1729,7 @@ declare function softmax(x: ArrayLike, axis?: number | number[]): Array;
|
|
|
1414
1729
|
*
|
|
1415
1730
|
* If `axis` is not specified, it defaults to the last axis.
|
|
1416
1731
|
*/
|
|
1417
|
-
declare function logSoftmax(x: ArrayLike, axis?:
|
|
1732
|
+
declare function logSoftmax(x: ArrayLike, axis?: Axis): Array;
|
|
1418
1733
|
/**
|
|
1419
1734
|
* Log-sum-exp reduction. Also a multivariate version of `softplus`.
|
|
1420
1735
|
*
|
|
@@ -1423,7 +1738,22 @@ declare function logSoftmax(x: ArrayLike, axis?: number | number[]): Array;
|
|
|
1423
1738
|
*
|
|
1424
1739
|
* Reference: https://en.wikipedia.org/wiki/LogSumExp
|
|
1425
1740
|
*/
|
|
1426
|
-
declare function logsumexp(x: ArrayLike, axis?:
|
|
1741
|
+
declare function logsumexp(x: ArrayLike, axis?: Axis): Array;
|
|
1742
|
+
/** Log-mean-exp reduction, like `jax.nn.logsumexp()` but subtracts `log(n)`. */
|
|
1743
|
+
declare function logmeanexp(x: ArrayLike, axis?: Axis): Array;
|
|
1744
|
+
/**
|
|
1745
|
+
* Standardizes input to zero mean and unit variance.
|
|
1746
|
+
*
|
|
1747
|
+
* By default, this is computed over the last axis. You can pass in a different
|
|
1748
|
+
* axis, or `null` to standardize over all elements.
|
|
1749
|
+
*
|
|
1750
|
+
* Epsilon is added to denominator, it defaults to `1e-5` for stability.
|
|
1751
|
+
*/
|
|
1752
|
+
declare function standardize(x: ArrayLike, axis?: Axis, opts?: {
|
|
1753
|
+
mean?: ArrayLike;
|
|
1754
|
+
variance?: ArrayLike;
|
|
1755
|
+
epsilon?: ArrayLike;
|
|
1756
|
+
}): Array;
|
|
1427
1757
|
/**
|
|
1428
1758
|
* One-hot encodes the given indices.
|
|
1429
1759
|
*
|
|
@@ -1442,7 +1772,7 @@ declare function logsumexp(x: ArrayLike, axis?: number | number[]): Array;
|
|
|
1442
1772
|
*/
|
|
1443
1773
|
declare function oneHot(x: Array, numClasses: number): Array;
|
|
1444
1774
|
declare namespace random_d_exports {
|
|
1445
|
-
export { bits, key, split, uniform };
|
|
1775
|
+
export { bernoulli, bits, exponential, key, normal, split, uniform };
|
|
1446
1776
|
}
|
|
1447
1777
|
/** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
|
|
1448
1778
|
declare function key(seed: number): Array;
|
|
@@ -1450,34 +1780,71 @@ declare function key(seed: number): Array;
|
|
|
1450
1780
|
declare function split(key: Array, num?: number | number[]): Array;
|
|
1451
1781
|
/** Sample uniform bits in the form of unsigned integers. */
|
|
1452
1782
|
declare function bits(key: Array, shape?: number[]): Array;
|
|
1453
|
-
/**
|
|
1454
|
-
|
|
1455
|
-
|
|
1456
|
-
|
|
1457
|
-
|
|
1458
|
-
minval?: number;
|
|
1459
|
-
maxval?: number;
|
|
1460
|
-
})
|
|
1783
|
+
/**
|
|
1784
|
+
* @function
|
|
1785
|
+
* Sample uniform random values in [minval, maxval) with given shape.
|
|
1786
|
+
*/
|
|
1787
|
+
declare const uniform: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined, args_2?: {
|
|
1788
|
+
minval?: number | undefined;
|
|
1789
|
+
maxval?: number | undefined;
|
|
1790
|
+
} | undefined) => Array>;
|
|
1791
|
+
/**
|
|
1792
|
+
* Sample Bernoulli random variables with given mean (0,1 categorical).
|
|
1793
|
+
*
|
|
1794
|
+
* Returns a random Boolean array with the specified shape. `p` can be an array
|
|
1795
|
+
* and must be broadcastable to `shape`.
|
|
1796
|
+
*/
|
|
1797
|
+
declare function bernoulli(key: Array, p?: ArrayLike, shape?: number[]): Array;
|
|
1798
|
+
/**
|
|
1799
|
+
* @function
|
|
1800
|
+
* Sample exponential random values according to `p(x) = exp(-x)`.
|
|
1801
|
+
*/
|
|
1802
|
+
declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
1803
|
+
/**
|
|
1804
|
+
* @function
|
|
1805
|
+
* Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
|
|
1806
|
+
*
|
|
1807
|
+
* Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
|
|
1808
|
+
* directly inverts the CDF, but we don't have support for that yet. Outputs will not be
|
|
1809
|
+
* bitwise identical to JAX.
|
|
1810
|
+
*/
|
|
1811
|
+
declare const normal: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
1461
1812
|
//#endregion
|
|
1462
1813
|
//#region src/index.d.ts
|
|
1463
|
-
/**
|
|
1814
|
+
/**
|
|
1815
|
+
* @function
|
|
1816
|
+
* Compute the forward-mode Jacobian-vector product for a function.
|
|
1817
|
+
*/
|
|
1464
1818
|
declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, ReturnType<F>];
|
|
1465
|
-
/**
|
|
1819
|
+
/**
|
|
1820
|
+
* @function
|
|
1821
|
+
* Vectorize an operation on a batched axis for one or more inputs.
|
|
1822
|
+
*/
|
|
1466
1823
|
declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number | MapJsTree<Parameters<F>, ArrayLike, number | null>) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
1467
|
-
/**
|
|
1824
|
+
/**
|
|
1825
|
+
* @function
|
|
1826
|
+
* Compute the Jacobian evaluated column-by-column by forward-mode AD.
|
|
1827
|
+
*/
|
|
1468
1828
|
declare const jacfwd: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
1469
|
-
/**
|
|
1829
|
+
/**
|
|
1830
|
+
* @function
|
|
1831
|
+
* Construct a Jaxpr by dynamically tracing a function with example inputs.
|
|
1832
|
+
*/
|
|
1470
1833
|
declare const makeJaxpr: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...args: Parameters<F>) => {
|
|
1471
1834
|
jaxpr: Jaxpr;
|
|
1472
1835
|
consts: Array[];
|
|
1473
1836
|
treedef: JsTreeDef;
|
|
1474
1837
|
};
|
|
1475
1838
|
/**
|
|
1839
|
+
* @function
|
|
1476
1840
|
* Mark a function for automatic JIT compilation, with operator fusion.
|
|
1477
1841
|
*
|
|
1478
1842
|
* The function will be compiled the first time it is called with a set of
|
|
1479
1843
|
* argument shapes.
|
|
1480
1844
|
*
|
|
1845
|
+
* You can call `.dispose()` on the returned, JIT-compiled function after all
|
|
1846
|
+
* calls to free memory associated with array constants.
|
|
1847
|
+
*
|
|
1481
1848
|
* **Options:**
|
|
1482
1849
|
* - `staticArgnums`: An array of argument indices to treat as static
|
|
1483
1850
|
* (compile-time constant). These arguments must be hashable, won't be traced,
|
|
@@ -1485,24 +1852,48 @@ declare const makeJaxpr: <F extends (...args: any[]) => JsTree<Array>>(f: F) =>
|
|
|
1485
1852
|
* - `device`: The device to place the computation on. If not specified, the
|
|
1486
1853
|
* computation will be placed on the default device.
|
|
1487
1854
|
*/
|
|
1488
|
-
declare const jit: <F extends (...args: any[]) => JsTree<Array>>(f: F, opts?: JitOpts) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F
|
|
1855
|
+
declare const jit: <F extends (...args: any[]) => JsTree<Array>>(f: F, opts?: JitOpts) => OwnedFunction<(...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>>;
|
|
1489
1856
|
/**
|
|
1857
|
+
* @function
|
|
1490
1858
|
* Produce a local linear approximation to a function at a point using jvp() and
|
|
1491
1859
|
* partial evaluation.
|
|
1492
1860
|
*/
|
|
1493
1861
|
declare const linearize: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (...tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>];
|
|
1494
|
-
/**
|
|
1862
|
+
/**
|
|
1863
|
+
* @function
|
|
1864
|
+
* Calculate the reverse-mode vector-Jacobian product for a function.
|
|
1865
|
+
*/
|
|
1495
1866
|
declare const vjp: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (cotangents: MapJsTree<ReturnType<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>];
|
|
1496
1867
|
/**
|
|
1868
|
+
* @function
|
|
1497
1869
|
* Compute the gradient of a scalar-valued function `f` with respect to its
|
|
1498
1870
|
* first argument.
|
|
1499
1871
|
*/
|
|
1500
1872
|
declare const grad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>[0], ArrayLike, Array>;
|
|
1501
|
-
/**
|
|
1873
|
+
/**
|
|
1874
|
+
* @function
|
|
1875
|
+
* Create a function that evaluates both `f` and the gradient of `f`.
|
|
1876
|
+
*/
|
|
1502
1877
|
declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, MapJsTree<Parameters<F>[0], ArrayLike, Array>];
|
|
1503
|
-
/**
|
|
1878
|
+
/**
|
|
1879
|
+
* @function
|
|
1880
|
+
* Compute the Jacobian evaluated row-by-row by reverse-mode AD.
|
|
1881
|
+
*/
|
|
1504
1882
|
declare const jacrev: typeof jacfwd;
|
|
1505
|
-
/**
|
|
1883
|
+
/**
|
|
1884
|
+
* @function
|
|
1885
|
+
* Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
|
|
1886
|
+
*/
|
|
1506
1887
|
declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
1888
|
+
/**
|
|
1889
|
+
* Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
|
|
1890
|
+
*
|
|
1891
|
+
* This can be used to wait for the results of an intermediate computation to
|
|
1892
|
+
* finish. It's recommended to call this regularly in an iterative computation
|
|
1893
|
+
* to avoid queueing up too many pending operations.
|
|
1894
|
+
*
|
|
1895
|
+
* Does not consume reference to the arrays.
|
|
1896
|
+
*/
|
|
1897
|
+
declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
|
|
1507
1898
|
//#endregion
|
|
1508
|
-
export { DType, type Device, type JsTree, type JsTreeDef, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random,
|
|
1899
|
+
export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|