@jax-js/jax 0.1.2 → 0.1.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +11 -32
- package/dist/{backend-BqymqzuU.js → backend-BY8wlLEl.js} +58 -20
- package/dist/{backend-DeVfWEFS.cjs → backend-CmaidnkQ.cjs} +58 -20
- package/dist/index.cjs +298 -134
- package/dist/index.d.cts +21 -5
- package/dist/index.d.ts +21 -5
- package/dist/index.js +298 -134
- package/dist/{webgpu-CcGP160M.cjs → webgpu-BVns4DbI.cjs} +14 -6
- package/dist/{webgpu-BGuG58KZ.js → webgpu-C9iAP5h5.js} +14 -6
- package/package.json +1 -1
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
const require_backend = require('./backend-
|
|
1
|
+
const require_backend = require('./backend-CmaidnkQ.cjs');
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -284,6 +284,8 @@ var WebGPUBackend = class {
|
|
|
284
284
|
});
|
|
285
285
|
}
|
|
286
286
|
dispatch(exe, inputs, outputs) {
|
|
287
|
+
if (inputs.length !== exe.kernel.nargs) throw new Error(`webgpu: dispatch with ${inputs.length} inputs, expected ${exe.kernel.nargs}`);
|
|
288
|
+
if (exe.kernel.size === 0) return;
|
|
287
289
|
const inputBuffers = inputs.map((slot) => this.#getBuffer(slot).buffer);
|
|
288
290
|
const outputBuffers = outputs.map((slot) => this.#getBuffer(slot).buffer);
|
|
289
291
|
pipelineSubmit(this.device, exe.data, inputBuffers, outputBuffers);
|
|
@@ -362,12 +364,12 @@ function pipelineSource(device, kernel) {
|
|
|
362
364
|
else if (line === popIndent) indent = indent.slice(0, -2);
|
|
363
365
|
else shader.push(line ? indent + line : line);
|
|
364
366
|
};
|
|
365
|
-
if (tune.exp.some((exp) => exp.dtype === require_backend.DType.Float16) ||
|
|
367
|
+
if (tune.exp.some((exp) => exp.dtype === require_backend.DType.Float16) || tune.epilogue?.some((exp) => exp.dtype === require_backend.DType.Float16)) {
|
|
366
368
|
if (!device.features.has("shader-f16")) throw new Error("WebGPU device does not support shader-f16 feature");
|
|
367
369
|
emit("enable f16;");
|
|
368
370
|
}
|
|
369
371
|
emit("fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }", "fn inf() -> f32 { let bits = 0x7f800000u; return bitcast<f32>(bits); }");
|
|
370
|
-
const distinctOps = require_backend.mapSetUnion(tune.exp.distinctOps(),
|
|
372
|
+
const distinctOps = require_backend.mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
|
|
371
373
|
if (distinctOps.has(require_backend.AluOp.Threefry2x32)) emit(threefrySrc);
|
|
372
374
|
if (distinctOps.has(require_backend.AluOp.Erf) || distinctOps.has(require_backend.AluOp.Erfc)) emit(erfSrc);
|
|
373
375
|
emit("");
|
|
@@ -375,6 +377,9 @@ function pipelineSource(device, kernel) {
|
|
|
375
377
|
tune.exp.fold((exp) => {
|
|
376
378
|
if (exp.op === require_backend.AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
|
|
377
379
|
});
|
|
380
|
+
tune.epilogue?.fold((exp) => {
|
|
381
|
+
if (exp.op === require_backend.AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
|
|
382
|
+
});
|
|
378
383
|
for (let i = 0; i < nargs; i++) {
|
|
379
384
|
const ty = dtypeToWgsl(usedArgs[i] ?? require_backend.DType.Float32, true);
|
|
380
385
|
emit(`@group(0) @binding(${i}) var<storage, read> ${args[i]} : array<${ty}>;`);
|
|
@@ -398,7 +403,7 @@ function pipelineSource(device, kernel) {
|
|
|
398
403
|
let gensymCount = 0;
|
|
399
404
|
const gensym = () => `alu${gensymCount++}`;
|
|
400
405
|
const isGensym = (text) => text.match(/^alu[0-9]+$/);
|
|
401
|
-
|
|
406
|
+
if (args.length > 0) emit(args.map((arg) => `_ = &${arg};`).join(" "));
|
|
402
407
|
const references = /* @__PURE__ */ new Map();
|
|
403
408
|
const seen = /* @__PURE__ */ new Set();
|
|
404
409
|
const countReferences = (exp) => {
|
|
@@ -530,7 +535,10 @@ function pipelineSource(device, kernel) {
|
|
|
530
535
|
const exp = tune.outputIdxExp.substitute({ upcast: require_backend.AluExp.i32(i) });
|
|
531
536
|
outputIdxExps.push(exp.simplify(cache));
|
|
532
537
|
countReferences(outputIdxExps[i]);
|
|
533
|
-
fusionExps.push(
|
|
538
|
+
fusionExps.push(tune.epilogue.substitute({
|
|
539
|
+
acc: require_backend.AluExp.variable(re.dtype, acc[i]),
|
|
540
|
+
upcast: require_backend.AluExp.i32(i)
|
|
541
|
+
}).simplify(cache));
|
|
534
542
|
countReferences(fusionExps[i]);
|
|
535
543
|
}
|
|
536
544
|
for (let i = 0; i < upcast; i++) {
|
|
@@ -652,4 +660,4 @@ async function compileError(shaderModule, scope, code) {
|
|
|
652
660
|
|
|
653
661
|
//#endregion
|
|
654
662
|
exports.WebGPUBackend = WebGPUBackend;
|
|
655
|
-
//# sourceMappingURL=webgpu-
|
|
663
|
+
//# sourceMappingURL=webgpu-BVns4DbI.cjs.map
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, SlotError, UnsupportedOpError, findPow2, isFloatDtype, mapSetUnion, strip1, tuneWebgpu } from "./backend-
|
|
1
|
+
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, SlotError, UnsupportedOpError, findPow2, isFloatDtype, mapSetUnion, strip1, tuneWebgpu } from "./backend-BY8wlLEl.js";
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -284,6 +284,8 @@ var WebGPUBackend = class {
|
|
|
284
284
|
});
|
|
285
285
|
}
|
|
286
286
|
dispatch(exe, inputs, outputs) {
|
|
287
|
+
if (inputs.length !== exe.kernel.nargs) throw new Error(`webgpu: dispatch with ${inputs.length} inputs, expected ${exe.kernel.nargs}`);
|
|
288
|
+
if (exe.kernel.size === 0) return;
|
|
287
289
|
const inputBuffers = inputs.map((slot) => this.#getBuffer(slot).buffer);
|
|
288
290
|
const outputBuffers = outputs.map((slot) => this.#getBuffer(slot).buffer);
|
|
289
291
|
pipelineSubmit(this.device, exe.data, inputBuffers, outputBuffers);
|
|
@@ -362,12 +364,12 @@ function pipelineSource(device, kernel) {
|
|
|
362
364
|
else if (line === popIndent) indent = indent.slice(0, -2);
|
|
363
365
|
else shader.push(line ? indent + line : line);
|
|
364
366
|
};
|
|
365
|
-
if (tune.exp.some((exp) => exp.dtype === DType.Float16) ||
|
|
367
|
+
if (tune.exp.some((exp) => exp.dtype === DType.Float16) || tune.epilogue?.some((exp) => exp.dtype === DType.Float16)) {
|
|
366
368
|
if (!device.features.has("shader-f16")) throw new Error("WebGPU device does not support shader-f16 feature");
|
|
367
369
|
emit("enable f16;");
|
|
368
370
|
}
|
|
369
371
|
emit("fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }", "fn inf() -> f32 { let bits = 0x7f800000u; return bitcast<f32>(bits); }");
|
|
370
|
-
const distinctOps = mapSetUnion(tune.exp.distinctOps(),
|
|
372
|
+
const distinctOps = mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
|
|
371
373
|
if (distinctOps.has(AluOp.Threefry2x32)) emit(threefrySrc);
|
|
372
374
|
if (distinctOps.has(AluOp.Erf) || distinctOps.has(AluOp.Erfc)) emit(erfSrc);
|
|
373
375
|
emit("");
|
|
@@ -375,6 +377,9 @@ function pipelineSource(device, kernel) {
|
|
|
375
377
|
tune.exp.fold((exp) => {
|
|
376
378
|
if (exp.op === AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
|
|
377
379
|
});
|
|
380
|
+
tune.epilogue?.fold((exp) => {
|
|
381
|
+
if (exp.op === AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
|
|
382
|
+
});
|
|
378
383
|
for (let i = 0; i < nargs; i++) {
|
|
379
384
|
const ty = dtypeToWgsl(usedArgs[i] ?? DType.Float32, true);
|
|
380
385
|
emit(`@group(0) @binding(${i}) var<storage, read> ${args[i]} : array<${ty}>;`);
|
|
@@ -398,7 +403,7 @@ function pipelineSource(device, kernel) {
|
|
|
398
403
|
let gensymCount = 0;
|
|
399
404
|
const gensym = () => `alu${gensymCount++}`;
|
|
400
405
|
const isGensym = (text) => text.match(/^alu[0-9]+$/);
|
|
401
|
-
|
|
406
|
+
if (args.length > 0) emit(args.map((arg) => `_ = &${arg};`).join(" "));
|
|
402
407
|
const references = /* @__PURE__ */ new Map();
|
|
403
408
|
const seen = /* @__PURE__ */ new Set();
|
|
404
409
|
const countReferences = (exp) => {
|
|
@@ -530,7 +535,10 @@ function pipelineSource(device, kernel) {
|
|
|
530
535
|
const exp = tune.outputIdxExp.substitute({ upcast: AluExp.i32(i) });
|
|
531
536
|
outputIdxExps.push(exp.simplify(cache));
|
|
532
537
|
countReferences(outputIdxExps[i]);
|
|
533
|
-
fusionExps.push(
|
|
538
|
+
fusionExps.push(tune.epilogue.substitute({
|
|
539
|
+
acc: AluExp.variable(re.dtype, acc[i]),
|
|
540
|
+
upcast: AluExp.i32(i)
|
|
541
|
+
}).simplify(cache));
|
|
534
542
|
countReferences(fusionExps[i]);
|
|
535
543
|
}
|
|
536
544
|
for (let i = 0; i < upcast; i++) {
|
|
@@ -652,4 +660,4 @@ async function compileError(shaderModule, scope, code) {
|
|
|
652
660
|
|
|
653
661
|
//#endregion
|
|
654
662
|
export { WebGPUBackend };
|
|
655
|
-
//# sourceMappingURL=webgpu-
|
|
663
|
+
//# sourceMappingURL=webgpu-C9iAP5h5.js.map
|