tensorgrad 0.0.15 → 0.0.17
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +253 -119
- package/dist/index.d.ts +154 -193
- package/dist/index.js +2208 -39
- package/dist/index.js.map +7 -1
- package/dist/worker.debug.js +569 -0
- package/package.json +60 -58
- package/src/adam.ts +69 -15
- package/src/compile.ts +334 -156
- package/src/index.ts +8 -4
- package/src/module.ts +72 -34
- package/src/runtime.ts +56 -31
- package/src/worker-protocol.ts +183 -0
- package/src/worker-proxy.ts +76 -0
- package/src/worker.ts +281 -0
- package/dist/adam.js +0 -111
- package/dist/adam.js.map +0 -1
- package/dist/buffers.js +0 -120
- package/dist/buffers.js.map +0 -1
- package/dist/capture.js +0 -33
- package/dist/capture.js.map +0 -1
- package/dist/codegen.js +0 -724
- package/dist/codegen.js.map +0 -1
- package/dist/compile.js +0 -184
- package/dist/compile.js.map +0 -1
- package/dist/grad.js +0 -380
- package/dist/grad.js.map +0 -1
- package/dist/ir.js +0 -60
- package/dist/ir.js.map +0 -1
- package/dist/module.js +0 -155
- package/dist/module.js.map +0 -1
- package/dist/nn.js +0 -135
- package/dist/nn.js.map +0 -1
- package/dist/ops.js +0 -326
- package/dist/ops.js.map +0 -1
- package/dist/runtime.js +0 -402
- package/dist/runtime.js.map +0 -1
- package/dist/shape.js +0 -259
- package/dist/shape.js.map +0 -1
- package/dist/trace.js +0 -100
- package/dist/trace.js.map +0 -1
package/src/worker.ts
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
// Worker entry point. Holds the GPUDevice + CompiledRuntime for one or more
|
|
2
|
+
// graphs and proxies main-thread requests via postMessage. See
|
|
3
|
+
// specs/WorkerArchitecture.md for the rationale.
|
|
4
|
+
//
|
|
5
|
+
// Keep this file dependency-free of anything DOM-y: it bundles into a Blob
|
|
6
|
+
// URL and runs in a Web Worker context where `window`/`document` don't
|
|
7
|
+
// exist. WebGPU IS available in workers (Chrome 113+, Safari 17.4+).
|
|
8
|
+
|
|
9
|
+
import { createRuntime, type CompiledRuntime, type RuntimeOpts } from './runtime.js'
|
|
10
|
+
import { resolveLR, type LRSchedule } from './adam.js'
|
|
11
|
+
import type { Req, Res, WireIR, WireAdamConfig, WireError } from './worker-protocol.js'
|
|
12
|
+
import { wireError } from './worker-protocol.js'
|
|
13
|
+
|
|
14
|
+
// ----------------------------------------------------------------------------
|
|
15
|
+
// Per-graph state
|
|
16
|
+
// ----------------------------------------------------------------------------
|
|
17
|
+
|
|
18
|
+
interface GraphSlot {
|
|
19
|
+
runtime: CompiledRuntime
|
|
20
|
+
paramNames: readonly string[]
|
|
21
|
+
outputShape: number[]
|
|
22
|
+
kernelCount: number
|
|
23
|
+
captureShapes: Record<string, number[]>
|
|
24
|
+
/** Adam state for this graph, if it's a training graph. The wrapped step
|
|
25
|
+
* uses these to populate the per-step lrt and decayShrink scalars. */
|
|
26
|
+
adam: AdamState | null
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
interface AdamState {
|
|
30
|
+
config: WireAdamConfig
|
|
31
|
+
t: number
|
|
32
|
+
lrtBuf: Float32Array
|
|
33
|
+
decayShrinkBuf: Float32Array | null
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
const graphs = new Map<number, GraphSlot>()
|
|
37
|
+
|
|
38
|
+
// Worker holds one device shared across all graphs (sibling forward graphs
|
|
39
|
+
// must share param GPUBuffers, which means sharing a device).
|
|
40
|
+
let device: GPUDevice | null = null
|
|
41
|
+
|
|
42
|
+
async function ensureDevice(): Promise<GPUDevice> {
|
|
43
|
+
if (device) return device
|
|
44
|
+
if (typeof navigator === 'undefined' || !navigator.gpu) {
|
|
45
|
+
throw new Error('tensorgrad worker: WebGPU not available in this environment')
|
|
46
|
+
}
|
|
47
|
+
const adapter = await navigator.gpu.requestAdapter()
|
|
48
|
+
if (!adapter) throw new Error('tensorgrad worker: no WebGPU adapter')
|
|
49
|
+
device = await adapter.requestDevice()
|
|
50
|
+
return device
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
// ----------------------------------------------------------------------------
|
|
54
|
+
// Request handlers
|
|
55
|
+
// ----------------------------------------------------------------------------
|
|
56
|
+
|
|
57
|
+
async function handleCreateRuntime(payload: {
|
|
58
|
+
graphId: number
|
|
59
|
+
ir: WireIR
|
|
60
|
+
initialParams: Record<string, Float32Array>
|
|
61
|
+
adam: WireAdamConfig | null
|
|
62
|
+
}): Promise<{ paramNames: string[]; outputShape: number[]; kernelCount: number; captureShapes: Record<string, number[]> }> {
|
|
63
|
+
const dev = await ensureDevice()
|
|
64
|
+
const { graph, plan, kernels } = payload.ir
|
|
65
|
+
const outputTensorId = graph.outputs[0]!
|
|
66
|
+
const outputBufferId = plan.tensorToBuffer.get(outputTensorId)!
|
|
67
|
+
const opts: RuntimeOpts = { device: dev }
|
|
68
|
+
const runtime = await createRuntime(plan, kernels, outputBufferId, opts)
|
|
69
|
+
|
|
70
|
+
// Upload initial params.
|
|
71
|
+
if (Object.keys(payload.initialParams).length > 0) {
|
|
72
|
+
runtime.uploadParams(payload.initialParams)
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
// Capture shape metadata for return.
|
|
76
|
+
const captureShapes: Record<string, number[]> = {}
|
|
77
|
+
for (const [name, bufId] of plan.capturesByName) {
|
|
78
|
+
captureShapes[name] = [...plan.buffers[bufId]!.shape]
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
const slot: GraphSlot = {
|
|
82
|
+
runtime,
|
|
83
|
+
paramNames: [...plan.paramsByName.keys()],
|
|
84
|
+
outputShape: [...runtime.outputShape],
|
|
85
|
+
kernelCount: kernels.filter(k => k.wgsl).length,
|
|
86
|
+
captureShapes,
|
|
87
|
+
adam: payload.adam ? createAdamState(payload.adam) : null,
|
|
88
|
+
}
|
|
89
|
+
graphs.set(payload.graphId, slot)
|
|
90
|
+
|
|
91
|
+
return {
|
|
92
|
+
paramNames: [...slot.paramNames],
|
|
93
|
+
outputShape: slot.outputShape,
|
|
94
|
+
kernelCount: slot.kernelCount,
|
|
95
|
+
captureShapes: slot.captureShapes,
|
|
96
|
+
}
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
async function handleCompileForward(payload: {
|
|
100
|
+
graphId: number
|
|
101
|
+
parentGraphId: number
|
|
102
|
+
ir: WireIR
|
|
103
|
+
}): Promise<{ paramNames: string[]; outputShape: number[]; kernelCount: number; captureShapes: Record<string, number[]> }> {
|
|
104
|
+
const dev = await ensureDevice()
|
|
105
|
+
const parent = graphs.get(payload.parentGraphId)
|
|
106
|
+
if (!parent) throw new Error(`compileForward: parent graph ${payload.parentGraphId} not found`)
|
|
107
|
+
|
|
108
|
+
const { graph, plan, kernels } = payload.ir
|
|
109
|
+
const outputTensorId = graph.outputs[0]!
|
|
110
|
+
const outputBufferId = plan.tensorToBuffer.get(outputTensorId)!
|
|
111
|
+
const opts: RuntimeOpts = { device: dev, sharedParams: parent.runtime.params }
|
|
112
|
+
const runtime = await createRuntime(plan, kernels, outputBufferId, opts)
|
|
113
|
+
// No initial-param upload — sharedParams covers everything.
|
|
114
|
+
|
|
115
|
+
const captureShapes: Record<string, number[]> = {}
|
|
116
|
+
for (const [name, bufId] of plan.capturesByName) {
|
|
117
|
+
captureShapes[name] = [...plan.buffers[bufId]!.shape]
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
const slot: GraphSlot = {
|
|
121
|
+
runtime,
|
|
122
|
+
paramNames: [...plan.paramsByName.keys()],
|
|
123
|
+
outputShape: [...runtime.outputShape],
|
|
124
|
+
kernelCount: kernels.filter(k => k.wgsl).length,
|
|
125
|
+
captureShapes,
|
|
126
|
+
adam: null,
|
|
127
|
+
}
|
|
128
|
+
graphs.set(payload.graphId, slot)
|
|
129
|
+
|
|
130
|
+
return {
|
|
131
|
+
paramNames: [...slot.paramNames],
|
|
132
|
+
outputShape: slot.outputShape,
|
|
133
|
+
kernelCount: slot.kernelCount,
|
|
134
|
+
captureShapes: slot.captureShapes,
|
|
135
|
+
}
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
function createAdamState(cfg: WireAdamConfig): AdamState {
|
|
139
|
+
return {
|
|
140
|
+
config: cfg,
|
|
141
|
+
t: 0,
|
|
142
|
+
lrtBuf: new Float32Array(1),
|
|
143
|
+
decayShrinkBuf: cfg.decayShrinkInputName ? new Float32Array(1) : null,
|
|
144
|
+
}
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
/** Inject Adam's per-step lrt + decayShrink scalars into the inputs map.
|
|
148
|
+
* Called before every step on a training graph. The buffers are reused
|
|
149
|
+
* across steps to avoid allocation. */
|
|
150
|
+
function injectAdamScalars(slot: GraphSlot, inputs: Record<string, Int32Array | Float32Array>): Record<string, Int32Array | Float32Array> {
|
|
151
|
+
const a = slot.adam
|
|
152
|
+
if (!a) return inputs
|
|
153
|
+
a.t++
|
|
154
|
+
const lrNow = resolveLR(a.config.lr as LRSchedule, a.t)
|
|
155
|
+
a.lrtBuf[0] = lrNow * Math.sqrt(1 - Math.pow(a.config.b2, a.t)) / (1 - Math.pow(a.config.b1, a.t))
|
|
156
|
+
const merged: Record<string, Int32Array | Float32Array> = { ...inputs, [a.config.lrtInputName]: a.lrtBuf }
|
|
157
|
+
if (a.decayShrinkBuf && a.config.decayShrinkInputName) {
|
|
158
|
+
a.decayShrinkBuf[0] = 1 - lrNow * a.config.weightDecay
|
|
159
|
+
merged[a.config.decayShrinkInputName] = a.decayShrinkBuf
|
|
160
|
+
}
|
|
161
|
+
return merged
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
async function handleStep(payload: {
|
|
165
|
+
graphId: number
|
|
166
|
+
inputs: Record<string, Int32Array | Float32Array>
|
|
167
|
+
withCaptures: boolean
|
|
168
|
+
}): Promise<{ loss: number; captures: Record<string, Float32Array> | null }> {
|
|
169
|
+
const slot = mustGet(payload.graphId)
|
|
170
|
+
const merged = injectAdamScalars(slot, payload.inputs)
|
|
171
|
+
if (payload.withCaptures) {
|
|
172
|
+
const r = await slot.runtime.step(merged, { withCaptures: true })
|
|
173
|
+
return { loss: r.loss, captures: capturesToRecord(r.captures, slot.captureShapes) }
|
|
174
|
+
}
|
|
175
|
+
const loss = await slot.runtime.step(merged)
|
|
176
|
+
return { loss, captures: null }
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
async function handleRun(payload: {
|
|
180
|
+
graphId: number
|
|
181
|
+
inputs: Record<string, Int32Array | Float32Array>
|
|
182
|
+
withCaptures: boolean
|
|
183
|
+
}): Promise<{ output: Float32Array; captures: Record<string, Float32Array> | null }> {
|
|
184
|
+
const slot = mustGet(payload.graphId)
|
|
185
|
+
if (payload.withCaptures) {
|
|
186
|
+
const r = await slot.runtime.run(payload.inputs, { withCaptures: true })
|
|
187
|
+
return { output: r.output, captures: capturesToRecord(r.captures, slot.captureShapes) }
|
|
188
|
+
}
|
|
189
|
+
const output = await slot.runtime.run(payload.inputs)
|
|
190
|
+
return { output, captures: null }
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
/** Captures (a class instance with a private Map) → a plain Record so the
|
|
194
|
+
* worker can transfer Float32Arrays back without serializing the class. */
|
|
195
|
+
function capturesToRecord(
|
|
196
|
+
captures: { get(name: string): Float32Array; has(name: string): boolean; names(): string[] },
|
|
197
|
+
// captureShapes available but not used directly — capture names from
|
|
198
|
+
// shapes in case captures.names() is filtered (it isn't, but be safe).
|
|
199
|
+
shapes: Record<string, number[]>,
|
|
200
|
+
): Record<string, Float32Array> {
|
|
201
|
+
const out: Record<string, Float32Array> = {}
|
|
202
|
+
for (const name of Object.keys(shapes)) {
|
|
203
|
+
if (captures.has(name)) out[name] = captures.get(name)
|
|
204
|
+
}
|
|
205
|
+
return out
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
function handleUploadParams(payload: {
|
|
209
|
+
graphId: number
|
|
210
|
+
params: Record<string, Float32Array>
|
|
211
|
+
partial: boolean
|
|
212
|
+
}): void {
|
|
213
|
+
const slot = mustGet(payload.graphId)
|
|
214
|
+
slot.runtime.uploadParams(payload.params, { partial: payload.partial })
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
async function handleDownloadParams(payload: { graphId: number }): Promise<{ params: Record<string, Float32Array> }> {
|
|
218
|
+
const slot = mustGet(payload.graphId)
|
|
219
|
+
return { params: await slot.runtime.downloadParams() }
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
async function handleDownloadParamGrads(payload: { graphId: number }): Promise<{ params: Record<string, Float32Array> }> {
|
|
223
|
+
const slot = mustGet(payload.graphId)
|
|
224
|
+
return { params: await slot.runtime.downloadParamGrads() }
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
function handleResetOptimizer(payload: { graphId: number }): void {
|
|
228
|
+
const slot = mustGet(payload.graphId)
|
|
229
|
+
slot.runtime.resetOptimizerState()
|
|
230
|
+
if (slot.adam) slot.adam.t = 0
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
function handleDestroy(payload: { graphId: number }): void {
|
|
234
|
+
const slot = graphs.get(payload.graphId)
|
|
235
|
+
if (!slot) return
|
|
236
|
+
slot.runtime.destroy()
|
|
237
|
+
graphs.delete(payload.graphId)
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
function mustGet(graphId: number): GraphSlot {
|
|
241
|
+
const slot = graphs.get(graphId)
|
|
242
|
+
if (!slot) throw new Error(`tensorgrad worker: graph ${graphId} not found`)
|
|
243
|
+
return slot
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
// ----------------------------------------------------------------------------
|
|
247
|
+
// Message dispatch
|
|
248
|
+
// ----------------------------------------------------------------------------
|
|
249
|
+
|
|
250
|
+
self.onmessage = async (ev: MessageEvent<Req>) => {
|
|
251
|
+
const req = ev.data
|
|
252
|
+
try {
|
|
253
|
+
let result: unknown
|
|
254
|
+
let transferList: ArrayBuffer[] = []
|
|
255
|
+
switch (req.kind) {
|
|
256
|
+
case 'createRuntime': result = await handleCreateRuntime(req.payload); break
|
|
257
|
+
case 'compileForward': result = await handleCompileForward(req.payload); break
|
|
258
|
+
case 'step': result = await handleStep(req.payload); transferList = collectTransfers((result as any).captures); break
|
|
259
|
+
case 'run': { const r = await handleRun(req.payload); result = r; transferList = [r.output.buffer as ArrayBuffer, ...collectTransfers(r.captures)]; break }
|
|
260
|
+
case 'uploadParams': handleUploadParams(req.payload); result = null; break
|
|
261
|
+
case 'downloadParams': { const r = await handleDownloadParams(req.payload); result = r; transferList = collectTransfers(r.params); break }
|
|
262
|
+
case 'downloadParamGrads':{ const r = await handleDownloadParamGrads(req.payload); result = r; transferList = collectTransfers(r.params); break }
|
|
263
|
+
case 'resetOptimizer': handleResetOptimizer(req.payload); result = null; break
|
|
264
|
+
case 'destroy': handleDestroy(req.payload); result = null; break
|
|
265
|
+
default: throw new Error(`unknown request kind: ${(req as { kind: string }).kind}`)
|
|
266
|
+
}
|
|
267
|
+
const reply: Res = { id: req.id, ok: true, result }
|
|
268
|
+
self.postMessage(reply, { transfer: transferList })
|
|
269
|
+
} catch (e) {
|
|
270
|
+
const error: WireError = wireError(e)
|
|
271
|
+
const reply: Res = { id: req.id, ok: false, error }
|
|
272
|
+
self.postMessage(reply)
|
|
273
|
+
}
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
function collectTransfers(rec: Record<string, Float32Array> | null | undefined): ArrayBuffer[] {
|
|
277
|
+
if (!rec) return []
|
|
278
|
+
const out: ArrayBuffer[] = []
|
|
279
|
+
for (const v of Object.values(rec)) out.push(v.buffer as ArrayBuffer)
|
|
280
|
+
return out
|
|
281
|
+
}
|
package/dist/adam.js
DELETED
|
@@ -1,111 +0,0 @@
|
|
|
1
|
-
// Adam / AdamW optimizer, in-graph.
|
|
2
|
-
//
|
|
3
|
-
// `appendAdam` extends a graph that already has a forward pass + autograd-emitted
|
|
4
|
-
// backward (i.e., has paramGrads from `appendGrad`) with the Adam update math.
|
|
5
|
-
//
|
|
6
|
-
// Per parameter P with gradient g:
|
|
7
|
-
// m_new = b1 * m + (1 - b1) * g
|
|
8
|
-
// v_new = b2 * v + (1 - b2) * g²
|
|
9
|
-
// p_new = decayShrink * p - lrt * m_new / (sqrt(v_new) + eps)
|
|
10
|
-
//
|
|
11
|
-
// `decayShrink = 1 - lr * weightDecay` when the param is being decayed
|
|
12
|
-
// (Loshchilov & Hutter, "AdamW") and 1 otherwise — at which point the
|
|
13
|
-
// multiply folds out and you're left with plain Adam. `lrt` is supplied
|
|
14
|
-
// per-step from CPU and includes the bias-correction factor
|
|
15
|
-
// `sqrt(1-b2^t)/(1-b1^t)`; that's why convergence isn't affected by the
|
|
16
|
-
// first-step warmup that bias-correction-free Adam suffers.
|
|
17
|
-
//
|
|
18
|
-
// **Static vs scheduled lr.** When `config.lr` is a number, decayShrink is
|
|
19
|
-
// baked into the kernel as a literal. When it's a function `(step) => lr`,
|
|
20
|
-
// decayShrink for decayed params becomes a per-step scalar input that the
|
|
21
|
-
// runtime updates each call (computed from the current step's lr). lrt is
|
|
22
|
-
// always per-step; the bias-correction factor changes every step regardless.
|
|
23
|
-
//
|
|
24
|
-
// Returns writeback declarations the buffer planner uses to wire up the
|
|
25
|
-
// "after step, copy the new value into the persistent home" path. m and v
|
|
26
|
-
// are state_inputs (zero-initialized, persistent across steps); the param
|
|
27
|
-
// updates are aliased back to the param buffers.
|
|
28
|
-
import { traceInto, stateInput, tensorInput } from './trace.js';
|
|
29
|
-
import { adamUpdateM, adamUpdateV, adamUpdateP } from './ops.js';
|
|
30
|
-
/**
|
|
31
|
-
* Append Adam update ops to `graph`. Must be called inside an active trace
|
|
32
|
-
* context (or after a trace, since traceInto re-enters the graph).
|
|
33
|
-
*
|
|
34
|
-
* @param graph the graph (already containing forward + backward)
|
|
35
|
-
* @param paramGrads param name -> gradient tensor (output of `appendGrad`)
|
|
36
|
-
* @param paramTensors param name -> the param's leaf Tensor (the param_input).
|
|
37
|
-
* Needed because the param_input lives in the graph but we
|
|
38
|
-
* don't have a direct map by name in `Graph` — caller passes it.
|
|
39
|
-
* @param config Adam hyperparameters. Set `weightDecay > 0` for AdamW; an
|
|
40
|
-
* optional `decayFilter` selects which params receive decay.
|
|
41
|
-
*/
|
|
42
|
-
export function appendAdam(graph, paramGrads, paramTensors, config,
|
|
43
|
-
/** Per-param decay flags from `materializeParams`. When supplied, overrides
|
|
44
|
-
* `config.decayFilter` for any name in the map; falls back to `decayFilter`
|
|
45
|
-
* for names not present (e.g., for low-level callers using `compile()`
|
|
46
|
-
* directly without a Module). */
|
|
47
|
-
decayFlags) {
|
|
48
|
-
const lrIsScheduled = typeof config.lr === 'function';
|
|
49
|
-
const lrFn = lrIsScheduled
|
|
50
|
-
? config.lr
|
|
51
|
-
: (() => config.lr);
|
|
52
|
-
const initialLr = lrFn(1);
|
|
53
|
-
const fullConfig = {
|
|
54
|
-
lr: lrFn,
|
|
55
|
-
b1: config.b1 ?? 0.9,
|
|
56
|
-
b2: config.b2 ?? 0.999,
|
|
57
|
-
eps: config.eps ?? 1e-8,
|
|
58
|
-
weightDecay: config.weightDecay ?? 0,
|
|
59
|
-
decayFilter: config.decayFilter ?? (() => true),
|
|
60
|
-
lrIsScheduled,
|
|
61
|
-
};
|
|
62
|
-
const writebacks = [];
|
|
63
|
-
const lrtInputName = '_adam_lrt';
|
|
64
|
-
// Tensor input for runtime-updated decayShrink (only created when lr is a
|
|
65
|
-
// schedule fn AND at least one param will receive weight decay).
|
|
66
|
-
let decayShrinkInputName = null;
|
|
67
|
-
return traceInto(graph, () => {
|
|
68
|
-
const lrt = tensorInput(lrtInputName, [], 'f32');
|
|
69
|
-
// Up-front: which params receive weight decay? Per-param decayFlags (set
|
|
70
|
-
// by Module.param's options) wins; falls back to decayFilter for names
|
|
71
|
-
// not in the map. Empty when weightDecay = 0 so the rest of the function
|
|
72
|
-
// can just ask "is this name in the set?".
|
|
73
|
-
const decayedNames = new Set(fullConfig.weightDecay > 0
|
|
74
|
-
? Object.keys(paramGrads).filter(name => (decayFlags && name in decayFlags) ? decayFlags[name] : fullConfig.decayFilter(name))
|
|
75
|
-
: []);
|
|
76
|
-
// We only need a runtime decayShrink scalar when lr varies per step AND
|
|
77
|
-
// at least one param is being decayed. Otherwise the value is constant
|
|
78
|
-
// and bakes into the kernel as a literal.
|
|
79
|
-
let decayShrinkScalar = null;
|
|
80
|
-
if (lrIsScheduled && decayedNames.size > 0) {
|
|
81
|
-
decayShrinkInputName = '_adam_decay_shrink';
|
|
82
|
-
decayShrinkScalar = tensorInput(decayShrinkInputName, [], 'f32');
|
|
83
|
-
}
|
|
84
|
-
for (const name of Object.keys(paramGrads)) {
|
|
85
|
-
const p = paramTensors[name];
|
|
86
|
-
const g = paramGrads[name];
|
|
87
|
-
if (!p)
|
|
88
|
-
throw new Error(`appendAdam: missing param tensor for '${name}'`);
|
|
89
|
-
if (!g)
|
|
90
|
-
throw new Error(`appendAdam: missing gradient for '${name}'`);
|
|
91
|
-
const mState = stateInput(`adam_m_${name}`, p.shape, 'f32', 0);
|
|
92
|
-
const vState = stateInput(`adam_v_${name}`, p.shape, 'f32', 0);
|
|
93
|
-
// Choose the decayShrink form per param:
|
|
94
|
-
// - non-decayed params: literal 1 (kernel multiply folds out).
|
|
95
|
-
// - decayed + scheduled lr: tensor input updated per step.
|
|
96
|
-
// - decayed + static lr: literal `1 - lr * wd` baked at compile.
|
|
97
|
-
const decayShrink = !decayedNames.has(name) ? 1
|
|
98
|
-
: decayShrinkScalar !== null ? decayShrinkScalar
|
|
99
|
-
: 1 - initialLr * fullConfig.weightDecay;
|
|
100
|
-
// Three fused kernels per parameter — one for each of m_new / v_new / p_new.
|
|
101
|
-
const newM = adamUpdateM(mState, g, fullConfig.b1);
|
|
102
|
-
const newV = adamUpdateV(vState, g, fullConfig.b2);
|
|
103
|
-
const newP = adamUpdateP(p, newM, newV, lrt, fullConfig.eps, decayShrink);
|
|
104
|
-
writebacks.push({ source: newM, destName: `adam_m_${name}`, destKind: 'state' });
|
|
105
|
-
writebacks.push({ source: newV, destName: `adam_v_${name}`, destKind: 'state' });
|
|
106
|
-
writebacks.push({ source: newP, destName: name, destKind: 'param' });
|
|
107
|
-
}
|
|
108
|
-
return { writebacks, lrtInputName, decayShrinkInputName, config: fullConfig };
|
|
109
|
-
});
|
|
110
|
-
}
|
|
111
|
-
//# sourceMappingURL=adam.js.map
|
package/dist/adam.js.map
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"adam.js","sourceRoot":"","sources":["../src/adam.ts"],"names":[],"mappings":"AAAA,oCAAoC;AACpC,EAAE;AACF,kFAAkF;AAClF,+EAA+E;AAC/E,EAAE;AACF,mCAAmC;AACnC,kCAAkC;AAClC,mCAAmC;AACnC,gEAAgE;AAChE,EAAE;AACF,uEAAuE;AACvE,sEAAsE;AACtE,wEAAwE;AACxE,4DAA4D;AAC5D,wEAAwE;AACxE,4DAA4D;AAC5D,EAAE;AACF,2EAA2E;AAC3E,2EAA2E;AAC3E,0EAA0E;AAC1E,0EAA0E;AAC1E,6EAA6E;AAC7E,EAAE;AACF,wEAAwE;AACxE,0EAA0E;AAC1E,0EAA0E;AAC1E,iDAAiD;AAKjD,OAAO,EAAE,SAAS,EAAE,UAAU,EAAE,WAAW,EAAE,MAAM,YAAY,CAAA;AAC/D,OAAO,EAAE,WAAW,EAAE,WAAW,EAAE,WAAW,EAAE,MAAM,UAAU,CAAA;AAgDhE;;;;;;;;;;;GAWG;AACH,MAAM,UAAU,UAAU,CACxB,KAAY,EACZ,UAAkC,EAClC,YAAoC,EACpC,MAAkB;AAClB;;;kCAGkC;AAClC,UAAoC;IAEpC,MAAM,aAAa,GAAG,OAAO,MAAM,CAAC,EAAE,KAAK,UAAU,CAAA;IACrD,MAAM,IAAI,GAAG,aAAa;QACxB,CAAC,CAAC,MAAM,CAAC,EAA8B;QACvC,CAAC,CAAC,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,EAAY,CAAC,CAAA;IAC/B,MAAM,SAAS,GAAG,IAAI,CAAC,CAAC,CAAC,CAAA;IACzB,MAAM,UAAU,GAAuB;QACrC,EAAE,EAAE,IAAI;QACR,EAAE,EAAE,MAAM,CAAC,EAAE,IAAI,GAAG;QACpB,EAAE,EAAE,MAAM,CAAC,EAAE,IAAI,KAAK;QACtB,GAAG,EAAE,MAAM,CAAC,GAAG,IAAI,IAAI;QACvB,WAAW,EAAE,MAAM,CAAC,WAAW,IAAI,CAAC;QACpC,WAAW,EAAE,MAAM,CAAC,WAAW,IAAI,CAAC,GAAG,EAAE,CAAC,IAAI,CAAC;QAC/C,aAAa;KACd,CAAA;IACD,MAAM,UAAU,GAAoB,EAAE,CAAA;IACtC,MAAM,YAAY,GAAG,WAAW,CAAA;IAChC,0EAA0E;IAC1E,iEAAiE;IACjE,IAAI,oBAAoB,GAAkB,IAAI,CAAA;IAE9C,OAAO,SAAS,CAAC,KAAK,EAAE,GAAG,EAAE;QAC3B,MAAM,GAAG,GAAG,WAAW,CAAC,YAAY,EAAE,EAAE,EAAE,KAAK,CAAC,CAAA;QAEhD,yEAAyE;QACzE,uEAAuE;QACvE,yEAAyE;QACzE,2CAA2C;QAC3C,MAAM,YAAY,GAAG,IAAI,GAAG,CAC1B,UAAU,CAAC,WAAW,GAAG,CAAC;YACxB,CAAC,CAAC,MAAM,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC,MAAM,CAAC,IAAI,CAAC,EAAE,CACpC,CAAC,UAAU,IAAI,IAAI,IAAI,UAAU,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC,IAAI,CAAE,CAAC,CAAC,CAAC,UAAU,CAAC,WAAW,CAAC,IAAI,CAAC,CAAC;YAC1F,CAAC,CAAC,EAAE,CACP,CAAA;QAED,wEAAwE;QACxE,uEAAuE;QACvE,0CAA0C;QAC1C,IAAI,iBAAiB,GAAkB,IAAI,CAAA;QAC3C,IAAI,aAAa,IAAI,YAAY,CAAC,IAAI,GAAG,CAAC,EAAE,CAAC;YAC3C,oBAAoB,GAAG,oBAAoB,CAAA;YAC3C,iBAAiB,GAAG,WAAW,CAAC,oBAAoB,EAAE,EAAE,EAAE,KAAK,CAAC,CAAA;QAClE,CAAC;QAED,KAAK,MAAM,IAAI,IAAI,MAAM,CAAC,IAAI,CAAC,UAAU,CAAC,EAAE,CAAC;YAC3C,MAAM,CAAC,GAAG,YAAY,CAAC,IAAI,CAAC,CAAA;YAC5B,MAAM,CAAC,GAAG,UAAU,CAAC,IAAI,CAAC,CAAA;YAC1B,IAAI,CAAC,CAAC;gBAAE,MAAM,IAAI,KAAK,CAAC,yCAAyC,IAAI,GAAG,CAAC,CAAA;YACzE,IAAI,CAAC,CAAC;gBAAE,MAAM,IAAI,KAAK,CAAC,qCAAqC,IAAI,GAAG,CAAC,CAAA;YAErE,MAAM,MAAM,GAAG,UAAU,CAAC,UAAU,IAAI,EAAE,EAAE,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,CAAC,CAAC,CAAA;YAC9D,MAAM,MAAM,GAAG,UAAU,CAAC,UAAU,IAAI,EAAE,EAAE,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,CAAC,CAAC,CAAA;YAE9D,yCAAyC;YACzC,iEAAiE;YACjE,6DAA6D;YAC7D,mEAAmE;YACnE,MAAM,WAAW,GACf,CAAC,YAAY,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC;gBAC3B,CAAC,CAAC,iBAAiB,KAAK,IAAI,CAAC,CAAC,CAAC,iBAAiB;oBAChD,CAAC,CAAC,CAAC,GAAG,SAAS,GAAG,UAAU,CAAC,WAAW,CAAA;YAE1C,6EAA6E;YAC7E,MAAM,IAAI,GAAG,WAAW,CAAC,MAAM,EAAE,CAAC,EAAE,UAAU,CAAC,EAAE,CAAC,CAAA;YAClD,MAAM,IAAI,GAAG,WAAW,CAAC,MAAM,EAAE,CAAC,EAAE,UAAU,CAAC,EAAE,CAAC,CAAA;YAClD,MAAM,IAAI,GAAG,WAAW,CAAC,CAAC,EAAE,IAAI,EAAE,IAAI,EAAE,GAAG,EAAE,UAAU,CAAC,GAAG,EAAE,WAAW,CAAC,CAAA;YAEzE,UAAU,CAAC,IAAI,CAAC,EAAE,MAAM,EAAE,IAAI,EAAE,QAAQ,EAAE,UAAU,IAAI,EAAE,EAAE,QAAQ,EAAE,OAAO,EAAE,CAAC,CAAA;YAChF,UAAU,CAAC,IAAI,CAAC,EAAE,MAAM,EAAE,IAAI,EAAE,QAAQ,EAAE,UAAU,IAAI,EAAE,EAAE,QAAQ,EAAE,OAAO,EAAE,CAAC,CAAA;YAChF,UAAU,CAAC,IAAI,CAAC,EAAE,MAAM,EAAE,IAAI,EAAE,QAAQ,EAAE,IAAI,EAAc,QAAQ,EAAE,OAAO,EAAE,CAAC,CAAA;QAClF,CAAC;QACD,OAAO,EAAE,UAAU,EAAE,YAAY,EAAE,oBAAoB,EAAE,MAAM,EAAE,UAAU,EAAE,CAAA;IAC/E,CAAC,CAAC,CAAA;AACJ,CAAC"}
|
package/dist/buffers.js
DELETED
|
@@ -1,120 +0,0 @@
|
|
|
1
|
-
// Buffer planning: walk a Graph and decide which GPU buffer each Tensor maps to.
|
|
2
|
-
//
|
|
3
|
-
// v1 strategy: one GPU buffer per IR Tensor. Static shapes mean every buffer's
|
|
4
|
-
// size is known at compile time and lifetimes don't overlap between steps —
|
|
5
|
-
// so no pooling needed. Total memory is the sum of every intermediate tensor.
|
|
6
|
-
// For our transformer at B=256: ~30 MB of activations + grads. Easily fits.
|
|
7
|
-
//
|
|
8
|
-
// Categorization is what the runtime cares about:
|
|
9
|
-
// * param — uploaded by user via uploadParams; persistent across steps
|
|
10
|
-
// * param_grad — written each step by the backward pass; readable for inspection
|
|
11
|
-
// * tensor_input — uploaded each step (tokens, targets, masks)
|
|
12
|
-
// * intermediate — produced by an op; lifetime = within a single step
|
|
13
|
-
// * output — special intermediate that should be made readable (loss)
|
|
14
|
-
import { shapeSize } from './shape.js';
|
|
15
|
-
const dtypeBytes = { f32: 4, i32: 4, bool: 4 };
|
|
16
|
-
/**
|
|
17
|
-
* Build a BufferPlan from a graph + the param-grad map produced by appendGrad.
|
|
18
|
-
* @param graph the full graph (forward + backward + any optimizer ops)
|
|
19
|
-
* @param paramGrads map from param name -> the Tensor that holds its gradient
|
|
20
|
-
* @param writebackDecls list of end-of-step writebacks (e.g. from appendAdam).
|
|
21
|
-
* Empty when there's no optimizer in the graph.
|
|
22
|
-
*/
|
|
23
|
-
export function planBuffers(graph, paramGrads, writebackDecls = []) {
|
|
24
|
-
const buffers = [];
|
|
25
|
-
const tensorToBuffer = new Map();
|
|
26
|
-
const paramsByName = new Map();
|
|
27
|
-
const inputsByName = new Map();
|
|
28
|
-
const paramGradsByName = new Map();
|
|
29
|
-
const statesByName = new Map();
|
|
30
|
-
// Build a quick reverse map: tensorId -> param name (for grads).
|
|
31
|
-
const gradTensorIdToName = new Map();
|
|
32
|
-
for (const [name, tensor] of Object.entries(paramGrads)) {
|
|
33
|
-
gradTensorIdToName.set(tensor.id, name);
|
|
34
|
-
}
|
|
35
|
-
// ...and tensorId -> param/input op (so we can name the buffer correctly).
|
|
36
|
-
const opByOutId = new Map();
|
|
37
|
-
for (const op of graph.ops)
|
|
38
|
-
opByOutId.set(op.out, op);
|
|
39
|
-
const outputSet = new Set(graph.outputs);
|
|
40
|
-
// Walk all tensors in id order. Categorize each.
|
|
41
|
-
for (const t of graph.tensors) {
|
|
42
|
-
const op = opByOutId.get(t.id);
|
|
43
|
-
let kind = 'intermediate';
|
|
44
|
-
let name = null;
|
|
45
|
-
let initValue;
|
|
46
|
-
if (op?.kind === 'param_input') {
|
|
47
|
-
kind = 'param';
|
|
48
|
-
name = op.name;
|
|
49
|
-
}
|
|
50
|
-
else if (op?.kind === 'tensor_input') {
|
|
51
|
-
kind = 'tensor_input';
|
|
52
|
-
name = op.name;
|
|
53
|
-
}
|
|
54
|
-
else if (op?.kind === 'state_input') {
|
|
55
|
-
kind = 'state';
|
|
56
|
-
name = op.name;
|
|
57
|
-
initValue = op.initValue;
|
|
58
|
-
}
|
|
59
|
-
else if (gradTensorIdToName.has(t.id)) {
|
|
60
|
-
kind = 'param_grad';
|
|
61
|
-
name = gradTensorIdToName.get(t.id);
|
|
62
|
-
}
|
|
63
|
-
else if (outputSet.has(t.id)) {
|
|
64
|
-
kind = 'output';
|
|
65
|
-
}
|
|
66
|
-
const spec = {
|
|
67
|
-
id: t.id,
|
|
68
|
-
byteSize: Math.max(4, shapeSize(t.shape) * dtypeBytes[t.dtype]),
|
|
69
|
-
dtype: t.dtype,
|
|
70
|
-
shape: t.shape,
|
|
71
|
-
kind,
|
|
72
|
-
name,
|
|
73
|
-
...(initValue !== undefined ? { initValue } : {}),
|
|
74
|
-
};
|
|
75
|
-
buffers.push(spec);
|
|
76
|
-
tensorToBuffer.set(t.id, t.id); // 1:1 for v1
|
|
77
|
-
if (kind === 'param')
|
|
78
|
-
paramsByName.set(name, t.id);
|
|
79
|
-
if (kind === 'tensor_input')
|
|
80
|
-
inputsByName.set(name, t.id);
|
|
81
|
-
if (kind === 'param_grad')
|
|
82
|
-
paramGradsByName.set(name, t.id);
|
|
83
|
-
if (kind === 'state')
|
|
84
|
-
statesByName.set(name, t.id);
|
|
85
|
-
}
|
|
86
|
-
const outputBufferIds = graph.outputs.map(id => tensorToBuffer.get(id));
|
|
87
|
-
// Resolve writeback declarations to (source, dest) buffer-id pairs.
|
|
88
|
-
const writebacks = writebackDecls.map(decl => {
|
|
89
|
-
const sourceBufId = tensorToBuffer.get(decl.source.id);
|
|
90
|
-
if (sourceBufId === undefined) {
|
|
91
|
-
throw new Error(`planBuffers: writeback source tensor #${decl.source.id} not in graph`);
|
|
92
|
-
}
|
|
93
|
-
const destBufId = decl.destKind === 'param'
|
|
94
|
-
? paramsByName.get(decl.destName)
|
|
95
|
-
: statesByName.get(decl.destName);
|
|
96
|
-
if (destBufId === undefined) {
|
|
97
|
-
throw new Error(`planBuffers: writeback dest ${decl.destKind}:'${decl.destName}' not found`);
|
|
98
|
-
}
|
|
99
|
-
const sourceSpec = buffers[sourceBufId];
|
|
100
|
-
const destSpec = buffers[destBufId];
|
|
101
|
-
if (sourceSpec.byteSize !== destSpec.byteSize) {
|
|
102
|
-
throw new Error(`planBuffers: writeback size mismatch for ${decl.destKind}:'${decl.destName}' ` +
|
|
103
|
-
`(source ${sourceSpec.byteSize} bytes vs dest ${destSpec.byteSize})`);
|
|
104
|
-
}
|
|
105
|
-
return { source: sourceBufId, dest: destBufId, bytes: sourceSpec.byteSize };
|
|
106
|
-
});
|
|
107
|
-
// Resolve graph.captures (name -> tensor id) to (name -> buffer id).
|
|
108
|
-
// No pinning needed at the planner level: each tensor already has its own
|
|
109
|
-
// buffer (see "v1 strategy" comment at top — no pooling yet).
|
|
110
|
-
const capturesByName = new Map();
|
|
111
|
-
for (const [name, tensorId] of graph.captures) {
|
|
112
|
-
const bufId = tensorToBuffer.get(tensorId);
|
|
113
|
-
if (bufId === undefined) {
|
|
114
|
-
throw new Error(`planBuffers: capture '${name}' references unknown tensor #${tensorId}`);
|
|
115
|
-
}
|
|
116
|
-
capturesByName.set(name, bufId);
|
|
117
|
-
}
|
|
118
|
-
return { buffers, tensorToBuffer, paramsByName, inputsByName, paramGradsByName, statesByName, capturesByName, outputBufferIds, writebacks };
|
|
119
|
-
}
|
|
120
|
-
//# sourceMappingURL=buffers.js.map
|
package/dist/buffers.js.map
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"buffers.js","sourceRoot":"","sources":["../src/buffers.ts"],"names":[],"mappings":"AAAA,iFAAiF;AACjF,EAAE;AACF,+EAA+E;AAC/E,4EAA4E;AAC5E,8EAA8E;AAC9E,4EAA4E;AAC5E,EAAE;AACF,kDAAkD;AAClD,gFAAgF;AAChF,qFAAqF;AACrF,iEAAiE;AACjE,wEAAwE;AACxE,8EAA8E;AAG9E,OAAO,EAAE,SAAS,EAAE,MAAM,YAAY,CAAA;AAyCtC,MAAM,UAAU,GAA0B,EAAE,GAAG,EAAE,CAAC,EAAE,GAAG,EAAE,CAAC,EAAE,IAAI,EAAE,CAAC,EAAE,CAAA;AAcrE;;;;;;GAMG;AACH,MAAM,UAAU,WAAW,CACzB,KAAY,EACZ,UAAkC,EAClC,iBAAkC,EAAE;IAEpC,MAAM,OAAO,GAAiB,EAAE,CAAA;IAChC,MAAM,cAAc,GAAG,IAAI,GAAG,EAAkB,CAAA;IAChD,MAAM,YAAY,GAAG,IAAI,GAAG,EAAkB,CAAA;IAC9C,MAAM,YAAY,GAAG,IAAI,GAAG,EAAkB,CAAA;IAC9C,MAAM,gBAAgB,GAAG,IAAI,GAAG,EAAkB,CAAA;IAClD,MAAM,YAAY,GAAG,IAAI,GAAG,EAAkB,CAAA;IAE9C,iEAAiE;IACjE,MAAM,kBAAkB,GAAG,IAAI,GAAG,EAAkB,CAAA;IACpD,KAAK,MAAM,CAAC,IAAI,EAAE,MAAM,CAAC,IAAI,MAAM,CAAC,OAAO,CAAC,UAAU,CAAC,EAAE,CAAC;QACxD,kBAAkB,CAAC,GAAG,CAAC,MAAM,CAAC,EAAE,EAAE,IAAI,CAAC,CAAA;IACzC,CAAC;IACD,2EAA2E;IAC3E,MAAM,SAAS,GAAG,IAAI,GAAG,EAAkB,CAAA;IAC3C,KAAK,MAAM,EAAE,IAAI,KAAK,CAAC,GAAG;QAAE,SAAS,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,EAAE,CAAC,CAAA;IAErD,MAAM,SAAS,GAAG,IAAI,GAAG,CAAC,KAAK,CAAC,OAAO,CAAC,CAAA;IAExC,iDAAiD;IACjD,KAAK,MAAM,CAAC,IAAI,KAAK,CAAC,OAAO,EAAE,CAAC;QAC9B,MAAM,EAAE,GAAG,SAAS,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,CAAA;QAC9B,IAAI,IAAI,GAAuB,cAAc,CAAA;QAC7C,IAAI,IAAI,GAAkB,IAAI,CAAA;QAC9B,IAAI,SAA6B,CAAA;QAEjC,IAAI,EAAE,EAAE,IAAI,KAAK,aAAa,EAAE,CAAC;YAC/B,IAAI,GAAG,OAAO,CAAA;YACd,IAAI,GAAG,EAAE,CAAC,IAAI,CAAA;QAChB,CAAC;aAAM,IAAI,EAAE,EAAE,IAAI,KAAK,cAAc,EAAE,CAAC;YACvC,IAAI,GAAG,cAAc,CAAA;YACrB,IAAI,GAAG,EAAE,CAAC,IAAI,CAAA;QAChB,CAAC;aAAM,IAAI,EAAE,EAAE,IAAI,KAAK,aAAa,EAAE,CAAC;YACtC,IAAI,GAAG,OAAO,CAAA;YACd,IAAI,GAAG,EAAE,CAAC,IAAI,CAAA;YACd,SAAS,GAAG,EAAE,CAAC,SAAS,CAAA;QAC1B,CAAC;aAAM,IAAI,kBAAkB,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC;YACxC,IAAI,GAAG,YAAY,CAAA;YACnB,IAAI,GAAG,kBAAkB,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAE,CAAA;QACtC,CAAC;aAAM,IAAI,SAAS,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC;YAC/B,IAAI,GAAG,QAAQ,CAAA;QACjB,CAAC;QAED,MAAM,IAAI,GAAe;YACvB,EAAE,EAAE,CAAC,CAAC,EAAE;YACR,QAAQ,EAAE,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC,KAAK,CAAC,GAAG,UAAU,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;YAC/D,KAAK,EAAE,CAAC,CAAC,KAAK;YACd,KAAK,EAAE,CAAC,CAAC,KAAK;YACd,IAAI;YACJ,IAAI;YACJ,GAAG,CAAC,SAAS,KAAK,SAAS,CAAC,CAAC,CAAC,EAAE,SAAS,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC;SAClD,CAAA;QACD,OAAO,CAAC,IAAI,CAAC,IAAI,CAAC,CAAA;QAClB,cAAc,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,CAAA,CAAE,aAAa;QAE7C,IAAI,IAAI,KAAK,OAAO;YAAE,YAAY,CAAC,GAAG,CAAC,IAAK,EAAE,CAAC,CAAC,EAAE,CAAC,CAAA;QACnD,IAAI,IAAI,KAAK,cAAc;YAAE,YAAY,CAAC,GAAG,CAAC,IAAK,EAAE,CAAC,CAAC,EAAE,CAAC,CAAA;QAC1D,IAAI,IAAI,KAAK,YAAY;YAAE,gBAAgB,CAAC,GAAG,CAAC,IAAK,EAAE,CAAC,CAAC,EAAE,CAAC,CAAA;QAC5D,IAAI,IAAI,KAAK,OAAO;YAAE,YAAY,CAAC,GAAG,CAAC,IAAK,EAAE,CAAC,CAAC,EAAE,CAAC,CAAA;IACrD,CAAC;IAED,MAAM,eAAe,GAAG,KAAK,CAAC,OAAO,CAAC,GAAG,CAAC,EAAE,CAAC,EAAE,CAAC,cAAc,CAAC,GAAG,CAAC,EAAE,CAAE,CAAC,CAAA;IAExE,oEAAoE;IACpE,MAAM,UAAU,GAAgB,cAAc,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE;QACxD,MAAM,WAAW,GAAG,cAAc,CAAC,GAAG,CAAC,IAAI,CAAC,MAAM,CAAC,EAAE,CAAC,CAAA;QACtD,IAAI,WAAW,KAAK,SAAS,EAAE,CAAC;YAC9B,MAAM,IAAI,KAAK,CAAC,yCAAyC,IAAI,CAAC,MAAM,CAAC,EAAE,eAAe,CAAC,CAAA;QACzF,CAAC;QACD,MAAM,SAAS,GAAG,IAAI,CAAC,QAAQ,KAAK,OAAO;YACzC,CAAC,CAAC,YAAY,CAAC,GAAG,CAAC,IAAI,CAAC,QAAQ,CAAC;YACjC,CAAC,CAAC,YAAY,CAAC,GAAG,CAAC,IAAI,CAAC,QAAQ,CAAC,CAAA;QACnC,IAAI,SAAS,KAAK,SAAS,EAAE,CAAC;YAC5B,MAAM,IAAI,KAAK,CAAC,+BAA+B,IAAI,CAAC,QAAQ,KAAK,IAAI,CAAC,QAAQ,aAAa,CAAC,CAAA;QAC9F,CAAC;QACD,MAAM,UAAU,GAAG,OAAO,CAAC,WAAW,CAAE,CAAA;QACxC,MAAM,QAAQ,GAAG,OAAO,CAAC,SAAS,CAAE,CAAA;QACpC,IAAI,UAAU,CAAC,QAAQ,KAAK,QAAQ,CAAC,QAAQ,EAAE,CAAC;YAC9C,MAAM,IAAI,KAAK,CACb,4CAA4C,IAAI,CAAC,QAAQ,KAAK,IAAI,CAAC,QAAQ,IAAI;gBAC/E,WAAW,UAAU,CAAC,QAAQ,kBAAkB,QAAQ,CAAC,QAAQ,GAAG,CACrE,CAAA;QACH,CAAC;QACD,OAAO,EAAE,MAAM,EAAE,WAAW,EAAE,IAAI,EAAE,SAAS,EAAE,KAAK,EAAE,UAAU,CAAC,QAAQ,EAAE,CAAA;IAC7E,CAAC,CAAC,CAAA;IAEF,qEAAqE;IACrE,0EAA0E;IAC1E,8DAA8D;IAC9D,MAAM,cAAc,GAAG,IAAI,GAAG,EAAkB,CAAA;IAChD,KAAK,MAAM,CAAC,IAAI,EAAE,QAAQ,CAAC,IAAI,KAAK,CAAC,QAAQ,EAAE,CAAC;QAC9C,MAAM,KAAK,GAAG,cAAc,CAAC,GAAG,CAAC,QAAQ,CAAC,CAAA;QAC1C,IAAI,KAAK,KAAK,SAAS,EAAE,CAAC;YACxB,MAAM,IAAI,KAAK,CAAC,yBAAyB,IAAI,gCAAgC,QAAQ,EAAE,CAAC,CAAA;QAC1F,CAAC;QACD,cAAc,CAAC,GAAG,CAAC,IAAI,EAAE,KAAK,CAAC,CAAA;IACjC,CAAC;IAED,OAAO,EAAE,OAAO,EAAE,cAAc,EAAE,YAAY,EAAE,YAAY,EAAE,gBAAgB,EAAE,YAAY,EAAE,cAAc,EAAE,eAAe,EAAE,UAAU,EAAE,CAAA;AAC7I,CAAC"}
|
package/dist/capture.js
DELETED
|
@@ -1,33 +0,0 @@
|
|
|
1
|
-
// Activation capture — opt-in readback of intermediate tensors at training step.
|
|
2
|
-
//
|
|
3
|
-
// Usage (inside the user's forward pass):
|
|
4
|
-
//
|
|
5
|
-
// import { capture } from 'tensorgrad'
|
|
6
|
-
//
|
|
7
|
-
// function attentionFwd(p, x) {
|
|
8
|
-
// const scores = mul(matmulBatched(q, kT), SCALE_QK)
|
|
9
|
-
// const attn = capture(`attn.${layerIdx}`, softmaxCausalLast(scores))
|
|
10
|
-
// return matmulBatched(attn, v)
|
|
11
|
-
// }
|
|
12
|
-
//
|
|
13
|
-
// Pass-through return type: `capture(name, t)` returns `t` unchanged so it
|
|
14
|
-
// inlines at the point of computation. Behind the scenes it registers `t.id`
|
|
15
|
-
// against `name` on the current graph; runtime exposes the registered tensors
|
|
16
|
-
// via `step(inputs, { withCaptures: true })`.
|
|
17
|
-
//
|
|
18
|
-
// Outside the user's forward trace (during `appendGrad` / `appendAdam`'s
|
|
19
|
-
// `traceInto` re-entry), `capture()` is a no-op — gradient and optimizer
|
|
20
|
-
// internals shouldn't accidentally publish themselves to the UI.
|
|
21
|
-
import { currentGraph, isCaptureEnabled } from './trace.js';
|
|
22
|
-
export function capture(name, t) {
|
|
23
|
-
if (!isCaptureEnabled())
|
|
24
|
-
return t;
|
|
25
|
-
const g = currentGraph();
|
|
26
|
-
if (g.captures.has(name)) {
|
|
27
|
-
throw new Error(`capture: name '${name}' already registered. Use unique names ` +
|
|
28
|
-
`(e.g. \`attn.\${layerIdx}\`) when capturing across a loop.`);
|
|
29
|
-
}
|
|
30
|
-
g.captures.set(name, t.id);
|
|
31
|
-
return t;
|
|
32
|
-
}
|
|
33
|
-
//# sourceMappingURL=capture.js.map
|
package/dist/capture.js.map
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"capture.js","sourceRoot":"","sources":["../src/capture.ts"],"names":[],"mappings":"AAAA,iFAAiF;AACjF,EAAE;AACF,0CAA0C;AAC1C,EAAE;AACF,yCAAyC;AACzC,EAAE;AACF,kCAAkC;AAClC,yDAAyD;AACzD,0EAA0E;AAC1E,oCAAoC;AACpC,MAAM;AACN,EAAE;AACF,2EAA2E;AAC3E,6EAA6E;AAC7E,8EAA8E;AAC9E,8CAA8C;AAC9C,EAAE;AACF,yEAAyE;AACzE,yEAAyE;AACzE,iEAAiE;AAGjE,OAAO,EAAE,YAAY,EAAE,gBAAgB,EAAE,MAAM,YAAY,CAAA;AAE3D,MAAM,UAAU,OAAO,CAAmB,IAAY,EAAE,CAAI;IAC1D,IAAI,CAAC,gBAAgB,EAAE;QAAE,OAAO,CAAC,CAAA;IACjC,MAAM,CAAC,GAAG,YAAY,EAAE,CAAA;IACxB,IAAI,CAAC,CAAC,QAAQ,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,CAAC;QACzB,MAAM,IAAI,KAAK,CACb,kBAAkB,IAAI,yCAAyC;YAC/D,4DAA4D,CAC7D,CAAA;IACH,CAAC;IACD,CAAC,CAAC,QAAQ,CAAC,GAAG,CAAC,IAAI,EAAE,CAAC,CAAC,EAAE,CAAC,CAAA;IAC1B,OAAO,CAAC,CAAA;AACV,CAAC"}
|