@simulatte/doppler 0.1.4 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (199) hide show
  1. package/README.md +26 -10
  2. package/package.json +30 -6
  3. package/src/client/doppler-api.browser.d.ts +1 -0
  4. package/src/client/doppler-api.browser.js +288 -0
  5. package/src/client/doppler-api.js +1 -1
  6. package/src/client/doppler-provider/types.js +1 -1
  7. package/src/config/execution-contract-check.d.ts +33 -0
  8. package/src/config/execution-contract-check.js +72 -0
  9. package/src/config/execution-v0-contract-check.d.ts +94 -0
  10. package/src/config/execution-v0-contract-check.js +251 -0
  11. package/src/config/execution-v0-graph-contract-check.d.ts +20 -0
  12. package/src/config/execution-v0-graph-contract-check.js +64 -0
  13. package/src/config/kernel-path-contract-check.d.ts +76 -0
  14. package/src/config/kernel-path-contract-check.js +479 -0
  15. package/src/config/kernel-path-loader.d.ts +16 -0
  16. package/src/config/kernel-path-loader.js +54 -0
  17. package/src/config/kernels/kernel-ref-digests.js +39 -27
  18. package/src/config/kernels/registry.json +598 -2
  19. package/src/config/loader.js +81 -48
  20. package/src/config/merge-contract-check.d.ts +16 -0
  21. package/src/config/merge-contract-check.js +321 -0
  22. package/src/config/merge-helpers.d.ts +58 -0
  23. package/src/config/merge-helpers.js +54 -0
  24. package/src/config/merge.js +21 -6
  25. package/src/config/presets/models/janus-text.json +2 -0
  26. package/src/config/presets/models/qwen3.json +9 -2
  27. package/src/config/presets/models/transformer.json +5 -0
  28. package/src/config/quantization-contract-check.d.ts +12 -0
  29. package/src/config/quantization-contract-check.js +91 -0
  30. package/src/config/required-inference-fields-contract-check.d.ts +24 -0
  31. package/src/config/required-inference-fields-contract-check.js +237 -0
  32. package/src/config/schema/browser-suite-metrics.schema.d.ts +17 -0
  33. package/src/config/schema/browser-suite-metrics.schema.js +46 -0
  34. package/src/config/schema/conversion-report.schema.d.ts +40 -0
  35. package/src/config/schema/conversion-report.schema.js +108 -0
  36. package/src/config/schema/doppler.schema.js +12 -18
  37. package/src/config/schema/index.d.ts +22 -0
  38. package/src/config/schema/index.js +18 -0
  39. package/src/config/schema/inference-defaults.schema.js +3 -0
  40. package/src/config/schema/inference.schema.d.ts +9 -0
  41. package/src/config/schema/kernel-path.schema.d.ts +6 -0
  42. package/src/config/schema/manifest.schema.d.ts +6 -0
  43. package/src/config/schema/manifest.schema.js +3 -0
  44. package/src/converter/core.d.ts +10 -0
  45. package/src/converter/core.js +27 -2
  46. package/src/converter/parsers/diffusion.js +63 -3
  47. package/src/converter/rope-config.js +42 -0
  48. package/src/gpu/device.js +58 -0
  49. package/src/gpu/kernels/attention.js +98 -0
  50. package/src/gpu/kernels/bias_add.wgsl +8 -6
  51. package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
  52. package/src/gpu/kernels/conv2d.js +1 -1
  53. package/src/gpu/kernels/conv2d.wgsl +7 -8
  54. package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
  55. package/src/gpu/kernels/depthwise_conv2d.d.ts +29 -0
  56. package/src/gpu/kernels/depthwise_conv2d.js +99 -0
  57. package/src/gpu/kernels/depthwise_conv2d.wgsl +55 -0
  58. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +59 -0
  59. package/src/gpu/kernels/grouped_pointwise_conv2d.d.ts +27 -0
  60. package/src/gpu/kernels/grouped_pointwise_conv2d.js +93 -0
  61. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +44 -0
  62. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +48 -0
  63. package/src/gpu/kernels/index.d.ts +30 -0
  64. package/src/gpu/kernels/index.js +25 -0
  65. package/src/gpu/kernels/matmul.js +25 -0
  66. package/src/gpu/kernels/pixel_shuffle.js +1 -1
  67. package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
  68. package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
  69. package/src/gpu/kernels/relu.d.ts +18 -0
  70. package/src/gpu/kernels/relu.js +58 -0
  71. package/src/gpu/kernels/relu.wgsl +22 -0
  72. package/src/gpu/kernels/relu_f16.wgsl +24 -0
  73. package/src/gpu/kernels/repeat_channels.d.ts +21 -0
  74. package/src/gpu/kernels/repeat_channels.js +60 -0
  75. package/src/gpu/kernels/repeat_channels.wgsl +28 -0
  76. package/src/gpu/kernels/repeat_channels_f16.wgsl +30 -0
  77. package/src/gpu/kernels/residual.js +44 -8
  78. package/src/gpu/kernels/residual.wgsl +6 -3
  79. package/src/gpu/kernels/residual_f16.wgsl +2 -1
  80. package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
  81. package/src/gpu/kernels/residual_vec4.wgsl +2 -1
  82. package/src/gpu/kernels/rmsnorm.js +58 -6
  83. package/src/gpu/kernels/rmsnorm.wgsl +14 -6
  84. package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
  85. package/src/gpu/kernels/rope.d.ts +2 -0
  86. package/src/gpu/kernels/rope.js +11 -1
  87. package/src/gpu/kernels/rope.wgsl +56 -40
  88. package/src/gpu/kernels/sana_linear_attention.d.ts +27 -0
  89. package/src/gpu/kernels/sana_linear_attention.js +121 -0
  90. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +43 -0
  91. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +46 -0
  92. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +51 -0
  93. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +53 -0
  94. package/src/gpu/kernels/silu.d.ts +1 -0
  95. package/src/gpu/kernels/silu.js +32 -14
  96. package/src/gpu/kernels/silu.wgsl +19 -9
  97. package/src/gpu/kernels/silu_f16.wgsl +19 -9
  98. package/src/gpu/kernels/transpose.js +15 -2
  99. package/src/gpu/kernels/transpose.wgsl +5 -6
  100. package/src/gpu/kernels/upsample2d.js +2 -1
  101. package/src/gpu/kernels/upsample2d.wgsl +6 -9
  102. package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
  103. package/src/gpu/kernels/utils.js +16 -1
  104. package/src/index-browser.d.ts +1 -1
  105. package/src/index-browser.js +2 -2
  106. package/src/index.js +1 -1
  107. package/src/inference/browser-harness.js +109 -23
  108. package/src/inference/pipelines/diffusion/init.js +14 -0
  109. package/src/inference/pipelines/diffusion/pipeline.js +215 -77
  110. package/src/inference/pipelines/diffusion/sana-transformer.d.ts +53 -0
  111. package/src/inference/pipelines/diffusion/sana-transformer.js +738 -0
  112. package/src/inference/pipelines/diffusion/scheduler.d.ts +17 -1
  113. package/src/inference/pipelines/diffusion/scheduler.js +91 -3
  114. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +11 -4
  115. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +282 -0
  116. package/src/inference/pipelines/diffusion/text-encoder.js +18 -1
  117. package/src/inference/pipelines/diffusion/types.d.ts +4 -0
  118. package/src/inference/pipelines/diffusion/vae.js +782 -78
  119. package/src/inference/pipelines/text/attention/record.js +11 -2
  120. package/src/inference/pipelines/text/attention/run.js +11 -2
  121. package/src/inference/pipelines/text/chat-format.js +25 -1
  122. package/src/inference/pipelines/text/config.d.ts +9 -0
  123. package/src/inference/pipelines/text/config.js +69 -2
  124. package/src/inference/pipelines/text/execution-plan.js +23 -31
  125. package/src/inference/pipelines/text/execution-v0.js +43 -95
  126. package/src/inference/pipelines/text/ffn/standard.js +3 -0
  127. package/src/inference/pipelines/text/init.d.ts +4 -0
  128. package/src/inference/pipelines/text/init.js +56 -9
  129. package/src/inference/pipelines/text/layer.js +11 -0
  130. package/src/inference/pipelines/text.js +4 -0
  131. package/src/inference/tokenizers/bundled.js +156 -33
  132. package/src/rules/execution-rules-contract-check.d.ts +17 -0
  133. package/src/rules/execution-rules-contract-check.js +245 -0
  134. package/src/rules/kernels/depthwise-conv2d.rules.json +6 -0
  135. package/src/rules/kernels/grouped-pointwise-conv2d.rules.json +6 -0
  136. package/src/rules/kernels/relu.rules.json +6 -0
  137. package/src/rules/kernels/repeat-channels.rules.json +6 -0
  138. package/src/rules/kernels/sana-linear-attention.rules.json +6 -0
  139. package/src/rules/layer-pattern-contract-check.d.ts +17 -0
  140. package/src/rules/layer-pattern-contract-check.js +231 -0
  141. package/src/rules/rule-registry.d.ts +28 -0
  142. package/src/rules/rule-registry.js +38 -0
  143. package/src/rules/tooling/command-runtime.rules.json +18 -0
  144. package/src/tooling/command-api.d.ts +27 -1
  145. package/src/tooling/command-api.js +142 -3
  146. package/src/tooling/conversion-config-materializer.d.ts +24 -0
  147. package/src/tooling/conversion-config-materializer.js +99 -0
  148. package/src/tooling/lean-execution-contract-runner.d.ts +43 -0
  149. package/src/tooling/lean-execution-contract-runner.js +158 -0
  150. package/src/tooling/node-browser-command-runner.d.ts +4 -0
  151. package/src/tooling/node-browser-command-runner.js +58 -3
  152. package/src/tooling/node-command-runner.js +15 -0
  153. package/src/tooling/node-convert.d.ts +10 -0
  154. package/src/tooling/node-converter.js +59 -0
  155. package/src/tooling/node-webgpu.js +11 -89
  156. package/src/training/checkpoint-watch.d.ts +7 -0
  157. package/src/training/checkpoint-watch.js +106 -0
  158. package/src/training/checkpoint.d.ts +6 -1
  159. package/src/training/checkpoint.js +12 -2
  160. package/src/training/distillation/artifacts.d.ts +71 -0
  161. package/src/training/distillation/artifacts.js +132 -0
  162. package/src/training/distillation/checkpoint-watch.d.ts +10 -0
  163. package/src/training/distillation/checkpoint-watch.js +57 -0
  164. package/src/training/distillation/dataset.d.ts +59 -0
  165. package/src/training/distillation/dataset.js +337 -0
  166. package/src/training/distillation/eval.d.ts +34 -0
  167. package/src/training/distillation/eval.js +310 -0
  168. package/src/training/distillation/index.d.ts +29 -0
  169. package/src/training/distillation/index.js +29 -0
  170. package/src/training/distillation/runtime.d.ts +20 -0
  171. package/src/training/distillation/runtime.js +121 -0
  172. package/src/training/distillation/scoreboard.d.ts +6 -0
  173. package/src/training/distillation/scoreboard.js +8 -0
  174. package/src/training/distillation/stage-a.d.ts +45 -0
  175. package/src/training/distillation/stage-a.js +338 -0
  176. package/src/training/distillation/stage-b.d.ts +24 -0
  177. package/src/training/distillation/stage-b.js +20 -0
  178. package/src/training/index.d.ts +10 -0
  179. package/src/training/index.js +10 -0
  180. package/src/training/lora-pipeline.d.ts +40 -0
  181. package/src/training/lora-pipeline.js +796 -0
  182. package/src/training/operator-artifacts.d.ts +62 -0
  183. package/src/training/operator-artifacts.js +140 -0
  184. package/src/training/operator-command.d.ts +5 -0
  185. package/src/training/operator-command.js +453 -0
  186. package/src/training/operator-eval.d.ts +48 -0
  187. package/src/training/operator-eval.js +230 -0
  188. package/src/training/operator-scoreboard.d.ts +5 -0
  189. package/src/training/operator-scoreboard.js +44 -0
  190. package/src/training/runner.d.ts +52 -0
  191. package/src/training/runner.js +29 -4
  192. package/src/training/suite.d.ts +112 -0
  193. package/src/training/suite.js +9 -9
  194. package/src/training/workloads.d.ts +164 -0
  195. package/src/training/workloads.js +539 -0
  196. package/src/version.d.ts +2 -0
  197. package/src/version.js +2 -0
  198. package/tools/convert-safetensors-node.js +47 -0
  199. package/tools/doppler-cli.js +252 -41
