@jax-js/jax 0.1.5 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.d.cts CHANGED
@@ -666,7 +666,7 @@ type IInfo = Readonly<{
666
666
  /** Machine limits for integer types. */
667
667
  declare function iinfo(dtype: DType): IInfo;
668
668
  declare namespace numpy_d_exports {
669
- export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot$1 as dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logspace, matmul, matrixTranspose, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
669
+ export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot$1 as dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logspace, matmul, matrixTranspose, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
670
670
  }
671
671
  declare const float32 = DType.Float32;
672
672
  declare const int32 = DType.Int32;
@@ -886,6 +886,8 @@ declare function columnStack(xs: ArrayLike[]): Array;
886
886
  declare function flipud(x: ArrayLike): Array;
887
887
  /** Flip an array horizontally (axis=1). */
888
888
  declare function fliplr(x: ArrayLike): Array;
889
+ /** Interchange two axes of an array. */
890
+ declare function swapaxes(a: ArrayLike, axis1: number, axis2: number): Array;
889
891
  /** Transpose the last two dimensions of an array. */
890
892
  declare function matrixTranspose(a: ArrayLike): Array;
891
893
  /** Return a 1-D flattened array containing the elements of the input. */
@@ -1590,14 +1592,14 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
1590
1592
  [Primitive.Pad]: {
1591
1593
  width: Pair[];
1592
1594
  };
1595
+ [Primitive.TriangularSolve]: {
1596
+ unitDiagonal: boolean;
1597
+ };
1593
1598
  [Primitive.Jit]: {
1594
1599
  name: string;
1595
1600
  jaxpr: Jaxpr;
1596
1601
  numConsts: number;
1597
1602
  };
1598
- [Primitive.TriangularSolve]: {
1599
- unitDiagonal: boolean;
1600
- };
1601
1603
  }
1602
1604
  /** Type of parameters taken by each primitive. */
1603
1605
  type PrimitiveParams<T extends Primitive> = T extends keyof PrimitiveParamsImpl ? PrimitiveParamsImpl[T] : Record<string, never>;
@@ -2076,6 +2078,24 @@ declare function logspace(start: number, stop: number, num?: number, endpoint?:
2076
2078
  dtype,
2077
2079
  device
2078
2080
  }?: DTypeAndDevice): Array;
2081
+ //#endregion
2082
+ //#region src/frontend/linearize.d.ts
2083
+ /** @inline */
2084
+ type GradOpts = {
2085
+ /**
2086
+ * Integer or sequence of integers. Specifies which positional argument(s) to
2087
+ * differentiate with respect to.
2088
+ *
2089
+ * Defaults to `0` (the first argument).
2090
+ */
2091
+ argnums?: number | number[];
2092
+ /**
2093
+ * The input function returns a pair of `[out, aux]` including an auxiliary
2094
+ * value. This `aux` is not differentiated, but is returned alongside the
2095
+ * gradient when evaluating the function.
2096
+ */
2097
+ hasAux?: boolean;
2098
+ };
2079
2099
  declare namespace lax_linalg_d_exports {
2080
2100
  export { cholesky, lu, triangularSolve };
2081
2101
  }
@@ -2166,7 +2186,7 @@ declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
2166
2186
  unitDiagonal?: boolean;
2167
2187
  }): Array;
2168
2188
  declare namespace lax_d_exports {
2169
- export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convWithGeneralPadding, dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
2189
+ export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
2170
2190
  }
2171
2191
  /**
2172
2192
  * Dimension numbers for general `dot()` primitive.
@@ -2204,7 +2224,11 @@ type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | Pair[];
2204
2224
  * The semantics of this operation mimic the `jax.lax.conv_general_dilated`
2205
2225
  * function in JAX, which wraps XLA's general convolution operator.
2206
2226
  *
2207
- * Grouped convolutions are not supported right now.
2227
+ * @param lhs - Input tensor; shape `[N, C_in, ...xs]`
2228
+ * @param rhs - Convolution kernel; shape `[C_out, C_in / G, ...ks]`
2229
+ * @param windowStrides - Strides for each spatial dimension
2230
+ * @param padding - Padding for each spatial dimension, or a string
2231
+ * (`"VALID"`, `"SAME"`, or `"SAME_LOWER"`)
2208
2232
  */
