tensorgrad 0.0.15 → 0.0.17
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +253 -119
- package/dist/index.d.ts +154 -193
- package/dist/index.js +2208 -39
- package/dist/index.js.map +7 -1
- package/dist/worker.debug.js +569 -0
- package/package.json +60 -58
- package/src/adam.ts +69 -15
- package/src/compile.ts +334 -156
- package/src/index.ts +8 -4
- package/src/module.ts +72 -34
- package/src/runtime.ts +56 -31
- package/src/worker-protocol.ts +183 -0
- package/src/worker-proxy.ts +76 -0
- package/src/worker.ts +281 -0
- package/dist/adam.js +0 -111
- package/dist/adam.js.map +0 -1
- package/dist/buffers.js +0 -120
- package/dist/buffers.js.map +0 -1
- package/dist/capture.js +0 -33
- package/dist/capture.js.map +0 -1
- package/dist/codegen.js +0 -724
- package/dist/codegen.js.map +0 -1
- package/dist/compile.js +0 -184
- package/dist/compile.js.map +0 -1
- package/dist/grad.js +0 -380
- package/dist/grad.js.map +0 -1
- package/dist/ir.js +0 -60
- package/dist/ir.js.map +0 -1
- package/dist/module.js +0 -155
- package/dist/module.js.map +0 -1
- package/dist/nn.js +0 -135
- package/dist/nn.js.map +0 -1
- package/dist/ops.js +0 -326
- package/dist/ops.js.map +0 -1
- package/dist/runtime.js +0 -402
- package/dist/runtime.js.map +0 -1
- package/dist/shape.js +0 -259
- package/dist/shape.js.map +0 -1
- package/dist/trace.js +0 -100
- package/dist/trace.js.map +0 -1
package/src/module.ts
CHANGED
|
@@ -32,30 +32,47 @@ import { paramInput } from './trace.js'
|
|
|
32
32
|
// Init metadata
|
|
33
33
|
// ============================================================================
|
|
34
34
|
|
|
35
|
-
/** How a parameter's initial values are produced.
|
|
36
|
-
*
|
|
37
|
-
*
|
|
38
|
-
*
|
|
39
|
-
*
|
|
40
|
-
* -
|
|
41
|
-
*
|
|
35
|
+
/** How a parameter's initial values are produced. Serializable shape — no
|
|
36
|
+
* closures, since the initial values cross the worker boundary at compile
|
|
37
|
+
* time. Use the `init` helpers for ergonomic construction.
|
|
38
|
+
*
|
|
39
|
+
* String shorthands:
|
|
40
|
+
* - `'randn'` — Gaussian with std 0.02 (the common weight-matrix init).
|
|
41
|
+
* - `'zeros'` — fill with 0 (biases, LayerNorm beta).
|
|
42
|
+
* - `'ones'` — fill with 1 (LayerNorm gain).
|
|
43
|
+
*
|
|
44
|
+
* Object shapes:
|
|
45
|
+
* - `{ kind: 'randn', scale }` — randn with explicit std.
|
|
46
|
+
* - `{ kind: 'kaiming', gain? }` — `std = gain / sqrt(fan_in)`. Default
|
|
47
|
+
* gain `sqrt(2)` (good for ReLU). `fan_in = shape[0]`.
|
|
48
|
+
* - `{ kind: 'literal', data }` — explicit Float32Array; length must
|
|
49
|
+
* match the parameter's element count.
|
|
42
50
|
*/
|
|
43
51
|
export type InitSpec =
|
|
44
52
|
| 'randn'
|
|
45
53
|
| 'zeros'
|
|
46
54
|
| 'ones'
|
|
47
|
-
|
|
|
55
|
+
| { readonly kind: 'randn'; readonly scale: number }
|
|
56
|
+
| { readonly kind: 'kaiming'; readonly gain?: number }
|
|
57
|
+
| { readonly kind: 'literal'; readonly data: Float32Array }
|
|
58
|
+
|
|
59
|
+
/** Ergonomic constructors for InitSpec object shapes. */
|
|
60
|
+
export const init = {
|
|
61
|
+
randn: (opts: { scale?: number } = {}): InitSpec => ({ kind: 'randn', scale: opts.scale ?? 0.02 }),
|
|
62
|
+
kaiming: (opts: { gain?: number } = {}): InitSpec =>
|
|
63
|
+
opts.gain !== undefined ? { kind: 'kaiming', gain: opts.gain } : { kind: 'kaiming' },
|
|
64
|
+
literal: (data: Float32Array): InitSpec => ({ kind: 'literal', data }),
|
|
65
|
+
}
|
|
48
66
|
|
|
49
67
|
export interface ParamOptions {
|
|
50
68
|
dtype?: Dtype
|
|
51
|
-
/** Init
|
|
69
|
+
/** Init shape. Default: `'randn'` (std 0.02). */
|
|
52
70
|
init?: InitSpec
|
|
53
|
-
/** Std dev for `'randn'`. Default 0.02. Ignored for non-randn init. */
|
|
54
|
-
scale?: number
|
|
55
71
|
/** Whether AdamW (when `weightDecay > 0`) should apply decoupled weight
|
|
56
|
-
* decay to this param. Default: `true` for
|
|
57
|
-
* embeddings)
|
|
58
|
-
* to force or skip. Replaces `adam.decayFilter` for
|
|
72
|
+
* decay to this param. Default: `true` for randn/kaiming/literal init
|
|
73
|
+
* (weight matrices, embeddings); `false` for zeros/ones (biases, LN
|
|
74
|
+
* gains). Override to force or skip. Replaces `adam.decayFilter` for
|
|
75
|
+
* the common case. */
|
|
59
76
|
decay?: boolean
|
|
60
77
|
}
|
|
61
78
|
|
|
@@ -65,31 +82,52 @@ function boxMuller(): number {
|
|
|
65
82
|
return Math.sqrt(-2 * Math.log(Math.max(1e-10, Math.random()))) * Math.cos(2 * Math.PI * Math.random())
|
|
66
83
|
}
|
|
67
84
|
|
|
68
|
-
function
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
return
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
85
|
+
function randnFn(scale: number): InitFn {
|
|
86
|
+
return (size) => {
|
|
87
|
+
const arr = new Float32Array(size)
|
|
88
|
+
for (let i = 0; i < size; i++) arr[i] = boxMuller() * scale
|
|
89
|
+
return arr
|
|
90
|
+
}
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
/** Compile-time-only: resolve an InitSpec shape into the closure that
|
|
94
|
+
* generates the initial Float32Array for a given parameter shape. Runs
|
|
95
|
+
* on the main thread before initial values are transferred to the worker. */
|
|
96
|
+
function resolveInit(spec: InitSpec | undefined): InitFn {
|
|
97
|
+
if (!spec || spec === 'randn') return randnFn(0.02)
|
|
98
|
+
if (spec === 'zeros') return (size) => new Float32Array(size)
|
|
99
|
+
if (spec === 'ones') return (size) => { const a = new Float32Array(size); a.fill(1); return a }
|
|
100
|
+
switch (spec.kind) {
|
|
101
|
+
case 'randn': return randnFn(spec.scale)
|
|
102
|
+
case 'kaiming': {
|
|
103
|
+
const gain = spec.gain ?? Math.sqrt(2)
|
|
104
|
+
return (size, shape) => {
|
|
105
|
+
const fanIn = shape[0] ?? size
|
|
106
|
+
const std = gain / Math.sqrt(fanIn)
|
|
107
|
+
const arr = new Float32Array(size)
|
|
108
|
+
for (let i = 0; i < size; i++) arr[i] = boxMuller() * std
|
|
109
|
+
return arr
|
|
110
|
+
}
|
|
111
|
+
}
|
|
112
|
+
case 'literal': {
|
|
113
|
+
const data = spec.data
|
|
114
|
+
return (size) => {
|
|
115
|
+
if (data.length !== size) {
|
|
116
|
+
throw new Error(`init.literal: data length ${data.length} doesn't match param size ${size}`)
|
|
117
|
+
}
|
|
118
|
+
return new Float32Array(data)
|
|
119
|
+
}
|
|
76
120
|
}
|
|
77
121
|
}
|
|
78
|
-
if (init === 'zeros') return (size) => new Float32Array(size)
|
|
79
|
-
if (init === 'ones') return (size) => { const a = new Float32Array(size); a.fill(1); return a }
|
|
80
|
-
if (typeof init === 'function') return init
|
|
81
|
-
throw new Error(`Unknown init: ${String(init)}`)
|
|
82
122
|
}
|
|
83
123
|
|
|
84
|
-
/** Resolve the decay default for a param.
|
|
85
|
-
*
|
|
86
|
-
* (
|
|
87
|
-
* inits are weight-shaped (Kaiming etc.). Explicit `decay: false` overrides. */
|
|
124
|
+
/** Resolve the decay default for a param. Weight-shaped inits (randn,
|
|
125
|
+
* kaiming, literal) default to decay=true; ones/zeros default to false
|
|
126
|
+
* (biases, LN gains). Explicit `decay` opt overrides. */
|
|
88
127
|
function resolveDecay(opts: ParamOptions | undefined): boolean {
|
|
89
128
|
if (opts?.decay !== undefined) return opts.decay
|
|
90
|
-
const
|
|
91
|
-
|
|
92
|
-
return true // 'randn' or function
|
|
129
|
+
const spec = opts?.init ?? 'randn'
|
|
130
|
+
return spec !== 'zeros' && spec !== 'ones'
|
|
93
131
|
}
|
|
94
132
|
|
|
95
133
|
// ============================================================================
|
|
@@ -127,7 +165,7 @@ export abstract class Module {
|
|
|
127
165
|
protected param(shape: Shape, opts?: ParamOptions): Tensor {
|
|
128
166
|
const dtype = opts?.dtype ?? 'f32'
|
|
129
167
|
// Lie to TypeScript: the sentinel becomes a Tensor at materialize time.
|
|
130
|
-
return new ParamSentinel(shape, dtype, resolveInit(opts), resolveDecay(opts)) as unknown as Tensor
|
|
168
|
+
return new ParamSentinel(shape, dtype, resolveInit(opts?.init), resolveDecay(opts)) as unknown as Tensor
|
|
131
169
|
}
|
|
132
170
|
}
|
|
133
171
|
|
package/src/runtime.ts
CHANGED
|
@@ -348,42 +348,67 @@ export async function createRuntime(
|
|
|
348
348
|
queue.writeBuffer(buffers.get(bufId)!, 0, data as unknown as BufferSource)
|
|
349
349
|
}
|
|
350
350
|
|
|
351
|
-
|
|
352
|
-
for
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
pass.setPipeline(pipeline)
|
|
359
|
-
pass.setBindGroup(0, bindGroup)
|
|
360
|
-
// WebGPU caps each dispatch dimension at 65535 workgroups. Split into 2D
|
|
361
|
-
// when a kernel needs more than that on the X axis. Kernels compute their
|
|
362
|
-
// global index as `gid.x + gid.y * (65535 * workgroup_size)`, matching the
|
|
363
|
-
// stride we set here. For dispatches that fit in one row, gid.y is 0.
|
|
364
|
-
const wgCount = Math.max(1, Math.ceil(k.threads / k.workgroupSize))
|
|
365
|
-
const MAX_X = 65535
|
|
366
|
-
const wgX = Math.min(wgCount, MAX_X)
|
|
367
|
-
const wgY = Math.ceil(wgCount / MAX_X)
|
|
368
|
-
pass.dispatchWorkgroups(wgX, wgY, 1)
|
|
369
|
-
pass.end()
|
|
370
|
-
}
|
|
371
|
-
// After all dispatches: writebacks (Adam state, updated params). Empty for
|
|
372
|
-
// forward-only compiles.
|
|
373
|
-
for (const wb of plan.writebacks) {
|
|
374
|
-
encoder.copyBufferToBuffer(buffers.get(wb.source)!, 0, buffers.get(wb.dest)!, 0, wb.bytes)
|
|
375
|
-
}
|
|
376
|
-
encoder.copyBufferToBuffer(buffers.get(lossBufferId)!, 0, outputReadback, 0, outputSpec.byteSize)
|
|
377
|
-
// Capture readbacks (only when opted in). All captures concatenate into
|
|
378
|
-
// a single staging buffer so we mapAsync once instead of N times.
|
|
351
|
+
// Chunked submit. One queue.submit() of all 240 kernels monopolizes the
|
|
352
|
+
// GPU for the full step duration, blocking compositor frames the entire
|
|
353
|
+
// time. Splitting into chunks with an explicit GPU-drain await between
|
|
354
|
+
// them gives the compositor a slot at each chunk boundary. On graphs
|
|
355
|
+
// smaller than CHUNK_SIZE this collapses to a single submit (no
|
|
356
|
+
// overhead). See specs/WorkerArchitecture.md / mobile-jank investigation.
|
|
357
|
+
const CHUNK_SIZE = 32
|
|
379
358
|
let layout: CaptureLayout | null = null
|
|
380
359
|
if (wantCaptures) {
|
|
360
|
+
// Compute layout up front so the last chunk can append capture copies.
|
|
381
361
|
layout = ensureCaptureStaging()
|
|
382
|
-
|
|
383
|
-
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
let kernelIdx = 0
|
|
365
|
+
while (kernelIdx < kernels.length) {
|
|
366
|
+
const chunkEnd = Math.min(kernelIdx + CHUNK_SIZE, kernels.length)
|
|
367
|
+
const isLast = chunkEnd === kernels.length
|
|
368
|
+
const encoder = device.createCommandEncoder({
|
|
369
|
+
label: kernels.length > CHUNK_SIZE ? `tensorgrad-chunk-${kernelIdx}` : 'tensorgrad-step',
|
|
370
|
+
})
|
|
371
|
+
for (let i = kernelIdx; i < chunkEnd; i++) {
|
|
372
|
+
const k = kernels[i]!
|
|
373
|
+
if (!k.wgsl || k.threads === 0) continue
|
|
374
|
+
const pipeline = pipelines[i]!
|
|
375
|
+
const bindGroup = bindGroups[i]!
|
|
376
|
+
const pass = encoder.beginComputePass({ label: k.opKind })
|
|
377
|
+
pass.setPipeline(pipeline)
|
|
378
|
+
pass.setBindGroup(0, bindGroup)
|
|
379
|
+
// WebGPU caps each dispatch dimension at 65535 workgroups. Split into 2D
|
|
380
|
+
// when a kernel needs more than that on the X axis. Kernels compute their
|
|
381
|
+
// global index as `gid.x + gid.y * (65535 * workgroup_size)`, matching the
|
|
382
|
+
// stride we set here. For dispatches that fit in one row, gid.y is 0.
|
|
383
|
+
const wgCount = Math.max(1, Math.ceil(k.threads / k.workgroupSize))
|
|
384
|
+
const MAX_X = 65535
|
|
385
|
+
const wgX = Math.min(wgCount, MAX_X)
|
|
386
|
+
const wgY = Math.ceil(wgCount / MAX_X)
|
|
387
|
+
pass.dispatchWorkgroups(wgX, wgY, 1)
|
|
388
|
+
pass.end()
|
|
384
389
|
}
|
|
390
|
+
if (isLast) {
|
|
391
|
+
// Writebacks (Adam state, updated params; empty for forward-only) +
|
|
392
|
+
// output readback copy + capture readback copies all go into the
|
|
393
|
+
// final chunk so a single mapAsync below sees everything.
|
|
394
|
+
for (const wb of plan.writebacks) {
|
|
395
|
+
encoder.copyBufferToBuffer(buffers.get(wb.source)!, 0, buffers.get(wb.dest)!, 0, wb.bytes)
|
|
396
|
+
}
|
|
397
|
+
encoder.copyBufferToBuffer(buffers.get(lossBufferId)!, 0, outputReadback, 0, outputSpec.byteSize)
|
|
398
|
+
if (layout) {
|
|
399
|
+
for (const s of layout.slices) {
|
|
400
|
+
encoder.copyBufferToBuffer(buffers.get(s.bufId)!, 0, layout.buffer, s.offset, s.byteSize)
|
|
401
|
+
}
|
|
402
|
+
}
|
|
403
|
+
}
|
|
404
|
+
queue.submit([encoder.finish()])
|
|
405
|
+
if (!isLast) {
|
|
406
|
+
// Drain the chunk before queuing the next one. This is the moment
|
|
407
|
+
// the compositor can interleave its own frame work onto the GPU.
|
|
408
|
+
await queue.onSubmittedWorkDone()
|
|
409
|
+
}
|
|
410
|
+
kernelIdx = chunkEnd
|
|
385
411
|
}
|
|
386
|
-
queue.submit([encoder.finish()])
|
|
387
412
|
|
|
388
413
|
// readback=false: training fire-and-forget. The encoder still copied
|
|
389
414
|
// loss → outputReadback (and captures → staging), but we don't await
|
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
// Wire format for the main-thread ↔ worker postMessage channel.
|
|
2
|
+
//
|
|
3
|
+
// All requests carry a numeric `id` assigned by the main thread; responses
|
|
4
|
+
// echo it back so the proxy can match concurrent in-flight calls. Every
|
|
5
|
+
// response is either `{ ok: true, result }` or `{ ok: false, error }`.
|
|
6
|
+
// Errors carry serialized name/message/stack so the proxy can reconstitute
|
|
7
|
+
// an Error with a working `instanceof` check on the receiving side.
|
|
8
|
+
//
|
|
9
|
+
// Inputs (typed arrays) and outputs (typed arrays, captures) are transferred
|
|
10
|
+
// rather than copied — see the per-request notes for which fields go on the
|
|
11
|
+
// transfer list. A single worker may host multiple compiled graphs (a train
|
|
12
|
+
// graph plus sibling forward graphs); each has a `graphId` issued by the
|
|
13
|
+
// main thread at compile time.
|
|
14
|
+
|
|
15
|
+
import type { Graph } from './ir.js'
|
|
16
|
+
import type { BufferPlan } from './buffers.js'
|
|
17
|
+
import type { KernelSpec } from './codegen.js'
|
|
18
|
+
import type { LRSchedule } from './adam.js'
|
|
19
|
+
|
|
20
|
+
// ============================================================================
|
|
21
|
+
// Serializable config (subset of AdamResolvedConfig that crosses the wire).
|
|
22
|
+
// `decayFilter` (a function, used only at compile time) is NOT part of this —
|
|
23
|
+
// the per-param decay decision is already baked into the IR by appendAdam
|
|
24
|
+
// before the IR ships to the worker.
|
|
25
|
+
// ============================================================================
|
|
26
|
+
|
|
27
|
+
export interface WireAdamConfig {
|
|
28
|
+
lr: LRSchedule
|
|
29
|
+
b1: number
|
|
30
|
+
b2: number
|
|
31
|
+
eps: number
|
|
32
|
+
weightDecay: number
|
|
33
|
+
lrIsScheduled: boolean
|
|
34
|
+
/** Names of the per-step scalar inputs the worker must populate before
|
|
35
|
+
* every step (`_adam_lrt`, optionally `_adam_decay_shrink`). Mirrors
|
|
36
|
+
* AdamResult so the worker can update them without re-deriving. */
|
|
37
|
+
lrtInputName: string
|
|
38
|
+
decayShrinkInputName: string | null
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
/** Compile output that crosses to the worker. Same fields as CompiledIR
|
|
42
|
+
* minus the `loss` tensor (carried by graph.outputs[0]). */
|
|
43
|
+
export interface WireIR {
|
|
44
|
+
graph: Graph
|
|
45
|
+
plan: BufferPlan
|
|
46
|
+
kernels: KernelSpec[]
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
// ============================================================================
|
|
50
|
+
// Requests (main → worker)
|
|
51
|
+
// ============================================================================
|
|
52
|
+
|
|
53
|
+
export type Req =
|
|
54
|
+
| { id: number; kind: 'createRuntime'; payload: CreateRuntimePayload }
|
|
55
|
+
| { id: number; kind: 'compileForward'; payload: CompileForwardPayload }
|
|
56
|
+
| { id: number; kind: 'step'; payload: StepPayload }
|
|
57
|
+
| { id: number; kind: 'run'; payload: RunPayload }
|
|
58
|
+
| { id: number; kind: 'uploadParams'; payload: UploadParamsPayload }
|
|
59
|
+
| { id: number; kind: 'downloadParams'; payload: { graphId: number } }
|
|
60
|
+
| { id: number; kind: 'downloadParamGrads'; payload: { graphId: number } }
|
|
61
|
+
| { id: number; kind: 'resetOptimizer'; payload: { graphId: number } }
|
|
62
|
+
| { id: number; kind: 'destroy'; payload: { graphId: number } }
|
|
63
|
+
|
|
64
|
+
/** Build the training runtime. Always graphId=0 for a fresh worker. */
|
|
65
|
+
export interface CreateRuntimePayload {
|
|
66
|
+
graphId: number
|
|
67
|
+
ir: WireIR
|
|
68
|
+
/** Initial param values per name. Transferred (zero-copy) — the main
|
|
69
|
+
* thread loses access after postMessage. */
|
|
70
|
+
initialParams: Record<string, Float32Array>
|
|
71
|
+
/** Adam config when training; absent for forward-only compiles. */
|
|
72
|
+
adam: WireAdamConfig | null
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
/** Build a sibling forward-only graph that shares param buffers with an
|
|
76
|
+
* existing graph (typically the training graph at graphId=0). */
|
|
77
|
+
export interface CompileForwardPayload {
|
|
78
|
+
graphId: number
|
|
79
|
+
parentGraphId: number
|
|
80
|
+
ir: WireIR
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
/** One training step. Inputs are transferred; the caller's typed arrays
|
|
84
|
+
* become detached after postMessage. */
|
|
85
|
+
export interface StepPayload {
|
|
86
|
+
graphId: number
|
|
87
|
+
inputs: Record<string, Int32Array | Float32Array>
|
|
88
|
+
withCaptures: boolean
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
/** Forward-only run. Same transfer semantics as `step`. */
|
|
92
|
+
export interface RunPayload {
|
|
93
|
+
graphId: number
|
|
94
|
+
inputs: Record<string, Int32Array | Float32Array>
|
|
95
|
+
withCaptures: boolean
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
export interface UploadParamsPayload {
|
|
99
|
+
graphId: number
|
|
100
|
+
params: Record<string, Float32Array> // transferred
|
|
101
|
+
partial: boolean
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
// ============================================================================
|
|
105
|
+
// Responses (worker → main)
|
|
106
|
+
// ============================================================================
|
|
107
|
+
|
|
108
|
+
export type Res<R = unknown> =
|
|
109
|
+
| { id: number; ok: true; result: R }
|
|
110
|
+
| { id: number; ok: false; error: WireError }
|
|
111
|
+
|
|
112
|
+
export interface WireError {
|
|
113
|
+
name: string
|
|
114
|
+
message: string
|
|
115
|
+
stack: string
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
// Per-request result shapes:
|
|
119
|
+
|
|
120
|
+
export interface CreateRuntimeResult {
|
|
121
|
+
paramNames: string[]
|
|
122
|
+
outputShape: number[]
|
|
123
|
+
kernelCount: number
|
|
124
|
+
captureShapes: Record<string, number[]>
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
export interface CompileForwardResult {
|
|
128
|
+
paramNames: string[]
|
|
129
|
+
outputShape: number[]
|
|
130
|
+
kernelCount: number
|
|
131
|
+
captureShapes: Record<string, number[]>
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
/** Step without `withCaptures` returns just `loss`. With captures, also
|
|
135
|
+
* populates `captures` (per-name Float32Array, all transferred back). */
|
|
136
|
+
export interface StepResultWire {
|
|
137
|
+
loss: number
|
|
138
|
+
captures: Record<string, Float32Array> | null
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
/** Run without `withCaptures` returns `{ output, captures: null }`.
|
|
142
|
+
* With captures, also populates `captures`. */
|
|
143
|
+
export interface RunResultWire {
|
|
144
|
+
output: Float32Array
|
|
145
|
+
captures: Record<string, Float32Array> | null
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
export interface DownloadParamsResult {
|
|
149
|
+
params: Record<string, Float32Array> // transferred
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
// ============================================================================
|
|
153
|
+
// Transfer-list helpers
|
|
154
|
+
// ============================================================================
|
|
155
|
+
|
|
156
|
+
/** Collect the underlying ArrayBuffers from a Record of typed arrays so we
|
|
157
|
+
* can pass them on `postMessage`'s transfer list. The values themselves
|
|
158
|
+
* stay in the Record; only their backing buffers move. */
|
|
159
|
+
export function transferablesOfRecord(
|
|
160
|
+
rec: Record<string, Int32Array | Float32Array>,
|
|
161
|
+
): ArrayBuffer[] {
|
|
162
|
+
const out: ArrayBuffer[] = []
|
|
163
|
+
for (const v of Object.values(rec)) out.push(v.buffer as ArrayBuffer)
|
|
164
|
+
return out
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
/** Serialize an Error to a wire-friendly shape, preserving stack + name so
|
|
168
|
+
* the receiving side can reconstitute an Error that an `instanceof`-aware
|
|
169
|
+
* caller (e.g., for `ShapeError`) can still pattern-match by name. */
|
|
170
|
+
export function wireError(e: unknown): WireError {
|
|
171
|
+
if (e instanceof Error) {
|
|
172
|
+
return { name: e.name, message: e.message, stack: e.stack ?? '' }
|
|
173
|
+
}
|
|
174
|
+
return { name: 'Error', message: String(e), stack: '' }
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
/** Reconstitute an Error from the wire shape on the receiving (main) side. */
|
|
178
|
+
export function reconstituteError(w: WireError): Error {
|
|
179
|
+
const err = new Error(w.message)
|
|
180
|
+
err.name = w.name
|
|
181
|
+
err.stack = w.stack
|
|
182
|
+
return err
|
|
183
|
+
}
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
// Main-thread half of the worker channel: request/response correlation,
|
|
2
|
+
// promise wiring, error reconstitution. Knows nothing about Adam, captures,
|
|
3
|
+
// IR, etc. — just shuttles typed messages.
|
|
4
|
+
|
|
5
|
+
import type { Req, Res, WireError } from './worker-protocol.js'
|
|
6
|
+
import { reconstituteError } from './worker-protocol.js'
|
|
7
|
+
|
|
8
|
+
interface PendingHandlers {
|
|
9
|
+
resolve: (v: unknown) => void
|
|
10
|
+
reject: (e: Error) => void
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
/** Spawn a worker from an inlined source string and provide a typed
|
|
14
|
+
* request/response channel. One WorkerProxy = one Worker = one GPUDevice
|
|
15
|
+
* on the worker side. Sibling graphs share the same WorkerProxy. */
|
|
16
|
+
export class WorkerProxy {
|
|
17
|
+
private worker: Worker
|
|
18
|
+
private nextId = 1
|
|
19
|
+
private pending = new Map<number, PendingHandlers>()
|
|
20
|
+
private terminated = false
|
|
21
|
+
|
|
22
|
+
constructor(workerSource: string) {
|
|
23
|
+
const blob = new Blob([workerSource], { type: 'application/javascript' })
|
|
24
|
+
const url = URL.createObjectURL(blob)
|
|
25
|
+
this.worker = new Worker(url, { type: 'module' })
|
|
26
|
+
// The Blob URL keeps memory alive as long as it's referenced; revoke
|
|
27
|
+
// once the worker has loaded its source. Browsers tolerate revoke
|
|
28
|
+
// immediately after construction in practice.
|
|
29
|
+
URL.revokeObjectURL(url)
|
|
30
|
+
|
|
31
|
+
this.worker.onmessage = (ev: MessageEvent<Res>) => {
|
|
32
|
+
const reply = ev.data
|
|
33
|
+
const handlers = this.pending.get(reply.id)
|
|
34
|
+
if (!handlers) return // stale reply; ignore
|
|
35
|
+
this.pending.delete(reply.id)
|
|
36
|
+
if (reply.ok) handlers.resolve(reply.result)
|
|
37
|
+
else handlers.reject(reconstituteError(reply.error))
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
this.worker.onerror = (ev: ErrorEvent) => {
|
|
41
|
+
const err = new Error(`tensorgrad worker error: ${ev.message || 'unknown'}`)
|
|
42
|
+
const wire: WireError = { name: 'WorkerError', message: err.message, stack: err.stack ?? '' }
|
|
43
|
+
// Reject everything in flight; subsequent calls will fail too.
|
|
44
|
+
for (const handlers of this.pending.values()) handlers.reject(reconstituteError(wire))
|
|
45
|
+
this.pending.clear()
|
|
46
|
+
}
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
/** Send a request and await its matching response. `transfer` lists the
|
|
50
|
+
* ArrayBuffers to move (zero-copy) into the worker. */
|
|
51
|
+
request<R>(req: Omit<Req, 'id'>, transfer: ArrayBuffer[] = []): Promise<R> {
|
|
52
|
+
if (this.terminated) return Promise.reject(new Error('tensorgrad: worker has been terminated'))
|
|
53
|
+
const id = this.nextId++
|
|
54
|
+
return new Promise<R>((resolve, reject) => {
|
|
55
|
+
this.pending.set(id, { resolve: resolve as (v: unknown) => void, reject })
|
|
56
|
+
this.worker.postMessage({ ...req, id } as Req, transfer)
|
|
57
|
+
})
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
/** Fire-and-forget variant for cases where the caller doesn't need a reply
|
|
61
|
+
* (currently unused; keep for symmetry / future use). */
|
|
62
|
+
send(req: Omit<Req, 'id'>, transfer: ArrayBuffer[] = []): void {
|
|
63
|
+
if (this.terminated) return
|
|
64
|
+
const id = this.nextId++
|
|
65
|
+
this.worker.postMessage({ ...req, id } as Req, transfer)
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
terminate(): void {
|
|
69
|
+
if (this.terminated) return
|
|
70
|
+
this.terminated = true
|
|
71
|
+
this.worker.terminate()
|
|
72
|
+
const err = new Error('tensorgrad: worker terminated')
|
|
73
|
+
for (const handlers of this.pending.values()) handlers.reject(err)
|
|
74
|
+
this.pending.clear()
|
|
75
|
+
}
|
|
76
|
+
}
|