@jax-js/jax 0.1.10 → 0.1.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -1,5 +1,5 @@
1
1
  import { __export } from "./chunk-Cl8Af3a2.js";
2
- import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, startTrace, stopTrace, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-Ctqs8la1.js";
2
+ import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, getWebGPUDevice, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, startTrace, stopTrace, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-DZvR7mZV.js";
3
3
 
4
4
  //#region src/frontend/convolution.ts
5
5
  /**
@@ -333,6 +333,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
333
333
  Primitive$1["Mod"] = "mod";
334
334
  Primitive$1["Min"] = "min";
335
335
  Primitive$1["Max"] = "max";
336
+ Primitive$1["BitCombine"] = "bit_combine";
337
+ Primitive$1["BitShift"] = "bit_shift";
336
338
  Primitive$1["Neg"] = "neg";
337
339
  Primitive$1["Reciprocal"] = "reciprocal";
338
340
  Primitive$1["Floor"] = "floor";
@@ -406,6 +408,12 @@ function min$1(x, y) {
406
408
  function max$1(x, y) {
407
409
  return bind1(Primitive.Max, [x, y]);
408
410
  }
411
+ function bitCombine(x, y, op) {
412
+ return bind1(Primitive.BitCombine, [x, y], { op });
413
+ }
414
+ function bitShift(x, y, op) {
415
+ return bind1(Primitive.BitShift, [x, y], { op });
416
+ }
409
417
  function neg(x) {
410
418
  return bind1(Primitive.Neg, [x]);
411
419
  }
@@ -1620,6 +1628,16 @@ const abstractEvalRules = {
1620
1628
  [Primitive.Mod]: binopAbstractEval,
1621
1629
  [Primitive.Min]: binopAbstractEval,
1622
1630
  [Primitive.Max]: binopAbstractEval,
1631
+ [Primitive.BitCombine]([x, y]) {
1632
+ const aval = promoteAvals(x, y);
1633
+ if (isFloatDtype(aval.dtype)) throw new TypeError(`bitwise operations require integer or boolean inputs, got ${aval.dtype}`);
1634
+ return [aval];
1635
+ },
1636
+ [Primitive.BitShift]([x, y]) {
1637
+ const shape$1 = generalBroadcast(x.shape, y.shape);
1638
+ if (isFloatDtype(x.dtype) || isFloatDtype(y.dtype) || x.dtype === DType.Bool || y.dtype === DType.Bool) throw new TypeError(`bit shift operations require integer inputs, got ${x} and ${y}`);
1639
+ return [new ShapedArray(shape$1, x.dtype, x.weakType)];
1640
+ },
1623
1641
  [Primitive.Neg]: vectorizedUnopAbstractEval,
1624
1642
  [Primitive.Reciprocal]: vectorizedUnopAbstractEval,
1625
1643
  [Primitive.Floor]: vectorizedUnopAbstractEval,
@@ -2155,6 +2173,8 @@ const jitRules = {
2155
2173
  [Primitive.Mod]: broadcastedJit(([a, b]) => AluExp.mod(a, b)),
2156
2174
  [Primitive.Min]: broadcastedJit(([a, b]) => AluExp.min(a, b)),
2157
2175
  [Primitive.Max]: broadcastedJit(([a, b]) => AluExp.max(a, b)),
2176
+ [Primitive.BitCombine]: broadcastedJit(([a, b], { op }) => AluExp.bitCombine(a, b, op)),
2177
+ [Primitive.BitShift]: broadcastedJit(([a, b], { op }) => AluExp.bitShift(a, b, op)),
2158
2178
  [Primitive.Neg]: unopJit((a) => AluExp.sub(AluExp.const(a.dtype, 0), a)),
2159
2179
  [Primitive.Reciprocal]: unopJit(AluExp.reciprocal),
2160
2180
  [Primitive.Floor]: unopJit(AluExp.floor),
@@ -2347,7 +2367,9 @@ function splitGraphDataflow(backend, jaxpr) {
2347
2367
  case Primitive.Idiv:
2348
2368
  case Primitive.Mod:
2349
2369
  case Primitive.Min:
2350
- case Primitive.Max: {
2370
+ case Primitive.Max:
2371
+ case Primitive.BitCombine:
2372
+ case Primitive.BitShift: {
2351
2373
  const otherInput = nextEqn.inputs.find((v) => v !== outVar);
2352
2374
  if (otherInput instanceof Lit || deepEqual(generalBroadcast(otherInput.aval.shape, outVar.aval.shape), outVar.aval.shape)) {
2353
2375
  head = usages[0];
@@ -2986,6 +3008,42 @@ var Array$1 = class Array$1 extends Tracer {
2986
3008
  return dtypedArray(this.dtype, buf);
2987
3009
  }
2988
3010
  /**
3011
+ * Return this array as a WebGPU buffer (with `STORAGE | COPY_SRC`).
3012
+ *
3013
+ * Only available on the WebGPU backend. The array's memory is still managed
3014
+ * by jax-js, and it will be freed when the buffer is no longer in use. You
3015
+ * _should not_ mutate the buffer's contents.
3016
+ *
3017
+ * Note that the GPU buffer may be slightly larger than the array's size; it
3018
+ * will always be aligned to 4 bytes.
3019
+ */
3020
+ async gpuBuffer() {
3021
+ if (this.device !== "webgpu") throw new Error(`gpuBuffer() is only available on WebGPU backend`);
3022
+ this.#realize();
3023
+ const pending = this.#pending;
3024
+ if (pending) {
3025
+ await Promise.all(pending.map((p) => p.prepare()));
3026
+ for (const p of pending) p.submit();
3027
+ }
3028
+ const backend = this.#backend;
3029
+ const { buffer } = backend.buffers.get(this.#source);
3030
+ this.dispose();
3031
+ return buffer;
3032
+ }
3033
+ /** Synchronous version of `Array.gpuBuffer()`. */
3034
+ gpuBufferSync() {
3035
+ if (this.device !== "webgpu") throw new Error(`gpuBufferSync() is only available on WebGPU backend`);
3036
+ this.#realize();
3037
+ for (const p of this.#pending) {
3038
+ p.prepareSync();
3039
+ p.submit();
3040
+ }
3041
+ const backend = this.#backend;
3042
+ const { buffer } = backend.buffers.get(this.#source);
3043
+ this.dispose();
3044
+ return buffer;
3045
+ }
3046
+ /**
2989
3047
  * Convert this array into a JavaScript object.
2990
3048
  *
2991
3049
  * This is a blocking operation that will compile all of the shaders and wait
@@ -3032,6 +3090,14 @@ var Array$1 = class Array$1 extends Tracer {
3032
3090
  [Primitive.Max]([x, y]) {
3033
3091
  return [x.#binary(AluOp.Max, y)];
3034
3092
  },
3093
+ [Primitive.BitCombine]([x, y], { op }) {
3094
+ const custom = (src) => AluExp.bitCombine(src[0], src[1], op);
3095
+ return [Array$1.#naryCustom("bit_combine", custom, [x, y])];
3096
+ },
3097
+ [Primitive.BitShift]([x, y], { op }) {
3098
+ const custom = (src) => AluExp.bitShift(src[0], src[1], op);
3099
+ return [Array$1.#naryCustom("bit_shift", custom, [x, y], { dtypeOverride: [void 0, y.dtype] })];
3100
+ },
3035
3101
  [Primitive.Neg]([x]) {
3036
3102
  return [zerosLike$1(x.ref).#binary(AluOp.Sub, x)];
3037
3103
  },
@@ -3723,6 +3789,8 @@ const vmapRules = {
3723
3789
  [Primitive.Mod]: broadcastBatcher(Primitive.Mod),
3724
3790
  [Primitive.Min]: broadcastBatcher(Primitive.Min),
3725
3791
  [Primitive.Max]: broadcastBatcher(Primitive.Max),
3792
+ [Primitive.BitCombine]: broadcastBatcher(Primitive.BitCombine),
3793
+ [Primitive.BitShift]: broadcastBatcher(Primitive.BitShift),
3726
3794
  [Primitive.Neg]: unopBatcher(Primitive.Neg),
3727
3795
  [Primitive.Reciprocal]: unopBatcher(Primitive.Reciprocal),
3728
3796
  [Primitive.Floor]: unopBatcher(Primitive.Floor),
@@ -4045,6 +4113,8 @@ const jvpRules = {
4045
4113
  [Primitive.Max]([x, y], [dx, dy]) {
4046
4114
  return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
4047
4115
  },
4116
+ [Primitive.BitCombine]: zeroTangentsJvp(Primitive.BitCombine),
4117
+ [Primitive.BitShift]: zeroTangentsJvp(Primitive.BitShift),
4048
4118
  [Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
4049
4119
  [Primitive.Reciprocal]([x], [dx]) {
4050
4120
  const xRecip = reciprocal$1(x.ref);
@@ -5236,7 +5306,8 @@ __export(numpy_linalg_exports, {
5236
5306
  solve: () => solve,
5237
5307
  tensordot: () => tensordot,
5238
5308
  trace: () => trace,
5239
- vecdot: () => vecdot
5309
+ vecdot: () => vecdot,
5310
+ vectorNorm: () => vectorNorm
5240
5311
  });
5241
5312
  function checkSquare(name, a) {
5242
5313
  if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`${name}: input must be at least 2D square matrix, got ${a.aval}`);
@@ -5415,6 +5486,23 @@ function solve(a, b) {
5415
5486
  if (bIs1d) x = squeeze(x, -1);
5416
5487
  return x;
5417
5488
  }
5489
+ /**
5490
+ * Compute the vector norm of an array.
5491
+ *
5492
+ * @param x - Input array.
5493
+ * @param ord - Order of the norm (default 2). Supports `Infinity`, `-Infinity`, `0`, or any real number.
5494
+ * @param axis - Axis/axes to reduce over (default: all axes).
5495
+ * @param keepdims - Whether to keep reduced dimensions as size 1.
5496
+ * @returns The norm of `x`, reduced over the given axes.
5497
+ */
5498
+ function vectorNorm(x, { ord = 2, axis = null, keepdims = false } = {}) {
5499
+ x = fudgeArray(x);
5500
+ const ax = axis ?? null;
5501
+ if (ord === Infinity) return max(absolute(x), ax, { keepdims });
5502
+ else if (ord === -Infinity) return min(absolute(x), ax, { keepdims });
5503
+ else if (ord === 0) return x.notEqual(0).astype(x.dtype).sum(ax, { keepdims });
5504
+ else return power(power(absolute(x), ord).sum(ax, { keepdims }), 1 / ord);
5505
+ }
5418
5506
 
5419
5507
  //#endregion
5420
5508
  //#region src/library/numpy/dtype-info.ts
@@ -5534,6 +5622,13 @@ __export(numpy_exports, {
5534
5622
  atan2: () => atan2,
5535
5623
  atanh: () => arctanh,
5536
5624
  average: () => average,
5625
+ bitwiseAnd: () => bitwiseAnd,
5626
+ bitwiseInvert: () => invert,
5627
+ bitwiseLeftShift: () => leftShift,
5628
+ bitwiseNot: () => invert,
5629
+ bitwiseOr: () => bitwiseOr,
5630
+ bitwiseRightShift: () => rightShift,
5631
+ bitwiseXor: () => bitwiseXor,
5537
5632
  bool: () => bool,
5538
5633
  broadcastArrays: () => broadcastArrays,
5539
5634
  broadcastShapes: () => broadcastShapes,
@@ -5595,12 +5690,14 @@ __export(numpy_exports, {
5595
5690
  inf: () => inf,
5596
5691
  inner: () => inner,
5597
5692
  int32: () => int32,
5693
+ invert: () => invert,
5598
5694
  isfinite: () => isfinite,
5599
5695
  isinf: () => isinf,
5600
5696
  isnan: () => isnan,
5601
5697
  isneginf: () => isneginf,
5602
5698
  isposinf: () => isposinf,
5603
5699
  ldexp: () => ldexp,
5700
+ leftShift: () => leftShift,
5604
5701
  less: () => less,
5605
5702
  lessEqual: () => lessEqual,
5606
5703
  linalg: () => numpy_linalg_exports,
@@ -5649,6 +5746,7 @@ __export(numpy_exports, {
5649
5746
  remainder: () => remainder,
5650
5747
  repeat: () => repeat,
5651
5748
  reshape: () => reshape,
5749
+ rightShift: () => rightShift,
5652
5750
  rint: () => rint,
5653
5751
  round: () => round,
5654
5752
  shape: () => shape,
@@ -5763,6 +5861,44 @@ function logicalXor(x, y) {
5763
5861
  function logicalNot(x) {
5764
5862
  return notEqual(astype(x, DType.Bool), true);
5765
5863
  }
5864
+ /** Compute element-wise bitwise AND. */
5865
+ function bitwiseAnd(x, y) {
5866
+ return bitCombine(x, y, "and");
5867
+ }
5868
+ /** Compute element-wise bitwise OR. */
5869
+ function bitwiseOr(x, y) {
5870
+ return bitCombine(x, y, "or");
5871
+ }
5872
+ /** Compute element-wise bitwise XOR. */
5873
+ function bitwiseXor(x, y) {
5874
+ return bitCombine(x, y, "xor");
5875
+ }
5876
+ /** Compute element-wise bitwise NOT (inversion). */
5877
+ function invert(x) {
5878
+ const arr = fudgeArray(x);
5879
+ let allOnes;
5880
+ switch (arr.dtype) {
5881
+ case DType.Bool:
5882
+ allOnes = true;
5883
+ break;
5884
+ case DType.Uint32:
5885
+ allOnes = 4294967295;
5886
+ break;
5887
+ case DType.Int32:
5888
+ allOnes = -1;
5889
+ break;
5890
+ default: throw new TypeError(`invert: unsupported dtype ${arr.dtype}`);
5891
+ }
5892
+ return bitCombine(arr, allOnes, "xor");
5893
+ }
5894
+ /** Compute element-wise left bit shift. */
5895
+ function leftShift(x, y) {
5896
+ return bitShift(x, y, "shl");
5897
+ }
5898
+ /** Compute element-wise right bit shift. */
5899
+ function rightShift(x, y) {
5900
+ return bitShift(x, y, "shr");
5901
+ }
5766
5902
  /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
5767
5903
  const where = where$1;
5768
5904
  /**
@@ -8291,4 +8427,4 @@ async function devicePut(x, device) {
8291
8427
  }
8292
8428
 
8293
8429
  //#endregion
8294
- export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, profiler, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
8430
+ export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, getWebGPUDevice, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, profiler, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
@@ -1,4 +1,4 @@
1
- import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-Ctqs8la1.js";
1
+ import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-DZvR7mZV.js";
2
2
 
3
3
  //#region src/backend/webgl/builtins.ts
4
4
  const threefrySrc = `
@@ -458,6 +458,12 @@ function generateExpression(exp, args, inputDtypes) {
458
458
  else source = `min(${a}, ${b})`;
459
459
  else if (op === AluOp.Max) if (dtype === DType.Bool) source = `(${a} || ${b})`;
460
460
  else source = `max(${a}, ${b})`;
461
+ else if (op === AluOp.BitCombine) {
462
+ let infix = arg === "and" ? "&" : arg === "or" ? "|" : "^";
463
+ if (dtype === DType.Bool) infix = infix + infix;
464
+ source = `(${a} ${infix} ${b})`;
465
+ } else if (op === AluOp.BitShift) if (arg === "shl") source = `(${a} << ${b})`;
466
+ else source = `(${a} >> ${b})`;
461
467
  } else if (AluGroup.Compare.has(op)) {
462
468
  const a = gen(src[0]);
463
469
  const b = gen(src[1]);
@@ -1,4 +1,4 @@
1
- const require_backend = require('./backend-DMauYnfl.cjs');
1
+ const require_backend = require('./backend-DlYlOYqN.cjs');
2
2
 
3
3
  //#region src/backend/webgl/builtins.ts
4
4
  const threefrySrc = `
@@ -458,6 +458,12 @@ function generateExpression(exp, args, inputDtypes) {
458
458
  else source = `min(${a}, ${b})`;
459
459
  else if (op === require_backend.AluOp.Max) if (dtype === require_backend.DType.Bool) source = `(${a} || ${b})`;
460
460
  else source = `max(${a}, ${b})`;
461
+ else if (op === require_backend.AluOp.BitCombine) {
462
+ let infix = arg === "and" ? "&" : arg === "or" ? "|" : "^";
463
+ if (dtype === require_backend.DType.Bool) infix = infix + infix;
464
+ source = `(${a} ${infix} ${b})`;
465
+ } else if (op === require_backend.AluOp.BitShift) if (arg === "shl") source = `(${a} << ${b})`;
466
+ else source = `(${a} >> ${b})`;
461
467
  } else if (require_backend.AluGroup.Compare.has(op)) {
462
468
  const a = gen(src[0]);
463
469
  const b = gen(src[1]);
@@ -1,4 +1,4 @@
1
- import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, emitTrace, findPow2, isFloatDtype, isTracing, mapSetUnion, onFlushTrace, prod, range, strip1, traceSourceInfo, tuneWebgpu } from "./backend-Ctqs8la1.js";
1
+ import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, emitTrace, findPow2, isFloatDtype, isTracing, mapSetUnion, onFlushTrace, prod, range, strip1, traceSourceInfo, tuneWebgpu } from "./backend-DZvR7mZV.js";
2
2
 
3
3
  //#region src/backend/webgpu/builtins.ts
4
4
  const threefrySrc = `
@@ -1100,6 +1100,11 @@ function pipelineSource(device, kernel) {
1100
1100
  else source = `min(${strip1(a)}, ${strip1(b)})`;
1101
1101
  else if (op === AluOp.Max) if (dtype === DType.Bool) source = `(${a} || ${b})`;
1102
1102
  else source = `max(${strip1(a)}, ${strip1(b)})`;
1103
+ else if (op === AluOp.BitCombine) if (arg === "and") source = `(${a} & ${b})`;
1104
+ else if (arg === "or") source = `(${a} | ${b})`;
1105
+ else source = dtype === DType.Bool ? `(${a} != ${b})` : `(${a} ^ ${b})`;
1106
+ else if (op === AluOp.BitShift) if (arg === "shl") source = `(${a} << ${b})`;
1107
+ else source = `(${a} >> ${b})`;
1103
1108
  else if (op === AluOp.Cmplt) source = `(${a} < ${b})`;
1104
1109
  else if (op === AluOp.Cmpne) if (isFloatDtype(src[0].dtype)) {
1105
1110
  const x = isGensym(a) ? a : gensym();
@@ -1,4 +1,4 @@
1
- const require_backend = require('./backend-DMauYnfl.cjs');
1
+ const require_backend = require('./backend-DlYlOYqN.cjs');
2
2
 
3
3
  //#region src/backend/webgpu/builtins.ts
4
4
  const threefrySrc = `
@@ -1100,6 +1100,11 @@ function pipelineSource(device, kernel) {
1100
1100
  else source = `min(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
1101
1101
  else if (op === require_backend.AluOp.Max) if (dtype === require_backend.DType.Bool) source = `(${a} || ${b})`;
1102
1102
  else source = `max(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
1103
+ else if (op === require_backend.AluOp.BitCombine) if (arg === "and") source = `(${a} & ${b})`;
1104
+ else if (arg === "or") source = `(${a} | ${b})`;
1105
+ else source = dtype === require_backend.DType.Bool ? `(${a} != ${b})` : `(${a} ^ ${b})`;
1106
+ else if (op === require_backend.AluOp.BitShift) if (arg === "shl") source = `(${a} << ${b})`;
1107
+ else source = `(${a} >> ${b})`;
1103
1108
  else if (op === require_backend.AluOp.Cmplt) source = `(${a} < ${b})`;
1104
1109
  else if (op === require_backend.AluOp.Cmpne) if (require_backend.isFloatDtype(src[0].dtype)) {
1105
1110
  const x = isGensym(a) ? a : gensym();
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@jax-js/jax",
3
- "version": "0.1.10",
3
+ "version": "0.1.11",
4
4
  "description": "Numerical computing and ML in the browser",
5
5
  "keywords": [
6
6
  "machine learning",