2209
2233
  declare function convGeneralDilated(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, {
2210
2234
  lhsDilation,
@@ -2219,6 +2243,37 @@ declare function convGeneralDilated(lhs: Array, rhs: Array, windowStrides: numbe
2219
2243
  declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, lhsDilation?: number[], rhsDilation?: number[]): Array;
2220
2244
  /** Convenience wrapper around `convGeneralDilated`. */
2221
2245
  declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType): Array;
2246
+ /**
2247
+ * Convenience wrapper for calculating the N-d convolution "transpose".
2248
+ *
2249
+ * This function directly calculates a fractionally strided conv rather than
2250
+ * indirectly calculating the gradient (transpose) of a forward convolution.
2251
+ * It is equivalent to the JAX version, except:
2252
+ *
2253
+ * - The `use_consistent_padding` option is not available. We only have the
2254
+ * consistent padding case (JAX version >0.8.4).
2255
+ * - The order of dimensions matches `lax.conv_general_dilated`.
2256
+ *
2257
+ * Unlike PyTorch/TensorFlow, by default we don't reverse the kernel's spatial
2258
+ * dimensions or the `(C_out, C_in)` axis order. To get this behavior, set
2259
+ * `transposeKernel` to true.
2260
+ *
2261
+ * @param lhs - Input tensor; shape `[N, C_in, ...xs]`
2262
+ * @param rhs - Convolution kernel; shape `[C_out, C_in, ...ks]`
2263
+ * @param strides - Sequence of n integers, sets fractional stride
2264
+ * @param padding - Apply padding of `dilation * (kernel_size - 1) - padding` to
2265
+ * each side of the input, so it acts like gradient of `conv()`
2266
+ * @param rhsDilation - Atrous dilation for the kernel
2267
+ * @param transposeKernel - Flip spatial axes and swap the input/output channels
2268
+ * of the kernel; its shape should be `[C_in, C_out, ...ks]`
2269
+ */
2270
+ declare function convTranspose(lhs: Array, rhs: Array, strides: number[], padding: PaddingType, {
2271
+ rhsDilation,
2272
+ transposeKernel
2273
+ }?: {
2274
+ rhsDilation?: number[];
2275
+ transposeKernel?: boolean;
2276
+ }): Array;
2222
2277
  /** Reduce a computation over padded windows. */
2223
2278
  declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
2224
2279
  /** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
@@ -2238,7 +2293,7 @@ declare function erfc(x: ArrayLike): Array;
2238
2293
  */
2239
2294
  declare function stopGradient(x: ArrayLike): Array;
2240
2295
  declare namespace nn_d_exports {
2241
- export { celu, elu, gelu, glu, hardSigmoid, hardSilu, hardSilu as hardSwish, hardTanh, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, selu, sigmoid, silu, softSign, softmax, softplus, sparsePlus, sparseSigmoid, squareplus, standardize, silu as swish };
2296
+ export { celu, dotProductAttention, elu, gelu, glu, hardSigmoid, hardSilu, hardSilu as hardSwish, hardTanh, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, selu, sigmoid, silu, softSign, softmax, softplus, sparsePlus, sparseSigmoid, squareplus, standardize, silu as swish };
2242
2297
  }
2243
2298
  /**
2244
2299
  * Rectified Linear Unit (ReLU) activation function:
@@ -2435,6 +2490,56 @@ declare function standardize(x: ArrayLike, axis?: Axis, opts?: {
2435
2490
  * ```
2436
2491
  */
2437
2492
  declare function oneHot(x: Array, numClasses: number): Array;
2493
+ /**
2494
+ * Scaled dot product attention (SDPA).
2495
+ *
2496
+ * Computes `softmax((Q @ K^T) / sqrt(d) + bias) @ V`, where `Q` is the query,
2497
+ * `K` is the key, `V` is the value, and `d` is the dimensionality of each key
2498
+ * and query vector.
2499
+ *
2500
+ * Multi-query attention is applied when input `key` and `value` tensors have
2501
+ * fewer heads than `query`.
2502
+ *
2503
+ * We use the following uppercase letters to denote array shapes:
2504
+ * - `B` = batch size
2505
+ * - `S` = length of key/value sequences (source)
2506
+ * - `L` = length of query sequences
2507
+ * - `N` = number of attention heads
2508
+ * - `H` = dimensionality of each attention head
2509
+ * - `K` = number of key/value heads (for grouped-query attention)
2510
+ *
2511
+ * The batch size `B` may be omitted, which is equivalent to `B = 1`. In this
2512
+ * case it must be omitted from all inputs.
2513
+ *
2514
+ * @param query - Query array; shape `[B, L, N, H]`
2515
+ * @param key - Key array; shape `[B, S, K, H]`
2516
+ * @param value - Value array; same shape as `key`
2517
+ * @param opts.bias - Optional bias to add to the attention logits; shape
2518
+ * `[B, N, L, S]` or broadcastable to it.
2519
+ * @param opts.mask - Optional mask to apply to the attention logits; should be
2520
+ * a boolean array broadcastable to `[B, N, L, S]`, where `true` indicates
2521
+ * the element should take part in attention.
2522
+ * @param opts.scale - Scaling factor override, default is `1 / sqrt(H)`.
2523
+ * @param opts.isCausal - If true, applies a casual mask.
2524
+ * @param opts.querySeqLengths - Optional sequence lengths for the queries;
2525
+ * shape `(B,)`. Taken from the beginning of the tensor.
2526
+ * @param opts.keyValueSeqLengths - Optional sequence lengths for the keys and
2527
+ * values; shape `(B,)`. Taken from the beginning of the tensor.
2528
+ * @param opts.localWindowSize - If specified, applies a local attention window
2529
+ * of the given size. Can be a single number or a tuple `[left, right]`.
2530
+ *
2531
+ * @returns The result of the attention operation; shape is the same as query
2532
+ * `[B, L, N, H]`, or `[L, N, H]` if `B` is omitted.
2533
+ */
2534
+ declare function dotProductAttention(query: ArrayLike, key: ArrayLike, value: ArrayLike, opts?: {
2535
+ bias?: ArrayLike;
2536
+ mask?: ArrayLike;
2537
+ scale?: number;
2538
+ isCausal?: boolean;
2539
+ querySeqLengths?: ArrayLike;
2540
+ keyValueSeqLengths?: ArrayLike;
2541
+ localWindowSize?: number | [number, number];
2542
+ }): Array;
2438
2543
  declare namespace random_d_exports {
2439
2544
  export { bernoulli, bits, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
2440
2545
  }
@@ -2526,7 +2631,9 @@ declare const logit: OwnedFunction<(x: ArrayLike) => Array>;
2526
2631
  * @function
2527
2632
  * Compute the forward-mode Jacobian-vector product for a function.
2528
2633
  */
2529
- declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, ReturnType<F>];
2634
+ declare const jvp: <F extends (...args: any[]) => JsTree<Array>, HA extends boolean = false>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, tangents: MapJsTree<Parameters<F>, Array, ArrayLike>, opts?: {
2635
+ hasAux?: HA;
2636
+ }) => HA extends true ? ReturnType<F> extends [infer Out, infer Aux] ? [Out, Out, Aux] : never : [ReturnType<F>, ReturnType<F>];
2530
2637
  /**
2531
2638
  * @function
2532
2639
  * Vectorize an operation on a batched axis for one or more inputs.
@@ -2568,28 +2675,100 @@ declare const jit: <F extends (...args: any[]) => JsTree<Array>>(f: F, opts?: Ji
2568
2675
  * Produce a local linear approximation to a function at a point using jvp() and
2569
2676
  * partial evaluation.
2570
2677
  */
