tensorgrad 0.0.2 → 0.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. package/README.md +7 -9
  2. package/dist/buffers.d.ts +1 -0
  3. package/dist/buffers.d.ts.map +1 -1
  4. package/dist/buffers.js +12 -1
  5. package/dist/buffers.js.map +1 -1
  6. package/dist/capture.d.ts +3 -0
  7. package/dist/capture.d.ts.map +1 -0
  8. package/dist/capture.js +33 -0
  9. package/dist/capture.js.map +1 -0
  10. package/dist/compile.d.ts +33 -5
  11. package/dist/compile.d.ts.map +1 -1
  12. package/dist/compile.js +96 -11
  13. package/dist/compile.js.map +1 -1
  14. package/dist/index.d.ts +5 -3
  15. package/dist/index.d.ts.map +1 -1
  16. package/dist/index.js +4 -2
  17. package/dist/index.js.map +1 -1
  18. package/dist/ir.d.ts +1 -0
  19. package/dist/ir.d.ts.map +1 -1
  20. package/dist/ir.js +1 -1
  21. package/dist/ir.js.map +1 -1
  22. package/dist/module.d.ts +30 -4
  23. package/dist/module.d.ts.map +1 -1
  24. package/dist/module.js +39 -13
  25. package/dist/module.js.map +1 -1
  26. package/dist/nn.d.ts +19 -0
  27. package/dist/nn.d.ts.map +1 -0
  28. package/dist/nn.js +60 -0
  29. package/dist/nn.js.map +1 -0
  30. package/dist/runtime.d.ts +79 -4
  31. package/dist/runtime.d.ts.map +1 -1
  32. package/dist/runtime.js +153 -19
  33. package/dist/runtime.js.map +1 -1
  34. package/dist/trace.d.ts +1 -0
  35. package/dist/trace.d.ts.map +1 -1
  36. package/dist/trace.js +12 -0
  37. package/dist/trace.js.map +1 -1
  38. package/package.json +1 -2
  39. package/src/buffers.ts +14 -1
  40. package/src/capture.ts +36 -0
  41. package/src/compile.ts +112 -13
  42. package/src/index.ts +5 -3
  43. package/src/ir.ts +5 -1
  44. package/src/module.ts +75 -11
  45. package/src/nn.ts +59 -0
  46. package/src/runtime.ts +260 -22
  47. package/src/trace.ts +13 -0
  48. package/SPEC.md +0 -293
package/dist/trace.js CHANGED
@@ -16,6 +16,10 @@
16
16
  import { makeGraph, addOp, captureSite } from './ir.js';
17
17
  // Module-local: the graph being built right now, or null if no trace is active.
18
18
  let _current = null;
