@jax-js/jax 0.1.9 → 0.1.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +31 -18
- package/dist/{backend-BId79r5b.js → backend-Ctqs8la1.js} +107 -11
- package/dist/{backend-DpI0riom.cjs → backend-DMauYnfl.cjs} +142 -10
- package/dist/index.cjs +225 -18
- package/dist/index.d.cts +112 -11
- package/dist/index.d.ts +112 -11
- package/dist/index.js +225 -19
- package/dist/{webgl-DnGrclTz.js → webgl-CvQ1QBX1.js} +1 -1
- package/dist/{webgl-C5NjXc1p.cjs → webgl-kvVt7-T7.cjs} +1 -1
- package/dist/{webgpu-CdjiJSa7.cjs → webgpu-DMSx7a6M.cjs} +136 -6
- package/dist/{webgpu-AN0cG_nB.js → webgpu-v_W_-oKw.js} +136 -6
- package/package.json +5 -16
package/dist/index.d.ts
CHANGED
|
@@ -1001,6 +1001,8 @@ declare abstract class Tracer {
|
|
|
1001
1001
|
reshape(shape: number | number[]): this;
|
|
1002
1002
|
/** Copy the array and cast to a specified dtype. */
|
|
1003
1003
|
astype(dtype: DType): this;
|
|
1004
|
+
/** Return a bitwise cast of the array, viewed as a new dtype. */
|
|
1005
|
+
view(dtype?: DType): this;
|
|
1004
1006
|
/** Subtract an array from this one. */
|
|
1005
1007
|
sub(other: this | TracerValue): this;
|
|
1006
1008
|
/** Divide an array by this one. */
|
|
@@ -1424,8 +1426,10 @@ declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
|
|
|
1424
1426
|
unitDiagonal?: boolean;
|
|
1425
1427
|
}): Array;
|
|
1426
1428
|
declare namespace lax_d_exports {
|
|
1427
|
-
export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient, topK };
|
|
1429
|
+
export { DotDimensionNumbers, PaddingType, bitcastConvertType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient, topK };
|
|
1428
1430
|
}
|
|
1431
|
+
/** Elementwise bitcast an array into a new dtype. */
|
|
1432
|
+
declare function bitcastConvertType(x: ArrayLike, newDtype: DType): Array;
|
|
1429
1433
|
/**
|
|
1430
1434
|
* Dimension numbers for general `dot()` primitive.
|
|
1431
1435
|
*
|
|
@@ -1564,7 +1568,7 @@ declare function fft(a: ComplexPair, axis?: number): ComplexPair;
|
|
|
1564
1568
|
*/
|
|
1565
1569
|
declare function ifft(a: ComplexPair, axis?: number): ComplexPair;
|
|
1566
1570
|
declare namespace numpy_linalg_d_exports {
|
|
1567
|
-
export { cholesky, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot };
|
|
1571
|
+
export { cholesky, cross$1 as cross, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot };
|
|
1568
1572
|
}
|
|
1569
1573
|
/**
|
|
1570
1574
|
* Compute the Cholesky decomposition of a (batched) positive-definite matrix.
|
|
@@ -1579,6 +1583,13 @@ declare function cholesky(a: ArrayLike, {
|
|
|
1579
1583
|
upper?: boolean;
|
|
1580
1584
|
symmetrizeInput?: boolean;
|
|
1581
1585
|
}): Array;
|
|
1586
|
+
/**
|
|
1587
|
+
* Compute the cross-product of two 3D vectors.
|
|
1588
|
+
*
|
|
1589
|
+
* This is a simpler and less flexible version of `jax.numpy.cross()`.
|
|
1590
|
+
* Both inputs must have size 3 along the specified axis.
|
|
1591
|
+
*/
|
|
1592
|
+
declare function cross$1(x1: ArrayLike, x2: ArrayLike, axis?: number): Array;
|
|
1582
1593
|
/** Compute the determinant of a square matrix (batched). */
|
|
1583
1594
|
declare function det(a: ArrayLike): Array;
|
|
1584
1595
|
/** Compute the inverse of a square matrix (batched). */
|
|
@@ -1665,7 +1676,7 @@ type IInfo = Readonly<{
|
|
|
1665
1676
|
/** Machine limits for integer types. */
|
|
1666
1677
|
declare function iinfo(dtype: DType): IInfo;
|
|
1667
1678
|
declare namespace numpy_d_exports {
|
|
1668
|
-
export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual,
|
|
1679
|
+
export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, round as around, array, arrayEqual, arrayEquiv, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, average, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, copysign, corrcoef, correlate, cos, cosh, cov, cross, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logicalAnd, logicalNot, logicalOr, logicalXor, logspace, matmul, matrixTranspose, matvec, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, nanToNum, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, rint, round, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vecmat, vstack, where, zeros, zerosLike };
|
|
1669
1680
|
}
|
|
1670
1681
|
declare const float32 = DType.Float32;
|
|
1671
1682
|
declare const int32 = DType.Int32;
|
|
@@ -1725,6 +1736,14 @@ declare const notEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
|
1725
1736
|
declare const greaterEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
1726
1737
|
/** @function Compare two arrays element-wise. */
|
|
1727
1738
|
declare const lessEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
1739
|
+
/** Compute element-wise logical AND. */
|
|
1740
|
+
declare function logicalAnd(x: ArrayLike, y: ArrayLike): Array;
|
|
1741
|
+
/** Compute element-wise logical OR. */
|
|
1742
|
+
declare function logicalOr(x: ArrayLike, y: ArrayLike): Array;
|
|
1743
|
+
/** Compute element-wise logical XOR. */
|
|
1744
|
+
declare function logicalXor(x: ArrayLike, y: ArrayLike): Array;
|
|
1745
|
+
/** Compute element-wise logical NOT. */
|
|
1746
|
+
declare function logicalNot(x: ArrayLike): Array;
|
|
1728
1747
|
/** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
1729
1748
|
declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
|
|
1730
1749
|
/**
|
|
@@ -1809,6 +1828,16 @@ declare function all(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
|
1809
1828
|
declare function ptp(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1810
1829
|
/** Compute the average of the array elements along the specified axis. */
|
|
1811
1830
|
declare function mean(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1831
|
+
/**
|
|
1832
|
+
* Compute the weighted average along the specified axis.
|
|
1833
|
+
*
|
|
1834
|
+
* If no axis is specified, mean is computed along all the axes. The weights
|
|
1835
|
+
* should have shape matching that of `a`, or if an axis is specified, it should
|
|
1836
|
+
* match the shape along those axes.
|
|
1837
|
+
*/
|
|
1838
|
+
declare function average(a: ArrayLike, axis?: Axis, opts?: {
|
|
1839
|
+
weights?: ArrayLike;
|
|
1840
|
+
} & ReduceOpts): Array;
|
|
1812
1841
|
/**
|
|
1813
1842
|
* Returns the indices of the minimum values along an axis.
|
|
1814
1843
|
*
|
|
@@ -1980,13 +2009,39 @@ declare function argsort(a: ArrayLike, axis?: number): Array;
|
|
|
1980
2009
|
* numbered axis. By default, the flattened array is used.
|
|
1981
2010
|
*/
|
|
1982
2011
|
declare function take(a: ArrayLike, indices: ArrayLike, axis?: number | null): Array;
|
|
1983
|
-
/**
|
|
2012
|
+
/**
|
|
2013
|
+
* Return if two arrays are element-wise equal within a tolerance.
|
|
2014
|
+
*
|
|
2015
|
+
* The formula used is `|actual - expected| <= atol + rtol * |expected|`, with
|
|
2016
|
+
* NaN values comparing equal if `equalNaN` is true.
|
|
2017
|
+
*/
|
|
1984
2018
|
declare function allclose(actual: Parameters<typeof array>[0], expected: Parameters<typeof array>[0], options?: {
|
|
1985
2019
|
rtol?: number;
|
|
1986
2020
|
atol?: number;
|
|
2021
|
+
equalNaN?: boolean;
|
|
1987
2022
|
}): boolean;
|
|
2023
|
+
/**
|
|
2024
|
+
* Check if two arrays are element-wise equal.
|
|
2025
|
+
*
|
|
2026
|
+
* Returns False if the arrays have different shapes. If `equalNaN` is True,
|
|
2027
|
+
* NaNs in the same position are considered equal.
|
|
2028
|
+
*/
|
|
2029
|
+
declare function arrayEqual(a1: ArrayLike, a2: ArrayLike, opts?: {
|
|
2030
|
+
equalNaN?: boolean;
|
|
2031
|
+
}): Array;
|
|
2032
|
+
/**
|
|
2033
|
+
* Check if two arrays are element-wise equal after broadcasting.
|
|
2034
|
+
*
|
|
2035
|
+
* Unlike `arrayEqual`, this allows inputs with different but
|
|
2036
|
+
* broadcast-compatible shapes.
|
|
2037
|
+
*/
|
|
2038
|
+
declare function arrayEquiv(a1: ArrayLike, a2: ArrayLike): Array;
|
|
1988
2039
|
/** Matrix product of two arrays. */
|
|
1989
2040
|
declare function matmul(x: ArrayLike, y: ArrayLike): Array;
|
|
2041
|
+
/** Matrix-vector product. x1 is [..., M, N], x2 is [..., N] → [..., M]. */
|
|
2042
|
+
declare function matvec(x1: ArrayLike, x2: ArrayLike): Array;
|
|
2043
|
+
/** Vector-matrix product. x1 is [..., N], x2 is [..., N, M] → [..., M]. */
|
|
2044
|
+
declare function vecmat(x1: ArrayLike, x2: ArrayLike): Array;
|
|
1990
2045
|
/** Dot product of two arrays. */
|
|
1991
2046
|
declare function dot(x: ArrayLike, y: ArrayLike): Array;
|
|
1992
2047
|
/**
|
|
@@ -2039,6 +2094,18 @@ declare function inner(x: ArrayLike, y: ArrayLike): Array;
|
|
|
2039
2094
|
* be of shape `[x.size, y.size]`.
|
|
2040
2095
|
*/
|
|
2041
2096
|
declare function outer(x: ArrayLike, y: ArrayLike): Array;
|
|
2097
|
+
/**
|
|
2098
|
+
* @function Compute the cross product of two arrays.
|
|
2099
|
+
*
|
|
2100
|
+
* Supports 2D (scalar result) and 3D cross products, with optional axis
|
|
2101
|
+
* arguments. If `axis` is given, it overrides `axisa`, `axisb`, and `axisc`.
|
|
2102
|
+
*/
|
|
2103
|
+
declare const cross: OwnedFunction<(a: ArrayLike, b: ArrayLike, args_2?: {
|
|
2104
|
+
axisa?: number | undefined;
|
|
2105
|
+
axisb?: number | undefined;
|
|
2106
|
+
axisc?: number | undefined;
|
|
2107
|
+
axis?: number | undefined;
|
|
2108
|
+
} | undefined) => Array>;
|
|
2042
2109
|
/** Vector dot product of two arrays along a given axis. */
|
|
2043
2110
|
declare function vecdot(x: ArrayLike, y: ArrayLike, {
|
|
2044
2111
|
axis
|
|
@@ -2084,14 +2151,13 @@ declare function clip(a: ArrayLike, min?: ArrayLike, max?: ArrayLike): Array;
|
|
|
2084
2151
|
declare function absolute(x: ArrayLike): Array;
|
|
2085
2152
|
/** Return an element-wise indication of sign of the input. */
|
|
2086
2153
|
declare function sign(x: ArrayLike): Array;
|
|
2087
|
-
/** @function Return element-wise positive values of the input (no-op). */
|
|
2088
|
-
declare const positive: (x: ArrayLike) => Array;
|
|
2089
2154
|
/**
|
|
2090
|
-
*
|
|
2091
|
-
*
|
|
2092
|
-
* `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
2155
|
+
* @function
|
|
2156
|
+
* Return the value with the magnitude of x and the sign of y, element-wise.
|
|
2093
2157
|
*/
|
|
2094
|
-
declare
|
|
2158
|
+
declare const copysign: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
|
|
2159
|
+
/** @function Return element-wise positive values of the input (no-op). */
|
|
2160
|
+
declare const positive: (x: ArrayLike) => Array;
|
|
2095
2161
|
/**
|
|
2096
2162
|
* Return the Hann window of size M, a taper with a weighted cosine bell.
|
|
2097
2163
|
*
|
|
@@ -2186,6 +2252,18 @@ declare const remainder: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
|
|
|
2186
2252
|
declare function divmod(x: ArrayLike, y: ArrayLike): [Array, Array];
|
|
2187
2253
|
/** Round input to the nearest integer towards zero. */
|
|
2188
2254
|
declare function trunc(x: ArrayLike): Array;
|
|
2255
|
+
/**
|
|
2256
|
+
* @function
|
|
2257
|
+
* Round to the given number of decimals.
|
|
2258
|
+
*
|
|
2259
|
+
* Uses banker's rounding (round half to even) to match NumPy/JAX behavior.
|
|
2260
|
+
*/
|
|
2261
|
+
declare const round: OwnedFunction<(a: ArrayLike, decimals?: number | undefined) => Array>;
|
|
2262
|
+
/**
|
|
2263
|
+
* @function
|
|
2264
|
+
* Round to the nearest integer, with ties going to the nearest even integer.
|
|
2265
|
+
*/
|
|
2266
|
+
declare const rint: OwnedFunction<(x: ArrayLike) => Array>;
|
|
2189
2267
|
/**
|
|
2190
2268
|
* Compute `x1 * 2 ** x2` as a standard multiplication and exponentiation.
|
|
2191
2269
|
*
|
|
@@ -2688,8 +2766,31 @@ declare namespace scipy_special_d_exports {
|
|
|
2688
2766
|
* The logit function, `logit(p) = log(p / (1-p))`.
|
|
2689
2767
|
*/
|
|
2690
2768
|
declare const logit: OwnedFunction<(x: ArrayLike) => Array>;
|
|
2769
|
+
//#endregion
|
|
2770
|
+
//#region src/tracing.d.ts
|
|
2771
|
+
/**
|
|
2772
|
+
* Start collecting kernel traces.
|
|
2773
|
+
*
|
|
2774
|
+
* Traces appear in developer tools under the "Performance" tab, and they are
|
|
2775
|
+
* useful for measuring fine-grained kernel execution time.
|
|
2776
|
+
*/
|
|
2777
|
+
declare function startTrace(): void;
|
|
2778
|
+
/**
|
|
2779
|
+
* Stop collecting kernel traces.
|
|
2780
|
+
*
|
|
2781
|
+
* Traces appear in developer tools under the "Performance" tab, and they are
|
|
2782
|
+
* useful for measuring fine-grained kernel execution time.
|
|
2783
|
+
*/
|
|
2784
|
+
declare function stopTrace(): void;
|
|
2785
|
+
/** Check if tracing is currently enabled. */
|
|
2786
|
+
|
|
2691
2787
|
//#endregion
|
|
2692
2788
|
//#region src/index.d.ts
|
|
2789
|
+
/** @namespace */
|
|
2790
|
+
declare const profiler: {
|
|
2791
|
+
startTrace: typeof startTrace;
|
|
2792
|
+
stopTrace: typeof stopTrace;
|
|
2793
|
+
};
|
|
2693
2794
|
/**
|
|
2694
2795
|
* @function
|
|
2695
2796
|
* Compute the forward-mode Jacobian-vector product for a function.
|
|
@@ -2854,4 +2955,4 @@ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
|
|
|
2854
2955
|
*/
|
|
2855
2956
|
declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
|
|
2856
2957
|
//#endregion
|
|
2857
|
-
export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|
|
2958
|
+
export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, profiler, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|
package/dist/index.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { __export } from "./chunk-Cl8Af3a2.js";
|
|
2
|
-
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-
|
|
2
|
+
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, startTrace, stopTrace, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-Ctqs8la1.js";
|
|
3
3
|
|
|
4
4
|
//#region src/frontend/convolution.ts
|
|
5
5
|
/**
|
|
@@ -807,6 +807,11 @@ var Tracer = class Tracer {
|
|
|
807
807
|
if (this.dtype === dtype) return this;
|
|
808
808
|
return cast(this, dtype);
|
|
809
809
|
}
|
|
810
|
+
/** Return a bitwise cast of the array, viewed as a new dtype. */
|
|
811
|
+
view(dtype) {
|
|
812
|
+
if (!dtype || dtype === this.dtype) return this;
|
|
813
|
+
return bitcast(this, dtype);
|
|
814
|
+
}
|
|
810
815
|
/** Subtract an array from this one. */
|
|
811
816
|
sub(other) {
|
|
812
817
|
return this.add(neg(other));
|
|
@@ -1624,7 +1629,7 @@ const abstractEvalRules = {
|
|
|
1624
1629
|
return [new ShapedArray(x.shape, dtype, false)];
|
|
1625
1630
|
},
|
|
1626
1631
|
[Primitive.Bitcast]([x], { dtype }) {
|
|
1627
|
-
if (x.dtype === DType.Bool || dtype === DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
|
|
1632
|
+
if (x.dtype !== dtype && (x.dtype === DType.Bool || dtype === DType.Bool)) throw new TypeError("Bitcast to/from bool is not allowed");
|
|
1628
1633
|
if (byteWidth(x.dtype) !== byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
|
|
1629
1634
|
return [new ShapedArray(x.shape, dtype, false)];
|
|
1630
1635
|
},
|
|
@@ -3046,8 +3051,8 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
3046
3051
|
return [x.#unary(AluOp.Cast, dtype)];
|
|
3047
3052
|
},
|
|
3048
3053
|
[Primitive.Bitcast]([x], { dtype }) {
|
|
3049
|
-
if (x.dtype === DType.Bool || dtype === DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
|
|
3050
3054
|
if (x.dtype === dtype) return [x];
|
|
3055
|
+
if (x.dtype === DType.Bool || dtype === DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
|
|
3051
3056
|
if (byteWidth(x.dtype) !== byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
|
|
3052
3057
|
if (x.#source instanceof AluExp) return [x.#unary(AluOp.Bitcast, dtype)];
|
|
3053
3058
|
else {
|
|
@@ -4142,6 +4147,7 @@ const jvpRules = {
|
|
|
4142
4147
|
},
|
|
4143
4148
|
[Primitive.TriangularSolve]([a, b], [da, db], { unitDiagonal }) {
|
|
4144
4149
|
const x = triangularSolve$1(a.ref, b, { unitDiagonal });
|
|
4150
|
+
da = unitDiagonal ? triu(da, 1) : triu(da);
|
|
4145
4151
|
const dax = batchMatmulT(da, x.ref);
|
|
4146
4152
|
const rhsT = db.sub(mT(dax));
|
|
4147
4153
|
const dx = triangularSolve$1(a, rhsT, { unitDiagonal });
|
|
@@ -5217,6 +5223,7 @@ function ifft(a, axis = -1) {
|
|
|
5217
5223
|
var numpy_linalg_exports = {};
|
|
5218
5224
|
__export(numpy_linalg_exports, {
|
|
5219
5225
|
cholesky: () => cholesky,
|
|
5226
|
+
cross: () => cross$1,
|
|
5220
5227
|
det: () => det,
|
|
5221
5228
|
diagonal: () => diagonal,
|
|
5222
5229
|
inv: () => inv,
|
|
@@ -5247,6 +5254,19 @@ function cholesky(a, { upper = false, symmetrizeInput = true } = {}) {
|
|
|
5247
5254
|
if (symmetrizeInput) a = a.ref.add(matrixTranspose(a)).mul(.5);
|
|
5248
5255
|
return cholesky$1(a, { upper });
|
|
5249
5256
|
}
|
|
5257
|
+
/**
|
|
5258
|
+
* Compute the cross-product of two 3D vectors.
|
|
5259
|
+
*
|
|
5260
|
+
* This is a simpler and less flexible version of `jax.numpy.cross()`.
|
|
5261
|
+
* Both inputs must have size 3 along the specified axis.
|
|
5262
|
+
*/
|
|
5263
|
+
function cross$1(x1, x2, axis = -1) {
|
|
5264
|
+
const a1 = checkAxis(axis, ndim(x1));
|
|
5265
|
+
const a2 = checkAxis(axis, ndim(x2));
|
|
5266
|
+
if (shape(x1)[a1] !== 3) throw new Error(`linalg.cross: x1 must have size 3 along axis ${axis}, got ${shape(x1)[a1]}`);
|
|
5267
|
+
if (shape(x2)[a2] !== 3) throw new Error(`linalg.cross: x2 must have size 3 along axis ${axis}, got ${shape(x2)[a2]}`);
|
|
5268
|
+
return cross(x1, x2, { axis });
|
|
5269
|
+
}
|
|
5250
5270
|
/** Compute the determinant of a square matrix (batched). */
|
|
5251
5271
|
function det(a) {
|
|
5252
5272
|
a = fudgeArray(a);
|
|
@@ -5262,7 +5282,7 @@ function det(a) {
|
|
|
5262
5282
|
function inv(a) {
|
|
5263
5283
|
a = fudgeArray(a);
|
|
5264
5284
|
const n = checkSquare("inv", a);
|
|
5265
|
-
return solve(a, eye(n));
|
|
5285
|
+
return solve(a, eye(n, void 0, { dtype: a.dtype }));
|
|
5266
5286
|
}
|
|
5267
5287
|
/**
|
|
5268
5288
|
* Return the least-squares solution to a linear equation.
|
|
@@ -5319,8 +5339,9 @@ function matrixPower(a, n) {
|
|
|
5319
5339
|
a = fudgeArray(a);
|
|
5320
5340
|
const m = checkSquare("matrixPower", a);
|
|
5321
5341
|
if (n === 0) {
|
|
5342
|
+
const dtype = a.dtype;
|
|
5322
5343
|
a.dispose();
|
|
5323
|
-
return broadcastTo(eye(m), a.shape);
|
|
5344
|
+
return broadcastTo(eye(m, void 0, { dtype }), a.shape);
|
|
5324
5345
|
}
|
|
5325
5346
|
if (n < 0) {
|
|
5326
5347
|
a = inv(a);
|
|
@@ -5502,13 +5523,17 @@ __export(numpy_exports, {
|
|
|
5502
5523
|
argmax: () => argmax,
|
|
5503
5524
|
argmin: () => argmin,
|
|
5504
5525
|
argsort: () => argsort,
|
|
5526
|
+
around: () => round,
|
|
5505
5527
|
array: () => array,
|
|
5528
|
+
arrayEqual: () => arrayEqual,
|
|
5529
|
+
arrayEquiv: () => arrayEquiv,
|
|
5506
5530
|
asin: () => asin,
|
|
5507
5531
|
asinh: () => arcsinh,
|
|
5508
5532
|
astype: () => astype,
|
|
5509
5533
|
atan: () => atan,
|
|
5510
5534
|
atan2: () => atan2,
|
|
5511
5535
|
atanh: () => arctanh,
|
|
5536
|
+
average: () => average,
|
|
5512
5537
|
bool: () => bool,
|
|
5513
5538
|
broadcastArrays: () => broadcastArrays,
|
|
5514
5539
|
broadcastShapes: () => broadcastShapes,
|
|
@@ -5519,11 +5544,13 @@ __export(numpy_exports, {
|
|
|
5519
5544
|
columnStack: () => columnStack,
|
|
5520
5545
|
concatenate: () => concatenate,
|
|
5521
5546
|
convolve: () => convolve,
|
|
5547
|
+
copysign: () => copysign,
|
|
5522
5548
|
corrcoef: () => corrcoef,
|
|
5523
5549
|
correlate: () => correlate,
|
|
5524
5550
|
cos: () => cos,
|
|
5525
5551
|
cosh: () => cosh,
|
|
5526
5552
|
cov: () => cov,
|
|
5553
|
+
cross: () => cross,
|
|
5527
5554
|
cumsum: () => cumsum,
|
|
5528
5555
|
cumulativeSum: () => cumsum,
|
|
5529
5556
|
deg2rad: () => deg2rad,
|
|
@@ -5559,7 +5586,6 @@ __export(numpy_exports, {
|
|
|
5559
5586
|
fullLike: () => fullLike$1,
|
|
5560
5587
|
greater: () => greater,
|
|
5561
5588
|
greaterEqual: () => greaterEqual,
|
|
5562
|
-
hamming: () => hamming,
|
|
5563
5589
|
hann: () => hann,
|
|
5564
5590
|
heaviside: () => heaviside,
|
|
5565
5591
|
hstack: () => hstack,
|
|
@@ -5583,9 +5609,14 @@ __export(numpy_exports, {
|
|
|
5583
5609
|
log10: () => log10,
|
|
5584
5610
|
log1p: () => log1p,
|
|
5585
5611
|
log2: () => log2,
|
|
5612
|
+
logicalAnd: () => logicalAnd,
|
|
5613
|
+
logicalNot: () => logicalNot,
|
|
5614
|
+
logicalOr: () => logicalOr,
|
|
5615
|
+
logicalXor: () => logicalXor,
|
|
5586
5616
|
logspace: () => logspace,
|
|
5587
5617
|
matmul: () => matmul,
|
|
5588
5618
|
matrixTranspose: () => matrixTranspose,
|
|
5619
|
+
matvec: () => matvec,
|
|
5589
5620
|
max: () => max,
|
|
5590
5621
|
maximum: () => maximum,
|
|
5591
5622
|
mean: () => mean,
|
|
@@ -5618,6 +5649,8 @@ __export(numpy_exports, {
|
|
|
5618
5649
|
remainder: () => remainder,
|
|
5619
5650
|
repeat: () => repeat,
|
|
5620
5651
|
reshape: () => reshape,
|
|
5652
|
+
rint: () => rint,
|
|
5653
|
+
round: () => round,
|
|
5621
5654
|
shape: () => shape,
|
|
5622
5655
|
sign: () => sign,
|
|
5623
5656
|
sin: () => sin,
|
|
@@ -5650,6 +5683,7 @@ __export(numpy_exports, {
|
|
|
5650
5683
|
var_: () => var_,
|
|
5651
5684
|
vdot: () => vdot,
|
|
5652
5685
|
vecdot: () => vecdot,
|
|
5686
|
+
vecmat: () => vecmat,
|
|
5653
5687
|
vstack: () => vstack,
|
|
5654
5688
|
where: () => where,
|
|
5655
5689
|
zeros: () => zeros,
|
|
@@ -5713,6 +5747,22 @@ const notEqual = notEqual$1;
|
|
|
5713
5747
|
const greaterEqual = greaterEqual$1;
|
|
5714
5748
|
/** @function Compare two arrays element-wise. */
|
|
5715
5749
|
const lessEqual = lessEqual$1;
|
|
5750
|
+
/** Compute element-wise logical AND. */
|
|
5751
|
+
function logicalAnd(x, y) {
|
|
5752
|
+
return astype(x, DType.Bool).mul(astype(y, DType.Bool));
|
|
5753
|
+
}
|
|
5754
|
+
/** Compute element-wise logical OR. */
|
|
5755
|
+
function logicalOr(x, y) {
|
|
5756
|
+
return astype(x, DType.Bool).add(astype(y, DType.Bool));
|
|
5757
|
+
}
|
|
5758
|
+
/** Compute element-wise logical XOR. */
|
|
5759
|
+
function logicalXor(x, y) {
|
|
5760
|
+
return notEqual(astype(x, DType.Bool), astype(y, DType.Bool));
|
|
5761
|
+
}
|
|
5762
|
+
/** Compute element-wise logical NOT. */
|
|
5763
|
+
function logicalNot(x) {
|
|
5764
|
+
return notEqual(astype(x, DType.Bool), true);
|
|
5765
|
+
}
|
|
5716
5766
|
/** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
5717
5767
|
const where = where$1;
|
|
5718
5768
|
/**
|
|
@@ -5820,6 +5870,34 @@ function mean(a, axis = null, opts) {
|
|
|
5820
5870
|
return fudgeArray(a).mean(axis, opts);
|
|
5821
5871
|
}
|
|
5822
5872
|
/**
|
|
5873
|
+
* Compute the weighted average along the specified axis.
|
|
5874
|
+
*
|
|
5875
|
+
* If no axis is specified, mean is computed along all the axes. The weights
|
|
5876
|
+
* should have shape matching that of `a`, or if an axis is specified, it should
|
|
5877
|
+
* match the shape along those axes.
|
|
5878
|
+
*/
|
|
5879
|
+
function average(a, axis = null, opts) {
|
|
5880
|
+
a = fudgeArray(a);
|
|
5881
|
+
if (opts?.weights == null) return mean(a, axis, opts);
|
|
5882
|
+
const weights = fudgeArray(opts.weights);
|
|
5883
|
+
axis = normalizeAxis(axis, ndim(a));
|
|
5884
|
+
const wShape = weights.shape;
|
|
5885
|
+
const aShape = a.shape;
|
|
5886
|
+
if (deepEqual(wShape, aShape)) {
|
|
5887
|
+
const scl = sum(weights.ref, axis, opts);
|
|
5888
|
+
return sum(multiply(a, weights), axis, opts).div(scl);
|
|
5889
|
+
} else if (axis.length === 1 && wShape.length === 1 && wShape[0] === aShape[axis[0]]) {
|
|
5890
|
+
const broadcastShape = aShape.map((_, i) => i === axis[0] ? wShape[0] : 1);
|
|
5891
|
+
const wReshaped = reshape(weights, broadcastShape);
|
|
5892
|
+
const scl = sum(wReshaped.ref, axis, opts);
|
|
5893
|
+
return sum(multiply(a, wReshaped), axis, opts).div(scl);
|
|
5894
|
+
} else {
|
|
5895
|
+
weights.dispose();
|
|
5896
|
+
a.dispose();
|
|
5897
|
+
throw new Error(`average: weights shape ${JSON.stringify(wShape)} is not compatible with array shape ${JSON.stringify(aShape)} and axis ${JSON.stringify(axis)}`);
|
|
5898
|
+
}
|
|
5899
|
+
}
|
|
5900
|
+
/**
|
|
5823
5901
|
* Returns the indices of the minimum values along an axis.
|
|
5824
5902
|
*
|
|
5825
5903
|
* By default, index is into the flatted array, otherwise it is along the
|
|
@@ -6223,20 +6301,63 @@ function take(a, indices, axis = null) {
|
|
|
6223
6301
|
axis = checkAxis(axis, ndim(a));
|
|
6224
6302
|
return gather(a, [indices], [axis], axis);
|
|
6225
6303
|
}
|
|
6226
|
-
/**
|
|
6304
|
+
/**
|
|
6305
|
+
* Return if two arrays are element-wise equal within a tolerance.
|
|
6306
|
+
*
|
|
6307
|
+
* The formula used is `|actual - expected| <= atol + rtol * |expected|`, with
|
|
6308
|
+
* NaN values comparing equal if `equalNaN` is true.
|
|
6309
|
+
*/
|
|
6227
6310
|
function allclose(actual, expected, options) {
|
|
6228
|
-
const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
|
|
6311
|
+
const { rtol = 1e-5, atol = 1e-7, equalNaN = false } = options ?? {};
|
|
6229
6312
|
const x = array(actual);
|
|
6230
6313
|
const y = array(expected);
|
|
6231
6314
|
if (!deepEqual(x.shape, y.shape)) return false;
|
|
6232
6315
|
const xData = x.dataSync();
|
|
6233
6316
|
const yData = y.dataSync();
|
|
6234
6317
|
for (let i = 0; i < xData.length; i++) {
|
|
6235
|
-
if (isNaN(xData[i]) !== isNaN(yData[i])) return false;
|
|
6318
|
+
if (equalNaN ? isNaN(xData[i]) !== isNaN(yData[i]) : isNaN(xData[i]) || isNaN(yData[i])) return false;
|
|
6236
6319
|
if (Math.abs(xData[i] - yData[i]) > atol + rtol * Math.abs(yData[i])) return false;
|
|
6237
6320
|
}
|
|
6238
6321
|
return true;
|
|
6239
6322
|
}
|
|
6323
|
+
/**
|
|
6324
|
+
* Check if two arrays are element-wise equal.
|
|
6325
|
+
*
|
|
6326
|
+
* Returns False if the arrays have different shapes. If `equalNaN` is True,
|
|
6327
|
+
* NaNs in the same position are considered equal.
|
|
6328
|
+
*/
|
|
6329
|
+
function arrayEqual(a1, a2, opts) {
|
|
6330
|
+
a1 = fudgeArray(a1);
|
|
6331
|
+
a2 = fudgeArray(a2);
|
|
6332
|
+
if (!deepEqual(a1.shape, a2.shape)) {
|
|
6333
|
+
a1.dispose();
|
|
6334
|
+
a2.dispose();
|
|
6335
|
+
return array(false);
|
|
6336
|
+
}
|
|
6337
|
+
if (opts?.equalNaN) {
|
|
6338
|
+
const nanMask = isnan(a1.ref).mul(isnan(a2.ref));
|
|
6339
|
+
return where(nanMask, true, equal(a1, a2)).all();
|
|
6340
|
+
}
|
|
6341
|
+
return equal(a1, a2).all();
|
|
6342
|
+
}
|
|
6343
|
+
/**
|
|
6344
|
+
* Check if two arrays are element-wise equal after broadcasting.
|
|
6345
|
+
*
|
|
6346
|
+
* Unlike `arrayEqual`, this allows inputs with different but
|
|
6347
|
+
* broadcast-compatible shapes.
|
|
6348
|
+
*/
|
|
6349
|
+
function arrayEquiv(a1, a2) {
|
|
6350
|
+
a1 = fudgeArray(a1);
|
|
6351
|
+
a2 = fudgeArray(a2);
|
|
6352
|
+
try {
|
|
6353
|
+
const [b1, b2] = broadcastArrays(a1, a2);
|
|
6354
|
+
return equal(b1, b2).all();
|
|
6355
|
+
} catch {
|
|
6356
|
+
a1.dispose();
|
|
6357
|
+
a2.dispose();
|
|
6358
|
+
return array(false);
|
|
6359
|
+
}
|
|
6360
|
+
}
|
|
6240
6361
|
/** Matrix product of two arrays. */
|
|
6241
6362
|
function matmul(x, y) {
|
|
6242
6363
|
if (ndim(x) === 0 || ndim(y) === 0) throw new Error("matmul: x and y must be at least 1D");
|
|
@@ -6250,6 +6371,16 @@ function matmul(x, y) {
|
|
|
6250
6371
|
rhsBatchDims: range(-2 - numBatchDims, -2)
|
|
6251
6372
|
});
|
|
6252
6373
|
}
|
|
6374
|
+
/** Matrix-vector product. x1 is [..., M, N], x2 is [..., N] → [..., M]. */
|
|
6375
|
+
function matvec(x1, x2) {
|
|
6376
|
+
if (ndim(x1) < 2 || ndim(x2) < 1) throw new Error("matvec: x1 must be at least 2D and x2 at least 1D");
|
|
6377
|
+
return einsum("...mn,...n->...m", x1, x2);
|
|
6378
|
+
}
|
|
6379
|
+
/** Vector-matrix product. x1 is [..., N], x2 is [..., N, M] → [..., M]. */
|
|
6380
|
+
function vecmat(x1, x2) {
|
|
6381
|
+
if (ndim(x1) < 1 || ndim(x2) < 2) throw new Error("vecmat: x1 must be at least 1D and x2 at least 2D");
|
|
6382
|
+
return einsum("...n,...nm->...m", x1, x2);
|
|
6383
|
+
}
|
|
6253
6384
|
/** Dot product of two arrays. */
|
|
6254
6385
|
function dot$1(x, y) {
|
|
6255
6386
|
if (ndim(x) === 0 || ndim(y) === 0) return multiply(x, y);
|
|
@@ -6408,6 +6539,49 @@ function outer(x, y) {
|
|
|
6408
6539
|
y = ravel(y);
|
|
6409
6540
|
return multiply(x.reshape([x.shape[0], 1]), y);
|
|
6410
6541
|
}
|
|
6542
|
+
/**
|
|
6543
|
+
* @function Compute the cross product of two arrays.
|
|
6544
|
+
*
|
|
6545
|
+
* Supports 2D (scalar result) and 3D cross products, with optional axis
|
|
6546
|
+
* arguments. If `axis` is given, it overrides `axisa`, `axisb`, and `axisc`.
|
|
6547
|
+
*/
|
|
6548
|
+
const cross = jit$1(function cross$2(a, b, { axisa = -1, axisb = -1, axisc = -1, axis } = {}) {
|
|
6549
|
+
if (axis !== void 0) {
|
|
6550
|
+
axisa = axis;
|
|
6551
|
+
axisb = axis;
|
|
6552
|
+
axisc = axis;
|
|
6553
|
+
}
|
|
6554
|
+
axisa = checkAxis(axisa, ndim(a));
|
|
6555
|
+
axisb = checkAxis(axisb, ndim(b));
|
|
6556
|
+
a = moveaxis$1(a, axisa, -1);
|
|
6557
|
+
b = moveaxis$1(b, axisb, -1);
|
|
6558
|
+
const da = a.shape.at(-1);
|
|
6559
|
+
const db = b.shape.at(-1);
|
|
6560
|
+
if (da !== 2 && da !== 3 || db !== 2 && db !== 3) throw new Error(`cross: incompatible dimensions for cross product (got ${da} and ${db})`);
|
|
6561
|
+
if (da === 2 && db === 2) {
|
|
6562
|
+
const [a0$1, a1$1] = split$1(a, 2, -1);
|
|
6563
|
+
const [b0$1, b1$1] = split$1(b, 2, -1);
|
|
6564
|
+
return squeeze(a0$1.mul(b1$1).sub(a1$1.mul(b0$1)), -1);
|
|
6565
|
+
}
|
|
6566
|
+
if (da === 2) {
|
|
6567
|
+
const zeroShape = [...a.shape.slice(0, -1), 1];
|
|
6568
|
+
a = concatenate([a, zeros(zeroShape)], -1);
|
|
6569
|
+
}
|
|
6570
|
+
if (db === 2) {
|
|
6571
|
+
const zeroShape = [...b.shape.slice(0, -1), 1];
|
|
6572
|
+
b = concatenate([b, zeros(zeroShape)], -1);
|
|
6573
|
+
}
|
|
6574
|
+
const [a0, a1, a2] = split$1(a, 3, -1);
|
|
6575
|
+
const [b0, b1, b2] = split$1(b, 3, -1);
|
|
6576
|
+
const c0 = a1.ref.mul(b2.ref).sub(a2.ref.mul(b1.ref));
|
|
6577
|
+
const c1 = a2.mul(b0.ref).sub(a0.ref.mul(b2));
|
|
6578
|
+
const c2 = a0.mul(b1).sub(a1.mul(b0));
|
|
6579
|
+
return moveaxis$1(concatenate([
|
|
6580
|
+
c0,
|
|
6581
|
+
c1,
|
|
6582
|
+
c2
|
|
6583
|
+
], -1), -1, axisc);
|
|
6584
|
+
}, { staticArgnums: [2] });
|
|
6411
6585
|
/** Vector dot product of two arrays along a given axis. */
|
|
6412
6586
|
function vecdot(x, y, { axis } = {}) {
|
|
6413
6587
|
const xaxis = checkAxis(axis ?? -1, ndim(x));
|
|
@@ -6504,16 +6678,15 @@ function sign(x) {
|
|
|
6504
6678
|
x = fudgeArray(x);
|
|
6505
6679
|
return where(notEqual(x.ref, 0), where(less(x, 0), -1, 1), 0);
|
|
6506
6680
|
}
|
|
6507
|
-
/** @function Return element-wise positive values of the input (no-op). */
|
|
6508
|
-
const positive = fudgeArray;
|
|
6509
6681
|
/**
|
|
6510
|
-
*
|
|
6511
|
-
*
|
|
6512
|
-
* `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
6682
|
+
* @function
|
|
6683
|
+
* Return the value with the magnitude of x and the sign of y, element-wise.
|
|
6513
6684
|
*/
|
|
6514
|
-
function
|
|
6515
|
-
return
|
|
6516
|
-
}
|
|
6685
|
+
const copysign = jit$1(function copysign$1(x, y) {
|
|
6686
|
+
return absolute(x).mul(sign(y));
|
|
6687
|
+
});
|
|
6688
|
+
/** @function Return element-wise positive values of the input (no-op). */
|
|
6689
|
+
const positive = fudgeArray;
|
|
6517
6690
|
/**
|
|
6518
6691
|
* Return the Hann window of size M, a taper with a weighted cosine bell.
|
|
6519
6692
|
*
|
|
@@ -6659,6 +6832,27 @@ function trunc(x) {
|
|
|
6659
6832
|
return idiv(x, 1);
|
|
6660
6833
|
}
|
|
6661
6834
|
/**
|
|
6835
|
+
* @function
|
|
6836
|
+
* Round to the given number of decimals.
|
|
6837
|
+
*
|
|
6838
|
+
* Uses banker's rounding (round half to even) to match NumPy/JAX behavior.
|
|
6839
|
+
*/
|
|
6840
|
+
const round = jit$1(function round$1(a, decimals = 0) {
|
|
6841
|
+
if (decimals === 0) return rint(a);
|
|
6842
|
+
const factor = 10 ** decimals;
|
|
6843
|
+
return rint(a.mul(factor)).mul(1 / factor);
|
|
6844
|
+
}, { staticArgnums: [1] });
|
|
6845
|
+
/**
|
|
6846
|
+
* @function
|
|
6847
|
+
* Round to the nearest integer, with ties going to the nearest even integer.
|
|
6848
|
+
*/
|
|
6849
|
+
const rint = jit$1(function rint$1(x) {
|
|
6850
|
+
const rounded = floor(x.ref.add(.5));
|
|
6851
|
+
const half = x.ref.sub(floor(x)).equal(.5);
|
|
6852
|
+
const odd = remainder(rounded.ref, 2).notEqual(0);
|
|
6853
|
+
return where(half.mul(odd), rounded.ref.sub(1), rounded);
|
|
6854
|
+
});
|
|
6855
|
+
/**
|
|
6662
6856
|
* Compute `x1 * 2 ** x2` as a standard multiplication and exponentiation.
|
|
6663
6857
|
*
|
|
6664
6858
|
* This is the inverse of `frexp()`.
|
|
@@ -6986,6 +7180,7 @@ function triangularSolve(a, b, { leftSide = false, lower = false, transposeA = f
|
|
|
6986
7180
|
//#region src/library/lax.ts
|
|
6987
7181
|
var lax_exports = {};
|
|
6988
7182
|
__export(lax_exports, {
|
|
7183
|
+
bitcastConvertType: () => bitcastConvertType,
|
|
6989
7184
|
conv: () => conv,
|
|
6990
7185
|
convGeneralDilated: () => convGeneralDilated,
|
|
6991
7186
|
convTranspose: () => convTranspose,
|
|
@@ -6999,6 +7194,10 @@ __export(lax_exports, {
|
|
|
6999
7194
|
topK: () => topK
|
|
7000
7195
|
});
|
|
7001
7196
|
const JsArray = globalThis.Array;
|
|
7197
|
+
/** Elementwise bitcast an array into a new dtype. */
|
|
7198
|
+
function bitcastConvertType(x, newDtype) {
|
|
7199
|
+
return fudgeArray(x).view(newDtype);
|
|
7200
|
+
}
|
|
7002
7201
|
/**
|
|
7003
7202
|
* General dot product/contraction operator.
|
|
7004
7203
|
*
|
|
@@ -7730,7 +7929,9 @@ function getK01(key$1) {
|
|
|
7730
7929
|
function key(seed) {
|
|
7731
7930
|
seed = array(seed, { dtype: DType.Uint32 });
|
|
7732
7931
|
if (seed.ndim !== 0) throw new Error(`key: seed must be a scalar integer, but got shape ${seed.shape} - use jax.vmap for batching.`);
|
|
7733
|
-
|
|
7932
|
+
const key$1 = stack([0, seed]);
|
|
7933
|
+
if (key$1 instanceof Array$1) key$1._realizeSource();
|
|
7934
|
+
return key$1;
|
|
7734
7935
|
}
|
|
7735
7936
|
/** Splits a PRNG key into `num` new keys by adding a leading axis. */
|
|
7736
7937
|
function split(key$1, num = 2) {
|
|
@@ -7925,6 +8126,11 @@ Symbol.asyncDispose ??= Symbol.for("Symbol.asyncDispose");
|
|
|
7925
8126
|
|
|
7926
8127
|
//#endregion
|
|
7927
8128
|
//#region src/index.ts
|
|
8129
|
+
/** @namespace */
|
|
8130
|
+
const profiler = {
|
|
8131
|
+
startTrace,
|
|
8132
|
+
stopTrace
|
|
8133
|
+
};
|
|
7928
8134
|
/**
|
|
7929
8135
|
* @function
|
|
7930
8136
|
* Compute the forward-mode Jacobian-vector product for a function.
|
|
@@ -8085,4 +8291,4 @@ async function devicePut(x, device) {
|
|
|
8085
8291
|
}
|
|
8086
8292
|
|
|
8087
8293
|
//#endregion
|
|
8088
|
-
export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
|
|
8294
|
+
export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, profiler, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-
|
|
1
|
+
import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-Ctqs8la1.js";
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgl/builtins.ts
|
|
4
4
|
const threefrySrc = `
|