tensorgrad 0.0.15 → 0.0.16
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.d.ts +154 -193
- package/dist/index.js +2208 -39
- package/dist/index.js.map +7 -1
- package/dist/worker.debug.js +553 -0
- package/package.json +60 -58
- package/src/adam.ts +69 -15
- package/src/compile.ts +334 -156
- package/src/index.ts +8 -4
- package/src/module.ts +72 -34
- package/src/worker-protocol.ts +183 -0
- package/src/worker-proxy.ts +76 -0
- package/src/worker.ts +281 -0
- package/dist/adam.js +0 -111
- package/dist/adam.js.map +0 -1
- package/dist/buffers.js +0 -120
- package/dist/buffers.js.map +0 -1
- package/dist/capture.js +0 -33
- package/dist/capture.js.map +0 -1
- package/dist/codegen.js +0 -724
- package/dist/codegen.js.map +0 -1
- package/dist/compile.js +0 -184
- package/dist/compile.js.map +0 -1
- package/dist/grad.js +0 -380
- package/dist/grad.js.map +0 -1
- package/dist/ir.js +0 -60
- package/dist/ir.js.map +0 -1
- package/dist/module.js +0 -155
- package/dist/module.js.map +0 -1
- package/dist/nn.js +0 -135
- package/dist/nn.js.map +0 -1
- package/dist/ops.js +0 -326
- package/dist/ops.js.map +0 -1
- package/dist/runtime.js +0 -402
- package/dist/runtime.js.map +0 -1
- package/dist/shape.js +0 -259
- package/dist/shape.js.map +0 -1
- package/dist/trace.js +0 -100
- package/dist/trace.js.map +0 -1
package/dist/adam.js
DELETED
|
@@ -1,111 +0,0 @@
|
|
|
1
|
-
// Adam / AdamW optimizer, in-graph.
|
|
2
|
-
//
|
|
3
|
-
// `appendAdam` extends a graph that already has a forward pass + autograd-emitted
|
|
4
|
-
// backward (i.e., has paramGrads from `appendGrad`) with the Adam update math.
|
|
5
|
-
//
|
|
6
|
-
// Per parameter P with gradient g:
|
|
7
|
-
// m_new = b1 * m + (1 - b1) * g
|
|
8
|
-
// v_new = b2 * v + (1 - b2) * g²
|
|
9
|
-
// p_new = decayShrink * p - lrt * m_new / (sqrt(v_new) + eps)
|
|
10
|
-
//
|
|
11
|
-
// `decayShrink = 1 - lr * weightDecay` when the param is being decayed
|
|
12
|
-
// (Loshchilov & Hutter, "AdamW") and 1 otherwise — at which point the
|
|
13
|
-
// multiply folds out and you're left with plain Adam. `lrt` is supplied
|
|
14
|
-
// per-step from CPU and includes the bias-correction factor
|
|
15
|
-
// `sqrt(1-b2^t)/(1-b1^t)`; that's why convergence isn't affected by the
|
|
16
|
-
// first-step warmup that bias-correction-free Adam suffers.
|
|
17
|
-
//
|
|
18
|
-
// **Static vs scheduled lr.** When `config.lr` is a number, decayShrink is
|
|
19
|
-
// baked into the kernel as a literal. When it's a function `(step) => lr`,
|
|
20
|
-
// decayShrink for decayed params becomes a per-step scalar input that the
|
|
21
|
-
// runtime updates each call (computed from the current step's lr). lrt is
|
|
22
|
-
// always per-step; the bias-correction factor changes every step regardless.
|
|
23
|
-
//
|
|
24
|
-
// Returns writeback declarations the buffer planner uses to wire up the
|
|
25
|
-
// "after step, copy the new value into the persistent home" path. m and v
|
|
26
|
-
// are state_inputs (zero-initialized, persistent across steps); the param
|
|
27
|
-
// updates are aliased back to the param buffers.
|
|
28
|
-
import { traceInto, stateInput, tensorInput } from './trace.js';
|
|
29
|
-
import { adamUpdateM, adamUpdateV, adamUpdateP } from './ops.js';
|
|
30
|
-
/**
|
|
31
|
-
* Append Adam update ops to `graph`. Must be called inside an active trace
|
|
32
|
-
* context (or after a trace, since traceInto re-enters the graph).
|
|
33
|
-
*
|
|
34
|
-
* @param graph the graph (already containing forward + backward)
|
|
35
|
-
* @param paramGrads param name -> gradient tensor (output of `appendGrad`)
|
|
36
|
-
* @param paramTensors param name -> the param's leaf Tensor (the param_input).
|
|
37
|
-
* Needed because the param_input lives in the graph but we
|
|
38
|
-
* don't have a direct map by name in `Graph` — caller passes it.
|
|
39
|
-
* @param config Adam hyperparameters. Set `weightDecay > 0` for AdamW; an
|
|
40
|
-
* optional `decayFilter` selects which params receive decay.
|
|
41
|
-
*/
|
|
42
|
-
export function appendAdam(graph, paramGrads, paramTensors, config,
|
|
43
|
-
/** Per-param decay flags from `materializeParams`. When supplied, overrides
|
|
44
|
-
* `config.decayFilter` for any name in the map; falls back to `decayFilter`
|
|
45
|
-
* for names not present (e.g., for low-level callers using `compile()`
|
|
46
|
-
* directly without a Module). */
|
|
47
|
-
decayFlags) {
|
|
48
|
-
const lrIsScheduled = typeof config.lr === 'function';
|
|
49
|
-
const lrFn = lrIsScheduled
|
|
50
|
-
? config.lr
|
|
51
|
-
: (() => config.lr);
|
|
52
|
-
const initialLr = lrFn(1);
|
|
53
|
-
const fullConfig = {
|
|
54
|
-
lr: lrFn,
|
|
55
|
-
b1: config.b1 ?? 0.9,
|
|
56
|
-
b2: config.b2 ?? 0.999,
|
|
57
|
-
eps: config.eps ?? 1e-8,
|
|
58
|
-
weightDecay: config.weightDecay ?? 0,
|
|
59
|
-
decayFilter: config.decayFilter ?? (() => true),
|
|
60
|
-
lrIsScheduled,
|
|
61
|
-
};
|
|
62
|
-
const writebacks = [];
|
|
63
|
-
const lrtInputName = '_adam_lrt';
|
|
64
|
-
// Tensor input for runtime-updated decayShrink (only created when lr is a
|
|
65
|
-
// schedule fn AND at least one param will receive weight decay).
|
|
66
|
-
let decayShrinkInputName = null;
|
|
67
|
-
return traceInto(graph, () => {
|
|
68
|
-
const lrt = tensorInput(lrtInputName, [], 'f32');
|
|
69
|
-
// Up-front: which params receive weight decay? Per-param decayFlags (set
|
|
70
|
-
// by Module.param's options) wins; falls back to decayFilter for names
|
|
71
|
-
// not in the map. Empty when weightDecay = 0 so the rest of the function
|
|
72
|
-
// can just ask "is this name in the set?".
|
|
73
|
-
const decayedNames = new Set(fullConfig.weightDecay > 0
|
|
74
|
-
? Object.keys(paramGrads).filter(name => (decayFlags && name in decayFlags) ? decayFlags[name] : fullConfig.decayFilter(name))
|
|
75
|
-
: []);
|
|
76
|
-
// We only need a runtime decayShrink scalar when lr varies per step AND
|
|
77
|
-
// at least one param is being decayed. Otherwise the value is constant
|
|
78
|
-
// and bakes into the kernel as a literal.
|
|
79
|
-
let decayShrinkScalar = null;
|
|
80
|
-
if (lrIsScheduled && decayedNames.size > 0) {
|
|
81
|
-
decayShrinkInputName = '_adam_decay_shrink';
|
|
82
|
-
decayShrinkScalar = tensorInput(decayShrinkInputName, [], 'f32');
|
|
83
|
-
}
|
|
84
|
-
for (const name of Object.keys(paramGrads)) {
|
|
85
|
-
const p = paramTensors[name];
|
|
86
|
-
const g = paramGrads[name];
|
|
87
|
-
if (!p)
|
|
88
|
-
throw new Error(`appendAdam: missing param tensor for '${name}'`);
|
|
89
|
-
if (!g)
|
|
90
|
-
throw new Error(`appendAdam: missing gradient for '${name}'`);
|
|
91
|
-
const mState = stateInput(`adam_m_${name}`, p.shape, 'f32', 0);
|
|
92
|
-
const vState = stateInput(`adam_v_${name}`, p.shape, 'f32', 0);
|
|
93
|
-
// Choose the decayShrink form per param:
|
|
94
|
-
// - non-decayed params: literal 1 (kernel multiply folds out).
|
|
95
|
-
// - decayed + scheduled lr: tensor input updated per step.
|
|
96
|
-
// - decayed + static lr: literal `1 - lr * wd` baked at compile.
|
|
97
|
-
const decayShrink = !decayedNames.has(name) ? 1
|
|
98
|
-
: decayShrinkScalar !== null ? decayShrinkScalar
|
|
99
|
-
: 1 - initialLr * fullConfig.weightDecay;
|
|
100
|
-
// Three fused kernels per parameter — one for each of m_new / v_new / p_new.
|
|
101
|
-
const newM = adamUpdateM(mState, g, fullConfig.b1);
|
|
102
|
-
const newV = adamUpdateV(vState, g, fullConfig.b2);
|
|
103
|
-
const newP = adamUpdateP(p, newM, newV, lrt, fullConfig.eps, decayShrink);
|
|
104
|
-
writebacks.push({ source: newM, destName: `adam_m_${name}`, destKind: 'state' });
|
|
105
|
-
writebacks.push({ source: newV, destName: `adam_v_${name}`, destKind: 'state' });
|
|
106
|
-
writebacks.push({ source: newP, destName: name, destKind: 'param' });
|
|
107
|
-
}
|
|
108
|
-
return { writebacks, lrtInputName, decayShrinkInputName, config: fullConfig };
|
|
109
|
-
});
|
|
110
|
-
}
|
|
111
|
-
//# sourceMappingURL=adam.js.map
|
package/dist/adam.js.map
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"adam.js","sourceRoot":"","sources":["../src/adam.ts"],"names":[],"mappings":"AAAA,oCAAoC;AACpC,EAAE;AACF,kFAAkF;AAClF,+EAA+E;AAC/E,EAAE;AACF,mCAAmC;AACnC,kCAAkC;AAClC,mCAAmC;AACnC,gEAAgE;AAChE,EAAE;AACF,uEAAuE;AACvE,sEAAsE;AACtE,wEAAwE;AACxE,4DAA4D;AAC5D,wEAAwE;AACxE,4DAA4D;AAC5D,EAAE;AACF,2EAA2E;AAC3E,2EAA2E;AAC3E,0EAA0E;AAC1E,0EAA0E;AAC1E,6EAA6E;AAC7E,EAAE;AACF,wEAAwE;AACxE,0EAA0E;AAC1E,0EAA0E;AAC1E,iDAAiD;AAKjD,OAAO,EAAE,SAAS,EAAE,UAAU,EAAE,WAAW,EAAE,MAAM,YAAY,CAAA;AAC/D,OAAO,EAAE,WAAW,EAAE,WAAW,EAAE,WAAW,EAAE,MAAM,UAAU,CAAA;AAgDhE;;;;;;;;;;;GAWG;AACH,MAAM,UAAU,UAAU,CACxB,KAAY,EACZ,UAAkC,EAClC,YAAoC,EACpC,MAAkB;AAClB;;;kCAGkC;AAClC,UAAoC;IAEpC,MAAM,aAAa,GAAG,OAAO,MAAM,CAAC,EAAE,KAAK,UAAU,CAAA;IACrD,MAAM,IAAI,GAAG,aAAa;QACxB,CAAC,CAAC,MAAM,CAAC,EAA8B;QACvC,CAAC,CAAC,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,EAAY,CAAC,CAAA;IAC/B,MAAM,SAAS,GAAG,IAAI,CAAC,CAAC,CAAC,CAAA;IACzB,MAAM,UAAU,GAAuB;QACrC,EAAE,EAAE,IAAI;QACR,EAAE,EAAE,MAAM,CAAC,EAAE,IAAI,GAAG;QACpB,EAAE,EAAE,MAAM,CAAC,EAAE,IAAI,KAAK;QACtB,GAAG,EAAE,MAAM,CAAC,GAAG,IAAI,IAAI;QACvB,WAAW,EAAE,MAAM,CAAC,WAAW,IAAI,CAAC;QACpC,WAAW,EAAE,MAAM,CAAC,WAAW,IAAI,CAAC,GAAG,EAAE,CAAC,IAAI,CAAC;QAC/C,aAAa;KACd,CAAA;IACD,MAAM,UAAU,GAAoB,EAAE,CAAA;IACtC,MAAM,YAAY,GAAG,WAAW,CAAA;IAChC,0EAA0E;IAC1E,iEAAiE;IACjE,IAAI,oBAAoB,GAAkB,IAAI,CAAA;IAE9C,OAAO,SAAS,CAAC,KAAK,EAAE,GAAG,EAAE;QAC3B,MAAM,GAAG,GAAG,WAAW,CAAC,YAAY,EAAE,EAAE,EAAE,KAAK,CAAC,CAAA;QAEhD,yEAAyE;QACzE,uEAAuE;QACvE,yEAAyE;QACzE,2CAA2C;QAC3C,MAAM,YAAY,GAAG,IAAI,GAAG,CAC1B,UAAU,CAAC,WAAW,GAAG,CAAC;YACxB,CAAC,CAAC,MAAM,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC,MAAM,CAAC,IAAI,CAAC,EAAE,CACpC,CAAC,UAAU,IAAI,IAAI,IAAI,UAAU,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC,IAAI,CAAE,CAAC,CAAC,CAAC,UAAU,CAAC,WAAW,CAAC,IAAI,CAAC,CAAC;YAC1F,CAAC,CAAC,EAAE,CACP,CAAA;QAED,wEAAwE;QACxE,uEAAuE;QACvE,0CAA0C;QAC1C,IAAI,iBAAiB,GAAkB,IAAI,CAAA;QAC3C,IAAI,aAAa,IAAI,YAAY,CAAC,IAAI,GAAG,CAAC,EAAE,CAAC;YAC3C,oBAAoB,GAAG,oBAAoB,CAAA;YAC3C,iBAAiB,GAAG,WAAW,CAAC,oBAAoB,EAAE,EAAE,EAAE,KAAK,CAAC,CAAA;QAClE,CAAC;QAED,KAAK,MAAM,IAAI,IAAI,MAAM,CAAC,IAAI,CAAC,UAAU,CAAC,EAAE,CAAC;YAC3C,MAAM,CAAC,GAAG,YAAY,CAAC,IAAI,CAAC,CAAA;YAC5B,MAAM,CAAC,GAAG,UAAU,CAAC,IAAI,CAAC,CAAA;YAC1B,IAAI,CAAC,CAAC;gBAAE,MAAM,IAAI,KAAK,CAAC,yCAAyC,IAAI,GAAG,CAAC,CAAA;YACzE,IAAI,CAAC,CAAC;gBAAE,MAAM,IAAI,KAAK,CAAC,qCAAqC,IAAI,GAAG,CAAC,CAAA;YAErE,MAAM,MAAM,GAAG,UAAU,CAAC,UAAU,IAAI,EAAE,EAAE,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,CAAC,CAAC,CAAA;YAC9D,MAAM,MAAM,GAAG,UAAU,CAAC,UAAU,IAAI,EAAE,EAAE,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,CAAC,CAAC,CAAA;YAE9D,yCAAyC;YACzC,iEAAiE;YACjE,6DAA6D;YAC7D,mEAAmE;YACnE,MAAM,WAAW,GACf,CAAC,YAAY,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC;gBAC3B,CAAC,CAAC,iBAAiB,KAAK,IAAI,CAAC,CAAC,CAAC,iBAAiB;oBAChD,CAAC,CAAC,CAAC,GAAG,SAAS,GAAG,UAAU,CAAC,WAAW,CAAA;YAE1C,6EAA6E;YAC7E,MAAM,IAAI,GAAG,WAAW,CAAC,MAAM,EAAE,CAAC,EAAE,UAAU,CAAC,EAAE,CAAC,CAAA;YAClD,MAAM,IAAI,GAAG,WAAW,CAAC,MAAM,EAAE,CAAC,EAAE,UAAU,CAAC,EAAE,CAAC,CAAA;YAClD,MAAM,IAAI,GAAG,WAAW,CAAC,CAAC,EAAE,IAAI,EAAE,IAAI,EAAE,GAAG,EAAE,UAAU,CAAC,GAAG,EAAE,WAAW,CAAC,CAAA;YAEzE,UAAU,CAAC,IAAI,CAAC,EAAE,MAAM,EAAE,IAAI,EAAE,QAAQ,EAAE,UAAU,IAAI,EAAE,EAAE,QAAQ,EAAE,OAAO,EAAE,CAAC,CAAA;YAChF,UAAU,CAAC,IAAI,CAAC,EAAE,MAAM,EAAE,IAAI,EAAE,QAAQ,EAAE,UAAU,IAAI,EAAE,EAAE,QAAQ,EAAE,OAAO,EAAE,CAAC,CAAA;YAChF,UAAU,CAAC,IAAI,CAAC,EAAE,MAAM,EAAE,IAAI,EAAE,QAAQ,EAAE,IAAI,EAAc,QAAQ,EAAE,OAAO,EAAE,CAAC,CAAA;QAClF,CAAC;QACD,OAAO,EAAE,UAAU,EAAE,YAAY,EAAE,oBAAoB,EAAE,MAAM,EAAE,UAAU,EAAE,CAAA;IAC/E,CAAC,CAAC,CAAA;AACJ,CAAC"}
|
package/dist/buffers.js
DELETED
|
@@ -1,120 +0,0 @@
|
|
|
1
|
-
// Buffer planning: walk a Graph and decide which GPU buffer each Tensor maps to.
|
|
2
|
-
//
|
|
3
|
-
// v1 strategy: one GPU buffer per IR Tensor. Static shapes mean every buffer's
|
|
4
|
-
// size is known at compile time and lifetimes don't overlap between steps —
|
|
5
|
-
// so no pooling needed. Total memory is the sum of every intermediate tensor.
|
|
6
|
-
// For our transformer at B=256: ~30 MB of activations + grads. Easily fits.
|
|
7
|
-
//
|
|
8
|
-
// Categorization is what the runtime cares about:
|
|
9
|
-
// * param — uploaded by user via uploadParams; persistent across steps
|
|
10
|
-
// * param_grad — written each step by the backward pass; readable for inspection
|
|
11
|
-
// * tensor_input — uploaded each step (tokens, targets, masks)
|
|
12
|
-
// * intermediate — produced by an op; lifetime = within a single step
|
|
13
|
-
// * output — special intermediate that should be made readable (loss)
|
|
14
|
-
import { shapeSize } from './shape.js';
|
|
15
|
-
const dtypeBytes = { f32: 4, i32: 4, bool: 4 };
|
|
16
|
-
/**
|
|
17
|
-
* Build a BufferPlan from a graph + the param-grad map produced by appendGrad.
|
|
18
|
-
* @param graph the full graph (forward + backward + any optimizer ops)
|
|
19
|
-
* @param paramGrads map from param name -> the Tensor that holds its gradient
|
|
20
|
-
* @param writebackDecls list of end-of-step writebacks (e.g. from appendAdam).
|
|
21
|
-
* Empty when there's no optimizer in the graph.
|
|
22
|
-
*/
|
|
23
|
-
export function planBuffers(graph, paramGrads, writebackDecls = []) {
|
|
24
|
-
const buffers = [];
|
|
25
|
-
const tensorToBuffer = new Map();
|
|
26
|
-
const paramsByName = new Map();
|
|
27
|
-
const inputsByName = new Map();
|
|
28
|
-
const paramGradsByName = new Map();
|
|
29
|
-
const statesByName = new Map();
|
|
30
|
-
// Build a quick reverse map: tensorId -> param name (for grads).
|
|
31
|
-
const gradTensorIdToName = new Map();
|
|
32
|
-
for (const [name, tensor] of Object.entries(paramGrads)) {
|
|
33
|
-
gradTensorIdToName.set(tensor.id, name);
|
|
34
|
-
}
|
|
35
|
-
// ...and tensorId -> param/input op (so we can name the buffer correctly).
|
|
36
|
-
const opByOutId = new Map();
|
|
37
|
-
for (const op of graph.ops)
|
|
38
|
-
opByOutId.set(op.out, op);
|
|
39
|
-
const outputSet = new Set(graph.outputs);
|
|
40
|
-
// Walk all tensors in id order. Categorize each.
|
|
41
|
-
for (const t of graph.tensors) {
|
|
42
|
-
const op = opByOutId.get(t.id);
|
|
43
|
-
let kind = 'intermediate';
|
|
44
|
-
let name = null;
|
|
45
|
-
let initValue;
|
|
46
|
-
if (op?.kind === 'param_input') {
|
|
47
|
-
kind = 'param';
|
|
48
|
-
name = op.name;
|
|
49
|
-
}
|
|
50
|
-
else if (op?.kind === 'tensor_input') {
|
|
51
|
-
kind = 'tensor_input';
|
|
52
|
-
name = op.name;
|
|
53
|
-
}
|
|
54
|
-
else if (op?.kind === 'state_input') {
|
|
55
|
-
kind = 'state';
|
|
56
|
-
name = op.name;
|
|
57
|
-
initValue = op.initValue;
|
|
58
|
-
}
|
|
59
|
-
else if (gradTensorIdToName.has(t.id)) {
|
|
60
|
-
kind = 'param_grad';
|
|
61
|
-
name = gradTensorIdToName.get(t.id);
|
|
62
|
-
}
|
|
63
|
-
else if (outputSet.has(t.id)) {
|
|
64
|
-
kind = 'output';
|
|
65
|
-
}
|
|
66
|
-
const spec = {
|
|
67
|
-
id: t.id,
|
|
68
|
-
byteSize: Math.max(4, shapeSize(t.shape) * dtypeBytes[t.dtype]),
|
|
69
|
-
dtype: t.dtype,
|
|
70
|
-
shape: t.shape,
|
|
71
|
-
kind,
|
|
72
|
-
name,
|
|
73
|
-
...(initValue !== undefined ? { initValue } : {}),
|
|
74
|
-
};
|
|
75
|
-
buffers.push(spec);
|
|
76
|
-
tensorToBuffer.set(t.id, t.id); // 1:1 for v1
|
|
77
|
-
if (kind === 'param')
|
|
78
|
-
paramsByName.set(name, t.id);
|
|
79
|
-
if (kind === 'tensor_input')
|
|
80
|
-
inputsByName.set(name, t.id);
|
|
81
|
-
if (kind === 'param_grad')
|
|
82
|
-
paramGradsByName.set(name, t.id);
|
|
83
|
-
if (kind === 'state')
|
|
84
|
-
statesByName.set(name, t.id);
|
|
85
|
-
}
|
|
86
|
-
const outputBufferIds = graph.outputs.map(id => tensorToBuffer.get(id));
|
|
87
|
-
// Resolve writeback declarations to (source, dest) buffer-id pairs.
|
|
88
|
-
const writebacks = writebackDecls.map(decl => {
|
|
89
|
-
const sourceBufId = tensorToBuffer.get(decl.source.id);
|
|
90
|
-
if (sourceBufId === undefined) {
|
|
91
|
-
throw new Error(`planBuffers: writeback source tensor #${decl.source.id} not in graph`);
|
|
92
|
-
}
|
|
93
|
-
const destBufId = decl.destKind === 'param'
|
|
94
|
-
? paramsByName.get(decl.destName)
|
|
95
|
-
: statesByName.get(decl.destName);
|
|
96
|
-
if (destBufId === undefined) {
|
|
97
|
-
throw new Error(`planBuffers: writeback dest ${decl.destKind}:'${decl.destName}' not found`);
|
|
98
|
-
}
|
|
99
|
-
const sourceSpec = buffers[sourceBufId];
|
|
100
|
-
const destSpec = buffers[destBufId];
|
|
101
|
-
if (sourceSpec.byteSize !== destSpec.byteSize) {
|
|
102
|
-
throw new Error(`planBuffers: writeback size mismatch for ${decl.destKind}:'${decl.destName}' ` +
|
|
103
|
-
`(source ${sourceSpec.byteSize} bytes vs dest ${destSpec.byteSize})`);
|
|
104
|
-
}
|
|
105
|
-
return { source: sourceBufId, dest: destBufId, bytes: sourceSpec.byteSize };
|
|
106
|
-
});
|
|
107
|
-
// Resolve graph.captures (name -> tensor id) to (name -> buffer id).
|
|
108
|
-
// No pinning needed at the planner level: each tensor already has its own
|
|
109
|
-
// buffer (see "v1 strategy" comment at top — no pooling yet).
|
|
110
|
-
const capturesByName = new Map();
|
|
111
|
-
for (const [name, tensorId] of graph.captures) {
|
|
112
|
-
const bufId = tensorToBuffer.get(tensorId);
|
|
113
|
-
if (bufId === undefined) {
|
|
114
|
-
throw new Error(`planBuffers: capture '${name}' references unknown tensor #${tensorId}`);
|
|
115
|
-
}
|
|
116
|
-
capturesByName.set(name, bufId);
|
|
117
|
-
}
|
|
118
|
-
return { buffers, tensorToBuffer, paramsByName, inputsByName, paramGradsByName, statesByName, capturesByName, outputBufferIds, writebacks };
|
|
119
|
-
}
|
|
120
|
-
//# sourceMappingURL=buffers.js.map
|
package/dist/buffers.js.map
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"buffers.js","sourceRoot":"","sources":["../src/buffers.ts"],"names":[],"mappings":"AAAA,iFAAiF;AACjF,EAAE;AACF,+EAA+E;AAC/E,4EAA4E;AAC5E,8EAA8E;AAC9E,4EAA4E;AAC5E,EAAE;AACF,kDAAkD;AAClD,gFAAgF;AAChF,qFAAqF;AACrF,iEAAiE;AACjE,wEAAwE;AACxE,8EAA8E;AAG9E,OAAO,EAAE,SAAS,EAAE,MAAM,YAAY,CAAA;AAyCtC,MAAM,UAAU,GAA0B,EAAE,GAAG,EAAE,CAAC,EAAE,GAAG,EAAE,CAAC,EAAE,IAAI,EAAE,CAAC,EAAE,CAAA;AAcrE;;;;;;GAMG;AACH,MAAM,UAAU,WAAW,CACzB,KAAY,EACZ,UAAkC,EAClC,iBAAkC,EAAE;IAEpC,MAAM,OAAO,GAAiB,EAAE,CAAA;IAChC,MAAM,cAAc,GAAG,IAAI,GAAG,EAAkB,CAAA;IAChD,MAAM,YAAY,GAAG,IAAI,GAAG,EAAkB,CAAA;IAC9C,MAAM,YAAY,GAAG,IAAI,GAAG,EAAkB,CAAA;IAC9C,MAAM,gBAAgB,GAAG,IAAI,GAAG,EAAkB,CAAA;IAClD,MAAM,YAAY,GAAG,IAAI,GAAG,EAAkB,CAAA;IAE9C,iEAAiE;IACjE,MAAM,kBAAkB,GAAG,IAAI,GAAG,EAAkB,CAAA;IACpD,KAAK,MAAM,CAAC,IAAI,EAAE,MAAM,CAAC,IAAI,MAAM,CAAC,OAAO,CAAC,UAAU,CAAC,EAAE,CAAC;QACxD,kBAAkB,CAAC,GAAG,CAAC,MAAM,CAAC,EAAE,EAAE,IAAI,CAAC,CAAA;IACzC,CAAC;IACD,2EAA2E;IAC3E,MAAM,SAAS,GAAG,IAAI,GAAG,EAAkB,CAAA;IAC3C,KAAK,MAAM,EAAE,IAAI,KAAK,CAAC,GAAG;QAAE,SAAS,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,EAAE,CAAC,CAAA;IAErD,MAAM,SAAS,GAAG,IAAI,GAAG,CAAC,KAAK,CAAC,OAAO,CAAC,CAAA;IAExC,iDAAiD;IACjD,KAAK,MAAM,CAAC,IAAI,KAAK,CAAC,OAAO,EAAE,CAAC;QAC9B,MAAM,EAAE,GAAG,SAAS,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,CAAA;QAC9B,IAAI,IAAI,GAAuB,cAAc,CAAA;QAC7C,IAAI,IAAI,GAAkB,IAAI,CAAA;QAC9B,IAAI,SAA6B,CAAA;QAEjC,IAAI,EAAE,EAAE,IAAI,KAAK,aAAa,EAAE,CAAC;YAC/B,IAAI,GAAG,OAAO,CAAA;YACd,IAAI,GAAG,EAAE,CAAC,IAAI,CAAA;QAChB,CAAC;aAAM,IAAI,EAAE,EAAE,IAAI,KAAK,cAAc,EAAE,CAAC;YACvC,IAAI,GAAG,cAAc,CAAA;YACrB,IAAI,GAAG,EAAE,CAAC,IAAI,CAAA;QAChB,CAAC;aAAM,IAAI,EAAE,EAAE,IAAI,KAAK,aAAa,EAAE,CAAC;YACtC,IAAI,GAAG,OAAO,CAAA;YACd,IAAI,GAAG,EAAE,CAAC,IAAI,CAAA;YACd,SAAS,GAAG,EAAE,CAAC,SAAS,CAAA;QAC1B,CAAC;aAAM,IAAI,kBAAkB,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC;YACxC,IAAI,GAAG,YAAY,CAAA;YACnB,IAAI,GAAG,kBAAkB,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAE,CAAA;QACtC,CAAC;aAAM,IAAI,SAAS,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC;YAC/B,IAAI,GAAG,QAAQ,CAAA;QACjB,CAAC;QAED,MAAM,IAAI,GAAe;YACvB,EAAE,EAAE,CAAC,CAAC,EAAE;YACR,QAAQ,EAAE,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC,KAAK,CAAC,GAAG,UAAU,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;YAC/D,KAAK,EAAE,CAAC,CAAC,KAAK;YACd,KAAK,EAAE,CAAC,CAAC,KAAK;YACd,IAAI;YACJ,IAAI;YACJ,GAAG,CAAC,SAAS,KAAK,SAAS,CAAC,CAAC,CAAC,EAAE,SAAS,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC;SAClD,CAAA;QACD,OAAO,CAAC,IAAI,CAAC,IAAI,CAAC,CAAA;QAClB,cAAc,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,CAAA,CAAE,aAAa;QAE7C,IAAI,IAAI,KAAK,OAAO;YAAE,YAAY,CAAC,GAAG,CAAC,IAAK,EAAE,CAAC,CAAC,EAAE,CAAC,CAAA;QACnD,IAAI,IAAI,KAAK,cAAc;YAAE,YAAY,CAAC,GAAG,CAAC,IAAK,EAAE,CAAC,CAAC,EAAE,CAAC,CAAA;QAC1D,IAAI,IAAI,KAAK,YAAY;YAAE,gBAAgB,CAAC,GAAG,CAAC,IAAK,EAAE,CAAC,CAAC,EAAE,CAAC,CAAA;QAC5D,IAAI,IAAI,KAAK,OAAO;YAAE,YAAY,CAAC,GAAG,CAAC,IAAK,EAAE,CAAC,CAAC,EAAE,CAAC,CAAA;IACrD,CAAC;IAED,MAAM,eAAe,GAAG,KAAK,CAAC,OAAO,CAAC,GAAG,CAAC,EAAE,CAAC,EAAE,CAAC,cAAc,CAAC,GAAG,CAAC,EAAE,CAAE,CAAC,CAAA;IAExE,oEAAoE;IACpE,MAAM,UAAU,GAAgB,cAAc,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE;QACxD,MAAM,WAAW,GAAG,cAAc,CAAC,GAAG,CAAC,IAAI,CAAC,MAAM,CAAC,EAAE,CAAC,CAAA;QACtD,IAAI,WAAW,KAAK,SAAS,EAAE,CAAC;YAC9B,MAAM,IAAI,KAAK,CAAC,yCAAyC,IAAI,CAAC,MAAM,CAAC,EAAE,eAAe,CAAC,CAAA;QACzF,CAAC;QACD,MAAM,SAAS,GAAG,IAAI,CAAC,QAAQ,KAAK,OAAO;YACzC,CAAC,CAAC,YAAY,CAAC,GAAG,CAAC,IAAI,CAAC,QAAQ,CAAC;YACjC,CAAC,CAAC,YAAY,CAAC,GAAG,CAAC,IAAI,CAAC,QAAQ,CAAC,CAAA;QACnC,IAAI,SAAS,KAAK,SAAS,EAAE,CAAC;YAC5B,MAAM,IAAI,KAAK,CAAC,+BAA+B,IAAI,CAAC,QAAQ,KAAK,IAAI,CAAC,QAAQ,aAAa,CAAC,CAAA;QAC9F,CAAC;QACD,MAAM,UAAU,GAAG,OAAO,CAAC,WAAW,CAAE,CAAA;QACxC,MAAM,QAAQ,GAAG,OAAO,CAAC,SAAS,CAAE,CAAA;QACpC,IAAI,UAAU,CAAC,QAAQ,KAAK,QAAQ,CAAC,QAAQ,EAAE,CAAC;YAC9C,MAAM,IAAI,KAAK,CACb,4CAA4C,IAAI,CAAC,QAAQ,KAAK,IAAI,CAAC,QAAQ,IAAI;gBAC/E,WAAW,UAAU,CAAC,QAAQ,kBAAkB,QAAQ,CAAC,QAAQ,GAAG,CACrE,CAAA;QACH,CAAC;QACD,OAAO,EAAE,MAAM,EAAE,WAAW,EAAE,IAAI,EAAE,SAAS,EAAE,KAAK,EAAE,UAAU,CAAC,QAAQ,EAAE,CAAA;IAC7E,CAAC,CAAC,CAAA;IAEF,qEAAqE;IACrE,0EAA0E;IAC1E,8DAA8D;IAC9D,MAAM,cAAc,GAAG,IAAI,GAAG,EAAkB,CAAA;IAChD,KAAK,MAAM,CAAC,IAAI,EAAE,QAAQ,CAAC,IAAI,KAAK,CAAC,QAAQ,EAAE,CAAC;QAC9C,MAAM,KAAK,GAAG,cAAc,CAAC,GAAG,CAAC,QAAQ,CAAC,CAAA;QAC1C,IAAI,KAAK,KAAK,SAAS,EAAE,CAAC;YACxB,MAAM,IAAI,KAAK,CAAC,yBAAyB,IAAI,gCAAgC,QAAQ,EAAE,CAAC,CAAA;QAC1F,CAAC;QACD,cAAc,CAAC,GAAG,CAAC,IAAI,EAAE,KAAK,CAAC,CAAA;IACjC,CAAC;IAED,OAAO,EAAE,OAAO,EAAE,cAAc,EAAE,YAAY,EAAE,YAAY,EAAE,gBAAgB,EAAE,YAAY,EAAE,cAAc,EAAE,eAAe,EAAE,UAAU,EAAE,CAAA;AAC7I,CAAC"}
|
package/dist/capture.js
DELETED
|
@@ -1,33 +0,0 @@
|
|
|
1
|
-
// Activation capture — opt-in readback of intermediate tensors at training step.
|
|
2
|
-
//
|
|
3
|
-
// Usage (inside the user's forward pass):
|
|
4
|
-
//
|
|
5
|
-
// import { capture } from 'tensorgrad'
|
|
6
|
-
//
|
|
7
|
-
// function attentionFwd(p, x) {
|
|
8
|
-
// const scores = mul(matmulBatched(q, kT), SCALE_QK)
|
|
9
|
-
// const attn = capture(`attn.${layerIdx}`, softmaxCausalLast(scores))
|
|
10
|
-
// return matmulBatched(attn, v)
|
|
11
|
-
// }
|
|
12
|
-
//
|
|
13
|
-
// Pass-through return type: `capture(name, t)` returns `t` unchanged so it
|
|
14
|
-
// inlines at the point of computation. Behind the scenes it registers `t.id`
|
|
15
|
-
// against `name` on the current graph; runtime exposes the registered tensors
|
|
16
|
-
// via `step(inputs, { withCaptures: true })`.
|
|
17
|
-
//
|
|
18
|
-
// Outside the user's forward trace (during `appendGrad` / `appendAdam`'s
|
|
19
|
-
// `traceInto` re-entry), `capture()` is a no-op — gradient and optimizer
|
|
20
|
-
// internals shouldn't accidentally publish themselves to the UI.
|
|
21
|
-
import { currentGraph, isCaptureEnabled } from './trace.js';
|
|
22
|
-
export function capture(name, t) {
|
|
23
|
-
if (!isCaptureEnabled())
|
|
24
|
-
return t;
|
|
25
|
-
const g = currentGraph();
|
|
26
|
-
if (g.captures.has(name)) {
|
|
27
|
-
throw new Error(`capture: name '${name}' already registered. Use unique names ` +
|
|
28
|
-
`(e.g. \`attn.\${layerIdx}\`) when capturing across a loop.`);
|
|
29
|
-
}
|
|
30
|
-
g.captures.set(name, t.id);
|
|
31
|
-
return t;
|
|
32
|
-
}
|
|
33
|
-
//# sourceMappingURL=capture.js.map
|
package/dist/capture.js.map
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"capture.js","sourceRoot":"","sources":["../src/capture.ts"],"names":[],"mappings":"AAAA,iFAAiF;AACjF,EAAE;AACF,0CAA0C;AAC1C,EAAE;AACF,yCAAyC;AACzC,EAAE;AACF,kCAAkC;AAClC,yDAAyD;AACzD,0EAA0E;AAC1E,oCAAoC;AACpC,MAAM;AACN,EAAE;AACF,2EAA2E;AAC3E,6EAA6E;AAC7E,8EAA8E;AAC9E,8CAA8C;AAC9C,EAAE;AACF,yEAAyE;AACzE,yEAAyE;AACzE,iEAAiE;AAGjE,OAAO,EAAE,YAAY,EAAE,gBAAgB,EAAE,MAAM,YAAY,CAAA;AAE3D,MAAM,UAAU,OAAO,CAAmB,IAAY,EAAE,CAAI;IAC1D,IAAI,CAAC,gBAAgB,EAAE;QAAE,OAAO,CAAC,CAAA;IACjC,MAAM,CAAC,GAAG,YAAY,EAAE,CAAA;IACxB,IAAI,CAAC,CAAC,QAAQ,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,CAAC;QACzB,MAAM,IAAI,KAAK,CACb,kBAAkB,IAAI,yCAAyC;YAC/D,4DAA4D,CAC7D,CAAA;IACH,CAAC;IACD,CAAC,CAAC,QAAQ,CAAC,GAAG,CAAC,IAAI,EAAE,CAAC,CAAC,EAAE,CAAC,CAAA;IAC1B,OAAO,CAAC,CAAA;AACV,CAAC"}
|