@jax-js/jax 0.1.6 → 0.1.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.d.cts CHANGED
@@ -541,1757 +541,1780 @@ declare class Executable<T = any> {
541
541
  source: Kernel | Routine, /** Extra data specific to the backend running this executable. */
542
542
  data: T);
543
543
  }
544
- declare namespace numpy_fft_d_exports {
545
- export { ComplexPair, fft, ifft };
544
+ declare namespace tree_d_exports {
545
+ export { JsTree, JsTreeDef, MapJsTree, NodeType, dispose, flatten, leaves, map, ref, structure, unflatten };
546
546
  }
547
- /**
548
- * A pair of arrays representing real and imaginary part `a + bj`. Both arrays
549
- * must have the same shape.
550
- */
551
- type ComplexPair = {
552
- real: Array;
553
- imag: Array;
547
+ declare enum NodeType {
548
+ Array = "Array",
549
+ Object = "Object",
550
+ Leaf = "Leaf",
551
+ }
552
+ /** Analog to the JAX "pytree" object, but for JavaScript. */
553
+ type JsTree<T> = T | JsTree<T>[] | {
554
+ [key: string]: JsTree<T>;
554
555
  };
555
- /**
556
- * Compute a one-dimensional discrete Fourier transform.
557
- *
558
- * Currently, the size of the axis must be a power of two.
559
- */
560
- declare function fft(a: ComplexPair, axis?: number): ComplexPair;
561
- /**
562
- * Compute a one-dimensional inverse discrete Fourier transform.
563
- *
564
- * Currently, the size of the axis must be a power of two.
565
- */
566
- declare function ifft(a: ComplexPair, axis?: number): ComplexPair;
567
- declare namespace numpy_linalg_d_exports {
568
- export { cholesky$1 as cholesky, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot };
556
+ type Same<X, Y> = (<T>() => T extends X ? 1 : 2) extends (<T>() => T extends Y ? 1 : 2) ? true : false;
557
+ type MappedJsTree<T, A, B> = T extends A ? B : T extends Array ? T : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
558
+ /** @ignore Convert a subtype of JsTree<A> into a JsTree<B>, with the same structure. */
559
+ type MapJsTree<T, A, B> = Same<A, B> extends true ? T : MappedJsTree<T, A, B>;
560
+ /** Represents the structure of a JsTree. */
561
+ declare class JsTreeDef {
562
+ readonly nodeType: NodeType;
563
+ readonly nodeMetadata: any;
564
+ readonly childTreedefs: JsTreeDef[];
565
+ static leaf: JsTreeDef;
566
+ constructor(nodeType: NodeType, nodeMetadata: any,
567
+ // Must be comparable with deepEqual.
568
+ childTreedefs: JsTreeDef[]);
569
+ /** Get the total number of leaves in the tree. */
570
+ get size(): number;
571
+ /** Returns a string representation of this tree definition. */
572
+ toString(root?: boolean): string;
573
+ /** Compare this tree definition with another. */
574
+ equals(other: JsTreeDef): boolean;
575
+ }
576
+ /** Flatten a structured object, returning the tree definition. */
577
+ declare function flatten<T>(tree: JsTree<T>): [T[], JsTreeDef];
578
+ /** Get the leaves of a tree. */
579
+ declare function leaves<T>(tree: JsTree<T>): T[];
580
+ /** Get the treedef for a tree. */
581
+ declare function structure<T>(tree: JsTree<T>): JsTreeDef;
582
+ /** Reconstruct a structured object from the flattened representation. */
583
+ declare function unflatten<T>(treedef: JsTreeDef, leaves: Iterable<T>): JsTree<T>;
584
+ /** Maps a multi-input function over pytree args to produce a new pytree. */
585
+ declare function map<T, U, Tree extends JsTree<T>>(fn: (...args: T[]) => U, tree: Tree, ...rest: Tree[]): MapJsTree<Tree, T, U>;
586
+ /** Take a reference of every array in a tree. */
587
+ declare function ref<Tree extends JsTree<any>>(tree: Tree): Tree;
588
+ /** Dispose every array in a tree. */
589
+ declare function dispose<Tree extends JsTree<any>>(tree: Tree | null | undefined): void;
590
+ //#endregion
591
+ //#region src/frontend/convolution.d.ts
592
+ /** Definition of a general dilated convolution. Should be valid on creation. */
593
+ interface ConvParams {
594
+ vmapDims: number;
595
+ strides: number[];
596
+ padding: Pair[];
597
+ lhsDilation: number[];
598
+ rhsDilation: number[];
569
599
  }
570
600
  /**
571
- * Compute the Cholesky decomposition of a (batched) positive-definite matrix.
601
+ * Check that the shapes and parameters passed to convolution are valid.
602
+ * Expected shapes of the lhs and rhs of the convolution are:
572
603
  *
573
- * This is like `jax.lax.linalg.cholesky()`, except with an option to symmetrize
574
- * the input matrix, which is on by default.
604
+ * - `lhsShape = [*vmapDims, batchSize, inChannels, spatialDims...]`
605
+ * - `rhsShape = [*vmapDims, outChannels, inChannels, kernelSize...]`
606
+ *
607
+ * If the check succeeds, returns the output shape.
575
608
  */
576
- declare function cholesky$1(a: ArrayLike, {
577
- upper,
578
- symmetrizeInput
579
- }?: {
580
- upper?: boolean;
581
- symmetrizeInput?: boolean;
582
- }): Array;
583
- /** Compute the determinant of a square matrix (batched). */
584
- declare function det(a: ArrayLike): Array;
585
- /** Compute the inverse of a square matrix (batched). */
586
- declare function inv(a: ArrayLike): Array;
609
+ //#endregion
610
+ //#region src/frontend/jaxpr.d.ts
587
611
  /**
588
- * Return the least-squares solution to a linear equation.
589
- *
590
- * For overdetermined systems, this finds the `x` that minimizes `norm(ax - b)`.
591
- * For underdetermined systems, this finds the minimum-norm solution for `x`.
592
- *
593
- * This currently uses Cholesky decomposition to solve the normal equations,
594
- * under the hood. The method is not as robust as QR or SVD.
612
+ * Function callback with an associated dispose() method.
595
613
  *
596
- * @param a coefficient matrix of shape `(M, N)`
597
- * @param b right-hand side of shape `(M,)` or `(M, K)`
598
- * @return least-squares solution of shape `(N,)` or `(N, K)`
614
+ * The dispose() method should be called to clean up any tracer resources needed
615
+ * by the function after the last time it is called.
599
616
  */
600
- declare function lstsq(a: ArrayLike, b: ArrayLike): Array;
601
- /** Raise a square matrix to an integer power, via repeated squarings. */
602
- declare function matrixPower(a: ArrayLike, n: number): Array;
603
- /** Return sign and natural logarithm of the determinant of `a`. */
604
- declare function slogdet(a: ArrayLike): [Array, Array];
617
+ type OwnedFunction<F extends Function> = F & {
618
+ dispose: () => void;
619
+ };
620
+ /** Variable in a Jaxpr expression. */
621
+ declare class Var {
622
+ #private;
623
+ readonly id: number;
624
+ readonly aval: ShapedArray;
625
+ constructor(aval: ShapedArray);
626
+ toString(): string;
627
+ }
628
+ /** Literal in a Jaxpr expression. Currently, only scalars are supported. */
629
+ declare class Lit {
630
+ readonly value: number;
631
+ readonly aval: ShapedArray;
632
+ get dtype(): DType;
633
+ constructor(aval: AbstractValue, value: number);
634
+ }
635
+ type Atom = Var | Lit;
636
+ declare class VarPrinter {
637
+ #private;
638
+ names: Map<Var, string>;
639
+ name(v: Var): string;
640
+ nameType(v: Var): string;
641
+ }
642
+ /** A single statement / binding in a Jaxpr, in ANF form. */
643
+ declare class JaxprEqn {
644
+ readonly primitive: Primitive;
645
+ readonly inputs: Atom[];
646
+ readonly params: Record<string, any>;
647
+ readonly outBinders: Var[];
648
+ constructor(primitive: Primitive, inputs: Atom[], params: Record<string, any>, outBinders: Var[]);
649
+ pprint(usedVars?: Set<Var>, vp?: VarPrinter): PPrint;
650
+ toString(): string;
651
+ }
652
+ /** Typed intermediate representation for traced computations. */
653
+ declare class Jaxpr implements FpHashable {
654
+ #private;
655
+ readonly inBinders: Var[];
656
+ readonly eqns: JaxprEqn[];
657
+ readonly outs: Atom[];
658
+ constructor(inBinders: Var[], eqns: JaxprEqn[], outs: Atom[]);
659
+ pprint(): PPrint;
660
+ toString(): string;
661
+ /**
662
+ * Gets a hash of this Jaxpr.
663
+ *
664
+ * Var identity is not considered in the hash, so two Jaxprs with the same
665
+ * order of assignments and operators but different variable IDs will resolve
666
+ * to the same hash (and toString representation).
667
+ */
668
+ getHash(): bigint;
669
+ hash(state: FpHash): void;
670
+ /**
671
+ * Produce a simplified Jaxpr with basic optimizations applied.
672
+ * - Trim away unused variables.
673
+ * - Fold away *1, *0, or +0 operations against literals.
674
+ * - Remove no-op movement operations.
675
+ */
676
+ simplify(): Jaxpr;
677
+ /** Flattens nested Jit in a Jaxpr. Useful for handling jit-of-jit. */
678
+ flatten(): Jaxpr;
679
+ }
680
+ /** Jaxpr with a collection of associated, traced constants. */
681
+ declare class ClosedJaxpr {
682
+ readonly jaxpr: Jaxpr;
683
+ readonly consts: Tracer[];
684
+ constructor(jaxpr: Jaxpr, consts: Tracer[]);
685
+ /** String representation of this Jaxpr. */
686
+ toString(): string;
687
+ /** Apply a function to the underlying Jaxpr. */
688
+ mapJaxpr(f: (jaxpr: Jaxpr) => Jaxpr): ClosedJaxpr;
689
+ /** Dispose of the constants in this Jaxpr. */
690
+ dispose(): void;
691
+ }
692
+ /** @inline */
693
+ type JitOpts = {
694
+ staticArgnums?: number[];
695
+ };
696
+ //#endregion
697
+ //#region src/frontend/core.d.ts
605
698
  /**
606
- * Solve a linear system of equations.
699
+ * Frontend primitive operations, which are lowered into Kernel objects before
700
+ * being dispatched to the backend.
607
701
  *
608
- * This solves a (batched) linear system of equations `a @ x = b` for `x` given
609
- * `a` and `b`. If `a` is singular, this will return `nan` or `inf` values.
702
+ * Any operation between arrays can be described in these parts. This is also
703
+ * the set of primitives that can occur in Jaxpr programs, and the level at
704
+ * which transformations like vmap, grad, and jvp occur. They are loosely based
705
+ * on [XLA](https://openxla.org/xla/operation_semantics).
610
706
  *
611
- * @param a - Coefficient matrix of shape `(..., N, N)`.
612
- * @param b - Values of shape `(N,)` or `(..., N, M)`.
613
- * @returns Solution `x` of shape `(..., N)` or `(..., N, M)`.
707
+ * All n-ary operations support broadcasting, with NumPy semantics.
614
708
  */
615
- declare function solve(a: ArrayLike, b: ArrayLike): Array;
616
- //#endregion
617
- //#region src/library/numpy/dtype-info.d.ts
618
- /** @inline */
619
- type FInfo = Readonly<{
620
- /** The number of bits occupied by the type. */
621
- bits: number;
622
- /** Returns the _dtype_ for which finfo returns information. */
623
- dtype: DType;
624
- /** The difference between 1.0 and the next smallest representable float larger than 1.0. */
625
- eps: number;
626
- /** The difference between 1.0 and the next largest representable float smaller than 1.0. */
627
- epsneg: number;
628
- /** The exponent that yields `eps`. */
629
- machep: number;
630
- /** The largest representable finite number. */
631
- max: number;
632
- /** The smallest positive power of the base (2) that causes overflow. */
633
- maxexp: number;
634
- /** The smallest representable (most negative) finite number. */
635
- min: number;
636
- /** The largest negative power of the base (2) without leading zeros in mantissa. */
637
- minexp: number;
638
- /** The exponent that yields `epsneg`. */
639
- negep: number;
640
- /** Number of bits in the exponent portion. */
641
- nexp: number;
642
- /** Number of bits in the mantissa portion. */
643
- nmant: number;
644
- /** The approximate number of decimal digits to which this kind of float is precise. */
645
- precision: number;
646
- /** The approximate decimal resolution, i.e., `10 ** -precision`. */
647
- resolution: number;
648
- /** The smallest positive normal number. */
649
- smallestNormal: number;
650
- /** The smallest positive subnormal number. */
651
- smallestSubnormal: number;
652
- }>;
653
- /** Machine limits for floating-point types. */
654
- declare function finfo(dtype: DType): FInfo;
655
- /** @inline */
656
- type IInfo = Readonly<{
657
- /** The number of bits occupied by the type. */
658
- bits: number;
659
- /** Returns the _dtype_ for which iinfo returns information. */
660
- dtype: DType;
661
- /** The largest representable integer. */
662
- max: number;
663
- /** The smallest representable integer. */
664
- min: number;
665
- }>;
666
- /** Machine limits for integer types. */
667
- declare function iinfo(dtype: DType): IInfo;
668
- declare namespace numpy_d_exports {
669
- export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot$1 as dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logspace, matmul, matrixTranspose, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
709
+ declare enum Primitive {
710
+ Add = "add",
711
+ Mul = "mul",
712
+ Idiv = "idiv",
713
+ Mod = "mod",
714
+ // uses sign of numerator, C-style, matches JS but not Python
715
+ Min = "min",
716
+ Max = "max",
717
+ Neg = "neg",
718
+ Reciprocal = "reciprocal",
719
+ Floor = "floor",
720
+ Ceil = "ceil",
721
+ StopGradient = "stop_gradient",
722
+ Cast = "cast",
723
+ Bitcast = "bitcast",
724
+ Sin = "sin",
725
+ Cos = "cos",
726
+ Asin = "asin",
727
+ Atan = "atan",
728
+ Exp = "exp",
729
+ Log = "log",
730
+ Erf = "erf",
731
+ Erfc = "erfc",
732
+ Sqrt = "sqrt",
733
+ Reduce = "reduce",
734
+ Dot = "dot",
735
+ // sum(x*y, axis=-1)
736
+ Conv = "conv",
737
+ // see lax.conv_general_dilated
738
+ Pool = "pool",
739
+ PoolTranspose = "pool_transpose",
740
+ Compare = "compare",
741
+ Where = "where",
742
+ Concatenate = "concatenate",
743
+ Split = "split",
744
+ RandomBits = "random_bits",
745
+ Gather = "gather",
746
+ Transpose = "transpose",
747
+ Broadcast = "broadcast",
748
+ Reshape = "reshape",
749
+ Flip = "flip",
750
+ Shrink = "shrink",
751
+ Pad = "pad",
752
+ Sort = "sort",
753
+ // sort(x, axis=-1)
754
+ Argsort = "argsort",
755
+ // argsort(x, axis=-1)
756
+ TriangularSolve = "triangular_solve",
757
+ // A is upper triangular, A @ X.T = B.T
758
+ Cholesky = "cholesky",
759
+ // A is positive-definite, A = L @ L^T
760
+ LU = "lu",
761
+ // LU decomposition with partial pivoting
762
+ Jit = "jit",
670
763
  }
671
- declare const float32 = DType.Float32;
672
- declare const int32 = DType.Int32;
673
- declare const uint32 = DType.Uint32;
674
- declare const bool = DType.Bool;
675
- declare const float16 = DType.Float16;
676
- declare const float64 = DType.Float64;
677
- /** Euler's constant, `e = 2.7182818284590...` */
678
- declare const e: number;
679
- /** Euler-Mascheroni constant, `γ = 0.5772156649...` */
680
- declare const eulerGamma = 0.5772156649015329;
681
- /** Positive infinity. */
682
- declare const inf: number;
683
- /** Floating-point representation of NaN. */
684
- declare const nan: number;
685
- /** This is Pi, `π = 3.14159265358979...` */
686
- declare const pi: number;
687
- /** @function Element-wise addition, with broadcasting. */
688
- declare const add: (x: ArrayLike, y: ArrayLike) => Array;
689
- /** @function Element-wise multiplication, with broadcasting. */
690
- declare const multiply: (x: ArrayLike, y: ArrayLike) => Array;
691
- /** @function Numerical negative of every element of an array. */
692
- declare const negative: (x: ArrayLike) => Array;
693
- /** @function Calculate element-wise reciprocal of the input. This is `1/x`. */
694
- declare const reciprocal: (x: ArrayLike) => Array;
695
- /** @function Round input down to the nearest integer. */
696
- declare const floor: (x: ArrayLike) => Array;
697
- /** @function Round input up to the nearest integer. */
698
- declare const ceil: (x: ArrayLike) => Array;
699
- /** @function Element-wise sine function (takes radians). */
700
- declare const sin: (x: ArrayLike) => Array;
701
- /** @function Element-wise cosine function (takes radians). */
702
- declare const cos: (x: ArrayLike) => Array;
703
- /** @function Element-wise inverse sine function (inverse of sin). */
704
- declare const asin: (x: ArrayLike) => Array;
705
- /** @function Element-wise inverse tangent function (inverse of tan). */
706
- declare const atan: (x: ArrayLike) => Array;
707
- /** @function Calculate the exponential of all elements in the input array. */
708
- declare const exp: (x: ArrayLike) => Array;
709
- /** @function Calculate the natural logarithm of all elements in the input array. */
710
- declare const log: (x: ArrayLike) => Array;
711
- /** @function Calculate the square root of all elements in the input array. */
712
- declare const sqrt: (x: ArrayLike) => Array;
713
- /** @function Return element-wise minimum of the input arrays. */
714
- declare const minimum: (x: ArrayLike, y: ArrayLike) => Array;
715
- /** @function Return element-wise maximum of the input arrays. */
716
- declare const maximum: (x: ArrayLike, y: ArrayLike) => Array;
717
- /** @function Compare two arrays element-wise. */
718
- declare const greater: (x: ArrayLike, y: ArrayLike) => Array;
719
- /** @function Compare two arrays element-wise. */
720
- declare const less: (x: ArrayLike, y: ArrayLike) => Array;
721
- /** @function Compare two arrays element-wise. */
722
- declare const equal: (x: ArrayLike, y: ArrayLike) => Array;
723
- /** @function Compare two arrays element-wise. */
724
- declare const notEqual: (x: ArrayLike, y: ArrayLike) => Array;
725
- /** @function Compare two arrays element-wise. */
726
- declare const greaterEqual: (x: ArrayLike, y: ArrayLike) => Array;
727
- /** @function Compare two arrays element-wise. */
728
- declare const lessEqual: (x: ArrayLike, y: ArrayLike) => Array;
729
- /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
730
- declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
764
+ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
765
+ [Primitive.Cast]: {
766
+ dtype: DType;
767
+ };
768
+ [Primitive.Bitcast]: {
769
+ dtype: DType;
770
+ };
771
+ [Primitive.Reduce]: {
772
+ op: AluOp;
773
+ axis: number[];
774
+ };
775
+ [Primitive.Conv]: ConvParams;
776
+ [Primitive.Pool]: {
777
+ window: number[];
778
+ strides: number[];
779
+ };
780
+ [Primitive.PoolTranspose]: {
781
+ inShape: number[];
782
+ window: number[];
783
+ strides: number[];
784
+ };
785
+ [Primitive.Compare]: {
786
+ op: CompareOp;
787
+ };
788
+ [Primitive.Concatenate]: {
789
+ axis: number;
790
+ };
791
+ [Primitive.Split]: {
792
+ axis: number;
793
+ sizes: number[];
794
+ };
795
+ [Primitive.RandomBits]: {
796
+ shape: number[];
797
+ mode: "xor" | 0 | 1;
798
+ };
799
+ [Primitive.Gather]: {
800
+ axis: number[];
801
+ outDim: number;
802
+ };
803
+ [Primitive.Transpose]: {
804
+ perm: number[];
805
+ };
806
+ [Primitive.Broadcast]: {
807
+ shape: number[];
808
+ axis: number[];
809
+ };
810
+ [Primitive.Reshape]: {
811
+ shape: number[];
812
+ };
813
+ [Primitive.Flip]: {
814
+ axis: number[];
815
+ };
816
+ [Primitive.Shrink]: {
817
+ slice: Pair[];
818
+ };
819
+ [Primitive.Pad]: {
820
+ width: Pair[];
821
+ };
822
+ [Primitive.TriangularSolve]: {
823
+ unitDiagonal: boolean;
824
+ };
825
+ [Primitive.Jit]: {
826
+ name: string;
827
+ jaxpr: Jaxpr;
828
+ numConsts: number;
829
+ };
830
+ }
831
+ /** Type of parameters taken by each primitive. */
832
+ type PrimitiveParams<T extends Primitive> = T extends keyof PrimitiveParamsImpl ? PrimitiveParamsImpl[T] : Record<string, never>;
833
+ declare enum CompareOp {
834
+ Less = "less",
835
+ Equal = "equal",
836
+ NotEqual = "not_equal",
837
+ LessEqual = "less_equal",
838
+ }
839
+ /** @inline */
840
+ type Axis = number | number[] | null;
841
+ /** @inline */
842
+ type ReduceOpts = {
843
+ keepdims?: boolean;
844
+ };
845
+ type MainTrace = {
846
+ level: number;
847
+ traceType: new (main: MainTrace) => Trace;
848
+ globalData: any | null;
849
+ };
731
850
  /**
732
- * @function
733
- * Permute the dimensions of an array. Defaults to reversing the axis order.
851
+ * Push an interpreter onto the trace stack. Use this like:
852
+ * `using main = newMain(...);`
734
853
  */
735
- declare const transpose: (x: ArrayLike, perm?: number[]) => Array;
854
+
855
+ type TracerValue = Tracer | number | boolean;
856
+ declare abstract class Trace {
857
+ readonly main: MainTrace;
858
+ constructor(main: MainTrace);
859
+ abstract pure(val: TracerValue): Tracer;
860
+ abstract lift(val: Tracer): Tracer;
861
+ abstract processPrimitive<P extends Primitive>(primitive: P, tracers: Tracer[], params: PrimitiveParams<P>): Tracer[];
862
+ }
863
+ /** Internal representation of an array value. */
864
+ interface AbstractValue {
865
+ /** Shape of the array. Must be a static tuple of non-negative dimensions. */
866
+ shape: number[];
867
+ /** Concrete data type of array elements. */
868
+ dtype: DType;
869
+ /**
870
+ * Arrays created from JavaScript numbers (e.g., `np.array(3)`) are created as
871
+ * _weakly typed_ unless a dtype is explicitly specified.
872
+ *
873
+ * Weakly typed values will automatically cast to the data type of other
874
+ * arrays when used as an operand as an expression. This property only affects
875
+ * how they promote in type casting; their memory layout is still determined
876
+ * by the actual `dtype` field.
877
+ *
878
+ * ```ts
879
+ * const x = np.array(3); // weakType = true, dtype = float32
880
+ * const y = np.array([1, 2], { dtype: np.int32 }); // weakType = false, dtype = int32
881
+ * const z = x.add(y); // z has dtype int32 because x is weakly typed
882
+ * ```
883
+ *
884
+ * Weak types are present in JIT programs in their spec (e.g., Jaxpr inputs
885
+ * and outputs can be weakly typed) form. But they're solely a frontend
886
+ * concept. Backends are not aware of weak types.
887
+ */
888
+ weakType: boolean;
889
+ }
736
890
  /**
737
- * @function
738
- * Give a new shape to an array without changing its data.
891
+ * Broadcast shapes and promote types with casting for two avals.
739
892
  *
740
- * One shape dimension can be -1. In this case, the value is inferred from the
741
- * length of the array and remaining dimensions.
893
+ * This implements the weak type behavior described in `promoteTypes()`, but not
894
+ * implemented in that function as `weakType` is not passed.
742
895
  */
743
- declare const reshape: (x: ArrayLike, shape: number[]) => Array;
744
- /**
745
- * @function
746
- * Move axes of an array to new positions. Other axes retain original order.
747
- */
748
- declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
749
- /**
750
- * @function
751
- * Add padding (zeros) to an array.
752
- *
753
- * The `width` argument is either an integer or pair of integers, in which case
754
- * all axes are padded with the same width. Or if it is an array of pairs, each
755
- * pair specifies the padding for its corresponding axis.
756
- */
757
- declare const pad: (x: ArrayLike, width: number | Pair | Pair[]) => Array;
758
- /**
759
- * @function
760
- * Return the number of dimensions of an array. Does not consume array reference.
761
- */
762
- declare const ndim: (x: ArrayLike) => number;
763
- /** @function Return the shape of an array. Does not consume array reference. */
764
- declare const shape$1: (x: ArrayLike) => number[];
765
- /**
766
- * @function
767
- * Return an array of zeros with the same shape and type as a given array.
768
- */
769
- declare const zerosLike: (a: ArrayLike, dtype?: DType) => Array;
770
- /**
771
- * @function
772
- * Return an array of ones with the same shape and type as a given array.
773
- */
774
- declare const onesLike: (a: ArrayLike, dtype?: DType) => Array;
775
- /**
776
- * @function
777
- * Return a full array with the same shape and type as a given array.
778
- */
779
- declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, dtype?: DType) => Array;
780
- /**
781
- * Return the number of elements in an array, optionally along an axis.
782
- * Does not consume array reference.
783
- */
784
- declare function size(a: ArrayLike, axis?: number): number;
785
- /** Convert an array to a specified dtype. */
786
- declare function astype(a: ArrayLike, dtype: DType): Array;
787
- /** Sum of the elements of the array over a given axis, or axes. */
788
- declare function sum(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
789
- /** Product of the array elements over a given axis. */
790
- declare function prod(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
791
- /** Return the minimum of array elements along a given axis. */
792
- declare function min(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
793
- /** Return the maximum of array elements along a given axis. */
794
- declare function max(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
795
- /**
796
- * Test whether all array elements along a given axis evaluate to True.
797
- *
798
- * Returns a boolean array with the same shape as `a` with the specified axis
799
- * removed. If axis is None, returns a scalar.
800
- */
801
- declare function all(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
802
- /**
803
- * Test whether any array element along a given axis evaluates to True.
804
- *
805
- * Returns a boolean array with the same shape as `a` with the specified axis
806
- * removed. If axis is None, returns a scalar.
807
- */
808
- declare function any(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
809
- /** Return the peak-to-peak range along a given axis (`max - min`). */
810
- declare function ptp(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
811
- /** Compute the average of the array elements along the specified axis. */
812
- declare function mean(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
813
- /**
814
- * Returns the indices of the minimum values along an axis.
815
- *
816
- * By default, index is into the flatted array, otherwise it is along the
817
- * specified axis.
818
- */
819
- declare function argmin(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
820
- /**
821
- * Returns the indices of the maximum values along an axis.
822
- *
823
- * By default, index is into the flatted array, otherwise it is along the
824
- * specified axis.
825
- */
826
- declare function argmax(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
827
- /**
828
- * Cumulative sum of elements along an axis.
829
- *
830
- * Currently this function is `O(n^2)`, we'll improve this later on with a
831
- * two-phase parallel reduction algorithm.
832
- */
833
- declare function cumsum(a: ArrayLike, axis?: number): Array;
834
- /** Reverse the elements in an array along the given axes. */
835
- declare function flip(x: ArrayLike, axis?: Axis): Array;
836
- /**
837
- * Split an array into multiple sub-arrays along an axis.
838
- *
839
- * @param a - The input array to split.
840
- * @param indicesOrSections - If an integer, it indicates the number of equal
841
- * sections to create along the specified axis. If a list of integers, it
842
- * specifies the indices at which to split the array.
843
- * @param axis - The axis along which to split the array. Default is 0.
844
- */
845
- declare function split$1(a: ArrayLike, indicesOrSections: number | number[], axis?: number): Array[];
846
- /**
847
- * Join a sequence of arrays along an existing axis.
848
- *
849
- * The arrays must have the same shape, except in the dimension corresponding to
850
- * `axis` (the first, by default).
851
- *
852
- * No scalars can be passed to this function, as the axis is then ambiguous.
853
- */
854
- declare function concatenate(xs: Array[], axis?: number): Array;
855
- /**
856
- * Join a sequence of arrays along a new axis.
857
- *
858
- * The `axis` parameter specifies the index of the new axis in the dimensions of
859
- * the result. For example, if `axis=0` it will be the first dimension and if
860
- * `axis=-1` it will be the last dimension.
861
- *
862
- * All shapes must have the same shape.
863
- */
864
- declare function stack(xs: ArrayLike[], axis?: number): Array;
865
- /**
866
- * Horizontally stack arrays. Inputs are promoted to rank at least 1, then
867
- * concatenated along axis 1 (if rank-2 or higher) or 0 (if rank-1).
868
- */
869
- declare function hstack(xs: ArrayLike[]): Array;
870
- /**
871
- * Vertically stack arrays. Inputs are promoted to rank at least 2, then
872
- * concatenated along axis 0.
873
- */
874
- declare function vstack(xs: ArrayLike[]): Array;
875
- /**
876
- * Stack arrays depth-wise. Inputs are promoted to rank at least 3, then
877
- * concatenated along axis 2.
878
- */
879
- declare function dstack(xs: ArrayLike[]): Array;
880
- /**
881
- * Stack arrays column-wise. Inputs are promoted to rank at least 2, then
882
- * concatenated along axis 1.
883
- */
884
- declare function columnStack(xs: ArrayLike[]): Array;
885
- /** Flip an array vertically (axis=0). */
886
- declare function flipud(x: ArrayLike): Array;
887
- /** Flip an array horizontally (axis=1). */
888
- declare function fliplr(x: ArrayLike): Array;
889
- /** Interchange two axes of an array. */
890
- declare function swapaxes(a: ArrayLike, axis1: number, axis2: number): Array;
891
- /** Transpose the last two dimensions of an array. */
892
- declare function matrixTranspose(a: ArrayLike): Array;
893
- /** Return a 1-D flattened array containing the elements of the input. */
894
- declare function ravel(a: ArrayLike): Array;
895
- /** Remove one or more length-1 axes from an array. */
896
- declare function squeeze(a: ArrayLike, axis?: Axis): Array;
896
+
897
+ declare abstract class Tracer {
898
+ /** @ignore */
899
+ readonly _trace: Trace;
900
+ constructor(trace: Trace);
901
+ abstract get aval(): AbstractValue;
902
+ abstract toString(): string;
903
+ /**
904
+ * Access an array by reference, incrementing the reference count.
905
+ *
906
+ * jax-js handles freeing arrays by using "move" semantics, like in Rust/C++.
907
+ * Whenever you pass an array into a function, that function should consume
908
+ * the array, and it will no longer be usable. For example, if you had:
909
+ *
910
+ * ```
911
+ * const x = np.array([1, 2, 3]);
912
+ * const y = np.add(x, x);
913
+ * ```
914
+ *
915
+ * The second line does not work because the first parameter consumes `x`, and
916
+ * then the second parameter will already have been freed / disposed.
917
+ *
918
+ * To fix this, you can write:
919
+ *
920
+ * ```
921
+ * const y = np.add(x.ref, x);
922
+ * ```
923
+ *
924
+ * Under the hood, every access to `.ref` increments the internal reference
925
+ * count of the array. The reference count starts at 1. When it hits 0, the
926
+ * memory behind the array is freed.
927
+ */
928
+ abstract get ref(): this;
929
+ /**
930
+ * Manually decrement the reference count of the array.
931
+ *
932
+ * Arrays are created with reference count 1. Whenever it is used as argument
933
+ * to a function or other operation, it is disposed (i.e., reference count
934
+ * decreases by 1) automatically. Whenever a `.ref` is created, the reference
935
+ * count increases.
936
+ *
937
+ * You generally don't need to call this function directly since arrays are
938
+ * automatically disposed after being passed into an operation. One common
939
+ * exception is when writing a function and ignoring one of its arguments. In
940
+ * that case, by convention you should dispose of that argument manually.
941
+ *
942
+ * ```
943
+ * function myCustomOperation(a: np.Array, b: np.Array) {
944
+ * b.dispose(); // Needed to satisfy "move" rules.
945
+ * return a.add(1);
946
+ * }
947
+ * ```
948
+ */
949
+ abstract dispose(): void;
950
+ /** The shape of the array. */
951
+ get shape(): number[];
952
+ /** The total number of elements in the array. */
953
+ get size(): number;
954
+ /** The dtype of elements stored in the array. */
955
+ get dtype(): DType;
956
+ /**
957
+ * Whether the array is weakly typed.
958
+ *
959
+ * Weakly typed arrays will cast to the dtype of the other operand. See
960
+ * `promoteTypes()` for details.
961
+ */
962
+ get weakType(): boolean;
963
+ /** The number of dimensions of the array. */
964
+ get ndim(): number;
965
+ /** @ignore */
966
+ fullLower(): Tracer;
967
+ neg(): this;
968
+ add(other: this | TracerValue): this;
969
+ mul(other: this | TracerValue): this;
970
+ mod(other: this | TracerValue): this;
971
+ greater(other: this | TracerValue): this;
972
+ less(other: this | TracerValue): this;
973
+ equal(other: this | TracerValue): this;
974
+ notEqual(other: this | TracerValue): this;
975
+ greaterEqual(other: this | TracerValue): this;
976
+ lessEqual(other: this | TracerValue): this;
977
+ /** Sum of the elements of the array over a given axis, or axes. */
978
+ sum(axis?: Axis, opts?: ReduceOpts): this;
979
+ /** Product of the array elements over a given axis. */
980
+ prod(axis?: Axis, opts?: ReduceOpts): this;
981
+ /** Compute the average of the array elements along the specified axis. */
982
+ mean(axis?: Axis, opts?: ReduceOpts): this;
983
+ /** Minimum of the elements of the array along a given axis. */
984
+ min(axis?: Axis, opts?: ReduceOpts): this;
985
+ /** Maximum of the elements of the array along a given axis. */
986
+ max(axis?: Axis, opts?: ReduceOpts): this;
987
+ /** Test whether all array elements along a given axis evaluate to true. */
988
+ all(axis?: Axis, opts?: ReduceOpts): this;
989
+ /** Test whether any array element along a given axis evaluates to true. */
990
+ any(axis?: Axis, opts?: ReduceOpts): this;
991
+ /** Permute the dimensions of an array. Defaults to reversing the axis order. */
992
+ transpose(perm?: number[]): this;
993
+ /**
994
+ * Give a new shape to an array without changing its data.
995
+ *
996
+ * One shape dimension can be -1. In this case, the value is inferred from the
997
+ * length of the array and remaining dimensions.
998
+ */
999
+ reshape(shape: number | number[]): this;
1000
+ /** Copy the array and cast to a specified dtype. */
1001
+ astype(dtype: DType): this;
1002
+ /** Subtract an array from this one. */
1003
+ sub(other: this | TracerValue): this;
1004
+ /** Divide an array by this one. */
1005
+ div(other: this | TracerValue): this;
1006
+ /** Return specified diagonals. See `jax.numpy.diagonal` for full docs. */
1007
+ diagonal(offset?: number, axis1?: number, axis2?: number): this;
1008
+ /** Flatten the array without changing its data. */
1009
+ flatten(): this;
1010
+ /** Flatten the array without changing its data. */
1011
+ ravel(): this;
1012
+ /**
1013
+ * Iterate over the first dimension of this array, returning slices.
1014
+ *
1015
+ * This can be used to destructure arrays. For example:
1016
+ *
1017
+ * ```js
1018
+ * let x = np.array([[1, 2], [3, 4]]);
1019
+ * let [a, b] = x;
1020
+ * console.log(a.js()); // [1, 2]
1021
+ * console.log(b.js()); // [3, 4]
1022
+ * ```
1023
+ */
1024
+ [Symbol.iterator](): IterableIterator<this>;
1025
+ /**
1026
+ * Return a sorted copy of an array in ascending order.
1027
+ *
1028
+ * See `jax.numpy.sort` for full docs.
1029
+ */
1030
+ sort(axis?: number): this;
1031
+ /**
1032
+ * Return the indices that would sort an array. This may not be a stable
1033
+ * sorting algorithm; it need not preserve order of indices in ties.
1034
+ *
1035
+ * See `jax.numpy.argsort` for full docs.
1036
+ */
1037
+ argsort(axis?: number): this;
1038
+ /**
1039
+ * Slice an array along one or more axes.
1040
+ *
1041
+ * This is the equivalent of slicing in Python, e.g. `x[1:3, 2, :, None]`. To
1042
+ * mimic this in JavaScript, we would write:
1043
+ *
1044
+ * ```js
1045
+ * x.slice([1, 3], 2, [], null);
1046
+ * ```
1047
+ *
1048
+ * The `slice` method accepts a variable number of arguments, each of which
1049
+ * can be a number, an empty array, a single-element array, a two-element
1050
+ * array, or `null`. The arguments are interpreted as follows:
1051
+ *
1052
+ * - A number `n` means to access the `n`-th element along that axis, removing
1053
+ * that axis from the resulting shape.
1054
+ * - An empty array `[]` means to keep that axis as-is, like `:` in Python.
1055
+ * - A single-element array `[i]` means to start slicing from index `i`
1056
+ * (inclusive) to the end of the axis, like `x[i:]`.
1057
+ * - A two-element array `[i, j]` means to slice from index `i` (inclusive)
1058
+ * to index `j` (exclusive), like `x[i:j]`.
1059
+ * - `null` means to add a new axis at that position, like `np.newaxis`.
1060
+ *
1061
+ * Like in Python, negative indices are supported, which count from the end of
1062
+ * the axis. For example, `-1` means the last element.
1063
+ *
1064
+ * Strided slices are not yet implemented, so you cannot write `x[::2]` or
1065
+ * similar.
1066
+ *
1067
+ * Advanced indexing by integer arrays is also supported. This translates to
1068
+ * the "gather" primitive, and it allows you to access specific elements of
1069
+ * the array by integer indices stored in another array.
1070
+ */
1071
+ slice(...index: (number | [] | [number] | Pair | null | Tracer)[]): this;
1072
+ }
1073
+ declare class ShapedArray implements AbstractValue {
1074
+ readonly shape: number[];
1075
+ readonly dtype: DType;
1076
+ readonly weakType: boolean;
1077
+ constructor(shape: number[], dtype: DType, weakType: boolean);
1078
+ static fromAval(aval: AbstractValue): ShapedArray;
1079
+ get ndim(): number;
1080
+ get size(): number;
1081
+ scalar(): ShapedArray;
1082
+ toString(): string;
1083
+ equals(other: ShapedArray): boolean;
1084
+ }
1085
+ //#endregion
1086
+ //#region src/frontend/array.d.ts
1087
+ type ArrayLike = Array | number | boolean;
1088
+ /** Version of pureArray with fudged types. */
1089
+
897
1090
  /**
898
- * Expand the shape of an array by inserting new axes of length 1.
899
- *
900
- * @param a - Input array.
901
- * @param axis - Position(s) in the expanded axes where the new axis (or axes)
902
- * is placed. Can be a single integer or an array of integers.
903
- * @returns Array with the number of dimensions increased.
1091
+ * An executable operation that will be dispatched to the backend.
904
1092
  *
905
- * @example
906
- * ```ts
907
- * const x = np.array([1, 2]);
908
- * np.expandDims(x, 0); // Shape [1, 2]
909
- * np.expandDims(x, 1); // Shape [2, 1]
910
- * np.expandDims(x, [0, 2]); // Shape [1, 2, 1]
911
- * ```
1093
+ * This holds a reference to all input buffers used in the operation. After the
1094
+ * operation is dispatched, the references should be released.
912
1095
  */
913
- declare function expandDims(a: ArrayLike, axis: number | number[]): Array;
1096
+ declare class PendingExecute {
1097
+ #private;
1098
+ readonly backend: Backend;
1099
+ readonly source: Kernel | Routine;
1100
+ readonly inputs: Slot[];
1101
+ readonly outputs: Slot[];
1102
+ prepared: Executable | null;
1103
+ submitted: boolean;
1104
+ constructor(backend: Backend, source: Kernel | Routine, inputs: Slot[], outputs: Slot[]);
1105
+ updateRc(delta: number): void;
1106
+ prepare(): Promise<void>;
1107
+ prepareSync(): void;
1108
+ submit(): void;
1109
+ }
1110
+ /** @inline */
1111
+ type DTypeAndDevice = {
1112
+ dtype?: DType;
1113
+ device?: Device;
1114
+ };
1115
+ type ArrayConstructorArgs = {
1116
+ source: AluExp | Slot;
1117
+ st: ShapeTracker;
1118
+ dtype: DType;
1119
+ weakType: boolean;
1120
+ backend: Backend;
1121
+ committed: boolean;
1122
+ pending?: Iterable<PendingExecute>;
1123
+ };
914
1124
  /**
915
- * Repeat each element of an array after themselves.
1125
+ * A multidimensional numeric array with data stored on CPU or GPU.
916
1126
  *
917
- * If no axis is provided, use the flattened input array, and return a flat
918
- * output array.
919
- */
920
- declare function repeat(a: ArrayLike, repeats: number, axis?: number): Array;
921
- /**
922
- * Construct an array by repeating A the number of times given by reps.
1127
+ * This is the library's core data type. Equivalent to `jax.Array` from JAX, or
1128
+ * `torch.Tensor`.
923
1129
  *
924
- * If `A` is an array of shape `(d1, d2, ..., dn)` and `reps` is a sequence of
925
- * integers, the resulting array will have a shape of `(reps[0] * d1,
926
- * reps[1] * d2, ..., reps[n] * dn)`, with `A` tiled along each dimension.
1130
+ * Not to be confused with the JavaScript "Array" constructor. Avoid importing
1131
+ * this into your code's namespace if you're already using the JavaScript
1132
+ * "Array" type by name.
927
1133
  */
928
- declare function tile(a: ArrayLike, reps: number | number[]): Array;
1134
+ declare class Array extends Tracer {
1135
+ #private;
1136
+ /**
1137
+ * @ignore
1138
+ * Constructs an array from source, shape and backend. Note that if the source
1139
+ * is a backend `Slot`, this constructor _takes ownership_ of the slot. It
1140
+ * will be freed when the array is disposed.
1141
+ */
1142
+ constructor(args: ArrayConstructorArgs);
1143
+ /** @ignore */
1144
+ get aval(): ShapedArray;
1145
+ /** Return a simple string representation of the array's dimensions. */
1146
+ toString(): string;
1147
+ get device(): Device;
1148
+ get ref(): this;
1149
+ /** Get the current reference count (for debugging memory management). */
1150
+ get refCount(): number;
1151
+ dispose(): void;
1152
+ /**
1153
+ * Convert this array into a primitive value.
1154
+ *
1155
+ * This only works for scalars (0-dimensional arrays). It lets you get values
1156
+ * "out" of the JAX system. For instance, if `x = np.array(5)`, then you can
1157
+ * evaluate `x + 1` and `x ** 2` to get `6` and `25`, respectively.
1158
+ *
1159
+ * This method is also called for `==` equality.
1160
+ */
1161
+ [Symbol.toPrimitive](): any;
1162
+ /** Realize the array and return it as data. */
1163
+ data(): Promise<DataArray>;
1164
+ /**
1165
+ * Wait for this array to finish evaluation.
1166
+ *
1167
+ * Operations and data loading in jax-js are lazy, so this function ensures
1168
+ * that pending operations are dispatched and fully executed before it
1169
+ * returns.
1170
+ *
1171
+ * If you are mapping from `data()` or `dataSync()`, it will also trigger
1172
+ * dispatch of operations as well.
1173
+ *
1174
+ * **Note:** `jax.blockUntilReady()` is a higher-level API, it calls this
1175
+ * asynchronously for multiple arrays.
1176
+ */
1177
+ blockUntilReady(): Promise<Array>;
1178
+ /**
1179
+ * Realize the array and return it as data. This is a sync variant and not
1180
+ * recommended for performance reasons, as it will block rendering.
1181
+ */
1182
+ dataSync(): DataArray;
1183
+ /**
1184
+ * Convert this array into a JavaScript object.
1185
+ *
1186
+ * This is a blocking operation that will compile all of the shaders and wait
1187
+ * for execution to complete, synchronously. No other JavaScript code on the
1188
+ * site will be run during shader execution.
1189
+ *
1190
+ * To avoid blocking, prefer `jsAsync()` when possible.
1191
+ */
1192
+ js(): any;
1193
+ /** Convert this array into a JavaScript object, asynchronously. */
1194
+ jsAsync(): Promise<any>;
1195
+ /**
1196
+ * Copy an element of an array to a numeric scalar and return it.
1197
+ *
1198
+ * Throws an error if the array does not have a single element. The array must
1199
+ * either be rank-0, or all dimensions of the shape are 1.
1200
+ */
1201
+ item(): number;
1202
+ /** @private Internal plumbing method for Array / Tracer ops. */
1203
+ static _implRules(): typeof implRules;
1204
+ /** @private */
1205
+ _realizeSource(): number;
1206
+ /** @private Put this array on a new backend, asynchronously. */
1207
+ _put(backend: Backend): Promise<Array>;
1208
+ /** @private Put this array on a new backend, synchronously. */
1209
+ _putSync(backend: Backend): Array;
1210
+ }
1211
+ /** Constructor for creating a new array from data. */
1212
+ declare function array(values: Array | DataArray | RecursiveArray<number> | RecursiveArray<boolean>, {
1213
+ shape,
1214
+ dtype,
1215
+ device
1216
+ }?: {
1217
+ shape?: number[];
1218
+ } & DTypeAndDevice): Array;
1219
+ /** If x is a value, lift it into an array, otherwise leave it be. */
1220
+
1221
+ type ImplRule<P extends Primitive> = (tracers: Array[], params: PrimitiveParams<P>) => Array[];
1222
+ declare const implRules: { [P in Primitive]: ImplRule<P> };
1223
+ /** Return a new array of given shape and type, filled with zeros. */
1224
+ declare function zeros(shape: number[], {
1225
+ dtype,
1226
+ device
1227
+ }?: DTypeAndDevice): Array;
1228
+ /** Return a new array of given shape and type, filled with ones. */
1229
+ declare function ones(shape: number[], {
1230
+ dtype,
1231
+ device
1232
+ }?: DTypeAndDevice): Array;
1233
+ /** Return a new array of given shape and type, filled with `fill_value`. */
1234
+ declare function full(shape: number[], fillValue: number | boolean | Array, {
1235
+ dtype,
1236
+ device
1237
+ }?: DTypeAndDevice): Array;
929
1238
  /**
930
- * Broadcast an array to a shape, with NumPy-style broadcasing rules.
1239
+ * Create an identity matrix.
931
1240
  *
932
- * In other words, this lets you append axes to the left, and/or expand
933
- * dimensions where the shape is 1.
1241
+ * If numCols is not provided, it defaults to numRows, i.e., a square identity
1242
+ * matrix with ones on the diagonal.
934
1243
  */
935
- declare function broadcastTo(a: ArrayLike, shape: number[]): Array;
936
- /** Broadcast input shapes to a common output shape. */
937
- declare function broadcastShapes(...shapes: number[][]): number[];
938
- /** Broadcast arrays to a common shape. */
939
- declare function broadcastArrays(...arrays: ArrayLike[]): Array[];
1244
+ declare function eye(numRows: number, numCols?: number, {
1245
+ dtype,
1246
+ device
1247
+ }?: DTypeAndDevice): Array;
1248
+ /** Return the identity matrix, with ones on the main diagonal. */
1249
+ declare function identity$1(n: number, {
1250
+ dtype,
1251
+ device
1252
+ }?: DTypeAndDevice): Array;
940
1253
  /**
941
- * Return specified diagonals.
1254
+ * Return evenly spaced values within a given interval.
942
1255
  *
943
- * If a is 2D, return the diagonal of the array with the given offset. If a is
944
- * 3D or higher, compute diagonals along the two given axes (default: 0, 1).
1256
+ * This can be called with a varying number of arguments, just like the range()
1257
+ * builtin function in Python.
945
1258
  *
946
- * This returns a view over the existing array. The shape of the resulting array
947
- * is determined by removing the two axes along which the diagonal is taken,
948
- * then appending a new axis to the right with holding the diagonals.
949
- */
950
- declare function diagonal(a: ArrayLike, offset?: number, axis1?: number, axis2?: number): Array;
951
- /**
952
- * Extract a diagonal or construct a diagonal array.
1259
+ * - `arange(stop)` is equivalent to `arange(0, stop, 1)`.
1260
+ * - `arange(start, stop)` is equivalent to `arange(start, stop, 1)`.
1261
+ * - `arange(start, stop, step)` creates an array starting at `start`, ending
1262
+ * before `stop`, with a step size of `step`.
953
1263
  *
954
- * If v is a 2D array, return the k-th diagonal of v (as a view). If v is a 1D
955
- * array, return a 2D array with v on the k-th diagonal.
1264
+ * Defaults to an integer data type. This can produce unintended results when
1265
+ * using a non-integer step, so prefer linspace() in those cases.
956
1266
  */
957
- declare function diag(v: ArrayLike, k?: number): Array;
958
- /** Calculate the sum of the diagonal of an array along the given axes. */
959
- declare function trace(a: ArrayLike, offset?: number, axis1?: number, axis2?: number): Array;
1267
+ declare function arange(start: number, stop?: number, step?: number, {
1268
+ dtype,
1269
+ device
1270
+ }?: DTypeAndDevice): Array;
960
1271
  /**
961
- * Return a sorted copy of an array.
1272
+ * Return an array with ones on and below the diagonal and zeros elsewhere.
962
1273
  *
963
- * The array is sorted along a specified axis (the last by default). This may be
964
- * an unstable sort, and it dispatches to device-specific implementation.
1274
+ * If `k` is provided, it specifies the sub-diagonal on and below which the
1275
+ * array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
1276
+ * `k>0` is above it.
965
1277
  */
966
- declare function sort(a: ArrayLike, axis?: number): Array;
1278
+ declare function tri(n: number, m?: number, k?: number, {
1279
+ dtype,
1280
+ device
1281
+ }?: DTypeAndDevice): Array;
1282
+ /** Return the lower triangle of an array. Must be of dimension >= 2. */
1283
+ declare function tril(a: ArrayLike, k?: number): Array;
1284
+ /** Return the upper triangle of an array. Must be of dimension >= 2. */
1285
+ declare function triu(a: ArrayLike, k?: number): Array;
967
1286
  /**
968
- * Return indices that would sort an array. This may be an unstable sorting
969
- * algorithm; it need not preserve order of indices in ties.
1287
+ * Return evenly spaced numbers over a specified interval.
970
1288
  *
971
- * Returns an array of `int32` indices.
1289
+ * Returns _num_ evenly spaced samples, calculated over the interval
1290
+ * [`start`, `stop`]. The endpoint `stop` is included in the result by default,
1291
+ * but this is controlled by the `endpoint` parameter.
972
1292
  *
973
- * The array is sorted along a specified axis (the last by default).
1293
+ * The default data type is Float32. Use arange() for integer steps.
974
1294
  */
975
- declare function argsort(a: ArrayLike, axis?: number): Array;
1295
+ declare function linspace(start: number, stop: number, num?: number, endpoint?: boolean, {
1296
+ dtype,
1297
+ device
1298
+ }?: DTypeAndDevice): Array;
976
1299
  /**
977
- * Take elements from an array along an axis.
1300
+ * Return numbers spaced evenly on a log scale.
978
1301
  *
979
- * This is equivalent to advanced indexing with integer indices over that
980
- * numbered axis. By default, the flattened array is used.
981
- */
982
- declare function take(a: ArrayLike, indices: ArrayLike, axis?: number | null): Array;
983
- /** Return if two arrays are element-wise equal within a tolerance. */
984
- declare function allclose(actual: Parameters<typeof array>[0], expected: Parameters<typeof array>[0], options?: {
985
- rtol?: number;
986
- atol?: number;
987
- }): boolean;
988
- /** Matrix product of two arrays. */
989
- declare function matmul(x: ArrayLike, y: ArrayLike): Array;
990
- /** Dot product of two arrays. */
991
- declare function dot$1(x: ArrayLike, y: ArrayLike): Array;
992
- /**
993
- * Compute the tensor dot product of two N-dimensional arrays.
1302
+ * In linear space, the sequence starts at `base ** start` and ends at
1303
+ * `base ** stop` (see `endpoint` below).
994
1304
  *
995
- * The behavior is determined by `axes`. If an integer `k`, sum over the last
996
- * `k` axes of x and the first `k` axes of y. If a tuple, then the first array
997
- * corresponds to the axes of x and the second to the axes of y.
1305
+ * @param start - `base ** start` is the starting value of the sequence.
1306
+ * @param stop - `base ** stop` is the final value of the sequence, unless `endpoint` is false.
1307
+ * @param num - Number of samples to generate. Default is 50.
1308
+ * @param endpoint - If true, `stop` is the last sample. Otherwise, it is not included. Default is true.
1309
+ * @param base - The base of the log space. Default is 10.
1310
+ * @returns Array of evenly spaced values on a log scale.
998
1311
  */
999
- declare function tensordot(x: ArrayLike, y: ArrayLike, axes?: number | [number[], number[]]): Array;
1312
+ declare function logspace(start: number, stop: number, num?: number, endpoint?: boolean, base?: number, {
1313
+ dtype,
1314
+ device
1315
+ }?: DTypeAndDevice): Array;
1316
+ //#endregion
1317
+ //#region src/frontend/linearize.d.ts
1318
+ /** @inline */
1319
+ type GradOpts = {
1320
+ /**
1321
+ * Integer or sequence of integers. Specifies which positional argument(s) to
1322
+ * differentiate with respect to.
1323
+ *
1324
+ * Defaults to `0` (the first argument).
1325
+ */
1326
+ argnums?: number | number[];
1327
+ /**
1328
+ * The input function returns a pair of `[out, aux]` including an auxiliary
1329
+ * value. This `aux` is not differentiated, but is returned alongside the
1330
+ * gradient when evaluating the function.
1331
+ */
1332
+ hasAux?: boolean;
1333
+ };
1334
+ declare namespace lax_linalg_d_exports {
1335
+ export { cholesky$1 as cholesky, lu, triangularSolve };
1336
+ }
1000
1337
  /**
1001
- * Einstein summation with string subscripts.
1338
+ * Compute the Cholesky decomposition of a symmetric positive-definite matrix.
1339
+ *
1340
+ * The Cholesky decomposition of a matrix `A` is:
1341
+ *
1342
+ * - A = L @ L^T (for upper=false, default)
1343
+ * - A = U^T @ U (for upper=true)
1344
+ *
1345
+ * where `L` is a lower-triangular matrix and `U` is an upper-triangular matrix.
1346
+ * The input matrix must be symmetric and positive-definite.
1002
1347
  *
1003
1348
  * @example
1004
1349
  * ```ts
1005
- * import { numpy as np } from "@jax-js/jax";
1350
+ * import { lax, numpy as np } from "@jax-js/jax";
1006
1351
  *
1007
- * const a = np.ones([2, 3]);
1008
- * const b = np.ones([3]);
1009
- * np.einsum("ij,j", a, b); // Shape [2]
1352
+ * const x = np.array([[2., 1.], [1., 2.]]);
1353
+ *
1354
+ * // Lower Cholesky factorization (default):
1355
+ * const L = lax.linalg.cholesky(x);
1356
+ * // L ≈ [[1.4142135, 0], [0.70710677, 1.2247449]]
1357
+ *
1358
+ * // Upper Cholesky factorization:
1359
+ * const U = lax.linalg.cholesky(x, { upper: true });
1360
+ * // U ≈ [[1.4142135, 0.70710677], [0, 1.2247449]]
1010
1361
  * ```
1011
1362
  */
1012
- declare function einsum(subscripts: string, ...operands: ArrayLike[]): Array;
1363
+ declare function cholesky$1(a: ArrayLike, {
1364
+ upper
1365
+ }?: {
1366
+ upper?: boolean;
1367
+ }): Array;
1013
1368
  /**
1014
- * Einstein summation alternating between arrays and numeric indices.
1369
+ * LU decomposition with partial pivoting.
1370
+ *
1371
+ * Computes the matrix decomposition: `P @ A = L @ U`, where `P` is a
1372
+ * permutation of the rows of `A`, `L` is lower-triangular with unit diagonal,
1373
+ * and `U` is upper-triangular.
1374
+ *
1375
+ * @param x - A batch of matrices with shape `[..., m, n]`.
1376
+ *
1377
+ * @returns A tuple `(lu, pivots, permutation)` where:
1378
+ * - `lu`: combined lower and upper triangular matrices.
1379
+ * - `pivots`: an array of pivot indices with shape `[..., min(m, n)]`.
1380
+ * - `permutation`: the permutation generated by pivots with shape `[..., m]`.
1015
1381
  *
1016
1382
  * @example
1017
1383
  * ```ts
1018
- * import { numpy as np } from "@jax-js/jax";
1384
+ * import { lax, numpy as np } from "@jax-js/jax";
1019
1385
  *
1020
- * const a = np.ones([2, 3]);
1021
- * const b = np.ones([3]);
1022
- * np.einsum(a, [0, 1], b, [1]); // Shape [2]
1386
+ * const A = np.array([[4., 3.], [6., 3.]]);
1387
+ * const [lu, pivots, permutation] = lax.linalg.lu(A);
1388
+ * // lu ≈ [[6., 3.], [0.6666667, 1.0]]
1389
+ * // pivots = [1, 1]
1390
+ * // permutation = [1, 0]
1023
1391
  * ```
1024
1392
  */
1025
- declare function einsum(...args: (ArrayLike | number[])[]): Array;
1393
+ declare function lu(x: ArrayLike): [Array, Array, Array];
1026
1394
  /**
1027
- * Compute the inner product of two arrays.
1395
+ * Solve a triangular linear system.
1028
1396
  *
1029
- * Unlike `jax.numpy.matmul()` or `jax.numpy.dot()`, this always performs a
1030
- * contraction on the last axis.
1397
+ * Solves `a @ x = b` (if leftSide=true) or `x @ a = b` (if leftSide=false)
1398
+ * where `a` is a triangular matrix.
1031
1399
  *
1032
- * Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
1033
- */
1034
- declare function inner(x: ArrayLike, y: ArrayLike): Array;
1035
- /**
1036
- * Compute the outer product of two arrays.
1400
+ * @example
1401
+ * ```ts
1402
+ * import { lax, numpy as np } from "@jax-js/jax";
1037
1403
  *
1038
- * If the input arrays are not 1D, they will be flattened. Returned array will
1039
- * be of shape `[x.size, y.size]`.
1404
+ * const L = np.array([[2., 0.], [1., 3.]]);
1405
+ * const b = np.array([4., 7.]).reshape([2, 1]);
1406
+ *
1407
+ * // Solve L @ x = b
1408
+ * const x = lax.linalg.triangularSolve(L, b, { leftSide: true, lower: true });
1409
+ * // x = [[2.], [5./3.]]
1410
+ * ```
1040
1411
  */
1041
- declare function outer(x: ArrayLike, y: ArrayLike): Array;
1042
- /** Vector dot product of two arrays along a given axis. */
1043
- declare function vecdot(x: ArrayLike, y: ArrayLike, {
1044
- axis
1412
+ declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
1413
+ leftSide,
1414
+ lower,
1415
+ transposeA,
1416
+ unitDiagonal
1045
1417
  }?: {
1046
- axis?: number;
1418
+ leftSide?: boolean;
1419
+ lower?: boolean;
1420
+ transposeA?: boolean;
1421
+ unitDiagonal?: boolean;
1047
1422
  }): Array;
1423
+ declare namespace lax_d_exports {
1424
+ export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
1425
+ }
1048
1426
  /**
1049
- * Return the dot product of two vectors.
1427
+ * Dimension numbers for general `dot()` primitive.
1050
1428
  *
1051
- * Like vecdot() but flattens the arguments first into vectors.
1429
+ * Contracting dimensions act as a tensor contraction (reduction) along the
1430
+ * given axis. They must be the same size in both operands. Batch dimensions
1431
+ * are treated as vectorized, leading batch dimensions.
1432
+ *
1433
+ * The return value has a shape where the first dimensions are shared batch
1434
+ * dimensions, followed by `lhs` non-contracting dimensions, followed by
1435
+ * `rhs` non-contracting dimensions.
1052
1436
  */
1053
- declare function vdot(x: ArrayLike, y: ArrayLike): Array;
1054
- /** Convolution of two one-dimensional arrays. */
1055
- declare function convolve(x: Array, y: Array, mode?: "full" | "same" | "valid"): Array;
1056
- /** Correlation of two one dimensional arrays. */
1057
- declare function correlate(x: Array, y: Array, mode?: "full" | "same" | "valid"): Array;
1437
+ type DotDimensionNumbers = {
1438
+ lhsContractingDims?: number[];
1439
+ rhsContractingDims?: number[];
1440
+ lhsBatchDims?: number[];
1441
+ rhsBatchDims?: number[];
1442
+ };
1443
+ /**
1444
+ * General dot product/contraction operator.
1445
+ *
1446
+ * Prefer higher-level functions like `jax.numpy.dot()`, `jax.numpy.matmul()`,
1447
+ * `jax.numpy.tensordot(), and `jax.numpy.einsum()` where possible.
1448
+ */
1449
+ declare function dot$1(lhs: Array, rhs: Array, {
1450
+ lhsContractingDims: lc,
1451
+ rhsContractingDims: rc,
1452
+ lhsBatchDims: lb,
1453
+ rhsBatchDims: rb
1454
+ }?: DotDimensionNumbers): Array;
1455
+ type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | Pair[];
1058
1456
  /**
1059
- * Return a tuple of coordinate matrices from coordinate vectors.
1457
+ * General n-dimensional convolution operator, with optional dilation.
1060
1458
  *
1061
- * Make N-D coordinate arrays for vectorized evaluations of N-D scalar/vector
1062
- * fields over N-D grids, given one-dimensional coordinate arrays x1, x2,…, xn.
1459
+ * The semantics of this operation mimic the `jax.lax.conv_general_dilated`
1460
+ * function in JAX, which wraps XLA's general convolution operator.
1461
+ *
1462
+ * @param lhs - Input tensor; shape `[N, C_in, ...xs]`
1463
+ * @param rhs - Convolution kernel; shape `[C_out, C_in / G, ...ks]`
1464
+ * @param windowStrides - Strides for each spatial dimension
1465
+ * @param padding - Padding for each spatial dimension, or a string
1466
+ * (`"VALID"`, `"SAME"`, or `"SAME_LOWER"`)
1063
1467
  */
1064
- declare function meshgrid(xs: Array[], {
1065
- indexing
1468
+ declare function convGeneralDilated(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, {
1469
+ lhsDilation,
1470
+ rhsDilation,
1471
+ featureGroupCount
1066
1472
  }?: {
1067
- indexing?: "xy" | "ij";
1068
- }): Array[];
1473
+ lhsDilation?: number[];
1474
+ rhsDilation?: number[];
1475
+ featureGroupCount?: number;
1476
+ }): Array;
1477
+ /** Convenience wrapper around `convGeneralDilated`. */
1478
+ declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, lhsDilation?: number[], rhsDilation?: number[]): Array;
1479
+ /** Convenience wrapper around `convGeneralDilated`. */
1480
+ declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType): Array;
1069
1481
  /**
1070
- * Clip (limit) the values in an array.
1482
+ * Convenience wrapper for calculating the N-d convolution "transpose".
1071
1483
  *
1072
- * Given an interval, values outside the interval are clipped to the interval
1073
- * edges. For example, if an interval of [0, 1] is specified, values smaller
1074
- * than 0 become 0, and values larger than 1 become 1.
1484
+ * This function directly calculates a fractionally strided conv rather than
1485
+ * indirectly calculating the gradient (transpose) of a forward convolution.
1486
+ * It is equivalent to the JAX version, except:
1075
1487
  *
1076
- * If either bound is undefined, it is ignored.
1077
- */
1078
- declare function clip(a: ArrayLike, min?: ArrayLike, max?: ArrayLike): Array;
1079
- /**
1080
- * Calculate the absolute value element-wise.
1488
+ * - The `use_consistent_padding` option is not available. We only have the
1489
+ * consistent padding case (JAX version >0.8.4).
1490
+ * - The order of dimensions matches `lax.conv_general_dilated`.
1081
1491
  *
1082
- * This is the same function as `jax.numpy.abs()`.
1492
+ * Unlike PyTorch/TensorFlow, by default we don't reverse the kernel's spatial
1493
+ * dimensions or the `(C_out, C_in)` axis order. To get this behavior, set
1494
+ * `transposeKernel` to true.
1495
+ *
1496
+ * @param lhs - Input tensor; shape `[N, C_in, ...xs]`
1497
+ * @param rhs - Convolution kernel; shape `[C_out, C_in, ...ks]`
1498
+ * @param strides - Sequence of n integers, sets fractional stride
1499
+ * @param padding - Apply padding of `dilation * (kernel_size - 1) - padding` to
1500
+ * each side of the input, so it acts like gradient of `conv()`
1501
+ * @param rhsDilation - Atrous dilation for the kernel
1502
+ * @param transposeKernel - Flip spatial axes and swap the input/output channels
1503
+ * of the kernel; its shape should be `[C_in, C_out, ...ks]`
1083
1504
  */
1084
- declare function absolute(x: ArrayLike): Array;
1085
- /** Return an element-wise indication of sign of the input. */
1086
- declare function sign(x: ArrayLike): Array;
1087
- /** @function Return element-wise positive values of the input (no-op). */
1088
- declare const positive: (x: ArrayLike) => Array;
1505
+ declare function convTranspose(lhs: Array, rhs: Array, strides: number[], padding: PaddingType, {
1506
+ rhsDilation,
1507
+ transposeKernel
1508
+ }?: {
1509
+ rhsDilation?: number[];
1510
+ transposeKernel?: boolean;
1511
+ }): Array;
1512
+ /** Reduce a computation over padded windows. */
1513
+ declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
1514
+ /** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
1515
+ declare function erf(x: ArrayLike): Array;
1089
1516
  /**
1090
- * Return the Hamming window of size M, a taper with a weighted cosine bell.
1517
+ * The complementary error function: `erfc(x) = 1 - erf(x)`.
1091
1518
  *
1092
- * `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
1519
+ * This function is more accurate than `1 - erf(x)` for large values of `x`,
1520
+ * where `erf(x)` is very close to 1.
1093
1521
  */
1094
- declare function hamming(M: number): Array;
1522
+ declare function erfc(x: ArrayLike): Array;
1095
1523
  /**
1096
- * Return the Hann window of size M, a taper with a weighted cosine bell.
1524
+ * Stops gradient computation.
1097
1525
  *
1098
- * `w(n) = 0.5 - 0.5 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
1526
+ * Behaves as the identity function but prevents the flow of gradients during
1527
+ * forward or reverse-mode automatic differentiation.
1099
1528
  */
1100
- declare function hann(M: number): Array;
1529
+ declare function stopGradient(x: ArrayLike): Array;
1530
+ declare namespace numpy_fft_d_exports {
1531
+ export { ComplexPair, fft, ifft };
1532
+ }
1101
1533
  /**
1102
- * @function
1103
- * Compute the Heaviside step function. It is defined piecewise:
1104
- * - `heaviside(x1, x2) = 0` for `x1 < 0`,
1105
- * - `heaviside(x1, x2) = x2` for `x1 == 0`,
1106
- * - `heaviside(x1, x2) = 1` for `x1 > 0`.
1534
+ * A pair of arrays representing real and imaginary part `a + bj`. Both arrays
1535
+ * must have the same shape.
1107
1536
  */
1108
- declare const heaviside: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
1109
- /** Calculate element-wise square of the input array. */
1110
- declare function square(x: ArrayLike): Array;
1111
- /** Element-wise tangent function (takes radians). */
1112
- declare function tan(x: ArrayLike): Array;
1537
+ type ComplexPair = {
1538
+ real: Array;
1539
+ imag: Array;
1540
+ };
1113
1541
  /**
1114
- * @function
1115
- * Return the normalized sinc function.
1542
+ * Compute a one-dimensional discrete Fourier transform.
1116
1543
  *
1117
- * The sinc function is defined as `sin(πx) / (πx)` for `x != 0`, and `1` for `x = 0`.
1118
- * This is the normalized sinc function commonly used in signal processing.
1544
+ * Currently, the size of the axis must be a power of two.
1545
+ */
1546
+ declare function fft(a: ComplexPair, axis?: number): ComplexPair;
1547
+ /**
1548
+ * Compute a one-dimensional inverse discrete Fourier transform.
1119
1549
  *
1120
- * **Note:** JVP is not supported at x=0 due to discontinuous derivative. This
1121
- * requires a custom JVP rule to handle properly (see JAX implementation).
1550
+ * Currently, the size of the axis must be a power of two.
1122
1551
  */
1123
- declare const sinc: OwnedFunction<(x: ArrayLike) => Array>;
1124
- /** Element-wise inverse cosine function (inverse of cos). */
1125
- declare function acos(x: ArrayLike): Array;
1552
+ declare function ifft(a: ComplexPair, axis?: number): ComplexPair;
1553
+ declare namespace numpy_linalg_d_exports {
1554
+ export { cholesky, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot };
1555
+ }
1126
1556
  /**
1127
- * @function
1128
- * Return element-wise hypotenuse for the given legs of a right triangle.
1557
+ * Compute the Cholesky decomposition of a (batched) positive-definite matrix.
1129
1558
  *
1130
- * In the original NumPy/JAX implementation, this function is more numerically
1131
- * stable than `sqrt(x1**2 + x2**2)`. We don't currently implement those
1132
- * stability improvements.
1559
+ * This is like `jax.lax.linalg.cholesky()`, except with an option to symmetrize
1560
+ * the input matrix, which is on by default.
1133
1561
  */
1134
- declare const hypot: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
1562
+ declare function cholesky(a: ArrayLike, {
1563
+ upper,
1564
+ symmetrizeInput
1565
+ }?: {
1566
+ upper?: boolean;
1567
+ symmetrizeInput?: boolean;
1568
+ }): Array;
1569
+ /** Compute the determinant of a square matrix (batched). */
1570
+ declare function det(a: ArrayLike): Array;
1571
+ /** Compute the inverse of a square matrix (batched). */
1572
+ declare function inv(a: ArrayLike): Array;
1135
1573
  /**
1136
- * @function
1137
- * Element-wise arc tangent of y/x with correct quadrant.
1574
+ * Return the least-squares solution to a linear equation.
1138
1575
  *
1139
- * Returns the angle in radians between the positive x-axis and the point (x, y).
1140
- * The result is in the range [-π, π].
1576
+ * For overdetermined systems, this finds the `x` that minimizes `norm(ax - b)`.
1577
+ * For underdetermined systems, this finds the minimum-norm solution for `x`.
1141
1578
  *
1142
- * Uses numerically stable formulas:
1143
- * - When x >= 0: atan2(y, x) = 2 * atan(y / (sqrt(x^2 + y^2) + x))
1144
- * - When x < 0: atan2(y, x) = 2 * atan((sqrt(x^2 + y^2) - x) / y)
1579
+ * This currently uses Cholesky decomposition to solve the normal equations,
1580
+ * under the hood. The method is not as robust as QR or SVD.
1145
1581
  *
1146
- * The output is ill-defined when both x and y are zero.
1582
+ * @param a coefficient matrix of shape `(M, N)`
1583
+ * @param b right-hand side of shape `(M,)` or `(M, K)`
1584
+ * @return least-squares solution of shape `(N,)` or `(N, K)`
1147
1585
  */
1148
- declare const atan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
1149
- /** Element-wise subtraction, with broadcasting. */
1150
- declare function subtract(x: ArrayLike, y: ArrayLike): Array;
1151
- /** Calculates the floating-point division of x by y element-wise. */
1152
- declare function trueDivide(x: ArrayLike, y: ArrayLike): Array;
1586
+ declare function lstsq(a: ArrayLike, b: ArrayLike): Array;
1587
+ /** Raise a square matrix to an integer power, via repeated squarings. */
1588
+ declare function matrixPower(a: ArrayLike, n: number): Array;
1589
+ /** Return sign and natural logarithm of the determinant of `a`. */
1590
+ declare function slogdet(a: ArrayLike): [Array, Array];
1153
1591
  /**
1154
- * Return the largest integer smaller or equal to the division of the inputs.
1155
- *
1156
- * The result is always rounded towards negative infinity.
1592
+ * Solve a linear system of equations.
1157
1593
  *
1158
- * For floating-point inputs, this is equivalent to `floor(x / y)`.
1159
- * For integer inputs, we use `(x - remainder(x, y)) / y` to handle
1160
- * negative values correctly (note: may overflow near int32 boundaries).
1594
+ * This solves a (batched) linear system of equations `a @ x = b` for `x` given
1595
+ * `a` and `b`. If `a` is singular, this will return `nan` or `inf` values.
1161
1596
  *
1162
- * @param x - Dividend array.
1163
- * @param y - Divisor array.
1164
- * @returns Element-wise floor division of x by y.
1597
+ * @param a - Coefficient matrix of shape `(..., N, N)`.
1598
+ * @param b - Values of shape `(N,)` or `(..., N, M)`.
1599
+ * @returns Solution `x` of shape `(..., N)` or `(..., N, M)`.
1165
1600
  */
1166
- declare function floorDivide(x: ArrayLike, y: ArrayLike): Array;
1601
+ declare function solve(a: ArrayLike, b: ArrayLike): Array;
1602
+ //#endregion
1603
+ //#region src/library/numpy/dtype-info.d.ts
1604
+ /** @inline */
1605
+ type FInfo = Readonly<{
1606
+ /** The number of bits occupied by the type. */
1607
+ bits: number;
1608
+ /** Returns the _dtype_ for which finfo returns information. */
1609
+ dtype: DType;
1610
+ /** The difference between 1.0 and the next smallest representable float larger than 1.0. */
1611
+ eps: number;
1612
+ /** The difference between 1.0 and the next largest representable float smaller than 1.0. */
1613
+ epsneg: number;
1614
+ /** The exponent that yields `eps`. */
1615
+ machep: number;
1616
+ /** The largest representable finite number. */
1617
+ max: number;
1618
+ /** The smallest positive power of the base (2) that causes overflow. */
1619
+ maxexp: number;
1620
+ /** The smallest representable (most negative) finite number. */
1621
+ min: number;
1622
+ /** The largest negative power of the base (2) without leading zeros in mantissa. */
1623
+ minexp: number;
1624
+ /** The exponent that yields `epsneg`. */
1625
+ negep: number;
1626
+ /** Number of bits in the exponent portion. */
1627
+ nexp: number;
1628
+ /** Number of bits in the mantissa portion. */
1629
+ nmant: number;
1630
+ /** The approximate number of decimal digits to which this kind of float is precise. */
1631
+ precision: number;
1632
+ /** The approximate decimal resolution, i.e., `10 ** -precision`. */
1633
+ resolution: number;
1634
+ /** The smallest positive normal number. */
1635
+ smallestNormal: number;
1636
+ /** The smallest positive subnormal number. */
1637
+ smallestSubnormal: number;
1638
+ }>;
1639
+ /** Machine limits for floating-point types. */
1640
+ declare function finfo(dtype: DType): FInfo;
1641
+ /** @inline */
1642
+ type IInfo = Readonly<{
1643
+ /** The number of bits occupied by the type. */
1644
+ bits: number;
1645
+ /** Returns the _dtype_ for which iinfo returns information. */
1646
+ dtype: DType;
1647
+ /** The largest representable integer. */
1648
+ max: number;
1649
+ /** The smallest representable integer. */
1650
+ min: number;
1651
+ }>;
1652
+ /** Machine limits for integer types. */
1653
+ declare function iinfo(dtype: DType): IInfo;
1654
+ declare namespace numpy_d_exports {
1655
+ export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logspace, matmul, matrixTranspose, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, nanToNum, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
1656
+ }
1657
+ declare const float32 = DType.Float32;
1658
+ declare const int32 = DType.Int32;
1659
+ declare const uint32 = DType.Uint32;
1660
+ declare const bool = DType.Bool;
1661
+ declare const float16 = DType.Float16;
1662
+ declare const float64 = DType.Float64;
1663
+ /** Euler's constant, `e = 2.7182818284590...` */
1664
+ declare const e: number;
1665
+ /** Euler-Mascheroni constant, `γ = 0.5772156649...` */
1666
+ declare const eulerGamma = 0.5772156649015329;
1667
+ /** Positive infinity. */
1668
+ declare const inf: number;
1669
+ /** Floating-point representation of NaN. */
1670
+ declare const nan: number;
1671
+ /** This is Pi, `π = 3.14159265358979...` */
1672
+ declare const pi: number;
1673
+ /** @function Element-wise addition, with broadcasting. */
1674
+ declare const add: (x: ArrayLike, y: ArrayLike) => Array;
1675
+ /** @function Element-wise multiplication, with broadcasting. */
1676
+ declare const multiply: (x: ArrayLike, y: ArrayLike) => Array;
1677
+ /** @function Numerical negative of every element of an array. */
1678
+ declare const negative: (x: ArrayLike) => Array;
1679
+ /** @function Calculate element-wise reciprocal of the input. This is `1/x`. */
1680
+ declare const reciprocal: (x: ArrayLike) => Array;
1681
+ /** @function Round input down to the nearest integer. */
1682
+ declare const floor: (x: ArrayLike) => Array;
1683
+ /** @function Round input up to the nearest integer. */
1684
+ declare const ceil: (x: ArrayLike) => Array;
1685
+ /** @function Element-wise sine function (takes radians). */
1686
+ declare const sin: (x: ArrayLike) => Array;
1687
+ /** @function Element-wise cosine function (takes radians). */
1688
+ declare const cos: (x: ArrayLike) => Array;
1689
+ /** @function Element-wise inverse sine function (inverse of sin). */
1690
+ declare const asin: (x: ArrayLike) => Array;
1691
+ /** @function Element-wise inverse tangent function (inverse of tan). */
1692
+ declare const atan: (x: ArrayLike) => Array;
1693
+ /** @function Calculate the exponential of all elements in the input array. */
1694
+ declare const exp: (x: ArrayLike) => Array;
1695
+ /** @function Calculate the natural logarithm of all elements in the input array. */
1696
+ declare const log: (x: ArrayLike) => Array;
1697
+ /** @function Calculate the square root of all elements in the input array. */
1698
+ declare const sqrt: (x: ArrayLike) => Array;
1699
+ /** @function Return element-wise minimum of the input arrays. */
1700
+ declare const minimum: (x: ArrayLike, y: ArrayLike) => Array;
1701
+ /** @function Return element-wise maximum of the input arrays. */
1702
+ declare const maximum: (x: ArrayLike, y: ArrayLike) => Array;
1703
+ /** @function Compare two arrays element-wise. */
1704
+ declare const greater: (x: ArrayLike, y: ArrayLike) => Array;
1705
+ /** @function Compare two arrays element-wise. */
1706
+ declare const less: (x: ArrayLike, y: ArrayLike) => Array;
1707
+ /** @function Compare two arrays element-wise. */
1708
+ declare const equal: (x: ArrayLike, y: ArrayLike) => Array;
1709
+ /** @function Compare two arrays element-wise. */
1710
+ declare const notEqual: (x: ArrayLike, y: ArrayLike) => Array;
1711
+ /** @function Compare two arrays element-wise. */
1712
+ declare const greaterEqual: (x: ArrayLike, y: ArrayLike) => Array;
1713
+ /** @function Compare two arrays element-wise. */
1714
+ declare const lessEqual: (x: ArrayLike, y: ArrayLike) => Array;
1715
+ /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
1716
+ declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
1167
1717
  /**
1168
1718
  * @function
1169
- * Calculate element-wise floating-point modulo operation.
1719
+ * Permute the dimensions of an array. Defaults to reversing the axis order.
1170
1720
  */
1171
- declare const fmod: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
1721
+ declare const transpose: (x: ArrayLike, perm?: number[]) => Array;
1172
1722
  /**
1173
1723
  * @function
1174
- * Calculate element-wise remainder of the division (matches sign of y).
1175
- */
1176
- declare const remainder: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
1177
- /**
1178
- * Return element-wise quotient and remainder simultaneously.
1179
- *
1180
- * Equivalent to `[floorDivide(x, y), remainder(x, y)]`.
1181
- *
1182
- * @param x - Dividend array.
1183
- * @param y - Divisor array.
1184
- * @returns Tuple of [quotient, remainder].
1185
- */
1186
- declare function divmod(x: ArrayLike, y: ArrayLike): [Array, Array];
1187
- /** Round input to the nearest integer towards zero. */
1188
- declare function trunc(x: ArrayLike): Array;
1189
- /**
1190
- * Compute `x1 * 2 ** x2` as a standard multiplication and exponentiation.
1191
- *
1192
- * This is the inverse of `frexp()`.
1193
- */
1194
- declare function ldexp(x1: ArrayLike, x2: ArrayLike): Array;
1195
- /**
1196
- * Decompose floating-point values into mantissa and two's exponent.
1724
+ * Give a new shape to an array without changing its data.
1197
1725
  *
1198
- * The mantissa is returned in the range `(-1, 1)` with magnitude `>= 0.5` if
1199
- * `x != 0`, and the exponent is an integer such that
1200
- * `x = mantissa * 2**exponent`.
1201
- */
1202
- declare function frexp(x: ArrayLike): [Array, Array];
1203
- /** Calculate `2**p` for all p in the input array. */
1204
- declare function exp2(p: ArrayLike): Array;
1205
- /** Return the base-2 logarithm of x, element-wise. */
1206
- declare function log2(x: ArrayLike): Array;
1207
- /** Return the base-10 logarithm of x, element-wise. */
1208
- declare function log10(x: ArrayLike): Array;
1209
- /** Calculate `exp(x) - 1` element-wise. */
1210
- declare function expm1(x: ArrayLike): Array;
1211
- /** Calculate the natural logarithm of `1 + x` element-wise. */
1212
- declare function log1p(x: ArrayLike): Array;
1213
- /** Convert angles from degrees to radians. */
1214
- declare function deg2rad(x: ArrayLike): Array;
1215
- /** @function Alias of `jax.numpy.deg2rad()`. */
1216
- declare const radians: typeof deg2rad;
1217
- /** Convert angles from radians to degrees. */
1218
- declare function rad2deg(x: ArrayLike): Array;
1219
- /** @function Alias of `jax.numpy.rad2deg()`. */
1220
- declare const degrees: typeof rad2deg;
1221
- /**
1222
- * @function
1223
- * Computes first array raised to power of second array, element-wise.
1726
+ * One shape dimension can be -1. In this case, the value is inferred from the
1727
+ * length of the array and remaining dimensions.
1224
1728
  */
1225
- declare const power: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
1226
- /** @function Calculate the element-wise cube root of the input array. */
1227
- declare const cbrt: OwnedFunction<(x: ArrayLike) => Array>;
1729
+ declare const reshape: (x: ArrayLike, shape: number[]) => Array;
1228
1730
  /**
1229
1731
  * @function
1230
- * Calculate element-wise hyperbolic sine of input.
1231
- *
1232
- * `sinh(x) = (exp(x) - exp(-x)) / 2`
1732
+ * Move axes of an array to new positions. Other axes retain original order.
1233
1733
  */
1234
- declare const sinh: OwnedFunction<(x: ArrayLike) => Array>;
1734
+ declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
1235
1735
  /**
1236
1736
  * @function
1237
- * Calculate element-wise hyperbolic cosine of input.
1737
+ * Add padding (zeros) to an array.
1238
1738
  *
1239
- * `cosh(x) = (exp(x) + exp(-x)) / 2`
1739
+ * The `width` argument is either an integer or pair of integers, in which case
1740
+ * all axes are padded with the same width. Or if it is an array of pairs, each
1741
+ * pair specifies the padding for its corresponding axis.
1240
1742
  */
1241
- declare const cosh: OwnedFunction<(x: ArrayLike) => Array>;
1743
+ declare const pad: (x: ArrayLike, width: number | Pair | Pair[] | Record<number, Pair>) => Array;
1242
1744
  /**
1243
1745
  * @function
1244
- * Calculate element-wise hyperbolic tangent of input.
1245
- *
1246
- * `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
1746
+ * Return the number of dimensions of an array. Does not consume array reference.
1247
1747
  */
1248
- declare const tanh: OwnedFunction<(x: ArrayLike) => Array>;
1748
+ declare const ndim: (x: ArrayLike) => number;
1749
+ /** @function Return the shape of an array. Does not consume array reference. */
1750
+ declare const shape$1: (x: ArrayLike) => number[];
1249
1751
  /**
1250
1752
  * @function
1251
- * Calculate element-wise inverse hyperbolic sine of input.
1252
- *
1253
- * `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
1753
+ * Return an array of zeros with the same shape and type as a given array.
1254
1754
  */
1255
- declare const arcsinh: OwnedFunction<(x: ArrayLike) => Array>;
1755
+ declare const zerosLike: (a: ArrayLike, dtype?: DType) => Array;
1256
1756
  /**
1257
1757
  * @function
1258
- * Calculate element-wise inverse hyperbolic cosine of input.
1259
- *
1260
- * `arccosh(x) = ln(x + sqrt(x^2 - 1))`
1758
+ * Return an array of ones with the same shape and type as a given array.
1261
1759
  */
1262
- declare const arccosh: OwnedFunction<(x: ArrayLike) => Array>;
1760
+ declare const onesLike: (a: ArrayLike, dtype?: DType) => Array;
1263
1761
  /**
1264
1762
  * @function
1265
- * Calculate element-wise inverse hyperbolic tangent of input.
1266
- *
1267
- * `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
1268
- */
1269
- declare const arctanh: OwnedFunction<(x: ArrayLike) => Array>;
1270
- /**
1271
- * Compute the variance of an array.
1272
- *
1273
- * The variance is computed for the flattened array by default, otherwise over
1274
- * the specified axis.
1275
- *
1276
- * If `correction` is provided, the divisor in calculation is `N - correction`,
1277
- * where `N` represents the number of elements (e.g., for Bessel's correction).
1278
- */
1279
- declare function var_(x: ArrayLike, axis?: Axis, opts?: {
1280
- mean?: ArrayLike;
1281
- correction?: number;
1282
- } & ReduceOpts): Array;
1283
- /**
1284
- * Compute the standard deviation of an array.
1285
- *
1286
- * The standard deviation is computed for the flattened array by default,
1287
- * otherwise over the specified axis.
1288
- *
1289
- * If `correction` is provided, the divisor in calculation is `N - correction`,
1290
- * where `N` represents the number of elements (e.g., for Bessel's correction).
1763
+ * Return a full array with the same shape and type as a given array.
1291
1764
  */
1292
- declare function std(x: ArrayLike, axis?: Axis, opts?: {
1293
- mean?: ArrayLike;
1294
- correction?: number;
1295
- } & ReduceOpts): Array;
1296
- /** Estimate the sample covariance of a set of variables. */
1297
- declare function cov(x: ArrayLike, y?: ArrayLike | null, {
1298
- rowvar
1299
- }?: {
1300
- rowvar?: boolean;
1301
- }): Array;
1302
- /** Compute the Pearson correlation coefficients (in range `[-1, 1]`). */
1303
- declare function corrcoef(x: ArrayLike, y?: ArrayLike): Array;
1304
- /** Test element-wise for positive or negative infinity, return bool array. */
1305
- declare function isinf(x: ArrayLike): Array;
1306
- /** Test element-wise for NaN (Not a Number). */
1307
- declare function isnan(x: ArrayLike): Array;
1308
- /** Test element-wise for negative infinity, return bool array. */
1309
- declare function isneginf(x: ArrayLike): Array;
1310
- /** Test element-wise for positive infinity, return bool array. */
1311
- declare function isposinf(x: ArrayLike): Array;
1765
+ declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, dtype?: DType) => Array;
1312
1766
  /**
1313
- * @function
1314
- * Test element-wise for finite values (not infinity or NaN).
1767
+ * Return the number of elements in an array, optionally along an axis.
1768
+ * Does not consume array reference.
1315
1769
  */
1316
- declare const isfinite: OwnedFunction<(x: ArrayLike) => Array>;
1317
- declare namespace tree_d_exports {
1318
- export { JsTree, JsTreeDef, MapJsTree, NodeType, dispose, flatten, leaves, map, ref, structure, unflatten };
1319
- }
1320
- declare enum NodeType {
1321
- Array = "Array",
1322
- Object = "Object",
1323
- Leaf = "Leaf",
1324
- }
1325
- /** Analog to the JAX "pytree" object, but for JavaScript. */
1326
- type JsTree<T> = T | JsTree<T>[] | {
1327
- [key: string]: JsTree<T>;
1328
- };
1329
- type Same<X, Y> = (<T>() => T extends X ? 1 : 2) extends (<T>() => T extends Y ? 1 : 2) ? true : false;
1330
- type MappedJsTree<T, A, B> = T extends A ? B : T extends Array ? T : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
1331
- /** @ignore Convert a subtype of JsTree<A> into a JsTree<B>, with the same structure. */
1332
- type MapJsTree<T, A, B> = Same<A, B> extends true ? T : MappedJsTree<T, A, B>;
1333
- /** Represents the structure of a JsTree. */
1334
- declare class JsTreeDef {
1335
- readonly nodeType: NodeType;
1336
- readonly nodeMetadata: any;
1337
- readonly childTreedefs: JsTreeDef[];
1338
- static leaf: JsTreeDef;
1339
- constructor(nodeType: NodeType, nodeMetadata: any,
1340
- // Must be comparable with deepEqual.
1341
- childTreedefs: JsTreeDef[]);
1342
- /** Get the total number of leaves in the tree. */
1343
- get size(): number;
1344
- /** Returns a string representation of this tree definition. */
1345
- toString(root?: boolean): string;
1346
- /** Compare this tree definition with another. */
1347
- equals(other: JsTreeDef): boolean;
1348
- }
1349
- /** Flatten a structured object, returning the tree definition. */
1350
- declare function flatten<T>(tree: JsTree<T>): [T[], JsTreeDef];
1351
- /** Get the leaves of a tree. */
1352
- declare function leaves<T>(tree: JsTree<T>): T[];
1353
- /** Get the treedef for a tree. */
1354
- declare function structure<T>(tree: JsTree<T>): JsTreeDef;
1355
- /** Reconstruct a structured object from the flattened representation. */
1356
- declare function unflatten<T>(treedef: JsTreeDef, leaves: Iterable<T>): JsTree<T>;
1357
- /** Maps a multi-input function over pytree args to produce a new pytree. */
1358
- declare function map<T, U, Tree extends JsTree<T>>(fn: (...args: T[]) => U, tree: Tree, ...rest: Tree[]): MapJsTree<Tree, T, U>;
1359
- /** Take a reference of every array in a tree. */
1360
- declare function ref<Tree extends JsTree<Array>>(tree: Tree): Tree;
1361
- /** Dispose every array in a tree. */
1362
- declare function dispose<Tree extends JsTree<Array>>(tree: Tree | null | undefined): void;
1363
- //#endregion
1364
- //#region src/frontend/convolution.d.ts
1365
- /** Definition of a general dilated convolution. Should be valid on creation. */
1366
- interface ConvParams {
1367
- vmapDims: number;
1368
- strides: number[];
1369
- padding: Pair[];
1370
- lhsDilation: number[];
1371
- rhsDilation: number[];
1372
- }
1770
+ declare function size(a: ArrayLike, axis?: number): number;
1771
+ /** Convert an array to a specified dtype. */
1772
+ declare function astype(a: ArrayLike, dtype: DType): Array;
1773
+ /** Sum of the elements of the array over a given axis, or axes. */
1774
+ declare function sum(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1775
+ /** Product of the array elements over a given axis. */
1776
+ declare function prod(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1777
+ /** Return the minimum of array elements along a given axis. */
1778
+ declare function min(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1779
+ /** Return the maximum of array elements along a given axis. */
1780
+ declare function max(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1373
1781
  /**
1374
- * Check that the shapes and parameters passed to convolution are valid.
1375
- * Expected shapes of the lhs and rhs of the convolution are:
1376
- *
1377
- * - `lhsShape = [*vmapDims, batchSize, inChannels, spatialDims...]`
1378
- * - `rhsShape = [*vmapDims, outChannels, inChannels, kernelSize...]`
1782
+ * Test whether any array element along a given axis evaluates to True.
1379
1783
  *
1380
- * If the check succeeds, returns the output shape.
1784
+ * Returns a boolean array with the same shape as `a` with the specified axis
1785
+ * removed. If axis is None, returns a scalar.
1381
1786
  */
1382
- //#endregion
1383
- //#region src/frontend/jaxpr.d.ts
1787
+ declare function any(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1384
1788
  /**
1385
- * Function callback with an associated dispose() method.
1789
+ * Test whether all array elements along a given axis evaluate to True.
1386
1790
  *
1387
- * The dispose() method should be called to clean up any tracer resources needed
1388
- * by the function after the last time it is called.
1791
+ * Returns a boolean array with the same shape as `a` with the specified axis
1792
+ * removed. If axis is None, returns a scalar.
1389
1793
  */
1390
- type OwnedFunction<F extends Function> = F & {
1391
- dispose: () => void;
1392
- };
1393
- /** Variable in a Jaxpr expression. */
1394
- declare class Var {
1395
- #private;
1396
- readonly id: number;
1397
- readonly aval: ShapedArray;
1398
- constructor(aval: ShapedArray);
1399
- toString(): string;
1400
- }
1401
- /** Literal in a Jaxpr expression. Currently, only scalars are supported. */
1402
- declare class Lit {
1403
- readonly value: number;
1404
- readonly aval: ShapedArray;
1405
- get dtype(): DType;
1406
- constructor(aval: AbstractValue, value: number);
1407
- }
1408
- type Atom = Var | Lit;
1409
- declare class VarPrinter {
1410
- #private;
1411
- names: Map<Var, string>;
1412
- name(v: Var): string;
1413
- nameType(v: Var): string;
1414
- }
1415
- /** A single statement / binding in a Jaxpr, in ANF form. */
1416
- declare class JaxprEqn {
1417
- readonly primitive: Primitive;
1418
- readonly inputs: Atom[];
1419
- readonly params: Record<string, any>;
1420
- readonly outBinders: Var[];
1421
- constructor(primitive: Primitive, inputs: Atom[], params: Record<string, any>, outBinders: Var[]);
1422
- pprint(usedVars?: Set<Var>, vp?: VarPrinter): PPrint;
1423
- toString(): string;
1424
- }
1425
- /** Typed intermediate representation for traced computations. */
1426
- declare class Jaxpr implements FpHashable {
1427
- #private;
1428
- readonly inBinders: Var[];
1429
- readonly eqns: JaxprEqn[];
1430
- readonly outs: Atom[];
1431
- constructor(inBinders: Var[], eqns: JaxprEqn[], outs: Atom[]);
1432
- pprint(): PPrint;
1433
- toString(): string;
1434
- /**
1435
- * Gets a hash of this Jaxpr.
1436
- *
1437
- * Var identity is not considered in the hash, so two Jaxprs with the same
1438
- * order of assignments and operators but different variable IDs will resolve
1439
- * to the same hash (and toString representation).
1440
- */
1441
- getHash(): bigint;
1442
- hash(state: FpHash): void;
1443
- /**
1444
- * Produce a simplified Jaxpr with basic optimizations applied.
1445
- * - Trim away unused variables.
1446
- * - Fold away *1, *0, or +0 operations against literals.
1447
- * - Remove no-op movement operations.
1448
- */
1449
- simplify(): Jaxpr;
1450
- /** Flattens nested Jit in a Jaxpr. Useful for handling jit-of-jit. */
1451
- flatten(): Jaxpr;
1452
- }
1453
- /** Jaxpr with a collection of associated, traced constants. */
1454
- declare class ClosedJaxpr {
1455
- readonly jaxpr: Jaxpr;
1456
- readonly consts: Tracer[];
1457
- constructor(jaxpr: Jaxpr, consts: Tracer[]);
1458
- /** String representation of this Jaxpr. */
1459
- toString(): string;
1460
- /** Apply a function to the underlying Jaxpr. */
1461
- mapJaxpr(f: (jaxpr: Jaxpr) => Jaxpr): ClosedJaxpr;
1462
- /** Dispose of the constants in this Jaxpr. */
1463
- dispose(): void;
1464
- }
1465
- /** @inline */
1466
- type JitOpts = {
1467
- staticArgnums?: number[];
1468
- };
1469
- //#endregion
1470
- //#region src/frontend/core.d.ts
1794
+ declare function all(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1795
+ /** Return the peak-to-peak range along a given axis (`max - min`). */
1796
+ declare function ptp(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1797
+ /** Compute the average of the array elements along the specified axis. */
1798
+ declare function mean(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1471
1799
  /**
1472
- * Frontend primitive operations, which are lowered into Kernel objects before
1473
- * being dispatched to the backend.
1474
- *
1475
- * Any operation between arrays can be described in these parts. This is also
1476
- * the set of primitives that can occur in Jaxpr programs, and the level at
1477
- * which transformations like vmap, grad, and jvp occur. They are loosely based
1478
- * on [XLA](https://openxla.org/xla/operation_semantics).
1479
- *
1480
- * All n-ary operations support broadcasting, with NumPy semantics.
1481
- */
1482
- declare enum Primitive {
1483
- Add = "add",
1484
- Mul = "mul",
1485
- Idiv = "idiv",
1486
- Mod = "mod",
1487
- // uses sign of numerator, C-style, matches JS but not Python
1488
- Min = "min",
1489
- Max = "max",
1490
- Neg = "neg",
1491
- Reciprocal = "reciprocal",
1492
- Floor = "floor",
1493
- Ceil = "ceil",
1494
- StopGradient = "stop_gradient",
1495
- Cast = "cast",
1496
- Bitcast = "bitcast",
1497
- Sin = "sin",
1498
- Cos = "cos",
1499
- Asin = "asin",
1500
- Atan = "atan",
1501
- Exp = "exp",
1502
- Log = "log",
1503
- Erf = "erf",
1504
- Erfc = "erfc",
1505
- Sqrt = "sqrt",
1506
- Reduce = "reduce",
1507
- Dot = "dot",
1508
- // sum(x*y, axis=-1)
1509
- Conv = "conv",
1510
- // see lax.conv_general_dilated
1511
- Pool = "pool",
1512
- PoolTranspose = "pool_transpose",
1513
- Compare = "compare",
1514
- Where = "where",
1515
- Concatenate = "concatenate",
1516
- Split = "split",
1517
- RandomBits = "random_bits",
1518
- Gather = "gather",
1519
- Transpose = "transpose",
1520
- Broadcast = "broadcast",
1521
- Reshape = "reshape",
1522
- Flip = "flip",
1523
- Shrink = "shrink",
1524
- Pad = "pad",
1525
- Sort = "sort",
1526
- // sort(x, axis=-1)
1527
- Argsort = "argsort",
1528
- // argsort(x, axis=-1)
1529
- TriangularSolve = "triangular_solve",
1530
- // A is upper triangular, A @ X.T = B.T
1531
- Cholesky = "cholesky",
1532
- // A is positive-definite, A = L @ L^T
1533
- LU = "lu",
1534
- // LU decomposition with partial pivoting
1535
- Jit = "jit",
1536
- }
1537
- interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
1538
- [Primitive.Cast]: {
1539
- dtype: DType;
1540
- };
1541
- [Primitive.Bitcast]: {
1542
- dtype: DType;
1543
- };
1544
- [Primitive.Reduce]: {
1545
- op: AluOp;
1546
- axis: number[];
1547
- };
1548
- [Primitive.Conv]: ConvParams;
1549
- [Primitive.Pool]: {
1550
- window: number[];
1551
- strides: number[];
1552
- };
1553
- [Primitive.PoolTranspose]: {
1554
- inShape: number[];
1555
- window: number[];
1556
- strides: number[];
1557
- };
1558
- [Primitive.Compare]: {
1559
- op: CompareOp;
1560
- };
1561
- [Primitive.Concatenate]: {
1562
- axis: number;
1563
- };
1564
- [Primitive.Split]: {
1565
- axis: number;
1566
- sizes: number[];
1567
- };
1568
- [Primitive.RandomBits]: {
1569
- shape: number[];
1570
- mode: "xor" | 0 | 1;
1571
- };
1572
- [Primitive.Gather]: {
1573
- axis: number[];
1574
- outDim: number;
1575
- };
1576
- [Primitive.Transpose]: {
1577
- perm: number[];
1578
- };
1579
- [Primitive.Broadcast]: {
1580
- shape: number[];
1581
- axis: number[];
1582
- };
1583
- [Primitive.Reshape]: {
1584
- shape: number[];
1585
- };
1586
- [Primitive.Flip]: {
1587
- axis: number[];
1588
- };
1589
- [Primitive.Shrink]: {
1590
- slice: Pair[];
1591
- };
1592
- [Primitive.Pad]: {
1593
- width: Pair[];
1594
- };
1595
- [Primitive.TriangularSolve]: {
1596
- unitDiagonal: boolean;
1597
- };
1598
- [Primitive.Jit]: {
1599
- name: string;
1600
- jaxpr: Jaxpr;
1601
- numConsts: number;
1602
- };
1603
- }
1604
- /** Type of parameters taken by each primitive. */
1605
- type PrimitiveParams<T extends Primitive> = T extends keyof PrimitiveParamsImpl ? PrimitiveParamsImpl[T] : Record<string, never>;
1606
- declare enum CompareOp {
1607
- Less = "less",
1608
- Equal = "equal",
1609
- NotEqual = "not_equal",
1610
- LessEqual = "less_equal",
1611
- }
1612
- /** @inline */
1613
- type Axis = number | number[] | null;
1614
- /** @inline */
1615
- type ReduceOpts = {
1616
- keepdims?: boolean;
1617
- };
1618
- type MainTrace = {
1619
- level: number;
1620
- traceType: new (main: MainTrace) => Trace;
1621
- globalData: any | null;
1622
- };
1800
+ * Returns the indices of the minimum values along an axis.
1801
+ *
1802
+ * By default, index is into the flatted array, otherwise it is along the
1803
+ * specified axis.
1804
+ */
1805
+ declare function argmin(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
1623
1806
  /**
1624
- * Push an interpreter onto the trace stack. Use this like:
1625
- * `using main = newMain(...);`
1807
+ * Returns the indices of the maximum values along an axis.
1808
+ *
1809
+ * By default, index is into the flatted array, otherwise it is along the
1810
+ * specified axis.
1626
1811
  */
1627
-
1628
- type TracerValue = Tracer | number | boolean;
1629
- declare abstract class Trace {
1630
- readonly main: MainTrace;
1631
- constructor(main: MainTrace);
1632
- abstract pure(val: TracerValue): Tracer;
1633
- abstract lift(val: Tracer): Tracer;
1634
- abstract processPrimitive<P extends Primitive>(primitive: P, tracers: Tracer[], params: PrimitiveParams<P>): Tracer[];
1635
- }
1636
- /** Internal representation of an array value. */
1637
- interface AbstractValue {
1638
- /** Shape of the array. Must be a static tuple of non-negative dimensions. */
1639
- shape: number[];
1640
- /** Concrete data type of array elements. */
1641
- dtype: DType;
1642
- /**
1643
- * Arrays created from JavaScript numbers (e.g., `np.array(3)`) are created as
1644
- * _weakly typed_ unless a dtype is explicitly specified.
1645
- *
1646
- * Weakly typed values will automatically cast to the data type of other
1647
- * arrays when used as an operand as an expression. This property only affects
1648
- * how they promote in type casting; their memory layout is still determined
1649
- * by the actual `dtype` field.
1650
- *
1651
- * ```ts
1652
- * const x = np.array(3); // weakType = true, dtype = float32
1653
- * const y = np.array([1, 2], { dtype: np.int32 }); // weakType = false, dtype = int32
1654
- * const z = x.add(y); // z has dtype int32 because x is weakly typed
1655
- * ```
1656
- *
1657
- * Weak types are present in JIT programs in their spec (e.g., Jaxpr inputs
1658
- * and outputs can be weakly typed) form. But they're solely a frontend
1659
- * concept. Backends are not aware of weak types.
1660
- */
1661
- weakType: boolean;
1662
- }
1812
+ declare function argmax(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
1663
1813
  /**
1664
- * Broadcast shapes and promote types with casting for two avals.
1814
+ * Cumulative sum of elements along an axis.
1665
1815
  *
1666
- * This implements the weak type behavior described in `promoteTypes()`, but not
1667
- * implemented in that function as `weakType` is not passed.
1816
+ * Currently this function is `O(n^2)`, we'll improve this later on with a
1817
+ * two-phase parallel reduction algorithm.
1668
1818
  */
1669
-
1670
- declare abstract class Tracer {
1671
- /** @ignore */
1672
- readonly _trace: Trace;
1673
- constructor(trace: Trace);
1674
- abstract get aval(): AbstractValue;
1675
- abstract toString(): string;
1676
- /**
1677
- * Access an array by reference, incrementing the reference count.
1678
- *
1679
- * jax-js handles freeing arrays by using "move" semantics, like in Rust/C++.
1680
- * Whenever you pass an array into a function, that function should consume
1681
- * the array, and it will no longer be usable. For example, if you had:
1682
- *
1683
- * ```
1684
- * const x = np.array([1, 2, 3]);
1685
- * const y = np.add(x, x);
1686
- * ```
1687
- *
1688
- * The second line does not work because the first parameter consumes `x`, and
1689
- * then the second parameter will already have been freed / disposed.
1690
- *
1691
- * To fix this, you can write:
1692
- *
1693
- * ```
1694
- * const y = np.add(x.ref, x);
1695
- * ```
1696
- *
1697
- * Under the hood, every access to `.ref` increments the internal reference
1698
- * count of the array. The reference count starts at 1. When it hits 0, the
1699
- * memory behind the array is freed.
1700
- */
1701
- abstract get ref(): this;
1702
- /**
1703
- * Manually decrement the reference count of the array.
1704
- *
1705
- * Arrays are created with reference count 1. Whenever it is used as argument
1706
- * to a function or other operation, it is disposed (i.e., reference count
1707
- * decreases by 1) automatically. Whenever a `.ref` is created, the reference
1708
- * count increases.
1709
- *
1710
- * You generally don't need to call this function directly since arrays are
1711
- * automatically disposed after being passed into an operation. One common
1712
- * exception is when writing a function and ignoring one of its arguments. In
1713
- * that case, by convention you should dispose of that argument manually.
1714
- *
1715
- * ```
1716
- * function myCustomOperation(a: np.Array, b: np.Array) {
1717
- * b.dispose(); // Needed to satisfy "move" rules.
1718
- * return a.add(1);
1719
- * }
1720
- * ```
1721
- */
1722
- abstract dispose(): void;
1723
- /** The shape of the array. */
1724
- get shape(): number[];
1725
- /** The total number of elements in the array. */
1726
- get size(): number;
1727
- /** The dtype of elements stored in the array. */
1728
- get dtype(): DType;
1729
- /**
1730
- * Whether the array is weakly typed.
1731
- *
1732
- * Weakly typed arrays will cast to the dtype of the other operand. See
1733
- * `promoteTypes()` for details.
1734
- */
1735
- get weakType(): boolean;
1736
- /** The number of dimensions of the array. */
1737
- get ndim(): number;
1738
- /** @ignore */
1739
- fullLower(): Tracer;
1740
- neg(): this;
1741
- add(other: this | TracerValue): this;
1742
- mul(other: this | TracerValue): this;
1743
- mod(other: this | TracerValue): this;
1744
- greater(other: this | TracerValue): this;
1745
- less(other: this | TracerValue): this;
1746
- equal(other: this | TracerValue): this;
1747
- notEqual(other: this | TracerValue): this;
1748
- greaterEqual(other: this | TracerValue): this;
1749
- lessEqual(other: this | TracerValue): this;
1750
- /** Sum of the elements of the array over a given axis, or axes. */
1751
- sum(axis?: Axis, opts?: ReduceOpts): this;
1752
- /** Product of the array elements over a given axis. */
1753
- prod(axis?: Axis, opts?: ReduceOpts): this;
1754
- /** Compute the average of the array elements along the specified axis. */
1755
- mean(axis?: Axis, opts?: ReduceOpts): this;
1756
- /** Permute the dimensions of an array. Defaults to reversing the axis order. */
1757
- transpose(perm?: number[]): this;
1758
- /**
1759
- * Give a new shape to an array without changing its data.
1760
- *
1761
- * One shape dimension can be -1. In this case, the value is inferred from the
1762
- * length of the array and remaining dimensions.
1763
- */
1764
- reshape(shape: number | number[]): this;
1765
- /** Copy the array and cast to a specified dtype. */
1766
- astype(dtype: DType): this;
1767
- /** Subtract an array from this one. */
1768
- sub(other: this | TracerValue): this;
1769
- /** Divide an array by this one. */
1770
- div(other: this | TracerValue): this;
1771
- /** Return specified diagonals. See `jax.numpy.diagonal` for full docs. */
1772
- diagonal(offset?: number, axis1?: number, axis2?: number): this;
1773
- /** Flatten the array without changing its data. */
1774
- flatten(): this;
1775
- /** Flatten the array without changing its data. */
1776
- ravel(): this;
1777
- /**
1778
- * Iterate over the first dimension of this array, returning slices.
1779
- *
1780
- * This can be used to destructure arrays. For example:
1781
- *
1782
- * ```js
1783
- * let x = np.array([[1, 2], [3, 4]]);
1784
- * let [a, b] = x;
1785
- * console.log(a.js()); // [1, 2]
1786
- * console.log(b.js()); // [3, 4]
1787
- * ```
1788
- */
1789
- [Symbol.iterator](): IterableIterator<this>;
1790
- /**
1791
- * Return a sorted copy of an array in ascending order.
1792
- *
1793
- * See `jax.numpy.sort` for full docs.
1794
- */
1795
- sort(axis?: number): this;
1796
- /**
1797
- * Return the indices that would sort an array. This may not be a stable
1798
- * sorting algorithm; it need not preserve order of indices in ties.
1799
- *
1800
- * See `jax.numpy.argsort` for full docs.
1801
- */
1802
- argsort(axis?: number): this;
1803
- /**
1804
- * Slice an array along one or more axes.
1805
- *
1806
- * This is the equivalent of slicing in Python, e.g. `x[1:3, 2, :, None]`. To
1807
- * mimic this in JavaScript, we would write:
1808
- *
1809
- * ```js
1810
- * x.slice([1, 3], 2, [], null);
1811
- * ```
1812
- *
1813
- * The `slice` method accepts a variable number of arguments, each of which
1814
- * can be a number, an empty array, a single-element array, a two-element
1815
- * array, or `null`. The arguments are interpreted as follows:
1816
- *
1817
- * - A number `n` means to access the `n`-th element along that axis, removing
1818
- * that axis from the resulting shape.
1819
- * - An empty array `[]` means to keep that axis as-is, like `:` in Python.
1820
- * - A single-element array `[i]` means to start slicing from index `i`
1821
- * (inclusive) to the end of the axis, like `x[i:]`.
1822
- * - A two-element array `[i, j]` means to slice from index `i` (inclusive)
1823
- * to index `j` (exclusive), like `x[i:j]`.
1824
- * - `null` means to add a new axis at that position, like `np.newaxis`.
1825
- *
1826
- * Like in Python, negative indices are supported, which count from the end of
1827
- * the axis. For example, `-1` means the last element.
1828
- *
1829
- * Strided slices are not yet implemented, so you cannot write `x[::2]` or
1830
- * similar.
1831
- *
1832
- * Advanced indexing by integer arrays is also supported. This translates to
1833
- * the "gather" primitive, and it allows you to access specific elements of
1834
- * the array by integer indices stored in another array.
1835
- */
1836
- slice(...index: (number | [] | [number] | Pair | null | Tracer)[]): this;
1837
- }
1838
- declare class ShapedArray implements AbstractValue {
1839
- readonly shape: number[];
1840
- readonly dtype: DType;
1841
- readonly weakType: boolean;
1842
- constructor(shape: number[], dtype: DType, weakType: boolean);
1843
- static fromAval(aval: AbstractValue): ShapedArray;
1844
- get ndim(): number;
1845
- get size(): number;
1846
- scalar(): ShapedArray;
1847
- toString(): string;
1848
- equals(other: ShapedArray): boolean;
1849
- }
1850
- //#endregion
1851
- //#region src/frontend/array.d.ts
1852
- type ArrayLike = Array | number | boolean;
1853
- /** Version of pureArray with fudged types. */
1854
-
1819
+ declare function cumsum(a: ArrayLike, axis?: number): Array;
1820
+ /** Reverse the elements in an array along the given axes. */
1821
+ declare function flip(x: ArrayLike, axis?: Axis): Array;
1822
+ /**
1823
+ * Split an array into multiple sub-arrays along an axis.
1824
+ *
1825
+ * @param a - The input array to split.
1826
+ * @param indicesOrSections - If an integer, it indicates the number of equal
1827
+ * sections to create along the specified axis. If a list of integers, it
1828
+ * specifies the indices at which to split the array.
1829
+ * @param axis - The axis along which to split the array. Default is 0.
1830
+ */
1831
+ declare function split$1(a: ArrayLike, indicesOrSections: number | number[], axis?: number): Array[];
1832
+ /**
1833
+ * Join a sequence of arrays along an existing axis.
1834
+ *
1835
+ * The arrays must have the same shape, except in the dimension corresponding to
1836
+ * `axis` (the first, by default).
1837
+ *
1838
+ * No scalars can be passed to this function, as the axis is then ambiguous.
1839
+ */
1840
+ declare function concatenate(xs: Array[], axis?: number): Array;
1855
1841
  /**
1856
- * An executable operation that will be dispatched to the backend.
1842
+ * Join a sequence of arrays along a new axis.
1857
1843
  *
1858
- * This holds a reference to all input buffers used in the operation. After the
1859
- * operation is dispatched, the references should be released.
1844
+ * The `axis` parameter specifies the index of the new axis in the dimensions of
1845
+ * the result. For example, if `axis=0` it will be the first dimension and if
1846
+ * `axis=-1` it will be the last dimension.
1847
+ *
1848
+ * All shapes must have the same shape.
1860
1849
  */
1861
- declare class PendingExecute {
1862
- #private;
1863
- readonly backend: Backend;
1864
- readonly source: Kernel | Routine;
1865
- readonly inputs: Slot[];
1866
- readonly outputs: Slot[];
1867
- prepared: Executable | null;
1868
- submitted: boolean;
1869
- constructor(backend: Backend, source: Kernel | Routine, inputs: Slot[], outputs: Slot[]);
1870
- updateRc(delta: number): void;
1871
- prepare(): Promise<void>;
1872
- prepareSync(): void;
1873
- submit(): void;
1874
- }
1875
- /** @inline */
1876
- type DTypeAndDevice = {
1877
- dtype?: DType;
1878
- device?: Device;
1879
- };
1880
- type ArrayConstructorArgs = {
1881
- source: AluExp | Slot;
1882
- st: ShapeTracker;
1883
- dtype: DType;
1884
- weakType: boolean;
1885
- backend: Backend;
1886
- committed: boolean;
1887
- pending?: Iterable<PendingExecute>;
1888
- };
1850
+ declare function stack(xs: ArrayLike[], axis?: number): Array;
1889
1851
  /**
1890
- * A multidimensional numeric array with data stored on CPU or GPU.
1852
+ * Horizontally stack arrays. Inputs are promoted to rank at least 1, then
1853
+ * concatenated along axis 1 (if rank-2 or higher) or 0 (if rank-1).
1854
+ */
1855
+ declare function hstack(xs: ArrayLike[]): Array;
1856
+ /**
1857
+ * Vertically stack arrays. Inputs are promoted to rank at least 2, then
1858
+ * concatenated along axis 0.
1859
+ */
1860
+ declare function vstack(xs: ArrayLike[]): Array;
1861
+ /**
1862
+ * Stack arrays depth-wise. Inputs are promoted to rank at least 3, then
1863
+ * concatenated along axis 2.
1864
+ */
1865
+ declare function dstack(xs: ArrayLike[]): Array;
1866
+ /**
1867
+ * Stack arrays column-wise. Inputs are promoted to rank at least 2, then
1868
+ * concatenated along axis 1.
1869
+ */
1870
+ declare function columnStack(xs: ArrayLike[]): Array;
1871
+ /** Flip an array vertically (axis=0). */
1872
+ declare function flipud(x: ArrayLike): Array;
1873
+ /** Flip an array horizontally (axis=1). */
1874
+ declare function fliplr(x: ArrayLike): Array;
1875
+ /** Interchange two axes of an array. */
1876
+ declare function swapaxes(a: ArrayLike, axis1: number, axis2: number): Array;
1877
+ /** Transpose the last two dimensions of an array. */
1878
+ declare function matrixTranspose(a: ArrayLike): Array;
1879
+ /** Return a 1-D flattened array containing the elements of the input. */
1880
+ declare function ravel(a: ArrayLike): Array;
1881
+ /** Remove one or more length-1 axes from an array. */
1882
+ declare function squeeze(a: ArrayLike, axis?: Axis): Array;
1883
+ /**
1884
+ * Expand the shape of an array by inserting new axes of length 1.
1891
1885
  *
1892
- * This is the library's core data type. Equivalent to `jax.Array` from JAX, or
1893
- * `torch.Tensor`.
1886
+ * @param a - Input array.
1887
+ * @param axis - Position(s) in the expanded axes where the new axis (or axes)
1888
+ * is placed. Can be a single integer or an array of integers.
1889
+ * @returns Array with the number of dimensions increased.
1894
1890
  *
1895
- * Not to be confused with the JavaScript "Array" constructor. Avoid importing
1896
- * this into your code's namespace if you're already using the JavaScript
1897
- * "Array" type by name.
1891
+ * @example
1892
+ * ```ts
1893
+ * const x = np.array([1, 2]);
1894
+ * np.expandDims(x, 0); // Shape [1, 2]
1895
+ * np.expandDims(x, 1); // Shape [2, 1]
1896
+ * np.expandDims(x, [0, 2]); // Shape [1, 2, 1]
1897
+ * ```
1898
1898
  */
1899
- declare class Array extends Tracer {
1900
- #private;
1901
- /**
1902
- * @ignore
1903
- * Constructs an array from source, shape and backend. Note that if the source
1904
- * is a backend `Slot`, this constructor _takes ownership_ of the slot. It
1905
- * will be freed when the array is disposed.
1906
- */
1907
- constructor(args: ArrayConstructorArgs);
1908
- /** @ignore */
1909
- get aval(): ShapedArray;
1910
- /** Return a simple string representation of the array's dimensions. */
1911
- toString(): string;
1912
- get device(): Device;
1913
- get ref(): this;
1914
- /** Get the current reference count (for debugging memory management). */
1915
- get refCount(): number;
1916
- dispose(): void;
1917
- /**
1918
- * Convert this array into a primitive value.
1919
- *
1920
- * This only works for scalars (0-dimensional arrays). It lets you get values
1921
- * "out" of the JAX system. For instance, if `x = np.array(5)`, then you can
1922
- * evaluate `x + 1` and `x ** 2` to get `6` and `25`, respectively.
1923
- *
1924
- * This method is also called for `==` equality.
1925
- */
1926
- [Symbol.toPrimitive](): any;
1927
- /** Realize the array and return it as data. */
1928
- data(): Promise<DataArray>;
1929
- /**
1930
- * Wait for this array to finish evaluation.
1931
- *
1932
- * Operations and data loading in jax-js are lazy, so this function ensures
1933
- * that pending operations are dispatched and fully executed before it
1934
- * returns.
1935
- *
1936
- * If you are mapping from `data()` or `dataSync()`, it will also trigger
1937
- * dispatch of operations as well.
1938
- *
1939
- * **Note:** `jax.blockUntilReady()` is a higher-level API, it calls this
1940
- * asynchronously for multiple arrays.
1941
- */
1942
- blockUntilReady(): Promise<Array>;
1943
- /**
1944
- * Realize the array and return it as data. This is a sync variant and not
1945
- * recommended for performance reasons, as it will block rendering.
1946
- */
1947
- dataSync(): DataArray;
1948
- /**
1949
- * Convert this array into a JavaScript object.
1950
- *
1951
- * This is a blocking operation that will compile all of the shaders and wait
1952
- * for execution to complete, synchronously. No other JavaScript code on the
1953
- * site will be run during shader execution.
1954
- *
1955
- * To avoid blocking, prefer `jsAsync()` when possible.
1956
- */
1957
- js(): any;
1958
- /** Convert this array into a JavaScript object, asynchronously. */
1959
- jsAsync(): Promise<any>;
1960
- /**
1961
- * Copy an element of an array to a numeric scalar and return it.
1962
- *
1963
- * Throws an error if the array does not have a single element. The array must
1964
- * either be rank-0, or all dimensions of the shape are 1.
1965
- */
1966
- item(): number;
1967
- /** @private Internal plumbing method for Array / Tracer ops. */
1968
- static _implRules(): typeof implRules;
1969
- /** @private */
1970
- _realizeSource(): number;
1971
- /** @private Put this array on a new backend, asynchronously. */
1972
- _put(backend: Backend): Promise<Array>;
1973
- /** @private Put this array on a new backend, synchronously. */
1974
- _putSync(backend: Backend): Array;
1975
- }
1976
- /** Constructor for creating a new array from data. */
1977
- declare function array(values: Array | DataArray | RecursiveArray<number> | RecursiveArray<boolean>, {
1978
- shape,
1979
- dtype,
1980
- device
1981
- }?: {
1982
- shape?: number[];
1983
- } & DTypeAndDevice): Array;
1984
- /** If x is a value, lift it into an array, otherwise leave it be. */
1985
-
1986
- type ImplRule<P extends Primitive> = (tracers: Array[], params: PrimitiveParams<P>) => Array[];
1987
- declare const implRules: { [P in Primitive]: ImplRule<P> };
1988
- /** Return a new array of given shape and type, filled with zeros. */
1989
- declare function zeros(shape: number[], {
1990
- dtype,
1991
- device
1992
- }?: DTypeAndDevice): Array;
1993
- /** Return a new array of given shape and type, filled with ones. */
1994
- declare function ones(shape: number[], {
1995
- dtype,
1996
- device
1997
- }?: DTypeAndDevice): Array;
1998
- /** Return a new array of given shape and type, filled with `fill_value`. */
1999
- declare function full(shape: number[], fillValue: number | boolean | Array, {
2000
- dtype,
2001
- device
2002
- }?: DTypeAndDevice): Array;
1899
+ declare function expandDims(a: ArrayLike, axis: number | number[]): Array;
1900
+ /**
1901
+ * Repeat each element of an array after themselves.
1902
+ *
1903
+ * If no axis is provided, use the flattened input array, and return a flat
1904
+ * output array.
1905
+ */
1906
+ declare function repeat(a: ArrayLike, repeats: number, axis?: number): Array;
1907
+ /**
1908
+ * Construct an array by repeating A the number of times given by reps.
1909
+ *
1910
+ * If `A` is an array of shape `(d1, d2, ..., dn)` and `reps` is a sequence of
1911
+ * integers, the resulting array will have a shape of `(reps[0] * d1,
1912
+ * reps[1] * d2, ..., reps[n] * dn)`, with `A` tiled along each dimension.
1913
+ */
1914
+ declare function tile(a: ArrayLike, reps: number | number[]): Array;
1915
+ /**
1916
+ * Broadcast an array to a shape, with NumPy-style broadcasing rules.
1917
+ *
1918
+ * In other words, this lets you append axes to the left, and/or expand
1919
+ * dimensions where the shape is 1.
1920
+ */
1921
+ declare function broadcastTo(a: ArrayLike, shape: number[]): Array;
1922
+ /** Broadcast input shapes to a common output shape. */
1923
+ declare function broadcastShapes(...shapes: number[][]): number[];
1924
+ /** Broadcast arrays to a common shape. */
1925
+ declare function broadcastArrays(...arrays: ArrayLike[]): Array[];
1926
+ /**
1927
+ * Return specified diagonals.
1928
+ *
1929
+ * If a is 2D, return the diagonal of the array with the given offset. If a is
1930
+ * 3D or higher, compute diagonals along the two given axes (default: 0, 1).
1931
+ *
1932
+ * This returns a view over the existing array. The shape of the resulting array
1933
+ * is determined by removing the two axes along which the diagonal is taken,
1934
+ * then appending a new axis to the right with holding the diagonals.
1935
+ */
1936
+ declare function diagonal(a: ArrayLike, offset?: number, axis1?: number, axis2?: number): Array;
1937
+ /**
1938
+ * Extract a diagonal or construct a diagonal array.
1939
+ *
1940
+ * If v is a 2D array, return the k-th diagonal of v (as a view). If v is a 1D
1941
+ * array, return a 2D array with v on the k-th diagonal.
1942
+ */
1943
+ declare function diag(v: ArrayLike, k?: number): Array;
1944
+ /** Calculate the sum of the diagonal of an array along the given axes. */
1945
+ declare function trace(a: ArrayLike, offset?: number, axis1?: number, axis2?: number): Array;
1946
+ /**
1947
+ * Return a sorted copy of an array.
1948
+ *
1949
+ * The array is sorted along a specified axis (the last by default). This may be
1950
+ * an unstable sort, and it dispatches to device-specific implementation.
1951
+ */
1952
+ declare function sort(a: ArrayLike, axis?: number): Array;
1953
+ /**
1954
+ * Return indices that would sort an array. This may be an unstable sorting
1955
+ * algorithm; it need not preserve order of indices in ties.
1956
+ *
1957
+ * Returns an array of `int32` indices.
1958
+ *
1959
+ * The array is sorted along a specified axis (the last by default).
1960
+ */
1961
+ declare function argsort(a: ArrayLike, axis?: number): Array;
2003
1962
  /**
2004
- * Create an identity matrix.
1963
+ * Take elements from an array along an axis.
2005
1964
  *
2006
- * If numCols is not provided, it defaults to numRows, i.e., a square identity
2007
- * matrix with ones on the diagonal.
1965
+ * This is equivalent to advanced indexing with integer indices over that
1966
+ * numbered axis. By default, the flattened array is used.
2008
1967
  */
2009
- declare function eye(numRows: number, numCols?: number, {
2010
- dtype,
2011
- device
2012
- }?: DTypeAndDevice): Array;
2013
- /** Return the identity matrix, with ones on the main diagonal. */
2014
- declare function identity$1(n: number, {
2015
- dtype,
2016
- device
2017
- }?: DTypeAndDevice): Array;
1968
+ declare function take(a: ArrayLike, indices: ArrayLike, axis?: number | null): Array;
1969
+ /** Return if two arrays are element-wise equal within a tolerance. */
1970
+ declare function allclose(actual: Parameters<typeof array>[0], expected: Parameters<typeof array>[0], options?: {
1971
+ rtol?: number;
1972
+ atol?: number;
1973
+ }): boolean;
1974
+ /** Matrix product of two arrays. */
1975
+ declare function matmul(x: ArrayLike, y: ArrayLike): Array;
1976
+ /** Dot product of two arrays. */
1977
+ declare function dot(x: ArrayLike, y: ArrayLike): Array;
2018
1978
  /**
2019
- * Return evenly spaced values within a given interval.
1979
+ * Compute the tensor dot product of two N-dimensional arrays.
2020
1980
  *
2021
- * This can be called with a varying number of arguments, just like the range()
2022
- * builtin function in Python.
1981
+ * The behavior is determined by `axes`. If an integer `k`, sum over the last
1982
+ * `k` axes of x and the first `k` axes of y. If a tuple, then the first array
1983
+ * corresponds to the axes of x and the second to the axes of y.
1984
+ */
1985
+ declare function tensordot(x: ArrayLike, y: ArrayLike, axes?: number | [number[], number[]]): Array;
1986
+ /**
1987
+ * Einstein summation with string subscripts.
2023
1988
  *
2024
- * - `arange(stop)` is equivalent to `arange(0, stop, 1)`.
2025
- * - `arange(start, stop)` is equivalent to `arange(start, stop, 1)`.
2026
- * - `arange(start, stop, step)` creates an array starting at `start`, ending
2027
- * before `stop`, with a step size of `step`.
1989
+ * @example
1990
+ * ```ts
1991
+ * import { numpy as np } from "@jax-js/jax";
2028
1992
  *
2029
- * Defaults to an integer data type. This can produce unintended results when
2030
- * using a non-integer step, so prefer linspace() in those cases.
1993
+ * const a = np.ones([2, 3]);
1994
+ * const b = np.ones([3]);
1995
+ * np.einsum("ij,j", a, b); // Shape [2]
1996
+ * ```
2031
1997
  */
2032
- declare function arange(start: number, stop?: number, step?: number, {
2033
- dtype,
2034
- device
2035
- }?: DTypeAndDevice): Array;
1998
+ declare function einsum(subscripts: string, ...operands: ArrayLike[]): Array;
2036
1999
  /**
2037
- * Return an array with ones on and below the diagonal and zeros elsewhere.
2000
+ * Einstein summation alternating between arrays and numeric indices.
2038
2001
  *
2039
- * If `k` is provided, it specifies the sub-diagonal on and below which the
2040
- * array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
2041
- * `k>0` is above it.
2002
+ * @example
2003
+ * ```ts
2004
+ * import { numpy as np } from "@jax-js/jax";
2005
+ *
2006
+ * const a = np.ones([2, 3]);
2007
+ * const b = np.ones([3]);
2008
+ * np.einsum(a, [0, 1], b, [1]); // Shape [2]
2009
+ * ```
2042
2010
  */
2043
- declare function tri(n: number, m?: number, k?: number, {
2044
- dtype,
2045
- device
2046
- }?: DTypeAndDevice): Array;
2047
- /** Return the lower triangle of an array. Must be of dimension >= 2. */
2048
- declare function tril(a: ArrayLike, k?: number): Array;
2049
- /** Return the upper triangle of an array. Must be of dimension >= 2. */
2050
- declare function triu(a: ArrayLike, k?: number): Array;
2011
+ declare function einsum(...args: (ArrayLike | number[])[]): Array;
2051
2012
  /**
2052
- * Return evenly spaced numbers over a specified interval.
2013
+ * Compute the inner product of two arrays.
2053
2014
  *
2054
- * Returns _num_ evenly spaced samples, calculated over the interval
2055
- * [`start`, `stop`]. The endpoint `stop` is included in the result by default,
2056
- * but this is controlled by the `endpoint` parameter.
2015
+ * Unlike `jax.numpy.matmul()` or `jax.numpy.dot()`, this always performs a
2016
+ * contraction on the last axis.
2057
2017
  *
2058
- * The default data type is Float32. Use arange() for integer steps.
2018
+ * Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
2059
2019
  */
2060
- declare function linspace(start: number, stop: number, num?: number, endpoint?: boolean, {
2061
- dtype,
2062
- device
2063
- }?: DTypeAndDevice): Array;
2020
+ declare function inner(x: ArrayLike, y: ArrayLike): Array;
2064
2021
  /**
2065
- * Return numbers spaced evenly on a log scale.
2022
+ * Compute the outer product of two arrays.
2066
2023
  *
2067
- * In linear space, the sequence starts at `base ** start` and ends at
2068
- * `base ** stop` (see `endpoint` below).
2024
+ * If the input arrays are not 1D, they will be flattened. Returned array will
2025
+ * be of shape `[x.size, y.size]`.
2026
+ */
2027
+ declare function outer(x: ArrayLike, y: ArrayLike): Array;
2028
+ /** Vector dot product of two arrays along a given axis. */
2029
+ declare function vecdot(x: ArrayLike, y: ArrayLike, {
2030
+ axis
2031
+ }?: {
2032
+ axis?: number;
2033
+ }): Array;
2034
+ /**
2035
+ * Return the dot product of two vectors.
2069
2036
  *
2070
- * @param start - `base ** start` is the starting value of the sequence.
2071
- * @param stop - `base ** stop` is the final value of the sequence, unless `endpoint` is false.
2072
- * @param num - Number of samples to generate. Default is 50.
2073
- * @param endpoint - If true, `stop` is the last sample. Otherwise, it is not included. Default is true.
2074
- * @param base - The base of the log space. Default is 10.
2075
- * @returns Array of evenly spaced values on a log scale.
2037
+ * Like vecdot() but flattens the arguments first into vectors.
2076
2038
  */
2077
- declare function logspace(start: number, stop: number, num?: number, endpoint?: boolean, base?: number, {
2078
- dtype,
2079
- device
2080
- }?: DTypeAndDevice): Array;
2081
- //#endregion
2082
- //#region src/frontend/linearize.d.ts
2083
- /** @inline */
2084
- type GradOpts = {
2085
- /**
2086
- * Integer or sequence of integers. Specifies which positional argument(s) to
2087
- * differentiate with respect to.
2088
- *
2089
- * Defaults to `0` (the first argument).
2090
- */
2091
- argnums?: number | number[];
2092
- /**
2093
- * The input function returns a pair of `[out, aux]` including an auxiliary
2094
- * value. This `aux` is not differentiated, but is returned alongside the
2095
- * gradient when evaluating the function.
2096
- */
2097
- hasAux?: boolean;
2098
- };
2099
- declare namespace lax_linalg_d_exports {
2100
- export { cholesky, lu, triangularSolve };
2101
- }
2039
+ declare function vdot(x: ArrayLike, y: ArrayLike): Array;
2040
+ /** Convolution of two one-dimensional arrays. */
2041
+ declare function convolve(x: Array, y: Array, mode?: "full" | "same" | "valid"): Array;
2042
+ /** Correlation of two one dimensional arrays. */
2043
+ declare function correlate(x: Array, y: Array, mode?: "full" | "same" | "valid"): Array;
2102
2044
  /**
2103
- * Compute the Cholesky decomposition of a symmetric positive-definite matrix.
2045
+ * Return a tuple of coordinate matrices from coordinate vectors.
2104
2046
  *
2105
- * The Cholesky decomposition of a matrix `A` is:
2047
+ * Make N-D coordinate arrays for vectorized evaluations of N-D scalar/vector
2048
+ * fields over N-D grids, given one-dimensional coordinate arrays x1, x2,…, xn.
2049
+ */
2050
+ declare function meshgrid(xs: Array[], {
2051
+ indexing
2052
+ }?: {
2053
+ indexing?: "xy" | "ij";
2054
+ }): Array[];
2055
+ /**
2056
+ * Clip (limit) the values in an array.
2106
2057
  *
2107
- * - A = L @ L^T (for upper=false, default)
2108
- * - A = U^T @ U (for upper=true)
2058
+ * Given an interval, values outside the interval are clipped to the interval
2059
+ * edges. For example, if an interval of [0, 1] is specified, values smaller
2060
+ * than 0 become 0, and values larger than 1 become 1.
2109
2061
  *
2110
- * where `L` is a lower-triangular matrix and `U` is an upper-triangular matrix.
2111
- * The input matrix must be symmetric and positive-definite.
2062
+ * If either bound is undefined, it is ignored.
2063
+ */
2064
+ declare function clip(a: ArrayLike, min?: ArrayLike, max?: ArrayLike): Array;
2065
+ /**
2066
+ * Calculate the absolute value element-wise.
2112
2067
  *
2113
- * @example
2114
- * ```ts
2115
- * import { lax, numpy as np } from "@jax-js/jax";
2068
+ * This is the same function as `jax.numpy.abs()`.
2069
+ */
2070
+ declare function absolute(x: ArrayLike): Array;
2071
+ /** Return an element-wise indication of sign of the input. */
2072
+ declare function sign(x: ArrayLike): Array;
2073
+ /** @function Return element-wise positive values of the input (no-op). */
2074
+ declare const positive: (x: ArrayLike) => Array;
2075
+ /**
2076
+ * Return the Hamming window of size M, a taper with a weighted cosine bell.
2116
2077
  *
2117
- * const x = np.array([[2., 1.], [1., 2.]]);
2078
+ * `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
2079
+ */
2080
+ declare function hamming(M: number): Array;
2081
+ /**
2082
+ * Return the Hann window of size M, a taper with a weighted cosine bell.
2118
2083
  *
2119
- * // Lower Cholesky factorization (default):
2120
- * const L = lax.linalg.cholesky(x);
2121
- * // L ≈ [[1.4142135, 0], [0.70710677, 1.2247449]]
2084
+ * `w(n) = 0.5 - 0.5 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
2085
+ */
2086
+ declare function hann(M: number): Array;
2087
+ /**
2088
+ * @function
2089
+ * Compute the Heaviside step function. It is defined piecewise:
2090
+ * - `heaviside(x1, x2) = 0` for `x1 < 0`,
2091
+ * - `heaviside(x1, x2) = x2` for `x1 == 0`,
2092
+ * - `heaviside(x1, x2) = 1` for `x1 > 0`.
2093
+ */
2094
+ declare const heaviside: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
2095
+ /** Calculate element-wise square of the input array. */
2096
+ declare function square(x: ArrayLike): Array;
2097
+ /** Element-wise tangent function (takes radians). */
2098
+ declare function tan(x: ArrayLike): Array;
2099
+ /**
2100
+ * @function
2101
+ * Return the normalized sinc function.
2122
2102
  *
2123
- * // Upper Cholesky factorization:
2124
- * const U = lax.linalg.cholesky(x, { upper: true });
2125
- * // U ≈ [[1.4142135, 0.70710677], [0, 1.2247449]]
2126
- * ```
2103
+ * The sinc function is defined as `sin(πx) / (πx)` for `x != 0`, and `1` for `x = 0`.
2104
+ * This is the normalized sinc function commonly used in signal processing.
2105
+ *
2106
+ * **Note:** JVP is not supported at x=0 due to discontinuous derivative. This
2107
+ * requires a custom JVP rule to handle properly (see JAX implementation).
2127
2108
  */
2128
- declare function cholesky(a: ArrayLike, {
2129
- upper
2130
- }?: {
2131
- upper?: boolean;
2132
- }): Array;
2109
+ declare const sinc: OwnedFunction<(x: ArrayLike) => Array>;
2110
+ /** Element-wise inverse cosine function (inverse of cos). */
2111
+ declare function acos(x: ArrayLike): Array;
2133
2112
  /**
2134
- * LU decomposition with partial pivoting.
2113
+ * @function
2114
+ * Return element-wise hypotenuse for the given legs of a right triangle.
2115
+ *
2116
+ * In the original NumPy/JAX implementation, this function is more numerically
2117
+ * stable than `sqrt(x1**2 + x2**2)`. We don't currently implement those
2118
+ * stability improvements.
2119
+ */
2120
+ declare const hypot: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
2121
+ /**
2122
+ * @function
2123
+ * Element-wise arc tangent of y/x with correct quadrant.
2124
+ *
2125
+ * Returns the angle in radians between the positive x-axis and the point (x, y).
2126
+ * The result is in the range [-π, π].
2127
+ *
2128
+ * Uses numerically stable formulas:
2129
+ * - When x >= 0: atan2(y, x) = 2 * atan(y / (sqrt(x^2 + y^2) + x))
2130
+ * - When x < 0: atan2(y, x) = 2 * atan((sqrt(x^2 + y^2) - x) / y)
2131
+ *
2132
+ * The output is ill-defined when both x and y are zero.
2133
+ */
2134
+ declare const atan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
2135
+ /** Element-wise subtraction, with broadcasting. */
2136
+ declare function subtract(x: ArrayLike, y: ArrayLike): Array;
2137
+ /** Calculates the floating-point division of x by y element-wise. */
2138
+ declare function trueDivide(x: ArrayLike, y: ArrayLike): Array;
2139
+ /**
2140
+ * Return the largest integer smaller or equal to the division of the inputs.
2135
2141
  *
2136
- * Computes the matrix decomposition: `P @ A = L @ U`, where `P` is a
2137
- * permutation of the rows of `A`, `L` is lower-triangular with unit diagonal,
2138
- * and `U` is upper-triangular.
2142
+ * The result is always rounded towards negative infinity.
2139
2143
  *
2140
- * @param x - A batch of matrices with shape `[..., m, n]`.
2144
+ * For floating-point inputs, this is equivalent to `floor(x / y)`.
2145
+ * For integer inputs, we use `(x - remainder(x, y)) / y` to handle
2146
+ * negative values correctly (note: may overflow near int32 boundaries).
2141
2147
  *
2142
- * @returns A tuple `(lu, pivots, permutation)` where:
2143
- * - `lu`: combined lower and upper triangular matrices.
2144
- * - `pivots`: an array of pivot indices with shape `[..., min(m, n)]`.
2145
- * - `permutation`: the permutation generated by pivots with shape `[..., m]`.
2148
+ * @param x - Dividend array.
2149
+ * @param y - Divisor array.
2150
+ * @returns Element-wise floor division of x by y.
2151
+ */
2152
+ declare function floorDivide(x: ArrayLike, y: ArrayLike): Array;
2153
+ /**
2154
+ * @function
2155
+ * Calculate element-wise floating-point modulo operation.
2156
+ */
2157
+ declare const fmod: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
2158
+ /**
2159
+ * @function
2160
+ * Calculate element-wise remainder of the division (matches sign of y).
2161
+ */
2162
+ declare const remainder: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
2163
+ /**
2164
+ * Return element-wise quotient and remainder simultaneously.
2146
2165
  *
2147
- * @example
2148
- * ```ts
2149
- * import { lax, numpy as np } from "@jax-js/jax";
2166
+ * Equivalent to `[floorDivide(x, y), remainder(x, y)]`.
2150
2167
  *
2151
- * const A = np.array([[4., 3.], [6., 3.]]);
2152
- * const [lu, pivots, permutation] = lax.linalg.lu(A);
2153
- * // lu [[6., 3.], [0.6666667, 1.0]]
2154
- * // pivots = [1, 1]
2155
- * // permutation = [1, 0]
2156
- * ```
2168
+ * @param x - Dividend array.
2169
+ * @param y - Divisor array.
2170
+ * @returns Tuple of [quotient, remainder].
2157
2171
  */
2158
- declare function lu(x: ArrayLike): [Array, Array, Array];
2172
+ declare function divmod(x: ArrayLike, y: ArrayLike): [Array, Array];
2173
+ /** Round input to the nearest integer towards zero. */
2174
+ declare function trunc(x: ArrayLike): Array;
2159
2175
  /**
2160
- * Solve a triangular linear system.
2161
- *
2162
- * Solves `a @ x = b` (if leftSide=true) or `x @ a = b` (if leftSide=false)
2163
- * where `a` is a triangular matrix.
2176
+ * Compute `x1 * 2 ** x2` as a standard multiplication and exponentiation.
2164
2177
  *
2165
- * @example
2166
- * ```ts
2167
- * import { lax, numpy as np } from "@jax-js/jax";
2178
+ * This is the inverse of `frexp()`.
2179
+ */
2180
+ declare function ldexp(x1: ArrayLike, x2: ArrayLike): Array;
2181
+ /**
2182
+ * Decompose floating-point values into mantissa and two's exponent.
2168
2183
  *
2169
- * const L = np.array([[2., 0.], [1., 3.]]);
2170
- * const b = np.array([4., 7.]).reshape([2, 1]);
2184
+ * The mantissa is returned in the range `(-1, 1)` with magnitude `>= 0.5` if
2185
+ * `x != 0`, and the exponent is an integer such that
2186
+ * `x = mantissa * 2**exponent`.
2187
+ */
2188
+ declare function frexp(x: ArrayLike): [Array, Array];
2189
+ /** Calculate `2**p` for all p in the input array. */
2190
+ declare function exp2(p: ArrayLike): Array;
2191
+ /** Return the base-2 logarithm of x, element-wise. */
2192
+ declare function log2(x: ArrayLike): Array;
2193
+ /** Return the base-10 logarithm of x, element-wise. */
2194
+ declare function log10(x: ArrayLike): Array;
2195
+ /** Calculate `exp(x) - 1` element-wise. */
2196
+ declare function expm1(x: ArrayLike): Array;
2197
+ /** Calculate the natural logarithm of `1 + x` element-wise. */
2198
+ declare function log1p(x: ArrayLike): Array;
2199
+ /** Convert angles from degrees to radians. */
2200
+ declare function deg2rad(x: ArrayLike): Array;
2201
+ /** @function Alias of `jax.numpy.deg2rad()`. */
2202
+ declare const radians: typeof deg2rad;
2203
+ /** Convert angles from radians to degrees. */
2204
+ declare function rad2deg(x: ArrayLike): Array;
2205
+ /** @function Alias of `jax.numpy.rad2deg()`. */
2206
+ declare const degrees: typeof rad2deg;
2207
+ /**
2208
+ * @function
2209
+ * Computes first array raised to power of second array, element-wise.
2210
+ */
2211
+ declare const power: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
2212
+ /** @function Calculate the element-wise cube root of the input array. */
2213
+ declare const cbrt: OwnedFunction<(x: ArrayLike) => Array>;
2214
+ /**
2215
+ * @function
2216
+ * Calculate element-wise hyperbolic sine of input.
2171
2217
  *
2172
- * // Solve L @ x = b
2173
- * const x = lax.linalg.triangularSolve(L, b, { leftSide: true, lower: true });
2174
- * // x = [[2.], [5./3.]]
2175
- * ```
2218
+ * `sinh(x) = (exp(x) - exp(-x)) / 2`
2176
2219
  */
2177
- declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
2178
- leftSide,
2179
- lower,
2180
- transposeA,
2181
- unitDiagonal
2182
- }?: {
2183
- leftSide?: boolean;
2184
- lower?: boolean;
2185
- transposeA?: boolean;
2186
- unitDiagonal?: boolean;
2187
- }): Array;
2188
- declare namespace lax_d_exports {
2189
- export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
2190
- }
2220
+ declare const sinh: OwnedFunction<(x: ArrayLike) => Array>;
2191
2221
  /**
2192
- * Dimension numbers for general `dot()` primitive.
2222
+ * @function
2223
+ * Calculate element-wise hyperbolic cosine of input.
2193
2224
  *
2194
- * Contracting dimensions act as a tensor contraction (reduction) along the
2195
- * given axis. They must be the same size in both operands. Batch dimensions
2196
- * are treated as vectorized, leading batch dimensions.
2225
+ * `cosh(x) = (exp(x) + exp(-x)) / 2`
2226
+ */
2227
+ declare const cosh: OwnedFunction<(x: ArrayLike) => Array>;
2228
+ /**
2229
+ * @function
2230
+ * Calculate element-wise hyperbolic tangent of input.
2197
2231
  *
2198
- * The return value has a shape where the first dimensions are shared batch
2199
- * dimensions, followed by `lhs` non-contracting dimensions, followed by
2200
- * `rhs` non-contracting dimensions.
2232
+ * `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
2201
2233
  */
2202
- type DotDimensionNumbers = {
2203
- lhsContractingDims?: number[];
2204
- rhsContractingDims?: number[];
2205
- lhsBatchDims?: number[];
2206
- rhsBatchDims?: number[];
2207
- };
2234
+ declare const tanh: OwnedFunction<(x: ArrayLike) => Array>;
2208
2235
  /**
2209
- * General dot product/contraction operator.
2236
+ * @function
2237
+ * Calculate element-wise inverse hyperbolic sine of input.
2210
2238
  *
2211
- * Prefer higher-level functions like `jax.numpy.dot()`, `jax.numpy.matmul()`,
2212
- * `jax.numpy.tensordot(), and `jax.numpy.einsum()` where possible.
2239
+ * `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
2213
2240
  */
2214
- declare function dot(lhs: Array, rhs: Array, {
2215
- lhsContractingDims: lc,
2216
- rhsContractingDims: rc,
2217
- lhsBatchDims: lb,
2218
- rhsBatchDims: rb
2219
- }?: DotDimensionNumbers): Array;
2220
- type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | Pair[];
2241
+ declare const arcsinh: OwnedFunction<(x: ArrayLike) => Array>;
2221
2242
  /**
2222
- * General n-dimensional convolution operator, with optional dilation.
2243
+ * @function
2244
+ * Calculate element-wise inverse hyperbolic cosine of input.
2223
2245
  *
2224
- * The semantics of this operation mimic the `jax.lax.conv_general_dilated`
2225
- * function in JAX, which wraps XLA's general convolution operator.
2246
+ * `arccosh(x) = ln(x + sqrt(x^2 - 1))`
2247
+ */
2248
+ declare const arccosh: OwnedFunction<(x: ArrayLike) => Array>;
2249
+ /**
2250
+ * @function
2251
+ * Calculate element-wise inverse hyperbolic tangent of input.
2226
2252
  *
2227
- * @param lhs - Input tensor; shape `[N, C_in, ...xs]`
2228
- * @param rhs - Convolution kernel; shape `[C_out, C_in / G, ...ks]`
2229
- * @param windowStrides - Strides for each spatial dimension
2230
- * @param padding - Padding for each spatial dimension, or a string
2231
- * (`"VALID"`, `"SAME"`, or `"SAME_LOWER"`)
2253
+ * `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
2232
2254
  */
2233
- declare function convGeneralDilated(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, {
2234
- lhsDilation,
2235
- rhsDilation,
2236
- featureGroupCount
2237
- }?: {
2238
- lhsDilation?: number[];
2239
- rhsDilation?: number[];
2240
- featureGroupCount?: number;
2241
- }): Array;
2242
- /** Convenience wrapper around `convGeneralDilated`. */
2243
- declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, lhsDilation?: number[], rhsDilation?: number[]): Array;
2244
- /** Convenience wrapper around `convGeneralDilated`. */
2245
- declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType): Array;
2255
+ declare const arctanh: OwnedFunction<(x: ArrayLike) => Array>;
2246
2256
  /**
2247
- * Convenience wrapper for calculating the N-d convolution "transpose".
2257
+ * Compute the variance of an array.
2248
2258
  *
2249
- * This function directly calculates a fractionally strided conv rather than
2250
- * indirectly calculating the gradient (transpose) of a forward convolution.
2251
- * It is equivalent to the JAX version, except:
2259
+ * The variance is computed for the flattened array by default, otherwise over
2260
+ * the specified axis.
2252
2261
  *
2253
- * - The `use_consistent_padding` option is not available. We only have the
2254
- * consistent padding case (JAX version >0.8.4).
2255
- * - The order of dimensions matches `lax.conv_general_dilated`.
2262
+ * If `correction` is provided, the divisor in calculation is `N - correction`,
2263
+ * where `N` represents the number of elements (e.g., for Bessel's correction).
2264
+ */
2265
+ declare function var_(x: ArrayLike, axis?: Axis, opts?: {
2266
+ mean?: ArrayLike;
2267
+ correction?: number;
2268
+ } & ReduceOpts): Array;
2269
+ /**
2270
+ * Compute the standard deviation of an array.
2256
2271
  *
2257
- * Unlike PyTorch/TensorFlow, by default we don't reverse the kernel's spatial
2258
- * dimensions or the `(C_out, C_in)` axis order. To get this behavior, set
2259
- * `transposeKernel` to true.
2272
+ * The standard deviation is computed for the flattened array by default,
2273
+ * otherwise over the specified axis.
2260
2274
  *
2261
- * @param lhs - Input tensor; shape `[N, C_in, ...xs]`
2262
- * @param rhs - Convolution kernel; shape `[C_out, C_in, ...ks]`
2263
- * @param strides - Sequence of n integers, sets fractional stride
2264
- * @param padding - Apply padding of `dilation * (kernel_size - 1) - padding` to
2265
- * each side of the input, so it acts like gradient of `conv()`
2266
- * @param rhsDilation - Atrous dilation for the kernel
2267
- * @param transposeKernel - Flip spatial axes and swap the input/output channels
2268
- * of the kernel; its shape should be `[C_in, C_out, ...ks]`
2275
+ * If `correction` is provided, the divisor in calculation is `N - correction`,
2276
+ * where `N` represents the number of elements (e.g., for Bessel's correction).
2269
2277
  */
2270
- declare function convTranspose(lhs: Array, rhs: Array, strides: number[], padding: PaddingType, {
2271
- rhsDilation,
2272
- transposeKernel
2278
+ declare function std(x: ArrayLike, axis?: Axis, opts?: {
2279
+ mean?: ArrayLike;
2280
+ correction?: number;
2281
+ } & ReduceOpts): Array;
2282
+ /** Estimate the sample covariance of a set of variables. */
2283
+ declare function cov(x: ArrayLike, y?: ArrayLike | null, {
2284
+ rowvar
2273
2285
  }?: {
2274
- rhsDilation?: number[];
2275
- transposeKernel?: boolean;
2286
+ rowvar?: boolean;
2276
2287
  }): Array;
2277
- /** Reduce a computation over padded windows. */
2278
- declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
2279
- /** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
2280
- declare function erf(x: ArrayLike): Array;
2288
+ /** Compute the Pearson correlation coefficients (in range `[-1, 1]`). */
2289
+ declare function corrcoef(x: ArrayLike, y?: ArrayLike): Array;
2290
+ /** Test element-wise for positive or negative infinity, return bool array. */
2291
+ declare function isinf(x: ArrayLike): Array;
2292
+ /** Test element-wise for NaN (Not a Number). */
2293
+ declare function isnan(x: ArrayLike): Array;
2294
+ /** Test element-wise for negative infinity, return bool array. */
2295
+ declare function isneginf(x: ArrayLike): Array;
2296
+ /** Test element-wise for positive infinity, return bool array. */
2297
+ declare function isposinf(x: ArrayLike): Array;
2281
2298
  /**
2282
- * The complementary error function: `erfc(x) = 1 - erf(x)`.
2299
+ * Replace NaN and infinite entries in an array.
2283
2300
  *
2284
- * This function is more accurate than `1 - erf(x)` for large values of `x`,
2285
- * where `erf(x)` is very close to 1.
2301
+ * By default, NaNs are replaced with `0.0`, and infinities are are substituted
2302
+ * with the corresponding maximum or minimum finite values.
2286
2303
  */
2287
- declare function erfc(x: ArrayLike): Array;
2304
+ declare function nanToNum(x: ArrayLike, {
2305
+ nan,
2306
+ posinf,
2307
+ neginf
2308
+ }?: {
2309
+ nan?: ArrayLike;
2310
+ posinf?: ArrayLike | null;
2311
+ neginf?: ArrayLike | null;
2312
+ }): Array;
2288
2313
  /**
2289
- * Stops gradient computation.
2290
- *
2291
- * Behaves as the identity function but prevents the flow of gradients during
2292
- * forward or reverse-mode automatic differentiation.
2314
+ * @function
2315
+ * Test element-wise for finite values (not infinity or NaN).
2293
2316
  */
2294
- declare function stopGradient(x: ArrayLike): Array;
2317
+ declare const isfinite: OwnedFunction<(x: ArrayLike) => Array>;
2295
2318
  declare namespace nn_d_exports {
2296
2319
  export { celu, dotProductAttention, elu, gelu, glu, hardSigmoid, hardSilu, hardSilu as hardSwish, hardTanh, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, selu, sigmoid, silu, softSign, softmax, softplus, sparsePlus, sparseSigmoid, squareplus, standardize, silu as swish };
2297
2320
  }