@jax-js/jax 0.1.10 → 0.1.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.cjs CHANGED
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
30
30
  }) : target, mod$1));
31
31
 
32
32
  //#endregion
33
- const require_backend = require('./backend-DMauYnfl.cjs');
33
+ const require_backend = require('./backend-DlYlOYqN.cjs');
34
34
 
35
35
  //#region src/frontend/convolution.ts
36
36
  /**
@@ -364,6 +364,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
364
364
  Primitive$1["Mod"] = "mod";
365
365
  Primitive$1["Min"] = "min";
366
366
  Primitive$1["Max"] = "max";
367
+ Primitive$1["BitCombine"] = "bit_combine";
368
+ Primitive$1["BitShift"] = "bit_shift";
367
369
  Primitive$1["Neg"] = "neg";
368
370
  Primitive$1["Reciprocal"] = "reciprocal";
369
371
  Primitive$1["Floor"] = "floor";
@@ -437,6 +439,12 @@ function min$1(x, y) {
437
439
  function max$1(x, y) {
438
440
  return bind1(Primitive.Max, [x, y]);
439
441
  }
442
+ function bitCombine(x, y, op) {
443
+ return bind1(Primitive.BitCombine, [x, y], { op });
444
+ }
445
+ function bitShift(x, y, op) {
446
+ return bind1(Primitive.BitShift, [x, y], { op });
447
+ }
440
448
  function neg(x) {
441
449
  return bind1(Primitive.Neg, [x]);
442
450
  }
@@ -1655,6 +1663,16 @@ const abstractEvalRules = {
1655
1663
  [Primitive.Mod]: binopAbstractEval,
1656
1664
  [Primitive.Min]: binopAbstractEval,
1657
1665
  [Primitive.Max]: binopAbstractEval,
1666
+ [Primitive.BitCombine]([x, y]) {
1667
+ const aval = promoteAvals(x, y);
1668
+ if (require_backend.isFloatDtype(aval.dtype)) throw new TypeError(`bitwise operations require integer or boolean inputs, got ${aval.dtype}`);
1669
+ return [aval];
1670
+ },
1671
+ [Primitive.BitShift]([x, y]) {
1672
+ const shape$1 = require_backend.generalBroadcast(x.shape, y.shape);
1673
+ if (require_backend.isFloatDtype(x.dtype) || require_backend.isFloatDtype(y.dtype) || x.dtype === require_backend.DType.Bool || y.dtype === require_backend.DType.Bool) throw new TypeError(`bit shift operations require integer inputs, got ${x} and ${y}`);
1674
+ return [new ShapedArray(shape$1, x.dtype, x.weakType)];
1675
+ },
1658
1676
  [Primitive.Neg]: vectorizedUnopAbstractEval,
1659
1677
  [Primitive.Reciprocal]: vectorizedUnopAbstractEval,
1660
1678
  [Primitive.Floor]: vectorizedUnopAbstractEval,
@@ -2190,6 +2208,8 @@ const jitRules = {
2190
2208
  [Primitive.Mod]: broadcastedJit(([a, b]) => require_backend.AluExp.mod(a, b)),
2191
2209
  [Primitive.Min]: broadcastedJit(([a, b]) => require_backend.AluExp.min(a, b)),
2192
2210
  [Primitive.Max]: broadcastedJit(([a, b]) => require_backend.AluExp.max(a, b)),
2211
+ [Primitive.BitCombine]: broadcastedJit(([a, b], { op }) => require_backend.AluExp.bitCombine(a, b, op)),
2212
+ [Primitive.BitShift]: broadcastedJit(([a, b], { op }) => require_backend.AluExp.bitShift(a, b, op)),
2193
2213
  [Primitive.Neg]: unopJit((a) => require_backend.AluExp.sub(require_backend.AluExp.const(a.dtype, 0), a)),
2194
2214
  [Primitive.Reciprocal]: unopJit(require_backend.AluExp.reciprocal),
2195
2215
  [Primitive.Floor]: unopJit(require_backend.AluExp.floor),
@@ -2382,7 +2402,9 @@ function splitGraphDataflow(backend, jaxpr) {
2382
2402
  case Primitive.Idiv:
2383
2403
  case Primitive.Mod:
2384
2404
  case Primitive.Min:
2385
- case Primitive.Max: {
2405
+ case Primitive.Max:
2406
+ case Primitive.BitCombine:
2407
+ case Primitive.BitShift: {
2386
2408
  const otherInput = nextEqn.inputs.find((v) => v !== outVar);
2387
2409
  if (otherInput instanceof Lit || require_backend.deepEqual(require_backend.generalBroadcast(otherInput.aval.shape, outVar.aval.shape), outVar.aval.shape)) {
2388
2410
  head = usages[0];
@@ -3021,6 +3043,42 @@ var Array$1 = class Array$1 extends Tracer {
3021
3043
  return require_backend.dtypedArray(this.dtype, buf);
3022
3044
  }
3023
3045
  /**
3046
+ * Return this array as a WebGPU buffer (with `STORAGE | COPY_SRC`).
3047
+ *
3048
+ * Only available on the WebGPU backend. The array's memory is still managed
3049
+ * by jax-js, and it will be freed when the buffer is no longer in use. You
3050
+ * _should not_ mutate the buffer's contents.
3051
+ *
3052
+ * Note that the GPU buffer may be slightly larger than the array's size; it
3053
+ * will always be aligned to 4 bytes.
3054
+ */
3055
+ async gpuBuffer() {
3056
+ if (this.device !== "webgpu") throw new Error(`gpuBuffer() is only available on WebGPU backend`);
3057
+ this.#realize();
3058
+ const pending = this.#pending;
3059
+ if (pending) {
3060
+ await Promise.all(pending.map((p) => p.prepare()));
3061
+ for (const p of pending) p.submit();
3062
+ }
3063
+ const backend = this.#backend;
3064
+ const { buffer } = backend.buffers.get(this.#source);
3065
+ this.dispose();
3066
+ return buffer;
3067
+ }
3068
+ /** Synchronous version of `Array.gpuBuffer()`. */
3069
+ gpuBufferSync() {
3070
+ if (this.device !== "webgpu") throw new Error(`gpuBufferSync() is only available on WebGPU backend`);
3071
+ this.#realize();
3072
+ for (const p of this.#pending) {
3073
+ p.prepareSync();
3074
+ p.submit();
3075
+ }
3076
+ const backend = this.#backend;
3077
+ const { buffer } = backend.buffers.get(this.#source);
3078
+ this.dispose();
3079
+ return buffer;
3080
+ }
3081
+ /**
3024
3082
  * Convert this array into a JavaScript object.
3025
3083
  *
3026
3084
  * This is a blocking operation that will compile all of the shaders and wait
@@ -3067,6 +3125,14 @@ var Array$1 = class Array$1 extends Tracer {
3067
3125
  [Primitive.Max]([x, y]) {
3068
3126
  return [x.#binary(require_backend.AluOp.Max, y)];
3069
3127
  },
3128
+ [Primitive.BitCombine]([x, y], { op }) {
3129
+ const custom = (src) => require_backend.AluExp.bitCombine(src[0], src[1], op);
3130
+ return [Array$1.#naryCustom("bit_combine", custom, [x, y])];
3131
+ },
3132
+ [Primitive.BitShift]([x, y], { op }) {
3133
+ const custom = (src) => require_backend.AluExp.bitShift(src[0], src[1], op);
3134
+ return [Array$1.#naryCustom("bit_shift", custom, [x, y], { dtypeOverride: [void 0, y.dtype] })];
3135
+ },
3070
3136
  [Primitive.Neg]([x]) {
3071
3137
  return [zerosLike$1(x.ref).#binary(require_backend.AluOp.Sub, x)];
3072
3138
  },
@@ -3759,6 +3825,8 @@ const vmapRules = {
3759
3825
  [Primitive.Mod]: broadcastBatcher(Primitive.Mod),
3760
3826
  [Primitive.Min]: broadcastBatcher(Primitive.Min),
3761
3827
  [Primitive.Max]: broadcastBatcher(Primitive.Max),
3828
+ [Primitive.BitCombine]: broadcastBatcher(Primitive.BitCombine),
3829
+ [Primitive.BitShift]: broadcastBatcher(Primitive.BitShift),
3762
3830
  [Primitive.Neg]: unopBatcher(Primitive.Neg),
3763
3831
  [Primitive.Reciprocal]: unopBatcher(Primitive.Reciprocal),
3764
3832
  [Primitive.Floor]: unopBatcher(Primitive.Floor),
@@ -4082,6 +4150,8 @@ const jvpRules = {
4082
4150
  [Primitive.Max]([x, y], [dx, dy]) {
4083
4151
  return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
4084
4152
  },
4153
+ [Primitive.BitCombine]: zeroTangentsJvp(Primitive.BitCombine),
4154
+ [Primitive.BitShift]: zeroTangentsJvp(Primitive.BitShift),
4085
4155
  [Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
4086
4156
  [Primitive.Reciprocal]([x], [dx]) {
4087
4157
  const xRecip = reciprocal$1(x.ref);
@@ -5273,7 +5343,8 @@ __export(numpy_linalg_exports, {
5273
5343
  solve: () => solve,
5274
5344
  tensordot: () => tensordot,
5275
5345
  trace: () => trace,
5276
- vecdot: () => vecdot
5346
+ vecdot: () => vecdot,
5347
+ vectorNorm: () => vectorNorm
5277
5348
  });
5278
5349
  function checkSquare(name, a) {
5279
5350
  if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`${name}: input must be at least 2D square matrix, got ${a.aval}`);
@@ -5452,6 +5523,23 @@ function solve(a, b) {
5452
5523
  if (bIs1d) x = squeeze(x, -1);
5453
5524
  return x;
5454
5525
  }
5526
+ /**
5527
+ * Compute the vector norm of an array.
5528
+ *
5529
+ * @param x - Input array.
5530
+ * @param ord - Order of the norm (default 2). Supports `Infinity`, `-Infinity`, `0`, or any real number.
5531
+ * @param axis - Axis/axes to reduce over (default: all axes).
5532
+ * @param keepdims - Whether to keep reduced dimensions as size 1.
5533
+ * @returns The norm of `x`, reduced over the given axes.
5534
+ */
5535
+ function vectorNorm(x, { ord = 2, axis = null, keepdims = false } = {}) {
5536
+ x = fudgeArray(x);
5537
+ const ax = axis ?? null;
5538
+ if (ord === Infinity) return max(absolute(x), ax, { keepdims });
5539
+ else if (ord === -Infinity) return min(absolute(x), ax, { keepdims });
5540
+ else if (ord === 0) return x.notEqual(0).astype(x.dtype).sum(ax, { keepdims });
5541
+ else return power(power(absolute(x), ord).sum(ax, { keepdims }), 1 / ord);
5542
+ }
5455
5543
 
5456
5544
  //#endregion
5457
5545
  //#region src/library/numpy/dtype-info.ts
@@ -5571,6 +5659,13 @@ __export(numpy_exports, {
5571
5659
  atan2: () => atan2,
5572
5660
  atanh: () => arctanh,
5573
5661
  average: () => average,
5662
+ bitwiseAnd: () => bitwiseAnd,
5663
+ bitwiseInvert: () => invert,
5664
+ bitwiseLeftShift: () => leftShift,
5665
+ bitwiseNot: () => invert,
5666
+ bitwiseOr: () => bitwiseOr,
5667
+ bitwiseRightShift: () => rightShift,
5668
+ bitwiseXor: () => bitwiseXor,
5574
5669
  bool: () => bool,
5575
5670
  broadcastArrays: () => broadcastArrays,
5576
5671
  broadcastShapes: () => broadcastShapes,
@@ -5632,12 +5727,14 @@ __export(numpy_exports, {
5632
5727
  inf: () => inf,
5633
5728
  inner: () => inner,
5634
5729
  int32: () => int32,
5730
+ invert: () => invert,
5635
5731
  isfinite: () => isfinite,
5636
5732
  isinf: () => isinf,
5637
5733
  isnan: () => isnan,
5638
5734
  isneginf: () => isneginf,
5639
5735
  isposinf: () => isposinf,
5640
5736
  ldexp: () => ldexp,
5737
+ leftShift: () => leftShift,
5641
5738
  less: () => less,
5642
5739
  lessEqual: () => lessEqual,
5643
5740
  linalg: () => numpy_linalg_exports,
@@ -5686,6 +5783,7 @@ __export(numpy_exports, {
5686
5783
  remainder: () => remainder,
5687
5784
  repeat: () => repeat,
5688
5785
  reshape: () => reshape,
5786
+ rightShift: () => rightShift,
5689
5787
  rint: () => rint,
5690
5788
  round: () => round,
5691
5789
  shape: () => shape,
@@ -5800,6 +5898,44 @@ function logicalXor(x, y) {
5800
5898
  function logicalNot(x) {
5801
5899
  return notEqual(astype(x, require_backend.DType.Bool), true);
5802
5900
  }
5901
+ /** Compute element-wise bitwise AND. */
5902
+ function bitwiseAnd(x, y) {
5903
+ return bitCombine(x, y, "and");
5904
+ }
5905
+ /** Compute element-wise bitwise OR. */
5906
+ function bitwiseOr(x, y) {
5907
+ return bitCombine(x, y, "or");
5908
+ }
5909
+ /** Compute element-wise bitwise XOR. */
5910
+ function bitwiseXor(x, y) {
5911
+ return bitCombine(x, y, "xor");
5912
+ }
5913
+ /** Compute element-wise bitwise NOT (inversion). */
5914
+ function invert(x) {
5915
+ const arr = fudgeArray(x);
5916
+ let allOnes;
5917
+ switch (arr.dtype) {
5918
+ case require_backend.DType.Bool:
5919
+ allOnes = true;
5920
+ break;
5921
+ case require_backend.DType.Uint32:
5922
+ allOnes = 4294967295;
5923
+ break;
5924
+ case require_backend.DType.Int32:
5925
+ allOnes = -1;
5926
+ break;
5927
+ default: throw new TypeError(`invert: unsupported dtype ${arr.dtype}`);
5928
+ }
5929
+ return bitCombine(arr, allOnes, "xor");
5930
+ }
5931
+ /** Compute element-wise left bit shift. */
5932
+ function leftShift(x, y) {
5933
+ return bitShift(x, y, "shl");
5934
+ }
5935
+ /** Compute element-wise right bit shift. */
5936
+ function rightShift(x, y) {
5937
+ return bitShift(x, y, "shr");
5938
+ }
5803
5939
  /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
5804
5940
  const where = where$1;
5805
5941
  /**
@@ -8336,6 +8472,7 @@ exports.blockUntilReady = blockUntilReady;
8336
8472
  exports.defaultDevice = require_backend.defaultDevice;
8337
8473
  exports.devicePut = devicePut;
8338
8474
  exports.devices = require_backend.devices;
8475
+ exports.getWebGPUDevice = require_backend.getWebGPUDevice;
8339
8476
  exports.grad = grad;
8340
8477
  exports.hessian = hessian;
8341
8478
  exports.init = require_backend.init;
package/dist/index.d.cts CHANGED
@@ -232,6 +232,8 @@ declare class AluExp implements FpHashable {
232
232
  static cast(dtype: DType, a: AluExp): AluExp;
233
233
  static bitcast(dtype: DType, a: AluExp): AluExp;
234
234
  static threefry2x32(k0: AluExp, k1: AluExp, c0: AluExp, c1: AluExp, mode?: "xor" | 0 | 1): AluExp;
235
+ static bitCombine(a: AluExp, b: AluExp, mode: "and" | "or" | "xor"): AluExp;
236
+ static bitShift(a: AluExp, b: AluExp, mode: "shl" | "shr"): AluExp;
235
237
  static cmplt(a: AluExp, b: AluExp): AluExp;
236
238
  static cmpne(a: AluExp, b: AluExp): AluExp;
237
239
  static where(cond: AluExp, a: AluExp, b: AluExp): AluExp;
@@ -323,6 +325,11 @@ declare enum AluOp {
323
325
  Reciprocal = "Reciprocal",
324
326
  Cast = "Cast",
325
327
  Bitcast = "Bitcast",
328
+ BitCombine = "BitCombine",
329
+ // arg = 'or' | 'and' | 'xor'
330
+ BitInvert = "BitInvert",
331
+ BitShift = "BitShift",
332
+ // arg = 'shl' | 'shr'
326
333
  Cmplt = "Cmplt",
327
334
  Cmpne = "Cmpne",
328
335
  Where = "Where",
@@ -546,6 +553,11 @@ declare class Executable<T = any> {
546
553
  source: Kernel | Routine, /** Extra data specific to the backend running this executable. */
547
554
  data: T);
548
555
  }
556
+ /**
557
+ * If the WebGPU backend has been initialized, return the `GPUDevice` that this
558
+ * backend runs on. This is useful for sharing buffers.
559
+ */
560
+ declare function getWebGPUDevice(): GPUDevice;
549
561
  declare namespace tree_d_exports {
550
562
  export { JsTree, JsTreeDef, MapJsTree, NodeType, dispose, flatten, leaves, map, ref, structure, unflatten };
551
563
  }
@@ -719,6 +731,8 @@ declare enum Primitive {
719
731
  // uses sign of numerator, C-style, matches JS but not Python
720
732
  Min = "min",
721
733
  Max = "max",
734
+ BitCombine = "bit_combine",
735
+ BitShift = "bit_shift",
722
736
  Neg = "neg",
723
737
  Reciprocal = "reciprocal",
724
738
  Floor = "floor",
@@ -767,6 +781,12 @@ declare enum Primitive {
767
781
  Jit = "jit",
768
782
  }
769
783
  interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
784
+ [Primitive.BitCombine]: {
785
+ op: "and" | "or" | "xor";
786
+ };
787
+ [Primitive.BitShift]: {
788
+ op: "shl" | "shr";
789
+ };
770
790
  [Primitive.Cast]: {
771
791
  dtype: DType;
772
792
  };
@@ -1194,6 +1214,19 @@ declare class Array extends Tracer {
1194
1214
  * recommended for performance reasons, as it will block rendering.
1195
1215
  */
1196
1216
  dataSync(): DataArray;
1217
+ /**
1218
+ * Return this array as a WebGPU buffer (with `STORAGE | COPY_SRC`).
1219
+ *
1220
+ * Only available on the WebGPU backend. The array's memory is still managed
1221
+ * by jax-js, and it will be freed when the buffer is no longer in use. You
1222
+ * _should not_ mutate the buffer's contents.
1223
+ *
1224
+ * Note that the GPU buffer may be slightly larger than the array's size; it
1225
+ * will always be aligned to 4 bytes.
1226
+ */
1227
+ gpuBuffer(): Promise<GPUBuffer>;
1228
+ /** Synchronous version of `Array.gpuBuffer()`. */
1229
+ gpuBufferSync(): GPUBuffer;
1197
1230
  /**
1198
1231
  * Convert this array into a JavaScript object.
1199
1232
  *
@@ -1571,7 +1604,7 @@ declare function fft(a: ComplexPair, axis?: number): ComplexPair;
1571
1604
  */
1572
1605
  declare function ifft(a: ComplexPair, axis?: number): ComplexPair;
1573
1606
  declare namespace numpy_linalg_d_exports {
1574
- export { cholesky, cross$1 as cross, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot };
1607
+ export { cholesky, cross$1 as cross, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot, vectorNorm };
1575
1608
  }
1576
1609
  /**
1577
1610
  * Compute the Cholesky decomposition of a (batched) positive-definite matrix.
@@ -1626,6 +1659,24 @@ declare function slogdet(a: ArrayLike): [Array, Array];
1626
1659
  * @returns Solution `x` of shape `(..., N)` or `(..., N, M)`.
1627
1660
  */
1628
1661
  declare function solve(a: ArrayLike, b: ArrayLike): Array;
1662
+ /**
1663
+ * Compute the vector norm of an array.
1664
+ *
1665
+ * @param x - Input array.
1666
+ * @param ord - Order of the norm (default 2). Supports `Infinity`, `-Infinity`, `0`, or any real number.
1667
+ * @param axis - Axis/axes to reduce over (default: all axes).
1668
+ * @param keepdims - Whether to keep reduced dimensions as size 1.
1669
+ * @returns The norm of `x`, reduced over the given axes.
1670
+ */
1671
+ declare function vectorNorm(x: ArrayLike, {
1672
+ ord,
1673
+ axis,
1674
+ keepdims
1675
+ }?: {
1676
+ ord?: number;
1677
+ axis?: number | number[] | null;
1678
+ keepdims?: boolean;
1679
+ }): Array;
1629
1680
  //#endregion
1630
1681
  //#region src/library/numpy/dtype-info.d.ts
1631
1682
  /** @inline */
@@ -1679,7 +1730,7 @@ type IInfo = Readonly<{
1679
1730
  /** Machine limits for integer types. */
1680
1731
  declare function iinfo(dtype: DType): IInfo;
1681
1732
  declare namespace numpy_d_exports {
1682
- export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, round as around, array, arrayEqual, arrayEquiv, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, average, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, copysign, corrcoef, correlate, cos, cosh, cov, cross, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logicalAnd, logicalNot, logicalOr, logicalXor, logspace, matmul, matrixTranspose, matvec, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, nanToNum, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, rint, round, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vecmat, vstack, where, zeros, zerosLike };
1733
+ export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, round as around, array, arrayEqual, arrayEquiv, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, average, bitwiseAnd, invert as bitwiseInvert, leftShift as bitwiseLeftShift, invert as bitwiseNot, bitwiseOr, rightShift as bitwiseRightShift, bitwiseXor, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, copysign, corrcoef, correlate, cos, cosh, cov, cross, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, invert, isfinite, isinf, isnan, isneginf, isposinf, ldexp, leftShift, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logicalAnd, logicalNot, logicalOr, logicalXor, logspace, matmul, matrixTranspose, matvec, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, nanToNum, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, rightShift, rint, round, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vecmat, vstack, where, zeros, zerosLike };
1683
1734
  }
1684
1735
  declare const float32 = DType.Float32;
1685
1736
  declare const int32 = DType.Int32;
@@ -1747,6 +1798,18 @@ declare function logicalOr(x: ArrayLike, y: ArrayLike): Array;
1747
1798
  declare function logicalXor(x: ArrayLike, y: ArrayLike): Array;
1748
1799
  /** Compute element-wise logical NOT. */
1749
1800
  declare function logicalNot(x: ArrayLike): Array;
1801
+ /** Compute element-wise bitwise AND. */
1802
+ declare function bitwiseAnd(x: ArrayLike, y: ArrayLike): Array;
1803
+ /** Compute element-wise bitwise OR. */
1804
+ declare function bitwiseOr(x: ArrayLike, y: ArrayLike): Array;
1805
+ /** Compute element-wise bitwise XOR. */
1806
+ declare function bitwiseXor(x: ArrayLike, y: ArrayLike): Array;
1807
+ /** Compute element-wise bitwise NOT (inversion). */
1808
+ declare function invert(x: ArrayLike): Array;
1809
+ /** Compute element-wise left bit shift. */
1810
+ declare function leftShift(x: ArrayLike, y: ArrayLike): Array;
1811
+ /** Compute element-wise right bit shift. */
1812
+ declare function rightShift(x: ArrayLike, y: ArrayLike): Array;
1750
1813
  /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
1751
1814
  declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
1752
1815
  /**
@@ -2958,4 +3021,4 @@ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
2958
3021
  */
2959
3022
  declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
2960
3023
  //#endregion
2961
- export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, profiler, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
3024
+ export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, getWebGPUDevice, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, profiler, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
package/dist/index.d.ts CHANGED
@@ -229,6 +229,8 @@ declare class AluExp implements FpHashable {
229
229
  static cast(dtype: DType, a: AluExp): AluExp;
230
230
  static bitcast(dtype: DType, a: AluExp): AluExp;
231
231
  static threefry2x32(k0: AluExp, k1: AluExp, c0: AluExp, c1: AluExp, mode?: "xor" | 0 | 1): AluExp;
232
+ static bitCombine(a: AluExp, b: AluExp, mode: "and" | "or" | "xor"): AluExp;
233
+ static bitShift(a: AluExp, b: AluExp, mode: "shl" | "shr"): AluExp;
232
234
  static cmplt(a: AluExp, b: AluExp): AluExp;
233
235
  static cmpne(a: AluExp, b: AluExp): AluExp;
234
236
  static where(cond: AluExp, a: AluExp, b: AluExp): AluExp;
@@ -320,6 +322,11 @@ declare enum AluOp {
320
322
  Reciprocal = "Reciprocal",
321
323
  Cast = "Cast",
322
324
  Bitcast = "Bitcast",
325
+ BitCombine = "BitCombine",
326
+ // arg = 'or' | 'and' | 'xor'
327
+ BitInvert = "BitInvert",
328
+ BitShift = "BitShift",
329
+ // arg = 'shl' | 'shr'
323
330
  Cmplt = "Cmplt",
324
331
  Cmpne = "Cmpne",
325
332
  Where = "Where",
@@ -543,6 +550,11 @@ declare class Executable<T = any> {
543
550
  source: Kernel | Routine, /** Extra data specific to the backend running this executable. */
544
551
  data: T);
545
552
  }
553
+ /**
554
+ * If the WebGPU backend has been initialized, return the `GPUDevice` that this
555
+ * backend runs on. This is useful for sharing buffers.
556
+ */
557
+ declare function getWebGPUDevice(): GPUDevice;
546
558
  declare namespace tree_d_exports {
547
559
  export { JsTree, JsTreeDef, MapJsTree, NodeType, dispose, flatten, leaves, map, ref, structure, unflatten };
548
560
  }
@@ -716,6 +728,8 @@ declare enum Primitive {
716
728
  // uses sign of numerator, C-style, matches JS but not Python
717
729
  Min = "min",
718
730
  Max = "max",
731
+ BitCombine = "bit_combine",
732
+ BitShift = "bit_shift",
719
733
  Neg = "neg",
720
734
  Reciprocal = "reciprocal",
721
735
  Floor = "floor",
@@ -764,6 +778,12 @@ declare enum Primitive {
764
778
  Jit = "jit",
765
779
  }
766
780
  interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
781
+ [Primitive.BitCombine]: {
782
+ op: "and" | "or" | "xor";
783
+ };
784
+ [Primitive.BitShift]: {
785
+ op: "shl" | "shr";
786
+ };
767
787
  [Primitive.Cast]: {
768
788
  dtype: DType;
769
789
  };
@@ -1191,6 +1211,19 @@ declare class Array extends Tracer {
1191
1211
  * recommended for performance reasons, as it will block rendering.
1192
1212
  */
1193
1213
  dataSync(): DataArray;
1214
+ /**
1215
+ * Return this array as a WebGPU buffer (with `STORAGE | COPY_SRC`).
1216
+ *
1217
+ * Only available on the WebGPU backend. The array's memory is still managed
1218
+ * by jax-js, and it will be freed when the buffer is no longer in use. You
1219
+ * _should not_ mutate the buffer's contents.
1220
+ *
1221
+ * Note that the GPU buffer may be slightly larger than the array's size; it
1222
+ * will always be aligned to 4 bytes.
1223
+ */
1224
+ gpuBuffer(): Promise<GPUBuffer>;
1225
+ /** Synchronous version of `Array.gpuBuffer()`. */
1226
+ gpuBufferSync(): GPUBuffer;
1194
1227
  /**
1195
1228
  * Convert this array into a JavaScript object.
1196
1229
  *
@@ -1568,7 +1601,7 @@ declare function fft(a: ComplexPair, axis?: number): ComplexPair;
1568
1601
  */
1569
1602
  declare function ifft(a: ComplexPair, axis?: number): ComplexPair;
1570
1603
  declare namespace numpy_linalg_d_exports {
1571
- export { cholesky, cross$1 as cross, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot };
1604
+ export { cholesky, cross$1 as cross, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot, vectorNorm };
1572
1605
  }
1573
1606
  /**
1574
1607
  * Compute the Cholesky decomposition of a (batched) positive-definite matrix.
@@ -1623,6 +1656,24 @@ declare function slogdet(a: ArrayLike): [Array, Array];
1623
1656
  * @returns Solution `x` of shape `(..., N)` or `(..., N, M)`.
1624
1657
  */
1625
1658
  declare function solve(a: ArrayLike, b: ArrayLike): Array;
1659
+ /**
1660
+ * Compute the vector norm of an array.
1661
+ *
1662
+ * @param x - Input array.
1663
+ * @param ord - Order of the norm (default 2). Supports `Infinity`, `-Infinity`, `0`, or any real number.
1664
+ * @param axis - Axis/axes to reduce over (default: all axes).
1665
+ * @param keepdims - Whether to keep reduced dimensions as size 1.
1666
+ * @returns The norm of `x`, reduced over the given axes.
1667
+ */
1668
+ declare function vectorNorm(x: ArrayLike, {
1669
+ ord,
1670
+ axis,
1671
+ keepdims
1672
+ }?: {
1673
+ ord?: number;
1674
+ axis?: number | number[] | null;
1675
+ keepdims?: boolean;
1676
+ }): Array;
1626
1677
  //#endregion
1627
1678
  //#region src/library/numpy/dtype-info.d.ts
1628
1679
  /** @inline */
@@ -1676,7 +1727,7 @@ type IInfo = Readonly<{
1676
1727
  /** Machine limits for integer types. */
1677
1728
  declare function iinfo(dtype: DType): IInfo;
1678
1729
  declare namespace numpy_d_exports {
1679
- export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, round as around, array, arrayEqual, arrayEquiv, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, average, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, copysign, corrcoef, correlate, cos, cosh, cov, cross, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logicalAnd, logicalNot, logicalOr, logicalXor, logspace, matmul, matrixTranspose, matvec, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, nanToNum, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, rint, round, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vecmat, vstack, where, zeros, zerosLike };
1730
+ export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, round as around, array, arrayEqual, arrayEquiv, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, average, bitwiseAnd, invert as bitwiseInvert, leftShift as bitwiseLeftShift, invert as bitwiseNot, bitwiseOr, rightShift as bitwiseRightShift, bitwiseXor, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, copysign, corrcoef, correlate, cos, cosh, cov, cross, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, invert, isfinite, isinf, isnan, isneginf, isposinf, ldexp, leftShift, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logicalAnd, logicalNot, logicalOr, logicalXor, logspace, matmul, matrixTranspose, matvec, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, nanToNum, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, rightShift, rint, round, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vecmat, vstack, where, zeros, zerosLike };
1680
1731
  }
1681
1732
  declare const float32 = DType.Float32;
1682
1733
  declare const int32 = DType.Int32;
@@ -1744,6 +1795,18 @@ declare function logicalOr(x: ArrayLike, y: ArrayLike): Array;
1744
1795
  declare function logicalXor(x: ArrayLike, y: ArrayLike): Array;
1745
1796
  /** Compute element-wise logical NOT. */
1746
1797
  declare function logicalNot(x: ArrayLike): Array;
1798
+ /** Compute element-wise bitwise AND. */
1799
+ declare function bitwiseAnd(x: ArrayLike, y: ArrayLike): Array;
1800
+ /** Compute element-wise bitwise OR. */
1801
+ declare function bitwiseOr(x: ArrayLike, y: ArrayLike): Array;
1802
+ /** Compute element-wise bitwise XOR. */
1803
+ declare function bitwiseXor(x: ArrayLike, y: ArrayLike): Array;
1804
+ /** Compute element-wise bitwise NOT (inversion). */
1805
+ declare function invert(x: ArrayLike): Array;
1806
+ /** Compute element-wise left bit shift. */
1807
+ declare function leftShift(x: ArrayLike, y: ArrayLike): Array;
1808
+ /** Compute element-wise right bit shift. */
1809
+ declare function rightShift(x: ArrayLike, y: ArrayLike): Array;
1747
1810
  /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
1748
1811
  declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
1749
1812
  /**
@@ -2955,4 +3018,4 @@ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
2955
3018
  */
2956
3019
  declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
2957
3020
  //#endregion
2958
- export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, profiler, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
3021
+ export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, getWebGPUDevice, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, profiler, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };