@jax-js/jax 0.0.5 → 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.d.cts CHANGED
@@ -124,6 +124,7 @@ declare class ShapeTracker {
124
124
  /** Like pad(), but allows for negative values. */
125
125
  padOrShrink(arg: Pair[]): ShapeTracker;
126
126
  }
127
+ //# sourceMappingURL=shape.d.ts.map
127
128
  //#endregion
128
129
  //#region src/utils.d.ts
129
130
  /**
@@ -222,6 +223,8 @@ declare class AluExp implements FpHashable {
222
223
  static atan(a: AluExp): AluExp;
223
224
  static exp(a: AluExp): AluExp;
224
225
  static log(a: AluExp): AluExp;
226
+ static erf(a: AluExp): AluExp;
227
+ static erfc(a: AluExp): AluExp;
225
228
  static sqrt(a: AluExp): AluExp;
226
229
  static reciprocal(a: AluExp): AluExp;
227
230
  static cast(dtype: DType, a: AluExp): AluExp;
@@ -289,8 +292,8 @@ declare class AluExp implements FpHashable {
289
292
  rewrite(visitor: (exp: AluExp) => AluExp | undefined | null): AluExp;
290
293
  /** Collect all nodes that satisfy a predicate. */
291
294
  collect(predicate: (exp: AluExp) => boolean): AluExp[];
292
- /** Produce a list of all distinct AluOp in this expression. */
293
- distinctOps(): Set<AluOp>;
295
+ /** Produce all distinct AluOp in this expression, with their dtypes. */
296
+ distinctOps(): Map<AluOp, Set<DType>>;
294
297
  /** Rewrite GlobalView operations to GlobalIndex operations. */
295
298
  rewriteGlobalViews(): AluExp;
296
299
  }
