tensorgrad 0.0.1 → 0.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. package/README.md +7 -9
  2. package/dist/adam.d.ts +14 -2
  3. package/dist/adam.d.ts.map +1 -1
  4. package/dist/adam.js +19 -8
  5. package/dist/adam.js.map +1 -1
  6. package/dist/buffers.d.ts +1 -0
  7. package/dist/buffers.d.ts.map +1 -1
  8. package/dist/buffers.js +12 -1
  9. package/dist/buffers.js.map +1 -1
  10. package/dist/capture.d.ts +3 -0
  11. package/dist/capture.d.ts.map +1 -0
  12. package/dist/capture.js +33 -0
  13. package/dist/capture.js.map +1 -0
  14. package/dist/codegen.js +4 -2
  15. package/dist/codegen.js.map +1 -1
  16. package/dist/compile.d.ts +33 -5
  17. package/dist/compile.d.ts.map +1 -1
  18. package/dist/compile.js +96 -11
  19. package/dist/compile.js.map +1 -1
  20. package/dist/index.d.ts +5 -3
  21. package/dist/index.d.ts.map +1 -1
  22. package/dist/index.js +4 -2
  23. package/dist/index.js.map +1 -1
  24. package/dist/ir.d.ts +2 -0
  25. package/dist/ir.d.ts.map +1 -1
  26. package/dist/ir.js +1 -1
  27. package/dist/ir.js.map +1 -1
  28. package/dist/module.d.ts +30 -4
  29. package/dist/module.d.ts.map +1 -1
  30. package/dist/module.js +39 -13
  31. package/dist/module.js.map +1 -1
  32. package/dist/nn.d.ts +19 -0
  33. package/dist/nn.d.ts.map +1 -0
  34. package/dist/nn.js +60 -0
  35. package/dist/nn.js.map +1 -0
  36. package/dist/ops.d.ts +1 -1
  37. package/dist/ops.d.ts.map +1 -1
  38. package/dist/ops.js +2 -2
  39. package/dist/ops.js.map +1 -1
  40. package/dist/runtime.d.ts +79 -4
  41. package/dist/runtime.d.ts.map +1 -1
  42. package/dist/runtime.js +153 -19
  43. package/dist/runtime.js.map +1 -1
  44. package/dist/trace.d.ts +1 -0
  45. package/dist/trace.d.ts.map +1 -1
  46. package/dist/trace.js +12 -0
  47. package/dist/trace.js.map +1 -1
  48. package/package.json +1 -2
  49. package/src/adam.ts +31 -10
  50. package/src/buffers.ts +14 -1
  51. package/src/capture.ts +36 -0
  52. package/src/codegen.ts +4 -2
  53. package/src/compile.ts +112 -13
  54. package/src/index.ts +5 -3
  55. package/src/ir.ts +10 -4
  56. package/src/module.ts +75 -11
  57. package/src/nn.ts +59 -0
  58. package/src/ops.ts +2 -2
  59. package/src/runtime.ts +260 -22
  60. package/src/trace.ts +13 -0
  61. package/SPEC.md +0 -293
package/dist/compile.js CHANGED
@@ -12,7 +12,7 @@ import { appendGrad } from './grad.js';
12
12
  import { appendAdam } from './adam.js';
13
13
  import { planBuffers } from './buffers.js';
14
14
  import { emitKernels } from './codegen.js';
15
- import { createRuntime } from './runtime.js';
15
+ import { createRuntime, createForwardRuntime } from './runtime.js';
16
16
  import { Module, materializeParams } from './module.js';
17
17
  /** Trace + autograd + buffer-plan + codegen, without touching WebGPU. */
