@simulatte/doppler 0.1.4 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (199) hide show
  1. package/README.md +26 -10
  2. package/package.json +30 -6
  3. package/src/client/doppler-api.browser.d.ts +1 -0
  4. package/src/client/doppler-api.browser.js +288 -0
  5. package/src/client/doppler-api.js +1 -1
  6. package/src/client/doppler-provider/types.js +1 -1
  7. package/src/config/execution-contract-check.d.ts +33 -0
  8. package/src/config/execution-contract-check.js +72 -0
  9. package/src/config/execution-v0-contract-check.d.ts +94 -0
  10. package/src/config/execution-v0-contract-check.js +251 -0
  11. package/src/config/execution-v0-graph-contract-check.d.ts +20 -0
  12. package/src/config/execution-v0-graph-contract-check.js +64 -0
  13. package/src/config/kernel-path-contract-check.d.ts +76 -0
  14. package/src/config/kernel-path-contract-check.js +479 -0
  15. package/src/config/kernel-path-loader.d.ts +16 -0
  16. package/src/config/kernel-path-loader.js +54 -0
  17. package/src/config/kernels/kernel-ref-digests.js +39 -27
  18. package/src/config/kernels/registry.json +598 -2
  19. package/src/config/loader.js +81 -48
  20. package/src/config/merge-contract-check.d.ts +16 -0
  21. package/src/config/merge-contract-check.js +321 -0
  22. package/src/config/merge-helpers.d.ts +58 -0
  23. package/src/config/merge-helpers.js +54 -0
  24. package/src/config/merge.js +21 -6
  25. package/src/config/presets/models/janus-text.json +2 -0
  26. package/src/config/presets/models/qwen3.json +9 -2
  27. package/src/config/presets/models/transformer.json +5 -0
  28. package/src/config/quantization-contract-check.d.ts +12 -0
  29. package/src/config/quantization-contract-check.js +91 -0
  30. package/src/config/required-inference-fields-contract-check.d.ts +24 -0
  31. package/src/config/required-inference-fields-contract-check.js +237 -0
  32. package/src/config/schema/browser-suite-metrics.schema.d.ts +17 -0
  33. package/src/config/schema/browser-suite-metrics.schema.js +46 -0
  34. package/src/config/schema/conversion-report.schema.d.ts +40 -0
  35. package/src/config/schema/conversion-report.schema.js +108 -0
  36. package/src/config/schema/doppler.schema.js +12 -18
  37. package/src/config/schema/index.d.ts +22 -0
  38. package/src/config/schema/index.js +18 -0
  39. package/src/config/schema/inference-defaults.schema.js +3 -0
  40. package/src/config/schema/inference.schema.d.ts +9 -0
  41. package/src/config/schema/kernel-path.schema.d.ts +6 -0
  42. package/src/config/schema/manifest.schema.d.ts +6 -0
  43. package/src/config/schema/manifest.schema.js +3 -0
  44. package/src/converter/core.d.ts +10 -0
  45. package/src/converter/core.js +27 -2
  46. package/src/converter/parsers/diffusion.js +63 -3
  47. package/src/converter/rope-config.js +42 -0
  48. package/src/gpu/device.js +58 -0
  49. package/src/gpu/kernels/attention.js +98 -0
  50. package/src/gpu/kernels/bias_add.wgsl +8 -6
  51. package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
  52. package/src/gpu/kernels/conv2d.js +1 -1
  53. package/src/gpu/kernels/conv2d.wgsl +7 -8
  54. package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
  55. package/src/gpu/kernels/depthwise_conv2d.d.ts +29 -0
  56. package/src/gpu/kernels/depthwise_conv2d.js +99 -0
  57. package/src/gpu/kernels/depthwise_conv2d.wgsl +55 -0
  58. package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +59 -0
  59. package/src/gpu/kernels/grouped_pointwise_conv2d.d.ts +27 -0
  60. package/src/gpu/kernels/grouped_pointwise_conv2d.js +93 -0
  61. package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +44 -0
  62. package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +48 -0
  63. package/src/gpu/kernels/index.d.ts +30 -0
  64. package/src/gpu/kernels/index.js +25 -0
  65. package/src/gpu/kernels/matmul.js +25 -0
  66. package/src/gpu/kernels/pixel_shuffle.js +1 -1
  67. package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
  68. package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
  69. package/src/gpu/kernels/relu.d.ts +18 -0
  70. package/src/gpu/kernels/relu.js +58 -0
  71. package/src/gpu/kernels/relu.wgsl +22 -0
  72. package/src/gpu/kernels/relu_f16.wgsl +24 -0
  73. package/src/gpu/kernels/repeat_channels.d.ts +21 -0
  74. package/src/gpu/kernels/repeat_channels.js +60 -0
  75. package/src/gpu/kernels/repeat_channels.wgsl +28 -0
  76. package/src/gpu/kernels/repeat_channels_f16.wgsl +30 -0
  77. package/src/gpu/kernels/residual.js +44 -8
  78. package/src/gpu/kernels/residual.wgsl +6 -3
  79. package/src/gpu/kernels/residual_f16.wgsl +2 -1
  80. package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
  81. package/src/gpu/kernels/residual_vec4.wgsl +2 -1
  82. package/src/gpu/kernels/rmsnorm.js +58 -6
  83. package/src/gpu/kernels/rmsnorm.wgsl +14 -6
  84. package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
  85. package/src/gpu/kernels/rope.d.ts +2 -0
  86. package/src/gpu/kernels/rope.js +11 -1
  87. package/src/gpu/kernels/rope.wgsl +56 -40
  88. package/src/gpu/kernels/sana_linear_attention.d.ts +27 -0
  89. package/src/gpu/kernels/sana_linear_attention.js +121 -0
  90. package/src/gpu/kernels/sana_linear_attention_apply.wgsl +43 -0
  91. package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +46 -0
  92. package/src/gpu/kernels/sana_linear_attention_summary.wgsl +51 -0
  93. package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +53 -0
  94. package/src/gpu/kernels/silu.d.ts +1 -0
  95. package/src/gpu/kernels/silu.js +32 -14
  96. package/src/gpu/kernels/silu.wgsl +19 -9
  97. package/src/gpu/kernels/silu_f16.wgsl +19 -9
  98. package/src/gpu/kernels/transpose.js +15 -2
  99. package/src/gpu/kernels/transpose.wgsl +5 -6
  100. package/src/gpu/kernels/upsample2d.js +2 -1
  101. package/src/gpu/kernels/upsample2d.wgsl +6 -9
  102. package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
  103. package/src/gpu/kernels/utils.js +16 -1
  104. package/src/index-browser.d.ts +1 -1
  105. package/src/index-browser.js +2 -2
  106. package/src/index.js +1 -1
  107. package/src/inference/browser-harness.js +109 -23
  108. package/src/inference/pipelines/diffusion/init.js +14 -0
  109. package/src/inference/pipelines/diffusion/pipeline.js +215 -77
  110. package/src/inference/pipelines/diffusion/sana-transformer.d.ts +53 -0
  111. package/src/inference/pipelines/diffusion/sana-transformer.js +738 -0
  112. package/src/inference/pipelines/diffusion/scheduler.d.ts +17 -1
  113. package/src/inference/pipelines/diffusion/scheduler.js +91 -3
  114. package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +11 -4
  115. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +282 -0
  116. package/src/inference/pipelines/diffusion/text-encoder.js +18 -1
  117. package/src/inference/pipelines/diffusion/types.d.ts +4 -0
  118. package/src/inference/pipelines/diffusion/vae.js +782 -78
  119. package/src/inference/pipelines/text/attention/record.js +11 -2
  120. package/src/inference/pipelines/text/attention/run.js +11 -2
  121. package/src/inference/pipelines/text/chat-format.js +25 -1
  122. package/src/inference/pipelines/text/config.d.ts +9 -0
  123. package/src/inference/pipelines/text/config.js +69 -2
  124. package/src/inference/pipelines/text/execution-plan.js +23 -31
  125. package/src/inference/pipelines/text/execution-v0.js +43 -95
  126. package/src/inference/pipelines/text/ffn/standard.js +3 -0
  127. package/src/inference/pipelines/text/init.d.ts +4 -0
  128. package/src/inference/pipelines/text/init.js +56 -9
  129. package/src/inference/pipelines/text/layer.js +11 -0
  130. package/src/inference/pipelines/text.js +4 -0
  131. package/src/inference/tokenizers/bundled.js +156 -33
  132. package/src/rules/execution-rules-contract-check.d.ts +17 -0
  133. package/src/rules/execution-rules-contract-check.js +245 -0
  134. package/src/rules/kernels/depthwise-conv2d.rules.json +6 -0
  135. package/src/rules/kernels/grouped-pointwise-conv2d.rules.json +6 -0
  136. package/src/rules/kernels/relu.rules.json +6 -0
  137. package/src/rules/kernels/repeat-channels.rules.json +6 -0
  138. package/src/rules/kernels/sana-linear-attention.rules.json +6 -0
  139. package/src/rules/layer-pattern-contract-check.d.ts +17 -0
  140. package/src/rules/layer-pattern-contract-check.js +231 -0
  141. package/src/rules/rule-registry.d.ts +28 -0
  142. package/src/rules/rule-registry.js +38 -0
  143. package/src/rules/tooling/command-runtime.rules.json +18 -0
  144. package/src/tooling/command-api.d.ts +27 -1
  145. package/src/tooling/command-api.js +142 -3
  146. package/src/tooling/conversion-config-materializer.d.ts +24 -0
  147. package/src/tooling/conversion-config-materializer.js +99 -0
  148. package/src/tooling/lean-execution-contract-runner.d.ts +43 -0
  149. package/src/tooling/lean-execution-contract-runner.js +158 -0
  150. package/src/tooling/node-browser-command-runner.d.ts +4 -0
  151. package/src/tooling/node-browser-command-runner.js +58 -3
  152. package/src/tooling/node-command-runner.js +15 -0
  153. package/src/tooling/node-convert.d.ts +10 -0
  154. package/src/tooling/node-converter.js +59 -0
  155. package/src/tooling/node-webgpu.js +11 -89
  156. package/src/training/checkpoint-watch.d.ts +7 -0
  157. package/src/training/checkpoint-watch.js +106 -0
  158. package/src/training/checkpoint.d.ts +6 -1
  159. package/src/training/checkpoint.js +12 -2
  160. package/src/training/distillation/artifacts.d.ts +71 -0
  161. package/src/training/distillation/artifacts.js +132 -0
  162. package/src/training/distillation/checkpoint-watch.d.ts +10 -0
  163. package/src/training/distillation/checkpoint-watch.js +57 -0
  164. package/src/training/distillation/dataset.d.ts +59 -0
  165. package/src/training/distillation/dataset.js +337 -0
  166. package/src/training/distillation/eval.d.ts +34 -0
  167. package/src/training/distillation/eval.js +310 -0
  168. package/src/training/distillation/index.d.ts +29 -0
  169. package/src/training/distillation/index.js +29 -0
  170. package/src/training/distillation/runtime.d.ts +20 -0
  171. package/src/training/distillation/runtime.js +121 -0
  172. package/src/training/distillation/scoreboard.d.ts +6 -0
  173. package/src/training/distillation/scoreboard.js +8 -0
  174. package/src/training/distillation/stage-a.d.ts +45 -0
  175. package/src/training/distillation/stage-a.js +338 -0
  176. package/src/training/distillation/stage-b.d.ts +24 -0
  177. package/src/training/distillation/stage-b.js +20 -0
  178. package/src/training/index.d.ts +10 -0
  179. package/src/training/index.js +10 -0
  180. package/src/training/lora-pipeline.d.ts +40 -0
  181. package/src/training/lora-pipeline.js +796 -0
  182. package/src/training/operator-artifacts.d.ts +62 -0
  183. package/src/training/operator-artifacts.js +140 -0
  184. package/src/training/operator-command.d.ts +5 -0
  185. package/src/training/operator-command.js +453 -0
  186. package/src/training/operator-eval.d.ts +48 -0
  187. package/src/training/operator-eval.js +230 -0
  188. package/src/training/operator-scoreboard.d.ts +5 -0
  189. package/src/training/operator-scoreboard.js +44 -0
  190. package/src/training/runner.d.ts +52 -0
  191. package/src/training/runner.js +29 -4
  192. package/src/training/suite.d.ts +112 -0
  193. package/src/training/suite.js +9 -9
  194. package/src/training/workloads.d.ts +164 -0
  195. package/src/training/workloads.js +539 -0
  196. package/src/version.d.ts +2 -0
  197. package/src/version.js +2 -0
  198. package/tools/convert-safetensors-node.js +47 -0
  199. package/tools/doppler-cli.js +252 -41
