tensorgrad 0.0.8 → 0.0.11
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +6 -6
- package/dist/compile.d.ts +77 -28
- package/dist/compile.d.ts.map +1 -1
- package/dist/compile.js +132 -81
- package/dist/compile.js.map +1 -1
- package/dist/index.d.ts +2 -2
- package/dist/index.d.ts.map +1 -1
- package/dist/index.js +2 -2
- package/dist/index.js.map +1 -1
- package/dist/nn.d.ts +14 -11
- package/dist/nn.d.ts.map +1 -1
- package/dist/nn.js +28 -33
- package/dist/nn.js.map +1 -1
- package/dist/runtime.d.ts +33 -32
- package/dist/runtime.d.ts.map +1 -1
- package/dist/runtime.js +59 -12
- package/dist/runtime.js.map +1 -1
- package/package.json +1 -1
- package/src/compile.ts +245 -114
- package/src/index.ts +7 -2
- package/src/nn.ts +34 -32
- package/src/runtime.ts +86 -52
package/src/compile.ts
CHANGED
|
@@ -17,15 +17,32 @@ import { emitKernels, type KernelSpec } from './codegen.js'
|
|
|
17
17
|
import { createRuntime, createForwardRuntime, type CompiledRuntime, type CompiledForward, type RuntimeOpts } from './runtime.js'
|
|
18
18
|
import { Module, materializeParams } from './module.js'
|
|
19
19
|
|
|
20
|
-
/** Declares one input tensor of the model's forward function.
|
|
21
|
-
* the
|
|
22
|
-
*
|
|
20
|
+
/** Declares one input tensor of the model's forward function. The name is the
|
|
21
|
+
* key in the `inputs:` Record at compile time and the key on the `step()`/
|
|
22
|
+
* `run()` data object at runtime. */
|
|
23
23
|
export interface InputDecl {
|
|
24
|
-
name: string
|
|
25
24
|
shape: Shape
|
|
26
25
|
dtype?: Dtype
|
|
27
26
|
}
|
|
28
27
|
|
|
28
|
+
/** Inputs declaration: a Record from input name to its shape/dtype. The name
|
|
29
|
+
* doubles as the key the forward fn destructures and the key the runtime
|
|
30
|
+
* expects in `step({...})` / `run({...})`. */
|
|
31
|
+
export type InputDecls = Record<string, InputDecl>
|
|
32
|
+
|
|
33
|
+
/** Maps an `InputDecls` Record to its forward-time tensor counterpart —
|
|
34
|
+
* same keys, each value is a Tensor. Used to type the forward function's
|
|
35
|
+
* `inputs` argument from the declared shape Record. */
|
|
36
|
+
export type InputsTensors<I extends InputDecls> = { [K in keyof I]: Tensor }
|
|
37
|
+
|
|
38
|
+
/** Forward function shape: takes the materialized model and a Record of
|
|
39
|
+
* named input tensors (matching the declared `inputs:` keys), returns the
|
|
40
|
+
* output tensor (loss for compileModule; logits/etc. for compileForward).
|
|
41
|
+
* The second generic flows from the inputs declaration so destructuring
|
|
42
|
+
* the input record stays typed. */
|
|
43
|
+
export type ForwardFn<M extends Module, I extends InputDecls = InputDecls> =
|
|
44
|
+
(m: M, inputs: InputsTensors<I>) => Tensor
|
|
45
|
+
|
|
29
46
|
export interface CompiledIR {
|
|
30
47
|
graph: GradResult['graph']
|
|
31
48
|
paramGrads: GradResult['paramGrads']
|
|
@@ -55,19 +72,52 @@ export async function compile(traceFn: () => Tensor, opts: RuntimeOpts = {}): Pr
|
|
|
55
72
|
// Module-aware compile
|
|
56
73
|
// ============================================================================
|
|
57
74
|
|
|
58
|
-
export interface CompileModuleOptions extends RuntimeOpts {
|
|
59
|
-
/** Per-step data inputs to the forward function
|
|
60
|
-
*
|
|
61
|
-
* `(
|
|
62
|
-
|
|
63
|
-
inputs?: InputDecl[]
|
|
75
|
+
export interface CompileModuleOptions<I extends InputDecls = InputDecls> extends RuntimeOpts {
|
|
76
|
+
/** Per-step data inputs to the forward function, keyed by name. The forward
|
|
77
|
+
* fn destructures these out of its second argument; runtime calls to
|
|
78
|
+
* `step()` / `run()` pass typed arrays under the same keys. */
|
|
79
|
+
inputs?: I
|
|
64
80
|
/** Adam hyperparameters. If omitted, no optimizer is appended (forward-only). */
|
|
65
81
|
adam?: AdamConfig
|
|
66
82
|
}
|
|
67
83
|
|
|
68
|
-
export interface CompileForwardOptions extends RuntimeOpts {
|
|
69
|
-
/** Per-step data inputs to the forward function. */
|
|
70
|
-
inputs?:
|
|
84
|
+
export interface CompileForwardOptions<I extends InputDecls = InputDecls> extends RuntimeOpts {
|
|
85
|
+
/** Per-step data inputs to the forward function, keyed by name. */
|
|
86
|
+
inputs?: I
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
/** Forward-only compile options as taken by the `compileForward` *method* on
|
|
90
|
+
* a training runtime — no `device` (inherited) and no `sharedParams`
|
|
91
|
+
* (auto-supplied from the train graph's params). */
|
|
92
|
+
export interface CompileForwardMethodOptions<I extends InputDecls = InputDecls> {
|
|
93
|
+
inputs?: I
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
/** Returned by `compileModule`. Adds training-graph extras (auto-init, reset,
|
|
97
|
+
* sibling-graph compile) on top of the base runtime. */
|
|
98
|
+
export interface CompiledModule<M extends Module> extends CompiledRuntime {
|
|
99
|
+
ir: CompiledIR
|
|
100
|
+
/** Number of dispatchable kernels (excludes leaf no-ops). */
|
|
101
|
+
kernelCount: number
|
|
102
|
+
/** Re-initialize all params from their declared init specs and zero the
|
|
103
|
+
* optimizer state. Use to start training over without recompiling. */
|
|
104
|
+
reset(): void
|
|
105
|
+
/** Compile a sibling forward-only graph (e.g., a B=1 inference graph or a
|
|
106
|
+
* B=N held-out eval graph) that shares this runtime's device and param
|
|
107
|
+
* buffers. Pass the forward fn (typically distinct from your loss fn —
|
|
108
|
+
* it returns logits, not a scalar) and any shape changes via `inputs`.
|
|
109
|
+
* Auto-initialization is a no-op since params are shared. */
|
|
110
|
+
compileForward<I extends InputDecls>(
|
|
111
|
+
forward: ForwardFn<M, I>,
|
|
112
|
+
opts?: CompileForwardMethodOptions<I>,
|
|
113
|
+
): Promise<CompiledForwardModule>
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
/** Returned by `compileForward` (and by the `compileForward` method). */
|
|
117
|
+
export interface CompiledForwardModule extends CompiledForward {
|
|
118
|
+
ir: CompiledIR
|
|
119
|
+
/** Number of dispatchable kernels (excludes leaf no-ops). */
|
|
120
|
+
kernelCount: number
|
|
71
121
|
}
|
|
72
122
|
|
|
73
123
|
/**
|
|
@@ -76,103 +126,70 @@ export interface CompileForwardOptions extends RuntimeOpts {
|
|
|
76
126
|
* field becomes a real `Tensor`), so the instance is consumed and shouldn't be
|
|
77
127
|
* referenced afterwards. Re-call the factory if you need a fresh tree.
|
|
78
128
|
*
|
|
79
|
-
* The forward function takes the materialized model and
|
|
80
|
-
* tensor.
|
|
129
|
+
* The forward function takes the materialized model and a Record of named
|
|
130
|
+
* input tensors, returns the loss tensor. Inputs are matched by name with the
|
|
131
|
+
* `inputs:` declaration:
|
|
132
|
+
*
|
|
133
|
+
* inputs: {
|
|
134
|
+
* tokens: { shape: [B, T], dtype: 'i32' },
|
|
135
|
+
* targets: { shape: [B, T], dtype: 'i32' },
|
|
136
|
+
* }
|
|
137
|
+
* forward: (m, { tokens, targets }) => …
|
|
81
138
|
*
|
|
82
139
|
* Walks the module tree to materialize params with auto-derived names, then
|
|
83
|
-
* runs trace → grad → adam → buffer plan → codegen → runtime.
|
|
140
|
+
* runs trace → grad → adam → buffer plan → codegen → runtime. Initial
|
|
141
|
+
* parameter values are uploaded automatically before this function returns;
|
|
142
|
+
* call `reset()` later to re-randomize.
|
|
84
143
|
*
|
|
85
144
|
* If `opts.adam` is set, the runtime's `step()` automatically tracks an
|
|
86
145
|
* internal step count and injects the bias-corrected `lrt` scalar each call;
|
|
87
146
|
* users don't need to provide it themselves.
|
|
88
147
|
*/
|
|
89
|
-
export async function compileModule<M extends Module>(
|
|
148
|
+
export async function compileModule<M extends Module, I extends InputDecls = InputDecls>(
|
|
90
149
|
modelFactory: () => M,
|
|
91
|
-
forward:
|
|
92
|
-
opts: CompileModuleOptions = {},
|
|
93
|
-
): Promise<
|
|
94
|
-
const
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
const graph = trace(() => {
|
|
98
|
-
materialized = materializeParams(model)
|
|
99
|
-
const inputTensors = inputDecls.map(d => tensorInput(d.name, d.shape, d.dtype ?? 'f32'))
|
|
100
|
-
return forward(model, ...inputTensors)
|
|
101
|
-
})
|
|
102
|
-
|
|
103
|
-
const { paramGrads, loss } = appendGrad(graph)
|
|
104
|
-
|
|
105
|
-
let adamResult: ReturnType<typeof appendAdam> | undefined
|
|
106
|
-
if (opts.adam) {
|
|
107
|
-
adamResult = appendAdam(graph, paramGrads, materialized.tensors, opts.adam, materialized.decayFlags)
|
|
108
|
-
}
|
|
109
|
-
|
|
110
|
-
const plan = planBuffers(graph, paramGrads, adamResult?.writebacks ?? [])
|
|
111
|
-
const kernels = emitKernels(graph, plan)
|
|
112
|
-
const lossBufferId = plan.tensorToBuffer.get(loss.id)!
|
|
113
|
-
const runtime = await createRuntime(plan, kernels, lossBufferId, opts)
|
|
150
|
+
forward: ForwardFn<M, I>,
|
|
151
|
+
opts: CompileModuleOptions<I> = {},
|
|
152
|
+
): Promise<CompiledModule<M>> {
|
|
153
|
+
const { runtime, materialized, plan, kernels, ir } = await buildModuleRuntime(
|
|
154
|
+
modelFactory, forward, opts, /* sharedParams */ undefined, /* withGrad */ true,
|
|
155
|
+
)
|
|
114
156
|
|
|
115
157
|
// If Adam is enabled, wrap step() to track the step count and supply lrt
|
|
116
158
|
// (and optionally decayShrink, when the user passed a per-step lr schedule).
|
|
117
159
|
// Wrap resetOptimizerState() too, so a reset zeros m/v *and* the bias-correction
|
|
118
160
|
// counter — otherwise the next step would skip Adam's warmup phase.
|
|
119
|
-
if (
|
|
120
|
-
|
|
121
|
-
let t = 0
|
|
122
|
-
const lrtBuf = new Float32Array(1)
|
|
123
|
-
const decayShrinkBuf = decayShrinkInputName ? new Float32Array(1) : null
|
|
124
|
-
const innerStep = runtime.step.bind(runtime) as CompiledRuntime['step']
|
|
125
|
-
const innerReset = runtime.resetOptimizerState.bind(runtime)
|
|
126
|
-
const wrappedStep = (
|
|
127
|
-
inputs: Record<string, Int32Array | Float32Array>,
|
|
128
|
-
opts?: { withCaptures?: boolean },
|
|
129
|
-
): Promise<number | { loss: number; captures: Record<string, Float32Array> }> => {
|
|
130
|
-
t++
|
|
131
|
-
const lrNow = config.lr(t)
|
|
132
|
-
lrtBuf[0] = lrNow * Math.sqrt(1 - Math.pow(config.b2, t)) / (1 - Math.pow(config.b1, t))
|
|
133
|
-
const merged: Record<string, Int32Array | Float32Array> = { ...inputs, [lrtInputName]: lrtBuf }
|
|
134
|
-
if (decayShrinkBuf && decayShrinkInputName) {
|
|
135
|
-
decayShrinkBuf[0] = 1 - lrNow * config.weightDecay
|
|
136
|
-
merged[decayShrinkInputName] = decayShrinkBuf
|
|
137
|
-
}
|
|
138
|
-
return opts?.withCaptures ? innerStep(merged, { withCaptures: true }) : innerStep(merged)
|
|
139
|
-
}
|
|
140
|
-
runtime.step = wrappedStep as CompiledRuntime['step']
|
|
141
|
-
runtime.resetOptimizerState = () => {
|
|
142
|
-
t = 0
|
|
143
|
-
innerReset()
|
|
144
|
-
}
|
|
161
|
+
if (opts.adam) {
|
|
162
|
+
wrapStepForAdam(runtime, opts.adam, ir)
|
|
145
163
|
}
|
|
146
164
|
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
}
|
|
165
|
+
// Auto-upload initial param values. Always wanted at this entry point —
|
|
166
|
+
// training runtimes own their params and need them randomized before step 1.
|
|
167
|
+
uploadInitialParams(plan, materialized.initFns, runtime, /* sharedParams */ undefined)
|
|
151
168
|
|
|
152
|
-
const
|
|
153
|
-
return Object.assign(runtime, { ir, uploadInitialParams })
|
|
154
|
-
}
|
|
169
|
+
const kernelCount = kernels.filter(k => k.wgsl).length
|
|
155
170
|
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
const size = shape.reduce((a, b) => a * b, 1)
|
|
171
|
-
const initFn = initFns[name]
|
|
172
|
-
if (!initFn) throw new Error(`uploadInitialParams: no init for param '${name}'`)
|
|
173
|
-
out[name] = initFn(size, shape)
|
|
171
|
+
const reset = () => {
|
|
172
|
+
uploadInitialParams(plan, materialized.initFns, runtime, undefined)
|
|
173
|
+
runtime.resetOptimizerState()
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
const compileForwardMethod = async <J extends InputDecls>(
|
|
177
|
+
forwardFn: ForwardFn<M, J>,
|
|
178
|
+
fOpts: CompileForwardMethodOptions<J> = {},
|
|
179
|
+
): Promise<CompiledForwardModule> => {
|
|
180
|
+
return compileForward<M, J>(modelFactory, forwardFn, {
|
|
181
|
+
...fOpts,
|
|
182
|
+
device: runtime.device,
|
|
183
|
+
sharedParams: runtime.params,
|
|
184
|
+
})
|
|
174
185
|
}
|
|
175
|
-
|
|
186
|
+
|
|
187
|
+
return Object.assign(runtime, {
|
|
188
|
+
ir,
|
|
189
|
+
kernelCount,
|
|
190
|
+
reset,
|
|
191
|
+
compileForward: compileForwardMethod,
|
|
192
|
+
})
|
|
176
193
|
}
|
|
177
194
|
|
|
178
195
|
// ============================================================================
|
|
@@ -185,43 +202,157 @@ function buildInitialParamUploads(
|
|
|
185
202
|
* scalar loss; runtime exposes `run(inputs)` returning the full output as a
|
|
186
203
|
* `Float32Array`.
|
|
187
204
|
*
|
|
205
|
+
* **Prefer the `compileForward` method on a training runtime** when both
|
|
206
|
+
* graphs use the same Module class — it auto-supplies `device` and
|
|
207
|
+
* `sharedParams`. This standalone form is for forward-only models with no
|
|
208
|
+
* training graph at all, or for sharing params across a different model.
|
|
209
|
+
*
|
|
188
210
|
* **Sharing params with a training compile.** Pass `opts.sharedParams =
|
|
189
211
|
* trainCompiled.params` to bind this graph's param buffers to an existing
|
|
190
212
|
* training runtime's GPU buffers — every train step is then immediately
|
|
191
|
-
* visible to `run()` calls here, no copies.
|
|
192
|
-
* `uploadInitialParams()` skips any param covered by `sharedParams`.
|
|
213
|
+
* visible to `run()` calls here, no copies.
|
|
193
214
|
*
|
|
194
|
-
*
|
|
195
|
-
*
|
|
215
|
+
* Initial param values are uploaded automatically for params *not* covered
|
|
216
|
+
* by `sharedParams` (those are owned by the sibling compile).
|
|
196
217
|
*/
|
|
197
|
-
export async function compileForward<M extends Module>(
|
|
218
|
+
export async function compileForward<M extends Module, I extends InputDecls = InputDecls>(
|
|
219
|
+
modelFactory: () => M,
|
|
220
|
+
forward: ForwardFn<M, I>,
|
|
221
|
+
opts: CompileForwardOptions<I> = {},
|
|
222
|
+
): Promise<CompiledForwardModule> {
|
|
223
|
+
const sharedParams = opts.sharedParams
|
|
224
|
+
const { runtime, materialized, plan, kernels, ir } = await buildModuleRuntime(
|
|
225
|
+
modelFactory, forward, opts, sharedParams, /* withGrad */ false,
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
// Auto-upload initial values for any params this graph owns. With
|
|
229
|
+
// `sharedParams` covering everything, this is a no-op.
|
|
230
|
+
uploadInitialParams(plan, materialized.initFns, runtime, sharedParams)
|
|
231
|
+
|
|
232
|
+
const kernelCount = kernels.filter(k => k.wgsl).length
|
|
233
|
+
return Object.assign(runtime, { ir, kernelCount })
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
// ============================================================================
|
|
237
|
+
// Internals
|
|
238
|
+
// ============================================================================
|
|
239
|
+
|
|
240
|
+
type InitFn = (size: number, shape: readonly number[]) => Float32Array
|
|
241
|
+
|
|
242
|
+
interface BuiltRuntime {
|
|
243
|
+
runtime: CompiledRuntime
|
|
244
|
+
materialized: ReturnType<typeof materializeParams>
|
|
245
|
+
plan: BufferPlan
|
|
246
|
+
kernels: KernelSpec[]
|
|
247
|
+
ir: CompiledIR
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
/** Shared body of compileModule + compileForward. The training and forward
|
|
251
|
+
* pipelines diverge only in (a) whether grad/Adam are appended and (b)
|
|
252
|
+
* whether the output buffer is the loss scalar or the user's returned
|
|
253
|
+
* tensor — both come out of the same trace and codegen path. */
|
|
254
|
+
async function buildModuleRuntime<M extends Module, I extends InputDecls>(
|
|
198
255
|
modelFactory: () => M,
|
|
199
|
-
forward:
|
|
200
|
-
opts:
|
|
201
|
-
|
|
202
|
-
|
|
256
|
+
forward: ForwardFn<M, I>,
|
|
257
|
+
opts: CompileModuleOptions<I> | CompileForwardOptions<I>,
|
|
258
|
+
sharedParams: Map<string, GPUBuffer> | undefined,
|
|
259
|
+
withGrad: boolean,
|
|
260
|
+
): Promise<BuiltRuntime> {
|
|
261
|
+
const inputDecls: InputDecls = opts.inputs ?? {}
|
|
203
262
|
const model = modelFactory()
|
|
204
263
|
let materialized: ReturnType<typeof materializeParams> = { tensors: {}, initFns: {}, decayFlags: {} }
|
|
205
264
|
const graph = trace(() => {
|
|
206
265
|
materialized = materializeParams(model)
|
|
207
|
-
const inputTensors
|
|
208
|
-
|
|
266
|
+
const inputTensors: Record<string, Tensor> = {}
|
|
267
|
+
for (const [name, decl] of Object.entries(inputDecls)) {
|
|
268
|
+
inputTensors[name] = tensorInput(name, decl.shape, decl.dtype ?? 'f32')
|
|
269
|
+
}
|
|
270
|
+
return forward(model, inputTensors as InputsTensors<I>)
|
|
209
271
|
})
|
|
210
272
|
|
|
211
|
-
|
|
273
|
+
let paramGrads: GradResult['paramGrads'] = {}
|
|
274
|
+
let outputTensor: Tensor
|
|
275
|
+
let adamWritebacks: ReturnType<typeof appendAdam>['writebacks'] = []
|
|
276
|
+
|
|
277
|
+
if (withGrad) {
|
|
278
|
+
const gradResult = appendGrad(graph)
|
|
279
|
+
paramGrads = gradResult.paramGrads
|
|
280
|
+
outputTensor = gradResult.loss
|
|
281
|
+
const adamCfg = (opts as CompileModuleOptions).adam
|
|
282
|
+
if (adamCfg) {
|
|
283
|
+
const adamResult = appendAdam(graph, paramGrads, materialized.tensors, adamCfg, materialized.decayFlags)
|
|
284
|
+
adamWritebacks = adamResult.writebacks
|
|
285
|
+
// Stash adam result on the graph so wrapStepForAdam can find it.
|
|
286
|
+
;(graph as Graph & { __adam?: ReturnType<typeof appendAdam> }).__adam = adamResult
|
|
287
|
+
}
|
|
288
|
+
} else {
|
|
289
|
+
outputTensor = graph.tensors[graph.outputs[0]!]!
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
const plan = planBuffers(graph, paramGrads, adamWritebacks)
|
|
212
293
|
const kernels = emitKernels(graph, plan)
|
|
213
|
-
const outputTensor = graph.tensors[graph.outputs[0]!]!
|
|
214
294
|
const outputBufferId = plan.tensorToBuffer.get(outputTensor.id)!
|
|
215
|
-
|
|
295
|
+
// exactOptionalPropertyTypes: only include sharedParams when defined.
|
|
296
|
+
const runtimeOpts: RuntimeOpts = sharedParams
|
|
297
|
+
? { ...opts, sharedParams }
|
|
298
|
+
: { ...opts }
|
|
299
|
+
const runtime = withGrad
|
|
300
|
+
? await createRuntime(plan, kernels, outputBufferId, runtimeOpts)
|
|
301
|
+
: await createForwardRuntime(plan, kernels, outputBufferId, runtimeOpts)
|
|
216
302
|
|
|
217
|
-
const
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
303
|
+
const ir: CompiledIR = { graph, paramGrads, loss: outputTensor, plan, kernels }
|
|
304
|
+
return { runtime: runtime as CompiledRuntime, materialized, plan, kernels, ir }
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
type Graph = ReturnType<typeof trace>
|
|
308
|
+
|
|
309
|
+
function wrapStepForAdam(runtime: CompiledRuntime, adamCfg: AdamConfig, ir: CompiledIR): void {
|
|
310
|
+
const adamResult = (ir.graph as Graph & { __adam?: ReturnType<typeof appendAdam> }).__adam!
|
|
311
|
+
const { lrtInputName, decayShrinkInputName, config } = adamResult
|
|
312
|
+
let t = 0
|
|
313
|
+
const lrtBuf = new Float32Array(1)
|
|
314
|
+
const decayShrinkBuf = decayShrinkInputName ? new Float32Array(1) : null
|
|
315
|
+
const innerStep = runtime.step.bind(runtime) as CompiledRuntime['step']
|
|
316
|
+
const innerReset = runtime.resetOptimizerState.bind(runtime)
|
|
317
|
+
const wrappedStep = ((
|
|
318
|
+
inputs: Record<string, Int32Array | Float32Array>,
|
|
319
|
+
opts?: { withCaptures?: boolean },
|
|
320
|
+
) => {
|
|
321
|
+
t++
|
|
322
|
+
const lrNow = config.lr(t)
|
|
323
|
+
lrtBuf[0] = lrNow * Math.sqrt(1 - Math.pow(config.b2, t)) / (1 - Math.pow(config.b1, t))
|
|
324
|
+
const merged: Record<string, Int32Array | Float32Array> = { ...inputs, [lrtInputName]: lrtBuf }
|
|
325
|
+
if (decayShrinkBuf && decayShrinkInputName) {
|
|
326
|
+
decayShrinkBuf[0] = 1 - lrNow * config.weightDecay
|
|
327
|
+
merged[decayShrinkInputName] = decayShrinkBuf
|
|
328
|
+
}
|
|
329
|
+
return opts?.withCaptures ? innerStep(merged, { withCaptures: true }) : innerStep(merged)
|
|
330
|
+
}) as CompiledRuntime['step']
|
|
331
|
+
runtime.step = wrappedStep
|
|
332
|
+
runtime.resetOptimizerState = () => {
|
|
333
|
+
t = 0
|
|
334
|
+
innerReset()
|
|
221
335
|
}
|
|
336
|
+
void adamCfg
|
|
337
|
+
}
|
|
222
338
|
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
339
|
+
/** Build a Record<paramName, Float32Array> by running each param's init
|
|
340
|
+
* function against its shape and uploading them to the runtime. Skips any
|
|
341
|
+
* param covered by `sharedParams` (those are owned by a sibling compile). */
|
|
342
|
+
function uploadInitialParams(
|
|
343
|
+
plan: BufferPlan,
|
|
344
|
+
initFns: Record<string, InitFn>,
|
|
345
|
+
runtime: CompiledRuntime | CompiledForward,
|
|
346
|
+
sharedParams: Map<string, GPUBuffer> | undefined,
|
|
347
|
+
): void {
|
|
348
|
+
const out: Record<string, Float32Array> = {}
|
|
349
|
+
for (const [name, bufId] of plan.paramsByName) {
|
|
350
|
+
if (sharedParams?.has(name)) continue
|
|
351
|
+
const shape = plan.buffers[bufId]!.shape
|
|
352
|
+
const size = shape.reduce((a, b) => a * b, 1)
|
|
353
|
+
const initFn = initFns[name]
|
|
354
|
+
if (!initFn) throw new Error(`compile: no init for param '${name}'`)
|
|
355
|
+
out[name] = initFn(size, shape)
|
|
356
|
+
}
|
|
357
|
+
if (Object.keys(out).length > 0) runtime.uploadParams(out, { partial: !!sharedParams })
|
|
227
358
|
}
|
package/src/index.ts
CHANGED
|
@@ -36,7 +36,12 @@ export { appendGrad, type GradResult } from './grad.js'
|
|
|
36
36
|
export { appendAdam, type AdamConfig, type AdamResult } from './adam.js'
|
|
37
37
|
export { planBuffers, type BufferPlan, type BufferSpec, type Writeback, type WritebackDecl } from './buffers.js'
|
|
38
38
|
export { emitKernels, type KernelSpec } from './codegen.js'
|
|
39
|
-
export { createRuntime, createForwardRuntime, type CompiledRuntime, type CompiledForward, type RuntimeOpts, type
|
|
40
|
-
export {
|
|
39
|
+
export { createRuntime, createForwardRuntime, Captures, type CompiledRuntime, type CompiledForward, type RuntimeOpts, type RunOptions, type StepResult, type RunResult } from './runtime.js'
|
|
40
|
+
export {
|
|
41
|
+
compile, compileToIR, compileModule, compileForward,
|
|
42
|
+
type CompiledIR, type CompileModuleOptions, type CompileForwardOptions, type CompileForwardMethodOptions,
|
|
43
|
+
type CompiledModule, type CompiledForwardModule,
|
|
44
|
+
type InputDecl, type InputDecls, type InputsTensors, type ForwardFn,
|
|
45
|
+
} from './compile.js'
|
|
41
46
|
export { Module, materializeParams, type InitSpec, type ParamOptions, type MaterializedParams } from './module.js'
|
|
42
47
|
export * as nn from './nn.js'
|
package/src/nn.ts
CHANGED
|
@@ -1,41 +1,44 @@
|
|
|
1
1
|
// Standard "batteries-included" Module subclasses for the most common layers.
|
|
2
2
|
//
|
|
3
|
-
//
|
|
4
|
-
//
|
|
5
|
-
//
|
|
6
|
-
//
|
|
7
|
-
// Import as a namespace:
|
|
3
|
+
// Each class declares its params and a `.fwd(x)` method that runs the forward
|
|
4
|
+
// computation. Forward methods are pure tensorgrad ops — autograd traces
|
|
5
|
+
// through them just like any other call.
|
|
8
6
|
//
|
|
9
7
|
// import { nn } from 'tensorgrad'
|
|
10
8
|
// class Block extends Module {
|
|
11
9
|
// ln = new nn.LayerNorm(D)
|
|
12
10
|
// ffn = new nn.Linear(D, 4 * D)
|
|
13
11
|
// }
|
|
14
|
-
// const y =
|
|
12
|
+
// const y = p.ffn.fwd(p.ln.fwd(x))
|
|
15
13
|
|
|
16
14
|
import { Module } from './module.js'
|
|
17
15
|
import type { Tensor } from './ir.js'
|
|
18
16
|
import { add, matmul, sub, mul, div, sqrt, meanLast, sumLast, reshape, swapAxes, oneHot, logSoftmaxLast } from './ops.js'
|
|
19
17
|
import { ShapeError } from './shape.js'
|
|
20
18
|
import { captureSite } from './ir.js'
|
|
19
|
+
import type { Captures } from './runtime.js'
|
|
21
20
|
|
|
22
21
|
// ----------------------------------------------------------------------------
|
|
23
22
|
// Linear: y = x @ W (+ b)
|
|
24
23
|
// ----------------------------------------------------------------------------
|
|
25
24
|
|
|
25
|
+
export interface LinearOptions {
|
|
26
|
+
/** Include a bias term (default true). */
|
|
27
|
+
bias?: boolean
|
|
28
|
+
}
|
|
29
|
+
|
|
26
30
|
export class Linear extends Module {
|
|
27
31
|
W: Tensor
|
|
28
32
|
b: Tensor | null
|
|
29
|
-
constructor(public readonly inDim: number, public readonly outDim: number,
|
|
33
|
+
constructor(public readonly inDim: number, public readonly outDim: number, opts: LinearOptions = {}) {
|
|
30
34
|
super()
|
|
31
35
|
this.W = this.param([inDim, outDim]) // randn, scale 0.02
|
|
32
|
-
this.b =
|
|
36
|
+
this.b = opts.bias === false ? null : this.param([outDim], { init: 'zeros' })
|
|
37
|
+
}
|
|
38
|
+
fwd(x: Tensor): Tensor {
|
|
39
|
+
const out = matmul(x, this.W)
|
|
40
|
+
return this.b ? add(out, this.b) : out
|
|
33
41
|
}
|
|
34
|
-
}
|
|
35
|
-
|
|
36
|
-
export function linearFwd(p: Linear, x: Tensor): Tensor {
|
|
37
|
-
const out = matmul(x, p.W)
|
|
38
|
-
return p.b ? add(out, p.b) : out
|
|
39
42
|
}
|
|
40
43
|
|
|
41
44
|
// ----------------------------------------------------------------------------
|
|
@@ -50,14 +53,13 @@ export class LayerNorm extends Module {
|
|
|
50
53
|
this.g = this.param([d], { init: 'ones' })
|
|
51
54
|
this.b = this.param([d], { init: 'zeros' })
|
|
52
55
|
}
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
return add(mul(div(c, stdev), p.g), p.b)
|
|
56
|
+
fwd(x: Tensor): Tensor {
|
|
57
|
+
const m = meanLast(x)
|
|
58
|
+
const c = sub(x, m)
|
|
59
|
+
const v = meanLast(mul(c, c))
|
|
60
|
+
const stdev = sqrt(add(v, this.eps))
|
|
61
|
+
return add(mul(div(c, stdev), this.g), this.b)
|
|
62
|
+
}
|
|
61
63
|
}
|
|
62
64
|
|
|
63
65
|
// ----------------------------------------------------------------------------
|
|
@@ -97,26 +99,26 @@ export function mergeHeads(x: Tensor): Tensor {
|
|
|
97
99
|
return reshape(swapped, [...lead, T, H * d])
|
|
98
100
|
}
|
|
99
101
|
|
|
100
|
-
/** Slice a
|
|
101
|
-
*
|
|
102
|
-
*
|
|
103
|
-
*
|
|
104
|
-
*
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
102
|
+
/** Slice a captured tensor named `name` into one Float32Array per head, using
|
|
103
|
+
* the static shape registered at compile time. The leading axis is treated as
|
|
104
|
+
* heads (matching `splitHeads` layout at B=1); a leading singleton batch is
|
|
105
|
+
* stripped if present so callers can pass capture names directly. Throws if
|
|
106
|
+
* the capture isn't registered or wasn't read back this call. */
|
|
107
|
+
export function unsplitHeads(captures: Captures, name: string): Float32Array[] {
|
|
108
|
+
const flat = captures.get(name)
|
|
109
|
+
const shape = captures.shapeOf(name)
|
|
108
110
|
if (shape.length < 2) {
|
|
109
|
-
throw new Error(`unsplitHeads: shape needs >= 2 dims, got [${shape.join(', ')}]`)
|
|
111
|
+
throw new Error(`unsplitHeads: '${name}' shape needs >= 2 dims, got [${shape.join(', ')}]`)
|
|
110
112
|
}
|
|
111
113
|
// For inference graphs at B=1, captures have shape [1, H, ..., ...]. Strip
|
|
112
|
-
// the leading 1 if present so
|
|
114
|
+
// the leading 1 if present so the next axis is heads.
|
|
113
115
|
const s = shape[0] === 1 ? shape.slice(1) : shape
|
|
114
116
|
const H = s[0]!
|
|
115
117
|
let stride = 1
|
|
116
118
|
for (let i = 1; i < s.length; i++) stride *= s[i]!
|
|
117
119
|
const expected = H * stride
|
|
118
120
|
if (flat.length !== expected) {
|
|
119
|
-
throw new Error(`unsplitHeads:
|
|
121
|
+
throw new Error(`unsplitHeads: '${name}' length ${flat.length} doesn't match shape product ${expected}`)
|
|
120
122
|
}
|
|
121
123
|
return Array.from({ length: H }, (_, h) => flat.slice(h * stride, (h + 1) * stride))
|
|
122
124
|
}
|