tensorgrad 0.0.15 → 0.0.16

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/adam.js DELETED
@@ -1,111 +0,0 @@
1
- // Adam / AdamW optimizer, in-graph.
2
- //
3
- // `appendAdam` extends a graph that already has a forward pass + autograd-emitted
4
- // backward (i.e., has paramGrads from `appendGrad`) with the Adam update math.
5
- //
6
- // Per parameter P with gradient g:
7
- // m_new = b1 * m + (1 - b1) * g
8
- // v_new = b2 * v + (1 - b2) * g²
9
- // p_new = decayShrink * p - lrt * m_new / (sqrt(v_new) + eps)
10
- //
11
- // `decayShrink = 1 - lr * weightDecay` when the param is being decayed
12
- // (Loshchilov & Hutter, "AdamW") and 1 otherwise — at which point the
13
- // multiply folds out and you're left with plain Adam. `lrt` is supplied
14
- // per-step from CPU and includes the bias-correction factor
15
- // `sqrt(1-b2^t)/(1-b1^t)`; that's why convergence isn't affected by the
16
- // first-step warmup that bias-correction-free Adam suffers.
17
- //
18
- // **Static vs scheduled lr.** When `config.lr` is a number, decayShrink is
19
- // baked into the kernel as a literal. When it's a function `(step) => lr`,
20
- // decayShrink for decayed params becomes a per-step scalar input that the
21
- // runtime updates each call (computed from the current step's lr). lrt is
22
- // always per-step; the bias-correction factor changes every step regardless.
23
- //
24
- // Returns writeback declarations the buffer planner uses to wire up the
25
- // "after step, copy the new value into the persistent home" path. m and v
26
- // are state_inputs (zero-initialized, persistent across steps); the param
27
- // updates are aliased back to the param buffers.
28
- import { traceInto, stateInput, tensorInput } from './trace.js';
29
- import { adamUpdateM, adamUpdateV, adamUpdateP } from './ops.js';
30
- /**
31
- * Append Adam update ops to `graph`. Must be called inside an active trace
32
- * context (or after a trace, since traceInto re-enters the graph).
33
- *
34
- * @param graph the graph (already containing forward + backward)
35
- * @param paramGrads param name -> gradient tensor (output of `appendGrad`)
36
- * @param paramTensors param name -> the param's leaf Tensor (the param_input).
37
- * Needed because the param_input lives in the graph but we
38
- * don't have a direct map by name in `Graph` — caller passes it.
39
- * @param config Adam hyperparameters. Set `weightDecay > 0` for AdamW; an
40
- * optional `decayFilter` selects which params receive decay.
41
- */
42
- export function appendAdam(graph, paramGrads, paramTensors, config,
43
- /** Per-param decay flags from `materializeParams`. When supplied, overrides
44
- * `config.decayFilter` for any name in the map; falls back to `decayFilter`
45
- * for names not present (e.g., for low-level callers using `compile()`
46
- * directly without a Module). */
47
- decayFlags) {
48
- const lrIsScheduled = typeof config.lr === 'function';
49
- const lrFn = lrIsScheduled
50
- ? config.lr
51
- : (() => config.lr);
52
- const initialLr = lrFn(1);
53
- const fullConfig = {
54
- lr: lrFn,
55
- b1: config.b1 ?? 0.9,
56
- b2: config.b2 ?? 0.999,
57
- eps: config.eps ?? 1e-8,
58
- weightDecay: config.weightDecay ?? 0,
59
- decayFilter: config.decayFilter ?? (() => true),
60
- lrIsScheduled,
61
- };
62
- const writebacks = [];
63
- const lrtInputName = '_adam_lrt';
64
- // Tensor input for runtime-updated decayShrink (only created when lr is a
65
- // schedule fn AND at least one param will receive weight decay).
66
- let decayShrinkInputName = null;
67
- return traceInto(graph, () => {
68
- const lrt = tensorInput(lrtInputName, [], 'f32');
69
- // Up-front: which params receive weight decay? Per-param decayFlags (set
70
- // by Module.param's options) wins; falls back to decayFilter for names
71
- // not in the map. Empty when weightDecay = 0 so the rest of the function
72
- // can just ask "is this name in the set?".
73
- const decayedNames = new Set(fullConfig.weightDecay > 0
74
- ? Object.keys(paramGrads).filter(name => (decayFlags && name in decayFlags) ? decayFlags[name] : fullConfig.decayFilter(name))
75
- : []);
76
- // We only need a runtime decayShrink scalar when lr varies per step AND
77
- // at least one param is being decayed. Otherwise the value is constant
78
- // and bakes into the kernel as a literal.
79
- let decayShrinkScalar = null;
80
- if (lrIsScheduled && decayedNames.size > 0) {
81
- decayShrinkInputName = '_adam_decay_shrink';
82
- decayShrinkScalar = tensorInput(decayShrinkInputName, [], 'f32');
83
- }
84
- for (const name of Object.keys(paramGrads)) {
85
- const p = paramTensors[name];
86
- const g = paramGrads[name];
87
- if (!p)
88
- throw new Error(`appendAdam: missing param tensor for '${name}'`);
89
- if (!g)
90
- throw new Error(`appendAdam: missing gradient for '${name}'`);
91
- const mState = stateInput(`adam_m_${name}`, p.shape, 'f32', 0);
92
- const vState = stateInput(`adam_v_${name}`, p.shape, 'f32', 0);
93
- // Choose the decayShrink form per param:
94
- // - non-decayed params: literal 1 (kernel multiply folds out).
95
- // - decayed + scheduled lr: tensor input updated per step.
96
- // - decayed + static lr: literal `1 - lr * wd` baked at compile.
97
- const decayShrink = !decayedNames.has(name) ? 1
98
- : decayShrinkScalar !== null ? decayShrinkScalar
99
- : 1 - initialLr * fullConfig.weightDecay;
100
- // Three fused kernels per parameter — one for each of m_new / v_new / p_new.
101
- const newM = adamUpdateM(mState, g, fullConfig.b1);
102
- const newV = adamUpdateV(vState, g, fullConfig.b2);
103
- const newP = adamUpdateP(p, newM, newV, lrt, fullConfig.eps, decayShrink);
104
- writebacks.push({ source: newM, destName: `adam_m_${name}`, destKind: 'state' });
105
- writebacks.push({ source: newV, destName: `adam_v_${name}`, destKind: 'state' });
106
- writebacks.push({ source: newP, destName: name, destKind: 'param' });
107
- }
108
- return { writebacks, lrtInputName, decayShrinkInputName, config: fullConfig };
109
- });
110
- }
111
- //# sourceMappingURL=adam.js.map
package/dist/adam.js.map DELETED
@@ -1 +0,0 @@
1
- {"version":3,"file":"adam.js","sourceRoot":"","sources":["../src/adam.ts"],"names":[],"mappings":"AAAA,oCAAoC;AACpC,EAAE;AACF,kFAAkF;AAClF,+EAA+E;AAC/E,EAAE;AACF,mCAAmC;AACnC,kCAAkC;AAClC,mCAAmC;AACnC,gEAAgE;AAChE,EAAE;AACF,uEAAuE;AACvE,sEAAsE;AACtE,wEAAwE;AACxE,4DAA4D;AAC5D,wEAAwE;AACxE,4DAA4D;AAC5D,EAAE;AACF,2EAA2E;AAC3E,2EAA2E;AAC3E,0EAA0E;AAC1E,0EAA0E;AAC1E,6EAA6E;AAC7E,EAAE;AACF,wEAAwE;AACxE,0EAA0E;AAC1E,0EAA0E;AAC1E,iDAAiD;AAKjD,OAAO,EAAE,SAAS,EAAE,UAAU,EAAE,WAAW,EAAE,MAAM,YAAY,CAAA;AAC/D,OAAO,EAAE,WAAW,EAAE,WAAW,EAAE,WAAW,EAAE,MAAM,UAAU,CAAA;AAgDhE;;;;;;;;;;;GAWG;AACH,MAAM,UAAU,UAAU,CACxB,KAAY,EACZ,UAAkC,EAClC,YAAoC,EACpC,MAAkB;AAClB;;;kCAGkC;AAClC,UAAoC;IAEpC,MAAM,aAAa,GAAG,OAAO,MAAM,CAAC,EAAE,KAAK,UAAU,CAAA;IACrD,MAAM,IAAI,GAAG,aAAa;QACxB,CAAC,CAAC,MAAM,CAAC,EAA8B;QACvC,CAAC,CAAC,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,EAAY,CAAC,CAAA;IAC/B,MAAM,SAAS,GAAG,IAAI,CAAC,CAAC,CAAC,CAAA;IACzB,MAAM,UAAU,GAAuB;QACrC,EAAE,EAAE,IAAI;QACR,EAAE,EAAE,MAAM,CAAC,EAAE,IAAI,GAAG;QACpB,EAAE,EAAE,MAAM,CAAC,EAAE,IAAI,KAAK;QACtB,GAAG,EAAE,MAAM,CAAC,GAAG,IAAI,IAAI;QACvB,WAAW,EAAE,MAAM,CAAC,WAAW,IAAI,CAAC;QACpC,WAAW,EAAE,MAAM,CAAC,WAAW,IAAI,CAAC,GAAG,EAAE,CAAC,IAAI,CAAC;QAC/C,aAAa;KACd,CAAA;IACD,MAAM,UAAU,GAAoB,EAAE,CAAA;IACtC,MAAM,YAAY,GAAG,WAAW,CAAA;IAChC,0EAA0E;IAC1E,iEAAiE;IACjE,IAAI,oBAAoB,GAAkB,IAAI,CAAA;IAE9C,OAAO,SAAS,CAAC,KAAK,EAAE,GAAG,EAAE;QAC3B,MAAM,GAAG,GAAG,WAAW,CAAC,YAAY,EAAE,EAAE,EAAE,KAAK,CAAC,CAAA;QAEhD,yEAAyE;QACzE,uEAAuE;QACvE,yEAAyE;QACzE,2CAA2C;QAC3C,MAAM,YAAY,GAAG,IAAI,GAAG,CAC1B,UAAU,CAAC,WAAW,GAAG,CAAC;YACxB,CAAC,CAAC,MAAM,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC,MAAM,CAAC,IAAI,CAAC,EAAE,CACpC,CAAC,UAAU,IAAI,IAAI,IAAI,UAAU,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC,IAAI,CAAE,CAAC,CAAC,CAAC,UAAU,CAAC,WAAW,CAAC,IAAI,CAAC,CAAC;YAC1F,CAAC,CAAC,EAAE,CACP,CAAA;QAED,wEAAwE;QACxE,uEAAuE;QACvE,0CAA0C;QAC1C,IAAI,iBAAiB,GAAkB,IAAI,CAAA;QAC3C,IAAI,aAAa,IAAI,YAAY,CAAC,IAAI,GAAG,CAAC,EAAE,CAAC;YAC3C,oBAAoB,GAAG,oBAAoB,CAAA;YAC3C,iBAAiB,GAAG,WAAW,CAAC,oBAAoB,EAAE,EAAE,EAAE,KAAK,CAAC,CAAA;QAClE,CAAC;QAED,KAAK,MAAM,IAAI,IAAI,MAAM,CAAC,IAAI,CAAC,UAAU,CAAC,EAAE,CAAC;YAC3C,MAAM,CAAC,GAAG,YAAY,CAAC,IAAI,CAAC,CAAA;YAC5B,MAAM,CAAC,GAAG,UAAU,CAAC,IAAI,CAAC,CAAA;YAC1B,IAAI,CAAC,CAAC;gBAAE,MAAM,IAAI,KAAK,CAAC,yCAAyC,IAAI,GAAG,CAAC,CAAA;YACzE,IAAI,CAAC,CAAC;gBAAE,MAAM,IAAI,KAAK,CAAC,qCAAqC,IAAI,GAAG,CAAC,CAAA;YAErE,MAAM,MAAM,GAAG,UAAU,CAAC,UAAU,IAAI,EAAE,EAAE,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,CAAC,CAAC,CAAA;YAC9D,MAAM,MAAM,GAAG,UAAU,CAAC,UAAU,IAAI,EAAE,EAAE,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,CAAC,CAAC,CAAA;YAE9D,yCAAyC;YACzC,iEAAiE;YACjE,6DAA6D;YAC7D,mEAAmE;YACnE,MAAM,WAAW,GACf,CAAC,YAAY,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC;gBAC3B,CAAC,CAAC,iBAAiB,KAAK,IAAI,CAAC,CAAC,CAAC,iBAAiB;oBAChD,CAAC,CAAC,CAAC,GAAG,SAAS,GAAG,UAAU,CAAC,WAAW,CAAA;YAE1C,6EAA6E;YAC7E,MAAM,IAAI,GAAG,WAAW,CAAC,MAAM,EAAE,CAAC,EAAE,UAAU,CAAC,EAAE,CAAC,CAAA;YAClD,MAAM,IAAI,GAAG,WAAW,CAAC,MAAM,EAAE,CAAC,EAAE,UAAU,CAAC,EAAE,CAAC,CAAA;YAClD,MAAM,IAAI,GAAG,WAAW,CAAC,CAAC,EAAE,IAAI,EAAE,IAAI,EAAE,GAAG,EAAE,UAAU,CAAC,GAAG,EAAE,WAAW,CAAC,CAAA;YAEzE,UAAU,CAAC,IAAI,CAAC,EAAE,MAAM,EAAE,IAAI,EAAE,QAAQ,EAAE,UAAU,IAAI,EAAE,EAAE,QAAQ,EAAE,OAAO,EAAE,CAAC,CAAA;YAChF,UAAU,CAAC,IAAI,CAAC,EAAE,MAAM,EAAE,IAAI,EAAE,QAAQ,EAAE,UAAU,IAAI,EAAE,EAAE,QAAQ,EAAE,OAAO,EAAE,CAAC,CAAA;YAChF,UAAU,CAAC,IAAI,CAAC,EAAE,MAAM,EAAE,IAAI,EAAE,QAAQ,EAAE,IAAI,EAAc,QAAQ,EAAE,OAAO,EAAE,CAAC,CAAA;QAClF,CAAC;QACD,OAAO,EAAE,UAAU,EAAE,YAAY,EAAE,oBAAoB,EAAE,MAAM,EAAE,UAAU,EAAE,CAAA;IAC/E,CAAC,CAAC,CAAA;AACJ,CAAC"}
package/dist/buffers.js DELETED
@@ -1,120 +0,0 @@
1
- // Buffer planning: walk a Graph and decide which GPU buffer each Tensor maps to.
2
- //
3
- // v1 strategy: one GPU buffer per IR Tensor. Static shapes mean every buffer's
4
- // size is known at compile time and lifetimes don't overlap between steps —
5
- // so no pooling needed. Total memory is the sum of every intermediate tensor.
6
- // For our transformer at B=256: ~30 MB of activations + grads. Easily fits.
7
- //
8
- // Categorization is what the runtime cares about:
9
- // * param — uploaded by user via uploadParams; persistent across steps
10
- // * param_grad — written each step by the backward pass; readable for inspection
11
- // * tensor_input — uploaded each step (tokens, targets, masks)
12
- // * intermediate — produced by an op; lifetime = within a single step
13
- // * output — special intermediate that should be made readable (loss)
14
- import { shapeSize } from './shape.js';
15
- const dtypeBytes = { f32: 4, i32: 4, bool: 4 };
16
- /**
17
- * Build a BufferPlan from a graph + the param-grad map produced by appendGrad.
18
- * @param graph the full graph (forward + backward + any optimizer ops)
19
- * @param paramGrads map from param name -> the Tensor that holds its gradient
20
- * @param writebackDecls list of end-of-step writebacks (e.g. from appendAdam).
21
- * Empty when there's no optimizer in the graph.
22
- */
23
- export function planBuffers(graph, paramGrads, writebackDecls = []) {
24
- const buffers = [];
25
- const tensorToBuffer = new Map();
26
- const paramsByName = new Map();
27
- const inputsByName = new Map();
28
- const paramGradsByName = new Map();
29
- const statesByName = new Map();
30
- // Build a quick reverse map: tensorId -> param name (for grads).
31
- const gradTensorIdToName = new Map();
32
- for (const [name, tensor] of Object.entries(paramGrads)) {
33
- gradTensorIdToName.set(tensor.id, name);
34
- }
35
- // ...and tensorId -> param/input op (so we can name the buffer correctly).
36
- const opByOutId = new Map();
37
- for (const op of graph.ops)
38
- opByOutId.set(op.out, op);
39
- const outputSet = new Set(graph.outputs);
40
- // Walk all tensors in id order. Categorize each.
41
- for (const t of graph.tensors) {
42
- const op = opByOutId.get(t.id);
43
- let kind = 'intermediate';
44
- let name = null;
45
- let initValue;
46
- if (op?.kind === 'param_input') {
47
- kind = 'param';
48
- name = op.name;
49
- }
50
- else if (op?.kind === 'tensor_input') {
51
- kind = 'tensor_input';
52
- name = op.name;
53
- }
54
- else if (op?.kind === 'state_input') {
55
- kind = 'state';
56
- name = op.name;
57
- initValue = op.initValue;
58
- }
59
- else if (gradTensorIdToName.has(t.id)) {
60
- kind = 'param_grad';
61
- name = gradTensorIdToName.get(t.id);
62
- }
63
- else if (outputSet.has(t.id)) {
64
- kind = 'output';
65
- }
66
- const spec = {
67
- id: t.id,
68
- byteSize: Math.max(4, shapeSize(t.shape) * dtypeBytes[t.dtype]),
69
- dtype: t.dtype,
70
- shape: t.shape,
71
- kind,
72
- name,
73
- ...(initValue !== undefined ? { initValue } : {}),
74
- };
75
- buffers.push(spec);
76
- tensorToBuffer.set(t.id, t.id); // 1:1 for v1
77
- if (kind === 'param')
78
- paramsByName.set(name, t.id);
79
- if (kind === 'tensor_input')
80
- inputsByName.set(name, t.id);
81
- if (kind === 'param_grad')
82
- paramGradsByName.set(name, t.id);
83
- if (kind === 'state')
84
- statesByName.set(name, t.id);
85
- }
86
- const outputBufferIds = graph.outputs.map(id => tensorToBuffer.get(id));
87
- // Resolve writeback declarations to (source, dest) buffer-id pairs.
88
- const writebacks = writebackDecls.map(decl => {
89
- const sourceBufId = tensorToBuffer.get(decl.source.id);
90
- if (sourceBufId === undefined) {
91
- throw new Error(`planBuffers: writeback source tensor #${decl.source.id} not in graph`);
92
- }
93
- const destBufId = decl.destKind === 'param'
94
- ? paramsByName.get(decl.destName)
95
- : statesByName.get(decl.destName);
96
- if (destBufId === undefined) {
97
- throw new Error(`planBuffers: writeback dest ${decl.destKind}:'${decl.destName}' not found`);
98
- }
99
- const sourceSpec = buffers[sourceBufId];
100
- const destSpec = buffers[destBufId];
101
- if (sourceSpec.byteSize !== destSpec.byteSize) {
102
- throw new Error(`planBuffers: writeback size mismatch for ${decl.destKind}:'${decl.destName}' ` +
103
- `(source ${sourceSpec.byteSize} bytes vs dest ${destSpec.byteSize})`);
104
- }
105
- return { source: sourceBufId, dest: destBufId, bytes: sourceSpec.byteSize };
106
- });
107
- // Resolve graph.captures (name -> tensor id) to (name -> buffer id).
108
- // No pinning needed at the planner level: each tensor already has its own
109
- // buffer (see "v1 strategy" comment at top — no pooling yet).
110
- const capturesByName = new Map();
111
- for (const [name, tensorId] of graph.captures) {
112
- const bufId = tensorToBuffer.get(tensorId);
113
- if (bufId === undefined) {
114
- throw new Error(`planBuffers: capture '${name}' references unknown tensor #${tensorId}`);
115
- }
116
- capturesByName.set(name, bufId);
117
- }
118
- return { buffers, tensorToBuffer, paramsByName, inputsByName, paramGradsByName, statesByName, capturesByName, outputBufferIds, writebacks };
119
- }
120
- //# sourceMappingURL=buffers.js.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"buffers.js","sourceRoot":"","sources":["../src/buffers.ts"],"names":[],"mappings":"AAAA,iFAAiF;AACjF,EAAE;AACF,+EAA+E;AAC/E,4EAA4E;AAC5E,8EAA8E;AAC9E,4EAA4E;AAC5E,EAAE;AACF,kDAAkD;AAClD,gFAAgF;AAChF,qFAAqF;AACrF,iEAAiE;AACjE,wEAAwE;AACxE,8EAA8E;AAG9E,OAAO,EAAE,SAAS,EAAE,MAAM,YAAY,CAAA;AAyCtC,MAAM,UAAU,GAA0B,EAAE,GAAG,EAAE,CAAC,EAAE,GAAG,EAAE,CAAC,EAAE,IAAI,EAAE,CAAC,EAAE,CAAA;AAcrE;;;;;;GAMG;AACH,MAAM,UAAU,WAAW,CACzB,KAAY,EACZ,UAAkC,EAClC,iBAAkC,EAAE;IAEpC,MAAM,OAAO,GAAiB,EAAE,CAAA;IAChC,MAAM,cAAc,GAAG,IAAI,GAAG,EAAkB,CAAA;IAChD,MAAM,YAAY,GAAG,IAAI,GAAG,EAAkB,CAAA;IAC9C,MAAM,YAAY,GAAG,IAAI,GAAG,EAAkB,CAAA;IAC9C,MAAM,gBAAgB,GAAG,IAAI,GAAG,EAAkB,CAAA;IAClD,MAAM,YAAY,GAAG,IAAI,GAAG,EAAkB,CAAA;IAE9C,iEAAiE;IACjE,MAAM,kBAAkB,GAAG,IAAI,GAAG,EAAkB,CAAA;IACpD,KAAK,MAAM,CAAC,IAAI,EAAE,MAAM,CAAC,IAAI,MAAM,CAAC,OAAO,CAAC,UAAU,CAAC,EAAE,CAAC;QACxD,kBAAkB,CAAC,GAAG,CAAC,MAAM,CAAC,EAAE,EAAE,IAAI,CAAC,CAAA;IACzC,CAAC;IACD,2EAA2E;IAC3E,MAAM,SAAS,GAAG,IAAI,GAAG,EAAkB,CAAA;IAC3C,KAAK,MAAM,EAAE,IAAI,KAAK,CAAC,GAAG;QAAE,SAAS,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,EAAE,CAAC,CAAA;IAErD,MAAM,SAAS,GAAG,IAAI,GAAG,CAAC,KAAK,CAAC,OAAO,CAAC,CAAA;IAExC,iDAAiD;IACjD,KAAK,MAAM,CAAC,IAAI,KAAK,CAAC,OAAO,EAAE,CAAC;QAC9B,MAAM,EAAE,GAAG,SAAS,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,CAAA;QAC9B,IAAI,IAAI,GAAuB,cAAc,CAAA;QAC7C,IAAI,IAAI,GAAkB,IAAI,CAAA;QAC9B,IAAI,SAA6B,CAAA;QAEjC,IAAI,EAAE,EAAE,IAAI,KAAK,aAAa,EAAE,CAAC;YAC/B,IAAI,GAAG,OAAO,CAAA;YACd,IAAI,GAAG,EAAE,CAAC,IAAI,CAAA;QAChB,CAAC;aAAM,IAAI,EAAE,EAAE,IAAI,KAAK,cAAc,EAAE,CAAC;YACvC,IAAI,GAAG,cAAc,CAAA;YACrB,IAAI,GAAG,EAAE,CAAC,IAAI,CAAA;QAChB,CAAC;aAAM,IAAI,EAAE,EAAE,IAAI,KAAK,aAAa,EAAE,CAAC;YACtC,IAAI,GAAG,OAAO,CAAA;YACd,IAAI,GAAG,EAAE,CAAC,IAAI,CAAA;YACd,SAAS,GAAG,EAAE,CAAC,SAAS,CAAA;QAC1B,CAAC;aAAM,IAAI,kBAAkB,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC;YACxC,IAAI,GAAG,YAAY,CAAA;YACnB,IAAI,GAAG,kBAAkB,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAE,CAAA;QACtC,CAAC;aAAM,IAAI,SAAS,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC;YAC/B,IAAI,GAAG,QAAQ,CAAA;QACjB,CAAC;QAED,MAAM,IAAI,GAAe;YACvB,EAAE,EAAE,CAAC,CAAC,EAAE;YACR,QAAQ,EAAE,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC,KAAK,CAAC,GAAG,UAAU,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;YAC/D,KAAK,EAAE,CAAC,CAAC,KAAK;YACd,KAAK,EAAE,CAAC,CAAC,KAAK;YACd,IAAI;YACJ,IAAI;YACJ,GAAG,CAAC,SAAS,KAAK,SAAS,CAAC,CAAC,CAAC,EAAE,SAAS,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC;SAClD,CAAA;QACD,OAAO,CAAC,IAAI,CAAC,IAAI,CAAC,CAAA;QAClB,cAAc,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,CAAA,CAAE,aAAa;QAE7C,IAAI,IAAI,KAAK,OAAO;YAAE,YAAY,CAAC,GAAG,CAAC,IAAK,EAAE,CAAC,CAAC,EAAE,CAAC,CAAA;QACnD,IAAI,IAAI,KAAK,cAAc;YAAE,YAAY,CAAC,GAAG,CAAC,IAAK,EAAE,CAAC,CAAC,EAAE,CAAC,CAAA;QAC1D,IAAI,IAAI,KAAK,YAAY;YAAE,gBAAgB,CAAC,GAAG,CAAC,IAAK,EAAE,CAAC,CAAC,EAAE,CAAC,CAAA;QAC5D,IAAI,IAAI,KAAK,OAAO;YAAE,YAAY,CAAC,GAAG,CAAC,IAAK,EAAE,CAAC,CAAC,EAAE,CAAC,CAAA;IACrD,CAAC;IAED,MAAM,eAAe,GAAG,KAAK,CAAC,OAAO,CAAC,GAAG,CAAC,EAAE,CAAC,EAAE,CAAC,cAAc,CAAC,GAAG,CAAC,EAAE,CAAE,CAAC,CAAA;IAExE,oEAAoE;IACpE,MAAM,UAAU,GAAgB,cAAc,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE;QACxD,MAAM,WAAW,GAAG,cAAc,CAAC,GAAG,CAAC,IAAI,CAAC,MAAM,CAAC,EAAE,CAAC,CAAA;QACtD,IAAI,WAAW,KAAK,SAAS,EAAE,CAAC;YAC9B,MAAM,IAAI,KAAK,CAAC,yCAAyC,IAAI,CAAC,MAAM,CAAC,EAAE,eAAe,CAAC,CAAA;QACzF,CAAC;QACD,MAAM,SAAS,GAAG,IAAI,CAAC,QAAQ,KAAK,OAAO;YACzC,CAAC,CAAC,YAAY,CAAC,GAAG,CAAC,IAAI,CAAC,QAAQ,CAAC;YACjC,CAAC,CAAC,YAAY,CAAC,GAAG,CAAC,IAAI,CAAC,QAAQ,CAAC,CAAA;QACnC,IAAI,SAAS,KAAK,SAAS,EAAE,CAAC;YAC5B,MAAM,IAAI,KAAK,CAAC,+BAA+B,IAAI,CAAC,QAAQ,KAAK,IAAI,CAAC,QAAQ,aAAa,CAAC,CAAA;QAC9F,CAAC;QACD,MAAM,UAAU,GAAG,OAAO,CAAC,WAAW,CAAE,CAAA;QACxC,MAAM,QAAQ,GAAG,OAAO,CAAC,SAAS,CAAE,CAAA;QACpC,IAAI,UAAU,CAAC,QAAQ,KAAK,QAAQ,CAAC,QAAQ,EAAE,CAAC;YAC9C,MAAM,IAAI,KAAK,CACb,4CAA4C,IAAI,CAAC,QAAQ,KAAK,IAAI,CAAC,QAAQ,IAAI;gBAC/E,WAAW,UAAU,CAAC,QAAQ,kBAAkB,QAAQ,CAAC,QAAQ,GAAG,CACrE,CAAA;QACH,CAAC;QACD,OAAO,EAAE,MAAM,EAAE,WAAW,EAAE,IAAI,EAAE,SAAS,EAAE,KAAK,EAAE,UAAU,CAAC,QAAQ,EAAE,CAAA;IAC7E,CAAC,CAAC,CAAA;IAEF,qEAAqE;IACrE,0EAA0E;IAC1E,8DAA8D;IAC9D,MAAM,cAAc,GAAG,IAAI,GAAG,EAAkB,CAAA;IAChD,KAAK,MAAM,CAAC,IAAI,EAAE,QAAQ,CAAC,IAAI,KAAK,CAAC,QAAQ,EAAE,CAAC;QAC9C,MAAM,KAAK,GAAG,cAAc,CAAC,GAAG,CAAC,QAAQ,CAAC,CAAA;QAC1C,IAAI,KAAK,KAAK,SAAS,EAAE,CAAC;YACxB,MAAM,IAAI,KAAK,CAAC,yBAAyB,IAAI,gCAAgC,QAAQ,EAAE,CAAC,CAAA;QAC1F,CAAC;QACD,cAAc,CAAC,GAAG,CAAC,IAAI,EAAE,KAAK,CAAC,CAAA;IACjC,CAAC;IAED,OAAO,EAAE,OAAO,EAAE,cAAc,EAAE,YAAY,EAAE,YAAY,EAAE,gBAAgB,EAAE,YAAY,EAAE,cAAc,EAAE,eAAe,EAAE,UAAU,EAAE,CAAA;AAC7I,CAAC"}
package/dist/capture.js DELETED
@@ -1,33 +0,0 @@
1
- // Activation capture — opt-in readback of intermediate tensors at training step.
2
- //
3
- // Usage (inside the user's forward pass):
4
- //
5
- // import { capture } from 'tensorgrad'
6
- //
7
- // function attentionFwd(p, x) {
8
- // const scores = mul(matmulBatched(q, kT), SCALE_QK)
9
- // const attn = capture(`attn.${layerIdx}`, softmaxCausalLast(scores))
10
- // return matmulBatched(attn, v)
11
- // }
12
- //
13
- // Pass-through return type: `capture(name, t)` returns `t` unchanged so it
14
- // inlines at the point of computation. Behind the scenes it registers `t.id`
15
- // against `name` on the current graph; runtime exposes the registered tensors
16
- // via `step(inputs, { withCaptures: true })`.
17
- //
18
- // Outside the user's forward trace (during `appendGrad` / `appendAdam`'s
19
- // `traceInto` re-entry), `capture()` is a no-op — gradient and optimizer
20
- // internals shouldn't accidentally publish themselves to the UI.
21
- import { currentGraph, isCaptureEnabled } from './trace.js';
22
- export function capture(name, t) {
23
- if (!isCaptureEnabled())
24
- return t;
25
- const g = currentGraph();
26
- if (g.captures.has(name)) {
27
- throw new Error(`capture: name '${name}' already registered. Use unique names ` +
28
- `(e.g. \`attn.\${layerIdx}\`) when capturing across a loop.`);
29
- }
30
- g.captures.set(name, t.id);
31
- return t;
32
- }
33
- //# sourceMappingURL=capture.js.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"capture.js","sourceRoot":"","sources":["../src/capture.ts"],"names":[],"mappings":"AAAA,iFAAiF;AACjF,EAAE;AACF,0CAA0C;AAC1C,EAAE;AACF,yCAAyC;AACzC,EAAE;AACF,kCAAkC;AAClC,yDAAyD;AACzD,0EAA0E;AAC1E,oCAAoC;AACpC,MAAM;AACN,EAAE;AACF,2EAA2E;AAC3E,6EAA6E;AAC7E,8EAA8E;AAC9E,8CAA8C;AAC9C,EAAE;AACF,yEAAyE;AACzE,yEAAyE;AACzE,iEAAiE;AAGjE,OAAO,EAAE,YAAY,EAAE,gBAAgB,EAAE,MAAM,YAAY,CAAA;AAE3D,MAAM,UAAU,OAAO,CAAmB,IAAY,EAAE,CAAI;IAC1D,IAAI,CAAC,gBAAgB,EAAE;QAAE,OAAO,CAAC,CAAA;IACjC,MAAM,CAAC,GAAG,YAAY,EAAE,CAAA;IACxB,IAAI,CAAC,CAAC,QAAQ,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,CAAC;QACzB,MAAM,IAAI,KAAK,CACb,kBAAkB,IAAI,yCAAyC;YAC/D,4DAA4D,CAC7D,CAAA;IACH,CAAC;IACD,CAAC,CAAC,QAAQ,CAAC,GAAG,CAAC,IAAI,EAAE,CAAC,CAAC,EAAE,CAAC,CAAA;IAC1B,OAAO,CAAC,CAAA;AACV,CAAC"}