@simulatte/doppler 0.1.5 → 0.1.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +23 -8
- package/package.json +7 -4
- package/src/config/kernels/kernel-ref-digests.js +39 -39
- package/src/config/kernels/registry.json +42 -2
- package/src/config/loader.js +31 -2
- package/src/config/merge.js +18 -0
- package/src/config/presets/models/qwen3.json +9 -2
- package/src/config/presets/models/transformer.json +5 -0
- package/src/config/required-inference-fields-contract-check.js +6 -0
- package/src/config/schema/inference-defaults.schema.js +3 -0
- package/src/config/schema/inference.schema.d.ts +9 -0
- package/src/config/schema/kernel-path.schema.d.ts +6 -0
- package/src/config/schema/manifest.schema.d.ts +6 -0
- package/src/config/schema/manifest.schema.js +3 -0
- package/src/converter/rope-config.js +42 -0
- package/src/gpu/device.js +58 -0
- package/src/gpu/kernels/attention.js +98 -0
- package/src/gpu/kernels/bias_add.wgsl +8 -6
- package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
- package/src/gpu/kernels/conv2d.js +1 -1
- package/src/gpu/kernels/conv2d.wgsl +7 -8
- package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
- package/src/gpu/kernels/depthwise_conv2d.js +2 -1
- package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
- package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +2 -1
- package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
- package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
- package/src/gpu/kernels/matmul.js +25 -0
- package/src/gpu/kernels/pixel_shuffle.js +1 -1
- package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
- package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
- package/src/gpu/kernels/relu.js +15 -2
- package/src/gpu/kernels/relu.wgsl +2 -1
- package/src/gpu/kernels/relu_f16.wgsl +2 -1
- package/src/gpu/kernels/repeat_channels.js +1 -1
- package/src/gpu/kernels/repeat_channels.wgsl +4 -5
- package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
- package/src/gpu/kernels/residual.js +44 -8
- package/src/gpu/kernels/residual.wgsl +6 -3
- package/src/gpu/kernels/residual_f16.wgsl +2 -1
- package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
- package/src/gpu/kernels/residual_vec4.wgsl +2 -1
- package/src/gpu/kernels/rmsnorm.js +58 -6
- package/src/gpu/kernels/rmsnorm.wgsl +14 -6
- package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
- package/src/gpu/kernels/rope.d.ts +2 -0
- package/src/gpu/kernels/rope.js +11 -1
- package/src/gpu/kernels/rope.wgsl +56 -40
- package/src/gpu/kernels/sana_linear_attention.js +1 -2
- package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
- package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
- package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
- package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
- package/src/gpu/kernels/silu.d.ts +1 -0
- package/src/gpu/kernels/silu.js +32 -14
- package/src/gpu/kernels/silu.wgsl +19 -9
- package/src/gpu/kernels/silu_f16.wgsl +19 -9
- package/src/gpu/kernels/transpose.js +15 -2
- package/src/gpu/kernels/transpose.wgsl +5 -6
- package/src/gpu/kernels/upsample2d.js +2 -1
- package/src/gpu/kernels/upsample2d.wgsl +6 -9
- package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
- package/src/gpu/kernels/utils.js +16 -1
- package/src/inference/browser-harness.js +47 -1
- package/src/inference/pipelines/diffusion/pipeline.js +15 -6
- package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
- package/src/inference/pipelines/text/attention/record.js +11 -2
- package/src/inference/pipelines/text/attention/run.js +11 -2
- package/src/inference/pipelines/text/chat-format.js +25 -1
- package/src/inference/pipelines/text/config.d.ts +4 -0
- package/src/inference/pipelines/text/config.js +68 -1
- package/src/inference/pipelines/text/execution-plan.js +23 -31
- package/src/inference/pipelines/text/execution-v0.js +29 -2
- package/src/inference/pipelines/text/ffn/standard.js +3 -0
- package/src/inference/pipelines/text/init.d.ts +4 -0
- package/src/inference/pipelines/text/init.js +56 -9
- package/src/inference/pipelines/text/layer.js +11 -0
- package/src/inference/pipelines/text.js +4 -0
- package/src/inference/tokenizers/bundled.js +156 -33
- package/src/rules/tooling/command-runtime.rules.json +18 -0
- package/src/tooling/command-api.d.ts +27 -1
- package/src/tooling/command-api.js +142 -3
- package/src/tooling/node-browser-command-runner.d.ts +4 -0
- package/src/tooling/node-browser-command-runner.js +58 -3
- package/src/tooling/node-command-runner.js +15 -0
- package/src/tooling/node-webgpu.js +9 -87
- package/src/training/checkpoint-watch.d.ts +7 -0
- package/src/training/checkpoint-watch.js +106 -0
- package/src/training/checkpoint.d.ts +6 -1
- package/src/training/checkpoint.js +12 -2
- package/src/training/distillation/artifacts.d.ts +71 -0
- package/src/training/distillation/artifacts.js +132 -0
- package/src/training/distillation/checkpoint-watch.d.ts +10 -0
- package/src/training/distillation/checkpoint-watch.js +57 -0
- package/src/training/distillation/dataset.d.ts +59 -0
- package/src/training/distillation/dataset.js +337 -0
- package/src/training/distillation/eval.d.ts +34 -0
- package/src/training/distillation/eval.js +310 -0
- package/src/training/distillation/index.d.ts +29 -0
- package/src/training/distillation/index.js +29 -0
- package/src/training/distillation/runtime.d.ts +20 -0
- package/src/training/distillation/runtime.js +121 -0
- package/src/training/distillation/scoreboard.d.ts +6 -0
- package/src/training/distillation/scoreboard.js +8 -0
- package/src/training/distillation/stage-a.d.ts +45 -0
- package/src/training/distillation/stage-a.js +338 -0
- package/src/training/distillation/stage-b.d.ts +24 -0
- package/src/training/distillation/stage-b.js +20 -0
- package/src/training/index.d.ts +10 -0
- package/src/training/index.js +10 -0
- package/src/training/lora-pipeline.d.ts +40 -0
- package/src/training/lora-pipeline.js +796 -0
- package/src/training/operator-artifacts.d.ts +62 -0
- package/src/training/operator-artifacts.js +140 -0
- package/src/training/operator-command.d.ts +5 -0
- package/src/training/operator-command.js +453 -0
- package/src/training/operator-eval.d.ts +48 -0
- package/src/training/operator-eval.js +230 -0
- package/src/training/operator-scoreboard.d.ts +5 -0
- package/src/training/operator-scoreboard.js +44 -0
- package/src/training/runner.d.ts +52 -0
- package/src/training/runner.js +29 -4
- package/src/training/suite.d.ts +112 -0
- package/src/training/suite.js +9 -9
- package/src/training/workloads.d.ts +164 -0
- package/src/training/workloads.js +539 -0
- package/src/version.js +1 -1
- package/tools/doppler-cli.js +137 -40
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
import type { LoadedTrainingWorkload } from '../workloads.js';
|
|
2
|
+
|
|
3
|
+
export declare function evaluateDistillationModel(options: {
|
|
4
|
+
loadedWorkload: LoadedTrainingWorkload;
|
|
5
|
+
layout?: Record<string, string> | null;
|
|
6
|
+
stageId: string;
|
|
7
|
+
checkpointId: string;
|
|
8
|
+
checkpointStep: number | null;
|
|
9
|
+
checkpointPath?: string | null;
|
|
10
|
+
distillRuntime: Record<string, unknown>;
|
|
11
|
+
model: Record<string, unknown>;
|
|
12
|
+
evalDatasetId?: string | null;
|
|
13
|
+
configHash?: string | null;
|
|
14
|
+
parentArtifacts?: Array<Record<string, unknown>>;
|
|
15
|
+
}): Promise<Record<string, unknown>[]>;
|
|
16
|
+
|
|
17
|
+
export declare function evaluateDistillationCheckpoint(options: {
|
|
18
|
+
loadedWorkload: LoadedTrainingWorkload;
|
|
19
|
+
checkpointPath: string;
|
|
20
|
+
checkpointId?: string | null;
|
|
21
|
+
checkpointStep?: number | null;
|
|
22
|
+
stageId?: string | null;
|
|
23
|
+
layout?: Record<string, string> | null;
|
|
24
|
+
datasetPath?: string | null;
|
|
25
|
+
stageAArtifact?: string | null;
|
|
26
|
+
stageAArtifactHash?: string | null;
|
|
27
|
+
evalDatasetId?: string | null;
|
|
28
|
+
parentArtifacts?: Array<Record<string, unknown>>;
|
|
29
|
+
}): Promise<Record<string, unknown>[]>;
|
|
30
|
+
|
|
31
|
+
export declare function readDistillCheckpointMarker(markerPath: string): Promise<{
|
|
32
|
+
absolutePath: string;
|
|
33
|
+
marker: Record<string, unknown>;
|
|
34
|
+
}>;
|
|
@@ -0,0 +1,310 @@
|
|
|
1
|
+
import { readFile } from 'node:fs/promises';
|
|
2
|
+
import { dirname, join, resolve } from 'node:path';
|
|
3
|
+
|
|
4
|
+
import { loadBackwardRegistry } from '../../config/backward-registry-loader.js';
|
|
5
|
+
import { f16ToF32Array } from '../../inference/kv-cache/types.js';
|
|
6
|
+
import { readBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
7
|
+
import { AutogradTape } from '../autograd.js';
|
|
8
|
+
import { loadCheckpoint } from '../checkpoint.js';
|
|
9
|
+
import { computeEvalMetrics } from '../operator-eval.js';
|
|
10
|
+
import {
|
|
11
|
+
buildDistillPrompt,
|
|
12
|
+
createDistillRuntimeContext,
|
|
13
|
+
createDistillStudentRuntimeModelFixture,
|
|
14
|
+
resolveDistillDataScope,
|
|
15
|
+
} from '../suite.js';
|
|
16
|
+
import { restoreTrainingCheckpointState } from '../runner.js';
|
|
17
|
+
import { loadCanonicalTranslationDataset } from './dataset.js';
|
|
18
|
+
import { buildDistillArtifactBase, writeDistillEvalReport } from './artifacts.js';
|
|
19
|
+
import { buildDistillationTrainingConfigFromWorkload, resolveInternalDistillStage } from './runtime.js';
|
|
20
|
+
|
|
21
|
+
function toFloat32Array(raw, dtype = 'f32') {
|
|
22
|
+
if (dtype === 'f16') {
|
|
23
|
+
return f16ToF32Array(new Uint16Array(raw));
|
|
24
|
+
}
|
|
25
|
+
return new Float32Array(raw);
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
function resolveEvalDatasets(workload, requestedEvalDatasetId = null) {
|
|
29
|
+
const evalDatasets = Array.isArray(workload.evalDatasets) ? workload.evalDatasets : [];
|
|
30
|
+
if (!requestedEvalDatasetId) {
|
|
31
|
+
return evalDatasets;
|
|
32
|
+
}
|
|
33
|
+
return evalDatasets.filter((entry) => entry.id === requestedEvalDatasetId);
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
function argmax(values) {
|
|
37
|
+
let bestIndex = 0;
|
|
38
|
+
let bestValue = Number.NEGATIVE_INFINITY;
|
|
39
|
+
for (let index = 0; index < values.length; index += 1) {
|
|
40
|
+
const value = Number.isFinite(values[index]) ? values[index] : Number.NEGATIVE_INFINITY;
|
|
41
|
+
if (value > bestValue) {
|
|
42
|
+
bestValue = value;
|
|
43
|
+
bestIndex = index;
|
|
44
|
+
}
|
|
45
|
+
}
|
|
46
|
+
return bestIndex;
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
function releaseTensorLike(value, released) {
|
|
50
|
+
if (!value || typeof value !== 'object') return;
|
|
51
|
+
const buffer = value.buffer;
|
|
52
|
+
if (!buffer || released.has(buffer)) return;
|
|
53
|
+
released.add(buffer);
|
|
54
|
+
releaseBuffer(buffer);
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
function disposeTapeOutputs(tape, protectedBuffers = new Set()) {
|
|
58
|
+
if (!tape || !Array.isArray(tape.records)) return;
|
|
59
|
+
const released = new Set();
|
|
60
|
+
for (const record of tape.records) {
|
|
61
|
+
const output = record?.output;
|
|
62
|
+
if (!output || typeof output !== 'object') continue;
|
|
63
|
+
if (output.buffer && !protectedBuffers.has(output.buffer)) {
|
|
64
|
+
releaseTensorLike(output, released);
|
|
65
|
+
continue;
|
|
66
|
+
}
|
|
67
|
+
if (Array.isArray(output)) {
|
|
68
|
+
for (const entry of output) {
|
|
69
|
+
if (entry?.buffer && !protectedBuffers.has(entry.buffer)) {
|
|
70
|
+
releaseTensorLike(entry, released);
|
|
71
|
+
}
|
|
72
|
+
}
|
|
73
|
+
}
|
|
74
|
+
}
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
function collectProtectedBuffers(model) {
|
|
78
|
+
const protectedBuffers = new Set();
|
|
79
|
+
const groups = typeof model?.paramGroups === 'function'
|
|
80
|
+
? model.paramGroups()
|
|
81
|
+
: {};
|
|
82
|
+
for (const params of Object.values(groups || {})) {
|
|
83
|
+
for (const tensor of Array.isArray(params) ? params : []) {
|
|
84
|
+
if (tensor?.buffer) {
|
|
85
|
+
protectedBuffers.add(tensor.buffer);
|
|
86
|
+
}
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
return protectedBuffers;
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
async function readLogitsTensor(tensor) {
|
|
93
|
+
const raw = await readBuffer(tensor.buffer);
|
|
94
|
+
return toFloat32Array(raw, tensor.dtype || 'f32');
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
async function greedyDecodeFixture(model, tokenizer, prompt, decodePolicy = {}) {
|
|
98
|
+
const maxTokens = Number.isInteger(decodePolicy?.maxTokens) && decodePolicy.maxTokens > 0
|
|
99
|
+
? decodePolicy.maxTokens
|
|
100
|
+
: null;
|
|
101
|
+
if (!maxTokens) {
|
|
102
|
+
throw new Error('Translation eval requires evalDatasets[].decodePolicy.maxTokens in the workload pack.');
|
|
103
|
+
}
|
|
104
|
+
const stopOnEos = decodePolicy?.stopOnEos !== false;
|
|
105
|
+
const eosToken = tokenizer?.getSpecialTokens?.()?.eos ?? null;
|
|
106
|
+
const protectedBuffers = collectProtectedBuffers(model);
|
|
107
|
+
const generated = [];
|
|
108
|
+
let currentPrompt = prompt;
|
|
109
|
+
for (let step = 0; step < maxTokens; step += 1) {
|
|
110
|
+
const tape = new AutogradTape(loadBackwardRegistry());
|
|
111
|
+
let logits = null;
|
|
112
|
+
try {
|
|
113
|
+
const result = await model.forwardDistill(
|
|
114
|
+
{
|
|
115
|
+
distill: {
|
|
116
|
+
prompts: [currentPrompt],
|
|
117
|
+
},
|
|
118
|
+
},
|
|
119
|
+
tape,
|
|
120
|
+
{ phase: 'anchor' }
|
|
121
|
+
);
|
|
122
|
+
logits = result?.logits || result;
|
|
123
|
+
const values = await readLogitsTensor(logits);
|
|
124
|
+
const tokenId = argmax(values);
|
|
125
|
+
if (stopOnEos && eosToken != null && tokenId === eosToken) {
|
|
126
|
+
break;
|
|
127
|
+
}
|
|
128
|
+
generated.push(tokenId);
|
|
129
|
+
currentPrompt = `${prompt}${tokenizer.decode(generated, false, false)}`;
|
|
130
|
+
} finally {
|
|
131
|
+
if (logits?.buffer && !protectedBuffers.has(logits.buffer)) {
|
|
132
|
+
releaseBuffer(logits.buffer);
|
|
133
|
+
}
|
|
134
|
+
model.cleanupDistillStep?.();
|
|
135
|
+
disposeTapeOutputs(tape, protectedBuffers);
|
|
136
|
+
}
|
|
137
|
+
}
|
|
138
|
+
return tokenizer.decode(generated, true, true);
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
function flattenMetricSummary(metrics) {
|
|
142
|
+
return {
|
|
143
|
+
bleu: metrics?.bleu?.score ?? null,
|
|
144
|
+
chrf: metrics?.chrf?.score ?? null,
|
|
145
|
+
exact_match: metrics?.exactMatch?.score ?? null,
|
|
146
|
+
accuracy: metrics?.accuracy?.score ?? null,
|
|
147
|
+
primaryMetric: metrics?.primaryMetric || null,
|
|
148
|
+
primaryScore: metrics?.primaryScore ?? null,
|
|
149
|
+
};
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
export async function evaluateDistillationModel(options) {
|
|
153
|
+
const loadedWorkload = options.loadedWorkload;
|
|
154
|
+
const workload = loadedWorkload.workload;
|
|
155
|
+
const stageId = options.stageId;
|
|
156
|
+
const checkpointId = options.checkpointId;
|
|
157
|
+
const checkpointStep = options.checkpointStep;
|
|
158
|
+
const distillRuntime = options.distillRuntime;
|
|
159
|
+
const model = options.model;
|
|
160
|
+
if (distillRuntime.studentGraphMode !== 'transformer_full') {
|
|
161
|
+
throw new Error(
|
|
162
|
+
`Distillation eval requires studentGraphMode="transformer_full"; got "${distillRuntime.studentGraphMode}".`
|
|
163
|
+
);
|
|
164
|
+
}
|
|
165
|
+
const evalDatasets = resolveEvalDatasets(workload, options.evalDatasetId || null);
|
|
166
|
+
if (evalDatasets.length === 0) {
|
|
167
|
+
throw new Error(`No eval datasets resolved for workload "${workload.id}".`);
|
|
168
|
+
}
|
|
169
|
+
const reports = [];
|
|
170
|
+
for (const evalDataset of evalDatasets) {
|
|
171
|
+
if (evalDataset.evalKind !== 'translation') {
|
|
172
|
+
throw new Error(`Distillation eval currently supports translation eval only, got "${evalDataset.evalKind}".`);
|
|
173
|
+
}
|
|
174
|
+
const dataset = await loadCanonicalTranslationDataset(evalDataset.datasetPath, {
|
|
175
|
+
strictPairContract: workload.pipeline.strictPairContract === true,
|
|
176
|
+
sourceLangs: evalDataset.sourceLangs || workload.pipeline.sourceLangs,
|
|
177
|
+
targetLangs: evalDataset.targetLangs || workload.pipeline.targetLangs,
|
|
178
|
+
pairAllowlist: evalDataset.pairAllowlist || workload.pipeline.pairAllowlist,
|
|
179
|
+
});
|
|
180
|
+
const hypotheses = [];
|
|
181
|
+
const references = [];
|
|
182
|
+
const samples = [];
|
|
183
|
+
for (const row of dataset.rows) {
|
|
184
|
+
const prompt = buildDistillPrompt({
|
|
185
|
+
direction: row.pair || (
|
|
186
|
+
row.src_lang && row.tgt_lang
|
|
187
|
+
? `${row.src_lang}->${row.tgt_lang}`
|
|
188
|
+
: 'unknown'
|
|
189
|
+
),
|
|
190
|
+
source: row.source,
|
|
191
|
+
});
|
|
192
|
+
const hypothesis = await greedyDecodeFixture(
|
|
193
|
+
model,
|
|
194
|
+
distillRuntime.studentPipeline.tokenizer,
|
|
195
|
+
prompt,
|
|
196
|
+
evalDataset.decodePolicy || {}
|
|
197
|
+
);
|
|
198
|
+
hypotheses.push(hypothesis);
|
|
199
|
+
references.push(row.target_pos);
|
|
200
|
+
if (samples.length < 5) {
|
|
201
|
+
samples.push({
|
|
202
|
+
row_id: row.row_id,
|
|
203
|
+
source: row.source,
|
|
204
|
+
reference: row.target_pos,
|
|
205
|
+
hypothesis,
|
|
206
|
+
});
|
|
207
|
+
}
|
|
208
|
+
}
|
|
209
|
+
const computedMetrics = computeEvalMetrics('translation', hypotheses, references, {});
|
|
210
|
+
const flattened = flattenMetricSummary(computedMetrics);
|
|
211
|
+
const reportPayload = {
|
|
212
|
+
...buildDistillArtifactBase(loadedWorkload, {
|
|
213
|
+
prefix: 'dst_eval',
|
|
214
|
+
artifactType: 'training_eval_report',
|
|
215
|
+
datasetPath: dataset.absolutePath,
|
|
216
|
+
datasetHash: dataset.canonicalHash,
|
|
217
|
+
stage: stageId,
|
|
218
|
+
checkpointStep,
|
|
219
|
+
parentArtifacts: options.parentArtifacts || [],
|
|
220
|
+
configHash: options.configHash || workload.configHash,
|
|
221
|
+
}),
|
|
222
|
+
checkpointId,
|
|
223
|
+
checkpointPath: options.checkpointPath || null,
|
|
224
|
+
evalDatasetId: evalDataset.id,
|
|
225
|
+
evalKind: evalDataset.evalKind,
|
|
226
|
+
metrics: computedMetrics,
|
|
227
|
+
...flattened,
|
|
228
|
+
rowCount: dataset.rowCount,
|
|
229
|
+
sampleRows: samples,
|
|
230
|
+
};
|
|
231
|
+
const reportFile = options.layout
|
|
232
|
+
? await writeDistillEvalReport(options.layout, reportPayload)
|
|
233
|
+
: null;
|
|
234
|
+
reports.push({
|
|
235
|
+
...reportPayload,
|
|
236
|
+
reportPath: reportFile?.path || null,
|
|
237
|
+
});
|
|
238
|
+
}
|
|
239
|
+
return reports;
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
export async function evaluateDistillationCheckpoint(options) {
|
|
243
|
+
const loadedWorkload = options.loadedWorkload;
|
|
244
|
+
const workload = loadedWorkload.workload;
|
|
245
|
+
const stagePlan = workload.pipeline.stagePlan;
|
|
246
|
+
const requestedStageId = String(options.stageId || '').trim();
|
|
247
|
+
const stageEntry = stagePlan.find((entry) => entry.id === requestedStageId)
|
|
248
|
+
|| stagePlan[0];
|
|
249
|
+
if (!stageEntry) {
|
|
250
|
+
throw new Error(`No stage entry resolved for workload "${workload.id}".`);
|
|
251
|
+
}
|
|
252
|
+
const internalStage = resolveInternalDistillStage(stageEntry);
|
|
253
|
+
const checkpointPath = resolve(String(options.checkpointPath));
|
|
254
|
+
const checkpointRecord = await loadCheckpoint(checkpointPath);
|
|
255
|
+
if (!checkpointRecord) {
|
|
256
|
+
throw new Error(`Checkpoint not found: ${checkpointPath}`);
|
|
257
|
+
}
|
|
258
|
+
const configBundle = buildDistillationTrainingConfigFromWorkload(loadedWorkload, stageEntry, {
|
|
259
|
+
datasetPath: options.datasetPath || workload.datasetPath,
|
|
260
|
+
stageAArtifact: options.stageAArtifact || null,
|
|
261
|
+
stageAArtifactHash: options.stageAArtifactHash || null,
|
|
262
|
+
artifactDir: dirname(dirname(checkpointPath)),
|
|
263
|
+
});
|
|
264
|
+
const distillRuntime = await createDistillRuntimeContext({
|
|
265
|
+
teacherModelId: workload.teacherModelId,
|
|
266
|
+
studentModelId: workload.studentModelId,
|
|
267
|
+
trainingStage: internalStage,
|
|
268
|
+
studentGraphMode: workload.pipeline.studentGraphMode,
|
|
269
|
+
}, configBundle.trainingConfig.training);
|
|
270
|
+
let fixture = null;
|
|
271
|
+
try {
|
|
272
|
+
fixture = await createDistillStudentRuntimeModelFixture({
|
|
273
|
+
training: configBundle.trainingConfig.training,
|
|
274
|
+
}, {
|
|
275
|
+
distillRuntime,
|
|
276
|
+
studentGraphMode: workload.pipeline.studentGraphMode,
|
|
277
|
+
});
|
|
278
|
+
await restoreTrainingCheckpointState(
|
|
279
|
+
fixture.model,
|
|
280
|
+
{ getState: () => null, stepCount: 0 },
|
|
281
|
+
checkpointRecord,
|
|
282
|
+
fixture.config
|
|
283
|
+
);
|
|
284
|
+
return evaluateDistillationModel({
|
|
285
|
+
loadedWorkload,
|
|
286
|
+
layout: options.layout || null,
|
|
287
|
+
stageId: stageEntry.id,
|
|
288
|
+
checkpointId: options.checkpointId || 'checkpoint',
|
|
289
|
+
checkpointStep: options.checkpointStep || null,
|
|
290
|
+
checkpointPath,
|
|
291
|
+
distillRuntime,
|
|
292
|
+
model: fixture.model,
|
|
293
|
+
evalDatasetId: options.evalDatasetId || null,
|
|
294
|
+
configHash: configBundle.trainingConfigHash,
|
|
295
|
+
parentArtifacts: options.parentArtifacts || [],
|
|
296
|
+
});
|
|
297
|
+
} finally {
|
|
298
|
+
fixture?.cleanup?.();
|
|
299
|
+
await distillRuntime.cleanup();
|
|
300
|
+
}
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
export async function readDistillCheckpointMarker(markerPath) {
|
|
304
|
+
const absolutePath = resolve(String(markerPath));
|
|
305
|
+
const raw = await readFile(absolutePath, 'utf8');
|
|
306
|
+
return {
|
|
307
|
+
absolutePath,
|
|
308
|
+
marker: JSON.parse(raw),
|
|
309
|
+
};
|
|
310
|
+
}
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
export {
|
|
2
|
+
normalizeDistillationPair,
|
|
3
|
+
normalizeTranslationPairRow,
|
|
4
|
+
loadCanonicalTranslationDataset,
|
|
5
|
+
buildFrozenSubset,
|
|
6
|
+
} from './dataset.js';
|
|
7
|
+
export {
|
|
8
|
+
createDistillationRunArtifacts,
|
|
9
|
+
writeDistillStageManifest,
|
|
10
|
+
writeDistillCheckpointMetadata,
|
|
11
|
+
writeDistillCheckpointComplete,
|
|
12
|
+
writeDistillEvalReport,
|
|
13
|
+
writeDistillCompareReport,
|
|
14
|
+
writeDistillQualityGateReport,
|
|
15
|
+
buildDistillArtifactBase,
|
|
16
|
+
} from './artifacts.js';
|
|
17
|
+
export { appendDistillationScoreboardRow } from './scoreboard.js';
|
|
18
|
+
export {
|
|
19
|
+
buildDistillationTrainingConfigFromWorkload,
|
|
20
|
+
resolveInternalDistillStage,
|
|
21
|
+
} from './runtime.js';
|
|
22
|
+
export {
|
|
23
|
+
evaluateDistillationModel,
|
|
24
|
+
evaluateDistillationCheckpoint,
|
|
25
|
+
readDistillCheckpointMarker,
|
|
26
|
+
} from './eval.js';
|
|
27
|
+
export { runDistillationStage, runDistillationStageA } from './stage-a.js';
|
|
28
|
+
export { runDistillationStageB } from './stage-b.js';
|
|
29
|
+
export { watchDistillationCheckpoints } from './checkpoint-watch.js';
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
export {
|
|
2
|
+
normalizeDistillationPair,
|
|
3
|
+
normalizeTranslationPairRow,
|
|
4
|
+
loadCanonicalTranslationDataset,
|
|
5
|
+
buildFrozenSubset,
|
|
6
|
+
} from './dataset.js';
|
|
7
|
+
export {
|
|
8
|
+
createDistillationRunArtifacts,
|
|
9
|
+
writeDistillStageManifest,
|
|
10
|
+
writeDistillCheckpointMetadata,
|
|
11
|
+
writeDistillCheckpointComplete,
|
|
12
|
+
writeDistillEvalReport,
|
|
13
|
+
writeDistillCompareReport,
|
|
14
|
+
writeDistillQualityGateReport,
|
|
15
|
+
buildDistillArtifactBase,
|
|
16
|
+
} from './artifacts.js';
|
|
17
|
+
export { appendDistillationScoreboardRow } from './scoreboard.js';
|
|
18
|
+
export {
|
|
19
|
+
buildDistillationTrainingConfigFromWorkload,
|
|
20
|
+
resolveInternalDistillStage,
|
|
21
|
+
} from './runtime.js';
|
|
22
|
+
export {
|
|
23
|
+
evaluateDistillationModel,
|
|
24
|
+
evaluateDistillationCheckpoint,
|
|
25
|
+
readDistillCheckpointMarker,
|
|
26
|
+
} from './eval.js';
|
|
27
|
+
export { runDistillationStage, runDistillationStageA } from './stage-a.js';
|
|
28
|
+
export { runDistillationStageB } from './stage-b.js';
|
|
29
|
+
export { watchDistillationCheckpoints } from './checkpoint-watch.js';
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import type { LoadedTrainingWorkload, DistillStagePlanEntry } from '../workloads.js';
|
|
2
|
+
|
|
3
|
+
export declare function resolveInternalDistillStage(
|
|
4
|
+
stageEntry: DistillStagePlanEntry | Record<string, unknown>
|
|
5
|
+
): 'stage_a' | 'stage_b';
|
|
6
|
+
|
|
7
|
+
export declare function buildDistillationTrainingConfigFromWorkload(
|
|
8
|
+
loadedWorkload: LoadedTrainingWorkload,
|
|
9
|
+
stageEntry: DistillStagePlanEntry | Record<string, unknown>,
|
|
10
|
+
options?: {
|
|
11
|
+
datasetPath?: string | null;
|
|
12
|
+
artifactDir?: string | null;
|
|
13
|
+
stageAArtifact?: string | null;
|
|
14
|
+
stageAArtifactHash?: string | null;
|
|
15
|
+
}
|
|
16
|
+
): {
|
|
17
|
+
internalStage: 'stage_a' | 'stage_b';
|
|
18
|
+
trainingConfig: Record<string, unknown>;
|
|
19
|
+
trainingConfigHash: string;
|
|
20
|
+
};
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
import { createTrainingConfig } from '../../config/training-defaults.js';
|
|
2
|
+
import { sha256Hex } from '../../utils/sha256.js';
|
|
3
|
+
|
|
4
|
+
function stableSortObject(value) {
|
|
5
|
+
if (Array.isArray(value)) {
|
|
6
|
+
return value.map((entry) => stableSortObject(entry));
|
|
7
|
+
}
|
|
8
|
+
if (!value || typeof value !== 'object') {
|
|
9
|
+
return value;
|
|
10
|
+
}
|
|
11
|
+
const sorted = {};
|
|
12
|
+
for (const key of Object.keys(value).sort()) {
|
|
13
|
+
sorted[key] = stableSortObject(value[key]);
|
|
14
|
+
}
|
|
15
|
+
return sorted;
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
function stableJson(value) {
|
|
19
|
+
return JSON.stringify(stableSortObject(value));
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
function normalizeStageLabel(value) {
|
|
23
|
+
const normalized = String(value || '').trim().toLowerCase().replace(/[\s-]+/g, '_');
|
|
24
|
+
return normalized;
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
export function resolveInternalDistillStage(stageEntry) {
|
|
28
|
+
const trainingStage = normalizeStageLabel(stageEntry?.trainingStage || stageEntry?.id || '');
|
|
29
|
+
const objective = normalizeStageLabel(stageEntry?.objective || '');
|
|
30
|
+
if (trainingStage === 'sft' || objective === 'sft') {
|
|
31
|
+
throw new Error(
|
|
32
|
+
'Distillation workload stage uses "sft", but the current JS distill runner only supports the KD-oriented stage_a contract. Use objective="kd" / trainingStage="stage_a" explicitly.'
|
|
33
|
+
);
|
|
34
|
+
}
|
|
35
|
+
if (
|
|
36
|
+
trainingStage === 'stage_b'
|
|
37
|
+
|| trainingStage === 'post_sft_distill'
|
|
38
|
+
|| trainingStage === 'post_sft_triplet'
|
|
39
|
+
|| objective === 'triplet'
|
|
40
|
+
) {
|
|
41
|
+
return 'stage_b';
|
|
42
|
+
}
|
|
43
|
+
if (
|
|
44
|
+
trainingStage === 'stage_a'
|
|
45
|
+
|| trainingStage === 'kd'
|
|
46
|
+
|| objective === 'kd'
|
|
47
|
+
|| objective === 'cross_entropy'
|
|
48
|
+
) {
|
|
49
|
+
return 'stage_a';
|
|
50
|
+
}
|
|
51
|
+
throw new Error(
|
|
52
|
+
`Unsupported distillation stage "${stageEntry?.trainingStage || stageEntry?.id || 'unknown'}".`
|
|
53
|
+
);
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
export function buildDistillationTrainingConfigFromWorkload(loadedWorkload, stageEntry, options = {}) {
|
|
57
|
+
const workload = loadedWorkload.workload;
|
|
58
|
+
if (workload.kind !== 'distill') {
|
|
59
|
+
throw new Error('buildDistillationTrainingConfigFromWorkload requires a distill workload.');
|
|
60
|
+
}
|
|
61
|
+
const internalStage = resolveInternalDistillStage(stageEntry);
|
|
62
|
+
const distillTraining = {
|
|
63
|
+
enabled: true,
|
|
64
|
+
stage: internalStage,
|
|
65
|
+
teacherModelId: workload.teacherModelId,
|
|
66
|
+
studentModelId: workload.studentModelId,
|
|
67
|
+
datasetId: workload.datasetId,
|
|
68
|
+
datasetPath: options.datasetPath || workload.datasetPath,
|
|
69
|
+
sourceLangs: workload.pipeline.sourceLangs,
|
|
70
|
+
targetLangs: workload.pipeline.targetLangs,
|
|
71
|
+
pairAllowlist: workload.pipeline.pairAllowlist,
|
|
72
|
+
strictPairContract: workload.pipeline.strictPairContract === true,
|
|
73
|
+
stageAArtifact: options.stageAArtifact || null,
|
|
74
|
+
stageAArtifactHash: options.stageAArtifactHash || null,
|
|
75
|
+
artifactDir: options.artifactDir || null,
|
|
76
|
+
temperature: workload.pipeline.temperature,
|
|
77
|
+
alphaKd: workload.pipeline.alphaKd,
|
|
78
|
+
alphaCe: workload.pipeline.alphaCe,
|
|
79
|
+
tripletMargin: workload.pipeline.tripletMargin,
|
|
80
|
+
studentGraphMode: workload.pipeline.studentGraphMode,
|
|
81
|
+
};
|
|
82
|
+
if (internalStage === 'stage_b') {
|
|
83
|
+
distillTraining.freeze = {
|
|
84
|
+
encoder: true,
|
|
85
|
+
prior: true,
|
|
86
|
+
decoder: true,
|
|
87
|
+
base: false,
|
|
88
|
+
lora: false,
|
|
89
|
+
};
|
|
90
|
+
}
|
|
91
|
+
const trainingConfig = createTrainingConfig({
|
|
92
|
+
training: {
|
|
93
|
+
enabled: true,
|
|
94
|
+
optimizer: {
|
|
95
|
+
type: workload.training.optimizer.type,
|
|
96
|
+
lr: workload.training.optimizer.lr,
|
|
97
|
+
beta1: workload.training.optimizer.beta1,
|
|
98
|
+
beta2: workload.training.optimizer.beta2,
|
|
99
|
+
eps: workload.training.optimizer.eps,
|
|
100
|
+
weightDecay: workload.training.optimizer.weightDecay,
|
|
101
|
+
scheduler: workload.training.optimizer.scheduler,
|
|
102
|
+
},
|
|
103
|
+
gradient: {
|
|
104
|
+
maxNorm: workload.training.gradientClipping.maxNorm,
|
|
105
|
+
},
|
|
106
|
+
precision: workload.training.precision,
|
|
107
|
+
distill: distillTraining,
|
|
108
|
+
},
|
|
109
|
+
});
|
|
110
|
+
return {
|
|
111
|
+
internalStage,
|
|
112
|
+
trainingConfig,
|
|
113
|
+
trainingConfigHash: sha256Hex(stableJson({
|
|
114
|
+
workloadConfigHash: workload.configHash,
|
|
115
|
+
stageEntry,
|
|
116
|
+
datasetPath: options.datasetPath || workload.datasetPath,
|
|
117
|
+
stageAArtifact: options.stageAArtifact || null,
|
|
118
|
+
stageAArtifactHash: options.stageAArtifactHash || null,
|
|
119
|
+
})),
|
|
120
|
+
};
|
|
121
|
+
}
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
export declare function appendDistillationScoreboardRow(
|
|
2
|
+
layout: Record<string, string>,
|
|
3
|
+
stageId: string,
|
|
4
|
+
row: Record<string, unknown>,
|
|
5
|
+
options?: { selectionMetric?: string | null; selectionGoal?: string | null }
|
|
6
|
+
): Promise<{ rowsPath: string; summaryPath: string }>;
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
import { join } from 'node:path';
|
|
2
|
+
|
|
3
|
+
import { appendScoreboardRow } from '../operator-scoreboard.js';
|
|
4
|
+
|
|
5
|
+
export async function appendDistillationScoreboardRow(layout, stageId, row, options = {}) {
|
|
6
|
+
const scoreboardDir = join(layout.scoreboard, String(stageId || 'stage'));
|
|
7
|
+
return appendScoreboardRow(scoreboardDir, row, options);
|
|
8
|
+
}
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import type { LoadedTrainingWorkload, DistillStagePlanEntry } from '../workloads.js';
|
|
2
|
+
|
|
3
|
+
export declare function runDistillationStage(options: {
|
|
4
|
+
loadedWorkload: LoadedTrainingWorkload;
|
|
5
|
+
stageEntry: DistillStagePlanEntry;
|
|
6
|
+
layout: Record<string, string>;
|
|
7
|
+
datasetPath?: string | null;
|
|
8
|
+
stageAArtifact?: string | null;
|
|
9
|
+
stageAArtifactHash?: string | null;
|
|
10
|
+
legacyArtifactDir?: string | null;
|
|
11
|
+
timestamp?: string | Date | null;
|
|
12
|
+
parentArtifacts?: Array<Record<string, unknown>>;
|
|
13
|
+
}): Promise<{
|
|
14
|
+
stageId: string;
|
|
15
|
+
trainingStage: 'stage_a' | 'stage_b';
|
|
16
|
+
metrics: Record<string, unknown>[];
|
|
17
|
+
checkpointArtifacts: Array<Record<string, unknown>>;
|
|
18
|
+
evalReports: Array<Record<string, unknown>>;
|
|
19
|
+
bestReport: Record<string, unknown> | null;
|
|
20
|
+
stageManifestPath: string;
|
|
21
|
+
legacyArtifact: Record<string, unknown> | null;
|
|
22
|
+
lastCheckpoint: Record<string, unknown> | null;
|
|
23
|
+
}>;
|
|
24
|
+
|
|
25
|
+
export declare function runDistillationStageA(options: {
|
|
26
|
+
loadedWorkload: LoadedTrainingWorkload;
|
|
27
|
+
stageEntry: DistillStagePlanEntry;
|
|
28
|
+
layout: Record<string, string>;
|
|
29
|
+
datasetPath?: string | null;
|
|
30
|
+
stageAArtifact?: string | null;
|
|
31
|
+
stageAArtifactHash?: string | null;
|
|
32
|
+
legacyArtifactDir?: string | null;
|
|
33
|
+
timestamp?: string | Date | null;
|
|
34
|
+
parentArtifacts?: Array<Record<string, unknown>>;
|
|
35
|
+
}): Promise<{
|
|
36
|
+
stageId: string;
|
|
37
|
+
trainingStage: 'stage_a' | 'stage_b';
|
|
38
|
+
metrics: Record<string, unknown>[];
|
|
39
|
+
checkpointArtifacts: Array<Record<string, unknown>>;
|
|
40
|
+
evalReports: Array<Record<string, unknown>>;
|
|
41
|
+
bestReport: Record<string, unknown> | null;
|
|
42
|
+
stageManifestPath: string;
|
|
43
|
+
legacyArtifact: Record<string, unknown> | null;
|
|
44
|
+
lastCheckpoint: Record<string, unknown> | null;
|
|
45
|
+
}>;
|