@simulatte/doppler 0.1.4 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (199) hide show
  1. package/README.md +26 -10
  2. package/package.json +30 -6
  3. package/src/client/doppler-api.browser.d.ts +1 -0
  4. package/src/client/doppler-api.browser.js +288 -0
  5. package/src/client/doppler-api.js +1 -1
  6. package/src/client/doppler-provider/types.js +1 -1
  7. package/src/config/execution-contract-check.d.ts +33 -0
  8. package/src/config/execution-contract-check.js +72 -0
  9. package/src/config/execution-v0-contract-check.d.ts +94 -0
  10. package/src/config/execution-v0-contract-check.js +251 -0
  11. package/src/config/execution-v0-graph-contract-check.d.ts +20 -0
  12. package/src/config/execution-v0-graph-contract-check.js +64 -0
  13. package/src/config/kernel-path-contract-check.d.ts +76 -0
  14. package/src/config/kernel-path-contract-check.js +479 -0
  15. package/src/config/kernel-path-loader.d.ts +16 -0
  16. package/src/config/kernel-path-loader.js +54 -0
  17. package/src/config/kernels/kernel-ref-digests.js +39 -27
  18. package/src/config/kernels/registry.json +598 -2
  19. package/src/config/loader.js +81 -48
  20. package/src/config/merge-contract-check.d.ts +16 -0
  21. package/src/config/merge-contract-check.js +321 -0
  22. package/src/config/merge-helpers.d.ts +58 -0
  23. package/src/config/merge-helpers.js +54 -0
  24. package/src/config/merge.js +21 -6
  25. package/src/config/presets/models/janus-text.json +2 -0
  26. package/src/config/presets/models/qwen3.json +9 -2
  27. package/src/config/presets/models/transformer.json +5 -0
  28. package/src/config/quantization-contract-check.d.ts +12 -0
  29. package/src/config/quantization-contract-check.js +91 -0
  30. package/src/config/required-inference-fields-contract-check.d.ts +24 -0
  31. package/src/config/required-inference-fields-contract-check.js +237 -0
  32. package/src/config/schema/browser-suite-metrics.schema.d.ts +17 -0
  33. package/src/config/schema/browser-suite-metrics.schema.js +46 -0
  34. package/src/config/schema/conversion-report.schema.d.ts +40 -0
  35. package/src/config/schema/conversion-report.schema.js +108 -0
  36. package/src/config/schema/doppler.schema.js +12 -18
  37. package/src/config/schema/index.d.ts +22 -0
  38. package/src/config/schema/index.js +18 -0
  39. package/src/config/schema/inference-defaults.schema.js +3 -0
  40. package/src/config/schema/inference.schema.d.ts +9 -0
  41. package/src/config/schema/kernel-path.schema.d.ts +6 -0
  42. package/src/config/schema/manifest.schema.d.ts +6 -0
  43. package/src/config/schema/manifest.schema.js +3 -0
  44. package/src/converter/core.d.ts +10 -0
  45. package/src/converter/core.js +27 -2
  46. package/src/converter/parsers/diffusion.js +63 -3
  47. package/src/converter/rope-config.js +42 -0
  48. package/src/gpu/device.js +58 -0
  49. package/src/gpu/kernels/attention.js +98 -0
  50. package/src/gpu/kernels/bias_add.wgsl +8 -6
  51. package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
  52. package/src/gpu/kernels/conv2d.js +1 -1
  53. package/src/gpu/kernels/conv2d.wgsl +7 -8
  54. package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
  55. package/src/gpu/kernels/depthwise_conv2d.d.ts +29 -0
  56. package/src/gpu/kernels/depthwise_conv2d.js +99 -0
  57. package/src/gpu/kernels/depthwise_conv2d.wgsl +55 -0
  58. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +59 -0
  59. package/src/gpu/kernels/grouped_pointwise_conv2d.d.ts +27 -0
  60. package/src/gpu/kernels/grouped_pointwise_conv2d.js +93 -0
  61. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +44 -0
  62. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +48 -0
  63. package/src/gpu/kernels/index.d.ts +30 -0
  64. package/src/gpu/kernels/index.js +25 -0
  65. package/src/gpu/kernels/matmul.js +25 -0
  66. package/src/gpu/kernels/pixel_shuffle.js +1 -1
  67. package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
  68. package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
  69. package/src/gpu/kernels/relu.d.ts +18 -0
  70. package/src/gpu/kernels/relu.js +58 -0
  71. package/src/gpu/kernels/relu.wgsl +22 -0
  72. package/src/gpu/kernels/relu_f16.wgsl +24 -0
  73. package/src/gpu/kernels/repeat_channels.d.ts +21 -0
  74. package/src/gpu/kernels/repeat_channels.js +60 -0
  75. package/src/gpu/kernels/repeat_channels.wgsl +28 -0
  76. package/src/gpu/kernels/repeat_channels_f16.wgsl +30 -0
  77. package/src/gpu/kernels/residual.js +44 -8
  78. package/src/gpu/kernels/residual.wgsl +6 -3
  79. package/src/gpu/kernels/residual_f16.wgsl +2 -1
  80. package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
  81. package/src/gpu/kernels/residual_vec4.wgsl +2 -1
  82. package/src/gpu/kernels/rmsnorm.js +58 -6
  83. package/src/gpu/kernels/rmsnorm.wgsl +14 -6
  84. package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
  85. package/src/gpu/kernels/rope.d.ts +2 -0
  86. package/src/gpu/kernels/rope.js +11 -1
  87. package/src/gpu/kernels/rope.wgsl +56 -40
  88. package/src/gpu/kernels/sana_linear_attention.d.ts +27 -0
  89. package/src/gpu/kernels/sana_linear_attention.js +121 -0
  90. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +43 -0
  91. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +46 -0
  92. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +51 -0
  93. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +53 -0
  94. package/src/gpu/kernels/silu.d.ts +1 -0
  95. package/src/gpu/kernels/silu.js +32 -14
  96. package/src/gpu/kernels/silu.wgsl +19 -9
  97. package/src/gpu/kernels/silu_f16.wgsl +19 -9
  98. package/src/gpu/kernels/transpose.js +15 -2
  99. package/src/gpu/kernels/transpose.wgsl +5 -6
  100. package/src/gpu/kernels/upsample2d.js +2 -1
  101. package/src/gpu/kernels/upsample2d.wgsl +6 -9
  102. package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
  103. package/src/gpu/kernels/utils.js +16 -1
  104. package/src/index-browser.d.ts +1 -1
  105. package/src/index-browser.js +2 -2
  106. package/src/index.js +1 -1
  107. package/src/inference/browser-harness.js +109 -23
  108. package/src/inference/pipelines/diffusion/init.js +14 -0
  109. package/src/inference/pipelines/diffusion/pipeline.js +215 -77
  110. package/src/inference/pipelines/diffusion/sana-transformer.d.ts +53 -0
  111. package/src/inference/pipelines/diffusion/sana-transformer.js +738 -0
  112. package/src/inference/pipelines/diffusion/scheduler.d.ts +17 -1
  113. package/src/inference/pipelines/diffusion/scheduler.js +91 -3
  114. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +11 -4
  115. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +282 -0
  116. package/src/inference/pipelines/diffusion/text-encoder.js +18 -1
  117. package/src/inference/pipelines/diffusion/types.d.ts +4 -0
  118. package/src/inference/pipelines/diffusion/vae.js +782 -78
  119. package/src/inference/pipelines/text/attention/record.js +11 -2
  120. package/src/inference/pipelines/text/attention/run.js +11 -2
  121. package/src/inference/pipelines/text/chat-format.js +25 -1
  122. package/src/inference/pipelines/text/config.d.ts +9 -0
  123. package/src/inference/pipelines/text/config.js +69 -2
  124. package/src/inference/pipelines/text/execution-plan.js +23 -31
  125. package/src/inference/pipelines/text/execution-v0.js +43 -95
  126. package/src/inference/pipelines/text/ffn/standard.js +3 -0
  127. package/src/inference/pipelines/text/init.d.ts +4 -0
  128. package/src/inference/pipelines/text/init.js +56 -9
  129. package/src/inference/pipelines/text/layer.js +11 -0
  130. package/src/inference/pipelines/text.js +4 -0
  131. package/src/inference/tokenizers/bundled.js +156 -33
  132. package/src/rules/execution-rules-contract-check.d.ts +17 -0
  133. package/src/rules/execution-rules-contract-check.js +245 -0
  134. package/src/rules/kernels/depthwise-conv2d.rules.json +6 -0
  135. package/src/rules/kernels/grouped-pointwise-conv2d.rules.json +6 -0
  136. package/src/rules/kernels/relu.rules.json +6 -0
  137. package/src/rules/kernels/repeat-channels.rules.json +6 -0
  138. package/src/rules/kernels/sana-linear-attention.rules.json +6 -0
  139. package/src/rules/layer-pattern-contract-check.d.ts +17 -0
  140. package/src/rules/layer-pattern-contract-check.js +231 -0
  141. package/src/rules/rule-registry.d.ts +28 -0
  142. package/src/rules/rule-registry.js +38 -0
  143. package/src/rules/tooling/command-runtime.rules.json +18 -0
  144. package/src/tooling/command-api.d.ts +27 -1
  145. package/src/tooling/command-api.js +142 -3
  146. package/src/tooling/conversion-config-materializer.d.ts +24 -0
  147. package/src/tooling/conversion-config-materializer.js +99 -0
  148. package/src/tooling/lean-execution-contract-runner.d.ts +43 -0
  149. package/src/tooling/lean-execution-contract-runner.js +158 -0
  150. package/src/tooling/node-browser-command-runner.d.ts +4 -0
  151. package/src/tooling/node-browser-command-runner.js +58 -3
  152. package/src/tooling/node-command-runner.js +15 -0
  153. package/src/tooling/node-convert.d.ts +10 -0
  154. package/src/tooling/node-converter.js +59 -0
  155. package/src/tooling/node-webgpu.js +11 -89
  156. package/src/training/checkpoint-watch.d.ts +7 -0
  157. package/src/training/checkpoint-watch.js +106 -0
  158. package/src/training/checkpoint.d.ts +6 -1
  159. package/src/training/checkpoint.js +12 -2
  160. package/src/training/distillation/artifacts.d.ts +71 -0
  161. package/src/training/distillation/artifacts.js +132 -0
  162. package/src/training/distillation/checkpoint-watch.d.ts +10 -0
  163. package/src/training/distillation/checkpoint-watch.js +57 -0
  164. package/src/training/distillation/dataset.d.ts +59 -0
  165. package/src/training/distillation/dataset.js +337 -0
  166. package/src/training/distillation/eval.d.ts +34 -0
  167. package/src/training/distillation/eval.js +310 -0
  168. package/src/training/distillation/index.d.ts +29 -0
  169. package/src/training/distillation/index.js +29 -0
  170. package/src/training/distillation/runtime.d.ts +20 -0
  171. package/src/training/distillation/runtime.js +121 -0
  172. package/src/training/distillation/scoreboard.d.ts +6 -0
  173. package/src/training/distillation/scoreboard.js +8 -0
  174. package/src/training/distillation/stage-a.d.ts +45 -0
  175. package/src/training/distillation/stage-a.js +338 -0
  176. package/src/training/distillation/stage-b.d.ts +24 -0
  177. package/src/training/distillation/stage-b.js +20 -0
  178. package/src/training/index.d.ts +10 -0
  179. package/src/training/index.js +10 -0
  180. package/src/training/lora-pipeline.d.ts +40 -0
  181. package/src/training/lora-pipeline.js +796 -0
  182. package/src/training/operator-artifacts.d.ts +62 -0
  183. package/src/training/operator-artifacts.js +140 -0
  184. package/src/training/operator-command.d.ts +5 -0
  185. package/src/training/operator-command.js +453 -0
  186. package/src/training/operator-eval.d.ts +48 -0
  187. package/src/training/operator-eval.js +230 -0
  188. package/src/training/operator-scoreboard.d.ts +5 -0
  189. package/src/training/operator-scoreboard.js +44 -0
  190. package/src/training/runner.d.ts +52 -0
  191. package/src/training/runner.js +29 -4
  192. package/src/training/suite.d.ts +112 -0
  193. package/src/training/suite.js +9 -9
  194. package/src/training/workloads.d.ts +164 -0
  195. package/src/training/workloads.js +539 -0
  196. package/src/version.d.ts +2 -0
  197. package/src/version.js +2 -0
  198. package/tools/convert-safetensors-node.js +47 -0
  199. package/tools/doppler-cli.js +252 -41
