tensorgrad 0.0.15 → 0.0.16
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.d.ts +154 -193
- package/dist/index.js +2208 -39
- package/dist/index.js.map +7 -1
- package/dist/worker.debug.js +553 -0
- package/package.json +60 -58
- package/src/adam.ts +69 -15
- package/src/compile.ts +334 -156
- package/src/index.ts +8 -4
- package/src/module.ts +72 -34
- package/src/worker-protocol.ts +183 -0
- package/src/worker-proxy.ts +76 -0
- package/src/worker.ts +281 -0
- package/dist/adam.js +0 -111
- package/dist/adam.js.map +0 -1
- package/dist/buffers.js +0 -120
- package/dist/buffers.js.map +0 -1
- package/dist/capture.js +0 -33
- package/dist/capture.js.map +0 -1
- package/dist/codegen.js +0 -724
- package/dist/codegen.js.map +0 -1
- package/dist/compile.js +0 -184
- package/dist/compile.js.map +0 -1
- package/dist/grad.js +0 -380
- package/dist/grad.js.map +0 -1
- package/dist/ir.js +0 -60
- package/dist/ir.js.map +0 -1
- package/dist/module.js +0 -155
- package/dist/module.js.map +0 -1
- package/dist/nn.js +0 -135
- package/dist/nn.js.map +0 -1
- package/dist/ops.js +0 -326
- package/dist/ops.js.map +0 -1
- package/dist/runtime.js +0 -402
- package/dist/runtime.js.map +0 -1
- package/dist/shape.js +0 -259
- package/dist/shape.js.map +0 -1
- package/dist/trace.js +0 -100
- package/dist/trace.js.map +0 -1
package/src/module.ts
CHANGED
|
@@ -32,30 +32,47 @@ import { paramInput } from './trace.js'
|
|
|
32
32
|
// Init metadata
|
|
33
33
|
// ============================================================================
|
|
34
34
|
|
|
35
|
-
/** How a parameter's initial values are produced.
|
|
36
|
-
*
|
|
37
|
-
*
|
|
38
|
-
*
|
|
39
|
-
*
|
|
40
|
-
* -
|
|
41
|
-
*
|
|
35
|
+
/** How a parameter's initial values are produced. Serializable shape — no
|
|
36
|
+
* closures, since the initial values cross the worker boundary at compile
|
|
37
|
+
* time. Use the `init` helpers for ergonomic construction.
|
|
38
|
+
*
|
|
39
|
+
* String shorthands:
|
|
40
|
+
* - `'randn'` — Gaussian with std 0.02 (the common weight-matrix init).
|
|
41
|
+
* - `'zeros'` — fill with 0 (biases, LayerNorm beta).
|
|
42
|
+
* - `'ones'` — fill with 1 (LayerNorm gain).
|
|
43
|
+
*
|
|
44
|
+
* Object shapes:
|
|
45
|
+
* - `{ kind: 'randn', scale }` — randn with explicit std.
|
|
46
|
+
* - `{ kind: 'kaiming', gain? }` — `std = gain / sqrt(fan_in)`. Default
|
|
47
|
+
* gain `sqrt(2)` (good for ReLU). `fan_in = shape[0]`.
|
|
48
|
+
* - `{ kind: 'literal', data }` — explicit Float32Array; length must
|
|
49
|
+
* match the parameter's element count.
|
|
42
50
|
*/
|
|
43
51
|
export type InitSpec =
|
|
44
52
|
| 'randn'
|
|
45
53
|
| 'zeros'
|
|
46
54
|
| 'ones'
|
|
47
|
-
|
|
|
55
|
+
| { readonly kind: 'randn'; readonly scale: number }
|
|
56
|
+
| { readonly kind: 'kaiming'; readonly gain?: number }
|
|
57
|
+
| { readonly kind: 'literal'; readonly data: Float32Array }
|
|
58
|
+
|
|
59
|
+
/** Ergonomic constructors for InitSpec object shapes. */
|
|
60
|
+
export const init = {
|
|
61
|
+
randn: (opts: { scale?: number } = {}): InitSpec => ({ kind: 'randn', scale: opts.scale ?? 0.02 }),
|
|
62
|
+
kaiming: (opts: { gain?: number } = {}): InitSpec =>
|
|
63
|
+
opts.gain !== undefined ? { kind: 'kaiming', gain: opts.gain } : { kind: 'kaiming' },
|
|
64
|
+
literal: (data: Float32Array): InitSpec => ({ kind: 'literal', data }),
|
|
65
|
+
}
|
|
48
66
|
|
|
49
67
|
export interface ParamOptions {
|
|
50
68
|
dtype?: Dtype
|
|
51
|
-
/** Init
|
|
69
|
+
/** Init shape. Default: `'randn'` (std 0.02). */
|
|
52
70
|
init?: InitSpec
|
|
53
|
-
/** Std dev for `'randn'`. Default 0.02. Ignored for non-randn init. */
|
|
54
|
-
scale?: number
|
|
55
71
|
/** Whether AdamW (when `weightDecay > 0`) should apply decoupled weight
|
|
56
|
-
* decay to this param. Default: `true` for
|
|
57
|
-
* embeddings)
|
|
58
|
-
* to force or skip. Replaces `adam.decayFilter` for
|
|
72
|
+
* decay to this param. Default: `true` for randn/kaiming/literal init
|
|
73
|
+
* (weight matrices, embeddings); `false` for zeros/ones (biases, LN
|
|
74
|
+
* gains). Override to force or skip. Replaces `adam.decayFilter` for
|
|
75
|
+
* the common case. */
|
|
59
76
|
decay?: boolean
|
|
60
77
|
}
|
|
61
78
|
|
|
@@ -65,31 +82,52 @@ function boxMuller(): number {
|
|
|
65
82
|
return Math.sqrt(-2 * Math.log(Math.max(1e-10, Math.random()))) * Math.cos(2 * Math.PI * Math.random())
|
|
66
83
|
}
|
|
67
84
|
|
|
68
|
-
function
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
return
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
85
|
+
function randnFn(scale: number): InitFn {
|
|
86
|
+
return (size) => {
|
|
87
|
+
const arr = new Float32Array(size)
|
|
88
|
+
for (let i = 0; i < size; i++) arr[i] = boxMuller() * scale
|
|
89
|
+
return arr
|
|
90
|
+
}
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
/** Compile-time-only: resolve an InitSpec shape into the closure that
|
|
94
|
+
* generates the initial Float32Array for a given parameter shape. Runs
|
|
95
|
+
* on the main thread before initial values are transferred to the worker. */
|
|
96
|
+
function resolveInit(spec: InitSpec | undefined): InitFn {
|
|
97
|
+
if (!spec || spec === 'randn') return randnFn(0.02)
|
|
98
|
+
if (spec === 'zeros') return (size) => new Float32Array(size)
|
|
99
|
+
if (spec === 'ones') return (size) => { const a = new Float32Array(size); a.fill(1); return a }
|
|
100
|
+
switch (spec.kind) {
|
|
101
|
+
case 'randn': return randnFn(spec.scale)
|
|
102
|
+
case 'kaiming': {
|
|
103
|
+
const gain = spec.gain ?? Math.sqrt(2)
|
|
104
|
+
return (size, shape) => {
|
|
105
|
+
const fanIn = shape[0] ?? size
|
|
106
|
+
const std = gain / Math.sqrt(fanIn)
|
|
107
|
+
const arr = new Float32Array(size)
|
|
108
|
+
for (let i = 0; i < size; i++) arr[i] = boxMuller() * std
|
|
109
|
+
return arr
|
|
110
|
+
}
|
|
111
|
+
}
|
|
112
|
+
case 'literal': {
|
|
113
|
+
const data = spec.data
|
|
114
|
+
return (size) => {
|
|
115
|
+
if (data.length !== size) {
|
|
116
|
+
throw new Error(`init.literal: data length ${data.length} doesn't match param size ${size}`)
|
|
117
|
+
}
|
|
118
|
+
return new Float32Array(data)
|
|
119
|
+
}
|
|
76
120
|
}
|
|
77
121
|
}
|
|
78
|
-
if (init === 'zeros') return (size) => new Float32Array(size)
|
|
79
|
-
if (init === 'ones') return (size) => { const a = new Float32Array(size); a.fill(1); return a }
|
|
80
|
-
if (typeof init === 'function') return init
|
|
81
|
-
throw new Error(`Unknown init: ${String(init)}`)
|
|
82
122
|
}
|
|
83
123
|
|
|
84
|
-
/** Resolve the decay default for a param.
|
|
85
|
-
*
|
|
86
|
-
* (
|
|
87
|
-
* inits are weight-shaped (Kaiming etc.). Explicit `decay: false` overrides. */
|
|
124
|
+
/** Resolve the decay default for a param. Weight-shaped inits (randn,
|
|
125
|
+
* kaiming, literal) default to decay=true; ones/zeros default to false
|
|
126
|
+
* (biases, LN gains). Explicit `decay` opt overrides. */
|
|
88
127
|
function resolveDecay(opts: ParamOptions | undefined): boolean {
|
|
89
128
|
if (opts?.decay !== undefined) return opts.decay
|
|
90
|
-
const
|
|
91
|
-
|
|
92
|
-
return true // 'randn' or function
|
|
129
|
+
const spec = opts?.init ?? 'randn'
|
|
130
|
+
return spec !== 'zeros' && spec !== 'ones'
|
|
93
131
|
}
|
|
94
132
|
|
|
95
133
|
// ============================================================================
|
|
@@ -127,7 +165,7 @@ export abstract class Module {
|
|
|
127
165
|
protected param(shape: Shape, opts?: ParamOptions): Tensor {
|
|
128
166
|
const dtype = opts?.dtype ?? 'f32'
|
|
129
167
|
// Lie to TypeScript: the sentinel becomes a Tensor at materialize time.
|
|
130
|
-
return new ParamSentinel(shape, dtype, resolveInit(opts), resolveDecay(opts)) as unknown as Tensor
|
|
168
|
+
return new ParamSentinel(shape, dtype, resolveInit(opts?.init), resolveDecay(opts)) as unknown as Tensor
|
|
131
169
|
}
|
|
132
170
|
}
|
|
133
171
|
|
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
// Wire format for the main-thread ↔ worker postMessage channel.
|
|
2
|
+
//
|
|
3
|
+
// All requests carry a numeric `id` assigned by the main thread; responses
|
|
4
|
+
// echo it back so the proxy can match concurrent in-flight calls. Every
|
|
5
|
+
// response is either `{ ok: true, result }` or `{ ok: false, error }`.
|
|
6
|
+
// Errors carry serialized name/message/stack so the proxy can reconstitute
|
|
7
|
+
// an Error with a working `instanceof` check on the receiving side.
|
|
8
|
+
//
|
|
9
|
+
// Inputs (typed arrays) and outputs (typed arrays, captures) are transferred
|
|
10
|
+
// rather than copied — see the per-request notes for which fields go on the
|
|
11
|
+
// transfer list. A single worker may host multiple compiled graphs (a train
|
|
12
|
+
// graph plus sibling forward graphs); each has a `graphId` issued by the
|
|
13
|
+
// main thread at compile time.
|
|
14
|
+
|
|
15
|
+
import type { Graph } from './ir.js'
|
|
16
|
+
import type { BufferPlan } from './buffers.js'
|
|
17
|
+
import type { KernelSpec } from './codegen.js'
|
|
18
|
+
import type { LRSchedule } from './adam.js'
|
|
19
|
+
|
|
20
|
+
// ============================================================================
|
|
21
|
+
// Serializable config (subset of AdamResolvedConfig that crosses the wire).
|
|
22
|
+
// `decayFilter` (a function, used only at compile time) is NOT part of this —
|
|
23
|
+
// the per-param decay decision is already baked into the IR by appendAdam
|
|
24
|
+
// before the IR ships to the worker.
|
|
25
|
+
// ============================================================================
|
|
26
|
+
|
|
27
|
+
export interface WireAdamConfig {
|
|
28
|
+
lr: LRSchedule
|
|
29
|
+
b1: number
|
|
30
|
+
b2: number
|
|
31
|
+
eps: number
|
|
32
|
+
weightDecay: number
|
|
33
|
+
lrIsScheduled: boolean
|
|
34
|
+
/** Names of the per-step scalar inputs the worker must populate before
|
|
35
|
+
* every step (`_adam_lrt`, optionally `_adam_decay_shrink`). Mirrors
|
|
36
|
+
* AdamResult so the worker can update them without re-deriving. */
|
|
37
|
+
lrtInputName: string
|
|
38
|
+
decayShrinkInputName: string | null
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
/** Compile output that crosses to the worker. Same fields as CompiledIR
|
|
42
|
+
* minus the `loss` tensor (carried by graph.outputs[0]). */
|
|
43
|
+
export interface WireIR {
|
|
44
|
+
graph: Graph
|
|
45
|
+
plan: BufferPlan
|
|
46
|
+
kernels: KernelSpec[]
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
// ============================================================================
|
|
50
|
+
// Requests (main → worker)
|
|
51
|
+
// ============================================================================
|
|
52
|
+
|
|
53
|
+
export type Req =
|
|
54
|
+
| { id: number; kind: 'createRuntime'; payload: CreateRuntimePayload }
|
|
55
|
+
| { id: number; kind: 'compileForward'; payload: CompileForwardPayload }
|
|
56
|
+
| { id: number; kind: 'step'; payload: StepPayload }
|
|
57
|
+
| { id: number; kind: 'run'; payload: RunPayload }
|
|
58
|
+
| { id: number; kind: 'uploadParams'; payload: UploadParamsPayload }
|
|
59
|
+
| { id: number; kind: 'downloadParams'; payload: { graphId: number } }
|
|
60
|
+
| { id: number; kind: 'downloadParamGrads'; payload: { graphId: number } }
|
|
61
|
+
| { id: number; kind: 'resetOptimizer'; payload: { graphId: number } }
|
|
62
|
+
| { id: number; kind: 'destroy'; payload: { graphId: number } }
|
|
63
|
+
|
|
64
|
+
/** Build the training runtime. Always graphId=0 for a fresh worker. */
|
|
65
|
+
export interface CreateRuntimePayload {
|
|
66
|
+
graphId: number
|
|
67
|
+
ir: WireIR
|
|
68
|
+
/** Initial param values per name. Transferred (zero-copy) — the main
|
|
69
|
+
* thread loses access after postMessage. */
|
|
70
|
+
initialParams: Record<string, Float32Array>
|
|
71
|
+
/** Adam config when training; absent for forward-only compiles. */
|
|
72
|
+
adam: WireAdamConfig | null
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
/** Build a sibling forward-only graph that shares param buffers with an
|
|
76
|
+
* existing graph (typically the training graph at graphId=0). */
|
|
77
|
+
export interface CompileForwardPayload {
|
|
78
|
+
graphId: number
|
|
79
|
+
parentGraphId: number
|
|
80
|
+
ir: WireIR
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
/** One training step. Inputs are transferred; the caller's typed arrays
|
|
84
|
+
* become detached after postMessage. */
|
|
85
|
+
export interface StepPayload {
|
|
86
|
+
graphId: number
|
|
87
|
+
inputs: Record<string, Int32Array | Float32Array>
|
|
88
|
+
withCaptures: boolean
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
/** Forward-only run. Same transfer semantics as `step`. */
|
|
92
|
+
export interface RunPayload {
|
|
93
|
+
graphId: number
|
|
94
|
+
inputs: Record<string, Int32Array | Float32Array>
|
|
95
|
+
withCaptures: boolean
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
export interface UploadParamsPayload {
|
|
99
|
+
graphId: number
|
|
100
|
+
params: Record<string, Float32Array> // transferred
|
|
101
|
+
partial: boolean
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
// ============================================================================
|
|
105
|
+
// Responses (worker → main)
|
|
106
|
+
// ============================================================================
|
|
107
|
+
|
|
108
|
+
export type Res<R = unknown> =
|
|
109
|
+
| { id: number; ok: true; result: R }
|
|
110
|
+
| { id: number; ok: false; error: WireError }
|
|
111
|
+
|
|
112
|
+
export interface WireError {
|
|
113
|
+
name: string
|
|
114
|
+
message: string
|
|
115
|
+
stack: string
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
// Per-request result shapes:
|
|
119
|
+
|
|
120
|
+
export interface CreateRuntimeResult {
|
|
121
|
+
paramNames: string[]
|
|
122
|
+
outputShape: number[]
|
|
123
|
+
kernelCount: number
|
|
124
|
+
captureShapes: Record<string, number[]>
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
export interface CompileForwardResult {
|
|
128
|
+
paramNames: string[]
|
|
129
|
+
outputShape: number[]
|
|
130
|
+
kernelCount: number
|
|
131
|
+
captureShapes: Record<string, number[]>
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
/** Step without `withCaptures` returns just `loss`. With captures, also
|
|
135
|
+
* populates `captures` (per-name Float32Array, all transferred back). */
|
|
136
|
+
export interface StepResultWire {
|
|
137
|
+
loss: number
|
|
138
|
+
captures: Record<string, Float32Array> | null
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
/** Run without `withCaptures` returns `{ output, captures: null }`.
|
|
142
|
+
* With captures, also populates `captures`. */
|
|
143
|
+
export interface RunResultWire {
|
|
144
|
+
output: Float32Array
|
|
145
|
+
captures: Record<string, Float32Array> | null
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
export interface DownloadParamsResult {
|
|
149
|
+
params: Record<string, Float32Array> // transferred
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
// ============================================================================
|
|
153
|
+
// Transfer-list helpers
|
|
154
|
+
// ============================================================================
|
|
155
|
+
|
|
156
|
+
/** Collect the underlying ArrayBuffers from a Record of typed arrays so we
|
|
157
|
+
* can pass them on `postMessage`'s transfer list. The values themselves
|
|
158
|
+
* stay in the Record; only their backing buffers move. */
|
|
159
|
+
export function transferablesOfRecord(
|
|
160
|
+
rec: Record<string, Int32Array | Float32Array>,
|
|
161
|
+
): ArrayBuffer[] {
|
|
162
|
+
const out: ArrayBuffer[] = []
|
|
163
|
+
for (const v of Object.values(rec)) out.push(v.buffer as ArrayBuffer)
|
|
164
|
+
return out
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
/** Serialize an Error to a wire-friendly shape, preserving stack + name so
|
|
168
|
+
* the receiving side can reconstitute an Error that an `instanceof`-aware
|
|
169
|
+
* caller (e.g., for `ShapeError`) can still pattern-match by name. */
|
|
170
|
+
export function wireError(e: unknown): WireError {
|
|
171
|
+
if (e instanceof Error) {
|
|
172
|
+
return { name: e.name, message: e.message, stack: e.stack ?? '' }
|
|
173
|
+
}
|
|
174
|
+
return { name: 'Error', message: String(e), stack: '' }
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
/** Reconstitute an Error from the wire shape on the receiving (main) side. */
|
|
178
|
+
export function reconstituteError(w: WireError): Error {
|
|
179
|
+
const err = new Error(w.message)
|
|
180
|
+
err.name = w.name
|
|
181
|
+
err.stack = w.stack
|
|
182
|
+
return err
|
|
183
|
+
}
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
// Main-thread half of the worker channel: request/response correlation,
|
|
2
|
+
// promise wiring, error reconstitution. Knows nothing about Adam, captures,
|
|
3
|
+
// IR, etc. — just shuttles typed messages.
|
|
4
|
+
|
|
5
|
+
import type { Req, Res, WireError } from './worker-protocol.js'
|
|
6
|
+
import { reconstituteError } from './worker-protocol.js'
|
|
7
|
+
|
|
8
|
+
interface PendingHandlers {
|
|
9
|
+
resolve: (v: unknown) => void
|
|
10
|
+
reject: (e: Error) => void
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
/** Spawn a worker from an inlined source string and provide a typed
|
|
14
|
+
* request/response channel. One WorkerProxy = one Worker = one GPUDevice
|
|
15
|
+
* on the worker side. Sibling graphs share the same WorkerProxy. */
|
|
16
|
+
export class WorkerProxy {
|
|
17
|
+
private worker: Worker
|
|
18
|
+
private nextId = 1
|
|
19
|
+
private pending = new Map<number, PendingHandlers>()
|
|
20
|
+
private terminated = false
|
|
21
|
+
|
|
22
|
+
constructor(workerSource: string) {
|
|
23
|
+
const blob = new Blob([workerSource], { type: 'application/javascript' })
|
|
24
|
+
const url = URL.createObjectURL(blob)
|
|
25
|
+
this.worker = new Worker(url, { type: 'module' })
|
|
26
|
+
// The Blob URL keeps memory alive as long as it's referenced; revoke
|
|
27
|
+
// once the worker has loaded its source. Browsers tolerate revoke
|
|
28
|
+
// immediately after construction in practice.
|
|
29
|
+
URL.revokeObjectURL(url)
|
|
30
|
+
|
|
31
|
+
this.worker.onmessage = (ev: MessageEvent<Res>) => {
|
|
32
|
+
const reply = ev.data
|
|
33
|
+
const handlers = this.pending.get(reply.id)
|
|
34
|
+
if (!handlers) return // stale reply; ignore
|
|
35
|
+
this.pending.delete(reply.id)
|
|
36
|
+
if (reply.ok) handlers.resolve(reply.result)
|
|
37
|
+
else handlers.reject(reconstituteError(reply.error))
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
this.worker.onerror = (ev: ErrorEvent) => {
|
|
41
|
+
const err = new Error(`tensorgrad worker error: ${ev.message || 'unknown'}`)
|
|
42
|
+
const wire: WireError = { name: 'WorkerError', message: err.message, stack: err.stack ?? '' }
|
|
43
|
+
// Reject everything in flight; subsequent calls will fail too.
|
|
44
|
+
for (const handlers of this.pending.values()) handlers.reject(reconstituteError(wire))
|
|
45
|
+
this.pending.clear()
|
|
46
|
+
}
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
/** Send a request and await its matching response. `transfer` lists the
|
|
50
|
+
* ArrayBuffers to move (zero-copy) into the worker. */
|
|
51
|
+
request<R>(req: Omit<Req, 'id'>, transfer: ArrayBuffer[] = []): Promise<R> {
|
|
52
|
+
if (this.terminated) return Promise.reject(new Error('tensorgrad: worker has been terminated'))
|
|
53
|
+
const id = this.nextId++
|
|
54
|
+
return new Promise<R>((resolve, reject) => {
|
|
55
|
+
this.pending.set(id, { resolve: resolve as (v: unknown) => void, reject })
|
|
56
|
+
this.worker.postMessage({ ...req, id } as Req, transfer)
|
|
57
|
+
})
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
/** Fire-and-forget variant for cases where the caller doesn't need a reply
|
|
61
|
+
* (currently unused; keep for symmetry / future use). */
|
|
62
|
+
send(req: Omit<Req, 'id'>, transfer: ArrayBuffer[] = []): void {
|
|
63
|
+
if (this.terminated) return
|
|
64
|
+
const id = this.nextId++
|
|
65
|
+
this.worker.postMessage({ ...req, id } as Req, transfer)
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
terminate(): void {
|
|
69
|
+
if (this.terminated) return
|
|
70
|
+
this.terminated = true
|
|
71
|
+
this.worker.terminate()
|
|
72
|
+
const err = new Error('tensorgrad: worker terminated')
|
|
73
|
+
for (const handlers of this.pending.values()) handlers.reject(err)
|
|
74
|
+
this.pending.clear()
|
|
75
|
+
}
|
|
76
|
+
}
|
package/src/worker.ts
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
// Worker entry point. Holds the GPUDevice + CompiledRuntime for one or more
|
|
2
|
+
// graphs and proxies main-thread requests via postMessage. See
|
|
3
|
+
// specs/WorkerArchitecture.md for the rationale.
|
|
4
|
+
//
|
|
5
|
+
// Keep this file dependency-free of anything DOM-y: it bundles into a Blob
|
|
6
|
+
// URL and runs in a Web Worker context where `window`/`document` don't
|
|
7
|
+
// exist. WebGPU IS available in workers (Chrome 113+, Safari 17.4+).
|
|
8
|
+
|
|
9
|
+
import { createRuntime, type CompiledRuntime, type RuntimeOpts } from './runtime.js'
|
|
10
|
+
import { resolveLR, type LRSchedule } from './adam.js'
|
|
11
|
+
import type { Req, Res, WireIR, WireAdamConfig, WireError } from './worker-protocol.js'
|
|
12
|
+
import { wireError } from './worker-protocol.js'
|
|
13
|
+
|
|
14
|
+
// ----------------------------------------------------------------------------
|
|
15
|
+
// Per-graph state
|
|
16
|
+
// ----------------------------------------------------------------------------
|
|
17
|
+
|
|
18
|
+
interface GraphSlot {
|
|
19
|
+
runtime: CompiledRuntime
|
|
20
|
+
paramNames: readonly string[]
|
|
21
|
+
outputShape: number[]
|
|
22
|
+
kernelCount: number
|
|
23
|
+
captureShapes: Record<string, number[]>
|
|
24
|
+
/** Adam state for this graph, if it's a training graph. The wrapped step
|
|
25
|
+
* uses these to populate the per-step lrt and decayShrink scalars. */
|
|
26
|
+
adam: AdamState | null
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
interface AdamState {
|
|
30
|
+
config: WireAdamConfig
|
|
31
|
+
t: number
|
|
32
|
+
lrtBuf: Float32Array
|
|
33
|
+
decayShrinkBuf: Float32Array | null
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
const graphs = new Map<number, GraphSlot>()
|
|
37
|
+
|
|
38
|
+
// Worker holds one device shared across all graphs (sibling forward graphs
|
|
39
|
+
// must share param GPUBuffers, which means sharing a device).
|
|
40
|
+
let device: GPUDevice | null = null
|
|
41
|
+
|
|
42
|
+
async function ensureDevice(): Promise<GPUDevice> {
|
|
43
|
+
if (device) return device
|
|
44
|
+
if (typeof navigator === 'undefined' || !navigator.gpu) {
|
|
45
|
+
throw new Error('tensorgrad worker: WebGPU not available in this environment')
|
|
46
|
+
}
|
|
47
|
+
const adapter = await navigator.gpu.requestAdapter()
|
|
48
|
+
if (!adapter) throw new Error('tensorgrad worker: no WebGPU adapter')
|
|
49
|
+
device = await adapter.requestDevice()
|
|
50
|
+
return device
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
// ----------------------------------------------------------------------------
|
|
54
|
+
// Request handlers
|
|
55
|
+
// ----------------------------------------------------------------------------
|
|
56
|
+
|
|
57
|
+
async function handleCreateRuntime(payload: {
|
|
58
|
+
graphId: number
|
|
59
|
+
ir: WireIR
|
|
60
|
+
initialParams: Record<string, Float32Array>
|
|
61
|
+
adam: WireAdamConfig | null
|
|
62
|
+
}): Promise<{ paramNames: string[]; outputShape: number[]; kernelCount: number; captureShapes: Record<string, number[]> }> {
|
|
63
|
+
const dev = await ensureDevice()
|
|
64
|
+
const { graph, plan, kernels } = payload.ir
|
|
65
|
+
const outputTensorId = graph.outputs[0]!
|
|
66
|
+
const outputBufferId = plan.tensorToBuffer.get(outputTensorId)!
|
|
67
|
+
const opts: RuntimeOpts = { device: dev }
|
|
68
|
+
const runtime = await createRuntime(plan, kernels, outputBufferId, opts)
|
|
69
|
+
|
|
70
|
+
// Upload initial params.
|
|
71
|
+
if (Object.keys(payload.initialParams).length > 0) {
|
|
72
|
+
runtime.uploadParams(payload.initialParams)
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
// Capture shape metadata for return.
|
|
76
|
+
const captureShapes: Record<string, number[]> = {}
|
|
77
|
+
for (const [name, bufId] of plan.capturesByName) {
|
|
78
|
+
captureShapes[name] = [...plan.buffers[bufId]!.shape]
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
const slot: GraphSlot = {
|
|
82
|
+
runtime,
|
|
83
|
+
paramNames: [...plan.paramsByName.keys()],
|
|
84
|
+
outputShape: [...runtime.outputShape],
|
|
85
|
+
kernelCount: kernels.filter(k => k.wgsl).length,
|
|
86
|
+
captureShapes,
|
|
87
|
+
adam: payload.adam ? createAdamState(payload.adam) : null,
|
|
88
|
+
}
|
|
89
|
+
graphs.set(payload.graphId, slot)
|
|
90
|
+
|
|
91
|
+
return {
|
|
92
|
+
paramNames: [...slot.paramNames],
|
|
93
|
+
outputShape: slot.outputShape,
|
|
94
|
+
kernelCount: slot.kernelCount,
|
|
95
|
+
captureShapes: slot.captureShapes,
|
|
96
|
+
}
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
async function handleCompileForward(payload: {
|
|
100
|
+
graphId: number
|
|
101
|
+
parentGraphId: number
|
|
102
|
+
ir: WireIR
|
|
103
|
+
}): Promise<{ paramNames: string[]; outputShape: number[]; kernelCount: number; captureShapes: Record<string, number[]> }> {
|
|
104
|
+
const dev = await ensureDevice()
|
|
105
|
+
const parent = graphs.get(payload.parentGraphId)
|
|
106
|
+
if (!parent) throw new Error(`compileForward: parent graph ${payload.parentGraphId} not found`)
|
|
107
|
+
|
|
108
|
+
const { graph, plan, kernels } = payload.ir
|
|
109
|
+
const outputTensorId = graph.outputs[0]!
|
|
110
|
+
const outputBufferId = plan.tensorToBuffer.get(outputTensorId)!
|
|
111
|
+
const opts: RuntimeOpts = { device: dev, sharedParams: parent.runtime.params }
|
|
112
|
+
const runtime = await createRuntime(plan, kernels, outputBufferId, opts)
|
|
113
|
+
// No initial-param upload — sharedParams covers everything.
|
|
114
|
+
|
|
115
|
+
const captureShapes: Record<string, number[]> = {}
|
|
116
|
+
for (const [name, bufId] of plan.capturesByName) {
|
|
117
|
+
captureShapes[name] = [...plan.buffers[bufId]!.shape]
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
const slot: GraphSlot = {
|
|
121
|
+
runtime,
|
|
122
|
+
paramNames: [...plan.paramsByName.keys()],
|
|
123
|
+
outputShape: [...runtime.outputShape],
|
|
124
|
+
kernelCount: kernels.filter(k => k.wgsl).length,
|
|
125
|
+
captureShapes,
|
|
126
|
+
adam: null,
|
|
127
|
+
}
|
|
128
|
+
graphs.set(payload.graphId, slot)
|
|
129
|
+
|
|
130
|
+
return {
|
|
131
|
+
paramNames: [...slot.paramNames],
|
|
132
|
+
outputShape: slot.outputShape,
|
|
133
|
+
kernelCount: slot.kernelCount,
|
|
134
|
+
captureShapes: slot.captureShapes,
|
|
135
|
+
}
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
function createAdamState(cfg: WireAdamConfig): AdamState {
|
|
139
|
+
return {
|
|
140
|
+
config: cfg,
|
|
141
|
+
t: 0,
|
|
142
|
+
lrtBuf: new Float32Array(1),
|
|
143
|
+
decayShrinkBuf: cfg.decayShrinkInputName ? new Float32Array(1) : null,
|
|
144
|
+
}
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
/** Inject Adam's per-step lrt + decayShrink scalars into the inputs map.
|
|
148
|
+
* Called before every step on a training graph. The buffers are reused
|
|
149
|
+
* across steps to avoid allocation. */
|
|
150
|
+
function injectAdamScalars(slot: GraphSlot, inputs: Record<string, Int32Array | Float32Array>): Record<string, Int32Array | Float32Array> {
|
|
151
|
+
const a = slot.adam
|
|
152
|
+
if (!a) return inputs
|
|
153
|
+
a.t++
|
|
154
|
+
const lrNow = resolveLR(a.config.lr as LRSchedule, a.t)
|
|
155
|
+
a.lrtBuf[0] = lrNow * Math.sqrt(1 - Math.pow(a.config.b2, a.t)) / (1 - Math.pow(a.config.b1, a.t))
|
|
156
|
+
const merged: Record<string, Int32Array | Float32Array> = { ...inputs, [a.config.lrtInputName]: a.lrtBuf }
|
|
157
|
+
if (a.decayShrinkBuf && a.config.decayShrinkInputName) {
|
|
158
|
+
a.decayShrinkBuf[0] = 1 - lrNow * a.config.weightDecay
|
|
159
|
+
merged[a.config.decayShrinkInputName] = a.decayShrinkBuf
|
|
160
|
+
}
|
|
161
|
+
return merged
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
async function handleStep(payload: {
|
|
165
|
+
graphId: number
|
|
166
|
+
inputs: Record<string, Int32Array | Float32Array>
|
|
167
|
+
withCaptures: boolean
|
|
168
|
+
}): Promise<{ loss: number; captures: Record<string, Float32Array> | null }> {
|
|
169
|
+
const slot = mustGet(payload.graphId)
|
|
170
|
+
const merged = injectAdamScalars(slot, payload.inputs)
|
|
171
|
+
if (payload.withCaptures) {
|
|
172
|
+
const r = await slot.runtime.step(merged, { withCaptures: true })
|
|
173
|
+
return { loss: r.loss, captures: capturesToRecord(r.captures, slot.captureShapes) }
|
|
174
|
+
}
|
|
175
|
+
const loss = await slot.runtime.step(merged)
|
|
176
|
+
return { loss, captures: null }
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
async function handleRun(payload: {
|
|
180
|
+
graphId: number
|
|
181
|
+
inputs: Record<string, Int32Array | Float32Array>
|
|
182
|
+
withCaptures: boolean
|
|
183
|
+
}): Promise<{ output: Float32Array; captures: Record<string, Float32Array> | null }> {
|
|
184
|
+
const slot = mustGet(payload.graphId)
|
|
185
|
+
if (payload.withCaptures) {
|
|
186
|
+
const r = await slot.runtime.run(payload.inputs, { withCaptures: true })
|
|
187
|
+
return { output: r.output, captures: capturesToRecord(r.captures, slot.captureShapes) }
|
|
188
|
+
}
|
|
189
|
+
const output = await slot.runtime.run(payload.inputs)
|
|
190
|
+
return { output, captures: null }
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
/** Captures (a class instance with a private Map) → a plain Record so the
|
|
194
|
+
* worker can transfer Float32Arrays back without serializing the class. */
|
|
195
|
+
function capturesToRecord(
|
|
196
|
+
captures: { get(name: string): Float32Array; has(name: string): boolean; names(): string[] },
|
|
197
|
+
// captureShapes available but not used directly — capture names from
|
|
198
|
+
// shapes in case captures.names() is filtered (it isn't, but be safe).
|
|
199
|
+
shapes: Record<string, number[]>,
|
|
200
|
+
): Record<string, Float32Array> {
|
|
201
|
+
const out: Record<string, Float32Array> = {}
|
|
202
|
+
for (const name of Object.keys(shapes)) {
|
|
203
|
+
if (captures.has(name)) out[name] = captures.get(name)
|
|
204
|
+
}
|
|
205
|
+
return out
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
function handleUploadParams(payload: {
|
|
209
|
+
graphId: number
|
|
210
|
+
params: Record<string, Float32Array>
|
|
211
|
+
partial: boolean
|
|
212
|
+
}): void {
|
|
213
|
+
const slot = mustGet(payload.graphId)
|
|
214
|
+
slot.runtime.uploadParams(payload.params, { partial: payload.partial })
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
async function handleDownloadParams(payload: { graphId: number }): Promise<{ params: Record<string, Float32Array> }> {
|
|
218
|
+
const slot = mustGet(payload.graphId)
|
|
219
|
+
return { params: await slot.runtime.downloadParams() }
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
async function handleDownloadParamGrads(payload: { graphId: number }): Promise<{ params: Record<string, Float32Array> }> {
|
|
223
|
+
const slot = mustGet(payload.graphId)
|
|
224
|
+
return { params: await slot.runtime.downloadParamGrads() }
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
function handleResetOptimizer(payload: { graphId: number }): void {
|
|
228
|
+
const slot = mustGet(payload.graphId)
|
|
229
|
+
slot.runtime.resetOptimizerState()
|
|
230
|
+
if (slot.adam) slot.adam.t = 0
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
function handleDestroy(payload: { graphId: number }): void {
|
|
234
|
+
const slot = graphs.get(payload.graphId)
|
|
235
|
+
if (!slot) return
|
|
236
|
+
slot.runtime.destroy()
|
|
237
|
+
graphs.delete(payload.graphId)
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
function mustGet(graphId: number): GraphSlot {
|
|
241
|
+
const slot = graphs.get(graphId)
|
|
242
|
+
if (!slot) throw new Error(`tensorgrad worker: graph ${graphId} not found`)
|
|
243
|
+
return slot
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
// ----------------------------------------------------------------------------
|
|
247
|
+
// Message dispatch
|
|
248
|
+
// ----------------------------------------------------------------------------
|
|
249
|
+
|
|
250
|
+
self.onmessage = async (ev: MessageEvent<Req>) => {
|
|
251
|
+
const req = ev.data
|
|
252
|
+
try {
|
|
253
|
+
let result: unknown
|
|
254
|
+
let transferList: ArrayBuffer[] = []
|
|
255
|
+
switch (req.kind) {
|
|
256
|
+
case 'createRuntime': result = await handleCreateRuntime(req.payload); break
|
|
257
|
+
case 'compileForward': result = await handleCompileForward(req.payload); break
|
|
258
|
+
case 'step': result = await handleStep(req.payload); transferList = collectTransfers((result as any).captures); break
|
|
259
|
+
case 'run': { const r = await handleRun(req.payload); result = r; transferList = [r.output.buffer as ArrayBuffer, ...collectTransfers(r.captures)]; break }
|
|
260
|
+
case 'uploadParams': handleUploadParams(req.payload); result = null; break
|
|
261
|
+
case 'downloadParams': { const r = await handleDownloadParams(req.payload); result = r; transferList = collectTransfers(r.params); break }
|
|
262
|
+
case 'downloadParamGrads':{ const r = await handleDownloadParamGrads(req.payload); result = r; transferList = collectTransfers(r.params); break }
|
|
263
|
+
case 'resetOptimizer': handleResetOptimizer(req.payload); result = null; break
|
|
264
|
+
case 'destroy': handleDestroy(req.payload); result = null; break
|
|
265
|
+
default: throw new Error(`unknown request kind: ${(req as { kind: string }).kind}`)
|
|
266
|
+
}
|
|
267
|
+
const reply: Res = { id: req.id, ok: true, result }
|
|
268
|
+
self.postMessage(reply, { transfer: transferList })
|
|
269
|
+
} catch (e) {
|
|
270
|
+
const error: WireError = wireError(e)
|
|
271
|
+
const reply: Res = { id: req.id, ok: false, error }
|
|
272
|
+
self.postMessage(reply)
|
|
273
|
+
}
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
function collectTransfers(rec: Record<string, Float32Array> | null | undefined): ArrayBuffer[] {
|
|
277
|
+
if (!rec) return []
|
|
278
|
+
const out: ArrayBuffer[] = []
|
|
279
|
+
for (const v of Object.values(rec)) out.push(v.buffer as ArrayBuffer)
|
|
280
|
+
return out
|
|
281
|
+
}
|