@jax-js/jax 0.1.1 → 0.1.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.d.cts CHANGED
@@ -227,6 +227,8 @@ declare class AluExp implements FpHashable {
227
227
  static erf(a: AluExp): AluExp;
228
228
  static erfc(a: AluExp): AluExp;
229
229
  static sqrt(a: AluExp): AluExp;
230
+ static floor(a: AluExp): AluExp;
231
+ static ceil(a: AluExp): AluExp;
230
232
  static reciprocal(a: AluExp): AluExp;
231
233
  static cast(dtype: DType, a: AluExp): AluExp;
232
234
  static bitcast(dtype: DType, a: AluExp): AluExp;
@@ -252,7 +254,7 @@ declare class AluExp implements FpHashable {
252
254
  /** Substitute variables in this AluExp to values. */
253
255
  substitute(variables: Record<string, AluExp>): AluExp;
254
256
  /** Reindex gid values in this expression as needed. */
255
- reindexGids(gidMap: Map<number, number>): AluExp;
257
+ reindexGids(newGids: number[]): AluExp;
256
258
  get min(): number;
257
259
  get max(): number;
258
260
  /** Largest known integer that divides self. */
@@ -317,6 +319,8 @@ declare enum AluOp {
317
319
  Erf = "Erf",
318
320
  Erfc = "Erfc",
319
321
  Sqrt = "Sqrt",
322
+ Floor = "Floor",
323
+ Ceil = "Ceil",
320
324
  Reciprocal = "Reciprocal",
321
325
  Cast = "Cast",
322
326
  Bitcast = "Bitcast",
@@ -459,679 +463,85 @@ declare class Executable<T = any> {
459
463
  constructor(kernel: Kernel, /** Extra data specific to the backend running this kernel. */
460
464
  data: T);
461
465
  }
462
- declare namespace tree_d_exports {
463
- export { JsTree, JsTreeDef, MapJsTree, NodeType, dispose, flatten, leaves, map, ref, structure, unflatten };
464
- }
465
- declare enum NodeType {
466
- Array = "Array",
467
- Object = "Object",
468
- Leaf = "Leaf",
469
- }
470
- /** Analog to the JAX "pytree" object, but for JavaScript. */
471
- type JsTree<T> = T | JsTree<T>[] | {
472
- [key: string]: JsTree<T>;
473
- };
474
- type Same<X, Y> = (<T>() => T extends X ? 1 : 2) extends (<T>() => T extends Y ? 1 : 2) ? true : false;
475
- type MappedJsTree<T, A, B> = T extends A ? B : T extends Array ? T : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
476
- /** @ignore Convert a subtype of JsTree<A> into a JsTree<B>, with the same structure. */
477
- type MapJsTree<T, A, B> = Same<A, B> extends true ? T : MappedJsTree<T, A, B>;
478
- /** Represents the structure of a JsTree. */
479
- declare class JsTreeDef {
480
- readonly nodeType: NodeType;
481
- readonly nodeMetadata: any;
482
- readonly childTreedefs: JsTreeDef[];
483
- static leaf: JsTreeDef;
484
- constructor(nodeType: NodeType, nodeMetadata: any,
485
- // Must be comparable with deepEqual.
486
- childTreedefs: JsTreeDef[]);
487
- /** Get the total number of leaves in the tree. */
488
- get size(): number;
489
- /** Returns a string representation of this tree definition. */
490
- toString(root?: boolean): string;
491
- /** Compare this tree definition with another. */
492
- equals(other: JsTreeDef): boolean;
493
- }
494
- /** Flatten a structured object, returning the tree definition. */
495
- declare function flatten<T>(tree: JsTree<T>): [T[], JsTreeDef];
496
- /** Get the leaves of a tree. */
497
- declare function leaves<T>(tree: JsTree<T>): T[];
498
- /** Get the treedef for a tree. */
499
- declare function structure<T>(tree: JsTree<T>): JsTreeDef;
500
- /** Reconstruct a structured object from the flattened representation. */
501
- declare function unflatten<T>(treedef: JsTreeDef, leaves: Iterable<T>): JsTree<T>;
502
- /** Maps a multi-input function over pytree args to produce a new pytree. */
503
- declare function map<T, U, Tree extends JsTree<T>>(fn: (...args: T[]) => U, tree: Tree, ...rest: Tree[]): MapJsTree<Tree, T, U>;
504
- /** Take a reference of every array in a tree. */
505
- declare function ref<Tree extends JsTree<Array>>(tree: Tree): Tree;
506
- /** Dispose every array in a tree. */
507
- declare function dispose<Tree extends JsTree<Array>>(tree: Tree | null | undefined): void;
508
- //#endregion
509
- //#region src/frontend/convolution.d.ts
510
- /** Definition of a general dilated convolution. Should be valid on creation. */
511
- interface ConvParams {
512
- strides: number[];
513
- padding: [number, number][];
514
- lhsDilation: number[];
515
- rhsDilation: number[];
466
+ declare namespace numpy_d_exports {
467
+ export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, cos, cosh, cumsum, cumulativeSum, deg2rad, degrees, diag, diagonal, divide, dot$1 as dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, float64, floor, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, positive, pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, squeeze, stack, std, subtract, sum, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
516
468
  }
469
+ declare const float32 = DType.Float32;
470
+ declare const int32 = DType.Int32;
471
+ declare const uint32 = DType.Uint32;
472
+ declare const bool = DType.Bool;
473
+ declare const float16 = DType.Float16;
474
+ declare const float64 = DType.Float64;
475
+ /** Euler's constant, `e = 2.7182818284590...` */
476
+ declare const e: number;
477
+ /** Euler-Mascheroni constant, `γ = 0.5772156649...` */
478
+ declare const eulerGamma = 0.5772156649015329;
479
+ /** Positive infinity. */
480
+ declare const inf: number;
481
+ /** Floating-point representation of NaN. */
482
+ declare const nan: number;
483
+ /** This is Pi, `π = 3.14159265358979...` */
484
+ declare const pi: number;
485
+ /** @function Element-wise addition, with broadcasting. */
486
+ declare const add: (x: ArrayLike, y: ArrayLike) => Array;
487
+ /** @function Element-wise multiplication, with broadcasting. */
488
+ declare const multiply: (x: ArrayLike, y: ArrayLike) => Array;
489
+ /** @function Numerical negative of every element of an array. */
490
+ declare const negative: (x: ArrayLike) => Array;
491
+ /** @function Calculate element-wise reciprocal of the input. This is `1/x`. */
492
+ declare const reciprocal: (x: ArrayLike) => Array;
493
+ /** @function Round input down to the nearest integer. */
494
+ declare const floor: (x: ArrayLike) => Array;
495
+ /** @function Round input up to the nearest integer. */
496
+ declare const ceil: (x: ArrayLike) => Array;
497
+ /** @function Element-wise sine function (takes radians). */
498
+ declare const sin: (x: ArrayLike) => Array;
499
+ /** @function Element-wise cosine function (takes radians). */
500
+ declare const cos: (x: ArrayLike) => Array;
501
+ /** @function Element-wise inverse sine function (inverse of sin). */
502
+ declare const asin: (x: ArrayLike) => Array;
503
+ /** @function Element-wise inverse tangent function (inverse of tan). */
504
+ declare const atan: (x: ArrayLike) => Array;
505
+ /** @function Calculate the exponential of all elements in the input array. */
506
+ declare const exp: (x: ArrayLike) => Array;
507
+ /** @function Calculate the natural logarithm of all elements in the input array. */
508
+ declare const log: (x: ArrayLike) => Array;
509
+ /** @function Calculate the square root of all elements in the input array. */
510
+ declare const sqrt: (x: ArrayLike) => Array;
511
+ /** @function Return element-wise minimum of the input arrays. */
512
+ declare const minimum: (x: ArrayLike, y: ArrayLike) => Array;
513
+ /** @function Return element-wise maximum of the input arrays. */
514
+ declare const maximum: (x: ArrayLike, y: ArrayLike) => Array;
515
+ /** @function Compare two arrays element-wise. */
516
+ declare const greater: (x: ArrayLike, y: ArrayLike) => Array;
517
+ /** @function Compare two arrays element-wise. */
518
+ declare const less: (x: ArrayLike, y: ArrayLike) => Array;
519
+ /** @function Compare two arrays element-wise. */
520
+ declare const equal: (x: ArrayLike, y: ArrayLike) => Array;
521
+ /** @function Compare two arrays element-wise. */
522
+ declare const notEqual: (x: ArrayLike, y: ArrayLike) => Array;
523
+ /** @function Compare two arrays element-wise. */
524
+ declare const greaterEqual: (x: ArrayLike, y: ArrayLike) => Array;
525
+ /** @function Compare two arrays element-wise. */
526
+ declare const lessEqual: (x: ArrayLike, y: ArrayLike) => Array;
527
+ /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
528
+ declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
517
529
  /**
518
- * Check that the shapes and parameters passed to convolution are valid.
519
- *
520
- * If the check succeeds, returns the output shape.
530
+ * @function
531
+ * Permute the dimensions of an array. Defaults to reversing the axis order.
521
532
  */
522
-
523
- //#endregion
524
- //#region src/frontend/core.d.ts
533
+ declare const transpose: (x: ArrayLike, perm?: number[]) => Array;
525
534
  /**
526
- * Frontend primitive operations, which are lowered into Kernel objects before
527
- * being dispatched to the backend.
528
- *
529
- * Any operation between arrays can be described in these parts. This is also
530
- * the set of primitives that can occur in Jaxpr programs, and the level at
531
- * which transformations like vmap, grad, and jvp occur. They are loosely based
532
- * on [XLA](https://openxla.org/xla/operation_semantics).
535
+ * @function
536
+ * Give a new shape to an array without changing its data.
533
537
  *
534
- * All n-ary operations support broadcasting, with NumPy semantics.
538
+ * One shape dimension can be -1. In this case, the value is inferred from the
539
+ * length of the array and remaining dimensions.
535
540
  */
536
- declare enum Primitive {
537
- Add = "add",
538
- Mul = "mul",
539
- Idiv = "idiv",
540
- Neg = "neg",
541
- Reciprocal = "reciprocal",
542
- StopGradient = "stop_gradient",
543
- Cast = "cast",
544
- Bitcast = "bitcast",
545
- RandomBits = "random_bits",
546
- Sin = "sin",
547
- Cos = "cos",
548
- Asin = "asin",
549
- Atan = "atan",
550
- Exp = "exp",
551
- Log = "log",
552
- Erf = "erf",
553
- Erfc = "erfc",
554
- Sqrt = "sqrt",
555
- Min = "min",
556
- Max = "max",
557
- Reduce = "reduce",
558
- Dot = "dot",
559
- // sum(x*y, axis=-1)
560
- Conv = "conv",
561
- // see lax.conv_general_dilated
562
- Pool = "pool",
563
- PoolTranspose = "pool_transpose",
564
- Compare = "compare",
565
- Where = "where",
566
- Transpose = "transpose",
567
- Broadcast = "broadcast",
568
- Reshape = "reshape",
569
- Flip = "flip",
570
- Shrink = "shrink",
571
- Pad = "pad",
572
- Gather = "gather",
573
- JitCall = "jit_call",
574
- }
575
- interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
576
- [Primitive.Cast]: {
577
- dtype: DType;
578
- };
579
- [Primitive.Bitcast]: {
580
- dtype: DType;
581
- };
582
- [Primitive.Reduce]: {
583
- op: AluOp;
584
- axis: number[];
585
- };
586
- [Primitive.Conv]: ConvParams;
587
- [Primitive.Pool]: {
588
- window: number[];
589
- strides: number[];
590
- };
591
- [Primitive.PoolTranspose]: {
592
- inShape: number[];
593
- window: number[];
594
- strides: number[];
595
- };
596
- [Primitive.Compare]: {
597
- op: CompareOp;
598
- };
599
- [Primitive.Transpose]: {
600
- perm: number[];
601
- };
602
- [Primitive.Broadcast]: {
603
- shape: number[];
604
- axis: number[];
605
- };
606
- [Primitive.RandomBits]: {
607
- shape: number[];
608
- mode: "xor" | 0 | 1;
609
- };
610
- [Primitive.Reshape]: {
611
- shape: number[];
612
- };
613
- [Primitive.Flip]: {
614
- axis: number[];
615
- };
616
- [Primitive.Shrink]: {
617
- slice: Pair[];
618
- };
619
- [Primitive.Pad]: {
620
- width: Pair[];
621
- };
622
- [Primitive.Gather]: {
623
- axis: number[];
624
- outDim: number;
625
- };
626
- [Primitive.JitCall]: {
627
- name: string;
628
- jaxpr: Jaxpr;
629
- numConsts: number;
630
- };
631
- }
632
- /** Type of parameters taken by each primitive. */
633
- type PrimitiveParams<T extends Primitive> = T extends keyof PrimitiveParamsImpl ? PrimitiveParamsImpl[T] : Record<string, never>;
634
- declare enum CompareOp {
635
- Less = "less",
636
- Equal = "equal",
637
- NotEqual = "not_equal",
638
- LessEqual = "less_equal",
639
- }
640
- /** @inline */
641
- type Axis = number | number[] | null;
642
- /** @inline */
643
- type ReduceOpts = {
644
- keepdims?: boolean;
645
- };
646
- type MainTrace = {
647
- level: number;
648
- traceType: new (main: MainTrace) => Trace;
649
- globalData: any | null;
650
- };
541
+ declare const reshape: (x: ArrayLike, shape: number[]) => Array;
651
542
  /**
652
- * Push an interpreter onto the trace stack. Use this like:
653
- * `using main = newMain(...);`
654
- */
655
-
656
- type TracerValue = Tracer | number | boolean;
657
- declare abstract class Trace {
658
- readonly main: MainTrace;
659
- constructor(main: MainTrace);
660
- abstract pure(val: TracerValue): Tracer;
661
- abstract lift(val: Tracer): Tracer;
662
- abstract processPrimitive<P extends Primitive>(primitive: P, tracers: Tracer[], params: PrimitiveParams<P>): Tracer[];
663
- }
664
- /** Internal representation of an array value. */
665
- interface AbstractValue {
666
- /** Shape of the array. Must be a static tuple of non-negative dimensions. */
667
- shape: number[];
668
- /** Concrete data type of array elements. */
669
- dtype: DType;
670
- /**
671
- * Arrays created from JavaScript numbers (e.g., `np.array(3)`) are created as
672
- * _weakly typed_ unless a dtype is explicitly specified.
673
- *
674
- * Weakly typed values will automatically cast to the data type of other
675
- * arrays when used as an operand as an expression. This property only affects
676
- * how they promote in type casting; their memory layout is still determined
677
- * by the actual `dtype` field.
678
- *
679
- * ```ts
680
- * const x = np.array(3); // weakType = true, dtype = float32
681
- * const y = np.array([1, 2], { dtype: np.int32 }); // weakType = false, dtype = int32
682
- * const z = x.add(y); // z has dtype int32 because x is weakly typed
683
- * ```
684
- *
685
- * Weak types are present in JIT programs in their spec (e.g., Jaxpr inputs
686
- * and outputs can be weakly typed) form. But they're solely a frontend
687
- * concept. Backends are not aware of weak types.
688
- */
689
- weakType: boolean;
690
- }
691
- /**
692
- * Broadcast shapes and promote types with casting for two avals.
693
- *
694
- * This implements the weak type behavior described in `promoteTypes()`, but not
695
- * implemented in that function as `weakType` is not passed.
696
- */
697
-
698
- declare abstract class Tracer {
699
- /** @ignore */
700
- readonly _trace: Trace;
701
- constructor(trace: Trace);
702
- abstract get aval(): AbstractValue;
703
- abstract toString(): string;
704
- /**
705
- * Access an array by reference, incrementing the reference count.
706
- *
707
- * jax-js handles freeing arrays by using "move" semantics, like in Rust/C++.
708
- * Whenever you pass an array into a function, that function should consume
709
- * the array, and it will no longer be usable. For example, if you had:
710
- *
711
- * ```
712
- * const x = np.array([1, 2, 3]);
713
- * const y = np.add(x, x);
714
- * ```
715
- *
716
- * The second line does not work because the first parameter consumes `x`, and
717
- * then the second parameter will already have been freed / disposed.
718
- *
719
- * To fix this, you can write:
720
- *
721
- * ```
722
- * const y = np.add(x.ref, x);
723
- * ```
724
- *
725
- * Under the hood, every access to `.ref` increments the internal reference
726
- * count of the array. The reference count starts at 1. When it hits 0, the
727
- * memory behind the array is freed.
728
- */
729
- abstract get ref(): this;
730
- /**
731
- * Manually decrement the reference count of the array.
732
- *
733
- * Arrays are created with reference count 1. Whenever it is used as argument
734
- * to a function or other operation, it is disposed (i.e., reference count
735
- * decreases by 1) automatically. Whenever a `.ref` is created, the reference
736
- * count increases.
737
- *
738
- * You generally don't need to call this function directly since arrays are
739
- * automatically disposed after being passed into an operation. One common
740
- * exception is when writing a function and ignoring one of its arguments. In
741
- * that case, by convention you should dispose of that argument manually.
742
- *
743
- * ```
744
- * function myCustomOperation(a: np.Array, b: np.Array) {
745
- * b.dispose(); // Needed to satisfy "move" rules.
746
- * return a.add(1);
747
- * }
748
- * ```
749
- */
750
- abstract dispose(): void;
751
- /** The shape of the array. */
752
- get shape(): number[];
753
- /** The total number of elements in the array. */
754
- get size(): number;
755
- /** The dtype of elements stored in the array. */
756
- get dtype(): DType;
757
- /**
758
- * Whether the array is weakly typed.
759
- *
760
- * Weakly typed arrays will cast to the dtype of the other operand. See
761
- * `promoteTypes()` for details.
762
- */
763
- get weakType(): boolean;
764
- /** The number of dimensions of the array. */
765
- get ndim(): number;
766
- /** @ignore */
767
- fullLower(): Tracer;
768
- neg(): this;
769
- add(other: this | TracerValue): this;
770
- mul(other: this | TracerValue): this;
771
- greater(other: this | TracerValue): this;
772
- less(other: this | TracerValue): this;
773
- equal(other: this | TracerValue): this;
774
- notEqual(other: this | TracerValue): this;
775
- greaterEqual(other: this | TracerValue): this;
776
- lessEqual(other: this | TracerValue): this;
777
- /** Sum of the elements of the array over a given axis, or axes. */
778
- sum(axis?: Axis, opts?: ReduceOpts): this;
779
- /** Product of the array elements over a given axis. */
780
- prod(axis?: Axis, opts?: ReduceOpts): this;
781
- /** Compute the average of the array elements along the specified axis. */
782
- mean(axis?: Axis, opts?: ReduceOpts): this;
783
- /** Permute the dimensions of an array. Defaults to reversing the axis order. */
784
- transpose(perm?: number[]): this;
785
- /**
786
- * Give a new shape to an array without changing its data.
787
- *
788
- * One shape dimension can be -1. In this case, the value is inferred from the
789
- * length of the array and remaining dimensions.
790
- */
791
- reshape(shape: number | number[]): this;
792
- /** Copy the array and cast to a specified dtype. */
793
- astype(dtype: DType): this;
794
- /** Subtract an array from this one. */
795
- sub(other: this | TracerValue): this;
796
- /** Divide an array by this one. */
797
- div(other: this | TracerValue): this;
798
- /** Return specified diagonals. See `numpy.diagonal` for full docs. */
799
- diagonal(offset?: number, axis1?: number, axis2?: number): this;
800
- /** Flatten the array without changing its data. */
801
- flatten(): this;
802
- /** Flatten the array without changing its data. */
803
- ravel(): this;
804
- /**
805
- * Iterate over the first dimension of this array, returning slices.
806
- *
807
- * This can be used to destructure arrays. For example:
808
- *
809
- * ```js
810
- * let x = np.array([[1, 2], [3, 4]]);
811
- * let [a, b] = x;
812
- * console.log(a.js()); // [1, 2]
813
- * console.log(b.js()); // [3, 4]
814
- * ```
815
- */
816
- [Symbol.iterator](): IterableIterator<this>;
817
- /**
818
- * Slice an array along one or more axes.
819
- *
820
- * This is the equivalent of slicing in Python, e.g. `x[1:3, 2, :, None]`. To
821
- * mimic this in JavaScript, we would write:
822
- *
823
- * ```js
824
- * x.slice([1, 3], 2, [], null);
825
- * ```
826
- *
827
- * The `slice` method accepts a variable number of arguments, each of which
828
- * can be a number, an empty array, a single-element array, a two-element
829
- * array, or `null`. The arguments are interpreted as follows:
830
- *
831
- * - A number `n` means to access the `n`-th element along that axis, removing
832
- * that axis from the resulting shape.
833
- * - An empty array `[]` means to keep that axis as-is, like `:` in Python.
834
- * - A single-element array `[i]` means to start slicing from index `i`
835
- * (inclusive) to the end of the axis, like `x[i:]`.
836
- * - A two-element array `[i, j]` means to slice from index `i` (inclusive)
837
- * to index `j` (exclusive), like `x[i:j]`.
838
- * - `null` means to add a new axis at that position, like `np.newaxis`.
839
- *
840
- * Like in Python, negative indices are supported, which count from the end of
841
- * the axis. For example, `-1` means the last element.
842
- *
843
- * Strided slices are not yet implemented, so you cannot write `x[::2]` or
844
- * similar.
845
- *
846
- * Advanced indexing by integer arrays is also supported. This translates to
847
- * the "gather" primitive, and it allows you to access specific elements of
848
- * the array by integer indices stored in another array.
849
- */
850
- slice(...index: (number | [] | [number] | Pair | null | Tracer)[]): this;
851
- }
852
- declare class ShapedArray implements AbstractValue {
853
- readonly shape: number[];
854
- readonly dtype: DType;
855
- readonly weakType: boolean;
856
- constructor(shape: number[], dtype: DType, weakType: boolean);
857
- static fromAval(aval: AbstractValue): ShapedArray;
858
- get ndim(): number;
859
- toString(): string;
860
- equals(other: ShapedArray): boolean;
861
- }
862
- //#endregion
863
- //#region src/frontend/array.d.ts
864
- type ArrayLike = Array | number | boolean;
865
- /** Version of pureArray with fudged types. */
866
-
867
- /**
868
- * An executable operation that will be dispatched to the backend.
869
- *
870
- * This holds a reference to all input buffers used in the operation. After the
871
- * operation is dispatched, the references should be released.
872
- */
873
- declare class PendingExecute {
874
- #private;
875
- readonly backend: Backend;
876
- readonly kernel: Kernel;
877
- readonly inputs: Slot[];
878
- readonly outputs: Slot[];
879
- prepared: Executable | null;
880
- submitted: boolean;
881
- constructor(backend: Backend, kernel: Kernel, inputs: Slot[], outputs: Slot[]);
882
- updateRc(delta: number): void;
883
- prepare(): Promise<void>;
884
- prepareSync(): void;
885
- submit(): void;
886
- }
887
- /** @inline */
888
- type DTypeAndDevice = {
889
- dtype?: DType;
890
- device?: Device;
891
- };
892
- type ArrayConstructorArgs = {
893
- source: AluExp | Slot;
894
- st: ShapeTracker;
895
- dtype: DType;
896
- weakType: boolean;
897
- backend: Backend;
898
- committed: boolean;
899
- pending?: Iterable<PendingExecute>;
900
- };
901
- /**
902
- * A multidimensional numeric array with data stored on CPU or GPU.
903
- *
904
- * This is the library's core data type. Equivalent to `jax.Array` from JAX, or
905
- * `torch.Tensor`.
906
- *
907
- * Not to be confused with the JavaScript "Array" constructor. Avoid importing
908
- * this into your code's namespace if you're already using the JavaScript
909
- * "Array" type by name.
910
- */
911
- declare class Array extends Tracer {
912
- #private;
913
- id: number;
914
- /**
915
- * @ignore
916
- * Constructs an array from source, shape and backend. Note that if the source
917
- * is a backend `Slot`, this constructor _takes ownership_ of the slot. It
918
- * will be freed when the array is disposed.
919
- */
920
- constructor(args: ArrayConstructorArgs);
921
- /** @ignore */
922
- get aval(): ShapedArray;
923
- /** Return a simple string representation of the array's dimensions. */
924
- toString(): string;
925
- get device(): Device;
926
- get ref(): this;
927
- dispose(): void;
928
- /**
929
- * Convert this array into a primitive value.
930
- *
931
- * This only works for scalars (0-dimensional arrays). It lets you get values
932
- * "out" of the JAX system. For instance, if `x = np.array(5)`, then you can
933
- * evaluate `x + 1` and `x ** 2` to get `6` and `25`, respectively.
934
- *
935
- * This method is also called for `==` equality.
936
- */
937
- [Symbol.toPrimitive](): any;
938
- /** Realize the array and return it as data. */
939
- data(): Promise<DataArray>;
940
- /**
941
- * Wait for this array to finish evaluation.
942
- *
943
- * Operations and data loading in jax-js are lazy, so this function ensures
944
- * that pending operations are dispatched and fully executed before it
945
- * returns.
946
- *
947
- * If you are mapping from `data()` or `dataSync()`, it will also trigger
948
- * dispatch of operations as well.
949
- *
950
- * **Note:** `jax.blockUntilReady()` is a higher-level API, it calls this
951
- * asynchronously for multiple arrays.
952
- */
953
- blockUntilReady(): Promise<Array>;
954
- /**
955
- * Realize the array and return it as data. This is a sync variant and not
956
- * recommended for performance reasons, as it will block rendering.
957
- */
958
- dataSync(): DataArray;
959
- /**
960
- * Convert this array into a JavaScript object.
961
- *
962
- * This is a blocking operation that will compile all of the shaders and wait
963
- * for execution to complete, synchronously. No other JavaScript code on the
964
- * site will be run during shader execution.
965
- *
966
- * To avoid blocking, prefer `jsAsync()` when possible.
967
- */
968
- js(): any;
969
- /** Convert this array into a JavaScript object, asynchronously. */
970
- jsAsync(): Promise<any>;
971
- /**
972
- * Copy an element of an array to a numeric scalar and return it.
973
- *
974
- * Throws an error if the array does not have a single element. The array must
975
- * either be rank-0, or all dimensions of the shape are 1.
976
- */
977
- item(): number;
978
- /** @private Internal plumbing method for Array / Tracer ops. */
979
- static _implRules(): typeof implRules;
980
- /** @private */
981
- _realizeSource(): number;
982
- /** @private Put this array on a new backend, asynchronously. */
983
- _put(backend: Backend): Promise<Array>;
984
- /** @private Put this array on a new backend, synchronously. */
985
- _putSync(backend: Backend): Array;
986
- }
987
- /** Constructor for creating a new array from data. */
988
- declare function array(values: Array | DataArray | RecursiveArray<number> | RecursiveArray<boolean>, {
989
- shape,
990
- dtype,
991
- device
992
- }?: {
993
- shape?: number[];
994
- } & DTypeAndDevice): Array;
995
- /** If x is a value, lift it into an array, otherwise leave it be. */
996
-
997
- type ImplRule<P extends Primitive> = (tracers: Array[], params: PrimitiveParams<P>) => Array[];
998
- declare const implRules: { [P in Primitive]: ImplRule<P> };
999
- /** Return a new array of given shape and type, filled with zeros. */
1000
- declare function zeros(shape: number[], {
1001
- dtype,
1002
- device
1003
- }?: DTypeAndDevice): Array;
1004
- /** Return a new array of given shape and type, filled with ones. */
1005
- declare function ones(shape: number[], {
1006
- dtype,
1007
- device
1008
- }?: DTypeAndDevice): Array;
1009
- /** Return a new array of given shape and type, filled with `fill_value`. */
1010
- declare function full(shape: number[], fillValue: number | boolean | Array, {
1011
- dtype,
1012
- device
1013
- }?: DTypeAndDevice): Array;
1014
- /**
1015
- * Create an identity matrix.
1016
- *
1017
- * If numCols is not provided, it defaults to numRows, i.e., a square identity
1018
- * matrix with ones on the diagonal.
1019
- */
1020
- declare function eye(numRows: number, numCols?: number, {
1021
- dtype,
1022
- device
1023
- }?: DTypeAndDevice): Array;
1024
- /** Return the identity matrix, with ones on the main diagonal. */
1025
- declare function identity$1(n: number, {
1026
- dtype,
1027
- device
1028
- }?: DTypeAndDevice): Array;
1029
- /**
1030
- * Return evenly spaced values within a given interval.
1031
- *
1032
- * This can be called with a varying number of arguments, just like the range()
1033
- * builtin function in Python.
1034
- *
1035
- * - `arange(stop)` is equivalent to `arange(0, stop, 1)`.
1036
- * - `arange(start, stop)` is equivalent to `arange(start, stop, 1)`.
1037
- * - `arange(start, stop, step)` creates an array starting at `start`, ending
1038
- * before `stop`, with a step size of `step`.
1039
- *
1040
- * Defaults to an integer data type. This can produce unintended results when
1041
- * using a non-integer step, so prefer linspace() in those cases.
1042
- */
1043
- declare function arange(start: number, stop?: number, step?: number, {
1044
- dtype,
1045
- device
1046
- }?: DTypeAndDevice): Array;
1047
- /**
1048
- * Return evenly spaced numbers over a specified interval.
1049
- *
1050
- * Returns _num_ evenly spaced samples, calculated over the interval
1051
- * [`start`, `stop`]. The endpoint `stop` is included in the result by default,
1052
- * but this is controlled by the `endpoint` parameter.
1053
- *
1054
- * The default data type is Float32. Use arange() for integer steps.
1055
- */
1056
- declare function linspace(start: number, stop: number, num?: number, endpoint?: boolean, {
1057
- dtype,
1058
- device
1059
- }?: DTypeAndDevice): Array;
1060
- declare namespace numpy_d_exports {
1061
- export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, float64, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
1062
- }
1063
- declare const float32 = DType.Float32;
1064
- declare const int32 = DType.Int32;
1065
- declare const uint32 = DType.Uint32;
1066
- declare const bool = DType.Bool;
1067
- declare const float16 = DType.Float16;
1068
- declare const float64 = DType.Float64;
1069
- /** Euler's constant, `e = 2.7182818284590...` */
1070
- declare const e: number;
1071
- /** Euler-Mascheroni constant, `γ = 0.5772156649...` */
1072
- declare const eulerGamma = 0.5772156649015329;
1073
- /** Positive infinity. */
1074
- declare const inf: number;
1075
- /** Floating-point representation of NaN. */
1076
- declare const nan: number;
1077
- /** This is Pi, `π = 3.14159265358979...` */
1078
- declare const pi: number;
1079
- /** @function Element-wise addition, with broadcasting. */
1080
- declare const add: (x: ArrayLike, y: ArrayLike) => Array;
1081
- /** @function Element-wise multiplication, with broadcasting. */
1082
- declare const multiply: (x: ArrayLike, y: ArrayLike) => Array;
1083
- /** @function Numerical negative of every element of an array. */
1084
- declare const negative: (x: ArrayLike) => Array;
1085
- /** @function Calculate element-wise reciprocal of the input. This is `1/x`. */
1086
- declare const reciprocal: (x: ArrayLike) => Array;
1087
- /** @function Element-wise sine function (takes radians). */
1088
- declare const sin: (x: ArrayLike) => Array;
1089
- /** @function Element-wise cosine function (takes radians). */
1090
- declare const cos: (x: ArrayLike) => Array;
1091
- /** @function Element-wise inverse sine function (inverse of sin). */
1092
- declare const asin: (x: ArrayLike) => Array;
1093
- /** @function Element-wise inverse tangent function (inverse of tan). */
1094
- declare const atan: (x: ArrayLike) => Array;
1095
- /** @function Calculate the exponential of all elements in the input array. */
1096
- declare const exp: (x: ArrayLike) => Array;
1097
- /** @function Calculate the natural logarithm of all elements in the input array. */
1098
- declare const log: (x: ArrayLike) => Array;
1099
- /** @function Calculate the square root of all elements in the input array. */
1100
- declare const sqrt: (x: ArrayLike) => Array;
1101
- /** @function Return element-wise minimum of the input arrays. */
1102
- declare const minimum: (x: ArrayLike, y: ArrayLike) => Array;
1103
- /** @function Return element-wise maximum of the input arrays. */
1104
- declare const maximum: (x: ArrayLike, y: ArrayLike) => Array;
1105
- /** @function Compare two arrays element-wise. */
1106
- declare const greater: (x: ArrayLike, y: ArrayLike) => Array;
1107
- /** @function Compare two arrays element-wise. */
1108
- declare const less: (x: ArrayLike, y: ArrayLike) => Array;
1109
- /** @function Compare two arrays element-wise. */
1110
- declare const equal: (x: ArrayLike, y: ArrayLike) => Array;
1111
- /** @function Compare two arrays element-wise. */
1112
- declare const notEqual: (x: ArrayLike, y: ArrayLike) => Array;
1113
- /** @function Compare two arrays element-wise. */
1114
- declare const greaterEqual: (x: ArrayLike, y: ArrayLike) => Array;
1115
- /** @function Compare two arrays element-wise. */
1116
- declare const lessEqual: (x: ArrayLike, y: ArrayLike) => Array;
1117
- /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
1118
- declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
1119
- /**
1120
- * @function
1121
- * Permute the dimensions of an array. Defaults to reversing the axis order.
1122
- */
1123
- declare const transpose: (x: ArrayLike, perm?: number[]) => Array;
1124
- /**
1125
- * @function
1126
- * Give a new shape to an array without changing its data.
1127
- *
1128
- * One shape dimension can be -1. In this case, the value is inferred from the
1129
- * length of the array and remaining dimensions.
1130
- */
1131
- declare const reshape: (x: ArrayLike, shape: number[]) => Array;
1132
- /**
1133
- * @function
1134
- * Move axes of an array to new positions. Other axes retain original order.
543
+ * @function
544
+ * Move axes of an array to new positions. Other axes retain original order.
1135
545
  */
1136
546
  declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
1137
547
  /**
@@ -1180,6 +590,8 @@ declare function prod(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1180
590
  declare function min(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1181
591
  /** Return the maximum of array elements along a given axis. */
1182
592
  declare function max(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
593
+ /** Return the peak-to-peak range along a given axis (`max - min`). */
594
+ declare function ptp(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1183
595
  /** Compute the average of the array elements along the specified axis. */
1184
596
  declare function mean(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1185
597
  /**
@@ -1196,6 +608,15 @@ declare function argmin(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
1196
608
  * specified axis.
1197
609
  */
1198
610
  declare function argmax(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
611
+ /**
612
+ * Cumulative sum of elements along an axis.
613
+ *
614
+ * Currently this function is `O(n^2)`, we'll improve this later on with a
615
+ * two-phase parallel reduction algorithm.
616
+ */
617
+ declare function cumsum(a: ArrayLike, axis?: number): Array;
618
+ /** @function Alternative name for `jax.numpy.cumsum()`. */
619
+ declare const cumulativeSum: typeof cumsum;
1199
620
  /** Reverse the elements in an array along the given axes. */
1200
621
  declare function flip(x: ArrayLike, axis?: Axis): Array;
1201
622
  /**
@@ -1245,6 +666,8 @@ declare function fliplr(x: ArrayLike): Array;
1245
666
  declare const permuteDims: (x: ArrayLike, perm?: number[]) => Array;
1246
667
  /** Return a 1-D flattened array containing the elements of the input. */
1247
668
  declare function ravel(a: ArrayLike): Array;
669
+ /** Remove one or more length-1 axes from an array. */
670
+ declare function squeeze(a: ArrayLike, axis?: Axis): Array;
1248
671
  /**
1249
672
  * Repeat each element of an array after themselves.
1250
673
  *
@@ -1253,381 +676,1078 @@ declare function ravel(a: ArrayLike): Array;
1253
676
  */
1254
677
  declare function repeat(a: ArrayLike, repeats: number, axis?: number): Array;
1255
678
  /**
1256
- * Construct an array by repeating A the number of times given by reps.
1257
- *
1258
- * If `A` is an array of shape `(d1, d2, ..., dn)` and `reps` is a sequence of
1259
- * integers, the resulting array will have a shape of `(reps[0] * d1,
1260
- * reps[1] * d2, ..., reps[n] * dn)`, with `A` tiled along each dimension.
679
+ * Construct an array by repeating A the number of times given by reps.
680
+ *
681
+ * If `A` is an array of shape `(d1, d2, ..., dn)` and `reps` is a sequence of
682
+ * integers, the resulting array will have a shape of `(reps[0] * d1,
683
+ * reps[1] * d2, ..., reps[n] * dn)`, with `A` tiled along each dimension.
684
+ */
685
+ declare function tile(a: ArrayLike, reps: number | number[]): Array;
686
+ /**
687
+ * Broadcast an array to a shape, with NumPy-style broadcasing rules.
688
+ *
689
+ * In other words, this lets you append axes to the left, and/or expand
690
+ * dimensions where the shape is 1.
691
+ */
692
+ declare function broadcastTo(a: ArrayLike, shape: number[]): Array;
693
+ /** Broadcast input shapes to a common output shape. */
694
+ declare function broadcastShapes(...shapes: number[][]): number[];
695
+ /** Broadcast arrays to a common shape. */
696
+ declare function broadcastArrays(...arrays: ArrayLike[]): Array[];
697
+ /**
698
+ * Return specified diagonals.
699
+ *
700
+ * If a is 2D, return the diagonal of the array with the given offset. If a is
701
+ * 3D or higher, compute diagonals along the two given axes (default: 0, 1).
702
+ *
703
+ * This returns a view over the existing array. The shape of the resulting array
704
+ * is determined by removing the two axes along which the diagonal is taken,
705
+ * then appending a new axis to the right with holding the diagonals.
706
+ */
707
+ declare function diagonal(a: ArrayLike, offset?: number, axis1?: number, axis2?: number): Array;
708
+ /**
709
+ * Extract a diagonal or construct a diagonal array.
710
+ *
711
+ * If v is a 2D array, return the k-th diagonal of v (as a view). If v is a 1D
712
+ * array, return a 2D array with v on the k-th diagonal.
713
+ */
714
+ declare function diag(v: ArrayLike, k?: number): Array;
715
+ /** Calculate the sum of the diagonal of an array along the given axes. */
716
+ declare function trace(a: ArrayLike, offset?: number, axis1?: number, axis2?: number): Array;
717
+ /** Return if two arrays are element-wise equal within a tolerance. */
718
+ declare function allclose(actual: Parameters<typeof array>[0], expected: Parameters<typeof array>[0], options?: {
719
+ rtol?: number;
720
+ atol?: number;
721
+ }): boolean;
722
+ /** Matrix product of two arrays. */
723
+ declare function matmul(x: ArrayLike, y: ArrayLike): Array;
724
+ /** Dot product of two arrays. */
725
+ declare function dot$1(x: ArrayLike, y: ArrayLike): Array;
726
+ /**
727
+ * Compute the tensor dot product of two N-dimensional arrays.
728
+ *
729
+ * The behavior is determined by `axes`. If an integer `k`, sum over the last
730
+ * `k` axes of x and the first `k` axes of y. If a tuple, then the first array
731
+ * corresponds to the axes of x and the second to the axes of y.
732
+ */
733
+ declare function tensordot(x: ArrayLike, y: ArrayLike, axes?: number | [number[], number[]]): Array;
734
+ /**
735
+ * Einstein summation with string subscripts.
736
+ *
737
+ * @example
738
+ * ```ts
739
+ * import { numpy as np } from "@jax-js/jax";
740
+ *
741
+ * const a = np.ones([2, 3]);
742
+ * const b = np.ones([3]);
743
+ * np.einsum("ij,j", a, b); // Shape [2]
744
+ * ```
745
+ */
746
+ declare function einsum(subscripts: string, ...operands: ArrayLike[]): Array;
747
+ /**
748
+ * Einstein summation alternating between arrays and numeric indices.
749
+ *
750
+ * @example
751
+ * ```ts
752
+ * import { numpy as np } from "@jax-js/jax";
753
+ *
754
+ * const a = np.ones([2, 3]);
755
+ * const b = np.ones([3]);
756
+ * np.einsum(a, [0, 1], b, [1]); // Shape [2]
757
+ * ```
758
+ */
759
+ declare function einsum(...args: (ArrayLike | number[])[]): Array;
760
+ /**
761
+ * Compute the inner product of two arrays.
762
+ *
763
+ * Unlike `jax.numpy.matmul()` or `jax.numpy.dot()`, this always performs a
764
+ * contraction on the last axis.
765
+ *
766
+ * Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
767
+ */
768
+ declare function inner(x: ArrayLike, y: ArrayLike): Array;
769
+ /**
770
+ * Compute the outer product of two arrays.
771
+ *
772
+ * If the input arrays are not 1D, they will be flattened. Returned array will
773
+ * be of shape `[x.size, y.size]`.
774
+ */
775
+ declare function outer(x: ArrayLike, y: ArrayLike): Array;
776
+ /** Vector dot product of two arrays along a given axis. */
777
+ declare function vecdot(x: ArrayLike, y: ArrayLike, {
778
+ axis
779
+ }?: {
780
+ axis?: number;
781
+ }): Array;
782
+ /**
783
+ * Return the dot product of two vectors.
784
+ *
785
+ * Like vecdot() but flattens the arguments first into vectors.
786
+ */
787
+ declare function vdot(x: ArrayLike, y: ArrayLike): Array;
788
+ /**
789
+ * Return a tuple of coordinate matrices from coordinate vectors.
790
+ *
791
+ * Make N-D coordinate arrays for vectorized evaluations of N-D scalar/vector
792
+ * fields over N-D grids, given one-dimensional coordinate arrays x1, x2,…, xn.
793
+ */
794
+ declare function meshgrid(xs: Array[], {
795
+ indexing
796
+ }?: {
797
+ indexing?: "xy" | "ij";
798
+ }): Array[];
799
+ /**
800
+ * Return an array with ones on and below the diagonal and zeros elsewhere.
801
+ *
802
+ * If `k` is provided, it specifies the sub-diagonal on and below which the
803
+ * array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
804
+ * `k>0` is above it.
805
+ */
806
+ declare function tri(n: number, m?: number, k?: number, {
807
+ dtype,
808
+ device
809
+ }?: DTypeAndDevice): Array;
810
+ /** Return the lower triangle of an array. Must be of dimension >= 2. */
811
+ declare function tril(a: ArrayLike, k?: number): Array;
812
+ /** Return the upper triangle of an array. Must be of dimension >= 2. */
813
+ declare function triu(a: ArrayLike, k?: number): Array;
814
+ /**
815
+ * Clip (limit) the values in an array.
816
+ *
817
+ * Given an interval, values outside the interval are clipped to the interval
818
+ * edges. For example, if an interval of [0, 1] is specified, values smaller
819
+ * than 0 become 0, and values larger than 1 become 1.
820
+ *
821
+ * If either bound is undefined, it is ignored.
822
+ */
823
+ declare function clip(a: ArrayLike, min?: ArrayLike, max?: ArrayLike): Array;
824
+ /**
825
+ * Calculate the absolute value element-wise.
826
+ *
827
+ * This is the same function as `jax.numpy.abs()`.
828
+ */
829
+ declare function absolute(x: ArrayLike): Array;
830
+ /** @function Alias of `jax.numpy.absolute()`. */
831
+ declare const abs: typeof absolute;
832
+ /** Return an element-wise indication of sign of the input. */
833
+ declare function sign(x: ArrayLike): Array;
834
+ /** @function Return element-wise positive values of the input (no-op). */
835
+ declare const positive: (x: ArrayLike) => Array;
836
+ /**
837
+ * Return the Hamming window of size M, a taper with a weighted cosine bell.
838
+ *
839
+ * `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
840
+ */
841
+ declare function hamming(M: number): Array;
842
+ /**
843
+ * Return the Hann window of size M, a taper with a weighted cosine bell.
844
+ *
845
+ * `w(n) = 0.5 - 0.5 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
846
+ */
847
+ declare function hann(M: number): Array;
848
+ /**
849
+ * @function
850
+ * Compute the Heaviside step function. It is defined piecewise:
851
+ * - `heaviside(x1, x2) = 0` for `x1 < 0`,
852
+ * - `heaviside(x1, x2) = x2` for `x1 == 0`,
853
+ * - `heaviside(x1, x2) = 1` for `x1 > 0`.
854
+ */
855
+ declare const heaviside: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
856
+ /** Calculate element-wise square of the input array. */
857
+ declare function square(x: ArrayLike): Array;
858
+ /** Element-wise tangent function (takes radians). */
859
+ declare function tan(x: ArrayLike): Array;
860
+ /** Element-wise inverse cosine function (inverse of cos). */
861
+ declare function acos(x: ArrayLike): Array;
862
+ /**
863
+ * @function
864
+ * Return element-wise hypotenuse for the given legs of a right triangle.
865
+ *
866
+ * In the original NumPy/JAX implementation, this function is more numerically
867
+ * stable than `sqrt(x1**2 + x2**2)`. We don't currently implement those
868
+ * stability improvements.
869
+ */
870
+ declare const hypot: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
871
+ /**
872
+ * @function
873
+ * Element-wise arc tangent of y/x with correct quadrant.
874
+ *
875
+ * Returns the angle in radians between the positive x-axis and the point (x, y).
876
+ * The result is in the range [-π, π].
877
+ *
878
+ * Uses numerically stable formulas:
879
+ * - When x >= 0: atan2(y, x) = 2 * atan(y / (sqrt(x^2 + y^2) + x))
880
+ * - When x < 0: atan2(y, x) = 2 * atan((sqrt(x^2 + y^2) - x) / y)
881
+ *
882
+ * The output is ill-defined when both x and y are zero.
883
+ */
884
+ declare const atan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
885
+ /** @function Alias of `jax.numpy.acos()`. */
886
+ declare const arccos: typeof acos;
887
+ /** @function Alias of `jax.numpy.atan()`. */
888
+ declare const arctan: (x: ArrayLike) => Array;
889
+ /** @function Alias of `jax.numpy.atan2()`. */
890
+ declare const arctan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
891
+ /** Element-wise subtraction, with broadcasting. */
892
+ declare function subtract(x: ArrayLike, y: ArrayLike): Array;
893
+ /** Calculates the floating-point division of x by y element-wise. */
894
+ declare function trueDivide(x: ArrayLike, y: ArrayLike): Array;
895
+ /**
896
+ * @function
897
+ * Calculate element-wise floating-point modulo operation.
898
+ */
899
+ declare const fmod: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
900
+ /**
901
+ * @function
902
+ * Calculate element-wise remainder of the division (matches sign of y).
903
+ */
904
+ declare const remainder: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
905
+ /** @function Alias of `jax.numpy.trueDivide()`. */
906
+ declare const divide: typeof trueDivide;
907
+ /** Round input to the nearest integer towards zero. */
908
+ declare function trunc(x: ArrayLike): Array;
909
+ /**
910
+ * Compute `x1 * 2 ** x2` as a standard multiplication and exponentiation.
911
+ *
912
+ * This is the inverse of `frexp()`.
913
+ */
914
+ declare function ldexp(x1: ArrayLike, x2: ArrayLike): Array;
915
+ /**
916
+ * Decompose floating-point values into mantissa and two's exponent.
917
+ *
918
+ * The mantissa is returned in the range `(-1, 1)` with magnitude `>= 0.5` if
919
+ * `x != 0`, and the exponent is an integer such that
920
+ * `x = mantissa * 2**exponent`.
921
+ */
922
+ declare function frexp(x: ArrayLike): [Array, Array];
923
+ /** Calculate `2**p` for all p in the input array. */
924
+ declare function exp2(p: ArrayLike): Array;
925
+ /** Return the base-2 logarithm of x, element-wise. */
926
+ declare function log2(x: ArrayLike): Array;
927
+ /** Return the base-10 logarithm of x, element-wise. */
928
+ declare function log10(x: ArrayLike): Array;
929
+ /** Calculate `exp(x) - 1` element-wise. */
930
+ declare function expm1(x: ArrayLike): Array;
931
+ /** Calculate the natural logarithm of `1 + x` element-wise. */
932
+ declare function log1p(x: ArrayLike): Array;
933
+ /** Convert angles from degrees to radians. */
934
+ declare function deg2rad(x: ArrayLike): Array;
935
+ /** @function Alias of `jax.numpy.deg2rad()`. */
936
+ declare const radians: typeof deg2rad;
937
+ /** Convert angles from radians to degrees. */
938
+ declare function rad2deg(x: ArrayLike): Array;
939
+ /** @function Alias of `jax.numpy.rad2deg()`. */
940
+ declare const degrees: typeof rad2deg;
941
+ /**
942
+ * @function
943
+ * Computes first array raised to power of second array, element-wise.
1261
944
  */
1262
- declare function tile(a: ArrayLike, reps: number | number[]): Array;
945
+ declare const power: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
946
+ /** @function Alias of `jax.numpy.power()`. */
947
+ declare const pow: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
948
+ /** @function Calculate the element-wise cube root of the input array. */
949
+ declare const cbrt: OwnedFunction<(x: ArrayLike) => Array>;
1263
950
  /**
1264
- * Broadcast an array to a shape, with NumPy-style broadcasing rules.
951
+ * @function
952
+ * Calculate element-wise hyperbolic sine of input.
1265
953
  *
1266
- * In other words, this lets you append axes to the left, and/or expand
1267
- * dimensions where the shape is 1.
954
+ * `sinh(x) = (exp(x) - exp(-x)) / 2`
1268
955
  */
1269
- declare function broadcastTo(a: ArrayLike, shape: number[]): Array;
1270
- /** Broadcast input shapes to a common output shape. */
1271
- declare function broadcastShapes(...shapes: number[][]): number[];
1272
- /** Broadcast arrays to a common shape. */
1273
- declare function broadcastArrays(...arrays: ArrayLike[]): Array[];
956
+ declare const sinh: OwnedFunction<(x: ArrayLike) => Array>;
1274
957
  /**
1275
- * Return specified diagonals.
1276
- *
1277
- * If a is 2D, return the diagonal of the array with the given offset. If a is
1278
- * 3D or higher, compute diagonals along the two given axes (default: 0, 1).
958
+ * @function
959
+ * Calculate element-wise hyperbolic cosine of input.
1279
960
  *
1280
- * This returns a view over the existing array. The shape of the resulting array
1281
- * is determined by removing the two axes along which the diagonal is taken,
1282
- * then appending a new axis to the right with holding the diagonals.
961
+ * `cosh(x) = (exp(x) + exp(-x)) / 2`
1283
962
  */
1284
- declare function diagonal(a: ArrayLike, offset?: number, axis1?: number, axis2?: number): Array;
963
+ declare const cosh: OwnedFunction<(x: ArrayLike) => Array>;
1285
964
  /**
1286
- * Extract a diagonal or construct a diagonal array.
965
+ * @function
966
+ * Calculate element-wise hyperbolic tangent of input.
1287
967
  *
1288
- * If v is a 2D array, return the k-th diagonal of v (as a view). If v is a 1D
1289
- * array, return a 2D array with v on the k-th diagonal.
968
+ * `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
1290
969
  */
1291
- declare function diag(v: ArrayLike, k?: number): Array;
1292
- /** Return if two arrays are element-wise equal within a tolerance. */
1293
- declare function allclose(actual: Parameters<typeof array>[0], expected: Parameters<typeof array>[0], options?: {
1294
- rtol?: number;
1295
- atol?: number;
1296
- }): boolean;
1297
- /** Matrix product of two arrays. */
1298
- declare function matmul(x: ArrayLike, y: ArrayLike): Array;
1299
- /** Dot product of two arrays. */
1300
- declare function dot(x: ArrayLike, y: ArrayLike): Array;
970
+ declare const tanh: OwnedFunction<(x: ArrayLike) => Array>;
1301
971
  /**
1302
- * Compute the inner product of two arrays.
1303
- *
1304
- * Unlike `jax.numpy.matmul()` or `jax.numpy.dot()`, this always performs a
1305
- * contraction on the last axis.
972
+ * @function
973
+ * Calculate element-wise inverse hyperbolic sine of input.
1306
974
  *
1307
- * Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
975
+ * `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
1308
976
  */
1309
- declare function inner(x: ArrayLike, y: ArrayLike): Array;
977
+ declare const arcsinh: OwnedFunction<(x: ArrayLike) => Array>;
1310
978
  /**
1311
- * Compute the outer product of two arrays.
979
+ * @function
980
+ * Calculate element-wise inverse hyperbolic cosine of input.
1312
981
  *
1313
- * If the input arrays are not 1D, they will be flattened. Returned array will
1314
- * be of shape `[x.size, y.size]`.
982
+ * `arccosh(x) = ln(x + sqrt(x^2 - 1))`
1315
983
  */
1316
- declare function outer(x: ArrayLike, y: ArrayLike): Array;
1317
- /** Vector dot product of two arrays along a given axis. */
1318
- declare function vecdot(x: ArrayLike, y: ArrayLike, {
1319
- axis
1320
- }?: {
1321
- axis?: number;
1322
- }): Array;
984
+ declare const arccosh: OwnedFunction<(x: ArrayLike) => Array>;
1323
985
  /**
1324
- * Return the dot product of two vectors.
986
+ * @function
987
+ * Calculate element-wise inverse hyperbolic tangent of input.
1325
988
  *
1326
- * Like vecdot() but flattens the arguments first into vectors.
989
+ * `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
1327
990
  */
1328
- declare function vdot(x: ArrayLike, y: ArrayLike): Array;
991
+ declare const arctanh: OwnedFunction<(x: ArrayLike) => Array>;
992
+ /** @function Alias of `jax.numpy.arcsinh()`. */
993
+ declare const asinh: OwnedFunction<(x: ArrayLike) => Array>;
994
+ /** @function Alias of `jax.numpy.arccosh()`. */
995
+ declare const acosh: OwnedFunction<(x: ArrayLike) => Array>;
996
+ /** @function Alias of `jax.numpy.arctanh()`. */
997
+ declare const atanh: OwnedFunction<(x: ArrayLike) => Array>;
1329
998
  /**
1330
- * Return a tuple of coordinate matrices from coordinate vectors.
999
+ * Compute the variance of an array.
1331
1000
  *
1332
- * Make N-D coordinate arrays for vectorized evaluations of N-D scalar/vector
1333
- * fields over N-D grids, given one-dimensional coordinate arrays x1, x2,…, xn.
1334
- */
1335
- declare function meshgrid(xs: Array[], {
1336
- indexing
1337
- }?: {
1338
- indexing?: "xy" | "ij";
1339
- }): Array[];
1340
- /**
1341
- * Return an array with ones on and below the diagonal and zeros elsewhere.
1001
+ * The variance is computed for the flattened array by default, otherwise over
1002
+ * the specified axis.
1342
1003
  *
1343
- * If `k` is provided, it specifies the sub-diagonal on and below which the
1344
- * array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
1345
- * `k>0` is above it.
1004
+ * If `correction` is provided, the divisor in calculation is `N - correction`,
1005
+ * where `N` represents the number of elements (e.g., for Bessel's correction).
1346
1006
  */
1347
- declare function tri(n: number, m?: number, k?: number, {
1348
- dtype,
1349
- device
1350
- }?: DTypeAndDevice): Array;
1351
- /** Return the lower triangle of an array. Must be of dimension >= 2. */
1352
- declare function tril(a: ArrayLike, k?: number): Array;
1353
- /** Return the upper triangle of an array. Must be of dimension >= 2. */
1354
- declare function triu(a: ArrayLike, k?: number): Array;
1007
+ declare function var_(x: ArrayLike, axis?: Axis, opts?: {
1008
+ mean?: ArrayLike;
1009
+ correction?: number;
1010
+ } & ReduceOpts): Array;
1355
1011
  /**
1356
- * Clip (limit) the values in an array.
1012
+ * Compute the standard deviation of an array.
1357
1013
  *
1358
- * Given an interval, values outside the interval are clipped to the interval
1359
- * edges. For example, if an interval of [0, 1] is specified, values smaller
1360
- * than 0 become 0, and values larger than 1 become 1.
1014
+ * The standard deviation is computed for the flattened array by default,
1015
+ * otherwise over the specified axis.
1361
1016
  *
1362
- * If either bound is undefined, it is ignored.
1017
+ * If `correction` is provided, the divisor in calculation is `N - correction`,
1018
+ * where `N` represents the number of elements (e.g., for Bessel's correction).
1363
1019
  */
1364
- declare function clip(a: ArrayLike, min?: ArrayLike, max?: ArrayLike): Array;
1020
+ declare function std(x: ArrayLike, axis?: Axis, opts?: {
1021
+ mean?: ArrayLike;
1022
+ correction?: number;
1023
+ } & ReduceOpts): Array;
1024
+ /** Test element-wise for positive or negative infinity, return bool array. */
1025
+ declare function isinf(x: ArrayLike): Array;
1026
+ /** Test element-wise for NaN (Not a Number). */
1027
+ declare function isnan(x: ArrayLike): Array;
1028
+ /** Test element-wise for negative infinity, return bool array. */
1029
+ declare function isneginf(x: ArrayLike): Array;
1030
+ /** Test element-wise for positive infinity, return bool array. */
1031
+ declare function isposinf(x: ArrayLike): Array;
1365
1032
  /**
1366
- * Calculate the absolute value element-wise.
1033
+ * @function
1034
+ * Test element-wise for finite values (not infinity or NaN).
1035
+ */
1036
+ declare const isfinite: OwnedFunction<(x: ArrayLike) => Array>;
1037
+ //# sourceMappingURL=numpy.d.ts.map
1038
+ declare namespace tree_d_exports {
1039
+ export { JsTree, JsTreeDef, MapJsTree, NodeType, dispose, flatten, leaves, map, ref, structure, unflatten };
1040
+ }
1041
+ declare enum NodeType {
1042
+ Array = "Array",
1043
+ Object = "Object",
1044
+ Leaf = "Leaf",
1045
+ }
1046
+ /** Analog to the JAX "pytree" object, but for JavaScript. */
1047
+ type JsTree<T> = T | JsTree<T>[] | {
1048
+ [key: string]: JsTree<T>;
1049
+ };
1050
+ type Same<X, Y> = (<T>() => T extends X ? 1 : 2) extends (<T>() => T extends Y ? 1 : 2) ? true : false;
1051
+ type MappedJsTree<T, A, B> = T extends A ? B : T extends Array ? T : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
1052
+ /** @ignore Convert a subtype of JsTree<A> into a JsTree<B>, with the same structure. */
1053
+ type MapJsTree<T, A, B> = Same<A, B> extends true ? T : MappedJsTree<T, A, B>;
1054
+ /** Represents the structure of a JsTree. */
1055
+ declare class JsTreeDef {
1056
+ readonly nodeType: NodeType;
1057
+ readonly nodeMetadata: any;
1058
+ readonly childTreedefs: JsTreeDef[];
1059
+ static leaf: JsTreeDef;
1060
+ constructor(nodeType: NodeType, nodeMetadata: any,
1061
+ // Must be comparable with deepEqual.
1062
+ childTreedefs: JsTreeDef[]);
1063
+ /** Get the total number of leaves in the tree. */
1064
+ get size(): number;
1065
+ /** Returns a string representation of this tree definition. */
1066
+ toString(root?: boolean): string;
1067
+ /** Compare this tree definition with another. */
1068
+ equals(other: JsTreeDef): boolean;
1069
+ }
1070
+ /** Flatten a structured object, returning the tree definition. */
1071
+ declare function flatten<T>(tree: JsTree<T>): [T[], JsTreeDef];
1072
+ /** Get the leaves of a tree. */
1073
+ declare function leaves<T>(tree: JsTree<T>): T[];
1074
+ /** Get the treedef for a tree. */
1075
+ declare function structure<T>(tree: JsTree<T>): JsTreeDef;
1076
+ /** Reconstruct a structured object from the flattened representation. */
1077
+ declare function unflatten<T>(treedef: JsTreeDef, leaves: Iterable<T>): JsTree<T>;
1078
+ /** Maps a multi-input function over pytree args to produce a new pytree. */
1079
+ declare function map<T, U, Tree extends JsTree<T>>(fn: (...args: T[]) => U, tree: Tree, ...rest: Tree[]): MapJsTree<Tree, T, U>;
1080
+ /** Take a reference of every array in a tree. */
1081
+ declare function ref<Tree extends JsTree<Array>>(tree: Tree): Tree;
1082
+ /** Dispose every array in a tree. */
1083
+ declare function dispose<Tree extends JsTree<Array>>(tree: Tree | null | undefined): void;
1084
+ //#endregion
1085
+ //#region src/frontend/convolution.d.ts
1086
+ /** Definition of a general dilated convolution. Should be valid on creation. */
1087
+ interface ConvParams {
1088
+ vmapDims: number;
1089
+ strides: number[];
1090
+ padding: [number, number][];
1091
+ lhsDilation: number[];
1092
+ rhsDilation: number[];
1093
+ }
1094
+ /**
1095
+ * Check that the shapes and parameters passed to convolution are valid.
1096
+ * Expected shapes of the lhs and rhs of the convolution are:
1367
1097
  *
1368
- * This is the same function as `jax.numpy.abs()`.
1098
+ * - `lhsShape = [*vmapDims, batchSize, inChannels, spatialDims...]`
1099
+ * - `rhsShape = [*vmapDims, outChannels, inChannels, kernelSize...]`
1100
+ *
1101
+ * If the check succeeds, returns the output shape.
1369
1102
  */
1370
- declare function absolute(x: ArrayLike): Array;
1371
- /** @function Alias of `jax.numpy.absolute()`. */
1372
- declare const abs: typeof absolute;
1373
- /** Return an element-wise indication of sign of the input. */
1374
- declare function sign(x: ArrayLike): Array;
1103
+ //#endregion
1104
+ //#region src/frontend/jaxpr.d.ts
1375
1105
  /**
1376
- * Return the Hamming window of size M, a taper with a weighted cosine bell.
1106
+ * Function callback with an associated dispose() method.
1377
1107
  *
1378
- * `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
1108
+ * The dispose() method should be called to clean up any tracer resources needed
1109
+ * by the function after the last time it is called.
1379
1110
  */
1380
- declare function hamming(M: number): Array;
1111
+ type OwnedFunction<F extends Function> = F & {
1112
+ dispose: () => void;
1113
+ };
1114
+ /** Variable in a Jaxpr expression. */
1115
+ declare class Var {
1116
+ #private;
1117
+ readonly id: number;
1118
+ readonly aval: ShapedArray;
1119
+ constructor(aval: ShapedArray);
1120
+ toString(): string;
1121
+ }
1122
+ /** Literal in a Jaxpr expression. Currently, only scalars are supported. */
1123
+ declare class Lit {
1124
+ readonly value: number;
1125
+ readonly aval: ShapedArray;
1126
+ get dtype(): DType;
1127
+ constructor(aval: AbstractValue, value: number);
1128
+ }
1129
+ type Atom = Var | Lit;
1130
+ declare class VarPrinter {
1131
+ #private;
1132
+ names: Map<Var, string>;
1133
+ name(v: Var): string;
1134
+ nameType(v: Var): string;
1135
+ }
1136
+ /** A single statement / binding in a Jaxpr, in ANF form. */
1137
+ declare class JaxprEqn {
1138
+ readonly primitive: Primitive;
1139
+ readonly inputs: Atom[];
1140
+ readonly params: Record<string, any>;
1141
+ readonly outBinders: Var[];
1142
+ constructor(primitive: Primitive, inputs: Atom[], params: Record<string, any>, outBinders: Var[]);
1143
+ pprint(usedVars?: Set<Var>, vp?: VarPrinter): PPrint;
1144
+ toString(): string;
1145
+ }
1146
+ /** Typed intermediate representation for traced computations. */
1147
+ declare class Jaxpr implements FpHashable {
1148
+ #private;
1149
+ readonly inBinders: Var[];
1150
+ readonly eqns: JaxprEqn[];
1151
+ readonly outs: Atom[];
1152
+ constructor(inBinders: Var[], eqns: JaxprEqn[], outs: Atom[]);
1153
+ pprint(): PPrint;
1154
+ toString(): string;
1155
+ /**
1156
+ * Gets a hash of this Jaxpr.
1157
+ *
1158
+ * Var identity is not considered in the hash, so two Jaxprs with the same
1159
+ * order of assignments and operators but different variable IDs will resolve
1160
+ * to the same hash (and toString representation).
1161
+ */
1162
+ getHash(): bigint;
1163
+ hash(state: FpHash): void;
1164
+ /**
1165
+ * Produce a simplified Jaxpr with basic optimizations applied.
1166
+ * - Trim away unused variables.
1167
+ * - Fold away *1, *0, or +0 operations against literals.
1168
+ * - Remove no-op movement operations.
1169
+ */
1170
+ simplify(): Jaxpr;
1171
+ /** Flattens nested JitCall in a Jaxpr. Useful for handling jit-of-jit. */
1172
+ flatten(): Jaxpr;
1173
+ }
1174
+ /** @inline */
1175
+ type JitOpts = {
1176
+ staticArgnums?: number[];
1177
+ };
1178
+ //#endregion
1179
+ //#region src/frontend/core.d.ts
1381
1180
  /**
1382
- * Return the Hann window of size M, a taper with a weighted cosine bell.
1181
+ * Frontend primitive operations, which are lowered into Kernel objects before
1182
+ * being dispatched to the backend.
1383
1183
  *
1384
- * `w(n) = 0.5 - 0.5 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
1385
- */
1386
- declare function hann(M: number): Array;
1387
- /**
1388
- * @function
1389
- * Compute the Heaviside step function. It is defined piecewise:
1390
- * - `heaviside(x1, x2) = 0` for `x1 < 0`,
1391
- * - `heaviside(x1, x2) = x2` for `x1 == 0`,
1392
- * - `heaviside(x1, x2) = 1` for `x1 > 0`.
1393
- */
1394
- declare const heaviside: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
1395
- /** Calculate element-wise square of the input array. */
1396
- declare function square(x: ArrayLike): Array;
1397
- /** Element-wise tangent function (takes radians). */
1398
- declare function tan(x: ArrayLike): Array;
1399
- /** Element-wise inverse cosine function (inverse of cos). */
1400
- declare function acos(x: ArrayLike): Array;
1401
- /**
1402
- * @function
1403
- * Return element-wise hypotenuse for the given legs of a right triangle.
1184
+ * Any operation between arrays can be described in these parts. This is also
1185
+ * the set of primitives that can occur in Jaxpr programs, and the level at
1186
+ * which transformations like vmap, grad, and jvp occur. They are loosely based
1187
+ * on [XLA](https://openxla.org/xla/operation_semantics).
1404
1188
  *
1405
- * In the original NumPy/JAX implementation, this function is more numerically
1406
- * stable than `sqrt(x1**2 + x2**2)`. We don't currently implement those
1407
- * stability improvements.
1189
+ * All n-ary operations support broadcasting, with NumPy semantics.
1408
1190
  */
1409
- declare const hypot: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
1191
+ declare enum Primitive {
1192
+ Add = "add",
1193
+ Mul = "mul",
1194
+ Idiv = "idiv",
1195
+ Mod = "mod",
1196
+ // uses sign of dividend, C-style, matches JS but not Python
1197
+ Neg = "neg",
1198
+ Reciprocal = "reciprocal",
1199
+ Floor = "floor",
1200
+ Ceil = "ceil",
1201
+ StopGradient = "stop_gradient",
1202
+ Cast = "cast",
1203
+ Bitcast = "bitcast",
1204
+ RandomBits = "random_bits",
1205
+ Sin = "sin",
1206
+ Cos = "cos",
1207
+ Asin = "asin",
1208
+ Atan = "atan",
1209
+ Exp = "exp",
1210
+ Log = "log",
1211
+ Erf = "erf",
1212
+ Erfc = "erfc",
1213
+ Sqrt = "sqrt",
1214
+ Min = "min",
1215
+ Max = "max",
1216
+ Reduce = "reduce",
1217
+ Dot = "dot",
1218
+ // sum(x*y, axis=-1)
1219
+ Conv = "conv",
1220
+ // see lax.conv_general_dilated
1221
+ Pool = "pool",
1222
+ PoolTranspose = "pool_transpose",
1223
+ Compare = "compare",
1224
+ Where = "where",
1225
+ Transpose = "transpose",
1226
+ Broadcast = "broadcast",
1227
+ Reshape = "reshape",
1228
+ Flip = "flip",
1229
+ Shrink = "shrink",
1230
+ Pad = "pad",
1231
+ Gather = "gather",
1232
+ JitCall = "jit_call",
1233
+ }
1234
+ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
1235
+ [Primitive.Cast]: {
1236
+ dtype: DType;
1237
+ };
1238
+ [Primitive.Bitcast]: {
1239
+ dtype: DType;
1240
+ };
1241
+ [Primitive.Reduce]: {
1242
+ op: AluOp;
1243
+ axis: number[];
1244
+ };
1245
+ [Primitive.Conv]: ConvParams;
1246
+ [Primitive.Pool]: {
1247
+ window: number[];
1248
+ strides: number[];
1249
+ };
1250
+ [Primitive.PoolTranspose]: {
1251
+ inShape: number[];
1252
+ window: number[];
1253
+ strides: number[];
1254
+ };
1255
+ [Primitive.Compare]: {
1256
+ op: CompareOp;
1257
+ };
1258
+ [Primitive.Transpose]: {
1259
+ perm: number[];
1260
+ };
1261
+ [Primitive.Broadcast]: {
1262
+ shape: number[];
1263
+ axis: number[];
1264
+ };
1265
+ [Primitive.RandomBits]: {
1266
+ shape: number[];
1267
+ mode: "xor" | 0 | 1;
1268
+ };
1269
+ [Primitive.Reshape]: {
1270
+ shape: number[];
1271
+ };
1272
+ [Primitive.Flip]: {
1273
+ axis: number[];
1274
+ };
1275
+ [Primitive.Shrink]: {
1276
+ slice: Pair[];
1277
+ };
1278
+ [Primitive.Pad]: {
1279
+ width: Pair[];
1280
+ };
1281
+ [Primitive.Gather]: {
1282
+ axis: number[];
1283
+ outDim: number;
1284
+ };
1285
+ [Primitive.JitCall]: {
1286
+ name: string;
1287
+ jaxpr: Jaxpr;
1288
+ numConsts: number;
1289
+ };
1290
+ }
1291
+ /** Type of parameters taken by each primitive. */
1292
+ type PrimitiveParams<T extends Primitive> = T extends keyof PrimitiveParamsImpl ? PrimitiveParamsImpl[T] : Record<string, never>;
1293
+ declare enum CompareOp {
1294
+ Less = "less",
1295
+ Equal = "equal",
1296
+ NotEqual = "not_equal",
1297
+ LessEqual = "less_equal",
1298
+ }
1299
+ /** @inline */
1300
+ type Axis = number | number[] | null;
1301
+ /** @inline */
1302
+ type ReduceOpts = {
1303
+ keepdims?: boolean;
1304
+ };
1305
+ type MainTrace = {
1306
+ level: number;
1307
+ traceType: new (main: MainTrace) => Trace;
1308
+ globalData: any | null;
1309
+ };
1410
1310
  /**
1411
- * @function
1412
- * Element-wise arc tangent of y/x with correct quadrant.
1413
- *
1414
- * Returns the angle in radians between the positive x-axis and the point (x, y).
1415
- * The result is in the range [-π, π].
1416
- *
1417
- * Uses numerically stable formulas:
1418
- * - When x >= 0: atan2(y, x) = 2 * atan(y / (sqrt(x^2 + y^2) + x))
1419
- * - When x < 0: atan2(y, x) = 2 * atan((sqrt(x^2 + y^2) - x) / y)
1420
- *
1421
- * The output is ill-defined when both x and y are zero.
1311
+ * Push an interpreter onto the trace stack. Use this like:
1312
+ * `using main = newMain(...);`
1422
1313
  */
1423
- declare const atan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
1424
- /** @function Alias of `jax.numpy.acos()`. */
1425
- declare const arccos: typeof acos;
1426
- /** @function Alias of `jax.numpy.atan()`. */
1427
- declare const arctan: (x: ArrayLike) => Array;
1428
- /** @function Alias of `jax.numpy.atan2()`. */
1429
- declare const arctan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
1430
- /** Element-wise subtraction, with broadcasting. */
1431
- declare function subtract(x: ArrayLike, y: ArrayLike): Array;
1432
- /** Calculates the floating-point division of x by y element-wise. */
1433
- declare function trueDivide(x: ArrayLike, y: ArrayLike): Array;
1434
- /** @function Alias of `jax.numpy.trueDivide()`. */
1435
- declare const divide: typeof trueDivide;
1436
- /** Round input to the nearest integer towards zero. */
1437
- declare function trunc(x: ArrayLike): Array;
1438
- /** Calculate `2**p` for all p in the input array. */
1439
- declare function exp2(p: ArrayLike): Array;
1440
- /** Return the base-2 logarithm of x, element-wise. */
1441
- declare function log2(x: ArrayLike): Array;
1442
- /** Return the base-10 logarithm of x, element-wise. */
1443
- declare function log10(x: ArrayLike): Array;
1444
- /** Calculate `exp(x) - 1` element-wise. */
1445
- declare function expm1(x: ArrayLike): Array;
1446
- /** Calculate the natural logarithm of `1 + x` element-wise. */
1447
- declare function log1p(x: ArrayLike): Array;
1448
- /** Convert angles from degrees to radians. */
1449
- declare function deg2rad(x: ArrayLike): Array;
1450
- /** @function Alias of `jax.numpy.deg2rad()`. */
1451
- declare const radians: typeof deg2rad;
1452
- /** Convert angles from radians to degrees. */
1453
- declare function rad2deg(x: ArrayLike): Array;
1454
- /** @function Alias of `jax.numpy.rad2deg()`. */
1455
- declare const degrees: typeof rad2deg;
1314
+
1315
+ type TracerValue = Tracer | number | boolean;
1316
+ declare abstract class Trace {
1317
+ readonly main: MainTrace;
1318
+ constructor(main: MainTrace);
1319
+ abstract pure(val: TracerValue): Tracer;
1320
+ abstract lift(val: Tracer): Tracer;
1321
+ abstract processPrimitive<P extends Primitive>(primitive: P, tracers: Tracer[], params: PrimitiveParams<P>): Tracer[];
1322
+ }
1323
+ /** Internal representation of an array value. */
1324
+ interface AbstractValue {
1325
+ /** Shape of the array. Must be a static tuple of non-negative dimensions. */
1326
+ shape: number[];
1327
+ /** Concrete data type of array elements. */
1328
+ dtype: DType;
1329
+ /**
1330
+ * Arrays created from JavaScript numbers (e.g., `np.array(3)`) are created as
1331
+ * _weakly typed_ unless a dtype is explicitly specified.
1332
+ *
1333
+ * Weakly typed values will automatically cast to the data type of other
1334
+ * arrays when used as an operand as an expression. This property only affects
1335
+ * how they promote in type casting; their memory layout is still determined
1336
+ * by the actual `dtype` field.
1337
+ *
1338
+ * ```ts
1339
+ * const x = np.array(3); // weakType = true, dtype = float32
1340
+ * const y = np.array([1, 2], { dtype: np.int32 }); // weakType = false, dtype = int32
1341
+ * const z = x.add(y); // z has dtype int32 because x is weakly typed
1342
+ * ```
1343
+ *
1344
+ * Weak types are present in JIT programs in their spec (e.g., Jaxpr inputs
1345
+ * and outputs can be weakly typed) form. But they're solely a frontend
1346
+ * concept. Backends are not aware of weak types.
1347
+ */
1348
+ weakType: boolean;
1349
+ }
1456
1350
  /**
1457
- * @function
1458
- * Computes first array raised to power of second array, element-wise.
1351
+ * Broadcast shapes and promote types with casting for two avals.
1352
+ *
1353
+ * This implements the weak type behavior described in `promoteTypes()`, but not
1354
+ * implemented in that function as `weakType` is not passed.
1459
1355
  */
1460
- declare const power: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
1461
- /** @function Alias of `jax.numpy.power()`. */
1462
- declare const pow: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
1463
- /** @function Calculate the element-wise cube root of the input array. */
1464
- declare const cbrt: OwnedFunction<(x: ArrayLike) => Array>;
1356
+
1357
+ declare abstract class Tracer {
1358
+ /** @ignore */
1359
+ readonly _trace: Trace;
1360
+ constructor(trace: Trace);
1361
+ abstract get aval(): AbstractValue;
1362
+ abstract toString(): string;
1363
+ /**
1364
+ * Access an array by reference, incrementing the reference count.
1365
+ *
1366
+ * jax-js handles freeing arrays by using "move" semantics, like in Rust/C++.
1367
+ * Whenever you pass an array into a function, that function should consume
1368
+ * the array, and it will no longer be usable. For example, if you had:
1369
+ *
1370
+ * ```
1371
+ * const x = np.array([1, 2, 3]);
1372
+ * const y = np.add(x, x);
1373
+ * ```
1374
+ *
1375
+ * The second line does not work because the first parameter consumes `x`, and
1376
+ * then the second parameter will already have been freed / disposed.
1377
+ *
1378
+ * To fix this, you can write:
1379
+ *
1380
+ * ```
1381
+ * const y = np.add(x.ref, x);
1382
+ * ```
1383
+ *
1384
+ * Under the hood, every access to `.ref` increments the internal reference
1385
+ * count of the array. The reference count starts at 1. When it hits 0, the
1386
+ * memory behind the array is freed.
1387
+ */
1388
+ abstract get ref(): this;
1389
+ /**
1390
+ * Manually decrement the reference count of the array.
1391
+ *
1392
+ * Arrays are created with reference count 1. Whenever it is used as argument
1393
+ * to a function or other operation, it is disposed (i.e., reference count
1394
+ * decreases by 1) automatically. Whenever a `.ref` is created, the reference
1395
+ * count increases.
1396
+ *
1397
+ * You generally don't need to call this function directly since arrays are
1398
+ * automatically disposed after being passed into an operation. One common
1399
+ * exception is when writing a function and ignoring one of its arguments. In
1400
+ * that case, by convention you should dispose of that argument manually.
1401
+ *
1402
+ * ```
1403
+ * function myCustomOperation(a: np.Array, b: np.Array) {
1404
+ * b.dispose(); // Needed to satisfy "move" rules.
1405
+ * return a.add(1);
1406
+ * }
1407
+ * ```
1408
+ */
1409
+ abstract dispose(): void;
1410
+ /** The shape of the array. */
1411
+ get shape(): number[];
1412
+ /** The total number of elements in the array. */
1413
+ get size(): number;
1414
+ /** The dtype of elements stored in the array. */
1415
+ get dtype(): DType;
1416
+ /**
1417
+ * Whether the array is weakly typed.
1418
+ *
1419
+ * Weakly typed arrays will cast to the dtype of the other operand. See
1420
+ * `promoteTypes()` for details.
1421
+ */
1422
+ get weakType(): boolean;
1423
+ /** The number of dimensions of the array. */
1424
+ get ndim(): number;
1425
+ /** @ignore */
1426
+ fullLower(): Tracer;
1427
+ neg(): this;
1428
+ add(other: this | TracerValue): this;
1429
+ mul(other: this | TracerValue): this;
1430
+ greater(other: this | TracerValue): this;
1431
+ less(other: this | TracerValue): this;
1432
+ equal(other: this | TracerValue): this;
1433
+ notEqual(other: this | TracerValue): this;
1434
+ greaterEqual(other: this | TracerValue): this;
1435
+ lessEqual(other: this | TracerValue): this;
1436
+ /** Sum of the elements of the array over a given axis, or axes. */
1437
+ sum(axis?: Axis, opts?: ReduceOpts): this;
1438
+ /** Product of the array elements over a given axis. */
1439
+ prod(axis?: Axis, opts?: ReduceOpts): this;
1440
+ /** Compute the average of the array elements along the specified axis. */
1441
+ mean(axis?: Axis, opts?: ReduceOpts): this;
1442
+ /** Permute the dimensions of an array. Defaults to reversing the axis order. */
1443
+ transpose(perm?: number[]): this;
1444
+ /**
1445
+ * Give a new shape to an array without changing its data.
1446
+ *
1447
+ * One shape dimension can be -1. In this case, the value is inferred from the
1448
+ * length of the array and remaining dimensions.
1449
+ */
1450
+ reshape(shape: number | number[]): this;
1451
+ /** Copy the array and cast to a specified dtype. */
1452
+ astype(dtype: DType): this;
1453
+ /** Subtract an array from this one. */
1454
+ sub(other: this | TracerValue): this;
1455
+ /** Divide an array by this one. */
1456
+ div(other: this | TracerValue): this;
1457
+ /** Return specified diagonals. See `numpy.diagonal` for full docs. */
1458
+ diagonal(offset?: number, axis1?: number, axis2?: number): this;
1459
+ /** Flatten the array without changing its data. */
1460
+ flatten(): this;
1461
+ /** Flatten the array without changing its data. */
1462
+ ravel(): this;
1463
+ /**
1464
+ * Iterate over the first dimension of this array, returning slices.
1465
+ *
1466
+ * This can be used to destructure arrays. For example:
1467
+ *
1468
+ * ```js
1469
+ * let x = np.array([[1, 2], [3, 4]]);
1470
+ * let [a, b] = x;
1471
+ * console.log(a.js()); // [1, 2]
1472
+ * console.log(b.js()); // [3, 4]
1473
+ * ```
1474
+ */
1475
+ [Symbol.iterator](): IterableIterator<this>;
1476
+ /**
1477
+ * Slice an array along one or more axes.
1478
+ *
1479
+ * This is the equivalent of slicing in Python, e.g. `x[1:3, 2, :, None]`. To
1480
+ * mimic this in JavaScript, we would write:
1481
+ *
1482
+ * ```js
1483
+ * x.slice([1, 3], 2, [], null);
1484
+ * ```
1485
+ *
1486
+ * The `slice` method accepts a variable number of arguments, each of which
1487
+ * can be a number, an empty array, a single-element array, a two-element
1488
+ * array, or `null`. The arguments are interpreted as follows:
1489
+ *
1490
+ * - A number `n` means to access the `n`-th element along that axis, removing
1491
+ * that axis from the resulting shape.
1492
+ * - An empty array `[]` means to keep that axis as-is, like `:` in Python.
1493
+ * - A single-element array `[i]` means to start slicing from index `i`
1494
+ * (inclusive) to the end of the axis, like `x[i:]`.
1495
+ * - A two-element array `[i, j]` means to slice from index `i` (inclusive)
1496
+ * to index `j` (exclusive), like `x[i:j]`.
1497
+ * - `null` means to add a new axis at that position, like `np.newaxis`.
1498
+ *
1499
+ * Like in Python, negative indices are supported, which count from the end of
1500
+ * the axis. For example, `-1` means the last element.
1501
+ *
1502
+ * Strided slices are not yet implemented, so you cannot write `x[::2]` or
1503
+ * similar.
1504
+ *
1505
+ * Advanced indexing by integer arrays is also supported. This translates to
1506
+ * the "gather" primitive, and it allows you to access specific elements of
1507
+ * the array by integer indices stored in another array.
1508
+ */
1509
+ slice(...index: (number | [] | [number] | Pair | null | Tracer)[]): this;
1510
+ }
1511
+ declare class ShapedArray implements AbstractValue {
1512
+ readonly shape: number[];
1513
+ readonly dtype: DType;
1514
+ readonly weakType: boolean;
1515
+ constructor(shape: number[], dtype: DType, weakType: boolean);
1516
+ static fromAval(aval: AbstractValue): ShapedArray;
1517
+ get ndim(): number;
1518
+ toString(): string;
1519
+ equals(other: ShapedArray): boolean;
1520
+ }
1521
+ //#endregion
1522
+ //#region src/frontend/array.d.ts
1523
+ type ArrayLike = Array | number | boolean;
1524
+ /** Version of pureArray with fudged types. */
1525
+
1465
1526
  /**
1466
- * @function
1467
- * Calculate element-wise hyperbolic sine of input.
1527
+ * An executable operation that will be dispatched to the backend.
1468
1528
  *
1469
- * `sinh(x) = (exp(x) - exp(-x)) / 2`
1529
+ * This holds a reference to all input buffers used in the operation. After the
1530
+ * operation is dispatched, the references should be released.
1470
1531
  */
1471
- declare const sinh: OwnedFunction<(x: ArrayLike) => Array>;
1532
+ declare class PendingExecute {
1533
+ #private;
1534
+ readonly backend: Backend;
1535
+ readonly kernel: Kernel;
1536
+ readonly inputs: Slot[];
1537
+ readonly outputs: Slot[];
1538
+ prepared: Executable | null;
1539
+ submitted: boolean;
1540
+ constructor(backend: Backend, kernel: Kernel, inputs: Slot[], outputs: Slot[]);
1541
+ updateRc(delta: number): void;
1542
+ prepare(): Promise<void>;
1543
+ prepareSync(): void;
1544
+ submit(): void;
1545
+ }
1546
+ /** @inline */
1547
+ type DTypeAndDevice = {
1548
+ dtype?: DType;
1549
+ device?: Device;
1550
+ };
1551
+ type ArrayConstructorArgs = {
1552
+ source: AluExp | Slot;
1553
+ st: ShapeTracker;
1554
+ dtype: DType;
1555
+ weakType: boolean;
1556
+ backend: Backend;
1557
+ committed: boolean;
1558
+ pending?: Iterable<PendingExecute>;
1559
+ };
1472
1560
  /**
1473
- * @function
1474
- * Calculate element-wise hyperbolic cosine of input.
1561
+ * A multidimensional numeric array with data stored on CPU or GPU.
1475
1562
  *
1476
- * `cosh(x) = (exp(x) + exp(-x)) / 2`
1563
+ * This is the library's core data type. Equivalent to `jax.Array` from JAX, or
1564
+ * `torch.Tensor`.
1565
+ *
1566
+ * Not to be confused with the JavaScript "Array" constructor. Avoid importing
1567
+ * this into your code's namespace if you're already using the JavaScript
1568
+ * "Array" type by name.
1477
1569
  */
1478
- declare const cosh: OwnedFunction<(x: ArrayLike) => Array>;
1570
+ declare class Array extends Tracer {
1571
+ #private;
1572
+ id: number;
1573
+ /**
1574
+ * @ignore
1575
+ * Constructs an array from source, shape and backend. Note that if the source
1576
+ * is a backend `Slot`, this constructor _takes ownership_ of the slot. It
1577
+ * will be freed when the array is disposed.
1578
+ */
1579
+ constructor(args: ArrayConstructorArgs);
1580
+ /** @ignore */
1581
+ get aval(): ShapedArray;
1582
+ /** Return a simple string representation of the array's dimensions. */
1583
+ toString(): string;
1584
+ get device(): Device;
1585
+ get ref(): this;
1586
+ dispose(): void;
1587
+ /**
1588
+ * Convert this array into a primitive value.
1589
+ *
1590
+ * This only works for scalars (0-dimensional arrays). It lets you get values
1591
+ * "out" of the JAX system. For instance, if `x = np.array(5)`, then you can
1592
+ * evaluate `x + 1` and `x ** 2` to get `6` and `25`, respectively.
1593
+ *
1594
+ * This method is also called for `==` equality.
1595
+ */
1596
+ [Symbol.toPrimitive](): any;
1597
+ /** Realize the array and return it as data. */
1598
+ data(): Promise<DataArray>;
1599
+ /**
1600
+ * Wait for this array to finish evaluation.
1601
+ *
1602
+ * Operations and data loading in jax-js are lazy, so this function ensures
1603
+ * that pending operations are dispatched and fully executed before it
1604
+ * returns.
1605
+ *
1606
+ * If you are mapping from `data()` or `dataSync()`, it will also trigger
1607
+ * dispatch of operations as well.
1608
+ *
1609
+ * **Note:** `jax.blockUntilReady()` is a higher-level API, it calls this
1610
+ * asynchronously for multiple arrays.
1611
+ */
1612
+ blockUntilReady(): Promise<Array>;
1613
+ /**
1614
+ * Realize the array and return it as data. This is a sync variant and not
1615
+ * recommended for performance reasons, as it will block rendering.
1616
+ */
1617
+ dataSync(): DataArray;
1618
+ /**
1619
+ * Convert this array into a JavaScript object.
1620
+ *
1621
+ * This is a blocking operation that will compile all of the shaders and wait
1622
+ * for execution to complete, synchronously. No other JavaScript code on the
1623
+ * site will be run during shader execution.
1624
+ *
1625
+ * To avoid blocking, prefer `jsAsync()` when possible.
1626
+ */
1627
+ js(): any;
1628
+ /** Convert this array into a JavaScript object, asynchronously. */
1629
+ jsAsync(): Promise<any>;
1630
+ /**
1631
+ * Copy an element of an array to a numeric scalar and return it.
1632
+ *
1633
+ * Throws an error if the array does not have a single element. The array must
1634
+ * either be rank-0, or all dimensions of the shape are 1.
1635
+ */
1636
+ item(): number;
1637
+ /** @private Internal plumbing method for Array / Tracer ops. */
1638
+ static _implRules(): typeof implRules;
1639
+ /** @private */
1640
+ _realizeSource(): number;
1641
+ /** @private Put this array on a new backend, asynchronously. */
1642
+ _put(backend: Backend): Promise<Array>;
1643
+ /** @private Put this array on a new backend, synchronously. */
1644
+ _putSync(backend: Backend): Array;
1645
+ }
1646
+ /** Constructor for creating a new array from data. */
1647
+ declare function array(values: Array | DataArray | RecursiveArray<number> | RecursiveArray<boolean>, {
1648
+ shape,
1649
+ dtype,
1650
+ device
1651
+ }?: {
1652
+ shape?: number[];
1653
+ } & DTypeAndDevice): Array;
1654
+ /** If x is a value, lift it into an array, otherwise leave it be. */
1655
+
1656
+ type ImplRule<P extends Primitive> = (tracers: Array[], params: PrimitiveParams<P>) => Array[];
1657
+ declare const implRules: { [P in Primitive]: ImplRule<P> };
1658
+ /** Return a new array of given shape and type, filled with zeros. */
1659
+ declare function zeros(shape: number[], {
1660
+ dtype,
1661
+ device
1662
+ }?: DTypeAndDevice): Array;
1663
+ /** Return a new array of given shape and type, filled with ones. */
1664
+ declare function ones(shape: number[], {
1665
+ dtype,
1666
+ device
1667
+ }?: DTypeAndDevice): Array;
1668
+ /** Return a new array of given shape and type, filled with `fill_value`. */
1669
+ declare function full(shape: number[], fillValue: number | boolean | Array, {
1670
+ dtype,
1671
+ device
1672
+ }?: DTypeAndDevice): Array;
1479
1673
  /**
1480
- * @function
1481
- * Calculate element-wise hyperbolic tangent of input.
1674
+ * Create an identity matrix.
1482
1675
  *
1483
- * `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
1676
+ * If numCols is not provided, it defaults to numRows, i.e., a square identity
1677
+ * matrix with ones on the diagonal.
1484
1678
  */
1485
- declare const tanh: OwnedFunction<(x: ArrayLike) => Array>;
1679
+ declare function eye(numRows: number, numCols?: number, {
1680
+ dtype,
1681
+ device
1682
+ }?: DTypeAndDevice): Array;
1683
+ /** Return the identity matrix, with ones on the main diagonal. */
1684
+ declare function identity$1(n: number, {
1685
+ dtype,
1686
+ device
1687
+ }?: DTypeAndDevice): Array;
1486
1688
  /**
1487
- * @function
1488
- * Calculate element-wise inverse hyperbolic sine of input.
1689
+ * Return evenly spaced values within a given interval.
1489
1690
  *
1490
- * `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
1491
- */
1492
- declare const arcsinh: OwnedFunction<(x: ArrayLike) => Array>;
1493
- /**
1494
- * @function
1495
- * Calculate element-wise inverse hyperbolic cosine of input.
1691
+ * This can be called with a varying number of arguments, just like the range()
1692
+ * builtin function in Python.
1496
1693
  *
1497
- * `arccosh(x) = ln(x + sqrt(x^2 - 1))`
1498
- */
1499
- declare const arccosh: OwnedFunction<(x: ArrayLike) => Array>;
1500
- /**
1501
- * @function
1502
- * Calculate element-wise inverse hyperbolic tangent of input.
1694
+ * - `arange(stop)` is equivalent to `arange(0, stop, 1)`.
1695
+ * - `arange(start, stop)` is equivalent to `arange(start, stop, 1)`.
1696
+ * - `arange(start, stop, step)` creates an array starting at `start`, ending
1697
+ * before `stop`, with a step size of `step`.
1503
1698
  *
1504
- * `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
1699
+ * Defaults to an integer data type. This can produce unintended results when
1700
+ * using a non-integer step, so prefer linspace() in those cases.
1505
1701
  */
1506
- declare const arctanh: OwnedFunction<(x: ArrayLike) => Array>;
1507
- /** @function Alias of `jax.numpy.arcsinh()`. */
1508
- declare const asinh: OwnedFunction<(x: ArrayLike) => Array>;
1509
- /** @function Alias of `jax.numpy.arccosh()`. */
1510
- declare const acosh: OwnedFunction<(x: ArrayLike) => Array>;
1511
- /** @function Alias of `jax.numpy.arctanh()`. */
1512
- declare const atanh: OwnedFunction<(x: ArrayLike) => Array>;
1702
+ declare function arange(start: number, stop?: number, step?: number, {
1703
+ dtype,
1704
+ device
1705
+ }?: DTypeAndDevice): Array;
1513
1706
  /**
1514
- * Compute the variance of an array.
1707
+ * Return evenly spaced numbers over a specified interval.
1515
1708
  *
1516
- * The variance is computed for the flattened array by default, otherwise over
1517
- * the specified axis.
1709
+ * Returns _num_ evenly spaced samples, calculated over the interval
1710
+ * [`start`, `stop`]. The endpoint `stop` is included in the result by default,
1711
+ * but this is controlled by the `endpoint` parameter.
1518
1712
  *
1519
- * If `correction` is provided, the divisor in calculation is `N - correction`,
1520
- * where `N` represents the number of elements (e.g., for Bessel's correction).
1713
+ * The default data type is Float32. Use arange() for integer steps.
1521
1714
  */
1522
- declare function var_(x: ArrayLike, axis?: Axis, opts?: {
1523
- mean?: ArrayLike;
1524
- correction?: number;
1525
- } & ReduceOpts): Array;
1715
+ declare function linspace(start: number, stop: number, num?: number, endpoint?: boolean, {
1716
+ dtype,
1717
+ device
1718
+ }?: DTypeAndDevice): Array;
1719
+ declare namespace lax_d_exports {
1720
+ export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convWithGeneralPadding, dot, erf, erfc, reduceWindow, stopGradient };
1721
+ }
1526
1722
  /**
1527
- * Compute the standard deviation of an array.
1723
+ * Dimension numbers for general `dot()` primitive.
1528
1724
  *
1529
- * The standard deviation is computed for the flattened array by default,
1530
- * otherwise over the specified axis.
1725
+ * Contracting dimensions act as a tensor contraction (reduction) along the
1726
+ * given axis. They must be the same size in both operands. Batch dimensions
1727
+ * are treated as vectorized, leading batch dimensions.
1531
1728
  *
1532
- * If `correction` is provided, the divisor in calculation is `N - correction`,
1533
- * where `N` represents the number of elements (e.g., for Bessel's correction).
1534
- */
1535
- declare function std(x: ArrayLike, axis?: Axis, opts?: {
1536
- mean?: ArrayLike;
1537
- correction?: number;
1538
- } & ReduceOpts): Array;
1539
- /** Test element-wise for positive or negative infinity, return bool array. */
1540
- declare function isinf(x: ArrayLike): Array;
1541
- /** Test element-wise for NaN (Not a Number). */
1542
- declare function isnan(x: ArrayLike): Array;
1543
- /** Test element-wise for negative infinity, return bool array. */
1544
- declare function isneginf(x: ArrayLike): Array;
1545
- /** Test element-wise for positive infinity, return bool array. */
1546
- declare function isposinf(x: ArrayLike): Array;
1547
- /**
1548
- * @function
1549
- * Test element-wise for finite values (not infinity or NaN).
1729
+ * The return value has a shape where the first dimensions are shared batch
1730
+ * dimensions, followed by `lhs` non-contracting dimensions, followed by
1731
+ * `rhs` non-contracting dimensions.
1550
1732
  */
1551
- declare const isfinite: OwnedFunction<(x: ArrayLike) => Array>;
1552
- //# sourceMappingURL=numpy.d.ts.map
1553
- //#endregion
1554
- //#region src/frontend/jaxpr.d.ts
1733
+ type DotDimensionNumbers = {
1734
+ lhsContractingDims?: number[];
1735
+ rhsContractingDims?: number[];
1736
+ lhsBatchDims?: number[];
1737
+ rhsBatchDims?: number[];
1738
+ };
1555
1739
  /**
1556
- * Function callback with an associated dispose() method.
1740
+ * General dot product/contraction operator.
1557
1741
  *
1558
- * The dispose() method should be called to clean up any tracer resources needed
1559
- * by the function after the last time it is called.
1742
+ * Prefer higher-level functions like `jax.numpy.dot()`, `jax.numpy.matmul()`,
1743
+ * `jax.numpy.tensordot(), and `jax.numpy.einsum()` where possible.
1560
1744
  */
1561
- type OwnedFunction<F extends Function> = F & {
1562
- dispose: () => void;
1563
- };
1564
- /** Variable in a Jaxpr expression. */
1565
- declare class Var {
1566
- #private;
1567
- readonly id: number;
1568
- readonly aval: ShapedArray;
1569
- constructor(aval: ShapedArray);
1570
- toString(): string;
1571
- }
1572
- /** Literal in a Jaxpr expression. Currently, only scalars are supported. */
1573
- declare class Lit {
1574
- readonly value: number;
1575
- readonly aval: ShapedArray;
1576
- get dtype(): DType;
1577
- constructor(aval: AbstractValue, value: number);
1578
- }
1579
- type Atom = Var | Lit;
1580
- declare class VarPrinter {
1581
- #private;
1582
- names: Map<Var, string>;
1583
- name(v: Var): string;
1584
- nameType(v: Var): string;
1585
- }
1586
- /** A single statement / binding in a Jaxpr, in ANF form. */
1587
- declare class JaxprEqn {
1588
- readonly primitive: Primitive;
1589
- readonly inputs: Atom[];
1590
- readonly params: Record<string, any>;
1591
- readonly outBinders: Var[];
1592
- constructor(primitive: Primitive, inputs: Atom[], params: Record<string, any>, outBinders: Var[]);
1593
- pprint(usedVars?: Set<Var>, vp?: VarPrinter): PPrint;
1594
- toString(): string;
1595
- }
1596
- /** Typed intermediate representation for traced computations. */
1597
- declare class Jaxpr implements FpHashable {
1598
- #private;
1599
- readonly inBinders: Var[];
1600
- readonly eqns: JaxprEqn[];
1601
- readonly outs: Atom[];
1602
- constructor(inBinders: Var[], eqns: JaxprEqn[], outs: Atom[]);
1603
- pprint(): PPrint;
1604
- toString(): string;
1605
- /**
1606
- * Gets a hash of this Jaxpr.
1607
- *
1608
- * Var identity is not considered in the hash, so two Jaxprs with the same
1609
- * order of assignments and operators but different variable IDs will resolve
1610
- * to the same hash (and toString representation).
1611
- */
1612
- getHash(): bigint;
1613
- hash(state: FpHash): void;
1614
- /**
1615
- * Produce a simplified Jaxpr with basic optimizations applied.
1616
- * - Trim away unused variables.
1617
- * - Fold away *1, *0, or +0 operations against literals.
1618
- * - Remove no-op movement operations.
1619
- */
1620
- simplify(): Jaxpr;
1621
- /** Flattens nested JitCall in a Jaxpr. Useful for handling jit-of-jit. */
1622
- flatten(): Jaxpr;
1623
- }
1624
- /** @inline */
1625
- type JitOpts = {
1626
- staticArgnums?: number[];
1627
- };
1628
- declare namespace lax_d_exports {
1629
- export { PaddingType, conv, convGeneralDilated, convWithGeneralPadding, erf, erfc, reduceWindow, stopGradient };
1630
- }
1745
+ declare function dot(lhs: Array, rhs: Array, {
1746
+ lhsContractingDims: lc,
1747
+ rhsContractingDims: rc,
1748
+ lhsBatchDims: lb,
1749
+ rhsBatchDims: rb
1750
+ }?: DotDimensionNumbers): Array;
1631
1751
  type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | [number, number][];
1632
1752
  /**
1633
1753
  * General n-dimensional convolution operator, with optional dilation.
@@ -1639,10 +1759,12 @@ type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | [number, number][];
1639
1759
  */
1640
1760
  declare function convGeneralDilated(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, {
1641
1761
  lhsDilation,
1642
- rhsDilation
1762
+ rhsDilation,
1763
+ featureGroupCount
1643
1764
  }?: {
1644
1765
  lhsDilation?: number[];
1645
1766
  rhsDilation?: number[];
1767
+ featureGroupCount?: number;
1646
1768
  }): Array;
1647
1769
  /** Convenience wrapper around `convGeneralDilated`. */
1648
1770
  declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, lhsDilation?: number[], rhsDilation?: number[]): Array;
@@ -1809,9 +1931,9 @@ declare function logSoftmax(x: ArrayLike, axis?: Axis): Array;
1809
1931
  *
1810
1932
  * Reference: https://en.wikipedia.org/wiki/LogSumExp
1811
1933
  */
1812
- declare function logsumexp(x: ArrayLike, axis?: Axis): Array;
1934
+ declare function logsumexp(x: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1813
1935
  /** Log-mean-exp reduction, like `jax.nn.logsumexp()` but subtracts `log(n)`. */
1814
- declare function logmeanexp(x: ArrayLike, axis?: Axis): Array;
1936
+ declare function logmeanexp(x: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1815
1937
  /**
1816
1938
  * Standardizes input to zero mean and unit variance.
1817
1939
  *