tensorgrad 0.0.1 → 0.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. package/README.md +7 -9
  2. package/dist/adam.d.ts +14 -2
  3. package/dist/adam.d.ts.map +1 -1
  4. package/dist/adam.js +19 -8
  5. package/dist/adam.js.map +1 -1
  6. package/dist/buffers.d.ts +1 -0
  7. package/dist/buffers.d.ts.map +1 -1
  8. package/dist/buffers.js +12 -1
  9. package/dist/buffers.js.map +1 -1
  10. package/dist/capture.d.ts +3 -0
  11. package/dist/capture.d.ts.map +1 -0
  12. package/dist/capture.js +33 -0
  13. package/dist/capture.js.map +1 -0
  14. package/dist/codegen.js +4 -2
  15. package/dist/codegen.js.map +1 -1
  16. package/dist/compile.d.ts +33 -5
  17. package/dist/compile.d.ts.map +1 -1
  18. package/dist/compile.js +96 -11
  19. package/dist/compile.js.map +1 -1
  20. package/dist/index.d.ts +5 -3
  21. package/dist/index.d.ts.map +1 -1
  22. package/dist/index.js +4 -2
  23. package/dist/index.js.map +1 -1
  24. package/dist/ir.d.ts +2 -0
  25. package/dist/ir.d.ts.map +1 -1
  26. package/dist/ir.js +1 -1
  27. package/dist/ir.js.map +1 -1
  28. package/dist/module.d.ts +30 -4
  29. package/dist/module.d.ts.map +1 -1
  30. package/dist/module.js +39 -13
  31. package/dist/module.js.map +1 -1
  32. package/dist/nn.d.ts +19 -0
  33. package/dist/nn.d.ts.map +1 -0
  34. package/dist/nn.js +60 -0
  35. package/dist/nn.js.map +1 -0
  36. package/dist/ops.d.ts +1 -1
  37. package/dist/ops.d.ts.map +1 -1
  38. package/dist/ops.js +2 -2
  39. package/dist/ops.js.map +1 -1
  40. package/dist/runtime.d.ts +79 -4
  41. package/dist/runtime.d.ts.map +1 -1
  42. package/dist/runtime.js +153 -19
  43. package/dist/runtime.js.map +1 -1
  44. package/dist/trace.d.ts +1 -0
  45. package/dist/trace.d.ts.map +1 -1
  46. package/dist/trace.js +12 -0
  47. package/dist/trace.js.map +1 -1
  48. package/package.json +1 -2
  49. package/src/adam.ts +31 -10
  50. package/src/buffers.ts +14 -1
  51. package/src/capture.ts +36 -0
  52. package/src/codegen.ts +4 -2
  53. package/src/compile.ts +112 -13
  54. package/src/index.ts +5 -3
  55. package/src/ir.ts +10 -4
  56. package/src/module.ts +75 -11
  57. package/src/nn.ts +59 -0
  58. package/src/ops.ts +2 -2
  59. package/src/runtime.ts +260 -22
  60. package/src/trace.ts +13 -0
  61. package/SPEC.md +0 -293
package/src/adam.ts CHANGED
@@ -1,4 +1,4 @@
1
- // Adam optimizer, in-graph.
1
+ // Adam / AdamW optimizer, in-graph.
2
2
  //
3
3
  // `appendAdam` extends a graph that already has a forward pass + autograd-emitted
4
4
  // backward (i.e., has paramGrads from `appendGrad`) with the Adam update math.
@@ -6,12 +6,14 @@
6
6
  // Per parameter P with gradient g:
7
7
  // m_new = b1 * m + (1 - b1) * g
8
8
  // v_new = b2 * v + (1 - b2) * g²
9
- // p_new = p - lr * m_new / (sqrt(v_new) + eps)
9
+ // p_new = decayShrink * p - lrt * m_new / (sqrt(v_new) + eps)
10
10
  //
11
- // This is "Adam without bias correction" — the `1 / (1 - β^t)` factors are
12
- // dropped because computing them in-graph requires per-step uniforms or
13
- // awkward exp/log tricks. In practice the omission only affects the first
14
- // ~100 steps; convergence is unaffected.
11
+ // `decayShrink = 1 - lr * weightDecay` when the param is being decayed
12
+ // (Loshchilov & Hutter, "AdamW") and 1 otherwise at which point the
13
+ // multiply folds out and you're left with plain Adam. `lrt` is supplied
14
+ // per-step from CPU and includes the bias-correction factor
15
+ // `sqrt(1-b2^t)/(1-b1^t)`; that's why convergence isn't affected by the
16
+ // first-step warmup that bias-correction-free Adam suffers.
15
17
  //
16
18
  // Returns writeback declarations the buffer planner uses to wire up the
17
19
  // "after step, copy the new value into the persistent home" path. m and v
