@simulatte/doppler 0.1.5 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. package/README.md +23 -8
  2. package/package.json +7 -4
  3. package/src/config/kernels/kernel-ref-digests.js +39 -39
  4. package/src/config/kernels/registry.json +42 -2
  5. package/src/config/loader.js +31 -2
  6. package/src/config/merge.js +18 -0
  7. package/src/config/presets/models/qwen3.json +9 -2
  8. package/src/config/presets/models/transformer.json +5 -0
  9. package/src/config/required-inference-fields-contract-check.js +6 -0
  10. package/src/config/schema/inference-defaults.schema.js +3 -0
  11. package/src/config/schema/inference.schema.d.ts +9 -0
  12. package/src/config/schema/kernel-path.schema.d.ts +6 -0
  13. package/src/config/schema/manifest.schema.d.ts +6 -0
  14. package/src/config/schema/manifest.schema.js +3 -0
  15. package/src/converter/rope-config.js +42 -0
  16. package/src/gpu/device.js +58 -0
  17. package/src/gpu/kernels/attention.js +98 -0
  18. package/src/gpu/kernels/bias_add.wgsl +8 -6
  19. package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
  20. package/src/gpu/kernels/conv2d.js +1 -1
  21. package/src/gpu/kernels/conv2d.wgsl +7 -8
  22. package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
  23. package/src/gpu/kernels/depthwise_conv2d.js +2 -1
  24. package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
  25. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
  26. package/src/gpu/kernels/grouped_pointwise_conv2d.js +2 -1
  27. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
  28. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
  29. package/src/gpu/kernels/matmul.js +25 -0
  30. package/src/gpu/kernels/pixel_shuffle.js +1 -1
  31. package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
  32. package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
  33. package/src/gpu/kernels/relu.js +15 -2
  34. package/src/gpu/kernels/relu.wgsl +2 -1
  35. package/src/gpu/kernels/relu_f16.wgsl +2 -1
  36. package/src/gpu/kernels/repeat_channels.js +1 -1
  37. package/src/gpu/kernels/repeat_channels.wgsl +4 -5
  38. package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
  39. package/src/gpu/kernels/residual.js +44 -8
  40. package/src/gpu/kernels/residual.wgsl +6 -3
  41. package/src/gpu/kernels/residual_f16.wgsl +2 -1
  42. package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
  43. package/src/gpu/kernels/residual_vec4.wgsl +2 -1
  44. package/src/gpu/kernels/rmsnorm.js +58 -6
  45. package/src/gpu/kernels/rmsnorm.wgsl +14 -6
  46. package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
  47. package/src/gpu/kernels/rope.d.ts +2 -0
  48. package/src/gpu/kernels/rope.js +11 -1
  49. package/src/gpu/kernels/rope.wgsl +56 -40
  50. package/src/gpu/kernels/sana_linear_attention.js +1 -2
  51. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
  52. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
  53. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
  54. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
  55. package/src/gpu/kernels/silu.d.ts +1 -0
  56. package/src/gpu/kernels/silu.js +32 -14
  57. package/src/gpu/kernels/silu.wgsl +19 -9
  58. package/src/gpu/kernels/silu_f16.wgsl +19 -9
  59. package/src/gpu/kernels/transpose.js +15 -2
  60. package/src/gpu/kernels/transpose.wgsl +5 -6
  61. package/src/gpu/kernels/upsample2d.js +2 -1
  62. package/src/gpu/kernels/upsample2d.wgsl +6 -9
  63. package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
  64. package/src/gpu/kernels/utils.js +16 -1
  65. package/src/inference/browser-harness.js +47 -1
  66. package/src/inference/pipelines/diffusion/pipeline.js +15 -6
  67. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
  68. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
  69. package/src/inference/pipelines/text/attention/record.js +11 -2
  70. package/src/inference/pipelines/text/attention/run.js +11 -2
  71. package/src/inference/pipelines/text/chat-format.js +25 -1
  72. package/src/inference/pipelines/text/config.d.ts +4 -0
  73. package/src/inference/pipelines/text/config.js +68 -1
  74. package/src/inference/pipelines/text/execution-plan.js +23 -31
  75. package/src/inference/pipelines/text/execution-v0.js +29 -2
  76. package/src/inference/pipelines/text/ffn/standard.js +3 -0
  77. package/src/inference/pipelines/text/init.d.ts +4 -0
  78. package/src/inference/pipelines/text/init.js +56 -9
  79. package/src/inference/pipelines/text/layer.js +11 -0
  80. package/src/inference/pipelines/text.js +4 -0
  81. package/src/inference/tokenizers/bundled.js +156 -33
  82. package/src/rules/tooling/command-runtime.rules.json +18 -0
  83. package/src/tooling/command-api.d.ts +27 -1
  84. package/src/tooling/command-api.js +142 -3
  85. package/src/tooling/node-browser-command-runner.d.ts +4 -0
  86. package/src/tooling/node-browser-command-runner.js +58 -3
  87. package/src/tooling/node-command-runner.js +15 -0
  88. package/src/tooling/node-webgpu.js +9 -87
  89. package/src/training/checkpoint-watch.d.ts +7 -0
  90. package/src/training/checkpoint-watch.js +106 -0
  91. package/src/training/checkpoint.d.ts +6 -1
  92. package/src/training/checkpoint.js +12 -2
  93. package/src/training/distillation/artifacts.d.ts +71 -0
  94. package/src/training/distillation/artifacts.js +132 -0
  95. package/src/training/distillation/checkpoint-watch.d.ts +10 -0
  96. package/src/training/distillation/checkpoint-watch.js +57 -0
  97. package/src/training/distillation/dataset.d.ts +59 -0
  98. package/src/training/distillation/dataset.js +337 -0
  99. package/src/training/distillation/eval.d.ts +34 -0
  100. package/src/training/distillation/eval.js +310 -0
  101. package/src/training/distillation/index.d.ts +29 -0
  102. package/src/training/distillation/index.js +29 -0
  103. package/src/training/distillation/runtime.d.ts +20 -0
  104. package/src/training/distillation/runtime.js +121 -0
  105. package/src/training/distillation/scoreboard.d.ts +6 -0
  106. package/src/training/distillation/scoreboard.js +8 -0
  107. package/src/training/distillation/stage-a.d.ts +45 -0
  108. package/src/training/distillation/stage-a.js +338 -0
  109. package/src/training/distillation/stage-b.d.ts +24 -0
  110. package/src/training/distillation/stage-b.js +20 -0
  111. package/src/training/index.d.ts +10 -0
  112. package/src/training/index.js +10 -0
  113. package/src/training/lora-pipeline.d.ts +40 -0
  114. package/src/training/lora-pipeline.js +796 -0
  115. package/src/training/operator-artifacts.d.ts +62 -0
  116. package/src/training/operator-artifacts.js +140 -0
  117. package/src/training/operator-command.d.ts +5 -0
  118. package/src/training/operator-command.js +453 -0
  119. package/src/training/operator-eval.d.ts +48 -0
  120. package/src/training/operator-eval.js +230 -0
  121. package/src/training/operator-scoreboard.d.ts +5 -0
  122. package/src/training/operator-scoreboard.js +44 -0
  123. package/src/training/runner.d.ts +52 -0
  124. package/src/training/runner.js +29 -4
  125. package/src/training/suite.d.ts +112 -0
  126. package/src/training/suite.js +9 -9
  127. package/src/training/workloads.d.ts +164 -0
  128. package/src/training/workloads.js +539 -0
  129. package/src/version.js +1 -1
  130. package/tools/doppler-cli.js +137 -40
