@jax-js/jax 0.1.9 → 0.1.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.cjs CHANGED
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
30
30
  }) : target, mod$1));
31
31
 
32
32
  //#endregion
33
- const require_backend = require('./backend-DpI0riom.cjs');
33
+ const require_backend = require('./backend-DlYlOYqN.cjs');
34
34
 
35
35
  //#region src/frontend/convolution.ts
36
36
  /**
@@ -364,6 +364,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
364
364
  Primitive$1["Mod"] = "mod";
365
365
  Primitive$1["Min"] = "min";
366
366
  Primitive$1["Max"] = "max";
367
+ Primitive$1["BitCombine"] = "bit_combine";
368
+ Primitive$1["BitShift"] = "bit_shift";
367
369
  Primitive$1["Neg"] = "neg";
368
370
  Primitive$1["Reciprocal"] = "reciprocal";
369
371
  Primitive$1["Floor"] = "floor";
@@ -437,6 +439,12 @@ function min$1(x, y) {
437
439
  function max$1(x, y) {
438
440
  return bind1(Primitive.Max, [x, y]);
439
441
  }
442
+ function bitCombine(x, y, op) {
443
+ return bind1(Primitive.BitCombine, [x, y], { op });
444
+ }
445
+ function bitShift(x, y, op) {
446
+ return bind1(Primitive.BitShift, [x, y], { op });
447
+ }
440
448
  function neg(x) {
441
449
  return bind1(Primitive.Neg, [x]);
442
450
  }
@@ -838,6 +846,11 @@ var Tracer = class Tracer {
838
846
  if (this.dtype === dtype) return this;
839
847
  return cast(this, dtype);
840
848
  }
849
+ /** Return a bitwise cast of the array, viewed as a new dtype. */
850
+ view(dtype) {
851
+ if (!dtype || dtype === this.dtype) return this;
852
+ return bitcast(this, dtype);
853
+ }
841
854
  /** Subtract an array from this one. */
