@jax-js/jax 0.0.4 → 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.d.cts CHANGED
@@ -124,6 +124,7 @@ declare class ShapeTracker {
124
124
  /** Like pad(), but allows for negative values. */
125
125
  padOrShrink(arg: Pair[]): ShapeTracker;
126
126
  }
127
+ //# sourceMappingURL=shape.d.ts.map
127
128
  //#endregion
128
129
  //#region src/utils.d.ts
129
130
  /**
@@ -180,12 +181,12 @@ type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Arr
180
181
  * **Type lattice:**
181
182
  * ```text
182
183
  * bool -> uint32 -> int32 -> float16 -> float32
183
- * weak f* --^
184
+ * weakType --^
184
185
  * ```
185
186
  *
186
- * The asterisk f* is a weak type used for JS number constants. When creating
187
- * arrays, JS numbers default to float32 but "weak" so they cast to the dtype of
188
- * any array they are first combined with.
187
+ * `weakType` represents weakly typed arrays. These are created for JS numbers,
188
+ * which default to float32 but "weak" so they cast to the dtype of any array
189
+ * they are first combined with, except `bool`.
189
190
  *
190
191
  * **Examples:**
191
192
  * - `promoteTypes(bool, int32) → int32`
@@ -222,6 +223,8 @@ declare class AluExp implements FpHashable {
222
223
  static atan(a: AluExp): AluExp;
223
224
  static exp(a: AluExp): AluExp;
224
225
  static log(a: AluExp): AluExp;
226
+ static erf(a: AluExp): AluExp;
227
+ static erfc(a: AluExp): AluExp;
225
228
  static sqrt(a: AluExp): AluExp;
226
229
  static reciprocal(a: AluExp): AluExp;
227
230
  static cast(dtype: DType, a: AluExp): AluExp;
@@ -289,8 +292,8 @@ declare class AluExp implements FpHashable {
289
292
  rewrite(visitor: (exp: AluExp) => AluExp | undefined | null): AluExp;
290
293
  /** Collect all nodes that satisfy a predicate. */
291
294
  collect(predicate: (exp: AluExp) => boolean): AluExp[];
292
- /** Produce a list of all distinct AluOp in this expression. */
293
- distinctOps(): Set<AluOp>;
295
+ /** Produce all distinct AluOp in this expression, with their dtypes. */
296
+ distinctOps(): Map<AluOp, Set<DType>>;
294
297
  /** Rewrite GlobalView operations to GlobalIndex operations. */
295
298
  rewriteGlobalViews(): AluExp;
296
299
  }
