@jax-js/jax 0.0.4 → 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.d.ts CHANGED
@@ -121,6 +121,7 @@ declare class ShapeTracker {
121
121
  /** Like pad(), but allows for negative values. */
122
122
  padOrShrink(arg: Pair[]): ShapeTracker;
123
123
  }
124
+ //# sourceMappingURL=shape.d.ts.map
124
125
  //#endregion
125
126
  //#region src/utils.d.ts
126
127
  /**
@@ -177,12 +178,12 @@ type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Arr
177
178
  * **Type lattice:**
178
179
  * ```text
179
180
  * bool -> uint32 -> int32 -> float16 -> float32
180
- * weak f* --^
181
+ * weakType --^
181
182
  * ```
182
183
  *
183
- * The asterisk f* is a weak type used for JS number constants. When creating
184
- * arrays, JS numbers default to float32 but "weak" so they cast to the dtype of
185
- * any array they are first combined with.
184
+ * `weakType` represents weakly typed arrays. These are created for JS numbers,
185
+ * which default to float32 but "weak" so they cast to the dtype of any array
186
+ * they are first combined with, except `bool`.
186
187
  *
187
188
  * **Examples:**
188
189
  * - `promoteTypes(bool, int32) → int32`
@@ -219,6 +220,8 @@ declare class AluExp implements FpHashable {
219
220
  static atan(a: AluExp): AluExp;
220
221
  static exp(a: AluExp): AluExp;
221
222
  static log(a: AluExp): AluExp;
223
+ static erf(a: AluExp): AluExp;
224
+ static erfc(a: AluExp): AluExp;
222
225
  static sqrt(a: AluExp): AluExp;
223
226
  static reciprocal(a: AluExp): AluExp;
224
227
  static cast(dtype: DType, a: AluExp): AluExp;
@@ -286,8 +289,8 @@ declare class AluExp implements FpHashable {
286
289
  rewrite(visitor: (exp: AluExp) => AluExp | undefined | null): AluExp;
287
290
  /** Collect all nodes that satisfy a predicate. */
288
291
  collect(predicate: (exp: AluExp) => boolean): AluExp[];
289
- /** Produce a list of all distinct AluOp in this expression. */
290
- distinctOps(): Set<AluOp>;
292
+ /** Produce all distinct AluOp in this expression, with their dtypes. */
293
+ distinctOps(): Map<AluOp, Set<DType>>;
291
294
  /** Rewrite GlobalView operations to GlobalIndex operations. */
292
295
  rewriteGlobalViews(): AluExp;
293
296
  }
@@ -306,6 +309,8 @@ declare enum AluOp {
306
309
  Atan = "Atan",
307
310
  Exp = "Exp",
308
311
  Log = "Log",
312
+ Erf = "Erf",
313
+ Erfc = "Erfc",
309
314
  Sqrt = "Sqrt",
310
315
  Reciprocal = "Reciprocal",
311
316
  Cast = "Cast",
@@ -462,7 +467,7 @@ type JsTree<T> = T | JsTree<T>[] | {
462
467
  [key: string]: JsTree<T>;
463
468
  };
464
469
  type Same<X, Y> = (<T>() => T extends X ? 1 : 2) extends (<T>() => T extends Y ? 1 : 2) ? true : false;
465
- type MappedJsTree<T, A, B> = T extends A ? B : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
470
+ type MappedJsTree<T, A, B> = T extends A ? B : T extends Array ? T : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
466
471
  /** @ignore Convert a subtype of JsTree<A> into a JsTree<B>, with the same structure. */
467
472
  type MapJsTree<T, A, B> = Same<A, B> extends true ? T : MappedJsTree<T, A, B>;
468
473
  /** Represents the structure of a JsTree. */
@@ -474,6 +479,8 @@ declare class JsTreeDef {
474
479
  constructor(nodeType: NodeType, nodeMetadata: any,
475
480
  // Must be comparable with deepEqual.
476
481
  childTreedefs: JsTreeDef[]);
482
+ /** Get the total number of leaves in the tree. */
483
+ get size(): number;
477
484
  /** Returns a string representation of this tree definition. */
478
485
  toString(root?: boolean): string;
479
486
  /** Compare this tree definition with another. */
@@ -537,6 +544,8 @@ declare enum Primitive {
537
544
  Atan = "atan",
538
545
  Exp = "exp",
539
546
  Log = "log",
547
+ Erf = "erf",
548
+ Erfc = "erfc",
540
549
  Sqrt = "sqrt",
541
550
  Min = "min",
542
551
  Max = "max",
@@ -610,6 +619,7 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
610
619
  outDim: number;
611
620
  };
612
621
  [Primitive.JitCall]: {
622
+ name: string;
613
623
  jaxpr: Jaxpr;
614
624
  numConsts: number;
615
625
  };
@@ -648,10 +658,40 @@ declare abstract class Trace {
648
658
  abstract lift(val: Tracer): Tracer;
649
659
  abstract processPrimitive<P extends Primitive>(primitive: P, tracers: Tracer[], params: PrimitiveParams<P>): Tracer[];
650
660
  }
661
+ /** Internal representation of an array value. */
651
662
  interface AbstractValue {
663
+ /** Shape of the array. Must be a static tuple of non-negative dimensions. */
652
664
  shape: number[];
665
+ /** Concrete data type of array elements. */
653
666
  dtype: DType;
667
+ /**
668
+ * Arrays created from JavaScript numbers (e.g., `np.array(3)`) are created as
669
+ * _weakly typed_ unless a dtype is explicitly specified.
670
+ *
671
+ * Weakly typed values will automatically cast to the data type of other
672
+ * arrays when used as an operand as an expression. This property only affects
673
+ * how they promote in type casting; their memory layout is still determined
674
+ * by the actual `dtype` field.
675
+ *
676
+ * ```ts
677
+ * const x = np.array(3); // weakType = true, dtype = float32
678
+ * const y = np.array([1, 2], { dtype: np.int32 }); // weakType = false, dtype = int32
679
+ * const z = x.add(y); // z has dtype int32 because x is weakly typed
680
+ * ```
681
+ *
682
+ * Weak types are present in JIT programs in their spec (e.g., Jaxpr inputs
683
+ * and outputs can be weakly typed) form. But they're solely a frontend
684
+ * concept. Backends are not aware of weak types.
685
+ */
686
+ weakType: boolean;
654
687
  }
688
+ /**
689
+ * Broadcast shapes and promote types with casting for two avals.
690
+ *
691
+ * This implements the weak type behavior described in `promoteTypes()`, but not
692
+ * implemented in that function as `weakType` is not passed.
693
+ */
694
+
655
695
  declare abstract class Tracer {
656
696
  /** @ignore */
657
697
  readonly _trace: Trace;
@@ -709,8 +749,15 @@ declare abstract class Tracer {
709
749
  get shape(): number[];
710
750
  /** The total number of elements in the array. */
711
751
  get size(): number;
712
- /** The dtype of the array. */
752
+ /** The dtype of elements stored in the array. */
713
753
  get dtype(): DType;
754
+ /**
755
+ * Whether the array is weakly typed.
756
+ *
757
+ * Weakly typed arrays will cast to the dtype of the other operand. See
758
+ * `promoteTypes()` for details.
759
+ */
760
+ get weakType(): boolean;
714
761
  /** The number of dimensions of the array. */
715
762
  get ndim(): number;
716
763
  /** @ignore */
@@ -802,7 +849,8 @@ declare abstract class Tracer {
802
849
  declare class ShapedArray implements AbstractValue {
803
850
  readonly shape: number[];
804
851
  readonly dtype: DType;
805
- constructor(shape: number[], dtype: DType);
852
+ readonly weakType: boolean;
853
+ constructor(shape: number[], dtype: DType, weakType: boolean);
806
854
  static fromAval(aval: AbstractValue): ShapedArray;
807
855
  get ndim(): number;
808
856
  toString(): string;
@@ -838,10 +886,19 @@ type DTypeAndDevice = {
838
886
  dtype?: DType;
839
887
  device?: Device;
840
888
  };
889
+ type ArrayConstructorArgs = {
890
+ source: AluExp | Slot;
891
+ st: ShapeTracker;
892
+ dtype: DType;
893
+ weakType: boolean;
894
+ backend: Backend;
895
+ committed: boolean;
896
+ pending?: Iterable<PendingExecute>;
897
+ };
841
898
  /**
842
899
  * A multidimensional numeric array with data stored on CPU or GPU.
843
900
  *
844
- * This is the library's core data type. Equivalent to `jnp.Array` from JAX, or
901
+ * This is the library's core data type. Equivalent to `jax.Array` from JAX, or
845
902
  * `torch.Tensor`.
846
903
  *
847
904
  * Not to be confused with the JavaScript "Array" constructor. Avoid importing
@@ -857,11 +914,7 @@ declare class Array extends Tracer {
857
914
  * is a backend `Slot`, this constructor _takes ownership_ of the slot. It
858
915
  * will be freed when the array is disposed.
859
916
  */
860
- constructor(source: AluExp | Slot, st: ShapeTracker, dtype: DType, backend: Backend, {
861
- pending
862
- }?: {
863
- pending?: Iterable<PendingExecute> | null;
864
- });
917
+ constructor(args: ArrayConstructorArgs);
865
918
  /** @ignore */
866
919
  get aval(): ShapedArray;
867
920
  /** Return a simple string representation of the array's dimensions. */
@@ -921,10 +974,13 @@ declare class Array extends Tracer {
921
974
  item(): number;
922
975
  /** @private Internal plumbing method for Array / Tracer ops. */
923
976
  static _implRules(): typeof implRules;
977
+ /** @private */
924
978
  _realizeSource(): number;
979
+ /** @private Put this array on a new backend, asynchronously. */
980
+ _put(backend: Backend): Promise<Array>;
981
+ /** @private Put this array on a new backend, synchronously. */
982
+ _putSync(backend: Backend): Array;
925
983
  }
926
- /** Construct an array from a single scalar constant. */
927
-
928
984
  /** Constructor for creating a new array from data. */
929
985
  declare function array(values: Array | Float16Array<ArrayBuffer> | Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | RecursiveArray<number> | RecursiveArray<boolean>, {
930
986
  shape,
@@ -999,7 +1055,7 @@ declare function linspace(start: number, stop: number, num?: number, endpoint?:
999
1055
  device
1000
1056
  }?: DTypeAndDevice): Array;
1001
1057
  declare namespace numpy_d_exports {
1002
- export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hstack, hypot, identity$1 as identity, inf, inner, int32, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
1058
+ export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
1003
1059
  }
1004
1060
  declare const float32 = DType.Float32;
1005
1061
  declare const int32 = DType.Int32;
@@ -1082,7 +1138,7 @@ declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
1082
1138
  * all axes are padded with the same width. Or if it is an array of pairs, each
1083
1139
  * pair specifies the padding for its corresponding axis.
1084
1140
  */
1085
- declare const pad: (x: ArrayLike, width: number | [number, number] | [number, number][]) => Array;
1141
+ declare const pad: (x: ArrayLike, width: number | Pair | Pair[]) => Array;
1086
1142
  /**
1087
1143
  * @function
1088
1144
  * Return the number of dimensions of an array. Does not consume array reference.
@@ -1312,6 +1368,26 @@ declare function absolute(x: ArrayLike): Array;
1312
1368
  declare const abs: typeof absolute;
1313
1369
  /** Return an element-wise indication of sign of the input. */
1314
1370
  declare function sign(x: ArrayLike): Array;
1371
+ /**
1372
+ * Return the Hamming window of size M, a taper with a weighted cosine bell.
1373
+ *
1374
+ * `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
1375
+ */
1376
+ declare function hamming(M: number): Array;
1377
+ /**
1378
+ * Return the Hann window of size M, a taper with a weighted cosine bell.
1379
+ *
1380
+ * `w(n) = 0.5 - 0.5 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
1381
+ */
1382
+ declare function hann(M: number): Array;
1383
+ /**
1384
+ * @function
1385
+ * Compute the Heaviside step function. It is defined piecewise:
1386
+ * - `heaviside(x1, x2) = 0` for `x1 < 0`,
1387
+ * - `heaviside(x1, x2) = x2` for `x1 == 0`,
1388
+ * - `heaviside(x1, x2) = 1` for `x1 > 0`.
1389
+ */
1390
+ declare const heaviside: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
1315
1391
  /** Calculate element-wise square of the input array. */
1316
1392
  declare function square(x: ArrayLike): Array;
1317
1393
  /** Element-wise tangent function (takes radians). */
@@ -1323,8 +1399,8 @@ declare function acos(x: ArrayLike): Array;
1323
1399
  * Return element-wise hypotenuse for the given legs of a right triangle.
1324
1400
  *
1325
1401
  * In the original NumPy/JAX implementation, this function is more numerically
1326
- * stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
1327
- * improvements.
1402
+ * stable than `sqrt(x1**2 + x2**2)`. We don't currently implement those
1403
+ * stability improvements.
1328
1404
  */
1329
1405
  declare const hypot: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
1330
1406
  /**
@@ -1456,6 +1532,7 @@ declare function std(x: ArrayLike, axis?: Axis, opts?: {
1456
1532
  mean?: ArrayLike;
1457
1533
  correction?: number;
1458
1534
  } & ReduceOpts): Array;
1535
+ //# sourceMappingURL=numpy.d.ts.map
1459
1536
  //#endregion
1460
1537
  //#region src/frontend/jaxpr.d.ts
1461
1538
  /**
@@ -1477,10 +1554,10 @@ declare class Var {
1477
1554
  }
1478
1555
  /** Literal in a Jaxpr expression. Currently, only scalars are supported. */
1479
1556
  declare class Lit {
1480
- readonly dtype: DType;
1481
1557
  readonly value: number;
1482
1558
  readonly aval: ShapedArray;
1483
- constructor(dtype: DType, value: number);
1559
+ get dtype(): DType;
1560
+ constructor(aval: AbstractValue, value: number);
1484
1561
  }
1485
1562
  type Atom = Var | Lit;
1486
1563
  declare class VarPrinter {
@@ -1530,10 +1607,9 @@ declare class Jaxpr implements FpHashable {
1530
1607
  /** @inline */
1531
1608
  type JitOpts = {
1532
1609
  staticArgnums?: number[];
1533
- device?: Device;
1534
1610
  };
1535
1611
  declare namespace lax_d_exports {
1536
- export { PaddingType, conv, convGeneralDilated, convWithGeneralPadding, reduceWindow };
1612
+ export { PaddingType, conv, convGeneralDilated, convWithGeneralPadding, erf, erfc, reduceWindow, stopGradient };
1537
1613
  }
1538
1614
  type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | [number, number][];
1539
1615
  /**
@@ -1557,6 +1633,23 @@ declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: n
1557
1633
  declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType): Array;
1558
1634
  /** Reduce a computation over padded windows. */
1559
1635
  declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
1636
+ /** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
1637
+ declare function erf(x: ArrayLike): Array;
1638
+ /**
1639
+ * The complementary error function: `erfc(x) = 1 - erf(x)`.
1640
+ *
1641
+ * This function is more accurate than `1 - erf(x)` for large values of `x`,
1642
+ * where `erf(x)` is very close to 1.
1643
+ */
1644
+ declare function erfc(x: ArrayLike): Array;
1645
+ /**
1646
+ * Stops gradient computation.
1647
+ *
1648
+ * Behaves as the identity function but prevents the flow of gradients during
1649
+ * forward or reverse-mode automatic differentiation.
1650
+ */
1651
+ declare function stopGradient(x: ArrayLike): Array;
1652
+ //# sourceMappingURL=lax.d.ts.map
1560
1653
  declare namespace nn_d_exports {
1561
1654
  export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, squareplus, standardize, swish };
1562
1655
  }
@@ -1641,15 +1734,17 @@ declare function celu(x: ArrayLike, alpha?: ArrayLike): Array;
1641
1734
  * @function
1642
1735
  * Gaussion error linear unit (GELU) activation function.
1643
1736
  *
1644
- * This is computed element-wise. Currently jax-js does not support the erf() or
1645
- * gelu() functions exactly as primitives, so an approximation is used:
1646
- * `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
1737
+ * This is computed element-wise. There are two variants depending on whether
1738
+ * `approximate` is set (default true):
1647
1739
  *
1648
- * Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
1740
+ * - Approximate: `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`
1741
+ * - Exact: `gelu(x) = x * 0.5 * erfc(-x / sqrt(2))`
1649
1742
  *
1650
- * This will be improved in the future.
1743
+ * Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
1651
1744
  */
1652
- declare const gelu: OwnedFunction<(x: ArrayLike) => Array>;
1745
+ declare const gelu: OwnedFunction<(x: ArrayLike, opts?: {
1746
+ approximate?: boolean | undefined;
1747
+ } | undefined) => Array>;
1653
1748
  /**
1654
1749
  * Gated linear unit (GLU) activation function.
1655
1750
  *
@@ -1730,6 +1825,7 @@ declare function standardize(x: ArrayLike, axis?: Axis, opts?: {
1730
1825
  * ```
1731
1826
  */
1732
1827
  declare function oneHot(x: Array, numClasses: number): Array;
1828
+ //# sourceMappingURL=nn.d.ts.map
1733
1829
  declare namespace random_d_exports {
1734
1830
  export { bernoulli, bits, exponential, key, normal, split, uniform };
1735
1831
  }
@@ -1739,14 +1835,14 @@ declare function key(seed: number): Array;
1739
1835
  declare function split(key: Array, num?: number | number[]): Array;
1740
1836
  /** Sample uniform bits in the form of unsigned integers. */
1741
1837
  declare function bits(key: Array, shape?: number[]): Array;
1742
- /** Sample uniform random values in [minval, maxval) with given shape. */
1743
- declare function uniform(key: Array, shape?: number[], {
1744
- minval,
1745
- maxval
1746
- }?: {
1747
- minval?: number;
1748
- maxval?: number;
1749
- }): Array;
1838
+ /**
1839
+ * @function
1840
+ * Sample uniform random values in [minval, maxval) with given shape.
1841
+ */
1842
+ declare const uniform: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined, args_2?: {
1843
+ minval?: number | undefined;
1844
+ maxval?: number | undefined;
1845
+ } | undefined) => Array>;
1750
1846
  /**
1751
1847
  * Sample Bernoulli random variables with given mean (0,1 categorical).
1752
1848
  *
@@ -1754,16 +1850,29 @@ declare function uniform(key: Array, shape?: number[], {
1754
1850
  * and must be broadcastable to `shape`.
1755
1851
  */
1756
1852
  declare function bernoulli(key: Array, p?: ArrayLike, shape?: number[]): Array;
1757
- /** Sample exponential random values according to `p(x) = exp(-x)`. */
1758
- declare function exponential(key: Array, shape?: number[]): Array;
1759
1853
  /**
1854
+ * @function
1855
+ * Sample exponential random values according to `p(x) = exp(-x)`.
1856
+ */
1857
+ declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
1858
+ /**
1859
+ * @function
1760
1860
  * Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
1761
1861
  *
1762
1862
  * Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
1763
1863
  * directly inverts the CDF, but we don't have support for that yet. Outputs will not be
1764
1864
  * bitwise identical to JAX.
1765
1865
  */
1766
- declare function normal(key: Array, shape?: number[]): Array;
1866
+ declare const normal: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
1867
+ //# sourceMappingURL=random.d.ts.map
1868
+ declare namespace scipy_special_d_exports {
1869
+ export { erf, erfc, logSoftmax, logit, logsumexp, softmax };
1870
+ }
1871
+ /**
1872
+ * @function
1873
+ * The logit function, `logit(p) = log(p / (1-p))`.
1874
+ */
1875
+ declare const logit: OwnedFunction<(x: ArrayLike) => Array>;
1767
1876
  //#endregion
1768
1877
  //#region src/index.d.ts
1769
1878
  /**
@@ -1775,7 +1884,7 @@ declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: F, primals:
1775
1884
  * @function
1776
1885
  * Vectorize an operation on a batched axis for one or more inputs.
1777
1886
  */
1778
- declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number | MapJsTree<Parameters<F>, ArrayLike, number | null>) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
1887
+ declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number | (number | null | JsTree<number | null>)[]) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
1779
1888
  /**
1780
1889
  * @function
1781
1890
  * Compute the Jacobian evaluated column-by-column by forward-mode AD.
@@ -1850,5 +1959,19 @@ declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJs
1850
1959
  * Does not consume reference to the arrays.
1851
1960
  */
1852
1961
  declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
1962
+ /**
1963
+ * Transfer `x` to `device`.
1964
+ *
1965
+ * `x` may be a nested container of arrays or scalars. The resulting structure
1966
+ * is committed to the device.
1967
+ *
1968
+ * If `device` is not specified, this function behaves as identity if the input
1969
+ * is already an `Array`, otherwise it places the scalar uncommitted on the
1970
+ * default device.
1971
+ */
1972
+ declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
1973
+ //# sourceMappingURL=index.d.ts.map
1974
+
1853
1975
  //#endregion
1854
- export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
1976
+ export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
1977
+ //# sourceMappingURL=index.d.ts.map