@@ -0,0 +1,796 @@
1
+ import { mkdir, readFile, readdir, writeFile } from 'node:fs/promises';
2
+ import { join, resolve } from 'node:path';
3
+
4
+ import { loadBackwardRegistry } from '../config/backward-registry-loader.js';
5
+ import { acquireBuffer, readBuffer, releaseBuffer, uploadData } from '../memory/buffer-pool.js';
6
+ import { createTensor } from '../gpu/tensor.js';
7
+ import { runMatmul } from '../gpu/kernels/index.js';
8
+ import { runResidualAdd } from '../gpu/kernels/residual.js';
9
+ import { parseJsonl } from './datasets/jsonl.js';
10
+ import { LoraAdapter } from './lora.js';
11
+ import { TrainingRunner, restoreTrainingCheckpointState } from './runner.js';
12
+ import { AdamOptimizer } from './optimizer.js';
13
+ import { crossEntropyLoss } from './loss.js';
14
+ import { clipGradients } from './clip.js';
15
+ import { OpType, AutogradTape } from './autograd.js';
16
+ import { loadCheckpoint } from './checkpoint.js';
17
+ import { exportLoRAAdapter } from './export.js';
18
+ import { computeEvalMetrics } from './operator-eval.js';
19
+ import { appendScoreboardRow } from './operator-scoreboard.js';
20
+ import {
21
+ buildArtifactBase,
22
+ createTrainingRunLayout,
23
+ hashArtifactPayload,
24
+ writeJsonArtifact,
25
+ writeRunContract,
26
+ writeWorkloadLock,
27
+ } from './operator-artifacts.js';
28
+ import { watchFinalizedCheckpoints } from './checkpoint-watch.js';
29
+ import { loadLoRAFromManifest } from '../adapters/lora-loader.js';
30
+
31
+ function stableSortObject(value) {
32
+ if (Array.isArray(value)) {
33
+ return value.map((entry) => stableSortObject(entry));
34
+ }
35
+ if (!value || typeof value !== 'object') {
36
+ return value;
37
+ }
38
+ const sorted = {};
39
+ for (const key of Object.keys(value).sort()) {
40
+ sorted[key] = stableSortObject(value[key]);
41
+ }
42
+ return sorted;
43
+ }
44
+
45
+ function stableJson(value) {
46
+ return JSON.stringify(stableSortObject(value));
47
+ }
48
+
49
+ function makeTensorFromFloat32(values, shape, label) {
50
+ const data = values instanceof Float32Array ? values : new Float32Array(values);
51
+ const buffer = acquireBuffer(data.byteLength, undefined, label);
52
+ uploadData(buffer, data);
53
+ return createTensor(buffer, 'f32', [...shape], label);
54
+ }
55
+
56
+ function makeTensorFromUint32(values, shape, label) {
57
+ const data = values instanceof Uint32Array ? values : new Uint32Array(values);
58
+ const buffer = acquireBuffer(data.byteLength, undefined, label);
59
+ uploadData(buffer, data);
60
+ return createTensor(buffer, 'u32', [...shape], label);
61
+ }
62
+
63
+ function releaseTensor(tensor) {
64
+ if (!tensor?.buffer) return;
65
+ releaseBuffer(tensor.buffer);
66
+ }
67
+
68
+ function createToyLoraModel(workload) {
69
+ const targetModule = workload.pipeline.adapter.targetModules[0];
70
+ if (!targetModule) {
71
+ throw new Error('LoRA workload requires at least one adapter target module.');
72
+ }
73
+ const baseWeight = makeTensorFromFloat32(
74
+ [0.08, -0.12, 0.16, 0.22, -0.03, 0.09],
75
+ [3, 2],
76
+ 'lora_toy_base_weight'
77
+ );
78
+ const adapter = new LoraAdapter({
79
+ inDim: 3,
80
+ outDim: 2,
81
+ rank: workload.pipeline.adapter.rank,
82
+ alpha: workload.pipeline.adapter.alpha,
83
+ });
84
+ const model = {
85
+ adapter,
86
+ baseWeight,
87
+ targetModule,
88
+ async forward(inputTensor, tape) {
89
+ const batchSize = Number.isInteger(inputTensor?.shape?.[0]) ? inputTensor.shape[0] : 1;
90
+ const baseLogits = await tape.record(
91
+ OpType.MATMUL,
92
+ (a, b) => runMatmul(a, b, batchSize, 2, 3, { transposeB: false }),
93
+ [inputTensor, baseWeight],
94
+ { M: batchSize, N: 2, K: 3, transposeB: false }
95
+ );
96
+ const delta = await adapter.forward(inputTensor, tape);
97
+ return tape.record(
98
+ OpType.RESIDUAL_ADD,
99
+ (a, b) => runResidualAdd(a, b, batchSize * 2),
100
+ [baseLogits, delta],
101
+ { size: batchSize * 2 }
102
+ );
103
+ },
104
+ loraParams() {
105
+ return [adapter.A, adapter.B];
106
+ },
107
+ paramGroups() {
108
+ return {
109
+ encoder: [],
110
+ prior: [],
111
+ decoder: [],
112
+ base: [baseWeight],
113
+ lora: [adapter.A, adapter.B],
114
+ };
115
+ },
116
+ };
117
+ return {
118
+ model,
119
+ cleanup() {
120
+ adapter.dispose();
121
+ releaseTensor(baseWeight);
122
+ },
123
+ };
124
+ }
125
+
126
+ function normalizeToyRow(record, index) {
127
+ if (!record || typeof record !== 'object' || Array.isArray(record)) {
128
+ throw new Error(`LoRA toy dataset row ${index + 1} must be an object.`);
129
+ }
130
+ const values = Array.isArray(record.input)
131
+ ? record.input
132
+ : (Array.isArray(record.features) ? record.features : null);
133
+ if (!Array.isArray(values) || values.length !== 3) {
134
+ throw new Error(`LoRA toy dataset row ${index + 1} requires input[3].`);
135
+ }
136
+ const input = values.map((value, valueIndex) => {
137
+ const parsed = Number(value);
138
+ if (!Number.isFinite(parsed)) {
139
+ throw new Error(`LoRA toy dataset row ${index + 1} input[${valueIndex}] must be finite.`);
140
+ }
141
+ return parsed;
142
+ });
143
+ const target = Number(record.target ?? record.label);
144
+ if (!Number.isInteger(target) || target < 0 || target > 1) {
145
+ throw new Error(`LoRA toy dataset row ${index + 1} requires integer target 0 or 1.`);
146
+ }
147
+ return {
148
+ id: String(record.id || `row-${index + 1}`),
149
+ input,
150
+ target,
151
+ };
152
+ }
153
+
154
+ async function loadToyLoraDataset(datasetPath) {
155
+ const absolutePath = resolve(String(datasetPath));
156
+ const raw = await readFile(absolutePath, 'utf8');
157
+ const rows = absolutePath.endsWith('.json')
158
+ ? JSON.parse(raw)
159
+ : parseJsonl(raw);
160
+ if (!Array.isArray(rows)) {
161
+ throw new Error(`LoRA dataset "${absolutePath}" must be a JSON array or JSONL file.`);
162
+ }
163
+ const normalizedRows = rows.map((row, index) => normalizeToyRow(row, index));
164
+ return {
165
+ absolutePath,
166
+ raw,
167
+ rows: normalizedRows,
168
+ datasetHash: hashArtifactPayload({ rows: normalizedRows }),
169
+ };
170
+ }
171
+
172
+ function createToyDatasetBatches(rows, batchSize) {
173
+ return {
174
+ async *batches() {
175
+ let inputTensor = null;
176
+ let targetTensor = null;
177
+ let tensorBatchSize = 0;
178
+ try {
179
+ for (let offset = 0; offset < rows.length; offset += batchSize) {
180
+ const batchRows = rows.slice(offset, offset + batchSize);
181
+ const inputData = new Float32Array(batchRows.length * 3);
182
+ const targetData = new Uint32Array(batchRows.length);
183
+ for (let rowIndex = 0; rowIndex < batchRows.length; rowIndex += 1) {
184
+ inputData.set(batchRows[rowIndex].input, rowIndex * 3);
185
+ targetData[rowIndex] = batchRows[rowIndex].target;
186
+ }
187
+ if (!inputTensor || !targetTensor || tensorBatchSize !== batchRows.length) {
188
+ releaseTensor(inputTensor);
189
+ releaseTensor(targetTensor);
190
+ inputTensor = makeTensorFromFloat32(inputData, [batchRows.length, 3], 'lora_toy_input');
191
+ targetTensor = makeTensorFromUint32(targetData, [batchRows.length], 'lora_toy_target');
192
+ tensorBatchSize = batchRows.length;
193
+ } else {
194
+ uploadData(inputTensor.buffer, inputData);
195
+ uploadData(targetTensor.buffer, targetData);
196
+ }
197
+ yield {
198
+ input: inputTensor,
199
+ targets: targetTensor,
200
+ };
201
+ }
202
+ } finally {
203
+ releaseTensor(inputTensor);
204
+ releaseTensor(targetTensor);
205
+ }
206
+ },
207
+ };
208
+ }
209
+
210
+ function collectProtectedBuffers(model) {
211
+ const protectedBuffers = new Set();
212
+ const groups = model.paramGroups();
213
+ for (const params of Object.values(groups)) {
214
+ for (const tensor of params) {
215
+ if (tensor?.buffer) {
216
+ protectedBuffers.add(tensor.buffer);
217
+ }
218
+ }
219
+ }
220
+ return protectedBuffers;
221
+ }
222
+
223
+ function disposeTapeOutputs(tape, protectedBuffers = new Set()) {
224
+ if (!Array.isArray(tape?.records)) return;
225
+ const released = new Set();
226
+ for (const record of tape.records) {
227
+ const output = record?.output;
228
+ if (output?.buffer && !protectedBuffers.has(output.buffer) && !released.has(output.buffer)) {
229
+ released.add(output.buffer);
230
+ releaseBuffer(output.buffer);
231
+ }
232
+ }
233
+ }
234
+
235
+ function argmax(values) {
236
+ let bestIndex = 0;
237
+ let bestValue = Number.NEGATIVE_INFINITY;
238
+ for (let index = 0; index < values.length; index += 1) {
239
+ const value = Number.isFinite(values[index]) ? values[index] : Number.NEGATIVE_INFINITY;
240
+ if (value > bestValue) {
241
+ bestValue = value;
242
+ bestIndex = index;
243
+ }
244
+ }
245
+ return bestIndex;
246
+ }
247
+
248
+ async function evaluateToyLoraModel(workload, model, dataset, layout = null, checkpointMeta = {}) {
249
+ const protectedBuffers = collectProtectedBuffers(model);
250
+ const evalReports = [];
251
+ const evalDatasets = Array.isArray(workload.evalDatasets) ? workload.evalDatasets : [];
252
+ for (const evalDataset of evalDatasets) {
253
+ if (evalDataset.evalKind !== 'classification' && evalDataset.evalKind !== 'text_generation') {
254
+ throw new Error(`LoRA eval currently supports classification/text_generation only, got "${evalDataset.evalKind}".`);
255
+ }
256
+ const evalDatasetMaterialized = evalDataset.datasetPath === dataset.absolutePath
257
+ ? dataset
258
+ : await loadToyLoraDataset(evalDataset.datasetPath);
259
+ const rows = evalDatasetMaterialized.rows;
260
+ const predictions = [];
261
+ const labels = [];
262
+ for (const row of rows) {
263
+ const tape = new AutogradTape(loadBackwardRegistry());
264
+ const inputTensor = makeTensorFromFloat32(row.input, [1, 3], 'lora_eval_input');
265
+ let logits = null;
266
+ try {
267
+ logits = await model.forward(inputTensor, tape);
268
+ const logitsData = new Float32Array(await readBuffer(logits.buffer));
269
+ predictions.push(String(argmax(logitsData)));
270
+ labels.push(String(row.target));
271
+ } finally {
272
+ releaseTensor(inputTensor);
273
+ if (logits?.buffer && !protectedBuffers.has(logits.buffer)) {
274
+ releaseBuffer(logits.buffer);
275
+ }
276
+ disposeTapeOutputs(tape, protectedBuffers);
277
+ }
278
+ }
279
+ const metrics = computeEvalMetrics('classification', predictions, labels, {});
280
+ const reportPayload = {
281
+ artifactType: 'training_eval_report',
282
+ schemaVersion: 1,
283
+ generatedAt: new Date().toISOString(),
284
+ workloadId: workload.id,
285
+ workloadPath: checkpointMeta.workloadPath || null,
286
+ workloadSha256: checkpointMeta.workloadSha256 || null,
287
+ configHash: checkpointMeta.configHash || workload.configHash,
288
+ datasetPath: evalDataset.datasetPath,
289
+ datasetHash: evalDatasetMaterialized.datasetHash,
290
+ baseModelId: workload.baseModelId,
291
+ stage: 'lora',
292
+ checkpointStep: checkpointMeta.checkpointStep ?? null,
293
+ evalDatasetId: evalDataset.id,
294
+ metrics,
295
+ primaryMetric: metrics.primaryMetric,
296
+ primaryScore: metrics.primaryScore,
297
+ accuracy: metrics.accuracy?.score ?? null,
298
+ };
299
+ const reportFile = layout
300
+ ? await writeJsonArtifact(
301
+ join(layout.eval, `${checkpointMeta.checkpointId || 'checkpoint'}__${evalDataset.id}.json`),
302
+ reportPayload
303
+ )
304
+ : null;
305
+ evalReports.push({
306
+ ...reportPayload,
307
+ reportPath: reportFile?.path || null,
308
+ });
309
+ }
310
+ return evalReports;
311
+ }
312
+
313
+ function buildRunContract(loadedWorkload) {
314
+ return {
315
+ artifactType: 'training_run_contract',
316
+ schemaVersion: 1,
317
+ generatedAt: new Date().toISOString(),
318
+ workloadId: loadedWorkload.workload.id,
319
+ workloadPath: loadedWorkload.absolutePath,
320
+ workloadSha256: loadedWorkload.workloadSha256,
321
+ configHash: loadedWorkload.workload.configHash,
322
+ claimBoundary: loadedWorkload.workload.claimBoundary,
323
+ kind: loadedWorkload.workload.kind,
324
+ evalDatasets: loadedWorkload.workload.evalDatasets,
325
+ };
326
+ }
327
+
328
+ function buildArtifact(loadedWorkload, options) {
329
+ const workload = loadedWorkload.workload;
330
+ const payload = buildArtifactBase({
331
+ artifactType: options.artifactType,
332
+ reportId: `${options.prefix}_${workload.id}_${options.id}`,
333
+ workload,
334
+ workloadPath: loadedWorkload.absolutePath,
335
+ workloadSha256: loadedWorkload.workloadSha256,
336
+ datasetPath: options.datasetPath || workload.datasetPath,
337
+ datasetHash: options.datasetHash || null,
338
+ baseModelId: workload.baseModelId,
339
+ stage: options.stage || 'lora',
340
+ checkpointStep: options.checkpointStep ?? null,
341
+ parentArtifacts: options.parentArtifacts || [],
342
+ runtime: 'node',
343
+ surface: 'node',
344
+ claimBoundary: workload.claimBoundary,
345
+ configHash: options.configHash || workload.configHash,
346
+ });
347
+ return {
348
+ ...payload,
349
+ artifactHash: hashArtifactPayload(payload),
350
+ };
351
+ }
352
+
353
+ async function exportToyLoraModel(loadedWorkload, layout, model, checkpointId, checkpointStep, datasetHash) {
354
+ const workload = loadedWorkload.workload;
355
+ const targetModule = model.targetModule || workload.pipeline.adapter.targetModules[0];
356
+ const exported = await exportLoRAAdapter({
357
+ id: workload.pipeline.export?.id || `${workload.id}-${checkpointId}`,
358
+ name: workload.pipeline.export?.name || `${workload.id}-${checkpointId}`,
359
+ baseModel: workload.baseModelId,
360
+ rank: workload.pipeline.adapter.rank,
361
+ alpha: workload.pipeline.adapter.alpha,
362
+ targetModules: [targetModule],
363
+ tensors: [
364
+ { name: `layers.0.${targetModule}.lora_a`, tensor: model.adapter.A },
365
+ { name: `layers.0.${targetModule}.lora_b`, tensor: model.adapter.B },
366
+ ],
367
+ });
368
+ const manifestPath = join(layout.exports, `${checkpointId}.adapter.manifest.json`);
369
+ await writeFile(manifestPath, exported.json, 'utf8');
370
+ await loadLoRAFromManifest(exported.manifest, {});
371
+ const artifactPayload = {
372
+ ...buildArtifact(loadedWorkload, {
373
+ prefix: 'lora_export',
374
+ id: checkpointId,
375
+ artifactType: 'lora_adapter_manifest',
376
+ checkpointStep,
377
+ datasetHash,
378
+ }),
379
+ checkpointId,
380
+ manifestPath,
381
+ manifest: exported.manifest,
382
+ };
383
+ const artifactFile = await writeJsonArtifact(
384
+ join(layout.exports, `${checkpointId}.export.json`),
385
+ artifactPayload
386
+ );
387
+ return {
388
+ checkpointId,
389
+ manifestPath,
390
+ exportPath: artifactFile.path,
391
+ manifest: exported.manifest,
392
+ };
393
+ }
394
+
395
+ async function selectLatestCheckpoint(runRoot) {
396
+ const checkpointsDir = join(runRoot, 'checkpoints');
397
+ const entries = await readdir(checkpointsDir, { withFileTypes: true });
398
+ const dirs = entries
399
+ .filter((entry) => entry.isDirectory())
400
+ .map((entry) => entry.name)
401
+ .sort((left, right) => left.localeCompare(right));
402
+ const latest = dirs[dirs.length - 1];
403
+ if (!latest) {
404
+ throw new Error(`No checkpoints found in ${checkpointsDir}.`);
405
+ }
406
+ return {
407
+ checkpointId: latest,
408
+ checkpointPath: join(checkpointsDir, latest, 'state.json'),
409
+ markerPath: join(checkpointsDir, latest, 'checkpoint.complete.json'),
410
+ };
411
+ }
412
+
413
+ export async function runLoraPipeline(options) {
414
+ const loadedWorkload = options.loadedWorkload;
415
+ const workload = loadedWorkload.workload;
416
+ if (workload.kind !== 'lora') {
417
+ throw new Error('runLoraPipeline requires a lora workload.');
418
+ }
419
+ if (workload.baseModelId !== 'training-toy') {
420
+ throw new Error('LoRA run currently supports baseModelId="training-toy" only.');
421
+ }
422
+ if (workload.pipeline.datasetFormat !== 'toy_linear_classification_jsonl') {
423
+ throw new Error('LoRA run currently supports datasetFormat="toy_linear_classification_jsonl" only.');
424
+ }
425
+ const layout = options.runRoot
426
+ ? {
427
+ runRoot: resolve(String(options.runRoot)),
428
+ logs: join(resolve(String(options.runRoot)), 'logs'),
429
+ checkpoints: join(resolve(String(options.runRoot)), 'checkpoints'),
430
+ eval: join(resolve(String(options.runRoot)), 'eval'),
431
+ scoreboard: join(resolve(String(options.runRoot)), 'scoreboard'),
432
+ exports: join(resolve(String(options.runRoot)), 'exports'),
433
+ compare: join(resolve(String(options.runRoot)), 'compare'),
434
+ qualityGate: join(resolve(String(options.runRoot)), 'quality-gate'),
435
+ }
436
+ : await createTrainingRunLayout({
437
+ kind: 'lora',
438
+ workloadId: workload.id,
439
+ timestamp: options.timestamp || null,
440
+ });
441
+ await Promise.all(Object.values(layout).map((dirPath) => mkdir(dirPath, { recursive: true })));
442
+ await writeRunContract(layout, buildRunContract(loadedWorkload));
443
+ await writeWorkloadLock(layout, loadedWorkload);
444
+ const dataset = await loadToyLoraDataset(workload.datasetPath);
445
+ const fixture = createToyLoraModel(workload);
446
+ try {
447
+ const evalReports = [];
448
+ const checkpointArtifacts = [];
449
+ const exports = [];
450
+ const runner = new TrainingRunner({
451
+ training: {
452
+ enabled: true,
453
+ optimizer: {
454
+ type: workload.training.optimizer.type,
455
+ lr: workload.training.optimizer.lr,
456
+ beta1: workload.training.optimizer.beta1,
457
+ beta2: workload.training.optimizer.beta2,
458
+ eps: workload.training.optimizer.eps,
459
+ weightDecay: workload.training.optimizer.weightDecay,
460
+ scheduler: workload.training.optimizer.scheduler,
461
+ },
462
+ gradient: {
463
+ maxNorm: workload.training.gradientClipping.maxNorm,
464
+ },
465
+ precision: workload.training.precision,
466
+ lossScaling: { enabled: false },
467
+ distill: {
468
+ enabled: false,
469
+ stage: 'stage_a',
470
+ teacherModelId: null,
471
+ studentModelId: null,
472
+ datasetId: null,
473
+ datasetPath: null,
474
+ languagePair: null,
475
+ sourceLangs: null,
476
+ targetLangs: null,
477
+ pairAllowlist: null,
478
+ strictPairContract: false,
479
+ shardIndex: null,
480
+ shardCount: null,
481
+ resumeFrom: null,
482
+ artifactDir: null,
483
+ stageAArtifact: null,
484
+ stageAArtifactHash: null,
485
+ temperature: 1,
486
+ alphaKd: 1,
487
+ alphaCe: 0,
488
+ allowHintFallback: false,
489
+ tripletMargin: 0.2,
490
+ studentGraphMode: 'projection_head',
491
+ freeze: { encoder: false, prior: false, decoder: false, base: true, lora: false },
492
+ },
493
+ ul: {
494
+ enabled: false,
495
+ stage: 'stage1_joint',
496
+ stage1Artifact: null,
497
+ stage1ArtifactHash: null,
498
+ artifactDir: null,
499
+ lambda0: 5,
500
+ seed: workload.seed,
501
+ noiseSchedule: { name: 'linear', minSigma: 0.1, maxSigma: 1, steps: 1 },
502
+ priorAlignment: { enabled: false, weight: 1 },
503
+ decoderSigmoidWeight: { enabled: false, maxWeight: 1 },
504
+ lossWeights: { prior: 1, decoder: 1, recon: 1 },
505
+ freeze: null,
506
+ },
507
+ },
508
+ }, {
509
+ optimizer: new AdamOptimizer({
510
+ training: {
511
+ optimizer: {
512
+ type: workload.training.optimizer.type,
513
+ lr: workload.training.optimizer.lr,
514
+ beta1: workload.training.optimizer.beta1,
515
+ beta2: workload.training.optimizer.beta2,
516
+ eps: workload.training.optimizer.eps,
517
+ weightDecay: workload.training.optimizer.weightDecay,
518
+ scheduler: workload.training.optimizer.scheduler,
519
+ },
520
+ gradient: {
521
+ maxNorm: workload.training.gradientClipping.maxNorm,
522
+ },
523
+ precision: workload.training.precision,
524
+ },
525
+ }),
526
+ crossEntropyLoss,
527
+ clipGradients,
528
+ resolveCheckpointKey({ step }) {
529
+ return join(layout.checkpoints, `checkpoint-${String(step).padStart(6, '0')}`, 'state.json');
530
+ },
531
+ onCheckpoint: async (checkpoint) => {
532
+ const checkpointId = `checkpoint-${String(checkpoint.step).padStart(6, '0')}`;
533
+ const checkpointPayload = {
534
+ ...buildArtifact(loadedWorkload, {
535
+ prefix: 'lora_ckpt',
536
+ id: checkpointId,
537
+ artifactType: 'training_checkpoint',
538
+ datasetHash: dataset.datasetHash,
539
+ checkpointStep: checkpoint.step,
540
+ }),
541
+ checkpointId,
542
+ checkpointPath: checkpoint.path,
543
+ optimizerStatePresent: true,
544
+ schedulerStatePresent: workload.training.optimizer.scheduler.enabled === true,
545
+ resumeLineage: checkpoint.metadata?.lineage || null,
546
+ };
547
+ await writeJsonArtifact(
548
+ join(layout.checkpoints, checkpointId, 'checkpoint.json'),
549
+ checkpointPayload
550
+ );
551
+ const checkpointArtifact = await writeJsonArtifact(
552
+ join(layout.checkpoints, checkpointId, 'checkpoint.complete.json'),
553
+ checkpointPayload
554
+ );
555
+ checkpointArtifacts.push({
556
+ checkpointId,
557
+ checkpointPath: checkpoint.path,
558
+ markerPath: checkpointArtifact.path,
559
+ checkpointStep: checkpoint.step,
560
+ });
561
+ if (workload.pipeline.export?.enabled === true && workload.pipeline.export.atCheckpoints === true) {
562
+ exports.push(await exportToyLoraModel(
563
+ loadedWorkload,
564
+ layout,
565
+ fixture.model,
566
+ checkpointId,
567
+ checkpoint.step,
568
+ dataset.datasetHash
569
+ ));
570
+ }
571
+ const reports = await evaluateToyLoraModel(workload, fixture.model, dataset, layout, {
572
+ checkpointId,
573
+ checkpointStep: checkpoint.step,
574
+ configHash: workload.configHash,
575
+ workloadPath: loadedWorkload.absolutePath,
576
+ workloadSha256: loadedWorkload.workloadSha256,
577
+ });
578
+ for (const report of reports) {
579
+ evalReports.push(report);
580
+ await appendScoreboardRow(layout.scoreboard, {
581
+ artifactType: 'training_scoreboard',
582
+ schemaVersion: 1,
583
+ generatedAt: new Date().toISOString(),
584
+ checkpointId,
585
+ checkpointStep: checkpoint.step,
586
+ evalDatasetId: report.evalDatasetId,
587
+ primaryMetric: report.primaryMetric,
588
+ primaryScore: report.primaryScore,
589
+ accuracy: report.accuracy,
590
+ metrics: {
591
+ accuracy: report.accuracy,
592
+ primaryScore: report.primaryScore,
593
+ },
594
+ }, {
595
+ selectionMetric: workload.selectionMetric,
596
+ selectionGoal: workload.selectionGoal,
597
+ });
598
+ }
599
+ },
600
+ });
601
+ const metrics = await runner.run(
602
+ fixture.model,
603
+ createToyDatasetBatches(dataset.rows, workload.training.batchSize),
604
+ {
605
+ epochs: 1,
606
+ batchSize: workload.training.batchSize,
607
+ shuffle: false,
608
+ maxSteps: workload.training.steps,
609
+ checkpointEvery: workload.checkpointEvery,
610
+ modelId: workload.baseModelId,
611
+ }
612
+ );
613
+ const finalCheckpointId = runner.lastCheckpoint
614
+ ? `checkpoint-${String(runner.lastCheckpoint.step).padStart(6, '0')}`
615
+ : null;
616
+ if (workload.pipeline.export?.enabled === true && finalCheckpointId && exports.every((entry) => entry.checkpointId !== finalCheckpointId)) {
617
+ exports.push(await exportToyLoraModel(
618
+ loadedWorkload,
619
+ layout,
620
+ fixture.model,
621
+ finalCheckpointId,
622
+ runner.lastCheckpoint.step,
623
+ dataset.datasetHash
624
+ ));
625
+ }
626
+ return {
627
+ ok: true,
628
+ kind: 'lora',
629
+ action: 'run',
630
+ workloadId: workload.id,
631
+ runRoot: layout.runRoot,
632
+ checkpointArtifacts,
633
+ evalReports,
634
+ exports,
635
+ metrics,
636
+ lastCheckpoint: runner.lastCheckpoint,
637
+ };
638
+ } finally {
639
+ fixture.cleanup();
640
+ }
641
+ }
642
+
643
+ export async function evaluateLoraCheckpoint(options) {
644
+ const loadedWorkload = options.loadedWorkload;
645
+ const checkpointPath = resolve(String(options.checkpointPath));
646
+ const workload = loadedWorkload.workload;
647
+ const dataset = await loadToyLoraDataset(workload.datasetPath);
648
+ const checkpointRecord = await loadCheckpoint(checkpointPath);
649
+ if (!checkpointRecord) {
650
+ throw new Error(`Checkpoint not found: ${checkpointPath}`);
651
+ }
652
+ const fixture = createToyLoraModel(workload);
653
+ try {
654
+ await restoreTrainingCheckpointState(fixture.model, { getState: () => null }, checkpointRecord, {
655
+ training: {
656
+ distill: { freeze: { encoder: false, prior: false, decoder: false, base: true, lora: false } },
657
+ ul: { freeze: null },
658
+ },
659
+ });
660
+ return evaluateToyLoraModel(workload, fixture.model, dataset, options.layout || null, {
661
+ checkpointId: options.checkpointId || 'checkpoint',
662
+ checkpointStep: options.checkpointStep ?? null,
663
+ configHash: workload.configHash,
664
+ workloadPath: loadedWorkload.absolutePath,
665
+ workloadSha256: loadedWorkload.workloadSha256,
666
+ });
667
+ } finally {
668
+ fixture.cleanup();
669
+ }
670
+ }
671
+
672
+ export async function exportLoraCheckpoint(options) {
673
+ const loadedWorkload = options.loadedWorkload;
674
+ const workload = loadedWorkload.workload;
675
+ const layout = options.layout || {
676
+ exports: resolve(options.exportsDir || 'reports/training/lora/exports'),
677
+ };
678
+ const checkpointPath = resolve(String(options.checkpointPath));
679
+ const checkpointRecord = await loadCheckpoint(checkpointPath);
680
+ if (!checkpointRecord) {
681
+ throw new Error(`Checkpoint not found: ${checkpointPath}`);
682
+ }
683
+ const fixture = createToyLoraModel(workload);
684
+ try {
685
+ await restoreTrainingCheckpointState(fixture.model, { getState: () => null }, checkpointRecord, {
686
+ training: {
687
+ distill: { freeze: { encoder: false, prior: false, decoder: false, base: true, lora: false } },
688
+ ul: { freeze: null },
689
+ },
690
+ });
691
+ const checkpointId = options.checkpointId || 'checkpoint';
692
+ return exportToyLoraModel(
693
+ loadedWorkload,
694
+ { ...layout, exports: layout.exports || resolve(options.exportsDir || 'reports/training/lora/exports') },
695
+ fixture.model,
696
+ checkpointId,
697
+ options.checkpointStep ?? null,
698
+ options.datasetHash || null
699
+ );
700
+ } finally {
701
+ fixture.cleanup();
702
+ }
703
+ }
704
+
705
+ export async function watchLoraCheckpoints(options) {
706
+ const latestCheckpoint = await selectLatestCheckpoint(options.runRoot);
707
+ return watchFinalizedCheckpoints({
708
+ checkpointsDir: join(options.runRoot, 'checkpoints'),
709
+ manifestPath: join(options.runRoot, 'scoreboard', 'watch-manifest.json'),
710
+ pollIntervalMs: options.pollIntervalMs || 2000,
711
+ stopWhenIdle: options.stopWhenIdle === true,
712
+ onCheckpoint: async (markerPath) => {
713
+ const raw = await readFile(markerPath, 'utf8');
714
+ const marker = JSON.parse(raw);
715
+ await evaluateLoraCheckpoint({
716
+ loadedWorkload: options.loadedWorkload,
717
+ checkpointPath: marker.checkpointPath || latestCheckpoint.checkpointPath,
718
+ checkpointId: marker.checkpointId || latestCheckpoint.checkpointId,
719
+ checkpointStep: marker.checkpointStep ?? null,
720
+ layout: {
721
+ eval: join(options.runRoot, 'eval'),
722
+ },
723
+ });
724
+ },
725
+ });
726
+ }
727
+
728
+ export async function compareLoraRun(options) {
729
+ const evalDir = join(options.runRoot, 'eval');
730
+ const entries = await readdir(evalDir, { withFileTypes: true });
731
+ const reports = [];
732
+ for (const entry of entries) {
733
+ if (!entry.isFile() || !entry.name.endsWith('.json')) continue;
734
+ const raw = await readFile(join(evalDir, entry.name), 'utf8');
735
+ reports.push(JSON.parse(raw));
736
+ }
737
+ const sorted = reports
738
+ .slice()
739
+ .sort((left, right) => {
740
+ const leftScore = Number(left?.primaryScore ?? Number.NEGATIVE_INFINITY);
741
+ const rightScore = Number(right?.primaryScore ?? Number.NEGATIVE_INFINITY);
742
+ return rightScore - leftScore;
743
+ });
744
+ const payload = {
745
+ artifactType: 'training_compare_report',
746
+ schemaVersion: 1,
747
+ generatedAt: new Date().toISOString(),
748
+ runRoot: options.runRoot,
749
+ count: sorted.length,
750
+ best: sorted[0] || null,
751
+ reports: sorted.map((report) => ({
752
+ checkpointId: report.checkpointId || null,
753
+ evalDatasetId: report.evalDatasetId || null,
754
+ primaryMetric: report.primaryMetric || null,
755
+ primaryScore: report.primaryScore ?? null,
756
+ accuracy: report.accuracy ?? null,
757
+ reportPath: report.reportPath || null,
758
+ })),
759
+ };
760
+ const artifact = await writeJsonArtifact(join(options.runRoot, 'compare', 'compare.json'), payload);
761
+ return {
762
+ ...payload,
763
+ comparePath: artifact.path,
764
+ };
765
+ }
766
+
767
+ export async function qualityGateLoraRun(options) {
768
+ const runRoot = resolve(String(options.runRoot));
769
+ const requiredPaths = [
770
+ join(runRoot, 'run_contract.json'),
771
+ join(runRoot, 'workload.lock.json'),
772
+ ];
773
+ const checks = [];
774
+ for (const filePath of requiredPaths) {
775
+ try {
776
+ await readFile(filePath, 'utf8');
777
+ checks.push({ path: filePath, ok: true });
778
+ } catch (error) {
779
+ checks.push({ path: filePath, ok: false, error: error?.message || String(error) });
780
+ }
781
+ }
782
+ const passed = checks.every((entry) => entry.ok === true);
783
+ const payload = {
784
+ artifactType: 'training_quality_gate',
785
+ schemaVersion: 1,
786
+ generatedAt: new Date().toISOString(),
787
+ runRoot,
788
+ passed,
789
+ checks,
790
+ };
791
+ const artifact = await writeJsonArtifact(join(runRoot, 'quality-gate', 'quality-gate.json'), payload);
792
+ return {
793
+ ...payload,
794
+ reportPath: artifact.path,
795
+ };
796
+ }