842
855
  sub(other) {
843
856
  return this.add(neg(other));
@@ -1650,6 +1663,16 @@ const abstractEvalRules = {
1650
1663
  [Primitive.Mod]: binopAbstractEval,
1651
1664
  [Primitive.Min]: binopAbstractEval,
1652
1665
  [Primitive.Max]: binopAbstractEval,
1666
+ [Primitive.BitCombine]([x, y]) {
1667
+ const aval = promoteAvals(x, y);
1668
+ if (require_backend.isFloatDtype(aval.dtype)) throw new TypeError(`bitwise operations require integer or boolean inputs, got ${aval.dtype}`);
1669
+ return [aval];
1670
+ },
1671
+ [Primitive.BitShift]([x, y]) {
1672
+ const shape$1 = require_backend.generalBroadcast(x.shape, y.shape);
1673
+ if (require_backend.isFloatDtype(x.dtype) || require_backend.isFloatDtype(y.dtype) || x.dtype === require_backend.DType.Bool || y.dtype === require_backend.DType.Bool) throw new TypeError(`bit shift operations require integer inputs, got ${x} and ${y}`);
1674
+ return [new ShapedArray(shape$1, x.dtype, x.weakType)];
1675
+ },
1653
1676
  [Primitive.Neg]: vectorizedUnopAbstractEval,
1654
1677
  [Primitive.Reciprocal]: vectorizedUnopAbstractEval,
1655
1678
  [Primitive.Floor]: vectorizedUnopAbstractEval,
@@ -1659,7 +1682,7 @@ const abstractEvalRules = {
1659
1682
  return [new ShapedArray(x.shape, dtype, false)];
1660
1683
  },
1661
1684
  [Primitive.Bitcast]([x], { dtype }) {
1662
- if (x.dtype === require_backend.DType.Bool || dtype === require_backend.DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
1685
+ if (x.dtype !== dtype && (x.dtype === require_backend.DType.Bool || dtype === require_backend.DType.Bool)) throw new TypeError("Bitcast to/from bool is not allowed");
1663
1686
  if (require_backend.byteWidth(x.dtype) !== require_backend.byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
1664
1687
  return [new ShapedArray(x.shape, dtype, false)];
1665
1688
  },
@@ -2185,6 +2208,8 @@ const jitRules = {
2185
2208
  [Primitive.Mod]: broadcastedJit(([a, b]) => require_backend.AluExp.mod(a, b)),
2186
2209
  [Primitive.Min]: broadcastedJit(([a, b]) => require_backend.AluExp.min(a, b)),
2187
2210
  [Primitive.Max]: broadcastedJit(([a, b]) => require_backend.AluExp.max(a, b)),
2211
+ [Primitive.BitCombine]: broadcastedJit(([a, b], { op }) => require_backend.AluExp.bitCombine(a, b, op)),
2212
+ [Primitive.BitShift]: broadcastedJit(([a, b], { op }) => require_backend.AluExp.bitShift(a, b, op)),
2188
2213
  [Primitive.Neg]: unopJit((a) => require_backend.AluExp.sub(require_backend.AluExp.const(a.dtype, 0), a)),
2189
2214
  [Primitive.Reciprocal]: unopJit(require_backend.AluExp.reciprocal),
2190
2215
  [Primitive.Floor]: unopJit(require_backend.AluExp.floor),
@@ -2377,7 +2402,9 @@ function splitGraphDataflow(backend, jaxpr) {
2377
2402
  case Primitive.Idiv:
2378
2403
  case Primitive.Mod:
2379
2404
  case Primitive.Min:
2380
- case Primitive.Max: {
2405
+ case Primitive.Max:
2406
+ case Primitive.BitCombine:
2407
+ case Primitive.BitShift: {
2381
2408
  const otherInput = nextEqn.inputs.find((v) => v !== outVar);
2382
2409
  if (otherInput instanceof Lit || require_backend.deepEqual(require_backend.generalBroadcast(otherInput.aval.shape, outVar.aval.shape), outVar.aval.shape)) {
2383
2410
  head = usages[0];
@@ -3016,6 +3043,42 @@ var Array$1 = class Array$1 extends Tracer {
3016
3043
  return require_backend.dtypedArray(this.dtype, buf);
3017
3044
  }
3018
3045
  /**
3046
+ * Return this array as a WebGPU buffer (with `STORAGE | COPY_SRC`).
3047
+ *
3048
+ * Only available on the WebGPU backend. The array's memory is still managed
3049
+ * by jax-js, and it will be freed when the buffer is no longer in use. You
3050
+ * _should not_ mutate the buffer's contents.
3051
+ *
3052
+ * Note that the GPU buffer may be slightly larger than the array's size; it
3053
+ * will always be aligned to 4 bytes.
3054
+ */
3055
+ async gpuBuffer() {
3056
+ if (this.device !== "webgpu") throw new Error(`gpuBuffer() is only available on WebGPU backend`);
3057
+ this.#realize();
3058
+ const pending = this.#pending;
3059
+ if (pending) {
3060
+ await Promise.all(pending.map((p) => p.prepare()));
3061
+ for (const p of pending) p.submit();
3062
+ }
3063
+ const backend = this.#backend;
3064
+ const { buffer } = backend.buffers.get(this.#source);
3065
+ this.dispose();
3066
+ return buffer;
3067
+ }
3068
+ /** Synchronous version of `Array.gpuBuffer()`. */
3069
+ gpuBufferSync() {
3070
+ if (this.device !== "webgpu") throw new Error(`gpuBufferSync() is only available on WebGPU backend`);
3071
+ this.#realize();
3072
+ for (const p of this.#pending) {
3073
+ p.prepareSync();
3074
+ p.submit();
3075
+ }
3076
+ const backend = this.#backend;
3077
+ const { buffer } = backend.buffers.get(this.#source);
3078
+ this.dispose();
3079
+ return buffer;
3080
+ }
3081
+ /**
3019
3082
  * Convert this array into a JavaScript object.
3020
3083
  *
3021
3084
  * This is a blocking operation that will compile all of the shaders and wait
@@ -3062,6 +3125,14 @@ var Array$1 = class Array$1 extends Tracer {
3062
3125
  [Primitive.Max]([x, y]) {
3063
3126
  return [x.#binary(require_backend.AluOp.Max, y)];
3064
3127
  },
3128
+ [Primitive.BitCombine]([x, y], { op }) {
3129
+ const custom = (src) => require_backend.AluExp.bitCombine(src[0], src[1], op);
3130
+ return [Array$1.#naryCustom("bit_combine", custom, [x, y])];
3131
+ },
3132
+ [Primitive.BitShift]([x, y], { op }) {
3133
+ const custom = (src) => require_backend.AluExp.bitShift(src[0], src[1], op);
3134
+ return [Array$1.#naryCustom("bit_shift", custom, [x, y], { dtypeOverride: [void 0, y.dtype] })];
3135
+ },
3065
3136
  [Primitive.Neg]([x]) {
3066
3137
  return [zerosLike$1(x.ref).#binary(require_backend.AluOp.Sub, x)];
3067
3138
  },
@@ -3081,8 +3152,8 @@ var Array$1 = class Array$1 extends Tracer {
3081
3152
  return [x.#unary(require_backend.AluOp.Cast, dtype)];
3082
3153
  },
3083
3154
  [Primitive.Bitcast]([x], { dtype }) {
3084
- if (x.dtype === require_backend.DType.Bool || dtype === require_backend.DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
3085
3155
  if (x.dtype === dtype) return [x];
3156
+ if (x.dtype === require_backend.DType.Bool || dtype === require_backend.DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
3086
3157
  if (require_backend.byteWidth(x.dtype) !== require_backend.byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
3087
3158
  if (x.#source instanceof require_backend.AluExp) return [x.#unary(require_backend.AluOp.Bitcast, dtype)];
3088
3159
  else {
@@ -3754,6 +3825,8 @@ const vmapRules = {
3754
3825
  [Primitive.Mod]: broadcastBatcher(Primitive.Mod),
3755
3826
  [Primitive.Min]: broadcastBatcher(Primitive.Min),
3756
3827
  [Primitive.Max]: broadcastBatcher(Primitive.Max),
3828
+ [Primitive.BitCombine]: broadcastBatcher(Primitive.BitCombine),
3829
+ [Primitive.BitShift]: broadcastBatcher(Primitive.BitShift),
3757
3830
  [Primitive.Neg]: unopBatcher(Primitive.Neg),
3758
3831
  [Primitive.Reciprocal]: unopBatcher(Primitive.Reciprocal),
3759
3832
  [Primitive.Floor]: unopBatcher(Primitive.Floor),
@@ -4077,6 +4150,8 @@ const jvpRules = {
4077
4150
  [Primitive.Max]([x, y], [dx, dy]) {
4078
4151
  return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
4079
4152
  },
4153
+ [Primitive.BitCombine]: zeroTangentsJvp(Primitive.BitCombine),
4154
+ [Primitive.BitShift]: zeroTangentsJvp(Primitive.BitShift),
4080
4155
  [Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
4081
4156
  [Primitive.Reciprocal]([x], [dx]) {
4082
4157
  const xRecip = reciprocal$1(x.ref);
@@ -4179,6 +4254,7 @@ const jvpRules = {
4179
4254
  },
4180
4255
  [Primitive.TriangularSolve]([a, b], [da, db], { unitDiagonal }) {
4181
4256
  const x = triangularSolve$1(a.ref, b, { unitDiagonal });
4257
+ da = unitDiagonal ? triu(da, 1) : triu(da);
4182
4258
  const dax = batchMatmulT(da, x.ref);
4183
4259
  const rhsT = db.sub(mT(dax));
4184
4260
  const dx = triangularSolve$1(a, rhsT, { unitDiagonal });
@@ -5254,6 +5330,7 @@ function ifft(a, axis = -1) {
5254
5330
  var numpy_linalg_exports = {};
5255
5331
  __export(numpy_linalg_exports, {
5256
5332
  cholesky: () => cholesky,
5333
+ cross: () => cross$1,
5257
5334
  det: () => det,
5258
5335
  diagonal: () => diagonal,
5259
5336
  inv: () => inv,
@@ -5266,7 +5343,8 @@ __export(numpy_linalg_exports, {
5266
5343
  solve: () => solve,
5267
5344
  tensordot: () => tensordot,
5268
5345
  trace: () => trace,
5269
- vecdot: () => vecdot
5346
+ vecdot: () => vecdot,
5347
+ vectorNorm: () => vectorNorm
5270
5348
  });
5271
5349
  function checkSquare(name, a) {
5272
5350
  if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`${name}: input must be at least 2D square matrix, got ${a.aval}`);
@@ -5284,6 +5362,19 @@ function cholesky(a, { upper = false, symmetrizeInput = true } = {}) {
5284
5362
  if (symmetrizeInput) a = a.ref.add(matrixTranspose(a)).mul(.5);
5285
5363
  return cholesky$1(a, { upper });
5286
5364
  }
5365
+ /**
5366
+ * Compute the cross-product of two 3D vectors.
5367
+ *
5368
+ * This is a simpler and less flexible version of `jax.numpy.cross()`.
5369
+ * Both inputs must have size 3 along the specified axis.
5370
+ */
5371
+ function cross$1(x1, x2, axis = -1) {
5372
+ const a1 = require_backend.checkAxis(axis, ndim(x1));
5373
+ const a2 = require_backend.checkAxis(axis, ndim(x2));
5374
+ if (shape(x1)[a1] !== 3) throw new Error(`linalg.cross: x1 must have size 3 along axis ${axis}, got ${shape(x1)[a1]}`);
5375
+ if (shape(x2)[a2] !== 3) throw new Error(`linalg.cross: x2 must have size 3 along axis ${axis}, got ${shape(x2)[a2]}`);
5376
+ return cross(x1, x2, { axis });
5377
+ }
5287
5378
  /** Compute the determinant of a square matrix (batched). */
5288
5379
  function det(a) {
5289
5380
  a = fudgeArray(a);
@@ -5299,7 +5390,7 @@ function det(a) {
5299
5390
  function inv(a) {
5300
5391
  a = fudgeArray(a);
5301
5392
  const n = checkSquare("inv", a);
5302
- return solve(a, eye(n));
5393
+ return solve(a, eye(n, void 0, { dtype: a.dtype }));
5303
5394
  }
5304
5395
  /**
5305
5396
  * Return the least-squares solution to a linear equation.
@@ -5356,8 +5447,9 @@ function matrixPower(a, n) {
5356
5447
  a = fudgeArray(a);
5357
5448
  const m = checkSquare("matrixPower", a);
5358
5449
  if (n === 0) {
5450
+ const dtype = a.dtype;
5359
5451
  a.dispose();
5360
- return broadcastTo(eye(m), a.shape);
5452
+ return broadcastTo(eye(m, void 0, { dtype }), a.shape);
5361
5453
  }
5362
5454
  if (n < 0) {
5363
5455
  a = inv(a);
@@ -5431,6 +5523,23 @@ function solve(a, b) {
5431
5523
  if (bIs1d) x = squeeze(x, -1);
5432
5524
  return x;
5433
5525
  }
5526
+ /**
5527
+ * Compute the vector norm of an array.
5528
+ *
5529
+ * @param x - Input array.
5530
+ * @param ord - Order of the norm (default 2). Supports `Infinity`, `-Infinity`, `0`, or any real number.
5531
+ * @param axis - Axis/axes to reduce over (default: all axes).
5532
+ * @param keepdims - Whether to keep reduced dimensions as size 1.
5533
+ * @returns The norm of `x`, reduced over the given axes.
5534
+ */
5535
+ function vectorNorm(x, { ord = 2, axis = null, keepdims = false } = {}) {
5536
+ x = fudgeArray(x);
5537
+ const ax = axis ?? null;
5538
+ if (ord === Infinity) return max(absolute(x), ax, { keepdims });
5539
+ else if (ord === -Infinity) return min(absolute(x), ax, { keepdims });
5540
+ else if (ord === 0) return x.notEqual(0).astype(x.dtype).sum(ax, { keepdims });
5541
+ else return power(power(absolute(x), ord).sum(ax, { keepdims }), 1 / ord);
5542
+ }
5434
5543
 
5435
5544
  //#endregion
5436
5545
  //#region src/library/numpy/dtype-info.ts
@@ -5539,13 +5648,24 @@ __export(numpy_exports, {
5539
5648
  argmax: () => argmax,
5540
5649
  argmin: () => argmin,
5541
5650
  argsort: () => argsort,
5651
+ around: () => round,
5542
5652
  array: () => array,
5653
+ arrayEqual: () => arrayEqual,
5654
+ arrayEquiv: () => arrayEquiv,
5543
5655
  asin: () => asin,
5544
5656
  asinh: () => arcsinh,
5545
5657
  astype: () => astype,
5546
5658
  atan: () => atan,
5547
5659
  atan2: () => atan2,
5548
5660
  atanh: () => arctanh,
5661
+ average: () => average,
5662
+ bitwiseAnd: () => bitwiseAnd,
5663
+ bitwiseInvert: () => invert,
5664
+ bitwiseLeftShift: () => leftShift,
5665
+ bitwiseNot: () => invert,
5666
+ bitwiseOr: () => bitwiseOr,
5667
+ bitwiseRightShift: () => rightShift,
5668
+ bitwiseXor: () => bitwiseXor,
5549
5669
  bool: () => bool,
5550
5670
  broadcastArrays: () => broadcastArrays,
5551
5671
  broadcastShapes: () => broadcastShapes,
@@ -5556,11 +5676,13 @@ __export(numpy_exports, {
5556
5676
  columnStack: () => columnStack,
5557
5677
  concatenate: () => concatenate,
5558
5678
  convolve: () => convolve,
5679
+ copysign: () => copysign,
5559
5680
  corrcoef: () => corrcoef,
5560
5681
  correlate: () => correlate,
5561
5682
  cos: () => cos,
5562
5683
  cosh: () => cosh,
5563
5684
  cov: () => cov,
5685
+ cross: () => cross,
5564
5686
  cumsum: () => cumsum,
5565
5687
  cumulativeSum: () => cumsum,
5566
5688
  deg2rad: () => deg2rad,
@@ -5596,7 +5718,6 @@ __export(numpy_exports, {
5596
5718
  fullLike: () => fullLike$1,
5597
5719
  greater: () => greater,
5598
5720
  greaterEqual: () => greaterEqual,
5599
- hamming: () => hamming,
5600
5721
  hann: () => hann,
5601
5722
  heaviside: () => heaviside,
5602
5723
  hstack: () => hstack,
@@ -5606,12 +5727,14 @@ __export(numpy_exports, {
5606
5727
  inf: () => inf,
5607
5728
  inner: () => inner,
5608
5729
  int32: () => int32,
5730
+ invert: () => invert,
5609
5731
  isfinite: () => isfinite,
5610
5732
  isinf: () => isinf,
5611
5733
  isnan: () => isnan,
5612
5734
  isneginf: () => isneginf,
5613
5735
  isposinf: () => isposinf,
5614
5736
  ldexp: () => ldexp,
5737
+ leftShift: () => leftShift,
5615
5738
  less: () => less,
5616
5739
  lessEqual: () => lessEqual,
5617
5740
  linalg: () => numpy_linalg_exports,
@@ -5620,9 +5743,14 @@ __export(numpy_exports, {
5620
5743
  log10: () => log10,
5621
5744
  log1p: () => log1p,
5622
5745
  log2: () => log2,
5746
+ logicalAnd: () => logicalAnd,
5747
+ logicalNot: () => logicalNot,
5748
+ logicalOr: () => logicalOr,
5749
+ logicalXor: () => logicalXor,
5623
5750
  logspace: () => logspace,
5624
5751
  matmul: () => matmul,
5625
5752
  matrixTranspose: () => matrixTranspose,
5753
+ matvec: () => matvec,
5626
5754
  max: () => max,
5627
5755
  maximum: () => maximum,
5628
5756
  mean: () => mean,
@@ -5655,6 +5783,9 @@ __export(numpy_exports, {
5655
5783
  remainder: () => remainder,
5656
5784
  repeat: () => repeat,
5657
5785
  reshape: () => reshape,
5786
+ rightShift: () => rightShift,
5787
+ rint: () => rint,
5788
+ round: () => round,
5658
5789
  shape: () => shape,
5659
5790
  sign: () => sign,
5660
5791
  sin: () => sin,
@@ -5687,6 +5818,7 @@ __export(numpy_exports, {
5687
5818
  var_: () => var_,
5688
5819
  vdot: () => vdot,
5689
5820
  vecdot: () => vecdot,
5821
+ vecmat: () => vecmat,
5690
5822
  vstack: () => vstack,
5691
5823
  where: () => where,
5692
5824
  zeros: () => zeros,
@@ -5750,6 +5882,60 @@ const notEqual = notEqual$1;
5750
5882
  const greaterEqual = greaterEqual$1;
5751
5883
  /** @function Compare two arrays element-wise. */
5752
5884
  const lessEqual = lessEqual$1;
5885
+ /** Compute element-wise logical AND. */
5886
+ function logicalAnd(x, y) {
5887
+ return astype(x, require_backend.DType.Bool).mul(astype(y, require_backend.DType.Bool));
5888
+ }
5889
+ /** Compute element-wise logical OR. */
5890
+ function logicalOr(x, y) {
5891
+ return astype(x, require_backend.DType.Bool).add(astype(y, require_backend.DType.Bool));
5892
+ }
5893
+ /** Compute element-wise logical XOR. */
5894
+ function logicalXor(x, y) {
5895
+ return notEqual(astype(x, require_backend.DType.Bool), astype(y, require_backend.DType.Bool));
5896
+ }
5897
+ /** Compute element-wise logical NOT. */
5898
+ function logicalNot(x) {
5899
+ return notEqual(astype(x, require_backend.DType.Bool), true);
5900
+ }
5901
+ /** Compute element-wise bitwise AND. */
5902
+ function bitwiseAnd(x, y) {
5903
+ return bitCombine(x, y, "and");
5904
+ }
5905
+ /** Compute element-wise bitwise OR. */
5906
+ function bitwiseOr(x, y) {
5907
+ return bitCombine(x, y, "or");
5908
+ }
5909
+ /** Compute element-wise bitwise XOR. */
5910
+ function bitwiseXor(x, y) {
5911
+ return bitCombine(x, y, "xor");
5912
+ }
5913
+ /** Compute element-wise bitwise NOT (inversion). */
5914
+ function invert(x) {
5915
+ const arr = fudgeArray(x);
5916
+ let allOnes;
5917
+ switch (arr.dtype) {
5918
+ case require_backend.DType.Bool:
5919
+ allOnes = true;
5920
+ break;
5921
+ case require_backend.DType.Uint32:
5922
+ allOnes = 4294967295;
5923
+ break;
5924
+ case require_backend.DType.Int32:
5925
+ allOnes = -1;
5926
+ break;
5927
+ default: throw new TypeError(`invert: unsupported dtype ${arr.dtype}`);
5928
+ }
5929
+ return bitCombine(arr, allOnes, "xor");
5930
+ }
5931
+ /** Compute element-wise left bit shift. */
5932
+ function leftShift(x, y) {
5933
+ return bitShift(x, y, "shl");
5934
+ }
5935
+ /** Compute element-wise right bit shift. */
5936
+ function rightShift(x, y) {
5937
+ return bitShift(x, y, "shr");
5938
+ }
5753
5939
  /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
5754
5940
  const where = where$1;
5755
5941
  /**
@@ -5857,6 +6043,34 @@ function mean(a, axis = null, opts) {
5857
6043
  return fudgeArray(a).mean(axis, opts);
5858
6044
  }
5859
6045
  /**
6046
+ * Compute the weighted average along the specified axis.
6047
+ *
6048
+ * If no axis is specified, mean is computed along all the axes. The weights
6049
+ * should have shape matching that of `a`, or if an axis is specified, it should
6050
+ * match the shape along those axes.
6051
+ */
6052
+ function average(a, axis = null, opts) {
6053
+ a = fudgeArray(a);
6054
+ if (opts?.weights == null) return mean(a, axis, opts);
6055
+ const weights = fudgeArray(opts.weights);
6056
+ axis = require_backend.normalizeAxis(axis, ndim(a));
6057
+ const wShape = weights.shape;
6058
+ const aShape = a.shape;
6059
+ if (require_backend.deepEqual(wShape, aShape)) {
6060
+ const scl = sum(weights.ref, axis, opts);
6061
+ return sum(multiply(a, weights), axis, opts).div(scl);
6062
+ } else if (axis.length === 1 && wShape.length === 1 && wShape[0] === aShape[axis[0]]) {
6063
+ const broadcastShape = aShape.map((_, i) => i === axis[0] ? wShape[0] : 1);
6064
+ const wReshaped = reshape(weights, broadcastShape);
6065
+ const scl = sum(wReshaped.ref, axis, opts);
6066
+ return sum(multiply(a, wReshaped), axis, opts).div(scl);
6067
+ } else {
6068
+ weights.dispose();
6069
+ a.dispose();
6070
+ throw new Error(`average: weights shape ${JSON.stringify(wShape)} is not compatible with array shape ${JSON.stringify(aShape)} and axis ${JSON.stringify(axis)}`);
6071
+ }
6072
+ }
6073
+ /**
5860
6074
  * Returns the indices of the minimum values along an axis.
5861
6075
  *
5862
6076
  * By default, index is into the flatted array, otherwise it is along the
@@ -6260,20 +6474,63 @@ function take(a, indices, axis = null) {
6260
6474
  axis = require_backend.checkAxis(axis, ndim(a));
6261
6475
  return gather(a, [indices], [axis], axis);
6262
6476
  }
6263
- /** Return if two arrays are element-wise equal within a tolerance. */
6477
+ /**
6478
+ * Return if two arrays are element-wise equal within a tolerance.
6479
+ *
6480
+ * The formula used is `|actual - expected| <= atol + rtol * |expected|`, with
6481
+ * NaN values comparing equal if `equalNaN` is true.
6482
+ */
6264
6483
  function allclose(actual, expected, options) {
6265
- const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
6484
+ const { rtol = 1e-5, atol = 1e-7, equalNaN = false } = options ?? {};
6266
6485
  const x = array(actual);
6267
6486
  const y = array(expected);
6268
6487
  if (!require_backend.deepEqual(x.shape, y.shape)) return false;
6269
6488
  const xData = x.dataSync();
6270
6489
  const yData = y.dataSync();
6271
6490
  for (let i = 0; i < xData.length; i++) {
6272
- if (isNaN(xData[i]) !== isNaN(yData[i])) return false;
6491
+ if (equalNaN ? isNaN(xData[i]) !== isNaN(yData[i]) : isNaN(xData[i]) || isNaN(yData[i])) return false;
6273
6492
  if (Math.abs(xData[i] - yData[i]) > atol + rtol * Math.abs(yData[i])) return false;
6274
6493
  }
6275
6494
  return true;
6276
6495
  }
6496
+ /**
6497
+ * Check if two arrays are element-wise equal.
6498
+ *
6499
+ * Returns False if the arrays have different shapes. If `equalNaN` is True,
6500
+ * NaNs in the same position are considered equal.
6501
+ */
6502
+ function arrayEqual(a1, a2, opts) {
6503
+ a1 = fudgeArray(a1);
6504
+ a2 = fudgeArray(a2);
6505
+ if (!require_backend.deepEqual(a1.shape, a2.shape)) {
6506
+ a1.dispose();
6507
+ a2.dispose();
6508
+ return array(false);
6509
+ }
6510
+ if (opts?.equalNaN) {
6511
+ const nanMask = isnan(a1.ref).mul(isnan(a2.ref));
6512
+ return where(nanMask, true, equal(a1, a2)).all();
6513
+ }
6514
+ return equal(a1, a2).all();
6515
+ }
6516
+ /**
6517
+ * Check if two arrays are element-wise equal after broadcasting.
6518
+ *
6519
+ * Unlike `arrayEqual`, this allows inputs with different but
6520
+ * broadcast-compatible shapes.
6521
+ */
6522
+ function arrayEquiv(a1, a2) {
6523
+ a1 = fudgeArray(a1);
6524
+ a2 = fudgeArray(a2);
6525
+ try {
6526
+ const [b1, b2] = broadcastArrays(a1, a2);
6527
+ return equal(b1, b2).all();
6528
+ } catch {
6529
+ a1.dispose();
6530
+ a2.dispose();
6531
+ return array(false);
6532
+ }
6533
+ }
6277
6534
  /** Matrix product of two arrays. */
6278
6535
  function matmul(x, y) {
6279
6536
  if (ndim(x) === 0 || ndim(y) === 0) throw new Error("matmul: x and y must be at least 1D");
@@ -6287,6 +6544,16 @@ function matmul(x, y) {
6287
6544
  rhsBatchDims: require_backend.range(-2 - numBatchDims, -2)
6288
6545
  });
6289
6546
  }
6547
+ /** Matrix-vector product. x1 is [..., M, N], x2 is [..., N] → [..., M]. */
6548
+ function matvec(x1, x2) {
6549
+ if (ndim(x1) < 2 || ndim(x2) < 1) throw new Error("matvec: x1 must be at least 2D and x2 at least 1D");
6550
+ return einsum("...mn,...n->...m", x1, x2);
6551
+ }
6552
+ /** Vector-matrix product. x1 is [..., N], x2 is [..., N, M] → [..., M]. */
6553
+ function vecmat(x1, x2) {
6554
+ if (ndim(x1) < 1 || ndim(x2) < 2) throw new Error("vecmat: x1 must be at least 1D and x2 at least 2D");
6555
+ return einsum("...n,...nm->...m", x1, x2);
6556
+ }
6290
6557
  /** Dot product of two arrays. */
6291
6558
  function dot$1(x, y) {
6292
6559
  if (ndim(x) === 0 || ndim(y) === 0) return multiply(x, y);
@@ -6445,6 +6712,49 @@ function outer(x, y) {
6445
6712
  y = ravel(y);
6446
6713
  return multiply(x.reshape([x.shape[0], 1]), y);
6447
6714
  }
6715
+ /**
6716
+ * @function Compute the cross product of two arrays.
6717
+ *
6718
+ * Supports 2D (scalar result) and 3D cross products, with optional axis
6719
+ * arguments. If `axis` is given, it overrides `axisa`, `axisb`, and `axisc`.
6720
+ */
6721
+ const cross = jit$1(function cross$2(a, b, { axisa = -1, axisb = -1, axisc = -1, axis } = {}) {
6722
+ if (axis !== void 0) {
6723
+ axisa = axis;
6724
+ axisb = axis;
6725
+ axisc = axis;
6726
+ }
6727
+ axisa = require_backend.checkAxis(axisa, ndim(a));
6728
+ axisb = require_backend.checkAxis(axisb, ndim(b));
6729
+ a = moveaxis$1(a, axisa, -1);
6730
+ b = moveaxis$1(b, axisb, -1);
6731
+ const da = a.shape.at(-1);
6732
+ const db = b.shape.at(-1);
6733
+ if (da !== 2 && da !== 3 || db !== 2 && db !== 3) throw new Error(`cross: incompatible dimensions for cross product (got ${da} and ${db})`);
6734
+ if (da === 2 && db === 2) {
6735
+ const [a0$1, a1$1] = split$1(a, 2, -1);
6736
+ const [b0$1, b1$1] = split$1(b, 2, -1);
6737
+ return squeeze(a0$1.mul(b1$1).sub(a1$1.mul(b0$1)), -1);
6738
+ }
6739
+ if (da === 2) {
6740
+ const zeroShape = [...a.shape.slice(0, -1), 1];
6741
+ a = concatenate([a, zeros(zeroShape)], -1);
6742
+ }
6743
+ if (db === 2) {
6744
+ const zeroShape = [...b.shape.slice(0, -1), 1];
6745
+ b = concatenate([b, zeros(zeroShape)], -1);
6746
+ }
6747
+ const [a0, a1, a2] = split$1(a, 3, -1);
6748
+ const [b0, b1, b2] = split$1(b, 3, -1);
6749
+ const c0 = a1.ref.mul(b2.ref).sub(a2.ref.mul(b1.ref));
6750
+ const c1 = a2.mul(b0.ref).sub(a0.ref.mul(b2));
6751
+ const c2 = a0.mul(b1).sub(a1.mul(b0));
6752
+ return moveaxis$1(concatenate([
6753
+ c0,
6754
+ c1,
6755
+ c2
6756
+ ], -1), -1, axisc);
6757
+ }, { staticArgnums: [2] });
6448
6758
  /** Vector dot product of two arrays along a given axis. */
6449
6759
  function vecdot(x, y, { axis } = {}) {
6450
6760
  const xaxis = require_backend.checkAxis(axis ?? -1, ndim(x));
@@ -6541,16 +6851,15 @@ function sign(x) {
6541
6851
  x = fudgeArray(x);
6542
6852
  return where(notEqual(x.ref, 0), where(less(x, 0), -1, 1), 0);
6543
6853
  }
6544
- /** @function Return element-wise positive values of the input (no-op). */
6545
- const positive = fudgeArray;
6546
6854
  /**
6547
- * Return the Hamming window of size M, a taper with a weighted cosine bell.
6548
- *
6549
- * `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
6855
+ * @function
6856
+ * Return the value with the magnitude of x and the sign of y, element-wise.
6550
6857
  */
6551
- function hamming(M) {
6552
- return cos(linspace(0, 2 * Math.PI, M)).mul(-.46).add(.54);
6553
- }
6858
+ const copysign = jit$1(function copysign$1(x, y) {
6859
+ return absolute(x).mul(sign(y));
6860
+ });
6861
+ /** @function Return element-wise positive values of the input (no-op). */
6862
+ const positive = fudgeArray;
6554
6863
  /**
6555
6864
  * Return the Hann window of size M, a taper with a weighted cosine bell.
6556
6865
  *
@@ -6696,6 +7005,27 @@ function trunc(x) {
6696
7005
  return idiv(x, 1);
6697
7006
  }
6698
7007
  /**
7008
+ * @function
7009
+ * Round to the given number of decimals.
7010
+ *
7011
+ * Uses banker's rounding (round half to even) to match NumPy/JAX behavior.
7012
+ */
7013
+ const round = jit$1(function round$1(a, decimals = 0) {
7014
+ if (decimals === 0) return rint(a);
7015
+ const factor = 10 ** decimals;
7016
+ return rint(a.mul(factor)).mul(1 / factor);
7017
+ }, { staticArgnums: [1] });
7018
+ /**
7019
+ * @function
7020
+ * Round to the nearest integer, with ties going to the nearest even integer.
7021
+ */
7022
+ const rint = jit$1(function rint$1(x) {
7023
+ const rounded = floor(x.ref.add(.5));
7024
+ const half = x.ref.sub(floor(x)).equal(.5);
7025
+ const odd = remainder(rounded.ref, 2).notEqual(0);
7026
+ return where(half.mul(odd), rounded.ref.sub(1), rounded);
7027
+ });
7028
+ /**
6699
7029
  * Compute `x1 * 2 ** x2` as a standard multiplication and exponentiation.
6700
7030
  *
6701
7031
  * This is the inverse of `frexp()`.
@@ -7023,6 +7353,7 @@ function triangularSolve(a, b, { leftSide = false, lower = false, transposeA = f
7023
7353
  //#region src/library/lax.ts
7024
7354
  var lax_exports = {};
7025
7355
  __export(lax_exports, {
7356
+ bitcastConvertType: () => bitcastConvertType,
7026
7357
  conv: () => conv,
7027
7358
  convGeneralDilated: () => convGeneralDilated,
7028
7359
  convTranspose: () => convTranspose,
@@ -7036,6 +7367,10 @@ __export(lax_exports, {
7036
7367
  topK: () => topK
7037
7368
  });
7038
7369
  const JsArray = globalThis.Array;
7370
+ /** Elementwise bitcast an array into a new dtype. */
7371
+ function bitcastConvertType(x, newDtype) {
7372
+ return fudgeArray(x).view(newDtype);
7373
+ }
7039
7374
  /**
7040
7375
  * General dot product/contraction operator.
7041
7376
  *
@@ -7767,7 +8102,9 @@ function getK01(key$1) {
7767
8102
  function key(seed) {
7768
8103
  seed = array(seed, { dtype: require_backend.DType.Uint32 });
7769
8104
  if (seed.ndim !== 0) throw new Error(`key: seed must be a scalar integer, but got shape ${seed.shape} - use jax.vmap for batching.`);
7770
- return stack([0, seed]);
8105
+ const key$1 = stack([0, seed]);
8106
+ if (key$1 instanceof Array$1) key$1._realizeSource();
8107
+ return key$1;
7771
8108
  }
7772
8109
  /** Splits a PRNG key into `num` new keys by adding a leading axis. */
7773
8110
  function split(key$1, num = 2) {
@@ -7962,6 +8299,11 @@ Symbol.asyncDispose ??= Symbol.for("Symbol.asyncDispose");
7962
8299
 
7963
8300
  //#endregion
7964
8301
  //#region src/index.ts
8302
+ /** @namespace */
8303
+ const profiler = {
8304
+ startTrace: require_backend.startTrace,
8305
+ stopTrace: require_backend.stopTrace
8306
+ };
7965
8307
  /**
7966
8308
  * @function
7967
8309
  * Compute the forward-mode Jacobian-vector product for a function.
@@ -8130,6 +8472,7 @@ exports.blockUntilReady = blockUntilReady;
8130
8472
  exports.defaultDevice = require_backend.defaultDevice;
8131
8473
  exports.devicePut = devicePut;
8132
8474
  exports.devices = require_backend.devices;
8475
+ exports.getWebGPUDevice = require_backend.getWebGPUDevice;
8133
8476
  exports.grad = grad;
8134
8477
  exports.hessian = hessian;
8135
8478
  exports.init = require_backend.init;
@@ -8158,6 +8501,7 @@ Object.defineProperty(exports, 'numpy', {
8158
8501
  return numpy_exports;
8159
8502
  }
8160
8503
  });
8504
+ exports.profiler = profiler;
8161
8505
  Object.defineProperty(exports, 'random', {
8162
8506
  enumerable: true,
8163
8507
  get: function () {