@jax-js/jax 0.1.9 → 0.1.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +31 -18
- package/dist/{backend-BId79r5b.js → backend-Ctqs8la1.js} +107 -11
- package/dist/{backend-DpI0riom.cjs → backend-DMauYnfl.cjs} +142 -10
- package/dist/index.cjs +225 -18
- package/dist/index.d.cts +112 -11
- package/dist/index.d.ts +112 -11
- package/dist/index.js +225 -19
- package/dist/{webgl-DnGrclTz.js → webgl-CvQ1QBX1.js} +1 -1
- package/dist/{webgl-C5NjXc1p.cjs → webgl-kvVt7-T7.cjs} +1 -1
- package/dist/{webgpu-CdjiJSa7.cjs → webgpu-DMSx7a6M.cjs} +136 -6
- package/dist/{webgpu-AN0cG_nB.js → webgpu-v_W_-oKw.js} +136 -6
- package/package.json +5 -16
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
const require_backend = require('./backend-
|
|
1
|
+
const require_backend = require('./backend-DMauYnfl.cjs');
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -152,6 +152,7 @@ var SyncReader = class SyncReader {
|
|
|
152
152
|
this.device = device;
|
|
153
153
|
}
|
|
154
154
|
#init() {
|
|
155
|
+
if (typeof OffscreenCanvas === "undefined") throw new Error("OffscreenCanvas is not available in this environment, so you cannot read data from WebGPU synchronously. Consider using the async API.");
|
|
155
156
|
const makeCanvas = () => new OffscreenCanvas(SyncReader.width, SyncReader.height);
|
|
156
157
|
this.deviceStorage = SyncReader.alphaModes.map(makeCanvas);
|
|
157
158
|
this.deviceContexts = this.deviceStorage.map((canvas, i) => {
|
|
@@ -728,6 +729,120 @@ function createRoutineShader(device, routine) {
|
|
|
728
729
|
}
|
|
729
730
|
}
|
|
730
731
|
|
|
732
|
+
//#endregion
|
|
733
|
+
//#region src/backend/webgpu/tracing.ts
|
|
734
|
+
const MAX_TIMESTAMP_QUERIES = 4096;
|
|
735
|
+
const activeBatch = /* @__PURE__ */ new WeakMap();
|
|
736
|
+
function createTracingBatch(device) {
|
|
737
|
+
return {
|
|
738
|
+
querySet: device.createQuerySet({
|
|
739
|
+
type: "timestamp",
|
|
740
|
+
count: MAX_TIMESTAMP_QUERIES
|
|
741
|
+
}),
|
|
742
|
+
resolve: device.createBuffer({
|
|
743
|
+
size: MAX_TIMESTAMP_QUERIES * 8,
|
|
744
|
+
usage: GPUBufferUsage.QUERY_RESOLVE | GPUBufferUsage.COPY_SRC
|
|
745
|
+
}),
|
|
746
|
+
dst: device.createBuffer({
|
|
747
|
+
size: MAX_TIMESTAMP_QUERIES * 8,
|
|
748
|
+
usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST
|
|
749
|
+
}),
|
|
750
|
+
nextIndex: 0,
|
|
751
|
+
entries: []
|
|
752
|
+
};
|
|
753
|
+
}
|
|
754
|
+
function acquireTracingSlot(device) {
|
|
755
|
+
if (!device.features.has("timestamp-query")) return void 0;
|
|
756
|
+
let batch = activeBatch.get(device);
|
|
757
|
+
if (batch && batch.nextIndex >= MAX_TIMESTAMP_QUERIES) {
|
|
758
|
+
flushTracingBatch(device, batch);
|
|
759
|
+
batch = void 0;
|
|
760
|
+
}
|
|
761
|
+
if (!batch) {
|
|
762
|
+
batch = createTracingBatch(device);
|
|
763
|
+
activeBatch.set(device, batch);
|
|
764
|
+
require_backend.onFlushTrace(() => {
|
|
765
|
+
const b = activeBatch.get(device);
|
|
766
|
+
if (b && b.entries.length > 0) flushTracingBatch(device, b);
|
|
767
|
+
activeBatch.delete(device);
|
|
768
|
+
});
|
|
769
|
+
}
|
|
770
|
+
const beginIndex = batch.nextIndex;
|
|
771
|
+
const endIndex = beginIndex + 1;
|
|
772
|
+
batch.nextIndex += 2;
|
|
773
|
+
return {
|
|
774
|
+
batch,
|
|
775
|
+
beginIndex,
|
|
776
|
+
endIndex
|
|
777
|
+
};
|
|
778
|
+
}
|
|
779
|
+
/**
|
|
780
|
+
* If tracing is active, acquire a slot for timestamp queries.
|
|
781
|
+
*
|
|
782
|
+
* Returns undefined if tracing is not active or the device doesn't support
|
|
783
|
+
* timestamp queries.
|
|
784
|
+
*/
|
|
785
|
+
function maybeAcquireTracingSlot(device) {
|
|
786
|
+
if (!require_backend.isTracing()) return void 0;
|
|
787
|
+
return acquireTracingSlot(device);
|
|
788
|
+
}
|
|
789
|
+
/**
|
|
790
|
+
* Record a tracing entry for a pipeline dispatch and schedule an auto-flush.
|
|
791
|
+
*/
|
|
792
|
+
function recordTrace(device, slot, source, numPasses, wgslSource) {
|
|
793
|
+
const info = require_backend.traceSourceInfo(source);
|
|
794
|
+
info.properties.push(["passes", `${numPasses}`]);
|
|
795
|
+
info.properties.push(["source", wgslSource]);
|
|
796
|
+
slot.batch.entries.push({
|
|
797
|
+
...info,
|
|
798
|
+
beginIndex: slot.beginIndex,
|
|
799
|
+
endIndex: slot.endIndex
|
|
800
|
+
});
|
|
801
|
+
scheduleAutoFlush(device);
|
|
802
|
+
}
|
|
803
|
+
/**
|
|
804
|
+
* If the active batch has pending entries, flush and replace it so traces
|
|
805
|
+
* are emitted without waiting for the batch to fill or stopTrace().
|
|
806
|
+
*
|
|
807
|
+
* Called after each dispatch records its entry via a microtask so that
|
|
808
|
+
* synchronous back-to-back dispatches are still batched together.
|
|
809
|
+
*/
|
|
810
|
+
function scheduleAutoFlush(device) {
|
|
811
|
+
queueMicrotask(() => {
|
|
812
|
+
const batch = activeBatch.get(device);
|
|
813
|
+
if (batch && batch.entries.length > 0) {
|
|
814
|
+
flushTracingBatch(device, batch);
|
|
815
|
+
activeBatch.set(device, createTracingBatch(device));
|
|
816
|
+
}
|
|
817
|
+
});
|
|
818
|
+
}
|
|
819
|
+
function flushTracingBatch(device, batch) {
|
|
820
|
+
if (batch.entries.length === 0) return;
|
|
821
|
+
const usedQueries = batch.nextIndex;
|
|
822
|
+
const encoder = device.createCommandEncoder();
|
|
823
|
+
encoder.resolveQuerySet(batch.querySet, 0, usedQueries, batch.resolve, 0);
|
|
824
|
+
encoder.copyBufferToBuffer(batch.resolve, 0, batch.dst, 0, usedQueries * 8);
|
|
825
|
+
device.queue.submit([encoder.finish()]);
|
|
826
|
+
const { entries } = batch;
|
|
827
|
+
batch.dst.mapAsync(GPUMapMode.READ).then(() => {
|
|
828
|
+
try {
|
|
829
|
+
const times = new BigInt64Array(batch.dst.getMappedRange());
|
|
830
|
+
const anchorGpuNs = times[entries[entries.length - 1].endIndex];
|
|
831
|
+
const anchorCpuMs = performance.now();
|
|
832
|
+
for (const entry of entries) {
|
|
833
|
+
const startMs = anchorCpuMs + Number(times[entry.beginIndex] - anchorGpuNs) / 1e6;
|
|
834
|
+
const endMs = anchorCpuMs + Number(times[entry.endIndex] - anchorGpuNs) / 1e6;
|
|
835
|
+
require_backend.emitTrace("webgpu", entry, startMs, endMs);
|
|
836
|
+
}
|
|
837
|
+
} finally {
|
|
838
|
+
batch.dst.unmap();
|
|
839
|
+
batch.querySet.destroy();
|
|
840
|
+
batch.resolve.destroy();
|
|
841
|
+
batch.dst.destroy();
|
|
842
|
+
}
|
|
843
|
+
});
|
|
844
|
+
}
|
|
845
|
+
|
|
731
846
|
//#endregion
|
|
732
847
|
//#region src/backend/webgpu.ts
|
|
733
848
|
/** Implementation of `Backend` that uses WebGPU in browsers. */
|
|
@@ -872,7 +987,7 @@ var WebGPUBackend = class {
|
|
|
872
987
|
dispatch(exe, inputs, outputs) {
|
|
873
988
|
const inputBuffers = inputs.map((slot) => this.#getBuffer(slot).buffer);
|
|
874
989
|
const outputBuffers = outputs.map((slot) => this.#getBuffer(slot).buffer);
|
|
875
|
-
pipelineSubmit(this.device, exe
|
|
990
|
+
pipelineSubmit(this.device, exe, inputBuffers, outputBuffers);
|
|
876
991
|
}
|
|
877
992
|
#getBuffer(slot) {
|
|
878
993
|
const buffer = this.buffers.get(slot);
|
|
@@ -1010,8 +1125,16 @@ function pipelineSource(device, kernel) {
|
|
|
1010
1125
|
else if (op === require_backend.AluOp.Reciprocal) source = `(1.0 / ${a})`;
|
|
1011
1126
|
else if (op === require_backend.AluOp.Floor) source = `floor(${require_backend.strip1(a)})`;
|
|
1012
1127
|
else if (op === require_backend.AluOp.Ceil) source = `ceil(${require_backend.strip1(a)})`;
|
|
1013
|
-
else if (op === require_backend.AluOp.Cast)
|
|
1014
|
-
|
|
1128
|
+
else if (op === require_backend.AluOp.Cast) {
|
|
1129
|
+
const srcTy = dtypeToWgsl(src[0].dtype);
|
|
1130
|
+
const dstTy = dtypeToWgsl(dtype);
|
|
1131
|
+
if (require_backend.isFloatDtype(src[0].dtype) && !(require_backend.isFloatDtype(dtype) || dtype === require_backend.DType.Bool)) {
|
|
1132
|
+
const maxVal = maxValueWgsl(dtype);
|
|
1133
|
+
const x = isGensym(a) ? a : gensym();
|
|
1134
|
+
if (x !== a) emit(`let ${x}: ${srcTy} = ${require_backend.strip1(a)};`);
|
|
1135
|
+
source = `select(${dstTy}(${x}), ${maxVal}, ${x} >= ${srcTy}(${maxVal}))`;
|
|
1136
|
+
} else source = `${dstTy}(${require_backend.strip1(a)})`;
|
|
1137
|
+
} else if (op === require_backend.AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${require_backend.strip1(a)})`;
|
|
1015
1138
|
}
|
|
1016
1139
|
else if (op === require_backend.AluOp.Where) source = `select(${require_backend.strip1(gen(src[2]))}, ${require_backend.strip1(gen(src[1]))}, ${require_backend.strip1(gen(src[0]))})`;
|
|
1017
1140
|
else if (op === require_backend.AluOp.Threefry2x32) {
|
|
@@ -1114,12 +1237,14 @@ function pipelineSource(device, kernel) {
|
|
|
1114
1237
|
passes: [{ grid: [gridX, gridY] }]
|
|
1115
1238
|
};
|
|
1116
1239
|
}
|
|
1117
|
-
function pipelineSubmit(device,
|
|
1240
|
+
function pipelineSubmit(device, exe, inputs, outputs) {
|
|
1241
|
+
const { data: pipelines, source } = exe;
|
|
1118
1242
|
const commandEncoder = device.createCommandEncoder();
|
|
1119
1243
|
for (const { pipeline,...shader } of pipelines) {
|
|
1120
1244
|
if (inputs.length !== shader.numInputs || outputs.length !== shader.numOutputs) throw new Error(`webgpu: expected ${shader.numInputs} inputs and ${shader.numOutputs} outputs, got ${inputs.length} inputs and ${outputs.length} outputs`);
|
|
1121
1245
|
const filteredPasses = shader.passes.filter(({ grid }) => require_backend.prod(grid) > 0);
|
|
1122
1246
|
if (filteredPasses.length === 0) continue;
|
|
1247
|
+
const slot = maybeAcquireTracingSlot(device);
|
|
1123
1248
|
const bindGroup = device.createBindGroup({
|
|
1124
1249
|
layout: pipeline.getBindGroupLayout(0),
|
|
1125
1250
|
entries: [...inputs.map((buffer, i) => ({
|
|
@@ -1149,13 +1274,18 @@ function pipelineSubmit(device, pipelines, inputs, outputs) {
|
|
|
1149
1274
|
}
|
|
1150
1275
|
for (let i = 0; i < filteredPasses.length; i++) {
|
|
1151
1276
|
const { grid } = filteredPasses[i];
|
|
1152
|
-
const passEncoder = commandEncoder.beginComputePass(
|
|
1277
|
+
const passEncoder = commandEncoder.beginComputePass({ timestampWrites: slot ? {
|
|
1278
|
+
querySet: slot.batch.querySet,
|
|
1279
|
+
beginningOfPassWriteIndex: i === 0 ? slot.beginIndex : void 0,
|
|
1280
|
+
endOfPassWriteIndex: i === filteredPasses.length - 1 ? slot.endIndex : void 0
|
|
1281
|
+
} : void 0 });
|
|
1153
1282
|
passEncoder.setPipeline(pipeline);
|
|
1154
1283
|
passEncoder.setBindGroup(0, bindGroup);
|
|
1155
1284
|
if (uniformBindGroup) passEncoder.setBindGroup(1, uniformBindGroup, [i * uniformAlignment]);
|
|
1156
1285
|
passEncoder.dispatchWorkgroups(grid[0], grid[1]);
|
|
1157
1286
|
passEncoder.end();
|
|
1158
1287
|
}
|
|
1288
|
+
if (slot) recordTrace(device, slot, source, filteredPasses.length, shader.code);
|
|
1159
1289
|
}
|
|
1160
1290
|
device.queue.submit([commandEncoder.finish()]);
|
|
1161
1291
|
}
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, findPow2, isFloatDtype, mapSetUnion, prod, range, strip1, tuneWebgpu } from "./backend-
|
|
1
|
+
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, emitTrace, findPow2, isFloatDtype, isTracing, mapSetUnion, onFlushTrace, prod, range, strip1, traceSourceInfo, tuneWebgpu } from "./backend-Ctqs8la1.js";
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -152,6 +152,7 @@ var SyncReader = class SyncReader {
|
|
|
152
152
|
this.device = device;
|
|
153
153
|
}
|
|
154
154
|
#init() {
|
|
155
|
+
if (typeof OffscreenCanvas === "undefined") throw new Error("OffscreenCanvas is not available in this environment, so you cannot read data from WebGPU synchronously. Consider using the async API.");
|
|
155
156
|
const makeCanvas = () => new OffscreenCanvas(SyncReader.width, SyncReader.height);
|
|
156
157
|
this.deviceStorage = SyncReader.alphaModes.map(makeCanvas);
|
|
157
158
|
this.deviceContexts = this.deviceStorage.map((canvas, i) => {
|
|
@@ -728,6 +729,120 @@ function createRoutineShader(device, routine) {
|
|
|
728
729
|
}
|
|
729
730
|
}
|
|
730
731
|
|
|
732
|
+
//#endregion
|
|
733
|
+
//#region src/backend/webgpu/tracing.ts
|
|
734
|
+
const MAX_TIMESTAMP_QUERIES = 4096;
|
|
735
|
+
const activeBatch = /* @__PURE__ */ new WeakMap();
|
|
736
|
+
function createTracingBatch(device) {
|
|
737
|
+
return {
|
|
738
|
+
querySet: device.createQuerySet({
|
|
739
|
+
type: "timestamp",
|
|
740
|
+
count: MAX_TIMESTAMP_QUERIES
|
|
741
|
+
}),
|
|
742
|
+
resolve: device.createBuffer({
|
|
743
|
+
size: MAX_TIMESTAMP_QUERIES * 8,
|
|
744
|
+
usage: GPUBufferUsage.QUERY_RESOLVE | GPUBufferUsage.COPY_SRC
|
|
745
|
+
}),
|
|
746
|
+
dst: device.createBuffer({
|
|
747
|
+
size: MAX_TIMESTAMP_QUERIES * 8,
|
|
748
|
+
usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST
|
|
749
|
+
}),
|
|
750
|
+
nextIndex: 0,
|
|
751
|
+
entries: []
|
|
752
|
+
};
|
|
753
|
+
}
|
|
754
|
+
function acquireTracingSlot(device) {
|
|
755
|
+
if (!device.features.has("timestamp-query")) return void 0;
|
|
756
|
+
let batch = activeBatch.get(device);
|
|
757
|
+
if (batch && batch.nextIndex >= MAX_TIMESTAMP_QUERIES) {
|
|
758
|
+
flushTracingBatch(device, batch);
|
|
759
|
+
batch = void 0;
|
|
760
|
+
}
|
|
761
|
+
if (!batch) {
|
|
762
|
+
batch = createTracingBatch(device);
|
|
763
|
+
activeBatch.set(device, batch);
|
|
764
|
+
onFlushTrace(() => {
|
|
765
|
+
const b = activeBatch.get(device);
|
|
766
|
+
if (b && b.entries.length > 0) flushTracingBatch(device, b);
|
|
767
|
+
activeBatch.delete(device);
|
|
768
|
+
});
|
|
769
|
+
}
|
|
770
|
+
const beginIndex = batch.nextIndex;
|
|
771
|
+
const endIndex = beginIndex + 1;
|
|
772
|
+
batch.nextIndex += 2;
|
|
773
|
+
return {
|
|
774
|
+
batch,
|
|
775
|
+
beginIndex,
|
|
776
|
+
endIndex
|
|
777
|
+
};
|
|
778
|
+
}
|
|
779
|
+
/**
|
|
780
|
+
* If tracing is active, acquire a slot for timestamp queries.
|
|
781
|
+
*
|
|
782
|
+
* Returns undefined if tracing is not active or the device doesn't support
|
|
783
|
+
* timestamp queries.
|
|
784
|
+
*/
|
|
785
|
+
function maybeAcquireTracingSlot(device) {
|
|
786
|
+
if (!isTracing()) return void 0;
|
|
787
|
+
return acquireTracingSlot(device);
|
|
788
|
+
}
|
|
789
|
+
/**
|
|
790
|
+
* Record a tracing entry for a pipeline dispatch and schedule an auto-flush.
|
|
791
|
+
*/
|
|
792
|
+
function recordTrace(device, slot, source, numPasses, wgslSource) {
|
|
793
|
+
const info = traceSourceInfo(source);
|
|
794
|
+
info.properties.push(["passes", `${numPasses}`]);
|
|
795
|
+
info.properties.push(["source", wgslSource]);
|
|
796
|
+
slot.batch.entries.push({
|
|
797
|
+
...info,
|
|
798
|
+
beginIndex: slot.beginIndex,
|
|
799
|
+
endIndex: slot.endIndex
|
|
800
|
+
});
|
|
801
|
+
scheduleAutoFlush(device);
|
|
802
|
+
}
|
|
803
|
+
/**
|
|
804
|
+
* If the active batch has pending entries, flush and replace it so traces
|
|
805
|
+
* are emitted without waiting for the batch to fill or stopTrace().
|
|
806
|
+
*
|
|
807
|
+
* Called after each dispatch records its entry via a microtask so that
|
|
808
|
+
* synchronous back-to-back dispatches are still batched together.
|
|
809
|
+
*/
|
|
810
|
+
function scheduleAutoFlush(device) {
|
|
811
|
+
queueMicrotask(() => {
|
|
812
|
+
const batch = activeBatch.get(device);
|
|
813
|
+
if (batch && batch.entries.length > 0) {
|
|
814
|
+
flushTracingBatch(device, batch);
|
|
815
|
+
activeBatch.set(device, createTracingBatch(device));
|
|
816
|
+
}
|
|
817
|
+
});
|
|
818
|
+
}
|
|
819
|
+
function flushTracingBatch(device, batch) {
|
|
820
|
+
if (batch.entries.length === 0) return;
|
|
821
|
+
const usedQueries = batch.nextIndex;
|
|
822
|
+
const encoder = device.createCommandEncoder();
|
|
823
|
+
encoder.resolveQuerySet(batch.querySet, 0, usedQueries, batch.resolve, 0);
|
|
824
|
+
encoder.copyBufferToBuffer(batch.resolve, 0, batch.dst, 0, usedQueries * 8);
|
|
825
|
+
device.queue.submit([encoder.finish()]);
|
|
826
|
+
const { entries } = batch;
|
|
827
|
+
batch.dst.mapAsync(GPUMapMode.READ).then(() => {
|
|
828
|
+
try {
|
|
829
|
+
const times = new BigInt64Array(batch.dst.getMappedRange());
|
|
830
|
+
const anchorGpuNs = times[entries[entries.length - 1].endIndex];
|
|
831
|
+
const anchorCpuMs = performance.now();
|
|
832
|
+
for (const entry of entries) {
|
|
833
|
+
const startMs = anchorCpuMs + Number(times[entry.beginIndex] - anchorGpuNs) / 1e6;
|
|
834
|
+
const endMs = anchorCpuMs + Number(times[entry.endIndex] - anchorGpuNs) / 1e6;
|
|
835
|
+
emitTrace("webgpu", entry, startMs, endMs);
|
|
836
|
+
}
|
|
837
|
+
} finally {
|
|
838
|
+
batch.dst.unmap();
|
|
839
|
+
batch.querySet.destroy();
|
|
840
|
+
batch.resolve.destroy();
|
|
841
|
+
batch.dst.destroy();
|
|
842
|
+
}
|
|
843
|
+
});
|
|
844
|
+
}
|
|
845
|
+
|
|
731
846
|
//#endregion
|
|
732
847
|
//#region src/backend/webgpu.ts
|
|
733
848
|
/** Implementation of `Backend` that uses WebGPU in browsers. */
|
|
@@ -872,7 +987,7 @@ var WebGPUBackend = class {
|
|
|
872
987
|
dispatch(exe, inputs, outputs) {
|
|
873
988
|
const inputBuffers = inputs.map((slot) => this.#getBuffer(slot).buffer);
|
|
874
989
|
const outputBuffers = outputs.map((slot) => this.#getBuffer(slot).buffer);
|
|
875
|
-
pipelineSubmit(this.device, exe
|
|
990
|
+
pipelineSubmit(this.device, exe, inputBuffers, outputBuffers);
|
|
876
991
|
}
|
|
877
992
|
#getBuffer(slot) {
|
|
878
993
|
const buffer = this.buffers.get(slot);
|
|
@@ -1010,8 +1125,16 @@ function pipelineSource(device, kernel) {
|
|
|
1010
1125
|
else if (op === AluOp.Reciprocal) source = `(1.0 / ${a})`;
|
|
1011
1126
|
else if (op === AluOp.Floor) source = `floor(${strip1(a)})`;
|
|
1012
1127
|
else if (op === AluOp.Ceil) source = `ceil(${strip1(a)})`;
|
|
1013
|
-
else if (op === AluOp.Cast)
|
|
1014
|
-
|
|
1128
|
+
else if (op === AluOp.Cast) {
|
|
1129
|
+
const srcTy = dtypeToWgsl(src[0].dtype);
|
|
1130
|
+
const dstTy = dtypeToWgsl(dtype);
|
|
1131
|
+
if (isFloatDtype(src[0].dtype) && !(isFloatDtype(dtype) || dtype === DType.Bool)) {
|
|
1132
|
+
const maxVal = maxValueWgsl(dtype);
|
|
1133
|
+
const x = isGensym(a) ? a : gensym();
|
|
1134
|
+
if (x !== a) emit(`let ${x}: ${srcTy} = ${strip1(a)};`);
|
|
1135
|
+
source = `select(${dstTy}(${x}), ${maxVal}, ${x} >= ${srcTy}(${maxVal}))`;
|
|
1136
|
+
} else source = `${dstTy}(${strip1(a)})`;
|
|
1137
|
+
} else if (op === AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${strip1(a)})`;
|
|
1015
1138
|
}
|
|
1016
1139
|
else if (op === AluOp.Where) source = `select(${strip1(gen(src[2]))}, ${strip1(gen(src[1]))}, ${strip1(gen(src[0]))})`;
|
|
1017
1140
|
else if (op === AluOp.Threefry2x32) {
|
|
@@ -1114,12 +1237,14 @@ function pipelineSource(device, kernel) {
|
|
|
1114
1237
|
passes: [{ grid: [gridX, gridY] }]
|
|
1115
1238
|
};
|
|
1116
1239
|
}
|
|
1117
|
-
function pipelineSubmit(device,
|
|
1240
|
+
function pipelineSubmit(device, exe, inputs, outputs) {
|
|
1241
|
+
const { data: pipelines, source } = exe;
|
|
1118
1242
|
const commandEncoder = device.createCommandEncoder();
|
|
1119
1243
|
for (const { pipeline,...shader } of pipelines) {
|
|
1120
1244
|
if (inputs.length !== shader.numInputs || outputs.length !== shader.numOutputs) throw new Error(`webgpu: expected ${shader.numInputs} inputs and ${shader.numOutputs} outputs, got ${inputs.length} inputs and ${outputs.length} outputs`);
|
|
1121
1245
|
const filteredPasses = shader.passes.filter(({ grid }) => prod(grid) > 0);
|
|
1122
1246
|
if (filteredPasses.length === 0) continue;
|
|
1247
|
+
const slot = maybeAcquireTracingSlot(device);
|
|
1123
1248
|
const bindGroup = device.createBindGroup({
|
|
1124
1249
|
layout: pipeline.getBindGroupLayout(0),
|
|
1125
1250
|
entries: [...inputs.map((buffer, i) => ({
|
|
@@ -1149,13 +1274,18 @@ function pipelineSubmit(device, pipelines, inputs, outputs) {
|
|
|
1149
1274
|
}
|
|
1150
1275
|
for (let i = 0; i < filteredPasses.length; i++) {
|
|
1151
1276
|
const { grid } = filteredPasses[i];
|
|
1152
|
-
const passEncoder = commandEncoder.beginComputePass(
|
|
1277
|
+
const passEncoder = commandEncoder.beginComputePass({ timestampWrites: slot ? {
|
|
1278
|
+
querySet: slot.batch.querySet,
|
|
1279
|
+
beginningOfPassWriteIndex: i === 0 ? slot.beginIndex : void 0,
|
|
1280
|
+
endOfPassWriteIndex: i === filteredPasses.length - 1 ? slot.endIndex : void 0
|
|
1281
|
+
} : void 0 });
|
|
1153
1282
|
passEncoder.setPipeline(pipeline);
|
|
1154
1283
|
passEncoder.setBindGroup(0, bindGroup);
|
|
1155
1284
|
if (uniformBindGroup) passEncoder.setBindGroup(1, uniformBindGroup, [i * uniformAlignment]);
|
|
1156
1285
|
passEncoder.dispatchWorkgroups(grid[0], grid[1]);
|
|
1157
1286
|
passEncoder.end();
|
|
1158
1287
|
}
|
|
1288
|
+
if (slot) recordTrace(device, slot, source, filteredPasses.length, shader.code);
|
|
1159
1289
|
}
|
|
1160
1290
|
device.queue.submit([commandEncoder.finish()]);
|
|
1161
1291
|
}
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@jax-js/jax",
|
|
3
|
-
"version": "0.1.
|
|
3
|
+
"version": "0.1.10",
|
|
4
4
|
"description": "Numerical computing and ML in the browser",
|
|
5
5
|
"keywords": [
|
|
6
6
|
"machine learning",
|
|
@@ -38,15 +38,13 @@
|
|
|
38
38
|
"devDependencies": {
|
|
39
39
|
"@eslint/js": "^9.31.0",
|
|
40
40
|
"@types/debug": "^4.1.12",
|
|
41
|
-
"@vitest/browser-playwright": "^4.0
|
|
42
|
-
"@vitest/coverage-v8": "4.0
|
|
41
|
+
"@vitest/browser-playwright": "^4.1.0",
|
|
42
|
+
"@vitest/coverage-v8": "^4.1.0",
|
|
43
43
|
"@webgpu/types": "^0.1.68",
|
|
44
44
|
"eslint": "^9.31.0",
|
|
45
45
|
"eslint-plugin-import": "^2.32.0",
|
|
46
46
|
"globals": "^16.0.0",
|
|
47
|
-
"
|
|
48
|
-
"lint-staged": "^16.2.7",
|
|
49
|
-
"playwright": "~1.52.0",
|
|
47
|
+
"playwright": "~1.58.2",
|
|
50
48
|
"prettier": "^3.6.2",
|
|
51
49
|
"prettier-plugin-svelte": "^3.4.0",
|
|
52
50
|
"tsdown": "^0.13.2",
|
|
@@ -55,7 +53,7 @@
|
|
|
55
53
|
"typedoc-theme-fresh": "^0.2.3",
|
|
56
54
|
"typescript": "~5.9.3",
|
|
57
55
|
"typescript-eslint": "^8.46.4",
|
|
58
|
-
"vitest": "^4.0
|
|
56
|
+
"vitest": "^4.1.0"
|
|
59
57
|
},
|
|
60
58
|
"engines": {
|
|
61
59
|
"pnpm": ">=10.0.0"
|
|
@@ -76,15 +74,6 @@
|
|
|
76
74
|
],
|
|
77
75
|
"proseWrap": "always"
|
|
78
76
|
},
|
|
79
|
-
"lint-staged": {
|
|
80
|
-
"*.{ts,tsx,js,jsx}": [
|
|
81
|
-
"eslint --fix",
|
|
82
|
-
"prettier --write"
|
|
83
|
-
],
|
|
84
|
-
"*.{json,md,yml,yaml,css,svelte,html}": [
|
|
85
|
-
"prettier --write"
|
|
86
|
-
]
|
|
87
|
-
},
|
|
88
77
|
"scripts": {
|
|
89
78
|
"build": "tsdown",
|
|
90
79
|
"build:watch": "TSDOWN_WATCH_MODE=1 tsdown",
|