@jax-js/jax 0.1.0 → 0.1.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/{backend-DwIAd0AG.js → backend-BqymqzuU.js} +194 -73
- package/dist/{backend-FtkbO6pI.cjs → backend-DeVfWEFS.cjs} +194 -73
- package/dist/index.cjs +2725 -2206
- package/dist/index.d.cts +964 -844
- package/dist/index.d.ts +964 -844
- package/dist/index.js +2698 -2179
- package/dist/{webgpu-LGi2A3mS.js → webgpu-BGuG58KZ.js} +20 -13
- package/dist/{webgpu-BE7zA_01.cjs → webgpu-CcGP160M.cjs} +20 -13
- package/package.json +1 -1
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, SlotError, UnsupportedOpError, findPow2, isFloatDtype, mapSetUnion, strip1, tuneWebgpu } from "./backend-
|
|
1
|
+
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, SlotError, UnsupportedOpError, findPow2, isFloatDtype, mapSetUnion, strip1, tuneWebgpu } from "./backend-BqymqzuU.js";
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -323,7 +323,7 @@ function dtypeToWgsl(dtype, storage = false) {
|
|
|
323
323
|
case DType.Uint32: return "u32";
|
|
324
324
|
case DType.Float32: return "f32";
|
|
325
325
|
case DType.Float16: return "f16";
|
|
326
|
-
default: throw new Error(`Unsupported dtype: ${dtype}`);
|
|
326
|
+
default: throw new Error(`Unsupported dtype for WebGPU: ${dtype}`);
|
|
327
327
|
}
|
|
328
328
|
}
|
|
329
329
|
function constToWgsl(dtype, value) {
|
|
@@ -397,6 +397,7 @@ function pipelineSource(device, kernel) {
|
|
|
397
397
|
}
|
|
398
398
|
let gensymCount = 0;
|
|
399
399
|
const gensym = () => `alu${gensymCount++}`;
|
|
400
|
+
const isGensym = (text) => text.match(/^alu[0-9]+$/);
|
|
400
401
|
for (let i = 0; i < args.length; i++) if (!usedArgs[i]) emit(`_ = &${args[i]};`);
|
|
401
402
|
const references = /* @__PURE__ */ new Map();
|
|
402
403
|
const seen = /* @__PURE__ */ new Set();
|
|
@@ -425,24 +426,30 @@ function pipelineSource(device, kernel) {
|
|
|
425
426
|
else if (op === AluOp.Min) source = `min(${strip1(a)}, ${strip1(b)})`;
|
|
426
427
|
else if (op === AluOp.Max) source = `max(${strip1(a)}, ${strip1(b)})`;
|
|
427
428
|
else if (op === AluOp.Cmplt) source = `(${a} < ${b})`;
|
|
428
|
-
else if (op === AluOp.Cmpne)
|
|
429
|
+
else if (op === AluOp.Cmpne) if (isFloatDtype(src[0].dtype)) {
|
|
430
|
+
const x = isGensym(a) ? a : gensym();
|
|
431
|
+
if (x !== a) emit(`let ${x} = ${a};`);
|
|
432
|
+
source = `(${x} != ${b} || min(${x}, ${dtypeToWgsl(src[0].dtype)}(inf())) != ${x})`;
|
|
433
|
+
} else source = `(${a} != ${b})`;
|
|
429
434
|
} else if (AluGroup.Unary.has(op)) if (op === AluOp.Reciprocal && src[0].op === AluOp.Sqrt) {
|
|
430
435
|
const a = gen(src[0].src[0]);
|
|
431
436
|
source = `inverseSqrt(${a})`;
|
|
432
437
|
} else {
|
|
433
438
|
const a = gen(src[0]);
|
|
434
|
-
if (op === AluOp.Sin) source = `sin(${a})`;
|
|
435
|
-
else if (op === AluOp.Cos) source = `cos(${a})`;
|
|
436
|
-
else if (op === AluOp.Asin) source = `asin(${a})`;
|
|
437
|
-
else if (op === AluOp.Atan) source = `atan(${a})`;
|
|
438
|
-
else if (op === AluOp.Exp) source = `exp(${a})`;
|
|
439
|
-
else if (op === AluOp.Log) source = `log(${a})`;
|
|
439
|
+
if (op === AluOp.Sin) source = `sin(${strip1(a)})`;
|
|
440
|
+
else if (op === AluOp.Cos) source = `cos(${strip1(a)})`;
|
|
441
|
+
else if (op === AluOp.Asin) source = `asin(${strip1(a)})`;
|
|
442
|
+
else if (op === AluOp.Atan) source = `atan(${strip1(a)})`;
|
|
443
|
+
else if (op === AluOp.Exp) source = `exp(${strip1(a)})`;
|
|
444
|
+
else if (op === AluOp.Log) source = `log(${strip1(a)})`;
|
|
440
445
|
else if (op === AluOp.Erf || op === AluOp.Erfc) {
|
|
441
446
|
const funcName = op === AluOp.Erf ? "erf" : "erfc";
|
|
442
|
-
if (dtype !== DType.Float32) source = `${dtypeToWgsl(dtype)}(${funcName}(f32(${a})))`;
|
|
443
|
-
else source = `${funcName}(${a})`;
|
|
444
|
-
} else if (op === AluOp.Sqrt) source = `sqrt(${a})`;
|
|
447
|
+
if (dtype !== DType.Float32) source = `${dtypeToWgsl(dtype)}(${funcName}(f32(${strip1(a)})))`;
|
|
448
|
+
else source = `${funcName}(${strip1(a)})`;
|
|
449
|
+
} else if (op === AluOp.Sqrt) source = `sqrt(${strip1(a)})`;
|
|
445
450
|
else if (op === AluOp.Reciprocal) source = `(1.0 / ${a})`;
|
|
451
|
+
else if (op === AluOp.Floor) source = `floor(${strip1(a)})`;
|
|
452
|
+
else if (op === AluOp.Ceil) source = `ceil(${strip1(a)})`;
|
|
446
453
|
else if (op === AluOp.Cast) source = `${dtypeToWgsl(dtype)}(${strip1(a)})`;
|
|
447
454
|
else if (op === AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${strip1(a)})`;
|
|
448
455
|
}
|
|
@@ -645,4 +652,4 @@ async function compileError(shaderModule, scope, code) {
|
|
|
645
652
|
|
|
646
653
|
//#endregion
|
|
647
654
|
export { WebGPUBackend };
|
|
648
|
-
//# sourceMappingURL=webgpu-
|
|
655
|
+
//# sourceMappingURL=webgpu-BGuG58KZ.js.map
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
const require_backend = require('./backend-
|
|
1
|
+
const require_backend = require('./backend-DeVfWEFS.cjs');
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -323,7 +323,7 @@ function dtypeToWgsl(dtype, storage = false) {
|
|
|
323
323
|
case require_backend.DType.Uint32: return "u32";
|
|
324
324
|
case require_backend.DType.Float32: return "f32";
|
|
325
325
|
case require_backend.DType.Float16: return "f16";
|
|
326
|
-
default: throw new Error(`Unsupported dtype: ${dtype}`);
|
|
326
|
+
default: throw new Error(`Unsupported dtype for WebGPU: ${dtype}`);
|
|
327
327
|
}
|
|
328
328
|
}
|
|
329
329
|
function constToWgsl(dtype, value) {
|
|
@@ -397,6 +397,7 @@ function pipelineSource(device, kernel) {
|
|
|
397
397
|
}
|
|
398
398
|
let gensymCount = 0;
|
|
399
399
|
const gensym = () => `alu${gensymCount++}`;
|
|
400
|
+
const isGensym = (text) => text.match(/^alu[0-9]+$/);
|
|
400
401
|
for (let i = 0; i < args.length; i++) if (!usedArgs[i]) emit(`_ = &${args[i]};`);
|
|
401
402
|
const references = /* @__PURE__ */ new Map();
|
|
402
403
|
const seen = /* @__PURE__ */ new Set();
|
|
@@ -425,24 +426,30 @@ function pipelineSource(device, kernel) {
|
|
|
425
426
|
else if (op === require_backend.AluOp.Min) source = `min(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
|
|
426
427
|
else if (op === require_backend.AluOp.Max) source = `max(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
|
|
427
428
|
else if (op === require_backend.AluOp.Cmplt) source = `(${a} < ${b})`;
|
|
428
|
-
else if (op === require_backend.AluOp.Cmpne)
|
|
429
|
+
else if (op === require_backend.AluOp.Cmpne) if (require_backend.isFloatDtype(src[0].dtype)) {
|
|
430
|
+
const x = isGensym(a) ? a : gensym();
|
|
431
|
+
if (x !== a) emit(`let ${x} = ${a};`);
|
|
432
|
+
source = `(${x} != ${b} || min(${x}, ${dtypeToWgsl(src[0].dtype)}(inf())) != ${x})`;
|
|
433
|
+
} else source = `(${a} != ${b})`;
|
|
429
434
|
} else if (require_backend.AluGroup.Unary.has(op)) if (op === require_backend.AluOp.Reciprocal && src[0].op === require_backend.AluOp.Sqrt) {
|
|
430
435
|
const a = gen(src[0].src[0]);
|
|
431
436
|
source = `inverseSqrt(${a})`;
|
|
432
437
|
} else {
|
|
433
438
|
const a = gen(src[0]);
|
|
434
|
-
if (op === require_backend.AluOp.Sin) source = `sin(${a})`;
|
|
435
|
-
else if (op === require_backend.AluOp.Cos) source = `cos(${a})`;
|
|
436
|
-
else if (op === require_backend.AluOp.Asin) source = `asin(${a})`;
|
|
437
|
-
else if (op === require_backend.AluOp.Atan) source = `atan(${a})`;
|
|
438
|
-
else if (op === require_backend.AluOp.Exp) source = `exp(${a})`;
|
|
439
|
-
else if (op === require_backend.AluOp.Log) source = `log(${a})`;
|
|
439
|
+
if (op === require_backend.AluOp.Sin) source = `sin(${require_backend.strip1(a)})`;
|
|
440
|
+
else if (op === require_backend.AluOp.Cos) source = `cos(${require_backend.strip1(a)})`;
|
|
441
|
+
else if (op === require_backend.AluOp.Asin) source = `asin(${require_backend.strip1(a)})`;
|
|
442
|
+
else if (op === require_backend.AluOp.Atan) source = `atan(${require_backend.strip1(a)})`;
|
|
443
|
+
else if (op === require_backend.AluOp.Exp) source = `exp(${require_backend.strip1(a)})`;
|
|
444
|
+
else if (op === require_backend.AluOp.Log) source = `log(${require_backend.strip1(a)})`;
|
|
440
445
|
else if (op === require_backend.AluOp.Erf || op === require_backend.AluOp.Erfc) {
|
|
441
446
|
const funcName = op === require_backend.AluOp.Erf ? "erf" : "erfc";
|
|
442
|
-
if (dtype !== require_backend.DType.Float32) source = `${dtypeToWgsl(dtype)}(${funcName}(f32(${a})))`;
|
|
443
|
-
else source = `${funcName}(${a})`;
|
|
444
|
-
} else if (op === require_backend.AluOp.Sqrt) source = `sqrt(${a})`;
|
|
447
|
+
if (dtype !== require_backend.DType.Float32) source = `${dtypeToWgsl(dtype)}(${funcName}(f32(${require_backend.strip1(a)})))`;
|
|
448
|
+
else source = `${funcName}(${require_backend.strip1(a)})`;
|
|
449
|
+
} else if (op === require_backend.AluOp.Sqrt) source = `sqrt(${require_backend.strip1(a)})`;
|
|
445
450
|
else if (op === require_backend.AluOp.Reciprocal) source = `(1.0 / ${a})`;
|
|
451
|
+
else if (op === require_backend.AluOp.Floor) source = `floor(${require_backend.strip1(a)})`;
|
|
452
|
+
else if (op === require_backend.AluOp.Ceil) source = `ceil(${require_backend.strip1(a)})`;
|
|
446
453
|
else if (op === require_backend.AluOp.Cast) source = `${dtypeToWgsl(dtype)}(${require_backend.strip1(a)})`;
|
|
447
454
|
else if (op === require_backend.AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${require_backend.strip1(a)})`;
|
|
448
455
|
}
|
|
@@ -645,4 +652,4 @@ async function compileError(shaderModule, scope, code) {
|
|
|
645
652
|
|
|
646
653
|
//#endregion
|
|
647
654
|
exports.WebGPUBackend = WebGPUBackend;
|
|
648
|
-
//# sourceMappingURL=webgpu-
|
|
655
|
+
//# sourceMappingURL=webgpu-CcGP160M.cjs.map
|