@jax-js/jax 0.1.5 → 0.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.d.ts CHANGED
@@ -538,1704 +538,1782 @@ declare class Executable<T = any> {
538
538
  source: Kernel | Routine, /** Extra data specific to the backend running this executable. */
539
539
  data: T);
540
540
  }
541
- declare namespace numpy_fft_d_exports {
542
- export { ComplexPair, fft, ifft };
541
+ declare namespace tree_d_exports {
542
+ export { JsTree, JsTreeDef, MapJsTree, NodeType, dispose, flatten, leaves, map, ref, structure, unflatten };
543
543
  }
544
- /**
545
- * A pair of arrays representing real and imaginary part `a + bj`. Both arrays
546
- * must have the same shape.
547
- */
548
- type ComplexPair = {
549
- real: Array;
550
- imag: Array;
544
+ declare enum NodeType {
545
+ Array = "Array",
546
+ Object = "Object",
547
+ Leaf = "Leaf",
548
+ }
549
+ /** Analog to the JAX "pytree" object, but for JavaScript. */
550
+ type JsTree<T> = T | JsTree<T>[] | {
551
+ [key: string]: JsTree<T>;
551
552
  };
552
- /**
553
- * Compute a one-dimensional discrete Fourier transform.
554
- *
555
- * Currently, the size of the axis must be a power of two.
556
- */
557
- declare function fft(a: ComplexPair, axis?: number): ComplexPair;
558
- /**
559
- * Compute a one-dimensional inverse discrete Fourier transform.
560
- *
561
- * Currently, the size of the axis must be a power of two.
562
- */
563
- declare function ifft(a: ComplexPair, axis?: number): ComplexPair;
564
- declare namespace numpy_linalg_d_exports {
565
- export { cholesky$1 as cholesky, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot };
553
+ type Same<X, Y> = (<T>() => T extends X ? 1 : 2) extends (<T>() => T extends Y ? 1 : 2) ? true : false;
554
+ type MappedJsTree<T, A, B> = T extends A ? B : T extends Array ? T : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
555
+ /** @ignore Convert a subtype of JsTree<A> into a JsTree<B>, with the same structure. */
556
+ type MapJsTree<T, A, B> = Same<A, B> extends true ? T : MappedJsTree<T, A, B>;
557
+ /** Represents the structure of a JsTree. */
558
+ declare class JsTreeDef {
559
+ readonly nodeType: NodeType;
560
+ readonly nodeMetadata: any;
561
+ readonly childTreedefs: JsTreeDef[];
562
+ static leaf: JsTreeDef;
563
+ constructor(nodeType: NodeType, nodeMetadata: any,
564
+ // Must be comparable with deepEqual.
565
+ childTreedefs: JsTreeDef[]);
566
+ /** Get the total number of leaves in the tree. */
567
+ get size(): number;
568
+ /** Returns a string representation of this tree definition. */
569
+ toString(root?: boolean): string;
570
+ /** Compare this tree definition with another. */
571
+ equals(other: JsTreeDef): boolean;
572
+ }
573
+ /** Flatten a structured object, returning the tree definition. */
574
+ declare function flatten<T>(tree: JsTree<T>): [T[], JsTreeDef];
575
+ /** Get the leaves of a tree. */
576
+ declare function leaves<T>(tree: JsTree<T>): T[];
577
+ /** Get the treedef for a tree. */
578
+ declare function structure<T>(tree: JsTree<T>): JsTreeDef;
579
+ /** Reconstruct a structured object from the flattened representation. */
580
+ declare function unflatten<T>(treedef: JsTreeDef, leaves: Iterable<T>): JsTree<T>;
581
+ /** Maps a multi-input function over pytree args to produce a new pytree. */
582
+ declare function map<T, U, Tree extends JsTree<T>>(fn: (...args: T[]) => U, tree: Tree, ...rest: Tree[]): MapJsTree<Tree, T, U>;
583
+ /** Take a reference of every array in a tree. */
584
+ declare function ref<Tree extends JsTree<any>>(tree: Tree): Tree;
585
+ /** Dispose every array in a tree. */
586
+ declare function dispose<Tree extends JsTree<any>>(tree: Tree | null | undefined): void;
587
+ //#endregion
588
+ //#region src/frontend/convolution.d.ts
589
+ /** Definition of a general dilated convolution. Should be valid on creation. */
590
+ interface ConvParams {
591
+ vmapDims: number;
592
+ strides: number[];
593
+ padding: Pair[];
594
+ lhsDilation: number[];
595
+ rhsDilation: number[];
566
596
  }
567
597
  /**
568
- * Compute the Cholesky decomposition of a (batched) positive-definite matrix.
598
+ * Check that the shapes and parameters passed to convolution are valid.
599
+ * Expected shapes of the lhs and rhs of the convolution are:
569
600
  *
570
- * This is like `jax.lax.linalg.cholesky()`, except with an option to symmetrize
571
- * the input matrix, which is on by default.
601
+ * - `lhsShape = [*vmapDims, batchSize, inChannels, spatialDims...]`
602
+ * - `rhsShape = [*vmapDims, outChannels, inChannels, kernelSize...]`
603
+ *
604
+ * If the check succeeds, returns the output shape.
572
605
  */
573
- declare function cholesky$1(a: ArrayLike, {
574
- upper,
575
- symmetrizeInput
576
- }?: {
577
- upper?: boolean;
578
- symmetrizeInput?: boolean;
579
- }): Array;
580
- /** Compute the determinant of a square matrix (batched). */
581
- declare function det(a: ArrayLike): Array;
582
- /** Compute the inverse of a square matrix (batched). */
583
- declare function inv(a: ArrayLike): Array;
606
+ //#endregion
607
+ //#region src/frontend/jaxpr.d.ts
584
608
  /**
585
- * Return the least-squares solution to a linear equation.
586
- *
587
- * For overdetermined systems, this finds the `x` that minimizes `norm(ax - b)`.
588
- * For underdetermined systems, this finds the minimum-norm solution for `x`.
589
- *
590
- * This currently uses Cholesky decomposition to solve the normal equations,
591
- * under the hood. The method is not as robust as QR or SVD.
609
+ * Function callback with an associated dispose() method.
592
610
  *
593
- * @param a coefficient matrix of shape `(M, N)`
594
- * @param b right-hand side of shape `(M,)` or `(M, K)`
595
- * @return least-squares solution of shape `(N,)` or `(N, K)`
611
+ * The dispose() method should be called to clean up any tracer resources needed
612
+ * by the function after the last time it is called.
596
613
  */
597
- declare function lstsq(a: ArrayLike, b: ArrayLike): Array;
598
- /** Raise a square matrix to an integer power, via repeated squarings. */
599
- declare function matrixPower(a: ArrayLike, n: number): Array;
600
- /** Return sign and natural logarithm of the determinant of `a`. */
601
- declare function slogdet(a: ArrayLike): [Array, Array];
614
+ type OwnedFunction<F extends Function> = F & {
615
+ dispose: () => void;
616
+ };
617
+ /** Variable in a Jaxpr expression. */
618
+ declare class Var {
619
+ #private;
620
+ readonly id: number;
621
+ readonly aval: ShapedArray;
622
+ constructor(aval: ShapedArray);
623
+ toString(): string;
624
+ }
625
+ /** Literal in a Jaxpr expression. Currently, only scalars are supported. */
626
+ declare class Lit {
627
+ readonly value: number;
628
+ readonly aval: ShapedArray;
629
+ get dtype(): DType;
630
+ constructor(aval: AbstractValue, value: number);
631
+ }
632
+ type Atom = Var | Lit;
633
+ declare class VarPrinter {
634
+ #private;
635
+ names: Map<Var, string>;
636
+ name(v: Var): string;
637
+ nameType(v: Var): string;
638
+ }
639
+ /** A single statement / binding in a Jaxpr, in ANF form. */
640
+ declare class JaxprEqn {
641
+ readonly primitive: Primitive;
642
+ readonly inputs: Atom[];
643
+ readonly params: Record<string, any>;
644
+ readonly outBinders: Var[];
645
+ constructor(primitive: Primitive, inputs: Atom[], params: Record<string, any>, outBinders: Var[]);
646
+ pprint(usedVars?: Set<Var>, vp?: VarPrinter): PPrint;
647
+ toString(): string;
648
+ }
649
+ /** Typed intermediate representation for traced computations. */
650
+ declare class Jaxpr implements FpHashable {
651
+ #private;
652
+ readonly inBinders: Var[];
653
+ readonly eqns: JaxprEqn[];
654
+ readonly outs: Atom[];
655
+ constructor(inBinders: Var[], eqns: JaxprEqn[], outs: Atom[]);
656
+ pprint(): PPrint;
657
+ toString(): string;
658
+ /**
659
+ * Gets a hash of this Jaxpr.
660
+ *
661
+ * Var identity is not considered in the hash, so two Jaxprs with the same
662
+ * order of assignments and operators but different variable IDs will resolve
663
+ * to the same hash (and toString representation).
664
+ */
665
+ getHash(): bigint;
666
+ hash(state: FpHash): void;
667
+ /**
668
+ * Produce a simplified Jaxpr with basic optimizations applied.
669
+ * - Trim away unused variables.
670
+ * - Fold away *1, *0, or +0 operations against literals.
671
+ * - Remove no-op movement operations.
672
+ */
673
+ simplify(): Jaxpr;
674
+ /** Flattens nested Jit in a Jaxpr. Useful for handling jit-of-jit. */
675
+ flatten(): Jaxpr;
676
+ }
677
+ /** Jaxpr with a collection of associated, traced constants. */
678
+ declare class ClosedJaxpr {
679
+ readonly jaxpr: Jaxpr;
680
+ readonly consts: Tracer[];
681
+ constructor(jaxpr: Jaxpr, consts: Tracer[]);
682
+ /** String representation of this Jaxpr. */
683
+ toString(): string;
684
+ /** Apply a function to the underlying Jaxpr. */
685
+ mapJaxpr(f: (jaxpr: Jaxpr) => Jaxpr): ClosedJaxpr;
686
+ /** Dispose of the constants in this Jaxpr. */
687
+ dispose(): void;
688
+ }
689
+ /** @inline */
690
+ type JitOpts = {
691
+ staticArgnums?: number[];
692
+ };
693
+ //#endregion
694
+ //#region src/frontend/core.d.ts
602
695
  /**
603
- * Solve a linear system of equations.
696
+ * Frontend primitive operations, which are lowered into Kernel objects before
697
+ * being dispatched to the backend.
604
698
  *
605
- * This solves a (batched) linear system of equations `a @ x = b` for `x` given
606
- * `a` and `b`. If `a` is singular, this will return `nan` or `inf` values.
699
+ * Any operation between arrays can be described in these parts. This is also
700
+ * the set of primitives that can occur in Jaxpr programs, and the level at
701
+ * which transformations like vmap, grad, and jvp occur. They are loosely based
702
+ * on [XLA](https://openxla.org/xla/operation_semantics).
607
703
  *
608
- * @param a - Coefficient matrix of shape `(..., N, N)`.
609
- * @param b - Values of shape `(N,)` or `(..., N, M)`.
610
- * @returns Solution `x` of shape `(..., N)` or `(..., N, M)`.
704
+ * All n-ary operations support broadcasting, with NumPy semantics.
611
705
  */
612
- declare function solve(a: ArrayLike, b: ArrayLike): Array;
613
- //#endregion
614
- //#region src/library/numpy/dtype-info.d.ts
615
- /** @inline */
616
- type FInfo = Readonly<{
617
- /** The number of bits occupied by the type. */
618
- bits: number;
619
- /** Returns the _dtype_ for which finfo returns information. */
620
- dtype: DType;
621
- /** The difference between 1.0 and the next smallest representable float larger than 1.0. */
622
- eps: number;
623
- /** The difference between 1.0 and the next largest representable float smaller than 1.0. */
624
- epsneg: number;
625
- /** The exponent that yields `eps`. */
626
- machep: number;
627
- /** The largest representable finite number. */
628
- max: number;
629
- /** The smallest positive power of the base (2) that causes overflow. */
630
- maxexp: number;
631
- /** The smallest representable (most negative) finite number. */
632
- min: number;
633
- /** The largest negative power of the base (2) without leading zeros in mantissa. */
634
- minexp: number;
635
- /** The exponent that yields `epsneg`. */
636
- negep: number;
637
- /** Number of bits in the exponent portion. */
638
- nexp: number;
639
- /** Number of bits in the mantissa portion. */
640
- nmant: number;
641
- /** The approximate number of decimal digits to which this kind of float is precise. */
642
- precision: number;
643
- /** The approximate decimal resolution, i.e., `10 ** -precision`. */
644
- resolution: number;
645
- /** The smallest positive normal number. */
646
- smallestNormal: number;
647
- /** The smallest positive subnormal number. */
648
- smallestSubnormal: number;
649
- }>;
650
- /** Machine limits for floating-point types. */
651
- declare function finfo(dtype: DType): FInfo;
652
- /** @inline */
653
- type IInfo = Readonly<{
654
- /** The number of bits occupied by the type. */
655
- bits: number;
656
- /** Returns the _dtype_ for which iinfo returns information. */
657
- dtype: DType;
658
- /** The largest representable integer. */
659
- max: number;
660
- /** The smallest representable integer. */
661
- min: number;
662
- }>;
663
- /** Machine limits for integer types. */
664
- declare function iinfo(dtype: DType): IInfo;
665
- declare namespace numpy_d_exports {
666
- export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot$1 as dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logspace, matmul, matrixTranspose, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
706
+ declare enum Primitive {
707
+ Add = "add",
708
+ Mul = "mul",
709
+ Idiv = "idiv",
710
+ Mod = "mod",
711
+ // uses sign of numerator, C-style, matches JS but not Python
712
+ Min = "min",
713
+ Max = "max",
714
+ Neg = "neg",
715
+ Reciprocal = "reciprocal",
716
+ Floor = "floor",
717
+ Ceil = "ceil",
718
+ StopGradient = "stop_gradient",
719
+ Cast = "cast",
720
+ Bitcast = "bitcast",
721
+ Sin = "sin",
722
+ Cos = "cos",
723
+ Asin = "asin",
724
+ Atan = "atan",
725
+ Exp = "exp",
726
+ Log = "log",
727
+ Erf = "erf",
728
+ Erfc = "erfc",
729
+ Sqrt = "sqrt",
730
+ Reduce = "reduce",
731
+ Dot = "dot",
732
+ // sum(x*y, axis=-1)
733
+ Conv = "conv",
734
+ // see lax.conv_general_dilated
735
+ Pool = "pool",
736
+ PoolTranspose = "pool_transpose",
737
+ Compare = "compare",
738
+ Where = "where",
739
+ Concatenate = "concatenate",
740
+ Split = "split",
741
+ RandomBits = "random_bits",
742
+ Gather = "gather",
743
+ Transpose = "transpose",
744
+ Broadcast = "broadcast",
745
+ Reshape = "reshape",
746
+ Flip = "flip",
747
+ Shrink = "shrink",
748
+ Pad = "pad",
749
+ Sort = "sort",
750
+ // sort(x, axis=-1)
751
+ Argsort = "argsort",
752
+ // argsort(x, axis=-1)
753
+ TriangularSolve = "triangular_solve",
754
+ // A is upper triangular, A @ X.T = B.T
755
+ Cholesky = "cholesky",
756
+ // A is positive-definite, A = L @ L^T
757
+ LU = "lu",
758
+ // LU decomposition with partial pivoting
759
+ Jit = "jit",
667
760
  }
668
- declare const float32 = DType.Float32;
669
- declare const int32 = DType.Int32;
670
- declare const uint32 = DType.Uint32;
671
- declare const bool = DType.Bool;
672
- declare const float16 = DType.Float16;
673
- declare const float64 = DType.Float64;
674
- /** Euler's constant, `e = 2.7182818284590...` */
675
- declare const e: number;
676
- /** Euler-Mascheroni constant, `γ = 0.5772156649...` */
677
- declare const eulerGamma = 0.5772156649015329;
678
- /** Positive infinity. */
679
- declare const inf: number;
680
- /** Floating-point representation of NaN. */
681
- declare const nan: number;
682
- /** This is Pi, `π = 3.14159265358979...` */
683
- declare const pi: number;
684
- /** @function Element-wise addition, with broadcasting. */
685
- declare const add: (x: ArrayLike, y: ArrayLike) => Array;
686
- /** @function Element-wise multiplication, with broadcasting. */
687
- declare const multiply: (x: ArrayLike, y: ArrayLike) => Array;
688
- /** @function Numerical negative of every element of an array. */
689
- declare const negative: (x: ArrayLike) => Array;
690
- /** @function Calculate element-wise reciprocal of the input. This is `1/x`. */
691
- declare const reciprocal: (x: ArrayLike) => Array;
692
- /** @function Round input down to the nearest integer. */
693
- declare const floor: (x: ArrayLike) => Array;
694
- /** @function Round input up to the nearest integer. */
695
- declare const ceil: (x: ArrayLike) => Array;
696
- /** @function Element-wise sine function (takes radians). */
697
- declare const sin: (x: ArrayLike) => Array;
698
- /** @function Element-wise cosine function (takes radians). */
699
- declare const cos: (x: ArrayLike) => Array;
700
- /** @function Element-wise inverse sine function (inverse of sin). */
701
- declare const asin: (x: ArrayLike) => Array;
702
- /** @function Element-wise inverse tangent function (inverse of tan). */
703
- declare const atan: (x: ArrayLike) => Array;
704
- /** @function Calculate the exponential of all elements in the input array. */
705
- declare const exp: (x: ArrayLike) => Array;
706
- /** @function Calculate the natural logarithm of all elements in the input array. */
707
- declare const log: (x: ArrayLike) => Array;
708
- /** @function Calculate the square root of all elements in the input array. */
709
- declare const sqrt: (x: ArrayLike) => Array;
710
- /** @function Return element-wise minimum of the input arrays. */
711
- declare const minimum: (x: ArrayLike, y: ArrayLike) => Array;
712
- /** @function Return element-wise maximum of the input arrays. */
713
- declare const maximum: (x: ArrayLike, y: ArrayLike) => Array;
714
- /** @function Compare two arrays element-wise. */
715
- declare const greater: (x: ArrayLike, y: ArrayLike) => Array;
716
- /** @function Compare two arrays element-wise. */
717
- declare const less: (x: ArrayLike, y: ArrayLike) => Array;
718
- /** @function Compare two arrays element-wise. */
719
- declare const equal: (x: ArrayLike, y: ArrayLike) => Array;
720
- /** @function Compare two arrays element-wise. */
721
- declare const notEqual: (x: ArrayLike, y: ArrayLike) => Array;
722
- /** @function Compare two arrays element-wise. */
723
- declare const greaterEqual: (x: ArrayLike, y: ArrayLike) => Array;
724
- /** @function Compare two arrays element-wise. */
725
- declare const lessEqual: (x: ArrayLike, y: ArrayLike) => Array;
726
- /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
727
- declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
761
+ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
762
+ [Primitive.Cast]: {
763
+ dtype: DType;
764
+ };
765
+ [Primitive.Bitcast]: {
766
+ dtype: DType;
767
+ };
768
+ [Primitive.Reduce]: {
769
+ op: AluOp;
770
+ axis: number[];
771
+ };
772
+ [Primitive.Conv]: ConvParams;
773
+ [Primitive.Pool]: {
774
+ window: number[];
775
+ strides: number[];
776
+ };
777
+ [Primitive.PoolTranspose]: {
778
+ inShape: number[];
779
+ window: number[];
780
+ strides: number[];
781
+ };
782
+ [Primitive.Compare]: {
783
+ op: CompareOp;
784
+ };
785
+ [Primitive.Concatenate]: {
786
+ axis: number;
787
+ };
788
+ [Primitive.Split]: {
789
+ axis: number;
790
+ sizes: number[];
791
+ };
792
+ [Primitive.RandomBits]: {
793
+ shape: number[];
794
+ mode: "xor" | 0 | 1;
795
+ };
796
+ [Primitive.Gather]: {
797
+ axis: number[];
798
+ outDim: number;
799
+ };
800
+ [Primitive.Transpose]: {
801
+ perm: number[];
802
+ };
803
+ [Primitive.Broadcast]: {
804
+ shape: number[];
805
+ axis: number[];
806
+ };
807
+ [Primitive.Reshape]: {
808
+ shape: number[];
809
+ };
810
+ [Primitive.Flip]: {
811
+ axis: number[];
812
+ };
813
+ [Primitive.Shrink]: {
814
+ slice: Pair[];
815
+ };
816
+ [Primitive.Pad]: {
817
+ width: Pair[];
818
+ };
819
+ [Primitive.TriangularSolve]: {
820
+ unitDiagonal: boolean;
821
+ };
822
+ [Primitive.Jit]: {
823
+ name: string;
824
+ jaxpr: Jaxpr;
825
+ numConsts: number;
826
+ };
827
+ }
828
+ /** Type of parameters taken by each primitive. */
829
+ type PrimitiveParams<T extends Primitive> = T extends keyof PrimitiveParamsImpl ? PrimitiveParamsImpl[T] : Record<string, never>;
830
+ declare enum CompareOp {
831
+ Less = "less",
832
+ Equal = "equal",
833
+ NotEqual = "not_equal",
834
+ LessEqual = "less_equal",
835
+ }
836
+ /** @inline */
837
+ type Axis = number | number[] | null;
838
+ /** @inline */
839
+ type ReduceOpts = {
840
+ keepdims?: boolean;
841
+ };
842
+ type MainTrace = {
843
+ level: number;
844
+ traceType: new (main: MainTrace) => Trace;
845
+ globalData: any | null;
846
+ };
728
847
  /**
729
- * @function
730
- * Permute the dimensions of an array. Defaults to reversing the axis order.
848
+ * Push an interpreter onto the trace stack. Use this like:
849
+ * `using main = newMain(...);`
731
850
  */
732
- declare const transpose: (x: ArrayLike, perm?: number[]) => Array;
851
+
852
+ type TracerValue = Tracer | number | boolean;
853
+ declare abstract class Trace {
854
+ readonly main: MainTrace;
855
+ constructor(main: MainTrace);
856
+ abstract pure(val: TracerValue): Tracer;
857
+ abstract lift(val: Tracer): Tracer;
858
+ abstract processPrimitive<P extends Primitive>(primitive: P, tracers: Tracer[], params: PrimitiveParams<P>): Tracer[];
859
+ }
860
+ /** Internal representation of an array value. */
861
+ interface AbstractValue {
862
+ /** Shape of the array. Must be a static tuple of non-negative dimensions. */
863
+ shape: number[];
864
+ /** Concrete data type of array elements. */
865
+ dtype: DType;
866
+ /**
867
+ * Arrays created from JavaScript numbers (e.g., `np.array(3)`) are created as
868
+ * _weakly typed_ unless a dtype is explicitly specified.
869
+ *
870
+ * Weakly typed values will automatically cast to the data type of other
871
+ * arrays when used as an operand as an expression. This property only affects
872
+ * how they promote in type casting; their memory layout is still determined
873
+ * by the actual `dtype` field.
874
+ *
875
+ * ```ts
876
+ * const x = np.array(3); // weakType = true, dtype = float32
877
+ * const y = np.array([1, 2], { dtype: np.int32 }); // weakType = false, dtype = int32
878
+ * const z = x.add(y); // z has dtype int32 because x is weakly typed
879
+ * ```
880
+ *
881
+ * Weak types are present in JIT programs in their spec (e.g., Jaxpr inputs
882
+ * and outputs can be weakly typed) form. But they're solely a frontend
883
+ * concept. Backends are not aware of weak types.
884
+ */
885
+ weakType: boolean;
886
+ }
733
887
  /**
734
- * @function
735
- * Give a new shape to an array without changing its data.
888
+ * Broadcast shapes and promote types with casting for two avals.
736
889
  *
737
- * One shape dimension can be -1. In this case, the value is inferred from the
738
- * length of the array and remaining dimensions.
890
+ * This implements the weak type behavior described in `promoteTypes()`, but not
891
+ * implemented in that function as `weakType` is not passed.
739
892
  */
740
- declare const reshape: (x: ArrayLike, shape: number[]) => Array;
741
- /**
742
- * @function
743
- * Move axes of an array to new positions. Other axes retain original order.
744
- */
745
- declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
893
+
894
+ declare abstract class Tracer {
895
+ /** @ignore */
896
+ readonly _trace: Trace;
897
+ constructor(trace: Trace);
898
+ abstract get aval(): AbstractValue;
899
+ abstract toString(): string;
900
+ /**
901
+ * Access an array by reference, incrementing the reference count.
902
+ *
903
+ * jax-js handles freeing arrays by using "move" semantics, like in Rust/C++.
904
+ * Whenever you pass an array into a function, that function should consume
905
+ * the array, and it will no longer be usable. For example, if you had:
906
+ *
907
+ * ```
908
+ * const x = np.array([1, 2, 3]);
909
+ * const y = np.add(x, x);
910
+ * ```
911
+ *
912
+ * The second line does not work because the first parameter consumes `x`, and
913
+ * then the second parameter will already have been freed / disposed.
914
+ *
915
+ * To fix this, you can write:
916
+ *
917
+ * ```
918
+ * const y = np.add(x.ref, x);
919
+ * ```
920
+ *
921
+ * Under the hood, every access to `.ref` increments the internal reference
922
+ * count of the array. The reference count starts at 1. When it hits 0, the
923
+ * memory behind the array is freed.
924
+ */
925
+ abstract get ref(): this;
926
+ /**
927
+ * Manually decrement the reference count of the array.
928
+ *
929
+ * Arrays are created with reference count 1. Whenever it is used as argument
930
+ * to a function or other operation, it is disposed (i.e., reference count
931
+ * decreases by 1) automatically. Whenever a `.ref` is created, the reference
932
+ * count increases.
933
+ *
934
+ * You generally don't need to call this function directly since arrays are
935
+ * automatically disposed after being passed into an operation. One common
936
+ * exception is when writing a function and ignoring one of its arguments. In
937
+ * that case, by convention you should dispose of that argument manually.
938
+ *
939
+ * ```
940
+ * function myCustomOperation(a: np.Array, b: np.Array) {
941
+ * b.dispose(); // Needed to satisfy "move" rules.
942
+ * return a.add(1);
943
+ * }
944
+ * ```
945
+ */
946
+ abstract dispose(): void;
947
+ /** The shape of the array. */
948
+ get shape(): number[];
949
+ /** The total number of elements in the array. */
950
+ get size(): number;
951
+ /** The dtype of elements stored in the array. */
952
+ get dtype(): DType;
953
+ /**
954
+ * Whether the array is weakly typed.
955
+ *
956
+ * Weakly typed arrays will cast to the dtype of the other operand. See
957
+ * `promoteTypes()` for details.
958
+ */
959
+ get weakType(): boolean;
960
+ /** The number of dimensions of the array. */
961
+ get ndim(): number;
962
+ /** @ignore */
963
+ fullLower(): Tracer;
964
+ neg(): this;
965
+ add(other: this | TracerValue): this;
966
+ mul(other: this | TracerValue): this;
967
+ mod(other: this | TracerValue): this;
968
+ greater(other: this | TracerValue): this;
969
+ less(other: this | TracerValue): this;
970
+ equal(other: this | TracerValue): this;
971
+ notEqual(other: this | TracerValue): this;
972
+ greaterEqual(other: this | TracerValue): this;
973
+ lessEqual(other: this | TracerValue): this;
974
+ /** Sum of the elements of the array over a given axis, or axes. */
975
+ sum(axis?: Axis, opts?: ReduceOpts): this;
976
+ /** Product of the array elements over a given axis. */
977
+ prod(axis?: Axis, opts?: ReduceOpts): this;
978
+ /** Compute the average of the array elements along the specified axis. */
979
+ mean(axis?: Axis, opts?: ReduceOpts): this;
980
+ /** Minimum of the elements of the array along a given axis. */
981
+ min(axis?: Axis, opts?: ReduceOpts): this;
982
+ /** Maximum of the elements of the array along a given axis. */
983
+ max(axis?: Axis, opts?: ReduceOpts): this;
984
+ /** Test whether all array elements along a given axis evaluate to true. */
985
+ all(axis?: Axis, opts?: ReduceOpts): this;
986
+ /** Test whether any array element along a given axis evaluates to true. */
987
+ any(axis?: Axis, opts?: ReduceOpts): this;
988
+ /** Permute the dimensions of an array. Defaults to reversing the axis order. */
989
+ transpose(perm?: number[]): this;
990
+ /**
991
+ * Give a new shape to an array without changing its data.
992
+ *
993
+ * One shape dimension can be -1. In this case, the value is inferred from the
994
+ * length of the array and remaining dimensions.
995
+ */
996
+ reshape(shape: number | number[]): this;
997
+ /** Copy the array and cast to a specified dtype. */
998
+ astype(dtype: DType): this;
999
+ /** Subtract an array from this one. */
1000
+ sub(other: this | TracerValue): this;
1001
+ /** Divide an array by this one. */
1002
+ div(other: this | TracerValue): this;
1003
+ /** Return specified diagonals. See `jax.numpy.diagonal` for full docs. */
1004
+ diagonal(offset?: number, axis1?: number, axis2?: number): this;
1005
+ /** Flatten the array without changing its data. */
1006
+ flatten(): this;
1007
+ /** Flatten the array without changing its data. */
1008
+ ravel(): this;
1009
+ /**
1010
+ * Iterate over the first dimension of this array, returning slices.
1011
+ *
1012
+ * This can be used to destructure arrays. For example:
1013
+ *
1014
+ * ```js
1015
+ * let x = np.array([[1, 2], [3, 4]]);
1016
+ * let [a, b] = x;
1017
+ * console.log(a.js()); // [1, 2]
1018
+ * console.log(b.js()); // [3, 4]
1019
+ * ```
1020
+ */
1021
+ [Symbol.iterator](): IterableIterator<this>;
1022
+ /**
1023
+ * Return a sorted copy of an array in ascending order.
1024
+ *
1025
+ * See `jax.numpy.sort` for full docs.
1026
+ */
1027
+ sort(axis?: number): this;
1028
+ /**
1029
+ * Return the indices that would sort an array. This may not be a stable
1030
+ * sorting algorithm; it need not preserve order of indices in ties.
1031
+ *
1032
+ * See `jax.numpy.argsort` for full docs.
1033
+ */
1034
+ argsort(axis?: number): this;
1035
+ /**
1036
+ * Slice an array along one or more axes.
1037
+ *
1038
+ * This is the equivalent of slicing in Python, e.g. `x[1:3, 2, :, None]`. To
1039
+ * mimic this in JavaScript, we would write:
1040
+ *
1041
+ * ```js
1042
+ * x.slice([1, 3], 2, [], null);
1043
+ * ```
1044
+ *
1045
+ * The `slice` method accepts a variable number of arguments, each of which
1046
+ * can be a number, an empty array, a single-element array, a two-element
1047
+ * array, or `null`. The arguments are interpreted as follows:
1048
+ *
1049
+ * - A number `n` means to access the `n`-th element along that axis, removing
1050
+ * that axis from the resulting shape.
1051
+ * - An empty array `[]` means to keep that axis as-is, like `:` in Python.
1052
+ * - A single-element array `[i]` means to start slicing from index `i`
1053
+ * (inclusive) to the end of the axis, like `x[i:]`.
1054
+ * - A two-element array `[i, j]` means to slice from index `i` (inclusive)
1055
+ * to index `j` (exclusive), like `x[i:j]`.
1056
+ * - `null` means to add a new axis at that position, like `np.newaxis`.
1057
+ *
1058
+ * Like in Python, negative indices are supported, which count from the end of
1059
+ * the axis. For example, `-1` means the last element.
1060
+ *
1061
+ * Strided slices are not yet implemented, so you cannot write `x[::2]` or
1062
+ * similar.
1063
+ *
1064
+ * Advanced indexing by integer arrays is also supported. This translates to
1065
+ * the "gather" primitive, and it allows you to access specific elements of
1066
+ * the array by integer indices stored in another array.
1067
+ */
1068
+ slice(...index: (number | [] | [number] | Pair | null | Tracer)[]): this;
1069
+ }
1070
+ declare class ShapedArray implements AbstractValue {
1071
+ readonly shape: number[];
1072
+ readonly dtype: DType;
1073
+ readonly weakType: boolean;
1074
+ constructor(shape: number[], dtype: DType, weakType: boolean);
1075
+ static fromAval(aval: AbstractValue): ShapedArray;
1076
+ get ndim(): number;
1077
+ get size(): number;
1078
+ scalar(): ShapedArray;
1079
+ toString(): string;
1080
+ equals(other: ShapedArray): boolean;
1081
+ }
1082
+ //#endregion
1083
+ //#region src/frontend/array.d.ts
1084
+ type ArrayLike = Array | number | boolean;
1085
+ /** Version of pureArray with fudged types. */
1086
+
746
1087
  /**
747
- * @function
748
- * Add padding (zeros) to an array.
1088
+ * An executable operation that will be dispatched to the backend.
749
1089
  *
750
- * The `width` argument is either an integer or pair of integers, in which case
751
- * all axes are padded with the same width. Or if it is an array of pairs, each
752
- * pair specifies the padding for its corresponding axis.
753
- */
754
- declare const pad: (x: ArrayLike, width: number | Pair | Pair[]) => Array;
755
- /**
756
- * @function
757
- * Return the number of dimensions of an array. Does not consume array reference.
758
- */
759
- declare const ndim: (x: ArrayLike) => number;
760
- /** @function Return the shape of an array. Does not consume array reference. */
761
- declare const shape$1: (x: ArrayLike) => number[];
762
- /**
763
- * @function
764
- * Return an array of zeros with the same shape and type as a given array.
765
- */
766
- declare const zerosLike: (a: ArrayLike, dtype?: DType) => Array;
767
- /**
768
- * @function
769
- * Return an array of ones with the same shape and type as a given array.
770
- */
771
- declare const onesLike: (a: ArrayLike, dtype?: DType) => Array;
772
- /**
773
- * @function
774
- * Return a full array with the same shape and type as a given array.
775
- */
776
- declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, dtype?: DType) => Array;
777
- /**
778
- * Return the number of elements in an array, optionally along an axis.
779
- * Does not consume array reference.
1090
+ * This holds a reference to all input buffers used in the operation. After the
1091
+ * operation is dispatched, the references should be released.
780
1092
  */
781
- declare function size(a: ArrayLike, axis?: number): number;
782
- /** Convert an array to a specified dtype. */
783
- declare function astype(a: ArrayLike, dtype: DType): Array;
784
- /** Sum of the elements of the array over a given axis, or axes. */
785
- declare function sum(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
786
- /** Product of the array elements over a given axis. */
787
- declare function prod(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
788
- /** Return the minimum of array elements along a given axis. */
789
- declare function min(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
790
- /** Return the maximum of array elements along a given axis. */
791
- declare function max(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1093
+ declare class PendingExecute {
1094
+ #private;
1095
+ readonly backend: Backend;
1096
+ readonly source: Kernel | Routine;
1097
+ readonly inputs: Slot[];
1098
+ readonly outputs: Slot[];
1099
+ prepared: Executable | null;
1100
+ submitted: boolean;
1101
+ constructor(backend: Backend, source: Kernel | Routine, inputs: Slot[], outputs: Slot[]);
1102
+ updateRc(delta: number): void;
1103
+ prepare(): Promise<void>;
1104
+ prepareSync(): void;
1105
+ submit(): void;
1106
+ }
1107
+ /** @inline */
1108
+ type DTypeAndDevice = {
1109
+ dtype?: DType;
1110
+ device?: Device;
1111
+ };
1112
+ type ArrayConstructorArgs = {
1113
+ source: AluExp | Slot;
1114
+ st: ShapeTracker;
1115
+ dtype: DType;
1116
+ weakType: boolean;
1117
+ backend: Backend;
1118
+ committed: boolean;
1119
+ pending?: Iterable<PendingExecute>;
1120
+ };
792
1121
  /**
793
- * Test whether all array elements along a given axis evaluate to True.
1122
+ * A multidimensional numeric array with data stored on CPU or GPU.
794
1123
  *
795
- * Returns a boolean array with the same shape as `a` with the specified axis
796
- * removed. If axis is None, returns a scalar.
1124
+ * This is the library's core data type. Equivalent to `jax.Array` from JAX, or
1125
+ * `torch.Tensor`.
1126
+ *
1127
+ * Not to be confused with the JavaScript "Array" constructor. Avoid importing
1128
+ * this into your code's namespace if you're already using the JavaScript
1129
+ * "Array" type by name.
797
1130
  */
798
- declare function all(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1131
+ declare class Array extends Tracer {
1132
+ #private;
1133
+ /**
1134
+ * @ignore
1135
+ * Constructs an array from source, shape and backend. Note that if the source
1136
+ * is a backend `Slot`, this constructor _takes ownership_ of the slot. It
1137
+ * will be freed when the array is disposed.
1138
+ */
1139
+ constructor(args: ArrayConstructorArgs);
1140
+ /** @ignore */
1141
+ get aval(): ShapedArray;
1142
+ /** Return a simple string representation of the array's dimensions. */
1143
+ toString(): string;
1144
+ get device(): Device;
1145
+ get ref(): this;
1146
+ /** Get the current reference count (for debugging memory management). */
1147
+ get refCount(): number;
1148
+ dispose(): void;
1149
+ /**
1150
+ * Convert this array into a primitive value.
1151
+ *
1152
+ * This only works for scalars (0-dimensional arrays). It lets you get values
1153
+ * "out" of the JAX system. For instance, if `x = np.array(5)`, then you can
1154
+ * evaluate `x + 1` and `x ** 2` to get `6` and `25`, respectively.
1155
+ *
1156
+ * This method is also called for `==` equality.
1157
+ */
1158
+ [Symbol.toPrimitive](): any;
1159
+ /** Realize the array and return it as data. */
1160
+ data(): Promise<DataArray>;
1161
+ /**
1162
+ * Wait for this array to finish evaluation.
1163
+ *
1164
+ * Operations and data loading in jax-js are lazy, so this function ensures
1165
+ * that pending operations are dispatched and fully executed before it
1166
+ * returns.
1167
+ *
1168
+ * If you are mapping from `data()` or `dataSync()`, it will also trigger
1169
+ * dispatch of operations as well.
1170
+ *
1171
+ * **Note:** `jax.blockUntilReady()` is a higher-level API, it calls this
1172
+ * asynchronously for multiple arrays.
1173
+ */
1174
+ blockUntilReady(): Promise<Array>;
1175
+ /**
1176
+ * Realize the array and return it as data. This is a sync variant and not
1177
+ * recommended for performance reasons, as it will block rendering.
1178
+ */
1179
+ dataSync(): DataArray;
1180
+ /**
1181
+ * Convert this array into a JavaScript object.
1182
+ *
1183
+ * This is a blocking operation that will compile all of the shaders and wait
1184
+ * for execution to complete, synchronously. No other JavaScript code on the
1185
+ * site will be run during shader execution.
1186
+ *
1187
+ * To avoid blocking, prefer `jsAsync()` when possible.
1188
+ */
1189
+ js(): any;
1190
+ /** Convert this array into a JavaScript object, asynchronously. */
1191
+ jsAsync(): Promise<any>;
1192
+ /**
1193
+ * Copy an element of an array to a numeric scalar and return it.
1194
+ *
1195
+ * Throws an error if the array does not have a single element. The array must
1196
+ * either be rank-0, or all dimensions of the shape are 1.
1197
+ */
1198
+ item(): number;
1199
+ /** @private Internal plumbing method for Array / Tracer ops. */
1200
+ static _implRules(): typeof implRules;
1201
+ /** @private */
1202
+ _realizeSource(): number;
1203
+ /** @private Put this array on a new backend, asynchronously. */
1204
+ _put(backend: Backend): Promise<Array>;
1205
+ /** @private Put this array on a new backend, synchronously. */
1206
+ _putSync(backend: Backend): Array;
1207
+ }
1208
+ /** Constructor for creating a new array from data. */
1209
+ declare function array(values: Array | DataArray | RecursiveArray<number> | RecursiveArray<boolean>, {
1210
+ shape,
1211
+ dtype,
1212
+ device
1213
+ }?: {
1214
+ shape?: number[];
1215
+ } & DTypeAndDevice): Array;
1216
+ /** If x is a value, lift it into an array, otherwise leave it be. */
1217
+
1218
+ type ImplRule<P extends Primitive> = (tracers: Array[], params: PrimitiveParams<P>) => Array[];
1219
+ declare const implRules: { [P in Primitive]: ImplRule<P> };
1220
+ /** Return a new array of given shape and type, filled with zeros. */
1221
+ declare function zeros(shape: number[], {
1222
+ dtype,
1223
+ device
1224
+ }?: DTypeAndDevice): Array;
1225
+ /** Return a new array of given shape and type, filled with ones. */
1226
+ declare function ones(shape: number[], {
1227
+ dtype,
1228
+ device
1229
+ }?: DTypeAndDevice): Array;
1230
+ /** Return a new array of given shape and type, filled with `fill_value`. */
1231
+ declare function full(shape: number[], fillValue: number | boolean | Array, {
1232
+ dtype,
1233
+ device
1234
+ }?: DTypeAndDevice): Array;
799
1235
  /**
800
- * Test whether any array element along a given axis evaluates to True.
1236
+ * Create an identity matrix.
801
1237
  *
802
- * Returns a boolean array with the same shape as `a` with the specified axis
803
- * removed. If axis is None, returns a scalar.
1238
+ * If numCols is not provided, it defaults to numRows, i.e., a square identity
1239
+ * matrix with ones on the diagonal.
804
1240
  */
805
- declare function any(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
806
- /** Return the peak-to-peak range along a given axis (`max - min`). */
807
- declare function ptp(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
808
- /** Compute the average of the array elements along the specified axis. */
809
- declare function mean(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1241
+ declare function eye(numRows: number, numCols?: number, {
1242
+ dtype,
1243
+ device
1244
+ }?: DTypeAndDevice): Array;
1245
+ /** Return the identity matrix, with ones on the main diagonal. */
1246
+ declare function identity$1(n: number, {
1247
+ dtype,
1248
+ device
1249
+ }?: DTypeAndDevice): Array;
810
1250
  /**
811
- * Returns the indices of the minimum values along an axis.
1251
+ * Return evenly spaced values within a given interval.
812
1252
  *
813
- * By default, index is into the flatted array, otherwise it is along the
814
- * specified axis.
1253
+ * This can be called with a varying number of arguments, just like the range()
1254
+ * builtin function in Python.
1255
+ *
1256
+ * - `arange(stop)` is equivalent to `arange(0, stop, 1)`.
1257
+ * - `arange(start, stop)` is equivalent to `arange(start, stop, 1)`.
1258
+ * - `arange(start, stop, step)` creates an array starting at `start`, ending
1259
+ * before `stop`, with a step size of `step`.
1260
+ *
1261
+ * Defaults to an integer data type. This can produce unintended results when
1262
+ * using a non-integer step, so prefer linspace() in those cases.
815
1263
  */
816
- declare function argmin(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
1264
+ declare function arange(start: number, stop?: number, step?: number, {
1265
+ dtype,
1266
+ device
1267
+ }?: DTypeAndDevice): Array;
817
1268
  /**
818
- * Returns the indices of the maximum values along an axis.
1269
+ * Return an array with ones on and below the diagonal and zeros elsewhere.
819
1270
  *
820
- * By default, index is into the flatted array, otherwise it is along the
821
- * specified axis.
1271
+ * If `k` is provided, it specifies the sub-diagonal on and below which the
1272
+ * array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
1273
+ * `k>0` is above it.
822
1274
  */
823
- declare function argmax(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
1275
+ declare function tri(n: number, m?: number, k?: number, {
1276
+ dtype,
1277
+ device
1278
+ }?: DTypeAndDevice): Array;
1279
+ /** Return the lower triangle of an array. Must be of dimension >= 2. */
1280
+ declare function tril(a: ArrayLike, k?: number): Array;
1281
+ /** Return the upper triangle of an array. Must be of dimension >= 2. */
1282
+ declare function triu(a: ArrayLike, k?: number): Array;
824
1283
  /**
825
- * Cumulative sum of elements along an axis.
1284
+ * Return evenly spaced numbers over a specified interval.
826
1285
  *
827
- * Currently this function is `O(n^2)`, we'll improve this later on with a
828
- * two-phase parallel reduction algorithm.
1286
+ * Returns _num_ evenly spaced samples, calculated over the interval
1287
+ * [`start`, `stop`]. The endpoint `stop` is included in the result by default,
1288
+ * but this is controlled by the `endpoint` parameter.
1289
+ *
1290
+ * The default data type is Float32. Use arange() for integer steps.
829
1291
  */
830
- declare function cumsum(a: ArrayLike, axis?: number): Array;
831
- /** Reverse the elements in an array along the given axes. */
832
- declare function flip(x: ArrayLike, axis?: Axis): Array;
1292
+ declare function linspace(start: number, stop: number, num?: number, endpoint?: boolean, {
1293
+ dtype,
1294
+ device
1295
+ }?: DTypeAndDevice): Array;
833
1296
  /**
834
- * Split an array into multiple sub-arrays along an axis.
1297
+ * Return numbers spaced evenly on a log scale.
835
1298
  *
836
- * @param a - The input array to split.
837
- * @param indicesOrSections - If an integer, it indicates the number of equal
838
- * sections to create along the specified axis. If a list of integers, it
839
- * specifies the indices at which to split the array.
840
- * @param axis - The axis along which to split the array. Default is 0.
1299
+ * In linear space, the sequence starts at `base ** start` and ends at
1300
+ * `base ** stop` (see `endpoint` below).
1301
+ *
1302
+ * @param start - `base ** start` is the starting value of the sequence.
1303
+ * @param stop - `base ** stop` is the final value of the sequence, unless `endpoint` is false.
1304
+ * @param num - Number of samples to generate. Default is 50.
1305
+ * @param endpoint - If true, `stop` is the last sample. Otherwise, it is not included. Default is true.
1306
+ * @param base - The base of the log space. Default is 10.
1307
+ * @returns Array of evenly spaced values on a log scale.
841
1308
  */
842
- declare function split$1(a: ArrayLike, indicesOrSections: number | number[], axis?: number): Array[];
1309
+ declare function logspace(start: number, stop: number, num?: number, endpoint?: boolean, base?: number, {
1310
+ dtype,
1311
+ device
1312
+ }?: DTypeAndDevice): Array;
1313
+ //#endregion
1314
+ //#region src/frontend/linearize.d.ts
1315
+ /** @inline */
1316
+ type GradOpts = {
1317
+ /**
1318
+ * Integer or sequence of integers. Specifies which positional argument(s) to
1319
+ * differentiate with respect to.
1320
+ *
1321
+ * Defaults to `0` (the first argument).
1322
+ */
1323
+ argnums?: number | number[];
1324
+ /**
1325
+ * The input function returns a pair of `[out, aux]` including an auxiliary
1326
+ * value. This `aux` is not differentiated, but is returned alongside the
1327
+ * gradient when evaluating the function.
1328
+ */
1329
+ hasAux?: boolean;
1330
+ };
1331
+ declare namespace lax_linalg_d_exports {
1332
+ export { cholesky$1 as cholesky, lu, triangularSolve };
1333
+ }
843
1334
  /**
844
- * Join a sequence of arrays along an existing axis.
1335
+ * Compute the Cholesky decomposition of a symmetric positive-definite matrix.
845
1336
  *
846
- * The arrays must have the same shape, except in the dimension corresponding to
847
- * `axis` (the first, by default).
1337
+ * The Cholesky decomposition of a matrix `A` is:
848
1338
  *
849
- * No scalars can be passed to this function, as the axis is then ambiguous.
1339
+ * - A = L @ L^T (for upper=false, default)
1340
+ * - A = U^T @ U (for upper=true)
1341
+ *
1342
+ * where `L` is a lower-triangular matrix and `U` is an upper-triangular matrix.
1343
+ * The input matrix must be symmetric and positive-definite.
1344
+ *
1345
+ * @example
1346
+ * ```ts
1347
+ * import { lax, numpy as np } from "@jax-js/jax";
1348
+ *
1349
+ * const x = np.array([[2., 1.], [1., 2.]]);
1350
+ *
1351
+ * // Lower Cholesky factorization (default):
1352
+ * const L = lax.linalg.cholesky(x);
1353
+ * // L ≈ [[1.4142135, 0], [0.70710677, 1.2247449]]
1354
+ *
1355
+ * // Upper Cholesky factorization:
1356
+ * const U = lax.linalg.cholesky(x, { upper: true });
1357
+ * // U ≈ [[1.4142135, 0.70710677], [0, 1.2247449]]
1358
+ * ```
850
1359
  */
851
- declare function concatenate(xs: Array[], axis?: number): Array;
1360
+ declare function cholesky$1(a: ArrayLike, {
1361
+ upper
1362
+ }?: {
1363
+ upper?: boolean;
1364
+ }): Array;
852
1365
  /**
853
- * Join a sequence of arrays along a new axis.
1366
+ * LU decomposition with partial pivoting.
854
1367
  *
855
- * The `axis` parameter specifies the index of the new axis in the dimensions of
856
- * the result. For example, if `axis=0` it will be the first dimension and if
857
- * `axis=-1` it will be the last dimension.
1368
+ * Computes the matrix decomposition: `P @ A = L @ U`, where `P` is a
1369
+ * permutation of the rows of `A`, `L` is lower-triangular with unit diagonal,
1370
+ * and `U` is upper-triangular.
858
1371
  *
859
- * All shapes must have the same shape.
1372
+ * @param x - A batch of matrices with shape `[..., m, n]`.
1373
+ *
1374
+ * @returns A tuple `(lu, pivots, permutation)` where:
1375
+ * - `lu`: combined lower and upper triangular matrices.
1376
+ * - `pivots`: an array of pivot indices with shape `[..., min(m, n)]`.
1377
+ * - `permutation`: the permutation generated by pivots with shape `[..., m]`.
1378
+ *
1379
+ * @example
1380
+ * ```ts
1381
+ * import { lax, numpy as np } from "@jax-js/jax";
1382
+ *
1383
+ * const A = np.array([[4., 3.], [6., 3.]]);
1384
+ * const [lu, pivots, permutation] = lax.linalg.lu(A);
1385
+ * // lu ≈ [[6., 3.], [0.6666667, 1.0]]
1386
+ * // pivots = [1, 1]
1387
+ * // permutation = [1, 0]
1388
+ * ```
860
1389
  */
861
- declare function stack(xs: ArrayLike[], axis?: number): Array;
1390
+ declare function lu(x: ArrayLike): [Array, Array, Array];
862
1391
  /**
863
- * Horizontally stack arrays. Inputs are promoted to rank at least 1, then
864
- * concatenated along axis 1 (if rank-2 or higher) or 0 (if rank-1).
1392
+ * Solve a triangular linear system.
1393
+ *
1394
+ * Solves `a @ x = b` (if leftSide=true) or `x @ a = b` (if leftSide=false)
1395
+ * where `a` is a triangular matrix.
1396
+ *
1397
+ * @example
1398
+ * ```ts
1399
+ * import { lax, numpy as np } from "@jax-js/jax";
1400
+ *
1401
+ * const L = np.array([[2., 0.], [1., 3.]]);
1402
+ * const b = np.array([4., 7.]).reshape([2, 1]);
1403
+ *
1404
+ * // Solve L @ x = b
1405
+ * const x = lax.linalg.triangularSolve(L, b, { leftSide: true, lower: true });
1406
+ * // x = [[2.], [5./3.]]
1407
+ * ```
865
1408
  */
866
- declare function hstack(xs: ArrayLike[]): Array;
1409
+ declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
1410
+ leftSide,
1411
+ lower,
1412
+ transposeA,
1413
+ unitDiagonal
1414
+ }?: {
1415
+ leftSide?: boolean;
1416
+ lower?: boolean;
1417
+ transposeA?: boolean;
1418
+ unitDiagonal?: boolean;
1419
+ }): Array;
1420
+ declare namespace lax_d_exports {
1421
+ export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
1422
+ }
867
1423
  /**
868
- * Vertically stack arrays. Inputs are promoted to rank at least 2, then
869
- * concatenated along axis 0.
1424
+ * Dimension numbers for general `dot()` primitive.
1425
+ *
1426
+ * Contracting dimensions act as a tensor contraction (reduction) along the
1427
+ * given axis. They must be the same size in both operands. Batch dimensions
1428
+ * are treated as vectorized, leading batch dimensions.
1429
+ *
1430
+ * The return value has a shape where the first dimensions are shared batch
1431
+ * dimensions, followed by `lhs` non-contracting dimensions, followed by
1432
+ * `rhs` non-contracting dimensions.
870
1433
  */
871
- declare function vstack(xs: ArrayLike[]): Array;
1434
+ type DotDimensionNumbers = {
1435
+ lhsContractingDims?: number[];
1436
+ rhsContractingDims?: number[];
1437
+ lhsBatchDims?: number[];
1438
+ rhsBatchDims?: number[];
1439
+ };
872
1440
  /**
873
- * Stack arrays depth-wise. Inputs are promoted to rank at least 3, then
874
- * concatenated along axis 2.
1441
+ * General dot product/contraction operator.
1442
+ *
1443
+ * Prefer higher-level functions like `jax.numpy.dot()`, `jax.numpy.matmul()`,
1444
+ * `jax.numpy.tensordot(), and `jax.numpy.einsum()` where possible.
875
1445
  */
876
- declare function dstack(xs: ArrayLike[]): Array;
1446
+ declare function dot$1(lhs: Array, rhs: Array, {
1447
+ lhsContractingDims: lc,
1448
+ rhsContractingDims: rc,
1449
+ lhsBatchDims: lb,
1450
+ rhsBatchDims: rb
1451
+ }?: DotDimensionNumbers): Array;
1452
+ type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | Pair[];
877
1453
  /**
878
- * Stack arrays column-wise. Inputs are promoted to rank at least 2, then
879
- * concatenated along axis 1.
1454
+ * General n-dimensional convolution operator, with optional dilation.
1455
+ *
1456
+ * The semantics of this operation mimic the `jax.lax.conv_general_dilated`
1457
+ * function in JAX, which wraps XLA's general convolution operator.
1458
+ *
1459
+ * @param lhs - Input tensor; shape `[N, C_in, ...xs]`
1460
+ * @param rhs - Convolution kernel; shape `[C_out, C_in / G, ...ks]`
1461
+ * @param windowStrides - Strides for each spatial dimension
1462
+ * @param padding - Padding for each spatial dimension, or a string
1463
+ * (`"VALID"`, `"SAME"`, or `"SAME_LOWER"`)
880
1464
  */
881
- declare function columnStack(xs: ArrayLike[]): Array;
882
- /** Flip an array vertically (axis=0). */
883
- declare function flipud(x: ArrayLike): Array;
884
- /** Flip an array horizontally (axis=1). */
885
- declare function fliplr(x: ArrayLike): Array;
886
- /** Transpose the last two dimensions of an array. */
887
- declare function matrixTranspose(a: ArrayLike): Array;
888
- /** Return a 1-D flattened array containing the elements of the input. */
889
- declare function ravel(a: ArrayLike): Array;
890
- /** Remove one or more length-1 axes from an array. */
891
- declare function squeeze(a: ArrayLike, axis?: Axis): Array;
1465
+ declare function convGeneralDilated(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, {
1466
+ lhsDilation,
1467
+ rhsDilation,
1468
+ featureGroupCount
1469
+ }?: {
1470
+ lhsDilation?: number[];
1471
+ rhsDilation?: number[];
1472
+ featureGroupCount?: number;
1473
+ }): Array;
1474
+ /** Convenience wrapper around `convGeneralDilated`. */
1475
+ declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, lhsDilation?: number[], rhsDilation?: number[]): Array;
1476
+ /** Convenience wrapper around `convGeneralDilated`. */
1477
+ declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType): Array;
892
1478
  /**
893
- * Expand the shape of an array by inserting new axes of length 1.
1479
+ * Convenience wrapper for calculating the N-d convolution "transpose".
894
1480
  *
895
- * @param a - Input array.
896
- * @param axis - Position(s) in the expanded axes where the new axis (or axes)
897
- * is placed. Can be a single integer or an array of integers.
898
- * @returns Array with the number of dimensions increased.
1481
+ * This function directly calculates a fractionally strided conv rather than
1482
+ * indirectly calculating the gradient (transpose) of a forward convolution.
1483
+ * It is equivalent to the JAX version, except:
899
1484
  *
900
- * @example
901
- * ```ts
902
- * const x = np.array([1, 2]);
903
- * np.expandDims(x, 0); // Shape [1, 2]
904
- * np.expandDims(x, 1); // Shape [2, 1]
905
- * np.expandDims(x, [0, 2]); // Shape [1, 2, 1]
906
- * ```
1485
+ * - The `use_consistent_padding` option is not available. We only have the
1486
+ * consistent padding case (JAX version >0.8.4).
1487
+ * - The order of dimensions matches `lax.conv_general_dilated`.
1488
+ *
1489
+ * Unlike PyTorch/TensorFlow, by default we don't reverse the kernel's spatial
1490
+ * dimensions or the `(C_out, C_in)` axis order. To get this behavior, set
1491
+ * `transposeKernel` to true.
1492
+ *
1493
+ * @param lhs - Input tensor; shape `[N, C_in, ...xs]`
1494
+ * @param rhs - Convolution kernel; shape `[C_out, C_in, ...ks]`
1495
+ * @param strides - Sequence of n integers, sets fractional stride
1496
+ * @param padding - Apply padding of `dilation * (kernel_size - 1) - padding` to
1497
+ * each side of the input, so it acts like gradient of `conv()`
1498
+ * @param rhsDilation - Atrous dilation for the kernel
1499
+ * @param transposeKernel - Flip spatial axes and swap the input/output channels
1500
+ * of the kernel; its shape should be `[C_in, C_out, ...ks]`
907
1501
  */
908
- declare function expandDims(a: ArrayLike, axis: number | number[]): Array;
1502
+ declare function convTranspose(lhs: Array, rhs: Array, strides: number[], padding: PaddingType, {
1503
+ rhsDilation,
1504
+ transposeKernel
1505
+ }?: {
1506
+ rhsDilation?: number[];
1507
+ transposeKernel?: boolean;
1508
+ }): Array;
1509
+ /** Reduce a computation over padded windows. */
1510
+ declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
1511
+ /** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
1512
+ declare function erf(x: ArrayLike): Array;
909
1513
  /**
910
- * Repeat each element of an array after themselves.
1514
+ * The complementary error function: `erfc(x) = 1 - erf(x)`.
911
1515
  *
912
- * If no axis is provided, use the flattened input array, and return a flat
913
- * output array.
1516
+ * This function is more accurate than `1 - erf(x)` for large values of `x`,
1517
+ * where `erf(x)` is very close to 1.
914
1518
  */
915
- declare function repeat(a: ArrayLike, repeats: number, axis?: number): Array;
1519
+ declare function erfc(x: ArrayLike): Array;
916
1520
  /**
917
- * Construct an array by repeating A the number of times given by reps.
1521
+ * Stops gradient computation.
918
1522
  *
919
- * If `A` is an array of shape `(d1, d2, ..., dn)` and `reps` is a sequence of
920
- * integers, the resulting array will have a shape of `(reps[0] * d1,
921
- * reps[1] * d2, ..., reps[n] * dn)`, with `A` tiled along each dimension.
1523
+ * Behaves as the identity function but prevents the flow of gradients during
1524
+ * forward or reverse-mode automatic differentiation.
922
1525
  */
923
- declare function tile(a: ArrayLike, reps: number | number[]): Array;
1526
+ declare function stopGradient(x: ArrayLike): Array;
1527
+ declare namespace numpy_fft_d_exports {
1528
+ export { ComplexPair, fft, ifft };
1529
+ }
924
1530
  /**
925
- * Broadcast an array to a shape, with NumPy-style broadcasing rules.
926
- *
927
- * In other words, this lets you append axes to the left, and/or expand
928
- * dimensions where the shape is 1.
1531
+ * A pair of arrays representing real and imaginary part `a + bj`. Both arrays
1532
+ * must have the same shape.
929
1533
  */
930
- declare function broadcastTo(a: ArrayLike, shape: number[]): Array;
931
- /** Broadcast input shapes to a common output shape. */
932
- declare function broadcastShapes(...shapes: number[][]): number[];
933
- /** Broadcast arrays to a common shape. */
934
- declare function broadcastArrays(...arrays: ArrayLike[]): Array[];
1534
+ type ComplexPair = {
1535
+ real: Array;
1536
+ imag: Array;
1537
+ };
935
1538
  /**
936
- * Return specified diagonals.
937
- *
938
- * If a is 2D, return the diagonal of the array with the given offset. If a is
939
- * 3D or higher, compute diagonals along the two given axes (default: 0, 1).
1539
+ * Compute a one-dimensional discrete Fourier transform.
940
1540
  *
941
- * This returns a view over the existing array. The shape of the resulting array
942
- * is determined by removing the two axes along which the diagonal is taken,
943
- * then appending a new axis to the right with holding the diagonals.
1541
+ * Currently, the size of the axis must be a power of two.
944
1542
  */
945
- declare function diagonal(a: ArrayLike, offset?: number, axis1?: number, axis2?: number): Array;
1543
+ declare function fft(a: ComplexPair, axis?: number): ComplexPair;
946
1544
  /**
947
- * Extract a diagonal or construct a diagonal array.
1545
+ * Compute a one-dimensional inverse discrete Fourier transform.
948
1546
  *
949
- * If v is a 2D array, return the k-th diagonal of v (as a view). If v is a 1D
950
- * array, return a 2D array with v on the k-th diagonal.
1547
+ * Currently, the size of the axis must be a power of two.
951
1548
  */
952
- declare function diag(v: ArrayLike, k?: number): Array;
953
- /** Calculate the sum of the diagonal of an array along the given axes. */
954
- declare function trace(a: ArrayLike, offset?: number, axis1?: number, axis2?: number): Array;
1549
+ declare function ifft(a: ComplexPair, axis?: number): ComplexPair;
1550
+ declare namespace numpy_linalg_d_exports {
1551
+ export { cholesky, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot };
1552
+ }
955
1553
  /**
956
- * Return a sorted copy of an array.
1554
+ * Compute the Cholesky decomposition of a (batched) positive-definite matrix.
957
1555
  *
958
- * The array is sorted along a specified axis (the last by default). This may be
959
- * an unstable sort, and it dispatches to device-specific implementation.
1556
+ * This is like `jax.lax.linalg.cholesky()`, except with an option to symmetrize
1557
+ * the input matrix, which is on by default.
960
1558
  */
961
- declare function sort(a: ArrayLike, axis?: number): Array;
1559
+ declare function cholesky(a: ArrayLike, {
1560
+ upper,
1561
+ symmetrizeInput
1562
+ }?: {
1563
+ upper?: boolean;
1564
+ symmetrizeInput?: boolean;
1565
+ }): Array;
1566
+ /** Compute the determinant of a square matrix (batched). */
1567
+ declare function det(a: ArrayLike): Array;
1568
+ /** Compute the inverse of a square matrix (batched). */
1569
+ declare function inv(a: ArrayLike): Array;
962
1570
  /**
963
- * Return indices that would sort an array. This may be an unstable sorting
964
- * algorithm; it need not preserve order of indices in ties.
1571
+ * Return the least-squares solution to a linear equation.
965
1572
  *
966
- * Returns an array of `int32` indices.
1573
+ * For overdetermined systems, this finds the `x` that minimizes `norm(ax - b)`.
1574
+ * For underdetermined systems, this finds the minimum-norm solution for `x`.
967
1575
  *
968
- * The array is sorted along a specified axis (the last by default).
1576
+ * This currently uses Cholesky decomposition to solve the normal equations,
1577
+ * under the hood. The method is not as robust as QR or SVD.
1578
+ *
1579
+ * @param a coefficient matrix of shape `(M, N)`
1580
+ * @param b right-hand side of shape `(M,)` or `(M, K)`
1581
+ * @return least-squares solution of shape `(N,)` or `(N, K)`
969
1582
  */
970
- declare function argsort(a: ArrayLike, axis?: number): Array;
1583
+ declare function lstsq(a: ArrayLike, b: ArrayLike): Array;
1584
+ /** Raise a square matrix to an integer power, via repeated squarings. */
1585
+ declare function matrixPower(a: ArrayLike, n: number): Array;
1586
+ /** Return sign and natural logarithm of the determinant of `a`. */
1587
+ declare function slogdet(a: ArrayLike): [Array, Array];
971
1588
  /**
972
- * Take elements from an array along an axis.
1589
+ * Solve a linear system of equations.
973
1590
  *
974
- * This is equivalent to advanced indexing with integer indices over that
975
- * numbered axis. By default, the flattened array is used.
1591
+ * This solves a (batched) linear system of equations `a @ x = b` for `x` given
1592
+ * `a` and `b`. If `a` is singular, this will return `nan` or `inf` values.
1593
+ *
1594
+ * @param a - Coefficient matrix of shape `(..., N, N)`.
1595
+ * @param b - Values of shape `(N,)` or `(..., N, M)`.
1596
+ * @returns Solution `x` of shape `(..., N)` or `(..., N, M)`.
976
1597
  */
977
- declare function take(a: ArrayLike, indices: ArrayLike, axis?: number | null): Array;
978
- /** Return if two arrays are element-wise equal within a tolerance. */
979
- declare function allclose(actual: Parameters<typeof array>[0], expected: Parameters<typeof array>[0], options?: {
980
- rtol?: number;
981
- atol?: number;
982
- }): boolean;
983
- /** Matrix product of two arrays. */
984
- declare function matmul(x: ArrayLike, y: ArrayLike): Array;
985
- /** Dot product of two arrays. */
986
- declare function dot$1(x: ArrayLike, y: ArrayLike): Array;
1598
+ declare function solve(a: ArrayLike, b: ArrayLike): Array;
1599
+ //#endregion
1600
+ //#region src/library/numpy/dtype-info.d.ts
1601
+ /** @inline */
1602
+ type FInfo = Readonly<{
1603
+ /** The number of bits occupied by the type. */
1604
+ bits: number;
1605
+ /** Returns the _dtype_ for which finfo returns information. */
1606
+ dtype: DType;
1607
+ /** The difference between 1.0 and the next smallest representable float larger than 1.0. */
1608
+ eps: number;
1609
+ /** The difference between 1.0 and the next largest representable float smaller than 1.0. */
1610
+ epsneg: number;
1611
+ /** The exponent that yields `eps`. */
1612
+ machep: number;
1613
+ /** The largest representable finite number. */
1614
+ max: number;
1615
+ /** The smallest positive power of the base (2) that causes overflow. */
1616
+ maxexp: number;
1617
+ /** The smallest representable (most negative) finite number. */
1618
+ min: number;
1619
+ /** The largest negative power of the base (2) without leading zeros in mantissa. */
1620
+ minexp: number;
1621
+ /** The exponent that yields `epsneg`. */
1622
+ negep: number;
1623
+ /** Number of bits in the exponent portion. */
1624
+ nexp: number;
1625
+ /** Number of bits in the mantissa portion. */
1626
+ nmant: number;
1627
+ /** The approximate number of decimal digits to which this kind of float is precise. */
1628
+ precision: number;
1629
+ /** The approximate decimal resolution, i.e., `10 ** -precision`. */
1630
+ resolution: number;
1631
+ /** The smallest positive normal number. */
1632
+ smallestNormal: number;
1633
+ /** The smallest positive subnormal number. */
1634
+ smallestSubnormal: number;
1635
+ }>;
1636
+ /** Machine limits for floating-point types. */
1637
+ declare function finfo(dtype: DType): FInfo;
1638
+ /** @inline */
1639
+ type IInfo = Readonly<{
1640
+ /** The number of bits occupied by the type. */
1641
+ bits: number;
1642
+ /** Returns the _dtype_ for which iinfo returns information. */
1643
+ dtype: DType;
1644
+ /** The largest representable integer. */
1645
+ max: number;
1646
+ /** The smallest representable integer. */
1647
+ min: number;
1648
+ }>;
1649
+ /** Machine limits for integer types. */
1650
+ declare function iinfo(dtype: DType): IInfo;
1651
+ declare namespace numpy_d_exports {
1652
+ export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logspace, matmul, matrixTranspose, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, nanToNum, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
1653
+ }
1654
+ declare const float32 = DType.Float32;
1655
+ declare const int32 = DType.Int32;
1656
+ declare const uint32 = DType.Uint32;
1657
+ declare const bool = DType.Bool;
1658
+ declare const float16 = DType.Float16;
1659
+ declare const float64 = DType.Float64;
1660
+ /** Euler's constant, `e = 2.7182818284590...` */
1661
+ declare const e: number;
1662
+ /** Euler-Mascheroni constant, `γ = 0.5772156649...` */
1663
+ declare const eulerGamma = 0.5772156649015329;
1664
+ /** Positive infinity. */
1665
+ declare const inf: number;
1666
+ /** Floating-point representation of NaN. */
1667
+ declare const nan: number;
1668
+ /** This is Pi, `π = 3.14159265358979...` */
1669
+ declare const pi: number;
1670
+ /** @function Element-wise addition, with broadcasting. */
1671
+ declare const add: (x: ArrayLike, y: ArrayLike) => Array;
1672
+ /** @function Element-wise multiplication, with broadcasting. */
1673
+ declare const multiply: (x: ArrayLike, y: ArrayLike) => Array;
1674
+ /** @function Numerical negative of every element of an array. */
1675
+ declare const negative: (x: ArrayLike) => Array;
1676
+ /** @function Calculate element-wise reciprocal of the input. This is `1/x`. */
1677
+ declare const reciprocal: (x: ArrayLike) => Array;
1678
+ /** @function Round input down to the nearest integer. */
1679
+ declare const floor: (x: ArrayLike) => Array;
1680
+ /** @function Round input up to the nearest integer. */
1681
+ declare const ceil: (x: ArrayLike) => Array;
1682
+ /** @function Element-wise sine function (takes radians). */
1683
+ declare const sin: (x: ArrayLike) => Array;
1684
+ /** @function Element-wise cosine function (takes radians). */
1685
+ declare const cos: (x: ArrayLike) => Array;
1686
+ /** @function Element-wise inverse sine function (inverse of sin). */
1687
+ declare const asin: (x: ArrayLike) => Array;
1688
+ /** @function Element-wise inverse tangent function (inverse of tan). */
1689
+ declare const atan: (x: ArrayLike) => Array;
1690
+ /** @function Calculate the exponential of all elements in the input array. */
1691
+ declare const exp: (x: ArrayLike) => Array;
1692
+ /** @function Calculate the natural logarithm of all elements in the input array. */
1693
+ declare const log: (x: ArrayLike) => Array;
1694
+ /** @function Calculate the square root of all elements in the input array. */
1695
+ declare const sqrt: (x: ArrayLike) => Array;
1696
+ /** @function Return element-wise minimum of the input arrays. */
1697
+ declare const minimum: (x: ArrayLike, y: ArrayLike) => Array;
1698
+ /** @function Return element-wise maximum of the input arrays. */
1699
+ declare const maximum: (x: ArrayLike, y: ArrayLike) => Array;
1700
+ /** @function Compare two arrays element-wise. */
1701
+ declare const greater: (x: ArrayLike, y: ArrayLike) => Array;
1702
+ /** @function Compare two arrays element-wise. */
1703
+ declare const less: (x: ArrayLike, y: ArrayLike) => Array;
1704
+ /** @function Compare two arrays element-wise. */
1705
+ declare const equal: (x: ArrayLike, y: ArrayLike) => Array;
1706
+ /** @function Compare two arrays element-wise. */
1707
+ declare const notEqual: (x: ArrayLike, y: ArrayLike) => Array;
1708
+ /** @function Compare two arrays element-wise. */
1709
+ declare const greaterEqual: (x: ArrayLike, y: ArrayLike) => Array;
1710
+ /** @function Compare two arrays element-wise. */
1711
+ declare const lessEqual: (x: ArrayLike, y: ArrayLike) => Array;
1712
+ /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
1713
+ declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
987
1714
  /**
988
- * Compute the tensor dot product of two N-dimensional arrays.
989
- *
990
- * The behavior is determined by `axes`. If an integer `k`, sum over the last
991
- * `k` axes of x and the first `k` axes of y. If a tuple, then the first array
992
- * corresponds to the axes of x and the second to the axes of y.
1715
+ * @function
1716
+ * Permute the dimensions of an array. Defaults to reversing the axis order.
993
1717
  */
994
- declare function tensordot(x: ArrayLike, y: ArrayLike, axes?: number | [number[], number[]]): Array;
1718
+ declare const transpose: (x: ArrayLike, perm?: number[]) => Array;
995
1719
  /**
996
- * Einstein summation with string subscripts.
997
- *
998
- * @example
999
- * ```ts
1000
- * import { numpy as np } from "@jax-js/jax";
1720
+ * @function
1721
+ * Give a new shape to an array without changing its data.
1001
1722
  *
1002
- * const a = np.ones([2, 3]);
1003
- * const b = np.ones([3]);
1004
- * np.einsum("ij,j", a, b); // Shape [2]
1005
- * ```
1723
+ * One shape dimension can be -1. In this case, the value is inferred from the
1724
+ * length of the array and remaining dimensions.
1725
+ */
1726
+ declare const reshape: (x: ArrayLike, shape: number[]) => Array;
1727
+ /**
1728
+ * @function
1729
+ * Move axes of an array to new positions. Other axes retain original order.
1006
1730
  */
1007
- declare function einsum(subscripts: string, ...operands: ArrayLike[]): Array;
1731
+ declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
1008
1732
  /**
1009
- * Einstein summation alternating between arrays and numeric indices.
1010
- *
1011
- * @example
1012
- * ```ts
1013
- * import { numpy as np } from "@jax-js/jax";
1733
+ * @function
1734
+ * Add padding (zeros) to an array.
1014
1735
  *
1015
- * const a = np.ones([2, 3]);
1016
- * const b = np.ones([3]);
1017
- * np.einsum(a, [0, 1], b, [1]); // Shape [2]
1018
- * ```
1736
+ * The `width` argument is either an integer or pair of integers, in which case
1737
+ * all axes are padded with the same width. Or if it is an array of pairs, each
1738
+ * pair specifies the padding for its corresponding axis.
1019
1739
  */
1020
- declare function einsum(...args: (ArrayLike | number[])[]): Array;
1740
+ declare const pad: (x: ArrayLike, width: number | Pair | Pair[] | Record<number, Pair>) => Array;
1021
1741
  /**
1022
- * Compute the inner product of two arrays.
1023
- *
1024
- * Unlike `jax.numpy.matmul()` or `jax.numpy.dot()`, this always performs a
1025
- * contraction on the last axis.
1026
- *
1027
- * Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
1742
+ * @function
1743
+ * Return the number of dimensions of an array. Does not consume array reference.
1028
1744
  */
1029
- declare function inner(x: ArrayLike, y: ArrayLike): Array;
1745
+ declare const ndim: (x: ArrayLike) => number;
1746
+ /** @function Return the shape of an array. Does not consume array reference. */
1747
+ declare const shape$1: (x: ArrayLike) => number[];
1030
1748
  /**
1031
- * Compute the outer product of two arrays.
1032
- *
1033
- * If the input arrays are not 1D, they will be flattened. Returned array will
1034
- * be of shape `[x.size, y.size]`.
1749
+ * @function
1750
+ * Return an array of zeros with the same shape and type as a given array.
1035
1751
  */
1036
- declare function outer(x: ArrayLike, y: ArrayLike): Array;
1037
- /** Vector dot product of two arrays along a given axis. */
1038
- declare function vecdot(x: ArrayLike, y: ArrayLike, {
1039
- axis
1040
- }?: {
1041
- axis?: number;
1042
- }): Array;
1752
+ declare const zerosLike: (a: ArrayLike, dtype?: DType) => Array;
1043
1753
  /**
1044
- * Return the dot product of two vectors.
1045
- *
1046
- * Like vecdot() but flattens the arguments first into vectors.
1754
+ * @function
1755
+ * Return an array of ones with the same shape and type as a given array.
1047
1756
  */
1048
- declare function vdot(x: ArrayLike, y: ArrayLike): Array;
1049
- /** Convolution of two one-dimensional arrays. */
1050
- declare function convolve(x: Array, y: Array, mode?: "full" | "same" | "valid"): Array;
1051
- /** Correlation of two one dimensional arrays. */
1052
- declare function correlate(x: Array, y: Array, mode?: "full" | "same" | "valid"): Array;
1757
+ declare const onesLike: (a: ArrayLike, dtype?: DType) => Array;
1053
1758
  /**
1054
- * Return a tuple of coordinate matrices from coordinate vectors.
1055
- *
1056
- * Make N-D coordinate arrays for vectorized evaluations of N-D scalar/vector
1057
- * fields over N-D grids, given one-dimensional coordinate arrays x1, x2,…, xn.
1759
+ * @function
1760
+ * Return a full array with the same shape and type as a given array.
1058
1761
  */
1059
- declare function meshgrid(xs: Array[], {
1060
- indexing
1061
- }?: {
1062
- indexing?: "xy" | "ij";
1063
- }): Array[];
1762
+ declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, dtype?: DType) => Array;
1064
1763
  /**
1065
- * Clip (limit) the values in an array.
1066
- *
1067
- * Given an interval, values outside the interval are clipped to the interval
1068
- * edges. For example, if an interval of [0, 1] is specified, values smaller
1069
- * than 0 become 0, and values larger than 1 become 1.
1070
- *
1071
- * If either bound is undefined, it is ignored.
1764
+ * Return the number of elements in an array, optionally along an axis.
1765
+ * Does not consume array reference.
1072
1766
  */
1073
- declare function clip(a: ArrayLike, min?: ArrayLike, max?: ArrayLike): Array;
1767
+ declare function size(a: ArrayLike, axis?: number): number;
1768
+ /** Convert an array to a specified dtype. */
1769
+ declare function astype(a: ArrayLike, dtype: DType): Array;
1770
+ /** Sum of the elements of the array over a given axis, or axes. */
1771
+ declare function sum(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1772
+ /** Product of the array elements over a given axis. */
1773
+ declare function prod(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1774
+ /** Return the minimum of array elements along a given axis. */
1775
+ declare function min(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1776
+ /** Return the maximum of array elements along a given axis. */
1777
+ declare function max(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1074
1778
  /**
1075
- * Calculate the absolute value element-wise.
1779
+ * Test whether any array element along a given axis evaluates to True.
1076
1780
  *
1077
- * This is the same function as `jax.numpy.abs()`.
1781
+ * Returns a boolean array with the same shape as `a` with the specified axis
1782
+ * removed. If axis is None, returns a scalar.
1078
1783
  */
1079
- declare function absolute(x: ArrayLike): Array;
1080
- /** Return an element-wise indication of sign of the input. */
1081
- declare function sign(x: ArrayLike): Array;
1082
- /** @function Return element-wise positive values of the input (no-op). */
1083
- declare const positive: (x: ArrayLike) => Array;
1784
+ declare function any(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1084
1785
  /**
1085
- * Return the Hamming window of size M, a taper with a weighted cosine bell.
1786
+ * Test whether all array elements along a given axis evaluate to True.
1086
1787
  *
1087
- * `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
1788
+ * Returns a boolean array with the same shape as `a` with the specified axis
1789
+ * removed. If axis is None, returns a scalar.
1088
1790
  */
1089
- declare function hamming(M: number): Array;
1791
+ declare function all(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1792
+ /** Return the peak-to-peak range along a given axis (`max - min`). */
1793
+ declare function ptp(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1794
+ /** Compute the average of the array elements along the specified axis. */
1795
+ declare function mean(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1090
1796
  /**
1091
- * Return the Hann window of size M, a taper with a weighted cosine bell.
1797
+ * Returns the indices of the minimum values along an axis.
1092
1798
  *
1093
- * `w(n) = 0.5 - 0.5 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
1799
+ * By default, index is into the flatted array, otherwise it is along the
1800
+ * specified axis.
1094
1801
  */
1095
- declare function hann(M: number): Array;
1802
+ declare function argmin(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
1096
1803
  /**
1097
- * @function
1098
- * Compute the Heaviside step function. It is defined piecewise:
1099
- * - `heaviside(x1, x2) = 0` for `x1 < 0`,
1100
- * - `heaviside(x1, x2) = x2` for `x1 == 0`,
1101
- * - `heaviside(x1, x2) = 1` for `x1 > 0`.
1804
+ * Returns the indices of the maximum values along an axis.
1805
+ *
1806
+ * By default, index is into the flatted array, otherwise it is along the
1807
+ * specified axis.
1102
1808
  */
1103
- declare const heaviside: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
1104
- /** Calculate element-wise square of the input array. */
1105
- declare function square(x: ArrayLike): Array;
1106
- /** Element-wise tangent function (takes radians). */
1107
- declare function tan(x: ArrayLike): Array;
1809
+ declare function argmax(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
1108
1810
  /**
1109
- * @function
1110
- * Return the normalized sinc function.
1111
- *
1112
- * The sinc function is defined as `sin(πx) / (πx)` for `x != 0`, and `1` for `x = 0`.
1113
- * This is the normalized sinc function commonly used in signal processing.
1811
+ * Cumulative sum of elements along an axis.
1114
1812
  *
1115
- * **Note:** JVP is not supported at x=0 due to discontinuous derivative. This
1116
- * requires a custom JVP rule to handle properly (see JAX implementation).
1813
+ * Currently this function is `O(n^2)`, we'll improve this later on with a
1814
+ * two-phase parallel reduction algorithm.
1117
1815
  */
1118
- declare const sinc: OwnedFunction<(x: ArrayLike) => Array>;
1119
- /** Element-wise inverse cosine function (inverse of cos). */
1120
- declare function acos(x: ArrayLike): Array;
1816
+ declare function cumsum(a: ArrayLike, axis?: number): Array;
1817
+ /** Reverse the elements in an array along the given axes. */
1818
+ declare function flip(x: ArrayLike, axis?: Axis): Array;
1121
1819
  /**
1122
- * @function
1123
- * Return element-wise hypotenuse for the given legs of a right triangle.
1820
+ * Split an array into multiple sub-arrays along an axis.
1124
1821
  *
1125
- * In the original NumPy/JAX implementation, this function is more numerically
1126
- * stable than `sqrt(x1**2 + x2**2)`. We don't currently implement those
1127
- * stability improvements.
1822
+ * @param a - The input array to split.
1823
+ * @param indicesOrSections - If an integer, it indicates the number of equal
1824
+ * sections to create along the specified axis. If a list of integers, it
1825
+ * specifies the indices at which to split the array.
1826
+ * @param axis - The axis along which to split the array. Default is 0.
1128
1827
  */
1129
- declare const hypot: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
1828
+ declare function split$1(a: ArrayLike, indicesOrSections: number | number[], axis?: number): Array[];
1130
1829
  /**
1131
- * @function
1132
- * Element-wise arc tangent of y/x with correct quadrant.
1133
- *
1134
- * Returns the angle in radians between the positive x-axis and the point (x, y).
1135
- * The result is in the range [-π, π].
1830
+ * Join a sequence of arrays along an existing axis.
1136
1831
  *
1137
- * Uses numerically stable formulas:
1138
- * - When x >= 0: atan2(y, x) = 2 * atan(y / (sqrt(x^2 + y^2) + x))
1139
- * - When x < 0: atan2(y, x) = 2 * atan((sqrt(x^2 + y^2) - x) / y)
1832
+ * The arrays must have the same shape, except in the dimension corresponding to
1833
+ * `axis` (the first, by default).
1140
1834
  *
1141
- * The output is ill-defined when both x and y are zero.
1835
+ * No scalars can be passed to this function, as the axis is then ambiguous.
1142
1836
  */
1143
- declare const atan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
1144
- /** Element-wise subtraction, with broadcasting. */
1145
- declare function subtract(x: ArrayLike, y: ArrayLike): Array;
1146
- /** Calculates the floating-point division of x by y element-wise. */
1147
- declare function trueDivide(x: ArrayLike, y: ArrayLike): Array;
1837
+ declare function concatenate(xs: Array[], axis?: number): Array;
1148
1838
  /**
1149
- * Return the largest integer smaller or equal to the division of the inputs.
1150
- *
1151
- * The result is always rounded towards negative infinity.
1839
+ * Join a sequence of arrays along a new axis.
1152
1840
  *
1153
- * For floating-point inputs, this is equivalent to `floor(x / y)`.
1154
- * For integer inputs, we use `(x - remainder(x, y)) / y` to handle
1155
- * negative values correctly (note: may overflow near int32 boundaries).
1841
+ * The `axis` parameter specifies the index of the new axis in the dimensions of
1842
+ * the result. For example, if `axis=0` it will be the first dimension and if
1843
+ * `axis=-1` it will be the last dimension.
1156
1844
  *
1157
- * @param x - Dividend array.
1158
- * @param y - Divisor array.
1159
- * @returns Element-wise floor division of x by y.
1845
+ * All shapes must have the same shape.
1160
1846
  */
1161
- declare function floorDivide(x: ArrayLike, y: ArrayLike): Array;
1847
+ declare function stack(xs: ArrayLike[], axis?: number): Array;
1162
1848
  /**
1163
- * @function
1164
- * Calculate element-wise floating-point modulo operation.
1849
+ * Horizontally stack arrays. Inputs are promoted to rank at least 1, then
1850
+ * concatenated along axis 1 (if rank-2 or higher) or 0 (if rank-1).
1165
1851
  */
1166
- declare const fmod: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
1852
+ declare function hstack(xs: ArrayLike[]): Array;
1167
1853
  /**
1168
- * @function
1169
- * Calculate element-wise remainder of the division (matches sign of y).
1854
+ * Vertically stack arrays. Inputs are promoted to rank at least 2, then
1855
+ * concatenated along axis 0.
1170
1856
  */
1171
- declare const remainder: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
1857
+ declare function vstack(xs: ArrayLike[]): Array;
1172
1858
  /**
1173
- * Return element-wise quotient and remainder simultaneously.
1859
+ * Stack arrays depth-wise. Inputs are promoted to rank at least 3, then
1860
+ * concatenated along axis 2.
1861
+ */
1862
+ declare function dstack(xs: ArrayLike[]): Array;
1863
+ /**
1864
+ * Stack arrays column-wise. Inputs are promoted to rank at least 2, then
1865
+ * concatenated along axis 1.
1866
+ */
1867
+ declare function columnStack(xs: ArrayLike[]): Array;
1868
+ /** Flip an array vertically (axis=0). */
1869
+ declare function flipud(x: ArrayLike): Array;
1870
+ /** Flip an array horizontally (axis=1). */
1871
+ declare function fliplr(x: ArrayLike): Array;
1872
+ /** Interchange two axes of an array. */
1873
+ declare function swapaxes(a: ArrayLike, axis1: number, axis2: number): Array;
1874
+ /** Transpose the last two dimensions of an array. */
1875
+ declare function matrixTranspose(a: ArrayLike): Array;
1876
+ /** Return a 1-D flattened array containing the elements of the input. */
1877
+ declare function ravel(a: ArrayLike): Array;
1878
+ /** Remove one or more length-1 axes from an array. */
1879
+ declare function squeeze(a: ArrayLike, axis?: Axis): Array;
1880
+ /**
1881
+ * Expand the shape of an array by inserting new axes of length 1.
1174
1882
  *
1175
- * Equivalent to `[floorDivide(x, y), remainder(x, y)]`.
1883
+ * @param a - Input array.
1884
+ * @param axis - Position(s) in the expanded axes where the new axis (or axes)
1885
+ * is placed. Can be a single integer or an array of integers.
1886
+ * @returns Array with the number of dimensions increased.
1176
1887
  *
1177
- * @param x - Dividend array.
1178
- * @param y - Divisor array.
1179
- * @returns Tuple of [quotient, remainder].
1888
+ * @example
1889
+ * ```ts
1890
+ * const x = np.array([1, 2]);
1891
+ * np.expandDims(x, 0); // Shape [1, 2]
1892
+ * np.expandDims(x, 1); // Shape [2, 1]
1893
+ * np.expandDims(x, [0, 2]); // Shape [1, 2, 1]
1894
+ * ```
1180
1895
  */
1181
- declare function divmod(x: ArrayLike, y: ArrayLike): [Array, Array];
1182
- /** Round input to the nearest integer towards zero. */
1183
- declare function trunc(x: ArrayLike): Array;
1896
+ declare function expandDims(a: ArrayLike, axis: number | number[]): Array;
1897
+ /**
1898
+ * Repeat each element of an array after themselves.
1899
+ *
1900
+ * If no axis is provided, use the flattened input array, and return a flat
1901
+ * output array.
1902
+ */
1903
+ declare function repeat(a: ArrayLike, repeats: number, axis?: number): Array;
1184
1904
  /**
1185
- * Compute `x1 * 2 ** x2` as a standard multiplication and exponentiation.
1905
+ * Construct an array by repeating A the number of times given by reps.
1186
1906
  *
1187
- * This is the inverse of `frexp()`.
1907
+ * If `A` is an array of shape `(d1, d2, ..., dn)` and `reps` is a sequence of
1908
+ * integers, the resulting array will have a shape of `(reps[0] * d1,
1909
+ * reps[1] * d2, ..., reps[n] * dn)`, with `A` tiled along each dimension.
1188
1910
  */
1189
- declare function ldexp(x1: ArrayLike, x2: ArrayLike): Array;
1911
+ declare function tile(a: ArrayLike, reps: number | number[]): Array;
1190
1912
  /**
1191
- * Decompose floating-point values into mantissa and two's exponent.
1913
+ * Broadcast an array to a shape, with NumPy-style broadcasing rules.
1192
1914
  *
1193
- * The mantissa is returned in the range `(-1, 1)` with magnitude `>= 0.5` if
1194
- * `x != 0`, and the exponent is an integer such that
1195
- * `x = mantissa * 2**exponent`.
1915
+ * In other words, this lets you append axes to the left, and/or expand
1916
+ * dimensions where the shape is 1.
1196
1917
  */
1197
- declare function frexp(x: ArrayLike): [Array, Array];
1198
- /** Calculate `2**p` for all p in the input array. */
1199
- declare function exp2(p: ArrayLike): Array;
1200
- /** Return the base-2 logarithm of x, element-wise. */
1201
- declare function log2(x: ArrayLike): Array;
1202
- /** Return the base-10 logarithm of x, element-wise. */
1203
- declare function log10(x: ArrayLike): Array;
1204
- /** Calculate `exp(x) - 1` element-wise. */
1205
- declare function expm1(x: ArrayLike): Array;
1206
- /** Calculate the natural logarithm of `1 + x` element-wise. */
1207
- declare function log1p(x: ArrayLike): Array;
1208
- /** Convert angles from degrees to radians. */
1209
- declare function deg2rad(x: ArrayLike): Array;
1210
- /** @function Alias of `jax.numpy.deg2rad()`. */
1211
- declare const radians: typeof deg2rad;
1212
- /** Convert angles from radians to degrees. */
1213
- declare function rad2deg(x: ArrayLike): Array;
1214
- /** @function Alias of `jax.numpy.rad2deg()`. */
1215
- declare const degrees: typeof rad2deg;
1918
+ declare function broadcastTo(a: ArrayLike, shape: number[]): Array;
1919
+ /** Broadcast input shapes to a common output shape. */
1920
+ declare function broadcastShapes(...shapes: number[][]): number[];
1921
+ /** Broadcast arrays to a common shape. */
1922
+ declare function broadcastArrays(...arrays: ArrayLike[]): Array[];
1216
1923
  /**
1217
- * @function
1218
- * Computes first array raised to power of second array, element-wise.
1924
+ * Return specified diagonals.
1925
+ *
1926
+ * If a is 2D, return the diagonal of the array with the given offset. If a is
1927
+ * 3D or higher, compute diagonals along the two given axes (default: 0, 1).
1928
+ *
1929
+ * This returns a view over the existing array. The shape of the resulting array
1930
+ * is determined by removing the two axes along which the diagonal is taken,
1931
+ * then appending a new axis to the right with holding the diagonals.
1219
1932
  */
1220
- declare const power: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
1221
- /** @function Calculate the element-wise cube root of the input array. */
1222
- declare const cbrt: OwnedFunction<(x: ArrayLike) => Array>;
1933
+ declare function diagonal(a: ArrayLike, offset?: number, axis1?: number, axis2?: number): Array;
1223
1934
  /**
1224
- * @function
1225
- * Calculate element-wise hyperbolic sine of input.
1935
+ * Extract a diagonal or construct a diagonal array.
1226
1936
  *
1227
- * `sinh(x) = (exp(x) - exp(-x)) / 2`
1937
+ * If v is a 2D array, return the k-th diagonal of v (as a view). If v is a 1D
1938
+ * array, return a 2D array with v on the k-th diagonal.
1228
1939
  */
1229
- declare const sinh: OwnedFunction<(x: ArrayLike) => Array>;
1940
+ declare function diag(v: ArrayLike, k?: number): Array;
1941
+ /** Calculate the sum of the diagonal of an array along the given axes. */
1942
+ declare function trace(a: ArrayLike, offset?: number, axis1?: number, axis2?: number): Array;
1230
1943
  /**
1231
- * @function
1232
- * Calculate element-wise hyperbolic cosine of input.
1944
+ * Return a sorted copy of an array.
1233
1945
  *
1234
- * `cosh(x) = (exp(x) + exp(-x)) / 2`
1946
+ * The array is sorted along a specified axis (the last by default). This may be
1947
+ * an unstable sort, and it dispatches to device-specific implementation.
1235
1948
  */
1236
- declare const cosh: OwnedFunction<(x: ArrayLike) => Array>;
1949
+ declare function sort(a: ArrayLike, axis?: number): Array;
1237
1950
  /**
1238
- * @function
1239
- * Calculate element-wise hyperbolic tangent of input.
1951
+ * Return indices that would sort an array. This may be an unstable sorting
1952
+ * algorithm; it need not preserve order of indices in ties.
1240
1953
  *
1241
- * `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
1954
+ * Returns an array of `int32` indices.
1955
+ *
1956
+ * The array is sorted along a specified axis (the last by default).
1242
1957
  */
1243
- declare const tanh: OwnedFunction<(x: ArrayLike) => Array>;
1958
+ declare function argsort(a: ArrayLike, axis?: number): Array;
1244
1959
  /**
1245
- * @function
1246
- * Calculate element-wise inverse hyperbolic sine of input.
1960
+ * Take elements from an array along an axis.
1247
1961
  *
1248
- * `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
1962
+ * This is equivalent to advanced indexing with integer indices over that
1963
+ * numbered axis. By default, the flattened array is used.
1249
1964
  */
1250
- declare const arcsinh: OwnedFunction<(x: ArrayLike) => Array>;
1965
+ declare function take(a: ArrayLike, indices: ArrayLike, axis?: number | null): Array;
1966
+ /** Return if two arrays are element-wise equal within a tolerance. */
1967
+ declare function allclose(actual: Parameters<typeof array>[0], expected: Parameters<typeof array>[0], options?: {
1968
+ rtol?: number;
1969
+ atol?: number;
1970
+ }): boolean;
1971
+ /** Matrix product of two arrays. */
1972
+ declare function matmul(x: ArrayLike, y: ArrayLike): Array;
1973
+ /** Dot product of two arrays. */
1974
+ declare function dot(x: ArrayLike, y: ArrayLike): Array;
1251
1975
  /**
1252
- * @function
1253
- * Calculate element-wise inverse hyperbolic cosine of input.
1976
+ * Compute the tensor dot product of two N-dimensional arrays.
1254
1977
  *
1255
- * `arccosh(x) = ln(x + sqrt(x^2 - 1))`
1978
+ * The behavior is determined by `axes`. If an integer `k`, sum over the last
1979
+ * `k` axes of x and the first `k` axes of y. If a tuple, then the first array
1980
+ * corresponds to the axes of x and the second to the axes of y.
1256
1981
  */
1257
- declare const arccosh: OwnedFunction<(x: ArrayLike) => Array>;
1982
+ declare function tensordot(x: ArrayLike, y: ArrayLike, axes?: number | [number[], number[]]): Array;
1258
1983
  /**
1259
- * @function
1260
- * Calculate element-wise inverse hyperbolic tangent of input.
1984
+ * Einstein summation with string subscripts.
1261
1985
  *
1262
- * `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
1986
+ * @example
1987
+ * ```ts
1988
+ * import { numpy as np } from "@jax-js/jax";
1989
+ *
1990
+ * const a = np.ones([2, 3]);
1991
+ * const b = np.ones([3]);
1992
+ * np.einsum("ij,j", a, b); // Shape [2]
1993
+ * ```
1263
1994
  */
1264
- declare const arctanh: OwnedFunction<(x: ArrayLike) => Array>;
1995
+ declare function einsum(subscripts: string, ...operands: ArrayLike[]): Array;
1265
1996
  /**
1266
- * Compute the variance of an array.
1997
+ * Einstein summation alternating between arrays and numeric indices.
1267
1998
  *
1268
- * The variance is computed for the flattened array by default, otherwise over
1269
- * the specified axis.
1999
+ * @example
2000
+ * ```ts
2001
+ * import { numpy as np } from "@jax-js/jax";
1270
2002
  *
1271
- * If `correction` is provided, the divisor in calculation is `N - correction`,
1272
- * where `N` represents the number of elements (e.g., for Bessel's correction).
2003
+ * const a = np.ones([2, 3]);
2004
+ * const b = np.ones([3]);
2005
+ * np.einsum(a, [0, 1], b, [1]); // Shape [2]
2006
+ * ```
1273
2007
  */
1274
- declare function var_(x: ArrayLike, axis?: Axis, opts?: {
1275
- mean?: ArrayLike;
1276
- correction?: number;
1277
- } & ReduceOpts): Array;
2008
+ declare function einsum(...args: (ArrayLike | number[])[]): Array;
1278
2009
  /**
1279
- * Compute the standard deviation of an array.
2010
+ * Compute the inner product of two arrays.
1280
2011
  *
1281
- * The standard deviation is computed for the flattened array by default,
1282
- * otherwise over the specified axis.
2012
+ * Unlike `jax.numpy.matmul()` or `jax.numpy.dot()`, this always performs a
2013
+ * contraction on the last axis.
1283
2014
  *
1284
- * If `correction` is provided, the divisor in calculation is `N - correction`,
1285
- * where `N` represents the number of elements (e.g., for Bessel's correction).
2015
+ * Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
1286
2016
  */
1287
- declare function std(x: ArrayLike, axis?: Axis, opts?: {
1288
- mean?: ArrayLike;
1289
- correction?: number;
1290
- } & ReduceOpts): Array;
1291
- /** Estimate the sample covariance of a set of variables. */
1292
- declare function cov(x: ArrayLike, y?: ArrayLike | null, {
1293
- rowvar
2017
+ declare function inner(x: ArrayLike, y: ArrayLike): Array;
2018
+ /**
2019
+ * Compute the outer product of two arrays.
2020
+ *
2021
+ * If the input arrays are not 1D, they will be flattened. Returned array will
2022
+ * be of shape `[x.size, y.size]`.
2023
+ */
2024
+ declare function outer(x: ArrayLike, y: ArrayLike): Array;
2025
+ /** Vector dot product of two arrays along a given axis. */
2026
+ declare function vecdot(x: ArrayLike, y: ArrayLike, {
2027
+ axis
1294
2028
  }?: {
1295
- rowvar?: boolean;
2029
+ axis?: number;
1296
2030
  }): Array;
1297
- /** Compute the Pearson correlation coefficients (in range `[-1, 1]`). */
1298
- declare function corrcoef(x: ArrayLike, y?: ArrayLike): Array;
1299
- /** Test element-wise for positive or negative infinity, return bool array. */
1300
- declare function isinf(x: ArrayLike): Array;
1301
- /** Test element-wise for NaN (Not a Number). */
1302
- declare function isnan(x: ArrayLike): Array;
1303
- /** Test element-wise for negative infinity, return bool array. */
1304
- declare function isneginf(x: ArrayLike): Array;
1305
- /** Test element-wise for positive infinity, return bool array. */
1306
- declare function isposinf(x: ArrayLike): Array;
1307
2031
  /**
1308
- * @function
1309
- * Test element-wise for finite values (not infinity or NaN).
2032
+ * Return the dot product of two vectors.
2033
+ *
2034
+ * Like vecdot() but flattens the arguments first into vectors.
1310
2035
  */
1311
- declare const isfinite: OwnedFunction<(x: ArrayLike) => Array>;
1312
- declare namespace tree_d_exports {
1313
- export { JsTree, JsTreeDef, MapJsTree, NodeType, dispose, flatten, leaves, map, ref, structure, unflatten };
1314
- }
1315
- declare enum NodeType {
1316
- Array = "Array",
1317
- Object = "Object",
1318
- Leaf = "Leaf",
1319
- }
1320
- /** Analog to the JAX "pytree" object, but for JavaScript. */
1321
- type JsTree<T> = T | JsTree<T>[] | {
1322
- [key: string]: JsTree<T>;
1323
- };
1324
- type Same<X, Y> = (<T>() => T extends X ? 1 : 2) extends (<T>() => T extends Y ? 1 : 2) ? true : false;
1325
- type MappedJsTree<T, A, B> = T extends A ? B : T extends Array ? T : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
1326
- /** @ignore Convert a subtype of JsTree<A> into a JsTree<B>, with the same structure. */
1327
- type MapJsTree<T, A, B> = Same<A, B> extends true ? T : MappedJsTree<T, A, B>;
1328
- /** Represents the structure of a JsTree. */
1329
- declare class JsTreeDef {
1330
- readonly nodeType: NodeType;
1331
- readonly nodeMetadata: any;
1332
- readonly childTreedefs: JsTreeDef[];
1333
- static leaf: JsTreeDef;
1334
- constructor(nodeType: NodeType, nodeMetadata: any,
1335
- // Must be comparable with deepEqual.
1336
- childTreedefs: JsTreeDef[]);
1337
- /** Get the total number of leaves in the tree. */
1338
- get size(): number;
1339
- /** Returns a string representation of this tree definition. */
1340
- toString(root?: boolean): string;
1341
- /** Compare this tree definition with another. */
1342
- equals(other: JsTreeDef): boolean;
1343
- }
1344
- /** Flatten a structured object, returning the tree definition. */
1345
- declare function flatten<T>(tree: JsTree<T>): [T[], JsTreeDef];
1346
- /** Get the leaves of a tree. */
1347
- declare function leaves<T>(tree: JsTree<T>): T[];
1348
- /** Get the treedef for a tree. */
1349
- declare function structure<T>(tree: JsTree<T>): JsTreeDef;
1350
- /** Reconstruct a structured object from the flattened representation. */
1351
- declare function unflatten<T>(treedef: JsTreeDef, leaves: Iterable<T>): JsTree<T>;
1352
- /** Maps a multi-input function over pytree args to produce a new pytree. */
1353
- declare function map<T, U, Tree extends JsTree<T>>(fn: (...args: T[]) => U, tree: Tree, ...rest: Tree[]): MapJsTree<Tree, T, U>;
1354
- /** Take a reference of every array in a tree. */
1355
- declare function ref<Tree extends JsTree<Array>>(tree: Tree): Tree;
1356
- /** Dispose every array in a tree. */
1357
- declare function dispose<Tree extends JsTree<Array>>(tree: Tree | null | undefined): void;
1358
- //#endregion
1359
- //#region src/frontend/convolution.d.ts
1360
- /** Definition of a general dilated convolution. Should be valid on creation. */
1361
- interface ConvParams {
1362
- vmapDims: number;
1363
- strides: number[];
1364
- padding: Pair[];
1365
- lhsDilation: number[];
1366
- rhsDilation: number[];
1367
- }
2036
+ declare function vdot(x: ArrayLike, y: ArrayLike): Array;
2037
+ /** Convolution of two one-dimensional arrays. */
2038
+ declare function convolve(x: Array, y: Array, mode?: "full" | "same" | "valid"): Array;
2039
+ /** Correlation of two one dimensional arrays. */
2040
+ declare function correlate(x: Array, y: Array, mode?: "full" | "same" | "valid"): Array;
2041
+ /**
2042
+ * Return a tuple of coordinate matrices from coordinate vectors.
2043
+ *
2044
+ * Make N-D coordinate arrays for vectorized evaluations of N-D scalar/vector
2045
+ * fields over N-D grids, given one-dimensional coordinate arrays x1, x2,…, xn.
2046
+ */
2047
+ declare function meshgrid(xs: Array[], {
2048
+ indexing
2049
+ }?: {
2050
+ indexing?: "xy" | "ij";
2051
+ }): Array[];
1368
2052
  /**
1369
- * Check that the shapes and parameters passed to convolution are valid.
1370
- * Expected shapes of the lhs and rhs of the convolution are:
2053
+ * Clip (limit) the values in an array.
1371
2054
  *
1372
- * - `lhsShape = [*vmapDims, batchSize, inChannels, spatialDims...]`
1373
- * - `rhsShape = [*vmapDims, outChannels, inChannels, kernelSize...]`
2055
+ * Given an interval, values outside the interval are clipped to the interval
2056
+ * edges. For example, if an interval of [0, 1] is specified, values smaller
2057
+ * than 0 become 0, and values larger than 1 become 1.
1374
2058
  *
1375
- * If the check succeeds, returns the output shape.
2059
+ * If either bound is undefined, it is ignored.
1376
2060
  */
1377
- //#endregion
1378
- //#region src/frontend/jaxpr.d.ts
2061
+ declare function clip(a: ArrayLike, min?: ArrayLike, max?: ArrayLike): Array;
1379
2062
  /**
1380
- * Function callback with an associated dispose() method.
2063
+ * Calculate the absolute value element-wise.
1381
2064
  *
1382
- * The dispose() method should be called to clean up any tracer resources needed
1383
- * by the function after the last time it is called.
2065
+ * This is the same function as `jax.numpy.abs()`.
1384
2066
  */
1385
- type OwnedFunction<F extends Function> = F & {
1386
- dispose: () => void;
1387
- };
1388
- /** Variable in a Jaxpr expression. */
1389
- declare class Var {
1390
- #private;
1391
- readonly id: number;
1392
- readonly aval: ShapedArray;
1393
- constructor(aval: ShapedArray);
1394
- toString(): string;
1395
- }
1396
- /** Literal in a Jaxpr expression. Currently, only scalars are supported. */
1397
- declare class Lit {
1398
- readonly value: number;
1399
- readonly aval: ShapedArray;
1400
- get dtype(): DType;
1401
- constructor(aval: AbstractValue, value: number);
1402
- }
1403
- type Atom = Var | Lit;
1404
- declare class VarPrinter {
1405
- #private;
1406
- names: Map<Var, string>;
1407
- name(v: Var): string;
1408
- nameType(v: Var): string;
1409
- }
1410
- /** A single statement / binding in a Jaxpr, in ANF form. */
1411
- declare class JaxprEqn {
1412
- readonly primitive: Primitive;
1413
- readonly inputs: Atom[];
1414
- readonly params: Record<string, any>;
1415
- readonly outBinders: Var[];
1416
- constructor(primitive: Primitive, inputs: Atom[], params: Record<string, any>, outBinders: Var[]);
1417
- pprint(usedVars?: Set<Var>, vp?: VarPrinter): PPrint;
1418
- toString(): string;
1419
- }
1420
- /** Typed intermediate representation for traced computations. */
1421
- declare class Jaxpr implements FpHashable {
1422
- #private;
1423
- readonly inBinders: Var[];
1424
- readonly eqns: JaxprEqn[];
1425
- readonly outs: Atom[];
1426
- constructor(inBinders: Var[], eqns: JaxprEqn[], outs: Atom[]);
1427
- pprint(): PPrint;
1428
- toString(): string;
1429
- /**
1430
- * Gets a hash of this Jaxpr.
1431
- *
1432
- * Var identity is not considered in the hash, so two Jaxprs with the same
1433
- * order of assignments and operators but different variable IDs will resolve
1434
- * to the same hash (and toString representation).
1435
- */
1436
- getHash(): bigint;
1437
- hash(state: FpHash): void;
1438
- /**
1439
- * Produce a simplified Jaxpr with basic optimizations applied.
1440
- * - Trim away unused variables.
1441
- * - Fold away *1, *0, or +0 operations against literals.
1442
- * - Remove no-op movement operations.
1443
- */
1444
- simplify(): Jaxpr;
1445
- /** Flattens nested Jit in a Jaxpr. Useful for handling jit-of-jit. */
1446
- flatten(): Jaxpr;
1447
- }
1448
- /** Jaxpr with a collection of associated, traced constants. */
1449
- declare class ClosedJaxpr {
1450
- readonly jaxpr: Jaxpr;
1451
- readonly consts: Tracer[];
1452
- constructor(jaxpr: Jaxpr, consts: Tracer[]);
1453
- /** String representation of this Jaxpr. */
1454
- toString(): string;
1455
- /** Apply a function to the underlying Jaxpr. */
1456
- mapJaxpr(f: (jaxpr: Jaxpr) => Jaxpr): ClosedJaxpr;
1457
- /** Dispose of the constants in this Jaxpr. */
1458
- dispose(): void;
1459
- }
1460
- /** @inline */
1461
- type JitOpts = {
1462
- staticArgnums?: number[];
1463
- };
1464
- //#endregion
1465
- //#region src/frontend/core.d.ts
2067
+ declare function absolute(x: ArrayLike): Array;
2068
+ /** Return an element-wise indication of sign of the input. */
2069
+ declare function sign(x: ArrayLike): Array;
2070
+ /** @function Return element-wise positive values of the input (no-op). */
2071
+ declare const positive: (x: ArrayLike) => Array;
1466
2072
  /**
1467
- * Frontend primitive operations, which are lowered into Kernel objects before
1468
- * being dispatched to the backend.
1469
- *
1470
- * Any operation between arrays can be described in these parts. This is also
1471
- * the set of primitives that can occur in Jaxpr programs, and the level at
1472
- * which transformations like vmap, grad, and jvp occur. They are loosely based
1473
- * on [XLA](https://openxla.org/xla/operation_semantics).
1474
- *
1475
- * All n-ary operations support broadcasting, with NumPy semantics.
1476
- */
1477
- declare enum Primitive {
1478
- Add = "add",
1479
- Mul = "mul",
1480
- Idiv = "idiv",
1481
- Mod = "mod",
1482
- // uses sign of numerator, C-style, matches JS but not Python
1483
- Min = "min",
1484
- Max = "max",
1485
- Neg = "neg",
1486
- Reciprocal = "reciprocal",
1487
- Floor = "floor",
1488
- Ceil = "ceil",
1489
- StopGradient = "stop_gradient",
1490
- Cast = "cast",
1491
- Bitcast = "bitcast",
1492
- Sin = "sin",
1493
- Cos = "cos",
1494
- Asin = "asin",
1495
- Atan = "atan",
1496
- Exp = "exp",
1497
- Log = "log",
1498
- Erf = "erf",
1499
- Erfc = "erfc",
1500
- Sqrt = "sqrt",
1501
- Reduce = "reduce",
1502
- Dot = "dot",
1503
- // sum(x*y, axis=-1)
1504
- Conv = "conv",
1505
- // see lax.conv_general_dilated
1506
- Pool = "pool",
1507
- PoolTranspose = "pool_transpose",
1508
- Compare = "compare",
1509
- Where = "where",
1510
- Concatenate = "concatenate",
1511
- Split = "split",
1512
- RandomBits = "random_bits",
1513
- Gather = "gather",
1514
- Transpose = "transpose",
1515
- Broadcast = "broadcast",
1516
- Reshape = "reshape",
1517
- Flip = "flip",
1518
- Shrink = "shrink",
1519
- Pad = "pad",
1520
- Sort = "sort",
1521
- // sort(x, axis=-1)
1522
- Argsort = "argsort",
1523
- // argsort(x, axis=-1)
1524
- TriangularSolve = "triangular_solve",
1525
- // A is upper triangular, A @ X.T = B.T
1526
- Cholesky = "cholesky",
1527
- // A is positive-definite, A = L @ L^T
1528
- LU = "lu",
1529
- // LU decomposition with partial pivoting
1530
- Jit = "jit",
1531
- }
1532
- interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
1533
- [Primitive.Cast]: {
1534
- dtype: DType;
1535
- };
1536
- [Primitive.Bitcast]: {
1537
- dtype: DType;
1538
- };
1539
- [Primitive.Reduce]: {
1540
- op: AluOp;
1541
- axis: number[];
1542
- };
1543
- [Primitive.Conv]: ConvParams;
1544
- [Primitive.Pool]: {
1545
- window: number[];
1546
- strides: number[];
1547
- };
1548
- [Primitive.PoolTranspose]: {
1549
- inShape: number[];
1550
- window: number[];
1551
- strides: number[];
1552
- };
1553
- [Primitive.Compare]: {
1554
- op: CompareOp;
1555
- };
1556
- [Primitive.Concatenate]: {
1557
- axis: number;
1558
- };
1559
- [Primitive.Split]: {
1560
- axis: number;
1561
- sizes: number[];
1562
- };
1563
- [Primitive.RandomBits]: {
1564
- shape: number[];
1565
- mode: "xor" | 0 | 1;
1566
- };
1567
- [Primitive.Gather]: {
1568
- axis: number[];
1569
- outDim: number;
1570
- };
1571
- [Primitive.Transpose]: {
1572
- perm: number[];
1573
- };
1574
- [Primitive.Broadcast]: {
1575
- shape: number[];
1576
- axis: number[];
1577
- };
1578
- [Primitive.Reshape]: {
1579
- shape: number[];
1580
- };
1581
- [Primitive.Flip]: {
1582
- axis: number[];
1583
- };
1584
- [Primitive.Shrink]: {
1585
- slice: Pair[];
1586
- };
1587
- [Primitive.Pad]: {
1588
- width: Pair[];
1589
- };
1590
- [Primitive.Jit]: {
1591
- name: string;
1592
- jaxpr: Jaxpr;
1593
- numConsts: number;
1594
- };
1595
- [Primitive.TriangularSolve]: {
1596
- unitDiagonal: boolean;
1597
- };
1598
- }
1599
- /** Type of parameters taken by each primitive. */
1600
- type PrimitiveParams<T extends Primitive> = T extends keyof PrimitiveParamsImpl ? PrimitiveParamsImpl[T] : Record<string, never>;
1601
- declare enum CompareOp {
1602
- Less = "less",
1603
- Equal = "equal",
1604
- NotEqual = "not_equal",
1605
- LessEqual = "less_equal",
1606
- }
1607
- /** @inline */
1608
- type Axis = number | number[] | null;
1609
- /** @inline */
1610
- type ReduceOpts = {
1611
- keepdims?: boolean;
1612
- };
1613
- type MainTrace = {
1614
- level: number;
1615
- traceType: new (main: MainTrace) => Trace;
1616
- globalData: any | null;
1617
- };
2073
+ * Return the Hamming window of size M, a taper with a weighted cosine bell.
2074
+ *
2075
+ * `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
2076
+ */
2077
+ declare function hamming(M: number): Array;
1618
2078
  /**
1619
- * Push an interpreter onto the trace stack. Use this like:
1620
- * `using main = newMain(...);`
2079
+ * Return the Hann window of size M, a taper with a weighted cosine bell.
2080
+ *
2081
+ * `w(n) = 0.5 - 0.5 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
1621
2082
  */
1622
-
1623
- type TracerValue = Tracer | number | boolean;
1624
- declare abstract class Trace {
1625
- readonly main: MainTrace;
1626
- constructor(main: MainTrace);
1627
- abstract pure(val: TracerValue): Tracer;
1628
- abstract lift(val: Tracer): Tracer;
1629
- abstract processPrimitive<P extends Primitive>(primitive: P, tracers: Tracer[], params: PrimitiveParams<P>): Tracer[];
1630
- }
1631
- /** Internal representation of an array value. */
1632
- interface AbstractValue {
1633
- /** Shape of the array. Must be a static tuple of non-negative dimensions. */
1634
- shape: number[];
1635
- /** Concrete data type of array elements. */
1636
- dtype: DType;
1637
- /**
1638
- * Arrays created from JavaScript numbers (e.g., `np.array(3)`) are created as
1639
- * _weakly typed_ unless a dtype is explicitly specified.
1640
- *
1641
- * Weakly typed values will automatically cast to the data type of other
1642
- * arrays when used as an operand as an expression. This property only affects
1643
- * how they promote in type casting; their memory layout is still determined
1644
- * by the actual `dtype` field.
1645
- *
1646
- * ```ts
1647
- * const x = np.array(3); // weakType = true, dtype = float32
1648
- * const y = np.array([1, 2], { dtype: np.int32 }); // weakType = false, dtype = int32
1649
- * const z = x.add(y); // z has dtype int32 because x is weakly typed
1650
- * ```
1651
- *
1652
- * Weak types are present in JIT programs in their spec (e.g., Jaxpr inputs
1653
- * and outputs can be weakly typed) form. But they're solely a frontend
1654
- * concept. Backends are not aware of weak types.
1655
- */
1656
- weakType: boolean;
1657
- }
2083
+ declare function hann(M: number): Array;
1658
2084
  /**
1659
- * Broadcast shapes and promote types with casting for two avals.
2085
+ * @function
2086
+ * Compute the Heaviside step function. It is defined piecewise:
2087
+ * - `heaviside(x1, x2) = 0` for `x1 < 0`,
2088
+ * - `heaviside(x1, x2) = x2` for `x1 == 0`,
2089
+ * - `heaviside(x1, x2) = 1` for `x1 > 0`.
2090
+ */
2091
+ declare const heaviside: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
2092
+ /** Calculate element-wise square of the input array. */
2093
+ declare function square(x: ArrayLike): Array;
2094
+ /** Element-wise tangent function (takes radians). */
2095
+ declare function tan(x: ArrayLike): Array;
2096
+ /**
2097
+ * @function
2098
+ * Return the normalized sinc function.
1660
2099
  *
1661
- * This implements the weak type behavior described in `promoteTypes()`, but not
1662
- * implemented in that function as `weakType` is not passed.
2100
+ * The sinc function is defined as `sin(πx) / (πx)` for `x != 0`, and `1` for `x = 0`.
2101
+ * This is the normalized sinc function commonly used in signal processing.
2102
+ *
2103
+ * **Note:** JVP is not supported at x=0 due to discontinuous derivative. This
2104
+ * requires a custom JVP rule to handle properly (see JAX implementation).
1663
2105
  */
1664
-
1665
- declare abstract class Tracer {
1666
- /** @ignore */
1667
- readonly _trace: Trace;
1668
- constructor(trace: Trace);
1669
- abstract get aval(): AbstractValue;
1670
- abstract toString(): string;
1671
- /**
1672
- * Access an array by reference, incrementing the reference count.
1673
- *
1674
- * jax-js handles freeing arrays by using "move" semantics, like in Rust/C++.
1675
- * Whenever you pass an array into a function, that function should consume
1676
- * the array, and it will no longer be usable. For example, if you had:
1677
- *
1678
- * ```
1679
- * const x = np.array([1, 2, 3]);
1680
- * const y = np.add(x, x);
1681
- * ```
1682
- *
1683
- * The second line does not work because the first parameter consumes `x`, and
1684
- * then the second parameter will already have been freed / disposed.
1685
- *
1686
- * To fix this, you can write:
1687
- *
1688
- * ```
1689
- * const y = np.add(x.ref, x);
1690
- * ```
1691
- *
1692
- * Under the hood, every access to `.ref` increments the internal reference
1693
- * count of the array. The reference count starts at 1. When it hits 0, the
1694
- * memory behind the array is freed.
1695
- */
1696
- abstract get ref(): this;
1697
- /**
1698
- * Manually decrement the reference count of the array.
1699
- *
1700
- * Arrays are created with reference count 1. Whenever it is used as argument
1701
- * to a function or other operation, it is disposed (i.e., reference count
1702
- * decreases by 1) automatically. Whenever a `.ref` is created, the reference
1703
- * count increases.
1704
- *
1705
- * You generally don't need to call this function directly since arrays are
1706
- * automatically disposed after being passed into an operation. One common
1707
- * exception is when writing a function and ignoring one of its arguments. In
1708
- * that case, by convention you should dispose of that argument manually.
1709
- *
1710
- * ```
1711
- * function myCustomOperation(a: np.Array, b: np.Array) {
1712
- * b.dispose(); // Needed to satisfy "move" rules.
1713
- * return a.add(1);
1714
- * }
1715
- * ```
1716
- */
1717
- abstract dispose(): void;
1718
- /** The shape of the array. */
1719
- get shape(): number[];
1720
- /** The total number of elements in the array. */
1721
- get size(): number;
1722
- /** The dtype of elements stored in the array. */
1723
- get dtype(): DType;
1724
- /**
1725
- * Whether the array is weakly typed.
1726
- *
1727
- * Weakly typed arrays will cast to the dtype of the other operand. See
1728
- * `promoteTypes()` for details.
1729
- */
1730
- get weakType(): boolean;
1731
- /** The number of dimensions of the array. */
1732
- get ndim(): number;
1733
- /** @ignore */
1734
- fullLower(): Tracer;
1735
- neg(): this;
1736
- add(other: this | TracerValue): this;
1737
- mul(other: this | TracerValue): this;
1738
- mod(other: this | TracerValue): this;
1739
- greater(other: this | TracerValue): this;
1740
- less(other: this | TracerValue): this;
1741
- equal(other: this | TracerValue): this;
1742
- notEqual(other: this | TracerValue): this;
1743
- greaterEqual(other: this | TracerValue): this;
1744
- lessEqual(other: this | TracerValue): this;
1745
- /** Sum of the elements of the array over a given axis, or axes. */
1746
- sum(axis?: Axis, opts?: ReduceOpts): this;
1747
- /** Product of the array elements over a given axis. */
1748
- prod(axis?: Axis, opts?: ReduceOpts): this;
1749
- /** Compute the average of the array elements along the specified axis. */
1750
- mean(axis?: Axis, opts?: ReduceOpts): this;
1751
- /** Permute the dimensions of an array. Defaults to reversing the axis order. */
1752
- transpose(perm?: number[]): this;
1753
- /**
1754
- * Give a new shape to an array without changing its data.
1755
- *
1756
- * One shape dimension can be -1. In this case, the value is inferred from the
1757
- * length of the array and remaining dimensions.
1758
- */
1759
- reshape(shape: number | number[]): this;
1760
- /** Copy the array and cast to a specified dtype. */
1761
- astype(dtype: DType): this;
1762
- /** Subtract an array from this one. */
1763
- sub(other: this | TracerValue): this;
1764
- /** Divide an array by this one. */
1765
- div(other: this | TracerValue): this;
1766
- /** Return specified diagonals. See `jax.numpy.diagonal` for full docs. */
1767
- diagonal(offset?: number, axis1?: number, axis2?: number): this;
1768
- /** Flatten the array without changing its data. */
1769
- flatten(): this;
1770
- /** Flatten the array without changing its data. */
1771
- ravel(): this;
1772
- /**
1773
- * Iterate over the first dimension of this array, returning slices.
1774
- *
1775
- * This can be used to destructure arrays. For example:
1776
- *
1777
- * ```js
1778
- * let x = np.array([[1, 2], [3, 4]]);
1779
- * let [a, b] = x;
1780
- * console.log(a.js()); // [1, 2]
1781
- * console.log(b.js()); // [3, 4]
1782
- * ```
1783
- */
1784
- [Symbol.iterator](): IterableIterator<this>;
1785
- /**
1786
- * Return a sorted copy of an array in ascending order.
1787
- *
1788
- * See `jax.numpy.sort` for full docs.
1789
- */
1790
- sort(axis?: number): this;
1791
- /**
1792
- * Return the indices that would sort an array. This may not be a stable
1793
- * sorting algorithm; it need not preserve order of indices in ties.
1794
- *
1795
- * See `jax.numpy.argsort` for full docs.
1796
- */
1797
- argsort(axis?: number): this;
1798
- /**
1799
- * Slice an array along one or more axes.
1800
- *
1801
- * This is the equivalent of slicing in Python, e.g. `x[1:3, 2, :, None]`. To
1802
- * mimic this in JavaScript, we would write:
1803
- *
1804
- * ```js
1805
- * x.slice([1, 3], 2, [], null);
1806
- * ```
1807
- *
1808
- * The `slice` method accepts a variable number of arguments, each of which
1809
- * can be a number, an empty array, a single-element array, a two-element
1810
- * array, or `null`. The arguments are interpreted as follows:
1811
- *
1812
- * - A number `n` means to access the `n`-th element along that axis, removing
1813
- * that axis from the resulting shape.
1814
- * - An empty array `[]` means to keep that axis as-is, like `:` in Python.
1815
- * - A single-element array `[i]` means to start slicing from index `i`
1816
- * (inclusive) to the end of the axis, like `x[i:]`.
1817
- * - A two-element array `[i, j]` means to slice from index `i` (inclusive)
1818
- * to index `j` (exclusive), like `x[i:j]`.
1819
- * - `null` means to add a new axis at that position, like `np.newaxis`.
1820
- *
1821
- * Like in Python, negative indices are supported, which count from the end of
1822
- * the axis. For example, `-1` means the last element.
1823
- *
1824
- * Strided slices are not yet implemented, so you cannot write `x[::2]` or
1825
- * similar.
1826
- *
1827
- * Advanced indexing by integer arrays is also supported. This translates to
1828
- * the "gather" primitive, and it allows you to access specific elements of
1829
- * the array by integer indices stored in another array.
1830
- */
1831
- slice(...index: (number | [] | [number] | Pair | null | Tracer)[]): this;
1832
- }
1833
- declare class ShapedArray implements AbstractValue {
1834
- readonly shape: number[];
1835
- readonly dtype: DType;
1836
- readonly weakType: boolean;
1837
- constructor(shape: number[], dtype: DType, weakType: boolean);
1838
- static fromAval(aval: AbstractValue): ShapedArray;
1839
- get ndim(): number;
1840
- get size(): number;
1841
- scalar(): ShapedArray;
1842
- toString(): string;
1843
- equals(other: ShapedArray): boolean;
1844
- }
1845
- //#endregion
1846
- //#region src/frontend/array.d.ts
1847
- type ArrayLike = Array | number | boolean;
1848
- /** Version of pureArray with fudged types. */
1849
-
2106
+ declare const sinc: OwnedFunction<(x: ArrayLike) => Array>;
2107
+ /** Element-wise inverse cosine function (inverse of cos). */
2108
+ declare function acos(x: ArrayLike): Array;
1850
2109
  /**
1851
- * An executable operation that will be dispatched to the backend.
2110
+ * @function
2111
+ * Return element-wise hypotenuse for the given legs of a right triangle.
1852
2112
  *
1853
- * This holds a reference to all input buffers used in the operation. After the
1854
- * operation is dispatched, the references should be released.
2113
+ * In the original NumPy/JAX implementation, this function is more numerically
2114
+ * stable than `sqrt(x1**2 + x2**2)`. We don't currently implement those
2115
+ * stability improvements.
1855
2116
  */
1856
- declare class PendingExecute {
1857
- #private;
1858
- readonly backend: Backend;
1859
- readonly source: Kernel | Routine;
1860
- readonly inputs: Slot[];
1861
- readonly outputs: Slot[];
1862
- prepared: Executable | null;
1863
- submitted: boolean;
1864
- constructor(backend: Backend, source: Kernel | Routine, inputs: Slot[], outputs: Slot[]);
1865
- updateRc(delta: number): void;
1866
- prepare(): Promise<void>;
1867
- prepareSync(): void;
1868
- submit(): void;
1869
- }
1870
- /** @inline */
1871
- type DTypeAndDevice = {
1872
- dtype?: DType;
1873
- device?: Device;
1874
- };
1875
- type ArrayConstructorArgs = {
1876
- source: AluExp | Slot;
1877
- st: ShapeTracker;
1878
- dtype: DType;
1879
- weakType: boolean;
1880
- backend: Backend;
1881
- committed: boolean;
1882
- pending?: Iterable<PendingExecute>;
1883
- };
2117
+ declare const hypot: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
1884
2118
  /**
1885
- * A multidimensional numeric array with data stored on CPU or GPU.
2119
+ * @function
2120
+ * Element-wise arc tangent of y/x with correct quadrant.
1886
2121
  *
1887
- * This is the library's core data type. Equivalent to `jax.Array` from JAX, or
1888
- * `torch.Tensor`.
2122
+ * Returns the angle in radians between the positive x-axis and the point (x, y).
2123
+ * The result is in the range [-π, π].
1889
2124
  *
1890
- * Not to be confused with the JavaScript "Array" constructor. Avoid importing
1891
- * this into your code's namespace if you're already using the JavaScript
1892
- * "Array" type by name.
2125
+ * Uses numerically stable formulas:
2126
+ * - When x >= 0: atan2(y, x) = 2 * atan(y / (sqrt(x^2 + y^2) + x))
2127
+ * - When x < 0: atan2(y, x) = 2 * atan((sqrt(x^2 + y^2) - x) / y)
2128
+ *
2129
+ * The output is ill-defined when both x and y are zero.
1893
2130
  */
1894
- declare class Array extends Tracer {
1895
- #private;
1896
- /**
1897
- * @ignore
1898
- * Constructs an array from source, shape and backend. Note that if the source
1899
- * is a backend `Slot`, this constructor _takes ownership_ of the slot. It
1900
- * will be freed when the array is disposed.
1901
- */
1902
- constructor(args: ArrayConstructorArgs);
1903
- /** @ignore */
1904
- get aval(): ShapedArray;
1905
- /** Return a simple string representation of the array's dimensions. */
1906
- toString(): string;
1907
- get device(): Device;
1908
- get ref(): this;
1909
- /** Get the current reference count (for debugging memory management). */
1910
- get refCount(): number;
1911
- dispose(): void;
1912
- /**
1913
- * Convert this array into a primitive value.
1914
- *
1915
- * This only works for scalars (0-dimensional arrays). It lets you get values
1916
- * "out" of the JAX system. For instance, if `x = np.array(5)`, then you can
1917
- * evaluate `x + 1` and `x ** 2` to get `6` and `25`, respectively.
1918
- *
1919
- * This method is also called for `==` equality.
1920
- */
1921
- [Symbol.toPrimitive](): any;
1922
- /** Realize the array and return it as data. */
1923
- data(): Promise<DataArray>;
1924
- /**
1925
- * Wait for this array to finish evaluation.
1926
- *
1927
- * Operations and data loading in jax-js are lazy, so this function ensures
1928
- * that pending operations are dispatched and fully executed before it
1929
- * returns.
1930
- *
1931
- * If you are mapping from `data()` or `dataSync()`, it will also trigger
1932
- * dispatch of operations as well.
1933
- *
1934
- * **Note:** `jax.blockUntilReady()` is a higher-level API, it calls this
1935
- * asynchronously for multiple arrays.
1936
- */
1937
- blockUntilReady(): Promise<Array>;
1938
- /**
1939
- * Realize the array and return it as data. This is a sync variant and not
1940
- * recommended for performance reasons, as it will block rendering.
1941
- */
1942
- dataSync(): DataArray;
1943
- /**
1944
- * Convert this array into a JavaScript object.
1945
- *
1946
- * This is a blocking operation that will compile all of the shaders and wait
1947
- * for execution to complete, synchronously. No other JavaScript code on the
1948
- * site will be run during shader execution.
1949
- *
1950
- * To avoid blocking, prefer `jsAsync()` when possible.
1951
- */
1952
- js(): any;
1953
- /** Convert this array into a JavaScript object, asynchronously. */
1954
- jsAsync(): Promise<any>;
1955
- /**
1956
- * Copy an element of an array to a numeric scalar and return it.
1957
- *
1958
- * Throws an error if the array does not have a single element. The array must
1959
- * either be rank-0, or all dimensions of the shape are 1.
1960
- */
1961
- item(): number;
1962
- /** @private Internal plumbing method for Array / Tracer ops. */
1963
- static _implRules(): typeof implRules;
1964
- /** @private */
1965
- _realizeSource(): number;
1966
- /** @private Put this array on a new backend, asynchronously. */
1967
- _put(backend: Backend): Promise<Array>;
1968
- /** @private Put this array on a new backend, synchronously. */
1969
- _putSync(backend: Backend): Array;
1970
- }
1971
- /** Constructor for creating a new array from data. */
1972
- declare function array(values: Array | DataArray | RecursiveArray<number> | RecursiveArray<boolean>, {
1973
- shape,
1974
- dtype,
1975
- device
1976
- }?: {
1977
- shape?: number[];
1978
- } & DTypeAndDevice): Array;
1979
- /** If x is a value, lift it into an array, otherwise leave it be. */
1980
-
1981
- type ImplRule<P extends Primitive> = (tracers: Array[], params: PrimitiveParams<P>) => Array[];
1982
- declare const implRules: { [P in Primitive]: ImplRule<P> };
1983
- /** Return a new array of given shape and type, filled with zeros. */
1984
- declare function zeros(shape: number[], {
1985
- dtype,
1986
- device
1987
- }?: DTypeAndDevice): Array;
1988
- /** Return a new array of given shape and type, filled with ones. */
1989
- declare function ones(shape: number[], {
1990
- dtype,
1991
- device
1992
- }?: DTypeAndDevice): Array;
1993
- /** Return a new array of given shape and type, filled with `fill_value`. */
1994
- declare function full(shape: number[], fillValue: number | boolean | Array, {
1995
- dtype,
1996
- device
1997
- }?: DTypeAndDevice): Array;
2131
+ declare const atan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
2132
+ /** Element-wise subtraction, with broadcasting. */
2133
+ declare function subtract(x: ArrayLike, y: ArrayLike): Array;
2134
+ /** Calculates the floating-point division of x by y element-wise. */
2135
+ declare function trueDivide(x: ArrayLike, y: ArrayLike): Array;
1998
2136
  /**
1999
- * Create an identity matrix.
2137
+ * Return the largest integer smaller or equal to the division of the inputs.
2000
2138
  *
2001
- * If numCols is not provided, it defaults to numRows, i.e., a square identity
2002
- * matrix with ones on the diagonal.
2139
+ * The result is always rounded towards negative infinity.
2140
+ *
2141
+ * For floating-point inputs, this is equivalent to `floor(x / y)`.
2142
+ * For integer inputs, we use `(x - remainder(x, y)) / y` to handle
2143
+ * negative values correctly (note: may overflow near int32 boundaries).
2144
+ *
2145
+ * @param x - Dividend array.
2146
+ * @param y - Divisor array.
2147
+ * @returns Element-wise floor division of x by y.
2003
2148
  */
2004
- declare function eye(numRows: number, numCols?: number, {
2005
- dtype,
2006
- device
2007
- }?: DTypeAndDevice): Array;
2008
- /** Return the identity matrix, with ones on the main diagonal. */
2009
- declare function identity$1(n: number, {
2010
- dtype,
2011
- device
2012
- }?: DTypeAndDevice): Array;
2149
+ declare function floorDivide(x: ArrayLike, y: ArrayLike): Array;
2013
2150
  /**
2014
- * Return evenly spaced values within a given interval.
2151
+ * @function
2152
+ * Calculate element-wise floating-point modulo operation.
2153
+ */
2154
+ declare const fmod: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
2155
+ /**
2156
+ * @function
2157
+ * Calculate element-wise remainder of the division (matches sign of y).
2158
+ */
2159
+ declare const remainder: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
2160
+ /**
2161
+ * Return element-wise quotient and remainder simultaneously.
2015
2162
  *
2016
- * This can be called with a varying number of arguments, just like the range()
2017
- * builtin function in Python.
2163
+ * Equivalent to `[floorDivide(x, y), remainder(x, y)]`.
2018
2164
  *
2019
- * - `arange(stop)` is equivalent to `arange(0, stop, 1)`.
2020
- * - `arange(start, stop)` is equivalent to `arange(start, stop, 1)`.
2021
- * - `arange(start, stop, step)` creates an array starting at `start`, ending
2022
- * before `stop`, with a step size of `step`.
2165
+ * @param x - Dividend array.
2166
+ * @param y - Divisor array.
2167
+ * @returns Tuple of [quotient, remainder].
2168
+ */
2169
+ declare function divmod(x: ArrayLike, y: ArrayLike): [Array, Array];
2170
+ /** Round input to the nearest integer towards zero. */
2171
+ declare function trunc(x: ArrayLike): Array;
2172
+ /**
2173
+ * Compute `x1 * 2 ** x2` as a standard multiplication and exponentiation.
2023
2174
  *
2024
- * Defaults to an integer data type. This can produce unintended results when
2025
- * using a non-integer step, so prefer linspace() in those cases.
2175
+ * This is the inverse of `frexp()`.
2026
2176
  */
2027
- declare function arange(start: number, stop?: number, step?: number, {
2028
- dtype,
2029
- device
2030
- }?: DTypeAndDevice): Array;
2177
+ declare function ldexp(x1: ArrayLike, x2: ArrayLike): Array;
2031
2178
  /**
2032
- * Return an array with ones on and below the diagonal and zeros elsewhere.
2179
+ * Decompose floating-point values into mantissa and two's exponent.
2033
2180
  *
2034
- * If `k` is provided, it specifies the sub-diagonal on and below which the
2035
- * array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
2036
- * `k>0` is above it.
2181
+ * The mantissa is returned in the range `(-1, 1)` with magnitude `>= 0.5` if
2182
+ * `x != 0`, and the exponent is an integer such that
2183
+ * `x = mantissa * 2**exponent`.
2037
2184
  */
2038
- declare function tri(n: number, m?: number, k?: number, {
2039
- dtype,
2040
- device
2041
- }?: DTypeAndDevice): Array;
2042
- /** Return the lower triangle of an array. Must be of dimension >= 2. */
2043
- declare function tril(a: ArrayLike, k?: number): Array;
2044
- /** Return the upper triangle of an array. Must be of dimension >= 2. */
2045
- declare function triu(a: ArrayLike, k?: number): Array;
2185
+ declare function frexp(x: ArrayLike): [Array, Array];
2186
+ /** Calculate `2**p` for all p in the input array. */
2187
+ declare function exp2(p: ArrayLike): Array;
2188
+ /** Return the base-2 logarithm of x, element-wise. */
2189
+ declare function log2(x: ArrayLike): Array;
2190
+ /** Return the base-10 logarithm of x, element-wise. */
2191
+ declare function log10(x: ArrayLike): Array;
2192
+ /** Calculate `exp(x) - 1` element-wise. */
2193
+ declare function expm1(x: ArrayLike): Array;
2194
+ /** Calculate the natural logarithm of `1 + x` element-wise. */
2195
+ declare function log1p(x: ArrayLike): Array;
2196
+ /** Convert angles from degrees to radians. */
2197
+ declare function deg2rad(x: ArrayLike): Array;
2198
+ /** @function Alias of `jax.numpy.deg2rad()`. */
2199
+ declare const radians: typeof deg2rad;
2200
+ /** Convert angles from radians to degrees. */
2201
+ declare function rad2deg(x: ArrayLike): Array;
2202
+ /** @function Alias of `jax.numpy.rad2deg()`. */
2203
+ declare const degrees: typeof rad2deg;
2046
2204
  /**
2047
- * Return evenly spaced numbers over a specified interval.
2048
- *
2049
- * Returns _num_ evenly spaced samples, calculated over the interval
2050
- * [`start`, `stop`]. The endpoint `stop` is included in the result by default,
2051
- * but this is controlled by the `endpoint` parameter.
2052
- *
2053
- * The default data type is Float32. Use arange() for integer steps.
2205
+ * @function
2206
+ * Computes first array raised to power of second array, element-wise.
2054
2207
  */
2055
- declare function linspace(start: number, stop: number, num?: number, endpoint?: boolean, {
2056
- dtype,
2057
- device
2058
- }?: DTypeAndDevice): Array;
2208
+ declare const power: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
2209
+ /** @function Calculate the element-wise cube root of the input array. */
2210
+ declare const cbrt: OwnedFunction<(x: ArrayLike) => Array>;
2059
2211
  /**
2060
- * Return numbers spaced evenly on a log scale.
2061
- *
2062
- * In linear space, the sequence starts at `base ** start` and ends at
2063
- * `base ** stop` (see `endpoint` below).
2212
+ * @function
2213
+ * Calculate element-wise hyperbolic sine of input.
2064
2214
  *
2065
- * @param start - `base ** start` is the starting value of the sequence.
2066
- * @param stop - `base ** stop` is the final value of the sequence, unless `endpoint` is false.
2067
- * @param num - Number of samples to generate. Default is 50.
2068
- * @param endpoint - If true, `stop` is the last sample. Otherwise, it is not included. Default is true.
2069
- * @param base - The base of the log space. Default is 10.
2070
- * @returns Array of evenly spaced values on a log scale.
2215
+ * `sinh(x) = (exp(x) - exp(-x)) / 2`
2071
2216
  */
2072
- declare function logspace(start: number, stop: number, num?: number, endpoint?: boolean, base?: number, {
2073
- dtype,
2074
- device
2075
- }?: DTypeAndDevice): Array;
2076
- declare namespace lax_linalg_d_exports {
2077
- export { cholesky, lu, triangularSolve };
2078
- }
2217
+ declare const sinh: OwnedFunction<(x: ArrayLike) => Array>;
2079
2218
  /**
2080
- * Compute the Cholesky decomposition of a symmetric positive-definite matrix.
2081
- *
2082
- * The Cholesky decomposition of a matrix `A` is:
2083
- *
2084
- * - A = L @ L^T (for upper=false, default)
2085
- * - A = U^T @ U (for upper=true)
2086
- *
2087
- * where `L` is a lower-triangular matrix and `U` is an upper-triangular matrix.
2088
- * The input matrix must be symmetric and positive-definite.
2089
- *
2090
- * @example
2091
- * ```ts
2092
- * import { lax, numpy as np } from "@jax-js/jax";
2093
- *
2094
- * const x = np.array([[2., 1.], [1., 2.]]);
2095
- *
2096
- * // Lower Cholesky factorization (default):
2097
- * const L = lax.linalg.cholesky(x);
2098
- * // L ≈ [[1.4142135, 0], [0.70710677, 1.2247449]]
2219
+ * @function
2220
+ * Calculate element-wise hyperbolic cosine of input.
2099
2221
  *
2100
- * // Upper Cholesky factorization:
2101
- * const U = lax.linalg.cholesky(x, { upper: true });
2102
- * // U ≈ [[1.4142135, 0.70710677], [0, 1.2247449]]
2103
- * ```
2222
+ * `cosh(x) = (exp(x) + exp(-x)) / 2`
2104
2223
  */
2105
- declare function cholesky(a: ArrayLike, {
2106
- upper
2107
- }?: {
2108
- upper?: boolean;
2109
- }): Array;
2224
+ declare const cosh: OwnedFunction<(x: ArrayLike) => Array>;
2110
2225
  /**
2111
- * LU decomposition with partial pivoting.
2112
- *
2113
- * Computes the matrix decomposition: `P @ A = L @ U`, where `P` is a
2114
- * permutation of the rows of `A`, `L` is lower-triangular with unit diagonal,
2115
- * and `U` is upper-triangular.
2116
- *
2117
- * @param x - A batch of matrices with shape `[..., m, n]`.
2118
- *
2119
- * @returns A tuple `(lu, pivots, permutation)` where:
2120
- * - `lu`: combined lower and upper triangular matrices.
2121
- * - `pivots`: an array of pivot indices with shape `[..., min(m, n)]`.
2122
- * - `permutation`: the permutation generated by pivots with shape `[..., m]`.
2123
- *
2124
- * @example
2125
- * ```ts
2126
- * import { lax, numpy as np } from "@jax-js/jax";
2226
+ * @function
2227
+ * Calculate element-wise hyperbolic tangent of input.
2127
2228
  *
2128
- * const A = np.array([[4., 3.], [6., 3.]]);
2129
- * const [lu, pivots, permutation] = lax.linalg.lu(A);
2130
- * // lu ≈ [[6., 3.], [0.6666667, 1.0]]
2131
- * // pivots = [1, 1]
2132
- * // permutation = [1, 0]
2133
- * ```
2229
+ * `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
2134
2230
  */
2135
- declare function lu(x: ArrayLike): [Array, Array, Array];
2231
+ declare const tanh: OwnedFunction<(x: ArrayLike) => Array>;
2136
2232
  /**
2137
- * Solve a triangular linear system.
2138
- *
2139
- * Solves `a @ x = b` (if leftSide=true) or `x @ a = b` (if leftSide=false)
2140
- * where `a` is a triangular matrix.
2141
- *
2142
- * @example
2143
- * ```ts
2144
- * import { lax, numpy as np } from "@jax-js/jax";
2145
- *
2146
- * const L = np.array([[2., 0.], [1., 3.]]);
2147
- * const b = np.array([4., 7.]).reshape([2, 1]);
2233
+ * @function
2234
+ * Calculate element-wise inverse hyperbolic sine of input.
2148
2235
  *
2149
- * // Solve L @ x = b
2150
- * const x = lax.linalg.triangularSolve(L, b, { leftSide: true, lower: true });
2151
- * // x = [[2.], [5./3.]]
2152
- * ```
2236
+ * `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
2153
2237
  */
2154
- declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
2155
- leftSide,
2156
- lower,
2157
- transposeA,
2158
- unitDiagonal
2159
- }?: {
2160
- leftSide?: boolean;
2161
- lower?: boolean;
2162
- transposeA?: boolean;
2163
- unitDiagonal?: boolean;
2164
- }): Array;
2165
- declare namespace lax_d_exports {
2166
- export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convWithGeneralPadding, dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
2167
- }
2238
+ declare const arcsinh: OwnedFunction<(x: ArrayLike) => Array>;
2168
2239
  /**
2169
- * Dimension numbers for general `dot()` primitive.
2240
+ * @function
2241
+ * Calculate element-wise inverse hyperbolic cosine of input.
2170
2242
  *
2171
- * Contracting dimensions act as a tensor contraction (reduction) along the
2172
- * given axis. They must be the same size in both operands. Batch dimensions
2173
- * are treated as vectorized, leading batch dimensions.
2243
+ * `arccosh(x) = ln(x + sqrt(x^2 - 1))`
2244
+ */
2245
+ declare const arccosh: OwnedFunction<(x: ArrayLike) => Array>;
2246
+ /**
2247
+ * @function
2248
+ * Calculate element-wise inverse hyperbolic tangent of input.
2174
2249
  *
2175
- * The return value has a shape where the first dimensions are shared batch
2176
- * dimensions, followed by `lhs` non-contracting dimensions, followed by
2177
- * `rhs` non-contracting dimensions.
2250
+ * `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
2178
2251
  */
2179
- type DotDimensionNumbers = {
2180
- lhsContractingDims?: number[];
2181
- rhsContractingDims?: number[];
2182
- lhsBatchDims?: number[];
2183
- rhsBatchDims?: number[];
2184
- };
2252
+ declare const arctanh: OwnedFunction<(x: ArrayLike) => Array>;
2185
2253
  /**
2186
- * General dot product/contraction operator.
2254
+ * Compute the variance of an array.
2187
2255
  *
2188
- * Prefer higher-level functions like `jax.numpy.dot()`, `jax.numpy.matmul()`,
2189
- * `jax.numpy.tensordot(), and `jax.numpy.einsum()` where possible.
2256
+ * The variance is computed for the flattened array by default, otherwise over
2257
+ * the specified axis.
2258
+ *
2259
+ * If `correction` is provided, the divisor in calculation is `N - correction`,
2260
+ * where `N` represents the number of elements (e.g., for Bessel's correction).
2190
2261
  */
2191
- declare function dot(lhs: Array, rhs: Array, {
2192
- lhsContractingDims: lc,
2193
- rhsContractingDims: rc,
2194
- lhsBatchDims: lb,
2195
- rhsBatchDims: rb
2196
- }?: DotDimensionNumbers): Array;
2197
- type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | Pair[];
2262
+ declare function var_(x: ArrayLike, axis?: Axis, opts?: {
2263
+ mean?: ArrayLike;
2264
+ correction?: number;
2265
+ } & ReduceOpts): Array;
2198
2266
  /**
2199
- * General n-dimensional convolution operator, with optional dilation.
2267
+ * Compute the standard deviation of an array.
2200
2268
  *
2201
- * The semantics of this operation mimic the `jax.lax.conv_general_dilated`
2202
- * function in JAX, which wraps XLA's general convolution operator.
2269
+ * The standard deviation is computed for the flattened array by default,
2270
+ * otherwise over the specified axis.
2203
2271
  *
2204
- * Grouped convolutions are not supported right now.
2272
+ * If `correction` is provided, the divisor in calculation is `N - correction`,
2273
+ * where `N` represents the number of elements (e.g., for Bessel's correction).
2205
2274
  */
2206
- declare function convGeneralDilated(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, {
2207
- lhsDilation,
2208
- rhsDilation,
2209
- featureGroupCount
2275
+ declare function std(x: ArrayLike, axis?: Axis, opts?: {
2276
+ mean?: ArrayLike;
2277
+ correction?: number;
2278
+ } & ReduceOpts): Array;
2279
+ /** Estimate the sample covariance of a set of variables. */
2280
+ declare function cov(x: ArrayLike, y?: ArrayLike | null, {
2281
+ rowvar
2210
2282
  }?: {
2211
- lhsDilation?: number[];
2212
- rhsDilation?: number[];
2213
- featureGroupCount?: number;
2283
+ rowvar?: boolean;
2214
2284
  }): Array;
2215
- /** Convenience wrapper around `convGeneralDilated`. */
2216
- declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, lhsDilation?: number[], rhsDilation?: number[]): Array;
2217
- /** Convenience wrapper around `convGeneralDilated`. */
2218
- declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType): Array;
2219
- /** Reduce a computation over padded windows. */
2220
- declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
2221
- /** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
2222
- declare function erf(x: ArrayLike): Array;
2285
+ /** Compute the Pearson correlation coefficients (in range `[-1, 1]`). */
2286
+ declare function corrcoef(x: ArrayLike, y?: ArrayLike): Array;
2287
+ /** Test element-wise for positive or negative infinity, return bool array. */
2288
+ declare function isinf(x: ArrayLike): Array;
2289
+ /** Test element-wise for NaN (Not a Number). */
2290
+ declare function isnan(x: ArrayLike): Array;
2291
+ /** Test element-wise for negative infinity, return bool array. */
2292
+ declare function isneginf(x: ArrayLike): Array;
2293
+ /** Test element-wise for positive infinity, return bool array. */
2294
+ declare function isposinf(x: ArrayLike): Array;
2223
2295
  /**
2224
- * The complementary error function: `erfc(x) = 1 - erf(x)`.
2296
+ * Replace NaN and infinite entries in an array.
2225
2297
  *
2226
- * This function is more accurate than `1 - erf(x)` for large values of `x`,
2227
- * where `erf(x)` is very close to 1.
2298
+ * By default, NaNs are replaced with `0.0`, and infinities are are substituted
2299
+ * with the corresponding maximum or minimum finite values.
2228
2300
  */
2229
- declare function erfc(x: ArrayLike): Array;
2301
+ declare function nanToNum(x: ArrayLike, {
2302
+ nan,
2303
+ posinf,
2304
+ neginf
2305
+ }?: {
2306
+ nan?: ArrayLike;
2307
+ posinf?: ArrayLike | null;
2308
+ neginf?: ArrayLike | null;
2309
+ }): Array;
2230
2310
  /**
2231
- * Stops gradient computation.
2232
- *
2233
- * Behaves as the identity function but prevents the flow of gradients during
2234
- * forward or reverse-mode automatic differentiation.
2311
+ * @function
2312
+ * Test element-wise for finite values (not infinity or NaN).
2235
2313
  */
2236
- declare function stopGradient(x: ArrayLike): Array;
2314
+ declare const isfinite: OwnedFunction<(x: ArrayLike) => Array>;
2237
2315
  declare namespace nn_d_exports {
2238
- export { celu, elu, gelu, glu, hardSigmoid, hardSilu, hardSilu as hardSwish, hardTanh, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, selu, sigmoid, silu, softSign, softmax, softplus, sparsePlus, sparseSigmoid, squareplus, standardize, silu as swish };
2316
+ export { celu, dotProductAttention, elu, gelu, glu, hardSigmoid, hardSilu, hardSilu as hardSwish, hardTanh, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, selu, sigmoid, silu, softSign, softmax, softplus, sparsePlus, sparseSigmoid, squareplus, standardize, silu as swish };
2239
2317
  }
2240
2318
  /**
2241
2319
  * Rectified Linear Unit (ReLU) activation function:
@@ -2432,6 +2510,56 @@ declare function standardize(x: ArrayLike, axis?: Axis, opts?: {
2432
2510
  * ```
2433
2511
  */
2434
2512
  declare function oneHot(x: Array, numClasses: number): Array;
2513
+ /**
2514
+ * Scaled dot product attention (SDPA).
2515
+ *
2516
+ * Computes `softmax((Q @ K^T) / sqrt(d) + bias) @ V`, where `Q` is the query,
2517
+ * `K` is the key, `V` is the value, and `d` is the dimensionality of each key
2518
+ * and query vector.
2519
+ *
2520
+ * Multi-query attention is applied when input `key` and `value` tensors have
2521
+ * fewer heads than `query`.
2522
+ *
2523
+ * We use the following uppercase letters to denote array shapes:
2524
+ * - `B` = batch size
2525
+ * - `S` = length of key/value sequences (source)
2526
+ * - `L` = length of query sequences
2527
+ * - `N` = number of attention heads
2528
+ * - `H` = dimensionality of each attention head
2529
+ * - `K` = number of key/value heads (for grouped-query attention)
2530
+ *
2531
+ * The batch size `B` may be omitted, which is equivalent to `B = 1`. In this
2532
+ * case it must be omitted from all inputs.
2533
+ *
2534
+ * @param query - Query array; shape `[B, L, N, H]`
2535
+ * @param key - Key array; shape `[B, S, K, H]`
2536
+ * @param value - Value array; same shape as `key`
2537
+ * @param opts.bias - Optional bias to add to the attention logits; shape
2538
+ * `[B, N, L, S]` or broadcastable to it.
2539
+ * @param opts.mask - Optional mask to apply to the attention logits; should be
2540
+ * a boolean array broadcastable to `[B, N, L, S]`, where `true` indicates
2541
+ * the element should take part in attention.
2542
+ * @param opts.scale - Scaling factor override, default is `1 / sqrt(H)`.
2543
+ * @param opts.isCausal - If true, applies a casual mask.
2544
+ * @param opts.querySeqLengths - Optional sequence lengths for the queries;
2545
+ * shape `(B,)`. Taken from the beginning of the tensor.
2546
+ * @param opts.keyValueSeqLengths - Optional sequence lengths for the keys and
2547
+ * values; shape `(B,)`. Taken from the beginning of the tensor.
2548
+ * @param opts.localWindowSize - If specified, applies a local attention window
2549
+ * of the given size. Can be a single number or a tuple `[left, right]`.
2550
+ *
2551
+ * @returns The result of the attention operation; shape is the same as query
2552
+ * `[B, L, N, H]`, or `[L, N, H]` if `B` is omitted.
2553
+ */
2554
+ declare function dotProductAttention(query: ArrayLike, key: ArrayLike, value: ArrayLike, opts?: {
2555
+ bias?: ArrayLike;
2556
+ mask?: ArrayLike;
2557
+ scale?: number;
2558
+ isCausal?: boolean;
2559
+ querySeqLengths?: ArrayLike;
2560
+ keyValueSeqLengths?: ArrayLike;
2561
+ localWindowSize?: number | [number, number];
2562
+ }): Array;
2435
2563
  declare namespace random_d_exports {
2436
2564
  export { bernoulli, bits, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
2437
2565
  }
@@ -2523,7 +2651,9 @@ declare const logit: OwnedFunction<(x: ArrayLike) => Array>;
2523
2651
  * @function
2524
2652
  * Compute the forward-mode Jacobian-vector product for a function.
2525
2653
  */
2526
- declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, ReturnType<F>];
2654
+ declare const jvp: <F extends (...args: any[]) => JsTree<Array>, HA extends boolean = false>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, tangents: MapJsTree<Parameters<F>, Array, ArrayLike>, opts?: {
2655
+ hasAux?: HA;
2656
+ }) => HA extends true ? ReturnType<F> extends [infer Out, infer Aux] ? [Out, Out, Aux] : never : [ReturnType<F>, ReturnType<F>];
2527
2657
  /**
2528
2658
  * @function
2529
2659
  * Vectorize an operation on a batched axis for one or more inputs.
@@ -2565,28 +2695,100 @@ declare const jit: <F extends (...args: any[]) => JsTree<Array>>(f: F, opts?: Ji
2565
2695
  * Produce a local linear approximation to a function at a point using jvp() and
2566
2696
  * partial evaluation.
2567
2697
  */
2568
- declare const linearize: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (...tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>];
2698
+ declare const linearize: <F extends (...args: any[]) => JsTree<Array>, HA extends boolean = false>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, opts?: {
2699
+ hasAux?: HA;
2700
+ }) => HA extends true ? ReturnType<F> extends [infer Out, infer Aux] ? [Out, OwnedFunction<(...tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => Out>, Aux] : never : [ReturnType<F>, OwnedFunction<(...tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>>];
2569
2701
  /**
2570
2702
  * @function
2571
2703
  * Calculate the reverse-mode vector-Jacobian product for a function.
2704
+ *
2705
+ * The return value is a tuple of `[out, vjpFn]`, where `out` is the output of
2706
+ * `f(primals)`, and `vjpFn` is a function that takes in cotangents for each
2707
+ * output and returns the cotangents for each input.
2708
+ *
2709
+ * When `{ hasAux: true }` is passed, the function `f` is expected to return an
2710
+ * `[out, aux]` tuple, and `vjp` returns `[out, vjpFn, aux]`.
2711
+ *
2712
+ * @example
2713
+ * ```ts
2714
+ * const [y, vjpFn] = vjp(f, [x]);
2715
+ *
2716
+ * // With hasAux
2717
+ * const [y, vjpFn, aux] = vjp(f, [x], { hasAux: true });
2718
+ * ```
2572
2719
  */
2573
- declare const vjp: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (cotangents: MapJsTree<ReturnType<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>];
2720
+ declare const vjp: <F extends (...args: any[]) => JsTree<Array>, const HA extends boolean = false>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, opts?: {
2721
+ hasAux?: HA;
2722
+ }) => HA extends true ? ReturnType<F> extends [infer Out, infer Aux] ? [Out, OwnedFunction<(cotangents: MapJsTree<Out, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>>, Aux] : never : [ReturnType<F>, OwnedFunction<(cotangents: MapJsTree<ReturnType<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>>];
2723
+ /** @inline */
2724
+ type GradOutputType<I, F extends (...args: any[]) => any> = MapJsTree<I extends undefined ? Parameters<F>[0] : I extends number ? Parameters<F>[I] : I extends number[] ? { [K in keyof I]: I[K] extends number ? Parameters<F>[I[K]] : never } : never, ArrayLike, Array>;
2574
2725
  /**
2575
2726
  * @function
2576
2727
  * Compute the gradient of a scalar-valued function `f` with respect to its
2577
2728
  * first argument.
2729
+ *
2730
+ * Pass in different `argnums` to differentiate with respect to other
2731
+ * arguments. If a tuple is provided, the return value will be a tuple of
2732
+ * gradients corresponding to each argument index.
2733
+ *
2734
+ * When `{ hasAux: true }` is passed, the function `f` is expected to return a
2735
+ * `[out, aux]` tuple, and the return value will be `[gradient, aux]`.
2736
+ *
2737
+ * @example
2738
+ * ```ts
2739
+ * const gradient = grad(f)(x);
2740
+ *
2741
+ * // With `argnums`
2742
+ * const [gradientX, gradientZ] = grad(f, { argnums: [0, 2] })(x, y, z);
2743
+ *
2744
+ * // With `hasAux`
2745
+ * const [gradient, aux] = grad(f, { hasAux: true })(x);
2746
+ * ```
2578
2747
  */
2579
- declare const grad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>[0], ArrayLike, Array>;
2748
+ declare const grad: <F extends (...args: any[]) => JsTree<Array>, const I extends undefined | number | number[] = undefined, const HA extends boolean = false>(f: F, opts?: Omit<GradOpts, "argnums" | "hasAux"> & {
2749
+ argnums?: I;
2750
+ hasAux?: HA;
2751
+ }) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => HA extends true ? ReturnType<F> extends [any, infer Aux] ? [GradOutputType<I, F>, Aux] : never : GradOutputType<I, F>;
2580
2752
  /**
2581
2753
  * @function
2582
2754
  * Create a function that evaluates both `f` and the gradient of `f`.
2755
+ *
2756
+ * When `{ hasAux: true }` is passed, the function `f` is expected to return an
2757
+ * `[out, aux]` tuple, and the return value will be `[[out, aux], gradient]`.
2758
+ *
2759
+ * @example
2760
+ * ```ts
2761
+ * // Without hasAux
2762
+ * const [value, gradient] = valueAndGrad(f)(x);
2763
+ *
2764
+ * // With hasAux
2765
+ * const [[value, aux], gradient] = valueAndGrad(f, { hasAux: true })(x);
2766
+ * ```
2583
2767
  */
2584
- declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, MapJsTree<Parameters<F>[0], ArrayLike, Array>];
2768
+ declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>, const I extends undefined | number | number[] = undefined, const HA extends boolean = false>(f: F, opts?: Omit<GradOpts, "argnums"> & {
2769
+ argnums?: I;
2770
+ hasAux?: HA;
2771
+ }) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, GradOutputType<I, F>];
2585
2772
  /**
2586
2773
  * @function
2587
2774
  * Compute the Jacobian evaluated row-by-row by reverse-mode AD.
2588
2775
  */
2589
2776
  declare const jacrev: typeof jacfwd;
2777
+ /**
2778
+ * @function
2779
+ * Compute the Hessian matrix of a scalar-valued function.
2780
+ *
2781
+ * The Hessian is the matrix of second-order partial derivatives of a function.
2782
+ * This is implemented as `jacfwd(grad(f))`.
2783
+ *
2784
+ * @example
2785
+ * ```ts
2786
+ * const f = (x: np.Array) => np.sum(x.ref.mul(x.ref).mul(x)); // x^3
2787
+ * const H = hessian(f)(np.array([1, 2, 3]));
2788
+ * // H[i,j] = d^2f / dx_i dx_j
2789
+ * ```
2790
+ */
2791
+ declare const hessian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
2590
2792
  /**
2591
2793
  * Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
2592
2794
  *
@@ -2609,4 +2811,4 @@ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
2609
2811
  */
2610
2812
  declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
2611
2813
  //#endregion
2612
- export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
2814
+ export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };