tensorgrad 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +121 -0
- package/SPEC.md +293 -0
- package/dist/adam.d.ts +31 -0
- package/dist/adam.d.ts.map +1 -0
- package/dist/adam.js +66 -0
- package/dist/adam.js.map +1 -0
- package/dist/buffers.d.ts +56 -0
- package/dist/buffers.d.ts.map +1 -0
- package/dist/buffers.js +114 -0
- package/dist/buffers.js.map +1 -0
- package/dist/codegen.d.ts +23 -0
- package/dist/codegen.d.ts.map +1 -0
- package/dist/codegen.js +709 -0
- package/dist/codegen.js.map +1 -0
- package/dist/compile.d.ts +53 -0
- package/dist/compile.d.ts.map +1 -0
- package/dist/compile.js +76 -0
- package/dist/compile.js.map +1 -0
- package/dist/grad.d.ts +8 -0
- package/dist/grad.d.ts.map +1 -0
- package/dist/grad.js +404 -0
- package/dist/grad.js.map +1 -0
- package/dist/index.d.ts +12 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +37 -0
- package/dist/index.js.map +1 -0
- package/dist/ir.d.ts +204 -0
- package/dist/ir.d.ts.map +1 -0
- package/dist/ir.js +60 -0
- package/dist/ir.js.map +1 -0
- package/dist/module.d.ts +21 -0
- package/dist/module.d.ts.map +1 -0
- package/dist/module.js +113 -0
- package/dist/module.js.map +1 -0
- package/dist/ops.d.ts +35 -0
- package/dist/ops.d.ts.map +1 -0
- package/dist/ops.js +270 -0
- package/dist/ops.js.map +1 -0
- package/dist/runtime.d.ts +26 -0
- package/dist/runtime.d.ts.map +1 -0
- package/dist/runtime.js +190 -0
- package/dist/runtime.js.map +1 -0
- package/dist/shape.d.ts +24 -0
- package/dist/shape.d.ts.map +1 -0
- package/dist/shape.js +259 -0
- package/dist/shape.js.map +1 -0
- package/dist/trace.d.ts +8 -0
- package/dist/trace.d.ts.map +1 -0
- package/dist/trace.js +93 -0
- package/dist/trace.js.map +1 -0
- package/package.json +62 -0
- package/src/adam.ts +95 -0
- package/src/buffers.ts +173 -0
- package/src/codegen.ts +758 -0
- package/src/compile.ts +120 -0
- package/src/grad.ts +459 -0
- package/src/index.ts +40 -0
- package/src/ir.ts +197 -0
- package/src/module.ts +126 -0
- package/src/ops.ts +311 -0
- package/src/runtime.ts +232 -0
- package/src/shape.ts +263 -0
- package/src/trace.ts +101 -0
package/dist/shape.js
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
// Shape inference and validation for each op kind.
|
|
2
|
+
//
|
|
3
|
+
// Every op in src/ops.ts validates its inputs and computes its output shape
|
|
4
|
+
// through helpers here. Errors throw with the captured call-site so the
|
|
5
|
+
// stack trace points at the user's line, not into the library.
|
|
6
|
+
//
|
|
7
|
+
// Broadcasting rules (deliberately limited):
|
|
8
|
+
// * For element-wise binops (add/sub/mul/div), we support trailing-axis
|
|
9
|
+
// broadcasting: the smaller operand's shape must be a suffix of the
|
|
10
|
+
// larger's, with axes of size 1 broadcasting to any size. Examples
|
|
11
|
+
// ALLOWED: [B, T, D] op [D] → [B, T, D]
|
|
12
|
+
// [B, T, D] op [1, D] → [B, T, D]
|
|
13
|
+
// [B, T, D] op [B, T, D] → [B, T, D]
|
|
14
|
+
// Examples REJECTED: [B, T, D] op [B] (suffix mismatch)
|
|
15
|
+
// [B, T, D] op [T, D] when T != B (legal numpy, banned here)
|
|
16
|
+
// The restriction makes codegen and autograd much simpler and covers every
|
|
17
|
+
// broadcast pattern in our transformer (biases, layernorm gain/bias, masks).
|
|
18
|
+
import { formatSite } from './ir.js';
|
|
19
|
+
// ============================================================================
|
|
20
|
+
// Errors
|
|
21
|
+
// ============================================================================
|
|
22
|
+
export class ShapeError extends Error {
|
|
23
|
+
constructor(message, site) {
|
|
24
|
+
const formatted = site ? `${message}\n at ${formatSite(site)}` : message;
|
|
25
|
+
super(formatted);
|
|
26
|
+
this.name = 'ShapeError';
|
|
27
|
+
}
|
|
28
|
+
}
|
|
29
|
+
function fail(message, site) {
|
|
30
|
+
throw new ShapeError(message, site);
|
|
31
|
+
}
|
|
32
|
+
// ============================================================================
|
|
33
|
+
// Shape utilities
|
|
34
|
+
// ============================================================================
|
|
35
|
+
export function shapesEqual(a, b) {
|
|
36
|
+
if (a.length !== b.length)
|
|
37
|
+
return false;
|
|
38
|
+
for (let i = 0; i < a.length; i++)
|
|
39
|
+
if (a[i] !== b[i])
|
|
40
|
+
return false;
|
|
41
|
+
return true;
|
|
42
|
+
}
|
|
43
|
+
export function shapeSize(shape) {
|
|
44
|
+
let n = 1;
|
|
45
|
+
for (const d of shape)
|
|
46
|
+
n *= d;
|
|
47
|
+
return n;
|
|
48
|
+
}
|
|
49
|
+
export function showShape(shape) {
|
|
50
|
+
return `[${shape.join(', ')}]`;
|
|
51
|
+
}
|
|
52
|
+
// Standard right-aligned NumPy-style broadcasting. Pad the shorter shape with
|
|
53
|
+
// leading 1s, then per-axis: equal dims unify, size-1 dims broadcast on either
|
|
54
|
+
// side, otherwise incompatible. Returns the resulting shape or null.
|
|
55
|
+
export function broadcastTrailing(a, b) {
|
|
56
|
+
const rank = Math.max(a.length, b.length);
|
|
57
|
+
const out = new Array(rank);
|
|
58
|
+
for (let i = 0; i < rank; i++) {
|
|
59
|
+
const ai = i - (rank - a.length);
|
|
60
|
+
const bi = i - (rank - b.length);
|
|
61
|
+
const av = ai < 0 ? 1 : a[ai];
|
|
62
|
+
const bv = bi < 0 ? 1 : b[bi];
|
|
63
|
+
if (av === bv)
|
|
64
|
+
out[i] = av;
|
|
65
|
+
else if (av === 1)
|
|
66
|
+
out[i] = bv;
|
|
67
|
+
else if (bv === 1)
|
|
68
|
+
out[i] = av;
|
|
69
|
+
else
|
|
70
|
+
return null;
|
|
71
|
+
}
|
|
72
|
+
return out;
|
|
73
|
+
}
|
|
74
|
+
// ============================================================================
|
|
75
|
+
// Per-op shape rules
|
|
76
|
+
// ============================================================================
|
|
77
|
+
//
|
|
78
|
+
// Each rule takes the input shapes and returns the output shape, or throws.
|
|
79
|
+
// All rules accept a `site` for error attribution.
|
|
80
|
+
export function inferElementwiseBinop(opName, aShape, bShape, site) {
|
|
81
|
+
const result = broadcastTrailing(aShape, bShape);
|
|
82
|
+
if (!result) {
|
|
83
|
+
fail(`${opName}: incompatible shapes ${showShape(aShape)} and ${showShape(bShape)}. ` +
|
|
84
|
+
`Trailing-suffix broadcasting only — the smaller shape must be a suffix of the larger, ` +
|
|
85
|
+
`with size-1 axes broadcasting to any size.`, site);
|
|
86
|
+
}
|
|
87
|
+
return result;
|
|
88
|
+
}
|
|
89
|
+
export function inferUnary(_opName, aShape, _site) {
|
|
90
|
+
return aShape;
|
|
91
|
+
}
|
|
92
|
+
export function inferMeanLast(opName, aShape, site) {
|
|
93
|
+
if (aShape.length === 0)
|
|
94
|
+
fail(`${opName}: cannot reduce a 0-d tensor`, site);
|
|
95
|
+
// keepdims=true: replace last axis with 1.
|
|
96
|
+
return [...aShape.slice(0, -1), 1];
|
|
97
|
+
}
|
|
98
|
+
export function inferSumLast(opName, aShape, site) {
|
|
99
|
+
if (aShape.length === 0)
|
|
100
|
+
fail(`${opName}: cannot reduce a 0-d tensor`, site);
|
|
101
|
+
// keepdims=false: drop the last axis.
|
|
102
|
+
return aShape.slice(0, -1);
|
|
103
|
+
}
|
|
104
|
+
export function inferReshape(opName, aShape, newShape, site) {
|
|
105
|
+
// Validate -1 placeholder (at most one allowed) and total size match.
|
|
106
|
+
let inferIdx = -1;
|
|
107
|
+
let knownSize = 1;
|
|
108
|
+
for (let i = 0; i < newShape.length; i++) {
|
|
109
|
+
const d = newShape[i];
|
|
110
|
+
if (d === -1) {
|
|
111
|
+
if (inferIdx !== -1)
|
|
112
|
+
fail(`${opName}: at most one -1 dim allowed in newShape ${showShape(newShape)}`, site);
|
|
113
|
+
inferIdx = i;
|
|
114
|
+
}
|
|
115
|
+
else if (d <= 0) {
|
|
116
|
+
fail(`${opName}: invalid dim ${d} in newShape ${showShape(newShape)}`, site);
|
|
117
|
+
}
|
|
118
|
+
else {
|
|
119
|
+
knownSize *= d;
|
|
120
|
+
}
|
|
121
|
+
}
|
|
122
|
+
const totalIn = shapeSize(aShape);
|
|
123
|
+
const out = [...newShape];
|
|
124
|
+
if (inferIdx !== -1) {
|
|
125
|
+
if (totalIn % knownSize !== 0) {
|
|
126
|
+
fail(`${opName}: cannot reshape ${showShape(aShape)} (size ${totalIn}) to ${showShape(newShape)} — known dims multiply to ${knownSize}`, site);
|
|
127
|
+
}
|
|
128
|
+
out[inferIdx] = totalIn / knownSize;
|
|
129
|
+
}
|
|
130
|
+
else if (knownSize !== totalIn) {
|
|
131
|
+
fail(`${opName}: size mismatch — input ${showShape(aShape)} has ${totalIn} elements but newShape ${showShape(newShape)} has ${knownSize}`, site);
|
|
132
|
+
}
|
|
133
|
+
return out;
|
|
134
|
+
}
|
|
135
|
+
export function inferTranspose(opName, aShape, perm, site) {
|
|
136
|
+
if (perm.length !== aShape.length) {
|
|
137
|
+
fail(`${opName}: perm length ${perm.length} must equal input rank ${aShape.length}`, site);
|
|
138
|
+
}
|
|
139
|
+
const seen = new Set();
|
|
140
|
+
for (const p of perm) {
|
|
141
|
+
if (p < 0 || p >= aShape.length)
|
|
142
|
+
fail(`${opName}: perm index ${p} out of range for rank ${aShape.length}`, site);
|
|
143
|
+
if (seen.has(p))
|
|
144
|
+
fail(`${opName}: perm has duplicate index ${p}`, site);
|
|
145
|
+
seen.add(p);
|
|
146
|
+
}
|
|
147
|
+
return perm.map(p => aShape[p]);
|
|
148
|
+
}
|
|
149
|
+
// matmul: a [..., M, K] · b [K, N] → [..., M, N]. b is unbatched.
|
|
150
|
+
export function inferMatmul(opName, aShape, bShape, site) {
|
|
151
|
+
if (aShape.length < 2)
|
|
152
|
+
fail(`${opName}: lhs must have rank >= 2, got ${showShape(aShape)}`, site);
|
|
153
|
+
if (bShape.length !== 2)
|
|
154
|
+
fail(`${opName}: rhs must have rank 2, got ${showShape(bShape)} — use matmulBatched for batched rhs`, site);
|
|
155
|
+
const M = aShape[aShape.length - 2];
|
|
156
|
+
const Ka = aShape[aShape.length - 1];
|
|
157
|
+
const Kb = bShape[0];
|
|
158
|
+
const N = bShape[1];
|
|
159
|
+
if (Ka !== Kb)
|
|
160
|
+
fail(`${opName}: inner dims don't match — ${showShape(aShape)} · ${showShape(bShape)} (last axis of lhs = ${Ka}, first axis of rhs = ${Kb})`, site);
|
|
161
|
+
return [...aShape.slice(0, -2), M, N];
|
|
162
|
+
}
|
|
163
|
+
// matmul_batched: a [..., M, K] · b [..., K, N] → [..., M, N]. Both have leading batch dims.
|
|
164
|
+
export function inferMatmulBatched(opName, aShape, bShape, site) {
|
|
165
|
+
if (aShape.length < 2 || bShape.length < 2) {
|
|
166
|
+
fail(`${opName}: both inputs must have rank >= 2, got ${showShape(aShape)} and ${showShape(bShape)}`, site);
|
|
167
|
+
}
|
|
168
|
+
if (aShape.length !== bShape.length) {
|
|
169
|
+
fail(`${opName}: ranks must match (got ${aShape.length} vs ${bShape.length}). Reshape if you need different batch dims.`, site);
|
|
170
|
+
}
|
|
171
|
+
const aBatch = aShape.slice(0, -2);
|
|
172
|
+
const bBatch = bShape.slice(0, -2);
|
|
173
|
+
for (let i = 0; i < aBatch.length; i++) {
|
|
174
|
+
if (aBatch[i] !== bBatch[i]) {
|
|
175
|
+
fail(`${opName}: batch dims must match — ${showShape(aShape)} vs ${showShape(bShape)}`, site);
|
|
176
|
+
}
|
|
177
|
+
}
|
|
178
|
+
const M = aShape[aShape.length - 2];
|
|
179
|
+
const Ka = aShape[aShape.length - 1];
|
|
180
|
+
const Kb = bShape[bShape.length - 2];
|
|
181
|
+
const N = bShape[bShape.length - 1];
|
|
182
|
+
if (Ka !== Kb)
|
|
183
|
+
fail(`${opName}: inner dims don't match — last axis of lhs = ${Ka}, second-to-last of rhs = ${Kb}`, site);
|
|
184
|
+
return [...aBatch, M, N];
|
|
185
|
+
}
|
|
186
|
+
export function inferOneHot(opName, indicesShape, depth, site) {
|
|
187
|
+
if (depth <= 0)
|
|
188
|
+
fail(`${opName}: depth must be positive, got ${depth}`, site);
|
|
189
|
+
return [...indicesShape, depth];
|
|
190
|
+
}
|
|
191
|
+
// where_causal preserves shape but requires the last two axes to be square.
|
|
192
|
+
export function inferWhereCausal(opName, aShape, site) {
|
|
193
|
+
if (aShape.length < 2)
|
|
194
|
+
fail(`${opName}: requires rank >= 2, got ${showShape(aShape)}`, site);
|
|
195
|
+
const m = aShape[aShape.length - 2];
|
|
196
|
+
const n = aShape[aShape.length - 1];
|
|
197
|
+
if (m !== n)
|
|
198
|
+
fail(`${opName}: last two axes must be equal (square mask), got ${showShape(aShape)}`, site);
|
|
199
|
+
return aShape;
|
|
200
|
+
}
|
|
201
|
+
export function inferSliceLastRange(opName, aShape, start, end, site) {
|
|
202
|
+
if (aShape.length === 0)
|
|
203
|
+
fail(`${opName}: cannot slice 0-d tensor`, site);
|
|
204
|
+
const last = aShape[aShape.length - 1];
|
|
205
|
+
if (start < 0 || end > last || start >= end) {
|
|
206
|
+
fail(`${opName}: invalid range [${start}, ${end}) for last axis of size ${last}`, site);
|
|
207
|
+
}
|
|
208
|
+
return [...aShape.slice(0, -1), end - start];
|
|
209
|
+
}
|
|
210
|
+
// broadcast_to: validate that `aShape` can broadcast to `targetShape` under
|
|
211
|
+
// right-aligned NumPy rules. Returns targetShape on success.
|
|
212
|
+
export function inferBroadcastTo(opName, aShape, targetShape, site) {
|
|
213
|
+
if (aShape.length > targetShape.length) {
|
|
214
|
+
fail(`${opName}: source rank ${aShape.length} > target rank ${targetShape.length}`, site);
|
|
215
|
+
}
|
|
216
|
+
const offset = targetShape.length - aShape.length;
|
|
217
|
+
for (let i = 0; i < aShape.length; i++) {
|
|
218
|
+
const av = aShape[i];
|
|
219
|
+
const tv = targetShape[offset + i];
|
|
220
|
+
if (av !== tv && av !== 1) {
|
|
221
|
+
fail(`${opName}: cannot broadcast ${showShape(aShape)} to ${showShape(targetShape)} — axis ${i} (size ${av}) doesn't match target axis ${offset + i} (size ${tv}) and isn't 1`, site);
|
|
222
|
+
}
|
|
223
|
+
}
|
|
224
|
+
return targetShape;
|
|
225
|
+
}
|
|
226
|
+
// sum_to_shape: validate that `targetShape` is a valid right-aligned reduction
|
|
227
|
+
// of `aShape` (i.e., aShape can have been produced by broadcasting targetShape).
|
|
228
|
+
export function inferSumToShape(opName, aShape, targetShape, site) {
|
|
229
|
+
if (targetShape.length > aShape.length) {
|
|
230
|
+
fail(`${opName}: target rank ${targetShape.length} > source rank ${aShape.length}`, site);
|
|
231
|
+
}
|
|
232
|
+
const offset = aShape.length - targetShape.length;
|
|
233
|
+
for (let i = 0; i < targetShape.length; i++) {
|
|
234
|
+
const av = aShape[offset + i];
|
|
235
|
+
const tv = targetShape[i];
|
|
236
|
+
if (av !== tv && tv !== 1) {
|
|
237
|
+
fail(`${opName}: cannot sum-reduce ${showShape(aShape)} to ${showShape(targetShape)} — target axis ${i} (size ${tv}) must be 1 or match source`, site);
|
|
238
|
+
}
|
|
239
|
+
}
|
|
240
|
+
return targetShape;
|
|
241
|
+
}
|
|
242
|
+
// Three-way broadcast for `where(cond, a, b)`. All three shapes must broadcast
|
|
243
|
+
// to a common shape under standard NumPy rules.
|
|
244
|
+
export function inferWhere(opName, condShape, aShape, bShape, site) {
|
|
245
|
+
const ab = broadcastTrailing(aShape, bShape);
|
|
246
|
+
if (!ab)
|
|
247
|
+
fail(`${opName}: a/b incompatible: ${showShape(aShape)} vs ${showShape(bShape)}`, site);
|
|
248
|
+
const result = broadcastTrailing(condShape, ab);
|
|
249
|
+
if (!result)
|
|
250
|
+
fail(`${opName}: cond ${showShape(condShape)} incompatible with broadcast(a, b) ${showShape(ab)}`, site);
|
|
251
|
+
return result;
|
|
252
|
+
}
|
|
253
|
+
export function inferReluGrad(opName, xShape, dyShape, site) {
|
|
254
|
+
if (!shapesEqual(xShape, dyShape)) {
|
|
255
|
+
fail(`${opName}: x and dy must have matching shapes, got ${showShape(xShape)} and ${showShape(dyShape)}`, site);
|
|
256
|
+
}
|
|
257
|
+
return xShape;
|
|
258
|
+
}
|
|
259
|
+
//# sourceMappingURL=shape.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"shape.js","sourceRoot":"","sources":["../src/shape.ts"],"names":[],"mappings":"AAAA,mDAAmD;AACnD,EAAE;AACF,4EAA4E;AAC5E,wEAAwE;AACxE,+DAA+D;AAC/D,EAAE;AACF,6CAA6C;AAC7C,0EAA0E;AAC1E,wEAAwE;AACxE,uEAAuE;AACvE,+CAA+C;AAC/C,iDAAiD;AACjD,oDAAoD;AACpD,+DAA+D;AAC/D,qFAAqF;AACrF,6EAA6E;AAC7E,+EAA+E;AAG/E,OAAO,EAAE,UAAU,EAAE,MAAM,SAAS,CAAA;AAEpC,+EAA+E;AAC/E,SAAS;AACT,+EAA+E;AAE/E,MAAM,OAAO,UAAW,SAAQ,KAAK;IACnC,YAAY,OAAe,EAAE,IAAqB;QAChD,MAAM,SAAS,GAAG,IAAI,CAAC,CAAC,CAAC,GAAG,OAAO,UAAU,UAAU,CAAC,IAAI,CAAC,EAAE,CAAC,CAAC,CAAC,OAAO,CAAA;QACzE,KAAK,CAAC,SAAS,CAAC,CAAA;QAChB,IAAI,CAAC,IAAI,GAAG,YAAY,CAAA;IAC1B,CAAC;CACF;AAED,SAAS,IAAI,CAAC,OAAe,EAAE,IAAqB;IAClD,MAAM,IAAI,UAAU,CAAC,OAAO,EAAE,IAAI,CAAC,CAAA;AACrC,CAAC;AAED,+EAA+E;AAC/E,kBAAkB;AAClB,+EAA+E;AAE/E,MAAM,UAAU,WAAW,CAAC,CAAQ,EAAE,CAAQ;IAC5C,IAAI,CAAC,CAAC,MAAM,KAAK,CAAC,CAAC,MAAM;QAAE,OAAO,KAAK,CAAA;IACvC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,MAAM,EAAE,CAAC,EAAE;QAAE,IAAI,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;YAAE,OAAO,KAAK,CAAA;IAClE,OAAO,IAAI,CAAA;AACb,CAAC;AAED,MAAM,UAAU,SAAS,CAAC,KAAY;IACpC,IAAI,CAAC,GAAG,CAAC,CAAA;IACT,KAAK,MAAM,CAAC,IAAI,KAAK;QAAE,CAAC,IAAI,CAAC,CAAA;IAC7B,OAAO,CAAC,CAAA;AACV,CAAC;AAED,MAAM,UAAU,SAAS,CAAC,KAAY;IACpC,OAAO,IAAI,KAAK,CAAC,IAAI,CAAC,IAAI,CAAC,GAAG,CAAA;AAChC,CAAC;AAED,8EAA8E;AAC9E,+EAA+E;AAC/E,qEAAqE;AACrE,MAAM,UAAU,iBAAiB,CAAC,CAAQ,EAAE,CAAQ;IAClD,MAAM,IAAI,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,CAAC,MAAM,EAAE,CAAC,CAAC,MAAM,CAAC,CAAA;IACzC,MAAM,GAAG,GAAa,IAAI,KAAK,CAAC,IAAI,CAAC,CAAA;IACrC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,EAAE,CAAC,EAAE,EAAE,CAAC;QAC9B,MAAM,EAAE,GAAG,CAAC,GAAG,CAAC,IAAI,GAAG,CAAC,CAAC,MAAM,CAAC,CAAA;QAChC,MAAM,EAAE,GAAG,CAAC,GAAG,CAAC,IAAI,GAAG,CAAC,CAAC,MAAM,CAAC,CAAA;QAChC,MAAM,EAAE,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,EAAE,CAAE,CAAA;QAC9B,MAAM,EAAE,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,EAAE,CAAE,CAAA;QAC9B,IAAI,EAAE,KAAK,EAAE;YAAE,GAAG,CAAC,CAAC,CAAC,GAAG,EAAE,CAAA;aACrB,IAAI,EAAE,KAAK,CAAC;YAAE,GAAG,CAAC,CAAC,CAAC,GAAG,EAAE,CAAA;aACzB,IAAI,EAAE,KAAK,CAAC;YAAE,GAAG,CAAC,CAAC,CAAC,GAAG,EAAE,CAAA;;YACzB,OAAO,IAAI,CAAA;IAClB,CAAC;IACD,OAAO,GAAG,CAAA;AACZ,CAAC;AAED,+EAA+E;AAC/E,qBAAqB;AACrB,+EAA+E;AAC/E,EAAE;AACF,4EAA4E;AAC5E,mDAAmD;AAEnD,MAAM,UAAU,qBAAqB,CACnC,MAAc,EAAE,MAAa,EAAE,MAAa,EAAE,IAAqB;IAEnE,MAAM,MAAM,GAAG,iBAAiB,CAAC,MAAM,EAAE,MAAM,CAAC,CAAA;IAChD,IAAI,CAAC,MAAM,EAAE,CAAC;QACZ,IAAI,CACF,GAAG,MAAM,yBAAyB,SAAS,CAAC,MAAM,CAAC,QAAQ,SAAS,CAAC,MAAM,CAAC,IAAI;YAChF,wFAAwF;YACxF,4CAA4C,EAC5C,IAAI,CACL,CAAA;IACH,CAAC;IACD,OAAO,MAAM,CAAA;AACf,CAAC;AAED,MAAM,UAAU,UAAU,CAAC,OAAe,EAAE,MAAa,EAAE,KAAsB;IAC/E,OAAO,MAAM,CAAA;AACf,CAAC;AAED,MAAM,UAAU,aAAa,CAAC,MAAc,EAAE,MAAa,EAAE,IAAqB;IAChF,IAAI,MAAM,CAAC,MAAM,KAAK,CAAC;QAAE,IAAI,CAAC,GAAG,MAAM,8BAA8B,EAAE,IAAI,CAAC,CAAA;IAC5E,2CAA2C;IAC3C,OAAO,CAAC,GAAG,MAAM,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAA;AACpC,CAAC;AAED,MAAM,UAAU,YAAY,CAAC,MAAc,EAAE,MAAa,EAAE,IAAqB;IAC/E,IAAI,MAAM,CAAC,MAAM,KAAK,CAAC;QAAE,IAAI,CAAC,GAAG,MAAM,8BAA8B,EAAE,IAAI,CAAC,CAAA;IAC5E,sCAAsC;IACtC,OAAO,MAAM,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAA;AAC5B,CAAC;AAED,MAAM,UAAU,YAAY,CAAC,MAAc,EAAE,MAAa,EAAE,QAAe,EAAE,IAAqB;IAChG,sEAAsE;IACtE,IAAI,QAAQ,GAAG,CAAC,CAAC,CAAA;IACjB,IAAI,SAAS,GAAG,CAAC,CAAA;IACjB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,QAAQ,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;QACzC,MAAM,CAAC,GAAG,QAAQ,CAAC,CAAC,CAAE,CAAA;QACtB,IAAI,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC;YACb,IAAI,QAAQ,KAAK,CAAC,CAAC;gBAAE,IAAI,CAAC,GAAG,MAAM,4CAA4C,SAAS,CAAC,QAAQ,CAAC,EAAE,EAAE,IAAI,CAAC,CAAA;YAC3G,QAAQ,GAAG,CAAC,CAAA;QACd,CAAC;aAAM,IAAI,CAAC,IAAI,CAAC,EAAE,CAAC;YAClB,IAAI,CAAC,GAAG,MAAM,iBAAiB,CAAC,gBAAgB,SAAS,CAAC,QAAQ,CAAC,EAAE,EAAE,IAAI,CAAC,CAAA;QAC9E,CAAC;aAAM,CAAC;YACN,SAAS,IAAI,CAAC,CAAA;QAChB,CAAC;IACH,CAAC;IACD,MAAM,OAAO,GAAG,SAAS,CAAC,MAAM,CAAC,CAAA;IACjC,MAAM,GAAG,GAAG,CAAC,GAAG,QAAQ,CAAC,CAAA;IACzB,IAAI,QAAQ,KAAK,CAAC,CAAC,EAAE,CAAC;QACpB,IAAI,OAAO,GAAG,SAAS,KAAK,CAAC,EAAE,CAAC;YAC9B,IAAI,CAAC,GAAG,MAAM,oBAAoB,SAAS,CAAC,MAAM,CAAC,UAAU,OAAO,QAAQ,SAAS,CAAC,QAAQ,CAAC,6BAA6B,SAAS,EAAE,EAAE,IAAI,CAAC,CAAA;QAChJ,CAAC;QACD,GAAG,CAAC,QAAQ,CAAC,GAAG,OAAO,GAAG,SAAS,CAAA;IACrC,CAAC;SAAM,IAAI,SAAS,KAAK,OAAO,EAAE,CAAC;QACjC,IAAI,CAAC,GAAG,MAAM,2BAA2B,SAAS,CAAC,MAAM,CAAC,QAAQ,OAAO,0BAA0B,SAAS,CAAC,QAAQ,CAAC,QAAQ,SAAS,EAAE,EAAE,IAAI,CAAC,CAAA;IAClJ,CAAC;IACD,OAAO,GAAG,CAAA;AACZ,CAAC;AAED,MAAM,UAAU,cAAc,CAAC,MAAc,EAAE,MAAa,EAAE,IAAuB,EAAE,IAAqB;IAC1G,IAAI,IAAI,CAAC,MAAM,KAAK,MAAM,CAAC,MAAM,EAAE,CAAC;QAClC,IAAI,CAAC,GAAG,MAAM,iBAAiB,IAAI,CAAC,MAAM,0BAA0B,MAAM,CAAC,MAAM,EAAE,EAAE,IAAI,CAAC,CAAA;IAC5F,CAAC;IACD,MAAM,IAAI,GAAG,IAAI,GAAG,EAAU,CAAA;IAC9B,KAAK,MAAM,CAAC,IAAI,IAAI,EAAE,CAAC;QACrB,IAAI,CAAC,GAAG,CAAC,IAAI,CAAC,IAAI,MAAM,CAAC,MAAM;YAAE,IAAI,CAAC,GAAG,MAAM,gBAAgB,CAAC,0BAA0B,MAAM,CAAC,MAAM,EAAE,EAAE,IAAI,CAAC,CAAA;QAChH,IAAI,IAAI,CAAC,GAAG,CAAC,CAAC,CAAC;YAAE,IAAI,CAAC,GAAG,MAAM,8BAA8B,CAAC,EAAE,EAAE,IAAI,CAAC,CAAA;QACvE,IAAI,CAAC,GAAG,CAAC,CAAC,CAAC,CAAA;IACb,CAAC;IACD,OAAO,IAAI,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,MAAM,CAAC,CAAC,CAAE,CAAC,CAAA;AAClC,CAAC;AAED,qEAAqE;AACrE,MAAM,UAAU,WAAW,CAAC,MAAc,EAAE,MAAa,EAAE,MAAa,EAAE,IAAqB;IAC7F,IAAI,MAAM,CAAC,MAAM,GAAG,CAAC;QAAE,IAAI,CAAC,GAAG,MAAM,kCAAkC,SAAS,CAAC,MAAM,CAAC,EAAE,EAAE,IAAI,CAAC,CAAA;IACjG,IAAI,MAAM,CAAC,MAAM,KAAK,CAAC;QAAE,IAAI,CAAC,GAAG,MAAM,+BAA+B,SAAS,CAAC,MAAM,CAAC,sCAAsC,EAAE,IAAI,CAAC,CAAA;IACpI,MAAM,CAAC,GAAG,MAAM,CAAC,MAAM,CAAC,MAAM,GAAG,CAAC,CAAE,CAAA;IACpC,MAAM,EAAE,GAAG,MAAM,CAAC,MAAM,CAAC,MAAM,GAAG,CAAC,CAAE,CAAA;IACrC,MAAM,EAAE,GAAG,MAAM,CAAC,CAAC,CAAE,CAAA;IACrB,MAAM,CAAC,GAAG,MAAM,CAAC,CAAC,CAAE,CAAA;IACpB,IAAI,EAAE,KAAK,EAAE;QAAE,IAAI,CAAC,GAAG,MAAM,8BAA8B,SAAS,CAAC,MAAM,CAAC,MAAM,SAAS,CAAC,MAAM,CAAC,wBAAwB,EAAE,yBAAyB,EAAE,GAAG,EAAE,IAAI,CAAC,CAAA;IAClK,OAAO,CAAC,GAAG,MAAM,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAA;AACvC,CAAC;AAED,gGAAgG;AAChG,MAAM,UAAU,kBAAkB,CAAC,MAAc,EAAE,MAAa,EAAE,MAAa,EAAE,IAAqB;IACpG,IAAI,MAAM,CAAC,MAAM,GAAG,CAAC,IAAI,MAAM,CAAC,MAAM,GAAG,CAAC,EAAE,CAAC;QAC3C,IAAI,CAAC,GAAG,MAAM,0CAA0C,SAAS,CAAC,MAAM,CAAC,QAAQ,SAAS,CAAC,MAAM,CAAC,EAAE,EAAE,IAAI,CAAC,CAAA;IAC7G,CAAC;IACD,IAAI,MAAM,CAAC,MAAM,KAAK,MAAM,CAAC,MAAM,EAAE,CAAC;QACpC,IAAI,CAAC,GAAG,MAAM,2BAA2B,MAAM,CAAC,MAAM,OAAO,MAAM,CAAC,MAAM,8CAA8C,EAAE,IAAI,CAAC,CAAA;IACjI,CAAC;IACD,MAAM,MAAM,GAAG,MAAM,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAA;IAClC,MAAM,MAAM,GAAG,MAAM,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAA;IAClC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;QACvC,IAAI,MAAM,CAAC,CAAC,CAAC,KAAK,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC;YAC5B,IAAI,CAAC,GAAG,MAAM,6BAA6B,SAAS,CAAC,MAAM,CAAC,OAAO,SAAS,CAAC,MAAM,CAAC,EAAE,EAAE,IAAI,CAAC,CAAA;QAC/F,CAAC;IACH,CAAC;IACD,MAAM,CAAC,GAAG,MAAM,CAAC,MAAM,CAAC,MAAM,GAAG,CAAC,CAAE,CAAA;IACpC,MAAM,EAAE,GAAG,MAAM,CAAC,MAAM,CAAC,MAAM,GAAG,CAAC,CAAE,CAAA;IACrC,MAAM,EAAE,GAAG,MAAM,CAAC,MAAM,CAAC,MAAM,GAAG,CAAC,CAAE,CAAA;IACrC,MAAM,CAAC,GAAG,MAAM,CAAC,MAAM,CAAC,MAAM,GAAG,CAAC,CAAE,CAAA;IACpC,IAAI,EAAE,KAAK,EAAE;QAAE,IAAI,CAAC,GAAG,MAAM,iDAAiD,EAAE,6BAA6B,EAAE,EAAE,EAAE,IAAI,CAAC,CAAA;IACxH,OAAO,CAAC,GAAG,MAAM,EAAE,CAAC,EAAE,CAAC,CAAC,CAAA;AAC1B,CAAC;AAED,MAAM,UAAU,WAAW,CAAC,MAAc,EAAE,YAAmB,EAAE,KAAa,EAAE,IAAqB;IACnG,IAAI,KAAK,IAAI,CAAC;QAAE,IAAI,CAAC,GAAG,MAAM,iCAAiC,KAAK,EAAE,EAAE,IAAI,CAAC,CAAA;IAC7E,OAAO,CAAC,GAAG,YAAY,EAAE,KAAK,CAAC,CAAA;AACjC,CAAC;AAED,4EAA4E;AAC5E,MAAM,UAAU,gBAAgB,CAAC,MAAc,EAAE,MAAa,EAAE,IAAqB;IACnF,IAAI,MAAM,CAAC,MAAM,GAAG,CAAC;QAAE,IAAI,CAAC,GAAG,MAAM,6BAA6B,SAAS,CAAC,MAAM,CAAC,EAAE,EAAE,IAAI,CAAC,CAAA;IAC5F,MAAM,CAAC,GAAG,MAAM,CAAC,MAAM,CAAC,MAAM,GAAG,CAAC,CAAE,CAAA;IACpC,MAAM,CAAC,GAAG,MAAM,CAAC,MAAM,CAAC,MAAM,GAAG,CAAC,CAAE,CAAA;IACpC,IAAI,CAAC,KAAK,CAAC;QAAE,IAAI,CAAC,GAAG,MAAM,oDAAoD,SAAS,CAAC,MAAM,CAAC,EAAE,EAAE,IAAI,CAAC,CAAA;IACzG,OAAO,MAAM,CAAA;AACf,CAAC;AAED,MAAM,UAAU,mBAAmB,CAAC,MAAc,EAAE,MAAa,EAAE,KAAa,EAAE,GAAW,EAAE,IAAqB;IAClH,IAAI,MAAM,CAAC,MAAM,KAAK,CAAC;QAAE,IAAI,CAAC,GAAG,MAAM,2BAA2B,EAAE,IAAI,CAAC,CAAA;IACzE,MAAM,IAAI,GAAG,MAAM,CAAC,MAAM,CAAC,MAAM,GAAG,CAAC,CAAE,CAAA;IACvC,IAAI,KAAK,GAAG,CAAC,IAAI,GAAG,GAAG,IAAI,IAAI,KAAK,IAAI,GAAG,EAAE,CAAC;QAC5C,IAAI,CAAC,GAAG,MAAM,oBAAoB,KAAK,KAAK,GAAG,2BAA2B,IAAI,EAAE,EAAE,IAAI,CAAC,CAAA;IACzF,CAAC;IACD,OAAO,CAAC,GAAG,MAAM,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,GAAG,GAAG,KAAK,CAAC,CAAA;AAC9C,CAAC;AAED,4EAA4E;AAC5E,6DAA6D;AAC7D,MAAM,UAAU,gBAAgB,CAAC,MAAc,EAAE,MAAa,EAAE,WAAkB,EAAE,IAAqB;IACvG,IAAI,MAAM,CAAC,MAAM,GAAG,WAAW,CAAC,MAAM,EAAE,CAAC;QACvC,IAAI,CAAC,GAAG,MAAM,iBAAiB,MAAM,CAAC,MAAM,kBAAkB,WAAW,CAAC,MAAM,EAAE,EAAE,IAAI,CAAC,CAAA;IAC3F,CAAC;IACD,MAAM,MAAM,GAAG,WAAW,CAAC,MAAM,GAAG,MAAM,CAAC,MAAM,CAAA;IACjD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;QACvC,MAAM,EAAE,GAAG,MAAM,CAAC,CAAC,CAAE,CAAA;QACrB,MAAM,EAAE,GAAG,WAAW,CAAC,MAAM,GAAG,CAAC,CAAE,CAAA;QACnC,IAAI,EAAE,KAAK,EAAE,IAAI,EAAE,KAAK,CAAC,EAAE,CAAC;YAC1B,IAAI,CAAC,GAAG,MAAM,sBAAsB,SAAS,CAAC,MAAM,CAAC,OAAO,SAAS,CAAC,WAAW,CAAC,WAAW,CAAC,UAAU,EAAE,+BAA+B,MAAM,GAAG,CAAC,UAAU,EAAE,eAAe,EAAE,IAAI,CAAC,CAAA;QACvL,CAAC;IACH,CAAC;IACD,OAAO,WAAW,CAAA;AACpB,CAAC;AAED,+EAA+E;AAC/E,iFAAiF;AACjF,MAAM,UAAU,eAAe,CAAC,MAAc,EAAE,MAAa,EAAE,WAAkB,EAAE,IAAqB;IACtG,IAAI,WAAW,CAAC,MAAM,GAAG,MAAM,CAAC,MAAM,EAAE,CAAC;QACvC,IAAI,CAAC,GAAG,MAAM,iBAAiB,WAAW,CAAC,MAAM,kBAAkB,MAAM,CAAC,MAAM,EAAE,EAAE,IAAI,CAAC,CAAA;IAC3F,CAAC;IACD,MAAM,MAAM,GAAG,MAAM,CAAC,MAAM,GAAG,WAAW,CAAC,MAAM,CAAA;IACjD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,WAAW,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;QAC5C,MAAM,EAAE,GAAG,MAAM,CAAC,MAAM,GAAG,CAAC,CAAE,CAAA;QAC9B,MAAM,EAAE,GAAG,WAAW,CAAC,CAAC,CAAE,CAAA;QAC1B,IAAI,EAAE,KAAK,EAAE,IAAI,EAAE,KAAK,CAAC,EAAE,CAAC;YAC1B,IAAI,CAAC,GAAG,MAAM,uBAAuB,SAAS,CAAC,MAAM,CAAC,OAAO,SAAS,CAAC,WAAW,CAAC,kBAAkB,CAAC,UAAU,EAAE,6BAA6B,EAAE,IAAI,CAAC,CAAA;QACxJ,CAAC;IACH,CAAC;IACD,OAAO,WAAW,CAAA;AACpB,CAAC;AAED,+EAA+E;AAC/E,gDAAgD;AAChD,MAAM,UAAU,UAAU,CAAC,MAAc,EAAE,SAAgB,EAAE,MAAa,EAAE,MAAa,EAAE,IAAqB;IAC9G,MAAM,EAAE,GAAG,iBAAiB,CAAC,MAAM,EAAE,MAAM,CAAC,CAAA;IAC5C,IAAI,CAAC,EAAE;QAAE,IAAI,CAAC,GAAG,MAAM,uBAAuB,SAAS,CAAC,MAAM,CAAC,OAAO,SAAS,CAAC,MAAM,CAAC,EAAE,EAAE,IAAI,CAAC,CAAA;IAChG,MAAM,MAAM,GAAG,iBAAiB,CAAC,SAAS,EAAE,EAAE,CAAC,CAAA;IAC/C,IAAI,CAAC,MAAM;QAAE,IAAI,CAAC,GAAG,MAAM,UAAU,SAAS,CAAC,SAAS,CAAC,sCAAsC,SAAS,CAAC,EAAE,CAAC,EAAE,EAAE,IAAI,CAAC,CAAA;IACrH,OAAO,MAAM,CAAA;AACf,CAAC;AAED,MAAM,UAAU,aAAa,CAAC,MAAc,EAAE,MAAa,EAAE,OAAc,EAAE,IAAqB;IAChG,IAAI,CAAC,WAAW,CAAC,MAAM,EAAE,OAAO,CAAC,EAAE,CAAC;QAClC,IAAI,CAAC,GAAG,MAAM,6CAA6C,SAAS,CAAC,MAAM,CAAC,QAAQ,SAAS,CAAC,OAAO,CAAC,EAAE,EAAE,IAAI,CAAC,CAAA;IACjH,CAAC;IACD,OAAO,MAAM,CAAA;AACf,CAAC"}
|
package/dist/trace.d.ts
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
import type { Graph, Tensor, Shape, Dtype } from './ir.js';
|
|
2
|
+
export declare function currentGraph(): Graph;
|
|
3
|
+
export declare function trace(fn: () => Tensor | Tensor[]): Graph;
|
|
4
|
+
export declare function traceInto<T>(g: Graph, fn: () => T): T;
|
|
5
|
+
export declare function paramInput(name: string, shape: Shape, dtype?: Dtype): Tensor;
|
|
6
|
+
export declare function tensorInput(name: string, shape: Shape, dtype?: Dtype): Tensor;
|
|
7
|
+
export declare function stateInput(name: string, shape: Shape, dtype?: Dtype, initValue?: number): Tensor;
|
|
8
|
+
//# sourceMappingURL=trace.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"trace.d.ts","sourceRoot":"","sources":["../src/trace.ts"],"names":[],"mappings":"AAgBA,OAAO,KAAK,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,SAAS,CAAA;AAM1D,wBAAgB,YAAY,IAAI,KAAK,CAQpC;AAID,wBAAgB,KAAK,CAAC,EAAE,EAAE,MAAM,MAAM,GAAG,MAAM,EAAE,GAAG,KAAK,CAgBxD;AAMD,wBAAgB,SAAS,CAAC,CAAC,EAAE,CAAC,EAAE,KAAK,EAAE,EAAE,EAAE,MAAM,CAAC,GAAG,CAAC,CAUrD;AAOD,wBAAgB,UAAU,CAAC,IAAI,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,GAAE,KAAa,GAAG,MAAM,CAOnF;AAED,wBAAgB,WAAW,CAAC,IAAI,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,GAAE,KAAa,GAAG,MAAM,CAOpF;AAID,wBAAgB,UAAU,CAAC,IAAI,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,GAAE,KAAa,EAAE,SAAS,SAAI,GAAG,MAAM,CAOlG"}
|
package/dist/trace.js
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
// Trace driver. Holds the "current graph" in module-local state so user code
|
|
2
|
+
// can call ops without threading a graph parameter through every function.
|
|
3
|
+
//
|
|
4
|
+
// Usage:
|
|
5
|
+
//
|
|
6
|
+
// const graph = trace(() => {
|
|
7
|
+
// const x = tensorInput('x', [B, T], 'i32')
|
|
8
|
+
// const w = paramInput('w', [V, D], 'f32')
|
|
9
|
+
// // ... user computation building tensors ...
|
|
10
|
+
// return finalLossTensor
|
|
11
|
+
// })
|
|
12
|
+
//
|
|
13
|
+
// `trace` is single-threaded and re-entrant only via nested calls (which share
|
|
14
|
+
// the outer graph — but we don't currently have a use for nesting). Calling an
|
|
15
|
+
// op outside a `trace(...)` block is an error.
|
|
16
|
+
import { makeGraph, addOp, captureSite } from './ir.js';
|
|
17
|
+
// Module-local: the graph being built right now, or null if no trace is active.
|
|
18
|
+
let _current = null;
|
|
19
|
+
export function currentGraph() {
|
|
20
|
+
if (!_current) {
|
|
21
|
+
throw new Error('tensorgrad: ops can only be called inside trace(). ' +
|
|
22
|
+
'Did you forget to wrap your forward pass?');
|
|
23
|
+
}
|
|
24
|
+
return _current;
|
|
25
|
+
}
|
|
26
|
+
// Run `fn` with a fresh graph as the current one; capture and return the graph.
|
|
27
|
+
// `fn` must return the tensor (or array of tensors) to mark as graph outputs.
|
|
28
|
+
export function trace(fn) {
|
|
29
|
+
if (_current) {
|
|
30
|
+
throw new Error('tensorgrad: nested trace() is not supported');
|
|
31
|
+
}
|
|
32
|
+
const g = makeGraph();
|
|
33
|
+
_current = g;
|
|
34
|
+
try {
|
|
35
|
+
const result = fn();
|
|
36
|
+
const outputs = Array.isArray(result) ? result : [result];
|
|
37
|
+
for (const t of outputs) {
|
|
38
|
+
;
|
|
39
|
+
g.outputs.push(t.id);
|
|
40
|
+
}
|
|
41
|
+
}
|
|
42
|
+
finally {
|
|
43
|
+
_current = null;
|
|
44
|
+
}
|
|
45
|
+
return g;
|
|
46
|
+
}
|
|
47
|
+
// Re-enter an existing graph to append more ops. Used by autograd to add
|
|
48
|
+
// backward ops to a graph that's already been traced. `fn` runs with the
|
|
49
|
+
// supplied graph as the current one; any ops it calls append to that graph.
|
|
50
|
+
// Returns whatever `fn` returns.
|
|
51
|
+
export function traceInto(g, fn) {
|
|
52
|
+
if (_current) {
|
|
53
|
+
throw new Error('tensorgrad: traceInto() called while another trace is active');
|
|
54
|
+
}
|
|
55
|
+
_current = g;
|
|
56
|
+
try {
|
|
57
|
+
return fn();
|
|
58
|
+
}
|
|
59
|
+
finally {
|
|
60
|
+
_current = null;
|
|
61
|
+
}
|
|
62
|
+
}
|
|
63
|
+
// ---- Leaf tensor builders --------------------------------------------------
|
|
64
|
+
// Inputs are added to the graph as `param_input` or `tensor_input` op nodes.
|
|
65
|
+
// Their .source on the Tensor points at that node so codegen knows where to
|
|
66
|
+
// bind external data.
|
|
67
|
+
export function paramInput(name, shape, dtype = 'f32') {
|
|
68
|
+
const g = currentGraph();
|
|
69
|
+
if (g.ops.some(op => (op.kind === 'param_input' || op.kind === 'tensor_input') && op.name === name)) {
|
|
70
|
+
throw new Error(`tensorgrad: input name '${name}' already used in this trace`);
|
|
71
|
+
}
|
|
72
|
+
const site = captureSite('paramInput');
|
|
73
|
+
return addOp(g, 'param_input', shape, dtype, site, { name });
|
|
74
|
+
}
|
|
75
|
+
export function tensorInput(name, shape, dtype = 'f32') {
|
|
76
|
+
const g = currentGraph();
|
|
77
|
+
if (g.ops.some(op => (op.kind === 'param_input' || op.kind === 'tensor_input') && op.name === name)) {
|
|
78
|
+
throw new Error(`tensorgrad: input name '${name}' already used in this trace`);
|
|
79
|
+
}
|
|
80
|
+
const site = captureSite('tensorInput');
|
|
81
|
+
return addOp(g, 'tensor_input', shape, dtype, site, { name });
|
|
82
|
+
}
|
|
83
|
+
// Persistent state buffer. Allocated at compile time, zero-(or initValue-)initialized,
|
|
84
|
+
// and updated across step() calls via writebacks declared by the optimizer helper.
|
|
85
|
+
export function stateInput(name, shape, dtype = 'f32', initValue = 0) {
|
|
86
|
+
const g = currentGraph();
|
|
87
|
+
if (g.ops.some(op => op.kind === 'state_input' && op.name === name)) {
|
|
88
|
+
throw new Error(`tensorgrad: state name '${name}' already used in this trace`);
|
|
89
|
+
}
|
|
90
|
+
const site = captureSite('stateInput');
|
|
91
|
+
return addOp(g, 'state_input', shape, dtype, site, { name, initValue });
|
|
92
|
+
}
|
|
93
|
+
//# sourceMappingURL=trace.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"trace.js","sourceRoot":"","sources":["../src/trace.ts"],"names":[],"mappings":"AAAA,6EAA6E;AAC7E,2EAA2E;AAC3E,EAAE;AACF,SAAS;AACT,EAAE;AACF,gCAAgC;AAChC,gDAAgD;AAChD,+CAA+C;AAC/C,mDAAmD;AACnD,6BAA6B;AAC7B,OAAO;AACP,EAAE;AACF,+EAA+E;AAC/E,+EAA+E;AAC/E,+CAA+C;AAG/C,OAAO,EAAE,SAAS,EAAE,KAAK,EAAE,WAAW,EAAE,MAAM,SAAS,CAAA;AAEvD,gFAAgF;AAChF,IAAI,QAAQ,GAAiB,IAAI,CAAA;AAEjC,MAAM,UAAU,YAAY;IAC1B,IAAI,CAAC,QAAQ,EAAE,CAAC;QACd,MAAM,IAAI,KAAK,CACb,qDAAqD;YACrD,2CAA2C,CAC5C,CAAA;IACH,CAAC;IACD,OAAO,QAAQ,CAAA;AACjB,CAAC;AAED,gFAAgF;AAChF,8EAA8E;AAC9E,MAAM,UAAU,KAAK,CAAC,EAA2B;IAC/C,IAAI,QAAQ,EAAE,CAAC;QACb,MAAM,IAAI,KAAK,CAAC,6CAA6C,CAAC,CAAA;IAChE,CAAC;IACD,MAAM,CAAC,GAAG,SAAS,EAAE,CAAA;IACrB,QAAQ,GAAG,CAAC,CAAA;IACZ,IAAI,CAAC;QACH,MAAM,MAAM,GAAG,EAAE,EAAE,CAAA;QACnB,MAAM,OAAO,GAAG,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAA;QACzD,KAAK,MAAM,CAAC,IAAI,OAAO,EAAE,CAAC;YACxB,CAAC;YAAC,CAAC,CAAC,OAAoB,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,CAAA;QACrC,CAAC;IACH,CAAC;YAAS,CAAC;QACT,QAAQ,GAAG,IAAI,CAAA;IACjB,CAAC;IACD,OAAO,CAAC,CAAA;AACV,CAAC;AAED,yEAAyE;AACzE,yEAAyE;AACzE,4EAA4E;AAC5E,iCAAiC;AACjC,MAAM,UAAU,SAAS,CAAI,CAAQ,EAAE,EAAW;IAChD,IAAI,QAAQ,EAAE,CAAC;QACb,MAAM,IAAI,KAAK,CAAC,8DAA8D,CAAC,CAAA;IACjF,CAAC;IACD,QAAQ,GAAG,CAAC,CAAA;IACZ,IAAI,CAAC;QACH,OAAO,EAAE,EAAE,CAAA;IACb,CAAC;YAAS,CAAC;QACT,QAAQ,GAAG,IAAI,CAAA;IACjB,CAAC;AACH,CAAC;AAED,+EAA+E;AAC/E,6EAA6E;AAC7E,4EAA4E;AAC5E,sBAAsB;AAEtB,MAAM,UAAU,UAAU,CAAC,IAAY,EAAE,KAAY,EAAE,QAAe,KAAK;IACzE,MAAM,CAAC,GAAG,YAAY,EAAE,CAAA;IACxB,IAAI,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,IAAI,KAAK,aAAa,IAAI,EAAE,CAAC,IAAI,KAAK,cAAc,CAAC,IAAI,EAAE,CAAC,IAAI,KAAK,IAAI,CAAC,EAAE,CAAC;QACpG,MAAM,IAAI,KAAK,CAAC,2BAA2B,IAAI,8BAA8B,CAAC,CAAA;IAChF,CAAC;IACD,MAAM,IAAI,GAAG,WAAW,CAAC,YAAY,CAAC,CAAA;IACtC,OAAO,KAAK,CAAC,CAAC,EAAE,aAAa,EAAE,KAAK,EAAE,KAAK,EAAE,IAAI,EAAE,EAAE,IAAI,EAAS,CAAC,CAAA;AACrE,CAAC;AAED,MAAM,UAAU,WAAW,CAAC,IAAY,EAAE,KAAY,EAAE,QAAe,KAAK;IAC1E,MAAM,CAAC,GAAG,YAAY,EAAE,CAAA;IACxB,IAAI,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,IAAI,KAAK,aAAa,IAAI,EAAE,CAAC,IAAI,KAAK,cAAc,CAAC,IAAI,EAAE,CAAC,IAAI,KAAK,IAAI,CAAC,EAAE,CAAC;QACpG,MAAM,IAAI,KAAK,CAAC,2BAA2B,IAAI,8BAA8B,CAAC,CAAA;IAChF,CAAC;IACD,MAAM,IAAI,GAAG,WAAW,CAAC,aAAa,CAAC,CAAA;IACvC,OAAO,KAAK,CAAC,CAAC,EAAE,cAAc,EAAE,KAAK,EAAE,KAAK,EAAE,IAAI,EAAE,EAAE,IAAI,EAAS,CAAC,CAAA;AACtE,CAAC;AAED,uFAAuF;AACvF,mFAAmF;AACnF,MAAM,UAAU,UAAU,CAAC,IAAY,EAAE,KAAY,EAAE,QAAe,KAAK,EAAE,SAAS,GAAG,CAAC;IACxF,MAAM,CAAC,GAAG,YAAY,EAAE,CAAA;IACxB,IAAI,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,IAAI,KAAK,aAAa,IAAI,EAAE,CAAC,IAAI,KAAK,IAAI,CAAC,EAAE,CAAC;QACpE,MAAM,IAAI,KAAK,CAAC,2BAA2B,IAAI,8BAA8B,CAAC,CAAA;IAChF,CAAC;IACD,MAAM,IAAI,GAAG,WAAW,CAAC,YAAY,CAAC,CAAA;IACtC,OAAO,KAAK,CAAC,CAAC,EAAE,aAAa,EAAE,KAAK,EAAE,KAAK,EAAE,IAAI,EAAE,EAAE,IAAI,EAAE,SAAS,EAAS,CAAC,CAAA;AAChF,CAAC"}
|
package/package.json
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
{
|
|
2
|
+
"name": "tensorgrad",
|
|
3
|
+
"version": "0.0.1",
|
|
4
|
+
"description": "Tiny TypeScript-native tensor library with autograd, compiling to WebGPU. Train small models in the browser without hand-writing kernels.",
|
|
5
|
+
"license": "MIT",
|
|
6
|
+
"author": "Ben Albahari",
|
|
7
|
+
"repository": {
|
|
8
|
+
"type": "git",
|
|
9
|
+
"url": "git+https://github.com/typebulb/tensorgrad.git"
|
|
10
|
+
},
|
|
11
|
+
"homepage": "https://github.com/typebulb/tensorgrad#readme",
|
|
12
|
+
"bugs": {
|
|
13
|
+
"url": "https://github.com/typebulb/tensorgrad/issues"
|
|
14
|
+
},
|
|
15
|
+
"keywords": [
|
|
16
|
+
"webgpu",
|
|
17
|
+
"machine-learning",
|
|
18
|
+
"autograd",
|
|
19
|
+
"tensor",
|
|
20
|
+
"neural-network",
|
|
21
|
+
"transformer",
|
|
22
|
+
"browser",
|
|
23
|
+
"typescript"
|
|
24
|
+
],
|
|
25
|
+
"type": "module",
|
|
26
|
+
"main": "./src/index.ts",
|
|
27
|
+
"types": "./src/index.ts",
|
|
28
|
+
"exports": {
|
|
29
|
+
".": {
|
|
30
|
+
"types": "./src/index.ts",
|
|
31
|
+
"default": "./src/index.ts"
|
|
32
|
+
}
|
|
33
|
+
},
|
|
34
|
+
"//": "Source-only exports for monorepo dev. publishConfig flips to dist/*.js for npm publish.",
|
|
35
|
+
"publishConfig": {
|
|
36
|
+
"main": "./dist/index.js",
|
|
37
|
+
"types": "./dist/index.d.ts",
|
|
38
|
+
"exports": {
|
|
39
|
+
".": {
|
|
40
|
+
"types": "./dist/index.d.ts",
|
|
41
|
+
"default": "./dist/index.js"
|
|
42
|
+
}
|
|
43
|
+
}
|
|
44
|
+
},
|
|
45
|
+
"files": [
|
|
46
|
+
"dist",
|
|
47
|
+
"src",
|
|
48
|
+
"SPEC.md",
|
|
49
|
+
"README.md",
|
|
50
|
+
"LICENSE"
|
|
51
|
+
],
|
|
52
|
+
"scripts": {
|
|
53
|
+
"build": "tsc -p tsconfig.json",
|
|
54
|
+
"typecheck": "tsc -p tsconfig.json --noEmit",
|
|
55
|
+
"test": "tsx test/smoke.ts",
|
|
56
|
+
"prepublishOnly": "tsc -p tsconfig.json"
|
|
57
|
+
},
|
|
58
|
+
"devDependencies": {
|
|
59
|
+
"tsx": "*",
|
|
60
|
+
"typescript": "*"
|
|
61
|
+
}
|
|
62
|
+
}
|
package/src/adam.ts
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
// Adam optimizer, in-graph.
|
|
2
|
+
//
|
|
3
|
+
// `appendAdam` extends a graph that already has a forward pass + autograd-emitted
|
|
4
|
+
// backward (i.e., has paramGrads from `appendGrad`) with the Adam update math.
|
|
5
|
+
//
|
|
6
|
+
// Per parameter P with gradient g:
|
|
7
|
+
// m_new = b1 * m + (1 - b1) * g
|
|
8
|
+
// v_new = b2 * v + (1 - b2) * g²
|
|
9
|
+
// p_new = p - lr * m_new / (sqrt(v_new) + eps)
|
|
10
|
+
//
|
|
11
|
+
// This is "Adam without bias correction" — the `1 / (1 - β^t)` factors are
|
|
12
|
+
// dropped because computing them in-graph requires per-step uniforms or
|
|
13
|
+
// awkward exp/log tricks. In practice the omission only affects the first
|
|
14
|
+
// ~100 steps; convergence is unaffected.
|
|
15
|
+
//
|
|
16
|
+
// Returns writeback declarations the buffer planner uses to wire up the
|
|
17
|
+
// "after step, copy the new value into the persistent home" path. m and v
|
|
18
|
+
// are state_inputs (zero-initialized, persistent across steps); the param
|
|
19
|
+
// updates are aliased back to the param buffers.
|
|
20
|
+
|
|
21
|
+
import type { Tensor } from './ir.js'
|
|
22
|
+
import type { Graph } from './ir.js'
|
|
23
|
+
import type { WritebackDecl } from './buffers.js'
|
|
24
|
+
import { traceInto, stateInput, tensorInput } from './trace.js'
|
|
25
|
+
import { adamUpdateM, adamUpdateV, adamUpdateP } from './ops.js'
|
|
26
|
+
|
|
27
|
+
export interface AdamConfig {
|
|
28
|
+
lr: number
|
|
29
|
+
b1?: number // default 0.9
|
|
30
|
+
b2?: number // default 0.999
|
|
31
|
+
eps?: number // default 1e-8
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
export interface AdamResult {
|
|
35
|
+
/** Writebacks the buffer planner should wire into the runtime. */
|
|
36
|
+
writebacks: WritebackDecl[]
|
|
37
|
+
/** Name of the per-step scalar tensor_input. The runtime fills this each call
|
|
38
|
+
* with `lr * sqrt(1-b2^t)/(1-b1^t)` (Adam's bias-corrected effective LR). */
|
|
39
|
+
lrtInputName: string
|
|
40
|
+
/** Hyperparameters as captured (so the runtime can compute lrt). */
|
|
41
|
+
config: Required<AdamConfig>
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
/**
|
|
45
|
+
* Append Adam update ops to `graph`. Must be called inside an active trace
|
|
46
|
+
* context (or after a trace, since traceInto re-enters the graph).
|
|
47
|
+
*
|
|
48
|
+
* @param graph the graph (already containing forward + backward)
|
|
49
|
+
* @param paramGrads param name -> gradient tensor (output of `appendGrad`)
|
|
50
|
+
* @param paramTensors param name -> the param's leaf Tensor (the param_input).
|
|
51
|
+
* Needed because the param_input lives in the graph but we
|
|
52
|
+
* don't have a direct map by name in `Graph` — caller passes it.
|
|
53
|
+
* @param config Adam hyperparameters
|
|
54
|
+
*/
|
|
55
|
+
export function appendAdam(
|
|
56
|
+
graph: Graph,
|
|
57
|
+
paramGrads: Record<string, Tensor>,
|
|
58
|
+
paramTensors: Record<string, Tensor>,
|
|
59
|
+
config: AdamConfig,
|
|
60
|
+
): AdamResult {
|
|
61
|
+
const fullConfig: Required<AdamConfig> = {
|
|
62
|
+
lr: config.lr,
|
|
63
|
+
b1: config.b1 ?? 0.9,
|
|
64
|
+
b2: config.b2 ?? 0.999,
|
|
65
|
+
eps: config.eps ?? 1e-8,
|
|
66
|
+
}
|
|
67
|
+
const writebacks: WritebackDecl[] = []
|
|
68
|
+
const lrtInputName = '_adam_lrt'
|
|
69
|
+
|
|
70
|
+
return traceInto(graph, () => {
|
|
71
|
+
// One scalar lrt input shared by every adam_update_p call. Runtime supplies
|
|
72
|
+
// it per step as `lr * sqrt(1-b2^t) / (1-b1^t)`.
|
|
73
|
+
const lrt = tensorInput(lrtInputName, [], 'f32')
|
|
74
|
+
|
|
75
|
+
for (const name of Object.keys(paramGrads)) {
|
|
76
|
+
const p = paramTensors[name]
|
|
77
|
+
const g = paramGrads[name]
|
|
78
|
+
if (!p) throw new Error(`appendAdam: missing param tensor for '${name}'`)
|
|
79
|
+
if (!g) throw new Error(`appendAdam: missing gradient for '${name}'`)
|
|
80
|
+
|
|
81
|
+
const mState = stateInput(`adam_m_${name}`, p.shape, 'f32', 0)
|
|
82
|
+
const vState = stateInput(`adam_v_${name}`, p.shape, 'f32', 0)
|
|
83
|
+
|
|
84
|
+
// Three fused kernels per parameter — one for each of m_new / v_new / p_new.
|
|
85
|
+
const newM = adamUpdateM(mState, g, fullConfig.b1)
|
|
86
|
+
const newV = adamUpdateV(vState, g, fullConfig.b2)
|
|
87
|
+
const newP = adamUpdateP(p, newM, newV, lrt, fullConfig.eps)
|
|
88
|
+
|
|
89
|
+
writebacks.push({ source: newM, destName: `adam_m_${name}`, destKind: 'state' })
|
|
90
|
+
writebacks.push({ source: newV, destName: `adam_v_${name}`, destKind: 'state' })
|
|
91
|
+
writebacks.push({ source: newP, destName: name, destKind: 'param' })
|
|
92
|
+
}
|
|
93
|
+
return { writebacks, lrtInputName, config: fullConfig }
|
|
94
|
+
})
|
|
95
|
+
}
|