@@ -309,6 +312,8 @@ declare enum AluOp {
309
312
  Atan = "Atan",
310
313
  Exp = "Exp",
311
314
  Log = "Log",
315
+ Erf = "Erf",
316
+ Erfc = "Erfc",
312
317
  Sqrt = "Sqrt",
313
318
  Reciprocal = "Reciprocal",
314
319
  Cast = "Cast",
@@ -465,7 +470,7 @@ type JsTree<T> = T | JsTree<T>[] | {
465
470
  [key: string]: JsTree<T>;
466
471
  };
467
472
  type Same<X, Y> = (<T>() => T extends X ? 1 : 2) extends (<T>() => T extends Y ? 1 : 2) ? true : false;
468
- type MappedJsTree<T, A, B> = T extends A ? B : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
473
+ type MappedJsTree<T, A, B> = T extends A ? B : T extends Array ? T : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
469
474
  /** @ignore Convert a subtype of JsTree<A> into a JsTree<B>, with the same structure. */
470
475
  type MapJsTree<T, A, B> = Same<A, B> extends true ? T : MappedJsTree<T, A, B>;
471
476
  /** Represents the structure of a JsTree. */
@@ -477,6 +482,8 @@ declare class JsTreeDef {
477
482
  constructor(nodeType: NodeType, nodeMetadata: any,
478
483
  // Must be comparable with deepEqual.
479
484
  childTreedefs: JsTreeDef[]);
485
+ /** Get the total number of leaves in the tree. */
486
+ get size(): number;
480
487
  /** Returns a string representation of this tree definition. */
481
488
  toString(root?: boolean): string;
482
489
  /** Compare this tree definition with another. */
@@ -540,6 +547,8 @@ declare enum Primitive {
540
547
  Atan = "atan",
541
548
  Exp = "exp",
542
549
  Log = "log",
550
+ Erf = "erf",
551
+ Erfc = "erfc",
543
552
  Sqrt = "sqrt",
544
553
  Min = "min",
545
554
  Max = "max",
@@ -886,12 +895,13 @@ type ArrayConstructorArgs = {
886
895
  dtype: DType;
887
896
  weakType: boolean;
888
897
  backend: Backend;
898
+ committed: boolean;
889
899
  pending?: Iterable<PendingExecute>;
890
900
  };
891
901
  /**
892
902
  * A multidimensional numeric array with data stored on CPU or GPU.
893
903
  *
894
- * This is the library's core data type. Equivalent to `jnp.Array` from JAX, or
904
+ * This is the library's core data type. Equivalent to `jax.Array` from JAX, or
895
905
  * `torch.Tensor`.
896
906
  *
897
907
  * Not to be confused with the JavaScript "Array" constructor. Avoid importing
@@ -967,7 +977,12 @@ declare class Array extends Tracer {
967
977
  item(): number;
968
978
  /** @private Internal plumbing method for Array / Tracer ops. */
969
979
  static _implRules(): typeof implRules;
980
+ /** @private */
970
981
  _realizeSource(): number;
982
+ /** @private Put this array on a new backend, asynchronously. */
983
+ _put(backend: Backend): Promise<Array>;
984
+ /** @private Put this array on a new backend, synchronously. */
985
+ _putSync(backend: Backend): Array;
971
986
  }
972
987
  /** Constructor for creating a new array from data. */
973
988
  declare function array(values: Array | Float16Array<ArrayBuffer> | Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | RecursiveArray<number> | RecursiveArray<boolean>, {
@@ -1043,7 +1058,7 @@ declare function linspace(start: number, stop: number, num?: number, endpoint?:
1043
1058
  device
1044
1059
  }?: DTypeAndDevice): Array;
1045
1060
  declare namespace numpy_d_exports {
1046
- export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hstack, hypot, identity$1 as identity, inf, inner, int32, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
1061
+ export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
1047
1062
  }
1048
1063
  declare const float32 = DType.Float32;
1049
1064
  declare const int32 = DType.Int32;
@@ -1126,7 +1141,7 @@ declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
1126
1141
  * all axes are padded with the same width. Or if it is an array of pairs, each
1127
1142
  * pair specifies the padding for its corresponding axis.
1128
1143
  */
1129
- declare const pad: (x: ArrayLike, width: number | [number, number] | [number, number][]) => Array;
1144
+ declare const pad: (x: ArrayLike, width: number | Pair | Pair[]) => Array;
1130
1145
  /**
1131
1146
  * @function
1132
1147
  * Return the number of dimensions of an array. Does not consume array reference.
@@ -1356,6 +1371,26 @@ declare function absolute(x: ArrayLike): Array;
1356
1371
  declare const abs: typeof absolute;
1357
1372
  /** Return an element-wise indication of sign of the input. */
1358
1373
  declare function sign(x: ArrayLike): Array;
1374
+ /**
1375
+ * Return the Hamming window of size M, a taper with a weighted cosine bell.
1376
+ *
1377
+ * `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
1378
+ */
1379
+ declare function hamming(M: number): Array;
1380
+ /**
1381
+ * Return the Hann window of size M, a taper with a weighted cosine bell.
1382
+ *
1383
+ * `w(n) = 0.5 - 0.5 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
1384
+ */
1385
+ declare function hann(M: number): Array;
1386
+ /**
1387
+ * @function
1388
+ * Compute the Heaviside step function. It is defined piecewise:
1389
+ * - `heaviside(x1, x2) = 0` for `x1 < 0`,
1390
+ * - `heaviside(x1, x2) = x2` for `x1 == 0`,
1391
+ * - `heaviside(x1, x2) = 1` for `x1 > 0`.
1392
+ */
1393
+ declare const heaviside: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
1359
1394
  /** Calculate element-wise square of the input array. */
1360
1395
  declare function square(x: ArrayLike): Array;
1361
1396
  /** Element-wise tangent function (takes radians). */
@@ -1367,8 +1402,8 @@ declare function acos(x: ArrayLike): Array;
1367
1402
  * Return element-wise hypotenuse for the given legs of a right triangle.
1368
1403
  *
1369
1404
  * In the original NumPy/JAX implementation, this function is more numerically
1370
- * stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
1371
- * improvements.
1405
+ * stable than `sqrt(x1**2 + x2**2)`. We don't currently implement those
1406
+ * stability improvements.
1372
1407
  */
1373
1408
  declare const hypot: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
1374
1409
  /**
@@ -1500,6 +1535,7 @@ declare function std(x: ArrayLike, axis?: Axis, opts?: {
1500
1535
  mean?: ArrayLike;
1501
1536
  correction?: number;
1502
1537
  } & ReduceOpts): Array;
1538
+ //# sourceMappingURL=numpy.d.ts.map
1503
1539
  //#endregion
1504
1540
  //#region src/frontend/jaxpr.d.ts
1505
1541
  /**
@@ -1574,10 +1610,9 @@ declare class Jaxpr implements FpHashable {
1574
1610
  /** @inline */
1575
1611
  type JitOpts = {
1576
1612
  staticArgnums?: number[];
1577
- device?: Device;
1578
1613
  };
1579
1614
  declare namespace lax_d_exports {
1580
- export { PaddingType, conv, convGeneralDilated, convWithGeneralPadding, reduceWindow };
1615
+ export { PaddingType, conv, convGeneralDilated, convWithGeneralPadding, erf, erfc, reduceWindow, stopGradient };
1581
1616
  }
1582
1617
  type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | [number, number][];
1583
1618
  /**
@@ -1601,6 +1636,23 @@ declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: n
1601
1636
  declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType): Array;
1602
1637
  /** Reduce a computation over padded windows. */
1603
1638
  declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
1639
+ /** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
1640
+ declare function erf(x: ArrayLike): Array;
1641
+ /**
1642
+ * The complementary error function: `erfc(x) = 1 - erf(x)`.
1643
+ *
1644
+ * This function is more accurate than `1 - erf(x)` for large values of `x`,
1645
+ * where `erf(x)` is very close to 1.
1646
+ */
1647
+ declare function erfc(x: ArrayLike): Array;
1648
+ /**
1649
+ * Stops gradient computation.
1650
+ *
1651
+ * Behaves as the identity function but prevents the flow of gradients during
1652
+ * forward or reverse-mode automatic differentiation.
1653
+ */
1654
+ declare function stopGradient(x: ArrayLike): Array;
1655
+ //# sourceMappingURL=lax.d.ts.map
1604
1656
  declare namespace nn_d_exports {
1605
1657
  export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, squareplus, standardize, swish };
1606
1658
  }
@@ -1685,15 +1737,17 @@ declare function celu(x: ArrayLike, alpha?: ArrayLike): Array;
1685
1737
  * @function
1686
1738
  * Gaussion error linear unit (GELU) activation function.
1687
1739
  *
1688
- * This is computed element-wise. Currently jax-js does not support the erf() or
1689
- * gelu() functions exactly as primitives, so an approximation is used:
1690
- * `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
1740
+ * This is computed element-wise. There are two variants depending on whether
1741
+ * `approximate` is set (default true):
1691
1742
  *
1692
- * Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
1743
+ * - Approximate: `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`
1744
+ * - Exact: `gelu(x) = x * 0.5 * erfc(-x / sqrt(2))`
1693
1745
  *
1694
- * This will be improved in the future.
1746
+ * Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
1695
1747
  */
1696
- declare const gelu: OwnedFunction<(x: ArrayLike) => Array>;
1748
+ declare const gelu: OwnedFunction<(x: ArrayLike, opts?: {
1749
+ approximate?: boolean | undefined;
1750
+ } | undefined) => Array>;
1697
1751
  /**
1698
1752
  * Gated linear unit (GLU) activation function.
1699
1753
  *
@@ -1774,6 +1828,7 @@ declare function standardize(x: ArrayLike, axis?: Axis, opts?: {
1774
1828
  * ```
1775
1829
  */
1776
1830
  declare function oneHot(x: Array, numClasses: number): Array;
1831
+ //# sourceMappingURL=nn.d.ts.map
1777
1832
  declare namespace random_d_exports {
1778
1833
  export { bernoulli, bits, exponential, key, normal, split, uniform };
1779
1834
  }
@@ -1812,6 +1867,15 @@ declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | und
1812
1867
  * bitwise identical to JAX.
1813
1868
  */
1814
1869
  declare const normal: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
1870
+ //# sourceMappingURL=random.d.ts.map
1871
+ declare namespace scipy_special_d_exports {
1872
+ export { erf, erfc, logSoftmax, logit, logsumexp, softmax };
1873
+ }
1874
+ /**
1875
+ * @function
1876
+ * The logit function, `logit(p) = log(p / (1-p))`.
1877
+ */
1878
+ declare const logit: OwnedFunction<(x: ArrayLike) => Array>;
1815
1879
  //#endregion
1816
1880
  //#region src/index.d.ts
1817
1881
  /**
@@ -1823,7 +1887,7 @@ declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: F, primals:
1823
1887
  * @function
1824
1888
  * Vectorize an operation on a batched axis for one or more inputs.
1825
1889
  */
1826
- declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number | MapJsTree<Parameters<F>, ArrayLike, number | null>) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
1890
+ declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number | (number | null | JsTree<number | null>)[]) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
1827
1891
  /**
1828
1892
  * @function
1829
1893
  * Compute the Jacobian evaluated column-by-column by forward-mode AD.
@@ -1898,5 +1962,19 @@ declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJs
1898
1962
  * Does not consume reference to the arrays.
1899
1963
  */
