@simulatte/doppler 0.1.5 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. package/README.md +23 -8
  2. package/package.json +7 -4
  3. package/src/config/kernels/kernel-ref-digests.js +39 -39
  4. package/src/config/kernels/registry.json +42 -2
  5. package/src/config/loader.js +31 -2
  6. package/src/config/merge.js +18 -0
  7. package/src/config/presets/models/qwen3.json +9 -2
  8. package/src/config/presets/models/transformer.json +5 -0
  9. package/src/config/required-inference-fields-contract-check.js +6 -0
  10. package/src/config/schema/inference-defaults.schema.js +3 -0
  11. package/src/config/schema/inference.schema.d.ts +9 -0
  12. package/src/config/schema/kernel-path.schema.d.ts +6 -0
  13. package/src/config/schema/manifest.schema.d.ts +6 -0
  14. package/src/config/schema/manifest.schema.js +3 -0
  15. package/src/converter/rope-config.js +42 -0
  16. package/src/gpu/device.js +58 -0
  17. package/src/gpu/kernels/attention.js +98 -0
  18. package/src/gpu/kernels/bias_add.wgsl +8 -6
  19. package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
  20. package/src/gpu/kernels/conv2d.js +1 -1
  21. package/src/gpu/kernels/conv2d.wgsl +7 -8
  22. package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
  23. package/src/gpu/kernels/depthwise_conv2d.js +2 -1
  24. package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
  25. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
  26. package/src/gpu/kernels/grouped_pointwise_conv2d.js +2 -1
  27. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
  28. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
  29. package/src/gpu/kernels/matmul.js +25 -0
  30. package/src/gpu/kernels/pixel_shuffle.js +1 -1
  31. package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
  32. package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
  33. package/src/gpu/kernels/relu.js +15 -2
  34. package/src/gpu/kernels/relu.wgsl +2 -1
  35. package/src/gpu/kernels/relu_f16.wgsl +2 -1
  36. package/src/gpu/kernels/repeat_channels.js +1 -1
  37. package/src/gpu/kernels/repeat_channels.wgsl +4 -5
  38. package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
  39. package/src/gpu/kernels/residual.js +44 -8
  40. package/src/gpu/kernels/residual.wgsl +6 -3
  41. package/src/gpu/kernels/residual_f16.wgsl +2 -1
  42. package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
  43. package/src/gpu/kernels/residual_vec4.wgsl +2 -1
  44. package/src/gpu/kernels/rmsnorm.js +58 -6
  45. package/src/gpu/kernels/rmsnorm.wgsl +14 -6
  46. package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
  47. package/src/gpu/kernels/rope.d.ts +2 -0
  48. package/src/gpu/kernels/rope.js +11 -1
  49. package/src/gpu/kernels/rope.wgsl +56 -40
  50. package/src/gpu/kernels/sana_linear_attention.js +1 -2
  51. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
  52. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
  53. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
  54. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
  55. package/src/gpu/kernels/silu.d.ts +1 -0
  56. package/src/gpu/kernels/silu.js +32 -14
  57. package/src/gpu/kernels/silu.wgsl +19 -9
  58. package/src/gpu/kernels/silu_f16.wgsl +19 -9
  59. package/src/gpu/kernels/transpose.js +15 -2
  60. package/src/gpu/kernels/transpose.wgsl +5 -6
  61. package/src/gpu/kernels/upsample2d.js +2 -1
  62. package/src/gpu/kernels/upsample2d.wgsl +6 -9
  63. package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
  64. package/src/gpu/kernels/utils.js +16 -1
  65. package/src/inference/browser-harness.js +47 -1
  66. package/src/inference/pipelines/diffusion/pipeline.js +15 -6
  67. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
  68. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
  69. package/src/inference/pipelines/text/attention/record.js +11 -2
  70. package/src/inference/pipelines/text/attention/run.js +11 -2
  71. package/src/inference/pipelines/text/chat-format.js +25 -1
  72. package/src/inference/pipelines/text/config.d.ts +4 -0
  73. package/src/inference/pipelines/text/config.js +68 -1
  74. package/src/inference/pipelines/text/execution-plan.js +23 -31
  75. package/src/inference/pipelines/text/execution-v0.js +29 -2
  76. package/src/inference/pipelines/text/ffn/standard.js +3 -0
  77. package/src/inference/pipelines/text/init.d.ts +4 -0
  78. package/src/inference/pipelines/text/init.js +56 -9
  79. package/src/inference/pipelines/text/layer.js +11 -0
  80. package/src/inference/pipelines/text.js +4 -0
  81. package/src/inference/tokenizers/bundled.js +156 -33
  82. package/src/rules/tooling/command-runtime.rules.json +18 -0
  83. package/src/tooling/command-api.d.ts +27 -1
  84. package/src/tooling/command-api.js +142 -3
  85. package/src/tooling/node-browser-command-runner.d.ts +4 -0
  86. package/src/tooling/node-browser-command-runner.js +58 -3
  87. package/src/tooling/node-command-runner.js +15 -0
  88. package/src/tooling/node-webgpu.js +9 -87
  89. package/src/training/checkpoint-watch.d.ts +7 -0
  90. package/src/training/checkpoint-watch.js +106 -0
  91. package/src/training/checkpoint.d.ts +6 -1
  92. package/src/training/checkpoint.js +12 -2
  93. package/src/training/distillation/artifacts.d.ts +71 -0
  94. package/src/training/distillation/artifacts.js +132 -0
  95. package/src/training/distillation/checkpoint-watch.d.ts +10 -0
  96. package/src/training/distillation/checkpoint-watch.js +57 -0
  97. package/src/training/distillation/dataset.d.ts +59 -0
  98. package/src/training/distillation/dataset.js +337 -0
  99. package/src/training/distillation/eval.d.ts +34 -0
  100. package/src/training/distillation/eval.js +310 -0
  101. package/src/training/distillation/index.d.ts +29 -0
  102. package/src/training/distillation/index.js +29 -0
  103. package/src/training/distillation/runtime.d.ts +20 -0
  104. package/src/training/distillation/runtime.js +121 -0
  105. package/src/training/distillation/scoreboard.d.ts +6 -0
  106. package/src/training/distillation/scoreboard.js +8 -0
  107. package/src/training/distillation/stage-a.d.ts +45 -0
  108. package/src/training/distillation/stage-a.js +338 -0
  109. package/src/training/distillation/stage-b.d.ts +24 -0
  110. package/src/training/distillation/stage-b.js +20 -0
  111. package/src/training/index.d.ts +10 -0
  112. package/src/training/index.js +10 -0
  113. package/src/training/lora-pipeline.d.ts +40 -0
  114. package/src/training/lora-pipeline.js +796 -0
  115. package/src/training/operator-artifacts.d.ts +62 -0
  116. package/src/training/operator-artifacts.js +140 -0
  117. package/src/training/operator-command.d.ts +5 -0
  118. package/src/training/operator-command.js +453 -0
  119. package/src/training/operator-eval.d.ts +48 -0
  120. package/src/training/operator-eval.js +230 -0
  121. package/src/training/operator-scoreboard.d.ts +5 -0
  122. package/src/training/operator-scoreboard.js +44 -0
  123. package/src/training/runner.d.ts +52 -0
  124. package/src/training/runner.js +29 -4
  125. package/src/training/suite.d.ts +112 -0
  126. package/src/training/suite.js +9 -9
  127. package/src/training/workloads.d.ts +164 -0
  128. package/src/training/workloads.js +539 -0
  129. package/src/version.js +1 -1
  130. package/tools/doppler-cli.js +137 -40
