@jax-js/jax 0.1.9 → 0.1.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -1,5 +1,5 @@
1
1
  import { __export } from "./chunk-Cl8Af3a2.js";
2
- import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-BId79r5b.js";
2
+ import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, getWebGPUDevice, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, startTrace, stopTrace, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-DZvR7mZV.js";
3
3
 
4
4
  //#region src/frontend/convolution.ts
5
5
  /**
@@ -333,6 +333,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
333
333
  Primitive$1["Mod"] = "mod";
334
334
  Primitive$1["Min"] = "min";
335
335
  Primitive$1["Max"] = "max";
336
+ Primitive$1["BitCombine"] = "bit_combine";
337
+ Primitive$1["BitShift"] = "bit_shift";
336
338
  Primitive$1["Neg"] = "neg";
337
339
  Primitive$1["Reciprocal"] = "reciprocal";
338
340
  Primitive$1["Floor"] = "floor";
@@ -406,6 +408,12 @@ function min$1(x, y) {
406
408
  function max$1(x, y) {
407
409
  return bind1(Primitive.Max, [x, y]);
408
410
  }
411
+ function bitCombine(x, y, op) {
412
+ return bind1(Primitive.BitCombine, [x, y], { op });
413
+ }
414
+ function bitShift(x, y, op) {
415
+ return bind1(Primitive.BitShift, [x, y], { op });
416
+ }
409
417
  function neg(x) {
410
418
  return bind1(Primitive.Neg, [x]);
411
419
  }
@@ -807,6 +815,11 @@ var Tracer = class Tracer {
807
815
  if (this.dtype === dtype) return this;
808
816
  return cast(this, dtype);
809
817
  }
818
+ /** Return a bitwise cast of the array, viewed as a new dtype. */
819
+ view(dtype) {
820
+ if (!dtype || dtype === this.dtype) return this;
821
+ return bitcast(this, dtype);
822
+ }
810
823
  /** Subtract an array from this one. */
