tensorgrad 0.0.11 → 0.0.13

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. package/README.md +119 -119
  2. package/dist/buffers.js +1 -6
  3. package/dist/buffers.js.map +1 -1
  4. package/dist/codegen.js +30 -28
  5. package/dist/codegen.js.map +1 -1
  6. package/dist/compile.js +39 -68
  7. package/dist/compile.js.map +1 -1
  8. package/dist/grad.js +1 -14
  9. package/dist/grad.js.map +1 -1
  10. package/dist/index.d.ts +740 -14
  11. package/dist/runtime.js +9 -11
  12. package/dist/runtime.js.map +1 -1
  13. package/dist/trace.js +8 -13
  14. package/dist/trace.js.map +1 -1
  15. package/package.json +67 -61
  16. package/src/buffers.ts +1 -6
  17. package/src/codegen.ts +31 -28
  18. package/src/compile.ts +45 -91
  19. package/src/grad.ts +1 -11
  20. package/src/index.ts +47 -47
  21. package/src/runtime.ts +520 -515
  22. package/src/trace.ts +12 -9
  23. package/dist/adam.d.ts +0 -65
  24. package/dist/adam.d.ts.map +0 -1
  25. package/dist/buffers.d.ts +0 -57
  26. package/dist/buffers.d.ts.map +0 -1
  27. package/dist/capture.d.ts +0 -3
  28. package/dist/capture.d.ts.map +0 -1
  29. package/dist/codegen.d.ts +0 -23
  30. package/dist/codegen.d.ts.map +0 -1
  31. package/dist/compile.d.ts +0 -130
  32. package/dist/compile.d.ts.map +0 -1
  33. package/dist/grad.d.ts +0 -8
  34. package/dist/grad.d.ts.map +0 -1
  35. package/dist/index.d.ts.map +0 -1
  36. package/dist/ir.d.ts +0 -207
  37. package/dist/ir.d.ts.map +0 -1
  38. package/dist/module.d.ts +0 -55
  39. package/dist/module.d.ts.map +0 -1
  40. package/dist/nn.d.ts +0 -42
  41. package/dist/nn.d.ts.map +0 -1
  42. package/dist/ops.d.ts +0 -48
  43. package/dist/ops.d.ts.map +0 -1
  44. package/dist/runtime.d.ts +0 -108
  45. package/dist/runtime.d.ts.map +0 -1
  46. package/dist/shape.d.ts +0 -24
  47. package/dist/shape.d.ts.map +0 -1
  48. package/dist/trace.d.ts +0 -9
  49. package/dist/trace.d.ts.map +0 -1
package/src/compile.ts CHANGED
@@ -11,11 +11,11 @@
11
11
  import type { Tensor, Shape, Dtype } from './ir.js'
12
12
  import { trace, tensorInput } from './trace.js'
13
13
  import { appendGrad, type GradResult } from './grad.js'
14
- import { appendAdam, type AdamConfig } from './adam.js'
14
+ import { appendAdam, type AdamConfig, type AdamResult } from './adam.js'
15
15
  import { planBuffers, type BufferPlan } from './buffers.js'
16
16
  import { emitKernels, type KernelSpec } from './codegen.js'
17
17
  import { createRuntime, createForwardRuntime, type CompiledRuntime, type CompiledForward, type RuntimeOpts } from './runtime.js'
18
- import { Module, materializeParams } from './module.js'
18
+ import { Module, materializeParams, type MaterializedParams } from './module.js'
19
19
 
