@jax-js/jax 0.1.9 → 0.1.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- const require_backend = require('./backend-DpI0riom.cjs');
1
+ const require_backend = require('./backend-DlYlOYqN.cjs');
2
2
 
3
3
  //#region src/backend/webgl/builtins.ts
4
4
  const threefrySrc = `
@@ -458,6 +458,12 @@ function generateExpression(exp, args, inputDtypes) {
458
458
  else source = `min(${a}, ${b})`;
459
459
  else if (op === require_backend.AluOp.Max) if (dtype === require_backend.DType.Bool) source = `(${a} || ${b})`;
460
460
  else source = `max(${a}, ${b})`;
461
+ else if (op === require_backend.AluOp.BitCombine) {
462
+ let infix = arg === "and" ? "&" : arg === "or" ? "|" : "^";
463
+ if (dtype === require_backend.DType.Bool) infix = infix + infix;
464
+ source = `(${a} ${infix} ${b})`;
465
+ } else if (op === require_backend.AluOp.BitShift) if (arg === "shl") source = `(${a} << ${b})`;
466
+ else source = `(${a} >> ${b})`;
461
467
  } else if (require_backend.AluGroup.Compare.has(op)) {
462
468
  const a = gen(src[0]);
463
469
  const b = gen(src[1]);
@@ -1,4 +1,4 @@
1
- import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, findPow2, isFloatDtype, mapSetUnion, prod, range, strip1, tuneWebgpu } from "./backend-BId79r5b.js";
1
+ import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, emitTrace, findPow2, isFloatDtype, isTracing, mapSetUnion, onFlushTrace, prod, range, strip1, traceSourceInfo, tuneWebgpu } from "./backend-DZvR7mZV.js";
2
2
 
3
3
  //#region src/backend/webgpu/builtins.ts
4
4
  const threefrySrc = `
@@ -152,6 +152,7 @@ var SyncReader = class SyncReader {
152
152
  this.device = device;
153
153
  }
154
154
  #init() {
155
+ if (typeof OffscreenCanvas === "undefined") throw new Error("OffscreenCanvas is not available in this environment, so you cannot read data from WebGPU synchronously. Consider using the async API.");
155
156
  const makeCanvas = () => new OffscreenCanvas(SyncReader.width, SyncReader.height);
156
157
  this.deviceStorage = SyncReader.alphaModes.map(makeCanvas);