2571
- declare const linearize: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (...tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>];
2678
+ declare const linearize: <F extends (...args: any[]) => JsTree<Array>, HA extends boolean = false>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, opts?: {
2679
+ hasAux?: HA;
2680
+ }) => HA extends true ? ReturnType<F> extends [infer Out, infer Aux] ? [Out, OwnedFunction<(...tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => Out>, Aux] : never : [ReturnType<F>, OwnedFunction<(...tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>>];
2572
2681
  /**
2573
2682
  * @function
2574
2683
  * Calculate the reverse-mode vector-Jacobian product for a function.
2684
+ *
2685
+ * The return value is a tuple of `[out, vjpFn]`, where `out` is the output of
2686
+ * `f(primals)`, and `vjpFn` is a function that takes in cotangents for each
2687
+ * output and returns the cotangents for each input.
2688
+ *
2689
+ * When `{ hasAux: true }` is passed, the function `f` is expected to return an
2690
+ * `[out, aux]` tuple, and `vjp` returns `[out, vjpFn, aux]`.
2691
+ *
2692
+ * @example
2693
+ * ```ts
2694
+ * const [y, vjpFn] = vjp(f, [x]);
2695
+ *
2696
+ * // With hasAux
2697
+ * const [y, vjpFn, aux] = vjp(f, [x], { hasAux: true });
2698
+ * ```
2575
2699
  */
2576
- declare const vjp: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (cotangents: MapJsTree<ReturnType<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>];
2700
+ declare const vjp: <F extends (...args: any[]) => JsTree<Array>, const HA extends boolean = false>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, opts?: {
2701
+ hasAux?: HA;
2702
+ }) => HA extends true ? ReturnType<F> extends [infer Out, infer Aux] ? [Out, OwnedFunction<(cotangents: MapJsTree<Out, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>>, Aux] : never : [ReturnType<F>, OwnedFunction<(cotangents: MapJsTree<ReturnType<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>>];
2703
+ /** @inline */
2704
+ type GradOutputType<I, F extends (...args: any[]) => any> = MapJsTree<I extends undefined ? Parameters<F>[0] : I extends number ? Parameters<F>[I] : I extends number[] ? { [K in keyof I]: I[K] extends number ? Parameters<F>[I[K]] : never } : never, ArrayLike, Array>;
2577
2705
  /**
2578
2706
  * @function
2579
2707
  * Compute the gradient of a scalar-valued function `f` with respect to its
2580
2708
  * first argument.
2709
+ *
2710
+ * Pass in different `argnums` to differentiate with respect to other
2711
+ * arguments. If a tuple is provided, the return value will be a tuple of
2712
+ * gradients corresponding to each argument index.
2713
+ *
2714
+ * When `{ hasAux: true }` is passed, the function `f` is expected to return a
2715
+ * `[out, aux]` tuple, and the return value will be `[gradient, aux]`.
2716
+ *
2717
+ * @example
2718
+ * ```ts
2719
+ * const gradient = grad(f)(x);
2720
+ *
2721
+ * // With `argnums`
2722
+ * const [gradientX, gradientZ] = grad(f, { argnums: [0, 2] })(x, y, z);
2723
+ *
2724
+ * // With `hasAux`
2725
+ * const [gradient, aux] = grad(f, { hasAux: true })(x);
2726
+ * ```
2581
2727
  */
2582
- declare const grad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>[0], ArrayLike, Array>;
2728
+ declare const grad: <F extends (...args: any[]) => JsTree<Array>, const I extends undefined | number | number[] = undefined, const HA extends boolean = false>(f: F, opts?: Omit<GradOpts, "argnums" | "hasAux"> & {
2729
+ argnums?: I;
2730
+ hasAux?: HA;
2731
+ }) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => HA extends true ? ReturnType<F> extends [any, infer Aux] ? [GradOutputType<I, F>, Aux] : never : GradOutputType<I, F>;
2583
2732
  /**
2584
2733
  * @function
2585
2734
  * Create a function that evaluates both `f` and the gradient of `f`.
2735
+ *
2736
+ * When `{ hasAux: true }` is passed, the function `f` is expected to return an
2737
+ * `[out, aux]` tuple, and the return value will be `[[out, aux], gradient]`.
2738
+ *
2739
+ * @example
2740
+ * ```ts
2741
+ * // Without hasAux
2742
+ * const [value, gradient] = valueAndGrad(f)(x);
2743
+ *
2744
+ * // With hasAux
2745
+ * const [[value, aux], gradient] = valueAndGrad(f, { hasAux: true })(x);
2746
+ * ```
2586
2747
  */
2587
- declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, MapJsTree<Parameters<F>[0], ArrayLike, Array>];
2748
+ declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>, const I extends undefined | number | number[] = undefined, const HA extends boolean = false>(f: F, opts?: Omit<GradOpts, "argnums"> & {
2749
+ argnums?: I;
2750
+ hasAux?: HA;
2751
+ }) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, GradOutputType<I, F>];
2588
2752
  /**
2589
2753
  * @function
2590
2754
  * Compute the Jacobian evaluated row-by-row by reverse-mode AD.
2591
2755
  */
2592
2756
  declare const jacrev: typeof jacfwd;
2757
+ /**
2758
+ * @function
2759
+ * Compute the Hessian matrix of a scalar-valued function.
2760
+ *
2761
+ * The Hessian is the matrix of second-order partial derivatives of a function.
2762
+ * This is implemented as `jacfwd(grad(f))`.
2763
+ *
2764
+ * @example
2765
+ * ```ts
2766
+ * const f = (x: np.Array) => np.sum(x.ref.mul(x.ref).mul(x)); // x^3
2767
+ * const H = hessian(f)(np.array([1, 2, 3]));
2768
+ * // H[i,j] = d^2f / dx_i dx_j
2769
+ * ```
2770
+ */
2771
+ declare const hessian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
2593
2772
  /**
2594
2773
  * Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
2595
2774
  *
@@ -2612,4 +2791,4 @@ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
2612
2791
  */
2613
2792
  declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
2614
2793
  //#endregion
2615
- export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
2794
+ export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
package/dist/index.d.ts CHANGED
@@ -663,7 +663,7 @@ type IInfo = Readonly<{
663
663
  /** Machine limits for integer types. */
664
664
  declare function iinfo(dtype: DType): IInfo;
665
665
  declare namespace numpy_d_exports {
666
- export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot$1 as dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logspace, matmul, matrixTranspose, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
666
+ export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot$1 as dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logspace, matmul, matrixTranspose, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
667
667
  }
668
668
  declare const float32 = DType.Float32;
669
669
  declare const int32 = DType.Int32;
@@ -883,6 +883,8 @@ declare function columnStack(xs: ArrayLike[]): Array;
883
883
  declare function flipud(x: ArrayLike): Array;
884
884
  /** Flip an array horizontally (axis=1). */
885
885
  declare function fliplr(x: ArrayLike): Array;
886
+ /** Interchange two axes of an array. */
887
+ declare function swapaxes(a: ArrayLike, axis1: number, axis2: number): Array;
886
888
  /** Transpose the last two dimensions of an array. */
887
889
  declare function matrixTranspose(a: ArrayLike): Array;
888
890
  /** Return a 1-D flattened array containing the elements of the input. */
@@ -1587,14 +1589,14 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
1587
1589
  [Primitive.Pad]: {
1588
1590
  width: Pair[];
1589
1591
  };
1592
+ [Primitive.TriangularSolve]: {
1593
+ unitDiagonal: boolean;
1594
+ };
1590
1595
  [Primitive.Jit]: {
1591
1596
  name: string;
1592
1597
  jaxpr: Jaxpr;
1593
1598
  numConsts: number;
1594
1599
  };
1595
- [Primitive.TriangularSolve]: {
1596
- unitDiagonal: boolean;
1597
- };
1598
1600
  }
1599
1601
  /** Type of parameters taken by each primitive. */
1600
1602
  type PrimitiveParams<T extends Primitive> = T extends keyof PrimitiveParamsImpl ? PrimitiveParamsImpl[T] : Record<string, never>;
@@ -2073,6 +2075,24 @@ declare function logspace(start: number, stop: number, num?: number, endpoint?:
2073
2075
  dtype,
2074
2076
  device
2075
2077
  }?: DTypeAndDevice): Array;
