@jax-js/jax 0.1.9 → 0.1.11
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +35 -19
- package/dist/{backend-BId79r5b.js → backend-DZvR7mZV.js} +831 -26
- package/dist/{backend-DpI0riom.cjs → backend-DlYlOYqN.cjs} +872 -25
- package/dist/index.cjs +364 -20
- package/dist/index.d.cts +175 -11
- package/dist/index.d.ts +175 -11
- package/dist/index.js +363 -21
- package/dist/{webgl-DnGrclTz.js → webgl-D8-14NzA.js} +7 -1
- package/dist/{webgl-C5NjXc1p.cjs → webgl-Ovaaa-Qx.cjs} +7 -1
- package/dist/{webgpu-AN0cG_nB.js → webgpu-Dg8FpYrH.js} +141 -6
- package/dist/{webgpu-CdjiJSa7.cjs → webgpu-uU9nnttc.cjs} +141 -6
- package/package.json +5 -16
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
const require_backend = require('./backend-
|
|
1
|
+
const require_backend = require('./backend-DlYlOYqN.cjs');
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgl/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -458,6 +458,12 @@ function generateExpression(exp, args, inputDtypes) {
|
|
|
458
458
|
else source = `min(${a}, ${b})`;
|
|
459
459
|
else if (op === require_backend.AluOp.Max) if (dtype === require_backend.DType.Bool) source = `(${a} || ${b})`;
|
|
460
460
|
else source = `max(${a}, ${b})`;
|
|
461
|
+
else if (op === require_backend.AluOp.BitCombine) {
|
|
462
|
+
let infix = arg === "and" ? "&" : arg === "or" ? "|" : "^";
|
|
463
|
+
if (dtype === require_backend.DType.Bool) infix = infix + infix;
|
|
464
|
+
source = `(${a} ${infix} ${b})`;
|
|
465
|
+
} else if (op === require_backend.AluOp.BitShift) if (arg === "shl") source = `(${a} << ${b})`;
|
|
466
|
+
else source = `(${a} >> ${b})`;
|
|
461
467
|
} else if (require_backend.AluGroup.Compare.has(op)) {
|
|
462
468
|
const a = gen(src[0]);
|
|
463
469
|
const b = gen(src[1]);
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, findPow2, isFloatDtype, mapSetUnion, prod, range, strip1, tuneWebgpu } from "./backend-
|
|
1
|
+
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, emitTrace, findPow2, isFloatDtype, isTracing, mapSetUnion, onFlushTrace, prod, range, strip1, traceSourceInfo, tuneWebgpu } from "./backend-DZvR7mZV.js";
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -152,6 +152,7 @@ var SyncReader = class SyncReader {
|
|
|
152
152
|
this.device = device;
|
|
153
153
|
}
|
|
154
154
|
#init() {
|
|
155
|
+
if (typeof OffscreenCanvas === "undefined") throw new Error("OffscreenCanvas is not available in this environment, so you cannot read data from WebGPU synchronously. Consider using the async API.");
|
|
155
156
|
const makeCanvas = () => new OffscreenCanvas(SyncReader.width, SyncReader.height);
|
|
156
157
|
this.deviceStorage = SyncReader.alphaModes.map(makeCanvas);
|
|
157
158
|
this.deviceContexts = this.deviceStorage.map((canvas, i) => {
|
|
@@ -728,6 +729,120 @@ function createRoutineShader(device, routine) {
|
|
|
728
729
|
}
|
|
729
730
|
}
|
|
730
731
|
|
|
732
|
+
//#endregion
|
|
733
|
+
//#region src/backend/webgpu/tracing.ts
|
|
734
|
+
const MAX_TIMESTAMP_QUERIES = 4096;
|
|
735
|
+
const activeBatch = /* @__PURE__ */ new WeakMap();
|
|
736
|
+
function createTracingBatch(device) {
|
|
737
|
+
return {
|
|
738
|
+
querySet: device.createQuerySet({
|
|
739
|
+
type: "timestamp",
|
|
740
|
+
count: MAX_TIMESTAMP_QUERIES
|
|
741
|
+
}),
|
|
742
|
+
resolve: device.createBuffer({
|
|
743
|
+
size: MAX_TIMESTAMP_QUERIES * 8,
|
|
744
|
+
usage: GPUBufferUsage.QUERY_RESOLVE | GPUBufferUsage.COPY_SRC
|
|
745
|
+
}),
|
|
746
|
+
dst: device.createBuffer({
|
|
747
|
+
size: MAX_TIMESTAMP_QUERIES * 8,
|
|
748
|
+
usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST
|
|
749
|
+
}),
|
|
750
|
+
nextIndex: 0,
|
|
751
|
+
entries: []
|
|
752
|
+
};
|
|
753
|
+
}
|
|
754
|
+
function acquireTracingSlot(device) {
|
|
755
|
+
if (!device.features.has("timestamp-query")) return void 0;
|
|
756
|
+
let batch = activeBatch.get(device);
|
|
757
|
+
if (batch && batch.nextIndex >= MAX_TIMESTAMP_QUERIES) {
|
|
758
|
+
flushTracingBatch(device, batch);
|
|
759
|
+
batch = void 0;
|
|
760
|
+
}
|
|
761
|
+
if (!batch) {
|
|
762
|
+
batch = createTracingBatch(device);
|
|
763
|
+
activeBatch.set(device, batch);
|
|
764
|
+
onFlushTrace(() => {
|
|
765
|
+
const b = activeBatch.get(device);
|
|
766
|
+
if (b && b.entries.length > 0) flushTracingBatch(device, b);
|
|
767
|
+
activeBatch.delete(device);
|
|
768
|
+
});
|
|
769
|
+
}
|
|
770
|
+
const beginIndex = batch.nextIndex;
|
|
771
|
+
const endIndex = beginIndex + 1;
|
|
772
|
+
batch.nextIndex += 2;
|
|
773
|
+
return {
|
|
774
|
+
batch,
|
|
775
|
+
beginIndex,
|
|
776
|
+
endIndex
|
|
777
|
+
};
|
|
778
|
+
}
|
|
779
|
+
/**
|
|
780
|
+
* If tracing is active, acquire a slot for timestamp queries.
|
|
781
|
+
*
|
|
782
|
+
* Returns undefined if tracing is not active or the device doesn't support
|
|
783
|
+
* timestamp queries.
|
|
784
|
+
*/
|
|
785
|
+
function maybeAcquireTracingSlot(device) {
|
|
786
|
+
if (!isTracing()) return void 0;
|
|
787
|
+
return acquireTracingSlot(device);
|
|
788
|
+
}
|
|
789
|
+
/**
|
|
790
|
+
* Record a tracing entry for a pipeline dispatch and schedule an auto-flush.
|
|
791
|
+
*/
|
|
792
|
+
function recordTrace(device, slot, source, numPasses, wgslSource) {
|
|
793
|
+
const info = traceSourceInfo(source);
|
|
794
|
+
info.properties.push(["passes", `${numPasses}`]);
|
|
795
|
+
info.properties.push(["source", wgslSource]);
|
|
796
|
+
slot.batch.entries.push({
|
|
797
|
+
...info,
|
|
798
|
+
beginIndex: slot.beginIndex,
|
|
799
|
+
endIndex: slot.endIndex
|
|
800
|
+
});
|
|
801
|
+
scheduleAutoFlush(device);
|
|
802
|
+
}
|
|
803
|
+
/**
|
|
804
|
+
* If the active batch has pending entries, flush and replace it so traces
|
|
805
|
+
* are emitted without waiting for the batch to fill or stopTrace().
|
|
806
|
+
*
|
|
807
|
+
* Called after each dispatch records its entry via a microtask so that
|
|
808
|
+
* synchronous back-to-back dispatches are still batched together.
|
|
809
|
+
*/
|
|
810
|
+
function scheduleAutoFlush(device) {
|
|
811
|
+
queueMicrotask(() => {
|
|
812
|
+
const batch = activeBatch.get(device);
|
|
813
|
+
if (batch && batch.entries.length > 0) {
|
|
814
|
+
flushTracingBatch(device, batch);
|
|
815
|
+
activeBatch.set(device, createTracingBatch(device));
|
|
816
|
+
}
|
|
817
|
+
});
|
|
818
|
+
}
|
|
819
|
+
function flushTracingBatch(device, batch) {
|
|
820
|
+
if (batch.entries.length === 0) return;
|
|
821
|
+
const usedQueries = batch.nextIndex;
|
|
822
|
+
const encoder = device.createCommandEncoder();
|
|
823
|
+
encoder.resolveQuerySet(batch.querySet, 0, usedQueries, batch.resolve, 0);
|
|
824
|
+
encoder.copyBufferToBuffer(batch.resolve, 0, batch.dst, 0, usedQueries * 8);
|
|
825
|
+
device.queue.submit([encoder.finish()]);
|
|
826
|
+
const { entries } = batch;
|
|
827
|
+
batch.dst.mapAsync(GPUMapMode.READ).then(() => {
|
|
828
|
+
try {
|
|
829
|
+
const times = new BigInt64Array(batch.dst.getMappedRange());
|
|
830
|
+
const anchorGpuNs = times[entries[entries.length - 1].endIndex];
|
|
831
|
+
const anchorCpuMs = performance.now();
|
|
832
|
+
for (const entry of entries) {
|
|
833
|
+
const startMs = anchorCpuMs + Number(times[entry.beginIndex] - anchorGpuNs) / 1e6;
|
|
834
|
+
const endMs = anchorCpuMs + Number(times[entry.endIndex] - anchorGpuNs) / 1e6;
|
|
835
|
+
emitTrace("webgpu", entry, startMs, endMs);
|
|
836
|
+
}
|
|
837
|
+
} finally {
|
|
838
|
+
batch.dst.unmap();
|
|
839
|
+
batch.querySet.destroy();
|
|
840
|
+
batch.resolve.destroy();
|
|
841
|
+
batch.dst.destroy();
|
|
842
|
+
}
|
|
843
|
+
});
|
|
844
|
+
}
|
|
845
|
+
|
|
731
846
|
//#endregion
|
|
732
847
|
//#region src/backend/webgpu.ts
|
|
733
848
|
/** Implementation of `Backend` that uses WebGPU in browsers. */
|
|
@@ -872,7 +987,7 @@ var WebGPUBackend = class {
|
|
|
872
987
|
dispatch(exe, inputs, outputs) {
|
|
873
988
|
const inputBuffers = inputs.map((slot) => this.#getBuffer(slot).buffer);
|
|
874
989
|
const outputBuffers = outputs.map((slot) => this.#getBuffer(slot).buffer);
|
|
875
|
-
pipelineSubmit(this.device, exe
|
|
990
|
+
pipelineSubmit(this.device, exe, inputBuffers, outputBuffers);
|
|
876
991
|
}
|
|
877
992
|
#getBuffer(slot) {
|
|
878
993
|
const buffer = this.buffers.get(slot);
|
|
@@ -985,6 +1100,11 @@ function pipelineSource(device, kernel) {
|
|
|
985
1100
|
else source = `min(${strip1(a)}, ${strip1(b)})`;
|
|
986
1101
|
else if (op === AluOp.Max) if (dtype === DType.Bool) source = `(${a} || ${b})`;
|
|
987
1102
|
else source = `max(${strip1(a)}, ${strip1(b)})`;
|
|
1103
|
+
else if (op === AluOp.BitCombine) if (arg === "and") source = `(${a} & ${b})`;
|
|
1104
|
+
else if (arg === "or") source = `(${a} | ${b})`;
|
|
1105
|
+
else source = dtype === DType.Bool ? `(${a} != ${b})` : `(${a} ^ ${b})`;
|
|
1106
|
+
else if (op === AluOp.BitShift) if (arg === "shl") source = `(${a} << ${b})`;
|
|
1107
|
+
else source = `(${a} >> ${b})`;
|
|
988
1108
|
else if (op === AluOp.Cmplt) source = `(${a} < ${b})`;
|
|
989
1109
|
else if (op === AluOp.Cmpne) if (isFloatDtype(src[0].dtype)) {
|
|
990
1110
|
const x = isGensym(a) ? a : gensym();
|
|
@@ -1010,8 +1130,16 @@ function pipelineSource(device, kernel) {
|
|
|
1010
1130
|
else if (op === AluOp.Reciprocal) source = `(1.0 / ${a})`;
|
|
1011
1131
|
else if (op === AluOp.Floor) source = `floor(${strip1(a)})`;
|
|
1012
1132
|
else if (op === AluOp.Ceil) source = `ceil(${strip1(a)})`;
|
|
1013
|
-
else if (op === AluOp.Cast)
|
|
1014
|
-
|
|
1133
|
+
else if (op === AluOp.Cast) {
|
|
1134
|
+
const srcTy = dtypeToWgsl(src[0].dtype);
|
|
1135
|
+
const dstTy = dtypeToWgsl(dtype);
|
|
1136
|
+
if (isFloatDtype(src[0].dtype) && !(isFloatDtype(dtype) || dtype === DType.Bool)) {
|
|
1137
|
+
const maxVal = maxValueWgsl(dtype);
|
|
1138
|
+
const x = isGensym(a) ? a : gensym();
|
|
1139
|
+
if (x !== a) emit(`let ${x}: ${srcTy} = ${strip1(a)};`);
|
|
1140
|
+
source = `select(${dstTy}(${x}), ${maxVal}, ${x} >= ${srcTy}(${maxVal}))`;
|
|
1141
|
+
} else source = `${dstTy}(${strip1(a)})`;
|
|
1142
|
+
} else if (op === AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${strip1(a)})`;
|
|
1015
1143
|
}
|
|
1016
1144
|
else if (op === AluOp.Where) source = `select(${strip1(gen(src[2]))}, ${strip1(gen(src[1]))}, ${strip1(gen(src[0]))})`;
|
|
1017
1145
|
else if (op === AluOp.Threefry2x32) {
|
|
@@ -1114,12 +1242,14 @@ function pipelineSource(device, kernel) {
|
|
|
1114
1242
|
passes: [{ grid: [gridX, gridY] }]
|
|
1115
1243
|
};
|
|
1116
1244
|
}
|
|
1117
|
-
function pipelineSubmit(device,
|
|
1245
|
+
function pipelineSubmit(device, exe, inputs, outputs) {
|
|
1246
|
+
const { data: pipelines, source } = exe;
|
|
1118
1247
|
const commandEncoder = device.createCommandEncoder();
|
|
1119
1248
|
for (const { pipeline,...shader } of pipelines) {
|
|
1120
1249
|
if (inputs.length !== shader.numInputs || outputs.length !== shader.numOutputs) throw new Error(`webgpu: expected ${shader.numInputs} inputs and ${shader.numOutputs} outputs, got ${inputs.length} inputs and ${outputs.length} outputs`);
|
|
1121
1250
|
const filteredPasses = shader.passes.filter(({ grid }) => prod(grid) > 0);
|
|
1122
1251
|
if (filteredPasses.length === 0) continue;
|
|
1252
|
+
const slot = maybeAcquireTracingSlot(device);
|
|
1123
1253
|
const bindGroup = device.createBindGroup({
|
|
1124
1254
|
layout: pipeline.getBindGroupLayout(0),
|
|
1125
1255
|
entries: [...inputs.map((buffer, i) => ({
|
|
@@ -1149,13 +1279,18 @@ function pipelineSubmit(device, pipelines, inputs, outputs) {
|
|
|
1149
1279
|
}
|
|
1150
1280
|
for (let i = 0; i < filteredPasses.length; i++) {
|
|
1151
1281
|
const { grid } = filteredPasses[i];
|
|
1152
|
-
const passEncoder = commandEncoder.beginComputePass(
|
|
1282
|
+
const passEncoder = commandEncoder.beginComputePass({ timestampWrites: slot ? {
|
|
1283
|
+
querySet: slot.batch.querySet,
|
|
1284
|
+
beginningOfPassWriteIndex: i === 0 ? slot.beginIndex : void 0,
|
|
1285
|
+
endOfPassWriteIndex: i === filteredPasses.length - 1 ? slot.endIndex : void 0
|
|
1286
|
+
} : void 0 });
|
|
1153
1287
|
passEncoder.setPipeline(pipeline);
|
|
1154
1288
|
passEncoder.setBindGroup(0, bindGroup);
|
|
1155
1289
|
if (uniformBindGroup) passEncoder.setBindGroup(1, uniformBindGroup, [i * uniformAlignment]);
|
|
1156
1290
|
passEncoder.dispatchWorkgroups(grid[0], grid[1]);
|
|
1157
1291
|
passEncoder.end();
|
|
1158
1292
|
}
|
|
1293
|
+
if (slot) recordTrace(device, slot, source, filteredPasses.length, shader.code);
|
|
1159
1294
|
}
|
|
1160
1295
|
device.queue.submit([commandEncoder.finish()]);
|
|
1161
1296
|
}
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
const require_backend = require('./backend-
|
|
1
|
+
const require_backend = require('./backend-DlYlOYqN.cjs');
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -152,6 +152,7 @@ var SyncReader = class SyncReader {
|
|
|
152
152
|
this.device = device;
|
|
153
153
|
}
|
|
154
154
|
#init() {
|
|
155
|
+
if (typeof OffscreenCanvas === "undefined") throw new Error("OffscreenCanvas is not available in this environment, so you cannot read data from WebGPU synchronously. Consider using the async API.");
|
|
155
156
|
const makeCanvas = () => new OffscreenCanvas(SyncReader.width, SyncReader.height);
|
|
156
157
|
this.deviceStorage = SyncReader.alphaModes.map(makeCanvas);
|
|
157
158
|
this.deviceContexts = this.deviceStorage.map((canvas, i) => {
|
|
@@ -728,6 +729,120 @@ function createRoutineShader(device, routine) {
|
|
|
728
729
|
}
|
|
729
730
|
}
|
|
730
731
|
|
|
732
|
+
//#endregion
|
|
733
|
+
//#region src/backend/webgpu/tracing.ts
|
|
734
|
+
const MAX_TIMESTAMP_QUERIES = 4096;
|
|
735
|
+
const activeBatch = /* @__PURE__ */ new WeakMap();
|
|
736
|
+
function createTracingBatch(device) {
|
|
737
|
+
return {
|
|
738
|
+
querySet: device.createQuerySet({
|
|
739
|
+
type: "timestamp",
|
|
740
|
+
count: MAX_TIMESTAMP_QUERIES
|
|
741
|
+
}),
|
|
742
|
+
resolve: device.createBuffer({
|
|
743
|
+
size: MAX_TIMESTAMP_QUERIES * 8,
|
|
744
|
+
usage: GPUBufferUsage.QUERY_RESOLVE | GPUBufferUsage.COPY_SRC
|
|
745
|
+
}),
|
|
746
|
+
dst: device.createBuffer({
|
|
747
|
+
size: MAX_TIMESTAMP_QUERIES * 8,
|
|
748
|
+
usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST
|
|
749
|
+
}),
|
|
750
|
+
nextIndex: 0,
|
|
751
|
+
entries: []
|
|
752
|
+
};
|
|
753
|
+
}
|
|
754
|
+
function acquireTracingSlot(device) {
|
|
755
|
+
if (!device.features.has("timestamp-query")) return void 0;
|
|
756
|
+
let batch = activeBatch.get(device);
|
|
757
|
+
if (batch && batch.nextIndex >= MAX_TIMESTAMP_QUERIES) {
|
|
758
|
+
flushTracingBatch(device, batch);
|
|
759
|
+
batch = void 0;
|
|
760
|
+
}
|
|
761
|
+
if (!batch) {
|
|
762
|
+
batch = createTracingBatch(device);
|
|
763
|
+
activeBatch.set(device, batch);
|
|
764
|
+
require_backend.onFlushTrace(() => {
|
|
765
|
+
const b = activeBatch.get(device);
|
|
766
|
+
if (b && b.entries.length > 0) flushTracingBatch(device, b);
|
|
767
|
+
activeBatch.delete(device);
|
|
768
|
+
});
|
|
769
|
+
}
|
|
770
|
+
const beginIndex = batch.nextIndex;
|
|
771
|
+
const endIndex = beginIndex + 1;
|
|
772
|
+
batch.nextIndex += 2;
|
|
773
|
+
return {
|
|
774
|
+
batch,
|
|
775
|
+
beginIndex,
|
|
776
|
+
endIndex
|
|
777
|
+
};
|
|
778
|
+
}
|
|
779
|
+
/**
|
|
780
|
+
* If tracing is active, acquire a slot for timestamp queries.
|
|
781
|
+
*
|
|
782
|
+
* Returns undefined if tracing is not active or the device doesn't support
|
|
783
|
+
* timestamp queries.
|
|
784
|
+
*/
|
|
785
|
+
function maybeAcquireTracingSlot(device) {
|
|
786
|
+
if (!require_backend.isTracing()) return void 0;
|
|
787
|
+
return acquireTracingSlot(device);
|
|
788
|
+
}
|
|
789
|
+
/**
|
|
790
|
+
* Record a tracing entry for a pipeline dispatch and schedule an auto-flush.
|
|
791
|
+
*/
|
|
792
|
+
function recordTrace(device, slot, source, numPasses, wgslSource) {
|
|
793
|
+
const info = require_backend.traceSourceInfo(source);
|
|
794
|
+
info.properties.push(["passes", `${numPasses}`]);
|
|
795
|
+
info.properties.push(["source", wgslSource]);
|
|
796
|
+
slot.batch.entries.push({
|
|
797
|
+
...info,
|
|
798
|
+
beginIndex: slot.beginIndex,
|
|
799
|
+
endIndex: slot.endIndex
|
|
800
|
+
});
|
|
801
|
+
scheduleAutoFlush(device);
|
|
802
|
+
}
|
|
803
|
+
/**
|
|
804
|
+
* If the active batch has pending entries, flush and replace it so traces
|
|
805
|
+
* are emitted without waiting for the batch to fill or stopTrace().
|
|
806
|
+
*
|
|
807
|
+
* Called after each dispatch records its entry via a microtask so that
|
|
808
|
+
* synchronous back-to-back dispatches are still batched together.
|
|
809
|
+
*/
|
|
810
|
+
function scheduleAutoFlush(device) {
|
|
811
|
+
queueMicrotask(() => {
|
|
812
|
+
const batch = activeBatch.get(device);
|
|
813
|
+
if (batch && batch.entries.length > 0) {
|
|
814
|
+
flushTracingBatch(device, batch);
|
|
815
|
+
activeBatch.set(device, createTracingBatch(device));
|
|
816
|
+
}
|
|
817
|
+
});
|
|
818
|
+
}
|
|
819
|
+
function flushTracingBatch(device, batch) {
|
|
820
|
+
if (batch.entries.length === 0) return;
|
|
821
|
+
const usedQueries = batch.nextIndex;
|
|
822
|
+
const encoder = device.createCommandEncoder();
|
|
823
|
+
encoder.resolveQuerySet(batch.querySet, 0, usedQueries, batch.resolve, 0);
|
|
824
|
+
encoder.copyBufferToBuffer(batch.resolve, 0, batch.dst, 0, usedQueries * 8);
|
|
825
|
+
device.queue.submit([encoder.finish()]);
|
|
826
|
+
const { entries } = batch;
|
|
827
|
+
batch.dst.mapAsync(GPUMapMode.READ).then(() => {
|
|
828
|
+
try {
|
|
829
|
+
const times = new BigInt64Array(batch.dst.getMappedRange());
|
|
830
|
+
const anchorGpuNs = times[entries[entries.length - 1].endIndex];
|
|
831
|
+
const anchorCpuMs = performance.now();
|
|
832
|
+
for (const entry of entries) {
|
|
833
|
+
const startMs = anchorCpuMs + Number(times[entry.beginIndex] - anchorGpuNs) / 1e6;
|
|
834
|
+
const endMs = anchorCpuMs + Number(times[entry.endIndex] - anchorGpuNs) / 1e6;
|
|
835
|
+
require_backend.emitTrace("webgpu", entry, startMs, endMs);
|
|
836
|
+
}
|
|
837
|
+
} finally {
|
|
838
|
+
batch.dst.unmap();
|
|
839
|
+
batch.querySet.destroy();
|
|
840
|
+
batch.resolve.destroy();
|
|
841
|
+
batch.dst.destroy();
|
|
842
|
+
}
|
|
843
|
+
});
|
|
844
|
+
}
|
|
845
|
+
|
|
731
846
|
//#endregion
|
|
732
847
|
//#region src/backend/webgpu.ts
|
|
733
848
|
/** Implementation of `Backend` that uses WebGPU in browsers. */
|
|
@@ -872,7 +987,7 @@ var WebGPUBackend = class {
|
|
|
872
987
|
dispatch(exe, inputs, outputs) {
|
|
873
988
|
const inputBuffers = inputs.map((slot) => this.#getBuffer(slot).buffer);
|
|
874
989
|
const outputBuffers = outputs.map((slot) => this.#getBuffer(slot).buffer);
|
|
875
|
-
pipelineSubmit(this.device, exe
|
|
990
|
+
pipelineSubmit(this.device, exe, inputBuffers, outputBuffers);
|
|
876
991
|
}
|
|
877
992
|
#getBuffer(slot) {
|
|
878
993
|
const buffer = this.buffers.get(slot);
|
|
@@ -985,6 +1100,11 @@ function pipelineSource(device, kernel) {
|
|
|
985
1100
|
else source = `min(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
|
|
986
1101
|
else if (op === require_backend.AluOp.Max) if (dtype === require_backend.DType.Bool) source = `(${a} || ${b})`;
|
|
987
1102
|
else source = `max(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
|
|
1103
|
+
else if (op === require_backend.AluOp.BitCombine) if (arg === "and") source = `(${a} & ${b})`;
|
|
1104
|
+
else if (arg === "or") source = `(${a} | ${b})`;
|
|
1105
|
+
else source = dtype === require_backend.DType.Bool ? `(${a} != ${b})` : `(${a} ^ ${b})`;
|
|
1106
|
+
else if (op === require_backend.AluOp.BitShift) if (arg === "shl") source = `(${a} << ${b})`;
|
|
1107
|
+
else source = `(${a} >> ${b})`;
|
|
988
1108
|
else if (op === require_backend.AluOp.Cmplt) source = `(${a} < ${b})`;
|
|
989
1109
|
else if (op === require_backend.AluOp.Cmpne) if (require_backend.isFloatDtype(src[0].dtype)) {
|
|
990
1110
|
const x = isGensym(a) ? a : gensym();
|
|
@@ -1010,8 +1130,16 @@ function pipelineSource(device, kernel) {
|
|
|
1010
1130
|
else if (op === require_backend.AluOp.Reciprocal) source = `(1.0 / ${a})`;
|
|
1011
1131
|
else if (op === require_backend.AluOp.Floor) source = `floor(${require_backend.strip1(a)})`;
|
|
1012
1132
|
else if (op === require_backend.AluOp.Ceil) source = `ceil(${require_backend.strip1(a)})`;
|
|
1013
|
-
else if (op === require_backend.AluOp.Cast)
|
|
1014
|
-
|
|
1133
|
+
else if (op === require_backend.AluOp.Cast) {
|
|
1134
|
+
const srcTy = dtypeToWgsl(src[0].dtype);
|
|
1135
|
+
const dstTy = dtypeToWgsl(dtype);
|
|
1136
|
+
if (require_backend.isFloatDtype(src[0].dtype) && !(require_backend.isFloatDtype(dtype) || dtype === require_backend.DType.Bool)) {
|
|
1137
|
+
const maxVal = maxValueWgsl(dtype);
|
|
1138
|
+
const x = isGensym(a) ? a : gensym();
|
|
1139
|
+
if (x !== a) emit(`let ${x}: ${srcTy} = ${require_backend.strip1(a)};`);
|
|
1140
|
+
source = `select(${dstTy}(${x}), ${maxVal}, ${x} >= ${srcTy}(${maxVal}))`;
|
|
1141
|
+
} else source = `${dstTy}(${require_backend.strip1(a)})`;
|
|
1142
|
+
} else if (op === require_backend.AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${require_backend.strip1(a)})`;
|
|
1015
1143
|
}
|
|
1016
1144
|
else if (op === require_backend.AluOp.Where) source = `select(${require_backend.strip1(gen(src[2]))}, ${require_backend.strip1(gen(src[1]))}, ${require_backend.strip1(gen(src[0]))})`;
|
|
1017
1145
|
else if (op === require_backend.AluOp.Threefry2x32) {
|
|
@@ -1114,12 +1242,14 @@ function pipelineSource(device, kernel) {
|
|
|
1114
1242
|
passes: [{ grid: [gridX, gridY] }]
|
|
1115
1243
|
};
|
|
1116
1244
|
}
|
|
1117
|
-
function pipelineSubmit(device,
|
|
1245
|
+
function pipelineSubmit(device, exe, inputs, outputs) {
|
|
1246
|
+
const { data: pipelines, source } = exe;
|
|
1118
1247
|
const commandEncoder = device.createCommandEncoder();
|
|
1119
1248
|
for (const { pipeline,...shader } of pipelines) {
|
|
1120
1249
|
if (inputs.length !== shader.numInputs || outputs.length !== shader.numOutputs) throw new Error(`webgpu: expected ${shader.numInputs} inputs and ${shader.numOutputs} outputs, got ${inputs.length} inputs and ${outputs.length} outputs`);
|
|
1121
1250
|
const filteredPasses = shader.passes.filter(({ grid }) => require_backend.prod(grid) > 0);
|
|
1122
1251
|
if (filteredPasses.length === 0) continue;
|
|
1252
|
+
const slot = maybeAcquireTracingSlot(device);
|
|
1123
1253
|
const bindGroup = device.createBindGroup({
|
|
1124
1254
|
layout: pipeline.getBindGroupLayout(0),
|
|
1125
1255
|
entries: [...inputs.map((buffer, i) => ({
|
|
@@ -1149,13 +1279,18 @@ function pipelineSubmit(device, pipelines, inputs, outputs) {
|
|
|
1149
1279
|
}
|
|
1150
1280
|
for (let i = 0; i < filteredPasses.length; i++) {
|
|
1151
1281
|
const { grid } = filteredPasses[i];
|
|
1152
|
-
const passEncoder = commandEncoder.beginComputePass(
|
|
1282
|
+
const passEncoder = commandEncoder.beginComputePass({ timestampWrites: slot ? {
|
|
1283
|
+
querySet: slot.batch.querySet,
|
|
1284
|
+
beginningOfPassWriteIndex: i === 0 ? slot.beginIndex : void 0,
|
|
1285
|
+
endOfPassWriteIndex: i === filteredPasses.length - 1 ? slot.endIndex : void 0
|
|
1286
|
+
} : void 0 });
|
|
1153
1287
|
passEncoder.setPipeline(pipeline);
|
|
1154
1288
|
passEncoder.setBindGroup(0, bindGroup);
|
|
1155
1289
|
if (uniformBindGroup) passEncoder.setBindGroup(1, uniformBindGroup, [i * uniformAlignment]);
|
|
1156
1290
|
passEncoder.dispatchWorkgroups(grid[0], grid[1]);
|
|
1157
1291
|
passEncoder.end();
|
|
1158
1292
|
}
|
|
1293
|
+
if (slot) recordTrace(device, slot, source, filteredPasses.length, shader.code);
|
|
1159
1294
|
}
|
|
1160
1295
|
device.queue.submit([commandEncoder.finish()]);
|
|
1161
1296
|
}
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@jax-js/jax",
|
|
3
|
-
"version": "0.1.
|
|
3
|
+
"version": "0.1.11",
|
|
4
4
|
"description": "Numerical computing and ML in the browser",
|
|
5
5
|
"keywords": [
|
|
6
6
|
"machine learning",
|
|
@@ -38,15 +38,13 @@
|
|
|
38
38
|
"devDependencies": {
|
|
39
39
|
"@eslint/js": "^9.31.0",
|
|
40
40
|
"@types/debug": "^4.1.12",
|
|
41
|
-
"@vitest/browser-playwright": "^4.0
|
|
42
|
-
"@vitest/coverage-v8": "4.0
|
|
41
|
+
"@vitest/browser-playwright": "^4.1.0",
|
|
42
|
+
"@vitest/coverage-v8": "^4.1.0",
|
|
43
43
|
"@webgpu/types": "^0.1.68",
|
|
44
44
|
"eslint": "^9.31.0",
|
|
45
45
|
"eslint-plugin-import": "^2.32.0",
|
|
46
46
|
"globals": "^16.0.0",
|
|
47
|
-
"
|
|
48
|
-
"lint-staged": "^16.2.7",
|
|
49
|
-
"playwright": "~1.52.0",
|
|
47
|
+
"playwright": "~1.58.2",
|
|
50
48
|
"prettier": "^3.6.2",
|
|
51
49
|
"prettier-plugin-svelte": "^3.4.0",
|
|
52
50
|
"tsdown": "^0.13.2",
|
|
@@ -55,7 +53,7 @@
|
|
|
55
53
|
"typedoc-theme-fresh": "^0.2.3",
|
|
56
54
|
"typescript": "~5.9.3",
|
|
57
55
|
"typescript-eslint": "^8.46.4",
|
|
58
|
-
"vitest": "^4.0
|
|
56
|
+
"vitest": "^4.1.0"
|
|
59
57
|
},
|
|
60
58
|
"engines": {
|
|
61
59
|
"pnpm": ">=10.0.0"
|
|
@@ -76,15 +74,6 @@
|
|
|
76
74
|
],
|
|
77
75
|
"proseWrap": "always"
|
|
78
76
|
},
|
|
79
|
-
"lint-staged": {
|
|
80
|
-
"*.{ts,tsx,js,jsx}": [
|
|
81
|
-
"eslint --fix",
|
|
82
|
-
"prettier --write"
|
|
83
|
-
],
|
|
84
|
-
"*.{json,md,yml,yaml,css,svelte,html}": [
|
|
85
|
-
"prettier --write"
|
|
86
|
-
]
|
|
87
|
-
},
|
|
88
77
|
"scripts": {
|
|
89
78
|
"build": "tsdown",
|
|
90
79
|
"build:watch": "TSDOWN_WATCH_MODE=1 tsdown",
|