tensorgrad 0.0.11 → 0.0.12

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/compile.ts CHANGED
@@ -1,358 +1,358 @@
1
- // Top-level compile(): trace → autograd → buffer plan → codegen → runtime.
2
- //
3
- // Two entry points:
4
- // * `compile(traceFn)` — low-level. User declares params via
5
- // paramInput() inside the trace.
6
- // * `compileModule(model, …)` — high-level. User defines the model as a
7
- // Module tree; the library auto-discovers
8
- // params, traces the forward, appends grad
9
- // and Adam, and returns a runtime.
10
-
11
- import type { Tensor, Shape, Dtype } from './ir.js'
12
- import { trace, tensorInput } from './trace.js'
13
- import { appendGrad, type GradResult } from './grad.js'
14
- import { appendAdam, type AdamConfig } from './adam.js'
15
- import { planBuffers, type BufferPlan } from './buffers.js'
16
- import { emitKernels, type KernelSpec } from './codegen.js'
17
- import { createRuntime, createForwardRuntime, type CompiledRuntime, type CompiledForward, type RuntimeOpts } from './runtime.js'
18
- import { Module, materializeParams } from './module.js'
19
-
20
- /** Declares one input tensor of the model's forward function. The name is the
21
- * key in the `inputs:` Record at compile time and the key on the `step()`/
22
- * `run()` data object at runtime. */
23
- export interface InputDecl {
24
- shape: Shape
25
- dtype?: Dtype
26
- }
27
-
28
- /** Inputs declaration: a Record from input name to its shape/dtype. The name
29
- * doubles as the key the forward fn destructures and the key the runtime
30
- * expects in `step({...})` / `run({...})`. */
31
- export type InputDecls = Record<string, InputDecl>
32
-
33
- /** Maps an `InputDecls` Record to its forward-time tensor counterpart —
34
- * same keys, each value is a Tensor. Used to type the forward function's
35
- * `inputs` argument from the declared shape Record. */
36
- export type InputsTensors<I extends InputDecls> = { [K in keyof I]: Tensor }
37
-
38
- /** Forward function shape: takes the materialized model and a Record of
39
- * named input tensors (matching the declared `inputs:` keys), returns the
40
- * output tensor (loss for compileModule; logits/etc. for compileForward).
41
- * The second generic flows from the inputs declaration so destructuring
42
- * the input record stays typed. */
43
- export type ForwardFn<M extends Module, I extends InputDecls = InputDecls> =
44
- (m: M, inputs: InputsTensors<I>) => Tensor
45
-
46
- export interface CompiledIR {
47
- graph: GradResult['graph']
48
- paramGrads: GradResult['paramGrads']
49
- loss: Tensor
50
- plan: BufferPlan
51
- kernels: KernelSpec[]
52
- }
53
-
54
- /** Trace + autograd + buffer-plan + codegen, without touching WebGPU. */
55
- export function compileToIR(traceFn: () => Tensor): CompiledIR {
56
- const graph = trace(traceFn)
57
- const { paramGrads, loss } = appendGrad(graph)
58
- const plan = planBuffers(graph, paramGrads)
59
- const kernels = emitKernels(graph, plan)
60
- return { graph, paramGrads, loss, plan, kernels }
61
- }
62
-
63
- /** Full compile pipeline. Browser-only because it creates a GPUDevice. */
64
- export async function compile(traceFn: () => Tensor, opts: RuntimeOpts = {}): Promise<CompiledRuntime & { ir: CompiledIR }> {
65
- const ir = compileToIR(traceFn)
66
- const lossBufferId = ir.plan.tensorToBuffer.get(ir.loss.id)!
67
- const runtime = await createRuntime(ir.plan, ir.kernels, lossBufferId, opts)
68
- return Object.assign(runtime, { ir })
69
- }
70
-
71
- // ============================================================================
72
- // Module-aware compile
73
- // ============================================================================
74
-
75
- export interface CompileModuleOptions<I extends InputDecls = InputDecls> extends RuntimeOpts {
76
- /** Per-step data inputs to the forward function, keyed by name. The forward
77
- * fn destructures these out of its second argument; runtime calls to
78
- * `step()` / `run()` pass typed arrays under the same keys. */
79
- inputs?: I
80
- /** Adam hyperparameters. If omitted, no optimizer is appended (forward-only). */
81
- adam?: AdamConfig
82
- }
83
-
84
- export interface CompileForwardOptions<I extends InputDecls = InputDecls> extends RuntimeOpts {
85
- /** Per-step data inputs to the forward function, keyed by name. */
86
- inputs?: I
87
- }
88
-
89
- /** Forward-only compile options as taken by the `compileForward` *method* on
90
- * a training runtime — no `device` (inherited) and no `sharedParams`
91
- * (auto-supplied from the train graph's params). */
92
- export interface CompileForwardMethodOptions<I extends InputDecls = InputDecls> {
93
- inputs?: I
94
- }
95
-
96
- /** Returned by `compileModule`. Adds training-graph extras (auto-init, reset,
97
- * sibling-graph compile) on top of the base runtime. */
98
- export interface CompiledModule<M extends Module> extends CompiledRuntime {
99
- ir: CompiledIR
100
- /** Number of dispatchable kernels (excludes leaf no-ops). */
101
- kernelCount: number
102
- /** Re-initialize all params from their declared init specs and zero the
103
- * optimizer state. Use to start training over without recompiling. */
104
- reset(): void
105
- /** Compile a sibling forward-only graph (e.g., a B=1 inference graph or a
106
- * B=N held-out eval graph) that shares this runtime's device and param
107
- * buffers. Pass the forward fn (typically distinct from your loss fn —
108
- * it returns logits, not a scalar) and any shape changes via `inputs`.
109
- * Auto-initialization is a no-op since params are shared. */
110
- compileForward<I extends InputDecls>(
111
- forward: ForwardFn<M, I>,
112
- opts?: CompileForwardMethodOptions<I>,
113
- ): Promise<CompiledForwardModule>
114
- }
115
-
116
- /** Returned by `compileForward` (and by the `compileForward` method). */
117
- export interface CompiledForwardModule extends CompiledForward {
118
- ir: CompiledIR
119
- /** Number of dispatchable kernels (excludes leaf no-ops). */
120
- kernelCount: number
121
- }
122
-
123
- /**
124
- * Compile a Module-based model. Pass a *factory* `() => new Model()`, not the
125
- * model instance itself: compilation mutates the tree (every `ParamSentinel`
126
- * field becomes a real `Tensor`), so the instance is consumed and shouldn't be
127
- * referenced afterwards. Re-call the factory if you need a fresh tree.
128
- *
129
- * The forward function takes the materialized model and a Record of named
130
- * input tensors, returns the loss tensor. Inputs are matched by name with the
131
- * `inputs:` declaration:
132
- *
133
- * inputs: {
134
- * tokens: { shape: [B, T], dtype: 'i32' },
135
- * targets: { shape: [B, T], dtype: 'i32' },
136
- * }
137
- * forward: (m, { tokens, targets }) => …
138
- *
139
- * Walks the module tree to materialize params with auto-derived names, then
140
- * runs trace → grad → adam → buffer plan → codegen → runtime. Initial
141
- * parameter values are uploaded automatically before this function returns;
142
- * call `reset()` later to re-randomize.
143
- *
144
- * If `opts.adam` is set, the runtime's `step()` automatically tracks an
145
- * internal step count and injects the bias-corrected `lrt` scalar each call;
146
- * users don't need to provide it themselves.
147
- */
148
- export async function compileModule<M extends Module, I extends InputDecls = InputDecls>(
149
- modelFactory: () => M,
150
- forward: ForwardFn<M, I>,
151
- opts: CompileModuleOptions<I> = {},
152
- ): Promise<CompiledModule<M>> {
153
- const { runtime, materialized, plan, kernels, ir } = await buildModuleRuntime(
154
- modelFactory, forward, opts, /* sharedParams */ undefined, /* withGrad */ true,
155
- )
156
-
157
- // If Adam is enabled, wrap step() to track the step count and supply lrt
158
- // (and optionally decayShrink, when the user passed a per-step lr schedule).
159
- // Wrap resetOptimizerState() too, so a reset zeros m/v *and* the bias-correction
160
- // counter — otherwise the next step would skip Adam's warmup phase.
161
- if (opts.adam) {
162
- wrapStepForAdam(runtime, opts.adam, ir)
163
- }
164
-
165
- // Auto-upload initial param values. Always wanted at this entry point —
166
- // training runtimes own their params and need them randomized before step 1.
167
- uploadInitialParams(plan, materialized.initFns, runtime, /* sharedParams */ undefined)
168
-
169
- const kernelCount = kernels.filter(k => k.wgsl).length
170
-
171
- const reset = () => {
172
- uploadInitialParams(plan, materialized.initFns, runtime, undefined)
173
- runtime.resetOptimizerState()
174
- }
175
-
176
- const compileForwardMethod = async <J extends InputDecls>(
177
- forwardFn: ForwardFn<M, J>,
178
- fOpts: CompileForwardMethodOptions<J> = {},
179
- ): Promise<CompiledForwardModule> => {
180
- return compileForward<M, J>(modelFactory, forwardFn, {
181
- ...fOpts,
182
- device: runtime.device,
183
- sharedParams: runtime.params,
184
- })
185
- }
186
-
187
- return Object.assign(runtime, {
188
- ir,
189
- kernelCount,
190
- reset,
191
- compileForward: compileForwardMethod,
192
- })
193
- }
194
-
195
- // ============================================================================
196
- // Forward-only compile
197
- // ============================================================================
198
-
199
- /**
200
- * Compile a Module-based model in forward-only mode (no autograd, no Adam).
201
- * The forward function returns the output tensor (e.g., logits) instead of a
202
- * scalar loss; runtime exposes `run(inputs)` returning the full output as a
203
- * `Float32Array`.
204
- *
205
- * **Prefer the `compileForward` method on a training runtime** when both
206
- * graphs use the same Module class — it auto-supplies `device` and
207
- * `sharedParams`. This standalone form is for forward-only models with no
208
- * training graph at all, or for sharing params across a different model.
209
- *
210
- * **Sharing params with a training compile.** Pass `opts.sharedParams =
211
- * trainCompiled.params` to bind this graph's param buffers to an existing
212
- * training runtime's GPU buffers — every train step is then immediately
213
- * visible to `run()` calls here, no copies.
214
- *
215
- * Initial param values are uploaded automatically for params *not* covered
216
- * by `sharedParams` (those are owned by the sibling compile).
217
- */
218
- export async function compileForward<M extends Module, I extends InputDecls = InputDecls>(
219
- modelFactory: () => M,
220
- forward: ForwardFn<M, I>,
221
- opts: CompileForwardOptions<I> = {},
222
- ): Promise<CompiledForwardModule> {
223
- const sharedParams = opts.sharedParams
224
- const { runtime, materialized, plan, kernels, ir } = await buildModuleRuntime(
225
- modelFactory, forward, opts, sharedParams, /* withGrad */ false,
226
- )
227
-
228
- // Auto-upload initial values for any params this graph owns. With
229
- // `sharedParams` covering everything, this is a no-op.
230
- uploadInitialParams(plan, materialized.initFns, runtime, sharedParams)
231
-
232
- const kernelCount = kernels.filter(k => k.wgsl).length
233
- return Object.assign(runtime, { ir, kernelCount })
234
- }
235
-
236
- // ============================================================================
237
- // Internals
238
- // ============================================================================
239
-
240
- type InitFn = (size: number, shape: readonly number[]) => Float32Array
241
-
242
- interface BuiltRuntime {
243
- runtime: CompiledRuntime
244
- materialized: ReturnType<typeof materializeParams>
245
- plan: BufferPlan
246
- kernels: KernelSpec[]
247
- ir: CompiledIR
248
- }
249
-
250
- /** Shared body of compileModule + compileForward. The training and forward
251
- * pipelines diverge only in (a) whether grad/Adam are appended and (b)
252
- * whether the output buffer is the loss scalar or the user's returned
253
- * tensor — both come out of the same trace and codegen path. */
254
- async function buildModuleRuntime<M extends Module, I extends InputDecls>(
255
- modelFactory: () => M,
256
- forward: ForwardFn<M, I>,
257
- opts: CompileModuleOptions<I> | CompileForwardOptions<I>,
258
- sharedParams: Map<string, GPUBuffer> | undefined,
259
- withGrad: boolean,
260
- ): Promise<BuiltRuntime> {
261
- const inputDecls: InputDecls = opts.inputs ?? {}
262
- const model = modelFactory()
263
- let materialized: ReturnType<typeof materializeParams> = { tensors: {}, initFns: {}, decayFlags: {} }
264
- const graph = trace(() => {
265
- materialized = materializeParams(model)
266
- const inputTensors: Record<string, Tensor> = {}
267
- for (const [name, decl] of Object.entries(inputDecls)) {
268
- inputTensors[name] = tensorInput(name, decl.shape, decl.dtype ?? 'f32')
269
- }
270
- return forward(model, inputTensors as InputsTensors<I>)
271
- })
272
-
273
- let paramGrads: GradResult['paramGrads'] = {}
274
- let outputTensor: Tensor
275
- let adamWritebacks: ReturnType<typeof appendAdam>['writebacks'] = []
276
-
277
- if (withGrad) {
278
- const gradResult = appendGrad(graph)
279
- paramGrads = gradResult.paramGrads
280
- outputTensor = gradResult.loss
281
- const adamCfg = (opts as CompileModuleOptions).adam
282
- if (adamCfg) {
283
- const adamResult = appendAdam(graph, paramGrads, materialized.tensors, adamCfg, materialized.decayFlags)
284
- adamWritebacks = adamResult.writebacks
285
- // Stash adam result on the graph so wrapStepForAdam can find it.
286
- ;(graph as Graph & { __adam?: ReturnType<typeof appendAdam> }).__adam = adamResult
287
- }
288
- } else {
289
- outputTensor = graph.tensors[graph.outputs[0]!]!
290
- }
291
-
292
- const plan = planBuffers(graph, paramGrads, adamWritebacks)
293
- const kernels = emitKernels(graph, plan)
294
- const outputBufferId = plan.tensorToBuffer.get(outputTensor.id)!
295
- // exactOptionalPropertyTypes: only include sharedParams when defined.
296
- const runtimeOpts: RuntimeOpts = sharedParams
297
- ? { ...opts, sharedParams }
298
- : { ...opts }
299
- const runtime = withGrad
300
- ? await createRuntime(plan, kernels, outputBufferId, runtimeOpts)
301
- : await createForwardRuntime(plan, kernels, outputBufferId, runtimeOpts)
302
-
303
- const ir: CompiledIR = { graph, paramGrads, loss: outputTensor, plan, kernels }
304
- return { runtime: runtime as CompiledRuntime, materialized, plan, kernels, ir }
305
- }
306
-
307
- type Graph = ReturnType<typeof trace>
308
-
309
- function wrapStepForAdam(runtime: CompiledRuntime, adamCfg: AdamConfig, ir: CompiledIR): void {
310
- const adamResult = (ir.graph as Graph & { __adam?: ReturnType<typeof appendAdam> }).__adam!
311
- const { lrtInputName, decayShrinkInputName, config } = adamResult
312
- let t = 0
313
- const lrtBuf = new Float32Array(1)
314
- const decayShrinkBuf = decayShrinkInputName ? new Float32Array(1) : null
315
- const innerStep = runtime.step.bind(runtime) as CompiledRuntime['step']
316
- const innerReset = runtime.resetOptimizerState.bind(runtime)
317
- const wrappedStep = ((
318
- inputs: Record<string, Int32Array | Float32Array>,
319
- opts?: { withCaptures?: boolean },
320
- ) => {
321
- t++
322
- const lrNow = config.lr(t)
323
- lrtBuf[0] = lrNow * Math.sqrt(1 - Math.pow(config.b2, t)) / (1 - Math.pow(config.b1, t))
324
- const merged: Record<string, Int32Array | Float32Array> = { ...inputs, [lrtInputName]: lrtBuf }
325
- if (decayShrinkBuf && decayShrinkInputName) {
326
- decayShrinkBuf[0] = 1 - lrNow * config.weightDecay
327
- merged[decayShrinkInputName] = decayShrinkBuf
328
- }
329
- return opts?.withCaptures ? innerStep(merged, { withCaptures: true }) : innerStep(merged)
330
- }) as CompiledRuntime['step']
331
- runtime.step = wrappedStep
332
- runtime.resetOptimizerState = () => {
333
- t = 0
334
- innerReset()
335
- }
336
- void adamCfg
337
- }
338
-
339
- /** Build a Record<paramName, Float32Array> by running each param's init
340
- * function against its shape and uploading them to the runtime. Skips any
341
- * param covered by `sharedParams` (those are owned by a sibling compile). */
342
- function uploadInitialParams(
343
- plan: BufferPlan,
344
- initFns: Record<string, InitFn>,
345
- runtime: CompiledRuntime | CompiledForward,
346
- sharedParams: Map<string, GPUBuffer> | undefined,
347
- ): void {
348
- const out: Record<string, Float32Array> = {}
349
- for (const [name, bufId] of plan.paramsByName) {
350
- if (sharedParams?.has(name)) continue
351
- const shape = plan.buffers[bufId]!.shape
352
- const size = shape.reduce((a, b) => a * b, 1)
353
- const initFn = initFns[name]
354
- if (!initFn) throw new Error(`compile: no init for param '${name}'`)
355
- out[name] = initFn(size, shape)
356
- }
357
- if (Object.keys(out).length > 0) runtime.uploadParams(out, { partial: !!sharedParams })
358
- }
1
+ // Top-level compile(): trace → autograd → buffer plan → codegen → runtime.
2
+ //
3
+ // Two entry points:
4
+ // * `compile(traceFn)` — low-level. User declares params via
5
+ // paramInput() inside the trace.
6
+ // * `compileModule(model, …)` — high-level. User defines the model as a
7
+ // Module tree; the library auto-discovers
8
+ // params, traces the forward, appends grad
9
+ // and Adam, and returns a runtime.
10
+
11
+ import type { Tensor, Shape, Dtype } from './ir.js'
12
+ import { trace, tensorInput } from './trace.js'
13
+ import { appendGrad, type GradResult } from './grad.js'
14
+ import { appendAdam, type AdamConfig } from './adam.js'
15
+ import { planBuffers, type BufferPlan } from './buffers.js'
16
+ import { emitKernels, type KernelSpec } from './codegen.js'
17
+ import { createRuntime, createForwardRuntime, type CompiledRuntime, type CompiledForward, type RuntimeOpts } from './runtime.js'
18
+ import { Module, materializeParams } from './module.js'
19
+
20
+ /** Declares one input tensor of the model's forward function. The name is the
21
+ * key in the `inputs:` Record at compile time and the key on the `step()`/
22
+ * `run()` data object at runtime. */
23
+ export interface InputDecl {
24
+ shape: Shape
25
+ dtype?: Dtype
26
+ }
27
+
28
+ /** Inputs declaration: a Record from input name to its shape/dtype. The name
29
+ * doubles as the key the forward fn destructures and the key the runtime
30
+ * expects in `step({...})` / `run({...})`. */
31
+ export type InputDecls = Record<string, InputDecl>
32
+
33
+ /** Maps an `InputDecls` Record to its forward-time tensor counterpart —
34
+ * same keys, each value is a Tensor. Used to type the forward function's
35
+ * `inputs` argument from the declared shape Record. */
36
+ export type InputsTensors<I extends InputDecls> = { [K in keyof I]: Tensor }
37
+
38
+ /** Forward function shape: takes the materialized model and a Record of
39
+ * named input tensors (matching the declared `inputs:` keys), returns the
40
+ * output tensor (loss for compileModule; logits/etc. for compileForward).
41
+ * The second generic flows from the inputs declaration so destructuring
42
+ * the input record stays typed. */
43
+ export type ForwardFn<M extends Module, I extends InputDecls = InputDecls> =
44
+ (m: M, inputs: InputsTensors<I>) => Tensor
45
+
46
+ export interface CompiledIR {
47
+ graph: GradResult['graph']
48
+ paramGrads: GradResult['paramGrads']
49
+ loss: Tensor
50
+ plan: BufferPlan
51
+ kernels: KernelSpec[]
52
+ }
53
+
54
+ /** Trace + autograd + buffer-plan + codegen, without touching WebGPU. */
55
+ export function compileToIR(traceFn: () => Tensor): CompiledIR {
56
+ const graph = trace(traceFn)
57
+ const { paramGrads, loss } = appendGrad(graph)
58
+ const plan = planBuffers(graph, paramGrads)
59
+ const kernels = emitKernels(graph, plan)
60
+ return { graph, paramGrads, loss, plan, kernels }
61
+ }
62
+
63
+ /** Full compile pipeline. Browser-only because it creates a GPUDevice. */
64
+ export async function compile(traceFn: () => Tensor, opts: RuntimeOpts = {}): Promise<CompiledRuntime & { ir: CompiledIR }> {
65
+ const ir = compileToIR(traceFn)
66
+ const lossBufferId = ir.plan.tensorToBuffer.get(ir.loss.id)!
67
+ const runtime = await createRuntime(ir.plan, ir.kernels, lossBufferId, opts)
68
+ return Object.assign(runtime, { ir })
69
+ }
70
+
71
+ // ============================================================================
72
+ // Module-aware compile
73
+ // ============================================================================
74
+
75
+ export interface CompileModuleOptions<I extends InputDecls = InputDecls> extends RuntimeOpts {
76
+ /** Per-step data inputs to the forward function, keyed by name. The forward
77
+ * fn destructures these out of its second argument; runtime calls to
78
+ * `step()` / `run()` pass typed arrays under the same keys. */
79
+ inputs?: I
80
+ /** Adam hyperparameters. If omitted, no optimizer is appended (forward-only). */
81
+ adam?: AdamConfig
82
+ }
83
+
84
+ export interface CompileForwardOptions<I extends InputDecls = InputDecls> extends RuntimeOpts {
85
+ /** Per-step data inputs to the forward function, keyed by name. */
86
+ inputs?: I
87
+ }
88
+
89
+ /** Forward-only compile options as taken by the `compileForward` *method* on
90
+ * a training runtime — no `device` (inherited) and no `sharedParams`
91
+ * (auto-supplied from the train graph's params). */
92
+ export interface CompileForwardMethodOptions<I extends InputDecls = InputDecls> {
93
+ inputs?: I
94
+ }
95
+
96
+ /** Returned by `compileModule`. Adds training-graph extras (auto-init, reset,
97
+ * sibling-graph compile) on top of the base runtime. */
98
+ export interface CompiledModule<M extends Module> extends CompiledRuntime {
99
+ ir: CompiledIR
100
+ /** Number of dispatchable kernels (excludes leaf no-ops). */
101
+ kernelCount: number
102
+ /** Re-initialize all params from their declared init specs and zero the
103
+ * optimizer state. Use to start training over without recompiling. */
104
+ reset(): void
105
+ /** Compile a sibling forward-only graph (e.g., a B=1 inference graph or a
106
+ * B=N held-out eval graph) that shares this runtime's device and param
107
+ * buffers. Pass the forward fn (typically distinct from your loss fn —
108
+ * it returns logits, not a scalar) and any shape changes via `inputs`.
109
+ * Auto-initialization is a no-op since params are shared. */
110
+ compileForward<I extends InputDecls>(
111
+ forward: ForwardFn<M, I>,
112
+ opts?: CompileForwardMethodOptions<I>,
113
+ ): Promise<CompiledForwardModule>
114
+ }
115
+
116
+ /** Returned by `compileForward` (and by the `compileForward` method). */
117
+ export interface CompiledForwardModule extends CompiledForward {
118
+ ir: CompiledIR
119
+ /** Number of dispatchable kernels (excludes leaf no-ops). */
120
+ kernelCount: number
121
+ }
122
+
123
+ /**
124
+ * Compile a Module-based model. Pass a *factory* `() => new Model()`, not the
125
+ * model instance itself: compilation mutates the tree (every `ParamSentinel`
126
+ * field becomes a real `Tensor`), so the instance is consumed and shouldn't be
127
+ * referenced afterwards. Re-call the factory if you need a fresh tree.
128
+ *
129
+ * The forward function takes the materialized model and a Record of named
130
+ * input tensors, returns the loss tensor. Inputs are matched by name with the
131
+ * `inputs:` declaration:
132
+ *
133
+ * inputs: {
134
+ * tokens: { shape: [B, T], dtype: 'i32' },
135
+ * targets: { shape: [B, T], dtype: 'i32' },
136
+ * }
137
+ * forward: (m, { tokens, targets }) => …
138
+ *
139
+ * Walks the module tree to materialize params with auto-derived names, then
140
+ * runs trace → grad → adam → buffer plan → codegen → runtime. Initial
141
+ * parameter values are uploaded automatically before this function returns;
142
+ * call `reset()` later to re-randomize.
143
+ *
144
+ * If `opts.adam` is set, the runtime's `step()` automatically tracks an
145
+ * internal step count and injects the bias-corrected `lrt` scalar each call;
146
+ * users don't need to provide it themselves.
147
+ */
148
+ export async function compileModule<M extends Module, I extends InputDecls = InputDecls>(
149
+ modelFactory: () => M,
150
+ forward: ForwardFn<M, I>,
151
+ opts: CompileModuleOptions<I> = {},
152
+ ): Promise<CompiledModule<M>> {
153
+ const { runtime, materialized, plan, kernels, ir } = await buildModuleRuntime(
154
+ modelFactory, forward, opts, /* sharedParams */ undefined, /* withGrad */ true,
155
+ )
156
+
157
+ // If Adam is enabled, wrap step() to track the step count and supply lrt
158
+ // (and optionally decayShrink, when the user passed a per-step lr schedule).
159
+ // Wrap resetOptimizerState() too, so a reset zeros m/v *and* the bias-correction
160
+ // counter — otherwise the next step would skip Adam's warmup phase.
161
+ if (opts.adam) {
162
+ wrapStepForAdam(runtime, opts.adam, ir)
163
+ }
164
+
165
+ // Auto-upload initial param values. Always wanted at this entry point —
166
+ // training runtimes own their params and need them randomized before step 1.
167
+ uploadInitialParams(plan, materialized.initFns, runtime, /* sharedParams */ undefined)
168
+
169
+ const kernelCount = kernels.filter(k => k.wgsl).length
170
+
171
+ const reset = () => {
172
+ uploadInitialParams(plan, materialized.initFns, runtime, undefined)
173
+ runtime.resetOptimizerState()
174
+ }
175
+
176
+ const compileForwardMethod = async <J extends InputDecls>(
177
+ forwardFn: ForwardFn<M, J>,
178
+ fOpts: CompileForwardMethodOptions<J> = {},
179
+ ): Promise<CompiledForwardModule> => {
180
+ return compileForward<M, J>(modelFactory, forwardFn, {
181
+ ...fOpts,
182
+ device: runtime.device,
183
+ sharedParams: runtime.params,
184
+ })
185
+ }
186
+
187
+ return Object.assign(runtime, {
188
+ ir,
189
+ kernelCount,
190
+ reset,
191
+ compileForward: compileForwardMethod,
192
+ })
193
+ }
194
+
195
+ // ============================================================================
196
+ // Forward-only compile
197
+ // ============================================================================
198
+
199
+ /**
200
+ * Compile a Module-based model in forward-only mode (no autograd, no Adam).
201
+ * The forward function returns the output tensor (e.g., logits) instead of a
202
+ * scalar loss; runtime exposes `run(inputs)` returning the full output as a
203
+ * `Float32Array`.
204
+ *
205
+ * **Prefer the `compileForward` method on a training runtime** when both
206
+ * graphs use the same Module class — it auto-supplies `device` and
207
+ * `sharedParams`. This standalone form is for forward-only models with no
208
+ * training graph at all, or for sharing params across a different model.
209
+ *
210
+ * **Sharing params with a training compile.** Pass `opts.sharedParams =
211
+ * trainCompiled.params` to bind this graph's param buffers to an existing
212
+ * training runtime's GPU buffers — every train step is then immediately
213
+ * visible to `run()` calls here, no copies.
214
+ *
215
+ * Initial param values are uploaded automatically for params *not* covered
216
+ * by `sharedParams` (those are owned by the sibling compile).
217
+ */
218
+ export async function compileForward<M extends Module, I extends InputDecls = InputDecls>(
219
+ modelFactory: () => M,
220
+ forward: ForwardFn<M, I>,
221
+ opts: CompileForwardOptions<I> = {},
222
+ ): Promise<CompiledForwardModule> {
223
+ const sharedParams = opts.sharedParams
224
+ const { runtime, materialized, plan, kernels, ir } = await buildModuleRuntime(
225
+ modelFactory, forward, opts, sharedParams, /* withGrad */ false,
226
+ )
227
+
228
+ // Auto-upload initial values for any params this graph owns. With
229
+ // `sharedParams` covering everything, this is a no-op.
230
+ uploadInitialParams(plan, materialized.initFns, runtime, sharedParams)
231
+
232
+ const kernelCount = kernels.filter(k => k.wgsl).length
233
+ return Object.assign(runtime, { ir, kernelCount })
234
+ }
235
+
236
+ // ============================================================================
237
+ // Internals
238
+ // ============================================================================
239
+
240
+ type InitFn = (size: number, shape: readonly number[]) => Float32Array
241
+
242
+ interface BuiltRuntime {
243
+ runtime: CompiledRuntime
244
+ materialized: ReturnType<typeof materializeParams>
245
+ plan: BufferPlan
246
+ kernels: KernelSpec[]
247
+ ir: CompiledIR
248
+ }
249
+
250
+ /** Shared body of compileModule + compileForward. The training and forward
251
+ * pipelines diverge only in (a) whether grad/Adam are appended and (b)
252
+ * whether the output buffer is the loss scalar or the user's returned
253
+ * tensor — both come out of the same trace and codegen path. */
254
+ async function buildModuleRuntime<M extends Module, I extends InputDecls>(
255
+ modelFactory: () => M,
256
+ forward: ForwardFn<M, I>,
257
+ opts: CompileModuleOptions<I> | CompileForwardOptions<I>,
258
+ sharedParams: Map<string, GPUBuffer> | undefined,
259
+ withGrad: boolean,
260
+ ): Promise<BuiltRuntime> {
261
+ const inputDecls: InputDecls = opts.inputs ?? {}
262
+ const model = modelFactory()
263
+ let materialized: ReturnType<typeof materializeParams> = { tensors: {}, initFns: {}, decayFlags: {} }
264
+ const graph = trace(() => {
265
+ materialized = materializeParams(model)
266
+ const inputTensors: Record<string, Tensor> = {}
267
+ for (const [name, decl] of Object.entries(inputDecls)) {
268
+ inputTensors[name] = tensorInput(name, decl.shape, decl.dtype ?? 'f32')
269
+ }
270
+ return forward(model, inputTensors as InputsTensors<I>)
271
+ })
272
+
273
+ let paramGrads: GradResult['paramGrads'] = {}
274
+ let outputTensor: Tensor
275
+ let adamWritebacks: ReturnType<typeof appendAdam>['writebacks'] = []
276
+
277
+ if (withGrad) {
278
+ const gradResult = appendGrad(graph)
279
+ paramGrads = gradResult.paramGrads
280
+ outputTensor = gradResult.loss
281
+ const adamCfg = (opts as CompileModuleOptions).adam
282
+ if (adamCfg) {
283
+ const adamResult = appendAdam(graph, paramGrads, materialized.tensors, adamCfg, materialized.decayFlags)
284
+ adamWritebacks = adamResult.writebacks
285
+ // Stash adam result on the graph so wrapStepForAdam can find it.
286
+ ;(graph as Graph & { __adam?: ReturnType<typeof appendAdam> }).__adam = adamResult
287
+ }
288
+ } else {
289
+ outputTensor = graph.tensors[graph.outputs[0]!]!
290
+ }
291
+
292
+ const plan = planBuffers(graph, paramGrads, adamWritebacks)
293
+ const kernels = emitKernels(graph, plan)
294
+ const outputBufferId = plan.tensorToBuffer.get(outputTensor.id)!
295
+ // exactOptionalPropertyTypes: only include sharedParams when defined.
296
+ const runtimeOpts: RuntimeOpts = sharedParams
297
+ ? { ...opts, sharedParams }
298
+ : { ...opts }
299
+ const runtime = withGrad
300
+ ? await createRuntime(plan, kernels, outputBufferId, runtimeOpts)
301
+ : await createForwardRuntime(plan, kernels, outputBufferId, runtimeOpts)
302
+
303
+ const ir: CompiledIR = { graph, paramGrads, loss: outputTensor, plan, kernels }
304
+ return { runtime: runtime as CompiledRuntime, materialized, plan, kernels, ir }
305
+ }
306
+
307
+ type Graph = ReturnType<typeof trace>
308
+
309
+ function wrapStepForAdam(runtime: CompiledRuntime, adamCfg: AdamConfig, ir: CompiledIR): void {
310
+ const adamResult = (ir.graph as Graph & { __adam?: ReturnType<typeof appendAdam> }).__adam!
311
+ const { lrtInputName, decayShrinkInputName, config } = adamResult
312
+ let t = 0
313
+ const lrtBuf = new Float32Array(1)
314
+ const decayShrinkBuf = decayShrinkInputName ? new Float32Array(1) : null
315
+ const innerStep = runtime.step.bind(runtime) as CompiledRuntime['step']
316
+ const innerReset = runtime.resetOptimizerState.bind(runtime)
317
+ const wrappedStep = ((
318
+ inputs: Record<string, Int32Array | Float32Array>,
319
+ opts?: { withCaptures?: boolean },
320
+ ) => {
321
+ t++
322
+ const lrNow = config.lr(t)
323
+ lrtBuf[0] = lrNow * Math.sqrt(1 - Math.pow(config.b2, t)) / (1 - Math.pow(config.b1, t))
324
+ const merged: Record<string, Int32Array | Float32Array> = { ...inputs, [lrtInputName]: lrtBuf }
325
+ if (decayShrinkBuf && decayShrinkInputName) {
326
+ decayShrinkBuf[0] = 1 - lrNow * config.weightDecay
327
+ merged[decayShrinkInputName] = decayShrinkBuf
328
+ }
329
+ return opts?.withCaptures ? innerStep(merged, { withCaptures: true }) : innerStep(merged)
330
+ }) as CompiledRuntime['step']
331
+ runtime.step = wrappedStep
332
+ runtime.resetOptimizerState = () => {
333
+ t = 0
334
+ innerReset()
335
+ }
336
+ void adamCfg
337
+ }
338
+
339
+ /** Build a Record<paramName, Float32Array> by running each param's init
340
+ * function against its shape and uploading them to the runtime. Skips any
341
+ * param covered by `sharedParams` (those are owned by a sibling compile). */
342
+ function uploadInitialParams(
343
+ plan: BufferPlan,
344
+ initFns: Record<string, InitFn>,
345
+ runtime: CompiledRuntime | CompiledForward,
346
+ sharedParams: Map<string, GPUBuffer> | undefined,
347
+ ): void {
348
+ const out: Record<string, Float32Array> = {}
349
+ for (const [name, bufId] of plan.paramsByName) {
350
+ if (sharedParams?.has(name)) continue
351
+ const shape = plan.buffers[bufId]!.shape
352
+ const size = shape.reduce((a, b) => a * b, 1)
353
+ const initFn = initFns[name]
354
+ if (!initFn) throw new Error(`compile: no init for param '${name}'`)
355
+ out[name] = initFn(size, shape)
356
+ }
357
+ if (Object.keys(out).length > 0) runtime.uploadParams(out, { partial: !!sharedParams })
358
+ }