19
+ // Module-local: whether `capture(name, t)` calls should register on the current
20
+ // graph. True only during the user's forward trace; false during `traceInto`
21
+ // (autograd / optimizer ops shouldn't accidentally publish gradient tensors).
22
+ let _captureEnabled = false;
19
23
  export function currentGraph() {
20
24
  if (!_current) {
21
25
  throw new Error('tensorgrad: ops can only be called inside trace(). ' +
@@ -23,6 +27,9 @@ export function currentGraph() {
23
27
  }
24
28
  return _current;
25
29
  }
30
+ export function isCaptureEnabled() {
31
+ return _captureEnabled;
32
+ }
26
33
  // Run `fn` with a fresh graph as the current one; capture and return the graph.
27
34
  // `fn` must return the tensor (or array of tensors) to mark as graph outputs.
28
35
  export function trace(fn) {
@@ -31,6 +38,7 @@ export function trace(fn) {
31
38
  }
32
39
  const g = makeGraph();
33
40
  _current = g;
41
+ _captureEnabled = true;
34
42
  try {
35
43
  const result = fn();
36
44
  const outputs = Array.isArray(result) ? result : [result];
@@ -41,18 +49,22 @@ export function trace(fn) {
41
49
  }
42
50
  finally {
43
51
  _current = null;
52
+ _captureEnabled = false;
44
53
  }
45
54
  return g;
46
55
  }
47
56
  // Re-enter an existing graph to append more ops. Used by autograd to add
48
57
  // backward ops to a graph that's already been traced. `fn` runs with the
49
58
  // supplied graph as the current one; any ops it calls append to that graph.
59
+ // Capture is intentionally disabled here — backward / optimizer rules
60
+ // shouldn't publish their internal tensors via `capture()`.
50
61
  // Returns whatever `fn` returns.
51
62
  export function traceInto(g, fn) {
52
63
  if (_current) {
53
64
  throw new Error('tensorgrad: traceInto() called while another trace is active');
54
65
  }
55
66
  _current = g;
67
+ // _captureEnabled stays false (default) — explicit, but not toggled.
56
68
  try {
57
69
  return fn();
58
70
  }
package/dist/trace.js.map CHANGED
@@ -1 +1 @@
1
- {"version":3,"file":"trace.js","sourceRoot":"","sources":["../src/trace.ts"],"names":[],"mappings":"AAAA,6EAA6E;AAC7E,2EAA2E;AAC3E,EAAE;AACF,SAAS;AACT,EAAE;AACF,gCAAgC;AAChC,gDAAgD;AAChD,+CAA+C;AAC/C,mDAAmD;AACnD,6BAA6B;AAC7B,OAAO;AACP,EAAE;AACF,+EAA+E;AAC/E,+EAA+E;AAC/E,+CAA+C;AAG/C,OAAO,EAAE,SAAS,EAAE,KAAK,EAAE,WAAW,EAAE,MAAM,SAAS,CAAA;AAEvD,gFAAgF;AAChF,IAAI,QAAQ,GAAiB,IAAI,CAAA;AAEjC,MAAM,UAAU,YAAY;IAC1B,IAAI,CAAC,QAAQ,EAAE,CAAC;QACd,MAAM,IAAI,KAAK,CACb,qDAAqD;YACrD,2CAA2C,CAC5C,CAAA;IACH,CAAC;IACD,OAAO,QAAQ,CAAA;AACjB,CAAC;AAED,gFAAgF;AAChF,8EAA8E;AAC9E,MAAM,UAAU,KAAK,CAAC,EAA2B;IAC/C,IAAI,QAAQ,EAAE,CAAC;QACb,MAAM,IAAI,KAAK,CAAC,6CAA6C,CAAC,CAAA;IAChE,CAAC;IACD,MAAM,CAAC,GAAG,SAAS,EAAE,CAAA;IACrB,QAAQ,GAAG,CAAC,CAAA;IACZ,IAAI,CAAC;QACH,MAAM,MAAM,GAAG,EAAE,EAAE,CAAA;QACnB,MAAM,OAAO,GAAG,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAA;QACzD,KAAK,MAAM,CAAC,IAAI,OAAO,EAAE,CAAC;YACxB,CAAC;YAAC,CAAC,CAAC,OAAoB,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,CAAA;QACrC,CAAC;IACH,CAAC;YAAS,CAAC;QACT,QAAQ,GAAG,IAAI,CAAA;IACjB,CAAC;IACD,OAAO,CAAC,CAAA;AACV,CAAC;AAED,yEAAyE;AACzE,yEAAyE;AACzE,4EAA4E;AAC5E,iCAAiC;AACjC,MAAM,UAAU,SAAS,CAAI,CAAQ,EAAE,EAAW;IAChD,IAAI,QAAQ,EAAE,CAAC;QACb,MAAM,IAAI,KAAK,CAAC,8DAA8D,CAAC,CAAA;IACjF,CAAC;IACD,QAAQ,GAAG,CAAC,CAAA;IACZ,IAAI,CAAC;QACH,OAAO,EAAE,EAAE,CAAA;IACb,CAAC;YAAS,CAAC;QACT,QAAQ,GAAG,IAAI,CAAA;IACjB,CAAC;AACH,CAAC;AAED,+EAA+E;AAC/E,6EAA6E;AAC7E,4EAA4E;AAC5E,sBAAsB;AAEtB,MAAM,UAAU,UAAU,CAAC,IAAY,EAAE,KAAY,EAAE,QAAe,KAAK;IACzE,MAAM,CAAC,GAAG,YAAY,EAAE,CAAA;IACxB,IAAI,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,IAAI,KAAK,aAAa,IAAI,EAAE,CAAC,IAAI,KAAK,cAAc,CAAC,IAAI,EAAE,CAAC,IAAI,KAAK,IAAI,CAAC,EAAE,CAAC;QACpG,MAAM,IAAI,KAAK,CAAC,2BAA2B,IAAI,8BAA8B,CAAC,CAAA;IAChF,CAAC;IACD,MAAM,IAAI,GAAG,WAAW,CAAC,YAAY,CAAC,CAAA;IACtC,OAAO,KAAK,CAAC,CAAC,EAAE,aAAa,EAAE,KAAK,EAAE,KAAK,EAAE,IAAI,EAAE,EAAE,IAAI,EAAS,CAAC,CAAA;AACrE,CAAC;AAED,MAAM,UAAU,WAAW,CAAC,IAAY,EAAE,KAAY,EAAE,QAAe,KAAK;IAC1E,MAAM,CAAC,GAAG,YAAY,EAAE,CAAA;IACxB,IAAI,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,IAAI,KAAK,aAAa,IAAI,EAAE,CAAC,IAAI,KAAK,cAAc,CAAC,IAAI,EAAE,CAAC,IAAI,KAAK,IAAI,CAAC,EAAE,CAAC;QACpG,MAAM,IAAI,KAAK,CAAC,2BAA2B,IAAI,8BAA8B,CAAC,CAAA;IAChF,CAAC;IACD,MAAM,IAAI,GAAG,WAAW,CAAC,aAAa,CAAC,CAAA;IACvC,OAAO,KAAK,CAAC,CAAC,EAAE,cAAc,EAAE,KAAK,EAAE,KAAK,EAAE,IAAI,EAAE,EAAE,IAAI,EAAS,CAAC,CAAA;AACtE,CAAC;AAED,uFAAuF;AACvF,mFAAmF;AACnF,MAAM,UAAU,UAAU,CAAC,IAAY,EAAE,KAAY,EAAE,QAAe,KAAK,EAAE,SAAS,GAAG,CAAC;IACxF,MAAM,CAAC,GAAG,YAAY,EAAE,CAAA;IACxB,IAAI,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,IAAI,KAAK,aAAa,IAAI,EAAE,CAAC,IAAI,KAAK,IAAI,CAAC,EAAE,CAAC;QACpE,MAAM,IAAI,KAAK,CAAC,2BAA2B,IAAI,8BAA8B,CAAC,CAAA;IAChF,CAAC;IACD,MAAM,IAAI,GAAG,WAAW,CAAC,YAAY,CAAC,CAAA;IACtC,OAAO,KAAK,CAAC,CAAC,EAAE,aAAa,EAAE,KAAK,EAAE,KAAK,EAAE,IAAI,EAAE,EAAE,IAAI,EAAE,SAAS,EAAS,CAAC,CAAA;AAChF,CAAC"}
1
+ {"version":3,"file":"trace.js","sourceRoot":"","sources":["../src/trace.ts"],"names":[],"mappings":"AAAA,6EAA6E;AAC7E,2EAA2E;AAC3E,EAAE;AACF,SAAS;AACT,EAAE;AACF,gCAAgC;AAChC,gDAAgD;AAChD,+CAA+C;AAC/C,mDAAmD;AACnD,6BAA6B;AAC7B,OAAO;AACP,EAAE;AACF,+EAA+E;AAC/E,+EAA+E;AAC/E,+CAA+C;AAG/C,OAAO,EAAE,SAAS,EAAE,KAAK,EAAE,WAAW,EAAE,MAAM,SAAS,CAAA;AAEvD,gFAAgF;AAChF,IAAI,QAAQ,GAAiB,IAAI,CAAA;AACjC,gFAAgF;AAChF,6EAA6E;AAC7E,8EAA8E;AAC9E,IAAI,eAAe,GAAG,KAAK,CAAA;AAE3B,MAAM,UAAU,YAAY;IAC1B,IAAI,CAAC,QAAQ,EAAE,CAAC;QACd,MAAM,IAAI,KAAK,CACb,qDAAqD;YACrD,2CAA2C,CAC5C,CAAA;IACH,CAAC;IACD,OAAO,QAAQ,CAAA;AACjB,CAAC;AAED,MAAM,UAAU,gBAAgB;IAC9B,OAAO,eAAe,CAAA;AACxB,CAAC;AAED,gFAAgF;AAChF,8EAA8E;AAC9E,MAAM,UAAU,KAAK,CAAC,EAA2B;IAC/C,IAAI,QAAQ,EAAE,CAAC;QACb,MAAM,IAAI,KAAK,CAAC,6CAA6C,CAAC,CAAA;IAChE,CAAC;IACD,MAAM,CAAC,GAAG,SAAS,EAAE,CAAA;IACrB,QAAQ,GAAG,CAAC,CAAA;IACZ,eAAe,GAAG,IAAI,CAAA;IACtB,IAAI,CAAC;QACH,MAAM,MAAM,GAAG,EAAE,EAAE,CAAA;QACnB,MAAM,OAAO,GAAG,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAA;QACzD,KAAK,MAAM,CAAC,IAAI,OAAO,EAAE,CAAC;YACxB,CAAC;YAAC,CAAC,CAAC,OAAoB,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,CAAA;QACrC,CAAC;IACH,CAAC;YAAS,CAAC;QACT,QAAQ,GAAG,IAAI,CAAA;QACf,eAAe,GAAG,KAAK,CAAA;IACzB,CAAC;IACD,OAAO,CAAC,CAAA;AACV,CAAC;AAED,yEAAyE;AACzE,yEAAyE;AACzE,4EAA4E;AAC5E,sEAAsE;AACtE,4DAA4D;AAC5D,iCAAiC;AACjC,MAAM,UAAU,SAAS,CAAI,CAAQ,EAAE,EAAW;IAChD,IAAI,QAAQ,EAAE,CAAC;QACb,MAAM,IAAI,KAAK,CAAC,8DAA8D,CAAC,CAAA;IACjF,CAAC;IACD,QAAQ,GAAG,CAAC,CAAA;IACZ,qEAAqE;IACrE,IAAI,CAAC;QACH,OAAO,EAAE,EAAE,CAAA;IACb,CAAC;YAAS,CAAC;QACT,QAAQ,GAAG,IAAI,CAAA;IACjB,CAAC;AACH,CAAC;AAED,+EAA+E;AAC/E,6EAA6E;AAC7E,4EAA4E;AAC5E,sBAAsB;AAEtB,MAAM,UAAU,UAAU,CAAC,IAAY,EAAE,KAAY,EAAE,QAAe,KAAK;IACzE,MAAM,CAAC,GAAG,YAAY,EAAE,CAAA;IACxB,IAAI,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,IAAI,KAAK,aAAa,IAAI,EAAE,CAAC,IAAI,KAAK,cAAc,CAAC,IAAI,EAAE,CAAC,IAAI,KAAK,IAAI,CAAC,EAAE,CAAC;QACpG,MAAM,IAAI,KAAK,CAAC,2BAA2B,IAAI,8BAA8B,CAAC,CAAA;IAChF,CAAC;IACD,MAAM,IAAI,GAAG,WAAW,CAAC,YAAY,CAAC,CAAA;IACtC,OAAO,KAAK,CAAC,CAAC,EAAE,aAAa,EAAE,KAAK,EAAE,KAAK,EAAE,IAAI,EAAE,EAAE,IAAI,EAAS,CAAC,CAAA;AACrE,CAAC;AAED,MAAM,UAAU,WAAW,CAAC,IAAY,EAAE,KAAY,EAAE,QAAe,KAAK;IAC1E,MAAM,CAAC,GAAG,YAAY,EAAE,CAAA;IACxB,IAAI,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,IAAI,KAAK,aAAa,IAAI,EAAE,CAAC,IAAI,KAAK,cAAc,CAAC,IAAI,EAAE,CAAC,IAAI,KAAK,IAAI,CAAC,EAAE,CAAC;QACpG,MAAM,IAAI,KAAK,CAAC,2BAA2B,IAAI,8BAA8B,CAAC,CAAA;IAChF,CAAC;IACD,MAAM,IAAI,GAAG,WAAW,CAAC,aAAa,CAAC,CAAA;IACvC,OAAO,KAAK,CAAC,CAAC,EAAE,cAAc,EAAE,KAAK,EAAE,KAAK,EAAE,IAAI,EAAE,EAAE,IAAI,EAAS,CAAC,CAAA;AACtE,CAAC;AAED,uFAAuF;AACvF,mFAAmF;AACnF,MAAM,UAAU,UAAU,CAAC,IAAY,EAAE,KAAY,EAAE,QAAe,KAAK,EAAE,SAAS,GAAG,CAAC;IACxF,MAAM,CAAC,GAAG,YAAY,EAAE,CAAA;IACxB,IAAI,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,IAAI,KAAK,aAAa,IAAI,EAAE,CAAC,IAAI,KAAK,IAAI,CAAC,EAAE,CAAC;QACpE,MAAM,IAAI,KAAK,CAAC,2BAA2B,IAAI,8BAA8B,CAAC,CAAA;IAChF,CAAC;IACD,MAAM,IAAI,GAAG,WAAW,CAAC,YAAY,CAAC,CAAA;IACtC,OAAO,KAAK,CAAC,CAAC,EAAE,aAAa,EAAE,KAAK,EAAE,KAAK,EAAE,IAAI,EAAE,EAAE,IAAI,EAAE,SAAS,EAAS,CAAC,CAAA;AAChF,CAAC"}
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "tensorgrad",
3
- "version": "0.0.2",
3
+ "version": "0.0.4",
4
4
  "description": "Tiny TypeScript-native tensor library with autograd, compiling to WebGPU. Train small models in the browser without hand-writing kernels.",
5
5
  "license": "MIT",
6
6
  "author": "Ben Albahari",
@@ -45,7 +45,6 @@
45
45
  "files": [
46
46
  "dist",
47
47
  "src",
48
- "SPEC.md",
49
48
  "README.md",
50
49
  "LICENSE"
51
50
  ],
package/src/buffers.ts CHANGED
@@ -47,6 +47,7 @@ export interface BufferPlan {
47
47
  inputsByName: Map<string, number> // name -> buffer id
48
48
  paramGradsByName: Map<string, number> // name -> buffer id
49
49
  statesByName: Map<string, number> // name -> buffer id (persistent state homes)
50
+ capturesByName: Map<string, number> // name -> buffer id (activation captures)
50
51
  outputBufferIds: number[] // graph.outputs mapped through
51
52
  /** End-of-step writebacks (Adam updates for params, m, v, etc.) */
52
53
  writebacks: Writeback[]
@@ -169,5 +170,17 @@ export function planBuffers(
169
170
  return { source: sourceBufId, dest: destBufId, bytes: sourceSpec.byteSize }
170
171
  })
171
172
 
172
- return { buffers, tensorToBuffer, paramsByName, inputsByName, paramGradsByName, statesByName, outputBufferIds, writebacks }
173
+ // Resolve graph.captures (name -> tensor id) to (name -> buffer id).
174
+ // No pinning needed at the planner level: each tensor already has its own
175
+ // buffer (see "v1 strategy" comment at top — no pooling yet).
176
+ const capturesByName = new Map<string, number>()
177
+ for (const [name, tensorId] of graph.captures) {
178
+ const bufId = tensorToBuffer.get(tensorId)
179
+ if (bufId === undefined) {
180
+ throw new Error(`planBuffers: capture '${name}' references unknown tensor #${tensorId}`)
181
+ }
182
+ capturesByName.set(name, bufId)
183
+ }
184
+
185
+ return { buffers, tensorToBuffer, paramsByName, inputsByName, paramGradsByName, statesByName, capturesByName, outputBufferIds, writebacks }
173
186
  }
package/src/capture.ts ADDED
@@ -0,0 +1,36 @@
1
+ // Activation capture — opt-in readback of intermediate tensors at training step.
2
+ //
3
+ // Usage (inside the user's forward pass):
4
+ //
5
+ // import { capture } from 'tensorgrad'
6
+ //
7
+ // function attentionFwd(p, x) {
8
+ // const scores = mul(matmulBatched(q, kT), SCALE_QK)
9
+ // const attn = capture(`attn.${layerIdx}`, softmaxCausalLast(scores))
10
+ // return matmulBatched(attn, v)
11
+ // }
12
+ //
13
+ // Pass-through return type: `capture(name, t)` returns `t` unchanged so it
14
+ // inlines at the point of computation. Behind the scenes it registers `t.id`
15
+ // against `name` on the current graph; runtime exposes the registered tensors
16
+ // via `step(inputs, { withCaptures: true })`.
17
+ //
18
+ // Outside the user's forward trace (during `appendGrad` / `appendAdam`'s
19
+ // `traceInto` re-entry), `capture()` is a no-op — gradient and optimizer
20
+ // internals shouldn't accidentally publish themselves to the UI.
21
+
22
+ import type { Tensor } from './ir.js'
23
+ import { currentGraph, isCaptureEnabled } from './trace.js'
24
+
25
+ export function capture<T extends Tensor>(name: string, t: T): T {
26
+ if (!isCaptureEnabled()) return t
27
+ const g = currentGraph()
28
+ if (g.captures.has(name)) {
29
+ throw new Error(
30
+ `capture: name '${name}' already registered. Use unique names ` +
31
+ `(e.g. \`attn.\${layerIdx}\`) when capturing across a loop.`,
32
+ )
33
+ }
34
+ g.captures.set(name, t.id)
35
+ return t
36
+ }
package/src/compile.ts CHANGED
@@ -14,7 +14,7 @@ import { appendGrad, type GradResult } from './grad.js'
14
14
  import { appendAdam, type AdamConfig } from './adam.js'
15
15
  import { planBuffers, type BufferPlan } from './buffers.js'
16
16
  import { emitKernels, type KernelSpec } from './codegen.js'
17
- import { createRuntime, type CompiledRuntime, type RuntimeOpts } from './runtime.js'
17
+ import { createRuntime, createForwardRuntime, type CompiledRuntime, type CompiledForward, type RuntimeOpts } from './runtime.js'
18
18
  import { Module, materializeParams } from './module.js'
19
19
 
20
20
  /** Declares one input tensor of the model's forward function. Order matches
@@ -65,10 +65,19 @@ export interface CompileModuleOptions extends RuntimeOpts {
65
65
  adam?: AdamConfig
66
66
  }
67
67
 
68
+ export interface CompileForwardOptions extends RuntimeOpts {
69
+ /** Per-step data inputs to the forward function. */
70
+ inputs?: InputDecl[]
71
+ }
72
+
68
73
  /**
69
- * Compile a Module-based model. The forward function takes the materialized
70
- * model and returns the loss tensor (typically by also calling tensorInput
71
- * for tokens/targets/masks inside).
74
+ * Compile a Module-based model. Pass a *factory* `() => new Model()`, not the
75
+ * model instance itself: compilation mutates the tree (every `ParamSentinel`
76
+ * field becomes a real `Tensor`), so the instance is consumed and shouldn't be
77
+ * referenced afterwards. Re-call the factory if you need a fresh tree.
78
+ *
79
+ * The forward function takes the materialized model and returns the loss
80
+ * tensor.
72
81
  *
73
82
  * Walks the module tree to materialize params with auto-derived names, then
74
83
  * runs trace → grad → adam → buffer plan → codegen → runtime.
@@ -78,14 +87,15 @@ export interface CompileModuleOptions extends RuntimeOpts {
78
87
  * users don't need to provide it themselves.
79
88
  */
80
89
  export async function compileModule<M extends Module>(
81
- model: M,
90
+ modelFactory: () => M,
82
91
  forward: (m: M, ...inputs: Tensor[]) => Tensor,
83
92
  opts: CompileModuleOptions = {},
84
- ): Promise<CompiledRuntime & { ir: CompiledIR }> {
93
+ ): Promise<CompiledRuntime & { ir: CompiledIR; uploadInitialParams: () => void }> {
85
94
  const inputDecls = opts.inputs ?? []
86
- let paramTensors: Record<string, Tensor> = {}
95
+ const model = modelFactory()
96
+ let materialized: ReturnType<typeof materializeParams> = { tensors: {}, initFns: {} }
87
97
  const graph = trace(() => {
88
- paramTensors = materializeParams(model)
98
+ materialized = materializeParams(model)
89
99
  const inputTensors = inputDecls.map(d => tensorInput(d.name, d.shape, d.dtype ?? 'f32'))
90
100
  return forward(model, ...inputTensors)
91
101
  })
@@ -94,7 +104,7 @@ export async function compileModule<M extends Module>(
94
104
 
95
105
  let adamResult: ReturnType<typeof appendAdam> | undefined
96
106
  if (opts.adam) {
97
- adamResult = appendAdam(graph, paramGrads, paramTensors, opts.adam)
107
+ adamResult = appendAdam(graph, paramGrads, materialized.tensors, opts.adam)
98
108
  }
99
109
 
100
110
  const plan = planBuffers(graph, paramGrads, adamResult?.writebacks ?? [])
@@ -103,18 +113,107 @@ export async function compileModule<M extends Module>(
103
113
  const runtime = await createRuntime(plan, kernels, lossBufferId, opts)
104
114
 
105
115
  // If Adam is enabled, wrap step() to track the step count and supply lrt.
116
+ // Wrap resetOptimizerState() too, so a reset zeros m/v *and* the bias-correction
117
+ // counter — otherwise the next step would skip Adam's warmup phase.
106
118
  if (adamResult) {
107
119
  const { lrtInputName, config } = adamResult
108
120
  let t = 0
109
121
  const lrtBuf = new Float32Array(1)
110
- const innerStep = runtime.step.bind(runtime)
111
- runtime.step = async (inputs) => {
122
+ const innerStep = runtime.step.bind(runtime) as CompiledRuntime['step']
123
+ const innerReset = runtime.resetOptimizerState.bind(runtime)
124
+ const wrappedStep = (
125
+ inputs: Record<string, Int32Array | Float32Array>,
126
+ opts?: { withCaptures?: boolean },
127
+ ): Promise<number | { loss: number; captures: Record<string, Float32Array> }> => {
112
128
  t++
113
129
  lrtBuf[0] = config.lr * Math.sqrt(1 - Math.pow(config.b2, t)) / (1 - Math.pow(config.b1, t))
114
- return innerStep({ ...inputs, [lrtInputName]: lrtBuf })
130
+ const merged = { ...inputs, [lrtInputName]: lrtBuf }
131
+ return opts?.withCaptures ? innerStep(merged, { withCaptures: true }) : innerStep(merged)
132
+ }
133
+ runtime.step = wrappedStep as CompiledRuntime['step']
134
+ runtime.resetOptimizerState = () => {
135
+ t = 0
136
+ innerReset()
137
+ }
138
+ }
139
+
140
+ const { initFns } = materialized
141
+ const uploadInitialParams = () => {
142
+ const out: Record<string, Float32Array> = {}
143
+ for (const [name, bufId] of plan.paramsByName) {
144
+ const shape = plan.buffers[bufId]!.shape
145
+ const size = shape.reduce((a, b) => a * b, 1)
146
+ const initFn = initFns[name]
147
+ if (!initFn) throw new Error(`uploadInitialParams: no init for param '${name}'`)
148
+ out[name] = initFn(size, shape)
115
149
  }
150
+ runtime.uploadParams(out)
116
151
  }
117
152
 
118
153
  const ir: CompiledIR = { graph, paramGrads, loss, plan, kernels }
119
- return Object.assign(runtime, { ir })
154
+ return Object.assign(runtime, { ir, uploadInitialParams })
155
+ }
156
+
157
+ // ============================================================================
158
+ // Forward-only compile
159
+ // ============================================================================
160
+
161
+ /**
162
+ * Compile a Module-based model in forward-only mode (no autograd, no Adam).
163
+ * The forward function returns the output tensor (e.g., logits) instead of a
164
+ * scalar loss; runtime exposes `run(inputs)` returning the full output as a
165
+ * `Float32Array`.
166
+ *
167
+ * **Sharing params with a training compile.** Pass `opts.sharedParams =
168
+ * trainCompiled.params` to bind this graph's param buffers to an existing
169
+ * training runtime's GPU buffers — every train step is then immediately
170
+ * visible to `run()` calls here, no copies. The forward graph's
171
+ * `uploadInitialParams()` skips any param covered by `sharedParams`.
172
+ *
173
+ * Typical use: a B=1 inference graph alongside a B=512 training graph,
174
+ * built from the same `Module` factory.
175
+ */
176
+ export async function compileForward<M extends Module>(
177
+ modelFactory: () => M,
178
+ forward: (m: M, ...inputs: Tensor[]) => Tensor,
179
+ opts: CompileForwardOptions = {},
180
+ ): Promise<CompiledForward & { ir: CompiledIR; uploadInitialParams: () => void }> {
181
+ const inputDecls = opts.inputs ?? []
182
+ const model = modelFactory()
183
+ let materialized: ReturnType<typeof materializeParams> = { tensors: {}, initFns: {} }
184
+ const graph = trace(() => {
185
+ materialized = materializeParams(model)
186
+ const inputTensors = inputDecls.map(d => tensorInput(d.name, d.shape, d.dtype ?? 'f32'))
187
+ return forward(model, ...inputTensors)
188
+ })
189
+
190
+ const plan = planBuffers(graph, /* paramGrads */ {})
191
+ const kernels = emitKernels(graph, plan)
192
+ const outputTensor = graph.tensors[graph.outputs[0]!]!
193
+ const outputBufferId = plan.tensorToBuffer.get(outputTensor.id)!
194
+ const runtime = await createForwardRuntime(plan, kernels, outputBufferId, opts)
195
+
196
+ const sharedParams = opts.sharedParams
197
+ const { initFns } = materialized
198
+ const uploadInitialParams = () => {
199
+ const out: Record<string, Float32Array> = {}
200
+ let needsUpload = false
201
+ for (const [name, bufId] of plan.paramsByName) {
202
+ // Skip params covered by sharedParams — those are owned by the providing
203
+ // compile and already initialized there.
204
+ if (sharedParams?.has(name)) continue
205
+ const shape = plan.buffers[bufId]!.shape
206
+ const size = shape.reduce((a, b) => a * b, 1)
207
+ const initFn = initFns[name]
208
+ if (!initFn) throw new Error(`uploadInitialParams: no init for param '${name}'`)
209
+ out[name] = initFn(size, shape)
210
+ needsUpload = true
211
+ }
212
+ if (needsUpload) runtime.uploadParams(out, { partial: !!sharedParams })
213
+ }
214
+
215
+ // CompiledIR.loss is the field name; for forward-only, it carries the user's
216
+ // returned tensor (e.g., logits). Same shape conceptually; just no autograd.
217
+ const ir: CompiledIR = { graph, paramGrads: {}, loss: outputTensor, plan, kernels }
218
+ return Object.assign(runtime, { ir, uploadInitialParams })
120
219
  }
package/src/index.ts CHANGED
@@ -6,6 +6,7 @@
6
6
  export type { Tensor, Shape, Dtype, OpNode, Graph, CallSite } from './ir.js'
7
7
  export { ShapeError } from './shape.js'
8
8
  export { trace, traceInto, paramInput, tensorInput, stateInput } from './trace.js'
9
+ export { capture } from './capture.js'
9
10
  export {
10
11
  // Element-wise arithmetic. The binops accept Tensor or JS-number for the second arg.
11
12
  add, sub, mul, div,
@@ -35,6 +36,7 @@ export { appendGrad, type GradResult } from './grad.js'
35
36
  export { appendAdam, type AdamConfig, type AdamResult } from './adam.js'
36
37
  export { planBuffers, type BufferPlan, type BufferSpec, type Writeback, type WritebackDecl } from './buffers.js'
37
38
  export { emitKernels, type KernelSpec } from './codegen.js'
38
- export { createRuntime, type CompiledRuntime, type RuntimeOpts } from './runtime.js'
39
- export { compile, compileToIR, compileModule, type CompiledIR, type CompileModuleOptions, type InputDecl } from './compile.js'
40
- export { Module, materializeParams } from './module.js'
39
+ export { createRuntime, createForwardRuntime, type CompiledRuntime, type CompiledForward, type RuntimeOpts, type StepOptions, type StepWithCaptures, type RunOptions, type RunWithCaptures } from './runtime.js'
40
+ export { compile, compileToIR, compileModule, compileForward, type CompiledIR, type CompileModuleOptions, type CompileForwardOptions, type InputDecl } from './compile.js'
41
+ export { Module, materializeParams, type InitSpec, type ParamOptions, type MaterializedParams } from './module.js'
42
+ export * as nn from './nn.js'
package/src/ir.ts CHANGED
@@ -141,10 +141,14 @@ export interface Graph {
141
141
  // Names of tensors that should be exposed as outputs of the compiled function.
142
142
  // Set by the trace driver; for a loss function, this is `[lossTensor]`.
143
143
  readonly outputs: number[]
144
+ // Tensors registered for activation readback via `capture(name, t)`.
145
+ // Keyed by user-supplied name; insertion order preserved. Empty when no
146
+ // captures registered (the common training case — zero overhead).
147
+ readonly captures: Map<string, number>
144
148
  }
145
149
 
146
150
  export function makeGraph(): Graph {
147
- return { ops: [], tensors: [], outputs: [] }
151
+ return { ops: [], tensors: [], outputs: [], captures: new Map() }
148
152
  }
149
153
 
150
154
  // Internal: register a fresh tensor in the graph and return its id.
package/src/module.ts CHANGED
@@ -6,8 +6,8 @@
6
6
  // W: Tensor; b: Tensor
7
7
  // constructor(inDim: number, outDim: number) {
8
8
  // super()
9
- // this.W = this.param([inDim, outDim])
10
- // this.b = this.param([outDim])
9
+ // this.W = this.param([inDim, outDim]) // randn, scale 0.02
10
+ // this.b = this.param([outDim], { init: 'zeros' })
11
11
  // }
12
12
  // }
13
13
  // class Block extends Module {
@@ -28,6 +28,54 @@
28
28
  import type { Tensor, Shape, Dtype } from './ir.js'
29
29
  import { paramInput } from './trace.js'
30
30
 
31
+ // ============================================================================
32
+ // Init metadata
33
+ // ============================================================================
34
+
35
+ /** How a parameter's initial values are produced.
36
+ * - `'randn'` — Gaussian, with `scale` (default 0.02). The common case for
37
+ * weight matrices and embeddings.
38
+ * - `'zeros'` — fill with 0. Common for biases and LayerNorm beta.
39
+ * - `'ones'` — fill with 1. Common for LayerNorm gain.
40
+ * - Custom function — receives total element count and shape, returns the
41
+ * Float32Array. Use for fan-in scaling or any non-standard scheme.
42
+ */
43
+ export type InitSpec =
44
+ | 'randn'
45
+ | 'zeros'
46
+ | 'ones'
47
+ | ((size: number, shape: readonly number[]) => Float32Array)
48
+
49
+ export interface ParamOptions {
50
+ dtype?: Dtype
51
+ /** Init kind. Default: `'randn'`. */
52
+ init?: InitSpec
53
+ /** Std dev for `'randn'`. Default 0.02. Ignored for non-randn init. */
54
+ scale?: number
55
+ }
56
+
57
+ type InitFn = (size: number, shape: readonly number[]) => Float32Array
58
+
59
+ function boxMuller(): number {
60
+ return Math.sqrt(-2 * Math.log(Math.max(1e-10, Math.random()))) * Math.cos(2 * Math.PI * Math.random())
61
+ }
62
+
63
+ function resolveInit(opts: ParamOptions | undefined): InitFn {
64
+ const init = opts?.init ?? 'randn'
65
+ if (init === 'randn') {
66
+ const scale = opts?.scale ?? 0.02
67
+ return (size) => {
68
+ const arr = new Float32Array(size)
69
+ for (let i = 0; i < size; i++) arr[i] = boxMuller() * scale
70
+ return arr
71
+ }
72
+ }
73
+ if (init === 'zeros') return (size) => new Float32Array(size)
74
+ if (init === 'ones') return (size) => { const a = new Float32Array(size); a.fill(1); return a }
75
+ if (typeof init === 'function') return init
76
+ throw new Error(`Unknown init: ${String(init)}`)
77
+ }
78
+
31
79
  // ============================================================================
32
80
  // Internals: param sentinel
33
81
  // ============================================================================
@@ -38,7 +86,11 @@ import { paramInput } from './trace.js'
38
86
  // only valid post-materialization (which is always before forward runs).
39
87
 
40
88
  class ParamSentinel {
41
- constructor(public readonly shape: Shape, public readonly dtype: Dtype) {}
89
+ constructor(
90
+ public readonly shape: Shape,
91
+ public readonly dtype: Dtype,
92
+ public readonly initFn: InitFn,
93
+ ) {}
42
94
  }
43
95
 
44
96
  // ============================================================================
@@ -52,11 +104,13 @@ export abstract class Module {
52
104
  * that gets replaced with a real Tensor at compile time.
53
105
  *
54
106
  * The parameter's name is auto-derived from its property path in the model
55
- * tree (e.g. `layers.0.attn.W_q`).
107
+ * tree (e.g. `layers.0.attn.W_q`). Init metadata travels with the param;
108
+ * call `compiled.uploadInitialParams()` to apply it after compile.
56
109
  */
57
- protected param(shape: Shape, dtype: Dtype = 'f32'): Tensor {
110
+ protected param(shape: Shape, opts?: ParamOptions): Tensor {
111
+ const dtype = opts?.dtype ?? 'f32'
58
112
  // Lie to TypeScript: the sentinel becomes a Tensor at materialize time.
59
- return new ParamSentinel(shape, dtype) as unknown as Tensor
113
+ return new ParamSentinel(shape, dtype, resolveInit(opts)) as unknown as Tensor
60
114
  }
61
115
  }
62
116
 
@@ -64,23 +118,33 @@ export abstract class Module {
64
118
  // Tree walking
65
119
  // ============================================================================
66
120
 
121
+ export interface MaterializedParams {
122
+ /** Map from auto-derived path (e.g. `layers.0.attn.W_q`) to its Tensor. */
123
+ tensors: Record<string, Tensor>
124
+ /** Init function per param path. Used by `uploadInitialParams`. */
125
+ initFns: Record<string, InitFn>
126
+ }
127
+
67
128
  /**
68
129
  * Walk the module tree and replace every ParamSentinel with a real Tensor
69
130
  * created via `paramInput(autoName, ...)`. Must be called inside an active
70
131
  * trace context (paramInput appends to the current graph).
71
132
  *
72
- * Returns a flat record of `{ path: tensor }` for every materialized param.
133
+ * Returns the param tensors keyed by path, plus init functions for use by
134
+ * `uploadInitialParams`.
73
135
  */
74
- export function materializeParams(root: Module): Record<string, Tensor> {
75
- const out: Record<string, Tensor> = {}
136
+ export function materializeParams(root: Module): MaterializedParams {
137
+ const tensors: Record<string, Tensor> = {}
138
+ const initFns: Record<string, InitFn> = {}
76
139
  visit(root, '', (path, val, owner, key) => {
77
140
  if (val instanceof ParamSentinel) {
78
141
  const t = paramInput(path, val.shape, val.dtype)
79
142
  ;(owner as any)[key] = t
80
- out[path] = t
143
+ tensors[path] = t
144
+ initFns[path] = val.initFn
81
145
  }
82
146
  })
83
- return out
147
+ return { tensors, initFns }
84
148
  }
85
149
 
86
150
  // ----------------------------------------------------------------------------
package/src/nn.ts ADDED
@@ -0,0 +1,59 @@
1
+ // Standard "batteries-included" Module subclasses for the most common layers.
2
+ //
3
+ // JAX-style: each class declares its params (and their init); the forward is a
4
+ // plain function the user calls with `(module, x)`. No subclassing, no method
5
+ // dispatch — keeps the autograd-traced computation visible at the call site.
6
+ //
7
+ // Import as a namespace:
8
+ //
9
+ // import { nn } from 'tensorgrad'
10
+ // class Block extends Module {
11
+ // ln = new nn.LayerNorm(D)
12
+ // ffn = new nn.Linear(D, 4 * D)
13
+ // }
14
+ // const y = nn.linearFwd(p.ffn, nn.layerNormFwd(p.ln, x))
15
+
16
+ import { Module } from './module.js'
17
+ import type { Tensor } from './ir.js'
18
+ import { add, matmul, sub, mul, div, sqrt, meanLast } from './ops.js'
19
+
20
+ // ----------------------------------------------------------------------------
21
+ // Linear: y = x @ W (+ b)
22
+ // ----------------------------------------------------------------------------
23
+
24
+ export class Linear extends Module {
25
+ W: Tensor
26
+ b: Tensor | null
27
+ constructor(public readonly inDim: number, public readonly outDim: number, withBias = true) {
28
+ super()
29
+ this.W = this.param([inDim, outDim]) // randn, scale 0.02
30
+ this.b = withBias ? this.param([outDim], { init: 'zeros' }) : null
31
+ }
32
+ }
33
+
34
+ export function linearFwd(p: Linear, x: Tensor): Tensor {
35
+ const out = matmul(x, p.W)
36
+ return p.b ? add(out, p.b) : out
37
+ }
38
+
39
+ // ----------------------------------------------------------------------------
40
+ // LayerNorm — normalizes over the last axis. eps defaults to 1e-5.
41
+ // ----------------------------------------------------------------------------
42
+
43
+ export class LayerNorm extends Module {
44
+ g: Tensor
45
+ b: Tensor
46
+ constructor(public readonly d: number, public readonly eps: number = 1e-5) {
47
+ super()
48
+ this.g = this.param([d], { init: 'ones' })
49
+ this.b = this.param([d], { init: 'zeros' })
50
+ }
51
+ }
52
+
53
+ export function layerNormFwd(p: LayerNorm, x: Tensor): Tensor {
54
+ const m = meanLast(x)
55
+ const c = sub(x, m)
56
+ const v = meanLast(mul(c, c))
57
+ const stdev = sqrt(add(v, p.eps))
58
+ return add(mul(div(c, stdev), p.g), p.b)
59
+ }