tensorgrad 0.0.12 → 0.0.13
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/buffers.js +1 -6
- package/dist/buffers.js.map +1 -1
- package/dist/codegen.js +30 -28
- package/dist/codegen.js.map +1 -1
- package/dist/compile.js +39 -68
- package/dist/compile.js.map +1 -1
- package/dist/grad.js +1 -14
- package/dist/grad.js.map +1 -1
- package/dist/index.d.ts +740 -14
- package/dist/runtime.js +6 -9
- package/dist/runtime.js.map +1 -1
- package/dist/trace.js +8 -13
- package/dist/trace.js.map +1 -1
- package/package.json +9 -3
- package/src/buffers.ts +1 -6
- package/src/codegen.ts +31 -28
- package/src/compile.ts +312 -358
- package/src/grad.ts +1 -11
- package/src/runtime.ts +6 -9
- package/src/trace.ts +12 -9
- package/dist/adam.d.ts +0 -65
- package/dist/adam.d.ts.map +0 -1
- package/dist/buffers.d.ts +0 -57
- package/dist/buffers.d.ts.map +0 -1
- package/dist/capture.d.ts +0 -3
- package/dist/capture.d.ts.map +0 -1
- package/dist/codegen.d.ts +0 -23
- package/dist/codegen.d.ts.map +0 -1
- package/dist/compile.d.ts +0 -130
- package/dist/compile.d.ts.map +0 -1
- package/dist/grad.d.ts +0 -8
- package/dist/grad.d.ts.map +0 -1
- package/dist/index.d.ts.map +0 -1
- package/dist/ir.d.ts +0 -207
- package/dist/ir.d.ts.map +0 -1
- package/dist/module.d.ts +0 -55
- package/dist/module.d.ts.map +0 -1
- package/dist/nn.d.ts +0 -42
- package/dist/nn.d.ts.map +0 -1
- package/dist/ops.d.ts +0 -48
- package/dist/ops.d.ts.map +0 -1
- package/dist/runtime.d.ts +0 -115
- package/dist/runtime.d.ts.map +0 -1
- package/dist/shape.d.ts +0 -24
- package/dist/shape.d.ts.map +0 -1
- package/dist/trace.d.ts +0 -9
- package/dist/trace.d.ts.map +0 -1
package/src/compile.ts
CHANGED
|
@@ -1,358 +1,312 @@
|
|
|
1
|
-
// Top-level compile(): trace → autograd → buffer plan → codegen → runtime.
|
|
2
|
-
//
|
|
3
|
-
// Two entry points:
|
|
4
|
-
// * `compile(traceFn)` — low-level. User declares params via
|
|
5
|
-
// paramInput() inside the trace.
|
|
6
|
-
// * `compileModule(model, …)` — high-level. User defines the model as a
|
|
7
|
-
// Module tree; the library auto-discovers
|
|
8
|
-
// params, traces the forward, appends grad
|
|
9
|
-
// and Adam, and returns a runtime.
|
|
10
|
-
|
|
11
|
-
import type { Tensor, Shape, Dtype } from './ir.js'
|
|
12
|
-
import { trace, tensorInput } from './trace.js'
|
|
13
|
-
import { appendGrad, type GradResult } from './grad.js'
|
|
14
|
-
import { appendAdam, type AdamConfig } from './adam.js'
|
|
15
|
-
import { planBuffers, type BufferPlan } from './buffers.js'
|
|
16
|
-
import { emitKernels, type KernelSpec } from './codegen.js'
|
|
17
|
-
import { createRuntime, createForwardRuntime, type CompiledRuntime, type CompiledForward, type RuntimeOpts } from './runtime.js'
|
|
18
|
-
import { Module, materializeParams } from './module.js'
|
|
19
|
-
|
|
20
|
-
/** Declares one input tensor of the model's forward function. The name is the
|
|
21
|
-
* key in the `inputs:` Record at compile time and the key on the `step()`/
|
|
22
|
-
* `run()` data object at runtime. */
|
|
23
|
-
export interface InputDecl {
|
|
24
|
-
shape: Shape
|
|
25
|
-
dtype?: Dtype
|
|
26
|
-
}
|
|
27
|
-
|
|
28
|
-
/** Inputs declaration: a Record from input name to its shape/dtype. The name
|
|
29
|
-
* doubles as the key the forward fn destructures and the key the runtime
|
|
30
|
-
* expects in `step({...})` / `run({...})`. */
|
|
31
|
-
export type InputDecls = Record<string, InputDecl>
|
|
32
|
-
|
|
33
|
-
/** Maps an `InputDecls` Record to its forward-time tensor counterpart —
|
|
34
|
-
* same keys, each value is a Tensor. Used to type the forward function's
|
|
35
|
-
* `inputs` argument from the declared shape Record. */
|
|
36
|
-
export type InputsTensors<I extends InputDecls> = { [K in keyof I]: Tensor }
|
|
37
|
-
|
|
38
|
-
/** Forward function shape: takes the materialized model and a Record of
|
|
39
|
-
* named input tensors (matching the declared `inputs:` keys), returns the
|
|
40
|
-
* output tensor (loss for compileModule; logits/etc. for compileForward).
|
|
41
|
-
* The second generic flows from the inputs declaration so destructuring
|
|
42
|
-
* the input record stays typed. */
|
|
43
|
-
export type ForwardFn<M extends Module, I extends InputDecls = InputDecls> =
|
|
44
|
-
(m: M, inputs: InputsTensors<I>) => Tensor
|
|
45
|
-
|
|
46
|
-
export interface CompiledIR {
|
|
47
|
-
graph: GradResult['graph']
|
|
48
|
-
paramGrads: GradResult['paramGrads']
|
|
49
|
-
loss: Tensor
|
|
50
|
-
plan: BufferPlan
|
|
51
|
-
kernels: KernelSpec[]
|
|
52
|
-
}
|
|
53
|
-
|
|
54
|
-
/** Trace + autograd + buffer-plan + codegen, without touching WebGPU. */
|
|
55
|
-
export function compileToIR(traceFn: () => Tensor): CompiledIR {
|
|
56
|
-
const graph = trace(traceFn)
|
|
57
|
-
const { paramGrads, loss } = appendGrad(graph)
|
|
58
|
-
const plan = planBuffers(graph, paramGrads)
|
|
59
|
-
const kernels = emitKernels(graph, plan)
|
|
60
|
-
return { graph, paramGrads, loss, plan, kernels }
|
|
61
|
-
}
|
|
62
|
-
|
|
63
|
-
/** Full compile pipeline. Browser-only because it creates a GPUDevice. */
|
|
64
|
-
export async function compile(traceFn: () => Tensor, opts: RuntimeOpts = {}): Promise<CompiledRuntime & { ir: CompiledIR }> {
|
|
65
|
-
const ir = compileToIR(traceFn)
|
|
66
|
-
const lossBufferId = ir.plan.tensorToBuffer.get(ir.loss.id)!
|
|
67
|
-
const runtime = await createRuntime(ir.plan, ir.kernels, lossBufferId, opts)
|
|
68
|
-
return Object.assign(runtime, { ir })
|
|
69
|
-
}
|
|
70
|
-
|
|
71
|
-
// ============================================================================
|
|
72
|
-
// Module-aware compile
|
|
73
|
-
// ============================================================================
|
|
74
|
-
|
|
75
|
-
export interface CompileModuleOptions<I extends InputDecls = InputDecls> extends RuntimeOpts {
|
|
76
|
-
/** Per-step data inputs to the forward function, keyed by name. The forward
|
|
77
|
-
* fn destructures these out of its second argument; runtime calls to
|
|
78
|
-
* `step()` / `run()` pass typed arrays under the same keys. */
|
|
79
|
-
inputs?: I
|
|
80
|
-
/** Adam hyperparameters. If omitted, no optimizer is appended (forward-only). */
|
|
81
|
-
adam?: AdamConfig
|
|
82
|
-
}
|
|
83
|
-
|
|
84
|
-
export interface CompileForwardOptions<I extends InputDecls = InputDecls> extends RuntimeOpts {
|
|
85
|
-
/** Per-step data inputs to the forward function, keyed by name. */
|
|
86
|
-
inputs?: I
|
|
87
|
-
}
|
|
88
|
-
|
|
89
|
-
/** Forward-only compile options as taken by the `compileForward` *method* on
|
|
90
|
-
* a training runtime — no `device` (inherited) and no `sharedParams`
|
|
91
|
-
* (auto-supplied from the train graph's params). */
|
|
92
|
-
export interface CompileForwardMethodOptions<I extends InputDecls = InputDecls> {
|
|
93
|
-
inputs?: I
|
|
94
|
-
}
|
|
95
|
-
|
|
96
|
-
/** Returned by `compileModule`. Adds training-graph extras (auto-init, reset,
|
|
97
|
-
* sibling-graph compile) on top of the base runtime. */
|
|
98
|
-
export interface CompiledModule<M extends Module> extends CompiledRuntime {
|
|
99
|
-
ir: CompiledIR
|
|
100
|
-
/** Number of dispatchable kernels (excludes leaf no-ops). */
|
|
101
|
-
kernelCount: number
|
|
102
|
-
/** Re-initialize all params from their declared init specs and zero the
|
|
103
|
-
* optimizer state. Use to start training over without recompiling. */
|
|
104
|
-
reset(): void
|
|
105
|
-
/** Compile a sibling forward-only graph (e.g., a B=1 inference graph or a
|
|
106
|
-
* B=N held-out eval graph) that shares this runtime's device and param
|
|
107
|
-
* buffers. Pass the forward fn (typically distinct from your loss fn —
|
|
108
|
-
* it returns logits, not a scalar) and any shape changes via `inputs`.
|
|
109
|
-
* Auto-initialization is a no-op since params are shared. */
|
|
110
|
-
compileForward<I extends InputDecls>(
|
|
111
|
-
forward: ForwardFn<M, I>,
|
|
112
|
-
opts?: CompileForwardMethodOptions<I>,
|
|
113
|
-
): Promise<CompiledForwardModule>
|
|
114
|
-
}
|
|
115
|
-
|
|
116
|
-
/** Returned by `compileForward` (and by the `compileForward` method). */
|
|
117
|
-
export interface CompiledForwardModule extends CompiledForward {
|
|
118
|
-
ir: CompiledIR
|
|
119
|
-
/** Number of dispatchable kernels (excludes leaf no-ops). */
|
|
120
|
-
kernelCount: number
|
|
121
|
-
}
|
|
122
|
-
|
|
123
|
-
/**
|
|
124
|
-
* Compile a Module-based model. Pass a *factory* `() => new Model()`, not the
|
|
125
|
-
* model instance itself: compilation mutates the tree (every `ParamSentinel`
|
|
126
|
-
* field becomes a real `Tensor`), so the instance is consumed and shouldn't be
|
|
127
|
-
* referenced afterwards. Re-call the factory if you need a fresh tree.
|
|
128
|
-
*
|
|
129
|
-
* The forward function takes the materialized model and a Record of named
|
|
130
|
-
* input tensors, returns the loss tensor. Inputs are matched by name with the
|
|
131
|
-
* `inputs:` declaration:
|
|
132
|
-
*
|
|
133
|
-
* inputs: {
|
|
134
|
-
* tokens: { shape: [B, T], dtype: 'i32' },
|
|
135
|
-
* targets: { shape: [B, T], dtype: 'i32' },
|
|
136
|
-
* }
|
|
137
|
-
* forward: (m, { tokens, targets }) => …
|
|
138
|
-
*
|
|
139
|
-
* Walks the module tree to materialize params with auto-derived names, then
|
|
140
|
-
* runs trace → grad → adam → buffer plan → codegen → runtime. Initial
|
|
141
|
-
* parameter values are uploaded automatically before this function returns;
|
|
142
|
-
* call `reset()` later to re-randomize.
|
|
143
|
-
*
|
|
144
|
-
* If `opts.adam` is set, the runtime's `step()` automatically tracks an
|
|
145
|
-
* internal step count and injects the bias-corrected `lrt` scalar each call;
|
|
146
|
-
* users don't need to provide it themselves.
|
|
147
|
-
*/
|
|
148
|
-
export async function compileModule<M extends Module, I extends InputDecls = InputDecls>(
|
|
149
|
-
modelFactory: () => M,
|
|
150
|
-
forward: ForwardFn<M, I>,
|
|
151
|
-
opts: CompileModuleOptions<I> = {},
|
|
152
|
-
): Promise<CompiledModule<M>> {
|
|
153
|
-
const {
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
}
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
*
|
|
201
|
-
*
|
|
202
|
-
*
|
|
203
|
-
* `
|
|
204
|
-
*
|
|
205
|
-
*
|
|
206
|
-
*
|
|
207
|
-
*
|
|
208
|
-
*
|
|
209
|
-
*
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
)
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
}
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
const
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
const lrtBuf = new Float32Array(1)
|
|
314
|
-
const decayShrinkBuf = decayShrinkInputName ? new Float32Array(1) : null
|
|
315
|
-
const innerStep = runtime.step.bind(runtime) as CompiledRuntime['step']
|
|
316
|
-
const innerReset = runtime.resetOptimizerState.bind(runtime)
|
|
317
|
-
const wrappedStep = ((
|
|
318
|
-
inputs: Record<string, Int32Array | Float32Array>,
|
|
319
|
-
opts?: { withCaptures?: boolean },
|
|
320
|
-
) => {
|
|
321
|
-
t++
|
|
322
|
-
const lrNow = config.lr(t)
|
|
323
|
-
lrtBuf[0] = lrNow * Math.sqrt(1 - Math.pow(config.b2, t)) / (1 - Math.pow(config.b1, t))
|
|
324
|
-
const merged: Record<string, Int32Array | Float32Array> = { ...inputs, [lrtInputName]: lrtBuf }
|
|
325
|
-
if (decayShrinkBuf && decayShrinkInputName) {
|
|
326
|
-
decayShrinkBuf[0] = 1 - lrNow * config.weightDecay
|
|
327
|
-
merged[decayShrinkInputName] = decayShrinkBuf
|
|
328
|
-
}
|
|
329
|
-
return opts?.withCaptures ? innerStep(merged, { withCaptures: true }) : innerStep(merged)
|
|
330
|
-
}) as CompiledRuntime['step']
|
|
331
|
-
runtime.step = wrappedStep
|
|
332
|
-
runtime.resetOptimizerState = () => {
|
|
333
|
-
t = 0
|
|
334
|
-
innerReset()
|
|
335
|
-
}
|
|
336
|
-
void adamCfg
|
|
337
|
-
}
|
|
338
|
-
|
|
339
|
-
/** Build a Record<paramName, Float32Array> by running each param's init
|
|
340
|
-
* function against its shape and uploading them to the runtime. Skips any
|
|
341
|
-
* param covered by `sharedParams` (those are owned by a sibling compile). */
|
|
342
|
-
function uploadInitialParams(
|
|
343
|
-
plan: BufferPlan,
|
|
344
|
-
initFns: Record<string, InitFn>,
|
|
345
|
-
runtime: CompiledRuntime | CompiledForward,
|
|
346
|
-
sharedParams: Map<string, GPUBuffer> | undefined,
|
|
347
|
-
): void {
|
|
348
|
-
const out: Record<string, Float32Array> = {}
|
|
349
|
-
for (const [name, bufId] of plan.paramsByName) {
|
|
350
|
-
if (sharedParams?.has(name)) continue
|
|
351
|
-
const shape = plan.buffers[bufId]!.shape
|
|
352
|
-
const size = shape.reduce((a, b) => a * b, 1)
|
|
353
|
-
const initFn = initFns[name]
|
|
354
|
-
if (!initFn) throw new Error(`compile: no init for param '${name}'`)
|
|
355
|
-
out[name] = initFn(size, shape)
|
|
356
|
-
}
|
|
357
|
-
if (Object.keys(out).length > 0) runtime.uploadParams(out, { partial: !!sharedParams })
|
|
358
|
-
}
|
|
1
|
+
// Top-level compile(): trace → autograd → buffer plan → codegen → runtime.
|
|
2
|
+
//
|
|
3
|
+
// Two entry points:
|
|
4
|
+
// * `compile(traceFn)` — low-level. User declares params via
|
|
5
|
+
// paramInput() inside the trace.
|
|
6
|
+
// * `compileModule(model, …)` — high-level. User defines the model as a
|
|
7
|
+
// Module tree; the library auto-discovers
|
|
8
|
+
// params, traces the forward, appends grad
|
|
9
|
+
// and Adam, and returns a runtime.
|
|
10
|
+
|
|
11
|
+
import type { Tensor, Shape, Dtype } from './ir.js'
|
|
12
|
+
import { trace, tensorInput } from './trace.js'
|
|
13
|
+
import { appendGrad, type GradResult } from './grad.js'
|
|
14
|
+
import { appendAdam, type AdamConfig, type AdamResult } from './adam.js'
|
|
15
|
+
import { planBuffers, type BufferPlan } from './buffers.js'
|
|
16
|
+
import { emitKernels, type KernelSpec } from './codegen.js'
|
|
17
|
+
import { createRuntime, createForwardRuntime, type CompiledRuntime, type CompiledForward, type RuntimeOpts } from './runtime.js'
|
|
18
|
+
import { Module, materializeParams, type MaterializedParams } from './module.js'
|
|
19
|
+
|
|
20
|
+
/** Declares one input tensor of the model's forward function. The name is the
|
|
21
|
+
* key in the `inputs:` Record at compile time and the key on the `step()`/
|
|
22
|
+
* `run()` data object at runtime. */
|
|
23
|
+
export interface InputDecl {
|
|
24
|
+
shape: Shape
|
|
25
|
+
dtype?: Dtype
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
/** Inputs declaration: a Record from input name to its shape/dtype. The name
|
|
29
|
+
* doubles as the key the forward fn destructures and the key the runtime
|
|
30
|
+
* expects in `step({...})` / `run({...})`. */
|
|
31
|
+
export type InputDecls = Record<string, InputDecl>
|
|
32
|
+
|
|
33
|
+
/** Maps an `InputDecls` Record to its forward-time tensor counterpart —
|
|
34
|
+
* same keys, each value is a Tensor. Used to type the forward function's
|
|
35
|
+
* `inputs` argument from the declared shape Record. */
|
|
36
|
+
export type InputsTensors<I extends InputDecls> = { [K in keyof I]: Tensor }
|
|
37
|
+
|
|
38
|
+
/** Forward function shape: takes the materialized model and a Record of
|
|
39
|
+
* named input tensors (matching the declared `inputs:` keys), returns the
|
|
40
|
+
* output tensor (loss for compileModule; logits/etc. for compileForward).
|
|
41
|
+
* The second generic flows from the inputs declaration so destructuring
|
|
42
|
+
* the input record stays typed. */
|
|
43
|
+
export type ForwardFn<M extends Module, I extends InputDecls = InputDecls> =
|
|
44
|
+
(m: M, inputs: InputsTensors<I>) => Tensor
|
|
45
|
+
|
|
46
|
+
export interface CompiledIR {
|
|
47
|
+
graph: GradResult['graph']
|
|
48
|
+
paramGrads: GradResult['paramGrads']
|
|
49
|
+
loss: Tensor
|
|
50
|
+
plan: BufferPlan
|
|
51
|
+
kernels: KernelSpec[]
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
/** Trace + autograd + buffer-plan + codegen, without touching WebGPU. */
|
|
55
|
+
export function compileToIR(traceFn: () => Tensor): CompiledIR {
|
|
56
|
+
const graph = trace(traceFn)
|
|
57
|
+
const { paramGrads, loss } = appendGrad(graph)
|
|
58
|
+
const plan = planBuffers(graph, paramGrads)
|
|
59
|
+
const kernels = emitKernels(graph, plan)
|
|
60
|
+
return { graph, paramGrads, loss, plan, kernels }
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
/** Full compile pipeline. Browser-only because it creates a GPUDevice. */
|
|
64
|
+
export async function compile(traceFn: () => Tensor, opts: RuntimeOpts = {}): Promise<CompiledRuntime & { ir: CompiledIR }> {
|
|
65
|
+
const ir = compileToIR(traceFn)
|
|
66
|
+
const lossBufferId = ir.plan.tensorToBuffer.get(ir.loss.id)!
|
|
67
|
+
const runtime = await createRuntime(ir.plan, ir.kernels, lossBufferId, opts)
|
|
68
|
+
return Object.assign(runtime, { ir })
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
// ============================================================================
|
|
72
|
+
// Module-aware compile
|
|
73
|
+
// ============================================================================
|
|
74
|
+
|
|
75
|
+
export interface CompileModuleOptions<I extends InputDecls = InputDecls> extends RuntimeOpts {
|
|
76
|
+
/** Per-step data inputs to the forward function, keyed by name. The forward
|
|
77
|
+
* fn destructures these out of its second argument; runtime calls to
|
|
78
|
+
* `step()` / `run()` pass typed arrays under the same keys. */
|
|
79
|
+
inputs?: I
|
|
80
|
+
/** Adam hyperparameters. If omitted, no optimizer is appended (forward-only). */
|
|
81
|
+
adam?: AdamConfig
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
export interface CompileForwardOptions<I extends InputDecls = InputDecls> extends RuntimeOpts {
|
|
85
|
+
/** Per-step data inputs to the forward function, keyed by name. */
|
|
86
|
+
inputs?: I
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
/** Forward-only compile options as taken by the `compileForward` *method* on
|
|
90
|
+
* a training runtime — no `device` (inherited) and no `sharedParams`
|
|
91
|
+
* (auto-supplied from the train graph's params). */
|
|
92
|
+
export interface CompileForwardMethodOptions<I extends InputDecls = InputDecls> {
|
|
93
|
+
inputs?: I
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
/** Returned by `compileModule`. Adds training-graph extras (auto-init, reset,
|
|
97
|
+
* sibling-graph compile) on top of the base runtime. */
|
|
98
|
+
export interface CompiledModule<M extends Module> extends CompiledRuntime {
|
|
99
|
+
ir: CompiledIR
|
|
100
|
+
/** Number of dispatchable kernels (excludes leaf no-ops). */
|
|
101
|
+
kernelCount: number
|
|
102
|
+
/** Re-initialize all params from their declared init specs and zero the
|
|
103
|
+
* optimizer state. Use to start training over without recompiling. */
|
|
104
|
+
reset(): void
|
|
105
|
+
/** Compile a sibling forward-only graph (e.g., a B=1 inference graph or a
|
|
106
|
+
* B=N held-out eval graph) that shares this runtime's device and param
|
|
107
|
+
* buffers. Pass the forward fn (typically distinct from your loss fn —
|
|
108
|
+
* it returns logits, not a scalar) and any shape changes via `inputs`.
|
|
109
|
+
* Auto-initialization is a no-op since params are shared. */
|
|
110
|
+
compileForward<I extends InputDecls>(
|
|
111
|
+
forward: ForwardFn<M, I>,
|
|
112
|
+
opts?: CompileForwardMethodOptions<I>,
|
|
113
|
+
): Promise<CompiledForwardModule>
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
/** Returned by `compileForward` (and by the `compileForward` method). */
|
|
117
|
+
export interface CompiledForwardModule extends CompiledForward {
|
|
118
|
+
ir: CompiledIR
|
|
119
|
+
/** Number of dispatchable kernels (excludes leaf no-ops). */
|
|
120
|
+
kernelCount: number
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
/**
|
|
124
|
+
* Compile a Module-based model. Pass a *factory* `() => new Model()`, not the
|
|
125
|
+
* model instance itself: compilation mutates the tree (every `ParamSentinel`
|
|
126
|
+
* field becomes a real `Tensor`), so the instance is consumed and shouldn't be
|
|
127
|
+
* referenced afterwards. Re-call the factory if you need a fresh tree.
|
|
128
|
+
*
|
|
129
|
+
* The forward function takes the materialized model and a Record of named
|
|
130
|
+
* input tensors, returns the loss tensor. Inputs are matched by name with the
|
|
131
|
+
* `inputs:` declaration:
|
|
132
|
+
*
|
|
133
|
+
* inputs: {
|
|
134
|
+
* tokens: { shape: [B, T], dtype: 'i32' },
|
|
135
|
+
* targets: { shape: [B, T], dtype: 'i32' },
|
|
136
|
+
* }
|
|
137
|
+
* forward: (m, { tokens, targets }) => …
|
|
138
|
+
*
|
|
139
|
+
* Walks the module tree to materialize params with auto-derived names, then
|
|
140
|
+
* runs trace → grad → adam → buffer plan → codegen → runtime. Initial
|
|
141
|
+
* parameter values are uploaded automatically before this function returns;
|
|
142
|
+
* call `reset()` later to re-randomize.
|
|
143
|
+
*
|
|
144
|
+
* If `opts.adam` is set, the runtime's `step()` automatically tracks an
|
|
145
|
+
* internal step count and injects the bias-corrected `lrt` scalar each call;
|
|
146
|
+
* users don't need to provide it themselves.
|
|
147
|
+
*/
|
|
148
|
+
export async function compileModule<M extends Module, I extends InputDecls = InputDecls>(
|
|
149
|
+
modelFactory: () => M,
|
|
150
|
+
forward: ForwardFn<M, I>,
|
|
151
|
+
opts: CompileModuleOptions<I> = {},
|
|
152
|
+
): Promise<CompiledModule<M>> {
|
|
153
|
+
const { graph, materialized } = traceModule(modelFactory, forward, opts.inputs ?? {})
|
|
154
|
+
const { paramGrads, loss } = appendGrad(graph)
|
|
155
|
+
const adamResult = opts.adam
|
|
156
|
+
? appendAdam(graph, paramGrads, materialized.tensors, opts.adam, materialized.decayFlags)
|
|
157
|
+
: undefined
|
|
158
|
+
|
|
159
|
+
const plan = planBuffers(graph, paramGrads, adamResult?.writebacks ?? [])
|
|
160
|
+
const kernels = emitKernels(graph, plan)
|
|
161
|
+
const lossBufferId = plan.tensorToBuffer.get(loss.id)!
|
|
162
|
+
const runtime = await createRuntime(plan, kernels, lossBufferId, opts)
|
|
163
|
+
|
|
164
|
+
if (adamResult) wrapStepForAdam(runtime, adamResult)
|
|
165
|
+
uploadInitialParams(plan, materialized.initFns, runtime, /* sharedParams */ undefined)
|
|
166
|
+
|
|
167
|
+
const ir: CompiledIR = { graph, paramGrads, loss, plan, kernels }
|
|
168
|
+
const kernelCount = countKernels(kernels)
|
|
169
|
+
|
|
170
|
+
const reset = () => {
|
|
171
|
+
uploadInitialParams(plan, materialized.initFns, runtime, undefined)
|
|
172
|
+
runtime.resetOptimizerState()
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
const compileForwardMethod = <J extends InputDecls>(
|
|
176
|
+
forwardFn: ForwardFn<M, J>,
|
|
177
|
+
fOpts: CompileForwardMethodOptions<J> = {},
|
|
178
|
+
): Promise<CompiledForwardModule> =>
|
|
179
|
+
compileForward<M, J>(modelFactory, forwardFn, {
|
|
180
|
+
...fOpts,
|
|
181
|
+
device: runtime.device,
|
|
182
|
+
sharedParams: runtime.params,
|
|
183
|
+
})
|
|
184
|
+
|
|
185
|
+
return Object.assign(runtime, { ir, kernelCount, reset, compileForward: compileForwardMethod })
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
// ============================================================================
|
|
189
|
+
// Forward-only compile
|
|
190
|
+
// ============================================================================
|
|
191
|
+
|
|
192
|
+
/**
|
|
193
|
+
* Compile a Module-based model in forward-only mode (no autograd, no Adam).
|
|
194
|
+
* The forward function returns the output tensor (e.g., logits) instead of a
|
|
195
|
+
* scalar loss; runtime exposes `run(inputs)` returning the full output as a
|
|
196
|
+
* `Float32Array`.
|
|
197
|
+
*
|
|
198
|
+
* **Prefer the `compileForward` method on a training runtime** when both
|
|
199
|
+
* graphs use the same Module class — it auto-supplies `device` and
|
|
200
|
+
* `sharedParams`. This standalone form is for forward-only models with no
|
|
201
|
+
* training graph at all, or for sharing params across a different model.
|
|
202
|
+
*
|
|
203
|
+
* **Sharing params with a training compile.** Pass `opts.sharedParams =
|
|
204
|
+
* trainCompiled.params` to bind this graph's param buffers to an existing
|
|
205
|
+
* training runtime's GPU buffers — every train step is then immediately
|
|
206
|
+
* visible to `run()` calls here, no copies.
|
|
207
|
+
*
|
|
208
|
+
* Initial param values are uploaded automatically for params *not* covered
|
|
209
|
+
* by `sharedParams` (those are owned by the sibling compile).
|
|
210
|
+
*/
|
|
211
|
+
export async function compileForward<M extends Module, I extends InputDecls = InputDecls>(
|
|
212
|
+
modelFactory: () => M,
|
|
213
|
+
forward: ForwardFn<M, I>,
|
|
214
|
+
opts: CompileForwardOptions<I> = {},
|
|
215
|
+
): Promise<CompiledForwardModule> {
|
|
216
|
+
const { graph, materialized } = traceModule(modelFactory, forward, opts.inputs ?? {})
|
|
217
|
+
const outputTensor = graph.tensors[graph.outputs[0]!]!
|
|
218
|
+
|
|
219
|
+
const plan = planBuffers(graph, /* paramGrads */ {})
|
|
220
|
+
const kernels = emitKernels(graph, plan)
|
|
221
|
+
const outputBufferId = plan.tensorToBuffer.get(outputTensor.id)!
|
|
222
|
+
const runtime = await createForwardRuntime(plan, kernels, outputBufferId, opts)
|
|
223
|
+
|
|
224
|
+
uploadInitialParams(plan, materialized.initFns, runtime, opts.sharedParams)
|
|
225
|
+
|
|
226
|
+
const ir: CompiledIR = { graph, paramGrads: {}, loss: outputTensor, plan, kernels }
|
|
227
|
+
return Object.assign(runtime, { ir, kernelCount: countKernels(kernels) })
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
// ============================================================================
|
|
231
|
+
// Internals
|
|
232
|
+
// ============================================================================
|
|
233
|
+
|
|
234
|
+
type Graph = ReturnType<typeof trace>
|
|
235
|
+
type InitFn = (size: number, shape: readonly number[]) => Float32Array
|
|
236
|
+
|
|
237
|
+
/** Trace the forward function with a fresh model + tensor inputs and capture
|
|
238
|
+
* the materialized params. Shared by both compile entry points; everything
|
|
239
|
+
* past this point (grad/adam/buffer plan/runtime) diverges. */
|
|
240
|
+
function traceModule<M extends Module, I extends InputDecls>(
|
|
241
|
+
modelFactory: () => M,
|
|
242
|
+
forward: ForwardFn<M, I>,
|
|
243
|
+
inputDecls: InputDecls,
|
|
244
|
+
): { graph: Graph; materialized: MaterializedParams } {
|
|
245
|
+
const model = modelFactory()
|
|
246
|
+
let materialized: MaterializedParams = { tensors: {}, initFns: {}, decayFlags: {} }
|
|
247
|
+
const graph = trace(() => {
|
|
248
|
+
materialized = materializeParams(model)
|
|
249
|
+
const inputTensors: Record<string, Tensor> = {}
|
|
250
|
+
for (const [name, decl] of Object.entries(inputDecls)) {
|
|
251
|
+
inputTensors[name] = tensorInput(name, decl.shape, decl.dtype ?? 'f32')
|
|
252
|
+
}
|
|
253
|
+
return forward(model, inputTensors as InputsTensors<I>)
|
|
254
|
+
})
|
|
255
|
+
return { graph, materialized }
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
const countKernels = (kernels: KernelSpec[]): number => kernels.filter(k => k.wgsl).length
|
|
259
|
+
|
|
260
|
+
/** Wrap the runtime's step() to inject Adam's per-step `lrt` (bias-corrected
|
|
261
|
+
* effective LR) and, when the user supplied a per-step lr schedule, the
|
|
262
|
+
* decayShrink scalar. Also wraps resetOptimizerState() so a reset zeros
|
|
263
|
+
* Adam's m/v *and* the bias-correction step counter — otherwise the next
|
|
264
|
+
* step would skip Adam's warmup phase. */
|
|
265
|
+
function wrapStepForAdam(runtime: CompiledRuntime, adamResult: AdamResult): void {
|
|
266
|
+
const { lrtInputName, decayShrinkInputName, config } = adamResult
|
|
267
|
+
let t = 0
|
|
268
|
+
const lrtBuf = new Float32Array(1)
|
|
269
|
+
const decayShrinkBuf = decayShrinkInputName ? new Float32Array(1) : null
|
|
270
|
+
const innerStep = runtime.step.bind(runtime) as CompiledRuntime['step']
|
|
271
|
+
const innerReset = runtime.resetOptimizerState.bind(runtime)
|
|
272
|
+
const wrappedStep = ((
|
|
273
|
+
inputs: Record<string, Int32Array | Float32Array>,
|
|
274
|
+
opts?: { withCaptures?: boolean },
|
|
275
|
+
) => {
|
|
276
|
+
t++
|
|
277
|
+
const lrNow = config.lr(t)
|
|
278
|
+
lrtBuf[0] = lrNow * Math.sqrt(1 - Math.pow(config.b2, t)) / (1 - Math.pow(config.b1, t))
|
|
279
|
+
const merged: Record<string, Int32Array | Float32Array> = { ...inputs, [lrtInputName]: lrtBuf }
|
|
280
|
+
if (decayShrinkBuf && decayShrinkInputName) {
|
|
281
|
+
decayShrinkBuf[0] = 1 - lrNow * config.weightDecay
|
|
282
|
+
merged[decayShrinkInputName] = decayShrinkBuf
|
|
283
|
+
}
|
|
284
|
+
return opts?.withCaptures ? innerStep(merged, { withCaptures: true }) : innerStep(merged)
|
|
285
|
+
}) as CompiledRuntime['step']
|
|
286
|
+
runtime.step = wrappedStep
|
|
287
|
+
runtime.resetOptimizerState = () => {
|
|
288
|
+
t = 0
|
|
289
|
+
innerReset()
|
|
290
|
+
}
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
/** Build a Record<paramName, Float32Array> by running each param's init
|
|
294
|
+
* function against its shape and uploading them to the runtime. Skips any
|
|
295
|
+
* param covered by `sharedParams` (those are owned by a sibling compile). */
|
|
296
|
+
function uploadInitialParams(
|
|
297
|
+
plan: BufferPlan,
|
|
298
|
+
initFns: Record<string, InitFn>,
|
|
299
|
+
runtime: CompiledRuntime | CompiledForward,
|
|
300
|
+
sharedParams: Map<string, GPUBuffer> | undefined,
|
|
301
|
+
): void {
|
|
302
|
+
const out: Record<string, Float32Array> = {}
|
|
303
|
+
for (const [name, bufId] of plan.paramsByName) {
|
|
304
|
+
if (sharedParams?.has(name)) continue
|
|
305
|
+
const shape = plan.buffers[bufId]!.shape
|
|
306
|
+
const size = shape.reduce((a, b) => a * b, 1)
|
|
307
|
+
const initFn = initFns[name]
|
|
308
|
+
if (!initFn) throw new Error(`compile: no init for param '${name}'`)
|
|
309
|
+
out[name] = initFn(size, shape)
|
|
310
|
+
}
|
|
311
|
+
if (Object.keys(out).length > 0) runtime.uploadParams(out, { partial: !!sharedParams })
|
|
312
|
+
}
|