@jax-js/jax 0.1.9 → 0.1.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.d.ts CHANGED
@@ -1001,6 +1001,8 @@ declare abstract class Tracer {
1001
1001
  reshape(shape: number | number[]): this;
1002
1002
  /** Copy the array and cast to a specified dtype. */
1003
1003
  astype(dtype: DType): this;
1004
+ /** Return a bitwise cast of the array, viewed as a new dtype. */
1005
+ view(dtype?: DType): this;
1004
1006
  /** Subtract an array from this one. */
1005
1007
  sub(other: this | TracerValue): this;
1006
1008
  /** Divide an array by this one. */
@@ -1424,8 +1426,10 @@ declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
1424
1426
  unitDiagonal?: boolean;
1425
1427
  }): Array;
1426
1428
  declare namespace lax_d_exports {
1427
- export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient, topK };
1429
+ export { DotDimensionNumbers, PaddingType, bitcastConvertType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient, topK };
1428
1430
  }
1431
+ /** Elementwise bitcast an array into a new dtype. */
1432
+ declare function bitcastConvertType(x: ArrayLike, newDtype: DType): Array;
1429
1433
  /**
1430
1434
  * Dimension numbers for general `dot()` primitive.
1431
1435
  *
@@ -1564,7 +1568,7 @@ declare function fft(a: ComplexPair, axis?: number): ComplexPair;
1564
1568
  */
1565
1569
  declare function ifft(a: ComplexPair, axis?: number): ComplexPair;
1566
1570
  declare namespace numpy_linalg_d_exports {
1567
- export { cholesky, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot };
1571
+ export { cholesky, cross$1 as cross, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot };
1568
1572
  }
1569
1573
  /**
1570
1574
  * Compute the Cholesky decomposition of a (batched) positive-definite matrix.
@@ -1579,6 +1583,13 @@ declare function cholesky(a: ArrayLike, {
1579
1583
  upper?: boolean;
1580
1584
  symmetrizeInput?: boolean;
1581
1585
  }): Array;
1586
+ /**
1587
+ * Compute the cross-product of two 3D vectors.
1588
+ *
1589
+ * This is a simpler and less flexible version of `jax.numpy.cross()`.
1590
+ * Both inputs must have size 3 along the specified axis.
1591
+ */
1592
+ declare function cross$1(x1: ArrayLike, x2: ArrayLike, axis?: number): Array;
1582
1593
  /** Compute the determinant of a square matrix (batched). */
1583
1594
  declare function det(a: ArrayLike): Array;
1584
1595
  /** Compute the inverse of a square matrix (batched). */
@@ -1665,7 +1676,7 @@ type IInfo = Readonly<{
1665
1676
  /** Machine limits for integer types. */
1666
1677
  declare function iinfo(dtype: DType): IInfo;
1667
1678
  declare namespace numpy_d_exports {
1668
- export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logspace, matmul, matrixTranspose, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, nanToNum, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
1679
+ export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, round as around, array, arrayEqual, arrayEquiv, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, average, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, copysign, corrcoef, correlate, cos, cosh, cov, cross, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logicalAnd, logicalNot, logicalOr, logicalXor, logspace, matmul, matrixTranspose, matvec, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, nanToNum, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, rint, round, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vecmat, vstack, where, zeros, zerosLike };
1669
1680
  }
1670
1681
  declare const float32 = DType.Float32;
1671
1682
  declare const int32 = DType.Int32;
@@ -1725,6 +1736,14 @@ declare const notEqual: (x: ArrayLike, y: ArrayLike) => Array;
1725
1736
  declare const greaterEqual: (x: ArrayLike, y: ArrayLike) => Array;
1726
1737
  /** @function Compare two arrays element-wise. */
1727
1738
  declare const lessEqual: (x: ArrayLike, y: ArrayLike) => Array;
1739
+ /** Compute element-wise logical AND. */
1740
+ declare function logicalAnd(x: ArrayLike, y: ArrayLike): Array;
1741
+ /** Compute element-wise logical OR. */
1742
+ declare function logicalOr(x: ArrayLike, y: ArrayLike): Array;
1743
+ /** Compute element-wise logical XOR. */
1744
+ declare function logicalXor(x: ArrayLike, y: ArrayLike): Array;
1745
+ /** Compute element-wise logical NOT. */
1746
+ declare function logicalNot(x: ArrayLike): Array;
1728
1747
  /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
1729
1748
  declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
1730
1749
  /**
@@ -1809,6 +1828,16 @@ declare function all(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1809
1828
  declare function ptp(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1810
1829
  /** Compute the average of the array elements along the specified axis. */
