tensorgrad 0.0.14 → 0.0.16
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.d.ts +154 -170
- package/dist/index.js +2208 -39
- package/dist/index.js.map +7 -1
- package/dist/worker.debug.js +553 -0
- package/package.json +60 -58
- package/src/adam.ts +69 -15
- package/src/compile.ts +334 -154
- package/src/index.ts +8 -4
- package/src/module.ts +72 -34
- package/src/runtime.ts +64 -11
- package/src/worker-protocol.ts +183 -0
- package/src/worker-proxy.ts +76 -0
- package/src/worker.ts +281 -0
- package/dist/adam.js +0 -111
- package/dist/adam.js.map +0 -1
- package/dist/buffers.js +0 -120
- package/dist/buffers.js.map +0 -1
- package/dist/capture.js +0 -33
- package/dist/capture.js.map +0 -1
- package/dist/codegen.js +0 -724
- package/dist/codegen.js.map +0 -1
- package/dist/compile.js +0 -180
- package/dist/compile.js.map +0 -1
- package/dist/grad.js +0 -380
- package/dist/grad.js.map +0 -1
- package/dist/ir.js +0 -60
- package/dist/ir.js.map +0 -1
- package/dist/module.js +0 -155
- package/dist/module.js.map +0 -1
- package/dist/nn.js +0 -135
- package/dist/nn.js.map +0 -1
- package/dist/ops.js +0 -326
- package/dist/ops.js.map +0 -1
- package/dist/runtime.js +0 -375
- package/dist/runtime.js.map +0 -1
- package/dist/shape.js +0 -259
- package/dist/shape.js.map +0 -1
- package/dist/trace.js +0 -100
- package/dist/trace.js.map +0 -1
package/src/module.ts
CHANGED
|
@@ -32,30 +32,47 @@ import { paramInput } from './trace.js'
|
|
|
32
32
|
// Init metadata
|
|
33
33
|
// ============================================================================
|
|
34
34
|
|
|
35
|
-
/** How a parameter's initial values are produced.
|
|
36
|
-
*
|
|
37
|
-
*
|
|
38
|
-
*
|
|
39
|
-
*
|
|
40
|
-
* -
|
|
41
|
-
*
|
|
35
|
+
/** How a parameter's initial values are produced. Serializable shape — no
|
|
36
|
+
* closures, since the initial values cross the worker boundary at compile
|
|
37
|
+
* time. Use the `init` helpers for ergonomic construction.
|
|
38
|
+
*
|
|
39
|
+
* String shorthands:
|
|
40
|
+
* - `'randn'` — Gaussian with std 0.02 (the common weight-matrix init).
|
|
41
|
+
* - `'zeros'` — fill with 0 (biases, LayerNorm beta).
|
|
42
|
+
* - `'ones'` — fill with 1 (LayerNorm gain).
|
|
43
|
+
*
|
|
44
|
+
* Object shapes:
|
|
45
|
+
* - `{ kind: 'randn', scale }` — randn with explicit std.
|
|
46
|
+
* - `{ kind: 'kaiming', gain? }` — `std = gain / sqrt(fan_in)`. Default
|
|
47
|
+
* gain `sqrt(2)` (good for ReLU). `fan_in = shape[0]`.
|
|
48
|
+
* - `{ kind: 'literal', data }` — explicit Float32Array; length must
|
|
49
|
+
* match the parameter's element count.
|
|
42
50
|
*/
|
|
43
51
|
export type InitSpec =
|
|
44
52
|
| 'randn'
|
|
45
53
|
| 'zeros'
|
|
46
54
|
| 'ones'
|
|
47
|
-
|
|
|
55
|
+
| { readonly kind: 'randn'; readonly scale: number }
|
|
56
|
+
| { readonly kind: 'kaiming'; readonly gain?: number }
|
|
57
|
+
| { readonly kind: 'literal'; readonly data: Float32Array }
|
|
58
|
+
|
|
59
|
+
/** Ergonomic constructors for InitSpec object shapes. */
|
|
60
|
+
export const init = {
|
|
61
|
+
randn: (opts: { scale?: number } = {}): InitSpec => ({ kind: 'randn', scale: opts.scale ?? 0.02 }),
|
|
62
|
+
kaiming: (opts: { gain?: number } = {}): InitSpec =>
|
|
63
|
+
opts.gain !== undefined ? { kind: 'kaiming', gain: opts.gain } : { kind: 'kaiming' },
|
|
64
|
+
literal: (data: Float32Array): InitSpec => ({ kind: 'literal', data }),
|
|
65
|
+
}
|
|
48
66
|
|
|
49
67
|
export interface ParamOptions {
|
|
50
68
|
dtype?: Dtype
|
|
51
|
-
/** Init
|
|
69
|
+
/** Init shape. Default: `'randn'` (std 0.02). */
|
|
52
70
|
init?: InitSpec
|
|
53
|
-
/** Std dev for `'randn'`. Default 0.02. Ignored for non-randn init. */
|
|
54
|
-
scale?: number
|
|
55
71
|
/** Whether AdamW (when `weightDecay > 0`) should apply decoupled weight
|
|
56
|
-
* decay to this param. Default: `true` for
|
|
57
|
-
* embeddings)
|
|
58
|
-
* to force or skip. Replaces `adam.decayFilter` for
|
|
72
|
+
* decay to this param. Default: `true` for randn/kaiming/literal init
|
|
73
|
+
* (weight matrices, embeddings); `false` for zeros/ones (biases, LN
|
|
74
|
+
* gains). Override to force or skip. Replaces `adam.decayFilter` for
|
|
75
|
+
* the common case. */
|
|
59
76
|
decay?: boolean
|
|
60
77
|
}
|
|
61
78
|
|
|
@@ -65,31 +82,52 @@ function boxMuller(): number {
|
|
|
65
82
|
return Math.sqrt(-2 * Math.log(Math.max(1e-10, Math.random()))) * Math.cos(2 * Math.PI * Math.random())
|
|
66
83
|
}
|
|
67
84
|
|
|
68
|
-
function
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
return
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
85
|
+
function randnFn(scale: number): InitFn {
|
|
86
|
+
return (size) => {
|
|
87
|
+
const arr = new Float32Array(size)
|
|
88
|
+
for (let i = 0; i < size; i++) arr[i] = boxMuller() * scale
|
|
89
|
+
return arr
|
|
90
|
+
}
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
/** Compile-time-only: resolve an InitSpec shape into the closure that
|
|
94
|
+
* generates the initial Float32Array for a given parameter shape. Runs
|
|
95
|
+
* on the main thread before initial values are transferred to the worker. */
|
|
96
|
+
function resolveInit(spec: InitSpec | undefined): InitFn {
|
|
97
|
+
if (!spec || spec === 'randn') return randnFn(0.02)
|
|
98
|
+
if (spec === 'zeros') return (size) => new Float32Array(size)
|
|
99
|
+
if (spec === 'ones') return (size) => { const a = new Float32Array(size); a.fill(1); return a }
|
|
100
|
+
switch (spec.kind) {
|
|
101
|
+
case 'randn': return randnFn(spec.scale)
|
|
102
|
+
case 'kaiming': {
|
|
103
|
+
const gain = spec.gain ?? Math.sqrt(2)
|
|
104
|
+
return (size, shape) => {
|
|
105
|
+
const fanIn = shape[0] ?? size
|
|
106
|
+
const std = gain / Math.sqrt(fanIn)
|
|
107
|
+
const arr = new Float32Array(size)
|
|
108
|
+
for (let i = 0; i < size; i++) arr[i] = boxMuller() * std
|
|
109
|
+
return arr
|
|
110
|
+
}
|
|
111
|
+
}
|
|
112
|
+
case 'literal': {
|
|
113
|
+
const data = spec.data
|
|
114
|
+
return (size) => {
|
|
115
|
+
if (data.length !== size) {
|
|
116
|
+
throw new Error(`init.literal: data length ${data.length} doesn't match param size ${size}`)
|
|
117
|
+
}
|
|
118
|
+
return new Float32Array(data)
|
|
119
|
+
}
|
|
76
120
|
}
|
|
77
121
|
}
|
|
78
|
-
if (init === 'zeros') return (size) => new Float32Array(size)
|
|
79
|
-
if (init === 'ones') return (size) => { const a = new Float32Array(size); a.fill(1); return a }
|
|
80
|
-
if (typeof init === 'function') return init
|
|
81
|
-
throw new Error(`Unknown init: ${String(init)}`)
|
|
82
122
|
}
|
|
83
123
|
|
|
84
|
-
/** Resolve the decay default for a param.
|
|
85
|
-
*
|
|
86
|
-
* (
|
|
87
|
-
* inits are weight-shaped (Kaiming etc.). Explicit `decay: false` overrides. */
|
|
124
|
+
/** Resolve the decay default for a param. Weight-shaped inits (randn,
|
|
125
|
+
* kaiming, literal) default to decay=true; ones/zeros default to false
|
|
126
|
+
* (biases, LN gains). Explicit `decay` opt overrides. */
|
|
88
127
|
function resolveDecay(opts: ParamOptions | undefined): boolean {
|
|
89
128
|
if (opts?.decay !== undefined) return opts.decay
|
|
90
|
-
const
|
|
91
|
-
|
|
92
|
-
return true // 'randn' or function
|
|
129
|
+
const spec = opts?.init ?? 'randn'
|
|
130
|
+
return spec !== 'zeros' && spec !== 'ones'
|
|
93
131
|
}
|
|
94
132
|
|
|
95
133
|
// ============================================================================
|
|
@@ -127,7 +165,7 @@ export abstract class Module {
|
|
|
127
165
|
protected param(shape: Shape, opts?: ParamOptions): Tensor {
|
|
128
166
|
const dtype = opts?.dtype ?? 'f32'
|
|
129
167
|
// Lie to TypeScript: the sentinel becomes a Tensor at materialize time.
|
|
130
|
-
return new ParamSentinel(shape, dtype, resolveInit(opts), resolveDecay(opts)) as unknown as Tensor
|
|
168
|
+
return new ParamSentinel(shape, dtype, resolveInit(opts?.init), resolveDecay(opts)) as unknown as Tensor
|
|
131
169
|
}
|
|
132
170
|
}
|
|
133
171
|
|
package/src/runtime.ts
CHANGED
|
@@ -69,6 +69,23 @@ export interface RunOptions {
|
|
|
69
69
|
withCaptures?: boolean
|
|
70
70
|
}
|
|
71
71
|
|
|
72
|
+
export interface StepOptions extends RunOptions {
|
|
73
|
+
/** If false, the training submit is queued but the JS thread does not
|
|
74
|
+
* await `mapAsync` of the loss buffer. Returns `void` immediately.
|
|
75
|
+
* Use `runtime.readLoss()` to read the latest loss explicitly when
|
|
76
|
+
* you want it (e.g., every Nth step for UI display).
|
|
77
|
+
*
|
|
78
|
+
* Why: each `mapAsync` round-trip is ~1 ms on desktop but 10–30 ms on
|
|
79
|
+
* Android Chrome. A training loop that awaits per step pays N × that
|
|
80
|
+
* on the main thread, which on mobile starves the OS compositor and
|
|
81
|
+
* causes visible UI sluggishness. With `readLoss: false` plus a
|
|
82
|
+
* `requestAnimationFrame` yield between steps, the main thread stays
|
|
83
|
+
* responsive while training runs at GPU speed.
|
|
84
|
+
*
|
|
85
|
+
* Implies `withCaptures: false`. Default: true. */
|
|
86
|
+
readLoss?: boolean
|
|
87
|
+
}
|
|
88
|
+
|
|
72
89
|
/** Common surface for both training and forward-only compiled runtimes. */
|
|
73
90
|
export interface CompiledBase {
|
|
74
91
|
/** The GPUDevice this runtime is bound to. Pass to sibling compiles to
|
|
@@ -112,11 +129,16 @@ export interface CompiledRuntime extends CompiledBase {
|
|
|
112
129
|
*/
|
|
113
130
|
step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>
|
|
114
131
|
step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<StepResult>
|
|
115
|
-
step(inputs: Record<string, Int32Array | Float32Array>, opts:
|
|
132
|
+
step(inputs: Record<string, Int32Array | Float32Array>, opts: { readLoss: false }): Promise<void>
|
|
133
|
+
step(inputs: Record<string, Int32Array | Float32Array>, opts: StepOptions): Promise<number | StepResult | void>
|
|
116
134
|
/** Same dispatch as step() but returns the full output Float32Array — for
|
|
117
135
|
* training graphs the output is a scalar loss, so step() is usually more
|
|
118
136
|
* convenient. Provided for parity with `compileForward`. */
|
|
119
137
|
run: RunFn
|
|
138
|
+
/** Read the latest loss value from the GPU. Pair with `step({ readLoss: false })`
|
|
139
|
+
* fire-and-forget training: every Nth iteration, call `readLoss()` for the
|
|
140
|
+
* UI, but most iterations don't pay the `mapAsync` cost. */
|
|
141
|
+
readLoss(): Promise<number>
|
|
120
142
|
/** Re-zero all optimizer state buffers (Adam's m/v) in place. Pair with
|
|
121
143
|
* `uploadInitialParams()` for a full training reset without recompile. */
|
|
122
144
|
resetOptimizerState(): void
|
|
@@ -292,18 +314,21 @@ export async function createRuntime(
|
|
|
292
314
|
// run sequentially even when fired from independent async paths (e.g., a
|
|
293
315
|
// training loop's auxiliary `refreshPrediction()` + `writeDiagnostic()`).
|
|
294
316
|
let pending: Promise<unknown> = Promise.resolve()
|
|
317
|
+
type DispatchOpts = { wantCaptures: boolean; readback: boolean }
|
|
318
|
+
type DispatchResult = { output: Float32Array; captures: Map<string, Float32Array> } | null
|
|
295
319
|
async function dispatch(
|
|
296
320
|
inputs: Record<string, Int32Array | Float32Array>,
|
|
297
|
-
|
|
298
|
-
): Promise<
|
|
299
|
-
const turn = pending.catch(() => {}).then(() => dispatchUnsynchronized(inputs,
|
|
321
|
+
opts: DispatchOpts,
|
|
322
|
+
): Promise<DispatchResult> {
|
|
323
|
+
const turn = pending.catch(() => {}).then(() => dispatchUnsynchronized(inputs, opts))
|
|
300
324
|
pending = turn
|
|
301
325
|
return turn
|
|
302
326
|
}
|
|
303
327
|
async function dispatchUnsynchronized(
|
|
304
328
|
inputs: Record<string, Int32Array | Float32Array>,
|
|
305
|
-
|
|
306
|
-
): Promise<
|
|
329
|
+
opts: DispatchOpts,
|
|
330
|
+
): Promise<DispatchResult> {
|
|
331
|
+
const wantCaptures = opts.wantCaptures
|
|
307
332
|
if (wantCaptures && plan.capturesByName.size === 0) {
|
|
308
333
|
throw new Error(
|
|
309
334
|
`withCaptures=true but no capture(...) calls were registered during ` +
|
|
@@ -360,6 +385,12 @@ export async function createRuntime(
|
|
|
360
385
|
}
|
|
361
386
|
queue.submit([encoder.finish()])
|
|
362
387
|
|
|
388
|
+
// readback=false: training fire-and-forget. The encoder still copied
|
|
389
|
+
// loss → outputReadback (and captures → staging), but we don't await
|
|
390
|
+
// mapAsync. The caller can read the latest loss later via readLoss()
|
|
391
|
+
// when it actually wants to display it.
|
|
392
|
+
if (!opts.readback) return null
|
|
393
|
+
|
|
363
394
|
await outputReadback.mapAsync(GPUMapMode.READ)
|
|
364
395
|
const output = new Float32Array(outputReadback.getMappedRange().slice(0))
|
|
365
396
|
outputReadback.unmap()
|
|
@@ -381,16 +412,37 @@ export async function createRuntime(
|
|
|
381
412
|
// ---- step() — training-mode wrapper, returns scalar [0] of output ---------
|
|
382
413
|
function step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>
|
|
383
414
|
function step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<StepResult>
|
|
384
|
-
function step(inputs: Record<string, Int32Array | Float32Array>, opts:
|
|
415
|
+
function step(inputs: Record<string, Int32Array | Float32Array>, opts: { readLoss: false }): Promise<void>
|
|
416
|
+
function step(inputs: Record<string, Int32Array | Float32Array>, opts: StepOptions): Promise<number | StepResult | void>
|
|
385
417
|
async function step(
|
|
386
418
|
inputs: Record<string, Int32Array | Float32Array>,
|
|
387
|
-
opts?:
|
|
388
|
-
): Promise<number | StepResult> {
|
|
389
|
-
|
|
419
|
+
opts?: StepOptions,
|
|
420
|
+
): Promise<number | StepResult | void> {
|
|
421
|
+
if (opts?.readLoss === false) {
|
|
422
|
+
await dispatch(inputs, { wantCaptures: false, readback: false })
|
|
423
|
+
return
|
|
424
|
+
}
|
|
425
|
+
const r = (await dispatch(inputs, { wantCaptures: opts?.withCaptures === true, readback: true }))!
|
|
390
426
|
if (opts?.withCaptures) return { loss: r.output[0]!, captures: new Captures(captureShapes, r.captures) }
|
|
391
427
|
return r.output[0]!
|
|
392
428
|
}
|
|
393
429
|
|
|
430
|
+
// ---- readLoss() — explicit late readback for fire-and-forget training -----
|
|
431
|
+
// Maps the output buffer (which step() always copies the latest loss into,
|
|
432
|
+
// even when readLoss:false) and returns the value. Goes through the same
|
|
433
|
+
// serialization chain as step()/run() so two readLoss() calls don't both
|
|
434
|
+
// try to mapAsync the same buffer.
|
|
435
|
+
async function readLoss(): Promise<number> {
|
|
436
|
+
const turn = pending.catch(() => {}).then(async () => {
|
|
437
|
+
await outputReadback.mapAsync(GPUMapMode.READ)
|
|
438
|
+
const v = new Float32Array(outputReadback.getMappedRange())[0]!
|
|
439
|
+
outputReadback.unmap()
|
|
440
|
+
return v
|
|
441
|
+
})
|
|
442
|
+
pending = turn
|
|
443
|
+
return turn
|
|
444
|
+
}
|
|
445
|
+
|
|
394
446
|
// ---- run() — forward-mode wrapper, returns Float32Array by default -------
|
|
395
447
|
// Same overloaded shape as step(): scalar-shaped result (here Float32Array,
|
|
396
448
|
// there a JS number) is the default; { ..., captures } is the opt-in form.
|
|
@@ -401,7 +453,7 @@ export async function createRuntime(
|
|
|
401
453
|
inputs: Record<string, Int32Array | Float32Array>,
|
|
402
454
|
opts?: RunOptions,
|
|
403
455
|
): Promise<Float32Array | RunResult> {
|
|
404
|
-
const r = await dispatch(inputs, opts?.withCaptures === true)
|
|
456
|
+
const r = (await dispatch(inputs, { wantCaptures: opts?.withCaptures === true, readback: true }))!
|
|
405
457
|
if (opts?.withCaptures) return { output: r.output, captures: new Captures(captureShapes, r.captures) }
|
|
406
458
|
return r.output
|
|
407
459
|
}
|
|
@@ -507,6 +559,7 @@ export async function createRuntime(
|
|
|
507
559
|
downloadParamGrads: () => downloadFromMap(plan.paramGradsByName),
|
|
508
560
|
step,
|
|
509
561
|
run,
|
|
562
|
+
readLoss,
|
|
510
563
|
resetOptimizerState,
|
|
511
564
|
destroy,
|
|
512
565
|
}
|
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
// Wire format for the main-thread ↔ worker postMessage channel.
|
|
2
|
+
//
|
|
3
|
+
// All requests carry a numeric `id` assigned by the main thread; responses
|
|
4
|
+
// echo it back so the proxy can match concurrent in-flight calls. Every
|
|
5
|
+
// response is either `{ ok: true, result }` or `{ ok: false, error }`.
|
|
6
|
+
// Errors carry serialized name/message/stack so the proxy can reconstitute
|
|
7
|
+
// an Error with a working `instanceof` check on the receiving side.
|
|
8
|
+
//
|
|
9
|
+
// Inputs (typed arrays) and outputs (typed arrays, captures) are transferred
|
|
10
|
+
// rather than copied — see the per-request notes for which fields go on the
|
|
11
|
+
// transfer list. A single worker may host multiple compiled graphs (a train
|
|
12
|
+
// graph plus sibling forward graphs); each has a `graphId` issued by the
|
|
13
|
+
// main thread at compile time.
|
|
14
|
+
|
|
15
|
+
import type { Graph } from './ir.js'
|
|
16
|
+
import type { BufferPlan } from './buffers.js'
|
|
17
|
+
import type { KernelSpec } from './codegen.js'
|
|
18
|
+
import type { LRSchedule } from './adam.js'
|
|
19
|
+
|
|
20
|
+
// ============================================================================
|
|
21
|
+
// Serializable config (subset of AdamResolvedConfig that crosses the wire).
|
|
22
|
+
// `decayFilter` (a function, used only at compile time) is NOT part of this —
|
|
23
|
+
// the per-param decay decision is already baked into the IR by appendAdam
|
|
24
|
+
// before the IR ships to the worker.
|
|
25
|
+
// ============================================================================
|
|
26
|
+
|
|
27
|
+
export interface WireAdamConfig {
|
|
28
|
+
lr: LRSchedule
|
|
29
|
+
b1: number
|
|
30
|
+
b2: number
|
|
31
|
+
eps: number
|
|
32
|
+
weightDecay: number
|
|
33
|
+
lrIsScheduled: boolean
|
|
34
|
+
/** Names of the per-step scalar inputs the worker must populate before
|
|
35
|
+
* every step (`_adam_lrt`, optionally `_adam_decay_shrink`). Mirrors
|
|
36
|
+
* AdamResult so the worker can update them without re-deriving. */
|
|
37
|
+
lrtInputName: string
|
|
38
|
+
decayShrinkInputName: string | null
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
/** Compile output that crosses to the worker. Same fields as CompiledIR
|
|
42
|
+
* minus the `loss` tensor (carried by graph.outputs[0]). */
|
|
43
|
+
export interface WireIR {
|
|
44
|
+
graph: Graph
|
|
45
|
+
plan: BufferPlan
|
|
46
|
+
kernels: KernelSpec[]
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
// ============================================================================
|
|
50
|
+
// Requests (main → worker)
|
|
51
|
+
// ============================================================================
|
|
52
|
+
|
|
53
|
+
export type Req =
|
|
54
|
+
| { id: number; kind: 'createRuntime'; payload: CreateRuntimePayload }
|
|
55
|
+
| { id: number; kind: 'compileForward'; payload: CompileForwardPayload }
|
|
56
|
+
| { id: number; kind: 'step'; payload: StepPayload }
|
|
57
|
+
| { id: number; kind: 'run'; payload: RunPayload }
|
|
58
|
+
| { id: number; kind: 'uploadParams'; payload: UploadParamsPayload }
|
|
59
|
+
| { id: number; kind: 'downloadParams'; payload: { graphId: number } }
|
|
60
|
+
| { id: number; kind: 'downloadParamGrads'; payload: { graphId: number } }
|
|
61
|
+
| { id: number; kind: 'resetOptimizer'; payload: { graphId: number } }
|
|
62
|
+
| { id: number; kind: 'destroy'; payload: { graphId: number } }
|
|
63
|
+
|
|
64
|
+
/** Build the training runtime. Always graphId=0 for a fresh worker. */
|
|
65
|
+
export interface CreateRuntimePayload {
|
|
66
|
+
graphId: number
|
|
67
|
+
ir: WireIR
|
|
68
|
+
/** Initial param values per name. Transferred (zero-copy) — the main
|
|
69
|
+
* thread loses access after postMessage. */
|
|
70
|
+
initialParams: Record<string, Float32Array>
|
|
71
|
+
/** Adam config when training; absent for forward-only compiles. */
|
|
72
|
+
adam: WireAdamConfig | null
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
/** Build a sibling forward-only graph that shares param buffers with an
|
|
76
|
+
* existing graph (typically the training graph at graphId=0). */
|
|
77
|
+
export interface CompileForwardPayload {
|
|
78
|
+
graphId: number
|
|
79
|
+
parentGraphId: number
|
|
80
|
+
ir: WireIR
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
/** One training step. Inputs are transferred; the caller's typed arrays
|
|
84
|
+
* become detached after postMessage. */
|
|
85
|
+
export interface StepPayload {
|
|
86
|
+
graphId: number
|
|
87
|
+
inputs: Record<string, Int32Array | Float32Array>
|
|
88
|
+
withCaptures: boolean
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
/** Forward-only run. Same transfer semantics as `step`. */
|
|
92
|
+
export interface RunPayload {
|
|
93
|
+
graphId: number
|
|
94
|
+
inputs: Record<string, Int32Array | Float32Array>
|
|
95
|
+
withCaptures: boolean
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
export interface UploadParamsPayload {
|
|
99
|
+
graphId: number
|
|
100
|
+
params: Record<string, Float32Array> // transferred
|
|
101
|
+
partial: boolean
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
// ============================================================================
|
|
105
|
+
// Responses (worker → main)
|
|
106
|
+
// ============================================================================
|
|
107
|
+
|
|
108
|
+
export type Res<R = unknown> =
|
|
109
|
+
| { id: number; ok: true; result: R }
|
|
110
|
+
| { id: number; ok: false; error: WireError }
|
|
111
|
+
|
|
112
|
+
export interface WireError {
|
|
113
|
+
name: string
|
|
114
|
+
message: string
|
|
115
|
+
stack: string
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
// Per-request result shapes:
|
|
119
|
+
|
|
120
|
+
export interface CreateRuntimeResult {
|
|
121
|
+
paramNames: string[]
|
|
122
|
+
outputShape: number[]
|
|
123
|
+
kernelCount: number
|
|
124
|
+
captureShapes: Record<string, number[]>
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
export interface CompileForwardResult {
|
|
128
|
+
paramNames: string[]
|
|
129
|
+
outputShape: number[]
|
|
130
|
+
kernelCount: number
|
|
131
|
+
captureShapes: Record<string, number[]>
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
/** Step without `withCaptures` returns just `loss`. With captures, also
|
|
135
|
+
* populates `captures` (per-name Float32Array, all transferred back). */
|
|
136
|
+
export interface StepResultWire {
|
|
137
|
+
loss: number
|
|
138
|
+
captures: Record<string, Float32Array> | null
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
/** Run without `withCaptures` returns `{ output, captures: null }`.
|
|
142
|
+
* With captures, also populates `captures`. */
|
|
143
|
+
export interface RunResultWire {
|
|
144
|
+
output: Float32Array
|
|
145
|
+
captures: Record<string, Float32Array> | null
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
export interface DownloadParamsResult {
|
|
149
|
+
params: Record<string, Float32Array> // transferred
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
// ============================================================================
|
|
153
|
+
// Transfer-list helpers
|
|
154
|
+
// ============================================================================
|
|
155
|
+
|
|
156
|
+
/** Collect the underlying ArrayBuffers from a Record of typed arrays so we
|
|
157
|
+
* can pass them on `postMessage`'s transfer list. The values themselves
|
|
158
|
+
* stay in the Record; only their backing buffers move. */
|
|
159
|
+
export function transferablesOfRecord(
|
|
160
|
+
rec: Record<string, Int32Array | Float32Array>,
|
|
161
|
+
): ArrayBuffer[] {
|
|
162
|
+
const out: ArrayBuffer[] = []
|
|
163
|
+
for (const v of Object.values(rec)) out.push(v.buffer as ArrayBuffer)
|
|
164
|
+
return out
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
/** Serialize an Error to a wire-friendly shape, preserving stack + name so
|
|
168
|
+
* the receiving side can reconstitute an Error that an `instanceof`-aware
|
|
169
|
+
* caller (e.g., for `ShapeError`) can still pattern-match by name. */
|
|
170
|
+
export function wireError(e: unknown): WireError {
|
|
171
|
+
if (e instanceof Error) {
|
|
172
|
+
return { name: e.name, message: e.message, stack: e.stack ?? '' }
|
|
173
|
+
}
|
|
174
|
+
return { name: 'Error', message: String(e), stack: '' }
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
/** Reconstitute an Error from the wire shape on the receiving (main) side. */
|
|
178
|
+
export function reconstituteError(w: WireError): Error {
|
|
179
|
+
const err = new Error(w.message)
|
|
180
|
+
err.name = w.name
|
|
181
|
+
err.stack = w.stack
|
|
182
|
+
return err
|
|
183
|
+
}
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
// Main-thread half of the worker channel: request/response correlation,
|
|
2
|
+
// promise wiring, error reconstitution. Knows nothing about Adam, captures,
|
|
3
|
+
// IR, etc. — just shuttles typed messages.
|
|
4
|
+
|
|
5
|
+
import type { Req, Res, WireError } from './worker-protocol.js'
|
|
6
|
+
import { reconstituteError } from './worker-protocol.js'
|
|
7
|
+
|
|
8
|
+
interface PendingHandlers {
|
|
9
|
+
resolve: (v: unknown) => void
|
|
10
|
+
reject: (e: Error) => void
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
/** Spawn a worker from an inlined source string and provide a typed
|
|
14
|
+
* request/response channel. One WorkerProxy = one Worker = one GPUDevice
|
|
15
|
+
* on the worker side. Sibling graphs share the same WorkerProxy. */
|
|
16
|
+
export class WorkerProxy {
|
|
17
|
+
private worker: Worker
|
|
18
|
+
private nextId = 1
|
|
19
|
+
private pending = new Map<number, PendingHandlers>()
|
|
20
|
+
private terminated = false
|
|
21
|
+
|
|
22
|
+
constructor(workerSource: string) {
|
|
23
|
+
const blob = new Blob([workerSource], { type: 'application/javascript' })
|
|
24
|
+
const url = URL.createObjectURL(blob)
|
|
25
|
+
this.worker = new Worker(url, { type: 'module' })
|
|
26
|
+
// The Blob URL keeps memory alive as long as it's referenced; revoke
|
|
27
|
+
// once the worker has loaded its source. Browsers tolerate revoke
|
|
28
|
+
// immediately after construction in practice.
|
|
29
|
+
URL.revokeObjectURL(url)
|
|
30
|
+
|
|
31
|
+
this.worker.onmessage = (ev: MessageEvent<Res>) => {
|
|
32
|
+
const reply = ev.data
|
|
33
|
+
const handlers = this.pending.get(reply.id)
|
|
34
|
+
if (!handlers) return // stale reply; ignore
|
|
35
|
+
this.pending.delete(reply.id)
|
|
36
|
+
if (reply.ok) handlers.resolve(reply.result)
|
|
37
|
+
else handlers.reject(reconstituteError(reply.error))
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
this.worker.onerror = (ev: ErrorEvent) => {
|
|
41
|
+
const err = new Error(`tensorgrad worker error: ${ev.message || 'unknown'}`)
|
|
42
|
+
const wire: WireError = { name: 'WorkerError', message: err.message, stack: err.stack ?? '' }
|
|
43
|
+
// Reject everything in flight; subsequent calls will fail too.
|
|
44
|
+
for (const handlers of this.pending.values()) handlers.reject(reconstituteError(wire))
|
|
45
|
+
this.pending.clear()
|
|
46
|
+
}
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
/** Send a request and await its matching response. `transfer` lists the
|
|
50
|
+
* ArrayBuffers to move (zero-copy) into the worker. */
|
|
51
|
+
request<R>(req: Omit<Req, 'id'>, transfer: ArrayBuffer[] = []): Promise<R> {
|
|
52
|
+
if (this.terminated) return Promise.reject(new Error('tensorgrad: worker has been terminated'))
|
|
53
|
+
const id = this.nextId++
|
|
54
|
+
return new Promise<R>((resolve, reject) => {
|
|
55
|
+
this.pending.set(id, { resolve: resolve as (v: unknown) => void, reject })
|
|
56
|
+
this.worker.postMessage({ ...req, id } as Req, transfer)
|
|
57
|
+
})
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
/** Fire-and-forget variant for cases where the caller doesn't need a reply
|
|
61
|
+
* (currently unused; keep for symmetry / future use). */
|
|
62
|
+
send(req: Omit<Req, 'id'>, transfer: ArrayBuffer[] = []): void {
|
|
63
|
+
if (this.terminated) return
|
|
64
|
+
const id = this.nextId++
|
|
65
|
+
this.worker.postMessage({ ...req, id } as Req, transfer)
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
terminate(): void {
|
|
69
|
+
if (this.terminated) return
|
|
70
|
+
this.terminated = true
|
|
71
|
+
this.worker.terminate()
|
|
72
|
+
const err = new Error('tensorgrad: worker terminated')
|
|
73
|
+
for (const handlers of this.pending.values()) handlers.reject(err)
|
|
74
|
+
this.pending.clear()
|
|
75
|
+
}
|
|
76
|
+
}
|