tensorgrad 0.0.9 → 0.0.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/compile.ts CHANGED
@@ -17,15 +17,32 @@ import { emitKernels, type KernelSpec } from './codegen.js'
17
17
  import { createRuntime, createForwardRuntime, type CompiledRuntime, type CompiledForward, type RuntimeOpts } from './runtime.js'
18
18
  import { Module, materializeParams } from './module.js'
19
19
 
20
- /** Declares one input tensor of the model's forward function. Order matches
21
- * the function's parameter list (after `model`). The `name` is used at
22
- * runtime to upload data via `step({ [name]: data })`. */
20
+ /** Declares one input tensor of the model's forward function. The name is the
21
+ * key in the `inputs:` Record at compile time and the key on the `step()`/
22
+ * `run()` data object at runtime. */
23
23
  export interface InputDecl {
24
- name: string
25
24
  shape: Shape
26
25
  dtype?: Dtype
27
26
  }
28
27
 
28
+ /** Inputs declaration: a Record from input name to its shape/dtype. The name
29
+ * doubles as the key the forward fn destructures and the key the runtime
30
+ * expects in `step({...})` / `run({...})`. */
31
+ export type InputDecls = Record<string, InputDecl>
32
+
33
+ /** Maps an `InputDecls` Record to its forward-time tensor counterpart —
34
+ * same keys, each value is a Tensor. Used to type the forward function's
35
+ * `inputs` argument from the declared shape Record. */
36
+ export type InputsTensors<I extends InputDecls> = { [K in keyof I]: Tensor }
37
+
38
+ /** Forward function shape: takes the materialized model and a Record of
39
+ * named input tensors (matching the declared `inputs:` keys), returns the
40
+ * output tensor (loss for compileModule; logits/etc. for compileForward).
41
+ * The second generic flows from the inputs declaration so destructuring
42
+ * the input record stays typed. */
43
+ export type ForwardFn<M extends Module, I extends InputDecls = InputDecls> =
44
+ (m: M, inputs: InputsTensors<I>) => Tensor
45
+
29
46
  export interface CompiledIR {
30
47
  graph: GradResult['graph']
31
48
  paramGrads: GradResult['paramGrads']
@@ -55,19 +72,52 @@ export async function compile(traceFn: () => Tensor, opts: RuntimeOpts = {}): Pr
55
72
  // Module-aware compile
56
73
  // ============================================================================
57
74
 
58
- export interface CompileModuleOptions extends RuntimeOpts {
59
- /** Per-step data inputs to the forward function. Order matches the forward
60
- * function's parameters (after the model). e.g. for
61
- * `(model, tokens, targets, mask) => loss`, inputs is
62
- * `[{name:'tokens',...}, {name:'targets',...}, {name:'mask',...}]`. */
63
- inputs?: InputDecl[]
75
+ export interface CompileModuleOptions<I extends InputDecls = InputDecls> extends RuntimeOpts {
76
+ /** Per-step data inputs to the forward function, keyed by name. The forward
77
+ * fn destructures these out of its second argument; runtime calls to
78
+ * `step()` / `run()` pass typed arrays under the same keys. */
79
+ inputs?: I
64
80
  /** Adam hyperparameters. If omitted, no optimizer is appended (forward-only). */
65
81
  adam?: AdamConfig
66
82
  }
67
83
 
68
- export interface CompileForwardOptions extends RuntimeOpts {
69
- /** Per-step data inputs to the forward function. */
70
- inputs?: InputDecl[]
84
+ export interface CompileForwardOptions<I extends InputDecls = InputDecls> extends RuntimeOpts {
85
+ /** Per-step data inputs to the forward function, keyed by name. */
86
+ inputs?: I
87
+ }
88
+
89
+ /** Forward-only compile options as taken by the `compileForward` *method* on
90
+ * a training runtime — no `device` (inherited) and no `sharedParams`
91
+ * (auto-supplied from the train graph's params). */
92
+ export interface CompileForwardMethodOptions<I extends InputDecls = InputDecls> {
93
+ inputs?: I
94
+ }
95
+
96
+ /** Returned by `compileModule`. Adds training-graph extras (auto-init, reset,
97
+ * sibling-graph compile) on top of the base runtime. */
98
+ export interface CompiledModule<M extends Module> extends CompiledRuntime {
99
+ ir: CompiledIR
100
+ /** Number of dispatchable kernels (excludes leaf no-ops). */
101
+ kernelCount: number
102
+ /** Re-initialize all params from their declared init specs and zero the
103
+ * optimizer state. Use to start training over without recompiling. */
104
+ reset(): void
105
+ /** Compile a sibling forward-only graph (e.g., a B=1 inference graph or a
106
+ * B=N held-out eval graph) that shares this runtime's device and param
107
+ * buffers. Pass the forward fn (typically distinct from your loss fn —
108
+ * it returns logits, not a scalar) and any shape changes via `inputs`.
109
+ * Auto-initialization is a no-op since params are shared. */
110
+ compileForward<I extends InputDecls>(
111
+ forward: ForwardFn<M, I>,
112
+ opts?: CompileForwardMethodOptions<I>,
113
+ ): Promise<CompiledForwardModule>
114
+ }
115
+
116
+ /** Returned by `compileForward` (and by the `compileForward` method). */
117
+ export interface CompiledForwardModule extends CompiledForward {
118
+ ir: CompiledIR
119
+ /** Number of dispatchable kernels (excludes leaf no-ops). */
120
+ kernelCount: number
71
121
  }
72
122
 
73
123
  /**
@@ -76,103 +126,70 @@ export interface CompileForwardOptions extends RuntimeOpts {
76
126
  * field becomes a real `Tensor`), so the instance is consumed and shouldn't be
77
127
  * referenced afterwards. Re-call the factory if you need a fresh tree.
78
128
  *
79
- * The forward function takes the materialized model and returns the loss
80
- * tensor.
129
+ * The forward function takes the materialized model and a Record of named
130
+ * input tensors, returns the loss tensor. Inputs are matched by name with the
131
+ * `inputs:` declaration:
132
+ *
133
+ * inputs: {
134
+ * tokens: { shape: [B, T], dtype: 'i32' },
135
+ * targets: { shape: [B, T], dtype: 'i32' },
136
+ * }
137
+ * forward: (m, { tokens, targets }) => …
81
138
  *
82
139
  * Walks the module tree to materialize params with auto-derived names, then
83
- * runs trace → grad → adam → buffer plan → codegen → runtime.
140
+ * runs trace → grad → adam → buffer plan → codegen → runtime. Initial
141
+ * parameter values are uploaded automatically before this function returns;
142
+ * call `reset()` later to re-randomize.
84
143
  *
85
144
  * If `opts.adam` is set, the runtime's `step()` automatically tracks an
86
145
  * internal step count and injects the bias-corrected `lrt` scalar each call;
87
146
  * users don't need to provide it themselves.
88
147
  */
89
- export async function compileModule<M extends Module>(
148
+ export async function compileModule<M extends Module, I extends InputDecls = InputDecls>(
90
149
  modelFactory: () => M,
91
- forward: (m: M, ...inputs: Tensor[]) => Tensor,
92
- opts: CompileModuleOptions = {},
93
- ): Promise<CompiledRuntime & { ir: CompiledIR; uploadInitialParams: () => void }> {
94
- const inputDecls = opts.inputs ?? []
95
- const model = modelFactory()
96
- let materialized: ReturnType<typeof materializeParams> = { tensors: {}, initFns: {}, decayFlags: {} }
97
- const graph = trace(() => {
98
- materialized = materializeParams(model)
99
- const inputTensors = inputDecls.map(d => tensorInput(d.name, d.shape, d.dtype ?? 'f32'))
100
- return forward(model, ...inputTensors)
101
- })
102
-
103
- const { paramGrads, loss } = appendGrad(graph)
104
-
105
- let adamResult: ReturnType<typeof appendAdam> | undefined
106
- if (opts.adam) {
107
- adamResult = appendAdam(graph, paramGrads, materialized.tensors, opts.adam, materialized.decayFlags)
108
- }
109
-
110
- const plan = planBuffers(graph, paramGrads, adamResult?.writebacks ?? [])
111
- const kernels = emitKernels(graph, plan)
112
- const lossBufferId = plan.tensorToBuffer.get(loss.id)!
113
- const runtime = await createRuntime(plan, kernels, lossBufferId, opts)
150
+ forward: ForwardFn<M, I>,
151
+ opts: CompileModuleOptions<I> = {},
152
+ ): Promise<CompiledModule<M>> {
153
+ const { runtime, materialized, plan, kernels, ir } = await buildModuleRuntime(
154
+ modelFactory, forward, opts, /* sharedParams */ undefined, /* withGrad */ true,
155
+ )
114
156
 
115
157
  // If Adam is enabled, wrap step() to track the step count and supply lrt
116
158
  // (and optionally decayShrink, when the user passed a per-step lr schedule).
117
159
  // Wrap resetOptimizerState() too, so a reset zeros m/v *and* the bias-correction
118
160
  // counter — otherwise the next step would skip Adam's warmup phase.
119
- if (adamResult) {
120
- const { lrtInputName, decayShrinkInputName, config } = adamResult
121
- let t = 0
122
- const lrtBuf = new Float32Array(1)
123
- const decayShrinkBuf = decayShrinkInputName ? new Float32Array(1) : null
124
- const innerStep = runtime.step.bind(runtime) as CompiledRuntime['step']
125
- const innerReset = runtime.resetOptimizerState.bind(runtime)
126
- const wrappedStep = (
127
- inputs: Record<string, Int32Array | Float32Array>,
128
- opts?: { withCaptures?: boolean },
129
- ): Promise<number | { loss: number; captures: Record<string, Float32Array> }> => {
130
- t++
131
- const lrNow = config.lr(t)
132
- lrtBuf[0] = lrNow * Math.sqrt(1 - Math.pow(config.b2, t)) / (1 - Math.pow(config.b1, t))
133
- const merged: Record<string, Int32Array | Float32Array> = { ...inputs, [lrtInputName]: lrtBuf }
134
- if (decayShrinkBuf && decayShrinkInputName) {
135
- decayShrinkBuf[0] = 1 - lrNow * config.weightDecay
136
- merged[decayShrinkInputName] = decayShrinkBuf
137
- }
138
- return opts?.withCaptures ? innerStep(merged, { withCaptures: true }) : innerStep(merged)
139
- }
140
- runtime.step = wrappedStep as CompiledRuntime['step']
141
- runtime.resetOptimizerState = () => {
142
- t = 0
143
- innerReset()
144
- }
161
+ if (opts.adam) {
162
+ wrapStepForAdam(runtime, opts.adam, ir)
145
163
  }
146
164
 
147
- const uploadInitialParams = () => {
148
- const out = buildInitialParamUploads(plan, materialized.initFns)
149
- runtime.uploadParams(out)
150
- }
165
+ // Auto-upload initial param values. Always wanted at this entry point —
166
+ // training runtimes own their params and need them randomized before step 1.
167
+ uploadInitialParams(plan, materialized.initFns, runtime, /* sharedParams */ undefined)
151
168
 
152
- const ir: CompiledIR = { graph, paramGrads, loss, plan, kernels }
153
- return Object.assign(runtime, { ir, uploadInitialParams })
154
- }
169
+ const kernelCount = kernels.filter(k => k.wgsl).length
155
170
 
156
- // Build a Record<paramName, Float32Array> by running each param's init
157
- // function against its shape. Shared by compileModule and compileForward.
158
- // `sharedParams`, when supplied, skips any name it covers (those are owned
159
- // by the sibling compile and already initialized there).
160
- type InitFn = (size: number, shape: readonly number[]) => Float32Array
161
- function buildInitialParamUploads(
162
- plan: BufferPlan,
163
- initFns: Record<string, InitFn>,
164
- sharedParams?: Map<string, GPUBuffer>,
165
- ): Record<string, Float32Array> {
166
- const out: Record<string, Float32Array> = {}
167
- for (const [name, bufId] of plan.paramsByName) {
168
- if (sharedParams?.has(name)) continue
169
- const shape = plan.buffers[bufId]!.shape
170
- const size = shape.reduce((a, b) => a * b, 1)
171
- const initFn = initFns[name]
172
- if (!initFn) throw new Error(`uploadInitialParams: no init for param '${name}'`)
173
- out[name] = initFn(size, shape)
171
+ const reset = () => {
172
+ uploadInitialParams(plan, materialized.initFns, runtime, undefined)
173
+ runtime.resetOptimizerState()
174
+ }
175
+
176
+ const compileForwardMethod = async <J extends InputDecls>(
177
+ forwardFn: ForwardFn<M, J>,
178
+ fOpts: CompileForwardMethodOptions<J> = {},
179
+ ): Promise<CompiledForwardModule> => {
180
+ return compileForward<M, J>(modelFactory, forwardFn, {
181
+ ...fOpts,
182
+ device: runtime.device,
183
+ sharedParams: runtime.params,
184
+ })
174
185
  }
175
- return out
186
+
187
+ return Object.assign(runtime, {
188
+ ir,
189
+ kernelCount,
190
+ reset,
191
+ compileForward: compileForwardMethod,
192
+ })
176
193
  }
177
194
 
178
195
  // ============================================================================
@@ -185,43 +202,157 @@ function buildInitialParamUploads(
185
202
  * scalar loss; runtime exposes `run(inputs)` returning the full output as a
186
203
  * `Float32Array`.
187
204
  *
205
+ * **Prefer the `compileForward` method on a training runtime** when both
206
+ * graphs use the same Module class — it auto-supplies `device` and
207
+ * `sharedParams`. This standalone form is for forward-only models with no
208
+ * training graph at all, or for sharing params across a different model.
209
+ *
188
210
  * **Sharing params with a training compile.** Pass `opts.sharedParams =
189
211
  * trainCompiled.params` to bind this graph's param buffers to an existing
190
212
  * training runtime's GPU buffers — every train step is then immediately
191
- * visible to `run()` calls here, no copies. The forward graph's
192
- * `uploadInitialParams()` skips any param covered by `sharedParams`.
213
+ * visible to `run()` calls here, no copies.
193
214
  *
194
- * Typical use: a B=1 inference graph alongside a B=512 training graph,
195
- * built from the same `Module` factory.
215
+ * Initial param values are uploaded automatically for params *not* covered
216
+ * by `sharedParams` (those are owned by the sibling compile).
196
217
  */
197
- export async function compileForward<M extends Module>(
218
+ export async function compileForward<M extends Module, I extends InputDecls = InputDecls>(
219
+ modelFactory: () => M,
220
+ forward: ForwardFn<M, I>,
221
+ opts: CompileForwardOptions<I> = {},
222
+ ): Promise<CompiledForwardModule> {
223
+ const sharedParams = opts.sharedParams
224
+ const { runtime, materialized, plan, kernels, ir } = await buildModuleRuntime(
225
+ modelFactory, forward, opts, sharedParams, /* withGrad */ false,
226
+ )
227
+
228
+ // Auto-upload initial values for any params this graph owns. With
229
+ // `sharedParams` covering everything, this is a no-op.
230
+ uploadInitialParams(plan, materialized.initFns, runtime, sharedParams)
231
+
232
+ const kernelCount = kernels.filter(k => k.wgsl).length
233
+ return Object.assign(runtime, { ir, kernelCount })
234
+ }
235
+
236
+ // ============================================================================
237
+ // Internals
238
+ // ============================================================================
239
+
240
+ type InitFn = (size: number, shape: readonly number[]) => Float32Array
241
+
242
+ interface BuiltRuntime {
243
+ runtime: CompiledRuntime
244
+ materialized: ReturnType<typeof materializeParams>
245
+ plan: BufferPlan
246
+ kernels: KernelSpec[]
247
+ ir: CompiledIR
248
+ }
249
+
250
+ /** Shared body of compileModule + compileForward. The training and forward
251
+ * pipelines diverge only in (a) whether grad/Adam are appended and (b)
252
+ * whether the output buffer is the loss scalar or the user's returned
253
+ * tensor — both come out of the same trace and codegen path. */
254
+ async function buildModuleRuntime<M extends Module, I extends InputDecls>(
198
255
  modelFactory: () => M,
199
- forward: (m: M, ...inputs: Tensor[]) => Tensor,
200
- opts: CompileForwardOptions = {},
201
- ): Promise<CompiledForward & { ir: CompiledIR; uploadInitialParams: () => void }> {
202
- const inputDecls = opts.inputs ?? []
256
+ forward: ForwardFn<M, I>,
257
+ opts: CompileModuleOptions<I> | CompileForwardOptions<I>,
258
+ sharedParams: Map<string, GPUBuffer> | undefined,
259
+ withGrad: boolean,
260
+ ): Promise<BuiltRuntime> {
261
+ const inputDecls: InputDecls = opts.inputs ?? {}
203
262
  const model = modelFactory()
204
263
  let materialized: ReturnType<typeof materializeParams> = { tensors: {}, initFns: {}, decayFlags: {} }
205
264
  const graph = trace(() => {
206
265
  materialized = materializeParams(model)
207
- const inputTensors = inputDecls.map(d => tensorInput(d.name, d.shape, d.dtype ?? 'f32'))
208
- return forward(model, ...inputTensors)
266
+ const inputTensors: Record<string, Tensor> = {}
267
+ for (const [name, decl] of Object.entries(inputDecls)) {
268
+ inputTensors[name] = tensorInput(name, decl.shape, decl.dtype ?? 'f32')
269
+ }
270
+ return forward(model, inputTensors as InputsTensors<I>)
209
271
  })
210
272
 
211
- const plan = planBuffers(graph, /* paramGrads */ {})
273
+ let paramGrads: GradResult['paramGrads'] = {}
274
+ let outputTensor: Tensor
275
+ let adamWritebacks: ReturnType<typeof appendAdam>['writebacks'] = []
276
+
277
+ if (withGrad) {
278
+ const gradResult = appendGrad(graph)
279
+ paramGrads = gradResult.paramGrads
280
+ outputTensor = gradResult.loss
281
+ const adamCfg = (opts as CompileModuleOptions).adam
282
+ if (adamCfg) {
283
+ const adamResult = appendAdam(graph, paramGrads, materialized.tensors, adamCfg, materialized.decayFlags)
284
+ adamWritebacks = adamResult.writebacks
285
+ // Stash adam result on the graph so wrapStepForAdam can find it.
286
+ ;(graph as Graph & { __adam?: ReturnType<typeof appendAdam> }).__adam = adamResult
287
+ }
288
+ } else {
289
+ outputTensor = graph.tensors[graph.outputs[0]!]!
290
+ }
291
+
292
+ const plan = planBuffers(graph, paramGrads, adamWritebacks)
212
293
  const kernels = emitKernels(graph, plan)
213
- const outputTensor = graph.tensors[graph.outputs[0]!]!
214
294
  const outputBufferId = plan.tensorToBuffer.get(outputTensor.id)!
215
- const runtime = await createForwardRuntime(plan, kernels, outputBufferId, opts)
295
+ // exactOptionalPropertyTypes: only include sharedParams when defined.
296
+ const runtimeOpts: RuntimeOpts = sharedParams
297
+ ? { ...opts, sharedParams }
298
+ : { ...opts }
299
+ const runtime = withGrad
300
+ ? await createRuntime(plan, kernels, outputBufferId, runtimeOpts)
301
+ : await createForwardRuntime(plan, kernels, outputBufferId, runtimeOpts)
216
302
 
217
- const sharedParams = opts.sharedParams
218
- const uploadInitialParams = () => {
219
- const out = buildInitialParamUploads(plan, materialized.initFns, sharedParams)
220
- if (Object.keys(out).length > 0) runtime.uploadParams(out, { partial: !!sharedParams })
303
+ const ir: CompiledIR = { graph, paramGrads, loss: outputTensor, plan, kernels }
304
+ return { runtime: runtime as CompiledRuntime, materialized, plan, kernels, ir }
305
+ }
306
+
307
+ type Graph = ReturnType<typeof trace>
308
+
309
+ function wrapStepForAdam(runtime: CompiledRuntime, adamCfg: AdamConfig, ir: CompiledIR): void {
310
+ const adamResult = (ir.graph as Graph & { __adam?: ReturnType<typeof appendAdam> }).__adam!
311
+ const { lrtInputName, decayShrinkInputName, config } = adamResult
312
+ let t = 0
313
+ const lrtBuf = new Float32Array(1)
314
+ const decayShrinkBuf = decayShrinkInputName ? new Float32Array(1) : null
315
+ const innerStep = runtime.step.bind(runtime) as CompiledRuntime['step']
316
+ const innerReset = runtime.resetOptimizerState.bind(runtime)
317
+ const wrappedStep = ((
318
+ inputs: Record<string, Int32Array | Float32Array>,
319
+ opts?: { withCaptures?: boolean },
320
+ ) => {
321
+ t++
322
+ const lrNow = config.lr(t)
323
+ lrtBuf[0] = lrNow * Math.sqrt(1 - Math.pow(config.b2, t)) / (1 - Math.pow(config.b1, t))
324
+ const merged: Record<string, Int32Array | Float32Array> = { ...inputs, [lrtInputName]: lrtBuf }
325
+ if (decayShrinkBuf && decayShrinkInputName) {
326
+ decayShrinkBuf[0] = 1 - lrNow * config.weightDecay
327
+ merged[decayShrinkInputName] = decayShrinkBuf
328
+ }
329
+ return opts?.withCaptures ? innerStep(merged, { withCaptures: true }) : innerStep(merged)
330
+ }) as CompiledRuntime['step']
331
+ runtime.step = wrappedStep
332
+ runtime.resetOptimizerState = () => {
333
+ t = 0
334
+ innerReset()
221
335
  }
336
+ void adamCfg
337
+ }
222
338
 
223
- // CompiledIR.loss is the field name; for forward-only, it carries the user's
224
- // returned tensor (e.g., logits). Same shape conceptually; just no autograd.
225
- const ir: CompiledIR = { graph, paramGrads: {}, loss: outputTensor, plan, kernels }
226
- return Object.assign(runtime, { ir, uploadInitialParams })
339
+ /** Build a Record<paramName, Float32Array> by running each param's init
340
+ * function against its shape and uploading them to the runtime. Skips any
341
+ * param covered by `sharedParams` (those are owned by a sibling compile). */
342
+ function uploadInitialParams(
343
+ plan: BufferPlan,
344
+ initFns: Record<string, InitFn>,
345
+ runtime: CompiledRuntime | CompiledForward,
346
+ sharedParams: Map<string, GPUBuffer> | undefined,
347
+ ): void {
348
+ const out: Record<string, Float32Array> = {}
349
+ for (const [name, bufId] of plan.paramsByName) {
350
+ if (sharedParams?.has(name)) continue
351
+ const shape = plan.buffers[bufId]!.shape
352
+ const size = shape.reduce((a, b) => a * b, 1)
353
+ const initFn = initFns[name]
354
+ if (!initFn) throw new Error(`compile: no init for param '${name}'`)
355
+ out[name] = initFn(size, shape)
356
+ }
357
+ if (Object.keys(out).length > 0) runtime.uploadParams(out, { partial: !!sharedParams })
227
358
  }
package/src/index.ts CHANGED
@@ -36,7 +36,12 @@ export { appendGrad, type GradResult } from './grad.js'
36
36
  export { appendAdam, type AdamConfig, type AdamResult } from './adam.js'
37
37
  export { planBuffers, type BufferPlan, type BufferSpec, type Writeback, type WritebackDecl } from './buffers.js'
38
38
  export { emitKernels, type KernelSpec } from './codegen.js'
39
- export { createRuntime, createForwardRuntime, type CompiledRuntime, type CompiledForward, type RuntimeOpts, type StepOptions, type StepWithCaptures, type RunOptions, type RunWithCaptures } from './runtime.js'
40
- export { compile, compileToIR, compileModule, compileForward, type CompiledIR, type CompileModuleOptions, type CompileForwardOptions, type InputDecl } from './compile.js'
39
+ export { createRuntime, createForwardRuntime, Captures, type CompiledRuntime, type CompiledForward, type RuntimeOpts, type RunOptions, type StepResult, type RunResult } from './runtime.js'
40
+ export {
41
+ compile, compileToIR, compileModule, compileForward,
42
+ type CompiledIR, type CompileModuleOptions, type CompileForwardOptions, type CompileForwardMethodOptions,
43
+ type CompiledModule, type CompiledForwardModule,
44
+ type InputDecl, type InputDecls, type InputsTensors, type ForwardFn,
45
+ } from './compile.js'
41
46
  export { Module, materializeParams, type InitSpec, type ParamOptions, type MaterializedParams } from './module.js'
42
47
  export * as nn from './nn.js'
package/src/nn.ts CHANGED
@@ -1,41 +1,44 @@
1
1
  // Standard "batteries-included" Module subclasses for the most common layers.
2
2
  //
3
- // JAX-style: each class declares its params (and their init); the forward is a
4
- // plain function the user calls with `(module, x)`. No subclassing, no method
5
- // dispatch keeps the autograd-traced computation visible at the call site.
6
- //
7
- // Import as a namespace:
3
+ // Each class declares its params and a `.fwd(x)` method that runs the forward
4
+ // computation. Forward methods are pure tensorgrad ops autograd traces
5
+ // through them just like any other call.
8
6
  //
9
7
  // import { nn } from 'tensorgrad'
10
8
  // class Block extends Module {
11
9
  // ln = new nn.LayerNorm(D)
12
10
  // ffn = new nn.Linear(D, 4 * D)
13
11
  // }
14
- // const y = nn.linearFwd(p.ffn, nn.layerNormFwd(p.ln, x))
12
+ // const y = p.ffn.fwd(p.ln.fwd(x))
15
13
 
16
14
  import { Module } from './module.js'
17
15
  import type { Tensor } from './ir.js'
18
16
  import { add, matmul, sub, mul, div, sqrt, meanLast, sumLast, reshape, swapAxes, oneHot, logSoftmaxLast } from './ops.js'
19
17
  import { ShapeError } from './shape.js'
20
18
  import { captureSite } from './ir.js'
19
+ import type { Captures } from './runtime.js'
21
20
 
22
21
  // ----------------------------------------------------------------------------
23
22
  // Linear: y = x @ W (+ b)
24
23
  // ----------------------------------------------------------------------------
25
24
 
25
+ export interface LinearOptions {
26
+ /** Include a bias term (default true). */
27
+ bias?: boolean
28
+ }
29
+
26
30
  export class Linear extends Module {
27
31
  W: Tensor
28
32
  b: Tensor | null
29
- constructor(public readonly inDim: number, public readonly outDim: number, withBias = true) {
33
+ constructor(public readonly inDim: number, public readonly outDim: number, opts: LinearOptions = {}) {
30
34
  super()
31
35
  this.W = this.param([inDim, outDim]) // randn, scale 0.02
32
- this.b = withBias ? this.param([outDim], { init: 'zeros' }) : null
36
+ this.b = opts.bias === false ? null : this.param([outDim], { init: 'zeros' })
37
+ }
38
+ fwd(x: Tensor): Tensor {
39
+ const out = matmul(x, this.W)
40
+ return this.b ? add(out, this.b) : out
33
41
  }
34
- }
35
-
36
- export function linearFwd(p: Linear, x: Tensor): Tensor {
37
- const out = matmul(x, p.W)
38
- return p.b ? add(out, p.b) : out
39
42
  }
40
43
 
41
44
  // ----------------------------------------------------------------------------
@@ -50,14 +53,13 @@ export class LayerNorm extends Module {
50
53
  this.g = this.param([d], { init: 'ones' })
51
54
  this.b = this.param([d], { init: 'zeros' })
52
55
  }
53
- }
54
-
55
- export function layerNormFwd(p: LayerNorm, x: Tensor): Tensor {
56
- const m = meanLast(x)
57
- const c = sub(x, m)
58
- const v = meanLast(mul(c, c))
59
- const stdev = sqrt(add(v, p.eps))
60
- return add(mul(div(c, stdev), p.g), p.b)
56
+ fwd(x: Tensor): Tensor {
57
+ const m = meanLast(x)
58
+ const c = sub(x, m)
59
+ const v = meanLast(mul(c, c))
60
+ const stdev = sqrt(add(v, this.eps))
61
+ return add(mul(div(c, stdev), this.g), this.b)
62
+ }
61
63
  }
62
64
 
63
65
  // ----------------------------------------------------------------------------
@@ -97,26 +99,26 @@ export function mergeHeads(x: Tensor): Tensor {
97
99
  return reshape(swapped, [...lead, T, H * d])
98
100
  }
99
101
 
100
- /** Slice a flat capture readback of shape `[H, ..., ...]` into one
101
- * Float32Array per head. The leading axis is treated as the head axis;
102
- * pass the shape from `compiled.captureShapes[name]`. Result: `H` arrays,
103
- * each holding the row-major data for that head (size = product of trailing
104
- * axes). For B>1 graphs, prefix the result by the batch — this helper
105
- * assumes the leading axis is heads, which matches how `splitHeads` lays
106
- * out captures at B=1 (the typical capture-readback shape). */
107
- export function unsplitHeads(flat: Float32Array, shape: readonly number[]): Float32Array[] {
102
+ /** Slice a captured tensor named `name` into one Float32Array per head, using
103
+ * the static shape registered at compile time. The leading axis is treated as
104
+ * heads (matching `splitHeads` layout at B=1); a leading singleton batch is
105
+ * stripped if present so callers can pass capture names directly. Throws if
106
+ * the capture isn't registered or wasn't read back this call. */
107
+ export function unsplitHeads(captures: Captures, name: string): Float32Array[] {
108
+ const flat = captures.get(name)
109
+ const shape = captures.shapeOf(name)
108
110
  if (shape.length < 2) {
109
- throw new Error(`unsplitHeads: shape needs >= 2 dims, got [${shape.join(', ')}]`)
111
+ throw new Error(`unsplitHeads: '${name}' shape needs >= 2 dims, got [${shape.join(', ')}]`)
110
112
  }
111
113
  // For inference graphs at B=1, captures have shape [1, H, ..., ...]. Strip
112
- // the leading 1 if present so callers can pass captureShapes[name] directly.
114
+ // the leading 1 if present so the next axis is heads.
113
115
  const s = shape[0] === 1 ? shape.slice(1) : shape
114
116
  const H = s[0]!
115
117
  let stride = 1
116
118
  for (let i = 1; i < s.length; i++) stride *= s[i]!
117
119
  const expected = H * stride
118
120
  if (flat.length !== expected) {
119
- throw new Error(`unsplitHeads: flat length ${flat.length} doesn't match shape product ${expected}`)
121
+ throw new Error(`unsplitHeads: '${name}' length ${flat.length} doesn't match shape product ${expected}`)
120
122
  }
121
123
  return Array.from({ length: H }, (_, h) => flat.slice(h * stride, (h + 1) * stride))
122
124
  }