@jax-js/jax 0.0.2 → 0.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.d.cts CHANGED
@@ -23,6 +23,7 @@ interface Stringable {
23
23
  }
24
24
  //#endregion
25
25
  //#region src/shape.d.ts
26
+ /** @inline */
26
27
  type Pair = [number, number];
27
28
  /**
28
29
  * A multidimensional view into memory. An array can be thought of as the
@@ -49,6 +50,8 @@ declare class View {
49
50
  get size(): number;
50
51
  /** Whether this is a default, contiguous, unaltered view of the data (identity). */
51
52
  get contiguous(): boolean;
53
+ /** Return the range of data being indexed in this view, or [0, 0] if none. */
54
+ dataRange(): [number, number];
52
55
  /** Produce an AluExp for evaluating this view at an index. */
53
56
  toAluExp(idxs: AluExp[]): [AluExp, AluExp];
54
57
  /**
@@ -109,9 +112,33 @@ declare class ShapeTracker {
109
112
  reshape(newShape: number[]): ShapeTracker;
110
113
  /** Broadcast along the given new axes, then expand the shape. */
111
114
  broadcast(newShape: number[], axis: number[]): ShapeTracker;
115
+ /**
116
+ * Repeat data in each axis by a positive number of repetitions.
117
+ *
118
+ * - If `tile` is true (default): [1, 2, 3] -> [1, 2, 3, 1, 2, 3].
119
+ * - If `tile` is false: [1, 2, 3] -> [1, 1, 2, 2, 3, 3].
120
+ */
121
+ repeat(reps: number[], tile?: boolean): ShapeTracker;
122
+ /** Move axis i to axis j. */
123
+ moveaxis(i: number, j: number): ShapeTracker;
124
+ /** Like pad(), but allows for negative values. */
125
+ padOrShrink(arg: Pair[]): ShapeTracker;
112
126
  }
113
127
  //#endregion
114
128
  //#region src/utils.d.ts
129
+ /**
130
+ * Set the debug level for verbose logging.
131
+ *
132
+ * 1. JIT compile logs
133
+ * 2. Shader code
134
+ * 3. Expressions and metadata
135
+ * 4. JIT programs, tuning details
136
+ * 5. Most verbose operation traces
137
+ *
138
+ * This is an experimental API and may change in behavior. Do not rely on this
139
+ * in production.
140
+ */
141
+ declare function setDebug(level: number): void;
115
142
  /** @inline */
116
143
  type RecursiveArray<T> = T | RecursiveArray<T>[];
