tensorgrad 0.0.2 → 0.0.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +7 -9
- package/dist/adam.d.ts +22 -5
- package/dist/adam.d.ts.map +1 -1
- package/dist/adam.js +42 -10
- package/dist/adam.js.map +1 -1
- package/dist/buffers.d.ts +1 -0
- package/dist/buffers.d.ts.map +1 -1
- package/dist/buffers.js +12 -1
- package/dist/buffers.js.map +1 -1
- package/dist/capture.d.ts +3 -0
- package/dist/capture.d.ts.map +1 -0
- package/dist/capture.js +33 -0
- package/dist/capture.js.map +1 -0
- package/dist/codegen.js +16 -5
- package/dist/codegen.js.map +1 -1
- package/dist/compile.d.ts +33 -5
- package/dist/compile.d.ts.map +1 -1
- package/dist/compile.js +106 -14
- package/dist/compile.js.map +1 -1
- package/dist/index.d.ts +5 -3
- package/dist/index.d.ts.map +1 -1
- package/dist/index.js +4 -2
- package/dist/index.js.map +1 -1
- package/dist/ir.d.ts +2 -0
- package/dist/ir.d.ts.map +1 -1
- package/dist/ir.js +1 -1
- package/dist/ir.js.map +1 -1
- package/dist/module.d.ts +30 -4
- package/dist/module.d.ts.map +1 -1
- package/dist/module.js +39 -13
- package/dist/module.js.map +1 -1
- package/dist/nn.d.ts +19 -0
- package/dist/nn.d.ts.map +1 -0
- package/dist/nn.js +60 -0
- package/dist/nn.js.map +1 -0
- package/dist/ops.d.ts +1 -1
- package/dist/ops.d.ts.map +1 -1
- package/dist/ops.js +18 -1
- package/dist/ops.js.map +1 -1
- package/dist/runtime.d.ts +79 -4
- package/dist/runtime.d.ts.map +1 -1
- package/dist/runtime.js +153 -19
- package/dist/runtime.js.map +1 -1
- package/dist/trace.d.ts +1 -0
- package/dist/trace.d.ts.map +1 -1
- package/dist/trace.js +12 -0
- package/dist/trace.js.map +1 -1
- package/package.json +1 -2
- package/src/adam.ts +65 -14
- package/src/buffers.ts +14 -1
- package/src/capture.ts +36 -0
- package/src/codegen.ts +16 -5
- package/src/compile.ts +122 -16
- package/src/index.ts +5 -3
- package/src/ir.ts +20 -4
- package/src/module.ts +75 -11
- package/src/nn.ts +59 -0
- package/src/ops.ts +26 -3
- package/src/runtime.ts +260 -22
- package/src/trace.ts +13 -0
- package/SPEC.md +0 -293
package/src/adam.ts
CHANGED
|
@@ -15,6 +15,12 @@
|
|
|
15
15
|
// `sqrt(1-b2^t)/(1-b1^t)`; that's why convergence isn't affected by the
|
|
16
16
|
// first-step warmup that bias-correction-free Adam suffers.
|
|
17
17
|
//
|
|
18
|
+
// **Static vs scheduled lr.** When `config.lr` is a number, decayShrink is
|
|
19
|
+
// baked into the kernel as a literal. When it's a function `(step) => lr`,
|
|
20
|
+
// decayShrink for decayed params becomes a per-step scalar input that the
|
|
21
|
+
// runtime updates each call (computed from the current step's lr). lrt is
|
|
22
|
+
// always per-step; the bias-correction factor changes every step regardless.
|
|
23
|
+
//
|
|
18
24
|
// Returns writeback declarations the buffer planner uses to wire up the
|
|
19
25
|
// "after step, copy the new value into the persistent home" path. m and v
|
|
20
26
|
// are state_inputs (zero-initialized, persistent across steps); the param
|
|
@@ -27,7 +33,11 @@ import { traceInto, stateInput, tensorInput } from './trace.js'
|
|
|
27
33
|
import { adamUpdateM, adamUpdateV, adamUpdateP } from './ops.js'
|
|
28
34
|
|
|
29
35
|
export interface AdamConfig {
|
|
30
|
-
|
|
36
|
+
/** Constant scalar (e.g., `0.005`) or a per-step schedule function
|
|
37
|
+
* `(step) => lr`. Schedule fn lets the user implement linear/cosine decay
|
|
38
|
+
* or warmup; first call passes `step=1`. Decay-shrink (AdamW) updates
|
|
39
|
+
* per-step automatically when this is a function. */
|
|
40
|
+
lr: number | ((step: number) => number)
|
|
31
41
|
b1?: number // default 0.9
|
|
32
42
|
b2?: number // default 0.999
|
|
33
43
|
eps?: number // default 1e-8
|
|
@@ -42,14 +52,30 @@ export interface AdamConfig {
|
|
|
42
52
|
decayFilter?: (paramName: string) => boolean
|
|
43
53
|
}
|
|
44
54
|
|
|
55
|
+
/** Resolved hyperparameters: lr is the schedule fn (constants are wrapped). */
|
|
56
|
+
export interface AdamResolvedConfig {
|
|
57
|
+
lr: (step: number) => number
|
|
58
|
+
b1: number
|
|
59
|
+
b2: number
|
|
60
|
+
eps: number
|
|
61
|
+
weightDecay: number
|
|
62
|
+
decayFilter: (name: string) => boolean
|
|
63
|
+
/** True iff the user supplied an lr function (vs a constant). When false,
|
|
64
|
+
* decayShrink is baked at compile time and never updated. */
|
|
65
|
+
lrIsScheduled: boolean
|
|
66
|
+
}
|
|
67
|
+
|
|
45
68
|
export interface AdamResult {
|
|
46
69
|
/** Writebacks the buffer planner should wire into the runtime. */
|
|
47
70
|
writebacks: WritebackDecl[]
|
|
48
71
|
/** Name of the per-step scalar tensor_input. The runtime fills this each call
|
|
49
72
|
* with `lr * sqrt(1-b2^t)/(1-b1^t)` (Adam's bias-corrected effective LR). */
|
|
50
73
|
lrtInputName: string
|
|
51
|
-
/**
|
|
52
|
-
|
|
74
|
+
/** Name of the per-step decayShrink scalar tensor_input, or null when lr is
|
|
75
|
+
* static (decayShrink baked into the kernel) or no params are decayed. */
|
|
76
|
+
decayShrinkInputName: string | null
|
|
77
|
+
/** Hyperparameters as captured (so the runtime can compute lrt and decayShrink). */
|
|
78
|
+
config: AdamResolvedConfig
|
|
53
79
|
}
|
|
54
80
|
|
|
55
81
|
/**
|
|
@@ -70,22 +96,40 @@ export function appendAdam(
|
|
|
70
96
|
paramTensors: Record<string, Tensor>,
|
|
71
97
|
config: AdamConfig,
|
|
72
98
|
): AdamResult {
|
|
73
|
-
const
|
|
74
|
-
|
|
99
|
+
const lrIsScheduled = typeof config.lr === 'function'
|
|
100
|
+
const lrFn = lrIsScheduled
|
|
101
|
+
? config.lr as (step: number) => number
|
|
102
|
+
: (() => config.lr as number)
|
|
103
|
+
const initialLr = lrFn(1)
|
|
104
|
+
const fullConfig: AdamResolvedConfig = {
|
|
105
|
+
lr: lrFn,
|
|
75
106
|
b1: config.b1 ?? 0.9,
|
|
76
107
|
b2: config.b2 ?? 0.999,
|
|
77
108
|
eps: config.eps ?? 1e-8,
|
|
78
109
|
weightDecay: config.weightDecay ?? 0,
|
|
79
110
|
decayFilter: config.decayFilter ?? (() => true),
|
|
111
|
+
lrIsScheduled,
|
|
80
112
|
}
|
|
81
113
|
const writebacks: WritebackDecl[] = []
|
|
82
114
|
const lrtInputName = '_adam_lrt'
|
|
115
|
+
// Tensor input for runtime-updated decayShrink (only created when lr is a
|
|
116
|
+
// schedule fn AND at least one param will receive weight decay).
|
|
117
|
+
let decayShrinkInputName: string | null = null
|
|
83
118
|
|
|
84
119
|
return traceInto(graph, () => {
|
|
85
|
-
// One scalar lrt input shared by every adam_update_p call. Runtime supplies
|
|
86
|
-
// it per step as `lr * sqrt(1-b2^t) / (1-b1^t)`.
|
|
87
120
|
const lrt = tensorInput(lrtInputName, [], 'f32')
|
|
88
121
|
|
|
122
|
+
// Decide up-front whether we need a runtime decayShrink scalar. Only does
|
|
123
|
+
// something when both (a) lr varies per step and (b) some param is decayed.
|
|
124
|
+
const needsDynamicShrink = lrIsScheduled
|
|
125
|
+
&& fullConfig.weightDecay > 0
|
|
126
|
+
&& Object.keys(paramGrads).some(name => fullConfig.decayFilter(name))
|
|
127
|
+
let decayShrinkScalar: Tensor | null = null
|
|
128
|
+
if (needsDynamicShrink) {
|
|
129
|
+
decayShrinkInputName = '_adam_decay_shrink'
|
|
130
|
+
decayShrinkScalar = tensorInput(decayShrinkInputName, [], 'f32')
|
|
131
|
+
}
|
|
132
|
+
|
|
89
133
|
for (const name of Object.keys(paramGrads)) {
|
|
90
134
|
const p = paramTensors[name]
|
|
91
135
|
const g = paramGrads[name]
|
|
@@ -95,12 +139,19 @@ export function appendAdam(
|
|
|
95
139
|
const mState = stateInput(`adam_m_${name}`, p.shape, 'f32', 0)
|
|
96
140
|
const vState = stateInput(`adam_v_${name}`, p.shape, 'f32', 0)
|
|
97
141
|
|
|
98
|
-
//
|
|
99
|
-
//
|
|
100
|
-
//
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
142
|
+
// Choose the decayShrink form per param:
|
|
143
|
+
// - non-decayed params: literal 1 (kernel multiply folds out).
|
|
144
|
+
// - decayed + static lr: literal `1 - lr * wd` baked at compile.
|
|
145
|
+
// - decayed + scheduled lr: tensor input updated per step.
|
|
146
|
+
const isDecayed = fullConfig.weightDecay > 0 && fullConfig.decayFilter(name)
|
|
147
|
+
let decayShrink: number | Tensor
|
|
148
|
+
if (!isDecayed) {
|
|
149
|
+
decayShrink = 1
|
|
150
|
+
} else if (decayShrinkScalar !== null) {
|
|
151
|
+
decayShrink = decayShrinkScalar
|
|
152
|
+
} else {
|
|
153
|
+
decayShrink = 1 - initialLr * fullConfig.weightDecay
|
|
154
|
+
}
|
|
104
155
|
|
|
105
156
|
// Three fused kernels per parameter — one for each of m_new / v_new / p_new.
|
|
106
157
|
const newM = adamUpdateM(mState, g, fullConfig.b1)
|
|
@@ -111,6 +162,6 @@ export function appendAdam(
|
|
|
111
162
|
writebacks.push({ source: newV, destName: `adam_v_${name}`, destKind: 'state' })
|
|
112
163
|
writebacks.push({ source: newP, destName: name, destKind: 'param' })
|
|
113
164
|
}
|
|
114
|
-
return { writebacks, lrtInputName, config: fullConfig }
|
|
165
|
+
return { writebacks, lrtInputName, decayShrinkInputName, config: fullConfig }
|
|
115
166
|
})
|
|
116
167
|
}
|
package/src/buffers.ts
CHANGED
|
@@ -47,6 +47,7 @@ export interface BufferPlan {
|
|
|
47
47
|
inputsByName: Map<string, number> // name -> buffer id
|
|
48
48
|
paramGradsByName: Map<string, number> // name -> buffer id
|
|
49
49
|
statesByName: Map<string, number> // name -> buffer id (persistent state homes)
|
|
50
|
+
capturesByName: Map<string, number> // name -> buffer id (activation captures)
|
|
50
51
|
outputBufferIds: number[] // graph.outputs mapped through
|
|
51
52
|
/** End-of-step writebacks (Adam updates for params, m, v, etc.) */
|
|
52
53
|
writebacks: Writeback[]
|
|
@@ -169,5 +170,17 @@ export function planBuffers(
|
|
|
169
170
|
return { source: sourceBufId, dest: destBufId, bytes: sourceSpec.byteSize }
|
|
170
171
|
})
|
|
171
172
|
|
|
172
|
-
|
|
173
|
+
// Resolve graph.captures (name -> tensor id) to (name -> buffer id).
|
|
174
|
+
// No pinning needed at the planner level: each tensor already has its own
|
|
175
|
+
// buffer (see "v1 strategy" comment at top — no pooling yet).
|
|
176
|
+
const capturesByName = new Map<string, number>()
|
|
177
|
+
for (const [name, tensorId] of graph.captures) {
|
|
178
|
+
const bufId = tensorToBuffer.get(tensorId)
|
|
179
|
+
if (bufId === undefined) {
|
|
180
|
+
throw new Error(`planBuffers: capture '${name}' references unknown tensor #${tensorId}`)
|
|
181
|
+
}
|
|
182
|
+
capturesByName.set(name, bufId)
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
return { buffers, tensorToBuffer, paramsByName, inputsByName, paramGradsByName, statesByName, capturesByName, outputBufferIds, writebacks }
|
|
173
186
|
}
|
package/src/capture.ts
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
// Activation capture — opt-in readback of intermediate tensors at training step.
|
|
2
|
+
//
|
|
3
|
+
// Usage (inside the user's forward pass):
|
|
4
|
+
//
|
|
5
|
+
// import { capture } from 'tensorgrad'
|
|
6
|
+
//
|
|
7
|
+
// function attentionFwd(p, x) {
|
|
8
|
+
// const scores = mul(matmulBatched(q, kT), SCALE_QK)
|
|
9
|
+
// const attn = capture(`attn.${layerIdx}`, softmaxCausalLast(scores))
|
|
10
|
+
// return matmulBatched(attn, v)
|
|
11
|
+
// }
|
|
12
|
+
//
|
|
13
|
+
// Pass-through return type: `capture(name, t)` returns `t` unchanged so it
|
|
14
|
+
// inlines at the point of computation. Behind the scenes it registers `t.id`
|
|
15
|
+
// against `name` on the current graph; runtime exposes the registered tensors
|
|
16
|
+
// via `step(inputs, { withCaptures: true })`.
|
|
17
|
+
//
|
|
18
|
+
// Outside the user's forward trace (during `appendGrad` / `appendAdam`'s
|
|
19
|
+
// `traceInto` re-entry), `capture()` is a no-op — gradient and optimizer
|
|
20
|
+
// internals shouldn't accidentally publish themselves to the UI.
|
|
21
|
+
|
|
22
|
+
import type { Tensor } from './ir.js'
|
|
23
|
+
import { currentGraph, isCaptureEnabled } from './trace.js'
|
|
24
|
+
|
|
25
|
+
export function capture<T extends Tensor>(name: string, t: T): T {
|
|
26
|
+
if (!isCaptureEnabled()) return t
|
|
27
|
+
const g = currentGraph()
|
|
28
|
+
if (g.captures.has(name)) {
|
|
29
|
+
throw new Error(
|
|
30
|
+
`capture: name '${name}' already registered. Use unique names ` +
|
|
31
|
+
`(e.g. \`attn.\${layerIdx}\`) when capturing across a loop.`,
|
|
32
|
+
)
|
|
33
|
+
}
|
|
34
|
+
g.captures.set(name, t.id)
|
|
35
|
+
return t
|
|
36
|
+
}
|
package/src/codegen.ts
CHANGED
|
@@ -557,23 +557,34 @@ fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
|
|
|
557
557
|
case 'adam_update_p': {
|
|
558
558
|
// p_new = decayShrink * p - lrt[0] * m_new / (sqrt(v_new) + eps).
|
|
559
559
|
// lrt is supplied per-step from CPU (already includes bias correction).
|
|
560
|
-
// decayShrink
|
|
561
|
-
//
|
|
560
|
+
// decayShrink is either baked as a literal (no schedule, fixed lr) or
|
|
561
|
+
// bound as a per-step scalar input (when the user supplies an lr
|
|
562
|
+
// schedule via `adam: { lr: (step) => ... }`). When literal=1 the WGSL
|
|
563
|
+
// compiler folds the multiply away.
|
|
562
564
|
const out = tof(op.out)
|
|
563
565
|
const total = shapeSize(out.shape)
|
|
566
|
+
const dynamicShrink = op.decayShrinkTensor !== null
|
|
567
|
+
const shrinkExpr = dynamicShrink ? 'decayShrink[0]' : wgslLiteral(op.decayShrink, 'f32')
|
|
568
|
+
const shrinkBinding = dynamicShrink
|
|
569
|
+
? `@group(0) @binding(4) var<storage, read> decayShrink : array<f32>;\n` +
|
|
570
|
+
`@group(0) @binding(5) var<storage, read_write> out : array<f32>;`
|
|
571
|
+
: `@group(0) @binding(4) var<storage, read_write> out : array<f32>;`
|
|
564
572
|
const wgsl = `
|
|
565
573
|
@group(0) @binding(0) var<storage, read> p : array<f32>;
|
|
566
574
|
@group(0) @binding(1) var<storage, read> mNew : array<f32>;
|
|
567
575
|
@group(0) @binding(2) var<storage, read> vNew : array<f32>;
|
|
568
576
|
@group(0) @binding(3) var<storage, read> lrt : array<f32>;
|
|
569
|
-
|
|
577
|
+
${shrinkBinding}
|
|
570
578
|
@compute @workgroup_size(${WG_SIZE})
|
|
571
579
|
fn main(@builtin(global_invocation_id) gid : vec3<u32>) {
|
|
572
580
|
let i = gid.x + gid.y * 16776960u;
|
|
573
581
|
if (i >= ${total}u) { return; }
|
|
574
|
-
out[i] = ${
|
|
582
|
+
out[i] = ${shrinkExpr} * p[i] - lrt[0] * mNew[i] / (sqrt(vNew[i]) + ${wgslLiteral(op.eps, 'f32')});
|
|
575
583
|
}`.trim()
|
|
576
|
-
|
|
584
|
+
const bindings = dynamicShrink
|
|
585
|
+
? [buf(op.p), buf(op.mNew), buf(op.vNew), buf(op.lrt), buf(op.decayShrinkTensor!), buf(op.out)]
|
|
586
|
+
: [buf(op.p), buf(op.mNew), buf(op.vNew), buf(op.lrt), buf(op.out)]
|
|
587
|
+
return { opIndex, opKind: op.kind, wgsl, bindings, threads: total, workgroupSize: WG_SIZE }
|
|
577
588
|
}
|
|
578
589
|
|
|
579
590
|
case 'sum_to_shape': {
|
package/src/compile.ts
CHANGED
|
@@ -14,7 +14,7 @@ import { appendGrad, type GradResult } from './grad.js'
|
|
|
14
14
|
import { appendAdam, type AdamConfig } from './adam.js'
|
|
15
15
|
import { planBuffers, type BufferPlan } from './buffers.js'
|
|
16
16
|
import { emitKernels, type KernelSpec } from './codegen.js'
|
|
17
|
-
import { createRuntime, type CompiledRuntime, type RuntimeOpts } from './runtime.js'
|
|
17
|
+
import { createRuntime, createForwardRuntime, type CompiledRuntime, type CompiledForward, type RuntimeOpts } from './runtime.js'
|
|
18
18
|
import { Module, materializeParams } from './module.js'
|
|
19
19
|
|
|
20
20
|
/** Declares one input tensor of the model's forward function. Order matches
|
|
@@ -65,10 +65,19 @@ export interface CompileModuleOptions extends RuntimeOpts {
|
|
|
65
65
|
adam?: AdamConfig
|
|
66
66
|
}
|
|
67
67
|
|
|
68
|
+
export interface CompileForwardOptions extends RuntimeOpts {
|
|
69
|
+
/** Per-step data inputs to the forward function. */
|
|
70
|
+
inputs?: InputDecl[]
|
|
71
|
+
}
|
|
72
|
+
|
|
68
73
|
/**
|
|
69
|
-
* Compile a Module-based model.
|
|
70
|
-
* model
|
|
71
|
-
*
|
|
74
|
+
* Compile a Module-based model. Pass a *factory* `() => new Model()`, not the
|
|
75
|
+
* model instance itself: compilation mutates the tree (every `ParamSentinel`
|
|
76
|
+
* field becomes a real `Tensor`), so the instance is consumed and shouldn't be
|
|
77
|
+
* referenced afterwards. Re-call the factory if you need a fresh tree.
|
|
78
|
+
*
|
|
79
|
+
* The forward function takes the materialized model and returns the loss
|
|
80
|
+
* tensor.
|
|
72
81
|
*
|
|
73
82
|
* Walks the module tree to materialize params with auto-derived names, then
|
|
74
83
|
* runs trace → grad → adam → buffer plan → codegen → runtime.
|
|
@@ -78,14 +87,15 @@ export interface CompileModuleOptions extends RuntimeOpts {
|
|
|
78
87
|
* users don't need to provide it themselves.
|
|
79
88
|
*/
|
|
80
89
|
export async function compileModule<M extends Module>(
|
|
81
|
-
|
|
90
|
+
modelFactory: () => M,
|
|
82
91
|
forward: (m: M, ...inputs: Tensor[]) => Tensor,
|
|
83
92
|
opts: CompileModuleOptions = {},
|
|
84
|
-
): Promise<CompiledRuntime & { ir: CompiledIR }> {
|
|
93
|
+
): Promise<CompiledRuntime & { ir: CompiledIR; uploadInitialParams: () => void }> {
|
|
85
94
|
const inputDecls = opts.inputs ?? []
|
|
86
|
-
|
|
95
|
+
const model = modelFactory()
|
|
96
|
+
let materialized: ReturnType<typeof materializeParams> = { tensors: {}, initFns: {} }
|
|
87
97
|
const graph = trace(() => {
|
|
88
|
-
|
|
98
|
+
materialized = materializeParams(model)
|
|
89
99
|
const inputTensors = inputDecls.map(d => tensorInput(d.name, d.shape, d.dtype ?? 'f32'))
|
|
90
100
|
return forward(model, ...inputTensors)
|
|
91
101
|
})
|
|
@@ -94,7 +104,7 @@ export async function compileModule<M extends Module>(
|
|
|
94
104
|
|
|
95
105
|
let adamResult: ReturnType<typeof appendAdam> | undefined
|
|
96
106
|
if (opts.adam) {
|
|
97
|
-
adamResult = appendAdam(graph, paramGrads,
|
|
107
|
+
adamResult = appendAdam(graph, paramGrads, materialized.tensors, opts.adam)
|
|
98
108
|
}
|
|
99
109
|
|
|
100
110
|
const plan = planBuffers(graph, paramGrads, adamResult?.writebacks ?? [])
|
|
@@ -102,19 +112,115 @@ export async function compileModule<M extends Module>(
|
|
|
102
112
|
const lossBufferId = plan.tensorToBuffer.get(loss.id)!
|
|
103
113
|
const runtime = await createRuntime(plan, kernels, lossBufferId, opts)
|
|
104
114
|
|
|
105
|
-
// If Adam is enabled, wrap step() to track the step count and supply lrt
|
|
115
|
+
// If Adam is enabled, wrap step() to track the step count and supply lrt
|
|
116
|
+
// (and optionally decayShrink, when the user passed a per-step lr schedule).
|
|
117
|
+
// Wrap resetOptimizerState() too, so a reset zeros m/v *and* the bias-correction
|
|
118
|
+
// counter — otherwise the next step would skip Adam's warmup phase.
|
|
106
119
|
if (adamResult) {
|
|
107
|
-
const { lrtInputName, config } = adamResult
|
|
120
|
+
const { lrtInputName, decayShrinkInputName, config } = adamResult
|
|
108
121
|
let t = 0
|
|
109
122
|
const lrtBuf = new Float32Array(1)
|
|
110
|
-
const
|
|
111
|
-
runtime.step
|
|
123
|
+
const decayShrinkBuf = decayShrinkInputName ? new Float32Array(1) : null
|
|
124
|
+
const innerStep = runtime.step.bind(runtime) as CompiledRuntime['step']
|
|
125
|
+
const innerReset = runtime.resetOptimizerState.bind(runtime)
|
|
126
|
+
const wrappedStep = (
|
|
127
|
+
inputs: Record<string, Int32Array | Float32Array>,
|
|
128
|
+
opts?: { withCaptures?: boolean },
|
|
129
|
+
): Promise<number | { loss: number; captures: Record<string, Float32Array> }> => {
|
|
112
130
|
t++
|
|
113
|
-
|
|
114
|
-
|
|
131
|
+
const lrNow = config.lr(t)
|
|
132
|
+
lrtBuf[0] = lrNow * Math.sqrt(1 - Math.pow(config.b2, t)) / (1 - Math.pow(config.b1, t))
|
|
133
|
+
const merged: Record<string, Int32Array | Float32Array> = { ...inputs, [lrtInputName]: lrtBuf }
|
|
134
|
+
if (decayShrinkBuf && decayShrinkInputName) {
|
|
135
|
+
decayShrinkBuf[0] = 1 - lrNow * config.weightDecay
|
|
136
|
+
merged[decayShrinkInputName] = decayShrinkBuf
|
|
137
|
+
}
|
|
138
|
+
return opts?.withCaptures ? innerStep(merged, { withCaptures: true }) : innerStep(merged)
|
|
139
|
+
}
|
|
140
|
+
runtime.step = wrappedStep as CompiledRuntime['step']
|
|
141
|
+
runtime.resetOptimizerState = () => {
|
|
142
|
+
t = 0
|
|
143
|
+
innerReset()
|
|
144
|
+
}
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
const { initFns } = materialized
|
|
148
|
+
const uploadInitialParams = () => {
|
|
149
|
+
const out: Record<string, Float32Array> = {}
|
|
150
|
+
for (const [name, bufId] of plan.paramsByName) {
|
|
151
|
+
const shape = plan.buffers[bufId]!.shape
|
|
152
|
+
const size = shape.reduce((a, b) => a * b, 1)
|
|
153
|
+
const initFn = initFns[name]
|
|
154
|
+
if (!initFn) throw new Error(`uploadInitialParams: no init for param '${name}'`)
|
|
155
|
+
out[name] = initFn(size, shape)
|
|
115
156
|
}
|
|
157
|
+
runtime.uploadParams(out)
|
|
116
158
|
}
|
|
117
159
|
|
|
118
160
|
const ir: CompiledIR = { graph, paramGrads, loss, plan, kernels }
|
|
119
|
-
return Object.assign(runtime, { ir })
|
|
161
|
+
return Object.assign(runtime, { ir, uploadInitialParams })
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
// ============================================================================
|
|
165
|
+
// Forward-only compile
|
|
166
|
+
// ============================================================================
|
|
167
|
+
|
|
168
|
+
/**
|
|
169
|
+
* Compile a Module-based model in forward-only mode (no autograd, no Adam).
|
|
170
|
+
* The forward function returns the output tensor (e.g., logits) instead of a
|
|
171
|
+
* scalar loss; runtime exposes `run(inputs)` returning the full output as a
|
|
172
|
+
* `Float32Array`.
|
|
173
|
+
*
|
|
174
|
+
* **Sharing params with a training compile.** Pass `opts.sharedParams =
|
|
175
|
+
* trainCompiled.params` to bind this graph's param buffers to an existing
|
|
176
|
+
* training runtime's GPU buffers — every train step is then immediately
|
|
177
|
+
* visible to `run()` calls here, no copies. The forward graph's
|
|
178
|
+
* `uploadInitialParams()` skips any param covered by `sharedParams`.
|
|
179
|
+
*
|
|
180
|
+
* Typical use: a B=1 inference graph alongside a B=512 training graph,
|
|
181
|
+
* built from the same `Module` factory.
|
|
182
|
+
*/
|
|
183
|
+
export async function compileForward<M extends Module>(
|
|
184
|
+
modelFactory: () => M,
|
|
185
|
+
forward: (m: M, ...inputs: Tensor[]) => Tensor,
|
|
186
|
+
opts: CompileForwardOptions = {},
|
|
187
|
+
): Promise<CompiledForward & { ir: CompiledIR; uploadInitialParams: () => void }> {
|
|
188
|
+
const inputDecls = opts.inputs ?? []
|
|
189
|
+
const model = modelFactory()
|
|
190
|
+
let materialized: ReturnType<typeof materializeParams> = { tensors: {}, initFns: {} }
|
|
191
|
+
const graph = trace(() => {
|
|
192
|
+
materialized = materializeParams(model)
|
|
193
|
+
const inputTensors = inputDecls.map(d => tensorInput(d.name, d.shape, d.dtype ?? 'f32'))
|
|
194
|
+
return forward(model, ...inputTensors)
|
|
195
|
+
})
|
|
196
|
+
|
|
197
|
+
const plan = planBuffers(graph, /* paramGrads */ {})
|
|
198
|
+
const kernels = emitKernels(graph, plan)
|
|
199
|
+
const outputTensor = graph.tensors[graph.outputs[0]!]!
|
|
200
|
+
const outputBufferId = plan.tensorToBuffer.get(outputTensor.id)!
|
|
201
|
+
const runtime = await createForwardRuntime(plan, kernels, outputBufferId, opts)
|
|
202
|
+
|
|
203
|
+
const sharedParams = opts.sharedParams
|
|
204
|
+
const { initFns } = materialized
|
|
205
|
+
const uploadInitialParams = () => {
|
|
206
|
+
const out: Record<string, Float32Array> = {}
|
|
207
|
+
let needsUpload = false
|
|
208
|
+
for (const [name, bufId] of plan.paramsByName) {
|
|
209
|
+
// Skip params covered by sharedParams — those are owned by the providing
|
|
210
|
+
// compile and already initialized there.
|
|
211
|
+
if (sharedParams?.has(name)) continue
|
|
212
|
+
const shape = plan.buffers[bufId]!.shape
|
|
213
|
+
const size = shape.reduce((a, b) => a * b, 1)
|
|
214
|
+
const initFn = initFns[name]
|
|
215
|
+
if (!initFn) throw new Error(`uploadInitialParams: no init for param '${name}'`)
|
|
216
|
+
out[name] = initFn(size, shape)
|
|
217
|
+
needsUpload = true
|
|
218
|
+
}
|
|
219
|
+
if (needsUpload) runtime.uploadParams(out, { partial: !!sharedParams })
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
// CompiledIR.loss is the field name; for forward-only, it carries the user's
|
|
223
|
+
// returned tensor (e.g., logits). Same shape conceptually; just no autograd.
|
|
224
|
+
const ir: CompiledIR = { graph, paramGrads: {}, loss: outputTensor, plan, kernels }
|
|
225
|
+
return Object.assign(runtime, { ir, uploadInitialParams })
|
|
120
226
|
}
|
package/src/index.ts
CHANGED
|
@@ -6,6 +6,7 @@
|
|
|
6
6
|
export type { Tensor, Shape, Dtype, OpNode, Graph, CallSite } from './ir.js'
|
|
7
7
|
export { ShapeError } from './shape.js'
|
|
8
8
|
export { trace, traceInto, paramInput, tensorInput, stateInput } from './trace.js'
|
|
9
|
+
export { capture } from './capture.js'
|
|
9
10
|
export {
|
|
10
11
|
// Element-wise arithmetic. The binops accept Tensor or JS-number for the second arg.
|
|
11
12
|
add, sub, mul, div,
|
|
@@ -35,6 +36,7 @@ export { appendGrad, type GradResult } from './grad.js'
|
|
|
35
36
|
export { appendAdam, type AdamConfig, type AdamResult } from './adam.js'
|
|
36
37
|
export { planBuffers, type BufferPlan, type BufferSpec, type Writeback, type WritebackDecl } from './buffers.js'
|
|
37
38
|
export { emitKernels, type KernelSpec } from './codegen.js'
|
|
38
|
-
export { createRuntime, type CompiledRuntime, type RuntimeOpts } from './runtime.js'
|
|
39
|
-
export { compile, compileToIR, compileModule, type CompiledIR, type CompileModuleOptions, type InputDecl } from './compile.js'
|
|
40
|
-
export { Module, materializeParams } from './module.js'
|
|
39
|
+
export { createRuntime, createForwardRuntime, type CompiledRuntime, type CompiledForward, type RuntimeOpts, type StepOptions, type StepWithCaptures, type RunOptions, type RunWithCaptures } from './runtime.js'
|
|
40
|
+
export { compile, compileToIR, compileModule, compileForward, type CompiledIR, type CompileModuleOptions, type CompileForwardOptions, type InputDecl } from './compile.js'
|
|
41
|
+
export { Module, materializeParams, type InitSpec, type ParamOptions, type MaterializedParams } from './module.js'
|
|
42
|
+
export * as nn from './nn.js'
|
package/src/ir.ts
CHANGED
|
@@ -113,9 +113,21 @@ export type OpNode =
|
|
|
113
113
|
// `lrt` is a scalar tensor (provided as a tensor_input updated per step) that
|
|
114
114
|
// already includes Adam's bias-correction factor: lrt = lr * sqrt(1-b2^t) / (1-b1^t).
|
|
115
115
|
// `decayShrink` is the decoupled-weight-decay factor (Loshchilov & Hutter,
|
|
116
|
-
// "AdamW")
|
|
117
|
-
//
|
|
118
|
-
|
|
116
|
+
// "AdamW"): 1 - lr * weightDecay when the param is being decayed, 1 otherwise.
|
|
117
|
+
// It can be either a compile-time literal (number) for fixed-lr training, or a
|
|
118
|
+
// tensor id pointing at a scalar input that the runtime updates per step (used
|
|
119
|
+
// when the user supplies an lr schedule via `adam: { lr: (step) => ... }`).
|
|
120
|
+
| {
|
|
121
|
+
kind: 'adam_update_p'
|
|
122
|
+
out: number
|
|
123
|
+
p: number
|
|
124
|
+
mNew: number
|
|
125
|
+
vNew: number
|
|
126
|
+
lrt: number
|
|
127
|
+
eps: number
|
|
128
|
+
decayShrink: number // literal (used when decayShrinkTensor is null)
|
|
129
|
+
decayShrinkTensor: number | null // tensor id of a scalar input; takes precedence when set
|
|
130
|
+
}
|
|
119
131
|
|
|
120
132
|
// ---- Slicing / broadcasting / autograd infrastructure -------------------
|
|
121
133
|
// Slice [start, end) along the last axis. Output shape: input shape with
|
|
@@ -141,10 +153,14 @@ export interface Graph {
|
|
|
141
153
|
// Names of tensors that should be exposed as outputs of the compiled function.
|
|
142
154
|
// Set by the trace driver; for a loss function, this is `[lossTensor]`.
|
|
143
155
|
readonly outputs: number[]
|
|
156
|
+
// Tensors registered for activation readback via `capture(name, t)`.
|
|
157
|
+
// Keyed by user-supplied name; insertion order preserved. Empty when no
|
|
158
|
+
// captures registered (the common training case — zero overhead).
|
|
159
|
+
readonly captures: Map<string, number>
|
|
144
160
|
}
|
|
145
161
|
|
|
146
162
|
export function makeGraph(): Graph {
|
|
147
|
-
return { ops: [], tensors: [], outputs: [] }
|
|
163
|
+
return { ops: [], tensors: [], outputs: [], captures: new Map() }
|
|
148
164
|
}
|
|
149
165
|
|
|
150
166
|
// Internal: register a fresh tensor in the graph and return its id.
|
package/src/module.ts
CHANGED
|
@@ -6,8 +6,8 @@
|
|
|
6
6
|
// W: Tensor; b: Tensor
|
|
7
7
|
// constructor(inDim: number, outDim: number) {
|
|
8
8
|
// super()
|
|
9
|
-
// this.W = this.param([inDim, outDim])
|
|
10
|
-
// this.b = this.param([outDim])
|
|
9
|
+
// this.W = this.param([inDim, outDim]) // randn, scale 0.02
|
|
10
|
+
// this.b = this.param([outDim], { init: 'zeros' })
|
|
11
11
|
// }
|
|
12
12
|
// }
|
|
13
13
|
// class Block extends Module {
|
|
@@ -28,6 +28,54 @@
|
|
|
28
28
|
import type { Tensor, Shape, Dtype } from './ir.js'
|
|
29
29
|
import { paramInput } from './trace.js'
|
|
30
30
|
|
|
31
|
+
// ============================================================================
|
|
32
|
+
// Init metadata
|
|
33
|
+
// ============================================================================
|
|
34
|
+
|
|
35
|
+
/** How a parameter's initial values are produced.
|
|
36
|
+
* - `'randn'` — Gaussian, with `scale` (default 0.02). The common case for
|
|
37
|
+
* weight matrices and embeddings.
|
|
38
|
+
* - `'zeros'` — fill with 0. Common for biases and LayerNorm beta.
|
|
39
|
+
* - `'ones'` — fill with 1. Common for LayerNorm gain.
|
|
40
|
+
* - Custom function — receives total element count and shape, returns the
|
|
41
|
+
* Float32Array. Use for fan-in scaling or any non-standard scheme.
|
|
42
|
+
*/
|
|
43
|
+
export type InitSpec =
|
|
44
|
+
| 'randn'
|
|
45
|
+
| 'zeros'
|
|
46
|
+
| 'ones'
|
|
47
|
+
| ((size: number, shape: readonly number[]) => Float32Array)
|
|
48
|
+
|
|
49
|
+
export interface ParamOptions {
|
|
50
|
+
dtype?: Dtype
|
|
51
|
+
/** Init kind. Default: `'randn'`. */
|
|
52
|
+
init?: InitSpec
|
|
53
|
+
/** Std dev for `'randn'`. Default 0.02. Ignored for non-randn init. */
|
|
54
|
+
scale?: number
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
type InitFn = (size: number, shape: readonly number[]) => Float32Array
|
|
58
|
+
|
|
59
|
+
function boxMuller(): number {
|
|
60
|
+
return Math.sqrt(-2 * Math.log(Math.max(1e-10, Math.random()))) * Math.cos(2 * Math.PI * Math.random())
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
function resolveInit(opts: ParamOptions | undefined): InitFn {
|
|
64
|
+
const init = opts?.init ?? 'randn'
|
|
65
|
+
if (init === 'randn') {
|
|
66
|
+
const scale = opts?.scale ?? 0.02
|
|
67
|
+
return (size) => {
|
|
68
|
+
const arr = new Float32Array(size)
|
|
69
|
+
for (let i = 0; i < size; i++) arr[i] = boxMuller() * scale
|
|
70
|
+
return arr
|
|
71
|
+
}
|
|
72
|
+
}
|
|
73
|
+
if (init === 'zeros') return (size) => new Float32Array(size)
|
|
74
|
+
if (init === 'ones') return (size) => { const a = new Float32Array(size); a.fill(1); return a }
|
|
75
|
+
if (typeof init === 'function') return init
|
|
76
|
+
throw new Error(`Unknown init: ${String(init)}`)
|
|
77
|
+
}
|
|
78
|
+
|
|
31
79
|
// ============================================================================
|
|
32
80
|
// Internals: param sentinel
|
|
33
81
|
// ============================================================================
|
|
@@ -38,7 +86,11 @@ import { paramInput } from './trace.js'
|
|
|
38
86
|
// only valid post-materialization (which is always before forward runs).
|
|
39
87
|
|
|
40
88
|
class ParamSentinel {
|
|
41
|
-
constructor(
|
|
89
|
+
constructor(
|
|
90
|
+
public readonly shape: Shape,
|
|
91
|
+
public readonly dtype: Dtype,
|
|
92
|
+
public readonly initFn: InitFn,
|
|
93
|
+
) {}
|
|
42
94
|
}
|
|
43
95
|
|
|
44
96
|
// ============================================================================
|
|
@@ -52,11 +104,13 @@ export abstract class Module {
|
|
|
52
104
|
* that gets replaced with a real Tensor at compile time.
|
|
53
105
|
*
|
|
54
106
|
* The parameter's name is auto-derived from its property path in the model
|
|
55
|
-
* tree (e.g. `layers.0.attn.W_q`).
|
|
107
|
+
* tree (e.g. `layers.0.attn.W_q`). Init metadata travels with the param;
|
|
108
|
+
* call `compiled.uploadInitialParams()` to apply it after compile.
|
|
56
109
|
*/
|
|
57
|
-
protected param(shape: Shape,
|
|
110
|
+
protected param(shape: Shape, opts?: ParamOptions): Tensor {
|
|
111
|
+
const dtype = opts?.dtype ?? 'f32'
|
|
58
112
|
// Lie to TypeScript: the sentinel becomes a Tensor at materialize time.
|
|
59
|
-
return new ParamSentinel(shape, dtype) as unknown as Tensor
|
|
113
|
+
return new ParamSentinel(shape, dtype, resolveInit(opts)) as unknown as Tensor
|
|
60
114
|
}
|
|
61
115
|
}
|
|
62
116
|
|
|
@@ -64,23 +118,33 @@ export abstract class Module {
|
|
|
64
118
|
// Tree walking
|
|
65
119
|
// ============================================================================
|
|
66
120
|
|
|
121
|
+
export interface MaterializedParams {
|
|
122
|
+
/** Map from auto-derived path (e.g. `layers.0.attn.W_q`) to its Tensor. */
|
|
123
|
+
tensors: Record<string, Tensor>
|
|
124
|
+
/** Init function per param path. Used by `uploadInitialParams`. */
|
|
125
|
+
initFns: Record<string, InitFn>
|
|
126
|
+
}
|
|
127
|
+
|
|
67
128
|
/**
|
|
68
129
|
* Walk the module tree and replace every ParamSentinel with a real Tensor
|
|
69
130
|
* created via `paramInput(autoName, ...)`. Must be called inside an active
|
|
70
131
|
* trace context (paramInput appends to the current graph).
|
|
71
132
|
*
|
|
72
|
-
* Returns
|
|
133
|
+
* Returns the param tensors keyed by path, plus init functions for use by
|
|
134
|
+
* `uploadInitialParams`.
|
|
73
135
|
*/
|
|
74
|
-
export function materializeParams(root: Module):
|
|
75
|
-
const
|
|
136
|
+
export function materializeParams(root: Module): MaterializedParams {
|
|
137
|
+
const tensors: Record<string, Tensor> = {}
|
|
138
|
+
const initFns: Record<string, InitFn> = {}
|
|
76
139
|
visit(root, '', (path, val, owner, key) => {
|
|
77
140
|
if (val instanceof ParamSentinel) {
|
|
78
141
|
const t = paramInput(path, val.shape, val.dtype)
|
|
79
142
|
;(owner as any)[key] = t
|
|
80
|
-
|
|
143
|
+
tensors[path] = t
|
|
144
|
+
initFns[path] = val.initFn
|
|
81
145
|
}
|
|
82
146
|
})
|
|
83
|
-
return
|
|
147
|
+
return { tensors, initFns }
|
|
84
148
|
}
|
|
85
149
|
|
|
86
150
|
// ----------------------------------------------------------------------------
|