@jax-js/jax 0.1.10 → 0.1.12
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +7 -2
- package/dist/{backend-Ctqs8la1.js → backend-DI-V78Rk.js} +732 -21
- package/dist/{backend-DMauYnfl.cjs → backend-x-6vqzIM.cjs} +737 -20
- package/dist/index.cjs +372 -20
- package/dist/index.d.cts +172 -4
- package/dist/index.d.ts +172 -4
- package/dist/index.js +372 -21
- package/dist/{webgl-CvQ1QBX1.js → webgl-BhsnpeB0.js} +7 -1
- package/dist/{webgl-kvVt7-T7.cjs → webgl-CD3WK_Me.cjs} +7 -1
- package/dist/{webgpu-v_W_-oKw.js → webgpu-C2kLdkUh.js} +299 -149
- package/dist/{webgpu-DMSx7a6M.cjs → webgpu-C4S8Uq9e.cjs} +299 -149
- package/package.json +1 -1
package/dist/index.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { __export } from "./chunk-Cl8Af3a2.js";
|
|
2
|
-
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, startTrace, stopTrace, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-
|
|
2
|
+
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, getWebGPUDevice, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, startTrace, stopTrace, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-DI-V78Rk.js";
|
|
3
3
|
|
|
4
4
|
//#region src/frontend/convolution.ts
|
|
5
5
|
/**
|
|
@@ -209,7 +209,7 @@ __export(tree_exports, {
|
|
|
209
209
|
structure: () => structure,
|
|
210
210
|
unflatten: () => unflatten
|
|
211
211
|
});
|
|
212
|
-
const JsArray$
|
|
212
|
+
const JsArray$3 = globalThis.Array;
|
|
213
213
|
let NodeType = /* @__PURE__ */ function(NodeType$1) {
|
|
214
214
|
NodeType$1["Array"] = "Array";
|
|
215
215
|
NodeType$1["Object"] = "Object";
|
|
@@ -257,7 +257,7 @@ function flatten(tree) {
|
|
|
257
257
|
return [leaves$1, treedef];
|
|
258
258
|
}
|
|
259
259
|
function _flatten(tree, leaves$1) {
|
|
260
|
-
if (JsArray$
|
|
260
|
+
if (JsArray$3.isArray(tree)) {
|
|
261
261
|
const childTrees = tree.map((c) => _flatten(c, leaves$1));
|
|
262
262
|
return new JsTreeDef(NodeType.Array, null, childTrees);
|
|
263
263
|
} else if (typeof tree === "object" && tree !== null && tree.constructor === Object) {
|
|
@@ -333,6 +333,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
333
333
|
Primitive$1["Mod"] = "mod";
|
|
334
334
|
Primitive$1["Min"] = "min";
|
|
335
335
|
Primitive$1["Max"] = "max";
|
|
336
|
+
Primitive$1["BitCombine"] = "bit_combine";
|
|
337
|
+
Primitive$1["BitShift"] = "bit_shift";
|
|
336
338
|
Primitive$1["Neg"] = "neg";
|
|
337
339
|
Primitive$1["Reciprocal"] = "reciprocal";
|
|
338
340
|
Primitive$1["Floor"] = "floor";
|
|
@@ -406,6 +408,12 @@ function min$1(x, y) {
|
|
|
406
408
|
function max$1(x, y) {
|
|
407
409
|
return bind1(Primitive.Max, [x, y]);
|
|
408
410
|
}
|
|
411
|
+
function bitCombine(x, y, op) {
|
|
412
|
+
return bind1(Primitive.BitCombine, [x, y], { op });
|
|
413
|
+
}
|
|
414
|
+
function bitShift(x, y, op) {
|
|
415
|
+
return bind1(Primitive.BitShift, [x, y], { op });
|
|
416
|
+
}
|
|
409
417
|
function neg(x) {
|
|
410
418
|
return bind1(Primitive.Neg, [x]);
|
|
411
419
|
}
|
|
@@ -1620,6 +1628,16 @@ const abstractEvalRules = {
|
|
|
1620
1628
|
[Primitive.Mod]: binopAbstractEval,
|
|
1621
1629
|
[Primitive.Min]: binopAbstractEval,
|
|
1622
1630
|
[Primitive.Max]: binopAbstractEval,
|
|
1631
|
+
[Primitive.BitCombine]([x, y]) {
|
|
1632
|
+
const aval = promoteAvals(x, y);
|
|
1633
|
+
if (isFloatDtype(aval.dtype)) throw new TypeError(`bitwise operations require integer or boolean inputs, got ${aval.dtype}`);
|
|
1634
|
+
return [aval];
|
|
1635
|
+
},
|
|
1636
|
+
[Primitive.BitShift]([x, y]) {
|
|
1637
|
+
const shape$1 = generalBroadcast(x.shape, y.shape);
|
|
1638
|
+
if (isFloatDtype(x.dtype) || isFloatDtype(y.dtype) || x.dtype === DType.Bool || y.dtype === DType.Bool) throw new TypeError(`bit shift operations require integer inputs, got ${x} and ${y}`);
|
|
1639
|
+
return [new ShapedArray(shape$1, x.dtype, x.weakType)];
|
|
1640
|
+
},
|
|
1623
1641
|
[Primitive.Neg]: vectorizedUnopAbstractEval,
|
|
1624
1642
|
[Primitive.Reciprocal]: vectorizedUnopAbstractEval,
|
|
1625
1643
|
[Primitive.Floor]: vectorizedUnopAbstractEval,
|
|
@@ -2155,6 +2173,8 @@ const jitRules = {
|
|
|
2155
2173
|
[Primitive.Mod]: broadcastedJit(([a, b]) => AluExp.mod(a, b)),
|
|
2156
2174
|
[Primitive.Min]: broadcastedJit(([a, b]) => AluExp.min(a, b)),
|
|
2157
2175
|
[Primitive.Max]: broadcastedJit(([a, b]) => AluExp.max(a, b)),
|
|
2176
|
+
[Primitive.BitCombine]: broadcastedJit(([a, b], { op }) => AluExp.bitCombine(a, b, op)),
|
|
2177
|
+
[Primitive.BitShift]: broadcastedJit(([a, b], { op }) => AluExp.bitShift(a, b, op)),
|
|
2158
2178
|
[Primitive.Neg]: unopJit((a) => AluExp.sub(AluExp.const(a.dtype, 0), a)),
|
|
2159
2179
|
[Primitive.Reciprocal]: unopJit(AluExp.reciprocal),
|
|
2160
2180
|
[Primitive.Floor]: unopJit(AluExp.floor),
|
|
@@ -2347,7 +2367,9 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2347
2367
|
case Primitive.Idiv:
|
|
2348
2368
|
case Primitive.Mod:
|
|
2349
2369
|
case Primitive.Min:
|
|
2350
|
-
case Primitive.Max:
|
|
2370
|
+
case Primitive.Max:
|
|
2371
|
+
case Primitive.BitCombine:
|
|
2372
|
+
case Primitive.BitShift: {
|
|
2351
2373
|
const otherInput = nextEqn.inputs.find((v) => v !== outVar);
|
|
2352
2374
|
if (otherInput instanceof Lit || deepEqual(generalBroadcast(otherInput.aval.shape, outVar.aval.shape), outVar.aval.shape)) {
|
|
2353
2375
|
head = usages[0];
|
|
@@ -2438,7 +2460,7 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2438
2460
|
|
|
2439
2461
|
//#endregion
|
|
2440
2462
|
//#region src/frontend/array.ts
|
|
2441
|
-
const JsArray$
|
|
2463
|
+
const JsArray$2 = globalThis.Array;
|
|
2442
2464
|
const inlineArrayLimit = 128;
|
|
2443
2465
|
/** Version of pureArray with fudged types. */
|
|
2444
2466
|
const fudgeArray = pureArray;
|
|
@@ -2878,6 +2900,15 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2878
2900
|
this.#check();
|
|
2879
2901
|
const indices = unravelAlu(this.#st.shape, AluVar.gidx);
|
|
2880
2902
|
if (this.#source instanceof AluExp) {
|
|
2903
|
+
let resolvedSource;
|
|
2904
|
+
if (this.#st.contiguous && this.#st.size < inlineArrayLimit && (resolvedSource = this.#source.resolve()) !== void 0) {
|
|
2905
|
+
const byteLength = this.#st.size * byteWidth(this.#dtype);
|
|
2906
|
+
const initialData = new Uint8Array(byteLength);
|
|
2907
|
+
dtypedArray(this.#dtype, initialData).fill(resolvedSource);
|
|
2908
|
+
this.#source = this.#backend.malloc(byteLength, initialData);
|
|
2909
|
+
this.#st = ShapeTracker.fromShape(this.shape);
|
|
2910
|
+
return;
|
|
2911
|
+
}
|
|
2881
2912
|
const exp$2 = accessorAluExp(this.#source, this.#st, indices);
|
|
2882
2913
|
const kernel = new Kernel(0, this.#st.size, exp$2);
|
|
2883
2914
|
const output = this.#backend.malloc(kernel.bytes);
|
|
@@ -2986,6 +3017,42 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2986
3017
|
return dtypedArray(this.dtype, buf);
|
|
2987
3018
|
}
|
|
2988
3019
|
/**
|
|
3020
|
+
* Return this array as a WebGPU buffer (with `STORAGE | COPY_SRC`).
|
|
3021
|
+
*
|
|
3022
|
+
* Only available on the WebGPU backend. The array's memory is still managed
|
|
3023
|
+
* by jax-js, and it will be freed when the buffer is no longer in use. You
|
|
3024
|
+
* _should not_ mutate the buffer's contents.
|
|
3025
|
+
*
|
|
3026
|
+
* Note that the GPU buffer may be slightly larger than the array's size; it
|
|
3027
|
+
* will always be aligned to 4 bytes.
|
|
3028
|
+
*/
|
|
3029
|
+
async gpuBuffer() {
|
|
3030
|
+
if (this.device !== "webgpu") throw new Error(`gpuBuffer() is only available on WebGPU backend`);
|
|
3031
|
+
this.#realize();
|
|
3032
|
+
const pending = this.#pending;
|
|
3033
|
+
if (pending) {
|
|
3034
|
+
await Promise.all(pending.map((p) => p.prepare()));
|
|
3035
|
+
for (const p of pending) p.submit();
|
|
3036
|
+
}
|
|
3037
|
+
const backend = this.#backend;
|
|
3038
|
+
const { buffer } = backend.buffers.get(this.#source);
|
|
3039
|
+
this.dispose();
|
|
3040
|
+
return buffer;
|
|
3041
|
+
}
|
|
3042
|
+
/** Synchronous version of `Array.gpuBuffer()`. */
|
|
3043
|
+
gpuBufferSync() {
|
|
3044
|
+
if (this.device !== "webgpu") throw new Error(`gpuBufferSync() is only available on WebGPU backend`);
|
|
3045
|
+
this.#realize();
|
|
3046
|
+
for (const p of this.#pending) {
|
|
3047
|
+
p.prepareSync();
|
|
3048
|
+
p.submit();
|
|
3049
|
+
}
|
|
3050
|
+
const backend = this.#backend;
|
|
3051
|
+
const { buffer } = backend.buffers.get(this.#source);
|
|
3052
|
+
this.dispose();
|
|
3053
|
+
return buffer;
|
|
3054
|
+
}
|
|
3055
|
+
/**
|
|
2989
3056
|
* Convert this array into a JavaScript object.
|
|
2990
3057
|
*
|
|
2991
3058
|
* This is a blocking operation that will compile all of the shaders and wait
|
|
@@ -3032,6 +3099,14 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
3032
3099
|
[Primitive.Max]([x, y]) {
|
|
3033
3100
|
return [x.#binary(AluOp.Max, y)];
|
|
3034
3101
|
},
|
|
3102
|
+
[Primitive.BitCombine]([x, y], { op }) {
|
|
3103
|
+
const custom = (src) => AluExp.bitCombine(src[0], src[1], op);
|
|
3104
|
+
return [Array$1.#naryCustom("bit_combine", custom, [x, y])];
|
|
3105
|
+
},
|
|
3106
|
+
[Primitive.BitShift]([x, y], { op }) {
|
|
3107
|
+
const custom = (src) => AluExp.bitShift(src[0], src[1], op);
|
|
3108
|
+
return [Array$1.#naryCustom("bit_shift", custom, [x, y], { dtypeOverride: [void 0, y.dtype] })];
|
|
3109
|
+
},
|
|
3035
3110
|
[Primitive.Neg]([x]) {
|
|
3036
3111
|
return [zerosLike$1(x.ref).#binary(AluOp.Sub, x)];
|
|
3037
3112
|
},
|
|
@@ -3284,7 +3359,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
3284
3359
|
if (!shape$1) {
|
|
3285
3360
|
shape$1 = [];
|
|
3286
3361
|
let cur = values;
|
|
3287
|
-
while (JsArray$
|
|
3362
|
+
while (JsArray$2.isArray(cur)) {
|
|
3288
3363
|
shape$1.push(cur.length);
|
|
3289
3364
|
cur = cur[0];
|
|
3290
3365
|
}
|
|
@@ -3723,6 +3798,8 @@ const vmapRules = {
|
|
|
3723
3798
|
[Primitive.Mod]: broadcastBatcher(Primitive.Mod),
|
|
3724
3799
|
[Primitive.Min]: broadcastBatcher(Primitive.Min),
|
|
3725
3800
|
[Primitive.Max]: broadcastBatcher(Primitive.Max),
|
|
3801
|
+
[Primitive.BitCombine]: broadcastBatcher(Primitive.BitCombine),
|
|
3802
|
+
[Primitive.BitShift]: broadcastBatcher(Primitive.BitShift),
|
|
3726
3803
|
[Primitive.Neg]: unopBatcher(Primitive.Neg),
|
|
3727
3804
|
[Primitive.Reciprocal]: unopBatcher(Primitive.Reciprocal),
|
|
3728
3805
|
[Primitive.Floor]: unopBatcher(Primitive.Floor),
|
|
@@ -4045,6 +4122,8 @@ const jvpRules = {
|
|
|
4045
4122
|
[Primitive.Max]([x, y], [dx, dy]) {
|
|
4046
4123
|
return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
|
|
4047
4124
|
},
|
|
4125
|
+
[Primitive.BitCombine]: zeroTangentsJvp(Primitive.BitCombine),
|
|
4126
|
+
[Primitive.BitShift]: zeroTangentsJvp(Primitive.BitShift),
|
|
4048
4127
|
[Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
|
|
4049
4128
|
[Primitive.Reciprocal]([x], [dx]) {
|
|
4050
4129
|
const xRecip = reciprocal$1(x.ref);
|
|
@@ -4162,7 +4241,7 @@ const jvpRules = {
|
|
|
4162
4241
|
return [[L], [dL]];
|
|
4163
4242
|
},
|
|
4164
4243
|
[Primitive.LU]([a], [da]) {
|
|
4165
|
-
const [luMatrix, pivots, permutation] = lu$1(a);
|
|
4244
|
+
const [luMatrix, pivots, permutation$1] = lu$1(a);
|
|
4166
4245
|
const [m, n] = a.shape.slice(-2);
|
|
4167
4246
|
const k = Math.min(m, n);
|
|
4168
4247
|
const luSliceL = sliceAxis(luMatrix.ref, -1, [0, k]);
|
|
@@ -4174,7 +4253,7 @@ const jvpRules = {
|
|
|
4174
4253
|
const uPadded = n > k ? padAxis(uUpper, -2, [0, n - k]) : uUpper;
|
|
4175
4254
|
const uEye = n > k ? padAxis(padAxis(eye(n - k), -1, [k, 0]), -2, [k, 0]) : zerosLike$1(uPadded.ref);
|
|
4176
4255
|
const U = uPadded.add(uEye);
|
|
4177
|
-
const P = permutation.ref.reshape([...permutation.shape, 1]).equal(arange(m)).astype(da.dtype);
|
|
4256
|
+
const P = permutation$1.ref.reshape([...permutation$1.shape, 1]).equal(arange(m)).astype(da.dtype);
|
|
4178
4257
|
const pda = batchMatmulT(P, mT(da));
|
|
4179
4258
|
const la = mT(triangularSolve$1(L.ref, mT(pda), {
|
|
4180
4259
|
lower: true,
|
|
@@ -4186,11 +4265,11 @@ const jvpRules = {
|
|
|
4186
4265
|
return [[
|
|
4187
4266
|
luMatrix,
|
|
4188
4267
|
pivots,
|
|
4189
|
-
permutation
|
|
4268
|
+
permutation$1
|
|
4190
4269
|
], [
|
|
4191
4270
|
lDot.add(uDot),
|
|
4192
4271
|
zerosLike$1(pivots.ref),
|
|
4193
|
-
zerosLike$1(permutation.ref)
|
|
4272
|
+
zerosLike$1(permutation$1.ref)
|
|
4194
4273
|
]];
|
|
4195
4274
|
},
|
|
4196
4275
|
[Primitive.Jit](primals, tangents, { name, jaxpr }) {
|
|
@@ -5236,7 +5315,8 @@ __export(numpy_linalg_exports, {
|
|
|
5236
5315
|
solve: () => solve,
|
|
5237
5316
|
tensordot: () => tensordot,
|
|
5238
5317
|
trace: () => trace,
|
|
5239
|
-
vecdot: () => vecdot
|
|
5318
|
+
vecdot: () => vecdot,
|
|
5319
|
+
vectorNorm: () => vectorNorm
|
|
5240
5320
|
});
|
|
5241
5321
|
function checkSquare(name, a) {
|
|
5242
5322
|
if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`${name}: input must be at least 2D square matrix, got ${a.aval}`);
|
|
@@ -5271,8 +5351,8 @@ function cross$1(x1, x2, axis = -1) {
|
|
|
5271
5351
|
function det(a) {
|
|
5272
5352
|
a = fudgeArray(a);
|
|
5273
5353
|
const n = checkSquare("det", a);
|
|
5274
|
-
const [lu$2, pivots, permutation] = lu(a);
|
|
5275
|
-
permutation.dispose();
|
|
5354
|
+
const [lu$2, pivots, permutation$1] = lu(a);
|
|
5355
|
+
permutation$1.dispose();
|
|
5276
5356
|
const parity = pivots.notEqual(arange(n)).astype(int32).sum(-1).mod(2);
|
|
5277
5357
|
const sign$1 = parity.mul(-2).add(1);
|
|
5278
5358
|
const diag$1 = lu$2.diagonal(0, -1, -2);
|
|
@@ -5361,8 +5441,8 @@ function matrixPower(a, n) {
|
|
|
5361
5441
|
function slogdet(a) {
|
|
5362
5442
|
a = fudgeArray(a);
|
|
5363
5443
|
const n = checkSquare("slogdet", a);
|
|
5364
|
-
const [lu$2, pivots, permutation] = lu(a);
|
|
5365
|
-
permutation.dispose();
|
|
5444
|
+
const [lu$2, pivots, permutation$1] = lu(a);
|
|
5445
|
+
permutation$1.dispose();
|
|
5366
5446
|
let parity = pivots.notEqual(arange(n)).astype(int32).sum(-1);
|
|
5367
5447
|
const diag$1 = lu$2.diagonal(0, -1, -2);
|
|
5368
5448
|
parity = parity.add(diag$1.ref.less(0).astype(int32).sum(-1)).mod(2);
|
|
@@ -5400,9 +5480,9 @@ function solve(a, b) {
|
|
|
5400
5480
|
n,
|
|
5401
5481
|
m
|
|
5402
5482
|
]);
|
|
5403
|
-
const [lu$2, pivots, permutation] = lu(a);
|
|
5483
|
+
const [lu$2, pivots, permutation$1] = lu(a);
|
|
5404
5484
|
pivots.dispose();
|
|
5405
|
-
const P = arange(n).equal(permutation.reshape([...permutation.shape, 1])).astype(b.dtype);
|
|
5485
|
+
const P = arange(n).equal(permutation$1.reshape([...permutation$1.shape, 1])).astype(b.dtype);
|
|
5406
5486
|
const LPb = triangularSolve(lu$2.ref, matmul(P, b), {
|
|
5407
5487
|
leftSide: true,
|
|
5408
5488
|
lower: true,
|
|
@@ -5415,6 +5495,23 @@ function solve(a, b) {
|
|
|
5415
5495
|
if (bIs1d) x = squeeze(x, -1);
|
|
5416
5496
|
return x;
|
|
5417
5497
|
}
|
|
5498
|
+
/**
|
|
5499
|
+
* Compute the vector norm of an array.
|
|
5500
|
+
*
|
|
5501
|
+
* @param x - Input array.
|
|
5502
|
+
* @param ord - Order of the norm (default 2). Supports `Infinity`, `-Infinity`, `0`, or any real number.
|
|
5503
|
+
* @param axis - Axis/axes to reduce over (default: all axes).
|
|
5504
|
+
* @param keepdims - Whether to keep reduced dimensions as size 1.
|
|
5505
|
+
* @returns The norm of `x`, reduced over the given axes.
|
|
5506
|
+
*/
|
|
5507
|
+
function vectorNorm(x, { ord = 2, axis = null, keepdims = false } = {}) {
|
|
5508
|
+
x = fudgeArray(x);
|
|
5509
|
+
const ax = axis ?? null;
|
|
5510
|
+
if (ord === Infinity) return max(absolute(x), ax, { keepdims });
|
|
5511
|
+
else if (ord === -Infinity) return min(absolute(x), ax, { keepdims });
|
|
5512
|
+
else if (ord === 0) return x.notEqual(0).astype(x.dtype).sum(ax, { keepdims });
|
|
5513
|
+
else return power(power(absolute(x), ord).sum(ax, { keepdims }), 1 / ord);
|
|
5514
|
+
}
|
|
5418
5515
|
|
|
5419
5516
|
//#endregion
|
|
5420
5517
|
//#region src/library/numpy/dtype-info.ts
|
|
@@ -5534,6 +5631,13 @@ __export(numpy_exports, {
|
|
|
5534
5631
|
atan2: () => atan2,
|
|
5535
5632
|
atanh: () => arctanh,
|
|
5536
5633
|
average: () => average,
|
|
5634
|
+
bitwiseAnd: () => bitwiseAnd,
|
|
5635
|
+
bitwiseInvert: () => invert,
|
|
5636
|
+
bitwiseLeftShift: () => leftShift,
|
|
5637
|
+
bitwiseNot: () => invert,
|
|
5638
|
+
bitwiseOr: () => bitwiseOr,
|
|
5639
|
+
bitwiseRightShift: () => rightShift,
|
|
5640
|
+
bitwiseXor: () => bitwiseXor,
|
|
5537
5641
|
bool: () => bool,
|
|
5538
5642
|
broadcastArrays: () => broadcastArrays,
|
|
5539
5643
|
broadcastShapes: () => broadcastShapes,
|
|
@@ -5595,12 +5699,14 @@ __export(numpy_exports, {
|
|
|
5595
5699
|
inf: () => inf,
|
|
5596
5700
|
inner: () => inner,
|
|
5597
5701
|
int32: () => int32,
|
|
5702
|
+
invert: () => invert,
|
|
5598
5703
|
isfinite: () => isfinite,
|
|
5599
5704
|
isinf: () => isinf,
|
|
5600
5705
|
isnan: () => isnan,
|
|
5601
5706
|
isneginf: () => isneginf,
|
|
5602
5707
|
isposinf: () => isposinf,
|
|
5603
5708
|
ldexp: () => ldexp,
|
|
5709
|
+
leftShift: () => leftShift,
|
|
5604
5710
|
less: () => less,
|
|
5605
5711
|
lessEqual: () => lessEqual,
|
|
5606
5712
|
linalg: () => numpy_linalg_exports,
|
|
@@ -5649,6 +5755,7 @@ __export(numpy_exports, {
|
|
|
5649
5755
|
remainder: () => remainder,
|
|
5650
5756
|
repeat: () => repeat,
|
|
5651
5757
|
reshape: () => reshape,
|
|
5758
|
+
rightShift: () => rightShift,
|
|
5652
5759
|
rint: () => rint,
|
|
5653
5760
|
round: () => round,
|
|
5654
5761
|
shape: () => shape,
|
|
@@ -5763,6 +5870,44 @@ function logicalXor(x, y) {
|
|
|
5763
5870
|
function logicalNot(x) {
|
|
5764
5871
|
return notEqual(astype(x, DType.Bool), true);
|
|
5765
5872
|
}
|
|
5873
|
+
/** Compute element-wise bitwise AND. */
|
|
5874
|
+
function bitwiseAnd(x, y) {
|
|
5875
|
+
return bitCombine(x, y, "and");
|
|
5876
|
+
}
|
|
5877
|
+
/** Compute element-wise bitwise OR. */
|
|
5878
|
+
function bitwiseOr(x, y) {
|
|
5879
|
+
return bitCombine(x, y, "or");
|
|
5880
|
+
}
|
|
5881
|
+
/** Compute element-wise bitwise XOR. */
|
|
5882
|
+
function bitwiseXor(x, y) {
|
|
5883
|
+
return bitCombine(x, y, "xor");
|
|
5884
|
+
}
|
|
5885
|
+
/** Compute element-wise bitwise NOT (inversion). */
|
|
5886
|
+
function invert(x) {
|
|
5887
|
+
const arr = fudgeArray(x);
|
|
5888
|
+
let allOnes;
|
|
5889
|
+
switch (arr.dtype) {
|
|
5890
|
+
case DType.Bool:
|
|
5891
|
+
allOnes = true;
|
|
5892
|
+
break;
|
|
5893
|
+
case DType.Uint32:
|
|
5894
|
+
allOnes = 4294967295;
|
|
5895
|
+
break;
|
|
5896
|
+
case DType.Int32:
|
|
5897
|
+
allOnes = -1;
|
|
5898
|
+
break;
|
|
5899
|
+
default: throw new TypeError(`invert: unsupported dtype ${arr.dtype}`);
|
|
5900
|
+
}
|
|
5901
|
+
return bitCombine(arr, allOnes, "xor");
|
|
5902
|
+
}
|
|
5903
|
+
/** Compute element-wise left bit shift. */
|
|
5904
|
+
function leftShift(x, y) {
|
|
5905
|
+
return bitShift(x, y, "shl");
|
|
5906
|
+
}
|
|
5907
|
+
/** Compute element-wise right bit shift. */
|
|
5908
|
+
function rightShift(x, y) {
|
|
5909
|
+
return bitShift(x, y, "shr");
|
|
5910
|
+
}
|
|
5766
5911
|
/** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
5767
5912
|
const where = where$1;
|
|
5768
5913
|
/**
|
|
@@ -7193,7 +7338,7 @@ __export(lax_exports, {
|
|
|
7193
7338
|
stopGradient: () => stopGradient$1,
|
|
7194
7339
|
topK: () => topK
|
|
7195
7340
|
});
|
|
7196
|
-
const JsArray = globalThis.Array;
|
|
7341
|
+
const JsArray$1 = globalThis.Array;
|
|
7197
7342
|
/** Elementwise bitcast an array into a new dtype. */
|
|
7198
7343
|
function bitcastConvertType(x, newDtype) {
|
|
7199
7344
|
return fudgeArray(x).view(newDtype);
|
|
@@ -7380,7 +7525,7 @@ function convTransposePadding(k, s, padding) {
|
|
|
7380
7525
|
} else if (padding === "VALID") {
|
|
7381
7526
|
padLen = k + s - 2 + Math.max(k - s, 0);
|
|
7382
7527
|
pad1 = k - 1;
|
|
7383
|
-
} else if (JsArray.isArray(padding)) {
|
|
7528
|
+
} else if (JsArray$1.isArray(padding)) {
|
|
7384
7529
|
const pads = [k - 1 - padding[0], k - 1 - padding[1]];
|
|
7385
7530
|
pad1 = pads[0];
|
|
7386
7531
|
padLen = pads[0] + pads[1];
|
|
@@ -7899,19 +8044,34 @@ function dotProductAttention(query, key$1, value, opts = {}) {
|
|
|
7899
8044
|
//#region src/library/random.ts
|
|
7900
8045
|
var random_exports = {};
|
|
7901
8046
|
__export(random_exports, {
|
|
8047
|
+
ball: () => ball,
|
|
7902
8048
|
bernoulli: () => bernoulli,
|
|
7903
8049
|
bits: () => bits,
|
|
7904
8050
|
categorical: () => categorical,
|
|
7905
8051
|
cauchy: () => cauchy,
|
|
8052
|
+
choice: () => choice,
|
|
8053
|
+
doubleSidedMaxwell: () => doubleSidedMaxwell,
|
|
7906
8054
|
exponential: () => exponential,
|
|
8055
|
+
geometric: () => geometric,
|
|
7907
8056
|
gumbel: () => gumbel,
|
|
7908
8057
|
key: () => key,
|
|
7909
8058
|
laplace: () => laplace,
|
|
8059
|
+
logistic: () => logistic,
|
|
8060
|
+
lognormal: () => lognormal,
|
|
8061
|
+
maxwell: () => maxwell,
|
|
7910
8062
|
multivariateNormal: () => multivariateNormal,
|
|
7911
8063
|
normal: () => normal,
|
|
8064
|
+
pareto: () => pareto,
|
|
8065
|
+
permutation: () => permutation,
|
|
8066
|
+
rademacher: () => rademacher,
|
|
8067
|
+
randint: () => randint,
|
|
8068
|
+
rayleigh: () => rayleigh,
|
|
7912
8069
|
split: () => split,
|
|
7913
|
-
|
|
8070
|
+
triangular: () => triangular,
|
|
8071
|
+
uniform: () => uniform,
|
|
8072
|
+
weibullMin: () => weibullMin
|
|
7914
8073
|
});
|
|
8074
|
+
const JsArray = globalThis.Array;
|
|
7915
8075
|
function validateKeyShape(key$1, scalar = false) {
|
|
7916
8076
|
if (key$1.ndim === 0) throw new Error("Key must have at least one dimension.");
|
|
7917
8077
|
if (key$1.shape[key$1.shape.length - 1] !== 2) throw new Error(`Invalid key shape: ${key$1.shape}. Expected last dimension to be 2.`);
|
|
@@ -7964,6 +8124,21 @@ const uniform = jit$1(function uniform$1(key$1, shape$1 = [], { minval = 0, maxv
|
|
|
7964
8124
|
else return rand.mul(maxval - minval).add(minval);
|
|
7965
8125
|
}, { staticArgnums: [1, 2] });
|
|
7966
8126
|
/**
|
|
8127
|
+
* @function
|
|
8128
|
+
* Sample points uniformly from the Euclidean unit ball in `d` dimensions.
|
|
8129
|
+
*
|
|
8130
|
+
* Only the Euclidean `p=2` case is currently supported.
|
|
8131
|
+
*/
|
|
8132
|
+
const ball = jit$1(function ball$1(key$1, d, { p = 2, shape: shape$1 = [] } = {}) {
|
|
8133
|
+
if (!Number.isInteger(d) || d <= 0) throw new Error(`ball: dimension must be a positive integer, got ${d}`);
|
|
8134
|
+
if (p !== 2) throw new Error("ball: only the Euclidean p=2 case is supported");
|
|
8135
|
+
const [k1, k2] = split(key$1, 2);
|
|
8136
|
+
const z = normal(k1, [...shape$1, d]);
|
|
8137
|
+
const norm = sqrt(z.ref.mul(z.ref).sum(-1, { keepdims: true }));
|
|
8138
|
+
const radius = exp(log(uniform(k2, [...shape$1, 1])).mul(1 / d));
|
|
8139
|
+
return z.div(norm).mul(radius);
|
|
8140
|
+
}, { staticArgnums: [1, 2] });
|
|
8141
|
+
/**
|
|
7967
8142
|
* Sample Bernoulli random variables with given mean (0,1 categorical).
|
|
7968
8143
|
*
|
|
7969
8144
|
* Returns a random Boolean array with the specified shape. `p` can be an array
|
|
@@ -8025,6 +8200,57 @@ const cauchy = jit$1(function cauchy$1(key$1, shape$1 = []) {
|
|
|
8025
8200
|
return tan(u.sub(.5).mul(Math.PI));
|
|
8026
8201
|
}, { staticArgnums: [1] });
|
|
8027
8202
|
/**
|
|
8203
|
+
* Sample from a population with optional replacement and optional probabilities.
|
|
8204
|
+
*
|
|
8205
|
+
* This implements the common JAX-compatible cases: integer populations and
|
|
8206
|
+
* array populations along `axis`. Probabilities `p`, if provided, are sampled
|
|
8207
|
+
* via `categorical(log(p))`.
|
|
8208
|
+
*/
|
|
8209
|
+
function choice(key$1, a, { shape: shape$1 = [], replace = true, p, axis = 0 } = {}) {
|
|
8210
|
+
let n;
|
|
8211
|
+
let values = null;
|
|
8212
|
+
if (typeof a === "number") {
|
|
8213
|
+
if (!Number.isInteger(a) || a < 0) throw new Error(`choice: a must be a non-negative integer, got ${a}`);
|
|
8214
|
+
n = a;
|
|
8215
|
+
} else {
|
|
8216
|
+
values = fudgeArray(a);
|
|
8217
|
+
axis = checkAxis(axis, values.ndim);
|
|
8218
|
+
n = values.shape[axis];
|
|
8219
|
+
}
|
|
8220
|
+
let indices;
|
|
8221
|
+
if (p !== void 0) indices = categorical(key$1, log(p), {
|
|
8222
|
+
shape: shape$1,
|
|
8223
|
+
replace
|
|
8224
|
+
});
|
|
8225
|
+
else if (replace) indices = randint(key$1, {
|
|
8226
|
+
minval: 0,
|
|
8227
|
+
maxval: n,
|
|
8228
|
+
shape: shape$1
|
|
8229
|
+
});
|
|
8230
|
+
else {
|
|
8231
|
+
const k = shape$1.reduce((acc, x) => acc * x, 1);
|
|
8232
|
+
if (k > n) throw new Error(`Number of samples without replacement (${k}) cannot exceed population size (${n}).`);
|
|
8233
|
+
indices = permutation(key$1, n).slice([0, k]).reshape(shape$1);
|
|
8234
|
+
}
|
|
8235
|
+
if (values === null) return indices;
|
|
8236
|
+
const index = JsArray(axis).fill([]);
|
|
8237
|
+
index.push(indices);
|
|
8238
|
+
return values.slice(...index);
|
|
8239
|
+
}
|
|
8240
|
+
/**
|
|
8241
|
+
* @function
|
|
8242
|
+
* Sample double-sided Maxwell random values with the provided location and scale.
|
|
8243
|
+
*/
|
|
8244
|
+
const doubleSidedMaxwell = jit$1(function doubleSidedMaxwell$1(key$1, loc, scale, shape$1 = []) {
|
|
8245
|
+
loc = fudgeArray(loc);
|
|
8246
|
+
scale = fudgeArray(scale);
|
|
8247
|
+
const [k1, k2] = split(key$1, 2);
|
|
8248
|
+
return rademacher(k1, {
|
|
8249
|
+
shape: shape$1,
|
|
8250
|
+
dtype: DType.Float32
|
|
8251
|
+
}).mul(maxwell(k2, shape$1)).mul(scale).add(loc);
|
|
8252
|
+
}, { staticArgnums: [3] });
|
|
8253
|
+
/**
|
|
8028
8254
|
* @function
|
|
8029
8255
|
* Sample exponential random values according to `p(x) = exp(-x)`.
|
|
8030
8256
|
*/
|
|
@@ -8034,6 +8260,14 @@ const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
|
|
|
8034
8260
|
}, { staticArgnums: [1] });
|
|
8035
8261
|
/**
|
|
8036
8262
|
* @function
|
|
8263
|
+
* Sample geometric random values: the number of trials until first success.
|
|
8264
|
+
*/
|
|
8265
|
+
const geometric = jit$1(function geometric$1(key$1, p, { shape: shape$1 = [], dtype = DType.Int32 } = {}) {
|
|
8266
|
+
p = fudgeArray(p);
|
|
8267
|
+
return floor(log1p(negative(uniform(key$1, shape$1))).div(log1p(negative(p)))).add(1).astype(dtype);
|
|
8268
|
+
}, { staticArgnums: [2] });
|
|
8269
|
+
/**
|
|
8270
|
+
* @function
|
|
8037
8271
|
* Sample from a Gumbel distribution with location 0 and scale 1.
|
|
8038
8272
|
*
|
|
8039
8273
|
* Uses inverse transform sampling: `x = -log(-log(u))` where u ~ Uniform(0, 1).
|
|
@@ -8058,6 +8292,32 @@ const laplace = jit$1(function laplace$1(key$1, shape$1 = []) {
|
|
|
8058
8292
|
}, { staticArgnums: [1] });
|
|
8059
8293
|
/**
|
|
8060
8294
|
* @function
|
|
8295
|
+
* Sample from a logistic distribution with location 0 and scale 1.
|
|
8296
|
+
*
|
|
8297
|
+
* Uses inverse transform sampling: `x = log(u) - log(1-u)`.
|
|
8298
|
+
*/
|
|
8299
|
+
const logistic = jit$1(function logistic$1(key$1, shape$1 = []) {
|
|
8300
|
+
const u = uniform(key$1, shape$1);
|
|
8301
|
+
return log(u.ref).sub(log1p(negative(u)));
|
|
8302
|
+
}, { staticArgnums: [1] });
|
|
8303
|
+
/**
|
|
8304
|
+
* @function
|
|
8305
|
+
* Sample log-normal random values: `exp(sigma * normal(key, shape))`.
|
|
8306
|
+
*/
|
|
8307
|
+
const lognormal = jit$1(function lognormal$1(key$1, sigma = 1, shape$1 = []) {
|
|
8308
|
+
sigma = fudgeArray(sigma);
|
|
8309
|
+
return exp(normal(key$1, shape$1).mul(sigma));
|
|
8310
|
+
}, { staticArgnums: [2] });
|
|
8311
|
+
/**
|
|
8312
|
+
* @function
|
|
8313
|
+
* Sample Maxwell-distributed random values.
|
|
8314
|
+
*/
|
|
8315
|
+
const maxwell = jit$1(function maxwell$1(key$1, shape$1 = []) {
|
|
8316
|
+
const z = normal(key$1, [...shape$1, 3]);
|
|
8317
|
+
return sqrt(z.ref.mul(z).sum(-1));
|
|
8318
|
+
}, { staticArgnums: [1] });
|
|
8319
|
+
/**
|
|
8320
|
+
* @function
|
|
8061
8321
|
* Sample multivariate normal random values with given mean and covariance.
|
|
8062
8322
|
*
|
|
8063
8323
|
* The values are returned with the given shape, along with the final dimension
|
|
@@ -8098,6 +8358,97 @@ const normal = jit$1(function normal$1(key$1, shape$1 = []) {
|
|
|
8098
8358
|
const theta = u2.mul(2 * Math.PI);
|
|
8099
8359
|
return radius.mul(cos(theta));
|
|
8100
8360
|
}, { staticArgnums: [1] });
|
|
8361
|
+
/**
|
|
8362
|
+
* @function
|
|
8363
|
+
* Sample from a Pareto distribution with shape parameter `b` and support [1, ∞).
|
|
8364
|
+
*/
|
|
8365
|
+
const pareto = jit$1(function pareto$1(key$1, b, shape$1 = []) {
|
|
8366
|
+
b = fudgeArray(b);
|
|
8367
|
+
return exp(exponential(key$1, shape$1).div(b));
|
|
8368
|
+
}, { staticArgnums: [2] });
|
|
8369
|
+
/**
|
|
8370
|
+
* Return a random permutation of an integer range or of an array along `axis`.
|
|
8371
|
+
*/
|
|
8372
|
+
function permutation(key$1, x, axis = 0) {
|
|
8373
|
+
if (typeof x === "number") {
|
|
8374
|
+
if (!Number.isInteger(x) || x < 0) throw new Error(`permutation: x must be a non-negative integer, got ${x}`);
|
|
8375
|
+
return argsort(uniform(key$1, [x])).astype(DType.Int32);
|
|
8376
|
+
}
|
|
8377
|
+
const arr = fudgeArray(x);
|
|
8378
|
+
axis = checkAxis(axis, arr.ndim);
|
|
8379
|
+
const perm = permutation(key$1, arr.shape[axis]);
|
|
8380
|
+
const index = JsArray(axis).fill([]);
|
|
8381
|
+
index.push(perm);
|
|
8382
|
+
return arr.slice(...index);
|
|
8383
|
+
}
|
|
8384
|
+
/**
|
|
8385
|
+
* @function
|
|
8386
|
+
* Sample Rademacher random values, uniformly from {-1, 1}.
|
|
8387
|
+
*/
|
|
8388
|
+
const rademacher = jit$1(function rademacher$1(key$1, { shape: shape$1 = [], dtype = DType.Int32 } = {}) {
|
|
8389
|
+
if (dtype === DType.Uint32 || dtype === DType.Bool) throw new Error(`rademacher: unsupported dtype ${dtype}`);
|
|
8390
|
+
const one = array(1, {
|
|
8391
|
+
dtype,
|
|
8392
|
+
device: key$1.device
|
|
8393
|
+
});
|
|
8394
|
+
const minusOne = array(-1, {
|
|
8395
|
+
dtype,
|
|
8396
|
+
device: key$1.device
|
|
8397
|
+
});
|
|
8398
|
+
return where(bernoulli(key$1, .5, shape$1), one, minusOne);
|
|
8399
|
+
}, { staticArgnums: [1] });
|
|
8400
|
+
/**
|
|
8401
|
+
* @function
|
|
8402
|
+
* Sample integer values uniformly from `[minval, maxval)`.
|
|
8403
|
+
*
|
|
8404
|
+
* This uses modulo reduction of uniform 32-bit random bits. For ranges that do
|
|
8405
|
+
* not divide 2^32, this introduces a very small modulo bias.
|
|
8406
|
+
*/
|
|
8407
|
+
const randint = jit$1(function randint$1(key$1, { minval, maxval, shape: shape$1 = [], dtype = DType.Int32 }) {
|
|
8408
|
+
if (!Number.isInteger(minval) || !Number.isInteger(maxval)) throw new Error("randint: minval and maxval must be integers");
|
|
8409
|
+
if (minval >= maxval) throw new Error(`Invalid range: [${minval}, ${maxval}).`);
|
|
8410
|
+
if (dtype !== DType.Int32 && dtype !== DType.Uint32) throw new Error(`randint: dtype must be int32 or uint32, got ${dtype}`);
|
|
8411
|
+
if (dtype === DType.Uint32 && minval < 0) throw new Error("randint: uint32 dtype requires minval >= 0");
|
|
8412
|
+
const range$1 = maxval - minval;
|
|
8413
|
+
return bits(key$1, shape$1).mod(range$1).astype(dtype).add(minval);
|
|
8414
|
+
}, { staticArgnums: [1] });
|
|
8415
|
+
/**
|
|
8416
|
+
* @function
|
|
8417
|
+
* Sample Rayleigh random values with the provided scale parameter.
|
|
8418
|
+
*/
|
|
8419
|
+
const rayleigh = jit$1(function rayleigh$1(key$1, scale = 1, shape$1 = []) {
|
|
8420
|
+
scale = fudgeArray(scale);
|
|
8421
|
+
return sqrt(exponential(key$1, shape$1).mul(2)).mul(scale);
|
|
8422
|
+
}, { staticArgnums: [2] });
|
|
8423
|
+
/**
|
|
8424
|
+
* @function
|
|
8425
|
+
* Sample triangular random values on `[left, right]` with the given mode.
|
|
8426
|
+
*/
|
|
8427
|
+
const triangular = jit$1(function triangular$1(key$1, left, mode, right, shape$1 = []) {
|
|
8428
|
+
left = fudgeArray(left);
|
|
8429
|
+
mode = fudgeArray(mode);
|
|
8430
|
+
right = fudgeArray(right);
|
|
8431
|
+
const u = uniform(key$1, shape$1);
|
|
8432
|
+
const width = right.ref.sub(left.ref);
|
|
8433
|
+
const leftSpan = mode.ref.sub(left.ref);
|
|
8434
|
+
const rightSpan = right.ref.sub(mode);
|
|
8435
|
+
const cutoff = leftSpan.ref.div(width.ref);
|
|
8436
|
+
const cond = u.ref.less(cutoff);
|
|
8437
|
+
const lower = left.add(sqrt(u.ref.mul(width.ref).mul(leftSpan)));
|
|
8438
|
+
const upper = right.sub(sqrt(negative(u).add(1).mul(width).mul(rightSpan)));
|
|
8439
|
+
return where(cond, lower, upper);
|
|
8440
|
+
}, { staticArgnums: [4] });
|
|
8441
|
+
/**
|
|
8442
|
+
* @function
|
|
8443
|
+
* Sample Weibull minimum random values.
|
|
8444
|
+
*
|
|
8445
|
+
* Uses `scale * exponential(key) ** (1 / concentration)`.
|
|
8446
|
+
*/
|
|
8447
|
+
const weibullMin = jit$1(function weibullMin$1(key$1, scale, concentration, shape$1 = []) {
|
|
8448
|
+
scale = fudgeArray(scale);
|
|
8449
|
+
concentration = fudgeArray(concentration);
|
|
8450
|
+
return scale.mul(exp(log(exponential(key$1, shape$1)).div(concentration)));
|
|
8451
|
+
}, { staticArgnums: [3] });
|
|
8101
8452
|
|
|
8102
8453
|
//#endregion
|
|
8103
8454
|
//#region src/library/scipy-special.ts
|
|
@@ -8291,4 +8642,4 @@ async function devicePut(x, device) {
|
|
|
8291
8642
|
}
|
|
8292
8643
|
|
|
8293
8644
|
//#endregion
|
|
8294
|
-
export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, profiler, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
|
|
8645
|
+
export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, getWebGPUDevice, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, profiler, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-
|
|
1
|
+
import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-DI-V78Rk.js";
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgl/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -458,6 +458,12 @@ function generateExpression(exp, args, inputDtypes) {
|
|
|
458
458
|
else source = `min(${a}, ${b})`;
|
|
459
459
|
else if (op === AluOp.Max) if (dtype === DType.Bool) source = `(${a} || ${b})`;
|
|
460
460
|
else source = `max(${a}, ${b})`;
|
|
461
|
+
else if (op === AluOp.BitCombine) {
|
|
462
|
+
let infix = arg === "and" ? "&" : arg === "or" ? "|" : "^";
|
|
463
|
+
if (dtype === DType.Bool) infix = infix + infix;
|
|
464
|
+
source = `(${a} ${infix} ${b})`;
|
|
465
|
+
} else if (op === AluOp.BitShift) if (arg === "shl") source = `(${a} << ${b})`;
|
|
466
|
+
else source = `(${a} >> ${b})`;
|
|
461
467
|
} else if (AluGroup.Compare.has(op)) {
|
|
462
468
|
const a = gen(src[0]);
|
|
463
469
|
const b = gen(src[1]);
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
const require_backend = require('./backend-
|
|
1
|
+
const require_backend = require('./backend-x-6vqzIM.cjs');
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgl/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -458,6 +458,12 @@ function generateExpression(exp, args, inputDtypes) {
|
|
|
458
458
|
else source = `min(${a}, ${b})`;
|
|
459
459
|
else if (op === require_backend.AluOp.Max) if (dtype === require_backend.DType.Bool) source = `(${a} || ${b})`;
|
|
460
460
|
else source = `max(${a}, ${b})`;
|
|
461
|
+
else if (op === require_backend.AluOp.BitCombine) {
|
|
462
|
+
let infix = arg === "and" ? "&" : arg === "or" ? "|" : "^";
|
|
463
|
+
if (dtype === require_backend.DType.Bool) infix = infix + infix;
|
|
464
|
+
source = `(${a} ${infix} ${b})`;
|
|
465
|
+
} else if (op === require_backend.AluOp.BitShift) if (arg === "shl") source = `(${a} << ${b})`;
|
|
466
|
+
else source = `(${a} >> ${b})`;
|
|
461
467
|
} else if (require_backend.AluGroup.Compare.has(op)) {
|
|
462
468
|
const a = gen(src[0]);
|
|
463
469
|
const b = gen(src[1]);
|