1811
1830
  declare function mean(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1831
+ /**
1832
+ * Compute the weighted average along the specified axis.
1833
+ *
1834
+ * If no axis is specified, mean is computed along all the axes. The weights
1835
+ * should have shape matching that of `a`, or if an axis is specified, it should
1836
+ * match the shape along those axes.
1837
+ */
1838
+ declare function average(a: ArrayLike, axis?: Axis, opts?: {
1839
+ weights?: ArrayLike;
1840
+ } & ReduceOpts): Array;
1812
1841
  /**
1813
1842
  * Returns the indices of the minimum values along an axis.
1814
1843
  *
@@ -1980,13 +2009,39 @@ declare function argsort(a: ArrayLike, axis?: number): Array;
1980
2009
  * numbered axis. By default, the flattened array is used.
1981
2010
  */
1982
2011
  declare function take(a: ArrayLike, indices: ArrayLike, axis?: number | null): Array;
1983
- /** Return if two arrays are element-wise equal within a tolerance. */
2012
+ /**
2013
+ * Return if two arrays are element-wise equal within a tolerance.
2014
+ *
2015
+ * The formula used is `|actual - expected| <= atol + rtol * |expected|`, with
2016
+ * NaN values comparing equal if `equalNaN` is true.
2017
+ */
1984
2018
  declare function allclose(actual: Parameters<typeof array>[0], expected: Parameters<typeof array>[0], options?: {
1985
2019
  rtol?: number;
1986
2020
  atol?: number;
2021
+ equalNaN?: boolean;
1987
2022
  }): boolean;
2023
+ /**
2024
+ * Check if two arrays are element-wise equal.
2025
+ *
2026
+ * Returns False if the arrays have different shapes. If `equalNaN` is True,
2027
+ * NaNs in the same position are considered equal.
2028
+ */
2029
+ declare function arrayEqual(a1: ArrayLike, a2: ArrayLike, opts?: {
2030
+ equalNaN?: boolean;
2031
+ }): Array;
2032
+ /**
2033
+ * Check if two arrays are element-wise equal after broadcasting.
2034
+ *
2035
+ * Unlike `arrayEqual`, this allows inputs with different but
2036
+ * broadcast-compatible shapes.
2037
+ */
2038
+ declare function arrayEquiv(a1: ArrayLike, a2: ArrayLike): Array;
1988
2039
  /** Matrix product of two arrays. */
1989
2040
  declare function matmul(x: ArrayLike, y: ArrayLike): Array;
2041
+ /** Matrix-vector product. x1 is [..., M, N], x2 is [..., N] → [..., M]. */
2042
+ declare function matvec(x1: ArrayLike, x2: ArrayLike): Array;
2043
+ /** Vector-matrix product. x1 is [..., N], x2 is [..., N, M] → [..., M]. */
2044
+ declare function vecmat(x1: ArrayLike, x2: ArrayLike): Array;
1990
2045
  /** Dot product of two arrays. */
1991
2046
  declare function dot(x: ArrayLike, y: ArrayLike): Array;
1992
2047
  /**
@@ -2039,6 +2094,18 @@ declare function inner(x: ArrayLike, y: ArrayLike): Array;
2039
2094
  * be of shape `[x.size, y.size]`.
2040
2095
  */
2041
2096
  declare function outer(x: ArrayLike, y: ArrayLike): Array;
2097
+ /**
2098
+ * @function Compute the cross product of two arrays.
2099
+ *
2100
+ * Supports 2D (scalar result) and 3D cross products, with optional axis
2101
+ * arguments. If `axis` is given, it overrides `axisa`, `axisb`, and `axisc`.
2102
+ */
2103
+ declare const cross: OwnedFunction<(a: ArrayLike, b: ArrayLike, args_2?: {
2104
+ axisa?: number | undefined;
2105
+ axisb?: number | undefined;
2106
+ axisc?: number | undefined;
2107
+ axis?: number | undefined;
2108
+ } | undefined) => Array>;
2042
2109
  /** Vector dot product of two arrays along a given axis. */
2043
2110
  declare function vecdot(x: ArrayLike, y: ArrayLike, {
2044
2111
  axis
@@ -2084,14 +2151,13 @@ declare function clip(a: ArrayLike, min?: ArrayLike, max?: ArrayLike): Array;
2084
2151
  declare function absolute(x: ArrayLike): Array;
2085
2152
  /** Return an element-wise indication of sign of the input. */
2086
2153
  declare function sign(x: ArrayLike): Array;
2087
- /** @function Return element-wise positive values of the input (no-op). */
2088
- declare const positive: (x: ArrayLike) => Array;
2089
2154
  /**
2090
- * Return the Hamming window of size M, a taper with a weighted cosine bell.
2091
- *
2092
- * `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
2155
+ * @function
2156
+ * Return the value with the magnitude of x and the sign of y, element-wise.
2093
2157
  */
2094
- declare function hamming(M: number): Array;
2158
+ declare const copysign: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
2159
+ /** @function Return element-wise positive values of the input (no-op). */
2160
+ declare const positive: (x: ArrayLike) => Array;
2095
2161
  /**
2096
2162
  * Return the Hann window of size M, a taper with a weighted cosine bell.
2097
2163
  *
@@ -2186,6 +2252,18 @@ declare const remainder: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
2186
2252
  declare function divmod(x: ArrayLike, y: ArrayLike): [Array, Array];
2187
2253
  /** Round input to the nearest integer towards zero. */
2188
2254
  declare function trunc(x: ArrayLike): Array;
2255
+ /**
2256
+ * @function
2257
+ * Round to the given number of decimals.
2258
+ *
2259
+ * Uses banker's rounding (round half to even) to match NumPy/JAX behavior.
2260
+ */
2261
+ declare const round: OwnedFunction<(a: ArrayLike, decimals?: number | undefined) => Array>;
2262
+ /**
2263
+ * @function
2264
+ * Round to the nearest integer, with ties going to the nearest even integer.
2265
+ */
2266
+ declare const rint: OwnedFunction<(x: ArrayLike) => Array>;
2189
2267
  /**
2190
2268
  * Compute `x1 * 2 ** x2` as a standard multiplication and exponentiation.
2191
2269
  *
@@ -2688,8 +2766,31 @@ declare namespace scipy_special_d_exports {
2688
2766
  * The logit function, `logit(p) = log(p / (1-p))`.
2689
2767
  */
2690
2768
  declare const logit: OwnedFunction<(x: ArrayLike) => Array>;
2769
+ //#endregion
2770
+ //#region src/tracing.d.ts
2771
+ /**
2772
+ * Start collecting kernel traces.
2773
+ *
2774
+ * Traces appear in developer tools under the "Performance" tab, and they are
2775
+ * useful for measuring fine-grained kernel execution time.
2776
+ */
2777
+ declare function startTrace(): void;
2778
+ /**
2779
+ * Stop collecting kernel traces.
2780
+ *
2781
+ * Traces appear in developer tools under the "Performance" tab, and they are
2782
+ * useful for measuring fine-grained kernel execution time.
2783
+ */
2784
+ declare function stopTrace(): void;
2785
+ /** Check if tracing is currently enabled. */
2786
+
2691
2787
  //#endregion
2692
2788
  //#region src/index.d.ts
2789
+ /** @namespace */
2790
+ declare const profiler: {
2791
+ startTrace: typeof startTrace;
2792
+ stopTrace: typeof stopTrace;
2793
+ };
2693
2794
  /**
2694
2795
  * @function
2695
2796
  * Compute the forward-mode Jacobian-vector product for a function.
@@ -2854,4 +2955,4 @@ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
2854
2955
  */
2855
2956
  declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
2856
2957
  //#endregion
2857
- export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
2958
+ export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, profiler, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
package/dist/index.js CHANGED
@@ -1,5 +1,5 @@
1
1
  import { __export } from "./chunk-Cl8Af3a2.js";
2
- import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-BId79r5b.js";
2
+ import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, startTrace, stopTrace, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-Ctqs8la1.js";
3
3
 
4
4
  //#region src/frontend/convolution.ts
5
5
  /**
@@ -807,6 +807,11 @@ var Tracer = class Tracer {
807
807
  if (this.dtype === dtype) return this;
808
808
  return cast(this, dtype);
809
809
  }
810
+ /** Return a bitwise cast of the array, viewed as a new dtype. */
811
+ view(dtype) {
812
+ if (!dtype || dtype === this.dtype) return this;
813
+ return bitcast(this, dtype);
814
+ }
810
815
  /** Subtract an array from this one. */
811
816
  sub(other) {
812
817
  return this.add(neg(other));
@@ -1624,7 +1629,7 @@ const abstractEvalRules = {
1624
1629
  return [new ShapedArray(x.shape, dtype, false)];
1625
1630
  },
1626
1631
  [Primitive.Bitcast]([x], { dtype }) {
1627
- if (x.dtype === DType.Bool || dtype === DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
1632
+ if (x.dtype !== dtype && (x.dtype === DType.Bool || dtype === DType.Bool)) throw new TypeError("Bitcast to/from bool is not allowed");
1628
1633
  if (byteWidth(x.dtype) !== byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
1629
1634
  return [new ShapedArray(x.shape, dtype, false)];
1630
1635
  },
@@ -3046,8 +3051,8 @@ var Array$1 = class Array$1 extends Tracer {
3046
3051
  return [x.#unary(AluOp.Cast, dtype)];
3047
3052
  },
3048
3053
  [Primitive.Bitcast]([x], { dtype }) {
3049
- if (x.dtype === DType.Bool || dtype === DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
3050
3054
  if (x.dtype === dtype) return [x];
3055
+ if (x.dtype === DType.Bool || dtype === DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
3051
3056
  if (byteWidth(x.dtype) !== byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
3052
3057
  if (x.#source instanceof AluExp) return [x.#unary(AluOp.Bitcast, dtype)];
3053
3058
  else {
@@ -4142,6 +4147,7 @@ const jvpRules = {
4142
4147
  },
4143
4148
  [Primitive.TriangularSolve]([a, b], [da, db], { unitDiagonal }) {
4144
4149
  const x = triangularSolve$1(a.ref, b, { unitDiagonal });
4150
+ da = unitDiagonal ? triu(da, 1) : triu(da);
4145
4151
  const dax = batchMatmulT(da, x.ref);
4146
4152
  const rhsT = db.sub(mT(dax));
4147
4153
  const dx = triangularSolve$1(a, rhsT, { unitDiagonal });
@@ -5217,6 +5223,7 @@ function ifft(a, axis = -1) {
5217
5223
  var numpy_linalg_exports = {};
5218
5224
  __export(numpy_linalg_exports, {
5219
5225
  cholesky: () => cholesky,
5226
+ cross: () => cross$1,
5220
5227
  det: () => det,
5221
5228
  diagonal: () => diagonal,
5222
5229
  inv: () => inv,
@@ -5247,6 +5254,19 @@ function cholesky(a, { upper = false, symmetrizeInput = true } = {}) {
5247
5254
  if (symmetrizeInput) a = a.ref.add(matrixTranspose(a)).mul(.5);
5248
5255
  return cholesky$1(a, { upper });
5249
5256
  }
5257
+ /**
5258
+ * Compute the cross-product of two 3D vectors.
5259
+ *
5260
+ * This is a simpler and less flexible version of `jax.numpy.cross()`.
5261
+ * Both inputs must have size 3 along the specified axis.
5262
+ */
5263
+ function cross$1(x1, x2, axis = -1) {
5264
+ const a1 = checkAxis(axis, ndim(x1));
5265
+ const a2 = checkAxis(axis, ndim(x2));
5266
+ if (shape(x1)[a1] !== 3) throw new Error(`linalg.cross: x1 must have size 3 along axis ${axis}, got ${shape(x1)[a1]}`);
5267
+ if (shape(x2)[a2] !== 3) throw new Error(`linalg.cross: x2 must have size 3 along axis ${axis}, got ${shape(x2)[a2]}`);
5268
+ return cross(x1, x2, { axis });
5269
+ }
5250
5270
  /** Compute the determinant of a square matrix (batched). */
5251
5271
  function det(a) {
5252
5272
  a = fudgeArray(a);
@@ -5262,7 +5282,7 @@ function det(a) {
5262
5282
  function inv(a) {
5263
5283
  a = fudgeArray(a);
5264
5284
  const n = checkSquare("inv", a);
5265
- return solve(a, eye(n));
5285
+ return solve(a, eye(n, void 0, { dtype: a.dtype }));
5266
5286
  }
5267
5287
  /**
5268
5288
  * Return the least-squares solution to a linear equation.
@@ -5319,8 +5339,9 @@ function matrixPower(a, n) {
5319
5339
  a = fudgeArray(a);
5320
5340
  const m = checkSquare("matrixPower", a);
5321
5341
  if (n === 0) {
5342
+ const dtype = a.dtype;
5322
5343
  a.dispose();
5323
- return broadcastTo(eye(m), a.shape);
5344
+ return broadcastTo(eye(m, void 0, { dtype }), a.shape);
5324
5345
  }
5325
5346
  if (n < 0) {
5326
5347
  a = inv(a);
@@ -5502,13 +5523,17 @@ __export(numpy_exports, {
5502
5523
  argmax: () => argmax,
5503
5524
  argmin: () => argmin,
5504
5525
  argsort: () => argsort,
5526
+ around: () => round,
5505
5527
  array: () => array,
5528
+ arrayEqual: () => arrayEqual,
5529
+ arrayEquiv: () => arrayEquiv,
5506
5530
  asin: () => asin,
5507
5531
  asinh: () => arcsinh,
5508
5532
  astype: () => astype,
5509
5533
  atan: () => atan,
5510
5534
  atan2: () => atan2,
5511
5535
  atanh: () => arctanh,
5536
+ average: () => average,
5512
5537
  bool: () => bool,
5513
5538
  broadcastArrays: () => broadcastArrays,
5514
5539
  broadcastShapes: () => broadcastShapes,
@@ -5519,11 +5544,13 @@ __export(numpy_exports, {
5519
5544
  columnStack: () => columnStack,
5520
5545
  concatenate: () => concatenate,
5521
5546
  convolve: () => convolve,
5547
+ copysign: () => copysign,
5522
5548
  corrcoef: () => corrcoef,
5523
5549
  correlate: () => correlate,
5524
5550
  cos: () => cos,
5525
5551
  cosh: () => cosh,
5526
5552
  cov: () => cov,
5553
+ cross: () => cross,
5527
5554
  cumsum: () => cumsum,
5528
5555
  cumulativeSum: () => cumsum,
5529
5556
  deg2rad: () => deg2rad,
@@ -5559,7 +5586,6 @@ __export(numpy_exports, {
5559
5586
  fullLike: () => fullLike$1,
5560
5587
  greater: () => greater,
5561
5588
  greaterEqual: () => greaterEqual,
5562
- hamming: () => hamming,
5563
5589
  hann: () => hann,
5564
5590
  heaviside: () => heaviside,
5565
5591
  hstack: () => hstack,
@@ -5583,9 +5609,14 @@ __export(numpy_exports, {
5583
5609
  log10: () => log10,
5584
5610
  log1p: () => log1p,
5585
5611
  log2: () => log2,
5612
+ logicalAnd: () => logicalAnd,
5613
+ logicalNot: () => logicalNot,
5614
+ logicalOr: () => logicalOr,
5615
+ logicalXor: () => logicalXor,
5586
5616
  logspace: () => logspace,
5587
5617
  matmul: () => matmul,
5588
5618
  matrixTranspose: () => matrixTranspose,
5619
+ matvec: () => matvec,
5589
5620
  max: () => max,
5590
5621
  maximum: () => maximum,
5591
5622
  mean: () => mean,
@@ -5618,6 +5649,8 @@ __export(numpy_exports, {
5618
5649
  remainder: () => remainder,
5619
5650
  repeat: () => repeat,
5620
5651
  reshape: () => reshape,
5652
+ rint: () => rint,
5653
+ round: () => round,
5621
5654
  shape: () => shape,
5622
5655
  sign: () => sign,
5623
5656
  sin: () => sin,
@@ -5650,6 +5683,7 @@ __export(numpy_exports, {
5650
5683
  var_: () => var_,
5651
5684
  vdot: () => vdot,
5652
5685
  vecdot: () => vecdot,
5686
+ vecmat: () => vecmat,
5653
5687
  vstack: () => vstack,
5654
5688
  where: () => where,
5655
5689
  zeros: () => zeros,
@@ -5713,6 +5747,22 @@ const notEqual = notEqual$1;
5713
5747
  const greaterEqual = greaterEqual$1;
5714
5748
  /** @function Compare two arrays element-wise. */
5715
5749
  const lessEqual = lessEqual$1;
5750
+ /** Compute element-wise logical AND. */
5751
+ function logicalAnd(x, y) {
5752
+ return astype(x, DType.Bool).mul(astype(y, DType.Bool));
5753
+ }
5754
+ /** Compute element-wise logical OR. */
5755
+ function logicalOr(x, y) {
5756
+ return astype(x, DType.Bool).add(astype(y, DType.Bool));
5757
+ }
5758
+ /** Compute element-wise logical XOR. */
5759
+ function logicalXor(x, y) {
5760
+ return notEqual(astype(x, DType.Bool), astype(y, DType.Bool));
5761
+ }
5762
+ /** Compute element-wise logical NOT. */
5763
+ function logicalNot(x) {
5764
+ return notEqual(astype(x, DType.Bool), true);
5765
+ }
5716
5766
  /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
5717
5767
  const where = where$1;
5718
5768
  /**
@@ -5820,6 +5870,34 @@ function mean(a, axis = null, opts) {
5820
5870
  return fudgeArray(a).mean(axis, opts);
5821
5871
  }
5822
5872
  /**
5873
+ * Compute the weighted average along the specified axis.
5874
+ *
5875
+ * If no axis is specified, mean is computed along all the axes. The weights
5876
+ * should have shape matching that of `a`, or if an axis is specified, it should
5877
+ * match the shape along those axes.
5878
+ */
5879
+ function average(a, axis = null, opts) {
5880
+ a = fudgeArray(a);
5881
+ if (opts?.weights == null) return mean(a, axis, opts);
5882
+ const weights = fudgeArray(opts.weights);
5883
+ axis = normalizeAxis(axis, ndim(a));
5884
+ const wShape = weights.shape;
5885
+ const aShape = a.shape;
5886
+ if (deepEqual(wShape, aShape)) {
5887
+ const scl = sum(weights.ref, axis, opts);
5888
+ return sum(multiply(a, weights), axis, opts).div(scl);
5889
+ } else if (axis.length === 1 && wShape.length === 1 && wShape[0] === aShape[axis[0]]) {
5890
+ const broadcastShape = aShape.map((_, i) => i === axis[0] ? wShape[0] : 1);
5891
+ const wReshaped = reshape(weights, broadcastShape);
5892
+ const scl = sum(wReshaped.ref, axis, opts);
5893
+ return sum(multiply(a, wReshaped), axis, opts).div(scl);
5894
+ } else {
5895
+ weights.dispose();
5896
+ a.dispose();
5897
+ throw new Error(`average: weights shape ${JSON.stringify(wShape)} is not compatible with array shape ${JSON.stringify(aShape)} and axis ${JSON.stringify(axis)}`);
5898
+ }
5899
+ }
5900
+ /**
5823
5901
  * Returns the indices of the minimum values along an axis.
5824
5902
  *
5825
5903
  * By default, index is into the flatted array, otherwise it is along the
@@ -6223,20 +6301,63 @@ function take(a, indices, axis = null) {
6223
6301
  axis = checkAxis(axis, ndim(a));
6224
6302
  return gather(a, [indices], [axis], axis);
6225
6303
  }
6226
- /** Return if two arrays are element-wise equal within a tolerance. */
6304
+ /**
6305
+ * Return if two arrays are element-wise equal within a tolerance.
6306
+ *
6307
+ * The formula used is `|actual - expected| <= atol + rtol * |expected|`, with
6308
+ * NaN values comparing equal if `equalNaN` is true.
6309
+ */
6227
6310
  function allclose(actual, expected, options) {
6228
- const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
6311
+ const { rtol = 1e-5, atol = 1e-7, equalNaN = false } = options ?? {};
6229
6312
  const x = array(actual);
6230
6313
  const y = array(expected);
6231
6314
  if (!deepEqual(x.shape, y.shape)) return false;
6232
6315
  const xData = x.dataSync();
6233
6316
  const yData = y.dataSync();
6234
6317
  for (let i = 0; i < xData.length; i++) {
6235
- if (isNaN(xData[i]) !== isNaN(yData[i])) return false;
6318
+ if (equalNaN ? isNaN(xData[i]) !== isNaN(yData[i]) : isNaN(xData[i]) || isNaN(yData[i])) return false;
6236
6319
  if (Math.abs(xData[i] - yData[i]) > atol + rtol * Math.abs(yData[i])) return false;
6237
6320
  }
6238
6321
  return true;
6239
6322
  }
6323
+ /**
6324
+ * Check if two arrays are element-wise equal.
6325
+ *
6326
+ * Returns False if the arrays have different shapes. If `equalNaN` is True,
6327
+ * NaNs in the same position are considered equal.
6328
+ */
6329
+ function arrayEqual(a1, a2, opts) {
6330
+ a1 = fudgeArray(a1);
6331
+ a2 = fudgeArray(a2);
6332
+ if (!deepEqual(a1.shape, a2.shape)) {
6333
+ a1.dispose();
6334
+ a2.dispose();
6335
+ return array(false);
6336
+ }
6337
+ if (opts?.equalNaN) {
6338
+ const nanMask = isnan(a1.ref).mul(isnan(a2.ref));
6339
+ return where(nanMask, true, equal(a1, a2)).all();
6340
+ }
6341
+ return equal(a1, a2).all();
6342
+ }
6343
+ /**
6344
+ * Check if two arrays are element-wise equal after broadcasting.
6345
+ *
6346
+ * Unlike `arrayEqual`, this allows inputs with different but
6347
+ * broadcast-compatible shapes.
6348
+ */
6349
+ function arrayEquiv(a1, a2) {
6350
+ a1 = fudgeArray(a1);
6351
+ a2 = fudgeArray(a2);
6352
+ try {
6353
+ const [b1, b2] = broadcastArrays(a1, a2);
6354
+ return equal(b1, b2).all();
6355
+ } catch {
6356
+ a1.dispose();
6357
+ a2.dispose();
6358
+ return array(false);
6359
+ }
6360
+ }
6240
6361
  /** Matrix product of two arrays. */
6241
6362
  function matmul(x, y) {
6242
6363
  if (ndim(x) === 0 || ndim(y) === 0) throw new Error("matmul: x and y must be at least 1D");
@@ -6250,6 +6371,16 @@ function matmul(x, y) {
6250
6371
  rhsBatchDims: range(-2 - numBatchDims, -2)
6251
6372
  });
6252
6373
  }
6374
+ /** Matrix-vector product. x1 is [..., M, N], x2 is [..., N] → [..., M]. */
6375
+ function matvec(x1, x2) {
6376
+ if (ndim(x1) < 2 || ndim(x2) < 1) throw new Error("matvec: x1 must be at least 2D and x2 at least 1D");
6377
+ return einsum("...mn,...n->...m", x1, x2);
6378
+ }
6379
+ /** Vector-matrix product. x1 is [..., N], x2 is [..., N, M] → [..., M]. */
6380
+ function vecmat(x1, x2) {
6381
+ if (ndim(x1) < 1 || ndim(x2) < 2) throw new Error("vecmat: x1 must be at least 1D and x2 at least 2D");
6382
+ return einsum("...n,...nm->...m", x1, x2);
6383
+ }
6253
6384
  /** Dot product of two arrays. */
6254
6385
  function dot$1(x, y) {
6255
6386
  if (ndim(x) === 0 || ndim(y) === 0) return multiply(x, y);
@@ -6408,6 +6539,49 @@ function outer(x, y) {
6408
6539
  y = ravel(y);
6409
6540
  return multiply(x.reshape([x.shape[0], 1]), y);
6410
6541
  }
6542
+ /**
6543
+ * @function Compute the cross product of two arrays.
6544
+ *
6545
+ * Supports 2D (scalar result) and 3D cross products, with optional axis
6546
+ * arguments. If `axis` is given, it overrides `axisa`, `axisb`, and `axisc`.
6547
+ */
6548
+ const cross = jit$1(function cross$2(a, b, { axisa = -1, axisb = -1, axisc = -1, axis } = {}) {
6549
+ if (axis !== void 0) {
6550
+ axisa = axis;
6551
+ axisb = axis;
6552
+ axisc = axis;
6553
+ }
6554
+ axisa = checkAxis(axisa, ndim(a));
6555
+ axisb = checkAxis(axisb, ndim(b));
6556
+ a = moveaxis$1(a, axisa, -1);
6557
+ b = moveaxis$1(b, axisb, -1);
6558
+ const da = a.shape.at(-1);
6559
+ const db = b.shape.at(-1);
6560
+ if (da !== 2 && da !== 3 || db !== 2 && db !== 3) throw new Error(`cross: incompatible dimensions for cross product (got ${da} and ${db})`);
6561
+ if (da === 2 && db === 2) {
6562
+ const [a0$1, a1$1] = split$1(a, 2, -1);
6563
+ const [b0$1, b1$1] = split$1(b, 2, -1);
6564
+ return squeeze(a0$1.mul(b1$1).sub(a1$1.mul(b0$1)), -1);
6565
+ }
6566
+ if (da === 2) {
6567
+ const zeroShape = [...a.shape.slice(0, -1), 1];
6568
+ a = concatenate([a, zeros(zeroShape)], -1);
6569
+ }
6570
+ if (db === 2) {
6571
+ const zeroShape = [...b.shape.slice(0, -1), 1];
6572
+ b = concatenate([b, zeros(zeroShape)], -1);
6573
+ }
6574
+ const [a0, a1, a2] = split$1(a, 3, -1);
6575
+ const [b0, b1, b2] = split$1(b, 3, -1);
6576
+ const c0 = a1.ref.mul(b2.ref).sub(a2.ref.mul(b1.ref));
6577
+ const c1 = a2.mul(b0.ref).sub(a0.ref.mul(b2));
6578
+ const c2 = a0.mul(b1).sub(a1.mul(b0));
6579
+ return moveaxis$1(concatenate([
6580
+ c0,
6581
+ c1,
6582
+ c2
6583
+ ], -1), -1, axisc);
6584
+ }, { staticArgnums: [2] });
6411
6585
  /** Vector dot product of two arrays along a given axis. */
6412
6586
  function vecdot(x, y, { axis } = {}) {
6413
6587
  const xaxis = checkAxis(axis ?? -1, ndim(x));
@@ -6504,16 +6678,15 @@ function sign(x) {
6504
6678
  x = fudgeArray(x);
6505
6679
  return where(notEqual(x.ref, 0), where(less(x, 0), -1, 1), 0);
6506
6680
  }
6507
- /** @function Return element-wise positive values of the input (no-op). */
6508
- const positive = fudgeArray;
6509
6681
  /**
6510
- * Return the Hamming window of size M, a taper with a weighted cosine bell.
6511
- *
6512
- * `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
6682
+ * @function
6683
+ * Return the value with the magnitude of x and the sign of y, element-wise.
6513
6684
  */
6514
- function hamming(M) {
6515
- return cos(linspace(0, 2 * Math.PI, M)).mul(-.46).add(.54);
6516
- }
6685
+ const copysign = jit$1(function copysign$1(x, y) {
6686
+ return absolute(x).mul(sign(y));
6687
+ });
6688
+ /** @function Return element-wise positive values of the input (no-op). */
6689
+ const positive = fudgeArray;
6517
6690
  /**
6518
6691
  * Return the Hann window of size M, a taper with a weighted cosine bell.
6519
6692
  *
@@ -6659,6 +6832,27 @@ function trunc(x) {
6659
6832
  return idiv(x, 1);
6660
6833
  }
6661
6834
  /**
6835
+ * @function
6836
+ * Round to the given number of decimals.
6837
+ *
6838
+ * Uses banker's rounding (round half to even) to match NumPy/JAX behavior.
6839
+ */
6840
+ const round = jit$1(function round$1(a, decimals = 0) {
6841
+ if (decimals === 0) return rint(a);
6842
+ const factor = 10 ** decimals;
6843
+ return rint(a.mul(factor)).mul(1 / factor);
6844
+ }, { staticArgnums: [1] });
6845
+ /**
6846
+ * @function
6847
+ * Round to the nearest integer, with ties going to the nearest even integer.
6848
+ */
6849
+ const rint = jit$1(function rint$1(x) {
6850
+ const rounded = floor(x.ref.add(.5));
6851
+ const half = x.ref.sub(floor(x)).equal(.5);
6852
+ const odd = remainder(rounded.ref, 2).notEqual(0);
6853
+ return where(half.mul(odd), rounded.ref.sub(1), rounded);
6854
+ });
6855
+ /**
6662
6856
  * Compute `x1 * 2 ** x2` as a standard multiplication and exponentiation.
6663
6857
  *
6664
6858
  * This is the inverse of `frexp()`.
@@ -6986,6 +7180,7 @@ function triangularSolve(a, b, { leftSide = false, lower = false, transposeA = f
6986
7180
  //#region src/library/lax.ts
6987
7181
  var lax_exports = {};
6988
7182
  __export(lax_exports, {
7183
+ bitcastConvertType: () => bitcastConvertType,
6989
7184
  conv: () => conv,
6990
7185
  convGeneralDilated: () => convGeneralDilated,
6991
7186
  convTranspose: () => convTranspose,
@@ -6999,6 +7194,10 @@ __export(lax_exports, {
6999
7194
  topK: () => topK
7000
7195
  });
7001
7196
  const JsArray = globalThis.Array;
7197
+ /** Elementwise bitcast an array into a new dtype. */
7198
+ function bitcastConvertType(x, newDtype) {
7199
+ return fudgeArray(x).view(newDtype);
7200
+ }
7002
7201
  /**
7003
7202
  * General dot product/contraction operator.
7004
7203
  *
@@ -7730,7 +7929,9 @@ function getK01(key$1) {
7730
7929
  function key(seed) {
7731
7930
  seed = array(seed, { dtype: DType.Uint32 });
7732
7931
  if (seed.ndim !== 0) throw new Error(`key: seed must be a scalar integer, but got shape ${seed.shape} - use jax.vmap for batching.`);
7733
- return stack([0, seed]);
7932
+ const key$1 = stack([0, seed]);
7933
+ if (key$1 instanceof Array$1) key$1._realizeSource();
7934
+ return key$1;
7734
7935
  }
7735
7936
  /** Splits a PRNG key into `num` new keys by adding a leading axis. */
7736
7937
  function split(key$1, num = 2) {
@@ -7925,6 +8126,11 @@ Symbol.asyncDispose ??= Symbol.for("Symbol.asyncDispose");
7925
8126
 
7926
8127
  //#endregion
7927
8128
  //#region src/index.ts
8129
+ /** @namespace */
8130
+ const profiler = {
8131
+ startTrace,
8132
+ stopTrace
8133
+ };
7928
8134
  /**
7929
8135
  * @function
7930
8136
  * Compute the forward-mode Jacobian-vector product for a function.
@@ -8085,4 +8291,4 @@ async function devicePut(x, device) {
8085
8291
  }
8086
8292
 
8087
8293
  //#endregion
8088
- export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
8294
+ export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, profiler, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
@@ -1,4 +1,4 @@
1
- import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-BId79r5b.js";
1
+ import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-Ctqs8la1.js";
2
2
 
3
3
  //#region src/backend/webgl/builtins.ts
4
4
  const threefrySrc = `
@@ -1,4 +1,4 @@
1
- const require_backend = require('./backend-DpI0riom.cjs');
1
+ const require_backend = require('./backend-DMauYnfl.cjs');
2
2
 
3
3
  //#region src/backend/webgl/builtins.ts
4
4
  const threefrySrc = `