@@ -0,0 +1,230 @@
1
+ import { readFile } from 'node:fs/promises';
2
+ import { resolve } from 'node:path';
3
+
4
+ import { parseJsonl } from './datasets/jsonl.js';
5
+
6
+ function asTokenSequence(text) {
7
+ return String(text ?? '')
8
+ .trim()
9
+ .split(/\s+/)
10
+ .filter(Boolean);
11
+ }
12
+
13
+ function extractCharacterNgrams(text, n) {
14
+ const normalized = Array.from(String(text ?? '').trim());
15
+ if (normalized.length < n) {
16
+ return new Map();
17
+ }
18
+ const grams = new Map();
19
+ for (let index = 0; index <= normalized.length - n; index += 1) {
20
+ const gram = normalized.slice(index, index + n).join('');
21
+ grams.set(gram, (grams.get(gram) || 0) + 1);
22
+ }
23
+ return grams;
24
+ }
25
+
26
+ function countOverlap(source, target) {
27
+ let overlap = 0;
28
+ for (const [key, sourceCount] of source.entries()) {
29
+ const targetCount = target.get(key) || 0;
30
+ overlap += Math.min(sourceCount, targetCount);
31
+ }
32
+ return overlap;
33
+ }
34
+
35
+ function computeBleuStats(hypotheses, references, maxOrder = 4) {
36
+ const matchesByOrder = new Array(maxOrder).fill(0);
37
+ const possibleByOrder = new Array(maxOrder).fill(0);
38
+ let hypothesisLength = 0;
39
+ let referenceLength = 0;
40
+
41
+ for (let index = 0; index < hypotheses.length; index += 1) {
42
+ const hypothesis = asTokenSequence(hypotheses[index]);
43
+ const reference = asTokenSequence(references[index]);
44
+ hypothesisLength += hypothesis.length;
45
+ referenceLength += reference.length;
46
+ for (let order = 1; order <= maxOrder; order += 1) {
47
+ const hypothesisCounts = new Map();
48
+ const referenceCounts = new Map();
49
+ for (let tokenIndex = 0; tokenIndex <= hypothesis.length - order; tokenIndex += 1) {
50
+ const ngram = hypothesis.slice(tokenIndex, tokenIndex + order).join('\u0001');
51
+ hypothesisCounts.set(ngram, (hypothesisCounts.get(ngram) || 0) + 1);
52
+ }
53
+ for (let tokenIndex = 0; tokenIndex <= reference.length - order; tokenIndex += 1) {
54
+ const ngram = reference.slice(tokenIndex, tokenIndex + order).join('\u0001');
55
+ referenceCounts.set(ngram, (referenceCounts.get(ngram) || 0) + 1);
56
+ }
57
+ matchesByOrder[order - 1] += countOverlap(hypothesisCounts, referenceCounts);
58
+ possibleByOrder[order - 1] += Math.max(0, hypothesis.length - order + 1);
59
+ }
60
+ }
61
+
62
+ return {
63
+ matchesByOrder,
64
+ possibleByOrder,
65
+ hypothesisLength,
66
+ referenceLength,
67
+ };
68
+ }
69
+
70
+ export function computeBleuScore(hypotheses, references, options = {}) {
71
+ const maxOrder = Number.isInteger(options.maxOrder) && options.maxOrder > 0
72
+ ? options.maxOrder
73
+ : 4;
74
+ if (!Array.isArray(hypotheses) || !Array.isArray(references) || hypotheses.length !== references.length) {
75
+ throw new Error('computeBleuScore requires equally sized hypothesis and reference arrays.');
76
+ }
77
+ if (hypotheses.length === 0) {
78
+ return {
79
+ score: 0,
80
+ brevityPenalty: 0,
81
+ precisions: new Array(maxOrder).fill(0),
82
+ hypothesisLength: 0,
83
+ referenceLength: 0,
84
+ };
85
+ }
86
+
87
+ const stats = computeBleuStats(hypotheses, references, maxOrder);
88
+ const precisions = [];
89
+ let precisionLogSum = 0;
90
+ for (let order = 0; order < maxOrder; order += 1) {
91
+ const matches = stats.matchesByOrder[order];
92
+ const possible = stats.possibleByOrder[order];
93
+ const precision = possible === 0
94
+ ? 0
95
+ : ((matches + 1) / (possible + 1));
96
+ precisions.push(precision);
97
+ precisionLogSum += Math.log(Math.max(precision, 1e-16));
98
+ }
99
+ const brevityPenalty = stats.hypothesisLength > stats.referenceLength
100
+ ? 1
101
+ : Math.exp(1 - (stats.referenceLength / Math.max(stats.hypothesisLength, 1)));
102
+ const score = brevityPenalty * Math.exp(precisionLogSum / maxOrder);
103
+ return {
104
+ score,
105
+ brevityPenalty,
106
+ precisions,
107
+ hypothesisLength: stats.hypothesisLength,
108
+ referenceLength: stats.referenceLength,
109
+ };
110
+ }
111
+
112
+ export function computeChrfScore(hypotheses, references, options = {}) {
113
+ const maxOrder = Number.isInteger(options.maxOrder) && options.maxOrder > 0
114
+ ? options.maxOrder
115
+ : 6;
116
+ const beta = Number.isFinite(options.beta) && options.beta > 0 ? options.beta : 2;
117
+ if (!Array.isArray(hypotheses) || !Array.isArray(references) || hypotheses.length !== references.length) {
118
+ throw new Error('computeChrfScore requires equally sized hypothesis and reference arrays.');
119
+ }
120
+ if (hypotheses.length === 0) {
121
+ return {
122
+ score: 0,
123
+ precision: 0,
124
+ recall: 0,
125
+ };
126
+ }
127
+
128
+ let precisionSum = 0;
129
+ let recallSum = 0;
130
+ for (let order = 1; order <= maxOrder; order += 1) {
131
+ let overlap = 0;
132
+ let hypothesisTotal = 0;
133
+ let referenceTotal = 0;
134
+ for (let index = 0; index < hypotheses.length; index += 1) {
135
+ const hypothesisCounts = extractCharacterNgrams(hypotheses[index], order);
136
+ const referenceCounts = extractCharacterNgrams(references[index], order);
137
+ overlap += countOverlap(hypothesisCounts, referenceCounts);
138
+ for (const value of hypothesisCounts.values()) {
139
+ hypothesisTotal += value;
140
+ }
141
+ for (const value of referenceCounts.values()) {
142
+ referenceTotal += value;
143
+ }
144
+ }
145
+ precisionSum += hypothesisTotal > 0 ? (overlap / hypothesisTotal) : 0;
146
+ recallSum += referenceTotal > 0 ? (overlap / referenceTotal) : 0;
147
+ }
148
+
149
+ const precision = precisionSum / maxOrder;
150
+ const recall = recallSum / maxOrder;
151
+ const betaSquared = beta * beta;
152
+ const score = (precision + recall) === 0
153
+ ? 0
154
+ : ((1 + betaSquared) * precision * recall) / ((betaSquared * precision) + recall);
155
+ return { score, precision, recall };
156
+ }
157
+
158
+ export function computeExactMatch(hypotheses, references) {
159
+ if (!Array.isArray(hypotheses) || !Array.isArray(references) || hypotheses.length !== references.length) {
160
+ throw new Error('computeExactMatch requires equally sized hypothesis and reference arrays.');
161
+ }
162
+ if (hypotheses.length === 0) {
163
+ return { score: 0, matches: 0, total: 0 };
164
+ }
165
+ let matches = 0;
166
+ for (let index = 0; index < hypotheses.length; index += 1) {
167
+ if (String(hypotheses[index] ?? '').trim() === String(references[index] ?? '').trim()) {
168
+ matches += 1;
169
+ }
170
+ }
171
+ return {
172
+ score: matches / hypotheses.length,
173
+ matches,
174
+ total: hypotheses.length,
175
+ };
176
+ }
177
+
178
+ export function computeAccuracy(labels, predictions) {
179
+ return computeExactMatch(predictions, labels);
180
+ }
181
+
182
+ export function computeEvalMetrics(evalKind, hypotheses, references, options = {}) {
183
+ const normalizedKind = String(evalKind || '').trim();
184
+ if (normalizedKind === 'translation') {
185
+ const bleu = computeBleuScore(hypotheses, references, options.bleu || {});
186
+ const chrf = computeChrfScore(hypotheses, references, options.chrf || {});
187
+ return {
188
+ bleu,
189
+ chrf,
190
+ primaryMetric: 'bleu',
191
+ primaryScore: bleu.score,
192
+ };
193
+ }
194
+ if (normalizedKind === 'text_generation') {
195
+ const exactMatch = computeExactMatch(hypotheses, references);
196
+ return {
197
+ exactMatch,
198
+ primaryMetric: 'exact_match',
199
+ primaryScore: exactMatch.score,
200
+ };
201
+ }
202
+ if (normalizedKind === 'classification') {
203
+ const accuracy = computeAccuracy(references, hypotheses);
204
+ return {
205
+ accuracy,
206
+ primaryMetric: 'accuracy',
207
+ primaryScore: accuracy.score,
208
+ };
209
+ }
210
+ if (normalizedKind === 'retrieval' || normalizedKind === 'custom') {
211
+ throw new Error(`Eval kind "${normalizedKind}" requires a custom evaluator and is not yet implemented.`);
212
+ }
213
+ throw new Error(`Unsupported eval kind "${normalizedKind}".`);
214
+ }
215
+
216
+ export async function loadEvalDataset(datasetPath) {
217
+ const absolutePath = resolve(String(datasetPath));
218
+ const raw = await readFile(absolutePath, 'utf8');
219
+ const rows = absolutePath.endsWith('.json')
220
+ ? JSON.parse(raw)
221
+ : parseJsonl(raw);
222
+ if (!Array.isArray(rows)) {
223
+ throw new Error(`Eval dataset "${absolutePath}" must be a JSON array or JSONL file.`);
224
+ }
225
+ return {
226
+ absolutePath,
227
+ rows,
228
+ raw,
229
+ };
230
+ }
@@ -0,0 +1,5 @@
1
+ export declare function appendScoreboardRow(
2
+ scoreboardDir: string,
3
+ row: Record<string, unknown>,
4
+ options?: { selectionMetric?: string | null; selectionGoal?: string | null }
5
+ ): Promise<{ rowsPath: string; summaryPath: string }>;
@@ -0,0 +1,44 @@
1
+ import { join } from 'node:path';
2
+
3
+ import { writeJsonArtifact, writeNdjsonRow } from './operator-artifacts.js';
4
+
5
+ function resolveComparableMetric(row, metric) {
6
+ if (!row || typeof row !== 'object') return null;
7
+ const direct = row[metric];
8
+ if (typeof direct === 'number' && Number.isFinite(direct)) {
9
+ return direct;
10
+ }
11
+ const metrics = row.metrics && typeof row.metrics === 'object' ? row.metrics : null;
12
+ const nested = metrics?.[metric];
13
+ if (typeof nested === 'number' && Number.isFinite(nested)) {
14
+ return nested;
15
+ }
16
+ return null;
17
+ }
18
+
19
+ export async function appendScoreboardRow(scoreboardDir, row, options = {}) {
20
+ const rowsPath = join(scoreboardDir, 'scoreboard.ndjson');
21
+ await writeNdjsonRow(rowsPath, row);
22
+ const metric = String(options.selectionMetric || row.selectionMetric || row.primaryMetric || '').trim();
23
+ const goal = String(options.selectionGoal || row.selectionGoal || 'max').trim();
24
+ const comparable = resolveComparableMetric(row, metric);
25
+ const summary = {
26
+ artifactType: 'training_scoreboard',
27
+ schemaVersion: 1,
28
+ generatedAt: new Date().toISOString(),
29
+ selectionMetric: metric || null,
30
+ selectionGoal: goal,
31
+ latest: row,
32
+ best: comparable === null
33
+ ? row
34
+ : {
35
+ ...row,
36
+ selectionMetricValue: comparable,
37
+ },
38
+ };
39
+ const summaryResult = await writeJsonArtifact(join(scoreboardDir, 'latest.json'), summary);
40
+ return {
41
+ rowsPath,
42
+ summaryPath: summaryResult.path,
43
+ };
44
+ }
@@ -90,6 +90,16 @@ export interface TrainingStepMetricsEntry {
90
90
  export interface TrainingRunnerCallbacks {
91
91
  onStep?: (entry: TrainingStepMetricsEntry) => Promise<void> | void;
92
92
  onEpoch?: (entry: { epoch: number; steps: number; loss: number }) => Promise<void> | void;
93
+ onCheckpoint?: (entry: {
94
+ key: string;
95
+ defaultCheckpointKey: string | null;
96
+ path: string | null;
97
+ metadata: Record<string, unknown> | null;
98
+ payload: unknown;
99
+ step: number;
100
+ epoch: number;
101
+ batch: number;
102
+ }) => Promise<void> | void;
93
103
  }
94
104
 
95
105
  export interface TrainingRunnerOptions extends TrainingRunnerCallbacks {
@@ -106,6 +116,12 @@ export interface TrainingRunnerOptions extends TrainingRunnerCallbacks {
106
116
  ) => Promise<ClipMetrics>;
107
117
  lossScaler?: DynamicLossScaler;
108
118
  trainingObjective?: TrainingObjective;
119
+ resolveCheckpointKey?: (entry: {
120
+ defaultCheckpointKey: string | null;
121
+ step: number;
122
+ epoch: number;
123
+ batch: number;
124
+ }) => Promise<string> | string;
109
125
  }
110
126
 
111
127
  export interface TrainingRunOptions {
@@ -159,6 +175,9 @@ export declare class TrainingRunner {
159
175
  lastArtifact: UlArtifactFinalizeResult | DistillArtifactFinalizeResult | null;
160
176
  lastCheckpoint: {
161
177
  key: string;
178
+ defaultKey?: string | null;
179
+ path?: string | null;
180
+ metadata?: Record<string, unknown> | null;
162
181
  step: number;
163
182
  epoch: number;
164
183
  batch: number;
@@ -194,3 +213,36 @@ export declare function runTraining(
194
213
  config: TrainingConfigSchema,
195
214
  options?: TrainingRunOptions & TrainingRunnerOptions
196
215
  ): Promise<TrainingStepMetricsEntry[]>;
216
+
217
+ export declare function createTrainingCheckpointPayload(
218
+ model: {
219
+ loraParams?: () => Tensor[];
220
+ paramGroups?: () => Record<string, Tensor[]>;
221
+ },
222
+ optimizer: unknown,
223
+ context: {
224
+ step: number;
225
+ epoch: number;
226
+ batch: number;
227
+ config: TrainingConfigSchema;
228
+ }
229
+ ): Promise<unknown>;
230
+
231
+ export declare function restoreTrainingCheckpointState(
232
+ model: {
233
+ loraParams?: () => Tensor[];
234
+ paramGroups?: () => Record<string, Tensor[]>;
235
+ },
236
+ optimizer: unknown,
237
+ checkpointRecord: unknown,
238
+ config: TrainingConfigSchema
239
+ ): Promise<{
240
+ step: number;
241
+ epoch: number;
242
+ batch: number;
243
+ checkpointHash: string | null;
244
+ previousCheckpointHash: string | null;
245
+ checkpointKey: string | null;
246
+ resumeAudits: Array<Record<string, unknown>>;
247
+ resumeAuditCount: number;
248
+ } | null>;
@@ -713,7 +713,7 @@ function looksLikeTrainingCheckpointRecord(value) {
713
713
  return Number.isInteger(progress.step) && progress.step >= 0;
714
714
  }
715
715
 
716
- async function createTrainingCheckpointPayload(model, optimizer, context) {
716
+ export async function createTrainingCheckpointPayload(model, optimizer, context) {
717
717
  const freezeMap = context.config?.training?.ul?.freeze
718
718
  ?? context.config?.training?.distill?.freeze
719
719
  ?? {};
@@ -747,7 +747,7 @@ async function createTrainingCheckpointPayload(model, optimizer, context) {
747
747
  };
748
748
  }
749
749
 
750
- async function restoreTrainingCheckpointState(model, optimizer, checkpointRecord, config) {
750
+ export async function restoreTrainingCheckpointState(model, optimizer, checkpointRecord, config) {
751
751
  if (!looksLikeTrainingCheckpointRecord(checkpointRecord)) {
752
752
  return null;
753
753
  }
@@ -837,6 +837,8 @@ export class TrainingRunner {
837
837
  this.lossScaler = options.lossScaler || new DynamicLossScaler(config.training.lossScaling);
838
838
  this.onStep = options.onStep || null;
839
839
  this.onEpoch = options.onEpoch || null;
840
+ this.onCheckpoint = options.onCheckpoint || null;
841
+ this.resolveCheckpointKey = options.resolveCheckpointKey || null;
840
842
  this.lastArtifact = null;
841
843
  this.lastCheckpoint = null;
842
844
  this.resumeState = null;
@@ -911,16 +913,39 @@ export class TrainingRunner {
911
913
  batch: checkpointContext.batch,
912
914
  config: this.config,
913
915
  });
914
- await saveCheckpoint(checkpointKey, payload, {
916
+ const resolvedCheckpointKey = this.resolveCheckpointKey
917
+ ? await this.resolveCheckpointKey({
918
+ defaultCheckpointKey: checkpointKey,
919
+ step: checkpointContext.step,
920
+ epoch: checkpointContext.epoch,
921
+ batch: checkpointContext.batch,
922
+ })
923
+ : checkpointKey;
924
+ const saveResult = await saveCheckpoint(resolvedCheckpointKey, payload, {
915
925
  ...checkpointMetadata,
916
926
  optimizerHash: hashStableJson(payload?.trainingState?.optimizerSlots || {}),
917
927
  });
918
928
  this.lastCheckpoint = {
919
- key: checkpointKey,
929
+ key: resolvedCheckpointKey,
930
+ defaultKey: checkpointKey,
931
+ path: saveResult?.path || null,
932
+ metadata: saveResult?.metadata || null,
920
933
  step: checkpointContext.step,
921
934
  epoch: checkpointContext.epoch,
922
935
  batch: checkpointContext.batch,
923
936
  };
937
+ if (this.onCheckpoint) {
938
+ await this.onCheckpoint({
939
+ key: resolvedCheckpointKey,
940
+ defaultCheckpointKey: checkpointKey,
941
+ path: saveResult?.path || null,
942
+ metadata: saveResult?.metadata || null,
943
+ payload,
944
+ step: checkpointContext.step,
945
+ epoch: checkpointContext.epoch,
946
+ batch: checkpointContext.batch,
947
+ });
948
+ }
924
949
  };
925
950
 
926
951
  const artifactSession = distillContract.enabled
@@ -176,6 +176,66 @@ export interface RunTrainingSuiteOptions {
176
176
  timestamp?: string | Date;
177
177
  }
178
178
 
179
+ export interface DistillDataScope {
180
+ sourceLangs: string[] | null;
181
+ targetLangs: string[] | null;
182
+ pairAllowlist: string[] | null;
183
+ sourceLangSet: Set<string> | null;
184
+ targetLangSet: Set<string> | null;
185
+ pairAllowlistSet: Set<string> | null;
186
+ strictPairContract: boolean;
187
+ }
188
+
189
+ export interface DistillDatasetReport {
190
+ absolutePath: string;
191
+ rowCount: number;
192
+ sampleCount: number;
193
+ directionCounts: Record<string, number>;
194
+ dataScope: {
195
+ sourceLangs: string[] | null;
196
+ targetLangs: string[] | null;
197
+ pairAllowlist: string[] | null;
198
+ strictPairContract: boolean;
199
+ } | null;
200
+ shardCount?: number;
201
+ shardPaths?: string[];
202
+ createDataset(options?: Record<string, unknown>): {
203
+ batches(): AsyncGenerator<Record<string, unknown>, void, unknown>;
204
+ };
205
+ }
206
+
207
+ export interface DistillRuntimeContext {
208
+ stage: 'stage_a' | 'stage_b';
209
+ teacherPipeline: Record<string, unknown>;
210
+ studentPipeline: Record<string, unknown>;
211
+ teacherModelId: string;
212
+ studentModelId: string;
213
+ teacherModelUrl: string | null;
214
+ studentModelUrl: string | null;
215
+ topK: number;
216
+ temperature: number;
217
+ alphaKd: number;
218
+ alphaCe: number;
219
+ tripletMargin: number;
220
+ studentGraphMode: string;
221
+ targetTokenMode: string;
222
+ cleanup(): Promise<void>;
223
+ }
224
+
225
+ export interface DistillStudentFixture {
226
+ config: Record<string, unknown>;
227
+ model: {
228
+ forward: (input: unknown, tape: unknown) => Promise<unknown>;
229
+ forwardDistill?: (batch: unknown, tape: unknown, options?: Record<string, unknown>) => Promise<{ logits: unknown }>;
230
+ cleanupDistillStep?: () => void;
231
+ loraParams?: () => unknown[];
232
+ paramGroups?: () => Record<string, unknown[]>;
233
+ };
234
+ outputDim?: number;
235
+ embeddingDim?: number;
236
+ cleanup(): void;
237
+ }
238
+
179
239
  export declare const trainingHarness: TrainingHarness;
180
240
 
181
241
  export declare function runTrainingSuite(
@@ -185,3 +245,55 @@ export declare function runTrainingSuite(
185
245
  export declare function runTrainingBenchSuite(
186
246
  options?: RunTrainingSuiteOptions
187
247
  ): Promise<TrainingBenchSuiteResult>;
248
+
249
+ export declare function resolveDistillDataScope(
250
+ options?: RunTrainingSuiteOptions,
251
+ trainingConfig?: Record<string, unknown> | null
252
+ ): DistillDataScope;
253
+
254
+ export declare function buildDistillPrompt(sample: Record<string, unknown>): string;
255
+
256
+ export declare function normalizeDistillStudentGraphMode(value: unknown): string;
257
+
258
+ export declare function loadDistillDatasetFromJsonl(
259
+ datasetPath: string,
260
+ scopeOptions?: DistillDataScope | null
261
+ ): Promise<DistillDatasetReport | null>;
262
+
263
+ export declare function loadDistillModelHandle(
264
+ modelRef: string,
265
+ role: string,
266
+ loadOptions?: Record<string, unknown>
267
+ ): Promise<{
268
+ modelRef: string;
269
+ modelUrl: string | null;
270
+ manifest: Record<string, unknown>;
271
+ pipeline: Record<string, unknown>;
272
+ }>;
273
+
274
+ export declare function createDistillRuntimeContext(
275
+ options?: RunTrainingSuiteOptions,
276
+ trainingConfig?: Record<string, unknown> | null
277
+ ): Promise<DistillRuntimeContext>;
278
+
279
+ export declare function createToyModelFixture(
280
+ overrides?: Record<string, unknown>
281
+ ): {
282
+ config: Record<string, unknown>;
283
+ model: {
284
+ forward: (input: unknown, tape: unknown) => Promise<unknown>;
285
+ loraParams(): unknown[];
286
+ paramGroups(): Record<string, unknown[]>;
287
+ };
288
+ batch: Record<string, unknown>;
289
+ cleanup(): void;
290
+ };
291
+
292
+ export declare function createDistillStudentRuntimeModelFixture(
293
+ overrides?: Record<string, unknown>,
294
+ options?: Record<string, unknown>
295
+ ): Promise<DistillStudentFixture>;
296
+
297
+ export declare function buildDistillTrainingOverrides(
298
+ options?: RunTrainingSuiteOptions
299
+ ): Record<string, unknown> | null;
@@ -190,7 +190,7 @@ function normalizeDistillPairAllowlist(value) {
190
190
  return [...new Set(normalized)];
191
191
  }
192
192
 
193
- function resolveDistillDataScope(options = {}, trainingConfig = null) {
193
+ export function resolveDistillDataScope(options = {}, trainingConfig = null) {
194
194
  const distillConfig = trainingConfig?.distill || {};
195
195
  const sourceLangs = normalizeDistillLanguageAllowlist(
196
196
  options.distillSourceLangs ?? distillConfig.sourceLangs ?? null
@@ -301,7 +301,7 @@ function resolveLanguageName(langCode) {
301
301
  return normalized || 'target';
302
302
  }
303
303
 
304
- function buildDistillPrompt(sample) {
304
+ export function buildDistillPrompt(sample) {
305
305
  const direction = String(sample?.direction || '').trim();
306
306
  const [srcCodeRaw, tgtCodeRaw] = direction.split('->');
307
307
  const srcCode = normalizeLangCode(srcCodeRaw) || srcCodeRaw || 'source';
@@ -328,7 +328,7 @@ function clampDistillTopK(value) {
328
328
  return Math.max(2, Math.min(256, parsed));
329
329
  }
330
330
 
331
- function normalizeDistillStudentGraphMode(value) {
331
+ export function normalizeDistillStudentGraphMode(value) {
332
332
  const normalized = normalizeOptionalString(value);
333
333
  if (!normalized) return DISTILL_STUDENT_GRAPH_FULL;
334
334
  const compact = normalized.toLowerCase().replace(/[-\s]/g, '_');
@@ -605,7 +605,7 @@ function createDistillTensorDataset(samples, options = {}) {
605
605
  };
606
606
  }
607
607
 
608
- async function loadDistillDatasetFromJsonl(datasetPath, scopeOptions = null) {
608
+ export async function loadDistillDatasetFromJsonl(datasetPath, scopeOptions = null) {
609
609
  const normalizedPath = normalizeDistillDatasetPath(datasetPath);
610
610
  if (!normalizedPath) return null;
611
611
  if (!isNodeRuntime()) {
@@ -820,7 +820,7 @@ async function initializeInferenceFromStore(modelId) {
820
820
  return { pipeline, manifest };
821
821
  }
822
822
 
823
- async function loadDistillModelHandle(modelRef, role, loadOptions = {}) {
823
+ export async function loadDistillModelHandle(modelRef, role, loadOptions = {}) {
824
824
  const normalizedRef = normalizeOptionalString(modelRef);
825
825
  if (!normalizedRef) {
826
826
  throw new Error(`Distill ${role} model reference is required.`);
@@ -876,7 +876,7 @@ function resolveDistillModelRefs(options = {}, trainingConfig = null) {
876
876
  };
877
877
  }
878
878
 
879
- async function createDistillRuntimeContext(options = {}, trainingConfig = null) {
879
+ export async function createDistillRuntimeContext(options = {}, trainingConfig = null) {
880
880
  const { teacherModelRef, studentModelRef } = resolveDistillModelRefs(options, trainingConfig);
881
881
  if (!teacherModelRef || !studentModelRef) {
882
882
  throw new Error('Distill stage requires teacherModelId and studentModelId.');
@@ -967,7 +967,7 @@ async function ensureTrainingGpuRuntime() {
967
967
  await initDevice();
968
968
  }
969
969
 
970
- function createToyModelFixture(overrides = {}) {
970
+ export function createToyModelFixture(overrides = {}) {
971
971
  const config = createTrainingConfig({
972
972
  ...overrides,
973
973
  training: {
@@ -1790,7 +1790,7 @@ async function createDistillStudentTransformerModelFixture(overrides = {}, optio
1790
1790
  };
1791
1791
  }
1792
1792
 
1793
- async function createDistillStudentRuntimeModelFixture(overrides = {}, options = {}) {
1793
+ export async function createDistillStudentRuntimeModelFixture(overrides = {}, options = {}) {
1794
1794
  const distillRuntime = options.distillRuntime && typeof options.distillRuntime === 'object'
1795
1795
  ? options.distillRuntime
1796
1796
  : null;
@@ -2085,7 +2085,7 @@ function buildUlTrainingOverrides(options = {}) {
2085
2085
  };
2086
2086
  }
2087
2087
 
2088
- function buildDistillTrainingOverrides(options = {}) {
2088
+ export function buildDistillTrainingOverrides(options = {}) {
2089
2089
  const trainingConfig = normalizeTrainingConfigOverride(options.trainingConfig);
2090
2090
  const explicitStage = normalizeTrainingStage(options.trainingStage || trainingConfig?.distill?.stage);
2091
2091
  const distillEnabled = isDistillStage(explicitStage) || trainingConfig?.distill?.enabled === true;