@jax-js/jax 0.1.10 → 0.1.11
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +4 -1
- package/dist/{backend-Ctqs8la1.js → backend-DZvR7mZV.js} +730 -21
- package/dist/{backend-DMauYnfl.cjs → backend-DlYlOYqN.cjs} +735 -20
- package/dist/index.cjs +140 -3
- package/dist/index.d.cts +66 -3
- package/dist/index.d.ts +66 -3
- package/dist/index.js +140 -4
- package/dist/{webgl-CvQ1QBX1.js → webgl-D8-14NzA.js} +7 -1
- package/dist/{webgl-kvVt7-T7.cjs → webgl-Ovaaa-Qx.cjs} +7 -1
- package/dist/{webgpu-v_W_-oKw.js → webgpu-Dg8FpYrH.js} +6 -1
- package/dist/{webgpu-DMSx7a6M.cjs → webgpu-uU9nnttc.cjs} +6 -1
- package/package.json +1 -1
package/dist/index.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { __export } from "./chunk-Cl8Af3a2.js";
|
|
2
|
-
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, startTrace, stopTrace, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-
|
|
2
|
+
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, getWebGPUDevice, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, startTrace, stopTrace, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-DZvR7mZV.js";
|
|
3
3
|
|
|
4
4
|
//#region src/frontend/convolution.ts
|
|
5
5
|
/**
|
|
@@ -333,6 +333,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
333
333
|
Primitive$1["Mod"] = "mod";
|
|
334
334
|
Primitive$1["Min"] = "min";
|
|
335
335
|
Primitive$1["Max"] = "max";
|
|
336
|
+
Primitive$1["BitCombine"] = "bit_combine";
|
|
337
|
+
Primitive$1["BitShift"] = "bit_shift";
|
|
336
338
|
Primitive$1["Neg"] = "neg";
|
|
337
339
|
Primitive$1["Reciprocal"] = "reciprocal";
|
|
338
340
|
Primitive$1["Floor"] = "floor";
|
|
@@ -406,6 +408,12 @@ function min$1(x, y) {
|
|
|
406
408
|
function max$1(x, y) {
|
|
407
409
|
return bind1(Primitive.Max, [x, y]);
|
|
408
410
|
}
|
|
411
|
+
function bitCombine(x, y, op) {
|
|
412
|
+
return bind1(Primitive.BitCombine, [x, y], { op });
|
|
413
|
+
}
|
|
414
|
+
function bitShift(x, y, op) {
|
|
415
|
+
return bind1(Primitive.BitShift, [x, y], { op });
|
|
416
|
+
}
|
|
409
417
|
function neg(x) {
|
|
410
418
|
return bind1(Primitive.Neg, [x]);
|
|
411
419
|
}
|
|
@@ -1620,6 +1628,16 @@ const abstractEvalRules = {
|
|
|
1620
1628
|
[Primitive.Mod]: binopAbstractEval,
|
|
1621
1629
|
[Primitive.Min]: binopAbstractEval,
|
|
1622
1630
|
[Primitive.Max]: binopAbstractEval,
|
|
1631
|
+
[Primitive.BitCombine]([x, y]) {
|
|
1632
|
+
const aval = promoteAvals(x, y);
|
|
1633
|
+
if (isFloatDtype(aval.dtype)) throw new TypeError(`bitwise operations require integer or boolean inputs, got ${aval.dtype}`);
|
|
1634
|
+
return [aval];
|
|
1635
|
+
},
|
|
1636
|
+
[Primitive.BitShift]([x, y]) {
|
|
1637
|
+
const shape$1 = generalBroadcast(x.shape, y.shape);
|
|
1638
|
+
if (isFloatDtype(x.dtype) || isFloatDtype(y.dtype) || x.dtype === DType.Bool || y.dtype === DType.Bool) throw new TypeError(`bit shift operations require integer inputs, got ${x} and ${y}`);
|
|
1639
|
+
return [new ShapedArray(shape$1, x.dtype, x.weakType)];
|
|
1640
|
+
},
|
|
1623
1641
|
[Primitive.Neg]: vectorizedUnopAbstractEval,
|
|
1624
1642
|
[Primitive.Reciprocal]: vectorizedUnopAbstractEval,
|
|
1625
1643
|
[Primitive.Floor]: vectorizedUnopAbstractEval,
|
|
@@ -2155,6 +2173,8 @@ const jitRules = {
|
|
|
2155
2173
|
[Primitive.Mod]: broadcastedJit(([a, b]) => AluExp.mod(a, b)),
|
|
2156
2174
|
[Primitive.Min]: broadcastedJit(([a, b]) => AluExp.min(a, b)),
|
|
2157
2175
|
[Primitive.Max]: broadcastedJit(([a, b]) => AluExp.max(a, b)),
|
|
2176
|
+
[Primitive.BitCombine]: broadcastedJit(([a, b], { op }) => AluExp.bitCombine(a, b, op)),
|
|
2177
|
+
[Primitive.BitShift]: broadcastedJit(([a, b], { op }) => AluExp.bitShift(a, b, op)),
|
|
2158
2178
|
[Primitive.Neg]: unopJit((a) => AluExp.sub(AluExp.const(a.dtype, 0), a)),
|
|
2159
2179
|
[Primitive.Reciprocal]: unopJit(AluExp.reciprocal),
|
|
2160
2180
|
[Primitive.Floor]: unopJit(AluExp.floor),
|
|
@@ -2347,7 +2367,9 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2347
2367
|
case Primitive.Idiv:
|
|
2348
2368
|
case Primitive.Mod:
|
|
2349
2369
|
case Primitive.Min:
|
|
2350
|
-
case Primitive.Max:
|
|
2370
|
+
case Primitive.Max:
|
|
2371
|
+
case Primitive.BitCombine:
|
|
2372
|
+
case Primitive.BitShift: {
|
|
2351
2373
|
const otherInput = nextEqn.inputs.find((v) => v !== outVar);
|
|
2352
2374
|
if (otherInput instanceof Lit || deepEqual(generalBroadcast(otherInput.aval.shape, outVar.aval.shape), outVar.aval.shape)) {
|
|
2353
2375
|
head = usages[0];
|
|
@@ -2986,6 +3008,42 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2986
3008
|
return dtypedArray(this.dtype, buf);
|
|
2987
3009
|
}
|
|
2988
3010
|
/**
|
|
3011
|
+
* Return this array as a WebGPU buffer (with `STORAGE | COPY_SRC`).
|
|
3012
|
+
*
|
|
3013
|
+
* Only available on the WebGPU backend. The array's memory is still managed
|
|
3014
|
+
* by jax-js, and it will be freed when the buffer is no longer in use. You
|
|
3015
|
+
* _should not_ mutate the buffer's contents.
|
|
3016
|
+
*
|
|
3017
|
+
* Note that the GPU buffer may be slightly larger than the array's size; it
|
|
3018
|
+
* will always be aligned to 4 bytes.
|
|
3019
|
+
*/
|
|
3020
|
+
async gpuBuffer() {
|
|
3021
|
+
if (this.device !== "webgpu") throw new Error(`gpuBuffer() is only available on WebGPU backend`);
|
|
3022
|
+
this.#realize();
|
|
3023
|
+
const pending = this.#pending;
|
|
3024
|
+
if (pending) {
|
|
3025
|
+
await Promise.all(pending.map((p) => p.prepare()));
|
|
3026
|
+
for (const p of pending) p.submit();
|
|
3027
|
+
}
|
|
3028
|
+
const backend = this.#backend;
|
|
3029
|
+
const { buffer } = backend.buffers.get(this.#source);
|
|
3030
|
+
this.dispose();
|
|
3031
|
+
return buffer;
|
|
3032
|
+
}
|
|
3033
|
+
/** Synchronous version of `Array.gpuBuffer()`. */
|
|
3034
|
+
gpuBufferSync() {
|
|
3035
|
+
if (this.device !== "webgpu") throw new Error(`gpuBufferSync() is only available on WebGPU backend`);
|
|
3036
|
+
this.#realize();
|
|
3037
|
+
for (const p of this.#pending) {
|
|
3038
|
+
p.prepareSync();
|
|
3039
|
+
p.submit();
|
|
3040
|
+
}
|
|
3041
|
+
const backend = this.#backend;
|
|
3042
|
+
const { buffer } = backend.buffers.get(this.#source);
|
|
3043
|
+
this.dispose();
|
|
3044
|
+
return buffer;
|
|
3045
|
+
}
|
|
3046
|
+
/**
|
|
2989
3047
|
* Convert this array into a JavaScript object.
|
|
2990
3048
|
*
|
|
2991
3049
|
* This is a blocking operation that will compile all of the shaders and wait
|
|
@@ -3032,6 +3090,14 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
3032
3090
|
[Primitive.Max]([x, y]) {
|
|
3033
3091
|
return [x.#binary(AluOp.Max, y)];
|
|
3034
3092
|
},
|
|
3093
|
+
[Primitive.BitCombine]([x, y], { op }) {
|
|
3094
|
+
const custom = (src) => AluExp.bitCombine(src[0], src[1], op);
|
|
3095
|
+
return [Array$1.#naryCustom("bit_combine", custom, [x, y])];
|
|
3096
|
+
},
|
|
3097
|
+
[Primitive.BitShift]([x, y], { op }) {
|
|
3098
|
+
const custom = (src) => AluExp.bitShift(src[0], src[1], op);
|
|
3099
|
+
return [Array$1.#naryCustom("bit_shift", custom, [x, y], { dtypeOverride: [void 0, y.dtype] })];
|
|
3100
|
+
},
|
|
3035
3101
|
[Primitive.Neg]([x]) {
|
|
3036
3102
|
return [zerosLike$1(x.ref).#binary(AluOp.Sub, x)];
|
|
3037
3103
|
},
|
|
@@ -3723,6 +3789,8 @@ const vmapRules = {
|
|
|
3723
3789
|
[Primitive.Mod]: broadcastBatcher(Primitive.Mod),
|
|
3724
3790
|
[Primitive.Min]: broadcastBatcher(Primitive.Min),
|
|
3725
3791
|
[Primitive.Max]: broadcastBatcher(Primitive.Max),
|
|
3792
|
+
[Primitive.BitCombine]: broadcastBatcher(Primitive.BitCombine),
|
|
3793
|
+
[Primitive.BitShift]: broadcastBatcher(Primitive.BitShift),
|
|
3726
3794
|
[Primitive.Neg]: unopBatcher(Primitive.Neg),
|
|
3727
3795
|
[Primitive.Reciprocal]: unopBatcher(Primitive.Reciprocal),
|
|
3728
3796
|
[Primitive.Floor]: unopBatcher(Primitive.Floor),
|
|
@@ -4045,6 +4113,8 @@ const jvpRules = {
|
|
|
4045
4113
|
[Primitive.Max]([x, y], [dx, dy]) {
|
|
4046
4114
|
return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
|
|
4047
4115
|
},
|
|
4116
|
+
[Primitive.BitCombine]: zeroTangentsJvp(Primitive.BitCombine),
|
|
4117
|
+
[Primitive.BitShift]: zeroTangentsJvp(Primitive.BitShift),
|
|
4048
4118
|
[Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
|
|
4049
4119
|
[Primitive.Reciprocal]([x], [dx]) {
|
|
4050
4120
|
const xRecip = reciprocal$1(x.ref);
|
|
@@ -5236,7 +5306,8 @@ __export(numpy_linalg_exports, {
|
|
|
5236
5306
|
solve: () => solve,
|
|
5237
5307
|
tensordot: () => tensordot,
|
|
5238
5308
|
trace: () => trace,
|
|
5239
|
-
vecdot: () => vecdot
|
|
5309
|
+
vecdot: () => vecdot,
|
|
5310
|
+
vectorNorm: () => vectorNorm
|
|
5240
5311
|
});
|
|
5241
5312
|
function checkSquare(name, a) {
|
|
5242
5313
|
if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`${name}: input must be at least 2D square matrix, got ${a.aval}`);
|
|
@@ -5415,6 +5486,23 @@ function solve(a, b) {
|
|
|
5415
5486
|
if (bIs1d) x = squeeze(x, -1);
|
|
5416
5487
|
return x;
|
|
5417
5488
|
}
|
|
5489
|
+
/**
|
|
5490
|
+
* Compute the vector norm of an array.
|
|
5491
|
+
*
|
|
5492
|
+
* @param x - Input array.
|
|
5493
|
+
* @param ord - Order of the norm (default 2). Supports `Infinity`, `-Infinity`, `0`, or any real number.
|
|
5494
|
+
* @param axis - Axis/axes to reduce over (default: all axes).
|
|
5495
|
+
* @param keepdims - Whether to keep reduced dimensions as size 1.
|
|
5496
|
+
* @returns The norm of `x`, reduced over the given axes.
|
|
5497
|
+
*/
|
|
5498
|
+
function vectorNorm(x, { ord = 2, axis = null, keepdims = false } = {}) {
|
|
5499
|
+
x = fudgeArray(x);
|
|
5500
|
+
const ax = axis ?? null;
|
|
5501
|
+
if (ord === Infinity) return max(absolute(x), ax, { keepdims });
|
|
5502
|
+
else if (ord === -Infinity) return min(absolute(x), ax, { keepdims });
|
|
5503
|
+
else if (ord === 0) return x.notEqual(0).astype(x.dtype).sum(ax, { keepdims });
|
|
5504
|
+
else return power(power(absolute(x), ord).sum(ax, { keepdims }), 1 / ord);
|
|
5505
|
+
}
|
|
5418
5506
|
|
|
5419
5507
|
//#endregion
|
|
5420
5508
|
//#region src/library/numpy/dtype-info.ts
|
|
@@ -5534,6 +5622,13 @@ __export(numpy_exports, {
|
|
|
5534
5622
|
atan2: () => atan2,
|
|
5535
5623
|
atanh: () => arctanh,
|
|
5536
5624
|
average: () => average,
|
|
5625
|
+
bitwiseAnd: () => bitwiseAnd,
|
|
5626
|
+
bitwiseInvert: () => invert,
|
|
5627
|
+
bitwiseLeftShift: () => leftShift,
|
|
5628
|
+
bitwiseNot: () => invert,
|
|
5629
|
+
bitwiseOr: () => bitwiseOr,
|
|
5630
|
+
bitwiseRightShift: () => rightShift,
|
|
5631
|
+
bitwiseXor: () => bitwiseXor,
|
|
5537
5632
|
bool: () => bool,
|
|
5538
5633
|
broadcastArrays: () => broadcastArrays,
|
|
5539
5634
|
broadcastShapes: () => broadcastShapes,
|
|
@@ -5595,12 +5690,14 @@ __export(numpy_exports, {
|
|
|
5595
5690
|
inf: () => inf,
|
|
5596
5691
|
inner: () => inner,
|
|
5597
5692
|
int32: () => int32,
|
|
5693
|
+
invert: () => invert,
|
|
5598
5694
|
isfinite: () => isfinite,
|
|
5599
5695
|
isinf: () => isinf,
|
|
5600
5696
|
isnan: () => isnan,
|
|
5601
5697
|
isneginf: () => isneginf,
|
|
5602
5698
|
isposinf: () => isposinf,
|
|
5603
5699
|
ldexp: () => ldexp,
|
|
5700
|
+
leftShift: () => leftShift,
|
|
5604
5701
|
less: () => less,
|
|
5605
5702
|
lessEqual: () => lessEqual,
|
|
5606
5703
|
linalg: () => numpy_linalg_exports,
|
|
@@ -5649,6 +5746,7 @@ __export(numpy_exports, {
|
|
|
5649
5746
|
remainder: () => remainder,
|
|
5650
5747
|
repeat: () => repeat,
|
|
5651
5748
|
reshape: () => reshape,
|
|
5749
|
+
rightShift: () => rightShift,
|
|
5652
5750
|
rint: () => rint,
|
|
5653
5751
|
round: () => round,
|
|
5654
5752
|
shape: () => shape,
|
|
@@ -5763,6 +5861,44 @@ function logicalXor(x, y) {
|
|
|
5763
5861
|
function logicalNot(x) {
|
|
5764
5862
|
return notEqual(astype(x, DType.Bool), true);
|
|
5765
5863
|
}
|
|
5864
|
+
/** Compute element-wise bitwise AND. */
|
|
5865
|
+
function bitwiseAnd(x, y) {
|
|
5866
|
+
return bitCombine(x, y, "and");
|
|
5867
|
+
}
|
|
5868
|
+
/** Compute element-wise bitwise OR. */
|
|
5869
|
+
function bitwiseOr(x, y) {
|
|
5870
|
+
return bitCombine(x, y, "or");
|
|
5871
|
+
}
|
|
5872
|
+
/** Compute element-wise bitwise XOR. */
|
|
5873
|
+
function bitwiseXor(x, y) {
|
|
5874
|
+
return bitCombine(x, y, "xor");
|
|
5875
|
+
}
|
|
5876
|
+
/** Compute element-wise bitwise NOT (inversion). */
|
|
5877
|
+
function invert(x) {
|
|
5878
|
+
const arr = fudgeArray(x);
|
|
5879
|
+
let allOnes;
|
|
5880
|
+
switch (arr.dtype) {
|
|
5881
|
+
case DType.Bool:
|
|
5882
|
+
allOnes = true;
|
|
5883
|
+
break;
|
|
5884
|
+
case DType.Uint32:
|
|
5885
|
+
allOnes = 4294967295;
|
|
5886
|
+
break;
|
|
5887
|
+
case DType.Int32:
|
|
5888
|
+
allOnes = -1;
|
|
5889
|
+
break;
|
|
5890
|
+
default: throw new TypeError(`invert: unsupported dtype ${arr.dtype}`);
|
|
5891
|
+
}
|
|
5892
|
+
return bitCombine(arr, allOnes, "xor");
|
|
5893
|
+
}
|
|
5894
|
+
/** Compute element-wise left bit shift. */
|
|
5895
|
+
function leftShift(x, y) {
|
|
5896
|
+
return bitShift(x, y, "shl");
|
|
5897
|
+
}
|
|
5898
|
+
/** Compute element-wise right bit shift. */
|
|
5899
|
+
function rightShift(x, y) {
|
|
5900
|
+
return bitShift(x, y, "shr");
|
|
5901
|
+
}
|
|
5766
5902
|
/** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
5767
5903
|
const where = where$1;
|
|
5768
5904
|
/**
|
|
@@ -8291,4 +8427,4 @@ async function devicePut(x, device) {
|
|
|
8291
8427
|
}
|
|
8292
8428
|
|
|
8293
8429
|
//#endregion
|
|
8294
|
-
export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, profiler, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
|
|
8430
|
+
export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, getWebGPUDevice, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, profiler, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-
|
|
1
|
+
import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-DZvR7mZV.js";
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgl/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -458,6 +458,12 @@ function generateExpression(exp, args, inputDtypes) {
|
|
|
458
458
|
else source = `min(${a}, ${b})`;
|
|
459
459
|
else if (op === AluOp.Max) if (dtype === DType.Bool) source = `(${a} || ${b})`;
|
|
460
460
|
else source = `max(${a}, ${b})`;
|
|
461
|
+
else if (op === AluOp.BitCombine) {
|
|
462
|
+
let infix = arg === "and" ? "&" : arg === "or" ? "|" : "^";
|
|
463
|
+
if (dtype === DType.Bool) infix = infix + infix;
|
|
464
|
+
source = `(${a} ${infix} ${b})`;
|
|
465
|
+
} else if (op === AluOp.BitShift) if (arg === "shl") source = `(${a} << ${b})`;
|
|
466
|
+
else source = `(${a} >> ${b})`;
|
|
461
467
|
} else if (AluGroup.Compare.has(op)) {
|
|
462
468
|
const a = gen(src[0]);
|
|
463
469
|
const b = gen(src[1]);
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
const require_backend = require('./backend-
|
|
1
|
+
const require_backend = require('./backend-DlYlOYqN.cjs');
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgl/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -458,6 +458,12 @@ function generateExpression(exp, args, inputDtypes) {
|
|
|
458
458
|
else source = `min(${a}, ${b})`;
|
|
459
459
|
else if (op === require_backend.AluOp.Max) if (dtype === require_backend.DType.Bool) source = `(${a} || ${b})`;
|
|
460
460
|
else source = `max(${a}, ${b})`;
|
|
461
|
+
else if (op === require_backend.AluOp.BitCombine) {
|
|
462
|
+
let infix = arg === "and" ? "&" : arg === "or" ? "|" : "^";
|
|
463
|
+
if (dtype === require_backend.DType.Bool) infix = infix + infix;
|
|
464
|
+
source = `(${a} ${infix} ${b})`;
|
|
465
|
+
} else if (op === require_backend.AluOp.BitShift) if (arg === "shl") source = `(${a} << ${b})`;
|
|
466
|
+
else source = `(${a} >> ${b})`;
|
|
461
467
|
} else if (require_backend.AluGroup.Compare.has(op)) {
|
|
462
468
|
const a = gen(src[0]);
|
|
463
469
|
const b = gen(src[1]);
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, emitTrace, findPow2, isFloatDtype, isTracing, mapSetUnion, onFlushTrace, prod, range, strip1, traceSourceInfo, tuneWebgpu } from "./backend-
|
|
1
|
+
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, emitTrace, findPow2, isFloatDtype, isTracing, mapSetUnion, onFlushTrace, prod, range, strip1, traceSourceInfo, tuneWebgpu } from "./backend-DZvR7mZV.js";
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -1100,6 +1100,11 @@ function pipelineSource(device, kernel) {
|
|
|
1100
1100
|
else source = `min(${strip1(a)}, ${strip1(b)})`;
|
|
1101
1101
|
else if (op === AluOp.Max) if (dtype === DType.Bool) source = `(${a} || ${b})`;
|
|
1102
1102
|
else source = `max(${strip1(a)}, ${strip1(b)})`;
|
|
1103
|
+
else if (op === AluOp.BitCombine) if (arg === "and") source = `(${a} & ${b})`;
|
|
1104
|
+
else if (arg === "or") source = `(${a} | ${b})`;
|
|
1105
|
+
else source = dtype === DType.Bool ? `(${a} != ${b})` : `(${a} ^ ${b})`;
|
|
1106
|
+
else if (op === AluOp.BitShift) if (arg === "shl") source = `(${a} << ${b})`;
|
|
1107
|
+
else source = `(${a} >> ${b})`;
|
|
1103
1108
|
else if (op === AluOp.Cmplt) source = `(${a} < ${b})`;
|
|
1104
1109
|
else if (op === AluOp.Cmpne) if (isFloatDtype(src[0].dtype)) {
|
|
1105
1110
|
const x = isGensym(a) ? a : gensym();
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
const require_backend = require('./backend-
|
|
1
|
+
const require_backend = require('./backend-DlYlOYqN.cjs');
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -1100,6 +1100,11 @@ function pipelineSource(device, kernel) {
|
|
|
1100
1100
|
else source = `min(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
|
|
1101
1101
|
else if (op === require_backend.AluOp.Max) if (dtype === require_backend.DType.Bool) source = `(${a} || ${b})`;
|
|
1102
1102
|
else source = `max(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
|
|
1103
|
+
else if (op === require_backend.AluOp.BitCombine) if (arg === "and") source = `(${a} & ${b})`;
|
|
1104
|
+
else if (arg === "or") source = `(${a} | ${b})`;
|
|
1105
|
+
else source = dtype === require_backend.DType.Bool ? `(${a} != ${b})` : `(${a} ^ ${b})`;
|
|
1106
|
+
else if (op === require_backend.AluOp.BitShift) if (arg === "shl") source = `(${a} << ${b})`;
|
|
1107
|
+
else source = `(${a} >> ${b})`;
|
|
1103
1108
|
else if (op === require_backend.AluOp.Cmplt) source = `(${a} < ${b})`;
|
|
1104
1109
|
else if (op === require_backend.AluOp.Cmpne) if (require_backend.isFloatDtype(src[0].dtype)) {
|
|
1105
1110
|
const x = isGensym(a) ? a : gensym();
|