@@ -0,0 +1,338 @@
1
+ import { mkdir } from 'node:fs/promises';
2
+ import { resolve } from 'node:path';
3
+
4
+ import { AdamOptimizer } from '../optimizer.js';
5
+ import { crossEntropyLoss } from '../loss.js';
6
+ import { clipGradients } from '../clip.js';
7
+ import { TrainingRunner } from '../runner.js';
8
+ import {
9
+ createDistillRuntimeContext,
10
+ createDistillStudentRuntimeModelFixture,
11
+ loadDistillDatasetFromJsonl,
12
+ resolveDistillDataScope,
13
+ } from '../suite.js';
14
+ import { loadCanonicalTranslationDataset } from './dataset.js';
15
+ import { evaluateDistillationModel } from './eval.js';
16
+ import {
17
+ buildDistillArtifactBase,
18
+ writeDistillCheckpointComplete,
19
+ writeDistillCheckpointMetadata,
20
+ writeDistillStageManifest,
21
+ } from './artifacts.js';
22
+ import { appendDistillationScoreboardRow } from './scoreboard.js';
23
+ import { buildDistillationTrainingConfigFromWorkload } from './runtime.js';
24
+
25
+ function padStep(step) {
26
+ return String(step).padStart(6, '0');
27
+ }
28
+
29
+ function resolveComparableMetric(report, metric) {
30
+ if (!report || typeof report !== 'object') return null;
31
+ const direct = report[metric];
32
+ if (typeof direct === 'number' && Number.isFinite(direct)) {
33
+ return direct;
34
+ }
35
+ const metrics = report.metrics && typeof report.metrics === 'object' ? report.metrics : null;
36
+ const nested = metrics?.[metric];
37
+ if (typeof nested === 'number' && Number.isFinite(nested)) {
38
+ return nested;
39
+ }
40
+ if (metrics?.[metric]?.score != null && Number.isFinite(metrics[metric].score)) {
41
+ return metrics[metric].score;
42
+ }
43
+ return null;
44
+ }
45
+
46
+ function selectBestReport(reports, metric, goal) {
47
+ const normalizedGoal = String(goal || 'max').trim();
48
+ let best = null;
49
+ let bestValue = null;
50
+ for (const report of reports) {
51
+ const value = resolveComparableMetric(report, metric);
52
+ if (!Number.isFinite(value)) continue;
53
+ if (best === null) {
54
+ best = report;
55
+ bestValue = value;
56
+ continue;
57
+ }
58
+ const better = normalizedGoal === 'min'
59
+ ? value < bestValue
60
+ : value > bestValue;
61
+ if (better) {
62
+ best = report;
63
+ bestValue = value;
64
+ }
65
+ }
66
+ return best;
67
+ }
68
+
69
+ function shouldEvalOnCheckpoint(stageEntry) {
70
+ const schedule = String(stageEntry?.evalSchedule || 'on_checkpoint').trim();
71
+ return schedule === 'on_checkpoint';
72
+ }
73
+
74
+ function shouldEvalAtEnd(stageEntry) {
75
+ const schedule = String(stageEntry?.evalSchedule || 'on_checkpoint').trim();
76
+ return schedule === 'final';
77
+ }
78
+
79
+ export async function runDistillationStage(options) {
80
+ const loadedWorkload = options.loadedWorkload;
81
+ const workload = loadedWorkload.workload;
82
+ const stageEntry = options.stageEntry;
83
+ const datasetPath = options.datasetPath || workload.datasetPath;
84
+ const legacyArtifactDir = options.legacyArtifactDir || resolve(options.layout.runRoot, 'legacy-stage-artifacts');
85
+ if (workload.training.batchSize !== 1) {
86
+ throw new Error('Distillation stage currently requires training.batchSize=1.');
87
+ }
88
+ if (workload.training.accumSteps !== 1) {
89
+ throw new Error('Distillation stage currently requires training.accumSteps=1.');
90
+ }
91
+ const configBundle = buildDistillationTrainingConfigFromWorkload(loadedWorkload, stageEntry, {
92
+ datasetPath,
93
+ artifactDir: legacyArtifactDir,
94
+ stageAArtifact: options.stageAArtifact || null,
95
+ stageAArtifactHash: options.stageAArtifactHash || null,
96
+ });
97
+ const distillDataScope = resolveDistillDataScope({
98
+ distillSourceLangs: workload.pipeline.sourceLangs,
99
+ distillTargetLangs: workload.pipeline.targetLangs,
100
+ distillPairAllowlist: workload.pipeline.pairAllowlist,
101
+ strictPairContract: workload.pipeline.strictPairContract === true,
102
+ }, configBundle.trainingConfig.training);
103
+ const distillDatasetReport = await loadDistillDatasetFromJsonl(datasetPath, distillDataScope);
104
+ if (!distillDatasetReport) {
105
+ throw new Error(`Unable to resolve distillation dataset "${datasetPath}".`);
106
+ }
107
+ const canonicalDataset = await loadCanonicalTranslationDataset(datasetPath, {
108
+ strictPairContract: workload.pipeline.strictPairContract === true,
109
+ sourceLangs: workload.pipeline.sourceLangs,
110
+ targetLangs: workload.pipeline.targetLangs,
111
+ pairAllowlist: workload.pipeline.pairAllowlist,
112
+ });
113
+ const distillRuntime = await createDistillRuntimeContext({
114
+ teacherModelId: workload.teacherModelId,
115
+ studentModelId: workload.studentModelId,
116
+ trainingStage: configBundle.internalStage,
117
+ studentGraphMode: workload.pipeline.studentGraphMode,
118
+ }, configBundle.trainingConfig.training);
119
+ let fixture = null;
120
+ try {
121
+ fixture = await createDistillStudentRuntimeModelFixture({
122
+ training: configBundle.trainingConfig.training,
123
+ }, {
124
+ distillRuntime,
125
+ studentGraphMode: workload.pipeline.studentGraphMode,
126
+ });
127
+ const stageId = stageEntry.id;
128
+ const stageCheckpointsDir = resolve(options.layout.checkpoints, stageId);
129
+ await mkdir(stageCheckpointsDir, { recursive: true });
130
+ const dataset = distillDatasetReport.createDataset({
131
+ batchSize: workload.training.batchSize,
132
+ shuffle: false,
133
+ seed: workload.seed,
134
+ distillRuntime,
135
+ });
136
+ const checkpointArtifacts = [];
137
+ const evalReports = [];
138
+ const runner = new TrainingRunner(fixture.config, {
139
+ optimizer: new AdamOptimizer(fixture.config),
140
+ crossEntropyLoss,
141
+ clipGradients,
142
+ resolveCheckpointKey({ step }) {
143
+ const checkpointId = `checkpoint-${padStep(step)}`;
144
+ return resolve(stageCheckpointsDir, checkpointId, 'state.json');
145
+ },
146
+ onCheckpoint: async (checkpoint) => {
147
+ const checkpointId = `checkpoint-${padStep(checkpoint.step)}`;
148
+ const checkpointBase = buildDistillArtifactBase(loadedWorkload, {
149
+ prefix: 'dst_ckpt',
150
+ artifactType: 'training_checkpoint',
151
+ datasetPath: distillDatasetReport.absolutePath,
152
+ datasetHash: canonicalDataset.canonicalHash,
153
+ stage: stageId,
154
+ checkpointStep: checkpoint.step,
155
+ configHash: configBundle.trainingConfigHash,
156
+ parentArtifacts: options.parentArtifacts || [],
157
+ });
158
+ const metadataPayload = {
159
+ ...checkpointBase,
160
+ checkpointId,
161
+ checkpointPath: checkpoint.path,
162
+ step: checkpoint.step,
163
+ epoch: checkpoint.epoch,
164
+ batch: checkpoint.batch,
165
+ optimizerStatePresent: true,
166
+ schedulerStatePresent: workload.training.optimizer.scheduler.enabled === true,
167
+ stageArtifact: options.stageAArtifact || null,
168
+ resumeLineage: checkpoint.metadata?.lineage || null,
169
+ };
170
+ const metadataFile = await writeDistillCheckpointMetadata(
171
+ options.layout,
172
+ stageId,
173
+ checkpointId,
174
+ metadataPayload
175
+ );
176
+ const completeFile = await writeDistillCheckpointComplete(
177
+ options.layout,
178
+ stageId,
179
+ checkpointId,
180
+ {
181
+ ...metadataPayload,
182
+ metadataPath: metadataFile.path,
183
+ finalized: true,
184
+ }
185
+ );
186
+ checkpointArtifacts.push({
187
+ checkpointId,
188
+ checkpointPath: checkpoint.path,
189
+ metadataPath: metadataFile.path,
190
+ completePath: completeFile.path,
191
+ step: checkpoint.step,
192
+ });
193
+ if (!shouldEvalOnCheckpoint(stageEntry)) {
194
+ return;
195
+ }
196
+ const reports = await evaluateDistillationModel({
197
+ loadedWorkload,
198
+ layout: options.layout,
199
+ stageId,
200
+ checkpointId,
201
+ checkpointStep: checkpoint.step,
202
+ checkpointPath: checkpoint.path,
203
+ distillRuntime,
204
+ model: fixture.model,
205
+ configHash: configBundle.trainingConfigHash,
206
+ parentArtifacts: [{
207
+ artifactType: 'training_checkpoint',
208
+ path: metadataFile.path,
209
+ checkpointId,
210
+ }],
211
+ });
212
+ for (const report of reports) {
213
+ evalReports.push(report);
214
+ await appendDistillationScoreboardRow(options.layout, stageId, {
215
+ artifactType: 'training_scoreboard',
216
+ schemaVersion: 1,
217
+ generatedAt: new Date().toISOString(),
218
+ stage: stageId,
219
+ checkpointId,
220
+ checkpointStep: checkpoint.step,
221
+ evalDatasetId: report.evalDatasetId,
222
+ selectionMetric: stageEntry.selectionMetric,
223
+ selectionGoal: stageEntry.selectionGoal,
224
+ primaryMetric: report.primaryMetric,
225
+ primaryScore: report.primaryScore,
226
+ bleu: report.bleu,
227
+ chrf: report.chrf,
228
+ reportPath: report.reportPath || null,
229
+ metrics: {
230
+ bleu: report.bleu,
231
+ chrf: report.chrf,
232
+ primaryScore: report.primaryScore,
233
+ },
234
+ }, {
235
+ selectionMetric: stageEntry.selectionMetric,
236
+ selectionGoal: stageEntry.selectionGoal,
237
+ });
238
+ }
239
+ },
240
+ });
241
+ const metrics = await runner.run(fixture.model, dataset, {
242
+ epochs: 1,
243
+ batchSize: workload.training.batchSize,
244
+ shuffle: false,
245
+ maxSteps: stageEntry.steps,
246
+ checkpointEvery: stageEntry.checkpointEvery,
247
+ modelId: workload.studentModelId,
248
+ trainingStage: configBundle.internalStage,
249
+ runtimePreset: null,
250
+ command: 'distill',
251
+ surface: 'node',
252
+ timestamp: options.timestamp || null,
253
+ distillArtifactDir: legacyArtifactDir,
254
+ stageAArtifact: options.stageAArtifact || null,
255
+ stageAArtifactHash: options.stageAArtifactHash || null,
256
+ teacherModelId: workload.teacherModelId,
257
+ studentModelId: workload.studentModelId,
258
+ distillDatasetId: workload.datasetId,
259
+ distillDatasetPath: distillDatasetReport.absolutePath,
260
+ distillSourceLangs: workload.pipeline.sourceLangs,
261
+ distillTargetLangs: workload.pipeline.targetLangs,
262
+ distillPairAllowlist: workload.pipeline.pairAllowlist,
263
+ strictPairContract: workload.pipeline.strictPairContract === true,
264
+ });
265
+ if (shouldEvalAtEnd(stageEntry) && runner.lastCheckpoint) {
266
+ const finalCheckpointId = `checkpoint-${padStep(runner.lastCheckpoint.step)}`;
267
+ const reports = await evaluateDistillationModel({
268
+ loadedWorkload,
269
+ layout: options.layout,
270
+ stageId,
271
+ checkpointId: finalCheckpointId,
272
+ checkpointStep: runner.lastCheckpoint.step,
273
+ checkpointPath: runner.lastCheckpoint.path || null,
274
+ distillRuntime,
275
+ model: fixture.model,
276
+ configHash: configBundle.trainingConfigHash,
277
+ });
278
+ evalReports.push(...reports);
279
+ }
280
+ const bestReport = selectBestReport(
281
+ evalReports,
282
+ stageEntry.selectionMetric,
283
+ stageEntry.selectionGoal
284
+ );
285
+ const stageManifestPayload = {
286
+ ...buildDistillArtifactBase(loadedWorkload, {
287
+ prefix: 'dst_stage',
288
+ artifactType: 'distill_stage_manifest',
289
+ datasetPath: distillDatasetReport.absolutePath,
290
+ datasetHash: canonicalDataset.canonicalHash,
291
+ stage: stageId,
292
+ checkpointStep: runner.lastCheckpoint?.step ?? null,
293
+ configHash: configBundle.trainingConfigHash,
294
+ parentArtifacts: options.parentArtifacts || [],
295
+ }),
296
+ stageId,
297
+ trainingStage: configBundle.internalStage,
298
+ objective: stageEntry.objective,
299
+ stagePlanEntry: stageEntry,
300
+ stepCount: Array.isArray(metrics) ? metrics.length : 0,
301
+ checkpointCount: checkpointArtifacts.length,
302
+ bestCheckpointId: bestReport?.checkpointId || null,
303
+ selectionMetric: stageEntry.selectionMetric,
304
+ selectionGoal: stageEntry.selectionGoal,
305
+ checkpointArtifacts,
306
+ evalReports: evalReports.map((report) => ({
307
+ checkpointId: report.checkpointId,
308
+ evalDatasetId: report.evalDatasetId,
309
+ reportPath: report.reportPath || null,
310
+ primaryMetric: report.primaryMetric,
311
+ primaryScore: report.primaryScore,
312
+ bleu: report.bleu,
313
+ chrf: report.chrf,
314
+ })),
315
+ legacyArtifact: runner.lastArtifact || null,
316
+ lastCheckpoint: runner.lastCheckpoint || null,
317
+ };
318
+ const stageManifest = await writeDistillStageManifest(options.layout, stageManifestPayload);
319
+ return {
320
+ stageId,
321
+ trainingStage: configBundle.internalStage,
322
+ metrics,
323
+ checkpointArtifacts,
324
+ evalReports,
325
+ bestReport,
326
+ stageManifestPath: stageManifest.path,
327
+ legacyArtifact: runner.lastArtifact || null,
328
+ lastCheckpoint: runner.lastCheckpoint || null,
329
+ };
330
+ } finally {
331
+ fixture?.cleanup?.();
332
+ await distillRuntime.cleanup();
333
+ }
334
+ }
335
+
336
+ export async function runDistillationStageA(options) {
337
+ return runDistillationStage(options);
338
+ }
@@ -0,0 +1,24 @@
1
+ import type { LoadedTrainingWorkload, DistillStagePlanEntry } from '../workloads.js';
2
+
3
+ export declare function runDistillationStageB(options: {
4
+ loadedWorkload: LoadedTrainingWorkload;
5
+ stageEntry: DistillStagePlanEntry;
6
+ layout: Record<string, string>;
7
+ datasetPath?: string | null;
8
+ stageAArtifact?: string | null;
9
+ stageAArtifactHash?: string | null;
10
+ legacyArtifactDir?: string | null;
11
+ timestamp?: string | Date | null;
12
+ parentArtifacts?: Array<Record<string, unknown>>;
13
+ priorStageResult?: Record<string, unknown> | null;
14
+ }): Promise<{
15
+ stageId: string;
16
+ trainingStage: 'stage_a' | 'stage_b';
17
+ metrics: Record<string, unknown>[];
18
+ checkpointArtifacts: Array<Record<string, unknown>>;
19
+ evalReports: Array<Record<string, unknown>>;
20
+ bestReport: Record<string, unknown> | null;
21
+ stageManifestPath: string;
22
+ legacyArtifact: Record<string, unknown> | null;
23
+ lastCheckpoint: Record<string, unknown> | null;
24
+ }>;
@@ -0,0 +1,20 @@
1
+ import { resolve } from 'node:path';
2
+
3
+ import { runDistillationStage } from './stage-a.js';
4
+
5
+ export async function runDistillationStageB(options) {
6
+ const priorStage = options.priorStageResult || null;
7
+ const stageAArtifact = options.stageAArtifact
8
+ || priorStage?.legacyArtifact?.manifestPath
9
+ || null;
10
+ if (!stageAArtifact) {
11
+ throw new Error('Distillation stage-b requires a Stage A artifact path.');
12
+ }
13
+ return runDistillationStage({
14
+ ...options,
15
+ stageAArtifact: resolve(String(stageAArtifact)),
16
+ stageAArtifactHash: options.stageAArtifactHash
17
+ || priorStage?.legacyArtifact?.manifestHash
18
+ || null,
19
+ });
20
+ }
@@ -10,6 +10,15 @@ export { exportLoRAAdapter } from './export.js';
10
10
  export { DynamicLossScaler, detectOverflow } from './loss-scaling.js';
11
11
  export { TrainingRunner, runTraining } from './runner.js';
12
12
  export { runTrainingSuite, runTrainingBenchSuite, trainingHarness } from './suite.js';
13
+ export { runTrainingOperatorCommand } from './operator-command.js';
14
+ export {
15
+ runLoraPipeline,
16
+ evaluateLoraCheckpoint,
17
+ exportLoraCheckpoint,
18
+ watchLoraCheckpoints,
19
+ compareLoraRun,
20
+ qualityGateLoraRun,
21
+ } from './lora-pipeline.js';
13
22
  export type {
14
23
  TrainingObjective,
15
24
  TrainingObjectiveContext,
@@ -31,6 +40,7 @@ export type {
31
40
  DistillArtifactSession,
32
41
  DistillArtifactFinalizeResult,
33
42
  } from './artifacts.js';
43
+ export * as distillation from './distillation/index.js';
34
44
  export {
35
45
  createDistillArtifactSession,
36
46
  resolveDistillTrainingContract,
@@ -10,6 +10,15 @@ export { exportLoRAAdapter } from './export.js';
10
10
  export { DynamicLossScaler, detectOverflow } from './loss-scaling.js';
11
11
  export { TrainingRunner, runTraining } from './runner.js';
12
12
  export { runTrainingSuite, runTrainingBenchSuite, trainingHarness } from './suite.js';
13
+ export { runTrainingOperatorCommand } from './operator-command.js';
14
+ export {
15
+ runLoraPipeline,
16
+ evaluateLoraCheckpoint,
17
+ exportLoraCheckpoint,
18
+ watchLoraCheckpoints,
19
+ compareLoraRun,
20
+ qualityGateLoraRun,
21
+ } from './lora-pipeline.js';
13
22
  export {
14
23
  createTrainingObjective,
15
24
  isTrainingObjective,
@@ -28,6 +37,7 @@ export {
28
37
  resolveUlTrainingContract,
29
38
  resolveStage1ArtifactContext,
30
39
  } from './artifacts.js';
40
+ export * as distillation from './distillation/index.js';
31
41
  export {
32
42
  resolveUlNoiseScale,
33
43
  resolveUlScheduledLambda,
@@ -0,0 +1,40 @@
1
+ import type { LoadedTrainingWorkload } from './workloads.js';
2
+
3
+ export declare function runLoraPipeline(options: {
4
+ loadedWorkload: LoadedTrainingWorkload;
5
+ runRoot?: string | null;
6
+ timestamp?: string | Date | null;
7
+ }): Promise<Record<string, unknown>>;
8
+
9
+ export declare function evaluateLoraCheckpoint(options: {
10
+ loadedWorkload: LoadedTrainingWorkload;
11
+ checkpointPath: string;
12
+ checkpointId?: string | null;
13
+ checkpointStep?: number | null;
14
+ layout?: Record<string, string> | null;
15
+ }): Promise<Record<string, unknown>[]>;
16
+
17
+ export declare function exportLoraCheckpoint(options: {
18
+ loadedWorkload: LoadedTrainingWorkload;
19
+ checkpointPath: string;
20
+ checkpointId?: string | null;
21
+ checkpointStep?: number | null;
22
+ layout?: Record<string, string> | null;
23
+ exportsDir?: string | null;
24
+ datasetHash?: string | null;
25
+ }): Promise<Record<string, unknown>>;
26
+
27
+ export declare function watchLoraCheckpoints(options: {
28
+ loadedWorkload: LoadedTrainingWorkload;
29
+ runRoot: string;
30
+ pollIntervalMs?: number | null;
31
+ stopWhenIdle?: boolean;
32
+ }): Promise<{ ok: true; processedCount: number; manifestPath: string }>;
33
+
34
+ export declare function compareLoraRun(options: {
35
+ runRoot: string;
36
+ }): Promise<Record<string, unknown>>;
37
+
38
+ export declare function qualityGateLoraRun(options: {
39
+ runRoot: string;
40
+ }): Promise<Record<string, unknown>>;