@jax-js/jax 0.0.5 → 0.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.d.cts CHANGED
@@ -124,6 +124,7 @@ declare class ShapeTracker {
124
124
  /** Like pad(), but allows for negative values. */
125
125
  padOrShrink(arg: Pair[]): ShapeTracker;
126
126
  }
127
+ //# sourceMappingURL=shape.d.ts.map
127
128
  //#endregion
128
129
  //#region src/utils.d.ts
129
130
  /**
@@ -167,9 +168,10 @@ declare enum DType {
167
168
  Uint32 = "uint32",
168
169
  Bool = "bool",
169
170
  Float16 = "float16",
171
+ Float64 = "float64",
170
172
  }
171
173
  /** @inline */
172
- type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | Float16Array<ArrayBuffer>;
174
+ type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | Float16Array<ArrayBuffer> | Float64Array<ArrayBuffer>;
173
175
  /**
174
176
  * Promote two dtypes to their join according to the type lattice.
175
177
  *
@@ -179,7 +181,7 @@ type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Arr
179
181
  *
180
182
  * **Type lattice:**
181
183
  * ```text
182
- * bool -> uint32 -> int32 -> float16 -> float32
184
+ * bool -> uint32 -> int32 -> float16 -> float32 -> float64
183
185
  * weakType --^
184
186
  * ```
185
187
  *
@@ -222,6 +224,8 @@ declare class AluExp implements FpHashable {
222
224
  static atan(a: AluExp): AluExp;
223
225
  static exp(a: AluExp): AluExp;
224
226
  static log(a: AluExp): AluExp;
227
+ static erf(a: AluExp): AluExp;
228
+ static erfc(a: AluExp): AluExp;
225
229
  static sqrt(a: AluExp): AluExp;
226
230
  static reciprocal(a: AluExp): AluExp;
227
231
  static cast(dtype: DType, a: AluExp): AluExp;
@@ -240,6 +244,7 @@ declare class AluExp implements FpHashable {
240
244
  static u32(value: number): AluExp;
241
245
  static bool(value: boolean): AluExp;
242
246
  static f16(value: number): AluExp;
247
+ static f64(value: number): AluExp;
243
248
  not(): AluExp;
244
249
  /** Compute a reasonable expression hash with low collision rate. */
245
250
  getHash(): bigint;
@@ -289,8 +294,8 @@ declare class AluExp implements FpHashable {
289
294
  rewrite(visitor: (exp: AluExp) => AluExp | undefined | null): AluExp;
290
295
  /** Collect all nodes that satisfy a predicate. */
291
296
  collect(predicate: (exp: AluExp) => boolean): AluExp[];
292
- /** Produce a list of all distinct AluOp in this expression. */
293
- distinctOps(): Set<AluOp>;
297
+ /** Produce all distinct AluOp in this expression, with their dtypes. */
298
+ distinctOps(): Map<AluOp, Set<DType>>;
294
299
  /** Rewrite GlobalView operations to GlobalIndex operations. */
295
300
  rewriteGlobalViews(): AluExp;
296
301
  }
