@simulatte/doppler 0.1.4 → 0.1.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +26 -10
- package/package.json +30 -6
- package/src/client/doppler-api.browser.d.ts +1 -0
- package/src/client/doppler-api.browser.js +288 -0
- package/src/client/doppler-api.js +1 -1
- package/src/client/doppler-provider/types.js +1 -1
- package/src/config/execution-contract-check.d.ts +33 -0
- package/src/config/execution-contract-check.js +72 -0
- package/src/config/execution-v0-contract-check.d.ts +94 -0
- package/src/config/execution-v0-contract-check.js +251 -0
- package/src/config/execution-v0-graph-contract-check.d.ts +20 -0
- package/src/config/execution-v0-graph-contract-check.js +64 -0
- package/src/config/kernel-path-contract-check.d.ts +76 -0
- package/src/config/kernel-path-contract-check.js +479 -0
- package/src/config/kernel-path-loader.d.ts +16 -0
- package/src/config/kernel-path-loader.js +54 -0
- package/src/config/kernels/kernel-ref-digests.js +39 -27
- package/src/config/kernels/registry.json +598 -2
- package/src/config/loader.js +81 -48
- package/src/config/merge-contract-check.d.ts +16 -0
- package/src/config/merge-contract-check.js +321 -0
- package/src/config/merge-helpers.d.ts +58 -0
- package/src/config/merge-helpers.js +54 -0
- package/src/config/merge.js +21 -6
- package/src/config/presets/models/janus-text.json +2 -0
- package/src/config/presets/models/qwen3.json +9 -2
- package/src/config/presets/models/transformer.json +5 -0
- package/src/config/quantization-contract-check.d.ts +12 -0
- package/src/config/quantization-contract-check.js +91 -0
- package/src/config/required-inference-fields-contract-check.d.ts +24 -0
- package/src/config/required-inference-fields-contract-check.js +237 -0
- package/src/config/schema/browser-suite-metrics.schema.d.ts +17 -0
- package/src/config/schema/browser-suite-metrics.schema.js +46 -0
- package/src/config/schema/conversion-report.schema.d.ts +40 -0
- package/src/config/schema/conversion-report.schema.js +108 -0
- package/src/config/schema/doppler.schema.js +12 -18
- package/src/config/schema/index.d.ts +22 -0
- package/src/config/schema/index.js +18 -0
- package/src/config/schema/inference-defaults.schema.js +3 -0
- package/src/config/schema/inference.schema.d.ts +9 -0
- package/src/config/schema/kernel-path.schema.d.ts +6 -0
- package/src/config/schema/manifest.schema.d.ts +6 -0
- package/src/config/schema/manifest.schema.js +3 -0
- package/src/converter/core.d.ts +10 -0
- package/src/converter/core.js +27 -2
- package/src/converter/parsers/diffusion.js +63 -3
- package/src/converter/rope-config.js +42 -0
- package/src/gpu/device.js +58 -0
- package/src/gpu/kernels/attention.js +98 -0
- package/src/gpu/kernels/bias_add.wgsl +8 -6
- package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
- package/src/gpu/kernels/conv2d.js +1 -1
- package/src/gpu/kernels/conv2d.wgsl +7 -8
- package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
- package/src/gpu/kernels/depthwise_conv2d.d.ts +29 -0
- package/src/gpu/kernels/depthwise_conv2d.js +99 -0
- package/src/gpu/kernels/depthwise_conv2d.wgsl +55 -0
- package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +59 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.d.ts +27 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +93 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +44 -0
- package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +48 -0
- package/src/gpu/kernels/index.d.ts +30 -0
- package/src/gpu/kernels/index.js +25 -0
- package/src/gpu/kernels/matmul.js +25 -0
- package/src/gpu/kernels/pixel_shuffle.js +1 -1
- package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
- package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
- package/src/gpu/kernels/relu.d.ts +18 -0
- package/src/gpu/kernels/relu.js +58 -0
- package/src/gpu/kernels/relu.wgsl +22 -0
- package/src/gpu/kernels/relu_f16.wgsl +24 -0
- package/src/gpu/kernels/repeat_channels.d.ts +21 -0
- package/src/gpu/kernels/repeat_channels.js +60 -0
- package/src/gpu/kernels/repeat_channels.wgsl +28 -0
- package/src/gpu/kernels/repeat_channels_f16.wgsl +30 -0
- package/src/gpu/kernels/residual.js +44 -8
- package/src/gpu/kernels/residual.wgsl +6 -3
- package/src/gpu/kernels/residual_f16.wgsl +2 -1
- package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
- package/src/gpu/kernels/residual_vec4.wgsl +2 -1
- package/src/gpu/kernels/rmsnorm.js +58 -6
- package/src/gpu/kernels/rmsnorm.wgsl +14 -6
- package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
- package/src/gpu/kernels/rope.d.ts +2 -0
- package/src/gpu/kernels/rope.js +11 -1
- package/src/gpu/kernels/rope.wgsl +56 -40
- package/src/gpu/kernels/sana_linear_attention.d.ts +27 -0
- package/src/gpu/kernels/sana_linear_attention.js +121 -0
- package/src/gpu/kernels/sana_linear_attention_apply.wgsl +43 -0
- package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +46 -0
- package/src/gpu/kernels/sana_linear_attention_summary.wgsl +51 -0
- package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +53 -0
- package/src/gpu/kernels/silu.d.ts +1 -0
- package/src/gpu/kernels/silu.js +32 -14
- package/src/gpu/kernels/silu.wgsl +19 -9
- package/src/gpu/kernels/silu_f16.wgsl +19 -9
- package/src/gpu/kernels/transpose.js +15 -2
- package/src/gpu/kernels/transpose.wgsl +5 -6
- package/src/gpu/kernels/upsample2d.js +2 -1
- package/src/gpu/kernels/upsample2d.wgsl +6 -9
- package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
- package/src/gpu/kernels/utils.js +16 -1
- package/src/index-browser.d.ts +1 -1
- package/src/index-browser.js +2 -2
- package/src/index.js +1 -1
- package/src/inference/browser-harness.js +109 -23
- package/src/inference/pipelines/diffusion/init.js +14 -0
- package/src/inference/pipelines/diffusion/pipeline.js +215 -77
- package/src/inference/pipelines/diffusion/sana-transformer.d.ts +53 -0
- package/src/inference/pipelines/diffusion/sana-transformer.js +738 -0
- package/src/inference/pipelines/diffusion/scheduler.d.ts +17 -1
- package/src/inference/pipelines/diffusion/scheduler.js +91 -3
- package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +11 -4
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +282 -0
- package/src/inference/pipelines/diffusion/text-encoder.js +18 -1
- package/src/inference/pipelines/diffusion/types.d.ts +4 -0
- package/src/inference/pipelines/diffusion/vae.js +782 -78
- package/src/inference/pipelines/text/attention/record.js +11 -2
- package/src/inference/pipelines/text/attention/run.js +11 -2
- package/src/inference/pipelines/text/chat-format.js +25 -1
- package/src/inference/pipelines/text/config.d.ts +9 -0
- package/src/inference/pipelines/text/config.js +69 -2
- package/src/inference/pipelines/text/execution-plan.js +23 -31
- package/src/inference/pipelines/text/execution-v0.js +43 -95
- package/src/inference/pipelines/text/ffn/standard.js +3 -0
- package/src/inference/pipelines/text/init.d.ts +4 -0
- package/src/inference/pipelines/text/init.js +56 -9
- package/src/inference/pipelines/text/layer.js +11 -0
- package/src/inference/pipelines/text.js +4 -0
- package/src/inference/tokenizers/bundled.js +156 -33
- package/src/rules/execution-rules-contract-check.d.ts +17 -0
- package/src/rules/execution-rules-contract-check.js +245 -0
- package/src/rules/kernels/depthwise-conv2d.rules.json +6 -0
- package/src/rules/kernels/grouped-pointwise-conv2d.rules.json +6 -0
- package/src/rules/kernels/relu.rules.json +6 -0
- package/src/rules/kernels/repeat-channels.rules.json +6 -0
- package/src/rules/kernels/sana-linear-attention.rules.json +6 -0
- package/src/rules/layer-pattern-contract-check.d.ts +17 -0
- package/src/rules/layer-pattern-contract-check.js +231 -0
- package/src/rules/rule-registry.d.ts +28 -0
- package/src/rules/rule-registry.js +38 -0
- package/src/rules/tooling/command-runtime.rules.json +18 -0
- package/src/tooling/command-api.d.ts +27 -1
- package/src/tooling/command-api.js +142 -3
- package/src/tooling/conversion-config-materializer.d.ts +24 -0
- package/src/tooling/conversion-config-materializer.js +99 -0
- package/src/tooling/lean-execution-contract-runner.d.ts +43 -0
- package/src/tooling/lean-execution-contract-runner.js +158 -0
- package/src/tooling/node-browser-command-runner.d.ts +4 -0
- package/src/tooling/node-browser-command-runner.js +58 -3
- package/src/tooling/node-command-runner.js +15 -0
- package/src/tooling/node-convert.d.ts +10 -0
- package/src/tooling/node-converter.js +59 -0
- package/src/tooling/node-webgpu.js +11 -89
- package/src/training/checkpoint-watch.d.ts +7 -0
- package/src/training/checkpoint-watch.js +106 -0
- package/src/training/checkpoint.d.ts +6 -1
- package/src/training/checkpoint.js +12 -2
- package/src/training/distillation/artifacts.d.ts +71 -0
- package/src/training/distillation/artifacts.js +132 -0
- package/src/training/distillation/checkpoint-watch.d.ts +10 -0
- package/src/training/distillation/checkpoint-watch.js +57 -0
- package/src/training/distillation/dataset.d.ts +59 -0
- package/src/training/distillation/dataset.js +337 -0
- package/src/training/distillation/eval.d.ts +34 -0
- package/src/training/distillation/eval.js +310 -0
- package/src/training/distillation/index.d.ts +29 -0
- package/src/training/distillation/index.js +29 -0
- package/src/training/distillation/runtime.d.ts +20 -0
- package/src/training/distillation/runtime.js +121 -0
- package/src/training/distillation/scoreboard.d.ts +6 -0
- package/src/training/distillation/scoreboard.js +8 -0
- package/src/training/distillation/stage-a.d.ts +45 -0
- package/src/training/distillation/stage-a.js +338 -0
- package/src/training/distillation/stage-b.d.ts +24 -0
- package/src/training/distillation/stage-b.js +20 -0
- package/src/training/index.d.ts +10 -0
- package/src/training/index.js +10 -0
- package/src/training/lora-pipeline.d.ts +40 -0
- package/src/training/lora-pipeline.js +796 -0
- package/src/training/operator-artifacts.d.ts +62 -0
- package/src/training/operator-artifacts.js +140 -0
- package/src/training/operator-command.d.ts +5 -0
- package/src/training/operator-command.js +453 -0
- package/src/training/operator-eval.d.ts +48 -0
- package/src/training/operator-eval.js +230 -0
- package/src/training/operator-scoreboard.d.ts +5 -0
- package/src/training/operator-scoreboard.js +44 -0
- package/src/training/runner.d.ts +52 -0
- package/src/training/runner.js +29 -4
- package/src/training/suite.d.ts +112 -0
- package/src/training/suite.js +9 -9
- package/src/training/workloads.d.ts +164 -0
- package/src/training/workloads.js +539 -0
- package/src/version.d.ts +2 -0
- package/src/version.js +2 -0
- package/tools/convert-safetensors-node.js +47 -0
- package/tools/doppler-cli.js +252 -41
|
@@ -0,0 +1,796 @@
|
|
|
1
|
+
import { mkdir, readFile, readdir, writeFile } from 'node:fs/promises';
|
|
2
|
+
import { join, resolve } from 'node:path';
|
|
3
|
+
|
|
4
|
+
import { loadBackwardRegistry } from '../config/backward-registry-loader.js';
|
|
5
|
+
import { acquireBuffer, readBuffer, releaseBuffer, uploadData } from '../memory/buffer-pool.js';
|
|
6
|
+
import { createTensor } from '../gpu/tensor.js';
|
|
7
|
+
import { runMatmul } from '../gpu/kernels/index.js';
|
|
8
|
+
import { runResidualAdd } from '../gpu/kernels/residual.js';
|
|
9
|
+
import { parseJsonl } from './datasets/jsonl.js';
|
|
10
|
+
import { LoraAdapter } from './lora.js';
|
|
11
|
+
import { TrainingRunner, restoreTrainingCheckpointState } from './runner.js';
|
|
12
|
+
import { AdamOptimizer } from './optimizer.js';
|
|
13
|
+
import { crossEntropyLoss } from './loss.js';
|
|
14
|
+
import { clipGradients } from './clip.js';
|
|
15
|
+
import { OpType, AutogradTape } from './autograd.js';
|
|
16
|
+
import { loadCheckpoint } from './checkpoint.js';
|
|
17
|
+
import { exportLoRAAdapter } from './export.js';
|
|
18
|
+
import { computeEvalMetrics } from './operator-eval.js';
|
|
19
|
+
import { appendScoreboardRow } from './operator-scoreboard.js';
|
|
20
|
+
import {
|
|
21
|
+
buildArtifactBase,
|
|
22
|
+
createTrainingRunLayout,
|
|
23
|
+
hashArtifactPayload,
|
|
24
|
+
writeJsonArtifact,
|
|
25
|
+
writeRunContract,
|
|
26
|
+
writeWorkloadLock,
|
|
27
|
+
} from './operator-artifacts.js';
|
|
28
|
+
import { watchFinalizedCheckpoints } from './checkpoint-watch.js';
|
|
29
|
+
import { loadLoRAFromManifest } from '../adapters/lora-loader.js';
|
|
30
|
+
|
|
31
|
+
function stableSortObject(value) {
|
|
32
|
+
if (Array.isArray(value)) {
|
|
33
|
+
return value.map((entry) => stableSortObject(entry));
|
|
34
|
+
}
|
|
35
|
+
if (!value || typeof value !== 'object') {
|
|
36
|
+
return value;
|
|
37
|
+
}
|
|
38
|
+
const sorted = {};
|
|
39
|
+
for (const key of Object.keys(value).sort()) {
|
|
40
|
+
sorted[key] = stableSortObject(value[key]);
|
|
41
|
+
}
|
|
42
|
+
return sorted;
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
function stableJson(value) {
|
|
46
|
+
return JSON.stringify(stableSortObject(value));
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
function makeTensorFromFloat32(values, shape, label) {
|
|
50
|
+
const data = values instanceof Float32Array ? values : new Float32Array(values);
|
|
51
|
+
const buffer = acquireBuffer(data.byteLength, undefined, label);
|
|
52
|
+
uploadData(buffer, data);
|
|
53
|
+
return createTensor(buffer, 'f32', [...shape], label);
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
function makeTensorFromUint32(values, shape, label) {
|
|
57
|
+
const data = values instanceof Uint32Array ? values : new Uint32Array(values);
|
|
58
|
+
const buffer = acquireBuffer(data.byteLength, undefined, label);
|
|
59
|
+
uploadData(buffer, data);
|
|
60
|
+
return createTensor(buffer, 'u32', [...shape], label);
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
function releaseTensor(tensor) {
|
|
64
|
+
if (!tensor?.buffer) return;
|
|
65
|
+
releaseBuffer(tensor.buffer);
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
function createToyLoraModel(workload) {
|
|
69
|
+
const targetModule = workload.pipeline.adapter.targetModules[0];
|
|
70
|
+
if (!targetModule) {
|
|
71
|
+
throw new Error('LoRA workload requires at least one adapter target module.');
|
|
72
|
+
}
|
|
73
|
+
const baseWeight = makeTensorFromFloat32(
|
|
74
|
+
[0.08, -0.12, 0.16, 0.22, -0.03, 0.09],
|
|
75
|
+
[3, 2],
|
|
76
|
+
'lora_toy_base_weight'
|
|
77
|
+
);
|
|
78
|
+
const adapter = new LoraAdapter({
|
|
79
|
+
inDim: 3,
|
|
80
|
+
outDim: 2,
|
|
81
|
+
rank: workload.pipeline.adapter.rank,
|
|
82
|
+
alpha: workload.pipeline.adapter.alpha,
|
|
83
|
+
});
|
|
84
|
+
const model = {
|
|
85
|
+
adapter,
|
|
86
|
+
baseWeight,
|
|
87
|
+
targetModule,
|
|
88
|
+
async forward(inputTensor, tape) {
|
|
89
|
+
const batchSize = Number.isInteger(inputTensor?.shape?.[0]) ? inputTensor.shape[0] : 1;
|
|
90
|
+
const baseLogits = await tape.record(
|
|
91
|
+
OpType.MATMUL,
|
|
92
|
+
(a, b) => runMatmul(a, b, batchSize, 2, 3, { transposeB: false }),
|
|
93
|
+
[inputTensor, baseWeight],
|
|
94
|
+
{ M: batchSize, N: 2, K: 3, transposeB: false }
|
|
95
|
+
);
|
|
96
|
+
const delta = await adapter.forward(inputTensor, tape);
|
|
97
|
+
return tape.record(
|
|
98
|
+
OpType.RESIDUAL_ADD,
|
|
99
|
+
(a, b) => runResidualAdd(a, b, batchSize * 2),
|
|
100
|
+
[baseLogits, delta],
|
|
101
|
+
{ size: batchSize * 2 }
|
|
102
|
+
);
|
|
103
|
+
},
|
|
104
|
+
loraParams() {
|
|
105
|
+
return [adapter.A, adapter.B];
|
|
106
|
+
},
|
|
107
|
+
paramGroups() {
|
|
108
|
+
return {
|
|
109
|
+
encoder: [],
|
|
110
|
+
prior: [],
|
|
111
|
+
decoder: [],
|
|
112
|
+
base: [baseWeight],
|
|
113
|
+
lora: [adapter.A, adapter.B],
|
|
114
|
+
};
|
|
115
|
+
},
|
|
116
|
+
};
|
|
117
|
+
return {
|
|
118
|
+
model,
|
|
119
|
+
cleanup() {
|
|
120
|
+
adapter.dispose();
|
|
121
|
+
releaseTensor(baseWeight);
|
|
122
|
+
},
|
|
123
|
+
};
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
function normalizeToyRow(record, index) {
|
|
127
|
+
if (!record || typeof record !== 'object' || Array.isArray(record)) {
|
|
128
|
+
throw new Error(`LoRA toy dataset row ${index + 1} must be an object.`);
|
|
129
|
+
}
|
|
130
|
+
const values = Array.isArray(record.input)
|
|
131
|
+
? record.input
|
|
132
|
+
: (Array.isArray(record.features) ? record.features : null);
|
|
133
|
+
if (!Array.isArray(values) || values.length !== 3) {
|
|
134
|
+
throw new Error(`LoRA toy dataset row ${index + 1} requires input[3].`);
|
|
135
|
+
}
|
|
136
|
+
const input = values.map((value, valueIndex) => {
|
|
137
|
+
const parsed = Number(value);
|
|
138
|
+
if (!Number.isFinite(parsed)) {
|
|
139
|
+
throw new Error(`LoRA toy dataset row ${index + 1} input[${valueIndex}] must be finite.`);
|
|
140
|
+
}
|
|
141
|
+
return parsed;
|
|
142
|
+
});
|
|
143
|
+
const target = Number(record.target ?? record.label);
|
|
144
|
+
if (!Number.isInteger(target) || target < 0 || target > 1) {
|
|
145
|
+
throw new Error(`LoRA toy dataset row ${index + 1} requires integer target 0 or 1.`);
|
|
146
|
+
}
|
|
147
|
+
return {
|
|
148
|
+
id: String(record.id || `row-${index + 1}`),
|
|
149
|
+
input,
|
|
150
|
+
target,
|
|
151
|
+
};
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
async function loadToyLoraDataset(datasetPath) {
|
|
155
|
+
const absolutePath = resolve(String(datasetPath));
|
|
156
|
+
const raw = await readFile(absolutePath, 'utf8');
|
|
157
|
+
const rows = absolutePath.endsWith('.json')
|
|
158
|
+
? JSON.parse(raw)
|
|
159
|
+
: parseJsonl(raw);
|
|
160
|
+
if (!Array.isArray(rows)) {
|
|
161
|
+
throw new Error(`LoRA dataset "${absolutePath}" must be a JSON array or JSONL file.`);
|
|
162
|
+
}
|
|
163
|
+
const normalizedRows = rows.map((row, index) => normalizeToyRow(row, index));
|
|
164
|
+
return {
|
|
165
|
+
absolutePath,
|
|
166
|
+
raw,
|
|
167
|
+
rows: normalizedRows,
|
|
168
|
+
datasetHash: hashArtifactPayload({ rows: normalizedRows }),
|
|
169
|
+
};
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
function createToyDatasetBatches(rows, batchSize) {
|
|
173
|
+
return {
|
|
174
|
+
async *batches() {
|
|
175
|
+
let inputTensor = null;
|
|
176
|
+
let targetTensor = null;
|
|
177
|
+
let tensorBatchSize = 0;
|
|
178
|
+
try {
|
|
179
|
+
for (let offset = 0; offset < rows.length; offset += batchSize) {
|
|
180
|
+
const batchRows = rows.slice(offset, offset + batchSize);
|
|
181
|
+
const inputData = new Float32Array(batchRows.length * 3);
|
|
182
|
+
const targetData = new Uint32Array(batchRows.length);
|
|
183
|
+
for (let rowIndex = 0; rowIndex < batchRows.length; rowIndex += 1) {
|
|
184
|
+
inputData.set(batchRows[rowIndex].input, rowIndex * 3);
|
|
185
|
+
targetData[rowIndex] = batchRows[rowIndex].target;
|
|
186
|
+
}
|
|
187
|
+
if (!inputTensor || !targetTensor || tensorBatchSize !== batchRows.length) {
|
|
188
|
+
releaseTensor(inputTensor);
|
|
189
|
+
releaseTensor(targetTensor);
|
|
190
|
+
inputTensor = makeTensorFromFloat32(inputData, [batchRows.length, 3], 'lora_toy_input');
|
|
191
|
+
targetTensor = makeTensorFromUint32(targetData, [batchRows.length], 'lora_toy_target');
|
|
192
|
+
tensorBatchSize = batchRows.length;
|
|
193
|
+
} else {
|
|
194
|
+
uploadData(inputTensor.buffer, inputData);
|
|
195
|
+
uploadData(targetTensor.buffer, targetData);
|
|
196
|
+
}
|
|
197
|
+
yield {
|
|
198
|
+
input: inputTensor,
|
|
199
|
+
targets: targetTensor,
|
|
200
|
+
};
|
|
201
|
+
}
|
|
202
|
+
} finally {
|
|
203
|
+
releaseTensor(inputTensor);
|
|
204
|
+
releaseTensor(targetTensor);
|
|
205
|
+
}
|
|
206
|
+
},
|
|
207
|
+
};
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
function collectProtectedBuffers(model) {
|
|
211
|
+
const protectedBuffers = new Set();
|
|
212
|
+
const groups = model.paramGroups();
|
|
213
|
+
for (const params of Object.values(groups)) {
|
|
214
|
+
for (const tensor of params) {
|
|
215
|
+
if (tensor?.buffer) {
|
|
216
|
+
protectedBuffers.add(tensor.buffer);
|
|
217
|
+
}
|
|
218
|
+
}
|
|
219
|
+
}
|
|
220
|
+
return protectedBuffers;
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
function disposeTapeOutputs(tape, protectedBuffers = new Set()) {
|
|
224
|
+
if (!Array.isArray(tape?.records)) return;
|
|
225
|
+
const released = new Set();
|
|
226
|
+
for (const record of tape.records) {
|
|
227
|
+
const output = record?.output;
|
|
228
|
+
if (output?.buffer && !protectedBuffers.has(output.buffer) && !released.has(output.buffer)) {
|
|
229
|
+
released.add(output.buffer);
|
|
230
|
+
releaseBuffer(output.buffer);
|
|
231
|
+
}
|
|
232
|
+
}
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
function argmax(values) {
|
|
236
|
+
let bestIndex = 0;
|
|
237
|
+
let bestValue = Number.NEGATIVE_INFINITY;
|
|
238
|
+
for (let index = 0; index < values.length; index += 1) {
|
|
239
|
+
const value = Number.isFinite(values[index]) ? values[index] : Number.NEGATIVE_INFINITY;
|
|
240
|
+
if (value > bestValue) {
|
|
241
|
+
bestValue = value;
|
|
242
|
+
bestIndex = index;
|
|
243
|
+
}
|
|
244
|
+
}
|
|
245
|
+
return bestIndex;
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
async function evaluateToyLoraModel(workload, model, dataset, layout = null, checkpointMeta = {}) {
|
|
249
|
+
const protectedBuffers = collectProtectedBuffers(model);
|
|
250
|
+
const evalReports = [];
|
|
251
|
+
const evalDatasets = Array.isArray(workload.evalDatasets) ? workload.evalDatasets : [];
|
|
252
|
+
for (const evalDataset of evalDatasets) {
|
|
253
|
+
if (evalDataset.evalKind !== 'classification' && evalDataset.evalKind !== 'text_generation') {
|
|
254
|
+
throw new Error(`LoRA eval currently supports classification/text_generation only, got "${evalDataset.evalKind}".`);
|
|
255
|
+
}
|
|
256
|
+
const evalDatasetMaterialized = evalDataset.datasetPath === dataset.absolutePath
|
|
257
|
+
? dataset
|
|
258
|
+
: await loadToyLoraDataset(evalDataset.datasetPath);
|
|
259
|
+
const rows = evalDatasetMaterialized.rows;
|
|
260
|
+
const predictions = [];
|
|
261
|
+
const labels = [];
|
|
262
|
+
for (const row of rows) {
|
|
263
|
+
const tape = new AutogradTape(loadBackwardRegistry());
|
|
264
|
+
const inputTensor = makeTensorFromFloat32(row.input, [1, 3], 'lora_eval_input');
|
|
265
|
+
let logits = null;
|
|
266
|
+
try {
|
|
267
|
+
logits = await model.forward(inputTensor, tape);
|
|
268
|
+
const logitsData = new Float32Array(await readBuffer(logits.buffer));
|
|
269
|
+
predictions.push(String(argmax(logitsData)));
|
|
270
|
+
labels.push(String(row.target));
|
|
271
|
+
} finally {
|
|
272
|
+
releaseTensor(inputTensor);
|
|
273
|
+
if (logits?.buffer && !protectedBuffers.has(logits.buffer)) {
|
|
274
|
+
releaseBuffer(logits.buffer);
|
|
275
|
+
}
|
|
276
|
+
disposeTapeOutputs(tape, protectedBuffers);
|
|
277
|
+
}
|
|
278
|
+
}
|
|
279
|
+
const metrics = computeEvalMetrics('classification', predictions, labels, {});
|
|
280
|
+
const reportPayload = {
|
|
281
|
+
artifactType: 'training_eval_report',
|
|
282
|
+
schemaVersion: 1,
|
|
283
|
+
generatedAt: new Date().toISOString(),
|
|
284
|
+
workloadId: workload.id,
|
|
285
|
+
workloadPath: checkpointMeta.workloadPath || null,
|
|
286
|
+
workloadSha256: checkpointMeta.workloadSha256 || null,
|
|
287
|
+
configHash: checkpointMeta.configHash || workload.configHash,
|
|
288
|
+
datasetPath: evalDataset.datasetPath,
|
|
289
|
+
datasetHash: evalDatasetMaterialized.datasetHash,
|
|
290
|
+
baseModelId: workload.baseModelId,
|
|
291
|
+
stage: 'lora',
|
|
292
|
+
checkpointStep: checkpointMeta.checkpointStep ?? null,
|
|
293
|
+
evalDatasetId: evalDataset.id,
|
|
294
|
+
metrics,
|
|
295
|
+
primaryMetric: metrics.primaryMetric,
|
|
296
|
+
primaryScore: metrics.primaryScore,
|
|
297
|
+
accuracy: metrics.accuracy?.score ?? null,
|
|
298
|
+
};
|
|
299
|
+
const reportFile = layout
|
|
300
|
+
? await writeJsonArtifact(
|
|
301
|
+
join(layout.eval, `${checkpointMeta.checkpointId || 'checkpoint'}__${evalDataset.id}.json`),
|
|
302
|
+
reportPayload
|
|
303
|
+
)
|
|
304
|
+
: null;
|
|
305
|
+
evalReports.push({
|
|
306
|
+
...reportPayload,
|
|
307
|
+
reportPath: reportFile?.path || null,
|
|
308
|
+
});
|
|
309
|
+
}
|
|
310
|
+
return evalReports;
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
function buildRunContract(loadedWorkload) {
|
|
314
|
+
return {
|
|
315
|
+
artifactType: 'training_run_contract',
|
|
316
|
+
schemaVersion: 1,
|
|
317
|
+
generatedAt: new Date().toISOString(),
|
|
318
|
+
workloadId: loadedWorkload.workload.id,
|
|
319
|
+
workloadPath: loadedWorkload.absolutePath,
|
|
320
|
+
workloadSha256: loadedWorkload.workloadSha256,
|
|
321
|
+
configHash: loadedWorkload.workload.configHash,
|
|
322
|
+
claimBoundary: loadedWorkload.workload.claimBoundary,
|
|
323
|
+
kind: loadedWorkload.workload.kind,
|
|
324
|
+
evalDatasets: loadedWorkload.workload.evalDatasets,
|
|
325
|
+
};
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
function buildArtifact(loadedWorkload, options) {
|
|
329
|
+
const workload = loadedWorkload.workload;
|
|
330
|
+
const payload = buildArtifactBase({
|
|
331
|
+
artifactType: options.artifactType,
|
|
332
|
+
reportId: `${options.prefix}_${workload.id}_${options.id}`,
|
|
333
|
+
workload,
|
|
334
|
+
workloadPath: loadedWorkload.absolutePath,
|
|
335
|
+
workloadSha256: loadedWorkload.workloadSha256,
|
|
336
|
+
datasetPath: options.datasetPath || workload.datasetPath,
|
|
337
|
+
datasetHash: options.datasetHash || null,
|
|
338
|
+
baseModelId: workload.baseModelId,
|
|
339
|
+
stage: options.stage || 'lora',
|
|
340
|
+
checkpointStep: options.checkpointStep ?? null,
|
|
341
|
+
parentArtifacts: options.parentArtifacts || [],
|
|
342
|
+
runtime: 'node',
|
|
343
|
+
surface: 'node',
|
|
344
|
+
claimBoundary: workload.claimBoundary,
|
|
345
|
+
configHash: options.configHash || workload.configHash,
|
|
346
|
+
});
|
|
347
|
+
return {
|
|
348
|
+
...payload,
|
|
349
|
+
artifactHash: hashArtifactPayload(payload),
|
|
350
|
+
};
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
async function exportToyLoraModel(loadedWorkload, layout, model, checkpointId, checkpointStep, datasetHash) {
|
|
354
|
+
const workload = loadedWorkload.workload;
|
|
355
|
+
const targetModule = model.targetModule || workload.pipeline.adapter.targetModules[0];
|
|
356
|
+
const exported = await exportLoRAAdapter({
|
|
357
|
+
id: workload.pipeline.export?.id || `${workload.id}-${checkpointId}`,
|
|
358
|
+
name: workload.pipeline.export?.name || `${workload.id}-${checkpointId}`,
|
|
359
|
+
baseModel: workload.baseModelId,
|
|
360
|
+
rank: workload.pipeline.adapter.rank,
|
|
361
|
+
alpha: workload.pipeline.adapter.alpha,
|
|
362
|
+
targetModules: [targetModule],
|
|
363
|
+
tensors: [
|
|
364
|
+
{ name: `layers.0.${targetModule}.lora_a`, tensor: model.adapter.A },
|
|
365
|
+
{ name: `layers.0.${targetModule}.lora_b`, tensor: model.adapter.B },
|
|
366
|
+
],
|
|
367
|
+
});
|
|
368
|
+
const manifestPath = join(layout.exports, `${checkpointId}.adapter.manifest.json`);
|
|
369
|
+
await writeFile(manifestPath, exported.json, 'utf8');
|
|
370
|
+
await loadLoRAFromManifest(exported.manifest, {});
|
|
371
|
+
const artifactPayload = {
|
|
372
|
+
...buildArtifact(loadedWorkload, {
|
|
373
|
+
prefix: 'lora_export',
|
|
374
|
+
id: checkpointId,
|
|
375
|
+
artifactType: 'lora_adapter_manifest',
|
|
376
|
+
checkpointStep,
|
|
377
|
+
datasetHash,
|
|
378
|
+
}),
|
|
379
|
+
checkpointId,
|
|
380
|
+
manifestPath,
|
|
381
|
+
manifest: exported.manifest,
|
|
382
|
+
};
|
|
383
|
+
const artifactFile = await writeJsonArtifact(
|
|
384
|
+
join(layout.exports, `${checkpointId}.export.json`),
|
|
385
|
+
artifactPayload
|
|
386
|
+
);
|
|
387
|
+
return {
|
|
388
|
+
checkpointId,
|
|
389
|
+
manifestPath,
|
|
390
|
+
exportPath: artifactFile.path,
|
|
391
|
+
manifest: exported.manifest,
|
|
392
|
+
};
|
|
393
|
+
}
|
|
394
|
+
|
|
395
|
+
async function selectLatestCheckpoint(runRoot) {
|
|
396
|
+
const checkpointsDir = join(runRoot, 'checkpoints');
|
|
397
|
+
const entries = await readdir(checkpointsDir, { withFileTypes: true });
|
|
398
|
+
const dirs = entries
|
|
399
|
+
.filter((entry) => entry.isDirectory())
|
|
400
|
+
.map((entry) => entry.name)
|
|
401
|
+
.sort((left, right) => left.localeCompare(right));
|
|
402
|
+
const latest = dirs[dirs.length - 1];
|
|
403
|
+
if (!latest) {
|
|
404
|
+
throw new Error(`No checkpoints found in ${checkpointsDir}.`);
|
|
405
|
+
}
|
|
406
|
+
return {
|
|
407
|
+
checkpointId: latest,
|
|
408
|
+
checkpointPath: join(checkpointsDir, latest, 'state.json'),
|
|
409
|
+
markerPath: join(checkpointsDir, latest, 'checkpoint.complete.json'),
|
|
410
|
+
};
|
|
411
|
+
}
|
|
412
|
+
|
|
413
|
+
export async function runLoraPipeline(options) {
|
|
414
|
+
const loadedWorkload = options.loadedWorkload;
|
|
415
|
+
const workload = loadedWorkload.workload;
|
|
416
|
+
if (workload.kind !== 'lora') {
|
|
417
|
+
throw new Error('runLoraPipeline requires a lora workload.');
|
|
418
|
+
}
|
|
419
|
+
if (workload.baseModelId !== 'training-toy') {
|
|
420
|
+
throw new Error('LoRA run currently supports baseModelId="training-toy" only.');
|
|
421
|
+
}
|
|
422
|
+
if (workload.pipeline.datasetFormat !== 'toy_linear_classification_jsonl') {
|
|
423
|
+
throw new Error('LoRA run currently supports datasetFormat="toy_linear_classification_jsonl" only.');
|
|
424
|
+
}
|
|
425
|
+
const layout = options.runRoot
|
|
426
|
+
? {
|
|
427
|
+
runRoot: resolve(String(options.runRoot)),
|
|
428
|
+
logs: join(resolve(String(options.runRoot)), 'logs'),
|
|
429
|
+
checkpoints: join(resolve(String(options.runRoot)), 'checkpoints'),
|
|
430
|
+
eval: join(resolve(String(options.runRoot)), 'eval'),
|
|
431
|
+
scoreboard: join(resolve(String(options.runRoot)), 'scoreboard'),
|
|
432
|
+
exports: join(resolve(String(options.runRoot)), 'exports'),
|
|
433
|
+
compare: join(resolve(String(options.runRoot)), 'compare'),
|
|
434
|
+
qualityGate: join(resolve(String(options.runRoot)), 'quality-gate'),
|
|
435
|
+
}
|
|
436
|
+
: await createTrainingRunLayout({
|
|
437
|
+
kind: 'lora',
|
|
438
|
+
workloadId: workload.id,
|
|
439
|
+
timestamp: options.timestamp || null,
|
|
440
|
+
});
|
|
441
|
+
await Promise.all(Object.values(layout).map((dirPath) => mkdir(dirPath, { recursive: true })));
|
|
442
|
+
await writeRunContract(layout, buildRunContract(loadedWorkload));
|
|
443
|
+
await writeWorkloadLock(layout, loadedWorkload);
|
|
444
|
+
const dataset = await loadToyLoraDataset(workload.datasetPath);
|
|
445
|
+
const fixture = createToyLoraModel(workload);
|
|
446
|
+
try {
|
|
447
|
+
const evalReports = [];
|
|
448
|
+
const checkpointArtifacts = [];
|
|
449
|
+
const exports = [];
|
|
450
|
+
const runner = new TrainingRunner({
|
|
451
|
+
training: {
|
|
452
|
+
enabled: true,
|
|
453
|
+
optimizer: {
|
|
454
|
+
type: workload.training.optimizer.type,
|
|
455
|
+
lr: workload.training.optimizer.lr,
|
|
456
|
+
beta1: workload.training.optimizer.beta1,
|
|
457
|
+
beta2: workload.training.optimizer.beta2,
|
|
458
|
+
eps: workload.training.optimizer.eps,
|
|
459
|
+
weightDecay: workload.training.optimizer.weightDecay,
|
|
460
|
+
scheduler: workload.training.optimizer.scheduler,
|
|
461
|
+
},
|
|
462
|
+
gradient: {
|
|
463
|
+
maxNorm: workload.training.gradientClipping.maxNorm,
|
|
464
|
+
},
|
|
465
|
+
precision: workload.training.precision,
|
|
466
|
+
lossScaling: { enabled: false },
|
|
467
|
+
distill: {
|
|
468
|
+
enabled: false,
|
|
469
|
+
stage: 'stage_a',
|
|
470
|
+
teacherModelId: null,
|
|
471
|
+
studentModelId: null,
|
|
472
|
+
datasetId: null,
|
|
473
|
+
datasetPath: null,
|
|
474
|
+
languagePair: null,
|
|
475
|
+
sourceLangs: null,
|
|
476
|
+
targetLangs: null,
|
|
477
|
+
pairAllowlist: null,
|
|
478
|
+
strictPairContract: false,
|
|
479
|
+
shardIndex: null,
|
|
480
|
+
shardCount: null,
|
|
481
|
+
resumeFrom: null,
|
|
482
|
+
artifactDir: null,
|
|
483
|
+
stageAArtifact: null,
|
|
484
|
+
stageAArtifactHash: null,
|
|
485
|
+
temperature: 1,
|
|
486
|
+
alphaKd: 1,
|
|
487
|
+
alphaCe: 0,
|
|
488
|
+
allowHintFallback: false,
|
|
489
|
+
tripletMargin: 0.2,
|
|
490
|
+
studentGraphMode: 'projection_head',
|
|
491
|
+
freeze: { encoder: false, prior: false, decoder: false, base: true, lora: false },
|
|
492
|
+
},
|
|
493
|
+
ul: {
|
|
494
|
+
enabled: false,
|
|
495
|
+
stage: 'stage1_joint',
|
|
496
|
+
stage1Artifact: null,
|
|
497
|
+
stage1ArtifactHash: null,
|
|
498
|
+
artifactDir: null,
|
|
499
|
+
lambda0: 5,
|
|
500
|
+
seed: workload.seed,
|
|
501
|
+
noiseSchedule: { name: 'linear', minSigma: 0.1, maxSigma: 1, steps: 1 },
|
|
502
|
+
priorAlignment: { enabled: false, weight: 1 },
|
|
503
|
+
decoderSigmoidWeight: { enabled: false, maxWeight: 1 },
|
|
504
|
+
lossWeights: { prior: 1, decoder: 1, recon: 1 },
|
|
505
|
+
freeze: null,
|
|
506
|
+
},
|
|
507
|
+
},
|
|
508
|
+
}, {
|
|
509
|
+
optimizer: new AdamOptimizer({
|
|
510
|
+
training: {
|
|
511
|
+
optimizer: {
|
|
512
|
+
type: workload.training.optimizer.type,
|
|
513
|
+
lr: workload.training.optimizer.lr,
|
|
514
|
+
beta1: workload.training.optimizer.beta1,
|
|
515
|
+
beta2: workload.training.optimizer.beta2,
|
|
516
|
+
eps: workload.training.optimizer.eps,
|
|
517
|
+
weightDecay: workload.training.optimizer.weightDecay,
|
|
518
|
+
scheduler: workload.training.optimizer.scheduler,
|
|
519
|
+
},
|
|
520
|
+
gradient: {
|
|
521
|
+
maxNorm: workload.training.gradientClipping.maxNorm,
|
|
522
|
+
},
|
|
523
|
+
precision: workload.training.precision,
|
|
524
|
+
},
|
|
525
|
+
}),
|
|
526
|
+
crossEntropyLoss,
|
|
527
|
+
clipGradients,
|
|
528
|
+
resolveCheckpointKey({ step }) {
|
|
529
|
+
return join(layout.checkpoints, `checkpoint-${String(step).padStart(6, '0')}`, 'state.json');
|
|
530
|
+
},
|
|
531
|
+
onCheckpoint: async (checkpoint) => {
|
|
532
|
+
const checkpointId = `checkpoint-${String(checkpoint.step).padStart(6, '0')}`;
|
|
533
|
+
const checkpointPayload = {
|
|
534
|
+
...buildArtifact(loadedWorkload, {
|
|
535
|
+
prefix: 'lora_ckpt',
|
|
536
|
+
id: checkpointId,
|
|
537
|
+
artifactType: 'training_checkpoint',
|
|
538
|
+
datasetHash: dataset.datasetHash,
|
|
539
|
+
checkpointStep: checkpoint.step,
|
|
540
|
+
}),
|
|
541
|
+
checkpointId,
|
|
542
|
+
checkpointPath: checkpoint.path,
|
|
543
|
+
optimizerStatePresent: true,
|
|
544
|
+
schedulerStatePresent: workload.training.optimizer.scheduler.enabled === true,
|
|
545
|
+
resumeLineage: checkpoint.metadata?.lineage || null,
|
|
546
|
+
};
|
|
547
|
+
await writeJsonArtifact(
|
|
548
|
+
join(layout.checkpoints, checkpointId, 'checkpoint.json'),
|
|
549
|
+
checkpointPayload
|
|
550
|
+
);
|
|
551
|
+
const checkpointArtifact = await writeJsonArtifact(
|
|
552
|
+
join(layout.checkpoints, checkpointId, 'checkpoint.complete.json'),
|
|
553
|
+
checkpointPayload
|
|
554
|
+
);
|
|
555
|
+
checkpointArtifacts.push({
|
|
556
|
+
checkpointId,
|
|
557
|
+
checkpointPath: checkpoint.path,
|
|
558
|
+
markerPath: checkpointArtifact.path,
|
|
559
|
+
checkpointStep: checkpoint.step,
|
|
560
|
+
});
|
|
561
|
+
if (workload.pipeline.export?.enabled === true && workload.pipeline.export.atCheckpoints === true) {
|
|
562
|
+
exports.push(await exportToyLoraModel(
|
|
563
|
+
loadedWorkload,
|
|
564
|
+
layout,
|
|
565
|
+
fixture.model,
|
|
566
|
+
checkpointId,
|
|
567
|
+
checkpoint.step,
|
|
568
|
+
dataset.datasetHash
|
|
569
|
+
));
|
|
570
|
+
}
|
|
571
|
+
const reports = await evaluateToyLoraModel(workload, fixture.model, dataset, layout, {
|
|
572
|
+
checkpointId,
|
|
573
|
+
checkpointStep: checkpoint.step,
|
|
574
|
+
configHash: workload.configHash,
|
|
575
|
+
workloadPath: loadedWorkload.absolutePath,
|
|
576
|
+
workloadSha256: loadedWorkload.workloadSha256,
|
|
577
|
+
});
|
|
578
|
+
for (const report of reports) {
|
|
579
|
+
evalReports.push(report);
|
|
580
|
+
await appendScoreboardRow(layout.scoreboard, {
|
|
581
|
+
artifactType: 'training_scoreboard',
|
|
582
|
+
schemaVersion: 1,
|
|
583
|
+
generatedAt: new Date().toISOString(),
|
|
584
|
+
checkpointId,
|
|
585
|
+
checkpointStep: checkpoint.step,
|
|
586
|
+
evalDatasetId: report.evalDatasetId,
|
|
587
|
+
primaryMetric: report.primaryMetric,
|
|
588
|
+
primaryScore: report.primaryScore,
|
|
589
|
+
accuracy: report.accuracy,
|
|
590
|
+
metrics: {
|
|
591
|
+
accuracy: report.accuracy,
|
|
592
|
+
primaryScore: report.primaryScore,
|
|
593
|
+
},
|
|
594
|
+
}, {
|
|
595
|
+
selectionMetric: workload.selectionMetric,
|
|
596
|
+
selectionGoal: workload.selectionGoal,
|
|
597
|
+
});
|
|
598
|
+
}
|
|
599
|
+
},
|
|
600
|
+
});
|
|
601
|
+
const metrics = await runner.run(
|
|
602
|
+
fixture.model,
|
|
603
|
+
createToyDatasetBatches(dataset.rows, workload.training.batchSize),
|
|
604
|
+
{
|
|
605
|
+
epochs: 1,
|
|
606
|
+
batchSize: workload.training.batchSize,
|
|
607
|
+
shuffle: false,
|
|
608
|
+
maxSteps: workload.training.steps,
|
|
609
|
+
checkpointEvery: workload.checkpointEvery,
|
|
610
|
+
modelId: workload.baseModelId,
|
|
611
|
+
}
|
|
612
|
+
);
|
|
613
|
+
const finalCheckpointId = runner.lastCheckpoint
|
|
614
|
+
? `checkpoint-${String(runner.lastCheckpoint.step).padStart(6, '0')}`
|
|
615
|
+
: null;
|
|
616
|
+
if (workload.pipeline.export?.enabled === true && finalCheckpointId && exports.every((entry) => entry.checkpointId !== finalCheckpointId)) {
|
|
617
|
+
exports.push(await exportToyLoraModel(
|
|
618
|
+
loadedWorkload,
|
|
619
|
+
layout,
|
|
620
|
+
fixture.model,
|
|
621
|
+
finalCheckpointId,
|
|
622
|
+
runner.lastCheckpoint.step,
|
|
623
|
+
dataset.datasetHash
|
|
624
|
+
));
|
|
625
|
+
}
|
|
626
|
+
return {
|
|
627
|
+
ok: true,
|
|
628
|
+
kind: 'lora',
|
|
629
|
+
action: 'run',
|
|
630
|
+
workloadId: workload.id,
|
|
631
|
+
runRoot: layout.runRoot,
|
|
632
|
+
checkpointArtifacts,
|
|
633
|
+
evalReports,
|
|
634
|
+
exports,
|
|
635
|
+
metrics,
|
|
636
|
+
lastCheckpoint: runner.lastCheckpoint,
|
|
637
|
+
};
|
|
638
|
+
} finally {
|
|
639
|
+
fixture.cleanup();
|
|
640
|
+
}
|
|
641
|
+
}
|
|
642
|
+
|
|
643
|
+
export async function evaluateLoraCheckpoint(options) {
|
|
644
|
+
const loadedWorkload = options.loadedWorkload;
|
|
645
|
+
const checkpointPath = resolve(String(options.checkpointPath));
|
|
646
|
+
const workload = loadedWorkload.workload;
|
|
647
|
+
const dataset = await loadToyLoraDataset(workload.datasetPath);
|
|
648
|
+
const checkpointRecord = await loadCheckpoint(checkpointPath);
|
|
649
|
+
if (!checkpointRecord) {
|
|
650
|
+
throw new Error(`Checkpoint not found: ${checkpointPath}`);
|
|
651
|
+
}
|
|
652
|
+
const fixture = createToyLoraModel(workload);
|
|
653
|
+
try {
|
|
654
|
+
await restoreTrainingCheckpointState(fixture.model, { getState: () => null }, checkpointRecord, {
|
|
655
|
+
training: {
|
|
656
|
+
distill: { freeze: { encoder: false, prior: false, decoder: false, base: true, lora: false } },
|
|
657
|
+
ul: { freeze: null },
|
|
658
|
+
},
|
|
659
|
+
});
|
|
660
|
+
return evaluateToyLoraModel(workload, fixture.model, dataset, options.layout || null, {
|
|
661
|
+
checkpointId: options.checkpointId || 'checkpoint',
|
|
662
|
+
checkpointStep: options.checkpointStep ?? null,
|
|
663
|
+
configHash: workload.configHash,
|
|
664
|
+
workloadPath: loadedWorkload.absolutePath,
|
|
665
|
+
workloadSha256: loadedWorkload.workloadSha256,
|
|
666
|
+
});
|
|
667
|
+
} finally {
|
|
668
|
+
fixture.cleanup();
|
|
669
|
+
}
|
|
670
|
+
}
|
|
671
|
+
|
|
672
|
+
export async function exportLoraCheckpoint(options) {
|
|
673
|
+
const loadedWorkload = options.loadedWorkload;
|
|
674
|
+
const workload = loadedWorkload.workload;
|
|
675
|
+
const layout = options.layout || {
|
|
676
|
+
exports: resolve(options.exportsDir || 'reports/training/lora/exports'),
|
|
677
|
+
};
|
|
678
|
+
const checkpointPath = resolve(String(options.checkpointPath));
|
|
679
|
+
const checkpointRecord = await loadCheckpoint(checkpointPath);
|
|
680
|
+
if (!checkpointRecord) {
|
|
681
|
+
throw new Error(`Checkpoint not found: ${checkpointPath}`);
|
|
682
|
+
}
|
|
683
|
+
const fixture = createToyLoraModel(workload);
|
|
684
|
+
try {
|
|
685
|
+
await restoreTrainingCheckpointState(fixture.model, { getState: () => null }, checkpointRecord, {
|
|
686
|
+
training: {
|
|
687
|
+
distill: { freeze: { encoder: false, prior: false, decoder: false, base: true, lora: false } },
|
|
688
|
+
ul: { freeze: null },
|
|
689
|
+
},
|
|
690
|
+
});
|
|
691
|
+
const checkpointId = options.checkpointId || 'checkpoint';
|
|
692
|
+
return exportToyLoraModel(
|
|
693
|
+
loadedWorkload,
|
|
694
|
+
{ ...layout, exports: layout.exports || resolve(options.exportsDir || 'reports/training/lora/exports') },
|
|
695
|
+
fixture.model,
|
|
696
|
+
checkpointId,
|
|
697
|
+
options.checkpointStep ?? null,
|
|
698
|
+
options.datasetHash || null
|
|
699
|
+
);
|
|
700
|
+
} finally {
|
|
701
|
+
fixture.cleanup();
|
|
702
|
+
}
|
|
703
|
+
}
|
|
704
|
+
|
|
705
|
+
export async function watchLoraCheckpoints(options) {
|
|
706
|
+
const latestCheckpoint = await selectLatestCheckpoint(options.runRoot);
|
|
707
|
+
return watchFinalizedCheckpoints({
|
|
708
|
+
checkpointsDir: join(options.runRoot, 'checkpoints'),
|
|
709
|
+
manifestPath: join(options.runRoot, 'scoreboard', 'watch-manifest.json'),
|
|
710
|
+
pollIntervalMs: options.pollIntervalMs || 2000,
|
|
711
|
+
stopWhenIdle: options.stopWhenIdle === true,
|
|
712
|
+
onCheckpoint: async (markerPath) => {
|
|
713
|
+
const raw = await readFile(markerPath, 'utf8');
|
|
714
|
+
const marker = JSON.parse(raw);
|
|
715
|
+
await evaluateLoraCheckpoint({
|
|
716
|
+
loadedWorkload: options.loadedWorkload,
|
|
717
|
+
checkpointPath: marker.checkpointPath || latestCheckpoint.checkpointPath,
|
|
718
|
+
checkpointId: marker.checkpointId || latestCheckpoint.checkpointId,
|
|
719
|
+
checkpointStep: marker.checkpointStep ?? null,
|
|
720
|
+
layout: {
|
|
721
|
+
eval: join(options.runRoot, 'eval'),
|
|
722
|
+
},
|
|
723
|
+
});
|
|
724
|
+
},
|
|
725
|
+
});
|
|
726
|
+
}
|
|
727
|
+
|
|
728
|
+
export async function compareLoraRun(options) {
|
|
729
|
+
const evalDir = join(options.runRoot, 'eval');
|
|
730
|
+
const entries = await readdir(evalDir, { withFileTypes: true });
|
|
731
|
+
const reports = [];
|
|
732
|
+
for (const entry of entries) {
|
|
733
|
+
if (!entry.isFile() || !entry.name.endsWith('.json')) continue;
|
|
734
|
+
const raw = await readFile(join(evalDir, entry.name), 'utf8');
|
|
735
|
+
reports.push(JSON.parse(raw));
|
|
736
|
+
}
|
|
737
|
+
const sorted = reports
|
|
738
|
+
.slice()
|
|
739
|
+
.sort((left, right) => {
|
|
740
|
+
const leftScore = Number(left?.primaryScore ?? Number.NEGATIVE_INFINITY);
|
|
741
|
+
const rightScore = Number(right?.primaryScore ?? Number.NEGATIVE_INFINITY);
|
|
742
|
+
return rightScore - leftScore;
|
|
743
|
+
});
|
|
744
|
+
const payload = {
|
|
745
|
+
artifactType: 'training_compare_report',
|
|
746
|
+
schemaVersion: 1,
|
|
747
|
+
generatedAt: new Date().toISOString(),
|
|
748
|
+
runRoot: options.runRoot,
|
|
749
|
+
count: sorted.length,
|
|
750
|
+
best: sorted[0] || null,
|
|
751
|
+
reports: sorted.map((report) => ({
|
|
752
|
+
checkpointId: report.checkpointId || null,
|
|
753
|
+
evalDatasetId: report.evalDatasetId || null,
|
|
754
|
+
primaryMetric: report.primaryMetric || null,
|
|
755
|
+
primaryScore: report.primaryScore ?? null,
|
|
756
|
+
accuracy: report.accuracy ?? null,
|
|
757
|
+
reportPath: report.reportPath || null,
|
|
758
|
+
})),
|
|
759
|
+
};
|
|
760
|
+
const artifact = await writeJsonArtifact(join(options.runRoot, 'compare', 'compare.json'), payload);
|
|
761
|
+
return {
|
|
762
|
+
...payload,
|
|
763
|
+
comparePath: artifact.path,
|
|
764
|
+
};
|
|
765
|
+
}
|
|
766
|
+
|
|
767
|
+
export async function qualityGateLoraRun(options) {
|
|
768
|
+
const runRoot = resolve(String(options.runRoot));
|
|
769
|
+
const requiredPaths = [
|
|
770
|
+
join(runRoot, 'run_contract.json'),
|
|
771
|
+
join(runRoot, 'workload.lock.json'),
|
|
772
|
+
];
|
|
773
|
+
const checks = [];
|
|
774
|
+
for (const filePath of requiredPaths) {
|
|
775
|
+
try {
|
|
776
|
+
await readFile(filePath, 'utf8');
|
|
777
|
+
checks.push({ path: filePath, ok: true });
|
|
778
|
+
} catch (error) {
|
|
779
|
+
checks.push({ path: filePath, ok: false, error: error?.message || String(error) });
|
|
780
|
+
}
|
|
781
|
+
}
|
|
782
|
+
const passed = checks.every((entry) => entry.ok === true);
|
|
783
|
+
const payload = {
|
|
784
|
+
artifactType: 'training_quality_gate',
|
|
785
|
+
schemaVersion: 1,
|
|
786
|
+
generatedAt: new Date().toISOString(),
|
|
787
|
+
runRoot,
|
|
788
|
+
passed,
|
|
789
|
+
checks,
|
|
790
|
+
};
|
|
791
|
+
const artifact = await writeJsonArtifact(join(runRoot, 'quality-gate', 'quality-gate.json'), payload);
|
|
792
|
+
return {
|
|
793
|
+
...payload,
|
|
794
|
+
reportPath: artifact.path,
|
|
795
|
+
};
|
|
796
|
+
}
|