@@ -0,0 +1,337 @@
1
+ import { mkdir, readFile, writeFile } from 'node:fs/promises';
2
+ import { dirname, join, resolve } from 'node:path';
3
+
4
+ import { parseJsonl } from '../datasets/jsonl.js';
5
+ import { sha256Hex } from '../../utils/sha256.js';
6
+
7
+ function stableSortObject(value) {
8
+ if (Array.isArray(value)) {
9
+ return value.map((entry) => stableSortObject(entry));
10
+ }
11
+ if (!value || typeof value !== 'object') {
12
+ return value;
13
+ }
14
+ const sorted = {};
15
+ for (const key of Object.keys(value).sort()) {
16
+ sorted[key] = stableSortObject(value[key]);
17
+ }
18
+ return sorted;
19
+ }
20
+
21
+ function stableJson(value) {
22
+ return JSON.stringify(stableSortObject(value));
23
+ }
24
+
25
+ function asOptionalString(value) {
26
+ if (value === undefined || value === null) return null;
27
+ const trimmed = String(value).trim();
28
+ return trimmed || null;
29
+ }
30
+
31
+ function asOptionalStringArray(value) {
32
+ if (value === undefined || value === null) return null;
33
+ const input = Array.isArray(value)
34
+ ? value
35
+ : (typeof value === 'string' ? value.split(',') : null);
36
+ if (!Array.isArray(input)) {
37
+ return null;
38
+ }
39
+ const normalized = input
40
+ .map((entry) => asOptionalString(entry))
41
+ .filter(Boolean);
42
+ return normalized.length > 0 ? normalized : null;
43
+ }
44
+
45
+ function normalizeLangCode(value) {
46
+ const normalized = asOptionalString(value);
47
+ if (!normalized) return null;
48
+ const compact = normalized.toLowerCase().replace(/_/g, '-');
49
+ if (compact.startsWith('en')) return 'en';
50
+ if (compact.startsWith('es')) return 'es';
51
+ return compact;
52
+ }
53
+
54
+ export function normalizeDistillationPair(value, srcLang = null, tgtLang = null) {
55
+ const pair = asOptionalString(value);
56
+ if (!pair) {
57
+ if (!srcLang || !tgtLang) return null;
58
+ return `${srcLang}->${tgtLang}`;
59
+ }
60
+ const normalized = pair.toLowerCase().replace(/_/g, '-').replace(/\s+/g, '');
61
+ const separator = normalized.includes('->') ? '->' : '-';
62
+ const parts = normalized.split(separator).filter(Boolean);
63
+ if (parts.length !== 2) return null;
64
+ const source = normalizeLangCode(parts[0]) || parts[0];
65
+ const target = normalizeLangCode(parts[1]) || parts[1];
66
+ return `${source}->${target}`;
67
+ }
68
+
69
+ function resolveStringCandidate(record, keys) {
70
+ for (const key of keys) {
71
+ const value = asOptionalString(record?.[key]);
72
+ if (value) return value;
73
+ }
74
+ return null;
75
+ }
76
+
77
+ function resolveStableRowId(record, index, canonical) {
78
+ const explicit = asOptionalString(record?.row_id ?? record?.rowId);
79
+ if (explicit) return explicit;
80
+ return sha256Hex(stableJson({
81
+ index,
82
+ src_lang: canonical.src_lang,
83
+ tgt_lang: canonical.tgt_lang,
84
+ pair: canonical.pair,
85
+ source: canonical.source,
86
+ target_pos: canonical.target_pos,
87
+ target_neg: canonical.target_neg,
88
+ }));
89
+ }
90
+
91
+ export function normalizeTranslationPairRow(record, index, options = {}) {
92
+ if (!record || typeof record !== 'object' || Array.isArray(record)) {
93
+ return null;
94
+ }
95
+ const source = resolveStringCandidate(record, ['source', 'query', 'prompt']);
96
+ const targetPos = resolveStringCandidate(record, ['target_pos', 'target', 'pos', 'completion']);
97
+ const targetNeg = resolveStringCandidate(record, ['target_neg', 'neg']);
98
+ if (!source || !targetPos) {
99
+ return null;
100
+ }
101
+ const srcLang = normalizeLangCode(record?.src_lang ?? record?.source_lang);
102
+ const tgtLang = normalizeLangCode(record?.tgt_lang ?? record?.target_lang ?? record?.lang);
103
+ const pair = normalizeDistillationPair(record?.pair, srcLang, tgtLang);
104
+ const strict = options.strictPairContract === true;
105
+ if (strict) {
106
+ if (!srcLang || !tgtLang) {
107
+ throw new Error('strictPairContract requires src_lang and tgt_lang on each row.');
108
+ }
109
+ if (!pair) {
110
+ throw new Error('strictPairContract requires pair on each row.');
111
+ }
112
+ if (pair !== `${srcLang}->${tgtLang}`) {
113
+ throw new Error(`row pair "${record?.pair}" does not match src/tgt "${srcLang}->${tgtLang}".`);
114
+ }
115
+ }
116
+ const canonical = {
117
+ src_lang: srcLang,
118
+ tgt_lang: tgtLang,
119
+ pair: pair || null,
120
+ source,
121
+ target_pos: targetPos,
122
+ target_neg: targetNeg,
123
+ };
124
+ canonical.row_id = resolveStableRowId(record, index, canonical);
125
+ return canonical;
126
+ }
127
+
128
+ function normalizeFilterSet(value, normalizer) {
129
+ const entries = asOptionalStringArray(value);
130
+ if (!entries) return null;
131
+ const normalized = entries
132
+ .map((entry) => normalizer(entry))
133
+ .filter(Boolean);
134
+ return normalized.length > 0 ? new Set(normalized) : null;
135
+ }
136
+
137
+ function applyDatasetFilters(rows, options = {}) {
138
+ const sourceLangs = normalizeFilterSet(options.sourceLangs, normalizeLangCode);
139
+ const targetLangs = normalizeFilterSet(options.targetLangs, normalizeLangCode);
140
+ const pairs = normalizeFilterSet(options.pairAllowlist, normalizeDistillationPair);
141
+ return rows.filter((row) => {
142
+ if (sourceLangs && (!row.src_lang || !sourceLangs.has(row.src_lang))) {
143
+ return false;
144
+ }
145
+ if (targetLangs && (!row.tgt_lang || !targetLangs.has(row.tgt_lang))) {
146
+ return false;
147
+ }
148
+ if (pairs && (!row.pair || !pairs.has(row.pair))) {
149
+ return false;
150
+ }
151
+ return true;
152
+ });
153
+ }
154
+
155
+ function normalizeSubsetSpec(value) {
156
+ if (!value || typeof value !== 'object' || Array.isArray(value)) {
157
+ return null;
158
+ }
159
+ const sizeRaw = Number(value.size ?? value.count ?? value.rowCount ?? 0);
160
+ const size = Number.isInteger(sizeRaw) && sizeRaw > 0 ? sizeRaw : null;
161
+ const seedRaw = Number(value.seed ?? 1337);
162
+ const seed = Number.isInteger(seedRaw) ? seedRaw : 1337;
163
+ const balanceBy = asOptionalString(value.balanceBy ?? value.allocationMode ?? value.stratifyBy);
164
+ return {
165
+ id: asOptionalString(value.id ?? value.name) || null,
166
+ size,
167
+ seed,
168
+ balanceBy,
169
+ parentSubsetManifest: asOptionalString(
170
+ value.parentSubsetManifest
171
+ ?? value.parentSubset
172
+ ?? value.parentManifest
173
+ ),
174
+ };
175
+ }
176
+
177
+ function buildDeterministicRank(seed, rowId) {
178
+ return sha256Hex(`${seed}:${rowId}`);
179
+ }
180
+
181
+ async function resolveParentRowIds(parentSubsetManifest) {
182
+ if (!parentSubsetManifest) return null;
183
+ const absolutePath = resolve(parentSubsetManifest);
184
+ const raw = await readFile(absolutePath, 'utf8');
185
+ const parsed = JSON.parse(raw);
186
+ const rowIds = Array.isArray(parsed?.rowIds)
187
+ ? parsed.rowIds
188
+ : [];
189
+ return rowIds.length > 0 ? new Set(rowIds.map((entry) => String(entry))) : new Set();
190
+ }
191
+
192
+ function buildPairBalancedSubset(rows, size, seed) {
193
+ const byPair = new Map();
194
+ for (const row of rows) {
195
+ const key = row.pair || 'unknown';
196
+ const bucket = byPair.get(key) || [];
197
+ bucket.push(row);
198
+ byPair.set(key, bucket);
199
+ }
200
+ for (const bucket of byPair.values()) {
201
+ bucket.sort((left, right) => {
202
+ const leftRank = buildDeterministicRank(seed, left.row_id);
203
+ const rightRank = buildDeterministicRank(seed, right.row_id);
204
+ return leftRank.localeCompare(rightRank);
205
+ });
206
+ }
207
+ const pairKeys = [...byPair.keys()].sort((left, right) => left.localeCompare(right));
208
+ const selected = [];
209
+ let cursor = 0;
210
+ while (selected.length < size) {
211
+ let progressed = false;
212
+ for (const pairKey of pairKeys) {
213
+ const bucket = byPair.get(pairKey);
214
+ if (!bucket || cursor >= bucket.length) continue;
215
+ selected.push(bucket[cursor]);
216
+ progressed = true;
217
+ if (selected.length >= size) break;
218
+ }
219
+ if (!progressed) break;
220
+ cursor += 1;
221
+ }
222
+ return selected;
223
+ }
224
+
225
+ function selectSubsetRows(rows, subsetSpec) {
226
+ if (!subsetSpec || !subsetSpec.size || subsetSpec.size >= rows.length) {
227
+ return rows.slice();
228
+ }
229
+ if (subsetSpec.balanceBy === 'pair' || subsetSpec.balanceBy === 'pair_balance') {
230
+ return buildPairBalancedSubset(rows, subsetSpec.size, subsetSpec.seed);
231
+ }
232
+ return rows
233
+ .slice()
234
+ .sort((left, right) => {
235
+ const leftRank = buildDeterministicRank(subsetSpec.seed, left.row_id);
236
+ const rightRank = buildDeterministicRank(subsetSpec.seed, right.row_id);
237
+ return leftRank.localeCompare(rightRank);
238
+ })
239
+ .slice(0, subsetSpec.size);
240
+ }
241
+
242
+ function computeDirectionCounts(rows) {
243
+ const counts = {};
244
+ for (const row of rows) {
245
+ const key = row.pair || 'unknown';
246
+ counts[key] = (counts[key] || 0) + 1;
247
+ }
248
+ return counts;
249
+ }
250
+
251
+ export async function loadCanonicalTranslationDataset(datasetPath, options = {}) {
252
+ const absolutePath = resolve(String(datasetPath));
253
+ const raw = await readFile(absolutePath, 'utf8');
254
+ const parsed = absolutePath.endsWith('.json')
255
+ ? JSON.parse(raw)
256
+ : parseJsonl(raw);
257
+ if (!Array.isArray(parsed)) {
258
+ throw new Error(`Distillation dataset "${absolutePath}" must be a JSON array or JSONL file.`);
259
+ }
260
+ const normalizedRows = [];
261
+ for (let index = 0; index < parsed.length; index += 1) {
262
+ const row = normalizeTranslationPairRow(parsed[index], index, options);
263
+ if (row) {
264
+ normalizedRows.push(row);
265
+ }
266
+ }
267
+ const filteredRows = applyDatasetFilters(normalizedRows, options);
268
+ if (filteredRows.length === 0) {
269
+ throw new Error(`Distillation dataset "${absolutePath}" has no usable rows after contract checks and filters.`);
270
+ }
271
+ const rowIds = filteredRows.map((row) => row.row_id);
272
+ return {
273
+ absolutePath,
274
+ raw,
275
+ rows: filteredRows,
276
+ rowCount: filteredRows.length,
277
+ directionCounts: computeDirectionCounts(filteredRows),
278
+ datasetHash: sha256Hex(raw),
279
+ canonicalHash: sha256Hex(stableJson(filteredRows)),
280
+ rowIdsHash: sha256Hex(rowIds.join('\n')),
281
+ };
282
+ }
283
+
284
+ export async function buildFrozenSubset(options) {
285
+ const dataset = await loadCanonicalTranslationDataset(options.datasetPath, {
286
+ strictPairContract: options.strictPairContract === true,
287
+ sourceLangs: options.sourceLangs,
288
+ targetLangs: options.targetLangs,
289
+ pairAllowlist: options.pairAllowlist,
290
+ });
291
+ const subsetSpec = normalizeSubsetSpec(options.subsetSpec);
292
+ const parentRowIds = subsetSpec?.parentSubsetManifest
293
+ ? await resolveParentRowIds(subsetSpec.parentSubsetManifest)
294
+ : null;
295
+ const scopedRows = parentRowIds
296
+ ? dataset.rows.filter((row) => parentRowIds.has(row.row_id))
297
+ : dataset.rows;
298
+ const subsetRows = selectSubsetRows(scopedRows, subsetSpec);
299
+ const outputDir = resolve(String(options.outputDir));
300
+ const subsetJsonlPath = join(outputDir, 'subset.jsonl');
301
+ const rowIdsPath = join(outputDir, 'row_ids.txt');
302
+ const manifestPath = join(outputDir, 'subset_manifest.json');
303
+ const serializedRows = `${subsetRows.map((row) => JSON.stringify(row)).join('\n')}\n`;
304
+ const rowIdsText = `${subsetRows.map((row) => row.row_id).join('\n')}\n`;
305
+ const manifest = {
306
+ artifactType: 'subset_manifest',
307
+ schemaVersion: 1,
308
+ generatedAt: new Date().toISOString(),
309
+ datasetPath: dataset.absolutePath,
310
+ datasetHash: dataset.datasetHash,
311
+ canonicalHash: dataset.canonicalHash,
312
+ rowIdsHash: sha256Hex(rowIdsText),
313
+ universeRowCount: dataset.rowCount,
314
+ subsetRowCount: subsetRows.length,
315
+ directionCounts: computeDirectionCounts(subsetRows),
316
+ subsetSpec,
317
+ parentSubsetManifest: subsetSpec?.parentSubsetManifest || null,
318
+ rowIds: subsetRows.map((row) => row.row_id),
319
+ output: {
320
+ subsetJsonlPath,
321
+ rowIdsPath,
322
+ },
323
+ };
324
+ await mkdir(outputDir, { recursive: true });
325
+ await writeFile(subsetJsonlPath, serializedRows, 'utf8');
326
+ await writeFile(rowIdsPath, rowIdsText, 'utf8');
327
+ await mkdir(dirname(manifestPath), { recursive: true });
328
+ await writeFile(manifestPath, `${JSON.stringify(manifest, null, 2)}\n`, 'utf8');
329
+ return {
330
+ dataset,
331
+ subsetRows,
332
+ subsetJsonlPath,
333
+ rowIdsPath,
334
+ manifestPath,
335
+ manifest,
336
+ };
337
+ }
@@ -0,0 +1,34 @@
1
+ import type { LoadedTrainingWorkload } from '../workloads.js';
2
+
3
+ export declare function evaluateDistillationModel(options: {
4
+ loadedWorkload: LoadedTrainingWorkload;
5
+ layout?: Record<string, string> | null;
6
+ stageId: string;
7
+ checkpointId: string;
8
+ checkpointStep: number | null;
9
+ checkpointPath?: string | null;
10
+ distillRuntime: Record<string, unknown>;
11
+ model: Record<string, unknown>;
12
+ evalDatasetId?: string | null;
13
+ configHash?: string | null;
14
+ parentArtifacts?: Array<Record<string, unknown>>;
15
+ }): Promise<Record<string, unknown>[]>;
16
+
17
+ export declare function evaluateDistillationCheckpoint(options: {
18
+ loadedWorkload: LoadedTrainingWorkload;
19
+ checkpointPath: string;
20
+ checkpointId?: string | null;
21
+ checkpointStep?: number | null;
22
+ stageId?: string | null;
23
+ layout?: Record<string, string> | null;
24
+ datasetPath?: string | null;
25
+ stageAArtifact?: string | null;
26
+ stageAArtifactHash?: string | null;
27
+ evalDatasetId?: string | null;
28
+ parentArtifacts?: Array<Record<string, unknown>>;
29
+ }): Promise<Record<string, unknown>[]>;
30
+
31
+ export declare function readDistillCheckpointMarker(markerPath: string): Promise<{
32
+ absolutePath: string;
33
+ marker: Record<string, unknown>;
34
+ }>;
@@ -0,0 +1,310 @@
1
+ import { readFile } from 'node:fs/promises';
2
+ import { dirname, join, resolve } from 'node:path';
3
+
4
+ import { loadBackwardRegistry } from '../../config/backward-registry-loader.js';
5
+ import { f16ToF32Array } from '../../inference/kv-cache/types.js';
6
+ import { readBuffer, releaseBuffer } from '../../memory/buffer-pool.js';
7
+ import { AutogradTape } from '../autograd.js';
8
+ import { loadCheckpoint } from '../checkpoint.js';
9
+ import { computeEvalMetrics } from '../operator-eval.js';
10
+ import {
11
+ buildDistillPrompt,
12
+ createDistillRuntimeContext,
13
+ createDistillStudentRuntimeModelFixture,
14
+ resolveDistillDataScope,
15
+ } from '../suite.js';
16
+ import { restoreTrainingCheckpointState } from '../runner.js';
17
+ import { loadCanonicalTranslationDataset } from './dataset.js';
18
+ import { buildDistillArtifactBase, writeDistillEvalReport } from './artifacts.js';
19
+ import { buildDistillationTrainingConfigFromWorkload, resolveInternalDistillStage } from './runtime.js';
20
+
21
+ function toFloat32Array(raw, dtype = 'f32') {
22
+ if (dtype === 'f16') {
23
+ return f16ToF32Array(new Uint16Array(raw));
24
+ }
25
+ return new Float32Array(raw);
26
+ }
27
+
28
+ function resolveEvalDatasets(workload, requestedEvalDatasetId = null) {
29
+ const evalDatasets = Array.isArray(workload.evalDatasets) ? workload.evalDatasets : [];
30
+ if (!requestedEvalDatasetId) {
31
+ return evalDatasets;
32
+ }
33
+ return evalDatasets.filter((entry) => entry.id === requestedEvalDatasetId);
34
+ }
35
+
36
+ function argmax(values) {
37
+ let bestIndex = 0;
38
+ let bestValue = Number.NEGATIVE_INFINITY;
39
+ for (let index = 0; index < values.length; index += 1) {
40
+ const value = Number.isFinite(values[index]) ? values[index] : Number.NEGATIVE_INFINITY;
41
+ if (value > bestValue) {
42
+ bestValue = value;
43
+ bestIndex = index;
44
+ }
45
+ }
46
+ return bestIndex;
47
+ }
48
+
49
+ function releaseTensorLike(value, released) {
50
+ if (!value || typeof value !== 'object') return;
51
+ const buffer = value.buffer;
52
+ if (!buffer || released.has(buffer)) return;
53
+ released.add(buffer);
54
+ releaseBuffer(buffer);
55
+ }
56
+
57
+ function disposeTapeOutputs(tape, protectedBuffers = new Set()) {
58
+ if (!tape || !Array.isArray(tape.records)) return;
59
+ const released = new Set();
60
+ for (const record of tape.records) {
61
+ const output = record?.output;
62
+ if (!output || typeof output !== 'object') continue;
63
+ if (output.buffer && !protectedBuffers.has(output.buffer)) {
64
+ releaseTensorLike(output, released);
65
+ continue;
66
+ }
67
+ if (Array.isArray(output)) {
68
+ for (const entry of output) {
69
+ if (entry?.buffer && !protectedBuffers.has(entry.buffer)) {
70
+ releaseTensorLike(entry, released);
71
+ }
72
+ }
73
+ }
74
+ }
75
+ }
76
+
77
+ function collectProtectedBuffers(model) {
78
+ const protectedBuffers = new Set();
79
+ const groups = typeof model?.paramGroups === 'function'
80
+ ? model.paramGroups()
81
+ : {};
82
+ for (const params of Object.values(groups || {})) {
83
+ for (const tensor of Array.isArray(params) ? params : []) {
84
+ if (tensor?.buffer) {
85
+ protectedBuffers.add(tensor.buffer);
86
+ }
87
+ }
88
+ }
89
+ return protectedBuffers;
90
+ }
91
+
92
+ async function readLogitsTensor(tensor) {
93
+ const raw = await readBuffer(tensor.buffer);
94
+ return toFloat32Array(raw, tensor.dtype || 'f32');
95
+ }
96
+
97
+ async function greedyDecodeFixture(model, tokenizer, prompt, decodePolicy = {}) {
98
+ const maxTokens = Number.isInteger(decodePolicy?.maxTokens) && decodePolicy.maxTokens > 0
99
+ ? decodePolicy.maxTokens
100
+ : null;
101
+ if (!maxTokens) {
102
+ throw new Error('Translation eval requires evalDatasets[].decodePolicy.maxTokens in the workload pack.');
103
+ }
104
+ const stopOnEos = decodePolicy?.stopOnEos !== false;
105
+ const eosToken = tokenizer?.getSpecialTokens?.()?.eos ?? null;
106
+ const protectedBuffers = collectProtectedBuffers(model);
107
+ const generated = [];
108
+ let currentPrompt = prompt;
109
+ for (let step = 0; step < maxTokens; step += 1) {
110
+ const tape = new AutogradTape(loadBackwardRegistry());
111
+ let logits = null;
112
+ try {
113
+ const result = await model.forwardDistill(
114
+ {
115
+ distill: {
116
+ prompts: [currentPrompt],
117
+ },
118
+ },
119
+ tape,
120
+ { phase: 'anchor' }
121
+ );
122
+ logits = result?.logits || result;
123
+ const values = await readLogitsTensor(logits);
124
+ const tokenId = argmax(values);
125
+ if (stopOnEos && eosToken != null && tokenId === eosToken) {
126
+ break;
127
+ }
128
+ generated.push(tokenId);
129
+ currentPrompt = `${prompt}${tokenizer.decode(generated, false, false)}`;
130
+ } finally {
131
+ if (logits?.buffer && !protectedBuffers.has(logits.buffer)) {
132
+ releaseBuffer(logits.buffer);
133
+ }
134
+ model.cleanupDistillStep?.();
135
+ disposeTapeOutputs(tape, protectedBuffers);
136
+ }
137
+ }
138
+ return tokenizer.decode(generated, true, true);
139
+ }
140
+
141
+ function flattenMetricSummary(metrics) {
142
+ return {
143
+ bleu: metrics?.bleu?.score ?? null,
144
+ chrf: metrics?.chrf?.score ?? null,
145
+ exact_match: metrics?.exactMatch?.score ?? null,
146
+ accuracy: metrics?.accuracy?.score ?? null,
147
+ primaryMetric: metrics?.primaryMetric || null,
148
+ primaryScore: metrics?.primaryScore ?? null,
149
+ };
150
+ }
151
+
152
+ export async function evaluateDistillationModel(options) {
153
+ const loadedWorkload = options.loadedWorkload;
154
+ const workload = loadedWorkload.workload;
155
+ const stageId = options.stageId;
156
+ const checkpointId = options.checkpointId;
157
+ const checkpointStep = options.checkpointStep;
158
+ const distillRuntime = options.distillRuntime;
159
+ const model = options.model;
160
+ if (distillRuntime.studentGraphMode !== 'transformer_full') {
161
+ throw new Error(
162
+ `Distillation eval requires studentGraphMode="transformer_full"; got "${distillRuntime.studentGraphMode}".`
163
+ );
164
+ }
165
+ const evalDatasets = resolveEvalDatasets(workload, options.evalDatasetId || null);
166
+ if (evalDatasets.length === 0) {
167
+ throw new Error(`No eval datasets resolved for workload "${workload.id}".`);
168
+ }
169
+ const reports = [];
170
+ for (const evalDataset of evalDatasets) {
171
+ if (evalDataset.evalKind !== 'translation') {
172
+ throw new Error(`Distillation eval currently supports translation eval only, got "${evalDataset.evalKind}".`);
173
+ }
174
+ const dataset = await loadCanonicalTranslationDataset(evalDataset.datasetPath, {
175
+ strictPairContract: workload.pipeline.strictPairContract === true,
176
+ sourceLangs: evalDataset.sourceLangs || workload.pipeline.sourceLangs,
177
+ targetLangs: evalDataset.targetLangs || workload.pipeline.targetLangs,
178
+ pairAllowlist: evalDataset.pairAllowlist || workload.pipeline.pairAllowlist,
179
+ });
180
+ const hypotheses = [];
181
+ const references = [];
182
+ const samples = [];
183
+ for (const row of dataset.rows) {
184
+ const prompt = buildDistillPrompt({
185
+ direction: row.pair || (
186
+ row.src_lang && row.tgt_lang
187
+ ? `${row.src_lang}->${row.tgt_lang}`
188
+ : 'unknown'
189
+ ),
190
+ source: row.source,
191
+ });
192
+ const hypothesis = await greedyDecodeFixture(
193
+ model,
194
+ distillRuntime.studentPipeline.tokenizer,
195
+ prompt,
196
+ evalDataset.decodePolicy || {}
197
+ );
198
+ hypotheses.push(hypothesis);
199
+ references.push(row.target_pos);
200
+ if (samples.length < 5) {
201
+ samples.push({
202
+ row_id: row.row_id,
203
+ source: row.source,
204
+ reference: row.target_pos,
205
+ hypothesis,
206
+ });
207
+ }
208
+ }
209
+ const computedMetrics = computeEvalMetrics('translation', hypotheses, references, {});
210
+ const flattened = flattenMetricSummary(computedMetrics);
211
+ const reportPayload = {
212
+ ...buildDistillArtifactBase(loadedWorkload, {
213
+ prefix: 'dst_eval',
214
+ artifactType: 'training_eval_report',
215
+ datasetPath: dataset.absolutePath,
216
+ datasetHash: dataset.canonicalHash,
217
+ stage: stageId,
218
+ checkpointStep,
219
+ parentArtifacts: options.parentArtifacts || [],
220
+ configHash: options.configHash || workload.configHash,
221
+ }),
222
+ checkpointId,
223
+ checkpointPath: options.checkpointPath || null,
224
+ evalDatasetId: evalDataset.id,
225
+ evalKind: evalDataset.evalKind,
226
+ metrics: computedMetrics,
227
+ ...flattened,
228
+ rowCount: dataset.rowCount,
229
+ sampleRows: samples,
230
+ };
231
+ const reportFile = options.layout
232
+ ? await writeDistillEvalReport(options.layout, reportPayload)
233
+ : null;
234
+ reports.push({
235
+ ...reportPayload,
236
+ reportPath: reportFile?.path || null,
237
+ });
238
+ }
239
+ return reports;
240
+ }
241
+
242
+ export async function evaluateDistillationCheckpoint(options) {
243
+ const loadedWorkload = options.loadedWorkload;
244
+ const workload = loadedWorkload.workload;
245
+ const stagePlan = workload.pipeline.stagePlan;
246
+ const requestedStageId = String(options.stageId || '').trim();
247
+ const stageEntry = stagePlan.find((entry) => entry.id === requestedStageId)
248
+ || stagePlan[0];
249
+ if (!stageEntry) {
250
+ throw new Error(`No stage entry resolved for workload "${workload.id}".`);
251
+ }
252
+ const internalStage = resolveInternalDistillStage(stageEntry);
253
+ const checkpointPath = resolve(String(options.checkpointPath));
254
+ const checkpointRecord = await loadCheckpoint(checkpointPath);
255
+ if (!checkpointRecord) {
256
+ throw new Error(`Checkpoint not found: ${checkpointPath}`);
257
+ }
258
+ const configBundle = buildDistillationTrainingConfigFromWorkload(loadedWorkload, stageEntry, {
259
+ datasetPath: options.datasetPath || workload.datasetPath,
260
+ stageAArtifact: options.stageAArtifact || null,
261
+ stageAArtifactHash: options.stageAArtifactHash || null,
262
+ artifactDir: dirname(dirname(checkpointPath)),
263
+ });
264
+ const distillRuntime = await createDistillRuntimeContext({
265
+ teacherModelId: workload.teacherModelId,
266
+ studentModelId: workload.studentModelId,
267
+ trainingStage: internalStage,
268
+ studentGraphMode: workload.pipeline.studentGraphMode,
269
+ }, configBundle.trainingConfig.training);
270
+ let fixture = null;
271
+ try {
272
+ fixture = await createDistillStudentRuntimeModelFixture({
273
+ training: configBundle.trainingConfig.training,
274
+ }, {
275
+ distillRuntime,
276
+ studentGraphMode: workload.pipeline.studentGraphMode,
277
+ });
278
+ await restoreTrainingCheckpointState(
279
+ fixture.model,
280
+ { getState: () => null, stepCount: 0 },
281
+ checkpointRecord,
282
+ fixture.config
283
+ );
284
+ return evaluateDistillationModel({
285
+ loadedWorkload,
286
+ layout: options.layout || null,
287
+ stageId: stageEntry.id,
288
+ checkpointId: options.checkpointId || 'checkpoint',
289
+ checkpointStep: options.checkpointStep || null,
290
+ checkpointPath,
291
+ distillRuntime,
292
+ model: fixture.model,
293
+ evalDatasetId: options.evalDatasetId || null,
294
+ configHash: configBundle.trainingConfigHash,
295
+ parentArtifacts: options.parentArtifacts || [],
296
+ });
297
+ } finally {
298
+ fixture?.cleanup?.();
299
+ await distillRuntime.cleanup();
300
+ }
301
+ }
302
+
303
+ export async function readDistillCheckpointMarker(markerPath) {
304
+ const absolutePath = resolve(String(markerPath));
305
+ const raw = await readFile(absolutePath, 'utf8');
306
+ return {
307
+ absolutePath,
308
+ marker: JSON.parse(raw),
309
+ };
310
+ }