@jax-js/jax 0.1.9 → 0.1.11
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +35 -19
- package/dist/{backend-BId79r5b.js → backend-DZvR7mZV.js} +831 -26
- package/dist/{backend-DpI0riom.cjs → backend-DlYlOYqN.cjs} +872 -25
- package/dist/index.cjs +364 -20
- package/dist/index.d.cts +175 -11
- package/dist/index.d.ts +175 -11
- package/dist/index.js +363 -21
- package/dist/{webgl-DnGrclTz.js → webgl-D8-14NzA.js} +7 -1
- package/dist/{webgl-C5NjXc1p.cjs → webgl-Ovaaa-Qx.cjs} +7 -1
- package/dist/{webgpu-AN0cG_nB.js → webgpu-Dg8FpYrH.js} +141 -6
- package/dist/{webgpu-CdjiJSa7.cjs → webgpu-uU9nnttc.cjs} +141 -6
- package/package.json +5 -16
package/dist/index.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { __export } from "./chunk-Cl8Af3a2.js";
|
|
2
|
-
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-
|
|
2
|
+
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, getWebGPUDevice, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, startTrace, stopTrace, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-DZvR7mZV.js";
|
|
3
3
|
|
|
4
4
|
//#region src/frontend/convolution.ts
|
|
5
5
|
/**
|
|
@@ -333,6 +333,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
333
333
|
Primitive$1["Mod"] = "mod";
|
|
334
334
|
Primitive$1["Min"] = "min";
|
|
335
335
|
Primitive$1["Max"] = "max";
|
|
336
|
+
Primitive$1["BitCombine"] = "bit_combine";
|
|
337
|
+
Primitive$1["BitShift"] = "bit_shift";
|
|
336
338
|
Primitive$1["Neg"] = "neg";
|
|
337
339
|
Primitive$1["Reciprocal"] = "reciprocal";
|
|
338
340
|
Primitive$1["Floor"] = "floor";
|
|
@@ -406,6 +408,12 @@ function min$1(x, y) {
|
|
|
406
408
|
function max$1(x, y) {
|
|
407
409
|
return bind1(Primitive.Max, [x, y]);
|
|
408
410
|
}
|
|
411
|
+
function bitCombine(x, y, op) {
|
|
412
|
+
return bind1(Primitive.BitCombine, [x, y], { op });
|
|
413
|
+
}
|
|
414
|
+
function bitShift(x, y, op) {
|
|
415
|
+
return bind1(Primitive.BitShift, [x, y], { op });
|
|
416
|
+
}
|
|
409
417
|
function neg(x) {
|
|
410
418
|
return bind1(Primitive.Neg, [x]);
|
|
411
419
|
}
|
|
@@ -807,6 +815,11 @@ var Tracer = class Tracer {
|
|
|
807
815
|
if (this.dtype === dtype) return this;
|
|
808
816
|
return cast(this, dtype);
|
|
809
817
|
}
|
|
818
|
+
/** Return a bitwise cast of the array, viewed as a new dtype. */
|
|
819
|
+
view(dtype) {
|
|
820
|
+
if (!dtype || dtype === this.dtype) return this;
|
|
821
|
+
return bitcast(this, dtype);
|
|
822
|
+
}
|
|
810
823
|
/** Subtract an array from this one. */
|
|
811
824
|
sub(other) {
|
|
812
825
|
return this.add(neg(other));
|
|
@@ -1615,6 +1628,16 @@ const abstractEvalRules = {
|
|
|
1615
1628
|
[Primitive.Mod]: binopAbstractEval,
|
|
1616
1629
|
[Primitive.Min]: binopAbstractEval,
|
|
1617
1630
|
[Primitive.Max]: binopAbstractEval,
|
|
1631
|
+
[Primitive.BitCombine]([x, y]) {
|
|
1632
|
+
const aval = promoteAvals(x, y);
|
|
1633
|
+
if (isFloatDtype(aval.dtype)) throw new TypeError(`bitwise operations require integer or boolean inputs, got ${aval.dtype}`);
|
|
1634
|
+
return [aval];
|
|
1635
|
+
},
|
|
1636
|
+
[Primitive.BitShift]([x, y]) {
|
|
1637
|
+
const shape$1 = generalBroadcast(x.shape, y.shape);
|
|
1638
|
+
if (isFloatDtype(x.dtype) || isFloatDtype(y.dtype) || x.dtype === DType.Bool || y.dtype === DType.Bool) throw new TypeError(`bit shift operations require integer inputs, got ${x} and ${y}`);
|
|
1639
|
+
return [new ShapedArray(shape$1, x.dtype, x.weakType)];
|
|
1640
|
+
},
|
|
1618
1641
|
[Primitive.Neg]: vectorizedUnopAbstractEval,
|
|
1619
1642
|
[Primitive.Reciprocal]: vectorizedUnopAbstractEval,
|
|
1620
1643
|
[Primitive.Floor]: vectorizedUnopAbstractEval,
|
|
@@ -1624,7 +1647,7 @@ const abstractEvalRules = {
|
|
|
1624
1647
|
return [new ShapedArray(x.shape, dtype, false)];
|
|
1625
1648
|
},
|
|
1626
1649
|
[Primitive.Bitcast]([x], { dtype }) {
|
|
1627
|
-
if (x.dtype === DType.Bool || dtype === DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
|
|
1650
|
+
if (x.dtype !== dtype && (x.dtype === DType.Bool || dtype === DType.Bool)) throw new TypeError("Bitcast to/from bool is not allowed");
|
|
1628
1651
|
if (byteWidth(x.dtype) !== byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
|
|
1629
1652
|
return [new ShapedArray(x.shape, dtype, false)];
|
|
1630
1653
|
},
|
|
@@ -2150,6 +2173,8 @@ const jitRules = {
|
|
|
2150
2173
|
[Primitive.Mod]: broadcastedJit(([a, b]) => AluExp.mod(a, b)),
|
|
2151
2174
|
[Primitive.Min]: broadcastedJit(([a, b]) => AluExp.min(a, b)),
|
|
2152
2175
|
[Primitive.Max]: broadcastedJit(([a, b]) => AluExp.max(a, b)),
|
|
2176
|
+
[Primitive.BitCombine]: broadcastedJit(([a, b], { op }) => AluExp.bitCombine(a, b, op)),
|
|
2177
|
+
[Primitive.BitShift]: broadcastedJit(([a, b], { op }) => AluExp.bitShift(a, b, op)),
|
|
2153
2178
|
[Primitive.Neg]: unopJit((a) => AluExp.sub(AluExp.const(a.dtype, 0), a)),
|
|
2154
2179
|
[Primitive.Reciprocal]: unopJit(AluExp.reciprocal),
|
|
2155
2180
|
[Primitive.Floor]: unopJit(AluExp.floor),
|
|
@@ -2342,7 +2367,9 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2342
2367
|
case Primitive.Idiv:
|
|
2343
2368
|
case Primitive.Mod:
|
|
2344
2369
|
case Primitive.Min:
|
|
2345
|
-
case Primitive.Max:
|
|
2370
|
+
case Primitive.Max:
|
|
2371
|
+
case Primitive.BitCombine:
|
|
2372
|
+
case Primitive.BitShift: {
|
|
2346
2373
|
const otherInput = nextEqn.inputs.find((v) => v !== outVar);
|
|
2347
2374
|
if (otherInput instanceof Lit || deepEqual(generalBroadcast(otherInput.aval.shape, outVar.aval.shape), outVar.aval.shape)) {
|
|
2348
2375
|
head = usages[0];
|
|
@@ -2981,6 +3008,42 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2981
3008
|
return dtypedArray(this.dtype, buf);
|
|
2982
3009
|
}
|
|
2983
3010
|
/**
|
|
3011
|
+
* Return this array as a WebGPU buffer (with `STORAGE | COPY_SRC`).
|
|
3012
|
+
*
|
|
3013
|
+
* Only available on the WebGPU backend. The array's memory is still managed
|
|
3014
|
+
* by jax-js, and it will be freed when the buffer is no longer in use. You
|
|
3015
|
+
* _should not_ mutate the buffer's contents.
|
|
3016
|
+
*
|
|
3017
|
+
* Note that the GPU buffer may be slightly larger than the array's size; it
|
|
3018
|
+
* will always be aligned to 4 bytes.
|
|
3019
|
+
*/
|
|
3020
|
+
async gpuBuffer() {
|
|
3021
|
+
if (this.device !== "webgpu") throw new Error(`gpuBuffer() is only available on WebGPU backend`);
|
|
3022
|
+
this.#realize();
|
|
3023
|
+
const pending = this.#pending;
|
|
3024
|
+
if (pending) {
|
|
3025
|
+
await Promise.all(pending.map((p) => p.prepare()));
|
|
3026
|
+
for (const p of pending) p.submit();
|
|
3027
|
+
}
|
|
3028
|
+
const backend = this.#backend;
|
|
3029
|
+
const { buffer } = backend.buffers.get(this.#source);
|
|
3030
|
+
this.dispose();
|
|
3031
|
+
return buffer;
|
|
3032
|
+
}
|
|
3033
|
+
/** Synchronous version of `Array.gpuBuffer()`. */
|
|
3034
|
+
gpuBufferSync() {
|
|
3035
|
+
if (this.device !== "webgpu") throw new Error(`gpuBufferSync() is only available on WebGPU backend`);
|
|
3036
|
+
this.#realize();
|
|
3037
|
+
for (const p of this.#pending) {
|
|
3038
|
+
p.prepareSync();
|
|
3039
|
+
p.submit();
|
|
3040
|
+
}
|
|
3041
|
+
const backend = this.#backend;
|
|
3042
|
+
const { buffer } = backend.buffers.get(this.#source);
|
|
3043
|
+
this.dispose();
|
|
3044
|
+
return buffer;
|
|
3045
|
+
}
|
|
3046
|
+
/**
|
|
2984
3047
|
* Convert this array into a JavaScript object.
|
|
2985
3048
|
*
|
|
2986
3049
|
* This is a blocking operation that will compile all of the shaders and wait
|
|
@@ -3027,6 +3090,14 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
3027
3090
|
[Primitive.Max]([x, y]) {
|
|
3028
3091
|
return [x.#binary(AluOp.Max, y)];
|
|
3029
3092
|
},
|
|
3093
|
+
[Primitive.BitCombine]([x, y], { op }) {
|
|
3094
|
+
const custom = (src) => AluExp.bitCombine(src[0], src[1], op);
|
|
3095
|
+
return [Array$1.#naryCustom("bit_combine", custom, [x, y])];
|
|
3096
|
+
},
|
|
3097
|
+
[Primitive.BitShift]([x, y], { op }) {
|
|
3098
|
+
const custom = (src) => AluExp.bitShift(src[0], src[1], op);
|
|
3099
|
+
return [Array$1.#naryCustom("bit_shift", custom, [x, y], { dtypeOverride: [void 0, y.dtype] })];
|
|
3100
|
+
},
|
|
3030
3101
|
[Primitive.Neg]([x]) {
|
|
3031
3102
|
return [zerosLike$1(x.ref).#binary(AluOp.Sub, x)];
|
|
3032
3103
|
},
|
|
@@ -3046,8 +3117,8 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
3046
3117
|
return [x.#unary(AluOp.Cast, dtype)];
|
|
3047
3118
|
},
|
|
3048
3119
|
[Primitive.Bitcast]([x], { dtype }) {
|
|
3049
|
-
if (x.dtype === DType.Bool || dtype === DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
|
|
3050
3120
|
if (x.dtype === dtype) return [x];
|
|
3121
|
+
if (x.dtype === DType.Bool || dtype === DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
|
|
3051
3122
|
if (byteWidth(x.dtype) !== byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
|
|
3052
3123
|
if (x.#source instanceof AluExp) return [x.#unary(AluOp.Bitcast, dtype)];
|
|
3053
3124
|
else {
|
|
@@ -3718,6 +3789,8 @@ const vmapRules = {
|
|
|
3718
3789
|
[Primitive.Mod]: broadcastBatcher(Primitive.Mod),
|
|
3719
3790
|
[Primitive.Min]: broadcastBatcher(Primitive.Min),
|
|
3720
3791
|
[Primitive.Max]: broadcastBatcher(Primitive.Max),
|
|
3792
|
+
[Primitive.BitCombine]: broadcastBatcher(Primitive.BitCombine),
|
|
3793
|
+
[Primitive.BitShift]: broadcastBatcher(Primitive.BitShift),
|
|
3721
3794
|
[Primitive.Neg]: unopBatcher(Primitive.Neg),
|
|
3722
3795
|
[Primitive.Reciprocal]: unopBatcher(Primitive.Reciprocal),
|
|
3723
3796
|
[Primitive.Floor]: unopBatcher(Primitive.Floor),
|
|
@@ -4040,6 +4113,8 @@ const jvpRules = {
|
|
|
4040
4113
|
[Primitive.Max]([x, y], [dx, dy]) {
|
|
4041
4114
|
return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
|
|
4042
4115
|
},
|
|
4116
|
+
[Primitive.BitCombine]: zeroTangentsJvp(Primitive.BitCombine),
|
|
4117
|
+
[Primitive.BitShift]: zeroTangentsJvp(Primitive.BitShift),
|
|
4043
4118
|
[Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
|
|
4044
4119
|
[Primitive.Reciprocal]([x], [dx]) {
|
|
4045
4120
|
const xRecip = reciprocal$1(x.ref);
|
|
@@ -4142,6 +4217,7 @@ const jvpRules = {
|
|
|
4142
4217
|
},
|
|
4143
4218
|
[Primitive.TriangularSolve]([a, b], [da, db], { unitDiagonal }) {
|
|
4144
4219
|
const x = triangularSolve$1(a.ref, b, { unitDiagonal });
|
|
4220
|
+
da = unitDiagonal ? triu(da, 1) : triu(da);
|
|
4145
4221
|
const dax = batchMatmulT(da, x.ref);
|
|
4146
4222
|
const rhsT = db.sub(mT(dax));
|
|
4147
4223
|
const dx = triangularSolve$1(a, rhsT, { unitDiagonal });
|
|
@@ -5217,6 +5293,7 @@ function ifft(a, axis = -1) {
|
|
|
5217
5293
|
var numpy_linalg_exports = {};
|
|
5218
5294
|
__export(numpy_linalg_exports, {
|
|
5219
5295
|
cholesky: () => cholesky,
|
|
5296
|
+
cross: () => cross$1,
|
|
5220
5297
|
det: () => det,
|
|
5221
5298
|
diagonal: () => diagonal,
|
|
5222
5299
|
inv: () => inv,
|
|
@@ -5229,7 +5306,8 @@ __export(numpy_linalg_exports, {
|
|
|
5229
5306
|
solve: () => solve,
|
|
5230
5307
|
tensordot: () => tensordot,
|
|
5231
5308
|
trace: () => trace,
|
|
5232
|
-
vecdot: () => vecdot
|
|
5309
|
+
vecdot: () => vecdot,
|
|
5310
|
+
vectorNorm: () => vectorNorm
|
|
5233
5311
|
});
|
|
5234
5312
|
function checkSquare(name, a) {
|
|
5235
5313
|
if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`${name}: input must be at least 2D square matrix, got ${a.aval}`);
|
|
@@ -5247,6 +5325,19 @@ function cholesky(a, { upper = false, symmetrizeInput = true } = {}) {
|
|
|
5247
5325
|
if (symmetrizeInput) a = a.ref.add(matrixTranspose(a)).mul(.5);
|
|
5248
5326
|
return cholesky$1(a, { upper });
|
|
5249
5327
|
}
|
|
5328
|
+
/**
|
|
5329
|
+
* Compute the cross-product of two 3D vectors.
|
|
5330
|
+
*
|
|
5331
|
+
* This is a simpler and less flexible version of `jax.numpy.cross()`.
|
|
5332
|
+
* Both inputs must have size 3 along the specified axis.
|
|
5333
|
+
*/
|
|
5334
|
+
function cross$1(x1, x2, axis = -1) {
|
|
5335
|
+
const a1 = checkAxis(axis, ndim(x1));
|
|
5336
|
+
const a2 = checkAxis(axis, ndim(x2));
|
|
5337
|
+
if (shape(x1)[a1] !== 3) throw new Error(`linalg.cross: x1 must have size 3 along axis ${axis}, got ${shape(x1)[a1]}`);
|
|
5338
|
+
if (shape(x2)[a2] !== 3) throw new Error(`linalg.cross: x2 must have size 3 along axis ${axis}, got ${shape(x2)[a2]}`);
|
|
5339
|
+
return cross(x1, x2, { axis });
|
|
5340
|
+
}
|
|
5250
5341
|
/** Compute the determinant of a square matrix (batched). */
|
|
5251
5342
|
function det(a) {
|
|
5252
5343
|
a = fudgeArray(a);
|
|
@@ -5262,7 +5353,7 @@ function det(a) {
|
|
|
5262
5353
|
function inv(a) {
|
|
5263
5354
|
a = fudgeArray(a);
|
|
5264
5355
|
const n = checkSquare("inv", a);
|
|
5265
|
-
return solve(a, eye(n));
|
|
5356
|
+
return solve(a, eye(n, void 0, { dtype: a.dtype }));
|
|
5266
5357
|
}
|
|
5267
5358
|
/**
|
|
5268
5359
|
* Return the least-squares solution to a linear equation.
|
|
@@ -5319,8 +5410,9 @@ function matrixPower(a, n) {
|
|
|
5319
5410
|
a = fudgeArray(a);
|
|
5320
5411
|
const m = checkSquare("matrixPower", a);
|
|
5321
5412
|
if (n === 0) {
|
|
5413
|
+
const dtype = a.dtype;
|
|
5322
5414
|
a.dispose();
|
|
5323
|
-
return broadcastTo(eye(m), a.shape);
|
|
5415
|
+
return broadcastTo(eye(m, void 0, { dtype }), a.shape);
|
|
5324
5416
|
}
|
|
5325
5417
|
if (n < 0) {
|
|
5326
5418
|
a = inv(a);
|
|
@@ -5394,6 +5486,23 @@ function solve(a, b) {
|
|
|
5394
5486
|
if (bIs1d) x = squeeze(x, -1);
|
|
5395
5487
|
return x;
|
|
5396
5488
|
}
|
|
5489
|
+
/**
|
|
5490
|
+
* Compute the vector norm of an array.
|
|
5491
|
+
*
|
|
5492
|
+
* @param x - Input array.
|
|
5493
|
+
* @param ord - Order of the norm (default 2). Supports `Infinity`, `-Infinity`, `0`, or any real number.
|
|
5494
|
+
* @param axis - Axis/axes to reduce over (default: all axes).
|
|
5495
|
+
* @param keepdims - Whether to keep reduced dimensions as size 1.
|
|
5496
|
+
* @returns The norm of `x`, reduced over the given axes.
|
|
5497
|
+
*/
|
|
5498
|
+
function vectorNorm(x, { ord = 2, axis = null, keepdims = false } = {}) {
|
|
5499
|
+
x = fudgeArray(x);
|
|
5500
|
+
const ax = axis ?? null;
|
|
5501
|
+
if (ord === Infinity) return max(absolute(x), ax, { keepdims });
|
|
5502
|
+
else if (ord === -Infinity) return min(absolute(x), ax, { keepdims });
|
|
5503
|
+
else if (ord === 0) return x.notEqual(0).astype(x.dtype).sum(ax, { keepdims });
|
|
5504
|
+
else return power(power(absolute(x), ord).sum(ax, { keepdims }), 1 / ord);
|
|
5505
|
+
}
|
|
5397
5506
|
|
|
5398
5507
|
//#endregion
|
|
5399
5508
|
//#region src/library/numpy/dtype-info.ts
|
|
@@ -5502,13 +5611,24 @@ __export(numpy_exports, {
|
|
|
5502
5611
|
argmax: () => argmax,
|
|
5503
5612
|
argmin: () => argmin,
|
|
5504
5613
|
argsort: () => argsort,
|
|
5614
|
+
around: () => round,
|
|
5505
5615
|
array: () => array,
|
|
5616
|
+
arrayEqual: () => arrayEqual,
|
|
5617
|
+
arrayEquiv: () => arrayEquiv,
|
|
5506
5618
|
asin: () => asin,
|
|
5507
5619
|
asinh: () => arcsinh,
|
|
5508
5620
|
astype: () => astype,
|
|
5509
5621
|
atan: () => atan,
|
|
5510
5622
|
atan2: () => atan2,
|
|
5511
5623
|
atanh: () => arctanh,
|
|
5624
|
+
average: () => average,
|
|
5625
|
+
bitwiseAnd: () => bitwiseAnd,
|
|
5626
|
+
bitwiseInvert: () => invert,
|
|
5627
|
+
bitwiseLeftShift: () => leftShift,
|
|
5628
|
+
bitwiseNot: () => invert,
|
|
5629
|
+
bitwiseOr: () => bitwiseOr,
|
|
5630
|
+
bitwiseRightShift: () => rightShift,
|
|
5631
|
+
bitwiseXor: () => bitwiseXor,
|
|
5512
5632
|
bool: () => bool,
|
|
5513
5633
|
broadcastArrays: () => broadcastArrays,
|
|
5514
5634
|
broadcastShapes: () => broadcastShapes,
|
|
@@ -5519,11 +5639,13 @@ __export(numpy_exports, {
|
|
|
5519
5639
|
columnStack: () => columnStack,
|
|
5520
5640
|
concatenate: () => concatenate,
|
|
5521
5641
|
convolve: () => convolve,
|
|
5642
|
+
copysign: () => copysign,
|
|
5522
5643
|
corrcoef: () => corrcoef,
|
|
5523
5644
|
correlate: () => correlate,
|
|
5524
5645
|
cos: () => cos,
|
|
5525
5646
|
cosh: () => cosh,
|
|
5526
5647
|
cov: () => cov,
|
|
5648
|
+
cross: () => cross,
|
|
5527
5649
|
cumsum: () => cumsum,
|
|
5528
5650
|
cumulativeSum: () => cumsum,
|
|
5529
5651
|
deg2rad: () => deg2rad,
|
|
@@ -5559,7 +5681,6 @@ __export(numpy_exports, {
|
|
|
5559
5681
|
fullLike: () => fullLike$1,
|
|
5560
5682
|
greater: () => greater,
|
|
5561
5683
|
greaterEqual: () => greaterEqual,
|
|
5562
|
-
hamming: () => hamming,
|
|
5563
5684
|
hann: () => hann,
|
|
5564
5685
|
heaviside: () => heaviside,
|
|
5565
5686
|
hstack: () => hstack,
|
|
@@ -5569,12 +5690,14 @@ __export(numpy_exports, {
|
|
|
5569
5690
|
inf: () => inf,
|
|
5570
5691
|
inner: () => inner,
|
|
5571
5692
|
int32: () => int32,
|
|
5693
|
+
invert: () => invert,
|
|
5572
5694
|
isfinite: () => isfinite,
|
|
5573
5695
|
isinf: () => isinf,
|
|
5574
5696
|
isnan: () => isnan,
|
|
5575
5697
|
isneginf: () => isneginf,
|
|
5576
5698
|
isposinf: () => isposinf,
|
|
5577
5699
|
ldexp: () => ldexp,
|
|
5700
|
+
leftShift: () => leftShift,
|
|
5578
5701
|
less: () => less,
|
|
5579
5702
|
lessEqual: () => lessEqual,
|
|
5580
5703
|
linalg: () => numpy_linalg_exports,
|
|
@@ -5583,9 +5706,14 @@ __export(numpy_exports, {
|
|
|
5583
5706
|
log10: () => log10,
|
|
5584
5707
|
log1p: () => log1p,
|
|
5585
5708
|
log2: () => log2,
|
|
5709
|
+
logicalAnd: () => logicalAnd,
|
|
5710
|
+
logicalNot: () => logicalNot,
|
|
5711
|
+
logicalOr: () => logicalOr,
|
|
5712
|
+
logicalXor: () => logicalXor,
|
|
5586
5713
|
logspace: () => logspace,
|
|
5587
5714
|
matmul: () => matmul,
|
|
5588
5715
|
matrixTranspose: () => matrixTranspose,
|
|
5716
|
+
matvec: () => matvec,
|
|
5589
5717
|
max: () => max,
|
|
5590
5718
|
maximum: () => maximum,
|
|
5591
5719
|
mean: () => mean,
|
|
@@ -5618,6 +5746,9 @@ __export(numpy_exports, {
|
|
|
5618
5746
|
remainder: () => remainder,
|
|
5619
5747
|
repeat: () => repeat,
|
|
5620
5748
|
reshape: () => reshape,
|
|
5749
|
+
rightShift: () => rightShift,
|
|
5750
|
+
rint: () => rint,
|
|
5751
|
+
round: () => round,
|
|
5621
5752
|
shape: () => shape,
|
|
5622
5753
|
sign: () => sign,
|
|
5623
5754
|
sin: () => sin,
|
|
@@ -5650,6 +5781,7 @@ __export(numpy_exports, {
|
|
|
5650
5781
|
var_: () => var_,
|
|
5651
5782
|
vdot: () => vdot,
|
|
5652
5783
|
vecdot: () => vecdot,
|
|
5784
|
+
vecmat: () => vecmat,
|
|
5653
5785
|
vstack: () => vstack,
|
|
5654
5786
|
where: () => where,
|
|
5655
5787
|
zeros: () => zeros,
|
|
@@ -5713,6 +5845,60 @@ const notEqual = notEqual$1;
|
|
|
5713
5845
|
const greaterEqual = greaterEqual$1;
|
|
5714
5846
|
/** @function Compare two arrays element-wise. */
|
|
5715
5847
|
const lessEqual = lessEqual$1;
|
|
5848
|
+
/** Compute element-wise logical AND. */
|
|
5849
|
+
function logicalAnd(x, y) {
|
|
5850
|
+
return astype(x, DType.Bool).mul(astype(y, DType.Bool));
|
|
5851
|
+
}
|
|
5852
|
+
/** Compute element-wise logical OR. */
|
|
5853
|
+
function logicalOr(x, y) {
|
|
5854
|
+
return astype(x, DType.Bool).add(astype(y, DType.Bool));
|
|
5855
|
+
}
|
|
5856
|
+
/** Compute element-wise logical XOR. */
|
|
5857
|
+
function logicalXor(x, y) {
|
|
5858
|
+
return notEqual(astype(x, DType.Bool), astype(y, DType.Bool));
|
|
5859
|
+
}
|
|
5860
|
+
/** Compute element-wise logical NOT. */
|
|
5861
|
+
function logicalNot(x) {
|
|
5862
|
+
return notEqual(astype(x, DType.Bool), true);
|
|
5863
|
+
}
|
|
5864
|
+
/** Compute element-wise bitwise AND. */
|
|
5865
|
+
function bitwiseAnd(x, y) {
|
|
5866
|
+
return bitCombine(x, y, "and");
|
|
5867
|
+
}
|
|
5868
|
+
/** Compute element-wise bitwise OR. */
|
|
5869
|
+
function bitwiseOr(x, y) {
|
|
5870
|
+
return bitCombine(x, y, "or");
|
|
5871
|
+
}
|
|
5872
|
+
/** Compute element-wise bitwise XOR. */
|
|
5873
|
+
function bitwiseXor(x, y) {
|
|
5874
|
+
return bitCombine(x, y, "xor");
|
|
5875
|
+
}
|
|
5876
|
+
/** Compute element-wise bitwise NOT (inversion). */
|
|
5877
|
+
function invert(x) {
|
|
5878
|
+
const arr = fudgeArray(x);
|
|
5879
|
+
let allOnes;
|
|
5880
|
+
switch (arr.dtype) {
|
|
5881
|
+
case DType.Bool:
|
|
5882
|
+
allOnes = true;
|
|
5883
|
+
break;
|
|
5884
|
+
case DType.Uint32:
|
|
5885
|
+
allOnes = 4294967295;
|
|
5886
|
+
break;
|
|
5887
|
+
case DType.Int32:
|
|
5888
|
+
allOnes = -1;
|
|
5889
|
+
break;
|
|
5890
|
+
default: throw new TypeError(`invert: unsupported dtype ${arr.dtype}`);
|
|
5891
|
+
}
|
|
5892
|
+
return bitCombine(arr, allOnes, "xor");
|
|
5893
|
+
}
|
|
5894
|
+
/** Compute element-wise left bit shift. */
|
|
5895
|
+
function leftShift(x, y) {
|
|
5896
|
+
return bitShift(x, y, "shl");
|
|
5897
|
+
}
|
|
5898
|
+
/** Compute element-wise right bit shift. */
|
|
5899
|
+
function rightShift(x, y) {
|
|
5900
|
+
return bitShift(x, y, "shr");
|
|
5901
|
+
}
|
|
5716
5902
|
/** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
5717
5903
|
const where = where$1;
|
|
5718
5904
|
/**
|
|
@@ -5820,6 +6006,34 @@ function mean(a, axis = null, opts) {
|
|
|
5820
6006
|
return fudgeArray(a).mean(axis, opts);
|
|
5821
6007
|
}
|
|
5822
6008
|
/**
|
|
6009
|
+
* Compute the weighted average along the specified axis.
|
|
6010
|
+
*
|
|
6011
|
+
* If no axis is specified, mean is computed along all the axes. The weights
|
|
6012
|
+
* should have shape matching that of `a`, or if an axis is specified, it should
|
|
6013
|
+
* match the shape along those axes.
|
|
6014
|
+
*/
|
|
6015
|
+
function average(a, axis = null, opts) {
|
|
6016
|
+
a = fudgeArray(a);
|
|
6017
|
+
if (opts?.weights == null) return mean(a, axis, opts);
|
|
6018
|
+
const weights = fudgeArray(opts.weights);
|
|
6019
|
+
axis = normalizeAxis(axis, ndim(a));
|
|
6020
|
+
const wShape = weights.shape;
|
|
6021
|
+
const aShape = a.shape;
|
|
6022
|
+
if (deepEqual(wShape, aShape)) {
|
|
6023
|
+
const scl = sum(weights.ref, axis, opts);
|
|
6024
|
+
return sum(multiply(a, weights), axis, opts).div(scl);
|
|
6025
|
+
} else if (axis.length === 1 && wShape.length === 1 && wShape[0] === aShape[axis[0]]) {
|
|
6026
|
+
const broadcastShape = aShape.map((_, i) => i === axis[0] ? wShape[0] : 1);
|
|
6027
|
+
const wReshaped = reshape(weights, broadcastShape);
|
|
6028
|
+
const scl = sum(wReshaped.ref, axis, opts);
|
|
6029
|
+
return sum(multiply(a, wReshaped), axis, opts).div(scl);
|
|
6030
|
+
} else {
|
|
6031
|
+
weights.dispose();
|
|
6032
|
+
a.dispose();
|
|
6033
|
+
throw new Error(`average: weights shape ${JSON.stringify(wShape)} is not compatible with array shape ${JSON.stringify(aShape)} and axis ${JSON.stringify(axis)}`);
|
|
6034
|
+
}
|
|
6035
|
+
}
|
|
6036
|
+
/**
|
|
5823
6037
|
* Returns the indices of the minimum values along an axis.
|
|
5824
6038
|
*
|
|
5825
6039
|
* By default, index is into the flatted array, otherwise it is along the
|
|
@@ -6223,20 +6437,63 @@ function take(a, indices, axis = null) {
|
|
|
6223
6437
|
axis = checkAxis(axis, ndim(a));
|
|
6224
6438
|
return gather(a, [indices], [axis], axis);
|
|
6225
6439
|
}
|
|
6226
|
-
/**
|
|
6440
|
+
/**
|
|
6441
|
+
* Return if two arrays are element-wise equal within a tolerance.
|
|
6442
|
+
*
|
|
6443
|
+
* The formula used is `|actual - expected| <= atol + rtol * |expected|`, with
|
|
6444
|
+
* NaN values comparing equal if `equalNaN` is true.
|
|
6445
|
+
*/
|
|
6227
6446
|
function allclose(actual, expected, options) {
|
|
6228
|
-
const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
|
|
6447
|
+
const { rtol = 1e-5, atol = 1e-7, equalNaN = false } = options ?? {};
|
|
6229
6448
|
const x = array(actual);
|
|
6230
6449
|
const y = array(expected);
|
|
6231
6450
|
if (!deepEqual(x.shape, y.shape)) return false;
|
|
6232
6451
|
const xData = x.dataSync();
|
|
6233
6452
|
const yData = y.dataSync();
|
|
6234
6453
|
for (let i = 0; i < xData.length; i++) {
|
|
6235
|
-
if (isNaN(xData[i]) !== isNaN(yData[i])) return false;
|
|
6454
|
+
if (equalNaN ? isNaN(xData[i]) !== isNaN(yData[i]) : isNaN(xData[i]) || isNaN(yData[i])) return false;
|
|
6236
6455
|
if (Math.abs(xData[i] - yData[i]) > atol + rtol * Math.abs(yData[i])) return false;
|
|
6237
6456
|
}
|
|
6238
6457
|
return true;
|
|
6239
6458
|
}
|
|
6459
|
+
/**
|
|
6460
|
+
* Check if two arrays are element-wise equal.
|
|
6461
|
+
*
|
|
6462
|
+
* Returns False if the arrays have different shapes. If `equalNaN` is True,
|
|
6463
|
+
* NaNs in the same position are considered equal.
|
|
6464
|
+
*/
|
|
6465
|
+
function arrayEqual(a1, a2, opts) {
|
|
6466
|
+
a1 = fudgeArray(a1);
|
|
6467
|
+
a2 = fudgeArray(a2);
|
|
6468
|
+
if (!deepEqual(a1.shape, a2.shape)) {
|
|
6469
|
+
a1.dispose();
|
|
6470
|
+
a2.dispose();
|
|
6471
|
+
return array(false);
|
|
6472
|
+
}
|
|
6473
|
+
if (opts?.equalNaN) {
|
|
6474
|
+
const nanMask = isnan(a1.ref).mul(isnan(a2.ref));
|
|
6475
|
+
return where(nanMask, true, equal(a1, a2)).all();
|
|
6476
|
+
}
|
|
6477
|
+
return equal(a1, a2).all();
|
|
6478
|
+
}
|
|
6479
|
+
/**
|
|
6480
|
+
* Check if two arrays are element-wise equal after broadcasting.
|
|
6481
|
+
*
|
|
6482
|
+
* Unlike `arrayEqual`, this allows inputs with different but
|
|
6483
|
+
* broadcast-compatible shapes.
|
|
6484
|
+
*/
|
|
6485
|
+
function arrayEquiv(a1, a2) {
|
|
6486
|
+
a1 = fudgeArray(a1);
|
|
6487
|
+
a2 = fudgeArray(a2);
|
|
6488
|
+
try {
|
|
6489
|
+
const [b1, b2] = broadcastArrays(a1, a2);
|
|
6490
|
+
return equal(b1, b2).all();
|
|
6491
|
+
} catch {
|
|
6492
|
+
a1.dispose();
|
|
6493
|
+
a2.dispose();
|
|
6494
|
+
return array(false);
|
|
6495
|
+
}
|
|
6496
|
+
}
|
|
6240
6497
|
/** Matrix product of two arrays. */
|
|
6241
6498
|
function matmul(x, y) {
|
|
6242
6499
|
if (ndim(x) === 0 || ndim(y) === 0) throw new Error("matmul: x and y must be at least 1D");
|
|
@@ -6250,6 +6507,16 @@ function matmul(x, y) {
|
|
|
6250
6507
|
rhsBatchDims: range(-2 - numBatchDims, -2)
|
|
6251
6508
|
});
|
|
6252
6509
|
}
|
|
6510
|
+
/** Matrix-vector product. x1 is [..., M, N], x2 is [..., N] → [..., M]. */
|
|
6511
|
+
function matvec(x1, x2) {
|
|
6512
|
+
if (ndim(x1) < 2 || ndim(x2) < 1) throw new Error("matvec: x1 must be at least 2D and x2 at least 1D");
|
|
6513
|
+
return einsum("...mn,...n->...m", x1, x2);
|
|
6514
|
+
}
|
|
6515
|
+
/** Vector-matrix product. x1 is [..., N], x2 is [..., N, M] → [..., M]. */
|
|
6516
|
+
function vecmat(x1, x2) {
|
|
6517
|
+
if (ndim(x1) < 1 || ndim(x2) < 2) throw new Error("vecmat: x1 must be at least 1D and x2 at least 2D");
|
|
6518
|
+
return einsum("...n,...nm->...m", x1, x2);
|
|
6519
|
+
}
|
|
6253
6520
|
/** Dot product of two arrays. */
|
|
6254
6521
|
function dot$1(x, y) {
|
|
6255
6522
|
if (ndim(x) === 0 || ndim(y) === 0) return multiply(x, y);
|
|
@@ -6408,6 +6675,49 @@ function outer(x, y) {
|
|
|
6408
6675
|
y = ravel(y);
|
|
6409
6676
|
return multiply(x.reshape([x.shape[0], 1]), y);
|
|
6410
6677
|
}
|
|
6678
|
+
/**
|
|
6679
|
+
* @function Compute the cross product of two arrays.
|
|
6680
|
+
*
|
|
6681
|
+
* Supports 2D (scalar result) and 3D cross products, with optional axis
|
|
6682
|
+
* arguments. If `axis` is given, it overrides `axisa`, `axisb`, and `axisc`.
|
|
6683
|
+
*/
|
|
6684
|
+
const cross = jit$1(function cross$2(a, b, { axisa = -1, axisb = -1, axisc = -1, axis } = {}) {
|
|
6685
|
+
if (axis !== void 0) {
|
|
6686
|
+
axisa = axis;
|
|
6687
|
+
axisb = axis;
|
|
6688
|
+
axisc = axis;
|
|
6689
|
+
}
|
|
6690
|
+
axisa = checkAxis(axisa, ndim(a));
|
|
6691
|
+
axisb = checkAxis(axisb, ndim(b));
|
|
6692
|
+
a = moveaxis$1(a, axisa, -1);
|
|
6693
|
+
b = moveaxis$1(b, axisb, -1);
|
|
6694
|
+
const da = a.shape.at(-1);
|
|
6695
|
+
const db = b.shape.at(-1);
|
|
6696
|
+
if (da !== 2 && da !== 3 || db !== 2 && db !== 3) throw new Error(`cross: incompatible dimensions for cross product (got ${da} and ${db})`);
|
|
6697
|
+
if (da === 2 && db === 2) {
|
|
6698
|
+
const [a0$1, a1$1] = split$1(a, 2, -1);
|
|
6699
|
+
const [b0$1, b1$1] = split$1(b, 2, -1);
|
|
6700
|
+
return squeeze(a0$1.mul(b1$1).sub(a1$1.mul(b0$1)), -1);
|
|
6701
|
+
}
|
|
6702
|
+
if (da === 2) {
|
|
6703
|
+
const zeroShape = [...a.shape.slice(0, -1), 1];
|
|
6704
|
+
a = concatenate([a, zeros(zeroShape)], -1);
|
|
6705
|
+
}
|
|
6706
|
+
if (db === 2) {
|
|
6707
|
+
const zeroShape = [...b.shape.slice(0, -1), 1];
|
|
6708
|
+
b = concatenate([b, zeros(zeroShape)], -1);
|
|
6709
|
+
}
|
|
6710
|
+
const [a0, a1, a2] = split$1(a, 3, -1);
|
|
6711
|
+
const [b0, b1, b2] = split$1(b, 3, -1);
|
|
6712
|
+
const c0 = a1.ref.mul(b2.ref).sub(a2.ref.mul(b1.ref));
|
|
6713
|
+
const c1 = a2.mul(b0.ref).sub(a0.ref.mul(b2));
|
|
6714
|
+
const c2 = a0.mul(b1).sub(a1.mul(b0));
|
|
6715
|
+
return moveaxis$1(concatenate([
|
|
6716
|
+
c0,
|
|
6717
|
+
c1,
|
|
6718
|
+
c2
|
|
6719
|
+
], -1), -1, axisc);
|
|
6720
|
+
}, { staticArgnums: [2] });
|
|
6411
6721
|
/** Vector dot product of two arrays along a given axis. */
|
|
6412
6722
|
function vecdot(x, y, { axis } = {}) {
|
|
6413
6723
|
const xaxis = checkAxis(axis ?? -1, ndim(x));
|
|
@@ -6504,16 +6814,15 @@ function sign(x) {
|
|
|
6504
6814
|
x = fudgeArray(x);
|
|
6505
6815
|
return where(notEqual(x.ref, 0), where(less(x, 0), -1, 1), 0);
|
|
6506
6816
|
}
|
|
6507
|
-
/** @function Return element-wise positive values of the input (no-op). */
|
|
6508
|
-
const positive = fudgeArray;
|
|
6509
6817
|
/**
|
|
6510
|
-
*
|
|
6511
|
-
*
|
|
6512
|
-
* `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
6818
|
+
* @function
|
|
6819
|
+
* Return the value with the magnitude of x and the sign of y, element-wise.
|
|
6513
6820
|
*/
|
|
6514
|
-
function
|
|
6515
|
-
return
|
|
6516
|
-
}
|
|
6821
|
+
const copysign = jit$1(function copysign$1(x, y) {
|
|
6822
|
+
return absolute(x).mul(sign(y));
|
|
6823
|
+
});
|
|
6824
|
+
/** @function Return element-wise positive values of the input (no-op). */
|
|
6825
|
+
const positive = fudgeArray;
|
|
6517
6826
|
/**
|
|
6518
6827
|
* Return the Hann window of size M, a taper with a weighted cosine bell.
|
|
6519
6828
|
*
|
|
@@ -6659,6 +6968,27 @@ function trunc(x) {
|
|
|
6659
6968
|
return idiv(x, 1);
|
|
6660
6969
|
}
|
|
6661
6970
|
/**
|
|
6971
|
+
* @function
|
|
6972
|
+
* Round to the given number of decimals.
|
|
6973
|
+
*
|
|
6974
|
+
* Uses banker's rounding (round half to even) to match NumPy/JAX behavior.
|
|
6975
|
+
*/
|
|
6976
|
+
const round = jit$1(function round$1(a, decimals = 0) {
|
|
6977
|
+
if (decimals === 0) return rint(a);
|
|
6978
|
+
const factor = 10 ** decimals;
|
|
6979
|
+
return rint(a.mul(factor)).mul(1 / factor);
|
|
6980
|
+
}, { staticArgnums: [1] });
|
|
6981
|
+
/**
|
|
6982
|
+
* @function
|
|
6983
|
+
* Round to the nearest integer, with ties going to the nearest even integer.
|
|
6984
|
+
*/
|
|
6985
|
+
const rint = jit$1(function rint$1(x) {
|
|
6986
|
+
const rounded = floor(x.ref.add(.5));
|
|
6987
|
+
const half = x.ref.sub(floor(x)).equal(.5);
|
|
6988
|
+
const odd = remainder(rounded.ref, 2).notEqual(0);
|
|
6989
|
+
return where(half.mul(odd), rounded.ref.sub(1), rounded);
|
|
6990
|
+
});
|
|
6991
|
+
/**
|
|
6662
6992
|
* Compute `x1 * 2 ** x2` as a standard multiplication and exponentiation.
|
|
6663
6993
|
*
|
|
6664
6994
|
* This is the inverse of `frexp()`.
|
|
@@ -6986,6 +7316,7 @@ function triangularSolve(a, b, { leftSide = false, lower = false, transposeA = f
|
|
|
6986
7316
|
//#region src/library/lax.ts
|
|
6987
7317
|
var lax_exports = {};
|
|
6988
7318
|
__export(lax_exports, {
|
|
7319
|
+
bitcastConvertType: () => bitcastConvertType,
|
|
6989
7320
|
conv: () => conv,
|
|
6990
7321
|
convGeneralDilated: () => convGeneralDilated,
|
|
6991
7322
|
convTranspose: () => convTranspose,
|
|
@@ -6999,6 +7330,10 @@ __export(lax_exports, {
|
|
|
6999
7330
|
topK: () => topK
|
|
7000
7331
|
});
|
|
7001
7332
|
const JsArray = globalThis.Array;
|
|
7333
|
+
/** Elementwise bitcast an array into a new dtype. */
|
|
7334
|
+
function bitcastConvertType(x, newDtype) {
|
|
7335
|
+
return fudgeArray(x).view(newDtype);
|
|
7336
|
+
}
|
|
7002
7337
|
/**
|
|
7003
7338
|
* General dot product/contraction operator.
|
|
7004
7339
|
*
|
|
@@ -7730,7 +8065,9 @@ function getK01(key$1) {
|
|
|
7730
8065
|
function key(seed) {
|
|
7731
8066
|
seed = array(seed, { dtype: DType.Uint32 });
|
|
7732
8067
|
if (seed.ndim !== 0) throw new Error(`key: seed must be a scalar integer, but got shape ${seed.shape} - use jax.vmap for batching.`);
|
|
7733
|
-
|
|
8068
|
+
const key$1 = stack([0, seed]);
|
|
8069
|
+
if (key$1 instanceof Array$1) key$1._realizeSource();
|
|
8070
|
+
return key$1;
|
|
7734
8071
|
}
|
|
7735
8072
|
/** Splits a PRNG key into `num` new keys by adding a leading axis. */
|
|
7736
8073
|
function split(key$1, num = 2) {
|
|
@@ -7925,6 +8262,11 @@ Symbol.asyncDispose ??= Symbol.for("Symbol.asyncDispose");
|
|
|
7925
8262
|
|
|
7926
8263
|
//#endregion
|
|
7927
8264
|
//#region src/index.ts
|
|
8265
|
+
/** @namespace */
|
|
8266
|
+
const profiler = {
|
|
8267
|
+
startTrace,
|
|
8268
|
+
stopTrace
|
|
8269
|
+
};
|
|
7928
8270
|
/**
|
|
7929
8271
|
* @function
|
|
7930
8272
|
* Compute the forward-mode Jacobian-vector product for a function.
|
|
@@ -8085,4 +8427,4 @@ async function devicePut(x, device) {
|
|
|
8085
8427
|
}
|
|
8086
8428
|
|
|
8087
8429
|
//#endregion
|
|
8088
|
-
export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
|
|
8430
|
+
export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, getWebGPUDevice, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, profiler, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-
|
|
1
|
+
import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-DZvR7mZV.js";
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgl/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -458,6 +458,12 @@ function generateExpression(exp, args, inputDtypes) {
|
|
|
458
458
|
else source = `min(${a}, ${b})`;
|
|
459
459
|
else if (op === AluOp.Max) if (dtype === DType.Bool) source = `(${a} || ${b})`;
|
|
460
460
|
else source = `max(${a}, ${b})`;
|
|
461
|
+
else if (op === AluOp.BitCombine) {
|
|
462
|
+
let infix = arg === "and" ? "&" : arg === "or" ? "|" : "^";
|
|
463
|
+
if (dtype === DType.Bool) infix = infix + infix;
|
|
464
|
+
source = `(${a} ${infix} ${b})`;
|
|
465
|
+
} else if (op === AluOp.BitShift) if (arg === "shl") source = `(${a} << ${b})`;
|
|
466
|
+
else source = `(${a} >> ${b})`;
|
|
461
467
|
} else if (AluGroup.Compare.has(op)) {
|
|
462
468
|
const a = gen(src[0]);
|
|
463
469
|
const b = gen(src[1]);
|