1900
1964
  declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
1965
+ /**
1966
+ * Transfer `x` to `device`.
1967
+ *
1968
+ * `x` may be a nested container of arrays or scalars. The resulting structure
1969
+ * is committed to the device.
1970
+ *
1971
+ * If `device` is not specified, this function behaves as identity if the input
1972
+ * is already an `Array`, otherwise it places the scalar uncommitted on the
1973
+ * default device.
1974
+ */
1975
+ declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
1976
+ //# sourceMappingURL=index.d.ts.map
1977
+
1901
1978
  //#endregion
1902
- export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
1979
+ export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
1980
+ //# sourceMappingURL=index.d.cts.map
package/dist/index.d.ts CHANGED
@@ -121,6 +121,7 @@ declare class ShapeTracker {
121
121
  /** Like pad(), but allows for negative values. */
122
122
  padOrShrink(arg: Pair[]): ShapeTracker;
123
123
  }
124
+ //# sourceMappingURL=shape.d.ts.map
124
125
  //#endregion
125
126
  //#region src/utils.d.ts
126
127
  /**
@@ -219,6 +220,8 @@ declare class AluExp implements FpHashable {
219
220
  static atan(a: AluExp): AluExp;
220
221
  static exp(a: AluExp): AluExp;
221
222
  static log(a: AluExp): AluExp;
223
+ static erf(a: AluExp): AluExp;
224
+ static erfc(a: AluExp): AluExp;
222
225
  static sqrt(a: AluExp): AluExp;
223
226
  static reciprocal(a: AluExp): AluExp;
224
227
  static cast(dtype: DType, a: AluExp): AluExp;
@@ -286,8 +289,8 @@ declare class AluExp implements FpHashable {
286
289
  rewrite(visitor: (exp: AluExp) => AluExp | undefined | null): AluExp;
287
290
  /** Collect all nodes that satisfy a predicate. */
