@jax-js/jax 0.1.1 → 0.1.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +11 -32
- package/dist/{backend-CoVtc9dx.js → backend-BY8wlLEl.js} +88 -25
- package/dist/{backend-BbrKEB18.cjs → backend-CmaidnkQ.cjs} +88 -25
- package/dist/index.cjs +2901 -2252
- package/dist/index.d.cts +1101 -979
- package/dist/index.d.ts +1101 -979
- package/dist/index.js +2892 -2243
- package/dist/{webgpu-DGYNVHma.cjs → webgpu-BVns4DbI.cjs} +25 -15
- package/dist/{webgpu-B3UVme6n.js → webgpu-C9iAP5h5.js} +25 -15
- package/package.json +13 -21
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
const require_backend = require('./backend-
|
|
1
|
+
const require_backend = require('./backend-CmaidnkQ.cjs');
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -284,6 +284,8 @@ var WebGPUBackend = class {
|
|
|
284
284
|
});
|
|
285
285
|
}
|
|
286
286
|
dispatch(exe, inputs, outputs) {
|
|
287
|
+
if (inputs.length !== exe.kernel.nargs) throw new Error(`webgpu: dispatch with ${inputs.length} inputs, expected ${exe.kernel.nargs}`);
|
|
288
|
+
if (exe.kernel.size === 0) return;
|
|
287
289
|
const inputBuffers = inputs.map((slot) => this.#getBuffer(slot).buffer);
|
|
288
290
|
const outputBuffers = outputs.map((slot) => this.#getBuffer(slot).buffer);
|
|
289
291
|
pipelineSubmit(this.device, exe.data, inputBuffers, outputBuffers);
|
|
@@ -362,12 +364,12 @@ function pipelineSource(device, kernel) {
|
|
|
362
364
|
else if (line === popIndent) indent = indent.slice(0, -2);
|
|
363
365
|
else shader.push(line ? indent + line : line);
|
|
364
366
|
};
|
|
365
|
-
if (tune.exp.some((exp) => exp.dtype === require_backend.DType.Float16) ||
|
|
367
|
+
if (tune.exp.some((exp) => exp.dtype === require_backend.DType.Float16) || tune.epilogue?.some((exp) => exp.dtype === require_backend.DType.Float16)) {
|
|
366
368
|
if (!device.features.has("shader-f16")) throw new Error("WebGPU device does not support shader-f16 feature");
|
|
367
369
|
emit("enable f16;");
|
|
368
370
|
}
|
|
369
371
|
emit("fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }", "fn inf() -> f32 { let bits = 0x7f800000u; return bitcast<f32>(bits); }");
|
|
370
|
-
const distinctOps = require_backend.mapSetUnion(tune.exp.distinctOps(),
|
|
372
|
+
const distinctOps = require_backend.mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
|
|
371
373
|
if (distinctOps.has(require_backend.AluOp.Threefry2x32)) emit(threefrySrc);
|
|
372
374
|
if (distinctOps.has(require_backend.AluOp.Erf) || distinctOps.has(require_backend.AluOp.Erfc)) emit(erfSrc);
|
|
373
375
|
emit("");
|
|
@@ -375,6 +377,9 @@ function pipelineSource(device, kernel) {
|
|
|
375
377
|
tune.exp.fold((exp) => {
|
|
376
378
|
if (exp.op === require_backend.AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
|
|
377
379
|
});
|
|
380
|
+
tune.epilogue?.fold((exp) => {
|
|
381
|
+
if (exp.op === require_backend.AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
|
|
382
|
+
});
|
|
378
383
|
for (let i = 0; i < nargs; i++) {
|
|
379
384
|
const ty = dtypeToWgsl(usedArgs[i] ?? require_backend.DType.Float32, true);
|
|
380
385
|
emit(`@group(0) @binding(${i}) var<storage, read> ${args[i]} : array<${ty}>;`);
|
|
@@ -398,7 +403,7 @@ function pipelineSource(device, kernel) {
|
|
|
398
403
|
let gensymCount = 0;
|
|
399
404
|
const gensym = () => `alu${gensymCount++}`;
|
|
400
405
|
const isGensym = (text) => text.match(/^alu[0-9]+$/);
|
|
401
|
-
|
|
406
|
+
if (args.length > 0) emit(args.map((arg) => `_ = &${arg};`).join(" "));
|
|
402
407
|
const references = /* @__PURE__ */ new Map();
|
|
403
408
|
const seen = /* @__PURE__ */ new Set();
|
|
404
409
|
const countReferences = (exp) => {
|
|
@@ -436,18 +441,20 @@ function pipelineSource(device, kernel) {
|
|
|
436
441
|
source = `inverseSqrt(${a})`;
|
|
437
442
|
} else {
|
|
438
443
|
const a = gen(src[0]);
|
|
439
|
-
if (op === require_backend.AluOp.Sin) source = `sin(${a})`;
|
|
440
|
-
else if (op === require_backend.AluOp.Cos) source = `cos(${a})`;
|
|
441
|
-
else if (op === require_backend.AluOp.Asin) source = `asin(${a})`;
|
|
442
|
-
else if (op === require_backend.AluOp.Atan) source = `atan(${a})`;
|
|
443
|
-
else if (op === require_backend.AluOp.Exp) source = `exp(${a})`;
|
|
444
|
-
else if (op === require_backend.AluOp.Log) source = `log(${a})`;
|
|
444
|
+
if (op === require_backend.AluOp.Sin) source = `sin(${require_backend.strip1(a)})`;
|
|
445
|
+
else if (op === require_backend.AluOp.Cos) source = `cos(${require_backend.strip1(a)})`;
|
|
446
|
+
else if (op === require_backend.AluOp.Asin) source = `asin(${require_backend.strip1(a)})`;
|
|
447
|
+
else if (op === require_backend.AluOp.Atan) source = `atan(${require_backend.strip1(a)})`;
|
|
448
|
+
else if (op === require_backend.AluOp.Exp) source = `exp(${require_backend.strip1(a)})`;
|
|
449
|
+
else if (op === require_backend.AluOp.Log) source = `log(${require_backend.strip1(a)})`;
|
|
445
450
|
else if (op === require_backend.AluOp.Erf || op === require_backend.AluOp.Erfc) {
|
|
446
451
|
const funcName = op === require_backend.AluOp.Erf ? "erf" : "erfc";
|
|
447
|
-
if (dtype !== require_backend.DType.Float32) source = `${dtypeToWgsl(dtype)}(${funcName}(f32(${a})))`;
|
|
448
|
-
else source = `${funcName}(${a})`;
|
|
449
|
-
} else if (op === require_backend.AluOp.Sqrt) source = `sqrt(${a})`;
|
|
452
|
+
if (dtype !== require_backend.DType.Float32) source = `${dtypeToWgsl(dtype)}(${funcName}(f32(${require_backend.strip1(a)})))`;
|
|
453
|
+
else source = `${funcName}(${require_backend.strip1(a)})`;
|
|
454
|
+
} else if (op === require_backend.AluOp.Sqrt) source = `sqrt(${require_backend.strip1(a)})`;
|
|
450
455
|
else if (op === require_backend.AluOp.Reciprocal) source = `(1.0 / ${a})`;
|
|
456
|
+
else if (op === require_backend.AluOp.Floor) source = `floor(${require_backend.strip1(a)})`;
|
|
457
|
+
else if (op === require_backend.AluOp.Ceil) source = `ceil(${require_backend.strip1(a)})`;
|
|
451
458
|
else if (op === require_backend.AluOp.Cast) source = `${dtypeToWgsl(dtype)}(${require_backend.strip1(a)})`;
|
|
452
459
|
else if (op === require_backend.AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${require_backend.strip1(a)})`;
|
|
453
460
|
}
|
|
@@ -528,7 +535,10 @@ function pipelineSource(device, kernel) {
|
|
|
528
535
|
const exp = tune.outputIdxExp.substitute({ upcast: require_backend.AluExp.i32(i) });
|
|
529
536
|
outputIdxExps.push(exp.simplify(cache));
|
|
530
537
|
countReferences(outputIdxExps[i]);
|
|
531
|
-
fusionExps.push(
|
|
538
|
+
fusionExps.push(tune.epilogue.substitute({
|
|
539
|
+
acc: require_backend.AluExp.variable(re.dtype, acc[i]),
|
|
540
|
+
upcast: require_backend.AluExp.i32(i)
|
|
541
|
+
}).simplify(cache));
|
|
532
542
|
countReferences(fusionExps[i]);
|
|
533
543
|
}
|
|
534
544
|
for (let i = 0; i < upcast; i++) {
|
|
@@ -650,4 +660,4 @@ async function compileError(shaderModule, scope, code) {
|
|
|
650
660
|
|
|
651
661
|
//#endregion
|
|
652
662
|
exports.WebGPUBackend = WebGPUBackend;
|
|
653
|
-
//# sourceMappingURL=webgpu-
|
|
663
|
+
//# sourceMappingURL=webgpu-BVns4DbI.cjs.map
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, SlotError, UnsupportedOpError, findPow2, isFloatDtype, mapSetUnion, strip1, tuneWebgpu } from "./backend-
|
|
1
|
+
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, SlotError, UnsupportedOpError, findPow2, isFloatDtype, mapSetUnion, strip1, tuneWebgpu } from "./backend-BY8wlLEl.js";
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -284,6 +284,8 @@ var WebGPUBackend = class {
|
|
|
284
284
|
});
|
|
285
285
|
}
|
|
286
286
|
dispatch(exe, inputs, outputs) {
|
|
287
|
+
if (inputs.length !== exe.kernel.nargs) throw new Error(`webgpu: dispatch with ${inputs.length} inputs, expected ${exe.kernel.nargs}`);
|
|
288
|
+
if (exe.kernel.size === 0) return;
|
|
287
289
|
const inputBuffers = inputs.map((slot) => this.#getBuffer(slot).buffer);
|
|
288
290
|
const outputBuffers = outputs.map((slot) => this.#getBuffer(slot).buffer);
|
|
289
291
|
pipelineSubmit(this.device, exe.data, inputBuffers, outputBuffers);
|
|
@@ -362,12 +364,12 @@ function pipelineSource(device, kernel) {
|
|
|
362
364
|
else if (line === popIndent) indent = indent.slice(0, -2);
|
|
363
365
|
else shader.push(line ? indent + line : line);
|
|
364
366
|
};
|
|
365
|
-
if (tune.exp.some((exp) => exp.dtype === DType.Float16) ||
|
|
367
|
+
if (tune.exp.some((exp) => exp.dtype === DType.Float16) || tune.epilogue?.some((exp) => exp.dtype === DType.Float16)) {
|
|
366
368
|
if (!device.features.has("shader-f16")) throw new Error("WebGPU device does not support shader-f16 feature");
|
|
367
369
|
emit("enable f16;");
|
|
368
370
|
}
|
|
369
371
|
emit("fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }", "fn inf() -> f32 { let bits = 0x7f800000u; return bitcast<f32>(bits); }");
|
|
370
|
-
const distinctOps = mapSetUnion(tune.exp.distinctOps(),
|
|
372
|
+
const distinctOps = mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
|
|
371
373
|
if (distinctOps.has(AluOp.Threefry2x32)) emit(threefrySrc);
|
|
372
374
|
if (distinctOps.has(AluOp.Erf) || distinctOps.has(AluOp.Erfc)) emit(erfSrc);
|
|
373
375
|
emit("");
|
|
@@ -375,6 +377,9 @@ function pipelineSource(device, kernel) {
|
|
|
375
377
|
tune.exp.fold((exp) => {
|
|
376
378
|
if (exp.op === AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
|
|
377
379
|
});
|
|
380
|
+
tune.epilogue?.fold((exp) => {
|
|
381
|
+
if (exp.op === AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
|
|
382
|
+
});
|
|
378
383
|
for (let i = 0; i < nargs; i++) {
|
|
379
384
|
const ty = dtypeToWgsl(usedArgs[i] ?? DType.Float32, true);
|
|
380
385
|
emit(`@group(0) @binding(${i}) var<storage, read> ${args[i]} : array<${ty}>;`);
|
|
@@ -398,7 +403,7 @@ function pipelineSource(device, kernel) {
|
|
|
398
403
|
let gensymCount = 0;
|
|
399
404
|
const gensym = () => `alu${gensymCount++}`;
|
|
400
405
|
const isGensym = (text) => text.match(/^alu[0-9]+$/);
|
|
401
|
-
|
|
406
|
+
if (args.length > 0) emit(args.map((arg) => `_ = &${arg};`).join(" "));
|
|
402
407
|
const references = /* @__PURE__ */ new Map();
|
|
403
408
|
const seen = /* @__PURE__ */ new Set();
|
|
404
409
|
const countReferences = (exp) => {
|
|
@@ -436,18 +441,20 @@ function pipelineSource(device, kernel) {
|
|
|
436
441
|
source = `inverseSqrt(${a})`;
|
|
437
442
|
} else {
|
|
438
443
|
const a = gen(src[0]);
|
|
439
|
-
if (op === AluOp.Sin) source = `sin(${a})`;
|
|
440
|
-
else if (op === AluOp.Cos) source = `cos(${a})`;
|
|
441
|
-
else if (op === AluOp.Asin) source = `asin(${a})`;
|
|
442
|
-
else if (op === AluOp.Atan) source = `atan(${a})`;
|
|
443
|
-
else if (op === AluOp.Exp) source = `exp(${a})`;
|
|
444
|
-
else if (op === AluOp.Log) source = `log(${a})`;
|
|
444
|
+
if (op === AluOp.Sin) source = `sin(${strip1(a)})`;
|
|
445
|
+
else if (op === AluOp.Cos) source = `cos(${strip1(a)})`;
|
|
446
|
+
else if (op === AluOp.Asin) source = `asin(${strip1(a)})`;
|
|
447
|
+
else if (op === AluOp.Atan) source = `atan(${strip1(a)})`;
|
|
448
|
+
else if (op === AluOp.Exp) source = `exp(${strip1(a)})`;
|
|
449
|
+
else if (op === AluOp.Log) source = `log(${strip1(a)})`;
|
|
445
450
|
else if (op === AluOp.Erf || op === AluOp.Erfc) {
|
|
446
451
|
const funcName = op === AluOp.Erf ? "erf" : "erfc";
|
|
447
|
-
if (dtype !== DType.Float32) source = `${dtypeToWgsl(dtype)}(${funcName}(f32(${a})))`;
|
|
448
|
-
else source = `${funcName}(${a})`;
|
|
449
|
-
} else if (op === AluOp.Sqrt) source = `sqrt(${a})`;
|
|
452
|
+
if (dtype !== DType.Float32) source = `${dtypeToWgsl(dtype)}(${funcName}(f32(${strip1(a)})))`;
|
|
453
|
+
else source = `${funcName}(${strip1(a)})`;
|
|
454
|
+
} else if (op === AluOp.Sqrt) source = `sqrt(${strip1(a)})`;
|
|
450
455
|
else if (op === AluOp.Reciprocal) source = `(1.0 / ${a})`;
|
|
456
|
+
else if (op === AluOp.Floor) source = `floor(${strip1(a)})`;
|
|
457
|
+
else if (op === AluOp.Ceil) source = `ceil(${strip1(a)})`;
|
|
451
458
|
else if (op === AluOp.Cast) source = `${dtypeToWgsl(dtype)}(${strip1(a)})`;
|
|
452
459
|
else if (op === AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${strip1(a)})`;
|
|
453
460
|
}
|
|
@@ -528,7 +535,10 @@ function pipelineSource(device, kernel) {
|
|
|
528
535
|
const exp = tune.outputIdxExp.substitute({ upcast: AluExp.i32(i) });
|
|
529
536
|
outputIdxExps.push(exp.simplify(cache));
|
|
530
537
|
countReferences(outputIdxExps[i]);
|
|
531
|
-
fusionExps.push(
|
|
538
|
+
fusionExps.push(tune.epilogue.substitute({
|
|
539
|
+
acc: AluExp.variable(re.dtype, acc[i]),
|
|
540
|
+
upcast: AluExp.i32(i)
|
|
541
|
+
}).simplify(cache));
|
|
532
542
|
countReferences(fusionExps[i]);
|
|
533
543
|
}
|
|
534
544
|
for (let i = 0; i < upcast; i++) {
|
|
@@ -650,4 +660,4 @@ async function compileError(shaderModule, scope, code) {
|
|
|
650
660
|
|
|
651
661
|
//#endregion
|
|
652
662
|
export { WebGPUBackend };
|
|
653
|
-
//# sourceMappingURL=webgpu-
|
|
663
|
+
//# sourceMappingURL=webgpu-C9iAP5h5.js.map
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@jax-js/jax",
|
|
3
|
-
"version": "0.1.
|
|
3
|
+
"version": "0.1.3",
|
|
4
4
|
"description": "Numerical computing and ML in the browser",
|
|
5
5
|
"keywords": [
|
|
6
6
|
"machine learning",
|
|
@@ -35,19 +35,6 @@
|
|
|
35
35
|
"types": "dist/index.d.ts",
|
|
36
36
|
"module": "dist/index.js",
|
|
37
37
|
"license": "MIT",
|
|
38
|
-
"scripts": {
|
|
39
|
-
"build": "tsdown",
|
|
40
|
-
"build:watch": "TSDOWN_WATCH_MODE=1 tsdown",
|
|
41
|
-
"check": "tsc",
|
|
42
|
-
"docs": "tsx typedoc.ts",
|
|
43
|
-
"format": "prettier --write .",
|
|
44
|
-
"format:check": "prettier --check .",
|
|
45
|
-
"lint": "eslint",
|
|
46
|
-
"test": "vitest",
|
|
47
|
-
"test:coverage": "vitest run --coverage && open coverage/index.html",
|
|
48
|
-
"prepublishOnly": "pnpm build",
|
|
49
|
-
"postpublish": "git tag jax/v$npm_package_version && git push --tags"
|
|
50
|
-
},
|
|
51
38
|
"devDependencies": {
|
|
52
39
|
"@eslint/js": "^9.31.0",
|
|
53
40
|
"@types/debug": "^4.1.12",
|
|
@@ -68,15 +55,9 @@
|
|
|
68
55
|
"typescript-eslint": "^8.46.4",
|
|
69
56
|
"vitest": "^4.0.9"
|
|
70
57
|
},
|
|
71
|
-
"packageManager": "pnpm@10.22.0",
|
|
72
58
|
"engines": {
|
|
73
59
|
"pnpm": ">=10.0.0"
|
|
74
60
|
},
|
|
75
|
-
"pnpm": {
|
|
76
|
-
"overrides": {
|
|
77
|
-
"@tensorflow/tfjs-core>@webgpu/types": "^0.1.68"
|
|
78
|
-
}
|
|
79
|
-
},
|
|
80
61
|
"prettier": {
|
|
81
62
|
"plugins": [
|
|
82
63
|
"prettier-plugin-svelte"
|
|
@@ -92,5 +73,16 @@
|
|
|
92
73
|
}
|
|
93
74
|
],
|
|
94
75
|
"proseWrap": "always"
|
|
76
|
+
},
|
|
77
|
+
"scripts": {
|
|
78
|
+
"build": "tsdown",
|
|
79
|
+
"build:watch": "TSDOWN_WATCH_MODE=1 tsdown",
|
|
80
|
+
"check": "tsc",
|
|
81
|
+
"docs": "tsx typedoc.ts",
|
|
82
|
+
"format": "prettier --write .",
|
|
83
|
+
"format:check": "prettier --check .",
|
|
84
|
+
"lint": "eslint",
|
|
85
|
+
"test": "vitest",
|
|
86
|
+
"test:coverage": "vitest run --coverage && open coverage/index.html"
|
|
95
87
|
}
|
|
96
|
-
}
|
|
88
|
+
}
|