tensorgrad 0.0.1 → 0.0.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +7 -9
- package/dist/adam.d.ts +14 -2
- package/dist/adam.d.ts.map +1 -1
- package/dist/adam.js +19 -8
- package/dist/adam.js.map +1 -1
- package/dist/buffers.d.ts +1 -0
- package/dist/buffers.d.ts.map +1 -1
- package/dist/buffers.js +12 -1
- package/dist/buffers.js.map +1 -1
- package/dist/capture.d.ts +3 -0
- package/dist/capture.d.ts.map +1 -0
- package/dist/capture.js +33 -0
- package/dist/capture.js.map +1 -0
- package/dist/codegen.js +4 -2
- package/dist/codegen.js.map +1 -1
- package/dist/compile.d.ts +33 -5
- package/dist/compile.d.ts.map +1 -1
- package/dist/compile.js +96 -11
- package/dist/compile.js.map +1 -1
- package/dist/index.d.ts +5 -3
- package/dist/index.d.ts.map +1 -1
- package/dist/index.js +4 -2
- package/dist/index.js.map +1 -1
- package/dist/ir.d.ts +2 -0
- package/dist/ir.d.ts.map +1 -1
- package/dist/ir.js +1 -1
- package/dist/ir.js.map +1 -1
- package/dist/module.d.ts +30 -4
- package/dist/module.d.ts.map +1 -1
- package/dist/module.js +39 -13
- package/dist/module.js.map +1 -1
- package/dist/nn.d.ts +19 -0
- package/dist/nn.d.ts.map +1 -0
- package/dist/nn.js +60 -0
- package/dist/nn.js.map +1 -0
- package/dist/ops.d.ts +1 -1
- package/dist/ops.d.ts.map +1 -1
- package/dist/ops.js +2 -2
- package/dist/ops.js.map +1 -1
- package/dist/runtime.d.ts +79 -4
- package/dist/runtime.d.ts.map +1 -1
- package/dist/runtime.js +153 -19
- package/dist/runtime.js.map +1 -1
- package/dist/trace.d.ts +1 -0
- package/dist/trace.d.ts.map +1 -1
- package/dist/trace.js +12 -0
- package/dist/trace.js.map +1 -1
- package/package.json +1 -2
- package/src/adam.ts +31 -10
- package/src/buffers.ts +14 -1
- package/src/capture.ts +36 -0
- package/src/codegen.ts +4 -2
- package/src/compile.ts +112 -13
- package/src/index.ts +5 -3
- package/src/ir.ts +10 -4
- package/src/module.ts +75 -11
- package/src/nn.ts +59 -0
- package/src/ops.ts +2 -2
- package/src/runtime.ts +260 -22
- package/src/trace.ts +13 -0
- package/SPEC.md +0 -293
package/src/adam.ts
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
// Adam optimizer, in-graph.
|
|
1
|
+
// Adam / AdamW optimizer, in-graph.
|
|
2
2
|
//
|
|
3
3
|
// `appendAdam` extends a graph that already has a forward pass + autograd-emitted
|
|
4
4
|
// backward (i.e., has paramGrads from `appendGrad`) with the Adam update math.
|
|
@@ -6,12 +6,14 @@
|
|
|
6
6
|
// Per parameter P with gradient g:
|
|
7
7
|
// m_new = b1 * m + (1 - b1) * g
|
|
8
8
|
// v_new = b2 * v + (1 - b2) * g²
|
|
9
|
-
// p_new = p -
|
|
9
|
+
// p_new = decayShrink * p - lrt * m_new / (sqrt(v_new) + eps)
|
|
10
10
|
//
|
|
11
|
-
//
|
|
12
|
-
//
|
|
13
|
-
//
|
|
14
|
-
//
|
|
11
|
+
// `decayShrink = 1 - lr * weightDecay` when the param is being decayed
|
|
12
|
+
// (Loshchilov & Hutter, "AdamW") and 1 otherwise — at which point the
|
|
13
|
+
// multiply folds out and you're left with plain Adam. `lrt` is supplied
|
|
14
|
+
// per-step from CPU and includes the bias-correction factor
|
|
15
|
+
// `sqrt(1-b2^t)/(1-b1^t)`; that's why convergence isn't affected by the
|
|
16
|
+
// first-step warmup that bias-correction-free Adam suffers.
|
|
15
17
|
//
|
|
16
18
|
// Returns writeback declarations the buffer planner uses to wire up the
|
|
17
19
|
// "after step, copy the new value into the persistent home" path. m and v
|
|
@@ -29,6 +31,15 @@ export interface AdamConfig {
|
|
|
29
31
|
b1?: number // default 0.9
|
|
30
32
|
b2?: number // default 0.999
|
|
31
33
|
eps?: number // default 1e-8
|
|
34
|
+
/** AdamW: decoupled weight decay coefficient. Default 0 (plain Adam).
|
|
35
|
+
* When non-zero, every step shrinks each decayed param by a factor of
|
|
36
|
+
* `1 - lr * weightDecay` before the gradient update. */
|
|
37
|
+
weightDecay?: number
|
|
38
|
+
/** Filter deciding which params get weight decay. Only consulted when
|
|
39
|
+
* weightDecay > 0. Default: decay every param. Override for the standard
|
|
40
|
+
* transformer convention (decay weights/embeddings, skip biases + LN gains).
|
|
41
|
+
* Example: `(name) => name.includes('.W') || name.endsWith('_emb')`. */
|
|
42
|
+
decayFilter?: (paramName: string) => boolean
|
|
32
43
|
}
|
|
33
44
|
|
|
34
45
|
export interface AdamResult {
|
|
@@ -38,7 +49,7 @@ export interface AdamResult {
|
|
|
38
49
|
* with `lr * sqrt(1-b2^t)/(1-b1^t)` (Adam's bias-corrected effective LR). */
|
|
39
50
|
lrtInputName: string
|
|
40
51
|
/** Hyperparameters as captured (so the runtime can compute lrt). */
|
|
41
|
-
config: Required<AdamConfig
|
|
52
|
+
config: Required<Omit<AdamConfig, 'decayFilter'>> & { decayFilter: (name: string) => boolean }
|
|
42
53
|
}
|
|
43
54
|
|
|
44
55
|
/**
|
|
@@ -50,7 +61,8 @@ export interface AdamResult {
|
|
|
50
61
|
* @param paramTensors param name -> the param's leaf Tensor (the param_input).
|
|
51
62
|
* Needed because the param_input lives in the graph but we
|
|
52
63
|
* don't have a direct map by name in `Graph` — caller passes it.
|
|
53
|
-
* @param config Adam hyperparameters
|
|
64
|
+
* @param config Adam hyperparameters. Set `weightDecay > 0` for AdamW; an
|
|
65
|
+
* optional `decayFilter` selects which params receive decay.
|
|
54
66
|
*/
|
|
55
67
|
export function appendAdam(
|
|
56
68
|
graph: Graph,
|
|
@@ -58,11 +70,13 @@ export function appendAdam(
|
|
|
58
70
|
paramTensors: Record<string, Tensor>,
|
|
59
71
|
config: AdamConfig,
|
|
60
72
|
): AdamResult {
|
|
61
|
-
const fullConfig
|
|
73
|
+
const fullConfig = {
|
|
62
74
|
lr: config.lr,
|
|
63
75
|
b1: config.b1 ?? 0.9,
|
|
64
76
|
b2: config.b2 ?? 0.999,
|
|
65
77
|
eps: config.eps ?? 1e-8,
|
|
78
|
+
weightDecay: config.weightDecay ?? 0,
|
|
79
|
+
decayFilter: config.decayFilter ?? (() => true),
|
|
66
80
|
}
|
|
67
81
|
const writebacks: WritebackDecl[] = []
|
|
68
82
|
const lrtInputName = '_adam_lrt'
|
|
@@ -81,10 +95,17 @@ export function appendAdam(
|
|
|
81
95
|
const mState = stateInput(`adam_m_${name}`, p.shape, 'f32', 0)
|
|
82
96
|
const vState = stateInput(`adam_v_${name}`, p.shape, 'f32', 0)
|
|
83
97
|
|
|
98
|
+
// decayShrink baked at compile time. 1.0 for plain Adam (no extra cost
|
|
99
|
+
// — the WGSL compiler folds the constant multiply); 1 - lr * weightDecay
|
|
100
|
+
// for the params the filter selects.
|
|
101
|
+
const decayShrink = (fullConfig.weightDecay > 0 && fullConfig.decayFilter(name))
|
|
102
|
+
? 1 - fullConfig.lr * fullConfig.weightDecay
|
|
103
|
+
: 1
|
|
104
|
+
|
|
84
105
|
// Three fused kernels per parameter — one for each of m_new / v_new / p_new.
|
|
85
106
|
const newM = adamUpdateM(mState, g, fullConfig.b1)
|
|
86
107
|
const newV = adamUpdateV(vState, g, fullConfig.b2)
|
|
87
|
-
const newP = adamUpdateP(p, newM, newV, lrt, fullConfig.eps)
|
|
108
|
+
const newP = adamUpdateP(p, newM, newV, lrt, fullConfig.eps, decayShrink)
|
|
88
109
|
|
|
89
110
|
writebacks.push({ source: newM, destName: `adam_m_${name}`, destKind: 'state' })
|
|
90
111
|
writebacks.push({ source: newV, destName: `adam_v_${name}`, destKind: 'state' })
|
package/src/buffers.ts
CHANGED
|
@@ -47,6 +47,7 @@ export interface BufferPlan {
|
|
|
47
47
|
inputsByName: Map<string, number> // name -> buffer id
|
|
48
48
|
paramGradsByName: Map<string, number> // name -> buffer id
|
|
49
49
|
statesByName: Map<string, number> // name -> buffer id (persistent state homes)
|
|
50
|
+
capturesByName: Map<string, number> // name -> buffer id (activation captures)
|
|
50
51
|
outputBufferIds: number[] // graph.outputs mapped through
|
|
51
52
|
/** End-of-step writebacks (Adam updates for params, m, v, etc.) */
|
|
52
53
|
writebacks: Writeback[]
|
|
@@ -169,5 +170,17 @@ export function planBuffers(
|
|
|
169
170
|
return { source: sourceBufId, dest: destBufId, bytes: sourceSpec.byteSize }
|
|
170
171
|
})
|
|
171
172
|
|
|
172
|
-
|
|
173
|
+
// Resolve graph.captures (name -> tensor id) to (name -> buffer id).
|
|
174
|
+
// No pinning needed at the planner level: each tensor already has its own
|
|
175
|
+
// buffer (see "v1 strategy" comment at top — no pooling yet).
|
|
176
|
+
const capturesByName = new Map<string, number>()
|
|
177
|
+
for (const [name, tensorId] of graph.captures) {
|
|
178
|
+
const bufId = tensorToBuffer.get(tensorId)
|
|
179
|
+
if (bufId === undefined) {
|
|
180
|
+
throw new Error(`planBuffers: capture '${name}' references unknown tensor #${tensorId}`)
|
|
181
|
+
}
|
|
182
|
+
capturesByName.set(name, bufId)
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
return { buffers, tensorToBuffer, paramsByName, inputsByName, paramGradsByName, statesByName, capturesByName, outputBufferIds, writebacks }
|
|
173
186
|
}
|
package/src/capture.ts
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
// Activation capture — opt-in readback of intermediate tensors at training step.
|
|
2
|
+
//
|
|
3
|
+
// Usage (inside the user's forward pass):
|
|
4
|
+
//
|
|
5
|
+
// import { capture } from 'tensorgrad'
|
|
6
|
+
//
|
|
7
|
+
// function attentionFwd(p, x) {
|
|
8
|
+
// const scores = mul(matmulBatched(q, kT), SCALE_QK)
|
|
9
|
+
// const attn = capture(`attn.${layerIdx}`, softmaxCausalLast(scores))
|
|
10
|
+
// return matmulBatched(attn, v)
|
|
11
|
+
// }
|
|
12
|
+
//
|
|
13
|
+
// Pass-through return type: `capture(name, t)` returns `t` unchanged so it
|
|
14
|
+
// inlines at the point of computation. Behind the scenes it registers `t.id`
|
|
15
|
+
// against `name` on the current graph; runtime exposes the registered tensors
|
|
16
|
+
// via `step(inputs, { withCaptures: true })`.
|
|
17
|
+
//
|
|
18
|
+
// Outside the user's forward trace (during `appendGrad` / `appendAdam`'s
|
|
19
|
+
// `traceInto` re-entry), `capture()` is a no-op — gradient and optimizer
|
|
20
|
+
// internals shouldn't accidentally publish themselves to the UI.
|
|
21
|
+
|
|
22
|
+
import type { Tensor } from './ir.js'
|
|
23
|
+
import { currentGraph, isCaptureEnabled } from './trace.js'
|
|
24
|
+
|
|
25
|
+
export function capture<T extends Tensor>(name: string, t: T): T {
|
|
26
|
+
if (!isCaptureEnabled()) return t
|
|
27
|
+
const g = currentGraph()
|
|
28
|
+
if (g.captures.has(name)) {
|
|
29
|
+
throw new Error(
|
|
30
|
+
`capture: name '${name}' already registered. Use unique names ` +
|
|
31
|
+
`(e.g. \`attn.\${layerIdx}\`) when capturing across a loop.`,
|
|
32
|
+
)
|
|
33
|
+
}
|
|
34
|
+
g.captures.set(name, t.id)
|
|
35
|
+
return t
|
|
36
|
+
}
|
package/src/codegen.ts
CHANGED
|
@@ -555,8 +555,10 @@ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
|
|
|
555
555
|
return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.v), buf(op.g), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
|
|
556
556
|
}
|
|
557
557
|
case 'adam_update_p': {
|
|
558
|
-
// p_new = p - lrt[0] * m_new / (sqrt(v_new) + eps).
|
|
558
|
+
// p_new = decayShrink * p - lrt[0] * m_new / (sqrt(v_new) + eps).
|
|
559
559
|
// lrt is supplied per-step from CPU (already includes bias correction).
|
|
560
|
+
// decayShrink encodes AdamW's decoupled weight decay; when no decay is
|
|
561
|
+
// requested it's exactly 1.0 and the WGSL compiler folds the multiply away.
|
|
560
562
|
const out = tof(op.out)
|
|
561
563
|
const total = shapeSize(out.shape)
|
|
562
564
|
const wgsl = `
|
|
@@ -569,7 +571,7 @@ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
|
|
|
569
571
|
fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
|
|
570
572
|
let i = gid.x + gid.y * 16776960u;
|
|
571
573
|
if (i >= ${total}u) { return; }
|
|
572
|
-
out[i] = p[i] - lrt[0] * mNew[i] / (sqrt(vNew[i]) + ${wgslLiteral(op.eps, 'f32')});
|
|
574
|
+
out[i] = ${wgslLiteral(op.decayShrink, 'f32')} * p[i] - lrt[0] * mNew[i] / (sqrt(vNew[i]) + ${wgslLiteral(op.eps, 'f32')});
|
|
573
575
|
}`.trim()
|
|
574
576
|
return { opIndex, opKind: op.kind, wgsl, bindings: [buf(op.p), buf(op.mNew), buf(op.vNew), buf(op.lrt), buf(op.out)], threads: total, workgroupSize: WG_SIZE }
|
|
575
577
|
}
|
package/src/compile.ts
CHANGED
|
@@ -14,7 +14,7 @@ import { appendGrad, type GradResult } from './grad.js'
|
|
|
14
14
|
import { appendAdam, type AdamConfig } from './adam.js'
|
|
15
15
|
import { planBuffers, type BufferPlan } from './buffers.js'
|
|
16
16
|
import { emitKernels, type KernelSpec } from './codegen.js'
|
|
17
|
-
import { createRuntime, type CompiledRuntime, type RuntimeOpts } from './runtime.js'
|
|
17
|
+
import { createRuntime, createForwardRuntime, type CompiledRuntime, type CompiledForward, type RuntimeOpts } from './runtime.js'
|
|
18
18
|
import { Module, materializeParams } from './module.js'
|
|
19
19
|
|
|
20
20
|
/** Declares one input tensor of the model's forward function. Order matches
|
|
@@ -65,10 +65,19 @@ export interface CompileModuleOptions extends RuntimeOpts {
|
|
|
65
65
|
adam?: AdamConfig
|
|
66
66
|
}
|
|
67
67
|
|
|
68
|
+
export interface CompileForwardOptions extends RuntimeOpts {
|
|
69
|
+
/** Per-step data inputs to the forward function. */
|
|
70
|
+
inputs?: InputDecl[]
|
|
71
|
+
}
|
|
72
|
+
|
|
68
73
|
/**
|
|
69
|
-
* Compile a Module-based model.
|
|
70
|
-
* model
|
|
71
|
-
*
|
|
74
|
+
* Compile a Module-based model. Pass a *factory* `() => new Model()`, not the
|
|
75
|
+
* model instance itself: compilation mutates the tree (every `ParamSentinel`
|
|
76
|
+
* field becomes a real `Tensor`), so the instance is consumed and shouldn't be
|
|
77
|
+
* referenced afterwards. Re-call the factory if you need a fresh tree.
|
|
78
|
+
*
|
|
79
|
+
* The forward function takes the materialized model and returns the loss
|
|
80
|
+
* tensor.
|
|
72
81
|
*
|
|
73
82
|
* Walks the module tree to materialize params with auto-derived names, then
|
|
74
83
|
* runs trace → grad → adam → buffer plan → codegen → runtime.
|
|
@@ -78,14 +87,15 @@ export interface CompileModuleOptions extends RuntimeOpts {
|
|
|
78
87
|
* users don't need to provide it themselves.
|
|
79
88
|
*/
|
|
80
89
|
export async function compileModule<M extends Module>(
|
|
81
|
-
|
|
90
|
+
modelFactory: () => M,
|
|
82
91
|
forward: (m: M, ...inputs: Tensor[]) => Tensor,
|
|
83
92
|
opts: CompileModuleOptions = {},
|
|
84
|
-
): Promise<CompiledRuntime & { ir: CompiledIR }> {
|
|
93
|
+
): Promise<CompiledRuntime & { ir: CompiledIR; uploadInitialParams: () => void }> {
|
|
85
94
|
const inputDecls = opts.inputs ?? []
|
|
86
|
-
|
|
95
|
+
const model = modelFactory()
|
|
96
|
+
let materialized: ReturnType<typeof materializeParams> = { tensors: {}, initFns: {} }
|
|
87
97
|
const graph = trace(() => {
|
|
88
|
-
|
|
98
|
+
materialized = materializeParams(model)
|
|
89
99
|
const inputTensors = inputDecls.map(d => tensorInput(d.name, d.shape, d.dtype ?? 'f32'))
|
|
90
100
|
return forward(model, ...inputTensors)
|
|
91
101
|
})
|
|
@@ -94,7 +104,7 @@ export async function compileModule<M extends Module>(
|
|
|
94
104
|
|
|
95
105
|
let adamResult: ReturnType<typeof appendAdam> | undefined
|
|
96
106
|
if (opts.adam) {
|
|
97
|
-
adamResult = appendAdam(graph, paramGrads,
|
|
107
|
+
adamResult = appendAdam(graph, paramGrads, materialized.tensors, opts.adam)
|
|
98
108
|
}
|
|
99
109
|
|
|
100
110
|
const plan = planBuffers(graph, paramGrads, adamResult?.writebacks ?? [])
|
|
@@ -103,18 +113,107 @@ export async function compileModule<M extends Module>(
|
|
|
103
113
|
const runtime = await createRuntime(plan, kernels, lossBufferId, opts)
|
|
104
114
|
|
|
105
115
|
// If Adam is enabled, wrap step() to track the step count and supply lrt.
|
|
116
|
+
// Wrap resetOptimizerState() too, so a reset zeros m/v *and* the bias-correction
|
|
117
|
+
// counter — otherwise the next step would skip Adam's warmup phase.
|
|
106
118
|
if (adamResult) {
|
|
107
119
|
const { lrtInputName, config } = adamResult
|
|
108
120
|
let t = 0
|
|
109
121
|
const lrtBuf = new Float32Array(1)
|
|
110
|
-
const innerStep = runtime.step.bind(runtime)
|
|
111
|
-
|
|
122
|
+
const innerStep = runtime.step.bind(runtime) as CompiledRuntime['step']
|
|
123
|
+
const innerReset = runtime.resetOptimizerState.bind(runtime)
|
|
124
|
+
const wrappedStep = (
|
|
125
|
+
inputs: Record<string, Int32Array | Float32Array>,
|
|
126
|
+
opts?: { withCaptures?: boolean },
|
|
127
|
+
): Promise<number | { loss: number; captures: Record<string, Float32Array> }> => {
|
|
112
128
|
t++
|
|
113
129
|
lrtBuf[0] = config.lr * Math.sqrt(1 - Math.pow(config.b2, t)) / (1 - Math.pow(config.b1, t))
|
|
114
|
-
|
|
130
|
+
const merged = { ...inputs, [lrtInputName]: lrtBuf }
|
|
131
|
+
return opts?.withCaptures ? innerStep(merged, { withCaptures: true }) : innerStep(merged)
|
|
132
|
+
}
|
|
133
|
+
runtime.step = wrappedStep as CompiledRuntime['step']
|
|
134
|
+
runtime.resetOptimizerState = () => {
|
|
135
|
+
t = 0
|
|
136
|
+
innerReset()
|
|
137
|
+
}
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
const { initFns } = materialized
|
|
141
|
+
const uploadInitialParams = () => {
|
|
142
|
+
const out: Record<string, Float32Array> = {}
|
|
143
|
+
for (const [name, bufId] of plan.paramsByName) {
|
|
144
|
+
const shape = plan.buffers[bufId]!.shape
|
|
145
|
+
const size = shape.reduce((a, b) => a * b, 1)
|
|
146
|
+
const initFn = initFns[name]
|
|
147
|
+
if (!initFn) throw new Error(`uploadInitialParams: no init for param '${name}'`)
|
|
148
|
+
out[name] = initFn(size, shape)
|
|
115
149
|
}
|
|
150
|
+
runtime.uploadParams(out)
|
|
116
151
|
}
|
|
117
152
|
|
|
118
153
|
const ir: CompiledIR = { graph, paramGrads, loss, plan, kernels }
|
|
119
|
-
return Object.assign(runtime, { ir })
|
|
154
|
+
return Object.assign(runtime, { ir, uploadInitialParams })
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
// ============================================================================
|
|
158
|
+
// Forward-only compile
|
|
159
|
+
// ============================================================================
|
|
160
|
+
|
|
161
|
+
/**
|
|
162
|
+
* Compile a Module-based model in forward-only mode (no autograd, no Adam).
|
|
163
|
+
* The forward function returns the output tensor (e.g., logits) instead of a
|
|
164
|
+
* scalar loss; runtime exposes `run(inputs)` returning the full output as a
|
|
165
|
+
* `Float32Array`.
|
|
166
|
+
*
|
|
167
|
+
* **Sharing params with a training compile.** Pass `opts.sharedParams =
|
|
168
|
+
* trainCompiled.params` to bind this graph's param buffers to an existing
|
|
169
|
+
* training runtime's GPU buffers — every train step is then immediately
|
|
170
|
+
* visible to `run()` calls here, no copies. The forward graph's
|
|
171
|
+
* `uploadInitialParams()` skips any param covered by `sharedParams`.
|
|
172
|
+
*
|
|
173
|
+
* Typical use: a B=1 inference graph alongside a B=512 training graph,
|
|
174
|
+
* built from the same `Module` factory.
|
|
175
|
+
*/
|
|
176
|
+
export async function compileForward<M extends Module>(
|
|
177
|
+
modelFactory: () => M,
|
|
178
|
+
forward: (m: M, ...inputs: Tensor[]) => Tensor,
|
|
179
|
+
opts: CompileForwardOptions = {},
|
|
180
|
+
): Promise<CompiledForward & { ir: CompiledIR; uploadInitialParams: () => void }> {
|
|
181
|
+
const inputDecls = opts.inputs ?? []
|
|
182
|
+
const model = modelFactory()
|
|
183
|
+
let materialized: ReturnType<typeof materializeParams> = { tensors: {}, initFns: {} }
|
|
184
|
+
const graph = trace(() => {
|
|
185
|
+
materialized = materializeParams(model)
|
|
186
|
+
const inputTensors = inputDecls.map(d => tensorInput(d.name, d.shape, d.dtype ?? 'f32'))
|
|
187
|
+
return forward(model, ...inputTensors)
|
|
188
|
+
})
|
|
189
|
+
|
|
190
|
+
const plan = planBuffers(graph, /* paramGrads */ {})
|
|
191
|
+
const kernels = emitKernels(graph, plan)
|
|
192
|
+
const outputTensor = graph.tensors[graph.outputs[0]!]!
|
|
193
|
+
const outputBufferId = plan.tensorToBuffer.get(outputTensor.id)!
|
|
194
|
+
const runtime = await createForwardRuntime(plan, kernels, outputBufferId, opts)
|
|
195
|
+
|
|
196
|
+
const sharedParams = opts.sharedParams
|
|
197
|
+
const { initFns } = materialized
|
|
198
|
+
const uploadInitialParams = () => {
|
|
199
|
+
const out: Record<string, Float32Array> = {}
|
|
200
|
+
let needsUpload = false
|
|
201
|
+
for (const [name, bufId] of plan.paramsByName) {
|
|
202
|
+
// Skip params covered by sharedParams — those are owned by the providing
|
|
203
|
+
// compile and already initialized there.
|
|
204
|
+
if (sharedParams?.has(name)) continue
|
|
205
|
+
const shape = plan.buffers[bufId]!.shape
|
|
206
|
+
const size = shape.reduce((a, b) => a * b, 1)
|
|
207
|
+
const initFn = initFns[name]
|
|
208
|
+
if (!initFn) throw new Error(`uploadInitialParams: no init for param '${name}'`)
|
|
209
|
+
out[name] = initFn(size, shape)
|
|
210
|
+
needsUpload = true
|
|
211
|
+
}
|
|
212
|
+
if (needsUpload) runtime.uploadParams(out, { partial: !!sharedParams })
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
// CompiledIR.loss is the field name; for forward-only, it carries the user's
|
|
216
|
+
// returned tensor (e.g., logits). Same shape conceptually; just no autograd.
|
|
217
|
+
const ir: CompiledIR = { graph, paramGrads: {}, loss: outputTensor, plan, kernels }
|
|
218
|
+
return Object.assign(runtime, { ir, uploadInitialParams })
|
|
120
219
|
}
|
package/src/index.ts
CHANGED
|
@@ -6,6 +6,7 @@
|
|
|
6
6
|
export type { Tensor, Shape, Dtype, OpNode, Graph, CallSite } from './ir.js'
|
|
7
7
|
export { ShapeError } from './shape.js'
|
|
8
8
|
export { trace, traceInto, paramInput, tensorInput, stateInput } from './trace.js'
|
|
9
|
+
export { capture } from './capture.js'
|
|
9
10
|
export {
|
|
10
11
|
// Element-wise arithmetic. The binops accept Tensor or JS-number for the second arg.
|
|
11
12
|
add, sub, mul, div,
|
|
@@ -35,6 +36,7 @@ export { appendGrad, type GradResult } from './grad.js'
|
|
|
35
36
|
export { appendAdam, type AdamConfig, type AdamResult } from './adam.js'
|
|
36
37
|
export { planBuffers, type BufferPlan, type BufferSpec, type Writeback, type WritebackDecl } from './buffers.js'
|
|
37
38
|
export { emitKernels, type KernelSpec } from './codegen.js'
|
|
38
|
-
export { createRuntime, type CompiledRuntime, type RuntimeOpts } from './runtime.js'
|
|
39
|
-
export { compile, compileToIR, compileModule, type CompiledIR, type CompileModuleOptions, type InputDecl } from './compile.js'
|
|
40
|
-
export { Module, materializeParams } from './module.js'
|
|
39
|
+
export { createRuntime, createForwardRuntime, type CompiledRuntime, type CompiledForward, type RuntimeOpts, type StepOptions, type StepWithCaptures, type RunOptions, type RunWithCaptures } from './runtime.js'
|
|
40
|
+
export { compile, compileToIR, compileModule, compileForward, type CompiledIR, type CompileModuleOptions, type CompileForwardOptions, type InputDecl } from './compile.js'
|
|
41
|
+
export { Module, materializeParams, type InitSpec, type ParamOptions, type MaterializedParams } from './module.js'
|
|
42
|
+
export * as nn from './nn.js'
|
package/src/ir.ts
CHANGED
|
@@ -109,11 +109,13 @@ export type OpNode =
|
|
|
109
109
|
// update into ~12 element-wise dispatches per param.
|
|
110
110
|
| { kind: 'adam_update_m'; out: number; m: number; g: number; b1: number }
|
|
111
111
|
| { kind: 'adam_update_v'; out: number; v: number; g: number; b2: number }
|
|
112
|
-
// adam_update_p: p_new = p - lrt[0] * m_new / (sqrt(v_new) + eps).
|
|
112
|
+
// adam_update_p: p_new = decayShrink * p - lrt[0] * m_new / (sqrt(v_new) + eps).
|
|
113
113
|
// `lrt` is a scalar tensor (provided as a tensor_input updated per step) that
|
|
114
114
|
// already includes Adam's bias-correction factor: lrt = lr * sqrt(1-b2^t) / (1-b1^t).
|
|
115
|
-
//
|
|
116
|
-
|
|
115
|
+
// `decayShrink` is the decoupled-weight-decay factor (Loshchilov & Hutter,
|
|
116
|
+
// "AdamW") baked at compile time: 1 - lr * weightDecay when the param is being
|
|
117
|
+
// decayed, 1 otherwise. eps and decayShrink are both baked into the kernel.
|
|
118
|
+
| { kind: 'adam_update_p'; out: number; p: number; mNew: number; vNew: number; lrt: number; eps: number; decayShrink: number }
|
|
117
119
|
|
|
118
120
|
// ---- Slicing / broadcasting / autograd infrastructure -------------------
|
|
119
121
|
// Slice [start, end) along the last axis. Output shape: input shape with
|
|
@@ -139,10 +141,14 @@ export interface Graph {
|
|
|
139
141
|
// Names of tensors that should be exposed as outputs of the compiled function.
|
|
140
142
|
// Set by the trace driver; for a loss function, this is `[lossTensor]`.
|
|
141
143
|
readonly outputs: number[]
|
|
144
|
+
// Tensors registered for activation readback via `capture(name, t)`.
|
|
145
|
+
// Keyed by user-supplied name; insertion order preserved. Empty when no
|
|
146
|
+
// captures registered (the common training case — zero overhead).
|
|
147
|
+
readonly captures: Map<string, number>
|
|
142
148
|
}
|
|
143
149
|
|
|
144
150
|
export function makeGraph(): Graph {
|
|
145
|
-
return { ops: [], tensors: [], outputs: [] }
|
|
151
|
+
return { ops: [], tensors: [], outputs: [], captures: new Map() }
|
|
146
152
|
}
|
|
147
153
|
|
|
148
154
|
// Internal: register a fresh tensor in the graph and return its id.
|
package/src/module.ts
CHANGED
|
@@ -6,8 +6,8 @@
|
|
|
6
6
|
// W: Tensor; b: Tensor
|
|
7
7
|
// constructor(inDim: number, outDim: number) {
|
|
8
8
|
// super()
|
|
9
|
-
// this.W = this.param([inDim, outDim])
|
|
10
|
-
// this.b = this.param([outDim])
|
|
9
|
+
// this.W = this.param([inDim, outDim]) // randn, scale 0.02
|
|
10
|
+
// this.b = this.param([outDim], { init: 'zeros' })
|
|
11
11
|
// }
|
|
12
12
|
// }
|
|
13
13
|
// class Block extends Module {
|
|
@@ -28,6 +28,54 @@
|
|
|
28
28
|
import type { Tensor, Shape, Dtype } from './ir.js'
|
|
29
29
|
import { paramInput } from './trace.js'
|
|
30
30
|
|
|
31
|
+
// ============================================================================
|
|
32
|
+
// Init metadata
|
|
33
|
+
// ============================================================================
|
|
34
|
+
|
|
35
|
+
/** How a parameter's initial values are produced.
|
|
36
|
+
* - `'randn'` — Gaussian, with `scale` (default 0.02). The common case for
|
|
37
|
+
* weight matrices and embeddings.
|
|
38
|
+
* - `'zeros'` — fill with 0. Common for biases and LayerNorm beta.
|
|
39
|
+
* - `'ones'` — fill with 1. Common for LayerNorm gain.
|
|
40
|
+
* - Custom function — receives total element count and shape, returns the
|
|
41
|
+
* Float32Array. Use for fan-in scaling or any non-standard scheme.
|
|
42
|
+
*/
|
|
43
|
+
export type InitSpec =
|
|
44
|
+
| 'randn'
|
|
45
|
+
| 'zeros'
|
|
46
|
+
| 'ones'
|
|
47
|
+
| ((size: number, shape: readonly number[]) => Float32Array)
|
|
48
|
+
|
|
49
|
+
export interface ParamOptions {
|
|
50
|
+
dtype?: Dtype
|
|
51
|
+
/** Init kind. Default: `'randn'`. */
|
|
52
|
+
init?: InitSpec
|
|
53
|
+
/** Std dev for `'randn'`. Default 0.02. Ignored for non-randn init. */
|
|
54
|
+
scale?: number
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
type InitFn = (size: number, shape: readonly number[]) => Float32Array
|
|
58
|
+
|
|
59
|
+
function boxMuller(): number {
|
|
60
|
+
return Math.sqrt(-2 * Math.log(Math.max(1e-10, Math.random()))) * Math.cos(2 * Math.PI * Math.random())
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
function resolveInit(opts: ParamOptions | undefined): InitFn {
|
|
64
|
+
const init = opts?.init ?? 'randn'
|
|
65
|
+
if (init === 'randn') {
|
|
66
|
+
const scale = opts?.scale ?? 0.02
|
|
67
|
+
return (size) => {
|
|
68
|
+
const arr = new Float32Array(size)
|
|
69
|
+
for (let i = 0; i < size; i++) arr[i] = boxMuller() * scale
|
|
70
|
+
return arr
|
|
71
|
+
}
|
|
72
|
+
}
|
|
73
|
+
if (init === 'zeros') return (size) => new Float32Array(size)
|
|
74
|
+
if (init === 'ones') return (size) => { const a = new Float32Array(size); a.fill(1); return a }
|
|
75
|
+
if (typeof init === 'function') return init
|
|
76
|
+
throw new Error(`Unknown init: ${String(init)}`)
|
|
77
|
+
}
|
|
78
|
+
|
|
31
79
|
// ============================================================================
|
|
32
80
|
// Internals: param sentinel
|
|
33
81
|
// ============================================================================
|
|
@@ -38,7 +86,11 @@ import { paramInput } from './trace.js'
|
|
|
38
86
|
// only valid post-materialization (which is always before forward runs).
|
|
39
87
|
|
|
40
88
|
class ParamSentinel {
|
|
41
|
-
constructor(
|
|
89
|
+
constructor(
|
|
90
|
+
public readonly shape: Shape,
|
|
91
|
+
public readonly dtype: Dtype,
|
|
92
|
+
public readonly initFn: InitFn,
|
|
93
|
+
) {}
|
|
42
94
|
}
|
|
43
95
|
|
|
44
96
|
// ============================================================================
|
|
@@ -52,11 +104,13 @@ export abstract class Module {
|
|
|
52
104
|
* that gets replaced with a real Tensor at compile time.
|
|
53
105
|
*
|
|
54
106
|
* The parameter's name is auto-derived from its property path in the model
|
|
55
|
-
* tree (e.g. `layers.0.attn.W_q`).
|
|
107
|
+
* tree (e.g. `layers.0.attn.W_q`). Init metadata travels with the param;
|
|
108
|
+
* call `compiled.uploadInitialParams()` to apply it after compile.
|
|
56
109
|
*/
|
|
57
|
-
protected param(shape: Shape,
|
|
110
|
+
protected param(shape: Shape, opts?: ParamOptions): Tensor {
|
|
111
|
+
const dtype = opts?.dtype ?? 'f32'
|
|
58
112
|
// Lie to TypeScript: the sentinel becomes a Tensor at materialize time.
|
|
59
|
-
return new ParamSentinel(shape, dtype) as unknown as Tensor
|
|
113
|
+
return new ParamSentinel(shape, dtype, resolveInit(opts)) as unknown as Tensor
|
|
60
114
|
}
|
|
61
115
|
}
|
|
62
116
|
|
|
@@ -64,23 +118,33 @@ export abstract class Module {
|
|
|
64
118
|
// Tree walking
|
|
65
119
|
// ============================================================================
|
|
66
120
|
|
|
121
|
+
export interface MaterializedParams {
|
|
122
|
+
/** Map from auto-derived path (e.g. `layers.0.attn.W_q`) to its Tensor. */
|
|
123
|
+
tensors: Record<string, Tensor>
|
|
124
|
+
/** Init function per param path. Used by `uploadInitialParams`. */
|
|
125
|
+
initFns: Record<string, InitFn>
|
|
126
|
+
}
|
|
127
|
+
|
|
67
128
|
/**
|
|
68
129
|
* Walk the module tree and replace every ParamSentinel with a real Tensor
|
|
69
130
|
* created via `paramInput(autoName, ...)`. Must be called inside an active
|
|
70
131
|
* trace context (paramInput appends to the current graph).
|
|
71
132
|
*
|
|
72
|
-
* Returns
|
|
133
|
+
* Returns the param tensors keyed by path, plus init functions for use by
|
|
134
|
+
* `uploadInitialParams`.
|
|
73
135
|
*/
|
|
74
|
-
export function materializeParams(root: Module):
|
|
75
|
-
const
|
|
136
|
+
export function materializeParams(root: Module): MaterializedParams {
|
|
137
|
+
const tensors: Record<string, Tensor> = {}
|
|
138
|
+
const initFns: Record<string, InitFn> = {}
|
|
76
139
|
visit(root, '', (path, val, owner, key) => {
|
|
77
140
|
if (val instanceof ParamSentinel) {
|
|
78
141
|
const t = paramInput(path, val.shape, val.dtype)
|
|
79
142
|
;(owner as any)[key] = t
|
|
80
|
-
|
|
143
|
+
tensors[path] = t
|
|
144
|
+
initFns[path] = val.initFn
|
|
81
145
|
}
|
|
82
146
|
})
|
|
83
|
-
return
|
|
147
|
+
return { tensors, initFns }
|
|
84
148
|
}
|
|
85
149
|
|
|
86
150
|
// ----------------------------------------------------------------------------
|
package/src/nn.ts
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
// Standard "batteries-included" Module subclasses for the most common layers.
|
|
2
|
+
//
|
|
3
|
+
// JAX-style: each class declares its params (and their init); the forward is a
|
|
4
|
+
// plain function the user calls with `(module, x)`. No subclassing, no method
|
|
5
|
+
// dispatch — keeps the autograd-traced computation visible at the call site.
|
|
6
|
+
//
|
|
7
|
+
// Import as a namespace:
|
|
8
|
+
//
|
|
9
|
+
// import { nn } from 'tensorgrad'
|
|
10
|
+
// class Block extends Module {
|
|
11
|
+
// ln = new nn.LayerNorm(D)
|
|
12
|
+
// ffn = new nn.Linear(D, 4 * D)
|
|
13
|
+
// }
|
|
14
|
+
// const y = nn.linearFwd(p.ffn, nn.layerNormFwd(p.ln, x))
|
|
15
|
+
|
|
16
|
+
import { Module } from './module.js'
|
|
17
|
+
import type { Tensor } from './ir.js'
|
|
18
|
+
import { add, matmul, sub, mul, div, sqrt, meanLast } from './ops.js'
|
|
19
|
+
|
|
20
|
+
// ----------------------------------------------------------------------------
|
|
21
|
+
// Linear: y = x @ W (+ b)
|
|
22
|
+
// ----------------------------------------------------------------------------
|
|
23
|
+
|
|
24
|
+
export class Linear extends Module {
|
|
25
|
+
W: Tensor
|
|
26
|
+
b: Tensor | null
|
|
27
|
+
constructor(public readonly inDim: number, public readonly outDim: number, withBias = true) {
|
|
28
|
+
super()
|
|
29
|
+
this.W = this.param([inDim, outDim]) // randn, scale 0.02
|
|
30
|
+
this.b = withBias ? this.param([outDim], { init: 'zeros' }) : null
|
|
31
|
+
}
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
export function linearFwd(p: Linear, x: Tensor): Tensor {
|
|
35
|
+
const out = matmul(x, p.W)
|
|
36
|
+
return p.b ? add(out, p.b) : out
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
// ----------------------------------------------------------------------------
|
|
40
|
+
// LayerNorm — normalizes over the last axis. eps defaults to 1e-5.
|
|
41
|
+
// ----------------------------------------------------------------------------
|
|
42
|
+
|
|
43
|
+
export class LayerNorm extends Module {
|
|
44
|
+
g: Tensor
|
|
45
|
+
b: Tensor
|
|
46
|
+
constructor(public readonly d: number, public readonly eps: number = 1e-5) {
|
|
47
|
+
super()
|
|
48
|
+
this.g = this.param([d], { init: 'ones' })
|
|
49
|
+
this.b = this.param([d], { init: 'zeros' })
|
|
50
|
+
}
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
export function layerNormFwd(p: LayerNorm, x: Tensor): Tensor {
|
|
54
|
+
const m = meanLast(x)
|
|
55
|
+
const c = sub(x, m)
|
|
56
|
+
const v = meanLast(mul(c, c))
|
|
57
|
+
const stdev = sqrt(add(v, p.eps))
|
|
58
|
+
return add(mul(div(c, stdev), p.g), p.b)
|
|
59
|
+
}
|
package/src/ops.ts
CHANGED
|
@@ -297,7 +297,7 @@ export function adamUpdateV(v: Tensor, g: Tensor, b2: number): Tensor {
|
|
|
297
297
|
return addOp(currentGraph(), 'adam_update_v', v.shape, 'f32', site, { v: v.id, g: g.id, b2 })
|
|
298
298
|
}
|
|
299
299
|
|
|
300
|
-
export function adamUpdateP(p: Tensor, mNew: Tensor, vNew: Tensor, lrt: Tensor, eps: number): Tensor {
|
|
300
|
+
export function adamUpdateP(p: Tensor, mNew: Tensor, vNew: Tensor, lrt: Tensor, eps: number, decayShrink: number = 1): Tensor {
|
|
301
301
|
const site = captureSite('adamUpdateP')
|
|
302
302
|
if (p.dtype !== 'f32') throw new ShapeError(`adamUpdateP: requires f32`, site)
|
|
303
303
|
if (lrt.dtype !== 'f32' || lrt.shape.length !== 0) {
|
|
@@ -307,5 +307,5 @@ export function adamUpdateP(p: Tensor, mNew: Tensor, vNew: Tensor, lrt: Tensor,
|
|
|
307
307
|
throw new ShapeError(`adamUpdateP: p/mNew shape mismatch`, site)
|
|
308
308
|
}
|
|
309
309
|
return addOp(currentGraph(), 'adam_update_p', p.shape, 'f32', site,
|
|
310
|
-
{ p: p.id, mNew: mNew.id, vNew: vNew.id, lrt: lrt.id, eps })
|
|
310
|
+
{ p: p.id, mNew: mNew.id, vNew: vNew.id, lrt: lrt.id, eps, decayShrink })
|
|
311
311
|
}
|