2078
+ //#endregion
2079
+ //#region src/frontend/linearize.d.ts
2080
+ /** @inline */
2081
+ type GradOpts = {
2082
+ /**
2083
+ * Integer or sequence of integers. Specifies which positional argument(s) to
2084
+ * differentiate with respect to.
2085
+ *
2086
+ * Defaults to `0` (the first argument).
2087
+ */
2088
+ argnums?: number | number[];
2089
+ /**
2090
+ * The input function returns a pair of `[out, aux]` including an auxiliary
2091
+ * value. This `aux` is not differentiated, but is returned alongside the
2092
+ * gradient when evaluating the function.
2093
+ */
2094
+ hasAux?: boolean;
2095
+ };
2076
2096
  declare namespace lax_linalg_d_exports {
2077
2097
  export { cholesky, lu, triangularSolve };
2078
2098
  }
@@ -2163,7 +2183,7 @@ declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
2163
2183
  unitDiagonal?: boolean;
2164
2184
  }): Array;
2165
2185
  declare namespace lax_d_exports {
2166
- export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convWithGeneralPadding, dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
2186
+ export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
2167
2187
  }
2168
2188
  /**
2169
2189
  * Dimension numbers for general `dot()` primitive.
@@ -2201,7 +2221,11 @@ type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | Pair[];
2201
2221
  * The semantics of this operation mimic the `jax.lax.conv_general_dilated`
2202
2222
  * function in JAX, which wraps XLA's general convolution operator.
2203
2223
  *
2204
- * Grouped convolutions are not supported right now.
2224
+ * @param lhs - Input tensor; shape `[N, C_in, ...xs]`
2225
+ * @param rhs - Convolution kernel; shape `[C_out, C_in / G, ...ks]`
2226
+ * @param windowStrides - Strides for each spatial dimension
2227
+ * @param padding - Padding for each spatial dimension, or a string
2228
+ * (`"VALID"`, `"SAME"`, or `"SAME_LOWER"`)
2205
2229
  */
