@jax-js/jax 0.1.6 → 0.1.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +60 -7
- package/dist/{backend-D7s-Retx.cjs → backend-B3foXiV_.cjs} +4 -4
- package/dist/{backend-Dx6Ob2D1.js → backend-nEolvdLv.js} +4 -4
- package/dist/index.cjs +88 -22
- package/dist/index.d.cts +1561 -1538
- package/dist/index.d.ts +1561 -1538
- package/dist/index.js +88 -22
- package/dist/{webgl-CyfzNW8T.cjs → webgl-DIIbKJ0G.cjs} +1 -1
- package/dist/{webgl-CLLvzJlO.js → webgl-DweKSWEm.js} +1 -1
- package/dist/{webgpu-C-VfevQW.js → webgpu-B96vzWGE.js} +1 -1
- package/dist/{webgpu-rraa6dfz.cjs → webgpu-BykvF26B.cjs} +1 -1
- package/package.json +1 -1
package/dist/index.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { __export } from "./chunk-Cl8Af3a2.js";
|
|
2
|
-
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-
|
|
2
|
+
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-nEolvdLv.js";
|
|
3
3
|
|
|
4
4
|
//#region src/frontend/convolution.ts
|
|
5
5
|
/**
|
|
@@ -306,11 +306,11 @@ function map(fn, tree, ...rest) {
|
|
|
306
306
|
}
|
|
307
307
|
/** Take a reference of every array in a tree. */
|
|
308
308
|
function ref(tree) {
|
|
309
|
-
return map((x) => x.ref, tree);
|
|
309
|
+
return map((x) => x instanceof Tracer ? x.ref : x, tree);
|
|
310
310
|
}
|
|
311
311
|
/** Dispose every array in a tree. */
|
|
312
312
|
function dispose(tree) {
|
|
313
|
-
if (tree) map((x) => x.dispose(), tree);
|
|
313
|
+
if (tree) map((x) => x instanceof Tracer ? x.dispose() : void 0, tree);
|
|
314
314
|
}
|
|
315
315
|
|
|
316
316
|
//#endregion
|
|
@@ -584,14 +584,20 @@ function shrink(x, slice) {
|
|
|
584
584
|
}
|
|
585
585
|
function pad$1(x, width) {
|
|
586
586
|
const nd = ndim$1(x);
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
else if (
|
|
590
|
-
if (width
|
|
591
|
-
const
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
587
|
+
let w;
|
|
588
|
+
if (typeof width === "number") w = [[width, width]];
|
|
589
|
+
else if (isNumberPair(width)) w = [width];
|
|
590
|
+
else if (!Array.isArray(width)) {
|
|
591
|
+
const indicesAndPairs = Object.entries(width);
|
|
592
|
+
w = rep(nd, [0, 0]);
|
|
593
|
+
for (const [k, v] of indicesAndPairs) w[checkAxis(parseInt(k), nd)] = v;
|
|
594
|
+
} else if (!width.every(isNumberPair)) throw new TypeError(`Invalid pad() type: ${JSON.stringify(width)}`);
|
|
595
|
+
else w = width;
|
|
596
|
+
if (w.length === 1) {
|
|
597
|
+
const [w0, w1] = w[0];
|
|
598
|
+
w = rep(nd, () => [w0, w1]);
|
|
599
|
+
} else if (w.length !== nd) throw new Error(`Invalid pad(): expected ${nd} axes, got ${w.length}`);
|
|
600
|
+
return bind1(Primitive.Pad, [x], { width: w });
|
|
595
601
|
}
|
|
596
602
|
function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
|
|
597
603
|
const as = getShape(a);
|
|
@@ -767,6 +773,22 @@ var Tracer = class Tracer {
|
|
|
767
773
|
const result = reduce(this.astype(castDtype), AluOp.Add, axis, opts);
|
|
768
774
|
return result.mul(1 / n).astype(originalDtype);
|
|
769
775
|
}
|
|
776
|
+
/** Minimum of the elements of the array along a given axis. */
|
|
777
|
+
min(axis = null, opts) {
|
|
778
|
+
return reduce(this, AluOp.Min, axis, opts);
|
|
779
|
+
}
|
|
780
|
+
/** Maximum of the elements of the array along a given axis. */
|
|
781
|
+
max(axis = null, opts) {
|
|
782
|
+
return reduce(this, AluOp.Max, axis, opts);
|
|
783
|
+
}
|
|
784
|
+
/** Test whether all array elements along a given axis evaluate to true. */
|
|
785
|
+
all(axis = null, opts) {
|
|
786
|
+
return this.astype(DType.Bool).min(axis, opts);
|
|
787
|
+
}
|
|
788
|
+
/** Test whether any array element along a given axis evaluates to true. */
|
|
789
|
+
any(axis = null, opts) {
|
|
790
|
+
return this.astype(DType.Bool).max(axis, opts);
|
|
791
|
+
}
|
|
770
792
|
/** Permute the dimensions of an array. Defaults to reversing the axis order. */
|
|
771
793
|
transpose(perm) {
|
|
772
794
|
return transpose$1(this, perm);
|
|
@@ -5570,6 +5592,7 @@ __export(numpy_exports, {
|
|
|
5570
5592
|
moveaxis: () => moveaxis$1,
|
|
5571
5593
|
multiply: () => multiply,
|
|
5572
5594
|
nan: () => nan,
|
|
5595
|
+
nanToNum: () => nanToNum,
|
|
5573
5596
|
ndim: () => ndim,
|
|
5574
5597
|
negative: () => negative,
|
|
5575
5598
|
notEqual: () => notEqual,
|
|
@@ -5767,24 +5790,22 @@ function max(a, axis = null, opts) {
|
|
|
5767
5790
|
return reduce(a, AluOp.Max, axis, opts);
|
|
5768
5791
|
}
|
|
5769
5792
|
/**
|
|
5770
|
-
* Test whether
|
|
5793
|
+
* Test whether any array element along a given axis evaluates to True.
|
|
5771
5794
|
*
|
|
5772
5795
|
* Returns a boolean array with the same shape as `a` with the specified axis
|
|
5773
5796
|
* removed. If axis is None, returns a scalar.
|
|
5774
5797
|
*/
|
|
5775
|
-
function
|
|
5776
|
-
|
|
5777
|
-
return min(a, axis, opts);
|
|
5798
|
+
function any(a, axis = null, opts) {
|
|
5799
|
+
return fudgeArray(a).any(axis, opts);
|
|
5778
5800
|
}
|
|
5779
5801
|
/**
|
|
5780
|
-
* Test whether
|
|
5802
|
+
* Test whether all array elements along a given axis evaluate to True.
|
|
5781
5803
|
*
|
|
5782
5804
|
* Returns a boolean array with the same shape as `a` with the specified axis
|
|
5783
5805
|
* removed. If axis is None, returns a scalar.
|
|
5784
5806
|
*/
|
|
5785
|
-
function
|
|
5786
|
-
|
|
5787
|
-
return max(a, axis, opts);
|
|
5807
|
+
function all(a, axis = null, opts) {
|
|
5808
|
+
return fudgeArray(a).all(axis, opts);
|
|
5788
5809
|
}
|
|
5789
5810
|
/** Return the peak-to-peak range along a given axis (`max - min`). */
|
|
5790
5811
|
function ptp(a, axis = null, opts) {
|
|
@@ -5885,7 +5906,7 @@ function split$1(a, indicesOrSections, axis = 0) {
|
|
|
5885
5906
|
const partSize = size$1 / indicesOrSections;
|
|
5886
5907
|
sizes = rep(indicesOrSections, partSize);
|
|
5887
5908
|
} else {
|
|
5888
|
-
const indices = indicesOrSections;
|
|
5909
|
+
const indices = indicesOrSections.map((i) => i < 0 ? i + size$1 : i);
|
|
5889
5910
|
sizes = [indices[0]];
|
|
5890
5911
|
for (let i = 1; i < indices.length; i++) sizes.push(indices[i] - indices[i - 1]);
|
|
5891
5912
|
sizes.push(size$1 - indices[indices.length - 1]);
|
|
@@ -6833,6 +6854,21 @@ function isposinf(x) {
|
|
|
6833
6854
|
return isFloatDtype(x.dtype) ? x.equal(Infinity) : fullLike$1(x, false);
|
|
6834
6855
|
}
|
|
6835
6856
|
/**
|
|
6857
|
+
* Replace NaN and infinite entries in an array.
|
|
6858
|
+
*
|
|
6859
|
+
* By default, NaNs are replaced with `0.0`, and infinities are are substituted
|
|
6860
|
+
* with the corresponding maximum or minimum finite values.
|
|
6861
|
+
*/
|
|
6862
|
+
function nanToNum(x, { nan: nan$1 = 0, posinf = null, neginf = null } = {}) {
|
|
6863
|
+
x = fudgeArray(x);
|
|
6864
|
+
x = where(isnan(x.ref), nan$1, x);
|
|
6865
|
+
posinf ??= isFloatDtype(x.dtype) ? finfo(x.dtype).max : iinfo(x.dtype).max;
|
|
6866
|
+
neginf ??= isFloatDtype(x.dtype) ? finfo(x.dtype).min : iinfo(x.dtype).min;
|
|
6867
|
+
x = where(isposinf(x.ref), posinf, x);
|
|
6868
|
+
x = where(isneginf(x.ref), neginf, x);
|
|
6869
|
+
return x;
|
|
6870
|
+
}
|
|
6871
|
+
/**
|
|
6836
6872
|
* @function
|
|
6837
6873
|
* Test element-wise for finite values (not infinity or NaN).
|
|
6838
6874
|
*/
|
|
@@ -7541,8 +7577,6 @@ function oneHot(x, numClasses) {
|
|
|
7541
7577
|
* `[B, L, N, H]`, or `[L, N, H]` if `B` is omitted.
|
|
7542
7578
|
*/
|
|
7543
7579
|
function dotProductAttention(query, key$1, value, opts = {}) {
|
|
7544
|
-
if (opts.querySeqLengths !== void 0 || opts.keyValueSeqLengths !== void 0) throw new Error("Sequence length masking is not yet implemented");
|
|
7545
|
-
if (opts.localWindowSize !== void 0) throw new Error("Local attention is not yet implemented");
|
|
7546
7580
|
query = fudgeArray(query);
|
|
7547
7581
|
key$1 = fudgeArray(key$1);
|
|
7548
7582
|
value = fudgeArray(value);
|
|
@@ -7580,6 +7614,38 @@ function dotProductAttention(query, key$1, value, opts = {}) {
|
|
|
7580
7614
|
const causalMask = tri(L, S, 0, { dtype: DType.Bool });
|
|
7581
7615
|
scores = where(causalMask, scores, -Infinity);
|
|
7582
7616
|
}
|
|
7617
|
+
if (opts.localWindowSize !== void 0) {
|
|
7618
|
+
const [before, after] = typeof opts.localWindowSize === "number" ? [opts.localWindowSize, opts.localWindowSize] : opts.localWindowSize;
|
|
7619
|
+
if (before < 0 || after < 0 || !Number.isInteger(before) || !Number.isInteger(after)) throw new Error(`dotProductAttention: localWindowSize values must be non-negative, got ${opts.localWindowSize}`);
|
|
7620
|
+
const localMask = tri(L, S, after, { dtype: DType.Bool }).mul(tri(L, S, -before - 1, { dtype: DType.Bool }).notEqual(true));
|
|
7621
|
+
scores = where(localMask, scores, -Infinity);
|
|
7622
|
+
}
|
|
7623
|
+
if (opts.querySeqLengths !== void 0) {
|
|
7624
|
+
const sl = expandDims(opts.querySeqLengths, [
|
|
7625
|
+
-1,
|
|
7626
|
+
-2,
|
|
7627
|
+
-3
|
|
7628
|
+
]);
|
|
7629
|
+
scores = where(arange(L).reshape([
|
|
7630
|
+
1,
|
|
7631
|
+
1,
|
|
7632
|
+
L,
|
|
7633
|
+
1
|
|
7634
|
+
]).less(sl), scores, -Infinity);
|
|
7635
|
+
}
|
|
7636
|
+
if (opts.keyValueSeqLengths !== void 0) {
|
|
7637
|
+
const sl = expandDims(opts.keyValueSeqLengths, [
|
|
7638
|
+
-1,
|
|
7639
|
+
-2,
|
|
7640
|
+
-3
|
|
7641
|
+
]);
|
|
7642
|
+
scores = where(arange(S).reshape([
|
|
7643
|
+
1,
|
|
7644
|
+
1,
|
|
7645
|
+
1,
|
|
7646
|
+
S
|
|
7647
|
+
]).less(sl), scores, -Infinity);
|
|
7648
|
+
}
|
|
7583
7649
|
const attn = softmax(scores, -1);
|
|
7584
7650
|
const out = einsum("BNLS,BSNH->BLNH", attn, value);
|
|
7585
7651
|
return isRank3 ? out.reshape([
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-
|
|
1
|
+
import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-nEolvdLv.js";
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgl/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, findPow2, isFloatDtype, mapSetUnion, prod, range, strip1, tuneWebgpu } from "./backend-
|
|
1
|
+
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, findPow2, isFloatDtype, mapSetUnion, prod, range, strip1, tuneWebgpu } from "./backend-nEolvdLv.js";
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu/builtins.ts
|
|
4
4
|
const threefrySrc = `
|