tensorgrad 0.0.14 → 0.0.16

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/compile.ts CHANGED
@@ -7,15 +7,43 @@
7
7
  // Module tree; the library auto-discovers
8
8
  // params, traces the forward, appends grad
9
9
  // and Adam, and returns a runtime.
10
+ //
11
+ // As of the worker-architecture refactor: compile-time work (trace, autograd,
12
+ // buffer planning, codegen) runs on the main thread. createRuntime and all
13
+ // dispatch/mapAsync work runs in a Web Worker spawned per top-level compile;
14
+ // the returned `CompiledModule` is a thin proxy over the worker channel.
15
+ // See specs/WorkerArchitecture.md.
10
16
 
11
17
  import type { Tensor, Shape, Dtype } from './ir.js'
12
18
  import { trace, tensorInput } from './trace.js'
13
19
  import { appendGrad, type GradResult } from './grad.js'
14
- import { appendAdam, type AdamConfig, type AdamResult } from './adam.js'
20
+ import {
21
+ appendAdam, resolveLR,
22
+ type AdamConfig, type AdamResult, type AdamResolvedConfig,
23
+ } from './adam.js'
15
24
  import { planBuffers, type BufferPlan } from './buffers.js'
16
25
  import { emitKernels, type KernelSpec } from './codegen.js'
17
- import { createRuntime, createForwardRuntime, type CompiledRuntime, type CompiledForward, type RuntimeOpts } from './runtime.js'
26
+ import {
27
+ Captures, type RunResult, type StepResult, type RunOptions, type UploadParamsOptions,
28
+ } from './runtime.js'
18
29
  import { Module, materializeParams, type MaterializedParams } from './module.js'
30
+ import { WorkerProxy } from './worker-proxy.js'
31
+ import {
32
+ transferablesOfRecord,
33
+ type Req, type WireIR, type WireAdamConfig,
34
+ type CreateRuntimeResult, type CompileForwardResult,
35
+ type StepResultWire, type RunResultWire, type DownloadParamsResult,
36
+ } from './worker-protocol.js'
37
+
38
+ // `__WORKER_SOURCE__` is replaced at build time by scripts/build.mjs with the
39
+ // stringified contents of the bundled src/worker.ts. Declared here so TS is
40
+ // happy; substituted as a string literal by esbuild's `define` during
41
+ // `npm run build:js`. See scripts/build.mjs.
42
+ declare const __WORKER_SOURCE__: string
43
+
44
+ // ============================================================================
45
+ // Public types
46
+ // ============================================================================
19
47
 
20
48
  /** Declares one input tensor of the model's forward function. The name is the
21
49
  * key in the `inputs:` Record at compile time and the key on the `step()`/
@@ -25,21 +53,14 @@ export interface InputDecl {
25
53
  dtype?: Dtype
26
54
  }
27
55
 
28
- /** Inputs declaration: a Record from input name to its shape/dtype. The name
29
- * doubles as the key the forward fn destructures and the key the runtime
30
- * expects in `step({...})` / `run({...})`. */
56
+ /** Inputs declaration: a Record from input name to its shape/dtype. */
31
57
  export type InputDecls = Record<string, InputDecl>
32
58
 
33
59
  /** Maps an `InputDecls` Record to its forward-time tensor counterpart —
34
- * same keys, each value is a Tensor. Used to type the forward function's
35
- * `inputs` argument from the declared shape Record. */
60
+ * same keys, each value is a Tensor. */
36
61
  export type InputsTensors<I extends InputDecls> = { [K in keyof I]: Tensor }
37
62
 
38
- /** Forward function shape: takes the materialized model and a Record of
39
- * named input tensors (matching the declared `inputs:` keys), returns the
40
- * output tensor (loss for compileModule; logits/etc. for compileForward).
41
- * The second generic flows from the inputs declaration so destructuring
42
- * the input record stays typed. */
63
+ /** Forward function shape. */
43
64
  export type ForwardFn<M extends Module, I extends InputDecls = InputDecls> =
44
65
  (m: M, inputs: InputsTensors<I>) => Tensor
45
66
 
@@ -60,75 +81,86 @@ export function compileToIR(traceFn: () => Tensor): CompiledIR {
60
81
  return { graph, paramGrads, loss, plan, kernels }
61
82
  }
62
83
 
63
- /** Full compile pipeline. Browser-only because it creates a GPUDevice. */
64
- export async function compile(traceFn: () => Tensor, opts: RuntimeOpts = {}): Promise<CompiledRuntime & { ir: CompiledIR }> {
65
- const ir = compileToIR(traceFn)
66
- const lossBufferId = ir.plan.tensorToBuffer.get(ir.loss.id)!
67
- const runtime = await createRuntime(ir.plan, ir.kernels, lossBufferId, opts)
68
- return Object.assign(runtime, { ir })
69
- }
70
-
71
84
  // ============================================================================
72
- // Module-aware compile
85
+ // CompiledModule / CompiledForwardModule — main-thread proxy surface
73
86
  // ============================================================================
74
87
 
75
- export interface CompileModuleOptions<I extends InputDecls = InputDecls> extends RuntimeOpts {
76
- /** Per-step data inputs to the forward function, keyed by name. The forward
77
- * fn destructures these out of its second argument; runtime calls to
78
- * `step()` / `run()` pass typed arrays under the same keys. */
88
+ export interface CompileModuleOptions<I extends InputDecls = InputDecls> {
79
89
  inputs?: I
80
- /** Adam hyperparameters. If omitted, no optimizer is appended (forward-only). */
81
90
  adam?: AdamConfig
82
91
  }
83
92
 
84
- export interface CompileForwardOptions<I extends InputDecls = InputDecls> extends RuntimeOpts {
85
- /** Per-step data inputs to the forward function, keyed by name. */
93
+ export interface CompileForwardOptions<I extends InputDecls = InputDecls> {
86
94
  inputs?: I
87
95
  }
88
96
 
89
- /** Forward-only compile options as taken by the `compileForward` *method* on
90
- * a training runtime — no `device` (inherited) and no `sharedParams`
91
- * (auto-supplied from the train graph's params). */
92
97
  export interface CompileForwardMethodOptions<I extends InputDecls = InputDecls> {
93
98
  inputs?: I
94
99
  }
95
100
 
96
- /** Returned by `compileModule`. Adds training-graph extras (auto-init, reset,
97
- * sibling-graph compile) on top of the base runtime. */
98
- export interface CompiledModule<M extends Module> extends CompiledRuntime {
99
- ir: CompiledIR
100
- /** Number of dispatchable kernels (excludes leaf no-ops). */
101
- kernelCount: number
102
- /** Re-initialize all params from their declared init specs and zero the
103
- * optimizer state. Use to start training over without recompiling. */
104
- reset(): void
105
- /** Compile a sibling forward-only graph (e.g., a B=1 inference graph or a
106
- * B=N held-out eval graph) that shares this runtime's device and param
107
- * buffers. Pass the forward fn (typically distinct from your loss fn —
108
- * it returns logits, not a scalar) and any shape changes via `inputs`.
109
- * Auto-initialization is a no-op since params are shared. */
101
+ /** Returned by `compileModule`. Proxies all GPU work to a worker held
102
+ * internally; user code awaits Promises and never sees the worker. */
103
+ export interface CompiledModule<M extends Module> {
104
+ readonly ir: CompiledIR
105
+ readonly kernelCount: number
106
+ readonly outputShape: readonly number[]
107
+ /** Names of the model's parameters, in materialization order. The actual
108
+ * GPUBuffers live in the worker; use `downloadParams()` for values. */
109
+ readonly paramNames: readonly string[]
110
+
111
+ step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>
112
+ step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<StepResult>
113
+
114
+ run(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>
115
+ run(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunResult>
116
+
117
+ uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): Promise<void>
118
+ downloadParams(): Promise<Record<string, Float32Array>>
119
+ downloadParamGrads(): Promise<Record<string, Float32Array>>
120
+
121
+ /** Re-initialize all params + zero optimizer state. */
122
+ reset(): Promise<void>
123
+ resetOptimizerState(): Promise<void>
124
+
125
+ /** Compile a sibling forward-only graph that shares this runtime's worker
126
+ * (and therefore its param GPUBuffers). */
110
127
  compileForward<I extends InputDecls>(
111
128
  forward: ForwardFn<M, I>,
112
129
  opts?: CompileForwardMethodOptions<I>,
113
130
  ): Promise<CompiledForwardModule>
131
+
132
+ /** Free the runtime's GPU resources and terminate the worker. */
133
+ destroy(): void
114
134
  }
115
135
 
116
136
  /** Returned by `compileForward` (and by the `compileForward` method). */
117
- export interface CompiledForwardModule extends CompiledForward {
118
- ir: CompiledIR
119
- /** Number of dispatchable kernels (excludes leaf no-ops). */
120
- kernelCount: number
137
+ export interface CompiledForwardModule {
138
+ readonly ir: CompiledIR
139
+ readonly kernelCount: number
140
+ readonly outputShape: readonly number[]
141
+ readonly paramNames: readonly string[]
142
+
143
+ run(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>
144
+ run(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunResult>
145
+
146
+ uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): Promise<void>
147
+ downloadParams(): Promise<Record<string, Float32Array>>
148
+
149
+ destroy(): void
121
150
  }
122
151
 
152
+ // ============================================================================
153
+ // compileModule / compileForward
154
+ // ============================================================================
155
+
123
156
  /**
124
157
  * Compile a Module-based model. Pass a *factory* `() => new Model()`, not the
125
158
  * model instance itself: compilation mutates the tree (every `ParamSentinel`
126
159
  * field becomes a real `Tensor`), so the instance is consumed and shouldn't be
127
- * referenced afterwards. Re-call the factory if you need a fresh tree.
160
+ * referenced afterwards.
128
161
  *
129
162
  * The forward function takes the materialized model and a Record of named
130
- * input tensors, returns the loss tensor. Inputs are matched by name with the
131
- * `inputs:` declaration:
163
+ * input tensors, returns the loss tensor:
132
164
  *
133
165
  * inputs: {
134
166
  * tokens: { shape: [B, T], dtype: 'i32' },
@@ -136,20 +168,16 @@ export interface CompiledForwardModule extends CompiledForward {
136
168
  * }
137
169
  * forward: (m, { tokens, targets }) => …
138
170
  *
139
- * Walks the module tree to materialize params with auto-derived names, then
140
- * runs trace grad adam buffer plan codegen → runtime. Initial
141
- * parameter values are uploaded automatically before this function returns;
142
- * call `reset()` later to re-randomize.
143
- *
144
- * If `opts.adam` is set, the runtime's `step()` automatically tracks an
145
- * internal step count and injects the bias-corrected `lrt` scalar each call;
146
- * users don't need to provide it themselves.
171
+ * Returns a `CompiledModule` proxy. All GPU work (createRuntime, step, run,
172
+ * mapAsync) happens in an internal worker; calls return Promises that resolve
173
+ * when the worker replies.
147
174
  */
148
175
  export async function compileModule<M extends Module, I extends InputDecls = InputDecls>(
149
176
  modelFactory: () => M,
150
177
  forward: ForwardFn<M, I>,
151
178
  opts: CompileModuleOptions<I> = {},
152
179
  ): Promise<CompiledModule<M>> {
180
+ // ---- Compile-time work (main thread) ------------------------------------
153
181
  const { graph, materialized } = traceModule(modelFactory, forward, opts.inputs ?? {})
154
182
  const { paramGrads, loss } = appendGrad(graph)
155
183
  const adamResult = opts.adam
@@ -158,55 +186,40 @@ export async function compileModule<M extends Module, I extends InputDecls = Inp
158
186
 
159
187
  const plan = planBuffers(graph, paramGrads, adamResult?.writebacks ?? [])
160
188
  const kernels = emitKernels(graph, plan)
161
- const lossBufferId = plan.tensorToBuffer.get(loss.id)!
162
- const runtime = await createRuntime(plan, kernels, lossBufferId, opts)
163
-
164
- if (adamResult) wrapStepForAdam(runtime, adamResult)
165
- uploadInitialParams(plan, materialized.initFns, runtime, /* sharedParams */ undefined)
166
-
167
189
  const ir: CompiledIR = { graph, paramGrads, loss, plan, kernels }
168
- const kernelCount = countKernels(kernels)
169
190
 
170
- const reset = () => {
171
- uploadInitialParams(plan, materialized.initFns, runtime, undefined)
172
- runtime.resetOptimizerState()
191
+ // Initial params: resolve init shapes to Float32Arrays now (main thread).
192
+ // These transfer (zero-copy) to the worker as part of createRuntime.
193
+ const initialParams = buildInitialParams(plan, materialized.initFns)
194
+
195
+ // ---- Spawn worker, send IR + initial params -----------------------------
196
+ const proxy = new WorkerProxy(__WORKER_SOURCE__)
197
+ const wireIR: WireIR = { graph, plan, kernels }
198
+ const wireAdam = adamResult ? wireAdamConfig(adamResult) : null
199
+ const transfers = transferablesOfRecord(initialParams)
200
+
201
+ let meta: CreateRuntimeResult
202
+ try {
203
+ meta = await proxy.request<CreateRuntimeResult>(
204
+ { kind: 'createRuntime', payload: { graphId: 0, ir: wireIR, initialParams, adam: wireAdam } },
205
+ transfers,
206
+ )
207
+ } catch (e) {
208
+ proxy.terminate()
209
+ throw e
173
210
  }
174
211
 
175
- const compileForwardMethod = <J extends InputDecls>(
176
- forwardFn: ForwardFn<M, J>,
177
- fOpts: CompileForwardMethodOptions<J> = {},
178
- ): Promise<CompiledForwardModule> =>
179
- compileForward<M, J>(modelFactory, forwardFn, {
180
- ...fOpts,
181
- device: runtime.device,
182
- sharedParams: runtime.params,
183
- })
184
-
185
- return Object.assign(runtime, { ir, kernelCount, reset, compileForward: compileForwardMethod })
212
+ return new CompiledModuleProxy<M>(
213
+ proxy, /* graphId */ 0, ir, meta, modelFactory,
214
+ /* initFns */ materialized.initFns,
215
+ /* nextGraphId */ { v: 1 },
216
+ )
186
217
  }
187
218
 
188
- // ============================================================================
189
- // Forward-only compile
190
- // ============================================================================
191
-
192
219
  /**
193
- * Compile a Module-based model in forward-only mode (no autograd, no Adam).
194
- * The forward function returns the output tensor (e.g., logits) instead of a
195
- * scalar loss; runtime exposes `run(inputs)` returning the full output as a
196
- * `Float32Array`.
197
- *
198
- * **Prefer the `compileForward` method on a training runtime** when both
199
- * graphs use the same Module class — it auto-supplies `device` and
200
- * `sharedParams`. This standalone form is for forward-only models with no
201
- * training graph at all, or for sharing params across a different model.
202
- *
203
- * **Sharing params with a training compile.** Pass `opts.sharedParams =
204
- * trainCompiled.params` to bind this graph's param buffers to an existing
205
- * training runtime's GPU buffers — every train step is then immediately
206
- * visible to `run()` calls here, no copies.
207
- *
208
- * Initial param values are uploaded automatically for params *not* covered
209
- * by `sharedParams` (those are owned by the sibling compile).
220
+ * Forward-only compile. Spawns its own worker. For sibling graphs that share
221
+ * params with a training graph, prefer the `compileForward` method on the
222
+ * CompiledModule returned by `compileModule()`.
210
223
  */
211
224
  export async function compileForward<M extends Module, I extends InputDecls = InputDecls>(
212
225
  modelFactory: () => M,
@@ -215,16 +228,195 @@ export async function compileForward<M extends Module, I extends InputDecls = In
215
228
  ): Promise<CompiledForwardModule> {
216
229
  const { graph, materialized } = traceModule(modelFactory, forward, opts.inputs ?? {})
217
230
  const outputTensor = graph.tensors[graph.outputs[0]!]!
218
-
219
231
  const plan = planBuffers(graph, /* paramGrads */ {})
220
232
  const kernels = emitKernels(graph, plan)
221
- const outputBufferId = plan.tensorToBuffer.get(outputTensor.id)!
222
- const runtime = await createForwardRuntime(plan, kernels, outputBufferId, opts)
233
+ const ir: CompiledIR = { graph, paramGrads: {}, loss: outputTensor, plan, kernels }
223
234
 
224
- uploadInitialParams(plan, materialized.initFns, runtime, opts.sharedParams)
235
+ const initialParams = buildInitialParams(plan, materialized.initFns)
236
+ const proxy = new WorkerProxy(__WORKER_SOURCE__)
237
+ const wireIR: WireIR = { graph, plan, kernels }
238
+ const transfers = transferablesOfRecord(initialParams)
239
+
240
+ let meta: CreateRuntimeResult
241
+ try {
242
+ meta = await proxy.request<CreateRuntimeResult>(
243
+ { kind: 'createRuntime', payload: { graphId: 0, ir: wireIR, initialParams, adam: null } },
244
+ transfers,
245
+ )
246
+ } catch (e) {
247
+ proxy.terminate()
248
+ throw e
249
+ }
225
250
 
226
- const ir: CompiledIR = { graph, paramGrads: {}, loss: outputTensor, plan, kernels }
227
- return Object.assign(runtime, { ir, kernelCount: countKernels(kernels) })
251
+ return new CompiledForwardModuleProxy(proxy, /* graphId */ 0, ir, meta, /* ownsWorker */ true)
252
+ }
253
+
254
+ // ============================================================================
255
+ // Proxy implementations
256
+ // ============================================================================
257
+
258
+ class CompiledModuleProxy<M extends Module> implements CompiledModule<M> {
259
+ constructor(
260
+ private readonly proxy: WorkerProxy,
261
+ private readonly graphId: number,
262
+ public readonly ir: CompiledIR,
263
+ private readonly meta: CreateRuntimeResult,
264
+ private readonly modelFactory: () => M,
265
+ /** Init closures captured from materializeParams at compile time. Used
266
+ * by reset() to regenerate initial param values. */
267
+ private readonly initFns: Record<string, InitFn>,
268
+ private readonly nextGraphId: { v: number },
269
+ ) {}
270
+
271
+ get kernelCount(): number { return this.meta.kernelCount }
272
+ get outputShape(): readonly number[] { return this.meta.outputShape }
273
+ get paramNames(): readonly string[] { return this.meta.paramNames }
274
+
275
+ step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>
276
+ step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<StepResult>
277
+ async step(
278
+ inputs: Record<string, Int32Array | Float32Array>,
279
+ opts?: { withCaptures?: boolean },
280
+ ): Promise<number | StepResult> {
281
+ // Note: inputs are copied (not transferred) into the worker. Callers
282
+ // commonly reuse the same TypedArray as a scratch buffer across step()
283
+ // calls; transferring would detach it. The copy cost is small relative
284
+ // to a training step's GPU work.
285
+ const r = await this.proxy.request<StepResultWire>(
286
+ { kind: 'step', payload: { graphId: this.graphId, inputs, withCaptures: opts?.withCaptures === true } },
287
+ )
288
+ if (opts?.withCaptures) {
289
+ return { loss: r.loss, captures: makeCaptures(r.captures, this.meta.captureShapes) }
290
+ }
291
+ return r.loss
292
+ }
293
+
294
+ run(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>
295
+ run(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunResult>
296
+ async run(
297
+ inputs: Record<string, Int32Array | Float32Array>,
298
+ opts?: { withCaptures?: boolean },
299
+ ): Promise<Float32Array | RunResult> {
300
+ // Inputs copied (see note in step()).
301
+ const r = await this.proxy.request<RunResultWire>(
302
+ { kind: 'run', payload: { graphId: this.graphId, inputs, withCaptures: opts?.withCaptures === true } },
303
+ )
304
+ if (opts?.withCaptures) {
305
+ return { output: r.output, captures: makeCaptures(r.captures, this.meta.captureShapes) }
306
+ }
307
+ return r.output
308
+ }
309
+
310
+ uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): Promise<void> {
311
+ // Params copied (see note in step()) — caller's Float32Arrays stay valid.
312
+ return this.proxy.request<null>(
313
+ { kind: 'uploadParams', payload: { graphId: this.graphId, params, partial: !!opts?.partial } },
314
+ ).then(() => undefined)
315
+ }
316
+
317
+ async downloadParams(): Promise<Record<string, Float32Array>> {
318
+ const r = await this.proxy.request<DownloadParamsResult>(
319
+ { kind: 'downloadParams', payload: { graphId: this.graphId } },
320
+ )
321
+ return r.params
322
+ }
323
+
324
+ async downloadParamGrads(): Promise<Record<string, Float32Array>> {
325
+ const r = await this.proxy.request<DownloadParamsResult>(
326
+ { kind: 'downloadParamGrads', payload: { graphId: this.graphId } },
327
+ )
328
+ return r.params
329
+ }
330
+
331
+ async reset(): Promise<void> {
332
+ // Re-init main-thread, upload, then reset Adam state on worker. Two
333
+ // round-trips but reset() is rare. The init closures were captured at
334
+ // compile time and stashed on the proxy.
335
+ const initialParams = buildInitialParams(this.ir.plan, this.initFns)
336
+ await this.uploadParams(initialParams)
337
+ await this.resetOptimizerState()
338
+ }
339
+
340
+ resetOptimizerState(): Promise<void> {
341
+ return this.proxy.request<null>(
342
+ { kind: 'resetOptimizer', payload: { graphId: this.graphId } },
343
+ ).then(() => undefined)
344
+ }
345
+
346
+ async compileForward<I extends InputDecls>(
347
+ forward: ForwardFn<M, I>,
348
+ opts: CompileForwardMethodOptions<I> = {},
349
+ ): Promise<CompiledForwardModule> {
350
+ const { graph, materialized: _materialized } = traceModule(this.modelFactory, forward, opts.inputs ?? {})
351
+ const outputTensor = graph.tensors[graph.outputs[0]!]!
352
+ const plan = planBuffers(graph, /* paramGrads */ {})
353
+ const kernels = emitKernels(graph, plan)
354
+ const ir: CompiledIR = { graph, paramGrads: {}, loss: outputTensor, plan, kernels }
355
+
356
+ const childGraphId = this.nextGraphId.v++
357
+ const wireIR: WireIR = { graph, plan, kernels }
358
+
359
+ const meta = await this.proxy.request<CompileForwardResult>(
360
+ { kind: 'compileForward', payload: { graphId: childGraphId, parentGraphId: this.graphId, ir: wireIR } },
361
+ )
362
+
363
+ return new CompiledForwardModuleProxy(this.proxy, childGraphId, ir, meta, /* ownsWorker */ false)
364
+ }
365
+
366
+ destroy(): void {
367
+ // Fire-and-forget destroy; postMessage ordering ensures the worker
368
+ // processes any in-flight requests before we terminate it.
369
+ this.proxy.send({ kind: 'destroy', payload: { graphId: this.graphId } })
370
+ this.proxy.terminate()
371
+ }
372
+ }
373
+
374
+ class CompiledForwardModuleProxy implements CompiledForwardModule {
375
+ constructor(
376
+ private readonly proxy: WorkerProxy,
377
+ private readonly graphId: number,
378
+ public readonly ir: CompiledIR,
379
+ private readonly meta: CompileForwardResult | CreateRuntimeResult,
380
+ private readonly ownsWorker: boolean,
381
+ ) {}
382
+
383
+ get kernelCount(): number { return this.meta.kernelCount }
384
+ get outputShape(): readonly number[] { return this.meta.outputShape }
385
+ get paramNames(): readonly string[] { return this.meta.paramNames }
386
+
387
+ run(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>
388
+ run(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunResult>
389
+ async run(
390
+ inputs: Record<string, Int32Array | Float32Array>,
391
+ opts?: { withCaptures?: boolean },
392
+ ): Promise<Float32Array | RunResult> {
393
+ // Inputs copied; caller's TypedArrays stay valid.
394
+ const r = await this.proxy.request<RunResultWire>(
395
+ { kind: 'run', payload: { graphId: this.graphId, inputs, withCaptures: opts?.withCaptures === true } },
396
+ )
397
+ if (opts?.withCaptures) {
398
+ return { output: r.output, captures: makeCaptures(r.captures, this.meta.captureShapes) }
399
+ }
400
+ return r.output
401
+ }
402
+
403
+ uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): Promise<void> {
404
+ return this.proxy.request<null>(
405
+ { kind: 'uploadParams', payload: { graphId: this.graphId, params, partial: !!opts?.partial } },
406
+ ).then(() => undefined)
407
+ }
408
+
409
+ async downloadParams(): Promise<Record<string, Float32Array>> {
410
+ const r = await this.proxy.request<DownloadParamsResult>(
411
+ { kind: 'downloadParams', payload: { graphId: this.graphId } },
412
+ )
413
+ return r.params
414
+ }
415
+
416
+ destroy(): void {
417
+ this.proxy.send({ kind: 'destroy', payload: { graphId: this.graphId } })
418
+ if (this.ownsWorker) this.proxy.terminate()
419
+ }
228
420
  }
229
421
 
230
422
  // ============================================================================
@@ -255,58 +447,46 @@ function traceModule<M extends Module, I extends InputDecls>(
255
447
  return { graph, materialized }
256
448
  }
257
449
 
258
- const countKernels = (kernels: KernelSpec[]): number => kernels.filter(k => k.wgsl).length
259
-
260
- /** Wrap the runtime's step() to inject Adam's per-step `lrt` (bias-corrected
261
- * effective LR) and, when the user supplied a per-step lr schedule, the
262
- * decayShrink scalar. Also wraps resetOptimizerState() so a reset zeros
263
- * Adam's m/v *and* the bias-correction step counter — otherwise the next
264
- * step would skip Adam's warmup phase. */
265
- function wrapStepForAdam(runtime: CompiledRuntime, adamResult: AdamResult): void {
266
- const { lrtInputName, decayShrinkInputName, config } = adamResult
267
- let t = 0
268
- const lrtBuf = new Float32Array(1)
269
- const decayShrinkBuf = decayShrinkInputName ? new Float32Array(1) : null
270
- const innerStep = runtime.step.bind(runtime) as CompiledRuntime['step']
271
- const innerReset = runtime.resetOptimizerState.bind(runtime)
272
- const wrappedStep = ((
273
- inputs: Record<string, Int32Array | Float32Array>,
274
- opts?: { withCaptures?: boolean },
275
- ) => {
276
- t++
277
- const lrNow = config.lr(t)
278
- lrtBuf[0] = lrNow * Math.sqrt(1 - Math.pow(config.b2, t)) / (1 - Math.pow(config.b1, t))
279
- const merged: Record<string, Int32Array | Float32Array> = { ...inputs, [lrtInputName]: lrtBuf }
280
- if (decayShrinkBuf && decayShrinkInputName) {
281
- decayShrinkBuf[0] = 1 - lrNow * config.weightDecay
282
- merged[decayShrinkInputName] = decayShrinkBuf
283
- }
284
- return opts?.withCaptures ? innerStep(merged, { withCaptures: true }) : innerStep(merged)
285
- }) as CompiledRuntime['step']
286
- runtime.step = wrappedStep
287
- runtime.resetOptimizerState = () => {
288
- t = 0
289
- innerReset()
290
- }
291
- }
292
-
293
- /** Build a Record<paramName, Float32Array> by running each param's init
294
- * function against its shape and uploading them to the runtime. Skips any
295
- * param covered by `sharedParams` (those are owned by a sibling compile). */
296
- function uploadInitialParams(
297
- plan: BufferPlan,
298
- initFns: Record<string, InitFn>,
299
- runtime: CompiledRuntime | CompiledForward,
300
- sharedParams: Map<string, GPUBuffer> | undefined,
301
- ): void {
450
+ /** Run each param's init function against its declared shape to produce the
451
+ * initial Float32Arrays. Runs main-thread before transfer to the worker. */
452
+ function buildInitialParams(plan: BufferPlan, initFns: Record<string, InitFn>): Record<string, Float32Array> {
302
453
  const out: Record<string, Float32Array> = {}
303
454
  for (const [name, bufId] of plan.paramsByName) {
304
- if (sharedParams?.has(name)) continue
305
455
  const shape = plan.buffers[bufId]!.shape
306
456
  const size = shape.reduce((a, b) => a * b, 1)
307
457
  const initFn = initFns[name]
308
458
  if (!initFn) throw new Error(`compile: no init for param '${name}'`)
309
459
  out[name] = initFn(size, shape)
310
460
  }
311
- if (Object.keys(out).length > 0) runtime.uploadParams(out, { partial: !!sharedParams })
461
+ return out
462
+ }
463
+
464
+ /** Subset of AdamResolvedConfig that crosses the wire (drops decayFilter,
465
+ * which is only used at compile time). */
466
+ function wireAdamConfig(r: AdamResult): WireAdamConfig {
467
+ const c: AdamResolvedConfig = r.config
468
+ return {
469
+ lr: c.lr,
470
+ b1: c.b1,
471
+ b2: c.b2,
472
+ eps: c.eps,
473
+ weightDecay: c.weightDecay,
474
+ lrIsScheduled: c.lrIsScheduled,
475
+ lrtInputName: r.lrtInputName,
476
+ decayShrinkInputName: r.decayShrinkInputName,
477
+ }
312
478
  }
479
+
480
+ /** Wrap a worker-returned `Record<name, Float32Array>` in a Captures instance
481
+ * using the static capture shapes captured at compile time. */
482
+ function makeCaptures(
483
+ captures: Record<string, Float32Array> | null,
484
+ captureShapes: Record<string, number[]>,
485
+ ): Captures {
486
+ const data = new Map<string, Float32Array>()
487
+ if (captures) {
488
+ for (const [name, arr] of Object.entries(captures)) data.set(name, arr)
489
+ }
490
+ return new Captures(captureShapes, data)
491
+ }
492
+
package/src/index.ts CHANGED
@@ -33,15 +33,19 @@ export {
33
33
  // adam.ts can import them) but aren't part of the public API — `add`/`mul`
34
34
  // overload on JS numbers, `where` subsumes the rest.
35
35
  export { appendGrad, type GradResult } from './grad.js'
36
- export { appendAdam, type AdamConfig, type AdamResult } from './adam.js'
36
+ export { appendAdam, lr, resolveLR, type AdamConfig, type AdamResult, type LRSchedule } from './adam.js'
37
37
  export { planBuffers, type BufferPlan, type BufferSpec, type Writeback, type WritebackDecl } from './buffers.js'
38
38
  export { emitKernels, type KernelSpec } from './codegen.js'
39
- export { createRuntime, createForwardRuntime, Captures, type CompiledRuntime, type CompiledForward, type RuntimeOpts, type RunOptions, type StepResult, type RunResult } from './runtime.js'
39
+ // Runtime types: only the user-facing pieces. CompiledRuntime/CompiledForward
40
+ // (worker-internal) and createRuntime/createForwardRuntime aren't part of the
41
+ // public API — users get CompiledModule/CompiledForwardModule (proxies) from
42
+ // compileModule/compileForward instead.
43
+ export { Captures, type RunOptions, type StepResult, type RunResult, type UploadParamsOptions } from './runtime.js'
40
44
  export {
41
- compile, compileToIR, compileModule, compileForward,
45
+ compileToIR, compileModule, compileForward,
42
46
  type CompiledIR, type CompileModuleOptions, type CompileForwardOptions, type CompileForwardMethodOptions,
43
47
  type CompiledModule, type CompiledForwardModule,
44
48
  type InputDecl, type InputDecls, type InputsTensors, type ForwardFn,
45
49
  } from './compile.js'
46
- export { Module, materializeParams, type InitSpec, type ParamOptions, type MaterializedParams } from './module.js'
50
+ export { Module, materializeParams, init, type InitSpec, type ParamOptions, type MaterializedParams } from './module.js'
47
51
  export * as nn from './nn.js'