tensorgrad 0.0.11 → 0.0.13
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +119 -119
- package/dist/buffers.js +1 -6
- package/dist/buffers.js.map +1 -1
- package/dist/codegen.js +30 -28
- package/dist/codegen.js.map +1 -1
- package/dist/compile.js +39 -68
- package/dist/compile.js.map +1 -1
- package/dist/grad.js +1 -14
- package/dist/grad.js.map +1 -1
- package/dist/index.d.ts +740 -14
- package/dist/runtime.js +9 -11
- package/dist/runtime.js.map +1 -1
- package/dist/trace.js +8 -13
- package/dist/trace.js.map +1 -1
- package/package.json +67 -61
- package/src/buffers.ts +1 -6
- package/src/codegen.ts +31 -28
- package/src/compile.ts +45 -91
- package/src/grad.ts +1 -11
- package/src/index.ts +47 -47
- package/src/runtime.ts +520 -515
- package/src/trace.ts +12 -9
- package/dist/adam.d.ts +0 -65
- package/dist/adam.d.ts.map +0 -1
- package/dist/buffers.d.ts +0 -57
- package/dist/buffers.d.ts.map +0 -1
- package/dist/capture.d.ts +0 -3
- package/dist/capture.d.ts.map +0 -1
- package/dist/codegen.d.ts +0 -23
- package/dist/codegen.d.ts.map +0 -1
- package/dist/compile.d.ts +0 -130
- package/dist/compile.d.ts.map +0 -1
- package/dist/grad.d.ts +0 -8
- package/dist/grad.d.ts.map +0 -1
- package/dist/index.d.ts.map +0 -1
- package/dist/ir.d.ts +0 -207
- package/dist/ir.d.ts.map +0 -1
- package/dist/module.d.ts +0 -55
- package/dist/module.d.ts.map +0 -1
- package/dist/nn.d.ts +0 -42
- package/dist/nn.d.ts.map +0 -1
- package/dist/ops.d.ts +0 -48
- package/dist/ops.d.ts.map +0 -1
- package/dist/runtime.d.ts +0 -108
- package/dist/runtime.d.ts.map +0 -1
- package/dist/shape.d.ts +0 -24
- package/dist/shape.d.ts.map +0 -1
- package/dist/trace.d.ts +0 -9
- package/dist/trace.d.ts.map +0 -1
package/src/compile.ts
CHANGED
|
@@ -11,11 +11,11 @@
|
|
|
11
11
|
import type { Tensor, Shape, Dtype } from './ir.js'
|
|
12
12
|
import { trace, tensorInput } from './trace.js'
|
|
13
13
|
import { appendGrad, type GradResult } from './grad.js'
|
|
14
|
-
import { appendAdam, type AdamConfig } from './adam.js'
|
|
14
|
+
import { appendAdam, type AdamConfig, type AdamResult } from './adam.js'
|
|
15
15
|
import { planBuffers, type BufferPlan } from './buffers.js'
|
|
16
16
|
import { emitKernels, type KernelSpec } from './codegen.js'
|
|
17
17
|
import { createRuntime, createForwardRuntime, type CompiledRuntime, type CompiledForward, type RuntimeOpts } from './runtime.js'
|
|
18
|
-
import { Module, materializeParams } from './module.js'
|
|
18
|
+
import { Module, materializeParams, type MaterializedParams } from './module.js'
|
|
19
19
|
|
|
20
20
|
/** Declares one input tensor of the model's forward function. The name is the
|
|
21
21
|
* key in the `inputs:` Record at compile time and the key on the `step()`/
|
|
@@ -150,46 +150,39 @@ export async function compileModule<M extends Module, I extends InputDecls = Inp
|
|
|
150
150
|
forward: ForwardFn<M, I>,
|
|
151
151
|
opts: CompileModuleOptions<I> = {},
|
|
152
152
|
): Promise<CompiledModule<M>> {
|
|
153
|
-
const {
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
}
|
|
153
|
+
const { graph, materialized } = traceModule(modelFactory, forward, opts.inputs ?? {})
|
|
154
|
+
const { paramGrads, loss } = appendGrad(graph)
|
|
155
|
+
const adamResult = opts.adam
|
|
156
|
+
? appendAdam(graph, paramGrads, materialized.tensors, opts.adam, materialized.decayFlags)
|
|
157
|
+
: undefined
|
|
158
|
+
|
|
159
|
+
const plan = planBuffers(graph, paramGrads, adamResult?.writebacks ?? [])
|
|
160
|
+
const kernels = emitKernels(graph, plan)
|
|
161
|
+
const lossBufferId = plan.tensorToBuffer.get(loss.id)!
|
|
162
|
+
const runtime = await createRuntime(plan, kernels, lossBufferId, opts)
|
|
164
163
|
|
|
165
|
-
|
|
166
|
-
// training runtimes own their params and need them randomized before step 1.
|
|
164
|
+
if (adamResult) wrapStepForAdam(runtime, adamResult)
|
|
167
165
|
uploadInitialParams(plan, materialized.initFns, runtime, /* sharedParams */ undefined)
|
|
168
166
|
|
|
169
|
-
const
|
|
167
|
+
const ir: CompiledIR = { graph, paramGrads, loss, plan, kernels }
|
|
168
|
+
const kernelCount = countKernels(kernels)
|
|
170
169
|
|
|
171
170
|
const reset = () => {
|
|
172
171
|
uploadInitialParams(plan, materialized.initFns, runtime, undefined)
|
|
173
172
|
runtime.resetOptimizerState()
|
|
174
173
|
}
|
|
175
174
|
|
|
176
|
-
const compileForwardMethod =
|
|
175
|
+
const compileForwardMethod = <J extends InputDecls>(
|
|
177
176
|
forwardFn: ForwardFn<M, J>,
|
|
178
177
|
fOpts: CompileForwardMethodOptions<J> = {},
|
|
179
|
-
): Promise<CompiledForwardModule> =>
|
|
180
|
-
|
|
178
|
+
): Promise<CompiledForwardModule> =>
|
|
179
|
+
compileForward<M, J>(modelFactory, forwardFn, {
|
|
181
180
|
...fOpts,
|
|
182
181
|
device: runtime.device,
|
|
183
182
|
sharedParams: runtime.params,
|
|
184
183
|
})
|
|
185
|
-
}
|
|
186
184
|
|
|
187
|
-
return Object.assign(runtime, {
|
|
188
|
-
ir,
|
|
189
|
-
kernelCount,
|
|
190
|
-
reset,
|
|
191
|
-
compileForward: compileForwardMethod,
|
|
192
|
-
})
|
|
185
|
+
return Object.assign(runtime, { ir, kernelCount, reset, compileForward: compileForwardMethod })
|
|
193
186
|
}
|
|
194
187
|
|
|
195
188
|
// ============================================================================
|
|
@@ -220,47 +213,37 @@ export async function compileForward<M extends Module, I extends InputDecls = In
|
|
|
220
213
|
forward: ForwardFn<M, I>,
|
|
221
214
|
opts: CompileForwardOptions<I> = {},
|
|
222
215
|
): Promise<CompiledForwardModule> {
|
|
223
|
-
const
|
|
224
|
-
const
|
|
225
|
-
modelFactory, forward, opts, sharedParams, /* withGrad */ false,
|
|
226
|
-
)
|
|
216
|
+
const { graph, materialized } = traceModule(modelFactory, forward, opts.inputs ?? {})
|
|
217
|
+
const outputTensor = graph.tensors[graph.outputs[0]!]!
|
|
227
218
|
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
219
|
+
const plan = planBuffers(graph, /* paramGrads */ {})
|
|
220
|
+
const kernels = emitKernels(graph, plan)
|
|
221
|
+
const outputBufferId = plan.tensorToBuffer.get(outputTensor.id)!
|
|
222
|
+
const runtime = await createForwardRuntime(plan, kernels, outputBufferId, opts)
|
|
223
|
+
|
|
224
|
+
uploadInitialParams(plan, materialized.initFns, runtime, opts.sharedParams)
|
|
231
225
|
|
|
232
|
-
const
|
|
233
|
-
return Object.assign(runtime, { ir, kernelCount })
|
|
226
|
+
const ir: CompiledIR = { graph, paramGrads: {}, loss: outputTensor, plan, kernels }
|
|
227
|
+
return Object.assign(runtime, { ir, kernelCount: countKernels(kernels) })
|
|
234
228
|
}
|
|
235
229
|
|
|
236
230
|
// ============================================================================
|
|
237
231
|
// Internals
|
|
238
232
|
// ============================================================================
|
|
239
233
|
|
|
234
|
+
type Graph = ReturnType<typeof trace>
|
|
240
235
|
type InitFn = (size: number, shape: readonly number[]) => Float32Array
|
|
241
236
|
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
kernels: KernelSpec[]
|
|
247
|
-
ir: CompiledIR
|
|
248
|
-
}
|
|
249
|
-
|
|
250
|
-
/** Shared body of compileModule + compileForward. The training and forward
|
|
251
|
-
* pipelines diverge only in (a) whether grad/Adam are appended and (b)
|
|
252
|
-
* whether the output buffer is the loss scalar or the user's returned
|
|
253
|
-
* tensor — both come out of the same trace and codegen path. */
|
|
254
|
-
async function buildModuleRuntime<M extends Module, I extends InputDecls>(
|
|
237
|
+
/** Trace the forward function with a fresh model + tensor inputs and capture
|
|
238
|
+
* the materialized params. Shared by both compile entry points; everything
|
|
239
|
+
* past this point (grad/adam/buffer plan/runtime) diverges. */
|
|
240
|
+
function traceModule<M extends Module, I extends InputDecls>(
|
|
255
241
|
modelFactory: () => M,
|
|
256
242
|
forward: ForwardFn<M, I>,
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
withGrad: boolean,
|
|
260
|
-
): Promise<BuiltRuntime> {
|
|
261
|
-
const inputDecls: InputDecls = opts.inputs ?? {}
|
|
243
|
+
inputDecls: InputDecls,
|
|
244
|
+
): { graph: Graph; materialized: MaterializedParams } {
|
|
262
245
|
const model = modelFactory()
|
|
263
|
-
let materialized:
|
|
246
|
+
let materialized: MaterializedParams = { tensors: {}, initFns: {}, decayFlags: {} }
|
|
264
247
|
const graph = trace(() => {
|
|
265
248
|
materialized = materializeParams(model)
|
|
266
249
|
const inputTensors: Record<string, Tensor> = {}
|
|
@@ -269,45 +252,17 @@ async function buildModuleRuntime<M extends Module, I extends InputDecls>(
|
|
|
269
252
|
}
|
|
270
253
|
return forward(model, inputTensors as InputsTensors<I>)
|
|
271
254
|
})
|
|
272
|
-
|
|
273
|
-
let paramGrads: GradResult['paramGrads'] = {}
|
|
274
|
-
let outputTensor: Tensor
|
|
275
|
-
let adamWritebacks: ReturnType<typeof appendAdam>['writebacks'] = []
|
|
276
|
-
|
|
277
|
-
if (withGrad) {
|
|
278
|
-
const gradResult = appendGrad(graph)
|
|
279
|
-
paramGrads = gradResult.paramGrads
|
|
280
|
-
outputTensor = gradResult.loss
|
|
281
|
-
const adamCfg = (opts as CompileModuleOptions).adam
|
|
282
|
-
if (adamCfg) {
|
|
283
|
-
const adamResult = appendAdam(graph, paramGrads, materialized.tensors, adamCfg, materialized.decayFlags)
|
|
284
|
-
adamWritebacks = adamResult.writebacks
|
|
285
|
-
// Stash adam result on the graph so wrapStepForAdam can find it.
|
|
286
|
-
;(graph as Graph & { __adam?: ReturnType<typeof appendAdam> }).__adam = adamResult
|
|
287
|
-
}
|
|
288
|
-
} else {
|
|
289
|
-
outputTensor = graph.tensors[graph.outputs[0]!]!
|
|
290
|
-
}
|
|
291
|
-
|
|
292
|
-
const plan = planBuffers(graph, paramGrads, adamWritebacks)
|
|
293
|
-
const kernels = emitKernels(graph, plan)
|
|
294
|
-
const outputBufferId = plan.tensorToBuffer.get(outputTensor.id)!
|
|
295
|
-
// exactOptionalPropertyTypes: only include sharedParams when defined.
|
|
296
|
-
const runtimeOpts: RuntimeOpts = sharedParams
|
|
297
|
-
? { ...opts, sharedParams }
|
|
298
|
-
: { ...opts }
|
|
299
|
-
const runtime = withGrad
|
|
300
|
-
? await createRuntime(plan, kernels, outputBufferId, runtimeOpts)
|
|
301
|
-
: await createForwardRuntime(plan, kernels, outputBufferId, runtimeOpts)
|
|
302
|
-
|
|
303
|
-
const ir: CompiledIR = { graph, paramGrads, loss: outputTensor, plan, kernels }
|
|
304
|
-
return { runtime: runtime as CompiledRuntime, materialized, plan, kernels, ir }
|
|
255
|
+
return { graph, materialized }
|
|
305
256
|
}
|
|
306
257
|
|
|
307
|
-
|
|
258
|
+
const countKernels = (kernels: KernelSpec[]): number => kernels.filter(k => k.wgsl).length
|
|
308
259
|
|
|
309
|
-
|
|
310
|
-
|
|
260
|
+
/** Wrap the runtime's step() to inject Adam's per-step `lrt` (bias-corrected
|
|
261
|
+
* effective LR) and, when the user supplied a per-step lr schedule, the
|
|
262
|
+
* decayShrink scalar. Also wraps resetOptimizerState() so a reset zeros
|
|
263
|
+
* Adam's m/v *and* the bias-correction step counter — otherwise the next
|
|
264
|
+
* step would skip Adam's warmup phase. */
|
|
265
|
+
function wrapStepForAdam(runtime: CompiledRuntime, adamResult: AdamResult): void {
|
|
311
266
|
const { lrtInputName, decayShrinkInputName, config } = adamResult
|
|
312
267
|
let t = 0
|
|
313
268
|
const lrtBuf = new Float32Array(1)
|
|
@@ -333,7 +288,6 @@ function wrapStepForAdam(runtime: CompiledRuntime, adamCfg: AdamConfig, ir: Comp
|
|
|
333
288
|
t = 0
|
|
334
289
|
innerReset()
|
|
335
290
|
}
|
|
336
|
-
void adamCfg
|
|
337
291
|
}
|
|
338
292
|
|
|
339
293
|
/** Build a Record<paramName, Float32Array> by running each param's init
|
package/src/grad.ts
CHANGED
|
@@ -25,6 +25,7 @@ import {
|
|
|
25
25
|
sumLast, where,
|
|
26
26
|
} from './ops.js'
|
|
27
27
|
import { traceInto } from './trace.js'
|
|
28
|
+
import { shapesEqual } from './shape.js'
|
|
28
29
|
|
|
29
30
|
// ============================================================================
|
|
30
31
|
// Public API
|
|
@@ -121,11 +122,6 @@ function unbroadcast(cotan: Tensor, toShape: Shape): Tensor {
|
|
|
121
122
|
return sumToShape(cotan, toShape)
|
|
122
123
|
}
|
|
123
124
|
|
|
124
|
-
function shapesEqual(a: Shape, b: Shape): boolean {
|
|
125
|
-
if (a.length !== b.length) return false
|
|
126
|
-
for (let i = 0; i < a.length; i++) if (a[i] !== b[i]) return false
|
|
127
|
-
return true
|
|
128
|
-
}
|
|
129
125
|
|
|
130
126
|
// ============================================================================
|
|
131
127
|
// Transpose rules
|
|
@@ -435,12 +431,6 @@ function runTransposeRule(
|
|
|
435
431
|
// Helpers
|
|
436
432
|
// ============================================================================
|
|
437
433
|
|
|
438
|
-
function identityPerm(rank: number): number[] {
|
|
439
|
-
const p: number[] = new Array(rank)
|
|
440
|
-
for (let i = 0; i < rank; i++) p[i] = i
|
|
441
|
-
return p
|
|
442
|
-
}
|
|
443
|
-
|
|
444
434
|
function invertPerm(perm: readonly number[]): number[] {
|
|
445
435
|
const inv: number[] = new Array(perm.length)
|
|
446
436
|
for (let i = 0; i < perm.length; i++) inv[perm[i]!] = i
|
package/src/index.ts
CHANGED
|
@@ -1,47 +1,47 @@
|
|
|
1
|
-
// Public surface. Bulb code imports from here.
|
|
2
|
-
//
|
|
3
|
-
// Phase 1 exports: IR types, op surface, trace driver. Autograd (Phase 2) and
|
|
4
|
-
// codegen / compile() (Phase 3+) come later.
|
|
5
|
-
|
|
6
|
-
export type { Tensor, Shape, Dtype, OpNode, Graph, CallSite } from './ir.js'
|
|
7
|
-
export { ShapeError } from './shape.js'
|
|
8
|
-
export { trace, traceInto, paramInput, tensorInput, stateInput } from './trace.js'
|
|
9
|
-
export { capture } from './capture.js'
|
|
10
|
-
export {
|
|
11
|
-
// Element-wise arithmetic. The binops accept Tensor or JS-number for the second arg.
|
|
12
|
-
add, sub, mul, div,
|
|
13
|
-
// Element-wise unary
|
|
14
|
-
sqrt, rsqrt, log, exp, relu,
|
|
15
|
-
// Comparisons + select
|
|
16
|
-
less, greater, where,
|
|
17
|
-
// Reductions over the last axis (other axes via reshape/transpose first)
|
|
18
|
-
meanLast, sumLast, sumAll,
|
|
19
|
-
// Shape ops
|
|
20
|
-
reshape, transpose, swapAxes,
|
|
21
|
-
// Linear algebra
|
|
22
|
-
matmul, matmulBatched,
|
|
23
|
-
// Indexing / casting
|
|
24
|
-
oneHot, arange, embedding,
|
|
25
|
-
// ML primitives — fused for the transformer
|
|
26
|
-
softmaxCausalLast, logSoftmaxLast, whereCausal,
|
|
27
|
-
// Slicing
|
|
28
|
-
sliceLastRange,
|
|
29
|
-
} from './ops.js'
|
|
30
|
-
|
|
31
|
-
// Note: addScalar/mulScalar/broadcastTo/sumToShape/constScalar/reluGrad/adam_update_*
|
|
32
|
-
// are autograd/optimizer building blocks. They live in ops.ts (so grad.ts and
|
|
33
|
-
// adam.ts can import them) but aren't part of the public API — `add`/`mul`
|
|
34
|
-
// overload on JS numbers, `where` subsumes the rest.
|
|
35
|
-
export { appendGrad, type GradResult } from './grad.js'
|
|
36
|
-
export { appendAdam, type AdamConfig, type AdamResult } from './adam.js'
|
|
37
|
-
export { planBuffers, type BufferPlan, type BufferSpec, type Writeback, type WritebackDecl } from './buffers.js'
|
|
38
|
-
export { emitKernels, type KernelSpec } from './codegen.js'
|
|
39
|
-
export { createRuntime, createForwardRuntime, Captures, type CompiledRuntime, type CompiledForward, type RuntimeOpts, type RunOptions, type StepResult, type RunResult } from './runtime.js'
|
|
40
|
-
export {
|
|
41
|
-
compile, compileToIR, compileModule, compileForward,
|
|
42
|
-
type CompiledIR, type CompileModuleOptions, type CompileForwardOptions, type CompileForwardMethodOptions,
|
|
43
|
-
type CompiledModule, type CompiledForwardModule,
|
|
44
|
-
type InputDecl, type InputDecls, type InputsTensors, type ForwardFn,
|
|
45
|
-
} from './compile.js'
|
|
46
|
-
export { Module, materializeParams, type InitSpec, type ParamOptions, type MaterializedParams } from './module.js'
|
|
47
|
-
export * as nn from './nn.js'
|
|
1
|
+
// Public surface. Bulb code imports from here.
|
|
2
|
+
//
|
|
3
|
+
// Phase 1 exports: IR types, op surface, trace driver. Autograd (Phase 2) and
|
|
4
|
+
// codegen / compile() (Phase 3+) come later.
|
|
5
|
+
|
|
6
|
+
export type { Tensor, Shape, Dtype, OpNode, Graph, CallSite } from './ir.js'
|
|
7
|
+
export { ShapeError } from './shape.js'
|
|
8
|
+
export { trace, traceInto, paramInput, tensorInput, stateInput } from './trace.js'
|
|
9
|
+
export { capture } from './capture.js'
|
|
10
|
+
export {
|
|
11
|
+
// Element-wise arithmetic. The binops accept Tensor or JS-number for the second arg.
|
|
12
|
+
add, sub, mul, div,
|
|
13
|
+
// Element-wise unary
|
|
14
|
+
sqrt, rsqrt, log, exp, relu,
|
|
15
|
+
// Comparisons + select
|
|
16
|
+
less, greater, where,
|
|
17
|
+
// Reductions over the last axis (other axes via reshape/transpose first)
|
|
18
|
+
meanLast, sumLast, sumAll,
|
|
19
|
+
// Shape ops
|
|
20
|
+
reshape, transpose, swapAxes,
|
|
21
|
+
// Linear algebra
|
|
22
|
+
matmul, matmulBatched,
|
|
23
|
+
// Indexing / casting
|
|
24
|
+
oneHot, arange, embedding,
|
|
25
|
+
// ML primitives — fused for the transformer
|
|
26
|
+
softmaxCausalLast, logSoftmaxLast, whereCausal,
|
|
27
|
+
// Slicing
|
|
28
|
+
sliceLastRange,
|
|
29
|
+
} from './ops.js'
|
|
30
|
+
|
|
31
|
+
// Note: addScalar/mulScalar/broadcastTo/sumToShape/constScalar/reluGrad/adam_update_*
|
|
32
|
+
// are autograd/optimizer building blocks. They live in ops.ts (so grad.ts and
|
|
33
|
+
// adam.ts can import them) but aren't part of the public API — `add`/`mul`
|
|
34
|
+
// overload on JS numbers, `where` subsumes the rest.
|
|
35
|
+
export { appendGrad, type GradResult } from './grad.js'
|
|
36
|
+
export { appendAdam, type AdamConfig, type AdamResult } from './adam.js'
|
|
37
|
+
export { planBuffers, type BufferPlan, type BufferSpec, type Writeback, type WritebackDecl } from './buffers.js'
|
|
38
|
+
export { emitKernels, type KernelSpec } from './codegen.js'
|
|
39
|
+
export { createRuntime, createForwardRuntime, Captures, type CompiledRuntime, type CompiledForward, type RuntimeOpts, type RunOptions, type StepResult, type RunResult } from './runtime.js'
|
|
40
|
+
export {
|
|
41
|
+
compile, compileToIR, compileModule, compileForward,
|
|
42
|
+
type CompiledIR, type CompileModuleOptions, type CompileForwardOptions, type CompileForwardMethodOptions,
|
|
43
|
+
type CompiledModule, type CompiledForwardModule,
|
|
44
|
+
type InputDecl, type InputDecls, type InputsTensors, type ForwardFn,
|
|
45
|
+
} from './compile.js'
|
|
46
|
+
export { Module, materializeParams, type InitSpec, type ParamOptions, type MaterializedParams } from './module.js'
|
|
47
|
+
export * as nn from './nn.js'
|