@jax-js/jax 0.1.9 → 0.1.11
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +35 -19
- package/dist/{backend-BId79r5b.js → backend-DZvR7mZV.js} +831 -26
- package/dist/{backend-DpI0riom.cjs → backend-DlYlOYqN.cjs} +872 -25
- package/dist/index.cjs +364 -20
- package/dist/index.d.cts +175 -11
- package/dist/index.d.ts +175 -11
- package/dist/index.js +363 -21
- package/dist/{webgl-DnGrclTz.js → webgl-D8-14NzA.js} +7 -1
- package/dist/{webgl-C5NjXc1p.cjs → webgl-Ovaaa-Qx.cjs} +7 -1
- package/dist/{webgpu-AN0cG_nB.js → webgpu-Dg8FpYrH.js} +141 -6
- package/dist/{webgpu-CdjiJSa7.cjs → webgpu-uU9nnttc.cjs} +141 -6
- package/package.json +5 -16
package/dist/index.cjs
CHANGED
|
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
|
|
|
30
30
|
}) : target, mod$1));
|
|
31
31
|
|
|
32
32
|
//#endregion
|
|
33
|
-
const require_backend = require('./backend-
|
|
33
|
+
const require_backend = require('./backend-DlYlOYqN.cjs');
|
|
34
34
|
|
|
35
35
|
//#region src/frontend/convolution.ts
|
|
36
36
|
/**
|
|
@@ -364,6 +364,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
364
364
|
Primitive$1["Mod"] = "mod";
|
|
365
365
|
Primitive$1["Min"] = "min";
|
|
366
366
|
Primitive$1["Max"] = "max";
|
|
367
|
+
Primitive$1["BitCombine"] = "bit_combine";
|
|
368
|
+
Primitive$1["BitShift"] = "bit_shift";
|
|
367
369
|
Primitive$1["Neg"] = "neg";
|
|
368
370
|
Primitive$1["Reciprocal"] = "reciprocal";
|
|
369
371
|
Primitive$1["Floor"] = "floor";
|
|
@@ -437,6 +439,12 @@ function min$1(x, y) {
|
|
|
437
439
|
function max$1(x, y) {
|
|
438
440
|
return bind1(Primitive.Max, [x, y]);
|
|
439
441
|
}
|
|
442
|
+
function bitCombine(x, y, op) {
|
|
443
|
+
return bind1(Primitive.BitCombine, [x, y], { op });
|
|
444
|
+
}
|
|
445
|
+
function bitShift(x, y, op) {
|
|
446
|
+
return bind1(Primitive.BitShift, [x, y], { op });
|
|
447
|
+
}
|
|
440
448
|
function neg(x) {
|
|
441
449
|
return bind1(Primitive.Neg, [x]);
|
|
442
450
|
}
|
|
@@ -838,6 +846,11 @@ var Tracer = class Tracer {
|
|
|
838
846
|
if (this.dtype === dtype) return this;
|
|
839
847
|
return cast(this, dtype);
|
|
840
848
|
}
|
|
849
|
+
/** Return a bitwise cast of the array, viewed as a new dtype. */
|
|
850
|
+
view(dtype) {
|
|
851
|
+
if (!dtype || dtype === this.dtype) return this;
|
|
852
|
+
return bitcast(this, dtype);
|
|
853
|
+
}
|
|
841
854
|
/** Subtract an array from this one. */
|
|
842
855
|
sub(other) {
|
|
843
856
|
return this.add(neg(other));
|
|
@@ -1650,6 +1663,16 @@ const abstractEvalRules = {
|
|
|
1650
1663
|
[Primitive.Mod]: binopAbstractEval,
|
|
1651
1664
|
[Primitive.Min]: binopAbstractEval,
|
|
1652
1665
|
[Primitive.Max]: binopAbstractEval,
|
|
1666
|
+
[Primitive.BitCombine]([x, y]) {
|
|
1667
|
+
const aval = promoteAvals(x, y);
|
|
1668
|
+
if (require_backend.isFloatDtype(aval.dtype)) throw new TypeError(`bitwise operations require integer or boolean inputs, got ${aval.dtype}`);
|
|
1669
|
+
return [aval];
|
|
1670
|
+
},
|
|
1671
|
+
[Primitive.BitShift]([x, y]) {
|
|
1672
|
+
const shape$1 = require_backend.generalBroadcast(x.shape, y.shape);
|
|
1673
|
+
if (require_backend.isFloatDtype(x.dtype) || require_backend.isFloatDtype(y.dtype) || x.dtype === require_backend.DType.Bool || y.dtype === require_backend.DType.Bool) throw new TypeError(`bit shift operations require integer inputs, got ${x} and ${y}`);
|
|
1674
|
+
return [new ShapedArray(shape$1, x.dtype, x.weakType)];
|
|
1675
|
+
},
|
|
1653
1676
|
[Primitive.Neg]: vectorizedUnopAbstractEval,
|
|
1654
1677
|
[Primitive.Reciprocal]: vectorizedUnopAbstractEval,
|
|
1655
1678
|
[Primitive.Floor]: vectorizedUnopAbstractEval,
|
|
@@ -1659,7 +1682,7 @@ const abstractEvalRules = {
|
|
|
1659
1682
|
return [new ShapedArray(x.shape, dtype, false)];
|
|
1660
1683
|
},
|
|
1661
1684
|
[Primitive.Bitcast]([x], { dtype }) {
|
|
1662
|
-
if (x.dtype === require_backend.DType.Bool || dtype === require_backend.DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
|
|
1685
|
+
if (x.dtype !== dtype && (x.dtype === require_backend.DType.Bool || dtype === require_backend.DType.Bool)) throw new TypeError("Bitcast to/from bool is not allowed");
|
|
1663
1686
|
if (require_backend.byteWidth(x.dtype) !== require_backend.byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
|
|
1664
1687
|
return [new ShapedArray(x.shape, dtype, false)];
|
|
1665
1688
|
},
|
|
@@ -2185,6 +2208,8 @@ const jitRules = {
|
|
|
2185
2208
|
[Primitive.Mod]: broadcastedJit(([a, b]) => require_backend.AluExp.mod(a, b)),
|
|
2186
2209
|
[Primitive.Min]: broadcastedJit(([a, b]) => require_backend.AluExp.min(a, b)),
|
|
2187
2210
|
[Primitive.Max]: broadcastedJit(([a, b]) => require_backend.AluExp.max(a, b)),
|
|
2211
|
+
[Primitive.BitCombine]: broadcastedJit(([a, b], { op }) => require_backend.AluExp.bitCombine(a, b, op)),
|
|
2212
|
+
[Primitive.BitShift]: broadcastedJit(([a, b], { op }) => require_backend.AluExp.bitShift(a, b, op)),
|
|
2188
2213
|
[Primitive.Neg]: unopJit((a) => require_backend.AluExp.sub(require_backend.AluExp.const(a.dtype, 0), a)),
|
|
2189
2214
|
[Primitive.Reciprocal]: unopJit(require_backend.AluExp.reciprocal),
|
|
2190
2215
|
[Primitive.Floor]: unopJit(require_backend.AluExp.floor),
|
|
@@ -2377,7 +2402,9 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2377
2402
|
case Primitive.Idiv:
|
|
2378
2403
|
case Primitive.Mod:
|
|
2379
2404
|
case Primitive.Min:
|
|
2380
|
-
case Primitive.Max:
|
|
2405
|
+
case Primitive.Max:
|
|
2406
|
+
case Primitive.BitCombine:
|
|
2407
|
+
case Primitive.BitShift: {
|
|
2381
2408
|
const otherInput = nextEqn.inputs.find((v) => v !== outVar);
|
|
2382
2409
|
if (otherInput instanceof Lit || require_backend.deepEqual(require_backend.generalBroadcast(otherInput.aval.shape, outVar.aval.shape), outVar.aval.shape)) {
|
|
2383
2410
|
head = usages[0];
|
|
@@ -3016,6 +3043,42 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
3016
3043
|
return require_backend.dtypedArray(this.dtype, buf);
|
|
3017
3044
|
}
|
|
3018
3045
|
/**
|
|
3046
|
+
* Return this array as a WebGPU buffer (with `STORAGE | COPY_SRC`).
|
|
3047
|
+
*
|
|
3048
|
+
* Only available on the WebGPU backend. The array's memory is still managed
|
|
3049
|
+
* by jax-js, and it will be freed when the buffer is no longer in use. You
|
|
3050
|
+
* _should not_ mutate the buffer's contents.
|
|
3051
|
+
*
|
|
3052
|
+
* Note that the GPU buffer may be slightly larger than the array's size; it
|
|
3053
|
+
* will always be aligned to 4 bytes.
|
|
3054
|
+
*/
|
|
3055
|
+
async gpuBuffer() {
|
|
3056
|
+
if (this.device !== "webgpu") throw new Error(`gpuBuffer() is only available on WebGPU backend`);
|
|
3057
|
+
this.#realize();
|
|
3058
|
+
const pending = this.#pending;
|
|
3059
|
+
if (pending) {
|
|
3060
|
+
await Promise.all(pending.map((p) => p.prepare()));
|
|
3061
|
+
for (const p of pending) p.submit();
|
|
3062
|
+
}
|
|
3063
|
+
const backend = this.#backend;
|
|
3064
|
+
const { buffer } = backend.buffers.get(this.#source);
|
|
3065
|
+
this.dispose();
|
|
3066
|
+
return buffer;
|
|
3067
|
+
}
|
|
3068
|
+
/** Synchronous version of `Array.gpuBuffer()`. */
|
|
3069
|
+
gpuBufferSync() {
|
|
3070
|
+
if (this.device !== "webgpu") throw new Error(`gpuBufferSync() is only available on WebGPU backend`);
|
|
3071
|
+
this.#realize();
|
|
3072
|
+
for (const p of this.#pending) {
|
|
3073
|
+
p.prepareSync();
|
|
3074
|
+
p.submit();
|
|
3075
|
+
}
|
|
3076
|
+
const backend = this.#backend;
|
|
3077
|
+
const { buffer } = backend.buffers.get(this.#source);
|
|
3078
|
+
this.dispose();
|
|
3079
|
+
return buffer;
|
|
3080
|
+
}
|
|
3081
|
+
/**
|
|
3019
3082
|
* Convert this array into a JavaScript object.
|
|
3020
3083
|
*
|
|
3021
3084
|
* This is a blocking operation that will compile all of the shaders and wait
|
|
@@ -3062,6 +3125,14 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
3062
3125
|
[Primitive.Max]([x, y]) {
|
|
3063
3126
|
return [x.#binary(require_backend.AluOp.Max, y)];
|
|
3064
3127
|
},
|
|
3128
|
+
[Primitive.BitCombine]([x, y], { op }) {
|
|
3129
|
+
const custom = (src) => require_backend.AluExp.bitCombine(src[0], src[1], op);
|
|
3130
|
+
return [Array$1.#naryCustom("bit_combine", custom, [x, y])];
|
|
3131
|
+
},
|
|
3132
|
+
[Primitive.BitShift]([x, y], { op }) {
|
|
3133
|
+
const custom = (src) => require_backend.AluExp.bitShift(src[0], src[1], op);
|
|
3134
|
+
return [Array$1.#naryCustom("bit_shift", custom, [x, y], { dtypeOverride: [void 0, y.dtype] })];
|
|
3135
|
+
},
|
|
3065
3136
|
[Primitive.Neg]([x]) {
|
|
3066
3137
|
return [zerosLike$1(x.ref).#binary(require_backend.AluOp.Sub, x)];
|
|
3067
3138
|
},
|
|
@@ -3081,8 +3152,8 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
3081
3152
|
return [x.#unary(require_backend.AluOp.Cast, dtype)];
|
|
3082
3153
|
},
|
|
3083
3154
|
[Primitive.Bitcast]([x], { dtype }) {
|
|
3084
|
-
if (x.dtype === require_backend.DType.Bool || dtype === require_backend.DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
|
|
3085
3155
|
if (x.dtype === dtype) return [x];
|
|
3156
|
+
if (x.dtype === require_backend.DType.Bool || dtype === require_backend.DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
|
|
3086
3157
|
if (require_backend.byteWidth(x.dtype) !== require_backend.byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
|
|
3087
3158
|
if (x.#source instanceof require_backend.AluExp) return [x.#unary(require_backend.AluOp.Bitcast, dtype)];
|
|
3088
3159
|
else {
|
|
@@ -3754,6 +3825,8 @@ const vmapRules = {
|
|
|
3754
3825
|
[Primitive.Mod]: broadcastBatcher(Primitive.Mod),
|
|
3755
3826
|
[Primitive.Min]: broadcastBatcher(Primitive.Min),
|
|
3756
3827
|
[Primitive.Max]: broadcastBatcher(Primitive.Max),
|
|
3828
|
+
[Primitive.BitCombine]: broadcastBatcher(Primitive.BitCombine),
|
|
3829
|
+
[Primitive.BitShift]: broadcastBatcher(Primitive.BitShift),
|
|
3757
3830
|
[Primitive.Neg]: unopBatcher(Primitive.Neg),
|
|
3758
3831
|
[Primitive.Reciprocal]: unopBatcher(Primitive.Reciprocal),
|
|
3759
3832
|
[Primitive.Floor]: unopBatcher(Primitive.Floor),
|
|
@@ -4077,6 +4150,8 @@ const jvpRules = {
|
|
|
4077
4150
|
[Primitive.Max]([x, y], [dx, dy]) {
|
|
4078
4151
|
return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
|
|
4079
4152
|
},
|
|
4153
|
+
[Primitive.BitCombine]: zeroTangentsJvp(Primitive.BitCombine),
|
|
4154
|
+
[Primitive.BitShift]: zeroTangentsJvp(Primitive.BitShift),
|
|
4080
4155
|
[Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
|
|
4081
4156
|
[Primitive.Reciprocal]([x], [dx]) {
|
|
4082
4157
|
const xRecip = reciprocal$1(x.ref);
|
|
@@ -4179,6 +4254,7 @@ const jvpRules = {
|
|
|
4179
4254
|
},
|
|
4180
4255
|
[Primitive.TriangularSolve]([a, b], [da, db], { unitDiagonal }) {
|
|
4181
4256
|
const x = triangularSolve$1(a.ref, b, { unitDiagonal });
|
|
4257
|
+
da = unitDiagonal ? triu(da, 1) : triu(da);
|
|
4182
4258
|
const dax = batchMatmulT(da, x.ref);
|
|
4183
4259
|
const rhsT = db.sub(mT(dax));
|
|
4184
4260
|
const dx = triangularSolve$1(a, rhsT, { unitDiagonal });
|
|
@@ -5254,6 +5330,7 @@ function ifft(a, axis = -1) {
|
|
|
5254
5330
|
var numpy_linalg_exports = {};
|
|
5255
5331
|
__export(numpy_linalg_exports, {
|
|
5256
5332
|
cholesky: () => cholesky,
|
|
5333
|
+
cross: () => cross$1,
|
|
5257
5334
|
det: () => det,
|
|
5258
5335
|
diagonal: () => diagonal,
|
|
5259
5336
|
inv: () => inv,
|
|
@@ -5266,7 +5343,8 @@ __export(numpy_linalg_exports, {
|
|
|
5266
5343
|
solve: () => solve,
|
|
5267
5344
|
tensordot: () => tensordot,
|
|
5268
5345
|
trace: () => trace,
|
|
5269
|
-
vecdot: () => vecdot
|
|
5346
|
+
vecdot: () => vecdot,
|
|
5347
|
+
vectorNorm: () => vectorNorm
|
|
5270
5348
|
});
|
|
5271
5349
|
function checkSquare(name, a) {
|
|
5272
5350
|
if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`${name}: input must be at least 2D square matrix, got ${a.aval}`);
|
|
@@ -5284,6 +5362,19 @@ function cholesky(a, { upper = false, symmetrizeInput = true } = {}) {
|
|
|
5284
5362
|
if (symmetrizeInput) a = a.ref.add(matrixTranspose(a)).mul(.5);
|
|
5285
5363
|
return cholesky$1(a, { upper });
|
|
5286
5364
|
}
|
|
5365
|
+
/**
|
|
5366
|
+
* Compute the cross-product of two 3D vectors.
|
|
5367
|
+
*
|
|
5368
|
+
* This is a simpler and less flexible version of `jax.numpy.cross()`.
|
|
5369
|
+
* Both inputs must have size 3 along the specified axis.
|
|
5370
|
+
*/
|
|
5371
|
+
function cross$1(x1, x2, axis = -1) {
|
|
5372
|
+
const a1 = require_backend.checkAxis(axis, ndim(x1));
|
|
5373
|
+
const a2 = require_backend.checkAxis(axis, ndim(x2));
|
|
5374
|
+
if (shape(x1)[a1] !== 3) throw new Error(`linalg.cross: x1 must have size 3 along axis ${axis}, got ${shape(x1)[a1]}`);
|
|
5375
|
+
if (shape(x2)[a2] !== 3) throw new Error(`linalg.cross: x2 must have size 3 along axis ${axis}, got ${shape(x2)[a2]}`);
|
|
5376
|
+
return cross(x1, x2, { axis });
|
|
5377
|
+
}
|
|
5287
5378
|
/** Compute the determinant of a square matrix (batched). */
|
|
5288
5379
|
function det(a) {
|
|
5289
5380
|
a = fudgeArray(a);
|
|
@@ -5299,7 +5390,7 @@ function det(a) {
|
|
|
5299
5390
|
function inv(a) {
|
|
5300
5391
|
a = fudgeArray(a);
|
|
5301
5392
|
const n = checkSquare("inv", a);
|
|
5302
|
-
return solve(a, eye(n));
|
|
5393
|
+
return solve(a, eye(n, void 0, { dtype: a.dtype }));
|
|
5303
5394
|
}
|
|
5304
5395
|
/**
|
|
5305
5396
|
* Return the least-squares solution to a linear equation.
|
|
@@ -5356,8 +5447,9 @@ function matrixPower(a, n) {
|
|
|
5356
5447
|
a = fudgeArray(a);
|
|
5357
5448
|
const m = checkSquare("matrixPower", a);
|
|
5358
5449
|
if (n === 0) {
|
|
5450
|
+
const dtype = a.dtype;
|
|
5359
5451
|
a.dispose();
|
|
5360
|
-
return broadcastTo(eye(m), a.shape);
|
|
5452
|
+
return broadcastTo(eye(m, void 0, { dtype }), a.shape);
|
|
5361
5453
|
}
|
|
5362
5454
|
if (n < 0) {
|
|
5363
5455
|
a = inv(a);
|
|
@@ -5431,6 +5523,23 @@ function solve(a, b) {
|
|
|
5431
5523
|
if (bIs1d) x = squeeze(x, -1);
|
|
5432
5524
|
return x;
|
|
5433
5525
|
}
|
|
5526
|
+
/**
|
|
5527
|
+
* Compute the vector norm of an array.
|
|
5528
|
+
*
|
|
5529
|
+
* @param x - Input array.
|
|
5530
|
+
* @param ord - Order of the norm (default 2). Supports `Infinity`, `-Infinity`, `0`, or any real number.
|
|
5531
|
+
* @param axis - Axis/axes to reduce over (default: all axes).
|
|
5532
|
+
* @param keepdims - Whether to keep reduced dimensions as size 1.
|
|
5533
|
+
* @returns The norm of `x`, reduced over the given axes.
|
|
5534
|
+
*/
|
|
5535
|
+
function vectorNorm(x, { ord = 2, axis = null, keepdims = false } = {}) {
|
|
5536
|
+
x = fudgeArray(x);
|
|
5537
|
+
const ax = axis ?? null;
|
|
5538
|
+
if (ord === Infinity) return max(absolute(x), ax, { keepdims });
|
|
5539
|
+
else if (ord === -Infinity) return min(absolute(x), ax, { keepdims });
|
|
5540
|
+
else if (ord === 0) return x.notEqual(0).astype(x.dtype).sum(ax, { keepdims });
|
|
5541
|
+
else return power(power(absolute(x), ord).sum(ax, { keepdims }), 1 / ord);
|
|
5542
|
+
}
|
|
5434
5543
|
|
|
5435
5544
|
//#endregion
|
|
5436
5545
|
//#region src/library/numpy/dtype-info.ts
|
|
@@ -5539,13 +5648,24 @@ __export(numpy_exports, {
|
|
|
5539
5648
|
argmax: () => argmax,
|
|
5540
5649
|
argmin: () => argmin,
|
|
5541
5650
|
argsort: () => argsort,
|
|
5651
|
+
around: () => round,
|
|
5542
5652
|
array: () => array,
|
|
5653
|
+
arrayEqual: () => arrayEqual,
|
|
5654
|
+
arrayEquiv: () => arrayEquiv,
|
|
5543
5655
|
asin: () => asin,
|
|
5544
5656
|
asinh: () => arcsinh,
|
|
5545
5657
|
astype: () => astype,
|
|
5546
5658
|
atan: () => atan,
|
|
5547
5659
|
atan2: () => atan2,
|
|
5548
5660
|
atanh: () => arctanh,
|
|
5661
|
+
average: () => average,
|
|
5662
|
+
bitwiseAnd: () => bitwiseAnd,
|
|
5663
|
+
bitwiseInvert: () => invert,
|
|
5664
|
+
bitwiseLeftShift: () => leftShift,
|
|
5665
|
+
bitwiseNot: () => invert,
|
|
5666
|
+
bitwiseOr: () => bitwiseOr,
|
|
5667
|
+
bitwiseRightShift: () => rightShift,
|
|
5668
|
+
bitwiseXor: () => bitwiseXor,
|
|
5549
5669
|
bool: () => bool,
|
|
5550
5670
|
broadcastArrays: () => broadcastArrays,
|
|
5551
5671
|
broadcastShapes: () => broadcastShapes,
|
|
@@ -5556,11 +5676,13 @@ __export(numpy_exports, {
|
|
|
5556
5676
|
columnStack: () => columnStack,
|
|
5557
5677
|
concatenate: () => concatenate,
|
|
5558
5678
|
convolve: () => convolve,
|
|
5679
|
+
copysign: () => copysign,
|
|
5559
5680
|
corrcoef: () => corrcoef,
|
|
5560
5681
|
correlate: () => correlate,
|
|
5561
5682
|
cos: () => cos,
|
|
5562
5683
|
cosh: () => cosh,
|
|
5563
5684
|
cov: () => cov,
|
|
5685
|
+
cross: () => cross,
|
|
5564
5686
|
cumsum: () => cumsum,
|
|
5565
5687
|
cumulativeSum: () => cumsum,
|
|
5566
5688
|
deg2rad: () => deg2rad,
|
|
@@ -5596,7 +5718,6 @@ __export(numpy_exports, {
|
|
|
5596
5718
|
fullLike: () => fullLike$1,
|
|
5597
5719
|
greater: () => greater,
|
|
5598
5720
|
greaterEqual: () => greaterEqual,
|
|
5599
|
-
hamming: () => hamming,
|
|
5600
5721
|
hann: () => hann,
|
|
5601
5722
|
heaviside: () => heaviside,
|
|
5602
5723
|
hstack: () => hstack,
|
|
@@ -5606,12 +5727,14 @@ __export(numpy_exports, {
|
|
|
5606
5727
|
inf: () => inf,
|
|
5607
5728
|
inner: () => inner,
|
|
5608
5729
|
int32: () => int32,
|
|
5730
|
+
invert: () => invert,
|
|
5609
5731
|
isfinite: () => isfinite,
|
|
5610
5732
|
isinf: () => isinf,
|
|
5611
5733
|
isnan: () => isnan,
|
|
5612
5734
|
isneginf: () => isneginf,
|
|
5613
5735
|
isposinf: () => isposinf,
|
|
5614
5736
|
ldexp: () => ldexp,
|
|
5737
|
+
leftShift: () => leftShift,
|
|
5615
5738
|
less: () => less,
|
|
5616
5739
|
lessEqual: () => lessEqual,
|
|
5617
5740
|
linalg: () => numpy_linalg_exports,
|
|
@@ -5620,9 +5743,14 @@ __export(numpy_exports, {
|
|
|
5620
5743
|
log10: () => log10,
|
|
5621
5744
|
log1p: () => log1p,
|
|
5622
5745
|
log2: () => log2,
|
|
5746
|
+
logicalAnd: () => logicalAnd,
|
|
5747
|
+
logicalNot: () => logicalNot,
|
|
5748
|
+
logicalOr: () => logicalOr,
|
|
5749
|
+
logicalXor: () => logicalXor,
|
|
5623
5750
|
logspace: () => logspace,
|
|
5624
5751
|
matmul: () => matmul,
|
|
5625
5752
|
matrixTranspose: () => matrixTranspose,
|
|
5753
|
+
matvec: () => matvec,
|
|
5626
5754
|
max: () => max,
|
|
5627
5755
|
maximum: () => maximum,
|
|
5628
5756
|
mean: () => mean,
|
|
@@ -5655,6 +5783,9 @@ __export(numpy_exports, {
|
|
|
5655
5783
|
remainder: () => remainder,
|
|
5656
5784
|
repeat: () => repeat,
|
|
5657
5785
|
reshape: () => reshape,
|
|
5786
|
+
rightShift: () => rightShift,
|
|
5787
|
+
rint: () => rint,
|
|
5788
|
+
round: () => round,
|
|
5658
5789
|
shape: () => shape,
|
|
5659
5790
|
sign: () => sign,
|
|
5660
5791
|
sin: () => sin,
|
|
@@ -5687,6 +5818,7 @@ __export(numpy_exports, {
|
|
|
5687
5818
|
var_: () => var_,
|
|
5688
5819
|
vdot: () => vdot,
|
|
5689
5820
|
vecdot: () => vecdot,
|
|
5821
|
+
vecmat: () => vecmat,
|
|
5690
5822
|
vstack: () => vstack,
|
|
5691
5823
|
where: () => where,
|
|
5692
5824
|
zeros: () => zeros,
|
|
@@ -5750,6 +5882,60 @@ const notEqual = notEqual$1;
|
|
|
5750
5882
|
const greaterEqual = greaterEqual$1;
|
|
5751
5883
|
/** @function Compare two arrays element-wise. */
|
|
5752
5884
|
const lessEqual = lessEqual$1;
|
|
5885
|
+
/** Compute element-wise logical AND. */
|
|
5886
|
+
function logicalAnd(x, y) {
|
|
5887
|
+
return astype(x, require_backend.DType.Bool).mul(astype(y, require_backend.DType.Bool));
|
|
5888
|
+
}
|
|
5889
|
+
/** Compute element-wise logical OR. */
|
|
5890
|
+
function logicalOr(x, y) {
|
|
5891
|
+
return astype(x, require_backend.DType.Bool).add(astype(y, require_backend.DType.Bool));
|
|
5892
|
+
}
|
|
5893
|
+
/** Compute element-wise logical XOR. */
|
|
5894
|
+
function logicalXor(x, y) {
|
|
5895
|
+
return notEqual(astype(x, require_backend.DType.Bool), astype(y, require_backend.DType.Bool));
|
|
5896
|
+
}
|
|
5897
|
+
/** Compute element-wise logical NOT. */
|
|
5898
|
+
function logicalNot(x) {
|
|
5899
|
+
return notEqual(astype(x, require_backend.DType.Bool), true);
|
|
5900
|
+
}
|
|
5901
|
+
/** Compute element-wise bitwise AND. */
|
|
5902
|
+
function bitwiseAnd(x, y) {
|
|
5903
|
+
return bitCombine(x, y, "and");
|
|
5904
|
+
}
|
|
5905
|
+
/** Compute element-wise bitwise OR. */
|
|
5906
|
+
function bitwiseOr(x, y) {
|
|
5907
|
+
return bitCombine(x, y, "or");
|
|
5908
|
+
}
|
|
5909
|
+
/** Compute element-wise bitwise XOR. */
|
|
5910
|
+
function bitwiseXor(x, y) {
|
|
5911
|
+
return bitCombine(x, y, "xor");
|
|
5912
|
+
}
|
|
5913
|
+
/** Compute element-wise bitwise NOT (inversion). */
|
|
5914
|
+
function invert(x) {
|
|
5915
|
+
const arr = fudgeArray(x);
|
|
5916
|
+
let allOnes;
|
|
5917
|
+
switch (arr.dtype) {
|
|
5918
|
+
case require_backend.DType.Bool:
|
|
5919
|
+
allOnes = true;
|
|
5920
|
+
break;
|
|
5921
|
+
case require_backend.DType.Uint32:
|
|
5922
|
+
allOnes = 4294967295;
|
|
5923
|
+
break;
|
|
5924
|
+
case require_backend.DType.Int32:
|
|
5925
|
+
allOnes = -1;
|
|
5926
|
+
break;
|
|
5927
|
+
default: throw new TypeError(`invert: unsupported dtype ${arr.dtype}`);
|
|
5928
|
+
}
|
|
5929
|
+
return bitCombine(arr, allOnes, "xor");
|
|
5930
|
+
}
|
|
5931
|
+
/** Compute element-wise left bit shift. */
|
|
5932
|
+
function leftShift(x, y) {
|
|
5933
|
+
return bitShift(x, y, "shl");
|
|
5934
|
+
}
|
|
5935
|
+
/** Compute element-wise right bit shift. */
|
|
5936
|
+
function rightShift(x, y) {
|
|
5937
|
+
return bitShift(x, y, "shr");
|
|
5938
|
+
}
|
|
5753
5939
|
/** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
5754
5940
|
const where = where$1;
|
|
5755
5941
|
/**
|
|
@@ -5857,6 +6043,34 @@ function mean(a, axis = null, opts) {
|
|
|
5857
6043
|
return fudgeArray(a).mean(axis, opts);
|
|
5858
6044
|
}
|
|
5859
6045
|
/**
|
|
6046
|
+
* Compute the weighted average along the specified axis.
|
|
6047
|
+
*
|
|
6048
|
+
* If no axis is specified, mean is computed along all the axes. The weights
|
|
6049
|
+
* should have shape matching that of `a`, or if an axis is specified, it should
|
|
6050
|
+
* match the shape along those axes.
|
|
6051
|
+
*/
|
|
6052
|
+
function average(a, axis = null, opts) {
|
|
6053
|
+
a = fudgeArray(a);
|
|
6054
|
+
if (opts?.weights == null) return mean(a, axis, opts);
|
|
6055
|
+
const weights = fudgeArray(opts.weights);
|
|
6056
|
+
axis = require_backend.normalizeAxis(axis, ndim(a));
|
|
6057
|
+
const wShape = weights.shape;
|
|
6058
|
+
const aShape = a.shape;
|
|
6059
|
+
if (require_backend.deepEqual(wShape, aShape)) {
|
|
6060
|
+
const scl = sum(weights.ref, axis, opts);
|
|
6061
|
+
return sum(multiply(a, weights), axis, opts).div(scl);
|
|
6062
|
+
} else if (axis.length === 1 && wShape.length === 1 && wShape[0] === aShape[axis[0]]) {
|
|
6063
|
+
const broadcastShape = aShape.map((_, i) => i === axis[0] ? wShape[0] : 1);
|
|
6064
|
+
const wReshaped = reshape(weights, broadcastShape);
|
|
6065
|
+
const scl = sum(wReshaped.ref, axis, opts);
|
|
6066
|
+
return sum(multiply(a, wReshaped), axis, opts).div(scl);
|
|
6067
|
+
} else {
|
|
6068
|
+
weights.dispose();
|
|
6069
|
+
a.dispose();
|
|
6070
|
+
throw new Error(`average: weights shape ${JSON.stringify(wShape)} is not compatible with array shape ${JSON.stringify(aShape)} and axis ${JSON.stringify(axis)}`);
|
|
6071
|
+
}
|
|
6072
|
+
}
|
|
6073
|
+
/**
|
|
5860
6074
|
* Returns the indices of the minimum values along an axis.
|
|
5861
6075
|
*
|
|
5862
6076
|
* By default, index is into the flatted array, otherwise it is along the
|
|
@@ -6260,20 +6474,63 @@ function take(a, indices, axis = null) {
|
|
|
6260
6474
|
axis = require_backend.checkAxis(axis, ndim(a));
|
|
6261
6475
|
return gather(a, [indices], [axis], axis);
|
|
6262
6476
|
}
|
|
6263
|
-
/**
|
|
6477
|
+
/**
|
|
6478
|
+
* Return if two arrays are element-wise equal within a tolerance.
|
|
6479
|
+
*
|
|
6480
|
+
* The formula used is `|actual - expected| <= atol + rtol * |expected|`, with
|
|
6481
|
+
* NaN values comparing equal if `equalNaN` is true.
|
|
6482
|
+
*/
|
|
6264
6483
|
function allclose(actual, expected, options) {
|
|
6265
|
-
const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
|
|
6484
|
+
const { rtol = 1e-5, atol = 1e-7, equalNaN = false } = options ?? {};
|
|
6266
6485
|
const x = array(actual);
|
|
6267
6486
|
const y = array(expected);
|
|
6268
6487
|
if (!require_backend.deepEqual(x.shape, y.shape)) return false;
|
|
6269
6488
|
const xData = x.dataSync();
|
|
6270
6489
|
const yData = y.dataSync();
|
|
6271
6490
|
for (let i = 0; i < xData.length; i++) {
|
|
6272
|
-
if (isNaN(xData[i]) !== isNaN(yData[i])) return false;
|
|
6491
|
+
if (equalNaN ? isNaN(xData[i]) !== isNaN(yData[i]) : isNaN(xData[i]) || isNaN(yData[i])) return false;
|
|
6273
6492
|
if (Math.abs(xData[i] - yData[i]) > atol + rtol * Math.abs(yData[i])) return false;
|
|
6274
6493
|
}
|
|
6275
6494
|
return true;
|
|
6276
6495
|
}
|
|
6496
|
+
/**
|
|
6497
|
+
* Check if two arrays are element-wise equal.
|
|
6498
|
+
*
|
|
6499
|
+
* Returns False if the arrays have different shapes. If `equalNaN` is True,
|
|
6500
|
+
* NaNs in the same position are considered equal.
|
|
6501
|
+
*/
|
|
6502
|
+
function arrayEqual(a1, a2, opts) {
|
|
6503
|
+
a1 = fudgeArray(a1);
|
|
6504
|
+
a2 = fudgeArray(a2);
|
|
6505
|
+
if (!require_backend.deepEqual(a1.shape, a2.shape)) {
|
|
6506
|
+
a1.dispose();
|
|
6507
|
+
a2.dispose();
|
|
6508
|
+
return array(false);
|
|
6509
|
+
}
|
|
6510
|
+
if (opts?.equalNaN) {
|
|
6511
|
+
const nanMask = isnan(a1.ref).mul(isnan(a2.ref));
|
|
6512
|
+
return where(nanMask, true, equal(a1, a2)).all();
|
|
6513
|
+
}
|
|
6514
|
+
return equal(a1, a2).all();
|
|
6515
|
+
}
|
|
6516
|
+
/**
|
|
6517
|
+
* Check if two arrays are element-wise equal after broadcasting.
|
|
6518
|
+
*
|
|
6519
|
+
* Unlike `arrayEqual`, this allows inputs with different but
|
|
6520
|
+
* broadcast-compatible shapes.
|
|
6521
|
+
*/
|
|
6522
|
+
function arrayEquiv(a1, a2) {
|
|
6523
|
+
a1 = fudgeArray(a1);
|
|
6524
|
+
a2 = fudgeArray(a2);
|
|
6525
|
+
try {
|
|
6526
|
+
const [b1, b2] = broadcastArrays(a1, a2);
|
|
6527
|
+
return equal(b1, b2).all();
|
|
6528
|
+
} catch {
|
|
6529
|
+
a1.dispose();
|
|
6530
|
+
a2.dispose();
|
|
6531
|
+
return array(false);
|
|
6532
|
+
}
|
|
6533
|
+
}
|
|
6277
6534
|
/** Matrix product of two arrays. */
|
|
6278
6535
|
function matmul(x, y) {
|
|
6279
6536
|
if (ndim(x) === 0 || ndim(y) === 0) throw new Error("matmul: x and y must be at least 1D");
|
|
@@ -6287,6 +6544,16 @@ function matmul(x, y) {
|
|
|
6287
6544
|
rhsBatchDims: require_backend.range(-2 - numBatchDims, -2)
|
|
6288
6545
|
});
|
|
6289
6546
|
}
|
|
6547
|
+
/** Matrix-vector product. x1 is [..., M, N], x2 is [..., N] → [..., M]. */
|
|
6548
|
+
function matvec(x1, x2) {
|
|
6549
|
+
if (ndim(x1) < 2 || ndim(x2) < 1) throw new Error("matvec: x1 must be at least 2D and x2 at least 1D");
|
|
6550
|
+
return einsum("...mn,...n->...m", x1, x2);
|
|
6551
|
+
}
|
|
6552
|
+
/** Vector-matrix product. x1 is [..., N], x2 is [..., N, M] → [..., M]. */
|
|
6553
|
+
function vecmat(x1, x2) {
|
|
6554
|
+
if (ndim(x1) < 1 || ndim(x2) < 2) throw new Error("vecmat: x1 must be at least 1D and x2 at least 2D");
|
|
6555
|
+
return einsum("...n,...nm->...m", x1, x2);
|
|
6556
|
+
}
|
|
6290
6557
|
/** Dot product of two arrays. */
|
|
6291
6558
|
function dot$1(x, y) {
|
|
6292
6559
|
if (ndim(x) === 0 || ndim(y) === 0) return multiply(x, y);
|
|
@@ -6445,6 +6712,49 @@ function outer(x, y) {
|
|
|
6445
6712
|
y = ravel(y);
|
|
6446
6713
|
return multiply(x.reshape([x.shape[0], 1]), y);
|
|
6447
6714
|
}
|
|
6715
|
+
/**
|
|
6716
|
+
* @function Compute the cross product of two arrays.
|
|
6717
|
+
*
|
|
6718
|
+
* Supports 2D (scalar result) and 3D cross products, with optional axis
|
|
6719
|
+
* arguments. If `axis` is given, it overrides `axisa`, `axisb`, and `axisc`.
|
|
6720
|
+
*/
|
|
6721
|
+
const cross = jit$1(function cross$2(a, b, { axisa = -1, axisb = -1, axisc = -1, axis } = {}) {
|
|
6722
|
+
if (axis !== void 0) {
|
|
6723
|
+
axisa = axis;
|
|
6724
|
+
axisb = axis;
|
|
6725
|
+
axisc = axis;
|
|
6726
|
+
}
|
|
6727
|
+
axisa = require_backend.checkAxis(axisa, ndim(a));
|
|
6728
|
+
axisb = require_backend.checkAxis(axisb, ndim(b));
|
|
6729
|
+
a = moveaxis$1(a, axisa, -1);
|
|
6730
|
+
b = moveaxis$1(b, axisb, -1);
|
|
6731
|
+
const da = a.shape.at(-1);
|
|
6732
|
+
const db = b.shape.at(-1);
|
|
6733
|
+
if (da !== 2 && da !== 3 || db !== 2 && db !== 3) throw new Error(`cross: incompatible dimensions for cross product (got ${da} and ${db})`);
|
|
6734
|
+
if (da === 2 && db === 2) {
|
|
6735
|
+
const [a0$1, a1$1] = split$1(a, 2, -1);
|
|
6736
|
+
const [b0$1, b1$1] = split$1(b, 2, -1);
|
|
6737
|
+
return squeeze(a0$1.mul(b1$1).sub(a1$1.mul(b0$1)), -1);
|
|
6738
|
+
}
|
|
6739
|
+
if (da === 2) {
|
|
6740
|
+
const zeroShape = [...a.shape.slice(0, -1), 1];
|
|
6741
|
+
a = concatenate([a, zeros(zeroShape)], -1);
|
|
6742
|
+
}
|
|
6743
|
+
if (db === 2) {
|
|
6744
|
+
const zeroShape = [...b.shape.slice(0, -1), 1];
|
|
6745
|
+
b = concatenate([b, zeros(zeroShape)], -1);
|
|
6746
|
+
}
|
|
6747
|
+
const [a0, a1, a2] = split$1(a, 3, -1);
|
|
6748
|
+
const [b0, b1, b2] = split$1(b, 3, -1);
|
|
6749
|
+
const c0 = a1.ref.mul(b2.ref).sub(a2.ref.mul(b1.ref));
|
|
6750
|
+
const c1 = a2.mul(b0.ref).sub(a0.ref.mul(b2));
|
|
6751
|
+
const c2 = a0.mul(b1).sub(a1.mul(b0));
|
|
6752
|
+
return moveaxis$1(concatenate([
|
|
6753
|
+
c0,
|
|
6754
|
+
c1,
|
|
6755
|
+
c2
|
|
6756
|
+
], -1), -1, axisc);
|
|
6757
|
+
}, { staticArgnums: [2] });
|
|
6448
6758
|
/** Vector dot product of two arrays along a given axis. */
|
|
6449
6759
|
function vecdot(x, y, { axis } = {}) {
|
|
6450
6760
|
const xaxis = require_backend.checkAxis(axis ?? -1, ndim(x));
|
|
@@ -6541,16 +6851,15 @@ function sign(x) {
|
|
|
6541
6851
|
x = fudgeArray(x);
|
|
6542
6852
|
return where(notEqual(x.ref, 0), where(less(x, 0), -1, 1), 0);
|
|
6543
6853
|
}
|
|
6544
|
-
/** @function Return element-wise positive values of the input (no-op). */
|
|
6545
|
-
const positive = fudgeArray;
|
|
6546
6854
|
/**
|
|
6547
|
-
*
|
|
6548
|
-
*
|
|
6549
|
-
* `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
6855
|
+
* @function
|
|
6856
|
+
* Return the value with the magnitude of x and the sign of y, element-wise.
|
|
6550
6857
|
*/
|
|
6551
|
-
function
|
|
6552
|
-
return
|
|
6553
|
-
}
|
|
6858
|
+
const copysign = jit$1(function copysign$1(x, y) {
|
|
6859
|
+
return absolute(x).mul(sign(y));
|
|
6860
|
+
});
|
|
6861
|
+
/** @function Return element-wise positive values of the input (no-op). */
|
|
6862
|
+
const positive = fudgeArray;
|
|
6554
6863
|
/**
|
|
6555
6864
|
* Return the Hann window of size M, a taper with a weighted cosine bell.
|
|
6556
6865
|
*
|
|
@@ -6696,6 +7005,27 @@ function trunc(x) {
|
|
|
6696
7005
|
return idiv(x, 1);
|
|
6697
7006
|
}
|
|
6698
7007
|
/**
|
|
7008
|
+
* @function
|
|
7009
|
+
* Round to the given number of decimals.
|
|
7010
|
+
*
|
|
7011
|
+
* Uses banker's rounding (round half to even) to match NumPy/JAX behavior.
|
|
7012
|
+
*/
|
|
7013
|
+
const round = jit$1(function round$1(a, decimals = 0) {
|
|
7014
|
+
if (decimals === 0) return rint(a);
|
|
7015
|
+
const factor = 10 ** decimals;
|
|
7016
|
+
return rint(a.mul(factor)).mul(1 / factor);
|
|
7017
|
+
}, { staticArgnums: [1] });
|
|
7018
|
+
/**
|
|
7019
|
+
* @function
|
|
7020
|
+
* Round to the nearest integer, with ties going to the nearest even integer.
|
|
7021
|
+
*/
|
|
7022
|
+
const rint = jit$1(function rint$1(x) {
|
|
7023
|
+
const rounded = floor(x.ref.add(.5));
|
|
7024
|
+
const half = x.ref.sub(floor(x)).equal(.5);
|
|
7025
|
+
const odd = remainder(rounded.ref, 2).notEqual(0);
|
|
7026
|
+
return where(half.mul(odd), rounded.ref.sub(1), rounded);
|
|
7027
|
+
});
|
|
7028
|
+
/**
|
|
6699
7029
|
* Compute `x1 * 2 ** x2` as a standard multiplication and exponentiation.
|
|
6700
7030
|
*
|
|
6701
7031
|
* This is the inverse of `frexp()`.
|
|
@@ -7023,6 +7353,7 @@ function triangularSolve(a, b, { leftSide = false, lower = false, transposeA = f
|
|
|
7023
7353
|
//#region src/library/lax.ts
|
|
7024
7354
|
var lax_exports = {};
|
|
7025
7355
|
__export(lax_exports, {
|
|
7356
|
+
bitcastConvertType: () => bitcastConvertType,
|
|
7026
7357
|
conv: () => conv,
|
|
7027
7358
|
convGeneralDilated: () => convGeneralDilated,
|
|
7028
7359
|
convTranspose: () => convTranspose,
|
|
@@ -7036,6 +7367,10 @@ __export(lax_exports, {
|
|
|
7036
7367
|
topK: () => topK
|
|
7037
7368
|
});
|
|
7038
7369
|
const JsArray = globalThis.Array;
|
|
7370
|
+
/** Elementwise bitcast an array into a new dtype. */
|
|
7371
|
+
function bitcastConvertType(x, newDtype) {
|
|
7372
|
+
return fudgeArray(x).view(newDtype);
|
|
7373
|
+
}
|
|
7039
7374
|
/**
|
|
7040
7375
|
* General dot product/contraction operator.
|
|
7041
7376
|
*
|
|
@@ -7767,7 +8102,9 @@ function getK01(key$1) {
|
|
|
7767
8102
|
function key(seed) {
|
|
7768
8103
|
seed = array(seed, { dtype: require_backend.DType.Uint32 });
|
|
7769
8104
|
if (seed.ndim !== 0) throw new Error(`key: seed must be a scalar integer, but got shape ${seed.shape} - use jax.vmap for batching.`);
|
|
7770
|
-
|
|
8105
|
+
const key$1 = stack([0, seed]);
|
|
8106
|
+
if (key$1 instanceof Array$1) key$1._realizeSource();
|
|
8107
|
+
return key$1;
|
|
7771
8108
|
}
|
|
7772
8109
|
/** Splits a PRNG key into `num` new keys by adding a leading axis. */
|
|
7773
8110
|
function split(key$1, num = 2) {
|
|
@@ -7962,6 +8299,11 @@ Symbol.asyncDispose ??= Symbol.for("Symbol.asyncDispose");
|
|
|
7962
8299
|
|
|
7963
8300
|
//#endregion
|
|
7964
8301
|
//#region src/index.ts
|
|
8302
|
+
/** @namespace */
|
|
8303
|
+
const profiler = {
|
|
8304
|
+
startTrace: require_backend.startTrace,
|
|
8305
|
+
stopTrace: require_backend.stopTrace
|
|
8306
|
+
};
|
|
7965
8307
|
/**
|
|
7966
8308
|
* @function
|
|
7967
8309
|
* Compute the forward-mode Jacobian-vector product for a function.
|
|
@@ -8130,6 +8472,7 @@ exports.blockUntilReady = blockUntilReady;
|
|
|
8130
8472
|
exports.defaultDevice = require_backend.defaultDevice;
|
|
8131
8473
|
exports.devicePut = devicePut;
|
|
8132
8474
|
exports.devices = require_backend.devices;
|
|
8475
|
+
exports.getWebGPUDevice = require_backend.getWebGPUDevice;
|
|
8133
8476
|
exports.grad = grad;
|
|
8134
8477
|
exports.hessian = hessian;
|
|
8135
8478
|
exports.init = require_backend.init;
|
|
@@ -8158,6 +8501,7 @@ Object.defineProperty(exports, 'numpy', {
|
|
|
8158
8501
|
return numpy_exports;
|
|
8159
8502
|
}
|
|
8160
8503
|
});
|
|
8504
|
+
exports.profiler = profiler;
|
|
8161
8505
|
Object.defineProperty(exports, 'random', {
|
|
8162
8506
|
enumerable: true,
|
|
8163
8507
|
get: function () {
|