tensorgrad 0.0.1 → 0.0.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +7 -9
- package/dist/adam.d.ts +14 -2
- package/dist/adam.d.ts.map +1 -1
- package/dist/adam.js +19 -8
- package/dist/adam.js.map +1 -1
- package/dist/buffers.d.ts +1 -0
- package/dist/buffers.d.ts.map +1 -1
- package/dist/buffers.js +12 -1
- package/dist/buffers.js.map +1 -1
- package/dist/capture.d.ts +3 -0
- package/dist/capture.d.ts.map +1 -0
- package/dist/capture.js +33 -0
- package/dist/capture.js.map +1 -0
- package/dist/codegen.js +4 -2
- package/dist/codegen.js.map +1 -1
- package/dist/compile.d.ts +33 -5
- package/dist/compile.d.ts.map +1 -1
- package/dist/compile.js +96 -11
- package/dist/compile.js.map +1 -1
- package/dist/index.d.ts +5 -3
- package/dist/index.d.ts.map +1 -1
- package/dist/index.js +4 -2
- package/dist/index.js.map +1 -1
- package/dist/ir.d.ts +2 -0
- package/dist/ir.d.ts.map +1 -1
- package/dist/ir.js +1 -1
- package/dist/ir.js.map +1 -1
- package/dist/module.d.ts +30 -4
- package/dist/module.d.ts.map +1 -1
- package/dist/module.js +39 -13
- package/dist/module.js.map +1 -1
- package/dist/nn.d.ts +19 -0
- package/dist/nn.d.ts.map +1 -0
- package/dist/nn.js +60 -0
- package/dist/nn.js.map +1 -0
- package/dist/ops.d.ts +1 -1
- package/dist/ops.d.ts.map +1 -1
- package/dist/ops.js +2 -2
- package/dist/ops.js.map +1 -1
- package/dist/runtime.d.ts +79 -4
- package/dist/runtime.d.ts.map +1 -1
- package/dist/runtime.js +153 -19
- package/dist/runtime.js.map +1 -1
- package/dist/trace.d.ts +1 -0
- package/dist/trace.d.ts.map +1 -1
- package/dist/trace.js +12 -0
- package/dist/trace.js.map +1 -1
- package/package.json +1 -2
- package/src/adam.ts +31 -10
- package/src/buffers.ts +14 -1
- package/src/capture.ts +36 -0
- package/src/codegen.ts +4 -2
- package/src/compile.ts +112 -13
- package/src/index.ts +5 -3
- package/src/ir.ts +10 -4
- package/src/module.ts +75 -11
- package/src/nn.ts +59 -0
- package/src/ops.ts +2 -2
- package/src/runtime.ts +260 -22
- package/src/trace.ts +13 -0
- package/SPEC.md +0 -293
package/src/runtime.ts
CHANGED
|
@@ -11,9 +11,47 @@ import type { KernelSpec } from './codegen.js'
|
|
|
11
11
|
// Provided by the browser per WebGPU spec; declare just what we use.
|
|
12
12
|
declare const GPUMapMode: { readonly READ: number; readonly WRITE: number }
|
|
13
13
|
|
|
14
|
+
export interface UploadParamsOptions {
|
|
15
|
+
/** Skip the "missing param" check, allowing the caller to update only some
|
|
16
|
+
* params and leave the rest at their current GPU values. Extra (unknown)
|
|
17
|
+
* keys are still rejected — that's always a typo. Default: false. */
|
|
18
|
+
partial?: boolean
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
export interface StepOptions {
|
|
22
|
+
/** Read back tensors registered via `capture(name, t)` during the trace.
|
|
23
|
+
* When false/unset, `step()` returns just the loss number; staging buffers
|
|
24
|
+
* for captures are not allocated. When true, returns `{ loss, captures }`
|
|
25
|
+
* and lazily allocates one staging buffer per capture on first use. */
|
|
26
|
+
withCaptures?: boolean
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
export interface StepWithCaptures {
|
|
30
|
+
loss: number
|
|
31
|
+
captures: Record<string, Float32Array>
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
export interface RunOptions {
|
|
35
|
+
/** Read back tensors registered via `capture(name, t)` during the trace.
|
|
36
|
+
* When false/unset, `run()` returns just the output Float32Array.
|
|
37
|
+
* When true, returns `{ output, captures }`. */
|
|
38
|
+
withCaptures?: boolean
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
export interface RunWithCaptures {
|
|
42
|
+
output: Float32Array
|
|
43
|
+
captures: Record<string, Float32Array>
|
|
44
|
+
}
|
|
45
|
+
|
|
14
46
|
export interface CompiledRuntime {
|
|
15
|
-
/**
|
|
16
|
-
|
|
47
|
+
/** Map of param name -> the underlying GPUBuffer. Pass to a sibling compile
|
|
48
|
+
* via `sharedParams` to share without copies — every step on this runtime
|
|
49
|
+
* is immediately visible to anyone reading these buffers. */
|
|
50
|
+
params: Map<string, GPUBuffer>
|
|
51
|
+
/** Upload parameter Float32Arrays to their GPU buffers. By default, requires
|
|
52
|
+
* *all* params to be present; throws on any unknown or missing key. Pass
|
|
53
|
+
* `{ partial: true }` to skip the missing-key check. */
|
|
54
|
+
uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): void
|
|
17
55
|
/** Read all parameters back as Float32Arrays — used for UI panels. */
|
|
18
56
|
downloadParams(): Promise<Record<string, Float32Array>>
|
|
19
57
|
/** Read all parameter gradients back. Mostly for verification / debugging. */
|
|
@@ -22,17 +60,51 @@ export interface CompiledRuntime {
|
|
|
22
60
|
* One full forward+backward step.
|
|
23
61
|
* 1. Uploads `inputs` (tokens, targets, masks) to input buffers.
|
|
24
62
|
* 2. Dispatches every kernel in order.
|
|
25
|
-
* 3. Reads back the loss scalar.
|
|
26
|
-
*
|
|
63
|
+
* 3. Reads back the loss scalar (and any registered captures, if requested).
|
|
64
|
+
* Default returns the loss as a JS number; with `{ withCaptures: true }`
|
|
65
|
+
* returns `{ loss, captures }` where `captures` is keyed by the names passed
|
|
66
|
+
* to `capture(...)` during the trace.
|
|
27
67
|
*/
|
|
28
68
|
step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>
|
|
69
|
+
step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<StepWithCaptures>
|
|
70
|
+
step(inputs: Record<string, Int32Array | Float32Array>, opts: StepOptions): Promise<number | StepWithCaptures>
|
|
71
|
+
/** Like `step()` but returns the full output Float32Array instead of just
|
|
72
|
+
* its first element. For training graphs this is rarely useful (the output
|
|
73
|
+
* *is* a scalar loss); it's the primary API for forward-only compiles. */
|
|
74
|
+
run(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>
|
|
75
|
+
run(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunWithCaptures>
|
|
76
|
+
run(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<Float32Array | RunWithCaptures>
|
|
77
|
+
/** Re-zero all optimizer state buffers (Adam's m/v) in place. Pair with
|
|
78
|
+
* `uploadInitialParams()` for a full training reset without recompile. */
|
|
79
|
+
resetOptimizerState(): void
|
|
29
80
|
/** Free GPU resources. */
|
|
30
81
|
destroy(): void
|
|
31
82
|
}
|
|
32
83
|
|
|
84
|
+
/** Forward-only compiled runtime — produced by `compileForward`. No optimizer,
|
|
85
|
+
* no backward. Returns the output tensor (not just a scalar) per `run()` call. */
|
|
86
|
+
export interface CompiledForward {
|
|
87
|
+
params: Map<string, GPUBuffer>
|
|
88
|
+
uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): void
|
|
89
|
+
downloadParams(): Promise<Record<string, Float32Array>>
|
|
90
|
+
/** Forward-only dispatch. Returns the graph's output tensor as a Float32Array
|
|
91
|
+
* (the user's returned tensor from the forward function, in row-major order).
|
|
92
|
+
* With `{ withCaptures: true }`, returns `{ output, captures }`. */
|
|
93
|
+
run(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>
|
|
94
|
+
run(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunWithCaptures>
|
|
95
|
+
run(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<Float32Array | RunWithCaptures>
|
|
96
|
+
destroy(): void
|
|
97
|
+
}
|
|
98
|
+
|
|
33
99
|
export interface RuntimeOpts {
|
|
34
100
|
/** Pre-acquired GPUDevice. If omitted, runtime requests its own. */
|
|
35
101
|
device?: GPUDevice
|
|
102
|
+
/** External param buffers to bind in place of allocating fresh ones, keyed
|
|
103
|
+
* by param name. Used to share params between a training compile and a
|
|
104
|
+
* sibling forward-only compile (e.g., a B=1 inference graph). When a name
|
|
105
|
+
* is in this map, the runtime reuses the provided GPUBuffer; otherwise it
|
|
106
|
+
* allocates as usual. */
|
|
107
|
+
sharedParams?: Map<string, GPUBuffer>
|
|
36
108
|
}
|
|
37
109
|
|
|
38
110
|
// Inlined numeric values (per WebGPU spec) so this module is importable in Node
|
|
@@ -52,8 +124,23 @@ export async function createRuntime(
|
|
|
52
124
|
|
|
53
125
|
// ---- Allocate one GPUBuffer per BufferSpec --------------------------------
|
|
54
126
|
// State buffers also get filled with their initValue at allocation time.
|
|
127
|
+
// Param buffers may be supplied externally via opts.sharedParams; in that
|
|
128
|
+
// case we reuse the provided GPUBuffer instead of allocating, and the
|
|
129
|
+
// sibling compile that owns it is responsible for upload + lifetime.
|
|
55
130
|
const buffers = new Map<number, GPUBuffer>()
|
|
131
|
+
const sharedParams = opts.sharedParams
|
|
56
132
|
for (const spec of plan.buffers) {
|
|
133
|
+
if (spec.kind === 'param' && sharedParams?.has(spec.name!)) {
|
|
134
|
+
const shared = sharedParams.get(spec.name!)!
|
|
135
|
+
if (shared.size !== spec.byteSize) {
|
|
136
|
+
throw new Error(
|
|
137
|
+
`sharedParams: size mismatch for '${spec.name}' — supplied ${shared.size} bytes, ` +
|
|
138
|
+
`compiled graph expects ${spec.byteSize}.`,
|
|
139
|
+
)
|
|
140
|
+
}
|
|
141
|
+
buffers.set(spec.id, shared)
|
|
142
|
+
continue
|
|
143
|
+
}
|
|
57
144
|
const buf = device.createBuffer({
|
|
58
145
|
size: spec.byteSize,
|
|
59
146
|
usage: STORAGE_RW,
|
|
@@ -69,6 +156,13 @@ export async function createRuntime(
|
|
|
69
156
|
queue.writeBuffer(buf, 0, init as unknown as BufferSource)
|
|
70
157
|
}
|
|
71
158
|
}
|
|
159
|
+
// Track which params are externally owned — those are skipped on destroy().
|
|
160
|
+
const ownedBufferIds = new Set<number>()
|
|
161
|
+
for (const spec of plan.buffers) {
|
|
162
|
+
if (!(spec.kind === 'param' && sharedParams?.has(spec.name!))) {
|
|
163
|
+
ownedBufferIds.add(spec.id)
|
|
164
|
+
}
|
|
165
|
+
}
|
|
72
166
|
|
|
73
167
|
// ---- Compile pipelines per kernel; cache by WGSL source -------------------
|
|
74
168
|
// Push an error scope around each shader+pipeline creation so we can surface
|
|
@@ -127,12 +221,45 @@ export async function createRuntime(
|
|
|
127
221
|
})
|
|
128
222
|
})
|
|
129
223
|
|
|
130
|
-
// ----
|
|
131
|
-
|
|
132
|
-
|
|
224
|
+
// ---- Output readback staging buffer ---------------------------------------
|
|
225
|
+
// `outputBufferId` is the graph's main output (loss for training, the user's
|
|
226
|
+
// returned tensor for forward-only). step() reads back its first element;
|
|
227
|
+
// run() reads back the full Float32Array.
|
|
228
|
+
const outputSpec = plan.buffers[lossBufferId]!
|
|
229
|
+
const outputReadback = device.createBuffer({ size: outputSpec.byteSize, usage: READBACK })
|
|
230
|
+
|
|
231
|
+
// ---- Capture readback staging buffers (lazy) ------------------------------
|
|
232
|
+
// Allocated on first `step({ withCaptures: true })` call and reused across
|
|
233
|
+
// subsequent calls. When the graph has no captures registered or when the
|
|
234
|
+
// caller never opts in, no extra GPU memory is allocated.
|
|
235
|
+
let captureStagings: Map<string, GPUBuffer> | null = null
|
|
236
|
+
function ensureCaptureStagings(): Map<string, GPUBuffer> {
|
|
237
|
+
if (captureStagings) return captureStagings
|
|
238
|
+
captureStagings = new Map()
|
|
239
|
+
for (const [name, bufId] of plan.capturesByName) {
|
|
240
|
+
const spec = plan.buffers[bufId]!
|
|
241
|
+
const staging = device.createBuffer({ size: spec.byteSize, usage: READBACK, label: `cap-${name}` })
|
|
242
|
+
captureStagings.set(name, staging)
|
|
243
|
+
}
|
|
244
|
+
return captureStagings
|
|
245
|
+
}
|
|
133
246
|
|
|
134
|
-
// ---- step()
|
|
135
|
-
|
|
247
|
+
// ---- dispatch() — shared core for step() and run() -----------------------
|
|
248
|
+
// Uploads inputs, dispatches all kernels (in order), queues writebacks, copies
|
|
249
|
+
// the output buffer into its staging, optionally copies captures into theirs,
|
|
250
|
+
// submits, and reads back. Returns the full output Float32Array; step() takes
|
|
251
|
+
// [0] for scalar loss, run() returns it whole.
|
|
252
|
+
async function dispatch(
|
|
253
|
+
inputs: Record<string, Int32Array | Float32Array>,
|
|
254
|
+
wantCaptures: boolean,
|
|
255
|
+
): Promise<{ output: Float32Array; captures: Record<string, Float32Array> | null }> {
|
|
256
|
+
if (wantCaptures && plan.capturesByName.size === 0) {
|
|
257
|
+
throw new Error(
|
|
258
|
+
`withCaptures=true but no capture(...) calls were registered during ` +
|
|
259
|
+
`the trace. Add capture('name', tensor) inside your forward pass for ` +
|
|
260
|
+
`the intermediates you want read back.`,
|
|
261
|
+
)
|
|
262
|
+
}
|
|
136
263
|
for (const [name, bufId] of plan.inputsByName) {
|
|
137
264
|
const data = inputs[name]
|
|
138
265
|
if (!data) throw new Error(`tensorgrad: missing input '${name}'`)
|
|
@@ -165,26 +292,93 @@ export async function createRuntime(
|
|
|
165
292
|
pass.dispatchWorkgroups(wgX, wgY, 1)
|
|
166
293
|
pass.end()
|
|
167
294
|
}
|
|
168
|
-
// After all dispatches: writebacks (Adam state, updated params).
|
|
169
|
-
//
|
|
170
|
-
// all kernel dispatches.
|
|
295
|
+
// After all dispatches: writebacks (Adam state, updated params). Empty for
|
|
296
|
+
// forward-only compiles.
|
|
171
297
|
for (const wb of plan.writebacks) {
|
|
172
298
|
encoder.copyBufferToBuffer(buffers.get(wb.source)!, 0, buffers.get(wb.dest)!, 0, wb.bytes)
|
|
173
299
|
}
|
|
174
|
-
encoder.copyBufferToBuffer(buffers.get(lossBufferId)!, 0,
|
|
300
|
+
encoder.copyBufferToBuffer(buffers.get(lossBufferId)!, 0, outputReadback, 0, outputSpec.byteSize)
|
|
301
|
+
// Capture readbacks (only when opted in). Queued before submit so they
|
|
302
|
+
// observe the same kernel outputs as the main output.
|
|
303
|
+
let stagings: Map<string, GPUBuffer> | null = null
|
|
304
|
+
if (wantCaptures) {
|
|
305
|
+
stagings = ensureCaptureStagings()
|
|
306
|
+
for (const [name, bufId] of plan.capturesByName) {
|
|
307
|
+
const spec = plan.buffers[bufId]!
|
|
308
|
+
encoder.copyBufferToBuffer(buffers.get(bufId)!, 0, stagings.get(name)!, 0, spec.byteSize)
|
|
309
|
+
}
|
|
310
|
+
}
|
|
175
311
|
queue.submit([encoder.finish()])
|
|
176
312
|
|
|
177
|
-
await
|
|
178
|
-
const
|
|
179
|
-
|
|
180
|
-
|
|
313
|
+
await outputReadback.mapAsync(GPUMapMode.READ)
|
|
314
|
+
const output = new Float32Array(outputReadback.getMappedRange().slice(0))
|
|
315
|
+
outputReadback.unmap()
|
|
316
|
+
|
|
317
|
+
if (!wantCaptures) return { output, captures: null }
|
|
318
|
+
|
|
319
|
+
const captures: Record<string, Float32Array> = {}
|
|
320
|
+
for (const [name, staging] of stagings!) {
|
|
321
|
+
await staging.mapAsync(GPUMapMode.READ)
|
|
322
|
+
captures[name] = new Float32Array(staging.getMappedRange().slice(0))
|
|
323
|
+
staging.unmap()
|
|
324
|
+
}
|
|
325
|
+
return { output, captures }
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
// ---- step() — training-mode wrapper, returns scalar [0] of output ---------
|
|
329
|
+
function step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>
|
|
330
|
+
function step(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<StepWithCaptures>
|
|
331
|
+
function step(inputs: Record<string, Int32Array | Float32Array>, opts: StepOptions): Promise<number | StepWithCaptures>
|
|
332
|
+
async function step(
|
|
333
|
+
inputs: Record<string, Int32Array | Float32Array>,
|
|
334
|
+
opts?: StepOptions,
|
|
335
|
+
): Promise<number | StepWithCaptures> {
|
|
336
|
+
const r = await dispatch(inputs, opts?.withCaptures === true)
|
|
337
|
+
if (opts?.withCaptures) return { loss: r.output[0]!, captures: r.captures! }
|
|
338
|
+
return r.output[0]!
|
|
339
|
+
}
|
|
340
|
+
|
|
341
|
+
// ---- run() — forward-mode wrapper, returns full output Float32Array -------
|
|
342
|
+
function run(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>
|
|
343
|
+
function run(inputs: Record<string, Int32Array | Float32Array>, opts: { withCaptures: true }): Promise<RunWithCaptures>
|
|
344
|
+
function run(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<Float32Array | RunWithCaptures>
|
|
345
|
+
async function run(
|
|
346
|
+
inputs: Record<string, Int32Array | Float32Array>,
|
|
347
|
+
opts?: RunOptions,
|
|
348
|
+
): Promise<Float32Array | RunWithCaptures> {
|
|
349
|
+
const r = await dispatch(inputs, opts?.withCaptures === true)
|
|
350
|
+
if (opts?.withCaptures) return { output: r.output, captures: r.captures! }
|
|
351
|
+
return r.output
|
|
181
352
|
}
|
|
182
353
|
|
|
183
354
|
// ---- uploadParams ---------------------------------------------------------
|
|
184
|
-
function uploadParams(params: Record<string, Float32Array
|
|
355
|
+
function uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions) {
|
|
356
|
+
const partial = opts?.partial ?? false
|
|
357
|
+
for (const name of Object.keys(params)) {
|
|
358
|
+
if (!plan.paramsByName.has(name)) {
|
|
359
|
+
throw new Error(
|
|
360
|
+
`uploadParams: unknown param '${name}'. ` +
|
|
361
|
+
`Known: ${[...plan.paramsByName.keys()].sort().join(', ')}`,
|
|
362
|
+
)
|
|
363
|
+
}
|
|
364
|
+
}
|
|
365
|
+
if (!partial) {
|
|
366
|
+
for (const name of plan.paramsByName.keys()) {
|
|
367
|
+
if (!(name in params)) {
|
|
368
|
+
throw new Error(
|
|
369
|
+
`uploadParams: missing param '${name}'. ` +
|
|
370
|
+
`Pass { partial: true } if you mean to update only some params.`,
|
|
371
|
+
)
|
|
372
|
+
}
|
|
373
|
+
}
|
|
374
|
+
}
|
|
185
375
|
for (const [name, bufId] of plan.paramsByName) {
|
|
186
376
|
const data = params[name]
|
|
187
377
|
if (!data) continue
|
|
378
|
+
const expected = plan.buffers[bufId]!.byteSize / 4
|
|
379
|
+
if (data.length !== expected) {
|
|
380
|
+
throw new Error(`uploadParams: '${name}' has ${data.length} elements, expected ${expected}`)
|
|
381
|
+
}
|
|
188
382
|
queue.writeBuffer(buffers.get(bufId)!, 0, data as unknown as BufferSource)
|
|
189
383
|
}
|
|
190
384
|
}
|
|
@@ -210,15 +404,59 @@ export async function createRuntime(
|
|
|
210
404
|
return out
|
|
211
405
|
}
|
|
212
406
|
|
|
407
|
+
function resetOptimizerState() {
|
|
408
|
+
for (const spec of plan.buffers) {
|
|
409
|
+
if (spec.kind !== 'state') continue
|
|
410
|
+
const elements = spec.byteSize / 4
|
|
411
|
+
const init = spec.dtype === 'f32'
|
|
412
|
+
? new Float32Array(elements).fill(spec.initValue ?? 0)
|
|
413
|
+
: new Int32Array(elements).fill(Math.trunc(spec.initValue ?? 0))
|
|
414
|
+
queue.writeBuffer(buffers.get(spec.id)!, 0, init as unknown as BufferSource)
|
|
415
|
+
}
|
|
416
|
+
}
|
|
417
|
+
|
|
418
|
+
// Build the params map AFTER buffer allocation so it points at the actual
|
|
419
|
+
// GPUBuffers (shared or freshly allocated).
|
|
420
|
+
const params = new Map<string, GPUBuffer>()
|
|
421
|
+
for (const [name, bufId] of plan.paramsByName) {
|
|
422
|
+
params.set(name, buffers.get(bufId)!)
|
|
423
|
+
}
|
|
424
|
+
|
|
425
|
+
const destroy = () => {
|
|
426
|
+
for (const [id, b] of buffers) {
|
|
427
|
+
if (ownedBufferIds.has(id)) b.destroy()
|
|
428
|
+
}
|
|
429
|
+
outputReadback.destroy()
|
|
430
|
+
if (captureStagings) for (const b of captureStagings.values()) b.destroy()
|
|
431
|
+
}
|
|
432
|
+
|
|
213
433
|
return {
|
|
434
|
+
params,
|
|
214
435
|
uploadParams,
|
|
215
436
|
downloadParams: () => downloadFromMap(plan.paramsByName),
|
|
216
437
|
downloadParamGrads: () => downloadFromMap(plan.paramGradsByName),
|
|
217
438
|
step,
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
439
|
+
run,
|
|
440
|
+
resetOptimizerState,
|
|
441
|
+
destroy,
|
|
442
|
+
}
|
|
443
|
+
}
|
|
444
|
+
|
|
445
|
+
/** Same machinery as `createRuntime`, narrower public API: no step,
|
|
446
|
+
* no resetOptimizerState, no downloadParamGrads. Used by `compileForward`. */
|
|
447
|
+
export async function createForwardRuntime(
|
|
448
|
+
plan: BufferPlan,
|
|
449
|
+
kernels: KernelSpec[],
|
|
450
|
+
outputBufferId: number,
|
|
451
|
+
opts: RuntimeOpts = {},
|
|
452
|
+
): Promise<CompiledForward> {
|
|
453
|
+
const full = await createRuntime(plan, kernels, outputBufferId, opts)
|
|
454
|
+
return {
|
|
455
|
+
params: full.params,
|
|
456
|
+
uploadParams: full.uploadParams,
|
|
457
|
+
downloadParams: full.downloadParams,
|
|
458
|
+
run: full.run,
|
|
459
|
+
destroy: full.destroy,
|
|
222
460
|
}
|
|
223
461
|
}
|
|
224
462
|
|
package/src/trace.ts
CHANGED
|
@@ -19,6 +19,10 @@ import { makeGraph, addOp, captureSite } from './ir.js'
|
|
|
19
19
|
|
|
20
20
|
// Module-local: the graph being built right now, or null if no trace is active.
|
|
21
21
|
let _current: Graph | null = null
|
|
22
|
+
// Module-local: whether `capture(name, t)` calls should register on the current
|
|
23
|
+
// graph. True only during the user's forward trace; false during `traceInto`
|
|
24
|
+
// (autograd / optimizer ops shouldn't accidentally publish gradient tensors).
|
|
25
|
+
let _captureEnabled = false
|
|
22
26
|
|
|
23
27
|
export function currentGraph(): Graph {
|
|
24
28
|
if (!_current) {
|
|
@@ -30,6 +34,10 @@ export function currentGraph(): Graph {
|
|
|
30
34
|
return _current
|
|
31
35
|
}
|
|
32
36
|
|
|
37
|
+
export function isCaptureEnabled(): boolean {
|
|
38
|
+
return _captureEnabled
|
|
39
|
+
}
|
|
40
|
+
|
|
33
41
|
// Run `fn` with a fresh graph as the current one; capture and return the graph.
|
|
34
42
|
// `fn` must return the tensor (or array of tensors) to mark as graph outputs.
|
|
35
43
|
export function trace(fn: () => Tensor | Tensor[]): Graph {
|
|
@@ -38,6 +46,7 @@ export function trace(fn: () => Tensor | Tensor[]): Graph {
|
|
|
38
46
|
}
|
|
39
47
|
const g = makeGraph()
|
|
40
48
|
_current = g
|
|
49
|
+
_captureEnabled = true
|
|
41
50
|
try {
|
|
42
51
|
const result = fn()
|
|
43
52
|
const outputs = Array.isArray(result) ? result : [result]
|
|
@@ -46,6 +55,7 @@ export function trace(fn: () => Tensor | Tensor[]): Graph {
|
|
|
46
55
|
}
|
|
47
56
|
} finally {
|
|
48
57
|
_current = null
|
|
58
|
+
_captureEnabled = false
|
|
49
59
|
}
|
|
50
60
|
return g
|
|
51
61
|
}
|
|
@@ -53,12 +63,15 @@ export function trace(fn: () => Tensor | Tensor[]): Graph {
|
|
|
53
63
|
// Re-enter an existing graph to append more ops. Used by autograd to add
|
|
54
64
|
// backward ops to a graph that's already been traced. `fn` runs with the
|
|
55
65
|
// supplied graph as the current one; any ops it calls append to that graph.
|
|
66
|
+
// Capture is intentionally disabled here — backward / optimizer rules
|
|
67
|
+
// shouldn't publish their internal tensors via `capture()`.
|
|
56
68
|
// Returns whatever `fn` returns.
|
|
57
69
|
export function traceInto<T>(g: Graph, fn: () => T): T {
|
|
58
70
|
if (_current) {
|
|
59
71
|
throw new Error('tensorgrad: traceInto() called while another trace is active')
|
|
60
72
|
}
|
|
61
73
|
_current = g
|
|
74
|
+
// _captureEnabled stays false (default) — explicit, but not toggled.
|
|
62
75
|
try {
|
|
63
76
|
return fn()
|
|
64
77
|
} finally {
|