@jax-js/jax 0.0.2 → 0.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.d.ts CHANGED
@@ -20,6 +20,7 @@ interface Stringable {
20
20
  //#endregion
21
21
  //#region src/shape.d.ts
22
22
 
23
+ /** @inline */
23
24
  type Pair = [number, number];
24
25
  /**
25
26
  * A multidimensional view into memory. An array can be thought of as the
@@ -46,6 +47,8 @@ declare class View {
46
47
  get size(): number;
47
48
  /** Whether this is a default, contiguous, unaltered view of the data (identity). */
48
49
  get contiguous(): boolean;
50
+ /** Return the range of data being indexed in this view, or [0, 0] if none. */
51
+ dataRange(): [number, number];
49
52
  /** Produce an AluExp for evaluating this view at an index. */
50
53
  toAluExp(idxs: AluExp[]): [AluExp, AluExp];
51
54
  /**
@@ -106,9 +109,33 @@ declare class ShapeTracker {
106
109
  reshape(newShape: number[]): ShapeTracker;
107
110
  /** Broadcast along the given new axes, then expand the shape. */
108
111
  broadcast(newShape: number[], axis: number[]): ShapeTracker;
112
+ /**
113
+ * Repeat data in each axis by a positive number of repetitions.
114
+ *
115
+ * - If `tile` is true (default): [1, 2, 3] -> [1, 2, 3, 1, 2, 3].
116
+ * - If `tile` is false: [1, 2, 3] -> [1, 1, 2, 2, 3, 3].
117
+ */
118
+ repeat(reps: number[], tile?: boolean): ShapeTracker;
119
+ /** Move axis i to axis j. */
120
+ moveaxis(i: number, j: number): ShapeTracker;
121
+ /** Like pad(), but allows for negative values. */
122
+ padOrShrink(arg: Pair[]): ShapeTracker;
109
123
  }
110
124
  //#endregion
111
125
  //#region src/utils.d.ts
126
+ /**
127
+ * Set the debug level for verbose logging.
128
+ *
129
+ * 1. JIT compile logs
130
+ * 2. Shader code
131
+ * 3. Expressions and metadata
132
+ * 4. JIT programs, tuning details
133
+ * 5. Most verbose operation traces
134
+ *
135
+ * This is an experimental API and may change in behavior. Do not rely on this
136
+ * in production.
137
+ */
138
+ declare function setDebug(level: number): void;
112
139
  /** @inline */
113
140
  type RecursiveArray<T> = T | RecursiveArray<T>[];