811
824
  sub(other) {
812
825
  return this.add(neg(other));
@@ -1615,6 +1628,16 @@ const abstractEvalRules = {
1615
1628
  [Primitive.Mod]: binopAbstractEval,
1616
1629
  [Primitive.Min]: binopAbstractEval,
1617
1630
  [Primitive.Max]: binopAbstractEval,
1631
+ [Primitive.BitCombine]([x, y]) {
1632
+ const aval = promoteAvals(x, y);
1633
+ if (isFloatDtype(aval.dtype)) throw new TypeError(`bitwise operations require integer or boolean inputs, got ${aval.dtype}`);
1634
+ return [aval];
1635
+ },
1636
+ [Primitive.BitShift]([x, y]) {
1637
+ const shape$1 = generalBroadcast(x.shape, y.shape);
1638
+ if (isFloatDtype(x.dtype) || isFloatDtype(y.dtype) || x.dtype === DType.Bool || y.dtype === DType.Bool) throw new TypeError(`bit shift operations require integer inputs, got ${x} and ${y}`);
1639
+ return [new ShapedArray(shape$1, x.dtype, x.weakType)];
1640
+ },
1618
1641
  [Primitive.Neg]: vectorizedUnopAbstractEval,
1619
1642
  [Primitive.Reciprocal]: vectorizedUnopAbstractEval,
1620
1643
  [Primitive.Floor]: vectorizedUnopAbstractEval,
@@ -1624,7 +1647,7 @@ const abstractEvalRules = {
1624
1647
  return [new ShapedArray(x.shape, dtype, false)];
1625
1648
  },
1626
1649
  [Primitive.Bitcast]([x], { dtype }) {
1627
- if (x.dtype === DType.Bool || dtype === DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
1650
+ if (x.dtype !== dtype && (x.dtype === DType.Bool || dtype === DType.Bool)) throw new TypeError("Bitcast to/from bool is not allowed");
1628
1651
  if (byteWidth(x.dtype) !== byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
1629
1652
  return [new ShapedArray(x.shape, dtype, false)];
1630
1653
  },
@@ -2150,6 +2173,8 @@ const jitRules = {
2150
2173
  [Primitive.Mod]: broadcastedJit(([a, b]) => AluExp.mod(a, b)),
2151
2174
  [Primitive.Min]: broadcastedJit(([a, b]) => AluExp.min(a, b)),
2152
2175
  [Primitive.Max]: broadcastedJit(([a, b]) => AluExp.max(a, b)),
2176
+ [Primitive.BitCombine]: broadcastedJit(([a, b], { op }) => AluExp.bitCombine(a, b, op)),
2177
+ [Primitive.BitShift]: broadcastedJit(([a, b], { op }) => AluExp.bitShift(a, b, op)),
2153
2178
  [Primitive.Neg]: unopJit((a) => AluExp.sub(AluExp.const(a.dtype, 0), a)),
2154
2179
  [Primitive.Reciprocal]: unopJit(AluExp.reciprocal),
2155
2180
  [Primitive.Floor]: unopJit(AluExp.floor),
@@ -2342,7 +2367,9 @@ function splitGraphDataflow(backend, jaxpr) {
2342
2367
  case Primitive.Idiv:
2343
2368
  case Primitive.Mod:
2344
2369
  case Primitive.Min:
2345
- case Primitive.Max: {
2370
+ case Primitive.Max:
2371
+ case Primitive.BitCombine:
2372
+ case Primitive.BitShift: {
2346
2373
  const otherInput = nextEqn.inputs.find((v) => v !== outVar);
2347
2374
  if (otherInput instanceof Lit || deepEqual(generalBroadcast(otherInput.aval.shape, outVar.aval.shape), outVar.aval.shape)) {
2348
2375
  head = usages[0];
@@ -2981,6 +3008,42 @@ var Array$1 = class Array$1 extends Tracer {
2981
3008
  return dtypedArray(this.dtype, buf);
2982
3009
  }
2983
3010
  /**
3011
+ * Return this array as a WebGPU buffer (with `STORAGE | COPY_SRC`).
3012
+ *
3013
+ * Only available on the WebGPU backend. The array's memory is still managed
3014
+ * by jax-js, and it will be freed when the buffer is no longer in use. You
3015
+ * _should not_ mutate the buffer's contents.
3016
+ *
3017
+ * Note that the GPU buffer may be slightly larger than the array's size; it
3018
+ * will always be aligned to 4 bytes.
3019
+ */
3020
+ async gpuBuffer() {
3021
+ if (this.device !== "webgpu") throw new Error(`gpuBuffer() is only available on WebGPU backend`);
3022
+ this.#realize();
3023
+ const pending = this.#pending;
3024
+ if (pending) {
3025
+ await Promise.all(pending.map((p) => p.prepare()));
3026
+ for (const p of pending) p.submit();
3027
+ }
3028
+ const backend = this.#backend;
3029
+ const { buffer } = backend.buffers.get(this.#source);
3030
+ this.dispose();
3031
+ return buffer;
3032
+ }
3033
+ /** Synchronous version of `Array.gpuBuffer()`. */
3034
+ gpuBufferSync() {
3035
+ if (this.device !== "webgpu") throw new Error(`gpuBufferSync() is only available on WebGPU backend`);
3036
+ this.#realize();
3037
+ for (const p of this.#pending) {
3038
+ p.prepareSync();
3039
+ p.submit();
3040
+ }
3041
+ const backend = this.#backend;
3042
+ const { buffer } = backend.buffers.get(this.#source);
3043
+ this.dispose();
3044
+ return buffer;
3045
+ }
3046
+ /**
2984
3047
  * Convert this array into a JavaScript object.
2985
3048
  *
2986
3049
  * This is a blocking operation that will compile all of the shaders and wait
@@ -3027,6 +3090,14 @@ var Array$1 = class Array$1 extends Tracer {
3027
3090
  [Primitive.Max]([x, y]) {
3028
3091
  return [x.#binary(AluOp.Max, y)];
3029
3092
  },
3093
+ [Primitive.BitCombine]([x, y], { op }) {
3094
+ const custom = (src) => AluExp.bitCombine(src[0], src[1], op);
3095
+ return [Array$1.#naryCustom("bit_combine", custom, [x, y])];
3096
+ },
3097
+ [Primitive.BitShift]([x, y], { op }) {
3098
+ const custom = (src) => AluExp.bitShift(src[0], src[1], op);
3099
+ return [Array$1.#naryCustom("bit_shift", custom, [x, y], { dtypeOverride: [void 0, y.dtype] })];
3100
+ },
3030
3101
  [Primitive.Neg]([x]) {
3031
3102
  return [zerosLike$1(x.ref).#binary(AluOp.Sub, x)];
3032
3103
  },
@@ -3046,8 +3117,8 @@ var Array$1 = class Array$1 extends Tracer {
3046
3117
  return [x.#unary(AluOp.Cast, dtype)];
3047
3118
  },
3048
3119
  [Primitive.Bitcast]([x], { dtype }) {
3049
- if (x.dtype === DType.Bool || dtype === DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
3050
3120
  if (x.dtype === dtype) return [x];
3121
+ if (x.dtype === DType.Bool || dtype === DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
3051
3122
  if (byteWidth(x.dtype) !== byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
3052
3123
  if (x.#source instanceof AluExp) return [x.#unary(AluOp.Bitcast, dtype)];
3053
3124
  else {
@@ -3718,6 +3789,8 @@ const vmapRules = {
3718
3789
  [Primitive.Mod]: broadcastBatcher(Primitive.Mod),
3719
3790
  [Primitive.Min]: broadcastBatcher(Primitive.Min),
3720
3791
  [Primitive.Max]: broadcastBatcher(Primitive.Max),
3792
+ [Primitive.BitCombine]: broadcastBatcher(Primitive.BitCombine),
3793
+ [Primitive.BitShift]: broadcastBatcher(Primitive.BitShift),
3721
3794
  [Primitive.Neg]: unopBatcher(Primitive.Neg),
3722
3795
  [Primitive.Reciprocal]: unopBatcher(Primitive.Reciprocal),
3723
3796
  [Primitive.Floor]: unopBatcher(Primitive.Floor),
@@ -4040,6 +4113,8 @@ const jvpRules = {
4040
4113
  [Primitive.Max]([x, y], [dx, dy]) {
4041
4114
  return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
4042
4115
  },
4116
+ [Primitive.BitCombine]: zeroTangentsJvp(Primitive.BitCombine),
4117
+ [Primitive.BitShift]: zeroTangentsJvp(Primitive.BitShift),
4043
4118
  [Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
4044
4119
  [Primitive.Reciprocal]([x], [dx]) {
4045
4120
  const xRecip = reciprocal$1(x.ref);
@@ -4142,6 +4217,7 @@ const jvpRules = {
4142
4217
  },
4143
4218
  [Primitive.TriangularSolve]([a, b], [da, db], { unitDiagonal }) {
4144
4219
  const x = triangularSolve$1(a.ref, b, { unitDiagonal });
4220
+ da = unitDiagonal ? triu(da, 1) : triu(da);
4145
4221
  const dax = batchMatmulT(da, x.ref);
4146
4222
  const rhsT = db.sub(mT(dax));
4147
4223
  const dx = triangularSolve$1(a, rhsT, { unitDiagonal });
@@ -5217,6 +5293,7 @@ function ifft(a, axis = -1) {
5217
5293
  var numpy_linalg_exports = {};
5218
5294
  __export(numpy_linalg_exports, {
5219
5295
  cholesky: () => cholesky,
5296
+ cross: () => cross$1,
5220
5297
  det: () => det,
5221
5298
  diagonal: () => diagonal,
5222
5299
  inv: () => inv,
@@ -5229,7 +5306,8 @@ __export(numpy_linalg_exports, {
5229
5306
  solve: () => solve,
5230
5307
  tensordot: () => tensordot,
5231
5308
  trace: () => trace,
5232
- vecdot: () => vecdot
5309
+ vecdot: () => vecdot,
5310
+ vectorNorm: () => vectorNorm
5233
5311
  });
5234
5312
  function checkSquare(name, a) {
5235
5313
  if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`${name}: input must be at least 2D square matrix, got ${a.aval}`);
@@ -5247,6 +5325,19 @@ function cholesky(a, { upper = false, symmetrizeInput = true } = {}) {
5247
5325
  if (symmetrizeInput) a = a.ref.add(matrixTranspose(a)).mul(.5);
5248
5326
  return cholesky$1(a, { upper });
5249
5327
  }
5328
+ /**
5329
+ * Compute the cross-product of two 3D vectors.
5330
+ *
5331
+ * This is a simpler and less flexible version of `jax.numpy.cross()`.
5332
+ * Both inputs must have size 3 along the specified axis.
5333
+ */
5334
+ function cross$1(x1, x2, axis = -1) {
5335
+ const a1 = checkAxis(axis, ndim(x1));
5336
+ const a2 = checkAxis(axis, ndim(x2));
5337
+ if (shape(x1)[a1] !== 3) throw new Error(`linalg.cross: x1 must have size 3 along axis ${axis}, got ${shape(x1)[a1]}`);
5338
+ if (shape(x2)[a2] !== 3) throw new Error(`linalg.cross: x2 must have size 3 along axis ${axis}, got ${shape(x2)[a2]}`);
5339
+ return cross(x1, x2, { axis });
5340
+ }
5250
5341
  /** Compute the determinant of a square matrix (batched). */
5251
5342
  function det(a) {
5252
5343
  a = fudgeArray(a);
@@ -5262,7 +5353,7 @@ function det(a) {
5262
5353
  function inv(a) {
5263
5354
  a = fudgeArray(a);
5264
5355
  const n = checkSquare("inv", a);
5265
- return solve(a, eye(n));
5356
+ return solve(a, eye(n, void 0, { dtype: a.dtype }));
5266
5357
  }
5267
5358
  /**
5268
5359
  * Return the least-squares solution to a linear equation.
@@ -5319,8 +5410,9 @@ function matrixPower(a, n) {
5319
5410
  a = fudgeArray(a);
5320
5411
  const m = checkSquare("matrixPower", a);
5321
5412
  if (n === 0) {
5413
+ const dtype = a.dtype;
5322
5414
  a.dispose();
5323
- return broadcastTo(eye(m), a.shape);
5415
+ return broadcastTo(eye(m, void 0, { dtype }), a.shape);
5324
5416
  }
5325
5417
  if (n < 0) {
5326
5418
  a = inv(a);
@@ -5394,6 +5486,23 @@ function solve(a, b) {
5394
5486
  if (bIs1d) x = squeeze(x, -1);
5395
5487
  return x;
5396
5488
  }
5489
+ /**
5490
+ * Compute the vector norm of an array.
5491
+ *
5492
+ * @param x - Input array.
5493
+ * @param ord - Order of the norm (default 2). Supports `Infinity`, `-Infinity`, `0`, or any real number.
5494
+ * @param axis - Axis/axes to reduce over (default: all axes).
5495
+ * @param keepdims - Whether to keep reduced dimensions as size 1.
5496
+ * @returns The norm of `x`, reduced over the given axes.
5497
+ */
5498
+ function vectorNorm(x, { ord = 2, axis = null, keepdims = false } = {}) {
5499
+ x = fudgeArray(x);
5500
+ const ax = axis ?? null;
5501
+ if (ord === Infinity) return max(absolute(x), ax, { keepdims });
5502
+ else if (ord === -Infinity) return min(absolute(x), ax, { keepdims });
5503
+ else if (ord === 0) return x.notEqual(0).astype(x.dtype).sum(ax, { keepdims });
5504
+ else return power(power(absolute(x), ord).sum(ax, { keepdims }), 1 / ord);
5505
+ }
5397
5506
 
5398
5507
  //#endregion
5399
5508
  //#region src/library/numpy/dtype-info.ts
@@ -5502,13 +5611,24 @@ __export(numpy_exports, {
5502
5611
  argmax: () => argmax,
5503
5612
  argmin: () => argmin,
5504
5613
  argsort: () => argsort,
5614
+ around: () => round,
5505
5615
  array: () => array,
5616
+ arrayEqual: () => arrayEqual,
5617
+ arrayEquiv: () => arrayEquiv,
5506
5618
  asin: () => asin,
5507
5619
  asinh: () => arcsinh,
5508
5620
  astype: () => astype,
5509
5621
  atan: () => atan,
5510
5622
  atan2: () => atan2,
5511
5623
  atanh: () => arctanh,
5624
+ average: () => average,
5625
+ bitwiseAnd: () => bitwiseAnd,
5626
+ bitwiseInvert: () => invert,
5627
+ bitwiseLeftShift: () => leftShift,
5628
+ bitwiseNot: () => invert,
5629
+ bitwiseOr: () => bitwiseOr,
5630
+ bitwiseRightShift: () => rightShift,
5631
+ bitwiseXor: () => bitwiseXor,
5512
5632
  bool: () => bool,
5513
5633
  broadcastArrays: () => broadcastArrays,
5514
5634
  broadcastShapes: () => broadcastShapes,
@@ -5519,11 +5639,13 @@ __export(numpy_exports, {
5519
5639
  columnStack: () => columnStack,
5520
5640
  concatenate: () => concatenate,
5521
5641
  convolve: () => convolve,
5642
+ copysign: () => copysign,
5522
5643
  corrcoef: () => corrcoef,
5523
5644
  correlate: () => correlate,
5524
5645
  cos: () => cos,
5525
5646
  cosh: () => cosh,
5526
5647
  cov: () => cov,
5648
+ cross: () => cross,
5527
5649
  cumsum: () => cumsum,
5528
5650
  cumulativeSum: () => cumsum,
5529
5651
  deg2rad: () => deg2rad,
@@ -5559,7 +5681,6 @@ __export(numpy_exports, {
5559
5681
  fullLike: () => fullLike$1,
5560
5682
  greater: () => greater,
5561
5683
  greaterEqual: () => greaterEqual,
5562
- hamming: () => hamming,
5563
5684
  hann: () => hann,
5564
5685
  heaviside: () => heaviside,
5565
5686
  hstack: () => hstack,
@@ -5569,12 +5690,14 @@ __export(numpy_exports, {
5569
5690
  inf: () => inf,
5570
5691
  inner: () => inner,
5571
5692
  int32: () => int32,
5693
+ invert: () => invert,
5572
5694
  isfinite: () => isfinite,
5573
5695
  isinf: () => isinf,
5574
5696
  isnan: () => isnan,
5575
5697
  isneginf: () => isneginf,
5576
5698
  isposinf: () => isposinf,
5577
5699
  ldexp: () => ldexp,
5700
+ leftShift: () => leftShift,
5578
5701
  less: () => less,
5579
5702
  lessEqual: () => lessEqual,
5580
5703
  linalg: () => numpy_linalg_exports,
@@ -5583,9 +5706,14 @@ __export(numpy_exports, {
5583
5706
  log10: () => log10,
5584
5707
  log1p: () => log1p,
5585
5708
  log2: () => log2,
5709
+ logicalAnd: () => logicalAnd,
5710
+ logicalNot: () => logicalNot,
5711
+ logicalOr: () => logicalOr,
5712
+ logicalXor: () => logicalXor,
5586
5713
  logspace: () => logspace,
5587
5714
  matmul: () => matmul,
5588
5715
  matrixTranspose: () => matrixTranspose,
5716
+ matvec: () => matvec,
5589
5717
  max: () => max,
5590
5718
  maximum: () => maximum,
5591
5719
  mean: () => mean,
@@ -5618,6 +5746,9 @@ __export(numpy_exports, {
5618
5746
  remainder: () => remainder,
5619
5747
  repeat: () => repeat,
5620
5748
  reshape: () => reshape,
5749
+ rightShift: () => rightShift,
5750
+ rint: () => rint,
5751
+ round: () => round,
5621
5752
  shape: () => shape,
5622
5753
  sign: () => sign,
5623
5754
  sin: () => sin,
@@ -5650,6 +5781,7 @@ __export(numpy_exports, {
5650
5781
  var_: () => var_,
5651
5782
  vdot: () => vdot,
5652
5783
  vecdot: () => vecdot,
5784
+ vecmat: () => vecmat,
5653
5785
  vstack: () => vstack,
5654
5786
  where: () => where,
5655
5787
  zeros: () => zeros,
@@ -5713,6 +5845,60 @@ const notEqual = notEqual$1;
5713
5845
  const greaterEqual = greaterEqual$1;
5714
5846
  /** @function Compare two arrays element-wise. */
5715
5847
  const lessEqual = lessEqual$1;
5848
+ /** Compute element-wise logical AND. */
5849
+ function logicalAnd(x, y) {
5850
+ return astype(x, DType.Bool).mul(astype(y, DType.Bool));
5851
+ }
5852
+ /** Compute element-wise logical OR. */
5853
+ function logicalOr(x, y) {
5854
+ return astype(x, DType.Bool).add(astype(y, DType.Bool));
5855
+ }
5856
+ /** Compute element-wise logical XOR. */
5857
+ function logicalXor(x, y) {
5858
+ return notEqual(astype(x, DType.Bool), astype(y, DType.Bool));
5859
+ }
5860
+ /** Compute element-wise logical NOT. */
5861
+ function logicalNot(x) {
5862
+ return notEqual(astype(x, DType.Bool), true);
5863
+ }
5864
+ /** Compute element-wise bitwise AND. */
5865
+ function bitwiseAnd(x, y) {
5866
+ return bitCombine(x, y, "and");
5867
+ }
5868
+ /** Compute element-wise bitwise OR. */
5869
+ function bitwiseOr(x, y) {
5870
+ return bitCombine(x, y, "or");
5871
+ }
5872
+ /** Compute element-wise bitwise XOR. */
5873
+ function bitwiseXor(x, y) {
5874
+ return bitCombine(x, y, "xor");
5875
+ }
5876
+ /** Compute element-wise bitwise NOT (inversion). */
5877
+ function invert(x) {
5878
+ const arr = fudgeArray(x);
5879
+ let allOnes;
5880
+ switch (arr.dtype) {
5881
+ case DType.Bool:
5882
+ allOnes = true;
5883
+ break;
5884
+ case DType.Uint32:
5885
+ allOnes = 4294967295;
5886
+ break;
5887
+ case DType.Int32:
5888
+ allOnes = -1;
5889
+ break;
5890
+ default: throw new TypeError(`invert: unsupported dtype ${arr.dtype}`);
5891
+ }
5892
+ return bitCombine(arr, allOnes, "xor");
5893
+ }
5894
+ /** Compute element-wise left bit shift. */
5895
+ function leftShift(x, y) {
5896
+ return bitShift(x, y, "shl");
5897
+ }
5898
+ /** Compute element-wise right bit shift. */
5899
+ function rightShift(x, y) {
5900
+ return bitShift(x, y, "shr");
5901
+ }
5716
5902
  /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
5717
5903
  const where = where$1;
5718
5904
  /**
@@ -5820,6 +6006,34 @@ function mean(a, axis = null, opts) {
5820
6006
  return fudgeArray(a).mean(axis, opts);
5821
6007
  }
5822
6008
  /**
6009
+ * Compute the weighted average along the specified axis.
6010
+ *
6011
+ * If no axis is specified, mean is computed along all the axes. The weights
6012
+ * should have shape matching that of `a`, or if an axis is specified, it should
6013
+ * match the shape along those axes.
6014
+ */
6015
+ function average(a, axis = null, opts) {
6016
+ a = fudgeArray(a);
6017
+ if (opts?.weights == null) return mean(a, axis, opts);
6018
+ const weights = fudgeArray(opts.weights);
6019
+ axis = normalizeAxis(axis, ndim(a));
6020
+ const wShape = weights.shape;
6021
+ const aShape = a.shape;
6022
+ if (deepEqual(wShape, aShape)) {
6023
+ const scl = sum(weights.ref, axis, opts);
6024
+ return sum(multiply(a, weights), axis, opts).div(scl);
6025
+ } else if (axis.length === 1 && wShape.length === 1 && wShape[0] === aShape[axis[0]]) {
6026
+ const broadcastShape = aShape.map((_, i) => i === axis[0] ? wShape[0] : 1);
6027
+ const wReshaped = reshape(weights, broadcastShape);
6028
+ const scl = sum(wReshaped.ref, axis, opts);
6029
+ return sum(multiply(a, wReshaped), axis, opts).div(scl);
6030
+ } else {
6031
+ weights.dispose();
6032
+ a.dispose();
6033
+ throw new Error(`average: weights shape ${JSON.stringify(wShape)} is not compatible with array shape ${JSON.stringify(aShape)} and axis ${JSON.stringify(axis)}`);
6034
+ }
6035
+ }
6036
+ /**
5823
6037
  * Returns the indices of the minimum values along an axis.
5824
6038
  *
5825
6039
  * By default, index is into the flatted array, otherwise it is along the
@@ -6223,20 +6437,63 @@ function take(a, indices, axis = null) {
6223
6437
  axis = checkAxis(axis, ndim(a));
6224
6438
  return gather(a, [indices], [axis], axis);
6225
6439
  }
6226
- /** Return if two arrays are element-wise equal within a tolerance. */
6440
+ /**
6441
+ * Return if two arrays are element-wise equal within a tolerance.
6442
+ *
6443
+ * The formula used is `|actual - expected| <= atol + rtol * |expected|`, with
6444
+ * NaN values comparing equal if `equalNaN` is true.
6445
+ */
6227
6446
  function allclose(actual, expected, options) {
6228
- const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
6447
+ const { rtol = 1e-5, atol = 1e-7, equalNaN = false } = options ?? {};
6229
6448
  const x = array(actual);
6230
6449
  const y = array(expected);
6231
6450
  if (!deepEqual(x.shape, y.shape)) return false;
6232
6451
  const xData = x.dataSync();
6233
6452
  const yData = y.dataSync();
6234
6453
  for (let i = 0; i < xData.length; i++) {
6235
- if (isNaN(xData[i]) !== isNaN(yData[i])) return false;
6454
+ if (equalNaN ? isNaN(xData[i]) !== isNaN(yData[i]) : isNaN(xData[i]) || isNaN(yData[i])) return false;
6236
6455
  if (Math.abs(xData[i] - yData[i]) > atol + rtol * Math.abs(yData[i])) return false;
6237
6456
  }
6238
6457
  return true;
6239
6458
  }
6459
+ /**
6460
+ * Check if two arrays are element-wise equal.
6461
+ *
6462
+ * Returns False if the arrays have different shapes. If `equalNaN` is True,
6463
+ * NaNs in the same position are considered equal.
6464
+ */
6465
+ function arrayEqual(a1, a2, opts) {
6466
+ a1 = fudgeArray(a1);
6467
+ a2 = fudgeArray(a2);
6468
+ if (!deepEqual(a1.shape, a2.shape)) {
6469
+ a1.dispose();
6470
+ a2.dispose();
6471
+ return array(false);
6472
+ }
6473
+ if (opts?.equalNaN) {
6474
+ const nanMask = isnan(a1.ref).mul(isnan(a2.ref));
6475
+ return where(nanMask, true, equal(a1, a2)).all();
6476
+ }
6477
+ return equal(a1, a2).all();
6478
+ }
6479
+ /**
6480
+ * Check if two arrays are element-wise equal after broadcasting.
6481
+ *
6482
+ * Unlike `arrayEqual`, this allows inputs with different but
6483
+ * broadcast-compatible shapes.
6484
+ */
6485
+ function arrayEquiv(a1, a2) {
6486
+ a1 = fudgeArray(a1);
6487
+ a2 = fudgeArray(a2);
6488
+ try {
6489
+ const [b1, b2] = broadcastArrays(a1, a2);
6490
+ return equal(b1, b2).all();
6491
+ } catch {
6492
+ a1.dispose();
6493
+ a2.dispose();
6494
+ return array(false);
6495
+ }
6496
+ }
6240
6497
  /** Matrix product of two arrays. */
6241
6498
  function matmul(x, y) {
6242
6499
  if (ndim(x) === 0 || ndim(y) === 0) throw new Error("matmul: x and y must be at least 1D");
@@ -6250,6 +6507,16 @@ function matmul(x, y) {
6250
6507
  rhsBatchDims: range(-2 - numBatchDims, -2)
6251
6508
  });
6252
6509
  }
6510
+ /** Matrix-vector product. x1 is [..., M, N], x2 is [..., N] → [..., M]. */
6511
+ function matvec(x1, x2) {
6512
+ if (ndim(x1) < 2 || ndim(x2) < 1) throw new Error("matvec: x1 must be at least 2D and x2 at least 1D");
6513
+ return einsum("...mn,...n->...m", x1, x2);
6514
+ }
6515
+ /** Vector-matrix product. x1 is [..., N], x2 is [..., N, M] → [..., M]. */
6516
+ function vecmat(x1, x2) {
6517
+ if (ndim(x1) < 1 || ndim(x2) < 2) throw new Error("vecmat: x1 must be at least 1D and x2 at least 2D");
6518
+ return einsum("...n,...nm->...m", x1, x2);
6519
+ }
6253
6520
  /** Dot product of two arrays. */
6254
6521
  function dot$1(x, y) {
6255
6522
  if (ndim(x) === 0 || ndim(y) === 0) return multiply(x, y);
@@ -6408,6 +6675,49 @@ function outer(x, y) {
6408
6675
  y = ravel(y);
6409
6676
  return multiply(x.reshape([x.shape[0], 1]), y);
6410
6677
  }
6678
+ /**
6679
+ * @function Compute the cross product of two arrays.
6680
+ *
6681
+ * Supports 2D (scalar result) and 3D cross products, with optional axis
6682
+ * arguments. If `axis` is given, it overrides `axisa`, `axisb`, and `axisc`.
6683
+ */
6684
+ const cross = jit$1(function cross$2(a, b, { axisa = -1, axisb = -1, axisc = -1, axis } = {}) {
6685
+ if (axis !== void 0) {
6686
+ axisa = axis;
6687
+ axisb = axis;
6688
+ axisc = axis;
6689
+ }
6690
+ axisa = checkAxis(axisa, ndim(a));
6691
+ axisb = checkAxis(axisb, ndim(b));
6692
+ a = moveaxis$1(a, axisa, -1);
6693
+ b = moveaxis$1(b, axisb, -1);
6694
+ const da = a.shape.at(-1);
6695
+ const db = b.shape.at(-1);
6696
+ if (da !== 2 && da !== 3 || db !== 2 && db !== 3) throw new Error(`cross: incompatible dimensions for cross product (got ${da} and ${db})`);
6697
+ if (da === 2 && db === 2) {
6698
+ const [a0$1, a1$1] = split$1(a, 2, -1);
6699
+ const [b0$1, b1$1] = split$1(b, 2, -1);
6700
+ return squeeze(a0$1.mul(b1$1).sub(a1$1.mul(b0$1)), -1);
6701
+ }
6702
+ if (da === 2) {
6703
+ const zeroShape = [...a.shape.slice(0, -1), 1];
6704
+ a = concatenate([a, zeros(zeroShape)], -1);
6705
+ }
6706
+ if (db === 2) {
6707
+ const zeroShape = [...b.shape.slice(0, -1), 1];
6708
+ b = concatenate([b, zeros(zeroShape)], -1);
6709
+ }
6710
+ const [a0, a1, a2] = split$1(a, 3, -1);
6711
+ const [b0, b1, b2] = split$1(b, 3, -1);
6712
+ const c0 = a1.ref.mul(b2.ref).sub(a2.ref.mul(b1.ref));
6713
+ const c1 = a2.mul(b0.ref).sub(a0.ref.mul(b2));
6714
+ const c2 = a0.mul(b1).sub(a1.mul(b0));
6715
+ return moveaxis$1(concatenate([
6716
+ c0,
6717
+ c1,
6718
+ c2
6719
+ ], -1), -1, axisc);
6720
+ }, { staticArgnums: [2] });
6411
6721
  /** Vector dot product of two arrays along a given axis. */
6412
6722
  function vecdot(x, y, { axis } = {}) {
6413
6723
  const xaxis = checkAxis(axis ?? -1, ndim(x));
@@ -6504,16 +6814,15 @@ function sign(x) {
6504
6814
  x = fudgeArray(x);
6505
6815
  return where(notEqual(x.ref, 0), where(less(x, 0), -1, 1), 0);
6506
6816
  }
6507
- /** @function Return element-wise positive values of the input (no-op). */
6508
- const positive = fudgeArray;
6509
6817
  /**
6510
- * Return the Hamming window of size M, a taper with a weighted cosine bell.
6511
- *
6512
- * `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
6818
+ * @function
6819
+ * Return the value with the magnitude of x and the sign of y, element-wise.
6513
6820
  */
6514
- function hamming(M) {
6515
- return cos(linspace(0, 2 * Math.PI, M)).mul(-.46).add(.54);
6516
- }
6821
+ const copysign = jit$1(function copysign$1(x, y) {
6822
+ return absolute(x).mul(sign(y));
6823
+ });
6824
+ /** @function Return element-wise positive values of the input (no-op). */
6825
+ const positive = fudgeArray;
6517
6826
  /**
6518
6827
  * Return the Hann window of size M, a taper with a weighted cosine bell.
6519
6828
  *
@@ -6659,6 +6968,27 @@ function trunc(x) {
6659
6968
  return idiv(x, 1);
6660
6969
  }
6661
6970
  /**
6971
+ * @function
6972
+ * Round to the given number of decimals.
6973
+ *
6974
+ * Uses banker's rounding (round half to even) to match NumPy/JAX behavior.
6975
+ */
6976
+ const round = jit$1(function round$1(a, decimals = 0) {
6977
+ if (decimals === 0) return rint(a);
6978
+ const factor = 10 ** decimals;
6979
+ return rint(a.mul(factor)).mul(1 / factor);
6980
+ }, { staticArgnums: [1] });
6981
+ /**
6982
+ * @function
6983
+ * Round to the nearest integer, with ties going to the nearest even integer.
6984
+ */
6985
+ const rint = jit$1(function rint$1(x) {
6986
+ const rounded = floor(x.ref.add(.5));
6987
+ const half = x.ref.sub(floor(x)).equal(.5);
6988
+ const odd = remainder(rounded.ref, 2).notEqual(0);
6989
+ return where(half.mul(odd), rounded.ref.sub(1), rounded);
6990
+ });
6991
+ /**
6662
6992
  * Compute `x1 * 2 ** x2` as a standard multiplication and exponentiation.
6663
6993
  *
6664
6994
  * This is the inverse of `frexp()`.
@@ -6986,6 +7316,7 @@ function triangularSolve(a, b, { leftSide = false, lower = false, transposeA = f
6986
7316
  //#region src/library/lax.ts
6987
7317
  var lax_exports = {};
6988
7318
  __export(lax_exports, {
7319
+ bitcastConvertType: () => bitcastConvertType,
6989
7320
  conv: () => conv,
6990
7321
  convGeneralDilated: () => convGeneralDilated,
6991
7322
  convTranspose: () => convTranspose,
@@ -6999,6 +7330,10 @@ __export(lax_exports, {
6999
7330
  topK: () => topK
7000
7331
  });
7001
7332
  const JsArray = globalThis.Array;
7333
+ /** Elementwise bitcast an array into a new dtype. */
7334
+ function bitcastConvertType(x, newDtype) {
7335
+ return fudgeArray(x).view(newDtype);
7336
+ }
7002
7337
  /**
7003
7338
  * General dot product/contraction operator.
7004
7339
  *
@@ -7730,7 +8065,9 @@ function getK01(key$1) {
7730
8065
  function key(seed) {
7731
8066
  seed = array(seed, { dtype: DType.Uint32 });
7732
8067
  if (seed.ndim !== 0) throw new Error(`key: seed must be a scalar integer, but got shape ${seed.shape} - use jax.vmap for batching.`);
7733
- return stack([0, seed]);
8068
+ const key$1 = stack([0, seed]);
8069
+ if (key$1 instanceof Array$1) key$1._realizeSource();
8070
+ return key$1;
7734
8071
  }
7735
8072
  /** Splits a PRNG key into `num` new keys by adding a leading axis. */
7736
8073
  function split(key$1, num = 2) {
@@ -7925,6 +8262,11 @@ Symbol.asyncDispose ??= Symbol.for("Symbol.asyncDispose");
7925
8262
 
7926
8263
  //#endregion
7927
8264
  //#region src/index.ts
8265
+ /** @namespace */
8266
+ const profiler = {
8267
+ startTrace,
8268
+ stopTrace
8269
+ };
7928
8270
  /**
7929
8271
  * @function
7930
8272
  * Compute the forward-mode Jacobian-vector product for a function.
@@ -8085,4 +8427,4 @@ async function devicePut(x, device) {
8085
8427
  }
8086
8428
 
8087
8429
  //#endregion
8088
- export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
8430
+ export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, getWebGPUDevice, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, profiler, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
@@ -1,4 +1,4 @@
1
- import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-BId79r5b.js";
1
+ import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-DZvR7mZV.js";
2
2
 
3
3
  //#region src/backend/webgl/builtins.ts
4
4
  const threefrySrc = `
@@ -458,6 +458,12 @@ function generateExpression(exp, args, inputDtypes) {
458
458
  else source = `min(${a}, ${b})`;
459
459
  else if (op === AluOp.Max) if (dtype === DType.Bool) source = `(${a} || ${b})`;
460
460
  else source = `max(${a}, ${b})`;
461
+ else if (op === AluOp.BitCombine) {
462
+ let infix = arg === "and" ? "&" : arg === "or" ? "|" : "^";
463
+ if (dtype === DType.Bool) infix = infix + infix;
464
+ source = `(${a} ${infix} ${b})`;
465
+ } else if (op === AluOp.BitShift) if (arg === "shl") source = `(${a} << ${b})`;
466
+ else source = `(${a} >> ${b})`;
461
467
  } else if (AluGroup.Compare.has(op)) {
462
468
  const a = gen(src[0]);
463
469
  const b = gen(src[1]);