@jax-js/jax 0.1.8 → 0.1.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -1,5 +1,5 @@
1
1
  import { __export } from "./chunk-Cl8Af3a2.js";
2
- import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-nEolvdLv.js";
2
+ import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, startTrace, stopTrace, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-Ctqs8la1.js";
3
3
 
4
4
  //#region src/frontend/convolution.ts
5
5
  /**
@@ -807,6 +807,11 @@ var Tracer = class Tracer {
807
807
  if (this.dtype === dtype) return this;
808
808
  return cast(this, dtype);
809
809
  }
810
+ /** Return a bitwise cast of the array, viewed as a new dtype. */
811
+ view(dtype) {
812
+ if (!dtype || dtype === this.dtype) return this;
813
+ return bitcast(this, dtype);
814
+ }
810
815
  /** Subtract an array from this one. */
811
816
  sub(other) {
812
817
  return this.add(neg(other));
@@ -889,18 +894,25 @@ var Tracer = class Tracer {
889
894
  return sort$1(this.transpose(perm)).transpose(invertPermutation(perm));
890
895
  }
891
896
  /**
892
- * Return the indices that would sort an array. This may not be a stable
893
- * sorting algorithm; it need not preserve order of indices in ties.
897
+ * Return the indices that would sort an array. Unlike `sort`, this is
898
+ * guaranteed to be a stable sorting algorithm; it always returns the smaller
899
+ * index first in event of ties.
894
900
  *
895
901
  * See `jax.numpy.argsort` for full docs.
896
902
  */
897
903
  argsort(axis = -1) {
898
904
  axis = checkAxis(axis, this.ndim);
899
- if (axis === this.ndim - 1) return argsort$1(this)[1];
905
+ if (axis === this.ndim - 1) {
906
+ const [y$1, yi$1] = argsort$1(this);
907
+ y$1.dispose();
908
+ return yi$1;
909
+ }
900
910
  const perm = range(this.ndim);
901
911
  perm.splice(axis, 1);
902
912
  perm.push(axis);
903
- return argsort$1(this.transpose(perm))[1].transpose(invertPermutation(perm));
913
+ const [y, yi] = argsort$1(this.transpose(perm));
914
+ y.dispose();
915
+ return yi.transpose(invertPermutation(perm));
904
916
  }
905
917
  /**
906
918
  * Slice an array along one or more axes.
@@ -1617,7 +1629,7 @@ const abstractEvalRules = {
1617
1629
  return [new ShapedArray(x.shape, dtype, false)];
1618
1630
  },
1619
1631
  [Primitive.Bitcast]([x], { dtype }) {
1620
- if (x.dtype === DType.Bool || dtype === DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
1632
+ if (x.dtype !== dtype && (x.dtype === DType.Bool || dtype === DType.Bool)) throw new TypeError("Bitcast to/from bool is not allowed");
1621
1633
  if (byteWidth(x.dtype) !== byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
1622
1634
  return [new ShapedArray(x.shape, dtype, false)];
1623
1635
  },
@@ -3039,8 +3051,8 @@ var Array$1 = class Array$1 extends Tracer {
3039
3051
  return [x.#unary(AluOp.Cast, dtype)];
3040
3052
  },
3041
3053
  [Primitive.Bitcast]([x], { dtype }) {
3042
- if (x.dtype === DType.Bool || dtype === DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
3043
3054
  if (x.dtype === dtype) return [x];
3055
+ if (x.dtype === DType.Bool || dtype === DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
3044
3056
  if (byteWidth(x.dtype) !== byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
3045
3057
  if (x.#source instanceof AluExp) return [x.#unary(AluOp.Bitcast, dtype)];
3046
3058
  else {
@@ -3381,32 +3393,26 @@ function fullInternal(aval, fillValue, device) {
3381
3393
  committed: device != void 0
3382
3394
  });
3383
3395
  }
3384
- function zerosLike$1(val, dtype) {
3385
- return fullLike(val, 0, dtype);
3396
+ function zerosLike$1(val, opts) {
3397
+ return fullLike(val, 0, opts);
3386
3398
  }
3387
- function onesLike$1(val, dtype) {
3388
- return fullLike(val, 1, dtype);
3399
+ function onesLike$1(val, opts) {
3400
+ return fullLike(val, 1, opts);
3389
3401
  }
3390
- function fullLike(val, fillValue, dtype) {
3402
+ function fullLike(val, fillValue, { dtype, shape: shape$1, device } = {}) {
3391
3403
  const aval = getAval(val);
3392
3404
  if (val instanceof Tracer) val.dispose();
3393
3405
  if (fillValue instanceof Tracer) throw new Error("numpy.fullLike() with array argument not implemented yet");
3394
- const sa = new ShapedArray(aval.shape, dtype ?? aval.dtype, aval.weakType);
3395
- return fullInternal(sa, fillValue);
3406
+ const sa = new ShapedArray(shape$1 ?? aval.shape, dtype ?? aval.dtype, aval.weakType && dtype === void 0);
3407
+ return fullInternal(sa, fillValue, device);
3396
3408
  }
3397
3409
  /** Return a new array of given shape and type, filled with zeros. */
3398
- function zeros(shape$1, { dtype, device } = {}) {
3399
- return full(shape$1, 0, {
3400
- dtype,
3401
- device
3402
- });
3410
+ function zeros(shape$1, opts) {
3411
+ return full(shape$1, 0, opts);
3403
3412
  }
3404
3413
  /** Return a new array of given shape and type, filled with ones. */
3405
- function ones(shape$1, { dtype, device } = {}) {
3406
- return full(shape$1, 1, {
3407
- dtype,
3408
- device
3409
- });
3414
+ function ones(shape$1, opts) {
3415
+ return full(shape$1, 1, opts);
3410
3416
  }
3411
3417
  /** Return a new array of given shape and type, filled with `fill_value`. */
3412
3418
  function full(shape$1, fillValue, { dtype, device } = {}) {
@@ -4141,6 +4147,7 @@ const jvpRules = {
4141
4147
  },
4142
4148
  [Primitive.TriangularSolve]([a, b], [da, db], { unitDiagonal }) {
4143
4149
  const x = triangularSolve$1(a.ref, b, { unitDiagonal });
4150
+ da = unitDiagonal ? triu(da, 1) : triu(da);
4144
4151
  const dax = batchMatmulT(da, x.ref);
4145
4152
  const rhsT = db.sub(mT(dax));
4146
4153
  const dx = triangularSolve$1(a, rhsT, { unitDiagonal });
@@ -5216,6 +5223,7 @@ function ifft(a, axis = -1) {
5216
5223
  var numpy_linalg_exports = {};
5217
5224
  __export(numpy_linalg_exports, {
5218
5225
  cholesky: () => cholesky,
5226
+ cross: () => cross$1,
5219
5227
  det: () => det,
5220
5228
  diagonal: () => diagonal,
5221
5229
  inv: () => inv,
@@ -5246,6 +5254,19 @@ function cholesky(a, { upper = false, symmetrizeInput = true } = {}) {
5246
5254
  if (symmetrizeInput) a = a.ref.add(matrixTranspose(a)).mul(.5);
5247
5255
  return cholesky$1(a, { upper });
5248
5256
  }
5257
+ /**
5258
+ * Compute the cross-product of two 3D vectors.
5259
+ *
5260
+ * This is a simpler and less flexible version of `jax.numpy.cross()`.
5261
+ * Both inputs must have size 3 along the specified axis.
5262
+ */
5263
+ function cross$1(x1, x2, axis = -1) {
5264
+ const a1 = checkAxis(axis, ndim(x1));
5265
+ const a2 = checkAxis(axis, ndim(x2));
5266
+ if (shape(x1)[a1] !== 3) throw new Error(`linalg.cross: x1 must have size 3 along axis ${axis}, got ${shape(x1)[a1]}`);
5267
+ if (shape(x2)[a2] !== 3) throw new Error(`linalg.cross: x2 must have size 3 along axis ${axis}, got ${shape(x2)[a2]}`);
5268
+ return cross(x1, x2, { axis });
5269
+ }
5249
5270
  /** Compute the determinant of a square matrix (batched). */
5250
5271
  function det(a) {
5251
5272
  a = fudgeArray(a);
@@ -5261,7 +5282,7 @@ function det(a) {
5261
5282
  function inv(a) {
5262
5283
  a = fudgeArray(a);
5263
5284
  const n = checkSquare("inv", a);
5264
- return solve(a, eye(n));
5285
+ return solve(a, eye(n, void 0, { dtype: a.dtype }));
5265
5286
  }
5266
5287
  /**
5267
5288
  * Return the least-squares solution to a linear equation.
@@ -5295,7 +5316,7 @@ function lstsq(a, b) {
5295
5316
  lower: true,
5296
5317
  transposeA: true
5297
5318
  });
5298
- return matmul(at, llb.ref);
5319
+ return matmul(at, llb);
5299
5320
  } else {
5300
5321
  const ata = matmul(at.ref, a);
5301
5322
  const l = cholesky(ata, { symmetrizeInput: false });
@@ -5318,8 +5339,9 @@ function matrixPower(a, n) {
5318
5339
  a = fudgeArray(a);
5319
5340
  const m = checkSquare("matrixPower", a);
5320
5341
  if (n === 0) {
5342
+ const dtype = a.dtype;
5321
5343
  a.dispose();
5322
- return broadcastTo(eye(m), a.shape);
5344
+ return broadcastTo(eye(m, void 0, { dtype }), a.shape);
5323
5345
  }
5324
5346
  if (n < 0) {
5325
5347
  a = inv(a);
@@ -5386,7 +5408,7 @@ function solve(a, b) {
5386
5408
  lower: true,
5387
5409
  unitDiagonal: true
5388
5410
  });
5389
- let x = triangularSolve(lu$2, LPb.ref, {
5411
+ let x = triangularSolve(lu$2, LPb, {
5390
5412
  leftSide: true,
5391
5413
  lower: false
5392
5414
  });
@@ -5501,13 +5523,17 @@ __export(numpy_exports, {
5501
5523
  argmax: () => argmax,
5502
5524
  argmin: () => argmin,
5503
5525
  argsort: () => argsort,
5526
+ around: () => round,
5504
5527
  array: () => array,
5528
+ arrayEqual: () => arrayEqual,
5529
+ arrayEquiv: () => arrayEquiv,
5505
5530
  asin: () => asin,
5506
5531
  asinh: () => arcsinh,
5507
5532
  astype: () => astype,
5508
5533
  atan: () => atan,
5509
5534
  atan2: () => atan2,
5510
5535
  atanh: () => arctanh,
5536
+ average: () => average,
5511
5537
  bool: () => bool,
5512
5538
  broadcastArrays: () => broadcastArrays,
5513
5539
  broadcastShapes: () => broadcastShapes,
@@ -5518,11 +5544,13 @@ __export(numpy_exports, {
5518
5544
  columnStack: () => columnStack,
5519
5545
  concatenate: () => concatenate,
5520
5546
  convolve: () => convolve,
5547
+ copysign: () => copysign,
5521
5548
  corrcoef: () => corrcoef,
5522
5549
  correlate: () => correlate,
5523
5550
  cos: () => cos,
5524
5551
  cosh: () => cosh,
5525
5552
  cov: () => cov,
5553
+ cross: () => cross,
5526
5554
  cumsum: () => cumsum,
5527
5555
  cumulativeSum: () => cumsum,
5528
5556
  deg2rad: () => deg2rad,
@@ -5558,7 +5586,6 @@ __export(numpy_exports, {
5558
5586
  fullLike: () => fullLike$1,
5559
5587
  greater: () => greater,
5560
5588
  greaterEqual: () => greaterEqual,
5561
- hamming: () => hamming,
5562
5589
  hann: () => hann,
5563
5590
  heaviside: () => heaviside,
5564
5591
  hstack: () => hstack,
@@ -5582,9 +5609,14 @@ __export(numpy_exports, {
5582
5609
  log10: () => log10,
5583
5610
  log1p: () => log1p,
5584
5611
  log2: () => log2,
5612
+ logicalAnd: () => logicalAnd,
5613
+ logicalNot: () => logicalNot,
5614
+ logicalOr: () => logicalOr,
5615
+ logicalXor: () => logicalXor,
5585
5616
  logspace: () => logspace,
5586
5617
  matmul: () => matmul,
5587
5618
  matrixTranspose: () => matrixTranspose,
5619
+ matvec: () => matvec,
5588
5620
  max: () => max,
5589
5621
  maximum: () => maximum,
5590
5622
  mean: () => mean,
@@ -5617,6 +5649,8 @@ __export(numpy_exports, {
5617
5649
  remainder: () => remainder,
5618
5650
  repeat: () => repeat,
5619
5651
  reshape: () => reshape,
5652
+ rint: () => rint,
5653
+ round: () => round,
5620
5654
  shape: () => shape,
5621
5655
  sign: () => sign,
5622
5656
  sin: () => sin,
@@ -5649,6 +5683,7 @@ __export(numpy_exports, {
5649
5683
  var_: () => var_,
5650
5684
  vdot: () => vdot,
5651
5685
  vecdot: () => vecdot,
5686
+ vecmat: () => vecmat,
5652
5687
  vstack: () => vstack,
5653
5688
  where: () => where,
5654
5689
  zeros: () => zeros,
@@ -5712,6 +5747,22 @@ const notEqual = notEqual$1;
5712
5747
  const greaterEqual = greaterEqual$1;
5713
5748
  /** @function Compare two arrays element-wise. */
5714
5749
  const lessEqual = lessEqual$1;
5750
+ /** Compute element-wise logical AND. */
5751
+ function logicalAnd(x, y) {
5752
+ return astype(x, DType.Bool).mul(astype(y, DType.Bool));
5753
+ }
5754
+ /** Compute element-wise logical OR. */
5755
+ function logicalOr(x, y) {
5756
+ return astype(x, DType.Bool).add(astype(y, DType.Bool));
5757
+ }
5758
+ /** Compute element-wise logical XOR. */
5759
+ function logicalXor(x, y) {
5760
+ return notEqual(astype(x, DType.Bool), astype(y, DType.Bool));
5761
+ }
5762
+ /** Compute element-wise logical NOT. */
5763
+ function logicalNot(x) {
5764
+ return notEqual(astype(x, DType.Bool), true);
5765
+ }
5715
5766
  /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
5716
5767
  const where = where$1;
5717
5768
  /**
@@ -5819,6 +5870,34 @@ function mean(a, axis = null, opts) {
5819
5870
  return fudgeArray(a).mean(axis, opts);
5820
5871
  }
5821
5872
  /**
5873
+ * Compute the weighted average along the specified axis.
5874
+ *
5875
+ * If no axis is specified, mean is computed along all the axes. The weights
5876
+ * should have shape matching that of `a`, or if an axis is specified, it should
5877
+ * match the shape along those axes.
5878
+ */
5879
+ function average(a, axis = null, opts) {
5880
+ a = fudgeArray(a);
5881
+ if (opts?.weights == null) return mean(a, axis, opts);
5882
+ const weights = fudgeArray(opts.weights);
5883
+ axis = normalizeAxis(axis, ndim(a));
5884
+ const wShape = weights.shape;
5885
+ const aShape = a.shape;
5886
+ if (deepEqual(wShape, aShape)) {
5887
+ const scl = sum(weights.ref, axis, opts);
5888
+ return sum(multiply(a, weights), axis, opts).div(scl);
5889
+ } else if (axis.length === 1 && wShape.length === 1 && wShape[0] === aShape[axis[0]]) {
5890
+ const broadcastShape = aShape.map((_, i) => i === axis[0] ? wShape[0] : 1);
5891
+ const wReshaped = reshape(weights, broadcastShape);
5892
+ const scl = sum(wReshaped.ref, axis, opts);
5893
+ return sum(multiply(a, wReshaped), axis, opts).div(scl);
5894
+ } else {
5895
+ weights.dispose();
5896
+ a.dispose();
5897
+ throw new Error(`average: weights shape ${JSON.stringify(wShape)} is not compatible with array shape ${JSON.stringify(aShape)} and axis ${JSON.stringify(axis)}`);
5898
+ }
5899
+ }
5900
+ /**
5822
5901
  * Returns the indices of the minimum values along an axis.
5823
5902
  *
5824
5903
  * By default, index is into the flatted array, otherwise it is along the
@@ -6197,8 +6276,9 @@ function sort(a, axis = -1) {
6197
6276
  return fudgeArray(a).sort(axis);
6198
6277
  }
6199
6278
  /**
6200
- * Return indices that would sort an array. This may be an unstable sorting
6201
- * algorithm; it need not preserve order of indices in ties.
6279
+ * Return indices that would sort an array. Unlike `sort`, this is guaranteed to
6280
+ * be a stable sorting algorithm; it always returns the smaller index first in
6281
+ * event of ties.
6202
6282
  *
6203
6283
  * Returns an array of `int32` indices.
6204
6284
  *
@@ -6221,20 +6301,63 @@ function take(a, indices, axis = null) {
6221
6301
  axis = checkAxis(axis, ndim(a));
6222
6302
  return gather(a, [indices], [axis], axis);
6223
6303
  }
6224
- /** Return if two arrays are element-wise equal within a tolerance. */
6304
+ /**
6305
+ * Return if two arrays are element-wise equal within a tolerance.
6306
+ *
6307
+ * The formula used is `|actual - expected| <= atol + rtol * |expected|`, with
6308
+ * NaN values comparing equal if `equalNaN` is true.
6309
+ */
6225
6310
  function allclose(actual, expected, options) {
6226
- const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
6311
+ const { rtol = 1e-5, atol = 1e-7, equalNaN = false } = options ?? {};
6227
6312
  const x = array(actual);
6228
6313
  const y = array(expected);
6229
6314
  if (!deepEqual(x.shape, y.shape)) return false;
6230
6315
  const xData = x.dataSync();
6231
6316
  const yData = y.dataSync();
6232
6317
  for (let i = 0; i < xData.length; i++) {
6233
- if (isNaN(xData[i]) !== isNaN(yData[i])) return false;
6318
+ if (equalNaN ? isNaN(xData[i]) !== isNaN(yData[i]) : isNaN(xData[i]) || isNaN(yData[i])) return false;
6234
6319
  if (Math.abs(xData[i] - yData[i]) > atol + rtol * Math.abs(yData[i])) return false;
6235
6320
  }
6236
6321
  return true;
6237
6322
  }
6323
+ /**
6324
+ * Check if two arrays are element-wise equal.
6325
+ *
6326
+ * Returns False if the arrays have different shapes. If `equalNaN` is True,
6327
+ * NaNs in the same position are considered equal.
6328
+ */
6329
+ function arrayEqual(a1, a2, opts) {
6330
+ a1 = fudgeArray(a1);
6331
+ a2 = fudgeArray(a2);
6332
+ if (!deepEqual(a1.shape, a2.shape)) {
6333
+ a1.dispose();
6334
+ a2.dispose();
6335
+ return array(false);
6336
+ }
6337
+ if (opts?.equalNaN) {
6338
+ const nanMask = isnan(a1.ref).mul(isnan(a2.ref));
6339
+ return where(nanMask, true, equal(a1, a2)).all();
6340
+ }
6341
+ return equal(a1, a2).all();
6342
+ }
6343
+ /**
6344
+ * Check if two arrays are element-wise equal after broadcasting.
6345
+ *
6346
+ * Unlike `arrayEqual`, this allows inputs with different but
6347
+ * broadcast-compatible shapes.
6348
+ */
6349
+ function arrayEquiv(a1, a2) {
6350
+ a1 = fudgeArray(a1);
6351
+ a2 = fudgeArray(a2);
6352
+ try {
6353
+ const [b1, b2] = broadcastArrays(a1, a2);
6354
+ return equal(b1, b2).all();
6355
+ } catch {
6356
+ a1.dispose();
6357
+ a2.dispose();
6358
+ return array(false);
6359
+ }
6360
+ }
6238
6361
  /** Matrix product of two arrays. */
6239
6362
  function matmul(x, y) {
6240
6363
  if (ndim(x) === 0 || ndim(y) === 0) throw new Error("matmul: x and y must be at least 1D");
@@ -6248,6 +6371,16 @@ function matmul(x, y) {
6248
6371
  rhsBatchDims: range(-2 - numBatchDims, -2)
6249
6372
  });
6250
6373
  }
6374
+ /** Matrix-vector product. x1 is [..., M, N], x2 is [..., N] → [..., M]. */
6375
+ function matvec(x1, x2) {
6376
+ if (ndim(x1) < 2 || ndim(x2) < 1) throw new Error("matvec: x1 must be at least 2D and x2 at least 1D");
6377
+ return einsum("...mn,...n->...m", x1, x2);
6378
+ }
6379
+ /** Vector-matrix product. x1 is [..., N], x2 is [..., N, M] → [..., M]. */
6380
+ function vecmat(x1, x2) {
6381
+ if (ndim(x1) < 1 || ndim(x2) < 2) throw new Error("vecmat: x1 must be at least 1D and x2 at least 2D");
6382
+ return einsum("...n,...nm->...m", x1, x2);
6383
+ }
6251
6384
  /** Dot product of two arrays. */
6252
6385
  function dot$1(x, y) {
6253
6386
  if (ndim(x) === 0 || ndim(y) === 0) return multiply(x, y);
@@ -6406,6 +6539,49 @@ function outer(x, y) {
6406
6539
  y = ravel(y);
6407
6540
  return multiply(x.reshape([x.shape[0], 1]), y);
6408
6541
  }
6542
+ /**
6543
+ * @function Compute the cross product of two arrays.
6544
+ *
6545
+ * Supports 2D (scalar result) and 3D cross products, with optional axis
6546
+ * arguments. If `axis` is given, it overrides `axisa`, `axisb`, and `axisc`.
6547
+ */
6548
+ const cross = jit$1(function cross$2(a, b, { axisa = -1, axisb = -1, axisc = -1, axis } = {}) {
6549
+ if (axis !== void 0) {
6550
+ axisa = axis;
6551
+ axisb = axis;
6552
+ axisc = axis;
6553
+ }
6554
+ axisa = checkAxis(axisa, ndim(a));
6555
+ axisb = checkAxis(axisb, ndim(b));
6556
+ a = moveaxis$1(a, axisa, -1);
6557
+ b = moveaxis$1(b, axisb, -1);
6558
+ const da = a.shape.at(-1);
6559
+ const db = b.shape.at(-1);
6560
+ if (da !== 2 && da !== 3 || db !== 2 && db !== 3) throw new Error(`cross: incompatible dimensions for cross product (got ${da} and ${db})`);
6561
+ if (da === 2 && db === 2) {
6562
+ const [a0$1, a1$1] = split$1(a, 2, -1);
6563
+ const [b0$1, b1$1] = split$1(b, 2, -1);
6564
+ return squeeze(a0$1.mul(b1$1).sub(a1$1.mul(b0$1)), -1);
6565
+ }
6566
+ if (da === 2) {
6567
+ const zeroShape = [...a.shape.slice(0, -1), 1];
6568
+ a = concatenate([a, zeros(zeroShape)], -1);
6569
+ }
6570
+ if (db === 2) {
6571
+ const zeroShape = [...b.shape.slice(0, -1), 1];
6572
+ b = concatenate([b, zeros(zeroShape)], -1);
6573
+ }
6574
+ const [a0, a1, a2] = split$1(a, 3, -1);
6575
+ const [b0, b1, b2] = split$1(b, 3, -1);
6576
+ const c0 = a1.ref.mul(b2.ref).sub(a2.ref.mul(b1.ref));
6577
+ const c1 = a2.mul(b0.ref).sub(a0.ref.mul(b2));
6578
+ const c2 = a0.mul(b1).sub(a1.mul(b0));
6579
+ return moveaxis$1(concatenate([
6580
+ c0,
6581
+ c1,
6582
+ c2
6583
+ ], -1), -1, axisc);
6584
+ }, { staticArgnums: [2] });
6409
6585
  /** Vector dot product of two arrays along a given axis. */
6410
6586
  function vecdot(x, y, { axis } = {}) {
6411
6587
  const xaxis = checkAxis(axis ?? -1, ndim(x));
@@ -6500,18 +6676,17 @@ function absolute(x) {
6500
6676
  /** Return an element-wise indication of sign of the input. */
6501
6677
  function sign(x) {
6502
6678
  x = fudgeArray(x);
6503
- return where(notEqual(x.ref, 0), where(less(x.ref, 0), -1, 1), 0);
6679
+ return where(notEqual(x.ref, 0), where(less(x, 0), -1, 1), 0);
6504
6680
  }
6505
- /** @function Return element-wise positive values of the input (no-op). */
6506
- const positive = fudgeArray;
6507
6681
  /**
6508
- * Return the Hamming window of size M, a taper with a weighted cosine bell.
6509
- *
6510
- * `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
6682
+ * @function
6683
+ * Return the value with the magnitude of x and the sign of y, element-wise.
6511
6684
  */
6512
- function hamming(M) {
6513
- return cos(linspace(0, 2 * Math.PI, M)).mul(-.46).add(.54);
6514
- }
6685
+ const copysign = jit$1(function copysign$1(x, y) {
6686
+ return absolute(x).mul(sign(y));
6687
+ });
6688
+ /** @function Return element-wise positive values of the input (no-op). */
6689
+ const positive = fudgeArray;
6515
6690
  /**
6516
6691
  * Return the Hann window of size M, a taper with a weighted cosine bell.
6517
6692
  *
@@ -6657,6 +6832,27 @@ function trunc(x) {
6657
6832
  return idiv(x, 1);
6658
6833
  }
6659
6834
  /**
6835
+ * @function
6836
+ * Round to the given number of decimals.
6837
+ *
6838
+ * Uses banker's rounding (round half to even) to match NumPy/JAX behavior.
6839
+ */
6840
+ const round = jit$1(function round$1(a, decimals = 0) {
6841
+ if (decimals === 0) return rint(a);
6842
+ const factor = 10 ** decimals;
6843
+ return rint(a.mul(factor)).mul(1 / factor);
6844
+ }, { staticArgnums: [1] });
6845
+ /**
6846
+ * @function
6847
+ * Round to the nearest integer, with ties going to the nearest even integer.
6848
+ */
6849
+ const rint = jit$1(function rint$1(x) {
6850
+ const rounded = floor(x.ref.add(.5));
6851
+ const half = x.ref.sub(floor(x)).equal(.5);
6852
+ const odd = remainder(rounded.ref, 2).notEqual(0);
6853
+ return where(half.mul(odd), rounded.ref.sub(1), rounded);
6854
+ });
6855
+ /**
6660
6856
  * Compute `x1 * 2 ** x2` as a standard multiplication and exponentiation.
6661
6857
  *
6662
6858
  * This is the inverse of `frexp()`.
@@ -6984,6 +7180,7 @@ function triangularSolve(a, b, { leftSide = false, lower = false, transposeA = f
6984
7180
  //#region src/library/lax.ts
6985
7181
  var lax_exports = {};
6986
7182
  __export(lax_exports, {
7183
+ bitcastConvertType: () => bitcastConvertType,
6987
7184
  conv: () => conv,
6988
7185
  convGeneralDilated: () => convGeneralDilated,
6989
7186
  convTranspose: () => convTranspose,
@@ -6993,9 +7190,14 @@ __export(lax_exports, {
6993
7190
  erfc: () => erfc,
6994
7191
  linalg: () => lax_linalg_exports,
6995
7192
  reduceWindow: () => reduceWindow,
6996
- stopGradient: () => stopGradient$1
7193
+ stopGradient: () => stopGradient$1,
7194
+ topK: () => topK
6997
7195
  });
6998
7196
  const JsArray = globalThis.Array;
7197
+ /** Elementwise bitcast an array into a new dtype. */
7198
+ function bitcastConvertType(x, newDtype) {
7199
+ return fudgeArray(x).view(newDtype);
7200
+ }
6999
7201
  /**
7000
7202
  * General dot product/contraction operator.
7001
7203
  *
@@ -7217,6 +7419,39 @@ function erfc(x) {
7217
7419
  function stopGradient$1(x) {
7218
7420
  return stopGradient(x);
7219
7421
  }
7422
+ /**
7423
+ * Returns top `k` values and their indices along the specified axis of operand.
7424
+ *
7425
+ * This is a _stable_ algorithm: If two elements are equal, the lower-index
7426
+ * element appears first.
7427
+ *
7428
+ * @returns A tuple of `(values, indices)`, where `values` and `indices` have
7429
+ * the same shape as `x`, except along `axis` where they have size `k`.
7430
+ */
7431
+ function topK(x, k, axis = -1) {
7432
+ x = fudgeArray(x);
7433
+ axis = checkAxis(axis, x.ndim);
7434
+ const size$1 = x.shape[axis];
7435
+ if (k < 0 || k > size$1) throw new Error(`topK: k must be in the range [0, ${size$1}], got ${k}`);
7436
+ if (k === 0) {
7437
+ const outShape = x.shape.slice();
7438
+ outShape[axis] = 0;
7439
+ const y$1 = zerosLike$1(x.ref, { shape: outShape });
7440
+ const yi$1 = zerosLike$1(x, {
7441
+ dtype: DType.Int32,
7442
+ shape: outShape
7443
+ });
7444
+ return [y$1, yi$1];
7445
+ }
7446
+ x = flip$1(x, [axis]);
7447
+ x = moveaxis(x, axis, -1);
7448
+ const [y, yi] = argsort$1(x);
7449
+ const extract = (a) => {
7450
+ a = a.slice(...rep(a.ndim - 1, []), [-k]);
7451
+ return flip$1(moveaxis(a, -1, axis), [axis]);
7452
+ };
7453
+ return [extract(y), extract(yi.neg().add(size$1 - 1))];
7454
+ }
7220
7455
 
7221
7456
  //#endregion
7222
7457
  //#region src/library/nn.ts
@@ -7408,7 +7643,7 @@ const gelu = jit$1(function gelu$1(x, opts) {
7408
7643
  if (opts?.approximate ?? true) {
7409
7644
  const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
7410
7645
  return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
7411
- } else return x.ref.mul(.5).mul(erfc$1(negative(x.ref.mul(Math.SQRT1_2))));
7646
+ } else return x.ref.mul(.5).mul(erfc$1(negative(x.mul(Math.SQRT1_2))));
7412
7647
  }, { staticArgnums: [1] });
7413
7648
  /**
7414
7649
  * Gated linear unit (GLU) activation function.
@@ -7666,6 +7901,7 @@ var random_exports = {};
7666
7901
  __export(random_exports, {
7667
7902
  bernoulli: () => bernoulli,
7668
7903
  bits: () => bits,
7904
+ categorical: () => categorical,
7669
7905
  cauchy: () => cauchy,
7670
7906
  exponential: () => exponential,
7671
7907
  gumbel: () => gumbel,
@@ -7693,7 +7929,9 @@ function getK01(key$1) {
7693
7929
  function key(seed) {
7694
7930
  seed = array(seed, { dtype: DType.Uint32 });
7695
7931
  if (seed.ndim !== 0) throw new Error(`key: seed must be a scalar integer, but got shape ${seed.shape} - use jax.vmap for batching.`);
7696
- return stack([0, seed]);
7932
+ const key$1 = stack([0, seed]);
7933
+ if (key$1 instanceof Array$1) key$1._realizeSource();
7934
+ return key$1;
7697
7935
  }
7698
7936
  /** Splits a PRNG key into `num` new keys by adding a leading axis. */
7699
7937
  function split(key$1, num = 2) {
@@ -7737,6 +7975,47 @@ function bernoulli(key$1, p = .5, shape$1 = []) {
7737
7975
  }
7738
7976
  /**
7739
7977
  * @function
7978
+ * Sample random values from categorical distributions.
7979
+ *
7980
+ * Uses the Gumbel max trick for sampling with replacement, or the Gumbel top-k
7981
+ * trick for sampling without replacement.
7982
+ *
7983
+ * Note: Sampling without replacement currently uses argsort and slices the last
7984
+ * k elements. This should be replaced with a more efficient topK implementation.
7985
+ *
7986
+ * - `key` - PRNG key
7987
+ * - `logits` - Unnormalized log probabilities of the categorical distribution(s).
7988
+ * `softmax(logits, axis)` gives the corresponding probabilities.
7989
+ * - `axis` - Axis along which logits belong to the same categorical distribution.
7990
+ * - `shape` - Result batch shape. Must be broadcast-compatible with
7991
+ * `logits.shape` with `axis` removed. Default is `logits.shape` with `axis` removed.
7992
+ * - `replace` - If true (default), sample with replacement. If false, sample
7993
+ * without replacement (each category can only be selected once per batch).
7994
+ * @returns A random array with int dtype and shape given by `shape` if provided,
7995
+ * otherwise `logits.shape` with `axis` removed.
7996
+ */
7997
+ const categorical = jit$1(function categorical$1(key$1, logits, { axis = -1, shape: shape$1, replace = true } = {}) {
7998
+ logits = fudgeArray(logits);
7999
+ axis = checkAxis(axis, logits.ndim);
8000
+ const numCategories = logits.shape[axis];
8001
+ const batchShape = logits.shape.toSpliced(axis, 1);
8002
+ if (shape$1 === void 0) shape$1 = batchShape;
8003
+ else if (!deepEqual(generalBroadcast(shape$1, batchShape), shape$1)) throw new Error(`Shape ${shape$1} is not broadcast-compatible with batch shape ${batchShape}.`);
8004
+ const shapePrefix = shape$1.slice(0, shape$1.length - batchShape.length);
8005
+ if (replace) {
8006
+ const noise = gumbel(key$1, [...shapePrefix, ...logits.shape]);
8007
+ return argmax(noise.add(logits), axis + shapePrefix.length);
8008
+ } else {
8009
+ const k = shapePrefix.reduce((a, b) => a * b, 1);
8010
+ if (k > numCategories) throw new Error(`Number of samples without replacement (${k}) cannot exceed number of categories (${numCategories}).`);
8011
+ const noise = gumbel(key$1, logits.shape);
8012
+ const [values, indices] = topK(noise.add(logits), k, axis);
8013
+ values.dispose();
8014
+ return indices.reshape(shape$1);
8015
+ }
8016
+ }, { staticArgnums: [2] });
8017
+ /**
8018
+ * @function
7740
8019
  * Sample from a Cauchy distribution with location 0 and scale 1.
7741
8020
  *
7742
8021
  * Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
@@ -7847,6 +8126,11 @@ Symbol.asyncDispose ??= Symbol.for("Symbol.asyncDispose");
7847
8126
 
7848
8127
  //#endregion
7849
8128
  //#region src/index.ts
8129
+ /** @namespace */
8130
+ const profiler = {
8131
+ startTrace,
8132
+ stopTrace
8133
+ };
7850
8134
  /**
7851
8135
  * @function
7852
8136
  * Compute the forward-mode Jacobian-vector product for a function.
@@ -8007,4 +8291,4 @@ async function devicePut(x, device) {
8007
8291
  }
8008
8292
 
8009
8293
  //#endregion
8010
- export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
8294
+ export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, profiler, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
@@ -1,4 +1,4 @@
1
- import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-nEolvdLv.js";
1
+ import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-Ctqs8la1.js";
2
2
 
3
3
  //#region src/backend/webgl/builtins.ts
4
4
  const threefrySrc = `
@@ -1,4 +1,4 @@
1
- const require_backend = require('./backend-B3foXiV_.cjs');
1
+ const require_backend = require('./backend-DMauYnfl.cjs');
2
2
 
3
3
  //#region src/backend/webgl/builtins.ts
4
4
  const threefrySrc = `