@@ -309,6 +312,8 @@ declare enum AluOp {
309
312
  Atan = "Atan",
310
313
  Exp = "Exp",
311
314
  Log = "Log",
315
+ Erf = "Erf",
316
+ Erfc = "Erfc",
312
317
  Sqrt = "Sqrt",
313
318
  Reciprocal = "Reciprocal",
314
319
  Cast = "Cast",
@@ -465,7 +470,7 @@ type JsTree<T> = T | JsTree<T>[] | {
465
470
  [key: string]: JsTree<T>;
466
471
  };
467
472
  type Same<X, Y> = (<T>() => T extends X ? 1 : 2) extends (<T>() => T extends Y ? 1 : 2) ? true : false;
468
- type MappedJsTree<T, A, B> = T extends A ? B : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
473
+ type MappedJsTree<T, A, B> = T extends A ? B : T extends Array ? T : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
469
474
  /** @ignore Convert a subtype of JsTree<A> into a JsTree<B>, with the same structure. */
470
475
  type MapJsTree<T, A, B> = Same<A, B> extends true ? T : MappedJsTree<T, A, B>;
471
476
  /** Represents the structure of a JsTree. */
@@ -477,6 +482,8 @@ declare class JsTreeDef {
477
482
  constructor(nodeType: NodeType, nodeMetadata: any,
478
483
  // Must be comparable with deepEqual.
479
484
  childTreedefs: JsTreeDef[]);
485
+ /** Get the total number of leaves in the tree. */
486
+ get size(): number;
480
487
  /** Returns a string representation of this tree definition. */
481
488
  toString(root?: boolean): string;
482
489
  /** Compare this tree definition with another. */
@@ -540,6 +547,8 @@ declare enum Primitive {
540
547
  Atan = "atan",
541
548
  Exp = "exp",
542
549
  Log = "log",
550
+ Erf = "erf",
551
+ Erfc = "erfc",
543
552
  Sqrt = "sqrt",
544
553
  Min = "min",
545
554
  Max = "max",
@@ -613,6 +622,7 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
613
622
  outDim: number;
614
623
  };
615
624
  [Primitive.JitCall]: {
625
+ name: string;
616
626
  jaxpr: Jaxpr;
617
627
  numConsts: number;
618
628
  };
@@ -651,10 +661,40 @@ declare abstract class Trace {
651
661
  abstract lift(val: Tracer): Tracer;
652
662
  abstract processPrimitive<P extends Primitive>(primitive: P, tracers: Tracer[], params: PrimitiveParams<P>): Tracer[];
653
663
  }
664
+ /** Internal representation of an array value. */
654
665
  interface AbstractValue {
666
+ /** Shape of the array. Must be a static tuple of non-negative dimensions. */
655
667
  shape: number[];
668
+ /** Concrete data type of array elements. */
656
669
  dtype: DType;
670
+ /**
671
+ * Arrays created from JavaScript numbers (e.g., `np.array(3)`) are created as
672
+ * _weakly typed_ unless a dtype is explicitly specified.
673
+ *
674
+ * Weakly typed values will automatically cast to the data type of other
675
+ * arrays when used as an operand as an expression. This property only affects
676
+ * how they promote in type casting; their memory layout is still determined
677
+ * by the actual `dtype` field.
678
+ *
679
+ * ```ts
680
+ * const x = np.array(3); // weakType = true, dtype = float32
681
+ * const y = np.array([1, 2], { dtype: np.int32 }); // weakType = false, dtype = int32
682
+ * const z = x.add(y); // z has dtype int32 because x is weakly typed
683
+ * ```
684
+ *
685
+ * Weak types are present in JIT programs in their spec (e.g., Jaxpr inputs
686
+ * and outputs can be weakly typed) form. But they're solely a frontend
687
+ * concept. Backends are not aware of weak types.
688
+ */
689
+ weakType: boolean;
657
690
  }
691
+ /**
692
+ * Broadcast shapes and promote types with casting for two avals.
693
+ *
694
+ * This implements the weak type behavior described in `promoteTypes()`, but not
695
+ * implemented in that function as `weakType` is not passed.
696
+ */
697
+
658
698
  declare abstract class Tracer {
659
699
  /** @ignore */
660
700
  readonly _trace: Trace;
@@ -712,8 +752,15 @@ declare abstract class Tracer {
712
752
  get shape(): number[];
713
753
  /** The total number of elements in the array. */
714
754
  get size(): number;
715
- /** The dtype of the array. */
755
+ /** The dtype of elements stored in the array. */
716
756
  get dtype(): DType;
757
+ /**
758
+ * Whether the array is weakly typed.
759
+ *
760
+ * Weakly typed arrays will cast to the dtype of the other operand. See
761
+ * `promoteTypes()` for details.
762
+ */
763
+ get weakType(): boolean;
717
764
  /** The number of dimensions of the array. */
718
765
  get ndim(): number;
719
766
  /** @ignore */
@@ -805,7 +852,8 @@ declare abstract class Tracer {
805
852
  declare class ShapedArray implements AbstractValue {
806
853
  readonly shape: number[];
807
854
  readonly dtype: DType;
808
- constructor(shape: number[], dtype: DType);
855
+ readonly weakType: boolean;
856
+ constructor(shape: number[], dtype: DType, weakType: boolean);
809
857
  static fromAval(aval: AbstractValue): ShapedArray;
810
858
  get ndim(): number;
811
859
  toString(): string;
@@ -841,10 +889,19 @@ type DTypeAndDevice = {
841
889
  dtype?: DType;
842
890
  device?: Device;
843
891
  };
892
+ type ArrayConstructorArgs = {
893
+ source: AluExp | Slot;
894
+ st: ShapeTracker;
895
+ dtype: DType;
896
+ weakType: boolean;
897
+ backend: Backend;
898
+ committed: boolean;
899
+ pending?: Iterable<PendingExecute>;
900
+ };
844
901
  /**
845
902
  * A multidimensional numeric array with data stored on CPU or GPU.
846
903
  *
847
- * This is the library's core data type. Equivalent to `jnp.Array` from JAX, or
904
+ * This is the library's core data type. Equivalent to `jax.Array` from JAX, or
848
905
  * `torch.Tensor`.
849
906
  *
850
907
  * Not to be confused with the JavaScript "Array" constructor. Avoid importing
@@ -860,11 +917,7 @@ declare class Array extends Tracer {
860
917
  * is a backend `Slot`, this constructor _takes ownership_ of the slot. It
861
918
  * will be freed when the array is disposed.
862
919
  */
863
- constructor(source: AluExp | Slot, st: ShapeTracker, dtype: DType, backend: Backend, {
864
- pending
865
- }?: {
866
- pending?: Iterable<PendingExecute> | null;
867
- });
920
+ constructor(args: ArrayConstructorArgs);
868
921
  /** @ignore */
869
922
  get aval(): ShapedArray;
870
923
  /** Return a simple string representation of the array's dimensions. */
@@ -924,10 +977,13 @@ declare class Array extends Tracer {
924
977
  item(): number;
925
978
  /** @private Internal plumbing method for Array / Tracer ops. */
926
979
  static _implRules(): typeof implRules;
980
+ /** @private */
927
981
  _realizeSource(): number;
982
+ /** @private Put this array on a new backend, asynchronously. */
983
+ _put(backend: Backend): Promise<Array>;
984
+ /** @private Put this array on a new backend, synchronously. */
985
+ _putSync(backend: Backend): Array;
928
986
  }
929
- /** Construct an array from a single scalar constant. */
930
-
931
987
  /** Constructor for creating a new array from data. */
932
988
  declare function array(values: Array | Float16Array<ArrayBuffer> | Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | RecursiveArray<number> | RecursiveArray<boolean>, {
933
989
  shape,
@@ -1002,7 +1058,7 @@ declare function linspace(start: number, stop: number, num?: number, endpoint?:
1002
1058
  device
1003
1059
  }?: DTypeAndDevice): Array;
1004
1060
  declare namespace numpy_d_exports {
1005
- export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hstack, hypot, identity$1 as identity, inf, inner, int32, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
1061
+ export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
1006
1062
  }
1007
1063
  declare const float32 = DType.Float32;
1008
1064
  declare const int32 = DType.Int32;
@@ -1085,7 +1141,7 @@ declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
1085
1141
  * all axes are padded with the same width. Or if it is an array of pairs, each
1086
1142
  * pair specifies the padding for its corresponding axis.
1087
1143
  */
1088
- declare const pad: (x: ArrayLike, width: number | [number, number] | [number, number][]) => Array;
1144
+ declare const pad: (x: ArrayLike, width: number | Pair | Pair[]) => Array;
1089
1145
  /**
1090
1146
  * @function
1091
1147
  * Return the number of dimensions of an array. Does not consume array reference.
@@ -1315,6 +1371,26 @@ declare function absolute(x: ArrayLike): Array;
1315
1371
  declare const abs: typeof absolute;
1316
1372
  /** Return an element-wise indication of sign of the input. */
1317
1373
  declare function sign(x: ArrayLike): Array;
1374
+ /**
1375
+ * Return the Hamming window of size M, a taper with a weighted cosine bell.
1376
+ *
1377
+ * `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
1378
+ */
1379
+ declare function hamming(M: number): Array;
1380
+ /**
1381
+ * Return the Hann window of size M, a taper with a weighted cosine bell.
1382
+ *
1383
+ * `w(n) = 0.5 - 0.5 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
1384
+ */
1385
+ declare function hann(M: number): Array;
1386
+ /**
1387
+ * @function
1388
+ * Compute the Heaviside step function. It is defined piecewise:
1389
+ * - `heaviside(x1, x2) = 0` for `x1 < 0`,
1390
+ * - `heaviside(x1, x2) = x2` for `x1 == 0`,
1391
+ * - `heaviside(x1, x2) = 1` for `x1 > 0`.
1392
+ */
1393
+ declare const heaviside: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
1318
1394
  /** Calculate element-wise square of the input array. */
1319
1395
  declare function square(x: ArrayLike): Array;
1320
1396
  /** Element-wise tangent function (takes radians). */
@@ -1326,8 +1402,8 @@ declare function acos(x: ArrayLike): Array;
1326
1402
  * Return element-wise hypotenuse for the given legs of a right triangle.
1327
1403
  *
1328
1404
  * In the original NumPy/JAX implementation, this function is more numerically
1329
- * stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
1330
- * improvements.
1405
+ * stable than `sqrt(x1**2 + x2**2)`. We don't currently implement those
1406
+ * stability improvements.
1331
1407
  */
1332
1408
  declare const hypot: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
1333
1409
  /**
@@ -1459,6 +1535,7 @@ declare function std(x: ArrayLike, axis?: Axis, opts?: {
1459
1535
  mean?: ArrayLike;
1460
1536
  correction?: number;
1461
1537
  } & ReduceOpts): Array;
1538
+ //# sourceMappingURL=numpy.d.ts.map
1462
1539
  //#endregion
1463
1540
  //#region src/frontend/jaxpr.d.ts
1464
1541
  /**
@@ -1480,10 +1557,10 @@ declare class Var {
1480
1557
  }
1481
1558
  /** Literal in a Jaxpr expression. Currently, only scalars are supported. */
1482
1559
  declare class Lit {
1483
- readonly dtype: DType;
1484
1560
  readonly value: number;
1485
1561
  readonly aval: ShapedArray;
1486
- constructor(dtype: DType, value: number);
1562
+ get dtype(): DType;
1563
+ constructor(aval: AbstractValue, value: number);
1487
1564
  }
1488
1565
  type Atom = Var | Lit;
1489
1566
  declare class VarPrinter {
@@ -1533,10 +1610,9 @@ declare class Jaxpr implements FpHashable {
1533
1610
  /** @inline */
1534
1611
  type JitOpts = {
1535
1612
  staticArgnums?: number[];
1536
- device?: Device;
1537
1613
  };
1538
1614
  declare namespace lax_d_exports {
1539
- export { PaddingType, conv, convGeneralDilated, convWithGeneralPadding, reduceWindow };
1615
+ export { PaddingType, conv, convGeneralDilated, convWithGeneralPadding, erf, erfc, reduceWindow, stopGradient };
1540
1616
  }
1541
1617
  type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | [number, number][];
1542
1618
  /**
@@ -1560,6 +1636,23 @@ declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: n
1560
1636
  declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType): Array;
1561
1637
  /** Reduce a computation over padded windows. */
1562
1638
  declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
1639
+ /** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
1640
+ declare function erf(x: ArrayLike): Array;
1641
+ /**
1642
+ * The complementary error function: `erfc(x) = 1 - erf(x)`.
1643
+ *
1644
+ * This function is more accurate than `1 - erf(x)` for large values of `x`,
1645
+ * where `erf(x)` is very close to 1.
1646
+ */
1647
+ declare function erfc(x: ArrayLike): Array;
1648
+ /**
1649
+ * Stops gradient computation.
1650
+ *
1651
+ * Behaves as the identity function but prevents the flow of gradients during
1652
+ * forward or reverse-mode automatic differentiation.
1653
+ */
1654
+ declare function stopGradient(x: ArrayLike): Array;
1655
+ //# sourceMappingURL=lax.d.ts.map
1563
1656
  declare namespace nn_d_exports {
1564
1657
  export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, squareplus, standardize, swish };
1565
1658
  }
@@ -1644,15 +1737,17 @@ declare function celu(x: ArrayLike, alpha?: ArrayLike): Array;
1644
1737
  * @function
1645
1738
  * Gaussion error linear unit (GELU) activation function.
1646
1739
  *
1647
- * This is computed element-wise. Currently jax-js does not support the erf() or
1648
- * gelu() functions exactly as primitives, so an approximation is used:
1649
- * `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
1740
+ * This is computed element-wise. There are two variants depending on whether
1741
+ * `approximate` is set (default true):
1650
1742
  *
1651
- * Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
1743
+ * - Approximate: `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`
1744
+ * - Exact: `gelu(x) = x * 0.5 * erfc(-x / sqrt(2))`
1652
1745
  *
1653
- * This will be improved in the future.
1746
+ * Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
1654
1747
  */
1655
- declare const gelu: OwnedFunction<(x: ArrayLike) => Array>;
1748
+ declare const gelu: OwnedFunction<(x: ArrayLike, opts?: {
1749
+ approximate?: boolean | undefined;
1750
+ } | undefined) => Array>;
1656
1751
  /**
1657
1752
  * Gated linear unit (GLU) activation function.
1658
1753
  *
@@ -1733,6 +1828,7 @@ declare function standardize(x: ArrayLike, axis?: Axis, opts?: {
1733
1828
  * ```
1734
1829
  */
1735
1830
  declare function oneHot(x: Array, numClasses: number): Array;
1831
+ //# sourceMappingURL=nn.d.ts.map
1736
1832
  declare namespace random_d_exports {
1737
1833
  export { bernoulli, bits, exponential, key, normal, split, uniform };
1738
1834
  }
@@ -1742,14 +1838,14 @@ declare function key(seed: number): Array;
1742
1838
  declare function split(key: Array, num?: number | number[]): Array;
1743
1839
  /** Sample uniform bits in the form of unsigned integers. */
1744
1840
  declare function bits(key: Array, shape?: number[]): Array;
1745
- /** Sample uniform random values in [minval, maxval) with given shape. */
1746
- declare function uniform(key: Array, shape?: number[], {
1747
- minval,
1748
- maxval
1749
- }?: {
1750
- minval?: number;
1751
- maxval?: number;
1752
- }): Array;
1841
+ /**
1842
+ * @function
1843
+ * Sample uniform random values in [minval, maxval) with given shape.
1844
+ */
1845
+ declare const uniform: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined, args_2?: {
1846
+ minval?: number | undefined;
1847
+ maxval?: number | undefined;
1848
+ } | undefined) => Array>;
1753
1849
  /**
1754
1850
  * Sample Bernoulli random variables with given mean (0,1 categorical).
1755
1851
  *
@@ -1757,16 +1853,29 @@ declare function uniform(key: Array, shape?: number[], {
1757
1853
  * and must be broadcastable to `shape`.
1758
1854
  */
1759
1855
  declare function bernoulli(key: Array, p?: ArrayLike, shape?: number[]): Array;
1760
- /** Sample exponential random values according to `p(x) = exp(-x)`. */
1761
- declare function exponential(key: Array, shape?: number[]): Array;
1762
1856
  /**
1857
+ * @function
1858
+ * Sample exponential random values according to `p(x) = exp(-x)`.
1859
+ */
1860
+ declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
1861
+ /**
1862
+ * @function
1763
1863
  * Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
1764
1864
  *
1765
1865
  * Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
1766
1866
  * directly inverts the CDF, but we don't have support for that yet. Outputs will not be
1767
1867
  * bitwise identical to JAX.
1768
1868
  */
1769
- declare function normal(key: Array, shape?: number[]): Array;
1869
+ declare const normal: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
1870
+ //# sourceMappingURL=random.d.ts.map
1871
+ declare namespace scipy_special_d_exports {
1872
+ export { erf, erfc, logSoftmax, logit, logsumexp, softmax };
1873
+ }
1874
+ /**
1875
+ * @function
1876
+ * The logit function, `logit(p) = log(p / (1-p))`.
1877
+ */
1878
+ declare const logit: OwnedFunction<(x: ArrayLike) => Array>;
1770
1879
  //#endregion
1771
1880
  //#region src/index.d.ts
1772
1881
  /**
@@ -1778,7 +1887,7 @@ declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: F, primals:
1778
1887
  * @function
1779
1888
  * Vectorize an operation on a batched axis for one or more inputs.
1780
1889
  */
1781
- declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number | MapJsTree<Parameters<F>, ArrayLike, number | null>) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
1890
+ declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number | (number | null | JsTree<number | null>)[]) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
1782
1891
  /**
1783
1892
  * @function
1784
1893
  * Compute the Jacobian evaluated column-by-column by forward-mode AD.
@@ -1853,5 +1962,19 @@ declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJs
1853
1962
  * Does not consume reference to the arrays.
1854
1963
  */
1855
1964
  declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
1965
+ /**
1966
+ * Transfer `x` to `device`.
1967
+ *
1968
+ * `x` may be a nested container of arrays or scalars. The resulting structure
1969
+ * is committed to the device.
1970
+ *
1971
+ * If `device` is not specified, this function behaves as identity if the input
1972
+ * is already an `Array`, otherwise it places the scalar uncommitted on the
1973
+ * default device.
1974
+ */
1975
+ declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
1976
+ //# sourceMappingURL=index.d.ts.map
1977
+
1856
1978
  //#endregion
1857
- export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
1979
+ export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
1980
+ //# sourceMappingURL=index.d.cts.map