20
20
  /** Declares one input tensor of the model's forward function. The name is the
21
21
  * key in the `inputs:` Record at compile time and the key on the `step()`/
@@ -150,46 +150,39 @@ export async function compileModule<M extends Module, I extends InputDecls = Inp
150
150
  forward: ForwardFn<M, I>,
151
151
  opts: CompileModuleOptions<I> = {},
152
152
  ): Promise<CompiledModule<M>> {
153
- const { runtime, materialized, plan, kernels, ir } = await buildModuleRuntime(
154
- modelFactory, forward, opts, /* sharedParams */ undefined, /* withGrad */ true,
155
- )
156
-
157
- // If Adam is enabled, wrap step() to track the step count and supply lrt
158
- // (and optionally decayShrink, when the user passed a per-step lr schedule).
159
- // Wrap resetOptimizerState() too, so a reset zeros m/v *and* the bias-correction
160
- // counter otherwise the next step would skip Adam's warmup phase.
161
- if (opts.adam) {
162
- wrapStepForAdam(runtime, opts.adam, ir)
163
- }
153
+ const { graph, materialized } = traceModule(modelFactory, forward, opts.inputs ?? {})
154
+ const { paramGrads, loss } = appendGrad(graph)
155
+ const adamResult = opts.adam
156
+ ? appendAdam(graph, paramGrads, materialized.tensors, opts.adam, materialized.decayFlags)
157
+ : undefined
158
+
159
+ const plan = planBuffers(graph, paramGrads, adamResult?.writebacks ?? [])
160
+ const kernels = emitKernels(graph, plan)
161
+ const lossBufferId = plan.tensorToBuffer.get(loss.id)!
162
+ const runtime = await createRuntime(plan, kernels, lossBufferId, opts)
164
163
 
165
- // Auto-upload initial param values. Always wanted at this entry point —
166
- // training runtimes own their params and need them randomized before step 1.
164
+ if (adamResult) wrapStepForAdam(runtime, adamResult)
167
165
  uploadInitialParams(plan, materialized.initFns, runtime, /* sharedParams */ undefined)
168
166
 
169
- const kernelCount = kernels.filter(k => k.wgsl).length
167
+ const ir: CompiledIR = { graph, paramGrads, loss, plan, kernels }
168
+ const kernelCount = countKernels(kernels)
170
169
 
171
170
  const reset = () => {
172
171
  uploadInitialParams(plan, materialized.initFns, runtime, undefined)
173
172
  runtime.resetOptimizerState()
174
173
  }
175
174
 
176
- const compileForwardMethod = async <J extends InputDecls>(
175
+ const compileForwardMethod = <J extends InputDecls>(
177
176
  forwardFn: ForwardFn<M, J>,
178
177
  fOpts: CompileForwardMethodOptions<J> = {},
179
- ): Promise<CompiledForwardModule> => {
180
- return compileForward<M, J>(modelFactory, forwardFn, {
178
+ ): Promise<CompiledForwardModule> =>
179
+ compileForward<M, J>(modelFactory, forwardFn, {
181
180
  ...fOpts,
182
181
  device: runtime.device,
183
182
  sharedParams: runtime.params,
184
183
  })
185
- }
186
184
 
187
- return Object.assign(runtime, {
188
- ir,
189
- kernelCount,
190
- reset,
191
- compileForward: compileForwardMethod,
192
- })
185
+ return Object.assign(runtime, { ir, kernelCount, reset, compileForward: compileForwardMethod })
193
186
  }
194
187
 
195
188
  // ============================================================================
@@ -220,47 +213,37 @@ export async function compileForward<M extends Module, I extends InputDecls = In
220
213
  forward: ForwardFn<M, I>,
221
214
  opts: CompileForwardOptions<I> = {},
222
215
  ): Promise<CompiledForwardModule> {
223
- const sharedParams = opts.sharedParams
224
- const { runtime, materialized, plan, kernels, ir } = await buildModuleRuntime(
225
- modelFactory, forward, opts, sharedParams, /* withGrad */ false,
226
- )
216
+ const { graph, materialized } = traceModule(modelFactory, forward, opts.inputs ?? {})
217
+ const outputTensor = graph.tensors[graph.outputs[0]!]!
227
218
 
228
- // Auto-upload initial values for any params this graph owns. With
229
- // `sharedParams` covering everything, this is a no-op.
230
- uploadInitialParams(plan, materialized.initFns, runtime, sharedParams)
219
+ const plan = planBuffers(graph, /* paramGrads */ {})
220
+ const kernels = emitKernels(graph, plan)
221
+ const outputBufferId = plan.tensorToBuffer.get(outputTensor.id)!
222
+ const runtime = await createForwardRuntime(plan, kernels, outputBufferId, opts)
223
+
224
+ uploadInitialParams(plan, materialized.initFns, runtime, opts.sharedParams)
231
225
 
232
- const kernelCount = kernels.filter(k => k.wgsl).length
233
- return Object.assign(runtime, { ir, kernelCount })
226
+ const ir: CompiledIR = { graph, paramGrads: {}, loss: outputTensor, plan, kernels }
227
+ return Object.assign(runtime, { ir, kernelCount: countKernels(kernels) })
234
228
  }
