@jax-js/jax 0.0.4 → 0.0.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +67 -24
- package/dist/{backend-EBRGmEYw.js → backend-CdcTZEOF.js} +35 -6
- package/dist/{backend-Ss1Mev_-.cjs → backend-yEU0L_ig.cjs} +40 -5
- package/dist/index.cjs +324 -225
- package/dist/index.d.cts +71 -26
- package/dist/index.d.ts +71 -26
- package/dist/index.js +314 -215
- package/dist/{webgpu-ow0Pn_6q.js → webgpu-CM-xNYzW.js} +1 -1
- package/dist/{webgpu-BVdMaO9T.cjs → webgpu-CNOpiO5T.cjs} +1 -1
- package/package.json +1 -1
package/dist/index.d.cts
CHANGED
|
@@ -180,12 +180,12 @@ type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Arr
|
|
|
180
180
|
* **Type lattice:**
|
|
181
181
|
* ```text
|
|
182
182
|
* bool -> uint32 -> int32 -> float16 -> float32
|
|
183
|
-
*
|
|
183
|
+
* weakType --^
|
|
184
184
|
* ```
|
|
185
185
|
*
|
|
186
|
-
*
|
|
187
|
-
*
|
|
188
|
-
*
|
|
186
|
+
* `weakType` represents weakly typed arrays. These are created for JS numbers,
|
|
187
|
+
* which default to float32 but "weak" so they cast to the dtype of any array
|
|
188
|
+
* they are first combined with, except `bool`.
|
|
189
189
|
*
|
|
190
190
|
* **Examples:**
|
|
191
191
|
* - `promoteTypes(bool, int32) → int32`
|
|
@@ -613,6 +613,7 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
|
613
613
|
outDim: number;
|
|
614
614
|
};
|
|
615
615
|
[Primitive.JitCall]: {
|
|
616
|
+
name: string;
|
|
616
617
|
jaxpr: Jaxpr;
|
|
617
618
|
numConsts: number;
|
|
618
619
|
};
|
|
@@ -651,10 +652,40 @@ declare abstract class Trace {
|
|
|
651
652
|
abstract lift(val: Tracer): Tracer;
|
|
652
653
|
abstract processPrimitive<P extends Primitive>(primitive: P, tracers: Tracer[], params: PrimitiveParams<P>): Tracer[];
|
|
653
654
|
}
|
|
655
|
+
/** Internal representation of an array value. */
|
|
654
656
|
interface AbstractValue {
|
|
657
|
+
/** Shape of the array. Must be a static tuple of non-negative dimensions. */
|
|
655
658
|
shape: number[];
|
|
659
|
+
/** Concrete data type of array elements. */
|
|
656
660
|
dtype: DType;
|
|
661
|
+
/**
|
|
662
|
+
* Arrays created from JavaScript numbers (e.g., `np.array(3)`) are created as
|
|
663
|
+
* _weakly typed_ unless a dtype is explicitly specified.
|
|
664
|
+
*
|
|
665
|
+
* Weakly typed values will automatically cast to the data type of other
|
|
666
|
+
* arrays when used as an operand as an expression. This property only affects
|
|
667
|
+
* how they promote in type casting; their memory layout is still determined
|
|
668
|
+
* by the actual `dtype` field.
|
|
669
|
+
*
|
|
670
|
+
* ```ts
|
|
671
|
+
* const x = np.array(3); // weakType = true, dtype = float32
|
|
672
|
+
* const y = np.array([1, 2], { dtype: np.int32 }); // weakType = false, dtype = int32
|
|
673
|
+
* const z = x.add(y); // z has dtype int32 because x is weakly typed
|
|
674
|
+
* ```
|
|
675
|
+
*
|
|
676
|
+
* Weak types are present in JIT programs in their spec (e.g., Jaxpr inputs
|
|
677
|
+
* and outputs can be weakly typed) form. But they're solely a frontend
|
|
678
|
+
* concept. Backends are not aware of weak types.
|
|
679
|
+
*/
|
|
680
|
+
weakType: boolean;
|
|
657
681
|
}
|
|
682
|
+
/**
|
|
683
|
+
* Broadcast shapes and promote types with casting for two avals.
|
|
684
|
+
*
|
|
685
|
+
* This implements the weak type behavior described in `promoteTypes()`, but not
|
|
686
|
+
* implemented in that function as `weakType` is not passed.
|
|
687
|
+
*/
|
|
688
|
+
|
|
658
689
|
declare abstract class Tracer {
|
|
659
690
|
/** @ignore */
|
|
660
691
|
readonly _trace: Trace;
|
|
@@ -712,8 +743,15 @@ declare abstract class Tracer {
|
|
|
712
743
|
get shape(): number[];
|
|
713
744
|
/** The total number of elements in the array. */
|
|
714
745
|
get size(): number;
|
|
715
|
-
/** The dtype of the array. */
|
|
746
|
+
/** The dtype of elements stored in the array. */
|
|
716
747
|
get dtype(): DType;
|
|
748
|
+
/**
|
|
749
|
+
* Whether the array is weakly typed.
|
|
750
|
+
*
|
|
751
|
+
* Weakly typed arrays will cast to the dtype of the other operand. See
|
|
752
|
+
* `promoteTypes()` for details.
|
|
753
|
+
*/
|
|
754
|
+
get weakType(): boolean;
|
|
717
755
|
/** The number of dimensions of the array. */
|
|
718
756
|
get ndim(): number;
|
|
719
757
|
/** @ignore */
|
|
@@ -805,7 +843,8 @@ declare abstract class Tracer {
|
|
|
805
843
|
declare class ShapedArray implements AbstractValue {
|
|
806
844
|
readonly shape: number[];
|
|
807
845
|
readonly dtype: DType;
|
|
808
|
-
|
|
846
|
+
readonly weakType: boolean;
|
|
847
|
+
constructor(shape: number[], dtype: DType, weakType: boolean);
|
|
809
848
|
static fromAval(aval: AbstractValue): ShapedArray;
|
|
810
849
|
get ndim(): number;
|
|
811
850
|
toString(): string;
|
|
@@ -841,6 +880,14 @@ type DTypeAndDevice = {
|
|
|
841
880
|
dtype?: DType;
|
|
842
881
|
device?: Device;
|
|
843
882
|
};
|
|
883
|
+
type ArrayConstructorArgs = {
|
|
884
|
+
source: AluExp | Slot;
|
|
885
|
+
st: ShapeTracker;
|
|
886
|
+
dtype: DType;
|
|
887
|
+
weakType: boolean;
|
|
888
|
+
backend: Backend;
|
|
889
|
+
pending?: Iterable<PendingExecute>;
|
|
890
|
+
};
|
|
844
891
|
/**
|
|
845
892
|
* A multidimensional numeric array with data stored on CPU or GPU.
|
|
846
893
|
*
|
|
@@ -860,11 +907,7 @@ declare class Array extends Tracer {
|
|
|
860
907
|
* is a backend `Slot`, this constructor _takes ownership_ of the slot. It
|
|
861
908
|
* will be freed when the array is disposed.
|
|
862
909
|
*/
|
|
863
|
-
constructor(
|
|
864
|
-
pending
|
|
865
|
-
}?: {
|
|
866
|
-
pending?: Iterable<PendingExecute> | null;
|
|
867
|
-
});
|
|
910
|
+
constructor(args: ArrayConstructorArgs);
|
|
868
911
|
/** @ignore */
|
|
869
912
|
get aval(): ShapedArray;
|
|
870
913
|
/** Return a simple string representation of the array's dimensions. */
|
|
@@ -926,8 +969,6 @@ declare class Array extends Tracer {
|
|
|
926
969
|
static _implRules(): typeof implRules;
|
|
927
970
|
_realizeSource(): number;
|
|
928
971
|
}
|
|
929
|
-
/** Construct an array from a single scalar constant. */
|
|
930
|
-
|
|
931
972
|
/** Constructor for creating a new array from data. */
|
|
932
973
|
declare function array(values: Array | Float16Array<ArrayBuffer> | Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | RecursiveArray<number> | RecursiveArray<boolean>, {
|
|
933
974
|
shape,
|
|
@@ -1480,10 +1521,10 @@ declare class Var {
|
|
|
1480
1521
|
}
|
|
1481
1522
|
/** Literal in a Jaxpr expression. Currently, only scalars are supported. */
|
|
1482
1523
|
declare class Lit {
|
|
1483
|
-
readonly dtype: DType;
|
|
1484
1524
|
readonly value: number;
|
|
1485
1525
|
readonly aval: ShapedArray;
|
|
1486
|
-
|
|
1526
|
+
get dtype(): DType;
|
|
1527
|
+
constructor(aval: AbstractValue, value: number);
|
|
1487
1528
|
}
|
|
1488
1529
|
type Atom = Var | Lit;
|
|
1489
1530
|
declare class VarPrinter {
|
|
@@ -1742,14 +1783,14 @@ declare function key(seed: number): Array;
|
|
|
1742
1783
|
declare function split(key: Array, num?: number | number[]): Array;
|
|
1743
1784
|
/** Sample uniform bits in the form of unsigned integers. */
|
|
1744
1785
|
declare function bits(key: Array, shape?: number[]): Array;
|
|
1745
|
-
/**
|
|
1746
|
-
|
|
1747
|
-
|
|
1748
|
-
|
|
1749
|
-
|
|
1750
|
-
minval?: number;
|
|
1751
|
-
maxval?: number;
|
|
1752
|
-
})
|
|
1786
|
+
/**
|
|
1787
|
+
* @function
|
|
1788
|
+
* Sample uniform random values in [minval, maxval) with given shape.
|
|
1789
|
+
*/
|
|
1790
|
+
declare const uniform: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined, args_2?: {
|
|
1791
|
+
minval?: number | undefined;
|
|
1792
|
+
maxval?: number | undefined;
|
|
1793
|
+
} | undefined) => Array>;
|
|
1753
1794
|
/**
|
|
1754
1795
|
* Sample Bernoulli random variables with given mean (0,1 categorical).
|
|
1755
1796
|
*
|
|
@@ -1757,16 +1798,20 @@ declare function uniform(key: Array, shape?: number[], {
|
|
|
1757
1798
|
* and must be broadcastable to `shape`.
|
|
1758
1799
|
*/
|
|
1759
1800
|
declare function bernoulli(key: Array, p?: ArrayLike, shape?: number[]): Array;
|
|
1760
|
-
/** Sample exponential random values according to `p(x) = exp(-x)`. */
|
|
1761
|
-
declare function exponential(key: Array, shape?: number[]): Array;
|
|
1762
1801
|
/**
|
|
1802
|
+
* @function
|
|
1803
|
+
* Sample exponential random values according to `p(x) = exp(-x)`.
|
|
1804
|
+
*/
|
|
1805
|
+
declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
1806
|
+
/**
|
|
1807
|
+
* @function
|
|
1763
1808
|
* Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
|
|
1764
1809
|
*
|
|
1765
1810
|
* Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
|
|
1766
1811
|
* directly inverts the CDF, but we don't have support for that yet. Outputs will not be
|
|
1767
1812
|
* bitwise identical to JAX.
|
|
1768
1813
|
*/
|
|
1769
|
-
declare
|
|
1814
|
+
declare const normal: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
1770
1815
|
//#endregion
|
|
1771
1816
|
//#region src/index.d.ts
|
|
1772
1817
|
/**
|
package/dist/index.d.ts
CHANGED
|
@@ -177,12 +177,12 @@ type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Arr
|
|
|
177
177
|
* **Type lattice:**
|
|
178
178
|
* ```text
|
|
179
179
|
* bool -> uint32 -> int32 -> float16 -> float32
|
|
180
|
-
*
|
|
180
|
+
* weakType --^
|
|
181
181
|
* ```
|
|
182
182
|
*
|
|
183
|
-
*
|
|
184
|
-
*
|
|
185
|
-
*
|
|
183
|
+
* `weakType` represents weakly typed arrays. These are created for JS numbers,
|
|
184
|
+
* which default to float32 but "weak" so they cast to the dtype of any array
|
|
185
|
+
* they are first combined with, except `bool`.
|
|
186
186
|
*
|
|
187
187
|
* **Examples:**
|
|
188
188
|
* - `promoteTypes(bool, int32) → int32`
|
|
@@ -610,6 +610,7 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
|
610
610
|
outDim: number;
|
|
611
611
|
};
|
|
612
612
|
[Primitive.JitCall]: {
|
|
613
|
+
name: string;
|
|
613
614
|
jaxpr: Jaxpr;
|
|
614
615
|
numConsts: number;
|
|
615
616
|
};
|
|
@@ -648,10 +649,40 @@ declare abstract class Trace {
|
|
|
648
649
|
abstract lift(val: Tracer): Tracer;
|
|
649
650
|
abstract processPrimitive<P extends Primitive>(primitive: P, tracers: Tracer[], params: PrimitiveParams<P>): Tracer[];
|
|
650
651
|
}
|
|
652
|
+
/** Internal representation of an array value. */
|
|
651
653
|
interface AbstractValue {
|
|
654
|
+
/** Shape of the array. Must be a static tuple of non-negative dimensions. */
|
|
652
655
|
shape: number[];
|
|
656
|
+
/** Concrete data type of array elements. */
|
|
653
657
|
dtype: DType;
|
|
658
|
+
/**
|
|
659
|
+
* Arrays created from JavaScript numbers (e.g., `np.array(3)`) are created as
|
|
660
|
+
* _weakly typed_ unless a dtype is explicitly specified.
|
|
661
|
+
*
|
|
662
|
+
* Weakly typed values will automatically cast to the data type of other
|
|
663
|
+
* arrays when used as an operand as an expression. This property only affects
|
|
664
|
+
* how they promote in type casting; their memory layout is still determined
|
|
665
|
+
* by the actual `dtype` field.
|
|
666
|
+
*
|
|
667
|
+
* ```ts
|
|
668
|
+
* const x = np.array(3); // weakType = true, dtype = float32
|
|
669
|
+
* const y = np.array([1, 2], { dtype: np.int32 }); // weakType = false, dtype = int32
|
|
670
|
+
* const z = x.add(y); // z has dtype int32 because x is weakly typed
|
|
671
|
+
* ```
|
|
672
|
+
*
|
|
673
|
+
* Weak types are present in JIT programs in their spec (e.g., Jaxpr inputs
|
|
674
|
+
* and outputs can be weakly typed) form. But they're solely a frontend
|
|
675
|
+
* concept. Backends are not aware of weak types.
|
|
676
|
+
*/
|
|
677
|
+
weakType: boolean;
|
|
654
678
|
}
|
|
679
|
+
/**
|
|
680
|
+
* Broadcast shapes and promote types with casting for two avals.
|
|
681
|
+
*
|
|
682
|
+
* This implements the weak type behavior described in `promoteTypes()`, but not
|
|
683
|
+
* implemented in that function as `weakType` is not passed.
|
|
684
|
+
*/
|
|
685
|
+
|
|
655
686
|
declare abstract class Tracer {
|
|
656
687
|
/** @ignore */
|
|
657
688
|
readonly _trace: Trace;
|
|
@@ -709,8 +740,15 @@ declare abstract class Tracer {
|
|
|
709
740
|
get shape(): number[];
|
|
710
741
|
/** The total number of elements in the array. */
|
|
711
742
|
get size(): number;
|
|
712
|
-
/** The dtype of the array. */
|
|
743
|
+
/** The dtype of elements stored in the array. */
|
|
713
744
|
get dtype(): DType;
|
|
745
|
+
/**
|
|
746
|
+
* Whether the array is weakly typed.
|
|
747
|
+
*
|
|
748
|
+
* Weakly typed arrays will cast to the dtype of the other operand. See
|
|
749
|
+
* `promoteTypes()` for details.
|
|
750
|
+
*/
|
|
751
|
+
get weakType(): boolean;
|
|
714
752
|
/** The number of dimensions of the array. */
|
|
715
753
|
get ndim(): number;
|
|
716
754
|
/** @ignore */
|
|
@@ -802,7 +840,8 @@ declare abstract class Tracer {
|
|
|
802
840
|
declare class ShapedArray implements AbstractValue {
|
|
803
841
|
readonly shape: number[];
|
|
804
842
|
readonly dtype: DType;
|
|
805
|
-
|
|
843
|
+
readonly weakType: boolean;
|
|
844
|
+
constructor(shape: number[], dtype: DType, weakType: boolean);
|
|
806
845
|
static fromAval(aval: AbstractValue): ShapedArray;
|
|
807
846
|
get ndim(): number;
|
|
808
847
|
toString(): string;
|
|
@@ -838,6 +877,14 @@ type DTypeAndDevice = {
|
|
|
838
877
|
dtype?: DType;
|
|
839
878
|
device?: Device;
|
|
840
879
|
};
|
|
880
|
+
type ArrayConstructorArgs = {
|
|
881
|
+
source: AluExp | Slot;
|
|
882
|
+
st: ShapeTracker;
|
|
883
|
+
dtype: DType;
|
|
884
|
+
weakType: boolean;
|
|
885
|
+
backend: Backend;
|
|
886
|
+
pending?: Iterable<PendingExecute>;
|
|
887
|
+
};
|
|
841
888
|
/**
|
|
842
889
|
* A multidimensional numeric array with data stored on CPU or GPU.
|
|
843
890
|
*
|
|
@@ -857,11 +904,7 @@ declare class Array extends Tracer {
|
|
|
857
904
|
* is a backend `Slot`, this constructor _takes ownership_ of the slot. It
|
|
858
905
|
* will be freed when the array is disposed.
|
|
859
906
|
*/
|
|
860
|
-
constructor(
|
|
861
|
-
pending
|
|
862
|
-
}?: {
|
|
863
|
-
pending?: Iterable<PendingExecute> | null;
|
|
864
|
-
});
|
|
907
|
+
constructor(args: ArrayConstructorArgs);
|
|
865
908
|
/** @ignore */
|
|
866
909
|
get aval(): ShapedArray;
|
|
867
910
|
/** Return a simple string representation of the array's dimensions. */
|
|
@@ -923,8 +966,6 @@ declare class Array extends Tracer {
|
|
|
923
966
|
static _implRules(): typeof implRules;
|
|
924
967
|
_realizeSource(): number;
|
|
925
968
|
}
|
|
926
|
-
/** Construct an array from a single scalar constant. */
|
|
927
|
-
|
|
928
969
|
/** Constructor for creating a new array from data. */
|
|
929
970
|
declare function array(values: Array | Float16Array<ArrayBuffer> | Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | RecursiveArray<number> | RecursiveArray<boolean>, {
|
|
930
971
|
shape,
|
|
@@ -1477,10 +1518,10 @@ declare class Var {
|
|
|
1477
1518
|
}
|
|
1478
1519
|
/** Literal in a Jaxpr expression. Currently, only scalars are supported. */
|
|
1479
1520
|
declare class Lit {
|
|
1480
|
-
readonly dtype: DType;
|
|
1481
1521
|
readonly value: number;
|
|
1482
1522
|
readonly aval: ShapedArray;
|
|
1483
|
-
|
|
1523
|
+
get dtype(): DType;
|
|
1524
|
+
constructor(aval: AbstractValue, value: number);
|
|
1484
1525
|
}
|
|
1485
1526
|
type Atom = Var | Lit;
|
|
1486
1527
|
declare class VarPrinter {
|
|
@@ -1739,14 +1780,14 @@ declare function key(seed: number): Array;
|
|
|
1739
1780
|
declare function split(key: Array, num?: number | number[]): Array;
|
|
1740
1781
|
/** Sample uniform bits in the form of unsigned integers. */
|
|
1741
1782
|
declare function bits(key: Array, shape?: number[]): Array;
|
|
1742
|
-
/**
|
|
1743
|
-
|
|
1744
|
-
|
|
1745
|
-
|
|
1746
|
-
|
|
1747
|
-
minval?: number;
|
|
1748
|
-
maxval?: number;
|
|
1749
|
-
})
|
|
1783
|
+
/**
|
|
1784
|
+
* @function
|
|
1785
|
+
* Sample uniform random values in [minval, maxval) with given shape.
|
|
1786
|
+
*/
|
|
1787
|
+
declare const uniform: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined, args_2?: {
|
|
1788
|
+
minval?: number | undefined;
|
|
1789
|
+
maxval?: number | undefined;
|
|
1790
|
+
} | undefined) => Array>;
|
|
1750
1791
|
/**
|
|
1751
1792
|
* Sample Bernoulli random variables with given mean (0,1 categorical).
|
|
1752
1793
|
*
|
|
@@ -1754,16 +1795,20 @@ declare function uniform(key: Array, shape?: number[], {
|
|
|
1754
1795
|
* and must be broadcastable to `shape`.
|
|
1755
1796
|
*/
|
|
1756
1797
|
declare function bernoulli(key: Array, p?: ArrayLike, shape?: number[]): Array;
|
|
1757
|
-
/** Sample exponential random values according to `p(x) = exp(-x)`. */
|
|
1758
|
-
declare function exponential(key: Array, shape?: number[]): Array;
|
|
1759
1798
|
/**
|
|
1799
|
+
* @function
|
|
1800
|
+
* Sample exponential random values according to `p(x) = exp(-x)`.
|
|
1801
|
+
*/
|
|
1802
|
+
declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
1803
|
+
/**
|
|
1804
|
+
* @function
|
|
1760
1805
|
* Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
|
|
1761
1806
|
*
|
|
1762
1807
|
* Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
|
|
1763
1808
|
* directly inverts the CDF, but we don't have support for that yet. Outputs will not be
|
|
1764
1809
|
* bitwise identical to JAX.
|
|
1765
1810
|
*/
|
|
1766
|
-
declare
|
|
1811
|
+
declare const normal: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
1767
1812
|
//#endregion
|
|
1768
1813
|
//#region src/index.d.ts
|
|
1769
1814
|
/**
|