@jax-js/jax 0.1.0 → 0.1.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/{backend-DwIAd0AG.js → backend-BqymqzuU.js} +194 -73
- package/dist/{backend-FtkbO6pI.cjs → backend-DeVfWEFS.cjs} +194 -73
- package/dist/index.cjs +2725 -2206
- package/dist/index.d.cts +964 -844
- package/dist/index.d.ts +964 -844
- package/dist/index.js +2698 -2179
- package/dist/{webgpu-LGi2A3mS.js → webgpu-BGuG58KZ.js} +20 -13
- package/dist/{webgpu-BE7zA_01.cjs → webgpu-CcGP160M.cjs} +20 -13
- package/package.json +1 -1
package/dist/index.d.cts
CHANGED
|
@@ -168,9 +168,10 @@ declare enum DType {
|
|
|
168
168
|
Uint32 = "uint32",
|
|
169
169
|
Bool = "bool",
|
|
170
170
|
Float16 = "float16",
|
|
171
|
+
Float64 = "float64",
|
|
171
172
|
}
|
|
172
173
|
/** @inline */
|
|
173
|
-
type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | Float16Array<ArrayBuffer>;
|
|
174
|
+
type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | Float16Array<ArrayBuffer> | Float64Array<ArrayBuffer>;
|
|
174
175
|
/**
|
|
175
176
|
* Promote two dtypes to their join according to the type lattice.
|
|
176
177
|
*
|
|
@@ -180,7 +181,7 @@ type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Arr
|
|
|
180
181
|
*
|
|
181
182
|
* **Type lattice:**
|
|
182
183
|
* ```text
|
|
183
|
-
* bool -> uint32 -> int32 -> float16 -> float32
|
|
184
|
+
* bool -> uint32 -> int32 -> float16 -> float32 -> float64
|
|
184
185
|
* weakType --^
|
|
185
186
|
* ```
|
|
186
187
|
*
|
|
@@ -226,6 +227,8 @@ declare class AluExp implements FpHashable {
|
|
|
226
227
|
static erf(a: AluExp): AluExp;
|
|
227
228
|
static erfc(a: AluExp): AluExp;
|
|
228
229
|
static sqrt(a: AluExp): AluExp;
|
|
230
|
+
static floor(a: AluExp): AluExp;
|
|
231
|
+
static ceil(a: AluExp): AluExp;
|
|
229
232
|
static reciprocal(a: AluExp): AluExp;
|
|
230
233
|
static cast(dtype: DType, a: AluExp): AluExp;
|
|
231
234
|
static bitcast(dtype: DType, a: AluExp): AluExp;
|
|
@@ -243,6 +246,7 @@ declare class AluExp implements FpHashable {
|
|
|
243
246
|
static u32(value: number): AluExp;
|
|
244
247
|
static bool(value: boolean): AluExp;
|
|
245
248
|
static f16(value: number): AluExp;
|
|
249
|
+
static f64(value: number): AluExp;
|
|
246
250
|
not(): AluExp;
|
|
247
251
|
/** Compute a reasonable expression hash with low collision rate. */
|
|
248
252
|
getHash(): bigint;
|
|
@@ -315,6 +319,8 @@ declare enum AluOp {
|
|
|
315
319
|
Erf = "Erf",
|
|
316
320
|
Erfc = "Erfc",
|
|
317
321
|
Sqrt = "Sqrt",
|
|
322
|
+
Floor = "Floor",
|
|
323
|
+
Ceil = "Ceil",
|
|
318
324
|
Reciprocal = "Reciprocal",
|
|
319
325
|
Cast = "Cast",
|
|
320
326
|
Bitcast = "Bitcast",
|
|
@@ -457,680 +463,85 @@ declare class Executable<T = any> {
|
|
|
457
463
|
constructor(kernel: Kernel, /** Extra data specific to the backend running this kernel. */
|
|
458
464
|
data: T);
|
|
459
465
|
}
|
|
460
|
-
declare namespace
|
|
461
|
-
export {
|
|
462
|
-
}
|
|
463
|
-
declare enum NodeType {
|
|
464
|
-
Array = "Array",
|
|
465
|
-
Object = "Object",
|
|
466
|
-
Leaf = "Leaf",
|
|
467
|
-
}
|
|
468
|
-
/** Analog to the JAX "pytree" object, but for JavaScript. */
|
|
469
|
-
type JsTree<T> = T | JsTree<T>[] | {
|
|
470
|
-
[key: string]: JsTree<T>;
|
|
471
|
-
};
|
|
472
|
-
type Same<X, Y> = (<T>() => T extends X ? 1 : 2) extends (<T>() => T extends Y ? 1 : 2) ? true : false;
|
|
473
|
-
type MappedJsTree<T, A, B> = T extends A ? B : T extends Array ? T : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
|
|
474
|
-
/** @ignore Convert a subtype of JsTree<A> into a JsTree<B>, with the same structure. */
|
|
475
|
-
type MapJsTree<T, A, B> = Same<A, B> extends true ? T : MappedJsTree<T, A, B>;
|
|
476
|
-
/** Represents the structure of a JsTree. */
|
|
477
|
-
declare class JsTreeDef {
|
|
478
|
-
readonly nodeType: NodeType;
|
|
479
|
-
readonly nodeMetadata: any;
|
|
480
|
-
readonly childTreedefs: JsTreeDef[];
|
|
481
|
-
static leaf: JsTreeDef;
|
|
482
|
-
constructor(nodeType: NodeType, nodeMetadata: any,
|
|
483
|
-
// Must be comparable with deepEqual.
|
|
484
|
-
childTreedefs: JsTreeDef[]);
|
|
485
|
-
/** Get the total number of leaves in the tree. */
|
|
486
|
-
get size(): number;
|
|
487
|
-
/** Returns a string representation of this tree definition. */
|
|
488
|
-
toString(root?: boolean): string;
|
|
489
|
-
/** Compare this tree definition with another. */
|
|
490
|
-
equals(other: JsTreeDef): boolean;
|
|
491
|
-
}
|
|
492
|
-
/** Flatten a structured object, returning the tree definition. */
|
|
493
|
-
declare function flatten<T>(tree: JsTree<T>): [T[], JsTreeDef];
|
|
494
|
-
/** Get the leaves of a tree. */
|
|
495
|
-
declare function leaves<T>(tree: JsTree<T>): T[];
|
|
496
|
-
/** Get the treedef for a tree. */
|
|
497
|
-
declare function structure<T>(tree: JsTree<T>): JsTreeDef;
|
|
498
|
-
/** Reconstruct a structured object from the flattened representation. */
|
|
499
|
-
declare function unflatten<T>(treedef: JsTreeDef, leaves: Iterable<T>): JsTree<T>;
|
|
500
|
-
/** Maps a multi-input function over pytree args to produce a new pytree. */
|
|
501
|
-
declare function map<T, U, Tree extends JsTree<T>>(fn: (...args: T[]) => U, tree: Tree, ...rest: Tree[]): MapJsTree<Tree, T, U>;
|
|
502
|
-
/** Take a reference of every array in a tree. */
|
|
503
|
-
declare function ref<Tree extends JsTree<Array>>(tree: Tree): Tree;
|
|
504
|
-
/** Dispose every array in a tree. */
|
|
505
|
-
declare function dispose<Tree extends JsTree<Array>>(tree: Tree | null | undefined): void;
|
|
506
|
-
//#endregion
|
|
507
|
-
//#region src/frontend/convolution.d.ts
|
|
508
|
-
/** Definition of a general dilated convolution. Should be valid on creation. */
|
|
509
|
-
interface ConvParams {
|
|
510
|
-
strides: number[];
|
|
511
|
-
padding: [number, number][];
|
|
512
|
-
lhsDilation: number[];
|
|
513
|
-
rhsDilation: number[];
|
|
466
|
+
declare namespace numpy_d_exports {
|
|
467
|
+
export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot$1 as dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, float64, floor, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, positive, pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, squeeze, stack, std, subtract, sum, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
514
468
|
}
|
|
469
|
+
declare const float32 = DType.Float32;
|
|
470
|
+
declare const int32 = DType.Int32;
|
|
471
|
+
declare const uint32 = DType.Uint32;
|
|
472
|
+
declare const bool = DType.Bool;
|
|
473
|
+
declare const float16 = DType.Float16;
|
|
474
|
+
declare const float64 = DType.Float64;
|
|
475
|
+
/** Euler's constant, `e = 2.7182818284590...` */
|
|
476
|
+
declare const e: number;
|
|
477
|
+
/** Euler-Mascheroni constant, `γ = 0.5772156649...` */
|
|
478
|
+
declare const eulerGamma = 0.5772156649015329;
|
|
479
|
+
/** Positive infinity. */
|
|
480
|
+
declare const inf: number;
|
|
481
|
+
/** Floating-point representation of NaN. */
|
|
482
|
+
declare const nan: number;
|
|
483
|
+
/** This is Pi, `π = 3.14159265358979...` */
|
|
484
|
+
declare const pi: number;
|
|
485
|
+
/** @function Element-wise addition, with broadcasting. */
|
|
486
|
+
declare const add: (x: ArrayLike, y: ArrayLike) => Array;
|
|
487
|
+
/** @function Element-wise multiplication, with broadcasting. */
|
|
488
|
+
declare const multiply: (x: ArrayLike, y: ArrayLike) => Array;
|
|
489
|
+
/** @function Numerical negative of every element of an array. */
|
|
490
|
+
declare const negative: (x: ArrayLike) => Array;
|
|
491
|
+
/** @function Calculate element-wise reciprocal of the input. This is `1/x`. */
|
|
492
|
+
declare const reciprocal: (x: ArrayLike) => Array;
|
|
493
|
+
/** @function Round input down to the nearest integer. */
|
|
494
|
+
declare const floor: (x: ArrayLike) => Array;
|
|
495
|
+
/** @function Round input up to the nearest integer. */
|
|
496
|
+
declare const ceil: (x: ArrayLike) => Array;
|
|
497
|
+
/** @function Element-wise sine function (takes radians). */
|
|
498
|
+
declare const sin: (x: ArrayLike) => Array;
|
|
499
|
+
/** @function Element-wise cosine function (takes radians). */
|
|
500
|
+
declare const cos: (x: ArrayLike) => Array;
|
|
501
|
+
/** @function Element-wise inverse sine function (inverse of sin). */
|
|
502
|
+
declare const asin: (x: ArrayLike) => Array;
|
|
503
|
+
/** @function Element-wise inverse tangent function (inverse of tan). */
|
|
504
|
+
declare const atan: (x: ArrayLike) => Array;
|
|
505
|
+
/** @function Calculate the exponential of all elements in the input array. */
|
|
506
|
+
declare const exp: (x: ArrayLike) => Array;
|
|
507
|
+
/** @function Calculate the natural logarithm of all elements in the input array. */
|
|
508
|
+
declare const log: (x: ArrayLike) => Array;
|
|
509
|
+
/** @function Calculate the square root of all elements in the input array. */
|
|
510
|
+
declare const sqrt: (x: ArrayLike) => Array;
|
|
511
|
+
/** @function Return element-wise minimum of the input arrays. */
|
|
512
|
+
declare const minimum: (x: ArrayLike, y: ArrayLike) => Array;
|
|
513
|
+
/** @function Return element-wise maximum of the input arrays. */
|
|
514
|
+
declare const maximum: (x: ArrayLike, y: ArrayLike) => Array;
|
|
515
|
+
/** @function Compare two arrays element-wise. */
|
|
516
|
+
declare const greater: (x: ArrayLike, y: ArrayLike) => Array;
|
|
517
|
+
/** @function Compare two arrays element-wise. */
|
|
518
|
+
declare const less: (x: ArrayLike, y: ArrayLike) => Array;
|
|
519
|
+
/** @function Compare two arrays element-wise. */
|
|
520
|
+
declare const equal: (x: ArrayLike, y: ArrayLike) => Array;
|
|
521
|
+
/** @function Compare two arrays element-wise. */
|
|
522
|
+
declare const notEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
523
|
+
/** @function Compare two arrays element-wise. */
|
|
524
|
+
declare const greaterEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
525
|
+
/** @function Compare two arrays element-wise. */
|
|
526
|
+
declare const lessEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
527
|
+
/** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
528
|
+
declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
|
|
515
529
|
/**
|
|
516
|
-
*
|
|
517
|
-
*
|
|
518
|
-
* If the check succeeds, returns the output shape.
|
|
530
|
+
* @function
|
|
531
|
+
* Permute the dimensions of an array. Defaults to reversing the axis order.
|
|
519
532
|
*/
|
|
520
|
-
|
|
521
|
-
//#endregion
|
|
522
|
-
//#region src/frontend/core.d.ts
|
|
533
|
+
declare const transpose: (x: ArrayLike, perm?: number[]) => Array;
|
|
523
534
|
/**
|
|
524
|
-
*
|
|
525
|
-
*
|
|
526
|
-
*
|
|
527
|
-
* Any operation between arrays can be described in these parts. This is also
|
|
528
|
-
* the set of primitives that can occur in Jaxpr programs, and the level at
|
|
529
|
-
* which transformations like vmap, grad, and jvp occur. They are loosely based
|
|
530
|
-
* on [XLA](https://openxla.org/xla/operation_semantics).
|
|
535
|
+
* @function
|
|
536
|
+
* Give a new shape to an array without changing its data.
|
|
531
537
|
*
|
|
532
|
-
*
|
|
538
|
+
* One shape dimension can be -1. In this case, the value is inferred from the
|
|
539
|
+
* length of the array and remaining dimensions.
|
|
533
540
|
*/
|
|
534
|
-
declare
|
|
535
|
-
Add = "add",
|
|
536
|
-
Mul = "mul",
|
|
537
|
-
Idiv = "idiv",
|
|
538
|
-
Neg = "neg",
|
|
539
|
-
Reciprocal = "reciprocal",
|
|
540
|
-
StopGradient = "stop_gradient",
|
|
541
|
-
Cast = "cast",
|
|
542
|
-
Bitcast = "bitcast",
|
|
543
|
-
RandomBits = "random_bits",
|
|
544
|
-
Sin = "sin",
|
|
545
|
-
Cos = "cos",
|
|
546
|
-
Asin = "asin",
|
|
547
|
-
Atan = "atan",
|
|
548
|
-
Exp = "exp",
|
|
549
|
-
Log = "log",
|
|
550
|
-
Erf = "erf",
|
|
551
|
-
Erfc = "erfc",
|
|
552
|
-
Sqrt = "sqrt",
|
|
553
|
-
Min = "min",
|
|
554
|
-
Max = "max",
|
|
555
|
-
Reduce = "reduce",
|
|
556
|
-
Dot = "dot",
|
|
557
|
-
// sum(x*y, axis=-1)
|
|
558
|
-
Conv = "conv",
|
|
559
|
-
// see lax.conv_general_dilated
|
|
560
|
-
Pool = "pool",
|
|
561
|
-
PoolTranspose = "pool_transpose",
|
|
562
|
-
Compare = "compare",
|
|
563
|
-
Where = "where",
|
|
564
|
-
Transpose = "transpose",
|
|
565
|
-
Broadcast = "broadcast",
|
|
566
|
-
Reshape = "reshape",
|
|
567
|
-
Flip = "flip",
|
|
568
|
-
Shrink = "shrink",
|
|
569
|
-
Pad = "pad",
|
|
570
|
-
Gather = "gather",
|
|
571
|
-
JitCall = "jit_call",
|
|
572
|
-
}
|
|
573
|
-
interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
574
|
-
[Primitive.Cast]: {
|
|
575
|
-
dtype: DType;
|
|
576
|
-
};
|
|
577
|
-
[Primitive.Bitcast]: {
|
|
578
|
-
dtype: DType;
|
|
579
|
-
};
|
|
580
|
-
[Primitive.Reduce]: {
|
|
581
|
-
op: AluOp;
|
|
582
|
-
axis: number[];
|
|
583
|
-
};
|
|
584
|
-
[Primitive.Conv]: ConvParams;
|
|
585
|
-
[Primitive.Pool]: {
|
|
586
|
-
window: number[];
|
|
587
|
-
strides: number[];
|
|
588
|
-
};
|
|
589
|
-
[Primitive.PoolTranspose]: {
|
|
590
|
-
inShape: number[];
|
|
591
|
-
window: number[];
|
|
592
|
-
strides: number[];
|
|
593
|
-
};
|
|
594
|
-
[Primitive.Compare]: {
|
|
595
|
-
op: CompareOp;
|
|
596
|
-
};
|
|
597
|
-
[Primitive.Transpose]: {
|
|
598
|
-
perm: number[];
|
|
599
|
-
};
|
|
600
|
-
[Primitive.Broadcast]: {
|
|
601
|
-
shape: number[];
|
|
602
|
-
axis: number[];
|
|
603
|
-
};
|
|
604
|
-
[Primitive.RandomBits]: {
|
|
605
|
-
shape: number[];
|
|
606
|
-
mode: "xor" | 0 | 1;
|
|
607
|
-
};
|
|
608
|
-
[Primitive.Reshape]: {
|
|
609
|
-
shape: number[];
|
|
610
|
-
};
|
|
611
|
-
[Primitive.Flip]: {
|
|
612
|
-
axis: number[];
|
|
613
|
-
};
|
|
614
|
-
[Primitive.Shrink]: {
|
|
615
|
-
slice: Pair[];
|
|
616
|
-
};
|
|
617
|
-
[Primitive.Pad]: {
|
|
618
|
-
width: Pair[];
|
|
619
|
-
};
|
|
620
|
-
[Primitive.Gather]: {
|
|
621
|
-
axis: number[];
|
|
622
|
-
outDim: number;
|
|
623
|
-
};
|
|
624
|
-
[Primitive.JitCall]: {
|
|
625
|
-
name: string;
|
|
626
|
-
jaxpr: Jaxpr;
|
|
627
|
-
numConsts: number;
|
|
628
|
-
};
|
|
629
|
-
}
|
|
630
|
-
/** Type of parameters taken by each primitive. */
|
|
631
|
-
type PrimitiveParams<T extends Primitive> = T extends keyof PrimitiveParamsImpl ? PrimitiveParamsImpl[T] : Record<string, never>;
|
|
632
|
-
declare enum CompareOp {
|
|
633
|
-
Greater = "greater",
|
|
634
|
-
Less = "less",
|
|
635
|
-
Equal = "equal",
|
|
636
|
-
NotEqual = "not_equal",
|
|
637
|
-
GreaterEqual = "greater_equal",
|
|
638
|
-
LessEqual = "less_equal",
|
|
639
|
-
}
|
|
640
|
-
/** @inline */
|
|
641
|
-
type Axis = number | number[] | null;
|
|
642
|
-
/** @inline */
|
|
643
|
-
type ReduceOpts = {
|
|
644
|
-
keepdims?: boolean;
|
|
645
|
-
};
|
|
646
|
-
type MainTrace = {
|
|
647
|
-
level: number;
|
|
648
|
-
traceType: new (main: MainTrace) => Trace;
|
|
649
|
-
globalData: any | null;
|
|
650
|
-
};
|
|
541
|
+
declare const reshape: (x: ArrayLike, shape: number[]) => Array;
|
|
651
542
|
/**
|
|
652
|
-
*
|
|
653
|
-
*
|
|
654
|
-
*/
|
|
655
|
-
|
|
656
|
-
type TracerValue = Tracer | number | boolean;
|
|
657
|
-
declare abstract class Trace {
|
|
658
|
-
readonly main: MainTrace;
|
|
659
|
-
constructor(main: MainTrace);
|
|
660
|
-
abstract pure(val: TracerValue): Tracer;
|
|
661
|
-
abstract lift(val: Tracer): Tracer;
|
|
662
|
-
abstract processPrimitive<P extends Primitive>(primitive: P, tracers: Tracer[], params: PrimitiveParams<P>): Tracer[];
|
|
663
|
-
}
|
|
664
|
-
/** Internal representation of an array value. */
|
|
665
|
-
interface AbstractValue {
|
|
666
|
-
/** Shape of the array. Must be a static tuple of non-negative dimensions. */
|
|
667
|
-
shape: number[];
|
|
668
|
-
/** Concrete data type of array elements. */
|
|
669
|
-
dtype: DType;
|
|
670
|
-
/**
|
|
671
|
-
* Arrays created from JavaScript numbers (e.g., `np.array(3)`) are created as
|
|
672
|
-
* _weakly typed_ unless a dtype is explicitly specified.
|
|
673
|
-
*
|
|
674
|
-
* Weakly typed values will automatically cast to the data type of other
|
|
675
|
-
* arrays when used as an operand as an expression. This property only affects
|
|
676
|
-
* how they promote in type casting; their memory layout is still determined
|
|
677
|
-
* by the actual `dtype` field.
|
|
678
|
-
*
|
|
679
|
-
* ```ts
|
|
680
|
-
* const x = np.array(3); // weakType = true, dtype = float32
|
|
681
|
-
* const y = np.array([1, 2], { dtype: np.int32 }); // weakType = false, dtype = int32
|
|
682
|
-
* const z = x.add(y); // z has dtype int32 because x is weakly typed
|
|
683
|
-
* ```
|
|
684
|
-
*
|
|
685
|
-
* Weak types are present in JIT programs in their spec (e.g., Jaxpr inputs
|
|
686
|
-
* and outputs can be weakly typed) form. But they're solely a frontend
|
|
687
|
-
* concept. Backends are not aware of weak types.
|
|
688
|
-
*/
|
|
689
|
-
weakType: boolean;
|
|
690
|
-
}
|
|
691
|
-
/**
|
|
692
|
-
* Broadcast shapes and promote types with casting for two avals.
|
|
693
|
-
*
|
|
694
|
-
* This implements the weak type behavior described in `promoteTypes()`, but not
|
|
695
|
-
* implemented in that function as `weakType` is not passed.
|
|
696
|
-
*/
|
|
697
|
-
|
|
698
|
-
declare abstract class Tracer {
|
|
699
|
-
/** @ignore */
|
|
700
|
-
readonly _trace: Trace;
|
|
701
|
-
constructor(trace: Trace);
|
|
702
|
-
abstract get aval(): AbstractValue;
|
|
703
|
-
abstract toString(): string;
|
|
704
|
-
/**
|
|
705
|
-
* Access an array by reference, incrementing the reference count.
|
|
706
|
-
*
|
|
707
|
-
* jax-js handles freeing arrays by using "move" semantics, like in Rust/C++.
|
|
708
|
-
* Whenever you pass an array into a function, that function should consume
|
|
709
|
-
* the array, and it will no longer be usable. For example, if you had:
|
|
710
|
-
*
|
|
711
|
-
* ```
|
|
712
|
-
* const x = np.array([1, 2, 3]);
|
|
713
|
-
* const y = np.add(x, x);
|
|
714
|
-
* ```
|
|
715
|
-
*
|
|
716
|
-
* The second line does not work because the first parameter consumes `x`, and
|
|
717
|
-
* then the second parameter will already have been freed / disposed.
|
|
718
|
-
*
|
|
719
|
-
* To fix this, you can write:
|
|
720
|
-
*
|
|
721
|
-
* ```
|
|
722
|
-
* const y = np.add(x.ref, x);
|
|
723
|
-
* ```
|
|
724
|
-
*
|
|
725
|
-
* Under the hood, every access to `.ref` increments the internal reference
|
|
726
|
-
* count of the array. The reference count starts at 1. When it hits 0, the
|
|
727
|
-
* memory behind the array is freed.
|
|
728
|
-
*/
|
|
729
|
-
abstract get ref(): this;
|
|
730
|
-
/**
|
|
731
|
-
* Manually decrement the reference count of the array.
|
|
732
|
-
*
|
|
733
|
-
* Arrays are created with reference count 1. Whenever it is used as argument
|
|
734
|
-
* to a function or other operation, it is disposed (i.e., reference count
|
|
735
|
-
* decreases by 1) automatically. Whenever a `.ref` is created, the reference
|
|
736
|
-
* count increases.
|
|
737
|
-
*
|
|
738
|
-
* You generally don't need to call this function directly since arrays are
|
|
739
|
-
* automatically disposed after being passed into an operation. One common
|
|
740
|
-
* exception is when writing a function and ignoring one of its arguments. In
|
|
741
|
-
* that case, by convention you should dispose of that argument manually.
|
|
742
|
-
*
|
|
743
|
-
* ```
|
|
744
|
-
* function myCustomOperation(a: np.Array, b: np.Array) {
|
|
745
|
-
* b.dispose(); // Needed to satisfy "move" rules.
|
|
746
|
-
* return a.add(1);
|
|
747
|
-
* }
|
|
748
|
-
* ```
|
|
749
|
-
*/
|
|
750
|
-
abstract dispose(): void;
|
|
751
|
-
/** The shape of the array. */
|
|
752
|
-
get shape(): number[];
|
|
753
|
-
/** The total number of elements in the array. */
|
|
754
|
-
get size(): number;
|
|
755
|
-
/** The dtype of elements stored in the array. */
|
|
756
|
-
get dtype(): DType;
|
|
757
|
-
/**
|
|
758
|
-
* Whether the array is weakly typed.
|
|
759
|
-
*
|
|
760
|
-
* Weakly typed arrays will cast to the dtype of the other operand. See
|
|
761
|
-
* `promoteTypes()` for details.
|
|
762
|
-
*/
|
|
763
|
-
get weakType(): boolean;
|
|
764
|
-
/** The number of dimensions of the array. */
|
|
765
|
-
get ndim(): number;
|
|
766
|
-
/** @ignore */
|
|
767
|
-
fullLower(): Tracer;
|
|
768
|
-
neg(): this;
|
|
769
|
-
add(other: this | TracerValue): this;
|
|
770
|
-
mul(other: this | TracerValue): this;
|
|
771
|
-
greater(other: this | TracerValue): this;
|
|
772
|
-
less(other: this | TracerValue): this;
|
|
773
|
-
equal(other: this | TracerValue): this;
|
|
774
|
-
notEqual(other: this | TracerValue): this;
|
|
775
|
-
greaterEqual(other: this | TracerValue): this;
|
|
776
|
-
lessEqual(other: this | TracerValue): this;
|
|
777
|
-
/** Sum of the elements of the array over a given axis, or axes. */
|
|
778
|
-
sum(axis?: Axis, opts?: ReduceOpts): this;
|
|
779
|
-
/** Product of the array elements over a given axis. */
|
|
780
|
-
prod(axis?: Axis, opts?: ReduceOpts): this;
|
|
781
|
-
/** Compute the average of the array elements along the specified axis. */
|
|
782
|
-
mean(axis?: Axis, opts?: ReduceOpts): this;
|
|
783
|
-
/** Permute the dimensions of an array. Defaults to reversing the axis order. */
|
|
784
|
-
transpose(perm?: number[]): this;
|
|
785
|
-
/**
|
|
786
|
-
* Give a new shape to an array without changing its data.
|
|
787
|
-
*
|
|
788
|
-
* One shape dimension can be -1. In this case, the value is inferred from the
|
|
789
|
-
* length of the array and remaining dimensions.
|
|
790
|
-
*/
|
|
791
|
-
reshape(shape: number | number[]): this;
|
|
792
|
-
/** Copy the array and cast to a specified dtype. */
|
|
793
|
-
astype(dtype: DType): this;
|
|
794
|
-
/** Subtract an array from this one. */
|
|
795
|
-
sub(other: this | TracerValue): this;
|
|
796
|
-
/** Divide an array by this one. */
|
|
797
|
-
div(other: this | TracerValue): this;
|
|
798
|
-
/** Return specified diagonals. See `numpy.diagonal` for full docs. */
|
|
799
|
-
diagonal(offset?: number, axis1?: number, axis2?: number): this;
|
|
800
|
-
/** Flatten the array without changing its data. */
|
|
801
|
-
flatten(): this;
|
|
802
|
-
/** Flatten the array without changing its data. */
|
|
803
|
-
ravel(): this;
|
|
804
|
-
/**
|
|
805
|
-
* Iterate over the first dimension of this array, returning slices.
|
|
806
|
-
*
|
|
807
|
-
* This can be used to destructure arrays. For example:
|
|
808
|
-
*
|
|
809
|
-
* ```js
|
|
810
|
-
* let x = np.array([[1, 2], [3, 4]]);
|
|
811
|
-
* let [a, b] = x;
|
|
812
|
-
* console.log(a.js()); // [1, 2]
|
|
813
|
-
* console.log(b.js()); // [3, 4]
|
|
814
|
-
* ```
|
|
815
|
-
*/
|
|
816
|
-
[Symbol.iterator](): IterableIterator<this>;
|
|
817
|
-
/**
|
|
818
|
-
* Slice an array along one or more axes.
|
|
819
|
-
*
|
|
820
|
-
* This is the equivalent of slicing in Python, e.g. `x[1:3, 2, :, None]`. To
|
|
821
|
-
* mimic this in JavaScript, we would write:
|
|
822
|
-
*
|
|
823
|
-
* ```js
|
|
824
|
-
* x.slice([1, 3], 2, [], null);
|
|
825
|
-
* ```
|
|
826
|
-
*
|
|
827
|
-
* The `slice` method accepts a variable number of arguments, each of which
|
|
828
|
-
* can be a number, an empty array, a single-element array, a two-element
|
|
829
|
-
* array, or `null`. The arguments are interpreted as follows:
|
|
830
|
-
*
|
|
831
|
-
* - A number `n` means to access the `n`-th element along that axis, removing
|
|
832
|
-
* that axis from the resulting shape.
|
|
833
|
-
* - An empty array `[]` means to keep that axis as-is, like `:` in Python.
|
|
834
|
-
* - A single-element array `[i]` means to start slicing from index `i`
|
|
835
|
-
* (inclusive) to the end of the axis, like `x[i:]`.
|
|
836
|
-
* - A two-element array `[i, j]` means to slice from index `i` (inclusive)
|
|
837
|
-
* to index `j` (exclusive), like `x[i:j]`.
|
|
838
|
-
* - `null` means to add a new axis at that position, like `np.newaxis`.
|
|
839
|
-
*
|
|
840
|
-
* Like in Python, negative indices are supported, which count from the end of
|
|
841
|
-
* the axis. For example, `-1` means the last element.
|
|
842
|
-
*
|
|
843
|
-
* Strided slices are not yet implemented, so you cannot write `x[::2]` or
|
|
844
|
-
* similar.
|
|
845
|
-
*
|
|
846
|
-
* Advanced indexing by integer arrays is also supported. This translates to
|
|
847
|
-
* the "gather" primitive, and it allows you to access specific elements of
|
|
848
|
-
* the array by integer indices stored in another array.
|
|
849
|
-
*/
|
|
850
|
-
slice(...index: (number | [] | [number] | Pair | null | Tracer)[]): this;
|
|
851
|
-
}
|
|
852
|
-
declare class ShapedArray implements AbstractValue {
|
|
853
|
-
readonly shape: number[];
|
|
854
|
-
readonly dtype: DType;
|
|
855
|
-
readonly weakType: boolean;
|
|
856
|
-
constructor(shape: number[], dtype: DType, weakType: boolean);
|
|
857
|
-
static fromAval(aval: AbstractValue): ShapedArray;
|
|
858
|
-
get ndim(): number;
|
|
859
|
-
toString(): string;
|
|
860
|
-
equals(other: ShapedArray): boolean;
|
|
861
|
-
}
|
|
862
|
-
//#endregion
|
|
863
|
-
//#region src/frontend/array.d.ts
|
|
864
|
-
type ArrayLike = Array | number | boolean;
|
|
865
|
-
/** Version of pureArray with fudged types. */
|
|
866
|
-
|
|
867
|
-
/**
|
|
868
|
-
* An executable operation that will be dispatched to the backend.
|
|
869
|
-
*
|
|
870
|
-
* This holds a reference to all input buffers used in the operation. After the
|
|
871
|
-
* operation is dispatched, the references should be released.
|
|
872
|
-
*/
|
|
873
|
-
declare class PendingExecute {
|
|
874
|
-
#private;
|
|
875
|
-
readonly backend: Backend;
|
|
876
|
-
readonly kernel: Kernel;
|
|
877
|
-
readonly inputs: Slot[];
|
|
878
|
-
readonly outputs: Slot[];
|
|
879
|
-
prepared: Executable | null;
|
|
880
|
-
submitted: boolean;
|
|
881
|
-
constructor(backend: Backend, kernel: Kernel, inputs: Slot[], outputs: Slot[]);
|
|
882
|
-
updateRc(delta: number): void;
|
|
883
|
-
prepare(): Promise<void>;
|
|
884
|
-
prepareSync(): void;
|
|
885
|
-
submit(): void;
|
|
886
|
-
}
|
|
887
|
-
/** @inline */
|
|
888
|
-
type DTypeAndDevice = {
|
|
889
|
-
dtype?: DType;
|
|
890
|
-
device?: Device;
|
|
891
|
-
};
|
|
892
|
-
type ArrayConstructorArgs = {
|
|
893
|
-
source: AluExp | Slot;
|
|
894
|
-
st: ShapeTracker;
|
|
895
|
-
dtype: DType;
|
|
896
|
-
weakType: boolean;
|
|
897
|
-
backend: Backend;
|
|
898
|
-
committed: boolean;
|
|
899
|
-
pending?: Iterable<PendingExecute>;
|
|
900
|
-
};
|
|
901
|
-
/**
|
|
902
|
-
* A multidimensional numeric array with data stored on CPU or GPU.
|
|
903
|
-
*
|
|
904
|
-
* This is the library's core data type. Equivalent to `jax.Array` from JAX, or
|
|
905
|
-
* `torch.Tensor`.
|
|
906
|
-
*
|
|
907
|
-
* Not to be confused with the JavaScript "Array" constructor. Avoid importing
|
|
908
|
-
* this into your code's namespace if you're already using the JavaScript
|
|
909
|
-
* "Array" type by name.
|
|
910
|
-
*/
|
|
911
|
-
declare class Array extends Tracer {
|
|
912
|
-
#private;
|
|
913
|
-
id: number;
|
|
914
|
-
/**
|
|
915
|
-
* @ignore
|
|
916
|
-
* Constructs an array from source, shape and backend. Note that if the source
|
|
917
|
-
* is a backend `Slot`, this constructor _takes ownership_ of the slot. It
|
|
918
|
-
* will be freed when the array is disposed.
|
|
919
|
-
*/
|
|
920
|
-
constructor(args: ArrayConstructorArgs);
|
|
921
|
-
/** @ignore */
|
|
922
|
-
get aval(): ShapedArray;
|
|
923
|
-
/** Return a simple string representation of the array's dimensions. */
|
|
924
|
-
toString(): string;
|
|
925
|
-
get device(): Device;
|
|
926
|
-
get ref(): this;
|
|
927
|
-
dispose(): void;
|
|
928
|
-
/**
|
|
929
|
-
* Convert this array into a primitive value.
|
|
930
|
-
*
|
|
931
|
-
* This only works for scalars (0-dimensional arrays). It lets you get values
|
|
932
|
-
* "out" of the JAX system. For instance, if `x = np.array(5)`, then you can
|
|
933
|
-
* evaluate `x + 1` and `x ** 2` to get `6` and `25`, respectively.
|
|
934
|
-
*
|
|
935
|
-
* This method is also called for `==` equality.
|
|
936
|
-
*/
|
|
937
|
-
[Symbol.toPrimitive](): any;
|
|
938
|
-
/** Realize the array and return it as data. */
|
|
939
|
-
data(): Promise<DataArray>;
|
|
940
|
-
/**
|
|
941
|
-
* Wait for this array to finish evaluation.
|
|
942
|
-
*
|
|
943
|
-
* Operations and data loading in jax-js are lazy, so this function ensures
|
|
944
|
-
* that pending operations are dispatched and fully executed before it
|
|
945
|
-
* returns.
|
|
946
|
-
*
|
|
947
|
-
* If you are mapping from `data()` or `dataSync()`, it will also trigger
|
|
948
|
-
* dispatch of operations as well.
|
|
949
|
-
*
|
|
950
|
-
* **Note:** `jax.blockUntilReady()` is a higher-level API, it calls this
|
|
951
|
-
* asynchronously for multiple arrays.
|
|
952
|
-
*/
|
|
953
|
-
blockUntilReady(): Promise<Array>;
|
|
954
|
-
/**
|
|
955
|
-
* Realize the array and return it as data. This is a sync variant and not
|
|
956
|
-
* recommended for performance reasons, as it will block rendering.
|
|
957
|
-
*/
|
|
958
|
-
dataSync(): DataArray;
|
|
959
|
-
/**
|
|
960
|
-
* Convert this array into a JavaScript object.
|
|
961
|
-
*
|
|
962
|
-
* This is a blocking operation that will compile all of the shaders and wait
|
|
963
|
-
* for execution to complete, synchronously. No other JavaScript code on the
|
|
964
|
-
* site will be run during shader execution.
|
|
965
|
-
*
|
|
966
|
-
* To avoid blocking, prefer `jsAsync()` when possible.
|
|
967
|
-
*/
|
|
968
|
-
js(): any;
|
|
969
|
-
/** Convert this array into a JavaScript object, asynchronously. */
|
|
970
|
-
jsAsync(): Promise<any>;
|
|
971
|
-
/**
|
|
972
|
-
* Copy an element of an array to a numeric scalar and return it.
|
|
973
|
-
*
|
|
974
|
-
* Throws an error if the array does not have a single element. The array must
|
|
975
|
-
* either be rank-0, or all dimensions of the shape are 1.
|
|
976
|
-
*/
|
|
977
|
-
item(): number;
|
|
978
|
-
/** @private Internal plumbing method for Array / Tracer ops. */
|
|
979
|
-
static _implRules(): typeof implRules;
|
|
980
|
-
/** @private */
|
|
981
|
-
_realizeSource(): number;
|
|
982
|
-
/** @private Put this array on a new backend, asynchronously. */
|
|
983
|
-
_put(backend: Backend): Promise<Array>;
|
|
984
|
-
/** @private Put this array on a new backend, synchronously. */
|
|
985
|
-
_putSync(backend: Backend): Array;
|
|
986
|
-
}
|
|
987
|
-
/** Constructor for creating a new array from data. */
|
|
988
|
-
declare function array(values: Array | Float16Array<ArrayBuffer> | Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | RecursiveArray<number> | RecursiveArray<boolean>, {
|
|
989
|
-
shape,
|
|
990
|
-
dtype,
|
|
991
|
-
device
|
|
992
|
-
}?: {
|
|
993
|
-
shape?: number[];
|
|
994
|
-
} & DTypeAndDevice): Array;
|
|
995
|
-
/** If x is a value, lift it into an array, otherwise leave it be. */
|
|
996
|
-
|
|
997
|
-
type ImplRule<P extends Primitive> = (tracers: Array[], params: PrimitiveParams<P>) => Array[];
|
|
998
|
-
declare const implRules: { [P in Primitive]: ImplRule<P> };
|
|
999
|
-
/** Return a new array of given shape and type, filled with zeros. */
|
|
1000
|
-
declare function zeros(shape: number[], {
|
|
1001
|
-
dtype,
|
|
1002
|
-
device
|
|
1003
|
-
}?: DTypeAndDevice): Array;
|
|
1004
|
-
/** Return a new array of given shape and type, filled with ones. */
|
|
1005
|
-
declare function ones(shape: number[], {
|
|
1006
|
-
dtype,
|
|
1007
|
-
device
|
|
1008
|
-
}?: DTypeAndDevice): Array;
|
|
1009
|
-
/** Return a new array of given shape and type, filled with `fill_value`. */
|
|
1010
|
-
declare function full(shape: number[], fillValue: number | boolean | Array, {
|
|
1011
|
-
dtype,
|
|
1012
|
-
device
|
|
1013
|
-
}?: DTypeAndDevice): Array;
|
|
1014
|
-
/**
|
|
1015
|
-
* Create an identity matrix.
|
|
1016
|
-
*
|
|
1017
|
-
* If numCols is not provided, it defaults to numRows, i.e., a square identity
|
|
1018
|
-
* matrix with ones on the diagonal.
|
|
1019
|
-
*/
|
|
1020
|
-
declare function eye(numRows: number, numCols?: number, {
|
|
1021
|
-
dtype,
|
|
1022
|
-
device
|
|
1023
|
-
}?: DTypeAndDevice): Array;
|
|
1024
|
-
/** Return the identity matrix, with ones on the main diagonal. */
|
|
1025
|
-
declare function identity$1(n: number, {
|
|
1026
|
-
dtype,
|
|
1027
|
-
device
|
|
1028
|
-
}?: DTypeAndDevice): Array;
|
|
1029
|
-
/**
|
|
1030
|
-
* Return evenly spaced values within a given interval.
|
|
1031
|
-
*
|
|
1032
|
-
* This can be called with a varying number of arguments, just like the range()
|
|
1033
|
-
* builtin function in Python.
|
|
1034
|
-
*
|
|
1035
|
-
* - `arange(stop)` is equivalent to `arange(0, stop, 1)`.
|
|
1036
|
-
* - `arange(start, stop)` is equivalent to `arange(start, stop, 1)`.
|
|
1037
|
-
* - `arange(start, stop, step)` creates an array starting at `start`, ending
|
|
1038
|
-
* before `stop`, with a step size of `step`.
|
|
1039
|
-
*
|
|
1040
|
-
* Defaults to an integer data type. This can produce unintended results when
|
|
1041
|
-
* using a non-integer step, so prefer linspace() in those cases.
|
|
1042
|
-
*/
|
|
1043
|
-
declare function arange(start: number, stop?: number, step?: number, {
|
|
1044
|
-
dtype,
|
|
1045
|
-
device
|
|
1046
|
-
}?: DTypeAndDevice): Array;
|
|
1047
|
-
/**
|
|
1048
|
-
* Return evenly spaced numbers over a specified interval.
|
|
1049
|
-
*
|
|
1050
|
-
* Returns _num_ evenly spaced samples, calculated over the interval
|
|
1051
|
-
* [`start`, `stop`]. The endpoint `stop` is included in the result by default,
|
|
1052
|
-
* but this is controlled by the `endpoint` parameter.
|
|
1053
|
-
*
|
|
1054
|
-
* The default data type is Float32. Use arange() for integer steps.
|
|
1055
|
-
*/
|
|
1056
|
-
declare function linspace(start: number, stop: number, num?: number, endpoint?: boolean, {
|
|
1057
|
-
dtype,
|
|
1058
|
-
device
|
|
1059
|
-
}?: DTypeAndDevice): Array;
|
|
1060
|
-
declare namespace numpy_d_exports {
|
|
1061
|
-
export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
1062
|
-
}
|
|
1063
|
-
declare const float32 = DType.Float32;
|
|
1064
|
-
declare const int32 = DType.Int32;
|
|
1065
|
-
declare const uint32 = DType.Uint32;
|
|
1066
|
-
declare const bool = DType.Bool;
|
|
1067
|
-
declare const float16 = DType.Float16;
|
|
1068
|
-
/** Euler's constant, `e = 2.7182818284590...` */
|
|
1069
|
-
declare const e: number;
|
|
1070
|
-
/** Euler-Mascheroni constant, `γ = 0.5772156649...` */
|
|
1071
|
-
declare const eulerGamma = 0.5772156649015329;
|
|
1072
|
-
/** Positive infinity. */
|
|
1073
|
-
declare const inf: number;
|
|
1074
|
-
/** Floating-point representation of NaN. */
|
|
1075
|
-
declare const nan: number;
|
|
1076
|
-
/** This is Pi, `π = 3.14159265358979...` */
|
|
1077
|
-
declare const pi: number;
|
|
1078
|
-
/** @function Element-wise addition, with broadcasting. */
|
|
1079
|
-
declare const add: (x: ArrayLike, y: ArrayLike) => Array;
|
|
1080
|
-
/** @function Element-wise multiplication, with broadcasting. */
|
|
1081
|
-
declare const multiply: (x: ArrayLike, y: ArrayLike) => Array;
|
|
1082
|
-
/** @function Numerical negative of every element of an array. */
|
|
1083
|
-
declare const negative: (x: ArrayLike) => Array;
|
|
1084
|
-
/** @function Calculate element-wise reciprocal of the input. This is `1/x`. */
|
|
1085
|
-
declare const reciprocal: (x: ArrayLike) => Array;
|
|
1086
|
-
/** @function Element-wise sine function (takes radians). */
|
|
1087
|
-
declare const sin: (x: ArrayLike) => Array;
|
|
1088
|
-
/** @function Element-wise cosine function (takes radians). */
|
|
1089
|
-
declare const cos: (x: ArrayLike) => Array;
|
|
1090
|
-
/** @function Element-wise inverse sine function (inverse of sin). */
|
|
1091
|
-
declare const asin: (x: ArrayLike) => Array;
|
|
1092
|
-
/** @function Element-wise inverse tangent function (inverse of tan). */
|
|
1093
|
-
declare const atan: (x: ArrayLike) => Array;
|
|
1094
|
-
/** @function Calculate the exponential of all elements in the input array. */
|
|
1095
|
-
declare const exp: (x: ArrayLike) => Array;
|
|
1096
|
-
/** @function Calculate the natural logarithm of all elements in the input array. */
|
|
1097
|
-
declare const log: (x: ArrayLike) => Array;
|
|
1098
|
-
/** @function Calculate the square root of all elements in the input array. */
|
|
1099
|
-
declare const sqrt: (x: ArrayLike) => Array;
|
|
1100
|
-
/** @function Return element-wise minimum of the input arrays. */
|
|
1101
|
-
declare const minimum: (x: ArrayLike, y: ArrayLike) => Array;
|
|
1102
|
-
/** @function Return element-wise maximum of the input arrays. */
|
|
1103
|
-
declare const maximum: (x: ArrayLike, y: ArrayLike) => Array;
|
|
1104
|
-
/** @function Compare two arrays element-wise. */
|
|
1105
|
-
declare const greater: (x: ArrayLike, y: ArrayLike) => Array;
|
|
1106
|
-
/** @function Compare two arrays element-wise. */
|
|
1107
|
-
declare const less: (x: ArrayLike, y: ArrayLike) => Array;
|
|
1108
|
-
/** @function Compare two arrays element-wise. */
|
|
1109
|
-
declare const equal: (x: ArrayLike, y: ArrayLike) => Array;
|
|
1110
|
-
/** @function Compare two arrays element-wise. */
|
|
1111
|
-
declare const notEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
1112
|
-
/** @function Compare two arrays element-wise. */
|
|
1113
|
-
declare const greaterEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
1114
|
-
/** @function Compare two arrays element-wise. */
|
|
1115
|
-
declare const lessEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
1116
|
-
/** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
1117
|
-
declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
|
|
1118
|
-
/**
|
|
1119
|
-
* @function
|
|
1120
|
-
* Permute the dimensions of an array. Defaults to reversing the axis order.
|
|
1121
|
-
*/
|
|
1122
|
-
declare const transpose: (x: ArrayLike, perm?: number[]) => Array;
|
|
1123
|
-
/**
|
|
1124
|
-
* @function
|
|
1125
|
-
* Give a new shape to an array without changing its data.
|
|
1126
|
-
*
|
|
1127
|
-
* One shape dimension can be -1. In this case, the value is inferred from the
|
|
1128
|
-
* length of the array and remaining dimensions.
|
|
1129
|
-
*/
|
|
1130
|
-
declare const reshape: (x: ArrayLike, shape: number[]) => Array;
|
|
1131
|
-
/**
|
|
1132
|
-
* @function
|
|
1133
|
-
* Move axes of an array to new positions. Other axes retain original order.
|
|
543
|
+
* @function
|
|
544
|
+
* Move axes of an array to new positions. Other axes retain original order.
|
|
1134
545
|
*/
|
|
1135
546
|
declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
|
|
1136
547
|
/**
|
|
@@ -1179,6 +590,8 @@ declare function prod(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
|
1179
590
|
declare function min(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1180
591
|
/** Return the maximum of array elements along a given axis. */
|
|
1181
592
|
declare function max(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
593
|
+
/** Return the peak-to-peak range along a given axis (`max - min`). */
|
|
594
|
+
declare function ptp(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1182
595
|
/** Compute the average of the array elements along the specified axis. */
|
|
1183
596
|
declare function mean(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1184
597
|
/**
|
|
@@ -1244,6 +657,8 @@ declare function fliplr(x: ArrayLike): Array;
|
|
|
1244
657
|
declare const permuteDims: (x: ArrayLike, perm?: number[]) => Array;
|
|
1245
658
|
/** Return a 1-D flattened array containing the elements of the input. */
|
|
1246
659
|
declare function ravel(a: ArrayLike): Array;
|
|
660
|
+
/** Remove one or more length-1 axes from an array. */
|
|
661
|
+
declare function squeeze(a: ArrayLike, axis?: Axis): Array;
|
|
1247
662
|
/**
|
|
1248
663
|
* Repeat each element of an array after themselves.
|
|
1249
664
|
*
|
|
@@ -1288,6 +703,8 @@ declare function diagonal(a: ArrayLike, offset?: number, axis1?: number, axis2?:
|
|
|
1288
703
|
* array, return a 2D array with v on the k-th diagonal.
|
|
1289
704
|
*/
|
|
1290
705
|
declare function diag(v: ArrayLike, k?: number): Array;
|
|
706
|
+
/** Calculate the sum of the diagonal of an array along the given axes. */
|
|
707
|
+
declare function trace(a: ArrayLike, offset?: number, axis1?: number, axis2?: number): Array;
|
|
1291
708
|
/** Return if two arrays are element-wise equal within a tolerance. */
|
|
1292
709
|
declare function allclose(actual: Parameters<typeof array>[0], expected: Parameters<typeof array>[0], options?: {
|
|
1293
710
|
rtol?: number;
|
|
@@ -1296,7 +713,41 @@ declare function allclose(actual: Parameters<typeof array>[0], expected: Paramet
|
|
|
1296
713
|
/** Matrix product of two arrays. */
|
|
1297
714
|
declare function matmul(x: ArrayLike, y: ArrayLike): Array;
|
|
1298
715
|
/** Dot product of two arrays. */
|
|
1299
|
-
declare function dot(x: ArrayLike, y: ArrayLike): Array;
|
|
716
|
+
declare function dot$1(x: ArrayLike, y: ArrayLike): Array;
|
|
717
|
+
/**
|
|
718
|
+
* Compute the tensor dot product of two N-dimensional arrays.
|
|
719
|
+
*
|
|
720
|
+
* The behavior is determined by `axes`. If an integer `k`, sum over the last
|
|
721
|
+
* `k` axes of x and the first `k` axes of y. If a tuple, then the first array
|
|
722
|
+
* corresponds to the axes of x and the second to the axes of y.
|
|
723
|
+
*/
|
|
724
|
+
declare function tensordot(x: ArrayLike, y: ArrayLike, axes?: number | [number[], number[]]): Array;
|
|
725
|
+
/**
|
|
726
|
+
* Einstein summation with string subscripts.
|
|
727
|
+
*
|
|
728
|
+
* @example
|
|
729
|
+
* ```ts
|
|
730
|
+
* import { numpy as np } from "@jax-js/jax";
|
|
731
|
+
*
|
|
732
|
+
* const a = np.ones([2, 3]);
|
|
733
|
+
* const b = np.ones([3]);
|
|
734
|
+
* np.einsum("ij,j", a, b); // Shape [2]
|
|
735
|
+
* ```
|
|
736
|
+
*/
|
|
737
|
+
declare function einsum(subscripts: string, ...operands: ArrayLike[]): Array;
|
|
738
|
+
/**
|
|
739
|
+
* Einstein summation alternating between arrays and numeric indices.
|
|
740
|
+
*
|
|
741
|
+
* @example
|
|
742
|
+
* ```ts
|
|
743
|
+
* import { numpy as np } from "@jax-js/jax";
|
|
744
|
+
*
|
|
745
|
+
* const a = np.ones([2, 3]);
|
|
746
|
+
* const b = np.ones([3]);
|
|
747
|
+
* np.einsum(a, [0, 1], b, [1]); // Shape [2]
|
|
748
|
+
* ```
|
|
749
|
+
*/
|
|
750
|
+
declare function einsum(...args: (ArrayLike | number[])[]): Array;
|
|
1300
751
|
/**
|
|
1301
752
|
* Compute the inner product of two arrays.
|
|
1302
753
|
*
|
|
@@ -1371,6 +822,8 @@ declare function absolute(x: ArrayLike): Array;
|
|
|
1371
822
|
declare const abs: typeof absolute;
|
|
1372
823
|
/** Return an element-wise indication of sign of the input. */
|
|
1373
824
|
declare function sign(x: ArrayLike): Array;
|
|
825
|
+
/** @function Return element-wise positive values of the input (no-op). */
|
|
826
|
+
declare const positive: (x: ArrayLike) => Array;
|
|
1374
827
|
/**
|
|
1375
828
|
* Return the Hamming window of size M, a taper with a weighted cosine bell.
|
|
1376
829
|
*
|
|
@@ -1407,213 +860,880 @@ declare function acos(x: ArrayLike): Array;
|
|
|
1407
860
|
*/
|
|
1408
861
|
declare const hypot: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
1409
862
|
/**
|
|
1410
|
-
* @function
|
|
1411
|
-
* Element-wise arc tangent of y/x with correct quadrant.
|
|
1412
|
-
*
|
|
1413
|
-
* Returns the angle in radians between the positive x-axis and the point (x, y).
|
|
1414
|
-
* The result is in the range [-π, π].
|
|
1415
|
-
*
|
|
1416
|
-
* Uses numerically stable formulas:
|
|
1417
|
-
* - When x >= 0: atan2(y, x) = 2 * atan(y / (sqrt(x^2 + y^2) + x))
|
|
1418
|
-
* - When x < 0: atan2(y, x) = 2 * atan((sqrt(x^2 + y^2) - x) / y)
|
|
863
|
+
* @function
|
|
864
|
+
* Element-wise arc tangent of y/x with correct quadrant.
|
|
865
|
+
*
|
|
866
|
+
* Returns the angle in radians between the positive x-axis and the point (x, y).
|
|
867
|
+
* The result is in the range [-π, π].
|
|
868
|
+
*
|
|
869
|
+
* Uses numerically stable formulas:
|
|
870
|
+
* - When x >= 0: atan2(y, x) = 2 * atan(y / (sqrt(x^2 + y^2) + x))
|
|
871
|
+
* - When x < 0: atan2(y, x) = 2 * atan((sqrt(x^2 + y^2) - x) / y)
|
|
872
|
+
*
|
|
873
|
+
* The output is ill-defined when both x and y are zero.
|
|
874
|
+
*/
|
|
875
|
+
declare const atan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
|
|
876
|
+
/** @function Alias of `jax.numpy.acos()`. */
|
|
877
|
+
declare const arccos: typeof acos;
|
|
878
|
+
/** @function Alias of `jax.numpy.atan()`. */
|
|
879
|
+
declare const arctan: (x: ArrayLike) => Array;
|
|
880
|
+
/** @function Alias of `jax.numpy.atan2()`. */
|
|
881
|
+
declare const arctan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
|
|
882
|
+
/** Element-wise subtraction, with broadcasting. */
|
|
883
|
+
declare function subtract(x: ArrayLike, y: ArrayLike): Array;
|
|
884
|
+
/** Calculates the floating-point division of x by y element-wise. */
|
|
885
|
+
declare function trueDivide(x: ArrayLike, y: ArrayLike): Array;
|
|
886
|
+
/**
|
|
887
|
+
* @function
|
|
888
|
+
* Calculate element-wise floating-point modulo operation.
|
|
889
|
+
*/
|
|
890
|
+
declare const fmod: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
|
|
891
|
+
/**
|
|
892
|
+
* @function
|
|
893
|
+
* Calculate element-wise remainder of the division (matches sign of y).
|
|
894
|
+
*/
|
|
895
|
+
declare const remainder: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
|
|
896
|
+
/** @function Alias of `jax.numpy.trueDivide()`. */
|
|
897
|
+
declare const divide: typeof trueDivide;
|
|
898
|
+
/** Round input to the nearest integer towards zero. */
|
|
899
|
+
declare function trunc(x: ArrayLike): Array;
|
|
900
|
+
/**
|
|
901
|
+
* Compute `x1 * 2 ** x2` as a standard multiplication and exponentiation.
|
|
902
|
+
*
|
|
903
|
+
* This is the inverse of `frexp()`.
|
|
904
|
+
*/
|
|
905
|
+
declare function ldexp(x1: ArrayLike, x2: ArrayLike): Array;
|
|
906
|
+
/**
|
|
907
|
+
* Decompose floating-point values into mantissa and two's exponent.
|
|
908
|
+
*
|
|
909
|
+
* The mantissa is returned in the range `(-1, 1)` with magnitude `>= 0.5` if
|
|
910
|
+
* `x != 0`, and the exponent is an integer such that
|
|
911
|
+
* `x = mantissa * 2**exponent`.
|
|
912
|
+
*/
|
|
913
|
+
declare function frexp(x: ArrayLike): [Array, Array];
|
|
914
|
+
/** Calculate `2**p` for all p in the input array. */
|
|
915
|
+
declare function exp2(p: ArrayLike): Array;
|
|
916
|
+
/** Return the base-2 logarithm of x, element-wise. */
|
|
917
|
+
declare function log2(x: ArrayLike): Array;
|
|
918
|
+
/** Return the base-10 logarithm of x, element-wise. */
|
|
919
|
+
declare function log10(x: ArrayLike): Array;
|
|
920
|
+
/** Calculate `exp(x) - 1` element-wise. */
|
|
921
|
+
declare function expm1(x: ArrayLike): Array;
|
|
922
|
+
/** Calculate the natural logarithm of `1 + x` element-wise. */
|
|
923
|
+
declare function log1p(x: ArrayLike): Array;
|
|
924
|
+
/** Convert angles from degrees to radians. */
|
|
925
|
+
declare function deg2rad(x: ArrayLike): Array;
|
|
926
|
+
/** @function Alias of `jax.numpy.deg2rad()`. */
|
|
927
|
+
declare const radians: typeof deg2rad;
|
|
928
|
+
/** Convert angles from radians to degrees. */
|
|
929
|
+
declare function rad2deg(x: ArrayLike): Array;
|
|
930
|
+
/** @function Alias of `jax.numpy.rad2deg()`. */
|
|
931
|
+
declare const degrees: typeof rad2deg;
|
|
932
|
+
/**
|
|
933
|
+
* @function
|
|
934
|
+
* Computes first array raised to power of second array, element-wise.
|
|
935
|
+
*/
|
|
936
|
+
declare const power: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
937
|
+
/** @function Alias of `jax.numpy.power()`. */
|
|
938
|
+
declare const pow: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
939
|
+
/** @function Calculate the element-wise cube root of the input array. */
|
|
940
|
+
declare const cbrt: OwnedFunction<(x: ArrayLike) => Array>;
|
|
941
|
+
/**
|
|
942
|
+
* @function
|
|
943
|
+
* Calculate element-wise hyperbolic sine of input.
|
|
944
|
+
*
|
|
945
|
+
* `sinh(x) = (exp(x) - exp(-x)) / 2`
|
|
946
|
+
*/
|
|
947
|
+
declare const sinh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
948
|
+
/**
|
|
949
|
+
* @function
|
|
950
|
+
* Calculate element-wise hyperbolic cosine of input.
|
|
951
|
+
*
|
|
952
|
+
* `cosh(x) = (exp(x) + exp(-x)) / 2`
|
|
953
|
+
*/
|
|
954
|
+
declare const cosh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
955
|
+
/**
|
|
956
|
+
* @function
|
|
957
|
+
* Calculate element-wise hyperbolic tangent of input.
|
|
958
|
+
*
|
|
959
|
+
* `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
|
|
960
|
+
*/
|
|
961
|
+
declare const tanh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
962
|
+
/**
|
|
963
|
+
* @function
|
|
964
|
+
* Calculate element-wise inverse hyperbolic sine of input.
|
|
965
|
+
*
|
|
966
|
+
* `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
|
|
967
|
+
*/
|
|
968
|
+
declare const arcsinh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
969
|
+
/**
|
|
970
|
+
* @function
|
|
971
|
+
* Calculate element-wise inverse hyperbolic cosine of input.
|
|
972
|
+
*
|
|
973
|
+
* `arccosh(x) = ln(x + sqrt(x^2 - 1))`
|
|
974
|
+
*/
|
|
975
|
+
declare const arccosh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
976
|
+
/**
|
|
977
|
+
* @function
|
|
978
|
+
* Calculate element-wise inverse hyperbolic tangent of input.
|
|
979
|
+
*
|
|
980
|
+
* `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
|
|
981
|
+
*/
|
|
982
|
+
declare const arctanh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
983
|
+
/** @function Alias of `jax.numpy.arcsinh()`. */
|
|
984
|
+
declare const asinh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
985
|
+
/** @function Alias of `jax.numpy.arccosh()`. */
|
|
986
|
+
declare const acosh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
987
|
+
/** @function Alias of `jax.numpy.arctanh()`. */
|
|
988
|
+
declare const atanh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
989
|
+
/**
|
|
990
|
+
* Compute the variance of an array.
|
|
991
|
+
*
|
|
992
|
+
* The variance is computed for the flattened array by default, otherwise over
|
|
993
|
+
* the specified axis.
|
|
994
|
+
*
|
|
995
|
+
* If `correction` is provided, the divisor in calculation is `N - correction`,
|
|
996
|
+
* where `N` represents the number of elements (e.g., for Bessel's correction).
|
|
997
|
+
*/
|
|
998
|
+
declare function var_(x: ArrayLike, axis?: Axis, opts?: {
|
|
999
|
+
mean?: ArrayLike;
|
|
1000
|
+
correction?: number;
|
|
1001
|
+
} & ReduceOpts): Array;
|
|
1002
|
+
/**
|
|
1003
|
+
* Compute the standard deviation of an array.
|
|
1004
|
+
*
|
|
1005
|
+
* The standard deviation is computed for the flattened array by default,
|
|
1006
|
+
* otherwise over the specified axis.
|
|
1007
|
+
*
|
|
1008
|
+
* If `correction` is provided, the divisor in calculation is `N - correction`,
|
|
1009
|
+
* where `N` represents the number of elements (e.g., for Bessel's correction).
|
|
1010
|
+
*/
|
|
1011
|
+
declare function std(x: ArrayLike, axis?: Axis, opts?: {
|
|
1012
|
+
mean?: ArrayLike;
|
|
1013
|
+
correction?: number;
|
|
1014
|
+
} & ReduceOpts): Array;
|
|
1015
|
+
/** Test element-wise for positive or negative infinity, return bool array. */
|
|
1016
|
+
declare function isinf(x: ArrayLike): Array;
|
|
1017
|
+
/** Test element-wise for NaN (Not a Number). */
|
|
1018
|
+
declare function isnan(x: ArrayLike): Array;
|
|
1019
|
+
/** Test element-wise for negative infinity, return bool array. */
|
|
1020
|
+
declare function isneginf(x: ArrayLike): Array;
|
|
1021
|
+
/** Test element-wise for positive infinity, return bool array. */
|
|
1022
|
+
declare function isposinf(x: ArrayLike): Array;
|
|
1023
|
+
/**
|
|
1024
|
+
* @function
|
|
1025
|
+
* Test element-wise for finite values (not infinity or NaN).
|
|
1026
|
+
*/
|
|
1027
|
+
declare const isfinite: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1028
|
+
//# sourceMappingURL=numpy.d.ts.map
|
|
1029
|
+
declare namespace tree_d_exports {
|
|
1030
|
+
export { JsTree, JsTreeDef, MapJsTree, NodeType, dispose, flatten, leaves, map, ref, structure, unflatten };
|
|
1031
|
+
}
|
|
1032
|
+
declare enum NodeType {
|
|
1033
|
+
Array = "Array",
|
|
1034
|
+
Object = "Object",
|
|
1035
|
+
Leaf = "Leaf",
|
|
1036
|
+
}
|
|
1037
|
+
/** Analog to the JAX "pytree" object, but for JavaScript. */
|
|
1038
|
+
type JsTree<T> = T | JsTree<T>[] | {
|
|
1039
|
+
[key: string]: JsTree<T>;
|
|
1040
|
+
};
|
|
1041
|
+
type Same<X, Y> = (<T>() => T extends X ? 1 : 2) extends (<T>() => T extends Y ? 1 : 2) ? true : false;
|
|
1042
|
+
type MappedJsTree<T, A, B> = T extends A ? B : T extends Array ? T : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
|
|
1043
|
+
/** @ignore Convert a subtype of JsTree<A> into a JsTree<B>, with the same structure. */
|
|
1044
|
+
type MapJsTree<T, A, B> = Same<A, B> extends true ? T : MappedJsTree<T, A, B>;
|
|
1045
|
+
/** Represents the structure of a JsTree. */
|
|
1046
|
+
declare class JsTreeDef {
|
|
1047
|
+
readonly nodeType: NodeType;
|
|
1048
|
+
readonly nodeMetadata: any;
|
|
1049
|
+
readonly childTreedefs: JsTreeDef[];
|
|
1050
|
+
static leaf: JsTreeDef;
|
|
1051
|
+
constructor(nodeType: NodeType, nodeMetadata: any,
|
|
1052
|
+
// Must be comparable with deepEqual.
|
|
1053
|
+
childTreedefs: JsTreeDef[]);
|
|
1054
|
+
/** Get the total number of leaves in the tree. */
|
|
1055
|
+
get size(): number;
|
|
1056
|
+
/** Returns a string representation of this tree definition. */
|
|
1057
|
+
toString(root?: boolean): string;
|
|
1058
|
+
/** Compare this tree definition with another. */
|
|
1059
|
+
equals(other: JsTreeDef): boolean;
|
|
1060
|
+
}
|
|
1061
|
+
/** Flatten a structured object, returning the tree definition. */
|
|
1062
|
+
declare function flatten<T>(tree: JsTree<T>): [T[], JsTreeDef];
|
|
1063
|
+
/** Get the leaves of a tree. */
|
|
1064
|
+
declare function leaves<T>(tree: JsTree<T>): T[];
|
|
1065
|
+
/** Get the treedef for a tree. */
|
|
1066
|
+
declare function structure<T>(tree: JsTree<T>): JsTreeDef;
|
|
1067
|
+
/** Reconstruct a structured object from the flattened representation. */
|
|
1068
|
+
declare function unflatten<T>(treedef: JsTreeDef, leaves: Iterable<T>): JsTree<T>;
|
|
1069
|
+
/** Maps a multi-input function over pytree args to produce a new pytree. */
|
|
1070
|
+
declare function map<T, U, Tree extends JsTree<T>>(fn: (...args: T[]) => U, tree: Tree, ...rest: Tree[]): MapJsTree<Tree, T, U>;
|
|
1071
|
+
/** Take a reference of every array in a tree. */
|
|
1072
|
+
declare function ref<Tree extends JsTree<Array>>(tree: Tree): Tree;
|
|
1073
|
+
/** Dispose every array in a tree. */
|
|
1074
|
+
declare function dispose<Tree extends JsTree<Array>>(tree: Tree | null | undefined): void;
|
|
1075
|
+
//#endregion
|
|
1076
|
+
//#region src/frontend/convolution.d.ts
|
|
1077
|
+
/** Definition of a general dilated convolution. Should be valid on creation. */
|
|
1078
|
+
interface ConvParams {
|
|
1079
|
+
strides: number[];
|
|
1080
|
+
padding: [number, number][];
|
|
1081
|
+
lhsDilation: number[];
|
|
1082
|
+
rhsDilation: number[];
|
|
1083
|
+
}
|
|
1084
|
+
/**
|
|
1085
|
+
* Check that the shapes and parameters passed to convolution are valid.
|
|
1419
1086
|
*
|
|
1420
|
-
*
|
|
1087
|
+
* If the check succeeds, returns the output shape.
|
|
1421
1088
|
*/
|
|
1422
|
-
|
|
1423
|
-
|
|
1424
|
-
declare const arccos: typeof acos;
|
|
1425
|
-
/** @function Alias of `jax.numpy.atan()`. */
|
|
1426
|
-
declare const arctan: (x: ArrayLike) => Array;
|
|
1427
|
-
/** @function Alias of `jax.numpy.atan2()`. */
|
|
1428
|
-
declare const arctan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
|
|
1429
|
-
/** Element-wise subtraction, with broadcasting. */
|
|
1430
|
-
declare function subtract(x: ArrayLike, y: ArrayLike): Array;
|
|
1431
|
-
/** Calculates the floating-point division of x by y element-wise. */
|
|
1432
|
-
declare function trueDivide(x: ArrayLike, y: ArrayLike): Array;
|
|
1433
|
-
/** @function Alias of `jax.numpy.trueDivide()`. */
|
|
1434
|
-
declare const divide: typeof trueDivide;
|
|
1435
|
-
/** Round input to the nearest integer towards zero. */
|
|
1436
|
-
declare function trunc(x: ArrayLike): Array;
|
|
1437
|
-
/** Calculate `2**p` for all p in the input array. */
|
|
1438
|
-
declare function exp2(p: ArrayLike): Array;
|
|
1439
|
-
/** Return the base-2 logarithm of x, element-wise. */
|
|
1440
|
-
declare function log2(x: ArrayLike): Array;
|
|
1441
|
-
/** Return the base-10 logarithm of x, element-wise. */
|
|
1442
|
-
declare function log10(x: ArrayLike): Array;
|
|
1443
|
-
/** Calculate `exp(x) - 1` element-wise. */
|
|
1444
|
-
declare function expm1(x: ArrayLike): Array;
|
|
1445
|
-
/** Calculate the natural logarithm of `1 + x` element-wise. */
|
|
1446
|
-
declare function log1p(x: ArrayLike): Array;
|
|
1447
|
-
/** Convert angles from degrees to radians. */
|
|
1448
|
-
declare function deg2rad(x: ArrayLike): Array;
|
|
1449
|
-
/** @function Alias of `jax.numpy.deg2rad()`. */
|
|
1450
|
-
declare const radians: typeof deg2rad;
|
|
1451
|
-
/** Convert angles from radians to degrees. */
|
|
1452
|
-
declare function rad2deg(x: ArrayLike): Array;
|
|
1453
|
-
/** @function Alias of `jax.numpy.rad2deg()`. */
|
|
1454
|
-
declare const degrees: typeof rad2deg;
|
|
1089
|
+
//#endregion
|
|
1090
|
+
//#region src/frontend/jaxpr.d.ts
|
|
1455
1091
|
/**
|
|
1456
|
-
*
|
|
1457
|
-
*
|
|
1092
|
+
* Function callback with an associated dispose() method.
|
|
1093
|
+
*
|
|
1094
|
+
* The dispose() method should be called to clean up any tracer resources needed
|
|
1095
|
+
* by the function after the last time it is called.
|
|
1458
1096
|
*/
|
|
1459
|
-
|
|
1460
|
-
|
|
1461
|
-
|
|
1462
|
-
/**
|
|
1463
|
-
declare
|
|
1097
|
+
type OwnedFunction<F extends Function> = F & {
|
|
1098
|
+
dispose: () => void;
|
|
1099
|
+
};
|
|
1100
|
+
/** Variable in a Jaxpr expression. */
|
|
1101
|
+
declare class Var {
|
|
1102
|
+
#private;
|
|
1103
|
+
readonly id: number;
|
|
1104
|
+
readonly aval: ShapedArray;
|
|
1105
|
+
constructor(aval: ShapedArray);
|
|
1106
|
+
toString(): string;
|
|
1107
|
+
}
|
|
1108
|
+
/** Literal in a Jaxpr expression. Currently, only scalars are supported. */
|
|
1109
|
+
declare class Lit {
|
|
1110
|
+
readonly value: number;
|
|
1111
|
+
readonly aval: ShapedArray;
|
|
1112
|
+
get dtype(): DType;
|
|
1113
|
+
constructor(aval: AbstractValue, value: number);
|
|
1114
|
+
}
|
|
1115
|
+
type Atom = Var | Lit;
|
|
1116
|
+
declare class VarPrinter {
|
|
1117
|
+
#private;
|
|
1118
|
+
names: Map<Var, string>;
|
|
1119
|
+
name(v: Var): string;
|
|
1120
|
+
nameType(v: Var): string;
|
|
1121
|
+
}
|
|
1122
|
+
/** A single statement / binding in a Jaxpr, in ANF form. */
|
|
1123
|
+
declare class JaxprEqn {
|
|
1124
|
+
readonly primitive: Primitive;
|
|
1125
|
+
readonly inputs: Atom[];
|
|
1126
|
+
readonly params: Record<string, any>;
|
|
1127
|
+
readonly outBinders: Var[];
|
|
1128
|
+
constructor(primitive: Primitive, inputs: Atom[], params: Record<string, any>, outBinders: Var[]);
|
|
1129
|
+
pprint(usedVars?: Set<Var>, vp?: VarPrinter): PPrint;
|
|
1130
|
+
toString(): string;
|
|
1131
|
+
}
|
|
1132
|
+
/** Typed intermediate representation for traced computations. */
|
|
1133
|
+
declare class Jaxpr implements FpHashable {
|
|
1134
|
+
#private;
|
|
1135
|
+
readonly inBinders: Var[];
|
|
1136
|
+
readonly eqns: JaxprEqn[];
|
|
1137
|
+
readonly outs: Atom[];
|
|
1138
|
+
constructor(inBinders: Var[], eqns: JaxprEqn[], outs: Atom[]);
|
|
1139
|
+
pprint(): PPrint;
|
|
1140
|
+
toString(): string;
|
|
1141
|
+
/**
|
|
1142
|
+
* Gets a hash of this Jaxpr.
|
|
1143
|
+
*
|
|
1144
|
+
* Var identity is not considered in the hash, so two Jaxprs with the same
|
|
1145
|
+
* order of assignments and operators but different variable IDs will resolve
|
|
1146
|
+
* to the same hash (and toString representation).
|
|
1147
|
+
*/
|
|
1148
|
+
getHash(): bigint;
|
|
1149
|
+
hash(state: FpHash): void;
|
|
1150
|
+
/**
|
|
1151
|
+
* Produce a simplified Jaxpr with basic optimizations applied.
|
|
1152
|
+
* - Trim away unused variables.
|
|
1153
|
+
* - Fold away *1, *0, or +0 operations against literals.
|
|
1154
|
+
* - Remove no-op movement operations.
|
|
1155
|
+
*/
|
|
1156
|
+
simplify(): Jaxpr;
|
|
1157
|
+
/** Flattens nested JitCall in a Jaxpr. Useful for handling jit-of-jit. */
|
|
1158
|
+
flatten(): Jaxpr;
|
|
1159
|
+
}
|
|
1160
|
+
/** @inline */
|
|
1161
|
+
type JitOpts = {
|
|
1162
|
+
staticArgnums?: number[];
|
|
1163
|
+
};
|
|
1164
|
+
//#endregion
|
|
1165
|
+
//#region src/frontend/core.d.ts
|
|
1464
1166
|
/**
|
|
1465
|
-
*
|
|
1466
|
-
*
|
|
1167
|
+
* Frontend primitive operations, which are lowered into Kernel objects before
|
|
1168
|
+
* being dispatched to the backend.
|
|
1467
1169
|
*
|
|
1468
|
-
*
|
|
1170
|
+
* Any operation between arrays can be described in these parts. This is also
|
|
1171
|
+
* the set of primitives that can occur in Jaxpr programs, and the level at
|
|
1172
|
+
* which transformations like vmap, grad, and jvp occur. They are loosely based
|
|
1173
|
+
* on [XLA](https://openxla.org/xla/operation_semantics).
|
|
1174
|
+
*
|
|
1175
|
+
* All n-ary operations support broadcasting, with NumPy semantics.
|
|
1176
|
+
*/
|
|
1177
|
+
declare enum Primitive {
|
|
1178
|
+
Add = "add",
|
|
1179
|
+
Mul = "mul",
|
|
1180
|
+
Idiv = "idiv",
|
|
1181
|
+
Mod = "mod",
|
|
1182
|
+
// uses sign of dividend, C-style, matches JS but not Python
|
|
1183
|
+
Neg = "neg",
|
|
1184
|
+
Reciprocal = "reciprocal",
|
|
1185
|
+
Floor = "floor",
|
|
1186
|
+
Ceil = "ceil",
|
|
1187
|
+
StopGradient = "stop_gradient",
|
|
1188
|
+
Cast = "cast",
|
|
1189
|
+
Bitcast = "bitcast",
|
|
1190
|
+
RandomBits = "random_bits",
|
|
1191
|
+
Sin = "sin",
|
|
1192
|
+
Cos = "cos",
|
|
1193
|
+
Asin = "asin",
|
|
1194
|
+
Atan = "atan",
|
|
1195
|
+
Exp = "exp",
|
|
1196
|
+
Log = "log",
|
|
1197
|
+
Erf = "erf",
|
|
1198
|
+
Erfc = "erfc",
|
|
1199
|
+
Sqrt = "sqrt",
|
|
1200
|
+
Min = "min",
|
|
1201
|
+
Max = "max",
|
|
1202
|
+
Reduce = "reduce",
|
|
1203
|
+
Dot = "dot",
|
|
1204
|
+
// sum(x*y, axis=-1)
|
|
1205
|
+
Conv = "conv",
|
|
1206
|
+
// see lax.conv_general_dilated
|
|
1207
|
+
Pool = "pool",
|
|
1208
|
+
PoolTranspose = "pool_transpose",
|
|
1209
|
+
Compare = "compare",
|
|
1210
|
+
Where = "where",
|
|
1211
|
+
Transpose = "transpose",
|
|
1212
|
+
Broadcast = "broadcast",
|
|
1213
|
+
Reshape = "reshape",
|
|
1214
|
+
Flip = "flip",
|
|
1215
|
+
Shrink = "shrink",
|
|
1216
|
+
Pad = "pad",
|
|
1217
|
+
Gather = "gather",
|
|
1218
|
+
JitCall = "jit_call",
|
|
1219
|
+
}
|
|
1220
|
+
interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
1221
|
+
[Primitive.Cast]: {
|
|
1222
|
+
dtype: DType;
|
|
1223
|
+
};
|
|
1224
|
+
[Primitive.Bitcast]: {
|
|
1225
|
+
dtype: DType;
|
|
1226
|
+
};
|
|
1227
|
+
[Primitive.Reduce]: {
|
|
1228
|
+
op: AluOp;
|
|
1229
|
+
axis: number[];
|
|
1230
|
+
};
|
|
1231
|
+
[Primitive.Conv]: ConvParams;
|
|
1232
|
+
[Primitive.Pool]: {
|
|
1233
|
+
window: number[];
|
|
1234
|
+
strides: number[];
|
|
1235
|
+
};
|
|
1236
|
+
[Primitive.PoolTranspose]: {
|
|
1237
|
+
inShape: number[];
|
|
1238
|
+
window: number[];
|
|
1239
|
+
strides: number[];
|
|
1240
|
+
};
|
|
1241
|
+
[Primitive.Compare]: {
|
|
1242
|
+
op: CompareOp;
|
|
1243
|
+
};
|
|
1244
|
+
[Primitive.Transpose]: {
|
|
1245
|
+
perm: number[];
|
|
1246
|
+
};
|
|
1247
|
+
[Primitive.Broadcast]: {
|
|
1248
|
+
shape: number[];
|
|
1249
|
+
axis: number[];
|
|
1250
|
+
};
|
|
1251
|
+
[Primitive.RandomBits]: {
|
|
1252
|
+
shape: number[];
|
|
1253
|
+
mode: "xor" | 0 | 1;
|
|
1254
|
+
};
|
|
1255
|
+
[Primitive.Reshape]: {
|
|
1256
|
+
shape: number[];
|
|
1257
|
+
};
|
|
1258
|
+
[Primitive.Flip]: {
|
|
1259
|
+
axis: number[];
|
|
1260
|
+
};
|
|
1261
|
+
[Primitive.Shrink]: {
|
|
1262
|
+
slice: Pair[];
|
|
1263
|
+
};
|
|
1264
|
+
[Primitive.Pad]: {
|
|
1265
|
+
width: Pair[];
|
|
1266
|
+
};
|
|
1267
|
+
[Primitive.Gather]: {
|
|
1268
|
+
axis: number[];
|
|
1269
|
+
outDim: number;
|
|
1270
|
+
};
|
|
1271
|
+
[Primitive.JitCall]: {
|
|
1272
|
+
name: string;
|
|
1273
|
+
jaxpr: Jaxpr;
|
|
1274
|
+
numConsts: number;
|
|
1275
|
+
};
|
|
1276
|
+
}
|
|
1277
|
+
/** Type of parameters taken by each primitive. */
|
|
1278
|
+
type PrimitiveParams<T extends Primitive> = T extends keyof PrimitiveParamsImpl ? PrimitiveParamsImpl[T] : Record<string, never>;
|
|
1279
|
+
declare enum CompareOp {
|
|
1280
|
+
Less = "less",
|
|
1281
|
+
Equal = "equal",
|
|
1282
|
+
NotEqual = "not_equal",
|
|
1283
|
+
LessEqual = "less_equal",
|
|
1284
|
+
}
|
|
1285
|
+
/** @inline */
|
|
1286
|
+
type Axis = number | number[] | null;
|
|
1287
|
+
/** @inline */
|
|
1288
|
+
type ReduceOpts = {
|
|
1289
|
+
keepdims?: boolean;
|
|
1290
|
+
};
|
|
1291
|
+
type MainTrace = {
|
|
1292
|
+
level: number;
|
|
1293
|
+
traceType: new (main: MainTrace) => Trace;
|
|
1294
|
+
globalData: any | null;
|
|
1295
|
+
};
|
|
1296
|
+
/**
|
|
1297
|
+
* Push an interpreter onto the trace stack. Use this like:
|
|
1298
|
+
* `using main = newMain(...);`
|
|
1469
1299
|
*/
|
|
1470
|
-
|
|
1300
|
+
|
|
1301
|
+
type TracerValue = Tracer | number | boolean;
|
|
1302
|
+
declare abstract class Trace {
|
|
1303
|
+
readonly main: MainTrace;
|
|
1304
|
+
constructor(main: MainTrace);
|
|
1305
|
+
abstract pure(val: TracerValue): Tracer;
|
|
1306
|
+
abstract lift(val: Tracer): Tracer;
|
|
1307
|
+
abstract processPrimitive<P extends Primitive>(primitive: P, tracers: Tracer[], params: PrimitiveParams<P>): Tracer[];
|
|
1308
|
+
}
|
|
1309
|
+
/** Internal representation of an array value. */
|
|
1310
|
+
interface AbstractValue {
|
|
1311
|
+
/** Shape of the array. Must be a static tuple of non-negative dimensions. */
|
|
1312
|
+
shape: number[];
|
|
1313
|
+
/** Concrete data type of array elements. */
|
|
1314
|
+
dtype: DType;
|
|
1315
|
+
/**
|
|
1316
|
+
* Arrays created from JavaScript numbers (e.g., `np.array(3)`) are created as
|
|
1317
|
+
* _weakly typed_ unless a dtype is explicitly specified.
|
|
1318
|
+
*
|
|
1319
|
+
* Weakly typed values will automatically cast to the data type of other
|
|
1320
|
+
* arrays when used as an operand as an expression. This property only affects
|
|
1321
|
+
* how they promote in type casting; their memory layout is still determined
|
|
1322
|
+
* by the actual `dtype` field.
|
|
1323
|
+
*
|
|
1324
|
+
* ```ts
|
|
1325
|
+
* const x = np.array(3); // weakType = true, dtype = float32
|
|
1326
|
+
* const y = np.array([1, 2], { dtype: np.int32 }); // weakType = false, dtype = int32
|
|
1327
|
+
* const z = x.add(y); // z has dtype int32 because x is weakly typed
|
|
1328
|
+
* ```
|
|
1329
|
+
*
|
|
1330
|
+
* Weak types are present in JIT programs in their spec (e.g., Jaxpr inputs
|
|
1331
|
+
* and outputs can be weakly typed) form. But they're solely a frontend
|
|
1332
|
+
* concept. Backends are not aware of weak types.
|
|
1333
|
+
*/
|
|
1334
|
+
weakType: boolean;
|
|
1335
|
+
}
|
|
1471
1336
|
/**
|
|
1472
|
-
*
|
|
1473
|
-
* Calculate element-wise hyperbolic cosine of input.
|
|
1337
|
+
* Broadcast shapes and promote types with casting for two avals.
|
|
1474
1338
|
*
|
|
1475
|
-
*
|
|
1339
|
+
* This implements the weak type behavior described in `promoteTypes()`, but not
|
|
1340
|
+
* implemented in that function as `weakType` is not passed.
|
|
1476
1341
|
*/
|
|
1477
|
-
|
|
1342
|
+
|
|
1343
|
+
declare abstract class Tracer {
|
|
1344
|
+
/** @ignore */
|
|
1345
|
+
readonly _trace: Trace;
|
|
1346
|
+
constructor(trace: Trace);
|
|
1347
|
+
abstract get aval(): AbstractValue;
|
|
1348
|
+
abstract toString(): string;
|
|
1349
|
+
/**
|
|
1350
|
+
* Access an array by reference, incrementing the reference count.
|
|
1351
|
+
*
|
|
1352
|
+
* jax-js handles freeing arrays by using "move" semantics, like in Rust/C++.
|
|
1353
|
+
* Whenever you pass an array into a function, that function should consume
|
|
1354
|
+
* the array, and it will no longer be usable. For example, if you had:
|
|
1355
|
+
*
|
|
1356
|
+
* ```
|
|
1357
|
+
* const x = np.array([1, 2, 3]);
|
|
1358
|
+
* const y = np.add(x, x);
|
|
1359
|
+
* ```
|
|
1360
|
+
*
|
|
1361
|
+
* The second line does not work because the first parameter consumes `x`, and
|
|
1362
|
+
* then the second parameter will already have been freed / disposed.
|
|
1363
|
+
*
|
|
1364
|
+
* To fix this, you can write:
|
|
1365
|
+
*
|
|
1366
|
+
* ```
|
|
1367
|
+
* const y = np.add(x.ref, x);
|
|
1368
|
+
* ```
|
|
1369
|
+
*
|
|
1370
|
+
* Under the hood, every access to `.ref` increments the internal reference
|
|
1371
|
+
* count of the array. The reference count starts at 1. When it hits 0, the
|
|
1372
|
+
* memory behind the array is freed.
|
|
1373
|
+
*/
|
|
1374
|
+
abstract get ref(): this;
|
|
1375
|
+
/**
|
|
1376
|
+
* Manually decrement the reference count of the array.
|
|
1377
|
+
*
|
|
1378
|
+
* Arrays are created with reference count 1. Whenever it is used as argument
|
|
1379
|
+
* to a function or other operation, it is disposed (i.e., reference count
|
|
1380
|
+
* decreases by 1) automatically. Whenever a `.ref` is created, the reference
|
|
1381
|
+
* count increases.
|
|
1382
|
+
*
|
|
1383
|
+
* You generally don't need to call this function directly since arrays are
|
|
1384
|
+
* automatically disposed after being passed into an operation. One common
|
|
1385
|
+
* exception is when writing a function and ignoring one of its arguments. In
|
|
1386
|
+
* that case, by convention you should dispose of that argument manually.
|
|
1387
|
+
*
|
|
1388
|
+
* ```
|
|
1389
|
+
* function myCustomOperation(a: np.Array, b: np.Array) {
|
|
1390
|
+
* b.dispose(); // Needed to satisfy "move" rules.
|
|
1391
|
+
* return a.add(1);
|
|
1392
|
+
* }
|
|
1393
|
+
* ```
|
|
1394
|
+
*/
|
|
1395
|
+
abstract dispose(): void;
|
|
1396
|
+
/** The shape of the array. */
|
|
1397
|
+
get shape(): number[];
|
|
1398
|
+
/** The total number of elements in the array. */
|
|
1399
|
+
get size(): number;
|
|
1400
|
+
/** The dtype of elements stored in the array. */
|
|
1401
|
+
get dtype(): DType;
|
|
1402
|
+
/**
|
|
1403
|
+
* Whether the array is weakly typed.
|
|
1404
|
+
*
|
|
1405
|
+
* Weakly typed arrays will cast to the dtype of the other operand. See
|
|
1406
|
+
* `promoteTypes()` for details.
|
|
1407
|
+
*/
|
|
1408
|
+
get weakType(): boolean;
|
|
1409
|
+
/** The number of dimensions of the array. */
|
|
1410
|
+
get ndim(): number;
|
|
1411
|
+
/** @ignore */
|
|
1412
|
+
fullLower(): Tracer;
|
|
1413
|
+
neg(): this;
|
|
1414
|
+
add(other: this | TracerValue): this;
|
|
1415
|
+
mul(other: this | TracerValue): this;
|
|
1416
|
+
greater(other: this | TracerValue): this;
|
|
1417
|
+
less(other: this | TracerValue): this;
|
|
1418
|
+
equal(other: this | TracerValue): this;
|
|
1419
|
+
notEqual(other: this | TracerValue): this;
|
|
1420
|
+
greaterEqual(other: this | TracerValue): this;
|
|
1421
|
+
lessEqual(other: this | TracerValue): this;
|
|
1422
|
+
/** Sum of the elements of the array over a given axis, or axes. */
|
|
1423
|
+
sum(axis?: Axis, opts?: ReduceOpts): this;
|
|
1424
|
+
/** Product of the array elements over a given axis. */
|
|
1425
|
+
prod(axis?: Axis, opts?: ReduceOpts): this;
|
|
1426
|
+
/** Compute the average of the array elements along the specified axis. */
|
|
1427
|
+
mean(axis?: Axis, opts?: ReduceOpts): this;
|
|
1428
|
+
/** Permute the dimensions of an array. Defaults to reversing the axis order. */
|
|
1429
|
+
transpose(perm?: number[]): this;
|
|
1430
|
+
/**
|
|
1431
|
+
* Give a new shape to an array without changing its data.
|
|
1432
|
+
*
|
|
1433
|
+
* One shape dimension can be -1. In this case, the value is inferred from the
|
|
1434
|
+
* length of the array and remaining dimensions.
|
|
1435
|
+
*/
|
|
1436
|
+
reshape(shape: number | number[]): this;
|
|
1437
|
+
/** Copy the array and cast to a specified dtype. */
|
|
1438
|
+
astype(dtype: DType): this;
|
|
1439
|
+
/** Subtract an array from this one. */
|
|
1440
|
+
sub(other: this | TracerValue): this;
|
|
1441
|
+
/** Divide an array by this one. */
|
|
1442
|
+
div(other: this | TracerValue): this;
|
|
1443
|
+
/** Return specified diagonals. See `numpy.diagonal` for full docs. */
|
|
1444
|
+
diagonal(offset?: number, axis1?: number, axis2?: number): this;
|
|
1445
|
+
/** Flatten the array without changing its data. */
|
|
1446
|
+
flatten(): this;
|
|
1447
|
+
/** Flatten the array without changing its data. */
|
|
1448
|
+
ravel(): this;
|
|
1449
|
+
/**
|
|
1450
|
+
* Iterate over the first dimension of this array, returning slices.
|
|
1451
|
+
*
|
|
1452
|
+
* This can be used to destructure arrays. For example:
|
|
1453
|
+
*
|
|
1454
|
+
* ```js
|
|
1455
|
+
* let x = np.array([[1, 2], [3, 4]]);
|
|
1456
|
+
* let [a, b] = x;
|
|
1457
|
+
* console.log(a.js()); // [1, 2]
|
|
1458
|
+
* console.log(b.js()); // [3, 4]
|
|
1459
|
+
* ```
|
|
1460
|
+
*/
|
|
1461
|
+
[Symbol.iterator](): IterableIterator<this>;
|
|
1462
|
+
/**
|
|
1463
|
+
* Slice an array along one or more axes.
|
|
1464
|
+
*
|
|
1465
|
+
* This is the equivalent of slicing in Python, e.g. `x[1:3, 2, :, None]`. To
|
|
1466
|
+
* mimic this in JavaScript, we would write:
|
|
1467
|
+
*
|
|
1468
|
+
* ```js
|
|
1469
|
+
* x.slice([1, 3], 2, [], null);
|
|
1470
|
+
* ```
|
|
1471
|
+
*
|
|
1472
|
+
* The `slice` method accepts a variable number of arguments, each of which
|
|
1473
|
+
* can be a number, an empty array, a single-element array, a two-element
|
|
1474
|
+
* array, or `null`. The arguments are interpreted as follows:
|
|
1475
|
+
*
|
|
1476
|
+
* - A number `n` means to access the `n`-th element along that axis, removing
|
|
1477
|
+
* that axis from the resulting shape.
|
|
1478
|
+
* - An empty array `[]` means to keep that axis as-is, like `:` in Python.
|
|
1479
|
+
* - A single-element array `[i]` means to start slicing from index `i`
|
|
1480
|
+
* (inclusive) to the end of the axis, like `x[i:]`.
|
|
1481
|
+
* - A two-element array `[i, j]` means to slice from index `i` (inclusive)
|
|
1482
|
+
* to index `j` (exclusive), like `x[i:j]`.
|
|
1483
|
+
* - `null` means to add a new axis at that position, like `np.newaxis`.
|
|
1484
|
+
*
|
|
1485
|
+
* Like in Python, negative indices are supported, which count from the end of
|
|
1486
|
+
* the axis. For example, `-1` means the last element.
|
|
1487
|
+
*
|
|
1488
|
+
* Strided slices are not yet implemented, so you cannot write `x[::2]` or
|
|
1489
|
+
* similar.
|
|
1490
|
+
*
|
|
1491
|
+
* Advanced indexing by integer arrays is also supported. This translates to
|
|
1492
|
+
* the "gather" primitive, and it allows you to access specific elements of
|
|
1493
|
+
* the array by integer indices stored in another array.
|
|
1494
|
+
*/
|
|
1495
|
+
slice(...index: (number | [] | [number] | Pair | null | Tracer)[]): this;
|
|
1496
|
+
}
|
|
1497
|
+
declare class ShapedArray implements AbstractValue {
|
|
1498
|
+
readonly shape: number[];
|
|
1499
|
+
readonly dtype: DType;
|
|
1500
|
+
readonly weakType: boolean;
|
|
1501
|
+
constructor(shape: number[], dtype: DType, weakType: boolean);
|
|
1502
|
+
static fromAval(aval: AbstractValue): ShapedArray;
|
|
1503
|
+
get ndim(): number;
|
|
1504
|
+
toString(): string;
|
|
1505
|
+
equals(other: ShapedArray): boolean;
|
|
1506
|
+
}
|
|
1507
|
+
//#endregion
|
|
1508
|
+
//#region src/frontend/array.d.ts
|
|
1509
|
+
type ArrayLike = Array | number | boolean;
|
|
1510
|
+
/** Version of pureArray with fudged types. */
|
|
1511
|
+
|
|
1478
1512
|
/**
|
|
1479
|
-
*
|
|
1480
|
-
* Calculate element-wise hyperbolic tangent of input.
|
|
1513
|
+
* An executable operation that will be dispatched to the backend.
|
|
1481
1514
|
*
|
|
1482
|
-
*
|
|
1515
|
+
* This holds a reference to all input buffers used in the operation. After the
|
|
1516
|
+
* operation is dispatched, the references should be released.
|
|
1483
1517
|
*/
|
|
1484
|
-
declare
|
|
1518
|
+
declare class PendingExecute {
|
|
1519
|
+
#private;
|
|
1520
|
+
readonly backend: Backend;
|
|
1521
|
+
readonly kernel: Kernel;
|
|
1522
|
+
readonly inputs: Slot[];
|
|
1523
|
+
readonly outputs: Slot[];
|
|
1524
|
+
prepared: Executable | null;
|
|
1525
|
+
submitted: boolean;
|
|
1526
|
+
constructor(backend: Backend, kernel: Kernel, inputs: Slot[], outputs: Slot[]);
|
|
1527
|
+
updateRc(delta: number): void;
|
|
1528
|
+
prepare(): Promise<void>;
|
|
1529
|
+
prepareSync(): void;
|
|
1530
|
+
submit(): void;
|
|
1531
|
+
}
|
|
1532
|
+
/** @inline */
|
|
1533
|
+
type DTypeAndDevice = {
|
|
1534
|
+
dtype?: DType;
|
|
1535
|
+
device?: Device;
|
|
1536
|
+
};
|
|
1537
|
+
type ArrayConstructorArgs = {
|
|
1538
|
+
source: AluExp | Slot;
|
|
1539
|
+
st: ShapeTracker;
|
|
1540
|
+
dtype: DType;
|
|
1541
|
+
weakType: boolean;
|
|
1542
|
+
backend: Backend;
|
|
1543
|
+
committed: boolean;
|
|
1544
|
+
pending?: Iterable<PendingExecute>;
|
|
1545
|
+
};
|
|
1485
1546
|
/**
|
|
1486
|
-
*
|
|
1487
|
-
* Calculate element-wise inverse hyperbolic sine of input.
|
|
1547
|
+
* A multidimensional numeric array with data stored on CPU or GPU.
|
|
1488
1548
|
*
|
|
1489
|
-
*
|
|
1549
|
+
* This is the library's core data type. Equivalent to `jax.Array` from JAX, or
|
|
1550
|
+
* `torch.Tensor`.
|
|
1551
|
+
*
|
|
1552
|
+
* Not to be confused with the JavaScript "Array" constructor. Avoid importing
|
|
1553
|
+
* this into your code's namespace if you're already using the JavaScript
|
|
1554
|
+
* "Array" type by name.
|
|
1490
1555
|
*/
|
|
1491
|
-
declare
|
|
1556
|
+
declare class Array extends Tracer {
|
|
1557
|
+
#private;
|
|
1558
|
+
id: number;
|
|
1559
|
+
/**
|
|
1560
|
+
* @ignore
|
|
1561
|
+
* Constructs an array from source, shape and backend. Note that if the source
|
|
1562
|
+
* is a backend `Slot`, this constructor _takes ownership_ of the slot. It
|
|
1563
|
+
* will be freed when the array is disposed.
|
|
1564
|
+
*/
|
|
1565
|
+
constructor(args: ArrayConstructorArgs);
|
|
1566
|
+
/** @ignore */
|
|
1567
|
+
get aval(): ShapedArray;
|
|
1568
|
+
/** Return a simple string representation of the array's dimensions. */
|
|
1569
|
+
toString(): string;
|
|
1570
|
+
get device(): Device;
|
|
1571
|
+
get ref(): this;
|
|
1572
|
+
dispose(): void;
|
|
1573
|
+
/**
|
|
1574
|
+
* Convert this array into a primitive value.
|
|
1575
|
+
*
|
|
1576
|
+
* This only works for scalars (0-dimensional arrays). It lets you get values
|
|
1577
|
+
* "out" of the JAX system. For instance, if `x = np.array(5)`, then you can
|
|
1578
|
+
* evaluate `x + 1` and `x ** 2` to get `6` and `25`, respectively.
|
|
1579
|
+
*
|
|
1580
|
+
* This method is also called for `==` equality.
|
|
1581
|
+
*/
|
|
1582
|
+
[Symbol.toPrimitive](): any;
|
|
1583
|
+
/** Realize the array and return it as data. */
|
|
1584
|
+
data(): Promise<DataArray>;
|
|
1585
|
+
/**
|
|
1586
|
+
* Wait for this array to finish evaluation.
|
|
1587
|
+
*
|
|
1588
|
+
* Operations and data loading in jax-js are lazy, so this function ensures
|
|
1589
|
+
* that pending operations are dispatched and fully executed before it
|
|
1590
|
+
* returns.
|
|
1591
|
+
*
|
|
1592
|
+
* If you are mapping from `data()` or `dataSync()`, it will also trigger
|
|
1593
|
+
* dispatch of operations as well.
|
|
1594
|
+
*
|
|
1595
|
+
* **Note:** `jax.blockUntilReady()` is a higher-level API, it calls this
|
|
1596
|
+
* asynchronously for multiple arrays.
|
|
1597
|
+
*/
|
|
1598
|
+
blockUntilReady(): Promise<Array>;
|
|
1599
|
+
/**
|
|
1600
|
+
* Realize the array and return it as data. This is a sync variant and not
|
|
1601
|
+
* recommended for performance reasons, as it will block rendering.
|
|
1602
|
+
*/
|
|
1603
|
+
dataSync(): DataArray;
|
|
1604
|
+
/**
|
|
1605
|
+
* Convert this array into a JavaScript object.
|
|
1606
|
+
*
|
|
1607
|
+
* This is a blocking operation that will compile all of the shaders and wait
|
|
1608
|
+
* for execution to complete, synchronously. No other JavaScript code on the
|
|
1609
|
+
* site will be run during shader execution.
|
|
1610
|
+
*
|
|
1611
|
+
* To avoid blocking, prefer `jsAsync()` when possible.
|
|
1612
|
+
*/
|
|
1613
|
+
js(): any;
|
|
1614
|
+
/** Convert this array into a JavaScript object, asynchronously. */
|
|
1615
|
+
jsAsync(): Promise<any>;
|
|
1616
|
+
/**
|
|
1617
|
+
* Copy an element of an array to a numeric scalar and return it.
|
|
1618
|
+
*
|
|
1619
|
+
* Throws an error if the array does not have a single element. The array must
|
|
1620
|
+
* either be rank-0, or all dimensions of the shape are 1.
|
|
1621
|
+
*/
|
|
1622
|
+
item(): number;
|
|
1623
|
+
/** @private Internal plumbing method for Array / Tracer ops. */
|
|
1624
|
+
static _implRules(): typeof implRules;
|
|
1625
|
+
/** @private */
|
|
1626
|
+
_realizeSource(): number;
|
|
1627
|
+
/** @private Put this array on a new backend, asynchronously. */
|
|
1628
|
+
_put(backend: Backend): Promise<Array>;
|
|
1629
|
+
/** @private Put this array on a new backend, synchronously. */
|
|
1630
|
+
_putSync(backend: Backend): Array;
|
|
1631
|
+
}
|
|
1632
|
+
/** Constructor for creating a new array from data. */
|
|
1633
|
+
declare function array(values: Array | DataArray | RecursiveArray<number> | RecursiveArray<boolean>, {
|
|
1634
|
+
shape,
|
|
1635
|
+
dtype,
|
|
1636
|
+
device
|
|
1637
|
+
}?: {
|
|
1638
|
+
shape?: number[];
|
|
1639
|
+
} & DTypeAndDevice): Array;
|
|
1640
|
+
/** If x is a value, lift it into an array, otherwise leave it be. */
|
|
1641
|
+
|
|
1642
|
+
type ImplRule<P extends Primitive> = (tracers: Array[], params: PrimitiveParams<P>) => Array[];
|
|
1643
|
+
declare const implRules: { [P in Primitive]: ImplRule<P> };
|
|
1644
|
+
/** Return a new array of given shape and type, filled with zeros. */
|
|
1645
|
+
declare function zeros(shape: number[], {
|
|
1646
|
+
dtype,
|
|
1647
|
+
device
|
|
1648
|
+
}?: DTypeAndDevice): Array;
|
|
1649
|
+
/** Return a new array of given shape and type, filled with ones. */
|
|
1650
|
+
declare function ones(shape: number[], {
|
|
1651
|
+
dtype,
|
|
1652
|
+
device
|
|
1653
|
+
}?: DTypeAndDevice): Array;
|
|
1654
|
+
/** Return a new array of given shape and type, filled with `fill_value`. */
|
|
1655
|
+
declare function full(shape: number[], fillValue: number | boolean | Array, {
|
|
1656
|
+
dtype,
|
|
1657
|
+
device
|
|
1658
|
+
}?: DTypeAndDevice): Array;
|
|
1492
1659
|
/**
|
|
1493
|
-
*
|
|
1494
|
-
* Calculate element-wise inverse hyperbolic cosine of input.
|
|
1660
|
+
* Create an identity matrix.
|
|
1495
1661
|
*
|
|
1496
|
-
*
|
|
1662
|
+
* If numCols is not provided, it defaults to numRows, i.e., a square identity
|
|
1663
|
+
* matrix with ones on the diagonal.
|
|
1497
1664
|
*/
|
|
1498
|
-
declare
|
|
1665
|
+
declare function eye(numRows: number, numCols?: number, {
|
|
1666
|
+
dtype,
|
|
1667
|
+
device
|
|
1668
|
+
}?: DTypeAndDevice): Array;
|
|
1669
|
+
/** Return the identity matrix, with ones on the main diagonal. */
|
|
1670
|
+
declare function identity$1(n: number, {
|
|
1671
|
+
dtype,
|
|
1672
|
+
device
|
|
1673
|
+
}?: DTypeAndDevice): Array;
|
|
1499
1674
|
/**
|
|
1500
|
-
*
|
|
1501
|
-
* Calculate element-wise inverse hyperbolic tangent of input.
|
|
1675
|
+
* Return evenly spaced values within a given interval.
|
|
1502
1676
|
*
|
|
1503
|
-
*
|
|
1677
|
+
* This can be called with a varying number of arguments, just like the range()
|
|
1678
|
+
* builtin function in Python.
|
|
1679
|
+
*
|
|
1680
|
+
* - `arange(stop)` is equivalent to `arange(0, stop, 1)`.
|
|
1681
|
+
* - `arange(start, stop)` is equivalent to `arange(start, stop, 1)`.
|
|
1682
|
+
* - `arange(start, stop, step)` creates an array starting at `start`, ending
|
|
1683
|
+
* before `stop`, with a step size of `step`.
|
|
1684
|
+
*
|
|
1685
|
+
* Defaults to an integer data type. This can produce unintended results when
|
|
1686
|
+
* using a non-integer step, so prefer linspace() in those cases.
|
|
1504
1687
|
*/
|
|
1505
|
-
declare
|
|
1506
|
-
|
|
1507
|
-
|
|
1508
|
-
|
|
1509
|
-
declare const acosh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1510
|
-
/** @function Alias of `jax.numpy.arctanh()`. */
|
|
1511
|
-
declare const atanh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1688
|
+
declare function arange(start: number, stop?: number, step?: number, {
|
|
1689
|
+
dtype,
|
|
1690
|
+
device
|
|
1691
|
+
}?: DTypeAndDevice): Array;
|
|
1512
1692
|
/**
|
|
1513
|
-
*
|
|
1693
|
+
* Return evenly spaced numbers over a specified interval.
|
|
1514
1694
|
*
|
|
1515
|
-
*
|
|
1516
|
-
* the
|
|
1695
|
+
* Returns _num_ evenly spaced samples, calculated over the interval
|
|
1696
|
+
* [`start`, `stop`]. The endpoint `stop` is included in the result by default,
|
|
1697
|
+
* but this is controlled by the `endpoint` parameter.
|
|
1517
1698
|
*
|
|
1518
|
-
*
|
|
1519
|
-
* where `N` represents the number of elements (e.g., for Bessel's correction).
|
|
1699
|
+
* The default data type is Float32. Use arange() for integer steps.
|
|
1520
1700
|
*/
|
|
1521
|
-
declare function
|
|
1522
|
-
|
|
1523
|
-
|
|
1524
|
-
}
|
|
1701
|
+
declare function linspace(start: number, stop: number, num?: number, endpoint?: boolean, {
|
|
1702
|
+
dtype,
|
|
1703
|
+
device
|
|
1704
|
+
}?: DTypeAndDevice): Array;
|
|
1705
|
+
declare namespace lax_d_exports {
|
|
1706
|
+
export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convWithGeneralPadding, dot, erf, erfc, reduceWindow, stopGradient };
|
|
1707
|
+
}
|
|
1525
1708
|
/**
|
|
1526
|
-
*
|
|
1709
|
+
* Dimension numbers for general `dot()` primitive.
|
|
1527
1710
|
*
|
|
1528
|
-
*
|
|
1529
|
-
*
|
|
1711
|
+
* Contracting dimensions act as a tensor contraction (reduction) along the
|
|
1712
|
+
* given axis. They must be the same size in both operands. Batch dimensions
|
|
1713
|
+
* are treated as vectorized, leading batch dimensions.
|
|
1530
1714
|
*
|
|
1531
|
-
*
|
|
1532
|
-
*
|
|
1715
|
+
* The return value has a shape where the first dimensions are shared batch
|
|
1716
|
+
* dimensions, followed by `lhs` non-contracting dimensions, followed by
|
|
1717
|
+
* `rhs` non-contracting dimensions.
|
|
1533
1718
|
*/
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
|
|
1539
|
-
|
|
1540
|
-
//#region src/frontend/jaxpr.d.ts
|
|
1719
|
+
type DotDimensionNumbers = {
|
|
1720
|
+
lhsContractingDims?: number[];
|
|
1721
|
+
rhsContractingDims?: number[];
|
|
1722
|
+
lhsBatchDims?: number[];
|
|
1723
|
+
rhsBatchDims?: number[];
|
|
1724
|
+
};
|
|
1541
1725
|
/**
|
|
1542
|
-
*
|
|
1726
|
+
* General dot product/contraction operator.
|
|
1543
1727
|
*
|
|
1544
|
-
*
|
|
1545
|
-
*
|
|
1728
|
+
* Prefer higher-level functions like `jax.numpy.dot()`, `jax.numpy.matmul()`,
|
|
1729
|
+
* `jax.numpy.tensordot(), and `jax.numpy.einsum()` where possible.
|
|
1546
1730
|
*/
|
|
1547
|
-
|
|
1548
|
-
|
|
1549
|
-
|
|
1550
|
-
|
|
1551
|
-
|
|
1552
|
-
|
|
1553
|
-
readonly id: number;
|
|
1554
|
-
readonly aval: ShapedArray;
|
|
1555
|
-
constructor(aval: ShapedArray);
|
|
1556
|
-
toString(): string;
|
|
1557
|
-
}
|
|
1558
|
-
/** Literal in a Jaxpr expression. Currently, only scalars are supported. */
|
|
1559
|
-
declare class Lit {
|
|
1560
|
-
readonly value: number;
|
|
1561
|
-
readonly aval: ShapedArray;
|
|
1562
|
-
get dtype(): DType;
|
|
1563
|
-
constructor(aval: AbstractValue, value: number);
|
|
1564
|
-
}
|
|
1565
|
-
type Atom = Var | Lit;
|
|
1566
|
-
declare class VarPrinter {
|
|
1567
|
-
#private;
|
|
1568
|
-
names: Map<Var, string>;
|
|
1569
|
-
name(v: Var): string;
|
|
1570
|
-
nameType(v: Var): string;
|
|
1571
|
-
}
|
|
1572
|
-
/** A single statement / binding in a Jaxpr, in ANF form. */
|
|
1573
|
-
declare class JaxprEqn {
|
|
1574
|
-
readonly primitive: Primitive;
|
|
1575
|
-
readonly inputs: Atom[];
|
|
1576
|
-
readonly params: Record<string, any>;
|
|
1577
|
-
readonly outBinders: Var[];
|
|
1578
|
-
constructor(primitive: Primitive, inputs: Atom[], params: Record<string, any>, outBinders: Var[]);
|
|
1579
|
-
pprint(usedVars?: Set<Var>, vp?: VarPrinter): PPrint;
|
|
1580
|
-
toString(): string;
|
|
1581
|
-
}
|
|
1582
|
-
/** Typed intermediate representation for traced computations. */
|
|
1583
|
-
declare class Jaxpr implements FpHashable {
|
|
1584
|
-
#private;
|
|
1585
|
-
readonly inBinders: Var[];
|
|
1586
|
-
readonly eqns: JaxprEqn[];
|
|
1587
|
-
readonly outs: Atom[];
|
|
1588
|
-
constructor(inBinders: Var[], eqns: JaxprEqn[], outs: Atom[]);
|
|
1589
|
-
pprint(): PPrint;
|
|
1590
|
-
toString(): string;
|
|
1591
|
-
/**
|
|
1592
|
-
* Gets a hash of this Jaxpr.
|
|
1593
|
-
*
|
|
1594
|
-
* Var identity is not considered in the hash, so two Jaxprs with the same
|
|
1595
|
-
* order of assignments and operators but different variable IDs will resolve
|
|
1596
|
-
* to the same hash (and toString representation).
|
|
1597
|
-
*/
|
|
1598
|
-
getHash(): bigint;
|
|
1599
|
-
hash(state: FpHash): void;
|
|
1600
|
-
/**
|
|
1601
|
-
* Produce a simplified Jaxpr with basic optimizations applied.
|
|
1602
|
-
* - Trim away unused variables.
|
|
1603
|
-
* - Fold away *1, *0, or +0 operations against literals.
|
|
1604
|
-
* - Remove no-op movement operations.
|
|
1605
|
-
*/
|
|
1606
|
-
simplify(): Jaxpr;
|
|
1607
|
-
/** Flattens nested JitCall in a Jaxpr. Useful for handling jit-of-jit. */
|
|
1608
|
-
flatten(): Jaxpr;
|
|
1609
|
-
}
|
|
1610
|
-
/** @inline */
|
|
1611
|
-
type JitOpts = {
|
|
1612
|
-
staticArgnums?: number[];
|
|
1613
|
-
};
|
|
1614
|
-
declare namespace lax_d_exports {
|
|
1615
|
-
export { PaddingType, conv, convGeneralDilated, convWithGeneralPadding, erf, erfc, reduceWindow, stopGradient };
|
|
1616
|
-
}
|
|
1731
|
+
declare function dot(lhs: Array, rhs: Array, {
|
|
1732
|
+
lhsContractingDims: lc,
|
|
1733
|
+
rhsContractingDims: rc,
|
|
1734
|
+
lhsBatchDims: lb,
|
|
1735
|
+
rhsBatchDims: rb
|
|
1736
|
+
}?: DotDimensionNumbers): Array;
|
|
1617
1737
|
type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | [number, number][];
|
|
1618
1738
|
/**
|
|
1619
1739
|
* General n-dimensional convolution operator, with optional dilation.
|