235
229
 
236
230
  // ============================================================================
237
231
  // Internals
238
232
  // ============================================================================
239
233
 
234
+ type Graph = ReturnType<typeof trace>
240
235
  type InitFn = (size: number, shape: readonly number[]) => Float32Array
241
236
 
242
- interface BuiltRuntime {
243
- runtime: CompiledRuntime
244
- materialized: ReturnType<typeof materializeParams>
245
- plan: BufferPlan
246
- kernels: KernelSpec[]
247
- ir: CompiledIR
248
- }
249
-
250
- /** Shared body of compileModule + compileForward. The training and forward
251
- * pipelines diverge only in (a) whether grad/Adam are appended and (b)
252
- * whether the output buffer is the loss scalar or the user's returned
253
- * tensor — both come out of the same trace and codegen path. */
254
- async function buildModuleRuntime<M extends Module, I extends InputDecls>(
237
+ /** Trace the forward function with a fresh model + tensor inputs and capture
238
+ * the materialized params. Shared by both compile entry points; everything
239
+ * past this point (grad/adam/buffer plan/runtime) diverges. */
240
+ function traceModule<M extends Module, I extends InputDecls>(
255
241
  modelFactory: () => M,
256
242
  forward: ForwardFn<M, I>,
257
- opts: CompileModuleOptions<I> | CompileForwardOptions<I>,
258
- sharedParams: Map<string, GPUBuffer> | undefined,
259
- withGrad: boolean,
260
- ): Promise<BuiltRuntime> {
261
- const inputDecls: InputDecls = opts.inputs ?? {}
243
+ inputDecls: InputDecls,
244
+ ): { graph: Graph; materialized: MaterializedParams } {
262
245
  const model = modelFactory()
263
- let materialized: ReturnType<typeof materializeParams> = { tensors: {}, initFns: {}, decayFlags: {} }
246
+ let materialized: MaterializedParams = { tensors: {}, initFns: {}, decayFlags: {} }
264
247
  const graph = trace(() => {
265
248
  materialized = materializeParams(model)
266
249
  const inputTensors: Record<string, Tensor> = {}
@@ -269,45 +252,17 @@ async function buildModuleRuntime<M extends Module, I extends InputDecls>(
269
252
  }
270
253
  return forward(model, inputTensors as InputsTensors<I>)
271
254
  })
272
-
273
- let paramGrads: GradResult['paramGrads'] = {}
274
- let outputTensor: Tensor
275
- let adamWritebacks: ReturnType<typeof appendAdam>['writebacks'] = []
276
-
277
- if (withGrad) {
278
- const gradResult = appendGrad(graph)
279
- paramGrads = gradResult.paramGrads
280
- outputTensor = gradResult.loss
281
- const adamCfg = (opts as CompileModuleOptions).adam
282
- if (adamCfg) {
283
- const adamResult = appendAdam(graph, paramGrads, materialized.tensors, adamCfg, materialized.decayFlags)
284
- adamWritebacks = adamResult.writebacks
285
- // Stash adam result on the graph so wrapStepForAdam can find it.
286
- ;(graph as Graph & { __adam?: ReturnType<typeof appendAdam> }).__adam = adamResult
287
- }
288
- } else {
289
- outputTensor = graph.tensors[graph.outputs[0]!]!
290
- }
291
-
292
- const plan = planBuffers(graph, paramGrads, adamWritebacks)
293
- const kernels = emitKernels(graph, plan)
294
- const outputBufferId = plan.tensorToBuffer.get(outputTensor.id)!
295
- // exactOptionalPropertyTypes: only include sharedParams when defined.
296
- const runtimeOpts: RuntimeOpts = sharedParams
297
- ? { ...opts, sharedParams }
298
- : { ...opts }
299
- const runtime = withGrad
300
- ? await createRuntime(plan, kernels, outputBufferId, runtimeOpts)
301
- : await createForwardRuntime(plan, kernels, outputBufferId, runtimeOpts)
302
-
303
- const ir: CompiledIR = { graph, paramGrads, loss: outputTensor, plan, kernels }
304
- return { runtime: runtime as CompiledRuntime, materialized, plan, kernels, ir }
255
+ return { graph, materialized }
305
256
  }
306
257
 
307
- type Graph = ReturnType<typeof trace>
258
+ const countKernels = (kernels: KernelSpec[]): number => kernels.filter(k => k.wgsl).length
308
259
 
309
- function wrapStepForAdam(runtime: CompiledRuntime, adamCfg: AdamConfig, ir: CompiledIR): void {
310
- const adamResult = (ir.graph as Graph & { __adam?: ReturnType<typeof appendAdam> }).__adam!
260
+ /** Wrap the runtime's step() to inject Adam's per-step `lrt` (bias-corrected
261
+ * effective LR) and, when the user supplied a per-step lr schedule, the
262
+ * decayShrink scalar. Also wraps resetOptimizerState() so a reset zeros
263
+ * Adam's m/v *and* the bias-correction step counter — otherwise the next
264
+ * step would skip Adam's warmup phase. */
265
+ function wrapStepForAdam(runtime: CompiledRuntime, adamResult: AdamResult): void {
311
266
  const { lrtInputName, decayShrinkInputName, config } = adamResult
312
267
  let t = 0
313
268
  const lrtBuf = new Float32Array(1)
@@ -333,7 +288,6 @@ function wrapStepForAdam(runtime: CompiledRuntime, adamCfg: AdamConfig, ir: Comp
333
288
  t = 0
334
289
  innerReset()
335
290
  }
336
- void adamCfg
337
291
  }
338
292
 
339
293
  /** Build a Record<paramName, Float32Array> by running each param's init
package/src/grad.ts CHANGED
@@ -25,6 +25,7 @@ import {
25
25
  sumLast, where,
26
26
  } from './ops.js'
27
27
  import { traceInto } from './trace.js'
28
+ import { shapesEqual } from './shape.js'
28
29
 
29
30
  // ============================================================================
30
31
  // Public API
@@ -121,11 +122,6 @@ function unbroadcast(cotan: Tensor, toShape: Shape): Tensor {
121
122
  return sumToShape(cotan, toShape)
122
123
  }
123
124
 
124
- function shapesEqual(a: Shape, b: Shape): boolean {
125
- if (a.length !== b.length) return false
126
- for (let i = 0; i < a.length; i++) if (a[i] !== b[i]) return false
127
- return true
128
- }
129
125
 
130
126
  // ============================================================================
131
127
  // Transpose rules
@@ -435,12 +431,6 @@ function runTransposeRule(
435
431
  // Helpers
436
432
  // ============================================================================
437
433
 
438
- function identityPerm(rank: number): number[] {
439
- const p: number[] = new Array(rank)
440
- for (let i = 0; i < rank; i++) p[i] = i
441
- return p
442
- }
443
-
444
434
  function invertPerm(perm: readonly number[]): number[] {
445
435
  const inv: number[] = new Array(perm.length)
446
436
  for (let i = 0; i < perm.length; i++) inv[perm[i]!] = i
package/src/index.ts CHANGED
@@ -1,47 +1,47 @@
1
- // Public surface. Bulb code imports from here.
2
- //
3
- // Phase 1 exports: IR types, op surface, trace driver. Autograd (Phase 2) and
4
- // codegen / compile() (Phase 3+) come later.
5
-
6
- export type { Tensor, Shape, Dtype, OpNode, Graph, CallSite } from './ir.js'
7
- export { ShapeError } from './shape.js'
8
- export { trace, traceInto, paramInput, tensorInput, stateInput } from './trace.js'
9
- export { capture } from './capture.js'
10
- export {
11
- // Element-wise arithmetic. The binops accept Tensor or JS-number for the second arg.
12
- add, sub, mul, div,
13
- // Element-wise unary
14
- sqrt, rsqrt, log, exp, relu,
15
- // Comparisons + select
16
- less, greater, where,
17
- // Reductions over the last axis (other axes via reshape/transpose first)
18
- meanLast, sumLast, sumAll,
19
- // Shape ops
20
- reshape, transpose, swapAxes,
21
- // Linear algebra
22
- matmul, matmulBatched,
23
- // Indexing / casting
24
- oneHot, arange, embedding,
25
- // ML primitives — fused for the transformer
26
- softmaxCausalLast, logSoftmaxLast, whereCausal,
27
- // Slicing
28
- sliceLastRange,
29
- } from './ops.js'
30
-
31
- // Note: addScalar/mulScalar/broadcastTo/sumToShape/constScalar/reluGrad/adam_update_*
32
- // are autograd/optimizer building blocks. They live in ops.ts (so grad.ts and
33
- // adam.ts can import them) but aren't part of the public API — `add`/`mul`
34
- // overload on JS numbers, `where` subsumes the rest.
35
- export { appendGrad, type GradResult } from './grad.js'
36
- export { appendAdam, type AdamConfig, type AdamResult } from './adam.js'
37
- export { planBuffers, type BufferPlan, type BufferSpec, type Writeback, type WritebackDecl } from './buffers.js'
38
- export { emitKernels, type KernelSpec } from './codegen.js'
39
- export { createRuntime, createForwardRuntime, Captures, type CompiledRuntime, type CompiledForward, type RuntimeOpts, type RunOptions, type StepResult, type RunResult } from './runtime.js'
40
- export {
41
- compile, compileToIR, compileModule, compileForward,
42
- type CompiledIR, type CompileModuleOptions, type CompileForwardOptions, type CompileForwardMethodOptions,
43
- type CompiledModule, type CompiledForwardModule,
44
- type InputDecl, type InputDecls, type InputsTensors, type ForwardFn,
45
- } from './compile.js'
46
- export { Module, materializeParams, type InitSpec, type ParamOptions, type MaterializedParams } from './module.js'
47
- export * as nn from './nn.js'
1
+ // Public surface. Bulb code imports from here.
2
+ //
3
+ // Phase 1 exports: IR types, op surface, trace driver. Autograd (Phase 2) and
4
+ // codegen / compile() (Phase 3+) come later.
5
+
6
+ export type { Tensor, Shape, Dtype, OpNode, Graph, CallSite } from './ir.js'
7
+ export { ShapeError } from './shape.js'
8
+ export { trace, traceInto, paramInput, tensorInput, stateInput } from './trace.js'
9
+ export { capture } from './capture.js'
10
+ export {
11
+ // Element-wise arithmetic. The binops accept Tensor or JS-number for the second arg.
12
+ add, sub, mul, div,
13
+ // Element-wise unary
14
+ sqrt, rsqrt, log, exp, relu,
15
+ // Comparisons + select
16
+ less, greater, where,
17
+ // Reductions over the last axis (other axes via reshape/transpose first)
18
+ meanLast, sumLast, sumAll,
19
+ // Shape ops
20
+ reshape, transpose, swapAxes,
21
+ // Linear algebra
22
+ matmul, matmulBatched,
23
+ // Indexing / casting
24
+ oneHot, arange, embedding,
25
+ // ML primitives — fused for the transformer
26
+ softmaxCausalLast, logSoftmaxLast, whereCausal,
27
+ // Slicing
28
+ sliceLastRange,
29
+ } from './ops.js'
30
+
31
+ // Note: addScalar/mulScalar/broadcastTo/sumToShape/constScalar/reluGrad/adam_update_*
32
+ // are autograd/optimizer building blocks. They live in ops.ts (so grad.ts and
33
+ // adam.ts can import them) but aren't part of the public API — `add`/`mul`
34
+ // overload on JS numbers, `where` subsumes the rest.
35
+ export { appendGrad, type GradResult } from './grad.js'
36
+ export { appendAdam, type AdamConfig, type AdamResult } from './adam.js'
37
+ export { planBuffers, type BufferPlan, type BufferSpec, type Writeback, type WritebackDecl } from './buffers.js'
38
+ export { emitKernels, type KernelSpec } from './codegen.js'
39
+ export { createRuntime, createForwardRuntime, Captures, type CompiledRuntime, type CompiledForward, type RuntimeOpts, type RunOptions, type StepResult, type RunResult } from './runtime.js'
40
+ export {
41
+ compile, compileToIR, compileModule, compileForward,
42
+ type CompiledIR, type CompileModuleOptions, type CompileForwardOptions, type CompileForwardMethodOptions,
43
+ type CompiledModule, type CompiledForwardModule,
44
+ type InputDecl, type InputDecls, type InputsTensors, type ForwardFn,
45
+ } from './compile.js'
46
+ export { Module, materializeParams, type InitSpec, type ParamOptions, type MaterializedParams } from './module.js'
47
+ export * as nn from './nn.js'