@jax-js/jax 0.1.6 → 0.1.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +73 -7
- package/dist/{backend-D7s-Retx.cjs → backend-B3foXiV_.cjs} +4 -4
- package/dist/{backend-Dx6Ob2D1.js → backend-nEolvdLv.js} +4 -4
- package/dist/index.cjs +94 -23
- package/dist/index.d.cts +1561 -1538
- package/dist/index.d.ts +1561 -1538
- package/dist/index.js +94 -23
- package/dist/{webgl-CyfzNW8T.cjs → webgl-DIIbKJ0G.cjs} +1 -1
- package/dist/{webgl-CLLvzJlO.js → webgl-DweKSWEm.js} +1 -1
- package/dist/{webgpu-C-VfevQW.js → webgpu-B96vzWGE.js} +1 -1
- package/dist/{webgpu-rraa6dfz.cjs → webgpu-BykvF26B.cjs} +1 -1
- package/package.json +12 -1
package/dist/index.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { __export } from "./chunk-Cl8Af3a2.js";
|
|
2
|
-
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-
|
|
2
|
+
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-nEolvdLv.js";
|
|
3
3
|
|
|
4
4
|
//#region src/frontend/convolution.ts
|
|
5
5
|
/**
|
|
@@ -306,11 +306,11 @@ function map(fn, tree, ...rest) {
|
|
|
306
306
|
}
|
|
307
307
|
/** Take a reference of every array in a tree. */
|
|
308
308
|
function ref(tree) {
|
|
309
|
-
return map((x) => x.ref, tree);
|
|
309
|
+
return map((x) => x instanceof Tracer ? x.ref : x, tree);
|
|
310
310
|
}
|
|
311
311
|
/** Dispose every array in a tree. */
|
|
312
312
|
function dispose(tree) {
|
|
313
|
-
if (tree) map((x) => x.dispose(), tree);
|
|
313
|
+
if (tree) map((x) => x instanceof Tracer ? x.dispose() : void 0, tree);
|
|
314
314
|
}
|
|
315
315
|
|
|
316
316
|
//#endregion
|
|
@@ -584,14 +584,20 @@ function shrink(x, slice) {
|
|
|
584
584
|
}
|
|
585
585
|
function pad$1(x, width) {
|
|
586
586
|
const nd = ndim$1(x);
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
else if (
|
|
590
|
-
if (width
|
|
591
|
-
const
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
587
|
+
let w;
|
|
588
|
+
if (typeof width === "number") w = [[width, width]];
|
|
589
|
+
else if (isNumberPair(width)) w = [width];
|
|
590
|
+
else if (!Array.isArray(width)) {
|
|
591
|
+
const indicesAndPairs = Object.entries(width);
|
|
592
|
+
w = rep(nd, [0, 0]);
|
|
593
|
+
for (const [k, v] of indicesAndPairs) w[checkAxis(parseInt(k), nd)] = v;
|
|
594
|
+
} else if (!width.every(isNumberPair)) throw new TypeError(`Invalid pad() type: ${JSON.stringify(width)}`);
|
|
595
|
+
else w = width;
|
|
596
|
+
if (w.length === 1) {
|
|
597
|
+
const [w0, w1] = w[0];
|
|
598
|
+
w = rep(nd, () => [w0, w1]);
|
|
599
|
+
} else if (w.length !== nd) throw new Error(`Invalid pad(): expected ${nd} axes, got ${w.length}`);
|
|
600
|
+
return bind1(Primitive.Pad, [x], { width: w });
|
|
595
601
|
}
|
|
596
602
|
function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
|
|
597
603
|
const as = getShape(a);
|
|
@@ -767,6 +773,22 @@ var Tracer = class Tracer {
|
|
|
767
773
|
const result = reduce(this.astype(castDtype), AluOp.Add, axis, opts);
|
|
768
774
|
return result.mul(1 / n).astype(originalDtype);
|
|
769
775
|
}
|
|
776
|
+
/** Minimum of the elements of the array along a given axis. */
|
|
777
|
+
min(axis = null, opts) {
|
|
778
|
+
return reduce(this, AluOp.Min, axis, opts);
|
|
779
|
+
}
|
|
780
|
+
/** Maximum of the elements of the array along a given axis. */
|
|
781
|
+
max(axis = null, opts) {
|
|
782
|
+
return reduce(this, AluOp.Max, axis, opts);
|
|
783
|
+
}
|
|
784
|
+
/** Test whether all array elements along a given axis evaluate to true. */
|
|
785
|
+
all(axis = null, opts) {
|
|
786
|
+
return this.astype(DType.Bool).min(axis, opts);
|
|
787
|
+
}
|
|
788
|
+
/** Test whether any array element along a given axis evaluates to true. */
|
|
789
|
+
any(axis = null, opts) {
|
|
790
|
+
return this.astype(DType.Bool).max(axis, opts);
|
|
791
|
+
}
|
|
770
792
|
/** Permute the dimensions of an array. Defaults to reversing the axis order. */
|
|
771
793
|
transpose(perm) {
|
|
772
794
|
return transpose$1(this, perm);
|
|
@@ -5270,6 +5292,7 @@ function lstsq(a, b) {
|
|
|
5270
5292
|
});
|
|
5271
5293
|
const llb = triangularSolve(l, lb, {
|
|
5272
5294
|
leftSide: true,
|
|
5295
|
+
lower: true,
|
|
5273
5296
|
transposeA: true
|
|
5274
5297
|
});
|
|
5275
5298
|
return matmul(at, llb.ref);
|
|
@@ -5283,6 +5306,7 @@ function lstsq(a, b) {
|
|
|
5283
5306
|
});
|
|
5284
5307
|
const llb = triangularSolve(l, lb, {
|
|
5285
5308
|
leftSide: true,
|
|
5309
|
+
lower: true,
|
|
5286
5310
|
transposeA: true
|
|
5287
5311
|
});
|
|
5288
5312
|
return llb;
|
|
@@ -5570,6 +5594,7 @@ __export(numpy_exports, {
|
|
|
5570
5594
|
moveaxis: () => moveaxis$1,
|
|
5571
5595
|
multiply: () => multiply,
|
|
5572
5596
|
nan: () => nan,
|
|
5597
|
+
nanToNum: () => nanToNum,
|
|
5573
5598
|
ndim: () => ndim,
|
|
5574
5599
|
negative: () => negative,
|
|
5575
5600
|
notEqual: () => notEqual,
|
|
@@ -5767,24 +5792,22 @@ function max(a, axis = null, opts) {
|
|
|
5767
5792
|
return reduce(a, AluOp.Max, axis, opts);
|
|
5768
5793
|
}
|
|
5769
5794
|
/**
|
|
5770
|
-
* Test whether
|
|
5795
|
+
* Test whether any array element along a given axis evaluates to True.
|
|
5771
5796
|
*
|
|
5772
5797
|
* Returns a boolean array with the same shape as `a` with the specified axis
|
|
5773
5798
|
* removed. If axis is None, returns a scalar.
|
|
5774
5799
|
*/
|
|
5775
|
-
function
|
|
5776
|
-
|
|
5777
|
-
return min(a, axis, opts);
|
|
5800
|
+
function any(a, axis = null, opts) {
|
|
5801
|
+
return fudgeArray(a).any(axis, opts);
|
|
5778
5802
|
}
|
|
5779
5803
|
/**
|
|
5780
|
-
* Test whether
|
|
5804
|
+
* Test whether all array elements along a given axis evaluate to True.
|
|
5781
5805
|
*
|
|
5782
5806
|
* Returns a boolean array with the same shape as `a` with the specified axis
|
|
5783
5807
|
* removed. If axis is None, returns a scalar.
|
|
5784
5808
|
*/
|
|
5785
|
-
function
|
|
5786
|
-
|
|
5787
|
-
return max(a, axis, opts);
|
|
5809
|
+
function all(a, axis = null, opts) {
|
|
5810
|
+
return fudgeArray(a).all(axis, opts);
|
|
5788
5811
|
}
|
|
5789
5812
|
/** Return the peak-to-peak range along a given axis (`max - min`). */
|
|
5790
5813
|
function ptp(a, axis = null, opts) {
|
|
@@ -5885,7 +5908,7 @@ function split$1(a, indicesOrSections, axis = 0) {
|
|
|
5885
5908
|
const partSize = size$1 / indicesOrSections;
|
|
5886
5909
|
sizes = rep(indicesOrSections, partSize);
|
|
5887
5910
|
} else {
|
|
5888
|
-
const indices = indicesOrSections;
|
|
5911
|
+
const indices = indicesOrSections.map((i) => i < 0 ? i + size$1 : i);
|
|
5889
5912
|
sizes = [indices[0]];
|
|
5890
5913
|
for (let i = 1; i < indices.length; i++) sizes.push(indices[i] - indices[i - 1]);
|
|
5891
5914
|
sizes.push(size$1 - indices[indices.length - 1]);
|
|
@@ -6833,6 +6856,21 @@ function isposinf(x) {
|
|
|
6833
6856
|
return isFloatDtype(x.dtype) ? x.equal(Infinity) : fullLike$1(x, false);
|
|
6834
6857
|
}
|
|
6835
6858
|
/**
|
|
6859
|
+
* Replace NaN and infinite entries in an array.
|
|
6860
|
+
*
|
|
6861
|
+
* By default, NaNs are replaced with `0.0`, and infinities are are substituted
|
|
6862
|
+
* with the corresponding maximum or minimum finite values.
|
|
6863
|
+
*/
|
|
6864
|
+
function nanToNum(x, { nan: nan$1 = 0, posinf = null, neginf = null } = {}) {
|
|
6865
|
+
x = fudgeArray(x);
|
|
6866
|
+
x = where(isnan(x.ref), nan$1, x);
|
|
6867
|
+
posinf ??= isFloatDtype(x.dtype) ? finfo(x.dtype).max : iinfo(x.dtype).max;
|
|
6868
|
+
neginf ??= isFloatDtype(x.dtype) ? finfo(x.dtype).min : iinfo(x.dtype).min;
|
|
6869
|
+
x = where(isposinf(x.ref), posinf, x);
|
|
6870
|
+
x = where(isneginf(x.ref), neginf, x);
|
|
6871
|
+
return x;
|
|
6872
|
+
}
|
|
6873
|
+
/**
|
|
6836
6874
|
* @function
|
|
6837
6875
|
* Test element-wise for finite values (not infinity or NaN).
|
|
6838
6876
|
*/
|
|
@@ -6930,7 +6968,10 @@ function triangularSolve(a, b, { leftSide = false, lower = false, transposeA = f
|
|
|
6930
6968
|
b = fudgeArray(b);
|
|
6931
6969
|
if (!leftSide) transposeA = !transposeA;
|
|
6932
6970
|
else b = moveaxis$1(b, -2, -1);
|
|
6933
|
-
if (transposeA)
|
|
6971
|
+
if (transposeA) {
|
|
6972
|
+
a = moveaxis$1(a, -2, -1);
|
|
6973
|
+
lower = !lower;
|
|
6974
|
+
}
|
|
6934
6975
|
let x = triangularSolve$1(a, b, {
|
|
6935
6976
|
lower,
|
|
6936
6977
|
unitDiagonal
|
|
@@ -7541,8 +7582,6 @@ function oneHot(x, numClasses) {
|
|
|
7541
7582
|
* `[B, L, N, H]`, or `[L, N, H]` if `B` is omitted.
|
|
7542
7583
|
*/
|
|
7543
7584
|
function dotProductAttention(query, key$1, value, opts = {}) {
|
|
7544
|
-
if (opts.querySeqLengths !== void 0 || opts.keyValueSeqLengths !== void 0) throw new Error("Sequence length masking is not yet implemented");
|
|
7545
|
-
if (opts.localWindowSize !== void 0) throw new Error("Local attention is not yet implemented");
|
|
7546
7585
|
query = fudgeArray(query);
|
|
7547
7586
|
key$1 = fudgeArray(key$1);
|
|
7548
7587
|
value = fudgeArray(value);
|
|
@@ -7580,6 +7619,38 @@ function dotProductAttention(query, key$1, value, opts = {}) {
|
|
|
7580
7619
|
const causalMask = tri(L, S, 0, { dtype: DType.Bool });
|
|
7581
7620
|
scores = where(causalMask, scores, -Infinity);
|
|
7582
7621
|
}
|
|
7622
|
+
if (opts.localWindowSize !== void 0) {
|
|
7623
|
+
const [before, after] = typeof opts.localWindowSize === "number" ? [opts.localWindowSize, opts.localWindowSize] : opts.localWindowSize;
|
|
7624
|
+
if (before < 0 || after < 0 || !Number.isInteger(before) || !Number.isInteger(after)) throw new Error(`dotProductAttention: localWindowSize values must be non-negative, got ${opts.localWindowSize}`);
|
|
7625
|
+
const localMask = tri(L, S, after, { dtype: DType.Bool }).mul(tri(L, S, -before - 1, { dtype: DType.Bool }).notEqual(true));
|
|
7626
|
+
scores = where(localMask, scores, -Infinity);
|
|
7627
|
+
}
|
|
7628
|
+
if (opts.querySeqLengths !== void 0) {
|
|
7629
|
+
const sl = expandDims(opts.querySeqLengths, [
|
|
7630
|
+
-1,
|
|
7631
|
+
-2,
|
|
7632
|
+
-3
|
|
7633
|
+
]);
|
|
7634
|
+
scores = where(arange(L).reshape([
|
|
7635
|
+
1,
|
|
7636
|
+
1,
|
|
7637
|
+
L,
|
|
7638
|
+
1
|
|
7639
|
+
]).less(sl), scores, -Infinity);
|
|
7640
|
+
}
|
|
7641
|
+
if (opts.keyValueSeqLengths !== void 0) {
|
|
7642
|
+
const sl = expandDims(opts.keyValueSeqLengths, [
|
|
7643
|
+
-1,
|
|
7644
|
+
-2,
|
|
7645
|
+
-3
|
|
7646
|
+
]);
|
|
7647
|
+
scores = where(arange(S).reshape([
|
|
7648
|
+
1,
|
|
7649
|
+
1,
|
|
7650
|
+
1,
|
|
7651
|
+
S
|
|
7652
|
+
]).less(sl), scores, -Infinity);
|
|
7653
|
+
}
|
|
7583
7654
|
const attn = softmax(scores, -1);
|
|
7584
7655
|
const out = einsum("BNLS,BSNH->BLNH", attn, value);
|
|
7585
7656
|
return isRank3 ? out.reshape([
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-
|
|
1
|
+
import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-nEolvdLv.js";
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgl/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, findPow2, isFloatDtype, mapSetUnion, prod, range, strip1, tuneWebgpu } from "./backend-
|
|
1
|
+
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, findPow2, isFloatDtype, mapSetUnion, prod, range, strip1, tuneWebgpu } from "./backend-nEolvdLv.js";
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@jax-js/jax",
|
|
3
|
-
"version": "0.1.
|
|
3
|
+
"version": "0.1.8",
|
|
4
4
|
"description": "Numerical computing and ML in the browser",
|
|
5
5
|
"keywords": [
|
|
6
6
|
"machine learning",
|
|
@@ -44,6 +44,8 @@
|
|
|
44
44
|
"eslint": "^9.31.0",
|
|
45
45
|
"eslint-plugin-import": "^2.32.0",
|
|
46
46
|
"globals": "^16.0.0",
|
|
47
|
+
"husky": "^9.1.7",
|
|
48
|
+
"lint-staged": "^16.2.7",
|
|
47
49
|
"playwright": "~1.52.0",
|
|
48
50
|
"prettier": "^3.6.2",
|
|
49
51
|
"prettier-plugin-svelte": "^3.4.0",
|
|
@@ -74,6 +76,15 @@
|
|
|
74
76
|
],
|
|
75
77
|
"proseWrap": "always"
|
|
76
78
|
},
|
|
79
|
+
"lint-staged": {
|
|
80
|
+
"*.{ts,tsx,js,jsx}": [
|
|
81
|
+
"eslint --fix",
|
|
82
|
+
"prettier --write"
|
|
83
|
+
],
|
|
84
|
+
"*.{json,md,yml,yaml,css,svelte,html}": [
|
|
85
|
+
"prettier --write"
|
|
86
|
+
]
|
|
87
|
+
},
|
|
77
88
|
"scripts": {
|
|
78
89
|
"build": "tsdown",
|
|
79
90
|
"build:watch": "TSDOWN_WATCH_MODE=1 tsdown",
|