@@ -0,0 +1,29 @@
1
+ export {
2
+ normalizeDistillationPair,
3
+ normalizeTranslationPairRow,
4
+ loadCanonicalTranslationDataset,
5
+ buildFrozenSubset,
6
+ } from './dataset.js';
7
+ export {
8
+ createDistillationRunArtifacts,
9
+ writeDistillStageManifest,
10
+ writeDistillCheckpointMetadata,
11
+ writeDistillCheckpointComplete,
12
+ writeDistillEvalReport,
13
+ writeDistillCompareReport,
14
+ writeDistillQualityGateReport,
15
+ buildDistillArtifactBase,
16
+ } from './artifacts.js';
17
+ export { appendDistillationScoreboardRow } from './scoreboard.js';
18
+ export {
19
+ buildDistillationTrainingConfigFromWorkload,
20
+ resolveInternalDistillStage,
21
+ } from './runtime.js';
22
+ export {
23
+ evaluateDistillationModel,
24
+ evaluateDistillationCheckpoint,
25
+ readDistillCheckpointMarker,
26
+ } from './eval.js';
27
+ export { runDistillationStage, runDistillationStageA } from './stage-a.js';
28
+ export { runDistillationStageB } from './stage-b.js';
29
+ export { watchDistillationCheckpoints } from './checkpoint-watch.js';
@@ -0,0 +1,29 @@
1
+ export {
2
+ normalizeDistillationPair,
3
+ normalizeTranslationPairRow,
4
+ loadCanonicalTranslationDataset,
5
+ buildFrozenSubset,
6
+ } from './dataset.js';
7
+ export {
8
+ createDistillationRunArtifacts,
9
+ writeDistillStageManifest,
10
+ writeDistillCheckpointMetadata,
11
+ writeDistillCheckpointComplete,
12
+ writeDistillEvalReport,
13
+ writeDistillCompareReport,
14
+ writeDistillQualityGateReport,
15
+ buildDistillArtifactBase,
16
+ } from './artifacts.js';
17
+ export { appendDistillationScoreboardRow } from './scoreboard.js';
18
+ export {
19
+ buildDistillationTrainingConfigFromWorkload,
20
+ resolveInternalDistillStage,
21
+ } from './runtime.js';
22
+ export {
23
+ evaluateDistillationModel,
24
+ evaluateDistillationCheckpoint,
25
+ readDistillCheckpointMarker,
26
+ } from './eval.js';
27
+ export { runDistillationStage, runDistillationStageA } from './stage-a.js';
28
+ export { runDistillationStageB } from './stage-b.js';
29
+ export { watchDistillationCheckpoints } from './checkpoint-watch.js';
@@ -0,0 +1,20 @@
1
+ import type { LoadedTrainingWorkload, DistillStagePlanEntry } from '../workloads.js';
2
+
3
+ export declare function resolveInternalDistillStage(
4
+ stageEntry: DistillStagePlanEntry | Record<string, unknown>
5
+ ): 'stage_a' | 'stage_b';
6
+
7
+ export declare function buildDistillationTrainingConfigFromWorkload(
8
+ loadedWorkload: LoadedTrainingWorkload,
9
+ stageEntry: DistillStagePlanEntry | Record<string, unknown>,
10
+ options?: {
11
+ datasetPath?: string | null;
12
+ artifactDir?: string | null;
13
+ stageAArtifact?: string | null;
14
+ stageAArtifactHash?: string | null;
15
+ }
16
+ ): {
17
+ internalStage: 'stage_a' | 'stage_b';
18
+ trainingConfig: Record<string, unknown>;
19
+ trainingConfigHash: string;
20
+ };
@@ -0,0 +1,121 @@
1
+ import { createTrainingConfig } from '../../config/training-defaults.js';
2
+ import { sha256Hex } from '../../utils/sha256.js';
3
+
4
+ function stableSortObject(value) {
5
+ if (Array.isArray(value)) {
6
+ return value.map((entry) => stableSortObject(entry));
7
+ }
8
+ if (!value || typeof value !== 'object') {
9
+ return value;
10
+ }
11
+ const sorted = {};
12
+ for (const key of Object.keys(value).sort()) {
13
+ sorted[key] = stableSortObject(value[key]);
14
+ }
15
+ return sorted;
16
+ }
17
+
18
+ function stableJson(value) {
19
+ return JSON.stringify(stableSortObject(value));
20
+ }
21
+
22
+ function normalizeStageLabel(value) {
23
+ const normalized = String(value || '').trim().toLowerCase().replace(/[\s-]+/g, '_');
24
+ return normalized;
25
+ }
26
+
27
+ export function resolveInternalDistillStage(stageEntry) {
28
+ const trainingStage = normalizeStageLabel(stageEntry?.trainingStage || stageEntry?.id || '');
29
+ const objective = normalizeStageLabel(stageEntry?.objective || '');
30
+ if (trainingStage === 'sft' || objective === 'sft') {
31
+ throw new Error(
32
+ 'Distillation workload stage uses "sft", but the current JS distill runner only supports the KD-oriented stage_a contract. Use objective="kd" / trainingStage="stage_a" explicitly.'
33
+ );
34
+ }
35
+ if (
36
+ trainingStage === 'stage_b'
37
+ || trainingStage === 'post_sft_distill'
38
+ || trainingStage === 'post_sft_triplet'
39
+ || objective === 'triplet'
40
+ ) {
41
+ return 'stage_b';
42
+ }
43
+ if (
44
+ trainingStage === 'stage_a'
45
+ || trainingStage === 'kd'
46
+ || objective === 'kd'
47
+ || objective === 'cross_entropy'
48
+ ) {
49
+ return 'stage_a';
50
+ }
51
+ throw new Error(
52
+ `Unsupported distillation stage "${stageEntry?.trainingStage || stageEntry?.id || 'unknown'}".`
53
+ );
54
+ }
55
+
56
+ export function buildDistillationTrainingConfigFromWorkload(loadedWorkload, stageEntry, options = {}) {
57
+ const workload = loadedWorkload.workload;
58
+ if (workload.kind !== 'distill') {
59
+ throw new Error('buildDistillationTrainingConfigFromWorkload requires a distill workload.');
60
+ }
61
+ const internalStage = resolveInternalDistillStage(stageEntry);
62
+ const distillTraining = {
63
+ enabled: true,
64
+ stage: internalStage,
65
+ teacherModelId: workload.teacherModelId,
66
+ studentModelId: workload.studentModelId,
67
+ datasetId: workload.datasetId,
68
+ datasetPath: options.datasetPath || workload.datasetPath,
69
+ sourceLangs: workload.pipeline.sourceLangs,
70
+ targetLangs: workload.pipeline.targetLangs,
71
+ pairAllowlist: workload.pipeline.pairAllowlist,
72
+ strictPairContract: workload.pipeline.strictPairContract === true,
73
+ stageAArtifact: options.stageAArtifact || null,
74
+ stageAArtifactHash: options.stageAArtifactHash || null,
75
+ artifactDir: options.artifactDir || null,
76
+ temperature: workload.pipeline.temperature,
77
+ alphaKd: workload.pipeline.alphaKd,
78
+ alphaCe: workload.pipeline.alphaCe,
79
+ tripletMargin: workload.pipeline.tripletMargin,
80
+ studentGraphMode: workload.pipeline.studentGraphMode,
81
+ };
82
+ if (internalStage === 'stage_b') {
83
+ distillTraining.freeze = {
84
+ encoder: true,
85
+ prior: true,
86
+ decoder: true,
87
+ base: false,
88
+ lora: false,
89
+ };
90
+ }
91
+ const trainingConfig = createTrainingConfig({
92
+ training: {
93
+ enabled: true,
94
+ optimizer: {
95
+ type: workload.training.optimizer.type,
96
+ lr: workload.training.optimizer.lr,
97
+ beta1: workload.training.optimizer.beta1,
98
+ beta2: workload.training.optimizer.beta2,
99
+ eps: workload.training.optimizer.eps,
100
+ weightDecay: workload.training.optimizer.weightDecay,
101
+ scheduler: workload.training.optimizer.scheduler,
102
+ },
103
+ gradient: {
104
+ maxNorm: workload.training.gradientClipping.maxNorm,
105
+ },
106
+ precision: workload.training.precision,
107
+ distill: distillTraining,
108
+ },
109
+ });
110
+ return {
111
+ internalStage,
112
+ trainingConfig,
113
+ trainingConfigHash: sha256Hex(stableJson({
114
+ workloadConfigHash: workload.configHash,
115
+ stageEntry,
116
+ datasetPath: options.datasetPath || workload.datasetPath,
117
+ stageAArtifact: options.stageAArtifact || null,
118
+ stageAArtifactHash: options.stageAArtifactHash || null,
119
+ })),
120
+ };
121
+ }
@@ -0,0 +1,6 @@
1
+ export declare function appendDistillationScoreboardRow(
2
+ layout: Record<string, string>,
3
+ stageId: string,
4
+ row: Record<string, unknown>,
5
+ options?: { selectionMetric?: string | null; selectionGoal?: string | null }
6
+ ): Promise<{ rowsPath: string; summaryPath: string }>;
@@ -0,0 +1,8 @@
1
+ import { join } from 'node:path';
2
+
3
+ import { appendScoreboardRow } from '../operator-scoreboard.js';
4
+
5
+ export async function appendDistillationScoreboardRow(layout, stageId, row, options = {}) {
6
+ const scoreboardDir = join(layout.scoreboard, String(stageId || 'stage'));
7
+ return appendScoreboardRow(scoreboardDir, row, options);
8
+ }
@@ -0,0 +1,45 @@
1
+ import type { LoadedTrainingWorkload, DistillStagePlanEntry } from '../workloads.js';
2
+
3
+ export declare function runDistillationStage(options: {
4
+ loadedWorkload: LoadedTrainingWorkload;
5
+ stageEntry: DistillStagePlanEntry;
6
+ layout: Record<string, string>;
7
+ datasetPath?: string | null;
8
+ stageAArtifact?: string | null;
9
+ stageAArtifactHash?: string | null;
10
+ legacyArtifactDir?: string | null;
11
+ timestamp?: string | Date | null;
12
+ parentArtifacts?: Array<Record<string, unknown>>;
13
+ }): Promise<{
14
+ stageId: string;
15
+ trainingStage: 'stage_a' | 'stage_b';
16
+ metrics: Record<string, unknown>[];
17
+ checkpointArtifacts: Array<Record<string, unknown>>;
18
+ evalReports: Array<Record<string, unknown>>;
19
+ bestReport: Record<string, unknown> | null;
20
+ stageManifestPath: string;
21
+ legacyArtifact: Record<string, unknown> | null;
22
+ lastCheckpoint: Record<string, unknown> | null;
23
+ }>;
24
+
25
+ export declare function runDistillationStageA(options: {
26
+ loadedWorkload: LoadedTrainingWorkload;
27
+ stageEntry: DistillStagePlanEntry;
28
+ layout: Record<string, string>;
29
+ datasetPath?: string | null;
30
+ stageAArtifact?: string | null;
31
+ stageAArtifactHash?: string | null;
32
+ legacyArtifactDir?: string | null;
33
+ timestamp?: string | Date | null;
34
+ parentArtifacts?: Array<Record<string, unknown>>;
35
+ }): Promise<{
36
+ stageId: string;
37
+ trainingStage: 'stage_a' | 'stage_b';
38
+ metrics: Record<string, unknown>[];
39
+ checkpointArtifacts: Array<Record<string, unknown>>;
40
+ evalReports: Array<Record<string, unknown>>;
41
+ bestReport: Record<string, unknown> | null;
42
+ stageManifestPath: string;
43
+ legacyArtifact: Record<string, unknown> | null;
44
+ lastCheckpoint: Record<string, unknown> | null;
45
+ }>;
@@ -0,0 +1,338 @@
1
+ import { mkdir } from 'node:fs/promises';
2
+ import { resolve } from 'node:path';
3
+
4
+ import { AdamOptimizer } from '../optimizer.js';
5
+ import { crossEntropyLoss } from '../loss.js';
6
+ import { clipGradients } from '../clip.js';
7
+ import { TrainingRunner } from '../runner.js';
8
+ import {
9
+ createDistillRuntimeContext,
10
+ createDistillStudentRuntimeModelFixture,
11
+ loadDistillDatasetFromJsonl,
12
+ resolveDistillDataScope,
13
+ } from '../suite.js';
14
+ import { loadCanonicalTranslationDataset } from './dataset.js';
15
+ import { evaluateDistillationModel } from './eval.js';
16
+ import {
17
+ buildDistillArtifactBase,
18
+ writeDistillCheckpointComplete,
19
+ writeDistillCheckpointMetadata,
20
+ writeDistillStageManifest,
21
+ } from './artifacts.js';
22
+ import { appendDistillationScoreboardRow } from './scoreboard.js';
23
+ import { buildDistillationTrainingConfigFromWorkload } from './runtime.js';
24
+
25
+ function padStep(step) {
26
+ return String(step).padStart(6, '0');
27
+ }
28
+
29
+ function resolveComparableMetric(report, metric) {
30
+ if (!report || typeof report !== 'object') return null;
31
+ const direct = report[metric];
32
+ if (typeof direct === 'number' && Number.isFinite(direct)) {
33
+ return direct;
34
+ }
35
+ const metrics = report.metrics && typeof report.metrics === 'object' ? report.metrics : null;
36
+ const nested = metrics?.[metric];
37
+ if (typeof nested === 'number' && Number.isFinite(nested)) {
38
+ return nested;
39
+ }
40
+ if (metrics?.[metric]?.score != null && Number.isFinite(metrics[metric].score)) {
41
+ return metrics[metric].score;
42
+ }
43
+ return null;
44
+ }
45
+
46
+ function selectBestReport(reports, metric, goal) {
47
+ const normalizedGoal = String(goal || 'max').trim();
48
+ let best = null;
49
+ let bestValue = null;
50
+ for (const report of reports) {
51
+ const value = resolveComparableMetric(report, metric);
52
+ if (!Number.isFinite(value)) continue;
53
+ if (best === null) {
54
+ best = report;
55
+ bestValue = value;
56
+ continue;
57
+ }
58
+ const better = normalizedGoal === 'min'
59
+ ? value < bestValue
60
+ : value > bestValue;
61
+ if (better) {
62
+ best = report;
63
+ bestValue = value;
64
+ }
65
+ }
66
+ return best;
67
+ }
68
+
69
+ function shouldEvalOnCheckpoint(stageEntry) {
70
+ const schedule = String(stageEntry?.evalSchedule || 'on_checkpoint').trim();
71
+ return schedule === 'on_checkpoint';
72
+ }
73
+
74
+ function shouldEvalAtEnd(stageEntry) {
75
+ const schedule = String(stageEntry?.evalSchedule || 'on_checkpoint').trim();
76
+ return schedule === 'final';
77
+ }
78
+
79
+ export async function runDistillationStage(options) {
80
+ const loadedWorkload = options.loadedWorkload;
81
+ const workload = loadedWorkload.workload;
82
+ const stageEntry = options.stageEntry;
83
+ const datasetPath = options.datasetPath || workload.datasetPath;
84
+ const legacyArtifactDir = options.legacyArtifactDir || resolve(options.layout.runRoot, 'legacy-stage-artifacts');
85
+ if (workload.training.batchSize !== 1) {
86
+ throw new Error('Distillation stage currently requires training.batchSize=1.');
87
+ }
88
+ if (workload.training.accumSteps !== 1) {
89
+ throw new Error('Distillation stage currently requires training.accumSteps=1.');
90
+ }
91
+ const configBundle = buildDistillationTrainingConfigFromWorkload(loadedWorkload, stageEntry, {
92
+ datasetPath,
93
+ artifactDir: legacyArtifactDir,
94
+ stageAArtifact: options.stageAArtifact || null,
95
+ stageAArtifactHash: options.stageAArtifactHash || null,
96
+ });
97
+ const distillDataScope = resolveDistillDataScope({
98
+ distillSourceLangs: workload.pipeline.sourceLangs,
99
+ distillTargetLangs: workload.pipeline.targetLangs,
100
+ distillPairAllowlist: workload.pipeline.pairAllowlist,
101
+ strictPairContract: workload.pipeline.strictPairContract === true,
102
+ }, configBundle.trainingConfig.training);
103
+ const distillDatasetReport = await loadDistillDatasetFromJsonl(datasetPath, distillDataScope);
104
+ if (!distillDatasetReport) {
105
+ throw new Error(`Unable to resolve distillation dataset "${datasetPath}".`);
106
+ }
107
+ const canonicalDataset = await loadCanonicalTranslationDataset(datasetPath, {
108
+ strictPairContract: workload.pipeline.strictPairContract === true,
109
+ sourceLangs: workload.pipeline.sourceLangs,
110
+ targetLangs: workload.pipeline.targetLangs,
111
+ pairAllowlist: workload.pipeline.pairAllowlist,
112
+ });
113
+ const distillRuntime = await createDistillRuntimeContext({
114
+ teacherModelId: workload.teacherModelId,
115
+ studentModelId: workload.studentModelId,
116
+ trainingStage: configBundle.internalStage,
117
+ studentGraphMode: workload.pipeline.studentGraphMode,
118
+ }, configBundle.trainingConfig.training);
119
+ let fixture = null;
120
+ try {
121
+ fixture = await createDistillStudentRuntimeModelFixture({
122
+ training: configBundle.trainingConfig.training,
123
+ }, {
124
+ distillRuntime,
125
+ studentGraphMode: workload.pipeline.studentGraphMode,
126
+ });
127
+ const stageId = stageEntry.id;
128
+ const stageCheckpointsDir = resolve(options.layout.checkpoints, stageId);
129
+ await mkdir(stageCheckpointsDir, { recursive: true });
130
+ const dataset = distillDatasetReport.createDataset({
131
+ batchSize: workload.training.batchSize,
132
+ shuffle: false,
133
+ seed: workload.seed,
134
+ distillRuntime,
135
+ });
136
+ const checkpointArtifacts = [];
137
+ const evalReports = [];
138
+ const runner = new TrainingRunner(fixture.config, {
139
+ optimizer: new AdamOptimizer(fixture.config),
140
+ crossEntropyLoss,
141
+ clipGradients,
142
+ resolveCheckpointKey({ step }) {
143
+ const checkpointId = `checkpoint-${padStep(step)}`;
144
+ return resolve(stageCheckpointsDir, checkpointId, 'state.json');
145
+ },
146
+ onCheckpoint: async (checkpoint) => {
147
+ const checkpointId = `checkpoint-${padStep(checkpoint.step)}`;
148
+ const checkpointBase = buildDistillArtifactBase(loadedWorkload, {
149
+ prefix: 'dst_ckpt',
150
+ artifactType: 'training_checkpoint',
151
+ datasetPath: distillDatasetReport.absolutePath,
152
+ datasetHash: canonicalDataset.canonicalHash,
153
+ stage: stageId,
154
+ checkpointStep: checkpoint.step,
155
+ configHash: configBundle.trainingConfigHash,
156
+ parentArtifacts: options.parentArtifacts || [],
157
+ });
158
+ const metadataPayload = {
159
+ ...checkpointBase,
160
+ checkpointId,
161
+ checkpointPath: checkpoint.path,
162
+ step: checkpoint.step,
163
+ epoch: checkpoint.epoch,
164
+ batch: checkpoint.batch,
165
+ optimizerStatePresent: true,
166
+ schedulerStatePresent: workload.training.optimizer.scheduler.enabled === true,
167
+ stageArtifact: options.stageAArtifact || null,
168
+ resumeLineage: checkpoint.metadata?.lineage || null,
169
+ };
170
+ const metadataFile = await writeDistillCheckpointMetadata(
171
+ options.layout,
172
+ stageId,
173
+ checkpointId,
174
+ metadataPayload
175
+ );
176
+ const completeFile = await writeDistillCheckpointComplete(
177
+ options.layout,
178
+ stageId,
179
+ checkpointId,
180
+ {
181
+ ...metadataPayload,
182
+ metadataPath: metadataFile.path,
183
+ finalized: true,
184
+ }
185
+ );
186
+ checkpointArtifacts.push({
187
+ checkpointId,
188
+ checkpointPath: checkpoint.path,
189
+ metadataPath: metadataFile.path,
190
+ completePath: completeFile.path,
191
+ step: checkpoint.step,
192
+ });
193
+ if (!shouldEvalOnCheckpoint(stageEntry)) {
194
+ return;
195
+ }
196
+ const reports = await evaluateDistillationModel({
197
+ loadedWorkload,
198
+ layout: options.layout,
199
+ stageId,
200
+ checkpointId,
201
+ checkpointStep: checkpoint.step,
202
+ checkpointPath: checkpoint.path,
203
+ distillRuntime,
204
+ model: fixture.model,
205
+ configHash: configBundle.trainingConfigHash,
206
+ parentArtifacts: [{
207
+ artifactType: 'training_checkpoint',
208
+ path: metadataFile.path,
209
+ checkpointId,
210
+ }],
211
+ });
212
+ for (const report of reports) {
213
+ evalReports.push(report);
214
+ await appendDistillationScoreboardRow(options.layout, stageId, {
215
+ artifactType: 'training_scoreboard',
216
+ schemaVersion: 1,
217
+ generatedAt: new Date().toISOString(),
218
+ stage: stageId,
219
+ checkpointId,
220
+ checkpointStep: checkpoint.step,
221
+ evalDatasetId: report.evalDatasetId,
222
+ selectionMetric: stageEntry.selectionMetric,
223
+ selectionGoal: stageEntry.selectionGoal,
224
+ primaryMetric: report.primaryMetric,
225
+ primaryScore: report.primaryScore,
226
+ bleu: report.bleu,
227
+ chrf: report.chrf,
228
+ reportPath: report.reportPath || null,
229
+ metrics: {
230
+ bleu: report.bleu,
231
+ chrf: report.chrf,
232
+ primaryScore: report.primaryScore,
233
+ },
234
+ }, {
235
+ selectionMetric: stageEntry.selectionMetric,
236
+ selectionGoal: stageEntry.selectionGoal,
237
+ });
238
+ }
239
+ },
240
+ });
241
+ const metrics = await runner.run(fixture.model, dataset, {
242
+ epochs: 1,
243
+ batchSize: workload.training.batchSize,
244
+ shuffle: false,
245
+ maxSteps: stageEntry.steps,
246
+ checkpointEvery: stageEntry.checkpointEvery,
247
+ modelId: workload.studentModelId,
248
+ trainingStage: configBundle.internalStage,
249
+ runtimePreset: null,
250
+ command: 'distill',
251
+ surface: 'node',
252
+ timestamp: options.timestamp || null,
253
+ distillArtifactDir: legacyArtifactDir,
254
+ stageAArtifact: options.stageAArtifact || null,
255
+ stageAArtifactHash: options.stageAArtifactHash || null,
256
+ teacherModelId: workload.teacherModelId,
257
+ studentModelId: workload.studentModelId,
258
+ distillDatasetId: workload.datasetId,
259
+ distillDatasetPath: distillDatasetReport.absolutePath,
260
+ distillSourceLangs: workload.pipeline.sourceLangs,
261
+ distillTargetLangs: workload.pipeline.targetLangs,
262
+ distillPairAllowlist: workload.pipeline.pairAllowlist,
263
+ strictPairContract: workload.pipeline.strictPairContract === true,
264
+ });
265
+ if (shouldEvalAtEnd(stageEntry) && runner.lastCheckpoint) {
266
+ const finalCheckpointId = `checkpoint-${padStep(runner.lastCheckpoint.step)}`;
267
+ const reports = await evaluateDistillationModel({
268
+ loadedWorkload,
269
+ layout: options.layout,
270
+ stageId,
271
+ checkpointId: finalCheckpointId,
272
+ checkpointStep: runner.lastCheckpoint.step,
273
+ checkpointPath: runner.lastCheckpoint.path || null,
274
+ distillRuntime,
275
+ model: fixture.model,
276
+ configHash: configBundle.trainingConfigHash,
277
+ });
278
+ evalReports.push(...reports);
279
+ }
280
+ const bestReport = selectBestReport(
281
+ evalReports,
282
+ stageEntry.selectionMetric,
283
+ stageEntry.selectionGoal
284
+ );
285
+ const stageManifestPayload = {
286
+ ...buildDistillArtifactBase(loadedWorkload, {
287
+ prefix: 'dst_stage',
288
+ artifactType: 'distill_stage_manifest',
289
+ datasetPath: distillDatasetReport.absolutePath,
290
+ datasetHash: canonicalDataset.canonicalHash,
291
+ stage: stageId,
292
+ checkpointStep: runner.lastCheckpoint?.step ?? null,
293
+ configHash: configBundle.trainingConfigHash,
294
+ parentArtifacts: options.parentArtifacts || [],
295
+ }),
296
+ stageId,
297
+ trainingStage: configBundle.internalStage,
298
+ objective: stageEntry.objective,
299
+ stagePlanEntry: stageEntry,
300
+ stepCount: Array.isArray(metrics) ? metrics.length : 0,
301
+ checkpointCount: checkpointArtifacts.length,
302
+ bestCheckpointId: bestReport?.checkpointId || null,
303
+ selectionMetric: stageEntry.selectionMetric,
304
+ selectionGoal: stageEntry.selectionGoal,
305
+ checkpointArtifacts,
306
+ evalReports: evalReports.map((report) => ({
307
+ checkpointId: report.checkpointId,
308
+ evalDatasetId: report.evalDatasetId,
309
+ reportPath: report.reportPath || null,
310
+ primaryMetric: report.primaryMetric,
311
+ primaryScore: report.primaryScore,
312
+ bleu: report.bleu,
313
+ chrf: report.chrf,
314
+ })),
315
+ legacyArtifact: runner.lastArtifact || null,
316
+ lastCheckpoint: runner.lastCheckpoint || null,
317
+ };
318
+ const stageManifest = await writeDistillStageManifest(options.layout, stageManifestPayload);
319
+ return {
320
+ stageId,
321
+ trainingStage: configBundle.internalStage,
322
+ metrics,
323
+ checkpointArtifacts,
324
+ evalReports,
325
+ bestReport,
326
+ stageManifestPath: stageManifest.path,
327
+ legacyArtifact: runner.lastArtifact || null,
328
+ lastCheckpoint: runner.lastCheckpoint || null,
329
+ };
330
+ } finally {
331
+ fixture?.cleanup?.();
332
+ await distillRuntime.cleanup();
333
+ }
334
+ }
335
+
336
+ export async function runDistillationStageA(options) {
337
+ return runDistillationStage(options);
338
+ }
@@ -0,0 +1,24 @@
1
+ import type { LoadedTrainingWorkload, DistillStagePlanEntry } from '../workloads.js';
2
+
3
+ export declare function runDistillationStageB(options: {
4
+ loadedWorkload: LoadedTrainingWorkload;
5
+ stageEntry: DistillStagePlanEntry;
6
+ layout: Record<string, string>;
7
+ datasetPath?: string | null;
8
+ stageAArtifact?: string | null;
9
+ stageAArtifactHash?: string | null;
10
+ legacyArtifactDir?: string | null;
11
+ timestamp?: string | Date | null;
12
+ parentArtifacts?: Array<Record<string, unknown>>;
13
+ priorStageResult?: Record<string, unknown> | null;
14
+ }): Promise<{
15
+ stageId: string;
16
+ trainingStage: 'stage_a' | 'stage_b';
17
+ metrics: Record<string, unknown>[];
18
+ checkpointArtifacts: Array<Record<string, unknown>>;
19
+ evalReports: Array<Record<string, unknown>>;
20
+ bestReport: Record<string, unknown> | null;
21
+ stageManifestPath: string;
22
+ legacyArtifact: Record<string, unknown> | null;
23
+ lastCheckpoint: Record<string, unknown> | null;
24
+ }>;
@@ -0,0 +1,20 @@
1
+ import { resolve } from 'node:path';
2
+
3
+ import { runDistillationStage } from './stage-a.js';
4
+
5
+ export async function runDistillationStageB(options) {
6
+ const priorStage = options.priorStageResult || null;
7
+ const stageAArtifact = options.stageAArtifact
8
+ || priorStage?.legacyArtifact?.manifestPath
9
+ || null;
10
+ if (!stageAArtifact) {
11
+ throw new Error('Distillation stage-b requires a Stage A artifact path.');
12
+ }
13
+ return runDistillationStage({
14
+ ...options,
15
+ stageAArtifact: resolve(String(stageAArtifact)),
16
+ stageAArtifactHash: options.stageAArtifactHash
17
+ || priorStage?.legacyArtifact?.manifestHash
18
+ || null,
19
+ });
20
+ }
@@ -10,6 +10,15 @@ export { exportLoRAAdapter } from './export.js';
10
10
  export { DynamicLossScaler, detectOverflow } from './loss-scaling.js';
11
11
  export { TrainingRunner, runTraining } from './runner.js';
12
12
  export { runTrainingSuite, runTrainingBenchSuite, trainingHarness } from './suite.js';
13
+ export { runTrainingOperatorCommand } from './operator-command.js';
14
+ export {
15
+ runLoraPipeline,
16
+ evaluateLoraCheckpoint,
17
+ exportLoraCheckpoint,
18
+ watchLoraCheckpoints,
19
+ compareLoraRun,
20
+ qualityGateLoraRun,
21
+ } from './lora-pipeline.js';
13
22
  export type {
14
23
  TrainingObjective,
15
24
  TrainingObjectiveContext,
@@ -31,6 +40,7 @@ export type {
31
40
  DistillArtifactSession,
32
41
  DistillArtifactFinalizeResult,
33
42
  } from './artifacts.js';
43
+ export * as distillation from './distillation/index.js';
34
44
  export {
35
45
  createDistillArtifactSession,
36
46
  resolveDistillTrainingContract,