@simulatte/doppler 0.1.5 → 0.1.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +23 -8
- package/package.json +7 -4
- package/src/config/kernels/kernel-ref-digests.js +39 -39
- package/src/config/kernels/registry.json +42 -2
- package/src/config/loader.js +31 -2
- package/src/config/merge.js +18 -0
- package/src/config/presets/models/qwen3.json +9 -2
- package/src/config/presets/models/transformer.json +5 -0
- package/src/config/required-inference-fields-contract-check.js +6 -0
- package/src/config/schema/inference-defaults.schema.js +3 -0
- package/src/config/schema/inference.schema.d.ts +9 -0
- package/src/config/schema/kernel-path.schema.d.ts +6 -0
- package/src/config/schema/manifest.schema.d.ts +6 -0
- package/src/config/schema/manifest.schema.js +3 -0
- package/src/converter/rope-config.js +42 -0
- package/src/gpu/device.js +58 -0
- package/src/gpu/kernels/attention.js +98 -0
- package/src/gpu/kernels/bias_add.wgsl +8 -6
- package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
- package/src/gpu/kernels/conv2d.js +1 -1
- package/src/gpu/kernels/conv2d.wgsl +7 -8
- package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
- package/src/gpu/kernels/depthwise_conv2d.js +2 -1
- package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
- package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +2 -1
- package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
- package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
- package/src/gpu/kernels/matmul.js +25 -0
- package/src/gpu/kernels/pixel_shuffle.js +1 -1
- package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
- package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
- package/src/gpu/kernels/relu.js +15 -2
- package/src/gpu/kernels/relu.wgsl +2 -1
- package/src/gpu/kernels/relu_f16.wgsl +2 -1
- package/src/gpu/kernels/repeat_channels.js +1 -1
- package/src/gpu/kernels/repeat_channels.wgsl +4 -5
- package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
- package/src/gpu/kernels/residual.js +44 -8
- package/src/gpu/kernels/residual.wgsl +6 -3
- package/src/gpu/kernels/residual_f16.wgsl +2 -1
- package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
- package/src/gpu/kernels/residual_vec4.wgsl +2 -1
- package/src/gpu/kernels/rmsnorm.js +58 -6
- package/src/gpu/kernels/rmsnorm.wgsl +14 -6
- package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
- package/src/gpu/kernels/rope.d.ts +2 -0
- package/src/gpu/kernels/rope.js +11 -1
- package/src/gpu/kernels/rope.wgsl +56 -40
- package/src/gpu/kernels/sana_linear_attention.js +1 -2
- package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
- package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
- package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
- package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
- package/src/gpu/kernels/silu.d.ts +1 -0
- package/src/gpu/kernels/silu.js +32 -14
- package/src/gpu/kernels/silu.wgsl +19 -9
- package/src/gpu/kernels/silu_f16.wgsl +19 -9
- package/src/gpu/kernels/transpose.js +15 -2
- package/src/gpu/kernels/transpose.wgsl +5 -6
- package/src/gpu/kernels/upsample2d.js +2 -1
- package/src/gpu/kernels/upsample2d.wgsl +6 -9
- package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
- package/src/gpu/kernels/utils.js +16 -1
- package/src/inference/browser-harness.js +47 -1
- package/src/inference/pipelines/diffusion/pipeline.js +15 -6
- package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
- package/src/inference/pipelines/text/attention/record.js +11 -2
- package/src/inference/pipelines/text/attention/run.js +11 -2
- package/src/inference/pipelines/text/chat-format.js +25 -1
- package/src/inference/pipelines/text/config.d.ts +4 -0
- package/src/inference/pipelines/text/config.js +68 -1
- package/src/inference/pipelines/text/execution-plan.js +23 -31
- package/src/inference/pipelines/text/execution-v0.js +29 -2
- package/src/inference/pipelines/text/ffn/standard.js +3 -0
- package/src/inference/pipelines/text/init.d.ts +4 -0
- package/src/inference/pipelines/text/init.js +56 -9
- package/src/inference/pipelines/text/layer.js +11 -0
- package/src/inference/pipelines/text.js +4 -0
- package/src/inference/tokenizers/bundled.js +156 -33
- package/src/rules/tooling/command-runtime.rules.json +18 -0
- package/src/tooling/command-api.d.ts +27 -1
- package/src/tooling/command-api.js +142 -3
- package/src/tooling/node-browser-command-runner.d.ts +4 -0
- package/src/tooling/node-browser-command-runner.js +58 -3
- package/src/tooling/node-command-runner.js +15 -0
- package/src/tooling/node-webgpu.js +9 -87
- package/src/training/checkpoint-watch.d.ts +7 -0
- package/src/training/checkpoint-watch.js +106 -0
- package/src/training/checkpoint.d.ts +6 -1
- package/src/training/checkpoint.js +12 -2
- package/src/training/distillation/artifacts.d.ts +71 -0
- package/src/training/distillation/artifacts.js +132 -0
- package/src/training/distillation/checkpoint-watch.d.ts +10 -0
- package/src/training/distillation/checkpoint-watch.js +57 -0
- package/src/training/distillation/dataset.d.ts +59 -0
- package/src/training/distillation/dataset.js +337 -0
- package/src/training/distillation/eval.d.ts +34 -0
- package/src/training/distillation/eval.js +310 -0
- package/src/training/distillation/index.d.ts +29 -0
- package/src/training/distillation/index.js +29 -0
- package/src/training/distillation/runtime.d.ts +20 -0
- package/src/training/distillation/runtime.js +121 -0
- package/src/training/distillation/scoreboard.d.ts +6 -0
- package/src/training/distillation/scoreboard.js +8 -0
- package/src/training/distillation/stage-a.d.ts +45 -0
- package/src/training/distillation/stage-a.js +338 -0
- package/src/training/distillation/stage-b.d.ts +24 -0
- package/src/training/distillation/stage-b.js +20 -0
- package/src/training/index.d.ts +10 -0
- package/src/training/index.js +10 -0
- package/src/training/lora-pipeline.d.ts +40 -0
- package/src/training/lora-pipeline.js +796 -0
- package/src/training/operator-artifacts.d.ts +62 -0
- package/src/training/operator-artifacts.js +140 -0
- package/src/training/operator-command.d.ts +5 -0
- package/src/training/operator-command.js +453 -0
- package/src/training/operator-eval.d.ts +48 -0
- package/src/training/operator-eval.js +230 -0
- package/src/training/operator-scoreboard.d.ts +5 -0
- package/src/training/operator-scoreboard.js +44 -0
- package/src/training/runner.d.ts +52 -0
- package/src/training/runner.js +29 -4
- package/src/training/suite.d.ts +112 -0
- package/src/training/suite.js +9 -9
- package/src/training/workloads.d.ts +164 -0
- package/src/training/workloads.js +539 -0
- package/src/version.js +1 -1
- package/tools/doppler-cli.js +137 -40
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
import { readFile } from 'node:fs/promises';
|
|
2
|
+
import { resolve } from 'node:path';
|
|
3
|
+
|
|
4
|
+
import { parseJsonl } from './datasets/jsonl.js';
|
|
5
|
+
|
|
6
|
+
function asTokenSequence(text) {
|
|
7
|
+
return String(text ?? '')
|
|
8
|
+
.trim()
|
|
9
|
+
.split(/\s+/)
|
|
10
|
+
.filter(Boolean);
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
function extractCharacterNgrams(text, n) {
|
|
14
|
+
const normalized = Array.from(String(text ?? '').trim());
|
|
15
|
+
if (normalized.length < n) {
|
|
16
|
+
return new Map();
|
|
17
|
+
}
|
|
18
|
+
const grams = new Map();
|
|
19
|
+
for (let index = 0; index <= normalized.length - n; index += 1) {
|
|
20
|
+
const gram = normalized.slice(index, index + n).join('');
|
|
21
|
+
grams.set(gram, (grams.get(gram) || 0) + 1);
|
|
22
|
+
}
|
|
23
|
+
return grams;
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
function countOverlap(source, target) {
|
|
27
|
+
let overlap = 0;
|
|
28
|
+
for (const [key, sourceCount] of source.entries()) {
|
|
29
|
+
const targetCount = target.get(key) || 0;
|
|
30
|
+
overlap += Math.min(sourceCount, targetCount);
|
|
31
|
+
}
|
|
32
|
+
return overlap;
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
function computeBleuStats(hypotheses, references, maxOrder = 4) {
|
|
36
|
+
const matchesByOrder = new Array(maxOrder).fill(0);
|
|
37
|
+
const possibleByOrder = new Array(maxOrder).fill(0);
|
|
38
|
+
let hypothesisLength = 0;
|
|
39
|
+
let referenceLength = 0;
|
|
40
|
+
|
|
41
|
+
for (let index = 0; index < hypotheses.length; index += 1) {
|
|
42
|
+
const hypothesis = asTokenSequence(hypotheses[index]);
|
|
43
|
+
const reference = asTokenSequence(references[index]);
|
|
44
|
+
hypothesisLength += hypothesis.length;
|
|
45
|
+
referenceLength += reference.length;
|
|
46
|
+
for (let order = 1; order <= maxOrder; order += 1) {
|
|
47
|
+
const hypothesisCounts = new Map();
|
|
48
|
+
const referenceCounts = new Map();
|
|
49
|
+
for (let tokenIndex = 0; tokenIndex <= hypothesis.length - order; tokenIndex += 1) {
|
|
50
|
+
const ngram = hypothesis.slice(tokenIndex, tokenIndex + order).join('\u0001');
|
|
51
|
+
hypothesisCounts.set(ngram, (hypothesisCounts.get(ngram) || 0) + 1);
|
|
52
|
+
}
|
|
53
|
+
for (let tokenIndex = 0; tokenIndex <= reference.length - order; tokenIndex += 1) {
|
|
54
|
+
const ngram = reference.slice(tokenIndex, tokenIndex + order).join('\u0001');
|
|
55
|
+
referenceCounts.set(ngram, (referenceCounts.get(ngram) || 0) + 1);
|
|
56
|
+
}
|
|
57
|
+
matchesByOrder[order - 1] += countOverlap(hypothesisCounts, referenceCounts);
|
|
58
|
+
possibleByOrder[order - 1] += Math.max(0, hypothesis.length - order + 1);
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
return {
|
|
63
|
+
matchesByOrder,
|
|
64
|
+
possibleByOrder,
|
|
65
|
+
hypothesisLength,
|
|
66
|
+
referenceLength,
|
|
67
|
+
};
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
export function computeBleuScore(hypotheses, references, options = {}) {
|
|
71
|
+
const maxOrder = Number.isInteger(options.maxOrder) && options.maxOrder > 0
|
|
72
|
+
? options.maxOrder
|
|
73
|
+
: 4;
|
|
74
|
+
if (!Array.isArray(hypotheses) || !Array.isArray(references) || hypotheses.length !== references.length) {
|
|
75
|
+
throw new Error('computeBleuScore requires equally sized hypothesis and reference arrays.');
|
|
76
|
+
}
|
|
77
|
+
if (hypotheses.length === 0) {
|
|
78
|
+
return {
|
|
79
|
+
score: 0,
|
|
80
|
+
brevityPenalty: 0,
|
|
81
|
+
precisions: new Array(maxOrder).fill(0),
|
|
82
|
+
hypothesisLength: 0,
|
|
83
|
+
referenceLength: 0,
|
|
84
|
+
};
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
const stats = computeBleuStats(hypotheses, references, maxOrder);
|
|
88
|
+
const precisions = [];
|
|
89
|
+
let precisionLogSum = 0;
|
|
90
|
+
for (let order = 0; order < maxOrder; order += 1) {
|
|
91
|
+
const matches = stats.matchesByOrder[order];
|
|
92
|
+
const possible = stats.possibleByOrder[order];
|
|
93
|
+
const precision = possible === 0
|
|
94
|
+
? 0
|
|
95
|
+
: ((matches + 1) / (possible + 1));
|
|
96
|
+
precisions.push(precision);
|
|
97
|
+
precisionLogSum += Math.log(Math.max(precision, 1e-16));
|
|
98
|
+
}
|
|
99
|
+
const brevityPenalty = stats.hypothesisLength > stats.referenceLength
|
|
100
|
+
? 1
|
|
101
|
+
: Math.exp(1 - (stats.referenceLength / Math.max(stats.hypothesisLength, 1)));
|
|
102
|
+
const score = brevityPenalty * Math.exp(precisionLogSum / maxOrder);
|
|
103
|
+
return {
|
|
104
|
+
score,
|
|
105
|
+
brevityPenalty,
|
|
106
|
+
precisions,
|
|
107
|
+
hypothesisLength: stats.hypothesisLength,
|
|
108
|
+
referenceLength: stats.referenceLength,
|
|
109
|
+
};
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
export function computeChrfScore(hypotheses, references, options = {}) {
|
|
113
|
+
const maxOrder = Number.isInteger(options.maxOrder) && options.maxOrder > 0
|
|
114
|
+
? options.maxOrder
|
|
115
|
+
: 6;
|
|
116
|
+
const beta = Number.isFinite(options.beta) && options.beta > 0 ? options.beta : 2;
|
|
117
|
+
if (!Array.isArray(hypotheses) || !Array.isArray(references) || hypotheses.length !== references.length) {
|
|
118
|
+
throw new Error('computeChrfScore requires equally sized hypothesis and reference arrays.');
|
|
119
|
+
}
|
|
120
|
+
if (hypotheses.length === 0) {
|
|
121
|
+
return {
|
|
122
|
+
score: 0,
|
|
123
|
+
precision: 0,
|
|
124
|
+
recall: 0,
|
|
125
|
+
};
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
let precisionSum = 0;
|
|
129
|
+
let recallSum = 0;
|
|
130
|
+
for (let order = 1; order <= maxOrder; order += 1) {
|
|
131
|
+
let overlap = 0;
|
|
132
|
+
let hypothesisTotal = 0;
|
|
133
|
+
let referenceTotal = 0;
|
|
134
|
+
for (let index = 0; index < hypotheses.length; index += 1) {
|
|
135
|
+
const hypothesisCounts = extractCharacterNgrams(hypotheses[index], order);
|
|
136
|
+
const referenceCounts = extractCharacterNgrams(references[index], order);
|
|
137
|
+
overlap += countOverlap(hypothesisCounts, referenceCounts);
|
|
138
|
+
for (const value of hypothesisCounts.values()) {
|
|
139
|
+
hypothesisTotal += value;
|
|
140
|
+
}
|
|
141
|
+
for (const value of referenceCounts.values()) {
|
|
142
|
+
referenceTotal += value;
|
|
143
|
+
}
|
|
144
|
+
}
|
|
145
|
+
precisionSum += hypothesisTotal > 0 ? (overlap / hypothesisTotal) : 0;
|
|
146
|
+
recallSum += referenceTotal > 0 ? (overlap / referenceTotal) : 0;
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
const precision = precisionSum / maxOrder;
|
|
150
|
+
const recall = recallSum / maxOrder;
|
|
151
|
+
const betaSquared = beta * beta;
|
|
152
|
+
const score = (precision + recall) === 0
|
|
153
|
+
? 0
|
|
154
|
+
: ((1 + betaSquared) * precision * recall) / ((betaSquared * precision) + recall);
|
|
155
|
+
return { score, precision, recall };
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
export function computeExactMatch(hypotheses, references) {
|
|
159
|
+
if (!Array.isArray(hypotheses) || !Array.isArray(references) || hypotheses.length !== references.length) {
|
|
160
|
+
throw new Error('computeExactMatch requires equally sized hypothesis and reference arrays.');
|
|
161
|
+
}
|
|
162
|
+
if (hypotheses.length === 0) {
|
|
163
|
+
return { score: 0, matches: 0, total: 0 };
|
|
164
|
+
}
|
|
165
|
+
let matches = 0;
|
|
166
|
+
for (let index = 0; index < hypotheses.length; index += 1) {
|
|
167
|
+
if (String(hypotheses[index] ?? '').trim() === String(references[index] ?? '').trim()) {
|
|
168
|
+
matches += 1;
|
|
169
|
+
}
|
|
170
|
+
}
|
|
171
|
+
return {
|
|
172
|
+
score: matches / hypotheses.length,
|
|
173
|
+
matches,
|
|
174
|
+
total: hypotheses.length,
|
|
175
|
+
};
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
export function computeAccuracy(labels, predictions) {
|
|
179
|
+
return computeExactMatch(predictions, labels);
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
export function computeEvalMetrics(evalKind, hypotheses, references, options = {}) {
|
|
183
|
+
const normalizedKind = String(evalKind || '').trim();
|
|
184
|
+
if (normalizedKind === 'translation') {
|
|
185
|
+
const bleu = computeBleuScore(hypotheses, references, options.bleu || {});
|
|
186
|
+
const chrf = computeChrfScore(hypotheses, references, options.chrf || {});
|
|
187
|
+
return {
|
|
188
|
+
bleu,
|
|
189
|
+
chrf,
|
|
190
|
+
primaryMetric: 'bleu',
|
|
191
|
+
primaryScore: bleu.score,
|
|
192
|
+
};
|
|
193
|
+
}
|
|
194
|
+
if (normalizedKind === 'text_generation') {
|
|
195
|
+
const exactMatch = computeExactMatch(hypotheses, references);
|
|
196
|
+
return {
|
|
197
|
+
exactMatch,
|
|
198
|
+
primaryMetric: 'exact_match',
|
|
199
|
+
primaryScore: exactMatch.score,
|
|
200
|
+
};
|
|
201
|
+
}
|
|
202
|
+
if (normalizedKind === 'classification') {
|
|
203
|
+
const accuracy = computeAccuracy(references, hypotheses);
|
|
204
|
+
return {
|
|
205
|
+
accuracy,
|
|
206
|
+
primaryMetric: 'accuracy',
|
|
207
|
+
primaryScore: accuracy.score,
|
|
208
|
+
};
|
|
209
|
+
}
|
|
210
|
+
if (normalizedKind === 'retrieval' || normalizedKind === 'custom') {
|
|
211
|
+
throw new Error(`Eval kind "${normalizedKind}" requires a custom evaluator and is not yet implemented.`);
|
|
212
|
+
}
|
|
213
|
+
throw new Error(`Unsupported eval kind "${normalizedKind}".`);
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
export async function loadEvalDataset(datasetPath) {
|
|
217
|
+
const absolutePath = resolve(String(datasetPath));
|
|
218
|
+
const raw = await readFile(absolutePath, 'utf8');
|
|
219
|
+
const rows = absolutePath.endsWith('.json')
|
|
220
|
+
? JSON.parse(raw)
|
|
221
|
+
: parseJsonl(raw);
|
|
222
|
+
if (!Array.isArray(rows)) {
|
|
223
|
+
throw new Error(`Eval dataset "${absolutePath}" must be a JSON array or JSONL file.`);
|
|
224
|
+
}
|
|
225
|
+
return {
|
|
226
|
+
absolutePath,
|
|
227
|
+
rows,
|
|
228
|
+
raw,
|
|
229
|
+
};
|
|
230
|
+
}
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
import { join } from 'node:path';
|
|
2
|
+
|
|
3
|
+
import { writeJsonArtifact, writeNdjsonRow } from './operator-artifacts.js';
|
|
4
|
+
|
|
5
|
+
function resolveComparableMetric(row, metric) {
|
|
6
|
+
if (!row || typeof row !== 'object') return null;
|
|
7
|
+
const direct = row[metric];
|
|
8
|
+
if (typeof direct === 'number' && Number.isFinite(direct)) {
|
|
9
|
+
return direct;
|
|
10
|
+
}
|
|
11
|
+
const metrics = row.metrics && typeof row.metrics === 'object' ? row.metrics : null;
|
|
12
|
+
const nested = metrics?.[metric];
|
|
13
|
+
if (typeof nested === 'number' && Number.isFinite(nested)) {
|
|
14
|
+
return nested;
|
|
15
|
+
}
|
|
16
|
+
return null;
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
export async function appendScoreboardRow(scoreboardDir, row, options = {}) {
|
|
20
|
+
const rowsPath = join(scoreboardDir, 'scoreboard.ndjson');
|
|
21
|
+
await writeNdjsonRow(rowsPath, row);
|
|
22
|
+
const metric = String(options.selectionMetric || row.selectionMetric || row.primaryMetric || '').trim();
|
|
23
|
+
const goal = String(options.selectionGoal || row.selectionGoal || 'max').trim();
|
|
24
|
+
const comparable = resolveComparableMetric(row, metric);
|
|
25
|
+
const summary = {
|
|
26
|
+
artifactType: 'training_scoreboard',
|
|
27
|
+
schemaVersion: 1,
|
|
28
|
+
generatedAt: new Date().toISOString(),
|
|
29
|
+
selectionMetric: metric || null,
|
|
30
|
+
selectionGoal: goal,
|
|
31
|
+
latest: row,
|
|
32
|
+
best: comparable === null
|
|
33
|
+
? row
|
|
34
|
+
: {
|
|
35
|
+
...row,
|
|
36
|
+
selectionMetricValue: comparable,
|
|
37
|
+
},
|
|
38
|
+
};
|
|
39
|
+
const summaryResult = await writeJsonArtifact(join(scoreboardDir, 'latest.json'), summary);
|
|
40
|
+
return {
|
|
41
|
+
rowsPath,
|
|
42
|
+
summaryPath: summaryResult.path,
|
|
43
|
+
};
|
|
44
|
+
}
|
package/src/training/runner.d.ts
CHANGED
|
@@ -90,6 +90,16 @@ export interface TrainingStepMetricsEntry {
|
|
|
90
90
|
export interface TrainingRunnerCallbacks {
|
|
91
91
|
onStep?: (entry: TrainingStepMetricsEntry) => Promise<void> | void;
|
|
92
92
|
onEpoch?: (entry: { epoch: number; steps: number; loss: number }) => Promise<void> | void;
|
|
93
|
+
onCheckpoint?: (entry: {
|
|
94
|
+
key: string;
|
|
95
|
+
defaultCheckpointKey: string | null;
|
|
96
|
+
path: string | null;
|
|
97
|
+
metadata: Record<string, unknown> | null;
|
|
98
|
+
payload: unknown;
|
|
99
|
+
step: number;
|
|
100
|
+
epoch: number;
|
|
101
|
+
batch: number;
|
|
102
|
+
}) => Promise<void> | void;
|
|
93
103
|
}
|
|
94
104
|
|
|
95
105
|
export interface TrainingRunnerOptions extends TrainingRunnerCallbacks {
|
|
@@ -106,6 +116,12 @@ export interface TrainingRunnerOptions extends TrainingRunnerCallbacks {
|
|
|
106
116
|
) => Promise<ClipMetrics>;
|
|
107
117
|
lossScaler?: DynamicLossScaler;
|
|
108
118
|
trainingObjective?: TrainingObjective;
|
|
119
|
+
resolveCheckpointKey?: (entry: {
|
|
120
|
+
defaultCheckpointKey: string | null;
|
|
121
|
+
step: number;
|
|
122
|
+
epoch: number;
|
|
123
|
+
batch: number;
|
|
124
|
+
}) => Promise<string> | string;
|
|
109
125
|
}
|
|
110
126
|
|
|
111
127
|
export interface TrainingRunOptions {
|
|
@@ -159,6 +175,9 @@ export declare class TrainingRunner {
|
|
|
159
175
|
lastArtifact: UlArtifactFinalizeResult | DistillArtifactFinalizeResult | null;
|
|
160
176
|
lastCheckpoint: {
|
|
161
177
|
key: string;
|
|
178
|
+
defaultKey?: string | null;
|
|
179
|
+
path?: string | null;
|
|
180
|
+
metadata?: Record<string, unknown> | null;
|
|
162
181
|
step: number;
|
|
163
182
|
epoch: number;
|
|
164
183
|
batch: number;
|
|
@@ -194,3 +213,36 @@ export declare function runTraining(
|
|
|
194
213
|
config: TrainingConfigSchema,
|
|
195
214
|
options?: TrainingRunOptions & TrainingRunnerOptions
|
|
196
215
|
): Promise<TrainingStepMetricsEntry[]>;
|
|
216
|
+
|
|
217
|
+
export declare function createTrainingCheckpointPayload(
|
|
218
|
+
model: {
|
|
219
|
+
loraParams?: () => Tensor[];
|
|
220
|
+
paramGroups?: () => Record<string, Tensor[]>;
|
|
221
|
+
},
|
|
222
|
+
optimizer: unknown,
|
|
223
|
+
context: {
|
|
224
|
+
step: number;
|
|
225
|
+
epoch: number;
|
|
226
|
+
batch: number;
|
|
227
|
+
config: TrainingConfigSchema;
|
|
228
|
+
}
|
|
229
|
+
): Promise<unknown>;
|
|
230
|
+
|
|
231
|
+
export declare function restoreTrainingCheckpointState(
|
|
232
|
+
model: {
|
|
233
|
+
loraParams?: () => Tensor[];
|
|
234
|
+
paramGroups?: () => Record<string, Tensor[]>;
|
|
235
|
+
},
|
|
236
|
+
optimizer: unknown,
|
|
237
|
+
checkpointRecord: unknown,
|
|
238
|
+
config: TrainingConfigSchema
|
|
239
|
+
): Promise<{
|
|
240
|
+
step: number;
|
|
241
|
+
epoch: number;
|
|
242
|
+
batch: number;
|
|
243
|
+
checkpointHash: string | null;
|
|
244
|
+
previousCheckpointHash: string | null;
|
|
245
|
+
checkpointKey: string | null;
|
|
246
|
+
resumeAudits: Array<Record<string, unknown>>;
|
|
247
|
+
resumeAuditCount: number;
|
|
248
|
+
} | null>;
|
package/src/training/runner.js
CHANGED
|
@@ -713,7 +713,7 @@ function looksLikeTrainingCheckpointRecord(value) {
|
|
|
713
713
|
return Number.isInteger(progress.step) && progress.step >= 0;
|
|
714
714
|
}
|
|
715
715
|
|
|
716
|
-
async function createTrainingCheckpointPayload(model, optimizer, context) {
|
|
716
|
+
export async function createTrainingCheckpointPayload(model, optimizer, context) {
|
|
717
717
|
const freezeMap = context.config?.training?.ul?.freeze
|
|
718
718
|
?? context.config?.training?.distill?.freeze
|
|
719
719
|
?? {};
|
|
@@ -747,7 +747,7 @@ async function createTrainingCheckpointPayload(model, optimizer, context) {
|
|
|
747
747
|
};
|
|
748
748
|
}
|
|
749
749
|
|
|
750
|
-
async function restoreTrainingCheckpointState(model, optimizer, checkpointRecord, config) {
|
|
750
|
+
export async function restoreTrainingCheckpointState(model, optimizer, checkpointRecord, config) {
|
|
751
751
|
if (!looksLikeTrainingCheckpointRecord(checkpointRecord)) {
|
|
752
752
|
return null;
|
|
753
753
|
}
|
|
@@ -837,6 +837,8 @@ export class TrainingRunner {
|
|
|
837
837
|
this.lossScaler = options.lossScaler || new DynamicLossScaler(config.training.lossScaling);
|
|
838
838
|
this.onStep = options.onStep || null;
|
|
839
839
|
this.onEpoch = options.onEpoch || null;
|
|
840
|
+
this.onCheckpoint = options.onCheckpoint || null;
|
|
841
|
+
this.resolveCheckpointKey = options.resolveCheckpointKey || null;
|
|
840
842
|
this.lastArtifact = null;
|
|
841
843
|
this.lastCheckpoint = null;
|
|
842
844
|
this.resumeState = null;
|
|
@@ -911,16 +913,39 @@ export class TrainingRunner {
|
|
|
911
913
|
batch: checkpointContext.batch,
|
|
912
914
|
config: this.config,
|
|
913
915
|
});
|
|
914
|
-
|
|
916
|
+
const resolvedCheckpointKey = this.resolveCheckpointKey
|
|
917
|
+
? await this.resolveCheckpointKey({
|
|
918
|
+
defaultCheckpointKey: checkpointKey,
|
|
919
|
+
step: checkpointContext.step,
|
|
920
|
+
epoch: checkpointContext.epoch,
|
|
921
|
+
batch: checkpointContext.batch,
|
|
922
|
+
})
|
|
923
|
+
: checkpointKey;
|
|
924
|
+
const saveResult = await saveCheckpoint(resolvedCheckpointKey, payload, {
|
|
915
925
|
...checkpointMetadata,
|
|
916
926
|
optimizerHash: hashStableJson(payload?.trainingState?.optimizerSlots || {}),
|
|
917
927
|
});
|
|
918
928
|
this.lastCheckpoint = {
|
|
919
|
-
key:
|
|
929
|
+
key: resolvedCheckpointKey,
|
|
930
|
+
defaultKey: checkpointKey,
|
|
931
|
+
path: saveResult?.path || null,
|
|
932
|
+
metadata: saveResult?.metadata || null,
|
|
920
933
|
step: checkpointContext.step,
|
|
921
934
|
epoch: checkpointContext.epoch,
|
|
922
935
|
batch: checkpointContext.batch,
|
|
923
936
|
};
|
|
937
|
+
if (this.onCheckpoint) {
|
|
938
|
+
await this.onCheckpoint({
|
|
939
|
+
key: resolvedCheckpointKey,
|
|
940
|
+
defaultCheckpointKey: checkpointKey,
|
|
941
|
+
path: saveResult?.path || null,
|
|
942
|
+
metadata: saveResult?.metadata || null,
|
|
943
|
+
payload,
|
|
944
|
+
step: checkpointContext.step,
|
|
945
|
+
epoch: checkpointContext.epoch,
|
|
946
|
+
batch: checkpointContext.batch,
|
|
947
|
+
});
|
|
948
|
+
}
|
|
924
949
|
};
|
|
925
950
|
|
|
926
951
|
const artifactSession = distillContract.enabled
|
package/src/training/suite.d.ts
CHANGED
|
@@ -176,6 +176,66 @@ export interface RunTrainingSuiteOptions {
|
|
|
176
176
|
timestamp?: string | Date;
|
|
177
177
|
}
|
|
178
178
|
|
|
179
|
+
export interface DistillDataScope {
|
|
180
|
+
sourceLangs: string[] | null;
|
|
181
|
+
targetLangs: string[] | null;
|
|
182
|
+
pairAllowlist: string[] | null;
|
|
183
|
+
sourceLangSet: Set<string> | null;
|
|
184
|
+
targetLangSet: Set<string> | null;
|
|
185
|
+
pairAllowlistSet: Set<string> | null;
|
|
186
|
+
strictPairContract: boolean;
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
export interface DistillDatasetReport {
|
|
190
|
+
absolutePath: string;
|
|
191
|
+
rowCount: number;
|
|
192
|
+
sampleCount: number;
|
|
193
|
+
directionCounts: Record<string, number>;
|
|
194
|
+
dataScope: {
|
|
195
|
+
sourceLangs: string[] | null;
|
|
196
|
+
targetLangs: string[] | null;
|
|
197
|
+
pairAllowlist: string[] | null;
|
|
198
|
+
strictPairContract: boolean;
|
|
199
|
+
} | null;
|
|
200
|
+
shardCount?: number;
|
|
201
|
+
shardPaths?: string[];
|
|
202
|
+
createDataset(options?: Record<string, unknown>): {
|
|
203
|
+
batches(): AsyncGenerator<Record<string, unknown>, void, unknown>;
|
|
204
|
+
};
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
export interface DistillRuntimeContext {
|
|
208
|
+
stage: 'stage_a' | 'stage_b';
|
|
209
|
+
teacherPipeline: Record<string, unknown>;
|
|
210
|
+
studentPipeline: Record<string, unknown>;
|
|
211
|
+
teacherModelId: string;
|
|
212
|
+
studentModelId: string;
|
|
213
|
+
teacherModelUrl: string | null;
|
|
214
|
+
studentModelUrl: string | null;
|
|
215
|
+
topK: number;
|
|
216
|
+
temperature: number;
|
|
217
|
+
alphaKd: number;
|
|
218
|
+
alphaCe: number;
|
|
219
|
+
tripletMargin: number;
|
|
220
|
+
studentGraphMode: string;
|
|
221
|
+
targetTokenMode: string;
|
|
222
|
+
cleanup(): Promise<void>;
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
export interface DistillStudentFixture {
|
|
226
|
+
config: Record<string, unknown>;
|
|
227
|
+
model: {
|
|
228
|
+
forward: (input: unknown, tape: unknown) => Promise<unknown>;
|
|
229
|
+
forwardDistill?: (batch: unknown, tape: unknown, options?: Record<string, unknown>) => Promise<{ logits: unknown }>;
|
|
230
|
+
cleanupDistillStep?: () => void;
|
|
231
|
+
loraParams?: () => unknown[];
|
|
232
|
+
paramGroups?: () => Record<string, unknown[]>;
|
|
233
|
+
};
|
|
234
|
+
outputDim?: number;
|
|
235
|
+
embeddingDim?: number;
|
|
236
|
+
cleanup(): void;
|
|
237
|
+
}
|
|
238
|
+
|
|
179
239
|
export declare const trainingHarness: TrainingHarness;
|
|
180
240
|
|
|
181
241
|
export declare function runTrainingSuite(
|
|
@@ -185,3 +245,55 @@ export declare function runTrainingSuite(
|
|
|
185
245
|
export declare function runTrainingBenchSuite(
|
|
186
246
|
options?: RunTrainingSuiteOptions
|
|
187
247
|
): Promise<TrainingBenchSuiteResult>;
|
|
248
|
+
|
|
249
|
+
export declare function resolveDistillDataScope(
|
|
250
|
+
options?: RunTrainingSuiteOptions,
|
|
251
|
+
trainingConfig?: Record<string, unknown> | null
|
|
252
|
+
): DistillDataScope;
|
|
253
|
+
|
|
254
|
+
export declare function buildDistillPrompt(sample: Record<string, unknown>): string;
|
|
255
|
+
|
|
256
|
+
export declare function normalizeDistillStudentGraphMode(value: unknown): string;
|
|
257
|
+
|
|
258
|
+
export declare function loadDistillDatasetFromJsonl(
|
|
259
|
+
datasetPath: string,
|
|
260
|
+
scopeOptions?: DistillDataScope | null
|
|
261
|
+
): Promise<DistillDatasetReport | null>;
|
|
262
|
+
|
|
263
|
+
export declare function loadDistillModelHandle(
|
|
264
|
+
modelRef: string,
|
|
265
|
+
role: string,
|
|
266
|
+
loadOptions?: Record<string, unknown>
|
|
267
|
+
): Promise<{
|
|
268
|
+
modelRef: string;
|
|
269
|
+
modelUrl: string | null;
|
|
270
|
+
manifest: Record<string, unknown>;
|
|
271
|
+
pipeline: Record<string, unknown>;
|
|
272
|
+
}>;
|
|
273
|
+
|
|
274
|
+
export declare function createDistillRuntimeContext(
|
|
275
|
+
options?: RunTrainingSuiteOptions,
|
|
276
|
+
trainingConfig?: Record<string, unknown> | null
|
|
277
|
+
): Promise<DistillRuntimeContext>;
|
|
278
|
+
|
|
279
|
+
export declare function createToyModelFixture(
|
|
280
|
+
overrides?: Record<string, unknown>
|
|
281
|
+
): {
|
|
282
|
+
config: Record<string, unknown>;
|
|
283
|
+
model: {
|
|
284
|
+
forward: (input: unknown, tape: unknown) => Promise<unknown>;
|
|
285
|
+
loraParams(): unknown[];
|
|
286
|
+
paramGroups(): Record<string, unknown[]>;
|
|
287
|
+
};
|
|
288
|
+
batch: Record<string, unknown>;
|
|
289
|
+
cleanup(): void;
|
|
290
|
+
};
|
|
291
|
+
|
|
292
|
+
export declare function createDistillStudentRuntimeModelFixture(
|
|
293
|
+
overrides?: Record<string, unknown>,
|
|
294
|
+
options?: Record<string, unknown>
|
|
295
|
+
): Promise<DistillStudentFixture>;
|
|
296
|
+
|
|
297
|
+
export declare function buildDistillTrainingOverrides(
|
|
298
|
+
options?: RunTrainingSuiteOptions
|
|
299
|
+
): Record<string, unknown> | null;
|
package/src/training/suite.js
CHANGED
|
@@ -190,7 +190,7 @@ function normalizeDistillPairAllowlist(value) {
|
|
|
190
190
|
return [...new Set(normalized)];
|
|
191
191
|
}
|
|
192
192
|
|
|
193
|
-
function resolveDistillDataScope(options = {}, trainingConfig = null) {
|
|
193
|
+
export function resolveDistillDataScope(options = {}, trainingConfig = null) {
|
|
194
194
|
const distillConfig = trainingConfig?.distill || {};
|
|
195
195
|
const sourceLangs = normalizeDistillLanguageAllowlist(
|
|
196
196
|
options.distillSourceLangs ?? distillConfig.sourceLangs ?? null
|
|
@@ -301,7 +301,7 @@ function resolveLanguageName(langCode) {
|
|
|
301
301
|
return normalized || 'target';
|
|
302
302
|
}
|
|
303
303
|
|
|
304
|
-
function buildDistillPrompt(sample) {
|
|
304
|
+
export function buildDistillPrompt(sample) {
|
|
305
305
|
const direction = String(sample?.direction || '').trim();
|
|
306
306
|
const [srcCodeRaw, tgtCodeRaw] = direction.split('->');
|
|
307
307
|
const srcCode = normalizeLangCode(srcCodeRaw) || srcCodeRaw || 'source';
|
|
@@ -328,7 +328,7 @@ function clampDistillTopK(value) {
|
|
|
328
328
|
return Math.max(2, Math.min(256, parsed));
|
|
329
329
|
}
|
|
330
330
|
|
|
331
|
-
function normalizeDistillStudentGraphMode(value) {
|
|
331
|
+
export function normalizeDistillStudentGraphMode(value) {
|
|
332
332
|
const normalized = normalizeOptionalString(value);
|
|
333
333
|
if (!normalized) return DISTILL_STUDENT_GRAPH_FULL;
|
|
334
334
|
const compact = normalized.toLowerCase().replace(/[-\s]/g, '_');
|
|
@@ -605,7 +605,7 @@ function createDistillTensorDataset(samples, options = {}) {
|
|
|
605
605
|
};
|
|
606
606
|
}
|
|
607
607
|
|
|
608
|
-
async function loadDistillDatasetFromJsonl(datasetPath, scopeOptions = null) {
|
|
608
|
+
export async function loadDistillDatasetFromJsonl(datasetPath, scopeOptions = null) {
|
|
609
609
|
const normalizedPath = normalizeDistillDatasetPath(datasetPath);
|
|
610
610
|
if (!normalizedPath) return null;
|
|
611
611
|
if (!isNodeRuntime()) {
|
|
@@ -820,7 +820,7 @@ async function initializeInferenceFromStore(modelId) {
|
|
|
820
820
|
return { pipeline, manifest };
|
|
821
821
|
}
|
|
822
822
|
|
|
823
|
-
async function loadDistillModelHandle(modelRef, role, loadOptions = {}) {
|
|
823
|
+
export async function loadDistillModelHandle(modelRef, role, loadOptions = {}) {
|
|
824
824
|
const normalizedRef = normalizeOptionalString(modelRef);
|
|
825
825
|
if (!normalizedRef) {
|
|
826
826
|
throw new Error(`Distill ${role} model reference is required.`);
|
|
@@ -876,7 +876,7 @@ function resolveDistillModelRefs(options = {}, trainingConfig = null) {
|
|
|
876
876
|
};
|
|
877
877
|
}
|
|
878
878
|
|
|
879
|
-
async function createDistillRuntimeContext(options = {}, trainingConfig = null) {
|
|
879
|
+
export async function createDistillRuntimeContext(options = {}, trainingConfig = null) {
|
|
880
880
|
const { teacherModelRef, studentModelRef } = resolveDistillModelRefs(options, trainingConfig);
|
|
881
881
|
if (!teacherModelRef || !studentModelRef) {
|
|
882
882
|
throw new Error('Distill stage requires teacherModelId and studentModelId.');
|
|
@@ -967,7 +967,7 @@ async function ensureTrainingGpuRuntime() {
|
|
|
967
967
|
await initDevice();
|
|
968
968
|
}
|
|
969
969
|
|
|
970
|
-
function createToyModelFixture(overrides = {}) {
|
|
970
|
+
export function createToyModelFixture(overrides = {}) {
|
|
971
971
|
const config = createTrainingConfig({
|
|
972
972
|
...overrides,
|
|
973
973
|
training: {
|
|
@@ -1790,7 +1790,7 @@ async function createDistillStudentTransformerModelFixture(overrides = {}, optio
|
|
|
1790
1790
|
};
|
|
1791
1791
|
}
|
|
1792
1792
|
|
|
1793
|
-
async function createDistillStudentRuntimeModelFixture(overrides = {}, options = {}) {
|
|
1793
|
+
export async function createDistillStudentRuntimeModelFixture(overrides = {}, options = {}) {
|
|
1794
1794
|
const distillRuntime = options.distillRuntime && typeof options.distillRuntime === 'object'
|
|
1795
1795
|
? options.distillRuntime
|
|
1796
1796
|
: null;
|
|
@@ -2085,7 +2085,7 @@ function buildUlTrainingOverrides(options = {}) {
|
|
|
2085
2085
|
};
|
|
2086
2086
|
}
|
|
2087
2087
|
|
|
2088
|
-
function buildDistillTrainingOverrides(options = {}) {
|
|
2088
|
+
export function buildDistillTrainingOverrides(options = {}) {
|
|
2089
2089
|
const trainingConfig = normalizeTrainingConfigOverride(options.trainingConfig);
|
|
2090
2090
|
const explicitStage = normalizeTrainingStage(options.trainingStage || trainingConfig?.distill?.stage);
|
|
2091
2091
|
const distillEnabled = isDistillStage(explicitStage) || trainingConfig?.distill?.enabled === true;
|