@jax-js/jax 0.0.3 → 0.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.d.ts CHANGED
@@ -123,6 +123,19 @@ declare class ShapeTracker {
123
123
  }
124
124
  //#endregion
125
125
  //#region src/utils.d.ts
126
+ /**
127
+ * Set the debug level for verbose logging.
128
+ *
129
+ * 1. JIT compile logs
130
+ * 2. Shader code
131
+ * 3. Expressions and metadata
132
+ * 4. JIT programs, tuning details
133
+ * 5. Most verbose operation traces
134
+ *
135
+ * This is an experimental API and may change in behavior. Do not rely on this
136
+ * in production.
137
+ */
138
+ declare function setDebug(level: number): void;
126
139
  /** @inline */
127
140
  type RecursiveArray<T> = T | RecursiveArray<T>[];
128
141
  interface FpHashable {
@@ -138,7 +151,7 @@ interface FpHashable {
138
151
  declare class FpHash {
139
152
  #private;
140
153
  value: bigint;
141
- update(...values: (string | boolean | number | bigint | null | undefined | FpHashable)[]): this;
154
+ update(x: string | boolean | number | bigint | null | undefined | FpHashable): this;
142
155
  static hash(...values: (string | boolean | number | bigint | null | undefined | FpHashable)[]): bigint;
143
156
  }
144
157
  /** Run a function while caching it inline inside a `Map`. */
@@ -154,6 +167,31 @@ declare enum DType {
154
167
  }
155
168
  /** @inline */
156
169
  type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | Float16Array<ArrayBuffer>;
170
+ /**
171
+ * Promote two dtypes to their join according to the type lattice.
172
+ *
173
+ * When performing operations between arrays of different types, we need to
174
+ * promote both operands to a common type that can represent values from both
175
+ * input types. This follows JAX's type promotion rules.
176
+ *
177
+ * **Type lattice:**
178
+ * ```text
179
+ * bool -> uint32 -> int32 -> float16 -> float32
180
+ * weak f* --^
181
+ * ```
182
+ *
183
+ * The asterisk f* is a weak type used for JS number constants. When creating
184
+ * arrays, JS numbers default to float32 but "weak" so they cast to the dtype of
185
+ * any array they are first combined with.
186
+ *
187
+ * **Examples:**
188
+ * - `promoteTypes(bool, int32) → int32`
189
+ * - `promoteTypes(uint32, int32) → int32`
190
+ * - `promoteTypes(int32, float16) → float16`
191
+ * - `promoteTypes(float16, float32) → float32`
192
+ * - `promoteTypes(uint32, float32) → float32`
193
+ */
194
+ declare function promoteTypes(dtype1: DType, dtype2: DType): DType;
157
195
  /**
158
196
  * Mathematical expression on scalar values.
159
197
  *
@@ -177,6 +215,8 @@ declare class AluExp implements FpHashable {
177
215
  static max(a: AluExp, b: AluExp): AluExp;
178
216
  static sin(a: AluExp): AluExp;
179
217
  static cos(a: AluExp): AluExp;
218
+ static asin(a: AluExp): AluExp;
219
+ static atan(a: AluExp): AluExp;
180
220
  static exp(a: AluExp): AluExp;
181
221
  static log(a: AluExp): AluExp;
182
222
  static sqrt(a: AluExp): AluExp;
@@ -262,6 +302,8 @@ declare enum AluOp {
262
302
  Max = "Max",
263
303
  Sin = "Sin",
264
304
  Cos = "Cos",
305
+ Asin = "Asin",
306
+ Atan = "Atan",
265
307
  Exp = "Exp",
266
308
  Log = "Log",
267
309
  Sqrt = "Sqrt",
@@ -354,8 +396,8 @@ declare class Reduction implements FpHashable {
354
396
  //#region src/backend.d.ts
355
397
  type Device = "cpu" | "wasm" | "webgpu";
356
398
  declare const devices: Device[];
357
- /** Set the default device backend (must be initialized). */
358
- declare function setDevice(device: Device): void;
399
+ /** Configure the default device for arrays. */
400
+ declare function defaultDevice(device?: Device): Device;
359
401
  /**
360
402
  * Initialize `jax-js` library backends.
361
403
  *
@@ -491,6 +533,8 @@ declare enum Primitive {
491
533
  RandomBits = "random_bits",
492
534
  Sin = "sin",
493
535
  Cos = "cos",
536
+ Asin = "asin",
537
+ Atan = "atan",
494
538
  Exp = "exp",
495
539
  Log = "log",
496
540
  Sqrt = "sqrt",
@@ -581,8 +625,10 @@ declare enum CompareOp {
581
625
  LessEqual = "less_equal",
582
626
  }
583
627
  /** @inline */
628
+ type Axis = number | number[] | null;
629
+ /** @inline */
584
630
  type ReduceOpts = {
585
- keepDims?: boolean;
631
+ keepdims?: boolean;
586
632
  };
587
633
  type MainTrace = {
588
634
  level: number;
@@ -659,9 +705,13 @@ declare abstract class Tracer {
659
705
  * ```
660
706
  */
661
707
  abstract dispose(): void;
708
+ /** The shape of the array. */
662
709
  get shape(): number[];
710
+ /** The total number of elements in the array. */
663
711
  get size(): number;
712
+ /** The dtype of the array. */
664
713
  get dtype(): DType;
714
+ /** The number of dimensions of the array. */
665
715
  get ndim(): number;
666
716
  /** @ignore */
667
717
  fullLower(): Tracer;
@@ -675,11 +725,11 @@ declare abstract class Tracer {
675
725
  greaterEqual(other: this | TracerValue): this;
676
726
  lessEqual(other: this | TracerValue): this;
677
727
  /** Sum of the elements of the array over a given axis, or axes. */
678
- sum(axis?: number | number[], opts?: ReduceOpts): this;
728
+ sum(axis?: Axis, opts?: ReduceOpts): this;
679
729
  /** Product of the array elements over a given axis. */
680
- prod(axis?: number | number[], opts?: ReduceOpts): this;
730
+ prod(axis?: Axis, opts?: ReduceOpts): this;
681
731
  /** Compute the average of the array elements along the specified axis. */
682
- mean(axis?: number | number[], opts?: ReduceOpts): this;
732
+ mean(axis?: Axis, opts?: ReduceOpts): this;
683
733
  /** Permute the dimensions of an array. Defaults to reversing the axis order. */
684
734
  transpose(perm?: number[]): this;
685
735
  /**
@@ -807,7 +857,11 @@ declare class Array extends Tracer {
807
857
  * is a backend `Slot`, this constructor _takes ownership_ of the slot. It
808
858
  * will be freed when the array is disposed.
809
859
  */
810
- constructor(source: AluExp | Slot, st: ShapeTracker, dtype: DType, backend: Backend, pending?: Iterable<PendingExecute> | null);
860
+ constructor(source: AluExp | Slot, st: ShapeTracker, dtype: DType, backend: Backend, {
861
+ pending
862
+ }?: {
863
+ pending?: Iterable<PendingExecute> | null;
864
+ });
811
865
  /** @ignore */
812
866
  get aval(): ShapedArray;
813
867
  /** Return a simple string representation of the array's dimensions. */
@@ -836,8 +890,11 @@ declare class Array extends Tracer {
836
890
  *
837
891
  * If you are mapping from `data()` or `dataSync()`, it will also trigger
838
892
  * dispatch of operations as well.
893
+ *
894
+ * **Note:** `jax.blockUntilReady()` is a higher-level API, it calls this
895
+ * asynchronously for multiple arrays.
839
896
  */
840
- wait(): Promise<Array>;
897
+ blockUntilReady(): Promise<Array>;
841
898
  /**
842
899
  * Realize the array and return it as data. This is a sync variant and not
843
900
  * recommended for performance reasons, as it will block rendering.
@@ -867,10 +924,7 @@ declare class Array extends Tracer {
867
924
  _realizeSource(): number;
868
925
  }
869
926
  /** Construct an array from a single scalar constant. */
870
- declare function scalar(value: number | boolean, {
871
- dtype,
872
- device
873
- }?: DTypeAndDevice): Array;
927
+
874
928
  /** Constructor for creating a new array from data. */
875
929
  declare function array(values: Array | Float16Array<ArrayBuffer> | Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | RecursiveArray<number> | RecursiveArray<boolean>, {
876
930
  shape,
@@ -908,7 +962,7 @@ declare function eye(numRows: number, numCols?: number, {
908
962
  dtype,
909
963
  device
910
964
  }?: DTypeAndDevice): Array;
911
- /** Return the identity array, with ones on the main diagonal. */
965
+ /** Return the identity matrix, with ones on the main diagonal. */
912
966
  declare function identity$1(n: number, {
913
967
  dtype,
914
968
  device
@@ -945,7 +999,7 @@ declare function linspace(start: number, stop: number, num?: number, endpoint?:
945
999
  device
946
1000
  }?: DTypeAndDevice): Array;
947
1001
  declare namespace numpy_d_exports {
948
- export { Array, ArrayLike, DType, abs, absolute, add, allclose, arange, argmax, argmin, array, astype, bool, clip, columnStack, concatenate, cos, cosh, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hstack, identity$1 as identity, inf, int32, less, lessEqual, linspace, log, log10, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, pad, permuteDims, pi, prod, ravel, reciprocal, reshape, scalar, shape$1 as shape, sin, sinh, size, sqrt, square, stack, sum, tan, tanh, transpose, trueDivide, trunc, uint32, vdot, vecdot, vstack, where, zeros, zerosLike };
1002
+ export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hstack, hypot, identity$1 as identity, inf, inner, int32, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
949
1003
  }
950
1004
  declare const float32 = DType.Float32;
951
1005
  declare const int32 = DType.Int32;
@@ -962,54 +1016,66 @@ declare const inf: number;
962
1016
  declare const nan: number;
963
1017
  /** This is Pi, `π = 3.14159265358979...` */
964
1018
  declare const pi: number;
965
- /** Element-wise addition, with broadcasting. */
1019
+ /** @function Element-wise addition, with broadcasting. */
966
1020
  declare const add: (x: ArrayLike, y: ArrayLike) => Array;
967
- /** Element-wise multiplication, with broadcasting. */
1021
+ /** @function Element-wise multiplication, with broadcasting. */
968
1022
  declare const multiply: (x: ArrayLike, y: ArrayLike) => Array;
969
- /** Numerical negative of every element of an array. */
1023
+ /** @function Numerical negative of every element of an array. */
970
1024
  declare const negative: (x: ArrayLike) => Array;
971
- /** Calculate element-wise reciprocal of the input. This is `1/x`. */
1025
+ /** @function Calculate element-wise reciprocal of the input. This is `1/x`. */
972
1026
  declare const reciprocal: (x: ArrayLike) => Array;
973
- /** Element-wise sine function (takes radians). */
1027
+ /** @function Element-wise sine function (takes radians). */
974
1028
  declare const sin: (x: ArrayLike) => Array;
975
- /** Element-wise cosine function (takes radians). */
1029
+ /** @function Element-wise cosine function (takes radians). */
976
1030
  declare const cos: (x: ArrayLike) => Array;
977
- /** Calculate the exponential of all elements in the input array. */
1031
+ /** @function Element-wise inverse sine function (inverse of sin). */
1032
+ declare const asin: (x: ArrayLike) => Array;
1033
+ /** @function Element-wise inverse tangent function (inverse of tan). */
1034
+ declare const atan: (x: ArrayLike) => Array;
1035
+ /** @function Calculate the exponential of all elements in the input array. */
978
1036
  declare const exp: (x: ArrayLike) => Array;
979
- /** Calculate the natural logarithm of all elements in the input array. */
1037
+ /** @function Calculate the natural logarithm of all elements in the input array. */
980
1038
  declare const log: (x: ArrayLike) => Array;
981
- /** Calculate the square root of all elements in the input array. */
1039
+ /** @function Calculate the square root of all elements in the input array. */
982
1040
  declare const sqrt: (x: ArrayLike) => Array;
983
- /** Return element-wise minimum of the input arrays. */
1041
+ /** @function Return element-wise minimum of the input arrays. */
984
1042
  declare const minimum: (x: ArrayLike, y: ArrayLike) => Array;
985
- /** Return element-wise maximum of the input arrays. */
1043
+ /** @function Return element-wise maximum of the input arrays. */
986
1044
  declare const maximum: (x: ArrayLike, y: ArrayLike) => Array;
987
- /** Compare two arrays element-wise. */
1045
+ /** @function Compare two arrays element-wise. */
988
1046
  declare const greater: (x: ArrayLike, y: ArrayLike) => Array;
989
- /** Compare two arrays element-wise. */
1047
+ /** @function Compare two arrays element-wise. */
990
1048
  declare const less: (x: ArrayLike, y: ArrayLike) => Array;
991
- /** Compare two arrays element-wise. */
1049
+ /** @function Compare two arrays element-wise. */
992
1050
  declare const equal: (x: ArrayLike, y: ArrayLike) => Array;
993
- /** Compare two arrays element-wise. */
1051
+ /** @function Compare two arrays element-wise. */
994
1052
  declare const notEqual: (x: ArrayLike, y: ArrayLike) => Array;
995
- /** Compare two arrays element-wise. */
1053
+ /** @function Compare two arrays element-wise. */
996
1054
  declare const greaterEqual: (x: ArrayLike, y: ArrayLike) => Array;
997
- /** Compare two arrays element-wise. */
1055
+ /** @function Compare two arrays element-wise. */
998
1056
  declare const lessEqual: (x: ArrayLike, y: ArrayLike) => Array;
999
- /** Element-wise ternary operator, evaluates to `x` if cond else `y`. */
1057
+ /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
1000
1058
  declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
1001
- /** Permute the dimensions of an array. Defaults to reversing the axis order. */
1059
+ /**
1060
+ * @function
1061
+ * Permute the dimensions of an array. Defaults to reversing the axis order.
1062
+ */
1002
1063
  declare const transpose: (x: ArrayLike, perm?: number[]) => Array;
1003
1064
  /**
1065
+ * @function
1004
1066
  * Give a new shape to an array without changing its data.
1005
1067
  *
1006
1068
  * One shape dimension can be -1. In this case, the value is inferred from the
1007
1069
  * length of the array and remaining dimensions.
1008
1070
  */
1009
1071
  declare const reshape: (x: ArrayLike, shape: number[]) => Array;
1010
- /** Move axes of an array to new positions. Other axes retain original order. */
1072
+ /**
1073
+ * @function
1074
+ * Move axes of an array to new positions. Other axes retain original order.
1075
+ */
1011
1076
  declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
1012
1077
  /**
1078
+ * @function
1013
1079
  * Add padding (zeros) to an array.
1014
1080
  *
1015
1081
  * The `width` argument is either an integer or pair of integers, in which case
@@ -1017,15 +1083,27 @@ declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
1017
1083
  * pair specifies the padding for its corresponding axis.
1018
1084
  */
1019
1085
  declare const pad: (x: ArrayLike, width: number | [number, number] | [number, number][]) => Array;
1020
- /** Return the number of dimensions of an array. Does not consume array reference. */
1086
+ /**
1087
+ * @function
1088
+ * Return the number of dimensions of an array. Does not consume array reference.
1089
+ */
1021
1090
  declare const ndim: (x: ArrayLike) => number;
1022
- /** Return the shape of an array. Does not consume array reference. */
1091
+ /** @function Return the shape of an array. Does not consume array reference. */
1023
1092
  declare const shape$1: (x: ArrayLike) => number[];
1024
- /** Return an array of zeros with the same shape and type as a given array. */
1093
+ /**
1094
+ * @function
1095
+ * Return an array of zeros with the same shape and type as a given array.
1096
+ */
1025
1097
  declare const zerosLike: (a: ArrayLike, dtype?: DType) => Array;
1026
- /** Return an array of ones with the same shape and type as a given array. */
1098
+ /**
1099
+ * @function
1100
+ * Return an array of ones with the same shape and type as a given array.
1101
+ */
1027
1102
  declare const onesLike: (a: ArrayLike, dtype?: DType) => Array;
1028
- /** Return a full array with the same shape and type as a given array. */
1103
+ /**
1104
+ * @function
1105
+ * Return a full array with the same shape and type as a given array.
1106
+ */
1029
1107
  declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, dtype?: DType) => Array;
1030
1108
  /**
1031
1109
  * Return the number of elements in an array, optionally along an axis.
@@ -1035,15 +1113,15 @@ declare function size(a: ArrayLike, axis?: number): number;
1035
1113
  /** Convert an array to a specified dtype. */
1036
1114
  declare function astype(a: ArrayLike, dtype: DType): Array;
1037
1115
  /** Sum of the elements of the array over a given axis, or axes. */
1038
- declare function sum(a: ArrayLike, axis?: number | number[], opts?: ReduceOpts): Array;
1116
+ declare function sum(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1039
1117
  /** Product of the array elements over a given axis. */
1040
- declare function prod(a: ArrayLike, axis?: number | number[], opts?: ReduceOpts): Array;
1118
+ declare function prod(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1041
1119
  /** Return the minimum of array elements along a given axis. */
1042
- declare function min(a: ArrayLike, axis?: number | number[], opts?: ReduceOpts): Array;
1120
+ declare function min(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1043
1121
  /** Return the maximum of array elements along a given axis. */
1044
- declare function max(a: ArrayLike, axis?: number | number[], opts?: ReduceOpts): Array;
1122
+ declare function max(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1045
1123
  /** Compute the average of the array elements along the specified axis. */
1046
- declare function mean(a: ArrayLike, axis?: number | number[], opts?: ReduceOpts): Array;
1124
+ declare function mean(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1047
1125
  /**
1048
1126
  * Returns the indices of the minimum values along an axis.
1049
1127
  *
@@ -1059,7 +1137,7 @@ declare function argmin(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
1059
1137
  */
1060
1138
  declare function argmax(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
1061
1139
  /** Reverse the elements in an array along the given axes. */
1062
- declare function flip(x: ArrayLike, axis?: number | number[]): Array;
1140
+ declare function flip(x: ArrayLike, axis?: Axis): Array;
1063
1141
  /**
1064
1142
  * Join a sequence of arrays along an existing axis.
1065
1143
  *
@@ -1103,9 +1181,36 @@ declare function columnStack(xs: ArrayLike[]): Array;
1103
1181
  declare function flipud(x: ArrayLike): Array;
1104
1182
  /** Flip an array horizontally (axis=1). */
1105
1183
  declare function fliplr(x: ArrayLike): Array;
1184
+ /** @function Alternative name for `numpy.transpose()`. */
1106
1185
  declare const permuteDims: (x: ArrayLike, perm?: number[]) => Array;
1107
1186
  /** Return a 1-D flattened array containing the elements of the input. */
1108
1187
  declare function ravel(a: ArrayLike): Array;
1188
+ /**
1189
+ * Repeat each element of an array after themselves.
1190
+ *
1191
+ * If no axis is provided, use the flattened input array, and return a flat
1192
+ * output array.
1193
+ */
1194
+ declare function repeat(a: ArrayLike, repeats: number, axis?: number): Array;
1195
+ /**
1196
+ * Construct an array by repeating A the number of times given by reps.
1197
+ *
1198
+ * If `A` is an array of shape `(d1, d2, ..., dn)` and `reps` is a sequence of
1199
+ * integers, the resulting array will have a shape of `(reps[0] * d1,
1200
+ * reps[1] * d2, ..., reps[n] * dn)`, with `A` tiled along each dimension.
1201
+ */
1202
+ declare function tile(a: ArrayLike, reps: number | number[]): Array;
1203
+ /**
1204
+ * Broadcast an array to a shape, with NumPy-style broadcasing rules.
1205
+ *
1206
+ * In other words, this lets you append axes to the left, and/or expand
1207
+ * dimensions where the shape is 1.
1208
+ */
1209
+ declare function broadcastTo(a: ArrayLike, shape: number[]): Array;
1210
+ /** Broadcast input shapes to a common output shape. */
1211
+ declare function broadcastShapes(...shapes: number[][]): number[];
1212
+ /** Broadcast arrays to a common shape. */
1213
+ declare function broadcastArrays(...arrays: ArrayLike[]): Array[];
1109
1214
  /**
1110
1215
  * Return specified diagonals.
1111
1216
  *
@@ -1133,8 +1238,28 @@ declare function allclose(actual: Parameters<typeof array>[0], expected: Paramet
1133
1238
  declare function matmul(x: ArrayLike, y: ArrayLike): Array;
1134
1239
  /** Dot product of two arrays. */
1135
1240
  declare function dot(x: ArrayLike, y: ArrayLike): Array;
1136
- /** Vector dot product of two arrays. */
1137
- declare function vecdot(x: ArrayLike, y: ArrayLike): Array;
1241
+ /**
1242
+ * Compute the inner product of two arrays.
1243
+ *
1244
+ * Unlike `jax.numpy.matmul()` or `jax.numpy.dot()`, this always performs a
1245
+ * contraction on the last axis.
1246
+ *
1247
+ * Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
1248
+ */
1249
+ declare function inner(x: ArrayLike, y: ArrayLike): Array;
1250
+ /**
1251
+ * Compute the outer product of two arrays.
1252
+ *
1253
+ * If the input arrays are not 1D, they will be flattened. Returned array will
1254
+ * be of shape `[x.size, y.size]`.
1255
+ */
1256
+ declare function outer(x: ArrayLike, y: ArrayLike): Array;
1257
+ /** Vector dot product of two arrays along a given axis. */
1258
+ declare function vecdot(x: ArrayLike, y: ArrayLike, {
1259
+ axis
1260
+ }?: {
1261
+ axis?: number;
1262
+ }): Array;
1138
1263
  /**
1139
1264
  * Return the dot product of two vectors.
1140
1265
  *
@@ -1152,6 +1277,21 @@ declare function meshgrid(xs: Array[], {
1152
1277
  }?: {
1153
1278
  indexing?: "xy" | "ij";
1154
1279
  }): Array[];
1280
+ /**
1281
+ * Return an array with ones on and below the diagonal and zeros elsewhere.
1282
+ *
1283
+ * If `k` is provided, it specifies the sub-diagonal on and below which the
1284
+ * array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
1285
+ * `k>0` is above it.
1286
+ */
1287
+ declare function tri(n: number, m?: number, k?: number, {
1288
+ dtype,
1289
+ device
1290
+ }?: DTypeAndDevice): Array;
1291
+ /** Return the lower triangle of an array. Must be of dimension >= 2. */
1292
+ declare function tril(a: ArrayLike, k?: number): Array;
1293
+ /** Return the upper triangle of an array. Must be of dimension >= 2. */
1294
+ declare function triu(a: ArrayLike, k?: number): Array;
1155
1295
  /**
1156
1296
  * Clip (limit) the values in an array.
1157
1297
  *
@@ -1168,15 +1308,50 @@ declare function clip(a: ArrayLike, min?: ArrayLike, max?: ArrayLike): Array;
1168
1308
  * This is the same function as `jax.numpy.abs()`.
1169
1309
  */
1170
1310
  declare function absolute(x: ArrayLike): Array;
1171
- /** Alias of `jax.numpy.absolute()`. */
1311
+ /** @function Alias of `jax.numpy.absolute()`. */
1172
1312
  declare const abs: typeof absolute;
1313
+ /** Return an element-wise indication of sign of the input. */
1314
+ declare function sign(x: ArrayLike): Array;
1173
1315
  /** Calculate element-wise square of the input array. */
1174
1316
  declare function square(x: ArrayLike): Array;
1175
- /** Compute a trigonometric tangent of each element of input. */
1317
+ /** Element-wise tangent function (takes radians). */
1176
1318
  declare function tan(x: ArrayLike): Array;
1319
+ /** Element-wise inverse cosine function (inverse of cos). */
1320
+ declare function acos(x: ArrayLike): Array;
1321
+ /**
1322
+ * @function
1323
+ * Return element-wise hypotenuse for the given legs of a right triangle.
1324
+ *
1325
+ * In the original NumPy/JAX implementation, this function is more numerically
1326
+ * stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
1327
+ * improvements.
1328
+ */
1329
+ declare const hypot: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
1330
+ /**
1331
+ * @function
1332
+ * Element-wise arc tangent of y/x with correct quadrant.
1333
+ *
1334
+ * Returns the angle in radians between the positive x-axis and the point (x, y).
1335
+ * The result is in the range [-π, π].
1336
+ *
1337
+ * Uses numerically stable formulas:
1338
+ * - When x >= 0: atan2(y, x) = 2 * atan(y / (sqrt(x^2 + y^2) + x))
1339
+ * - When x < 0: atan2(y, x) = 2 * atan((sqrt(x^2 + y^2) - x) / y)
1340
+ *
1341
+ * The output is ill-defined when both x and y are zero.
1342
+ */
1343
+ declare const atan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
1344
+ /** @function Alias of `jax.numpy.acos()`. */
1345
+ declare const arccos: typeof acos;
1346
+ /** @function Alias of `jax.numpy.atan()`. */
1347
+ declare const arctan: (x: ArrayLike) => Array;
1348
+ /** @function Alias of `jax.numpy.atan2()`. */
1349
+ declare const arctan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
1350
+ /** Element-wise subtraction, with broadcasting. */
1351
+ declare function subtract(x: ArrayLike, y: ArrayLike): Array;
1177
1352
  /** Calculates the floating-point division of x by y element-wise. */
1178
1353
  declare function trueDivide(x: ArrayLike, y: ArrayLike): Array;
1179
- /** Alias of `jax.numpy.trueDivide()`. */
1354
+ /** @function Alias of `jax.numpy.trueDivide()`. */
1180
1355
  declare const divide: typeof trueDivide;
1181
1356
  /** Round input to the nearest integer towards zero. */
1182
1357
  declare function trunc(x: ArrayLike): Array;
@@ -1186,26 +1361,112 @@ declare function exp2(p: ArrayLike): Array;
1186
1361
  declare function log2(x: ArrayLike): Array;
1187
1362
  /** Return the base-10 logarithm of x, element-wise. */
1188
1363
  declare function log10(x: ArrayLike): Array;
1364
+ /** Calculate `exp(x) - 1` element-wise. */
1365
+ declare function expm1(x: ArrayLike): Array;
1366
+ /** Calculate the natural logarithm of `1 + x` element-wise. */
1367
+ declare function log1p(x: ArrayLike): Array;
1368
+ /** Convert angles from degrees to radians. */
1369
+ declare function deg2rad(x: ArrayLike): Array;
1370
+ /** @function Alias of `jax.numpy.deg2rad()`. */
1371
+ declare const radians: typeof deg2rad;
1372
+ /** Convert angles from radians to degrees. */
1373
+ declare function rad2deg(x: ArrayLike): Array;
1374
+ /** @function Alias of `jax.numpy.rad2deg()`. */
1375
+ declare const degrees: typeof rad2deg;
1189
1376
  /**
1377
+ * @function
1378
+ * Computes first array raised to power of second array, element-wise.
1379
+ */
1380
+ declare const power: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
1381
+ /** @function Alias of `jax.numpy.power()`. */
1382
+ declare const pow: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
1383
+ /** @function Calculate the element-wise cube root of the input array. */
1384
+ declare const cbrt: OwnedFunction<(x: ArrayLike) => Array>;
1385
+ /**
1386
+ * @function
1190
1387
  * Calculate element-wise hyperbolic sine of input.
1191
1388
  *
1192
1389
  * `sinh(x) = (exp(x) - exp(-x)) / 2`
1193
1390
  */
1194
- declare function sinh(x: ArrayLike): Array;
1391
+ declare const sinh: OwnedFunction<(x: ArrayLike) => Array>;
1195
1392
  /**
1393
+ * @function
1196
1394
  * Calculate element-wise hyperbolic cosine of input.
1197
1395
  *
1198
1396
  * `cosh(x) = (exp(x) + exp(-x)) / 2`
1199
1397
  */
1200
- declare function cosh(x: ArrayLike): Array;
1398
+ declare const cosh: OwnedFunction<(x: ArrayLike) => Array>;
1201
1399
  /**
1400
+ * @function
1202
1401
  * Calculate element-wise hyperbolic tangent of input.
1203
1402
  *
1204
1403
  * `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
1205
1404
  */
1206
- declare function tanh(x: ArrayLike): Array;
1405
+ declare const tanh: OwnedFunction<(x: ArrayLike) => Array>;
1406
+ /**
1407
+ * @function
1408
+ * Calculate element-wise inverse hyperbolic sine of input.
1409
+ *
1410
+ * `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
1411
+ */
1412
+ declare const arcsinh: OwnedFunction<(x: ArrayLike) => Array>;
1413
+ /**
1414
+ * @function
1415
+ * Calculate element-wise inverse hyperbolic cosine of input.
1416
+ *
1417
+ * `arccosh(x) = ln(x + sqrt(x^2 - 1))`
1418
+ */
1419
+ declare const arccosh: OwnedFunction<(x: ArrayLike) => Array>;
1420
+ /**
1421
+ * @function
1422
+ * Calculate element-wise inverse hyperbolic tangent of input.
1423
+ *
1424
+ * `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
1425
+ */
1426
+ declare const arctanh: OwnedFunction<(x: ArrayLike) => Array>;
1427
+ /** @function Alias of `jax.numpy.arcsinh()`. */
1428
+ declare const asinh: OwnedFunction<(x: ArrayLike) => Array>;
1429
+ /** @function Alias of `jax.numpy.arccosh()`. */
1430
+ declare const acosh: OwnedFunction<(x: ArrayLike) => Array>;
1431
+ /** @function Alias of `jax.numpy.arctanh()`. */
1432
+ declare const atanh: OwnedFunction<(x: ArrayLike) => Array>;
1433
+ /**
1434
+ * Compute the variance of an array.
1435
+ *
1436
+ * The variance is computed for the flattened array by default, otherwise over
1437
+ * the specified axis.
1438
+ *
1439
+ * If `correction` is provided, the divisor in calculation is `N - correction`,
1440
+ * where `N` represents the number of elements (e.g., for Bessel's correction).
1441
+ */
1442
+ declare function var_(x: ArrayLike, axis?: Axis, opts?: {
1443
+ mean?: ArrayLike;
1444
+ correction?: number;
1445
+ } & ReduceOpts): Array;
1446
+ /**
1447
+ * Compute the standard deviation of an array.
1448
+ *
1449
+ * The standard deviation is computed for the flattened array by default,
1450
+ * otherwise over the specified axis.
1451
+ *
1452
+ * If `correction` is provided, the divisor in calculation is `N - correction`,
1453
+ * where `N` represents the number of elements (e.g., for Bessel's correction).
1454
+ */
1455
+ declare function std(x: ArrayLike, axis?: Axis, opts?: {
1456
+ mean?: ArrayLike;
1457
+ correction?: number;
1458
+ } & ReduceOpts): Array;
1207
1459
  //#endregion
1208
1460
  //#region src/frontend/jaxpr.d.ts
1461
+ /**
1462
+ * Function callback with an associated dispose() method.
1463
+ *
1464
+ * The dispose() method should be called to clean up any tracer resources needed
1465
+ * by the function after the last time it is called.
1466
+ */
1467
+ type OwnedFunction<F extends Function> = F & {
1468
+ dispose: () => void;
1469
+ };
1209
1470
  /** Variable in a Jaxpr expression. */
1210
1471
  declare class Var {
1211
1472
  #private;
@@ -1297,7 +1558,7 @@ declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding:
1297
1558
  /** Reduce a computation over padded windows. */
1298
1559
  declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
1299
1560
  declare namespace nn_d_exports {
1300
- export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, swish };
1561
+ export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, squareplus, standardize, swish };
1301
1562
  }
1302
1563
  /**
1303
1564
  * Rectified Linear Unit (ReLU) activation function:
@@ -1329,6 +1590,7 @@ declare function softplus(x: ArrayLike): Array;
1329
1590
  */
1330
1591
  declare function softSign(x: ArrayLike): Array;
1331
1592
  /**
1593
+ * @function
1332
1594
  * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
1333
1595
  * Swish, computed element-wise:
1334
1596
  * `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
@@ -1337,8 +1599,9 @@ declare function softSign(x: ArrayLike): Array;
1337
1599
  *
1338
1600
  * Reference: https://en.wikipedia.org/wiki/Swish_function
1339
1601
  */
1340
- declare const silu: (x: ArrayLike) => Array;
1602
+ declare const silu: OwnedFunction<(x: ArrayLike) => Array>;
1341
1603
  /**
1604
+ * @function
1342
1605
  * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
1343
1606
  * Swish, computed element-wise:
1344
1607
  * `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
@@ -1347,31 +1610,35 @@ declare const silu: (x: ArrayLike) => Array;
1347
1610
  *
1348
1611
  * Reference: https://en.wikipedia.org/wiki/Swish_function
1349
1612
  */
1350
- declare const swish: (x: ArrayLike) => Array;
1613
+ declare const swish: OwnedFunction<(x: ArrayLike) => Array>;
1351
1614
  /**
1352
1615
  * Log-sigmoid activation function, computed element-wise:
1353
1616
  * `log_sigmoid(x) = log(sigmoid(x)) = -log(1 + exp(-x))`.
1354
1617
  */
1355
1618
  declare function logSigmoid(x: ArrayLike): Array;
1356
- /** Identity activation function. Returns the argument unmodified. */
1619
+ /**
1620
+ * @function
1621
+ * Identity activation function. Returns the argument unmodified.
1622
+ */
1357
1623
  declare const identity: (x: ArrayLike) => Array;
1358
1624
  /** Leaky rectified linear (ReLU) activation function */
1359
- declare function leakyRelu(x: ArrayLike, negativeSlope?: number): Array;
1625
+ declare function leakyRelu(x: ArrayLike, negativeSlope?: ArrayLike): Array;
1360
1626
  /**
1361
1627
  * Exponential linear unit activation function.
1362
1628
  *
1363
1629
  * Computes the element-wise function:
1364
1630
  * `elu(x) = x > 0 ? x : alpha * (exp(x) - 1)`
1365
1631
  */
1366
- declare function elu(x: ArrayLike, alpha?: number): Array;
1632
+ declare function elu(x: ArrayLike, alpha?: ArrayLike): Array;
1367
1633
  /**
1368
1634
  * Continuously-differentiable exponential linear unit activation function.
1369
1635
  *
1370
1636
  * Computes the element-wise function:
1371
1637
  * `celu(x) = x > 0 ? x : alpha * (exp(x/alpha) - 1)`
1372
1638
  */
1373
- declare function celu(x: ArrayLike, alpha?: number): Array;
1639
+ declare function celu(x: ArrayLike, alpha?: ArrayLike): Array;
1374
1640
  /**
1641
+ * @function
1375
1642
  * Gaussion error linear unit (GELU) activation function.
1376
1643
  *
1377
1644
  * This is computed element-wise. Currently jax-js does not support the erf() or
@@ -1382,7 +1649,7 @@ declare function celu(x: ArrayLike, alpha?: number): Array;
1382
1649
  *
1383
1650
  * This will be improved in the future.
1384
1651
  */
1385
- declare const gelu: (x: ArrayLike) => Array;
1652
+ declare const gelu: OwnedFunction<(x: ArrayLike) => Array>;
1386
1653
  /**
1387
1654
  * Gated linear unit (GLU) activation function.
1388
1655
  *
@@ -1390,6 +1657,13 @@ declare const gelu: (x: ArrayLike) => Array;
1390
1657
  * computes `a * sigmoid(b)`.
1391
1658
  */
1392
1659
  declare function glu(x: ArrayLike, axis?: number): Array;
1660
+ /**
1661
+ * Squareplus activation function.
1662
+ *
1663
+ * Computes the element-wise function:
1664
+ * `squareplus(x) = 0.5 * (x + sqrt(x^2 + b))`
1665
+ */
1666
+ declare function squareplus(x: ArrayLike, b?: ArrayLike): Array;
1393
1667
  /**
1394
1668
  * Mish activation function.
1395
1669
  *
@@ -1405,7 +1679,7 @@ declare function mish(x: ArrayLike): Array;
1405
1679
  *
1406
1680
  * Reference: https://en.wikipedia.org/wiki/Softmax_function
1407
1681
  */
1408
- declare function softmax(x: ArrayLike, axis?: number | number[]): Array;
1682
+ declare function softmax(x: ArrayLike, axis?: Axis): Array;
1409
1683
  /**
1410
1684
  * Log-Softmax function.
1411
1685
  *
@@ -1414,7 +1688,7 @@ declare function softmax(x: ArrayLike, axis?: number | number[]): Array;
1414
1688
  *
1415
1689
  * If `axis` is not specified, it defaults to the last axis.
1416
1690
  */
1417
- declare function logSoftmax(x: ArrayLike, axis?: number | number[]): Array;
1691
+ declare function logSoftmax(x: ArrayLike, axis?: Axis): Array;
1418
1692
  /**
1419
1693
  * Log-sum-exp reduction. Also a multivariate version of `softplus`.
1420
1694
  *
@@ -1423,7 +1697,22 @@ declare function logSoftmax(x: ArrayLike, axis?: number | number[]): Array;
1423
1697
  *
1424
1698
  * Reference: https://en.wikipedia.org/wiki/LogSumExp
1425
1699
  */
1426
- declare function logsumexp(x: ArrayLike, axis?: number | number[]): Array;
1700
+ declare function logsumexp(x: ArrayLike, axis?: Axis): Array;
1701
+ /** Log-mean-exp reduction, like `jax.nn.logsumexp()` but subtracts `log(n)`. */
1702
+ declare function logmeanexp(x: ArrayLike, axis?: Axis): Array;
1703
+ /**
1704
+ * Standardizes input to zero mean and unit variance.
1705
+ *
1706
+ * By default, this is computed over the last axis. You can pass in a different
1707
+ * axis, or `null` to standardize over all elements.
1708
+ *
1709
+ * Epsilon is added to denominator, it defaults to `1e-5` for stability.
1710
+ */
1711
+ declare function standardize(x: ArrayLike, axis?: Axis, opts?: {
1712
+ mean?: ArrayLike;
1713
+ variance?: ArrayLike;
1714
+ epsilon?: ArrayLike;
1715
+ }): Array;
1427
1716
  /**
1428
1717
  * One-hot encodes the given indices.
1429
1718
  *
@@ -1442,7 +1731,7 @@ declare function logsumexp(x: ArrayLike, axis?: number | number[]): Array;
1442
1731
  */
1443
1732
  declare function oneHot(x: Array, numClasses: number): Array;
1444
1733
  declare namespace random_d_exports {
1445
- export { bits, key, split, uniform };
1734
+ export { bernoulli, bits, exponential, key, normal, split, uniform };
1446
1735
  }
1447
1736
  /** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
1448
1737
  declare function key(seed: number): Array;
@@ -1458,26 +1747,59 @@ declare function uniform(key: Array, shape?: number[], {
1458
1747
  minval?: number;
1459
1748
  maxval?: number;
1460
1749
  }): Array;
1750
+ /**
1751
+ * Sample Bernoulli random variables with given mean (0,1 categorical).
1752
+ *
1753
+ * Returns a random Boolean array with the specified shape. `p` can be an array
1754
+ * and must be broadcastable to `shape`.
1755
+ */
1756
+ declare function bernoulli(key: Array, p?: ArrayLike, shape?: number[]): Array;
1757
+ /** Sample exponential random values according to `p(x) = exp(-x)`. */
1758
+ declare function exponential(key: Array, shape?: number[]): Array;
1759
+ /**
1760
+ * Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
1761
+ *
1762
+ * Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
1763
+ * directly inverts the CDF, but we don't have support for that yet. Outputs will not be
1764
+ * bitwise identical to JAX.
1765
+ */
1766
+ declare function normal(key: Array, shape?: number[]): Array;
1461
1767
  //#endregion
1462
1768
  //#region src/index.d.ts
1463
- /** Compute the forward-mode Jacobian-vector product for a function. */
1769
+ /**
1770
+ * @function
1771
+ * Compute the forward-mode Jacobian-vector product for a function.
1772
+ */
1464
1773
  declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, ReturnType<F>];
1465
- /** Vectorize an operation on a batched axis for one or more inputs. */
1774
+ /**
1775
+ * @function
1776
+ * Vectorize an operation on a batched axis for one or more inputs.
1777
+ */
1466
1778
  declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number | MapJsTree<Parameters<F>, ArrayLike, number | null>) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
1467
- /** Compute the Jacobian evaluated column-by-column by forward-mode AD. */
1779
+ /**
1780
+ * @function
1781
+ * Compute the Jacobian evaluated column-by-column by forward-mode AD.
1782
+ */
1468
1783
  declare const jacfwd: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
1469
- /** Construct a Jaxpr by dynamically tracing a function with example inputs. */
1784
+ /**
1785
+ * @function
1786
+ * Construct a Jaxpr by dynamically tracing a function with example inputs.
1787
+ */
1470
1788
  declare const makeJaxpr: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...args: Parameters<F>) => {
1471
1789
  jaxpr: Jaxpr;
1472
1790
  consts: Array[];
1473
1791
  treedef: JsTreeDef;
1474
1792
  };
1475
1793
  /**
1794
+ * @function
1476
1795
  * Mark a function for automatic JIT compilation, with operator fusion.
1477
1796
  *
1478
1797
  * The function will be compiled the first time it is called with a set of
1479
1798
  * argument shapes.
1480
1799
  *
1800
+ * You can call `.dispose()` on the returned, JIT-compiled function after all
1801
+ * calls to free memory associated with array constants.
1802
+ *
1481
1803
  * **Options:**
1482
1804
  * - `staticArgnums`: An array of argument indices to treat as static
1483
1805
  * (compile-time constant). These arguments must be hashable, won't be traced,
@@ -1485,24 +1807,48 @@ declare const makeJaxpr: <F extends (...args: any[]) => JsTree<Array>>(f: F) =>
1485
1807
  * - `device`: The device to place the computation on. If not specified, the
1486
1808
  * computation will be placed on the default device.
1487
1809
  */
1488
- declare const jit: <F extends (...args: any[]) => JsTree<Array>>(f: F, opts?: JitOpts) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
1810
+ declare const jit: <F extends (...args: any[]) => JsTree<Array>>(f: F, opts?: JitOpts) => OwnedFunction<(...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>>;
1489
1811
  /**
1812
+ * @function
1490
1813
  * Produce a local linear approximation to a function at a point using jvp() and
1491
1814
  * partial evaluation.
1492
1815
  */
1493
1816
  declare const linearize: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (...tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>];
1494
- /** Calculate the reverse-mode vector-Jacobian product for a function. */
1817
+ /**
1818
+ * @function
1819
+ * Calculate the reverse-mode vector-Jacobian product for a function.
1820
+ */
1495
1821
  declare const vjp: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (cotangents: MapJsTree<ReturnType<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>];
1496
1822
  /**
1823
+ * @function
1497
1824
  * Compute the gradient of a scalar-valued function `f` with respect to its
1498
1825
  * first argument.
1499
1826
  */
1500
1827
  declare const grad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>[0], ArrayLike, Array>;
1501
- /** Create a function that evaluates both `f` and the gradient of `f`. */
1828
+ /**
1829
+ * @function
1830
+ * Create a function that evaluates both `f` and the gradient of `f`.
1831
+ */
1502
1832
  declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, MapJsTree<Parameters<F>[0], ArrayLike, Array>];
1503
- /** Compute the Jacobian evaluated row-by-row by reverse-mode AD. */
1833
+ /**
1834
+ * @function
1835
+ * Compute the Jacobian evaluated row-by-row by reverse-mode AD.
1836
+ */
1504
1837
  declare const jacrev: typeof jacfwd;
1505
- /** Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`. */
1838
+ /**
1839
+ * @function
1840
+ * Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
1841
+ */
1506
1842
  declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
1843
+ /**
1844
+ * Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
1845
+ *
1846
+ * This can be used to wait for the results of an intermediate computation to
1847
+ * finish. It's recommended to call this regularly in an iterative computation
1848
+ * to avoid queueing up too many pending operations.
1849
+ *
1850
+ * Does not consume reference to the arrays.
1851
+ */
1852
+ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
1507
1853
  //#endregion
1508
- export { DType, type Device, type JsTree, type JsTreeDef, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, setDevice, tree_d_exports as tree, valueAndGrad, vjp, vmap };
1854
+ export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };