@jax-js/jax 0.1.9 → 0.1.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.cjs CHANGED
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
30
30
  }) : target, mod$1));
31
31
 
32
32
  //#endregion
33
- const require_backend = require('./backend-DpI0riom.cjs');
33
+ const require_backend = require('./backend-DMauYnfl.cjs');
34
34
 
35
35
  //#region src/frontend/convolution.ts
36
36
  /**
@@ -838,6 +838,11 @@ var Tracer = class Tracer {
838
838
  if (this.dtype === dtype) return this;
839
839
  return cast(this, dtype);
840
840
  }
841
+ /** Return a bitwise cast of the array, viewed as a new dtype. */
842
+ view(dtype) {
843
+ if (!dtype || dtype === this.dtype) return this;
844
+ return bitcast(this, dtype);
845
+ }
841
846
  /** Subtract an array from this one. */
842
847
  sub(other) {
843
848
  return this.add(neg(other));
@@ -1659,7 +1664,7 @@ const abstractEvalRules = {
1659
1664
  return [new ShapedArray(x.shape, dtype, false)];
1660
1665
  },
1661
1666
  [Primitive.Bitcast]([x], { dtype }) {
1662
- if (x.dtype === require_backend.DType.Bool || dtype === require_backend.DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
1667
+ if (x.dtype !== dtype && (x.dtype === require_backend.DType.Bool || dtype === require_backend.DType.Bool)) throw new TypeError("Bitcast to/from bool is not allowed");
1663
1668
  if (require_backend.byteWidth(x.dtype) !== require_backend.byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
1664
1669
  return [new ShapedArray(x.shape, dtype, false)];
1665
1670
  },
@@ -3081,8 +3086,8 @@ var Array$1 = class Array$1 extends Tracer {
3081
3086
  return [x.#unary(require_backend.AluOp.Cast, dtype)];
3082
3087
  },
3083
3088
  [Primitive.Bitcast]([x], { dtype }) {
3084
- if (x.dtype === require_backend.DType.Bool || dtype === require_backend.DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
3085
3089
  if (x.dtype === dtype) return [x];
3090
+ if (x.dtype === require_backend.DType.Bool || dtype === require_backend.DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
3086
3091
  if (require_backend.byteWidth(x.dtype) !== require_backend.byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
3087
3092
  if (x.#source instanceof require_backend.AluExp) return [x.#unary(require_backend.AluOp.Bitcast, dtype)];
3088
3093
  else {
@@ -4179,6 +4184,7 @@ const jvpRules = {
4179
4184
  },
4180
4185
  [Primitive.TriangularSolve]([a, b], [da, db], { unitDiagonal }) {
4181
4186
  const x = triangularSolve$1(a.ref, b, { unitDiagonal });
4187
+ da = unitDiagonal ? triu(da, 1) : triu(da);
4182
4188
  const dax = batchMatmulT(da, x.ref);
4183
4189
  const rhsT = db.sub(mT(dax));
4184
4190
  const dx = triangularSolve$1(a, rhsT, { unitDiagonal });
@@ -5254,6 +5260,7 @@ function ifft(a, axis = -1) {
5254
5260
  var numpy_linalg_exports = {};
5255
5261
  __export(numpy_linalg_exports, {
5256
5262
  cholesky: () => cholesky,
5263
+ cross: () => cross$1,
5257
5264
  det: () => det,
5258
5265
  diagonal: () => diagonal,
5259
5266
  inv: () => inv,
@@ -5284,6 +5291,19 @@ function cholesky(a, { upper = false, symmetrizeInput = true } = {}) {
5284
5291
  if (symmetrizeInput) a = a.ref.add(matrixTranspose(a)).mul(.5);
5285
5292
  return cholesky$1(a, { upper });
5286
5293
  }
5294
+ /**
5295
+ * Compute the cross-product of two 3D vectors.
5296
+ *
5297
+ * This is a simpler and less flexible version of `jax.numpy.cross()`.
5298
+ * Both inputs must have size 3 along the specified axis.
5299
+ */
5300
+ function cross$1(x1, x2, axis = -1) {
5301
+ const a1 = require_backend.checkAxis(axis, ndim(x1));
5302
+ const a2 = require_backend.checkAxis(axis, ndim(x2));
5303
+ if (shape(x1)[a1] !== 3) throw new Error(`linalg.cross: x1 must have size 3 along axis ${axis}, got ${shape(x1)[a1]}`);
5304
+ if (shape(x2)[a2] !== 3) throw new Error(`linalg.cross: x2 must have size 3 along axis ${axis}, got ${shape(x2)[a2]}`);
5305
+ return cross(x1, x2, { axis });
5306
+ }
5287
5307
  /** Compute the determinant of a square matrix (batched). */
5288
5308
  function det(a) {
5289
5309
  a = fudgeArray(a);
@@ -5299,7 +5319,7 @@ function det(a) {
5299
5319
  function inv(a) {
5300
5320
  a = fudgeArray(a);
5301
5321
  const n = checkSquare("inv", a);
5302
- return solve(a, eye(n));
5322
+ return solve(a, eye(n, void 0, { dtype: a.dtype }));
5303
5323
  }
5304
5324
  /**
5305
5325
  * Return the least-squares solution to a linear equation.
@@ -5356,8 +5376,9 @@ function matrixPower(a, n) {
5356
5376
  a = fudgeArray(a);
5357
5377
  const m = checkSquare("matrixPower", a);
5358
5378
  if (n === 0) {
5379
+ const dtype = a.dtype;
5359
5380
  a.dispose();
5360
- return broadcastTo(eye(m), a.shape);
5381
+ return broadcastTo(eye(m, void 0, { dtype }), a.shape);
5361
5382
  }
5362
5383
  if (n < 0) {
5363
5384
  a = inv(a);
@@ -5539,13 +5560,17 @@ __export(numpy_exports, {
5539
5560
  argmax: () => argmax,
5540
5561
  argmin: () => argmin,
5541
5562
  argsort: () => argsort,
5563
+ around: () => round,
5542
5564
  array: () => array,
5565
+ arrayEqual: () => arrayEqual,
5566
+ arrayEquiv: () => arrayEquiv,
5543
5567
  asin: () => asin,
5544
5568
  asinh: () => arcsinh,
5545
5569
  astype: () => astype,
5546
5570
  atan: () => atan,
5547
5571
  atan2: () => atan2,
5548
5572
  atanh: () => arctanh,
5573
+ average: () => average,
5549
5574
  bool: () => bool,
5550
5575
  broadcastArrays: () => broadcastArrays,
5551
5576
  broadcastShapes: () => broadcastShapes,
@@ -5556,11 +5581,13 @@ __export(numpy_exports, {
5556
5581
  columnStack: () => columnStack,
5557
5582
  concatenate: () => concatenate,
5558
5583
  convolve: () => convolve,
5584
+ copysign: () => copysign,
5559
5585
  corrcoef: () => corrcoef,
5560
5586
  correlate: () => correlate,
5561
5587
  cos: () => cos,
5562
5588
  cosh: () => cosh,
5563
5589
  cov: () => cov,
5590
+ cross: () => cross,
5564
5591
  cumsum: () => cumsum,
5565
5592
  cumulativeSum: () => cumsum,
5566
5593
  deg2rad: () => deg2rad,
@@ -5596,7 +5623,6 @@ __export(numpy_exports, {
5596
5623
  fullLike: () => fullLike$1,
5597
5624
  greater: () => greater,
5598
5625
  greaterEqual: () => greaterEqual,
5599
- hamming: () => hamming,
5600
5626
  hann: () => hann,
5601
5627
  heaviside: () => heaviside,
5602
5628
  hstack: () => hstack,
@@ -5620,9 +5646,14 @@ __export(numpy_exports, {
5620
5646
  log10: () => log10,
5621
5647
  log1p: () => log1p,
5622
5648
  log2: () => log2,
5649
+ logicalAnd: () => logicalAnd,
5650
+ logicalNot: () => logicalNot,
5651
+ logicalOr: () => logicalOr,
5652
+ logicalXor: () => logicalXor,
5623
5653
  logspace: () => logspace,
5624
5654
  matmul: () => matmul,
5625
5655
  matrixTranspose: () => matrixTranspose,
5656
+ matvec: () => matvec,
5626
5657
  max: () => max,
5627
5658
  maximum: () => maximum,
5628
5659
  mean: () => mean,
@@ -5655,6 +5686,8 @@ __export(numpy_exports, {
5655
5686
  remainder: () => remainder,
5656
5687
  repeat: () => repeat,
5657
5688
  reshape: () => reshape,
5689
+ rint: () => rint,
5690
+ round: () => round,
5658
5691
  shape: () => shape,
5659
5692
  sign: () => sign,
5660
5693
  sin: () => sin,
@@ -5687,6 +5720,7 @@ __export(numpy_exports, {
5687
5720
  var_: () => var_,
5688
5721
  vdot: () => vdot,
5689
5722
  vecdot: () => vecdot,
5723
+ vecmat: () => vecmat,
5690
5724
  vstack: () => vstack,
5691
5725
  where: () => where,
5692
5726
  zeros: () => zeros,
@@ -5750,6 +5784,22 @@ const notEqual = notEqual$1;
5750
5784
  const greaterEqual = greaterEqual$1;
5751
5785
  /** @function Compare two arrays element-wise. */
5752
5786
  const lessEqual = lessEqual$1;
5787
+ /** Compute element-wise logical AND. */
5788
+ function logicalAnd(x, y) {
5789
+ return astype(x, require_backend.DType.Bool).mul(astype(y, require_backend.DType.Bool));
5790
+ }
5791
+ /** Compute element-wise logical OR. */
5792
+ function logicalOr(x, y) {
5793
+ return astype(x, require_backend.DType.Bool).add(astype(y, require_backend.DType.Bool));
5794
+ }
5795
+ /** Compute element-wise logical XOR. */
5796
+ function logicalXor(x, y) {
5797
+ return notEqual(astype(x, require_backend.DType.Bool), astype(y, require_backend.DType.Bool));
5798
+ }
5799
+ /** Compute element-wise logical NOT. */
5800
+ function logicalNot(x) {
5801
+ return notEqual(astype(x, require_backend.DType.Bool), true);
5802
+ }
5753
5803
  /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
5754
5804
  const where = where$1;
5755
5805
  /**
@@ -5857,6 +5907,34 @@ function mean(a, axis = null, opts) {
5857
5907
  return fudgeArray(a).mean(axis, opts);
5858
5908
  }
5859
5909
  /**
5910
+ * Compute the weighted average along the specified axis.
5911
+ *
5912
+ * If no axis is specified, mean is computed along all the axes. The weights
5913
+ * should have shape matching that of `a`, or if an axis is specified, it should
5914
+ * match the shape along those axes.
5915
+ */
5916
+ function average(a, axis = null, opts) {
5917
+ a = fudgeArray(a);
5918
+ if (opts?.weights == null) return mean(a, axis, opts);
5919
+ const weights = fudgeArray(opts.weights);
5920
+ axis = require_backend.normalizeAxis(axis, ndim(a));
5921
+ const wShape = weights.shape;
5922
+ const aShape = a.shape;
5923
+ if (require_backend.deepEqual(wShape, aShape)) {
5924
+ const scl = sum(weights.ref, axis, opts);
5925
+ return sum(multiply(a, weights), axis, opts).div(scl);
5926
+ } else if (axis.length === 1 && wShape.length === 1 && wShape[0] === aShape[axis[0]]) {
5927
+ const broadcastShape = aShape.map((_, i) => i === axis[0] ? wShape[0] : 1);
5928
+ const wReshaped = reshape(weights, broadcastShape);
5929
+ const scl = sum(wReshaped.ref, axis, opts);
5930
+ return sum(multiply(a, wReshaped), axis, opts).div(scl);
5931
+ } else {
5932
+ weights.dispose();
5933
+ a.dispose();
5934
+ throw new Error(`average: weights shape ${JSON.stringify(wShape)} is not compatible with array shape ${JSON.stringify(aShape)} and axis ${JSON.stringify(axis)}`);
5935
+ }
5936
+ }
5937
+ /**
5860
5938
  * Returns the indices of the minimum values along an axis.
5861
5939
  *
5862
5940
  * By default, index is into the flatted array, otherwise it is along the
@@ -6260,20 +6338,63 @@ function take(a, indices, axis = null) {
6260
6338
  axis = require_backend.checkAxis(axis, ndim(a));
6261
6339
  return gather(a, [indices], [axis], axis);
6262
6340
  }
6263
- /** Return if two arrays are element-wise equal within a tolerance. */
6341
+ /**
6342
+ * Return if two arrays are element-wise equal within a tolerance.
6343
+ *
6344
+ * The formula used is `|actual - expected| <= atol + rtol * |expected|`, with
6345
+ * NaN values comparing equal if `equalNaN` is true.
6346
+ */
6264
6347
  function allclose(actual, expected, options) {
6265
- const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
6348
+ const { rtol = 1e-5, atol = 1e-7, equalNaN = false } = options ?? {};
6266
6349
  const x = array(actual);
6267
6350
  const y = array(expected);
6268
6351
  if (!require_backend.deepEqual(x.shape, y.shape)) return false;
6269
6352
  const xData = x.dataSync();
6270
6353
  const yData = y.dataSync();
6271
6354
  for (let i = 0; i < xData.length; i++) {
6272
- if (isNaN(xData[i]) !== isNaN(yData[i])) return false;
6355
+ if (equalNaN ? isNaN(xData[i]) !== isNaN(yData[i]) : isNaN(xData[i]) || isNaN(yData[i])) return false;
6273
6356
  if (Math.abs(xData[i] - yData[i]) > atol + rtol * Math.abs(yData[i])) return false;
6274
6357
  }
6275
6358
  return true;
6276
6359
  }
6360
+ /**
6361
+ * Check if two arrays are element-wise equal.
6362
+ *
6363
+ * Returns False if the arrays have different shapes. If `equalNaN` is True,
6364
+ * NaNs in the same position are considered equal.
6365
+ */
6366
+ function arrayEqual(a1, a2, opts) {
6367
+ a1 = fudgeArray(a1);
6368
+ a2 = fudgeArray(a2);
6369
+ if (!require_backend.deepEqual(a1.shape, a2.shape)) {
6370
+ a1.dispose();
6371
+ a2.dispose();
6372
+ return array(false);
6373
+ }
6374
+ if (opts?.equalNaN) {
6375
+ const nanMask = isnan(a1.ref).mul(isnan(a2.ref));
6376
+ return where(nanMask, true, equal(a1, a2)).all();
6377
+ }
6378
+ return equal(a1, a2).all();
6379
+ }
6380
+ /**
6381
+ * Check if two arrays are element-wise equal after broadcasting.
6382
+ *
6383
+ * Unlike `arrayEqual`, this allows inputs with different but
6384
+ * broadcast-compatible shapes.
6385
+ */
6386
+ function arrayEquiv(a1, a2) {
6387
+ a1 = fudgeArray(a1);
6388
+ a2 = fudgeArray(a2);
6389
+ try {
6390
+ const [b1, b2] = broadcastArrays(a1, a2);
6391
+ return equal(b1, b2).all();
6392
+ } catch {
6393
+ a1.dispose();
6394
+ a2.dispose();
6395
+ return array(false);
6396
+ }
6397
+ }
6277
6398
  /** Matrix product of two arrays. */
6278
6399
  function matmul(x, y) {
6279
6400
  if (ndim(x) === 0 || ndim(y) === 0) throw new Error("matmul: x and y must be at least 1D");
@@ -6287,6 +6408,16 @@ function matmul(x, y) {
6287
6408
  rhsBatchDims: require_backend.range(-2 - numBatchDims, -2)
6288
6409
  });
6289
6410
  }
6411
+ /** Matrix-vector product. x1 is [..., M, N], x2 is [..., N] → [..., M]. */
6412
+ function matvec(x1, x2) {
6413
+ if (ndim(x1) < 2 || ndim(x2) < 1) throw new Error("matvec: x1 must be at least 2D and x2 at least 1D");
6414
+ return einsum("...mn,...n->...m", x1, x2);
6415
+ }
6416
+ /** Vector-matrix product. x1 is [..., N], x2 is [..., N, M] → [..., M]. */
6417
+ function vecmat(x1, x2) {
6418
+ if (ndim(x1) < 1 || ndim(x2) < 2) throw new Error("vecmat: x1 must be at least 1D and x2 at least 2D");
6419
+ return einsum("...n,...nm->...m", x1, x2);
6420
+ }
6290
6421
  /** Dot product of two arrays. */
6291
6422
  function dot$1(x, y) {
6292
6423
  if (ndim(x) === 0 || ndim(y) === 0) return multiply(x, y);
@@ -6445,6 +6576,49 @@ function outer(x, y) {
6445
6576
  y = ravel(y);
6446
6577
  return multiply(x.reshape([x.shape[0], 1]), y);
6447
6578
  }
6579
+ /**
6580
+ * @function Compute the cross product of two arrays.
6581
+ *
6582
+ * Supports 2D (scalar result) and 3D cross products, with optional axis
6583
+ * arguments. If `axis` is given, it overrides `axisa`, `axisb`, and `axisc`.
6584
+ */
6585
+ const cross = jit$1(function cross$2(a, b, { axisa = -1, axisb = -1, axisc = -1, axis } = {}) {
6586
+ if (axis !== void 0) {
6587
+ axisa = axis;
6588
+ axisb = axis;
6589
+ axisc = axis;
6590
+ }
6591
+ axisa = require_backend.checkAxis(axisa, ndim(a));
6592
+ axisb = require_backend.checkAxis(axisb, ndim(b));
6593
+ a = moveaxis$1(a, axisa, -1);
6594
+ b = moveaxis$1(b, axisb, -1);
6595
+ const da = a.shape.at(-1);
6596
+ const db = b.shape.at(-1);
6597
+ if (da !== 2 && da !== 3 || db !== 2 && db !== 3) throw new Error(`cross: incompatible dimensions for cross product (got ${da} and ${db})`);
6598
+ if (da === 2 && db === 2) {
6599
+ const [a0$1, a1$1] = split$1(a, 2, -1);
6600
+ const [b0$1, b1$1] = split$1(b, 2, -1);
6601
+ return squeeze(a0$1.mul(b1$1).sub(a1$1.mul(b0$1)), -1);
6602
+ }
6603
+ if (da === 2) {
6604
+ const zeroShape = [...a.shape.slice(0, -1), 1];
6605
+ a = concatenate([a, zeros(zeroShape)], -1);
6606
+ }
6607
+ if (db === 2) {
6608
+ const zeroShape = [...b.shape.slice(0, -1), 1];
6609
+ b = concatenate([b, zeros(zeroShape)], -1);
6610
+ }
6611
+ const [a0, a1, a2] = split$1(a, 3, -1);
6612
+ const [b0, b1, b2] = split$1(b, 3, -1);
6613
+ const c0 = a1.ref.mul(b2.ref).sub(a2.ref.mul(b1.ref));
6614
+ const c1 = a2.mul(b0.ref).sub(a0.ref.mul(b2));
6615
+ const c2 = a0.mul(b1).sub(a1.mul(b0));
6616
+ return moveaxis$1(concatenate([
6617
+ c0,
6618
+ c1,
6619
+ c2
6620
+ ], -1), -1, axisc);
6621
+ }, { staticArgnums: [2] });
6448
6622
  /** Vector dot product of two arrays along a given axis. */
6449
6623
  function vecdot(x, y, { axis } = {}) {
6450
6624
  const xaxis = require_backend.checkAxis(axis ?? -1, ndim(x));
@@ -6541,16 +6715,15 @@ function sign(x) {
6541
6715
  x = fudgeArray(x);
6542
6716
  return where(notEqual(x.ref, 0), where(less(x, 0), -1, 1), 0);
6543
6717
  }
6544
- /** @function Return element-wise positive values of the input (no-op). */
6545
- const positive = fudgeArray;
6546
6718
  /**
6547
- * Return the Hamming window of size M, a taper with a weighted cosine bell.
6548
- *
6549
- * `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
6719
+ * @function
6720
+ * Return the value with the magnitude of x and the sign of y, element-wise.
6550
6721
  */
6551
- function hamming(M) {
6552
- return cos(linspace(0, 2 * Math.PI, M)).mul(-.46).add(.54);
6553
- }
6722
+ const copysign = jit$1(function copysign$1(x, y) {
6723
+ return absolute(x).mul(sign(y));
6724
+ });
6725
+ /** @function Return element-wise positive values of the input (no-op). */
6726
+ const positive = fudgeArray;
6554
6727
  /**
6555
6728
  * Return the Hann window of size M, a taper with a weighted cosine bell.
6556
6729
  *
@@ -6696,6 +6869,27 @@ function trunc(x) {
6696
6869
  return idiv(x, 1);
6697
6870
  }
6698
6871
  /**
6872
+ * @function
6873
+ * Round to the given number of decimals.
6874
+ *
6875
+ * Uses banker's rounding (round half to even) to match NumPy/JAX behavior.
6876
+ */
6877
+ const round = jit$1(function round$1(a, decimals = 0) {
6878
+ if (decimals === 0) return rint(a);
6879
+ const factor = 10 ** decimals;
6880
+ return rint(a.mul(factor)).mul(1 / factor);
6881
+ }, { staticArgnums: [1] });
6882
+ /**
6883
+ * @function
6884
+ * Round to the nearest integer, with ties going to the nearest even integer.
6885
+ */
6886
+ const rint = jit$1(function rint$1(x) {
6887
+ const rounded = floor(x.ref.add(.5));
6888
+ const half = x.ref.sub(floor(x)).equal(.5);
6889
+ const odd = remainder(rounded.ref, 2).notEqual(0);
6890
+ return where(half.mul(odd), rounded.ref.sub(1), rounded);
6891
+ });
6892
+ /**
6699
6893
  * Compute `x1 * 2 ** x2` as a standard multiplication and exponentiation.
6700
6894
  *
6701
6895
  * This is the inverse of `frexp()`.
@@ -7023,6 +7217,7 @@ function triangularSolve(a, b, { leftSide = false, lower = false, transposeA = f
7023
7217
  //#region src/library/lax.ts
7024
7218
  var lax_exports = {};
7025
7219
  __export(lax_exports, {
7220
+ bitcastConvertType: () => bitcastConvertType,
7026
7221
  conv: () => conv,
7027
7222
  convGeneralDilated: () => convGeneralDilated,
7028
7223
  convTranspose: () => convTranspose,
@@ -7036,6 +7231,10 @@ __export(lax_exports, {
7036
7231
  topK: () => topK
7037
7232
  });
7038
7233
  const JsArray = globalThis.Array;
7234
+ /** Elementwise bitcast an array into a new dtype. */
7235
+ function bitcastConvertType(x, newDtype) {
7236
+ return fudgeArray(x).view(newDtype);
7237
+ }
7039
7238
  /**
7040
7239
  * General dot product/contraction operator.
7041
7240
  *
@@ -7767,7 +7966,9 @@ function getK01(key$1) {
7767
7966
  function key(seed) {
7768
7967
  seed = array(seed, { dtype: require_backend.DType.Uint32 });
7769
7968
  if (seed.ndim !== 0) throw new Error(`key: seed must be a scalar integer, but got shape ${seed.shape} - use jax.vmap for batching.`);
7770
- return stack([0, seed]);
7969
+ const key$1 = stack([0, seed]);
7970
+ if (key$1 instanceof Array$1) key$1._realizeSource();
7971
+ return key$1;
7771
7972
  }
7772
7973
  /** Splits a PRNG key into `num` new keys by adding a leading axis. */
7773
7974
  function split(key$1, num = 2) {
@@ -7962,6 +8163,11 @@ Symbol.asyncDispose ??= Symbol.for("Symbol.asyncDispose");
7962
8163
 
7963
8164
  //#endregion
7964
8165
  //#region src/index.ts
8166
+ /** @namespace */
8167
+ const profiler = {
8168
+ startTrace: require_backend.startTrace,
8169
+ stopTrace: require_backend.stopTrace
8170
+ };
7965
8171
  /**
7966
8172
  * @function
7967
8173
  * Compute the forward-mode Jacobian-vector product for a function.
@@ -8158,6 +8364,7 @@ Object.defineProperty(exports, 'numpy', {
8158
8364
  return numpy_exports;
8159
8365
  }
8160
8366
  });
8367
+ exports.profiler = profiler;
8161
8368
  Object.defineProperty(exports, 'random', {
8162
8369
  enumerable: true,
8163
8370
  get: function () {
package/dist/index.d.cts CHANGED
@@ -1004,6 +1004,8 @@ declare abstract class Tracer {
1004
1004
  reshape(shape: number | number[]): this;
1005
1005
  /** Copy the array and cast to a specified dtype. */
1006
1006
  astype(dtype: DType): this;
1007
+ /** Return a bitwise cast of the array, viewed as a new dtype. */
1008
+ view(dtype?: DType): this;
1007
1009
  /** Subtract an array from this one. */
1008
1010
  sub(other: this | TracerValue): this;
1009
1011
  /** Divide an array by this one. */
@@ -1427,8 +1429,10 @@ declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
1427
1429
  unitDiagonal?: boolean;
1428
1430
  }): Array;
1429
1431
  declare namespace lax_d_exports {
1430
- export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient, topK };
1432
+ export { DotDimensionNumbers, PaddingType, bitcastConvertType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient, topK };
1431
1433
  }
1434
+ /** Elementwise bitcast an array into a new dtype. */
1435
+ declare function bitcastConvertType(x: ArrayLike, newDtype: DType): Array;
1432
1436
  /**
1433
1437
  * Dimension numbers for general `dot()` primitive.
1434
1438
  *
@@ -1567,7 +1571,7 @@ declare function fft(a: ComplexPair, axis?: number): ComplexPair;
1567
1571
  */
1568
1572
  declare function ifft(a: ComplexPair, axis?: number): ComplexPair;
1569
1573
  declare namespace numpy_linalg_d_exports {
1570
- export { cholesky, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot };
1574
+ export { cholesky, cross$1 as cross, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot };
1571
1575
  }
1572
1576
  /**
1573
1577
  * Compute the Cholesky decomposition of a (batched) positive-definite matrix.
@@ -1582,6 +1586,13 @@ declare function cholesky(a: ArrayLike, {
1582
1586
  upper?: boolean;
1583
1587
  symmetrizeInput?: boolean;
1584
1588
  }): Array;
1589
+ /**
1590
+ * Compute the cross-product of two 3D vectors.
1591
+ *
1592
+ * This is a simpler and less flexible version of `jax.numpy.cross()`.
1593
+ * Both inputs must have size 3 along the specified axis.
1594
+ */
1595
+ declare function cross$1(x1: ArrayLike, x2: ArrayLike, axis?: number): Array;
1585
1596
  /** Compute the determinant of a square matrix (batched). */
1586
1597
  declare function det(a: ArrayLike): Array;
1587
1598
  /** Compute the inverse of a square matrix (batched). */
@@ -1668,7 +1679,7 @@ type IInfo = Readonly<{
1668
1679
  /** Machine limits for integer types. */
1669
1680
  declare function iinfo(dtype: DType): IInfo;
1670
1681
  declare namespace numpy_d_exports {
1671
- export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logspace, matmul, matrixTranspose, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, nanToNum, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
1682
+ export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, round as around, array, arrayEqual, arrayEquiv, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, average, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, copysign, corrcoef, correlate, cos, cosh, cov, cross, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logicalAnd, logicalNot, logicalOr, logicalXor, logspace, matmul, matrixTranspose, matvec, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, nanToNum, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, rint, round, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vecmat, vstack, where, zeros, zerosLike };
1672
1683
  }
1673
1684
  declare const float32 = DType.Float32;
1674
1685
  declare const int32 = DType.Int32;
@@ -1728,6 +1739,14 @@ declare const notEqual: (x: ArrayLike, y: ArrayLike) => Array;
1728
1739
  declare const greaterEqual: (x: ArrayLike, y: ArrayLike) => Array;
1729
1740
  /** @function Compare two arrays element-wise. */
1730
1741
  declare const lessEqual: (x: ArrayLike, y: ArrayLike) => Array;
1742
+ /** Compute element-wise logical AND. */
1743
+ declare function logicalAnd(x: ArrayLike, y: ArrayLike): Array;
1744
+ /** Compute element-wise logical OR. */
1745
+ declare function logicalOr(x: ArrayLike, y: ArrayLike): Array;
1746
+ /** Compute element-wise logical XOR. */
1747
+ declare function logicalXor(x: ArrayLike, y: ArrayLike): Array;
1748
+ /** Compute element-wise logical NOT. */
1749
+ declare function logicalNot(x: ArrayLike): Array;
1731
1750
  /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
1732
1751
  declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
1733
1752
  /**
@@ -1812,6 +1831,16 @@ declare function all(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1812
1831
  declare function ptp(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1813
1832
  /** Compute the average of the array elements along the specified axis. */
1814
1833
  declare function mean(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
1834
+ /**
1835
+ * Compute the weighted average along the specified axis.
1836
+ *
1837
+ * If no axis is specified, mean is computed along all the axes. The weights
1838
+ * should have shape matching that of `a`, or if an axis is specified, it should
1839
+ * match the shape along those axes.
1840
+ */
1841
+ declare function average(a: ArrayLike, axis?: Axis, opts?: {
1842
+ weights?: ArrayLike;
1843
+ } & ReduceOpts): Array;
1815
1844
  /**
1816
1845
  * Returns the indices of the minimum values along an axis.
1817
1846
  *
@@ -1983,13 +2012,39 @@ declare function argsort(a: ArrayLike, axis?: number): Array;
1983
2012
  * numbered axis. By default, the flattened array is used.
1984
2013
  */
1985
2014
  declare function take(a: ArrayLike, indices: ArrayLike, axis?: number | null): Array;
1986
- /** Return if two arrays are element-wise equal within a tolerance. */
2015
+ /**
2016
+ * Return if two arrays are element-wise equal within a tolerance.
2017
+ *
2018
+ * The formula used is `|actual - expected| <= atol + rtol * |expected|`, with
2019
+ * NaN values comparing equal if `equalNaN` is true.
2020
+ */
1987
2021
  declare function allclose(actual: Parameters<typeof array>[0], expected: Parameters<typeof array>[0], options?: {
1988
2022
  rtol?: number;
1989
2023
  atol?: number;
2024
+ equalNaN?: boolean;
1990
2025
  }): boolean;
2026
+ /**
2027
+ * Check if two arrays are element-wise equal.
2028
+ *
2029
+ * Returns False if the arrays have different shapes. If `equalNaN` is True,
2030
+ * NaNs in the same position are considered equal.
2031
+ */
2032
+ declare function arrayEqual(a1: ArrayLike, a2: ArrayLike, opts?: {
2033
+ equalNaN?: boolean;
2034
+ }): Array;
2035
+ /**
2036
+ * Check if two arrays are element-wise equal after broadcasting.
2037
+ *
2038
+ * Unlike `arrayEqual`, this allows inputs with different but
2039
+ * broadcast-compatible shapes.
2040
+ */
2041
+ declare function arrayEquiv(a1: ArrayLike, a2: ArrayLike): Array;
1991
2042
  /** Matrix product of two arrays. */
1992
2043
  declare function matmul(x: ArrayLike, y: ArrayLike): Array;
2044
+ /** Matrix-vector product. x1 is [..., M, N], x2 is [..., N] → [..., M]. */
2045
+ declare function matvec(x1: ArrayLike, x2: ArrayLike): Array;
2046
+ /** Vector-matrix product. x1 is [..., N], x2 is [..., N, M] → [..., M]. */
2047
+ declare function vecmat(x1: ArrayLike, x2: ArrayLike): Array;
1993
2048
  /** Dot product of two arrays. */
1994
2049
  declare function dot(x: ArrayLike, y: ArrayLike): Array;
1995
2050
  /**
@@ -2042,6 +2097,18 @@ declare function inner(x: ArrayLike, y: ArrayLike): Array;
2042
2097
  * be of shape `[x.size, y.size]`.
2043
2098
  */
2044
2099
  declare function outer(x: ArrayLike, y: ArrayLike): Array;
2100
+ /**
2101
+ * @function Compute the cross product of two arrays.
2102
+ *
2103
+ * Supports 2D (scalar result) and 3D cross products, with optional axis
2104
+ * arguments. If `axis` is given, it overrides `axisa`, `axisb`, and `axisc`.
2105
+ */
2106
+ declare const cross: OwnedFunction<(a: ArrayLike, b: ArrayLike, args_2?: {
2107
+ axisa?: number | undefined;
2108
+ axisb?: number | undefined;
2109
+ axisc?: number | undefined;
2110
+ axis?: number | undefined;
2111
+ } | undefined) => Array>;
2045
2112
  /** Vector dot product of two arrays along a given axis. */
2046
2113
  declare function vecdot(x: ArrayLike, y: ArrayLike, {
2047
2114
  axis
@@ -2087,14 +2154,13 @@ declare function clip(a: ArrayLike, min?: ArrayLike, max?: ArrayLike): Array;
2087
2154
  declare function absolute(x: ArrayLike): Array;
2088
2155
  /** Return an element-wise indication of sign of the input. */
2089
2156
  declare function sign(x: ArrayLike): Array;
2090
- /** @function Return element-wise positive values of the input (no-op). */
2091
- declare const positive: (x: ArrayLike) => Array;
2092
2157
  /**
2093
- * Return the Hamming window of size M, a taper with a weighted cosine bell.
2094
- *
2095
- * `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
2158
+ * @function
2159
+ * Return the value with the magnitude of x and the sign of y, element-wise.
2096
2160
  */
2097
- declare function hamming(M: number): Array;
2161
+ declare const copysign: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
2162
+ /** @function Return element-wise positive values of the input (no-op). */
2163
+ declare const positive: (x: ArrayLike) => Array;
2098
2164
  /**
2099
2165
  * Return the Hann window of size M, a taper with a weighted cosine bell.
2100
2166
  *
@@ -2189,6 +2255,18 @@ declare const remainder: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
2189
2255
  declare function divmod(x: ArrayLike, y: ArrayLike): [Array, Array];
2190
2256
  /** Round input to the nearest integer towards zero. */
2191
2257
  declare function trunc(x: ArrayLike): Array;
2258
+ /**
2259
+ * @function
2260
+ * Round to the given number of decimals.
2261
+ *
2262
+ * Uses banker's rounding (round half to even) to match NumPy/JAX behavior.
2263
+ */
2264
+ declare const round: OwnedFunction<(a: ArrayLike, decimals?: number | undefined) => Array>;
2265
+ /**
2266
+ * @function
2267
+ * Round to the nearest integer, with ties going to the nearest even integer.
2268
+ */
2269
+ declare const rint: OwnedFunction<(x: ArrayLike) => Array>;
2192
2270
  /**
2193
2271
  * Compute `x1 * 2 ** x2` as a standard multiplication and exponentiation.
2194
2272
  *
@@ -2691,8 +2769,31 @@ declare namespace scipy_special_d_exports {
2691
2769
  * The logit function, `logit(p) = log(p / (1-p))`.
2692
2770
  */
2693
2771
  declare const logit: OwnedFunction<(x: ArrayLike) => Array>;
2772
+ //#endregion
2773
+ //#region src/tracing.d.ts
2774
+ /**
2775
+ * Start collecting kernel traces.
2776
+ *
2777
+ * Traces appear in developer tools under the "Performance" tab, and they are
2778
+ * useful for measuring fine-grained kernel execution time.
2779
+ */
2780
+ declare function startTrace(): void;
2781
+ /**
2782
+ * Stop collecting kernel traces.
2783
+ *
2784
+ * Traces appear in developer tools under the "Performance" tab, and they are
2785
+ * useful for measuring fine-grained kernel execution time.
2786
+ */
2787
+ declare function stopTrace(): void;
2788
+ /** Check if tracing is currently enabled. */
2789
+
2694
2790
  //#endregion
2695
2791
  //#region src/index.d.ts
2792
+ /** @namespace */
2793
+ declare const profiler: {
2794
+ startTrace: typeof startTrace;
2795
+ stopTrace: typeof stopTrace;
2796
+ };
2696
2797
  /**
2697
2798
  * @function
2698
2799
  * Compute the forward-mode Jacobian-vector product for a function.
@@ -2857,4 +2958,4 @@ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
2857
2958
  */
2858
2959
  declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
2859
2960
  //#endregion
2860
- export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
2961
+ export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, profiler, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };