tensorgrad 0.0.16 → 0.0.17

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -148,34 +148,50 @@ ${k.wgsl}
148
148
  }
149
149
  queue.writeBuffer(buffers.get(bufId), 0, data);
150
150
  }
151
- const encoder = device2.createCommandEncoder({ label: "tensorgrad-step" });
152
- for (let i = 0; i < kernels.length; i++) {
153
- const k = kernels[i];
154
- if (!k.wgsl || k.threads === 0) continue;
155
- const pipeline = pipelines[i];
156
- const bindGroup = bindGroups[i];
157
- const pass = encoder.beginComputePass({ label: k.opKind });
158
- pass.setPipeline(pipeline);
159
- pass.setBindGroup(0, bindGroup);
160
- const wgCount = Math.max(1, Math.ceil(k.threads / k.workgroupSize));
161
- const MAX_X = 65535;
162
- const wgX = Math.min(wgCount, MAX_X);
163
- const wgY = Math.ceil(wgCount / MAX_X);
164
- pass.dispatchWorkgroups(wgX, wgY, 1);
165
- pass.end();
166
- }
167
- for (const wb of plan.writebacks) {
168
- encoder.copyBufferToBuffer(buffers.get(wb.source), 0, buffers.get(wb.dest), 0, wb.bytes);
169
- }
170
- encoder.copyBufferToBuffer(buffers.get(lossBufferId), 0, outputReadback, 0, outputSpec.byteSize);
151
+ const CHUNK_SIZE = 32;
171
152
  let layout = null;
172
153
  if (wantCaptures) {
173
154
  layout = ensureCaptureStaging();
174
- for (const s of layout.slices) {
175
- encoder.copyBufferToBuffer(buffers.get(s.bufId), 0, layout.buffer, s.offset, s.byteSize);
155
+ }
156
+ let kernelIdx = 0;
157
+ while (kernelIdx < kernels.length) {
158
+ const chunkEnd = Math.min(kernelIdx + CHUNK_SIZE, kernels.length);
159
+ const isLast = chunkEnd === kernels.length;
160
+ const encoder = device2.createCommandEncoder({
161
+ label: kernels.length > CHUNK_SIZE ? `tensorgrad-chunk-${kernelIdx}` : "tensorgrad-step"
162
+ });
163
+ for (let i = kernelIdx; i < chunkEnd; i++) {
164
+ const k = kernels[i];
165
+ if (!k.wgsl || k.threads === 0) continue;
166
+ const pipeline = pipelines[i];
167
+ const bindGroup = bindGroups[i];
168
+ const pass = encoder.beginComputePass({ label: k.opKind });
169
+ pass.setPipeline(pipeline);
170
+ pass.setBindGroup(0, bindGroup);
171
+ const wgCount = Math.max(1, Math.ceil(k.threads / k.workgroupSize));
172
+ const MAX_X = 65535;
173
+ const wgX = Math.min(wgCount, MAX_X);
174
+ const wgY = Math.ceil(wgCount / MAX_X);
175
+ pass.dispatchWorkgroups(wgX, wgY, 1);
176
+ pass.end();
176
177
  }
178
+ if (isLast) {
179
+ for (const wb of plan.writebacks) {
180
+ encoder.copyBufferToBuffer(buffers.get(wb.source), 0, buffers.get(wb.dest), 0, wb.bytes);
181
+ }
182
+ encoder.copyBufferToBuffer(buffers.get(lossBufferId), 0, outputReadback, 0, outputSpec.byteSize);
183
+ if (layout) {
184
+ for (const s of layout.slices) {
185
+ encoder.copyBufferToBuffer(buffers.get(s.bufId), 0, layout.buffer, s.offset, s.byteSize);
186
+ }
187
+ }
188
+ }
189
+ queue.submit([encoder.finish()]);
190
+ if (!isLast) {
191
+ await queue.onSubmittedWorkDone();
192
+ }
193
+ kernelIdx = chunkEnd;
177
194
  }
178
- queue.submit([encoder.finish()]);
179
195
  if (!opts2.readback) return null;
180
196
  await outputReadback.mapAsync(GPUMapMode.READ);
181
197
  const output = new Float32Array(outputReadback.getMappedRange().slice(0));
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "tensorgrad",
3
- "version": "0.0.16",
3
+ "version": "0.0.17",
4
4
  "description": "Tiny TypeScript-native tensor library with autograd, compiling to WebGPU. Train small models in the browser without hand-writing kernels.",
5
5
  "license": "MIT",
6
6
  "author": "Ben Albahari",
package/src/runtime.ts CHANGED
@@ -348,42 +348,67 @@ export async function createRuntime(
348
348
  queue.writeBuffer(buffers.get(bufId)!, 0, data as unknown as BufferSource)
349
349
  }
350
350
 
351
- const encoder = device.createCommandEncoder({ label: 'tensorgrad-step' })
352
- for (let i = 0; i < kernels.length; i++) {
353
- const k = kernels[i]!
354
- if (!k.wgsl || k.threads === 0) continue
355
- const pipeline = pipelines[i]!
356
- const bindGroup = bindGroups[i]!
357
- const pass = encoder.beginComputePass({ label: k.opKind })
358
- pass.setPipeline(pipeline)
359
- pass.setBindGroup(0, bindGroup)
360
- // WebGPU caps each dispatch dimension at 65535 workgroups. Split into 2D
361
- // when a kernel needs more than that on the X axis. Kernels compute their
362
- // global index as `gid.x + gid.y * (65535 * workgroup_size)`, matching the
363
- // stride we set here. For dispatches that fit in one row, gid.y is 0.
364
- const wgCount = Math.max(1, Math.ceil(k.threads / k.workgroupSize))
365
- const MAX_X = 65535
366
- const wgX = Math.min(wgCount, MAX_X)
367
- const wgY = Math.ceil(wgCount / MAX_X)
368
- pass.dispatchWorkgroups(wgX, wgY, 1)
369
- pass.end()
370
- }
371
- // After all dispatches: writebacks (Adam state, updated params). Empty for
372
- // forward-only compiles.
373
- for (const wb of plan.writebacks) {
374
- encoder.copyBufferToBuffer(buffers.get(wb.source)!, 0, buffers.get(wb.dest)!, 0, wb.bytes)
375
- }
376
- encoder.copyBufferToBuffer(buffers.get(lossBufferId)!, 0, outputReadback, 0, outputSpec.byteSize)
377
- // Capture readbacks (only when opted in). All captures concatenate into
378
- // a single staging buffer so we mapAsync once instead of N times.
351
+ // Chunked submit. One queue.submit() of all 240 kernels monopolizes the
352
+ // GPU for the full step duration, blocking compositor frames the entire
353
+ // time. Splitting into chunks with an explicit GPU-drain await between
354
+ // them gives the compositor a slot at each chunk boundary. On graphs
355
+ // smaller than CHUNK_SIZE this collapses to a single submit (no
356
+ // overhead). See specs/WorkerArchitecture.md / mobile-jank investigation.
357
+ const CHUNK_SIZE = 32
379
358
  let layout: CaptureLayout | null = null
380
359
  if (wantCaptures) {
360
+ // Compute layout up front so the last chunk can append capture copies.
381
361
  layout = ensureCaptureStaging()
382
- for (const s of layout.slices) {
383
- encoder.copyBufferToBuffer(buffers.get(s.bufId)!, 0, layout.buffer, s.offset, s.byteSize)
362
+ }
363
+
364
+ let kernelIdx = 0
365
+ while (kernelIdx < kernels.length) {
366
+ const chunkEnd = Math.min(kernelIdx + CHUNK_SIZE, kernels.length)
367
+ const isLast = chunkEnd === kernels.length
368
+ const encoder = device.createCommandEncoder({
369
+ label: kernels.length > CHUNK_SIZE ? `tensorgrad-chunk-${kernelIdx}` : 'tensorgrad-step',
370
+ })
371
+ for (let i = kernelIdx; i < chunkEnd; i++) {
372
+ const k = kernels[i]!
373
+ if (!k.wgsl || k.threads === 0) continue
374
+ const pipeline = pipelines[i]!
375
+ const bindGroup = bindGroups[i]!
376
+ const pass = encoder.beginComputePass({ label: k.opKind })
377
+ pass.setPipeline(pipeline)
378
+ pass.setBindGroup(0, bindGroup)
379
+ // WebGPU caps each dispatch dimension at 65535 workgroups. Split into 2D
380
+ // when a kernel needs more than that on the X axis. Kernels compute their
381
+ // global index as `gid.x + gid.y * (65535 * workgroup_size)`, matching the
382
+ // stride we set here. For dispatches that fit in one row, gid.y is 0.
383
+ const wgCount = Math.max(1, Math.ceil(k.threads / k.workgroupSize))
384
+ const MAX_X = 65535
385
+ const wgX = Math.min(wgCount, MAX_X)
386
+ const wgY = Math.ceil(wgCount / MAX_X)
387
+ pass.dispatchWorkgroups(wgX, wgY, 1)
388
+ pass.end()
384
389
  }
390
+ if (isLast) {
391
+ // Writebacks (Adam state, updated params; empty for forward-only) +
392
+ // output readback copy + capture readback copies all go into the
393
+ // final chunk so a single mapAsync below sees everything.
394
+ for (const wb of plan.writebacks) {
395
+ encoder.copyBufferToBuffer(buffers.get(wb.source)!, 0, buffers.get(wb.dest)!, 0, wb.bytes)
396
+ }
397
+ encoder.copyBufferToBuffer(buffers.get(lossBufferId)!, 0, outputReadback, 0, outputSpec.byteSize)
398
+ if (layout) {
399
+ for (const s of layout.slices) {
400
+ encoder.copyBufferToBuffer(buffers.get(s.bufId)!, 0, layout.buffer, s.offset, s.byteSize)
401
+ }
402
+ }
403
+ }
404
+ queue.submit([encoder.finish()])
405
+ if (!isLast) {
406
+ // Drain the chunk before queuing the next one. This is the moment
407
+ // the compositor can interleave its own frame work onto the GPU.
408
+ await queue.onSubmittedWorkDone()
409
+ }
410
+ kernelIdx = chunkEnd
385
411
  }
386
- queue.submit([encoder.finish()])
387
412
 
388
413
  // readback=false: training fire-and-forget. The encoder still copied
389
414
  // loss → outputReadback (and captures → staging), but we don't await