2206
2230
  declare function convGeneralDilated(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, {
2207
2231
  lhsDilation,
@@ -2216,6 +2240,37 @@ declare function convGeneralDilated(lhs: Array, rhs: Array, windowStrides: numbe
2216
2240
  declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, lhsDilation?: number[], rhsDilation?: number[]): Array;
2217
2241
  /** Convenience wrapper around `convGeneralDilated`. */
2218
2242
  declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType): Array;
2243
+ /**
2244
+ * Convenience wrapper for calculating the N-d convolution "transpose".
2245
+ *
2246
+ * This function directly calculates a fractionally strided conv rather than
2247
+ * indirectly calculating the gradient (transpose) of a forward convolution.
2248
+ * It is equivalent to the JAX version, except:
2249
+ *
2250
+ * - The `use_consistent_padding` option is not available. We only have the
2251
+ * consistent padding case (JAX version >0.8.4).
2252
+ * - The order of dimensions matches `lax.conv_general_dilated`.
2253
+ *
2254
+ * Unlike PyTorch/TensorFlow, by default we don't reverse the kernel's spatial
2255
+ * dimensions or the `(C_out, C_in)` axis order. To get this behavior, set
2256
+ * `transposeKernel` to true.
2257
+ *
2258
+ * @param lhs - Input tensor; shape `[N, C_in, ...xs]`
2259
+ * @param rhs - Convolution kernel; shape `[C_out, C_in, ...ks]`
2260
+ * @param strides - Sequence of n integers, sets fractional stride
2261
+ * @param padding - Apply padding of `dilation * (kernel_size - 1) - padding` to
2262
+ * each side of the input, so it acts like gradient of `conv()`
2263
+ * @param rhsDilation - Atrous dilation for the kernel
2264
+ * @param transposeKernel - Flip spatial axes and swap the input/output channels
2265
+ * of the kernel; its shape should be `[C_in, C_out, ...ks]`
2266
+ */
2267
+ declare function convTranspose(lhs: Array, rhs: Array, strides: number[], padding: PaddingType, {
2268
+ rhsDilation,
2269
+ transposeKernel
2270
+ }?: {
2271
+ rhsDilation?: number[];
2272
+ transposeKernel?: boolean;
2273
+ }): Array;
2219
2274
  /** Reduce a computation over padded windows. */
2220
2275
  declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
2221
2276
  /** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
@@ -2235,7 +2290,7 @@ declare function erfc(x: ArrayLike): Array;
2235
2290
  */
2236
2291
  declare function stopGradient(x: ArrayLike): Array;
2237
2292
  declare namespace nn_d_exports {
2238
- export { celu, elu, gelu, glu, hardSigmoid, hardSilu, hardSilu as hardSwish, hardTanh, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, selu, sigmoid, silu, softSign, softmax, softplus, sparsePlus, sparseSigmoid, squareplus, standardize, silu as swish };
2293
+ export { celu, dotProductAttention, elu, gelu, glu, hardSigmoid, hardSilu, hardSilu as hardSwish, hardTanh, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, selu, sigmoid, silu, softSign, softmax, softplus, sparsePlus, sparseSigmoid, squareplus, standardize, silu as swish };
2239
2294
  }
2240
2295
  /**
2241
2296
  * Rectified Linear Unit (ReLU) activation function:
@@ -2432,6 +2487,56 @@ declare function standardize(x: ArrayLike, axis?: Axis, opts?: {
2432
2487
  * ```
2433
2488
  */
2434
2489
  declare function oneHot(x: Array, numClasses: number): Array;
2490
+ /**
2491
+ * Scaled dot product attention (SDPA).
2492
+ *
2493
+ * Computes `softmax((Q @ K^T) / sqrt(d) + bias) @ V`, where `Q` is the query,
2494
+ * `K` is the key, `V` is the value, and `d` is the dimensionality of each key
2495
+ * and query vector.
2496
+ *
2497
+ * Multi-query attention is applied when input `key` and `value` tensors have
2498
+ * fewer heads than `query`.
2499
+ *
2500
+ * We use the following uppercase letters to denote array shapes:
2501
+ * - `B` = batch size
2502
+ * - `S` = length of key/value sequences (source)
2503
+ * - `L` = length of query sequences
2504
+ * - `N` = number of attention heads
2505
+ * - `H` = dimensionality of each attention head
2506
+ * - `K` = number of key/value heads (for grouped-query attention)
2507
+ *
2508
+ * The batch size `B` may be omitted, which is equivalent to `B = 1`. In this
2509
+ * case it must be omitted from all inputs.
2510
+ *
2511
+ * @param query - Query array; shape `[B, L, N, H]`
2512
+ * @param key - Key array; shape `[B, S, K, H]`
2513
+ * @param value - Value array; same shape as `key`
2514
+ * @param opts.bias - Optional bias to add to the attention logits; shape
2515
+ * `[B, N, L, S]` or broadcastable to it.
2516
+ * @param opts.mask - Optional mask to apply to the attention logits; should be
2517
+ * a boolean array broadcastable to `[B, N, L, S]`, where `true` indicates
2518
+ * the element should take part in attention.
2519
+ * @param opts.scale - Scaling factor override, default is `1 / sqrt(H)`.
2520
+ * @param opts.isCausal - If true, applies a casual mask.
2521
+ * @param opts.querySeqLengths - Optional sequence lengths for the queries;
2522
+ * shape `(B,)`. Taken from the beginning of the tensor.
2523
+ * @param opts.keyValueSeqLengths - Optional sequence lengths for the keys and
2524
+ * values; shape `(B,)`. Taken from the beginning of the tensor.
2525
+ * @param opts.localWindowSize - If specified, applies a local attention window
2526
+ * of the given size. Can be a single number or a tuple `[left, right]`.
2527
+ *
2528
+ * @returns The result of the attention operation; shape is the same as query
2529
+ * `[B, L, N, H]`, or `[L, N, H]` if `B` is omitted.
2530
+ */
2531
+ declare function dotProductAttention(query: ArrayLike, key: ArrayLike, value: ArrayLike, opts?: {
2532
+ bias?: ArrayLike;
2533
+ mask?: ArrayLike;
2534
+ scale?: number;
2535
+ isCausal?: boolean;
2536
+ querySeqLengths?: ArrayLike;
2537
+ keyValueSeqLengths?: ArrayLike;
2538
+ localWindowSize?: number | [number, number];
2539
+ }): Array;
2435
2540
  declare namespace random_d_exports {
2436
2541
  export { bernoulli, bits, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
2437
2542
  }
@@ -2523,7 +2628,9 @@ declare const logit: OwnedFunction<(x: ArrayLike) => Array>;
2523
2628
  * @function
2524
2629
  * Compute the forward-mode Jacobian-vector product for a function.
2525
2630
  */
2526
- declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, ReturnType<F>];
2631
+ declare const jvp: <F extends (...args: any[]) => JsTree<Array>, HA extends boolean = false>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, tangents: MapJsTree<Parameters<F>, Array, ArrayLike>, opts?: {
2632
+ hasAux?: HA;
2633
+ }) => HA extends true ? ReturnType<F> extends [infer Out, infer Aux] ? [Out, Out, Aux] : never : [ReturnType<F>, ReturnType<F>];
2527
2634
  /**
2528
2635
  * @function
2529
2636
  * Vectorize an operation on a batched axis for one or more inputs.
@@ -2565,28 +2672,100 @@ declare const jit: <F extends (...args: any[]) => JsTree<Array>>(f: F, opts?: Ji
2565
2672
  * Produce a local linear approximation to a function at a point using jvp() and
2566
2673
  * partial evaluation.
2567
2674
  */
