@jax-js/jax 0.0.4 → 0.0.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.d.cts CHANGED
@@ -180,12 +180,12 @@ type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Arr
180
180
  * **Type lattice:**
181
181
  * ```text
182
182
  * bool -> uint32 -> int32 -> float16 -> float32
183
- * weak f* --^
183
+ * weakType --^
184
184
  * ```
185
185
  *
186
- * The asterisk f* is a weak type used for JS number constants. When creating
187
- * arrays, JS numbers default to float32 but "weak" so they cast to the dtype of
188
- * any array they are first combined with.
186
+ * `weakType` represents weakly typed arrays. These are created for JS numbers,
187
+ * which default to float32 but "weak" so they cast to the dtype of any array
188
+ * they are first combined with, except `bool`.
189
189
  *
190
190
  * **Examples:**
191
191
  * - `promoteTypes(bool, int32) → int32`
@@ -613,6 +613,7 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
613
613
  outDim: number;
614
614
  };
615
615
  [Primitive.JitCall]: {
616
+ name: string;
616
617
  jaxpr: Jaxpr;
617
618
  numConsts: number;
618
619
  };
@@ -651,10 +652,40 @@ declare abstract class Trace {
651
652
  abstract lift(val: Tracer): Tracer;
652
653
  abstract processPrimitive<P extends Primitive>(primitive: P, tracers: Tracer[], params: PrimitiveParams<P>): Tracer[];
653
654
  }
655
+ /** Internal representation of an array value. */
654
656
  interface AbstractValue {
657
+ /** Shape of the array. Must be a static tuple of non-negative dimensions. */
655
658
  shape: number[];
659
+ /** Concrete data type of array elements. */
656
660
  dtype: DType;
661
+ /**
662
+ * Arrays created from JavaScript numbers (e.g., `np.array(3)`) are created as
663
+ * _weakly typed_ unless a dtype is explicitly specified.
664
+ *
665
+ * Weakly typed values will automatically cast to the data type of other
666
+ * arrays when used as an operand as an expression. This property only affects
667
+ * how they promote in type casting; their memory layout is still determined
668
+ * by the actual `dtype` field.
669
+ *
670
+ * ```ts
671
+ * const x = np.array(3); // weakType = true, dtype = float32
672
+ * const y = np.array([1, 2], { dtype: np.int32 }); // weakType = false, dtype = int32
673
+ * const z = x.add(y); // z has dtype int32 because x is weakly typed
674
+ * ```
675
+ *
676
+ * Weak types are present in JIT programs in their spec (e.g., Jaxpr inputs
677
+ * and outputs can be weakly typed) form. But they're solely a frontend
678
+ * concept. Backends are not aware of weak types.
679
+ */
680
+ weakType: boolean;
657
681
  }
