tensorgrad 0.0.9 → 0.0.12
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +119 -119
- package/dist/compile.d.ts +77 -28
- package/dist/compile.d.ts.map +1 -1
- package/dist/compile.js +132 -81
- package/dist/compile.js.map +1 -1
- package/dist/index.d.ts +2 -2
- package/dist/index.d.ts.map +1 -1
- package/dist/index.js +2 -2
- package/dist/index.js.map +1 -1
- package/dist/nn.d.ts +14 -11
- package/dist/nn.d.ts.map +1 -1
- package/dist/nn.js +28 -33
- package/dist/nn.js.map +1 -1
- package/dist/runtime.d.ts +35 -27
- package/dist/runtime.d.ts.map +1 -1
- package/dist/runtime.js +45 -10
- package/dist/runtime.js.map +1 -1
- package/package.json +61 -61
- package/src/compile.ts +358 -227
- package/src/index.ts +47 -42
- package/src/nn.ts +34 -32
- package/src/runtime.ts +523 -497
package/dist/compile.js
CHANGED
|
@@ -35,82 +35,55 @@ export async function compile(traceFn, opts = {}) {
|
|
|
35
35
|
* field becomes a real `Tensor`), so the instance is consumed and shouldn't be
|
|
36
36
|
* referenced afterwards. Re-call the factory if you need a fresh tree.
|
|
37
37
|
*
|
|
38
|
-
* The forward function takes the materialized model and
|
|
39
|
-
* tensor.
|
|
38
|
+
* The forward function takes the materialized model and a Record of named
|
|
39
|
+
* input tensors, returns the loss tensor. Inputs are matched by name with the
|
|
40
|
+
* `inputs:` declaration:
|
|
41
|
+
*
|
|
42
|
+
* inputs: {
|
|
43
|
+
* tokens: { shape: [B, T], dtype: 'i32' },
|
|
44
|
+
* targets: { shape: [B, T], dtype: 'i32' },
|
|
45
|
+
* }
|
|
46
|
+
* forward: (m, { tokens, targets }) => …
|
|
40
47
|
*
|
|
41
48
|
* Walks the module tree to materialize params with auto-derived names, then
|
|
42
|
-
* runs trace → grad → adam → buffer plan → codegen → runtime.
|
|
49
|
+
* runs trace → grad → adam → buffer plan → codegen → runtime. Initial
|
|
50
|
+
* parameter values are uploaded automatically before this function returns;
|
|
51
|
+
* call `reset()` later to re-randomize.
|
|
43
52
|
*
|
|
44
53
|
* If `opts.adam` is set, the runtime's `step()` automatically tracks an
|
|
45
54
|
* internal step count and injects the bias-corrected `lrt` scalar each call;
|
|
46
55
|
* users don't need to provide it themselves.
|
|
47
56
|
*/
|
|
48
57
|
export async function compileModule(modelFactory, forward, opts = {}) {
|
|
49
|
-
const
|
|
50
|
-
const model = modelFactory();
|
|
51
|
-
let materialized = { tensors: {}, initFns: {}, decayFlags: {} };
|
|
52
|
-
const graph = trace(() => {
|
|
53
|
-
materialized = materializeParams(model);
|
|
54
|
-
const inputTensors = inputDecls.map(d => tensorInput(d.name, d.shape, d.dtype ?? 'f32'));
|
|
55
|
-
return forward(model, ...inputTensors);
|
|
56
|
-
});
|
|
57
|
-
const { paramGrads, loss } = appendGrad(graph);
|
|
58
|
-
let adamResult;
|
|
59
|
-
if (opts.adam) {
|
|
60
|
-
adamResult = appendAdam(graph, paramGrads, materialized.tensors, opts.adam, materialized.decayFlags);
|
|
61
|
-
}
|
|
62
|
-
const plan = planBuffers(graph, paramGrads, adamResult?.writebacks ?? []);
|
|
63
|
-
const kernels = emitKernels(graph, plan);
|
|
64
|
-
const lossBufferId = plan.tensorToBuffer.get(loss.id);
|
|
65
|
-
const runtime = await createRuntime(plan, kernels, lossBufferId, opts);
|
|
58
|
+
const { runtime, materialized, plan, kernels, ir } = await buildModuleRuntime(modelFactory, forward, opts, /* sharedParams */ undefined, /* withGrad */ true);
|
|
66
59
|
// If Adam is enabled, wrap step() to track the step count and supply lrt
|
|
67
60
|
// (and optionally decayShrink, when the user passed a per-step lr schedule).
|
|
68
61
|
// Wrap resetOptimizerState() too, so a reset zeros m/v *and* the bias-correction
|
|
69
62
|
// counter — otherwise the next step would skip Adam's warmup phase.
|
|
70
|
-
if (
|
|
71
|
-
|
|
72
|
-
let t = 0;
|
|
73
|
-
const lrtBuf = new Float32Array(1);
|
|
74
|
-
const decayShrinkBuf = decayShrinkInputName ? new Float32Array(1) : null;
|
|
75
|
-
const innerStep = runtime.step.bind(runtime);
|
|
76
|
-
const innerReset = runtime.resetOptimizerState.bind(runtime);
|
|
77
|
-
const wrappedStep = (inputs, opts) => {
|
|
78
|
-
t++;
|
|
79
|
-
const lrNow = config.lr(t);
|
|
80
|
-
lrtBuf[0] = lrNow * Math.sqrt(1 - Math.pow(config.b2, t)) / (1 - Math.pow(config.b1, t));
|
|
81
|
-
const merged = { ...inputs, [lrtInputName]: lrtBuf };
|
|
82
|
-
if (decayShrinkBuf && decayShrinkInputName) {
|
|
83
|
-
decayShrinkBuf[0] = 1 - lrNow * config.weightDecay;
|
|
84
|
-
merged[decayShrinkInputName] = decayShrinkBuf;
|
|
85
|
-
}
|
|
86
|
-
return opts?.withCaptures ? innerStep(merged, { withCaptures: true }) : innerStep(merged);
|
|
87
|
-
};
|
|
88
|
-
runtime.step = wrappedStep;
|
|
89
|
-
runtime.resetOptimizerState = () => {
|
|
90
|
-
t = 0;
|
|
91
|
-
innerReset();
|
|
92
|
-
};
|
|
63
|
+
if (opts.adam) {
|
|
64
|
+
wrapStepForAdam(runtime, opts.adam, ir);
|
|
93
65
|
}
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
66
|
+
// Auto-upload initial param values. Always wanted at this entry point —
|
|
67
|
+
// training runtimes own their params and need them randomized before step 1.
|
|
68
|
+
uploadInitialParams(plan, materialized.initFns, runtime, /* sharedParams */ undefined);
|
|
69
|
+
const kernelCount = kernels.filter(k => k.wgsl).length;
|
|
70
|
+
const reset = () => {
|
|
71
|
+
uploadInitialParams(plan, materialized.initFns, runtime, undefined);
|
|
72
|
+
runtime.resetOptimizerState();
|
|
97
73
|
};
|
|
98
|
-
const
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
out[name] = initFn(size, shape);
|
|
112
|
-
}
|
|
113
|
-
return out;
|
|
74
|
+
const compileForwardMethod = async (forwardFn, fOpts = {}) => {
|
|
75
|
+
return compileForward(modelFactory, forwardFn, {
|
|
76
|
+
...fOpts,
|
|
77
|
+
device: runtime.device,
|
|
78
|
+
sharedParams: runtime.params,
|
|
79
|
+
});
|
|
80
|
+
};
|
|
81
|
+
return Object.assign(runtime, {
|
|
82
|
+
ir,
|
|
83
|
+
kernelCount,
|
|
84
|
+
reset,
|
|
85
|
+
compileForward: compileForwardMethod,
|
|
86
|
+
});
|
|
114
87
|
}
|
|
115
88
|
// ============================================================================
|
|
116
89
|
// Forward-only compile
|
|
@@ -121,38 +94,116 @@ function buildInitialParamUploads(plan, initFns, sharedParams) {
|
|
|
121
94
|
* scalar loss; runtime exposes `run(inputs)` returning the full output as a
|
|
122
95
|
* `Float32Array`.
|
|
123
96
|
*
|
|
97
|
+
* **Prefer the `compileForward` method on a training runtime** when both
|
|
98
|
+
* graphs use the same Module class — it auto-supplies `device` and
|
|
99
|
+
* `sharedParams`. This standalone form is for forward-only models with no
|
|
100
|
+
* training graph at all, or for sharing params across a different model.
|
|
101
|
+
*
|
|
124
102
|
* **Sharing params with a training compile.** Pass `opts.sharedParams =
|
|
125
103
|
* trainCompiled.params` to bind this graph's param buffers to an existing
|
|
126
104
|
* training runtime's GPU buffers — every train step is then immediately
|
|
127
|
-
* visible to `run()` calls here, no copies.
|
|
128
|
-
* `uploadInitialParams()` skips any param covered by `sharedParams`.
|
|
105
|
+
* visible to `run()` calls here, no copies.
|
|
129
106
|
*
|
|
130
|
-
*
|
|
131
|
-
*
|
|
107
|
+
* Initial param values are uploaded automatically for params *not* covered
|
|
108
|
+
* by `sharedParams` (those are owned by the sibling compile).
|
|
132
109
|
*/
|
|
133
110
|
export async function compileForward(modelFactory, forward, opts = {}) {
|
|
134
|
-
const
|
|
111
|
+
const sharedParams = opts.sharedParams;
|
|
112
|
+
const { runtime, materialized, plan, kernels, ir } = await buildModuleRuntime(modelFactory, forward, opts, sharedParams, /* withGrad */ false);
|
|
113
|
+
// Auto-upload initial values for any params this graph owns. With
|
|
114
|
+
// `sharedParams` covering everything, this is a no-op.
|
|
115
|
+
uploadInitialParams(plan, materialized.initFns, runtime, sharedParams);
|
|
116
|
+
const kernelCount = kernels.filter(k => k.wgsl).length;
|
|
117
|
+
return Object.assign(runtime, { ir, kernelCount });
|
|
118
|
+
}
|
|
119
|
+
/** Shared body of compileModule + compileForward. The training and forward
|
|
120
|
+
* pipelines diverge only in (a) whether grad/Adam are appended and (b)
|
|
121
|
+
* whether the output buffer is the loss scalar or the user's returned
|
|
122
|
+
* tensor — both come out of the same trace and codegen path. */
|
|
123
|
+
async function buildModuleRuntime(modelFactory, forward, opts, sharedParams, withGrad) {
|
|
124
|
+
const inputDecls = opts.inputs ?? {};
|
|
135
125
|
const model = modelFactory();
|
|
136
126
|
let materialized = { tensors: {}, initFns: {}, decayFlags: {} };
|
|
137
127
|
const graph = trace(() => {
|
|
138
128
|
materialized = materializeParams(model);
|
|
139
|
-
const inputTensors =
|
|
140
|
-
|
|
129
|
+
const inputTensors = {};
|
|
130
|
+
for (const [name, decl] of Object.entries(inputDecls)) {
|
|
131
|
+
inputTensors[name] = tensorInput(name, decl.shape, decl.dtype ?? 'f32');
|
|
132
|
+
}
|
|
133
|
+
return forward(model, inputTensors);
|
|
141
134
|
});
|
|
142
|
-
|
|
135
|
+
let paramGrads = {};
|
|
136
|
+
let outputTensor;
|
|
137
|
+
let adamWritebacks = [];
|
|
138
|
+
if (withGrad) {
|
|
139
|
+
const gradResult = appendGrad(graph);
|
|
140
|
+
paramGrads = gradResult.paramGrads;
|
|
141
|
+
outputTensor = gradResult.loss;
|
|
142
|
+
const adamCfg = opts.adam;
|
|
143
|
+
if (adamCfg) {
|
|
144
|
+
const adamResult = appendAdam(graph, paramGrads, materialized.tensors, adamCfg, materialized.decayFlags);
|
|
145
|
+
adamWritebacks = adamResult.writebacks;
|
|
146
|
+
graph.__adam = adamResult;
|
|
147
|
+
}
|
|
148
|
+
}
|
|
149
|
+
else {
|
|
150
|
+
outputTensor = graph.tensors[graph.outputs[0]];
|
|
151
|
+
}
|
|
152
|
+
const plan = planBuffers(graph, paramGrads, adamWritebacks);
|
|
143
153
|
const kernels = emitKernels(graph, plan);
|
|
144
|
-
const outputTensor = graph.tensors[graph.outputs[0]];
|
|
145
154
|
const outputBufferId = plan.tensorToBuffer.get(outputTensor.id);
|
|
146
|
-
|
|
147
|
-
const
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
155
|
+
// exactOptionalPropertyTypes: only include sharedParams when defined.
|
|
156
|
+
const runtimeOpts = sharedParams
|
|
157
|
+
? { ...opts, sharedParams }
|
|
158
|
+
: { ...opts };
|
|
159
|
+
const runtime = withGrad
|
|
160
|
+
? await createRuntime(plan, kernels, outputBufferId, runtimeOpts)
|
|
161
|
+
: await createForwardRuntime(plan, kernels, outputBufferId, runtimeOpts);
|
|
162
|
+
const ir = { graph, paramGrads, loss: outputTensor, plan, kernels };
|
|
163
|
+
return { runtime: runtime, materialized, plan, kernels, ir };
|
|
164
|
+
}
|
|
165
|
+
function wrapStepForAdam(runtime, adamCfg, ir) {
|
|
166
|
+
const adamResult = ir.graph.__adam;
|
|
167
|
+
const { lrtInputName, decayShrinkInputName, config } = adamResult;
|
|
168
|
+
let t = 0;
|
|
169
|
+
const lrtBuf = new Float32Array(1);
|
|
170
|
+
const decayShrinkBuf = decayShrinkInputName ? new Float32Array(1) : null;
|
|
171
|
+
const innerStep = runtime.step.bind(runtime);
|
|
172
|
+
const innerReset = runtime.resetOptimizerState.bind(runtime);
|
|
173
|
+
const wrappedStep = ((inputs, opts) => {
|
|
174
|
+
t++;
|
|
175
|
+
const lrNow = config.lr(t);
|
|
176
|
+
lrtBuf[0] = lrNow * Math.sqrt(1 - Math.pow(config.b2, t)) / (1 - Math.pow(config.b1, t));
|
|
177
|
+
const merged = { ...inputs, [lrtInputName]: lrtBuf };
|
|
178
|
+
if (decayShrinkBuf && decayShrinkInputName) {
|
|
179
|
+
decayShrinkBuf[0] = 1 - lrNow * config.weightDecay;
|
|
180
|
+
merged[decayShrinkInputName] = decayShrinkBuf;
|
|
181
|
+
}
|
|
182
|
+
return opts?.withCaptures ? innerStep(merged, { withCaptures: true }) : innerStep(merged);
|
|
183
|
+
});
|
|
184
|
+
runtime.step = wrappedStep;
|
|
185
|
+
runtime.resetOptimizerState = () => {
|
|
186
|
+
t = 0;
|
|
187
|
+
innerReset();
|
|
152
188
|
};
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
189
|
+
void adamCfg;
|
|
190
|
+
}
|
|
191
|
+
/** Build a Record<paramName, Float32Array> by running each param's init
|
|
192
|
+
* function against its shape and uploading them to the runtime. Skips any
|
|
193
|
+
* param covered by `sharedParams` (those are owned by a sibling compile). */
|
|
194
|
+
function uploadInitialParams(plan, initFns, runtime, sharedParams) {
|
|
195
|
+
const out = {};
|
|
196
|
+
for (const [name, bufId] of plan.paramsByName) {
|
|
197
|
+
if (sharedParams?.has(name))
|
|
198
|
+
continue;
|
|
199
|
+
const shape = plan.buffers[bufId].shape;
|
|
200
|
+
const size = shape.reduce((a, b) => a * b, 1);
|
|
201
|
+
const initFn = initFns[name];
|
|
202
|
+
if (!initFn)
|
|
203
|
+
throw new Error(`compile: no init for param '${name}'`);
|
|
204
|
+
out[name] = initFn(size, shape);
|
|
205
|
+
}
|
|
206
|
+
if (Object.keys(out).length > 0)
|
|
207
|
+
runtime.uploadParams(out, { partial: !!sharedParams });
|
|
157
208
|
}
|
|
158
209
|
//# sourceMappingURL=compile.js.map
|
package/dist/compile.js.map
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"compile.js","sourceRoot":"","sources":["../src/compile.ts"],"names":[],"mappings":"AAAA,2EAA2E;AAC3E,EAAE;AACF,oBAAoB;AACpB,sEAAsE;AACtE,iEAAiE;AACjE,0EAA0E;AAC1E,0EAA0E;AAC1E,2EAA2E;AAC3E,mEAAmE;AAGnE,OAAO,EAAE,KAAK,EAAE,WAAW,EAAE,MAAM,YAAY,CAAA;AAC/C,OAAO,EAAE,UAAU,EAAmB,MAAM,WAAW,CAAA;AACvD,OAAO,EAAE,UAAU,EAAmB,MAAM,WAAW,CAAA;AACvD,OAAO,EAAE,WAAW,EAAmB,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAE,WAAW,EAAmB,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAE,aAAa,EAAE,oBAAoB,EAAgE,MAAM,cAAc,CAAA;AAChI,OAAO,EAAE,MAAM,EAAE,iBAAiB,EAAE,MAAM,aAAa,CAAA;
|
|
1
|
+
{"version":3,"file":"compile.js","sourceRoot":"","sources":["../src/compile.ts"],"names":[],"mappings":"AAAA,2EAA2E;AAC3E,EAAE;AACF,oBAAoB;AACpB,sEAAsE;AACtE,iEAAiE;AACjE,0EAA0E;AAC1E,0EAA0E;AAC1E,2EAA2E;AAC3E,mEAAmE;AAGnE,OAAO,EAAE,KAAK,EAAE,WAAW,EAAE,MAAM,YAAY,CAAA;AAC/C,OAAO,EAAE,UAAU,EAAmB,MAAM,WAAW,CAAA;AACvD,OAAO,EAAE,UAAU,EAAmB,MAAM,WAAW,CAAA;AACvD,OAAO,EAAE,WAAW,EAAmB,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAE,WAAW,EAAmB,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAE,aAAa,EAAE,oBAAoB,EAAgE,MAAM,cAAc,CAAA;AAChI,OAAO,EAAE,MAAM,EAAE,iBAAiB,EAAE,MAAM,aAAa,CAAA;AAoCvD,yEAAyE;AACzE,MAAM,UAAU,WAAW,CAAC,OAAqB;IAC/C,MAAM,KAAK,GAAG,KAAK,CAAC,OAAO,CAAC,CAAA;IAC5B,MAAM,EAAE,UAAU,EAAE,IAAI,EAAE,GAAG,UAAU,CAAC,KAAK,CAAC,CAAA;IAC9C,MAAM,IAAI,GAAG,WAAW,CAAC,KAAK,EAAE,UAAU,CAAC,CAAA;IAC3C,MAAM,OAAO,GAAG,WAAW,CAAC,KAAK,EAAE,IAAI,CAAC,CAAA;IACxC,OAAO,EAAE,KAAK,EAAE,UAAU,EAAE,IAAI,EAAE,IAAI,EAAE,OAAO,EAAE,CAAA;AACnD,CAAC;AAED,0EAA0E;AAC1E,MAAM,CAAC,KAAK,UAAU,OAAO,CAAC,OAAqB,EAAE,OAAoB,EAAE;IACzE,MAAM,EAAE,GAAG,WAAW,CAAC,OAAO,CAAC,CAAA;IAC/B,MAAM,YAAY,GAAG,EAAE,CAAC,IAAI,CAAC,cAAc,CAAC,GAAG,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,CAAE,CAAA;IAC5D,MAAM,OAAO,GAAG,MAAM,aAAa,CAAC,EAAE,CAAC,IAAI,EAAE,EAAE,CAAC,OAAO,EAAE,YAAY,EAAE,IAAI,CAAC,CAAA;IAC5E,OAAO,MAAM,CAAC,MAAM,CAAC,OAAO,EAAE,EAAE,EAAE,EAAE,CAAC,CAAA;AACvC,CAAC;AAsDD;;;;;;;;;;;;;;;;;;;;;;;;GAwBG;AACH,MAAM,CAAC,KAAK,UAAU,aAAa,CACjC,YAAqB,EACrB,OAAwB,EACxB,OAAgC,EAAE;IAElC,MAAM,EAAE,OAAO,EAAE,YAAY,EAAE,IAAI,EAAE,OAAO,EAAE,EAAE,EAAE,GAAG,MAAM,kBAAkB,CAC3E,YAAY,EAAE,OAAO,EAAE,IAAI,EAAE,kBAAkB,CAAC,SAAS,EAAE,cAAc,CAAC,IAAI,CAC/E,CAAA;IAED,yEAAyE;IACzE,6EAA6E;IAC7E,iFAAiF;IACjF,oEAAoE;IACpE,IAAI,IAAI,CAAC,IAAI,EAAE,CAAC;QACd,eAAe,CAAC,OAAO,EAAE,IAAI,CAAC,IAAI,EAAE,EAAE,CAAC,CAAA;IACzC,CAAC;IAED,wEAAwE;IACxE,6EAA6E;IAC7E,mBAAmB,CAAC,IAAI,EAAE,YAAY,CAAC,OAAO,EAAE,OAAO,EAAE,kBAAkB,CAAC,SAAS,CAAC,CAAA;IAEtF,MAAM,WAAW,GAAG,OAAO,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC,MAAM,CAAA;IAEtD,MAAM,KAAK,GAAG,GAAG,EAAE;QACjB,mBAAmB,CAAC,IAAI,EAAE,YAAY,CAAC,OAAO,EAAE,OAAO,EAAE,SAAS,CAAC,CAAA;QACnE,OAAO,CAAC,mBAAmB,EAAE,CAAA;IAC/B,CAAC,CAAA;IAED,MAAM,oBAAoB,GAAG,KAAK,EAChC,SAA0B,EAC1B,QAAwC,EAAE,EACV,EAAE;QAClC,OAAO,cAAc,CAAO,YAAY,EAAE,SAAS,EAAE;YACnD,GAAG,KAAK;YACR,MAAM,EAAE,OAAO,CAAC,MAAM;YACtB,YAAY,EAAE,OAAO,CAAC,MAAM;SAC7B,CAAC,CAAA;IACJ,CAAC,CAAA;IAED,OAAO,MAAM,CAAC,MAAM,CAAC,OAAO,EAAE;QAC5B,EAAE;QACF,WAAW;QACX,KAAK;QACL,cAAc,EAAE,oBAAoB;KACrC,CAAC,CAAA;AACJ,CAAC;AAED,+EAA+E;AAC/E,uBAAuB;AACvB,+EAA+E;AAE/E;;;;;;;;;;;;;;;;;;GAkBG;AACH,MAAM,CAAC,KAAK,UAAU,cAAc,CAClC,YAAqB,EACrB,OAAwB,EACxB,OAAiC,EAAE;IAEnC,MAAM,YAAY,GAAG,IAAI,CAAC,YAAY,CAAA;IACtC,MAAM,EAAE,OAAO,EAAE,YAAY,EAAE,IAAI,EAAE,OAAO,EAAE,EAAE,EAAE,GAAG,MAAM,kBAAkB,CAC3E,YAAY,EAAE,OAAO,EAAE,IAAI,EAAE,YAAY,EAAE,cAAc,CAAC,KAAK,CAChE,CAAA;IAED,kEAAkE;IAClE,uDAAuD;IACvD,mBAAmB,CAAC,IAAI,EAAE,YAAY,CAAC,OAAO,EAAE,OAAO,EAAE,YAAY,CAAC,CAAA;IAEtE,MAAM,WAAW,GAAG,OAAO,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC,MAAM,CAAA;IACtD,OAAO,MAAM,CAAC,MAAM,CAAC,OAAO,EAAE,EAAE,EAAE,EAAE,WAAW,EAAE,CAAC,CAAA;AACpD,CAAC;AAgBD;;;iEAGiE;AACjE,KAAK,UAAU,kBAAkB,CAC/B,YAAqB,EACrB,OAAwB,EACxB,IAAwD,EACxD,YAAgD,EAChD,QAAiB;IAEjB,MAAM,UAAU,GAAe,IAAI,CAAC,MAAM,IAAI,EAAE,CAAA;IAChD,MAAM,KAAK,GAAG,YAAY,EAAE,CAAA;IAC5B,IAAI,YAAY,GAAyC,EAAE,OAAO,EAAE,EAAE,EAAE,OAAO,EAAE,EAAE,EAAE,UAAU,EAAE,EAAE,EAAE,CAAA;IACrG,MAAM,KAAK,GAAG,KAAK,CAAC,GAAG,EAAE;QACvB,YAAY,GAAG,iBAAiB,CAAC,KAAK,CAAC,CAAA;QACvC,MAAM,YAAY,GAA2B,EAAE,CAAA;QAC/C,KAAK,MAAM,CAAC,IAAI,EAAE,IAAI,CAAC,IAAI,MAAM,CAAC,OAAO,CAAC,UAAU,CAAC,EAAE,CAAC;YACtD,YAAY,CAAC,IAAI,CAAC,GAAG,WAAW,CAAC,IAAI,EAAE,IAAI,CAAC,KAAK,EAAE,IAAI,CAAC,KAAK,IAAI,KAAK,CAAC,CAAA;QACzE,CAAC;QACD,OAAO,OAAO,CAAC,KAAK,EAAE,YAAgC,CAAC,CAAA;IACzD,CAAC,CAAC,CAAA;IAEF,IAAI,UAAU,GAA6B,EAAE,CAAA;IAC7C,IAAI,YAAoB,CAAA;IACxB,IAAI,cAAc,GAAgD,EAAE,CAAA;IAEpE,IAAI,QAAQ,EAAE,CAAC;QACb,MAAM,UAAU,GAAG,UAAU,CAAC,KAAK,CAAC,CAAA;QACpC,UAAU,GAAG,UAAU,CAAC,UAAU,CAAA;QAClC,YAAY,GAAG,UAAU,CAAC,IAAI,CAAA;QAC9B,MAAM,OAAO,GAAI,IAA6B,CAAC,IAAI,CAAA;QACnD,IAAI,OAAO,EAAE,CAAC;YACZ,MAAM,UAAU,GAAG,UAAU,CAAC,KAAK,EAAE,UAAU,EAAE,YAAY,CAAC,OAAO,EAAE,OAAO,EAAE,YAAY,CAAC,UAAU,CAAC,CAAA;YACxG,cAAc,GAAG,UAAU,CAAC,UAAU,CAErC;YAAC,KAA4D,CAAC,MAAM,GAAG,UAAU,CAAA;QACpF,CAAC;IACH,CAAC;SAAM,CAAC;QACN,YAAY,GAAG,KAAK,CAAC,OAAO,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC,CAAE,CAAE,CAAA;IAClD,CAAC;IAED,MAAM,IAAI,GAAG,WAAW,CAAC,KAAK,EAAE,UAAU,EAAE,cAAc,CAAC,CAAA;IAC3D,MAAM,OAAO,GAAG,WAAW,CAAC,KAAK,EAAE,IAAI,CAAC,CAAA;IACxC,MAAM,cAAc,GAAG,IAAI,CAAC,cAAc,CAAC,GAAG,CAAC,YAAY,CAAC,EAAE,CAAE,CAAA;IAChE,sEAAsE;IACtE,MAAM,WAAW,GAAgB,YAAY;QAC3C,CAAC,CAAC,EAAE,GAAG,IAAI,EAAE,YAAY,EAAE;QAC3B,CAAC,CAAC,EAAE,GAAG,IAAI,EAAE,CAAA;IACf,MAAM,OAAO,GAAG,QAAQ;QACtB,CAAC,CAAC,MAAM,aAAa,CAAC,IAAI,EAAE,OAAO,EAAE,cAAc,EAAE,WAAW,CAAC;QACjE,CAAC,CAAC,MAAM,oBAAoB,CAAC,IAAI,EAAE,OAAO,EAAE,cAAc,EAAE,WAAW,CAAC,CAAA;IAE1E,MAAM,EAAE,GAAe,EAAE,KAAK,EAAE,UAAU,EAAE,IAAI,EAAE,YAAY,EAAE,IAAI,EAAE,OAAO,EAAE,CAAA;IAC/E,OAAO,EAAE,OAAO,EAAE,OAA0B,EAAE,YAAY,EAAE,IAAI,EAAE,OAAO,EAAE,EAAE,EAAE,CAAA;AACjF,CAAC;AAID,SAAS,eAAe,CAAC,OAAwB,EAAE,OAAmB,EAAE,EAAc;IACpF,MAAM,UAAU,GAAI,EAAE,CAAC,KAA4D,CAAC,MAAO,CAAA;IAC3F,MAAM,EAAE,YAAY,EAAE,oBAAoB,EAAE,MAAM,EAAE,GAAG,UAAU,CAAA;IACjE,IAAI,CAAC,GAAG,CAAC,CAAA;IACT,MAAM,MAAM,GAAG,IAAI,YAAY,CAAC,CAAC,CAAC,CAAA;IAClC,MAAM,cAAc,GAAG,oBAAoB,CAAC,CAAC,CAAC,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,IAAI,CAAA;IACxE,MAAM,SAAS,GAAG,OAAO,CAAC,IAAI,CAAC,IAAI,CAAC,OAAO,CAA4B,CAAA;IACvE,MAAM,UAAU,GAAG,OAAO,CAAC,mBAAmB,CAAC,IAAI,CAAC,OAAO,CAAC,CAAA;IAC5D,MAAM,WAAW,GAAG,CAAC,CACnB,MAAiD,EACjD,IAAiC,EACjC,EAAE;QACF,CAAC,EAAE,CAAA;QACH,MAAM,KAAK,GAAG,MAAM,CAAC,EAAE,CAAC,CAAC,CAAC,CAAA;QAC1B,MAAM,CAAC,CAAC,CAAC,GAAG,KAAK,GAAG,IAAI,CAAC,IAAI,CAAC,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,MAAM,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,MAAM,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,CAAA;QACxF,MAAM,MAAM,GAA8C,EAAE,GAAG,MAAM,EAAE,CAAC,YAAY,CAAC,EAAE,MAAM,EAAE,CAAA;QAC/F,IAAI,cAAc,IAAI,oBAAoB,EAAE,CAAC;YAC3C,cAAc,CAAC,CAAC,CAAC,GAAG,CAAC,GAAG,KAAK,GAAG,MAAM,CAAC,WAAW,CAAA;YAClD,MAAM,CAAC,oBAAoB,CAAC,GAAG,cAAc,CAAA;QAC/C,CAAC;QACD,OAAO,IAAI,EAAE,YAAY,CAAC,CAAC,CAAC,SAAS,CAAC,MAAM,EAAE,EAAE,YAAY,EAAE,IAAI,EAAE,CAAC,CAAC,CAAC,CAAC,SAAS,CAAC,MAAM,CAAC,CAAA;IAC3F,CAAC,CAA4B,CAAA;IAC7B,OAAO,CAAC,IAAI,GAAG,WAAW,CAAA;IAC1B,OAAO,CAAC,mBAAmB,GAAG,GAAG,EAAE;QACjC,CAAC,GAAG,CAAC,CAAA;QACL,UAAU,EAAE,CAAA;IACd,CAAC,CAAA;IACD,KAAK,OAAO,CAAA;AACd,CAAC;AAED;;8EAE8E;AAC9E,SAAS,mBAAmB,CAC1B,IAAgB,EAChB,OAA+B,EAC/B,OAA0C,EAC1C,YAAgD;IAEhD,MAAM,GAAG,GAAiC,EAAE,CAAA;IAC5C,KAAK,MAAM,CAAC,IAAI,EAAE,KAAK,CAAC,IAAI,IAAI,CAAC,YAAY,EAAE,CAAC;QAC9C,IAAI,YAAY,EAAE,GAAG,CAAC,IAAI,CAAC;YAAE,SAAQ;QACrC,MAAM,KAAK,GAAG,IAAI,CAAC,OAAO,CAAC,KAAK,CAAE,CAAC,KAAK,CAAA;QACxC,MAAM,IAAI,GAAG,KAAK,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC,CAAA;QAC7C,MAAM,MAAM,GAAG,OAAO,CAAC,IAAI,CAAC,CAAA;QAC5B,IAAI,CAAC,MAAM;YAAE,MAAM,IAAI,KAAK,CAAC,+BAA+B,IAAI,GAAG,CAAC,CAAA;QACpE,GAAG,CAAC,IAAI,CAAC,GAAG,MAAM,CAAC,IAAI,EAAE,KAAK,CAAC,CAAA;IACjC,CAAC;IACD,IAAI,MAAM,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,MAAM,GAAG,CAAC;QAAE,OAAO,CAAC,YAAY,CAAC,GAAG,EAAE,EAAE,OAAO,EAAE,CAAC,CAAC,YAAY,EAAE,CAAC,CAAA;AACzF,CAAC"}
|
package/dist/index.d.ts
CHANGED
|
@@ -7,8 +7,8 @@ export { appendGrad, type GradResult } from './grad.js';
|
|
|
7
7
|
export { appendAdam, type AdamConfig, type AdamResult } from './adam.js';
|
|
8
8
|
export { planBuffers, type BufferPlan, type BufferSpec, type Writeback, type WritebackDecl } from './buffers.js';
|
|
9
9
|
export { emitKernels, type KernelSpec } from './codegen.js';
|
|
10
|
-
export { createRuntime, createForwardRuntime, type CompiledRuntime, type CompiledForward, type RuntimeOpts, type
|
|
11
|
-
export { compile, compileToIR, compileModule, compileForward, type CompiledIR, type CompileModuleOptions, type CompileForwardOptions, type InputDecl } from './compile.js';
|
|
10
|
+
export { createRuntime, createForwardRuntime, Captures, type CompiledRuntime, type CompiledForward, type RuntimeOpts, type RunOptions, type StepResult, type RunResult } from './runtime.js';
|
|
11
|
+
export { compile, compileToIR, compileModule, compileForward, type CompiledIR, type CompileModuleOptions, type CompileForwardOptions, type CompileForwardMethodOptions, type CompiledModule, type CompiledForwardModule, type InputDecl, type InputDecls, type InputsTensors, type ForwardFn, } from './compile.js';
|
|
12
12
|
export { Module, materializeParams, type InitSpec, type ParamOptions, type MaterializedParams } from './module.js';
|
|
13
13
|
export * as nn from './nn.js';
|
|
14
14
|
//# sourceMappingURL=index.d.ts.map
|
package/dist/index.d.ts.map
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAKA,YAAY,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,QAAQ,EAAE,MAAM,SAAS,CAAA;AAC5E,OAAO,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AACvC,OAAO,EAAE,KAAK,EAAE,SAAS,EAAE,UAAU,EAAE,WAAW,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AAClF,OAAO,EAAE,OAAO,EAAE,MAAM,cAAc,CAAA;AACtC,OAAO,EAEL,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG,EAElB,IAAI,EAAE,KAAK,EAAE,GAAG,EAAE,GAAG,EAAE,IAAI,EAE3B,IAAI,EAAE,OAAO,EAAE,KAAK,EAEpB,QAAQ,EAAE,OAAO,EAAE,MAAM,EAEzB,OAAO,EAAE,SAAS,EAAE,QAAQ,EAE5B,MAAM,EAAE,aAAa,EAErB,MAAM,EAAE,MAAM,EAAE,SAAS,EAEzB,iBAAiB,EAAE,cAAc,EAAE,WAAW,EAE9C,cAAc,GACf,MAAM,UAAU,CAAA;AAMjB,OAAO,EAAE,UAAU,EAAE,KAAK,UAAU,EAAE,MAAM,WAAW,CAAA;AACvD,OAAO,EAAE,UAAU,EAAE,KAAK,UAAU,EAAE,KAAK,UAAU,EAAE,MAAM,WAAW,CAAA;AACxE,OAAO,EAAE,WAAW,EAAE,KAAK,UAAU,EAAE,KAAK,UAAU,EAAE,KAAK,SAAS,EAAE,KAAK,aAAa,EAAE,MAAM,cAAc,CAAA;AAChH,OAAO,EAAE,WAAW,EAAE,KAAK,UAAU,EAAE,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAE,aAAa,EAAE,oBAAoB,EAAE,
|
|
1
|
+
{"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAKA,YAAY,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,QAAQ,EAAE,MAAM,SAAS,CAAA;AAC5E,OAAO,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AACvC,OAAO,EAAE,KAAK,EAAE,SAAS,EAAE,UAAU,EAAE,WAAW,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AAClF,OAAO,EAAE,OAAO,EAAE,MAAM,cAAc,CAAA;AACtC,OAAO,EAEL,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG,EAElB,IAAI,EAAE,KAAK,EAAE,GAAG,EAAE,GAAG,EAAE,IAAI,EAE3B,IAAI,EAAE,OAAO,EAAE,KAAK,EAEpB,QAAQ,EAAE,OAAO,EAAE,MAAM,EAEzB,OAAO,EAAE,SAAS,EAAE,QAAQ,EAE5B,MAAM,EAAE,aAAa,EAErB,MAAM,EAAE,MAAM,EAAE,SAAS,EAEzB,iBAAiB,EAAE,cAAc,EAAE,WAAW,EAE9C,cAAc,GACf,MAAM,UAAU,CAAA;AAMjB,OAAO,EAAE,UAAU,EAAE,KAAK,UAAU,EAAE,MAAM,WAAW,CAAA;AACvD,OAAO,EAAE,UAAU,EAAE,KAAK,UAAU,EAAE,KAAK,UAAU,EAAE,MAAM,WAAW,CAAA;AACxE,OAAO,EAAE,WAAW,EAAE,KAAK,UAAU,EAAE,KAAK,UAAU,EAAE,KAAK,SAAS,EAAE,KAAK,aAAa,EAAE,MAAM,cAAc,CAAA;AAChH,OAAO,EAAE,WAAW,EAAE,KAAK,UAAU,EAAE,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAE,aAAa,EAAE,oBAAoB,EAAE,QAAQ,EAAE,KAAK,eAAe,EAAE,KAAK,eAAe,EAAE,KAAK,WAAW,EAAE,KAAK,UAAU,EAAE,KAAK,UAAU,EAAE,KAAK,SAAS,EAAE,MAAM,cAAc,CAAA;AAC5L,OAAO,EACL,OAAO,EAAE,WAAW,EAAE,aAAa,EAAE,cAAc,EACnD,KAAK,UAAU,EAAE,KAAK,oBAAoB,EAAE,KAAK,qBAAqB,EAAE,KAAK,2BAA2B,EACxG,KAAK,cAAc,EAAE,KAAK,qBAAqB,EAC/C,KAAK,SAAS,EAAE,KAAK,UAAU,EAAE,KAAK,aAAa,EAAE,KAAK,SAAS,GACpE,MAAM,cAAc,CAAA;AACrB,OAAO,EAAE,MAAM,EAAE,iBAAiB,EAAE,KAAK,QAAQ,EAAE,KAAK,YAAY,EAAE,KAAK,kBAAkB,EAAE,MAAM,aAAa,CAAA;AAClH,OAAO,KAAK,EAAE,MAAM,SAAS,CAAA"}
|
package/dist/index.js
CHANGED
|
@@ -32,8 +32,8 @@ export { appendGrad } from './grad.js';
|
|
|
32
32
|
export { appendAdam } from './adam.js';
|
|
33
33
|
export { planBuffers } from './buffers.js';
|
|
34
34
|
export { emitKernels } from './codegen.js';
|
|
35
|
-
export { createRuntime, createForwardRuntime } from './runtime.js';
|
|
36
|
-
export { compile, compileToIR, compileModule, compileForward } from './compile.js';
|
|
35
|
+
export { createRuntime, createForwardRuntime, Captures } from './runtime.js';
|
|
36
|
+
export { compile, compileToIR, compileModule, compileForward, } from './compile.js';
|
|
37
37
|
export { Module, materializeParams } from './module.js';
|
|
38
38
|
export * as nn from './nn.js';
|
|
39
39
|
//# sourceMappingURL=index.js.map
|
package/dist/index.js.map
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"index.js","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAAA,+CAA+C;AAC/C,EAAE;AACF,8EAA8E;AAC9E,6CAA6C;AAG7C,OAAO,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AACvC,OAAO,EAAE,KAAK,EAAE,SAAS,EAAE,UAAU,EAAE,WAAW,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AAClF,OAAO,EAAE,OAAO,EAAE,MAAM,cAAc,CAAA;AACtC,OAAO;AACL,qFAAqF;AACrF,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG;AAClB,qBAAqB;AACrB,IAAI,EAAE,KAAK,EAAE,GAAG,EAAE,GAAG,EAAE,IAAI;AAC3B,uBAAuB;AACvB,IAAI,EAAE,OAAO,EAAE,KAAK;AACpB,yEAAyE;AACzE,QAAQ,EAAE,OAAO,EAAE,MAAM;AACzB,YAAY;AACZ,OAAO,EAAE,SAAS,EAAE,QAAQ;AAC5B,iBAAiB;AACjB,MAAM,EAAE,aAAa;AACrB,qBAAqB;AACrB,MAAM,EAAE,MAAM,EAAE,SAAS;AACzB,4CAA4C;AAC5C,iBAAiB,EAAE,cAAc,EAAE,WAAW;AAC9C,UAAU;AACV,cAAc,GACf,MAAM,UAAU,CAAA;AAEjB,sFAAsF;AACtF,8EAA8E;AAC9E,2EAA2E;AAC3E,qDAAqD;AACrD,OAAO,EAAE,UAAU,EAAmB,MAAM,WAAW,CAAA;AACvD,OAAO,EAAE,UAAU,EAAoC,MAAM,WAAW,CAAA;AACxE,OAAO,EAAE,WAAW,EAAwE,MAAM,cAAc,CAAA;AAChH,OAAO,EAAE,WAAW,EAAmB,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAE,aAAa,EAAE,oBAAoB,
|
|
1
|
+
{"version":3,"file":"index.js","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAAA,+CAA+C;AAC/C,EAAE;AACF,8EAA8E;AAC9E,6CAA6C;AAG7C,OAAO,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AACvC,OAAO,EAAE,KAAK,EAAE,SAAS,EAAE,UAAU,EAAE,WAAW,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AAClF,OAAO,EAAE,OAAO,EAAE,MAAM,cAAc,CAAA;AACtC,OAAO;AACL,qFAAqF;AACrF,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG;AAClB,qBAAqB;AACrB,IAAI,EAAE,KAAK,EAAE,GAAG,EAAE,GAAG,EAAE,IAAI;AAC3B,uBAAuB;AACvB,IAAI,EAAE,OAAO,EAAE,KAAK;AACpB,yEAAyE;AACzE,QAAQ,EAAE,OAAO,EAAE,MAAM;AACzB,YAAY;AACZ,OAAO,EAAE,SAAS,EAAE,QAAQ;AAC5B,iBAAiB;AACjB,MAAM,EAAE,aAAa;AACrB,qBAAqB;AACrB,MAAM,EAAE,MAAM,EAAE,SAAS;AACzB,4CAA4C;AAC5C,iBAAiB,EAAE,cAAc,EAAE,WAAW;AAC9C,UAAU;AACV,cAAc,GACf,MAAM,UAAU,CAAA;AAEjB,sFAAsF;AACtF,8EAA8E;AAC9E,2EAA2E;AAC3E,qDAAqD;AACrD,OAAO,EAAE,UAAU,EAAmB,MAAM,WAAW,CAAA;AACvD,OAAO,EAAE,UAAU,EAAoC,MAAM,WAAW,CAAA;AACxE,OAAO,EAAE,WAAW,EAAwE,MAAM,cAAc,CAAA;AAChH,OAAO,EAAE,WAAW,EAAmB,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAE,aAAa,EAAE,oBAAoB,EAAE,QAAQ,EAAkH,MAAM,cAAc,CAAA;AAC5L,OAAO,EACL,OAAO,EAAE,WAAW,EAAE,aAAa,EAAE,cAAc,GAIpD,MAAM,cAAc,CAAA;AACrB,OAAO,EAAE,MAAM,EAAE,iBAAiB,EAA6D,MAAM,aAAa,CAAA;AAClH,OAAO,KAAK,EAAE,MAAM,SAAS,CAAA"}
|
package/dist/nn.d.ts
CHANGED
|
@@ -1,35 +1,38 @@
|
|
|
1
1
|
import { Module } from './module.js';
|
|
2
2
|
import type { Tensor } from './ir.js';
|
|
3
|
+
import type { Captures } from './runtime.js';
|
|
4
|
+
export interface LinearOptions {
|
|
5
|
+
/** Include a bias term (default true). */
|
|
6
|
+
bias?: boolean;
|
|
7
|
+
}
|
|
3
8
|
export declare class Linear extends Module {
|
|
4
9
|
readonly inDim: number;
|
|
5
10
|
readonly outDim: number;
|
|
6
11
|
W: Tensor;
|
|
7
12
|
b: Tensor | null;
|
|
8
|
-
constructor(inDim: number, outDim: number,
|
|
13
|
+
constructor(inDim: number, outDim: number, opts?: LinearOptions);
|
|
14
|
+
fwd(x: Tensor): Tensor;
|
|
9
15
|
}
|
|
10
|
-
export declare function linearFwd(p: Linear, x: Tensor): Tensor;
|
|
11
16
|
export declare class LayerNorm extends Module {
|
|
12
17
|
readonly d: number;
|
|
13
18
|
readonly eps: number;
|
|
14
19
|
g: Tensor;
|
|
15
20
|
b: Tensor;
|
|
16
21
|
constructor(d: number, eps?: number);
|
|
22
|
+
fwd(x: Tensor): Tensor;
|
|
17
23
|
}
|
|
18
|
-
export declare function layerNormFwd(p: LayerNorm, x: Tensor): Tensor;
|
|
19
24
|
/** [..., T, D] → [..., H, T, D/H]. Folds the standard
|
|
20
25
|
* `transpose(reshape(x, [..., T, H, d]), [..., H, T, d])` pattern into one
|
|
21
26
|
* call. Last dim of `x` must divide evenly by `nHeads`. */
|
|
22
27
|
export declare function splitHeads(x: Tensor, nHeads: number): Tensor;
|
|
23
28
|
/** Inverse of `splitHeads`: [..., H, T, d] → [..., T, H*d]. */
|
|
24
29
|
export declare function mergeHeads(x: Tensor): Tensor;
|
|
25
|
-
/** Slice a
|
|
26
|
-
*
|
|
27
|
-
*
|
|
28
|
-
*
|
|
29
|
-
*
|
|
30
|
-
|
|
31
|
-
* out captures at B=1 (the typical capture-readback shape). */
|
|
32
|
-
export declare function unsplitHeads(flat: Float32Array, shape: readonly number[]): Float32Array[];
|
|
30
|
+
/** Slice a captured tensor named `name` into one Float32Array per head, using
|
|
31
|
+
* the static shape registered at compile time. The leading axis is treated as
|
|
32
|
+
* heads (matching `splitHeads` layout at B=1); a leading singleton batch is
|
|
33
|
+
* stripped if present so callers can pass capture names directly. Throws if
|
|
34
|
+
* the capture isn't registered or wasn't read back this call. */
|
|
35
|
+
export declare function unsplitHeads(captures: Captures, name: string): Float32Array[];
|
|
33
36
|
/** Per-position cross-entropy along the last (vocab) axis: returns
|
|
34
37
|
* `-log p(target)` at each position. `logits` is `[..., V]`; `targets` is
|
|
35
38
|
* `[...]` of i32; result is `[...]` (one rank less than logits). The user
|
package/dist/nn.d.ts.map
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"nn.d.ts","sourceRoot":"","sources":["../src/nn.ts"],"names":[],"mappings":"
|
|
1
|
+
{"version":3,"file":"nn.d.ts","sourceRoot":"","sources":["../src/nn.ts"],"names":[],"mappings":"AAaA,OAAO,EAAE,MAAM,EAAE,MAAM,aAAa,CAAA;AACpC,OAAO,KAAK,EAAE,MAAM,EAAE,MAAM,SAAS,CAAA;AAIrC,OAAO,KAAK,EAAE,QAAQ,EAAE,MAAM,cAAc,CAAA;AAM5C,MAAM,WAAW,aAAa;IAC5B,0CAA0C;IAC1C,IAAI,CAAC,EAAE,OAAO,CAAA;CACf;AAED,qBAAa,MAAO,SAAQ,MAAM;aAGJ,KAAK,EAAE,MAAM;aAAkB,MAAM,EAAE,MAAM;IAFzE,CAAC,EAAE,MAAM,CAAA;IACT,CAAC,EAAE,MAAM,GAAG,IAAI,CAAA;gBACY,KAAK,EAAE,MAAM,EAAkB,MAAM,EAAE,MAAM,EAAE,IAAI,GAAE,aAAkB;IAKnG,GAAG,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM;CAIvB;AAMD,qBAAa,SAAU,SAAQ,MAAM;aAGP,CAAC,EAAE,MAAM;aAAkB,GAAG,EAAE,MAAM;IAFlE,CAAC,EAAE,MAAM,CAAA;IACT,CAAC,EAAE,MAAM,CAAA;gBACmB,CAAC,EAAE,MAAM,EAAkB,GAAG,GAAE,MAAa;IAKzE,GAAG,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM;CAOvB;AAOD;;4DAE4D;AAC5D,wBAAgB,UAAU,CAAC,CAAC,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,MAAM,CAa5D;AAED,+DAA+D;AAC/D,wBAAgB,UAAU,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAW5C;AAED;;;;kEAIkE;AAClE,wBAAgB,YAAY,CAAC,QAAQ,EAAE,QAAQ,EAAE,IAAI,EAAE,MAAM,GAAG,YAAY,EAAE,CAiB7E;AAMD;;;;+EAI+E;AAC/E,wBAAgB,gBAAgB,CAAC,MAAM,EAAE,MAAM,EAAE,OAAO,EAAE,MAAM,GAAG,MAAM,CASxE"}
|
package/dist/nn.js
CHANGED
|
@@ -1,40 +1,35 @@
|
|
|
1
1
|
// Standard "batteries-included" Module subclasses for the most common layers.
|
|
2
2
|
//
|
|
3
|
-
//
|
|
4
|
-
//
|
|
5
|
-
//
|
|
6
|
-
//
|
|
7
|
-
// Import as a namespace:
|
|
3
|
+
// Each class declares its params and a `.fwd(x)` method that runs the forward
|
|
4
|
+
// computation. Forward methods are pure tensorgrad ops — autograd traces
|
|
5
|
+
// through them just like any other call.
|
|
8
6
|
//
|
|
9
7
|
// import { nn } from 'tensorgrad'
|
|
10
8
|
// class Block extends Module {
|
|
11
9
|
// ln = new nn.LayerNorm(D)
|
|
12
10
|
// ffn = new nn.Linear(D, 4 * D)
|
|
13
11
|
// }
|
|
14
|
-
// const y =
|
|
12
|
+
// const y = p.ffn.fwd(p.ln.fwd(x))
|
|
15
13
|
import { Module } from './module.js';
|
|
16
14
|
import { add, matmul, sub, mul, div, sqrt, meanLast, sumLast, reshape, swapAxes, oneHot, logSoftmaxLast } from './ops.js';
|
|
17
15
|
import { ShapeError } from './shape.js';
|
|
18
16
|
import { captureSite } from './ir.js';
|
|
19
|
-
// ----------------------------------------------------------------------------
|
|
20
|
-
// Linear: y = x @ W (+ b)
|
|
21
|
-
// ----------------------------------------------------------------------------
|
|
22
17
|
export class Linear extends Module {
|
|
23
18
|
inDim;
|
|
24
19
|
outDim;
|
|
25
20
|
W;
|
|
26
21
|
b;
|
|
27
|
-
constructor(inDim, outDim,
|
|
22
|
+
constructor(inDim, outDim, opts = {}) {
|
|
28
23
|
super();
|
|
29
24
|
this.inDim = inDim;
|
|
30
25
|
this.outDim = outDim;
|
|
31
26
|
this.W = this.param([inDim, outDim]); // randn, scale 0.02
|
|
32
|
-
this.b =
|
|
27
|
+
this.b = opts.bias === false ? null : this.param([outDim], { init: 'zeros' });
|
|
28
|
+
}
|
|
29
|
+
fwd(x) {
|
|
30
|
+
const out = matmul(x, this.W);
|
|
31
|
+
return this.b ? add(out, this.b) : out;
|
|
33
32
|
}
|
|
34
|
-
}
|
|
35
|
-
export function linearFwd(p, x) {
|
|
36
|
-
const out = matmul(x, p.W);
|
|
37
|
-
return p.b ? add(out, p.b) : out;
|
|
38
33
|
}
|
|
39
34
|
// ----------------------------------------------------------------------------
|
|
40
35
|
// LayerNorm — normalizes over the last axis. eps defaults to 1e-5.
|
|
@@ -51,13 +46,13 @@ export class LayerNorm extends Module {
|
|
|
51
46
|
this.g = this.param([d], { init: 'ones' });
|
|
52
47
|
this.b = this.param([d], { init: 'zeros' });
|
|
53
48
|
}
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
49
|
+
fwd(x) {
|
|
50
|
+
const m = meanLast(x);
|
|
51
|
+
const c = sub(x, m);
|
|
52
|
+
const v = meanLast(mul(c, c));
|
|
53
|
+
const stdev = sqrt(add(v, this.eps));
|
|
54
|
+
return add(mul(div(c, stdev), this.g), this.b);
|
|
55
|
+
}
|
|
61
56
|
}
|
|
62
57
|
// ----------------------------------------------------------------------------
|
|
63
58
|
// Multi-head attention shape helpers — split the last (model) axis into
|
|
@@ -95,19 +90,19 @@ export function mergeHeads(x) {
|
|
|
95
90
|
const swapped = swapAxes(x, r - 3, r - 2);
|
|
96
91
|
return reshape(swapped, [...lead, T, H * d]);
|
|
97
92
|
}
|
|
98
|
-
/** Slice a
|
|
99
|
-
*
|
|
100
|
-
*
|
|
101
|
-
*
|
|
102
|
-
*
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
93
|
+
/** Slice a captured tensor named `name` into one Float32Array per head, using
|
|
94
|
+
* the static shape registered at compile time. The leading axis is treated as
|
|
95
|
+
* heads (matching `splitHeads` layout at B=1); a leading singleton batch is
|
|
96
|
+
* stripped if present so callers can pass capture names directly. Throws if
|
|
97
|
+
* the capture isn't registered or wasn't read back this call. */
|
|
98
|
+
export function unsplitHeads(captures, name) {
|
|
99
|
+
const flat = captures.get(name);
|
|
100
|
+
const shape = captures.shapeOf(name);
|
|
106
101
|
if (shape.length < 2) {
|
|
107
|
-
throw new Error(`unsplitHeads: shape needs >= 2 dims, got [${shape.join(', ')}]`);
|
|
102
|
+
throw new Error(`unsplitHeads: '${name}' shape needs >= 2 dims, got [${shape.join(', ')}]`);
|
|
108
103
|
}
|
|
109
104
|
// For inference graphs at B=1, captures have shape [1, H, ..., ...]. Strip
|
|
110
|
-
// the leading 1 if present so
|
|
105
|
+
// the leading 1 if present so the next axis is heads.
|
|
111
106
|
const s = shape[0] === 1 ? shape.slice(1) : shape;
|
|
112
107
|
const H = s[0];
|
|
113
108
|
let stride = 1;
|
|
@@ -115,7 +110,7 @@ export function unsplitHeads(flat, shape) {
|
|
|
115
110
|
stride *= s[i];
|
|
116
111
|
const expected = H * stride;
|
|
117
112
|
if (flat.length !== expected) {
|
|
118
|
-
throw new Error(`unsplitHeads:
|
|
113
|
+
throw new Error(`unsplitHeads: '${name}' length ${flat.length} doesn't match shape product ${expected}`);
|
|
119
114
|
}
|
|
120
115
|
return Array.from({ length: H }, (_, h) => flat.slice(h * stride, (h + 1) * stride));
|
|
121
116
|
}
|
package/dist/nn.js.map
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"nn.js","sourceRoot":"","sources":["../src/nn.ts"],"names":[],"mappings":"AAAA,8EAA8E;AAC9E,EAAE;AACF
|
|
1
|
+
{"version":3,"file":"nn.js","sourceRoot":"","sources":["../src/nn.ts"],"names":[],"mappings":"AAAA,8EAA8E;AAC9E,EAAE;AACF,8EAA8E;AAC9E,yEAAyE;AACzE,yCAAyC;AACzC,EAAE;AACF,oCAAoC;AACpC,iCAAiC;AACjC,gCAAgC;AAChC,oCAAoC;AACpC,MAAM;AACN,qCAAqC;AAErC,OAAO,EAAE,MAAM,EAAE,MAAM,aAAa,CAAA;AAEpC,OAAO,EAAE,GAAG,EAAE,MAAM,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,IAAI,EAAE,QAAQ,EAAE,OAAO,EAAE,OAAO,EAAE,QAAQ,EAAE,MAAM,EAAE,cAAc,EAAE,MAAM,UAAU,CAAA;AACzH,OAAO,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AACvC,OAAO,EAAE,WAAW,EAAE,MAAM,SAAS,CAAA;AAYrC,MAAM,OAAO,MAAO,SAAQ,MAAM;IAGJ;IAA+B;IAF3D,CAAC,CAAQ;IACT,CAAC,CAAe;IAChB,YAA4B,KAAa,EAAkB,MAAc,EAAE,OAAsB,EAAE;QACjG,KAAK,EAAE,CAAA;QADmB,UAAK,GAAL,KAAK,CAAQ;QAAkB,WAAM,GAAN,MAAM,CAAQ;QAEvE,IAAI,CAAC,CAAC,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,MAAM,CAAC,CAAC,CAAA,CAAsB,oBAAoB;QAC9E,IAAI,CAAC,CAAC,GAAG,IAAI,CAAC,IAAI,KAAK,KAAK,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,MAAM,CAAC,EAAE,EAAE,IAAI,EAAE,OAAO,EAAE,CAAC,CAAA;IAC/E,CAAC;IACD,GAAG,CAAC,CAAS;QACX,MAAM,GAAG,GAAG,MAAM,CAAC,CAAC,EAAE,IAAI,CAAC,CAAC,CAAC,CAAA;QAC7B,OAAO,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,GAAG,EAAE,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,CAAA;IACxC,CAAC;CACF;AAED,+EAA+E;AAC/E,mEAAmE;AACnE,+EAA+E;AAE/E,MAAM,OAAO,SAAU,SAAQ,MAAM;IAGP;IAA2B;IAFvD,CAAC,CAAQ;IACT,CAAC,CAAQ;IACT,YAA4B,CAAS,EAAkB,MAAc,IAAI;QACvE,KAAK,EAAE,CAAA;QADmB,MAAC,GAAD,CAAC,CAAQ;QAAkB,QAAG,GAAH,GAAG,CAAe;QAEvE,IAAI,CAAC,CAAC,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,EAAE,EAAE,IAAI,EAAE,MAAM,EAAE,CAAC,CAAA;QAC1C,IAAI,CAAC,CAAC,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,EAAE,EAAE,IAAI,EAAE,OAAO,EAAE,CAAC,CAAA;IAC7C,CAAC;IACD,GAAG,CAAC,CAAS;QACX,MAAM,CAAC,GAAG,QAAQ,CAAC,CAAC,CAAC,CAAA;QACrB,MAAM,CAAC,GAAG,GAAG,CAAC,CAAC,EAAE,CAAC,CAAC,CAAA;QACnB,MAAM,CAAC,GAAG,QAAQ,CAAC,GAAG,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAA;QAC7B,MAAM,KAAK,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,IAAI,CAAC,GAAG,CAAC,CAAC,CAAA;QACpC,OAAO,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,EAAE,KAAK,CAAC,EAAE,IAAI,CAAC,CAAC,CAAC,EAAE,IAAI,CAAC,CAAC,CAAC,CAAA;IAChD,CAAC;CACF;AAED,+EAA+E;AAC/E,wEAAwE;AACxE,gEAAgE;AAChE,+EAA+E;AAE/E;;4DAE4D;AAC5D,MAAM,UAAU,UAAU,CAAC,CAAS,EAAE,MAAc;IAClD,MAAM,IAAI,GAAG,WAAW,CAAC,YAAY,CAAC,CAAA;IACtC,MAAM,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,MAAM,CAAA;IACxB,IAAI,CAAC,GAAG,CAAC;QAAE,MAAM,IAAI,UAAU,CAAC,uCAAuC,CAAC,EAAE,EAAE,IAAI,CAAC,CAAA;IACjF,MAAM,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,GAAG,CAAC,CAAE,CAAA;IACzB,MAAM,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,GAAG,CAAC,CAAE,CAAA;IACzB,IAAI,CAAC,GAAG,MAAM,KAAK,CAAC,EAAE,CAAC;QACrB,MAAM,IAAI,UAAU,CAAC,wBAAwB,CAAC,4BAA4B,MAAM,EAAE,EAAE,IAAI,CAAC,CAAA;IAC3F,CAAC;IACD,MAAM,IAAI,GAAG,CAAC,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,CAAA;IACpC,MAAM,QAAQ,GAAG,OAAO,CAAC,CAAC,EAAE,CAAC,GAAG,IAAI,EAAE,CAAC,EAAE,MAAM,EAAE,CAAC,GAAG,MAAM,CAAC,CAAC,CAAA;IAC7D,2DAA2D;IAC3D,OAAO,QAAQ,CAAC,QAAQ,EAAE,IAAI,CAAC,MAAM,EAAE,IAAI,CAAC,MAAM,GAAG,CAAC,CAAC,CAAA;AACzD,CAAC;AAED,+DAA+D;AAC/D,MAAM,UAAU,UAAU,CAAC,CAAS;IAClC,MAAM,IAAI,GAAG,WAAW,CAAC,YAAY,CAAC,CAAA;IACtC,MAAM,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,MAAM,CAAA;IACxB,IAAI,CAAC,GAAG,CAAC;QAAE,MAAM,IAAI,UAAU,CAAC,uCAAuC,CAAC,EAAE,EAAE,IAAI,CAAC,CAAA;IACjF,MAAM,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,GAAG,CAAC,CAAE,CAAA;IACzB,MAAM,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,GAAG,CAAC,CAAE,CAAA;IACzB,MAAM,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,GAAG,CAAC,CAAE,CAAA;IACzB,MAAM,IAAI,GAAG,CAAC,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,CAAA;IACpC,sEAAsE;IACtE,MAAM,OAAO,GAAG,QAAQ,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,CAAA;IACzC,OAAO,OAAO,CAAC,OAAO,EAAE,CAAC,GAAG,IAAI,EAAE,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,CAAC,CAAA;AAC9C,CAAC;AAED;;;;kEAIkE;AAClE,MAAM,UAAU,YAAY,CAAC,QAAkB,EAAE,IAAY;IAC3D,MAAM,IAAI,GAAG,QAAQ,CAAC,GAAG,CAAC,IAAI,CAAC,CAAA;IAC/B,MAAM,KAAK,GAAG,QAAQ,CAAC,OAAO,CAAC,IAAI,CAAC,CAAA;IACpC,IAAI,KAAK,CAAC,MAAM,GAAG,CAAC,EAAE,CAAC;QACrB,MAAM,IAAI,KAAK,CAAC,kBAAkB,IAAI,iCAAiC,KAAK,CAAC,IAAI,CAAC,IAAI,CAAC,GAAG,CAAC,CAAA;IAC7F,CAAC;IACD,2EAA2E;IAC3E,sDAAsD;IACtD,MAAM,CAAC,GAAG,KAAK,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,KAAK,CAAA;IACjD,MAAM,CAAC,GAAG,CAAC,CAAC,CAAC,CAAE,CAAA;IACf,IAAI,MAAM,GAAG,CAAC,CAAA;IACd,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,MAAM,EAAE,CAAC,EAAE;QAAE,MAAM,IAAI,CAAC,CAAC,CAAC,CAAE,CAAA;IAClD,MAAM,QAAQ,GAAG,CAAC,GAAG,MAAM,CAAA;IAC3B,IAAI,IAAI,CAAC,MAAM,KAAK,QAAQ,EAAE,CAAC;QAC7B,MAAM,IAAI,KAAK,CAAC,kBAAkB,IAAI,YAAY,IAAI,CAAC,MAAM,gCAAgC,QAAQ,EAAE,CAAC,CAAA;IAC1G,CAAC;IACD,OAAO,KAAK,CAAC,IAAI,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,GAAG,MAAM,EAAE,CAAC,CAAC,GAAG,CAAC,CAAC,GAAG,MAAM,CAAC,CAAC,CAAA;AACtF,CAAC;AAED,+EAA+E;AAC/E,eAAe;AACf,+EAA+E;AAE/E;;;;+EAI+E;AAC/E,MAAM,UAAU,gBAAgB,CAAC,MAAc,EAAE,OAAe;IAC9D,MAAM,IAAI,GAAG,WAAW,CAAC,kBAAkB,CAAC,CAAA;IAC5C,IAAI,OAAO,CAAC,KAAK,KAAK,KAAK,EAAE,CAAC;QAC5B,MAAM,IAAI,UAAU,CAAC,8CAA8C,OAAO,CAAC,KAAK,EAAE,EAAE,IAAI,CAAC,CAAA;IAC3F,CAAC;IACD,MAAM,KAAK,GAAG,MAAM,CAAC,KAAK,CAAC,MAAM,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAE,CAAA;IACpD,MAAM,EAAE,GAAG,cAAc,CAAC,MAAM,CAAC,CAAA,CAAmC,WAAW;IAC/E,MAAM,QAAQ,GAAG,OAAO,CAAC,GAAG,CAAC,EAAE,EAAE,MAAM,CAAC,OAAO,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAA,CAAI,QAAQ;IAC5E,OAAO,GAAG,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAA;AAC1B,CAAC"}
|