@jax-js/jax 0.1.6 → 0.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -1,5 +1,5 @@
1
1
  import { __export } from "./chunk-Cl8Af3a2.js";
2
- import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-Dx6Ob2D1.js";
2
+ import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-nEolvdLv.js";
3
3
 
4
4
  //#region src/frontend/convolution.ts
5
5
  /**
@@ -306,11 +306,11 @@ function map(fn, tree, ...rest) {
306
306
  }
307
307
  /** Take a reference of every array in a tree. */
308
308
  function ref(tree) {
309
- return map((x) => x.ref, tree);
309
+ return map((x) => x instanceof Tracer ? x.ref : x, tree);
310
310
  }
311
311
  /** Dispose every array in a tree. */
312
312
  function dispose(tree) {
313
- if (tree) map((x) => x.dispose(), tree);
313
+ if (tree) map((x) => x instanceof Tracer ? x.dispose() : void 0, tree);
314
314
  }
315
315
 
316
316
  //#endregion
@@ -584,14 +584,20 @@ function shrink(x, slice) {
584
584
  }
585
585
  function pad$1(x, width) {
586
586
  const nd = ndim$1(x);
587
- if (typeof width === "number") width = [[width, width]];
588
- else if (isNumberPair(width)) width = [width];
589
- else if (!Array.isArray(width) || !width.every(isNumberPair)) throw new TypeError(`Invalid pad() type: ${JSON.stringify(width)}`);
590
- if (width.length === 1) {
591
- const [w0, w1] = width[0];
592
- width = rep(nd, () => [w0, w1]);
593
- } else if (width.length !== nd) throw new Error(`Invalid pad(): expected ${nd} axes, got ${width.length}`);
594
- return bind1(Primitive.Pad, [x], { width });
587
+ let w;
588
+ if (typeof width === "number") w = [[width, width]];
589
+ else if (isNumberPair(width)) w = [width];
590
+ else if (!Array.isArray(width)) {
591
+ const indicesAndPairs = Object.entries(width);
592
+ w = rep(nd, [0, 0]);
593
+ for (const [k, v] of indicesAndPairs) w[checkAxis(parseInt(k), nd)] = v;
594
+ } else if (!width.every(isNumberPair)) throw new TypeError(`Invalid pad() type: ${JSON.stringify(width)}`);
595
+ else w = width;
596
+ if (w.length === 1) {
597
+ const [w0, w1] = w[0];
598
+ w = rep(nd, () => [w0, w1]);
599
+ } else if (w.length !== nd) throw new Error(`Invalid pad(): expected ${nd} axes, got ${w.length}`);
600
+ return bind1(Primitive.Pad, [x], { width: w });
595
601
  }
596
602
  function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
597
603
  const as = getShape(a);
@@ -767,6 +773,22 @@ var Tracer = class Tracer {
767
773
  const result = reduce(this.astype(castDtype), AluOp.Add, axis, opts);
768
774
  return result.mul(1 / n).astype(originalDtype);
769
775
  }
776
+ /** Minimum of the elements of the array along a given axis. */
777
+ min(axis = null, opts) {
778
+ return reduce(this, AluOp.Min, axis, opts);
779
+ }
780
+ /** Maximum of the elements of the array along a given axis. */
781
+ max(axis = null, opts) {
782
+ return reduce(this, AluOp.Max, axis, opts);
783
+ }
784
+ /** Test whether all array elements along a given axis evaluate to true. */
785
+ all(axis = null, opts) {
786
+ return this.astype(DType.Bool).min(axis, opts);
787
+ }
788
+ /** Test whether any array element along a given axis evaluates to true. */
789
+ any(axis = null, opts) {
790
+ return this.astype(DType.Bool).max(axis, opts);
791
+ }
770
792
  /** Permute the dimensions of an array. Defaults to reversing the axis order. */
