@jax-js/jax 0.1.9 → 0.1.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +31 -18
- package/dist/{backend-BId79r5b.js → backend-Ctqs8la1.js} +107 -11
- package/dist/{backend-DpI0riom.cjs → backend-DMauYnfl.cjs} +142 -10
- package/dist/index.cjs +225 -18
- package/dist/index.d.cts +112 -11
- package/dist/index.d.ts +112 -11
- package/dist/index.js +225 -19
- package/dist/{webgl-DnGrclTz.js → webgl-CvQ1QBX1.js} +1 -1
- package/dist/{webgl-C5NjXc1p.cjs → webgl-kvVt7-T7.cjs} +1 -1
- package/dist/{webgpu-CdjiJSa7.cjs → webgpu-DMSx7a6M.cjs} +136 -6
- package/dist/{webgpu-AN0cG_nB.js → webgpu-v_W_-oKw.js} +136 -6
- package/package.json +5 -16
package/dist/index.cjs
CHANGED
|
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
|
|
|
30
30
|
}) : target, mod$1));
|
|
31
31
|
|
|
32
32
|
//#endregion
|
|
33
|
-
const require_backend = require('./backend-
|
|
33
|
+
const require_backend = require('./backend-DMauYnfl.cjs');
|
|
34
34
|
|
|
35
35
|
//#region src/frontend/convolution.ts
|
|
36
36
|
/**
|
|
@@ -838,6 +838,11 @@ var Tracer = class Tracer {
|
|
|
838
838
|
if (this.dtype === dtype) return this;
|
|
839
839
|
return cast(this, dtype);
|
|
840
840
|
}
|
|
841
|
+
/** Return a bitwise cast of the array, viewed as a new dtype. */
|
|
842
|
+
view(dtype) {
|
|
843
|
+
if (!dtype || dtype === this.dtype) return this;
|
|
844
|
+
return bitcast(this, dtype);
|
|
845
|
+
}
|
|
841
846
|
/** Subtract an array from this one. */
|
|
842
847
|
sub(other) {
|
|
843
848
|
return this.add(neg(other));
|
|
@@ -1659,7 +1664,7 @@ const abstractEvalRules = {
|
|
|
1659
1664
|
return [new ShapedArray(x.shape, dtype, false)];
|
|
1660
1665
|
},
|
|
1661
1666
|
[Primitive.Bitcast]([x], { dtype }) {
|
|
1662
|
-
if (x.dtype === require_backend.DType.Bool || dtype === require_backend.DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
|
|
1667
|
+
if (x.dtype !== dtype && (x.dtype === require_backend.DType.Bool || dtype === require_backend.DType.Bool)) throw new TypeError("Bitcast to/from bool is not allowed");
|
|
1663
1668
|
if (require_backend.byteWidth(x.dtype) !== require_backend.byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
|
|
1664
1669
|
return [new ShapedArray(x.shape, dtype, false)];
|
|
1665
1670
|
},
|
|
@@ -3081,8 +3086,8 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
3081
3086
|
return [x.#unary(require_backend.AluOp.Cast, dtype)];
|
|
3082
3087
|
},
|
|
3083
3088
|
[Primitive.Bitcast]([x], { dtype }) {
|
|
3084
|
-
if (x.dtype === require_backend.DType.Bool || dtype === require_backend.DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
|
|
3085
3089
|
if (x.dtype === dtype) return [x];
|
|
3090
|
+
if (x.dtype === require_backend.DType.Bool || dtype === require_backend.DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
|
|
3086
3091
|
if (require_backend.byteWidth(x.dtype) !== require_backend.byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
|
|
3087
3092
|
if (x.#source instanceof require_backend.AluExp) return [x.#unary(require_backend.AluOp.Bitcast, dtype)];
|
|
3088
3093
|
else {
|
|
@@ -4179,6 +4184,7 @@ const jvpRules = {
|
|
|
4179
4184
|
},
|
|
4180
4185
|
[Primitive.TriangularSolve]([a, b], [da, db], { unitDiagonal }) {
|
|
4181
4186
|
const x = triangularSolve$1(a.ref, b, { unitDiagonal });
|
|
4187
|
+
da = unitDiagonal ? triu(da, 1) : triu(da);
|
|
4182
4188
|
const dax = batchMatmulT(da, x.ref);
|
|
4183
4189
|
const rhsT = db.sub(mT(dax));
|
|
4184
4190
|
const dx = triangularSolve$1(a, rhsT, { unitDiagonal });
|
|
@@ -5254,6 +5260,7 @@ function ifft(a, axis = -1) {
|
|
|
5254
5260
|
var numpy_linalg_exports = {};
|
|
5255
5261
|
__export(numpy_linalg_exports, {
|
|
5256
5262
|
cholesky: () => cholesky,
|
|
5263
|
+
cross: () => cross$1,
|
|
5257
5264
|
det: () => det,
|
|
5258
5265
|
diagonal: () => diagonal,
|
|
5259
5266
|
inv: () => inv,
|
|
@@ -5284,6 +5291,19 @@ function cholesky(a, { upper = false, symmetrizeInput = true } = {}) {
|
|
|
5284
5291
|
if (symmetrizeInput) a = a.ref.add(matrixTranspose(a)).mul(.5);
|
|
5285
5292
|
return cholesky$1(a, { upper });
|
|
5286
5293
|
}
|
|
5294
|
+
/**
|
|
5295
|
+
* Compute the cross-product of two 3D vectors.
|
|
5296
|
+
*
|
|
5297
|
+
* This is a simpler and less flexible version of `jax.numpy.cross()`.
|
|
5298
|
+
* Both inputs must have size 3 along the specified axis.
|
|
5299
|
+
*/
|
|
5300
|
+
function cross$1(x1, x2, axis = -1) {
|
|
5301
|
+
const a1 = require_backend.checkAxis(axis, ndim(x1));
|
|
5302
|
+
const a2 = require_backend.checkAxis(axis, ndim(x2));
|
|
5303
|
+
if (shape(x1)[a1] !== 3) throw new Error(`linalg.cross: x1 must have size 3 along axis ${axis}, got ${shape(x1)[a1]}`);
|
|
5304
|
+
if (shape(x2)[a2] !== 3) throw new Error(`linalg.cross: x2 must have size 3 along axis ${axis}, got ${shape(x2)[a2]}`);
|
|
5305
|
+
return cross(x1, x2, { axis });
|
|
5306
|
+
}
|
|
5287
5307
|
/** Compute the determinant of a square matrix (batched). */
|
|
5288
5308
|
function det(a) {
|
|
5289
5309
|
a = fudgeArray(a);
|
|
@@ -5299,7 +5319,7 @@ function det(a) {
|
|
|
5299
5319
|
function inv(a) {
|
|
5300
5320
|
a = fudgeArray(a);
|
|
5301
5321
|
const n = checkSquare("inv", a);
|
|
5302
|
-
return solve(a, eye(n));
|
|
5322
|
+
return solve(a, eye(n, void 0, { dtype: a.dtype }));
|
|
5303
5323
|
}
|
|
5304
5324
|
/**
|
|
5305
5325
|
* Return the least-squares solution to a linear equation.
|
|
@@ -5356,8 +5376,9 @@ function matrixPower(a, n) {
|
|
|
5356
5376
|
a = fudgeArray(a);
|
|
5357
5377
|
const m = checkSquare("matrixPower", a);
|
|
5358
5378
|
if (n === 0) {
|
|
5379
|
+
const dtype = a.dtype;
|
|
5359
5380
|
a.dispose();
|
|
5360
|
-
return broadcastTo(eye(m), a.shape);
|
|
5381
|
+
return broadcastTo(eye(m, void 0, { dtype }), a.shape);
|
|
5361
5382
|
}
|
|
5362
5383
|
if (n < 0) {
|
|
5363
5384
|
a = inv(a);
|
|
@@ -5539,13 +5560,17 @@ __export(numpy_exports, {
|
|
|
5539
5560
|
argmax: () => argmax,
|
|
5540
5561
|
argmin: () => argmin,
|
|
5541
5562
|
argsort: () => argsort,
|
|
5563
|
+
around: () => round,
|
|
5542
5564
|
array: () => array,
|
|
5565
|
+
arrayEqual: () => arrayEqual,
|
|
5566
|
+
arrayEquiv: () => arrayEquiv,
|
|
5543
5567
|
asin: () => asin,
|
|
5544
5568
|
asinh: () => arcsinh,
|
|
5545
5569
|
astype: () => astype,
|
|
5546
5570
|
atan: () => atan,
|
|
5547
5571
|
atan2: () => atan2,
|
|
5548
5572
|
atanh: () => arctanh,
|
|
5573
|
+
average: () => average,
|
|
5549
5574
|
bool: () => bool,
|
|
5550
5575
|
broadcastArrays: () => broadcastArrays,
|
|
5551
5576
|
broadcastShapes: () => broadcastShapes,
|
|
@@ -5556,11 +5581,13 @@ __export(numpy_exports, {
|
|
|
5556
5581
|
columnStack: () => columnStack,
|
|
5557
5582
|
concatenate: () => concatenate,
|
|
5558
5583
|
convolve: () => convolve,
|
|
5584
|
+
copysign: () => copysign,
|
|
5559
5585
|
corrcoef: () => corrcoef,
|
|
5560
5586
|
correlate: () => correlate,
|
|
5561
5587
|
cos: () => cos,
|
|
5562
5588
|
cosh: () => cosh,
|
|
5563
5589
|
cov: () => cov,
|
|
5590
|
+
cross: () => cross,
|
|
5564
5591
|
cumsum: () => cumsum,
|
|
5565
5592
|
cumulativeSum: () => cumsum,
|
|
5566
5593
|
deg2rad: () => deg2rad,
|
|
@@ -5596,7 +5623,6 @@ __export(numpy_exports, {
|
|
|
5596
5623
|
fullLike: () => fullLike$1,
|
|
5597
5624
|
greater: () => greater,
|
|
5598
5625
|
greaterEqual: () => greaterEqual,
|
|
5599
|
-
hamming: () => hamming,
|
|
5600
5626
|
hann: () => hann,
|
|
5601
5627
|
heaviside: () => heaviside,
|
|
5602
5628
|
hstack: () => hstack,
|
|
@@ -5620,9 +5646,14 @@ __export(numpy_exports, {
|
|
|
5620
5646
|
log10: () => log10,
|
|
5621
5647
|
log1p: () => log1p,
|
|
5622
5648
|
log2: () => log2,
|
|
5649
|
+
logicalAnd: () => logicalAnd,
|
|
5650
|
+
logicalNot: () => logicalNot,
|
|
5651
|
+
logicalOr: () => logicalOr,
|
|
5652
|
+
logicalXor: () => logicalXor,
|
|
5623
5653
|
logspace: () => logspace,
|
|
5624
5654
|
matmul: () => matmul,
|
|
5625
5655
|
matrixTranspose: () => matrixTranspose,
|
|
5656
|
+
matvec: () => matvec,
|
|
5626
5657
|
max: () => max,
|
|
5627
5658
|
maximum: () => maximum,
|
|
5628
5659
|
mean: () => mean,
|
|
@@ -5655,6 +5686,8 @@ __export(numpy_exports, {
|
|
|
5655
5686
|
remainder: () => remainder,
|
|
5656
5687
|
repeat: () => repeat,
|
|
5657
5688
|
reshape: () => reshape,
|
|
5689
|
+
rint: () => rint,
|
|
5690
|
+
round: () => round,
|
|
5658
5691
|
shape: () => shape,
|
|
5659
5692
|
sign: () => sign,
|
|
5660
5693
|
sin: () => sin,
|
|
@@ -5687,6 +5720,7 @@ __export(numpy_exports, {
|
|
|
5687
5720
|
var_: () => var_,
|
|
5688
5721
|
vdot: () => vdot,
|
|
5689
5722
|
vecdot: () => vecdot,
|
|
5723
|
+
vecmat: () => vecmat,
|
|
5690
5724
|
vstack: () => vstack,
|
|
5691
5725
|
where: () => where,
|
|
5692
5726
|
zeros: () => zeros,
|
|
@@ -5750,6 +5784,22 @@ const notEqual = notEqual$1;
|
|
|
5750
5784
|
const greaterEqual = greaterEqual$1;
|
|
5751
5785
|
/** @function Compare two arrays element-wise. */
|
|
5752
5786
|
const lessEqual = lessEqual$1;
|
|
5787
|
+
/** Compute element-wise logical AND. */
|
|
5788
|
+
function logicalAnd(x, y) {
|
|
5789
|
+
return astype(x, require_backend.DType.Bool).mul(astype(y, require_backend.DType.Bool));
|
|
5790
|
+
}
|
|
5791
|
+
/** Compute element-wise logical OR. */
|
|
5792
|
+
function logicalOr(x, y) {
|
|
5793
|
+
return astype(x, require_backend.DType.Bool).add(astype(y, require_backend.DType.Bool));
|
|
5794
|
+
}
|
|
5795
|
+
/** Compute element-wise logical XOR. */
|
|
5796
|
+
function logicalXor(x, y) {
|
|
5797
|
+
return notEqual(astype(x, require_backend.DType.Bool), astype(y, require_backend.DType.Bool));
|
|
5798
|
+
}
|
|
5799
|
+
/** Compute element-wise logical NOT. */
|
|
5800
|
+
function logicalNot(x) {
|
|
5801
|
+
return notEqual(astype(x, require_backend.DType.Bool), true);
|
|
5802
|
+
}
|
|
5753
5803
|
/** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
5754
5804
|
const where = where$1;
|
|
5755
5805
|
/**
|
|
@@ -5857,6 +5907,34 @@ function mean(a, axis = null, opts) {
|
|
|
5857
5907
|
return fudgeArray(a).mean(axis, opts);
|
|
5858
5908
|
}
|
|
5859
5909
|
/**
|
|
5910
|
+
* Compute the weighted average along the specified axis.
|
|
5911
|
+
*
|
|
5912
|
+
* If no axis is specified, mean is computed along all the axes. The weights
|
|
5913
|
+
* should have shape matching that of `a`, or if an axis is specified, it should
|
|
5914
|
+
* match the shape along those axes.
|
|
5915
|
+
*/
|
|
5916
|
+
function average(a, axis = null, opts) {
|
|
5917
|
+
a = fudgeArray(a);
|
|
5918
|
+
if (opts?.weights == null) return mean(a, axis, opts);
|
|
5919
|
+
const weights = fudgeArray(opts.weights);
|
|
5920
|
+
axis = require_backend.normalizeAxis(axis, ndim(a));
|
|
5921
|
+
const wShape = weights.shape;
|
|
5922
|
+
const aShape = a.shape;
|
|
5923
|
+
if (require_backend.deepEqual(wShape, aShape)) {
|
|
5924
|
+
const scl = sum(weights.ref, axis, opts);
|
|
5925
|
+
return sum(multiply(a, weights), axis, opts).div(scl);
|
|
5926
|
+
} else if (axis.length === 1 && wShape.length === 1 && wShape[0] === aShape[axis[0]]) {
|
|
5927
|
+
const broadcastShape = aShape.map((_, i) => i === axis[0] ? wShape[0] : 1);
|
|
5928
|
+
const wReshaped = reshape(weights, broadcastShape);
|
|
5929
|
+
const scl = sum(wReshaped.ref, axis, opts);
|
|
5930
|
+
return sum(multiply(a, wReshaped), axis, opts).div(scl);
|
|
5931
|
+
} else {
|
|
5932
|
+
weights.dispose();
|
|
5933
|
+
a.dispose();
|
|
5934
|
+
throw new Error(`average: weights shape ${JSON.stringify(wShape)} is not compatible with array shape ${JSON.stringify(aShape)} and axis ${JSON.stringify(axis)}`);
|
|
5935
|
+
}
|
|
5936
|
+
}
|
|
5937
|
+
/**
|
|
5860
5938
|
* Returns the indices of the minimum values along an axis.
|
|
5861
5939
|
*
|
|
5862
5940
|
* By default, index is into the flatted array, otherwise it is along the
|
|
@@ -6260,20 +6338,63 @@ function take(a, indices, axis = null) {
|
|
|
6260
6338
|
axis = require_backend.checkAxis(axis, ndim(a));
|
|
6261
6339
|
return gather(a, [indices], [axis], axis);
|
|
6262
6340
|
}
|
|
6263
|
-
/**
|
|
6341
|
+
/**
|
|
6342
|
+
* Return if two arrays are element-wise equal within a tolerance.
|
|
6343
|
+
*
|
|
6344
|
+
* The formula used is `|actual - expected| <= atol + rtol * |expected|`, with
|
|
6345
|
+
* NaN values comparing equal if `equalNaN` is true.
|
|
6346
|
+
*/
|
|
6264
6347
|
function allclose(actual, expected, options) {
|
|
6265
|
-
const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
|
|
6348
|
+
const { rtol = 1e-5, atol = 1e-7, equalNaN = false } = options ?? {};
|
|
6266
6349
|
const x = array(actual);
|
|
6267
6350
|
const y = array(expected);
|
|
6268
6351
|
if (!require_backend.deepEqual(x.shape, y.shape)) return false;
|
|
6269
6352
|
const xData = x.dataSync();
|
|
6270
6353
|
const yData = y.dataSync();
|
|
6271
6354
|
for (let i = 0; i < xData.length; i++) {
|
|
6272
|
-
if (isNaN(xData[i]) !== isNaN(yData[i])) return false;
|
|
6355
|
+
if (equalNaN ? isNaN(xData[i]) !== isNaN(yData[i]) : isNaN(xData[i]) || isNaN(yData[i])) return false;
|
|
6273
6356
|
if (Math.abs(xData[i] - yData[i]) > atol + rtol * Math.abs(yData[i])) return false;
|
|
6274
6357
|
}
|
|
6275
6358
|
return true;
|
|
6276
6359
|
}
|
|
6360
|
+
/**
|
|
6361
|
+
* Check if two arrays are element-wise equal.
|
|
6362
|
+
*
|
|
6363
|
+
* Returns False if the arrays have different shapes. If `equalNaN` is True,
|
|
6364
|
+
* NaNs in the same position are considered equal.
|
|
6365
|
+
*/
|
|
6366
|
+
function arrayEqual(a1, a2, opts) {
|
|
6367
|
+
a1 = fudgeArray(a1);
|
|
6368
|
+
a2 = fudgeArray(a2);
|
|
6369
|
+
if (!require_backend.deepEqual(a1.shape, a2.shape)) {
|
|
6370
|
+
a1.dispose();
|
|
6371
|
+
a2.dispose();
|
|
6372
|
+
return array(false);
|
|
6373
|
+
}
|
|
6374
|
+
if (opts?.equalNaN) {
|
|
6375
|
+
const nanMask = isnan(a1.ref).mul(isnan(a2.ref));
|
|
6376
|
+
return where(nanMask, true, equal(a1, a2)).all();
|
|
6377
|
+
}
|
|
6378
|
+
return equal(a1, a2).all();
|
|
6379
|
+
}
|
|
6380
|
+
/**
|
|
6381
|
+
* Check if two arrays are element-wise equal after broadcasting.
|
|
6382
|
+
*
|
|
6383
|
+
* Unlike `arrayEqual`, this allows inputs with different but
|
|
6384
|
+
* broadcast-compatible shapes.
|
|
6385
|
+
*/
|
|
6386
|
+
function arrayEquiv(a1, a2) {
|
|
6387
|
+
a1 = fudgeArray(a1);
|
|
6388
|
+
a2 = fudgeArray(a2);
|
|
6389
|
+
try {
|
|
6390
|
+
const [b1, b2] = broadcastArrays(a1, a2);
|
|
6391
|
+
return equal(b1, b2).all();
|
|
6392
|
+
} catch {
|
|
6393
|
+
a1.dispose();
|
|
6394
|
+
a2.dispose();
|
|
6395
|
+
return array(false);
|
|
6396
|
+
}
|
|
6397
|
+
}
|
|
6277
6398
|
/** Matrix product of two arrays. */
|
|
6278
6399
|
function matmul(x, y) {
|
|
6279
6400
|
if (ndim(x) === 0 || ndim(y) === 0) throw new Error("matmul: x and y must be at least 1D");
|
|
@@ -6287,6 +6408,16 @@ function matmul(x, y) {
|
|
|
6287
6408
|
rhsBatchDims: require_backend.range(-2 - numBatchDims, -2)
|
|
6288
6409
|
});
|
|
6289
6410
|
}
|
|
6411
|
+
/** Matrix-vector product. x1 is [..., M, N], x2 is [..., N] → [..., M]. */
|
|
6412
|
+
function matvec(x1, x2) {
|
|
6413
|
+
if (ndim(x1) < 2 || ndim(x2) < 1) throw new Error("matvec: x1 must be at least 2D and x2 at least 1D");
|
|
6414
|
+
return einsum("...mn,...n->...m", x1, x2);
|
|
6415
|
+
}
|
|
6416
|
+
/** Vector-matrix product. x1 is [..., N], x2 is [..., N, M] → [..., M]. */
|
|
6417
|
+
function vecmat(x1, x2) {
|
|
6418
|
+
if (ndim(x1) < 1 || ndim(x2) < 2) throw new Error("vecmat: x1 must be at least 1D and x2 at least 2D");
|
|
6419
|
+
return einsum("...n,...nm->...m", x1, x2);
|
|
6420
|
+
}
|
|
6290
6421
|
/** Dot product of two arrays. */
|
|
6291
6422
|
function dot$1(x, y) {
|
|
6292
6423
|
if (ndim(x) === 0 || ndim(y) === 0) return multiply(x, y);
|
|
@@ -6445,6 +6576,49 @@ function outer(x, y) {
|
|
|
6445
6576
|
y = ravel(y);
|
|
6446
6577
|
return multiply(x.reshape([x.shape[0], 1]), y);
|
|
6447
6578
|
}
|
|
6579
|
+
/**
|
|
6580
|
+
* @function Compute the cross product of two arrays.
|
|
6581
|
+
*
|
|
6582
|
+
* Supports 2D (scalar result) and 3D cross products, with optional axis
|
|
6583
|
+
* arguments. If `axis` is given, it overrides `axisa`, `axisb`, and `axisc`.
|
|
6584
|
+
*/
|
|
6585
|
+
const cross = jit$1(function cross$2(a, b, { axisa = -1, axisb = -1, axisc = -1, axis } = {}) {
|
|
6586
|
+
if (axis !== void 0) {
|
|
6587
|
+
axisa = axis;
|
|
6588
|
+
axisb = axis;
|
|
6589
|
+
axisc = axis;
|
|
6590
|
+
}
|
|
6591
|
+
axisa = require_backend.checkAxis(axisa, ndim(a));
|
|
6592
|
+
axisb = require_backend.checkAxis(axisb, ndim(b));
|
|
6593
|
+
a = moveaxis$1(a, axisa, -1);
|
|
6594
|
+
b = moveaxis$1(b, axisb, -1);
|
|
6595
|
+
const da = a.shape.at(-1);
|
|
6596
|
+
const db = b.shape.at(-1);
|
|
6597
|
+
if (da !== 2 && da !== 3 || db !== 2 && db !== 3) throw new Error(`cross: incompatible dimensions for cross product (got ${da} and ${db})`);
|
|
6598
|
+
if (da === 2 && db === 2) {
|
|
6599
|
+
const [a0$1, a1$1] = split$1(a, 2, -1);
|
|
6600
|
+
const [b0$1, b1$1] = split$1(b, 2, -1);
|
|
6601
|
+
return squeeze(a0$1.mul(b1$1).sub(a1$1.mul(b0$1)), -1);
|
|
6602
|
+
}
|
|
6603
|
+
if (da === 2) {
|
|
6604
|
+
const zeroShape = [...a.shape.slice(0, -1), 1];
|
|
6605
|
+
a = concatenate([a, zeros(zeroShape)], -1);
|
|
6606
|
+
}
|
|
6607
|
+
if (db === 2) {
|
|
6608
|
+
const zeroShape = [...b.shape.slice(0, -1), 1];
|
|
6609
|
+
b = concatenate([b, zeros(zeroShape)], -1);
|
|
6610
|
+
}
|
|
6611
|
+
const [a0, a1, a2] = split$1(a, 3, -1);
|
|
6612
|
+
const [b0, b1, b2] = split$1(b, 3, -1);
|
|
6613
|
+
const c0 = a1.ref.mul(b2.ref).sub(a2.ref.mul(b1.ref));
|
|
6614
|
+
const c1 = a2.mul(b0.ref).sub(a0.ref.mul(b2));
|
|
6615
|
+
const c2 = a0.mul(b1).sub(a1.mul(b0));
|
|
6616
|
+
return moveaxis$1(concatenate([
|
|
6617
|
+
c0,
|
|
6618
|
+
c1,
|
|
6619
|
+
c2
|
|
6620
|
+
], -1), -1, axisc);
|
|
6621
|
+
}, { staticArgnums: [2] });
|
|
6448
6622
|
/** Vector dot product of two arrays along a given axis. */
|
|
6449
6623
|
function vecdot(x, y, { axis } = {}) {
|
|
6450
6624
|
const xaxis = require_backend.checkAxis(axis ?? -1, ndim(x));
|
|
@@ -6541,16 +6715,15 @@ function sign(x) {
|
|
|
6541
6715
|
x = fudgeArray(x);
|
|
6542
6716
|
return where(notEqual(x.ref, 0), where(less(x, 0), -1, 1), 0);
|
|
6543
6717
|
}
|
|
6544
|
-
/** @function Return element-wise positive values of the input (no-op). */
|
|
6545
|
-
const positive = fudgeArray;
|
|
6546
6718
|
/**
|
|
6547
|
-
*
|
|
6548
|
-
*
|
|
6549
|
-
* `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
6719
|
+
* @function
|
|
6720
|
+
* Return the value with the magnitude of x and the sign of y, element-wise.
|
|
6550
6721
|
*/
|
|
6551
|
-
function
|
|
6552
|
-
return
|
|
6553
|
-
}
|
|
6722
|
+
const copysign = jit$1(function copysign$1(x, y) {
|
|
6723
|
+
return absolute(x).mul(sign(y));
|
|
6724
|
+
});
|
|
6725
|
+
/** @function Return element-wise positive values of the input (no-op). */
|
|
6726
|
+
const positive = fudgeArray;
|
|
6554
6727
|
/**
|
|
6555
6728
|
* Return the Hann window of size M, a taper with a weighted cosine bell.
|
|
6556
6729
|
*
|
|
@@ -6696,6 +6869,27 @@ function trunc(x) {
|
|
|
6696
6869
|
return idiv(x, 1);
|
|
6697
6870
|
}
|
|
6698
6871
|
/**
|
|
6872
|
+
* @function
|
|
6873
|
+
* Round to the given number of decimals.
|
|
6874
|
+
*
|
|
6875
|
+
* Uses banker's rounding (round half to even) to match NumPy/JAX behavior.
|
|
6876
|
+
*/
|
|
6877
|
+
const round = jit$1(function round$1(a, decimals = 0) {
|
|
6878
|
+
if (decimals === 0) return rint(a);
|
|
6879
|
+
const factor = 10 ** decimals;
|
|
6880
|
+
return rint(a.mul(factor)).mul(1 / factor);
|
|
6881
|
+
}, { staticArgnums: [1] });
|
|
6882
|
+
/**
|
|
6883
|
+
* @function
|
|
6884
|
+
* Round to the nearest integer, with ties going to the nearest even integer.
|
|
6885
|
+
*/
|
|
6886
|
+
const rint = jit$1(function rint$1(x) {
|
|
6887
|
+
const rounded = floor(x.ref.add(.5));
|
|
6888
|
+
const half = x.ref.sub(floor(x)).equal(.5);
|
|
6889
|
+
const odd = remainder(rounded.ref, 2).notEqual(0);
|
|
6890
|
+
return where(half.mul(odd), rounded.ref.sub(1), rounded);
|
|
6891
|
+
});
|
|
6892
|
+
/**
|
|
6699
6893
|
* Compute `x1 * 2 ** x2` as a standard multiplication and exponentiation.
|
|
6700
6894
|
*
|
|
6701
6895
|
* This is the inverse of `frexp()`.
|
|
@@ -7023,6 +7217,7 @@ function triangularSolve(a, b, { leftSide = false, lower = false, transposeA = f
|
|
|
7023
7217
|
//#region src/library/lax.ts
|
|
7024
7218
|
var lax_exports = {};
|
|
7025
7219
|
__export(lax_exports, {
|
|
7220
|
+
bitcastConvertType: () => bitcastConvertType,
|
|
7026
7221
|
conv: () => conv,
|
|
7027
7222
|
convGeneralDilated: () => convGeneralDilated,
|
|
7028
7223
|
convTranspose: () => convTranspose,
|
|
@@ -7036,6 +7231,10 @@ __export(lax_exports, {
|
|
|
7036
7231
|
topK: () => topK
|
|
7037
7232
|
});
|
|
7038
7233
|
const JsArray = globalThis.Array;
|
|
7234
|
+
/** Elementwise bitcast an array into a new dtype. */
|
|
7235
|
+
function bitcastConvertType(x, newDtype) {
|
|
7236
|
+
return fudgeArray(x).view(newDtype);
|
|
7237
|
+
}
|
|
7039
7238
|
/**
|
|
7040
7239
|
* General dot product/contraction operator.
|
|
7041
7240
|
*
|
|
@@ -7767,7 +7966,9 @@ function getK01(key$1) {
|
|
|
7767
7966
|
function key(seed) {
|
|
7768
7967
|
seed = array(seed, { dtype: require_backend.DType.Uint32 });
|
|
7769
7968
|
if (seed.ndim !== 0) throw new Error(`key: seed must be a scalar integer, but got shape ${seed.shape} - use jax.vmap for batching.`);
|
|
7770
|
-
|
|
7969
|
+
const key$1 = stack([0, seed]);
|
|
7970
|
+
if (key$1 instanceof Array$1) key$1._realizeSource();
|
|
7971
|
+
return key$1;
|
|
7771
7972
|
}
|
|
7772
7973
|
/** Splits a PRNG key into `num` new keys by adding a leading axis. */
|
|
7773
7974
|
function split(key$1, num = 2) {
|
|
@@ -7962,6 +8163,11 @@ Symbol.asyncDispose ??= Symbol.for("Symbol.asyncDispose");
|
|
|
7962
8163
|
|
|
7963
8164
|
//#endregion
|
|
7964
8165
|
//#region src/index.ts
|
|
8166
|
+
/** @namespace */
|
|
8167
|
+
const profiler = {
|
|
8168
|
+
startTrace: require_backend.startTrace,
|
|
8169
|
+
stopTrace: require_backend.stopTrace
|
|
8170
|
+
};
|
|
7965
8171
|
/**
|
|
7966
8172
|
* @function
|
|
7967
8173
|
* Compute the forward-mode Jacobian-vector product for a function.
|
|
@@ -8158,6 +8364,7 @@ Object.defineProperty(exports, 'numpy', {
|
|
|
8158
8364
|
return numpy_exports;
|
|
8159
8365
|
}
|
|
8160
8366
|
});
|
|
8367
|
+
exports.profiler = profiler;
|
|
8161
8368
|
Object.defineProperty(exports, 'random', {
|
|
8162
8369
|
enumerable: true,
|
|
8163
8370
|
get: function () {
|
package/dist/index.d.cts
CHANGED
|
@@ -1004,6 +1004,8 @@ declare abstract class Tracer {
|
|
|
1004
1004
|
reshape(shape: number | number[]): this;
|
|
1005
1005
|
/** Copy the array and cast to a specified dtype. */
|
|
1006
1006
|
astype(dtype: DType): this;
|
|
1007
|
+
/** Return a bitwise cast of the array, viewed as a new dtype. */
|
|
1008
|
+
view(dtype?: DType): this;
|
|
1007
1009
|
/** Subtract an array from this one. */
|
|
1008
1010
|
sub(other: this | TracerValue): this;
|
|
1009
1011
|
/** Divide an array by this one. */
|
|
@@ -1427,8 +1429,10 @@ declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
|
|
|
1427
1429
|
unitDiagonal?: boolean;
|
|
1428
1430
|
}): Array;
|
|
1429
1431
|
declare namespace lax_d_exports {
|
|
1430
|
-
export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient, topK };
|
|
1432
|
+
export { DotDimensionNumbers, PaddingType, bitcastConvertType, conv, convGeneralDilated, convTranspose, convWithGeneralPadding, dot$1 as dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient, topK };
|
|
1431
1433
|
}
|
|
1434
|
+
/** Elementwise bitcast an array into a new dtype. */
|
|
1435
|
+
declare function bitcastConvertType(x: ArrayLike, newDtype: DType): Array;
|
|
1432
1436
|
/**
|
|
1433
1437
|
* Dimension numbers for general `dot()` primitive.
|
|
1434
1438
|
*
|
|
@@ -1567,7 +1571,7 @@ declare function fft(a: ComplexPair, axis?: number): ComplexPair;
|
|
|
1567
1571
|
*/
|
|
1568
1572
|
declare function ifft(a: ComplexPair, axis?: number): ComplexPair;
|
|
1569
1573
|
declare namespace numpy_linalg_d_exports {
|
|
1570
|
-
export { cholesky, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot };
|
|
1574
|
+
export { cholesky, cross$1 as cross, det, diagonal, inv, lstsq, matmul, matrixPower, matrixTranspose, outer, slogdet, solve, tensordot, trace, vecdot };
|
|
1571
1575
|
}
|
|
1572
1576
|
/**
|
|
1573
1577
|
* Compute the Cholesky decomposition of a (batched) positive-definite matrix.
|
|
@@ -1582,6 +1586,13 @@ declare function cholesky(a: ArrayLike, {
|
|
|
1582
1586
|
upper?: boolean;
|
|
1583
1587
|
symmetrizeInput?: boolean;
|
|
1584
1588
|
}): Array;
|
|
1589
|
+
/**
|
|
1590
|
+
* Compute the cross-product of two 3D vectors.
|
|
1591
|
+
*
|
|
1592
|
+
* This is a simpler and less flexible version of `jax.numpy.cross()`.
|
|
1593
|
+
* Both inputs must have size 3 along the specified axis.
|
|
1594
|
+
*/
|
|
1595
|
+
declare function cross$1(x1: ArrayLike, x2: ArrayLike, axis?: number): Array;
|
|
1585
1596
|
/** Compute the determinant of a square matrix (batched). */
|
|
1586
1597
|
declare function det(a: ArrayLike): Array;
|
|
1587
1598
|
/** Compute the inverse of a square matrix (batched). */
|
|
@@ -1668,7 +1679,7 @@ type IInfo = Readonly<{
|
|
|
1668
1679
|
/** Machine limits for integer types. */
|
|
1669
1680
|
declare function iinfo(dtype: DType): IInfo;
|
|
1670
1681
|
declare namespace numpy_d_exports {
|
|
1671
|
-
export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual,
|
|
1682
|
+
export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, round as around, array, arrayEqual, arrayEquiv, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, average, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, copysign, corrcoef, correlate, cos, cosh, cov, cross, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, divmod, dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, finfo, flip, fliplr, flipud, float16, float32, float64, floor, floorDivide, fmod, frexp, full, fullLike, greater, greaterEqual, hann, heaviside, hstack, hypot, identity$1 as identity, iinfo, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, logicalAnd, logicalNot, logicalOr, logicalXor, logspace, matmul, matrixTranspose, matvec, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, nanToNum, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, rint, round, shape$1 as shape, sign, sin, sinc, sinh, size, sort, split$1 as split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vecmat, vstack, where, zeros, zerosLike };
|
|
1672
1683
|
}
|
|
1673
1684
|
declare const float32 = DType.Float32;
|
|
1674
1685
|
declare const int32 = DType.Int32;
|
|
@@ -1728,6 +1739,14 @@ declare const notEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
|
1728
1739
|
declare const greaterEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
1729
1740
|
/** @function Compare two arrays element-wise. */
|
|
1730
1741
|
declare const lessEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
1742
|
+
/** Compute element-wise logical AND. */
|
|
1743
|
+
declare function logicalAnd(x: ArrayLike, y: ArrayLike): Array;
|
|
1744
|
+
/** Compute element-wise logical OR. */
|
|
1745
|
+
declare function logicalOr(x: ArrayLike, y: ArrayLike): Array;
|
|
1746
|
+
/** Compute element-wise logical XOR. */
|
|
1747
|
+
declare function logicalXor(x: ArrayLike, y: ArrayLike): Array;
|
|
1748
|
+
/** Compute element-wise logical NOT. */
|
|
1749
|
+
declare function logicalNot(x: ArrayLike): Array;
|
|
1731
1750
|
/** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
1732
1751
|
declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
|
|
1733
1752
|
/**
|
|
@@ -1812,6 +1831,16 @@ declare function all(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
|
1812
1831
|
declare function ptp(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1813
1832
|
/** Compute the average of the array elements along the specified axis. */
|
|
1814
1833
|
declare function mean(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
1834
|
+
/**
|
|
1835
|
+
* Compute the weighted average along the specified axis.
|
|
1836
|
+
*
|
|
1837
|
+
* If no axis is specified, mean is computed along all the axes. The weights
|
|
1838
|
+
* should have shape matching that of `a`, or if an axis is specified, it should
|
|
1839
|
+
* match the shape along those axes.
|
|
1840
|
+
*/
|
|
1841
|
+
declare function average(a: ArrayLike, axis?: Axis, opts?: {
|
|
1842
|
+
weights?: ArrayLike;
|
|
1843
|
+
} & ReduceOpts): Array;
|
|
1815
1844
|
/**
|
|
1816
1845
|
* Returns the indices of the minimum values along an axis.
|
|
1817
1846
|
*
|
|
@@ -1983,13 +2012,39 @@ declare function argsort(a: ArrayLike, axis?: number): Array;
|
|
|
1983
2012
|
* numbered axis. By default, the flattened array is used.
|
|
1984
2013
|
*/
|
|
1985
2014
|
declare function take(a: ArrayLike, indices: ArrayLike, axis?: number | null): Array;
|
|
1986
|
-
/**
|
|
2015
|
+
/**
|
|
2016
|
+
* Return if two arrays are element-wise equal within a tolerance.
|
|
2017
|
+
*
|
|
2018
|
+
* The formula used is `|actual - expected| <= atol + rtol * |expected|`, with
|
|
2019
|
+
* NaN values comparing equal if `equalNaN` is true.
|
|
2020
|
+
*/
|
|
1987
2021
|
declare function allclose(actual: Parameters<typeof array>[0], expected: Parameters<typeof array>[0], options?: {
|
|
1988
2022
|
rtol?: number;
|
|
1989
2023
|
atol?: number;
|
|
2024
|
+
equalNaN?: boolean;
|
|
1990
2025
|
}): boolean;
|
|
2026
|
+
/**
|
|
2027
|
+
* Check if two arrays are element-wise equal.
|
|
2028
|
+
*
|
|
2029
|
+
* Returns False if the arrays have different shapes. If `equalNaN` is True,
|
|
2030
|
+
* NaNs in the same position are considered equal.
|
|
2031
|
+
*/
|
|
2032
|
+
declare function arrayEqual(a1: ArrayLike, a2: ArrayLike, opts?: {
|
|
2033
|
+
equalNaN?: boolean;
|
|
2034
|
+
}): Array;
|
|
2035
|
+
/**
|
|
2036
|
+
* Check if two arrays are element-wise equal after broadcasting.
|
|
2037
|
+
*
|
|
2038
|
+
* Unlike `arrayEqual`, this allows inputs with different but
|
|
2039
|
+
* broadcast-compatible shapes.
|
|
2040
|
+
*/
|
|
2041
|
+
declare function arrayEquiv(a1: ArrayLike, a2: ArrayLike): Array;
|
|
1991
2042
|
/** Matrix product of two arrays. */
|
|
1992
2043
|
declare function matmul(x: ArrayLike, y: ArrayLike): Array;
|
|
2044
|
+
/** Matrix-vector product. x1 is [..., M, N], x2 is [..., N] → [..., M]. */
|
|
2045
|
+
declare function matvec(x1: ArrayLike, x2: ArrayLike): Array;
|
|
2046
|
+
/** Vector-matrix product. x1 is [..., N], x2 is [..., N, M] → [..., M]. */
|
|
2047
|
+
declare function vecmat(x1: ArrayLike, x2: ArrayLike): Array;
|
|
1993
2048
|
/** Dot product of two arrays. */
|
|
1994
2049
|
declare function dot(x: ArrayLike, y: ArrayLike): Array;
|
|
1995
2050
|
/**
|
|
@@ -2042,6 +2097,18 @@ declare function inner(x: ArrayLike, y: ArrayLike): Array;
|
|
|
2042
2097
|
* be of shape `[x.size, y.size]`.
|
|
2043
2098
|
*/
|
|
2044
2099
|
declare function outer(x: ArrayLike, y: ArrayLike): Array;
|
|
2100
|
+
/**
|
|
2101
|
+
* @function Compute the cross product of two arrays.
|
|
2102
|
+
*
|
|
2103
|
+
* Supports 2D (scalar result) and 3D cross products, with optional axis
|
|
2104
|
+
* arguments. If `axis` is given, it overrides `axisa`, `axisb`, and `axisc`.
|
|
2105
|
+
*/
|
|
2106
|
+
declare const cross: OwnedFunction<(a: ArrayLike, b: ArrayLike, args_2?: {
|
|
2107
|
+
axisa?: number | undefined;
|
|
2108
|
+
axisb?: number | undefined;
|
|
2109
|
+
axisc?: number | undefined;
|
|
2110
|
+
axis?: number | undefined;
|
|
2111
|
+
} | undefined) => Array>;
|
|
2045
2112
|
/** Vector dot product of two arrays along a given axis. */
|
|
2046
2113
|
declare function vecdot(x: ArrayLike, y: ArrayLike, {
|
|
2047
2114
|
axis
|
|
@@ -2087,14 +2154,13 @@ declare function clip(a: ArrayLike, min?: ArrayLike, max?: ArrayLike): Array;
|
|
|
2087
2154
|
declare function absolute(x: ArrayLike): Array;
|
|
2088
2155
|
/** Return an element-wise indication of sign of the input. */
|
|
2089
2156
|
declare function sign(x: ArrayLike): Array;
|
|
2090
|
-
/** @function Return element-wise positive values of the input (no-op). */
|
|
2091
|
-
declare const positive: (x: ArrayLike) => Array;
|
|
2092
2157
|
/**
|
|
2093
|
-
*
|
|
2094
|
-
*
|
|
2095
|
-
* `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
2158
|
+
* @function
|
|
2159
|
+
* Return the value with the magnitude of x and the sign of y, element-wise.
|
|
2096
2160
|
*/
|
|
2097
|
-
declare
|
|
2161
|
+
declare const copysign: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
|
|
2162
|
+
/** @function Return element-wise positive values of the input (no-op). */
|
|
2163
|
+
declare const positive: (x: ArrayLike) => Array;
|
|
2098
2164
|
/**
|
|
2099
2165
|
* Return the Hann window of size M, a taper with a weighted cosine bell.
|
|
2100
2166
|
*
|
|
@@ -2189,6 +2255,18 @@ declare const remainder: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
|
|
|
2189
2255
|
declare function divmod(x: ArrayLike, y: ArrayLike): [Array, Array];
|
|
2190
2256
|
/** Round input to the nearest integer towards zero. */
|
|
2191
2257
|
declare function trunc(x: ArrayLike): Array;
|
|
2258
|
+
/**
|
|
2259
|
+
* @function
|
|
2260
|
+
* Round to the given number of decimals.
|
|
2261
|
+
*
|
|
2262
|
+
* Uses banker's rounding (round half to even) to match NumPy/JAX behavior.
|
|
2263
|
+
*/
|
|
2264
|
+
declare const round: OwnedFunction<(a: ArrayLike, decimals?: number | undefined) => Array>;
|
|
2265
|
+
/**
|
|
2266
|
+
* @function
|
|
2267
|
+
* Round to the nearest integer, with ties going to the nearest even integer.
|
|
2268
|
+
*/
|
|
2269
|
+
declare const rint: OwnedFunction<(x: ArrayLike) => Array>;
|
|
2192
2270
|
/**
|
|
2193
2271
|
* Compute `x1 * 2 ** x2` as a standard multiplication and exponentiation.
|
|
2194
2272
|
*
|
|
@@ -2691,8 +2769,31 @@ declare namespace scipy_special_d_exports {
|
|
|
2691
2769
|
* The logit function, `logit(p) = log(p / (1-p))`.
|
|
2692
2770
|
*/
|
|
2693
2771
|
declare const logit: OwnedFunction<(x: ArrayLike) => Array>;
|
|
2772
|
+
//#endregion
|
|
2773
|
+
//#region src/tracing.d.ts
|
|
2774
|
+
/**
|
|
2775
|
+
* Start collecting kernel traces.
|
|
2776
|
+
*
|
|
2777
|
+
* Traces appear in developer tools under the "Performance" tab, and they are
|
|
2778
|
+
* useful for measuring fine-grained kernel execution time.
|
|
2779
|
+
*/
|
|
2780
|
+
declare function startTrace(): void;
|
|
2781
|
+
/**
|
|
2782
|
+
* Stop collecting kernel traces.
|
|
2783
|
+
*
|
|
2784
|
+
* Traces appear in developer tools under the "Performance" tab, and they are
|
|
2785
|
+
* useful for measuring fine-grained kernel execution time.
|
|
2786
|
+
*/
|
|
2787
|
+
declare function stopTrace(): void;
|
|
2788
|
+
/** Check if tracing is currently enabled. */
|
|
2789
|
+
|
|
2694
2790
|
//#endregion
|
|
2695
2791
|
//#region src/index.d.ts
|
|
2792
|
+
/** @namespace */
|
|
2793
|
+
declare const profiler: {
|
|
2794
|
+
startTrace: typeof startTrace;
|
|
2795
|
+
stopTrace: typeof stopTrace;
|
|
2796
|
+
};
|
|
2696
2797
|
/**
|
|
2697
2798
|
* @function
|
|
2698
2799
|
* Compute the forward-mode Jacobian-vector product for a function.
|
|
@@ -2857,4 +2958,4 @@ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
|
|
|
2857
2958
|
*/
|
|
2858
2959
|
declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
|
|
2859
2960
|
//#endregion
|
|
2860
|
-
export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|
|
2961
|
+
export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, profiler, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|