@@ -29,6 +31,15 @@ export interface AdamConfig {
29
31
  b1?: number // default 0.9
30
32
  b2?: number // default 0.999
31
33
  eps?: number // default 1e-8
34
+ /** AdamW: decoupled weight decay coefficient. Default 0 (plain Adam).
35
+ * When non-zero, every step shrinks each decayed param by a factor of
36
+ * `1 - lr * weightDecay` before the gradient update. */
37
+ weightDecay?: number
38
+ /** Filter deciding which params get weight decay. Only consulted when
39
+ * weightDecay > 0. Default: decay every param. Override for the standard
40
+ * transformer convention (decay weights/embeddings, skip biases + LN gains).
41
+ * Example: `(name) => name.includes('.W') || name.endsWith('_emb')`. */
42
+ decayFilter?: (paramName: string) => boolean
32
43
  }
33
44
 
34
45
  export interface AdamResult {
@@ -38,7 +49,7 @@ export interface AdamResult {
38
49
  * with `lr * sqrt(1-b2^t)/(1-b1^t)` (Adam's bias-corrected effective LR). */
39
50
  lrtInputName: string
40
51
  /** Hyperparameters as captured (so the runtime can compute lrt). */
41
- config: Required<AdamConfig>
52
+ config: Required<Omit<AdamConfig, 'decayFilter'>> & { decayFilter: (name: string) => boolean }
42
53
  }
43
54
 
44
55
  /**
@@ -50,7 +61,8 @@ export interface AdamResult {
50
61
  * @param paramTensors param name -> the param's leaf Tensor (the param_input).
51
62
  * Needed because the param_input lives in the graph but we
52
63
  * don't have a direct map by name in `Graph` — caller passes it.
53
- * @param config Adam hyperparameters
64
+ * @param config Adam hyperparameters. Set `weightDecay > 0` for AdamW; an
65
+ * optional `decayFilter` selects which params receive decay.
54
66
  */
55
67
  export function appendAdam(
56
68
  graph: Graph,
@@ -58,11 +70,13 @@ export function appendAdam(
58
70
  paramTensors: Record<string, Tensor>,
59
71
  config: AdamConfig,
60
72
  ): AdamResult {
61
- const fullConfig: Required<AdamConfig> = {
73
+ const fullConfig = {
62
74
  lr: config.lr,
63
75
  b1: config.b1 ?? 0.9,
64
76
  b2: config.b2 ?? 0.999,
65
77
  eps: config.eps ?? 1e-8,
78
+ weightDecay: config.weightDecay ?? 0,
79
+ decayFilter: config.decayFilter ?? (() => true),
66
80
  }
67
81
  const writebacks: WritebackDecl[] = []
68
82
  const lrtInputName = '_adam_lrt'
@@ -81,10 +95,17 @@ export function appendAdam(
81
95
  const mState = stateInput(`adam_m_${name}`, p.shape, 'f32', 0)
82
96
  const vState = stateInput(`adam_v_${name}`, p.shape, 'f32', 0)
83
97
 
98
+ // decayShrink baked at compile time. 1.0 for plain Adam (no extra cost
99
+ // — the WGSL compiler folds the constant multiply); 1 - lr * weightDecay
100
+ // for the params the filter selects.
101
+ const decayShrink = (fullConfig.weightDecay > 0 && fullConfig.decayFilter(name))
102
+ ? 1 - fullConfig.lr * fullConfig.weightDecay
103
+ : 1
104
+
84
105
  // Three fused kernels per parameter — one for each of m_new / v_new / p_new.
85
106
  const newM = adamUpdateM(mState, g, fullConfig.b1)
86
107
  const newV = adamUpdateV(vState, g, fullConfig.b2)
87
- const newP = adamUpdateP(p, newM, newV, lrt, fullConfig.eps)
108
+ const newP = adamUpdateP(p, newM, newV, lrt, fullConfig.eps, decayShrink)
88
109
 
89
110
  writebacks.push({ source: newM, destName: `adam_m_${name}`, destKind: 'state' })
90
111
  writebacks.push({ source: newV, destName: `adam_v_${name}`, destKind: 'state' })
package/src/buffers.ts CHANGED
@@ -47,6 +47,7 @@ export interface BufferPlan {
47
47
  inputsByName: Map<string, number> // name -> buffer id
48
48
  paramGradsByName: Map<string, number> // name -> buffer id
49
49
  statesByName: Map<string, number> // name -> buffer id (persistent state homes)
50
+ capturesByName: Map<string, number> // name -> buffer id (activation captures)
50
51
  outputBufferIds: number[] // graph.outputs mapped through
51
52
  /** End-of-step writebacks (Adam updates for params, m, v, etc.) */
52
53
  writebacks: Writeback[]
@@ -169,5 +170,17 @@ export function planBuffers(
169
170
  return { source: sourceBufId, dest: destBufId, bytes: sourceSpec.byteSize }
170
171
  })
171
172
 
172
- return { buffers, tensorToBuffer, paramsByName, inputsByName, paramGradsByName, statesByName, outputBufferIds, writebacks }
173
+ // Resolve graph.captures (name -> tensor id) to (name -> buffer id).
174
+ // No pinning needed at the planner level: each tensor already has its own
175
+ // buffer (see "v1 strategy" comment at top — no pooling yet).
176
+ const capturesByName = new Map<string, number>()
177
+ for (const [name, tensorId] of graph.captures) {
178
+ const bufId = tensorToBuffer.get(tensorId)
179
+ if (bufId === undefined) {
180
+ throw new Error(`planBuffers: capture '${name}' references unknown tensor #${tensorId}`)
181
+ }
182
+ capturesByName.set(name, bufId)
183
+ }
184
+
185
+ return { buffers, tensorToBuffer, paramsByName, inputsByName, paramGradsByName, statesByName, capturesByName, outputBufferIds, writebacks }
173
186
  }
package/src/capture.ts ADDED
@@ -0,0 +1,36 @@
1
+ // Activation capture — opt-in readback of intermediate tensors at training step.
2
+ //
3
+ // Usage (inside the user's forward pass):
4
+ //
5
+ // import { capture } from 'tensorgrad'
6
+ //
7
+ // function attentionFwd(p, x) {
8
+ // const scores = mul(matmulBatched(q, kT), SCALE_QK)
9
+ // const attn = capture(`attn.${layerIdx}`, softmaxCausalLast(scores))
10
+ // return matmulBatched(attn, v)
11
+ // }
12
+ //
13
+ // Pass-through return type: `capture(name, t)` returns `t` unchanged so it
14
+ // inlines at the point of computation. Behind the scenes it registers `t.id`
15
+ // against `name` on the current graph; runtime exposes the registered tensors
16
+ // via `step(inputs, { withCaptures: true })`.
17
+ //
18
+ // Outside the user's forward trace (during `appendGrad` / `appendAdam`'s
19
+ // `traceInto` re-entry), `capture()` is a no-op — gradient and optimizer
20
+ // internals shouldn't accidentally publish themselves to the UI.
21
+
22
+ import type { Tensor } from './ir.js'
23
+ import { currentGraph, isCaptureEnabled } from './trace.js'
24
+
25
+ export function capture<T extends Tensor>(name: string, t: T): T {
26
+ if (!isCaptureEnabled()) return t
27
+ const g = currentGraph()
28
+ if (g.captures.has(name)) {
29
+ throw new Error(
30
+ `capture: name '${name}' already registered. Use unique names ` +
31
+ `(e.g. \`attn.\${layerIdx}\`) when capturing across a loop.`,
32
+ )
33
+ }
34
+ g.captures.set(name, t.id)
35
+ return t
36
+ }
package/src/codegen.ts CHANGED
@@ -555,8 +555,10 @@ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
555
555
  return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.v), buf(op.g), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
556
556
  }
557
557
  case 'adam_update_p': {
558
- // p_new = p - lrt[0] * m_new / (sqrt(v_new) + eps).
558
+ // p_new = decayShrink * p - lrt[0] * m_new / (sqrt(v_new) + eps).
559
559
  // lrt is supplied per-step from CPU (already includes bias correction).
560
+ // decayShrink encodes AdamW's decoupled weight decay; when no decay is
561
+ // requested it's exactly 1.0 and the WGSL compiler folds the multiply away.
560
562
  const out = tof(op.out)
561
563
  const total = shapeSize(out.shape)
562
564
  const wgsl = `
@@ -569,7 +571,7 @@ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
569
571
  fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
570
572
  let i = gid.x + gid.y * 16776960u;
571
573
  if (i >= ${total}u) { return; }
572
- out[i] = p[i] - lrt[0] * mNew[i] / (sqrt(vNew[i]) + ${wgslLiteral(op.eps, 'f32')});
574
+ out[i] = ${wgslLiteral(op.decayShrink, 'f32')} * p[i] - lrt[0] * mNew[i] / (sqrt(vNew[i]) + ${wgslLiteral(op.eps, 'f32')});
573
575
  }`.trim()
574
576
  return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.p), buf(op.mNew), buf(op.vNew), buf(op.lrt), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
575
577
  }
package/src/compile.ts CHANGED
@@ -14,7 +14,7 @@ import { appendGrad, type GradResult } from './grad.js'
14
14
  import { appendAdam, type AdamConfig } from './adam.js'
15
15
  import { planBuffers, type BufferPlan } from './buffers.js'
16
16
  import { emitKernels, type KernelSpec } from './codegen.js'
17
- import { createRuntime, type CompiledRuntime, type RuntimeOpts } from './runtime.js'
17
+ import { createRuntime, createForwardRuntime, type CompiledRuntime, type CompiledForward, type RuntimeOpts } from './runtime.js'
18
18
  import { Module, materializeParams } from './module.js'
19
19
 
20
20
  /** Declares one input tensor of the model's forward function. Order matches
@@ -65,10 +65,19 @@ export interface CompileModuleOptions extends RuntimeOpts {
65
65
  adam?: AdamConfig
66
66
  }
67
67
 
68
+ export interface CompileForwardOptions extends RuntimeOpts {
69
+ /** Per-step data inputs to the forward function. */
70
+ inputs?: InputDecl[]
71
+ }
72
+
68
73
  /**
69
- * Compile a Module-based model. The forward function takes the materialized
70
- * model and returns the loss tensor (typically by also calling tensorInput
71
- * for tokens/targets/masks inside).
74
+ * Compile a Module-based model. Pass a *factory* `() => new Model()`, not the
75
+ * model instance itself: compilation mutates the tree (every `ParamSentinel`
76
+ * field becomes a real `Tensor`), so the instance is consumed and shouldn't be
77
+ * referenced afterwards. Re-call the factory if you need a fresh tree.
78
+ *
79
+ * The forward function takes the materialized model and returns the loss
80
+ * tensor.
72
81
  *
73
82
  * Walks the module tree to materialize params with auto-derived names, then
74
83
  * runs trace → grad → adam → buffer plan → codegen → runtime.
@@ -78,14 +87,15 @@ export interface CompileModuleOptions extends RuntimeOpts {
78
87
  * users don't need to provide it themselves.
79
88
  */
80
89
  export async function compileModule<M extends Module>(
81
- model: M,
90
+ modelFactory: () => M,
82
91
  forward: (m: M, ...inputs: Tensor[]) => Tensor,
83
92
  opts: CompileModuleOptions = {},
84
- ): Promise<CompiledRuntime & { ir: CompiledIR }> {
93
+ ): Promise<CompiledRuntime & { ir: CompiledIR; uploadInitialParams: () => void }> {
85
94
  const inputDecls = opts.inputs ?? []
86
- let paramTensors: Record<string, Tensor> = {}
95
+ const model = modelFactory()
96
+ let materialized: ReturnType<typeof materializeParams> = { tensors: {}, initFns: {} }
87
97
  const graph = trace(() => {
88
- paramTensors = materializeParams(model)
98
+ materialized = materializeParams(model)
89
99
  const inputTensors = inputDecls.map(d => tensorInput(d.name, d.shape, d.dtype ?? 'f32'))
90
100
  return forward(model, ...inputTensors)
91
101
  })
@@ -94,7 +104,7 @@ export async function compileModule<M extends Module>(
94
104
 
95
105
  let adamResult: ReturnType<typeof appendAdam> | undefined
96
106
  if (opts.adam) {
97
- adamResult = appendAdam(graph, paramGrads, paramTensors, opts.adam)
107
+ adamResult = appendAdam(graph, paramGrads, materialized.tensors, opts.adam)
98
108
  }
99
109
 
100
110
  const plan = planBuffers(graph, paramGrads, adamResult?.writebacks ?? [])
@@ -103,18 +113,107 @@ export async function compileModule<M extends Module>(
103
113
  const runtime = await createRuntime(plan, kernels, lossBufferId, opts)
104
114
 
105
115
  // If Adam is enabled, wrap step() to track the step count and supply lrt.
116
+ // Wrap resetOptimizerState() too, so a reset zeros m/v *and* the bias-correction
117
+ // counter — otherwise the next step would skip Adam's warmup phase.
106
118
  if (adamResult) {
107
119
  const { lrtInputName, config } = adamResult
108
120
  let t = 0
109
121
  const lrtBuf = new Float32Array(1)
110
- const innerStep = runtime.step.bind(runtime)
111
- runtime.step = async (inputs) => {
122
+ const innerStep = runtime.step.bind(runtime) as CompiledRuntime['step']
123
+ const innerReset = runtime.resetOptimizerState.bind(runtime)
124
+ const wrappedStep = (
125
+ inputs: Record<string, Int32Array | Float32Array>,
126
+ opts?: { withCaptures?: boolean },
127
+ ): Promise<number | { loss: number; captures: Record<string, Float32Array> }> => {
112
128
  t++
113
129
  lrtBuf[0] = config.lr * Math.sqrt(1 - Math.pow(config.b2, t)) / (1 - Math.pow(config.b1, t))
114
- return innerStep({ ...inputs, [lrtInputName]: lrtBuf })
130
+ const merged = { ...inputs, [lrtInputName]: lrtBuf }
131
+ return opts?.withCaptures ? innerStep(merged, { withCaptures: true }) : innerStep(merged)
132
+ }
133
+ runtime.step = wrappedStep as CompiledRuntime['step']
134
+ runtime.resetOptimizerState = () => {
135
+ t = 0
136
+ innerReset()
137
+ }
138
+ }
139
+
140
+ const { initFns } = materialized
141
+ const uploadInitialParams = () => {
142
+ const out: Record<string, Float32Array> = {}
143
+ for (const [name, bufId] of plan.paramsByName) {
144
+ const shape = plan.buffers[bufId]!.shape
145
+ const size = shape.reduce((a, b) => a * b, 1)
146
+ const initFn = initFns[name]
147
+ if (!initFn) throw new Error(`uploadInitialParams: no init for param '${name}'`)
148
+ out[name] = initFn(size, shape)
115
149
  }
150
+ runtime.uploadParams(out)
116
151
  }
117
152
 
118
153
  const ir: CompiledIR = { graph, paramGrads, loss, plan, kernels }
119
- return Object.assign(runtime, { ir })
154
+ return Object.assign(runtime, { ir, uploadInitialParams })
155
+ }
156
+
157
+ // ============================================================================
158
+ // Forward-only compile
159
+ // ============================================================================
160
+
161
+ /**
162
+ * Compile a Module-based model in forward-only mode (no autograd, no Adam).
163
+ * The forward function returns the output tensor (e.g., logits) instead of a
164
+ * scalar loss; runtime exposes `run(inputs)` returning the full output as a
165
+ * `Float32Array`.
166
+ *
167
+ * **Sharing params with a training compile.** Pass `opts.sharedParams =
168
+ * trainCompiled.params` to bind this graph's param buffers to an existing
169
+ * training runtime's GPU buffers — every train step is then immediately
170
+ * visible to `run()` calls here, no copies. The forward graph's
171
+ * `uploadInitialParams()` skips any param covered by `sharedParams`.
172
+ *
173
+ * Typical use: a B=1 inference graph alongside a B=512 training graph,
174
+ * built from the same `Module` factory.
175
+ */
176
+ export async function compileForward<M extends Module>(
177
+ modelFactory: () => M,
178
+ forward: (m: M, ...inputs: Tensor[]) => Tensor,
179
+ opts: CompileForwardOptions = {},
180
+ ): Promise<CompiledForward & { ir: CompiledIR; uploadInitialParams: () => void }> {
181
+ const inputDecls = opts.inputs ?? []
182
+ const model = modelFactory()
183
+ let materialized: ReturnType<typeof materializeParams> = { tensors: {}, initFns: {} }
184
+ const graph = trace(() => {
185
+ materialized = materializeParams(model)
186
+ const inputTensors = inputDecls.map(d => tensorInput(d.name, d.shape, d.dtype ?? 'f32'))
187
+ return forward(model, ...inputTensors)
188
+ })
189
+
190
+ const plan = planBuffers(graph, /* paramGrads */ {})
191
+ const kernels = emitKernels(graph, plan)
192
+ const outputTensor = graph.tensors[graph.outputs[0]!]!
193
+ const outputBufferId = plan.tensorToBuffer.get(outputTensor.id)!
194
+ const runtime = await createForwardRuntime(plan, kernels, outputBufferId, opts)
195
+
196
+ const sharedParams = opts.sharedParams
197
+ const { initFns } = materialized
198
+ const uploadInitialParams = () => {
199
+ const out: Record<string, Float32Array> = {}
200
+ let needsUpload = false
201
+ for (const [name, bufId] of plan.paramsByName) {
202
+ // Skip params covered by sharedParams — those are owned by the providing
203
+ // compile and already initialized there.
204
+ if (sharedParams?.has(name)) continue
205
+ const shape = plan.buffers[bufId]!.shape
206
+ const size = shape.reduce((a, b) => a * b, 1)
207
+ const initFn = initFns[name]
208
+ if (!initFn) throw new Error(`uploadInitialParams: no init for param '${name}'`)
209
+ out[name] = initFn(size, shape)
210
+ needsUpload = true
211
+ }
212
+ if (needsUpload) runtime.uploadParams(out, { partial: !!sharedParams })
213
+ }
214
+
215
+ // CompiledIR.loss is the field name; for forward-only, it carries the user's
216
+ // returned tensor (e.g., logits). Same shape conceptually; just no autograd.
217
+ const ir: CompiledIR = { graph, paramGrads: {}, loss: outputTensor, plan, kernels }
218
+ return Object.assign(runtime, { ir, uploadInitialParams })
120
219
  }
package/src/index.ts CHANGED
@@ -6,6 +6,7 @@
6
6
  export type { Tensor, Shape, Dtype, OpNode, Graph, CallSite } from './ir.js'
7
7
  export { ShapeError } from './shape.js'
8
8
  export { trace, traceInto, paramInput, tensorInput, stateInput } from './trace.js'
9
+ export { capture } from './capture.js'
9
10
  export {
10
11
  // Element-wise arithmetic. The binops accept Tensor or JS-number for the second arg.
11
12
  add, sub, mul, div,
@@ -35,6 +36,7 @@ export { appendGrad, type GradResult } from './grad.js'
35
36
  export { appendAdam, type AdamConfig, type AdamResult } from './adam.js'
36
37
  export { planBuffers, type BufferPlan, type BufferSpec, type Writeback, type WritebackDecl } from './buffers.js'
37
38
  export { emitKernels, type KernelSpec } from './codegen.js'
38
- export { createRuntime, type CompiledRuntime, type RuntimeOpts } from './runtime.js'
39
- export { compile, compileToIR, compileModule, type CompiledIR, type CompileModuleOptions, type InputDecl } from './compile.js'
40
- export { Module, materializeParams } from './module.js'
39
+ export { createRuntime, createForwardRuntime, type CompiledRuntime, type CompiledForward, type RuntimeOpts, type StepOptions, type StepWithCaptures, type RunOptions, type RunWithCaptures } from './runtime.js'
40
+ export { compile, compileToIR, compileModule, compileForward, type CompiledIR, type CompileModuleOptions, type CompileForwardOptions, type InputDecl } from './compile.js'
41
+ export { Module, materializeParams, type InitSpec, type ParamOptions, type MaterializedParams } from './module.js'
42
+ export * as nn from './nn.js'
package/src/ir.ts CHANGED
@@ -109,11 +109,13 @@ export type OpNode =
109
109
  // update into ~12 element-wise dispatches per param.
110
110
  | { kind: 'adam_update_m'; out: number; m: number; g: number; b1: number }
111
111
  | { kind: 'adam_update_v'; out: number; v: number; g: number; b2: number }
112
- // adam_update_p: p_new = p - lrt[0] * m_new / (sqrt(v_new) + eps).
112
+ // adam_update_p: p_new = decayShrink * p - lrt[0] * m_new / (sqrt(v_new) + eps).
113
113
  // `lrt` is a scalar tensor (provided as a tensor_input updated per step) that
114
114
  // already includes Adam's bias-correction factor: lrt = lr * sqrt(1-b2^t) / (1-b1^t).
115
- // Only `eps` is baked in.
116
- | { kind: 'adam_update_p'; out: number; p: number; mNew: number; vNew: number; lrt: number; eps: number }
115
+ // `decayShrink` is the decoupled-weight-decay factor (Loshchilov & Hutter,
116
+ // "AdamW") baked at compile time: 1 - lr * weightDecay when the param is being
117
+ // decayed, 1 otherwise. eps and decayShrink are both baked into the kernel.
118
+ | { kind: 'adam_update_p'; out: number; p: number; mNew: number; vNew: number; lrt: number; eps: number; decayShrink: number }
117
119
 
118
120
  // ---- Slicing / broadcasting / autograd infrastructure -------------------
119
121
  // Slice [start, end) along the last axis. Output shape: input shape with
@@ -139,10 +141,14 @@ export interface Graph {
139
141
  // Names of tensors that should be exposed as outputs of the compiled function.
140
142
  // Set by the trace driver; for a loss function, this is `[lossTensor]`.
141
143
  readonly outputs: number[]
144
+ // Tensors registered for activation readback via `capture(name, t)`.
145
+ // Keyed by user-supplied name; insertion order preserved. Empty when no
146
+ // captures registered (the common training case — zero overhead).
147
+ readonly captures: Map<string, number>
142
148
  }
143
149
 
144
150
  export function makeGraph(): Graph {
145
- return { ops: [], tensors: [], outputs: [] }
151
+ return { ops: [], tensors: [], outputs: [], captures: new Map() }
146
152
  }
147
153
 
148
154
  // Internal: register a fresh tensor in the graph and return its id.
package/src/module.ts CHANGED
@@ -6,8 +6,8 @@
6
6
  // W: Tensor; b: Tensor
7
7
  // constructor(inDim: number, outDim: number) {
8
8
  // super()
9
- // this.W = this.param([inDim, outDim])
10
- // this.b = this.param([outDim])
9
+ // this.W = this.param([inDim, outDim]) // randn, scale 0.02
10
+ // this.b = this.param([outDim], { init: 'zeros' })
11
11
  // }
12
12
  // }
13
13
  // class Block extends Module {
@@ -28,6 +28,54 @@
28
28
  import type { Tensor, Shape, Dtype } from './ir.js'
29
29
  import { paramInput } from './trace.js'
30
30
 
31
+ // ============================================================================
32
+ // Init metadata
33
+ // ============================================================================
34
+
35
+ /** How a parameter's initial values are produced.
36
+ * - `'randn'` — Gaussian, with `scale` (default 0.02). The common case for
37
+ * weight matrices and embeddings.
38
+ * - `'zeros'` — fill with 0. Common for biases and LayerNorm beta.
39
+ * - `'ones'` — fill with 1. Common for LayerNorm gain.
40
+ * - Custom function — receives total element count and shape, returns the
41
+ * Float32Array. Use for fan-in scaling or any non-standard scheme.
42
+ */
43
+ export type InitSpec =
44
+ | 'randn'
45
+ | 'zeros'
46
+ | 'ones'
47
+ | ((size: number, shape: readonly number[]) => Float32Array)
48
+
49
+ export interface ParamOptions {
50
+ dtype?: Dtype
51
+ /** Init kind. Default: `'randn'`. */
52
+ init?: InitSpec
53
+ /** Std dev for `'randn'`. Default 0.02. Ignored for non-randn init. */
54
+ scale?: number
55
+ }
56
+
57
+ type InitFn = (size: number, shape: readonly number[]) => Float32Array
58
+
59
+ function boxMuller(): number {
60
+ return Math.sqrt(-2 * Math.log(Math.max(1e-10, Math.random()))) * Math.cos(2 * Math.PI * Math.random())
61
+ }
62
+
63
+ function resolveInit(opts: ParamOptions | undefined): InitFn {
64
+ const init = opts?.init ?? 'randn'
65
+ if (init === 'randn') {
66
+ const scale = opts?.scale ?? 0.02
67
+ return (size) => {
68
+ const arr = new Float32Array(size)
69
+ for (let i = 0; i < size; i++) arr[i] = boxMuller() * scale
70
+ return arr
71
+ }
72
+ }
73
+ if (init === 'zeros') return (size) => new Float32Array(size)
74
+ if (init === 'ones') return (size) => { const a = new Float32Array(size); a.fill(1); return a }
75
+ if (typeof init === 'function') return init
76
+ throw new Error(`Unknown init: ${String(init)}`)
77
+ }
78
+
31
79
  // ============================================================================
32
80
  // Internals: param sentinel
33
81
  // ============================================================================
@@ -38,7 +86,11 @@ import { paramInput } from './trace.js'
38
86
  // only valid post-materialization (which is always before forward runs).
39
87
 
40
88
  class ParamSentinel {
41
- constructor(public readonly shape: Shape, public readonly dtype: Dtype) {}
89
+ constructor(
90
+ public readonly shape: Shape,
91
+ public readonly dtype: Dtype,
92
+ public readonly initFn: InitFn,
93
+ ) {}
42
94
  }
43
95
 
44
96
  // ============================================================================
@@ -52,11 +104,13 @@ export abstract class Module {
52
104
  * that gets replaced with a real Tensor at compile time.
53
105
  *
54
106
  * The parameter's name is auto-derived from its property path in the model
55
- * tree (e.g. `layers.0.attn.W_q`).
107
+ * tree (e.g. `layers.0.attn.W_q`). Init metadata travels with the param;
108
+ * call `compiled.uploadInitialParams()` to apply it after compile.
56
109
  */
57
- protected param(shape: Shape, dtype: Dtype = 'f32'): Tensor {
110
+ protected param(shape: Shape, opts?: ParamOptions): Tensor {
111
+ const dtype = opts?.dtype ?? 'f32'
58
112
  // Lie to TypeScript: the sentinel becomes a Tensor at materialize time.
59
- return new ParamSentinel(shape, dtype) as unknown as Tensor
113
+ return new ParamSentinel(shape, dtype, resolveInit(opts)) as unknown as Tensor
60
114
  }
61
115
  }
62
116
 
@@ -64,23 +118,33 @@ export abstract class Module {
64
118
  // Tree walking
65
119
  // ============================================================================
66
120
 
121
+ export interface MaterializedParams {
122
+ /** Map from auto-derived path (e.g. `layers.0.attn.W_q`) to its Tensor. */
123
+ tensors: Record<string, Tensor>
124
+ /** Init function per param path. Used by `uploadInitialParams`. */
125
+ initFns: Record<string, InitFn>
126
+ }
127
+
67
128
  /**
68
129
  * Walk the module tree and replace every ParamSentinel with a real Tensor
69
130
  * created via `paramInput(autoName, ...)`. Must be called inside an active
70
131
  * trace context (paramInput appends to the current graph).
71
132
  *
72
- * Returns a flat record of `{ path: tensor }` for every materialized param.
133
+ * Returns the param tensors keyed by path, plus init functions for use by
134
+ * `uploadInitialParams`.
73
135
  */
74
- export function materializeParams(root: Module): Record<string, Tensor> {
75
- const out: Record<string, Tensor> = {}
136
+ export function materializeParams(root: Module): MaterializedParams {
137
+ const tensors: Record<string, Tensor> = {}
138
+ const initFns: Record<string, InitFn> = {}
76
139
  visit(root, '', (path, val, owner, key) => {
77
140
  if (val instanceof ParamSentinel) {
78
141
  const t = paramInput(path, val.shape, val.dtype)
79
142
  ;(owner as any)[key] = t
80
- out[path] = t
143
+ tensors[path] = t
144
+ initFns[path] = val.initFn
81
145
  }
82
146
  })
83
- return out
147
+ return { tensors, initFns }
84
148
  }
85
149
 
86
150
  // ----------------------------------------------------------------------------
package/src/nn.ts ADDED
@@ -0,0 +1,59 @@
1
+ // Standard "batteries-included" Module subclasses for the most common layers.
2
+ //
3
+ // JAX-style: each class declares its params (and their init); the forward is a
4
+ // plain function the user calls with `(module, x)`. No subclassing, no method
5
+ // dispatch — keeps the autograd-traced computation visible at the call site.
6
+ //
7
+ // Import as a namespace:
8
+ //
9
+ // import { nn } from 'tensorgrad'
10
+ // class Block extends Module {
11
+ // ln = new nn.LayerNorm(D)
12
+ // ffn = new nn.Linear(D, 4 * D)
13
+ // }
14
+ // const y = nn.linearFwd(p.ffn, nn.layerNormFwd(p.ln, x))
15
+
16
+ import { Module } from './module.js'
17
+ import type { Tensor } from './ir.js'
18
+ import { add, matmul, sub, mul, div, sqrt, meanLast } from './ops.js'
19
+
20
+ // ----------------------------------------------------------------------------
21
+ // Linear: y = x @ W (+ b)
22
+ // ----------------------------------------------------------------------------
23
+
24
+ export class Linear extends Module {
25
+ W: Tensor
26
+ b: Tensor | null
27
+ constructor(public readonly inDim: number, public readonly outDim: number, withBias = true) {
28
+ super()
29
+ this.W = this.param([inDim, outDim]) // randn, scale 0.02
30
+ this.b = withBias ? this.param([outDim], { init: 'zeros' }) : null
31
+ }
32
+ }
33
+
34
+ export function linearFwd(p: Linear, x: Tensor): Tensor {
35
+ const out = matmul(x, p.W)
36
+ return p.b ? add(out, p.b) : out
37
+ }
38
+
39
+ // ----------------------------------------------------------------------------
40
+ // LayerNorm — normalizes over the last axis. eps defaults to 1e-5.
41
+ // ----------------------------------------------------------------------------
42
+
43
+ export class LayerNorm extends Module {
44
+ g: Tensor
45
+ b: Tensor
46
+ constructor(public readonly d: number, public readonly eps: number = 1e-5) {
47
+ super()
48
+ this.g = this.param([d], { init: 'ones' })
49
+ this.b = this.param([d], { init: 'zeros' })
50
+ }
51
+ }
52
+
53
+ export function layerNormFwd(p: LayerNorm, x: Tensor): Tensor {
54
+ const m = meanLast(x)
55
+ const c = sub(x, m)
56
+ const v = meanLast(mul(c, c))
57
+ const stdev = sqrt(add(v, p.eps))
58
+ return add(mul(div(c, stdev), p.g), p.b)
59
+ }
package/src/ops.ts CHANGED
@@ -297,7 +297,7 @@ export function adamUpdateV(v: Tensor, g: Tensor, b2: number): Tensor {
297
297
  return addOp(currentGraph(), 'adam_update_v', v.shape, 'f32', site, { v: v.id, g: g.id, b2 })
298
298
  }
299
299
 
300
- export function adamUpdateP(p: Tensor, mNew: Tensor, vNew: Tensor, lrt: Tensor, eps: number): Tensor {
300
+ export function adamUpdateP(p: Tensor, mNew: Tensor, vNew: Tensor, lrt: Tensor, eps: number, decayShrink: number = 1): Tensor {
301
301
  const site = captureSite('adamUpdateP')
302
302
  if (p.dtype !== 'f32') throw new ShapeError(`adamUpdateP: requires f32`, site)
303
303
  if (lrt.dtype !== 'f32' || lrt.shape.length !== 0) {
@@ -307,5 +307,5 @@ export function adamUpdateP(p: Tensor, mNew: Tensor, vNew: Tensor, lrt: Tensor,
307
307
  throw new ShapeError(`adamUpdateP: p/mNew shape mismatch`, site)
308
308
  }
309
309
  return addOp(currentGraph(), 'adam_update_p', p.shape, 'f32', site,
310
- { p: p.id, mNew: mNew.id, vNew: vNew.id, lrt: lrt.id, eps })
310
+ { p: p.id, mNew: mNew.id, vNew: vNew.id, lrt: lrt.id, eps, decayShrink })
311
311
  }