tensorgrad 0.0.14 → 0.0.16

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/worker.ts ADDED
@@ -0,0 +1,281 @@
1
+ // Worker entry point. Holds the GPUDevice + CompiledRuntime for one or more
2
+ // graphs and proxies main-thread requests via postMessage. See
3
+ // specs/WorkerArchitecture.md for the rationale.
4
+ //
5
+ // Keep this file dependency-free of anything DOM-y: it bundles into a Blob
6
+ // URL and runs in a Web Worker context where `window`/`document` don't
7
+ // exist. WebGPU IS available in workers (Chrome 113+, Safari 17.4+).
8
+
9
+ import { createRuntime, type CompiledRuntime, type RuntimeOpts } from './runtime.js'
10
+ import { resolveLR, type LRSchedule } from './adam.js'
11
+ import type { Req, Res, WireIR, WireAdamConfig, WireError } from './worker-protocol.js'
12
+ import { wireError } from './worker-protocol.js'
13
+
14
+ // ----------------------------------------------------------------------------
15
+ // Per-graph state
16
+ // ----------------------------------------------------------------------------
17
+
18
+ interface GraphSlot {
19
+ runtime: CompiledRuntime
20
+ paramNames: readonly string[]
21
+ outputShape: number[]
22
+ kernelCount: number
23
+ captureShapes: Record<string, number[]>
24
+ /** Adam state for this graph, if it's a training graph. The wrapped step
25
+ * uses these to populate the per-step lrt and decayShrink scalars. */
26
+ adam: AdamState | null
27
+ }
28
+
29
+ interface AdamState {
30
+ config: WireAdamConfig
31
+ t: number
32
+ lrtBuf: Float32Array
33
+ decayShrinkBuf: Float32Array | null
34
+ }
35
+
36
+ const graphs = new Map<number, GraphSlot>()
37
+
38
+ // Worker holds one device shared across all graphs (sibling forward graphs
39
+ // must share param GPUBuffers, which means sharing a device).
40
+ let device: GPUDevice | null = null
41
+
42
+ async function ensureDevice(): Promise<GPUDevice> {
43
+ if (device) return device
44
+ if (typeof navigator === 'undefined' || !navigator.gpu) {
45
+ throw new Error('tensorgrad worker: WebGPU not available in this environment')
46
+ }
47
+ const adapter = await navigator.gpu.requestAdapter()
48
+ if (!adapter) throw new Error('tensorgrad worker: no WebGPU adapter')
49
+ device = await adapter.requestDevice()
50
+ return device
51
+ }
52
+
53
+ // ----------------------------------------------------------------------------
54
+ // Request handlers
55
+ // ----------------------------------------------------------------------------
56
+
57
+ async function handleCreateRuntime(payload: {
58
+ graphId: number
59
+ ir: WireIR
60
+ initialParams: Record<string, Float32Array>
61
+ adam: WireAdamConfig | null
62
+ }): Promise<{ paramNames: string[]; outputShape: number[]; kernelCount: number; captureShapes: Record<string, number[]> }> {
63
+ const dev = await ensureDevice()
64
+ const { graph, plan, kernels } = payload.ir
65
+ const outputTensorId = graph.outputs[0]!
66
+ const outputBufferId = plan.tensorToBuffer.get(outputTensorId)!
67
+ const opts: RuntimeOpts = { device: dev }
68
+ const runtime = await createRuntime(plan, kernels, outputBufferId, opts)
69
+
70
+ // Upload initial params.
71
+ if (Object.keys(payload.initialParams).length > 0) {
72
+ runtime.uploadParams(payload.initialParams)
73
+ }
74
+
75
+ // Capture shape metadata for return.
76
+ const captureShapes: Record<string, number[]> = {}
77
+ for (const [name, bufId] of plan.capturesByName) {
78
+ captureShapes[name] = [...plan.buffers[bufId]!.shape]
79
+ }
80
+
81
+ const slot: GraphSlot = {
82
+ runtime,
83
+ paramNames: [...plan.paramsByName.keys()],
84
+ outputShape: [...runtime.outputShape],
85
+ kernelCount: kernels.filter(k => k.wgsl).length,
86
+ captureShapes,
87
+ adam: payload.adam ? createAdamState(payload.adam) : null,
88
+ }
89
+ graphs.set(payload.graphId, slot)
90
+
91
+ return {
92
+ paramNames: [...slot.paramNames],
93
+ outputShape: slot.outputShape,
94
+ kernelCount: slot.kernelCount,
95
+ captureShapes: slot.captureShapes,
96
+ }
97
+ }
98
+
99
+ async function handleCompileForward(payload: {
100
+ graphId: number
101
+ parentGraphId: number
102
+ ir: WireIR
103
+ }): Promise<{ paramNames: string[]; outputShape: number[]; kernelCount: number; captureShapes: Record<string, number[]> }> {
104
+ const dev = await ensureDevice()
105
+ const parent = graphs.get(payload.parentGraphId)
106
+ if (!parent) throw new Error(`compileForward: parent graph ${payload.parentGraphId} not found`)
107
+
108
+ const { graph, plan, kernels } = payload.ir
109
+ const outputTensorId = graph.outputs[0]!
110
+ const outputBufferId = plan.tensorToBuffer.get(outputTensorId)!
111
+ const opts: RuntimeOpts = { device: dev, sharedParams: parent.runtime.params }
112
+ const runtime = await createRuntime(plan, kernels, outputBufferId, opts)
113
+ // No initial-param upload — sharedParams covers everything.
114
+
115
+ const captureShapes: Record<string, number[]> = {}
116
+ for (const [name, bufId] of plan.capturesByName) {
117
+ captureShapes[name] = [...plan.buffers[bufId]!.shape]
118
+ }
119
+
120
+ const slot: GraphSlot = {
121
+ runtime,
122
+ paramNames: [...plan.paramsByName.keys()],
123
+ outputShape: [...runtime.outputShape],
124
+ kernelCount: kernels.filter(k => k.wgsl).length,
125
+ captureShapes,
126
+ adam: null,
127
+ }
128
+ graphs.set(payload.graphId, slot)
129
+
130
+ return {
131
+ paramNames: [...slot.paramNames],
132
+ outputShape: slot.outputShape,
133
+ kernelCount: slot.kernelCount,
134
+ captureShapes: slot.captureShapes,
135
+ }
136
+ }
137
+
138
+ function createAdamState(cfg: WireAdamConfig): AdamState {
139
+ return {
140
+ config: cfg,
141
+ t: 0,
142
+ lrtBuf: new Float32Array(1),
143
+ decayShrinkBuf: cfg.decayShrinkInputName ? new Float32Array(1) : null,
144
+ }
145
+ }
146
+
147
+ /** Inject Adam's per-step lrt + decayShrink scalars into the inputs map.
148
+ * Called before every step on a training graph. The buffers are reused
149
+ * across steps to avoid allocation. */
150
+ function injectAdamScalars(slot: GraphSlot, inputs: Record<string, Int32Array | Float32Array>): Record<string, Int32Array | Float32Array> {
151
+ const a = slot.adam
152
+ if (!a) return inputs
153
+ a.t++
154
+ const lrNow = resolveLR(a.config.lr as LRSchedule, a.t)
155
+ a.lrtBuf[0] = lrNow * Math.sqrt(1 - Math.pow(a.config.b2, a.t)) / (1 - Math.pow(a.config.b1, a.t))
156
+ const merged: Record<string, Int32Array | Float32Array> = { ...inputs, [a.config.lrtInputName]: a.lrtBuf }
157
+ if (a.decayShrinkBuf && a.config.decayShrinkInputName) {
158
+ a.decayShrinkBuf[0] = 1 - lrNow * a.config.weightDecay
159
+ merged[a.config.decayShrinkInputName] = a.decayShrinkBuf
160
+ }
161
+ return merged
162
+ }
163
+
164
+ async function handleStep(payload: {
165
+ graphId: number
166
+ inputs: Record<string, Int32Array | Float32Array>
167
+ withCaptures: boolean
168
+ }): Promise<{ loss: number; captures: Record<string, Float32Array> | null }> {
169
+ const slot = mustGet(payload.graphId)
170
+ const merged = injectAdamScalars(slot, payload.inputs)
171
+ if (payload.withCaptures) {
172
+ const r = await slot.runtime.step(merged, { withCaptures: true })
173
+ return { loss: r.loss, captures: capturesToRecord(r.captures, slot.captureShapes) }
174
+ }
175
+ const loss = await slot.runtime.step(merged)
176
+ return { loss, captures: null }
177
+ }
178
+
179
+ async function handleRun(payload: {
180
+ graphId: number
181
+ inputs: Record<string, Int32Array | Float32Array>
182
+ withCaptures: boolean
183
+ }): Promise<{ output: Float32Array; captures: Record<string, Float32Array> | null }> {
184
+ const slot = mustGet(payload.graphId)
185
+ if (payload.withCaptures) {
186
+ const r = await slot.runtime.run(payload.inputs, { withCaptures: true })
187
+ return { output: r.output, captures: capturesToRecord(r.captures, slot.captureShapes) }
188
+ }
189
+ const output = await slot.runtime.run(payload.inputs)
190
+ return { output, captures: null }
191
+ }
192
+
193
+ /** Captures (a class instance with a private Map) → a plain Record so the
194
+ * worker can transfer Float32Arrays back without serializing the class. */
195
+ function capturesToRecord(
196
+ captures: { get(name: string): Float32Array; has(name: string): boolean; names(): string[] },
197
+ // captureShapes available but not used directly — capture names from
198
+ // shapes in case captures.names() is filtered (it isn't, but be safe).
199
+ shapes: Record<string, number[]>,
200
+ ): Record<string, Float32Array> {
201
+ const out: Record<string, Float32Array> = {}
202
+ for (const name of Object.keys(shapes)) {
203
+ if (captures.has(name)) out[name] = captures.get(name)
204
+ }
205
+ return out
206
+ }
207
+
208
+ function handleUploadParams(payload: {
209
+ graphId: number
210
+ params: Record<string, Float32Array>
211
+ partial: boolean
212
+ }): void {
213
+ const slot = mustGet(payload.graphId)
214
+ slot.runtime.uploadParams(payload.params, { partial: payload.partial })
215
+ }
216
+
217
+ async function handleDownloadParams(payload: { graphId: number }): Promise<{ params: Record<string, Float32Array> }> {
218
+ const slot = mustGet(payload.graphId)
219
+ return { params: await slot.runtime.downloadParams() }
220
+ }
221
+
222
+ async function handleDownloadParamGrads(payload: { graphId: number }): Promise<{ params: Record<string, Float32Array> }> {
223
+ const slot = mustGet(payload.graphId)
224
+ return { params: await slot.runtime.downloadParamGrads() }
225
+ }
226
+
227
+ function handleResetOptimizer(payload: { graphId: number }): void {
228
+ const slot = mustGet(payload.graphId)
229
+ slot.runtime.resetOptimizerState()
230
+ if (slot.adam) slot.adam.t = 0
231
+ }
232
+
233
+ function handleDestroy(payload: { graphId: number }): void {
234
+ const slot = graphs.get(payload.graphId)
235
+ if (!slot) return
236
+ slot.runtime.destroy()
237
+ graphs.delete(payload.graphId)
238
+ }
239
+
240
+ function mustGet(graphId: number): GraphSlot {
241
+ const slot = graphs.get(graphId)
242
+ if (!slot) throw new Error(`tensorgrad worker: graph ${graphId} not found`)
243
+ return slot
244
+ }
245
+
246
+ // ----------------------------------------------------------------------------
247
+ // Message dispatch
248
+ // ----------------------------------------------------------------------------
249
+
250
+ self.onmessage = async (ev: MessageEvent<Req>) => {
251
+ const req = ev.data
252
+ try {
253
+ let result: unknown
254
+ let transferList: ArrayBuffer[] = []
255
+ switch (req.kind) {
256
+ case 'createRuntime': result = await handleCreateRuntime(req.payload); break
257
+ case 'compileForward': result = await handleCompileForward(req.payload); break
258
+ case 'step': result = await handleStep(req.payload); transferList = collectTransfers((result as any).captures); break
259
+ case 'run': { const r = await handleRun(req.payload); result = r; transferList = [r.output.buffer as ArrayBuffer, ...collectTransfers(r.captures)]; break }
260
+ case 'uploadParams': handleUploadParams(req.payload); result = null; break
261
+ case 'downloadParams': { const r = await handleDownloadParams(req.payload); result = r; transferList = collectTransfers(r.params); break }
262
+ case 'downloadParamGrads':{ const r = await handleDownloadParamGrads(req.payload); result = r; transferList = collectTransfers(r.params); break }
263
+ case 'resetOptimizer': handleResetOptimizer(req.payload); result = null; break
264
+ case 'destroy': handleDestroy(req.payload); result = null; break
265
+ default: throw new Error(`unknown request kind: ${(req as { kind: string }).kind}`)
266
+ }
267
+ const reply: Res = { id: req.id, ok: true, result }
268
+ self.postMessage(reply, { transfer: transferList })
269
+ } catch (e) {
270
+ const error: WireError = wireError(e)
271
+ const reply: Res = { id: req.id, ok: false, error }
272
+ self.postMessage(reply)
273
+ }
274
+ }
275
+
276
+ function collectTransfers(rec: Record<string, Float32Array> | null | undefined): ArrayBuffer[] {
277
+ if (!rec) return []
278
+ const out: ArrayBuffer[] = []
279
+ for (const v of Object.values(rec)) out.push(v.buffer as ArrayBuffer)
280
+ return out
281
+ }
package/dist/adam.js DELETED
@@ -1,111 +0,0 @@
1
- // Adam / AdamW optimizer, in-graph.
2
- //
3
- // `appendAdam` extends a graph that already has a forward pass + autograd-emitted
4
- // backward (i.e., has paramGrads from `appendGrad`) with the Adam update math.
5
- //
6
- // Per parameter P with gradient g:
7
- // m_new = b1 * m + (1 - b1) * g
8
- // v_new = b2 * v + (1 - b2) * g²
9
- // p_new = decayShrink * p - lrt * m_new / (sqrt(v_new) + eps)
10
- //
11
- // `decayShrink = 1 - lr * weightDecay` when the param is being decayed
12
- // (Loshchilov & Hutter, "AdamW") and 1 otherwise — at which point the
13
- // multiply folds out and you're left with plain Adam. `lrt` is supplied
14
- // per-step from CPU and includes the bias-correction factor
15
- // `sqrt(1-b2^t)/(1-b1^t)`; that's why convergence isn't affected by the
16
- // first-step warmup that bias-correction-free Adam suffers.
17
- //
18
- // **Static vs scheduled lr.** When `config.lr` is a number, decayShrink is
19
- // baked into the kernel as a literal. When it's a function `(step) => lr`,
20
- // decayShrink for decayed params becomes a per-step scalar input that the
21
- // runtime updates each call (computed from the current step's lr). lrt is
22
- // always per-step; the bias-correction factor changes every step regardless.
23
- //
24
- // Returns writeback declarations the buffer planner uses to wire up the
25
- // "after step, copy the new value into the persistent home" path. m and v
26
- // are state_inputs (zero-initialized, persistent across steps); the param
27
- // updates are aliased back to the param buffers.
28
- import { traceInto, stateInput, tensorInput } from './trace.js';
29
- import { adamUpdateM, adamUpdateV, adamUpdateP } from './ops.js';
30
- /**
31
- * Append Adam update ops to `graph`. Must be called inside an active trace
32
- * context (or after a trace, since traceInto re-enters the graph).
33
- *
34
- * @param graph the graph (already containing forward + backward)
35
- * @param paramGrads param name -> gradient tensor (output of `appendGrad`)
36
- * @param paramTensors param name -> the param's leaf Tensor (the param_input).
37
- * Needed because the param_input lives in the graph but we
38
- * don't have a direct map by name in `Graph` — caller passes it.
39
- * @param config Adam hyperparameters. Set `weightDecay > 0` for AdamW; an
40
- * optional `decayFilter` selects which params receive decay.
41
- */
42
- export function appendAdam(graph, paramGrads, paramTensors, config,
43
- /** Per-param decay flags from `materializeParams`. When supplied, overrides
44
- * `config.decayFilter` for any name in the map; falls back to `decayFilter`
45
- * for names not present (e.g., for low-level callers using `compile()`
46
- * directly without a Module). */
47
- decayFlags) {
48
- const lrIsScheduled = typeof config.lr === 'function';
49
- const lrFn = lrIsScheduled
50
- ? config.lr
51
- : (() => config.lr);
52
- const initialLr = lrFn(1);
53
- const fullConfig = {
54
- lr: lrFn,
55
- b1: config.b1 ?? 0.9,
56
- b2: config.b2 ?? 0.999,
57
- eps: config.eps ?? 1e-8,
58
- weightDecay: config.weightDecay ?? 0,
59
- decayFilter: config.decayFilter ?? (() => true),
60
- lrIsScheduled,
61
- };
62
- const writebacks = [];
63
- const lrtInputName = '_adam_lrt';
64
- // Tensor input for runtime-updated decayShrink (only created when lr is a
65
- // schedule fn AND at least one param will receive weight decay).
66
- let decayShrinkInputName = null;
67
- return traceInto(graph, () => {
68
- const lrt = tensorInput(lrtInputName, [], 'f32');
69
- // Up-front: which params receive weight decay? Per-param decayFlags (set
70
- // by Module.param's options) wins; falls back to decayFilter for names
71
- // not in the map. Empty when weightDecay = 0 so the rest of the function
72
- // can just ask "is this name in the set?".
73
- const decayedNames = new Set(fullConfig.weightDecay > 0
74
- ? Object.keys(paramGrads).filter(name => (decayFlags && name in decayFlags) ? decayFlags[name] : fullConfig.decayFilter(name))
75
- : []);
76
- // We only need a runtime decayShrink scalar when lr varies per step AND
77
- // at least one param is being decayed. Otherwise the value is constant
78
- // and bakes into the kernel as a literal.
79
- let decayShrinkScalar = null;
80
- if (lrIsScheduled && decayedNames.size > 0) {
81
- decayShrinkInputName = '_adam_decay_shrink';
82
- decayShrinkScalar = tensorInput(decayShrinkInputName, [], 'f32');
83
- }
84
- for (const name of Object.keys(paramGrads)) {
85
- const p = paramTensors[name];
86
- const g = paramGrads[name];
87
- if (!p)
88
- throw new Error(`appendAdam: missing param tensor for '${name}'`);
89
- if (!g)
90
- throw new Error(`appendAdam: missing gradient for '${name}'`);
91
- const mState = stateInput(`adam_m_${name}`, p.shape, 'f32', 0);
92
- const vState = stateInput(`adam_v_${name}`, p.shape, 'f32', 0);
93
- // Choose the decayShrink form per param:
94
- // - non-decayed params: literal 1 (kernel multiply folds out).
95
- // - decayed + scheduled lr: tensor input updated per step.
96
- // - decayed + static lr: literal `1 - lr * wd` baked at compile.
97
- const decayShrink = !decayedNames.has(name) ? 1
98
- : decayShrinkScalar !== null ? decayShrinkScalar
99
- : 1 - initialLr * fullConfig.weightDecay;
100
- // Three fused kernels per parameter — one for each of m_new / v_new / p_new.
101
- const newM = adamUpdateM(mState, g, fullConfig.b1);
102
- const newV = adamUpdateV(vState, g, fullConfig.b2);
103
- const newP = adamUpdateP(p, newM, newV, lrt, fullConfig.eps, decayShrink);
104
- writebacks.push({ source: newM, destName: `adam_m_${name}`, destKind: 'state' });
105
- writebacks.push({ source: newV, destName: `adam_v_${name}`, destKind: 'state' });
106
- writebacks.push({ source: newP, destName: name, destKind: 'param' });
107
- }
108
- return { writebacks, lrtInputName, decayShrinkInputName, config: fullConfig };
109
- });
110
- }
111
- //# sourceMappingURL=adam.js.map
package/dist/adam.js.map DELETED
@@ -1 +0,0 @@
1
- {"version":3,"file":"adam.js","sourceRoot":"","sources":["../src/adam.ts"],"names":[],"mappings":"AAAA,oCAAoC;AACpC,EAAE;AACF,kFAAkF;AAClF,+EAA+E;AAC/E,EAAE;AACF,mCAAmC;AACnC,kCAAkC;AAClC,mCAAmC;AACnC,gEAAgE;AAChE,EAAE;AACF,uEAAuE;AACvE,sEAAsE;AACtE,wEAAwE;AACxE,4DAA4D;AAC5D,wEAAwE;AACxE,4DAA4D;AAC5D,EAAE;AACF,2EAA2E;AAC3E,2EAA2E;AAC3E,0EAA0E;AAC1E,0EAA0E;AAC1E,6EAA6E;AAC7E,EAAE;AACF,wEAAwE;AACxE,0EAA0E;AAC1E,0EAA0E;AAC1E,iDAAiD;AAKjD,OAAO,EAAE,SAAS,EAAE,UAAU,EAAE,WAAW,EAAE,MAAM,YAAY,CAAA;AAC/D,OAAO,EAAE,WAAW,EAAE,WAAW,EAAE,WAAW,EAAE,MAAM,UAAU,CAAA;AAgDhE;;;;;;;;;;;GAWG;AACH,MAAM,UAAU,UAAU,CACxB,KAAY,EACZ,UAAkC,EAClC,YAAoC,EACpC,MAAkB;AAClB;;;kCAGkC;AAClC,UAAoC;IAEpC,MAAM,aAAa,GAAG,OAAO,MAAM,CAAC,EAAE,KAAK,UAAU,CAAA;IACrD,MAAM,IAAI,GAAG,aAAa;QACxB,CAAC,CAAC,MAAM,CAAC,EAA8B;QACvC,CAAC,CAAC,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,EAAY,CAAC,CAAA;IAC/B,MAAM,SAAS,GAAG,IAAI,CAAC,CAAC,CAAC,CAAA;IACzB,MAAM,UAAU,GAAuB;QACrC,EAAE,EAAE,IAAI;QACR,EAAE,EAAE,MAAM,CAAC,EAAE,IAAI,GAAG;QACpB,EAAE,EAAE,MAAM,CAAC,EAAE,IAAI,KAAK;QACtB,GAAG,EAAE,MAAM,CAAC,GAAG,IAAI,IAAI;QACvB,WAAW,EAAE,MAAM,CAAC,WAAW,IAAI,CAAC;QACpC,WAAW,EAAE,MAAM,CAAC,WAAW,IAAI,CAAC,GAAG,EAAE,CAAC,IAAI,CAAC;QAC/C,aAAa;KACd,CAAA;IACD,MAAM,UAAU,GAAoB,EAAE,CAAA;IACtC,MAAM,YAAY,GAAG,WAAW,CAAA;IAChC,0EAA0E;IAC1E,iEAAiE;IACjE,IAAI,oBAAoB,GAAkB,IAAI,CAAA;IAE9C,OAAO,SAAS,CAAC,KAAK,EAAE,GAAG,EAAE;QAC3B,MAAM,GAAG,GAAG,WAAW,CAAC,YAAY,EAAE,EAAE,EAAE,KAAK,CAAC,CAAA;QAEhD,yEAAyE;QACzE,uEAAuE;QACvE,yEAAyE;QACzE,2CAA2C;QAC3C,MAAM,YAAY,GAAG,IAAI,GAAG,CAC1B,UAAU,CAAC,WAAW,GAAG,CAAC;YACxB,CAAC,CAAC,MAAM,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC,MAAM,CAAC,IAAI,CAAC,EAAE,CACpC,CAAC,UAAU,IAAI,IAAI,IAAI,UAAU,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC,IAAI,CAAE,CAAC,CAAC,CAAC,UAAU,CAAC,WAAW,CAAC,IAAI,CAAC,CAAC;YAC1F,CAAC,CAAC,EAAE,CACP,CAAA;QAED,wEAAwE;QACxE,uEAAuE;QACvE,0CAA0C;QAC1C,IAAI,iBAAiB,GAAkB,IAAI,CAAA;QAC3C,IAAI,aAAa,IAAI,YAAY,CAAC,IAAI,GAAG,CAAC,EAAE,CAAC;YAC3C,oBAAoB,GAAG,oBAAoB,CAAA;YAC3C,iBAAiB,GAAG,WAAW,CAAC,oBAAoB,EAAE,EAAE,EAAE,KAAK,CAAC,CAAA;QAClE,CAAC;QAED,KAAK,MAAM,IAAI,IAAI,MAAM,CAAC,IAAI,CAAC,UAAU,CAAC,EAAE,CAAC;YAC3C,MAAM,CAAC,GAAG,YAAY,CAAC,IAAI,CAAC,CAAA;YAC5B,MAAM,CAAC,GAAG,UAAU,CAAC,IAAI,CAAC,CAAA;YAC1B,IAAI,CAAC,CAAC;gBAAE,MAAM,IAAI,KAAK,CAAC,yCAAyC,IAAI,GAAG,CAAC,CAAA;YACzE,IAAI,CAAC,CAAC;gBAAE,MAAM,IAAI,KAAK,CAAC,qCAAqC,IAAI,GAAG,CAAC,CAAA;YAErE,MAAM,MAAM,GAAG,UAAU,CAAC,UAAU,IAAI,EAAE,EAAE,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,CAAC,CAAC,CAAA;YAC9D,MAAM,MAAM,GAAG,UAAU,CAAC,UAAU,IAAI,EAAE,EAAE,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,CAAC,CAAC,CAAA;YAE9D,yCAAyC;YACzC,iEAAiE;YACjE,6DAA6D;YAC7D,mEAAmE;YACnE,MAAM,WAAW,GACf,CAAC,YAAY,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC;gBAC3B,CAAC,CAAC,iBAAiB,KAAK,IAAI,CAAC,CAAC,CAAC,iBAAiB;oBAChD,CAAC,CAAC,CAAC,GAAG,SAAS,GAAG,UAAU,CAAC,WAAW,CAAA;YAE1C,6EAA6E;YAC7E,MAAM,IAAI,GAAG,WAAW,CAAC,MAAM,EAAE,CAAC,EAAE,UAAU,CAAC,EAAE,CAAC,CAAA;YAClD,MAAM,IAAI,GAAG,WAAW,CAAC,MAAM,EAAE,CAAC,EAAE,UAAU,CAAC,EAAE,CAAC,CAAA;YAClD,MAAM,IAAI,GAAG,WAAW,CAAC,CAAC,EAAE,IAAI,EAAE,IAAI,EAAE,GAAG,EAAE,UAAU,CAAC,GAAG,EAAE,WAAW,CAAC,CAAA;YAEzE,UAAU,CAAC,IAAI,CAAC,EAAE,MAAM,EAAE,IAAI,EAAE,QAAQ,EAAE,UAAU,IAAI,EAAE,EAAE,QAAQ,EAAE,OAAO,EAAE,CAAC,CAAA;YAChF,UAAU,CAAC,IAAI,CAAC,EAAE,MAAM,EAAE,IAAI,EAAE,QAAQ,EAAE,UAAU,IAAI,EAAE,EAAE,QAAQ,EAAE,OAAO,EAAE,CAAC,CAAA;YAChF,UAAU,CAAC,IAAI,CAAC,EAAE,MAAM,EAAE,IAAI,EAAE,QAAQ,EAAE,IAAI,EAAc,QAAQ,EAAE,OAAO,EAAE,CAAC,CAAA;QAClF,CAAC;QACD,OAAO,EAAE,UAAU,EAAE,YAAY,EAAE,oBAAoB,EAAE,MAAM,EAAE,UAAU,EAAE,CAAA;IAC/E,CAAC,CAAC,CAAA;AACJ,CAAC"}
package/dist/buffers.js DELETED
@@ -1,120 +0,0 @@
1
- // Buffer planning: walk a Graph and decide which GPU buffer each Tensor maps to.
2
- //
3
- // v1 strategy: one GPU buffer per IR Tensor. Static shapes mean every buffer's
4
- // size is known at compile time and lifetimes don't overlap between steps —
5
- // so no pooling needed. Total memory is the sum of every intermediate tensor.
6
- // For our transformer at B=256: ~30 MB of activations + grads. Easily fits.
7
- //
8
- // Categorization is what the runtime cares about:
9
- // * param — uploaded by user via uploadParams; persistent across steps
10
- // * param_grad — written each step by the backward pass; readable for inspection
11
- // * tensor_input — uploaded each step (tokens, targets, masks)
12
- // * intermediate — produced by an op; lifetime = within a single step
13
- // * output — special intermediate that should be made readable (loss)
14
- import { shapeSize } from './shape.js';
15
- const dtypeBytes = { f32: 4, i32: 4, bool: 4 };
16
- /**
17
- * Build a BufferPlan from a graph + the param-grad map produced by appendGrad.
18
- * @param graph the full graph (forward + backward + any optimizer ops)
19
- * @param paramGrads map from param name -> the Tensor that holds its gradient
20
- * @param writebackDecls list of end-of-step writebacks (e.g. from appendAdam).
21
- * Empty when there's no optimizer in the graph.
22
- */
23
- export function planBuffers(graph, paramGrads, writebackDecls = []) {
24
- const buffers = [];
25
- const tensorToBuffer = new Map();
26
- const paramsByName = new Map();
27
- const inputsByName = new Map();
28
- const paramGradsByName = new Map();
29
- const statesByName = new Map();
30
- // Build a quick reverse map: tensorId -> param name (for grads).
31
- const gradTensorIdToName = new Map();
32
- for (const [name, tensor] of Object.entries(paramGrads)) {
33
- gradTensorIdToName.set(tensor.id, name);
34
- }
35
- // ...and tensorId -> param/input op (so we can name the buffer correctly).
36
- const opByOutId = new Map();
37
- for (const op of graph.ops)
38
- opByOutId.set(op.out, op);
39
- const outputSet = new Set(graph.outputs);
40
- // Walk all tensors in id order. Categorize each.
41
- for (const t of graph.tensors) {
42
- const op = opByOutId.get(t.id);
43
- let kind = 'intermediate';
44
- let name = null;
45
- let initValue;
46
- if (op?.kind === 'param_input') {
47
- kind = 'param';
48
- name = op.name;
49
- }
50
- else if (op?.kind === 'tensor_input') {
51
- kind = 'tensor_input';
52
- name = op.name;
53
- }
54
- else if (op?.kind === 'state_input') {
55
- kind = 'state';
56
- name = op.name;
57
- initValue = op.initValue;
58
- }
59
- else if (gradTensorIdToName.has(t.id)) {
60
- kind = 'param_grad';
61
- name = gradTensorIdToName.get(t.id);
62
- }
63
- else if (outputSet.has(t.id)) {
64
- kind = 'output';
65
- }
66
- const spec = {
67
- id: t.id,
68
- byteSize: Math.max(4, shapeSize(t.shape) * dtypeBytes[t.dtype]),
69
- dtype: t.dtype,
70
- shape: t.shape,
71
- kind,
72
- name,
73
- ...(initValue !== undefined ? { initValue } : {}),
74
- };
75
- buffers.push(spec);
76
- tensorToBuffer.set(t.id, t.id); // 1:1 for v1
77
- if (kind === 'param')
78
- paramsByName.set(name, t.id);
79
- if (kind === 'tensor_input')
80
- inputsByName.set(name, t.id);
81
- if (kind === 'param_grad')
82
- paramGradsByName.set(name, t.id);
83
- if (kind === 'state')
84
- statesByName.set(name, t.id);
85
- }
86
- const outputBufferIds = graph.outputs.map(id => tensorToBuffer.get(id));
87
- // Resolve writeback declarations to (source, dest) buffer-id pairs.
88
- const writebacks = writebackDecls.map(decl => {
89
- const sourceBufId = tensorToBuffer.get(decl.source.id);
90
- if (sourceBufId === undefined) {
91
- throw new Error(`planBuffers: writeback source tensor #${decl.source.id} not in graph`);
92
- }
93
- const destBufId = decl.destKind === 'param'
94
- ? paramsByName.get(decl.destName)
95
- : statesByName.get(decl.destName);
96
- if (destBufId === undefined) {
97
- throw new Error(`planBuffers: writeback dest ${decl.destKind}:'${decl.destName}' not found`);
98
- }
99
- const sourceSpec = buffers[sourceBufId];
100
- const destSpec = buffers[destBufId];
101
- if (sourceSpec.byteSize !== destSpec.byteSize) {
102
- throw new Error(`planBuffers: writeback size mismatch for ${decl.destKind}:'${decl.destName}' ` +
103
- `(source ${sourceSpec.byteSize} bytes vs dest ${destSpec.byteSize})`);
104
- }
105
- return { source: sourceBufId, dest: destBufId, bytes: sourceSpec.byteSize };
106
- });
107
- // Resolve graph.captures (name -> tensor id) to (name -> buffer id).
108
- // No pinning needed at the planner level: each tensor already has its own
109
- // buffer (see "v1 strategy" comment at top — no pooling yet).
110
- const capturesByName = new Map();
111
- for (const [name, tensorId] of graph.captures) {
112
- const bufId = tensorToBuffer.get(tensorId);
113
- if (bufId === undefined) {
114
- throw new Error(`planBuffers: capture '${name}' references unknown tensor #${tensorId}`);
115
- }
116
- capturesByName.set(name, bufId);
117
- }
118
- return { buffers, tensorToBuffer, paramsByName, inputsByName, paramGradsByName, statesByName, capturesByName, outputBufferIds, writebacks };
119
- }
120
- //# sourceMappingURL=buffers.js.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"buffers.js","sourceRoot":"","sources":["../src/buffers.ts"],"names":[],"mappings":"AAAA,iFAAiF;AACjF,EAAE;AACF,+EAA+E;AAC/E,4EAA4E;AAC5E,8EAA8E;AAC9E,4EAA4E;AAC5E,EAAE;AACF,kDAAkD;AAClD,gFAAgF;AAChF,qFAAqF;AACrF,iEAAiE;AACjE,wEAAwE;AACxE,8EAA8E;AAG9E,OAAO,EAAE,SAAS,EAAE,MAAM,YAAY,CAAA;AAyCtC,MAAM,UAAU,GAA0B,EAAE,GAAG,EAAE,CAAC,EAAE,GAAG,EAAE,CAAC,EAAE,IAAI,EAAE,CAAC,EAAE,CAAA;AAcrE;;;;;;GAMG;AACH,MAAM,UAAU,WAAW,CACzB,KAAY,EACZ,UAAkC,EAClC,iBAAkC,EAAE;IAEpC,MAAM,OAAO,GAAiB,EAAE,CAAA;IAChC,MAAM,cAAc,GAAG,IAAI,GAAG,EAAkB,CAAA;IAChD,MAAM,YAAY,GAAG,IAAI,GAAG,EAAkB,CAAA;IAC9C,MAAM,YAAY,GAAG,IAAI,GAAG,EAAkB,CAAA;IAC9C,MAAM,gBAAgB,GAAG,IAAI,GAAG,EAAkB,CAAA;IAClD,MAAM,YAAY,GAAG,IAAI,GAAG,EAAkB,CAAA;IAE9C,iEAAiE;IACjE,MAAM,kBAAkB,GAAG,IAAI,GAAG,EAAkB,CAAA;IACpD,KAAK,MAAM,CAAC,IAAI,EAAE,MAAM,CAAC,IAAI,MAAM,CAAC,OAAO,CAAC,UAAU,CAAC,EAAE,CAAC;QACxD,kBAAkB,CAAC,GAAG,CAAC,MAAM,CAAC,EAAE,EAAE,IAAI,CAAC,CAAA;IACzC,CAAC;IACD,2EAA2E;IAC3E,MAAM,SAAS,GAAG,IAAI,GAAG,EAAkB,CAAA;IAC3C,KAAK,MAAM,EAAE,IAAI,KAAK,CAAC,GAAG;QAAE,SAAS,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,EAAE,CAAC,CAAA;IAErD,MAAM,SAAS,GAAG,IAAI,GAAG,CAAC,KAAK,CAAC,OAAO,CAAC,CAAA;IAExC,iDAAiD;IACjD,KAAK,MAAM,CAAC,IAAI,KAAK,CAAC,OAAO,EAAE,CAAC;QAC9B,MAAM,EAAE,GAAG,SAAS,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,CAAA;QAC9B,IAAI,IAAI,GAAuB,cAAc,CAAA;QAC7C,IAAI,IAAI,GAAkB,IAAI,CAAA;QAC9B,IAAI,SAA6B,CAAA;QAEjC,IAAI,EAAE,EAAE,IAAI,KAAK,aAAa,EAAE,CAAC;YAC/B,IAAI,GAAG,OAAO,CAAA;YACd,IAAI,GAAG,EAAE,CAAC,IAAI,CAAA;QAChB,CAAC;aAAM,IAAI,EAAE,EAAE,IAAI,KAAK,cAAc,EAAE,CAAC;YACvC,IAAI,GAAG,cAAc,CAAA;YACrB,IAAI,GAAG,EAAE,CAAC,IAAI,CAAA;QAChB,CAAC;aAAM,IAAI,EAAE,EAAE,IAAI,KAAK,aAAa,EAAE,CAAC;YACtC,IAAI,GAAG,OAAO,CAAA;YACd,IAAI,GAAG,EAAE,CAAC,IAAI,CAAA;YACd,SAAS,GAAG,EAAE,CAAC,SAAS,CAAA;QAC1B,CAAC;aAAM,IAAI,kBAAkB,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC;YACxC,IAAI,GAAG,YAAY,CAAA;YACnB,IAAI,GAAG,kBAAkB,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAE,CAAA;QACtC,CAAC;aAAM,IAAI,SAAS,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC;YAC/B,IAAI,GAAG,QAAQ,CAAA;QACjB,CAAC;QAED,MAAM,IAAI,GAAe;YACvB,EAAE,EAAE,CAAC,CAAC,EAAE;YACR,QAAQ,EAAE,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC,KAAK,CAAC,GAAG,UAAU,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;YAC/D,KAAK,EAAE,CAAC,CAAC,KAAK;YACd,KAAK,EAAE,CAAC,CAAC,KAAK;YACd,IAAI;YACJ,IAAI;YACJ,GAAG,CAAC,SAAS,KAAK,SAAS,CAAC,CAAC,CAAC,EAAE,SAAS,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC;SAClD,CAAA;QACD,OAAO,CAAC,IAAI,CAAC,IAAI,CAAC,CAAA;QAClB,cAAc,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,CAAA,CAAE,aAAa;QAE7C,IAAI,IAAI,KAAK,OAAO;YAAE,YAAY,CAAC,GAAG,CAAC,IAAK,EAAE,CAAC,CAAC,EAAE,CAAC,CAAA;QACnD,IAAI,IAAI,KAAK,cAAc;YAAE,YAAY,CAAC,GAAG,CAAC,IAAK,EAAE,CAAC,CAAC,EAAE,CAAC,CAAA;QAC1D,IAAI,IAAI,KAAK,YAAY;YAAE,gBAAgB,CAAC,GAAG,CAAC,IAAK,EAAE,CAAC,CAAC,EAAE,CAAC,CAAA;QAC5D,IAAI,IAAI,KAAK,OAAO;YAAE,YAAY,CAAC,GAAG,CAAC,IAAK,EAAE,CAAC,CAAC,EAAE,CAAC,CAAA;IACrD,CAAC;IAED,MAAM,eAAe,GAAG,KAAK,CAAC,OAAO,CAAC,GAAG,CAAC,EAAE,CAAC,EAAE,CAAC,cAAc,CAAC,GAAG,CAAC,EAAE,CAAE,CAAC,CAAA;IAExE,oEAAoE;IACpE,MAAM,UAAU,GAAgB,cAAc,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE;QACxD,MAAM,WAAW,GAAG,cAAc,CAAC,GAAG,CAAC,IAAI,CAAC,MAAM,CAAC,EAAE,CAAC,CAAA;QACtD,IAAI,WAAW,KAAK,SAAS,EAAE,CAAC;YAC9B,MAAM,IAAI,KAAK,CAAC,yCAAyC,IAAI,CAAC,MAAM,CAAC,EAAE,eAAe,CAAC,CAAA;QACzF,CAAC;QACD,MAAM,SAAS,GAAG,IAAI,CAAC,QAAQ,KAAK,OAAO;YACzC,CAAC,CAAC,YAAY,CAAC,GAAG,CAAC,IAAI,CAAC,QAAQ,CAAC;YACjC,CAAC,CAAC,YAAY,CAAC,GAAG,CAAC,IAAI,CAAC,QAAQ,CAAC,CAAA;QACnC,IAAI,SAAS,KAAK,SAAS,EAAE,CAAC;YAC5B,MAAM,IAAI,KAAK,CAAC,+BAA+B,IAAI,CAAC,QAAQ,KAAK,IAAI,CAAC,QAAQ,aAAa,CAAC,CAAA;QAC9F,CAAC;QACD,MAAM,UAAU,GAAG,OAAO,CAAC,WAAW,CAAE,CAAA;QACxC,MAAM,QAAQ,GAAG,OAAO,CAAC,SAAS,CAAE,CAAA;QACpC,IAAI,UAAU,CAAC,QAAQ,KAAK,QAAQ,CAAC,QAAQ,EAAE,CAAC;YAC9C,MAAM,IAAI,KAAK,CACb,4CAA4C,IAAI,CAAC,QAAQ,KAAK,IAAI,CAAC,QAAQ,IAAI;gBAC/E,WAAW,UAAU,CAAC,QAAQ,kBAAkB,QAAQ,CAAC,QAAQ,GAAG,CACrE,CAAA;QACH,CAAC;QACD,OAAO,EAAE,MAAM,EAAE,WAAW,EAAE,IAAI,EAAE,SAAS,EAAE,KAAK,EAAE,UAAU,CAAC,QAAQ,EAAE,CAAA;IAC7E,CAAC,CAAC,CAAA;IAEF,qEAAqE;IACrE,0EAA0E;IAC1E,8DAA8D;IAC9D,MAAM,cAAc,GAAG,IAAI,GAAG,EAAkB,CAAA;IAChD,KAAK,MAAM,CAAC,IAAI,EAAE,QAAQ,CAAC,IAAI,KAAK,CAAC,QAAQ,EAAE,CAAC;QAC9C,MAAM,KAAK,GAAG,cAAc,CAAC,GAAG,CAAC,QAAQ,CAAC,CAAA;QAC1C,IAAI,KAAK,KAAK,SAAS,EAAE,CAAC;YACxB,MAAM,IAAI,KAAK,CAAC,yBAAyB,IAAI,gCAAgC,QAAQ,EAAE,CAAC,CAAA;QAC1F,CAAC;QACD,cAAc,CAAC,GAAG,CAAC,IAAI,EAAE,KAAK,CAAC,CAAA;IACjC,CAAC;IAED,OAAO,EAAE,OAAO,EAAE,cAAc,EAAE,YAAY,EAAE,YAAY,EAAE,gBAAgB,EAAE,YAAY,EAAE,cAAc,EAAE,eAAe,EAAE,UAAU,EAAE,CAAA;AAC7I,CAAC"}
package/dist/capture.js DELETED
@@ -1,33 +0,0 @@
1
- // Activation capture — opt-in readback of intermediate tensors at training step.
2
- //
3
- // Usage (inside the user's forward pass):
4
- //
5
- // import { capture } from 'tensorgrad'
6
- //
7
- // function attentionFwd(p, x) {
8
- // const scores = mul(matmulBatched(q, kT), SCALE_QK)
9
- // const attn = capture(`attn.${layerIdx}`, softmaxCausalLast(scores))
10
- // return matmulBatched(attn, v)
11
- // }
12
- //
13
- // Pass-through return type: `capture(name, t)` returns `t` unchanged so it
14
- // inlines at the point of computation. Behind the scenes it registers `t.id`
15
- // against `name` on the current graph; runtime exposes the registered tensors
16
- // via `step(inputs, { withCaptures: true })`.
17
- //
18
- // Outside the user's forward trace (during `appendGrad` / `appendAdam`'s
19
- // `traceInto` re-entry), `capture()` is a no-op — gradient and optimizer
20
- // internals shouldn't accidentally publish themselves to the UI.
21
- import { currentGraph, isCaptureEnabled } from './trace.js';
22
- export function capture(name, t) {
23
- if (!isCaptureEnabled())
24
- return t;
25
- const g = currentGraph();
26
- if (g.captures.has(name)) {
27
- throw new Error(`capture: name '${name}' already registered. Use unique names ` +
28
- `(e.g. \`attn.\${layerIdx}\`) when capturing across a loop.`);
29
- }
30
- g.captures.set(name, t.id);
31
- return t;
32
- }
33
- //# sourceMappingURL=capture.js.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"capture.js","sourceRoot":"","sources":["../src/capture.ts"],"names":[],"mappings":"AAAA,iFAAiF;AACjF,EAAE;AACF,0CAA0C;AAC1C,EAAE;AACF,yCAAyC;AACzC,EAAE;AACF,kCAAkC;AAClC,yDAAyD;AACzD,0EAA0E;AAC1E,oCAAoC;AACpC,MAAM;AACN,EAAE;AACF,2EAA2E;AAC3E,6EAA6E;AAC7E,8EAA8E;AAC9E,8CAA8C;AAC9C,EAAE;AACF,yEAAyE;AACzE,yEAAyE;AACzE,iEAAiE;AAGjE,OAAO,EAAE,YAAY,EAAE,gBAAgB,EAAE,MAAM,YAAY,CAAA;AAE3D,MAAM,UAAU,OAAO,CAAmB,IAAY,EAAE,CAAI;IAC1D,IAAI,CAAC,gBAAgB,EAAE;QAAE,OAAO,CAAC,CAAA;IACjC,MAAM,CAAC,GAAG,YAAY,EAAE,CAAA;IACxB,IAAI,CAAC,CAAC,QAAQ,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,CAAC;QACzB,MAAM,IAAI,KAAK,CACb,kBAAkB,IAAI,yCAAyC;YAC/D,4DAA4D,CAC7D,CAAA;IACH,CAAC;IACD,CAAC,CAAC,QAAQ,CAAC,GAAG,CAAC,IAAI,EAAE,CAAC,CAAC,EAAE,CAAC,CAAA;IAC1B,OAAO,CAAC,CAAA;AACV,CAAC"}