@jax-js/jax 0.1.8 → 0.1.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.cjs CHANGED
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
30
30
  }) : target, mod$1));
31
31
 
32
32
  //#endregion
33
- const require_backend = require('./backend-B3foXiV_.cjs');
33
+ const require_backend = require('./backend-DMauYnfl.cjs');
34
34
 
35
35
  //#region src/frontend/convolution.ts
36
36
  /**
@@ -838,6 +838,11 @@ var Tracer = class Tracer {
838
838
  if (this.dtype === dtype) return this;
839
839
  return cast(this, dtype);
840
840
  }
841
+ /** Return a bitwise cast of the array, viewed as a new dtype. */
842
+ view(dtype) {
843
+ if (!dtype || dtype === this.dtype) return this;
844
+ return bitcast(this, dtype);
845
+ }
841
846
  /** Subtract an array from this one. */
842
847
  sub(other) {
843
848
  return this.add(neg(other));
@@ -920,18 +925,25 @@ var Tracer = class Tracer {
920
925
  return sort$1(this.transpose(perm)).transpose(require_backend.invertPermutation(perm));
921
926
  }
922
927
  /**
923
- * Return the indices that would sort an array. This may not be a stable
924
- * sorting algorithm; it need not preserve order of indices in ties.
928
+ * Return the indices that would sort an array. Unlike `sort`, this is
929
+ * guaranteed to be a stable sorting algorithm; it always returns the smaller
930
+ * index first in event of ties.
925
931
  *
926
932
  * See `jax.numpy.argsort` for full docs.
927
933
  */
928
934
  argsort(axis = -1) {
929
935
  axis = require_backend.checkAxis(axis, this.ndim);
930
- if (axis === this.ndim - 1) return argsort$1(this)[1];
936
+ if (axis === this.ndim - 1) {
937
+ const [y$1, yi$1] = argsort$1(this);
938
+ y$1.dispose();
939
+ return yi$1;
940
+ }
931
941
  const perm = require_backend.range(this.ndim);
932
942
  perm.splice(axis, 1);
933
943
  perm.push(axis);
934
- return argsort$1(this.transpose(perm))[1].transpose(require_backend.invertPermutation(perm));
944
+ const [y, yi] = argsort$1(this.transpose(perm));
945
+ y.dispose();
946
+ return yi.transpose(require_backend.invertPermutation(perm));
935
947
  }
936
948
  /**
937
949
  * Slice an array along one or more axes.
@@ -1652,7 +1664,7 @@ const abstractEvalRules = {
1652
1664
  return [new ShapedArray(x.shape, dtype, false)];
1653
1665
  },
1654
1666
  [Primitive.Bitcast]([x], { dtype }) {
1655
- if (x.dtype === require_backend.DType.Bool || dtype === require_backend.DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
1667
+ if (x.dtype !== dtype && (x.dtype === require_backend.DType.Bool || dtype === require_backend.DType.Bool)) throw new TypeError("Bitcast to/from bool is not allowed");
1656
1668
  if (require_backend.byteWidth(x.dtype) !== require_backend.byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
1657
1669
  return [new ShapedArray(x.shape, dtype, false)];
1658
1670
  },
@@ -3074,8 +3086,8 @@ var Array$1 = class Array$1 extends Tracer {
3074
3086
  return [x.#unary(require_backend.AluOp.Cast, dtype)];
3075
3087
  },
3076
3088
  [Primitive.Bitcast]([x], { dtype }) {
3077
- if (x.dtype === require_backend.DType.Bool || dtype === require_backend.DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
3078
3089
  if (x.dtype === dtype) return [x];
3090
+ if (x.dtype === require_backend.DType.Bool || dtype === require_backend.DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
3079
3091
  if (require_backend.byteWidth(x.dtype) !== require_backend.byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
3080
3092
  if (x.#source instanceof require_backend.AluExp) return [x.#unary(require_backend.AluOp.Bitcast, dtype)];
3081
3093
  else {
@@ -3416,32 +3428,26 @@ function fullInternal(aval, fillValue, device) {
3416
3428
  committed: device != void 0
3417
3429
  });
3418
3430
  }
3419
- function zerosLike$1(val, dtype) {
3420
- return fullLike(val, 0, dtype);
3431
+ function zerosLike$1(val, opts) {
3432
+ return fullLike(val, 0, opts);
3421
3433
  }
3422
- function onesLike$1(val, dtype) {
3423
- return fullLike(val, 1, dtype);
3434
+ function onesLike$1(val, opts) {
3435
+ return fullLike(val, 1, opts);
3424
3436
  }
3425
- function fullLike(val, fillValue, dtype) {
3437
+ function fullLike(val, fillValue, { dtype, shape: shape$1, device } = {}) {
3426
3438
  const aval = getAval(val);
3427
3439
  if (val instanceof Tracer) val.dispose();
3428
3440
  if (fillValue instanceof Tracer) throw new Error("numpy.fullLike() with array argument not implemented yet");
3429
- const sa = new ShapedArray(aval.shape, dtype ?? aval.dtype, aval.weakType);
3430
- return fullInternal(sa, fillValue);
3441
+ const sa = new ShapedArray(shape$1 ?? aval.shape, dtype ?? aval.dtype, aval.weakType && dtype === void 0);
3442
+ return fullInternal(sa, fillValue, device);
3431
3443
  }
3432
3444
  /** Return a new array of given shape and type, filled with zeros. */
3433
- function zeros(shape$1, { dtype, device } = {}) {
3434
- return full(shape$1, 0, {
3435
- dtype,
3436
- device
3437
- });
3445
+ function zeros(shape$1, opts) {
3446
+ return full(shape$1, 0, opts);
3438
3447
  }
3439
3448
  /** Return a new array of given shape and type, filled with ones. */
3440
- function ones(shape$1, { dtype, device } = {}) {
3441
- return full(shape$1, 1, {
3442
- dtype,
3443
- device
3444
- });
3449
+ function ones(shape$1, opts) {
3450
+ return full(shape$1, 1, opts);
3445
3451
  }
3446
3452
  /** Return a new array of given shape and type, filled with `fill_value`. */
3447
3453
  function full(shape$1, fillValue, { dtype, device } = {}) {
@@ -4178,6 +4184,7 @@ const jvpRules = {
4178
4184
  },
4179
4185
  [Primitive.TriangularSolve]([a, b], [da, db], { unitDiagonal }) {
4180
4186
  const x = triangularSolve$1(a.ref, b, { unitDiagonal });
4187
+ da = unitDiagonal ? triu(da, 1) : triu(da);
4181
4188
  const dax = batchMatmulT(da, x.ref);
4182
4189
  const rhsT = db.sub(mT(dax));
4183
4190
  const dx = triangularSolve$1(a, rhsT, { unitDiagonal });
@@ -5253,6 +5260,7 @@ function ifft(a, axis = -1) {
5253
5260
  var numpy_linalg_exports = {};
5254
5261
  __export(numpy_linalg_exports, {
5255
5262
  cholesky: () => cholesky,
5263
+ cross: () => cross$1,
5256
5264
  det: () => det,
5257
5265
  diagonal: () => diagonal,
5258
5266
  inv: () => inv,
@@ -5283,6 +5291,19 @@ function cholesky(a, { upper = false, symmetrizeInput = true } = {}) {
5283
5291
  if (symmetrizeInput) a = a.ref.add(matrixTranspose(a)).mul(.5);
5284
5292
  return cholesky$1(a, { upper });
5285
5293
  }
5294
+ /**
5295
+ * Compute the cross-product of two 3D vectors.
5296
+ *
5297
+ * This is a simpler and less flexible version of `jax.numpy.cross()`.
5298
+ * Both inputs must have size 3 along the specified axis.
5299
+ */
5300
+ function cross$1(x1, x2, axis = -1) {
5301
+ const a1 = require_backend.checkAxis(axis, ndim(x1));
5302
+ const a2 = require_backend.checkAxis(axis, ndim(x2));
5303
+ if (shape(x1)[a1] !== 3) throw new Error(`linalg.cross: x1 must have size 3 along axis ${axis}, got ${shape(x1)[a1]}`);
5304
+ if (shape(x2)[a2] !== 3) throw new Error(`linalg.cross: x2 must have size 3 along axis ${axis}, got ${shape(x2)[a2]}`);
5305
+ return cross(x1, x2, { axis });
5306
+ }
5286
5307
  /** Compute the determinant of a square matrix (batched). */
5287
5308
  function det(a) {
5288
5309
  a = fudgeArray(a);
@@ -5298,7 +5319,7 @@ function det(a) {
5298
5319
  function inv(a) {
5299
5320
  a = fudgeArray(a);
5300
5321
  const n = checkSquare("inv", a);
5301
- return solve(a, eye(n));
5322
+ return solve(a, eye(n, void 0, { dtype: a.dtype }));
5302
5323
  }
5303
5324
  /**
5304
5325
  * Return the least-squares solution to a linear equation.
@@ -5332,7 +5353,7 @@ function lstsq(a, b) {
5332
5353
  lower: true,
5333
5354
  transposeA: true
5334
5355
  });
5335
- return matmul(at, llb.ref);
5356
+ return matmul(at, llb);
5336
5357
  } else {
5337
5358
  const ata = matmul(at.ref, a);
5338
5359
  const l = cholesky(ata, { symmetrizeInput: false });
@@ -5355,8 +5376,9 @@ function matrixPower(a, n) {
5355
5376
  a = fudgeArray(a);
5356
5377
  const m = checkSquare("matrixPower", a);
5357
5378
  if (n === 0) {
5379
+ const dtype = a.dtype;
5358
5380
  a.dispose();
5359
- return broadcastTo(eye(m), a.shape);
5381
+ return broadcastTo(eye(m, void 0, { dtype }), a.shape);
5360
5382
  }
5361
5383
  if (n < 0) {
5362
5384
  a = inv(a);
@@ -5423,7 +5445,7 @@ function solve(a, b) {
5423
5445
  lower: true,
5424
5446
  unitDiagonal: true
5425
5447
  });
5426
- let x = triangularSolve(lu$2, LPb.ref, {
5448
+ let x = triangularSolve(lu$2, LPb, {
5427
5449
  leftSide: true,
5428
5450
  lower: false
5429
5451
  });
@@ -5538,13 +5560,17 @@ __export(numpy_exports, {
5538
5560
  argmax: () => argmax,
5539
5561
  argmin: () => argmin,
5540
5562
  argsort: () => argsort,
5563
+ around: () => round,
5541
5564
  array: () => array,
5565
+ arrayEqual: () => arrayEqual,
5566
+ arrayEquiv: () => arrayEquiv,
5542
5567
  asin: () => asin,
5543
5568
  asinh: () => arcsinh,
5544
5569
  astype: () => astype,
5545
5570
  atan: () => atan,
5546
5571
  atan2: () => atan2,
5547
5572
  atanh: () => arctanh,
5573
+ average: () => average,
5548
5574
  bool: () => bool,
5549
5575
  broadcastArrays: () => broadcastArrays,
5550
5576
  broadcastShapes: () => broadcastShapes,
@@ -5555,11 +5581,13 @@ __export(numpy_exports, {
5555
5581
  columnStack: () => columnStack,
5556
5582
  concatenate: () => concatenate,
5557
5583
  convolve: () => convolve,
5584
+ copysign: () => copysign,
5558
5585
  corrcoef: () => corrcoef,
5559
5586
  correlate: () => correlate,
5560
5587
  cos: () => cos,
5561
5588
  cosh: () => cosh,
5562
5589
  cov: () => cov,
5590
+ cross: () => cross,
5563
5591
  cumsum: () => cumsum,
5564
5592
  cumulativeSum: () => cumsum,
5565
5593
  deg2rad: () => deg2rad,
@@ -5595,7 +5623,6 @@ __export(numpy_exports, {
5595
5623
  fullLike: () => fullLike$1,
5596
5624
  greater: () => greater,
5597
5625
  greaterEqual: () => greaterEqual,
5598
- hamming: () => hamming,
5599
5626
  hann: () => hann,
5600
5627
  heaviside: () => heaviside,
5601
5628
  hstack: () => hstack,
@@ -5619,9 +5646,14 @@ __export(numpy_exports, {
5619
5646
  log10: () => log10,
5620
5647
  log1p: () => log1p,
5621
5648
  log2: () => log2,
5649
+ logicalAnd: () => logicalAnd,
5650
+ logicalNot: () => logicalNot,
5651
+ logicalOr: () => logicalOr,
5652
+ logicalXor: () => logicalXor,
5622
5653
  logspace: () => logspace,
5623
5654
  matmul: () => matmul,
5624
5655
  matrixTranspose: () => matrixTranspose,
5656
+ matvec: () => matvec,
5625
5657
  max: () => max,
5626
5658
  maximum: () => maximum,
5627
5659
  mean: () => mean,
@@ -5654,6 +5686,8 @@ __export(numpy_exports, {
5654
5686
  remainder: () => remainder,
5655
5687
  repeat: () => repeat,
5656
5688
  reshape: () => reshape,
5689
+ rint: () => rint,
5690
+ round: () => round,
5657
5691
  shape: () => shape,
5658
5692
  sign: () => sign,
5659
5693
  sin: () => sin,
@@ -5686,6 +5720,7 @@ __export(numpy_exports, {
5686
5720
  var_: () => var_,
5687
5721
  vdot: () => vdot,
5688
5722
  vecdot: () => vecdot,
5723
+ vecmat: () => vecmat,
5689
5724
  vstack: () => vstack,
5690
5725
  where: () => where,
5691
5726
  zeros: () => zeros,
@@ -5749,6 +5784,22 @@ const notEqual = notEqual$1;
5749
5784
  const greaterEqual = greaterEqual$1;
5750
5785
  /** @function Compare two arrays element-wise. */
5751
5786
  const lessEqual = lessEqual$1;
5787
+ /** Compute element-wise logical AND. */
5788
+ function logicalAnd(x, y) {
5789
+ return astype(x, require_backend.DType.Bool).mul(astype(y, require_backend.DType.Bool));
5790
+ }
5791
+ /** Compute element-wise logical OR. */
5792
+ function logicalOr(x, y) {
5793
+ return astype(x, require_backend.DType.Bool).add(astype(y, require_backend.DType.Bool));
5794
+ }
5795
+ /** Compute element-wise logical XOR. */
5796
+ function logicalXor(x, y) {
5797
+ return notEqual(astype(x, require_backend.DType.Bool), astype(y, require_backend.DType.Bool));
5798
+ }
5799
+ /** Compute element-wise logical NOT. */
5800
+ function logicalNot(x) {
5801
+ return notEqual(astype(x, require_backend.DType.Bool), true);
5802
+ }
5752
5803
  /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
5753
5804
  const where = where$1;
5754
5805
  /**
@@ -5856,6 +5907,34 @@ function mean(a, axis = null, opts) {
5856
5907
  return fudgeArray(a).mean(axis, opts);
5857
5908
  }
5858
5909
  /**
5910
+ * Compute the weighted average along the specified axis.
5911
+ *
5912
+ * If no axis is specified, mean is computed along all the axes. The weights
5913
+ * should have shape matching that of `a`, or if an axis is specified, it should
5914
+ * match the shape along those axes.
5915
+ */
5916
+ function average(a, axis = null, opts) {
5917
+ a = fudgeArray(a);
5918
+ if (opts?.weights == null) return mean(a, axis, opts);
5919
+ const weights = fudgeArray(opts.weights);
5920
+ axis = require_backend.normalizeAxis(axis, ndim(a));
5921
+ const wShape = weights.shape;
5922
+ const aShape = a.shape;
5923
+ if (require_backend.deepEqual(wShape, aShape)) {
5924
+ const scl = sum(weights.ref, axis, opts);
5925
+ return sum(multiply(a, weights), axis, opts).div(scl);
5926
+ } else if (axis.length === 1 && wShape.length === 1 && wShape[0] === aShape[axis[0]]) {
5927
+ const broadcastShape = aShape.map((_, i) => i === axis[0] ? wShape[0] : 1);
5928
+ const wReshaped = reshape(weights, broadcastShape);
5929
+ const scl = sum(wReshaped.ref, axis, opts);
5930
+ return sum(multiply(a, wReshaped), axis, opts).div(scl);
5931
+ } else {
5932
+ weights.dispose();
5933
+ a.dispose();
5934
+ throw new Error(`average: weights shape ${JSON.stringify(wShape)} is not compatible with array shape ${JSON.stringify(aShape)} and axis ${JSON.stringify(axis)}`);
5935
+ }
5936
+ }
5937
+ /**
5859
5938
  * Returns the indices of the minimum values along an axis.
5860
5939
  *
5861
5940
  * By default, index is into the flatted array, otherwise it is along the
@@ -6234,8 +6313,9 @@ function sort(a, axis = -1) {
6234
6313
  return fudgeArray(a).sort(axis);
6235
6314
  }
6236
6315
  /**
6237
- * Return indices that would sort an array. This may be an unstable sorting
6238
- * algorithm; it need not preserve order of indices in ties.
6316
+ * Return indices that would sort an array. Unlike `sort`, this is guaranteed to
6317
+ * be a stable sorting algorithm; it always returns the smaller index first in
6318
+ * event of ties.
6239
6319
  *
6240
6320
  * Returns an array of `int32` indices.
6241
6321
  *
@@ -6258,20 +6338,63 @@ function take(a, indices, axis = null) {
6258
6338
  axis = require_backend.checkAxis(axis, ndim(a));
6259
6339
  return gather(a, [indices], [axis], axis);
6260
6340
  }
6261
- /** Return if two arrays are element-wise equal within a tolerance. */
6341
+ /**
6342
+ * Return if two arrays are element-wise equal within a tolerance.
6343
+ *
6344
+ * The formula used is `|actual - expected| <= atol + rtol * |expected|`, with
6345
+ * NaN values comparing equal if `equalNaN` is true.
6346
+ */
6262
6347
  function allclose(actual, expected, options) {
6263
- const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
6348
+ const { rtol = 1e-5, atol = 1e-7, equalNaN = false } = options ?? {};
6264
6349
  const x = array(actual);
6265
6350
  const y = array(expected);
6266
6351
  if (!require_backend.deepEqual(x.shape, y.shape)) return false;
6267
6352
  const xData = x.dataSync();
6268
6353
  const yData = y.dataSync();
6269
6354
  for (let i = 0; i < xData.length; i++) {
6270
- if (isNaN(xData[i]) !== isNaN(yData[i])) return false;
6355
+ if (equalNaN ? isNaN(xData[i]) !== isNaN(yData[i]) : isNaN(xData[i]) || isNaN(yData[i])) return false;
6271
6356
  if (Math.abs(xData[i] - yData[i]) > atol + rtol * Math.abs(yData[i])) return false;
6272
6357
  }
6273
6358
  return true;
6274
6359
  }
6360
+ /**
6361
+ * Check if two arrays are element-wise equal.
6362
+ *
6363
+ * Returns False if the arrays have different shapes. If `equalNaN` is True,
6364
+ * NaNs in the same position are considered equal.
6365
+ */
6366
+ function arrayEqual(a1, a2, opts) {
6367
+ a1 = fudgeArray(a1);
6368
+ a2 = fudgeArray(a2);
6369
+ if (!require_backend.deepEqual(a1.shape, a2.shape)) {
6370
+ a1.dispose();
6371
+ a2.dispose();
6372
+ return array(false);
6373
+ }
6374
+ if (opts?.equalNaN) {
6375
+ const nanMask = isnan(a1.ref).mul(isnan(a2.ref));
6376
+ return where(nanMask, true, equal(a1, a2)).all();
6377
+ }
6378
+ return equal(a1, a2).all();
6379
+ }
6380
+ /**
6381
+ * Check if two arrays are element-wise equal after broadcasting.
6382
+ *
6383
+ * Unlike `arrayEqual`, this allows inputs with different but
6384
+ * broadcast-compatible shapes.
6385
+ */
6386
+ function arrayEquiv(a1, a2) {
6387
+ a1 = fudgeArray(a1);
6388
+ a2 = fudgeArray(a2);
6389
+ try {
6390
+ const [b1, b2] = broadcastArrays(a1, a2);
6391
+ return equal(b1, b2).all();
6392
+ } catch {
6393
+ a1.dispose();
6394
+ a2.dispose();
6395
+ return array(false);
6396
+ }
6397
+ }
6275
6398
  /** Matrix product of two arrays. */
6276
6399
  function matmul(x, y) {
6277
6400
  if (ndim(x) === 0 || ndim(y) === 0) throw new Error("matmul: x and y must be at least 1D");
@@ -6285,6 +6408,16 @@ function matmul(x, y) {
6285
6408
  rhsBatchDims: require_backend.range(-2 - numBatchDims, -2)
6286
6409
  });
6287
6410
  }
6411
+ /** Matrix-vector product. x1 is [..., M, N], x2 is [..., N] → [..., M]. */
6412
+ function matvec(x1, x2) {
6413
+ if (ndim(x1) < 2 || ndim(x2) < 1) throw new Error("matvec: x1 must be at least 2D and x2 at least 1D");
6414
+ return einsum("...mn,...n->...m", x1, x2);
6415
+ }
6416
+ /** Vector-matrix product. x1 is [..., N], x2 is [..., N, M] → [..., M]. */
6417
+ function vecmat(x1, x2) {
6418
+ if (ndim(x1) < 1 || ndim(x2) < 2) throw new Error("vecmat: x1 must be at least 1D and x2 at least 2D");
6419
+ return einsum("...n,...nm->...m", x1, x2);
6420
+ }
6288
6421
  /** Dot product of two arrays. */
6289
6422
  function dot$1(x, y) {
6290
6423
  if (ndim(x) === 0 || ndim(y) === 0) return multiply(x, y);
@@ -6443,6 +6576,49 @@ function outer(x, y) {
6443
6576
  y = ravel(y);
6444
6577
  return multiply(x.reshape([x.shape[0], 1]), y);
6445
6578
  }
6579
+ /**
6580
+ * @function Compute the cross product of two arrays.
6581
+ *
6582
+ * Supports 2D (scalar result) and 3D cross products, with optional axis
6583
+ * arguments. If `axis` is given, it overrides `axisa`, `axisb`, and `axisc`.
6584
+ */
6585
+ const cross = jit$1(function cross$2(a, b, { axisa = -1, axisb = -1, axisc = -1, axis } = {}) {
6586
+ if (axis !== void 0) {
6587
+ axisa = axis;
6588
+ axisb = axis;
6589
+ axisc = axis;
6590
+ }
6591
+ axisa = require_backend.checkAxis(axisa, ndim(a));
6592
+ axisb = require_backend.checkAxis(axisb, ndim(b));
6593
+ a = moveaxis$1(a, axisa, -1);
6594
+ b = moveaxis$1(b, axisb, -1);
6595
+ const da = a.shape.at(-1);
6596
+ const db = b.shape.at(-1);
6597
+ if (da !== 2 && da !== 3 || db !== 2 && db !== 3) throw new Error(`cross: incompatible dimensions for cross product (got ${da} and ${db})`);
6598
+ if (da === 2 && db === 2) {
6599
+ const [a0$1, a1$1] = split$1(a, 2, -1);
6600
+ const [b0$1, b1$1] = split$1(b, 2, -1);
6601
+ return squeeze(a0$1.mul(b1$1).sub(a1$1.mul(b0$1)), -1);
6602
+ }
6603
+ if (da === 2) {
6604
+ const zeroShape = [...a.shape.slice(0, -1), 1];
6605
+ a = concatenate([a, zeros(zeroShape)], -1);
6606
+ }
6607
+ if (db === 2) {
6608
+ const zeroShape = [...b.shape.slice(0, -1), 1];
6609
+ b = concatenate([b, zeros(zeroShape)], -1);
6610
+ }
6611
+ const [a0, a1, a2] = split$1(a, 3, -1);
6612
+ const [b0, b1, b2] = split$1(b, 3, -1);
6613
+ const c0 = a1.ref.mul(b2.ref).sub(a2.ref.mul(b1.ref));
6614
+ const c1 = a2.mul(b0.ref).sub(a0.ref.mul(b2));
6615
+ const c2 = a0.mul(b1).sub(a1.mul(b0));
6616
+ return moveaxis$1(concatenate([
6617
+ c0,
6618
+ c1,
6619
+ c2
6620
+ ], -1), -1, axisc);
6621
+ }, { staticArgnums: [2] });
6446
6622
  /** Vector dot product of two arrays along a given axis. */
6447
6623
  function vecdot(x, y, { axis } = {}) {
6448
6624
  const xaxis = require_backend.checkAxis(axis ?? -1, ndim(x));
@@ -6537,18 +6713,17 @@ function absolute(x) {
6537
6713
  /** Return an element-wise indication of sign of the input. */
6538
6714
  function sign(x) {
6539
6715
  x = fudgeArray(x);
6540
- return where(notEqual(x.ref, 0), where(less(x.ref, 0), -1, 1), 0);
6716
+ return where(notEqual(x.ref, 0), where(less(x, 0), -1, 1), 0);
6541
6717
  }
6542
- /** @function Return element-wise positive values of the input (no-op). */
6543
- const positive = fudgeArray;
6544
6718
  /**
6545
- * Return the Hamming window of size M, a taper with a weighted cosine bell.
6546
- *
6547
- * `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
6719
+ * @function
6720
+ * Return the value with the magnitude of x and the sign of y, element-wise.
6548
6721
  */
6549
- function hamming(M) {
6550
- return cos(linspace(0, 2 * Math.PI, M)).mul(-.46).add(.54);
6551
- }
6722
+ const copysign = jit$1(function copysign$1(x, y) {
6723
+ return absolute(x).mul(sign(y));
6724
+ });
6725
+ /** @function Return element-wise positive values of the input (no-op). */
6726
+ const positive = fudgeArray;
6552
6727
  /**
6553
6728
  * Return the Hann window of size M, a taper with a weighted cosine bell.
6554
6729
  *
@@ -6694,6 +6869,27 @@ function trunc(x) {
6694
6869
  return idiv(x, 1);
6695
6870
  }
6696
6871
  /**
6872
+ * @function
6873
+ * Round to the given number of decimals.
6874
+ *
6875
+ * Uses banker's rounding (round half to even) to match NumPy/JAX behavior.
6876
+ */
6877
+ const round = jit$1(function round$1(a, decimals = 0) {
6878
+ if (decimals === 0) return rint(a);
6879
+ const factor = 10 ** decimals;
6880
+ return rint(a.mul(factor)).mul(1 / factor);
6881
+ }, { staticArgnums: [1] });
6882
+ /**
6883
+ * @function
6884
+ * Round to the nearest integer, with ties going to the nearest even integer.
6885
+ */
6886
+ const rint = jit$1(function rint$1(x) {
6887
+ const rounded = floor(x.ref.add(.5));
6888
+ const half = x.ref.sub(floor(x)).equal(.5);
6889
+ const odd = remainder(rounded.ref, 2).notEqual(0);
6890
+ return where(half.mul(odd), rounded.ref.sub(1), rounded);
6891
+ });
6892
+ /**
6697
6893
  * Compute `x1 * 2 ** x2` as a standard multiplication and exponentiation.
6698
6894
  *
6699
6895
  * This is the inverse of `frexp()`.
@@ -7021,6 +7217,7 @@ function triangularSolve(a, b, { leftSide = false, lower = false, transposeA = f
7021
7217
  //#region src/library/lax.ts
7022
7218
  var lax_exports = {};
7023
7219
  __export(lax_exports, {
7220
+ bitcastConvertType: () => bitcastConvertType,
7024
7221
  conv: () => conv,
7025
7222
  convGeneralDilated: () => convGeneralDilated,
7026
7223
  convTranspose: () => convTranspose,
@@ -7030,9 +7227,14 @@ __export(lax_exports, {
7030
7227
  erfc: () => erfc,
7031
7228
  linalg: () => lax_linalg_exports,
7032
7229
  reduceWindow: () => reduceWindow,
7033
- stopGradient: () => stopGradient$1
7230
+ stopGradient: () => stopGradient$1,
7231
+ topK: () => topK
7034
7232
  });
7035
7233
  const JsArray = globalThis.Array;
7234
+ /** Elementwise bitcast an array into a new dtype. */
7235
+ function bitcastConvertType(x, newDtype) {
7236
+ return fudgeArray(x).view(newDtype);
7237
+ }
7036
7238
  /**
7037
7239
  * General dot product/contraction operator.
7038
7240
  *
@@ -7254,6 +7456,39 @@ function erfc(x) {
7254
7456
  function stopGradient$1(x) {
7255
7457
  return stopGradient(x);
7256
7458
  }
7459
+ /**
7460
+ * Returns top `k` values and their indices along the specified axis of operand.
7461
+ *
7462
+ * This is a _stable_ algorithm: If two elements are equal, the lower-index
7463
+ * element appears first.
7464
+ *
7465
+ * @returns A tuple of `(values, indices)`, where `values` and `indices` have
7466
+ * the same shape as `x`, except along `axis` where they have size `k`.
7467
+ */
7468
+ function topK(x, k, axis = -1) {
7469
+ x = fudgeArray(x);
7470
+ axis = require_backend.checkAxis(axis, x.ndim);
7471
+ const size$1 = x.shape[axis];
7472
+ if (k < 0 || k > size$1) throw new Error(`topK: k must be in the range [0, ${size$1}], got ${k}`);
7473
+ if (k === 0) {
7474
+ const outShape = x.shape.slice();
7475
+ outShape[axis] = 0;
7476
+ const y$1 = zerosLike$1(x.ref, { shape: outShape });
7477
+ const yi$1 = zerosLike$1(x, {
7478
+ dtype: require_backend.DType.Int32,
7479
+ shape: outShape
7480
+ });
7481
+ return [y$1, yi$1];
7482
+ }
7483
+ x = flip$1(x, [axis]);
7484
+ x = moveaxis(x, axis, -1);
7485
+ const [y, yi] = argsort$1(x);
7486
+ const extract = (a) => {
7487
+ a = a.slice(...require_backend.rep(a.ndim - 1, []), [-k]);
7488
+ return flip$1(moveaxis(a, -1, axis), [axis]);
7489
+ };
7490
+ return [extract(y), extract(yi.neg().add(size$1 - 1))];
7491
+ }
7257
7492
 
7258
7493
  //#endregion
7259
7494
  //#region src/library/nn.ts
@@ -7445,7 +7680,7 @@ const gelu = jit$1(function gelu$1(x, opts) {
7445
7680
  if (opts?.approximate ?? true) {
7446
7681
  const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
7447
7682
  return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
7448
- } else return x.ref.mul(.5).mul(erfc$1(negative(x.ref.mul(Math.SQRT1_2))));
7683
+ } else return x.ref.mul(.5).mul(erfc$1(negative(x.mul(Math.SQRT1_2))));
7449
7684
  }, { staticArgnums: [1] });
7450
7685
  /**
7451
7686
  * Gated linear unit (GLU) activation function.
@@ -7703,6 +7938,7 @@ var random_exports = {};
7703
7938
  __export(random_exports, {
7704
7939
  bernoulli: () => bernoulli,
7705
7940
  bits: () => bits,
7941
+ categorical: () => categorical,
7706
7942
  cauchy: () => cauchy,
7707
7943
  exponential: () => exponential,
7708
7944
  gumbel: () => gumbel,
@@ -7730,7 +7966,9 @@ function getK01(key$1) {
7730
7966
  function key(seed) {
7731
7967
  seed = array(seed, { dtype: require_backend.DType.Uint32 });
7732
7968
  if (seed.ndim !== 0) throw new Error(`key: seed must be a scalar integer, but got shape ${seed.shape} - use jax.vmap for batching.`);
7733
- return stack([0, seed]);
7969
+ const key$1 = stack([0, seed]);
7970
+ if (key$1 instanceof Array$1) key$1._realizeSource();
7971
+ return key$1;
7734
7972
  }
7735
7973
  /** Splits a PRNG key into `num` new keys by adding a leading axis. */
7736
7974
  function split(key$1, num = 2) {
@@ -7774,6 +8012,47 @@ function bernoulli(key$1, p = .5, shape$1 = []) {
7774
8012
  }
7775
8013
  /**
7776
8014
  * @function
8015
+ * Sample random values from categorical distributions.
8016
+ *
8017
+ * Uses the Gumbel max trick for sampling with replacement, or the Gumbel top-k
8018
+ * trick for sampling without replacement.
8019
+ *
8020
+ * Note: Sampling without replacement currently uses argsort and slices the last
8021
+ * k elements. This should be replaced with a more efficient topK implementation.
8022
+ *
8023
+ * - `key` - PRNG key
8024
+ * - `logits` - Unnormalized log probabilities of the categorical distribution(s).
8025
+ * `softmax(logits, axis)` gives the corresponding probabilities.
8026
+ * - `axis` - Axis along which logits belong to the same categorical distribution.
8027
+ * - `shape` - Result batch shape. Must be broadcast-compatible with
8028
+ * `logits.shape` with `axis` removed. Default is `logits.shape` with `axis` removed.
8029
+ * - `replace` - If true (default), sample with replacement. If false, sample
8030
+ * without replacement (each category can only be selected once per batch).
8031
+ * @returns A random array with int dtype and shape given by `shape` if provided,
8032
+ * otherwise `logits.shape` with `axis` removed.
8033
+ */
8034
+ const categorical = jit$1(function categorical$1(key$1, logits, { axis = -1, shape: shape$1, replace = true } = {}) {
8035
+ logits = fudgeArray(logits);
8036
+ axis = require_backend.checkAxis(axis, logits.ndim);
8037
+ const numCategories = logits.shape[axis];
8038
+ const batchShape = logits.shape.toSpliced(axis, 1);
8039
+ if (shape$1 === void 0) shape$1 = batchShape;
8040
+ else if (!require_backend.deepEqual(require_backend.generalBroadcast(shape$1, batchShape), shape$1)) throw new Error(`Shape ${shape$1} is not broadcast-compatible with batch shape ${batchShape}.`);
8041
+ const shapePrefix = shape$1.slice(0, shape$1.length - batchShape.length);
8042
+ if (replace) {
8043
+ const noise = gumbel(key$1, [...shapePrefix, ...logits.shape]);
8044
+ return argmax(noise.add(logits), axis + shapePrefix.length);
8045
+ } else {
8046
+ const k = shapePrefix.reduce((a, b) => a * b, 1);
8047
+ if (k > numCategories) throw new Error(`Number of samples without replacement (${k}) cannot exceed number of categories (${numCategories}).`);
8048
+ const noise = gumbel(key$1, logits.shape);
8049
+ const [values, indices] = topK(noise.add(logits), k, axis);
8050
+ values.dispose();
8051
+ return indices.reshape(shape$1);
8052
+ }
8053
+ }, { staticArgnums: [2] });
8054
+ /**
8055
+ * @function
7777
8056
  * Sample from a Cauchy distribution with location 0 and scale 1.
7778
8057
  *
7779
8058
  * Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
@@ -7884,6 +8163,11 @@ Symbol.asyncDispose ??= Symbol.for("Symbol.asyncDispose");
7884
8163
 
7885
8164
  //#endregion
7886
8165
  //#region src/index.ts
8166
+ /** @namespace */
8167
+ const profiler = {
8168
+ startTrace: require_backend.startTrace,
8169
+ stopTrace: require_backend.stopTrace
8170
+ };
7887
8171
  /**
7888
8172
  * @function
7889
8173
  * Compute the forward-mode Jacobian-vector product for a function.
@@ -8080,6 +8364,7 @@ Object.defineProperty(exports, 'numpy', {
8080
8364
  return numpy_exports;
8081
8365
  }
8082
8366
  });
8367
+ exports.profiler = profiler;
8083
8368
  Object.defineProperty(exports, 'random', {
8084
8369
  enumerable: true,
8085
8370
  get: function () {