@jax-js/jax 0.1.2 → 0.1.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.d.cts CHANGED
@@ -254,7 +254,7 @@ declare class AluExp implements FpHashable {
254
254
  /** Substitute variables in this AluExp to values. */
255
255
  substitute(variables: Record<string, AluExp>): AluExp;
256
256
  /** Reindex gid values in this expression as needed. */
257
- reindexGids(gidMap: Map<number, number>): AluExp;
257
+ reindexGids(newGids: number[]): AluExp;
258
258
  get min(): number;
259
259
  get max(): number;
260
260
  /** Largest known integer that divides self. */
@@ -464,7 +464,7 @@ declare class Executable<T = any> {
464
464
  data: T);
465
465
  }
466
466
  declare namespace numpy_d_exports {
467
- export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot$1 as dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, float64, floor, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, positive, pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, squeeze, stack, std, subtract, sum, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
467
+ export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, cos, cosh, cumsum, cumulativeSum, deg2rad, degrees, diag, diagonal, divide, dot$1 as dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, float64, floor, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, positive, pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, squeeze, stack, std, subtract, sum, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
468
468
  }
469
469
  declare const float32 = DType.Float32;
470
470
  declare const int32 = DType.Int32;
@@ -608,6 +608,15 @@ declare function argmin(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
608
608
  * specified axis.
609
609
  */
610
610
  declare function argmax(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
611
+ /**
612
+ * Cumulative sum of elements along an axis.
613
+ *
614
+ * Currently this function is `O(n^2)`, we'll improve this later on with a
615
+ * two-phase parallel reduction algorithm.
616
+ */
617
+ declare function cumsum(a: ArrayLike, axis?: number): Array;
618
+ /** @function Alternative name for `jax.numpy.cumsum()`. */
619
+ declare const cumulativeSum: typeof cumsum;
611
620
  /** Reverse the elements in an array along the given axes. */
612
621
  declare function flip(x: ArrayLike, axis?: Axis): Array;
613
622
  /**
@@ -1076,6 +1085,7 @@ declare function dispose<Tree extends JsTree<Array>>(tree: Tree | null | undefin
1076
1085
  //#region src/frontend/convolution.d.ts
1077
1086
  /** Definition of a general dilated convolution. Should be valid on creation. */
1078
1087
  interface ConvParams {
1088
+ vmapDims: number;
1079
1089
  strides: number[];
1080
1090
  padding: [number, number][];
1081
1091
  lhsDilation: number[];
@@ -1083,6 +1093,10 @@ interface ConvParams {
1083
1093
  }
1084
1094
  /**
1085
1095
  * Check that the shapes and parameters passed to convolution are valid.
1096
+ * Expected shapes of the lhs and rhs of the convolution are:
1097
+ *
1098
+ * - `lhsShape = [*vmapDims, batchSize, inChannels, spatialDims...]`
1099
+ * - `rhsShape = [*vmapDims, outChannels, inChannels, kernelSize...]`
1086
1100
  *
1087
1101
  * If the check succeeds, returns the output shape.
1088
1102
  */
@@ -1745,10 +1759,12 @@ type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | [number, number][];
1745
1759
  */
1746
1760
  declare function convGeneralDilated(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, {
1747
1761
  lhsDilation,
1748
- rhsDilation
1762
+ rhsDilation,
1763
+ featureGroupCount
1749
1764
  }?: {
1750
1765
  lhsDilation?: number[];
1751
1766
  rhsDilation?: number[];
1767
+ featureGroupCount?: number;
1752
1768
  }): Array;
1753
1769
  /** Convenience wrapper around `convGeneralDilated`. */
1754
1770
  declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, lhsDilation?: number[], rhsDilation?: number[]): Array;
@@ -1915,9 +1931,9 @@ declare function logSoftmax(x: ArrayLike, axis?: Axis): Array;
1915
1931
  *
1916
1932
  * Reference: https://en.wikipedia.org/wiki/LogSumExp
1917
1933
  */
1918
- declare function logsumexp(x: ArrayLike, axis?: Axis): Array;
1934
+ declare function logsumexp(x: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1919
1935
  /** Log-mean-exp reduction, like `jax.nn.logsumexp()` but subtracts `log(n)`. */
1920
- declare function logmeanexp(x: ArrayLike, axis?: Axis): Array;
1936
+ declare function logmeanexp(x: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1921
1937
  /**
1922
1938
  * Standardizes input to zero mean and unit variance.
1923
1939
  *
package/dist/index.d.ts CHANGED
@@ -251,7 +251,7 @@ declare class AluExp implements FpHashable {
251
251
  /** Substitute variables in this AluExp to values. */
252
252
  substitute(variables: Record<string, AluExp>): AluExp;
253
253
  /** Reindex gid values in this expression as needed. */
254
- reindexGids(gidMap: Map<number, number>): AluExp;
254
+ reindexGids(newGids: number[]): AluExp;
255
255
  get min(): number;
256
256
  get max(): number;
257
257
  /** Largest known integer that divides self. */
@@ -461,7 +461,7 @@ declare class Executable<T = any> {
461
461
  data: T);
462
462
  }
463
463
  declare namespace numpy_d_exports {
464
- export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot$1 as dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, float64, floor, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, positive, pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, squeeze, stack, std, subtract, sum, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
464
+ export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, cos, cosh, cumsum, cumulativeSum, deg2rad, degrees, diag, diagonal, divide, dot$1 as dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, float64, floor, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, positive, pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, squeeze, stack, std, subtract, sum, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
465
465
  }
466
466
  declare const float32 = DType.Float32;
467
467
  declare const int32 = DType.Int32;
@@ -605,6 +605,15 @@ declare function argmin(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
605
605
  * specified axis.
606
606
  */
607
607
  declare function argmax(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
608
+ /**
609
+ * Cumulative sum of elements along an axis.
610
+ *
611
+ * Currently this function is `O(n^2)`, we'll improve this later on with a
612
+ * two-phase parallel reduction algorithm.
613
+ */
614
+ declare function cumsum(a: ArrayLike, axis?: number): Array;
615
+ /** @function Alternative name for `jax.numpy.cumsum()`. */
616
+ declare const cumulativeSum: typeof cumsum;
608
617
  /** Reverse the elements in an array along the given axes. */
609
618
  declare function flip(x: ArrayLike, axis?: Axis): Array;
610
619
  /**
@@ -1073,6 +1082,7 @@ declare function dispose<Tree extends JsTree<Array>>(tree: Tree | null | undefin
1073
1082
  //#region src/frontend/convolution.d.ts
1074
1083
  /** Definition of a general dilated convolution. Should be valid on creation. */
1075
1084
  interface ConvParams {
1085
+ vmapDims: number;
1076
1086
  strides: number[];
1077
1087
  padding: [number, number][];
1078
1088
  lhsDilation: number[];
@@ -1080,6 +1090,10 @@ interface ConvParams {
1080
1090
  }
1081
1091
  /**
1082
1092
  * Check that the shapes and parameters passed to convolution are valid.
1093
+ * Expected shapes of the lhs and rhs of the convolution are:
1094
+ *
1095
+ * - `lhsShape = [*vmapDims, batchSize, inChannels, spatialDims...]`
1096
+ * - `rhsShape = [*vmapDims, outChannels, inChannels, kernelSize...]`
1083
1097
  *
1084
1098
  * If the check succeeds, returns the output shape.
1085
1099
  */
@@ -1742,10 +1756,12 @@ type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | [number, number][];
1742
1756
  */
1743
1757
  declare function convGeneralDilated(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, {
1744
1758
  lhsDilation,
1745
- rhsDilation
1759
+ rhsDilation,
1760
+ featureGroupCount
1746
1761
  }?: {
1747
1762
  lhsDilation?: number[];
1748
1763
  rhsDilation?: number[];
1764
+ featureGroupCount?: number;
1749
1765
  }): Array;
1750
1766
  /** Convenience wrapper around `convGeneralDilated`. */
1751
1767
  declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, lhsDilation?: number[], rhsDilation?: number[]): Array;
@@ -1912,9 +1928,9 @@ declare function logSoftmax(x: ArrayLike, axis?: Axis): Array;
1912
1928
  *
1913
1929
  * Reference: https://en.wikipedia.org/wiki/LogSumExp
1914
1930
  */
1915
- declare function logsumexp(x: ArrayLike, axis?: Axis): Array;
1931
+ declare function logsumexp(x: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1916
1932
  /** Log-mean-exp reduction, like `jax.nn.logsumexp()` but subtracts `log(n)`. */
1917
- declare function logmeanexp(x: ArrayLike, axis?: Axis): Array;
1933
+ declare function logmeanexp(x: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1918
1934
  /**
1919
1935
  * Standardizes input to zero mean and unit variance.
1920
1936
  *