tensorgrad 0.0.14 → 0.0.16
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.d.ts +154 -170
- package/dist/index.js +2208 -39
- package/dist/index.js.map +7 -1
- package/dist/worker.debug.js +553 -0
- package/package.json +60 -58
- package/src/adam.ts +69 -15
- package/src/compile.ts +334 -154
- package/src/index.ts +8 -4
- package/src/module.ts +72 -34
- package/src/runtime.ts +64 -11
- package/src/worker-protocol.ts +183 -0
- package/src/worker-proxy.ts +76 -0
- package/src/worker.ts +281 -0
- package/dist/adam.js +0 -111
- package/dist/adam.js.map +0 -1
- package/dist/buffers.js +0 -120
- package/dist/buffers.js.map +0 -1
- package/dist/capture.js +0 -33
- package/dist/capture.js.map +0 -1
- package/dist/codegen.js +0 -724
- package/dist/codegen.js.map +0 -1
- package/dist/compile.js +0 -180
- package/dist/compile.js.map +0 -1
- package/dist/grad.js +0 -380
- package/dist/grad.js.map +0 -1
- package/dist/ir.js +0 -60
- package/dist/ir.js.map +0 -1
- package/dist/module.js +0 -155
- package/dist/module.js.map +0 -1
- package/dist/nn.js +0 -135
- package/dist/nn.js.map +0 -1
- package/dist/ops.js +0 -326
- package/dist/ops.js.map +0 -1
- package/dist/runtime.js +0 -375
- package/dist/runtime.js.map +0 -1
- package/dist/shape.js +0 -259
- package/dist/shape.js.map +0 -1
- package/dist/trace.js +0 -100
- package/dist/trace.js.map +0 -1
|
@@ -0,0 +1,553 @@
|
|
|
1
|
+
// src/runtime.ts
|
|
2
|
+
var Captures = class {
|
|
3
|
+
constructor(shapes, data) {
|
|
4
|
+
this.shapes = shapes;
|
|
5
|
+
this.data = data;
|
|
6
|
+
}
|
|
7
|
+
shapes;
|
|
8
|
+
data;
|
|
9
|
+
get(name) {
|
|
10
|
+
const d = this.data.get(name);
|
|
11
|
+
if (!d) {
|
|
12
|
+
const known = [...this.data.keys()].sort().join(", ");
|
|
13
|
+
const detail = known ? `Known this call: ${known}` : `(call run/step with { withCaptures: true } to populate)`;
|
|
14
|
+
throw new Error(`Captures.get: '${name}' not present. ${detail}`);
|
|
15
|
+
}
|
|
16
|
+
return d;
|
|
17
|
+
}
|
|
18
|
+
shapeOf(name) {
|
|
19
|
+
const s = this.shapes[name];
|
|
20
|
+
if (!s) {
|
|
21
|
+
const known = Object.keys(this.shapes).sort().join(", ") || "(none registered)";
|
|
22
|
+
throw new Error(`Captures.shapeOf: '${name}' not registered. Known: ${known}`);
|
|
23
|
+
}
|
|
24
|
+
return s;
|
|
25
|
+
}
|
|
26
|
+
has(name) {
|
|
27
|
+
return this.data.has(name);
|
|
28
|
+
}
|
|
29
|
+
names() {
|
|
30
|
+
return [...this.data.keys()].sort();
|
|
31
|
+
}
|
|
32
|
+
};
|
|
33
|
+
var STORAGE_RW = 128 | 8 | 4;
|
|
34
|
+
var READBACK = 1 | 8;
|
|
35
|
+
async function createRuntime(plan, kernels, lossBufferId, opts = {}) {
|
|
36
|
+
const device2 = opts.device ?? await acquireDevice();
|
|
37
|
+
const queue = device2.queue;
|
|
38
|
+
const buffers = /* @__PURE__ */ new Map();
|
|
39
|
+
const ownedBufferIds = /* @__PURE__ */ new Set();
|
|
40
|
+
const sharedParams = opts.sharedParams;
|
|
41
|
+
for (const spec of plan.buffers) {
|
|
42
|
+
const shared = spec.kind === "param" ? sharedParams?.get(spec.name) : void 0;
|
|
43
|
+
if (shared) {
|
|
44
|
+
if (shared.size !== spec.byteSize) {
|
|
45
|
+
throw new Error(
|
|
46
|
+
`sharedParams: size mismatch for '${spec.name}' \u2014 supplied ${shared.size} bytes, compiled graph expects ${spec.byteSize}.`
|
|
47
|
+
);
|
|
48
|
+
}
|
|
49
|
+
buffers.set(spec.id, shared);
|
|
50
|
+
continue;
|
|
51
|
+
}
|
|
52
|
+
const buf = device2.createBuffer({
|
|
53
|
+
size: spec.byteSize,
|
|
54
|
+
usage: STORAGE_RW,
|
|
55
|
+
label: spec.name ?? `t${spec.id}-${spec.kind}`
|
|
56
|
+
});
|
|
57
|
+
buffers.set(spec.id, buf);
|
|
58
|
+
ownedBufferIds.add(spec.id);
|
|
59
|
+
if (spec.kind === "state") fillStateBuffer(spec, buf);
|
|
60
|
+
}
|
|
61
|
+
const moduleCache = /* @__PURE__ */ new Map();
|
|
62
|
+
const pipelines = [];
|
|
63
|
+
const probes = [];
|
|
64
|
+
for (const k of kernels) {
|
|
65
|
+
if (!k.wgsl) {
|
|
66
|
+
pipelines.push(null);
|
|
67
|
+
continue;
|
|
68
|
+
}
|
|
69
|
+
let module = moduleCache.get(k.wgsl);
|
|
70
|
+
if (!module) {
|
|
71
|
+
module = device2.createShaderModule({ code: k.wgsl, label: k.opKind });
|
|
72
|
+
moduleCache.set(k.wgsl, module);
|
|
73
|
+
}
|
|
74
|
+
device2.pushErrorScope("validation");
|
|
75
|
+
const pipeline = device2.createComputePipeline({
|
|
76
|
+
layout: "auto",
|
|
77
|
+
compute: { module, entryPoint: "main" },
|
|
78
|
+
label: k.opKind
|
|
79
|
+
});
|
|
80
|
+
pipelines.push(pipeline);
|
|
81
|
+
probes.push(device2.popErrorScope().then((err) => err ? { k, module, err } : null));
|
|
82
|
+
}
|
|
83
|
+
const probeResults = await Promise.all(probes);
|
|
84
|
+
const failures = probeResults.filter((p) => p != null);
|
|
85
|
+
if (failures.length > 0) {
|
|
86
|
+
const reports = [];
|
|
87
|
+
for (const { k, module, err } of failures) {
|
|
88
|
+
const info = await module.getCompilationInfo();
|
|
89
|
+
const messages = info.messages.map((m) => ` L${m.lineNum}:${m.linePos} [${m.type}] ${m.message}`).join("\n");
|
|
90
|
+
reports.push(
|
|
91
|
+
`[shader compile error] ${k.opKind} (op #${k.opIndex}): ${err.message}
|
|
92
|
+
` + (messages || " (no compilation messages)") + `
|
|
93
|
+
--- WGSL ---
|
|
94
|
+
${k.wgsl}
|
|
95
|
+
-----------`
|
|
96
|
+
);
|
|
97
|
+
}
|
|
98
|
+
console.error(reports.join("\n\n"));
|
|
99
|
+
throw new Error(`tensorgrad: ${failures.length} shader(s) failed to compile (see console).`);
|
|
100
|
+
}
|
|
101
|
+
const bindGroups = kernels.map((k, i) => {
|
|
102
|
+
const pipeline = pipelines[i];
|
|
103
|
+
if (!pipeline) return null;
|
|
104
|
+
return device2.createBindGroup({
|
|
105
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
106
|
+
entries: k.bindings.map((bufId, idx) => ({
|
|
107
|
+
binding: idx,
|
|
108
|
+
resource: { buffer: buffers.get(bufId) }
|
|
109
|
+
}))
|
|
110
|
+
});
|
|
111
|
+
});
|
|
112
|
+
const outputSpec = plan.buffers[lossBufferId];
|
|
113
|
+
const outputReadback = device2.createBuffer({ size: outputSpec.byteSize, usage: READBACK });
|
|
114
|
+
let captureStaging = null;
|
|
115
|
+
function ensureCaptureStaging() {
|
|
116
|
+
if (captureStaging) return captureStaging;
|
|
117
|
+
let totalBytes = 0;
|
|
118
|
+
const slices = [];
|
|
119
|
+
for (const [name, bufId] of plan.capturesByName) {
|
|
120
|
+
const spec = plan.buffers[bufId];
|
|
121
|
+
slices.push({ name, bufId, offset: totalBytes, byteSize: spec.byteSize });
|
|
122
|
+
totalBytes += spec.byteSize;
|
|
123
|
+
}
|
|
124
|
+
const buffer = device2.createBuffer({ size: totalBytes, usage: READBACK, label: "captures-staging" });
|
|
125
|
+
captureStaging = { buffer, slices };
|
|
126
|
+
return captureStaging;
|
|
127
|
+
}
|
|
128
|
+
let pending = Promise.resolve();
|
|
129
|
+
async function dispatch(inputs, opts2) {
|
|
130
|
+
const turn = pending.catch(() => {
|
|
131
|
+
}).then(() => dispatchUnsynchronized(inputs, opts2));
|
|
132
|
+
pending = turn;
|
|
133
|
+
return turn;
|
|
134
|
+
}
|
|
135
|
+
async function dispatchUnsynchronized(inputs, opts2) {
|
|
136
|
+
const wantCaptures = opts2.wantCaptures;
|
|
137
|
+
if (wantCaptures && plan.capturesByName.size === 0) {
|
|
138
|
+
throw new Error(
|
|
139
|
+
`withCaptures=true but no capture(...) calls were registered during the trace. Add capture('name', tensor) inside your forward pass for the intermediates you want read back.`
|
|
140
|
+
);
|
|
141
|
+
}
|
|
142
|
+
for (const [name, bufId] of plan.inputsByName) {
|
|
143
|
+
const data = inputs[name];
|
|
144
|
+
if (!data) throw new Error(`tensorgrad: missing input '${name}'`);
|
|
145
|
+
const expectedBytes = plan.buffers[bufId].byteSize;
|
|
146
|
+
if (data.byteLength !== expectedBytes) {
|
|
147
|
+
throw new Error(`tensorgrad: input '${name}' has ${data.byteLength} bytes, expected ${expectedBytes}`);
|
|
148
|
+
}
|
|
149
|
+
queue.writeBuffer(buffers.get(bufId), 0, data);
|
|
150
|
+
}
|
|
151
|
+
const encoder = device2.createCommandEncoder({ label: "tensorgrad-step" });
|
|
152
|
+
for (let i = 0; i < kernels.length; i++) {
|
|
153
|
+
const k = kernels[i];
|
|
154
|
+
if (!k.wgsl || k.threads === 0) continue;
|
|
155
|
+
const pipeline = pipelines[i];
|
|
156
|
+
const bindGroup = bindGroups[i];
|
|
157
|
+
const pass = encoder.beginComputePass({ label: k.opKind });
|
|
158
|
+
pass.setPipeline(pipeline);
|
|
159
|
+
pass.setBindGroup(0, bindGroup);
|
|
160
|
+
const wgCount = Math.max(1, Math.ceil(k.threads / k.workgroupSize));
|
|
161
|
+
const MAX_X = 65535;
|
|
162
|
+
const wgX = Math.min(wgCount, MAX_X);
|
|
163
|
+
const wgY = Math.ceil(wgCount / MAX_X);
|
|
164
|
+
pass.dispatchWorkgroups(wgX, wgY, 1);
|
|
165
|
+
pass.end();
|
|
166
|
+
}
|
|
167
|
+
for (const wb of plan.writebacks) {
|
|
168
|
+
encoder.copyBufferToBuffer(buffers.get(wb.source), 0, buffers.get(wb.dest), 0, wb.bytes);
|
|
169
|
+
}
|
|
170
|
+
encoder.copyBufferToBuffer(buffers.get(lossBufferId), 0, outputReadback, 0, outputSpec.byteSize);
|
|
171
|
+
let layout = null;
|
|
172
|
+
if (wantCaptures) {
|
|
173
|
+
layout = ensureCaptureStaging();
|
|
174
|
+
for (const s of layout.slices) {
|
|
175
|
+
encoder.copyBufferToBuffer(buffers.get(s.bufId), 0, layout.buffer, s.offset, s.byteSize);
|
|
176
|
+
}
|
|
177
|
+
}
|
|
178
|
+
queue.submit([encoder.finish()]);
|
|
179
|
+
if (!opts2.readback) return null;
|
|
180
|
+
await outputReadback.mapAsync(GPUMapMode.READ);
|
|
181
|
+
const output = new Float32Array(outputReadback.getMappedRange().slice(0));
|
|
182
|
+
outputReadback.unmap();
|
|
183
|
+
const captures = /* @__PURE__ */ new Map();
|
|
184
|
+
if (layout) {
|
|
185
|
+
await layout.buffer.mapAsync(GPUMapMode.READ);
|
|
186
|
+
const range = layout.buffer.getMappedRange();
|
|
187
|
+
for (const s of layout.slices) {
|
|
188
|
+
captures.set(s.name, new Float32Array(range, s.offset, s.byteSize / 4).slice());
|
|
189
|
+
}
|
|
190
|
+
layout.buffer.unmap();
|
|
191
|
+
}
|
|
192
|
+
return { output, captures };
|
|
193
|
+
}
|
|
194
|
+
async function step(inputs, opts2) {
|
|
195
|
+
if (opts2?.readLoss === false) {
|
|
196
|
+
await dispatch(inputs, { wantCaptures: false, readback: false });
|
|
197
|
+
return;
|
|
198
|
+
}
|
|
199
|
+
const r = await dispatch(inputs, { wantCaptures: opts2?.withCaptures === true, readback: true });
|
|
200
|
+
if (opts2?.withCaptures) return { loss: r.output[0], captures: new Captures(captureShapes, r.captures) };
|
|
201
|
+
return r.output[0];
|
|
202
|
+
}
|
|
203
|
+
async function readLoss() {
|
|
204
|
+
const turn = pending.catch(() => {
|
|
205
|
+
}).then(async () => {
|
|
206
|
+
await outputReadback.mapAsync(GPUMapMode.READ);
|
|
207
|
+
const v = new Float32Array(outputReadback.getMappedRange())[0];
|
|
208
|
+
outputReadback.unmap();
|
|
209
|
+
return v;
|
|
210
|
+
});
|
|
211
|
+
pending = turn;
|
|
212
|
+
return turn;
|
|
213
|
+
}
|
|
214
|
+
async function run(inputs, opts2) {
|
|
215
|
+
const r = await dispatch(inputs, { wantCaptures: opts2?.withCaptures === true, readback: true });
|
|
216
|
+
if (opts2?.withCaptures) return { output: r.output, captures: new Captures(captureShapes, r.captures) };
|
|
217
|
+
return r.output;
|
|
218
|
+
}
|
|
219
|
+
function uploadParams(params2, opts2) {
|
|
220
|
+
const partial = opts2?.partial ?? false;
|
|
221
|
+
for (const name of Object.keys(params2)) {
|
|
222
|
+
if (!plan.paramsByName.has(name)) {
|
|
223
|
+
throw new Error(
|
|
224
|
+
`uploadParams: unknown param '${name}'. Known: ${[...plan.paramsByName.keys()].sort().join(", ")}`
|
|
225
|
+
);
|
|
226
|
+
}
|
|
227
|
+
}
|
|
228
|
+
if (!partial) {
|
|
229
|
+
for (const name of plan.paramsByName.keys()) {
|
|
230
|
+
if (!(name in params2)) {
|
|
231
|
+
throw new Error(
|
|
232
|
+
`uploadParams: missing param '${name}'. Pass { partial: true } if you mean to update only some params.`
|
|
233
|
+
);
|
|
234
|
+
}
|
|
235
|
+
}
|
|
236
|
+
}
|
|
237
|
+
for (const [name, bufId] of plan.paramsByName) {
|
|
238
|
+
const data = params2[name];
|
|
239
|
+
if (!data) continue;
|
|
240
|
+
const expected = plan.buffers[bufId].byteSize / 4;
|
|
241
|
+
if (data.length !== expected) {
|
|
242
|
+
throw new Error(`uploadParams: '${name}' has ${data.length} elements, expected ${expected}`);
|
|
243
|
+
}
|
|
244
|
+
queue.writeBuffer(buffers.get(bufId), 0, data);
|
|
245
|
+
}
|
|
246
|
+
}
|
|
247
|
+
async function downloadFromMap(map) {
|
|
248
|
+
const stagings = [];
|
|
249
|
+
const encoder = device2.createCommandEncoder({ label: "tensorgrad-download" });
|
|
250
|
+
for (const [name, bufId] of map) {
|
|
251
|
+
const spec = plan.buffers[bufId];
|
|
252
|
+
const staging = device2.createBuffer({ size: spec.byteSize, usage: READBACK });
|
|
253
|
+
encoder.copyBufferToBuffer(buffers.get(bufId), 0, staging, 0, spec.byteSize);
|
|
254
|
+
stagings.push({ name, buf: staging, bytes: spec.byteSize });
|
|
255
|
+
}
|
|
256
|
+
queue.submit([encoder.finish()]);
|
|
257
|
+
const out = {};
|
|
258
|
+
for (const s of stagings) {
|
|
259
|
+
await s.buf.mapAsync(GPUMapMode.READ);
|
|
260
|
+
out[s.name] = new Float32Array(s.buf.getMappedRange().slice(0));
|
|
261
|
+
s.buf.unmap();
|
|
262
|
+
s.buf.destroy();
|
|
263
|
+
}
|
|
264
|
+
return out;
|
|
265
|
+
}
|
|
266
|
+
function fillStateBuffer(spec, target) {
|
|
267
|
+
const elements = spec.byteSize / 4;
|
|
268
|
+
const init = spec.dtype === "f32" ? new Float32Array(elements).fill(spec.initValue ?? 0) : new Int32Array(elements).fill(Math.trunc(spec.initValue ?? 0));
|
|
269
|
+
queue.writeBuffer(target, 0, init);
|
|
270
|
+
}
|
|
271
|
+
function resetOptimizerState() {
|
|
272
|
+
for (const spec of plan.buffers) {
|
|
273
|
+
if (spec.kind === "state") fillStateBuffer(spec, buffers.get(spec.id));
|
|
274
|
+
}
|
|
275
|
+
}
|
|
276
|
+
const params = /* @__PURE__ */ new Map();
|
|
277
|
+
for (const [name, bufId] of plan.paramsByName) {
|
|
278
|
+
params.set(name, buffers.get(bufId));
|
|
279
|
+
}
|
|
280
|
+
const captureShapes = {};
|
|
281
|
+
for (const [name, bufId] of plan.capturesByName) {
|
|
282
|
+
captureShapes[name] = [...plan.buffers[bufId].shape];
|
|
283
|
+
}
|
|
284
|
+
const outputShape = [...plan.buffers[lossBufferId].shape];
|
|
285
|
+
const destroy = () => {
|
|
286
|
+
for (const [id, b] of buffers) {
|
|
287
|
+
if (ownedBufferIds.has(id)) b.destroy();
|
|
288
|
+
}
|
|
289
|
+
outputReadback.destroy();
|
|
290
|
+
if (captureStaging) captureStaging.buffer.destroy();
|
|
291
|
+
};
|
|
292
|
+
return {
|
|
293
|
+
device: device2,
|
|
294
|
+
params,
|
|
295
|
+
outputShape,
|
|
296
|
+
uploadParams,
|
|
297
|
+
downloadParams: () => downloadFromMap(plan.paramsByName),
|
|
298
|
+
downloadParamGrads: () => downloadFromMap(plan.paramGradsByName),
|
|
299
|
+
step,
|
|
300
|
+
run,
|
|
301
|
+
readLoss,
|
|
302
|
+
resetOptimizerState,
|
|
303
|
+
destroy
|
|
304
|
+
};
|
|
305
|
+
}
|
|
306
|
+
async function acquireDevice() {
|
|
307
|
+
if (typeof navigator === "undefined" || !navigator.gpu) {
|
|
308
|
+
throw new Error("tensorgrad: WebGPU not available in this environment");
|
|
309
|
+
}
|
|
310
|
+
const adapter = await navigator.gpu.requestAdapter();
|
|
311
|
+
if (!adapter) throw new Error("tensorgrad: no WebGPU adapter");
|
|
312
|
+
return await adapter.requestDevice();
|
|
313
|
+
}
|
|
314
|
+
|
|
315
|
+
// src/adam.ts
|
|
316
|
+
function resolveLR(schedule, step) {
|
|
317
|
+
if (typeof schedule === "number") return schedule;
|
|
318
|
+
switch (schedule.kind) {
|
|
319
|
+
case "constant":
|
|
320
|
+
return schedule.value;
|
|
321
|
+
case "linearDecay": {
|
|
322
|
+
const f = Math.min(step / schedule.steps, 1);
|
|
323
|
+
return schedule.peak + (schedule.final - schedule.peak) * f;
|
|
324
|
+
}
|
|
325
|
+
case "cosineDecay": {
|
|
326
|
+
const f = Math.min(step / schedule.steps, 1);
|
|
327
|
+
return schedule.final + 0.5 * (schedule.peak - schedule.final) * (1 + Math.cos(Math.PI * f));
|
|
328
|
+
}
|
|
329
|
+
case "warmup": {
|
|
330
|
+
if (step <= schedule.warmupSteps) return schedule.peakLr * (step / schedule.warmupSteps);
|
|
331
|
+
return resolveLR(schedule.after, step - schedule.warmupSteps);
|
|
332
|
+
}
|
|
333
|
+
}
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
// src/worker-protocol.ts
|
|
337
|
+
function wireError(e) {
|
|
338
|
+
if (e instanceof Error) {
|
|
339
|
+
return { name: e.name, message: e.message, stack: e.stack ?? "" };
|
|
340
|
+
}
|
|
341
|
+
return { name: "Error", message: String(e), stack: "" };
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
// src/worker.ts
|
|
345
|
+
var graphs = /* @__PURE__ */ new Map();
|
|
346
|
+
var device = null;
|
|
347
|
+
async function ensureDevice() {
|
|
348
|
+
if (device) return device;
|
|
349
|
+
if (typeof navigator === "undefined" || !navigator.gpu) {
|
|
350
|
+
throw new Error("tensorgrad worker: WebGPU not available in this environment");
|
|
351
|
+
}
|
|
352
|
+
const adapter = await navigator.gpu.requestAdapter();
|
|
353
|
+
if (!adapter) throw new Error("tensorgrad worker: no WebGPU adapter");
|
|
354
|
+
device = await adapter.requestDevice();
|
|
355
|
+
return device;
|
|
356
|
+
}
|
|
357
|
+
async function handleCreateRuntime(payload) {
|
|
358
|
+
const dev = await ensureDevice();
|
|
359
|
+
const { graph, plan, kernels } = payload.ir;
|
|
360
|
+
const outputTensorId = graph.outputs[0];
|
|
361
|
+
const outputBufferId = plan.tensorToBuffer.get(outputTensorId);
|
|
362
|
+
const opts = { device: dev };
|
|
363
|
+
const runtime = await createRuntime(plan, kernels, outputBufferId, opts);
|
|
364
|
+
if (Object.keys(payload.initialParams).length > 0) {
|
|
365
|
+
runtime.uploadParams(payload.initialParams);
|
|
366
|
+
}
|
|
367
|
+
const captureShapes = {};
|
|
368
|
+
for (const [name, bufId] of plan.capturesByName) {
|
|
369
|
+
captureShapes[name] = [...plan.buffers[bufId].shape];
|
|
370
|
+
}
|
|
371
|
+
const slot = {
|
|
372
|
+
runtime,
|
|
373
|
+
paramNames: [...plan.paramsByName.keys()],
|
|
374
|
+
outputShape: [...runtime.outputShape],
|
|
375
|
+
kernelCount: kernels.filter((k) => k.wgsl).length,
|
|
376
|
+
captureShapes,
|
|
377
|
+
adam: payload.adam ? createAdamState(payload.adam) : null
|
|
378
|
+
};
|
|
379
|
+
graphs.set(payload.graphId, slot);
|
|
380
|
+
return {
|
|
381
|
+
paramNames: [...slot.paramNames],
|
|
382
|
+
outputShape: slot.outputShape,
|
|
383
|
+
kernelCount: slot.kernelCount,
|
|
384
|
+
captureShapes: slot.captureShapes
|
|
385
|
+
};
|
|
386
|
+
}
|
|
387
|
+
async function handleCompileForward(payload) {
|
|
388
|
+
const dev = await ensureDevice();
|
|
389
|
+
const parent = graphs.get(payload.parentGraphId);
|
|
390
|
+
if (!parent) throw new Error(`compileForward: parent graph ${payload.parentGraphId} not found`);
|
|
391
|
+
const { graph, plan, kernels } = payload.ir;
|
|
392
|
+
const outputTensorId = graph.outputs[0];
|
|
393
|
+
const outputBufferId = plan.tensorToBuffer.get(outputTensorId);
|
|
394
|
+
const opts = { device: dev, sharedParams: parent.runtime.params };
|
|
395
|
+
const runtime = await createRuntime(plan, kernels, outputBufferId, opts);
|
|
396
|
+
const captureShapes = {};
|
|
397
|
+
for (const [name, bufId] of plan.capturesByName) {
|
|
398
|
+
captureShapes[name] = [...plan.buffers[bufId].shape];
|
|
399
|
+
}
|
|
400
|
+
const slot = {
|
|
401
|
+
runtime,
|
|
402
|
+
paramNames: [...plan.paramsByName.keys()],
|
|
403
|
+
outputShape: [...runtime.outputShape],
|
|
404
|
+
kernelCount: kernels.filter((k) => k.wgsl).length,
|
|
405
|
+
captureShapes,
|
|
406
|
+
adam: null
|
|
407
|
+
};
|
|
408
|
+
graphs.set(payload.graphId, slot);
|
|
409
|
+
return {
|
|
410
|
+
paramNames: [...slot.paramNames],
|
|
411
|
+
outputShape: slot.outputShape,
|
|
412
|
+
kernelCount: slot.kernelCount,
|
|
413
|
+
captureShapes: slot.captureShapes
|
|
414
|
+
};
|
|
415
|
+
}
|
|
416
|
+
function createAdamState(cfg) {
|
|
417
|
+
return {
|
|
418
|
+
config: cfg,
|
|
419
|
+
t: 0,
|
|
420
|
+
lrtBuf: new Float32Array(1),
|
|
421
|
+
decayShrinkBuf: cfg.decayShrinkInputName ? new Float32Array(1) : null
|
|
422
|
+
};
|
|
423
|
+
}
|
|
424
|
+
function injectAdamScalars(slot, inputs) {
|
|
425
|
+
const a = slot.adam;
|
|
426
|
+
if (!a) return inputs;
|
|
427
|
+
a.t++;
|
|
428
|
+
const lrNow = resolveLR(a.config.lr, a.t);
|
|
429
|
+
a.lrtBuf[0] = lrNow * Math.sqrt(1 - Math.pow(a.config.b2, a.t)) / (1 - Math.pow(a.config.b1, a.t));
|
|
430
|
+
const merged = { ...inputs, [a.config.lrtInputName]: a.lrtBuf };
|
|
431
|
+
if (a.decayShrinkBuf && a.config.decayShrinkInputName) {
|
|
432
|
+
a.decayShrinkBuf[0] = 1 - lrNow * a.config.weightDecay;
|
|
433
|
+
merged[a.config.decayShrinkInputName] = a.decayShrinkBuf;
|
|
434
|
+
}
|
|
435
|
+
return merged;
|
|
436
|
+
}
|
|
437
|
+
async function handleStep(payload) {
|
|
438
|
+
const slot = mustGet(payload.graphId);
|
|
439
|
+
const merged = injectAdamScalars(slot, payload.inputs);
|
|
440
|
+
if (payload.withCaptures) {
|
|
441
|
+
const r = await slot.runtime.step(merged, { withCaptures: true });
|
|
442
|
+
return { loss: r.loss, captures: capturesToRecord(r.captures, slot.captureShapes) };
|
|
443
|
+
}
|
|
444
|
+
const loss = await slot.runtime.step(merged);
|
|
445
|
+
return { loss, captures: null };
|
|
446
|
+
}
|
|
447
|
+
async function handleRun(payload) {
|
|
448
|
+
const slot = mustGet(payload.graphId);
|
|
449
|
+
if (payload.withCaptures) {
|
|
450
|
+
const r = await slot.runtime.run(payload.inputs, { withCaptures: true });
|
|
451
|
+
return { output: r.output, captures: capturesToRecord(r.captures, slot.captureShapes) };
|
|
452
|
+
}
|
|
453
|
+
const output = await slot.runtime.run(payload.inputs);
|
|
454
|
+
return { output, captures: null };
|
|
455
|
+
}
|
|
456
|
+
function capturesToRecord(captures, shapes) {
|
|
457
|
+
const out = {};
|
|
458
|
+
for (const name of Object.keys(shapes)) {
|
|
459
|
+
if (captures.has(name)) out[name] = captures.get(name);
|
|
460
|
+
}
|
|
461
|
+
return out;
|
|
462
|
+
}
|
|
463
|
+
function handleUploadParams(payload) {
|
|
464
|
+
const slot = mustGet(payload.graphId);
|
|
465
|
+
slot.runtime.uploadParams(payload.params, { partial: payload.partial });
|
|
466
|
+
}
|
|
467
|
+
async function handleDownloadParams(payload) {
|
|
468
|
+
const slot = mustGet(payload.graphId);
|
|
469
|
+
return { params: await slot.runtime.downloadParams() };
|
|
470
|
+
}
|
|
471
|
+
async function handleDownloadParamGrads(payload) {
|
|
472
|
+
const slot = mustGet(payload.graphId);
|
|
473
|
+
return { params: await slot.runtime.downloadParamGrads() };
|
|
474
|
+
}
|
|
475
|
+
function handleResetOptimizer(payload) {
|
|
476
|
+
const slot = mustGet(payload.graphId);
|
|
477
|
+
slot.runtime.resetOptimizerState();
|
|
478
|
+
if (slot.adam) slot.adam.t = 0;
|
|
479
|
+
}
|
|
480
|
+
function handleDestroy(payload) {
|
|
481
|
+
const slot = graphs.get(payload.graphId);
|
|
482
|
+
if (!slot) return;
|
|
483
|
+
slot.runtime.destroy();
|
|
484
|
+
graphs.delete(payload.graphId);
|
|
485
|
+
}
|
|
486
|
+
function mustGet(graphId) {
|
|
487
|
+
const slot = graphs.get(graphId);
|
|
488
|
+
if (!slot) throw new Error(`tensorgrad worker: graph ${graphId} not found`);
|
|
489
|
+
return slot;
|
|
490
|
+
}
|
|
491
|
+
self.onmessage = async (ev) => {
|
|
492
|
+
const req = ev.data;
|
|
493
|
+
try {
|
|
494
|
+
let result;
|
|
495
|
+
let transferList = [];
|
|
496
|
+
switch (req.kind) {
|
|
497
|
+
case "createRuntime":
|
|
498
|
+
result = await handleCreateRuntime(req.payload);
|
|
499
|
+
break;
|
|
500
|
+
case "compileForward":
|
|
501
|
+
result = await handleCompileForward(req.payload);
|
|
502
|
+
break;
|
|
503
|
+
case "step":
|
|
504
|
+
result = await handleStep(req.payload);
|
|
505
|
+
transferList = collectTransfers(result.captures);
|
|
506
|
+
break;
|
|
507
|
+
case "run": {
|
|
508
|
+
const r = await handleRun(req.payload);
|
|
509
|
+
result = r;
|
|
510
|
+
transferList = [r.output.buffer, ...collectTransfers(r.captures)];
|
|
511
|
+
break;
|
|
512
|
+
}
|
|
513
|
+
case "uploadParams":
|
|
514
|
+
handleUploadParams(req.payload);
|
|
515
|
+
result = null;
|
|
516
|
+
break;
|
|
517
|
+
case "downloadParams": {
|
|
518
|
+
const r = await handleDownloadParams(req.payload);
|
|
519
|
+
result = r;
|
|
520
|
+
transferList = collectTransfers(r.params);
|
|
521
|
+
break;
|
|
522
|
+
}
|
|
523
|
+
case "downloadParamGrads": {
|
|
524
|
+
const r = await handleDownloadParamGrads(req.payload);
|
|
525
|
+
result = r;
|
|
526
|
+
transferList = collectTransfers(r.params);
|
|
527
|
+
break;
|
|
528
|
+
}
|
|
529
|
+
case "resetOptimizer":
|
|
530
|
+
handleResetOptimizer(req.payload);
|
|
531
|
+
result = null;
|
|
532
|
+
break;
|
|
533
|
+
case "destroy":
|
|
534
|
+
handleDestroy(req.payload);
|
|
535
|
+
result = null;
|
|
536
|
+
break;
|
|
537
|
+
default:
|
|
538
|
+
throw new Error(`unknown request kind: ${req.kind}`);
|
|
539
|
+
}
|
|
540
|
+
const reply = { id: req.id, ok: true, result };
|
|
541
|
+
self.postMessage(reply, { transfer: transferList });
|
|
542
|
+
} catch (e) {
|
|
543
|
+
const error = wireError(e);
|
|
544
|
+
const reply = { id: req.id, ok: false, error };
|
|
545
|
+
self.postMessage(reply);
|
|
546
|
+
}
|
|
547
|
+
};
|
|
548
|
+
function collectTransfers(rec) {
|
|
549
|
+
if (!rec) return [];
|
|
550
|
+
const out = [];
|
|
551
|
+
for (const v of Object.values(rec)) out.push(v.buffer);
|
|
552
|
+
return out;
|
|
553
|
+
}
|
package/package.json
CHANGED
|
@@ -1,58 +1,60 @@
|
|
|
1
|
-
{
|
|
2
|
-
"name": "tensorgrad",
|
|
3
|
-
"version": "0.0.
|
|
4
|
-
"description": "Tiny TypeScript-native tensor library with autograd, compiling to WebGPU. Train small models in the browser without hand-writing kernels.",
|
|
5
|
-
"license": "MIT",
|
|
6
|
-
"author": "Ben Albahari",
|
|
7
|
-
"repository": {
|
|
8
|
-
"type": "git",
|
|
9
|
-
"url": "git+https://github.com/typebulb/tensorgrad.git"
|
|
10
|
-
},
|
|
11
|
-
"homepage": "https://github.com/typebulb/tensorgrad#readme",
|
|
12
|
-
"bugs": {
|
|
13
|
-
"url": "https://github.com/typebulb/tensorgrad/issues"
|
|
14
|
-
},
|
|
15
|
-
"keywords": [
|
|
16
|
-
"webgpu",
|
|
17
|
-
"machine-learning",
|
|
18
|
-
"autograd",
|
|
19
|
-
"tensor",
|
|
20
|
-
"neural-network",
|
|
21
|
-
"transformer",
|
|
22
|
-
"browser",
|
|
23
|
-
"typescript"
|
|
24
|
-
],
|
|
25
|
-
"type": "module",
|
|
26
|
-
"main": "./dist/index.js",
|
|
27
|
-
"module": "./dist/index.js",
|
|
28
|
-
"types": "./dist/index.d.ts",
|
|
29
|
-
"exports": {
|
|
30
|
-
".": {
|
|
31
|
-
"types": "./dist/index.d.ts",
|
|
32
|
-
"import": "./dist/index.js"
|
|
33
|
-
}
|
|
34
|
-
},
|
|
35
|
-
"sideEffects": false,
|
|
36
|
-
"files": [
|
|
37
|
-
"dist",
|
|
38
|
-
"src",
|
|
39
|
-
"README.md",
|
|
40
|
-
"LICENSE"
|
|
41
|
-
],
|
|
42
|
-
"scripts": {
|
|
43
|
-
"build:js": "
|
|
44
|
-
"build:
|
|
45
|
-
"build": "
|
|
46
|
-
"
|
|
47
|
-
"
|
|
48
|
-
"
|
|
49
|
-
"
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
"
|
|
54
|
-
"
|
|
55
|
-
"
|
|
56
|
-
"
|
|
57
|
-
|
|
58
|
-
|
|
1
|
+
{
|
|
2
|
+
"name": "tensorgrad",
|
|
3
|
+
"version": "0.0.16",
|
|
4
|
+
"description": "Tiny TypeScript-native tensor library with autograd, compiling to WebGPU. Train small models in the browser without hand-writing kernels.",
|
|
5
|
+
"license": "MIT",
|
|
6
|
+
"author": "Ben Albahari",
|
|
7
|
+
"repository": {
|
|
8
|
+
"type": "git",
|
|
9
|
+
"url": "git+https://github.com/typebulb/tensorgrad.git"
|
|
10
|
+
},
|
|
11
|
+
"homepage": "https://github.com/typebulb/tensorgrad#readme",
|
|
12
|
+
"bugs": {
|
|
13
|
+
"url": "https://github.com/typebulb/tensorgrad/issues"
|
|
14
|
+
},
|
|
15
|
+
"keywords": [
|
|
16
|
+
"webgpu",
|
|
17
|
+
"machine-learning",
|
|
18
|
+
"autograd",
|
|
19
|
+
"tensor",
|
|
20
|
+
"neural-network",
|
|
21
|
+
"transformer",
|
|
22
|
+
"browser",
|
|
23
|
+
"typescript"
|
|
24
|
+
],
|
|
25
|
+
"type": "module",
|
|
26
|
+
"main": "./dist/index.js",
|
|
27
|
+
"module": "./dist/index.js",
|
|
28
|
+
"types": "./dist/index.d.ts",
|
|
29
|
+
"exports": {
|
|
30
|
+
".": {
|
|
31
|
+
"types": "./dist/index.d.ts",
|
|
32
|
+
"import": "./dist/index.js"
|
|
33
|
+
}
|
|
34
|
+
},
|
|
35
|
+
"sideEffects": false,
|
|
36
|
+
"files": [
|
|
37
|
+
"dist",
|
|
38
|
+
"src",
|
|
39
|
+
"README.md",
|
|
40
|
+
"LICENSE"
|
|
41
|
+
],
|
|
42
|
+
"scripts": {
|
|
43
|
+
"build:js": "node scripts/build.mjs",
|
|
44
|
+
"build:js:watch": "node scripts/build.mjs --watch",
|
|
45
|
+
"build:types": "tsc -p tsconfig.types.json && rollup -c rollup.dts.config.mjs && rimraf dist/types-temp",
|
|
46
|
+
"build": "npm run build:js && npm run build:types",
|
|
47
|
+
"clean": "rimraf dist",
|
|
48
|
+
"typecheck": "tsc -p tsconfig.json --noEmit",
|
|
49
|
+
"test": "tsx test/smoke.ts",
|
|
50
|
+
"prepublishOnly": "npm run clean && npm run build"
|
|
51
|
+
},
|
|
52
|
+
"devDependencies": {
|
|
53
|
+
"esbuild": "^0.28.0",
|
|
54
|
+
"rimraf": "^6.0.1",
|
|
55
|
+
"rollup": "^4.0.0",
|
|
56
|
+
"rollup-plugin-dts": "^6.2.0",
|
|
57
|
+
"tsx": "*",
|
|
58
|
+
"typescript": "*"
|
|
59
|
+
}
|
|
60
|
+
}
|