2568
- declare const linearize: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (...tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>];
2675
+ declare const linearize: <F extends (...args: any[]) => JsTree<Array>, HA extends boolean = false>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, opts?: {
2676
+ hasAux?: HA;
2677
+ }) => HA extends true ? ReturnType<F> extends [infer Out, infer Aux] ? [Out, OwnedFunction<(...tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => Out>, Aux] : never : [ReturnType<F>, OwnedFunction<(...tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>>];
2569
2678
  /**
2570
2679
  * @function
2571
2680
  * Calculate the reverse-mode vector-Jacobian product for a function.
2681
+ *
2682
+ * The return value is a tuple of `[out, vjpFn]`, where `out` is the output of
2683
+ * `f(primals)`, and `vjpFn` is a function that takes in cotangents for each
2684
+ * output and returns the cotangents for each input.
2685
+ *
2686
+ * When `{ hasAux: true }` is passed, the function `f` is expected to return an
2687
+ * `[out, aux]` tuple, and `vjp` returns `[out, vjpFn, aux]`.
2688
+ *
2689
+ * @example
2690
+ * ```ts
2691
+ * const [y, vjpFn] = vjp(f, [x]);
2692
+ *
2693
+ * // With hasAux
2694
+ * const [y, vjpFn, aux] = vjp(f, [x], { hasAux: true });
2695
+ * ```
2572
2696
  */
2573
- declare const vjp: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (cotangents: MapJsTree<ReturnType<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>];
2697
+ declare const vjp: <F extends (...args: any[]) => JsTree<Array>, const HA extends boolean = false>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, opts?: {
2698
+ hasAux?: HA;
2699
+ }) => HA extends true ? ReturnType<F> extends [infer Out, infer Aux] ? [Out, OwnedFunction<(cotangents: MapJsTree<Out, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>>, Aux] : never : [ReturnType<F>, OwnedFunction<(cotangents: MapJsTree<ReturnType<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>>];
2700
+ /** @inline */
2701
+ type GradOutputType<I, F extends (...args: any[]) => any> = MapJsTree<I extends undefined ? Parameters<F>[0] : I extends number ? Parameters<F>[I] : I extends number[] ? { [K in keyof I]: I[K] extends number ? Parameters<F>[I[K]] : never } : never, ArrayLike, Array>;
2574
2702
  /**
2575
2703
  * @function
2576
2704
  * Compute the gradient of a scalar-valued function `f` with respect to its
2577
2705
  * first argument.
2706
+ *
2707
+ * Pass in different `argnums` to differentiate with respect to other
2708
+ * arguments. If a tuple is provided, the return value will be a tuple of
2709
+ * gradients corresponding to each argument index.
2710
+ *
2711
+ * When `{ hasAux: true }` is passed, the function `f` is expected to return a
2712
+ * `[out, aux]` tuple, and the return value will be `[gradient, aux]`.
2713
+ *
2714
+ * @example
2715
+ * ```ts
2716
+ * const gradient = grad(f)(x);
2717
+ *
2718
+ * // With `argnums`
2719
+ * const [gradientX, gradientZ] = grad(f, { argnums: [0, 2] })(x, y, z);
2720
+ *
2721
+ * // With `hasAux`
2722
+ * const [gradient, aux] = grad(f, { hasAux: true })(x);
2723
+ * ```
2578
2724
  */
2579
- declare const grad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>[0], ArrayLike, Array>;
2725
+ declare const grad: <F extends (...args: any[]) => JsTree<Array>, const I extends undefined | number | number[] = undefined, const HA extends boolean = false>(f: F, opts?: Omit<GradOpts, "argnums" | "hasAux"> & {
2726
+ argnums?: I;
2727
+ hasAux?: HA;
2728
+ }) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => HA extends true ? ReturnType<F> extends [any, infer Aux] ? [GradOutputType<I, F>, Aux] : never : GradOutputType<I, F>;
2580
2729
  /**
2581
2730
  * @function
2582
2731
  * Create a function that evaluates both `f` and the gradient of `f`.
2732
+ *
2733
+ * When `{ hasAux: true }` is passed, the function `f` is expected to return an
2734
+ * `[out, aux]` tuple, and the return value will be `[[out, aux], gradient]`.
2735
+ *
2736
+ * @example
2737
+ * ```ts
2738
+ * // Without hasAux
2739
+ * const [value, gradient] = valueAndGrad(f)(x);
2740
+ *
2741
+ * // With hasAux
2742
+ * const [[value, aux], gradient] = valueAndGrad(f, { hasAux: true })(x);
2743
+ * ```
2583
2744
  */
2584
- declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, MapJsTree<Parameters<F>[0], ArrayLike, Array>];
2745
+ declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>, const I extends undefined | number | number[] = undefined, const HA extends boolean = false>(f: F, opts?: Omit<GradOpts, "argnums"> & {
2746
+ argnums?: I;
2747
+ hasAux?: HA;
2748
+ }) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, GradOutputType<I, F>];
2585
2749
  /**
2586
2750
  * @function
2587
2751
  * Compute the Jacobian evaluated row-by-row by reverse-mode AD.
2588
2752
  */
2589
2753
  declare const jacrev: typeof jacfwd;
2754
+ /**
2755
+ * @function
2756
+ * Compute the Hessian matrix of a scalar-valued function.
2757
+ *
2758
+ * The Hessian is the matrix of second-order partial derivatives of a function.
2759
+ * This is implemented as `jacfwd(grad(f))`.
2760
+ *
2761
+ * @example
2762
+ * ```ts
2763
+ * const f = (x: np.Array) => np.sum(x.ref.mul(x.ref).mul(x)); // x^3
2764
+ * const H = hessian(f)(np.array([1, 2, 3]));
2765
+ * // H[i,j] = d^2f / dx_i dx_j
2766
+ * ```
2767
+ */
2768
+ declare const hessian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
2590
2769
  /**
2591
2770
  * Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
2592
2771
  *
@@ -2609,4 +2788,4 @@ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
2609
2788
  */
2610
2789
  declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
2611
2790
  //#endregion
2612
- export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
2791
+ export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };