@jax-js/jax 0.1.6 → 0.1.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -1,5 +1,5 @@
1
1
  import { __export } from "./chunk-Cl8Af3a2.js";
2
- import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-Dx6Ob2D1.js";
2
+ import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-nEolvdLv.js";
3
3
 
4
4
  //#region src/frontend/convolution.ts
5
5
  /**
@@ -306,11 +306,11 @@ function map(fn, tree, ...rest) {
306
306
  }
307
307
  /** Take a reference of every array in a tree. */
308
308
  function ref(tree) {
309
- return map((x) => x.ref, tree);
309
+ return map((x) => x instanceof Tracer ? x.ref : x, tree);
310
310
  }
311
311
  /** Dispose every array in a tree. */
312
312
  function dispose(tree) {
313
- if (tree) map((x) => x.dispose(), tree);
313
+ if (tree) map((x) => x instanceof Tracer ? x.dispose() : void 0, tree);
314
314
  }
315
315
 
316
316
  //#endregion
@@ -584,14 +584,20 @@ function shrink(x, slice) {
584
584
  }
585
585
  function pad$1(x, width) {
586
586
  const nd = ndim$1(x);
587
- if (typeof width === "number") width = [[width, width]];
588
- else if (isNumberPair(width)) width = [width];
589
- else if (!Array.isArray(width) || !width.every(isNumberPair)) throw new TypeError(`Invalid pad() type: ${JSON.stringify(width)}`);
590
- if (width.length === 1) {
591
- const [w0, w1] = width[0];
592
- width = rep(nd, () => [w0, w1]);
593
- } else if (width.length !== nd) throw new Error(`Invalid pad(): expected ${nd} axes, got ${width.length}`);
594
- return bind1(Primitive.Pad, [x], { width });
587
+ let w;
588
+ if (typeof width === "number") w = [[width, width]];
589
+ else if (isNumberPair(width)) w = [width];
590
+ else if (!Array.isArray(width)) {
591
+ const indicesAndPairs = Object.entries(width);
592
+ w = rep(nd, [0, 0]);
593
+ for (const [k, v] of indicesAndPairs) w[checkAxis(parseInt(k), nd)] = v;
594
+ } else if (!width.every(isNumberPair)) throw new TypeError(`Invalid pad() type: ${JSON.stringify(width)}`);
595
+ else w = width;
596
+ if (w.length === 1) {
597
+ const [w0, w1] = w[0];
598
+ w = rep(nd, () => [w0, w1]);
599
+ } else if (w.length !== nd) throw new Error(`Invalid pad(): expected ${nd} axes, got ${w.length}`);
600
+ return bind1(Primitive.Pad, [x], { width: w });
595
601
  }
596
602
  function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
597
603
  const as = getShape(a);
@@ -767,6 +773,22 @@ var Tracer = class Tracer {
767
773
  const result = reduce(this.astype(castDtype), AluOp.Add, axis, opts);
768
774
  return result.mul(1 / n).astype(originalDtype);
769
775
  }
776
+ /** Minimum of the elements of the array along a given axis. */
777
+ min(axis = null, opts) {
778
+ return reduce(this, AluOp.Min, axis, opts);
779
+ }
780
+ /** Maximum of the elements of the array along a given axis. */
781
+ max(axis = null, opts) {
782
+ return reduce(this, AluOp.Max, axis, opts);
783
+ }
784
+ /** Test whether all array elements along a given axis evaluate to true. */
785
+ all(axis = null, opts) {
786
+ return this.astype(DType.Bool).min(axis, opts);
787
+ }
788
+ /** Test whether any array element along a given axis evaluates to true. */
789
+ any(axis = null, opts) {
790
+ return this.astype(DType.Bool).max(axis, opts);
791
+ }
770
792
  /** Permute the dimensions of an array. Defaults to reversing the axis order. */
771
793
  transpose(perm) {
772
794
  return transpose$1(this, perm);
@@ -5270,6 +5292,7 @@ function lstsq(a, b) {
5270
5292
  });
5271
5293
  const llb = triangularSolve(l, lb, {
5272
5294
  leftSide: true,
5295
+ lower: true,
5273
5296
  transposeA: true
5274
5297
  });
5275
5298
  return matmul(at, llb.ref);
@@ -5283,6 +5306,7 @@ function lstsq(a, b) {
5283
5306
  });
5284
5307
  const llb = triangularSolve(l, lb, {
5285
5308
  leftSide: true,
5309
+ lower: true,
5286
5310
  transposeA: true
5287
5311
  });
5288
5312
  return llb;
@@ -5570,6 +5594,7 @@ __export(numpy_exports, {
5570
5594
  moveaxis: () => moveaxis$1,
5571
5595
  multiply: () => multiply,
5572
5596
  nan: () => nan,
5597
+ nanToNum: () => nanToNum,
5573
5598
  ndim: () => ndim,
5574
5599
  negative: () => negative,
5575
5600
  notEqual: () => notEqual,
@@ -5767,24 +5792,22 @@ function max(a, axis = null, opts) {
5767
5792
  return reduce(a, AluOp.Max, axis, opts);
5768
5793
  }
5769
5794
  /**
5770
- * Test whether all array elements along a given axis evaluate to True.
5795
+ * Test whether any array element along a given axis evaluates to True.
5771
5796
  *
5772
5797
  * Returns a boolean array with the same shape as `a` with the specified axis
5773
5798
  * removed. If axis is None, returns a scalar.
5774
5799
  */
5775
- function all(a, axis = null, opts) {
5776
- a = fudgeArray(a).astype(DType.Bool);
5777
- return min(a, axis, opts);
5800
+ function any(a, axis = null, opts) {
5801
+ return fudgeArray(a).any(axis, opts);
5778
5802
  }
5779
5803
  /**
5780
- * Test whether any array element along a given axis evaluates to True.
5804
+ * Test whether all array elements along a given axis evaluate to True.
5781
5805
  *
5782
5806
  * Returns a boolean array with the same shape as `a` with the specified axis
5783
5807
  * removed. If axis is None, returns a scalar.
5784
5808
  */
5785
- function any(a, axis = null, opts) {
5786
- a = fudgeArray(a).astype(DType.Bool);
5787
- return max(a, axis, opts);
5809
+ function all(a, axis = null, opts) {
5810
+ return fudgeArray(a).all(axis, opts);
5788
5811
  }
5789
5812
  /** Return the peak-to-peak range along a given axis (`max - min`). */
5790
5813
  function ptp(a, axis = null, opts) {
@@ -5885,7 +5908,7 @@ function split$1(a, indicesOrSections, axis = 0) {
5885
5908
  const partSize = size$1 / indicesOrSections;
5886
5909
  sizes = rep(indicesOrSections, partSize);
5887
5910
  } else {
5888
- const indices = indicesOrSections;
5911
+ const indices = indicesOrSections.map((i) => i < 0 ? i + size$1 : i);
5889
5912
  sizes = [indices[0]];
5890
5913
  for (let i = 1; i < indices.length; i++) sizes.push(indices[i] - indices[i - 1]);
5891
5914
  sizes.push(size$1 - indices[indices.length - 1]);
@@ -6833,6 +6856,21 @@ function isposinf(x) {
6833
6856
  return isFloatDtype(x.dtype) ? x.equal(Infinity) : fullLike$1(x, false);
6834
6857
  }
6835
6858
  /**
6859
+ * Replace NaN and infinite entries in an array.
6860
+ *
6861
+ * By default, NaNs are replaced with `0.0`, and infinities are are substituted
6862
+ * with the corresponding maximum or minimum finite values.
6863
+ */
6864
+ function nanToNum(x, { nan: nan$1 = 0, posinf = null, neginf = null } = {}) {
6865
+ x = fudgeArray(x);
6866
+ x = where(isnan(x.ref), nan$1, x);
6867
+ posinf ??= isFloatDtype(x.dtype) ? finfo(x.dtype).max : iinfo(x.dtype).max;
6868
+ neginf ??= isFloatDtype(x.dtype) ? finfo(x.dtype).min : iinfo(x.dtype).min;
6869
+ x = where(isposinf(x.ref), posinf, x);
6870
+ x = where(isneginf(x.ref), neginf, x);
6871
+ return x;
6872
+ }
6873
+ /**
6836
6874
  * @function
6837
6875
  * Test element-wise for finite values (not infinity or NaN).
6838
6876
  */
@@ -6930,7 +6968,10 @@ function triangularSolve(a, b, { leftSide = false, lower = false, transposeA = f
6930
6968
  b = fudgeArray(b);
6931
6969
  if (!leftSide) transposeA = !transposeA;
6932
6970
  else b = moveaxis$1(b, -2, -1);
6933
- if (transposeA) a = moveaxis$1(a, -2, -1);
6971
+ if (transposeA) {
6972
+ a = moveaxis$1(a, -2, -1);
6973
+ lower = !lower;
6974
+ }
6934
6975
  let x = triangularSolve$1(a, b, {
6935
6976
  lower,
6936
6977
  unitDiagonal
@@ -7541,8 +7582,6 @@ function oneHot(x, numClasses) {
7541
7582
  * `[B, L, N, H]`, or `[L, N, H]` if `B` is omitted.
7542
7583
  */
7543
7584
  function dotProductAttention(query, key$1, value, opts = {}) {
7544
- if (opts.querySeqLengths !== void 0 || opts.keyValueSeqLengths !== void 0) throw new Error("Sequence length masking is not yet implemented");
7545
- if (opts.localWindowSize !== void 0) throw new Error("Local attention is not yet implemented");
7546
7585
  query = fudgeArray(query);
7547
7586
  key$1 = fudgeArray(key$1);
7548
7587
  value = fudgeArray(value);
@@ -7580,6 +7619,38 @@ function dotProductAttention(query, key$1, value, opts = {}) {
7580
7619
  const causalMask = tri(L, S, 0, { dtype: DType.Bool });
7581
7620
  scores = where(causalMask, scores, -Infinity);
7582
7621
  }
7622
+ if (opts.localWindowSize !== void 0) {
7623
+ const [before, after] = typeof opts.localWindowSize === "number" ? [opts.localWindowSize, opts.localWindowSize] : opts.localWindowSize;
7624
+ if (before < 0 || after < 0 || !Number.isInteger(before) || !Number.isInteger(after)) throw new Error(`dotProductAttention: localWindowSize values must be non-negative, got ${opts.localWindowSize}`);
7625
+ const localMask = tri(L, S, after, { dtype: DType.Bool }).mul(tri(L, S, -before - 1, { dtype: DType.Bool }).notEqual(true));
7626
+ scores = where(localMask, scores, -Infinity);
7627
+ }
7628
+ if (opts.querySeqLengths !== void 0) {
7629
+ const sl = expandDims(opts.querySeqLengths, [
7630
+ -1,
7631
+ -2,
7632
+ -3
7633
+ ]);
7634
+ scores = where(arange(L).reshape([
7635
+ 1,
7636
+ 1,
7637
+ L,
7638
+ 1
7639
+ ]).less(sl), scores, -Infinity);
7640
+ }
7641
+ if (opts.keyValueSeqLengths !== void 0) {
7642
+ const sl = expandDims(opts.keyValueSeqLengths, [
7643
+ -1,
7644
+ -2,
7645
+ -3
7646
+ ]);
7647
+ scores = where(arange(S).reshape([
7648
+ 1,
7649
+ 1,
7650
+ 1,
7651
+ S
7652
+ ]).less(sl), scores, -Infinity);
7653
+ }
7583
7654
  const attn = softmax(scores, -1);
7584
7655
  const out = einsum("BNLS,BSNH->BLNH", attn, value);
7585
7656
  return isRank3 ? out.reshape([
@@ -1,4 +1,4 @@
1
- const require_backend = require('./backend-D7s-Retx.cjs');
1
+ const require_backend = require('./backend-B3foXiV_.cjs');
2
2
 
3
3
  //#region src/backend/webgl/builtins.ts
4
4
  const threefrySrc = `
@@ -1,4 +1,4 @@
1
- import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-Dx6Ob2D1.js";
1
+ import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-nEolvdLv.js";
2
2
 
3
3
  //#region src/backend/webgl/builtins.ts
4
4
  const threefrySrc = `
@@ -1,4 +1,4 @@
1
- import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, findPow2, isFloatDtype, mapSetUnion, prod, range, strip1, tuneWebgpu } from "./backend-Dx6Ob2D1.js";
1
+ import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, findPow2, isFloatDtype, mapSetUnion, prod, range, strip1, tuneWebgpu } from "./backend-nEolvdLv.js";
2
2
 
3
3
  //#region src/backend/webgpu/builtins.ts
4
4
  const threefrySrc = `
@@ -1,4 +1,4 @@
1
- const require_backend = require('./backend-D7s-Retx.cjs');
1
+ const require_backend = require('./backend-B3foXiV_.cjs');
2
2
 
3
3
  //#region src/backend/webgpu/builtins.ts
4
4
  const threefrySrc = `
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@jax-js/jax",
3
- "version": "0.1.6",
3
+ "version": "0.1.8",
4
4
  "description": "Numerical computing and ML in the browser",
5
5
  "keywords": [
6
6
  "machine learning",
@@ -44,6 +44,8 @@
44
44
  "eslint": "^9.31.0",
45
45
  "eslint-plugin-import": "^2.32.0",
46
46
  "globals": "^16.0.0",
47
+ "husky": "^9.1.7",
48
+ "lint-staged": "^16.2.7",
47
49
  "playwright": "~1.52.0",
48
50
  "prettier": "^3.6.2",
49
51
  "prettier-plugin-svelte": "^3.4.0",
@@ -74,6 +76,15 @@
74
76
  ],
75
77
  "proseWrap": "always"
76
78
  },
79
+ "lint-staged": {
80
+ "*.{ts,tsx,js,jsx}": [
81
+ "eslint --fix",
82
+ "prettier --write"
83
+ ],
84
+ "*.{json,md,yml,yaml,css,svelte,html}": [
85
+ "prettier --write"
86
+ ]
87
+ },
77
88
  "scripts": {
78
89
  "build": "tsdown",
79
90
  "build:watch": "TSDOWN_WATCH_MODE=1 tsdown",