@jax-js/jax 0.1.8 → 0.1.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +46 -29
- package/dist/{backend-nEolvdLv.js → backend-Ctqs8la1.js} +122 -15
- package/dist/{backend-B3foXiV_.cjs → backend-DMauYnfl.cjs} +157 -14
- package/dist/index.cjs +331 -46
- package/dist/index.d.cts +175 -31
- package/dist/index.d.ts +175 -31
- package/dist/index.js +331 -47
- package/dist/{webgl-DweKSWEm.js → webgl-CvQ1QBX1.js} +1 -1
- package/dist/{webgl-DIIbKJ0G.cjs → webgl-kvVt7-T7.cjs} +1 -1
- package/dist/{webgpu-BykvF26B.cjs → webgpu-DMSx7a6M.cjs} +160 -15
- package/dist/{webgpu-B96vzWGE.js → webgpu-v_W_-oKw.js} +160 -15
- package/package.json +5 -16
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
const require_backend = require('./backend-
|
|
1
|
+
const require_backend = require('./backend-DMauYnfl.cjs');
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -152,6 +152,7 @@ var SyncReader = class SyncReader {
|
|
|
152
152
|
this.device = device;
|
|
153
153
|
}
|
|
154
154
|
#init() {
|
|
155
|
+
if (typeof OffscreenCanvas === "undefined") throw new Error("OffscreenCanvas is not available in this environment, so you cannot read data from WebGPU synchronously. Consider using the async API.");
|
|
155
156
|
const makeCanvas = () => new OffscreenCanvas(SyncReader.width, SyncReader.height);
|
|
156
157
|
this.deviceStorage = SyncReader.alphaModes.map(makeCanvas);
|
|
157
158
|
this.deviceContexts = this.deviceStorage.map((canvas, i) => {
|
|
@@ -247,6 +248,10 @@ function bitonicSortUniform(pass) {
|
|
|
247
248
|
* `2^(step+1)` with multiple workgroups. This doesn't use shared memory.
|
|
248
249
|
*
|
|
249
250
|
* The total number of passes is roughly `log2(n / workgroupSize)^2 / 2`.
|
|
251
|
+
*
|
|
252
|
+
* If `outputIndices` is true, the shader also tracks the original indices of
|
|
253
|
+
* the sorted elements (argsort) and outputs them to a separate buffer. This
|
|
254
|
+
* also makes the sorting algorithm stable.
|
|
250
255
|
*/
|
|
251
256
|
function bitonicSortShader(device, dtype, n, batches, outputIndices) {
|
|
252
257
|
const ty = dtypeToWgsl(dtype, true);
|
|
@@ -286,14 +291,21 @@ ${require_backend.isFloatDtype(dtype) ? `
|
|
|
286
291
|
fn compare_and_swap(i: u32, j: u32) {
|
|
287
292
|
let val_i = shared_vals[i];
|
|
288
293
|
let val_j = shared_vals[j];
|
|
289
|
-
|
|
294
|
+
${outputIndices ? `
|
|
295
|
+
if (
|
|
296
|
+
compare(val_j, val_i) ||
|
|
297
|
+
(!compare(val_i, val_j) && shared_idx[j] < shared_idx[i])
|
|
298
|
+
) {
|
|
290
299
|
shared_vals[i] = val_j;
|
|
291
300
|
shared_vals[j] = val_i;
|
|
292
|
-
${outputIndices ? `
|
|
293
301
|
let tmp_idx = shared_idx[i];
|
|
294
302
|
shared_idx[i] = shared_idx[j];
|
|
295
|
-
shared_idx[j] = tmp_idx
|
|
296
|
-
}
|
|
303
|
+
shared_idx[j] = tmp_idx;
|
|
304
|
+
}` : `
|
|
305
|
+
if (compare(val_j, val_i)) {
|
|
306
|
+
shared_vals[i] = val_j;
|
|
307
|
+
shared_vals[j] = val_i;
|
|
308
|
+
}`}
|
|
297
309
|
}
|
|
298
310
|
|
|
299
311
|
@compute @workgroup_size(${workgroupSize})
|
|
@@ -370,13 +382,17 @@ ${outputIndices ? `
|
|
|
370
382
|
if (j < ${n}u) {
|
|
371
383
|
let val_i = output[base + i];
|
|
372
384
|
let val_j = output[base + j];
|
|
373
|
-
|
|
385
|
+
${outputIndices ? `
|
|
386
|
+
let idx_i = output_idx[base + i];
|
|
387
|
+
let idx_j = output_idx[base + j];
|
|
388
|
+
if (compare(val_j, val_i) || (!compare(val_i, val_j) && idx_j < idx_i)) {
|
|
374
389
|
output[base + i] = val_j;
|
|
375
390
|
output[base + j] = val_i;
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
391
|
+
output_idx[base + i] = idx_j;
|
|
392
|
+
output_idx[base + j] = idx_i;` : `
|
|
393
|
+
if (compare(val_j, val_i)) {
|
|
394
|
+
output[base + i] = val_j;
|
|
395
|
+
output[base + j] = val_i;`}
|
|
380
396
|
}
|
|
381
397
|
}
|
|
382
398
|
}
|
|
@@ -713,6 +729,120 @@ function createRoutineShader(device, routine) {
|
|
|
713
729
|
}
|
|
714
730
|
}
|
|
715
731
|
|
|
732
|
+
//#endregion
|
|
733
|
+
//#region src/backend/webgpu/tracing.ts
|
|
734
|
+
const MAX_TIMESTAMP_QUERIES = 4096;
|
|
735
|
+
const activeBatch = /* @__PURE__ */ new WeakMap();
|
|
736
|
+
function createTracingBatch(device) {
|
|
737
|
+
return {
|
|
738
|
+
querySet: device.createQuerySet({
|
|
739
|
+
type: "timestamp",
|
|
740
|
+
count: MAX_TIMESTAMP_QUERIES
|
|
741
|
+
}),
|
|
742
|
+
resolve: device.createBuffer({
|
|
743
|
+
size: MAX_TIMESTAMP_QUERIES * 8,
|
|
744
|
+
usage: GPUBufferUsage.QUERY_RESOLVE | GPUBufferUsage.COPY_SRC
|
|
745
|
+
}),
|
|
746
|
+
dst: device.createBuffer({
|
|
747
|
+
size: MAX_TIMESTAMP_QUERIES * 8,
|
|
748
|
+
usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST
|
|
749
|
+
}),
|
|
750
|
+
nextIndex: 0,
|
|
751
|
+
entries: []
|
|
752
|
+
};
|
|
753
|
+
}
|
|
754
|
+
function acquireTracingSlot(device) {
|
|
755
|
+
if (!device.features.has("timestamp-query")) return void 0;
|
|
756
|
+
let batch = activeBatch.get(device);
|
|
757
|
+
if (batch && batch.nextIndex >= MAX_TIMESTAMP_QUERIES) {
|
|
758
|
+
flushTracingBatch(device, batch);
|
|
759
|
+
batch = void 0;
|
|
760
|
+
}
|
|
761
|
+
if (!batch) {
|
|
762
|
+
batch = createTracingBatch(device);
|
|
763
|
+
activeBatch.set(device, batch);
|
|
764
|
+
require_backend.onFlushTrace(() => {
|
|
765
|
+
const b = activeBatch.get(device);
|
|
766
|
+
if (b && b.entries.length > 0) flushTracingBatch(device, b);
|
|
767
|
+
activeBatch.delete(device);
|
|
768
|
+
});
|
|
769
|
+
}
|
|
770
|
+
const beginIndex = batch.nextIndex;
|
|
771
|
+
const endIndex = beginIndex + 1;
|
|
772
|
+
batch.nextIndex += 2;
|
|
773
|
+
return {
|
|
774
|
+
batch,
|
|
775
|
+
beginIndex,
|
|
776
|
+
endIndex
|
|
777
|
+
};
|
|
778
|
+
}
|
|
779
|
+
/**
|
|
780
|
+
* If tracing is active, acquire a slot for timestamp queries.
|
|
781
|
+
*
|
|
782
|
+
* Returns undefined if tracing is not active or the device doesn't support
|
|
783
|
+
* timestamp queries.
|
|
784
|
+
*/
|
|
785
|
+
function maybeAcquireTracingSlot(device) {
|
|
786
|
+
if (!require_backend.isTracing()) return void 0;
|
|
787
|
+
return acquireTracingSlot(device);
|
|
788
|
+
}
|
|
789
|
+
/**
|
|
790
|
+
* Record a tracing entry for a pipeline dispatch and schedule an auto-flush.
|
|
791
|
+
*/
|
|
792
|
+
function recordTrace(device, slot, source, numPasses, wgslSource) {
|
|
793
|
+
const info = require_backend.traceSourceInfo(source);
|
|
794
|
+
info.properties.push(["passes", `${numPasses}`]);
|
|
795
|
+
info.properties.push(["source", wgslSource]);
|
|
796
|
+
slot.batch.entries.push({
|
|
797
|
+
...info,
|
|
798
|
+
beginIndex: slot.beginIndex,
|
|
799
|
+
endIndex: slot.endIndex
|
|
800
|
+
});
|
|
801
|
+
scheduleAutoFlush(device);
|
|
802
|
+
}
|
|
803
|
+
/**
|
|
804
|
+
* If the active batch has pending entries, flush and replace it so traces
|
|
805
|
+
* are emitted without waiting for the batch to fill or stopTrace().
|
|
806
|
+
*
|
|
807
|
+
* Called after each dispatch records its entry via a microtask so that
|
|
808
|
+
* synchronous back-to-back dispatches are still batched together.
|
|
809
|
+
*/
|
|
810
|
+
function scheduleAutoFlush(device) {
|
|
811
|
+
queueMicrotask(() => {
|
|
812
|
+
const batch = activeBatch.get(device);
|
|
813
|
+
if (batch && batch.entries.length > 0) {
|
|
814
|
+
flushTracingBatch(device, batch);
|
|
815
|
+
activeBatch.set(device, createTracingBatch(device));
|
|
816
|
+
}
|
|
817
|
+
});
|
|
818
|
+
}
|
|
819
|
+
function flushTracingBatch(device, batch) {
|
|
820
|
+
if (batch.entries.length === 0) return;
|
|
821
|
+
const usedQueries = batch.nextIndex;
|
|
822
|
+
const encoder = device.createCommandEncoder();
|
|
823
|
+
encoder.resolveQuerySet(batch.querySet, 0, usedQueries, batch.resolve, 0);
|
|
824
|
+
encoder.copyBufferToBuffer(batch.resolve, 0, batch.dst, 0, usedQueries * 8);
|
|
825
|
+
device.queue.submit([encoder.finish()]);
|
|
826
|
+
const { entries } = batch;
|
|
827
|
+
batch.dst.mapAsync(GPUMapMode.READ).then(() => {
|
|
828
|
+
try {
|
|
829
|
+
const times = new BigInt64Array(batch.dst.getMappedRange());
|
|
830
|
+
const anchorGpuNs = times[entries[entries.length - 1].endIndex];
|
|
831
|
+
const anchorCpuMs = performance.now();
|
|
832
|
+
for (const entry of entries) {
|
|
833
|
+
const startMs = anchorCpuMs + Number(times[entry.beginIndex] - anchorGpuNs) / 1e6;
|
|
834
|
+
const endMs = anchorCpuMs + Number(times[entry.endIndex] - anchorGpuNs) / 1e6;
|
|
835
|
+
require_backend.emitTrace("webgpu", entry, startMs, endMs);
|
|
836
|
+
}
|
|
837
|
+
} finally {
|
|
838
|
+
batch.dst.unmap();
|
|
839
|
+
batch.querySet.destroy();
|
|
840
|
+
batch.resolve.destroy();
|
|
841
|
+
batch.dst.destroy();
|
|
842
|
+
}
|
|
843
|
+
});
|
|
844
|
+
}
|
|
845
|
+
|
|
716
846
|
//#endregion
|
|
717
847
|
//#region src/backend/webgpu.ts
|
|
718
848
|
/** Implementation of `Backend` that uses WebGPU in browsers. */
|
|
@@ -857,7 +987,7 @@ var WebGPUBackend = class {
|
|
|
857
987
|
dispatch(exe, inputs, outputs) {
|
|
858
988
|
const inputBuffers = inputs.map((slot) => this.#getBuffer(slot).buffer);
|
|
859
989
|
const outputBuffers = outputs.map((slot) => this.#getBuffer(slot).buffer);
|
|
860
|
-
pipelineSubmit(this.device, exe
|
|
990
|
+
pipelineSubmit(this.device, exe, inputBuffers, outputBuffers);
|
|
861
991
|
}
|
|
862
992
|
#getBuffer(slot) {
|
|
863
993
|
const buffer = this.buffers.get(slot);
|
|
@@ -995,8 +1125,16 @@ function pipelineSource(device, kernel) {
|
|
|
995
1125
|
else if (op === require_backend.AluOp.Reciprocal) source = `(1.0 / ${a})`;
|
|
996
1126
|
else if (op === require_backend.AluOp.Floor) source = `floor(${require_backend.strip1(a)})`;
|
|
997
1127
|
else if (op === require_backend.AluOp.Ceil) source = `ceil(${require_backend.strip1(a)})`;
|
|
998
|
-
else if (op === require_backend.AluOp.Cast)
|
|
999
|
-
|
|
1128
|
+
else if (op === require_backend.AluOp.Cast) {
|
|
1129
|
+
const srcTy = dtypeToWgsl(src[0].dtype);
|
|
1130
|
+
const dstTy = dtypeToWgsl(dtype);
|
|
1131
|
+
if (require_backend.isFloatDtype(src[0].dtype) && !(require_backend.isFloatDtype(dtype) || dtype === require_backend.DType.Bool)) {
|
|
1132
|
+
const maxVal = maxValueWgsl(dtype);
|
|
1133
|
+
const x = isGensym(a) ? a : gensym();
|
|
1134
|
+
if (x !== a) emit(`let ${x}: ${srcTy} = ${require_backend.strip1(a)};`);
|
|
1135
|
+
source = `select(${dstTy}(${x}), ${maxVal}, ${x} >= ${srcTy}(${maxVal}))`;
|
|
1136
|
+
} else source = `${dstTy}(${require_backend.strip1(a)})`;
|
|
1137
|
+
} else if (op === require_backend.AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${require_backend.strip1(a)})`;
|
|
1000
1138
|
}
|
|
1001
1139
|
else if (op === require_backend.AluOp.Where) source = `select(${require_backend.strip1(gen(src[2]))}, ${require_backend.strip1(gen(src[1]))}, ${require_backend.strip1(gen(src[0]))})`;
|
|
1002
1140
|
else if (op === require_backend.AluOp.Threefry2x32) {
|
|
@@ -1099,12 +1237,14 @@ function pipelineSource(device, kernel) {
|
|
|
1099
1237
|
passes: [{ grid: [gridX, gridY] }]
|
|
1100
1238
|
};
|
|
1101
1239
|
}
|
|
1102
|
-
function pipelineSubmit(device,
|
|
1240
|
+
function pipelineSubmit(device, exe, inputs, outputs) {
|
|
1241
|
+
const { data: pipelines, source } = exe;
|
|
1103
1242
|
const commandEncoder = device.createCommandEncoder();
|
|
1104
1243
|
for (const { pipeline,...shader } of pipelines) {
|
|
1105
1244
|
if (inputs.length !== shader.numInputs || outputs.length !== shader.numOutputs) throw new Error(`webgpu: expected ${shader.numInputs} inputs and ${shader.numOutputs} outputs, got ${inputs.length} inputs and ${outputs.length} outputs`);
|
|
1106
1245
|
const filteredPasses = shader.passes.filter(({ grid }) => require_backend.prod(grid) > 0);
|
|
1107
1246
|
if (filteredPasses.length === 0) continue;
|
|
1247
|
+
const slot = maybeAcquireTracingSlot(device);
|
|
1108
1248
|
const bindGroup = device.createBindGroup({
|
|
1109
1249
|
layout: pipeline.getBindGroupLayout(0),
|
|
1110
1250
|
entries: [...inputs.map((buffer, i) => ({
|
|
@@ -1134,13 +1274,18 @@ function pipelineSubmit(device, pipelines, inputs, outputs) {
|
|
|
1134
1274
|
}
|
|
1135
1275
|
for (let i = 0; i < filteredPasses.length; i++) {
|
|
1136
1276
|
const { grid } = filteredPasses[i];
|
|
1137
|
-
const passEncoder = commandEncoder.beginComputePass(
|
|
1277
|
+
const passEncoder = commandEncoder.beginComputePass({ timestampWrites: slot ? {
|
|
1278
|
+
querySet: slot.batch.querySet,
|
|
1279
|
+
beginningOfPassWriteIndex: i === 0 ? slot.beginIndex : void 0,
|
|
1280
|
+
endOfPassWriteIndex: i === filteredPasses.length - 1 ? slot.endIndex : void 0
|
|
1281
|
+
} : void 0 });
|
|
1138
1282
|
passEncoder.setPipeline(pipeline);
|
|
1139
1283
|
passEncoder.setBindGroup(0, bindGroup);
|
|
1140
1284
|
if (uniformBindGroup) passEncoder.setBindGroup(1, uniformBindGroup, [i * uniformAlignment]);
|
|
1141
1285
|
passEncoder.dispatchWorkgroups(grid[0], grid[1]);
|
|
1142
1286
|
passEncoder.end();
|
|
1143
1287
|
}
|
|
1288
|
+
if (slot) recordTrace(device, slot, source, filteredPasses.length, shader.code);
|
|
1144
1289
|
}
|
|
1145
1290
|
device.queue.submit([commandEncoder.finish()]);
|
|
1146
1291
|
}
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, findPow2, isFloatDtype, mapSetUnion, prod, range, strip1, tuneWebgpu } from "./backend-
|
|
1
|
+
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, emitTrace, findPow2, isFloatDtype, isTracing, mapSetUnion, onFlushTrace, prod, range, strip1, traceSourceInfo, tuneWebgpu } from "./backend-Ctqs8la1.js";
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -152,6 +152,7 @@ var SyncReader = class SyncReader {
|
|
|
152
152
|
this.device = device;
|
|
153
153
|
}
|
|
154
154
|
#init() {
|
|
155
|
+
if (typeof OffscreenCanvas === "undefined") throw new Error("OffscreenCanvas is not available in this environment, so you cannot read data from WebGPU synchronously. Consider using the async API.");
|
|
155
156
|
const makeCanvas = () => new OffscreenCanvas(SyncReader.width, SyncReader.height);
|
|
156
157
|
this.deviceStorage = SyncReader.alphaModes.map(makeCanvas);
|
|
157
158
|
this.deviceContexts = this.deviceStorage.map((canvas, i) => {
|
|
@@ -247,6 +248,10 @@ function bitonicSortUniform(pass) {
|
|
|
247
248
|
* `2^(step+1)` with multiple workgroups. This doesn't use shared memory.
|
|
248
249
|
*
|
|
249
250
|
* The total number of passes is roughly `log2(n / workgroupSize)^2 / 2`.
|
|
251
|
+
*
|
|
252
|
+
* If `outputIndices` is true, the shader also tracks the original indices of
|
|
253
|
+
* the sorted elements (argsort) and outputs them to a separate buffer. This
|
|
254
|
+
* also makes the sorting algorithm stable.
|
|
250
255
|
*/
|
|
251
256
|
function bitonicSortShader(device, dtype, n, batches, outputIndices) {
|
|
252
257
|
const ty = dtypeToWgsl(dtype, true);
|
|
@@ -286,14 +291,21 @@ ${isFloatDtype(dtype) ? `
|
|
|
286
291
|
fn compare_and_swap(i: u32, j: u32) {
|
|
287
292
|
let val_i = shared_vals[i];
|
|
288
293
|
let val_j = shared_vals[j];
|
|
289
|
-
|
|
294
|
+
${outputIndices ? `
|
|
295
|
+
if (
|
|
296
|
+
compare(val_j, val_i) ||
|
|
297
|
+
(!compare(val_i, val_j) && shared_idx[j] < shared_idx[i])
|
|
298
|
+
) {
|
|
290
299
|
shared_vals[i] = val_j;
|
|
291
300
|
shared_vals[j] = val_i;
|
|
292
|
-
${outputIndices ? `
|
|
293
301
|
let tmp_idx = shared_idx[i];
|
|
294
302
|
shared_idx[i] = shared_idx[j];
|
|
295
|
-
shared_idx[j] = tmp_idx
|
|
296
|
-
}
|
|
303
|
+
shared_idx[j] = tmp_idx;
|
|
304
|
+
}` : `
|
|
305
|
+
if (compare(val_j, val_i)) {
|
|
306
|
+
shared_vals[i] = val_j;
|
|
307
|
+
shared_vals[j] = val_i;
|
|
308
|
+
}`}
|
|
297
309
|
}
|
|
298
310
|
|
|
299
311
|
@compute @workgroup_size(${workgroupSize})
|
|
@@ -370,13 +382,17 @@ ${outputIndices ? `
|
|
|
370
382
|
if (j < ${n}u) {
|
|
371
383
|
let val_i = output[base + i];
|
|
372
384
|
let val_j = output[base + j];
|
|
373
|
-
|
|
385
|
+
${outputIndices ? `
|
|
386
|
+
let idx_i = output_idx[base + i];
|
|
387
|
+
let idx_j = output_idx[base + j];
|
|
388
|
+
if (compare(val_j, val_i) || (!compare(val_i, val_j) && idx_j < idx_i)) {
|
|
374
389
|
output[base + i] = val_j;
|
|
375
390
|
output[base + j] = val_i;
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
391
|
+
output_idx[base + i] = idx_j;
|
|
392
|
+
output_idx[base + j] = idx_i;` : `
|
|
393
|
+
if (compare(val_j, val_i)) {
|
|
394
|
+
output[base + i] = val_j;
|
|
395
|
+
output[base + j] = val_i;`}
|
|
380
396
|
}
|
|
381
397
|
}
|
|
382
398
|
}
|
|
@@ -713,6 +729,120 @@ function createRoutineShader(device, routine) {
|
|
|
713
729
|
}
|
|
714
730
|
}
|
|
715
731
|
|
|
732
|
+
//#endregion
|
|
733
|
+
//#region src/backend/webgpu/tracing.ts
|
|
734
|
+
const MAX_TIMESTAMP_QUERIES = 4096;
|
|
735
|
+
const activeBatch = /* @__PURE__ */ new WeakMap();
|
|
736
|
+
function createTracingBatch(device) {
|
|
737
|
+
return {
|
|
738
|
+
querySet: device.createQuerySet({
|
|
739
|
+
type: "timestamp",
|
|
740
|
+
count: MAX_TIMESTAMP_QUERIES
|
|
741
|
+
}),
|
|
742
|
+
resolve: device.createBuffer({
|
|
743
|
+
size: MAX_TIMESTAMP_QUERIES * 8,
|
|
744
|
+
usage: GPUBufferUsage.QUERY_RESOLVE | GPUBufferUsage.COPY_SRC
|
|
745
|
+
}),
|
|
746
|
+
dst: device.createBuffer({
|
|
747
|
+
size: MAX_TIMESTAMP_QUERIES * 8,
|
|
748
|
+
usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST
|
|
749
|
+
}),
|
|
750
|
+
nextIndex: 0,
|
|
751
|
+
entries: []
|
|
752
|
+
};
|
|
753
|
+
}
|
|
754
|
+
function acquireTracingSlot(device) {
|
|
755
|
+
if (!device.features.has("timestamp-query")) return void 0;
|
|
756
|
+
let batch = activeBatch.get(device);
|
|
757
|
+
if (batch && batch.nextIndex >= MAX_TIMESTAMP_QUERIES) {
|
|
758
|
+
flushTracingBatch(device, batch);
|
|
759
|
+
batch = void 0;
|
|
760
|
+
}
|
|
761
|
+
if (!batch) {
|
|
762
|
+
batch = createTracingBatch(device);
|
|
763
|
+
activeBatch.set(device, batch);
|
|
764
|
+
onFlushTrace(() => {
|
|
765
|
+
const b = activeBatch.get(device);
|
|
766
|
+
if (b && b.entries.length > 0) flushTracingBatch(device, b);
|
|
767
|
+
activeBatch.delete(device);
|
|
768
|
+
});
|
|
769
|
+
}
|
|
770
|
+
const beginIndex = batch.nextIndex;
|
|
771
|
+
const endIndex = beginIndex + 1;
|
|
772
|
+
batch.nextIndex += 2;
|
|
773
|
+
return {
|
|
774
|
+
batch,
|
|
775
|
+
beginIndex,
|
|
776
|
+
endIndex
|
|
777
|
+
};
|
|
778
|
+
}
|
|
779
|
+
/**
|
|
780
|
+
* If tracing is active, acquire a slot for timestamp queries.
|
|
781
|
+
*
|
|
782
|
+
* Returns undefined if tracing is not active or the device doesn't support
|
|
783
|
+
* timestamp queries.
|
|
784
|
+
*/
|
|
785
|
+
function maybeAcquireTracingSlot(device) {
|
|
786
|
+
if (!isTracing()) return void 0;
|
|
787
|
+
return acquireTracingSlot(device);
|
|
788
|
+
}
|
|
789
|
+
/**
|
|
790
|
+
* Record a tracing entry for a pipeline dispatch and schedule an auto-flush.
|
|
791
|
+
*/
|
|
792
|
+
function recordTrace(device, slot, source, numPasses, wgslSource) {
|
|
793
|
+
const info = traceSourceInfo(source);
|
|
794
|
+
info.properties.push(["passes", `${numPasses}`]);
|
|
795
|
+
info.properties.push(["source", wgslSource]);
|
|
796
|
+
slot.batch.entries.push({
|
|
797
|
+
...info,
|
|
798
|
+
beginIndex: slot.beginIndex,
|
|
799
|
+
endIndex: slot.endIndex
|
|
800
|
+
});
|
|
801
|
+
scheduleAutoFlush(device);
|
|
802
|
+
}
|
|
803
|
+
/**
|
|
804
|
+
* If the active batch has pending entries, flush and replace it so traces
|
|
805
|
+
* are emitted without waiting for the batch to fill or stopTrace().
|
|
806
|
+
*
|
|
807
|
+
* Called after each dispatch records its entry via a microtask so that
|
|
808
|
+
* synchronous back-to-back dispatches are still batched together.
|
|
809
|
+
*/
|
|
810
|
+
function scheduleAutoFlush(device) {
|
|
811
|
+
queueMicrotask(() => {
|
|
812
|
+
const batch = activeBatch.get(device);
|
|
813
|
+
if (batch && batch.entries.length > 0) {
|
|
814
|
+
flushTracingBatch(device, batch);
|
|
815
|
+
activeBatch.set(device, createTracingBatch(device));
|
|
816
|
+
}
|
|
817
|
+
});
|
|
818
|
+
}
|
|
819
|
+
function flushTracingBatch(device, batch) {
|
|
820
|
+
if (batch.entries.length === 0) return;
|
|
821
|
+
const usedQueries = batch.nextIndex;
|
|
822
|
+
const encoder = device.createCommandEncoder();
|
|
823
|
+
encoder.resolveQuerySet(batch.querySet, 0, usedQueries, batch.resolve, 0);
|
|
824
|
+
encoder.copyBufferToBuffer(batch.resolve, 0, batch.dst, 0, usedQueries * 8);
|
|
825
|
+
device.queue.submit([encoder.finish()]);
|
|
826
|
+
const { entries } = batch;
|
|
827
|
+
batch.dst.mapAsync(GPUMapMode.READ).then(() => {
|
|
828
|
+
try {
|
|
829
|
+
const times = new BigInt64Array(batch.dst.getMappedRange());
|
|
830
|
+
const anchorGpuNs = times[entries[entries.length - 1].endIndex];
|
|
831
|
+
const anchorCpuMs = performance.now();
|
|
832
|
+
for (const entry of entries) {
|
|
833
|
+
const startMs = anchorCpuMs + Number(times[entry.beginIndex] - anchorGpuNs) / 1e6;
|
|
834
|
+
const endMs = anchorCpuMs + Number(times[entry.endIndex] - anchorGpuNs) / 1e6;
|
|
835
|
+
emitTrace("webgpu", entry, startMs, endMs);
|
|
836
|
+
}
|
|
837
|
+
} finally {
|
|
838
|
+
batch.dst.unmap();
|
|
839
|
+
batch.querySet.destroy();
|
|
840
|
+
batch.resolve.destroy();
|
|
841
|
+
batch.dst.destroy();
|
|
842
|
+
}
|
|
843
|
+
});
|
|
844
|
+
}
|
|
845
|
+
|
|
716
846
|
//#endregion
|
|
717
847
|
//#region src/backend/webgpu.ts
|
|
718
848
|
/** Implementation of `Backend` that uses WebGPU in browsers. */
|
|
@@ -857,7 +987,7 @@ var WebGPUBackend = class {
|
|
|
857
987
|
dispatch(exe, inputs, outputs) {
|
|
858
988
|
const inputBuffers = inputs.map((slot) => this.#getBuffer(slot).buffer);
|
|
859
989
|
const outputBuffers = outputs.map((slot) => this.#getBuffer(slot).buffer);
|
|
860
|
-
pipelineSubmit(this.device, exe
|
|
990
|
+
pipelineSubmit(this.device, exe, inputBuffers, outputBuffers);
|
|
861
991
|
}
|
|
862
992
|
#getBuffer(slot) {
|
|
863
993
|
const buffer = this.buffers.get(slot);
|
|
@@ -995,8 +1125,16 @@ function pipelineSource(device, kernel) {
|
|
|
995
1125
|
else if (op === AluOp.Reciprocal) source = `(1.0 / ${a})`;
|
|
996
1126
|
else if (op === AluOp.Floor) source = `floor(${strip1(a)})`;
|
|
997
1127
|
else if (op === AluOp.Ceil) source = `ceil(${strip1(a)})`;
|
|
998
|
-
else if (op === AluOp.Cast)
|
|
999
|
-
|
|
1128
|
+
else if (op === AluOp.Cast) {
|
|
1129
|
+
const srcTy = dtypeToWgsl(src[0].dtype);
|
|
1130
|
+
const dstTy = dtypeToWgsl(dtype);
|
|
1131
|
+
if (isFloatDtype(src[0].dtype) && !(isFloatDtype(dtype) || dtype === DType.Bool)) {
|
|
1132
|
+
const maxVal = maxValueWgsl(dtype);
|
|
1133
|
+
const x = isGensym(a) ? a : gensym();
|
|
1134
|
+
if (x !== a) emit(`let ${x}: ${srcTy} = ${strip1(a)};`);
|
|
1135
|
+
source = `select(${dstTy}(${x}), ${maxVal}, ${x} >= ${srcTy}(${maxVal}))`;
|
|
1136
|
+
} else source = `${dstTy}(${strip1(a)})`;
|
|
1137
|
+
} else if (op === AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${strip1(a)})`;
|
|
1000
1138
|
}
|
|
1001
1139
|
else if (op === AluOp.Where) source = `select(${strip1(gen(src[2]))}, ${strip1(gen(src[1]))}, ${strip1(gen(src[0]))})`;
|
|
1002
1140
|
else if (op === AluOp.Threefry2x32) {
|
|
@@ -1099,12 +1237,14 @@ function pipelineSource(device, kernel) {
|
|
|
1099
1237
|
passes: [{ grid: [gridX, gridY] }]
|
|
1100
1238
|
};
|
|
1101
1239
|
}
|
|
1102
|
-
function pipelineSubmit(device,
|
|
1240
|
+
function pipelineSubmit(device, exe, inputs, outputs) {
|
|
1241
|
+
const { data: pipelines, source } = exe;
|
|
1103
1242
|
const commandEncoder = device.createCommandEncoder();
|
|
1104
1243
|
for (const { pipeline,...shader } of pipelines) {
|
|
1105
1244
|
if (inputs.length !== shader.numInputs || outputs.length !== shader.numOutputs) throw new Error(`webgpu: expected ${shader.numInputs} inputs and ${shader.numOutputs} outputs, got ${inputs.length} inputs and ${outputs.length} outputs`);
|
|
1106
1245
|
const filteredPasses = shader.passes.filter(({ grid }) => prod(grid) > 0);
|
|
1107
1246
|
if (filteredPasses.length === 0) continue;
|
|
1247
|
+
const slot = maybeAcquireTracingSlot(device);
|
|
1108
1248
|
const bindGroup = device.createBindGroup({
|
|
1109
1249
|
layout: pipeline.getBindGroupLayout(0),
|
|
1110
1250
|
entries: [...inputs.map((buffer, i) => ({
|
|
@@ -1134,13 +1274,18 @@ function pipelineSubmit(device, pipelines, inputs, outputs) {
|
|
|
1134
1274
|
}
|
|
1135
1275
|
for (let i = 0; i < filteredPasses.length; i++) {
|
|
1136
1276
|
const { grid } = filteredPasses[i];
|
|
1137
|
-
const passEncoder = commandEncoder.beginComputePass(
|
|
1277
|
+
const passEncoder = commandEncoder.beginComputePass({ timestampWrites: slot ? {
|
|
1278
|
+
querySet: slot.batch.querySet,
|
|
1279
|
+
beginningOfPassWriteIndex: i === 0 ? slot.beginIndex : void 0,
|
|
1280
|
+
endOfPassWriteIndex: i === filteredPasses.length - 1 ? slot.endIndex : void 0
|
|
1281
|
+
} : void 0 });
|
|
1138
1282
|
passEncoder.setPipeline(pipeline);
|
|
1139
1283
|
passEncoder.setBindGroup(0, bindGroup);
|
|
1140
1284
|
if (uniformBindGroup) passEncoder.setBindGroup(1, uniformBindGroup, [i * uniformAlignment]);
|
|
1141
1285
|
passEncoder.dispatchWorkgroups(grid[0], grid[1]);
|
|
1142
1286
|
passEncoder.end();
|
|
1143
1287
|
}
|
|
1288
|
+
if (slot) recordTrace(device, slot, source, filteredPasses.length, shader.code);
|
|
1144
1289
|
}
|
|
1145
1290
|
device.queue.submit([commandEncoder.finish()]);
|
|
1146
1291
|
}
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@jax-js/jax",
|
|
3
|
-
"version": "0.1.
|
|
3
|
+
"version": "0.1.10",
|
|
4
4
|
"description": "Numerical computing and ML in the browser",
|
|
5
5
|
"keywords": [
|
|
6
6
|
"machine learning",
|
|
@@ -38,15 +38,13 @@
|
|
|
38
38
|
"devDependencies": {
|
|
39
39
|
"@eslint/js": "^9.31.0",
|
|
40
40
|
"@types/debug": "^4.1.12",
|
|
41
|
-
"@vitest/browser-playwright": "^4.0
|
|
42
|
-
"@vitest/coverage-v8": "4.0
|
|
41
|
+
"@vitest/browser-playwright": "^4.1.0",
|
|
42
|
+
"@vitest/coverage-v8": "^4.1.0",
|
|
43
43
|
"@webgpu/types": "^0.1.68",
|
|
44
44
|
"eslint": "^9.31.0",
|
|
45
45
|
"eslint-plugin-import": "^2.32.0",
|
|
46
46
|
"globals": "^16.0.0",
|
|
47
|
-
"
|
|
48
|
-
"lint-staged": "^16.2.7",
|
|
49
|
-
"playwright": "~1.52.0",
|
|
47
|
+
"playwright": "~1.58.2",
|
|
50
48
|
"prettier": "^3.6.2",
|
|
51
49
|
"prettier-plugin-svelte": "^3.4.0",
|
|
52
50
|
"tsdown": "^0.13.2",
|
|
@@ -55,7 +53,7 @@
|
|
|
55
53
|
"typedoc-theme-fresh": "^0.2.3",
|
|
56
54
|
"typescript": "~5.9.3",
|
|
57
55
|
"typescript-eslint": "^8.46.4",
|
|
58
|
-
"vitest": "^4.0
|
|
56
|
+
"vitest": "^4.1.0"
|
|
59
57
|
},
|
|
60
58
|
"engines": {
|
|
61
59
|
"pnpm": ">=10.0.0"
|
|
@@ -76,15 +74,6 @@
|
|
|
76
74
|
],
|
|
77
75
|
"proseWrap": "always"
|
|
78
76
|
},
|
|
79
|
-
"lint-staged": {
|
|
80
|
-
"*.{ts,tsx,js,jsx}": [
|
|
81
|
-
"eslint --fix",
|
|
82
|
-
"prettier --write"
|
|
83
|
-
],
|
|
84
|
-
"*.{json,md,yml,yaml,css,svelte,html}": [
|
|
85
|
-
"prettier --write"
|
|
86
|
-
]
|
|
87
|
-
},
|
|
88
77
|
"scripts": {
|
|
89
78
|
"build": "tsdown",
|
|
90
79
|
"build:watch": "TSDOWN_WATCH_MODE=1 tsdown",
|