tensorgrad 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +121 -0
- package/SPEC.md +293 -0
- package/dist/adam.d.ts +31 -0
- package/dist/adam.d.ts.map +1 -0
- package/dist/adam.js +66 -0
- package/dist/adam.js.map +1 -0
- package/dist/buffers.d.ts +56 -0
- package/dist/buffers.d.ts.map +1 -0
- package/dist/buffers.js +114 -0
- package/dist/buffers.js.map +1 -0
- package/dist/codegen.d.ts +23 -0
- package/dist/codegen.d.ts.map +1 -0
- package/dist/codegen.js +709 -0
- package/dist/codegen.js.map +1 -0
- package/dist/compile.d.ts +53 -0
- package/dist/compile.d.ts.map +1 -0
- package/dist/compile.js +76 -0
- package/dist/compile.js.map +1 -0
- package/dist/grad.d.ts +8 -0
- package/dist/grad.d.ts.map +1 -0
- package/dist/grad.js +404 -0
- package/dist/grad.js.map +1 -0
- package/dist/index.d.ts +12 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +37 -0
- package/dist/index.js.map +1 -0
- package/dist/ir.d.ts +204 -0
- package/dist/ir.d.ts.map +1 -0
- package/dist/ir.js +60 -0
- package/dist/ir.js.map +1 -0
- package/dist/module.d.ts +21 -0
- package/dist/module.d.ts.map +1 -0
- package/dist/module.js +113 -0
- package/dist/module.js.map +1 -0
- package/dist/ops.d.ts +35 -0
- package/dist/ops.d.ts.map +1 -0
- package/dist/ops.js +270 -0
- package/dist/ops.js.map +1 -0
- package/dist/runtime.d.ts +26 -0
- package/dist/runtime.d.ts.map +1 -0
- package/dist/runtime.js +190 -0
- package/dist/runtime.js.map +1 -0
- package/dist/shape.d.ts +24 -0
- package/dist/shape.d.ts.map +1 -0
- package/dist/shape.js +259 -0
- package/dist/shape.js.map +1 -0
- package/dist/trace.d.ts +8 -0
- package/dist/trace.d.ts.map +1 -0
- package/dist/trace.js +93 -0
- package/dist/trace.js.map +1 -0
- package/package.json +62 -0
- package/src/adam.ts +95 -0
- package/src/buffers.ts +173 -0
- package/src/codegen.ts +758 -0
- package/src/compile.ts +120 -0
- package/src/grad.ts +459 -0
- package/src/index.ts +40 -0
- package/src/ir.ts +197 -0
- package/src/module.ts +126 -0
- package/src/ops.ts +311 -0
- package/src/runtime.ts +232 -0
- package/src/shape.ts +263 -0
- package/src/trace.ts +101 -0
package/dist/buffers.js
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
// Buffer planning: walk a Graph and decide which GPU buffer each Tensor maps to.
|
|
2
|
+
//
|
|
3
|
+
// v1 strategy: one GPU buffer per IR Tensor. Static shapes mean every buffer's
|
|
4
|
+
// size is known at compile time and lifetimes don't overlap between steps —
|
|
5
|
+
// so no pooling needed. Total memory is the sum of every intermediate tensor.
|
|
6
|
+
// For our transformer at B=256: ~30 MB of activations + grads. Easily fits.
|
|
7
|
+
//
|
|
8
|
+
// Categorization is what the runtime cares about:
|
|
9
|
+
// * param — uploaded by user via uploadParams; persistent across steps
|
|
10
|
+
// * param_grad — written each step by the backward pass; readable for inspection
|
|
11
|
+
// * tensor_input — uploaded each step (tokens, targets, masks)
|
|
12
|
+
// * intermediate — produced by an op; lifetime = within a single step
|
|
13
|
+
// * output — special intermediate that should be made readable (loss)
|
|
14
|
+
const dtypeBytes = { f32: 4, i32: 4, bool: 4 };
|
|
15
|
+
function shapeSize(shape) {
|
|
16
|
+
let n = 1;
|
|
17
|
+
for (const d of shape)
|
|
18
|
+
n *= d;
|
|
19
|
+
return n;
|
|
20
|
+
}
|
|
21
|
+
/**
|
|
22
|
+
* Build a BufferPlan from a graph + the param-grad map produced by appendGrad.
|
|
23
|
+
* @param graph the full graph (forward + backward + any optimizer ops)
|
|
24
|
+
* @param paramGrads map from param name -> the Tensor that holds its gradient
|
|
25
|
+
* @param writebackDecls list of end-of-step writebacks (e.g. from appendAdam).
|
|
26
|
+
* Empty when there's no optimizer in the graph.
|
|
27
|
+
*/
|
|
28
|
+
export function planBuffers(graph, paramGrads, writebackDecls = []) {
|
|
29
|
+
const buffers = [];
|
|
30
|
+
const tensorToBuffer = new Map();
|
|
31
|
+
const paramsByName = new Map();
|
|
32
|
+
const inputsByName = new Map();
|
|
33
|
+
const paramGradsByName = new Map();
|
|
34
|
+
const statesByName = new Map();
|
|
35
|
+
// Build a quick reverse map: tensorId -> param name (for grads).
|
|
36
|
+
const gradTensorIdToName = new Map();
|
|
37
|
+
for (const [name, tensor] of Object.entries(paramGrads)) {
|
|
38
|
+
gradTensorIdToName.set(tensor.id, name);
|
|
39
|
+
}
|
|
40
|
+
// ...and tensorId -> param/input op (so we can name the buffer correctly).
|
|
41
|
+
const opByOutId = new Map();
|
|
42
|
+
for (const op of graph.ops)
|
|
43
|
+
opByOutId.set(op.out, op);
|
|
44
|
+
const outputSet = new Set(graph.outputs);
|
|
45
|
+
// Walk all tensors in id order. Categorize each.
|
|
46
|
+
for (const t of graph.tensors) {
|
|
47
|
+
const op = opByOutId.get(t.id);
|
|
48
|
+
let kind = 'intermediate';
|
|
49
|
+
let name = null;
|
|
50
|
+
let initValue;
|
|
51
|
+
if (op?.kind === 'param_input') {
|
|
52
|
+
kind = 'param';
|
|
53
|
+
name = op.name;
|
|
54
|
+
}
|
|
55
|
+
else if (op?.kind === 'tensor_input') {
|
|
56
|
+
kind = 'tensor_input';
|
|
57
|
+
name = op.name;
|
|
58
|
+
}
|
|
59
|
+
else if (op?.kind === 'state_input') {
|
|
60
|
+
kind = 'state';
|
|
61
|
+
name = op.name;
|
|
62
|
+
initValue = op.initValue;
|
|
63
|
+
}
|
|
64
|
+
else if (gradTensorIdToName.has(t.id)) {
|
|
65
|
+
kind = 'param_grad';
|
|
66
|
+
name = gradTensorIdToName.get(t.id);
|
|
67
|
+
}
|
|
68
|
+
else if (outputSet.has(t.id)) {
|
|
69
|
+
kind = 'output';
|
|
70
|
+
}
|
|
71
|
+
const spec = {
|
|
72
|
+
id: t.id,
|
|
73
|
+
byteSize: Math.max(4, shapeSize(t.shape) * dtypeBytes[t.dtype]),
|
|
74
|
+
dtype: t.dtype,
|
|
75
|
+
shape: t.shape,
|
|
76
|
+
kind,
|
|
77
|
+
name,
|
|
78
|
+
...(initValue !== undefined ? { initValue } : {}),
|
|
79
|
+
};
|
|
80
|
+
buffers.push(spec);
|
|
81
|
+
tensorToBuffer.set(t.id, t.id); // 1:1 for v1
|
|
82
|
+
if (kind === 'param')
|
|
83
|
+
paramsByName.set(name, t.id);
|
|
84
|
+
if (kind === 'tensor_input')
|
|
85
|
+
inputsByName.set(name, t.id);
|
|
86
|
+
if (kind === 'param_grad')
|
|
87
|
+
paramGradsByName.set(name, t.id);
|
|
88
|
+
if (kind === 'state')
|
|
89
|
+
statesByName.set(name, t.id);
|
|
90
|
+
}
|
|
91
|
+
const outputBufferIds = graph.outputs.map(id => tensorToBuffer.get(id));
|
|
92
|
+
// Resolve writeback declarations to (source, dest) buffer-id pairs.
|
|
93
|
+
const writebacks = writebackDecls.map(decl => {
|
|
94
|
+
const sourceBufId = tensorToBuffer.get(decl.source.id);
|
|
95
|
+
if (sourceBufId === undefined) {
|
|
96
|
+
throw new Error(`planBuffers: writeback source tensor #${decl.source.id} not in graph`);
|
|
97
|
+
}
|
|
98
|
+
const destBufId = decl.destKind === 'param'
|
|
99
|
+
? paramsByName.get(decl.destName)
|
|
100
|
+
: statesByName.get(decl.destName);
|
|
101
|
+
if (destBufId === undefined) {
|
|
102
|
+
throw new Error(`planBuffers: writeback dest ${decl.destKind}:'${decl.destName}' not found`);
|
|
103
|
+
}
|
|
104
|
+
const sourceSpec = buffers[sourceBufId];
|
|
105
|
+
const destSpec = buffers[destBufId];
|
|
106
|
+
if (sourceSpec.byteSize !== destSpec.byteSize) {
|
|
107
|
+
throw new Error(`planBuffers: writeback size mismatch for ${decl.destKind}:'${decl.destName}' ` +
|
|
108
|
+
`(source ${sourceSpec.byteSize} bytes vs dest ${destSpec.byteSize})`);
|
|
109
|
+
}
|
|
110
|
+
return { source: sourceBufId, dest: destBufId, bytes: sourceSpec.byteSize };
|
|
111
|
+
});
|
|
112
|
+
return { buffers, tensorToBuffer, paramsByName, inputsByName, paramGradsByName, statesByName, outputBufferIds, writebacks };
|
|
113
|
+
}
|
|
114
|
+
//# sourceMappingURL=buffers.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"buffers.js","sourceRoot":"","sources":["../src/buffers.ts"],"names":[],"mappings":"AAAA,iFAAiF;AACjF,EAAE;AACF,+EAA+E;AAC/E,4EAA4E;AAC5E,8EAA8E;AAC9E,4EAA4E;AAC5E,EAAE;AACF,kDAAkD;AAClD,gFAAgF;AAChF,qFAAqF;AACrF,iEAAiE;AACjE,wEAAwE;AACxE,8EAA8E;AA0C9E,MAAM,UAAU,GAA0B,EAAE,GAAG,EAAE,CAAC,EAAE,GAAG,EAAE,CAAC,EAAE,IAAI,EAAE,CAAC,EAAE,CAAA;AAErE,SAAS,SAAS,CAAC,KAAY;IAC7B,IAAI,CAAC,GAAG,CAAC,CAAA;IACT,KAAK,MAAM,CAAC,IAAI,KAAK;QAAE,CAAC,IAAI,CAAC,CAAA;IAC7B,OAAO,CAAC,CAAA;AACV,CAAC;AAcD;;;;;;GAMG;AACH,MAAM,UAAU,WAAW,CACzB,KAAY,EACZ,UAAkC,EAClC,iBAAkC,EAAE;IAEpC,MAAM,OAAO,GAAiB,EAAE,CAAA;IAChC,MAAM,cAAc,GAAG,IAAI,GAAG,EAAkB,CAAA;IAChD,MAAM,YAAY,GAAG,IAAI,GAAG,EAAkB,CAAA;IAC9C,MAAM,YAAY,GAAG,IAAI,GAAG,EAAkB,CAAA;IAC9C,MAAM,gBAAgB,GAAG,IAAI,GAAG,EAAkB,CAAA;IAClD,MAAM,YAAY,GAAG,IAAI,GAAG,EAAkB,CAAA;IAE9C,iEAAiE;IACjE,MAAM,kBAAkB,GAAG,IAAI,GAAG,EAAkB,CAAA;IACpD,KAAK,MAAM,CAAC,IAAI,EAAE,MAAM,CAAC,IAAI,MAAM,CAAC,OAAO,CAAC,UAAU,CAAC,EAAE,CAAC;QACxD,kBAAkB,CAAC,GAAG,CAAC,MAAM,CAAC,EAAE,EAAE,IAAI,CAAC,CAAA;IACzC,CAAC;IACD,2EAA2E;IAC3E,MAAM,SAAS,GAAG,IAAI,GAAG,EAAkB,CAAA;IAC3C,KAAK,MAAM,EAAE,IAAI,KAAK,CAAC,GAAG;QAAE,SAAS,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,EAAE,CAAC,CAAA;IAErD,MAAM,SAAS,GAAG,IAAI,GAAG,CAAC,KAAK,CAAC,OAAO,CAAC,CAAA;IAExC,iDAAiD;IACjD,KAAK,MAAM,CAAC,IAAI,KAAK,CAAC,OAAO,EAAE,CAAC;QAC9B,MAAM,EAAE,GAAG,SAAS,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,CAAA;QAC9B,IAAI,IAAI,GAAuB,cAAc,CAAA;QAC7C,IAAI,IAAI,GAAkB,IAAI,CAAA;QAC9B,IAAI,SAA6B,CAAA;QAEjC,IAAI,EAAE,EAAE,IAAI,KAAK,aAAa,EAAE,CAAC;YAC/B,IAAI,GAAG,OAAO,CAAA;YACd,IAAI,GAAG,EAAE,CAAC,IAAI,CAAA;QAChB,CAAC;aAAM,IAAI,EAAE,EAAE,IAAI,KAAK,cAAc,EAAE,CAAC;YACvC,IAAI,GAAG,cAAc,CAAA;YACrB,IAAI,GAAG,EAAE,CAAC,IAAI,CAAA;QAChB,CAAC;aAAM,IAAI,EAAE,EAAE,IAAI,KAAK,aAAa,EAAE,CAAC;YACtC,IAAI,GAAG,OAAO,CAAA;YACd,IAAI,GAAG,EAAE,CAAC,IAAI,CAAA;YACd,SAAS,GAAG,EAAE,CAAC,SAAS,CAAA;QAC1B,CAAC;aAAM,IAAI,kBAAkB,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC;YACxC,IAAI,GAAG,YAAY,CAAA;YACnB,IAAI,GAAG,kBAAkB,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAE,CAAA;QACtC,CAAC;aAAM,IAAI,SAAS,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC;YAC/B,IAAI,GAAG,QAAQ,CAAA;QACjB,CAAC;QAED,MAAM,IAAI,GAAe;YACvB,EAAE,EAAE,CAAC,CAAC,EAAE;YACR,QAAQ,EAAE,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC,KAAK,CAAC,GAAG,UAAU,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;YAC/D,KAAK,EAAE,CAAC,CAAC,KAAK;YACd,KAAK,EAAE,CAAC,CAAC,KAAK;YACd,IAAI;YACJ,IAAI;YACJ,GAAG,CAAC,SAAS,KAAK,SAAS,CAAC,CAAC,CAAC,EAAE,SAAS,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC;SAClD,CAAA;QACD,OAAO,CAAC,IAAI,CAAC,IAAI,CAAC,CAAA;QAClB,cAAc,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,CAAA,CAAE,aAAa;QAE7C,IAAI,IAAI,KAAK,OAAO;YAAE,YAAY,CAAC,GAAG,CAAC,IAAK,EAAE,CAAC,CAAC,EAAE,CAAC,CAAA;QACnD,IAAI,IAAI,KAAK,cAAc;YAAE,YAAY,CAAC,GAAG,CAAC,IAAK,EAAE,CAAC,CAAC,EAAE,CAAC,CAAA;QAC1D,IAAI,IAAI,KAAK,YAAY;YAAE,gBAAgB,CAAC,GAAG,CAAC,IAAK,EAAE,CAAC,CAAC,EAAE,CAAC,CAAA;QAC5D,IAAI,IAAI,KAAK,OAAO;YAAE,YAAY,CAAC,GAAG,CAAC,IAAK,EAAE,CAAC,CAAC,EAAE,CAAC,CAAA;IACrD,CAAC;IAED,MAAM,eAAe,GAAG,KAAK,CAAC,OAAO,CAAC,GAAG,CAAC,EAAE,CAAC,EAAE,CAAC,cAAc,CAAC,GAAG,CAAC,EAAE,CAAE,CAAC,CAAA;IAExE,oEAAoE;IACpE,MAAM,UAAU,GAAgB,cAAc,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE;QACxD,MAAM,WAAW,GAAG,cAAc,CAAC,GAAG,CAAC,IAAI,CAAC,MAAM,CAAC,EAAE,CAAC,CAAA;QACtD,IAAI,WAAW,KAAK,SAAS,EAAE,CAAC;YAC9B,MAAM,IAAI,KAAK,CAAC,yCAAyC,IAAI,CAAC,MAAM,CAAC,EAAE,eAAe,CAAC,CAAA;QACzF,CAAC;QACD,MAAM,SAAS,GAAG,IAAI,CAAC,QAAQ,KAAK,OAAO;YACzC,CAAC,CAAC,YAAY,CAAC,GAAG,CAAC,IAAI,CAAC,QAAQ,CAAC;YACjC,CAAC,CAAC,YAAY,CAAC,GAAG,CAAC,IAAI,CAAC,QAAQ,CAAC,CAAA;QACnC,IAAI,SAAS,KAAK,SAAS,EAAE,CAAC;YAC5B,MAAM,IAAI,KAAK,CAAC,+BAA+B,IAAI,CAAC,QAAQ,KAAK,IAAI,CAAC,QAAQ,aAAa,CAAC,CAAA;QAC9F,CAAC;QACD,MAAM,UAAU,GAAG,OAAO,CAAC,WAAW,CAAE,CAAA;QACxC,MAAM,QAAQ,GAAG,OAAO,CAAC,SAAS,CAAE,CAAA;QACpC,IAAI,UAAU,CAAC,QAAQ,KAAK,QAAQ,CAAC,QAAQ,EAAE,CAAC;YAC9C,MAAM,IAAI,KAAK,CACb,4CAA4C,IAAI,CAAC,QAAQ,KAAK,IAAI,CAAC,QAAQ,IAAI;gBAC/E,WAAW,UAAU,CAAC,QAAQ,kBAAkB,QAAQ,CAAC,QAAQ,GAAG,CACrE,CAAA;QACH,CAAC;QACD,OAAO,EAAE,MAAM,EAAE,WAAW,EAAE,IAAI,EAAE,SAAS,EAAE,KAAK,EAAE,UAAU,CAAC,QAAQ,EAAE,CAAA;IAC7E,CAAC,CAAC,CAAA;IAEF,OAAO,EAAE,OAAO,EAAE,cAAc,EAAE,YAAY,EAAE,YAAY,EAAE,gBAAgB,EAAE,YAAY,EAAE,eAAe,EAAE,UAAU,EAAE,CAAA;AAC7H,CAAC"}
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import type { Graph, OpNode } from './ir.js';
|
|
2
|
+
import type { BufferPlan } from './buffers.js';
|
|
3
|
+
export interface KernelSpec {
|
|
4
|
+
/** Index into graph.ops. */
|
|
5
|
+
opIndex: number;
|
|
6
|
+
/** Op kind (for debugging / pipeline cache key). */
|
|
7
|
+
opKind: OpNode['kind'];
|
|
8
|
+
/** Generated WGSL source. Empty string for "logical" ops with no kernel. */
|
|
9
|
+
wgsl: string;
|
|
10
|
+
/**
|
|
11
|
+
* Buffer ids in binding-index order. The runtime creates a bind group with
|
|
12
|
+
* these in @binding(0..N) on @group(0). Inputs come first (read), output last
|
|
13
|
+
* (read_write).
|
|
14
|
+
*/
|
|
15
|
+
bindings: number[];
|
|
16
|
+
/** Number of threads to dispatch (1-D). 0 means "skip" (e.g. reshape no-op). */
|
|
17
|
+
threads: number;
|
|
18
|
+
/** Workgroup size; usually WG_SIZE. */
|
|
19
|
+
workgroupSize: number;
|
|
20
|
+
}
|
|
21
|
+
/** Generate a KernelSpec per compute op in graph.ops (in dispatch order). */
|
|
22
|
+
export declare function emitKernels(graph: Graph, plan: BufferPlan): KernelSpec[];
|
|
23
|
+
//# sourceMappingURL=codegen.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"codegen.d.ts","sourceRoot":"","sources":["../src/codegen.ts"],"names":[],"mappings":"AAYA,OAAO,KAAK,EAAE,KAAK,EAAE,MAAM,EAAiB,MAAM,SAAS,CAAA;AAC3D,OAAO,KAAK,EAAE,UAAU,EAAE,MAAM,cAAc,CAAA;AAO9C,MAAM,WAAW,UAAU;IACzB,4BAA4B;IAC5B,OAAO,EAAE,MAAM,CAAA;IACf,oDAAoD;IACpD,MAAM,EAAE,MAAM,CAAC,MAAM,CAAC,CAAA;IACtB,4EAA4E;IAC5E,IAAI,EAAE,MAAM,CAAA;IACZ;;;;OAIG;IACH,QAAQ,EAAE,MAAM,EAAE,CAAA;IAClB,gFAAgF;IAChF,OAAO,EAAE,MAAM,CAAA;IACf,uCAAuC;IACvC,aAAa,EAAE,MAAM,CAAA;CACtB;AAMD,6EAA6E;AAC7E,wBAAgB,WAAW,CAAC,KAAK,EAAE,KAAK,EAAE,IAAI,EAAE,UAAU,GAAG,UAAU,EAAE,CAQxE"}
|