18
18
  export function compileToIR(traceFn) {
@@ -30,9 +30,13 @@ export async function compile(traceFn, opts = {}) {
30
30
  return Object.assign(runtime, { ir });
31
31
  }
32
32
  /**
33
- * Compile a Module-based model. The forward function takes the materialized
34
- * model and returns the loss tensor (typically by also calling tensorInput
35
- * for tokens/targets/masks inside).
33
+ * Compile a Module-based model. Pass a *factory* `() => new Model()`, not the
34
+ * model instance itself: compilation mutates the tree (every `ParamSentinel`
35
+ * field becomes a real `Tensor`), so the instance is consumed and shouldn't be
36
+ * referenced afterwards. Re-call the factory if you need a fresh tree.
37
+ *
38
+ * The forward function takes the materialized model and returns the loss
39
+ * tensor.
36
40
  *
37
41
  * Walks the module tree to materialize params with auto-derived names, then
38
42
  * runs trace → grad → adam → buffer plan → codegen → runtime.
@@ -41,36 +45,117 @@ export async function compile(traceFn, opts = {}) {
41
45
  * internal step count and injects the bias-corrected `lrt` scalar each call;
42
46
  * users don't need to provide it themselves.
43
47
  */
44
- export async function compileModule(model, forward, opts = {}) {
48
+ export async function compileModule(modelFactory, forward, opts = {}) {
45
49
  const inputDecls = opts.inputs ?? [];
46
- let paramTensors = {};
50
+ const model = modelFactory();
51
+ let materialized = { tensors: {}, initFns: {} };
47
52
  const graph = trace(() => {
48
- paramTensors = materializeParams(model);
53
+ materialized = materializeParams(model);
49
54
  const inputTensors = inputDecls.map(d => tensorInput(d.name, d.shape, d.dtype ?? 'f32'));
50
55
  return forward(model, ...inputTensors);
51
56
  });
52
57
  const { paramGrads, loss } = appendGrad(graph);
53
58
  let adamResult;
54
59
  if (opts.adam) {
55
- adamResult = appendAdam(graph, paramGrads, paramTensors, opts.adam);
60
+ adamResult = appendAdam(graph, paramGrads, materialized.tensors, opts.adam);
56
61
  }
57
62
  const plan = planBuffers(graph, paramGrads, adamResult?.writebacks ?? []);
58
63
  const kernels = emitKernels(graph, plan);
59
64
  const lossBufferId = plan.tensorToBuffer.get(loss.id);
60
65
  const runtime = await createRuntime(plan, kernels, lossBufferId, opts);
61
66
  // If Adam is enabled, wrap step() to track the step count and supply lrt.
67
+ // Wrap resetOptimizerState() too, so a reset zeros m/v *and* the bias-correction
68
+ // counter — otherwise the next step would skip Adam's warmup phase.
62
69
  if (adamResult) {
63
70
  const { lrtInputName, config } = adamResult;
64
71
  let t = 0;
65
72
  const lrtBuf = new Float32Array(1);
66
73
  const innerStep = runtime.step.bind(runtime);
67
- runtime.step = async (inputs) => {
74
+ const innerReset = runtime.resetOptimizerState.bind(runtime);
75
+ const wrappedStep = (inputs, opts) => {
68
76
  t++;
69
77
  lrtBuf[0] = config.lr * Math.sqrt(1 - Math.pow(config.b2, t)) / (1 - Math.pow(config.b1, t));
70
- return innerStep({ ...inputs, [lrtInputName]: lrtBuf });
78
+ const merged = { ...inputs, [lrtInputName]: lrtBuf };
79
+ return opts?.withCaptures ? innerStep(merged, { withCaptures: true }) : innerStep(merged);
80
+ };
81
+ runtime.step = wrappedStep;
82
+ runtime.resetOptimizerState = () => {
83
+ t = 0;
84
+ innerReset();
71
85
  };
72
86
  }
87
+ const { initFns } = materialized;
88
+ const uploadInitialParams = () => {
89
+ const out = {};
90
+ for (const [name, bufId] of plan.paramsByName) {
91
+ const shape = plan.buffers[bufId].shape;
92
+ const size = shape.reduce((a, b) => a * b, 1);
93
+ const initFn = initFns[name];
94
+ if (!initFn)
95
+ throw new Error(`uploadInitialParams: no init for param '${name}'`);
96
+ out[name] = initFn(size, shape);
97
+ }
98
+ runtime.uploadParams(out);
99
+ };
73
100
  const ir = { graph, paramGrads, loss, plan, kernels };
74
- return Object.assign(runtime, { ir });
101
+ return Object.assign(runtime, { ir, uploadInitialParams });
102
+ }
103
+ // ============================================================================
104
+ // Forward-only compile
105
+ // ============================================================================
106
+ /**
107
+ * Compile a Module-based model in forward-only mode (no autograd, no Adam).
108
+ * The forward function returns the output tensor (e.g., logits) instead of a
109
+ * scalar loss; runtime exposes `run(inputs)` returning the full output as a
110
+ * `Float32Array`.
111
+ *
112
+ * **Sharing params with a training compile.** Pass `opts.sharedParams =
113
+ * trainCompiled.params` to bind this graph's param buffers to an existing
114
+ * training runtime's GPU buffers — every train step is then immediately
115
+ * visible to `run()` calls here, no copies. The forward graph's
116
+ * `uploadInitialParams()` skips any param covered by `sharedParams`.
117
+ *
118
+ * Typical use: a B=1 inference graph alongside a B=512 training graph,
119
+ * built from the same `Module` factory.
120
+ */
121
+ export async function compileForward(modelFactory, forward, opts = {}) {
122
+ const inputDecls = opts.inputs ?? [];
123
+ const model = modelFactory();
124
+ let materialized = { tensors: {}, initFns: {} };
125
+ const graph = trace(() => {
126
+ materialized = materializeParams(model);
127
+ const inputTensors = inputDecls.map(d => tensorInput(d.name, d.shape, d.dtype ?? 'f32'));
128
+ return forward(model, ...inputTensors);
129
+ });
130
+ const plan = planBuffers(graph, /* paramGrads */ {});
131
+ const kernels = emitKernels(graph, plan);
132
+ const outputTensor = graph.tensors[graph.outputs[0]];
133
+ const outputBufferId = plan.tensorToBuffer.get(outputTensor.id);
134
+ const runtime = await createForwardRuntime(plan, kernels, outputBufferId, opts);
135
+ const sharedParams = opts.sharedParams;
136
+ const { initFns } = materialized;
137
+ const uploadInitialParams = () => {
138
+ const out = {};
139
+ let needsUpload = false;
140
+ for (const [name, bufId] of plan.paramsByName) {
141
+ // Skip params covered by sharedParams — those are owned by the providing
142
+ // compile and already initialized there.
143
+ if (sharedParams?.has(name))
144
+ continue;
145
+ const shape = plan.buffers[bufId].shape;
146
+ const size = shape.reduce((a, b) => a * b, 1);
147
+ const initFn = initFns[name];
148
+ if (!initFn)
149
+ throw new Error(`uploadInitialParams: no init for param '${name}'`);
150
+ out[name] = initFn(size, shape);
151
+ needsUpload = true;
152
+ }
153
+ if (needsUpload)
154
+ runtime.uploadParams(out, { partial: !!sharedParams });
155
+ };
156
+ // CompiledIR.loss is the field name; for forward-only, it carries the user's
157
+ // returned tensor (e.g., logits). Same shape conceptually; just no autograd.
158
+ const ir = { graph, paramGrads: {}, loss: outputTensor, plan, kernels };
159
+ return Object.assign(runtime, { ir, uploadInitialParams });
75
160
  }
76
161
  //# sourceMappingURL=compile.js.map
@@ -1 +1 @@
1
- {"version":3,"file":"compile.js","sourceRoot":"","sources":["../src/compile.ts"],"names":[],"mappings":"AAAA,2EAA2E;AAC3E,EAAE;AACF,oBAAoB;AACpB,sEAAsE;AACtE,iEAAiE;AACjE,0EAA0E;AAC1E,0EAA0E;AAC1E,2EAA2E;AAC3E,mEAAmE;AAGnE,OAAO,EAAE,KAAK,EAAE,WAAW,EAAE,MAAM,YAAY,CAAA;AAC/C,OAAO,EAAE,UAAU,EAAmB,MAAM,WAAW,CAAA;AACvD,OAAO,EAAE,UAAU,EAAmB,MAAM,WAAW,CAAA;AACvD,OAAO,EAAE,WAAW,EAAmB,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAE,WAAW,EAAmB,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAE,aAAa,EAA0C,MAAM,cAAc,CAAA;AACpF,OAAO,EAAE,MAAM,EAAE,iBAAiB,EAAE,MAAM,aAAa,CAAA;AAmBvD,yEAAyE;AACzE,MAAM,UAAU,WAAW,CAAC,OAAqB;IAC/C,MAAM,KAAK,GAAG,KAAK,CAAC,OAAO,CAAC,CAAA;IAC5B,MAAM,EAAE,UAAU,EAAE,IAAI,EAAE,GAAG,UAAU,CAAC,KAAK,CAAC,CAAA;IAC9C,MAAM,IAAI,GAAG,WAAW,CAAC,KAAK,EAAE,UAAU,CAAC,CAAA;IAC3C,MAAM,OAAO,GAAG,WAAW,CAAC,KAAK,EAAE,IAAI,CAAC,CAAA;IACxC,OAAO,EAAE,KAAK,EAAE,UAAU,EAAE,IAAI,EAAE,IAAI,EAAE,OAAO,EAAE,CAAA;AACnD,CAAC;AAED,0EAA0E;AAC1E,MAAM,CAAC,KAAK,UAAU,OAAO,CAAC,OAAqB,EAAE,OAAoB,EAAE;IACzE,MAAM,EAAE,GAAG,WAAW,CAAC,OAAO,CAAC,CAAA;IAC/B,MAAM,YAAY,GAAG,EAAE,CAAC,IAAI,CAAC,cAAc,CAAC,GAAG,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,CAAE,CAAA;IAC5D,MAAM,OAAO,GAAG,MAAM,aAAa,CAAC,EAAE,CAAC,IAAI,EAAE,EAAE,CAAC,OAAO,EAAE,YAAY,EAAE,IAAI,CAAC,CAAA;IAC5E,OAAO,MAAM,CAAC,MAAM,CAAC,OAAO,EAAE,EAAE,EAAE,EAAE,CAAC,CAAA;AACvC,CAAC;AAgBD;;;;;;;;;;;GAWG;AACH,MAAM,CAAC,KAAK,UAAU,aAAa,CACjC,KAAQ,EACR,OAA8C,EAC9C,OAA6B,EAAE;IAE/B,MAAM,UAAU,GAAG,IAAI,CAAC,MAAM,IAAI,EAAE,CAAA;IACpC,IAAI,YAAY,GAA2B,EAAE,CAAA;IAC7C,MAAM,KAAK,GAAG,KAAK,CAAC,GAAG,EAAE;QACvB,YAAY,GAAG,iBAAiB,CAAC,KAAK,CAAC,CAAA;QACvC,MAAM,YAAY,GAAG,UAAU,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,WAAW,CAAC,CAAC,CAAC,IAAI,EAAE,CAAC,CAAC,KAAK,EAAE,CAAC,CAAC,KAAK,IAAI,KAAK,CAAC,CAAC,CAAA;QACxF,OAAO,OAAO,CAAC,KAAK,EAAE,GAAG,YAAY,CAAC,CAAA;IACxC,CAAC,CAAC,CAAA;IAEF,MAAM,EAAE,UAAU,EAAE,IAAI,EAAE,GAAG,UAAU,CAAC,KAAK,CAAC,CAAA;IAE9C,IAAI,UAAqD,CAAA;IACzD,IAAI,IAAI,CAAC,IAAI,EAAE,CAAC;QACd,UAAU,GAAG,UAAU,CAAC,KAAK,EAAE,UAAU,EAAE,YAAY,EAAE,IAAI,CAAC,IAAI,CAAC,CAAA;IACrE,CAAC;IAED,MAAM,IAAI,GAAG,WAAW,CAAC,KAAK,EAAE,UAAU,EAAE,UAAU,EAAE,UAAU,IAAI,EAAE,CAAC,CAAA;IACzE,MAAM,OAAO,GAAG,WAAW,CAAC,KAAK,EAAE,IAAI,CAAC,CAAA;IACxC,MAAM,YAAY,GAAG,IAAI,CAAC,cAAc,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,CAAE,CAAA;IACtD,MAAM,OAAO,GAAG,MAAM,aAAa,CAAC,IAAI,EAAE,OAAO,EAAE,YAAY,EAAE,IAAI,CAAC,CAAA;IAEtE,0EAA0E;IAC1E,IAAI,UAAU,EAAE,CAAC;QACf,MAAM,EAAE,YAAY,EAAE,MAAM,EAAE,GAAG,UAAU,CAAA;QAC3C,IAAI,CAAC,GAAG,CAAC,CAAA;QACT,MAAM,MAAM,GAAG,IAAI,YAAY,CAAC,CAAC,CAAC,CAAA;QAClC,MAAM,SAAS,GAAG,OAAO,CAAC,IAAI,CAAC,IAAI,CAAC,OAAO,CAAC,CAAA;QAC5C,OAAO,CAAC,IAAI,GAAG,KAAK,EAAE,MAAM,EAAE,EAAE;YAC9B,CAAC,EAAE,CAAA;YACH,MAAM,CAAC,CAAC,CAAC,GAAG,MAAM,CAAC,EAAE,GAAG,IAAI,CAAC,IAAI,CAAC,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,MAAM,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,MAAM,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,CAAA;YAC5F,OAAO,SAAS,CAAC,EAAE,GAAG,MAAM,EAAE,CAAC,YAAY,CAAC,EAAE,MAAM,EAAE,CAAC,CAAA;QACzD,CAAC,CAAA;IACH,CAAC;IAED,MAAM,EAAE,GAAe,EAAE,KAAK,EAAE,UAAU,EAAE,IAAI,EAAE,IAAI,EAAE,OAAO,EAAE,CAAA;IACjE,OAAO,MAAM,CAAC,MAAM,CAAC,OAAO,EAAE,EAAE,EAAE,EAAE,CAAC,CAAA;AACvC,CAAC"}
1
+ {"version":3,"file":"compile.js","sourceRoot":"","sources":["../src/compile.ts"],"names":[],"mappings":"AAAA,2EAA2E;AAC3E,EAAE;AACF,oBAAoB;AACpB,sEAAsE;AACtE,iEAAiE;AACjE,0EAA0E;AAC1E,0EAA0E;AAC1E,2EAA2E;AAC3E,mEAAmE;AAGnE,OAAO,EAAE,KAAK,EAAE,WAAW,EAAE,MAAM,YAAY,CAAA;AAC/C,OAAO,EAAE,UAAU,EAAmB,MAAM,WAAW,CAAA;AACvD,OAAO,EAAE,UAAU,EAAmB,MAAM,WAAW,CAAA;AACvD,OAAO,EAAE,WAAW,EAAmB,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAE,WAAW,EAAmB,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAE,aAAa,EAAE,oBAAoB,EAAgE,MAAM,cAAc,CAAA;AAChI,OAAO,EAAE,MAAM,EAAE,iBAAiB,EAAE,MAAM,aAAa,CAAA;AAmBvD,yEAAyE;AACzE,MAAM,UAAU,WAAW,CAAC,OAAqB;IAC/C,MAAM,KAAK,GAAG,KAAK,CAAC,OAAO,CAAC,CAAA;IAC5B,MAAM,EAAE,UAAU,EAAE,IAAI,EAAE,GAAG,UAAU,CAAC,KAAK,CAAC,CAAA;IAC9C,MAAM,IAAI,GAAG,WAAW,CAAC,KAAK,EAAE,UAAU,CAAC,CAAA;IAC3C,MAAM,OAAO,GAAG,WAAW,CAAC,KAAK,EAAE,IAAI,CAAC,CAAA;IACxC,OAAO,EAAE,KAAK,EAAE,UAAU,EAAE,IAAI,EAAE,IAAI,EAAE,OAAO,EAAE,CAAA;AACnD,CAAC;AAED,0EAA0E;AAC1E,MAAM,CAAC,KAAK,UAAU,OAAO,CAAC,OAAqB,EAAE,OAAoB,EAAE;IACzE,MAAM,EAAE,GAAG,WAAW,CAAC,OAAO,CAAC,CAAA;IAC/B,MAAM,YAAY,GAAG,EAAE,CAAC,IAAI,CAAC,cAAc,CAAC,GAAG,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,CAAE,CAAA;IAC5D,MAAM,OAAO,GAAG,MAAM,aAAa,CAAC,EAAE,CAAC,IAAI,EAAE,EAAE,CAAC,OAAO,EAAE,YAAY,EAAE,IAAI,CAAC,CAAA;IAC5E,OAAO,MAAM,CAAC,MAAM,CAAC,OAAO,EAAE,EAAE,EAAE,EAAE,CAAC,CAAA;AACvC,CAAC;AAqBD;;;;;;;;;;;;;;;GAeG;AACH,MAAM,CAAC,KAAK,UAAU,aAAa,CACjC,YAAqB,EACrB,OAA8C,EAC9C,OAA6B,EAAE;IAE/B,MAAM,UAAU,GAAG,IAAI,CAAC,MAAM,IAAI,EAAE,CAAA;IACpC,MAAM,KAAK,GAAG,YAAY,EAAE,CAAA;IAC5B,IAAI,YAAY,GAAyC,EAAE,OAAO,EAAE,EAAE,EAAE,OAAO,EAAE,EAAE,EAAE,CAAA;IACrF,MAAM,KAAK,GAAG,KAAK,CAAC,GAAG,EAAE;QACvB,YAAY,GAAG,iBAAiB,CAAC,KAAK,CAAC,CAAA;QACvC,MAAM,YAAY,GAAG,UAAU,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,WAAW,CAAC,CAAC,CAAC,IAAI,EAAE,CAAC,CAAC,KAAK,EAAE,CAAC,CAAC,KAAK,IAAI,KAAK,CAAC,CAAC,CAAA;QACxF,OAAO,OAAO,CAAC,KAAK,EAAE,GAAG,YAAY,CAAC,CAAA;IACxC,CAAC,CAAC,CAAA;IAEF,MAAM,EAAE,UAAU,EAAE,IAAI,EAAE,GAAG,UAAU,CAAC,KAAK,CAAC,CAAA;IAE9C,IAAI,UAAqD,CAAA;IACzD,IAAI,IAAI,CAAC,IAAI,EAAE,CAAC;QACd,UAAU,GAAG,UAAU,CAAC,KAAK,EAAE,UAAU,EAAE,YAAY,CAAC,OAAO,EAAE,IAAI,CAAC,IAAI,CAAC,CAAA;IAC7E,CAAC;IAED,MAAM,IAAI,GAAG,WAAW,CAAC,KAAK,EAAE,UAAU,EAAE,UAAU,EAAE,UAAU,IAAI,EAAE,CAAC,CAAA;IACzE,MAAM,OAAO,GAAG,WAAW,CAAC,KAAK,EAAE,IAAI,CAAC,CAAA;IACxC,MAAM,YAAY,GAAG,IAAI,CAAC,cAAc,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,CAAE,CAAA;IACtD,MAAM,OAAO,GAAG,MAAM,aAAa,CAAC,IAAI,EAAE,OAAO,EAAE,YAAY,EAAE,IAAI,CAAC,CAAA;IAEtE,0EAA0E;IAC1E,iFAAiF;IACjF,oEAAoE;IACpE,IAAI,UAAU,EAAE,CAAC;QACf,MAAM,EAAE,YAAY,EAAE,MAAM,EAAE,GAAG,UAAU,CAAA;QAC3C,IAAI,CAAC,GAAG,CAAC,CAAA;QACT,MAAM,MAAM,GAAG,IAAI,YAAY,CAAC,CAAC,CAAC,CAAA;QAClC,MAAM,SAAS,GAAG,OAAO,CAAC,IAAI,CAAC,IAAI,CAAC,OAAO,CAA4B,CAAA;QACvE,MAAM,UAAU,GAAG,OAAO,CAAC,mBAAmB,CAAC,IAAI,CAAC,OAAO,CAAC,CAAA;QAC5D,MAAM,WAAW,GAAG,CAClB,MAAiD,EACjD,IAAiC,EAC2C,EAAE;YAC9E,CAAC,EAAE,CAAA;YACH,MAAM,CAAC,CAAC,CAAC,GAAG,MAAM,CAAC,EAAE,GAAG,IAAI,CAAC,IAAI,CAAC,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,MAAM,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,MAAM,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,CAAA;YAC5F,MAAM,MAAM,GAAG,EAAE,GAAG,MAAM,EAAE,CAAC,YAAY,CAAC,EAAE,MAAM,EAAE,CAAA;YACpD,OAAO,IAAI,EAAE,YAAY,CAAC,CAAC,CAAC,SAAS,CAAC,MAAM,EAAE,EAAE,YAAY,EAAE,IAAI,EAAE,CAAC,CAAC,CAAC,CAAC,SAAS,CAAC,MAAM,CAAC,CAAA;QAC3F,CAAC,CAAA;QACD,OAAO,CAAC,IAAI,GAAG,WAAsC,CAAA;QACrD,OAAO,CAAC,mBAAmB,GAAG,GAAG,EAAE;YACjC,CAAC,GAAG,CAAC,CAAA;YACL,UAAU,EAAE,CAAA;QACd,CAAC,CAAA;IACH,CAAC;IAED,MAAM,EAAE,OAAO,EAAE,GAAG,YAAY,CAAA;IAChC,MAAM,mBAAmB,GAAG,GAAG,EAAE;QAC/B,MAAM,GAAG,GAAiC,EAAE,CAAA;QAC5C,KAAK,MAAM,CAAC,IAAI,EAAE,KAAK,CAAC,IAAI,IAAI,CAAC,YAAY,EAAE,CAAC;YAC9C,MAAM,KAAK,GAAG,IAAI,CAAC,OAAO,CAAC,KAAK,CAAE,CAAC,KAAK,CAAA;YACxC,MAAM,IAAI,GAAG,KAAK,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC,CAAA;YAC7C,MAAM,MAAM,GAAG,OAAO,CAAC,IAAI,CAAC,CAAA;YAC5B,IAAI,CAAC,MAAM;gBAAE,MAAM,IAAI,KAAK,CAAC,2CAA2C,IAAI,GAAG,CAAC,CAAA;YAChF,GAAG,CAAC,IAAI,CAAC,GAAG,MAAM,CAAC,IAAI,EAAE,KAAK,CAAC,CAAA;QACjC,CAAC;QACD,OAAO,CAAC,YAAY,CAAC,GAAG,CAAC,CAAA;IAC3B,CAAC,CAAA;IAED,MAAM,EAAE,GAAe,EAAE,KAAK,EAAE,UAAU,EAAE,IAAI,EAAE,IAAI,EAAE,OAAO,EAAE,CAAA;IACjE,OAAO,MAAM,CAAC,MAAM,CAAC,OAAO,EAAE,EAAE,EAAE,EAAE,mBAAmB,EAAE,CAAC,CAAA;AAC5D,CAAC;AAED,+EAA+E;AAC/E,uBAAuB;AACvB,+EAA+E;AAE/E;;;;;;;;;;;;;;GAcG;AACH,MAAM,CAAC,KAAK,UAAU,cAAc,CAClC,YAAqB,EACrB,OAA8C,EAC9C,OAA8B,EAAE;IAEhC,MAAM,UAAU,GAAG,IAAI,CAAC,MAAM,IAAI,EAAE,CAAA;IACpC,MAAM,KAAK,GAAG,YAAY,EAAE,CAAA;IAC5B,IAAI,YAAY,GAAyC,EAAE,OAAO,EAAE,EAAE,EAAE,OAAO,EAAE,EAAE,EAAE,CAAA;IACrF,MAAM,KAAK,GAAG,KAAK,CAAC,GAAG,EAAE;QACvB,YAAY,GAAG,iBAAiB,CAAC,KAAK,CAAC,CAAA;QACvC,MAAM,YAAY,GAAG,UAAU,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,WAAW,CAAC,CAAC,CAAC,IAAI,EAAE,CAAC,CAAC,KAAK,EAAE,CAAC,CAAC,KAAK,IAAI,KAAK,CAAC,CAAC,CAAA;QACxF,OAAO,OAAO,CAAC,KAAK,EAAE,GAAG,YAAY,CAAC,CAAA;IACxC,CAAC,CAAC,CAAA;IAEF,MAAM,IAAI,GAAG,WAAW,CAAC,KAAK,EAAE,gBAAgB,CAAC,EAAE,CAAC,CAAA;IACpD,MAAM,OAAO,GAAG,WAAW,CAAC,KAAK,EAAE,IAAI,CAAC,CAAA;IACxC,MAAM,YAAY,GAAG,KAAK,CAAC,OAAO,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC,CAAE,CAAE,CAAA;IACtD,MAAM,cAAc,GAAG,IAAI,CAAC,cAAc,CAAC,GAAG,CAAC,YAAY,CAAC,EAAE,CAAE,CAAA;IAChE,MAAM,OAAO,GAAG,MAAM,oBAAoB,CAAC,IAAI,EAAE,OAAO,EAAE,cAAc,EAAE,IAAI,CAAC,CAAA;IAE/E,MAAM,YAAY,GAAG,IAAI,CAAC,YAAY,CAAA;IACtC,MAAM,EAAE,OAAO,EAAE,GAAG,YAAY,CAAA;IAChC,MAAM,mBAAmB,GAAG,GAAG,EAAE;QAC/B,MAAM,GAAG,GAAiC,EAAE,CAAA;QAC5C,IAAI,WAAW,GAAG,KAAK,CAAA;QACvB,KAAK,MAAM,CAAC,IAAI,EAAE,KAAK,CAAC,IAAI,IAAI,CAAC,YAAY,EAAE,CAAC;YAC9C,yEAAyE;YACzE,yCAAyC;YACzC,IAAI,YAAY,EAAE,GAAG,CAAC,IAAI,CAAC;gBAAE,SAAQ;YACrC,MAAM,KAAK,GAAG,IAAI,CAAC,OAAO,CAAC,KAAK,CAAE,CAAC,KAAK,CAAA;YACxC,MAAM,IAAI,GAAG,KAAK,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC,CAAA;YAC7C,MAAM,MAAM,GAAG,OAAO,CAAC,IAAI,CAAC,CAAA;YAC5B,IAAI,CAAC,MAAM;gBAAE,MAAM,IAAI,KAAK,CAAC,2CAA2C,IAAI,GAAG,CAAC,CAAA;YAChF,GAAG,CAAC,IAAI,CAAC,GAAG,MAAM,CAAC,IAAI,EAAE,KAAK,CAAC,CAAA;YAC/B,WAAW,GAAG,IAAI,CAAA;QACpB,CAAC;QACD,IAAI,WAAW;YAAE,OAAO,CAAC,YAAY,CAAC,GAAG,EAAE,EAAE,OAAO,EAAE,CAAC,CAAC,YAAY,EAAE,CAAC,CAAA;IACzE,CAAC,CAAA;IAED,6EAA6E;IAC7E,6EAA6E;IAC7E,MAAM,EAAE,GAAe,EAAE,KAAK,EAAE,UAAU,EAAE,EAAE,EAAE,IAAI,EAAE,YAAY,EAAE,IAAI,EAAE,OAAO,EAAE,CAAA;IACnF,OAAO,MAAM,CAAC,MAAM,CAAC,OAAO,EAAE,EAAE,EAAE,EAAE,mBAAmB,EAAE,CAAC,CAAA;AAC5D,CAAC"}
package/dist/index.d.ts CHANGED
@@ -1,12 +1,14 @@
1
1
  export type { Tensor, Shape, Dtype, OpNode, Graph, CallSite } from './ir.js';
2
2
  export { ShapeError } from './shape.js';
3
3
  export { trace, traceInto, paramInput, tensorInput, stateInput } from './trace.js';
4
+ export { capture } from './capture.js';
4
5
  export { add, sub, mul, div, sqrt, rsqrt, log, exp, relu, less, greater, where, meanLast, sumLast, reshape, transpose, matmul, matmulBatched, oneHot, arange, softmaxCausalLast, logSoftmaxLast, whereCausal, sliceLastRange, } from './ops.js';
5
6
  export { appendGrad, type GradResult } from './grad.js';
6
7
  export { appendAdam, type AdamConfig, type AdamResult } from './adam.js';
7
8
  export { planBuffers, type BufferPlan, type BufferSpec, type Writeback, type WritebackDecl } from './buffers.js';
8
9
  export { emitKernels, type KernelSpec } from './codegen.js';
9
- export { createRuntime, type CompiledRuntime, type RuntimeOpts } from './runtime.js';
10
- export { compile, compileToIR, compileModule, type CompiledIR, type CompileModuleOptions, type InputDecl } from './compile.js';
11
- export { Module, materializeParams } from './module.js';
10
+ export { createRuntime, createForwardRuntime, type CompiledRuntime, type CompiledForward, type RuntimeOpts, type StepOptions, type StepWithCaptures, type RunOptions, type RunWithCaptures } from './runtime.js';
11
+ export { compile, compileToIR, compileModule, compileForward, type CompiledIR, type CompileModuleOptions, type CompileForwardOptions, type InputDecl } from './compile.js';
12
+ export { Module, materializeParams, type InitSpec, type ParamOptions, type MaterializedParams } from './module.js';
13
+ export * as nn from './nn.js';
12
14
  //# sourceMappingURL=index.d.ts.map
@@ -1 +1 @@
1
- {"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAKA,YAAY,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,QAAQ,EAAE,MAAM,SAAS,CAAA;AAC5E,OAAO,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AACvC,OAAO,EAAE,KAAK,EAAE,SAAS,EAAE,UAAU,EAAE,WAAW,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AAClF,OAAO,EAEL,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG,EAElB,IAAI,EAAE,KAAK,EAAE,GAAG,EAAE,GAAG,EAAE,IAAI,EAE3B,IAAI,EAAE,OAAO,EAAE,KAAK,EAEpB,QAAQ,EAAE,OAAO,EAEjB,OAAO,EAAE,SAAS,EAElB,MAAM,EAAE,aAAa,EAErB,MAAM,EAAE,MAAM,EAEd,iBAAiB,EAAE,cAAc,EAAE,WAAW,EAE9C,cAAc,GACf,MAAM,UAAU,CAAA;AAMjB,OAAO,EAAE,UAAU,EAAE,KAAK,UAAU,EAAE,MAAM,WAAW,CAAA;AACvD,OAAO,EAAE,UAAU,EAAE,KAAK,UAAU,EAAE,KAAK,UAAU,EAAE,MAAM,WAAW,CAAA;AACxE,OAAO,EAAE,WAAW,EAAE,KAAK,UAAU,EAAE,KAAK,UAAU,EAAE,KAAK,SAAS,EAAE,KAAK,aAAa,EAAE,MAAM,cAAc,CAAA;AAChH,OAAO,EAAE,WAAW,EAAE,KAAK,UAAU,EAAE,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAE,aAAa,EAAE,KAAK,eAAe,EAAE,KAAK,WAAW,EAAE,MAAM,cAAc,CAAA;AACpF,OAAO,EAAE,OAAO,EAAE,WAAW,EAAE,aAAa,EAAE,KAAK,UAAU,EAAE,KAAK,oBAAoB,EAAE,KAAK,SAAS,EAAE,MAAM,cAAc,CAAA;AAC9H,OAAO,EAAE,MAAM,EAAE,iBAAiB,EAAE,MAAM,aAAa,CAAA"}
1
+ {"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAKA,YAAY,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,QAAQ,EAAE,MAAM,SAAS,CAAA;AAC5E,OAAO,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AACvC,OAAO,EAAE,KAAK,EAAE,SAAS,EAAE,UAAU,EAAE,WAAW,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AAClF,OAAO,EAAE,OAAO,EAAE,MAAM,cAAc,CAAA;AACtC,OAAO,EAEL,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG,EAElB,IAAI,EAAE,KAAK,EAAE,GAAG,EAAE,GAAG,EAAE,IAAI,EAE3B,IAAI,EAAE,OAAO,EAAE,KAAK,EAEpB,QAAQ,EAAE,OAAO,EAEjB,OAAO,EAAE,SAAS,EAElB,MAAM,EAAE,aAAa,EAErB,MAAM,EAAE,MAAM,EAEd,iBAAiB,EAAE,cAAc,EAAE,WAAW,EAE9C,cAAc,GACf,MAAM,UAAU,CAAA;AAMjB,OAAO,EAAE,UAAU,EAAE,KAAK,UAAU,EAAE,MAAM,WAAW,CAAA;AACvD,OAAO,EAAE,UAAU,EAAE,KAAK,UAAU,EAAE,KAAK,UAAU,EAAE,MAAM,WAAW,CAAA;AACxE,OAAO,EAAE,WAAW,EAAE,KAAK,UAAU,EAAE,KAAK,UAAU,EAAE,KAAK,SAAS,EAAE,KAAK,aAAa,EAAE,MAAM,cAAc,CAAA;AAChH,OAAO,EAAE,WAAW,EAAE,KAAK,UAAU,EAAE,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAE,aAAa,EAAE,oBAAoB,EAAE,KAAK,eAAe,EAAE,KAAK,eAAe,EAAE,KAAK,WAAW,EAAE,KAAK,WAAW,EAAE,KAAK,gBAAgB,EAAE,KAAK,UAAU,EAAE,KAAK,eAAe,EAAE,MAAM,cAAc,CAAA;AAChN,OAAO,EAAE,OAAO,EAAE,WAAW,EAAE,aAAa,EAAE,cAAc,EAAE,KAAK,UAAU,EAAE,KAAK,oBAAoB,EAAE,KAAK,qBAAqB,EAAE,KAAK,SAAS,EAAE,MAAM,cAAc,CAAA;AAC1K,OAAO,EAAE,MAAM,EAAE,iBAAiB,EAAE,KAAK,QAAQ,EAAE,KAAK,YAAY,EAAE,KAAK,kBAAkB,EAAE,MAAM,aAAa,CAAA;AAClH,OAAO,KAAK,EAAE,MAAM,SAAS,CAAA"}
package/dist/index.js CHANGED
@@ -4,6 +4,7 @@
4
4
  // codegen / compile() (Phase 3+) come later.
5
5
  export { ShapeError } from './shape.js';
6
6
  export { trace, traceInto, paramInput, tensorInput, stateInput } from './trace.js';
7
+ export { capture } from './capture.js';
7
8
  export {
8
9
  // Element-wise arithmetic. The binops accept Tensor or JS-number for the second arg.
9
10
  add, sub, mul, div,
@@ -31,7 +32,8 @@ export { appendGrad } from './grad.js';
31
32
  export { appendAdam } from './adam.js';
32
33
  export { planBuffers } from './buffers.js';
33
34
  export { emitKernels } from './codegen.js';
34
- export { createRuntime } from './runtime.js';
35
- export { compile, compileToIR, compileModule } from './compile.js';
35
+ export { createRuntime, createForwardRuntime } from './runtime.js';
36
+ export { compile, compileToIR, compileModule, compileForward } from './compile.js';
36
37
  export { Module, materializeParams } from './module.js';
38
+ export * as nn from './nn.js';
37
39
  //# sourceMappingURL=index.js.map
package/dist/index.js.map CHANGED
@@ -1 +1 @@
1
- {"version":3,"file":"index.js","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAAA,+CAA+C;AAC/C,EAAE;AACF,8EAA8E;AAC9E,6CAA6C;AAG7C,OAAO,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AACvC,OAAO,EAAE,KAAK,EAAE,SAAS,EAAE,UAAU,EAAE,WAAW,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AAClF,OAAO;AACL,qFAAqF;AACrF,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG;AAClB,qBAAqB;AACrB,IAAI,EAAE,KAAK,EAAE,GAAG,EAAE,GAAG,EAAE,IAAI;AAC3B,uBAAuB;AACvB,IAAI,EAAE,OAAO,EAAE,KAAK;AACpB,yEAAyE;AACzE,QAAQ,EAAE,OAAO;AACjB,YAAY;AACZ,OAAO,EAAE,SAAS;AAClB,iBAAiB;AACjB,MAAM,EAAE,aAAa;AACrB,qBAAqB;AACrB,MAAM,EAAE,MAAM;AACd,4CAA4C;AAC5C,iBAAiB,EAAE,cAAc,EAAE,WAAW;AAC9C,UAAU;AACV,cAAc,GACf,MAAM,UAAU,CAAA;AAEjB,sFAAsF;AACtF,8EAA8E;AAC9E,2EAA2E;AAC3E,qDAAqD;AACrD,OAAO,EAAE,UAAU,EAAmB,MAAM,WAAW,CAAA;AACvD,OAAO,EAAE,UAAU,EAAoC,MAAM,WAAW,CAAA;AACxE,OAAO,EAAE,WAAW,EAAwE,MAAM,cAAc,CAAA;AAChH,OAAO,EAAE,WAAW,EAAmB,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAE,aAAa,EAA0C,MAAM,cAAc,CAAA;AACpF,OAAO,EAAE,OAAO,EAAE,WAAW,EAAE,aAAa,EAA8D,MAAM,cAAc,CAAA;AAC9H,OAAO,EAAE,MAAM,EAAE,iBAAiB,EAAE,MAAM,aAAa,CAAA"}
1
+ {"version":3,"file":"index.js","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAAA,+CAA+C;AAC/C,EAAE;AACF,8EAA8E;AAC9E,6CAA6C;AAG7C,OAAO,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AACvC,OAAO,EAAE,KAAK,EAAE,SAAS,EAAE,UAAU,EAAE,WAAW,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AAClF,OAAO,EAAE,OAAO,EAAE,MAAM,cAAc,CAAA;AACtC,OAAO;AACL,qFAAqF;AACrF,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG;AAClB,qBAAqB;AACrB,IAAI,EAAE,KAAK,EAAE,GAAG,EAAE,GAAG,EAAE,IAAI;AAC3B,uBAAuB;AACvB,IAAI,EAAE,OAAO,EAAE,KAAK;AACpB,yEAAyE;AACzE,QAAQ,EAAE,OAAO;AACjB,YAAY;AACZ,OAAO,EAAE,SAAS;AAClB,iBAAiB;AACjB,MAAM,EAAE,aAAa;AACrB,qBAAqB;AACrB,MAAM,EAAE,MAAM;AACd,4CAA4C;AAC5C,iBAAiB,EAAE,cAAc,EAAE,WAAW;AAC9C,UAAU;AACV,cAAc,GACf,MAAM,UAAU,CAAA;AAEjB,sFAAsF;AACtF,8EAA8E;AAC9E,2EAA2E;AAC3E,qDAAqD;AACrD,OAAO,EAAE,UAAU,EAAmB,MAAM,WAAW,CAAA;AACvD,OAAO,EAAE,UAAU,EAAoC,MAAM,WAAW,CAAA;AACxE,OAAO,EAAE,WAAW,EAAwE,MAAM,cAAc,CAAA;AAChH,OAAO,EAAE,WAAW,EAAmB,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAE,aAAa,EAAE,oBAAoB,EAAgJ,MAAM,cAAc,CAAA;AAChN,OAAO,EAAE,OAAO,EAAE,WAAW,EAAE,aAAa,EAAE,cAAc,EAA0F,MAAM,cAAc,CAAA;AAC1K,OAAO,EAAE,MAAM,EAAE,iBAAiB,EAA6D,MAAM,aAAa,CAAA;AAClH,OAAO,KAAK,EAAE,MAAM,SAAS,CAAA"}
package/dist/ir.d.ts CHANGED
@@ -162,6 +162,7 @@ export type OpNode = {
162
162
  vNew: number;
163
163
  lrt: number;
164
164
  eps: number;
165
+ decayShrink: number;
165
166
  } | {
166
167
  kind: 'slice_last_range';
167
168
  out: number;
@@ -193,6 +194,7 @@ export interface Graph {
193
194
  readonly ops: OpNode[];
194
195
  readonly tensors: Tensor[];
195
196
  readonly outputs: number[];
197
+ readonly captures: Map<string, number>;
196
198
  }
197
199
  export declare function makeGraph(): Graph;
198
200
  export declare function addTensor(g: Graph, shape: Shape, dtype: Dtype, source: number | null, site: CallSite | null): Tensor;
package/dist/ir.d.ts.map CHANGED
@@ -1 +1 @@
1
- {"version":3,"file":"ir.d.ts","sourceRoot":"","sources":["../src/ir.ts"],"names":[],"mappings":"AAcA,MAAM,MAAM,KAAK,GAAG,KAAK,GAAG,KAAK,GAAG,MAAM,CAAA;AAC1C,MAAM,MAAM,KAAK,GAAG,SAAS,MAAM,EAAE,CAAA;AAIrC,MAAM,WAAW,MAAM;IACrB,QAAQ,CAAC,EAAE,EAAE,MAAM,CAAA;IACnB,QAAQ,CAAC,KAAK,EAAE,KAAK,CAAA;IACrB,QAAQ,CAAC,KAAK,EAAE,KAAK,CAAA;IAErB,QAAQ,CAAC,MAAM,EAAE,MAAM,GAAG,IAAI,CAAA;IAG9B,QAAQ,CAAC,IAAI,EAAE,QAAQ,GAAG,IAAI,CAAA;CAC/B;AAED,MAAM,WAAW,QAAQ;IACvB,QAAQ,CAAC,MAAM,EAAE,MAAM,CAAA;IAEvB,QAAQ,CAAC,KAAK,EAAE,MAAM,CAAA;CACvB;AAQD,MAAM,MAAM,MAAM,GAGd;IAAE,IAAI,EAAE,aAAa,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,MAAM,CAAA;CAAE,GAElD;IAAE,IAAI,EAAE,cAAc,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,MAAM,CAAA;CAAE,GAInD;IAAE,IAAI,EAAE,aAAa,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,MAAM,CAAC;IAAC,SAAS,EAAE,MAAM,CAAA;CAAE,GAGrE;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAClD;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAClD;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAClD;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAClD;IAAE,IAAI,EAAE,YAAY,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,MAAM,EAAE,MAAM,CAAA;CAAE,GAC9D;IAAE,IAAI,EAAE,YAAY,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,MAAM,EAAE,MAAM,CAAA;CAAE,GAG9D;IAAE,IAAI,EAAE,MAAM,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACxC;IAAE,IAAI,EAAE,OAAO,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACzC;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACvC;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACvC;IAAE,IAAI,EAAE,MAAM,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAGxC;IAAE,IAAI,EAAE,WAAW,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAC7C;IAAE,IAAI,EAAE,UAAU,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAG5C;IAAE,IAAI,EAAE,SAAS,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,QAAQ,EAAE,KAAK,CAAA;CAAE,GAC5D;IAAE,IAAI,EAAE,WAAW,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,SAAS,MAAM,EAAE,CAAA;CAAE,GAMtE;IAAE,IAAI,EAAE,QAAQ,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAErD;IAAE,IAAI,EAAE,gBAAgB,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAG7D;IAAE,IAAI,EAAE,SAAS,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,OAAO,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,KAAK,CAAA;CAAE,GAC9E;IAAE,IAAI,EAAE,QAAQ,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,KAAK,CAAA;CAAE,GAGxD;IAAE,IAAI,EAAE,qBAAqB,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACvD;IAAE,IAAI,EAAE,kBAAkB,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAIpD;IAAE,IAAI,EAAE,cAAc,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,SAAS,EAAE,MAAM,CAAA;CAAE,GAKnE;IAAE,IAAI,EAAE,MAAM,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACnD;IAAE,IAAI,EAAE,SAAS,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAGtD;IAAE,IAAI,EAAE,OAAO,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAMlE;IAAE,IAAI,EAAE,eAAe,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,EAAE,EAAE,MAAM,CAAA;CAAE,GACxE;IAAE,IAAI,EAAE,eAAe,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,EAAE,EAAE,MAAM,CAAA;CAAE,GAKxE;IAAE,IAAI,EAAE,eAAe,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,MAAM,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,GAAG,EAAE,MAAM,CAAA;CAAE,GAMvG;IAAE,IAAI,EAAE,kBAAkB,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,MAAM,CAAC;IAAC,GAAG,EAAE,MAAM,CAAA;CAAE,GAGhF;IAAE,IAAI,EAAE,cAAc,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,WAAW,EAAE,KAAK,CAAA;CAAE,GAGpE;IAAE,IAAI,EAAE,cAAc,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,WAAW,EAAE,KAAK,CAAA;CAAE,GAEpE;IAAE,IAAI,EAAE,cAAc,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,KAAK,CAAA;CAAE,GAElE;IAAE,IAAI,EAAE,WAAW,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,EAAE,EAAE,MAAM,CAAA;CAAE,CAAA;AAI7D,MAAM,WAAW,KAAK;IACpB,QAAQ,CAAC,GAAG,EAAE,MAAM,EAAE,CAAA;IACtB,QAAQ,CAAC,OAAO,EAAE,MAAM,EAAE,CAAA;IAG1B,QAAQ,CAAC,OAAO,EAAE,MAAM,EAAE,CAAA;CAC3B;AAED,wBAAgB,SAAS,IAAI,KAAK,CAEjC;AAGD,wBAAgB,SAAS,CAAC,CAAC,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,EAAE,MAAM,GAAG,IAAI,EAAE,IAAI,EAAE,QAAQ,GAAG,IAAI,GAAG,MAAM,CAKpH;AAMD,wBAAgB,KAAK,CAAC,CAAC,SAAS,MAAM,CAAC,MAAM,CAAC,EAC5C,CAAC,EAAE,KAAK,EACR,IAAI,EAAE,CAAC,EACP,KAAK,EAAE,KAAK,EACZ,KAAK,EAAE,KAAK,EACZ,IAAI,EAAE,QAAQ,GAAG,IAAI,EACrB,MAAM,EAAE,IAAI,CAAC,OAAO,CAAC,MAAM,EAAE;IAAE,IAAI,EAAE,CAAC,CAAA;CAAE,CAAC,EAAE,MAAM,GAAG,KAAK,CAAC,GACzD,MAAM,CAMR;AAID,wBAAgB,WAAW,CAAC,MAAM,EAAE,MAAM,GAAG,QAAQ,CAIpD;AAID,wBAAgB,UAAU,CAAC,IAAI,EAAE,QAAQ,GAAG,MAAM,CAYjD"}
1
+ {"version":3,"file":"ir.d.ts","sourceRoot":"","sources":["../src/ir.ts"],"names":[],"mappings":"AAcA,MAAM,MAAM,KAAK,GAAG,KAAK,GAAG,KAAK,GAAG,MAAM,CAAA;AAC1C,MAAM,MAAM,KAAK,GAAG,SAAS,MAAM,EAAE,CAAA;AAIrC,MAAM,WAAW,MAAM;IACrB,QAAQ,CAAC,EAAE,EAAE,MAAM,CAAA;IACnB,QAAQ,CAAC,KAAK,EAAE,KAAK,CAAA;IACrB,QAAQ,CAAC,KAAK,EAAE,KAAK,CAAA;IAErB,QAAQ,CAAC,MAAM,EAAE,MAAM,GAAG,IAAI,CAAA;IAG9B,QAAQ,CAAC,IAAI,EAAE,QAAQ,GAAG,IAAI,CAAA;CAC/B;AAED,MAAM,WAAW,QAAQ;IACvB,QAAQ,CAAC,MAAM,EAAE,MAAM,CAAA;IAEvB,QAAQ,CAAC,KAAK,EAAE,MAAM,CAAA;CACvB;AAQD,MAAM,MAAM,MAAM,GAGd;IAAE,IAAI,EAAE,aAAa,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,MAAM,CAAA;CAAE,GAElD;IAAE,IAAI,EAAE,cAAc,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,MAAM,CAAA;CAAE,GAInD;IAAE,IAAI,EAAE,aAAa,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,MAAM,CAAC;IAAC,SAAS,EAAE,MAAM,CAAA;CAAE,GAGrE;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAClD;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAClD;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAClD;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAClD;IAAE,IAAI,EAAE,YAAY,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,MAAM,EAAE,MAAM,CAAA;CAAE,GAC9D;IAAE,IAAI,EAAE,YAAY,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,MAAM,EAAE,MAAM,CAAA;CAAE,GAG9D;IAAE,IAAI,EAAE,MAAM,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACxC;IAAE,IAAI,EAAE,OAAO,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACzC;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACvC;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACvC;IAAE,IAAI,EAAE,MAAM,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAGxC;IAAE,IAAI,EAAE,WAAW,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAC7C;IAAE,IAAI,EAAE,UAAU,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAG5C;IAAE,IAAI,EAAE,SAAS,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,QAAQ,EAAE,KAAK,CAAA;CAAE,GAC5D;IAAE,IAAI,EAAE,WAAW,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,SAAS,MAAM,EAAE,CAAA;CAAE,GAMtE;IAAE,IAAI,EAAE,QAAQ,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAErD;IAAE,IAAI,EAAE,gBAAgB,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAG7D;IAAE,IAAI,EAAE,SAAS,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,OAAO,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,KAAK,CAAA;CAAE,GAC9E;IAAE,IAAI,EAAE,QAAQ,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,KAAK,CAAA;CAAE,GAGxD;IAAE,IAAI,EAAE,qBAAqB,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACvD;IAAE,IAAI,EAAE,kBAAkB,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAIpD;IAAE,IAAI,EAAE,cAAc,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,SAAS,EAAE,MAAM,CAAA;CAAE,GAKnE;IAAE,IAAI,EAAE,MAAM,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACnD;IAAE,IAAI,EAAE,SAAS,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAGtD;IAAE,IAAI,EAAE,OAAO,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAMlE;IAAE,IAAI,EAAE,eAAe,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,EAAE,EAAE,MAAM,CAAA;CAAE,GACxE;IAAE,IAAI,EAAE,eAAe,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,EAAE,EAAE,MAAM,CAAA;CAAE,GAOxE;IAAE,IAAI,EAAE,eAAe,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,MAAM,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,WAAW,EAAE,MAAM,CAAA;CAAE,GAM5H;IAAE,IAAI,EAAE,kBAAkB,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,MAAM,CAAC;IAAC,GAAG,EAAE,MAAM,CAAA;CAAE,GAGhF;IAAE,IAAI,EAAE,cAAc,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,WAAW,EAAE,KAAK,CAAA;CAAE,GAGpE;IAAE,IAAI,EAAE,cAAc,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,WAAW,EAAE,KAAK,CAAA;CAAE,GAEpE;IAAE,IAAI,EAAE,cAAc,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,KAAK,CAAA;CAAE,GAElE;IAAE,IAAI,EAAE,WAAW,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,EAAE,EAAE,MAAM,CAAA;CAAE,CAAA;AAI7D,MAAM,WAAW,KAAK;IACpB,QAAQ,CAAC,GAAG,EAAE,MAAM,EAAE,CAAA;IACtB,QAAQ,CAAC,OAAO,EAAE,MAAM,EAAE,CAAA;IAG1B,QAAQ,CAAC,OAAO,EAAE,MAAM,EAAE,CAAA;IAI1B,QAAQ,CAAC,QAAQ,EAAE,GAAG,CAAC,MAAM,EAAE,MAAM,CAAC,CAAA;CACvC;AAED,wBAAgB,SAAS,IAAI,KAAK,CAEjC;AAGD,wBAAgB,SAAS,CAAC,CAAC,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,EAAE,MAAM,GAAG,IAAI,EAAE,IAAI,EAAE,QAAQ,GAAG,IAAI,GAAG,MAAM,CAKpH;AAMD,wBAAgB,KAAK,CAAC,CAAC,SAAS,MAAM,CAAC,MAAM,CAAC,EAC5C,CAAC,EAAE,KAAK,EACR,IAAI,EAAE,CAAC,EACP,KAAK,EAAE,KAAK,EACZ,KAAK,EAAE,KAAK,EACZ,IAAI,EAAE,QAAQ,GAAG,IAAI,EACrB,MAAM,EAAE,IAAI,CAAC,OAAO,CAAC,MAAM,EAAE;IAAE,IAAI,EAAE,CAAC,CAAA;CAAE,CAAC,EAAE,MAAM,GAAG,KAAK,CAAC,GACzD,MAAM,CAMR;AAID,wBAAgB,WAAW,CAAC,MAAM,EAAE,MAAM,GAAG,QAAQ,CAIpD;AAID,wBAAgB,UAAU,CAAC,IAAI,EAAE,QAAQ,GAAG,MAAM,CAYjD"}
package/dist/ir.js CHANGED
@@ -12,7 +12,7 @@
12
12
  // Design intent: keep this file boring. No tracing logic, no shape inference,
13
13
  // no codegen — those live in their own modules and consume `Graph` / `OpNode`.
14
14
  export function makeGraph() {
15
- return { ops: [], tensors: [], outputs: [] };
15
+ return { ops: [], tensors: [], outputs: [], captures: new Map() };
16
16
  }
17
17
  // Internal: register a fresh tensor in the graph and return its id.
18
18
  export function addTensor(g, shape, dtype, source, site) {
package/dist/ir.js.map CHANGED
@@ -1 +1 @@
1
- {"version":3,"file":"ir.js","sourceRoot":"","sources":["../src/ir.ts"],"names":[],"mappings":"AAAA,uDAAuD;AACvD,EAAE;AACF,gFAAgF;AAChF,+EAA+E;AAC/E,8EAA8E;AAC9E,EAAE;AACF,0DAA0D;AAC1D,uCAAuC;AACvC,8EAA8E;AAC9E,wFAAwF;AACxF,EAAE;AACF,8EAA8E;AAC9E,+EAA+E;AAmI/E,MAAM,UAAU,SAAS;IACvB,OAAO,EAAE,GAAG,EAAE,EAAE,EAAE,OAAO,EAAE,EAAE,EAAE,OAAO,EAAE,EAAE,EAAE,CAAA;AAC9C,CAAC;AAED,oEAAoE;AACpE,MAAM,UAAU,SAAS,CAAC,CAAQ,EAAE,KAAY,EAAE,KAAY,EAAE,MAAqB,EAAE,IAAqB;IAC1G,MAAM,EAAE,GAAG,CAAC,CAAC,OAAO,CAAC,MAAM,CAAA;IAC3B,MAAM,CAAC,GAAW,EAAE,EAAE,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,EAAE,IAAI,EAAE,CAAA;IACpD,CAAC,CAAC,OAAO,CAAC,IAAI,CAAC,CAAC,CAAC,CAAA;IACjB,OAAO,CAAC,CAAA;AACV,CAAC;AAED,kFAAkF;AAClF,0EAA0E;AAC1E,+EAA+E;AAC/E,kFAAkF;AAClF,MAAM,UAAU,KAAK,CACnB,CAAQ,EACR,IAAO,EACP,KAAY,EACZ,KAAY,EACZ,IAAqB,EACrB,MAA0D;IAE1D,MAAM,OAAO,GAAG,CAAC,CAAC,GAAG,CAAC,MAAM,CAAA;IAC5B,MAAM,GAAG,GAAG,SAAS,CAAC,CAAC,EAAE,KAAK,EAAE,KAAK,EAAE,OAAO,EAAE,IAAI,CAAC,CAAA;IACrD,MAAM,IAAI,GAAG,EAAE,IAAI,EAAE,GAAG,EAAE,GAAG,CAAC,EAAE,EAAE,GAAG,MAAM,EAAkC,CAAA;IAC7E,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,IAAI,CAAC,CAAA;IAChB,OAAO,GAAG,CAAA;AACZ,CAAC;AAED,0EAA0E;AAC1E,iFAAiF;AACjF,MAAM,UAAU,WAAW,CAAC,MAAc;IACxC,+EAA+E;IAC/E,MAAM,KAAK,GAAG,CAAC,IAAI,KAAK,EAAE,CAAC,CAAC,KAAK,IAAI,EAAE,CAAA;IACvC,OAAO,EAAE,MAAM,EAAE,KAAK,EAAE,CAAA;AAC1B,CAAC;AAED,8EAA8E;AAC9E,2DAA2D;AAC3D,MAAM,UAAU,UAAU,CAAC,IAAc;IACvC,MAAM,KAAK,GAAG,IAAI,CAAC,KAAK,CAAC,KAAK,CAAC,IAAI,CAAC,CAAA;IACpC,2EAA2E;IAC3E,iEAAiE;IACjE,MAAM,UAAU,GAAa,EAAE,CAAA;IAC/B,KAAK,MAAM,IAAI,IAAI,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC;QAClC,IAAI,IAAI,CAAC,QAAQ,CAAC,kBAAkB,CAAC,IAAI,IAAI,CAAC,QAAQ,CAAC,qBAAqB,CAAC;YAAE,SAAQ;QACvF,UAAU,CAAC,IAAI,CAAC,IAAI,CAAC,IAAI,EAAE,CAAC,CAAA;QAC5B,IAAI,UAAU,CAAC,MAAM,IAAI,CAAC;YAAE,MAAK;IACnC,CAAC;IACD,IAAI,UAAU,CAAC,MAAM,KAAK,CAAC;QAAE,OAAO,IAAI,IAAI,CAAC,MAAM,yBAAyB,CAAA;IAC5E,OAAO,IAAI,IAAI,CAAC,MAAM,QAAQ,UAAU,CAAC,IAAI,CAAC,MAAM,CAAC,EAAE,CAAA;AACzD,CAAC"}
1
+ {"version":3,"file":"ir.js","sourceRoot":"","sources":["../src/ir.ts"],"names":[],"mappings":"AAAA,uDAAuD;AACvD,EAAE;AACF,gFAAgF;AAChF,+EAA+E;AAC/E,8EAA8E;AAC9E,EAAE;AACF,0DAA0D;AAC1D,uCAAuC;AACvC,8EAA8E;AAC9E,wFAAwF;AACxF,EAAE;AACF,8EAA8E;AAC9E,+EAA+E;AAyI/E,MAAM,UAAU,SAAS;IACvB,OAAO,EAAE,GAAG,EAAE,EAAE,EAAE,OAAO,EAAE,EAAE,EAAE,OAAO,EAAE,EAAE,EAAE,QAAQ,EAAE,IAAI,GAAG,EAAE,EAAE,CAAA;AACnE,CAAC;AAED,oEAAoE;AACpE,MAAM,UAAU,SAAS,CAAC,CAAQ,EAAE,KAAY,EAAE,KAAY,EAAE,MAAqB,EAAE,IAAqB;IAC1G,MAAM,EAAE,GAAG,CAAC,CAAC,OAAO,CAAC,MAAM,CAAA;IAC3B,MAAM,CAAC,GAAW,EAAE,EAAE,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,EAAE,IAAI,EAAE,CAAA;IACpD,CAAC,CAAC,OAAO,CAAC,IAAI,CAAC,CAAC,CAAC,CAAA;IACjB,OAAO,CAAC,CAAA;AACV,CAAC;AAED,kFAAkF;AAClF,0EAA0E;AAC1E,+EAA+E;AAC/E,kFAAkF;AAClF,MAAM,UAAU,KAAK,CACnB,CAAQ,EACR,IAAO,EACP,KAAY,EACZ,KAAY,EACZ,IAAqB,EACrB,MAA0D;IAE1D,MAAM,OAAO,GAAG,CAAC,CAAC,GAAG,CAAC,MAAM,CAAA;IAC5B,MAAM,GAAG,GAAG,SAAS,CAAC,CAAC,EAAE,KAAK,EAAE,KAAK,EAAE,OAAO,EAAE,IAAI,CAAC,CAAA;IACrD,MAAM,IAAI,GAAG,EAAE,IAAI,EAAE,GAAG,EAAE,GAAG,CAAC,EAAE,EAAE,GAAG,MAAM,EAAkC,CAAA;IAC7E,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,IAAI,CAAC,CAAA;IAChB,OAAO,GAAG,CAAA;AACZ,CAAC;AAED,0EAA0E;AAC1E,iFAAiF;AACjF,MAAM,UAAU,WAAW,CAAC,MAAc;IACxC,+EAA+E;IAC/E,MAAM,KAAK,GAAG,CAAC,IAAI,KAAK,EAAE,CAAC,CAAC,KAAK,IAAI,EAAE,CAAA;IACvC,OAAO,EAAE,MAAM,EAAE,KAAK,EAAE,CAAA;AAC1B,CAAC;AAED,8EAA8E;AAC9E,2DAA2D;AAC3D,MAAM,UAAU,UAAU,CAAC,IAAc;IACvC,MAAM,KAAK,GAAG,IAAI,CAAC,KAAK,CAAC,KAAK,CAAC,IAAI,CAAC,CAAA;IACpC,2EAA2E;IAC3E,iEAAiE;IACjE,MAAM,UAAU,GAAa,EAAE,CAAA;IAC/B,KAAK,MAAM,IAAI,IAAI,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC;QAClC,IAAI,IAAI,CAAC,QAAQ,CAAC,kBAAkB,CAAC,IAAI,IAAI,CAAC,QAAQ,CAAC,qBAAqB,CAAC;YAAE,SAAQ;QACvF,UAAU,CAAC,IAAI,CAAC,IAAI,CAAC,IAAI,EAAE,CAAC,CAAA;QAC5B,IAAI,UAAU,CAAC,MAAM,IAAI,CAAC;YAAE,MAAK;IACnC,CAAC;IACD,IAAI,UAAU,CAAC,MAAM,KAAK,CAAC;QAAE,OAAO,IAAI,IAAI,CAAC,MAAM,yBAAyB,CAAA;IAC5E,OAAO,IAAI,IAAI,CAAC,MAAM,QAAQ,UAAU,CAAC,IAAI,CAAC,MAAM,CAAC,EAAE,CAAA;AACzD,CAAC"}
package/dist/module.d.ts CHANGED
@@ -1,4 +1,21 @@
1
1
  import type { Tensor, Shape, Dtype } from './ir.js';
2
+ /** How a parameter's initial values are produced.
3
+ * - `'randn'` — Gaussian, with `scale` (default 0.02). The common case for
4
+ * weight matrices and embeddings.
5
+ * - `'zeros'` — fill with 0. Common for biases and LayerNorm beta.
6
+ * - `'ones'` — fill with 1. Common for LayerNorm gain.
7
+ * - Custom function — receives total element count and shape, returns the
8
+ * Float32Array. Use for fan-in scaling or any non-standard scheme.
9
+ */
10
+ export type InitSpec = 'randn' | 'zeros' | 'ones' | ((size: number, shape: readonly number[]) => Float32Array);
11
+ export interface ParamOptions {
12
+ dtype?: Dtype;
13
+ /** Init kind. Default: `'randn'`. */
14
+ init?: InitSpec;
15
+ /** Std dev for `'randn'`. Default 0.02. Ignored for non-randn init. */
16
+ scale?: number;
17
+ }
18
+ type InitFn = (size: number, shape: readonly number[]) => Float32Array;
2
19
  export declare abstract class Module {
3
20
  /**
4
21
  * Declare a learnable parameter at this module. Must be called from inside
@@ -6,16 +23,25 @@ export declare abstract class Module {
6
23
  * that gets replaced with a real Tensor at compile time.
7
24
  *
8
25
  * The parameter's name is auto-derived from its property path in the model
9
- * tree (e.g. `layers.0.attn.W_q`).
26
+ * tree (e.g. `layers.0.attn.W_q`). Init metadata travels with the param;
27
+ * call `compiled.uploadInitialParams()` to apply it after compile.
10
28
  */
11
- protected param(shape: Shape, dtype?: Dtype): Tensor;
29
+ protected param(shape: Shape, opts?: ParamOptions): Tensor;
30
+ }
31
+ export interface MaterializedParams {
32
+ /** Map from auto-derived path (e.g. `layers.0.attn.W_q`) to its Tensor. */
33
+ tensors: Record<string, Tensor>;
34
+ /** Init function per param path. Used by `uploadInitialParams`. */
35
+ initFns: Record<string, InitFn>;
12
36
  }
13
37
  /**
14
38
  * Walk the module tree and replace every ParamSentinel with a real Tensor
15
39
  * created via `paramInput(autoName, ...)`. Must be called inside an active
16
40
  * trace context (paramInput appends to the current graph).
17
41
  *
18
- * Returns a flat record of `{ path: tensor }` for every materialized param.
42
+ * Returns the param tensors keyed by path, plus init functions for use by
43
+ * `uploadInitialParams`.
19
44
  */
20
- export declare function materializeParams(root: Module): Record<string, Tensor>;
45
+ export declare function materializeParams(root: Module): MaterializedParams;
46
+ export {};
21
47
  //# sourceMappingURL=module.d.ts.map
@@ -1 +1 @@
1
- {"version":3,"file":"module.d.ts","sourceRoot":"","sources":["../src/module.ts"],"names":[],"mappings":"AA2BA,OAAO,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,SAAS,CAAA;AAoBnD,8BAAsB,MAAM;IAC1B;;;;;;;OAOG;IACH,SAAS,CAAC,KAAK,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,GAAE,KAAa,GAAG,MAAM;CAI5D;AAMD;;;;;;GAMG;AACH,wBAAgB,iBAAiB,CAAC,IAAI,EAAE,MAAM,GAAG,MAAM,CAAC,MAAM,EAAE,MAAM,CAAC,CAUtE"}
1
+ {"version":3,"file":"module.d.ts","sourceRoot":"","sources":["../src/module.ts"],"names":[],"mappings":"AA2BA,OAAO,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,SAAS,CAAA;AAOnD;;;;;;;GAOG;AACH,MAAM,MAAM,QAAQ,GAChB,OAAO,GACP,OAAO,GACP,MAAM,GACN,CAAC,CAAC,IAAI,EAAE,MAAM,EAAE,KAAK,EAAE,SAAS,MAAM,EAAE,KAAK,YAAY,CAAC,CAAA;AAE9D,MAAM,WAAW,YAAY;IAC3B,KAAK,CAAC,EAAE,KAAK,CAAA;IACb,qCAAqC;IACrC,IAAI,CAAC,EAAE,QAAQ,CAAA;IACf,uEAAuE;IACvE,KAAK,CAAC,EAAE,MAAM,CAAA;CACf;AAED,KAAK,MAAM,GAAG,CAAC,IAAI,EAAE,MAAM,EAAE,KAAK,EAAE,SAAS,MAAM,EAAE,KAAK,YAAY,CAAA;AA2CtE,8BAAsB,MAAM;IAC1B;;;;;;;;OAQG;IACH,SAAS,CAAC,KAAK,CAAC,KAAK,EAAE,KAAK,EAAE,IAAI,CAAC,EAAE,YAAY,GAAG,MAAM;CAK3D;AAMD,MAAM,WAAW,kBAAkB;IACjC,2EAA2E;IAC3E,OAAO,EAAE,MAAM,CAAC,MAAM,EAAE,MAAM,CAAC,CAAA;IAC/B,mEAAmE;IACnE,OAAO,EAAE,MAAM,CAAC,MAAM,EAAE,MAAM,CAAC,CAAA;CAChC;AAED;;;;;;;GAOG;AACH,wBAAgB,iBAAiB,CAAC,IAAI,EAAE,MAAM,GAAG,kBAAkB,CAYlE"}
package/dist/module.js CHANGED
@@ -6,8 +6,8 @@
6
6
  // W: Tensor; b: Tensor
7
7
  // constructor(inDim: number, outDim: number) {
8
8
  // super()
9
- // this.W = this.param([inDim, outDim])
10
- // this.b = this.param([outDim])
9
+ // this.W = this.param([inDim, outDim]) // randn, scale 0.02
10
+ // this.b = this.param([outDim], { init: 'zeros' })
11
11
  // }
12
12
  // }
13
13
  // class Block extends Module {
@@ -25,6 +25,28 @@
25
25
  // and writeback wiring. Forward functions are pure and stateless — they
26
26
  // take the materialized model and inputs, return a Tensor.
27
27
  import { paramInput } from './trace.js';
28
+ function boxMuller() {
29
+ return Math.sqrt(-2 * Math.log(Math.max(1e-10, Math.random()))) * Math.cos(2 * Math.PI * Math.random());
30
+ }
31
+ function resolveInit(opts) {
32
+ const init = opts?.init ?? 'randn';
33
+ if (init === 'randn') {
34
+ const scale = opts?.scale ?? 0.02;
35
+ return (size) => {
36
+ const arr = new Float32Array(size);
37
+ for (let i = 0; i < size; i++)
38
+ arr[i] = boxMuller() * scale;
39
+ return arr;
40
+ };
41
+ }
42
+ if (init === 'zeros')
43
+ return (size) => new Float32Array(size);
44
+ if (init === 'ones')
45
+ return (size) => { const a = new Float32Array(size); a.fill(1); return a; };
46
+ if (typeof init === 'function')
47
+ return init;
48
+ throw new Error(`Unknown init: ${String(init)}`);
49
+ }
28
50
  // ============================================================================
29
51
  // Internals: param sentinel
30
52
  // ============================================================================
@@ -36,9 +58,11 @@ import { paramInput } from './trace.js';
36
58
  class ParamSentinel {
37
59
  shape;
38
60
  dtype;
39
- constructor(shape, dtype) {
61
+ initFn;
62
+ constructor(shape, dtype, initFn) {
40
63
  this.shape = shape;
41
64
  this.dtype = dtype;
65
+ this.initFn = initFn;
42
66
  }
43
67
  }
44
68
  // ============================================================================
@@ -51,33 +75,35 @@ export class Module {
51
75
  * that gets replaced with a real Tensor at compile time.
52
76
  *
53
77
  * The parameter's name is auto-derived from its property path in the model
54
- * tree (e.g. `layers.0.attn.W_q`).
78
+ * tree (e.g. `layers.0.attn.W_q`). Init metadata travels with the param;
79
+ * call `compiled.uploadInitialParams()` to apply it after compile.
55
80
  */
56
- param(shape, dtype = 'f32') {
81
+ param(shape, opts) {
82
+ const dtype = opts?.dtype ?? 'f32';
57
83
  // Lie to TypeScript: the sentinel becomes a Tensor at materialize time.
58
- return new ParamSentinel(shape, dtype);
84
+ return new ParamSentinel(shape, dtype, resolveInit(opts));
59
85
  }
60
86
  }
61
- // ============================================================================
62
- // Tree walking
63
- // ============================================================================
64
87
  /**
65
88
  * Walk the module tree and replace every ParamSentinel with a real Tensor
66
89
  * created via `paramInput(autoName, ...)`. Must be called inside an active
67
90
  * trace context (paramInput appends to the current graph).
68
91
  *
69
- * Returns a flat record of `{ path: tensor }` for every materialized param.
92
+ * Returns the param tensors keyed by path, plus init functions for use by
93
+ * `uploadInitialParams`.
70
94
  */
71
95
  export function materializeParams(root) {
72
- const out = {};
96
+ const tensors = {};
97
+ const initFns = {};
73
98
  visit(root, '', (path, val, owner, key) => {
74
99
  if (val instanceof ParamSentinel) {
75
100
  const t = paramInput(path, val.shape, val.dtype);
76
101
  owner[key] = t;
77
- out[path] = t;
102
+ tensors[path] = t;
103
+ initFns[path] = val.initFn;
78
104
  }
79
105
  });
80
- return out;
106
+ return { tensors, initFns };
81
107
  }
82
108
  function visit(node, path, visitor) {
83
109
  if (node === null || node === undefined)
@@ -1 +1 @@
1
- {"version":3,"file":"module.js","sourceRoot":"","sources":["../src/module.ts"],"names":[],"mappings":"AAAA,6EAA6E;AAC7E,EAAE;AACF,+CAA+C;AAC/C,EAAE;AACF,kCAAkC;AAClC,2BAA2B;AAC3B,mDAAmD;AACnD,gBAAgB;AAChB,6CAA6C;AAC7C,sCAAsC;AACtC,QAAQ;AACR,MAAM;AACN,iCAAiC;AACjC,8BAA8B;AAC9B,+BAA+B;AAC/B,MAAM;AACN,iCAAiC;AACjC,mCAAmC;AACnC,+CAA+C;AAC/C,MAAM;AACN,EAAE;AACF,wEAAwE;AACxE,0EAA0E;AAC1E,0EAA0E;AAC1E,wEAAwE;AACxE,2DAA2D;AAG3D,OAAO,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AAEvC,+EAA+E;AAC/E,4BAA4B;AAC5B,+EAA+E;AAC/E,EAAE;AACF,6EAA6E;AAC7E,4EAA4E;AAC5E,2EAA2E;AAC3E,yEAAyE;AAEzE,MAAM,aAAa;IACW;IAA8B;IAA1D,YAA4B,KAAY,EAAkB,KAAY;QAA1C,UAAK,GAAL,KAAK,CAAO;QAAkB,UAAK,GAAL,KAAK,CAAO;IAAG,CAAC;CAC3E;AAED,+EAA+E;AAC/E,oBAAoB;AACpB,+EAA+E;AAE/E,MAAM,OAAgB,MAAM;IAC1B;;;;;;;OAOG;IACO,KAAK,CAAC,KAAY,EAAE,QAAe,KAAK;QAChD,wEAAwE;QACxE,OAAO,IAAI,aAAa,CAAC,KAAK,EAAE,KAAK,CAAsB,CAAA;IAC7D,CAAC;CACF;AAED,+EAA+E;AAC/E,eAAe;AACf,+EAA+E;AAE/E;;;;;;GAMG;AACH,MAAM,UAAU,iBAAiB,CAAC,IAAY;IAC5C,MAAM,GAAG,GAA2B,EAAE,CAAA;IACtC,KAAK,CAAC,IAAI,EAAE,EAAE,EAAE,CAAC,IAAI,EAAE,GAAG,EAAE,KAAK,EAAE,GAAG,EAAE,EAAE;QACxC,IAAI,GAAG,YAAY,aAAa,EAAE,CAAC;YACjC,MAAM,CAAC,GAAG,UAAU,CAAC,IAAI,EAAE,GAAG,CAAC,KAAK,EAAE,GAAG,CAAC,KAAK,CAAC,CAC/C;YAAC,KAAa,CAAC,GAAG,CAAC,GAAG,CAAC,CAAA;YACxB,GAAG,CAAC,IAAI,CAAC,GAAG,CAAC,CAAA;QACf,CAAC;IACH,CAAC,CAAC,CAAA;IACF,OAAO,GAAG,CAAA;AACZ,CAAC;AAaD,SAAS,KAAK,CAAC,IAAa,EAAE,IAAY,EAAE,OAAgB;IAC1D,IAAI,IAAI,KAAK,IAAI,IAAI,IAAI,KAAK,SAAS;QAAE,OAAM;IAC/C,IAAI,OAAO,IAAI,KAAK,QAAQ;QAAE,OAAM;IAEpC,IAAI,IAAI,YAAY,MAAM,EAAE,CAAC;QAC3B,KAAK,MAAM,GAAG,IAAI,MAAM,CAAC,IAAI,CAAC,IAAc,CAAC,EAAE,CAAC;YAC9C,MAAM,KAAK,GAAI,IAAY,CAAC,GAAG,CAAC,CAAA;YAChC,MAAM,SAAS,GAAG,IAAI,CAAC,CAAC,CAAC,GAAG,IAAI,IAAI,GAAG,EAAE,CAAC,CAAC,CAAC,GAAG,CAAA;YAC/C,UAAU,CAAC,KAAK,EAAE,SAAS,EAAE,IAAI,EAAE,GAAG,EAAE,OAAO,CAAC,CAAA;QAClD,CAAC;QACD,OAAM;IACR,CAAC;IACD,IAAI,KAAK,CAAC,OAAO,CAAC,IAAI,CAAC,EAAE,CAAC;QACxB,IAAI,CAAC,OAAO,CAAC,CAAC,IAAI,EAAE,CAAC,EAAE,EAAE;YACvB,MAAM,SAAS,GAAG,IAAI,CAAC,CAAC,CAAC,GAAG,IAAI,IAAI,CAAC,EAAE,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAA;YACnD,UAAU,CAAC,IAAI,EAAE,SAAS,EAAE,IAAyB,EAAE,CAAC,EAAE,OAAO,CAAC,CAAA;QACpE,CAAC,CAAC,CAAA;QACF,OAAM;IACR,CAAC;IACD,2EAA2E;IAC3E,uBAAuB;AACzB,CAAC;AAED,SAAS,UAAU,CAAC,KAAc,EAAE,IAAY,EAAE,KAAa,EAAE,GAAoB,EAAE,OAAgB;IACrG,IAAI,KAAK,YAAY,MAAM,IAAI,KAAK,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE,CAAC;QACpD,KAAK,CAAC,KAAK,EAAE,IAAI,EAAE,OAAO,CAAC,CAAA;IAC7B,CAAC;SAAM,CAAC;QACN,OAAO,CAAC,IAAI,EAAE,KAAK,EAAE,KAAK,EAAE,GAAG,CAAC,CAAA;IAClC,CAAC;AACH,CAAC"}
1
+ {"version":3,"file":"module.js","sourceRoot":"","sources":["../src/module.ts"],"names":[],"mappings":"AAAA,6EAA6E;AAC7E,EAAE;AACF,+CAA+C;AAC/C,EAAE;AACF,kCAAkC;AAClC,2BAA2B;AAC3B,mDAAmD;AACnD,gBAAgB;AAChB,gFAAgF;AAChF,yDAAyD;AACzD,QAAQ;AACR,MAAM;AACN,iCAAiC;AACjC,8BAA8B;AAC9B,+BAA+B;AAC/B,MAAM;AACN,iCAAiC;AACjC,mCAAmC;AACnC,+CAA+C;AAC/C,MAAM;AACN,EAAE;AACF,wEAAwE;AACxE,0EAA0E;AAC1E,0EAA0E;AAC1E,wEAAwE;AACxE,2DAA2D;AAG3D,OAAO,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AA8BvC,SAAS,SAAS;IAChB,OAAO,IAAI,CAAC,IAAI,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,IAAI,CAAC,GAAG,CAAC,KAAK,EAAE,IAAI,CAAC,MAAM,EAAE,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,GAAG,IAAI,CAAC,EAAE,GAAG,IAAI,CAAC,MAAM,EAAE,CAAC,CAAA;AACzG,CAAC;AAED,SAAS,WAAW,CAAC,IAA8B;IACjD,MAAM,IAAI,GAAG,IAAI,EAAE,IAAI,IAAI,OAAO,CAAA;IAClC,IAAI,IAAI,KAAK,OAAO,EAAE,CAAC;QACrB,MAAM,KAAK,GAAG,IAAI,EAAE,KAAK,IAAI,IAAI,CAAA;QACjC,OAAO,CAAC,IAAI,EAAE,EAAE;YACd,MAAM,GAAG,GAAG,IAAI,YAAY,CAAC,IAAI,CAAC,CAAA;YAClC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,EAAE,CAAC,EAAE;gBAAE,GAAG,CAAC,CAAC,CAAC,GAAG,SAAS,EAAE,GAAG,KAAK,CAAA;YAC3D,OAAO,GAAG,CAAA;QACZ,CAAC,CAAA;IACH,CAAC;IACD,IAAI,IAAI,KAAK,OAAO;QAAE,OAAO,CAAC,IAAI,EAAE,EAAE,CAAC,IAAI,YAAY,CAAC,IAAI,CAAC,CAAA;IAC7D,IAAI,IAAI,KAAK,MAAM;QAAE,OAAO,CAAC,IAAI,EAAE,EAAE,GAAG,MAAM,CAAC,GAAG,IAAI,YAAY,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,CAAA,CAAC,CAAC,CAAA;IAC/F,IAAI,OAAO,IAAI,KAAK,UAAU;QAAE,OAAO,IAAI,CAAA;IAC3C,MAAM,IAAI,KAAK,CAAC,iBAAiB,MAAM,CAAC,IAAI,CAAC,EAAE,CAAC,CAAA;AAClD,CAAC;AAED,+EAA+E;AAC/E,4BAA4B;AAC5B,+EAA+E;AAC/E,EAAE;AACF,6EAA6E;AAC7E,4EAA4E;AAC5E,2EAA2E;AAC3E,yEAAyE;AAEzE,MAAM,aAAa;IAEC;IACA;IACA;IAHlB,YACkB,KAAY,EACZ,KAAY,EACZ,MAAc;QAFd,UAAK,GAAL,KAAK,CAAO;QACZ,UAAK,GAAL,KAAK,CAAO;QACZ,WAAM,GAAN,MAAM,CAAQ;IAC7B,CAAC;CACL;AAED,+EAA+E;AAC/E,oBAAoB;AACpB,+EAA+E;AAE/E,MAAM,OAAgB,MAAM;IAC1B;;;;;;;;OAQG;IACO,KAAK,CAAC,KAAY,EAAE,IAAmB;QAC/C,MAAM,KAAK,GAAG,IAAI,EAAE,KAAK,IAAI,KAAK,CAAA;QAClC,wEAAwE;QACxE,OAAO,IAAI,aAAa,CAAC,KAAK,EAAE,KAAK,EAAE,WAAW,CAAC,IAAI,CAAC,CAAsB,CAAA;IAChF,CAAC;CACF;AAaD;;;;;;;GAOG;AACH,MAAM,UAAU,iBAAiB,CAAC,IAAY;IAC5C,MAAM,OAAO,GAA2B,EAAE,CAAA;IAC1C,MAAM,OAAO,GAA2B,EAAE,CAAA;IAC1C,KAAK,CAAC,IAAI,EAAE,EAAE,EAAE,CAAC,IAAI,EAAE,GAAG,EAAE,KAAK,EAAE,GAAG,EAAE,EAAE;QACxC,IAAI,GAAG,YAAY,aAAa,EAAE,CAAC;YACjC,MAAM,CAAC,GAAG,UAAU,CAAC,IAAI,EAAE,GAAG,CAAC,KAAK,EAAE,GAAG,CAAC,KAAK,CAAC,CAC/C;YAAC,KAAa,CAAC,GAAG,CAAC,GAAG,CAAC,CAAA;YACxB,OAAO,CAAC,IAAI,CAAC,GAAG,CAAC,CAAA;YACjB,OAAO,CAAC,IAAI,CAAC,GAAG,GAAG,CAAC,MAAM,CAAA;QAC5B,CAAC;IACH,CAAC,CAAC,CAAA;IACF,OAAO,EAAE,OAAO,EAAE,OAAO,EAAE,CAAA;AAC7B,CAAC;AAaD,SAAS,KAAK,CAAC,IAAa,EAAE,IAAY,EAAE,OAAgB;IAC1D,IAAI,IAAI,KAAK,IAAI,IAAI,IAAI,KAAK,SAAS;QAAE,OAAM;IAC/C,IAAI,OAAO,IAAI,KAAK,QAAQ;QAAE,OAAM;IAEpC,IAAI,IAAI,YAAY,MAAM,EAAE,CAAC;QAC3B,KAAK,MAAM,GAAG,IAAI,MAAM,CAAC,IAAI,CAAC,IAAc,CAAC,EAAE,CAAC;YAC9C,MAAM,KAAK,GAAI,IAAY,CAAC,GAAG,CAAC,CAAA;YAChC,MAAM,SAAS,GAAG,IAAI,CAAC,CAAC,CAAC,GAAG,IAAI,IAAI,GAAG,EAAE,CAAC,CAAC,CAAC,GAAG,CAAA;YAC/C,UAAU,CAAC,KAAK,EAAE,SAAS,EAAE,IAAI,EAAE,GAAG,EAAE,OAAO,CAAC,CAAA;QAClD,CAAC;QACD,OAAM;IACR,CAAC;IACD,IAAI,KAAK,CAAC,OAAO,CAAC,IAAI,CAAC,EAAE,CAAC;QACxB,IAAI,CAAC,OAAO,CAAC,CAAC,IAAI,EAAE,CAAC,EAAE,EAAE;YACvB,MAAM,SAAS,GAAG,IAAI,CAAC,CAAC,CAAC,GAAG,IAAI,IAAI,CAAC,EAAE,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAA;YACnD,UAAU,CAAC,IAAI,EAAE,SAAS,EAAE,IAAyB,EAAE,CAAC,EAAE,OAAO,CAAC,CAAA;QACpE,CAAC,CAAC,CAAA;QACF,OAAM;IACR,CAAC;IACD,2EAA2E;IAC3E,uBAAuB;AACzB,CAAC;AAED,SAAS,UAAU,CAAC,KAAc,EAAE,IAAY,EAAE,KAAa,EAAE,GAAoB,EAAE,OAAgB;IACrG,IAAI,KAAK,YAAY,MAAM,IAAI,KAAK,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE,CAAC;QACpD,KAAK,CAAC,KAAK,EAAE,IAAI,EAAE,OAAO,CAAC,CAAA;IAC7B,CAAC;SAAM,CAAC;QACN,OAAO,CAAC,IAAI,EAAE,KAAK,EAAE,KAAK,EAAE,GAAG,CAAC,CAAA;IAClC,CAAC;AACH,CAAC"}
package/dist/nn.d.ts ADDED
@@ -0,0 +1,19 @@
1
+ import { Module } from './module.js';
2
+ import type { Tensor } from './ir.js';
3
+ export declare class Linear extends Module {
4
+ readonly inDim: number;
5
+ readonly outDim: number;
6
+ W: Tensor;
7
+ b: Tensor | null;
8
+ constructor(inDim: number, outDim: number, withBias?: boolean);
9
+ }
10
+ export declare function linearFwd(p: Linear, x: Tensor): Tensor;
11
+ export declare class LayerNorm extends Module {
12
+ readonly d: number;
13
+ readonly eps: number;
14
+ g: Tensor;
15
+ b: Tensor;
16
+ constructor(d: number, eps?: number);
17
+ }
18
+ export declare function layerNormFwd(p: LayerNorm, x: Tensor): Tensor;
19
+ //# sourceMappingURL=nn.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"nn.d.ts","sourceRoot":"","sources":["../src/nn.ts"],"names":[],"mappings":"AAeA,OAAO,EAAE,MAAM,EAAE,MAAM,aAAa,CAAA;AACpC,OAAO,KAAK,EAAE,MAAM,EAAE,MAAM,SAAS,CAAA;AAOrC,qBAAa,MAAO,SAAQ,MAAM;aAGJ,KAAK,EAAE,MAAM;aAAkB,MAAM,EAAE,MAAM;IAFzE,CAAC,EAAE,MAAM,CAAA;IACT,CAAC,EAAE,MAAM,GAAG,IAAI,CAAA;gBACY,KAAK,EAAE,MAAM,EAAkB,MAAM,EAAE,MAAM,EAAE,QAAQ,UAAO;CAK3F;AAED,wBAAgB,SAAS,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAGtD;AAMD,qBAAa,SAAU,SAAQ,MAAM;aAGP,CAAC,EAAE,MAAM;aAAkB,GAAG,EAAE,MAAM;IAFlE,CAAC,EAAE,MAAM,CAAA;IACT,CAAC,EAAE,MAAM,CAAA;gBACmB,CAAC,EAAE,MAAM,EAAkB,GAAG,GAAE,MAAa;CAK1E;AAED,wBAAgB,YAAY,CAAC,CAAC,EAAE,SAAS,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAM5D"}
package/dist/nn.js ADDED
@@ -0,0 +1,60 @@
1
+ // Standard "batteries-included" Module subclasses for the most common layers.
2
+ //
3
+ // JAX-style: each class declares its params (and their init); the forward is a
4
+ // plain function the user calls with `(module, x)`. No subclassing, no method
5
+ // dispatch — keeps the autograd-traced computation visible at the call site.
6
+ //
7
+ // Import as a namespace:
8
+ //
9
+ // import { nn } from 'tensorgrad'
10
+ // class Block extends Module {
11
+ // ln = new nn.LayerNorm(D)
12
+ // ffn = new nn.Linear(D, 4 * D)
13
+ // }
14
+ // const y = nn.linearFwd(p.ffn, nn.layerNormFwd(p.ln, x))
15
+ import { Module } from './module.js';
16
+ import { add, matmul, sub, mul, div, sqrt, meanLast } from './ops.js';
17
+ // ----------------------------------------------------------------------------
18
+ // Linear: y = x @ W (+ b)
19
+ // ----------------------------------------------------------------------------
20
+ export class Linear extends Module {
21
+ inDim;
22
+ outDim;
23
+ W;
24
+ b;
25
+ constructor(inDim, outDim, withBias = true) {
26
+ super();
27
+ this.inDim = inDim;
28
+ this.outDim = outDim;
29
+ this.W = this.param([inDim, outDim]); // randn, scale 0.02
30
+ this.b = withBias ? this.param([outDim], { init: 'zeros' }) : null;
31
+ }
32
+ }
33
+ export function linearFwd(p, x) {
34
+ const out = matmul(x, p.W);
35
+ return p.b ? add(out, p.b) : out;
36
+ }
37
+ // ----------------------------------------------------------------------------
38
+ // LayerNorm — normalizes over the last axis. eps defaults to 1e-5.
39
+ // ----------------------------------------------------------------------------
40
+ export class LayerNorm extends Module {
41
+ d;
42
+ eps;
43
+ g;
44
+ b;
45
+ constructor(d, eps = 1e-5) {
46
+ super();
47
+ this.d = d;
48
+ this.eps = eps;
49
+ this.g = this.param([d], { init: 'ones' });
50
+ this.b = this.param([d], { init: 'zeros' });
51
+ }
52
+ }
53
+ export function layerNormFwd(p, x) {
54
+ const m = meanLast(x);
55
+ const c = sub(x, m);
56
+ const v = meanLast(mul(c, c));
57
+ const stdev = sqrt(add(v, p.eps));
58
+ return add(mul(div(c, stdev), p.g), p.b);
59
+ }
60
+ //# sourceMappingURL=nn.js.map
package/dist/nn.js.map ADDED
@@ -0,0 +1 @@
1
+ {"version":3,"file":"nn.js","sourceRoot":"","sources":["../src/nn.ts"],"names":[],"mappings":"AAAA,8EAA8E;AAC9E,EAAE;AACF,+EAA+E;AAC/E,8EAA8E;AAC9E,6EAA6E;AAC7E,EAAE;AACF,yBAAyB;AACzB,EAAE;AACF,oCAAoC;AACpC,iCAAiC;AACjC,gCAAgC;AAChC,oCAAoC;AACpC,MAAM;AACN,4DAA4D;AAE5D,OAAO,EAAE,MAAM,EAAE,MAAM,aAAa,CAAA;AAEpC,OAAO,EAAE,GAAG,EAAE,MAAM,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,IAAI,EAAE,QAAQ,EAAE,MAAM,UAAU,CAAA;AAErE,+EAA+E;AAC/E,0BAA0B;AAC1B,+EAA+E;AAE/E,MAAM,OAAO,MAAO,SAAQ,MAAM;IAGJ;IAA+B;IAF3D,CAAC,CAAQ;IACT,CAAC,CAAe;IAChB,YAA4B,KAAa,EAAkB,MAAc,EAAE,QAAQ,GAAG,IAAI;QACxF,KAAK,EAAE,CAAA;QADmB,UAAK,GAAL,KAAK,CAAQ;QAAkB,WAAM,GAAN,MAAM,CAAQ;QAEvE,IAAI,CAAC,CAAC,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,MAAM,CAAC,CAAC,CAAA,CAAsB,oBAAoB;QAC9E,IAAI,CAAC,CAAC,GAAG,QAAQ,CAAC,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,MAAM,CAAC,EAAE,EAAE,IAAI,EAAE,OAAO,EAAE,CAAC,CAAC,CAAC,CAAC,IAAI,CAAA;IACpE,CAAC;CACF;AAED,MAAM,UAAU,SAAS,CAAC,CAAS,EAAE,CAAS;IAC5C,MAAM,GAAG,GAAG,MAAM,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAA;IAC1B,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,GAAG,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,CAAA;AAClC,CAAC;AAED,+EAA+E;AAC/E,mEAAmE;AACnE,+EAA+E;AAE/E,MAAM,OAAO,SAAU,SAAQ,MAAM;IAGP;IAA2B;IAFvD,CAAC,CAAQ;IACT,CAAC,CAAQ;IACT,YAA4B,CAAS,EAAkB,MAAc,IAAI;QACvE,KAAK,EAAE,CAAA;QADmB,MAAC,GAAD,CAAC,CAAQ;QAAkB,QAAG,GAAH,GAAG,CAAe;QAEvE,IAAI,CAAC,CAAC,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,EAAE,EAAE,IAAI,EAAE,MAAM,EAAE,CAAC,CAAA;QAC1C,IAAI,CAAC,CAAC,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,EAAE,EAAE,IAAI,EAAE,OAAO,EAAE,CAAC,CAAA;IAC7C,CAAC;CACF;AAED,MAAM,UAAU,YAAY,CAAC,CAAY,EAAE,CAAS;IAClD,MAAM,CAAC,GAAG,QAAQ,CAAC,CAAC,CAAC,CAAA;IACrB,MAAM,CAAC,GAAG,GAAG,CAAC,CAAC,EAAE,CAAC,CAAC,CAAA;IACnB,MAAM,CAAC,GAAG,QAAQ,CAAC,GAAG,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAA;IAC7B,MAAM,KAAK,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,CAAC,CAAC,GAAG,CAAC,CAAC,CAAA;IACjC,OAAO,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,EAAE,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAA;AAC1C,CAAC"}
package/dist/ops.d.ts CHANGED
@@ -31,5 +31,5 @@ export declare function where(cond: Tensor, a: Tensor, b: Tensor): Tensor;
31
31
  export declare function reluGrad(x: Tensor, dy: Tensor): Tensor;
32
32
  export declare function adamUpdateM(m: Tensor, g: Tensor, b1: number): Tensor;
33
33
  export declare function adamUpdateV(v: Tensor, g: Tensor, b2: number): Tensor;
34
- export declare function adamUpdateP(p: Tensor, mNew: Tensor, vNew: Tensor, lrt: Tensor, eps: number): Tensor;
34
+ export declare function adamUpdateP(p: Tensor, mNew: Tensor, vNew: Tensor, lrt: Tensor, eps: number, decayShrink?: number): Tensor;
35
35
  //# sourceMappingURL=ops.d.ts.map
package/dist/ops.d.ts.map CHANGED
@@ -1 +1 @@
1
- {"version":3,"file":"ops.d.ts","sourceRoot":"","sources":["../src/ops.ts"],"names":[],"mappings":"AAWA,OAAO,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAU,MAAM,SAAS,CAAA;AAmC3D,wBAAgB,GAAG,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,GAAG,MAAM,CAEzD;AACD,wBAAgB,GAAG,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,GAAG,MAAM,CAEzD;AACD,wBAAgB,GAAG,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,GAAG,MAAM,CAEzD;AACD,wBAAgB,GAAG,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,GAAG,MAAM,CAMzD;AAQD,wBAAgB,SAAS,CAAC,CAAC,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,MAAM,CAG3D;AAED,wBAAgB,SAAS,CAAC,CAAC,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,MAAM,CAG3D;AAYD,eAAO,MAAM,IAAI,GAAK,GAAG,MAAM,KAAG,MAA2B,CAAA;AAC7D,eAAO,MAAM,KAAK,GAAI,GAAG,MAAM,KAAG,MAA2B,CAAA;AAC7D,eAAO,MAAM,GAAG,GAAM,GAAG,MAAM,KAAG,MAA2B,CAAA;AAC7D,eAAO,MAAM,GAAG,GAAM,GAAG,MAAM,KAAG,MAA2B,CAAA;AAC7D,eAAO,MAAM,IAAI,GAAK,GAAG,MAAM,KAAG,MAA2B,CAAA;AAO7D,wBAAgB,QAAQ,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAK1C;AAED,wBAAgB,OAAO,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAKzC;AAMD,wBAAgB,OAAO,CAAC,CAAC,EAAE,MAAM,EAAE,QAAQ,EAAE,KAAK,GAAG,MAAM,CAI1D;AAED,wBAAgB,SAAS,CAAC,CAAC,EAAE,MAAM,EAAE,IAAI,EAAE,SAAS,MAAM,EAAE,GAAG,MAAM,CAIpE;AAMD,wBAAgB,MAAM,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAOnD;AAED,wBAAgB,aAAa,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAO1D;AAMD,wBAAgB,MAAM,CAAC,OAAO,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,GAAE,KAAa,GAAG,MAAM,CAOnF;AAGD,wBAAgB,MAAM,CAAC,CAAC,EAAE,MAAM,EAAE,KAAK,GAAE,KAAa,GAAG,MAAM,CAM9D;AASD,wBAAgB,iBAAiB,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAKnD;AAGD,wBAAgB,cAAc,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAIhD;AAMD,wBAAgB,WAAW,CAAC,CAAC,EAAE,MAAM,EAAE,SAAS,EAAE,MAAM,GAAG,MAAM,CAKhE;AAQD,wBAAgB,cAAc,CAAC,CAAC,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,EAAE,GAAG,EAAE,MAAM,GAAG,MAAM,CAI5E;AAOD,wBAAgB,WAAW,CAAC,CAAC,EAAE,MAAM,EAAE,WAAW,EAAE,KAAK,GAAG,MAAM,CAIjE;AAED,wBAAgB,UAAU,CAAC,CAAC,EAAE,MAAM,EAAE,WAAW,EAAE,KAAK,GAAG,MAAM,CAIhE;AAOD,wBAAgB,WAAW,CAAC,KAAK,EAAE,MAAM,EAAE,KAAK,GAAE,KAAa,GAAG,MAAM,CAGvE;AAWD,eAAO,MAAM,IAAI,GAAO,GAAG,MAAM,EAAE,GAAG,MAAM,KAAG,MAAqD,CAAA;AACpG,eAAO,MAAM,OAAO,GAAI,GAAG,MAAM,EAAE,GAAG,MAAM,KAAG,MAAqD,CAAA;AAGpG,wBAAgB,KAAK,CAAC,IAAI,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAMhE;AAID,wBAAgB,QAAQ,CAAC,CAAC,EAAE,MAAM,EAAE,EAAE,EAAE,MAAM,GAAG,MAAM,CAOtD;AAMD,wBAAgB,WAAW,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,EAAE,EAAE,EAAE,MAAM,GAAG,MAAM,CAOpE;AAED,wBAAgB,WAAW,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,EAAE,EAAE,EAAE,MAAM,GAAG,MAAM,CAOpE;AAED,wBAAgB,WAAW,CAAC,CAAC,EAAE,MAAM,EAAE,IAAI,EAAE,MAAM,EAAE,IAAI,EAAE,MAAM,EAAE,GAAG,EAAE,MAAM,EAAE,GAAG,EAAE,MAAM,GAAG,MAAM,CAWnG"}
1
+ {"version":3,"file":"ops.d.ts","sourceRoot":"","sources":["../src/ops.ts"],"names":[],"mappings":"AAWA,OAAO,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAU,MAAM,SAAS,CAAA;AAmC3D,wBAAgB,GAAG,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,GAAG,MAAM,CAEzD;AACD,wBAAgB,GAAG,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,GAAG,MAAM,CAEzD;AACD,wBAAgB,GAAG,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,GAAG,MAAM,CAEzD;AACD,wBAAgB,GAAG,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,GAAG,MAAM,CAMzD;AAQD,wBAAgB,SAAS,CAAC,CAAC,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,MAAM,CAG3D;AAED,wBAAgB,SAAS,CAAC,CAAC,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,MAAM,CAG3D;AAYD,eAAO,MAAM,IAAI,GAAK,GAAG,MAAM,KAAG,MAA2B,CAAA;AAC7D,eAAO,MAAM,KAAK,GAAI,GAAG,MAAM,KAAG,MAA2B,CAAA;AAC7D,eAAO,MAAM,GAAG,GAAM,GAAG,MAAM,KAAG,MAA2B,CAAA;AAC7D,eAAO,MAAM,GAAG,GAAM,GAAG,MAAM,KAAG,MAA2B,CAAA;AAC7D,eAAO,MAAM,IAAI,GAAK,GAAG,MAAM,KAAG,MAA2B,CAAA;AAO7D,wBAAgB,QAAQ,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAK1C;AAED,wBAAgB,OAAO,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAKzC;AAMD,wBAAgB,OAAO,CAAC,CAAC,EAAE,MAAM,EAAE,QAAQ,EAAE,KAAK,GAAG,MAAM,CAI1D;AAED,wBAAgB,SAAS,CAAC,CAAC,EAAE,MAAM,EAAE,IAAI,EAAE,SAAS,MAAM,EAAE,GAAG,MAAM,CAIpE;AAMD,wBAAgB,MAAM,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAOnD;AAED,wBAAgB,aAAa,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAO1D;AAMD,wBAAgB,MAAM,CAAC,OAAO,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,GAAE,KAAa,GAAG,MAAM,CAOnF;AAGD,wBAAgB,MAAM,CAAC,CAAC,EAAE,MAAM,EAAE,KAAK,GAAE,KAAa,GAAG,MAAM,CAM9D;AASD,wBAAgB,iBAAiB,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAKnD;AAGD,wBAAgB,cAAc,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAIhD;AAMD,wBAAgB,WAAW,CAAC,CAAC,EAAE,MAAM,EAAE,SAAS,EAAE,MAAM,GAAG,MAAM,CAKhE;AAQD,wBAAgB,cAAc,CAAC,CAAC,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,EAAE,GAAG,EAAE,MAAM,GAAG,MAAM,CAI5E;AAOD,wBAAgB,WAAW,CAAC,CAAC,EAAE,MAAM,EAAE,WAAW,EAAE,KAAK,GAAG,MAAM,CAIjE;AAED,wBAAgB,UAAU,CAAC,CAAC,EAAE,MAAM,EAAE,WAAW,EAAE,KAAK,GAAG,MAAM,CAIhE;AAOD,wBAAgB,WAAW,CAAC,KAAK,EAAE,MAAM,EAAE,KAAK,GAAE,KAAa,GAAG,MAAM,CAGvE;AAWD,eAAO,MAAM,IAAI,GAAO,GAAG,MAAM,EAAE,GAAG,MAAM,KAAG,MAAqD,CAAA;AACpG,eAAO,MAAM,OAAO,GAAI,GAAG,MAAM,EAAE,GAAG,MAAM,KAAG,MAAqD,CAAA;AAGpG,wBAAgB,KAAK,CAAC,IAAI,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAMhE;AAID,wBAAgB,QAAQ,CAAC,CAAC,EAAE,MAAM,EAAE,EAAE,EAAE,MAAM,GAAG,MAAM,CAOtD;AAMD,wBAAgB,WAAW,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,EAAE,EAAE,EAAE,MAAM,GAAG,MAAM,CAOpE;AAED,wBAAgB,WAAW,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,EAAE,EAAE,EAAE,MAAM,GAAG,MAAM,CAOpE;AAED,wBAAgB,WAAW,CAAC,CAAC,EAAE,MAAM,EAAE,IAAI,EAAE,MAAM,EAAE,IAAI,EAAE,MAAM,EAAE,GAAG,EAAE,MAAM,EAAE,GAAG,EAAE,MAAM,EAAE,WAAW,GAAE,MAAU,GAAG,MAAM,CAW5H"}
package/dist/ops.js CHANGED
@@ -255,7 +255,7 @@ export function adamUpdateV(v, g, b2) {
255
255
  }
256
256
  return addOp(currentGraph(), 'adam_update_v', v.shape, 'f32', site, { v: v.id, g: g.id, b2 });
257
257
  }
258
- export function adamUpdateP(p, mNew, vNew, lrt, eps) {
258
+ export function adamUpdateP(p, mNew, vNew, lrt, eps, decayShrink = 1) {
259
259
  const site = captureSite('adamUpdateP');
260
260
  if (p.dtype !== 'f32')
261
261
  throw new ShapeError(`adamUpdateP: requires f32`, site);
@@ -265,6 +265,6 @@ export function adamUpdateP(p, mNew, vNew, lrt, eps) {
265
265
  if (p.shape.length !== mNew.shape.length || p.shape.some((d, i) => d !== mNew.shape[i])) {
266
266
  throw new ShapeError(`adamUpdateP: p/mNew shape mismatch`, site);
267
267
  }
268
- return addOp(currentGraph(), 'adam_update_p', p.shape, 'f32', site, { p: p.id, mNew: mNew.id, vNew: vNew.id, lrt: lrt.id, eps });
268
+ return addOp(currentGraph(), 'adam_update_p', p.shape, 'f32', site, { p: p.id, mNew: mNew.id, vNew: vNew.id, lrt: lrt.id, eps, decayShrink });
269
269
  }
270
270
  //# sourceMappingURL=ops.js.map