157
158
  this.deviceContexts = this.deviceStorage.map((canvas, i) => {
@@ -728,6 +729,120 @@ function createRoutineShader(device, routine) {
728
729
  }
729
730
  }
730
731
 
732
+ //#endregion
733
+ //#region src/backend/webgpu/tracing.ts
734
+ const MAX_TIMESTAMP_QUERIES = 4096;
735
+ const activeBatch = /* @__PURE__ */ new WeakMap();
736
+ function createTracingBatch(device) {
737
+ return {
738
+ querySet: device.createQuerySet({
739
+ type: "timestamp",
740
+ count: MAX_TIMESTAMP_QUERIES
741
+ }),
742
+ resolve: device.createBuffer({
743
+ size: MAX_TIMESTAMP_QUERIES * 8,
744
+ usage: GPUBufferUsage.QUERY_RESOLVE | GPUBufferUsage.COPY_SRC
745
+ }),
746
+ dst: device.createBuffer({
747
+ size: MAX_TIMESTAMP_QUERIES * 8,
748
+ usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST
749
+ }),
750
+ nextIndex: 0,
751
+ entries: []
752
+ };
753
+ }
754
+ function acquireTracingSlot(device) {
755
+ if (!device.features.has("timestamp-query")) return void 0;
756
+ let batch = activeBatch.get(device);
757
+ if (batch && batch.nextIndex >= MAX_TIMESTAMP_QUERIES) {
758
+ flushTracingBatch(device, batch);
759
+ batch = void 0;
760
+ }
761
+ if (!batch) {
762
+ batch = createTracingBatch(device);
763
+ activeBatch.set(device, batch);
764
+ onFlushTrace(() => {
765
+ const b = activeBatch.get(device);
766
+ if (b && b.entries.length > 0) flushTracingBatch(device, b);
767
+ activeBatch.delete(device);
768
+ });
769
+ }
770
+ const beginIndex = batch.nextIndex;
771
+ const endIndex = beginIndex + 1;
772
+ batch.nextIndex += 2;
773
+ return {
774
+ batch,
775
+ beginIndex,
776
+ endIndex
777
+ };
778
+ }
779
+ /**
780
+ * If tracing is active, acquire a slot for timestamp queries.
781
+ *
782
+ * Returns undefined if tracing is not active or the device doesn't support
783
+ * timestamp queries.
784
+ */
785
+ function maybeAcquireTracingSlot(device) {
786
+ if (!isTracing()) return void 0;
787
+ return acquireTracingSlot(device);
788
+ }
789
+ /**
790
+ * Record a tracing entry for a pipeline dispatch and schedule an auto-flush.
791
+ */
792
+ function recordTrace(device, slot, source, numPasses, wgslSource) {
793
+ const info = traceSourceInfo(source);
794
+ info.properties.push(["passes", `${numPasses}`]);
795
+ info.properties.push(["source", wgslSource]);
796
+ slot.batch.entries.push({
797
+ ...info,
798
+ beginIndex: slot.beginIndex,
799
+ endIndex: slot.endIndex
800
+ });
801
+ scheduleAutoFlush(device);
802
+ }
803
+ /**
804
+ * If the active batch has pending entries, flush and replace it so traces
805
+ * are emitted without waiting for the batch to fill or stopTrace().
806
+ *
807
+ * Called after each dispatch records its entry via a microtask so that
808
+ * synchronous back-to-back dispatches are still batched together.
809
+ */
810
+ function scheduleAutoFlush(device) {
811
+ queueMicrotask(() => {
812
+ const batch = activeBatch.get(device);
813
+ if (batch && batch.entries.length > 0) {
814
+ flushTracingBatch(device, batch);
815
+ activeBatch.set(device, createTracingBatch(device));
816
+ }
817
+ });
818
+ }
819
+ function flushTracingBatch(device, batch) {
820
+ if (batch.entries.length === 0) return;
821
+ const usedQueries = batch.nextIndex;
822
+ const encoder = device.createCommandEncoder();
823
+ encoder.resolveQuerySet(batch.querySet, 0, usedQueries, batch.resolve, 0);
824
+ encoder.copyBufferToBuffer(batch.resolve, 0, batch.dst, 0, usedQueries * 8);
825
+ device.queue.submit([encoder.finish()]);
826
+ const { entries } = batch;
827
+ batch.dst.mapAsync(GPUMapMode.READ).then(() => {
828
+ try {
829
+ const times = new BigInt64Array(batch.dst.getMappedRange());
830
+ const anchorGpuNs = times[entries[entries.length - 1].endIndex];
831
+ const anchorCpuMs = performance.now();
832
+ for (const entry of entries) {
833
+ const startMs = anchorCpuMs + Number(times[entry.beginIndex] - anchorGpuNs) / 1e6;
834
+ const endMs = anchorCpuMs + Number(times[entry.endIndex] - anchorGpuNs) / 1e6;
835
+ emitTrace("webgpu", entry, startMs, endMs);
836
+ }
837
+ } finally {
838
+ batch.dst.unmap();
839
+ batch.querySet.destroy();
840
+ batch.resolve.destroy();
841
+ batch.dst.destroy();
842
+ }
843
+ });
844
+ }
845
+
731
846
  //#endregion
732
847
  //#region src/backend/webgpu.ts
733
848
  /** Implementation of `Backend` that uses WebGPU in browsers. */
@@ -872,7 +987,7 @@ var WebGPUBackend = class {
872
987
  dispatch(exe, inputs, outputs) {
873
988
  const inputBuffers = inputs.map((slot) => this.#getBuffer(slot).buffer);
874
989
  const outputBuffers = outputs.map((slot) => this.#getBuffer(slot).buffer);
875
- pipelineSubmit(this.device, exe.data, inputBuffers, outputBuffers);
990
+ pipelineSubmit(this.device, exe, inputBuffers, outputBuffers);
876
991
  }
877
992
  #getBuffer(slot) {
878
993
  const buffer = this.buffers.get(slot);
@@ -985,6 +1100,11 @@ function pipelineSource(device, kernel) {
985
1100
  else source = `min(${strip1(a)}, ${strip1(b)})`;
986
1101
  else if (op === AluOp.Max) if (dtype === DType.Bool) source = `(${a} || ${b})`;
987
1102
  else source = `max(${strip1(a)}, ${strip1(b)})`;
1103
+ else if (op === AluOp.BitCombine) if (arg === "and") source = `(${a} & ${b})`;
1104
+ else if (arg === "or") source = `(${a} | ${b})`;
1105
+ else source = dtype === DType.Bool ? `(${a} != ${b})` : `(${a} ^ ${b})`;
1106
+ else if (op === AluOp.BitShift) if (arg === "shl") source = `(${a} << ${b})`;
1107
+ else source = `(${a} >> ${b})`;
988
1108
  else if (op === AluOp.Cmplt) source = `(${a} < ${b})`;
989
1109
  else if (op === AluOp.Cmpne) if (isFloatDtype(src[0].dtype)) {
990
1110
  const x = isGensym(a) ? a : gensym();
@@ -1010,8 +1130,16 @@ function pipelineSource(device, kernel) {
1010
1130
  else if (op === AluOp.Reciprocal) source = `(1.0 / ${a})`;
1011
1131
  else if (op === AluOp.Floor) source = `floor(${strip1(a)})`;
1012
1132
  else if (op === AluOp.Ceil) source = `ceil(${strip1(a)})`;
1013
- else if (op === AluOp.Cast) source = `${dtypeToWgsl(dtype)}(${strip1(a)})`;
1014
- else if (op === AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${strip1(a)})`;
1133
+ else if (op === AluOp.Cast) {
1134
+ const srcTy = dtypeToWgsl(src[0].dtype);
1135
+ const dstTy = dtypeToWgsl(dtype);
1136
+ if (isFloatDtype(src[0].dtype) && !(isFloatDtype(dtype) || dtype === DType.Bool)) {
1137
+ const maxVal = maxValueWgsl(dtype);
1138
+ const x = isGensym(a) ? a : gensym();
1139
+ if (x !== a) emit(`let ${x}: ${srcTy} = ${strip1(a)};`);
1140
+ source = `select(${dstTy}(${x}), ${maxVal}, ${x} >= ${srcTy}(${maxVal}))`;
1141
+ } else source = `${dstTy}(${strip1(a)})`;
1142
+ } else if (op === AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${strip1(a)})`;
1015
1143
  }
1016
1144
  else if (op === AluOp.Where) source = `select(${strip1(gen(src[2]))}, ${strip1(gen(src[1]))}, ${strip1(gen(src[0]))})`;
1017
1145
  else if (op === AluOp.Threefry2x32) {
@@ -1114,12 +1242,14 @@ function pipelineSource(device, kernel) {
1114
1242
  passes: [{ grid: [gridX, gridY] }]
1115
1243
  };
1116
1244
  }
1117
- function pipelineSubmit(device, pipelines, inputs, outputs) {
1245
+ function pipelineSubmit(device, exe, inputs, outputs) {
1246
+ const { data: pipelines, source } = exe;
1118
1247
  const commandEncoder = device.createCommandEncoder();
1119
1248
  for (const { pipeline,...shader } of pipelines) {
1120
1249
  if (inputs.length !== shader.numInputs || outputs.length !== shader.numOutputs) throw new Error(`webgpu: expected ${shader.numInputs} inputs and ${shader.numOutputs} outputs, got ${inputs.length} inputs and ${outputs.length} outputs`);
1121
1250
  const filteredPasses = shader.passes.filter(({ grid }) => prod(grid) > 0);
1122
1251
  if (filteredPasses.length === 0) continue;
1252
+ const slot = maybeAcquireTracingSlot(device);
1123
1253
  const bindGroup = device.createBindGroup({
1124
1254
  layout: pipeline.getBindGroupLayout(0),
1125
1255
  entries: [...inputs.map((buffer, i) => ({
@@ -1149,13 +1279,18 @@ function pipelineSubmit(device, pipelines, inputs, outputs) {
1149
1279
  }
1150
1280
  for (let i = 0; i < filteredPasses.length; i++) {
1151
1281
  const { grid } = filteredPasses[i];
1152
- const passEncoder = commandEncoder.beginComputePass();
1282
+ const passEncoder = commandEncoder.beginComputePass({ timestampWrites: slot ? {
1283
+ querySet: slot.batch.querySet,
1284
+ beginningOfPassWriteIndex: i === 0 ? slot.beginIndex : void 0,
1285
+ endOfPassWriteIndex: i === filteredPasses.length - 1 ? slot.endIndex : void 0
1286
+ } : void 0 });
1153
1287
  passEncoder.setPipeline(pipeline);
1154
1288
  passEncoder.setBindGroup(0, bindGroup);
1155
1289
  if (uniformBindGroup) passEncoder.setBindGroup(1, uniformBindGroup, [i * uniformAlignment]);
1156
1290
  passEncoder.dispatchWorkgroups(grid[0], grid[1]);
1157
1291
  passEncoder.end();
1158
1292
  }
1293
+ if (slot) recordTrace(device, slot, source, filteredPasses.length, shader.code);
1159
1294
  }
1160
1295
  device.queue.submit([commandEncoder.finish()]);
1161
1296
  }
@@ -1,4 +1,4 @@
1
- const require_backend = require('./backend-DpI0riom.cjs');
1
+ const require_backend = require('./backend-DlYlOYqN.cjs');
2
2
 
3
3
  //#region src/backend/webgpu/builtins.ts
4
4
  const threefrySrc = `
@@ -152,6 +152,7 @@ var SyncReader = class SyncReader {
152
152
  this.device = device;
153
153
  }
154
154
  #init() {
155
+ if (typeof OffscreenCanvas === "undefined") throw new Error("OffscreenCanvas is not available in this environment, so you cannot read data from WebGPU synchronously. Consider using the async API.");
155
156
  const makeCanvas = () => new OffscreenCanvas(SyncReader.width, SyncReader.height);
156
157
  this.deviceStorage = SyncReader.alphaModes.map(makeCanvas);
157
158
  this.deviceContexts = this.deviceStorage.map((canvas, i) => {
@@ -728,6 +729,120 @@ function createRoutineShader(device, routine) {
728
729
  }
729
730
  }
730
731
 
732
+ //#endregion
733
+ //#region src/backend/webgpu/tracing.ts
734
+ const MAX_TIMESTAMP_QUERIES = 4096;
735
+ const activeBatch = /* @__PURE__ */ new WeakMap();
736
+ function createTracingBatch(device) {
737
+ return {
738
+ querySet: device.createQuerySet({
739
+ type: "timestamp",
740
+ count: MAX_TIMESTAMP_QUERIES
741
+ }),
742
+ resolve: device.createBuffer({
743
+ size: MAX_TIMESTAMP_QUERIES * 8,
744
+ usage: GPUBufferUsage.QUERY_RESOLVE | GPUBufferUsage.COPY_SRC
745
+ }),
746
+ dst: device.createBuffer({
747
+ size: MAX_TIMESTAMP_QUERIES * 8,
748
+ usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST
749
+ }),
750
+ nextIndex: 0,
751
+ entries: []
752
+ };
753
+ }
754
+ function acquireTracingSlot(device) {
755
+ if (!device.features.has("timestamp-query")) return void 0;
756
+ let batch = activeBatch.get(device);
757
+ if (batch && batch.nextIndex >= MAX_TIMESTAMP_QUERIES) {
758
+ flushTracingBatch(device, batch);
759
+ batch = void 0;
760
+ }
761
+ if (!batch) {
762
+ batch = createTracingBatch(device);
763
+ activeBatch.set(device, batch);
764
+ require_backend.onFlushTrace(() => {
765
+ const b = activeBatch.get(device);
766
+ if (b && b.entries.length > 0) flushTracingBatch(device, b);
767
+ activeBatch.delete(device);
768
+ });
769
+ }
770
+ const beginIndex = batch.nextIndex;
771
+ const endIndex = beginIndex + 1;
772
+ batch.nextIndex += 2;
773
+ return {
774
+ batch,
775
+ beginIndex,
776
+ endIndex
777
+ };
778
+ }
779
+ /**
780
+ * If tracing is active, acquire a slot for timestamp queries.
781
+ *
782
+ * Returns undefined if tracing is not active or the device doesn't support
783
+ * timestamp queries.
784
+ */
785
+ function maybeAcquireTracingSlot(device) {
786
+ if (!require_backend.isTracing()) return void 0;
787
+ return acquireTracingSlot(device);
788
+ }
789
+ /**
790
+ * Record a tracing entry for a pipeline dispatch and schedule an auto-flush.
791
+ */
792
+ function recordTrace(device, slot, source, numPasses, wgslSource) {
793
+ const info = require_backend.traceSourceInfo(source);
794
+ info.properties.push(["passes", `${numPasses}`]);
795
+ info.properties.push(["source", wgslSource]);
796
+ slot.batch.entries.push({
797
+ ...info,
798
+ beginIndex: slot.beginIndex,
799
+ endIndex: slot.endIndex
800
+ });
801
+ scheduleAutoFlush(device);
802
+ }
803
+ /**
804
+ * If the active batch has pending entries, flush and replace it so traces
805
+ * are emitted without waiting for the batch to fill or stopTrace().
806
+ *
807
+ * Called after each dispatch records its entry via a microtask so that
808
+ * synchronous back-to-back dispatches are still batched together.
809
+ */
810
+ function scheduleAutoFlush(device) {
811
+ queueMicrotask(() => {
812
+ const batch = activeBatch.get(device);
813
+ if (batch && batch.entries.length > 0) {
814
+ flushTracingBatch(device, batch);
815
+ activeBatch.set(device, createTracingBatch(device));
816
+ }
817
+ });
818
+ }
819
+ function flushTracingBatch(device, batch) {
820
+ if (batch.entries.length === 0) return;
821
+ const usedQueries = batch.nextIndex;
822
+ const encoder = device.createCommandEncoder();
823
+ encoder.resolveQuerySet(batch.querySet, 0, usedQueries, batch.resolve, 0);
824
+ encoder.copyBufferToBuffer(batch.resolve, 0, batch.dst, 0, usedQueries * 8);
825
+ device.queue.submit([encoder.finish()]);
826
+ const { entries } = batch;
827
+ batch.dst.mapAsync(GPUMapMode.READ).then(() => {
828
+ try {
829
+ const times = new BigInt64Array(batch.dst.getMappedRange());
830
+ const anchorGpuNs = times[entries[entries.length - 1].endIndex];
831
+ const anchorCpuMs = performance.now();
832
+ for (const entry of entries) {
833
+ const startMs = anchorCpuMs + Number(times[entry.beginIndex] - anchorGpuNs) / 1e6;
834
+ const endMs = anchorCpuMs + Number(times[entry.endIndex] - anchorGpuNs) / 1e6;
835
+ require_backend.emitTrace("webgpu", entry, startMs, endMs);
836
+ }
837
+ } finally {
838
+ batch.dst.unmap();
839
+ batch.querySet.destroy();
840
+ batch.resolve.destroy();
841
+ batch.dst.destroy();
842
+ }
843
+ });
844
+ }
845
+
731
846
  //#endregion
732
847
  //#region src/backend/webgpu.ts
733
848
  /** Implementation of `Backend` that uses WebGPU in browsers. */
@@ -872,7 +987,7 @@ var WebGPUBackend = class {
872
987
  dispatch(exe, inputs, outputs) {
873
988
  const inputBuffers = inputs.map((slot) => this.#getBuffer(slot).buffer);
874
989
  const outputBuffers = outputs.map((slot) => this.#getBuffer(slot).buffer);
875
- pipelineSubmit(this.device, exe.data, inputBuffers, outputBuffers);
990
+ pipelineSubmit(this.device, exe, inputBuffers, outputBuffers);
876
991
  }
877
992
  #getBuffer(slot) {
878
993
  const buffer = this.buffers.get(slot);
@@ -985,6 +1100,11 @@ function pipelineSource(device, kernel) {
985
1100
  else source = `min(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
986
1101
  else if (op === require_backend.AluOp.Max) if (dtype === require_backend.DType.Bool) source = `(${a} || ${b})`;
987
1102
  else source = `max(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
1103
+ else if (op === require_backend.AluOp.BitCombine) if (arg === "and") source = `(${a} & ${b})`;
1104
+ else if (arg === "or") source = `(${a} | ${b})`;
1105
+ else source = dtype === require_backend.DType.Bool ? `(${a} != ${b})` : `(${a} ^ ${b})`;
1106
+ else if (op === require_backend.AluOp.BitShift) if (arg === "shl") source = `(${a} << ${b})`;
1107
+ else source = `(${a} >> ${b})`;
988
1108
  else if (op === require_backend.AluOp.Cmplt) source = `(${a} < ${b})`;
989
1109
  else if (op === require_backend.AluOp.Cmpne) if (require_backend.isFloatDtype(src[0].dtype)) {
990
1110
  const x = isGensym(a) ? a : gensym();
@@ -1010,8 +1130,16 @@ function pipelineSource(device, kernel) {
1010
1130
  else if (op === require_backend.AluOp.Reciprocal) source = `(1.0 / ${a})`;
1011
1131
  else if (op === require_backend.AluOp.Floor) source = `floor(${require_backend.strip1(a)})`;
1012
1132
  else if (op === require_backend.AluOp.Ceil) source = `ceil(${require_backend.strip1(a)})`;
1013
- else if (op === require_backend.AluOp.Cast) source = `${dtypeToWgsl(dtype)}(${require_backend.strip1(a)})`;
1014
- else if (op === require_backend.AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${require_backend.strip1(a)})`;
1133
+ else if (op === require_backend.AluOp.Cast) {
1134
+ const srcTy = dtypeToWgsl(src[0].dtype);
1135
+ const dstTy = dtypeToWgsl(dtype);
1136
+ if (require_backend.isFloatDtype(src[0].dtype) && !(require_backend.isFloatDtype(dtype) || dtype === require_backend.DType.Bool)) {
1137
+ const maxVal = maxValueWgsl(dtype);
1138
+ const x = isGensym(a) ? a : gensym();
1139
+ if (x !== a) emit(`let ${x}: ${srcTy} = ${require_backend.strip1(a)};`);
1140
+ source = `select(${dstTy}(${x}), ${maxVal}, ${x} >= ${srcTy}(${maxVal}))`;
1141
+ } else source = `${dstTy}(${require_backend.strip1(a)})`;
1142
+ } else if (op === require_backend.AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${require_backend.strip1(a)})`;
1015
1143
  }
1016
1144
  else if (op === require_backend.AluOp.Where) source = `select(${require_backend.strip1(gen(src[2]))}, ${require_backend.strip1(gen(src[1]))}, ${require_backend.strip1(gen(src[0]))})`;
1017
1145
  else if (op === require_backend.AluOp.Threefry2x32) {
@@ -1114,12 +1242,14 @@ function pipelineSource(device, kernel) {
1114
1242
  passes: [{ grid: [gridX, gridY] }]
1115
1243
  };
1116
1244
  }
1117
- function pipelineSubmit(device, pipelines, inputs, outputs) {
1245
+ function pipelineSubmit(device, exe, inputs, outputs) {
1246
+ const { data: pipelines, source } = exe;
1118
1247
  const commandEncoder = device.createCommandEncoder();
1119
1248
  for (const { pipeline,...shader } of pipelines) {
1120
1249
  if (inputs.length !== shader.numInputs || outputs.length !== shader.numOutputs) throw new Error(`webgpu: expected ${shader.numInputs} inputs and ${shader.numOutputs} outputs, got ${inputs.length} inputs and ${outputs.length} outputs`);
1121
1250
  const filteredPasses = shader.passes.filter(({ grid }) => require_backend.prod(grid) > 0);
1122
1251
  if (filteredPasses.length === 0) continue;
1252
+ const slot = maybeAcquireTracingSlot(device);
1123
1253
  const bindGroup = device.createBindGroup({
1124
1254
  layout: pipeline.getBindGroupLayout(0),
1125
1255
  entries: [...inputs.map((buffer, i) => ({
@@ -1149,13 +1279,18 @@ function pipelineSubmit(device, pipelines, inputs, outputs) {
1149
1279
  }
1150
1280
  for (let i = 0; i < filteredPasses.length; i++) {
1151
1281
  const { grid } = filteredPasses[i];
1152
- const passEncoder = commandEncoder.beginComputePass();
1282
+ const passEncoder = commandEncoder.beginComputePass({ timestampWrites: slot ? {
1283
+ querySet: slot.batch.querySet,
1284
+ beginningOfPassWriteIndex: i === 0 ? slot.beginIndex : void 0,
1285
+ endOfPassWriteIndex: i === filteredPasses.length - 1 ? slot.endIndex : void 0
1286
+ } : void 0 });
1153
1287
  passEncoder.setPipeline(pipeline);
1154
1288
  passEncoder.setBindGroup(0, bindGroup);
1155
1289
  if (uniformBindGroup) passEncoder.setBindGroup(1, uniformBindGroup, [i * uniformAlignment]);
1156
1290
  passEncoder.dispatchWorkgroups(grid[0], grid[1]);
1157
1291
  passEncoder.end();
1158
1292
  }
1293
+ if (slot) recordTrace(device, slot, source, filteredPasses.length, shader.code);
1159
1294
  }
1160
1295
  device.queue.submit([commandEncoder.finish()]);
1161
1296
  }
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@jax-js/jax",
3
- "version": "0.1.9",
3
+ "version": "0.1.11",
4
4
  "description": "Numerical computing and ML in the browser",
5
5
  "keywords": [
6
6
  "machine learning",
@@ -38,15 +38,13 @@
38
38
  "devDependencies": {
39
39
  "@eslint/js": "^9.31.0",
40
40
  "@types/debug": "^4.1.12",
41
- "@vitest/browser-playwright": "^4.0.9",
42
- "@vitest/coverage-v8": "4.0.9",
41
+ "@vitest/browser-playwright": "^4.1.0",
42
+ "@vitest/coverage-v8": "^4.1.0",
43
43
  "@webgpu/types": "^0.1.68",
44
44
  "eslint": "^9.31.0",
45
45
  "eslint-plugin-import": "^2.32.0",
46
46
  "globals": "^16.0.0",
47
- "husky": "^9.1.7",
48
- "lint-staged": "^16.2.7",
49
- "playwright": "~1.52.0",
47
+ "playwright": "~1.58.2",
50
48
  "prettier": "^3.6.2",
51
49
  "prettier-plugin-svelte": "^3.4.0",
52
50
  "tsdown": "^0.13.2",
@@ -55,7 +53,7 @@
55
53
  "typedoc-theme-fresh": "^0.2.3",
56
54
  "typescript": "~5.9.3",
57
55
  "typescript-eslint": "^8.46.4",
58
- "vitest": "^4.0.9"
56
+ "vitest": "^4.1.0"
59
57
  },
60
58
  "engines": {
61
59
  "pnpm": ">=10.0.0"
@@ -76,15 +74,6 @@
76
74
  ],
77
75
  "proseWrap": "always"
78
76
  },
79
- "lint-staged": {
80
- "*.{ts,tsx,js,jsx}": [
81
- "eslint --fix",
82
- "prettier --write"
83
- ],
84
- "*.{json,md,yml,yaml,css,svelte,html}": [
85
- "prettier --write"
86
- ]
87
- },
88
77
  "scripts": {
89
78
  "build": "tsdown",
90
79
  "build:watch": "TSDOWN_WATCH_MODE=1 tsdown",