114
141
  interface FpHashable {
@@ -124,19 +151,47 @@ interface FpHashable {
124
151
  declare class FpHash {
125
152
  #private;
126
153
  value: bigint;
127
- update(...values: (string | boolean | number | bigint | null | undefined | FpHashable)[]): this;
154
+ update(x: string | boolean | number | bigint | null | undefined | FpHashable): this;
128
155
  static hash(...values: (string | boolean | number | bigint | null | undefined | FpHashable)[]): bigint;
129
156
  }
130
157
  /** Run a function while caching it inline inside a `Map`. */
131
158
  //#endregion
132
159
  //#region src/alu.d.ts
160
+ /** A numerical data type for array contents. */
133
161
  declare enum DType {
134
162
  Float32 = "float32",
135
163
  Int32 = "int32",
136
164
  Uint32 = "uint32",
137
165
  Bool = "bool",
138
- Complex64 = "complex64",
166
+ Float16 = "float16",
139
167
  }
168
+ /** @inline */
169
+ type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | Float16Array<ArrayBuffer>;
170
+ /**
171
+ * Promote two dtypes to their join according to the type lattice.
172
+ *
173
+ * When performing operations between arrays of different types, we need to
174
+ * promote both operands to a common type that can represent values from both
175
+ * input types. This follows JAX's type promotion rules.
176
+ *
177
+ * **Type lattice:**
178
+ * ```text
179
+ * bool -> uint32 -> int32 -> float16 -> float32
180
+ * weak f* --^
181
+ * ```
182
+ *
183
+ * The asterisk f* is a weak type used for JS number constants. When creating
184
+ * arrays, JS numbers default to float32 but "weak" so they cast to the dtype of
185
+ * any array they are first combined with.
186
+ *
187
+ * **Examples:**
188
+ * - `promoteTypes(bool, int32) → int32`
189
+ * - `promoteTypes(uint32, int32) → int32`
190
+ * - `promoteTypes(int32, float16) → float16`
191
+ * - `promoteTypes(float16, float32) → float32`
192
+ * - `promoteTypes(uint32, float32) → float32`
193
+ */
194
+ declare function promoteTypes(dtype1: DType, dtype2: DType): DType;
140
195
  /**
141
196
  * Mathematical expression on scalar values.
142
197
  *
@@ -160,8 +215,11 @@ declare class AluExp implements FpHashable {
160
215
  static max(a: AluExp, b: AluExp): AluExp;
161
216
  static sin(a: AluExp): AluExp;
162
217
  static cos(a: AluExp): AluExp;
218
+ static asin(a: AluExp): AluExp;
219
+ static atan(a: AluExp): AluExp;
163
220
  static exp(a: AluExp): AluExp;
164
221
  static log(a: AluExp): AluExp;
222
+ static sqrt(a: AluExp): AluExp;
165
223
  static reciprocal(a: AluExp): AluExp;
166
224
  static cast(dtype: DType, a: AluExp): AluExp;
167
225
  static bitcast(dtype: DType, a: AluExp): AluExp;
@@ -172,12 +230,13 @@ declare class AluExp implements FpHashable {
172
230
  static const(dtype: DType, value: any): AluExp;
173
231
  static special(dtype: DType, name: string, n: number): AluExp;
174
232
  static variable(dtype: DType, name: string): AluExp;
175
- static globalIndex(dtype: DType, gid: number, bufidx: AluExp): AluExp;
233
+ static globalIndex(dtype: DType, gid: number, len: number, bufidx: AluExp): AluExp;
176
234
  static globalView(dtype: DType, gid: number, st: ShapeTracker, indices: AluExp[]): AluExp;
235
+ static f32(value: number): AluExp;
177
236
  static i32(value: number): AluExp;
178
237
  static u32(value: number): AluExp;
179
- static f32(value: number): AluExp;
180
238
  static bool(value: boolean): AluExp;
239
+ static f16(value: number): AluExp;
181
240
  not(): AluExp;
182
241
  /** Compute a reasonable expression hash with low collision rate. */
183
242
  getHash(): bigint;
@@ -188,6 +247,19 @@ declare class AluExp implements FpHashable {
188
247
  reindexGids(gidMap: Map<number, number>): AluExp;
189
248
  get min(): number;
190
249
  get max(): number;
250
+ /** Largest known integer that divides self. */
251
+ constFactor(): number;
252
+ /**
253
+ * Checks if divisible by an integer v and returns the quotient if it is, or
254
+ * `null` if it's not divisible.
255
+ */
256
+ divides(v: number): AluExp | null;
257
+ /**
258
+ * Get all expressions by deeply matching an operation.
259
+ *
260
+ * For example: `((2+(3*5))+4).splitOp(+) -> [2,(3*5),4]`.
261
+ */
262
+ splitOp(sep: AluOp): IterableIterator<AluExp>;
191
263
  /**
192
264
  * Simplify the expression by replacing any known patterns and deduping
193
265
  * identical subexpressions.
@@ -208,10 +280,16 @@ declare class AluExp implements FpHashable {
208
280
  toString(): string;
209
281
  /** Generic fold() operation with a reducer over the expression tree. */
210
282
  fold<T = void>(reducer: (exp: AluExp, mappedSrc: T[]) => T): T;
283
+ /** Check if any expression in the tree satisfies a predicate. */
284
+ some(predicate: (exp: AluExp) => boolean): boolean;
211
285
  /** Rewrite the expression recursively using a visitor. */
212
286
  rewrite(visitor: (exp: AluExp) => AluExp | undefined | null): AluExp;
213
287
  /** Collect all nodes that satisfy a predicate. */
214
288
  collect(predicate: (exp: AluExp) => boolean): AluExp[];
289
+ /** Produce a list of all distinct AluOp in this expression. */
290
+ distinctOps(): Set<AluOp>;
291
+ /** Rewrite GlobalView operations to GlobalIndex operations. */
292
+ rewriteGlobalViews(): AluExp;
215
293
  }
216
294
  /** Symbolic form for each mathematical operation. */
217
295
  declare enum AluOp {
@@ -224,8 +302,11 @@ declare enum AluOp {
224
302
  Max = "Max",
225
303
  Sin = "Sin",
226
304
  Cos = "Cos",
305
+ Asin = "Asin",
306
+ Atan = "Atan",
227
307
  Exp = "Exp",
228
308
  Log = "Log",
309
+ Sqrt = "Sqrt",
229
310
  Reciprocal = "Reciprocal",
230
311
  Cast = "Cast",
231
312
  Bitcast = "Bitcast",
@@ -242,7 +323,7 @@ declare enum AluOp {
242
323
  Variable = "Variable",
243
324
  // arg = variable
244
325
  GlobalIndex = "GlobalIndex",
245
- // arg = gid; src = [bufidx]
326
+ // arg = [gid, len]; src = [bufidx]
246
327
  GlobalView = "GlobalView",
247
328
  }
248
329
  /**
@@ -297,12 +378,12 @@ declare class Reduction implements FpHashable {
297
378
  /** Size of the reduction axis. */
298
379
  readonly size: number;
299
380
  /** Follow-up expression defined with the "acc" variable, defaults to identity. */
300
- readonly fusion: AluExp;
381
+ readonly epilogue: AluExp;
301
382
  constructor(/** Data type of the values being reduced over. */
302
383
  dtype: DType, /** Operation to perform. Only ops in `AluGroup.Reduce` are supported. */
303
384
  op: AluOp, /** Size of the reduction axis. */
304
385
  size: number, /** Follow-up expression defined with the "acc" variable, defaults to identity. */
305
- fusion?: AluExp);
386
+ epilogue?: AluExp);
306
387
  hash(state: FpHash): void;
307
388
  toString(): string;
308
389
  /** Get the identity for this reduction operation. */
@@ -313,10 +394,10 @@ declare class Reduction implements FpHashable {
313
394
  /** Expression for accessing `indices` in input array with the given shape. */
314
395
  //#endregion
315
396
  //#region src/backend.d.ts
316
- type Device = "cpu" | "webgpu";
397
+ type Device = "cpu" | "wasm" | "webgpu";
317
398
  declare const devices: Device[];
318
- /** Set the default device backend (must be initialized). */
319
- declare function setDevice(device: Device): void;
399
+ /** Configure the default device for arrays. */
400
+ declare function defaultDevice(device?: Device): Device;
320
401
  /**
321
402
  * Initialize `jax-js` library backends.
322
403
  *
@@ -336,7 +417,7 @@ interface Backend {
336
417
  /** Maximum number of arguments per dispatched kernel. */
337
418
  readonly maxArgs: number;
338
419
  /** Allocate a new slot with reference count 1. */
339
- malloc(size: number, initialData?: ArrayBuffer): Slot;
420
+ malloc(size: number, initialData?: Uint8Array): Slot;
340
421
  /** Increment the reference count of the slot. */
341
422
  incRef(slot: Slot): void;
342
423
  /**
@@ -345,9 +426,9 @@ interface Backend {
345
426
  */
346
427
  decRef(slot: Slot): void;
347
428
  /** Read a range of bytes from a buffer. */
348
- read(slot: Slot, start?: number, count?: number): Promise<ArrayBuffer>;
429
+ read(slot: Slot, start?: number, count?: number): Promise<Uint8Array<ArrayBuffer>>;
349
430
  /** Read a range of bytes from a buffer, blocking variant. */
350
- readSync(slot: Slot, start?: number, count?: number): ArrayBuffer;
431
+ readSync(slot: Slot, start?: number, count?: number): Uint8Array<ArrayBuffer>;
351
432
  /** Prepare an expression to be executed later. */
352
433
  prepare(kernel: Kernel): Promise<Executable>;
353
434
  /** Prepare an expression to be executed later, blocking variant. */
@@ -369,18 +450,22 @@ declare class Executable<T = any> {
369
450
  data: T);
370
451
  }
371
452
  declare namespace tree_d_exports {
372
- export { JsTree, JsTreeDef, MapJsTree, NodeType, flatten, leaves, map, ref, structure, unflatten };
453
+ export { JsTree, JsTreeDef, MapJsTree, NodeType, dispose, flatten, leaves, map, ref, structure, unflatten };
373
454
  }
374
455
  declare enum NodeType {
375
456
  Array = "Array",
376
457
  Object = "Object",
377
458
  Leaf = "Leaf",
378
459
  }
460
+ /** Analog to the JAX "pytree" object, but for JavaScript. */
379
461
  type JsTree<T> = T | JsTree<T>[] | {
380
462
  [key: string]: JsTree<T>;
381
463
  };
382
- type MapJsTree<T, A, B> = T extends A ? B : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
383
- /** Analog to the JAX "pytree" object, but for JavaScript. */
464
+ type Same<X, Y> = (<T>() => T extends X ? 1 : 2) extends (<T>() => T extends Y ? 1 : 2) ? true : false;
465
+ type MappedJsTree<T, A, B> = T extends A ? B : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
466
+ /** @ignore Convert a subtype of JsTree<A> into a JsTree<B>, with the same structure. */
467
+ type MapJsTree<T, A, B> = Same<A, B> extends true ? T : MappedJsTree<T, A, B>;
468
+ /** Represents the structure of a JsTree. */
384
469
  declare class JsTreeDef {
385
470
  readonly nodeType: NodeType;
386
471
  readonly nodeMetadata: any;
@@ -406,6 +491,23 @@ declare function unflatten<T>(treedef: JsTreeDef, leaves: Iterable<T>): JsTree<T
406
491
  declare function map<T, U, Tree extends JsTree<T>>(fn: (...args: T[]) => U, tree: Tree, ...rest: Tree[]): MapJsTree<Tree, T, U>;
407
492
  /** Take a reference of every array in a tree. */
408
493
  declare function ref<Tree extends JsTree<Array>>(tree: Tree): Tree;
494
+ /** Dispose every array in a tree. */
495
+ declare function dispose<Tree extends JsTree<Array>>(tree: Tree | null | undefined): void;
496
+ //#endregion
497
+ //#region src/frontend/convolution.d.ts
498
+ /** Definition of a general dilated convolution. Should be valid on creation. */
499
+ interface ConvParams {
500
+ strides: number[];
501
+ padding: [number, number][];
502
+ lhsDilation: number[];
503
+ rhsDilation: number[];
504
+ }
505
+ /**
506
+ * Check that the shapes and parameters passed to convolution are valid.
507
+ *
508
+ * If the check succeeds, returns the output shape.
509
+ */
510
+
409
511
  //#endregion
410
512
  //#region src/frontend/core.d.ts
411
513
  /**
@@ -431,12 +533,20 @@ declare enum Primitive {
431
533
  RandomBits = "random_bits",
432
534
  Sin = "sin",
433
535
  Cos = "cos",
536
+ Asin = "asin",
537
+ Atan = "atan",
434
538
  Exp = "exp",
435
539
  Log = "log",
540
+ Sqrt = "sqrt",
436
541
  Min = "min",
437
542
  Max = "max",
438
543
  Reduce = "reduce",
439
544
  Dot = "dot",
545
+ // sum(x*y, axis=-1)
546
+ Conv = "conv",
547
+ // see lax.conv_general_dilated
548
+ Pool = "pool",
549
+ PoolTranspose = "pool_transpose",
440
550
  Compare = "compare",
441
551
  Where = "where",
442
552
  Transpose = "transpose",
@@ -448,7 +558,7 @@ declare enum Primitive {
448
558
  Gather = "gather",
449
559
  JitCall = "jit_call",
450
560
  }
451
- interface PrimitiveParamsImpl extends Record<Primitive, Record<string, unknown>> {
561
+ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
452
562
  [Primitive.Cast]: {
453
563
  dtype: DType;
454
564
  };
@@ -459,6 +569,16 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, unknown>>
459
569
  op: AluOp;
460
570
  axis: number[];
461
571
  };
572
+ [Primitive.Conv]: ConvParams;
573
+ [Primitive.Pool]: {
574
+ window: number[];
575
+ strides: number[];
576
+ };
577
+ [Primitive.PoolTranspose]: {
578
+ inShape: number[];
579
+ window: number[];
580
+ strides: number[];
581
+ };
462
582
  [Primitive.Compare]: {
463
583
  op: CompareOp;
464
584
  };
@@ -480,10 +600,10 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, unknown>>
480
600
  axis: number[];
481
601
  };
482
602
  [Primitive.Shrink]: {
483
- slice: [number, number][];
603
+ slice: Pair[];
484
604
  };
485
605
  [Primitive.Pad]: {
486
- width: [number, number][];
606
+ width: Pair[];
487
607
  };
488
608
  [Primitive.Gather]: {
489
609
  axis: number[];
@@ -505,8 +625,10 @@ declare enum CompareOp {
505
625
  LessEqual = "less_equal",
506
626
  }
507
627
  /** @inline */
628
+ type Axis = number | number[] | null;
629
+ /** @inline */
508
630
  type ReduceOpts = {
509
- keepDims?: boolean;
631
+ keepdims?: boolean;
510
632
  };
511
633
  type MainTrace = {
512
634
  level: number;
@@ -583,8 +705,13 @@ declare abstract class Tracer {
583
705
  * ```
584
706
  */
585
707
  abstract dispose(): void;
708
+ /** The shape of the array. */
586
709
  get shape(): number[];
710
+ /** The total number of elements in the array. */
711
+ get size(): number;
712
+ /** The dtype of the array. */
587
713
  get dtype(): DType;
714
+ /** The number of dimensions of the array. */
588
715
  get ndim(): number;
589
716
  /** @ignore */
590
717
  fullLower(): Tracer;
@@ -598,11 +725,11 @@ declare abstract class Tracer {
598
725
  greaterEqual(other: this | TracerValue): this;
599
726
  lessEqual(other: this | TracerValue): this;
600
727
  /** Sum of the elements of the array over a given axis, or axes. */
601
- sum(axis?: number | number[], opts?: ReduceOpts): this;
728
+ sum(axis?: Axis, opts?: ReduceOpts): this;
602
729
  /** Product of the array elements over a given axis. */
603
- prod(axis?: number | number[], opts?: ReduceOpts): this;
730
+ prod(axis?: Axis, opts?: ReduceOpts): this;
604
731
  /** Compute the average of the array elements along the specified axis. */
605
- mean(axis?: number | number[], opts?: ReduceOpts): this;
732
+ mean(axis?: Axis, opts?: ReduceOpts): this;
606
733
  /** Permute the dimensions of an array. Defaults to reversing the axis order. */
607
734
  transpose(perm?: number[]): this;
608
735
  /**
@@ -670,7 +797,7 @@ declare abstract class Tracer {
670
797
  * the "gather" primitive, and it allows you to access specific elements of
671
798
  * the array by integer indices stored in another array.
672
799
  */
673
- slice(...index: (number | [] | [number] | [number, number] | null | Tracer)[]): this;
800
+ slice(...index: (number | [] | [number] | Pair | null | Tracer)[]): this;
674
801
  }
675
802
  declare class ShapedArray implements AbstractValue {
676
803
  readonly shape: number[];
@@ -678,7 +805,7 @@ declare class ShapedArray implements AbstractValue {
678
805
  constructor(shape: number[], dtype: DType);
679
806
  static fromAval(aval: AbstractValue): ShapedArray;
680
807
  get ndim(): number;
681
- strShort(): string;
808
+ toString(): string;
682
809
  equals(other: ShapedArray): boolean;
683
810
  }
684
811
  //#endregion
@@ -730,7 +857,11 @@ declare class Array extends Tracer {
730
857
  * is a backend `Slot`, this constructor _takes ownership_ of the slot. It
731
858
  * will be freed when the array is disposed.
732
859
  */
733
- constructor(source: AluExp | Slot, st: ShapeTracker, dtype: DType, backend: Backend, pending?: Iterable<PendingExecute> | null);
860
+ constructor(source: AluExp | Slot, st: ShapeTracker, dtype: DType, backend: Backend, {
861
+ pending
862
+ }?: {
863
+ pending?: Iterable<PendingExecute> | null;
864
+ });
734
865
  /** @ignore */
735
866
  get aval(): ShapedArray;
736
867
  /** Return a simple string representation of the array's dimensions. */
@@ -749,14 +880,26 @@ declare class Array extends Tracer {
749
880
  */
750
881
  [Symbol.toPrimitive](): any;
751
882
  /** Realize the array and return it as data. */
752
- data(): Promise<Float32Array | Int32Array | Uint32Array>;
753
- /** Wait for this array to be placed on the backend, if needed. */
754
- wait(): Promise<void>;
883
+ data(): Promise<DataArray>;
884
+ /**
885
+ * Wait for this array to finish evaluation.
886
+ *
887
+ * Operations and data loading in jax-js are lazy, so this function ensures
888
+ * that pending operations are dispatched and fully executed before it
889
+ * returns.
890
+ *
891
+ * If you are mapping from `data()` or `dataSync()`, it will also trigger
892
+ * dispatch of operations as well.
893
+ *
894
+ * **Note:** `jax.blockUntilReady()` is a higher-level API, it calls this
895
+ * asynchronously for multiple arrays.
896
+ */
897
+ blockUntilReady(): Promise<Array>;
755
898
  /**
756
899
  * Realize the array and return it as data. This is a sync variant and not
757
900
  * recommended for performance reasons, as it will block rendering.
758
901
  */
759
- dataSync(): Float32Array | Int32Array | Uint32Array;
902
+ dataSync(): DataArray;
760
903
  /**
761
904
  * Convert this array into a JavaScript object.
762
905
  *
@@ -769,17 +912,21 @@ declare class Array extends Tracer {
769
912
  js(): any;
770
913
  /** Convert this array into a JavaScript object, asynchronously. */
771
914
  jsAsync(): Promise<any>;
915
+ /**
916
+ * Copy an element of an array to a numeric scalar and return it.
917
+ *
918
+ * Throws an error if the array does not have a single element. The array must
919
+ * either be rank-0, or all dimensions of the shape are 1.
920
+ */
921
+ item(): number;
772
922
  /** @private Internal plumbing method for Array / Tracer ops. */
773
923
  static _implRules(): typeof implRules;
774
924
  _realizeSource(): number;
775
925
  }
776
926
  /** Construct an array from a single scalar constant. */
777
- declare function scalar(value: number | boolean, {
778
- dtype,
779
- device
780
- }?: DTypeAndDevice): Array;
927
+
781
928
  /** Constructor for creating a new array from data. */
782
- declare function array(values: Array | Float32Array | Int32Array | RecursiveArray<number> | RecursiveArray<boolean>, {
929
+ declare function array(values: Array | Float16Array<ArrayBuffer> | Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | RecursiveArray<number> | RecursiveArray<boolean>, {
783
930
  shape,
784
931
  dtype,
785
932
  device
@@ -815,7 +962,7 @@ declare function eye(numRows: number, numCols?: number, {
815
962
  dtype,
816
963
  device
817
964
  }?: DTypeAndDevice): Array;
818
- /** Return the identity array, with ones on the main diagonal. */
965
+ /** Return the identity matrix, with ones on the main diagonal. */
819
966
  declare function identity$1(n: number, {
820
967
  dtype,
821
968
  device
@@ -851,15 +998,14 @@ declare function linspace(start: number, stop: number, num?: number, endpoint?:
851
998
  dtype,
852
999
  device
853
1000
  }?: DTypeAndDevice): Array;
854
- /** Translate a `CompareOp` into an `AluExp` on two sub-expressions. */
855
1001
  declare namespace numpy_d_exports {
856
- export { Array, ArrayLike, DType, abs, absolute, add, allclose, arange, argmax, argmin, array, astype, bool, clip, columnStack, complex64, concatenate, cos, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, eye, flip, fliplr, flipud, float32, full, greater, greaterEqual, hstack, identity$1 as identity, inf, int32, less, lessEqual, linspace, log, log10, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, pad, permuteDims, pi, prod, ravel, reciprocal, reshape, scalar, shape$1 as shape, sin, size, square, stack, sum, tan, transpose, trueDivide, trunc, uint32, vdot, vecdot, vstack, where, zeros };
1002
+ export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hstack, hypot, identity$1 as identity, inf, inner, int32, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
857
1003
  }
858
1004
  declare const float32 = DType.Float32;
859
1005
  declare const int32 = DType.Int32;
860
1006
  declare const uint32 = DType.Uint32;
861
1007
  declare const bool = DType.Bool;
862
- declare const complex64 = DType.Complex64;
1008
+ declare const float16 = DType.Float16;
863
1009
  /** Euler's constant, `e = 2.7182818284590...` */
864
1010
  declare const e: number;
865
1011
  /** Euler-Mascheroni constant, `γ = 0.5772156649...` */
@@ -870,52 +1016,66 @@ declare const inf: number;
870
1016
  declare const nan: number;
871
1017
  /** This is Pi, `π = 3.14159265358979...` */
872
1018
  declare const pi: number;
873
- /** Element-wise addition, with broadcasting. */
1019
+ /** @function Element-wise addition, with broadcasting. */
874
1020
  declare const add: (x: ArrayLike, y: ArrayLike) => Array;
875
- /** Element-wise multiplication, with broadcasting. */
1021
+ /** @function Element-wise multiplication, with broadcasting. */
876
1022
  declare const multiply: (x: ArrayLike, y: ArrayLike) => Array;
877
- /** Numerical negative of every element of an array. */
1023
+ /** @function Numerical negative of every element of an array. */
878
1024
  declare const negative: (x: ArrayLike) => Array;
879
- /** Calculate element-wise reciprocal of the input. This is `1/x`. */
1025
+ /** @function Calculate element-wise reciprocal of the input. This is `1/x`. */
880
1026
  declare const reciprocal: (x: ArrayLike) => Array;
881
- /** Element-wise sine function (takes radians). */
1027
+ /** @function Element-wise sine function (takes radians). */
882
1028
  declare const sin: (x: ArrayLike) => Array;
883
- /** Element-wise cosine function (takes radians). */
1029
+ /** @function Element-wise cosine function (takes radians). */
884
1030
  declare const cos: (x: ArrayLike) => Array;
885
- /** Calculate the exponential of all elements in the input array. */
1031
+ /** @function Element-wise inverse sine function (inverse of sin). */
1032
+ declare const asin: (x: ArrayLike) => Array;
1033
+ /** @function Element-wise inverse tangent function (inverse of tan). */
1034
+ declare const atan: (x: ArrayLike) => Array;
1035
+ /** @function Calculate the exponential of all elements in the input array. */
886
1036
  declare const exp: (x: ArrayLike) => Array;
887
- /** Calculate the natural logarithm of all elements in the input array. */
1037
+ /** @function Calculate the natural logarithm of all elements in the input array. */
888
1038
  declare const log: (x: ArrayLike) => Array;
889
- /** Return element-wise minimum of the input arrays. */
1039
+ /** @function Calculate the square root of all elements in the input array. */
1040
+ declare const sqrt: (x: ArrayLike) => Array;
1041
+ /** @function Return element-wise minimum of the input arrays. */
890
1042
  declare const minimum: (x: ArrayLike, y: ArrayLike) => Array;
891
- /** Return element-wise maximum of the input arrays. */
1043
+ /** @function Return element-wise maximum of the input arrays. */
892
1044
  declare const maximum: (x: ArrayLike, y: ArrayLike) => Array;
893
- /** Compare two arrays element-wise. */
1045
+ /** @function Compare two arrays element-wise. */
894
1046
  declare const greater: (x: ArrayLike, y: ArrayLike) => Array;
895
- /** Compare two arrays element-wise. */
1047
+ /** @function Compare two arrays element-wise. */
896
1048
  declare const less: (x: ArrayLike, y: ArrayLike) => Array;
897
- /** Compare two arrays element-wise. */
1049
+ /** @function Compare two arrays element-wise. */
898
1050
  declare const equal: (x: ArrayLike, y: ArrayLike) => Array;
899
- /** Compare two arrays element-wise. */
1051
+ /** @function Compare two arrays element-wise. */
900
1052
  declare const notEqual: (x: ArrayLike, y: ArrayLike) => Array;
901
- /** Compare two arrays element-wise. */
1053
+ /** @function Compare two arrays element-wise. */
902
1054
  declare const greaterEqual: (x: ArrayLike, y: ArrayLike) => Array;
903
- /** Compare two arrays element-wise. */
1055
+ /** @function Compare two arrays element-wise. */
904
1056
  declare const lessEqual: (x: ArrayLike, y: ArrayLike) => Array;
905
- /** Element-wise ternary operator, evaluates to `x` if cond else `y`. */
1057
+ /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
906
1058
  declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
907
- /** Permute the dimensions of an array. Defaults to reversing the axis order. */
1059
+ /**
1060
+ * @function
1061
+ * Permute the dimensions of an array. Defaults to reversing the axis order.
1062
+ */
908
1063
  declare const transpose: (x: ArrayLike, perm?: number[]) => Array;
909
1064
  /**
1065
+ * @function
910
1066
  * Give a new shape to an array without changing its data.
911
1067
  *
912
1068
  * One shape dimension can be -1. In this case, the value is inferred from the
913
1069
  * length of the array and remaining dimensions.
914
1070
  */
915
1071
  declare const reshape: (x: ArrayLike, shape: number[]) => Array;
916
- /** Move axes of an array to new positions. Other axes retain original order. */
1072
+ /**
1073
+ * @function
1074
+ * Move axes of an array to new positions. Other axes retain original order.
1075
+ */
917
1076
  declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
918
1077
  /**
1078
+ * @function
919
1079
  * Add padding (zeros) to an array.
920
1080
  *
921
1081
  * The `width` argument is either an integer or pair of integers, in which case
@@ -923,10 +1083,28 @@ declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
923
1083
  * pair specifies the padding for its corresponding axis.
924
1084
  */
925
1085
  declare const pad: (x: ArrayLike, width: number | [number, number] | [number, number][]) => Array;
926
- /** Return the number of dimensions of an array. Does not consume array reference. */
1086
+ /**
1087
+ * @function
1088
+ * Return the number of dimensions of an array. Does not consume array reference.
1089
+ */
927
1090
  declare const ndim: (x: ArrayLike) => number;
928
- /** Return the shape of an array. Does not consume array reference. */
1091
+ /** @function Return the shape of an array. Does not consume array reference. */
929
1092
  declare const shape$1: (x: ArrayLike) => number[];
1093
+ /**
1094
+ * @function
1095
+ * Return an array of zeros with the same shape and type as a given array.
1096
+ */
1097
+ declare const zerosLike: (a: ArrayLike, dtype?: DType) => Array;
1098
+ /**
1099
+ * @function
1100
+ * Return an array of ones with the same shape and type as a given array.
1101
+ */
1102
+ declare const onesLike: (a: ArrayLike, dtype?: DType) => Array;
1103
+ /**
1104
+ * @function
1105
+ * Return a full array with the same shape and type as a given array.
1106
+ */
1107
+ declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, dtype?: DType) => Array;
930
1108
  /**
931
1109
  * Return the number of elements in an array, optionally along an axis.
932
1110
  * Does not consume array reference.
@@ -935,15 +1113,15 @@ declare function size(a: ArrayLike, axis?: number): number;
935
1113
  /** Convert an array to a specified dtype. */
936
1114
  declare function astype(a: ArrayLike, dtype: DType): Array;
937
1115
  /** Sum of the elements of the array over a given axis, or axes. */
938
- declare function sum(a: ArrayLike, axis?: number | number[], opts?: ReduceOpts): Array;
1116
+ declare function sum(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
939
1117
  /** Product of the array elements over a given axis. */
940
- declare function prod(a: ArrayLike, axis?: number | number[], opts?: ReduceOpts): Array;
1118
+ declare function prod(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
941
1119
  /** Return the minimum of array elements along a given axis. */
942
- declare function min(a: ArrayLike, axis?: number | number[], opts?: ReduceOpts): Array;
1120
+ declare function min(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
943
1121
  /** Return the maximum of array elements along a given axis. */
944
- declare function max(a: ArrayLike, axis?: number | number[], opts?: ReduceOpts): Array;
1122
+ declare function max(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
945
1123
  /** Compute the average of the array elements along the specified axis. */
946
- declare function mean(a: ArrayLike, axis?: number | number[], opts?: ReduceOpts): Array;
1124
+ declare function mean(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
947
1125
  /**
948
1126
  * Returns the indices of the minimum values along an axis.
949
1127
  *
@@ -959,7 +1137,7 @@ declare function argmin(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
959
1137
  */
960
1138
  declare function argmax(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
961
1139
  /** Reverse the elements in an array along the given axes. */
962
- declare function flip(x: ArrayLike, axis?: number | number[]): Array;
1140
+ declare function flip(x: ArrayLike, axis?: Axis): Array;
963
1141
  /**
964
1142
  * Join a sequence of arrays along an existing axis.
965
1143
  *
@@ -1003,16 +1181,45 @@ declare function columnStack(xs: ArrayLike[]): Array;
1003
1181
  declare function flipud(x: ArrayLike): Array;
1004
1182
  /** Flip an array horizontally (axis=1). */
1005
1183
  declare function fliplr(x: ArrayLike): Array;
1184
+ /** @function Alternative name for `numpy.transpose()`. */
1006
1185
  declare const permuteDims: (x: ArrayLike, perm?: number[]) => Array;
1007
1186
  /** Return a 1-D flattened array containing the elements of the input. */
1008
1187
  declare function ravel(a: ArrayLike): Array;
1188
+ /**
1189
+ * Repeat each element of an array after themselves.
1190
+ *
1191
+ * If no axis is provided, use the flattened input array, and return a flat
1192
+ * output array.
1193
+ */
1194
+ declare function repeat(a: ArrayLike, repeats: number, axis?: number): Array;
1195
+ /**
1196
+ * Construct an array by repeating A the number of times given by reps.
1197
+ *
1198
+ * If `A` is an array of shape `(d1, d2, ..., dn)` and `reps` is a sequence of
1199
+ * integers, the resulting array will have a shape of `(reps[0] * d1,
1200
+ * reps[1] * d2, ..., reps[n] * dn)`, with `A` tiled along each dimension.
1201
+ */
1202
+ declare function tile(a: ArrayLike, reps: number | number[]): Array;
1203
+ /**
1204
+ * Broadcast an array to a shape, with NumPy-style broadcasing rules.
1205
+ *
1206
+ * In other words, this lets you append axes to the left, and/or expand
1207
+ * dimensions where the shape is 1.
1208
+ */
1209
+ declare function broadcastTo(a: ArrayLike, shape: number[]): Array;
1210
+ /** Broadcast input shapes to a common output shape. */
1211
+ declare function broadcastShapes(...shapes: number[][]): number[];
1212
+ /** Broadcast arrays to a common shape. */
1213
+ declare function broadcastArrays(...arrays: ArrayLike[]): Array[];
1009
1214
  /**
1010
1215
  * Return specified diagonals.
1011
1216
  *
1012
1217
  * If a is 2D, return the diagonal of the array with the given offset. If a is
1013
- * 3D or higher, compute diagonals along the two given axes.
1218
+ * 3D or higher, compute diagonals along the two given axes (default: 0, 1).
1014
1219
  *
1015
- * This returns a view over the existing array.
1220
+ * This returns a view over the existing array. The shape of the resulting array
1221
+ * is determined by removing the two axes along which the diagonal is taken,
1222
+ * then appending a new axis to the right with holding the diagonals.
1016
1223
  */
1017
1224
  declare function diagonal(a: ArrayLike, offset?: number, axis1?: number, axis2?: number): Array;
1018
1225
  /**
@@ -1031,8 +1238,28 @@ declare function allclose(actual: Parameters<typeof array>[0], expected: Paramet
1031
1238
  declare function matmul(x: ArrayLike, y: ArrayLike): Array;
1032
1239
  /** Dot product of two arrays. */
1033
1240
  declare function dot(x: ArrayLike, y: ArrayLike): Array;
1034
- /** Vector dot product of two arrays. */
1035
- declare function vecdot(x: ArrayLike, y: ArrayLike): Array;
1241
+ /**
1242
+ * Compute the inner product of two arrays.
1243
+ *
1244
+ * Unlike `jax.numpy.matmul()` or `jax.numpy.dot()`, this always performs a
1245
+ * contraction on the last axis.
1246
+ *
1247
+ * Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
1248
+ */
1249
+ declare function inner(x: ArrayLike, y: ArrayLike): Array;
1250
+ /**
1251
+ * Compute the outer product of two arrays.
1252
+ *
1253
+ * If the input arrays are not 1D, they will be flattened. Returned array will
1254
+ * be of shape `[x.size, y.size]`.
1255
+ */
1256
+ declare function outer(x: ArrayLike, y: ArrayLike): Array;
1257
+ /** Vector dot product of two arrays along a given axis. */
1258
+ declare function vecdot(x: ArrayLike, y: ArrayLike, {
1259
+ axis
1260
+ }?: {
1261
+ axis?: number;
1262
+ }): Array;
1036
1263
  /**
1037
1264
  * Return the dot product of two vectors.
1038
1265
  *
@@ -1050,6 +1277,21 @@ declare function meshgrid(xs: Array[], {
1050
1277
  }?: {
1051
1278
  indexing?: "xy" | "ij";
1052
1279
  }): Array[];
1280
+ /**
1281
+ * Return an array with ones on and below the diagonal and zeros elsewhere.
1282
+ *
1283
+ * If `k` is provided, it specifies the sub-diagonal on and below which the
1284
+ * array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
1285
+ * `k>0` is above it.
1286
+ */
1287
+ declare function tri(n: number, m?: number, k?: number, {
1288
+ dtype,
1289
+ device
1290
+ }?: DTypeAndDevice): Array;
1291
+ /** Return the lower triangle of an array. Must be of dimension >= 2. */
1292
+ declare function tril(a: ArrayLike, k?: number): Array;
1293
+ /** Return the upper triangle of an array. Must be of dimension >= 2. */
1294
+ declare function triu(a: ArrayLike, k?: number): Array;
1053
1295
  /**
1054
1296
  * Clip (limit) the values in an array.
1055
1297
  *
@@ -1066,15 +1308,50 @@ declare function clip(a: ArrayLike, min?: ArrayLike, max?: ArrayLike): Array;
1066
1308
  * This is the same function as `jax.numpy.abs()`.
1067
1309
  */
1068
1310
  declare function absolute(x: ArrayLike): Array;
1069
- /** Alias of `jax.numpy.absolute()`. */
1311
+ /** @function Alias of `jax.numpy.absolute()`. */
1070
1312
  declare const abs: typeof absolute;
1313
+ /** Return an element-wise indication of sign of the input. */
1314
+ declare function sign(x: ArrayLike): Array;
1071
1315
  /** Calculate element-wise square of the input array. */
1072
1316
  declare function square(x: ArrayLike): Array;
1073
- /** Compute a trigonometric tangent of each element of input. */
1317
+ /** Element-wise tangent function (takes radians). */
1074
1318
  declare function tan(x: ArrayLike): Array;
1319
+ /** Element-wise inverse cosine function (inverse of cos). */
1320
+ declare function acos(x: ArrayLike): Array;
1321
+ /**
1322
+ * @function
1323
+ * Return element-wise hypotenuse for the given legs of a right triangle.
1324
+ *
1325
+ * In the original NumPy/JAX implementation, this function is more numerically
1326
+ * stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
1327
+ * improvements.
1328
+ */
1329
+ declare const hypot: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
1330
+ /**
1331
+ * @function
1332
+ * Element-wise arc tangent of y/x with correct quadrant.
1333
+ *
1334
+ * Returns the angle in radians between the positive x-axis and the point (x, y).
1335
+ * The result is in the range [-π, π].
1336
+ *
1337
+ * Uses numerically stable formulas:
1338
+ * - When x >= 0: atan2(y, x) = 2 * atan(y / (sqrt(x^2 + y^2) + x))
1339
+ * - When x < 0: atan2(y, x) = 2 * atan((sqrt(x^2 + y^2) - x) / y)
1340
+ *
1341
+ * The output is ill-defined when both x and y are zero.
1342
+ */
1343
+ declare const atan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
1344
+ /** @function Alias of `jax.numpy.acos()`. */
1345
+ declare const arccos: typeof acos;
1346
+ /** @function Alias of `jax.numpy.atan()`. */
1347
+ declare const arctan: (x: ArrayLike) => Array;
1348
+ /** @function Alias of `jax.numpy.atan2()`. */
1349
+ declare const arctan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
1350
+ /** Element-wise subtraction, with broadcasting. */
1351
+ declare function subtract(x: ArrayLike, y: ArrayLike): Array;
1075
1352
  /** Calculates the floating-point division of x by y element-wise. */
1076
1353
  declare function trueDivide(x: ArrayLike, y: ArrayLike): Array;
1077
- /** Alias of `jax.numpy.trueDivide()`. */
1354
+ /** @function Alias of `jax.numpy.trueDivide()`. */
1078
1355
  declare const divide: typeof trueDivide;
1079
1356
  /** Round input to the nearest integer towards zero. */
1080
1357
  declare function trunc(x: ArrayLike): Array;
@@ -1084,8 +1361,112 @@ declare function exp2(p: ArrayLike): Array;
1084
1361
  declare function log2(x: ArrayLike): Array;
1085
1362
  /** Return the base-10 logarithm of x, element-wise. */
1086
1363
  declare function log10(x: ArrayLike): Array;
1364
+ /** Calculate `exp(x) - 1` element-wise. */
1365
+ declare function expm1(x: ArrayLike): Array;
1366
+ /** Calculate the natural logarithm of `1 + x` element-wise. */
1367
+ declare function log1p(x: ArrayLike): Array;
1368
+ /** Convert angles from degrees to radians. */
1369
+ declare function deg2rad(x: ArrayLike): Array;
1370
+ /** @function Alias of `jax.numpy.deg2rad()`. */
1371
+ declare const radians: typeof deg2rad;
1372
+ /** Convert angles from radians to degrees. */
1373
+ declare function rad2deg(x: ArrayLike): Array;
1374
+ /** @function Alias of `jax.numpy.rad2deg()`. */
1375
+ declare const degrees: typeof rad2deg;
1376
+ /**
1377
+ * @function
1378
+ * Computes first array raised to power of second array, element-wise.
1379
+ */
1380
+ declare const power: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
1381
+ /** @function Alias of `jax.numpy.power()`. */
1382
+ declare const pow: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
1383
+ /** @function Calculate the element-wise cube root of the input array. */
1384
+ declare const cbrt: OwnedFunction<(x: ArrayLike) => Array>;
1385
+ /**
1386
+ * @function
1387
+ * Calculate element-wise hyperbolic sine of input.
1388
+ *
1389
+ * `sinh(x) = (exp(x) - exp(-x)) / 2`
1390
+ */
1391
+ declare const sinh: OwnedFunction<(x: ArrayLike) => Array>;
1392
+ /**
1393
+ * @function
1394
+ * Calculate element-wise hyperbolic cosine of input.
1395
+ *
1396
+ * `cosh(x) = (exp(x) + exp(-x)) / 2`
1397
+ */
1398
+ declare const cosh: OwnedFunction<(x: ArrayLike) => Array>;
1399
+ /**
1400
+ * @function
1401
+ * Calculate element-wise hyperbolic tangent of input.
1402
+ *
1403
+ * `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
1404
+ */
1405
+ declare const tanh: OwnedFunction<(x: ArrayLike) => Array>;
1406
+ /**
1407
+ * @function
1408
+ * Calculate element-wise inverse hyperbolic sine of input.
1409
+ *
1410
+ * `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
1411
+ */
1412
+ declare const arcsinh: OwnedFunction<(x: ArrayLike) => Array>;
1413
+ /**
1414
+ * @function
1415
+ * Calculate element-wise inverse hyperbolic cosine of input.
1416
+ *
1417
+ * `arccosh(x) = ln(x + sqrt(x^2 - 1))`
1418
+ */
1419
+ declare const arccosh: OwnedFunction<(x: ArrayLike) => Array>;
1420
+ /**
1421
+ * @function
1422
+ * Calculate element-wise inverse hyperbolic tangent of input.
1423
+ *
1424
+ * `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
1425
+ */
1426
+ declare const arctanh: OwnedFunction<(x: ArrayLike) => Array>;
1427
+ /** @function Alias of `jax.numpy.arcsinh()`. */
1428
+ declare const asinh: OwnedFunction<(x: ArrayLike) => Array>;
1429
+ /** @function Alias of `jax.numpy.arccosh()`. */
1430
+ declare const acosh: OwnedFunction<(x: ArrayLike) => Array>;
1431
+ /** @function Alias of `jax.numpy.arctanh()`. */
1432
+ declare const atanh: OwnedFunction<(x: ArrayLike) => Array>;
1433
+ /**
1434
+ * Compute the variance of an array.
1435
+ *
1436
+ * The variance is computed for the flattened array by default, otherwise over
1437
+ * the specified axis.
1438
+ *
1439
+ * If `correction` is provided, the divisor in calculation is `N - correction`,
1440
+ * where `N` represents the number of elements (e.g., for Bessel's correction).
1441
+ */
1442
+ declare function var_(x: ArrayLike, axis?: Axis, opts?: {
1443
+ mean?: ArrayLike;
1444
+ correction?: number;
1445
+ } & ReduceOpts): Array;
1446
+ /**
1447
+ * Compute the standard deviation of an array.
1448
+ *
1449
+ * The standard deviation is computed for the flattened array by default,
1450
+ * otherwise over the specified axis.
1451
+ *
1452
+ * If `correction` is provided, the divisor in calculation is `N - correction`,
1453
+ * where `N` represents the number of elements (e.g., for Bessel's correction).
1454
+ */
1455
+ declare function std(x: ArrayLike, axis?: Axis, opts?: {
1456
+ mean?: ArrayLike;
1457
+ correction?: number;
1458
+ } & ReduceOpts): Array;
1087
1459
  //#endregion
1088
1460
  //#region src/frontend/jaxpr.d.ts
1461
+ /**
1462
+ * Function callback with an associated dispose() method.
1463
+ *
1464
+ * The dispose() method should be called to clean up any tracer resources needed
1465
+ * by the function after the last time it is called.
1466
+ */
1467
+ type OwnedFunction<F extends Function> = F & {
1468
+ dispose: () => void;
1469
+ };
1089
1470
  /** Variable in a Jaxpr expression. */
1090
1471
  declare class Var {
1091
1472
  #private;
@@ -1146,8 +1527,38 @@ declare class Jaxpr implements FpHashable {
1146
1527
  /** Flattens nested JitCall in a Jaxpr. Useful for handling jit-of-jit. */
1147
1528
  flatten(): Jaxpr;
1148
1529
  }
1530
+ /** @inline */
1531
+ type JitOpts = {
1532
+ staticArgnums?: number[];
1533
+ device?: Device;
1534
+ };
1535
+ declare namespace lax_d_exports {
1536
+ export { PaddingType, conv, convGeneralDilated, convWithGeneralPadding, reduceWindow };
1537
+ }
1538
+ type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | [number, number][];
1539
+ /**
1540
+ * General n-dimensional convolution operator, with optional dilation.
1541
+ *
1542
+ * The semantics of this operation mimic the `jax.lax.conv_general_dilated`
1543
+ * function in JAX, which wraps XLA's general convolution operator.
1544
+ *
1545
+ * Grouped convolutions are not supported right now.
1546
+ */
1547
+ declare function convGeneralDilated(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, {
1548
+ lhsDilation,
1549
+ rhsDilation
1550
+ }?: {
1551
+ lhsDilation?: number[];
1552
+ rhsDilation?: number[];
1553
+ }): Array;
1554
+ /** Convenience wrapper around `convGeneralDilated`. */
1555
+ declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, lhsDilation?: number[], rhsDilation?: number[]): Array;
1556
+ /** Convenience wrapper around `convGeneralDilated`. */
1557
+ declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType): Array;
1558
+ /** Reduce a computation over padded windows. */
1559
+ declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
1149
1560
  declare namespace nn_d_exports {
1150
- export { identity, logSigmoid, logSoftmax, logsumexp, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, swish };
1561
+ export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, squareplus, standardize, swish };
1151
1562
  }
1152
1563
  /**
1153
1564
  * Rectified Linear Unit (ReLU) activation function:
@@ -1179,6 +1590,7 @@ declare function softplus(x: ArrayLike): Array;
1179
1590
  */
1180
1591
  declare function softSign(x: ArrayLike): Array;
1181
1592
  /**
1593
+ * @function
1182
1594
  * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
1183
1595
  * Swish, computed element-wise:
1184
1596
  * `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
@@ -1187,8 +1599,9 @@ declare function softSign(x: ArrayLike): Array;
1187
1599
  *
1188
1600
  * Reference: https://en.wikipedia.org/wiki/Swish_function
1189
1601
  */
1190
- declare function silu(x: ArrayLike): Array;
1602
+ declare const silu: OwnedFunction<(x: ArrayLike) => Array>;
1191
1603
  /**
1604
+ * @function
1192
1605
  * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
1193
1606
  * Swish, computed element-wise:
1194
1607
  * `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
@@ -1197,14 +1610,67 @@ declare function silu(x: ArrayLike): Array;
1197
1610
  *
1198
1611
  * Reference: https://en.wikipedia.org/wiki/Swish_function
1199
1612
  */
1200
- declare const swish: typeof silu;
1613
+ declare const swish: OwnedFunction<(x: ArrayLike) => Array>;
1201
1614
  /**
1202
1615
  * Log-sigmoid activation function, computed element-wise:
1203
1616
  * `log_sigmoid(x) = log(sigmoid(x)) = -log(1 + exp(-x))`.
1204
1617
  */
1205
1618
  declare function logSigmoid(x: ArrayLike): Array;
1206
- /** Identity activation function. Returns the argument unmodified. */
1619
+ /**
1620
+ * @function
1621
+ * Identity activation function. Returns the argument unmodified.
1622
+ */
1207
1623
  declare const identity: (x: ArrayLike) => Array;
1624
+ /** Leaky rectified linear (ReLU) activation function */
1625
+ declare function leakyRelu(x: ArrayLike, negativeSlope?: ArrayLike): Array;
1626
+ /**
1627
+ * Exponential linear unit activation function.
1628
+ *
1629
+ * Computes the element-wise function:
1630
+ * `elu(x) = x > 0 ? x : alpha * (exp(x) - 1)`
1631
+ */
1632
+ declare function elu(x: ArrayLike, alpha?: ArrayLike): Array;
1633
+ /**
1634
+ * Continuously-differentiable exponential linear unit activation function.
1635
+ *
1636
+ * Computes the element-wise function:
1637
+ * `celu(x) = x > 0 ? x : alpha * (exp(x/alpha) - 1)`
1638
+ */
1639
+ declare function celu(x: ArrayLike, alpha?: ArrayLike): Array;
1640
+ /**
1641
+ * @function
1642
+ * Gaussion error linear unit (GELU) activation function.
1643
+ *
1644
+ * This is computed element-wise. Currently jax-js does not support the erf() or
1645
+ * gelu() functions exactly as primitives, so an approximation is used:
1646
+ * `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
1647
+ *
1648
+ * Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
1649
+ *
1650
+ * This will be improved in the future.
1651
+ */
1652
+ declare const gelu: OwnedFunction<(x: ArrayLike) => Array>;
1653
+ /**
1654
+ * Gated linear unit (GLU) activation function.
1655
+ *
1656
+ * Splits the `axis` dimension of the input into two halves, a and b, then
1657
+ * computes `a * sigmoid(b)`.
1658
+ */
1659
+ declare function glu(x: ArrayLike, axis?: number): Array;
1660
+ /**
1661
+ * Squareplus activation function.
1662
+ *
1663
+ * Computes the element-wise function:
1664
+ * `squareplus(x) = 0.5 * (x + sqrt(x^2 + b))`
1665
+ */
1666
+ declare function squareplus(x: ArrayLike, b?: ArrayLike): Array;
1667
+ /**
1668
+ * Mish activation function.
1669
+ *
1670
+ * Computes the element-wise function:
1671
+ * `mish(x) = x * tanh(softplus(x))`
1672
+ */
1673
+ declare function mish(x: ArrayLike): Array;
1208
1674
  /**
1209
1675
  * Softmax function. Computes the function which rescales elements to the range
1210
1676
  * [0, 1] such that the elements along `axis` sum to 1.
@@ -1213,7 +1679,7 @@ declare const identity: (x: ArrayLike) => Array;
1213
1679
  *
1214
1680
  * Reference: https://en.wikipedia.org/wiki/Softmax_function
1215
1681
  */
1216
- declare function softmax(x: ArrayLike, axis?: number | number[]): Array;
1682
+ declare function softmax(x: ArrayLike, axis?: Axis): Array;
1217
1683
  /**
1218
1684
  * Log-Softmax function.
1219
1685
  *
@@ -1222,7 +1688,7 @@ declare function softmax(x: ArrayLike, axis?: number | number[]): Array;
1222
1688
  *
1223
1689
  * If `axis` is not specified, it defaults to the last axis.
1224
1690
  */
1225
- declare function logSoftmax(x: ArrayLike, axis?: number | number[]): Array;
1691
+ declare function logSoftmax(x: ArrayLike, axis?: Axis): Array;
1226
1692
  /**
1227
1693
  * Log-sum-exp reduction. Also a multivariate version of `softplus`.
1228
1694
  *
@@ -1231,7 +1697,22 @@ declare function logSoftmax(x: ArrayLike, axis?: number | number[]): Array;
1231
1697
  *
1232
1698
  * Reference: https://en.wikipedia.org/wiki/LogSumExp
1233
1699
  */
1234
- declare function logsumexp(x: ArrayLike, axis?: number | number[]): Array;
1700
+ declare function logsumexp(x: ArrayLike, axis?: Axis): Array;
1701
+ /** Log-mean-exp reduction, like `jax.nn.logsumexp()` but subtracts `log(n)`. */
1702
+ declare function logmeanexp(x: ArrayLike, axis?: Axis): Array;
1703
+ /**
1704
+ * Standardizes input to zero mean and unit variance.
1705
+ *
1706
+ * By default, this is computed over the last axis. You can pass in a different
1707
+ * axis, or `null` to standardize over all elements.
1708
+ *
1709
+ * Epsilon is added to denominator, it defaults to `1e-5` for stability.
1710
+ */
1711
+ declare function standardize(x: ArrayLike, axis?: Axis, opts?: {
1712
+ mean?: ArrayLike;
1713
+ variance?: ArrayLike;
1714
+ epsilon?: ArrayLike;
1715
+ }): Array;
1235
1716
  /**
1236
1717
  * One-hot encodes the given indices.
1237
1718
  *
@@ -1250,7 +1731,7 @@ declare function logsumexp(x: ArrayLike, axis?: number | number[]): Array;
1250
1731
  */
1251
1732
  declare function oneHot(x: Array, numClasses: number): Array;
1252
1733
  declare namespace random_d_exports {
1253
- export { bits, key, split, uniform };
1734
+ export { bernoulli, bits, exponential, key, normal, split, uniform };
1254
1735
  }
1255
1736
  /** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
1256
1737
  declare function key(seed: number): Array;
@@ -1266,40 +1747,108 @@ declare function uniform(key: Array, shape?: number[], {
1266
1747
  minval?: number;
1267
1748
  maxval?: number;
1268
1749
  }): Array;
1750
+ /**
1751
+ * Sample Bernoulli random variables with given mean (0,1 categorical).
1752
+ *
1753
+ * Returns a random Boolean array with the specified shape. `p` can be an array
1754
+ * and must be broadcastable to `shape`.
1755
+ */
1756
+ declare function bernoulli(key: Array, p?: ArrayLike, shape?: number[]): Array;
1757
+ /** Sample exponential random values according to `p(x) = exp(-x)`. */
1758
+ declare function exponential(key: Array, shape?: number[]): Array;
1759
+ /**
1760
+ * Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
1761
+ *
1762
+ * Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
1763
+ * directly inverts the CDF, but we don't have support for that yet. Outputs will not be
1764
+ * bitwise identical to JAX.
1765
+ */
1766
+ declare function normal(key: Array, shape?: number[]): Array;
1269
1767
  //#endregion
1270
1768
  //#region src/index.d.ts
1271
- /** @inline */
1272
- type WithArgsSubtype<F extends (args: any[]) => any, T> = Parameters<F> extends T ? F : never;
1273
- /** Compute the forward-mode Jacobian-vector product for a function. */
1274
- declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>, primals: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>, tangents: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => [ReturnType<F>, ReturnType<F>];
1275
- /** Vectorize an operation on a batched axis for one or more inputs. */
1276
- declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>, inAxes?: number | MapJsTree<Parameters<F>, ArrayLike, number | null>) => (...args: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => ReturnType<F>;
1277
- /** Compute the Jacobian evaluated column-by-column by forward-mode AD. */
1278
- declare const jacfwd: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => ReturnType<F>;
1279
- /** Construct a Jaxpr by dynamically tracing a function with example inputs. */
1280
- declare const makeJaxpr: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>) => (...args: Parameters<F>) => {
1769
+ /**
1770
+ * @function
1771
+ * Compute the forward-mode Jacobian-vector product for a function.
1772
+ */
1773
+ declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, ReturnType<F>];
1774
+ /**
1775
+ * @function
1776
+ * Vectorize an operation on a batched axis for one or more inputs.
1777
+ */
1778
+ declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number | MapJsTree<Parameters<F>, ArrayLike, number | null>) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
1779
+ /**
1780
+ * @function
1781
+ * Compute the Jacobian evaluated column-by-column by forward-mode AD.
1782
+ */
1783
+ declare const jacfwd: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
1784
+ /**
1785
+ * @function
1786
+ * Construct a Jaxpr by dynamically tracing a function with example inputs.
1787
+ */
1788
+ declare const makeJaxpr: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...args: Parameters<F>) => {
1281
1789
  jaxpr: Jaxpr;
1282
1790
  consts: Array[];
1283
1791
  treedef: JsTreeDef;
1284
1792
  };
1285
- declare const jit: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>) => (...args: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => ReturnType<F>;
1286
1793
  /**
1794
+ * @function
1795
+ * Mark a function for automatic JIT compilation, with operator fusion.
1796
+ *
1797
+ * The function will be compiled the first time it is called with a set of
1798
+ * argument shapes.
1799
+ *
1800
+ * You can call `.dispose()` on the returned, JIT-compiled function after all
1801
+ * calls to free memory associated with array constants.
1802
+ *
1803
+ * **Options:**
1804
+ * - `staticArgnums`: An array of argument indices to treat as static
1805
+ * (compile-time constant). These arguments must be hashable, won't be traced,
1806
+ * and different values will trigger recompilation.
1807
+ * - `device`: The device to place the computation on. If not specified, the
1808
+ * computation will be placed on the default device.
1809
+ */
1810
+ declare const jit: <F extends (...args: any[]) => JsTree<Array>>(f: F, opts?: JitOpts) => OwnedFunction<(...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>>;
1811
+ /**
1812
+ * @function
1287
1813
  * Produce a local linear approximation to a function at a point using jvp() and
1288
1814
  * partial evaluation.
1289
1815
  */
1290
- declare const linearize: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>, ...primals: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => [ReturnType<F>, (...tangents: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => ReturnType<F>];
1291
- /** Calculate the reverse-mode vector-Jacobian product for a function. */
1292
- declare const vjp: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>, ...primals: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => [ReturnType<F>, (cotangents: MapJsTree<ReturnType<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>];
1816
+ declare const linearize: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (...tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>];
1293
1817
  /**
1818
+ * @function
1819
+ * Calculate the reverse-mode vector-Jacobian product for a function.
1820
+ */
1821
+ declare const vjp: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (cotangents: MapJsTree<ReturnType<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>];
1822
+ /**
1823
+ * @function
1294
1824
  * Compute the gradient of a scalar-valued function `f` with respect to its
1295
1825
  * first argument.
1296
1826
  */
1297
- declare const grad: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>) => (...primals: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => MapJsTree<Parameters<F>[0], ArrayLike, Array>;
1298
- /** Create a function that evaluates both `f` and the gradient of `f`. */
1299
- declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>) => (...primals: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => [ReturnType<F>, MapJsTree<Parameters<F>[0], ArrayLike, Array>];
1300
- /** Compute the Jacobian evaluated row-by-row by reverse-mode AD. */
1827
+ declare const grad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>[0], ArrayLike, Array>;
1828
+ /**
1829
+ * @function
1830
+ * Create a function that evaluates both `f` and the gradient of `f`.
1831
+ */
1832
+ declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, MapJsTree<Parameters<F>[0], ArrayLike, Array>];
1833
+ /**
1834
+ * @function
1835
+ * Compute the Jacobian evaluated row-by-row by reverse-mode AD.
1836
+ */
1301
1837
  declare const jacrev: typeof jacfwd;
1302
- /** Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`. */
1303
- declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => ReturnType<F>;
1838
+ /**
1839
+ * @function
1840
+ * Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
1841
+ */
1842
+ declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
1843
+ /**
1844
+ * Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
1845
+ *
1846
+ * This can be used to wait for the results of an intermediate computation to
1847
+ * finish. It's recommended to call this regularly in an iterative computation
1848
+ * to avoid queueing up too many pending operations.
1849
+ *
1850
+ * Does not consume reference to the arrays.
1851
+ */
1852
+ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
1304
1853
  //#endregion
1305
- export { type Device, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, setDevice, tree_d_exports as tree, valueAndGrad, vjp, vmap };
1854
+ export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };