tensorgrad 0.0.2 → 0.0.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. package/README.md +7 -9
  2. package/dist/adam.d.ts +22 -5
  3. package/dist/adam.d.ts.map +1 -1
  4. package/dist/adam.js +42 -10
  5. package/dist/adam.js.map +1 -1
  6. package/dist/buffers.d.ts +1 -0
  7. package/dist/buffers.d.ts.map +1 -1
  8. package/dist/buffers.js +12 -1
  9. package/dist/buffers.js.map +1 -1
  10. package/dist/capture.d.ts +3 -0
  11. package/dist/capture.d.ts.map +1 -0
  12. package/dist/capture.js +33 -0
  13. package/dist/capture.js.map +1 -0
  14. package/dist/codegen.js +16 -5
  15. package/dist/codegen.js.map +1 -1
  16. package/dist/compile.d.ts +33 -5
  17. package/dist/compile.d.ts.map +1 -1
  18. package/dist/compile.js +106 -14
  19. package/dist/compile.js.map +1 -1
  20. package/dist/index.d.ts +5 -3
  21. package/dist/index.d.ts.map +1 -1
  22. package/dist/index.js +4 -2
  23. package/dist/index.js.map +1 -1
  24. package/dist/ir.d.ts +2 -0
  25. package/dist/ir.d.ts.map +1 -1
  26. package/dist/ir.js +1 -1
  27. package/dist/ir.js.map +1 -1
  28. package/dist/module.d.ts +30 -4
  29. package/dist/module.d.ts.map +1 -1
  30. package/dist/module.js +39 -13
  31. package/dist/module.js.map +1 -1
  32. package/dist/nn.d.ts +19 -0
  33. package/dist/nn.d.ts.map +1 -0
  34. package/dist/nn.js +60 -0
  35. package/dist/nn.js.map +1 -0
  36. package/dist/ops.d.ts +1 -1
  37. package/dist/ops.d.ts.map +1 -1
  38. package/dist/ops.js +18 -1
  39. package/dist/ops.js.map +1 -1
  40. package/dist/runtime.d.ts +79 -4
  41. package/dist/runtime.d.ts.map +1 -1
  42. package/dist/runtime.js +153 -19
  43. package/dist/runtime.js.map +1 -1
  44. package/dist/trace.d.ts +1 -0
  45. package/dist/trace.d.ts.map +1 -1
  46. package/dist/trace.js +12 -0
  47. package/dist/trace.js.map +1 -1
  48. package/package.json +1 -2
  49. package/src/adam.ts +65 -14
  50. package/src/buffers.ts +14 -1
  51. package/src/capture.ts +36 -0
  52. package/src/codegen.ts +16 -5
  53. package/src/compile.ts +122 -16
  54. package/src/index.ts +5 -3
  55. package/src/ir.ts +20 -4
  56. package/src/module.ts +75 -11
  57. package/src/nn.ts +59 -0
  58. package/src/ops.ts +26 -3
  59. package/src/runtime.ts +260 -22
  60. package/src/trace.ts +13 -0
  61. package/SPEC.md +0 -293
package/src/adam.ts CHANGED
@@ -15,6 +15,12 @@
15
15
  // `sqrt(1-b2^t)/(1-b1^t)`; that's why convergence isn't affected by the
16
16
  // first-step warmup that bias-correction-free Adam suffers.
17
17
  //
18
+ // **Static vs scheduled lr.** When `config.lr` is a number, decayShrink is
19
+ // baked into the kernel as a literal. When it's a function `(step) => lr`,
20
+ // decayShrink for decayed params becomes a per-step scalar input that the
21
+ // runtime updates each call (computed from the current step's lr). lrt is
22
+ // always per-step; the bias-correction factor changes every step regardless.
23
+ //
18
24
  // Returns writeback declarations the buffer planner uses to wire up the
19
25
  // "after step, copy the new value into the persistent home" path. m and v
20
26
  // are state_inputs (zero-initialized, persistent across steps); the param
@@ -27,7 +33,11 @@ import { traceInto, stateInput, tensorInput } from './trace.js'
27
33
  import { adamUpdateM, adamUpdateV, adamUpdateP } from './ops.js'
28
34
 
29
35
  export interface AdamConfig {
30
- lr: number
36
+ /** Constant scalar (e.g., `0.005`) or a per-step schedule function
37
+ * `(step) => lr`. Schedule fn lets the user implement linear/cosine decay
38
+ * or warmup; first call passes `step=1`. Decay-shrink (AdamW) updates
39
+ * per-step automatically when this is a function. */
40
+ lr: number | ((step: number) => number)
31
41
  b1?: number // default 0.9
32
42
  b2?: number // default 0.999
33
43
  eps?: number // default 1e-8
@@ -42,14 +52,30 @@ export interface AdamConfig {
42
52
  decayFilter?: (paramName: string) => boolean
43
53
  }
44
54
 
55
+ /** Resolved hyperparameters: lr is the schedule fn (constants are wrapped). */
56
+ export interface AdamResolvedConfig {
57
+ lr: (step: number) => number
58
+ b1: number
59
+ b2: number
60
+ eps: number
61
+ weightDecay: number
62
+ decayFilter: (name: string) => boolean
63
+ /** True iff the user supplied an lr function (vs a constant). When false,
64
+ * decayShrink is baked at compile time and never updated. */
65
+ lrIsScheduled: boolean
66
+ }
67
+
45
68
  export interface AdamResult {
46
69
  /** Writebacks the buffer planner should wire into the runtime. */
47
70
  writebacks: WritebackDecl[]
48
71
  /** Name of the per-step scalar tensor_input. The runtime fills this each call
49
72
  * with `lr * sqrt(1-b2^t)/(1-b1^t)` (Adam's bias-corrected effective LR). */
50
73
  lrtInputName: string
51
- /** Hyperparameters as captured (so the runtime can compute lrt). */
52
- config: Required<Omit<AdamConfig, 'decayFilter'>> & { decayFilter: (name: string) => boolean }
74
+ /** Name of the per-step decayShrink scalar tensor_input, or null when lr is
75
+ * static (decayShrink baked into the kernel) or no params are decayed. */
76
+ decayShrinkInputName: string | null
77
+ /** Hyperparameters as captured (so the runtime can compute lrt and decayShrink). */
78
+ config: AdamResolvedConfig
53
79
  }
54
80
 
55
81
  /**
@@ -70,22 +96,40 @@ export function appendAdam(
70
96
  paramTensors: Record<string, Tensor>,
71
97
  config: AdamConfig,
72
98
  ): AdamResult {
73
- const fullConfig = {
74
- lr: config.lr,
99
+ const lrIsScheduled = typeof config.lr === 'function'
100
+ const lrFn = lrIsScheduled
101
+ ? config.lr as (step: number) => number
102
+ : (() => config.lr as number)
103
+ const initialLr = lrFn(1)
104
+ const fullConfig: AdamResolvedConfig = {
105
+ lr: lrFn,
75
106
  b1: config.b1 ?? 0.9,
76
107
  b2: config.b2 ?? 0.999,
77
108
  eps: config.eps ?? 1e-8,
78
109
  weightDecay: config.weightDecay ?? 0,
79
110
  decayFilter: config.decayFilter ?? (() => true),
111
+ lrIsScheduled,
80
112
  }
81
113
  const writebacks: WritebackDecl[] = []
82
114
  const lrtInputName = '_adam_lrt'
115
+ // Tensor input for runtime-updated decayShrink (only created when lr is a
116
+ // schedule fn AND at least one param will receive weight decay).
117
+ let decayShrinkInputName: string | null = null
83
118
 
84
119
  return traceInto(graph, () => {
85
- // One scalar lrt input shared by every adam_update_p call. Runtime supplies
86
- // it per step as `lr * sqrt(1-b2^t) / (1-b1^t)`.
87
120
  const lrt = tensorInput(lrtInputName, [], 'f32')
88
121
 
122
+ // Decide up-front whether we need a runtime decayShrink scalar. Only does
123
+ // something when both (a) lr varies per step and (b) some param is decayed.
124
+ const needsDynamicShrink = lrIsScheduled
125
+ && fullConfig.weightDecay > 0
126
+ && Object.keys(paramGrads).some(name => fullConfig.decayFilter(name))
127
+ let decayShrinkScalar: Tensor | null = null
128
+ if (needsDynamicShrink) {
129
+ decayShrinkInputName = '_adam_decay_shrink'
130
+ decayShrinkScalar = tensorInput(decayShrinkInputName, [], 'f32')
131
+ }
132
+
89
133
  for (const name of Object.keys(paramGrads)) {
90
134
  const p = paramTensors[name]
91
135
  const g = paramGrads[name]
@@ -95,12 +139,19 @@ export function appendAdam(
95
139
  const mState = stateInput(`adam_m_${name}`, p.shape, 'f32', 0)
96
140
  const vState = stateInput(`adam_v_${name}`, p.shape, 'f32', 0)
97
141
 
98
- // decayShrink baked at compile time. 1.0 for plain Adam (no extra cost
99
- // the WGSL compiler folds the constant multiply); 1 - lr * weightDecay
100
- // for the params the filter selects.
101
- const decayShrink = (fullConfig.weightDecay > 0 && fullConfig.decayFilter(name))
102
- ? 1 - fullConfig.lr * fullConfig.weightDecay
103
- : 1
142
+ // Choose the decayShrink form per param:
143
+ // - non-decayed params: literal 1 (kernel multiply folds out).
144
+ // - decayed + static lr: literal `1 - lr * wd` baked at compile.
145
+ // - decayed + scheduled lr: tensor input updated per step.
146
+ const isDecayed = fullConfig.weightDecay > 0 && fullConfig.decayFilter(name)
147
+ let decayShrink: number | Tensor
148
+ if (!isDecayed) {
149
+ decayShrink = 1
150
+ } else if (decayShrinkScalar !== null) {
151
+ decayShrink = decayShrinkScalar
152
+ } else {
153
+ decayShrink = 1 - initialLr * fullConfig.weightDecay
154
+ }
104
155
 
105
156
  // Three fused kernels per parameter — one for each of m_new / v_new / p_new.
106
157
  const newM = adamUpdateM(mState, g, fullConfig.b1)
@@ -111,6 +162,6 @@ export function appendAdam(
111
162
  writebacks.push({ source: newV, destName: `adam_v_${name}`, destKind: 'state' })
112
163
  writebacks.push({ source: newP, destName: name, destKind: 'param' })
113
164
  }
114
- return { writebacks, lrtInputName, config: fullConfig }
165
+ return { writebacks, lrtInputName, decayShrinkInputName, config: fullConfig }
115
166
  })
116
167
  }
package/src/buffers.ts CHANGED
@@ -47,6 +47,7 @@ export interface BufferPlan {
47
47
  inputsByName: Map<string, number> // name -> buffer id
48
48
  paramGradsByName: Map<string, number> // name -> buffer id
49
49
  statesByName: Map<string, number> // name -> buffer id (persistent state homes)
50
+ capturesByName: Map<string, number> // name -> buffer id (activation captures)
50
51
  outputBufferIds: number[] // graph.outputs mapped through
51
52
  /** End-of-step writebacks (Adam updates for params, m, v, etc.) */
52
53
  writebacks: Writeback[]
@@ -169,5 +170,17 @@ export function planBuffers(
169
170
  return { source: sourceBufId, dest: destBufId, bytes: sourceSpec.byteSize }
170
171
  })
171
172
 
172
- return { buffers, tensorToBuffer, paramsByName, inputsByName, paramGradsByName, statesByName, outputBufferIds, writebacks }
173
+ // Resolve graph.captures (name -> tensor id) to (name -> buffer id).
174
+ // No pinning needed at the planner level: each tensor already has its own
175
+ // buffer (see "v1 strategy" comment at top — no pooling yet).
176
+ const capturesByName = new Map<string, number>()
177
+ for (const [name, tensorId] of graph.captures) {
178
+ const bufId = tensorToBuffer.get(tensorId)
179
+ if (bufId === undefined) {
180
+ throw new Error(`planBuffers: capture '${name}' references unknown tensor #${tensorId}`)
181
+ }
182
+ capturesByName.set(name, bufId)
183
+ }
184
+
185
+ return { buffers, tensorToBuffer, paramsByName, inputsByName, paramGradsByName, statesByName, capturesByName, outputBufferIds, writebacks }
173
186
  }
package/src/capture.ts ADDED
@@ -0,0 +1,36 @@
1
+ // Activation capture — opt-in readback of intermediate tensors at training step.
2
+ //
3
+ // Usage (inside the user's forward pass):
4
+ //
5
+ // import { capture } from 'tensorgrad'
6
+ //
7
+ // function attentionFwd(p, x) {
8
+ // const scores = mul(matmulBatched(q, kT), SCALE_QK)
9
+ // const attn = capture(`attn.${layerIdx}`, softmaxCausalLast(scores))
10
+ // return matmulBatched(attn, v)
11
+ // }
12
+ //
13
+ // Pass-through return type: `capture(name, t)` returns `t` unchanged so it
14
+ // inlines at the point of computation. Behind the scenes it registers `t.id`
15
+ // against `name` on the current graph; runtime exposes the registered tensors
16
+ // via `step(inputs, { withCaptures: true })`.
17
+ //
18
+ // Outside the user's forward trace (during `appendGrad` / `appendAdam`'s
19
+ // `traceInto` re-entry), `capture()` is a no-op — gradient and optimizer
20
+ // internals shouldn't accidentally publish themselves to the UI.
21
+
22
+ import type { Tensor } from './ir.js'
23
+ import { currentGraph, isCaptureEnabled } from './trace.js'
24
+
25
+ export function capture<T extends Tensor>(name: string, t: T): T {
26
+ if (!isCaptureEnabled()) return t
27
+ const g = currentGraph()
28
+ if (g.captures.has(name)) {
29
+ throw new Error(
30
+ `capture: name '${name}' already registered. Use unique names ` +
31
+ `(e.g. \`attn.\${layerIdx}\`) when capturing across a loop.`,
32
+ )
33
+ }
34
+ g.captures.set(name, t.id)
35
+ return t
36
+ }
package/src/codegen.ts CHANGED
@@ -557,23 +557,34 @@ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
557
557
  case 'adam_update_p': {
558
558
  // p_new = decayShrink * p - lrt[0] * m_new / (sqrt(v_new) + eps).
559
559
  // lrt is supplied per-step from CPU (already includes bias correction).
560
- // decayShrink encodes AdamW's decoupled weight decay; when no decay is
561
- // requested it's exactly 1.0 and the WGSL compiler folds the multiply away.
560
+ // decayShrink is either baked as a literal (no schedule, fixed lr) or
561
+ // bound as a per-step scalar input (when the user supplies an lr
562
+ // schedule via `adam: { lr: (step) => ... }`). When literal=1 the WGSL
563
+ // compiler folds the multiply away.
562
564
  const out = tof(op.out)
563
565
  const total = shapeSize(out.shape)
566
+ const dynamicShrink = op.decayShrinkTensor !== null
567
+ const shrinkExpr = dynamicShrink ? 'decayShrink[0]' : wgslLiteral(op.decayShrink, 'f32')
568
+ const shrinkBinding = dynamicShrink
569
+ ? `@group(0) @binding(4) var<storage, read> decayShrink : array<f32>;\n` +
570
+ `@group(0) @binding(5) var<storage, read_write> out : array<f32>;`
571
+ : `@group(0) @binding(4) var<storage, read_write> out : array<f32>;`
564
572
  const wgsl = `
565
573
  @group(0) @binding(0) var<storage, read> p : array<f32>;
566
574
  @group(0) @binding(1) var<storage, read> mNew : array<f32>;
567
575
  @group(0) @binding(2) var<storage, read> vNew : array<f32>;
568
576
  @group(0) @binding(3) var<storage, read> lrt : array<f32>;
569
- @group(0) @binding(4) var<storage, read_write> out : array<f32>;
577
+ ${shrinkBinding}
570
578
  @compute @workgroup_size(${WG_SIZE})
571
579
  fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
572
580
  let i = gid.x + gid.y * 16776960u;
573
581
  if (i >= ${total}u) { return; }
574
- out[i] = ${wgslLiteral(op.decayShrink, 'f32')} * p[i] - lrt[0] * mNew[i] / (sqrt(vNew[i]) + ${wgslLiteral(op.eps, 'f32')});
582
+ out[i] = ${shrinkExpr} * p[i] - lrt[0] * mNew[i] / (sqrt(vNew[i]) + ${wgslLiteral(op.eps, 'f32')});
575
583
  }`.trim()
576
- return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.p), buf(op.mNew), buf(op.vNew), buf(op.lrt), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
584
+ const bindings = dynamicShrink
585
+ ? [buf(op.p), buf(op.mNew), buf(op.vNew), buf(op.lrt), buf(op.decayShrinkTensor!), buf(op.out)]
586
+ : [buf(op.p), buf(op.mNew), buf(op.vNew), buf(op.lrt), buf(op.out)]
587
+ return { opIndex, opKind: op.kind, wgsl, bindings, threads: total, workgroupSize: WG_SIZE }
577
588
  }
578
589
 
579
590
  case 'sum_to_shape': {
package/src/compile.ts CHANGED
@@ -14,7 +14,7 @@ import { appendGrad, type GradResult } from './grad.js'
14
14
  import { appendAdam, type AdamConfig } from './adam.js'
15
15
  import { planBuffers, type BufferPlan } from './buffers.js'
16
16
  import { emitKernels, type KernelSpec } from './codegen.js'
17
- import { createRuntime, type CompiledRuntime, type RuntimeOpts } from './runtime.js'
17
+ import { createRuntime, createForwardRuntime, type CompiledRuntime, type CompiledForward, type RuntimeOpts } from './runtime.js'
18
18
  import { Module, materializeParams } from './module.js'
19
19
 
20
20
  /** Declares one input tensor of the model's forward function. Order matches
@@ -65,10 +65,19 @@ export interface CompileModuleOptions extends RuntimeOpts {
65
65
  adam?: AdamConfig
66
66
  }
67
67
 
68
+ export interface CompileForwardOptions extends RuntimeOpts {
69
+ /** Per-step data inputs to the forward function. */
70
+ inputs?: InputDecl[]
71
+ }
72
+
68
73
  /**
69
- * Compile a Module-based model. The forward function takes the materialized
70
- * model and returns the loss tensor (typically by also calling tensorInput
71
- * for tokens/targets/masks inside).
74
+ * Compile a Module-based model. Pass a *factory* `() => new Model()`, not the
75
+ * model instance itself: compilation mutates the tree (every `ParamSentinel`
76
+ * field becomes a real `Tensor`), so the instance is consumed and shouldn't be
77
+ * referenced afterwards. Re-call the factory if you need a fresh tree.
78
+ *
79
+ * The forward function takes the materialized model and returns the loss
80
+ * tensor.
72
81
  *
73
82
  * Walks the module tree to materialize params with auto-derived names, then
74
83
  * runs trace → grad → adam → buffer plan → codegen → runtime.
@@ -78,14 +87,15 @@ export interface CompileModuleOptions extends RuntimeOpts {
78
87
  * users don't need to provide it themselves.
79
88
  */
80
89
  export async function compileModule<M extends Module>(
81
- model: M,
90
+ modelFactory: () => M,
82
91
  forward: (m: M, ...inputs: Tensor[]) => Tensor,
83
92
  opts: CompileModuleOptions = {},
84
- ): Promise<CompiledRuntime & { ir: CompiledIR }> {
93
+ ): Promise<CompiledRuntime & { ir: CompiledIR; uploadInitialParams: () => void }> {
85
94
  const inputDecls = opts.inputs ?? []
86
- let paramTensors: Record<string, Tensor> = {}
95
+ const model = modelFactory()
96
+ let materialized: ReturnType<typeof materializeParams> = { tensors: {}, initFns: {} }
87
97
  const graph = trace(() => {
88
- paramTensors = materializeParams(model)
98
+ materialized = materializeParams(model)
89
99
  const inputTensors = inputDecls.map(d => tensorInput(d.name, d.shape, d.dtype ?? 'f32'))
90
100
  return forward(model, ...inputTensors)
91
101
  })
@@ -94,7 +104,7 @@ export async function compileModule<M extends Module>(
94
104
 
95
105
  let adamResult: ReturnType<typeof appendAdam> | undefined
96
106
  if (opts.adam) {
97
- adamResult = appendAdam(graph, paramGrads, paramTensors, opts.adam)
107
+ adamResult = appendAdam(graph, paramGrads, materialized.tensors, opts.adam)
98
108
  }
99
109
 
100
110
  const plan = planBuffers(graph, paramGrads, adamResult?.writebacks ?? [])
@@ -102,19 +112,115 @@ export async function compileModule<M extends Module>(
102
112
  const lossBufferId = plan.tensorToBuffer.get(loss.id)!
103
113
  const runtime = await createRuntime(plan, kernels, lossBufferId, opts)
104
114
 
105
- // If Adam is enabled, wrap step() to track the step count and supply lrt.
115
+ // If Adam is enabled, wrap step() to track the step count and supply lrt
116
+ // (and optionally decayShrink, when the user passed a per-step lr schedule).
117
+ // Wrap resetOptimizerState() too, so a reset zeros m/v *and* the bias-correction
118
+ // counter — otherwise the next step would skip Adam's warmup phase.
106
119
  if (adamResult) {
107
- const { lrtInputName, config } = adamResult
120
+ const { lrtInputName, decayShrinkInputName, config } = adamResult
108
121
  let t = 0
109
122
  const lrtBuf = new Float32Array(1)
110
- const innerStep = runtime.step.bind(runtime)
111
- runtime.step = async (inputs) => {
123
+ const decayShrinkBuf = decayShrinkInputName ? new Float32Array(1) : null
124
+ const innerStep = runtime.step.bind(runtime) as CompiledRuntime['step']
125
+ const innerReset = runtime.resetOptimizerState.bind(runtime)
126
+ const wrappedStep = (
127
+ inputs: Record<string, Int32Array | Float32Array>,
128
+ opts?: { withCaptures?: boolean },
129
+ ): Promise<number | { loss: number; captures: Record<string, Float32Array> }> => {
112
130
  t++
113
- lrtBuf[0] = config.lr * Math.sqrt(1 - Math.pow(config.b2, t)) / (1 - Math.pow(config.b1, t))
114
- return innerStep({ ...inputs, [lrtInputName]: lrtBuf })
131
+ const lrNow = config.lr(t)
132
+ lrtBuf[0] = lrNow * Math.sqrt(1 - Math.pow(config.b2, t)) / (1 - Math.pow(config.b1, t))
133
+ const merged: Record<string, Int32Array | Float32Array> = { ...inputs, [lrtInputName]: lrtBuf }
134
+ if (decayShrinkBuf && decayShrinkInputName) {
135
+ decayShrinkBuf[0] = 1 - lrNow * config.weightDecay
136
+ merged[decayShrinkInputName] = decayShrinkBuf
137
+ }
138
+ return opts?.withCaptures ? innerStep(merged, { withCaptures: true }) : innerStep(merged)
139
+ }
140
+ runtime.step = wrappedStep as CompiledRuntime['step']
141
+ runtime.resetOptimizerState = () => {
142
+ t = 0
143
+ innerReset()
144
+ }
145
+ }
146
+
147
+ const { initFns } = materialized
148
+ const uploadInitialParams = () => {
149
+ const out: Record<string, Float32Array> = {}
150
+ for (const [name, bufId] of plan.paramsByName) {
151
+ const shape = plan.buffers[bufId]!.shape
152
+ const size = shape.reduce((a, b) => a * b, 1)
153
+ const initFn = initFns[name]
154
+ if (!initFn) throw new Error(`uploadInitialParams: no init for param '${name}'`)
155
+ out[name] = initFn(size, shape)
115
156
  }
157
+ runtime.uploadParams(out)
116
158
  }
117
159
 
118
160
  const ir: CompiledIR = { graph, paramGrads, loss, plan, kernels }
119
- return Object.assign(runtime, { ir })
161
+ return Object.assign(runtime, { ir, uploadInitialParams })
162
+ }
163
+
164
+ // ============================================================================
165
+ // Forward-only compile
166
+ // ============================================================================
167
+
168
+ /**
169
+ * Compile a Module-based model in forward-only mode (no autograd, no Adam).
170
+ * The forward function returns the output tensor (e.g., logits) instead of a
171
+ * scalar loss; runtime exposes `run(inputs)` returning the full output as a
172
+ * `Float32Array`.
173
+ *
174
+ * **Sharing params with a training compile.** Pass `opts.sharedParams =
175
+ * trainCompiled.params` to bind this graph's param buffers to an existing
176
+ * training runtime's GPU buffers — every train step is then immediately
177
+ * visible to `run()` calls here, no copies. The forward graph's
178
+ * `uploadInitialParams()` skips any param covered by `sharedParams`.
179
+ *
180
+ * Typical use: a B=1 inference graph alongside a B=512 training graph,
181
+ * built from the same `Module` factory.
182
+ */
183
+ export async function compileForward<M extends Module>(
184
+ modelFactory: () => M,
185
+ forward: (m: M, ...inputs: Tensor[]) => Tensor,
186
+ opts: CompileForwardOptions = {},
187
+ ): Promise<CompiledForward & { ir: CompiledIR; uploadInitialParams: () => void }> {
188
+ const inputDecls = opts.inputs ?? []
189
+ const model = modelFactory()
190
+ let materialized: ReturnType<typeof materializeParams> = { tensors: {}, initFns: {} }
191
+ const graph = trace(() => {
192
+ materialized = materializeParams(model)
193
+ const inputTensors = inputDecls.map(d => tensorInput(d.name, d.shape, d.dtype ?? 'f32'))
194
+ return forward(model, ...inputTensors)
195
+ })
196
+
197
+ const plan = planBuffers(graph, /* paramGrads */ {})
198
+ const kernels = emitKernels(graph, plan)
199
+ const outputTensor = graph.tensors[graph.outputs[0]!]!
200
+ const outputBufferId = plan.tensorToBuffer.get(outputTensor.id)!
201
+ const runtime = await createForwardRuntime(plan, kernels, outputBufferId, opts)
202
+
203
+ const sharedParams = opts.sharedParams
204
+ const { initFns } = materialized
205
+ const uploadInitialParams = () => {
206
+ const out: Record<string, Float32Array> = {}
207
+ let needsUpload = false
208
+ for (const [name, bufId] of plan.paramsByName) {
209
+ // Skip params covered by sharedParams — those are owned by the providing
210
+ // compile and already initialized there.
211
+ if (sharedParams?.has(name)) continue
212
+ const shape = plan.buffers[bufId]!.shape
213
+ const size = shape.reduce((a, b) => a * b, 1)
214
+ const initFn = initFns[name]
215
+ if (!initFn) throw new Error(`uploadInitialParams: no init for param '${name}'`)
216
+ out[name] = initFn(size, shape)
217
+ needsUpload = true
218
+ }
219
+ if (needsUpload) runtime.uploadParams(out, { partial: !!sharedParams })
220
+ }
221
+
222
+ // CompiledIR.loss is the field name; for forward-only, it carries the user's
223
+ // returned tensor (e.g., logits). Same shape conceptually; just no autograd.
224
+ const ir: CompiledIR = { graph, paramGrads: {}, loss: outputTensor, plan, kernels }
225
+ return Object.assign(runtime, { ir, uploadInitialParams })
120
226
  }
package/src/index.ts CHANGED
@@ -6,6 +6,7 @@
6
6
  export type { Tensor, Shape, Dtype, OpNode, Graph, CallSite } from './ir.js'
7
7
  export { ShapeError } from './shape.js'
8
8
  export { trace, traceInto, paramInput, tensorInput, stateInput } from './trace.js'
9
+ export { capture } from './capture.js'
9
10
  export {
10
11
  // Element-wise arithmetic. The binops accept Tensor or JS-number for the second arg.
11
12
  add, sub, mul, div,
@@ -35,6 +36,7 @@ export { appendGrad, type GradResult } from './grad.js'
35
36
  export { appendAdam, type AdamConfig, type AdamResult } from './adam.js'
36
37
  export { planBuffers, type BufferPlan, type BufferSpec, type Writeback, type WritebackDecl } from './buffers.js'
37
38
  export { emitKernels, type KernelSpec } from './codegen.js'
38
- export { createRuntime, type CompiledRuntime, type RuntimeOpts } from './runtime.js'
39
- export { compile, compileToIR, compileModule, type CompiledIR, type CompileModuleOptions, type InputDecl } from './compile.js'
40
- export { Module, materializeParams } from './module.js'
39
+ export { createRuntime, createForwardRuntime, type CompiledRuntime, type CompiledForward, type RuntimeOpts, type StepOptions, type StepWithCaptures, type RunOptions, type RunWithCaptures } from './runtime.js'
40
+ export { compile, compileToIR, compileModule, compileForward, type CompiledIR, type CompileModuleOptions, type CompileForwardOptions, type InputDecl } from './compile.js'
41
+ export { Module, materializeParams, type InitSpec, type ParamOptions, type MaterializedParams } from './module.js'
42
+ export * as nn from './nn.js'
package/src/ir.ts CHANGED
@@ -113,9 +113,21 @@ export type OpNode =
113
113
  // `lrt` is a scalar tensor (provided as a tensor_input updated per step) that
114
114
  // already includes Adam's bias-correction factor: lrt = lr * sqrt(1-b2^t) / (1-b1^t).
115
115
  // `decayShrink` is the decoupled-weight-decay factor (Loshchilov & Hutter,
116
- // "AdamW") baked at compile time: 1 - lr * weightDecay when the param is being
117
- // decayed, 1 otherwise. eps and decayShrink are both baked into the kernel.
118
- | { kind: 'adam_update_p'; out: number; p: number; mNew: number; vNew: number; lrt: number; eps: number; decayShrink: number }
116
+ // "AdamW"): 1 - lr * weightDecay when the param is being decayed, 1 otherwise.
117
+ // It can be either a compile-time literal (number) for fixed-lr training, or a
118
+ // tensor id pointing at a scalar input that the runtime updates per step (used
119
+ // when the user supplies an lr schedule via `adam: { lr: (step) => ... }`).
120
+ | {
121
+ kind: 'adam_update_p'
122
+ out: number
123
+ p: number
124
+ mNew: number
125
+ vNew: number
126
+ lrt: number
127
+ eps: number
128
+ decayShrink: number // literal (used when decayShrinkTensor is null)
129
+ decayShrinkTensor: number | null // tensor id of a scalar input; takes precedence when set
130
+ }
119
131
 
120
132
  // ---- Slicing / broadcasting / autograd infrastructure -------------------
121
133
  // Slice [start, end) along the last axis. Output shape: input shape with
@@ -141,10 +153,14 @@ export interface Graph {
141
153
  // Names of tensors that should be exposed as outputs of the compiled function.
142
154
  // Set by the trace driver; for a loss function, this is `[lossTensor]`.
143
155
  readonly outputs: number[]
156
+ // Tensors registered for activation readback via `capture(name, t)`.
157
+ // Keyed by user-supplied name; insertion order preserved. Empty when no
158
+ // captures registered (the common training case — zero overhead).
159
+ readonly captures: Map<string, number>
144
160
  }
145
161
 
146
162
  export function makeGraph(): Graph {
147
- return { ops: [], tensors: [], outputs: [] }
163
+ return { ops: [], tensors: [], outputs: [], captures: new Map() }
148
164
  }
149
165
 
150
166
  // Internal: register a fresh tensor in the graph and return its id.
package/src/module.ts CHANGED
@@ -6,8 +6,8 @@
6
6
  // W: Tensor; b: Tensor
7
7
  // constructor(inDim: number, outDim: number) {
8
8
  // super()
9
- // this.W = this.param([inDim, outDim])
10
- // this.b = this.param([outDim])
9
+ // this.W = this.param([inDim, outDim]) // randn, scale 0.02
10
+ // this.b = this.param([outDim], { init: 'zeros' })
11
11
  // }
12
12
  // }
13
13
  // class Block extends Module {
@@ -28,6 +28,54 @@
28
28
  import type { Tensor, Shape, Dtype } from './ir.js'
29
29
  import { paramInput } from './trace.js'
30
30
 
31
+ // ============================================================================
32
+ // Init metadata
33
+ // ============================================================================
34
+
35
+ /** How a parameter's initial values are produced.
36
+ * - `'randn'` — Gaussian, with `scale` (default 0.02). The common case for
37
+ * weight matrices and embeddings.
38
+ * - `'zeros'` — fill with 0. Common for biases and LayerNorm beta.
39
+ * - `'ones'` — fill with 1. Common for LayerNorm gain.
40
+ * - Custom function — receives total element count and shape, returns the
41
+ * Float32Array. Use for fan-in scaling or any non-standard scheme.
42
+ */
43
+ export type InitSpec =
44
+ | 'randn'
45
+ | 'zeros'
46
+ | 'ones'
47
+ | ((size: number, shape: readonly number[]) => Float32Array)
48
+
49
+ export interface ParamOptions {
50
+ dtype?: Dtype
51
+ /** Init kind. Default: `'randn'`. */
52
+ init?: InitSpec
53
+ /** Std dev for `'randn'`. Default 0.02. Ignored for non-randn init. */
54
+ scale?: number
55
+ }
56
+
57
+ type InitFn = (size: number, shape: readonly number[]) => Float32Array
58
+
59
+ function boxMuller(): number {
60
+ return Math.sqrt(-2 * Math.log(Math.max(1e-10, Math.random()))) * Math.cos(2 * Math.PI * Math.random())
61
+ }
62
+
63
+ function resolveInit(opts: ParamOptions | undefined): InitFn {
64
+ const init = opts?.init ?? 'randn'
65
+ if (init === 'randn') {
66
+ const scale = opts?.scale ?? 0.02
67
+ return (size) => {
68
+ const arr = new Float32Array(size)
69
+ for (let i = 0; i < size; i++) arr[i] = boxMuller() * scale
70
+ return arr
71
+ }
72
+ }
73
+ if (init === 'zeros') return (size) => new Float32Array(size)
74
+ if (init === 'ones') return (size) => { const a = new Float32Array(size); a.fill(1); return a }
75
+ if (typeof init === 'function') return init
76
+ throw new Error(`Unknown init: ${String(init)}`)
77
+ }
78
+
31
79
  // ============================================================================
32
80
  // Internals: param sentinel
33
81
  // ============================================================================
@@ -38,7 +86,11 @@ import { paramInput } from './trace.js'
38
86
  // only valid post-materialization (which is always before forward runs).
39
87
 
40
88
  class ParamSentinel {
41
- constructor(public readonly shape: Shape, public readonly dtype: Dtype) {}
89
+ constructor(
90
+ public readonly shape: Shape,
91
+ public readonly dtype: Dtype,
92
+ public readonly initFn: InitFn,
93
+ ) {}
42
94
  }
43
95
 
44
96
  // ============================================================================
@@ -52,11 +104,13 @@ export abstract class Module {
52
104
  * that gets replaced with a real Tensor at compile time.
53
105
  *
54
106
  * The parameter's name is auto-derived from its property path in the model
55
- * tree (e.g. `layers.0.attn.W_q`).
107
+ * tree (e.g. `layers.0.attn.W_q`). Init metadata travels with the param;
108
+ * call `compiled.uploadInitialParams()` to apply it after compile.
56
109
  */
57
- protected param(shape: Shape, dtype: Dtype = 'f32'): Tensor {
110
+ protected param(shape: Shape, opts?: ParamOptions): Tensor {
111
+ const dtype = opts?.dtype ?? 'f32'
58
112
  // Lie to TypeScript: the sentinel becomes a Tensor at materialize time.
59
- return new ParamSentinel(shape, dtype) as unknown as Tensor
113
+ return new ParamSentinel(shape, dtype, resolveInit(opts)) as unknown as Tensor
60
114
  }
61
115
  }
62
116
 
@@ -64,23 +118,33 @@ export abstract class Module {
64
118
  // Tree walking
65
119
  // ============================================================================
66
120
 
121
+ export interface MaterializedParams {
122
+ /** Map from auto-derived path (e.g. `layers.0.attn.W_q`) to its Tensor. */
123
+ tensors: Record<string, Tensor>
124
+ /** Init function per param path. Used by `uploadInitialParams`. */
125
+ initFns: Record<string, InitFn>
126
+ }
127
+
67
128
  /**
68
129
  * Walk the module tree and replace every ParamSentinel with a real Tensor
69
130
  * created via `paramInput(autoName, ...)`. Must be called inside an active
70
131
  * trace context (paramInput appends to the current graph).
71
132
  *
72
- * Returns a flat record of `{ path: tensor }` for every materialized param.
133
+ * Returns the param tensors keyed by path, plus init functions for use by
134
+ * `uploadInitialParams`.
73
135
  */
74
- export function materializeParams(root: Module): Record<string, Tensor> {
75
- const out: Record<string, Tensor> = {}
136
+ export function materializeParams(root: Module): MaterializedParams {
137
+ const tensors: Record<string, Tensor> = {}
138
+ const initFns: Record<string, InitFn> = {}
76
139
  visit(root, '', (path, val, owner, key) => {
77
140
  if (val instanceof ParamSentinel) {
78
141
  const t = paramInput(path, val.shape, val.dtype)
79
142
  ;(owner as any)[key] = t
80
- out[path] = t
143
+ tensors[path] = t
144
+ initFns[path] = val.initFn
81
145
  }
82
146
  })
83
- return out
147
+ return { tensors, initFns }
84
148
  }
85
149
 
86
150
  // ----------------------------------------------------------------------------