@jax-js/jax 0.1.5 → 0.1.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/{backend-DziQSaoQ.cjs → backend-D7s-Retx.cjs} +23 -4
- package/dist/{backend-DaqL-MNz.js → backend-Dx6Ob2D1.js} +18 -5
- package/dist/index.cjs +365 -110
- package/dist/index.d.cts +192 -13
- package/dist/index.d.ts +192 -13
- package/dist/index.js +365 -111
- package/dist/{webgl-RSuZKvgc.js → webgl-CLLvzJlO.js} +1 -1
- package/dist/{webgl-ClIYb8jP.cjs → webgl-CyfzNW8T.cjs} +1 -1
- package/dist/{webgpu-Dh7k9io0.js → webgpu-C-VfevQW.js} +1 -1
- package/dist/{webgpu-Db2JrNBr.cjs → webgpu-rraa6dfz.cjs} +1 -1
- package/package.json +1 -1
package/dist/index.d.cts
CHANGED
|
@@ -666,7 +666,7 @@ type IInfo = Readonly<{
|
|
|
666
666
|
/** Machine limits for integer types. */
|
|
667
667
|
declare function iinfo(dtype: DType): IInfo;
|
|
668
668
|
declare namespace numpy_d_exports {
|
|
669
|
-
export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot$1 as dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logspace, matmul, matrixTranspose, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
669
|
+
export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot$1 as dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logspace, matmul, matrixTranspose, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
670
670
|
}
|
|
671
671
|
declare const float32 = DType.Float32;
|
|
672
672
|
declare const int32 = DType.Int32;
|
|
@@ -886,6 +886,8 @@ declare function columnStack(xs: ArrayLike[]): Array;
|
|
|
886
886
|
declare function flipud(x: ArrayLike): Array;
|
|
887
887
|
/** Flip an array horizontally (axis=1). */
|
|
888
888
|
declare function fliplr(x: ArrayLike): Array;
|
|
889
|
+
/** Interchange two axes of an array. */
|
|
890
|
+
declare function swapaxes(a: ArrayLike, axis1: number, axis2: number): Array;
|
|
889
891
|
/** Transpose the last two dimensions of an array. */
|
|
890
892
|
declare function matrixTranspose(a: ArrayLike): Array;
|
|
891
893
|
/** Return a 1-D flattened array containing the elements of the input. */
|
|
@@ -1590,14 +1592,14 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
|
1590
1592
|
[Primitive.Pad]: {
|
|
1591
1593
|
width: Pair[];
|
|
1592
1594
|
};
|
|
1595
|
+
[Primitive.TriangularSolve]: {
|
|
1596
|
+
unitDiagonal: boolean;
|
|
1597
|
+
};
|
|
1593
1598
|
[Primitive.Jit]: {
|
|
1594
1599
|
name: string;
|
|
1595
1600
|
jaxpr: Jaxpr;
|
|
1596
1601
|
numConsts: number;
|
|
1597
1602
|
};
|
|
1598
|
-
[Primitive.TriangularSolve]: {
|
|
1599
|
-
unitDiagonal: boolean;
|
|
1600
|
-
};
|
|
1601
1603
|
}
|
|
1602
1604
|
/** Type of parameters taken by each primitive. */
|
|
1603
1605
|
type PrimitiveParams<T extends Primitive> = T extends keyof PrimitiveParamsImpl ? PrimitiveParamsImpl[T] : Record<string, never>;
|
|
@@ -2076,6 +2078,24 @@ declare function logspace(start: number, stop: number, num?: number, endpoint?:
|
|
|
2076
2078
|
dtype,
|
|
2077
2079
|
device
|
|
2078
2080
|
}?: DTypeAndDevice): Array;
|
|
2081
|
+
//#endregion
|
|
2082
|
+
//#region src/frontend/linearize.d.ts
|
|
2083
|
+
/** @inline */
|
|
2084
|
+
type GradOpts = {
|
|
2085
|
+
/**
|
|
2086
|
+
* Integer or sequence of integers. Specifies which positional argument(s) to
|
|
2087
|
+
* differentiate with respect to.
|
|
2088
|
+
*
|
|
2089
|
+
* Defaults to `0` (the first argument).
|
|
2090
|
+
*/
|
|
2091
|
+
argnums?: number | number[];
|
|
2092
|
+
/**
|
|
2093
|
+
* The input function returns a pair of `[out, aux]` including an auxiliary
|
|
2094
|
+
* value. This `aux` is not differentiated, but is returned alongside the
|
|
2095
|
+
* gradient when evaluating the function.
|
|
2096
|
+
*/
|
|
2097
|
+
hasAux?: boolean;
|
|
2098
|
+
};
|
|
2079
2099
|
declare namespace lax_linalg_d_exports {
|
|
2080
2100
|
export { cholesky, lu, triangularSolve };
|
|
2081
2101
|
}
|
|
@@ -2166,7 +2186,7 @@ declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
|
|
|
2166
2186
|
unitDiagonal?: boolean;
|
|
2167
2187
|
}): Array;
|
|
2168
2188
|
declare namespace lax_d_exports {
|
|
2169
|
-
export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convWithGeneralPadding, dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
|
|
2189
|
+
export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
|
|
2170
2190
|
}
|
|
2171
2191
|
/**
|
|
2172
2192
|
* Dimension numbers for general `dot()` primitive.
|
|
@@ -2204,7 +2224,11 @@ type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | Pair[];
|
|
|
2204
2224
|
* The semantics of this operation mimic the `jax.lax.conv_general_dilated`
|
|
2205
2225
|
* function in JAX, which wraps XLA's general convolution operator.
|
|
2206
2226
|
*
|
|
2207
|
-
*
|
|
2227
|
+
* @param lhs - Input tensor; shape `[N, C_in, ...xs]`
|
|
2228
|
+
* @param rhs - Convolution kernel; shape `[C_out, C_in / G, ...ks]`
|
|
2229
|
+
* @param windowStrides - Strides for each spatial dimension
|
|
2230
|
+
* @param padding - Padding for each spatial dimension, or a string
|
|
2231
|
+
* (`"VALID"`, `"SAME"`, or `"SAME_LOWER"`)
|
|
2208
2232
|
*/
|
|
2209
2233
|
declare function convGeneralDilated(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, {
|
|
2210
2234
|
lhsDilation,
|
|
@@ -2219,6 +2243,37 @@ declare function convGeneralDilated(lhs: Array, rhs: Array, windowStrides: numbe
|
|
|
2219
2243
|
declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, lhsDilation?: number[], rhsDilation?: number[]): Array;
|
|
2220
2244
|
/** Convenience wrapper around `convGeneralDilated`. */
|
|
2221
2245
|
declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType): Array;
|
|
2246
|
+
/**
|
|
2247
|
+
* Convenience wrapper for calculating the N-d convolution "transpose".
|
|
2248
|
+
*
|
|
2249
|
+
* This function directly calculates a fractionally strided conv rather than
|
|
2250
|
+
* indirectly calculating the gradient (transpose) of a forward convolution.
|
|
2251
|
+
* It is equivalent to the JAX version, except:
|
|
2252
|
+
*
|
|
2253
|
+
* - The `use_consistent_padding` option is not available. We only have the
|
|
2254
|
+
* consistent padding case (JAX version >0.8.4).
|
|
2255
|
+
* - The order of dimensions matches `lax.conv_general_dilated`.
|
|
2256
|
+
*
|
|
2257
|
+
* Unlike PyTorch/TensorFlow, by default we don't reverse the kernel's spatial
|
|
2258
|
+
* dimensions or the `(C_out, C_in)` axis order. To get this behavior, set
|
|
2259
|
+
* `transposeKernel` to true.
|
|
2260
|
+
*
|
|
2261
|
+
* @param lhs - Input tensor; shape `[N, C_in, ...xs]`
|
|
2262
|
+
* @param rhs - Convolution kernel; shape `[C_out, C_in, ...ks]`
|
|
2263
|
+
* @param strides - Sequence of n integers, sets fractional stride
|
|
2264
|
+
* @param padding - Apply padding of `dilation * (kernel_size - 1) - padding` to
|
|
2265
|
+
* each side of the input, so it acts like gradient of `conv()`
|
|
2266
|
+
* @param rhsDilation - Atrous dilation for the kernel
|
|
2267
|
+
* @param transposeKernel - Flip spatial axes and swap the input/output channels
|
|
2268
|
+
* of the kernel; its shape should be `[C_in, C_out, ...ks]`
|
|
2269
|
+
*/
|
|
2270
|
+
declare function convTranspose(lhs: Array, rhs: Array, strides: number[], padding: PaddingType, {
|
|
2271
|
+
rhsDilation,
|
|
2272
|
+
transposeKernel
|
|
2273
|
+
}?: {
|
|
2274
|
+
rhsDilation?: number[];
|
|
2275
|
+
transposeKernel?: boolean;
|
|
2276
|
+
}): Array;
|
|
2222
2277
|
/** Reduce a computation over padded windows. */
|
|
2223
2278
|
declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
|
|
2224
2279
|
/** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
|
|
@@ -2238,7 +2293,7 @@ declare function erfc(x: ArrayLike): Array;
|
|
|
2238
2293
|
*/
|
|
2239
2294
|
declare function stopGradient(x: ArrayLike): Array;
|
|
2240
2295
|
declare namespace nn_d_exports {
|
|
2241
|
-
export { celu, elu, gelu, glu, hardSigmoid, hardSilu, hardSilu as hardSwish, hardTanh, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, selu, sigmoid, silu, softSign, softmax, softplus, sparsePlus, sparseSigmoid, squareplus, standardize, silu as swish };
|
|
2296
|
+
export { celu, dotProductAttention, elu, gelu, glu, hardSigmoid, hardSilu, hardSilu as hardSwish, hardTanh, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, selu, sigmoid, silu, softSign, softmax, softplus, sparsePlus, sparseSigmoid, squareplus, standardize, silu as swish };
|
|
2242
2297
|
}
|
|
2243
2298
|
/**
|
|
2244
2299
|
* Rectified Linear Unit (ReLU) activation function:
|
|
@@ -2435,6 +2490,56 @@ declare function standardize(x: ArrayLike, axis?: Axis, opts?: {
|
|
|
2435
2490
|
* ```
|
|
2436
2491
|
*/
|
|
2437
2492
|
declare function oneHot(x: Array, numClasses: number): Array;
|
|
2493
|
+
/**
|
|
2494
|
+
* Scaled dot product attention (SDPA).
|
|
2495
|
+
*
|
|
2496
|
+
* Computes `softmax((Q @ K^T) / sqrt(d) + bias) @ V`, where `Q` is the query,
|
|
2497
|
+
* `K` is the key, `V` is the value, and `d` is the dimensionality of each key
|
|
2498
|
+
* and query vector.
|
|
2499
|
+
*
|
|
2500
|
+
* Multi-query attention is applied when input `key` and `value` tensors have
|
|
2501
|
+
* fewer heads than `query`.
|
|
2502
|
+
*
|
|
2503
|
+
* We use the following uppercase letters to denote array shapes:
|
|
2504
|
+
* - `B` = batch size
|
|
2505
|
+
* - `S` = length of key/value sequences (source)
|
|
2506
|
+
* - `L` = length of query sequences
|
|
2507
|
+
* - `N` = number of attention heads
|
|
2508
|
+
* - `H` = dimensionality of each attention head
|
|
2509
|
+
* - `K` = number of key/value heads (for grouped-query attention)
|
|
2510
|
+
*
|
|
2511
|
+
* The batch size `B` may be omitted, which is equivalent to `B = 1`. In this
|
|
2512
|
+
* case it must be omitted from all inputs.
|
|
2513
|
+
*
|
|
2514
|
+
* @param query - Query array; shape `[B, L, N, H]`
|
|
2515
|
+
* @param key - Key array; shape `[B, S, K, H]`
|
|
2516
|
+
* @param value - Value array; same shape as `key`
|
|
2517
|
+
* @param opts.bias - Optional bias to add to the attention logits; shape
|
|
2518
|
+
* `[B, N, L, S]` or broadcastable to it.
|
|
2519
|
+
* @param opts.mask - Optional mask to apply to the attention logits; should be
|
|
2520
|
+
* a boolean array broadcastable to `[B, N, L, S]`, where `true` indicates
|
|
2521
|
+
* the element should take part in attention.
|
|
2522
|
+
* @param opts.scale - Scaling factor override, default is `1 / sqrt(H)`.
|
|
2523
|
+
* @param opts.isCausal - If true, applies a casual mask.
|
|
2524
|
+
* @param opts.querySeqLengths - Optional sequence lengths for the queries;
|
|
2525
|
+
* shape `(B,)`. Taken from the beginning of the tensor.
|
|
2526
|
+
* @param opts.keyValueSeqLengths - Optional sequence lengths for the keys and
|
|
2527
|
+
* values; shape `(B,)`. Taken from the beginning of the tensor.
|
|
2528
|
+
* @param opts.localWindowSize - If specified, applies a local attention window
|
|
2529
|
+
* of the given size. Can be a single number or a tuple `[left, right]`.
|
|
2530
|
+
*
|
|
2531
|
+
* @returns The result of the attention operation; shape is the same as query
|
|
2532
|
+
* `[B, L, N, H]`, or `[L, N, H]` if `B` is omitted.
|
|
2533
|
+
*/
|
|
2534
|
+
declare function dotProductAttention(query: ArrayLike, key: ArrayLike, value: ArrayLike, opts?: {
|
|
2535
|
+
bias?: ArrayLike;
|
|
2536
|
+
mask?: ArrayLike;
|
|
2537
|
+
scale?: number;
|
|
2538
|
+
isCausal?: boolean;
|
|
2539
|
+
querySeqLengths?: ArrayLike;
|
|
2540
|
+
keyValueSeqLengths?: ArrayLike;
|
|
2541
|
+
localWindowSize?: number | [number, number];
|
|
2542
|
+
}): Array;
|
|
2438
2543
|
declare namespace random_d_exports {
|
|
2439
2544
|
export { bernoulli, bits, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
|
|
2440
2545
|
}
|
|
@@ -2526,7 +2631,9 @@ declare const logit: OwnedFunction<(x: ArrayLike) => Array>;
|
|
|
2526
2631
|
* @function
|
|
2527
2632
|
* Compute the forward-mode Jacobian-vector product for a function.
|
|
2528
2633
|
*/
|
|
2529
|
-
declare const jvp: <F extends (...args: any[]) => JsTree<Array
|
|
2634
|
+
declare const jvp: <F extends (...args: any[]) => JsTree<Array>, HA extends boolean = false>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, tangents: MapJsTree<Parameters<F>, Array, ArrayLike>, opts?: {
|
|
2635
|
+
hasAux?: HA;
|
|
2636
|
+
}) => HA extends true ? ReturnType<F> extends [infer Out, infer Aux] ? [Out, Out, Aux] : never : [ReturnType<F>, ReturnType<F>];
|
|
2530
2637
|
/**
|
|
2531
2638
|
* @function
|
|
2532
2639
|
* Vectorize an operation on a batched axis for one or more inputs.
|
|
@@ -2568,28 +2675,100 @@ declare const jit: <F extends (...args: any[]) => JsTree<Array>>(f: F, opts?: Ji
|
|
|
2568
2675
|
* Produce a local linear approximation to a function at a point using jvp() and
|
|
2569
2676
|
* partial evaluation.
|
|
2570
2677
|
*/
|
|
2571
|
-
declare const linearize: <F extends (...args: any[]) => JsTree<Array
|
|
2678
|
+
declare const linearize: <F extends (...args: any[]) => JsTree<Array>, HA extends boolean = false>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, opts?: {
|
|
2679
|
+
hasAux?: HA;
|
|
2680
|
+
}) => HA extends true ? ReturnType<F> extends [infer Out, infer Aux] ? [Out, OwnedFunction<(...tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => Out>, Aux] : never : [ReturnType<F>, OwnedFunction<(...tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>>];
|
|
2572
2681
|
/**
|
|
2573
2682
|
* @function
|
|
2574
2683
|
* Calculate the reverse-mode vector-Jacobian product for a function.
|
|
2684
|
+
*
|
|
2685
|
+
* The return value is a tuple of `[out, vjpFn]`, where `out` is the output of
|
|
2686
|
+
* `f(primals)`, and `vjpFn` is a function that takes in cotangents for each
|
|
2687
|
+
* output and returns the cotangents for each input.
|
|
2688
|
+
*
|
|
2689
|
+
* When `{ hasAux: true }` is passed, the function `f` is expected to return an
|
|
2690
|
+
* `[out, aux]` tuple, and `vjp` returns `[out, vjpFn, aux]`.
|
|
2691
|
+
*
|
|
2692
|
+
* @example
|
|
2693
|
+
* ```ts
|
|
2694
|
+
* const [y, vjpFn] = vjp(f, [x]);
|
|
2695
|
+
*
|
|
2696
|
+
* // With hasAux
|
|
2697
|
+
* const [y, vjpFn, aux] = vjp(f, [x], { hasAux: true });
|
|
2698
|
+
* ```
|
|
2575
2699
|
*/
|
|
2576
|
-
declare const vjp: <F extends (...args: any[]) => JsTree<Array
|
|
2700
|
+
declare const vjp: <F extends (...args: any[]) => JsTree<Array>, const HA extends boolean = false>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, opts?: {
|
|
2701
|
+
hasAux?: HA;
|
|
2702
|
+
}) => HA extends true ? ReturnType<F> extends [infer Out, infer Aux] ? [Out, OwnedFunction<(cotangents: MapJsTree<Out, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>>, Aux] : never : [ReturnType<F>, OwnedFunction<(cotangents: MapJsTree<ReturnType<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>>];
|
|
2703
|
+
/** @inline */
|
|
2704
|
+
type GradOutputType<I, F extends (...args: any[]) => any> = MapJsTree<I extends undefined ? Parameters<F>[0] : I extends number ? Parameters<F>[I] : I extends number[] ? { [K in keyof I]: I[K] extends number ? Parameters<F>[I[K]] : never } : never, ArrayLike, Array>;
|
|
2577
2705
|
/**
|
|
2578
2706
|
* @function
|
|
2579
2707
|
* Compute the gradient of a scalar-valued function `f` with respect to its
|
|
2580
2708
|
* first argument.
|
|
2709
|
+
*
|
|
2710
|
+
* Pass in different `argnums` to differentiate with respect to other
|
|
2711
|
+
* arguments. If a tuple is provided, the return value will be a tuple of
|
|
2712
|
+
* gradients corresponding to each argument index.
|
|
2713
|
+
*
|
|
2714
|
+
* When `{ hasAux: true }` is passed, the function `f` is expected to return a
|
|
2715
|
+
* `[out, aux]` tuple, and the return value will be `[gradient, aux]`.
|
|
2716
|
+
*
|
|
2717
|
+
* @example
|
|
2718
|
+
* ```ts
|
|
2719
|
+
* const gradient = grad(f)(x);
|
|
2720
|
+
*
|
|
2721
|
+
* // With `argnums`
|
|
2722
|
+
* const [gradientX, gradientZ] = grad(f, { argnums: [0, 2] })(x, y, z);
|
|
2723
|
+
*
|
|
2724
|
+
* // With `hasAux`
|
|
2725
|
+
* const [gradient, aux] = grad(f, { hasAux: true })(x);
|
|
2726
|
+
* ```
|
|
2581
2727
|
*/
|
|
2582
|
-
declare const grad: <F extends (...args: any[]) => JsTree<Array
|
|
2728
|
+
declare const grad: <F extends (...args: any[]) => JsTree<Array>, const I extends undefined | number | number[] = undefined, const HA extends boolean = false>(f: F, opts?: Omit<GradOpts, "argnums" | "hasAux"> & {
|
|
2729
|
+
argnums?: I;
|
|
2730
|
+
hasAux?: HA;
|
|
2731
|
+
}) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => HA extends true ? ReturnType<F> extends [any, infer Aux] ? [GradOutputType<I, F>, Aux] : never : GradOutputType<I, F>;
|
|
2583
2732
|
/**
|
|
2584
2733
|
* @function
|
|
2585
2734
|
* Create a function that evaluates both `f` and the gradient of `f`.
|
|
2735
|
+
*
|
|
2736
|
+
* When `{ hasAux: true }` is passed, the function `f` is expected to return an
|
|
2737
|
+
* `[out, aux]` tuple, and the return value will be `[[out, aux], gradient]`.
|
|
2738
|
+
*
|
|
2739
|
+
* @example
|
|
2740
|
+
* ```ts
|
|
2741
|
+
* // Without hasAux
|
|
2742
|
+
* const [value, gradient] = valueAndGrad(f)(x);
|
|
2743
|
+
*
|
|
2744
|
+
* // With hasAux
|
|
2745
|
+
* const [[value, aux], gradient] = valueAndGrad(f, { hasAux: true })(x);
|
|
2746
|
+
* ```
|
|
2586
2747
|
*/
|
|
2587
|
-
declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array
|
|
2748
|
+
declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>, const I extends undefined | number | number[] = undefined, const HA extends boolean = false>(f: F, opts?: Omit<GradOpts, "argnums"> & {
|
|
2749
|
+
argnums?: I;
|
|
2750
|
+
hasAux?: HA;
|
|
2751
|
+
}) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, GradOutputType<I, F>];
|
|
2588
2752
|
/**
|
|
2589
2753
|
* @function
|
|
2590
2754
|
* Compute the Jacobian evaluated row-by-row by reverse-mode AD.
|
|
2591
2755
|
*/
|
|
2592
2756
|
declare const jacrev: typeof jacfwd;
|
|
2757
|
+
/**
|
|
2758
|
+
* @function
|
|
2759
|
+
* Compute the Hessian matrix of a scalar-valued function.
|
|
2760
|
+
*
|
|
2761
|
+
* The Hessian is the matrix of second-order partial derivatives of a function.
|
|
2762
|
+
* This is implemented as `jacfwd(grad(f))`.
|
|
2763
|
+
*
|
|
2764
|
+
* @example
|
|
2765
|
+
* ```ts
|
|
2766
|
+
* const f = (x: np.Array) => np.sum(x.ref.mul(x.ref).mul(x)); // x^3
|
|
2767
|
+
* const H = hessian(f)(np.array([1, 2, 3]));
|
|
2768
|
+
* // H[i,j] = d^2f / dx_i dx_j
|
|
2769
|
+
* ```
|
|
2770
|
+
*/
|
|
2771
|
+
declare const hessian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
2593
2772
|
/**
|
|
2594
2773
|
* Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
|
|
2595
2774
|
*
|
|
@@ -2612,4 +2791,4 @@ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
|
|
|
2612
2791
|
*/
|
|
2613
2792
|
declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
|
|
2614
2793
|
//#endregion
|
|
2615
|
-
export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|
|
2794
|
+
export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|
package/dist/index.d.ts
CHANGED
|
@@ -663,7 +663,7 @@ type IInfo = Readonly<{
|
|
|
663
663
|
/** Machine limits for integer types. */
|
|
664
664
|
declare function iinfo(dtype: DType): IInfo;
|
|
665
665
|
declare namespace numpy_d_exports {
|
|
666
|
-
export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot$1 as dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logspace, matmul, matrixTranspose, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
666
|
+
export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot$1 as dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logspace, matmul, matrixTranspose, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
667
667
|
}
|
|
668
668
|
declare const float32 = DType.Float32;
|
|
669
669
|
declare const int32 = DType.Int32;
|
|
@@ -883,6 +883,8 @@ declare function columnStack(xs: ArrayLike[]): Array;
|
|
|
883
883
|
declare function flipud(x: ArrayLike): Array;
|
|
884
884
|
/** Flip an array horizontally (axis=1). */
|
|
885
885
|
declare function fliplr(x: ArrayLike): Array;
|
|
886
|
+
/** Interchange two axes of an array. */
|
|
887
|
+
declare function swapaxes(a: ArrayLike, axis1: number, axis2: number): Array;
|
|
886
888
|
/** Transpose the last two dimensions of an array. */
|
|
887
889
|
declare function matrixTranspose(a: ArrayLike): Array;
|
|
888
890
|
/** Return a 1-D flattened array containing the elements of the input. */
|
|
@@ -1587,14 +1589,14 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
|
1587
1589
|
[Primitive.Pad]: {
|
|
1588
1590
|
width: Pair[];
|
|
1589
1591
|
};
|
|
1592
|
+
[Primitive.TriangularSolve]: {
|
|
1593
|
+
unitDiagonal: boolean;
|
|
1594
|
+
};
|
|
1590
1595
|
[Primitive.Jit]: {
|
|
1591
1596
|
name: string;
|
|
1592
1597
|
jaxpr: Jaxpr;
|
|
1593
1598
|
numConsts: number;
|
|
1594
1599
|
};
|
|
1595
|
-
[Primitive.TriangularSolve]: {
|
|
1596
|
-
unitDiagonal: boolean;
|
|
1597
|
-
};
|
|
1598
1600
|
}
|
|
1599
1601
|
/** Type of parameters taken by each primitive. */
|
|
1600
1602
|
type PrimitiveParams<T extends Primitive> = T extends keyof PrimitiveParamsImpl ? PrimitiveParamsImpl[T] : Record<string, never>;
|
|
@@ -2073,6 +2075,24 @@ declare function logspace(start: number, stop: number, num?: number, endpoint?:
|
|
|
2073
2075
|
dtype,
|
|
2074
2076
|
device
|
|
2075
2077
|
}?: DTypeAndDevice): Array;
|
|
2078
|
+
//#endregion
|
|
2079
|
+
//#region src/frontend/linearize.d.ts
|
|
2080
|
+
/** @inline */
|
|
2081
|
+
type GradOpts = {
|
|
2082
|
+
/**
|
|
2083
|
+
* Integer or sequence of integers. Specifies which positional argument(s) to
|
|
2084
|
+
* differentiate with respect to.
|
|
2085
|
+
*
|
|
2086
|
+
* Defaults to `0` (the first argument).
|
|
2087
|
+
*/
|
|
2088
|
+
argnums?: number | number[];
|
|
2089
|
+
/**
|
|
2090
|
+
* The input function returns a pair of `[out, aux]` including an auxiliary
|
|
2091
|
+
* value. This `aux` is not differentiated, but is returned alongside the
|
|
2092
|
+
* gradient when evaluating the function.
|
|
2093
|
+
*/
|
|
2094
|
+
hasAux?: boolean;
|
|
2095
|
+
};
|
|
2076
2096
|
declare namespace lax_linalg_d_exports {
|
|
2077
2097
|
export { cholesky, lu, triangularSolve };
|
|
2078
2098
|
}
|
|
@@ -2163,7 +2183,7 @@ declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
|
|
|
2163
2183
|
unitDiagonal?: boolean;
|
|
2164
2184
|
}): Array;
|
|
2165
2185
|
declare namespace lax_d_exports {
|
|
2166
|
-
export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convWithGeneralPadding, dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
|
|
2186
|
+
export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
|
|
2167
2187
|
}
|
|
2168
2188
|
/**
|
|
2169
2189
|
* Dimension numbers for general `dot()` primitive.
|
|
@@ -2201,7 +2221,11 @@ type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | Pair[];
|
|
|
2201
2221
|
* The semantics of this operation mimic the `jax.lax.conv_general_dilated`
|
|
2202
2222
|
* function in JAX, which wraps XLA's general convolution operator.
|
|
2203
2223
|
*
|
|
2204
|
-
*
|
|
2224
|
+
* @param lhs - Input tensor; shape `[N, C_in, ...xs]`
|
|
2225
|
+
* @param rhs - Convolution kernel; shape `[C_out, C_in / G, ...ks]`
|
|
2226
|
+
* @param windowStrides - Strides for each spatial dimension
|
|
2227
|
+
* @param padding - Padding for each spatial dimension, or a string
|
|
2228
|
+
* (`"VALID"`, `"SAME"`, or `"SAME_LOWER"`)
|
|
2205
2229
|
*/
|
|
2206
2230
|
declare function convGeneralDilated(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, {
|
|
2207
2231
|
lhsDilation,
|
|
@@ -2216,6 +2240,37 @@ declare function convGeneralDilated(lhs: Array, rhs: Array, windowStrides: numbe
|
|
|
2216
2240
|
declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, lhsDilation?: number[], rhsDilation?: number[]): Array;
|
|
2217
2241
|
/** Convenience wrapper around `convGeneralDilated`. */
|
|
2218
2242
|
declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType): Array;
|
|
2243
|
+
/**
|
|
2244
|
+
* Convenience wrapper for calculating the N-d convolution "transpose".
|
|
2245
|
+
*
|
|
2246
|
+
* This function directly calculates a fractionally strided conv rather than
|
|
2247
|
+
* indirectly calculating the gradient (transpose) of a forward convolution.
|
|
2248
|
+
* It is equivalent to the JAX version, except:
|
|
2249
|
+
*
|
|
2250
|
+
* - The `use_consistent_padding` option is not available. We only have the
|
|
2251
|
+
* consistent padding case (JAX version >0.8.4).
|
|
2252
|
+
* - The order of dimensions matches `lax.conv_general_dilated`.
|
|
2253
|
+
*
|
|
2254
|
+
* Unlike PyTorch/TensorFlow, by default we don't reverse the kernel's spatial
|
|
2255
|
+
* dimensions or the `(C_out, C_in)` axis order. To get this behavior, set
|
|
2256
|
+
* `transposeKernel` to true.
|
|
2257
|
+
*
|
|
2258
|
+
* @param lhs - Input tensor; shape `[N, C_in, ...xs]`
|
|
2259
|
+
* @param rhs - Convolution kernel; shape `[C_out, C_in, ...ks]`
|
|
2260
|
+
* @param strides - Sequence of n integers, sets fractional stride
|
|
2261
|
+
* @param padding - Apply padding of `dilation * (kernel_size - 1) - padding` to
|
|
2262
|
+
* each side of the input, so it acts like gradient of `conv()`
|
|
2263
|
+
* @param rhsDilation - Atrous dilation for the kernel
|
|
2264
|
+
* @param transposeKernel - Flip spatial axes and swap the input/output channels
|
|
2265
|
+
* of the kernel; its shape should be `[C_in, C_out, ...ks]`
|
|
2266
|
+
*/
|
|
2267
|
+
declare function convTranspose(lhs: Array, rhs: Array, strides: number[], padding: PaddingType, {
|
|
2268
|
+
rhsDilation,
|
|
2269
|
+
transposeKernel
|
|
2270
|
+
}?: {
|
|
2271
|
+
rhsDilation?: number[];
|
|
2272
|
+
transposeKernel?: boolean;
|
|
2273
|
+
}): Array;
|
|
2219
2274
|
/** Reduce a computation over padded windows. */
|
|
2220
2275
|
declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
|
|
2221
2276
|
/** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
|
|
@@ -2235,7 +2290,7 @@ declare function erfc(x: ArrayLike): Array;
|
|
|
2235
2290
|
*/
|
|
2236
2291
|
declare function stopGradient(x: ArrayLike): Array;
|
|
2237
2292
|
declare namespace nn_d_exports {
|
|
2238
|
-
export { celu, elu, gelu, glu, hardSigmoid, hardSilu, hardSilu as hardSwish, hardTanh, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, selu, sigmoid, silu, softSign, softmax, softplus, sparsePlus, sparseSigmoid, squareplus, standardize, silu as swish };
|
|
2293
|
+
export { celu, dotProductAttention, elu, gelu, glu, hardSigmoid, hardSilu, hardSilu as hardSwish, hardTanh, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, selu, sigmoid, silu, softSign, softmax, softplus, sparsePlus, sparseSigmoid, squareplus, standardize, silu as swish };
|
|
2239
2294
|
}
|
|
2240
2295
|
/**
|
|
2241
2296
|
* Rectified Linear Unit (ReLU) activation function:
|
|
@@ -2432,6 +2487,56 @@ declare function standardize(x: ArrayLike, axis?: Axis, opts?: {
|
|
|
2432
2487
|
* ```
|
|
2433
2488
|
*/
|
|
2434
2489
|
declare function oneHot(x: Array, numClasses: number): Array;
|
|
2490
|
+
/**
|
|
2491
|
+
* Scaled dot product attention (SDPA).
|
|
2492
|
+
*
|
|
2493
|
+
* Computes `softmax((Q @ K^T) / sqrt(d) + bias) @ V`, where `Q` is the query,
|
|
2494
|
+
* `K` is the key, `V` is the value, and `d` is the dimensionality of each key
|
|
2495
|
+
* and query vector.
|
|
2496
|
+
*
|
|
2497
|
+
* Multi-query attention is applied when input `key` and `value` tensors have
|
|
2498
|
+
* fewer heads than `query`.
|
|
2499
|
+
*
|
|
2500
|
+
* We use the following uppercase letters to denote array shapes:
|
|
2501
|
+
* - `B` = batch size
|
|
2502
|
+
* - `S` = length of key/value sequences (source)
|
|
2503
|
+
* - `L` = length of query sequences
|
|
2504
|
+
* - `N` = number of attention heads
|
|
2505
|
+
* - `H` = dimensionality of each attention head
|
|
2506
|
+
* - `K` = number of key/value heads (for grouped-query attention)
|
|
2507
|
+
*
|
|
2508
|
+
* The batch size `B` may be omitted, which is equivalent to `B = 1`. In this
|
|
2509
|
+
* case it must be omitted from all inputs.
|
|
2510
|
+
*
|
|
2511
|
+
* @param query - Query array; shape `[B, L, N, H]`
|
|
2512
|
+
* @param key - Key array; shape `[B, S, K, H]`
|
|
2513
|
+
* @param value - Value array; same shape as `key`
|
|
2514
|
+
* @param opts.bias - Optional bias to add to the attention logits; shape
|
|
2515
|
+
* `[B, N, L, S]` or broadcastable to it.
|
|
2516
|
+
* @param opts.mask - Optional mask to apply to the attention logits; should be
|
|
2517
|
+
* a boolean array broadcastable to `[B, N, L, S]`, where `true` indicates
|
|
2518
|
+
* the element should take part in attention.
|
|
2519
|
+
* @param opts.scale - Scaling factor override, default is `1 / sqrt(H)`.
|
|
2520
|
+
* @param opts.isCausal - If true, applies a casual mask.
|
|
2521
|
+
* @param opts.querySeqLengths - Optional sequence lengths for the queries;
|
|
2522
|
+
* shape `(B,)`. Taken from the beginning of the tensor.
|
|
2523
|
+
* @param opts.keyValueSeqLengths - Optional sequence lengths for the keys and
|
|
2524
|
+
* values; shape `(B,)`. Taken from the beginning of the tensor.
|
|
2525
|
+
* @param opts.localWindowSize - If specified, applies a local attention window
|
|
2526
|
+
* of the given size. Can be a single number or a tuple `[left, right]`.
|
|
2527
|
+
*
|
|
2528
|
+
* @returns The result of the attention operation; shape is the same as query
|
|
2529
|
+
* `[B, L, N, H]`, or `[L, N, H]` if `B` is omitted.
|
|
2530
|
+
*/
|
|
2531
|
+
declare function dotProductAttention(query: ArrayLike, key: ArrayLike, value: ArrayLike, opts?: {
|
|
2532
|
+
bias?: ArrayLike;
|
|
2533
|
+
mask?: ArrayLike;
|
|
2534
|
+
scale?: number;
|
|
2535
|
+
isCausal?: boolean;
|
|
2536
|
+
querySeqLengths?: ArrayLike;
|
|
2537
|
+
keyValueSeqLengths?: ArrayLike;
|
|
2538
|
+
localWindowSize?: number | [number, number];
|
|
2539
|
+
}): Array;
|
|
2435
2540
|
declare namespace random_d_exports {
|
|
2436
2541
|
export { bernoulli, bits, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
|
|
2437
2542
|
}
|
|
@@ -2523,7 +2628,9 @@ declare const logit: OwnedFunction<(x: ArrayLike) => Array>;
|
|
|
2523
2628
|
* @function
|
|
2524
2629
|
* Compute the forward-mode Jacobian-vector product for a function.
|
|
2525
2630
|
*/
|
|
2526
|
-
declare const jvp: <F extends (...args: any[]) => JsTree<Array
|
|
2631
|
+
declare const jvp: <F extends (...args: any[]) => JsTree<Array>, HA extends boolean = false>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, tangents: MapJsTree<Parameters<F>, Array, ArrayLike>, opts?: {
|
|
2632
|
+
hasAux?: HA;
|
|
2633
|
+
}) => HA extends true ? ReturnType<F> extends [infer Out, infer Aux] ? [Out, Out, Aux] : never : [ReturnType<F>, ReturnType<F>];
|
|
2527
2634
|
/**
|
|
2528
2635
|
* @function
|
|
2529
2636
|
* Vectorize an operation on a batched axis for one or more inputs.
|
|
@@ -2565,28 +2672,100 @@ declare const jit: <F extends (...args: any[]) => JsTree<Array>>(f: F, opts?: Ji
|
|
|
2565
2672
|
* Produce a local linear approximation to a function at a point using jvp() and
|
|
2566
2673
|
* partial evaluation.
|
|
2567
2674
|
*/
|
|
2568
|
-
declare const linearize: <F extends (...args: any[]) => JsTree<Array
|
|
2675
|
+
declare const linearize: <F extends (...args: any[]) => JsTree<Array>, HA extends boolean = false>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, opts?: {
|
|
2676
|
+
hasAux?: HA;
|
|
2677
|
+
}) => HA extends true ? ReturnType<F> extends [infer Out, infer Aux] ? [Out, OwnedFunction<(...tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => Out>, Aux] : never : [ReturnType<F>, OwnedFunction<(...tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>>];
|
|
2569
2678
|
/**
|
|
2570
2679
|
* @function
|
|
2571
2680
|
* Calculate the reverse-mode vector-Jacobian product for a function.
|
|
2681
|
+
*
|
|
2682
|
+
* The return value is a tuple of `[out, vjpFn]`, where `out` is the output of
|
|
2683
|
+
* `f(primals)`, and `vjpFn` is a function that takes in cotangents for each
|
|
2684
|
+
* output and returns the cotangents for each input.
|
|
2685
|
+
*
|
|
2686
|
+
* When `{ hasAux: true }` is passed, the function `f` is expected to return an
|
|
2687
|
+
* `[out, aux]` tuple, and `vjp` returns `[out, vjpFn, aux]`.
|
|
2688
|
+
*
|
|
2689
|
+
* @example
|
|
2690
|
+
* ```ts
|
|
2691
|
+
* const [y, vjpFn] = vjp(f, [x]);
|
|
2692
|
+
*
|
|
2693
|
+
* // With hasAux
|
|
2694
|
+
* const [y, vjpFn, aux] = vjp(f, [x], { hasAux: true });
|
|
2695
|
+
* ```
|
|
2572
2696
|
*/
|
|
2573
|
-
declare const vjp: <F extends (...args: any[]) => JsTree<Array
|
|
2697
|
+
declare const vjp: <F extends (...args: any[]) => JsTree<Array>, const HA extends boolean = false>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, opts?: {
|
|
2698
|
+
hasAux?: HA;
|
|
2699
|
+
}) => HA extends true ? ReturnType<F> extends [infer Out, infer Aux] ? [Out, OwnedFunction<(cotangents: MapJsTree<Out, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>>, Aux] : never : [ReturnType<F>, OwnedFunction<(cotangents: MapJsTree<ReturnType<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>>];
|
|
2700
|
+
/** @inline */
|
|
2701
|
+
type GradOutputType<I, F extends (...args: any[]) => any> = MapJsTree<I extends undefined ? Parameters<F>[0] : I extends number ? Parameters<F>[I] : I extends number[] ? { [K in keyof I]: I[K] extends number ? Parameters<F>[I[K]] : never } : never, ArrayLike, Array>;
|
|
2574
2702
|
/**
|
|
2575
2703
|
* @function
|
|
2576
2704
|
* Compute the gradient of a scalar-valued function `f` with respect to its
|
|
2577
2705
|
* first argument.
|
|
2706
|
+
*
|
|
2707
|
+
* Pass in different `argnums` to differentiate with respect to other
|
|
2708
|
+
* arguments. If a tuple is provided, the return value will be a tuple of
|
|
2709
|
+
* gradients corresponding to each argument index.
|
|
2710
|
+
*
|
|
2711
|
+
* When `{ hasAux: true }` is passed, the function `f` is expected to return a
|
|
2712
|
+
* `[out, aux]` tuple, and the return value will be `[gradient, aux]`.
|
|
2713
|
+
*
|
|
2714
|
+
* @example
|
|
2715
|
+
* ```ts
|
|
2716
|
+
* const gradient = grad(f)(x);
|
|
2717
|
+
*
|
|
2718
|
+
* // With `argnums`
|
|
2719
|
+
* const [gradientX, gradientZ] = grad(f, { argnums: [0, 2] })(x, y, z);
|
|
2720
|
+
*
|
|
2721
|
+
* // With `hasAux`
|
|
2722
|
+
* const [gradient, aux] = grad(f, { hasAux: true })(x);
|
|
2723
|
+
* ```
|
|
2578
2724
|
*/
|
|
2579
|
-
declare const grad: <F extends (...args: any[]) => JsTree<Array
|
|
2725
|
+
declare const grad: <F extends (...args: any[]) => JsTree<Array>, const I extends undefined | number | number[] = undefined, const HA extends boolean = false>(f: F, opts?: Omit<GradOpts, "argnums" | "hasAux"> & {
|
|
2726
|
+
argnums?: I;
|
|
2727
|
+
hasAux?: HA;
|
|
2728
|
+
}) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => HA extends true ? ReturnType<F> extends [any, infer Aux] ? [GradOutputType<I, F>, Aux] : never : GradOutputType<I, F>;
|
|
2580
2729
|
/**
|
|
2581
2730
|
* @function
|
|
2582
2731
|
* Create a function that evaluates both `f` and the gradient of `f`.
|
|
2732
|
+
*
|
|
2733
|
+
* When `{ hasAux: true }` is passed, the function `f` is expected to return an
|
|
2734
|
+
* `[out, aux]` tuple, and the return value will be `[[out, aux], gradient]`.
|
|
2735
|
+
*
|
|
2736
|
+
* @example
|
|
2737
|
+
* ```ts
|
|
2738
|
+
* // Without hasAux
|
|
2739
|
+
* const [value, gradient] = valueAndGrad(f)(x);
|
|
2740
|
+
*
|
|
2741
|
+
* // With hasAux
|
|
2742
|
+
* const [[value, aux], gradient] = valueAndGrad(f, { hasAux: true })(x);
|
|
2743
|
+
* ```
|
|
2583
2744
|
*/
|
|
2584
|
-
declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array
|
|
2745
|
+
declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>, const I extends undefined | number | number[] = undefined, const HA extends boolean = false>(f: F, opts?: Omit<GradOpts, "argnums"> & {
|
|
2746
|
+
argnums?: I;
|
|
2747
|
+
hasAux?: HA;
|
|
2748
|
+
}) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, GradOutputType<I, F>];
|
|
2585
2749
|
/**
|
|
2586
2750
|
* @function
|
|
2587
2751
|
* Compute the Jacobian evaluated row-by-row by reverse-mode AD.
|
|
2588
2752
|
*/
|
|
2589
2753
|
declare const jacrev: typeof jacfwd;
|
|
2754
|
+
/**
|
|
2755
|
+
* @function
|
|
2756
|
+
* Compute the Hessian matrix of a scalar-valued function.
|
|
2757
|
+
*
|
|
2758
|
+
* The Hessian is the matrix of second-order partial derivatives of a function.
|
|
2759
|
+
* This is implemented as `jacfwd(grad(f))`.
|
|
2760
|
+
*
|
|
2761
|
+
* @example
|
|
2762
|
+
* ```ts
|
|
2763
|
+
* const f = (x: np.Array) => np.sum(x.ref.mul(x.ref).mul(x)); // x^3
|
|
2764
|
+
* const H = hessian(f)(np.array([1, 2, 3]));
|
|
2765
|
+
* // H[i,j] = d^2f / dx_i dx_j
|
|
2766
|
+
* ```
|
|
2767
|
+
*/
|
|
2768
|
+
declare const hessian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
2590
2769
|
/**
|
|
2591
2770
|
* Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
|
|
2592
2771
|
*
|
|
@@ -2609,4 +2788,4 @@ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
|
|
|
2609
2788
|
*/
|
|
2610
2789
|
declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
|
|
2611
2790
|
//#endregion
|
|
2612
|
-
export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|
|
2791
|
+
export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|