@jax-js/jax 0.0.3 → 0.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.d.cts CHANGED
@@ -126,6 +126,19 @@ declare class ShapeTracker {
126
126
  }
127
127
  //#endregion
128
128
  //#region src/utils.d.ts
129
+ /**
130
+ * Set the debug level for verbose logging.
131
+ *
132
+ * 1. JIT compile logs
133
+ * 2. Shader code
134
+ * 3. Expressions and metadata
135
+ * 4. JIT programs, tuning details
136
+ * 5. Most verbose operation traces
137
+ *
138
+ * This is an experimental API and may change in behavior. Do not rely on this
139
+ * in production.
140
+ */
141
+ declare function setDebug(level: number): void;
129
142
  /** @inline */
130
143
  type RecursiveArray<T> = T | RecursiveArray<T>[];
131
144
  interface FpHashable {
@@ -141,7 +154,7 @@ interface FpHashable {
141
154
  declare class FpHash {
142
155
  #private;
143
156
  value: bigint;
144
- update(...values: (string | boolean | number | bigint | null | undefined | FpHashable)[]): this;
157
+ update(x: string | boolean | number | bigint | null | undefined | FpHashable): this;
145
158
  static hash(...values: (string | boolean | number | bigint | null | undefined | FpHashable)[]): bigint;
146
159
  }
147
160
  /** Run a function while caching it inline inside a `Map`. */
@@ -157,6 +170,31 @@ declare enum DType {
157
170
  }
158
171
  /** @inline */
159
172
  type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | Float16Array<ArrayBuffer>;
173
+ /**
174
+ * Promote two dtypes to their join according to the type lattice.
175
+ *
176
+ * When performing operations between arrays of different types, we need to
177
+ * promote both operands to a common type that can represent values from both
178
+ * input types. This follows JAX's type promotion rules.
179
+ *
180
+ * **Type lattice:**
181
+ * ```text
182
+ * bool -> uint32 -> int32 -> float16 -> float32
183
+ * weak f* --^
184
+ * ```
185
+ *
186
+ * The asterisk f* is a weak type used for JS number constants. When creating
187
+ * arrays, JS numbers default to float32 but "weak" so they cast to the dtype of
188
+ * any array they are first combined with.
189
+ *
190
+ * **Examples:**
191
+ * - `promoteTypes(bool, int32) → int32`
192
+ * - `promoteTypes(uint32, int32) → int32`
193
+ * - `promoteTypes(int32, float16) → float16`
194
+ * - `promoteTypes(float16, float32) → float32`
195
+ * - `promoteTypes(uint32, float32) → float32`
196
+ */
197
+ declare function promoteTypes(dtype1: DType, dtype2: DType): DType;
160
198
  /**
161
199
  * Mathematical expression on scalar values.
162
200
  *
@@ -180,6 +218,8 @@ declare class AluExp implements FpHashable {
180
218
  static max(a: AluExp, b: AluExp): AluExp;
181
219
  static sin(a: AluExp): AluExp;
182
220
  static cos(a: AluExp): AluExp;
221
+ static asin(a: AluExp): AluExp;
222
+ static atan(a: AluExp): AluExp;
183
223
  static exp(a: AluExp): AluExp;
184
224
  static log(a: AluExp): AluExp;
185
225
  static sqrt(a: AluExp): AluExp;
@@ -265,6 +305,8 @@ declare enum AluOp {
265
305
  Max = "Max",
266
306
  Sin = "Sin",
267
307
  Cos = "Cos",
308
+ Asin = "Asin",
309
+ Atan = "Atan",
268
310
  Exp = "Exp",
269
311
  Log = "Log",
270
312
  Sqrt = "Sqrt",
@@ -357,8 +399,8 @@ declare class Reduction implements FpHashable {
357
399
  //#region src/backend.d.ts
358
400
  type Device = "cpu" | "wasm" | "webgpu";
359
401
  declare const devices: Device[];
360
- /** Set the default device backend (must be initialized). */
361
- declare function setDevice(device: Device): void;
402
+ /** Configure the default device for arrays. */
403
+ declare function defaultDevice(device?: Device): Device;
362
404
  /**
363
405
  * Initialize `jax-js` library backends.
364
406
  *
@@ -494,6 +536,8 @@ declare enum Primitive {
494
536
  RandomBits = "random_bits",
495
537
  Sin = "sin",
496
538
  Cos = "cos",
539
+ Asin = "asin",
540
+ Atan = "atan",
497
541
  Exp = "exp",
498
542
  Log = "log",
499
543
  Sqrt = "sqrt",
@@ -584,8 +628,10 @@ declare enum CompareOp {
584
628
  LessEqual = "less_equal",
585
629
  }
586
630
  /** @inline */
631
+ type Axis = number | number[] | null;
632
+ /** @inline */
587
633
  type ReduceOpts = {
588
- keepDims?: boolean;
634
+ keepdims?: boolean;
589
635
  };
590
636
  type MainTrace = {
591
637
  level: number;
@@ -662,9 +708,13 @@ declare abstract class Tracer {
662
708
  * ```
663
709
  */
664
710
  abstract dispose(): void;
711
+ /** The shape of the array. */
665
712
  get shape(): number[];
713
+ /** The total number of elements in the array. */
666
714
  get size(): number;
715
+ /** The dtype of the array. */
667
716
  get dtype(): DType;
717
+ /** The number of dimensions of the array. */
668
718
  get ndim(): number;
669
719
  /** @ignore */
670
720
  fullLower(): Tracer;
@@ -678,11 +728,11 @@ declare abstract class Tracer {
678
728
  greaterEqual(other: this | TracerValue): this;
679
729
  lessEqual(other: this | TracerValue): this;
680
730
  /** Sum of the elements of the array over a given axis, or axes. */
681
- sum(axis?: number | number[], opts?: ReduceOpts): this;
731
+ sum(axis?: Axis, opts?: ReduceOpts): this;
682
732
  /** Product of the array elements over a given axis. */
683
- prod(axis?: number | number[], opts?: ReduceOpts): this;
733
+ prod(axis?: Axis, opts?: ReduceOpts): this;
684
734
  /** Compute the average of the array elements along the specified axis. */
685
- mean(axis?: number | number[], opts?: ReduceOpts): this;
735
+ mean(axis?: Axis, opts?: ReduceOpts): this;
686
736
  /** Permute the dimensions of an array. Defaults to reversing the axis order. */
687
737
  transpose(perm?: number[]): this;
688
738
  /**
@@ -810,7 +860,11 @@ declare class Array extends Tracer {
810
860
  * is a backend `Slot`, this constructor _takes ownership_ of the slot. It
811
861
  * will be freed when the array is disposed.
812
862
  */
813
- constructor(source: AluExp | Slot, st: ShapeTracker, dtype: DType, backend: Backend, pending?: Iterable<PendingExecute> | null);
863
+ constructor(source: AluExp | Slot, st: ShapeTracker, dtype: DType, backend: Backend, {
864
+ pending
865
+ }?: {
866
+ pending?: Iterable<PendingExecute> | null;
867
+ });
814
868
  /** @ignore */
815
869
  get aval(): ShapedArray;
816
870
  /** Return a simple string representation of the array's dimensions. */
@@ -839,8 +893,11 @@ declare class Array extends Tracer {
839
893
  *
840
894
  * If you are mapping from `data()` or `dataSync()`, it will also trigger
841
895
  * dispatch of operations as well.
896
+ *
897
+ * **Note:** `jax.blockUntilReady()` is a higher-level API, it calls this
898
+ * asynchronously for multiple arrays.
842
899
  */
843
- wait(): Promise<Array>;
900
+ blockUntilReady(): Promise<Array>;
844
901
  /**
845
902
  * Realize the array and return it as data. This is a sync variant and not
846
903
  * recommended for performance reasons, as it will block rendering.
@@ -870,10 +927,7 @@ declare class Array extends Tracer {
870
927
  _realizeSource(): number;
871
928
  }
872
929
  /** Construct an array from a single scalar constant. */
873
- declare function scalar(value: number | boolean, {
874
- dtype,
875
- device
876
- }?: DTypeAndDevice): Array;
930
+
877
931
  /** Constructor for creating a new array from data. */
878
932
  declare function array(values: Array | Float16Array<ArrayBuffer> | Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | RecursiveArray<number> | RecursiveArray<boolean>, {
879
933
  shape,
@@ -911,7 +965,7 @@ declare function eye(numRows: number, numCols?: number, {
911
965
  dtype,
912
966
  device
913
967
  }?: DTypeAndDevice): Array;
914
- /** Return the identity array, with ones on the main diagonal. */
968
+ /** Return the identity matrix, with ones on the main diagonal. */
915
969
  declare function identity$1(n: number, {
916
970
  dtype,
917
971
  device
@@ -948,7 +1002,7 @@ declare function linspace(start: number, stop: number, num?: number, endpoint?:
948
1002
  device
949
1003
  }?: DTypeAndDevice): Array;
950
1004
  declare namespace numpy_d_exports {
951
- export { Array, ArrayLike, DType, abs, absolute, add, allclose, arange, argmax, argmin, array, astype, bool, clip, columnStack, concatenate, cos, cosh, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hstack, identity$1 as identity, inf, int32, less, lessEqual, linspace, log, log10, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, pad, permuteDims, pi, prod, ravel, reciprocal, reshape, scalar, shape$1 as shape, sin, sinh, size, sqrt, square, stack, sum, tan, tanh, transpose, trueDivide, trunc, uint32, vdot, vecdot, vstack, where, zeros, zerosLike };
1005
+ export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hstack, hypot, identity$1 as identity, inf, inner, int32, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
952
1006
  }
953
1007
  declare const float32 = DType.Float32;
954
1008
  declare const int32 = DType.Int32;
@@ -965,54 +1019,66 @@ declare const inf: number;
965
1019
  declare const nan: number;
966
1020
  /** This is Pi, `π = 3.14159265358979...` */
967
1021
  declare const pi: number;
968
- /** Element-wise addition, with broadcasting. */
1022
+ /** @function Element-wise addition, with broadcasting. */
969
1023
  declare const add: (x: ArrayLike, y: ArrayLike) => Array;
970
- /** Element-wise multiplication, with broadcasting. */
1024
+ /** @function Element-wise multiplication, with broadcasting. */
971
1025
  declare const multiply: (x: ArrayLike, y: ArrayLike) => Array;
972
- /** Numerical negative of every element of an array. */
1026
+ /** @function Numerical negative of every element of an array. */
973
1027
  declare const negative: (x: ArrayLike) => Array;
974
- /** Calculate element-wise reciprocal of the input. This is `1/x`. */
1028
+ /** @function Calculate element-wise reciprocal of the input. This is `1/x`. */
975
1029
  declare const reciprocal: (x: ArrayLike) => Array;
976
- /** Element-wise sine function (takes radians). */
1030
+ /** @function Element-wise sine function (takes radians). */
977
1031
  declare const sin: (x: ArrayLike) => Array;
978
- /** Element-wise cosine function (takes radians). */
1032
+ /** @function Element-wise cosine function (takes radians). */
979
1033
  declare const cos: (x: ArrayLike) => Array;
980
- /** Calculate the exponential of all elements in the input array. */
1034
+ /** @function Element-wise inverse sine function (inverse of sin). */
1035
+ declare const asin: (x: ArrayLike) => Array;
1036
+ /** @function Element-wise inverse tangent function (inverse of tan). */
1037
+ declare const atan: (x: ArrayLike) => Array;
1038
+ /** @function Calculate the exponential of all elements in the input array. */
981
1039
  declare const exp: (x: ArrayLike) => Array;
982
- /** Calculate the natural logarithm of all elements in the input array. */
1040
+ /** @function Calculate the natural logarithm of all elements in the input array. */
983
1041
  declare const log: (x: ArrayLike) => Array;
984
- /** Calculate the square root of all elements in the input array. */
1042
+ /** @function Calculate the square root of all elements in the input array. */
985
1043
  declare const sqrt: (x: ArrayLike) => Array;
986
- /** Return element-wise minimum of the input arrays. */
1044
+ /** @function Return element-wise minimum of the input arrays. */
987
1045
  declare const minimum: (x: ArrayLike, y: ArrayLike) => Array;
988
- /** Return element-wise maximum of the input arrays. */
1046
+ /** @function Return element-wise maximum of the input arrays. */
989
1047
  declare const maximum: (x: ArrayLike, y: ArrayLike) => Array;
990
- /** Compare two arrays element-wise. */
1048
+ /** @function Compare two arrays element-wise. */
991
1049
  declare const greater: (x: ArrayLike, y: ArrayLike) => Array;
992
- /** Compare two arrays element-wise. */
1050
+ /** @function Compare two arrays element-wise. */
993
1051
  declare const less: (x: ArrayLike, y: ArrayLike) => Array;
994
- /** Compare two arrays element-wise. */
1052
+ /** @function Compare two arrays element-wise. */
995
1053
  declare const equal: (x: ArrayLike, y: ArrayLike) => Array;
996
- /** Compare two arrays element-wise. */
1054
+ /** @function Compare two arrays element-wise. */
997
1055
  declare const notEqual: (x: ArrayLike, y: ArrayLike) => Array;
998
- /** Compare two arrays element-wise. */
1056
+ /** @function Compare two arrays element-wise. */
999
1057
  declare const greaterEqual: (x: ArrayLike, y: ArrayLike) => Array;
1000
- /** Compare two arrays element-wise. */
1058
+ /** @function Compare two arrays element-wise. */
1001
1059
  declare const lessEqual: (x: ArrayLike, y: ArrayLike) => Array;
1002
- /** Element-wise ternary operator, evaluates to `x` if cond else `y`. */
1060
+ /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
1003
1061
  declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
1004
- /** Permute the dimensions of an array. Defaults to reversing the axis order. */
1062
+ /**
1063
+ * @function
1064
+ * Permute the dimensions of an array. Defaults to reversing the axis order.
1065
+ */
1005
1066
  declare const transpose: (x: ArrayLike, perm?: number[]) => Array;
1006
1067
  /**
1068
+ * @function
1007
1069
  * Give a new shape to an array without changing its data.
1008
1070
  *
1009
1071
  * One shape dimension can be -1. In this case, the value is inferred from the
1010
1072
  * length of the array and remaining dimensions.
1011
1073
  */
1012
1074
  declare const reshape: (x: ArrayLike, shape: number[]) => Array;
1013
- /** Move axes of an array to new positions. Other axes retain original order. */
1075
+ /**
1076
+ * @function
1077
+ * Move axes of an array to new positions. Other axes retain original order.
1078
+ */
1014
1079
  declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
1015
1080
  /**
1081
+ * @function
1016
1082
  * Add padding (zeros) to an array.
1017
1083
  *
1018
1084
  * The `width` argument is either an integer or pair of integers, in which case
@@ -1020,15 +1086,27 @@ declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
1020
1086
  * pair specifies the padding for its corresponding axis.
1021
1087
  */
1022
1088
  declare const pad: (x: ArrayLike, width: number | [number, number] | [number, number][]) => Array;
1023
- /** Return the number of dimensions of an array. Does not consume array reference. */
1089
+ /**
1090
+ * @function
1091
+ * Return the number of dimensions of an array. Does not consume array reference.
1092
+ */
1024
1093
  declare const ndim: (x: ArrayLike) => number;
1025
- /** Return the shape of an array. Does not consume array reference. */
1094
+ /** @function Return the shape of an array. Does not consume array reference. */
1026
1095
  declare const shape$1: (x: ArrayLike) => number[];
1027
- /** Return an array of zeros with the same shape and type as a given array. */
1096
+ /**
1097
+ * @function
1098
+ * Return an array of zeros with the same shape and type as a given array.
1099
+ */
1028
1100
  declare const zerosLike: (a: ArrayLike, dtype?: DType) => Array;
1029
- /** Return an array of ones with the same shape and type as a given array. */
1101
+ /**
1102
+ * @function
1103
+ * Return an array of ones with the same shape and type as a given array.
1104
+ */
1030
1105
  declare const onesLike: (a: ArrayLike, dtype?: DType) => Array;
1031
- /** Return a full array with the same shape and type as a given array. */
1106
+ /**
1107
+ * @function
1108
+ * Return a full array with the same shape and type as a given array.
1109
+ */
1032
1110
  declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, dtype?: DType) => Array;
1033
1111
  /**
1034
1112
  * Return the number of elements in an array, optionally along an axis.
@@ -1038,15 +1116,15 @@ declare function size(a: ArrayLike, axis?: number): number;
1038
1116
  /** Convert an array to a specified dtype. */
1039
1117
  declare function astype(a: ArrayLike, dtype: DType): Array;
1040
1118
  /** Sum of the elements of the array over a given axis, or axes. */
1041
- declare function sum(a: ArrayLike, axis?: number | number[], opts?: ReduceOpts): Array;
1119
+ declare function sum(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1042
1120
  /** Product of the array elements over a given axis. */
1043
- declare function prod(a: ArrayLike, axis?: number | number[], opts?: ReduceOpts): Array;
1121
+ declare function prod(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1044
1122
  /** Return the minimum of array elements along a given axis. */
1045
- declare function min(a: ArrayLike, axis?: number | number[], opts?: ReduceOpts): Array;
1123
+ declare function min(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1046
1124
  /** Return the maximum of array elements along a given axis. */
1047
- declare function max(a: ArrayLike, axis?: number | number[], opts?: ReduceOpts): Array;
1125
+ declare function max(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1048
1126
  /** Compute the average of the array elements along the specified axis. */
1049
- declare function mean(a: ArrayLike, axis?: number | number[], opts?: ReduceOpts): Array;
1127
+ declare function mean(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1050
1128
  /**
1051
1129
  * Returns the indices of the minimum values along an axis.
1052
1130
  *
@@ -1062,7 +1140,7 @@ declare function argmin(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
1062
1140
  */
1063
1141
  declare function argmax(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
1064
1142
  /** Reverse the elements in an array along the given axes. */
1065
- declare function flip(x: ArrayLike, axis?: number | number[]): Array;
1143
+ declare function flip(x: ArrayLike, axis?: Axis): Array;
1066
1144
  /**
1067
1145
  * Join a sequence of arrays along an existing axis.
1068
1146
  *
@@ -1106,9 +1184,36 @@ declare function columnStack(xs: ArrayLike[]): Array;
1106
1184
  declare function flipud(x: ArrayLike): Array;
1107
1185
  /** Flip an array horizontally (axis=1). */
1108
1186
  declare function fliplr(x: ArrayLike): Array;
1187
+ /** @function Alternative name for `numpy.transpose()`. */
1109
1188
  declare const permuteDims: (x: ArrayLike, perm?: number[]) => Array;
1110
1189
  /** Return a 1-D flattened array containing the elements of the input. */
1111
1190
  declare function ravel(a: ArrayLike): Array;
1191
+ /**
1192
+ * Repeat each element of an array after themselves.
1193
+ *
1194
+ * If no axis is provided, use the flattened input array, and return a flat
1195
+ * output array.
1196
+ */
1197
+ declare function repeat(a: ArrayLike, repeats: number, axis?: number): Array;
1198
+ /**
1199
+ * Construct an array by repeating A the number of times given by reps.
1200
+ *
1201
+ * If `A` is an array of shape `(d1, d2, ..., dn)` and `reps` is a sequence of
1202
+ * integers, the resulting array will have a shape of `(reps[0] * d1,
1203
+ * reps[1] * d2, ..., reps[n] * dn)`, with `A` tiled along each dimension.
1204
+ */
1205
+ declare function tile(a: ArrayLike, reps: number | number[]): Array;
1206
+ /**
1207
+ * Broadcast an array to a shape, with NumPy-style broadcasing rules.
1208
+ *
1209
+ * In other words, this lets you append axes to the left, and/or expand
1210
+ * dimensions where the shape is 1.
1211
+ */
1212
+ declare function broadcastTo(a: ArrayLike, shape: number[]): Array;
1213
+ /** Broadcast input shapes to a common output shape. */
1214
+ declare function broadcastShapes(...shapes: number[][]): number[];
1215
+ /** Broadcast arrays to a common shape. */
1216
+ declare function broadcastArrays(...arrays: ArrayLike[]): Array[];
1112
1217
  /**
1113
1218
  * Return specified diagonals.
1114
1219
  *
@@ -1136,8 +1241,28 @@ declare function allclose(actual: Parameters<typeof array>[0], expected: Paramet
1136
1241
  declare function matmul(x: ArrayLike, y: ArrayLike): Array;
1137
1242
  /** Dot product of two arrays. */
1138
1243
  declare function dot(x: ArrayLike, y: ArrayLike): Array;
1139
- /** Vector dot product of two arrays. */
1140
- declare function vecdot(x: ArrayLike, y: ArrayLike): Array;
1244
+ /**
1245
+ * Compute the inner product of two arrays.
1246
+ *
1247
+ * Unlike `jax.numpy.matmul()` or `jax.numpy.dot()`, this always performs a
1248
+ * contraction on the last axis.
1249
+ *
1250
+ * Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
1251
+ */
1252
+ declare function inner(x: ArrayLike, y: ArrayLike): Array;
1253
+ /**
1254
+ * Compute the outer product of two arrays.
1255
+ *
1256
+ * If the input arrays are not 1D, they will be flattened. Returned array will
1257
+ * be of shape `[x.size, y.size]`.
1258
+ */
1259
+ declare function outer(x: ArrayLike, y: ArrayLike): Array;
1260
+ /** Vector dot product of two arrays along a given axis. */
1261
+ declare function vecdot(x: ArrayLike, y: ArrayLike, {
1262
+ axis
1263
+ }?: {
1264
+ axis?: number;
1265
+ }): Array;
1141
1266
  /**
1142
1267
  * Return the dot product of two vectors.
1143
1268
  *
@@ -1155,6 +1280,21 @@ declare function meshgrid(xs: Array[], {
1155
1280
  }?: {
1156
1281
  indexing?: "xy" | "ij";
1157
1282
  }): Array[];
1283
+ /**
1284
+ * Return an array with ones on and below the diagonal and zeros elsewhere.
1285
+ *
1286
+ * If `k` is provided, it specifies the sub-diagonal on and below which the
1287
+ * array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
1288
+ * `k>0` is above it.
1289
+ */
1290
+ declare function tri(n: number, m?: number, k?: number, {
1291
+ dtype,
1292
+ device
1293
+ }?: DTypeAndDevice): Array;
1294
+ /** Return the lower triangle of an array. Must be of dimension >= 2. */
1295
+ declare function tril(a: ArrayLike, k?: number): Array;
1296
+ /** Return the upper triangle of an array. Must be of dimension >= 2. */
1297
+ declare function triu(a: ArrayLike, k?: number): Array;
1158
1298
  /**
1159
1299
  * Clip (limit) the values in an array.
1160
1300
  *
@@ -1171,15 +1311,50 @@ declare function clip(a: ArrayLike, min?: ArrayLike, max?: ArrayLike): Array;
1171
1311
  * This is the same function as `jax.numpy.abs()`.
1172
1312
  */
1173
1313
  declare function absolute(x: ArrayLike): Array;
1174
- /** Alias of `jax.numpy.absolute()`. */
1314
+ /** @function Alias of `jax.numpy.absolute()`. */
1175
1315
  declare const abs: typeof absolute;
1316
+ /** Return an element-wise indication of sign of the input. */
1317
+ declare function sign(x: ArrayLike): Array;
1176
1318
  /** Calculate element-wise square of the input array. */
1177
1319
  declare function square(x: ArrayLike): Array;
1178
- /** Compute a trigonometric tangent of each element of input. */
1320
+ /** Element-wise tangent function (takes radians). */
1179
1321
  declare function tan(x: ArrayLike): Array;
1322
+ /** Element-wise inverse cosine function (inverse of cos). */
1323
+ declare function acos(x: ArrayLike): Array;
1324
+ /**
1325
+ * @function
1326
+ * Return element-wise hypotenuse for the given legs of a right triangle.
1327
+ *
1328
+ * In the original NumPy/JAX implementation, this function is more numerically
1329
+ * stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
1330
+ * improvements.
1331
+ */
1332
+ declare const hypot: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
1333
+ /**
1334
+ * @function
1335
+ * Element-wise arc tangent of y/x with correct quadrant.
1336
+ *
1337
+ * Returns the angle in radians between the positive x-axis and the point (x, y).
1338
+ * The result is in the range [-π, π].
1339
+ *
1340
+ * Uses numerically stable formulas:
1341
+ * - When x >= 0: atan2(y, x) = 2 * atan(y / (sqrt(x^2 + y^2) + x))
1342
+ * - When x < 0: atan2(y, x) = 2 * atan((sqrt(x^2 + y^2) - x) / y)
1343
+ *
1344
+ * The output is ill-defined when both x and y are zero.
1345
+ */
1346
+ declare const atan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
1347
+ /** @function Alias of `jax.numpy.acos()`. */
1348
+ declare const arccos: typeof acos;
1349
+ /** @function Alias of `jax.numpy.atan()`. */
1350
+ declare const arctan: (x: ArrayLike) => Array;
1351
+ /** @function Alias of `jax.numpy.atan2()`. */
1352
+ declare const arctan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
1353
+ /** Element-wise subtraction, with broadcasting. */
1354
+ declare function subtract(x: ArrayLike, y: ArrayLike): Array;
1180
1355
  /** Calculates the floating-point division of x by y element-wise. */
1181
1356
  declare function trueDivide(x: ArrayLike, y: ArrayLike): Array;
1182
- /** Alias of `jax.numpy.trueDivide()`. */
1357
+ /** @function Alias of `jax.numpy.trueDivide()`. */
1183
1358
  declare const divide: typeof trueDivide;
1184
1359
  /** Round input to the nearest integer towards zero. */
1185
1360
  declare function trunc(x: ArrayLike): Array;
@@ -1189,26 +1364,112 @@ declare function exp2(p: ArrayLike): Array;
1189
1364
  declare function log2(x: ArrayLike): Array;
1190
1365
  /** Return the base-10 logarithm of x, element-wise. */
1191
1366
  declare function log10(x: ArrayLike): Array;
1367
+ /** Calculate `exp(x) - 1` element-wise. */
1368
+ declare function expm1(x: ArrayLike): Array;
1369
+ /** Calculate the natural logarithm of `1 + x` element-wise. */
1370
+ declare function log1p(x: ArrayLike): Array;
1371
+ /** Convert angles from degrees to radians. */
1372
+ declare function deg2rad(x: ArrayLike): Array;
1373
+ /** @function Alias of `jax.numpy.deg2rad()`. */
1374
+ declare const radians: typeof deg2rad;
1375
+ /** Convert angles from radians to degrees. */
1376
+ declare function rad2deg(x: ArrayLike): Array;
1377
+ /** @function Alias of `jax.numpy.rad2deg()`. */
1378
+ declare const degrees: typeof rad2deg;
1192
1379
  /**
1380
+ * @function
1381
+ * Computes first array raised to power of second array, element-wise.
1382
+ */
1383
+ declare const power: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
1384
+ /** @function Alias of `jax.numpy.power()`. */
1385
+ declare const pow: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
1386
+ /** @function Calculate the element-wise cube root of the input array. */
1387
+ declare const cbrt: OwnedFunction<(x: ArrayLike) => Array>;
1388
+ /**
1389
+ * @function
1193
1390
  * Calculate element-wise hyperbolic sine of input.
1194
1391
  *
1195
1392
  * `sinh(x) = (exp(x) - exp(-x)) / 2`
1196
1393
  */
1197
- declare function sinh(x: ArrayLike): Array;
1394
+ declare const sinh: OwnedFunction<(x: ArrayLike) => Array>;
1198
1395
  /**
1396
+ * @function
1199
1397
  * Calculate element-wise hyperbolic cosine of input.
1200
1398
  *
1201
1399
  * `cosh(x) = (exp(x) + exp(-x)) / 2`
1202
1400
  */
1203
- declare function cosh(x: ArrayLike): Array;
1401
+ declare const cosh: OwnedFunction<(x: ArrayLike) => Array>;
1204
1402
  /**
1403
+ * @function
1205
1404
  * Calculate element-wise hyperbolic tangent of input.
1206
1405
  *
1207
1406
  * `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
1208
1407
  */
1209
- declare function tanh(x: ArrayLike): Array;
1408
+ declare const tanh: OwnedFunction<(x: ArrayLike) => Array>;
1409
+ /**
1410
+ * @function
1411
+ * Calculate element-wise inverse hyperbolic sine of input.
1412
+ *
1413
+ * `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
1414
+ */
1415
+ declare const arcsinh: OwnedFunction<(x: ArrayLike) => Array>;
1416
+ /**
1417
+ * @function
1418
+ * Calculate element-wise inverse hyperbolic cosine of input.
1419
+ *
1420
+ * `arccosh(x) = ln(x + sqrt(x^2 - 1))`
1421
+ */
1422
+ declare const arccosh: OwnedFunction<(x: ArrayLike) => Array>;
1423
+ /**
1424
+ * @function
1425
+ * Calculate element-wise inverse hyperbolic tangent of input.
1426
+ *
1427
+ * `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
1428
+ */
1429
+ declare const arctanh: OwnedFunction<(x: ArrayLike) => Array>;
1430
+ /** @function Alias of `jax.numpy.arcsinh()`. */
1431
+ declare const asinh: OwnedFunction<(x: ArrayLike) => Array>;
1432
+ /** @function Alias of `jax.numpy.arccosh()`. */
1433
+ declare const acosh: OwnedFunction<(x: ArrayLike) => Array>;
1434
+ /** @function Alias of `jax.numpy.arctanh()`. */
1435
+ declare const atanh: OwnedFunction<(x: ArrayLike) => Array>;
1436
+ /**
1437
+ * Compute the variance of an array.
1438
+ *
1439
+ * The variance is computed for the flattened array by default, otherwise over
1440
+ * the specified axis.
1441
+ *
1442
+ * If `correction` is provided, the divisor in calculation is `N - correction`,
1443
+ * where `N` represents the number of elements (e.g., for Bessel's correction).
1444
+ */
1445
+ declare function var_(x: ArrayLike, axis?: Axis, opts?: {
1446
+ mean?: ArrayLike;
1447
+ correction?: number;
1448
+ } & ReduceOpts): Array;
1449
+ /**
1450
+ * Compute the standard deviation of an array.
1451
+ *
1452
+ * The standard deviation is computed for the flattened array by default,
1453
+ * otherwise over the specified axis.
1454
+ *
1455
+ * If `correction` is provided, the divisor in calculation is `N - correction`,
1456
+ * where `N` represents the number of elements (e.g., for Bessel's correction).
1457
+ */
1458
+ declare function std(x: ArrayLike, axis?: Axis, opts?: {
1459
+ mean?: ArrayLike;
1460
+ correction?: number;
1461
+ } & ReduceOpts): Array;
1210
1462
  //#endregion
1211
1463
  //#region src/frontend/jaxpr.d.ts
1464
+ /**
1465
+ * Function callback with an associated dispose() method.
1466
+ *
1467
+ * The dispose() method should be called to clean up any tracer resources needed
1468
+ * by the function after the last time it is called.
1469
+ */
1470
+ type OwnedFunction<F extends Function> = F & {
1471
+ dispose: () => void;
1472
+ };
1212
1473
  /** Variable in a Jaxpr expression. */
1213
1474
  declare class Var {
1214
1475
  #private;
@@ -1300,7 +1561,7 @@ declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding:
1300
1561
  /** Reduce a computation over padded windows. */
1301
1562
  declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
1302
1563
  declare namespace nn_d_exports {
1303
- export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, swish };
1564
+ export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, squareplus, standardize, swish };
1304
1565
  }
1305
1566
  /**
1306
1567
  * Rectified Linear Unit (ReLU) activation function:
@@ -1332,6 +1593,7 @@ declare function softplus(x: ArrayLike): Array;
1332
1593
  */
1333
1594
  declare function softSign(x: ArrayLike): Array;
1334
1595
  /**
1596
+ * @function
1335
1597
  * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
1336
1598
  * Swish, computed element-wise:
1337
1599
  * `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
@@ -1340,8 +1602,9 @@ declare function softSign(x: ArrayLike): Array;
1340
1602
  *
1341
1603
  * Reference: https://en.wikipedia.org/wiki/Swish_function
1342
1604
  */
1343
- declare const silu: (x: ArrayLike) => Array;
1605
+ declare const silu: OwnedFunction<(x: ArrayLike) => Array>;
1344
1606
  /**
1607
+ * @function
1345
1608
  * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
1346
1609
  * Swish, computed element-wise:
1347
1610
  * `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
@@ -1350,31 +1613,35 @@ declare const silu: (x: ArrayLike) => Array;
1350
1613
  *
1351
1614
  * Reference: https://en.wikipedia.org/wiki/Swish_function
1352
1615
  */
1353
- declare const swish: (x: ArrayLike) => Array;
1616
+ declare const swish: OwnedFunction<(x: ArrayLike) => Array>;
1354
1617
  /**
1355
1618
  * Log-sigmoid activation function, computed element-wise:
1356
1619
  * `log_sigmoid(x) = log(sigmoid(x)) = -log(1 + exp(-x))`.
1357
1620
  */
1358
1621
  declare function logSigmoid(x: ArrayLike): Array;
1359
- /** Identity activation function. Returns the argument unmodified. */
1622
+ /**
1623
+ * @function
1624
+ * Identity activation function. Returns the argument unmodified.
1625
+ */
1360
1626
  declare const identity: (x: ArrayLike) => Array;
1361
1627
  /** Leaky rectified linear (ReLU) activation function */
1362
- declare function leakyRelu(x: ArrayLike, negativeSlope?: number): Array;
1628
+ declare function leakyRelu(x: ArrayLike, negativeSlope?: ArrayLike): Array;
1363
1629
  /**
1364
1630
  * Exponential linear unit activation function.
1365
1631
  *
1366
1632
  * Computes the element-wise function:
1367
1633
  * `elu(x) = x > 0 ? x : alpha * (exp(x) - 1)`
1368
1634
  */
1369
- declare function elu(x: ArrayLike, alpha?: number): Array;
1635
+ declare function elu(x: ArrayLike, alpha?: ArrayLike): Array;
1370
1636
  /**
1371
1637
  * Continuously-differentiable exponential linear unit activation function.
1372
1638
  *
1373
1639
  * Computes the element-wise function:
1374
1640
  * `celu(x) = x > 0 ? x : alpha * (exp(x/alpha) - 1)`
1375
1641
  */
1376
- declare function celu(x: ArrayLike, alpha?: number): Array;
1642
+ declare function celu(x: ArrayLike, alpha?: ArrayLike): Array;
1377
1643
  /**
1644
+ * @function
1378
1645
  * Gaussion error linear unit (GELU) activation function.
1379
1646
  *
1380
1647
  * This is computed element-wise. Currently jax-js does not support the erf() or
@@ -1385,7 +1652,7 @@ declare function celu(x: ArrayLike, alpha?: number): Array;
1385
1652
  *
1386
1653
  * This will be improved in the future.
1387
1654
  */
1388
- declare const gelu: (x: ArrayLike) => Array;
1655
+ declare const gelu: OwnedFunction<(x: ArrayLike) => Array>;
1389
1656
  /**
1390
1657
  * Gated linear unit (GLU) activation function.
1391
1658
  *
@@ -1393,6 +1660,13 @@ declare const gelu: (x: ArrayLike) => Array;
1393
1660
  * computes `a * sigmoid(b)`.
1394
1661
  */
1395
1662
  declare function glu(x: ArrayLike, axis?: number): Array;
1663
+ /**
1664
+ * Squareplus activation function.
1665
+ *
1666
+ * Computes the element-wise function:
1667
+ * `squareplus(x) = 0.5 * (x + sqrt(x^2 + b))`
1668
+ */
1669
+ declare function squareplus(x: ArrayLike, b?: ArrayLike): Array;
1396
1670
  /**
1397
1671
  * Mish activation function.
1398
1672
  *
@@ -1408,7 +1682,7 @@ declare function mish(x: ArrayLike): Array;
1408
1682
  *
1409
1683
  * Reference: https://en.wikipedia.org/wiki/Softmax_function
1410
1684
  */
1411
- declare function softmax(x: ArrayLike, axis?: number | number[]): Array;
1685
+ declare function softmax(x: ArrayLike, axis?: Axis): Array;
1412
1686
  /**
1413
1687
  * Log-Softmax function.
1414
1688
  *
@@ -1417,7 +1691,7 @@ declare function softmax(x: ArrayLike, axis?: number | number[]): Array;
1417
1691
  *
1418
1692
  * If `axis` is not specified, it defaults to the last axis.
1419
1693
  */
1420
- declare function logSoftmax(x: ArrayLike, axis?: number | number[]): Array;
1694
+ declare function logSoftmax(x: ArrayLike, axis?: Axis): Array;
1421
1695
  /**
1422
1696
  * Log-sum-exp reduction. Also a multivariate version of `softplus`.
1423
1697
  *
@@ -1426,7 +1700,22 @@ declare function logSoftmax(x: ArrayLike, axis?: number | number[]): Array;
1426
1700
  *
1427
1701
  * Reference: https://en.wikipedia.org/wiki/LogSumExp
1428
1702
  */
1429
- declare function logsumexp(x: ArrayLike, axis?: number | number[]): Array;
1703
+ declare function logsumexp(x: ArrayLike, axis?: Axis): Array;
1704
+ /** Log-mean-exp reduction, like `jax.nn.logsumexp()` but subtracts `log(n)`. */
1705
+ declare function logmeanexp(x: ArrayLike, axis?: Axis): Array;
1706
+ /**
1707
+ * Standardizes input to zero mean and unit variance.
1708
+ *
1709
+ * By default, this is computed over the last axis. You can pass in a different
1710
+ * axis, or `null` to standardize over all elements.
1711
+ *
1712
+ * Epsilon is added to denominator, it defaults to `1e-5` for stability.
1713
+ */
1714
+ declare function standardize(x: ArrayLike, axis?: Axis, opts?: {
1715
+ mean?: ArrayLike;
1716
+ variance?: ArrayLike;
1717
+ epsilon?: ArrayLike;
1718
+ }): Array;
1430
1719
  /**
1431
1720
  * One-hot encodes the given indices.
1432
1721
  *
@@ -1445,7 +1734,7 @@ declare function logsumexp(x: ArrayLike, axis?: number | number[]): Array;
1445
1734
  */
1446
1735
  declare function oneHot(x: Array, numClasses: number): Array;
1447
1736
  declare namespace random_d_exports {
1448
- export { bits, key, split, uniform };
1737
+ export { bernoulli, bits, exponential, key, normal, split, uniform };
1449
1738
  }
1450
1739
  /** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
1451
1740
  declare function key(seed: number): Array;
@@ -1461,26 +1750,59 @@ declare function uniform(key: Array, shape?: number[], {
1461
1750
  minval?: number;
1462
1751
  maxval?: number;
1463
1752
  }): Array;
1753
+ /**
1754
+ * Sample Bernoulli random variables with given mean (0,1 categorical).
1755
+ *
1756
+ * Returns a random Boolean array with the specified shape. `p` can be an array
1757
+ * and must be broadcastable to `shape`.
1758
+ */
1759
+ declare function bernoulli(key: Array, p?: ArrayLike, shape?: number[]): Array;
1760
+ /** Sample exponential random values according to `p(x) = exp(-x)`. */
1761
+ declare function exponential(key: Array, shape?: number[]): Array;
1762
+ /**
1763
+ * Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
1764
+ *
1765
+ * Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
1766
+ * directly inverts the CDF, but we don't have support for that yet. Outputs will not be
1767
+ * bitwise identical to JAX.
1768
+ */
1769
+ declare function normal(key: Array, shape?: number[]): Array;
1464
1770
  //#endregion
1465
1771
  //#region src/index.d.ts
1466
- /** Compute the forward-mode Jacobian-vector product for a function. */
1772
+ /**
1773
+ * @function
1774
+ * Compute the forward-mode Jacobian-vector product for a function.
1775
+ */
1467
1776
  declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, ReturnType<F>];
1468
- /** Vectorize an operation on a batched axis for one or more inputs. */
1777
+ /**
1778
+ * @function
1779
+ * Vectorize an operation on a batched axis for one or more inputs.
1780
+ */
1469
1781
  declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number | MapJsTree<Parameters<F>, ArrayLike, number | null>) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
1470
- /** Compute the Jacobian evaluated column-by-column by forward-mode AD. */
1782
+ /**
1783
+ * @function
1784
+ * Compute the Jacobian evaluated column-by-column by forward-mode AD.
1785
+ */
1471
1786
  declare const jacfwd: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
1472
- /** Construct a Jaxpr by dynamically tracing a function with example inputs. */
1787
+ /**
1788
+ * @function
1789
+ * Construct a Jaxpr by dynamically tracing a function with example inputs.
1790
+ */
1473
1791
  declare const makeJaxpr: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...args: Parameters<F>) => {
1474
1792
  jaxpr: Jaxpr;
1475
1793
  consts: Array[];
1476
1794
  treedef: JsTreeDef;
1477
1795
  };
1478
1796
  /**
1797
+ * @function
1479
1798
  * Mark a function for automatic JIT compilation, with operator fusion.
1480
1799
  *
1481
1800
  * The function will be compiled the first time it is called with a set of
1482
1801
  * argument shapes.
1483
1802
  *
1803
+ * You can call `.dispose()` on the returned, JIT-compiled function after all
1804
+ * calls to free memory associated with array constants.
1805
+ *
1484
1806
  * **Options:**
1485
1807
  * - `staticArgnums`: An array of argument indices to treat as static
1486
1808
  * (compile-time constant). These arguments must be hashable, won't be traced,
@@ -1488,24 +1810,48 @@ declare const makeJaxpr: <F extends (...args: any[]) => JsTree<Array>>(f: F) =>
1488
1810
  * - `device`: The device to place the computation on. If not specified, the
1489
1811
  * computation will be placed on the default device.
1490
1812
  */
1491
- declare const jit: <F extends (...args: any[]) => JsTree<Array>>(f: F, opts?: JitOpts) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
1813
+ declare const jit: <F extends (...args: any[]) => JsTree<Array>>(f: F, opts?: JitOpts) => OwnedFunction<(...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>>;
1492
1814
  /**
1815
+ * @function
1493
1816
  * Produce a local linear approximation to a function at a point using jvp() and
1494
1817
  * partial evaluation.
1495
1818
  */
1496
1819
  declare const linearize: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (...tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>];
1497
- /** Calculate the reverse-mode vector-Jacobian product for a function. */
1820
+ /**
1821
+ * @function
1822
+ * Calculate the reverse-mode vector-Jacobian product for a function.
1823
+ */
1498
1824
  declare const vjp: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (cotangents: MapJsTree<ReturnType<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>];
1499
1825
  /**
1826
+ * @function
1500
1827
  * Compute the gradient of a scalar-valued function `f` with respect to its
1501
1828
  * first argument.
1502
1829
  */
1503
1830
  declare const grad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>[0], ArrayLike, Array>;
1504
- /** Create a function that evaluates both `f` and the gradient of `f`. */
1831
+ /**
1832
+ * @function
1833
+ * Create a function that evaluates both `f` and the gradient of `f`.
1834
+ */
1505
1835
  declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, MapJsTree<Parameters<F>[0], ArrayLike, Array>];
1506
- /** Compute the Jacobian evaluated row-by-row by reverse-mode AD. */
1836
+ /**
1837
+ * @function
1838
+ * Compute the Jacobian evaluated row-by-row by reverse-mode AD.
1839
+ */
1507
1840
  declare const jacrev: typeof jacfwd;
1508
- /** Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`. */
1841
+ /**
1842
+ * @function
1843
+ * Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
1844
+ */
1509
1845
  declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
1846
+ /**
1847
+ * Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
1848
+ *
1849
+ * This can be used to wait for the results of an intermediate computation to
1850
+ * finish. It's recommended to call this regularly in an iterative computation
1851
+ * to avoid queueing up too many pending operations.
1852
+ *
1853
+ * Does not consume reference to the arrays.
1854
+ */
1855
+ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
1510
1856
  //#endregion
1511
- export { DType, type Device, type JsTree, type JsTreeDef, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, setDevice, tree_d_exports as tree, valueAndGrad, vjp, vmap };
1857
+ export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };