@jax-js/jax 0.1.5 → 0.1.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +60 -7
- package/dist/{backend-DziQSaoQ.cjs → backend-B3foXiV_.cjs} +25 -6
- package/dist/{backend-DaqL-MNz.js → backend-nEolvdLv.js} +20 -7
- package/dist/index.cjs +450 -129
- package/dist/index.d.cts +1669 -1467
- package/dist/index.d.ts +1669 -1467
- package/dist/index.js +450 -130
- package/dist/{webgl-ClIYb8jP.cjs → webgl-DIIbKJ0G.cjs} +1 -1
- package/dist/{webgl-RSuZKvgc.js → webgl-DweKSWEm.js} +1 -1
- package/dist/{webgpu-Dh7k9io0.js → webgpu-B96vzWGE.js} +1 -1
- package/dist/{webgpu-Db2JrNBr.cjs → webgpu-BykvF26B.cjs} +1 -1
- package/package.json +1 -1
package/dist/index.d.cts
CHANGED
|
@@ -541,1704 +541,1782 @@ declare class Executable<T = any> {
|
|
|
541
541
|
source: Kernel | Routine, /** Extra data specific to the backend running this executable. */
|
|
542
542
|
data: T);
|
|
543
543
|
}
|
|
544
|
-
declare namespace
|
|
545
|
-
export {
|
|
544
|
+
declare namespace tree_d_exports {
|
|
545
|
+
export { JsTree, JsTreeDef, MapJsTree, NodeType, dispose, flatten, leaves, map, ref, structure, unflatten };
|
|
546
546
|
}
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
547
|
+
declare enum NodeType {
|
|
548
|
+
Array = "Array",
|
|
549
|
+
Object = "Object",
|
|
550
|
+
Leaf = "Leaf",
|
|
551
|
+
}
|
|
552
|
+
/** Analog to the JAX "pytree" object, but for JavaScript. */
|
|
553
|
+
type JsTree<T> = T | JsTree<T>[] | {
|
|
554
|
+
[key: string]: JsTree<T>;
|
|
554
555
|
};
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
*/
|
|
560
|
-
declare
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
556
|
+
type Same<X, Y> = (<T>() => T extends X ? 1 : 2) extends (<T>() => T extends Y ? 1 : 2) ? true : false;
|
|
557
|
+
type MappedJsTree<T, A, B> = T extends A ? B : T extends Array ? T : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
|
|
558
|
+
/** @ignore Convert a subtype of JsTree<A> into a JsTree<B>, with the same structure. */
|
|
559
|
+
type MapJsTree<T, A, B> = Same<A, B> extends true ? T : MappedJsTree<T, A, B>;
|
|
560
|
+
/** Represents the structure of a JsTree. */
|
|
561
|
+
declare class JsTreeDef {
|
|
562
|
+
readonly nodeType: NodeType;
|
|
563
|
+
readonly nodeMetadata: any;
|
|
564
|
+
readonly childTreedefs: JsTreeDef[];
|
|
565
|
+
static leaf: JsTreeDef;
|
|
566
|
+
constructor(nodeType: NodeType, nodeMetadata: any,
|
|
567
|
+
// Must be comparable with deepEqual.
|
|
568
|
+
childTreedefs: JsTreeDef[]);
|
|
569
|
+
/** Get the total number of leaves in the tree. */
|
|
570
|
+
get size(): number;
|
|
571
|
+
/** Returns a string representation of this tree definition. */
|
|
572
|
+
toString(root?: boolean): string;
|
|
573
|
+
/** Compare this tree definition with another. */
|
|
574
|
+
equals(other: JsTreeDef): boolean;
|
|
575
|
+
}
|
|
576
|
+
/** Flatten a structured object, returning the tree definition. */
|
|
577
|
+
declare function flatten<T>(tree: JsTree<T>): [T[], JsTreeDef];
|
|
578
|
+
/** Get the leaves of a tree. */
|
|
579
|
+
declare function leaves<T>(tree: JsTree<T>): T[];
|
|
580
|
+
/** Get the treedef for a tree. */
|
|
581
|
+
declare function structure<T>(tree: JsTree<T>): JsTreeDef;
|
|
582
|
+
/** Reconstruct a structured object from the flattened representation. */
|
|
583
|
+
declare function unflatten<T>(treedef: JsTreeDef, leaves: Iterable<T>): JsTree<T>;
|
|
584
|
+
/** Maps a multi-input function over pytree args to produce a new pytree. */
|
|
585
|
+
declare function map<T, U, Tree extends JsTree<T>>(fn: (...args: T[]) => U, tree: Tree, ...rest: Tree[]): MapJsTree<Tree, T, U>;
|
|
586
|
+
/** Take a reference of every array in a tree. */
|
|
587
|
+
declare function ref<Tree extends JsTree<any>>(tree: Tree): Tree;
|
|
588
|
+
/** Dispose every array in a tree. */
|
|
589
|
+
declare function dispose<Tree extends JsTree<any>>(tree: Tree | null | undefined): void;
|
|
590
|
+
//#endregion
|
|
591
|
+
//#region src/frontend/convolution.d.ts
|
|
592
|
+
/** Definition of a general dilated convolution. Should be valid on creation. */
|
|
593
|
+
interface ConvParams {
|
|
594
|
+
vmapDims: number;
|
|
595
|
+
strides: number[];
|
|
596
|
+
padding: Pair[];
|
|
597
|
+
lhsDilation: number[];
|
|
598
|
+
rhsDilation: number[];
|
|
569
599
|
}
|
|
570
600
|
/**
|
|
571
|
-
*
|
|
601
|
+
* Check that the shapes and parameters passed to convolution are valid.
|
|
602
|
+
* Expected shapes of the lhs and rhs of the convolution are:
|
|
572
603
|
*
|
|
573
|
-
*
|
|
574
|
-
*
|
|
604
|
+
* - `lhsShape = [*vmapDims, batchSize, inChannels, spatialDims...]`
|
|
605
|
+
* - `rhsShape = [*vmapDims, outChannels, inChannels, kernelSize...]`
|
|
606
|
+
*
|
|
607
|
+
* If the check succeeds, returns the output shape.
|
|
575
608
|
*/
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
symmetrizeInput
|
|
579
|
-
}?: {
|
|
580
|
-
upper?: boolean;
|
|
581
|
-
symmetrizeInput?: boolean;
|
|
582
|
-
}): Array;
|
|
583
|
-
/** Compute the determinant of a square matrix (batched). */
|
|
584
|
-
declare function det(a: ArrayLike): Array;
|
|
585
|
-
/** Compute the inverse of a square matrix (batched). */
|
|
586
|
-
declare function inv(a: ArrayLike): Array;
|
|
609
|
+
//#endregion
|
|
610
|
+
//#region src/frontend/jaxpr.d.ts
|
|
587
611
|
/**
|
|
588
|
-
*
|
|
589
|
-
*
|
|
590
|
-
* For overdetermined systems, this finds the `x` that minimizes `norm(ax - b)`.
|
|
591
|
-
* For underdetermined systems, this finds the minimum-norm solution for `x`.
|
|
592
|
-
*
|
|
593
|
-
* This currently uses Cholesky decomposition to solve the normal equations,
|
|
594
|
-
* under the hood. The method is not as robust as QR or SVD.
|
|
612
|
+
* Function callback with an associated dispose() method.
|
|
595
613
|
*
|
|
596
|
-
*
|
|
597
|
-
*
|
|
598
|
-
* @return least-squares solution of shape `(N,)` or `(N, K)`
|
|
614
|
+
* The dispose() method should be called to clean up any tracer resources needed
|
|
615
|
+
* by the function after the last time it is called.
|
|
599
616
|
*/
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
/**
|
|
604
|
-
declare
|
|
617
|
+
type OwnedFunction<F extends Function> = F & {
|
|
618
|
+
dispose: () => void;
|
|
619
|
+
};
|
|
620
|
+
/** Variable in a Jaxpr expression. */
|
|
621
|
+
declare class Var {
|
|
622
|
+
#private;
|
|
623
|
+
readonly id: number;
|
|
624
|
+
readonly aval: ShapedArray;
|
|
625
|
+
constructor(aval: ShapedArray);
|
|
626
|
+
toString(): string;
|
|
627
|
+
}
|
|
628
|
+
/** Literal in a Jaxpr expression. Currently, only scalars are supported. */
|
|
629
|
+
declare class Lit {
|
|
630
|
+
readonly value: number;
|
|
631
|
+
readonly aval: ShapedArray;
|
|
632
|
+
get dtype(): DType;
|
|
633
|
+
constructor(aval: AbstractValue, value: number);
|
|
634
|
+
}
|
|
635
|
+
type Atom = Var | Lit;
|
|
636
|
+
declare class VarPrinter {
|
|
637
|
+
#private;
|
|
638
|
+
names: Map<Var, string>;
|
|
639
|
+
name(v: Var): string;
|
|
640
|
+
nameType(v: Var): string;
|
|
641
|
+
}
|
|
642
|
+
/** A single statement / binding in a Jaxpr, in ANF form. */
|
|
643
|
+
declare class JaxprEqn {
|
|
644
|
+
readonly primitive: Primitive;
|
|
645
|
+
readonly inputs: Atom[];
|
|
646
|
+
readonly params: Record<string, any>;
|
|
647
|
+
readonly outBinders: Var[];
|
|
648
|
+
constructor(primitive: Primitive, inputs: Atom[], params: Record<string, any>, outBinders: Var[]);
|
|
649
|
+
pprint(usedVars?: Set<Var>, vp?: VarPrinter): PPrint;
|
|
650
|
+
toString(): string;
|
|
651
|
+
}
|
|
652
|
+
/** Typed intermediate representation for traced computations. */
|
|
653
|
+
declare class Jaxpr implements FpHashable {
|
|
654
|
+
#private;
|
|
655
|
+
readonly inBinders: Var[];
|
|
656
|
+
readonly eqns: JaxprEqn[];
|
|
657
|
+
readonly outs: Atom[];
|
|
658
|
+
constructor(inBinders: Var[], eqns: JaxprEqn[], outs: Atom[]);
|
|
659
|
+
pprint(): PPrint;
|
|
660
|
+
toString(): string;
|
|
661
|
+
/**
|
|
662
|
+
* Gets a hash of this Jaxpr.
|
|
663
|
+
*
|
|
664
|
+
* Var identity is not considered in the hash, so two Jaxprs with the same
|
|
665
|
+
* order of assignments and operators but different variable IDs will resolve
|
|
666
|
+
* to the same hash (and toString representation).
|
|
667
|
+
*/
|
|
668
|
+
getHash(): bigint;
|
|
669
|
+
hash(state: FpHash): void;
|
|
670
|
+
/**
|
|
671
|
+
* Produce a simplified Jaxpr with basic optimizations applied.
|
|
672
|
+
* - Trim away unused variables.
|
|
673
|
+
* - Fold away *1, *0, or +0 operations against literals.
|
|
674
|
+
* - Remove no-op movement operations.
|
|
675
|
+
*/
|
|
676
|
+
simplify(): Jaxpr;
|
|
677
|
+
/** Flattens nested Jit in a Jaxpr. Useful for handling jit-of-jit. */
|
|
678
|
+
flatten(): Jaxpr;
|
|
679
|
+
}
|
|
680
|
+
/** Jaxpr with a collection of associated, traced constants. */
|
|
681
|
+
declare class ClosedJaxpr {
|
|
682
|
+
readonly jaxpr: Jaxpr;
|
|
683
|
+
readonly consts: Tracer[];
|
|
684
|
+
constructor(jaxpr: Jaxpr, consts: Tracer[]);
|
|
685
|
+
/** String representation of this Jaxpr. */
|
|
686
|
+
toString(): string;
|
|
687
|
+
/** Apply a function to the underlying Jaxpr. */
|
|
688
|
+
mapJaxpr(f: (jaxpr: Jaxpr) => Jaxpr): ClosedJaxpr;
|
|
689
|
+
/** Dispose of the constants in this Jaxpr. */
|
|
690
|
+
dispose(): void;
|
|
691
|
+
}
|
|
692
|
+
/** @inline */
|
|
693
|
+
type JitOpts = {
|
|
694
|
+
staticArgnums?: number[];
|
|
695
|
+
};
|
|
696
|
+
//#endregion
|
|
697
|
+
//#region src/frontend/core.d.ts
|
|
605
698
|
/**
|
|
606
|
-
*
|
|
699
|
+
* Frontend primitive operations, which are lowered into Kernel objects before
|
|
700
|
+
* being dispatched to the backend.
|
|
607
701
|
*
|
|
608
|
-
*
|
|
609
|
-
*
|
|
702
|
+
* Any operation between arrays can be described in these parts. This is also
|
|
703
|
+
* the set of primitives that can occur in Jaxpr programs, and the level at
|
|
704
|
+
* which transformations like vmap, grad, and jvp occur. They are loosely based
|
|
705
|
+
* on [XLA](https://openxla.org/xla/operation_semantics).
|
|
610
706
|
*
|
|
611
|
-
*
|
|
612
|
-
* @param b - Values of shape `(N,)` or `(..., N, M)`.
|
|
613
|
-
* @returns Solution `x` of shape `(..., N)` or `(..., N, M)`.
|
|
707
|
+
* All n-ary operations support broadcasting, with NumPy semantics.
|
|
614
708
|
*/
|
|
615
|
-
declare
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot$1 as dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logspace, matmul, matrixTranspose, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
709
|
+
declare enum Primitive {
|
|
710
|
+
Add = "add",
|
|
711
|
+
Mul = "mul",
|
|
712
|
+
Idiv = "idiv",
|
|
713
|
+
Mod = "mod",
|
|
714
|
+
// uses sign of numerator, C-style, matches JS but not Python
|
|
715
|
+
Min = "min",
|
|
716
|
+
Max = "max",
|
|
717
|
+
Neg = "neg",
|
|
718
|
+
Reciprocal = "reciprocal",
|
|
719
|
+
Floor = "floor",
|
|
720
|
+
Ceil = "ceil",
|
|
721
|
+
StopGradient = "stop_gradient",
|
|
722
|
+
Cast = "cast",
|
|
723
|
+
Bitcast = "bitcast",
|
|
724
|
+
Sin = "sin",
|
|
725
|
+
Cos = "cos",
|
|
726
|
+
Asin = "asin",
|
|
727
|
+
Atan = "atan",
|
|
728
|
+
Exp = "exp",
|
|
729
|
+
Log = "log",
|
|
730
|
+
Erf = "erf",
|
|
731
|
+
Erfc = "erfc",
|
|
732
|
+
Sqrt = "sqrt",
|
|
733
|
+
Reduce = "reduce",
|
|
734
|
+
Dot = "dot",
|
|
735
|
+
// sum(x*y, axis=-1)
|
|
736
|
+
Conv = "conv",
|
|
737
|
+
// see lax.conv_general_dilated
|
|
738
|
+
Pool = "pool",
|
|
739
|
+
PoolTranspose = "pool_transpose",
|
|
740
|
+
Compare = "compare",
|
|
741
|
+
Where = "where",
|
|
742
|
+
Concatenate = "concatenate",
|
|
743
|
+
Split = "split",
|
|
744
|
+
RandomBits = "random_bits",
|
|
745
|
+
Gather = "gather",
|
|
746
|
+
Transpose = "transpose",
|
|
747
|
+
Broadcast = "broadcast",
|
|
748
|
+
Reshape = "reshape",
|
|
749
|
+
Flip = "flip",
|
|
750
|
+
Shrink = "shrink",
|
|
751
|
+
Pad = "pad",
|
|
752
|
+
Sort = "sort",
|
|
753
|
+
// sort(x, axis=-1)
|
|
754
|
+
Argsort = "argsort",
|
|
755
|
+
// argsort(x, axis=-1)
|
|
756
|
+
TriangularSolve = "triangular_solve",
|
|
757
|
+
// A is upper triangular, A @ X.T = B.T
|
|
758
|
+
Cholesky = "cholesky",
|
|
759
|
+
// A is positive-definite, A = L @ L^T
|
|
760
|
+
LU = "lu",
|
|
761
|
+
// LU decomposition with partial pivoting
|
|
762
|
+
Jit = "jit",
|
|
670
763
|
}
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
764
|
+
interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
765
|
+
[Primitive.Cast]: {
|
|
766
|
+
dtype: DType;
|
|
767
|
+
};
|
|
768
|
+
[Primitive.Bitcast]: {
|
|
769
|
+
dtype: DType;
|
|
770
|
+
};
|
|
771
|
+
[Primitive.Reduce]: {
|
|
772
|
+
op: AluOp;
|
|
773
|
+
axis: number[];
|
|
774
|
+
};
|
|
775
|
+
[Primitive.Conv]: ConvParams;
|
|
776
|
+
[Primitive.Pool]: {
|
|
777
|
+
window: number[];
|
|
778
|
+
strides: number[];
|
|
779
|
+
};
|
|
780
|
+
[Primitive.PoolTranspose]: {
|
|
781
|
+
inShape: number[];
|
|
782
|
+
window: number[];
|
|
783
|
+
strides: number[];
|
|
784
|
+
};
|
|
785
|
+
[Primitive.Compare]: {
|
|
786
|
+
op: CompareOp;
|
|
787
|
+
};
|
|
788
|
+
[Primitive.Concatenate]: {
|
|
789
|
+
axis: number;
|
|
790
|
+
};
|
|
791
|
+
[Primitive.Split]: {
|
|
792
|
+
axis: number;
|
|
793
|
+
sizes: number[];
|
|
794
|
+
};
|
|
795
|
+
[Primitive.RandomBits]: {
|
|
796
|
+
shape: number[];
|
|
797
|
+
mode: "xor" | 0 | 1;
|
|
798
|
+
};
|
|
799
|
+
[Primitive.Gather]: {
|
|
800
|
+
axis: number[];
|
|
801
|
+
outDim: number;
|
|
802
|
+
};
|
|
803
|
+
[Primitive.Transpose]: {
|
|
804
|
+
perm: number[];
|
|
805
|
+
};
|
|
806
|
+
[Primitive.Broadcast]: {
|
|
807
|
+
shape: number[];
|
|
808
|
+
axis: number[];
|
|
809
|
+
};
|
|
810
|
+
[Primitive.Reshape]: {
|
|
811
|
+
shape: number[];
|
|
812
|
+
};
|
|
813
|
+
[Primitive.Flip]: {
|
|
814
|
+
axis: number[];
|
|
815
|
+
};
|
|
816
|
+
[Primitive.Shrink]: {
|
|
817
|
+
slice: Pair[];
|
|
818
|
+
};
|
|
819
|
+
[Primitive.Pad]: {
|
|
820
|
+
width: Pair[];
|
|
821
|
+
};
|
|
822
|
+
[Primitive.TriangularSolve]: {
|
|
823
|
+
unitDiagonal: boolean;
|
|
824
|
+
};
|
|
825
|
+
[Primitive.Jit]: {
|
|
826
|
+
name: string;
|
|
827
|
+
jaxpr: Jaxpr;
|
|
828
|
+
numConsts: number;
|
|
829
|
+
};
|
|
830
|
+
}
|
|
831
|
+
/** Type of parameters taken by each primitive. */
|
|
832
|
+
type PrimitiveParams<T extends Primitive> = T extends keyof PrimitiveParamsImpl ? PrimitiveParamsImpl[T] : Record<string, never>;
|
|
833
|
+
declare enum CompareOp {
|
|
834
|
+
Less = "less",
|
|
835
|
+
Equal = "equal",
|
|
836
|
+
NotEqual = "not_equal",
|
|
837
|
+
LessEqual = "less_equal",
|
|
838
|
+
}
|
|
839
|
+
/** @inline */
|
|
840
|
+
type Axis = number | number[] | null;
|
|
841
|
+
/** @inline */
|
|
842
|
+
type ReduceOpts = {
|
|
843
|
+
keepdims?: boolean;
|
|
844
|
+
};
|
|
845
|
+
type MainTrace = {
|
|
846
|
+
level: number;
|
|
847
|
+
traceType: new (main: MainTrace) => Trace;
|
|
848
|
+
globalData: any | null;
|
|
849
|
+
};
|
|
731
850
|
/**
|
|
732
|
-
*
|
|
733
|
-
*
|
|
851
|
+
* Push an interpreter onto the trace stack. Use this like:
|
|
852
|
+
* `using main = newMain(...);`
|
|
734
853
|
*/
|
|
735
|
-
|
|
854
|
+
|
|
855
|
+
type TracerValue = Tracer | number | boolean;
|
|
856
|
+
declare abstract class Trace {
|
|
857
|
+
readonly main: MainTrace;
|
|
858
|
+
constructor(main: MainTrace);
|
|
859
|
+
abstract pure(val: TracerValue): Tracer;
|
|
860
|
+
abstract lift(val: Tracer): Tracer;
|
|
861
|
+
abstract processPrimitive<P extends Primitive>(primitive: P, tracers: Tracer[], params: PrimitiveParams<P>): Tracer[];
|
|
862
|
+
}
|
|
863
|
+
/** Internal representation of an array value. */
|
|
864
|
+
interface AbstractValue {
|
|
865
|
+
/** Shape of the array. Must be a static tuple of non-negative dimensions. */
|
|
866
|
+
shape: number[];
|
|
867
|
+
/** Concrete data type of array elements. */
|
|
868
|
+
dtype: DType;
|
|
869
|
+
/**
|
|
870
|
+
* Arrays created from JavaScript numbers (e.g., `np.array(3)`) are created as
|
|
871
|
+
* _weakly typed_ unless a dtype is explicitly specified.
|
|
872
|
+
*
|
|
873
|
+
* Weakly typed values will automatically cast to the data type of other
|
|
874
|
+
* arrays when used as an operand as an expression. This property only affects
|
|
875
|
+
* how they promote in type casting; their memory layout is still determined
|
|
876
|
+
* by the actual `dtype` field.
|
|
877
|
+
*
|
|
878
|
+
* ```ts
|
|
879
|
+
* const x = np.array(3); // weakType = true, dtype = float32
|
|
880
|
+
* const y = np.array([1, 2], { dtype: np.int32 }); // weakType = false, dtype = int32
|
|
881
|
+
* const z = x.add(y); // z has dtype int32 because x is weakly typed
|
|
882
|
+
* ```
|
|
883
|
+
*
|
|
884
|
+
* Weak types are present in JIT programs in their spec (e.g., Jaxpr inputs
|
|
885
|
+
* and outputs can be weakly typed) form. But they're solely a frontend
|
|
886
|
+
* concept. Backends are not aware of weak types.
|
|
887
|
+
*/
|
|
888
|
+
weakType: boolean;
|
|
889
|
+
}
|
|
736
890
|
/**
|
|
737
|
-
*
|
|
738
|
-
* Give a new shape to an array without changing its data.
|
|
891
|
+
* Broadcast shapes and promote types with casting for two avals.
|
|
739
892
|
*
|
|
740
|
-
*
|
|
741
|
-
*
|
|
893
|
+
* This implements the weak type behavior described in `promoteTypes()`, but not
|
|
894
|
+
* implemented in that function as `weakType` is not passed.
|
|
742
895
|
*/
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
896
|
+
|
|
897
|
+
declare abstract class Tracer {
|
|
898
|
+
/** @ignore */
|
|
899
|
+
readonly _trace: Trace;
|
|
900
|
+
constructor(trace: Trace);
|
|
901
|
+
abstract get aval(): AbstractValue;
|
|
902
|
+
abstract toString(): string;
|
|
903
|
+
/**
|
|
904
|
+
* Access an array by reference, incrementing the reference count.
|
|
905
|
+
*
|
|
906
|
+
* jax-js handles freeing arrays by using "move" semantics, like in Rust/C++.
|
|
907
|
+
* Whenever you pass an array into a function, that function should consume
|
|
908
|
+
* the array, and it will no longer be usable. For example, if you had:
|
|
909
|
+
*
|
|
910
|
+
* ```
|
|
911
|
+
* const x = np.array([1, 2, 3]);
|
|
912
|
+
* const y = np.add(x, x);
|
|
913
|
+
* ```
|
|
914
|
+
*
|
|
915
|
+
* The second line does not work because the first parameter consumes `x`, and
|
|
916
|
+
* then the second parameter will already have been freed / disposed.
|
|
917
|
+
*
|
|
918
|
+
* To fix this, you can write:
|
|
919
|
+
*
|
|
920
|
+
* ```
|
|
921
|
+
* const y = np.add(x.ref, x);
|
|
922
|
+
* ```
|
|
923
|
+
*
|
|
924
|
+
* Under the hood, every access to `.ref` increments the internal reference
|
|
925
|
+
* count of the array. The reference count starts at 1. When it hits 0, the
|
|
926
|
+
* memory behind the array is freed.
|
|
927
|
+
*/
|
|
928
|
+
abstract get ref(): this;
|
|
929
|
+
/**
|
|
930
|
+
* Manually decrement the reference count of the array.
|
|
931
|
+
*
|
|
932
|
+
* Arrays are created with reference count 1. Whenever it is used as argument
|
|
933
|
+
* to a function or other operation, it is disposed (i.e., reference count
|
|
934
|
+
* decreases by 1) automatically. Whenever a `.ref` is created, the reference
|
|
935
|
+
* count increases.
|
|
936
|
+
*
|
|
937
|
+
* You generally don't need to call this function directly since arrays are
|
|
938
|
+
* automatically disposed after being passed into an operation. One common
|
|
939
|
+
* exception is when writing a function and ignoring one of its arguments. In
|
|
940
|
+
* that case, by convention you should dispose of that argument manually.
|
|
941
|
+
*
|
|
942
|
+
* ```
|
|
943
|
+
* function myCustomOperation(a: np.Array, b: np.Array) {
|
|
944
|
+
* b.dispose(); // Needed to satisfy "move" rules.
|
|
945
|
+
* return a.add(1);
|
|
946
|
+
* }
|
|
947
|
+
* ```
|
|
948
|
+
*/
|
|
949
|
+
abstract dispose(): void;
|
|
950
|
+
/** The shape of the array. */
|
|
951
|
+
get shape(): number[];
|
|
952
|
+
/** The total number of elements in the array. */
|
|
953
|
+
get size(): number;
|
|
954
|
+
/** The dtype of elements stored in the array. */
|
|
955
|
+
get dtype(): DType;
|
|
956
|
+
/**
|
|
957
|
+
* Whether the array is weakly typed.
|
|
958
|
+
*
|
|
959
|
+
* Weakly typed arrays will cast to the dtype of the other operand. See
|
|
960
|
+
* `promoteTypes()` for details.
|
|
961
|
+
*/
|
|
962
|
+
get weakType(): boolean;
|
|
963
|
+
/** The number of dimensions of the array. */
|
|
964
|
+
get ndim(): number;
|
|
965
|
+
/** @ignore */
|
|
966
|
+
fullLower(): Tracer;
|
|
967
|
+
neg(): this;
|
|
968
|
+
add(other: this | TracerValue): this;
|
|
969
|
+
mul(other: this | TracerValue): this;
|
|
970
|
+
mod(other: this | TracerValue): this;
|
|
971
|
+
greater(other: this | TracerValue): this;
|
|
972
|
+
less(other: this | TracerValue): this;
|
|
973
|
+
equal(other: this | TracerValue): this;
|
|
974
|
+
notEqual(other: this | TracerValue): this;
|
|
975
|
+
greaterEqual(other: this | TracerValue): this;
|
|
976
|
+
lessEqual(other: this | TracerValue): this;
|
|
977
|
+
/** Sum of the elements of the array over a given axis, or axes. */
|
|
978
|
+
sum(axis?: Axis, opts?: ReduceOpts): this;
|
|
979
|
+
/** Product of the array elements over a given axis. */
|
|
980
|
+
prod(axis?: Axis, opts?: ReduceOpts): this;
|
|
981
|
+
/** Compute the average of the array elements along the specified axis. */
|
|
982
|
+
mean(axis?: Axis, opts?: ReduceOpts): this;
|
|
983
|
+
/** Minimum of the elements of the array along a given axis. */
|
|
984
|
+
min(axis?: Axis, opts?: ReduceOpts): this;
|
|
985
|
+
/** Maximum of the elements of the array along a given axis. */
|
|
986
|
+
max(axis?: Axis, opts?: ReduceOpts): this;
|
|
987
|
+
/** Test whether all array elements along a given axis evaluate to true. */
|
|
988
|
+
all(axis?: Axis, opts?: ReduceOpts): this;
|
|
989
|
+
/** Test whether any array element along a given axis evaluates to true. */
|
|
990
|
+
any(axis?: Axis, opts?: ReduceOpts): this;
|
|
991
|
+
/** Permute the dimensions of an array. Defaults to reversing the axis order. */
|
|
992
|
+
transpose(perm?: number[]): this;
|
|
993
|
+
/**
|
|
994
|
+
* Give a new shape to an array without changing its data.
|
|
995
|
+
*
|
|
996
|
+
* One shape dimension can be -1. In this case, the value is inferred from the
|
|
997
|
+
* length of the array and remaining dimensions.
|
|
998
|
+
*/
|
|
999
|
+
reshape(shape: number | number[]): this;
|
|
1000
|
+
/** Copy the array and cast to a specified dtype. */
|
|
1001
|
+
astype(dtype: DType): this;
|
|
1002
|
+
/** Subtract an array from this one. */
|
|
1003
|
+
sub(other: this | TracerValue): this;
|
|
1004
|
+
/** Divide an array by this one. */
|
|
1005
|
+
div(other: this | TracerValue): this;
|
|
1006
|
+
/** Return specified diagonals. See `jax.numpy.diagonal` for full docs. */
|
|
1007
|
+
diagonal(offset?: number, axis1?: number, axis2?: number): this;
|
|
1008
|
+
/** Flatten the array without changing its data. */
|
|
1009
|
+
flatten(): this;
|
|
1010
|
+
/** Flatten the array without changing its data. */
|
|
1011
|
+
ravel(): this;
|
|
1012
|
+
/**
|
|
1013
|
+
* Iterate over the first dimension of this array, returning slices.
|
|
1014
|
+
*
|
|
1015
|
+
* This can be used to destructure arrays. For example:
|
|
1016
|
+
*
|
|
1017
|
+
* ```js
|
|
1018
|
+
* let x = np.array([[1, 2], [3, 4]]);
|
|
1019
|
+
* let [a, b] = x;
|
|
1020
|
+
* console.log(a.js()); // [1, 2]
|
|
1021
|
+
* console.log(b.js()); // [3, 4]
|
|
1022
|
+
* ```
|
|
1023
|
+
*/
|
|
1024
|
+
[Symbol.iterator](): IterableIterator<this>;
|
|
1025
|
+
/**
|
|
1026
|
+
* Return a sorted copy of an array in ascending order.
|
|
1027
|
+
*
|
|
1028
|
+
* See `jax.numpy.sort` for full docs.
|
|
1029
|
+
*/
|
|
1030
|
+
sort(axis?: number): this;
|
|
1031
|
+
/**
|
|
1032
|
+
* Return the indices that would sort an array. This may not be a stable
|
|
1033
|
+
* sorting algorithm; it need not preserve order of indices in ties.
|
|
1034
|
+
*
|
|
1035
|
+
* See `jax.numpy.argsort` for full docs.
|
|
1036
|
+
*/
|
|
1037
|
+
argsort(axis?: number): this;
|
|
1038
|
+
/**
|
|
1039
|
+
* Slice an array along one or more axes.
|
|
1040
|
+
*
|
|
1041
|
+
* This is the equivalent of slicing in Python, e.g. `x[1:3, 2, :, None]`. To
|
|
1042
|
+
* mimic this in JavaScript, we would write:
|
|
1043
|
+
*
|
|
1044
|
+
* ```js
|
|
1045
|
+
* x.slice([1, 3], 2, [], null);
|
|
1046
|
+
* ```
|
|
1047
|
+
*
|
|
1048
|
+
* The `slice` method accepts a variable number of arguments, each of which
|
|
1049
|
+
* can be a number, an empty array, a single-element array, a two-element
|
|
1050
|
+
* array, or `null`. The arguments are interpreted as follows:
|
|
1051
|
+
*
|
|
1052
|
+
* - A number `n` means to access the `n`-th element along that axis, removing
|
|
1053
|
+
* that axis from the resulting shape.
|
|
1054
|
+
* - An empty array `[]` means to keep that axis as-is, like `:` in Python.
|
|
1055
|
+
* - A single-element array `[i]` means to start slicing from index `i`
|
|
1056
|
+
* (inclusive) to the end of the axis, like `x[i:]`.
|
|
1057
|
+
* - A two-element array `[i, j]` means to slice from index `i` (inclusive)
|
|
1058
|
+
* to index `j` (exclusive), like `x[i:j]`.
|
|
1059
|
+
* - `null` means to add a new axis at that position, like `np.newaxis`.
|
|
1060
|
+
*
|
|
1061
|
+
* Like in Python, negative indices are supported, which count from the end of
|
|
1062
|
+
* the axis. For example, `-1` means the last element.
|
|
1063
|
+
*
|
|
1064
|
+
* Strided slices are not yet implemented, so you cannot write `x[::2]` or
|
|
1065
|
+
* similar.
|
|
1066
|
+
*
|
|
1067
|
+
* Advanced indexing by integer arrays is also supported. This translates to
|
|
1068
|
+
* the "gather" primitive, and it allows you to access specific elements of
|
|
1069
|
+
* the array by integer indices stored in another array.
|
|
1070
|
+
*/
|
|
1071
|
+
slice(...index: (number | [] | [number] | Pair | null | Tracer)[]): this;
|
|
1072
|
+
}
|
|
1073
|
+
declare class ShapedArray implements AbstractValue {
|
|
1074
|
+
readonly shape: number[];
|
|
1075
|
+
readonly dtype: DType;
|
|
1076
|
+
readonly weakType: boolean;
|
|
1077
|
+
constructor(shape: number[], dtype: DType, weakType: boolean);
|
|
1078
|
+
static fromAval(aval: AbstractValue): ShapedArray;
|
|
1079
|
+
get ndim(): number;
|
|
1080
|
+
get size(): number;
|
|
1081
|
+
scalar(): ShapedArray;
|
|
1082
|
+
toString(): string;
|
|
1083
|
+
equals(other: ShapedArray): boolean;
|
|
1084
|
+
}
|
|
1085
|
+
//#endregion
|
|
1086
|
+
//#region src/frontend/array.d.ts
|
|
1087
|
+
type ArrayLike = Array | number | boolean;
|
|
1088
|
+
/** Version of pureArray with fudged types. */
|
|
1089
|
+
|
|
749
1090
|
/**
|
|
750
|
-
*
|
|
751
|
-
* Add padding (zeros) to an array.
|
|
1091
|
+
* An executable operation that will be dispatched to the backend.
|
|
752
1092
|
*
|
|
753
|
-
*
|
|
754
|
-
*
|
|
755
|
-
* pair specifies the padding for its corresponding axis.
|
|
756
|
-
*/
|
|
757
|
-
declare const pad: (x: ArrayLike, width: number | Pair | Pair[]) => Array;
|
|
758
|
-
/**
|
|
759
|
-
* @function
|
|
760
|
-
* Return the number of dimensions of an array. Does not consume array reference.
|
|
761
|
-
*/
|
|
762
|
-
declare const ndim: (x: ArrayLike) => number;
|
|
763
|
-
/** @function Return the shape of an array. Does not consume array reference. */
|
|
764
|
-
declare const shape$1: (x: ArrayLike) => number[];
|
|
765
|
-
/**
|
|
766
|
-
* @function
|
|
767
|
-
* Return an array of zeros with the same shape and type as a given array.
|
|
768
|
-
*/
|
|
769
|
-
declare const zerosLike: (a: ArrayLike, dtype?: DType) => Array;
|
|
770
|
-
/**
|
|
771
|
-
* @function
|
|
772
|
-
* Return an array of ones with the same shape and type as a given array.
|
|
773
|
-
*/
|
|
774
|
-
declare const onesLike: (a: ArrayLike, dtype?: DType) => Array;
|
|
775
|
-
/**
|
|
776
|
-
* @function
|
|
777
|
-
* Return a full array with the same shape and type as a given array.
|
|
778
|
-
*/
|
|
779
|
-
declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, dtype?: DType) => Array;
|
|
780
|
-
/**
|
|
781
|
-
* Return the number of elements in an array, optionally along an axis.
|
|
782
|
-
* Does not consume array reference.
|
|
1093
|
+
* This holds a reference to all input buffers used in the operation. After the
|
|
1094
|
+
* operation is dispatched, the references should be released.
|
|
783
1095
|
*/
|
|
784
|
-
declare
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
1096
|
+
declare class PendingExecute {
|
|
1097
|
+
#private;
|
|
1098
|
+
readonly backend: Backend;
|
|
1099
|
+
readonly source: Kernel | Routine;
|
|
1100
|
+
readonly inputs: Slot[];
|
|
1101
|
+
readonly outputs: Slot[];
|
|
1102
|
+
prepared: Executable | null;
|
|
1103
|
+
submitted: boolean;
|
|
1104
|
+
constructor(backend: Backend, source: Kernel | Routine, inputs: Slot[], outputs: Slot[]);
|
|
1105
|
+
updateRc(delta: number): void;
|
|
1106
|
+
prepare(): Promise<void>;
|
|
1107
|
+
prepareSync(): void;
|
|
1108
|
+
submit(): void;
|
|
1109
|
+
}
|
|
1110
|
+
/** @inline */
|
|
1111
|
+
type DTypeAndDevice = {
|
|
1112
|
+
dtype?: DType;
|
|
1113
|
+
device?: Device;
|
|
1114
|
+
};
|
|
1115
|
+
type ArrayConstructorArgs = {
|
|
1116
|
+
source: AluExp | Slot;
|
|
1117
|
+
st: ShapeTracker;
|
|
1118
|
+
dtype: DType;
|
|
1119
|
+
weakType: boolean;
|
|
1120
|
+
backend: Backend;
|
|
1121
|
+
committed: boolean;
|
|
1122
|
+
pending?: Iterable<PendingExecute>;
|
|
1123
|
+
};
|
|
795
1124
|
/**
|
|
796
|
-
*
|
|
1125
|
+
* A multidimensional numeric array with data stored on CPU or GPU.
|
|
797
1126
|
*
|
|
798
|
-
*
|
|
799
|
-
*
|
|
1127
|
+
* This is the library's core data type. Equivalent to `jax.Array` from JAX, or
|
|
1128
|
+
* `torch.Tensor`.
|
|
1129
|
+
*
|
|
1130
|
+
* Not to be confused with the JavaScript "Array" constructor. Avoid importing
|
|
1131
|
+
* this into your code's namespace if you're already using the JavaScript
|
|
1132
|
+
* "Array" type by name.
|
|
800
1133
|
*/
|
|
801
|
-
declare
|
|
1134
|
+
declare class Array extends Tracer {
|
|
1135
|
+
#private;
|
|
1136
|
+
/**
|
|
1137
|
+
* @ignore
|
|
1138
|
+
* Constructs an array from source, shape and backend. Note that if the source
|
|
1139
|
+
* is a backend `Slot`, this constructor _takes ownership_ of the slot. It
|
|
1140
|
+
* will be freed when the array is disposed.
|
|
1141
|
+
*/
|
|
1142
|
+
constructor(args: ArrayConstructorArgs);
|
|
1143
|
+
/** @ignore */
|
|
1144
|
+
get aval(): ShapedArray;
|
|
1145
|
+
/** Return a simple string representation of the array's dimensions. */
|
|
1146
|
+
toString(): string;
|
|
1147
|
+
get device(): Device;
|
|
1148
|
+
get ref(): this;
|
|
1149
|
+
/** Get the current reference count (for debugging memory management). */
|
|
1150
|
+
get refCount(): number;
|
|
1151
|
+
dispose(): void;
|
|
1152
|
+
/**
|
|
1153
|
+
* Convert this array into a primitive value.
|
|
1154
|
+
*
|
|
1155
|
+
* This only works for scalars (0-dimensional arrays). It lets you get values
|
|
1156
|
+
* "out" of the JAX system. For instance, if `x = np.array(5)`, then you can
|
|
1157
|
+
* evaluate `x + 1` and `x ** 2` to get `6` and `25`, respectively.
|
|
1158
|
+
*
|
|
1159
|
+
* This method is also called for `==` equality.
|
|
1160
|
+
*/
|
|
1161
|
+
[Symbol.toPrimitive](): any;
|
|
1162
|
+
/** Realize the array and return it as data. */
|
|
1163
|
+
data(): Promise<DataArray>;
|
|
1164
|
+
/**
|
|
1165
|
+
* Wait for this array to finish evaluation.
|
|
1166
|
+
*
|
|
1167
|
+
* Operations and data loading in jax-js are lazy, so this function ensures
|
|
1168
|
+
* that pending operations are dispatched and fully executed before it
|
|
1169
|
+
* returns.
|
|
1170
|
+
*
|
|
1171
|
+
* If you are mapping from `data()` or `dataSync()`, it will also trigger
|
|
1172
|
+
* dispatch of operations as well.
|
|
1173
|
+
*
|
|
1174
|
+
* **Note:** `jax.blockUntilReady()` is a higher-level API, it calls this
|
|
1175
|
+
* asynchronously for multiple arrays.
|
|
1176
|
+
*/
|
|
1177
|
+
blockUntilReady(): Promise<Array>;
|
|
1178
|
+
/**
|
|
1179
|
+
* Realize the array and return it as data. This is a sync variant and not
|
|
1180
|
+
* recommended for performance reasons, as it will block rendering.
|
|
1181
|
+
*/
|
|
1182
|
+
dataSync(): DataArray;
|
|
1183
|
+
/**
|
|
1184
|
+
* Convert this array into a JavaScript object.
|
|
1185
|
+
*
|
|
1186
|
+
* This is a blocking operation that will compile all of the shaders and wait
|
|
1187
|
+
* for execution to complete, synchronously. No other JavaScript code on the
|
|
1188
|
+
* site will be run during shader execution.
|
|
1189
|
+
*
|
|
1190
|
+
* To avoid blocking, prefer `jsAsync()` when possible.
|
|
1191
|
+
*/
|
|
1192
|
+
js(): any;
|
|
1193
|
+
/** Convert this array into a JavaScript object, asynchronously. */
|
|
1194
|
+
jsAsync(): Promise<any>;
|
|
1195
|
+
/**
|
|
1196
|
+
* Copy an element of an array to a numeric scalar and return it.
|
|
1197
|
+
*
|
|
1198
|
+
* Throws an error if the array does not have a single element. The array must
|
|
1199
|
+
* either be rank-0, or all dimensions of the shape are 1.
|
|
1200
|
+
*/
|
|
1201
|
+
item(): number;
|
|
1202
|
+
/** @private Internal plumbing method for Array / Tracer ops. */
|
|
1203
|
+
static _implRules(): typeof implRules;
|
|
1204
|
+
/** @private */
|
|
1205
|
+
_realizeSource(): number;
|
|
1206
|
+
/** @private Put this array on a new backend, asynchronously. */
|
|
1207
|
+
_put(backend: Backend): Promise<Array>;
|
|
1208
|
+
/** @private Put this array on a new backend, synchronously. */
|
|
1209
|
+
_putSync(backend: Backend): Array;
|
|
1210
|
+
}
|
|
1211
|
+
/** Constructor for creating a new array from data. */
|
|
1212
|
+
declare function array(values: Array | DataArray | RecursiveArray<number> | RecursiveArray<boolean>, {
|
|
1213
|
+
shape,
|
|
1214
|
+
dtype,
|
|
1215
|
+
device
|
|
1216
|
+
}?: {
|
|
1217
|
+
shape?: number[];
|
|
1218
|
+
} & DTypeAndDevice): Array;
|
|
1219
|
+
/** If x is a value, lift it into an array, otherwise leave it be. */
|
|
1220
|
+
|
|
1221
|
+
type ImplRule<P extends Primitive> = (tracers: Array[], params: PrimitiveParams<P>) => Array[];
|
|
1222
|
+
declare const implRules: { [P in Primitive]: ImplRule<P> };
|
|
1223
|
+
/** Return a new array of given shape and type, filled with zeros. */
|
|
1224
|
+
declare function zeros(shape: number[], {
|
|
1225
|
+
dtype,
|
|
1226
|
+
device
|
|
1227
|
+
}?: DTypeAndDevice): Array;
|
|
1228
|
+
/** Return a new array of given shape and type, filled with ones. */
|
|
1229
|
+
declare function ones(shape: number[], {
|
|
1230
|
+
dtype,
|
|
1231
|
+
device
|
|
1232
|
+
}?: DTypeAndDevice): Array;
|
|
1233
|
+
/** Return a new array of given shape and type, filled with `fill_value`. */
|
|
1234
|
+
declare function full(shape: number[], fillValue: number | boolean | Array, {
|
|
1235
|
+
dtype,
|
|
1236
|
+
device
|
|
1237
|
+
}?: DTypeAndDevice): Array;
|
|
802
1238
|
/**
|
|
803
|
-
*
|
|
1239
|
+
* Create an identity matrix.
|
|
804
1240
|
*
|
|
805
|
-
*
|
|
806
|
-
*
|
|
1241
|
+
* If numCols is not provided, it defaults to numRows, i.e., a square identity
|
|
1242
|
+
* matrix with ones on the diagonal.
|
|
807
1243
|
*/
|
|
808
|
-
declare function
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
1244
|
+
declare function eye(numRows: number, numCols?: number, {
|
|
1245
|
+
dtype,
|
|
1246
|
+
device
|
|
1247
|
+
}?: DTypeAndDevice): Array;
|
|
1248
|
+
/** Return the identity matrix, with ones on the main diagonal. */
|
|
1249
|
+
declare function identity$1(n: number, {
|
|
1250
|
+
dtype,
|
|
1251
|
+
device
|
|
1252
|
+
}?: DTypeAndDevice): Array;
|
|
813
1253
|
/**
|
|
814
|
-
*
|
|
1254
|
+
* Return evenly spaced values within a given interval.
|
|
815
1255
|
*
|
|
816
|
-
*
|
|
817
|
-
*
|
|
1256
|
+
* This can be called with a varying number of arguments, just like the range()
|
|
1257
|
+
* builtin function in Python.
|
|
1258
|
+
*
|
|
1259
|
+
* - `arange(stop)` is equivalent to `arange(0, stop, 1)`.
|
|
1260
|
+
* - `arange(start, stop)` is equivalent to `arange(start, stop, 1)`.
|
|
1261
|
+
* - `arange(start, stop, step)` creates an array starting at `start`, ending
|
|
1262
|
+
* before `stop`, with a step size of `step`.
|
|
1263
|
+
*
|
|
1264
|
+
* Defaults to an integer data type. This can produce unintended results when
|
|
1265
|
+
* using a non-integer step, so prefer linspace() in those cases.
|
|
818
1266
|
*/
|
|
819
|
-
declare function
|
|
1267
|
+
declare function arange(start: number, stop?: number, step?: number, {
|
|
1268
|
+
dtype,
|
|
1269
|
+
device
|
|
1270
|
+
}?: DTypeAndDevice): Array;
|
|
820
1271
|
/**
|
|
821
|
-
*
|
|
1272
|
+
* Return an array with ones on and below the diagonal and zeros elsewhere.
|
|
822
1273
|
*
|
|
823
|
-
*
|
|
824
|
-
*
|
|
1274
|
+
* If `k` is provided, it specifies the sub-diagonal on and below which the
|
|
1275
|
+
* array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
|
|
1276
|
+
* `k>0` is above it.
|
|
825
1277
|
*/
|
|
826
|
-
declare function
|
|
1278
|
+
declare function tri(n: number, m?: number, k?: number, {
|
|
1279
|
+
dtype,
|
|
1280
|
+
device
|
|
1281
|
+
}?: DTypeAndDevice): Array;
|
|
1282
|
+
/** Return the lower triangle of an array. Must be of dimension >= 2. */
|
|
1283
|
+
declare function tril(a: ArrayLike, k?: number): Array;
|
|
1284
|
+
/** Return the upper triangle of an array. Must be of dimension >= 2. */
|
|
1285
|
+
declare function triu(a: ArrayLike, k?: number): Array;
|
|
827
1286
|
/**
|
|
828
|
-
*
|
|
1287
|
+
* Return evenly spaced numbers over a specified interval.
|
|
829
1288
|
*
|
|
830
|
-
*
|
|
831
|
-
*
|
|
1289
|
+
* Returns _num_ evenly spaced samples, calculated over the interval
|
|
1290
|
+
* [`start`, `stop`]. The endpoint `stop` is included in the result by default,
|
|
1291
|
+
* but this is controlled by the `endpoint` parameter.
|
|
1292
|
+
*
|
|
1293
|
+
* The default data type is Float32. Use arange() for integer steps.
|
|
832
1294
|
*/
|
|
833
|
-
declare function
|
|
834
|
-
|
|
835
|
-
|
|
1295
|
+
declare function linspace(start: number, stop: number, num?: number, endpoint?: boolean, {
|
|
1296
|
+
dtype,
|
|
1297
|
+
device
|
|
1298
|
+
}?: DTypeAndDevice): Array;
|
|
836
1299
|
/**
|
|
837
|
-
*
|
|
1300
|
+
* Return numbers spaced evenly on a log scale.
|
|
838
1301
|
*
|
|
839
|
-
*
|
|
840
|
-
*
|
|
841
|
-
*
|
|
842
|
-
*
|
|
843
|
-
* @param
|
|
1302
|
+
* In linear space, the sequence starts at `base ** start` and ends at
|
|
1303
|
+
* `base ** stop` (see `endpoint` below).
|
|
1304
|
+
*
|
|
1305
|
+
* @param start - `base ** start` is the starting value of the sequence.
|
|
1306
|
+
* @param stop - `base ** stop` is the final value of the sequence, unless `endpoint` is false.
|
|
1307
|
+
* @param num - Number of samples to generate. Default is 50.
|
|
1308
|
+
* @param endpoint - If true, `stop` is the last sample. Otherwise, it is not included. Default is true.
|
|
1309
|
+
* @param base - The base of the log space. Default is 10.
|
|
1310
|
+
* @returns Array of evenly spaced values on a log scale.
|
|
844
1311
|
*/
|
|
845
|
-
declare function
|
|
1312
|
+
declare function logspace(start: number, stop: number, num?: number, endpoint?: boolean, base?: number, {
|
|
1313
|
+
dtype,
|
|
1314
|
+
device
|
|
1315
|
+
}?: DTypeAndDevice): Array;
|
|
1316
|
+
//#endregion
|
|
1317
|
+
//#region src/frontend/linearize.d.ts
|
|
1318
|
+
/** @inline */
|
|
1319
|
+
type GradOpts = {
|
|
1320
|
+
/**
|
|
1321
|
+
* Integer or sequence of integers. Specifies which positional argument(s) to
|
|
1322
|
+
* differentiate with respect to.
|
|
1323
|
+
*
|
|
1324
|
+
* Defaults to `0` (the first argument).
|
|
1325
|
+
*/
|
|
1326
|
+
argnums?: number | number[];
|
|
1327
|
+
/**
|
|
1328
|
+
* The input function returns a pair of `[out, aux]` including an auxiliary
|
|
1329
|
+
* value. This `aux` is not differentiated, but is returned alongside the
|
|
1330
|
+
* gradient when evaluating the function.
|
|
1331
|
+
*/
|
|
1332
|
+
hasAux?: boolean;
|
|
1333
|
+
};
|
|
1334
|
+
declare namespace lax_linalg_d_exports {
|
|
1335
|
+
export { cholesky$1 as cholesky, lu, triangularSolve };
|
|
1336
|
+
}
|
|
846
1337
|
/**
|
|
847
|
-
*
|
|
1338
|
+
* Compute the Cholesky decomposition of a symmetric positive-definite matrix.
|
|
848
1339
|
*
|
|
849
|
-
* The
|
|
850
|
-
* `axis` (the first, by default).
|
|
1340
|
+
* The Cholesky decomposition of a matrix `A` is:
|
|
851
1341
|
*
|
|
852
|
-
*
|
|
1342
|
+
* - A = L @ L^T (for upper=false, default)
|
|
1343
|
+
* - A = U^T @ U (for upper=true)
|
|
1344
|
+
*
|
|
1345
|
+
* where `L` is a lower-triangular matrix and `U` is an upper-triangular matrix.
|
|
1346
|
+
* The input matrix must be symmetric and positive-definite.
|
|
1347
|
+
*
|
|
1348
|
+
* @example
|
|
1349
|
+
* ```ts
|
|
1350
|
+
* import { lax, numpy as np } from "@jax-js/jax";
|
|
1351
|
+
*
|
|
1352
|
+
* const x = np.array([[2., 1.], [1., 2.]]);
|
|
1353
|
+
*
|
|
1354
|
+
* // Lower Cholesky factorization (default):
|
|
1355
|
+
* const L = lax.linalg.cholesky(x);
|
|
1356
|
+
* // L ≈ [[1.4142135, 0], [0.70710677, 1.2247449]]
|
|
1357
|
+
*
|
|
1358
|
+
* // Upper Cholesky factorization:
|
|
1359
|
+
* const U = lax.linalg.cholesky(x, { upper: true });
|
|
1360
|
+
* // U ≈ [[1.4142135, 0.70710677], [0, 1.2247449]]
|
|
1361
|
+
* ```
|
|
853
1362
|
*/
|
|
854
|
-
declare function
|
|
1363
|
+
declare function cholesky$1(a: ArrayLike, {
|
|
1364
|
+
upper
|
|
1365
|
+
}?: {
|
|
1366
|
+
upper?: boolean;
|
|
1367
|
+
}): Array;
|
|
855
1368
|
/**
|
|
856
|
-
*
|
|
1369
|
+
* LU decomposition with partial pivoting.
|
|
857
1370
|
*
|
|
858
|
-
*
|
|
859
|
-
*
|
|
860
|
-
* `
|
|
1371
|
+
* Computes the matrix decomposition: `P @ A = L @ U`, where `P` is a
|
|
1372
|
+
* permutation of the rows of `A`, `L` is lower-triangular with unit diagonal,
|
|
1373
|
+
* and `U` is upper-triangular.
|
|
861
1374
|
*
|
|
862
|
-
*
|
|
1375
|
+
* @param x - A batch of matrices with shape `[..., m, n]`.
|
|
1376
|
+
*
|
|
1377
|
+
* @returns A tuple `(lu, pivots, permutation)` where:
|
|
1378
|
+
* - `lu`: combined lower and upper triangular matrices.
|
|
1379
|
+
* - `pivots`: an array of pivot indices with shape `[..., min(m, n)]`.
|
|
1380
|
+
* - `permutation`: the permutation generated by pivots with shape `[..., m]`.
|
|
1381
|
+
*
|
|
1382
|
+
* @example
|
|
1383
|
+
* ```ts
|
|
1384
|
+
* import { lax, numpy as np } from "@jax-js/jax";
|
|
1385
|
+
*
|
|
1386
|
+
* const A = np.array([[4., 3.], [6., 3.]]);
|
|
1387
|
+
* const [lu, pivots, permutation] = lax.linalg.lu(A);
|
|
1388
|
+
* // lu ≈ [[6., 3.], [0.6666667, 1.0]]
|
|
1389
|
+
* // pivots = [1, 1]
|
|
1390
|
+
* // permutation = [1, 0]
|
|
1391
|
+
* ```
|
|
863
1392
|
*/
|
|
864
|
-
declare function
|
|
1393
|
+
declare function lu(x: ArrayLike): [Array, Array, Array];
|
|
865
1394
|
/**
|
|
866
|
-
*
|
|
867
|
-
*
|
|
1395
|
+
* Solve a triangular linear system.
|
|
1396
|
+
*
|
|
1397
|
+
* Solves `a @ x = b` (if leftSide=true) or `x @ a = b` (if leftSide=false)
|
|
1398
|
+
* where `a` is a triangular matrix.
|
|
1399
|
+
*
|
|
1400
|
+
* @example
|
|
1401
|
+
* ```ts
|
|
1402
|
+
* import { lax, numpy as np } from "@jax-js/jax";
|
|
1403
|
+
*
|
|
1404
|
+
* const L = np.array([[2., 0.], [1., 3.]]);
|
|
1405
|
+
* const b = np.array([4., 7.]).reshape([2, 1]);
|
|
1406
|
+
*
|
|
1407
|
+
* // Solve L @ x = b
|
|
1408
|
+
* const x = lax.linalg.triangularSolve(L, b, { leftSide: true, lower: true });
|
|
1409
|
+
* // x = [[2.], [5./3.]]
|
|
1410
|
+
* ```
|
|
868
1411
|
*/
|
|
869
|
-
declare function
|
|
1412
|
+
declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
|
|
1413
|
+
leftSide,
|
|
1414
|
+
lower,
|
|
1415
|
+
transposeA,
|
|
1416
|
+
unitDiagonal
|
|
1417
|
+
}?: {
|
|
1418
|
+
leftSide?: boolean;
|
|
1419
|
+
lower?: boolean;
|
|
1420
|
+
transposeA?: boolean;
|
|
1421
|
+
unitDiagonal?: boolean;
|
|
1422
|
+
}): Array;
|
|
1423
|
+
declare namespace lax_d_exports {
|
|
1424
|
+
export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
|
|
1425
|
+
}
|
|
870
1426
|
/**
|
|
871
|
-
*
|
|
872
|
-
*
|
|
1427
|
+
* Dimension numbers for general `dot()` primitive.
|
|
1428
|
+
*
|
|
1429
|
+
* Contracting dimensions act as a tensor contraction (reduction) along the
|
|
1430
|
+
* given axis. They must be the same size in both operands. Batch dimensions
|
|
1431
|
+
* are treated as vectorized, leading batch dimensions.
|
|
1432
|
+
*
|
|
1433
|
+
* The return value has a shape where the first dimensions are shared batch
|
|
1434
|
+
* dimensions, followed by `lhs` non-contracting dimensions, followed by
|
|
1435
|
+
* `rhs` non-contracting dimensions.
|
|
873
1436
|
*/
|
|
874
|
-
|
|
1437
|
+
type DotDimensionNumbers = {
|
|
1438
|
+
lhsContractingDims?: number[];
|
|
1439
|
+
rhsContractingDims?: number[];
|
|
1440
|
+
lhsBatchDims?: number[];
|
|
1441
|
+
rhsBatchDims?: number[];
|
|
1442
|
+
};
|
|
875
1443
|
/**
|
|
876
|
-
*
|
|
877
|
-
*
|
|
1444
|
+
* General dot product/contraction operator.
|
|
1445
|
+
*
|
|
1446
|
+
* Prefer higher-level functions like `jax.numpy.dot()`, `jax.numpy.matmul()`,
|
|
1447
|
+
* `jax.numpy.tensordot(), and `jax.numpy.einsum()` where possible.
|
|
878
1448
|
*/
|
|
879
|
-
declare function
|
|
1449
|
+
declare function dot$1(lhs: Array, rhs: Array, {
|
|
1450
|
+
lhsContractingDims: lc,
|
|
1451
|
+
rhsContractingDims: rc,
|
|
1452
|
+
lhsBatchDims: lb,
|
|
1453
|
+
rhsBatchDims: rb
|
|
1454
|
+
}?: DotDimensionNumbers): Array;
|
|
1455
|
+
type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | Pair[];
|
|
880
1456
|
/**
|
|
881
|
-
*
|
|
882
|
-
*
|
|
1457
|
+
* General n-dimensional convolution operator, with optional dilation.
|
|
1458
|
+
*
|
|
1459
|
+
* The semantics of this operation mimic the `jax.lax.conv_general_dilated`
|
|
1460
|
+
* function in JAX, which wraps XLA's general convolution operator.
|
|
1461
|
+
*
|
|
1462
|
+
* @param lhs - Input tensor; shape `[N, C_in, ...xs]`
|
|
1463
|
+
* @param rhs - Convolution kernel; shape `[C_out, C_in / G, ...ks]`
|
|
1464
|
+
* @param windowStrides - Strides for each spatial dimension
|
|
1465
|
+
* @param padding - Padding for each spatial dimension, or a string
|
|
1466
|
+
* (`"VALID"`, `"SAME"`, or `"SAME_LOWER"`)
|
|
883
1467
|
*/
|
|
884
|
-
declare function
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
/**
|
|
894
|
-
declare function
|
|
1468
|
+
declare function convGeneralDilated(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, {
|
|
1469
|
+
lhsDilation,
|
|
1470
|
+
rhsDilation,
|
|
1471
|
+
featureGroupCount
|
|
1472
|
+
}?: {
|
|
1473
|
+
lhsDilation?: number[];
|
|
1474
|
+
rhsDilation?: number[];
|
|
1475
|
+
featureGroupCount?: number;
|
|
1476
|
+
}): Array;
|
|
1477
|
+
/** Convenience wrapper around `convGeneralDilated`. */
|
|
1478
|
+
declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, lhsDilation?: number[], rhsDilation?: number[]): Array;
|
|
1479
|
+
/** Convenience wrapper around `convGeneralDilated`. */
|
|
1480
|
+
declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType): Array;
|
|
895
1481
|
/**
|
|
896
|
-
*
|
|
1482
|
+
* Convenience wrapper for calculating the N-d convolution "transpose".
|
|
897
1483
|
*
|
|
898
|
-
*
|
|
899
|
-
*
|
|
900
|
-
*
|
|
901
|
-
* @returns Array with the number of dimensions increased.
|
|
1484
|
+
* This function directly calculates a fractionally strided conv rather than
|
|
1485
|
+
* indirectly calculating the gradient (transpose) of a forward convolution.
|
|
1486
|
+
* It is equivalent to the JAX version, except:
|
|
902
1487
|
*
|
|
903
|
-
*
|
|
904
|
-
*
|
|
905
|
-
*
|
|
906
|
-
*
|
|
907
|
-
*
|
|
908
|
-
*
|
|
909
|
-
*
|
|
1488
|
+
* - The `use_consistent_padding` option is not available. We only have the
|
|
1489
|
+
* consistent padding case (JAX version >0.8.4).
|
|
1490
|
+
* - The order of dimensions matches `lax.conv_general_dilated`.
|
|
1491
|
+
*
|
|
1492
|
+
* Unlike PyTorch/TensorFlow, by default we don't reverse the kernel's spatial
|
|
1493
|
+
* dimensions or the `(C_out, C_in)` axis order. To get this behavior, set
|
|
1494
|
+
* `transposeKernel` to true.
|
|
1495
|
+
*
|
|
1496
|
+
* @param lhs - Input tensor; shape `[N, C_in, ...xs]`
|
|
1497
|
+
* @param rhs - Convolution kernel; shape `[C_out, C_in, ...ks]`
|
|
1498
|
+
* @param strides - Sequence of n integers, sets fractional stride
|
|
1499
|
+
* @param padding - Apply padding of `dilation * (kernel_size - 1) - padding` to
|
|
1500
|
+
* each side of the input, so it acts like gradient of `conv()`
|
|
1501
|
+
* @param rhsDilation - Atrous dilation for the kernel
|
|
1502
|
+
* @param transposeKernel - Flip spatial axes and swap the input/output channels
|
|
1503
|
+
* of the kernel; its shape should be `[C_in, C_out, ...ks]`
|
|
910
1504
|
*/
|
|
911
|
-
declare function
|
|
1505
|
+
declare function convTranspose(lhs: Array, rhs: Array, strides: number[], padding: PaddingType, {
|
|
1506
|
+
rhsDilation,
|
|
1507
|
+
transposeKernel
|
|
1508
|
+
}?: {
|
|
1509
|
+
rhsDilation?: number[];
|
|
1510
|
+
transposeKernel?: boolean;
|
|
1511
|
+
}): Array;
|
|
1512
|
+
/** Reduce a computation over padded windows. */
|
|
1513
|
+
declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
|
|
1514
|
+
/** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
|
|
1515
|
+
declare function erf(x: ArrayLike): Array;
|
|
912
1516
|
/**
|
|
913
|
-
*
|
|
1517
|
+
* The complementary error function: `erfc(x) = 1 - erf(x)`.
|
|
914
1518
|
*
|
|
915
|
-
*
|
|
916
|
-
*
|
|
1519
|
+
* This function is more accurate than `1 - erf(x)` for large values of `x`,
|
|
1520
|
+
* where `erf(x)` is very close to 1.
|
|
917
1521
|
*/
|
|
918
|
-
declare function
|
|
1522
|
+
declare function erfc(x: ArrayLike): Array;
|
|
919
1523
|
/**
|
|
920
|
-
*
|
|
1524
|
+
* Stops gradient computation.
|
|
921
1525
|
*
|
|
922
|
-
*
|
|
923
|
-
*
|
|
924
|
-
* reps[1] * d2, ..., reps[n] * dn)`, with `A` tiled along each dimension.
|
|
1526
|
+
* Behaves as the identity function but prevents the flow of gradients during
|
|
1527
|
+
* forward or reverse-mode automatic differentiation.
|
|
925
1528
|
*/
|
|
926
|
-
declare function
|
|
1529
|
+
declare function stopGradient(x: ArrayLike): Array;
|
|
1530
|
+
declare namespace numpy_fft_d_exports {
|
|
1531
|
+
export { ComplexPair, fft, ifft };
|
|
1532
|
+
}
|
|
927
1533
|
/**
|
|
928
|
-
*
|
|
929
|
-
*
|
|
930
|
-
* In other words, this lets you append axes to the left, and/or expand
|
|
931
|
-
* dimensions where the shape is 1.
|
|
1534
|
+
* A pair of arrays representing real and imaginary part `a + bj`. Both arrays
|
|
1535
|
+
* must have the same shape.
|
|
932
1536
|
*/
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
declare function broadcastArrays(...arrays: ArrayLike[]): Array[];
|
|
1537
|
+
type ComplexPair = {
|
|
1538
|
+
real: Array;
|
|
1539
|
+
imag: Array;
|
|
1540
|
+
};
|
|
938
1541
|
/**
|
|
939
|
-
*
|
|
940
|
-
*
|
|
941
|
-
* If a is 2D, return the diagonal of the array with the given offset. If a is
|
|
942
|
-
* 3D or higher, compute diagonals along the two given axes (default: 0, 1).
|
|
1542
|
+
* Compute a one-dimensional discrete Fourier transform.
|
|
943
1543
|
*
|
|
944
|
-
*
|
|
945
|
-
* is determined by removing the two axes along which the diagonal is taken,
|
|
946
|
-
* then appending a new axis to the right with holding the diagonals.
|
|
1544
|
+
* Currently, the size of the axis must be a power of two.
|
|
947
1545
|
*/
|
|
948
|
-
declare function
|
|
1546
|
+
declare function fft(a: ComplexPair, axis?: number): ComplexPair;
|
|
949
1547
|
/**
|
|
950
|
-
*
|
|
1548
|
+
* Compute a one-dimensional inverse discrete Fourier transform.
|
|
951
1549
|
*
|
|
952
|
-
*
|
|
953
|
-
* array, return a 2D array with v on the k-th diagonal.
|
|
1550
|
+
* Currently, the size of the axis must be a power of two.
|
|
954
1551
|
*/
|
|
955
|
-
declare function
|
|
956
|
-
|
|
957
|
-
|
|
1552
|
+
declare function ifft(a: ComplexPair, axis?: number): ComplexPair;
|
|
1553
|
+
declare namespace numpy_linalg_d_exports {
|
|
1554
|
+
export { cholesky, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot };
|
|
1555
|
+
}
|
|
958
1556
|
/**
|
|
959
|
-
*
|
|
1557
|
+
* Compute the Cholesky decomposition of a (batched) positive-definite matrix.
|
|
960
1558
|
*
|
|
961
|
-
*
|
|
962
|
-
*
|
|
1559
|
+
* This is like `jax.lax.linalg.cholesky()`, except with an option to symmetrize
|
|
1560
|
+
* the input matrix, which is on by default.
|
|
963
1561
|
*/
|
|
964
|
-
declare function
|
|
1562
|
+
declare function cholesky(a: ArrayLike, {
|
|
1563
|
+
upper,
|
|
1564
|
+
symmetrizeInput
|
|
1565
|
+
}?: {
|
|
1566
|
+
upper?: boolean;
|
|
1567
|
+
symmetrizeInput?: boolean;
|
|
1568
|
+
}): Array;
|
|
1569
|
+
/** Compute the determinant of a square matrix (batched). */
|
|
1570
|
+
declare function det(a: ArrayLike): Array;
|
|
1571
|
+
/** Compute the inverse of a square matrix (batched). */
|
|
1572
|
+
declare function inv(a: ArrayLike): Array;
|
|
965
1573
|
/**
|
|
966
|
-
* Return
|
|
967
|
-
* algorithm; it need not preserve order of indices in ties.
|
|
1574
|
+
* Return the least-squares solution to a linear equation.
|
|
968
1575
|
*
|
|
969
|
-
*
|
|
1576
|
+
* For overdetermined systems, this finds the `x` that minimizes `norm(ax - b)`.
|
|
1577
|
+
* For underdetermined systems, this finds the minimum-norm solution for `x`.
|
|
970
1578
|
*
|
|
971
|
-
*
|
|
1579
|
+
* This currently uses Cholesky decomposition to solve the normal equations,
|
|
1580
|
+
* under the hood. The method is not as robust as QR or SVD.
|
|
1581
|
+
*
|
|
1582
|
+
* @param a coefficient matrix of shape `(M, N)`
|
|
1583
|
+
* @param b right-hand side of shape `(M,)` or `(M, K)`
|
|
1584
|
+
* @return least-squares solution of shape `(N,)` or `(N, K)`
|
|
972
1585
|
*/
|
|
973
|
-
declare function
|
|
1586
|
+
declare function lstsq(a: ArrayLike, b: ArrayLike): Array;
|
|
1587
|
+
/** Raise a square matrix to an integer power, via repeated squarings. */
|
|
1588
|
+
declare function matrixPower(a: ArrayLike, n: number): Array;
|
|
1589
|
+
/** Return sign and natural logarithm of the determinant of `a`. */
|
|
1590
|
+
declare function slogdet(a: ArrayLike): [Array, Array];
|
|
974
1591
|
/**
|
|
975
|
-
*
|
|
1592
|
+
* Solve a linear system of equations.
|
|
976
1593
|
*
|
|
977
|
-
* This
|
|
978
|
-
*
|
|
1594
|
+
* This solves a (batched) linear system of equations `a @ x = b` for `x` given
|
|
1595
|
+
* `a` and `b`. If `a` is singular, this will return `nan` or `inf` values.
|
|
1596
|
+
*
|
|
1597
|
+
* @param a - Coefficient matrix of shape `(..., N, N)`.
|
|
1598
|
+
* @param b - Values of shape `(N,)` or `(..., N, M)`.
|
|
1599
|
+
* @returns Solution `x` of shape `(..., N)` or `(..., N, M)`.
|
|
979
1600
|
*/
|
|
980
|
-
declare function
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
1601
|
+
declare function solve(a: ArrayLike, b: ArrayLike): Array;
|
|
1602
|
+
//#endregion
|
|
1603
|
+
//#region src/library/numpy/dtype-info.d.ts
|
|
1604
|
+
/** @inline */
|
|
1605
|
+
type FInfo = Readonly<{
|
|
1606
|
+
/** The number of bits occupied by the type. */
|
|
1607
|
+
bits: number;
|
|
1608
|
+
/** Returns the _dtype_ for which finfo returns information. */
|
|
1609
|
+
dtype: DType;
|
|
1610
|
+
/** The difference between 1.0 and the next smallest representable float larger than 1.0. */
|
|
1611
|
+
eps: number;
|
|
1612
|
+
/** The difference between 1.0 and the next largest representable float smaller than 1.0. */
|
|
1613
|
+
epsneg: number;
|
|
1614
|
+
/** The exponent that yields `eps`. */
|
|
1615
|
+
machep: number;
|
|
1616
|
+
/** The largest representable finite number. */
|
|
1617
|
+
max: number;
|
|
1618
|
+
/** The smallest positive power of the base (2) that causes overflow. */
|
|
1619
|
+
maxexp: number;
|
|
1620
|
+
/** The smallest representable (most negative) finite number. */
|
|
1621
|
+
min: number;
|
|
1622
|
+
/** The largest negative power of the base (2) without leading zeros in mantissa. */
|
|
1623
|
+
minexp: number;
|
|
1624
|
+
/** The exponent that yields `epsneg`. */
|
|
1625
|
+
negep: number;
|
|
1626
|
+
/** Number of bits in the exponent portion. */
|
|
1627
|
+
nexp: number;
|
|
1628
|
+
/** Number of bits in the mantissa portion. */
|
|
1629
|
+
nmant: number;
|
|
1630
|
+
/** The approximate number of decimal digits to which this kind of float is precise. */
|
|
1631
|
+
precision: number;
|
|
1632
|
+
/** The approximate decimal resolution, i.e., `10 ** -precision`. */
|
|
1633
|
+
resolution: number;
|
|
1634
|
+
/** The smallest positive normal number. */
|
|
1635
|
+
smallestNormal: number;
|
|
1636
|
+
/** The smallest positive subnormal number. */
|
|
1637
|
+
smallestSubnormal: number;
|
|
1638
|
+
}>;
|
|
1639
|
+
/** Machine limits for floating-point types. */
|
|
1640
|
+
declare function finfo(dtype: DType): FInfo;
|
|
1641
|
+
/** @inline */
|
|
1642
|
+
type IInfo = Readonly<{
|
|
1643
|
+
/** The number of bits occupied by the type. */
|
|
1644
|
+
bits: number;
|
|
1645
|
+
/** Returns the _dtype_ for which iinfo returns information. */
|
|
1646
|
+
dtype: DType;
|
|
1647
|
+
/** The largest representable integer. */
|
|
1648
|
+
max: number;
|
|
1649
|
+
/** The smallest representable integer. */
|
|
1650
|
+
min: number;
|
|
1651
|
+
}>;
|
|
1652
|
+
/** Machine limits for integer types. */
|
|
1653
|
+
declare function iinfo(dtype: DType): IInfo;
|
|
1654
|
+
declare namespace numpy_d_exports {
|
|
1655
|
+
export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logspace, matmul, matrixTranspose, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, nanToNum, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
1656
|
+
}
|
|
1657
|
+
declare const float32 = DType.Float32;
|
|
1658
|
+
declare const int32 = DType.Int32;
|
|
1659
|
+
declare const uint32 = DType.Uint32;
|
|
1660
|
+
declare const bool = DType.Bool;
|
|
1661
|
+
declare const float16 = DType.Float16;
|
|
1662
|
+
declare const float64 = DType.Float64;
|
|
1663
|
+
/** Euler's constant, `e = 2.7182818284590...` */
|
|
1664
|
+
declare const e: number;
|
|
1665
|
+
/** Euler-Mascheroni constant, `γ = 0.5772156649...` */
|
|
1666
|
+
declare const eulerGamma = 0.5772156649015329;
|
|
1667
|
+
/** Positive infinity. */
|
|
1668
|
+
declare const inf: number;
|
|
1669
|
+
/** Floating-point representation of NaN. */
|
|
1670
|
+
declare const nan: number;
|
|
1671
|
+
/** This is Pi, `π = 3.14159265358979...` */
|
|
1672
|
+
declare const pi: number;
|
|
1673
|
+
/** @function Element-wise addition, with broadcasting. */
|
|
1674
|
+
declare const add: (x: ArrayLike, y: ArrayLike) => Array;
|
|
1675
|
+
/** @function Element-wise multiplication, with broadcasting. */
|
|
1676
|
+
declare const multiply: (x: ArrayLike, y: ArrayLike) => Array;
|
|
1677
|
+
/** @function Numerical negative of every element of an array. */
|
|
1678
|
+
declare const negative: (x: ArrayLike) => Array;
|
|
1679
|
+
/** @function Calculate element-wise reciprocal of the input. This is `1/x`. */
|
|
1680
|
+
declare const reciprocal: (x: ArrayLike) => Array;
|
|
1681
|
+
/** @function Round input down to the nearest integer. */
|
|
1682
|
+
declare const floor: (x: ArrayLike) => Array;
|
|
1683
|
+
/** @function Round input up to the nearest integer. */
|
|
1684
|
+
declare const ceil: (x: ArrayLike) => Array;
|
|
1685
|
+
/** @function Element-wise sine function (takes radians). */
|
|
1686
|
+
declare const sin: (x: ArrayLike) => Array;
|
|
1687
|
+
/** @function Element-wise cosine function (takes radians). */
|
|
1688
|
+
declare const cos: (x: ArrayLike) => Array;
|
|
1689
|
+
/** @function Element-wise inverse sine function (inverse of sin). */
|
|
1690
|
+
declare const asin: (x: ArrayLike) => Array;
|
|
1691
|
+
/** @function Element-wise inverse tangent function (inverse of tan). */
|
|
1692
|
+
declare const atan: (x: ArrayLike) => Array;
|
|
1693
|
+
/** @function Calculate the exponential of all elements in the input array. */
|
|
1694
|
+
declare const exp: (x: ArrayLike) => Array;
|
|
1695
|
+
/** @function Calculate the natural logarithm of all elements in the input array. */
|
|
1696
|
+
declare const log: (x: ArrayLike) => Array;
|
|
1697
|
+
/** @function Calculate the square root of all elements in the input array. */
|
|
1698
|
+
declare const sqrt: (x: ArrayLike) => Array;
|
|
1699
|
+
/** @function Return element-wise minimum of the input arrays. */
|
|
1700
|
+
declare const minimum: (x: ArrayLike, y: ArrayLike) => Array;
|
|
1701
|
+
/** @function Return element-wise maximum of the input arrays. */
|
|
1702
|
+
declare const maximum: (x: ArrayLike, y: ArrayLike) => Array;
|
|
1703
|
+
/** @function Compare two arrays element-wise. */
|
|
1704
|
+
declare const greater: (x: ArrayLike, y: ArrayLike) => Array;
|
|
1705
|
+
/** @function Compare two arrays element-wise. */
|
|
1706
|
+
declare const less: (x: ArrayLike, y: ArrayLike) => Array;
|
|
1707
|
+
/** @function Compare two arrays element-wise. */
|
|
1708
|
+
declare const equal: (x: ArrayLike, y: ArrayLike) => Array;
|
|
1709
|
+
/** @function Compare two arrays element-wise. */
|
|
1710
|
+
declare const notEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
1711
|
+
/** @function Compare two arrays element-wise. */
|
|
1712
|
+
declare const greaterEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
1713
|
+
/** @function Compare two arrays element-wise. */
|
|
1714
|
+
declare const lessEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
1715
|
+
/** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
1716
|
+
declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
|
|
990
1717
|
/**
|
|
991
|
-
*
|
|
992
|
-
*
|
|
993
|
-
* The behavior is determined by `axes`. If an integer `k`, sum over the last
|
|
994
|
-
* `k` axes of x and the first `k` axes of y. If a tuple, then the first array
|
|
995
|
-
* corresponds to the axes of x and the second to the axes of y.
|
|
1718
|
+
* @function
|
|
1719
|
+
* Permute the dimensions of an array. Defaults to reversing the axis order.
|
|
996
1720
|
*/
|
|
997
|
-
declare
|
|
1721
|
+
declare const transpose: (x: ArrayLike, perm?: number[]) => Array;
|
|
998
1722
|
/**
|
|
999
|
-
*
|
|
1000
|
-
*
|
|
1001
|
-
* @example
|
|
1002
|
-
* ```ts
|
|
1003
|
-
* import { numpy as np } from "@jax-js/jax";
|
|
1723
|
+
* @function
|
|
1724
|
+
* Give a new shape to an array without changing its data.
|
|
1004
1725
|
*
|
|
1005
|
-
*
|
|
1006
|
-
*
|
|
1007
|
-
|
|
1008
|
-
|
|
1726
|
+
* One shape dimension can be -1. In this case, the value is inferred from the
|
|
1727
|
+
* length of the array and remaining dimensions.
|
|
1728
|
+
*/
|
|
1729
|
+
declare const reshape: (x: ArrayLike, shape: number[]) => Array;
|
|
1730
|
+
/**
|
|
1731
|
+
* @function
|
|
1732
|
+
* Move axes of an array to new positions. Other axes retain original order.
|
|
1009
1733
|
*/
|
|
1010
|
-
declare
|
|
1734
|
+
declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
|
|
1011
1735
|
/**
|
|
1012
|
-
*
|
|
1013
|
-
*
|
|
1014
|
-
* @example
|
|
1015
|
-
* ```ts
|
|
1016
|
-
* import { numpy as np } from "@jax-js/jax";
|
|
1736
|
+
* @function
|
|
1737
|
+
* Add padding (zeros) to an array.
|
|
1017
1738
|
*
|
|
1018
|
-
*
|
|
1019
|
-
*
|
|
1020
|
-
*
|
|
1021
|
-
* ```
|
|
1739
|
+
* The `width` argument is either an integer or pair of integers, in which case
|
|
1740
|
+
* all axes are padded with the same width. Or if it is an array of pairs, each
|
|
1741
|
+
* pair specifies the padding for its corresponding axis.
|
|
1022
1742
|
*/
|
|
1023
|
-
declare
|
|
1743
|
+
declare const pad: (x: ArrayLike, width: number | Pair | Pair[] | Record<number, Pair>) => Array;
|
|
1024
1744
|
/**
|
|
1025
|
-
*
|
|
1026
|
-
*
|
|
1027
|
-
* Unlike `jax.numpy.matmul()` or `jax.numpy.dot()`, this always performs a
|
|
1028
|
-
* contraction on the last axis.
|
|
1029
|
-
*
|
|
1030
|
-
* Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
|
|
1745
|
+
* @function
|
|
1746
|
+
* Return the number of dimensions of an array. Does not consume array reference.
|
|
1031
1747
|
*/
|
|
1032
|
-
declare
|
|
1748
|
+
declare const ndim: (x: ArrayLike) => number;
|
|
1749
|
+
/** @function Return the shape of an array. Does not consume array reference. */
|
|
1750
|
+
declare const shape$1: (x: ArrayLike) => number[];
|
|
1033
1751
|
/**
|
|
1034
|
-
*
|
|
1035
|
-
*
|
|
1036
|
-
* If the input arrays are not 1D, they will be flattened. Returned array will
|
|
1037
|
-
* be of shape `[x.size, y.size]`.
|
|
1752
|
+
* @function
|
|
1753
|
+
* Return an array of zeros with the same shape and type as a given array.
|
|
1038
1754
|
*/
|
|
1039
|
-
declare
|
|
1040
|
-
/** Vector dot product of two arrays along a given axis. */
|
|
1041
|
-
declare function vecdot(x: ArrayLike, y: ArrayLike, {
|
|
1042
|
-
axis
|
|
1043
|
-
}?: {
|
|
1044
|
-
axis?: number;
|
|
1045
|
-
}): Array;
|
|
1755
|
+
declare const zerosLike: (a: ArrayLike, dtype?: DType) => Array;
|
|
1046
1756
|
/**
|
|
1047
|
-
*
|
|
1048
|
-
*
|
|
1049
|
-
* Like vecdot() but flattens the arguments first into vectors.
|
|
1757
|
+
* @function
|
|
1758
|
+
* Return an array of ones with the same shape and type as a given array.
|
|
1050
1759
|
*/
|
|
1051
|
-
declare
|
|
1052
|
-
/** Convolution of two one-dimensional arrays. */
|
|
1053
|
-
declare function convolve(x: Array, y: Array, mode?: "full" | "same" | "valid"): Array;
|
|
1054
|
-
/** Correlation of two one dimensional arrays. */
|
|
1055
|
-
declare function correlate(x: Array, y: Array, mode?: "full" | "same" | "valid"): Array;
|
|
1760
|
+
declare const onesLike: (a: ArrayLike, dtype?: DType) => Array;
|
|
1056
1761
|
/**
|
|
1057
|
-
*
|
|
1058
|
-
*
|
|
1059
|
-
* Make N-D coordinate arrays for vectorized evaluations of N-D scalar/vector
|
|
1060
|
-
* fields over N-D grids, given one-dimensional coordinate arrays x1, x2,…, xn.
|
|
1762
|
+
* @function
|
|
1763
|
+
* Return a full array with the same shape and type as a given array.
|
|
1061
1764
|
*/
|
|
1062
|
-
declare
|
|
1063
|
-
indexing
|
|
1064
|
-
}?: {
|
|
1065
|
-
indexing?: "xy" | "ij";
|
|
1066
|
-
}): Array[];
|
|
1765
|
+
declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, dtype?: DType) => Array;
|
|
1067
1766
|
/**
|
|
1068
|
-
*
|
|
1069
|
-
*
|
|
1070
|
-
* Given an interval, values outside the interval are clipped to the interval
|
|
1071
|
-
* edges. For example, if an interval of [0, 1] is specified, values smaller
|
|
1072
|
-
* than 0 become 0, and values larger than 1 become 1.
|
|
1073
|
-
*
|
|
1074
|
-
* If either bound is undefined, it is ignored.
|
|
1767
|
+
* Return the number of elements in an array, optionally along an axis.
|
|
1768
|
+
* Does not consume array reference.
|
|
1075
1769
|
*/
|
|
1076
|
-
declare function
|
|
1770
|
+
declare function size(a: ArrayLike, axis?: number): number;
|
|
1771
|
+
/** Convert an array to a specified dtype. */
|
|
1772
|
+
declare function astype(a: ArrayLike, dtype: DType): Array;
|
|
1773
|
+
/** Sum of the elements of the array over a given axis, or axes. */
|
|
1774
|
+
declare function sum(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1775
|
+
/** Product of the array elements over a given axis. */
|
|
1776
|
+
declare function prod(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1777
|
+
/** Return the minimum of array elements along a given axis. */
|
|
1778
|
+
declare function min(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1779
|
+
/** Return the maximum of array elements along a given axis. */
|
|
1780
|
+
declare function max(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1077
1781
|
/**
|
|
1078
|
-
*
|
|
1782
|
+
* Test whether any array element along a given axis evaluates to True.
|
|
1079
1783
|
*
|
|
1080
|
-
*
|
|
1784
|
+
* Returns a boolean array with the same shape as `a` with the specified axis
|
|
1785
|
+
* removed. If axis is None, returns a scalar.
|
|
1081
1786
|
*/
|
|
1082
|
-
declare function
|
|
1083
|
-
/** Return an element-wise indication of sign of the input. */
|
|
1084
|
-
declare function sign(x: ArrayLike): Array;
|
|
1085
|
-
/** @function Return element-wise positive values of the input (no-op). */
|
|
1086
|
-
declare const positive: (x: ArrayLike) => Array;
|
|
1787
|
+
declare function any(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1087
1788
|
/**
|
|
1088
|
-
*
|
|
1789
|
+
* Test whether all array elements along a given axis evaluate to True.
|
|
1089
1790
|
*
|
|
1090
|
-
*
|
|
1791
|
+
* Returns a boolean array with the same shape as `a` with the specified axis
|
|
1792
|
+
* removed. If axis is None, returns a scalar.
|
|
1091
1793
|
*/
|
|
1092
|
-
declare function
|
|
1794
|
+
declare function all(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1795
|
+
/** Return the peak-to-peak range along a given axis (`max - min`). */
|
|
1796
|
+
declare function ptp(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1797
|
+
/** Compute the average of the array elements along the specified axis. */
|
|
1798
|
+
declare function mean(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1093
1799
|
/**
|
|
1094
|
-
*
|
|
1800
|
+
* Returns the indices of the minimum values along an axis.
|
|
1095
1801
|
*
|
|
1096
|
-
*
|
|
1802
|
+
* By default, index is into the flatted array, otherwise it is along the
|
|
1803
|
+
* specified axis.
|
|
1097
1804
|
*/
|
|
1098
|
-
declare function
|
|
1805
|
+
declare function argmin(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
|
|
1099
1806
|
/**
|
|
1100
|
-
*
|
|
1101
|
-
*
|
|
1102
|
-
*
|
|
1103
|
-
*
|
|
1104
|
-
* - `heaviside(x1, x2) = 1` for `x1 > 0`.
|
|
1807
|
+
* Returns the indices of the maximum values along an axis.
|
|
1808
|
+
*
|
|
1809
|
+
* By default, index is into the flatted array, otherwise it is along the
|
|
1810
|
+
* specified axis.
|
|
1105
1811
|
*/
|
|
1106
|
-
declare
|
|
1107
|
-
/** Calculate element-wise square of the input array. */
|
|
1108
|
-
declare function square(x: ArrayLike): Array;
|
|
1109
|
-
/** Element-wise tangent function (takes radians). */
|
|
1110
|
-
declare function tan(x: ArrayLike): Array;
|
|
1812
|
+
declare function argmax(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
|
|
1111
1813
|
/**
|
|
1112
|
-
*
|
|
1113
|
-
* Return the normalized sinc function.
|
|
1114
|
-
*
|
|
1115
|
-
* The sinc function is defined as `sin(πx) / (πx)` for `x != 0`, and `1` for `x = 0`.
|
|
1116
|
-
* This is the normalized sinc function commonly used in signal processing.
|
|
1814
|
+
* Cumulative sum of elements along an axis.
|
|
1117
1815
|
*
|
|
1118
|
-
*
|
|
1119
|
-
*
|
|
1816
|
+
* Currently this function is `O(n^2)`, we'll improve this later on with a
|
|
1817
|
+
* two-phase parallel reduction algorithm.
|
|
1120
1818
|
*/
|
|
1121
|
-
declare
|
|
1122
|
-
/**
|
|
1123
|
-
declare function
|
|
1819
|
+
declare function cumsum(a: ArrayLike, axis?: number): Array;
|
|
1820
|
+
/** Reverse the elements in an array along the given axes. */
|
|
1821
|
+
declare function flip(x: ArrayLike, axis?: Axis): Array;
|
|
1124
1822
|
/**
|
|
1125
|
-
*
|
|
1126
|
-
* Return element-wise hypotenuse for the given legs of a right triangle.
|
|
1823
|
+
* Split an array into multiple sub-arrays along an axis.
|
|
1127
1824
|
*
|
|
1128
|
-
*
|
|
1129
|
-
*
|
|
1130
|
-
*
|
|
1825
|
+
* @param a - The input array to split.
|
|
1826
|
+
* @param indicesOrSections - If an integer, it indicates the number of equal
|
|
1827
|
+
* sections to create along the specified axis. If a list of integers, it
|
|
1828
|
+
* specifies the indices at which to split the array.
|
|
1829
|
+
* @param axis - The axis along which to split the array. Default is 0.
|
|
1131
1830
|
*/
|
|
1132
|
-
declare
|
|
1831
|
+
declare function split$1(a: ArrayLike, indicesOrSections: number | number[], axis?: number): Array[];
|
|
1133
1832
|
/**
|
|
1134
|
-
*
|
|
1135
|
-
* Element-wise arc tangent of y/x with correct quadrant.
|
|
1136
|
-
*
|
|
1137
|
-
* Returns the angle in radians between the positive x-axis and the point (x, y).
|
|
1138
|
-
* The result is in the range [-π, π].
|
|
1833
|
+
* Join a sequence of arrays along an existing axis.
|
|
1139
1834
|
*
|
|
1140
|
-
*
|
|
1141
|
-
*
|
|
1142
|
-
* - When x < 0: atan2(y, x) = 2 * atan((sqrt(x^2 + y^2) - x) / y)
|
|
1835
|
+
* The arrays must have the same shape, except in the dimension corresponding to
|
|
1836
|
+
* `axis` (the first, by default).
|
|
1143
1837
|
*
|
|
1144
|
-
*
|
|
1838
|
+
* No scalars can be passed to this function, as the axis is then ambiguous.
|
|
1145
1839
|
*/
|
|
1146
|
-
declare
|
|
1147
|
-
/** Element-wise subtraction, with broadcasting. */
|
|
1148
|
-
declare function subtract(x: ArrayLike, y: ArrayLike): Array;
|
|
1149
|
-
/** Calculates the floating-point division of x by y element-wise. */
|
|
1150
|
-
declare function trueDivide(x: ArrayLike, y: ArrayLike): Array;
|
|
1840
|
+
declare function concatenate(xs: Array[], axis?: number): Array;
|
|
1151
1841
|
/**
|
|
1152
|
-
*
|
|
1153
|
-
*
|
|
1154
|
-
* The result is always rounded towards negative infinity.
|
|
1842
|
+
* Join a sequence of arrays along a new axis.
|
|
1155
1843
|
*
|
|
1156
|
-
*
|
|
1157
|
-
* For
|
|
1158
|
-
*
|
|
1844
|
+
* The `axis` parameter specifies the index of the new axis in the dimensions of
|
|
1845
|
+
* the result. For example, if `axis=0` it will be the first dimension and if
|
|
1846
|
+
* `axis=-1` it will be the last dimension.
|
|
1159
1847
|
*
|
|
1160
|
-
*
|
|
1161
|
-
* @param y - Divisor array.
|
|
1162
|
-
* @returns Element-wise floor division of x by y.
|
|
1848
|
+
* All shapes must have the same shape.
|
|
1163
1849
|
*/
|
|
1164
|
-
declare function
|
|
1850
|
+
declare function stack(xs: ArrayLike[], axis?: number): Array;
|
|
1165
1851
|
/**
|
|
1166
|
-
*
|
|
1167
|
-
*
|
|
1852
|
+
* Horizontally stack arrays. Inputs are promoted to rank at least 1, then
|
|
1853
|
+
* concatenated along axis 1 (if rank-2 or higher) or 0 (if rank-1).
|
|
1168
1854
|
*/
|
|
1169
|
-
declare
|
|
1855
|
+
declare function hstack(xs: ArrayLike[]): Array;
|
|
1170
1856
|
/**
|
|
1171
|
-
*
|
|
1172
|
-
*
|
|
1857
|
+
* Vertically stack arrays. Inputs are promoted to rank at least 2, then
|
|
1858
|
+
* concatenated along axis 0.
|
|
1173
1859
|
*/
|
|
1174
|
-
declare
|
|
1860
|
+
declare function vstack(xs: ArrayLike[]): Array;
|
|
1175
1861
|
/**
|
|
1176
|
-
*
|
|
1862
|
+
* Stack arrays depth-wise. Inputs are promoted to rank at least 3, then
|
|
1863
|
+
* concatenated along axis 2.
|
|
1864
|
+
*/
|
|
1865
|
+
declare function dstack(xs: ArrayLike[]): Array;
|
|
1866
|
+
/**
|
|
1867
|
+
* Stack arrays column-wise. Inputs are promoted to rank at least 2, then
|
|
1868
|
+
* concatenated along axis 1.
|
|
1869
|
+
*/
|
|
1870
|
+
declare function columnStack(xs: ArrayLike[]): Array;
|
|
1871
|
+
/** Flip an array vertically (axis=0). */
|
|
1872
|
+
declare function flipud(x: ArrayLike): Array;
|
|
1873
|
+
/** Flip an array horizontally (axis=1). */
|
|
1874
|
+
declare function fliplr(x: ArrayLike): Array;
|
|
1875
|
+
/** Interchange two axes of an array. */
|
|
1876
|
+
declare function swapaxes(a: ArrayLike, axis1: number, axis2: number): Array;
|
|
1877
|
+
/** Transpose the last two dimensions of an array. */
|
|
1878
|
+
declare function matrixTranspose(a: ArrayLike): Array;
|
|
1879
|
+
/** Return a 1-D flattened array containing the elements of the input. */
|
|
1880
|
+
declare function ravel(a: ArrayLike): Array;
|
|
1881
|
+
/** Remove one or more length-1 axes from an array. */
|
|
1882
|
+
declare function squeeze(a: ArrayLike, axis?: Axis): Array;
|
|
1883
|
+
/**
|
|
1884
|
+
* Expand the shape of an array by inserting new axes of length 1.
|
|
1177
1885
|
*
|
|
1178
|
-
*
|
|
1886
|
+
* @param a - Input array.
|
|
1887
|
+
* @param axis - Position(s) in the expanded axes where the new axis (or axes)
|
|
1888
|
+
* is placed. Can be a single integer or an array of integers.
|
|
1889
|
+
* @returns Array with the number of dimensions increased.
|
|
1179
1890
|
*
|
|
1180
|
-
* @
|
|
1181
|
-
*
|
|
1182
|
-
*
|
|
1891
|
+
* @example
|
|
1892
|
+
* ```ts
|
|
1893
|
+
* const x = np.array([1, 2]);
|
|
1894
|
+
* np.expandDims(x, 0); // Shape [1, 2]
|
|
1895
|
+
* np.expandDims(x, 1); // Shape [2, 1]
|
|
1896
|
+
* np.expandDims(x, [0, 2]); // Shape [1, 2, 1]
|
|
1897
|
+
* ```
|
|
1183
1898
|
*/
|
|
1184
|
-
declare function
|
|
1185
|
-
/**
|
|
1186
|
-
|
|
1899
|
+
declare function expandDims(a: ArrayLike, axis: number | number[]): Array;
|
|
1900
|
+
/**
|
|
1901
|
+
* Repeat each element of an array after themselves.
|
|
1902
|
+
*
|
|
1903
|
+
* If no axis is provided, use the flattened input array, and return a flat
|
|
1904
|
+
* output array.
|
|
1905
|
+
*/
|
|
1906
|
+
declare function repeat(a: ArrayLike, repeats: number, axis?: number): Array;
|
|
1187
1907
|
/**
|
|
1188
|
-
*
|
|
1908
|
+
* Construct an array by repeating A the number of times given by reps.
|
|
1189
1909
|
*
|
|
1190
|
-
*
|
|
1910
|
+
* If `A` is an array of shape `(d1, d2, ..., dn)` and `reps` is a sequence of
|
|
1911
|
+
* integers, the resulting array will have a shape of `(reps[0] * d1,
|
|
1912
|
+
* reps[1] * d2, ..., reps[n] * dn)`, with `A` tiled along each dimension.
|
|
1191
1913
|
*/
|
|
1192
|
-
declare function
|
|
1914
|
+
declare function tile(a: ArrayLike, reps: number | number[]): Array;
|
|
1193
1915
|
/**
|
|
1194
|
-
*
|
|
1916
|
+
* Broadcast an array to a shape, with NumPy-style broadcasing rules.
|
|
1195
1917
|
*
|
|
1196
|
-
*
|
|
1197
|
-
*
|
|
1198
|
-
* `x = mantissa * 2**exponent`.
|
|
1918
|
+
* In other words, this lets you append axes to the left, and/or expand
|
|
1919
|
+
* dimensions where the shape is 1.
|
|
1199
1920
|
*/
|
|
1200
|
-
declare function
|
|
1201
|
-
/**
|
|
1202
|
-
declare function
|
|
1203
|
-
/**
|
|
1204
|
-
declare function
|
|
1205
|
-
/** Return the base-10 logarithm of x, element-wise. */
|
|
1206
|
-
declare function log10(x: ArrayLike): Array;
|
|
1207
|
-
/** Calculate `exp(x) - 1` element-wise. */
|
|
1208
|
-
declare function expm1(x: ArrayLike): Array;
|
|
1209
|
-
/** Calculate the natural logarithm of `1 + x` element-wise. */
|
|
1210
|
-
declare function log1p(x: ArrayLike): Array;
|
|
1211
|
-
/** Convert angles from degrees to radians. */
|
|
1212
|
-
declare function deg2rad(x: ArrayLike): Array;
|
|
1213
|
-
/** @function Alias of `jax.numpy.deg2rad()`. */
|
|
1214
|
-
declare const radians: typeof deg2rad;
|
|
1215
|
-
/** Convert angles from radians to degrees. */
|
|
1216
|
-
declare function rad2deg(x: ArrayLike): Array;
|
|
1217
|
-
/** @function Alias of `jax.numpy.rad2deg()`. */
|
|
1218
|
-
declare const degrees: typeof rad2deg;
|
|
1921
|
+
declare function broadcastTo(a: ArrayLike, shape: number[]): Array;
|
|
1922
|
+
/** Broadcast input shapes to a common output shape. */
|
|
1923
|
+
declare function broadcastShapes(...shapes: number[][]): number[];
|
|
1924
|
+
/** Broadcast arrays to a common shape. */
|
|
1925
|
+
declare function broadcastArrays(...arrays: ArrayLike[]): Array[];
|
|
1219
1926
|
/**
|
|
1220
|
-
*
|
|
1221
|
-
*
|
|
1927
|
+
* Return specified diagonals.
|
|
1928
|
+
*
|
|
1929
|
+
* If a is 2D, return the diagonal of the array with the given offset. If a is
|
|
1930
|
+
* 3D or higher, compute diagonals along the two given axes (default: 0, 1).
|
|
1931
|
+
*
|
|
1932
|
+
* This returns a view over the existing array. The shape of the resulting array
|
|
1933
|
+
* is determined by removing the two axes along which the diagonal is taken,
|
|
1934
|
+
* then appending a new axis to the right with holding the diagonals.
|
|
1222
1935
|
*/
|
|
1223
|
-
declare
|
|
1224
|
-
/** @function Calculate the element-wise cube root of the input array. */
|
|
1225
|
-
declare const cbrt: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1936
|
+
declare function diagonal(a: ArrayLike, offset?: number, axis1?: number, axis2?: number): Array;
|
|
1226
1937
|
/**
|
|
1227
|
-
*
|
|
1228
|
-
* Calculate element-wise hyperbolic sine of input.
|
|
1938
|
+
* Extract a diagonal or construct a diagonal array.
|
|
1229
1939
|
*
|
|
1230
|
-
*
|
|
1940
|
+
* If v is a 2D array, return the k-th diagonal of v (as a view). If v is a 1D
|
|
1941
|
+
* array, return a 2D array with v on the k-th diagonal.
|
|
1231
1942
|
*/
|
|
1232
|
-
declare
|
|
1943
|
+
declare function diag(v: ArrayLike, k?: number): Array;
|
|
1944
|
+
/** Calculate the sum of the diagonal of an array along the given axes. */
|
|
1945
|
+
declare function trace(a: ArrayLike, offset?: number, axis1?: number, axis2?: number): Array;
|
|
1233
1946
|
/**
|
|
1234
|
-
*
|
|
1235
|
-
* Calculate element-wise hyperbolic cosine of input.
|
|
1947
|
+
* Return a sorted copy of an array.
|
|
1236
1948
|
*
|
|
1237
|
-
*
|
|
1949
|
+
* The array is sorted along a specified axis (the last by default). This may be
|
|
1950
|
+
* an unstable sort, and it dispatches to device-specific implementation.
|
|
1238
1951
|
*/
|
|
1239
|
-
declare
|
|
1952
|
+
declare function sort(a: ArrayLike, axis?: number): Array;
|
|
1240
1953
|
/**
|
|
1241
|
-
*
|
|
1242
|
-
*
|
|
1954
|
+
* Return indices that would sort an array. This may be an unstable sorting
|
|
1955
|
+
* algorithm; it need not preserve order of indices in ties.
|
|
1243
1956
|
*
|
|
1244
|
-
*
|
|
1957
|
+
* Returns an array of `int32` indices.
|
|
1958
|
+
*
|
|
1959
|
+
* The array is sorted along a specified axis (the last by default).
|
|
1245
1960
|
*/
|
|
1246
|
-
declare
|
|
1961
|
+
declare function argsort(a: ArrayLike, axis?: number): Array;
|
|
1247
1962
|
/**
|
|
1248
|
-
*
|
|
1249
|
-
* Calculate element-wise inverse hyperbolic sine of input.
|
|
1963
|
+
* Take elements from an array along an axis.
|
|
1250
1964
|
*
|
|
1251
|
-
*
|
|
1965
|
+
* This is equivalent to advanced indexing with integer indices over that
|
|
1966
|
+
* numbered axis. By default, the flattened array is used.
|
|
1252
1967
|
*/
|
|
1253
|
-
declare
|
|
1968
|
+
declare function take(a: ArrayLike, indices: ArrayLike, axis?: number | null): Array;
|
|
1969
|
+
/** Return if two arrays are element-wise equal within a tolerance. */
|
|
1970
|
+
declare function allclose(actual: Parameters<typeof array>[0], expected: Parameters<typeof array>[0], options?: {
|
|
1971
|
+
rtol?: number;
|
|
1972
|
+
atol?: number;
|
|
1973
|
+
}): boolean;
|
|
1974
|
+
/** Matrix product of two arrays. */
|
|
1975
|
+
declare function matmul(x: ArrayLike, y: ArrayLike): Array;
|
|
1976
|
+
/** Dot product of two arrays. */
|
|
1977
|
+
declare function dot(x: ArrayLike, y: ArrayLike): Array;
|
|
1254
1978
|
/**
|
|
1255
|
-
*
|
|
1256
|
-
* Calculate element-wise inverse hyperbolic cosine of input.
|
|
1979
|
+
* Compute the tensor dot product of two N-dimensional arrays.
|
|
1257
1980
|
*
|
|
1258
|
-
* `
|
|
1981
|
+
* The behavior is determined by `axes`. If an integer `k`, sum over the last
|
|
1982
|
+
* `k` axes of x and the first `k` axes of y. If a tuple, then the first array
|
|
1983
|
+
* corresponds to the axes of x and the second to the axes of y.
|
|
1259
1984
|
*/
|
|
1260
|
-
declare
|
|
1985
|
+
declare function tensordot(x: ArrayLike, y: ArrayLike, axes?: number | [number[], number[]]): Array;
|
|
1261
1986
|
/**
|
|
1262
|
-
*
|
|
1263
|
-
* Calculate element-wise inverse hyperbolic tangent of input.
|
|
1987
|
+
* Einstein summation with string subscripts.
|
|
1264
1988
|
*
|
|
1265
|
-
*
|
|
1989
|
+
* @example
|
|
1990
|
+
* ```ts
|
|
1991
|
+
* import { numpy as np } from "@jax-js/jax";
|
|
1992
|
+
*
|
|
1993
|
+
* const a = np.ones([2, 3]);
|
|
1994
|
+
* const b = np.ones([3]);
|
|
1995
|
+
* np.einsum("ij,j", a, b); // Shape [2]
|
|
1996
|
+
* ```
|
|
1266
1997
|
*/
|
|
1267
|
-
declare
|
|
1998
|
+
declare function einsum(subscripts: string, ...operands: ArrayLike[]): Array;
|
|
1268
1999
|
/**
|
|
1269
|
-
*
|
|
2000
|
+
* Einstein summation alternating between arrays and numeric indices.
|
|
1270
2001
|
*
|
|
1271
|
-
*
|
|
1272
|
-
*
|
|
2002
|
+
* @example
|
|
2003
|
+
* ```ts
|
|
2004
|
+
* import { numpy as np } from "@jax-js/jax";
|
|
1273
2005
|
*
|
|
1274
|
-
*
|
|
1275
|
-
*
|
|
2006
|
+
* const a = np.ones([2, 3]);
|
|
2007
|
+
* const b = np.ones([3]);
|
|
2008
|
+
* np.einsum(a, [0, 1], b, [1]); // Shape [2]
|
|
2009
|
+
* ```
|
|
1276
2010
|
*/
|
|
1277
|
-
declare function
|
|
1278
|
-
mean?: ArrayLike;
|
|
1279
|
-
correction?: number;
|
|
1280
|
-
} & ReduceOpts): Array;
|
|
2011
|
+
declare function einsum(...args: (ArrayLike | number[])[]): Array;
|
|
1281
2012
|
/**
|
|
1282
|
-
* Compute the
|
|
2013
|
+
* Compute the inner product of two arrays.
|
|
1283
2014
|
*
|
|
1284
|
-
*
|
|
1285
|
-
*
|
|
2015
|
+
* Unlike `jax.numpy.matmul()` or `jax.numpy.dot()`, this always performs a
|
|
2016
|
+
* contraction on the last axis.
|
|
1286
2017
|
*
|
|
1287
|
-
*
|
|
1288
|
-
* where `N` represents the number of elements (e.g., for Bessel's correction).
|
|
2018
|
+
* Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
|
|
1289
2019
|
*/
|
|
1290
|
-
declare function
|
|
1291
|
-
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
|
|
1295
|
-
|
|
1296
|
-
|
|
2020
|
+
declare function inner(x: ArrayLike, y: ArrayLike): Array;
|
|
2021
|
+
/**
|
|
2022
|
+
* Compute the outer product of two arrays.
|
|
2023
|
+
*
|
|
2024
|
+
* If the input arrays are not 1D, they will be flattened. Returned array will
|
|
2025
|
+
* be of shape `[x.size, y.size]`.
|
|
2026
|
+
*/
|
|
2027
|
+
declare function outer(x: ArrayLike, y: ArrayLike): Array;
|
|
2028
|
+
/** Vector dot product of two arrays along a given axis. */
|
|
2029
|
+
declare function vecdot(x: ArrayLike, y: ArrayLike, {
|
|
2030
|
+
axis
|
|
1297
2031
|
}?: {
|
|
1298
|
-
|
|
2032
|
+
axis?: number;
|
|
1299
2033
|
}): Array;
|
|
1300
|
-
/** Compute the Pearson correlation coefficients (in range `[-1, 1]`). */
|
|
1301
|
-
declare function corrcoef(x: ArrayLike, y?: ArrayLike): Array;
|
|
1302
|
-
/** Test element-wise for positive or negative infinity, return bool array. */
|
|
1303
|
-
declare function isinf(x: ArrayLike): Array;
|
|
1304
|
-
/** Test element-wise for NaN (Not a Number). */
|
|
1305
|
-
declare function isnan(x: ArrayLike): Array;
|
|
1306
|
-
/** Test element-wise for negative infinity, return bool array. */
|
|
1307
|
-
declare function isneginf(x: ArrayLike): Array;
|
|
1308
|
-
/** Test element-wise for positive infinity, return bool array. */
|
|
1309
|
-
declare function isposinf(x: ArrayLike): Array;
|
|
1310
2034
|
/**
|
|
1311
|
-
*
|
|
1312
|
-
*
|
|
2035
|
+
* Return the dot product of two vectors.
|
|
2036
|
+
*
|
|
2037
|
+
* Like vecdot() but flattens the arguments first into vectors.
|
|
1313
2038
|
*/
|
|
1314
|
-
declare
|
|
1315
|
-
|
|
1316
|
-
|
|
1317
|
-
|
|
1318
|
-
declare
|
|
1319
|
-
|
|
1320
|
-
|
|
1321
|
-
|
|
1322
|
-
|
|
1323
|
-
|
|
1324
|
-
|
|
1325
|
-
|
|
1326
|
-
|
|
1327
|
-
|
|
1328
|
-
|
|
1329
|
-
|
|
1330
|
-
type MapJsTree<T, A, B> = Same<A, B> extends true ? T : MappedJsTree<T, A, B>;
|
|
1331
|
-
/** Represents the structure of a JsTree. */
|
|
1332
|
-
declare class JsTreeDef {
|
|
1333
|
-
readonly nodeType: NodeType;
|
|
1334
|
-
readonly nodeMetadata: any;
|
|
1335
|
-
readonly childTreedefs: JsTreeDef[];
|
|
1336
|
-
static leaf: JsTreeDef;
|
|
1337
|
-
constructor(nodeType: NodeType, nodeMetadata: any,
|
|
1338
|
-
// Must be comparable with deepEqual.
|
|
1339
|
-
childTreedefs: JsTreeDef[]);
|
|
1340
|
-
/** Get the total number of leaves in the tree. */
|
|
1341
|
-
get size(): number;
|
|
1342
|
-
/** Returns a string representation of this tree definition. */
|
|
1343
|
-
toString(root?: boolean): string;
|
|
1344
|
-
/** Compare this tree definition with another. */
|
|
1345
|
-
equals(other: JsTreeDef): boolean;
|
|
1346
|
-
}
|
|
1347
|
-
/** Flatten a structured object, returning the tree definition. */
|
|
1348
|
-
declare function flatten<T>(tree: JsTree<T>): [T[], JsTreeDef];
|
|
1349
|
-
/** Get the leaves of a tree. */
|
|
1350
|
-
declare function leaves<T>(tree: JsTree<T>): T[];
|
|
1351
|
-
/** Get the treedef for a tree. */
|
|
1352
|
-
declare function structure<T>(tree: JsTree<T>): JsTreeDef;
|
|
1353
|
-
/** Reconstruct a structured object from the flattened representation. */
|
|
1354
|
-
declare function unflatten<T>(treedef: JsTreeDef, leaves: Iterable<T>): JsTree<T>;
|
|
1355
|
-
/** Maps a multi-input function over pytree args to produce a new pytree. */
|
|
1356
|
-
declare function map<T, U, Tree extends JsTree<T>>(fn: (...args: T[]) => U, tree: Tree, ...rest: Tree[]): MapJsTree<Tree, T, U>;
|
|
1357
|
-
/** Take a reference of every array in a tree. */
|
|
1358
|
-
declare function ref<Tree extends JsTree<Array>>(tree: Tree): Tree;
|
|
1359
|
-
/** Dispose every array in a tree. */
|
|
1360
|
-
declare function dispose<Tree extends JsTree<Array>>(tree: Tree | null | undefined): void;
|
|
1361
|
-
//#endregion
|
|
1362
|
-
//#region src/frontend/convolution.d.ts
|
|
1363
|
-
/** Definition of a general dilated convolution. Should be valid on creation. */
|
|
1364
|
-
interface ConvParams {
|
|
1365
|
-
vmapDims: number;
|
|
1366
|
-
strides: number[];
|
|
1367
|
-
padding: Pair[];
|
|
1368
|
-
lhsDilation: number[];
|
|
1369
|
-
rhsDilation: number[];
|
|
1370
|
-
}
|
|
2039
|
+
declare function vdot(x: ArrayLike, y: ArrayLike): Array;
|
|
2040
|
+
/** Convolution of two one-dimensional arrays. */
|
|
2041
|
+
declare function convolve(x: Array, y: Array, mode?: "full" | "same" | "valid"): Array;
|
|
2042
|
+
/** Correlation of two one dimensional arrays. */
|
|
2043
|
+
declare function correlate(x: Array, y: Array, mode?: "full" | "same" | "valid"): Array;
|
|
2044
|
+
/**
|
|
2045
|
+
* Return a tuple of coordinate matrices from coordinate vectors.
|
|
2046
|
+
*
|
|
2047
|
+
* Make N-D coordinate arrays for vectorized evaluations of N-D scalar/vector
|
|
2048
|
+
* fields over N-D grids, given one-dimensional coordinate arrays x1, x2,…, xn.
|
|
2049
|
+
*/
|
|
2050
|
+
declare function meshgrid(xs: Array[], {
|
|
2051
|
+
indexing
|
|
2052
|
+
}?: {
|
|
2053
|
+
indexing?: "xy" | "ij";
|
|
2054
|
+
}): Array[];
|
|
1371
2055
|
/**
|
|
1372
|
-
*
|
|
1373
|
-
* Expected shapes of the lhs and rhs of the convolution are:
|
|
2056
|
+
* Clip (limit) the values in an array.
|
|
1374
2057
|
*
|
|
1375
|
-
*
|
|
1376
|
-
*
|
|
2058
|
+
* Given an interval, values outside the interval are clipped to the interval
|
|
2059
|
+
* edges. For example, if an interval of [0, 1] is specified, values smaller
|
|
2060
|
+
* than 0 become 0, and values larger than 1 become 1.
|
|
1377
2061
|
*
|
|
1378
|
-
* If
|
|
2062
|
+
* If either bound is undefined, it is ignored.
|
|
1379
2063
|
*/
|
|
1380
|
-
|
|
1381
|
-
//#region src/frontend/jaxpr.d.ts
|
|
2064
|
+
declare function clip(a: ArrayLike, min?: ArrayLike, max?: ArrayLike): Array;
|
|
1382
2065
|
/**
|
|
1383
|
-
*
|
|
2066
|
+
* Calculate the absolute value element-wise.
|
|
1384
2067
|
*
|
|
1385
|
-
*
|
|
1386
|
-
* by the function after the last time it is called.
|
|
2068
|
+
* This is the same function as `jax.numpy.abs()`.
|
|
1387
2069
|
*/
|
|
1388
|
-
|
|
1389
|
-
|
|
1390
|
-
|
|
1391
|
-
/**
|
|
1392
|
-
declare
|
|
1393
|
-
#private;
|
|
1394
|
-
readonly id: number;
|
|
1395
|
-
readonly aval: ShapedArray;
|
|
1396
|
-
constructor(aval: ShapedArray);
|
|
1397
|
-
toString(): string;
|
|
1398
|
-
}
|
|
1399
|
-
/** Literal in a Jaxpr expression. Currently, only scalars are supported. */
|
|
1400
|
-
declare class Lit {
|
|
1401
|
-
readonly value: number;
|
|
1402
|
-
readonly aval: ShapedArray;
|
|
1403
|
-
get dtype(): DType;
|
|
1404
|
-
constructor(aval: AbstractValue, value: number);
|
|
1405
|
-
}
|
|
1406
|
-
type Atom = Var | Lit;
|
|
1407
|
-
declare class VarPrinter {
|
|
1408
|
-
#private;
|
|
1409
|
-
names: Map<Var, string>;
|
|
1410
|
-
name(v: Var): string;
|
|
1411
|
-
nameType(v: Var): string;
|
|
1412
|
-
}
|
|
1413
|
-
/** A single statement / binding in a Jaxpr, in ANF form. */
|
|
1414
|
-
declare class JaxprEqn {
|
|
1415
|
-
readonly primitive: Primitive;
|
|
1416
|
-
readonly inputs: Atom[];
|
|
1417
|
-
readonly params: Record<string, any>;
|
|
1418
|
-
readonly outBinders: Var[];
|
|
1419
|
-
constructor(primitive: Primitive, inputs: Atom[], params: Record<string, any>, outBinders: Var[]);
|
|
1420
|
-
pprint(usedVars?: Set<Var>, vp?: VarPrinter): PPrint;
|
|
1421
|
-
toString(): string;
|
|
1422
|
-
}
|
|
1423
|
-
/** Typed intermediate representation for traced computations. */
|
|
1424
|
-
declare class Jaxpr implements FpHashable {
|
|
1425
|
-
#private;
|
|
1426
|
-
readonly inBinders: Var[];
|
|
1427
|
-
readonly eqns: JaxprEqn[];
|
|
1428
|
-
readonly outs: Atom[];
|
|
1429
|
-
constructor(inBinders: Var[], eqns: JaxprEqn[], outs: Atom[]);
|
|
1430
|
-
pprint(): PPrint;
|
|
1431
|
-
toString(): string;
|
|
1432
|
-
/**
|
|
1433
|
-
* Gets a hash of this Jaxpr.
|
|
1434
|
-
*
|
|
1435
|
-
* Var identity is not considered in the hash, so two Jaxprs with the same
|
|
1436
|
-
* order of assignments and operators but different variable IDs will resolve
|
|
1437
|
-
* to the same hash (and toString representation).
|
|
1438
|
-
*/
|
|
1439
|
-
getHash(): bigint;
|
|
1440
|
-
hash(state: FpHash): void;
|
|
1441
|
-
/**
|
|
1442
|
-
* Produce a simplified Jaxpr with basic optimizations applied.
|
|
1443
|
-
* - Trim away unused variables.
|
|
1444
|
-
* - Fold away *1, *0, or +0 operations against literals.
|
|
1445
|
-
* - Remove no-op movement operations.
|
|
1446
|
-
*/
|
|
1447
|
-
simplify(): Jaxpr;
|
|
1448
|
-
/** Flattens nested Jit in a Jaxpr. Useful for handling jit-of-jit. */
|
|
1449
|
-
flatten(): Jaxpr;
|
|
1450
|
-
}
|
|
1451
|
-
/** Jaxpr with a collection of associated, traced constants. */
|
|
1452
|
-
declare class ClosedJaxpr {
|
|
1453
|
-
readonly jaxpr: Jaxpr;
|
|
1454
|
-
readonly consts: Tracer[];
|
|
1455
|
-
constructor(jaxpr: Jaxpr, consts: Tracer[]);
|
|
1456
|
-
/** String representation of this Jaxpr. */
|
|
1457
|
-
toString(): string;
|
|
1458
|
-
/** Apply a function to the underlying Jaxpr. */
|
|
1459
|
-
mapJaxpr(f: (jaxpr: Jaxpr) => Jaxpr): ClosedJaxpr;
|
|
1460
|
-
/** Dispose of the constants in this Jaxpr. */
|
|
1461
|
-
dispose(): void;
|
|
1462
|
-
}
|
|
1463
|
-
/** @inline */
|
|
1464
|
-
type JitOpts = {
|
|
1465
|
-
staticArgnums?: number[];
|
|
1466
|
-
};
|
|
1467
|
-
//#endregion
|
|
1468
|
-
//#region src/frontend/core.d.ts
|
|
2070
|
+
declare function absolute(x: ArrayLike): Array;
|
|
2071
|
+
/** Return an element-wise indication of sign of the input. */
|
|
2072
|
+
declare function sign(x: ArrayLike): Array;
|
|
2073
|
+
/** @function Return element-wise positive values of the input (no-op). */
|
|
2074
|
+
declare const positive: (x: ArrayLike) => Array;
|
|
1469
2075
|
/**
|
|
1470
|
-
*
|
|
1471
|
-
*
|
|
1472
|
-
*
|
|
1473
|
-
|
|
1474
|
-
|
|
1475
|
-
* which transformations like vmap, grad, and jvp occur. They are loosely based
|
|
1476
|
-
* on [XLA](https://openxla.org/xla/operation_semantics).
|
|
1477
|
-
*
|
|
1478
|
-
* All n-ary operations support broadcasting, with NumPy semantics.
|
|
1479
|
-
*/
|
|
1480
|
-
declare enum Primitive {
|
|
1481
|
-
Add = "add",
|
|
1482
|
-
Mul = "mul",
|
|
1483
|
-
Idiv = "idiv",
|
|
1484
|
-
Mod = "mod",
|
|
1485
|
-
// uses sign of numerator, C-style, matches JS but not Python
|
|
1486
|
-
Min = "min",
|
|
1487
|
-
Max = "max",
|
|
1488
|
-
Neg = "neg",
|
|
1489
|
-
Reciprocal = "reciprocal",
|
|
1490
|
-
Floor = "floor",
|
|
1491
|
-
Ceil = "ceil",
|
|
1492
|
-
StopGradient = "stop_gradient",
|
|
1493
|
-
Cast = "cast",
|
|
1494
|
-
Bitcast = "bitcast",
|
|
1495
|
-
Sin = "sin",
|
|
1496
|
-
Cos = "cos",
|
|
1497
|
-
Asin = "asin",
|
|
1498
|
-
Atan = "atan",
|
|
1499
|
-
Exp = "exp",
|
|
1500
|
-
Log = "log",
|
|
1501
|
-
Erf = "erf",
|
|
1502
|
-
Erfc = "erfc",
|
|
1503
|
-
Sqrt = "sqrt",
|
|
1504
|
-
Reduce = "reduce",
|
|
1505
|
-
Dot = "dot",
|
|
1506
|
-
// sum(x*y, axis=-1)
|
|
1507
|
-
Conv = "conv",
|
|
1508
|
-
// see lax.conv_general_dilated
|
|
1509
|
-
Pool = "pool",
|
|
1510
|
-
PoolTranspose = "pool_transpose",
|
|
1511
|
-
Compare = "compare",
|
|
1512
|
-
Where = "where",
|
|
1513
|
-
Concatenate = "concatenate",
|
|
1514
|
-
Split = "split",
|
|
1515
|
-
RandomBits = "random_bits",
|
|
1516
|
-
Gather = "gather",
|
|
1517
|
-
Transpose = "transpose",
|
|
1518
|
-
Broadcast = "broadcast",
|
|
1519
|
-
Reshape = "reshape",
|
|
1520
|
-
Flip = "flip",
|
|
1521
|
-
Shrink = "shrink",
|
|
1522
|
-
Pad = "pad",
|
|
1523
|
-
Sort = "sort",
|
|
1524
|
-
// sort(x, axis=-1)
|
|
1525
|
-
Argsort = "argsort",
|
|
1526
|
-
// argsort(x, axis=-1)
|
|
1527
|
-
TriangularSolve = "triangular_solve",
|
|
1528
|
-
// A is upper triangular, A @ X.T = B.T
|
|
1529
|
-
Cholesky = "cholesky",
|
|
1530
|
-
// A is positive-definite, A = L @ L^T
|
|
1531
|
-
LU = "lu",
|
|
1532
|
-
// LU decomposition with partial pivoting
|
|
1533
|
-
Jit = "jit",
|
|
1534
|
-
}
|
|
1535
|
-
interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
1536
|
-
[Primitive.Cast]: {
|
|
1537
|
-
dtype: DType;
|
|
1538
|
-
};
|
|
1539
|
-
[Primitive.Bitcast]: {
|
|
1540
|
-
dtype: DType;
|
|
1541
|
-
};
|
|
1542
|
-
[Primitive.Reduce]: {
|
|
1543
|
-
op: AluOp;
|
|
1544
|
-
axis: number[];
|
|
1545
|
-
};
|
|
1546
|
-
[Primitive.Conv]: ConvParams;
|
|
1547
|
-
[Primitive.Pool]: {
|
|
1548
|
-
window: number[];
|
|
1549
|
-
strides: number[];
|
|
1550
|
-
};
|
|
1551
|
-
[Primitive.PoolTranspose]: {
|
|
1552
|
-
inShape: number[];
|
|
1553
|
-
window: number[];
|
|
1554
|
-
strides: number[];
|
|
1555
|
-
};
|
|
1556
|
-
[Primitive.Compare]: {
|
|
1557
|
-
op: CompareOp;
|
|
1558
|
-
};
|
|
1559
|
-
[Primitive.Concatenate]: {
|
|
1560
|
-
axis: number;
|
|
1561
|
-
};
|
|
1562
|
-
[Primitive.Split]: {
|
|
1563
|
-
axis: number;
|
|
1564
|
-
sizes: number[];
|
|
1565
|
-
};
|
|
1566
|
-
[Primitive.RandomBits]: {
|
|
1567
|
-
shape: number[];
|
|
1568
|
-
mode: "xor" | 0 | 1;
|
|
1569
|
-
};
|
|
1570
|
-
[Primitive.Gather]: {
|
|
1571
|
-
axis: number[];
|
|
1572
|
-
outDim: number;
|
|
1573
|
-
};
|
|
1574
|
-
[Primitive.Transpose]: {
|
|
1575
|
-
perm: number[];
|
|
1576
|
-
};
|
|
1577
|
-
[Primitive.Broadcast]: {
|
|
1578
|
-
shape: number[];
|
|
1579
|
-
axis: number[];
|
|
1580
|
-
};
|
|
1581
|
-
[Primitive.Reshape]: {
|
|
1582
|
-
shape: number[];
|
|
1583
|
-
};
|
|
1584
|
-
[Primitive.Flip]: {
|
|
1585
|
-
axis: number[];
|
|
1586
|
-
};
|
|
1587
|
-
[Primitive.Shrink]: {
|
|
1588
|
-
slice: Pair[];
|
|
1589
|
-
};
|
|
1590
|
-
[Primitive.Pad]: {
|
|
1591
|
-
width: Pair[];
|
|
1592
|
-
};
|
|
1593
|
-
[Primitive.Jit]: {
|
|
1594
|
-
name: string;
|
|
1595
|
-
jaxpr: Jaxpr;
|
|
1596
|
-
numConsts: number;
|
|
1597
|
-
};
|
|
1598
|
-
[Primitive.TriangularSolve]: {
|
|
1599
|
-
unitDiagonal: boolean;
|
|
1600
|
-
};
|
|
1601
|
-
}
|
|
1602
|
-
/** Type of parameters taken by each primitive. */
|
|
1603
|
-
type PrimitiveParams<T extends Primitive> = T extends keyof PrimitiveParamsImpl ? PrimitiveParamsImpl[T] : Record<string, never>;
|
|
1604
|
-
declare enum CompareOp {
|
|
1605
|
-
Less = "less",
|
|
1606
|
-
Equal = "equal",
|
|
1607
|
-
NotEqual = "not_equal",
|
|
1608
|
-
LessEqual = "less_equal",
|
|
1609
|
-
}
|
|
1610
|
-
/** @inline */
|
|
1611
|
-
type Axis = number | number[] | null;
|
|
1612
|
-
/** @inline */
|
|
1613
|
-
type ReduceOpts = {
|
|
1614
|
-
keepdims?: boolean;
|
|
1615
|
-
};
|
|
1616
|
-
type MainTrace = {
|
|
1617
|
-
level: number;
|
|
1618
|
-
traceType: new (main: MainTrace) => Trace;
|
|
1619
|
-
globalData: any | null;
|
|
1620
|
-
};
|
|
2076
|
+
* Return the Hamming window of size M, a taper with a weighted cosine bell.
|
|
2077
|
+
*
|
|
2078
|
+
* `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
2079
|
+
*/
|
|
2080
|
+
declare function hamming(M: number): Array;
|
|
1621
2081
|
/**
|
|
1622
|
-
*
|
|
1623
|
-
*
|
|
2082
|
+
* Return the Hann window of size M, a taper with a weighted cosine bell.
|
|
2083
|
+
*
|
|
2084
|
+
* `w(n) = 0.5 - 0.5 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
1624
2085
|
*/
|
|
1625
|
-
|
|
1626
|
-
type TracerValue = Tracer | number | boolean;
|
|
1627
|
-
declare abstract class Trace {
|
|
1628
|
-
readonly main: MainTrace;
|
|
1629
|
-
constructor(main: MainTrace);
|
|
1630
|
-
abstract pure(val: TracerValue): Tracer;
|
|
1631
|
-
abstract lift(val: Tracer): Tracer;
|
|
1632
|
-
abstract processPrimitive<P extends Primitive>(primitive: P, tracers: Tracer[], params: PrimitiveParams<P>): Tracer[];
|
|
1633
|
-
}
|
|
1634
|
-
/** Internal representation of an array value. */
|
|
1635
|
-
interface AbstractValue {
|
|
1636
|
-
/** Shape of the array. Must be a static tuple of non-negative dimensions. */
|
|
1637
|
-
shape: number[];
|
|
1638
|
-
/** Concrete data type of array elements. */
|
|
1639
|
-
dtype: DType;
|
|
1640
|
-
/**
|
|
1641
|
-
* Arrays created from JavaScript numbers (e.g., `np.array(3)`) are created as
|
|
1642
|
-
* _weakly typed_ unless a dtype is explicitly specified.
|
|
1643
|
-
*
|
|
1644
|
-
* Weakly typed values will automatically cast to the data type of other
|
|
1645
|
-
* arrays when used as an operand as an expression. This property only affects
|
|
1646
|
-
* how they promote in type casting; their memory layout is still determined
|
|
1647
|
-
* by the actual `dtype` field.
|
|
1648
|
-
*
|
|
1649
|
-
* ```ts
|
|
1650
|
-
* const x = np.array(3); // weakType = true, dtype = float32
|
|
1651
|
-
* const y = np.array([1, 2], { dtype: np.int32 }); // weakType = false, dtype = int32
|
|
1652
|
-
* const z = x.add(y); // z has dtype int32 because x is weakly typed
|
|
1653
|
-
* ```
|
|
1654
|
-
*
|
|
1655
|
-
* Weak types are present in JIT programs in their spec (e.g., Jaxpr inputs
|
|
1656
|
-
* and outputs can be weakly typed) form. But they're solely a frontend
|
|
1657
|
-
* concept. Backends are not aware of weak types.
|
|
1658
|
-
*/
|
|
1659
|
-
weakType: boolean;
|
|
1660
|
-
}
|
|
2086
|
+
declare function hann(M: number): Array;
|
|
1661
2087
|
/**
|
|
1662
|
-
*
|
|
2088
|
+
* @function
|
|
2089
|
+
* Compute the Heaviside step function. It is defined piecewise:
|
|
2090
|
+
* - `heaviside(x1, x2) = 0` for `x1 < 0`,
|
|
2091
|
+
* - `heaviside(x1, x2) = x2` for `x1 == 0`,
|
|
2092
|
+
* - `heaviside(x1, x2) = 1` for `x1 > 0`.
|
|
2093
|
+
*/
|
|
2094
|
+
declare const heaviside: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
2095
|
+
/** Calculate element-wise square of the input array. */
|
|
2096
|
+
declare function square(x: ArrayLike): Array;
|
|
2097
|
+
/** Element-wise tangent function (takes radians). */
|
|
2098
|
+
declare function tan(x: ArrayLike): Array;
|
|
2099
|
+
/**
|
|
2100
|
+
* @function
|
|
2101
|
+
* Return the normalized sinc function.
|
|
1663
2102
|
*
|
|
1664
|
-
*
|
|
1665
|
-
*
|
|
2103
|
+
* The sinc function is defined as `sin(πx) / (πx)` for `x != 0`, and `1` for `x = 0`.
|
|
2104
|
+
* This is the normalized sinc function commonly used in signal processing.
|
|
2105
|
+
*
|
|
2106
|
+
* **Note:** JVP is not supported at x=0 due to discontinuous derivative. This
|
|
2107
|
+
* requires a custom JVP rule to handle properly (see JAX implementation).
|
|
1666
2108
|
*/
|
|
1667
|
-
|
|
1668
|
-
|
|
1669
|
-
|
|
1670
|
-
readonly _trace: Trace;
|
|
1671
|
-
constructor(trace: Trace);
|
|
1672
|
-
abstract get aval(): AbstractValue;
|
|
1673
|
-
abstract toString(): string;
|
|
1674
|
-
/**
|
|
1675
|
-
* Access an array by reference, incrementing the reference count.
|
|
1676
|
-
*
|
|
1677
|
-
* jax-js handles freeing arrays by using "move" semantics, like in Rust/C++.
|
|
1678
|
-
* Whenever you pass an array into a function, that function should consume
|
|
1679
|
-
* the array, and it will no longer be usable. For example, if you had:
|
|
1680
|
-
*
|
|
1681
|
-
* ```
|
|
1682
|
-
* const x = np.array([1, 2, 3]);
|
|
1683
|
-
* const y = np.add(x, x);
|
|
1684
|
-
* ```
|
|
1685
|
-
*
|
|
1686
|
-
* The second line does not work because the first parameter consumes `x`, and
|
|
1687
|
-
* then the second parameter will already have been freed / disposed.
|
|
1688
|
-
*
|
|
1689
|
-
* To fix this, you can write:
|
|
1690
|
-
*
|
|
1691
|
-
* ```
|
|
1692
|
-
* const y = np.add(x.ref, x);
|
|
1693
|
-
* ```
|
|
1694
|
-
*
|
|
1695
|
-
* Under the hood, every access to `.ref` increments the internal reference
|
|
1696
|
-
* count of the array. The reference count starts at 1. When it hits 0, the
|
|
1697
|
-
* memory behind the array is freed.
|
|
1698
|
-
*/
|
|
1699
|
-
abstract get ref(): this;
|
|
1700
|
-
/**
|
|
1701
|
-
* Manually decrement the reference count of the array.
|
|
1702
|
-
*
|
|
1703
|
-
* Arrays are created with reference count 1. Whenever it is used as argument
|
|
1704
|
-
* to a function or other operation, it is disposed (i.e., reference count
|
|
1705
|
-
* decreases by 1) automatically. Whenever a `.ref` is created, the reference
|
|
1706
|
-
* count increases.
|
|
1707
|
-
*
|
|
1708
|
-
* You generally don't need to call this function directly since arrays are
|
|
1709
|
-
* automatically disposed after being passed into an operation. One common
|
|
1710
|
-
* exception is when writing a function and ignoring one of its arguments. In
|
|
1711
|
-
* that case, by convention you should dispose of that argument manually.
|
|
1712
|
-
*
|
|
1713
|
-
* ```
|
|
1714
|
-
* function myCustomOperation(a: np.Array, b: np.Array) {
|
|
1715
|
-
* b.dispose(); // Needed to satisfy "move" rules.
|
|
1716
|
-
* return a.add(1);
|
|
1717
|
-
* }
|
|
1718
|
-
* ```
|
|
1719
|
-
*/
|
|
1720
|
-
abstract dispose(): void;
|
|
1721
|
-
/** The shape of the array. */
|
|
1722
|
-
get shape(): number[];
|
|
1723
|
-
/** The total number of elements in the array. */
|
|
1724
|
-
get size(): number;
|
|
1725
|
-
/** The dtype of elements stored in the array. */
|
|
1726
|
-
get dtype(): DType;
|
|
1727
|
-
/**
|
|
1728
|
-
* Whether the array is weakly typed.
|
|
1729
|
-
*
|
|
1730
|
-
* Weakly typed arrays will cast to the dtype of the other operand. See
|
|
1731
|
-
* `promoteTypes()` for details.
|
|
1732
|
-
*/
|
|
1733
|
-
get weakType(): boolean;
|
|
1734
|
-
/** The number of dimensions of the array. */
|
|
1735
|
-
get ndim(): number;
|
|
1736
|
-
/** @ignore */
|
|
1737
|
-
fullLower(): Tracer;
|
|
1738
|
-
neg(): this;
|
|
1739
|
-
add(other: this | TracerValue): this;
|
|
1740
|
-
mul(other: this | TracerValue): this;
|
|
1741
|
-
mod(other: this | TracerValue): this;
|
|
1742
|
-
greater(other: this | TracerValue): this;
|
|
1743
|
-
less(other: this | TracerValue): this;
|
|
1744
|
-
equal(other: this | TracerValue): this;
|
|
1745
|
-
notEqual(other: this | TracerValue): this;
|
|
1746
|
-
greaterEqual(other: this | TracerValue): this;
|
|
1747
|
-
lessEqual(other: this | TracerValue): this;
|
|
1748
|
-
/** Sum of the elements of the array over a given axis, or axes. */
|
|
1749
|
-
sum(axis?: Axis, opts?: ReduceOpts): this;
|
|
1750
|
-
/** Product of the array elements over a given axis. */
|
|
1751
|
-
prod(axis?: Axis, opts?: ReduceOpts): this;
|
|
1752
|
-
/** Compute the average of the array elements along the specified axis. */
|
|
1753
|
-
mean(axis?: Axis, opts?: ReduceOpts): this;
|
|
1754
|
-
/** Permute the dimensions of an array. Defaults to reversing the axis order. */
|
|
1755
|
-
transpose(perm?: number[]): this;
|
|
1756
|
-
/**
|
|
1757
|
-
* Give a new shape to an array without changing its data.
|
|
1758
|
-
*
|
|
1759
|
-
* One shape dimension can be -1. In this case, the value is inferred from the
|
|
1760
|
-
* length of the array and remaining dimensions.
|
|
1761
|
-
*/
|
|
1762
|
-
reshape(shape: number | number[]): this;
|
|
1763
|
-
/** Copy the array and cast to a specified dtype. */
|
|
1764
|
-
astype(dtype: DType): this;
|
|
1765
|
-
/** Subtract an array from this one. */
|
|
1766
|
-
sub(other: this | TracerValue): this;
|
|
1767
|
-
/** Divide an array by this one. */
|
|
1768
|
-
div(other: this | TracerValue): this;
|
|
1769
|
-
/** Return specified diagonals. See `jax.numpy.diagonal` for full docs. */
|
|
1770
|
-
diagonal(offset?: number, axis1?: number, axis2?: number): this;
|
|
1771
|
-
/** Flatten the array without changing its data. */
|
|
1772
|
-
flatten(): this;
|
|
1773
|
-
/** Flatten the array without changing its data. */
|
|
1774
|
-
ravel(): this;
|
|
1775
|
-
/**
|
|
1776
|
-
* Iterate over the first dimension of this array, returning slices.
|
|
1777
|
-
*
|
|
1778
|
-
* This can be used to destructure arrays. For example:
|
|
1779
|
-
*
|
|
1780
|
-
* ```js
|
|
1781
|
-
* let x = np.array([[1, 2], [3, 4]]);
|
|
1782
|
-
* let [a, b] = x;
|
|
1783
|
-
* console.log(a.js()); // [1, 2]
|
|
1784
|
-
* console.log(b.js()); // [3, 4]
|
|
1785
|
-
* ```
|
|
1786
|
-
*/
|
|
1787
|
-
[Symbol.iterator](): IterableIterator<this>;
|
|
1788
|
-
/**
|
|
1789
|
-
* Return a sorted copy of an array in ascending order.
|
|
1790
|
-
*
|
|
1791
|
-
* See `jax.numpy.sort` for full docs.
|
|
1792
|
-
*/
|
|
1793
|
-
sort(axis?: number): this;
|
|
1794
|
-
/**
|
|
1795
|
-
* Return the indices that would sort an array. This may not be a stable
|
|
1796
|
-
* sorting algorithm; it need not preserve order of indices in ties.
|
|
1797
|
-
*
|
|
1798
|
-
* See `jax.numpy.argsort` for full docs.
|
|
1799
|
-
*/
|
|
1800
|
-
argsort(axis?: number): this;
|
|
1801
|
-
/**
|
|
1802
|
-
* Slice an array along one or more axes.
|
|
1803
|
-
*
|
|
1804
|
-
* This is the equivalent of slicing in Python, e.g. `x[1:3, 2, :, None]`. To
|
|
1805
|
-
* mimic this in JavaScript, we would write:
|
|
1806
|
-
*
|
|
1807
|
-
* ```js
|
|
1808
|
-
* x.slice([1, 3], 2, [], null);
|
|
1809
|
-
* ```
|
|
1810
|
-
*
|
|
1811
|
-
* The `slice` method accepts a variable number of arguments, each of which
|
|
1812
|
-
* can be a number, an empty array, a single-element array, a two-element
|
|
1813
|
-
* array, or `null`. The arguments are interpreted as follows:
|
|
1814
|
-
*
|
|
1815
|
-
* - A number `n` means to access the `n`-th element along that axis, removing
|
|
1816
|
-
* that axis from the resulting shape.
|
|
1817
|
-
* - An empty array `[]` means to keep that axis as-is, like `:` in Python.
|
|
1818
|
-
* - A single-element array `[i]` means to start slicing from index `i`
|
|
1819
|
-
* (inclusive) to the end of the axis, like `x[i:]`.
|
|
1820
|
-
* - A two-element array `[i, j]` means to slice from index `i` (inclusive)
|
|
1821
|
-
* to index `j` (exclusive), like `x[i:j]`.
|
|
1822
|
-
* - `null` means to add a new axis at that position, like `np.newaxis`.
|
|
1823
|
-
*
|
|
1824
|
-
* Like in Python, negative indices are supported, which count from the end of
|
|
1825
|
-
* the axis. For example, `-1` means the last element.
|
|
1826
|
-
*
|
|
1827
|
-
* Strided slices are not yet implemented, so you cannot write `x[::2]` or
|
|
1828
|
-
* similar.
|
|
1829
|
-
*
|
|
1830
|
-
* Advanced indexing by integer arrays is also supported. This translates to
|
|
1831
|
-
* the "gather" primitive, and it allows you to access specific elements of
|
|
1832
|
-
* the array by integer indices stored in another array.
|
|
1833
|
-
*/
|
|
1834
|
-
slice(...index: (number | [] | [number] | Pair | null | Tracer)[]): this;
|
|
1835
|
-
}
|
|
1836
|
-
declare class ShapedArray implements AbstractValue {
|
|
1837
|
-
readonly shape: number[];
|
|
1838
|
-
readonly dtype: DType;
|
|
1839
|
-
readonly weakType: boolean;
|
|
1840
|
-
constructor(shape: number[], dtype: DType, weakType: boolean);
|
|
1841
|
-
static fromAval(aval: AbstractValue): ShapedArray;
|
|
1842
|
-
get ndim(): number;
|
|
1843
|
-
get size(): number;
|
|
1844
|
-
scalar(): ShapedArray;
|
|
1845
|
-
toString(): string;
|
|
1846
|
-
equals(other: ShapedArray): boolean;
|
|
1847
|
-
}
|
|
1848
|
-
//#endregion
|
|
1849
|
-
//#region src/frontend/array.d.ts
|
|
1850
|
-
type ArrayLike = Array | number | boolean;
|
|
1851
|
-
/** Version of pureArray with fudged types. */
|
|
1852
|
-
|
|
2109
|
+
declare const sinc: OwnedFunction<(x: ArrayLike) => Array>;
|
|
2110
|
+
/** Element-wise inverse cosine function (inverse of cos). */
|
|
2111
|
+
declare function acos(x: ArrayLike): Array;
|
|
1853
2112
|
/**
|
|
1854
|
-
*
|
|
2113
|
+
* @function
|
|
2114
|
+
* Return element-wise hypotenuse for the given legs of a right triangle.
|
|
1855
2115
|
*
|
|
1856
|
-
*
|
|
1857
|
-
*
|
|
2116
|
+
* In the original NumPy/JAX implementation, this function is more numerically
|
|
2117
|
+
* stable than `sqrt(x1**2 + x2**2)`. We don't currently implement those
|
|
2118
|
+
* stability improvements.
|
|
1858
2119
|
*/
|
|
1859
|
-
declare
|
|
1860
|
-
#private;
|
|
1861
|
-
readonly backend: Backend;
|
|
1862
|
-
readonly source: Kernel | Routine;
|
|
1863
|
-
readonly inputs: Slot[];
|
|
1864
|
-
readonly outputs: Slot[];
|
|
1865
|
-
prepared: Executable | null;
|
|
1866
|
-
submitted: boolean;
|
|
1867
|
-
constructor(backend: Backend, source: Kernel | Routine, inputs: Slot[], outputs: Slot[]);
|
|
1868
|
-
updateRc(delta: number): void;
|
|
1869
|
-
prepare(): Promise<void>;
|
|
1870
|
-
prepareSync(): void;
|
|
1871
|
-
submit(): void;
|
|
1872
|
-
}
|
|
1873
|
-
/** @inline */
|
|
1874
|
-
type DTypeAndDevice = {
|
|
1875
|
-
dtype?: DType;
|
|
1876
|
-
device?: Device;
|
|
1877
|
-
};
|
|
1878
|
-
type ArrayConstructorArgs = {
|
|
1879
|
-
source: AluExp | Slot;
|
|
1880
|
-
st: ShapeTracker;
|
|
1881
|
-
dtype: DType;
|
|
1882
|
-
weakType: boolean;
|
|
1883
|
-
backend: Backend;
|
|
1884
|
-
committed: boolean;
|
|
1885
|
-
pending?: Iterable<PendingExecute>;
|
|
1886
|
-
};
|
|
2120
|
+
declare const hypot: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
1887
2121
|
/**
|
|
1888
|
-
*
|
|
2122
|
+
* @function
|
|
2123
|
+
* Element-wise arc tangent of y/x with correct quadrant.
|
|
1889
2124
|
*
|
|
1890
|
-
*
|
|
1891
|
-
*
|
|
2125
|
+
* Returns the angle in radians between the positive x-axis and the point (x, y).
|
|
2126
|
+
* The result is in the range [-π, π].
|
|
1892
2127
|
*
|
|
1893
|
-
*
|
|
1894
|
-
*
|
|
1895
|
-
*
|
|
2128
|
+
* Uses numerically stable formulas:
|
|
2129
|
+
* - When x >= 0: atan2(y, x) = 2 * atan(y / (sqrt(x^2 + y^2) + x))
|
|
2130
|
+
* - When x < 0: atan2(y, x) = 2 * atan((sqrt(x^2 + y^2) - x) / y)
|
|
2131
|
+
*
|
|
2132
|
+
* The output is ill-defined when both x and y are zero.
|
|
1896
2133
|
*/
|
|
1897
|
-
declare
|
|
1898
|
-
|
|
1899
|
-
|
|
1900
|
-
|
|
1901
|
-
|
|
1902
|
-
* is a backend `Slot`, this constructor _takes ownership_ of the slot. It
|
|
1903
|
-
* will be freed when the array is disposed.
|
|
1904
|
-
*/
|
|
1905
|
-
constructor(args: ArrayConstructorArgs);
|
|
1906
|
-
/** @ignore */
|
|
1907
|
-
get aval(): ShapedArray;
|
|
1908
|
-
/** Return a simple string representation of the array's dimensions. */
|
|
1909
|
-
toString(): string;
|
|
1910
|
-
get device(): Device;
|
|
1911
|
-
get ref(): this;
|
|
1912
|
-
/** Get the current reference count (for debugging memory management). */
|
|
1913
|
-
get refCount(): number;
|
|
1914
|
-
dispose(): void;
|
|
1915
|
-
/**
|
|
1916
|
-
* Convert this array into a primitive value.
|
|
1917
|
-
*
|
|
1918
|
-
* This only works for scalars (0-dimensional arrays). It lets you get values
|
|
1919
|
-
* "out" of the JAX system. For instance, if `x = np.array(5)`, then you can
|
|
1920
|
-
* evaluate `x + 1` and `x ** 2` to get `6` and `25`, respectively.
|
|
1921
|
-
*
|
|
1922
|
-
* This method is also called for `==` equality.
|
|
1923
|
-
*/
|
|
1924
|
-
[Symbol.toPrimitive](): any;
|
|
1925
|
-
/** Realize the array and return it as data. */
|
|
1926
|
-
data(): Promise<DataArray>;
|
|
1927
|
-
/**
|
|
1928
|
-
* Wait for this array to finish evaluation.
|
|
1929
|
-
*
|
|
1930
|
-
* Operations and data loading in jax-js are lazy, so this function ensures
|
|
1931
|
-
* that pending operations are dispatched and fully executed before it
|
|
1932
|
-
* returns.
|
|
1933
|
-
*
|
|
1934
|
-
* If you are mapping from `data()` or `dataSync()`, it will also trigger
|
|
1935
|
-
* dispatch of operations as well.
|
|
1936
|
-
*
|
|
1937
|
-
* **Note:** `jax.blockUntilReady()` is a higher-level API, it calls this
|
|
1938
|
-
* asynchronously for multiple arrays.
|
|
1939
|
-
*/
|
|
1940
|
-
blockUntilReady(): Promise<Array>;
|
|
1941
|
-
/**
|
|
1942
|
-
* Realize the array and return it as data. This is a sync variant and not
|
|
1943
|
-
* recommended for performance reasons, as it will block rendering.
|
|
1944
|
-
*/
|
|
1945
|
-
dataSync(): DataArray;
|
|
1946
|
-
/**
|
|
1947
|
-
* Convert this array into a JavaScript object.
|
|
1948
|
-
*
|
|
1949
|
-
* This is a blocking operation that will compile all of the shaders and wait
|
|
1950
|
-
* for execution to complete, synchronously. No other JavaScript code on the
|
|
1951
|
-
* site will be run during shader execution.
|
|
1952
|
-
*
|
|
1953
|
-
* To avoid blocking, prefer `jsAsync()` when possible.
|
|
1954
|
-
*/
|
|
1955
|
-
js(): any;
|
|
1956
|
-
/** Convert this array into a JavaScript object, asynchronously. */
|
|
1957
|
-
jsAsync(): Promise<any>;
|
|
1958
|
-
/**
|
|
1959
|
-
* Copy an element of an array to a numeric scalar and return it.
|
|
1960
|
-
*
|
|
1961
|
-
* Throws an error if the array does not have a single element. The array must
|
|
1962
|
-
* either be rank-0, or all dimensions of the shape are 1.
|
|
1963
|
-
*/
|
|
1964
|
-
item(): number;
|
|
1965
|
-
/** @private Internal plumbing method for Array / Tracer ops. */
|
|
1966
|
-
static _implRules(): typeof implRules;
|
|
1967
|
-
/** @private */
|
|
1968
|
-
_realizeSource(): number;
|
|
1969
|
-
/** @private Put this array on a new backend, asynchronously. */
|
|
1970
|
-
_put(backend: Backend): Promise<Array>;
|
|
1971
|
-
/** @private Put this array on a new backend, synchronously. */
|
|
1972
|
-
_putSync(backend: Backend): Array;
|
|
1973
|
-
}
|
|
1974
|
-
/** Constructor for creating a new array from data. */
|
|
1975
|
-
declare function array(values: Array | DataArray | RecursiveArray<number> | RecursiveArray<boolean>, {
|
|
1976
|
-
shape,
|
|
1977
|
-
dtype,
|
|
1978
|
-
device
|
|
1979
|
-
}?: {
|
|
1980
|
-
shape?: number[];
|
|
1981
|
-
} & DTypeAndDevice): Array;
|
|
1982
|
-
/** If x is a value, lift it into an array, otherwise leave it be. */
|
|
1983
|
-
|
|
1984
|
-
type ImplRule<P extends Primitive> = (tracers: Array[], params: PrimitiveParams<P>) => Array[];
|
|
1985
|
-
declare const implRules: { [P in Primitive]: ImplRule<P> };
|
|
1986
|
-
/** Return a new array of given shape and type, filled with zeros. */
|
|
1987
|
-
declare function zeros(shape: number[], {
|
|
1988
|
-
dtype,
|
|
1989
|
-
device
|
|
1990
|
-
}?: DTypeAndDevice): Array;
|
|
1991
|
-
/** Return a new array of given shape and type, filled with ones. */
|
|
1992
|
-
declare function ones(shape: number[], {
|
|
1993
|
-
dtype,
|
|
1994
|
-
device
|
|
1995
|
-
}?: DTypeAndDevice): Array;
|
|
1996
|
-
/** Return a new array of given shape and type, filled with `fill_value`. */
|
|
1997
|
-
declare function full(shape: number[], fillValue: number | boolean | Array, {
|
|
1998
|
-
dtype,
|
|
1999
|
-
device
|
|
2000
|
-
}?: DTypeAndDevice): Array;
|
|
2134
|
+
declare const atan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
|
|
2135
|
+
/** Element-wise subtraction, with broadcasting. */
|
|
2136
|
+
declare function subtract(x: ArrayLike, y: ArrayLike): Array;
|
|
2137
|
+
/** Calculates the floating-point division of x by y element-wise. */
|
|
2138
|
+
declare function trueDivide(x: ArrayLike, y: ArrayLike): Array;
|
|
2001
2139
|
/**
|
|
2002
|
-
*
|
|
2140
|
+
* Return the largest integer smaller or equal to the division of the inputs.
|
|
2003
2141
|
*
|
|
2004
|
-
*
|
|
2005
|
-
*
|
|
2142
|
+
* The result is always rounded towards negative infinity.
|
|
2143
|
+
*
|
|
2144
|
+
* For floating-point inputs, this is equivalent to `floor(x / y)`.
|
|
2145
|
+
* For integer inputs, we use `(x - remainder(x, y)) / y` to handle
|
|
2146
|
+
* negative values correctly (note: may overflow near int32 boundaries).
|
|
2147
|
+
*
|
|
2148
|
+
* @param x - Dividend array.
|
|
2149
|
+
* @param y - Divisor array.
|
|
2150
|
+
* @returns Element-wise floor division of x by y.
|
|
2006
2151
|
*/
|
|
2007
|
-
declare function
|
|
2008
|
-
dtype,
|
|
2009
|
-
device
|
|
2010
|
-
}?: DTypeAndDevice): Array;
|
|
2011
|
-
/** Return the identity matrix, with ones on the main diagonal. */
|
|
2012
|
-
declare function identity$1(n: number, {
|
|
2013
|
-
dtype,
|
|
2014
|
-
device
|
|
2015
|
-
}?: DTypeAndDevice): Array;
|
|
2152
|
+
declare function floorDivide(x: ArrayLike, y: ArrayLike): Array;
|
|
2016
2153
|
/**
|
|
2017
|
-
*
|
|
2154
|
+
* @function
|
|
2155
|
+
* Calculate element-wise floating-point modulo operation.
|
|
2156
|
+
*/
|
|
2157
|
+
declare const fmod: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
|
|
2158
|
+
/**
|
|
2159
|
+
* @function
|
|
2160
|
+
* Calculate element-wise remainder of the division (matches sign of y).
|
|
2161
|
+
*/
|
|
2162
|
+
declare const remainder: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
|
|
2163
|
+
/**
|
|
2164
|
+
* Return element-wise quotient and remainder simultaneously.
|
|
2018
2165
|
*
|
|
2019
|
-
*
|
|
2020
|
-
* builtin function in Python.
|
|
2166
|
+
* Equivalent to `[floorDivide(x, y), remainder(x, y)]`.
|
|
2021
2167
|
*
|
|
2022
|
-
*
|
|
2023
|
-
*
|
|
2024
|
-
*
|
|
2025
|
-
|
|
2168
|
+
* @param x - Dividend array.
|
|
2169
|
+
* @param y - Divisor array.
|
|
2170
|
+
* @returns Tuple of [quotient, remainder].
|
|
2171
|
+
*/
|
|
2172
|
+
declare function divmod(x: ArrayLike, y: ArrayLike): [Array, Array];
|
|
2173
|
+
/** Round input to the nearest integer towards zero. */
|
|
2174
|
+
declare function trunc(x: ArrayLike): Array;
|
|
2175
|
+
/**
|
|
2176
|
+
* Compute `x1 * 2 ** x2` as a standard multiplication and exponentiation.
|
|
2026
2177
|
*
|
|
2027
|
-
*
|
|
2028
|
-
* using a non-integer step, so prefer linspace() in those cases.
|
|
2178
|
+
* This is the inverse of `frexp()`.
|
|
2029
2179
|
*/
|
|
2030
|
-
declare function
|
|
2031
|
-
dtype,
|
|
2032
|
-
device
|
|
2033
|
-
}?: DTypeAndDevice): Array;
|
|
2180
|
+
declare function ldexp(x1: ArrayLike, x2: ArrayLike): Array;
|
|
2034
2181
|
/**
|
|
2035
|
-
*
|
|
2182
|
+
* Decompose floating-point values into mantissa and two's exponent.
|
|
2036
2183
|
*
|
|
2037
|
-
*
|
|
2038
|
-
*
|
|
2039
|
-
* `
|
|
2184
|
+
* The mantissa is returned in the range `(-1, 1)` with magnitude `>= 0.5` if
|
|
2185
|
+
* `x != 0`, and the exponent is an integer such that
|
|
2186
|
+
* `x = mantissa * 2**exponent`.
|
|
2040
2187
|
*/
|
|
2041
|
-
declare function
|
|
2042
|
-
|
|
2043
|
-
|
|
2044
|
-
|
|
2045
|
-
|
|
2046
|
-
|
|
2047
|
-
|
|
2048
|
-
|
|
2188
|
+
declare function frexp(x: ArrayLike): [Array, Array];
|
|
2189
|
+
/** Calculate `2**p` for all p in the input array. */
|
|
2190
|
+
declare function exp2(p: ArrayLike): Array;
|
|
2191
|
+
/** Return the base-2 logarithm of x, element-wise. */
|
|
2192
|
+
declare function log2(x: ArrayLike): Array;
|
|
2193
|
+
/** Return the base-10 logarithm of x, element-wise. */
|
|
2194
|
+
declare function log10(x: ArrayLike): Array;
|
|
2195
|
+
/** Calculate `exp(x) - 1` element-wise. */
|
|
2196
|
+
declare function expm1(x: ArrayLike): Array;
|
|
2197
|
+
/** Calculate the natural logarithm of `1 + x` element-wise. */
|
|
2198
|
+
declare function log1p(x: ArrayLike): Array;
|
|
2199
|
+
/** Convert angles from degrees to radians. */
|
|
2200
|
+
declare function deg2rad(x: ArrayLike): Array;
|
|
2201
|
+
/** @function Alias of `jax.numpy.deg2rad()`. */
|
|
2202
|
+
declare const radians: typeof deg2rad;
|
|
2203
|
+
/** Convert angles from radians to degrees. */
|
|
2204
|
+
declare function rad2deg(x: ArrayLike): Array;
|
|
2205
|
+
/** @function Alias of `jax.numpy.rad2deg()`. */
|
|
2206
|
+
declare const degrees: typeof rad2deg;
|
|
2049
2207
|
/**
|
|
2050
|
-
*
|
|
2051
|
-
*
|
|
2052
|
-
* Returns _num_ evenly spaced samples, calculated over the interval
|
|
2053
|
-
* [`start`, `stop`]. The endpoint `stop` is included in the result by default,
|
|
2054
|
-
* but this is controlled by the `endpoint` parameter.
|
|
2055
|
-
*
|
|
2056
|
-
* The default data type is Float32. Use arange() for integer steps.
|
|
2208
|
+
* @function
|
|
2209
|
+
* Computes first array raised to power of second array, element-wise.
|
|
2057
2210
|
*/
|
|
2058
|
-
declare
|
|
2059
|
-
|
|
2060
|
-
|
|
2061
|
-
}?: DTypeAndDevice): Array;
|
|
2211
|
+
declare const power: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
2212
|
+
/** @function Calculate the element-wise cube root of the input array. */
|
|
2213
|
+
declare const cbrt: OwnedFunction<(x: ArrayLike) => Array>;
|
|
2062
2214
|
/**
|
|
2063
|
-
*
|
|
2064
|
-
*
|
|
2065
|
-
* In linear space, the sequence starts at `base ** start` and ends at
|
|
2066
|
-
* `base ** stop` (see `endpoint` below).
|
|
2215
|
+
* @function
|
|
2216
|
+
* Calculate element-wise hyperbolic sine of input.
|
|
2067
2217
|
*
|
|
2068
|
-
*
|
|
2069
|
-
* @param stop - `base ** stop` is the final value of the sequence, unless `endpoint` is false.
|
|
2070
|
-
* @param num - Number of samples to generate. Default is 50.
|
|
2071
|
-
* @param endpoint - If true, `stop` is the last sample. Otherwise, it is not included. Default is true.
|
|
2072
|
-
* @param base - The base of the log space. Default is 10.
|
|
2073
|
-
* @returns Array of evenly spaced values on a log scale.
|
|
2218
|
+
* `sinh(x) = (exp(x) - exp(-x)) / 2`
|
|
2074
2219
|
*/
|
|
2075
|
-
declare
|
|
2076
|
-
dtype,
|
|
2077
|
-
device
|
|
2078
|
-
}?: DTypeAndDevice): Array;
|
|
2079
|
-
declare namespace lax_linalg_d_exports {
|
|
2080
|
-
export { cholesky, lu, triangularSolve };
|
|
2081
|
-
}
|
|
2220
|
+
declare const sinh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
2082
2221
|
/**
|
|
2083
|
-
*
|
|
2084
|
-
*
|
|
2085
|
-
* The Cholesky decomposition of a matrix `A` is:
|
|
2086
|
-
*
|
|
2087
|
-
* - A = L @ L^T (for upper=false, default)
|
|
2088
|
-
* - A = U^T @ U (for upper=true)
|
|
2089
|
-
*
|
|
2090
|
-
* where `L` is a lower-triangular matrix and `U` is an upper-triangular matrix.
|
|
2091
|
-
* The input matrix must be symmetric and positive-definite.
|
|
2092
|
-
*
|
|
2093
|
-
* @example
|
|
2094
|
-
* ```ts
|
|
2095
|
-
* import { lax, numpy as np } from "@jax-js/jax";
|
|
2096
|
-
*
|
|
2097
|
-
* const x = np.array([[2., 1.], [1., 2.]]);
|
|
2098
|
-
*
|
|
2099
|
-
* // Lower Cholesky factorization (default):
|
|
2100
|
-
* const L = lax.linalg.cholesky(x);
|
|
2101
|
-
* // L ≈ [[1.4142135, 0], [0.70710677, 1.2247449]]
|
|
2222
|
+
* @function
|
|
2223
|
+
* Calculate element-wise hyperbolic cosine of input.
|
|
2102
2224
|
*
|
|
2103
|
-
*
|
|
2104
|
-
* const U = lax.linalg.cholesky(x, { upper: true });
|
|
2105
|
-
* // U ≈ [[1.4142135, 0.70710677], [0, 1.2247449]]
|
|
2106
|
-
* ```
|
|
2225
|
+
* `cosh(x) = (exp(x) + exp(-x)) / 2`
|
|
2107
2226
|
*/
|
|
2108
|
-
declare
|
|
2109
|
-
upper
|
|
2110
|
-
}?: {
|
|
2111
|
-
upper?: boolean;
|
|
2112
|
-
}): Array;
|
|
2227
|
+
declare const cosh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
2113
2228
|
/**
|
|
2114
|
-
*
|
|
2115
|
-
*
|
|
2116
|
-
* Computes the matrix decomposition: `P @ A = L @ U`, where `P` is a
|
|
2117
|
-
* permutation of the rows of `A`, `L` is lower-triangular with unit diagonal,
|
|
2118
|
-
* and `U` is upper-triangular.
|
|
2119
|
-
*
|
|
2120
|
-
* @param x - A batch of matrices with shape `[..., m, n]`.
|
|
2121
|
-
*
|
|
2122
|
-
* @returns A tuple `(lu, pivots, permutation)` where:
|
|
2123
|
-
* - `lu`: combined lower and upper triangular matrices.
|
|
2124
|
-
* - `pivots`: an array of pivot indices with shape `[..., min(m, n)]`.
|
|
2125
|
-
* - `permutation`: the permutation generated by pivots with shape `[..., m]`.
|
|
2126
|
-
*
|
|
2127
|
-
* @example
|
|
2128
|
-
* ```ts
|
|
2129
|
-
* import { lax, numpy as np } from "@jax-js/jax";
|
|
2229
|
+
* @function
|
|
2230
|
+
* Calculate element-wise hyperbolic tangent of input.
|
|
2130
2231
|
*
|
|
2131
|
-
*
|
|
2132
|
-
* const [lu, pivots, permutation] = lax.linalg.lu(A);
|
|
2133
|
-
* // lu ≈ [[6., 3.], [0.6666667, 1.0]]
|
|
2134
|
-
* // pivots = [1, 1]
|
|
2135
|
-
* // permutation = [1, 0]
|
|
2136
|
-
* ```
|
|
2232
|
+
* `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
|
|
2137
2233
|
*/
|
|
2138
|
-
declare
|
|
2234
|
+
declare const tanh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
2139
2235
|
/**
|
|
2140
|
-
*
|
|
2141
|
-
*
|
|
2142
|
-
* Solves `a @ x = b` (if leftSide=true) or `x @ a = b` (if leftSide=false)
|
|
2143
|
-
* where `a` is a triangular matrix.
|
|
2144
|
-
*
|
|
2145
|
-
* @example
|
|
2146
|
-
* ```ts
|
|
2147
|
-
* import { lax, numpy as np } from "@jax-js/jax";
|
|
2148
|
-
*
|
|
2149
|
-
* const L = np.array([[2., 0.], [1., 3.]]);
|
|
2150
|
-
* const b = np.array([4., 7.]).reshape([2, 1]);
|
|
2236
|
+
* @function
|
|
2237
|
+
* Calculate element-wise inverse hyperbolic sine of input.
|
|
2151
2238
|
*
|
|
2152
|
-
*
|
|
2153
|
-
* const x = lax.linalg.triangularSolve(L, b, { leftSide: true, lower: true });
|
|
2154
|
-
* // x = [[2.], [5./3.]]
|
|
2155
|
-
* ```
|
|
2239
|
+
* `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
|
|
2156
2240
|
*/
|
|
2157
|
-
declare
|
|
2158
|
-
leftSide,
|
|
2159
|
-
lower,
|
|
2160
|
-
transposeA,
|
|
2161
|
-
unitDiagonal
|
|
2162
|
-
}?: {
|
|
2163
|
-
leftSide?: boolean;
|
|
2164
|
-
lower?: boolean;
|
|
2165
|
-
transposeA?: boolean;
|
|
2166
|
-
unitDiagonal?: boolean;
|
|
2167
|
-
}): Array;
|
|
2168
|
-
declare namespace lax_d_exports {
|
|
2169
|
-
export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convWithGeneralPadding, dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
|
|
2170
|
-
}
|
|
2241
|
+
declare const arcsinh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
2171
2242
|
/**
|
|
2172
|
-
*
|
|
2243
|
+
* @function
|
|
2244
|
+
* Calculate element-wise inverse hyperbolic cosine of input.
|
|
2173
2245
|
*
|
|
2174
|
-
*
|
|
2175
|
-
|
|
2176
|
-
|
|
2246
|
+
* `arccosh(x) = ln(x + sqrt(x^2 - 1))`
|
|
2247
|
+
*/
|
|
2248
|
+
declare const arccosh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
2249
|
+
/**
|
|
2250
|
+
* @function
|
|
2251
|
+
* Calculate element-wise inverse hyperbolic tangent of input.
|
|
2177
2252
|
*
|
|
2178
|
-
*
|
|
2179
|
-
* dimensions, followed by `lhs` non-contracting dimensions, followed by
|
|
2180
|
-
* `rhs` non-contracting dimensions.
|
|
2253
|
+
* `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
|
|
2181
2254
|
*/
|
|
2182
|
-
|
|
2183
|
-
lhsContractingDims?: number[];
|
|
2184
|
-
rhsContractingDims?: number[];
|
|
2185
|
-
lhsBatchDims?: number[];
|
|
2186
|
-
rhsBatchDims?: number[];
|
|
2187
|
-
};
|
|
2255
|
+
declare const arctanh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
2188
2256
|
/**
|
|
2189
|
-
*
|
|
2257
|
+
* Compute the variance of an array.
|
|
2190
2258
|
*
|
|
2191
|
-
*
|
|
2192
|
-
*
|
|
2259
|
+
* The variance is computed for the flattened array by default, otherwise over
|
|
2260
|
+
* the specified axis.
|
|
2261
|
+
*
|
|
2262
|
+
* If `correction` is provided, the divisor in calculation is `N - correction`,
|
|
2263
|
+
* where `N` represents the number of elements (e.g., for Bessel's correction).
|
|
2193
2264
|
*/
|
|
2194
|
-
declare function
|
|
2195
|
-
|
|
2196
|
-
|
|
2197
|
-
|
|
2198
|
-
rhsBatchDims: rb
|
|
2199
|
-
}?: DotDimensionNumbers): Array;
|
|
2200
|
-
type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | Pair[];
|
|
2265
|
+
declare function var_(x: ArrayLike, axis?: Axis, opts?: {
|
|
2266
|
+
mean?: ArrayLike;
|
|
2267
|
+
correction?: number;
|
|
2268
|
+
} & ReduceOpts): Array;
|
|
2201
2269
|
/**
|
|
2202
|
-
*
|
|
2270
|
+
* Compute the standard deviation of an array.
|
|
2203
2271
|
*
|
|
2204
|
-
* The
|
|
2205
|
-
*
|
|
2272
|
+
* The standard deviation is computed for the flattened array by default,
|
|
2273
|
+
* otherwise over the specified axis.
|
|
2206
2274
|
*
|
|
2207
|
-
*
|
|
2275
|
+
* If `correction` is provided, the divisor in calculation is `N - correction`,
|
|
2276
|
+
* where `N` represents the number of elements (e.g., for Bessel's correction).
|
|
2208
2277
|
*/
|
|
2209
|
-
declare function
|
|
2210
|
-
|
|
2211
|
-
|
|
2212
|
-
|
|
2278
|
+
declare function std(x: ArrayLike, axis?: Axis, opts?: {
|
|
2279
|
+
mean?: ArrayLike;
|
|
2280
|
+
correction?: number;
|
|
2281
|
+
} & ReduceOpts): Array;
|
|
2282
|
+
/** Estimate the sample covariance of a set of variables. */
|
|
2283
|
+
declare function cov(x: ArrayLike, y?: ArrayLike | null, {
|
|
2284
|
+
rowvar
|
|
2213
2285
|
}?: {
|
|
2214
|
-
|
|
2215
|
-
rhsDilation?: number[];
|
|
2216
|
-
featureGroupCount?: number;
|
|
2286
|
+
rowvar?: boolean;
|
|
2217
2287
|
}): Array;
|
|
2218
|
-
/**
|
|
2219
|
-
declare function
|
|
2220
|
-
/**
|
|
2221
|
-
declare function
|
|
2222
|
-
/**
|
|
2223
|
-
declare function
|
|
2224
|
-
/**
|
|
2225
|
-
declare function
|
|
2288
|
+
/** Compute the Pearson correlation coefficients (in range `[-1, 1]`). */
|
|
2289
|
+
declare function corrcoef(x: ArrayLike, y?: ArrayLike): Array;
|
|
2290
|
+
/** Test element-wise for positive or negative infinity, return bool array. */
|
|
2291
|
+
declare function isinf(x: ArrayLike): Array;
|
|
2292
|
+
/** Test element-wise for NaN (Not a Number). */
|
|
2293
|
+
declare function isnan(x: ArrayLike): Array;
|
|
2294
|
+
/** Test element-wise for negative infinity, return bool array. */
|
|
2295
|
+
declare function isneginf(x: ArrayLike): Array;
|
|
2296
|
+
/** Test element-wise for positive infinity, return bool array. */
|
|
2297
|
+
declare function isposinf(x: ArrayLike): Array;
|
|
2226
2298
|
/**
|
|
2227
|
-
*
|
|
2299
|
+
* Replace NaN and infinite entries in an array.
|
|
2228
2300
|
*
|
|
2229
|
-
*
|
|
2230
|
-
*
|
|
2301
|
+
* By default, NaNs are replaced with `0.0`, and infinities are are substituted
|
|
2302
|
+
* with the corresponding maximum or minimum finite values.
|
|
2231
2303
|
*/
|
|
2232
|
-
declare function
|
|
2304
|
+
declare function nanToNum(x: ArrayLike, {
|
|
2305
|
+
nan,
|
|
2306
|
+
posinf,
|
|
2307
|
+
neginf
|
|
2308
|
+
}?: {
|
|
2309
|
+
nan?: ArrayLike;
|
|
2310
|
+
posinf?: ArrayLike | null;
|
|
2311
|
+
neginf?: ArrayLike | null;
|
|
2312
|
+
}): Array;
|
|
2233
2313
|
/**
|
|
2234
|
-
*
|
|
2235
|
-
*
|
|
2236
|
-
* Behaves as the identity function but prevents the flow of gradients during
|
|
2237
|
-
* forward or reverse-mode automatic differentiation.
|
|
2314
|
+
* @function
|
|
2315
|
+
* Test element-wise for finite values (not infinity or NaN).
|
|
2238
2316
|
*/
|
|
2239
|
-
declare
|
|
2317
|
+
declare const isfinite: OwnedFunction<(x: ArrayLike) => Array>;
|
|
2240
2318
|
declare namespace nn_d_exports {
|
|
2241
|
-
export { celu, elu, gelu, glu, hardSigmoid, hardSilu, hardSilu as hardSwish, hardTanh, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, selu, sigmoid, silu, softSign, softmax, softplus, sparsePlus, sparseSigmoid, squareplus, standardize, silu as swish };
|
|
2319
|
+
export { celu, dotProductAttention, elu, gelu, glu, hardSigmoid, hardSilu, hardSilu as hardSwish, hardTanh, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, selu, sigmoid, silu, softSign, softmax, softplus, sparsePlus, sparseSigmoid, squareplus, standardize, silu as swish };
|
|
2242
2320
|
}
|
|
2243
2321
|
/**
|
|
2244
2322
|
* Rectified Linear Unit (ReLU) activation function:
|
|
@@ -2435,6 +2513,56 @@ declare function standardize(x: ArrayLike, axis?: Axis, opts?: {
|
|
|
2435
2513
|
* ```
|
|
2436
2514
|
*/
|
|
2437
2515
|
declare function oneHot(x: Array, numClasses: number): Array;
|
|
2516
|
+
/**
|
|
2517
|
+
* Scaled dot product attention (SDPA).
|
|
2518
|
+
*
|
|
2519
|
+
* Computes `softmax((Q @ K^T) / sqrt(d) + bias) @ V`, where `Q` is the query,
|
|
2520
|
+
* `K` is the key, `V` is the value, and `d` is the dimensionality of each key
|
|
2521
|
+
* and query vector.
|
|
2522
|
+
*
|
|
2523
|
+
* Multi-query attention is applied when input `key` and `value` tensors have
|
|
2524
|
+
* fewer heads than `query`.
|
|
2525
|
+
*
|
|
2526
|
+
* We use the following uppercase letters to denote array shapes:
|
|
2527
|
+
* - `B` = batch size
|
|
2528
|
+
* - `S` = length of key/value sequences (source)
|
|
2529
|
+
* - `L` = length of query sequences
|
|
2530
|
+
* - `N` = number of attention heads
|
|
2531
|
+
* - `H` = dimensionality of each attention head
|
|
2532
|
+
* - `K` = number of key/value heads (for grouped-query attention)
|
|
2533
|
+
*
|
|
2534
|
+
* The batch size `B` may be omitted, which is equivalent to `B = 1`. In this
|
|
2535
|
+
* case it must be omitted from all inputs.
|
|
2536
|
+
*
|
|
2537
|
+
* @param query - Query array; shape `[B, L, N, H]`
|
|
2538
|
+
* @param key - Key array; shape `[B, S, K, H]`
|
|
2539
|
+
* @param value - Value array; same shape as `key`
|
|
2540
|
+
* @param opts.bias - Optional bias to add to the attention logits; shape
|
|
2541
|
+
* `[B, N, L, S]` or broadcastable to it.
|
|
2542
|
+
* @param opts.mask - Optional mask to apply to the attention logits; should be
|
|
2543
|
+
* a boolean array broadcastable to `[B, N, L, S]`, where `true` indicates
|
|
2544
|
+
* the element should take part in attention.
|
|
2545
|
+
* @param opts.scale - Scaling factor override, default is `1 / sqrt(H)`.
|
|
2546
|
+
* @param opts.isCausal - If true, applies a casual mask.
|
|
2547
|
+
* @param opts.querySeqLengths - Optional sequence lengths for the queries;
|
|
2548
|
+
* shape `(B,)`. Taken from the beginning of the tensor.
|
|
2549
|
+
* @param opts.keyValueSeqLengths - Optional sequence lengths for the keys and
|
|
2550
|
+
* values; shape `(B,)`. Taken from the beginning of the tensor.
|
|
2551
|
+
* @param opts.localWindowSize - If specified, applies a local attention window
|
|
2552
|
+
* of the given size. Can be a single number or a tuple `[left, right]`.
|
|
2553
|
+
*
|
|
2554
|
+
* @returns The result of the attention operation; shape is the same as query
|
|
2555
|
+
* `[B, L, N, H]`, or `[L, N, H]` if `B` is omitted.
|
|
2556
|
+
*/
|
|
2557
|
+
declare function dotProductAttention(query: ArrayLike, key: ArrayLike, value: ArrayLike, opts?: {
|
|
2558
|
+
bias?: ArrayLike;
|
|
2559
|
+
mask?: ArrayLike;
|
|
2560
|
+
scale?: number;
|
|
2561
|
+
isCausal?: boolean;
|
|
2562
|
+
querySeqLengths?: ArrayLike;
|
|
2563
|
+
keyValueSeqLengths?: ArrayLike;
|
|
2564
|
+
localWindowSize?: number | [number, number];
|
|
2565
|
+
}): Array;
|
|
2438
2566
|
declare namespace random_d_exports {
|
|
2439
2567
|
export { bernoulli, bits, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
|
|
2440
2568
|
}
|
|
@@ -2526,7 +2654,9 @@ declare const logit: OwnedFunction<(x: ArrayLike) => Array>;
|
|
|
2526
2654
|
* @function
|
|
2527
2655
|
* Compute the forward-mode Jacobian-vector product for a function.
|
|
2528
2656
|
*/
|
|
2529
|
-
declare const jvp: <F extends (...args: any[]) => JsTree<Array
|
|
2657
|
+
declare const jvp: <F extends (...args: any[]) => JsTree<Array>, HA extends boolean = false>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, tangents: MapJsTree<Parameters<F>, Array, ArrayLike>, opts?: {
|
|
2658
|
+
hasAux?: HA;
|
|
2659
|
+
}) => HA extends true ? ReturnType<F> extends [infer Out, infer Aux] ? [Out, Out, Aux] : never : [ReturnType<F>, ReturnType<F>];
|
|
2530
2660
|
/**
|
|
2531
2661
|
* @function
|
|
2532
2662
|
* Vectorize an operation on a batched axis for one or more inputs.
|
|
@@ -2568,28 +2698,100 @@ declare const jit: <F extends (...args: any[]) => JsTree<Array>>(f: F, opts?: Ji
|
|
|
2568
2698
|
* Produce a local linear approximation to a function at a point using jvp() and
|
|
2569
2699
|
* partial evaluation.
|
|
2570
2700
|
*/
|
|
2571
|
-
declare const linearize: <F extends (...args: any[]) => JsTree<Array
|
|
2701
|
+
declare const linearize: <F extends (...args: any[]) => JsTree<Array>, HA extends boolean = false>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, opts?: {
|
|
2702
|
+
hasAux?: HA;
|
|
2703
|
+
}) => HA extends true ? ReturnType<F> extends [infer Out, infer Aux] ? [Out, OwnedFunction<(...tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => Out>, Aux] : never : [ReturnType<F>, OwnedFunction<(...tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>>];
|
|
2572
2704
|
/**
|
|
2573
2705
|
* @function
|
|
2574
2706
|
* Calculate the reverse-mode vector-Jacobian product for a function.
|
|
2707
|
+
*
|
|
2708
|
+
* The return value is a tuple of `[out, vjpFn]`, where `out` is the output of
|
|
2709
|
+
* `f(primals)`, and `vjpFn` is a function that takes in cotangents for each
|
|
2710
|
+
* output and returns the cotangents for each input.
|
|
2711
|
+
*
|
|
2712
|
+
* When `{ hasAux: true }` is passed, the function `f` is expected to return an
|
|
2713
|
+
* `[out, aux]` tuple, and `vjp` returns `[out, vjpFn, aux]`.
|
|
2714
|
+
*
|
|
2715
|
+
* @example
|
|
2716
|
+
* ```ts
|
|
2717
|
+
* const [y, vjpFn] = vjp(f, [x]);
|
|
2718
|
+
*
|
|
2719
|
+
* // With hasAux
|
|
2720
|
+
* const [y, vjpFn, aux] = vjp(f, [x], { hasAux: true });
|
|
2721
|
+
* ```
|
|
2575
2722
|
*/
|
|
2576
|
-
declare const vjp: <F extends (...args: any[]) => JsTree<Array
|
|
2723
|
+
declare const vjp: <F extends (...args: any[]) => JsTree<Array>, const HA extends boolean = false>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, opts?: {
|
|
2724
|
+
hasAux?: HA;
|
|
2725
|
+
}) => HA extends true ? ReturnType<F> extends [infer Out, infer Aux] ? [Out, OwnedFunction<(cotangents: MapJsTree<Out, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>>, Aux] : never : [ReturnType<F>, OwnedFunction<(cotangents: MapJsTree<ReturnType<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>>];
|
|
2726
|
+
/** @inline */
|
|
2727
|
+
type GradOutputType<I, F extends (...args: any[]) => any> = MapJsTree<I extends undefined ? Parameters<F>[0] : I extends number ? Parameters<F>[I] : I extends number[] ? { [K in keyof I]: I[K] extends number ? Parameters<F>[I[K]] : never } : never, ArrayLike, Array>;
|
|
2577
2728
|
/**
|
|
2578
2729
|
* @function
|
|
2579
2730
|
* Compute the gradient of a scalar-valued function `f` with respect to its
|
|
2580
2731
|
* first argument.
|
|
2732
|
+
*
|
|
2733
|
+
* Pass in different `argnums` to differentiate with respect to other
|
|
2734
|
+
* arguments. If a tuple is provided, the return value will be a tuple of
|
|
2735
|
+
* gradients corresponding to each argument index.
|
|
2736
|
+
*
|
|
2737
|
+
* When `{ hasAux: true }` is passed, the function `f` is expected to return a
|
|
2738
|
+
* `[out, aux]` tuple, and the return value will be `[gradient, aux]`.
|
|
2739
|
+
*
|
|
2740
|
+
* @example
|
|
2741
|
+
* ```ts
|
|
2742
|
+
* const gradient = grad(f)(x);
|
|
2743
|
+
*
|
|
2744
|
+
* // With `argnums`
|
|
2745
|
+
* const [gradientX, gradientZ] = grad(f, { argnums: [0, 2] })(x, y, z);
|
|
2746
|
+
*
|
|
2747
|
+
* // With `hasAux`
|
|
2748
|
+
* const [gradient, aux] = grad(f, { hasAux: true })(x);
|
|
2749
|
+
* ```
|
|
2581
2750
|
*/
|
|
2582
|
-
declare const grad: <F extends (...args: any[]) => JsTree<Array
|
|
2751
|
+
declare const grad: <F extends (...args: any[]) => JsTree<Array>, const I extends undefined | number | number[] = undefined, const HA extends boolean = false>(f: F, opts?: Omit<GradOpts, "argnums" | "hasAux"> & {
|
|
2752
|
+
argnums?: I;
|
|
2753
|
+
hasAux?: HA;
|
|
2754
|
+
}) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => HA extends true ? ReturnType<F> extends [any, infer Aux] ? [GradOutputType<I, F>, Aux] : never : GradOutputType<I, F>;
|
|
2583
2755
|
/**
|
|
2584
2756
|
* @function
|
|
2585
2757
|
* Create a function that evaluates both `f` and the gradient of `f`.
|
|
2758
|
+
*
|
|
2759
|
+
* When `{ hasAux: true }` is passed, the function `f` is expected to return an
|
|
2760
|
+
* `[out, aux]` tuple, and the return value will be `[[out, aux], gradient]`.
|
|
2761
|
+
*
|
|
2762
|
+
* @example
|
|
2763
|
+
* ```ts
|
|
2764
|
+
* // Without hasAux
|
|
2765
|
+
* const [value, gradient] = valueAndGrad(f)(x);
|
|
2766
|
+
*
|
|
2767
|
+
* // With hasAux
|
|
2768
|
+
* const [[value, aux], gradient] = valueAndGrad(f, { hasAux: true })(x);
|
|
2769
|
+
* ```
|
|
2586
2770
|
*/
|
|
2587
|
-
declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array
|
|
2771
|
+
declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>, const I extends undefined | number | number[] = undefined, const HA extends boolean = false>(f: F, opts?: Omit<GradOpts, "argnums"> & {
|
|
2772
|
+
argnums?: I;
|
|
2773
|
+
hasAux?: HA;
|
|
2774
|
+
}) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, GradOutputType<I, F>];
|
|
2588
2775
|
/**
|
|
2589
2776
|
* @function
|
|
2590
2777
|
* Compute the Jacobian evaluated row-by-row by reverse-mode AD.
|
|
2591
2778
|
*/
|
|
2592
2779
|
declare const jacrev: typeof jacfwd;
|
|
2780
|
+
/**
|
|
2781
|
+
* @function
|
|
2782
|
+
* Compute the Hessian matrix of a scalar-valued function.
|
|
2783
|
+
*
|
|
2784
|
+
* The Hessian is the matrix of second-order partial derivatives of a function.
|
|
2785
|
+
* This is implemented as `jacfwd(grad(f))`.
|
|
2786
|
+
*
|
|
2787
|
+
* @example
|
|
2788
|
+
* ```ts
|
|
2789
|
+
* const f = (x: np.Array) => np.sum(x.ref.mul(x.ref).mul(x)); // x^3
|
|
2790
|
+
* const H = hessian(f)(np.array([1, 2, 3]));
|
|
2791
|
+
* // H[i,j] = d^2f / dx_i dx_j
|
|
2792
|
+
* ```
|
|
2793
|
+
*/
|
|
2794
|
+
declare const hessian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
2593
2795
|
/**
|
|
2594
2796
|
* Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
|
|
2595
2797
|
*
|
|
@@ -2612,4 +2814,4 @@ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
|
|
|
2612
2814
|
*/
|
|
2613
2815
|
declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
|
|
2614
2816
|
//#endregion
|
|
2615
|
-
export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|
|
2817
|
+
export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|