@jax-js/jax 0.1.8 → 0.1.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- const require_backend = require('./backend-B3foXiV_.cjs');
1
+ const require_backend = require('./backend-DMauYnfl.cjs');
2
2
 
3
3
  //#region src/backend/webgpu/builtins.ts
4
4
  const threefrySrc = `
@@ -152,6 +152,7 @@ var SyncReader = class SyncReader {
152
152
  this.device = device;
153
153
  }
154
154
  #init() {
155
+ if (typeof OffscreenCanvas === "undefined") throw new Error("OffscreenCanvas is not available in this environment, so you cannot read data from WebGPU synchronously. Consider using the async API.");
155
156
  const makeCanvas = () => new OffscreenCanvas(SyncReader.width, SyncReader.height);
156
157
  this.deviceStorage = SyncReader.alphaModes.map(makeCanvas);
157
158
  this.deviceContexts = this.deviceStorage.map((canvas, i) => {
@@ -247,6 +248,10 @@ function bitonicSortUniform(pass) {
247
248
  * `2^(step+1)` with multiple workgroups. This doesn't use shared memory.
248
249
  *
249
250
  * The total number of passes is roughly `log2(n / workgroupSize)^2 / 2`.
251
+ *
252
+ * If `outputIndices` is true, the shader also tracks the original indices of
253
+ * the sorted elements (argsort) and outputs them to a separate buffer. This
254
+ * also makes the sorting algorithm stable.
250
255
  */
251
256
  function bitonicSortShader(device, dtype, n, batches, outputIndices) {
252
257
  const ty = dtypeToWgsl(dtype, true);
@@ -286,14 +291,21 @@ ${require_backend.isFloatDtype(dtype) ? `
286
291
  fn compare_and_swap(i: u32, j: u32) {
287
292
  let val_i = shared_vals[i];
288
293
  let val_j = shared_vals[j];
289
- if (compare(val_j, val_i)) {
294
+ ${outputIndices ? `
295
+ if (
296
+ compare(val_j, val_i) ||
297
+ (!compare(val_i, val_j) && shared_idx[j] < shared_idx[i])
298
+ ) {
290
299
  shared_vals[i] = val_j;
291
300
  shared_vals[j] = val_i;
292
- ${outputIndices ? `
293
301
  let tmp_idx = shared_idx[i];
294
302
  shared_idx[i] = shared_idx[j];
295
- shared_idx[j] = tmp_idx;` : ""}
296
- }
303
+ shared_idx[j] = tmp_idx;
304
+ }` : `
305
+ if (compare(val_j, val_i)) {
306
+ shared_vals[i] = val_j;
307
+ shared_vals[j] = val_i;
308
+ }`}
297
309
  }
298
310
 
299
311
  @compute @workgroup_size(${workgroupSize})
@@ -370,13 +382,17 @@ ${outputIndices ? `
370
382
  if (j < ${n}u) {
371
383
  let val_i = output[base + i];
372
384
  let val_j = output[base + j];
373
- if (compare(val_j, val_i)) {
385
+ ${outputIndices ? `
386
+ let idx_i = output_idx[base + i];
387
+ let idx_j = output_idx[base + j];
388
+ if (compare(val_j, val_i) || (!compare(val_i, val_j) && idx_j < idx_i)) {
374
389
  output[base + i] = val_j;
375
390
  output[base + j] = val_i;
376
- ${outputIndices ? `
377
- let tmp_idx = output_idx[base + i];
378
- output_idx[base + i] = output_idx[base + j];
379
- output_idx[base + j] = tmp_idx;` : ""}
391
+ output_idx[base + i] = idx_j;
392
+ output_idx[base + j] = idx_i;` : `
393
+ if (compare(val_j, val_i)) {
394
+ output[base + i] = val_j;
395
+ output[base + j] = val_i;`}
380
396
  }
381
397
  }
382
398
  }
@@ -713,6 +729,120 @@ function createRoutineShader(device, routine) {
713
729
  }
714
730
  }
715
731
 
732
+ //#endregion
733
+ //#region src/backend/webgpu/tracing.ts
734
+ const MAX_TIMESTAMP_QUERIES = 4096;
735
+ const activeBatch = /* @__PURE__ */ new WeakMap();
736
+ function createTracingBatch(device) {
737
+ return {
738
+ querySet: device.createQuerySet({
739
+ type: "timestamp",
740
+ count: MAX_TIMESTAMP_QUERIES
741
+ }),
742
+ resolve: device.createBuffer({
743
+ size: MAX_TIMESTAMP_QUERIES * 8,
744
+ usage: GPUBufferUsage.QUERY_RESOLVE | GPUBufferUsage.COPY_SRC
745
+ }),
746
+ dst: device.createBuffer({
747
+ size: MAX_TIMESTAMP_QUERIES * 8,
748
+ usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST
749
+ }),
750
+ nextIndex: 0,
751
+ entries: []
752
+ };
753
+ }
754
+ function acquireTracingSlot(device) {
755
+ if (!device.features.has("timestamp-query")) return void 0;
756
+ let batch = activeBatch.get(device);
757
+ if (batch && batch.nextIndex >= MAX_TIMESTAMP_QUERIES) {
758
+ flushTracingBatch(device, batch);
759
+ batch = void 0;
760
+ }
761
+ if (!batch) {
762
+ batch = createTracingBatch(device);
763
+ activeBatch.set(device, batch);
764
+ require_backend.onFlushTrace(() => {
765
+ const b = activeBatch.get(device);
766
+ if (b && b.entries.length > 0) flushTracingBatch(device, b);
767
+ activeBatch.delete(device);
768
+ });
769
+ }
770
+ const beginIndex = batch.nextIndex;
771
+ const endIndex = beginIndex + 1;
772
+ batch.nextIndex += 2;
773
+ return {
774
+ batch,
775
+ beginIndex,
776
+ endIndex
777
+ };
778
+ }
779
+ /**
780
+ * If tracing is active, acquire a slot for timestamp queries.
781
+ *
782
+ * Returns undefined if tracing is not active or the device doesn't support
783
+ * timestamp queries.
784
+ */
785
+ function maybeAcquireTracingSlot(device) {
786
+ if (!require_backend.isTracing()) return void 0;
787
+ return acquireTracingSlot(device);
788
+ }
789
+ /**
790
+ * Record a tracing entry for a pipeline dispatch and schedule an auto-flush.
791
+ */
792
+ function recordTrace(device, slot, source, numPasses, wgslSource) {
793
+ const info = require_backend.traceSourceInfo(source);
794
+ info.properties.push(["passes", `${numPasses}`]);
795
+ info.properties.push(["source", wgslSource]);
796
+ slot.batch.entries.push({
797
+ ...info,
798
+ beginIndex: slot.beginIndex,
799
+ endIndex: slot.endIndex
800
+ });
801
+ scheduleAutoFlush(device);
802
+ }
803
+ /**
804
+ * If the active batch has pending entries, flush and replace it so traces
805
+ * are emitted without waiting for the batch to fill or stopTrace().
806
+ *
807
+ * Called after each dispatch records its entry via a microtask so that
808
+ * synchronous back-to-back dispatches are still batched together.
809
+ */
810
+ function scheduleAutoFlush(device) {
811
+ queueMicrotask(() => {
812
+ const batch = activeBatch.get(device);
813
+ if (batch && batch.entries.length > 0) {
814
+ flushTracingBatch(device, batch);
815
+ activeBatch.set(device, createTracingBatch(device));
816
+ }
817
+ });
818
+ }
819
+ function flushTracingBatch(device, batch) {
820
+ if (batch.entries.length === 0) return;
821
+ const usedQueries = batch.nextIndex;
822
+ const encoder = device.createCommandEncoder();
823
+ encoder.resolveQuerySet(batch.querySet, 0, usedQueries, batch.resolve, 0);
824
+ encoder.copyBufferToBuffer(batch.resolve, 0, batch.dst, 0, usedQueries * 8);
825
+ device.queue.submit([encoder.finish()]);
826
+ const { entries } = batch;
827
+ batch.dst.mapAsync(GPUMapMode.READ).then(() => {
828
+ try {
829
+ const times = new BigInt64Array(batch.dst.getMappedRange());
830
+ const anchorGpuNs = times[entries[entries.length - 1].endIndex];
831
+ const anchorCpuMs = performance.now();
832
+ for (const entry of entries) {
833
+ const startMs = anchorCpuMs + Number(times[entry.beginIndex] - anchorGpuNs) / 1e6;
834
+ const endMs = anchorCpuMs + Number(times[entry.endIndex] - anchorGpuNs) / 1e6;
835
+ require_backend.emitTrace("webgpu", entry, startMs, endMs);
836
+ }
837
+ } finally {
838
+ batch.dst.unmap();
839
+ batch.querySet.destroy();
840
+ batch.resolve.destroy();
841
+ batch.dst.destroy();
842
+ }
843
+ });
844
+ }
845
+
716
846
  //#endregion
717
847
  //#region src/backend/webgpu.ts
718
848
  /** Implementation of `Backend` that uses WebGPU in browsers. */
@@ -857,7 +987,7 @@ var WebGPUBackend = class {
857
987
  dispatch(exe, inputs, outputs) {
858
988
  const inputBuffers = inputs.map((slot) => this.#getBuffer(slot).buffer);
859
989
  const outputBuffers = outputs.map((slot) => this.#getBuffer(slot).buffer);
860
- pipelineSubmit(this.device, exe.data, inputBuffers, outputBuffers);
990
+ pipelineSubmit(this.device, exe, inputBuffers, outputBuffers);
861
991
  }
862
992
  #getBuffer(slot) {
863
993
  const buffer = this.buffers.get(slot);
@@ -995,8 +1125,16 @@ function pipelineSource(device, kernel) {
995
1125
  else if (op === require_backend.AluOp.Reciprocal) source = `(1.0 / ${a})`;
996
1126
  else if (op === require_backend.AluOp.Floor) source = `floor(${require_backend.strip1(a)})`;
997
1127
  else if (op === require_backend.AluOp.Ceil) source = `ceil(${require_backend.strip1(a)})`;
998
- else if (op === require_backend.AluOp.Cast) source = `${dtypeToWgsl(dtype)}(${require_backend.strip1(a)})`;
999
- else if (op === require_backend.AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${require_backend.strip1(a)})`;
1128
+ else if (op === require_backend.AluOp.Cast) {
1129
+ const srcTy = dtypeToWgsl(src[0].dtype);
1130
+ const dstTy = dtypeToWgsl(dtype);
1131
+ if (require_backend.isFloatDtype(src[0].dtype) && !(require_backend.isFloatDtype(dtype) || dtype === require_backend.DType.Bool)) {
1132
+ const maxVal = maxValueWgsl(dtype);
1133
+ const x = isGensym(a) ? a : gensym();
1134
+ if (x !== a) emit(`let ${x}: ${srcTy} = ${require_backend.strip1(a)};`);
1135
+ source = `select(${dstTy}(${x}), ${maxVal}, ${x} >= ${srcTy}(${maxVal}))`;
1136
+ } else source = `${dstTy}(${require_backend.strip1(a)})`;
1137
+ } else if (op === require_backend.AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${require_backend.strip1(a)})`;
1000
1138
  }
1001
1139
  else if (op === require_backend.AluOp.Where) source = `select(${require_backend.strip1(gen(src[2]))}, ${require_backend.strip1(gen(src[1]))}, ${require_backend.strip1(gen(src[0]))})`;
1002
1140
  else if (op === require_backend.AluOp.Threefry2x32) {
@@ -1099,12 +1237,14 @@ function pipelineSource(device, kernel) {
1099
1237
  passes: [{ grid: [gridX, gridY] }]
1100
1238
  };
1101
1239
  }
1102
- function pipelineSubmit(device, pipelines, inputs, outputs) {
1240
+ function pipelineSubmit(device, exe, inputs, outputs) {
1241
+ const { data: pipelines, source } = exe;
1103
1242
  const commandEncoder = device.createCommandEncoder();
1104
1243
  for (const { pipeline,...shader } of pipelines) {
1105
1244
  if (inputs.length !== shader.numInputs || outputs.length !== shader.numOutputs) throw new Error(`webgpu: expected ${shader.numInputs} inputs and ${shader.numOutputs} outputs, got ${inputs.length} inputs and ${outputs.length} outputs`);
1106
1245
  const filteredPasses = shader.passes.filter(({ grid }) => require_backend.prod(grid) > 0);
1107
1246
  if (filteredPasses.length === 0) continue;
1247
+ const slot = maybeAcquireTracingSlot(device);
1108
1248
  const bindGroup = device.createBindGroup({
1109
1249
  layout: pipeline.getBindGroupLayout(0),
1110
1250
  entries: [...inputs.map((buffer, i) => ({
@@ -1134,13 +1274,18 @@ function pipelineSubmit(device, pipelines, inputs, outputs) {
1134
1274
  }
1135
1275
  for (let i = 0; i < filteredPasses.length; i++) {
1136
1276
  const { grid } = filteredPasses[i];
1137
- const passEncoder = commandEncoder.beginComputePass();
1277
+ const passEncoder = commandEncoder.beginComputePass({ timestampWrites: slot ? {
1278
+ querySet: slot.batch.querySet,
1279
+ beginningOfPassWriteIndex: i === 0 ? slot.beginIndex : void 0,
1280
+ endOfPassWriteIndex: i === filteredPasses.length - 1 ? slot.endIndex : void 0
1281
+ } : void 0 });
1138
1282
  passEncoder.setPipeline(pipeline);
1139
1283
  passEncoder.setBindGroup(0, bindGroup);
1140
1284
  if (uniformBindGroup) passEncoder.setBindGroup(1, uniformBindGroup, [i * uniformAlignment]);
1141
1285
  passEncoder.dispatchWorkgroups(grid[0], grid[1]);
1142
1286
  passEncoder.end();
1143
1287
  }
1288
+ if (slot) recordTrace(device, slot, source, filteredPasses.length, shader.code);
1144
1289
  }
1145
1290
  device.queue.submit([commandEncoder.finish()]);
1146
1291
  }
@@ -1,4 +1,4 @@
1
- import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, findPow2, isFloatDtype, mapSetUnion, prod, range, strip1, tuneWebgpu } from "./backend-nEolvdLv.js";
1
+ import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, emitTrace, findPow2, isFloatDtype, isTracing, mapSetUnion, onFlushTrace, prod, range, strip1, traceSourceInfo, tuneWebgpu } from "./backend-Ctqs8la1.js";
2
2
 
3
3
  //#region src/backend/webgpu/builtins.ts
4
4
  const threefrySrc = `
@@ -152,6 +152,7 @@ var SyncReader = class SyncReader {
152
152
  this.device = device;
153
153
  }
154
154
  #init() {
155
+ if (typeof OffscreenCanvas === "undefined") throw new Error("OffscreenCanvas is not available in this environment, so you cannot read data from WebGPU synchronously. Consider using the async API.");
155
156
  const makeCanvas = () => new OffscreenCanvas(SyncReader.width, SyncReader.height);
156
157
  this.deviceStorage = SyncReader.alphaModes.map(makeCanvas);
157
158
  this.deviceContexts = this.deviceStorage.map((canvas, i) => {
@@ -247,6 +248,10 @@ function bitonicSortUniform(pass) {
247
248
  * `2^(step+1)` with multiple workgroups. This doesn't use shared memory.
248
249
  *
249
250
  * The total number of passes is roughly `log2(n / workgroupSize)^2 / 2`.
251
+ *
252
+ * If `outputIndices` is true, the shader also tracks the original indices of
253
+ * the sorted elements (argsort) and outputs them to a separate buffer. This
254
+ * also makes the sorting algorithm stable.
250
255
  */
251
256
  function bitonicSortShader(device, dtype, n, batches, outputIndices) {
252
257
  const ty = dtypeToWgsl(dtype, true);
@@ -286,14 +291,21 @@ ${isFloatDtype(dtype) ? `
286
291
  fn compare_and_swap(i: u32, j: u32) {
287
292
  let val_i = shared_vals[i];
288
293
  let val_j = shared_vals[j];
289
- if (compare(val_j, val_i)) {
294
+ ${outputIndices ? `
295
+ if (
296
+ compare(val_j, val_i) ||
297
+ (!compare(val_i, val_j) && shared_idx[j] < shared_idx[i])
298
+ ) {
290
299
  shared_vals[i] = val_j;
291
300
  shared_vals[j] = val_i;
292
- ${outputIndices ? `
293
301
  let tmp_idx = shared_idx[i];
294
302
  shared_idx[i] = shared_idx[j];
295
- shared_idx[j] = tmp_idx;` : ""}
296
- }
303
+ shared_idx[j] = tmp_idx;
304
+ }` : `
305
+ if (compare(val_j, val_i)) {
306
+ shared_vals[i] = val_j;
307
+ shared_vals[j] = val_i;
308
+ }`}
297
309
  }
298
310
 
299
311
  @compute @workgroup_size(${workgroupSize})
@@ -370,13 +382,17 @@ ${outputIndices ? `
370
382
  if (j < ${n}u) {
371
383
  let val_i = output[base + i];
372
384
  let val_j = output[base + j];
373
- if (compare(val_j, val_i)) {
385
+ ${outputIndices ? `
386
+ let idx_i = output_idx[base + i];
387
+ let idx_j = output_idx[base + j];
388
+ if (compare(val_j, val_i) || (!compare(val_i, val_j) && idx_j < idx_i)) {
374
389
  output[base + i] = val_j;
375
390
  output[base + j] = val_i;
376
- ${outputIndices ? `
377
- let tmp_idx = output_idx[base + i];
378
- output_idx[base + i] = output_idx[base + j];
379
- output_idx[base + j] = tmp_idx;` : ""}
391
+ output_idx[base + i] = idx_j;
392
+ output_idx[base + j] = idx_i;` : `
393
+ if (compare(val_j, val_i)) {
394
+ output[base + i] = val_j;
395
+ output[base + j] = val_i;`}
380
396
  }
381
397
  }
382
398
  }
@@ -713,6 +729,120 @@ function createRoutineShader(device, routine) {
713
729
  }
714
730
  }
715
731
 
732
+ //#endregion
733
+ //#region src/backend/webgpu/tracing.ts
734
+ const MAX_TIMESTAMP_QUERIES = 4096;
735
+ const activeBatch = /* @__PURE__ */ new WeakMap();
736
+ function createTracingBatch(device) {
737
+ return {
738
+ querySet: device.createQuerySet({
739
+ type: "timestamp",
740
+ count: MAX_TIMESTAMP_QUERIES
741
+ }),
742
+ resolve: device.createBuffer({
743
+ size: MAX_TIMESTAMP_QUERIES * 8,
744
+ usage: GPUBufferUsage.QUERY_RESOLVE | GPUBufferUsage.COPY_SRC
745
+ }),
746
+ dst: device.createBuffer({
747
+ size: MAX_TIMESTAMP_QUERIES * 8,
748
+ usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST
749
+ }),
750
+ nextIndex: 0,
751
+ entries: []
752
+ };
753
+ }
754
+ function acquireTracingSlot(device) {
755
+ if (!device.features.has("timestamp-query")) return void 0;
756
+ let batch = activeBatch.get(device);
757
+ if (batch && batch.nextIndex >= MAX_TIMESTAMP_QUERIES) {
758
+ flushTracingBatch(device, batch);
759
+ batch = void 0;
760
+ }
761
+ if (!batch) {
762
+ batch = createTracingBatch(device);
763
+ activeBatch.set(device, batch);
764
+ onFlushTrace(() => {
765
+ const b = activeBatch.get(device);
766
+ if (b && b.entries.length > 0) flushTracingBatch(device, b);
767
+ activeBatch.delete(device);
768
+ });
769
+ }
770
+ const beginIndex = batch.nextIndex;
771
+ const endIndex = beginIndex + 1;
772
+ batch.nextIndex += 2;
773
+ return {
774
+ batch,
775
+ beginIndex,
776
+ endIndex
777
+ };
778
+ }
779
+ /**
780
+ * If tracing is active, acquire a slot for timestamp queries.
781
+ *
782
+ * Returns undefined if tracing is not active or the device doesn't support
783
+ * timestamp queries.
784
+ */
785
+ function maybeAcquireTracingSlot(device) {
786
+ if (!isTracing()) return void 0;
787
+ return acquireTracingSlot(device);
788
+ }
789
+ /**
790
+ * Record a tracing entry for a pipeline dispatch and schedule an auto-flush.
791
+ */
792
+ function recordTrace(device, slot, source, numPasses, wgslSource) {
793
+ const info = traceSourceInfo(source);
794
+ info.properties.push(["passes", `${numPasses}`]);
795
+ info.properties.push(["source", wgslSource]);
796
+ slot.batch.entries.push({
797
+ ...info,
798
+ beginIndex: slot.beginIndex,
799
+ endIndex: slot.endIndex
800
+ });
801
+ scheduleAutoFlush(device);
802
+ }
803
+ /**
804
+ * If the active batch has pending entries, flush and replace it so traces
805
+ * are emitted without waiting for the batch to fill or stopTrace().
806
+ *
807
+ * Called after each dispatch records its entry via a microtask so that
808
+ * synchronous back-to-back dispatches are still batched together.
809
+ */
810
+ function scheduleAutoFlush(device) {
811
+ queueMicrotask(() => {
812
+ const batch = activeBatch.get(device);
813
+ if (batch && batch.entries.length > 0) {
814
+ flushTracingBatch(device, batch);
815
+ activeBatch.set(device, createTracingBatch(device));
816
+ }
817
+ });
818
+ }
819
+ function flushTracingBatch(device, batch) {
820
+ if (batch.entries.length === 0) return;
821
+ const usedQueries = batch.nextIndex;
822
+ const encoder = device.createCommandEncoder();
823
+ encoder.resolveQuerySet(batch.querySet, 0, usedQueries, batch.resolve, 0);
824
+ encoder.copyBufferToBuffer(batch.resolve, 0, batch.dst, 0, usedQueries * 8);
825
+ device.queue.submit([encoder.finish()]);
826
+ const { entries } = batch;
827
+ batch.dst.mapAsync(GPUMapMode.READ).then(() => {
828
+ try {
829
+ const times = new BigInt64Array(batch.dst.getMappedRange());
830
+ const anchorGpuNs = times[entries[entries.length - 1].endIndex];
831
+ const anchorCpuMs = performance.now();
832
+ for (const entry of entries) {
833
+ const startMs = anchorCpuMs + Number(times[entry.beginIndex] - anchorGpuNs) / 1e6;
834
+ const endMs = anchorCpuMs + Number(times[entry.endIndex] - anchorGpuNs) / 1e6;
835
+ emitTrace("webgpu", entry, startMs, endMs);
836
+ }
837
+ } finally {
838
+ batch.dst.unmap();
839
+ batch.querySet.destroy();
840
+ batch.resolve.destroy();
841
+ batch.dst.destroy();
842
+ }
843
+ });
844
+ }
845
+
716
846
  //#endregion
717
847
  //#region src/backend/webgpu.ts
718
848
  /** Implementation of `Backend` that uses WebGPU in browsers. */
@@ -857,7 +987,7 @@ var WebGPUBackend = class {
857
987
  dispatch(exe, inputs, outputs) {
858
988
  const inputBuffers = inputs.map((slot) => this.#getBuffer(slot).buffer);
859
989
  const outputBuffers = outputs.map((slot) => this.#getBuffer(slot).buffer);
860
- pipelineSubmit(this.device, exe.data, inputBuffers, outputBuffers);
990
+ pipelineSubmit(this.device, exe, inputBuffers, outputBuffers);
861
991
  }
862
992
  #getBuffer(slot) {
863
993
  const buffer = this.buffers.get(slot);
@@ -995,8 +1125,16 @@ function pipelineSource(device, kernel) {
995
1125
  else if (op === AluOp.Reciprocal) source = `(1.0 / ${a})`;
996
1126
  else if (op === AluOp.Floor) source = `floor(${strip1(a)})`;
997
1127
  else if (op === AluOp.Ceil) source = `ceil(${strip1(a)})`;
998
- else if (op === AluOp.Cast) source = `${dtypeToWgsl(dtype)}(${strip1(a)})`;
999
- else if (op === AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${strip1(a)})`;
1128
+ else if (op === AluOp.Cast) {
1129
+ const srcTy = dtypeToWgsl(src[0].dtype);
1130
+ const dstTy = dtypeToWgsl(dtype);
1131
+ if (isFloatDtype(src[0].dtype) && !(isFloatDtype(dtype) || dtype === DType.Bool)) {
1132
+ const maxVal = maxValueWgsl(dtype);
1133
+ const x = isGensym(a) ? a : gensym();
1134
+ if (x !== a) emit(`let ${x}: ${srcTy} = ${strip1(a)};`);
1135
+ source = `select(${dstTy}(${x}), ${maxVal}, ${x} >= ${srcTy}(${maxVal}))`;
1136
+ } else source = `${dstTy}(${strip1(a)})`;
1137
+ } else if (op === AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${strip1(a)})`;
1000
1138
  }
1001
1139
  else if (op === AluOp.Where) source = `select(${strip1(gen(src[2]))}, ${strip1(gen(src[1]))}, ${strip1(gen(src[0]))})`;
1002
1140
  else if (op === AluOp.Threefry2x32) {
@@ -1099,12 +1237,14 @@ function pipelineSource(device, kernel) {
1099
1237
  passes: [{ grid: [gridX, gridY] }]
1100
1238
  };
1101
1239
  }
1102
- function pipelineSubmit(device, pipelines, inputs, outputs) {
1240
+ function pipelineSubmit(device, exe, inputs, outputs) {
1241
+ const { data: pipelines, source } = exe;
1103
1242
  const commandEncoder = device.createCommandEncoder();
1104
1243
  for (const { pipeline,...shader } of pipelines) {
1105
1244
  if (inputs.length !== shader.numInputs || outputs.length !== shader.numOutputs) throw new Error(`webgpu: expected ${shader.numInputs} inputs and ${shader.numOutputs} outputs, got ${inputs.length} inputs and ${outputs.length} outputs`);
1106
1245
  const filteredPasses = shader.passes.filter(({ grid }) => prod(grid) > 0);
1107
1246
  if (filteredPasses.length === 0) continue;
1247
+ const slot = maybeAcquireTracingSlot(device);
1108
1248
  const bindGroup = device.createBindGroup({
1109
1249
  layout: pipeline.getBindGroupLayout(0),
1110
1250
  entries: [...inputs.map((buffer, i) => ({
@@ -1134,13 +1274,18 @@ function pipelineSubmit(device, pipelines, inputs, outputs) {
1134
1274
  }
1135
1275
  for (let i = 0; i < filteredPasses.length; i++) {
1136
1276
  const { grid } = filteredPasses[i];
1137
- const passEncoder = commandEncoder.beginComputePass();
1277
+ const passEncoder = commandEncoder.beginComputePass({ timestampWrites: slot ? {
1278
+ querySet: slot.batch.querySet,
1279
+ beginningOfPassWriteIndex: i === 0 ? slot.beginIndex : void 0,
1280
+ endOfPassWriteIndex: i === filteredPasses.length - 1 ? slot.endIndex : void 0
1281
+ } : void 0 });
1138
1282
  passEncoder.setPipeline(pipeline);
1139
1283
  passEncoder.setBindGroup(0, bindGroup);
1140
1284
  if (uniformBindGroup) passEncoder.setBindGroup(1, uniformBindGroup, [i * uniformAlignment]);
1141
1285
  passEncoder.dispatchWorkgroups(grid[0], grid[1]);
1142
1286
  passEncoder.end();
1143
1287
  }
1288
+ if (slot) recordTrace(device, slot, source, filteredPasses.length, shader.code);
1144
1289
  }
1145
1290
  device.queue.submit([commandEncoder.finish()]);
1146
1291
  }
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@jax-js/jax",
3
- "version": "0.1.8",
3
+ "version": "0.1.10",
4
4
  "description": "Numerical computing and ML in the browser",
5
5
  "keywords": [
6
6
  "machine learning",
@@ -38,15 +38,13 @@
38
38
  "devDependencies": {
39
39
  "@eslint/js": "^9.31.0",
40
40
  "@types/debug": "^4.1.12",
41
- "@vitest/browser-playwright": "^4.0.9",
42
- "@vitest/coverage-v8": "4.0.9",
41
+ "@vitest/browser-playwright": "^4.1.0",
42
+ "@vitest/coverage-v8": "^4.1.0",
43
43
  "@webgpu/types": "^0.1.68",
44
44
  "eslint": "^9.31.0",
45
45
  "eslint-plugin-import": "^2.32.0",
46
46
  "globals": "^16.0.0",
47
- "husky": "^9.1.7",
48
- "lint-staged": "^16.2.7",
49
- "playwright": "~1.52.0",
47
+ "playwright": "~1.58.2",
50
48
  "prettier": "^3.6.2",
51
49
  "prettier-plugin-svelte": "^3.4.0",
52
50
  "tsdown": "^0.13.2",
@@ -55,7 +53,7 @@
55
53
  "typedoc-theme-fresh": "^0.2.3",
56
54
  "typescript": "~5.9.3",
57
55
  "typescript-eslint": "^8.46.4",
58
- "vitest": "^4.0.9"
56
+ "vitest": "^4.1.0"
59
57
  },
60
58
  "engines": {
61
59
  "pnpm": ">=10.0.0"
@@ -76,15 +74,6 @@
76
74
  ],
77
75
  "proseWrap": "always"
78
76
  },
79
- "lint-staged": {
80
- "*.{ts,tsx,js,jsx}": [
81
- "eslint --fix",
82
- "prettier --write"
83
- ],
84
- "*.{json,md,yml,yaml,css,svelte,html}": [
85
- "prettier --write"
86
- ]
87
- },
88
77
  "scripts": {
89
78
  "build": "tsdown",
90
79
  "build:watch": "TSDOWN_WATCH_MODE=1 tsdown",