682
+ /**
683
+ * Broadcast shapes and promote types with casting for two avals.
684
+ *
685
+ * This implements the weak type behavior described in `promoteTypes()`, but not
686
+ * implemented in that function as `weakType` is not passed.
687
+ */
688
+
658
689
  declare abstract class Tracer {
659
690
  /** @ignore */
660
691
  readonly _trace: Trace;
@@ -712,8 +743,15 @@ declare abstract class Tracer {
712
743
  get shape(): number[];
713
744
  /** The total number of elements in the array. */
714
745
  get size(): number;
715
- /** The dtype of the array. */
746
+ /** The dtype of elements stored in the array. */
716
747
  get dtype(): DType;
748
+ /**
749
+ * Whether the array is weakly typed.
750
+ *
751
+ * Weakly typed arrays will cast to the dtype of the other operand. See
752
+ * `promoteTypes()` for details.
753
+ */
754
+ get weakType(): boolean;
717
755
  /** The number of dimensions of the array. */
718
756
  get ndim(): number;
719
757
  /** @ignore */
@@ -805,7 +843,8 @@ declare abstract class Tracer {
805
843
  declare class ShapedArray implements AbstractValue {
806
844
  readonly shape: number[];
807
845
  readonly dtype: DType;
808
- constructor(shape: number[], dtype: DType);
846
+ readonly weakType: boolean;
847
+ constructor(shape: number[], dtype: DType, weakType: boolean);
809
848
  static fromAval(aval: AbstractValue): ShapedArray;
810
849
  get ndim(): number;
811
850
  toString(): string;
@@ -841,6 +880,14 @@ type DTypeAndDevice = {
841
880
  dtype?: DType;
842
881
  device?: Device;
843
882
  };
883
+ type ArrayConstructorArgs = {
884
+ source: AluExp | Slot;
885
+ st: ShapeTracker;
886
+ dtype: DType;
887
+ weakType: boolean;
888
+ backend: Backend;
889
+ pending?: Iterable<PendingExecute>;
890
+ };
844
891
  /**
845
892
  * A multidimensional numeric array with data stored on CPU or GPU.
846
893
  *
@@ -860,11 +907,7 @@ declare class Array extends Tracer {
860
907
  * is a backend `Slot`, this constructor _takes ownership_ of the slot. It
861
908
  * will be freed when the array is disposed.
862
909
  */
863
- constructor(source: AluExp | Slot, st: ShapeTracker, dtype: DType, backend: Backend, {
864
- pending
865
- }?: {
866
- pending?: Iterable<PendingExecute> | null;
867
- });
910
+ constructor(args: ArrayConstructorArgs);
868
911
  /** @ignore */
869
912
  get aval(): ShapedArray;
870
913
  /** Return a simple string representation of the array's dimensions. */
@@ -926,8 +969,6 @@ declare class Array extends Tracer {
926
969
  static _implRules(): typeof implRules;
927
970
  _realizeSource(): number;
928
971
  }
929
- /** Construct an array from a single scalar constant. */
930
-
931
972
  /** Constructor for creating a new array from data. */
932
973
  declare function array(values: Array | Float16Array<ArrayBuffer> | Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | RecursiveArray<number> | RecursiveArray<boolean>, {
933
974
  shape,
@@ -1480,10 +1521,10 @@ declare class Var {
1480
1521
  }
1481
1522
  /** Literal in a Jaxpr expression. Currently, only scalars are supported. */
1482
1523
  declare class Lit {
1483
- readonly dtype: DType;
1484
1524
  readonly value: number;
1485
1525
  readonly aval: ShapedArray;
1486
- constructor(dtype: DType, value: number);
1526
+ get dtype(): DType;
1527
+ constructor(aval: AbstractValue, value: number);
1487
1528
  }
1488
1529
  type Atom = Var | Lit;
1489
1530
  declare class VarPrinter {
@@ -1742,14 +1783,14 @@ declare function key(seed: number): Array;
1742
1783
  declare function split(key: Array, num?: number | number[]): Array;
1743
1784
  /** Sample uniform bits in the form of unsigned integers. */
1744
1785
  declare function bits(key: Array, shape?: number[]): Array;
1745
- /** Sample uniform random values in [minval, maxval) with given shape. */
1746
- declare function uniform(key: Array, shape?: number[], {
1747
- minval,
1748
- maxval
1749
- }?: {
1750
- minval?: number;
1751
- maxval?: number;
1752
- }): Array;
1786
+ /**
1787
+ * @function
1788
+ * Sample uniform random values in [minval, maxval) with given shape.
1789
+ */
1790
+ declare const uniform: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined, args_2?: {
1791
+ minval?: number | undefined;
1792
+ maxval?: number | undefined;
1793
+ } | undefined) => Array>;
1753
1794
  /**
1754
1795
  * Sample Bernoulli random variables with given mean (0,1 categorical).
1755
1796
  *
@@ -1757,16 +1798,20 @@ declare function uniform(key: Array, shape?: number[], {
1757
1798
  * and must be broadcastable to `shape`.
1758
1799
  */
1759
1800
  declare function bernoulli(key: Array, p?: ArrayLike, shape?: number[]): Array;
1760
- /** Sample exponential random values according to `p(x) = exp(-x)`. */
1761
- declare function exponential(key: Array, shape?: number[]): Array;
1762
1801
  /**
1802
+ * @function
1803
+ * Sample exponential random values according to `p(x) = exp(-x)`.
1804
+ */
1805
+ declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
1806
+ /**
1807
+ * @function
1763
1808
  * Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
1764
1809
  *
1765
1810
  * Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
1766
1811
  * directly inverts the CDF, but we don't have support for that yet. Outputs will not be
1767
1812
  * bitwise identical to JAX.
1768
1813
  */
1769
- declare function normal(key: Array, shape?: number[]): Array;
1814
+ declare const normal: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
1770
1815
  //#endregion
1771
1816
  //#region src/index.d.ts
1772
1817
  /**
package/dist/index.d.ts CHANGED
@@ -177,12 +177,12 @@ type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Arr
177
177
  * **Type lattice:**
178
178
  * ```text
179
179
  * bool -> uint32 -> int32 -> float16 -> float32
180
- * weak f* --^
180
+ * weakType --^
181
181
  * ```
182
182
  *
183
- * The asterisk f* is a weak type used for JS number constants. When creating
184
- * arrays, JS numbers default to float32 but "weak" so they cast to the dtype of
185
- * any array they are first combined with.
183
+ * `weakType` represents weakly typed arrays. These are created for JS numbers,
184
+ * which default to float32 but "weak" so they cast to the dtype of any array
185
+ * they are first combined with, except `bool`.
186
186
  *
187
187
  * **Examples:**
188
188
  * - `promoteTypes(bool, int32) → int32`
@@ -610,6 +610,7 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
610
610
  outDim: number;
611
611
  };
612
612
  [Primitive.JitCall]: {
613
+ name: string;
613
614
  jaxpr: Jaxpr;
614
615
  numConsts: number;
615
616
  };
@@ -648,10 +649,40 @@ declare abstract class Trace {
648
649
  abstract lift(val: Tracer): Tracer;
649
650
  abstract processPrimitive<P extends Primitive>(primitive: P, tracers: Tracer[], params: PrimitiveParams<P>): Tracer[];
650
651
  }
652
+ /** Internal representation of an array value. */
651
653
  interface AbstractValue {
654
+ /** Shape of the array. Must be a static tuple of non-negative dimensions. */
652
655
  shape: number[];
656
+ /** Concrete data type of array elements. */
653
657
  dtype: DType;
658
+ /**
659
+ * Arrays created from JavaScript numbers (e.g., `np.array(3)`) are created as
660
+ * _weakly typed_ unless a dtype is explicitly specified.
661
+ *
662
+ * Weakly typed values will automatically cast to the data type of other
663
+ * arrays when used as an operand as an expression. This property only affects
664
+ * how they promote in type casting; their memory layout is still determined
665
+ * by the actual `dtype` field.
666
+ *
667
+ * ```ts
668
+ * const x = np.array(3); // weakType = true, dtype = float32
669
+ * const y = np.array([1, 2], { dtype: np.int32 }); // weakType = false, dtype = int32
670
+ * const z = x.add(y); // z has dtype int32 because x is weakly typed
671
+ * ```
672
+ *
673
+ * Weak types are present in JIT programs in their spec (e.g., Jaxpr inputs
674
+ * and outputs can be weakly typed) form. But they're solely a frontend
675
+ * concept. Backends are not aware of weak types.
676
+ */
677
+ weakType: boolean;
654
678
  }
679
+ /**
680
+ * Broadcast shapes and promote types with casting for two avals.
681
+ *
682
+ * This implements the weak type behavior described in `promoteTypes()`, but not
683
+ * implemented in that function as `weakType` is not passed.
684
+ */
685
+
655
686
  declare abstract class Tracer {
656
687
  /** @ignore */
657
688
  readonly _trace: Trace;
@@ -709,8 +740,15 @@ declare abstract class Tracer {
709
740
  get shape(): number[];
710
741
  /** The total number of elements in the array. */
711
742
  get size(): number;
712
- /** The dtype of the array. */
743
+ /** The dtype of elements stored in the array. */
713
744
  get dtype(): DType;
745
+ /**
746
+ * Whether the array is weakly typed.
747
+ *
748
+ * Weakly typed arrays will cast to the dtype of the other operand. See
749
+ * `promoteTypes()` for details.
750
+ */
751
+ get weakType(): boolean;
714
752
  /** The number of dimensions of the array. */
715
753
  get ndim(): number;
716
754
  /** @ignore */
@@ -802,7 +840,8 @@ declare abstract class Tracer {
802
840
  declare class ShapedArray implements AbstractValue {
803
841
  readonly shape: number[];
804
842
  readonly dtype: DType;
805
- constructor(shape: number[], dtype: DType);
843
+ readonly weakType: boolean;
844
+ constructor(shape: number[], dtype: DType, weakType: boolean);
806
845
  static fromAval(aval: AbstractValue): ShapedArray;
807
846
  get ndim(): number;
808
847
  toString(): string;
@@ -838,6 +877,14 @@ type DTypeAndDevice = {
838
877
  dtype?: DType;
839
878
  device?: Device;
840
879
  };
880
+ type ArrayConstructorArgs = {
881
+ source: AluExp | Slot;
882
+ st: ShapeTracker;
883
+ dtype: DType;
884
+ weakType: boolean;
885
+ backend: Backend;
886
+ pending?: Iterable<PendingExecute>;
887
+ };
841
888
  /**
842
889
  * A multidimensional numeric array with data stored on CPU or GPU.
843
890
  *
@@ -857,11 +904,7 @@ declare class Array extends Tracer {
857
904
  * is a backend `Slot`, this constructor _takes ownership_ of the slot. It
858
905
  * will be freed when the array is disposed.
859
906
  */
860
- constructor(source: AluExp | Slot, st: ShapeTracker, dtype: DType, backend: Backend, {
861
- pending
862
- }?: {
863
- pending?: Iterable<PendingExecute> | null;
864
- });
907
+ constructor(args: ArrayConstructorArgs);
865
908
  /** @ignore */
866
909
  get aval(): ShapedArray;
867
910
  /** Return a simple string representation of the array's dimensions. */
@@ -923,8 +966,6 @@ declare class Array extends Tracer {
923
966
  static _implRules(): typeof implRules;
924
967
  _realizeSource(): number;
925
968
  }
926
- /** Construct an array from a single scalar constant. */
927
-
928
969
  /** Constructor for creating a new array from data. */
929
970
  declare function array(values: Array | Float16Array<ArrayBuffer> | Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | RecursiveArray<number> | RecursiveArray<boolean>, {
930
971
  shape,
@@ -1477,10 +1518,10 @@ declare class Var {
1477
1518
  }
1478
1519
  /** Literal in a Jaxpr expression. Currently, only scalars are supported. */
1479
1520
  declare class Lit {
1480
- readonly dtype: DType;
1481
1521
  readonly value: number;
1482
1522
  readonly aval: ShapedArray;
1483
- constructor(dtype: DType, value: number);
1523
+ get dtype(): DType;
1524
+ constructor(aval: AbstractValue, value: number);
1484
1525
  }
1485
1526
  type Atom = Var | Lit;
1486
1527
  declare class VarPrinter {
@@ -1739,14 +1780,14 @@ declare function key(seed: number): Array;
1739
1780
  declare function split(key: Array, num?: number | number[]): Array;
1740
1781
  /** Sample uniform bits in the form of unsigned integers. */
1741
1782
  declare function bits(key: Array, shape?: number[]): Array;
1742
- /** Sample uniform random values in [minval, maxval) with given shape. */
1743
- declare function uniform(key: Array, shape?: number[], {
1744
- minval,
1745
- maxval
1746
- }?: {
1747
- minval?: number;
1748
- maxval?: number;
1749
- }): Array;
1783
+ /**
1784
+ * @function
1785
+ * Sample uniform random values in [minval, maxval) with given shape.
1786
+ */
1787
+ declare const uniform: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined, args_2?: {
1788
+ minval?: number | undefined;
1789
+ maxval?: number | undefined;
1790
+ } | undefined) => Array>;
1750
1791
  /**
1751
1792
  * Sample Bernoulli random variables with given mean (0,1 categorical).
1752
1793
  *
@@ -1754,16 +1795,20 @@ declare function uniform(key: Array, shape?: number[], {
1754
1795
  * and must be broadcastable to `shape`.
1755
1796
  */
1756
1797
  declare function bernoulli(key: Array, p?: ArrayLike, shape?: number[]): Array;
1757
- /** Sample exponential random values according to `p(x) = exp(-x)`. */
1758
- declare function exponential(key: Array, shape?: number[]): Array;
1759
1798
  /**
1799
+ * @function
1800
+ * Sample exponential random values according to `p(x) = exp(-x)`.
1801
+ */
1802
+ declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
1803
+ /**
1804
+ * @function
1760
1805
  * Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
1761
1806
  *
1762
1807
  * Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
1763
1808
  * directly inverts the CDF, but we don't have support for that yet. Outputs will not be
1764
1809
  * bitwise identical to JAX.
1765
1810
  */
1766
- declare function normal(key: Array, shape?: number[]): Array;
1811
+ declare const normal: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
1767
1812
  //#endregion
1768
1813
  //#region src/index.d.ts
1769
1814
  /**