tensorgrad 0.0.9 → 0.0.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -44,7 +44,7 @@ function forward(m: MLP, x: Tensor): Tensor {
44
44
  return linear(m.l3, relu(linear(m.l2, relu(linear(m.l1, x)))))
45
45
  }
46
46
 
47
- function loss(m: MLP, x: Tensor, y: Tensor): Tensor {
47
+ function loss(m: MLP, { x, y }: { x: Tensor; y: Tensor }): Tensor {
48
48
  const diff = sub(forward(m, x), y)
49
49
  return mul(sumLast(reshape(mul(diff, diff), [B])), 1 / B)
50
50
  }
@@ -52,13 +52,13 @@ function loss(m: MLP, x: Tensor, y: Tensor): Tensor {
52
52
  const B = 256
53
53
  const compiled = await compileModule(() => new MLP(), loss, {
54
54
  adam: { lr: 0.005 },
55
- inputs: [
56
- { name: 'x', shape: [B, 1], dtype: 'f32' },
57
- { name: 'y', shape: [B, 1], dtype: 'f32' },
58
- ],
55
+ inputs: {
56
+ x: { shape: [B, 1], dtype: 'f32' },
57
+ y: { shape: [B, 1], dtype: 'f32' },
58
+ },
59
59
  })
60
60
 
61
- compiled.uploadInitialParams() // applies the per-param init declared above
61
+ // Initial params are uploaded automatically — no manual step needed.
62
62
 
63
63
  for (let step = 0; step < 1000; step++) {
64
64
  const { x, y } = generateBatch()
package/dist/compile.d.ts CHANGED
@@ -5,14 +5,29 @@ import { type BufferPlan } from './buffers.js';
5
5
  import { type KernelSpec } from './codegen.js';
6
6
  import { type CompiledRuntime, type CompiledForward, type RuntimeOpts } from './runtime.js';
7
7
  import { Module } from './module.js';
8
- /** Declares one input tensor of the model's forward function. Order matches
9
- * the function's parameter list (after `model`). The `name` is used at
10
- * runtime to upload data via `step({ [name]: data })`. */
8
+ /** Declares one input tensor of the model's forward function. The name is the
9
+ * key in the `inputs:` Record at compile time and the key on the `step()`/
10
+ * `run()` data object at runtime. */
11
11
  export interface InputDecl {
12
- name: string;
13
12
  shape: Shape;
14
13
  dtype?: Dtype;
15
14
  }
15
+ /** Inputs declaration: a Record from input name to its shape/dtype. The name
16
+ * doubles as the key the forward fn destructures and the key the runtime
17
+ * expects in `step({...})` / `run({...})`. */
18
+ export type InputDecls = Record<string, InputDecl>;
19
+ /** Maps an `InputDecls` Record to its forward-time tensor counterpart —
20
+ * same keys, each value is a Tensor. Used to type the forward function's
21
+ * `inputs` argument from the declared shape Record. */
22
+ export type InputsTensors<I extends InputDecls> = {
23
+ [K in keyof I]: Tensor;
24
+ };
25
+ /** Forward function shape: takes the materialized model and a Record of
26
+ * named input tensors (matching the declared `inputs:` keys), returns the
27
+ * output tensor (loss for compileModule; logits/etc. for compileForward).
28
+ * The second generic flows from the inputs declaration so destructuring
29
+ * the input record stays typed. */
30
+ export type ForwardFn<M extends Module, I extends InputDecls = InputDecls> = (m: M, inputs: InputsTensors<I>) => Tensor;
16
31
  export interface CompiledIR {
17
32
  graph: GradResult['graph'];
18
33
  paramGrads: GradResult['paramGrads'];
@@ -26,18 +41,45 @@ export declare function compileToIR(traceFn: () => Tensor): CompiledIR;
26
41
  export declare function compile(traceFn: () => Tensor, opts?: RuntimeOpts): Promise<CompiledRuntime & {
27
42
  ir: CompiledIR;
28
43
  }>;
29
- export interface CompileModuleOptions extends RuntimeOpts {
30
- /** Per-step data inputs to the forward function. Order matches the forward
31
- * function's parameters (after the model). e.g. for
32
- * `(model, tokens, targets, mask) => loss`, inputs is
33
- * `[{name:'tokens',...}, {name:'targets',...}, {name:'mask',...}]`. */
34
- inputs?: InputDecl[];
44
+ export interface CompileModuleOptions<I extends InputDecls = InputDecls> extends RuntimeOpts {
45
+ /** Per-step data inputs to the forward function, keyed by name. The forward
46
+ * fn destructures these out of its second argument; runtime calls to
47
+ * `step()` / `run()` pass typed arrays under the same keys. */
48
+ inputs?: I;
35
49
  /** Adam hyperparameters. If omitted, no optimizer is appended (forward-only). */
36
50
  adam?: AdamConfig;
37
51
  }
38
- export interface CompileForwardOptions extends RuntimeOpts {
39
- /** Per-step data inputs to the forward function. */
40
- inputs?: InputDecl[];
52
+ export interface CompileForwardOptions<I extends InputDecls = InputDecls> extends RuntimeOpts {
53
+ /** Per-step data inputs to the forward function, keyed by name. */
54
+ inputs?: I;
55
+ }
56
+ /** Forward-only compile options as taken by the `compileForward` *method* on
57
+ * a training runtime — no `device` (inherited) and no `sharedParams`
58
+ * (auto-supplied from the train graph's params). */
59
+ export interface CompileForwardMethodOptions<I extends InputDecls = InputDecls> {
60
+ inputs?: I;
61
+ }
62
+ /** Returned by `compileModule`. Adds training-graph extras (auto-init, reset,
63
+ * sibling-graph compile) on top of the base runtime. */
64
+ export interface CompiledModule<M extends Module> extends CompiledRuntime {
65
+ ir: CompiledIR;
66
+ /** Number of dispatchable kernels (excludes leaf no-ops). */
67
+ kernelCount: number;
68
+ /** Re-initialize all params from their declared init specs and zero the
69
+ * optimizer state. Use to start training over without recompiling. */
70
+ reset(): void;
71
+ /** Compile a sibling forward-only graph (e.g., a B=1 inference graph or a
72
+ * B=N held-out eval graph) that shares this runtime's device and param
73
+ * buffers. Pass the forward fn (typically distinct from your loss fn —
74
+ * it returns logits, not a scalar) and any shape changes via `inputs`.
75
+ * Auto-initialization is a no-op since params are shared. */
76
+ compileForward<I extends InputDecls>(forward: ForwardFn<M, I>, opts?: CompileForwardMethodOptions<I>): Promise<CompiledForwardModule>;
77
+ }
78
+ /** Returned by `compileForward` (and by the `compileForward` method). */
79
+ export interface CompiledForwardModule extends CompiledForward {
80
+ ir: CompiledIR;
81
+ /** Number of dispatchable kernels (excludes leaf no-ops). */
82
+ kernelCount: number;
41
83
  }
42
84
  /**
43
85
  * Compile a Module-based model. Pass a *factory* `() => new Model()`, not the
@@ -45,37 +87,44 @@ export interface CompileForwardOptions extends RuntimeOpts {
45
87
  * field becomes a real `Tensor`), so the instance is consumed and shouldn't be
46
88
  * referenced afterwards. Re-call the factory if you need a fresh tree.
47
89
  *
48
- * The forward function takes the materialized model and returns the loss
49
- * tensor.
90
+ * The forward function takes the materialized model and a Record of named
91
+ * input tensors, returns the loss tensor. Inputs are matched by name with the
92
+ * `inputs:` declaration:
93
+ *
94
+ * inputs: {
95
+ * tokens: { shape: [B, T], dtype: 'i32' },
96
+ * targets: { shape: [B, T], dtype: 'i32' },
97
+ * }
98
+ * forward: (m, { tokens, targets }) => …
50
99
  *
51
100
  * Walks the module tree to materialize params with auto-derived names, then
52
- * runs trace → grad → adam → buffer plan → codegen → runtime.
101
+ * runs trace → grad → adam → buffer plan → codegen → runtime. Initial
102
+ * parameter values are uploaded automatically before this function returns;
103
+ * call `reset()` later to re-randomize.
53
104
  *
54
105
  * If `opts.adam` is set, the runtime's `step()` automatically tracks an
55
106
  * internal step count and injects the bias-corrected `lrt` scalar each call;
56
107
  * users don't need to provide it themselves.
57
108
  */
58
- export declare function compileModule<M extends Module>(modelFactory: () => M, forward: (m: M, ...inputs: Tensor[]) => Tensor, opts?: CompileModuleOptions): Promise<CompiledRuntime & {
59
- ir: CompiledIR;
60
- uploadInitialParams: () => void;
61
- }>;
109
+ export declare function compileModule<M extends Module, I extends InputDecls = InputDecls>(modelFactory: () => M, forward: ForwardFn<M, I>, opts?: CompileModuleOptions<I>): Promise<CompiledModule<M>>;
62
110
  /**
63
111
  * Compile a Module-based model in forward-only mode (no autograd, no Adam).
64
112
  * The forward function returns the output tensor (e.g., logits) instead of a
65
113
  * scalar loss; runtime exposes `run(inputs)` returning the full output as a
66
114
  * `Float32Array`.
67
115
  *
116
+ * **Prefer the `compileForward` method on a training runtime** when both
117
+ * graphs use the same Module class — it auto-supplies `device` and
118
+ * `sharedParams`. This standalone form is for forward-only models with no
119
+ * training graph at all, or for sharing params across a different model.
120
+ *
68
121
  * **Sharing params with a training compile.** Pass `opts.sharedParams =
69
122
  * trainCompiled.params` to bind this graph's param buffers to an existing
70
123
  * training runtime's GPU buffers — every train step is then immediately
71
- * visible to `run()` calls here, no copies. The forward graph's
72
- * `uploadInitialParams()` skips any param covered by `sharedParams`.
124
+ * visible to `run()` calls here, no copies.
73
125
  *
74
- * Typical use: a B=1 inference graph alongside a B=512 training graph,
75
- * built from the same `Module` factory.
126
+ * Initial param values are uploaded automatically for params *not* covered
127
+ * by `sharedParams` (those are owned by the sibling compile).
76
128
  */
77
- export declare function compileForward<M extends Module>(modelFactory: () => M, forward: (m: M, ...inputs: Tensor[]) => Tensor, opts?: CompileForwardOptions): Promise<CompiledForward & {
78
- ir: CompiledIR;
79
- uploadInitialParams: () => void;
80
- }>;
129
+ export declare function compileForward<M extends Module, I extends InputDecls = InputDecls>(modelFactory: () => M, forward: ForwardFn<M, I>, opts?: CompileForwardOptions<I>): Promise<CompiledForwardModule>;
81
130
  //# sourceMappingURL=compile.d.ts.map
@@ -1 +1 @@
1
- {"version":3,"file":"compile.d.ts","sourceRoot":"","sources":["../src/compile.ts"],"names":[],"mappings":"AAUA,OAAO,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,SAAS,CAAA;AAEnD,OAAO,EAAc,KAAK,UAAU,EAAE,MAAM,WAAW,CAAA;AACvD,OAAO,EAAc,KAAK,UAAU,EAAE,MAAM,WAAW,CAAA;AACvD,OAAO,EAAe,KAAK,UAAU,EAAE,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAe,KAAK,UAAU,EAAE,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAuC,KAAK,eAAe,EAAE,KAAK,eAAe,EAAE,KAAK,WAAW,EAAE,MAAM,cAAc,CAAA;AAChI,OAAO,EAAE,MAAM,EAAqB,MAAM,aAAa,CAAA;AAEvD;;2DAE2D;AAC3D,MAAM,WAAW,SAAS;IACxB,IAAI,EAAE,MAAM,CAAA;IACZ,KAAK,EAAE,KAAK,CAAA;IACZ,KAAK,CAAC,EAAE,KAAK,CAAA;CACd;AAED,MAAM,WAAW,UAAU;IACzB,KAAK,EAAE,UAAU,CAAC,OAAO,CAAC,CAAA;IAC1B,UAAU,EAAE,UAAU,CAAC,YAAY,CAAC,CAAA;IACpC,IAAI,EAAE,MAAM,CAAA;IACZ,IAAI,EAAE,UAAU,CAAA;IAChB,OAAO,EAAE,UAAU,EAAE,CAAA;CACtB;AAED,yEAAyE;AACzE,wBAAgB,WAAW,CAAC,OAAO,EAAE,MAAM,MAAM,GAAG,UAAU,CAM7D;AAED,0EAA0E;AAC1E,wBAAsB,OAAO,CAAC,OAAO,EAAE,MAAM,MAAM,EAAE,IAAI,GAAE,WAAgB,GAAG,OAAO,CAAC,eAAe,GAAG;IAAE,EAAE,EAAE,UAAU,CAAA;CAAE,CAAC,CAK1H;AAMD,MAAM,WAAW,oBAAqB,SAAQ,WAAW;IACvD;;;4EAGwE;IACxE,MAAM,CAAC,EAAE,SAAS,EAAE,CAAA;IACpB,iFAAiF;IACjF,IAAI,CAAC,EAAE,UAAU,CAAA;CAClB;AAED,MAAM,WAAW,qBAAsB,SAAQ,WAAW;IACxD,oDAAoD;IACpD,MAAM,CAAC,EAAE,SAAS,EAAE,CAAA;CACrB;AAED;;;;;;;;;;;;;;;GAeG;AACH,wBAAsB,aAAa,CAAC,CAAC,SAAS,MAAM,EAClD,YAAY,EAAE,MAAM,CAAC,EACrB,OAAO,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,GAAG,MAAM,EAAE,MAAM,EAAE,KAAK,MAAM,EAC9C,IAAI,GAAE,oBAAyB,GAC9B,OAAO,CAAC,eAAe,GAAG;IAAE,EAAE,EAAE,UAAU,CAAC;IAAC,mBAAmB,EAAE,MAAM,IAAI,CAAA;CAAE,CAAC,CA6DhF;AA4BD;;;;;;;;;;;;;;GAcG;AACH,wBAAsB,cAAc,CAAC,CAAC,SAAS,MAAM,EACnD,YAAY,EAAE,MAAM,CAAC,EACrB,OAAO,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,GAAG,MAAM,EAAE,MAAM,EAAE,KAAK,MAAM,EAC9C,IAAI,GAAE,qBAA0B,GAC/B,OAAO,CAAC,eAAe,GAAG;IAAE,EAAE,EAAE,UAAU,CAAC;IAAC,mBAAmB,EAAE,MAAM,IAAI,CAAA;CAAE,CAAC,CA0BhF"}
1
+ {"version":3,"file":"compile.d.ts","sourceRoot":"","sources":["../src/compile.ts"],"names":[],"mappings":"AAUA,OAAO,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,SAAS,CAAA;AAEnD,OAAO,EAAc,KAAK,UAAU,EAAE,MAAM,WAAW,CAAA;AACvD,OAAO,EAAc,KAAK,UAAU,EAAE,MAAM,WAAW,CAAA;AACvD,OAAO,EAAe,KAAK,UAAU,EAAE,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAe,KAAK,UAAU,EAAE,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAuC,KAAK,eAAe,EAAE,KAAK,eAAe,EAAE,KAAK,WAAW,EAAE,MAAM,cAAc,CAAA;AAChI,OAAO,EAAE,MAAM,EAAqB,MAAM,aAAa,CAAA;AAEvD;;sCAEsC;AACtC,MAAM,WAAW,SAAS;IACxB,KAAK,EAAE,KAAK,CAAA;IACZ,KAAK,CAAC,EAAE,KAAK,CAAA;CACd;AAED;;+CAE+C;AAC/C,MAAM,MAAM,UAAU,GAAG,MAAM,CAAC,MAAM,EAAE,SAAS,CAAC,CAAA;AAElD;;wDAEwD;AACxD,MAAM,MAAM,aAAa,CAAC,CAAC,SAAS,UAAU,IAAI;KAAG,CAAC,IAAI,MAAM,CAAC,GAAG,MAAM;CAAE,CAAA;AAE5E;;;;oCAIoC;AACpC,MAAM,MAAM,SAAS,CAAC,CAAC,SAAS,MAAM,EAAE,CAAC,SAAS,UAAU,GAAG,UAAU,IACvE,CAAC,CAAC,EAAE,CAAC,EAAE,MAAM,EAAE,aAAa,CAAC,CAAC,CAAC,KAAK,MAAM,CAAA;AAE5C,MAAM,WAAW,UAAU;IACzB,KAAK,EAAE,UAAU,CAAC,OAAO,CAAC,CAAA;IAC1B,UAAU,EAAE,UAAU,CAAC,YAAY,CAAC,CAAA;IACpC,IAAI,EAAE,MAAM,CAAA;IACZ,IAAI,EAAE,UAAU,CAAA;IAChB,OAAO,EAAE,UAAU,EAAE,CAAA;CACtB;AAED,yEAAyE;AACzE,wBAAgB,WAAW,CAAC,OAAO,EAAE,MAAM,MAAM,GAAG,UAAU,CAM7D;AAED,0EAA0E;AAC1E,wBAAsB,OAAO,CAAC,OAAO,EAAE,MAAM,MAAM,EAAE,IAAI,GAAE,WAAgB,GAAG,OAAO,CAAC,eAAe,GAAG;IAAE,EAAE,EAAE,UAAU,CAAA;CAAE,CAAC,CAK1H;AAMD,MAAM,WAAW,oBAAoB,CAAC,CAAC,SAAS,UAAU,GAAG,UAAU,CAAE,SAAQ,WAAW;IAC1F;;oEAEgE;IAChE,MAAM,CAAC,EAAE,CAAC,CAAA;IACV,iFAAiF;IACjF,IAAI,CAAC,EAAE,UAAU,CAAA;CAClB;AAED,MAAM,WAAW,qBAAqB,CAAC,CAAC,SAAS,UAAU,GAAG,UAAU,CAAE,SAAQ,WAAW;IAC3F,mEAAmE;IACnE,MAAM,CAAC,EAAE,CAAC,CAAA;CACX;AAED;;qDAEqD;AACrD,MAAM,WAAW,2BAA2B,CAAC,CAAC,SAAS,UAAU,GAAG,UAAU;IAC5E,MAAM,CAAC,EAAE,CAAC,CAAA;CACX;AAED;yDACyD;AACzD,MAAM,WAAW,cAAc,CAAC,CAAC,SAAS,MAAM,CAAE,SAAQ,eAAe;IACvE,EAAE,EAAE,UAAU,CAAA;IACd,6DAA6D;IAC7D,WAAW,EAAE,MAAM,CAAA;IACnB;2EACuE;IACvE,KAAK,IAAI,IAAI,CAAA;IACb;;;;kEAI8D;IAC9D,cAAc,CAAC,CAAC,SAAS,UAAU,EACjC,OAAO,EAAE,SAAS,CAAC,CAAC,EAAE,CAAC,CAAC,EACxB,IAAI,CAAC,EAAE,2BAA2B,CAAC,CAAC,CAAC,GACpC,OAAO,CAAC,qBAAqB,CAAC,CAAA;CAClC;AAED,yEAAyE;AACzE,MAAM,WAAW,qBAAsB,SAAQ,eAAe;IAC5D,EAAE,EAAE,UAAU,CAAA;IACd,6DAA6D;IAC7D,WAAW,EAAE,MAAM,CAAA;CACpB;AAED;;;;;;;;;;;;;;;;;;;;;;;;GAwBG;AACH,wBAAsB,aAAa,CAAC,CAAC,SAAS,MAAM,EAAE,CAAC,SAAS,UAAU,GAAG,UAAU,EACrF,YAAY,EAAE,MAAM,CAAC,EACrB,OAAO,EAAE,SAAS,CAAC,CAAC,EAAE,CAAC,CAAC,EACxB,IAAI,GAAE,oBAAoB,CAAC,CAAC,CAAM,GACjC,OAAO,CAAC,cAAc,CAAC,CAAC,CAAC,CAAC,CAyC5B;AAMD;;;;;;;;;;;;;;;;;;GAkBG;AACH,wBAAsB,cAAc,CAAC,CAAC,SAAS,MAAM,EAAE,CAAC,SAAS,UAAU,GAAG,UAAU,EACtF,YAAY,EAAE,MAAM,CAAC,EACrB,OAAO,EAAE,SAAS,CAAC,CAAC,EAAE,CAAC,CAAC,EACxB,IAAI,GAAE,qBAAqB,CAAC,CAAC,CAAM,GAClC,OAAO,CAAC,qBAAqB,CAAC,CAYhC"}
package/dist/compile.js CHANGED
@@ -35,82 +35,55 @@ export async function compile(traceFn, opts = {}) {
35
35
  * field becomes a real `Tensor`), so the instance is consumed and shouldn't be
36
36
  * referenced afterwards. Re-call the factory if you need a fresh tree.
37
37
  *
38
- * The forward function takes the materialized model and returns the loss
39
- * tensor.
38
+ * The forward function takes the materialized model and a Record of named
39
+ * input tensors, returns the loss tensor. Inputs are matched by name with the
40
+ * `inputs:` declaration:
41
+ *
42
+ * inputs: {
43
+ * tokens: { shape: [B, T], dtype: 'i32' },
44
+ * targets: { shape: [B, T], dtype: 'i32' },
45
+ * }
46
+ * forward: (m, { tokens, targets }) => …
40
47
  *
41
48
  * Walks the module tree to materialize params with auto-derived names, then
42
- * runs trace → grad → adam → buffer plan → codegen → runtime.
49
+ * runs trace → grad → adam → buffer plan → codegen → runtime. Initial
50
+ * parameter values are uploaded automatically before this function returns;
51
+ * call `reset()` later to re-randomize.
43
52
  *
44
53
  * If `opts.adam` is set, the runtime's `step()` automatically tracks an
45
54
  * internal step count and injects the bias-corrected `lrt` scalar each call;
46
55
  * users don't need to provide it themselves.
47
56
  */
48
57
  export async function compileModule(modelFactory, forward, opts = {}) {
49
- const inputDecls = opts.inputs ?? [];
50
- const model = modelFactory();
51
- let materialized = { tensors: {}, initFns: {}, decayFlags: {} };
52
- const graph = trace(() => {
53
- materialized = materializeParams(model);
54
- const inputTensors = inputDecls.map(d => tensorInput(d.name, d.shape, d.dtype ?? 'f32'));
55
- return forward(model, ...inputTensors);
56
- });
57
- const { paramGrads, loss } = appendGrad(graph);
58
- let adamResult;
59
- if (opts.adam) {
60
- adamResult = appendAdam(graph, paramGrads, materialized.tensors, opts.adam, materialized.decayFlags);
61
- }
62
- const plan = planBuffers(graph, paramGrads, adamResult?.writebacks ?? []);
63
- const kernels = emitKernels(graph, plan);
64
- const lossBufferId = plan.tensorToBuffer.get(loss.id);
65
- const runtime = await createRuntime(plan, kernels, lossBufferId, opts);
58
+ const { runtime, materialized, plan, kernels, ir } = await buildModuleRuntime(modelFactory, forward, opts, /* sharedParams */ undefined, /* withGrad */ true);
66
59
  // If Adam is enabled, wrap step() to track the step count and supply lrt
67
60
  // (and optionally decayShrink, when the user passed a per-step lr schedule).
68
61
  // Wrap resetOptimizerState() too, so a reset zeros m/v *and* the bias-correction
69
62
  // counter — otherwise the next step would skip Adam's warmup phase.
70
- if (adamResult) {
71
- const { lrtInputName, decayShrinkInputName, config } = adamResult;
72
- let t = 0;
73
- const lrtBuf = new Float32Array(1);
74
- const decayShrinkBuf = decayShrinkInputName ? new Float32Array(1) : null;
75
- const innerStep = runtime.step.bind(runtime);
76
- const innerReset = runtime.resetOptimizerState.bind(runtime);
77
- const wrappedStep = (inputs, opts) => {
78
- t++;
79
- const lrNow = config.lr(t);
80
- lrtBuf[0] = lrNow * Math.sqrt(1 - Math.pow(config.b2, t)) / (1 - Math.pow(config.b1, t));
81
- const merged = { ...inputs, [lrtInputName]: lrtBuf };
82
- if (decayShrinkBuf && decayShrinkInputName) {
83
- decayShrinkBuf[0] = 1 - lrNow * config.weightDecay;
84
- merged[decayShrinkInputName] = decayShrinkBuf;
85
- }
86
- return opts?.withCaptures ? innerStep(merged, { withCaptures: true }) : innerStep(merged);
87
- };
88
- runtime.step = wrappedStep;
89
- runtime.resetOptimizerState = () => {
90
- t = 0;
91
- innerReset();
92
- };
63
+ if (opts.adam) {
64
+ wrapStepForAdam(runtime, opts.adam, ir);
93
65
  }
94
- const uploadInitialParams = () => {
95
- const out = buildInitialParamUploads(plan, materialized.initFns);
96
- runtime.uploadParams(out);
66
+ // Auto-upload initial param values. Always wanted at this entry point —
67
+ // training runtimes own their params and need them randomized before step 1.
68
+ uploadInitialParams(plan, materialized.initFns, runtime, /* sharedParams */ undefined);
69
+ const kernelCount = kernels.filter(k => k.wgsl).length;
70
+ const reset = () => {
71
+ uploadInitialParams(plan, materialized.initFns, runtime, undefined);
72
+ runtime.resetOptimizerState();
97
73
  };
98
- const ir = { graph, paramGrads, loss, plan, kernels };
99
- return Object.assign(runtime, { ir, uploadInitialParams });
100
- }
101
- function buildInitialParamUploads(plan, initFns, sharedParams) {
102
- const out = {};
103
- for (const [name, bufId] of plan.paramsByName) {
104
- if (sharedParams?.has(name))
105
- continue;
106
- const shape = plan.buffers[bufId].shape;
107
- const size = shape.reduce((a, b) => a * b, 1);
108
- const initFn = initFns[name];
109
- if (!initFn)
110
- throw new Error(`uploadInitialParams: no init for param '${name}'`);
111
- out[name] = initFn(size, shape);
112
- }
113
- return out;
74
+ const compileForwardMethod = async (forwardFn, fOpts = {}) => {
75
+ return compileForward(modelFactory, forwardFn, {
76
+ ...fOpts,
77
+ device: runtime.device,
78
+ sharedParams: runtime.params,
79
+ });
80
+ };
81
+ return Object.assign(runtime, {
82
+ ir,
83
+ kernelCount,
84
+ reset,
85
+ compileForward: compileForwardMethod,
86
+ });
114
87
  }
115
88
  // ============================================================================
116
89
  // Forward-only compile
@@ -121,38 +94,116 @@ function buildInitialParamUploads(plan, initFns, sharedParams) {
121
94
  * scalar loss; runtime exposes `run(inputs)` returning the full output as a
122
95
  * `Float32Array`.
123
96
  *
97
+ * **Prefer the `compileForward` method on a training runtime** when both
98
+ * graphs use the same Module class — it auto-supplies `device` and
99
+ * `sharedParams`. This standalone form is for forward-only models with no
100
+ * training graph at all, or for sharing params across a different model.
101
+ *
124
102
  * **Sharing params with a training compile.** Pass `opts.sharedParams =
125
103
  * trainCompiled.params` to bind this graph's param buffers to an existing
126
104
  * training runtime's GPU buffers — every train step is then immediately
127
- * visible to `run()` calls here, no copies. The forward graph's
128
- * `uploadInitialParams()` skips any param covered by `sharedParams`.
105
+ * visible to `run()` calls here, no copies.
129
106
  *
130
- * Typical use: a B=1 inference graph alongside a B=512 training graph,
131
- * built from the same `Module` factory.
107
+ * Initial param values are uploaded automatically for params *not* covered
108
+ * by `sharedParams` (those are owned by the sibling compile).
132
109
  */
133
110
  export async function compileForward(modelFactory, forward, opts = {}) {
134
- const inputDecls = opts.inputs ?? [];
111
+ const sharedParams = opts.sharedParams;
112
+ const { runtime, materialized, plan, kernels, ir } = await buildModuleRuntime(modelFactory, forward, opts, sharedParams, /* withGrad */ false);
113
+ // Auto-upload initial values for any params this graph owns. With
114
+ // `sharedParams` covering everything, this is a no-op.
115
+ uploadInitialParams(plan, materialized.initFns, runtime, sharedParams);
116
+ const kernelCount = kernels.filter(k => k.wgsl).length;
117
+ return Object.assign(runtime, { ir, kernelCount });
118
+ }
119
+ /** Shared body of compileModule + compileForward. The training and forward
120
+ * pipelines diverge only in (a) whether grad/Adam are appended and (b)
121
+ * whether the output buffer is the loss scalar or the user's returned
122
+ * tensor — both come out of the same trace and codegen path. */
123
+ async function buildModuleRuntime(modelFactory, forward, opts, sharedParams, withGrad) {
124
+ const inputDecls = opts.inputs ?? {};
135
125
  const model = modelFactory();
136
126
  let materialized = { tensors: {}, initFns: {}, decayFlags: {} };
137
127
  const graph = trace(() => {
138
128
  materialized = materializeParams(model);
139
- const inputTensors = inputDecls.map(d => tensorInput(d.name, d.shape, d.dtype ?? 'f32'));
140
- return forward(model, ...inputTensors);
129
+ const inputTensors = {};
130
+ for (const [name, decl] of Object.entries(inputDecls)) {
131
+ inputTensors[name] = tensorInput(name, decl.shape, decl.dtype ?? 'f32');
132
+ }
133
+ return forward(model, inputTensors);
141
134
  });
142
- const plan = planBuffers(graph, /* paramGrads */ {});
135
+ let paramGrads = {};
136
+ let outputTensor;
137
+ let adamWritebacks = [];
138
+ if (withGrad) {
139
+ const gradResult = appendGrad(graph);
140
+ paramGrads = gradResult.paramGrads;
141
+ outputTensor = gradResult.loss;
142
+ const adamCfg = opts.adam;
143
+ if (adamCfg) {
144
+ const adamResult = appendAdam(graph, paramGrads, materialized.tensors, adamCfg, materialized.decayFlags);
145
+ adamWritebacks = adamResult.writebacks;
146
+ graph.__adam = adamResult;
147
+ }
148
+ }
149
+ else {
150
+ outputTensor = graph.tensors[graph.outputs[0]];
151
+ }
152
+ const plan = planBuffers(graph, paramGrads, adamWritebacks);
143
153
  const kernels = emitKernels(graph, plan);
144
- const outputTensor = graph.tensors[graph.outputs[0]];
145
154
  const outputBufferId = plan.tensorToBuffer.get(outputTensor.id);
146
- const runtime = await createForwardRuntime(plan, kernels, outputBufferId, opts);
147
- const sharedParams = opts.sharedParams;
148
- const uploadInitialParams = () => {
149
- const out = buildInitialParamUploads(plan, materialized.initFns, sharedParams);
150
- if (Object.keys(out).length > 0)
151
- runtime.uploadParams(out, { partial: !!sharedParams });
155
+ // exactOptionalPropertyTypes: only include sharedParams when defined.
156
+ const runtimeOpts = sharedParams
157
+ ? { ...opts, sharedParams }
158
+ : { ...opts };
159
+ const runtime = withGrad
160
+ ? await createRuntime(plan, kernels, outputBufferId, runtimeOpts)
161
+ : await createForwardRuntime(plan, kernels, outputBufferId, runtimeOpts);
162
+ const ir = { graph, paramGrads, loss: outputTensor, plan, kernels };
163
+ return { runtime: runtime, materialized, plan, kernels, ir };
164
+ }
165
+ function wrapStepForAdam(runtime, adamCfg, ir) {
166
+ const adamResult = ir.graph.__adam;
167
+ const { lrtInputName, decayShrinkInputName, config } = adamResult;
168
+ let t = 0;
169
+ const lrtBuf = new Float32Array(1);
170
+ const decayShrinkBuf = decayShrinkInputName ? new Float32Array(1) : null;
171
+ const innerStep = runtime.step.bind(runtime);
172
+ const innerReset = runtime.resetOptimizerState.bind(runtime);
173
+ const wrappedStep = ((inputs, opts) => {
174
+ t++;
175
+ const lrNow = config.lr(t);
176
+ lrtBuf[0] = lrNow * Math.sqrt(1 - Math.pow(config.b2, t)) / (1 - Math.pow(config.b1, t));
177
+ const merged = { ...inputs, [lrtInputName]: lrtBuf };
178
+ if (decayShrinkBuf && decayShrinkInputName) {
179
+ decayShrinkBuf[0] = 1 - lrNow * config.weightDecay;
180
+ merged[decayShrinkInputName] = decayShrinkBuf;
181
+ }
182
+ return opts?.withCaptures ? innerStep(merged, { withCaptures: true }) : innerStep(merged);
183
+ });
184
+ runtime.step = wrappedStep;
185
+ runtime.resetOptimizerState = () => {
186
+ t = 0;
187
+ innerReset();
152
188
  };
153
- // CompiledIR.loss is the field name; for forward-only, it carries the user's
154
- // returned tensor (e.g., logits). Same shape conceptually; just no autograd.
155
- const ir = { graph, paramGrads: {}, loss: outputTensor, plan, kernels };
156
- return Object.assign(runtime, { ir, uploadInitialParams });
189
+ void adamCfg;
190
+ }
191
+ /** Build a Record<paramName, Float32Array> by running each param's init
192
+ * function against its shape and uploading them to the runtime. Skips any
193
+ * param covered by `sharedParams` (those are owned by a sibling compile). */
194
+ function uploadInitialParams(plan, initFns, runtime, sharedParams) {
195
+ const out = {};
196
+ for (const [name, bufId] of plan.paramsByName) {
197
+ if (sharedParams?.has(name))
198
+ continue;
199
+ const shape = plan.buffers[bufId].shape;
200
+ const size = shape.reduce((a, b) => a * b, 1);
201
+ const initFn = initFns[name];
202
+ if (!initFn)
203
+ throw new Error(`compile: no init for param '${name}'`);
204
+ out[name] = initFn(size, shape);
205
+ }
206
+ if (Object.keys(out).length > 0)
207
+ runtime.uploadParams(out, { partial: !!sharedParams });
157
208
  }
158
209
  //# sourceMappingURL=compile.js.map
@@ -1 +1 @@
1
- {"version":3,"file":"compile.js","sourceRoot":"","sources":["../src/compile.ts"],"names":[],"mappings":"AAAA,2EAA2E;AAC3E,EAAE;AACF,oBAAoB;AACpB,sEAAsE;AACtE,iEAAiE;AACjE,0EAA0E;AAC1E,0EAA0E;AAC1E,2EAA2E;AAC3E,mEAAmE;AAGnE,OAAO,EAAE,KAAK,EAAE,WAAW,EAAE,MAAM,YAAY,CAAA;AAC/C,OAAO,EAAE,UAAU,EAAmB,MAAM,WAAW,CAAA;AACvD,OAAO,EAAE,UAAU,EAAmB,MAAM,WAAW,CAAA;AACvD,OAAO,EAAE,WAAW,EAAmB,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAE,WAAW,EAAmB,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAE,aAAa,EAAE,oBAAoB,EAAgE,MAAM,cAAc,CAAA;AAChI,OAAO,EAAE,MAAM,EAAE,iBAAiB,EAAE,MAAM,aAAa,CAAA;AAmBvD,yEAAyE;AACzE,MAAM,UAAU,WAAW,CAAC,OAAqB;IAC/C,MAAM,KAAK,GAAG,KAAK,CAAC,OAAO,CAAC,CAAA;IAC5B,MAAM,EAAE,UAAU,EAAE,IAAI,EAAE,GAAG,UAAU,CAAC,KAAK,CAAC,CAAA;IAC9C,MAAM,IAAI,GAAG,WAAW,CAAC,KAAK,EAAE,UAAU,CAAC,CAAA;IAC3C,MAAM,OAAO,GAAG,WAAW,CAAC,KAAK,EAAE,IAAI,CAAC,CAAA;IACxC,OAAO,EAAE,KAAK,EAAE,UAAU,EAAE,IAAI,EAAE,IAAI,EAAE,OAAO,EAAE,CAAA;AACnD,CAAC;AAED,0EAA0E;AAC1E,MAAM,CAAC,KAAK,UAAU,OAAO,CAAC,OAAqB,EAAE,OAAoB,EAAE;IACzE,MAAM,EAAE,GAAG,WAAW,CAAC,OAAO,CAAC,CAAA;IAC/B,MAAM,YAAY,GAAG,EAAE,CAAC,IAAI,CAAC,cAAc,CAAC,GAAG,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,CAAE,CAAA;IAC5D,MAAM,OAAO,GAAG,MAAM,aAAa,CAAC,EAAE,CAAC,IAAI,EAAE,EAAE,CAAC,OAAO,EAAE,YAAY,EAAE,IAAI,CAAC,CAAA;IAC5E,OAAO,MAAM,CAAC,MAAM,CAAC,OAAO,EAAE,EAAE,EAAE,EAAE,CAAC,CAAA;AACvC,CAAC;AAqBD;;;;;;;;;;;;;;;GAeG;AACH,MAAM,CAAC,KAAK,UAAU,aAAa,CACjC,YAAqB,EACrB,OAA8C,EAC9C,OAA6B,EAAE;IAE/B,MAAM,UAAU,GAAG,IAAI,CAAC,MAAM,IAAI,EAAE,CAAA;IACpC,MAAM,KAAK,GAAG,YAAY,EAAE,CAAA;IAC5B,IAAI,YAAY,GAAyC,EAAE,OAAO,EAAE,EAAE,EAAE,OAAO,EAAE,EAAE,EAAE,UAAU,EAAE,EAAE,EAAE,CAAA;IACrG,MAAM,KAAK,GAAG,KAAK,CAAC,GAAG,EAAE;QACvB,YAAY,GAAG,iBAAiB,CAAC,KAAK,CAAC,CAAA;QACvC,MAAM,YAAY,GAAG,UAAU,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,WAAW,CAAC,CAAC,CAAC,IAAI,EAAE,CAAC,CAAC,KAAK,EAAE,CAAC,CAAC,KAAK,IAAI,KAAK,CAAC,CAAC,CAAA;QACxF,OAAO,OAAO,CAAC,KAAK,EAAE,GAAG,YAAY,CAAC,CAAA;IACxC,CAAC,CAAC,CAAA;IAEF,MAAM,EAAE,UAAU,EAAE,IAAI,EAAE,GAAG,UAAU,CAAC,KAAK,CAAC,CAAA;IAE9C,IAAI,UAAqD,CAAA;IACzD,IAAI,IAAI,CAAC,IAAI,EAAE,CAAC;QACd,UAAU,GAAG,UAAU,CAAC,KAAK,EAAE,UAAU,EAAE,YAAY,CAAC,OAAO,EAAE,IAAI,CAAC,IAAI,EAAE,YAAY,CAAC,UAAU,CAAC,CAAA;IACtG,CAAC;IAED,MAAM,IAAI,GAAG,WAAW,CAAC,KAAK,EAAE,UAAU,EAAE,UAAU,EAAE,UAAU,IAAI,EAAE,CAAC,CAAA;IACzE,MAAM,OAAO,GAAG,WAAW,CAAC,KAAK,EAAE,IAAI,CAAC,CAAA;IACxC,MAAM,YAAY,GAAG,IAAI,CAAC,cAAc,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,CAAE,CAAA;IACtD,MAAM,OAAO,GAAG,MAAM,aAAa,CAAC,IAAI,EAAE,OAAO,EAAE,YAAY,EAAE,IAAI,CAAC,CAAA;IAEtE,yEAAyE;IACzE,6EAA6E;IAC7E,iFAAiF;IACjF,oEAAoE;IACpE,IAAI,UAAU,EAAE,CAAC;QACf,MAAM,EAAE,YAAY,EAAE,oBAAoB,EAAE,MAAM,EAAE,GAAG,UAAU,CAAA;QACjE,IAAI,CAAC,GAAG,CAAC,CAAA;QACT,MAAM,MAAM,GAAG,IAAI,YAAY,CAAC,CAAC,CAAC,CAAA;QAClC,MAAM,cAAc,GAAG,oBAAoB,CAAC,CAAC,CAAC,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,IAAI,CAAA;QACxE,MAAM,SAAS,GAAG,OAAO,CAAC,IAAI,CAAC,IAAI,CAAC,OAAO,CAA4B,CAAA;QACvE,MAAM,UAAU,GAAG,OAAO,CAAC,mBAAmB,CAAC,IAAI,CAAC,OAAO,CAAC,CAAA;QAC5D,MAAM,WAAW,GAAG,CAClB,MAAiD,EACjD,IAAiC,EAC2C,EAAE;YAC9E,CAAC,EAAE,CAAA;YACH,MAAM,KAAK,GAAG,MAAM,CAAC,EAAE,CAAC,CAAC,CAAC,CAAA;YAC1B,MAAM,CAAC,CAAC,CAAC,GAAG,KAAK,GAAG,IAAI,CAAC,IAAI,CAAC,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,MAAM,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,MAAM,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,CAAA;YACxF,MAAM,MAAM,GAA8C,EAAE,GAAG,MAAM,EAAE,CAAC,YAAY,CAAC,EAAE,MAAM,EAAE,CAAA;YAC/F,IAAI,cAAc,IAAI,oBAAoB,EAAE,CAAC;gBAC3C,cAAc,CAAC,CAAC,CAAC,GAAG,CAAC,GAAG,KAAK,GAAG,MAAM,CAAC,WAAW,CAAA;gBAClD,MAAM,CAAC,oBAAoB,CAAC,GAAG,cAAc,CAAA;YAC/C,CAAC;YACD,OAAO,IAAI,EAAE,YAAY,CAAC,CAAC,CAAC,SAAS,CAAC,MAAM,EAAE,EAAE,YAAY,EAAE,IAAI,EAAE,CAAC,CAAC,CAAC,CAAC,SAAS,CAAC,MAAM,CAAC,CAAA;QAC3F,CAAC,CAAA;QACD,OAAO,CAAC,IAAI,GAAG,WAAsC,CAAA;QACrD,OAAO,CAAC,mBAAmB,GAAG,GAAG,EAAE;YACjC,CAAC,GAAG,CAAC,CAAA;YACL,UAAU,EAAE,CAAA;QACd,CAAC,CAAA;IACH,CAAC;IAED,MAAM,mBAAmB,GAAG,GAAG,EAAE;QAC/B,MAAM,GAAG,GAAG,wBAAwB,CAAC,IAAI,EAAE,YAAY,CAAC,OAAO,CAAC,CAAA;QAChE,OAAO,CAAC,YAAY,CAAC,GAAG,CAAC,CAAA;IAC3B,CAAC,CAAA;IAED,MAAM,EAAE,GAAe,EAAE,KAAK,EAAE,UAAU,EAAE,IAAI,EAAE,IAAI,EAAE,OAAO,EAAE,CAAA;IACjE,OAAO,MAAM,CAAC,MAAM,CAAC,OAAO,EAAE,EAAE,EAAE,EAAE,mBAAmB,EAAE,CAAC,CAAA;AAC5D,CAAC;AAOD,SAAS,wBAAwB,CAC/B,IAAgB,EAChB,OAA+B,EAC/B,YAAqC;IAErC,MAAM,GAAG,GAAiC,EAAE,CAAA;IAC5C,KAAK,MAAM,CAAC,IAAI,EAAE,KAAK,CAAC,IAAI,IAAI,CAAC,YAAY,EAAE,CAAC;QAC9C,IAAI,YAAY,EAAE,GAAG,CAAC,IAAI,CAAC;YAAE,SAAQ;QACrC,MAAM,KAAK,GAAG,IAAI,CAAC,OAAO,CAAC,KAAK,CAAE,CAAC,KAAK,CAAA;QACxC,MAAM,IAAI,GAAG,KAAK,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC,CAAA;QAC7C,MAAM,MAAM,GAAG,OAAO,CAAC,IAAI,CAAC,CAAA;QAC5B,IAAI,CAAC,MAAM;YAAE,MAAM,IAAI,KAAK,CAAC,2CAA2C,IAAI,GAAG,CAAC,CAAA;QAChF,GAAG,CAAC,IAAI,CAAC,GAAG,MAAM,CAAC,IAAI,EAAE,KAAK,CAAC,CAAA;IACjC,CAAC;IACD,OAAO,GAAG,CAAA;AACZ,CAAC;AAED,+EAA+E;AAC/E,uBAAuB;AACvB,+EAA+E;AAE/E;;;;;;;;;;;;;;GAcG;AACH,MAAM,CAAC,KAAK,UAAU,cAAc,CAClC,YAAqB,EACrB,OAA8C,EAC9C,OAA8B,EAAE;IAEhC,MAAM,UAAU,GAAG,IAAI,CAAC,MAAM,IAAI,EAAE,CAAA;IACpC,MAAM,KAAK,GAAG,YAAY,EAAE,CAAA;IAC5B,IAAI,YAAY,GAAyC,EAAE,OAAO,EAAE,EAAE,EAAE,OAAO,EAAE,EAAE,EAAE,UAAU,EAAE,EAAE,EAAE,CAAA;IACrG,MAAM,KAAK,GAAG,KAAK,CAAC,GAAG,EAAE;QACvB,YAAY,GAAG,iBAAiB,CAAC,KAAK,CAAC,CAAA;QACvC,MAAM,YAAY,GAAG,UAAU,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,WAAW,CAAC,CAAC,CAAC,IAAI,EAAE,CAAC,CAAC,KAAK,EAAE,CAAC,CAAC,KAAK,IAAI,KAAK,CAAC,CAAC,CAAA;QACxF,OAAO,OAAO,CAAC,KAAK,EAAE,GAAG,YAAY,CAAC,CAAA;IACxC,CAAC,CAAC,CAAA;IAEF,MAAM,IAAI,GAAG,WAAW,CAAC,KAAK,EAAE,gBAAgB,CAAC,EAAE,CAAC,CAAA;IACpD,MAAM,OAAO,GAAG,WAAW,CAAC,KAAK,EAAE,IAAI,CAAC,CAAA;IACxC,MAAM,YAAY,GAAG,KAAK,CAAC,OAAO,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC,CAAE,CAAE,CAAA;IACtD,MAAM,cAAc,GAAG,IAAI,CAAC,cAAc,CAAC,GAAG,CAAC,YAAY,CAAC,EAAE,CAAE,CAAA;IAChE,MAAM,OAAO,GAAG,MAAM,oBAAoB,CAAC,IAAI,EAAE,OAAO,EAAE,cAAc,EAAE,IAAI,CAAC,CAAA;IAE/E,MAAM,YAAY,GAAG,IAAI,CAAC,YAAY,CAAA;IACtC,MAAM,mBAAmB,GAAG,GAAG,EAAE;QAC/B,MAAM,GAAG,GAAG,wBAAwB,CAAC,IAAI,EAAE,YAAY,CAAC,OAAO,EAAE,YAAY,CAAC,CAAA;QAC9E,IAAI,MAAM,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,MAAM,GAAG,CAAC;YAAE,OAAO,CAAC,YAAY,CAAC,GAAG,EAAE,EAAE,OAAO,EAAE,CAAC,CAAC,YAAY,EAAE,CAAC,CAAA;IACzF,CAAC,CAAA;IAED,6EAA6E;IAC7E,6EAA6E;IAC7E,MAAM,EAAE,GAAe,EAAE,KAAK,EAAE,UAAU,EAAE,EAAE,EAAE,IAAI,EAAE,YAAY,EAAE,IAAI,EAAE,OAAO,EAAE,CAAA;IACnF,OAAO,MAAM,CAAC,MAAM,CAAC,OAAO,EAAE,EAAE,EAAE,EAAE,mBAAmB,EAAE,CAAC,CAAA;AAC5D,CAAC"}
1
+ {"version":3,"file":"compile.js","sourceRoot":"","sources":["../src/compile.ts"],"names":[],"mappings":"AAAA,2EAA2E;AAC3E,EAAE;AACF,oBAAoB;AACpB,sEAAsE;AACtE,iEAAiE;AACjE,0EAA0E;AAC1E,0EAA0E;AAC1E,2EAA2E;AAC3E,mEAAmE;AAGnE,OAAO,EAAE,KAAK,EAAE,WAAW,EAAE,MAAM,YAAY,CAAA;AAC/C,OAAO,EAAE,UAAU,EAAmB,MAAM,WAAW,CAAA;AACvD,OAAO,EAAE,UAAU,EAAmB,MAAM,WAAW,CAAA;AACvD,OAAO,EAAE,WAAW,EAAmB,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAE,WAAW,EAAmB,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAE,aAAa,EAAE,oBAAoB,EAAgE,MAAM,cAAc,CAAA;AAChI,OAAO,EAAE,MAAM,EAAE,iBAAiB,EAAE,MAAM,aAAa,CAAA;AAoCvD,yEAAyE;AACzE,MAAM,UAAU,WAAW,CAAC,OAAqB;IAC/C,MAAM,KAAK,GAAG,KAAK,CAAC,OAAO,CAAC,CAAA;IAC5B,MAAM,EAAE,UAAU,EAAE,IAAI,EAAE,GAAG,UAAU,CAAC,KAAK,CAAC,CAAA;IAC9C,MAAM,IAAI,GAAG,WAAW,CAAC,KAAK,EAAE,UAAU,CAAC,CAAA;IAC3C,MAAM,OAAO,GAAG,WAAW,CAAC,KAAK,EAAE,IAAI,CAAC,CAAA;IACxC,OAAO,EAAE,KAAK,EAAE,UAAU,EAAE,IAAI,EAAE,IAAI,EAAE,OAAO,EAAE,CAAA;AACnD,CAAC;AAED,0EAA0E;AAC1E,MAAM,CAAC,KAAK,UAAU,OAAO,CAAC,OAAqB,EAAE,OAAoB,EAAE;IACzE,MAAM,EAAE,GAAG,WAAW,CAAC,OAAO,CAAC,CAAA;IAC/B,MAAM,YAAY,GAAG,EAAE,CAAC,IAAI,CAAC,cAAc,CAAC,GAAG,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,CAAE,CAAA;IAC5D,MAAM,OAAO,GAAG,MAAM,aAAa,CAAC,EAAE,CAAC,IAAI,EAAE,EAAE,CAAC,OAAO,EAAE,YAAY,EAAE,IAAI,CAAC,CAAA;IAC5E,OAAO,MAAM,CAAC,MAAM,CAAC,OAAO,EAAE,EAAE,EAAE,EAAE,CAAC,CAAA;AACvC,CAAC;AAsDD;;;;;;;;;;;;;;;;;;;;;;;;GAwBG;AACH,MAAM,CAAC,KAAK,UAAU,aAAa,CACjC,YAAqB,EACrB,OAAwB,EACxB,OAAgC,EAAE;IAElC,MAAM,EAAE,OAAO,EAAE,YAAY,EAAE,IAAI,EAAE,OAAO,EAAE,EAAE,EAAE,GAAG,MAAM,kBAAkB,CAC3E,YAAY,EAAE,OAAO,EAAE,IAAI,EAAE,kBAAkB,CAAC,SAAS,EAAE,cAAc,CAAC,IAAI,CAC/E,CAAA;IAED,yEAAyE;IACzE,6EAA6E;IAC7E,iFAAiF;IACjF,oEAAoE;IACpE,IAAI,IAAI,CAAC,IAAI,EAAE,CAAC;QACd,eAAe,CAAC,OAAO,EAAE,IAAI,CAAC,IAAI,EAAE,EAAE,CAAC,CAAA;IACzC,CAAC;IAED,wEAAwE;IACxE,6EAA6E;IAC7E,mBAAmB,CAAC,IAAI,EAAE,YAAY,CAAC,OAAO,EAAE,OAAO,EAAE,kBAAkB,CAAC,SAAS,CAAC,CAAA;IAEtF,MAAM,WAAW,GAAG,OAAO,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC,MAAM,CAAA;IAEtD,MAAM,KAAK,GAAG,GAAG,EAAE;QACjB,mBAAmB,CAAC,IAAI,EAAE,YAAY,CAAC,OAAO,EAAE,OAAO,EAAE,SAAS,CAAC,CAAA;QACnE,OAAO,CAAC,mBAAmB,EAAE,CAAA;IAC/B,CAAC,CAAA;IAED,MAAM,oBAAoB,GAAG,KAAK,EAChC,SAA0B,EAC1B,QAAwC,EAAE,EACV,EAAE;QAClC,OAAO,cAAc,CAAO,YAAY,EAAE,SAAS,EAAE;YACnD,GAAG,KAAK;YACR,MAAM,EAAE,OAAO,CAAC,MAAM;YACtB,YAAY,EAAE,OAAO,CAAC,MAAM;SAC7B,CAAC,CAAA;IACJ,CAAC,CAAA;IAED,OAAO,MAAM,CAAC,MAAM,CAAC,OAAO,EAAE;QAC5B,EAAE;QACF,WAAW;QACX,KAAK;QACL,cAAc,EAAE,oBAAoB;KACrC,CAAC,CAAA;AACJ,CAAC;AAED,+EAA+E;AAC/E,uBAAuB;AACvB,+EAA+E;AAE/E;;;;;;;;;;;;;;;;;;GAkBG;AACH,MAAM,CAAC,KAAK,UAAU,cAAc,CAClC,YAAqB,EACrB,OAAwB,EACxB,OAAiC,EAAE;IAEnC,MAAM,YAAY,GAAG,IAAI,CAAC,YAAY,CAAA;IACtC,MAAM,EAAE,OAAO,EAAE,YAAY,EAAE,IAAI,EAAE,OAAO,EAAE,EAAE,EAAE,GAAG,MAAM,kBAAkB,CAC3E,YAAY,EAAE,OAAO,EAAE,IAAI,EAAE,YAAY,EAAE,cAAc,CAAC,KAAK,CAChE,CAAA;IAED,kEAAkE;IAClE,uDAAuD;IACvD,mBAAmB,CAAC,IAAI,EAAE,YAAY,CAAC,OAAO,EAAE,OAAO,EAAE,YAAY,CAAC,CAAA;IAEtE,MAAM,WAAW,GAAG,OAAO,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC,MAAM,CAAA;IACtD,OAAO,MAAM,CAAC,MAAM,CAAC,OAAO,EAAE,EAAE,EAAE,EAAE,WAAW,EAAE,CAAC,CAAA;AACpD,CAAC;AAgBD;;;iEAGiE;AACjE,KAAK,UAAU,kBAAkB,CAC/B,YAAqB,EACrB,OAAwB,EACxB,IAAwD,EACxD,YAAgD,EAChD,QAAiB;IAEjB,MAAM,UAAU,GAAe,IAAI,CAAC,MAAM,IAAI,EAAE,CAAA;IAChD,MAAM,KAAK,GAAG,YAAY,EAAE,CAAA;IAC5B,IAAI,YAAY,GAAyC,EAAE,OAAO,EAAE,EAAE,EAAE,OAAO,EAAE,EAAE,EAAE,UAAU,EAAE,EAAE,EAAE,CAAA;IACrG,MAAM,KAAK,GAAG,KAAK,CAAC,GAAG,EAAE;QACvB,YAAY,GAAG,iBAAiB,CAAC,KAAK,CAAC,CAAA;QACvC,MAAM,YAAY,GAA2B,EAAE,CAAA;QAC/C,KAAK,MAAM,CAAC,IAAI,EAAE,IAAI,CAAC,IAAI,MAAM,CAAC,OAAO,CAAC,UAAU,CAAC,EAAE,CAAC;YACtD,YAAY,CAAC,IAAI,CAAC,GAAG,WAAW,CAAC,IAAI,EAAE,IAAI,CAAC,KAAK,EAAE,IAAI,CAAC,KAAK,IAAI,KAAK,CAAC,CAAA;QACzE,CAAC;QACD,OAAO,OAAO,CAAC,KAAK,EAAE,YAAgC,CAAC,CAAA;IACzD,CAAC,CAAC,CAAA;IAEF,IAAI,UAAU,GAA6B,EAAE,CAAA;IAC7C,IAAI,YAAoB,CAAA;IACxB,IAAI,cAAc,GAAgD,EAAE,CAAA;IAEpE,IAAI,QAAQ,EAAE,CAAC;QACb,MAAM,UAAU,GAAG,UAAU,CAAC,KAAK,CAAC,CAAA;QACpC,UAAU,GAAG,UAAU,CAAC,UAAU,CAAA;QAClC,YAAY,GAAG,UAAU,CAAC,IAAI,CAAA;QAC9B,MAAM,OAAO,GAAI,IAA6B,CAAC,IAAI,CAAA;QACnD,IAAI,OAAO,EAAE,CAAC;YACZ,MAAM,UAAU,GAAG,UAAU,CAAC,KAAK,EAAE,UAAU,EAAE,YAAY,CAAC,OAAO,EAAE,OAAO,EAAE,YAAY,CAAC,UAAU,CAAC,CAAA;YACxG,cAAc,GAAG,UAAU,CAAC,UAAU,CAErC;YAAC,KAA4D,CAAC,MAAM,GAAG,UAAU,CAAA;QACpF,CAAC;IACH,CAAC;SAAM,CAAC;QACN,YAAY,GAAG,KAAK,CAAC,OAAO,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC,CAAE,CAAE,CAAA;IAClD,CAAC;IAED,MAAM,IAAI,GAAG,WAAW,CAAC,KAAK,EAAE,UAAU,EAAE,cAAc,CAAC,CAAA;IAC3D,MAAM,OAAO,GAAG,WAAW,CAAC,KAAK,EAAE,IAAI,CAAC,CAAA;IACxC,MAAM,cAAc,GAAG,IAAI,CAAC,cAAc,CAAC,GAAG,CAAC,YAAY,CAAC,EAAE,CAAE,CAAA;IAChE,sEAAsE;IACtE,MAAM,WAAW,GAAgB,YAAY;QAC3C,CAAC,CAAC,EAAE,GAAG,IAAI,EAAE,YAAY,EAAE;QAC3B,CAAC,CAAC,EAAE,GAAG,IAAI,EAAE,CAAA;IACf,MAAM,OAAO,GAAG,QAAQ;QACtB,CAAC,CAAC,MAAM,aAAa,CAAC,IAAI,EAAE,OAAO,EAAE,cAAc,EAAE,WAAW,CAAC;QACjE,CAAC,CAAC,MAAM,oBAAoB,CAAC,IAAI,EAAE,OAAO,EAAE,cAAc,EAAE,WAAW,CAAC,CAAA;IAE1E,MAAM,EAAE,GAAe,EAAE,KAAK,EAAE,UAAU,EAAE,IAAI,EAAE,YAAY,EAAE,IAAI,EAAE,OAAO,EAAE,CAAA;IAC/E,OAAO,EAAE,OAAO,EAAE,OAA0B,EAAE,YAAY,EAAE,IAAI,EAAE,OAAO,EAAE,EAAE,EAAE,CAAA;AACjF,CAAC;AAID,SAAS,eAAe,CAAC,OAAwB,EAAE,OAAmB,EAAE,EAAc;IACpF,MAAM,UAAU,GAAI,EAAE,CAAC,KAA4D,CAAC,MAAO,CAAA;IAC3F,MAAM,EAAE,YAAY,EAAE,oBAAoB,EAAE,MAAM,EAAE,GAAG,UAAU,CAAA;IACjE,IAAI,CAAC,GAAG,CAAC,CAAA;IACT,MAAM,MAAM,GAAG,IAAI,YAAY,CAAC,CAAC,CAAC,CAAA;IAClC,MAAM,cAAc,GAAG,oBAAoB,CAAC,CAAC,CAAC,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,IAAI,CAAA;IACxE,MAAM,SAAS,GAAG,OAAO,CAAC,IAAI,CAAC,IAAI,CAAC,OAAO,CAA4B,CAAA;IACvE,MAAM,UAAU,GAAG,OAAO,CAAC,mBAAmB,CAAC,IAAI,CAAC,OAAO,CAAC,CAAA;IAC5D,MAAM,WAAW,GAAG,CAAC,CACnB,MAAiD,EACjD,IAAiC,EACjC,EAAE;QACF,CAAC,EAAE,CAAA;QACH,MAAM,KAAK,GAAG,MAAM,CAAC,EAAE,CAAC,CAAC,CAAC,CAAA;QAC1B,MAAM,CAAC,CAAC,CAAC,GAAG,KAAK,GAAG,IAAI,CAAC,IAAI,CAAC,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,MAAM,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,MAAM,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,CAAA;QACxF,MAAM,MAAM,GAA8C,EAAE,GAAG,MAAM,EAAE,CAAC,YAAY,CAAC,EAAE,MAAM,EAAE,CAAA;QAC/F,IAAI,cAAc,IAAI,oBAAoB,EAAE,CAAC;YAC3C,cAAc,CAAC,CAAC,CAAC,GAAG,CAAC,GAAG,KAAK,GAAG,MAAM,CAAC,WAAW,CAAA;YAClD,MAAM,CAAC,oBAAoB,CAAC,GAAG,cAAc,CAAA;QAC/C,CAAC;QACD,OAAO,IAAI,EAAE,YAAY,CAAC,CAAC,CAAC,SAAS,CAAC,MAAM,EAAE,EAAE,YAAY,EAAE,IAAI,EAAE,CAAC,CAAC,CAAC,CAAC,SAAS,CAAC,MAAM,CAAC,CAAA;IAC3F,CAAC,CAA4B,CAAA;IAC7B,OAAO,CAAC,IAAI,GAAG,WAAW,CAAA;IAC1B,OAAO,CAAC,mBAAmB,GAAG,GAAG,EAAE;QACjC,CAAC,GAAG,CAAC,CAAA;QACL,UAAU,EAAE,CAAA;IACd,CAAC,CAAA;IACD,KAAK,OAAO,CAAA;AACd,CAAC;AAED;;8EAE8E;AAC9E,SAAS,mBAAmB,CAC1B,IAAgB,EAChB,OAA+B,EAC/B,OAA0C,EAC1C,YAAgD;IAEhD,MAAM,GAAG,GAAiC,EAAE,CAAA;IAC5C,KAAK,MAAM,CAAC,IAAI,EAAE,KAAK,CAAC,IAAI,IAAI,CAAC,YAAY,EAAE,CAAC;QAC9C,IAAI,YAAY,EAAE,GAAG,CAAC,IAAI,CAAC;YAAE,SAAQ;QACrC,MAAM,KAAK,GAAG,IAAI,CAAC,OAAO,CAAC,KAAK,CAAE,CAAC,KAAK,CAAA;QACxC,MAAM,IAAI,GAAG,KAAK,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC,CAAA;QAC7C,MAAM,MAAM,GAAG,OAAO,CAAC,IAAI,CAAC,CAAA;QAC5B,IAAI,CAAC,MAAM;YAAE,MAAM,IAAI,KAAK,CAAC,+BAA+B,IAAI,GAAG,CAAC,CAAA;QACpE,GAAG,CAAC,IAAI,CAAC,GAAG,MAAM,CAAC,IAAI,EAAE,KAAK,CAAC,CAAA;IACjC,CAAC;IACD,IAAI,MAAM,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,MAAM,GAAG,CAAC;QAAE,OAAO,CAAC,YAAY,CAAC,GAAG,EAAE,EAAE,OAAO,EAAE,CAAC,CAAC,YAAY,EAAE,CAAC,CAAA;AACzF,CAAC"}
package/dist/index.d.ts CHANGED
@@ -7,8 +7,8 @@ export { appendGrad, type GradResult } from './grad.js';
7
7
  export { appendAdam, type AdamConfig, type AdamResult } from './adam.js';
8
8
  export { planBuffers, type BufferPlan, type BufferSpec, type Writeback, type WritebackDecl } from './buffers.js';
9
9
  export { emitKernels, type KernelSpec } from './codegen.js';
10
- export { createRuntime, createForwardRuntime, type CompiledRuntime, type CompiledForward, type RuntimeOpts, type StepOptions, type StepWithCaptures, type RunOptions, type RunWithCaptures } from './runtime.js';
11
- export { compile, compileToIR, compileModule, compileForward, type CompiledIR, type CompileModuleOptions, type CompileForwardOptions, type InputDecl } from './compile.js';
10
+ export { createRuntime, createForwardRuntime, Captures, type CompiledRuntime, type CompiledForward, type RuntimeOpts, type RunOptions, type StepResult, type RunResult } from './runtime.js';
11
+ export { compile, compileToIR, compileModule, compileForward, type CompiledIR, type CompileModuleOptions, type CompileForwardOptions, type CompileForwardMethodOptions, type CompiledModule, type CompiledForwardModule, type InputDecl, type InputDecls, type InputsTensors, type ForwardFn, } from './compile.js';
12
12
  export { Module, materializeParams, type InitSpec, type ParamOptions, type MaterializedParams } from './module.js';
13
13
  export * as nn from './nn.js';
14
14
  //# sourceMappingURL=index.d.ts.map
@@ -1 +1 @@
1
- {"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAKA,YAAY,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,QAAQ,EAAE,MAAM,SAAS,CAAA;AAC5E,OAAO,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AACvC,OAAO,EAAE,KAAK,EAAE,SAAS,EAAE,UAAU,EAAE,WAAW,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AAClF,OAAO,EAAE,OAAO,EAAE,MAAM,cAAc,CAAA;AACtC,OAAO,EAEL,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG,EAElB,IAAI,EAAE,KAAK,EAAE,GAAG,EAAE,GAAG,EAAE,IAAI,EAE3B,IAAI,EAAE,OAAO,EAAE,KAAK,EAEpB,QAAQ,EAAE,OAAO,EAAE,MAAM,EAEzB,OAAO,EAAE,SAAS,EAAE,QAAQ,EAE5B,MAAM,EAAE,aAAa,EAErB,MAAM,EAAE,MAAM,EAAE,SAAS,EAEzB,iBAAiB,EAAE,cAAc,EAAE,WAAW,EAE9C,cAAc,GACf,MAAM,UAAU,CAAA;AAMjB,OAAO,EAAE,UAAU,EAAE,KAAK,UAAU,EAAE,MAAM,WAAW,CAAA;AACvD,OAAO,EAAE,UAAU,EAAE,KAAK,UAAU,EAAE,KAAK,UAAU,EAAE,MAAM,WAAW,CAAA;AACxE,OAAO,EAAE,WAAW,EAAE,KAAK,UAAU,EAAE,KAAK,UAAU,EAAE,KAAK,SAAS,EAAE,KAAK,aAAa,EAAE,MAAM,cAAc,CAAA;AAChH,OAAO,EAAE,WAAW,EAAE,KAAK,UAAU,EAAE,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAE,aAAa,EAAE,oBAAoB,EAAE,KAAK,eAAe,EAAE,KAAK,eAAe,EAAE,KAAK,WAAW,EAAE,KAAK,WAAW,EAAE,KAAK,gBAAgB,EAAE,KAAK,UAAU,EAAE,KAAK,eAAe,EAAE,MAAM,cAAc,CAAA;AAChN,OAAO,EAAE,OAAO,EAAE,WAAW,EAAE,aAAa,EAAE,cAAc,EAAE,KAAK,UAAU,EAAE,KAAK,oBAAoB,EAAE,KAAK,qBAAqB,EAAE,KAAK,SAAS,EAAE,MAAM,cAAc,CAAA;AAC1K,OAAO,EAAE,MAAM,EAAE,iBAAiB,EAAE,KAAK,QAAQ,EAAE,KAAK,YAAY,EAAE,KAAK,kBAAkB,EAAE,MAAM,aAAa,CAAA;AAClH,OAAO,KAAK,EAAE,MAAM,SAAS,CAAA"}
1
+ {"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAKA,YAAY,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,QAAQ,EAAE,MAAM,SAAS,CAAA;AAC5E,OAAO,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AACvC,OAAO,EAAE,KAAK,EAAE,SAAS,EAAE,UAAU,EAAE,WAAW,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AAClF,OAAO,EAAE,OAAO,EAAE,MAAM,cAAc,CAAA;AACtC,OAAO,EAEL,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG,EAElB,IAAI,EAAE,KAAK,EAAE,GAAG,EAAE,GAAG,EAAE,IAAI,EAE3B,IAAI,EAAE,OAAO,EAAE,KAAK,EAEpB,QAAQ,EAAE,OAAO,EAAE,MAAM,EAEzB,OAAO,EAAE,SAAS,EAAE,QAAQ,EAE5B,MAAM,EAAE,aAAa,EAErB,MAAM,EAAE,MAAM,EAAE,SAAS,EAEzB,iBAAiB,EAAE,cAAc,EAAE,WAAW,EAE9C,cAAc,GACf,MAAM,UAAU,CAAA;AAMjB,OAAO,EAAE,UAAU,EAAE,KAAK,UAAU,EAAE,MAAM,WAAW,CAAA;AACvD,OAAO,EAAE,UAAU,EAAE,KAAK,UAAU,EAAE,KAAK,UAAU,EAAE,MAAM,WAAW,CAAA;AACxE,OAAO,EAAE,WAAW,EAAE,KAAK,UAAU,EAAE,KAAK,UAAU,EAAE,KAAK,SAAS,EAAE,KAAK,aAAa,EAAE,MAAM,cAAc,CAAA;AAChH,OAAO,EAAE,WAAW,EAAE,KAAK,UAAU,EAAE,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAE,aAAa,EAAE,oBAAoB,EAAE,QAAQ,EAAE,KAAK,eAAe,EAAE,KAAK,eAAe,EAAE,KAAK,WAAW,EAAE,KAAK,UAAU,EAAE,KAAK,UAAU,EAAE,KAAK,SAAS,EAAE,MAAM,cAAc,CAAA;AAC5L,OAAO,EACL,OAAO,EAAE,WAAW,EAAE,aAAa,EAAE,cAAc,EACnD,KAAK,UAAU,EAAE,KAAK,oBAAoB,EAAE,KAAK,qBAAqB,EAAE,KAAK,2BAA2B,EACxG,KAAK,cAAc,EAAE,KAAK,qBAAqB,EAC/C,KAAK,SAAS,EAAE,KAAK,UAAU,EAAE,KAAK,aAAa,EAAE,KAAK,SAAS,GACpE,MAAM,cAAc,CAAA;AACrB,OAAO,EAAE,MAAM,EAAE,iBAAiB,EAAE,KAAK,QAAQ,EAAE,KAAK,YAAY,EAAE,KAAK,kBAAkB,EAAE,MAAM,aAAa,CAAA;AAClH,OAAO,KAAK,EAAE,MAAM,SAAS,CAAA"}
package/dist/index.js CHANGED
@@ -32,8 +32,8 @@ export { appendGrad } from './grad.js';
32
32
  export { appendAdam } from './adam.js';
33
33
  export { planBuffers } from './buffers.js';
34
34
  export { emitKernels } from './codegen.js';
35
- export { createRuntime, createForwardRuntime } from './runtime.js';
36
- export { compile, compileToIR, compileModule, compileForward } from './compile.js';
35
+ export { createRuntime, createForwardRuntime, Captures } from './runtime.js';
36
+ export { compile, compileToIR, compileModule, compileForward, } from './compile.js';
37
37
  export { Module, materializeParams } from './module.js';
38
38
  export * as nn from './nn.js';
39
39
  //# sourceMappingURL=index.js.map
package/dist/index.js.map CHANGED
@@ -1 +1 @@
1
- {"version":3,"file":"index.js","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAAA,+CAA+C;AAC/C,EAAE;AACF,8EAA8E;AAC9E,6CAA6C;AAG7C,OAAO,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AACvC,OAAO,EAAE,KAAK,EAAE,SAAS,EAAE,UAAU,EAAE,WAAW,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AAClF,OAAO,EAAE,OAAO,EAAE,MAAM,cAAc,CAAA;AACtC,OAAO;AACL,qFAAqF;AACrF,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG;AAClB,qBAAqB;AACrB,IAAI,EAAE,KAAK,EAAE,GAAG,EAAE,GAAG,EAAE,IAAI;AAC3B,uBAAuB;AACvB,IAAI,EAAE,OAAO,EAAE,KAAK;AACpB,yEAAyE;AACzE,QAAQ,EAAE,OAAO,EAAE,MAAM;AACzB,YAAY;AACZ,OAAO,EAAE,SAAS,EAAE,QAAQ;AAC5B,iBAAiB;AACjB,MAAM,EAAE,aAAa;AACrB,qBAAqB;AACrB,MAAM,EAAE,MAAM,EAAE,SAAS;AACzB,4CAA4C;AAC5C,iBAAiB,EAAE,cAAc,EAAE,WAAW;AAC9C,UAAU;AACV,cAAc,GACf,MAAM,UAAU,CAAA;AAEjB,sFAAsF;AACtF,8EAA8E;AAC9E,2EAA2E;AAC3E,qDAAqD;AACrD,OAAO,EAAE,UAAU,EAAmB,MAAM,WAAW,CAAA;AACvD,OAAO,EAAE,UAAU,EAAoC,MAAM,WAAW,CAAA;AACxE,OAAO,EAAE,WAAW,EAAwE,MAAM,cAAc,CAAA;AAChH,OAAO,EAAE,WAAW,EAAmB,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAE,aAAa,EAAE,oBAAoB,EAAgJ,MAAM,cAAc,CAAA;AAChN,OAAO,EAAE,OAAO,EAAE,WAAW,EAAE,aAAa,EAAE,cAAc,EAA0F,MAAM,cAAc,CAAA;AAC1K,OAAO,EAAE,MAAM,EAAE,iBAAiB,EAA6D,MAAM,aAAa,CAAA;AAClH,OAAO,KAAK,EAAE,MAAM,SAAS,CAAA"}
1
+ {"version":3,"file":"index.js","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAAA,+CAA+C;AAC/C,EAAE;AACF,8EAA8E;AAC9E,6CAA6C;AAG7C,OAAO,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AACvC,OAAO,EAAE,KAAK,EAAE,SAAS,EAAE,UAAU,EAAE,WAAW,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AAClF,OAAO,EAAE,OAAO,EAAE,MAAM,cAAc,CAAA;AACtC,OAAO;AACL,qFAAqF;AACrF,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG;AAClB,qBAAqB;AACrB,IAAI,EAAE,KAAK,EAAE,GAAG,EAAE,GAAG,EAAE,IAAI;AAC3B,uBAAuB;AACvB,IAAI,EAAE,OAAO,EAAE,KAAK;AACpB,yEAAyE;AACzE,QAAQ,EAAE,OAAO,EAAE,MAAM;AACzB,YAAY;AACZ,OAAO,EAAE,SAAS,EAAE,QAAQ;AAC5B,iBAAiB;AACjB,MAAM,EAAE,aAAa;AACrB,qBAAqB;AACrB,MAAM,EAAE,MAAM,EAAE,SAAS;AACzB,4CAA4C;AAC5C,iBAAiB,EAAE,cAAc,EAAE,WAAW;AAC9C,UAAU;AACV,cAAc,GACf,MAAM,UAAU,CAAA;AAEjB,sFAAsF;AACtF,8EAA8E;AAC9E,2EAA2E;AAC3E,qDAAqD;AACrD,OAAO,EAAE,UAAU,EAAmB,MAAM,WAAW,CAAA;AACvD,OAAO,EAAE,UAAU,EAAoC,MAAM,WAAW,CAAA;AACxE,OAAO,EAAE,WAAW,EAAwE,MAAM,cAAc,CAAA;AAChH,OAAO,EAAE,WAAW,EAAmB,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAE,aAAa,EAAE,oBAAoB,EAAE,QAAQ,EAAkH,MAAM,cAAc,CAAA;AAC5L,OAAO,EACL,OAAO,EAAE,WAAW,EAAE,aAAa,EAAE,cAAc,GAIpD,MAAM,cAAc,CAAA;AACrB,OAAO,EAAE,MAAM,EAAE,iBAAiB,EAA6D,MAAM,aAAa,CAAA;AAClH,OAAO,KAAK,EAAE,MAAM,SAAS,CAAA"}
package/dist/nn.d.ts CHANGED
@@ -1,35 +1,38 @@
1
1
  import { Module } from './module.js';
2
2
  import type { Tensor } from './ir.js';
3
+ import type { Captures } from './runtime.js';
4
+ export interface LinearOptions {
5
+ /** Include a bias term (default true). */
6
+ bias?: boolean;
7
+ }
3
8
  export declare class Linear extends Module {
4
9
  readonly inDim: number;
5
10
  readonly outDim: number;
6
11
  W: Tensor;
7
12
  b: Tensor | null;
8
- constructor(inDim: number, outDim: number, withBias?: boolean);
13
+ constructor(inDim: number, outDim: number, opts?: LinearOptions);
14
+ fwd(x: Tensor): Tensor;
9
15
  }
10
- export declare function linearFwd(p: Linear, x: Tensor): Tensor;
11
16
  export declare class LayerNorm extends Module {
12
17
  readonly d: number;
13
18
  readonly eps: number;
14
19
  g: Tensor;
15
20
  b: Tensor;
16
21
  constructor(d: number, eps?: number);
22
+ fwd(x: Tensor): Tensor;
17
23
  }
18
- export declare function layerNormFwd(p: LayerNorm, x: Tensor): Tensor;
19
24
  /** [..., T, D] → [..., H, T, D/H]. Folds the standard
20
25
  * `transpose(reshape(x, [..., T, H, d]), [..., H, T, d])` pattern into one
21
26
  * call. Last dim of `x` must divide evenly by `nHeads`. */
22
27
  export declare function splitHeads(x: Tensor, nHeads: number): Tensor;
23
28
  /** Inverse of `splitHeads`: [..., H, T, d] → [..., T, H*d]. */
24
29
  export declare function mergeHeads(x: Tensor): Tensor;
25
- /** Slice a flat capture readback of shape `[H, ..., ...]` into one
26
- * Float32Array per head. The leading axis is treated as the head axis;
27
- * pass the shape from `compiled.captureShapes[name]`. Result: `H` arrays,
28
- * each holding the row-major data for that head (size = product of trailing
29
- * axes). For B>1 graphs, prefix the result by the batch — this helper
30
- * assumes the leading axis is heads, which matches how `splitHeads` lays
31
- * out captures at B=1 (the typical capture-readback shape). */
32
- export declare function unsplitHeads(flat: Float32Array, shape: readonly number[]): Float32Array[];
30
+ /** Slice a captured tensor named `name` into one Float32Array per head, using
31
+ * the static shape registered at compile time. The leading axis is treated as
32
+ * heads (matching `splitHeads` layout at B=1); a leading singleton batch is
33
+ * stripped if present so callers can pass capture names directly. Throws if
34
+ * the capture isn't registered or wasn't read back this call. */
35
+ export declare function unsplitHeads(captures: Captures, name: string): Float32Array[];
33
36
  /** Per-position cross-entropy along the last (vocab) axis: returns
34
37
  * `-log p(target)` at each position. `logits` is `[..., V]`; `targets` is
35
38
  * `[...]` of i32; result is `[...]` (one rank less than logits). The user
package/dist/nn.d.ts.map CHANGED
@@ -1 +1 @@
1
- {"version":3,"file":"nn.d.ts","sourceRoot":"","sources":["../src/nn.ts"],"names":[],"mappings":"AAeA,OAAO,EAAE,MAAM,EAAE,MAAM,aAAa,CAAA;AACpC,OAAO,KAAK,EAAE,MAAM,EAAE,MAAM,SAAS,CAAA;AASrC,qBAAa,MAAO,SAAQ,MAAM;aAGJ,KAAK,EAAE,MAAM;aAAkB,MAAM,EAAE,MAAM;IAFzE,CAAC,EAAE,MAAM,CAAA;IACT,CAAC,EAAE,MAAM,GAAG,IAAI,CAAA;gBACY,KAAK,EAAE,MAAM,EAAkB,MAAM,EAAE,MAAM,EAAE,QAAQ,UAAO;CAK3F;AAED,wBAAgB,SAAS,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAGtD;AAMD,qBAAa,SAAU,SAAQ,MAAM;aAGP,CAAC,EAAE,MAAM;aAAkB,GAAG,EAAE,MAAM;IAFlE,CAAC,EAAE,MAAM,CAAA;IACT,CAAC,EAAE,MAAM,CAAA;gBACmB,CAAC,EAAE,MAAM,EAAkB,GAAG,GAAE,MAAa;CAK1E;AAED,wBAAgB,YAAY,CAAC,CAAC,EAAE,SAAS,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAM5D;AAOD;;4DAE4D;AAC5D,wBAAgB,UAAU,CAAC,CAAC,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,MAAM,CAa5D;AAED,+DAA+D;AAC/D,wBAAgB,UAAU,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAW5C;AAED;;;;;;gEAMgE;AAChE,wBAAgB,YAAY,CAAC,IAAI,EAAE,YAAY,EAAE,KAAK,EAAE,SAAS,MAAM,EAAE,GAAG,YAAY,EAAE,CAezF;AAMD;;;;+EAI+E;AAC/E,wBAAgB,gBAAgB,CAAC,MAAM,EAAE,MAAM,EAAE,OAAO,EAAE,MAAM,GAAG,MAAM,CASxE"}
1
+ {"version":3,"file":"nn.d.ts","sourceRoot":"","sources":["../src/nn.ts"],"names":[],"mappings":"AAaA,OAAO,EAAE,MAAM,EAAE,MAAM,aAAa,CAAA;AACpC,OAAO,KAAK,EAAE,MAAM,EAAE,MAAM,SAAS,CAAA;AAIrC,OAAO,KAAK,EAAE,QAAQ,EAAE,MAAM,cAAc,CAAA;AAM5C,MAAM,WAAW,aAAa;IAC5B,0CAA0C;IAC1C,IAAI,CAAC,EAAE,OAAO,CAAA;CACf;AAED,qBAAa,MAAO,SAAQ,MAAM;aAGJ,KAAK,EAAE,MAAM;aAAkB,MAAM,EAAE,MAAM;IAFzE,CAAC,EAAE,MAAM,CAAA;IACT,CAAC,EAAE,MAAM,GAAG,IAAI,CAAA;gBACY,KAAK,EAAE,MAAM,EAAkB,MAAM,EAAE,MAAM,EAAE,IAAI,GAAE,aAAkB;IAKnG,GAAG,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM;CAIvB;AAMD,qBAAa,SAAU,SAAQ,MAAM;aAGP,CAAC,EAAE,MAAM;aAAkB,GAAG,EAAE,MAAM;IAFlE,CAAC,EAAE,MAAM,CAAA;IACT,CAAC,EAAE,MAAM,CAAA;gBACmB,CAAC,EAAE,MAAM,EAAkB,GAAG,GAAE,MAAa;IAKzE,GAAG,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM;CAOvB;AAOD;;4DAE4D;AAC5D,wBAAgB,UAAU,CAAC,CAAC,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,MAAM,CAa5D;AAED,+DAA+D;AAC/D,wBAAgB,UAAU,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAW5C;AAED;;;;kEAIkE;AAClE,wBAAgB,YAAY,CAAC,QAAQ,EAAE,QAAQ,EAAE,IAAI,EAAE,MAAM,GAAG,YAAY,EAAE,CAiB7E;AAMD;;;;+EAI+E;AAC/E,wBAAgB,gBAAgB,CAAC,MAAM,EAAE,MAAM,EAAE,OAAO,EAAE,MAAM,GAAG,MAAM,CASxE"}