117
144
  interface FpHashable {
@@ -127,19 +154,47 @@ interface FpHashable {
127
154
  declare class FpHash {
128
155
  #private;
129
156
  value: bigint;
130
- update(...values: (string | boolean | number | bigint | null | undefined | FpHashable)[]): this;
157
+ update(x: string | boolean | number | bigint | null | undefined | FpHashable): this;
131
158
  static hash(...values: (string | boolean | number | bigint | null | undefined | FpHashable)[]): bigint;
132
159
  }
133
160
  /** Run a function while caching it inline inside a `Map`. */
134
161
  //#endregion
135
162
  //#region src/alu.d.ts
163
+ /** A numerical data type for array contents. */
136
164
  declare enum DType {
137
165
  Float32 = "float32",
138
166
  Int32 = "int32",
139
167
  Uint32 = "uint32",
140
168
  Bool = "bool",
141
- Complex64 = "complex64",
169
+ Float16 = "float16",
142
170
  }
171
+ /** @inline */
172
+ type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | Float16Array<ArrayBuffer>;
173
+ /**
174
+ * Promote two dtypes to their join according to the type lattice.
175
+ *
176
+ * When performing operations between arrays of different types, we need to
177
+ * promote both operands to a common type that can represent values from both
178
+ * input types. This follows JAX's type promotion rules.
179
+ *
180
+ * **Type lattice:**
181
+ * ```text
182
+ * bool -> uint32 -> int32 -> float16 -> float32
183
+ * weak f* --^
184
+ * ```
185
+ *
186
+ * The asterisk f* is a weak type used for JS number constants. When creating
187
+ * arrays, JS numbers default to float32 but "weak" so they cast to the dtype of
188
+ * any array they are first combined with.
189
+ *
190
+ * **Examples:**
191
+ * - `promoteTypes(bool, int32) → int32`
192
+ * - `promoteTypes(uint32, int32) → int32`
193
+ * - `promoteTypes(int32, float16) → float16`
194
+ * - `promoteTypes(float16, float32) → float32`
195
+ * - `promoteTypes(uint32, float32) → float32`
196
+ */
197
+ declare function promoteTypes(dtype1: DType, dtype2: DType): DType;
143
198
  /**
144
199
  * Mathematical expression on scalar values.
145
200
  *
@@ -163,8 +218,11 @@ declare class AluExp implements FpHashable {
163
218
  static max(a: AluExp, b: AluExp): AluExp;
164
219
  static sin(a: AluExp): AluExp;
165
220
  static cos(a: AluExp): AluExp;
221
+ static asin(a: AluExp): AluExp;
222
+ static atan(a: AluExp): AluExp;
166
223
  static exp(a: AluExp): AluExp;
167
224
  static log(a: AluExp): AluExp;
225
+ static sqrt(a: AluExp): AluExp;
168
226
  static reciprocal(a: AluExp): AluExp;
169
227
  static cast(dtype: DType, a: AluExp): AluExp;
170
228
  static bitcast(dtype: DType, a: AluExp): AluExp;
@@ -175,12 +233,13 @@ declare class AluExp implements FpHashable {
175
233
  static const(dtype: DType, value: any): AluExp;
176
234
  static special(dtype: DType, name: string, n: number): AluExp;
177
235
  static variable(dtype: DType, name: string): AluExp;
178
- static globalIndex(dtype: DType, gid: number, bufidx: AluExp): AluExp;
236
+ static globalIndex(dtype: DType, gid: number, len: number, bufidx: AluExp): AluExp;
179
237
  static globalView(dtype: DType, gid: number, st: ShapeTracker, indices: AluExp[]): AluExp;
238
+ static f32(value: number): AluExp;
180
239
  static i32(value: number): AluExp;
181
240
  static u32(value: number): AluExp;
182
- static f32(value: number): AluExp;
183
241
  static bool(value: boolean): AluExp;
242
+ static f16(value: number): AluExp;
184
243
  not(): AluExp;
185
244
  /** Compute a reasonable expression hash with low collision rate. */
186
245
  getHash(): bigint;
@@ -191,6 +250,19 @@ declare class AluExp implements FpHashable {
191
250
  reindexGids(gidMap: Map<number, number>): AluExp;
192
251
  get min(): number;
193
252
  get max(): number;
253
+ /** Largest known integer that divides self. */
254
+ constFactor(): number;
255
+ /**
256
+ * Checks if divisible by an integer v and returns the quotient if it is, or
257
+ * `null` if it's not divisible.
258
+ */
259
+ divides(v: number): AluExp | null;
260
+ /**
261
+ * Get all expressions by deeply matching an operation.
262
+ *
263
+ * For example: `((2+(3*5))+4).splitOp(+) -> [2,(3*5),4]`.
264
+ */
265
+ splitOp(sep: AluOp): IterableIterator<AluExp>;
194
266
  /**
195
267
  * Simplify the expression by replacing any known patterns and deduping
196
268
  * identical subexpressions.
@@ -211,10 +283,16 @@ declare class AluExp implements FpHashable {
211
283
  toString(): string;
212
284
  /** Generic fold() operation with a reducer over the expression tree. */
213
285
  fold<T = void>(reducer: (exp: AluExp, mappedSrc: T[]) => T): T;
286
+ /** Check if any expression in the tree satisfies a predicate. */
287
+ some(predicate: (exp: AluExp) => boolean): boolean;
214
288
  /** Rewrite the expression recursively using a visitor. */
215
289
  rewrite(visitor: (exp: AluExp) => AluExp | undefined | null): AluExp;
216
290
  /** Collect all nodes that satisfy a predicate. */
217
291
  collect(predicate: (exp: AluExp) => boolean): AluExp[];
292
+ /** Produce a list of all distinct AluOp in this expression. */
293
+ distinctOps(): Set<AluOp>;
294
+ /** Rewrite GlobalView operations to GlobalIndex operations. */
295
+ rewriteGlobalViews(): AluExp;
218
296
  }
219
297
  /** Symbolic form for each mathematical operation. */
220
298
  declare enum AluOp {
@@ -227,8 +305,11 @@ declare enum AluOp {
227
305
  Max = "Max",
228
306
  Sin = "Sin",
229
307
  Cos = "Cos",
308
+ Asin = "Asin",
309
+ Atan = "Atan",
230
310
  Exp = "Exp",
231
311
  Log = "Log",
312
+ Sqrt = "Sqrt",
232
313
  Reciprocal = "Reciprocal",
233
314
  Cast = "Cast",
234
315
  Bitcast = "Bitcast",
@@ -245,7 +326,7 @@ declare enum AluOp {
245
326
  Variable = "Variable",
246
327
  // arg = variable
247
328
  GlobalIndex = "GlobalIndex",
248
- // arg = gid; src = [bufidx]
329
+ // arg = [gid, len]; src = [bufidx]
249
330
  GlobalView = "GlobalView",
250
331
  }
251
332
  /**
@@ -300,12 +381,12 @@ declare class Reduction implements FpHashable {
300
381
  /** Size of the reduction axis. */
301
382
  readonly size: number;
302
383
  /** Follow-up expression defined with the "acc" variable, defaults to identity. */
303
- readonly fusion: AluExp;
384
+ readonly epilogue: AluExp;
304
385
  constructor(/** Data type of the values being reduced over. */
305
386
  dtype: DType, /** Operation to perform. Only ops in `AluGroup.Reduce` are supported. */
306
387
  op: AluOp, /** Size of the reduction axis. */
307
388
  size: number, /** Follow-up expression defined with the "acc" variable, defaults to identity. */
308
- fusion?: AluExp);
389
+ epilogue?: AluExp);
309
390
  hash(state: FpHash): void;
310
391
  toString(): string;
311
392
  /** Get the identity for this reduction operation. */
@@ -316,10 +397,10 @@ declare class Reduction implements FpHashable {
316
397
  /** Expression for accessing `indices` in input array with the given shape. */
317
398
  //#endregion
318
399
  //#region src/backend.d.ts
319
- type Device = "cpu" | "webgpu";
400
+ type Device = "cpu" | "wasm" | "webgpu";
320
401
  declare const devices: Device[];
321
- /** Set the default device backend (must be initialized). */
322
- declare function setDevice(device: Device): void;
402
+ /** Configure the default device for arrays. */
403
+ declare function defaultDevice(device?: Device): Device;
323
404
  /**
324
405
  * Initialize `jax-js` library backends.
325
406
  *
@@ -339,7 +420,7 @@ interface Backend {
339
420
  /** Maximum number of arguments per dispatched kernel. */
340
421
  readonly maxArgs: number;
341
422
  /** Allocate a new slot with reference count 1. */
342
- malloc(size: number, initialData?: ArrayBuffer): Slot;
423
+ malloc(size: number, initialData?: Uint8Array): Slot;
343
424
  /** Increment the reference count of the slot. */
344
425
  incRef(slot: Slot): void;
345
426
  /**
@@ -348,9 +429,9 @@ interface Backend {
348
429
  */
349
430
  decRef(slot: Slot): void;
350
431
  /** Read a range of bytes from a buffer. */
351
- read(slot: Slot, start?: number, count?: number): Promise<ArrayBuffer>;
432
+ read(slot: Slot, start?: number, count?: number): Promise<Uint8Array<ArrayBuffer>>;
352
433
  /** Read a range of bytes from a buffer, blocking variant. */
353
- readSync(slot: Slot, start?: number, count?: number): ArrayBuffer;
434
+ readSync(slot: Slot, start?: number, count?: number): Uint8Array<ArrayBuffer>;
354
435
  /** Prepare an expression to be executed later. */
355
436
  prepare(kernel: Kernel): Promise<Executable>;
356
437
  /** Prepare an expression to be executed later, blocking variant. */
@@ -372,18 +453,22 @@ declare class Executable<T = any> {
372
453
  data: T);
373
454
  }
374
455
  declare namespace tree_d_exports {
375
- export { JsTree, JsTreeDef, MapJsTree, NodeType, flatten, leaves, map, ref, structure, unflatten };
456
+ export { JsTree, JsTreeDef, MapJsTree, NodeType, dispose, flatten, leaves, map, ref, structure, unflatten };
376
457
  }
377
458
  declare enum NodeType {
378
459
  Array = "Array",
379
460
  Object = "Object",
380
461
  Leaf = "Leaf",
381
462
  }
463
+ /** Analog to the JAX "pytree" object, but for JavaScript. */
382
464
  type JsTree<T> = T | JsTree<T>[] | {
383
465
  [key: string]: JsTree<T>;
384
466
  };
385
- type MapJsTree<T, A, B> = T extends A ? B : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
386
- /** Analog to the JAX "pytree" object, but for JavaScript. */
467
+ type Same<X, Y> = (<T>() => T extends X ? 1 : 2) extends (<T>() => T extends Y ? 1 : 2) ? true : false;
468
+ type MappedJsTree<T, A, B> = T extends A ? B : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
469
+ /** @ignore Convert a subtype of JsTree<A> into a JsTree<B>, with the same structure. */
470
+ type MapJsTree<T, A, B> = Same<A, B> extends true ? T : MappedJsTree<T, A, B>;
471
+ /** Represents the structure of a JsTree. */
387
472
  declare class JsTreeDef {
388
473
  readonly nodeType: NodeType;
389
474
  readonly nodeMetadata: any;
@@ -409,6 +494,23 @@ declare function unflatten<T>(treedef: JsTreeDef, leaves: Iterable<T>): JsTree<T
409
494
  declare function map<T, U, Tree extends JsTree<T>>(fn: (...args: T[]) => U, tree: Tree, ...rest: Tree[]): MapJsTree<Tree, T, U>;
410
495
  /** Take a reference of every array in a tree. */
411
496
  declare function ref<Tree extends JsTree<Array>>(tree: Tree): Tree;
497
+ /** Dispose every array in a tree. */
498
+ declare function dispose<Tree extends JsTree<Array>>(tree: Tree | null | undefined): void;
499
+ //#endregion
500
+ //#region src/frontend/convolution.d.ts
501
+ /** Definition of a general dilated convolution. Should be valid on creation. */
502
+ interface ConvParams {
503
+ strides: number[];
504
+ padding: [number, number][];
505
+ lhsDilation: number[];
506
+ rhsDilation: number[];
507
+ }
508
+ /**
509
+ * Check that the shapes and parameters passed to convolution are valid.
510
+ *
511
+ * If the check succeeds, returns the output shape.
512
+ */
513
+
412
514
  //#endregion
413
515
  //#region src/frontend/core.d.ts
414
516
  /**
@@ -434,12 +536,20 @@ declare enum Primitive {
434
536
  RandomBits = "random_bits",
435
537
  Sin = "sin",
436
538
  Cos = "cos",
539
+ Asin = "asin",
540
+ Atan = "atan",
437
541
  Exp = "exp",
438
542
  Log = "log",
543
+ Sqrt = "sqrt",
439
544
  Min = "min",
440
545
  Max = "max",
441
546
  Reduce = "reduce",
442
547
  Dot = "dot",
548
+ // sum(x*y, axis=-1)
549
+ Conv = "conv",
550
+ // see lax.conv_general_dilated
551
+ Pool = "pool",
552
+ PoolTranspose = "pool_transpose",
443
553
  Compare = "compare",
444
554
  Where = "where",
445
555
  Transpose = "transpose",
@@ -451,7 +561,7 @@ declare enum Primitive {
451
561
  Gather = "gather",
452
562
  JitCall = "jit_call",
453
563
  }
454
- interface PrimitiveParamsImpl extends Record<Primitive, Record<string, unknown>> {
564
+ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
455
565
  [Primitive.Cast]: {
456
566
  dtype: DType;
457
567
  };
@@ -462,6 +572,16 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, unknown>>
462
572
  op: AluOp;
463
573
  axis: number[];
464
574
  };
575
+ [Primitive.Conv]: ConvParams;
576
+ [Primitive.Pool]: {
577
+ window: number[];
578
+ strides: number[];
579
+ };
580
+ [Primitive.PoolTranspose]: {
581
+ inShape: number[];
582
+ window: number[];
583
+ strides: number[];
584
+ };
465
585
  [Primitive.Compare]: {
466
586
  op: CompareOp;
467
587
  };
@@ -483,10 +603,10 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, unknown>>
483
603
  axis: number[];
484
604
  };
485
605
  [Primitive.Shrink]: {
486
- slice: [number, number][];
606
+ slice: Pair[];
487
607
  };
488
608
  [Primitive.Pad]: {
489
- width: [number, number][];
609
+ width: Pair[];
490
610
  };
491
611
  [Primitive.Gather]: {
492
612
  axis: number[];
@@ -508,8 +628,10 @@ declare enum CompareOp {
508
628
  LessEqual = "less_equal",
509
629
  }
510
630
  /** @inline */
631
+ type Axis = number | number[] | null;
632
+ /** @inline */
511
633
  type ReduceOpts = {
512
- keepDims?: boolean;
634
+ keepdims?: boolean;
513
635
  };
514
636
  type MainTrace = {
515
637
  level: number;
@@ -586,8 +708,13 @@ declare abstract class Tracer {
586
708
  * ```
587
709
  */
588
710
  abstract dispose(): void;
711
+ /** The shape of the array. */
589
712
  get shape(): number[];
713
+ /** The total number of elements in the array. */
714
+ get size(): number;
715
+ /** The dtype of the array. */
590
716
  get dtype(): DType;
717
+ /** The number of dimensions of the array. */
591
718
  get ndim(): number;
592
719
  /** @ignore */
593
720
  fullLower(): Tracer;
@@ -601,11 +728,11 @@ declare abstract class Tracer {
601
728
  greaterEqual(other: this | TracerValue): this;
602
729
  lessEqual(other: this | TracerValue): this;
603
730
  /** Sum of the elements of the array over a given axis, or axes. */
604
- sum(axis?: number | number[], opts?: ReduceOpts): this;
731
+ sum(axis?: Axis, opts?: ReduceOpts): this;
605
732
  /** Product of the array elements over a given axis. */
606
- prod(axis?: number | number[], opts?: ReduceOpts): this;
733
+ prod(axis?: Axis, opts?: ReduceOpts): this;
607
734
  /** Compute the average of the array elements along the specified axis. */
608
- mean(axis?: number | number[], opts?: ReduceOpts): this;
735
+ mean(axis?: Axis, opts?: ReduceOpts): this;
609
736
  /** Permute the dimensions of an array. Defaults to reversing the axis order. */
610
737
  transpose(perm?: number[]): this;
611
738
  /**
@@ -673,7 +800,7 @@ declare abstract class Tracer {
673
800
  * the "gather" primitive, and it allows you to access specific elements of
674
801
  * the array by integer indices stored in another array.
675
802
  */
676
- slice(...index: (number | [] | [number] | [number, number] | null | Tracer)[]): this;
803
+ slice(...index: (number | [] | [number] | Pair | null | Tracer)[]): this;
677
804
  }
678
805
  declare class ShapedArray implements AbstractValue {
679
806
  readonly shape: number[];
@@ -681,7 +808,7 @@ declare class ShapedArray implements AbstractValue {
681
808
  constructor(shape: number[], dtype: DType);
682
809
  static fromAval(aval: AbstractValue): ShapedArray;
683
810
  get ndim(): number;
684
- strShort(): string;
811
+ toString(): string;
685
812
  equals(other: ShapedArray): boolean;
686
813
  }
687
814
  //#endregion
@@ -733,7 +860,11 @@ declare class Array extends Tracer {
733
860
  * is a backend `Slot`, this constructor _takes ownership_ of the slot. It
734
861
  * will be freed when the array is disposed.
735
862
  */
736
- constructor(source: AluExp | Slot, st: ShapeTracker, dtype: DType, backend: Backend, pending?: Iterable<PendingExecute> | null);
863
+ constructor(source: AluExp | Slot, st: ShapeTracker, dtype: DType, backend: Backend, {
864
+ pending
865
+ }?: {
866
+ pending?: Iterable<PendingExecute> | null;
867
+ });
737
868
  /** @ignore */
738
869
  get aval(): ShapedArray;
739
870
  /** Return a simple string representation of the array's dimensions. */
@@ -752,14 +883,26 @@ declare class Array extends Tracer {
752
883
  */
753
884
  [Symbol.toPrimitive](): any;
754
885
  /** Realize the array and return it as data. */
755
- data(): Promise<Float32Array | Int32Array | Uint32Array>;
756
- /** Wait for this array to be placed on the backend, if needed. */
757
- wait(): Promise<void>;
886
+ data(): Promise<DataArray>;
887
+ /**
888
+ * Wait for this array to finish evaluation.
889
+ *
890
+ * Operations and data loading in jax-js are lazy, so this function ensures
891
+ * that pending operations are dispatched and fully executed before it
892
+ * returns.
893
+ *
894
+ * If you are mapping from `data()` or `dataSync()`, it will also trigger
895
+ * dispatch of operations as well.
896
+ *
897
+ * **Note:** `jax.blockUntilReady()` is a higher-level API, it calls this
898
+ * asynchronously for multiple arrays.
899
+ */
900
+ blockUntilReady(): Promise<Array>;
758
901
  /**
759
902
  * Realize the array and return it as data. This is a sync variant and not
760
903
  * recommended for performance reasons, as it will block rendering.
761
904
  */
762
- dataSync(): Float32Array | Int32Array | Uint32Array;
905
+ dataSync(): DataArray;
763
906
  /**
764
907
  * Convert this array into a JavaScript object.
765
908
  *
@@ -772,17 +915,21 @@ declare class Array extends Tracer {
772
915
  js(): any;
773
916
  /** Convert this array into a JavaScript object, asynchronously. */
774
917
  jsAsync(): Promise<any>;
918
+ /**
919
+ * Copy an element of an array to a numeric scalar and return it.
920
+ *
921
+ * Throws an error if the array does not have a single element. The array must
922
+ * either be rank-0, or all dimensions of the shape are 1.
923
+ */
924
+ item(): number;
775
925
  /** @private Internal plumbing method for Array / Tracer ops. */
776
926
  static _implRules(): typeof implRules;
777
927
  _realizeSource(): number;
778
928
  }
779
929
  /** Construct an array from a single scalar constant. */
780
- declare function scalar(value: number | boolean, {
781
- dtype,
782
- device
783
- }?: DTypeAndDevice): Array;
930
+
784
931
  /** Constructor for creating a new array from data. */
785
- declare function array(values: Array | Float32Array | Int32Array | RecursiveArray<number> | RecursiveArray<boolean>, {
932
+ declare function array(values: Array | Float16Array<ArrayBuffer> | Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | RecursiveArray<number> | RecursiveArray<boolean>, {
786
933
  shape,
787
934
  dtype,
788
935
  device
@@ -818,7 +965,7 @@ declare function eye(numRows: number, numCols?: number, {
818
965
  dtype,
819
966
  device
820
967
  }?: DTypeAndDevice): Array;
821
- /** Return the identity array, with ones on the main diagonal. */
968
+ /** Return the identity matrix, with ones on the main diagonal. */
822
969
  declare function identity$1(n: number, {
823
970
  dtype,
824
971
  device
@@ -854,15 +1001,14 @@ declare function linspace(start: number, stop: number, num?: number, endpoint?:
854
1001
  dtype,
855
1002
  device
856
1003
  }?: DTypeAndDevice): Array;
857
- /** Translate a `CompareOp` into an `AluExp` on two sub-expressions. */
858
1004
  declare namespace numpy_d_exports {
859
- export { Array, ArrayLike, DType, abs, absolute, add, allclose, arange, argmax, argmin, array, astype, bool, clip, columnStack, complex64, concatenate, cos, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, eye, flip, fliplr, flipud, float32, full, greater, greaterEqual, hstack, identity$1 as identity, inf, int32, less, lessEqual, linspace, log, log10, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, pad, permuteDims, pi, prod, ravel, reciprocal, reshape, scalar, shape$1 as shape, sin, size, square, stack, sum, tan, transpose, trueDivide, trunc, uint32, vdot, vecdot, vstack, where, zeros };
1005
+ export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hstack, hypot, identity$1 as identity, inf, inner, int32, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
860
1006
  }
861
1007
  declare const float32 = DType.Float32;
862
1008
  declare const int32 = DType.Int32;
863
1009
  declare const uint32 = DType.Uint32;
864
1010
  declare const bool = DType.Bool;
865
- declare const complex64 = DType.Complex64;
1011
+ declare const float16 = DType.Float16;
866
1012
  /** Euler's constant, `e = 2.7182818284590...` */
867
1013
  declare const e: number;
868
1014
  /** Euler-Mascheroni constant, `γ = 0.5772156649...` */
@@ -873,52 +1019,66 @@ declare const inf: number;
873
1019
  declare const nan: number;
874
1020
  /** This is Pi, `π = 3.14159265358979...` */
875
1021
  declare const pi: number;
876
- /** Element-wise addition, with broadcasting. */
1022
+ /** @function Element-wise addition, with broadcasting. */
877
1023
  declare const add: (x: ArrayLike, y: ArrayLike) => Array;
878
- /** Element-wise multiplication, with broadcasting. */
1024
+ /** @function Element-wise multiplication, with broadcasting. */
879
1025
  declare const multiply: (x: ArrayLike, y: ArrayLike) => Array;
880
- /** Numerical negative of every element of an array. */
1026
+ /** @function Numerical negative of every element of an array. */
881
1027
  declare const negative: (x: ArrayLike) => Array;
882
- /** Calculate element-wise reciprocal of the input. This is `1/x`. */
1028
+ /** @function Calculate element-wise reciprocal of the input. This is `1/x`. */
883
1029
  declare const reciprocal: (x: ArrayLike) => Array;
884
- /** Element-wise sine function (takes radians). */
1030
+ /** @function Element-wise sine function (takes radians). */
885
1031
  declare const sin: (x: ArrayLike) => Array;
886
- /** Element-wise cosine function (takes radians). */
1032
+ /** @function Element-wise cosine function (takes radians). */
887
1033
  declare const cos: (x: ArrayLike) => Array;
888
- /** Calculate the exponential of all elements in the input array. */
1034
+ /** @function Element-wise inverse sine function (inverse of sin). */
1035
+ declare const asin: (x: ArrayLike) => Array;
1036
+ /** @function Element-wise inverse tangent function (inverse of tan). */
1037
+ declare const atan: (x: ArrayLike) => Array;
1038
+ /** @function Calculate the exponential of all elements in the input array. */
889
1039
  declare const exp: (x: ArrayLike) => Array;
890
- /** Calculate the natural logarithm of all elements in the input array. */
1040
+ /** @function Calculate the natural logarithm of all elements in the input array. */
891
1041
  declare const log: (x: ArrayLike) => Array;
892
- /** Return element-wise minimum of the input arrays. */
1042
+ /** @function Calculate the square root of all elements in the input array. */
1043
+ declare const sqrt: (x: ArrayLike) => Array;
1044
+ /** @function Return element-wise minimum of the input arrays. */
893
1045
  declare const minimum: (x: ArrayLike, y: ArrayLike) => Array;
894
- /** Return element-wise maximum of the input arrays. */
1046
+ /** @function Return element-wise maximum of the input arrays. */
895
1047
  declare const maximum: (x: ArrayLike, y: ArrayLike) => Array;
896
- /** Compare two arrays element-wise. */
1048
+ /** @function Compare two arrays element-wise. */
897
1049
  declare const greater: (x: ArrayLike, y: ArrayLike) => Array;
898
- /** Compare two arrays element-wise. */
1050
+ /** @function Compare two arrays element-wise. */
899
1051
  declare const less: (x: ArrayLike, y: ArrayLike) => Array;
900
- /** Compare two arrays element-wise. */
1052
+ /** @function Compare two arrays element-wise. */
901
1053
  declare const equal: (x: ArrayLike, y: ArrayLike) => Array;
902
- /** Compare two arrays element-wise. */
1054
+ /** @function Compare two arrays element-wise. */
903
1055
  declare const notEqual: (x: ArrayLike, y: ArrayLike) => Array;
904
- /** Compare two arrays element-wise. */
1056
+ /** @function Compare two arrays element-wise. */
905
1057
  declare const greaterEqual: (x: ArrayLike, y: ArrayLike) => Array;
906
- /** Compare two arrays element-wise. */
1058
+ /** @function Compare two arrays element-wise. */
907
1059
  declare const lessEqual: (x: ArrayLike, y: ArrayLike) => Array;
908
- /** Element-wise ternary operator, evaluates to `x` if cond else `y`. */
1060
+ /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
909
1061
  declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
910
- /** Permute the dimensions of an array. Defaults to reversing the axis order. */
1062
+ /**
1063
+ * @function
1064
+ * Permute the dimensions of an array. Defaults to reversing the axis order.
1065
+ */
911
1066
  declare const transpose: (x: ArrayLike, perm?: number[]) => Array;
912
1067
  /**
1068
+ * @function
913
1069
  * Give a new shape to an array without changing its data.
914
1070
  *
915
1071
  * One shape dimension can be -1. In this case, the value is inferred from the
916
1072
  * length of the array and remaining dimensions.
917
1073
  */
918
1074
  declare const reshape: (x: ArrayLike, shape: number[]) => Array;
919
- /** Move axes of an array to new positions. Other axes retain original order. */
1075
+ /**
1076
+ * @function
1077
+ * Move axes of an array to new positions. Other axes retain original order.
1078
+ */
920
1079
  declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
921
1080
  /**
1081
+ * @function
922
1082
  * Add padding (zeros) to an array.
923
1083
  *
924
1084
  * The `width` argument is either an integer or pair of integers, in which case
@@ -926,10 +1086,28 @@ declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
926
1086
  * pair specifies the padding for its corresponding axis.
927
1087
  */
928
1088
  declare const pad: (x: ArrayLike, width: number | [number, number] | [number, number][]) => Array;
929
- /** Return the number of dimensions of an array. Does not consume array reference. */
1089
+ /**
1090
+ * @function
1091
+ * Return the number of dimensions of an array. Does not consume array reference.
1092
+ */
930
1093
  declare const ndim: (x: ArrayLike) => number;
931
- /** Return the shape of an array. Does not consume array reference. */
1094
+ /** @function Return the shape of an array. Does not consume array reference. */
932
1095
  declare const shape$1: (x: ArrayLike) => number[];
1096
+ /**
1097
+ * @function
1098
+ * Return an array of zeros with the same shape and type as a given array.
1099
+ */
1100
+ declare const zerosLike: (a: ArrayLike, dtype?: DType) => Array;
1101
+ /**
1102
+ * @function
1103
+ * Return an array of ones with the same shape and type as a given array.
1104
+ */
1105
+ declare const onesLike: (a: ArrayLike, dtype?: DType) => Array;
1106
+ /**
1107
+ * @function
1108
+ * Return a full array with the same shape and type as a given array.
1109
+ */
1110
+ declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, dtype?: DType) => Array;
933
1111
  /**
934
1112
  * Return the number of elements in an array, optionally along an axis.
935
1113
  * Does not consume array reference.
@@ -938,15 +1116,15 @@ declare function size(a: ArrayLike, axis?: number): number;
938
1116
  /** Convert an array to a specified dtype. */
939
1117
  declare function astype(a: ArrayLike, dtype: DType): Array;
940
1118
  /** Sum of the elements of the array over a given axis, or axes. */
941
- declare function sum(a: ArrayLike, axis?: number | number[], opts?: ReduceOpts): Array;
1119
+ declare function sum(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
942
1120
  /** Product of the array elements over a given axis. */
943
- declare function prod(a: ArrayLike, axis?: number | number[], opts?: ReduceOpts): Array;
1121
+ declare function prod(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
944
1122
  /** Return the minimum of array elements along a given axis. */
945
- declare function min(a: ArrayLike, axis?: number | number[], opts?: ReduceOpts): Array;
1123
+ declare function min(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
946
1124
  /** Return the maximum of array elements along a given axis. */
947
- declare function max(a: ArrayLike, axis?: number | number[], opts?: ReduceOpts): Array;
1125
+ declare function max(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
948
1126
  /** Compute the average of the array elements along the specified axis. */
949
- declare function mean(a: ArrayLike, axis?: number | number[], opts?: ReduceOpts): Array;
1127
+ declare function mean(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
950
1128
  /**
951
1129
  * Returns the indices of the minimum values along an axis.
952
1130
  *
@@ -962,7 +1140,7 @@ declare function argmin(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
962
1140
  */
963
1141
  declare function argmax(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
964
1142
  /** Reverse the elements in an array along the given axes. */
965
- declare function flip(x: ArrayLike, axis?: number | number[]): Array;
1143
+ declare function flip(x: ArrayLike, axis?: Axis): Array;
966
1144
  /**
967
1145
  * Join a sequence of arrays along an existing axis.
968
1146
  *
@@ -1006,16 +1184,45 @@ declare function columnStack(xs: ArrayLike[]): Array;
1006
1184
  declare function flipud(x: ArrayLike): Array;
1007
1185
  /** Flip an array horizontally (axis=1). */
1008
1186
  declare function fliplr(x: ArrayLike): Array;
1187
+ /** @function Alternative name for `numpy.transpose()`. */
1009
1188
  declare const permuteDims: (x: ArrayLike, perm?: number[]) => Array;
1010
1189
  /** Return a 1-D flattened array containing the elements of the input. */
1011
1190
  declare function ravel(a: ArrayLike): Array;
1191
+ /**
1192
+ * Repeat each element of an array after themselves.
1193
+ *
1194
+ * If no axis is provided, use the flattened input array, and return a flat
1195
+ * output array.
1196
+ */
1197
+ declare function repeat(a: ArrayLike, repeats: number, axis?: number): Array;
1198
+ /**
1199
+ * Construct an array by repeating A the number of times given by reps.
1200
+ *
1201
+ * If `A` is an array of shape `(d1, d2, ..., dn)` and `reps` is a sequence of
1202
+ * integers, the resulting array will have a shape of `(reps[0] * d1,
1203
+ * reps[1] * d2, ..., reps[n] * dn)`, with `A` tiled along each dimension.
1204
+ */
1205
+ declare function tile(a: ArrayLike, reps: number | number[]): Array;
1206
+ /**
1207
+ * Broadcast an array to a shape, with NumPy-style broadcasing rules.
1208
+ *
1209
+ * In other words, this lets you append axes to the left, and/or expand
1210
+ * dimensions where the shape is 1.
1211
+ */
1212
+ declare function broadcastTo(a: ArrayLike, shape: number[]): Array;
1213
+ /** Broadcast input shapes to a common output shape. */
1214
+ declare function broadcastShapes(...shapes: number[][]): number[];
1215
+ /** Broadcast arrays to a common shape. */
1216
+ declare function broadcastArrays(...arrays: ArrayLike[]): Array[];
1012
1217
  /**
1013
1218
  * Return specified diagonals.
1014
1219
  *
1015
1220
  * If a is 2D, return the diagonal of the array with the given offset. If a is
1016
- * 3D or higher, compute diagonals along the two given axes.
1221
+ * 3D or higher, compute diagonals along the two given axes (default: 0, 1).
1017
1222
  *
1018
- * This returns a view over the existing array.
1223
+ * This returns a view over the existing array. The shape of the resulting array
1224
+ * is determined by removing the two axes along which the diagonal is taken,
1225
+ * then appending a new axis to the right with holding the diagonals.
1019
1226
  */
1020
1227
  declare function diagonal(a: ArrayLike, offset?: number, axis1?: number, axis2?: number): Array;
1021
1228
  /**
@@ -1034,8 +1241,28 @@ declare function allclose(actual: Parameters<typeof array>[0], expected: Paramet
1034
1241
  declare function matmul(x: ArrayLike, y: ArrayLike): Array;
1035
1242
  /** Dot product of two arrays. */
1036
1243
  declare function dot(x: ArrayLike, y: ArrayLike): Array;
1037
- /** Vector dot product of two arrays. */
1038
- declare function vecdot(x: ArrayLike, y: ArrayLike): Array;
1244
+ /**
1245
+ * Compute the inner product of two arrays.
1246
+ *
1247
+ * Unlike `jax.numpy.matmul()` or `jax.numpy.dot()`, this always performs a
1248
+ * contraction on the last axis.
1249
+ *
1250
+ * Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
1251
+ */
1252
+ declare function inner(x: ArrayLike, y: ArrayLike): Array;
1253
+ /**
1254
+ * Compute the outer product of two arrays.
1255
+ *
1256
+ * If the input arrays are not 1D, they will be flattened. Returned array will
1257
+ * be of shape `[x.size, y.size]`.
1258
+ */
1259
+ declare function outer(x: ArrayLike, y: ArrayLike): Array;
1260
+ /** Vector dot product of two arrays along a given axis. */
1261
+ declare function vecdot(x: ArrayLike, y: ArrayLike, {
1262
+ axis
1263
+ }?: {
1264
+ axis?: number;
1265
+ }): Array;
1039
1266
  /**
1040
1267
  * Return the dot product of two vectors.
1041
1268
  *
@@ -1053,6 +1280,21 @@ declare function meshgrid(xs: Array[], {
1053
1280
  }?: {
1054
1281
  indexing?: "xy" | "ij";
1055
1282
  }): Array[];
1283
+ /**
1284
+ * Return an array with ones on and below the diagonal and zeros elsewhere.
1285
+ *
1286
+ * If `k` is provided, it specifies the sub-diagonal on and below which the
1287
+ * array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
1288
+ * `k>0` is above it.
1289
+ */
1290
+ declare function tri(n: number, m?: number, k?: number, {
1291
+ dtype,
1292
+ device
1293
+ }?: DTypeAndDevice): Array;
1294
+ /** Return the lower triangle of an array. Must be of dimension >= 2. */
1295
+ declare function tril(a: ArrayLike, k?: number): Array;
1296
+ /** Return the upper triangle of an array. Must be of dimension >= 2. */
1297
+ declare function triu(a: ArrayLike, k?: number): Array;
1056
1298
  /**
1057
1299
  * Clip (limit) the values in an array.
1058
1300
  *
@@ -1069,15 +1311,50 @@ declare function clip(a: ArrayLike, min?: ArrayLike, max?: ArrayLike): Array;
1069
1311
  * This is the same function as `jax.numpy.abs()`.
1070
1312
  */
1071
1313
  declare function absolute(x: ArrayLike): Array;
1072
- /** Alias of `jax.numpy.absolute()`. */
1314
+ /** @function Alias of `jax.numpy.absolute()`. */
1073
1315
  declare const abs: typeof absolute;
1316
+ /** Return an element-wise indication of sign of the input. */
1317
+ declare function sign(x: ArrayLike): Array;
1074
1318
  /** Calculate element-wise square of the input array. */
1075
1319
  declare function square(x: ArrayLike): Array;
1076
- /** Compute a trigonometric tangent of each element of input. */
1320
+ /** Element-wise tangent function (takes radians). */
1077
1321
  declare function tan(x: ArrayLike): Array;
1322
+ /** Element-wise inverse cosine function (inverse of cos). */
1323
+ declare function acos(x: ArrayLike): Array;
1324
+ /**
1325
+ * @function
1326
+ * Return element-wise hypotenuse for the given legs of a right triangle.
1327
+ *
1328
+ * In the original NumPy/JAX implementation, this function is more numerically
1329
+ * stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
1330
+ * improvements.
1331
+ */
1332
+ declare const hypot: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
1333
+ /**
1334
+ * @function
1335
+ * Element-wise arc tangent of y/x with correct quadrant.
1336
+ *
1337
+ * Returns the angle in radians between the positive x-axis and the point (x, y).
1338
+ * The result is in the range [-π, π].
1339
+ *
1340
+ * Uses numerically stable formulas:
1341
+ * - When x >= 0: atan2(y, x) = 2 * atan(y / (sqrt(x^2 + y^2) + x))
1342
+ * - When x < 0: atan2(y, x) = 2 * atan((sqrt(x^2 + y^2) - x) / y)
1343
+ *
1344
+ * The output is ill-defined when both x and y are zero.
1345
+ */
1346
+ declare const atan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
1347
+ /** @function Alias of `jax.numpy.acos()`. */
1348
+ declare const arccos: typeof acos;
1349
+ /** @function Alias of `jax.numpy.atan()`. */
1350
+ declare const arctan: (x: ArrayLike) => Array;
1351
+ /** @function Alias of `jax.numpy.atan2()`. */
1352
+ declare const arctan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
1353
+ /** Element-wise subtraction, with broadcasting. */
1354
+ declare function subtract(x: ArrayLike, y: ArrayLike): Array;
1078
1355
  /** Calculates the floating-point division of x by y element-wise. */
1079
1356
  declare function trueDivide(x: ArrayLike, y: ArrayLike): Array;
1080
- /** Alias of `jax.numpy.trueDivide()`. */
1357
+ /** @function Alias of `jax.numpy.trueDivide()`. */
1081
1358
  declare const divide: typeof trueDivide;
1082
1359
  /** Round input to the nearest integer towards zero. */
1083
1360
  declare function trunc(x: ArrayLike): Array;
@@ -1087,8 +1364,112 @@ declare function exp2(p: ArrayLike): Array;
1087
1364
  declare function log2(x: ArrayLike): Array;
1088
1365
  /** Return the base-10 logarithm of x, element-wise. */
1089
1366
  declare function log10(x: ArrayLike): Array;
1367
+ /** Calculate `exp(x) - 1` element-wise. */
1368
+ declare function expm1(x: ArrayLike): Array;
1369
+ /** Calculate the natural logarithm of `1 + x` element-wise. */
1370
+ declare function log1p(x: ArrayLike): Array;
1371
+ /** Convert angles from degrees to radians. */
1372
+ declare function deg2rad(x: ArrayLike): Array;
1373
+ /** @function Alias of `jax.numpy.deg2rad()`. */
1374
+ declare const radians: typeof deg2rad;
1375
+ /** Convert angles from radians to degrees. */
1376
+ declare function rad2deg(x: ArrayLike): Array;
1377
+ /** @function Alias of `jax.numpy.rad2deg()`. */
1378
+ declare const degrees: typeof rad2deg;
1379
+ /**
1380
+ * @function
1381
+ * Computes first array raised to power of second array, element-wise.
1382
+ */
1383
+ declare const power: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
1384
+ /** @function Alias of `jax.numpy.power()`. */
1385
+ declare const pow: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
1386
+ /** @function Calculate the element-wise cube root of the input array. */
1387
+ declare const cbrt: OwnedFunction<(x: ArrayLike) => Array>;
1388
+ /**
1389
+ * @function
1390
+ * Calculate element-wise hyperbolic sine of input.
1391
+ *
1392
+ * `sinh(x) = (exp(x) - exp(-x)) / 2`
1393
+ */
1394
+ declare const sinh: OwnedFunction<(x: ArrayLike) => Array>;
1395
+ /**
1396
+ * @function
1397
+ * Calculate element-wise hyperbolic cosine of input.
1398
+ *
1399
+ * `cosh(x) = (exp(x) + exp(-x)) / 2`
1400
+ */
1401
+ declare const cosh: OwnedFunction<(x: ArrayLike) => Array>;
1402
+ /**
1403
+ * @function
1404
+ * Calculate element-wise hyperbolic tangent of input.
1405
+ *
1406
+ * `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
1407
+ */
1408
+ declare const tanh: OwnedFunction<(x: ArrayLike) => Array>;
1409
+ /**
1410
+ * @function
1411
+ * Calculate element-wise inverse hyperbolic sine of input.
1412
+ *
1413
+ * `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
1414
+ */
1415
+ declare const arcsinh: OwnedFunction<(x: ArrayLike) => Array>;
1416
+ /**
1417
+ * @function
1418
+ * Calculate element-wise inverse hyperbolic cosine of input.
1419
+ *
1420
+ * `arccosh(x) = ln(x + sqrt(x^2 - 1))`
1421
+ */
1422
+ declare const arccosh: OwnedFunction<(x: ArrayLike) => Array>;
1423
+ /**
1424
+ * @function
1425
+ * Calculate element-wise inverse hyperbolic tangent of input.
1426
+ *
1427
+ * `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
1428
+ */
1429
+ declare const arctanh: OwnedFunction<(x: ArrayLike) => Array>;
1430
+ /** @function Alias of `jax.numpy.arcsinh()`. */
1431
+ declare const asinh: OwnedFunction<(x: ArrayLike) => Array>;
1432
+ /** @function Alias of `jax.numpy.arccosh()`. */
1433
+ declare const acosh: OwnedFunction<(x: ArrayLike) => Array>;
1434
+ /** @function Alias of `jax.numpy.arctanh()`. */
1435
+ declare const atanh: OwnedFunction<(x: ArrayLike) => Array>;
1436
+ /**
1437
+ * Compute the variance of an array.
1438
+ *
1439
+ * The variance is computed for the flattened array by default, otherwise over
1440
+ * the specified axis.
1441
+ *
1442
+ * If `correction` is provided, the divisor in calculation is `N - correction`,
1443
+ * where `N` represents the number of elements (e.g., for Bessel's correction).
1444
+ */
1445
+ declare function var_(x: ArrayLike, axis?: Axis, opts?: {
1446
+ mean?: ArrayLike;
1447
+ correction?: number;
1448
+ } & ReduceOpts): Array;
1449
+ /**
1450
+ * Compute the standard deviation of an array.
1451
+ *
1452
+ * The standard deviation is computed for the flattened array by default,
1453
+ * otherwise over the specified axis.
1454
+ *
1455
+ * If `correction` is provided, the divisor in calculation is `N - correction`,
1456
+ * where `N` represents the number of elements (e.g., for Bessel's correction).
1457
+ */
1458
+ declare function std(x: ArrayLike, axis?: Axis, opts?: {
1459
+ mean?: ArrayLike;
1460
+ correction?: number;
1461
+ } & ReduceOpts): Array;
1090
1462
  //#endregion
1091
1463
  //#region src/frontend/jaxpr.d.ts
1464
+ /**
1465
+ * Function callback with an associated dispose() method.
1466
+ *
1467
+ * The dispose() method should be called to clean up any tracer resources needed
1468
+ * by the function after the last time it is called.
1469
+ */
1470
+ type OwnedFunction<F extends Function> = F & {
1471
+ dispose: () => void;
1472
+ };
1092
1473
  /** Variable in a Jaxpr expression. */
1093
1474
  declare class Var {
1094
1475
  #private;
@@ -1149,8 +1530,38 @@ declare class Jaxpr implements FpHashable {
1149
1530
  /** Flattens nested JitCall in a Jaxpr. Useful for handling jit-of-jit. */
1150
1531
  flatten(): Jaxpr;
1151
1532
  }
1533
+ /** @inline */
1534
+ type JitOpts = {
1535
+ staticArgnums?: number[];
1536
+ device?: Device;
1537
+ };
1538
+ declare namespace lax_d_exports {
1539
+ export { PaddingType, conv, convGeneralDilated, convWithGeneralPadding, reduceWindow };
1540
+ }
1541
+ type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | [number, number][];
1542
+ /**
1543
+ * General n-dimensional convolution operator, with optional dilation.
1544
+ *
1545
+ * The semantics of this operation mimic the `jax.lax.conv_general_dilated`
1546
+ * function in JAX, which wraps XLA's general convolution operator.
1547
+ *
1548
+ * Grouped convolutions are not supported right now.
1549
+ */
1550
+ declare function convGeneralDilated(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, {
1551
+ lhsDilation,
1552
+ rhsDilation
1553
+ }?: {
1554
+ lhsDilation?: number[];
1555
+ rhsDilation?: number[];
1556
+ }): Array;
1557
+ /** Convenience wrapper around `convGeneralDilated`. */
1558
+ declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, lhsDilation?: number[], rhsDilation?: number[]): Array;
1559
+ /** Convenience wrapper around `convGeneralDilated`. */
1560
+ declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType): Array;
1561
+ /** Reduce a computation over padded windows. */
1562
+ declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
1152
1563
  declare namespace nn_d_exports {
1153
- export { identity, logSigmoid, logSoftmax, logsumexp, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, swish };
1564
+ export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, squareplus, standardize, swish };
1154
1565
  }
1155
1566
  /**
1156
1567
  * Rectified Linear Unit (ReLU) activation function:
@@ -1182,6 +1593,7 @@ declare function softplus(x: ArrayLike): Array;
1182
1593
  */
1183
1594
  declare function softSign(x: ArrayLike): Array;
1184
1595
  /**
1596
+ * @function
1185
1597
  * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
1186
1598
  * Swish, computed element-wise:
1187
1599
  * `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
@@ -1190,8 +1602,9 @@ declare function softSign(x: ArrayLike): Array;
1190
1602
  *
1191
1603
  * Reference: https://en.wikipedia.org/wiki/Swish_function
1192
1604
  */
1193
- declare function silu(x: ArrayLike): Array;
1605
+ declare const silu: OwnedFunction<(x: ArrayLike) => Array>;
1194
1606
  /**
1607
+ * @function
1195
1608
  * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
1196
1609
  * Swish, computed element-wise:
1197
1610
  * `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
@@ -1200,14 +1613,67 @@ declare function silu(x: ArrayLike): Array;
1200
1613
  *
1201
1614
  * Reference: https://en.wikipedia.org/wiki/Swish_function
1202
1615
  */
1203
- declare const swish: typeof silu;
1616
+ declare const swish: OwnedFunction<(x: ArrayLike) => Array>;
1204
1617
  /**
1205
1618
  * Log-sigmoid activation function, computed element-wise:
1206
1619
  * `log_sigmoid(x) = log(sigmoid(x)) = -log(1 + exp(-x))`.
1207
1620
  */
1208
1621
  declare function logSigmoid(x: ArrayLike): Array;
1209
- /** Identity activation function. Returns the argument unmodified. */
1622
+ /**
1623
+ * @function
1624
+ * Identity activation function. Returns the argument unmodified.
1625
+ */
1210
1626
  declare const identity: (x: ArrayLike) => Array;
1627
+ /** Leaky rectified linear (ReLU) activation function */
1628
+ declare function leakyRelu(x: ArrayLike, negativeSlope?: ArrayLike): Array;
1629
+ /**
1630
+ * Exponential linear unit activation function.
1631
+ *
1632
+ * Computes the element-wise function:
1633
+ * `elu(x) = x > 0 ? x : alpha * (exp(x) - 1)`
1634
+ */
1635
+ declare function elu(x: ArrayLike, alpha?: ArrayLike): Array;
1636
+ /**
1637
+ * Continuously-differentiable exponential linear unit activation function.
1638
+ *
1639
+ * Computes the element-wise function:
1640
+ * `celu(x) = x > 0 ? x : alpha * (exp(x/alpha) - 1)`
1641
+ */
1642
+ declare function celu(x: ArrayLike, alpha?: ArrayLike): Array;
1643
+ /**
1644
+ * @function
1645
+ * Gaussion error linear unit (GELU) activation function.
1646
+ *
1647
+ * This is computed element-wise. Currently jax-js does not support the erf() or
1648
+ * gelu() functions exactly as primitives, so an approximation is used:
1649
+ * `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
1650
+ *
1651
+ * Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
1652
+ *
1653
+ * This will be improved in the future.
1654
+ */
1655
+ declare const gelu: OwnedFunction<(x: ArrayLike) => Array>;
1656
+ /**
1657
+ * Gated linear unit (GLU) activation function.
1658
+ *
1659
+ * Splits the `axis` dimension of the input into two halves, a and b, then
1660
+ * computes `a * sigmoid(b)`.
1661
+ */
1662
+ declare function glu(x: ArrayLike, axis?: number): Array;
1663
+ /**
1664
+ * Squareplus activation function.
1665
+ *
1666
+ * Computes the element-wise function:
1667
+ * `squareplus(x) = 0.5 * (x + sqrt(x^2 + b))`
1668
+ */
1669
+ declare function squareplus(x: ArrayLike, b?: ArrayLike): Array;
1670
+ /**
1671
+ * Mish activation function.
1672
+ *
1673
+ * Computes the element-wise function:
1674
+ * `mish(x) = x * tanh(softplus(x))`
1675
+ */
1676
+ declare function mish(x: ArrayLike): Array;
1211
1677
  /**
1212
1678
  * Softmax function. Computes the function which rescales elements to the range
1213
1679
  * [0, 1] such that the elements along `axis` sum to 1.
@@ -1216,7 +1682,7 @@ declare const identity: (x: ArrayLike) => Array;
1216
1682
  *
1217
1683
  * Reference: https://en.wikipedia.org/wiki/Softmax_function
1218
1684
  */
1219
- declare function softmax(x: ArrayLike, axis?: number | number[]): Array;
1685
+ declare function softmax(x: ArrayLike, axis?: Axis): Array;
1220
1686
  /**
1221
1687
  * Log-Softmax function.
1222
1688
  *
@@ -1225,7 +1691,7 @@ declare function softmax(x: ArrayLike, axis?: number | number[]): Array;
1225
1691
  *
1226
1692
  * If `axis` is not specified, it defaults to the last axis.
1227
1693
  */
1228
- declare function logSoftmax(x: ArrayLike, axis?: number | number[]): Array;
1694
+ declare function logSoftmax(x: ArrayLike, axis?: Axis): Array;
1229
1695
  /**
1230
1696
  * Log-sum-exp reduction. Also a multivariate version of `softplus`.
1231
1697
  *
@@ -1234,7 +1700,22 @@ declare function logSoftmax(x: ArrayLike, axis?: number | number[]): Array;
1234
1700
  *
1235
1701
  * Reference: https://en.wikipedia.org/wiki/LogSumExp
1236
1702
  */
1237
- declare function logsumexp(x: ArrayLike, axis?: number | number[]): Array;
1703
+ declare function logsumexp(x: ArrayLike, axis?: Axis): Array;
1704
+ /** Log-mean-exp reduction, like `jax.nn.logsumexp()` but subtracts `log(n)`. */
1705
+ declare function logmeanexp(x: ArrayLike, axis?: Axis): Array;
1706
+ /**
1707
+ * Standardizes input to zero mean and unit variance.
1708
+ *
1709
+ * By default, this is computed over the last axis. You can pass in a different
1710
+ * axis, or `null` to standardize over all elements.
1711
+ *
1712
+ * Epsilon is added to denominator, it defaults to `1e-5` for stability.
1713
+ */
1714
+ declare function standardize(x: ArrayLike, axis?: Axis, opts?: {
1715
+ mean?: ArrayLike;
1716
+ variance?: ArrayLike;
1717
+ epsilon?: ArrayLike;
1718
+ }): Array;
1238
1719
  /**
1239
1720
  * One-hot encodes the given indices.
1240
1721
  *
@@ -1253,7 +1734,7 @@ declare function logsumexp(x: ArrayLike, axis?: number | number[]): Array;
1253
1734
  */
1254
1735
  declare function oneHot(x: Array, numClasses: number): Array;
1255
1736
  declare namespace random_d_exports {
1256
- export { bits, key, split, uniform };
1737
+ export { bernoulli, bits, exponential, key, normal, split, uniform };
1257
1738
  }
1258
1739
  /** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
1259
1740
  declare function key(seed: number): Array;
@@ -1269,40 +1750,108 @@ declare function uniform(key: Array, shape?: number[], {
1269
1750
  minval?: number;
1270
1751
  maxval?: number;
1271
1752
  }): Array;
1753
+ /**
1754
+ * Sample Bernoulli random variables with given mean (0,1 categorical).
1755
+ *
1756
+ * Returns a random Boolean array with the specified shape. `p` can be an array
1757
+ * and must be broadcastable to `shape`.
1758
+ */
1759
+ declare function bernoulli(key: Array, p?: ArrayLike, shape?: number[]): Array;
1760
+ /** Sample exponential random values according to `p(x) = exp(-x)`. */
1761
+ declare function exponential(key: Array, shape?: number[]): Array;
1762
+ /**
1763
+ * Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
1764
+ *
1765
+ * Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
1766
+ * directly inverts the CDF, but we don't have support for that yet. Outputs will not be
1767
+ * bitwise identical to JAX.
1768
+ */
1769
+ declare function normal(key: Array, shape?: number[]): Array;
1272
1770
  //#endregion
1273
1771
  //#region src/index.d.ts
1274
- /** @inline */
1275
- type WithArgsSubtype<F extends (args: any[]) => any, T> = Parameters<F> extends T ? F : never;
1276
- /** Compute the forward-mode Jacobian-vector product for a function. */
1277
- declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>, primals: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>, tangents: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => [ReturnType<F>, ReturnType<F>];
1278
- /** Vectorize an operation on a batched axis for one or more inputs. */
1279
- declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>, inAxes?: number | MapJsTree<Parameters<F>, ArrayLike, number | null>) => (...args: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => ReturnType<F>;
1280
- /** Compute the Jacobian evaluated column-by-column by forward-mode AD. */
1281
- declare const jacfwd: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => ReturnType<F>;
1282
- /** Construct a Jaxpr by dynamically tracing a function with example inputs. */
1283
- declare const makeJaxpr: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>) => (...args: Parameters<F>) => {
1772
+ /**
1773
+ * @function
1774
+ * Compute the forward-mode Jacobian-vector product for a function.
1775
+ */
1776
+ declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, ReturnType<F>];
1777
+ /**
1778
+ * @function
1779
+ * Vectorize an operation on a batched axis for one or more inputs.
1780
+ */
1781
+ declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number | MapJsTree<Parameters<F>, ArrayLike, number | null>) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
1782
+ /**
1783
+ * @function
1784
+ * Compute the Jacobian evaluated column-by-column by forward-mode AD.
1785
+ */
1786
+ declare const jacfwd: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
1787
+ /**
1788
+ * @function
1789
+ * Construct a Jaxpr by dynamically tracing a function with example inputs.
1790
+ */
1791
+ declare const makeJaxpr: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...args: Parameters<F>) => {
1284
1792
  jaxpr: Jaxpr;
1285
1793
  consts: Array[];
1286
1794
  treedef: JsTreeDef;
1287
1795
  };
1288
- declare const jit: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>) => (...args: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => ReturnType<F>;
1289
1796
  /**
1797
+ * @function
1798
+ * Mark a function for automatic JIT compilation, with operator fusion.
1799
+ *
1800
+ * The function will be compiled the first time it is called with a set of
1801
+ * argument shapes.
1802
+ *
1803
+ * You can call `.dispose()` on the returned, JIT-compiled function after all
1804
+ * calls to free memory associated with array constants.
1805
+ *
1806
+ * **Options:**
1807
+ * - `staticArgnums`: An array of argument indices to treat as static
1808
+ * (compile-time constant). These arguments must be hashable, won't be traced,
1809
+ * and different values will trigger recompilation.
1810
+ * - `device`: The device to place the computation on. If not specified, the
1811
+ * computation will be placed on the default device.
1812
+ */
1813
+ declare const jit: <F extends (...args: any[]) => JsTree<Array>>(f: F, opts?: JitOpts) => OwnedFunction<(...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>>;
1814
+ /**
1815
+ * @function
1290
1816
  * Produce a local linear approximation to a function at a point using jvp() and
1291
1817
  * partial evaluation.
1292
1818
  */
1293
- declare const linearize: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>, ...primals: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => [ReturnType<F>, (...tangents: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => ReturnType<F>];
1294
- /** Calculate the reverse-mode vector-Jacobian product for a function. */
1295
- declare const vjp: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>, ...primals: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => [ReturnType<F>, (cotangents: MapJsTree<ReturnType<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>];
1819
+ declare const linearize: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (...tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>];
1296
1820
  /**
1821
+ * @function
1822
+ * Calculate the reverse-mode vector-Jacobian product for a function.
1823
+ */
1824
+ declare const vjp: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (cotangents: MapJsTree<ReturnType<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>];
1825
+ /**
1826
+ * @function
1297
1827
  * Compute the gradient of a scalar-valued function `f` with respect to its
1298
1828
  * first argument.
1299
1829
  */
1300
- declare const grad: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>) => (...primals: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => MapJsTree<Parameters<F>[0], ArrayLike, Array>;
1301
- /** Create a function that evaluates both `f` and the gradient of `f`. */
1302
- declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>) => (...primals: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => [ReturnType<F>, MapJsTree<Parameters<F>[0], ArrayLike, Array>];
1303
- /** Compute the Jacobian evaluated row-by-row by reverse-mode AD. */
1830
+ declare const grad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>[0], ArrayLike, Array>;
1831
+ /**
1832
+ * @function
1833
+ * Create a function that evaluates both `f` and the gradient of `f`.
1834
+ */
1835
+ declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, MapJsTree<Parameters<F>[0], ArrayLike, Array>];
1836
+ /**
1837
+ * @function
1838
+ * Compute the Jacobian evaluated row-by-row by reverse-mode AD.
1839
+ */
1304
1840
  declare const jacrev: typeof jacfwd;
1305
- /** Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`. */
1306
- declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => ReturnType<F>;
1841
+ /**
1842
+ * @function
1843
+ * Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
1844
+ */
1845
+ declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
1846
+ /**
1847
+ * Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
1848
+ *
1849
+ * This can be used to wait for the results of an intermediate computation to
1850
+ * finish. It's recommended to call this regularly in an iterative computation
1851
+ * to avoid queueing up too many pending operations.
1852
+ *
1853
+ * Does not consume reference to the arrays.
1854
+ */
1855
+ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
1307
1856
  //#endregion
1308
- export { type Device, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, setDevice, tree_d_exports as tree, valueAndGrad, vjp, vmap };
1857
+ export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };