@jax-js/jax 0.1.8 → 0.1.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +46 -29
- package/dist/{backend-nEolvdLv.js → backend-Ctqs8la1.js} +122 -15
- package/dist/{backend-B3foXiV_.cjs → backend-DMauYnfl.cjs} +157 -14
- package/dist/index.cjs +331 -46
- package/dist/index.d.cts +175 -31
- package/dist/index.d.ts +175 -31
- package/dist/index.js +331 -47
- package/dist/{webgl-DweKSWEm.js → webgl-CvQ1QBX1.js} +1 -1
- package/dist/{webgl-DIIbKJ0G.cjs → webgl-kvVt7-T7.cjs} +1 -1
- package/dist/{webgpu-BykvF26B.cjs → webgpu-DMSx7a6M.cjs} +160 -15
- package/dist/{webgpu-B96vzWGE.js → webgpu-v_W_-oKw.js} +160 -15
- package/package.json +5 -16
package/dist/index.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { __export } from "./chunk-Cl8Af3a2.js";
|
|
2
|
-
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-
|
|
2
|
+
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, startTrace, stopTrace, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-Ctqs8la1.js";
|
|
3
3
|
|
|
4
4
|
//#region src/frontend/convolution.ts
|
|
5
5
|
/**
|
|
@@ -807,6 +807,11 @@ var Tracer = class Tracer {
|
|
|
807
807
|
if (this.dtype === dtype) return this;
|
|
808
808
|
return cast(this, dtype);
|
|
809
809
|
}
|
|
810
|
+
/** Return a bitwise cast of the array, viewed as a new dtype. */
|
|
811
|
+
view(dtype) {
|
|
812
|
+
if (!dtype || dtype === this.dtype) return this;
|
|
813
|
+
return bitcast(this, dtype);
|
|
814
|
+
}
|
|
810
815
|
/** Subtract an array from this one. */
|
|
811
816
|
sub(other) {
|
|
812
817
|
return this.add(neg(other));
|
|
@@ -889,18 +894,25 @@ var Tracer = class Tracer {
|
|
|
889
894
|
return sort$1(this.transpose(perm)).transpose(invertPermutation(perm));
|
|
890
895
|
}
|
|
891
896
|
/**
|
|
892
|
-
* Return the indices that would sort an array.
|
|
893
|
-
* sorting algorithm; it
|
|
897
|
+
* Return the indices that would sort an array. Unlike `sort`, this is
|
|
898
|
+
* guaranteed to be a stable sorting algorithm; it always returns the smaller
|
|
899
|
+
* index first in event of ties.
|
|
894
900
|
*
|
|
895
901
|
* See `jax.numpy.argsort` for full docs.
|
|
896
902
|
*/
|
|
897
903
|
argsort(axis = -1) {
|
|
898
904
|
axis = checkAxis(axis, this.ndim);
|
|
899
|
-
if (axis === this.ndim - 1)
|
|
905
|
+
if (axis === this.ndim - 1) {
|
|
906
|
+
const [y$1, yi$1] = argsort$1(this);
|
|
907
|
+
y$1.dispose();
|
|
908
|
+
return yi$1;
|
|
909
|
+
}
|
|
900
910
|
const perm = range(this.ndim);
|
|
901
911
|
perm.splice(axis, 1);
|
|
902
912
|
perm.push(axis);
|
|
903
|
-
|
|
913
|
+
const [y, yi] = argsort$1(this.transpose(perm));
|
|
914
|
+
y.dispose();
|
|
915
|
+
return yi.transpose(invertPermutation(perm));
|
|
904
916
|
}
|
|
905
917
|
/**
|
|
906
918
|
* Slice an array along one or more axes.
|
|
@@ -1617,7 +1629,7 @@ const abstractEvalRules = {
|
|
|
1617
1629
|
return [new ShapedArray(x.shape, dtype, false)];
|
|
1618
1630
|
},
|
|
1619
1631
|
[Primitive.Bitcast]([x], { dtype }) {
|
|
1620
|
-
if (x.dtype === DType.Bool || dtype === DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
|
|
1632
|
+
if (x.dtype !== dtype && (x.dtype === DType.Bool || dtype === DType.Bool)) throw new TypeError("Bitcast to/from bool is not allowed");
|
|
1621
1633
|
if (byteWidth(x.dtype) !== byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
|
|
1622
1634
|
return [new ShapedArray(x.shape, dtype, false)];
|
|
1623
1635
|
},
|
|
@@ -3039,8 +3051,8 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
3039
3051
|
return [x.#unary(AluOp.Cast, dtype)];
|
|
3040
3052
|
},
|
|
3041
3053
|
[Primitive.Bitcast]([x], { dtype }) {
|
|
3042
|
-
if (x.dtype === DType.Bool || dtype === DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
|
|
3043
3054
|
if (x.dtype === dtype) return [x];
|
|
3055
|
+
if (x.dtype === DType.Bool || dtype === DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
|
|
3044
3056
|
if (byteWidth(x.dtype) !== byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
|
|
3045
3057
|
if (x.#source instanceof AluExp) return [x.#unary(AluOp.Bitcast, dtype)];
|
|
3046
3058
|
else {
|
|
@@ -3381,32 +3393,26 @@ function fullInternal(aval, fillValue, device) {
|
|
|
3381
3393
|
committed: device != void 0
|
|
3382
3394
|
});
|
|
3383
3395
|
}
|
|
3384
|
-
function zerosLike$1(val,
|
|
3385
|
-
return fullLike(val, 0,
|
|
3396
|
+
function zerosLike$1(val, opts) {
|
|
3397
|
+
return fullLike(val, 0, opts);
|
|
3386
3398
|
}
|
|
3387
|
-
function onesLike$1(val,
|
|
3388
|
-
return fullLike(val, 1,
|
|
3399
|
+
function onesLike$1(val, opts) {
|
|
3400
|
+
return fullLike(val, 1, opts);
|
|
3389
3401
|
}
|
|
3390
|
-
function fullLike(val, fillValue, dtype) {
|
|
3402
|
+
function fullLike(val, fillValue, { dtype, shape: shape$1, device } = {}) {
|
|
3391
3403
|
const aval = getAval(val);
|
|
3392
3404
|
if (val instanceof Tracer) val.dispose();
|
|
3393
3405
|
if (fillValue instanceof Tracer) throw new Error("numpy.fullLike() with array argument not implemented yet");
|
|
3394
|
-
const sa = new ShapedArray(aval.shape, dtype ?? aval.dtype, aval.weakType);
|
|
3395
|
-
return fullInternal(sa, fillValue);
|
|
3406
|
+
const sa = new ShapedArray(shape$1 ?? aval.shape, dtype ?? aval.dtype, aval.weakType && dtype === void 0);
|
|
3407
|
+
return fullInternal(sa, fillValue, device);
|
|
3396
3408
|
}
|
|
3397
3409
|
/** Return a new array of given shape and type, filled with zeros. */
|
|
3398
|
-
function zeros(shape$1,
|
|
3399
|
-
return full(shape$1, 0,
|
|
3400
|
-
dtype,
|
|
3401
|
-
device
|
|
3402
|
-
});
|
|
3410
|
+
function zeros(shape$1, opts) {
|
|
3411
|
+
return full(shape$1, 0, opts);
|
|
3403
3412
|
}
|
|
3404
3413
|
/** Return a new array of given shape and type, filled with ones. */
|
|
3405
|
-
function ones(shape$1,
|
|
3406
|
-
return full(shape$1, 1,
|
|
3407
|
-
dtype,
|
|
3408
|
-
device
|
|
3409
|
-
});
|
|
3414
|
+
function ones(shape$1, opts) {
|
|
3415
|
+
return full(shape$1, 1, opts);
|
|
3410
3416
|
}
|
|
3411
3417
|
/** Return a new array of given shape and type, filled with `fill_value`. */
|
|
3412
3418
|
function full(shape$1, fillValue, { dtype, device } = {}) {
|
|
@@ -4141,6 +4147,7 @@ const jvpRules = {
|
|
|
4141
4147
|
},
|
|
4142
4148
|
[Primitive.TriangularSolve]([a, b], [da, db], { unitDiagonal }) {
|
|
4143
4149
|
const x = triangularSolve$1(a.ref, b, { unitDiagonal });
|
|
4150
|
+
da = unitDiagonal ? triu(da, 1) : triu(da);
|
|
4144
4151
|
const dax = batchMatmulT(da, x.ref);
|
|
4145
4152
|
const rhsT = db.sub(mT(dax));
|
|
4146
4153
|
const dx = triangularSolve$1(a, rhsT, { unitDiagonal });
|
|
@@ -5216,6 +5223,7 @@ function ifft(a, axis = -1) {
|
|
|
5216
5223
|
var numpy_linalg_exports = {};
|
|
5217
5224
|
__export(numpy_linalg_exports, {
|
|
5218
5225
|
cholesky: () => cholesky,
|
|
5226
|
+
cross: () => cross$1,
|
|
5219
5227
|
det: () => det,
|
|
5220
5228
|
diagonal: () => diagonal,
|
|
5221
5229
|
inv: () => inv,
|
|
@@ -5246,6 +5254,19 @@ function cholesky(a, { upper = false, symmetrizeInput = true } = {}) {
|
|
|
5246
5254
|
if (symmetrizeInput) a = a.ref.add(matrixTranspose(a)).mul(.5);
|
|
5247
5255
|
return cholesky$1(a, { upper });
|
|
5248
5256
|
}
|
|
5257
|
+
/**
|
|
5258
|
+
* Compute the cross-product of two 3D vectors.
|
|
5259
|
+
*
|
|
5260
|
+
* This is a simpler and less flexible version of `jax.numpy.cross()`.
|
|
5261
|
+
* Both inputs must have size 3 along the specified axis.
|
|
5262
|
+
*/
|
|
5263
|
+
function cross$1(x1, x2, axis = -1) {
|
|
5264
|
+
const a1 = checkAxis(axis, ndim(x1));
|
|
5265
|
+
const a2 = checkAxis(axis, ndim(x2));
|
|
5266
|
+
if (shape(x1)[a1] !== 3) throw new Error(`linalg.cross: x1 must have size 3 along axis ${axis}, got ${shape(x1)[a1]}`);
|
|
5267
|
+
if (shape(x2)[a2] !== 3) throw new Error(`linalg.cross: x2 must have size 3 along axis ${axis}, got ${shape(x2)[a2]}`);
|
|
5268
|
+
return cross(x1, x2, { axis });
|
|
5269
|
+
}
|
|
5249
5270
|
/** Compute the determinant of a square matrix (batched). */
|
|
5250
5271
|
function det(a) {
|
|
5251
5272
|
a = fudgeArray(a);
|
|
@@ -5261,7 +5282,7 @@ function det(a) {
|
|
|
5261
5282
|
function inv(a) {
|
|
5262
5283
|
a = fudgeArray(a);
|
|
5263
5284
|
const n = checkSquare("inv", a);
|
|
5264
|
-
return solve(a, eye(n));
|
|
5285
|
+
return solve(a, eye(n, void 0, { dtype: a.dtype }));
|
|
5265
5286
|
}
|
|
5266
5287
|
/**
|
|
5267
5288
|
* Return the least-squares solution to a linear equation.
|
|
@@ -5295,7 +5316,7 @@ function lstsq(a, b) {
|
|
|
5295
5316
|
lower: true,
|
|
5296
5317
|
transposeA: true
|
|
5297
5318
|
});
|
|
5298
|
-
return matmul(at, llb
|
|
5319
|
+
return matmul(at, llb);
|
|
5299
5320
|
} else {
|
|
5300
5321
|
const ata = matmul(at.ref, a);
|
|
5301
5322
|
const l = cholesky(ata, { symmetrizeInput: false });
|
|
@@ -5318,8 +5339,9 @@ function matrixPower(a, n) {
|
|
|
5318
5339
|
a = fudgeArray(a);
|
|
5319
5340
|
const m = checkSquare("matrixPower", a);
|
|
5320
5341
|
if (n === 0) {
|
|
5342
|
+
const dtype = a.dtype;
|
|
5321
5343
|
a.dispose();
|
|
5322
|
-
return broadcastTo(eye(m), a.shape);
|
|
5344
|
+
return broadcastTo(eye(m, void 0, { dtype }), a.shape);
|
|
5323
5345
|
}
|
|
5324
5346
|
if (n < 0) {
|
|
5325
5347
|
a = inv(a);
|
|
@@ -5386,7 +5408,7 @@ function solve(a, b) {
|
|
|
5386
5408
|
lower: true,
|
|
5387
5409
|
unitDiagonal: true
|
|
5388
5410
|
});
|
|
5389
|
-
let x = triangularSolve(lu$2, LPb
|
|
5411
|
+
let x = triangularSolve(lu$2, LPb, {
|
|
5390
5412
|
leftSide: true,
|
|
5391
5413
|
lower: false
|
|
5392
5414
|
});
|
|
@@ -5501,13 +5523,17 @@ __export(numpy_exports, {
|
|
|
5501
5523
|
argmax: () => argmax,
|
|
5502
5524
|
argmin: () => argmin,
|
|
5503
5525
|
argsort: () => argsort,
|
|
5526
|
+
around: () => round,
|
|
5504
5527
|
array: () => array,
|
|
5528
|
+
arrayEqual: () => arrayEqual,
|
|
5529
|
+
arrayEquiv: () => arrayEquiv,
|
|
5505
5530
|
asin: () => asin,
|
|
5506
5531
|
asinh: () => arcsinh,
|
|
5507
5532
|
astype: () => astype,
|
|
5508
5533
|
atan: () => atan,
|
|
5509
5534
|
atan2: () => atan2,
|
|
5510
5535
|
atanh: () => arctanh,
|
|
5536
|
+
average: () => average,
|
|
5511
5537
|
bool: () => bool,
|
|
5512
5538
|
broadcastArrays: () => broadcastArrays,
|
|
5513
5539
|
broadcastShapes: () => broadcastShapes,
|
|
@@ -5518,11 +5544,13 @@ __export(numpy_exports, {
|
|
|
5518
5544
|
columnStack: () => columnStack,
|
|
5519
5545
|
concatenate: () => concatenate,
|
|
5520
5546
|
convolve: () => convolve,
|
|
5547
|
+
copysign: () => copysign,
|
|
5521
5548
|
corrcoef: () => corrcoef,
|
|
5522
5549
|
correlate: () => correlate,
|
|
5523
5550
|
cos: () => cos,
|
|
5524
5551
|
cosh: () => cosh,
|
|
5525
5552
|
cov: () => cov,
|
|
5553
|
+
cross: () => cross,
|
|
5526
5554
|
cumsum: () => cumsum,
|
|
5527
5555
|
cumulativeSum: () => cumsum,
|
|
5528
5556
|
deg2rad: () => deg2rad,
|
|
@@ -5558,7 +5586,6 @@ __export(numpy_exports, {
|
|
|
5558
5586
|
fullLike: () => fullLike$1,
|
|
5559
5587
|
greater: () => greater,
|
|
5560
5588
|
greaterEqual: () => greaterEqual,
|
|
5561
|
-
hamming: () => hamming,
|
|
5562
5589
|
hann: () => hann,
|
|
5563
5590
|
heaviside: () => heaviside,
|
|
5564
5591
|
hstack: () => hstack,
|
|
@@ -5582,9 +5609,14 @@ __export(numpy_exports, {
|
|
|
5582
5609
|
log10: () => log10,
|
|
5583
5610
|
log1p: () => log1p,
|
|
5584
5611
|
log2: () => log2,
|
|
5612
|
+
logicalAnd: () => logicalAnd,
|
|
5613
|
+
logicalNot: () => logicalNot,
|
|
5614
|
+
logicalOr: () => logicalOr,
|
|
5615
|
+
logicalXor: () => logicalXor,
|
|
5585
5616
|
logspace: () => logspace,
|
|
5586
5617
|
matmul: () => matmul,
|
|
5587
5618
|
matrixTranspose: () => matrixTranspose,
|
|
5619
|
+
matvec: () => matvec,
|
|
5588
5620
|
max: () => max,
|
|
5589
5621
|
maximum: () => maximum,
|
|
5590
5622
|
mean: () => mean,
|
|
@@ -5617,6 +5649,8 @@ __export(numpy_exports, {
|
|
|
5617
5649
|
remainder: () => remainder,
|
|
5618
5650
|
repeat: () => repeat,
|
|
5619
5651
|
reshape: () => reshape,
|
|
5652
|
+
rint: () => rint,
|
|
5653
|
+
round: () => round,
|
|
5620
5654
|
shape: () => shape,
|
|
5621
5655
|
sign: () => sign,
|
|
5622
5656
|
sin: () => sin,
|
|
@@ -5649,6 +5683,7 @@ __export(numpy_exports, {
|
|
|
5649
5683
|
var_: () => var_,
|
|
5650
5684
|
vdot: () => vdot,
|
|
5651
5685
|
vecdot: () => vecdot,
|
|
5686
|
+
vecmat: () => vecmat,
|
|
5652
5687
|
vstack: () => vstack,
|
|
5653
5688
|
where: () => where,
|
|
5654
5689
|
zeros: () => zeros,
|
|
@@ -5712,6 +5747,22 @@ const notEqual = notEqual$1;
|
|
|
5712
5747
|
const greaterEqual = greaterEqual$1;
|
|
5713
5748
|
/** @function Compare two arrays element-wise. */
|
|
5714
5749
|
const lessEqual = lessEqual$1;
|
|
5750
|
+
/** Compute element-wise logical AND. */
|
|
5751
|
+
function logicalAnd(x, y) {
|
|
5752
|
+
return astype(x, DType.Bool).mul(astype(y, DType.Bool));
|
|
5753
|
+
}
|
|
5754
|
+
/** Compute element-wise logical OR. */
|
|
5755
|
+
function logicalOr(x, y) {
|
|
5756
|
+
return astype(x, DType.Bool).add(astype(y, DType.Bool));
|
|
5757
|
+
}
|
|
5758
|
+
/** Compute element-wise logical XOR. */
|
|
5759
|
+
function logicalXor(x, y) {
|
|
5760
|
+
return notEqual(astype(x, DType.Bool), astype(y, DType.Bool));
|
|
5761
|
+
}
|
|
5762
|
+
/** Compute element-wise logical NOT. */
|
|
5763
|
+
function logicalNot(x) {
|
|
5764
|
+
return notEqual(astype(x, DType.Bool), true);
|
|
5765
|
+
}
|
|
5715
5766
|
/** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
5716
5767
|
const where = where$1;
|
|
5717
5768
|
/**
|
|
@@ -5819,6 +5870,34 @@ function mean(a, axis = null, opts) {
|
|
|
5819
5870
|
return fudgeArray(a).mean(axis, opts);
|
|
5820
5871
|
}
|
|
5821
5872
|
/**
|
|
5873
|
+
* Compute the weighted average along the specified axis.
|
|
5874
|
+
*
|
|
5875
|
+
* If no axis is specified, mean is computed along all the axes. The weights
|
|
5876
|
+
* should have shape matching that of `a`, or if an axis is specified, it should
|
|
5877
|
+
* match the shape along those axes.
|
|
5878
|
+
*/
|
|
5879
|
+
function average(a, axis = null, opts) {
|
|
5880
|
+
a = fudgeArray(a);
|
|
5881
|
+
if (opts?.weights == null) return mean(a, axis, opts);
|
|
5882
|
+
const weights = fudgeArray(opts.weights);
|
|
5883
|
+
axis = normalizeAxis(axis, ndim(a));
|
|
5884
|
+
const wShape = weights.shape;
|
|
5885
|
+
const aShape = a.shape;
|
|
5886
|
+
if (deepEqual(wShape, aShape)) {
|
|
5887
|
+
const scl = sum(weights.ref, axis, opts);
|
|
5888
|
+
return sum(multiply(a, weights), axis, opts).div(scl);
|
|
5889
|
+
} else if (axis.length === 1 && wShape.length === 1 && wShape[0] === aShape[axis[0]]) {
|
|
5890
|
+
const broadcastShape = aShape.map((_, i) => i === axis[0] ? wShape[0] : 1);
|
|
5891
|
+
const wReshaped = reshape(weights, broadcastShape);
|
|
5892
|
+
const scl = sum(wReshaped.ref, axis, opts);
|
|
5893
|
+
return sum(multiply(a, wReshaped), axis, opts).div(scl);
|
|
5894
|
+
} else {
|
|
5895
|
+
weights.dispose();
|
|
5896
|
+
a.dispose();
|
|
5897
|
+
throw new Error(`average: weights shape ${JSON.stringify(wShape)} is not compatible with array shape ${JSON.stringify(aShape)} and axis ${JSON.stringify(axis)}`);
|
|
5898
|
+
}
|
|
5899
|
+
}
|
|
5900
|
+
/**
|
|
5822
5901
|
* Returns the indices of the minimum values along an axis.
|
|
5823
5902
|
*
|
|
5824
5903
|
* By default, index is into the flatted array, otherwise it is along the
|
|
@@ -6197,8 +6276,9 @@ function sort(a, axis = -1) {
|
|
|
6197
6276
|
return fudgeArray(a).sort(axis);
|
|
6198
6277
|
}
|
|
6199
6278
|
/**
|
|
6200
|
-
* Return indices that would sort an array.
|
|
6201
|
-
* algorithm; it
|
|
6279
|
+
* Return indices that would sort an array. Unlike `sort`, this is guaranteed to
|
|
6280
|
+
* be a stable sorting algorithm; it always returns the smaller index first in
|
|
6281
|
+
* event of ties.
|
|
6202
6282
|
*
|
|
6203
6283
|
* Returns an array of `int32` indices.
|
|
6204
6284
|
*
|
|
@@ -6221,20 +6301,63 @@ function take(a, indices, axis = null) {
|
|
|
6221
6301
|
axis = checkAxis(axis, ndim(a));
|
|
6222
6302
|
return gather(a, [indices], [axis], axis);
|
|
6223
6303
|
}
|
|
6224
|
-
/**
|
|
6304
|
+
/**
|
|
6305
|
+
* Return if two arrays are element-wise equal within a tolerance.
|
|
6306
|
+
*
|
|
6307
|
+
* The formula used is `|actual - expected| <= atol + rtol * |expected|`, with
|
|
6308
|
+
* NaN values comparing equal if `equalNaN` is true.
|
|
6309
|
+
*/
|
|
6225
6310
|
function allclose(actual, expected, options) {
|
|
6226
|
-
const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
|
|
6311
|
+
const { rtol = 1e-5, atol = 1e-7, equalNaN = false } = options ?? {};
|
|
6227
6312
|
const x = array(actual);
|
|
6228
6313
|
const y = array(expected);
|
|
6229
6314
|
if (!deepEqual(x.shape, y.shape)) return false;
|
|
6230
6315
|
const xData = x.dataSync();
|
|
6231
6316
|
const yData = y.dataSync();
|
|
6232
6317
|
for (let i = 0; i < xData.length; i++) {
|
|
6233
|
-
if (isNaN(xData[i]) !== isNaN(yData[i])) return false;
|
|
6318
|
+
if (equalNaN ? isNaN(xData[i]) !== isNaN(yData[i]) : isNaN(xData[i]) || isNaN(yData[i])) return false;
|
|
6234
6319
|
if (Math.abs(xData[i] - yData[i]) > atol + rtol * Math.abs(yData[i])) return false;
|
|
6235
6320
|
}
|
|
6236
6321
|
return true;
|
|
6237
6322
|
}
|
|
6323
|
+
/**
|
|
6324
|
+
* Check if two arrays are element-wise equal.
|
|
6325
|
+
*
|
|
6326
|
+
* Returns False if the arrays have different shapes. If `equalNaN` is True,
|
|
6327
|
+
* NaNs in the same position are considered equal.
|
|
6328
|
+
*/
|
|
6329
|
+
function arrayEqual(a1, a2, opts) {
|
|
6330
|
+
a1 = fudgeArray(a1);
|
|
6331
|
+
a2 = fudgeArray(a2);
|
|
6332
|
+
if (!deepEqual(a1.shape, a2.shape)) {
|
|
6333
|
+
a1.dispose();
|
|
6334
|
+
a2.dispose();
|
|
6335
|
+
return array(false);
|
|
6336
|
+
}
|
|
6337
|
+
if (opts?.equalNaN) {
|
|
6338
|
+
const nanMask = isnan(a1.ref).mul(isnan(a2.ref));
|
|
6339
|
+
return where(nanMask, true, equal(a1, a2)).all();
|
|
6340
|
+
}
|
|
6341
|
+
return equal(a1, a2).all();
|
|
6342
|
+
}
|
|
6343
|
+
/**
|
|
6344
|
+
* Check if two arrays are element-wise equal after broadcasting.
|
|
6345
|
+
*
|
|
6346
|
+
* Unlike `arrayEqual`, this allows inputs with different but
|
|
6347
|
+
* broadcast-compatible shapes.
|
|
6348
|
+
*/
|
|
6349
|
+
function arrayEquiv(a1, a2) {
|
|
6350
|
+
a1 = fudgeArray(a1);
|
|
6351
|
+
a2 = fudgeArray(a2);
|
|
6352
|
+
try {
|
|
6353
|
+
const [b1, b2] = broadcastArrays(a1, a2);
|
|
6354
|
+
return equal(b1, b2).all();
|
|
6355
|
+
} catch {
|
|
6356
|
+
a1.dispose();
|
|
6357
|
+
a2.dispose();
|
|
6358
|
+
return array(false);
|
|
6359
|
+
}
|
|
6360
|
+
}
|
|
6238
6361
|
/** Matrix product of two arrays. */
|
|
6239
6362
|
function matmul(x, y) {
|
|
6240
6363
|
if (ndim(x) === 0 || ndim(y) === 0) throw new Error("matmul: x and y must be at least 1D");
|
|
@@ -6248,6 +6371,16 @@ function matmul(x, y) {
|
|
|
6248
6371
|
rhsBatchDims: range(-2 - numBatchDims, -2)
|
|
6249
6372
|
});
|
|
6250
6373
|
}
|
|
6374
|
+
/** Matrix-vector product. x1 is [..., M, N], x2 is [..., N] → [..., M]. */
|
|
6375
|
+
function matvec(x1, x2) {
|
|
6376
|
+
if (ndim(x1) < 2 || ndim(x2) < 1) throw new Error("matvec: x1 must be at least 2D and x2 at least 1D");
|
|
6377
|
+
return einsum("...mn,...n->...m", x1, x2);
|
|
6378
|
+
}
|
|
6379
|
+
/** Vector-matrix product. x1 is [..., N], x2 is [..., N, M] → [..., M]. */
|
|
6380
|
+
function vecmat(x1, x2) {
|
|
6381
|
+
if (ndim(x1) < 1 || ndim(x2) < 2) throw new Error("vecmat: x1 must be at least 1D and x2 at least 2D");
|
|
6382
|
+
return einsum("...n,...nm->...m", x1, x2);
|
|
6383
|
+
}
|
|
6251
6384
|
/** Dot product of two arrays. */
|
|
6252
6385
|
function dot$1(x, y) {
|
|
6253
6386
|
if (ndim(x) === 0 || ndim(y) === 0) return multiply(x, y);
|
|
@@ -6406,6 +6539,49 @@ function outer(x, y) {
|
|
|
6406
6539
|
y = ravel(y);
|
|
6407
6540
|
return multiply(x.reshape([x.shape[0], 1]), y);
|
|
6408
6541
|
}
|
|
6542
|
+
/**
|
|
6543
|
+
* @function Compute the cross product of two arrays.
|
|
6544
|
+
*
|
|
6545
|
+
* Supports 2D (scalar result) and 3D cross products, with optional axis
|
|
6546
|
+
* arguments. If `axis` is given, it overrides `axisa`, `axisb`, and `axisc`.
|
|
6547
|
+
*/
|
|
6548
|
+
const cross = jit$1(function cross$2(a, b, { axisa = -1, axisb = -1, axisc = -1, axis } = {}) {
|
|
6549
|
+
if (axis !== void 0) {
|
|
6550
|
+
axisa = axis;
|
|
6551
|
+
axisb = axis;
|
|
6552
|
+
axisc = axis;
|
|
6553
|
+
}
|
|
6554
|
+
axisa = checkAxis(axisa, ndim(a));
|
|
6555
|
+
axisb = checkAxis(axisb, ndim(b));
|
|
6556
|
+
a = moveaxis$1(a, axisa, -1);
|
|
6557
|
+
b = moveaxis$1(b, axisb, -1);
|
|
6558
|
+
const da = a.shape.at(-1);
|
|
6559
|
+
const db = b.shape.at(-1);
|
|
6560
|
+
if (da !== 2 && da !== 3 || db !== 2 && db !== 3) throw new Error(`cross: incompatible dimensions for cross product (got ${da} and ${db})`);
|
|
6561
|
+
if (da === 2 && db === 2) {
|
|
6562
|
+
const [a0$1, a1$1] = split$1(a, 2, -1);
|
|
6563
|
+
const [b0$1, b1$1] = split$1(b, 2, -1);
|
|
6564
|
+
return squeeze(a0$1.mul(b1$1).sub(a1$1.mul(b0$1)), -1);
|
|
6565
|
+
}
|
|
6566
|
+
if (da === 2) {
|
|
6567
|
+
const zeroShape = [...a.shape.slice(0, -1), 1];
|
|
6568
|
+
a = concatenate([a, zeros(zeroShape)], -1);
|
|
6569
|
+
}
|
|
6570
|
+
if (db === 2) {
|
|
6571
|
+
const zeroShape = [...b.shape.slice(0, -1), 1];
|
|
6572
|
+
b = concatenate([b, zeros(zeroShape)], -1);
|
|
6573
|
+
}
|
|
6574
|
+
const [a0, a1, a2] = split$1(a, 3, -1);
|
|
6575
|
+
const [b0, b1, b2] = split$1(b, 3, -1);
|
|
6576
|
+
const c0 = a1.ref.mul(b2.ref).sub(a2.ref.mul(b1.ref));
|
|
6577
|
+
const c1 = a2.mul(b0.ref).sub(a0.ref.mul(b2));
|
|
6578
|
+
const c2 = a0.mul(b1).sub(a1.mul(b0));
|
|
6579
|
+
return moveaxis$1(concatenate([
|
|
6580
|
+
c0,
|
|
6581
|
+
c1,
|
|
6582
|
+
c2
|
|
6583
|
+
], -1), -1, axisc);
|
|
6584
|
+
}, { staticArgnums: [2] });
|
|
6409
6585
|
/** Vector dot product of two arrays along a given axis. */
|
|
6410
6586
|
function vecdot(x, y, { axis } = {}) {
|
|
6411
6587
|
const xaxis = checkAxis(axis ?? -1, ndim(x));
|
|
@@ -6500,18 +6676,17 @@ function absolute(x) {
|
|
|
6500
6676
|
/** Return an element-wise indication of sign of the input. */
|
|
6501
6677
|
function sign(x) {
|
|
6502
6678
|
x = fudgeArray(x);
|
|
6503
|
-
return where(notEqual(x.ref, 0), where(less(x
|
|
6679
|
+
return where(notEqual(x.ref, 0), where(less(x, 0), -1, 1), 0);
|
|
6504
6680
|
}
|
|
6505
|
-
/** @function Return element-wise positive values of the input (no-op). */
|
|
6506
|
-
const positive = fudgeArray;
|
|
6507
6681
|
/**
|
|
6508
|
-
*
|
|
6509
|
-
*
|
|
6510
|
-
* `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
6682
|
+
* @function
|
|
6683
|
+
* Return the value with the magnitude of x and the sign of y, element-wise.
|
|
6511
6684
|
*/
|
|
6512
|
-
function
|
|
6513
|
-
return
|
|
6514
|
-
}
|
|
6685
|
+
const copysign = jit$1(function copysign$1(x, y) {
|
|
6686
|
+
return absolute(x).mul(sign(y));
|
|
6687
|
+
});
|
|
6688
|
+
/** @function Return element-wise positive values of the input (no-op). */
|
|
6689
|
+
const positive = fudgeArray;
|
|
6515
6690
|
/**
|
|
6516
6691
|
* Return the Hann window of size M, a taper with a weighted cosine bell.
|
|
6517
6692
|
*
|
|
@@ -6657,6 +6832,27 @@ function trunc(x) {
|
|
|
6657
6832
|
return idiv(x, 1);
|
|
6658
6833
|
}
|
|
6659
6834
|
/**
|
|
6835
|
+
* @function
|
|
6836
|
+
* Round to the given number of decimals.
|
|
6837
|
+
*
|
|
6838
|
+
* Uses banker's rounding (round half to even) to match NumPy/JAX behavior.
|
|
6839
|
+
*/
|
|
6840
|
+
const round = jit$1(function round$1(a, decimals = 0) {
|
|
6841
|
+
if (decimals === 0) return rint(a);
|
|
6842
|
+
const factor = 10 ** decimals;
|
|
6843
|
+
return rint(a.mul(factor)).mul(1 / factor);
|
|
6844
|
+
}, { staticArgnums: [1] });
|
|
6845
|
+
/**
|
|
6846
|
+
* @function
|
|
6847
|
+
* Round to the nearest integer, with ties going to the nearest even integer.
|
|
6848
|
+
*/
|
|
6849
|
+
const rint = jit$1(function rint$1(x) {
|
|
6850
|
+
const rounded = floor(x.ref.add(.5));
|
|
6851
|
+
const half = x.ref.sub(floor(x)).equal(.5);
|
|
6852
|
+
const odd = remainder(rounded.ref, 2).notEqual(0);
|
|
6853
|
+
return where(half.mul(odd), rounded.ref.sub(1), rounded);
|
|
6854
|
+
});
|
|
6855
|
+
/**
|
|
6660
6856
|
* Compute `x1 * 2 ** x2` as a standard multiplication and exponentiation.
|
|
6661
6857
|
*
|
|
6662
6858
|
* This is the inverse of `frexp()`.
|
|
@@ -6984,6 +7180,7 @@ function triangularSolve(a, b, { leftSide = false, lower = false, transposeA = f
|
|
|
6984
7180
|
//#region src/library/lax.ts
|
|
6985
7181
|
var lax_exports = {};
|
|
6986
7182
|
__export(lax_exports, {
|
|
7183
|
+
bitcastConvertType: () => bitcastConvertType,
|
|
6987
7184
|
conv: () => conv,
|
|
6988
7185
|
convGeneralDilated: () => convGeneralDilated,
|
|
6989
7186
|
convTranspose: () => convTranspose,
|
|
@@ -6993,9 +7190,14 @@ __export(lax_exports, {
|
|
|
6993
7190
|
erfc: () => erfc,
|
|
6994
7191
|
linalg: () => lax_linalg_exports,
|
|
6995
7192
|
reduceWindow: () => reduceWindow,
|
|
6996
|
-
stopGradient: () => stopGradient$1
|
|
7193
|
+
stopGradient: () => stopGradient$1,
|
|
7194
|
+
topK: () => topK
|
|
6997
7195
|
});
|
|
6998
7196
|
const JsArray = globalThis.Array;
|
|
7197
|
+
/** Elementwise bitcast an array into a new dtype. */
|
|
7198
|
+
function bitcastConvertType(x, newDtype) {
|
|
7199
|
+
return fudgeArray(x).view(newDtype);
|
|
7200
|
+
}
|
|
6999
7201
|
/**
|
|
7000
7202
|
* General dot product/contraction operator.
|
|
7001
7203
|
*
|
|
@@ -7217,6 +7419,39 @@ function erfc(x) {
|
|
|
7217
7419
|
function stopGradient$1(x) {
|
|
7218
7420
|
return stopGradient(x);
|
|
7219
7421
|
}
|
|
7422
|
+
/**
|
|
7423
|
+
* Returns top `k` values and their indices along the specified axis of operand.
|
|
7424
|
+
*
|
|
7425
|
+
* This is a _stable_ algorithm: If two elements are equal, the lower-index
|
|
7426
|
+
* element appears first.
|
|
7427
|
+
*
|
|
7428
|
+
* @returns A tuple of `(values, indices)`, where `values` and `indices` have
|
|
7429
|
+
* the same shape as `x`, except along `axis` where they have size `k`.
|
|
7430
|
+
*/
|
|
7431
|
+
function topK(x, k, axis = -1) {
|
|
7432
|
+
x = fudgeArray(x);
|
|
7433
|
+
axis = checkAxis(axis, x.ndim);
|
|
7434
|
+
const size$1 = x.shape[axis];
|
|
7435
|
+
if (k < 0 || k > size$1) throw new Error(`topK: k must be in the range [0, ${size$1}], got ${k}`);
|
|
7436
|
+
if (k === 0) {
|
|
7437
|
+
const outShape = x.shape.slice();
|
|
7438
|
+
outShape[axis] = 0;
|
|
7439
|
+
const y$1 = zerosLike$1(x.ref, { shape: outShape });
|
|
7440
|
+
const yi$1 = zerosLike$1(x, {
|
|
7441
|
+
dtype: DType.Int32,
|
|
7442
|
+
shape: outShape
|
|
7443
|
+
});
|
|
7444
|
+
return [y$1, yi$1];
|
|
7445
|
+
}
|
|
7446
|
+
x = flip$1(x, [axis]);
|
|
7447
|
+
x = moveaxis(x, axis, -1);
|
|
7448
|
+
const [y, yi] = argsort$1(x);
|
|
7449
|
+
const extract = (a) => {
|
|
7450
|
+
a = a.slice(...rep(a.ndim - 1, []), [-k]);
|
|
7451
|
+
return flip$1(moveaxis(a, -1, axis), [axis]);
|
|
7452
|
+
};
|
|
7453
|
+
return [extract(y), extract(yi.neg().add(size$1 - 1))];
|
|
7454
|
+
}
|
|
7220
7455
|
|
|
7221
7456
|
//#endregion
|
|
7222
7457
|
//#region src/library/nn.ts
|
|
@@ -7408,7 +7643,7 @@ const gelu = jit$1(function gelu$1(x, opts) {
|
|
|
7408
7643
|
if (opts?.approximate ?? true) {
|
|
7409
7644
|
const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
|
|
7410
7645
|
return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
|
|
7411
|
-
} else return x.ref.mul(.5).mul(erfc$1(negative(x.
|
|
7646
|
+
} else return x.ref.mul(.5).mul(erfc$1(negative(x.mul(Math.SQRT1_2))));
|
|
7412
7647
|
}, { staticArgnums: [1] });
|
|
7413
7648
|
/**
|
|
7414
7649
|
* Gated linear unit (GLU) activation function.
|
|
@@ -7666,6 +7901,7 @@ var random_exports = {};
|
|
|
7666
7901
|
__export(random_exports, {
|
|
7667
7902
|
bernoulli: () => bernoulli,
|
|
7668
7903
|
bits: () => bits,
|
|
7904
|
+
categorical: () => categorical,
|
|
7669
7905
|
cauchy: () => cauchy,
|
|
7670
7906
|
exponential: () => exponential,
|
|
7671
7907
|
gumbel: () => gumbel,
|
|
@@ -7693,7 +7929,9 @@ function getK01(key$1) {
|
|
|
7693
7929
|
function key(seed) {
|
|
7694
7930
|
seed = array(seed, { dtype: DType.Uint32 });
|
|
7695
7931
|
if (seed.ndim !== 0) throw new Error(`key: seed must be a scalar integer, but got shape ${seed.shape} - use jax.vmap for batching.`);
|
|
7696
|
-
|
|
7932
|
+
const key$1 = stack([0, seed]);
|
|
7933
|
+
if (key$1 instanceof Array$1) key$1._realizeSource();
|
|
7934
|
+
return key$1;
|
|
7697
7935
|
}
|
|
7698
7936
|
/** Splits a PRNG key into `num` new keys by adding a leading axis. */
|
|
7699
7937
|
function split(key$1, num = 2) {
|
|
@@ -7737,6 +7975,47 @@ function bernoulli(key$1, p = .5, shape$1 = []) {
|
|
|
7737
7975
|
}
|
|
7738
7976
|
/**
|
|
7739
7977
|
* @function
|
|
7978
|
+
* Sample random values from categorical distributions.
|
|
7979
|
+
*
|
|
7980
|
+
* Uses the Gumbel max trick for sampling with replacement, or the Gumbel top-k
|
|
7981
|
+
* trick for sampling without replacement.
|
|
7982
|
+
*
|
|
7983
|
+
* Note: Sampling without replacement currently uses argsort and slices the last
|
|
7984
|
+
* k elements. This should be replaced with a more efficient topK implementation.
|
|
7985
|
+
*
|
|
7986
|
+
* - `key` - PRNG key
|
|
7987
|
+
* - `logits` - Unnormalized log probabilities of the categorical distribution(s).
|
|
7988
|
+
* `softmax(logits, axis)` gives the corresponding probabilities.
|
|
7989
|
+
* - `axis` - Axis along which logits belong to the same categorical distribution.
|
|
7990
|
+
* - `shape` - Result batch shape. Must be broadcast-compatible with
|
|
7991
|
+
* `logits.shape` with `axis` removed. Default is `logits.shape` with `axis` removed.
|
|
7992
|
+
* - `replace` - If true (default), sample with replacement. If false, sample
|
|
7993
|
+
* without replacement (each category can only be selected once per batch).
|
|
7994
|
+
* @returns A random array with int dtype and shape given by `shape` if provided,
|
|
7995
|
+
* otherwise `logits.shape` with `axis` removed.
|
|
7996
|
+
*/
|
|
7997
|
+
const categorical = jit$1(function categorical$1(key$1, logits, { axis = -1, shape: shape$1, replace = true } = {}) {
|
|
7998
|
+
logits = fudgeArray(logits);
|
|
7999
|
+
axis = checkAxis(axis, logits.ndim);
|
|
8000
|
+
const numCategories = logits.shape[axis];
|
|
8001
|
+
const batchShape = logits.shape.toSpliced(axis, 1);
|
|
8002
|
+
if (shape$1 === void 0) shape$1 = batchShape;
|
|
8003
|
+
else if (!deepEqual(generalBroadcast(shape$1, batchShape), shape$1)) throw new Error(`Shape ${shape$1} is not broadcast-compatible with batch shape ${batchShape}.`);
|
|
8004
|
+
const shapePrefix = shape$1.slice(0, shape$1.length - batchShape.length);
|
|
8005
|
+
if (replace) {
|
|
8006
|
+
const noise = gumbel(key$1, [...shapePrefix, ...logits.shape]);
|
|
8007
|
+
return argmax(noise.add(logits), axis + shapePrefix.length);
|
|
8008
|
+
} else {
|
|
8009
|
+
const k = shapePrefix.reduce((a, b) => a * b, 1);
|
|
8010
|
+
if (k > numCategories) throw new Error(`Number of samples without replacement (${k}) cannot exceed number of categories (${numCategories}).`);
|
|
8011
|
+
const noise = gumbel(key$1, logits.shape);
|
|
8012
|
+
const [values, indices] = topK(noise.add(logits), k, axis);
|
|
8013
|
+
values.dispose();
|
|
8014
|
+
return indices.reshape(shape$1);
|
|
8015
|
+
}
|
|
8016
|
+
}, { staticArgnums: [2] });
|
|
8017
|
+
/**
|
|
8018
|
+
* @function
|
|
7740
8019
|
* Sample from a Cauchy distribution with location 0 and scale 1.
|
|
7741
8020
|
*
|
|
7742
8021
|
* Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
|
|
@@ -7847,6 +8126,11 @@ Symbol.asyncDispose ??= Symbol.for("Symbol.asyncDispose");
|
|
|
7847
8126
|
|
|
7848
8127
|
//#endregion
|
|
7849
8128
|
//#region src/index.ts
|
|
8129
|
+
/** @namespace */
|
|
8130
|
+
const profiler = {
|
|
8131
|
+
startTrace,
|
|
8132
|
+
stopTrace
|
|
8133
|
+
};
|
|
7850
8134
|
/**
|
|
7851
8135
|
* @function
|
|
7852
8136
|
* Compute the forward-mode Jacobian-vector product for a function.
|
|
@@ -8007,4 +8291,4 @@ async function devicePut(x, device) {
|
|
|
8007
8291
|
}
|
|
8008
8292
|
|
|
8009
8293
|
//#endregion
|
|
8010
|
-
export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
|
|
8294
|
+
export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, profiler, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-
|
|
1
|
+
import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-Ctqs8la1.js";
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgl/builtins.ts
|
|
4
4
|
const threefrySrc = `
|