@simulatte/doppler 0.1.5 → 0.1.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +23 -8
- package/package.json +7 -4
- package/src/config/kernels/kernel-ref-digests.js +39 -39
- package/src/config/kernels/registry.json +42 -2
- package/src/config/loader.js +31 -2
- package/src/config/merge.js +18 -0
- package/src/config/presets/models/qwen3.json +9 -2
- package/src/config/presets/models/transformer.json +5 -0
- package/src/config/required-inference-fields-contract-check.js +6 -0
- package/src/config/schema/inference-defaults.schema.js +3 -0
- package/src/config/schema/inference.schema.d.ts +9 -0
- package/src/config/schema/kernel-path.schema.d.ts +6 -0
- package/src/config/schema/manifest.schema.d.ts +6 -0
- package/src/config/schema/manifest.schema.js +3 -0
- package/src/converter/rope-config.js +42 -0
- package/src/gpu/device.js +58 -0
- package/src/gpu/kernels/attention.js +98 -0
- package/src/gpu/kernels/bias_add.wgsl +8 -6
- package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
- package/src/gpu/kernels/conv2d.js +1 -1
- package/src/gpu/kernels/conv2d.wgsl +7 -8
- package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
- package/src/gpu/kernels/depthwise_conv2d.js +2 -1
- package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
- package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +2 -1
- package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
- package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
- package/src/gpu/kernels/matmul.js +25 -0
- package/src/gpu/kernels/pixel_shuffle.js +1 -1
- package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
- package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
- package/src/gpu/kernels/relu.js +15 -2
- package/src/gpu/kernels/relu.wgsl +2 -1
- package/src/gpu/kernels/relu_f16.wgsl +2 -1
- package/src/gpu/kernels/repeat_channels.js +1 -1
- package/src/gpu/kernels/repeat_channels.wgsl +4 -5
- package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
- package/src/gpu/kernels/residual.js +44 -8
- package/src/gpu/kernels/residual.wgsl +6 -3
- package/src/gpu/kernels/residual_f16.wgsl +2 -1
- package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
- package/src/gpu/kernels/residual_vec4.wgsl +2 -1
- package/src/gpu/kernels/rmsnorm.js +58 -6
- package/src/gpu/kernels/rmsnorm.wgsl +14 -6
- package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
- package/src/gpu/kernels/rope.d.ts +2 -0
- package/src/gpu/kernels/rope.js +11 -1
- package/src/gpu/kernels/rope.wgsl +56 -40
- package/src/gpu/kernels/sana_linear_attention.js +1 -2
- package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
- package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
- package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
- package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
- package/src/gpu/kernels/silu.d.ts +1 -0
- package/src/gpu/kernels/silu.js +32 -14
- package/src/gpu/kernels/silu.wgsl +19 -9
- package/src/gpu/kernels/silu_f16.wgsl +19 -9
- package/src/gpu/kernels/transpose.js +15 -2
- package/src/gpu/kernels/transpose.wgsl +5 -6
- package/src/gpu/kernels/upsample2d.js +2 -1
- package/src/gpu/kernels/upsample2d.wgsl +6 -9
- package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
- package/src/gpu/kernels/utils.js +16 -1
- package/src/inference/browser-harness.js +47 -1
- package/src/inference/pipelines/diffusion/pipeline.js +15 -6
- package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
- package/src/inference/pipelines/text/attention/record.js +11 -2
- package/src/inference/pipelines/text/attention/run.js +11 -2
- package/src/inference/pipelines/text/chat-format.js +25 -1
- package/src/inference/pipelines/text/config.d.ts +4 -0
- package/src/inference/pipelines/text/config.js +68 -1
- package/src/inference/pipelines/text/execution-plan.js +23 -31
- package/src/inference/pipelines/text/execution-v0.js +29 -2
- package/src/inference/pipelines/text/ffn/standard.js +3 -0
- package/src/inference/pipelines/text/init.d.ts +4 -0
- package/src/inference/pipelines/text/init.js +56 -9
- package/src/inference/pipelines/text/layer.js +11 -0
- package/src/inference/pipelines/text.js +4 -0
- package/src/inference/tokenizers/bundled.js +156 -33
- package/src/rules/tooling/command-runtime.rules.json +18 -0
- package/src/tooling/command-api.d.ts +27 -1
- package/src/tooling/command-api.js +142 -3
- package/src/tooling/node-browser-command-runner.d.ts +4 -0
- package/src/tooling/node-browser-command-runner.js +58 -3
- package/src/tooling/node-command-runner.js +15 -0
- package/src/tooling/node-webgpu.js +9 -87
- package/src/training/checkpoint-watch.d.ts +7 -0
- package/src/training/checkpoint-watch.js +106 -0
- package/src/training/checkpoint.d.ts +6 -1
- package/src/training/checkpoint.js +12 -2
- package/src/training/distillation/artifacts.d.ts +71 -0
- package/src/training/distillation/artifacts.js +132 -0
- package/src/training/distillation/checkpoint-watch.d.ts +10 -0
- package/src/training/distillation/checkpoint-watch.js +57 -0
- package/src/training/distillation/dataset.d.ts +59 -0
- package/src/training/distillation/dataset.js +337 -0
- package/src/training/distillation/eval.d.ts +34 -0
- package/src/training/distillation/eval.js +310 -0
- package/src/training/distillation/index.d.ts +29 -0
- package/src/training/distillation/index.js +29 -0
- package/src/training/distillation/runtime.d.ts +20 -0
- package/src/training/distillation/runtime.js +121 -0
- package/src/training/distillation/scoreboard.d.ts +6 -0
- package/src/training/distillation/scoreboard.js +8 -0
- package/src/training/distillation/stage-a.d.ts +45 -0
- package/src/training/distillation/stage-a.js +338 -0
- package/src/training/distillation/stage-b.d.ts +24 -0
- package/src/training/distillation/stage-b.js +20 -0
- package/src/training/index.d.ts +10 -0
- package/src/training/index.js +10 -0
- package/src/training/lora-pipeline.d.ts +40 -0
- package/src/training/lora-pipeline.js +796 -0
- package/src/training/operator-artifacts.d.ts +62 -0
- package/src/training/operator-artifacts.js +140 -0
- package/src/training/operator-command.d.ts +5 -0
- package/src/training/operator-command.js +453 -0
- package/src/training/operator-eval.d.ts +48 -0
- package/src/training/operator-eval.js +230 -0
- package/src/training/operator-scoreboard.d.ts +5 -0
- package/src/training/operator-scoreboard.js +44 -0
- package/src/training/runner.d.ts +52 -0
- package/src/training/runner.js +29 -4
- package/src/training/suite.d.ts +112 -0
- package/src/training/suite.js +9 -9
- package/src/training/workloads.d.ts +164 -0
- package/src/training/workloads.js +539 -0
- package/src/version.js +1 -1
- package/tools/doppler-cli.js +137 -40
|
@@ -184,12 +184,22 @@ export async function saveCheckpoint(key, payload, options = {}) {
|
|
|
184
184
|
|
|
185
185
|
if (useNodeStore) {
|
|
186
186
|
await writeNodeCheckpointRecord(nodePath, data);
|
|
187
|
-
return
|
|
187
|
+
return {
|
|
188
|
+
key,
|
|
189
|
+
path: nodePath,
|
|
190
|
+
metadata: data.metadata,
|
|
191
|
+
data,
|
|
192
|
+
};
|
|
188
193
|
}
|
|
189
194
|
|
|
190
195
|
return new Promise((resolve, reject) => {
|
|
191
196
|
const tx = browserStore.db.transaction(browserStore.storeName, 'readwrite');
|
|
192
|
-
tx.oncomplete = () => resolve(
|
|
197
|
+
tx.oncomplete = () => resolve({
|
|
198
|
+
key,
|
|
199
|
+
path: null,
|
|
200
|
+
metadata: data.metadata,
|
|
201
|
+
data,
|
|
202
|
+
});
|
|
193
203
|
tx.onerror = () => reject(tx.error);
|
|
194
204
|
const store = tx.objectStore(browserStore.storeName);
|
|
195
205
|
store.put(data, key);
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
import type { LoadedTrainingWorkload } from '../workloads.js';
|
|
2
|
+
|
|
3
|
+
export declare function createDistillationRunArtifacts(options: {
|
|
4
|
+
loadedWorkload: LoadedTrainingWorkload;
|
|
5
|
+
runRoot?: string | null;
|
|
6
|
+
timestamp?: string | Date | null;
|
|
7
|
+
}): Promise<{
|
|
8
|
+
layout: {
|
|
9
|
+
runRoot: string;
|
|
10
|
+
logs: string;
|
|
11
|
+
checkpoints: string;
|
|
12
|
+
eval: string;
|
|
13
|
+
scoreboard: string;
|
|
14
|
+
exports: string;
|
|
15
|
+
compare: string;
|
|
16
|
+
qualityGate: string;
|
|
17
|
+
};
|
|
18
|
+
runContract: { path: string; sha256: string; relativePath: string };
|
|
19
|
+
workloadLock: { path: string; sha256: string; relativePath: string };
|
|
20
|
+
runContractPayload: Record<string, unknown>;
|
|
21
|
+
}>;
|
|
22
|
+
|
|
23
|
+
export declare function writeDistillStageManifest(
|
|
24
|
+
layout: Record<string, string>,
|
|
25
|
+
payload: Record<string, unknown>
|
|
26
|
+
): Promise<{ path: string; sha256: string; relativePath: string }>;
|
|
27
|
+
|
|
28
|
+
export declare function writeDistillCheckpointMetadata(
|
|
29
|
+
layout: Record<string, string>,
|
|
30
|
+
stageId: string,
|
|
31
|
+
checkpointId: string,
|
|
32
|
+
payload: Record<string, unknown>
|
|
33
|
+
): Promise<{ path: string; sha256: string; relativePath: string }>;
|
|
34
|
+
|
|
35
|
+
export declare function writeDistillCheckpointComplete(
|
|
36
|
+
layout: Record<string, string>,
|
|
37
|
+
stageId: string,
|
|
38
|
+
checkpointId: string,
|
|
39
|
+
payload: Record<string, unknown>
|
|
40
|
+
): Promise<{ path: string; sha256: string; relativePath: string }>;
|
|
41
|
+
|
|
42
|
+
export declare function writeDistillEvalReport(
|
|
43
|
+
layout: Record<string, string>,
|
|
44
|
+
payload: Record<string, unknown>
|
|
45
|
+
): Promise<{ path: string; sha256: string; relativePath: string }>;
|
|
46
|
+
|
|
47
|
+
export declare function writeDistillCompareReport(
|
|
48
|
+
layout: Record<string, string>,
|
|
49
|
+
payload: Record<string, unknown>
|
|
50
|
+
): Promise<{ path: string; sha256: string; relativePath: string }>;
|
|
51
|
+
|
|
52
|
+
export declare function writeDistillQualityGateReport(
|
|
53
|
+
layout: Record<string, string>,
|
|
54
|
+
payload: Record<string, unknown>
|
|
55
|
+
): Promise<{ path: string; sha256: string; relativePath: string }>;
|
|
56
|
+
|
|
57
|
+
export declare function buildDistillArtifactBase(
|
|
58
|
+
loadedWorkload: LoadedTrainingWorkload,
|
|
59
|
+
options: {
|
|
60
|
+
prefix?: string;
|
|
61
|
+
artifactType: string;
|
|
62
|
+
datasetPath?: string | null;
|
|
63
|
+
datasetHash?: string | null;
|
|
64
|
+
stage?: string | null;
|
|
65
|
+
checkpointStep?: number | null;
|
|
66
|
+
parentArtifacts?: Array<Record<string, unknown>>;
|
|
67
|
+
runtime?: string;
|
|
68
|
+
surface?: string;
|
|
69
|
+
configHash?: string | null;
|
|
70
|
+
}
|
|
71
|
+
): Record<string, unknown>;
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
import { mkdir } from 'node:fs/promises';
|
|
2
|
+
import { join, resolve } from 'node:path';
|
|
3
|
+
|
|
4
|
+
import {
|
|
5
|
+
buildArtifactBase,
|
|
6
|
+
createTrainingRunLayout,
|
|
7
|
+
hashArtifactPayload,
|
|
8
|
+
writeJsonArtifact,
|
|
9
|
+
writeRunContract,
|
|
10
|
+
writeWorkloadLock,
|
|
11
|
+
} from '../operator-artifacts.js';
|
|
12
|
+
|
|
13
|
+
function toReportId(prefix, workloadId, suffix) {
|
|
14
|
+
return `${prefix}_${workloadId}_${suffix}`;
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
export async function createDistillationRunArtifacts(options) {
|
|
18
|
+
const loadedWorkload = options.loadedWorkload;
|
|
19
|
+
const workload = loadedWorkload?.workload;
|
|
20
|
+
if (!workload || workload.kind !== 'distill') {
|
|
21
|
+
throw new Error('createDistillationRunArtifacts requires a distill workload pack.');
|
|
22
|
+
}
|
|
23
|
+
const layout = options.runRoot
|
|
24
|
+
? {
|
|
25
|
+
runRoot: resolve(String(options.runRoot)),
|
|
26
|
+
logs: join(resolve(String(options.runRoot)), 'logs'),
|
|
27
|
+
checkpoints: join(resolve(String(options.runRoot)), 'checkpoints'),
|
|
28
|
+
eval: join(resolve(String(options.runRoot)), 'eval'),
|
|
29
|
+
scoreboard: join(resolve(String(options.runRoot)), 'scoreboard'),
|
|
30
|
+
exports: join(resolve(String(options.runRoot)), 'exports'),
|
|
31
|
+
compare: join(resolve(String(options.runRoot)), 'compare'),
|
|
32
|
+
qualityGate: join(resolve(String(options.runRoot)), 'quality-gate'),
|
|
33
|
+
}
|
|
34
|
+
: await createTrainingRunLayout({
|
|
35
|
+
kind: 'distill',
|
|
36
|
+
workloadId: workload.id,
|
|
37
|
+
timestamp: options.timestamp || null,
|
|
38
|
+
});
|
|
39
|
+
await Promise.all(
|
|
40
|
+
Object.values(layout).map((dirPath) => mkdir(dirPath, { recursive: true }))
|
|
41
|
+
);
|
|
42
|
+
const runContractPayload = {
|
|
43
|
+
artifactType: 'training_run_contract',
|
|
44
|
+
schemaVersion: 1,
|
|
45
|
+
generatedAt: new Date().toISOString(),
|
|
46
|
+
workloadId: workload.id,
|
|
47
|
+
workloadPath: loadedWorkload.absolutePath,
|
|
48
|
+
workloadSha256: loadedWorkload.workloadSha256,
|
|
49
|
+
configHash: workload.configHash,
|
|
50
|
+
claimBoundary: workload.claimBoundary,
|
|
51
|
+
kind: workload.kind,
|
|
52
|
+
stagePlan: workload.pipeline.stagePlan,
|
|
53
|
+
datasetPath: workload.datasetPath,
|
|
54
|
+
evalDatasets: workload.evalDatasets,
|
|
55
|
+
surfaceSupport: workload.surfaceSupport,
|
|
56
|
+
};
|
|
57
|
+
const runContract = await writeRunContract(layout, runContractPayload);
|
|
58
|
+
const workloadLock = await writeWorkloadLock(layout, loadedWorkload);
|
|
59
|
+
return {
|
|
60
|
+
layout,
|
|
61
|
+
runContract,
|
|
62
|
+
workloadLock,
|
|
63
|
+
runContractPayload,
|
|
64
|
+
};
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
export async function writeDistillStageManifest(layout, payload) {
|
|
68
|
+
const stageId = String(payload?.stageId || payload?.stage || 'stage').trim();
|
|
69
|
+
const filePath = join(layout.checkpoints, stageId, 'distill_stage_manifest.json');
|
|
70
|
+
return writeJsonArtifact(filePath, payload);
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
export async function writeDistillCheckpointMetadata(layout, stageId, checkpointId, payload) {
|
|
74
|
+
const filePath = join(layout.checkpoints, stageId, checkpointId, 'checkpoint.json');
|
|
75
|
+
return writeJsonArtifact(filePath, payload);
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
export async function writeDistillCheckpointComplete(layout, stageId, checkpointId, payload) {
|
|
79
|
+
const filePath = join(layout.checkpoints, stageId, checkpointId, 'checkpoint.complete.json');
|
|
80
|
+
return writeJsonArtifact(filePath, payload);
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
export async function writeDistillEvalReport(layout, payload) {
|
|
84
|
+
const stageId = String(payload?.stage || 'stage').trim();
|
|
85
|
+
const checkpointId = String(payload?.checkpointId || 'checkpoint').trim();
|
|
86
|
+
const evalDatasetId = String(payload?.evalDatasetId || 'eval').trim();
|
|
87
|
+
const filePath = join(layout.eval, stageId, `${checkpointId}__${evalDatasetId}.json`);
|
|
88
|
+
return writeJsonArtifact(filePath, payload);
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
export async function writeDistillCompareReport(layout, payload) {
|
|
92
|
+
return writeJsonArtifact(join(layout.compare, 'compare.json'), payload);
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
export async function writeDistillQualityGateReport(layout, payload) {
|
|
96
|
+
return writeJsonArtifact(join(layout.qualityGate, 'quality-gate.json'), payload);
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
export function buildDistillArtifactBase(loadedWorkload, options) {
|
|
100
|
+
const workload = loadedWorkload.workload;
|
|
101
|
+
const checkpointStep = Number.isInteger(options.checkpointStep)
|
|
102
|
+
? options.checkpointStep
|
|
103
|
+
: null;
|
|
104
|
+
const reportId = toReportId(
|
|
105
|
+
options.prefix || 'dst',
|
|
106
|
+
workload.id,
|
|
107
|
+
`${options.stage || 'stage'}_${checkpointStep == null ? 'final' : String(checkpointStep).padStart(6, '0')}`
|
|
108
|
+
);
|
|
109
|
+
const payload = buildArtifactBase({
|
|
110
|
+
artifactType: options.artifactType,
|
|
111
|
+
reportId,
|
|
112
|
+
workload,
|
|
113
|
+
workloadPath: loadedWorkload.absolutePath,
|
|
114
|
+
workloadSha256: loadedWorkload.workloadSha256,
|
|
115
|
+
datasetPath: options.datasetPath || workload.datasetPath,
|
|
116
|
+
datasetHash: options.datasetHash || null,
|
|
117
|
+
baseModelId: workload.baseModelId,
|
|
118
|
+
teacherModelId: workload.teacherModelId,
|
|
119
|
+
studentModelId: workload.studentModelId,
|
|
120
|
+
stage: options.stage || null,
|
|
121
|
+
checkpointStep,
|
|
122
|
+
parentArtifacts: options.parentArtifacts || [],
|
|
123
|
+
runtime: options.runtime || 'node',
|
|
124
|
+
surface: options.surface || 'node',
|
|
125
|
+
claimBoundary: workload.claimBoundary,
|
|
126
|
+
configHash: options.configHash || workload.configHash,
|
|
127
|
+
});
|
|
128
|
+
return {
|
|
129
|
+
...payload,
|
|
130
|
+
artifactHash: hashArtifactPayload(payload),
|
|
131
|
+
};
|
|
132
|
+
}
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
import type { LoadedTrainingWorkload } from '../workloads.js';
|
|
2
|
+
|
|
3
|
+
export declare function watchDistillationCheckpoints(options: {
|
|
4
|
+
loadedWorkload: LoadedTrainingWorkload;
|
|
5
|
+
layout: Record<string, string>;
|
|
6
|
+
checkpointsDir?: string | null;
|
|
7
|
+
manifestPath?: string | null;
|
|
8
|
+
pollIntervalMs?: number | null;
|
|
9
|
+
stopWhenIdle?: boolean;
|
|
10
|
+
}): Promise<{ ok: true; processedCount: number; manifestPath: string }>;
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import { join, resolve } from 'node:path';
|
|
2
|
+
|
|
3
|
+
import { watchFinalizedCheckpoints } from '../checkpoint-watch.js';
|
|
4
|
+
import { appendDistillationScoreboardRow } from './scoreboard.js';
|
|
5
|
+
import { evaluateDistillationCheckpoint, readDistillCheckpointMarker } from './eval.js';
|
|
6
|
+
|
|
7
|
+
export async function watchDistillationCheckpoints(options) {
|
|
8
|
+
const layout = options.layout;
|
|
9
|
+
const loadedWorkload = options.loadedWorkload;
|
|
10
|
+
const checkpointsDir = resolve(options.checkpointsDir || layout.checkpoints);
|
|
11
|
+
const manifestPath = resolve(options.manifestPath || join(layout.scoreboard, 'checkpoint-watch-manifest.json'));
|
|
12
|
+
return watchFinalizedCheckpoints({
|
|
13
|
+
checkpointsDir,
|
|
14
|
+
manifestPath,
|
|
15
|
+
pollIntervalMs: options.pollIntervalMs || 2000,
|
|
16
|
+
stopWhenIdle: options.stopWhenIdle === true,
|
|
17
|
+
onCheckpoint: async (markerPath) => {
|
|
18
|
+
const { marker } = await readDistillCheckpointMarker(markerPath);
|
|
19
|
+
const reports = await evaluateDistillationCheckpoint({
|
|
20
|
+
loadedWorkload,
|
|
21
|
+
checkpointPath: marker.checkpointPath,
|
|
22
|
+
checkpointId: marker.checkpointId,
|
|
23
|
+
checkpointStep: marker.checkpointStep,
|
|
24
|
+
stageId: marker.stage,
|
|
25
|
+
layout,
|
|
26
|
+
stageAArtifact: marker.stageArtifact || null,
|
|
27
|
+
stageAArtifactHash: marker.stageArtifactHash || null,
|
|
28
|
+
});
|
|
29
|
+
for (const report of reports) {
|
|
30
|
+
await appendDistillationScoreboardRow(layout, String(marker.stage || 'stage'), {
|
|
31
|
+
artifactType: 'training_scoreboard',
|
|
32
|
+
schemaVersion: 1,
|
|
33
|
+
generatedAt: new Date().toISOString(),
|
|
34
|
+
stage: marker.stage,
|
|
35
|
+
checkpointId: report.checkpointId,
|
|
36
|
+
checkpointStep: report.checkpointStep,
|
|
37
|
+
evalDatasetId: report.evalDatasetId,
|
|
38
|
+
selectionMetric: report.primaryMetric,
|
|
39
|
+
selectionGoal: 'max',
|
|
40
|
+
primaryMetric: report.primaryMetric,
|
|
41
|
+
primaryScore: report.primaryScore,
|
|
42
|
+
bleu: report.bleu,
|
|
43
|
+
chrf: report.chrf,
|
|
44
|
+
reportPath: report.reportPath || null,
|
|
45
|
+
metrics: {
|
|
46
|
+
bleu: report.bleu,
|
|
47
|
+
chrf: report.chrf,
|
|
48
|
+
primaryScore: report.primaryScore,
|
|
49
|
+
},
|
|
50
|
+
}, {
|
|
51
|
+
selectionMetric: report.primaryMetric,
|
|
52
|
+
selectionGoal: 'max',
|
|
53
|
+
});
|
|
54
|
+
}
|
|
55
|
+
},
|
|
56
|
+
});
|
|
57
|
+
}
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
export interface CanonicalTranslationRow {
|
|
2
|
+
row_id: string;
|
|
3
|
+
src_lang: string | null;
|
|
4
|
+
tgt_lang: string | null;
|
|
5
|
+
pair: string | null;
|
|
6
|
+
source: string;
|
|
7
|
+
target_pos: string;
|
|
8
|
+
target_neg: string | null;
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
export interface CanonicalTranslationDataset {
|
|
12
|
+
absolutePath: string;
|
|
13
|
+
raw: string;
|
|
14
|
+
rows: CanonicalTranslationRow[];
|
|
15
|
+
rowCount: number;
|
|
16
|
+
directionCounts: Record<string, number>;
|
|
17
|
+
datasetHash: string;
|
|
18
|
+
canonicalHash: string;
|
|
19
|
+
rowIdsHash: string;
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
export declare function normalizeDistillationPair(
|
|
23
|
+
value: unknown,
|
|
24
|
+
srcLang?: string | null,
|
|
25
|
+
tgtLang?: string | null
|
|
26
|
+
): string | null;
|
|
27
|
+
|
|
28
|
+
export declare function normalizeTranslationPairRow(
|
|
29
|
+
record: Record<string, unknown>,
|
|
30
|
+
index: number,
|
|
31
|
+
options?: { strictPairContract?: boolean }
|
|
32
|
+
): CanonicalTranslationRow | null;
|
|
33
|
+
|
|
34
|
+
export declare function loadCanonicalTranslationDataset(
|
|
35
|
+
datasetPath: string,
|
|
36
|
+
options?: {
|
|
37
|
+
strictPairContract?: boolean;
|
|
38
|
+
sourceLangs?: string[] | string | null;
|
|
39
|
+
targetLangs?: string[] | string | null;
|
|
40
|
+
pairAllowlist?: string[] | string | null;
|
|
41
|
+
}
|
|
42
|
+
): Promise<CanonicalTranslationDataset>;
|
|
43
|
+
|
|
44
|
+
export declare function buildFrozenSubset(options: {
|
|
45
|
+
datasetPath: string;
|
|
46
|
+
outputDir: string;
|
|
47
|
+
strictPairContract?: boolean;
|
|
48
|
+
sourceLangs?: string[] | string | null;
|
|
49
|
+
targetLangs?: string[] | string | null;
|
|
50
|
+
pairAllowlist?: string[] | string | null;
|
|
51
|
+
subsetSpec?: Record<string, unknown> | null;
|
|
52
|
+
}): Promise<{
|
|
53
|
+
dataset: CanonicalTranslationDataset;
|
|
54
|
+
subsetRows: CanonicalTranslationRow[];
|
|
55
|
+
subsetJsonlPath: string;
|
|
56
|
+
rowIdsPath: string;
|
|
57
|
+
manifestPath: string;
|
|
58
|
+
manifest: Record<string, unknown>;
|
|
59
|
+
}>;
|
|
@@ -0,0 +1,337 @@
|
|
|
1
|
+
import { mkdir, readFile, writeFile } from 'node:fs/promises';
|
|
2
|
+
import { dirname, join, resolve } from 'node:path';
|
|
3
|
+
|
|
4
|
+
import { parseJsonl } from '../datasets/jsonl.js';
|
|
5
|
+
import { sha256Hex } from '../../utils/sha256.js';
|
|
6
|
+
|
|
7
|
+
function stableSortObject(value) {
|
|
8
|
+
if (Array.isArray(value)) {
|
|
9
|
+
return value.map((entry) => stableSortObject(entry));
|
|
10
|
+
}
|
|
11
|
+
if (!value || typeof value !== 'object') {
|
|
12
|
+
return value;
|
|
13
|
+
}
|
|
14
|
+
const sorted = {};
|
|
15
|
+
for (const key of Object.keys(value).sort()) {
|
|
16
|
+
sorted[key] = stableSortObject(value[key]);
|
|
17
|
+
}
|
|
18
|
+
return sorted;
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
function stableJson(value) {
|
|
22
|
+
return JSON.stringify(stableSortObject(value));
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
function asOptionalString(value) {
|
|
26
|
+
if (value === undefined || value === null) return null;
|
|
27
|
+
const trimmed = String(value).trim();
|
|
28
|
+
return trimmed || null;
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
function asOptionalStringArray(value) {
|
|
32
|
+
if (value === undefined || value === null) return null;
|
|
33
|
+
const input = Array.isArray(value)
|
|
34
|
+
? value
|
|
35
|
+
: (typeof value === 'string' ? value.split(',') : null);
|
|
36
|
+
if (!Array.isArray(input)) {
|
|
37
|
+
return null;
|
|
38
|
+
}
|
|
39
|
+
const normalized = input
|
|
40
|
+
.map((entry) => asOptionalString(entry))
|
|
41
|
+
.filter(Boolean);
|
|
42
|
+
return normalized.length > 0 ? normalized : null;
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
function normalizeLangCode(value) {
|
|
46
|
+
const normalized = asOptionalString(value);
|
|
47
|
+
if (!normalized) return null;
|
|
48
|
+
const compact = normalized.toLowerCase().replace(/_/g, '-');
|
|
49
|
+
if (compact.startsWith('en')) return 'en';
|
|
50
|
+
if (compact.startsWith('es')) return 'es';
|
|
51
|
+
return compact;
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
export function normalizeDistillationPair(value, srcLang = null, tgtLang = null) {
|
|
55
|
+
const pair = asOptionalString(value);
|
|
56
|
+
if (!pair) {
|
|
57
|
+
if (!srcLang || !tgtLang) return null;
|
|
58
|
+
return `${srcLang}->${tgtLang}`;
|
|
59
|
+
}
|
|
60
|
+
const normalized = pair.toLowerCase().replace(/_/g, '-').replace(/\s+/g, '');
|
|
61
|
+
const separator = normalized.includes('->') ? '->' : '-';
|
|
62
|
+
const parts = normalized.split(separator).filter(Boolean);
|
|
63
|
+
if (parts.length !== 2) return null;
|
|
64
|
+
const source = normalizeLangCode(parts[0]) || parts[0];
|
|
65
|
+
const target = normalizeLangCode(parts[1]) || parts[1];
|
|
66
|
+
return `${source}->${target}`;
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
function resolveStringCandidate(record, keys) {
|
|
70
|
+
for (const key of keys) {
|
|
71
|
+
const value = asOptionalString(record?.[key]);
|
|
72
|
+
if (value) return value;
|
|
73
|
+
}
|
|
74
|
+
return null;
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
function resolveStableRowId(record, index, canonical) {
|
|
78
|
+
const explicit = asOptionalString(record?.row_id ?? record?.rowId);
|
|
79
|
+
if (explicit) return explicit;
|
|
80
|
+
return sha256Hex(stableJson({
|
|
81
|
+
index,
|
|
82
|
+
src_lang: canonical.src_lang,
|
|
83
|
+
tgt_lang: canonical.tgt_lang,
|
|
84
|
+
pair: canonical.pair,
|
|
85
|
+
source: canonical.source,
|
|
86
|
+
target_pos: canonical.target_pos,
|
|
87
|
+
target_neg: canonical.target_neg,
|
|
88
|
+
}));
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
export function normalizeTranslationPairRow(record, index, options = {}) {
|
|
92
|
+
if (!record || typeof record !== 'object' || Array.isArray(record)) {
|
|
93
|
+
return null;
|
|
94
|
+
}
|
|
95
|
+
const source = resolveStringCandidate(record, ['source', 'query', 'prompt']);
|
|
96
|
+
const targetPos = resolveStringCandidate(record, ['target_pos', 'target', 'pos', 'completion']);
|
|
97
|
+
const targetNeg = resolveStringCandidate(record, ['target_neg', 'neg']);
|
|
98
|
+
if (!source || !targetPos) {
|
|
99
|
+
return null;
|
|
100
|
+
}
|
|
101
|
+
const srcLang = normalizeLangCode(record?.src_lang ?? record?.source_lang);
|
|
102
|
+
const tgtLang = normalizeLangCode(record?.tgt_lang ?? record?.target_lang ?? record?.lang);
|
|
103
|
+
const pair = normalizeDistillationPair(record?.pair, srcLang, tgtLang);
|
|
104
|
+
const strict = options.strictPairContract === true;
|
|
105
|
+
if (strict) {
|
|
106
|
+
if (!srcLang || !tgtLang) {
|
|
107
|
+
throw new Error('strictPairContract requires src_lang and tgt_lang on each row.');
|
|
108
|
+
}
|
|
109
|
+
if (!pair) {
|
|
110
|
+
throw new Error('strictPairContract requires pair on each row.');
|
|
111
|
+
}
|
|
112
|
+
if (pair !== `${srcLang}->${tgtLang}`) {
|
|
113
|
+
throw new Error(`row pair "${record?.pair}" does not match src/tgt "${srcLang}->${tgtLang}".`);
|
|
114
|
+
}
|
|
115
|
+
}
|
|
116
|
+
const canonical = {
|
|
117
|
+
src_lang: srcLang,
|
|
118
|
+
tgt_lang: tgtLang,
|
|
119
|
+
pair: pair || null,
|
|
120
|
+
source,
|
|
121
|
+
target_pos: targetPos,
|
|
122
|
+
target_neg: targetNeg,
|
|
123
|
+
};
|
|
124
|
+
canonical.row_id = resolveStableRowId(record, index, canonical);
|
|
125
|
+
return canonical;
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
function normalizeFilterSet(value, normalizer) {
|
|
129
|
+
const entries = asOptionalStringArray(value);
|
|
130
|
+
if (!entries) return null;
|
|
131
|
+
const normalized = entries
|
|
132
|
+
.map((entry) => normalizer(entry))
|
|
133
|
+
.filter(Boolean);
|
|
134
|
+
return normalized.length > 0 ? new Set(normalized) : null;
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
function applyDatasetFilters(rows, options = {}) {
|
|
138
|
+
const sourceLangs = normalizeFilterSet(options.sourceLangs, normalizeLangCode);
|
|
139
|
+
const targetLangs = normalizeFilterSet(options.targetLangs, normalizeLangCode);
|
|
140
|
+
const pairs = normalizeFilterSet(options.pairAllowlist, normalizeDistillationPair);
|
|
141
|
+
return rows.filter((row) => {
|
|
142
|
+
if (sourceLangs && (!row.src_lang || !sourceLangs.has(row.src_lang))) {
|
|
143
|
+
return false;
|
|
144
|
+
}
|
|
145
|
+
if (targetLangs && (!row.tgt_lang || !targetLangs.has(row.tgt_lang))) {
|
|
146
|
+
return false;
|
|
147
|
+
}
|
|
148
|
+
if (pairs && (!row.pair || !pairs.has(row.pair))) {
|
|
149
|
+
return false;
|
|
150
|
+
}
|
|
151
|
+
return true;
|
|
152
|
+
});
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
function normalizeSubsetSpec(value) {
|
|
156
|
+
if (!value || typeof value !== 'object' || Array.isArray(value)) {
|
|
157
|
+
return null;
|
|
158
|
+
}
|
|
159
|
+
const sizeRaw = Number(value.size ?? value.count ?? value.rowCount ?? 0);
|
|
160
|
+
const size = Number.isInteger(sizeRaw) && sizeRaw > 0 ? sizeRaw : null;
|
|
161
|
+
const seedRaw = Number(value.seed ?? 1337);
|
|
162
|
+
const seed = Number.isInteger(seedRaw) ? seedRaw : 1337;
|
|
163
|
+
const balanceBy = asOptionalString(value.balanceBy ?? value.allocationMode ?? value.stratifyBy);
|
|
164
|
+
return {
|
|
165
|
+
id: asOptionalString(value.id ?? value.name) || null,
|
|
166
|
+
size,
|
|
167
|
+
seed,
|
|
168
|
+
balanceBy,
|
|
169
|
+
parentSubsetManifest: asOptionalString(
|
|
170
|
+
value.parentSubsetManifest
|
|
171
|
+
?? value.parentSubset
|
|
172
|
+
?? value.parentManifest
|
|
173
|
+
),
|
|
174
|
+
};
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
function buildDeterministicRank(seed, rowId) {
|
|
178
|
+
return sha256Hex(`${seed}:${rowId}`);
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
async function resolveParentRowIds(parentSubsetManifest) {
|
|
182
|
+
if (!parentSubsetManifest) return null;
|
|
183
|
+
const absolutePath = resolve(parentSubsetManifest);
|
|
184
|
+
const raw = await readFile(absolutePath, 'utf8');
|
|
185
|
+
const parsed = JSON.parse(raw);
|
|
186
|
+
const rowIds = Array.isArray(parsed?.rowIds)
|
|
187
|
+
? parsed.rowIds
|
|
188
|
+
: [];
|
|
189
|
+
return rowIds.length > 0 ? new Set(rowIds.map((entry) => String(entry))) : new Set();
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
function buildPairBalancedSubset(rows, size, seed) {
|
|
193
|
+
const byPair = new Map();
|
|
194
|
+
for (const row of rows) {
|
|
195
|
+
const key = row.pair || 'unknown';
|
|
196
|
+
const bucket = byPair.get(key) || [];
|
|
197
|
+
bucket.push(row);
|
|
198
|
+
byPair.set(key, bucket);
|
|
199
|
+
}
|
|
200
|
+
for (const bucket of byPair.values()) {
|
|
201
|
+
bucket.sort((left, right) => {
|
|
202
|
+
const leftRank = buildDeterministicRank(seed, left.row_id);
|
|
203
|
+
const rightRank = buildDeterministicRank(seed, right.row_id);
|
|
204
|
+
return leftRank.localeCompare(rightRank);
|
|
205
|
+
});
|
|
206
|
+
}
|
|
207
|
+
const pairKeys = [...byPair.keys()].sort((left, right) => left.localeCompare(right));
|
|
208
|
+
const selected = [];
|
|
209
|
+
let cursor = 0;
|
|
210
|
+
while (selected.length < size) {
|
|
211
|
+
let progressed = false;
|
|
212
|
+
for (const pairKey of pairKeys) {
|
|
213
|
+
const bucket = byPair.get(pairKey);
|
|
214
|
+
if (!bucket || cursor >= bucket.length) continue;
|
|
215
|
+
selected.push(bucket[cursor]);
|
|
216
|
+
progressed = true;
|
|
217
|
+
if (selected.length >= size) break;
|
|
218
|
+
}
|
|
219
|
+
if (!progressed) break;
|
|
220
|
+
cursor += 1;
|
|
221
|
+
}
|
|
222
|
+
return selected;
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
function selectSubsetRows(rows, subsetSpec) {
|
|
226
|
+
if (!subsetSpec || !subsetSpec.size || subsetSpec.size >= rows.length) {
|
|
227
|
+
return rows.slice();
|
|
228
|
+
}
|
|
229
|
+
if (subsetSpec.balanceBy === 'pair' || subsetSpec.balanceBy === 'pair_balance') {
|
|
230
|
+
return buildPairBalancedSubset(rows, subsetSpec.size, subsetSpec.seed);
|
|
231
|
+
}
|
|
232
|
+
return rows
|
|
233
|
+
.slice()
|
|
234
|
+
.sort((left, right) => {
|
|
235
|
+
const leftRank = buildDeterministicRank(subsetSpec.seed, left.row_id);
|
|
236
|
+
const rightRank = buildDeterministicRank(subsetSpec.seed, right.row_id);
|
|
237
|
+
return leftRank.localeCompare(rightRank);
|
|
238
|
+
})
|
|
239
|
+
.slice(0, subsetSpec.size);
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
function computeDirectionCounts(rows) {
|
|
243
|
+
const counts = {};
|
|
244
|
+
for (const row of rows) {
|
|
245
|
+
const key = row.pair || 'unknown';
|
|
246
|
+
counts[key] = (counts[key] || 0) + 1;
|
|
247
|
+
}
|
|
248
|
+
return counts;
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
export async function loadCanonicalTranslationDataset(datasetPath, options = {}) {
|
|
252
|
+
const absolutePath = resolve(String(datasetPath));
|
|
253
|
+
const raw = await readFile(absolutePath, 'utf8');
|
|
254
|
+
const parsed = absolutePath.endsWith('.json')
|
|
255
|
+
? JSON.parse(raw)
|
|
256
|
+
: parseJsonl(raw);
|
|
257
|
+
if (!Array.isArray(parsed)) {
|
|
258
|
+
throw new Error(`Distillation dataset "${absolutePath}" must be a JSON array or JSONL file.`);
|
|
259
|
+
}
|
|
260
|
+
const normalizedRows = [];
|
|
261
|
+
for (let index = 0; index < parsed.length; index += 1) {
|
|
262
|
+
const row = normalizeTranslationPairRow(parsed[index], index, options);
|
|
263
|
+
if (row) {
|
|
264
|
+
normalizedRows.push(row);
|
|
265
|
+
}
|
|
266
|
+
}
|
|
267
|
+
const filteredRows = applyDatasetFilters(normalizedRows, options);
|
|
268
|
+
if (filteredRows.length === 0) {
|
|
269
|
+
throw new Error(`Distillation dataset "${absolutePath}" has no usable rows after contract checks and filters.`);
|
|
270
|
+
}
|
|
271
|
+
const rowIds = filteredRows.map((row) => row.row_id);
|
|
272
|
+
return {
|
|
273
|
+
absolutePath,
|
|
274
|
+
raw,
|
|
275
|
+
rows: filteredRows,
|
|
276
|
+
rowCount: filteredRows.length,
|
|
277
|
+
directionCounts: computeDirectionCounts(filteredRows),
|
|
278
|
+
datasetHash: sha256Hex(raw),
|
|
279
|
+
canonicalHash: sha256Hex(stableJson(filteredRows)),
|
|
280
|
+
rowIdsHash: sha256Hex(rowIds.join('\n')),
|
|
281
|
+
};
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
export async function buildFrozenSubset(options) {
|
|
285
|
+
const dataset = await loadCanonicalTranslationDataset(options.datasetPath, {
|
|
286
|
+
strictPairContract: options.strictPairContract === true,
|
|
287
|
+
sourceLangs: options.sourceLangs,
|
|
288
|
+
targetLangs: options.targetLangs,
|
|
289
|
+
pairAllowlist: options.pairAllowlist,
|
|
290
|
+
});
|
|
291
|
+
const subsetSpec = normalizeSubsetSpec(options.subsetSpec);
|
|
292
|
+
const parentRowIds = subsetSpec?.parentSubsetManifest
|
|
293
|
+
? await resolveParentRowIds(subsetSpec.parentSubsetManifest)
|
|
294
|
+
: null;
|
|
295
|
+
const scopedRows = parentRowIds
|
|
296
|
+
? dataset.rows.filter((row) => parentRowIds.has(row.row_id))
|
|
297
|
+
: dataset.rows;
|
|
298
|
+
const subsetRows = selectSubsetRows(scopedRows, subsetSpec);
|
|
299
|
+
const outputDir = resolve(String(options.outputDir));
|
|
300
|
+
const subsetJsonlPath = join(outputDir, 'subset.jsonl');
|
|
301
|
+
const rowIdsPath = join(outputDir, 'row_ids.txt');
|
|
302
|
+
const manifestPath = join(outputDir, 'subset_manifest.json');
|
|
303
|
+
const serializedRows = `${subsetRows.map((row) => JSON.stringify(row)).join('\n')}\n`;
|
|
304
|
+
const rowIdsText = `${subsetRows.map((row) => row.row_id).join('\n')}\n`;
|
|
305
|
+
const manifest = {
|
|
306
|
+
artifactType: 'subset_manifest',
|
|
307
|
+
schemaVersion: 1,
|
|
308
|
+
generatedAt: new Date().toISOString(),
|
|
309
|
+
datasetPath: dataset.absolutePath,
|
|
310
|
+
datasetHash: dataset.datasetHash,
|
|
311
|
+
canonicalHash: dataset.canonicalHash,
|
|
312
|
+
rowIdsHash: sha256Hex(rowIdsText),
|
|
313
|
+
universeRowCount: dataset.rowCount,
|
|
314
|
+
subsetRowCount: subsetRows.length,
|
|
315
|
+
directionCounts: computeDirectionCounts(subsetRows),
|
|
316
|
+
subsetSpec,
|
|
317
|
+
parentSubsetManifest: subsetSpec?.parentSubsetManifest || null,
|
|
318
|
+
rowIds: subsetRows.map((row) => row.row_id),
|
|
319
|
+
output: {
|
|
320
|
+
subsetJsonlPath,
|
|
321
|
+
rowIdsPath,
|
|
322
|
+
},
|
|
323
|
+
};
|
|
324
|
+
await mkdir(outputDir, { recursive: true });
|
|
325
|
+
await writeFile(subsetJsonlPath, serializedRows, 'utf8');
|
|
326
|
+
await writeFile(rowIdsPath, rowIdsText, 'utf8');
|
|
327
|
+
await mkdir(dirname(manifestPath), { recursive: true });
|
|
328
|
+
await writeFile(manifestPath, `${JSON.stringify(manifest, null, 2)}\n`, 'utf8');
|
|
329
|
+
return {
|
|
330
|
+
dataset,
|
|
331
|
+
subsetRows,
|
|
332
|
+
subsetJsonlPath,
|
|
333
|
+
rowIdsPath,
|
|
334
|
+
manifestPath,
|
|
335
|
+
manifest,
|
|
336
|
+
};
|
|
337
|
+
}
|