771
793
  transpose(perm) {
772
794
  return transpose$1(this, perm);
@@ -5570,6 +5592,7 @@ __export(numpy_exports, {
5570
5592
  moveaxis: () => moveaxis$1,
5571
5593
  multiply: () => multiply,
5572
5594
  nan: () => nan,
5595
+ nanToNum: () => nanToNum,
5573
5596
  ndim: () => ndim,
5574
5597
  negative: () => negative,
5575
5598
  notEqual: () => notEqual,
@@ -5767,24 +5790,22 @@ function max(a, axis = null, opts) {
5767
5790
  return reduce(a, AluOp.Max, axis, opts);
5768
5791
  }
5769
5792
  /**
5770
- * Test whether all array elements along a given axis evaluate to True.
5793
+ * Test whether any array element along a given axis evaluates to True.
5771
5794
  *
5772
5795
  * Returns a boolean array with the same shape as `a` with the specified axis
5773
5796
  * removed. If axis is None, returns a scalar.
5774
5797
  */
5775
- function all(a, axis = null, opts) {
5776
- a = fudgeArray(a).astype(DType.Bool);
5777
- return min(a, axis, opts);
5798
+ function any(a, axis = null, opts) {
5799
+ return fudgeArray(a).any(axis, opts);
5778
5800
  }
5779
5801
  /**
5780
- * Test whether any array element along a given axis evaluates to True.
5802
+ * Test whether all array elements along a given axis evaluate to True.
5781
5803
  *
5782
5804
  * Returns a boolean array with the same shape as `a` with the specified axis
5783
5805
  * removed. If axis is None, returns a scalar.
5784
5806
  */
5785
- function any(a, axis = null, opts) {
5786
- a = fudgeArray(a).astype(DType.Bool);
5787
- return max(a, axis, opts);
5807
+ function all(a, axis = null, opts) {
5808
+ return fudgeArray(a).all(axis, opts);
5788
5809
  }
5789
5810
  /** Return the peak-to-peak range along a given axis (`max - min`). */
5790
5811
  function ptp(a, axis = null, opts) {
@@ -5885,7 +5906,7 @@ function split$1(a, indicesOrSections, axis = 0) {
5885
5906
  const partSize = size$1 / indicesOrSections;
5886
5907
  sizes = rep(indicesOrSections, partSize);
5887
5908
  } else {
5888
- const indices = indicesOrSections;
5909
+ const indices = indicesOrSections.map((i) => i < 0 ? i + size$1 : i);
5889
5910
  sizes = [indices[0]];
5890
5911
  for (let i = 1; i < indices.length; i++) sizes.push(indices[i] - indices[i - 1]);
5891
5912
  sizes.push(size$1 - indices[indices.length - 1]);
@@ -6833,6 +6854,21 @@ function isposinf(x) {
6833
6854
  return isFloatDtype(x.dtype) ? x.equal(Infinity) : fullLike$1(x, false);
6834
6855
  }
6835
6856
  /**
6857
+ * Replace NaN and infinite entries in an array.
6858
+ *
6859
+ * By default, NaNs are replaced with `0.0`, and infinities are are substituted
6860
+ * with the corresponding maximum or minimum finite values.
6861
+ */
6862
+ function nanToNum(x, { nan: nan$1 = 0, posinf = null, neginf = null } = {}) {
6863
+ x = fudgeArray(x);
6864
+ x = where(isnan(x.ref), nan$1, x);
6865
+ posinf ??= isFloatDtype(x.dtype) ? finfo(x.dtype).max : iinfo(x.dtype).max;
6866
+ neginf ??= isFloatDtype(x.dtype) ? finfo(x.dtype).min : iinfo(x.dtype).min;
6867
+ x = where(isposinf(x.ref), posinf, x);
6868
+ x = where(isneginf(x.ref), neginf, x);
6869
+ return x;
6870
+ }
6871
+ /**
6836
6872
  * @function
6837
6873
  * Test element-wise for finite values (not infinity or NaN).
6838
6874
  */
@@ -7541,8 +7577,6 @@ function oneHot(x, numClasses) {
7541
7577
  * `[B, L, N, H]`, or `[L, N, H]` if `B` is omitted.
7542
7578
  */
7543
7579
  function dotProductAttention(query, key$1, value, opts = {}) {
7544
- if (opts.querySeqLengths !== void 0 || opts.keyValueSeqLengths !== void 0) throw new Error("Sequence length masking is not yet implemented");
7545
- if (opts.localWindowSize !== void 0) throw new Error("Local attention is not yet implemented");
7546
7580
  query = fudgeArray(query);
7547
7581
  key$1 = fudgeArray(key$1);
7548
7582
  value = fudgeArray(value);
@@ -7580,6 +7614,38 @@ function dotProductAttention(query, key$1, value, opts = {}) {
7580
7614
  const causalMask = tri(L, S, 0, { dtype: DType.Bool });
7581
7615
  scores = where(causalMask, scores, -Infinity);
7582
7616
  }
7617
+ if (opts.localWindowSize !== void 0) {
7618
+ const [before, after] = typeof opts.localWindowSize === "number" ? [opts.localWindowSize, opts.localWindowSize] : opts.localWindowSize;
7619
+ if (before < 0 || after < 0 || !Number.isInteger(before) || !Number.isInteger(after)) throw new Error(`dotProductAttention: localWindowSize values must be non-negative, got ${opts.localWindowSize}`);
7620
+ const localMask = tri(L, S, after, { dtype: DType.Bool }).mul(tri(L, S, -before - 1, { dtype: DType.Bool }).notEqual(true));
7621
+ scores = where(localMask, scores, -Infinity);
7622
+ }
7623
+ if (opts.querySeqLengths !== void 0) {
7624
+ const sl = expandDims(opts.querySeqLengths, [
7625
+ -1,
7626
+ -2,
7627
+ -3
7628
+ ]);
7629
+ scores = where(arange(L).reshape([
7630
+ 1,
7631
+ 1,
7632
+ L,
7633
+ 1
7634
+ ]).less(sl), scores, -Infinity);
7635
+ }
7636
+ if (opts.keyValueSeqLengths !== void 0) {
7637
+ const sl = expandDims(opts.keyValueSeqLengths, [
7638
+ -1,
7639
+ -2,
7640
+ -3
7641
+ ]);
7642
+ scores = where(arange(S).reshape([
7643
+ 1,
7644
+ 1,
7645
+ 1,
7646
+ S
7647
+ ]).less(sl), scores, -Infinity);
7648
+ }
7583
7649
  const attn = softmax(scores, -1);
7584
7650
  const out = einsum("BNLS,BSNH->BLNH", attn, value);
7585
7651
  return isRank3 ? out.reshape([
@@ -1,4 +1,4 @@
1
- const require_backend = require('./backend-D7s-Retx.cjs');
1
+ const require_backend = require('./backend-B3foXiV_.cjs');
2
2
 
3
3
  //#region src/backend/webgl/builtins.ts
4
4
  const threefrySrc = `
@@ -1,4 +1,4 @@
1
- import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-Dx6Ob2D1.js";
1
+ import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-nEolvdLv.js";
2
2
 
3
3
  //#region src/backend/webgl/builtins.ts
4
4
  const threefrySrc = `
@@ -1,4 +1,4 @@
1
- import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, findPow2, isFloatDtype, mapSetUnion, prod, range, strip1, tuneWebgpu } from "./backend-Dx6Ob2D1.js";
1
+ import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, findPow2, isFloatDtype, mapSetUnion, prod, range, strip1, tuneWebgpu } from "./backend-nEolvdLv.js";
2
2
 
3
3
  //#region src/backend/webgpu/builtins.ts
4
4
  const threefrySrc = `
@@ -1,4 +1,4 @@
1
- const require_backend = require('./backend-D7s-Retx.cjs');
1
+ const require_backend = require('./backend-B3foXiV_.cjs');
2
2
 
3
3
  //#region src/backend/webgpu/builtins.ts
4
4
  const threefrySrc = `
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@jax-js/jax",
3
- "version": "0.1.6",
3
+ "version": "0.1.7",
4
4
  "description": "Numerical computing and ML in the browser",
5
5
  "keywords": [
6
6
  "machine learning",