@@ -309,6 +314,8 @@ declare enum AluOp {
309
314
  Atan = "Atan",
310
315
  Exp = "Exp",
311
316
  Log = "Log",
317
+ Erf = "Erf",
318
+ Erfc = "Erfc",
312
319
  Sqrt = "Sqrt",
313
320
  Reciprocal = "Reciprocal",
314
321
  Cast = "Cast",
@@ -465,7 +472,7 @@ type JsTree<T> = T | JsTree<T>[] | {
465
472
  [key: string]: JsTree<T>;
466
473
  };
467
474
  type Same<X, Y> = (<T>() => T extends X ? 1 : 2) extends (<T>() => T extends Y ? 1 : 2) ? true : false;
468
- type MappedJsTree<T, A, B> = T extends A ? B : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
475
+ type MappedJsTree<T, A, B> = T extends A ? B : T extends Array ? T : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
469
476
  /** @ignore Convert a subtype of JsTree<A> into a JsTree<B>, with the same structure. */
470
477
  type MapJsTree<T, A, B> = Same<A, B> extends true ? T : MappedJsTree<T, A, B>;
471
478
  /** Represents the structure of a JsTree. */
@@ -477,6 +484,8 @@ declare class JsTreeDef {
477
484
  constructor(nodeType: NodeType, nodeMetadata: any,
478
485
  // Must be comparable with deepEqual.
479
486
  childTreedefs: JsTreeDef[]);
487
+ /** Get the total number of leaves in the tree. */
488
+ get size(): number;
480
489
  /** Returns a string representation of this tree definition. */
481
490
  toString(root?: boolean): string;
482
491
  /** Compare this tree definition with another. */
@@ -540,6 +549,8 @@ declare enum Primitive {
540
549
  Atan = "atan",
541
550
  Exp = "exp",
542
551
  Log = "log",
552
+ Erf = "erf",
553
+ Erfc = "erfc",
543
554
  Sqrt = "sqrt",
544
555
  Min = "min",
545
556
  Max = "max",
@@ -621,11 +632,9 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
621
632
  /** Type of parameters taken by each primitive. */
622
633
  type PrimitiveParams<T extends Primitive> = T extends keyof PrimitiveParamsImpl ? PrimitiveParamsImpl[T] : Record<string, never>;
623
634
  declare enum CompareOp {
624
- Greater = "greater",
625
635
  Less = "less",
626
636
  Equal = "equal",
627
637
  NotEqual = "not_equal",
628
- GreaterEqual = "greater_equal",
629
638
  LessEqual = "less_equal",
630
639
  }
631
640
  /** @inline */
@@ -886,12 +895,13 @@ type ArrayConstructorArgs = {
886
895
  dtype: DType;
887
896
  weakType: boolean;
888
897
  backend: Backend;
898
+ committed: boolean;
889
899
  pending?: Iterable<PendingExecute>;
890
900
  };
891
901
  /**
892
902
  * A multidimensional numeric array with data stored on CPU or GPU.
893
903
  *
894
- * This is the library's core data type. Equivalent to `jnp.Array` from JAX, or
904
+ * This is the library's core data type. Equivalent to `jax.Array` from JAX, or
895
905
  * `torch.Tensor`.
896
906
  *
897
907
  * Not to be confused with the JavaScript "Array" constructor. Avoid importing
@@ -967,10 +977,15 @@ declare class Array extends Tracer {
967
977
  item(): number;
968
978
  /** @private Internal plumbing method for Array / Tracer ops. */
969
979
  static _implRules(): typeof implRules;
980
+ /** @private */
970
981
  _realizeSource(): number;
982
+ /** @private Put this array on a new backend, asynchronously. */
983
+ _put(backend: Backend): Promise<Array>;
984
+ /** @private Put this array on a new backend, synchronously. */
985
+ _putSync(backend: Backend): Array;
971
986
  }
972
987
  /** Constructor for creating a new array from data. */
973
- declare function array(values: Array | Float16Array<ArrayBuffer> | Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | RecursiveArray<number> | RecursiveArray<boolean>, {
988
+ declare function array(values: Array | DataArray | RecursiveArray<number> | RecursiveArray<boolean>, {
974
989
  shape,
975
990
  dtype,
976
991
  device
@@ -1043,13 +1058,14 @@ declare function linspace(start: number, stop: number, num?: number, endpoint?:
1043
1058
  device
1044
1059
  }?: DTypeAndDevice): Array;
1045
1060
  declare namespace numpy_d_exports {
1046
- export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hstack, hypot, identity$1 as identity, inf, inner, int32, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
1061
+ export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, float64, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
1047
1062
  }
1048
1063
  declare const float32 = DType.Float32;
1049
1064
  declare const int32 = DType.Int32;
1050
1065
  declare const uint32 = DType.Uint32;
1051
1066
  declare const bool = DType.Bool;
1052
1067
  declare const float16 = DType.Float16;
1068
+ declare const float64 = DType.Float64;
1053
1069
  /** Euler's constant, `e = 2.7182818284590...` */
1054
1070
  declare const e: number;
1055
1071
  /** Euler-Mascheroni constant, `γ = 0.5772156649...` */
@@ -1126,7 +1142,7 @@ declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
1126
1142
  * all axes are padded with the same width. Or if it is an array of pairs, each
1127
1143
  * pair specifies the padding for its corresponding axis.
1128
1144
  */
1129
- declare const pad: (x: ArrayLike, width: number | [number, number] | [number, number][]) => Array;
1145
+ declare const pad: (x: ArrayLike, width: number | Pair | Pair[]) => Array;
1130
1146
  /**
1131
1147
  * @function
1132
1148
  * Return the number of dimensions of an array. Does not consume array reference.
@@ -1356,6 +1372,26 @@ declare function absolute(x: ArrayLike): Array;
1356
1372
  declare const abs: typeof absolute;
1357
1373
  /** Return an element-wise indication of sign of the input. */
1358
1374
  declare function sign(x: ArrayLike): Array;
1375
+ /**
1376
+ * Return the Hamming window of size M, a taper with a weighted cosine bell.
1377
+ *
1378
+ * `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
1379
+ */
1380
+ declare function hamming(M: number): Array;
1381
+ /**
1382
+ * Return the Hann window of size M, a taper with a weighted cosine bell.
1383
+ *
1384
+ * `w(n) = 0.5 - 0.5 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
1385
+ */
1386
+ declare function hann(M: number): Array;
1387
+ /**
1388
+ * @function
1389
+ * Compute the Heaviside step function. It is defined piecewise:
1390
+ * - `heaviside(x1, x2) = 0` for `x1 < 0`,
1391
+ * - `heaviside(x1, x2) = x2` for `x1 == 0`,
1392
+ * - `heaviside(x1, x2) = 1` for `x1 > 0`.
1393
+ */
1394
+ declare const heaviside: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
1359
1395
  /** Calculate element-wise square of the input array. */
1360
1396
  declare function square(x: ArrayLike): Array;
1361
1397
  /** Element-wise tangent function (takes radians). */
@@ -1367,8 +1403,8 @@ declare function acos(x: ArrayLike): Array;
1367
1403
  * Return element-wise hypotenuse for the given legs of a right triangle.
1368
1404
  *
1369
1405
  * In the original NumPy/JAX implementation, this function is more numerically
1370
- * stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
1371
- * improvements.
1406
+ * stable than `sqrt(x1**2 + x2**2)`. We don't currently implement those
1407
+ * stability improvements.
1372
1408
  */
1373
1409
  declare const hypot: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
1374
1410
  /**
@@ -1500,6 +1536,20 @@ declare function std(x: ArrayLike, axis?: Axis, opts?: {
1500
1536
  mean?: ArrayLike;
1501
1537
  correction?: number;
1502
1538
  } & ReduceOpts): Array;
1539
+ /** Test element-wise for positive or negative infinity, return bool array. */
1540
+ declare function isinf(x: ArrayLike): Array;
1541
+ /** Test element-wise for NaN (Not a Number). */
1542
+ declare function isnan(x: ArrayLike): Array;
1543
+ /** Test element-wise for negative infinity, return bool array. */
1544
+ declare function isneginf(x: ArrayLike): Array;
1545
+ /** Test element-wise for positive infinity, return bool array. */
1546
+ declare function isposinf(x: ArrayLike): Array;
1547
+ /**
1548
+ * @function
1549
+ * Test element-wise for finite values (not infinity or NaN).
1550
+ */
1551
+ declare const isfinite: OwnedFunction<(x: ArrayLike) => Array>;
1552
+ //# sourceMappingURL=numpy.d.ts.map
1503
1553
  //#endregion
1504
1554
  //#region src/frontend/jaxpr.d.ts
1505
1555
  /**
@@ -1574,10 +1624,9 @@ declare class Jaxpr implements FpHashable {
1574
1624
  /** @inline */
1575
1625
  type JitOpts = {
1576
1626
  staticArgnums?: number[];
1577
- device?: Device;
1578
1627
  };
1579
1628
  declare namespace lax_d_exports {
1580
- export { PaddingType, conv, convGeneralDilated, convWithGeneralPadding, reduceWindow };
1629
+ export { PaddingType, conv, convGeneralDilated, convWithGeneralPadding, erf, erfc, reduceWindow, stopGradient };
1581
1630
  }
1582
1631
  type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | [number, number][];
1583
1632
  /**
@@ -1601,6 +1650,23 @@ declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: n
1601
1650
  declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType): Array;
1602
1651
  /** Reduce a computation over padded windows. */
1603
1652
  declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
1653
+ /** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
1654
+ declare function erf(x: ArrayLike): Array;
1655
+ /**
1656
+ * The complementary error function: `erfc(x) = 1 - erf(x)`.
1657
+ *
1658
+ * This function is more accurate than `1 - erf(x)` for large values of `x`,
1659
+ * where `erf(x)` is very close to 1.
1660
+ */
1661
+ declare function erfc(x: ArrayLike): Array;
1662
+ /**
1663
+ * Stops gradient computation.
1664
+ *
1665
+ * Behaves as the identity function but prevents the flow of gradients during
1666
+ * forward or reverse-mode automatic differentiation.
1667
+ */
1668
+ declare function stopGradient(x: ArrayLike): Array;
1669
+ //# sourceMappingURL=lax.d.ts.map
1604
1670
  declare namespace nn_d_exports {
1605
1671
  export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, squareplus, standardize, swish };
1606
1672
  }
@@ -1685,15 +1751,17 @@ declare function celu(x: ArrayLike, alpha?: ArrayLike): Array;
1685
1751
  * @function
1686
1752
  * Gaussion error linear unit (GELU) activation function.
1687
1753
  *
1688
- * This is computed element-wise. Currently jax-js does not support the erf() or
1689
- * gelu() functions exactly as primitives, so an approximation is used:
1690
- * `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
1754
+ * This is computed element-wise. There are two variants depending on whether
1755
+ * `approximate` is set (default true):
1691
1756
  *
1692
- * Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
1757
+ * - Approximate: `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`
1758
+ * - Exact: `gelu(x) = x * 0.5 * erfc(-x / sqrt(2))`
1693
1759
  *
1694
- * This will be improved in the future.
1760
+ * Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
1695
1761
  */
1696
- declare const gelu: OwnedFunction<(x: ArrayLike) => Array>;
1762
+ declare const gelu: OwnedFunction<(x: ArrayLike, opts?: {
1763
+ approximate?: boolean | undefined;
1764
+ } | undefined) => Array>;
1697
1765
  /**
1698
1766
  * Gated linear unit (GLU) activation function.
1699
1767
  *
@@ -1774,6 +1842,7 @@ declare function standardize(x: ArrayLike, axis?: Axis, opts?: {
1774
1842
  * ```
1775
1843
  */
1776
1844
  declare function oneHot(x: Array, numClasses: number): Array;
1845
+ //# sourceMappingURL=nn.d.ts.map
1777
1846
  declare namespace random_d_exports {
1778
1847
  export { bernoulli, bits, exponential, key, normal, split, uniform };
1779
1848
  }
@@ -1812,6 +1881,15 @@ declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | und
1812
1881
  * bitwise identical to JAX.
1813
1882
  */
1814
1883
  declare const normal: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
1884
+ //# sourceMappingURL=random.d.ts.map
1885
+ declare namespace scipy_special_d_exports {
1886
+ export { erf, erfc, logSoftmax, logit, logsumexp, softmax };
1887
+ }
1888
+ /**
1889
+ * @function
1890
+ * The logit function, `logit(p) = log(p / (1-p))`.
1891
+ */
1892
+ declare const logit: OwnedFunction<(x: ArrayLike) => Array>;
1815
1893
  //#endregion
1816
1894
  //#region src/index.d.ts
1817
1895
  /**
@@ -1823,7 +1901,7 @@ declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: F, primals:
1823
1901
  * @function
1824
1902
  * Vectorize an operation on a batched axis for one or more inputs.
1825
1903
  */
1826
- declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number | MapJsTree<Parameters<F>, ArrayLike, number | null>) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
1904
+ declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number | (number | null | JsTree<number | null>)[]) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
1827
1905
  /**
1828
1906
  * @function
1829
1907
  * Compute the Jacobian evaluated column-by-column by forward-mode AD.
@@ -1898,5 +1976,19 @@ declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJs
1898
1976
  * Does not consume reference to the arrays.
1899
1977
  */
1900
1978
  declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
1979
+ /**
1980
+ * Transfer `x` to `device`.
1981
+ *
1982
+ * `x` may be a nested container of arrays or scalars. The resulting structure
1983
+ * is committed to the device.
1984
+ *
1985
+ * If `device` is not specified, this function behaves as identity if the input
1986
+ * is already an `Array`, otherwise it places the scalar uncommitted on the
1987
+ * default device.
1988
+ */
1989
+ declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
1990
+ //# sourceMappingURL=index.d.ts.map
1991
+
1901
1992
  //#endregion
1902
- export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
1993
+ export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
1994
+ //# sourceMappingURL=index.d.cts.map
package/dist/index.d.ts CHANGED
@@ -121,6 +121,7 @@ declare class ShapeTracker {
121
121
  /** Like pad(), but allows for negative values. */
122
122
  padOrShrink(arg: Pair[]): ShapeTracker;
123
123
  }
124
+ //# sourceMappingURL=shape.d.ts.map
124
125
  //#endregion
125
126
  //#region src/utils.d.ts
126
127
  /**
@@ -164,9 +165,10 @@ declare enum DType {
164
165
  Uint32 = "uint32",
165
166
  Bool = "bool",
166
167
  Float16 = "float16",
168
+ Float64 = "float64",
167
169
  }
168
170
  /** @inline */
169
- type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | Float16Array<ArrayBuffer>;
171
+ type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | Float16Array<ArrayBuffer> | Float64Array<ArrayBuffer>;
170
172
  /**
171
173
  * Promote two dtypes to their join according to the type lattice.
172
174
  *
@@ -176,7 +178,7 @@ type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Arr
176
178
  *
177
179
  * **Type lattice:**
178
180
  * ```text
179
- * bool -> uint32 -> int32 -> float16 -> float32
181
+ * bool -> uint32 -> int32 -> float16 -> float32 -> float64
180
182
  * weakType --^
181
183
  * ```
182
184
  *
@@ -219,6 +221,8 @@ declare class AluExp implements FpHashable {
219
221
  static atan(a: AluExp): AluExp;
220
222
  static exp(a: AluExp): AluExp;
221
223
  static log(a: AluExp): AluExp;
224
+ static erf(a: AluExp): AluExp;
225
+ static erfc(a: AluExp): AluExp;
222
226
  static sqrt(a: AluExp): AluExp;
223
227
  static reciprocal(a: AluExp): AluExp;
224
228
  static cast(dtype: DType, a: AluExp): AluExp;
@@ -237,6 +241,7 @@ declare class AluExp implements FpHashable {
237
241
  static u32(value: number): AluExp;
238
242
  static bool(value: boolean): AluExp;
239
243
  static f16(value: number): AluExp;
244
+ static f64(value: number): AluExp;
240
245
  not(): AluExp;
241
246
  /** Compute a reasonable expression hash with low collision rate. */
242
247
  getHash(): bigint;
@@ -286,8 +291,8 @@ declare class AluExp implements FpHashable {
286
291
  rewrite(visitor: (exp: AluExp) => AluExp | undefined | null): AluExp;
287
292
  /** Collect all nodes that satisfy a predicate. */
288
293
  collect(predicate: (exp: AluExp) => boolean): AluExp[];
289
- /** Produce a list of all distinct AluOp in this expression. */
290
- distinctOps(): Set<AluOp>;
294
+ /** Produce all distinct AluOp in this expression, with their dtypes. */
295
+ distinctOps(): Map<AluOp, Set<DType>>;
291
296
  /** Rewrite GlobalView operations to GlobalIndex operations. */
292
297
  rewriteGlobalViews(): AluExp;
293
298
  }
@@ -306,6 +311,8 @@ declare enum AluOp {
306
311
  Atan = "Atan",
307
312
  Exp = "Exp",
308
313
  Log = "Log",
314
+ Erf = "Erf",
315
+ Erfc = "Erfc",
309
316
  Sqrt = "Sqrt",
310
317
  Reciprocal = "Reciprocal",
311
318
  Cast = "Cast",
@@ -462,7 +469,7 @@ type JsTree<T> = T | JsTree<T>[] | {
462
469
  [key: string]: JsTree<T>;
463
470
  };
464
471
  type Same<X, Y> = (<T>() => T extends X ? 1 : 2) extends (<T>() => T extends Y ? 1 : 2) ? true : false;
465
- type MappedJsTree<T, A, B> = T extends A ? B : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
472
+ type MappedJsTree<T, A, B> = T extends A ? B : T extends Array ? T : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
466
473
  /** @ignore Convert a subtype of JsTree<A> into a JsTree<B>, with the same structure. */
467
474
  type MapJsTree<T, A, B> = Same<A, B> extends true ? T : MappedJsTree<T, A, B>;
468
475
  /** Represents the structure of a JsTree. */
@@ -474,6 +481,8 @@ declare class JsTreeDef {
474
481
  constructor(nodeType: NodeType, nodeMetadata: any,
475
482
  // Must be comparable with deepEqual.
476
483
  childTreedefs: JsTreeDef[]);
484
+ /** Get the total number of leaves in the tree. */
485
+ get size(): number;
477
486
  /** Returns a string representation of this tree definition. */
478
487
  toString(root?: boolean): string;
479
488
  /** Compare this tree definition with another. */
@@ -537,6 +546,8 @@ declare enum Primitive {
537
546
  Atan = "atan",
538
547
  Exp = "exp",
539
548
  Log = "log",
549
+ Erf = "erf",
550
+ Erfc = "erfc",
540
551
  Sqrt = "sqrt",
541
552
  Min = "min",
542
553
  Max = "max",
@@ -618,11 +629,9 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
618
629
  /** Type of parameters taken by each primitive. */
619
630
  type PrimitiveParams<T extends Primitive> = T extends keyof PrimitiveParamsImpl ? PrimitiveParamsImpl[T] : Record<string, never>;
620
631
  declare enum CompareOp {
621
- Greater = "greater",
622
632
  Less = "less",
623
633
  Equal = "equal",
624
634
  NotEqual = "not_equal",
625
- GreaterEqual = "greater_equal",
626
635
  LessEqual = "less_equal",
627
636
  }
628
637
  /** @inline */
@@ -883,12 +892,13 @@ type ArrayConstructorArgs = {
883
892
  dtype: DType;
884
893
  weakType: boolean;
885
894
  backend: Backend;
895
+ committed: boolean;
886
896
  pending?: Iterable<PendingExecute>;
887
897
  };
888
898
  /**
889
899
  * A multidimensional numeric array with data stored on CPU or GPU.
890
900
  *
891
- * This is the library's core data type. Equivalent to `jnp.Array` from JAX, or
901
+ * This is the library's core data type. Equivalent to `jax.Array` from JAX, or
892
902
  * `torch.Tensor`.
893
903
  *
894
904
  * Not to be confused with the JavaScript "Array" constructor. Avoid importing
@@ -964,10 +974,15 @@ declare class Array extends Tracer {
964
974
  item(): number;
965
975
  /** @private Internal plumbing method for Array / Tracer ops. */
966
976
  static _implRules(): typeof implRules;
977
+ /** @private */
967
978
  _realizeSource(): number;
979
+ /** @private Put this array on a new backend, asynchronously. */
980
+ _put(backend: Backend): Promise<Array>;
981
+ /** @private Put this array on a new backend, synchronously. */
982
+ _putSync(backend: Backend): Array;
968
983
  }
969
984
  /** Constructor for creating a new array from data. */
970
- declare function array(values: Array | Float16Array<ArrayBuffer> | Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | RecursiveArray<number> | RecursiveArray<boolean>, {
985
+ declare function array(values: Array | DataArray | RecursiveArray<number> | RecursiveArray<boolean>, {
971
986
  shape,
972
987
  dtype,
973
988
  device
@@ -1040,13 +1055,14 @@ declare function linspace(start: number, stop: number, num?: number, endpoint?:
1040
1055
  device
1041
1056
  }?: DTypeAndDevice): Array;
1042
1057
  declare namespace numpy_d_exports {
1043
- export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hstack, hypot, identity$1 as identity, inf, inner, int32, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
1058
+ export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, float64, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
1044
1059
  }
1045
1060
  declare const float32 = DType.Float32;
1046
1061
  declare const int32 = DType.Int32;
1047
1062
  declare const uint32 = DType.Uint32;
1048
1063
  declare const bool = DType.Bool;
1049
1064
  declare const float16 = DType.Float16;
1065
+ declare const float64 = DType.Float64;
1050
1066
  /** Euler's constant, `e = 2.7182818284590...` */
1051
1067
  declare const e: number;
1052
1068
  /** Euler-Mascheroni constant, `γ = 0.5772156649...` */
@@ -1123,7 +1139,7 @@ declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
1123
1139
  * all axes are padded with the same width. Or if it is an array of pairs, each
1124
1140
  * pair specifies the padding for its corresponding axis.
1125
1141
  */
1126
- declare const pad: (x: ArrayLike, width: number | [number, number] | [number, number][]) => Array;
1142
+ declare const pad: (x: ArrayLike, width: number | Pair | Pair[]) => Array;
1127
1143
  /**
1128
1144
  * @function
1129
1145
  * Return the number of dimensions of an array. Does not consume array reference.
@@ -1353,6 +1369,26 @@ declare function absolute(x: ArrayLike): Array;
1353
1369
  declare const abs: typeof absolute;
1354
1370
  /** Return an element-wise indication of sign of the input. */
1355
1371
  declare function sign(x: ArrayLike): Array;
1372
+ /**
1373
+ * Return the Hamming window of size M, a taper with a weighted cosine bell.
1374
+ *
1375
+ * `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
1376
+ */
1377
+ declare function hamming(M: number): Array;
1378
+ /**
1379
+ * Return the Hann window of size M, a taper with a weighted cosine bell.
1380
+ *
1381
+ * `w(n) = 0.5 - 0.5 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
1382
+ */
1383
+ declare function hann(M: number): Array;
1384
+ /**
1385
+ * @function
1386
+ * Compute the Heaviside step function. It is defined piecewise:
1387
+ * - `heaviside(x1, x2) = 0` for `x1 < 0`,
1388
+ * - `heaviside(x1, x2) = x2` for `x1 == 0`,
1389
+ * - `heaviside(x1, x2) = 1` for `x1 > 0`.
1390
+ */
1391
+ declare const heaviside: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
1356
1392
  /** Calculate element-wise square of the input array. */
1357
1393
  declare function square(x: ArrayLike): Array;
1358
1394
  /** Element-wise tangent function (takes radians). */
@@ -1364,8 +1400,8 @@ declare function acos(x: ArrayLike): Array;
1364
1400
  * Return element-wise hypotenuse for the given legs of a right triangle.
1365
1401
  *
1366
1402
  * In the original NumPy/JAX implementation, this function is more numerically
1367
- * stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
1368
- * improvements.
1403
+ * stable than `sqrt(x1**2 + x2**2)`. We don't currently implement those
1404
+ * stability improvements.
1369
1405
  */
1370
1406
  declare const hypot: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
1371
1407
  /**
@@ -1497,6 +1533,20 @@ declare function std(x: ArrayLike, axis?: Axis, opts?: {
1497
1533
  mean?: ArrayLike;
1498
1534
  correction?: number;
1499
1535
  } & ReduceOpts): Array;
1536
+ /** Test element-wise for positive or negative infinity, return bool array. */
1537
+ declare function isinf(x: ArrayLike): Array;
1538
+ /** Test element-wise for NaN (Not a Number). */
1539
+ declare function isnan(x: ArrayLike): Array;
1540
+ /** Test element-wise for negative infinity, return bool array. */
1541
+ declare function isneginf(x: ArrayLike): Array;
1542
+ /** Test element-wise for positive infinity, return bool array. */
1543
+ declare function isposinf(x: ArrayLike): Array;
1544
+ /**
1545
+ * @function
1546
+ * Test element-wise for finite values (not infinity or NaN).
1547
+ */
1548
+ declare const isfinite: OwnedFunction<(x: ArrayLike) => Array>;
1549
+ //# sourceMappingURL=numpy.d.ts.map
1500
1550
  //#endregion
1501
1551
  //#region src/frontend/jaxpr.d.ts
1502
1552
  /**
@@ -1571,10 +1621,9 @@ declare class Jaxpr implements FpHashable {
1571
1621
  /** @inline */
1572
1622
  type JitOpts = {
1573
1623
  staticArgnums?: number[];
1574
- device?: Device;
1575
1624
  };
1576
1625
  declare namespace lax_d_exports {
1577
- export { PaddingType, conv, convGeneralDilated, convWithGeneralPadding, reduceWindow };
1626
+ export { PaddingType, conv, convGeneralDilated, convWithGeneralPadding, erf, erfc, reduceWindow, stopGradient };
1578
1627
  }
1579
1628
  type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | [number, number][];
1580
1629
  /**
@@ -1598,6 +1647,23 @@ declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: n
1598
1647
  declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType): Array;
1599
1648
  /** Reduce a computation over padded windows. */
1600
1649
  declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
1650
+ /** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
1651
+ declare function erf(x: ArrayLike): Array;
1652
+ /**
1653
+ * The complementary error function: `erfc(x) = 1 - erf(x)`.
1654
+ *
1655
+ * This function is more accurate than `1 - erf(x)` for large values of `x`,
1656
+ * where `erf(x)` is very close to 1.
1657
+ */
1658
+ declare function erfc(x: ArrayLike): Array;
1659
+ /**
1660
+ * Stops gradient computation.
1661
+ *
1662
+ * Behaves as the identity function but prevents the flow of gradients during
1663
+ * forward or reverse-mode automatic differentiation.
1664
+ */
1665
+ declare function stopGradient(x: ArrayLike): Array;
1666
+ //# sourceMappingURL=lax.d.ts.map
1601
1667
  declare namespace nn_d_exports {
1602
1668
  export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, squareplus, standardize, swish };
1603
1669
  }
@@ -1682,15 +1748,17 @@ declare function celu(x: ArrayLike, alpha?: ArrayLike): Array;
1682
1748
  * @function
1683
1749
  * Gaussion error linear unit (GELU) activation function.
1684
1750
  *
1685
- * This is computed element-wise. Currently jax-js does not support the erf() or
1686
- * gelu() functions exactly as primitives, so an approximation is used:
1687
- * `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
1751
+ * This is computed element-wise. There are two variants depending on whether
1752
+ * `approximate` is set (default true):
1688
1753
  *
1689
- * Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
1754
+ * - Approximate: `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`
1755
+ * - Exact: `gelu(x) = x * 0.5 * erfc(-x / sqrt(2))`
1690
1756
  *
1691
- * This will be improved in the future.
1757
+ * Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
1692
1758
  */
1693
- declare const gelu: OwnedFunction<(x: ArrayLike) => Array>;
1759
+ declare const gelu: OwnedFunction<(x: ArrayLike, opts?: {
1760
+ approximate?: boolean | undefined;
1761
+ } | undefined) => Array>;
1694
1762
  /**
1695
1763
  * Gated linear unit (GLU) activation function.
1696
1764
  *
@@ -1771,6 +1839,7 @@ declare function standardize(x: ArrayLike, axis?: Axis, opts?: {
1771
1839
  * ```
1772
1840
  */
1773
1841
  declare function oneHot(x: Array, numClasses: number): Array;
1842
+ //# sourceMappingURL=nn.d.ts.map
1774
1843
  declare namespace random_d_exports {
1775
1844
  export { bernoulli, bits, exponential, key, normal, split, uniform };
1776
1845
  }
@@ -1809,6 +1878,15 @@ declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | und
1809
1878
  * bitwise identical to JAX.
1810
1879
  */
1811
1880
  declare const normal: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
1881
+ //# sourceMappingURL=random.d.ts.map
1882
+ declare namespace scipy_special_d_exports {
1883
+ export { erf, erfc, logSoftmax, logit, logsumexp, softmax };
1884
+ }
1885
+ /**
1886
+ * @function
1887
+ * The logit function, `logit(p) = log(p / (1-p))`.
1888
+ */
1889
+ declare const logit: OwnedFunction<(x: ArrayLike) => Array>;
1812
1890
  //#endregion
1813
1891
  //#region src/index.d.ts
1814
1892
  /**
@@ -1820,7 +1898,7 @@ declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: F, primals:
1820
1898
  * @function
1821
1899
  * Vectorize an operation on a batched axis for one or more inputs.
1822
1900
  */
1823
- declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number | MapJsTree<Parameters<F>, ArrayLike, number | null>) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
1901
+ declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number | (number | null | JsTree<number | null>)[]) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
1824
1902
  /**
1825
1903
  * @function
1826
1904
  * Compute the Jacobian evaluated column-by-column by forward-mode AD.
@@ -1895,5 +1973,19 @@ declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJs
1895
1973
  * Does not consume reference to the arrays.
1896
1974
  */
1897
1975
  declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
1976
+ /**
1977
+ * Transfer `x` to `device`.
1978
+ *
1979
+ * `x` may be a nested container of arrays or scalars. The resulting structure
1980
+ * is committed to the device.
1981
+ *
1982
+ * If `device` is not specified, this function behaves as identity if the input
1983
+ * is already an `Array`, otherwise it places the scalar uncommitted on the
1984
+ * default device.
1985
+ */
1986
+ declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
1987
+ //# sourceMappingURL=index.d.ts.map
1988
+
1898
1989
  //#endregion
1899
- export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
1990
+ export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
1991
+ //# sourceMappingURL=index.d.ts.map