@jax-js/jax 0.1.8 → 0.1.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +46 -29
- package/dist/{backend-nEolvdLv.js → backend-Ctqs8la1.js} +122 -15
- package/dist/{backend-B3foXiV_.cjs → backend-DMauYnfl.cjs} +157 -14
- package/dist/index.cjs +331 -46
- package/dist/index.d.cts +175 -31
- package/dist/index.d.ts +175 -31
- package/dist/index.js +331 -47
- package/dist/{webgl-DweKSWEm.js → webgl-CvQ1QBX1.js} +1 -1
- package/dist/{webgl-DIIbKJ0G.cjs → webgl-kvVt7-T7.cjs} +1 -1
- package/dist/{webgpu-BykvF26B.cjs → webgpu-DMSx7a6M.cjs} +160 -15
- package/dist/{webgpu-B96vzWGE.js → webgpu-v_W_-oKw.js} +160 -15
- package/package.json +5 -16
package/dist/index.cjs
CHANGED
|
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
|
|
|
30
30
|
}) : target, mod$1));
|
|
31
31
|
|
|
32
32
|
//#endregion
|
|
33
|
-
const require_backend = require('./backend-
|
|
33
|
+
const require_backend = require('./backend-DMauYnfl.cjs');
|
|
34
34
|
|
|
35
35
|
//#region src/frontend/convolution.ts
|
|
36
36
|
/**
|
|
@@ -838,6 +838,11 @@ var Tracer = class Tracer {
|
|
|
838
838
|
if (this.dtype === dtype) return this;
|
|
839
839
|
return cast(this, dtype);
|
|
840
840
|
}
|
|
841
|
+
/** Return a bitwise cast of the array, viewed as a new dtype. */
|
|
842
|
+
view(dtype) {
|
|
843
|
+
if (!dtype || dtype === this.dtype) return this;
|
|
844
|
+
return bitcast(this, dtype);
|
|
845
|
+
}
|
|
841
846
|
/** Subtract an array from this one. */
|
|
842
847
|
sub(other) {
|
|
843
848
|
return this.add(neg(other));
|
|
@@ -920,18 +925,25 @@ var Tracer = class Tracer {
|
|
|
920
925
|
return sort$1(this.transpose(perm)).transpose(require_backend.invertPermutation(perm));
|
|
921
926
|
}
|
|
922
927
|
/**
|
|
923
|
-
* Return the indices that would sort an array.
|
|
924
|
-
* sorting algorithm; it
|
|
928
|
+
* Return the indices that would sort an array. Unlike `sort`, this is
|
|
929
|
+
* guaranteed to be a stable sorting algorithm; it always returns the smaller
|
|
930
|
+
* index first in event of ties.
|
|
925
931
|
*
|
|
926
932
|
* See `jax.numpy.argsort` for full docs.
|
|
927
933
|
*/
|
|
928
934
|
argsort(axis = -1) {
|
|
929
935
|
axis = require_backend.checkAxis(axis, this.ndim);
|
|
930
|
-
if (axis === this.ndim - 1)
|
|
936
|
+
if (axis === this.ndim - 1) {
|
|
937
|
+
const [y$1, yi$1] = argsort$1(this);
|
|
938
|
+
y$1.dispose();
|
|
939
|
+
return yi$1;
|
|
940
|
+
}
|
|
931
941
|
const perm = require_backend.range(this.ndim);
|
|
932
942
|
perm.splice(axis, 1);
|
|
933
943
|
perm.push(axis);
|
|
934
|
-
|
|
944
|
+
const [y, yi] = argsort$1(this.transpose(perm));
|
|
945
|
+
y.dispose();
|
|
946
|
+
return yi.transpose(require_backend.invertPermutation(perm));
|
|
935
947
|
}
|
|
936
948
|
/**
|
|
937
949
|
* Slice an array along one or more axes.
|
|
@@ -1652,7 +1664,7 @@ const abstractEvalRules = {
|
|
|
1652
1664
|
return [new ShapedArray(x.shape, dtype, false)];
|
|
1653
1665
|
},
|
|
1654
1666
|
[Primitive.Bitcast]([x], { dtype }) {
|
|
1655
|
-
if (x.dtype === require_backend.DType.Bool || dtype === require_backend.DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
|
|
1667
|
+
if (x.dtype !== dtype && (x.dtype === require_backend.DType.Bool || dtype === require_backend.DType.Bool)) throw new TypeError("Bitcast to/from bool is not allowed");
|
|
1656
1668
|
if (require_backend.byteWidth(x.dtype) !== require_backend.byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
|
|
1657
1669
|
return [new ShapedArray(x.shape, dtype, false)];
|
|
1658
1670
|
},
|
|
@@ -3074,8 +3086,8 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
3074
3086
|
return [x.#unary(require_backend.AluOp.Cast, dtype)];
|
|
3075
3087
|
},
|
|
3076
3088
|
[Primitive.Bitcast]([x], { dtype }) {
|
|
3077
|
-
if (x.dtype === require_backend.DType.Bool || dtype === require_backend.DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
|
|
3078
3089
|
if (x.dtype === dtype) return [x];
|
|
3090
|
+
if (x.dtype === require_backend.DType.Bool || dtype === require_backend.DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
|
|
3079
3091
|
if (require_backend.byteWidth(x.dtype) !== require_backend.byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
|
|
3080
3092
|
if (x.#source instanceof require_backend.AluExp) return [x.#unary(require_backend.AluOp.Bitcast, dtype)];
|
|
3081
3093
|
else {
|
|
@@ -3416,32 +3428,26 @@ function fullInternal(aval, fillValue, device) {
|
|
|
3416
3428
|
committed: device != void 0
|
|
3417
3429
|
});
|
|
3418
3430
|
}
|
|
3419
|
-
function zerosLike$1(val,
|
|
3420
|
-
return fullLike(val, 0,
|
|
3431
|
+
function zerosLike$1(val, opts) {
|
|
3432
|
+
return fullLike(val, 0, opts);
|
|
3421
3433
|
}
|
|
3422
|
-
function onesLike$1(val,
|
|
3423
|
-
return fullLike(val, 1,
|
|
3434
|
+
function onesLike$1(val, opts) {
|
|
3435
|
+
return fullLike(val, 1, opts);
|
|
3424
3436
|
}
|
|
3425
|
-
function fullLike(val, fillValue, dtype) {
|
|
3437
|
+
function fullLike(val, fillValue, { dtype, shape: shape$1, device } = {}) {
|
|
3426
3438
|
const aval = getAval(val);
|
|
3427
3439
|
if (val instanceof Tracer) val.dispose();
|
|
3428
3440
|
if (fillValue instanceof Tracer) throw new Error("numpy.fullLike() with array argument not implemented yet");
|
|
3429
|
-
const sa = new ShapedArray(aval.shape, dtype ?? aval.dtype, aval.weakType);
|
|
3430
|
-
return fullInternal(sa, fillValue);
|
|
3441
|
+
const sa = new ShapedArray(shape$1 ?? aval.shape, dtype ?? aval.dtype, aval.weakType && dtype === void 0);
|
|
3442
|
+
return fullInternal(sa, fillValue, device);
|
|
3431
3443
|
}
|
|
3432
3444
|
/** Return a new array of given shape and type, filled with zeros. */
|
|
3433
|
-
function zeros(shape$1,
|
|
3434
|
-
return full(shape$1, 0,
|
|
3435
|
-
dtype,
|
|
3436
|
-
device
|
|
3437
|
-
});
|
|
3445
|
+
function zeros(shape$1, opts) {
|
|
3446
|
+
return full(shape$1, 0, opts);
|
|
3438
3447
|
}
|
|
3439
3448
|
/** Return a new array of given shape and type, filled with ones. */
|
|
3440
|
-
function ones(shape$1,
|
|
3441
|
-
return full(shape$1, 1,
|
|
3442
|
-
dtype,
|
|
3443
|
-
device
|
|
3444
|
-
});
|
|
3449
|
+
function ones(shape$1, opts) {
|
|
3450
|
+
return full(shape$1, 1, opts);
|
|
3445
3451
|
}
|
|
3446
3452
|
/** Return a new array of given shape and type, filled with `fill_value`. */
|
|
3447
3453
|
function full(shape$1, fillValue, { dtype, device } = {}) {
|
|
@@ -4178,6 +4184,7 @@ const jvpRules = {
|
|
|
4178
4184
|
},
|
|
4179
4185
|
[Primitive.TriangularSolve]([a, b], [da, db], { unitDiagonal }) {
|
|
4180
4186
|
const x = triangularSolve$1(a.ref, b, { unitDiagonal });
|
|
4187
|
+
da = unitDiagonal ? triu(da, 1) : triu(da);
|
|
4181
4188
|
const dax = batchMatmulT(da, x.ref);
|
|
4182
4189
|
const rhsT = db.sub(mT(dax));
|
|
4183
4190
|
const dx = triangularSolve$1(a, rhsT, { unitDiagonal });
|
|
@@ -5253,6 +5260,7 @@ function ifft(a, axis = -1) {
|
|
|
5253
5260
|
var numpy_linalg_exports = {};
|
|
5254
5261
|
__export(numpy_linalg_exports, {
|
|
5255
5262
|
cholesky: () => cholesky,
|
|
5263
|
+
cross: () => cross$1,
|
|
5256
5264
|
det: () => det,
|
|
5257
5265
|
diagonal: () => diagonal,
|
|
5258
5266
|
inv: () => inv,
|
|
@@ -5283,6 +5291,19 @@ function cholesky(a, { upper = false, symmetrizeInput = true } = {}) {
|
|
|
5283
5291
|
if (symmetrizeInput) a = a.ref.add(matrixTranspose(a)).mul(.5);
|
|
5284
5292
|
return cholesky$1(a, { upper });
|
|
5285
5293
|
}
|
|
5294
|
+
/**
|
|
5295
|
+
* Compute the cross-product of two 3D vectors.
|
|
5296
|
+
*
|
|
5297
|
+
* This is a simpler and less flexible version of `jax.numpy.cross()`.
|
|
5298
|
+
* Both inputs must have size 3 along the specified axis.
|
|
5299
|
+
*/
|
|
5300
|
+
function cross$1(x1, x2, axis = -1) {
|
|
5301
|
+
const a1 = require_backend.checkAxis(axis, ndim(x1));
|
|
5302
|
+
const a2 = require_backend.checkAxis(axis, ndim(x2));
|
|
5303
|
+
if (shape(x1)[a1] !== 3) throw new Error(`linalg.cross: x1 must have size 3 along axis ${axis}, got ${shape(x1)[a1]}`);
|
|
5304
|
+
if (shape(x2)[a2] !== 3) throw new Error(`linalg.cross: x2 must have size 3 along axis ${axis}, got ${shape(x2)[a2]}`);
|
|
5305
|
+
return cross(x1, x2, { axis });
|
|
5306
|
+
}
|
|
5286
5307
|
/** Compute the determinant of a square matrix (batched). */
|
|
5287
5308
|
function det(a) {
|
|
5288
5309
|
a = fudgeArray(a);
|
|
@@ -5298,7 +5319,7 @@ function det(a) {
|
|
|
5298
5319
|
function inv(a) {
|
|
5299
5320
|
a = fudgeArray(a);
|
|
5300
5321
|
const n = checkSquare("inv", a);
|
|
5301
|
-
return solve(a, eye(n));
|
|
5322
|
+
return solve(a, eye(n, void 0, { dtype: a.dtype }));
|
|
5302
5323
|
}
|
|
5303
5324
|
/**
|
|
5304
5325
|
* Return the least-squares solution to a linear equation.
|
|
@@ -5332,7 +5353,7 @@ function lstsq(a, b) {
|
|
|
5332
5353
|
lower: true,
|
|
5333
5354
|
transposeA: true
|
|
5334
5355
|
});
|
|
5335
|
-
return matmul(at, llb
|
|
5356
|
+
return matmul(at, llb);
|
|
5336
5357
|
} else {
|
|
5337
5358
|
const ata = matmul(at.ref, a);
|
|
5338
5359
|
const l = cholesky(ata, { symmetrizeInput: false });
|
|
@@ -5355,8 +5376,9 @@ function matrixPower(a, n) {
|
|
|
5355
5376
|
a = fudgeArray(a);
|
|
5356
5377
|
const m = checkSquare("matrixPower", a);
|
|
5357
5378
|
if (n === 0) {
|
|
5379
|
+
const dtype = a.dtype;
|
|
5358
5380
|
a.dispose();
|
|
5359
|
-
return broadcastTo(eye(m), a.shape);
|
|
5381
|
+
return broadcastTo(eye(m, void 0, { dtype }), a.shape);
|
|
5360
5382
|
}
|
|
5361
5383
|
if (n < 0) {
|
|
5362
5384
|
a = inv(a);
|
|
@@ -5423,7 +5445,7 @@ function solve(a, b) {
|
|
|
5423
5445
|
lower: true,
|
|
5424
5446
|
unitDiagonal: true
|
|
5425
5447
|
});
|
|
5426
|
-
let x = triangularSolve(lu$2, LPb
|
|
5448
|
+
let x = triangularSolve(lu$2, LPb, {
|
|
5427
5449
|
leftSide: true,
|
|
5428
5450
|
lower: false
|
|
5429
5451
|
});
|
|
@@ -5538,13 +5560,17 @@ __export(numpy_exports, {
|
|
|
5538
5560
|
argmax: () => argmax,
|
|
5539
5561
|
argmin: () => argmin,
|
|
5540
5562
|
argsort: () => argsort,
|
|
5563
|
+
around: () => round,
|
|
5541
5564
|
array: () => array,
|
|
5565
|
+
arrayEqual: () => arrayEqual,
|
|
5566
|
+
arrayEquiv: () => arrayEquiv,
|
|
5542
5567
|
asin: () => asin,
|
|
5543
5568
|
asinh: () => arcsinh,
|
|
5544
5569
|
astype: () => astype,
|
|
5545
5570
|
atan: () => atan,
|
|
5546
5571
|
atan2: () => atan2,
|
|
5547
5572
|
atanh: () => arctanh,
|
|
5573
|
+
average: () => average,
|
|
5548
5574
|
bool: () => bool,
|
|
5549
5575
|
broadcastArrays: () => broadcastArrays,
|
|
5550
5576
|
broadcastShapes: () => broadcastShapes,
|
|
@@ -5555,11 +5581,13 @@ __export(numpy_exports, {
|
|
|
5555
5581
|
columnStack: () => columnStack,
|
|
5556
5582
|
concatenate: () => concatenate,
|
|
5557
5583
|
convolve: () => convolve,
|
|
5584
|
+
copysign: () => copysign,
|
|
5558
5585
|
corrcoef: () => corrcoef,
|
|
5559
5586
|
correlate: () => correlate,
|
|
5560
5587
|
cos: () => cos,
|
|
5561
5588
|
cosh: () => cosh,
|
|
5562
5589
|
cov: () => cov,
|
|
5590
|
+
cross: () => cross,
|
|
5563
5591
|
cumsum: () => cumsum,
|
|
5564
5592
|
cumulativeSum: () => cumsum,
|
|
5565
5593
|
deg2rad: () => deg2rad,
|
|
@@ -5595,7 +5623,6 @@ __export(numpy_exports, {
|
|
|
5595
5623
|
fullLike: () => fullLike$1,
|
|
5596
5624
|
greater: () => greater,
|
|
5597
5625
|
greaterEqual: () => greaterEqual,
|
|
5598
|
-
hamming: () => hamming,
|
|
5599
5626
|
hann: () => hann,
|
|
5600
5627
|
heaviside: () => heaviside,
|
|
5601
5628
|
hstack: () => hstack,
|
|
@@ -5619,9 +5646,14 @@ __export(numpy_exports, {
|
|
|
5619
5646
|
log10: () => log10,
|
|
5620
5647
|
log1p: () => log1p,
|
|
5621
5648
|
log2: () => log2,
|
|
5649
|
+
logicalAnd: () => logicalAnd,
|
|
5650
|
+
logicalNot: () => logicalNot,
|
|
5651
|
+
logicalOr: () => logicalOr,
|
|
5652
|
+
logicalXor: () => logicalXor,
|
|
5622
5653
|
logspace: () => logspace,
|
|
5623
5654
|
matmul: () => matmul,
|
|
5624
5655
|
matrixTranspose: () => matrixTranspose,
|
|
5656
|
+
matvec: () => matvec,
|
|
5625
5657
|
max: () => max,
|
|
5626
5658
|
maximum: () => maximum,
|
|
5627
5659
|
mean: () => mean,
|
|
@@ -5654,6 +5686,8 @@ __export(numpy_exports, {
|
|
|
5654
5686
|
remainder: () => remainder,
|
|
5655
5687
|
repeat: () => repeat,
|
|
5656
5688
|
reshape: () => reshape,
|
|
5689
|
+
rint: () => rint,
|
|
5690
|
+
round: () => round,
|
|
5657
5691
|
shape: () => shape,
|
|
5658
5692
|
sign: () => sign,
|
|
5659
5693
|
sin: () => sin,
|
|
@@ -5686,6 +5720,7 @@ __export(numpy_exports, {
|
|
|
5686
5720
|
var_: () => var_,
|
|
5687
5721
|
vdot: () => vdot,
|
|
5688
5722
|
vecdot: () => vecdot,
|
|
5723
|
+
vecmat: () => vecmat,
|
|
5689
5724
|
vstack: () => vstack,
|
|
5690
5725
|
where: () => where,
|
|
5691
5726
|
zeros: () => zeros,
|
|
@@ -5749,6 +5784,22 @@ const notEqual = notEqual$1;
|
|
|
5749
5784
|
const greaterEqual = greaterEqual$1;
|
|
5750
5785
|
/** @function Compare two arrays element-wise. */
|
|
5751
5786
|
const lessEqual = lessEqual$1;
|
|
5787
|
+
/** Compute element-wise logical AND. */
|
|
5788
|
+
function logicalAnd(x, y) {
|
|
5789
|
+
return astype(x, require_backend.DType.Bool).mul(astype(y, require_backend.DType.Bool));
|
|
5790
|
+
}
|
|
5791
|
+
/** Compute element-wise logical OR. */
|
|
5792
|
+
function logicalOr(x, y) {
|
|
5793
|
+
return astype(x, require_backend.DType.Bool).add(astype(y, require_backend.DType.Bool));
|
|
5794
|
+
}
|
|
5795
|
+
/** Compute element-wise logical XOR. */
|
|
5796
|
+
function logicalXor(x, y) {
|
|
5797
|
+
return notEqual(astype(x, require_backend.DType.Bool), astype(y, require_backend.DType.Bool));
|
|
5798
|
+
}
|
|
5799
|
+
/** Compute element-wise logical NOT. */
|
|
5800
|
+
function logicalNot(x) {
|
|
5801
|
+
return notEqual(astype(x, require_backend.DType.Bool), true);
|
|
5802
|
+
}
|
|
5752
5803
|
/** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
5753
5804
|
const where = where$1;
|
|
5754
5805
|
/**
|
|
@@ -5856,6 +5907,34 @@ function mean(a, axis = null, opts) {
|
|
|
5856
5907
|
return fudgeArray(a).mean(axis, opts);
|
|
5857
5908
|
}
|
|
5858
5909
|
/**
|
|
5910
|
+
* Compute the weighted average along the specified axis.
|
|
5911
|
+
*
|
|
5912
|
+
* If no axis is specified, mean is computed along all the axes. The weights
|
|
5913
|
+
* should have shape matching that of `a`, or if an axis is specified, it should
|
|
5914
|
+
* match the shape along those axes.
|
|
5915
|
+
*/
|
|
5916
|
+
function average(a, axis = null, opts) {
|
|
5917
|
+
a = fudgeArray(a);
|
|
5918
|
+
if (opts?.weights == null) return mean(a, axis, opts);
|
|
5919
|
+
const weights = fudgeArray(opts.weights);
|
|
5920
|
+
axis = require_backend.normalizeAxis(axis, ndim(a));
|
|
5921
|
+
const wShape = weights.shape;
|
|
5922
|
+
const aShape = a.shape;
|
|
5923
|
+
if (require_backend.deepEqual(wShape, aShape)) {
|
|
5924
|
+
const scl = sum(weights.ref, axis, opts);
|
|
5925
|
+
return sum(multiply(a, weights), axis, opts).div(scl);
|
|
5926
|
+
} else if (axis.length === 1 && wShape.length === 1 && wShape[0] === aShape[axis[0]]) {
|
|
5927
|
+
const broadcastShape = aShape.map((_, i) => i === axis[0] ? wShape[0] : 1);
|
|
5928
|
+
const wReshaped = reshape(weights, broadcastShape);
|
|
5929
|
+
const scl = sum(wReshaped.ref, axis, opts);
|
|
5930
|
+
return sum(multiply(a, wReshaped), axis, opts).div(scl);
|
|
5931
|
+
} else {
|
|
5932
|
+
weights.dispose();
|
|
5933
|
+
a.dispose();
|
|
5934
|
+
throw new Error(`average: weights shape ${JSON.stringify(wShape)} is not compatible with array shape ${JSON.stringify(aShape)} and axis ${JSON.stringify(axis)}`);
|
|
5935
|
+
}
|
|
5936
|
+
}
|
|
5937
|
+
/**
|
|
5859
5938
|
* Returns the indices of the minimum values along an axis.
|
|
5860
5939
|
*
|
|
5861
5940
|
* By default, index is into the flatted array, otherwise it is along the
|
|
@@ -6234,8 +6313,9 @@ function sort(a, axis = -1) {
|
|
|
6234
6313
|
return fudgeArray(a).sort(axis);
|
|
6235
6314
|
}
|
|
6236
6315
|
/**
|
|
6237
|
-
* Return indices that would sort an array.
|
|
6238
|
-
* algorithm; it
|
|
6316
|
+
* Return indices that would sort an array. Unlike `sort`, this is guaranteed to
|
|
6317
|
+
* be a stable sorting algorithm; it always returns the smaller index first in
|
|
6318
|
+
* event of ties.
|
|
6239
6319
|
*
|
|
6240
6320
|
* Returns an array of `int32` indices.
|
|
6241
6321
|
*
|
|
@@ -6258,20 +6338,63 @@ function take(a, indices, axis = null) {
|
|
|
6258
6338
|
axis = require_backend.checkAxis(axis, ndim(a));
|
|
6259
6339
|
return gather(a, [indices], [axis], axis);
|
|
6260
6340
|
}
|
|
6261
|
-
/**
|
|
6341
|
+
/**
|
|
6342
|
+
* Return if two arrays are element-wise equal within a tolerance.
|
|
6343
|
+
*
|
|
6344
|
+
* The formula used is `|actual - expected| <= atol + rtol * |expected|`, with
|
|
6345
|
+
* NaN values comparing equal if `equalNaN` is true.
|
|
6346
|
+
*/
|
|
6262
6347
|
function allclose(actual, expected, options) {
|
|
6263
|
-
const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
|
|
6348
|
+
const { rtol = 1e-5, atol = 1e-7, equalNaN = false } = options ?? {};
|
|
6264
6349
|
const x = array(actual);
|
|
6265
6350
|
const y = array(expected);
|
|
6266
6351
|
if (!require_backend.deepEqual(x.shape, y.shape)) return false;
|
|
6267
6352
|
const xData = x.dataSync();
|
|
6268
6353
|
const yData = y.dataSync();
|
|
6269
6354
|
for (let i = 0; i < xData.length; i++) {
|
|
6270
|
-
if (isNaN(xData[i]) !== isNaN(yData[i])) return false;
|
|
6355
|
+
if (equalNaN ? isNaN(xData[i]) !== isNaN(yData[i]) : isNaN(xData[i]) || isNaN(yData[i])) return false;
|
|
6271
6356
|
if (Math.abs(xData[i] - yData[i]) > atol + rtol * Math.abs(yData[i])) return false;
|
|
6272
6357
|
}
|
|
6273
6358
|
return true;
|
|
6274
6359
|
}
|
|
6360
|
+
/**
|
|
6361
|
+
* Check if two arrays are element-wise equal.
|
|
6362
|
+
*
|
|
6363
|
+
* Returns False if the arrays have different shapes. If `equalNaN` is True,
|
|
6364
|
+
* NaNs in the same position are considered equal.
|
|
6365
|
+
*/
|
|
6366
|
+
function arrayEqual(a1, a2, opts) {
|
|
6367
|
+
a1 = fudgeArray(a1);
|
|
6368
|
+
a2 = fudgeArray(a2);
|
|
6369
|
+
if (!require_backend.deepEqual(a1.shape, a2.shape)) {
|
|
6370
|
+
a1.dispose();
|
|
6371
|
+
a2.dispose();
|
|
6372
|
+
return array(false);
|
|
6373
|
+
}
|
|
6374
|
+
if (opts?.equalNaN) {
|
|
6375
|
+
const nanMask = isnan(a1.ref).mul(isnan(a2.ref));
|
|
6376
|
+
return where(nanMask, true, equal(a1, a2)).all();
|
|
6377
|
+
}
|
|
6378
|
+
return equal(a1, a2).all();
|
|
6379
|
+
}
|
|
6380
|
+
/**
|
|
6381
|
+
* Check if two arrays are element-wise equal after broadcasting.
|
|
6382
|
+
*
|
|
6383
|
+
* Unlike `arrayEqual`, this allows inputs with different but
|
|
6384
|
+
* broadcast-compatible shapes.
|
|
6385
|
+
*/
|
|
6386
|
+
function arrayEquiv(a1, a2) {
|
|
6387
|
+
a1 = fudgeArray(a1);
|
|
6388
|
+
a2 = fudgeArray(a2);
|
|
6389
|
+
try {
|
|
6390
|
+
const [b1, b2] = broadcastArrays(a1, a2);
|
|
6391
|
+
return equal(b1, b2).all();
|
|
6392
|
+
} catch {
|
|
6393
|
+
a1.dispose();
|
|
6394
|
+
a2.dispose();
|
|
6395
|
+
return array(false);
|
|
6396
|
+
}
|
|
6397
|
+
}
|
|
6275
6398
|
/** Matrix product of two arrays. */
|
|
6276
6399
|
function matmul(x, y) {
|
|
6277
6400
|
if (ndim(x) === 0 || ndim(y) === 0) throw new Error("matmul: x and y must be at least 1D");
|
|
@@ -6285,6 +6408,16 @@ function matmul(x, y) {
|
|
|
6285
6408
|
rhsBatchDims: require_backend.range(-2 - numBatchDims, -2)
|
|
6286
6409
|
});
|
|
6287
6410
|
}
|
|
6411
|
+
/** Matrix-vector product. x1 is [..., M, N], x2 is [..., N] → [..., M]. */
|
|
6412
|
+
function matvec(x1, x2) {
|
|
6413
|
+
if (ndim(x1) < 2 || ndim(x2) < 1) throw new Error("matvec: x1 must be at least 2D and x2 at least 1D");
|
|
6414
|
+
return einsum("...mn,...n->...m", x1, x2);
|
|
6415
|
+
}
|
|
6416
|
+
/** Vector-matrix product. x1 is [..., N], x2 is [..., N, M] → [..., M]. */
|
|
6417
|
+
function vecmat(x1, x2) {
|
|
6418
|
+
if (ndim(x1) < 1 || ndim(x2) < 2) throw new Error("vecmat: x1 must be at least 1D and x2 at least 2D");
|
|
6419
|
+
return einsum("...n,...nm->...m", x1, x2);
|
|
6420
|
+
}
|
|
6288
6421
|
/** Dot product of two arrays. */
|
|
6289
6422
|
function dot$1(x, y) {
|
|
6290
6423
|
if (ndim(x) === 0 || ndim(y) === 0) return multiply(x, y);
|
|
@@ -6443,6 +6576,49 @@ function outer(x, y) {
|
|
|
6443
6576
|
y = ravel(y);
|
|
6444
6577
|
return multiply(x.reshape([x.shape[0], 1]), y);
|
|
6445
6578
|
}
|
|
6579
|
+
/**
|
|
6580
|
+
* @function Compute the cross product of two arrays.
|
|
6581
|
+
*
|
|
6582
|
+
* Supports 2D (scalar result) and 3D cross products, with optional axis
|
|
6583
|
+
* arguments. If `axis` is given, it overrides `axisa`, `axisb`, and `axisc`.
|
|
6584
|
+
*/
|
|
6585
|
+
const cross = jit$1(function cross$2(a, b, { axisa = -1, axisb = -1, axisc = -1, axis } = {}) {
|
|
6586
|
+
if (axis !== void 0) {
|
|
6587
|
+
axisa = axis;
|
|
6588
|
+
axisb = axis;
|
|
6589
|
+
axisc = axis;
|
|
6590
|
+
}
|
|
6591
|
+
axisa = require_backend.checkAxis(axisa, ndim(a));
|
|
6592
|
+
axisb = require_backend.checkAxis(axisb, ndim(b));
|
|
6593
|
+
a = moveaxis$1(a, axisa, -1);
|
|
6594
|
+
b = moveaxis$1(b, axisb, -1);
|
|
6595
|
+
const da = a.shape.at(-1);
|
|
6596
|
+
const db = b.shape.at(-1);
|
|
6597
|
+
if (da !== 2 && da !== 3 || db !== 2 && db !== 3) throw new Error(`cross: incompatible dimensions for cross product (got ${da} and ${db})`);
|
|
6598
|
+
if (da === 2 && db === 2) {
|
|
6599
|
+
const [a0$1, a1$1] = split$1(a, 2, -1);
|
|
6600
|
+
const [b0$1, b1$1] = split$1(b, 2, -1);
|
|
6601
|
+
return squeeze(a0$1.mul(b1$1).sub(a1$1.mul(b0$1)), -1);
|
|
6602
|
+
}
|
|
6603
|
+
if (da === 2) {
|
|
6604
|
+
const zeroShape = [...a.shape.slice(0, -1), 1];
|
|
6605
|
+
a = concatenate([a, zeros(zeroShape)], -1);
|
|
6606
|
+
}
|
|
6607
|
+
if (db === 2) {
|
|
6608
|
+
const zeroShape = [...b.shape.slice(0, -1), 1];
|
|
6609
|
+
b = concatenate([b, zeros(zeroShape)], -1);
|
|
6610
|
+
}
|
|
6611
|
+
const [a0, a1, a2] = split$1(a, 3, -1);
|
|
6612
|
+
const [b0, b1, b2] = split$1(b, 3, -1);
|
|
6613
|
+
const c0 = a1.ref.mul(b2.ref).sub(a2.ref.mul(b1.ref));
|
|
6614
|
+
const c1 = a2.mul(b0.ref).sub(a0.ref.mul(b2));
|
|
6615
|
+
const c2 = a0.mul(b1).sub(a1.mul(b0));
|
|
6616
|
+
return moveaxis$1(concatenate([
|
|
6617
|
+
c0,
|
|
6618
|
+
c1,
|
|
6619
|
+
c2
|
|
6620
|
+
], -1), -1, axisc);
|
|
6621
|
+
}, { staticArgnums: [2] });
|
|
6446
6622
|
/** Vector dot product of two arrays along a given axis. */
|
|
6447
6623
|
function vecdot(x, y, { axis } = {}) {
|
|
6448
6624
|
const xaxis = require_backend.checkAxis(axis ?? -1, ndim(x));
|
|
@@ -6537,18 +6713,17 @@ function absolute(x) {
|
|
|
6537
6713
|
/** Return an element-wise indication of sign of the input. */
|
|
6538
6714
|
function sign(x) {
|
|
6539
6715
|
x = fudgeArray(x);
|
|
6540
|
-
return where(notEqual(x.ref, 0), where(less(x
|
|
6716
|
+
return where(notEqual(x.ref, 0), where(less(x, 0), -1, 1), 0);
|
|
6541
6717
|
}
|
|
6542
|
-
/** @function Return element-wise positive values of the input (no-op). */
|
|
6543
|
-
const positive = fudgeArray;
|
|
6544
6718
|
/**
|
|
6545
|
-
*
|
|
6546
|
-
*
|
|
6547
|
-
* `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
6719
|
+
* @function
|
|
6720
|
+
* Return the value with the magnitude of x and the sign of y, element-wise.
|
|
6548
6721
|
*/
|
|
6549
|
-
function
|
|
6550
|
-
return
|
|
6551
|
-
}
|
|
6722
|
+
const copysign = jit$1(function copysign$1(x, y) {
|
|
6723
|
+
return absolute(x).mul(sign(y));
|
|
6724
|
+
});
|
|
6725
|
+
/** @function Return element-wise positive values of the input (no-op). */
|
|
6726
|
+
const positive = fudgeArray;
|
|
6552
6727
|
/**
|
|
6553
6728
|
* Return the Hann window of size M, a taper with a weighted cosine bell.
|
|
6554
6729
|
*
|
|
@@ -6694,6 +6869,27 @@ function trunc(x) {
|
|
|
6694
6869
|
return idiv(x, 1);
|
|
6695
6870
|
}
|
|
6696
6871
|
/**
|
|
6872
|
+
* @function
|
|
6873
|
+
* Round to the given number of decimals.
|
|
6874
|
+
*
|
|
6875
|
+
* Uses banker's rounding (round half to even) to match NumPy/JAX behavior.
|
|
6876
|
+
*/
|
|
6877
|
+
const round = jit$1(function round$1(a, decimals = 0) {
|
|
6878
|
+
if (decimals === 0) return rint(a);
|
|
6879
|
+
const factor = 10 ** decimals;
|
|
6880
|
+
return rint(a.mul(factor)).mul(1 / factor);
|
|
6881
|
+
}, { staticArgnums: [1] });
|
|
6882
|
+
/**
|
|
6883
|
+
* @function
|
|
6884
|
+
* Round to the nearest integer, with ties going to the nearest even integer.
|
|
6885
|
+
*/
|
|
6886
|
+
const rint = jit$1(function rint$1(x) {
|
|
6887
|
+
const rounded = floor(x.ref.add(.5));
|
|
6888
|
+
const half = x.ref.sub(floor(x)).equal(.5);
|
|
6889
|
+
const odd = remainder(rounded.ref, 2).notEqual(0);
|
|
6890
|
+
return where(half.mul(odd), rounded.ref.sub(1), rounded);
|
|
6891
|
+
});
|
|
6892
|
+
/**
|
|
6697
6893
|
* Compute `x1 * 2 ** x2` as a standard multiplication and exponentiation.
|
|
6698
6894
|
*
|
|
6699
6895
|
* This is the inverse of `frexp()`.
|
|
@@ -7021,6 +7217,7 @@ function triangularSolve(a, b, { leftSide = false, lower = false, transposeA = f
|
|
|
7021
7217
|
//#region src/library/lax.ts
|
|
7022
7218
|
var lax_exports = {};
|
|
7023
7219
|
__export(lax_exports, {
|
|
7220
|
+
bitcastConvertType: () => bitcastConvertType,
|
|
7024
7221
|
conv: () => conv,
|
|
7025
7222
|
convGeneralDilated: () => convGeneralDilated,
|
|
7026
7223
|
convTranspose: () => convTranspose,
|
|
@@ -7030,9 +7227,14 @@ __export(lax_exports, {
|
|
|
7030
7227
|
erfc: () => erfc,
|
|
7031
7228
|
linalg: () => lax_linalg_exports,
|
|
7032
7229
|
reduceWindow: () => reduceWindow,
|
|
7033
|
-
stopGradient: () => stopGradient$1
|
|
7230
|
+
stopGradient: () => stopGradient$1,
|
|
7231
|
+
topK: () => topK
|
|
7034
7232
|
});
|
|
7035
7233
|
const JsArray = globalThis.Array;
|
|
7234
|
+
/** Elementwise bitcast an array into a new dtype. */
|
|
7235
|
+
function bitcastConvertType(x, newDtype) {
|
|
7236
|
+
return fudgeArray(x).view(newDtype);
|
|
7237
|
+
}
|
|
7036
7238
|
/**
|
|
7037
7239
|
* General dot product/contraction operator.
|
|
7038
7240
|
*
|
|
@@ -7254,6 +7456,39 @@ function erfc(x) {
|
|
|
7254
7456
|
function stopGradient$1(x) {
|
|
7255
7457
|
return stopGradient(x);
|
|
7256
7458
|
}
|
|
7459
|
+
/**
|
|
7460
|
+
* Returns top `k` values and their indices along the specified axis of operand.
|
|
7461
|
+
*
|
|
7462
|
+
* This is a _stable_ algorithm: If two elements are equal, the lower-index
|
|
7463
|
+
* element appears first.
|
|
7464
|
+
*
|
|
7465
|
+
* @returns A tuple of `(values, indices)`, where `values` and `indices` have
|
|
7466
|
+
* the same shape as `x`, except along `axis` where they have size `k`.
|
|
7467
|
+
*/
|
|
7468
|
+
function topK(x, k, axis = -1) {
|
|
7469
|
+
x = fudgeArray(x);
|
|
7470
|
+
axis = require_backend.checkAxis(axis, x.ndim);
|
|
7471
|
+
const size$1 = x.shape[axis];
|
|
7472
|
+
if (k < 0 || k > size$1) throw new Error(`topK: k must be in the range [0, ${size$1}], got ${k}`);
|
|
7473
|
+
if (k === 0) {
|
|
7474
|
+
const outShape = x.shape.slice();
|
|
7475
|
+
outShape[axis] = 0;
|
|
7476
|
+
const y$1 = zerosLike$1(x.ref, { shape: outShape });
|
|
7477
|
+
const yi$1 = zerosLike$1(x, {
|
|
7478
|
+
dtype: require_backend.DType.Int32,
|
|
7479
|
+
shape: outShape
|
|
7480
|
+
});
|
|
7481
|
+
return [y$1, yi$1];
|
|
7482
|
+
}
|
|
7483
|
+
x = flip$1(x, [axis]);
|
|
7484
|
+
x = moveaxis(x, axis, -1);
|
|
7485
|
+
const [y, yi] = argsort$1(x);
|
|
7486
|
+
const extract = (a) => {
|
|
7487
|
+
a = a.slice(...require_backend.rep(a.ndim - 1, []), [-k]);
|
|
7488
|
+
return flip$1(moveaxis(a, -1, axis), [axis]);
|
|
7489
|
+
};
|
|
7490
|
+
return [extract(y), extract(yi.neg().add(size$1 - 1))];
|
|
7491
|
+
}
|
|
7257
7492
|
|
|
7258
7493
|
//#endregion
|
|
7259
7494
|
//#region src/library/nn.ts
|
|
@@ -7445,7 +7680,7 @@ const gelu = jit$1(function gelu$1(x, opts) {
|
|
|
7445
7680
|
if (opts?.approximate ?? true) {
|
|
7446
7681
|
const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
|
|
7447
7682
|
return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
|
|
7448
|
-
} else return x.ref.mul(.5).mul(erfc$1(negative(x.
|
|
7683
|
+
} else return x.ref.mul(.5).mul(erfc$1(negative(x.mul(Math.SQRT1_2))));
|
|
7449
7684
|
}, { staticArgnums: [1] });
|
|
7450
7685
|
/**
|
|
7451
7686
|
* Gated linear unit (GLU) activation function.
|
|
@@ -7703,6 +7938,7 @@ var random_exports = {};
|
|
|
7703
7938
|
__export(random_exports, {
|
|
7704
7939
|
bernoulli: () => bernoulli,
|
|
7705
7940
|
bits: () => bits,
|
|
7941
|
+
categorical: () => categorical,
|
|
7706
7942
|
cauchy: () => cauchy,
|
|
7707
7943
|
exponential: () => exponential,
|
|
7708
7944
|
gumbel: () => gumbel,
|
|
@@ -7730,7 +7966,9 @@ function getK01(key$1) {
|
|
|
7730
7966
|
function key(seed) {
|
|
7731
7967
|
seed = array(seed, { dtype: require_backend.DType.Uint32 });
|
|
7732
7968
|
if (seed.ndim !== 0) throw new Error(`key: seed must be a scalar integer, but got shape ${seed.shape} - use jax.vmap for batching.`);
|
|
7733
|
-
|
|
7969
|
+
const key$1 = stack([0, seed]);
|
|
7970
|
+
if (key$1 instanceof Array$1) key$1._realizeSource();
|
|
7971
|
+
return key$1;
|
|
7734
7972
|
}
|
|
7735
7973
|
/** Splits a PRNG key into `num` new keys by adding a leading axis. */
|
|
7736
7974
|
function split(key$1, num = 2) {
|
|
@@ -7774,6 +8012,47 @@ function bernoulli(key$1, p = .5, shape$1 = []) {
|
|
|
7774
8012
|
}
|
|
7775
8013
|
/**
|
|
7776
8014
|
* @function
|
|
8015
|
+
* Sample random values from categorical distributions.
|
|
8016
|
+
*
|
|
8017
|
+
* Uses the Gumbel max trick for sampling with replacement, or the Gumbel top-k
|
|
8018
|
+
* trick for sampling without replacement.
|
|
8019
|
+
*
|
|
8020
|
+
* Note: Sampling without replacement currently uses argsort and slices the last
|
|
8021
|
+
* k elements. This should be replaced with a more efficient topK implementation.
|
|
8022
|
+
*
|
|
8023
|
+
* - `key` - PRNG key
|
|
8024
|
+
* - `logits` - Unnormalized log probabilities of the categorical distribution(s).
|
|
8025
|
+
* `softmax(logits, axis)` gives the corresponding probabilities.
|
|
8026
|
+
* - `axis` - Axis along which logits belong to the same categorical distribution.
|
|
8027
|
+
* - `shape` - Result batch shape. Must be broadcast-compatible with
|
|
8028
|
+
* `logits.shape` with `axis` removed. Default is `logits.shape` with `axis` removed.
|
|
8029
|
+
* - `replace` - If true (default), sample with replacement. If false, sample
|
|
8030
|
+
* without replacement (each category can only be selected once per batch).
|
|
8031
|
+
* @returns A random array with int dtype and shape given by `shape` if provided,
|
|
8032
|
+
* otherwise `logits.shape` with `axis` removed.
|
|
8033
|
+
*/
|
|
8034
|
+
const categorical = jit$1(function categorical$1(key$1, logits, { axis = -1, shape: shape$1, replace = true } = {}) {
|
|
8035
|
+
logits = fudgeArray(logits);
|
|
8036
|
+
axis = require_backend.checkAxis(axis, logits.ndim);
|
|
8037
|
+
const numCategories = logits.shape[axis];
|
|
8038
|
+
const batchShape = logits.shape.toSpliced(axis, 1);
|
|
8039
|
+
if (shape$1 === void 0) shape$1 = batchShape;
|
|
8040
|
+
else if (!require_backend.deepEqual(require_backend.generalBroadcast(shape$1, batchShape), shape$1)) throw new Error(`Shape ${shape$1} is not broadcast-compatible with batch shape ${batchShape}.`);
|
|
8041
|
+
const shapePrefix = shape$1.slice(0, shape$1.length - batchShape.length);
|
|
8042
|
+
if (replace) {
|
|
8043
|
+
const noise = gumbel(key$1, [...shapePrefix, ...logits.shape]);
|
|
8044
|
+
return argmax(noise.add(logits), axis + shapePrefix.length);
|
|
8045
|
+
} else {
|
|
8046
|
+
const k = shapePrefix.reduce((a, b) => a * b, 1);
|
|
8047
|
+
if (k > numCategories) throw new Error(`Number of samples without replacement (${k}) cannot exceed number of categories (${numCategories}).`);
|
|
8048
|
+
const noise = gumbel(key$1, logits.shape);
|
|
8049
|
+
const [values, indices] = topK(noise.add(logits), k, axis);
|
|
8050
|
+
values.dispose();
|
|
8051
|
+
return indices.reshape(shape$1);
|
|
8052
|
+
}
|
|
8053
|
+
}, { staticArgnums: [2] });
|
|
8054
|
+
/**
|
|
8055
|
+
* @function
|
|
7777
8056
|
* Sample from a Cauchy distribution with location 0 and scale 1.
|
|
7778
8057
|
*
|
|
7779
8058
|
* Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
|
|
@@ -7884,6 +8163,11 @@ Symbol.asyncDispose ??= Symbol.for("Symbol.asyncDispose");
|
|
|
7884
8163
|
|
|
7885
8164
|
//#endregion
|
|
7886
8165
|
//#region src/index.ts
|
|
8166
|
+
/** @namespace */
|
|
8167
|
+
const profiler = {
|
|
8168
|
+
startTrace: require_backend.startTrace,
|
|
8169
|
+
stopTrace: require_backend.stopTrace
|
|
8170
|
+
};
|
|
7887
8171
|
/**
|
|
7888
8172
|
* @function
|
|
7889
8173
|
* Compute the forward-mode Jacobian-vector product for a function.
|
|
@@ -8080,6 +8364,7 @@ Object.defineProperty(exports, 'numpy', {
|
|
|
8080
8364
|
return numpy_exports;
|
|
8081
8365
|
}
|
|
8082
8366
|
});
|
|
8367
|
+
exports.profiler = profiler;
|
|
8083
8368
|
Object.defineProperty(exports, 'random', {
|
|
8084
8369
|
enumerable: true,
|
|
8085
8370
|
get: function () {
|