288
291
  collect(predicate: (exp: AluExp) => boolean): AluExp[];
289
- /** Produce a list of all distinct AluOp in this expression. */
290
- distinctOps(): Set<AluOp>;
292
+ /** Produce all distinct AluOp in this expression, with their dtypes. */
293
+ distinctOps(): Map<AluOp, Set<DType>>;
291
294
  /** Rewrite GlobalView operations to GlobalIndex operations. */
292
295
  rewriteGlobalViews(): AluExp;
293
296
  }
@@ -306,6 +309,8 @@ declare enum AluOp {
306
309
  Atan = "Atan",
307
310
  Exp = "Exp",
308
311
  Log = "Log",
312
+ Erf = "Erf",
313
+ Erfc = "Erfc",
309
314
  Sqrt = "Sqrt",
310
315
  Reciprocal = "Reciprocal",
311
316
  Cast = "Cast",
@@ -462,7 +467,7 @@ type JsTree<T> = T | JsTree<T>[] | {
462
467
  [key: string]: JsTree<T>;
463
468
  };
464
469
  type Same<X, Y> = (<T>() => T extends X ? 1 : 2) extends (<T>() => T extends Y ? 1 : 2) ? true : false;
465
- type MappedJsTree<T, A, B> = T extends A ? B : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
470
+ type MappedJsTree<T, A, B> = T extends A ? B : T extends Array ? T : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
466
471
  /** @ignore Convert a subtype of JsTree<A> into a JsTree<B>, with the same structure. */
467
472
  type MapJsTree<T, A, B> = Same<A, B> extends true ? T : MappedJsTree<T, A, B>;
468
473
  /** Represents the structure of a JsTree. */
@@ -474,6 +479,8 @@ declare class JsTreeDef {
474
479
  constructor(nodeType: NodeType, nodeMetadata: any,
475
480
  // Must be comparable with deepEqual.
476
481
  childTreedefs: JsTreeDef[]);
482
+ /** Get the total number of leaves in the tree. */
483
+ get size(): number;
477
484
  /** Returns a string representation of this tree definition. */
478
485
  toString(root?: boolean): string;
479
486
  /** Compare this tree definition with another. */
@@ -537,6 +544,8 @@ declare enum Primitive {
537
544
  Atan = "atan",
538
545
  Exp = "exp",
539
546
  Log = "log",
547
+ Erf = "erf",
548
+ Erfc = "erfc",
540
549
  Sqrt = "sqrt",
541
550
  Min = "min",
542
551
  Max = "max",
@@ -883,12 +892,13 @@ type ArrayConstructorArgs = {
883
892
  dtype: DType;
884
893
  weakType: boolean;
885
894
  backend: Backend;
895
+ committed: boolean;
886
896
  pending?: Iterable<PendingExecute>;
887
897
  };
888
898
  /**
889
899
  * A multidimensional numeric array with data stored on CPU or GPU.
890
900
  *
891
- * This is the library's core data type. Equivalent to `jnp.Array` from JAX, or
901
+ * This is the library's core data type. Equivalent to `jax.Array` from JAX, or
892
902
  * `torch.Tensor`.
893
903
  *
894
904
  * Not to be confused with the JavaScript "Array" constructor. Avoid importing
@@ -964,7 +974,12 @@ declare class Array extends Tracer {
964
974
  item(): number;
965
975
  /** @private Internal plumbing method for Array / Tracer ops. */
966
976
  static _implRules(): typeof implRules;
977
+ /** @private */
967
978
  _realizeSource(): number;
979
+ /** @private Put this array on a new backend, asynchronously. */
980
+ _put(backend: Backend): Promise<Array>;
981
+ /** @private Put this array on a new backend, synchronously. */
982
+ _putSync(backend: Backend): Array;
968
983
  }
969
984
  /** Constructor for creating a new array from data. */
970
985
  declare function array(values: Array | Float16Array<ArrayBuffer> | Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | RecursiveArray<number> | RecursiveArray<boolean>, {
@@ -1040,7 +1055,7 @@ declare function linspace(start: number, stop: number, num?: number, endpoint?:
1040
1055
  device
1041
1056
  }?: DTypeAndDevice): Array;
1042
1057
  declare namespace numpy_d_exports {
1043
- export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hstack, hypot, identity$1 as identity, inf, inner, int32, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
1058
+ export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
1044
1059
  }
1045
1060
  declare const float32 = DType.Float32;
1046
1061
  declare const int32 = DType.Int32;
@@ -1123,7 +1138,7 @@ declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
1123
1138
  * all axes are padded with the same width. Or if it is an array of pairs, each
1124
1139
  * pair specifies the padding for its corresponding axis.
1125
1140
  */
1126
- declare const pad: (x: ArrayLike, width: number | [number, number] | [number, number][]) => Array;
1141
+ declare const pad: (x: ArrayLike, width: number | Pair | Pair[]) => Array;
1127
1142
  /**
1128
1143
  * @function
1129
1144
  * Return the number of dimensions of an array. Does not consume array reference.
@@ -1353,6 +1368,26 @@ declare function absolute(x: ArrayLike): Array;
1353
1368
  declare const abs: typeof absolute;
1354
1369
  /** Return an element-wise indication of sign of the input. */
1355
1370
  declare function sign(x: ArrayLike): Array;
1371
+ /**
1372
+ * Return the Hamming window of size M, a taper with a weighted cosine bell.
1373
+ *
1374
+ * `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
1375
+ */
1376
+ declare function hamming(M: number): Array;
1377
+ /**
1378
+ * Return the Hann window of size M, a taper with a weighted cosine bell.
1379
+ *
1380
+ * `w(n) = 0.5 - 0.5 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
1381
+ */
1382
+ declare function hann(M: number): Array;
1383
+ /**
1384
+ * @function
1385
+ * Compute the Heaviside step function. It is defined piecewise:
1386
+ * - `heaviside(x1, x2) = 0` for `x1 < 0`,
1387
+ * - `heaviside(x1, x2) = x2` for `x1 == 0`,
1388
+ * - `heaviside(x1, x2) = 1` for `x1 > 0`.
1389
+ */
1390
+ declare const heaviside: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
1356
1391
  /** Calculate element-wise square of the input array. */
1357
1392
  declare function square(x: ArrayLike): Array;
1358
1393
  /** Element-wise tangent function (takes radians). */
@@ -1364,8 +1399,8 @@ declare function acos(x: ArrayLike): Array;
1364
1399
  * Return element-wise hypotenuse for the given legs of a right triangle.
1365
1400
  *
1366
1401
  * In the original NumPy/JAX implementation, this function is more numerically
1367
- * stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
1368
- * improvements.
1402
+ * stable than `sqrt(x1**2 + x2**2)`. We don't currently implement those
1403
+ * stability improvements.
1369
1404
  */
1370
1405
  declare const hypot: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
1371
1406
  /**
@@ -1497,6 +1532,7 @@ declare function std(x: ArrayLike, axis?: Axis, opts?: {
1497
1532
  mean?: ArrayLike;
1498
1533
  correction?: number;
1499
1534
  } & ReduceOpts): Array;
1535
+ //# sourceMappingURL=numpy.d.ts.map
1500
1536
  //#endregion
1501
1537
  //#region src/frontend/jaxpr.d.ts
1502
1538
  /**
@@ -1571,10 +1607,9 @@ declare class Jaxpr implements FpHashable {
1571
1607
  /** @inline */
1572
1608
  type JitOpts = {
1573
1609
  staticArgnums?: number[];
1574
- device?: Device;
1575
1610
  };
1576
1611
  declare namespace lax_d_exports {
1577
- export { PaddingType, conv, convGeneralDilated, convWithGeneralPadding, reduceWindow };
1612
+ export { PaddingType, conv, convGeneralDilated, convWithGeneralPadding, erf, erfc, reduceWindow, stopGradient };
1578
1613
  }
1579
1614
  type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | [number, number][];
1580
1615
  /**
@@ -1598,6 +1633,23 @@ declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: n
1598
1633
  declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType): Array;
1599
1634
  /** Reduce a computation over padded windows. */
1600
1635
  declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
1636
+ /** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
1637
+ declare function erf(x: ArrayLike): Array;
1638
+ /**
1639
+ * The complementary error function: `erfc(x) = 1 - erf(x)`.
1640
+ *
1641
+ * This function is more accurate than `1 - erf(x)` for large values of `x`,
1642
+ * where `erf(x)` is very close to 1.
1643
+ */
1644
+ declare function erfc(x: ArrayLike): Array;
1645
+ /**
1646
+ * Stops gradient computation.
1647
+ *
1648
+ * Behaves as the identity function but prevents the flow of gradients during
1649
+ * forward or reverse-mode automatic differentiation.
1650
+ */
1651
+ declare function stopGradient(x: ArrayLike): Array;
1652
+ //# sourceMappingURL=lax.d.ts.map
1601
1653
  declare namespace nn_d_exports {
1602
1654
  export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, squareplus, standardize, swish };
1603
1655
  }
@@ -1682,15 +1734,17 @@ declare function celu(x: ArrayLike, alpha?: ArrayLike): Array;
1682
1734
  * @function
1683
1735
  * Gaussion error linear unit (GELU) activation function.
1684
1736
  *
1685
- * This is computed element-wise. Currently jax-js does not support the erf() or
1686
- * gelu() functions exactly as primitives, so an approximation is used:
1687
- * `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
1737
+ * This is computed element-wise. There are two variants depending on whether
1738
+ * `approximate` is set (default true):
1688
1739
  *
1689
- * Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
1740
+ * - Approximate: `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`
1741
+ * - Exact: `gelu(x) = x * 0.5 * erfc(-x / sqrt(2))`
1690
1742
  *
1691
- * This will be improved in the future.
1743
+ * Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
1692
1744
  */
1693
- declare const gelu: OwnedFunction<(x: ArrayLike) => Array>;
1745
+ declare const gelu: OwnedFunction<(x: ArrayLike, opts?: {
1746
+ approximate?: boolean | undefined;
1747
+ } | undefined) => Array>;
1694
1748
  /**
1695
1749
  * Gated linear unit (GLU) activation function.
1696
1750
  *
@@ -1771,6 +1825,7 @@ declare function standardize(x: ArrayLike, axis?: Axis, opts?: {
1771
1825
  * ```
1772
1826
  */
1773
1827
  declare function oneHot(x: Array, numClasses: number): Array;
1828
+ //# sourceMappingURL=nn.d.ts.map
1774
1829
  declare namespace random_d_exports {
1775
1830
  export { bernoulli, bits, exponential, key, normal, split, uniform };
1776
1831
  }
@@ -1809,6 +1864,15 @@ declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | und
1809
1864
  * bitwise identical to JAX.
1810
1865
  */
1811
1866
  declare const normal: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
1867
+ //# sourceMappingURL=random.d.ts.map
1868
+ declare namespace scipy_special_d_exports {
1869
+ export { erf, erfc, logSoftmax, logit, logsumexp, softmax };
1870
+ }
1871
+ /**
1872
+ * @function
1873
+ * The logit function, `logit(p) = log(p / (1-p))`.
1874
+ */
1875
+ declare const logit: OwnedFunction<(x: ArrayLike) => Array>;
1812
1876
  //#endregion
1813
1877
  //#region src/index.d.ts
1814
1878
  /**
@@ -1820,7 +1884,7 @@ declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: F, primals:
1820
1884
  * @function
1821
1885
  * Vectorize an operation on a batched axis for one or more inputs.
1822
1886
  */
1823
- declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number | MapJsTree<Parameters<F>, ArrayLike, number | null>) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
1887
+ declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number | (number | null | JsTree<number | null>)[]) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
1824
1888
  /**
1825
1889
  * @function
1826
1890
  * Compute the Jacobian evaluated column-by-column by forward-mode AD.
@@ -1895,5 +1959,19 @@ declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJs
1895
1959
  * Does not consume reference to the arrays.
1896
1960
  */
1897
1961
  declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
1962
+ /**
1963
+ * Transfer `x` to `device`.
1964
+ *
1965
+ * `x` may be a nested container of arrays or scalars. The resulting structure
1966
+ * is committed to the device.
1967
+ *
1968
+ * If `device` is not specified, this function behaves as identity if the input
1969
+ * is already an `Array`, otherwise it places the scalar uncommitted on the
1970
+ * default device.
1971
+ */
1972
+ declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
1973
+ //# sourceMappingURL=index.d.ts.map
1974
+
1898
1975
  //#endregion
1899
- export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
1976
+ export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
1977
+ //# sourceMappingURL=index.d.ts.map