tensorgrad 0.0.15 → 0.0.16
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.d.ts +154 -193
- package/dist/index.js +2208 -39
- package/dist/index.js.map +7 -1
- package/dist/worker.debug.js +553 -0
- package/package.json +60 -58
- package/src/adam.ts +69 -15
- package/src/compile.ts +334 -156
- package/src/index.ts +8 -4
- package/src/module.ts +72 -34
- package/src/worker-protocol.ts +183 -0
- package/src/worker-proxy.ts +76 -0
- package/src/worker.ts +281 -0
- package/dist/adam.js +0 -111
- package/dist/adam.js.map +0 -1
- package/dist/buffers.js +0 -120
- package/dist/buffers.js.map +0 -1
- package/dist/capture.js +0 -33
- package/dist/capture.js.map +0 -1
- package/dist/codegen.js +0 -724
- package/dist/codegen.js.map +0 -1
- package/dist/compile.js +0 -184
- package/dist/compile.js.map +0 -1
- package/dist/grad.js +0 -380
- package/dist/grad.js.map +0 -1
- package/dist/ir.js +0 -60
- package/dist/ir.js.map +0 -1
- package/dist/module.js +0 -155
- package/dist/module.js.map +0 -1
- package/dist/nn.js +0 -135
- package/dist/nn.js.map +0 -1
- package/dist/ops.js +0 -326
- package/dist/ops.js.map +0 -1
- package/dist/runtime.js +0 -402
- package/dist/runtime.js.map +0 -1
- package/dist/shape.js +0 -259
- package/dist/shape.js.map +0 -1
- package/dist/trace.js +0 -100
- package/dist/trace.js.map +0 -1
package/src/compile.ts
CHANGED
|
@@ -7,15 +7,43 @@
|
|
|
7
7
|
// Module tree; the library auto-discovers
|
|
8
8
|
// params, traces the forward, appends grad
|
|
9
9
|
// and Adam, and returns a runtime.
|
|
10
|
+
//
|
|
11
|
+
// As of the worker-architecture refactor: compile-time work (trace, autograd,
|
|
12
|
+
// buffer planning, codegen) runs on the main thread. createRuntime and all
|
|
13
|
+
// dispatch/mapAsync work runs in a Web Worker spawned per top-level compile;
|
|
14
|
+
// the returned `CompiledModule` is a thin proxy over the worker channel.
|
|
15
|
+
// See specs/WorkerArchitecture.md.
|
|
10
16
|
|
|
11
17
|
import type { Tensor, Shape, Dtype } from './ir.js'
|
|
12
18
|
import { trace, tensorInput } from './trace.js'
|
|
13
19
|
import { appendGrad, type GradResult } from './grad.js'
|
|
14
|
-
import {
|
|
20
|
+
import {
|
|
21
|
+
appendAdam, resolveLR,
|
|
22
|
+
type AdamConfig, type AdamResult, type AdamResolvedConfig,
|
|
23
|
+
} from './adam.js'
|
|
15
24
|
import { planBuffers, type BufferPlan } from './buffers.js'
|
|
16
25
|
import { emitKernels, type KernelSpec } from './codegen.js'
|
|
17
|
-
import {
|
|
26
|
+
import {
|
|
27
|
+
Captures, type RunResult, type StepResult, type RunOptions, type UploadParamsOptions,
|
|
28
|
+
} from './runtime.js'
|
|
18
29
|
import { Module, materializeParams, type MaterializedParams } from './module.js'
|
|
30
|
+
import { WorkerProxy } from './worker-proxy.js'
|
|
31
|
+
import {
|
|
32
|
+
transferablesOfRecord,
|
|
33
|
+
type Req, type WireIR, type WireAdamConfig,
|
|
34
|
+
type CreateRuntimeResult, type CompileForwardResult,
|
|
35
|
+
type StepResultWire, type RunResultWire, type DownloadParamsResult,
|
|
36
|
+
} from './worker-protocol.js'
|
|
37
|
+
|
|
38
|
+
// `__WORKER_SOURCE__` is replaced at build time by scripts/build.mjs with the
|
|
39
|
+
// stringified contents of the bundled src/worker.ts. Declared here so TS is
|
|
40
|
+
// happy; substituted as a string literal by esbuild's `define` during
|
|
41
|
+
// `npm run build:js`. See scripts/build.mjs.
|
|
42
|
+
declare const __WORKER_SOURCE__: string
|
|
43
|
+
|
|
44
|
+
// ============================================================================
|
|
45
|
+
// Public types
|
|
46
|
+
// ============================================================================
|
|
19
47
|
|
|
20
48
|
/** Declares one input tensor of the model's forward function. The name is the
|
|
21
49
|
* key in the `inputs:` Record at compile time and the key on the `step()`/
|
|
@@ -25,21 +53,14 @@ export interface InputDecl {
|
|
|
25
53
|
dtype?: Dtype
|
|
26
54
|
}
|
|
27
55
|
|
|
28
|
-
/** Inputs declaration: a Record from input name to its shape/dtype.
|
|
29
|
-
* doubles as the key the forward fn destructures and the key the runtime
|
|
30
|
-
* expects in `step({...})` / `run({...})`. */
|
|
56
|
+
/** Inputs declaration: a Record from input name to its shape/dtype. */
|
|
31
57
|
export type InputDecls = Record<string, InputDecl>
|
|
32
58
|
|
|
33
59
|
/** Maps an `InputDecls` Record to its forward-time tensor counterpart —
|
|
34
|
-
* same keys, each value is a Tensor.
|
|
35
|
-
* `inputs` argument from the declared shape Record. */
|
|
60
|
+
* same keys, each value is a Tensor. */
|
|
36
61
|
export type InputsTensors<I extends InputDecls> = { [K in keyof I]: Tensor }
|
|
37
62
|
|
|
38
|
-
/** Forward function shape
|
|
39
|
-
* named input tensors (matching the declared `inputs:` keys), returns the
|
|
40
|
-
* output tensor (loss for compileModule; logits/etc. for compileForward).
|
|
41
|
-
* The second generic flows from the inputs declaration so destructuring
|
|
42
|
-
* the input record stays typed. */
|
|
63
|
+
/** Forward function shape. */
|
|
43
64
|
export type ForwardFn<M extends Module, I extends InputDecls = InputDecls> =
|
|
44
65
|
(m: M, inputs: InputsTensors<I>) => Tensor
|
|
45
66
|
|
|
@@ -60,75 +81,86 @@ export function compileToIR(traceFn: () => Tensor): CompiledIR {
|
|
|
60
81
|
return { graph, paramGrads, loss, plan, kernels }
|
|
61
82
|
}
|
|
62
83
|
|
|
63
|
-
/** Full compile pipeline. Browser-only because it creates a GPUDevice. */
|
|
64
|
-
export async function compile(traceFn: () => Tensor, opts: RuntimeOpts = {}): Promise<CompiledRuntime & { ir: CompiledIR }> {
|
|
65
|
-
const ir = compileToIR(traceFn)
|
|
66
|
-
const lossBufferId = ir.plan.tensorToBuffer.get(ir.loss.id)!
|
|
67
|
-
const runtime = await createRuntime(ir.plan, ir.kernels, lossBufferId, opts)
|
|
68
|
-
return Object.assign(runtime, { ir })
|
|
69
|
-
}
|
|
70
|
-
|
|
71
84
|
// ============================================================================
|
|
72
|
-
//
|
|
85
|
+
// CompiledModule / CompiledForwardModule — main-thread proxy surface
|
|
73
86
|
// ============================================================================
|
|
74
87
|
|
|
75
|
-
export interface CompileModuleOptions<I extends InputDecls = InputDecls>
|
|
76
|
-
/** Per-step data inputs to the forward function, keyed by name. The forward
|
|
77
|
-
* fn destructures these out of its second argument; runtime calls to
|
|
78
|
-
* `step()` / `run()` pass typed arrays under the same keys. */
|
|
88
|
+
export interface CompileModuleOptions<I extends InputDecls = InputDecls> {
|
|
79
89
|
inputs?: I
|
|
80
|
-
/** Adam hyperparameters. If omitted, no optimizer is appended (forward-only). */
|
|
81
90
|
adam?: AdamConfig
|
|
82
91
|
}
|
|
83
92
|
|
|
84
|
-
export interface CompileForwardOptions<I extends InputDecls = InputDecls>
|
|
85
|
-
/** Per-step data inputs to the forward function, keyed by name. */
|
|
93
|
+
export interface CompileForwardOptions<I extends InputDecls = InputDecls> {
|
|
86
94
|
inputs?: I
|
|
87
95
|
}
|
|
88
96
|
|
|
89
|
-
/** Forward-only compile options as taken by the `compileForward` *method* on
|
|
90
|
-
* a training runtime — no `device` (inherited) and no `sharedParams`
|
|
91
|
-
* (auto-supplied from the train graph's params). */
|
|
92
97
|
export interface CompileForwardMethodOptions<I extends InputDecls = InputDecls> {
|
|
93
98
|
inputs?: I
|
|
94
99
|
}
|
|
95
100
|
|
|
96
|
-
/** Returned by `compileModule`.
|
|
97
|
-
*
|
|
98
|
-
export interface CompiledModule<M extends Module>
|
|
99
|
-
ir: CompiledIR
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
/**
|
|
103
|
-
*
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
101
|
+
/** Returned by `compileModule`. Proxies all GPU work to a worker held
|
|
102
|
+
* internally; user code awaits Promises and never sees the worker. */
|
|
103
|
+
export interface CompiledModule<M extends Module> {
|
|
104
|
+
readonly ir: CompiledIR
|
|
105
|
+
readonly kernelCount: number
|
|
106
|
+
readonly outputShape: readonly number[]
|
|
107
|
+
/** Names of the model's parameters, in materialization order. The actual
|
|
108
|
+
* GPUBuffers live in the worker; use `downloadParams()` for values. */
|
|
109
|
+
readonly paramNames: readonly string[]
|
|
110
|
+
|
|
111
|
+
step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>
|
|
112
|
+
step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<StepResult>
|
|
113
|
+
|
|
114
|
+
run(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>
|
|
115
|
+
run(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunResult>
|
|
116
|
+
|
|
117
|
+
uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): Promise<void>
|
|
118
|
+
downloadParams(): Promise<Record<string, Float32Array>>
|
|
119
|
+
downloadParamGrads(): Promise<Record<string, Float32Array>>
|
|
120
|
+
|
|
121
|
+
/** Re-initialize all params + zero optimizer state. */
|
|
122
|
+
reset(): Promise<void>
|
|
123
|
+
resetOptimizerState(): Promise<void>
|
|
124
|
+
|
|
125
|
+
/** Compile a sibling forward-only graph that shares this runtime's worker
|
|
126
|
+
* (and therefore its param GPUBuffers). */
|
|
110
127
|
compileForward<I extends InputDecls>(
|
|
111
128
|
forward: ForwardFn<M, I>,
|
|
112
129
|
opts?: CompileForwardMethodOptions<I>,
|
|
113
130
|
): Promise<CompiledForwardModule>
|
|
131
|
+
|
|
132
|
+
/** Free the runtime's GPU resources and terminate the worker. */
|
|
133
|
+
destroy(): void
|
|
114
134
|
}
|
|
115
135
|
|
|
116
136
|
/** Returned by `compileForward` (and by the `compileForward` method). */
|
|
117
|
-
export interface CompiledForwardModule
|
|
118
|
-
ir: CompiledIR
|
|
119
|
-
|
|
120
|
-
|
|
137
|
+
export interface CompiledForwardModule {
|
|
138
|
+
readonly ir: CompiledIR
|
|
139
|
+
readonly kernelCount: number
|
|
140
|
+
readonly outputShape: readonly number[]
|
|
141
|
+
readonly paramNames: readonly string[]
|
|
142
|
+
|
|
143
|
+
run(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>
|
|
144
|
+
run(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunResult>
|
|
145
|
+
|
|
146
|
+
uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): Promise<void>
|
|
147
|
+
downloadParams(): Promise<Record<string, Float32Array>>
|
|
148
|
+
|
|
149
|
+
destroy(): void
|
|
121
150
|
}
|
|
122
151
|
|
|
152
|
+
// ============================================================================
|
|
153
|
+
// compileModule / compileForward
|
|
154
|
+
// ============================================================================
|
|
155
|
+
|
|
123
156
|
/**
|
|
124
157
|
* Compile a Module-based model. Pass a *factory* `() => new Model()`, not the
|
|
125
158
|
* model instance itself: compilation mutates the tree (every `ParamSentinel`
|
|
126
159
|
* field becomes a real `Tensor`), so the instance is consumed and shouldn't be
|
|
127
|
-
* referenced afterwards.
|
|
160
|
+
* referenced afterwards.
|
|
128
161
|
*
|
|
129
162
|
* The forward function takes the materialized model and a Record of named
|
|
130
|
-
* input tensors, returns the loss tensor
|
|
131
|
-
* `inputs:` declaration:
|
|
163
|
+
* input tensors, returns the loss tensor:
|
|
132
164
|
*
|
|
133
165
|
* inputs: {
|
|
134
166
|
* tokens: { shape: [B, T], dtype: 'i32' },
|
|
@@ -136,20 +168,16 @@ export interface CompiledForwardModule extends CompiledForward {
|
|
|
136
168
|
* }
|
|
137
169
|
* forward: (m, { tokens, targets }) => …
|
|
138
170
|
*
|
|
139
|
-
*
|
|
140
|
-
*
|
|
141
|
-
*
|
|
142
|
-
* call `reset()` later to re-randomize.
|
|
143
|
-
*
|
|
144
|
-
* If `opts.adam` is set, the runtime's `step()` automatically tracks an
|
|
145
|
-
* internal step count and injects the bias-corrected `lrt` scalar each call;
|
|
146
|
-
* users don't need to provide it themselves.
|
|
171
|
+
* Returns a `CompiledModule` proxy. All GPU work (createRuntime, step, run,
|
|
172
|
+
* mapAsync) happens in an internal worker; calls return Promises that resolve
|
|
173
|
+
* when the worker replies.
|
|
147
174
|
*/
|
|
148
175
|
export async function compileModule<M extends Module, I extends InputDecls = InputDecls>(
|
|
149
176
|
modelFactory: () => M,
|
|
150
177
|
forward: ForwardFn<M, I>,
|
|
151
178
|
opts: CompileModuleOptions<I> = {},
|
|
152
179
|
): Promise<CompiledModule<M>> {
|
|
180
|
+
// ---- Compile-time work (main thread) ------------------------------------
|
|
153
181
|
const { graph, materialized } = traceModule(modelFactory, forward, opts.inputs ?? {})
|
|
154
182
|
const { paramGrads, loss } = appendGrad(graph)
|
|
155
183
|
const adamResult = opts.adam
|
|
@@ -158,55 +186,40 @@ export async function compileModule<M extends Module, I extends InputDecls = Inp
|
|
|
158
186
|
|
|
159
187
|
const plan = planBuffers(graph, paramGrads, adamResult?.writebacks ?? [])
|
|
160
188
|
const kernels = emitKernels(graph, plan)
|
|
161
|
-
const lossBufferId = plan.tensorToBuffer.get(loss.id)!
|
|
162
|
-
const runtime = await createRuntime(plan, kernels, lossBufferId, opts)
|
|
163
|
-
|
|
164
|
-
if (adamResult) wrapStepForAdam(runtime, adamResult)
|
|
165
|
-
uploadInitialParams(plan, materialized.initFns, runtime, /* sharedParams */ undefined)
|
|
166
|
-
|
|
167
189
|
const ir: CompiledIR = { graph, paramGrads, loss, plan, kernels }
|
|
168
|
-
const kernelCount = countKernels(kernels)
|
|
169
190
|
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
191
|
+
// Initial params: resolve init shapes to Float32Arrays now (main thread).
|
|
192
|
+
// These transfer (zero-copy) to the worker as part of createRuntime.
|
|
193
|
+
const initialParams = buildInitialParams(plan, materialized.initFns)
|
|
194
|
+
|
|
195
|
+
// ---- Spawn worker, send IR + initial params -----------------------------
|
|
196
|
+
const proxy = new WorkerProxy(__WORKER_SOURCE__)
|
|
197
|
+
const wireIR: WireIR = { graph, plan, kernels }
|
|
198
|
+
const wireAdam = adamResult ? wireAdamConfig(adamResult) : null
|
|
199
|
+
const transfers = transferablesOfRecord(initialParams)
|
|
200
|
+
|
|
201
|
+
let meta: CreateRuntimeResult
|
|
202
|
+
try {
|
|
203
|
+
meta = await proxy.request<CreateRuntimeResult>(
|
|
204
|
+
{ kind: 'createRuntime', payload: { graphId: 0, ir: wireIR, initialParams, adam: wireAdam } },
|
|
205
|
+
transfers,
|
|
206
|
+
)
|
|
207
|
+
} catch (e) {
|
|
208
|
+
proxy.terminate()
|
|
209
|
+
throw e
|
|
173
210
|
}
|
|
174
211
|
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
...fOpts,
|
|
181
|
-
device: runtime.device,
|
|
182
|
-
sharedParams: runtime.params,
|
|
183
|
-
})
|
|
184
|
-
|
|
185
|
-
return Object.assign(runtime, { ir, kernelCount, reset, compileForward: compileForwardMethod })
|
|
212
|
+
return new CompiledModuleProxy<M>(
|
|
213
|
+
proxy, /* graphId */ 0, ir, meta, modelFactory,
|
|
214
|
+
/* initFns */ materialized.initFns,
|
|
215
|
+
/* nextGraphId */ { v: 1 },
|
|
216
|
+
)
|
|
186
217
|
}
|
|
187
218
|
|
|
188
|
-
// ============================================================================
|
|
189
|
-
// Forward-only compile
|
|
190
|
-
// ============================================================================
|
|
191
|
-
|
|
192
219
|
/**
|
|
193
|
-
*
|
|
194
|
-
*
|
|
195
|
-
*
|
|
196
|
-
* `Float32Array`.
|
|
197
|
-
*
|
|
198
|
-
* **Prefer the `compileForward` method on a training runtime** when both
|
|
199
|
-
* graphs use the same Module class — it auto-supplies `device` and
|
|
200
|
-
* `sharedParams`. This standalone form is for forward-only models with no
|
|
201
|
-
* training graph at all, or for sharing params across a different model.
|
|
202
|
-
*
|
|
203
|
-
* **Sharing params with a training compile.** Pass `opts.sharedParams =
|
|
204
|
-
* trainCompiled.params` to bind this graph's param buffers to an existing
|
|
205
|
-
* training runtime's GPU buffers — every train step is then immediately
|
|
206
|
-
* visible to `run()` calls here, no copies.
|
|
207
|
-
*
|
|
208
|
-
* Initial param values are uploaded automatically for params *not* covered
|
|
209
|
-
* by `sharedParams` (those are owned by the sibling compile).
|
|
220
|
+
* Forward-only compile. Spawns its own worker. For sibling graphs that share
|
|
221
|
+
* params with a training graph, prefer the `compileForward` method on the
|
|
222
|
+
* CompiledModule returned by `compileModule()`.
|
|
210
223
|
*/
|
|
211
224
|
export async function compileForward<M extends Module, I extends InputDecls = InputDecls>(
|
|
212
225
|
modelFactory: () => M,
|
|
@@ -215,16 +228,195 @@ export async function compileForward<M extends Module, I extends InputDecls = In
|
|
|
215
228
|
): Promise<CompiledForwardModule> {
|
|
216
229
|
const { graph, materialized } = traceModule(modelFactory, forward, opts.inputs ?? {})
|
|
217
230
|
const outputTensor = graph.tensors[graph.outputs[0]!]!
|
|
218
|
-
|
|
219
231
|
const plan = planBuffers(graph, /* paramGrads */ {})
|
|
220
232
|
const kernels = emitKernels(graph, plan)
|
|
221
|
-
const
|
|
222
|
-
const runtime = await createForwardRuntime(plan, kernels, outputBufferId, opts)
|
|
233
|
+
const ir: CompiledIR = { graph, paramGrads: {}, loss: outputTensor, plan, kernels }
|
|
223
234
|
|
|
224
|
-
|
|
235
|
+
const initialParams = buildInitialParams(plan, materialized.initFns)
|
|
236
|
+
const proxy = new WorkerProxy(__WORKER_SOURCE__)
|
|
237
|
+
const wireIR: WireIR = { graph, plan, kernels }
|
|
238
|
+
const transfers = transferablesOfRecord(initialParams)
|
|
239
|
+
|
|
240
|
+
let meta: CreateRuntimeResult
|
|
241
|
+
try {
|
|
242
|
+
meta = await proxy.request<CreateRuntimeResult>(
|
|
243
|
+
{ kind: 'createRuntime', payload: { graphId: 0, ir: wireIR, initialParams, adam: null } },
|
|
244
|
+
transfers,
|
|
245
|
+
)
|
|
246
|
+
} catch (e) {
|
|
247
|
+
proxy.terminate()
|
|
248
|
+
throw e
|
|
249
|
+
}
|
|
225
250
|
|
|
226
|
-
|
|
227
|
-
|
|
251
|
+
return new CompiledForwardModuleProxy(proxy, /* graphId */ 0, ir, meta, /* ownsWorker */ true)
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
// ============================================================================
|
|
255
|
+
// Proxy implementations
|
|
256
|
+
// ============================================================================
|
|
257
|
+
|
|
258
|
+
class CompiledModuleProxy<M extends Module> implements CompiledModule<M> {
|
|
259
|
+
constructor(
|
|
260
|
+
private readonly proxy: WorkerProxy,
|
|
261
|
+
private readonly graphId: number,
|
|
262
|
+
public readonly ir: CompiledIR,
|
|
263
|
+
private readonly meta: CreateRuntimeResult,
|
|
264
|
+
private readonly modelFactory: () => M,
|
|
265
|
+
/** Init closures captured from materializeParams at compile time. Used
|
|
266
|
+
* by reset() to regenerate initial param values. */
|
|
267
|
+
private readonly initFns: Record<string, InitFn>,
|
|
268
|
+
private readonly nextGraphId: { v: number },
|
|
269
|
+
) {}
|
|
270
|
+
|
|
271
|
+
get kernelCount(): number { return this.meta.kernelCount }
|
|
272
|
+
get outputShape(): readonly number[] { return this.meta.outputShape }
|
|
273
|
+
get paramNames(): readonly string[] { return this.meta.paramNames }
|
|
274
|
+
|
|
275
|
+
step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>
|
|
276
|
+
step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<StepResult>
|
|
277
|
+
async step(
|
|
278
|
+
inputs: Record<string, Int32Array | Float32Array>,
|
|
279
|
+
opts?: { withCaptures?: boolean },
|
|
280
|
+
): Promise<number | StepResult> {
|
|
281
|
+
// Note: inputs are copied (not transferred) into the worker. Callers
|
|
282
|
+
// commonly reuse the same TypedArray as a scratch buffer across step()
|
|
283
|
+
// calls; transferring would detach it. The copy cost is small relative
|
|
284
|
+
// to a training step's GPU work.
|
|
285
|
+
const r = await this.proxy.request<StepResultWire>(
|
|
286
|
+
{ kind: 'step', payload: { graphId: this.graphId, inputs, withCaptures: opts?.withCaptures === true } },
|
|
287
|
+
)
|
|
288
|
+
if (opts?.withCaptures) {
|
|
289
|
+
return { loss: r.loss, captures: makeCaptures(r.captures, this.meta.captureShapes) }
|
|
290
|
+
}
|
|
291
|
+
return r.loss
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
run(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>
|
|
295
|
+
run(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunResult>
|
|
296
|
+
async run(
|
|
297
|
+
inputs: Record<string, Int32Array | Float32Array>,
|
|
298
|
+
opts?: { withCaptures?: boolean },
|
|
299
|
+
): Promise<Float32Array | RunResult> {
|
|
300
|
+
// Inputs copied (see note in step()).
|
|
301
|
+
const r = await this.proxy.request<RunResultWire>(
|
|
302
|
+
{ kind: 'run', payload: { graphId: this.graphId, inputs, withCaptures: opts?.withCaptures === true } },
|
|
303
|
+
)
|
|
304
|
+
if (opts?.withCaptures) {
|
|
305
|
+
return { output: r.output, captures: makeCaptures(r.captures, this.meta.captureShapes) }
|
|
306
|
+
}
|
|
307
|
+
return r.output
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): Promise<void> {
|
|
311
|
+
// Params copied (see note in step()) — caller's Float32Arrays stay valid.
|
|
312
|
+
return this.proxy.request<null>(
|
|
313
|
+
{ kind: 'uploadParams', payload: { graphId: this.graphId, params, partial: !!opts?.partial } },
|
|
314
|
+
).then(() => undefined)
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
async downloadParams(): Promise<Record<string, Float32Array>> {
|
|
318
|
+
const r = await this.proxy.request<DownloadParamsResult>(
|
|
319
|
+
{ kind: 'downloadParams', payload: { graphId: this.graphId } },
|
|
320
|
+
)
|
|
321
|
+
return r.params
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
async downloadParamGrads(): Promise<Record<string, Float32Array>> {
|
|
325
|
+
const r = await this.proxy.request<DownloadParamsResult>(
|
|
326
|
+
{ kind: 'downloadParamGrads', payload: { graphId: this.graphId } },
|
|
327
|
+
)
|
|
328
|
+
return r.params
|
|
329
|
+
}
|
|
330
|
+
|
|
331
|
+
async reset(): Promise<void> {
|
|
332
|
+
// Re-init main-thread, upload, then reset Adam state on worker. Two
|
|
333
|
+
// round-trips but reset() is rare. The init closures were captured at
|
|
334
|
+
// compile time and stashed on the proxy.
|
|
335
|
+
const initialParams = buildInitialParams(this.ir.plan, this.initFns)
|
|
336
|
+
await this.uploadParams(initialParams)
|
|
337
|
+
await this.resetOptimizerState()
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
resetOptimizerState(): Promise<void> {
|
|
341
|
+
return this.proxy.request<null>(
|
|
342
|
+
{ kind: 'resetOptimizer', payload: { graphId: this.graphId } },
|
|
343
|
+
).then(() => undefined)
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
async compileForward<I extends InputDecls>(
|
|
347
|
+
forward: ForwardFn<M, I>,
|
|
348
|
+
opts: CompileForwardMethodOptions<I> = {},
|
|
349
|
+
): Promise<CompiledForwardModule> {
|
|
350
|
+
const { graph, materialized: _materialized } = traceModule(this.modelFactory, forward, opts.inputs ?? {})
|
|
351
|
+
const outputTensor = graph.tensors[graph.outputs[0]!]!
|
|
352
|
+
const plan = planBuffers(graph, /* paramGrads */ {})
|
|
353
|
+
const kernels = emitKernels(graph, plan)
|
|
354
|
+
const ir: CompiledIR = { graph, paramGrads: {}, loss: outputTensor, plan, kernels }
|
|
355
|
+
|
|
356
|
+
const childGraphId = this.nextGraphId.v++
|
|
357
|
+
const wireIR: WireIR = { graph, plan, kernels }
|
|
358
|
+
|
|
359
|
+
const meta = await this.proxy.request<CompileForwardResult>(
|
|
360
|
+
{ kind: 'compileForward', payload: { graphId: childGraphId, parentGraphId: this.graphId, ir: wireIR } },
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
return new CompiledForwardModuleProxy(this.proxy, childGraphId, ir, meta, /* ownsWorker */ false)
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
destroy(): void {
|
|
367
|
+
// Fire-and-forget destroy; postMessage ordering ensures the worker
|
|
368
|
+
// processes any in-flight requests before we terminate it.
|
|
369
|
+
this.proxy.send({ kind: 'destroy', payload: { graphId: this.graphId } })
|
|
370
|
+
this.proxy.terminate()
|
|
371
|
+
}
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
class CompiledForwardModuleProxy implements CompiledForwardModule {
|
|
375
|
+
constructor(
|
|
376
|
+
private readonly proxy: WorkerProxy,
|
|
377
|
+
private readonly graphId: number,
|
|
378
|
+
public readonly ir: CompiledIR,
|
|
379
|
+
private readonly meta: CompileForwardResult | CreateRuntimeResult,
|
|
380
|
+
private readonly ownsWorker: boolean,
|
|
381
|
+
) {}
|
|
382
|
+
|
|
383
|
+
get kernelCount(): number { return this.meta.kernelCount }
|
|
384
|
+
get outputShape(): readonly number[] { return this.meta.outputShape }
|
|
385
|
+
get paramNames(): readonly string[] { return this.meta.paramNames }
|
|
386
|
+
|
|
387
|
+
run(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>
|
|
388
|
+
run(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunResult>
|
|
389
|
+
async run(
|
|
390
|
+
inputs: Record<string, Int32Array | Float32Array>,
|
|
391
|
+
opts?: { withCaptures?: boolean },
|
|
392
|
+
): Promise<Float32Array | RunResult> {
|
|
393
|
+
// Inputs copied; caller's TypedArrays stay valid.
|
|
394
|
+
const r = await this.proxy.request<RunResultWire>(
|
|
395
|
+
{ kind: 'run', payload: { graphId: this.graphId, inputs, withCaptures: opts?.withCaptures === true } },
|
|
396
|
+
)
|
|
397
|
+
if (opts?.withCaptures) {
|
|
398
|
+
return { output: r.output, captures: makeCaptures(r.captures, this.meta.captureShapes) }
|
|
399
|
+
}
|
|
400
|
+
return r.output
|
|
401
|
+
}
|
|
402
|
+
|
|
403
|
+
uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): Promise<void> {
|
|
404
|
+
return this.proxy.request<null>(
|
|
405
|
+
{ kind: 'uploadParams', payload: { graphId: this.graphId, params, partial: !!opts?.partial } },
|
|
406
|
+
).then(() => undefined)
|
|
407
|
+
}
|
|
408
|
+
|
|
409
|
+
async downloadParams(): Promise<Record<string, Float32Array>> {
|
|
410
|
+
const r = await this.proxy.request<DownloadParamsResult>(
|
|
411
|
+
{ kind: 'downloadParams', payload: { graphId: this.graphId } },
|
|
412
|
+
)
|
|
413
|
+
return r.params
|
|
414
|
+
}
|
|
415
|
+
|
|
416
|
+
destroy(): void {
|
|
417
|
+
this.proxy.send({ kind: 'destroy', payload: { graphId: this.graphId } })
|
|
418
|
+
if (this.ownsWorker) this.proxy.terminate()
|
|
419
|
+
}
|
|
228
420
|
}
|
|
229
421
|
|
|
230
422
|
// ============================================================================
|
|
@@ -255,60 +447,46 @@ function traceModule<M extends Module, I extends InputDecls>(
|
|
|
255
447
|
return { graph, materialized }
|
|
256
448
|
}
|
|
257
449
|
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
* effective LR) and, when the user supplied a per-step lr schedule, the
|
|
262
|
-
* decayShrink scalar. Also wraps resetOptimizerState() so a reset zeros
|
|
263
|
-
* Adam's m/v *and* the bias-correction step counter — otherwise the next
|
|
264
|
-
* step would skip Adam's warmup phase. */
|
|
265
|
-
function wrapStepForAdam(runtime: CompiledRuntime, adamResult: AdamResult): void {
|
|
266
|
-
const { lrtInputName, decayShrinkInputName, config } = adamResult
|
|
267
|
-
let t = 0
|
|
268
|
-
const lrtBuf = new Float32Array(1)
|
|
269
|
-
const decayShrinkBuf = decayShrinkInputName ? new Float32Array(1) : null
|
|
270
|
-
const innerStep = runtime.step.bind(runtime) as CompiledRuntime['step']
|
|
271
|
-
const innerReset = runtime.resetOptimizerState.bind(runtime)
|
|
272
|
-
const wrappedStep = ((
|
|
273
|
-
inputs: Record<string, Int32Array | Float32Array>,
|
|
274
|
-
opts?: { withCaptures?: boolean; readLoss?: boolean },
|
|
275
|
-
) => {
|
|
276
|
-
t++
|
|
277
|
-
const lrNow = config.lr(t)
|
|
278
|
-
lrtBuf[0] = lrNow * Math.sqrt(1 - Math.pow(config.b2, t)) / (1 - Math.pow(config.b1, t))
|
|
279
|
-
const merged: Record<string, Int32Array | Float32Array> = { ...inputs, [lrtInputName]: lrtBuf }
|
|
280
|
-
if (decayShrinkBuf && decayShrinkInputName) {
|
|
281
|
-
decayShrinkBuf[0] = 1 - lrNow * config.weightDecay
|
|
282
|
-
merged[decayShrinkInputName] = decayShrinkBuf
|
|
283
|
-
}
|
|
284
|
-
if (opts?.readLoss === false) return innerStep(merged, { readLoss: false })
|
|
285
|
-
if (opts?.withCaptures) return innerStep(merged, { withCaptures: true })
|
|
286
|
-
return innerStep(merged)
|
|
287
|
-
}) as CompiledRuntime['step']
|
|
288
|
-
runtime.step = wrappedStep
|
|
289
|
-
runtime.resetOptimizerState = () => {
|
|
290
|
-
t = 0
|
|
291
|
-
innerReset()
|
|
292
|
-
}
|
|
293
|
-
}
|
|
294
|
-
|
|
295
|
-
/** Build a Record<paramName, Float32Array> by running each param's init
|
|
296
|
-
* function against its shape and uploading them to the runtime. Skips any
|
|
297
|
-
* param covered by `sharedParams` (those are owned by a sibling compile). */
|
|
298
|
-
function uploadInitialParams(
|
|
299
|
-
plan: BufferPlan,
|
|
300
|
-
initFns: Record<string, InitFn>,
|
|
301
|
-
runtime: CompiledRuntime | CompiledForward,
|
|
302
|
-
sharedParams: Map<string, GPUBuffer> | undefined,
|
|
303
|
-
): void {
|
|
450
|
+
/** Run each param's init function against its declared shape to produce the
|
|
451
|
+
* initial Float32Arrays. Runs main-thread before transfer to the worker. */
|
|
452
|
+
function buildInitialParams(plan: BufferPlan, initFns: Record<string, InitFn>): Record<string, Float32Array> {
|
|
304
453
|
const out: Record<string, Float32Array> = {}
|
|
305
454
|
for (const [name, bufId] of plan.paramsByName) {
|
|
306
|
-
if (sharedParams?.has(name)) continue
|
|
307
455
|
const shape = plan.buffers[bufId]!.shape
|
|
308
456
|
const size = shape.reduce((a, b) => a * b, 1)
|
|
309
457
|
const initFn = initFns[name]
|
|
310
458
|
if (!initFn) throw new Error(`compile: no init for param '${name}'`)
|
|
311
459
|
out[name] = initFn(size, shape)
|
|
312
460
|
}
|
|
313
|
-
|
|
461
|
+
return out
|
|
462
|
+
}
|
|
463
|
+
|
|
464
|
+
/** Subset of AdamResolvedConfig that crosses the wire (drops decayFilter,
|
|
465
|
+
* which is only used at compile time). */
|
|
466
|
+
function wireAdamConfig(r: AdamResult): WireAdamConfig {
|
|
467
|
+
const c: AdamResolvedConfig = r.config
|
|
468
|
+
return {
|
|
469
|
+
lr: c.lr,
|
|
470
|
+
b1: c.b1,
|
|
471
|
+
b2: c.b2,
|
|
472
|
+
eps: c.eps,
|
|
473
|
+
weightDecay: c.weightDecay,
|
|
474
|
+
lrIsScheduled: c.lrIsScheduled,
|
|
475
|
+
lrtInputName: r.lrtInputName,
|
|
476
|
+
decayShrinkInputName: r.decayShrinkInputName,
|
|
477
|
+
}
|
|
314
478
|
}
|
|
479
|
+
|
|
480
|
+
/** Wrap a worker-returned `Record<name, Float32Array>` in a Captures instance
|
|
481
|
+
* using the static capture shapes captured at compile time. */
|
|
482
|
+
function makeCaptures(
|
|
483
|
+
captures: Record<string, Float32Array> | null,
|
|
484
|
+
captureShapes: Record<string, number[]>,
|
|
485
|
+
): Captures {
|
|
486
|
+
const data = new Map<string, Float32Array>()
|
|
487
|
+
if (captures) {
|
|
488
|
+
for (const [name, arr] of Object.entries(captures)) data.set(name, arr)
|
|
489
|
+
}
|
|
490
|
+
return new Captures(captureShapes, data)
|
|
491
|
+
}
|
|
492
|
+
|
package/src/index.ts
CHANGED
|
@@ -33,15 +33,19 @@ export {
|
|
|
33
33
|
// adam.ts can import them) but aren't part of the public API — `add`/`mul`
|
|
34
34
|
// overload on JS numbers, `where` subsumes the rest.
|
|
35
35
|
export { appendGrad, type GradResult } from './grad.js'
|
|
36
|
-
export { appendAdam, type AdamConfig, type AdamResult } from './adam.js'
|
|
36
|
+
export { appendAdam, lr, resolveLR, type AdamConfig, type AdamResult, type LRSchedule } from './adam.js'
|
|
37
37
|
export { planBuffers, type BufferPlan, type BufferSpec, type Writeback, type WritebackDecl } from './buffers.js'
|
|
38
38
|
export { emitKernels, type KernelSpec } from './codegen.js'
|
|
39
|
-
|
|
39
|
+
// Runtime types: only the user-facing pieces. CompiledRuntime/CompiledForward
|
|
40
|
+
// (worker-internal) and createRuntime/createForwardRuntime aren't part of the
|
|
41
|
+
// public API — users get CompiledModule/CompiledForwardModule (proxies) from
|
|
42
|
+
// compileModule/compileForward instead.
|
|
43
|
+
export { Captures, type RunOptions, type StepResult, type RunResult, type UploadParamsOptions } from './runtime.js'
|
|
40
44
|
export {
|
|
41
|
-
|
|
45
|
+
compileToIR, compileModule, compileForward,
|
|
42
46
|
type CompiledIR, type CompileModuleOptions, type CompileForwardOptions, type CompileForwardMethodOptions,
|
|
43
47
|
type CompiledModule, type CompiledForwardModule,
|
|
44
48
|
type InputDecl, type InputDecls, type InputsTensors, type ForwardFn,
|
|
45
49
|
} from './compile.js'
|
|
46
|
-
export { Module, materializeParams, type InitSpec, type ParamOptions, type MaterializedParams } from './module.js'
|
|
50
|
+
export { Module, materializeParams, init, type InitSpec, type ParamOptions, type MaterializedParams } from './module.js'
|
|
47
51
|
export * as nn from './nn.js'
|