@mlx-node/trl 0.0.0 → 0.0.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +389 -0
- package/package.json +16 -5
- package/dist/data/dataset.d.ts +0 -22
- package/dist/data/dataset.d.ts.map +0 -1
- package/dist/data/dataset.js +0 -142
- package/dist/data/sft-dataset.d.ts +0 -156
- package/dist/data/sft-dataset.d.ts.map +0 -1
- package/dist/data/sft-dataset.js +0 -415
- package/dist/index.d.ts +0 -33
- package/dist/index.d.ts.map +0 -1
- package/dist/index.js +0 -47
- package/dist/trainers/grpo-config.d.ts +0 -42
- package/dist/trainers/grpo-config.d.ts.map +0 -1
- package/dist/trainers/grpo-config.js +0 -220
- package/dist/trainers/grpo-entropy.d.ts +0 -33
- package/dist/trainers/grpo-entropy.d.ts.map +0 -1
- package/dist/trainers/grpo-entropy.js +0 -18
- package/dist/trainers/grpo-trainer.d.ts +0 -602
- package/dist/trainers/grpo-trainer.d.ts.map +0 -1
- package/dist/trainers/grpo-trainer.js +0 -1439
- package/dist/trainers/sft-config.d.ts +0 -32
- package/dist/trainers/sft-config.d.ts.map +0 -1
- package/dist/trainers/sft-config.js +0 -186
- package/dist/trainers/sft-trainer.d.ts +0 -141
- package/dist/trainers/sft-trainer.d.ts.map +0 -1
- package/dist/trainers/sft-trainer.js +0 -502
- package/dist/trainers/training-logger.d.ts +0 -375
- package/dist/trainers/training-logger.d.ts.map +0 -1
- package/dist/trainers/training-logger.js +0 -542
- package/dist/types.d.ts +0 -54
- package/dist/types.d.ts.map +0 -1
- package/dist/types.js +0 -1
- package/dist/utils/path-security.d.ts +0 -51
- package/dist/utils/path-security.d.ts.map +0 -1
- package/dist/utils/path-security.js +0 -69
- package/dist/utils/xml-parser.d.ts +0 -6
- package/dist/utils/xml-parser.d.ts.map +0 -1
- package/dist/utils/xml-parser.js +0 -184
|
@@ -1,1439 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* GRPO Training Engine - Rust-Native Training
|
|
3
|
-
*
|
|
4
|
-
* This module provides a Rust-native GRPO training engine that minimizes
|
|
5
|
-
* FFI overhead by keeping the training loop entirely in Rust.
|
|
6
|
-
*
|
|
7
|
-
* ## Key Features
|
|
8
|
-
* - Training loop runs in Rust (eliminates FFI overhead)
|
|
9
|
-
* - Built-in reward functions (tool use, XML format, length, JSON schema)
|
|
10
|
-
* - Custom JS rewards via callback pattern
|
|
11
|
-
* - Gradient accumulation and memory management in Rust
|
|
12
|
-
* - High-level train() method for full training runs
|
|
13
|
-
* - Low-level trainStep() for custom training loops
|
|
14
|
-
*
|
|
15
|
-
* ## High-Level Usage (train with dataset)
|
|
16
|
-
* ```typescript
|
|
17
|
-
* const trainer = await GRPOTrainer.create({
|
|
18
|
-
* modelPath: './model',
|
|
19
|
-
* modelName: 'qwen3-0.6b',
|
|
20
|
-
* rewardFunction: (prompts, completions) => [...scores],
|
|
21
|
-
* });
|
|
22
|
-
* await trainer.train(dataset);
|
|
23
|
-
* ```
|
|
24
|
-
*
|
|
25
|
-
* ## Low-Level Usage (step-by-step)
|
|
26
|
-
* ```typescript
|
|
27
|
-
* const model = await Qwen3Model.loadPretrained(modelPath);
|
|
28
|
-
* const trainer = new GRPOTrainer(model, config);
|
|
29
|
-
*
|
|
30
|
-
* trainer.registerBuiltinReward({
|
|
31
|
-
* rewardType: 'ToolUse',
|
|
32
|
-
* allowedTools: ['search', 'calculate'],
|
|
33
|
-
* });
|
|
34
|
-
*
|
|
35
|
-
* for (const batch of dataset) {
|
|
36
|
-
* const completions = await trainer.generateBatch(batch.prompts);
|
|
37
|
-
* const rewards = await myRewardFunction(batch.prompts, completions);
|
|
38
|
-
* const metrics = await trainer.trainStep(batch.prompts, rewards);
|
|
39
|
-
* }
|
|
40
|
-
* ```
|
|
41
|
-
*/
|
|
42
|
-
import { existsSync, mkdirSync, writeFileSync, readFileSync, readdirSync, copyFileSync, cpSync, rmSync, statSync, } from 'node:fs';
|
|
43
|
-
import { createHash } from 'node:crypto';
|
|
44
|
-
import { dirname, join } from 'node:path';
|
|
45
|
-
import * as readline from 'node:readline';
|
|
46
|
-
import { GrpoTrainingEngine, NativeRewardRegistry, Qwen3Model, OutputStore, buildRewardOutputs, } from '@mlx-node/core';
|
|
47
|
-
import { createTrainingLogger } from './training-logger';
|
|
48
|
-
// Re-export native types
|
|
49
|
-
export { GrpoTrainingEngine, NativeRewardRegistry, OutputStore } from '@mlx-node/core';
|
|
50
|
-
/**
|
|
51
|
-
* Default configuration
|
|
52
|
-
*/
|
|
53
|
-
export const DEFAULT_GRPO_CONFIG = {
|
|
54
|
-
learningRate: 1e-6,
|
|
55
|
-
gradientAccumulationSteps: 1,
|
|
56
|
-
gradientClipNorm: 1.0,
|
|
57
|
-
weightDecay: 0.01,
|
|
58
|
-
numEpochs: 1,
|
|
59
|
-
batchSize: 1,
|
|
60
|
-
groupSize: 4,
|
|
61
|
-
clipEpsilon: 0.2,
|
|
62
|
-
klCoef: 0.0,
|
|
63
|
-
lossType: 'grpo',
|
|
64
|
-
advantageNormalization: true,
|
|
65
|
-
maxCompletionLength: 256,
|
|
66
|
-
temperature: 0.8,
|
|
67
|
-
topP: 0.95,
|
|
68
|
-
repetitionPenalty: 1.1,
|
|
69
|
-
logInterval: 1,
|
|
70
|
-
saveInterval: 100,
|
|
71
|
-
evalInterval: 100,
|
|
72
|
-
logConsole: true,
|
|
73
|
-
logJsonl: true,
|
|
74
|
-
maxCheckpoints: 3,
|
|
75
|
-
};
|
|
76
|
-
/**
|
|
77
|
-
* Compute a hash of dataset content for identity checking on resume.
|
|
78
|
-
*
|
|
79
|
-
* Hashes the first N examples to create a fingerprint that can detect
|
|
80
|
-
* if the dataset has been modified between training runs.
|
|
81
|
-
*
|
|
82
|
-
* @param dataset - Array of dataset examples
|
|
83
|
-
* @param sampleSize - Number of examples to hash (default: 10)
|
|
84
|
-
* @returns 16-character hex hash string
|
|
85
|
-
*/
|
|
86
|
-
export function computeDatasetHash(dataset, sampleSize = 10) {
|
|
87
|
-
const samples = dataset.slice(0, sampleSize);
|
|
88
|
-
// Create a content string from prompts (stringified for consistency)
|
|
89
|
-
const content = samples.map((ex) => JSON.stringify(ex.prompt)).join('|||');
|
|
90
|
-
return createHash('sha256').update(content).digest('hex').slice(0, 16);
|
|
91
|
-
}
|
|
92
|
-
// Note: RewardOutput is now built using the Rust buildRewardOutputs function
|
|
93
|
-
// which handles tool call parsing and thinking extraction natively.
|
|
94
|
-
/**
|
|
95
|
-
* Error thrown when a reward function times out.
|
|
96
|
-
*/
|
|
97
|
-
export class RewardTimeoutError extends Error {
|
|
98
|
-
timeoutMs;
|
|
99
|
-
constructor(message, timeoutMs) {
|
|
100
|
-
super(message);
|
|
101
|
-
this.timeoutMs = timeoutMs;
|
|
102
|
-
this.name = 'RewardTimeoutError';
|
|
103
|
-
}
|
|
104
|
-
}
|
|
105
|
-
/**
|
|
106
|
-
* Wraps a promise with a timeout.
|
|
107
|
-
*
|
|
108
|
-
* @param promise - The promise to wrap
|
|
109
|
-
* @param timeoutMs - Timeout in milliseconds (0 = no timeout)
|
|
110
|
-
* @param errorMessage - Error message if timeout is reached
|
|
111
|
-
* @returns The promise result or throws RewardTimeoutError
|
|
112
|
-
*/
|
|
113
|
-
function withTimeout(promise, timeoutMs, errorMessage) {
|
|
114
|
-
// Timeout of 0 means no timeout
|
|
115
|
-
if (timeoutMs <= 0) {
|
|
116
|
-
return promise;
|
|
117
|
-
}
|
|
118
|
-
return new Promise((resolve, reject) => {
|
|
119
|
-
const timer = setTimeout(() => {
|
|
120
|
-
reject(new RewardTimeoutError(errorMessage, timeoutMs));
|
|
121
|
-
}, timeoutMs);
|
|
122
|
-
promise
|
|
123
|
-
.then((result) => {
|
|
124
|
-
clearTimeout(timer);
|
|
125
|
-
resolve(result);
|
|
126
|
-
})
|
|
127
|
-
.catch((error) => {
|
|
128
|
-
clearTimeout(timer);
|
|
129
|
-
reject(error);
|
|
130
|
-
});
|
|
131
|
-
});
|
|
132
|
-
}
|
|
133
|
-
/**
|
|
134
|
-
* GRPO Trainer - Rust-Native Training Engine
|
|
135
|
-
*
|
|
136
|
-
* Provides a TypeScript-friendly interface to the Rust training engine.
|
|
137
|
-
* Supports both high-level training (train()) and low-level step-by-step (trainStep()).
|
|
138
|
-
*/
|
|
139
|
-
export class GRPOTrainer {
|
|
140
|
-
engine;
|
|
141
|
-
model;
|
|
142
|
-
config;
|
|
143
|
-
rewardFn;
|
|
144
|
-
currentEpoch = 0;
|
|
145
|
-
currentStep = 0;
|
|
146
|
-
/** Original model path (for tokenizer files when saving checkpoints) */
|
|
147
|
-
originalModelPath;
|
|
148
|
-
// TUI state
|
|
149
|
-
paused = false;
|
|
150
|
-
stopRequested = false;
|
|
151
|
-
stdinInterface;
|
|
152
|
-
logger;
|
|
153
|
-
sampleDisplayMode = 'all';
|
|
154
|
-
// Output recording
|
|
155
|
-
outputStore;
|
|
156
|
-
outputStoreInitPromise;
|
|
157
|
-
outputStoreRunId;
|
|
158
|
-
outputStorePath;
|
|
159
|
-
// Crash recovery
|
|
160
|
-
lastCheckpointStep = 0;
|
|
161
|
-
signalHandlersInstalled = false;
|
|
162
|
-
// Last known good checkpoint tracking (for NaN gradient recovery)
|
|
163
|
-
lastGoodCheckpointPath = null;
|
|
164
|
-
lastGoodCheckpointStep = 0;
|
|
165
|
-
// Dataset tracking for resume validation
|
|
166
|
-
datasetMetadata;
|
|
167
|
-
processedBatchIndices = new Set();
|
|
168
|
-
/**
|
|
169
|
-
* Create a new GRPO trainer from a model
|
|
170
|
-
*
|
|
171
|
-
* @param model - Pre-loaded Qwen3 model
|
|
172
|
-
* @param config - Training configuration
|
|
173
|
-
*/
|
|
174
|
-
constructor(model, config = {}, logger) {
|
|
175
|
-
// Auto-detect TUI mode from environment variable (set by mlx-train TUI)
|
|
176
|
-
const tuiModeFromEnv = process.env.MLX_TUI_MODE === '1';
|
|
177
|
-
if (tuiModeFromEnv && config.tuiMode === undefined) {
|
|
178
|
-
config.tuiMode = true;
|
|
179
|
-
}
|
|
180
|
-
// Auto-enable database persistence in TUI mode (enables Database tab)
|
|
181
|
-
if (tuiModeFromEnv && config.outputStore === undefined) {
|
|
182
|
-
config.outputStore = { enabled: true };
|
|
183
|
-
}
|
|
184
|
-
this.config = { ...DEFAULT_GRPO_CONFIG, ...config };
|
|
185
|
-
this.model = model;
|
|
186
|
-
// Create or use provided logger (TUI mode auto-detected from MLX_TUI_MODE env var)
|
|
187
|
-
this.logger =
|
|
188
|
-
logger ??
|
|
189
|
-
createTrainingLogger({
|
|
190
|
-
logConsole: this.config.logConsole,
|
|
191
|
-
logJsonl: this.config.logJsonl,
|
|
192
|
-
outputDir: this.config.outputDir,
|
|
193
|
-
runName: this.config.runName,
|
|
194
|
-
logInterval: this.config.logInterval ?? 1,
|
|
195
|
-
});
|
|
196
|
-
// Set reward function if provided
|
|
197
|
-
if (this.config.rewardFunction) {
|
|
198
|
-
this.rewardFn = this.config.rewardFunction;
|
|
199
|
-
}
|
|
200
|
-
// Convert to native config
|
|
201
|
-
const engineConfig = {
|
|
202
|
-
learningRate: this.config.learningRate,
|
|
203
|
-
gradientAccumulationSteps: this.config.gradientAccumulationSteps,
|
|
204
|
-
gradientClipNorm: this.config.gradientClipNorm,
|
|
205
|
-
groupSize: this.config.groupSize,
|
|
206
|
-
clipEpsilon: this.config.clipEpsilon,
|
|
207
|
-
klCoef: this.config.klCoef,
|
|
208
|
-
lossType: this.config.lossType,
|
|
209
|
-
maxCompletionLength: this.config.maxCompletionLength,
|
|
210
|
-
temperature: this.config.temperature,
|
|
211
|
-
topP: this.config.topP,
|
|
212
|
-
topK: this.config.topK,
|
|
213
|
-
repetitionPenalty: this.config.repetitionPenalty,
|
|
214
|
-
// Tool calling support
|
|
215
|
-
tools: this.config.tools,
|
|
216
|
-
enableThinking: this.config.enableThinking,
|
|
217
|
-
// Memory optimization
|
|
218
|
-
lmHeadChunkSize: this.config.lmHeadChunkSize,
|
|
219
|
-
forwardChunkSize: this.config.forwardChunkSize,
|
|
220
|
-
vocabChunkSize: this.config.vocabChunkSize,
|
|
221
|
-
// Parallel batch generation
|
|
222
|
-
useParallelBatchGeneration: this.config.useParallelBatchGeneration,
|
|
223
|
-
};
|
|
224
|
-
this.engine = new GrpoTrainingEngine(model, engineConfig);
|
|
225
|
-
// Setup stdin handler if TUI mode
|
|
226
|
-
if (this.config.tuiMode) {
|
|
227
|
-
this.setupStdinHandler();
|
|
228
|
-
}
|
|
229
|
-
// Always setup signal handlers for crash recovery
|
|
230
|
-
this.setupSignalHandlers();
|
|
231
|
-
}
|
|
232
|
-
/**
|
|
233
|
-
* Setup stdin handler for TUI control commands
|
|
234
|
-
*/
|
|
235
|
-
setupStdinHandler() {
|
|
236
|
-
if (!this.config.tuiMode)
|
|
237
|
-
return;
|
|
238
|
-
this.stdinInterface = readline.createInterface({
|
|
239
|
-
input: process.stdin,
|
|
240
|
-
output: process.stdout,
|
|
241
|
-
terminal: false,
|
|
242
|
-
});
|
|
243
|
-
this.stdinInterface.on('line', (line) => {
|
|
244
|
-
const cmd = line.trim();
|
|
245
|
-
this.handleStdinCommand(cmd);
|
|
246
|
-
});
|
|
247
|
-
}
|
|
248
|
-
/**
|
|
249
|
-
* Setup OS signal handlers for graceful shutdown on crash/interrupt
|
|
250
|
-
*
|
|
251
|
-
* Catches SIGTERM, SIGINT, and uncaught exceptions to:
|
|
252
|
-
* - Save emergency checkpoint (if > 10 steps since last)
|
|
253
|
-
* - Finalize OutputStore with 'crashed' status
|
|
254
|
-
* - Exit cleanly
|
|
255
|
-
*/
|
|
256
|
-
setupSignalHandlers() {
|
|
257
|
-
if (this.signalHandlersInstalled)
|
|
258
|
-
return;
|
|
259
|
-
this.signalHandlersInstalled = true;
|
|
260
|
-
const gracefulShutdown = async (signal) => {
|
|
261
|
-
this.logger.warn(`Received ${signal}, initiating graceful shutdown...`);
|
|
262
|
-
this.stopRequested = true;
|
|
263
|
-
try {
|
|
264
|
-
// Skip checkpoint if recent one exists (within 10 steps)
|
|
265
|
-
const stepsSinceCheckpoint = this.currentStep - this.lastCheckpointStep;
|
|
266
|
-
if (this.config.outputDir && stepsSinceCheckpoint > 10) {
|
|
267
|
-
this.logger.info(`Saving emergency checkpoint (${stepsSinceCheckpoint} steps since last)...`);
|
|
268
|
-
await this.saveCheckpoint(`emergency-${this.currentStep}`);
|
|
269
|
-
}
|
|
270
|
-
else if (stepsSinceCheckpoint <= 10) {
|
|
271
|
-
this.logger.info(`Skipping checkpoint (only ${stepsSinceCheckpoint} steps since last)`);
|
|
272
|
-
}
|
|
273
|
-
// Finalize OutputStore with crashed status
|
|
274
|
-
if (this.outputStore) {
|
|
275
|
-
await this.outputStore.endRun('crashed');
|
|
276
|
-
await this.outputStore.flush();
|
|
277
|
-
this.logger.info('OutputStore finalized with crashed status');
|
|
278
|
-
}
|
|
279
|
-
}
|
|
280
|
-
catch (e) {
|
|
281
|
-
console.error('Emergency save failed:', e);
|
|
282
|
-
}
|
|
283
|
-
// Cleanup stdin interface
|
|
284
|
-
if (this.stdinInterface) {
|
|
285
|
-
this.stdinInterface.close();
|
|
286
|
-
}
|
|
287
|
-
process.exit(0);
|
|
288
|
-
};
|
|
289
|
-
process.on('SIGTERM', () => {
|
|
290
|
-
gracefulShutdown('SIGTERM').catch(console.error);
|
|
291
|
-
});
|
|
292
|
-
process.on('SIGINT', () => {
|
|
293
|
-
gracefulShutdown('SIGINT').catch(console.error);
|
|
294
|
-
});
|
|
295
|
-
process.on('uncaughtException', (err) => {
|
|
296
|
-
console.error('Uncaught exception:', err);
|
|
297
|
-
gracefulShutdown('uncaughtException').catch(console.error);
|
|
298
|
-
});
|
|
299
|
-
process.on('unhandledRejection', (reason) => {
|
|
300
|
-
console.error('Unhandled rejection:', reason);
|
|
301
|
-
gracefulShutdown('unhandledRejection').catch(console.error);
|
|
302
|
-
});
|
|
303
|
-
}
|
|
304
|
-
/**
|
|
305
|
-
* Initialize the output store for recording training outputs
|
|
306
|
-
*/
|
|
307
|
-
async initOutputStore(stepsPerEpoch) {
|
|
308
|
-
// Guard against re-initialization (e.g., train() called after trainStepAuto())
|
|
309
|
-
if (this.outputStore)
|
|
310
|
-
return;
|
|
311
|
-
const cfg = this.config.outputStore;
|
|
312
|
-
if (!cfg?.enabled) {
|
|
313
|
-
// Even without outputStore, send UI resume state if resuming from checkpoint
|
|
314
|
-
if (this.config.resumeFromCheckpoint && this.currentStep > 0 && stepsPerEpoch) {
|
|
315
|
-
this.sendResumeStateUiOnly(stepsPerEpoch);
|
|
316
|
-
}
|
|
317
|
-
return;
|
|
318
|
-
}
|
|
319
|
-
const localPath = cfg.localPath ?? join(this.config.outputDir ?? '.', 'outputs.db');
|
|
320
|
-
this.outputStorePath = localPath;
|
|
321
|
-
// Ensure parent directory exists (for lazy init via trainStepAuto)
|
|
322
|
-
const parentDir = dirname(localPath);
|
|
323
|
-
if (parentDir !== '.' && !existsSync(parentDir)) {
|
|
324
|
-
mkdirSync(parentDir, { recursive: true });
|
|
325
|
-
}
|
|
326
|
-
this.outputStore = await OutputStore.local(localPath);
|
|
327
|
-
// If resuming from checkpoint AND we have a run name AND checkpoint was actually loaded,
|
|
328
|
-
// try to resume the existing database run. currentStep > 0 means checkpoint was loaded.
|
|
329
|
-
if (this.config.resumeFromCheckpoint && this.config.runName && this.currentStep > 0) {
|
|
330
|
-
const existingRun = await this.outputStore.findRunByName(this.config.runName);
|
|
331
|
-
if (existingRun) {
|
|
332
|
-
this.logger.info(`Resuming database run: ${this.config.runName} (${existingRun.id})`);
|
|
333
|
-
await this.outputStore.resumeRun(existingRun.id);
|
|
334
|
-
this.outputStoreRunId = existingRun.id;
|
|
335
|
-
// Clean up any database records that are ahead of the checkpoint step
|
|
336
|
-
// This prevents UNIQUE constraint errors when re-recording steps
|
|
337
|
-
// Uses cascade delete to also clean up orphaned generations, tool_calls, and logs
|
|
338
|
-
if (this.currentStep > 0) {
|
|
339
|
-
const cleanupStats = await this.outputStore.deleteAllAfterStep(existingRun.id, this.currentStep);
|
|
340
|
-
if (cleanupStats.stepsDeleted > 0) {
|
|
341
|
-
this.logger.info(`Cleaned up stale records after step ${this.currentStep}: ` +
|
|
342
|
-
`${cleanupStats.stepsDeleted} steps, ${cleanupStats.generationsDeleted} generations, ` +
|
|
343
|
-
`${cleanupStats.logsDeleted} logs`);
|
|
344
|
-
}
|
|
345
|
-
}
|
|
346
|
-
this.logger.databasePath(localPath, this.outputStoreRunId, this.config.runName ?? undefined);
|
|
347
|
-
// Send resume state to TUI for sparkline and aggregate restoration
|
|
348
|
-
await this.sendResumeState(existingRun.id, stepsPerEpoch);
|
|
349
|
-
return;
|
|
350
|
-
}
|
|
351
|
-
else {
|
|
352
|
-
// Run name specified but not found - warn and create new
|
|
353
|
-
this.logger.warn(`No existing run found with name: ${this.config.runName}. Starting new run.`);
|
|
354
|
-
}
|
|
355
|
-
}
|
|
356
|
-
// Start a new run with sanitized config (no auth token)
|
|
357
|
-
const modelName = this.config.modelName ?? 'qwen3';
|
|
358
|
-
const modelPath = this.originalModelPath ?? this.config.modelPath ?? undefined;
|
|
359
|
-
const sanitizedConfig = {
|
|
360
|
-
...this.config,
|
|
361
|
-
outputStore: this.config.outputStore ? { ...this.config.outputStore, authToken: undefined } : undefined,
|
|
362
|
-
};
|
|
363
|
-
// Use startRunWithName if a run name is provided, otherwise use startRun
|
|
364
|
-
if (this.config.runName) {
|
|
365
|
-
this.outputStoreRunId = await this.outputStore.startRunWithName(this.config.runName, modelName, modelPath, JSON.stringify(sanitizedConfig));
|
|
366
|
-
}
|
|
367
|
-
else {
|
|
368
|
-
this.outputStoreRunId = await this.outputStore.startRun(modelName, modelPath, JSON.stringify(sanitizedConfig));
|
|
369
|
-
}
|
|
370
|
-
// Emit database path to TUI for the Database tab
|
|
371
|
-
this.logger.databasePath(localPath, this.outputStoreRunId, this.config.runName ?? undefined);
|
|
372
|
-
// If resuming from checkpoint but didn't find existing DB run, still send UI state
|
|
373
|
-
// This ensures TUI displays correct batch progress even without historical data
|
|
374
|
-
if (this.config.resumeFromCheckpoint && this.currentStep > 0 && stepsPerEpoch) {
|
|
375
|
-
this.sendResumeStateUiOnly(stepsPerEpoch);
|
|
376
|
-
}
|
|
377
|
-
}
|
|
378
|
-
/**
|
|
379
|
-
* Send minimal resume state to TUI for UI display only (no historical data)
|
|
380
|
-
*
|
|
381
|
-
* Used when resuming from checkpoint without a matching database run.
|
|
382
|
-
* Ensures TUI shows correct epoch/batch progress.
|
|
383
|
-
*/
|
|
384
|
-
sendResumeStateUiOnly(stepsPerEpoch) {
|
|
385
|
-
if (!this.logger.isTuiMode)
|
|
386
|
-
return;
|
|
387
|
-
const totalEpochs = this.config.numEpochs ?? 1;
|
|
388
|
-
const stepInEpoch = this.currentStep > 0 ? ((this.currentStep - 1) % stepsPerEpoch) + 1 : 0;
|
|
389
|
-
this.logger.resumeState({
|
|
390
|
-
step: this.currentStep,
|
|
391
|
-
epoch: this.currentEpoch + 1, // 1-indexed
|
|
392
|
-
totalEpochs,
|
|
393
|
-
stepInEpoch,
|
|
394
|
-
totalStepsInEpoch: stepsPerEpoch,
|
|
395
|
-
metricsHistory: [], // No historical data
|
|
396
|
-
aggregates: {
|
|
397
|
-
bestReward: 0,
|
|
398
|
-
avgReward: 0,
|
|
399
|
-
rewardCount: 0,
|
|
400
|
-
bestLoss: Infinity,
|
|
401
|
-
avgLoss: 0,
|
|
402
|
-
lossCount: 0,
|
|
403
|
-
totalTokens: 0,
|
|
404
|
-
avgGenerationTimeMs: 0,
|
|
405
|
-
avgTrainingTimeMs: 0,
|
|
406
|
-
},
|
|
407
|
-
});
|
|
408
|
-
this.logger.info(`Sent UI resume state to TUI (step ${this.currentStep}, no historical data)`);
|
|
409
|
-
}
|
|
410
|
-
/**
|
|
411
|
-
* Send resume state to TUI for restoring sparklines and aggregates
|
|
412
|
-
*
|
|
413
|
-
* Queries the database for historical metrics and aggregates, then sends
|
|
414
|
-
* to TUI via the resumeState message.
|
|
415
|
-
*
|
|
416
|
-
* @param runId - Database run ID
|
|
417
|
-
* @param actualStepsPerEpoch - Actual steps per epoch from dataset (if known)
|
|
418
|
-
*/
|
|
419
|
-
async sendResumeState(runId, actualStepsPerEpoch) {
|
|
420
|
-
if (!this.outputStore || !this.logger.isTuiMode)
|
|
421
|
-
return;
|
|
422
|
-
try {
|
|
423
|
-
// Query historical metrics (last 60 for sparklines)
|
|
424
|
-
const metricsHistory = await this.outputStore.getRecentStepMetrics(runId, 60);
|
|
425
|
-
// Query aggregate statistics
|
|
426
|
-
const aggregates = await this.outputStore.getRunAggregates(runId);
|
|
427
|
-
// Use actual steps per epoch if provided, otherwise use a reasonable default
|
|
428
|
-
const totalEpochs = this.config.numEpochs ?? 1;
|
|
429
|
-
const stepsPerEpoch = actualStepsPerEpoch ?? 50;
|
|
430
|
-
// Calculate step within current epoch
|
|
431
|
-
const stepInEpoch = this.currentStep > 0 ? ((this.currentStep - 1) % stepsPerEpoch) + 1 : 0;
|
|
432
|
-
// Send resume state to TUI
|
|
433
|
-
// Note: epoch is 1-indexed to match epochStart() convention
|
|
434
|
-
this.logger.resumeState({
|
|
435
|
-
step: this.currentStep,
|
|
436
|
-
epoch: this.currentEpoch + 1,
|
|
437
|
-
totalEpochs,
|
|
438
|
-
stepInEpoch,
|
|
439
|
-
totalStepsInEpoch: stepsPerEpoch,
|
|
440
|
-
metricsHistory: metricsHistory.map((m) => ({
|
|
441
|
-
step: Number(m.step),
|
|
442
|
-
loss: m.loss,
|
|
443
|
-
meanReward: m.meanReward,
|
|
444
|
-
stdAdvantage: m.stdAdvantage,
|
|
445
|
-
perplexity: m.perplexity ?? undefined,
|
|
446
|
-
tokenAccuracy: m.tokenAccuracy ?? undefined,
|
|
447
|
-
generationTimeMs: m.generationTimeMs ?? undefined,
|
|
448
|
-
trainingTimeMs: m.trainingTimeMs ?? undefined,
|
|
449
|
-
})),
|
|
450
|
-
aggregates: {
|
|
451
|
-
bestReward: aggregates.bestReward,
|
|
452
|
-
avgReward: aggregates.avgReward,
|
|
453
|
-
rewardCount: Number(aggregates.rewardCount),
|
|
454
|
-
bestLoss: aggregates.bestLoss,
|
|
455
|
-
avgLoss: aggregates.avgLoss,
|
|
456
|
-
lossCount: Number(aggregates.lossCount),
|
|
457
|
-
totalTokens: Number(aggregates.totalTokens),
|
|
458
|
-
avgGenerationTimeMs: aggregates.avgGenerationTimeMs,
|
|
459
|
-
avgTrainingTimeMs: aggregates.avgTrainingTimeMs,
|
|
460
|
-
},
|
|
461
|
-
});
|
|
462
|
-
this.logger.info(`Sent ${metricsHistory.length} historical metrics to TUI`);
|
|
463
|
-
}
|
|
464
|
-
catch (err) {
|
|
465
|
-
this.logger.warn(`Failed to send resume state to TUI: ${err}`);
|
|
466
|
-
}
|
|
467
|
-
}
|
|
468
|
-
/**
|
|
469
|
-
* Ensure output store is initialized (lazy initialization for low-level API users)
|
|
470
|
-
* Uses promise mutex to prevent race conditions from concurrent calls.
|
|
471
|
-
*
|
|
472
|
-
* Call this method from custom training loops before starting training
|
|
473
|
-
* to enable database recording and TUI database tab.
|
|
474
|
-
*/
|
|
475
|
-
async ensureOutputStoreInitialized() {
|
|
476
|
-
if (this.outputStore)
|
|
477
|
-
return; // Already initialized
|
|
478
|
-
if (!this.config.outputStore?.enabled)
|
|
479
|
-
return; // Not enabled
|
|
480
|
-
// Use promise mutex to prevent concurrent initialization
|
|
481
|
-
if (this.outputStoreInitPromise) {
|
|
482
|
-
await this.outputStoreInitPromise;
|
|
483
|
-
return;
|
|
484
|
-
}
|
|
485
|
-
this.outputStoreInitPromise = this.initOutputStore();
|
|
486
|
-
try {
|
|
487
|
-
await this.outputStoreInitPromise;
|
|
488
|
-
}
|
|
489
|
-
catch (err) {
|
|
490
|
-
// Clear promise on failure to allow retry
|
|
491
|
-
this.outputStoreInitPromise = undefined;
|
|
492
|
-
throw err;
|
|
493
|
-
}
|
|
494
|
-
}
|
|
495
|
-
/**
|
|
496
|
-
* Get the output store (for querying recorded data)
|
|
497
|
-
*/
|
|
498
|
-
getOutputStore() {
|
|
499
|
-
return this.outputStore;
|
|
500
|
-
}
|
|
501
|
-
/**
|
|
502
|
-
* Handle a command received from stdin
|
|
503
|
-
*/
|
|
504
|
-
handleStdinCommand(cmd) {
|
|
505
|
-
switch (cmd) {
|
|
506
|
-
case 'PAUSE':
|
|
507
|
-
this.paused = true;
|
|
508
|
-
this.logger.paused(this.currentStep);
|
|
509
|
-
break;
|
|
510
|
-
case 'RESUME':
|
|
511
|
-
this.paused = false;
|
|
512
|
-
this.logger.resumed(this.currentStep);
|
|
513
|
-
break;
|
|
514
|
-
case 'SAVE_CHECKPOINT':
|
|
515
|
-
// Will be handled in the training loop
|
|
516
|
-
this.saveCheckpoint().catch(() => { });
|
|
517
|
-
break;
|
|
518
|
-
case 'STOP':
|
|
519
|
-
this.stopRequested = true;
|
|
520
|
-
break;
|
|
521
|
-
default:
|
|
522
|
-
// Handle SET commands (e.g., SET sample_display=best_worst)
|
|
523
|
-
if (cmd.startsWith('SET ')) {
|
|
524
|
-
const keyValue = cmd.slice(4); // Remove 'SET ' prefix
|
|
525
|
-
const eqIdx = keyValue.indexOf('=');
|
|
526
|
-
if (eqIdx > 0) {
|
|
527
|
-
const key = keyValue.slice(0, eqIdx);
|
|
528
|
-
const value = keyValue.slice(eqIdx + 1);
|
|
529
|
-
if (key === 'sample_display') {
|
|
530
|
-
if (value === 'all' || value === 'best_worst' || value === 'random') {
|
|
531
|
-
this.sampleDisplayMode = value;
|
|
532
|
-
}
|
|
533
|
-
}
|
|
534
|
-
}
|
|
535
|
-
}
|
|
536
|
-
break;
|
|
537
|
-
}
|
|
538
|
-
}
|
|
539
|
-
/**
|
|
540
|
-
* Wait for resume if paused, with polling
|
|
541
|
-
*/
|
|
542
|
-
async waitForResume() {
|
|
543
|
-
while (this.paused && !this.stopRequested) {
|
|
544
|
-
await new Promise((resolve) => setTimeout(resolve, 100));
|
|
545
|
-
}
|
|
546
|
-
}
|
|
547
|
-
/**
|
|
548
|
-
* Create a trainer by loading a model from disk
|
|
549
|
-
*
|
|
550
|
-
* This is the recommended way to create a trainer for training runs.
|
|
551
|
-
* If resumeFromCheckpoint is set, loads from checkpoint instead of modelPath.
|
|
552
|
-
*
|
|
553
|
-
* @param config - Configuration including modelPath
|
|
554
|
-
* @returns Promise<GRPOTrainer>
|
|
555
|
-
*/
|
|
556
|
-
static async create(config) {
|
|
557
|
-
if (!config.modelPath) {
|
|
558
|
-
throw new Error('modelPath is required when using GRPOTrainer.create()');
|
|
559
|
-
}
|
|
560
|
-
// Validate unsupported config options (fail fast)
|
|
561
|
-
if (config.advantageNormalization === false) {
|
|
562
|
-
throw new Error('advantageNormalization=false is not yet supported. Remove this option or set to true.');
|
|
563
|
-
}
|
|
564
|
-
if (config.weightDecay !== undefined && config.weightDecay !== 0.01) {
|
|
565
|
-
throw new Error('Custom weightDecay is not yet implemented. Optimizer uses simple SGD. Remove weightDecay from config or use default (0.01).');
|
|
566
|
-
}
|
|
567
|
-
if (config.rewardType === 'model') {
|
|
568
|
-
throw new Error('rewardType="model" is not implemented. Use rewardType="function" with a custom reward function.');
|
|
569
|
-
}
|
|
570
|
-
if (config.rewardModelPath) {
|
|
571
|
-
throw new Error('rewardModelPath is not implemented. Use rewardType="function" with a custom reward function.');
|
|
572
|
-
}
|
|
573
|
-
if (config.device && config.device !== 'metal') {
|
|
574
|
-
throw new Error(`device="${config.device}" is not supported. MLX only supports Metal GPU. Remove device from config.`);
|
|
575
|
-
}
|
|
576
|
-
// Create logger early (before model loading)
|
|
577
|
-
// TUI mode is auto-detected from MLX_TUI_MODE env var (set by mlx-tui)
|
|
578
|
-
const logger = createTrainingLogger({
|
|
579
|
-
logConsole: config.logConsole,
|
|
580
|
-
logJsonl: config.logJsonl,
|
|
581
|
-
outputDir: config.outputDir,
|
|
582
|
-
runName: config.runName,
|
|
583
|
-
logInterval: config.logInterval ?? 1,
|
|
584
|
-
});
|
|
585
|
-
let modelPath = config.modelPath;
|
|
586
|
-
let resumedState = null;
|
|
587
|
-
// Handle checkpoint resumption
|
|
588
|
-
if (config.resumeFromCheckpoint) {
|
|
589
|
-
const checkpointPath = config.resumeFromCheckpoint === 'latest'
|
|
590
|
-
? GRPOTrainer.findLatestCheckpoint(config.outputDir)
|
|
591
|
-
: config.resumeFromCheckpoint;
|
|
592
|
-
if (checkpointPath) {
|
|
593
|
-
const statePath = join(checkpointPath, 'training_state.json');
|
|
594
|
-
if (existsSync(statePath)) {
|
|
595
|
-
resumedState = JSON.parse(readFileSync(statePath, 'utf-8'));
|
|
596
|
-
// Fallback: If training_state.json has step 0 but checkpoint name suggests otherwise,
|
|
597
|
-
// derive step from checkpoint name (e.g., checkpoint-8 → step 8)
|
|
598
|
-
// This handles cases where training_state.json was corrupted or overwritten
|
|
599
|
-
if (resumedState && resumedState.step === 0) {
|
|
600
|
-
const checkpointName = checkpointPath.split('/').pop() ?? '';
|
|
601
|
-
const match = checkpointName.match(/^checkpoint-(\d+)$/);
|
|
602
|
-
if (match) {
|
|
603
|
-
const derivedStep = parseInt(match[1], 10);
|
|
604
|
-
if (derivedStep > 0) {
|
|
605
|
-
logger.warn(`Checkpoint ${checkpointName} has step 0 in training_state.json but name suggests step ${derivedStep}. Using ${derivedStep}.`);
|
|
606
|
-
resumedState.step = derivedStep;
|
|
607
|
-
// Estimate epoch from step (will be refined by actual training data)
|
|
608
|
-
}
|
|
609
|
-
}
|
|
610
|
-
}
|
|
611
|
-
logger.info(`Resuming from checkpoint: ${checkpointPath} (step ${resumedState?.step}, epoch ${resumedState?.epoch})`);
|
|
612
|
-
}
|
|
613
|
-
// Load model weights from checkpoint
|
|
614
|
-
modelPath = checkpointPath;
|
|
615
|
-
}
|
|
616
|
-
else if (config.resumeFromCheckpoint === 'latest') {
|
|
617
|
-
logger.info('No checkpoint found, starting fresh training');
|
|
618
|
-
}
|
|
619
|
-
}
|
|
620
|
-
// Get model name for display
|
|
621
|
-
const modelName = modelPath.split('/').pop() ?? 'Unknown';
|
|
622
|
-
logger.status('loading', `Loading ${modelName}...`);
|
|
623
|
-
// Load model from disk (checkpoint or original)
|
|
624
|
-
const model = await Qwen3Model.loadPretrained(modelPath);
|
|
625
|
-
logger.status('loading', `${modelName} loaded`);
|
|
626
|
-
// Create trainer with the pre-created logger
|
|
627
|
-
const trainer = new GRPOTrainer(model, config, logger);
|
|
628
|
-
// Always store the original model path (for tokenizer files when saving checkpoints)
|
|
629
|
-
trainer.originalModelPath = config.modelPath;
|
|
630
|
-
// Restore training state if resuming
|
|
631
|
-
if (resumedState) {
|
|
632
|
-
trainer.currentStep = resumedState.step;
|
|
633
|
-
trainer.currentEpoch = resumedState.epoch;
|
|
634
|
-
// Restore dataset metadata for resume validation
|
|
635
|
-
if (resumedState.dataset) {
|
|
636
|
-
trainer.datasetMetadata = {
|
|
637
|
-
size: resumedState.dataset.size,
|
|
638
|
-
contentHash: resumedState.dataset.contentHash,
|
|
639
|
-
shuffleSeed: resumedState.dataset.shuffleSeed,
|
|
640
|
-
};
|
|
641
|
-
// Restore processed batch indices
|
|
642
|
-
if (resumedState.dataset.processedBatchIndices) {
|
|
643
|
-
trainer.processedBatchIndices = new Set(resumedState.dataset.processedBatchIndices);
|
|
644
|
-
}
|
|
645
|
-
logger.debug(`Restored dataset metadata: size=${resumedState.dataset.size}, hash=${resumedState.dataset.contentHash}, ` +
|
|
646
|
-
`${trainer.processedBatchIndices.size} processed batches`);
|
|
647
|
-
}
|
|
648
|
-
// If resuming from a regular checkpoint (not emergency), track it as last known good
|
|
649
|
-
// This allows recovery to fall back to this checkpoint if NaN gradients occur
|
|
650
|
-
if (modelPath && !modelPath.includes('emergency-')) {
|
|
651
|
-
trainer.lastGoodCheckpointPath = modelPath;
|
|
652
|
-
trainer.lastGoodCheckpointStep = resumedState.step;
|
|
653
|
-
logger.debug(`Initialized last good checkpoint from resumed state: step ${resumedState.step}`);
|
|
654
|
-
}
|
|
655
|
-
}
|
|
656
|
-
return trainer;
|
|
657
|
-
}
|
|
658
|
-
/**
|
|
659
|
-
* Find the latest checkpoint in the output directory
|
|
660
|
-
*/
|
|
661
|
-
static findLatestCheckpoint(outputDir) {
|
|
662
|
-
if (!outputDir || !existsSync(outputDir)) {
|
|
663
|
-
return null;
|
|
664
|
-
}
|
|
665
|
-
const entries = readdirSync(outputDir, { withFileTypes: true });
|
|
666
|
-
const checkpoints = entries
|
|
667
|
-
.filter((e) => e.isDirectory() && e.name.startsWith('checkpoint-'))
|
|
668
|
-
.map((e) => ({
|
|
669
|
-
name: e.name,
|
|
670
|
-
step: parseInt(e.name.replace('checkpoint-', ''), 10),
|
|
671
|
-
path: join(outputDir, e.name),
|
|
672
|
-
}))
|
|
673
|
-
.filter((c) => !isNaN(c.step))
|
|
674
|
-
.sort((a, b) => b.step - a.step);
|
|
675
|
-
return checkpoints.length > 0 ? checkpoints[0].path : null;
|
|
676
|
-
}
|
|
677
|
-
/**
|
|
678
|
-
* Register a built-in reward function
|
|
679
|
-
*
|
|
680
|
-
* Built-in rewards run entirely in Rust with no FFI overhead.
|
|
681
|
-
*
|
|
682
|
-
* @example
|
|
683
|
-
* ```typescript
|
|
684
|
-
* // Tool use validation
|
|
685
|
-
* trainer.registerBuiltinReward({
|
|
686
|
-
* rewardType: 'ToolUse',
|
|
687
|
-
* allowedTools: ['search', 'calculate'],
|
|
688
|
-
* required: true,
|
|
689
|
-
* weight: 1.0,
|
|
690
|
-
* });
|
|
691
|
-
*
|
|
692
|
-
* // XML format validation
|
|
693
|
-
* trainer.registerBuiltinReward({
|
|
694
|
-
* rewardType: 'XmlFormat',
|
|
695
|
-
* requiredTags: ['thinking', 'answer'],
|
|
696
|
-
* weight: 0.5,
|
|
697
|
-
* });
|
|
698
|
-
*
|
|
699
|
-
* // Length-based reward
|
|
700
|
-
* trainer.registerBuiltinReward({
|
|
701
|
-
* rewardType: 'Length',
|
|
702
|
-
* minLength: 100,
|
|
703
|
-
* maxLength: 500,
|
|
704
|
-
* useChars: true,
|
|
705
|
-
* });
|
|
706
|
-
* ```
|
|
707
|
-
*/
|
|
708
|
-
registerBuiltinReward(config) {
|
|
709
|
-
this.engine.registerBuiltinReward(config);
|
|
710
|
-
}
|
|
711
|
-
/**
|
|
712
|
-
* Set a custom JavaScript reward function
|
|
713
|
-
*
|
|
714
|
-
* The function will be called after generation to compute rewards.
|
|
715
|
-
*
|
|
716
|
-
* @param fn - Reward function that takes prompts and completions
|
|
717
|
-
*/
|
|
718
|
-
setRewardFunction(fn) {
|
|
719
|
-
this.rewardFn = fn;
|
|
720
|
-
}
|
|
721
|
-
/**
|
|
722
|
-
* Generate completions for prompts
|
|
723
|
-
*
|
|
724
|
-
* Generates `groupSize` completions per prompt.
|
|
725
|
-
* Returns all data needed for training, including tokens and log probabilities.
|
|
726
|
-
*
|
|
727
|
-
* @param prompts - Array of chat conversations
|
|
728
|
-
* @returns GenerateBatchResult with completion texts and native generation data
|
|
729
|
-
*/
|
|
730
|
-
async generateBatch(prompts) {
|
|
731
|
-
if (prompts.length === 0) {
|
|
732
|
-
return {
|
|
733
|
-
completionTexts: [],
|
|
734
|
-
nativeResult: {
|
|
735
|
-
completionTexts: [],
|
|
736
|
-
completionTokens: [],
|
|
737
|
-
completionLogprobs: [],
|
|
738
|
-
completionLengths: [],
|
|
739
|
-
finishReasons: [],
|
|
740
|
-
},
|
|
741
|
-
tokenCounts: [],
|
|
742
|
-
finishReasons: [],
|
|
743
|
-
};
|
|
744
|
-
}
|
|
745
|
-
// Call the native engine to generate completions with full data
|
|
746
|
-
const nativeResult = await this.engine.generateBatchForTraining(prompts);
|
|
747
|
-
return {
|
|
748
|
-
completionTexts: nativeResult.completionTexts,
|
|
749
|
-
nativeResult,
|
|
750
|
-
tokenCounts: nativeResult.completionLengths,
|
|
751
|
-
finishReasons: nativeResult.finishReasons,
|
|
752
|
-
};
|
|
753
|
-
}
|
|
754
|
-
/**
|
|
755
|
-
* Score completions using built-in rewards
|
|
756
|
-
*
|
|
757
|
-
* @param prompts - Prompt texts (one per completion)
|
|
758
|
-
* @param completions - Completion texts
|
|
759
|
-
* @returns Array of reward scores
|
|
760
|
-
*/
|
|
761
|
-
scoreCompletions(prompts, completions) {
|
|
762
|
-
return this.engine.scoreCompletions(prompts, completions);
|
|
763
|
-
}
|
|
764
|
-
/**
|
|
765
|
-
* Score generations using the configured reward function.
|
|
766
|
-
*
|
|
767
|
-
* Builds RewardOutput array with structured completion data and passes to reward function.
|
|
768
|
-
*
|
|
769
|
-
* @param prompts - Array of chat conversations
|
|
770
|
-
* @param completions - Generated completion texts
|
|
771
|
-
* @param context - Context for the reward function
|
|
772
|
-
* @param groupSize - Number of completions per prompt (optional, defaults to config.groupSize)
|
|
773
|
-
* @param tokenCounts - Token counts for each completion (optional, defaults to 0s)
|
|
774
|
-
* @param finishReasons - Finish reasons from generation (optional, e.g. "eos", "length", "repetition")
|
|
775
|
-
* @returns Promise<Float32Array> of reward scores
|
|
776
|
-
*/
|
|
777
|
-
async scoreGenerations(prompts, completions, context, groupSize, tokenCounts, finishReasons) {
|
|
778
|
-
const effectiveGroupSize = groupSize ?? this.config.groupSize ?? 4;
|
|
779
|
-
const expectedCompletions = prompts.length * effectiveGroupSize;
|
|
780
|
-
if (completions.length !== expectedCompletions) {
|
|
781
|
-
throw new Error(`Expected ${expectedCompletions} completions (${prompts.length} prompts × ${effectiveGroupSize} groupSize) but got ${completions.length}`);
|
|
782
|
-
}
|
|
783
|
-
if (!this.rewardFn && !this.engine.hasBuiltinRewards) {
|
|
784
|
-
throw new Error('No reward function configured. Set rewardFunction in config or call setRewardFunction()');
|
|
785
|
-
}
|
|
786
|
-
// Convert ChatMessage[][] to string[] for Rust function
|
|
787
|
-
const promptTexts = prompts.map((msgs) => msgs.map((m) => `${m.role}: ${m.content}`).join('\n'));
|
|
788
|
-
// Use provided token counts or default to 0
|
|
789
|
-
const effectiveTokenCounts = tokenCounts ?? completions.map(() => 0);
|
|
790
|
-
// Use provided finish reasons or default to empty (triggers inference fallback in Rust)
|
|
791
|
-
const effectiveFinishReasons = finishReasons ?? [];
|
|
792
|
-
// Build structured reward outputs using Rust function
|
|
793
|
-
const rewardOutputs = buildRewardOutputs(promptTexts, completions, effectiveTokenCounts, effectiveFinishReasons, effectiveGroupSize);
|
|
794
|
-
let rewards;
|
|
795
|
-
if (this.rewardFn) {
|
|
796
|
-
// Get timeout from config (default 60 seconds, 0 = no timeout)
|
|
797
|
-
const rewardTimeout = this.config.rewardTimeout ?? 60_000;
|
|
798
|
-
// Wrap reward function call with timeout
|
|
799
|
-
const rewardPromise = Promise.resolve(this.rewardFn(rewardOutputs, context));
|
|
800
|
-
rewards = await withTimeout(rewardPromise, rewardTimeout, `Reward function timed out after ${rewardTimeout}ms. ` +
|
|
801
|
-
`Consider increasing rewardTimeout in config or optimizing your reward function.`);
|
|
802
|
-
}
|
|
803
|
-
else {
|
|
804
|
-
// For built-in rewards, extract prompts and completions for legacy API
|
|
805
|
-
const promptStrings = rewardOutputs.map((o) => o.prompt);
|
|
806
|
-
const completionTexts = rewardOutputs.map((o) => o.completion.rawText);
|
|
807
|
-
rewards = this.scoreCompletions(promptStrings, completionTexts);
|
|
808
|
-
}
|
|
809
|
-
const rewardsArray = rewards instanceof Float32Array ? rewards : Float32Array.from(rewards);
|
|
810
|
-
if (rewardsArray.length !== expectedCompletions) {
|
|
811
|
-
throw new Error(`Reward function returned ${rewardsArray.length} rewards but expected ${expectedCompletions}`);
|
|
812
|
-
}
|
|
813
|
-
return rewardsArray;
|
|
814
|
-
}
|
|
815
|
-
/**
|
|
816
|
-
* Run a training step
|
|
817
|
-
*
|
|
818
|
-
* This method:
|
|
819
|
-
* 1. Generates completions with tokens and log probabilities
|
|
820
|
-
* 2. Computes rewards using the configured reward function
|
|
821
|
-
* 3. Trains using the SAME completions that were scored (no double-generation)
|
|
822
|
-
*
|
|
823
|
-
* @param prompts - Array of chat conversations
|
|
824
|
-
* @returns Training step metrics
|
|
825
|
-
*/
|
|
826
|
-
async trainStep(prompts, context) {
|
|
827
|
-
const { metrics } = await this.trainStepAuto(prompts, context);
|
|
828
|
-
return metrics;
|
|
829
|
-
}
|
|
830
|
-
/**
|
|
831
|
-
* Run a complete training step with automatic reward computation
|
|
832
|
-
*
|
|
833
|
-
* This method combines generation, reward scoring, and training into a single
|
|
834
|
-
* Rust call, eliminating FFI overhead by keeping token data in Rust memory.
|
|
835
|
-
*
|
|
836
|
-
* 1. Generates completions with full token/logprob data (stays in Rust)
|
|
837
|
-
* 2. Calls JS reward function with RewardOutput[]
|
|
838
|
-
* 3. Performs training update using the in-memory data
|
|
839
|
-
*
|
|
840
|
-
* @param prompts - Array of chat conversations
|
|
841
|
-
* @param context - Context for the reward function
|
|
842
|
-
* @returns Training metrics and generated completions
|
|
843
|
-
*/
|
|
844
|
-
async trainStepAuto(prompts, context) {
|
|
845
|
-
// Lazy initialize output store for low-level API users
|
|
846
|
-
await this.ensureOutputStoreInitialized();
|
|
847
|
-
if (!this.rewardFn && !this.engine.hasBuiltinRewards) {
|
|
848
|
-
throw new Error('No reward function configured. Set rewardFunction in config or call setRewardFunction()');
|
|
849
|
-
}
|
|
850
|
-
// Create reward callback that parses JSON and converts output to number[]
|
|
851
|
-
// The Rust side serializes Vec<RewardOutput> to JSON because complex nested types
|
|
852
|
-
// don't convert properly through ThreadsafeFunction
|
|
853
|
-
// Note: With CalleeHandled=true (default), callback receives (err, value) format
|
|
854
|
-
// Using ThreadsafeFunction<T, Promise<R>> pattern so Rust can await the Promise
|
|
855
|
-
const rewardCallback = async (err, outputsJson) => {
|
|
856
|
-
const rewardStart = Date.now();
|
|
857
|
-
if (err) {
|
|
858
|
-
throw new Error(`Reward callback error from Rust: ${err.message}`);
|
|
859
|
-
}
|
|
860
|
-
if (!outputsJson || outputsJson === 'null') {
|
|
861
|
-
throw new Error(`Invalid JSON received from Rust: ${outputsJson}`);
|
|
862
|
-
}
|
|
863
|
-
// Parse JSON and convert snake_case to camelCase for TypeScript compatibility
|
|
864
|
-
// Rust's serde serializes as snake_case but TypeScript expects camelCase
|
|
865
|
-
const rawOutputs = JSON.parse(outputsJson);
|
|
866
|
-
// Convert to RewardOutput format with proper camelCase properties
|
|
867
|
-
const outputs = rawOutputs.map((o) => ({
|
|
868
|
-
prompt: o.prompt,
|
|
869
|
-
completion: {
|
|
870
|
-
text: o.completion.text,
|
|
871
|
-
rawText: o.completion.raw_text,
|
|
872
|
-
toolCalls: o.completion.tool_calls.map((tc) => ({
|
|
873
|
-
id: tc.id,
|
|
874
|
-
name: tc.name,
|
|
875
|
-
arguments: tc.arguments,
|
|
876
|
-
status: tc.status, // 'ok' | 'invalid_json' | 'missing_name'
|
|
877
|
-
error: tc.error,
|
|
878
|
-
rawContent: tc.raw_content,
|
|
879
|
-
})),
|
|
880
|
-
thinking: o.completion.thinking ?? undefined,
|
|
881
|
-
numTokens: o.completion.num_tokens,
|
|
882
|
-
finishReason: o.completion.finish_reason,
|
|
883
|
-
},
|
|
884
|
-
expectedAnswer: o.expected_answer ?? undefined,
|
|
885
|
-
}));
|
|
886
|
-
this.logger.info(` → Computing rewards for ${outputs.length} completions...`);
|
|
887
|
-
let rewards;
|
|
888
|
-
if (this.rewardFn) {
|
|
889
|
-
// Get timeout from config (default 60 seconds, 0 = no timeout)
|
|
890
|
-
const rewardTimeout = this.config.rewardTimeout ?? 60_000;
|
|
891
|
-
// Wrap reward function call with timeout
|
|
892
|
-
const rewardPromise = Promise.resolve(
|
|
893
|
-
// @ts-expect-error context is optional
|
|
894
|
-
this.rewardFn(outputs, context));
|
|
895
|
-
rewards = await withTimeout(rewardPromise, rewardTimeout, `Reward function timed out after ${rewardTimeout}ms. ` +
|
|
896
|
-
`Consider increasing rewardTimeout in config or optimizing your reward function.`);
|
|
897
|
-
}
|
|
898
|
-
else {
|
|
899
|
-
// Use built-in rewards
|
|
900
|
-
const promptStrings = outputs.map((o) => o.prompt);
|
|
901
|
-
const completionTexts = outputs.map((o) => o.completion.rawText);
|
|
902
|
-
rewards = this.scoreCompletions(promptStrings, completionTexts);
|
|
903
|
-
}
|
|
904
|
-
// Convert Float32Array to plain number[] for NAPI compatibility
|
|
905
|
-
let result;
|
|
906
|
-
if (rewards instanceof Float32Array) {
|
|
907
|
-
result = Array.from(rewards, (v) => Number(v));
|
|
908
|
-
}
|
|
909
|
-
else {
|
|
910
|
-
result = rewards.map((v) => Number(v));
|
|
911
|
-
}
|
|
912
|
-
const rewardDuration = Date.now() - rewardStart;
|
|
913
|
-
const avgReward = result.reduce((a, b) => a + b, 0) / result.length;
|
|
914
|
-
this.logger.info(` → Rewards computed in ${rewardDuration}ms (avg=${avgReward.toFixed(2)})`);
|
|
915
|
-
return result;
|
|
916
|
-
};
|
|
917
|
-
// Call unified Rust method - generation, scoring, and training in one FFI call
|
|
918
|
-
// Use recording method if output store is enabled
|
|
919
|
-
const recordOutputs = !!this.outputStore;
|
|
920
|
-
const result = await this.engine.trainStepAuto(prompts, rewardCallback, recordOutputs);
|
|
921
|
-
this.currentStep++;
|
|
922
|
-
// Record outputs to database if enabled
|
|
923
|
-
if (this.outputStore && result.outputsJson) {
|
|
924
|
-
try {
|
|
925
|
-
await this.outputStore.recordStepFromOutputs(this.currentStep, result.metrics, result.outputsJson, result.rewards, this.config.groupSize ?? 4);
|
|
926
|
-
}
|
|
927
|
-
catch (err) {
|
|
928
|
-
// Log error but don't fail training
|
|
929
|
-
console.error('[OutputStore] Failed to record step:', err);
|
|
930
|
-
}
|
|
931
|
-
}
|
|
932
|
-
return {
|
|
933
|
-
metrics: { ...result.metrics, epoch: this.currentEpoch },
|
|
934
|
-
completions: result.completions,
|
|
935
|
-
rewards: result.rewards,
|
|
936
|
-
completionLengths: result.completionLengths,
|
|
937
|
-
};
|
|
938
|
-
}
|
|
939
|
-
/**
|
|
940
|
-
* Increment the step counter (for custom training loops)
|
|
941
|
-
*
|
|
942
|
-
* Call this after each training step when using low-level APIs like
|
|
943
|
-
* engine.trainStepWithGenerations() instead of trainer.trainStepAuto().
|
|
944
|
-
*/
|
|
945
|
-
incrementStep() {
|
|
946
|
-
this.currentStep++;
|
|
947
|
-
}
|
|
948
|
-
/**
|
|
949
|
-
* Get the current step number
|
|
950
|
-
*/
|
|
951
|
-
getStep() {
|
|
952
|
-
return this.currentStep;
|
|
953
|
-
}
|
|
954
|
-
/**
|
|
955
|
-
* Get the current epoch number
|
|
956
|
-
*/
|
|
957
|
-
getEpoch() {
|
|
958
|
-
return this.currentEpoch;
|
|
959
|
-
}
|
|
960
|
-
/**
|
|
961
|
-
* Record a training step to the output store database (for custom training loops)
|
|
962
|
-
*
|
|
963
|
-
* Use this when building custom training loops with engine.trainStepWithGenerations().
|
|
964
|
-
* The step number should be the value after incrementStep() was called.
|
|
965
|
-
*
|
|
966
|
-
* @param step - Step number
|
|
967
|
-
* @param metrics - Step metrics from the engine
|
|
968
|
-
* @param completions - Generated completion texts
|
|
969
|
-
* @param rewards - Reward values for each completion
|
|
970
|
-
* @param prompts - Prompt messages for each completion
|
|
971
|
-
*/
|
|
972
|
-
async recordStepToDatabase(step, metrics, completions, rewards, prompts) {
|
|
973
|
-
if (!this.outputStore)
|
|
974
|
-
return;
|
|
975
|
-
const groupSize = this.config.groupSize ?? 4;
|
|
976
|
-
// Build outputs JSON in the format expected by recordStepFromOutputs
|
|
977
|
-
const outputs = completions.map((text, i) => ({
|
|
978
|
-
prompt: prompts[Math.floor(i / groupSize)] ?? '',
|
|
979
|
-
completion: {
|
|
980
|
-
text,
|
|
981
|
-
raw_text: text,
|
|
982
|
-
tool_calls: [],
|
|
983
|
-
thinking: null,
|
|
984
|
-
num_tokens: text.length, // Approximate
|
|
985
|
-
finish_reason: 'eos',
|
|
986
|
-
},
|
|
987
|
-
expected_answer: null,
|
|
988
|
-
}));
|
|
989
|
-
const outputsJson = JSON.stringify(outputs);
|
|
990
|
-
try {
|
|
991
|
-
await this.outputStore.recordStepFromOutputs(step, {
|
|
992
|
-
step,
|
|
993
|
-
loss: metrics.loss,
|
|
994
|
-
totalTokens: metrics.totalTokens,
|
|
995
|
-
meanReward: metrics.meanReward,
|
|
996
|
-
stdReward: metrics.stdReward,
|
|
997
|
-
meanAdvantage: metrics.meanAdvantage,
|
|
998
|
-
stdAdvantage: metrics.stdAdvantage,
|
|
999
|
-
generationTimeMs: 0,
|
|
1000
|
-
trainingTimeMs: 0,
|
|
1001
|
-
peakMemoryMb: 0,
|
|
1002
|
-
activeMemoryMb: 0,
|
|
1003
|
-
gradientsApplied: true,
|
|
1004
|
-
}, outputsJson, rewards, groupSize);
|
|
1005
|
-
}
|
|
1006
|
-
catch (err) {
|
|
1007
|
-
console.error('[OutputStore] Failed to record step:', err);
|
|
1008
|
-
}
|
|
1009
|
-
}
|
|
1010
|
-
/**
|
|
1011
|
-
* Run a full training loop over a dataset
|
|
1012
|
-
*
|
|
1013
|
-
* This is the high-level training API that handles:
|
|
1014
|
-
* - Epoch iteration
|
|
1015
|
-
* - Batching
|
|
1016
|
-
* - Generation and reward computation
|
|
1017
|
-
* - Logging (if configured)
|
|
1018
|
-
* - Checkpoint saving and resumption
|
|
1019
|
-
* - TUI mode support (pause/resume, sample reporting)
|
|
1020
|
-
*
|
|
1021
|
-
* @param dataset - Array of DatasetExample items
|
|
1022
|
-
*/
|
|
1023
|
-
async train(dataset) {
|
|
1024
|
-
if (dataset.length === 0) {
|
|
1025
|
-
return;
|
|
1026
|
-
}
|
|
1027
|
-
const numEpochs = this.config.numEpochs ?? 1;
|
|
1028
|
-
const batchSize = this.config.batchSize ?? 1;
|
|
1029
|
-
const saveInterval = this.config.saveInterval ?? 100;
|
|
1030
|
-
// Create output directory if needed
|
|
1031
|
-
if (this.config.outputDir && !existsSync(this.config.outputDir)) {
|
|
1032
|
-
mkdirSync(this.config.outputDir, { recursive: true });
|
|
1033
|
-
}
|
|
1034
|
-
// Calculate total steps per epoch BEFORE initOutputStore (needed for accurate resume state)
|
|
1035
|
-
const stepsPerEpoch = Math.ceil(dataset.length / batchSize);
|
|
1036
|
-
// Compute current dataset metadata
|
|
1037
|
-
const currentDatasetHash = computeDatasetHash(dataset);
|
|
1038
|
-
const currentDatasetMetadata = {
|
|
1039
|
-
size: dataset.length,
|
|
1040
|
-
contentHash: currentDatasetHash,
|
|
1041
|
-
};
|
|
1042
|
-
// Validate dataset on resume if we have previous metadata
|
|
1043
|
-
if (this.datasetMetadata && this.currentStep > 0) {
|
|
1044
|
-
const prevMeta = this.datasetMetadata;
|
|
1045
|
-
if (prevMeta.size !== dataset.length) {
|
|
1046
|
-
this.logger.warn(`[Resume] Dataset size mismatch: checkpoint was trained on ${prevMeta.size} examples, ` +
|
|
1047
|
-
`current dataset has ${dataset.length} examples. Batch indices may not align correctly.`);
|
|
1048
|
-
}
|
|
1049
|
-
if (prevMeta.contentHash !== currentDatasetHash) {
|
|
1050
|
-
this.logger.warn(`[Resume] Dataset content mismatch: checkpoint dataset hash ${prevMeta.contentHash}, ` +
|
|
1051
|
-
`current dataset hash ${currentDatasetHash}. Dataset may have been modified or shuffled differently.`);
|
|
1052
|
-
}
|
|
1053
|
-
// Log validation result
|
|
1054
|
-
if (prevMeta.size === dataset.length && prevMeta.contentHash === currentDatasetHash) {
|
|
1055
|
-
this.logger.info(`[Resume] Dataset validated: ${dataset.length} examples, hash ${currentDatasetHash} (matches checkpoint)`);
|
|
1056
|
-
}
|
|
1057
|
-
}
|
|
1058
|
-
// Store current dataset metadata for future checkpoints
|
|
1059
|
-
this.datasetMetadata = currentDatasetMetadata;
|
|
1060
|
-
// Initialize output store if enabled (pass stepsPerEpoch for accurate batch display on resume)
|
|
1061
|
-
await this.initOutputStore(stepsPerEpoch);
|
|
1062
|
-
// Determine starting point based on resumed state
|
|
1063
|
-
const startEpoch = this.currentEpoch;
|
|
1064
|
-
const startStep = this.currentStep;
|
|
1065
|
-
const startBatchIdx = startStep > 0 ? startStep % stepsPerEpoch : 0;
|
|
1066
|
-
// Get model name from path
|
|
1067
|
-
const modelName = this.originalModelPath?.split('/').pop() ?? this.config.modelPath?.split('/').pop() ?? 'Unknown';
|
|
1068
|
-
// Log training start
|
|
1069
|
-
this.logger.init(modelName, {
|
|
1070
|
-
trainingType: 'grpo',
|
|
1071
|
-
numEpochs,
|
|
1072
|
-
batchSize,
|
|
1073
|
-
groupSize: this.config.groupSize ?? 4,
|
|
1074
|
-
learningRate: this.config.learningRate ?? 1e-6,
|
|
1075
|
-
}, dataset.length);
|
|
1076
|
-
if (startStep > 0) {
|
|
1077
|
-
this.logger.info(`Resuming from step ${startStep} (epoch ${startEpoch + 1}, batch ${startBatchIdx + 1}/${stepsPerEpoch})`);
|
|
1078
|
-
}
|
|
1079
|
-
for (let epoch = startEpoch; epoch < numEpochs; epoch++) {
|
|
1080
|
-
// Check for stop request
|
|
1081
|
-
if (this.stopRequested)
|
|
1082
|
-
break;
|
|
1083
|
-
this.currentEpoch = epoch;
|
|
1084
|
-
this.startEpoch();
|
|
1085
|
-
const epochStartTime = Date.now();
|
|
1086
|
-
// Log epoch start
|
|
1087
|
-
this.logger.epochStart(epoch, numEpochs, stepsPerEpoch);
|
|
1088
|
-
// Calculate starting batch index for this epoch
|
|
1089
|
-
const epochStartBatch = epoch === startEpoch && startStep > 0 ? startBatchIdx * batchSize : 0;
|
|
1090
|
-
// Iterate through batches
|
|
1091
|
-
for (let i = epochStartBatch; i < dataset.length; i += batchSize) {
|
|
1092
|
-
// Check for stop request
|
|
1093
|
-
if (this.stopRequested)
|
|
1094
|
-
break;
|
|
1095
|
-
// Wait if paused
|
|
1096
|
-
if (this.paused) {
|
|
1097
|
-
await this.waitForResume();
|
|
1098
|
-
if (this.stopRequested)
|
|
1099
|
-
break;
|
|
1100
|
-
}
|
|
1101
|
-
// Calculate batch index (0-indexed within epoch)
|
|
1102
|
-
const batchIdx = Math.floor(i / batchSize);
|
|
1103
|
-
// Skip already processed batches on resume (from checkpoint's processedBatchIndices)
|
|
1104
|
-
if (this.processedBatchIndices.has(batchIdx)) {
|
|
1105
|
-
this.logger.debug(`Skipping already processed batch ${batchIdx + 1}/${stepsPerEpoch} (from checkpoint)`);
|
|
1106
|
-
continue;
|
|
1107
|
-
}
|
|
1108
|
-
const batch = dataset.slice(i, Math.min(i + batchSize, dataset.length));
|
|
1109
|
-
// Extract prompts and answers from batch
|
|
1110
|
-
const prompts = batch.map((ex) => ex.prompt);
|
|
1111
|
-
// Verbose logging for debugging stuck batches
|
|
1112
|
-
const batchNum = batchIdx + 1;
|
|
1113
|
-
this.logger.info(`Batch ${batchNum}/${stepsPerEpoch} starting (${prompts.length} prompts × ${this.config.groupSize ?? 4} groups)`);
|
|
1114
|
-
// Run training step with auto reward computation
|
|
1115
|
-
const stepStartTime = Date.now();
|
|
1116
|
-
const { metrics, completions, rewards, completionLengths } = await this.trainStepAuto(prompts);
|
|
1117
|
-
const stepDuration = Date.now() - stepStartTime;
|
|
1118
|
-
this.logger.info(`Batch ${batchNum}/${stepsPerEpoch} done in ${(stepDuration / 1000).toFixed(1)}s ` +
|
|
1119
|
-
`(gen=${metrics.generationTimeMs?.toFixed(0) ?? '?'}ms, train=${metrics.trainingTimeMs?.toFixed(0) ?? '?'}ms, loss=${metrics.loss.toFixed(4)})`);
|
|
1120
|
-
// Log step metrics (logger handles TUI/console mode internally)
|
|
1121
|
-
this.logger.step(metrics, batchIdx, stepsPerEpoch);
|
|
1122
|
-
// Track processed batch for resume
|
|
1123
|
-
this.processedBatchIndices.add(batchIdx);
|
|
1124
|
-
// Report generation samples to TUI based on display mode
|
|
1125
|
-
// In console mode, logger.generation() is a no-op
|
|
1126
|
-
const groupSize = this.config.groupSize ?? 4;
|
|
1127
|
-
// Determine which sample indices to report based on display mode
|
|
1128
|
-
let indicesToReport;
|
|
1129
|
-
if (this.sampleDisplayMode === 'all') {
|
|
1130
|
-
// Report all samples
|
|
1131
|
-
indicesToReport = Array.from({ length: completions.length }, (_, i) => i);
|
|
1132
|
-
}
|
|
1133
|
-
else if (this.sampleDisplayMode === 'best_worst') {
|
|
1134
|
-
// Find indices of best (max reward) and worst (min reward) samples
|
|
1135
|
-
let bestIdx = 0;
|
|
1136
|
-
let worstIdx = 0;
|
|
1137
|
-
let bestReward = rewards[0];
|
|
1138
|
-
let worstReward = rewards[0];
|
|
1139
|
-
for (let j = 1; j < rewards.length; j++) {
|
|
1140
|
-
if (rewards[j] > bestReward) {
|
|
1141
|
-
bestReward = rewards[j];
|
|
1142
|
-
bestIdx = j;
|
|
1143
|
-
}
|
|
1144
|
-
if (rewards[j] < worstReward) {
|
|
1145
|
-
worstReward = rewards[j];
|
|
1146
|
-
worstIdx = j;
|
|
1147
|
-
}
|
|
1148
|
-
}
|
|
1149
|
-
// Avoid duplicates if best and worst are the same
|
|
1150
|
-
indicesToReport = bestIdx === worstIdx ? [bestIdx] : [bestIdx, worstIdx];
|
|
1151
|
-
}
|
|
1152
|
-
else {
|
|
1153
|
-
// random: pick 2 random samples (or fewer if completions.length < 2)
|
|
1154
|
-
const numSamples = Math.min(2, completions.length);
|
|
1155
|
-
const shuffled = Array.from({ length: completions.length }, (_, i) => i);
|
|
1156
|
-
// Fisher-Yates partial shuffle for first numSamples
|
|
1157
|
-
for (let k = 0; k < numSamples; k++) {
|
|
1158
|
-
const randIdx = k + Math.floor(Math.random() * (shuffled.length - k));
|
|
1159
|
-
[shuffled[k], shuffled[randIdx]] = [shuffled[randIdx], shuffled[k]];
|
|
1160
|
-
}
|
|
1161
|
-
indicesToReport = shuffled.slice(0, numSamples);
|
|
1162
|
-
}
|
|
1163
|
-
for (const j of indicesToReport) {
|
|
1164
|
-
// Get the prompt for this completion (each prompt has groupSize completions)
|
|
1165
|
-
const promptIdx = Math.floor(j / groupSize);
|
|
1166
|
-
const promptMessages = prompts[promptIdx] ?? [];
|
|
1167
|
-
// Format prompt as text (last user message is most relevant)
|
|
1168
|
-
const lastUserMsg = promptMessages.filter((m) => m.role === 'user').pop();
|
|
1169
|
-
const promptText = lastUserMsg?.content ?? '';
|
|
1170
|
-
this.logger.generation({
|
|
1171
|
-
index: j,
|
|
1172
|
-
prompt: promptText,
|
|
1173
|
-
completion: completions[j],
|
|
1174
|
-
reward: rewards[j],
|
|
1175
|
-
tokens: completionLengths[j] ?? this.config.maxCompletionLength ?? 256,
|
|
1176
|
-
});
|
|
1177
|
-
}
|
|
1178
|
-
// Save checkpoint periodically
|
|
1179
|
-
if (this.config.outputDir && this.currentStep > 0 && this.currentStep % saveInterval === 0) {
|
|
1180
|
-
const path = await this.saveCheckpoint();
|
|
1181
|
-
if (path) {
|
|
1182
|
-
this.logger.checkpoint(path, this.currentStep);
|
|
1183
|
-
}
|
|
1184
|
-
}
|
|
1185
|
-
// Check for emergency checkpoint (triggered by consecutive NaN gradients)
|
|
1186
|
-
if (this.config.outputDir && this.engine.needsEmergencySave) {
|
|
1187
|
-
this.logger.warn(`[EMERGENCY] Emergency save triggered after ${this.engine.nanGradientCount} consecutive NaN gradients at step ${this.currentStep}`);
|
|
1188
|
-
// Save current (possibly corrupted) state for debugging
|
|
1189
|
-
const debugCheckpointPath = `emergency-debug-step-${this.currentStep}`;
|
|
1190
|
-
await this.saveCheckpoint(debugCheckpointPath, { isEmergency: true });
|
|
1191
|
-
this.logger.info(`[EMERGENCY] Saved debug checkpoint with current (possibly corrupted) state to ${debugCheckpointPath}`);
|
|
1192
|
-
// If we have a last known good checkpoint, copy it for recovery
|
|
1193
|
-
if (this.lastGoodCheckpointPath && existsSync(this.lastGoodCheckpointPath)) {
|
|
1194
|
-
this.logger.warn(`[EMERGENCY] Reverting to last good checkpoint from step ${this.lastGoodCheckpointStep}: ${this.lastGoodCheckpointPath}`);
|
|
1195
|
-
// Copy last good checkpoint to a recovery location
|
|
1196
|
-
const outputDir = this.config.outputDir ?? './outputs';
|
|
1197
|
-
const recoveryPath = join(outputDir, `emergency-recovery-step-${this.lastGoodCheckpointStep}`);
|
|
1198
|
-
try {
|
|
1199
|
-
// Remove existing recovery checkpoint if it exists
|
|
1200
|
-
if (existsSync(recoveryPath)) {
|
|
1201
|
-
rmSync(recoveryPath, { recursive: true, force: true });
|
|
1202
|
-
}
|
|
1203
|
-
// Copy the last good checkpoint to recovery location
|
|
1204
|
-
cpSync(this.lastGoodCheckpointPath, recoveryPath, { recursive: true });
|
|
1205
|
-
this.logger.info(`[EMERGENCY] Copied last good checkpoint to ${recoveryPath}`);
|
|
1206
|
-
this.logger.warn(`[EMERGENCY] Recovery checkpoint available at: ${recoveryPath}\n` +
|
|
1207
|
-
` To resume from the last good state, use: resumeFromCheckpoint: '${recoveryPath}'`);
|
|
1208
|
-
}
|
|
1209
|
-
catch (copyError) {
|
|
1210
|
-
this.logger.error(`[EMERGENCY] Failed to copy last good checkpoint: ${copyError}`);
|
|
1211
|
-
}
|
|
1212
|
-
}
|
|
1213
|
-
else {
|
|
1214
|
-
this.logger.error(`[EMERGENCY] No previous good checkpoint available for recovery! ` +
|
|
1215
|
-
`The debug checkpoint contains the current (potentially corrupted) model state.`);
|
|
1216
|
-
}
|
|
1217
|
-
// Clear the emergency flag
|
|
1218
|
-
this.engine.clearEmergencySaveFlag();
|
|
1219
|
-
// Log recovery guidance
|
|
1220
|
-
this.logger.warn(`[EMERGENCY] Training will continue, but model quality may be degraded.\n` +
|
|
1221
|
-
` Recommendations:\n` +
|
|
1222
|
-
` - Reduce learning rate (current: ${this.config.learningRate ?? 1e-6})\n` +
|
|
1223
|
-
` - Check training data for anomalies\n` +
|
|
1224
|
-
` - Consider stopping and resuming from the recovery checkpoint`);
|
|
1225
|
-
}
|
|
1226
|
-
}
|
|
1227
|
-
const epochEndTime = Date.now();
|
|
1228
|
-
const epochTimeSecs = (epochEndTime - epochStartTime) / 1000;
|
|
1229
|
-
this.endEpoch(epochTimeSecs);
|
|
1230
|
-
this.logger.epochEnd(epoch, numEpochs, epochTimeSecs);
|
|
1231
|
-
// Clear processed batch indices at epoch boundary (new epoch = new batches)
|
|
1232
|
-
this.processedBatchIndices.clear();
|
|
1233
|
-
}
|
|
1234
|
-
// Save final checkpoint
|
|
1235
|
-
if (this.config.outputDir && !this.stopRequested) {
|
|
1236
|
-
const path = await this.saveCheckpoint('final');
|
|
1237
|
-
if (path) {
|
|
1238
|
-
this.logger.checkpoint(path, this.currentStep);
|
|
1239
|
-
}
|
|
1240
|
-
}
|
|
1241
|
-
// Log completion
|
|
1242
|
-
this.logger.complete(this.currentStep);
|
|
1243
|
-
// End output store run if active
|
|
1244
|
-
if (this.outputStore) {
|
|
1245
|
-
const status = this.stopRequested ? 'stopped' : 'completed';
|
|
1246
|
-
await this.outputStore.endRun(status);
|
|
1247
|
-
await this.outputStore.flush();
|
|
1248
|
-
}
|
|
1249
|
-
// Cleanup stdin interface
|
|
1250
|
-
if (this.stdinInterface) {
|
|
1251
|
-
this.stdinInterface.close();
|
|
1252
|
-
}
|
|
1253
|
-
}
|
|
1254
|
-
/**
|
|
1255
|
-
* Save a checkpoint with model weights and training state
|
|
1256
|
-
*
|
|
1257
|
-
* Regular checkpoints (non-emergency) are tracked as "last known good" checkpoints.
|
|
1258
|
-
* When NaN gradients occur, the emergency save logic can restore from the last good checkpoint.
|
|
1259
|
-
*
|
|
1260
|
-
* @param name - Checkpoint name (default: "checkpoint-{step}")
|
|
1261
|
-
* @param options - Optional settings for checkpoint save behavior
|
|
1262
|
-
* @param options.isEmergency - If true, this is an emergency checkpoint (debug state, not "good")
|
|
1263
|
-
* @returns Path to saved checkpoint, or empty string if save was skipped due to corruption
|
|
1264
|
-
*/
|
|
1265
|
-
async saveCheckpoint(name, options) {
|
|
1266
|
-
const isEmergency = options?.isEmergency ?? false;
|
|
1267
|
-
const checkpointName = name ?? `checkpoint-${this.currentStep}`;
|
|
1268
|
-
const outputDir = this.config.outputDir ?? './outputs';
|
|
1269
|
-
const checkpointPath = join(outputDir, checkpointName);
|
|
1270
|
-
// Create checkpoint directory
|
|
1271
|
-
if (!existsSync(checkpointPath)) {
|
|
1272
|
-
mkdirSync(checkpointPath, { recursive: true });
|
|
1273
|
-
}
|
|
1274
|
-
// Save training state with dataset metadata for resume validation
|
|
1275
|
-
const state = {
|
|
1276
|
-
step: this.currentStep,
|
|
1277
|
-
epoch: this.currentEpoch,
|
|
1278
|
-
timestamp: new Date().toISOString(),
|
|
1279
|
-
dataset: this.datasetMetadata
|
|
1280
|
-
? {
|
|
1281
|
-
size: this.datasetMetadata.size,
|
|
1282
|
-
contentHash: this.datasetMetadata.contentHash,
|
|
1283
|
-
shuffleSeed: this.datasetMetadata.shuffleSeed,
|
|
1284
|
-
processedBatchIndices: Array.from(this.processedBatchIndices),
|
|
1285
|
-
}
|
|
1286
|
-
: undefined,
|
|
1287
|
-
};
|
|
1288
|
-
const statePath = join(checkpointPath, 'training_state.json');
|
|
1289
|
-
writeFileSync(statePath, JSON.stringify(state, null, 2));
|
|
1290
|
-
// Save model weights
|
|
1291
|
-
await this.model.saveModel(checkpointPath);
|
|
1292
|
-
// Copy tokenizer files from original model path (required for loading checkpoints)
|
|
1293
|
-
const tokenizerSource = this.originalModelPath ?? this.config.modelPath;
|
|
1294
|
-
if (tokenizerSource) {
|
|
1295
|
-
const tokenizerFiles = ['tokenizer.json', 'tokenizer_config.json', 'vocab.json', 'merges.txt'];
|
|
1296
|
-
for (const file of tokenizerFiles) {
|
|
1297
|
-
const srcPath = join(tokenizerSource, file);
|
|
1298
|
-
const destPath = join(checkpointPath, file);
|
|
1299
|
-
if (existsSync(srcPath) && !existsSync(destPath)) {
|
|
1300
|
-
copyFileSync(srcPath, destPath);
|
|
1301
|
-
}
|
|
1302
|
-
}
|
|
1303
|
-
}
|
|
1304
|
-
this.logger.info(`Checkpoint saved: ${checkpointPath}`);
|
|
1305
|
-
// Track last checkpoint step for emergency save throttling
|
|
1306
|
-
this.lastCheckpointStep = this.currentStep;
|
|
1307
|
-
// Track as "last known good" checkpoint (only for regular saves, not emergency saves)
|
|
1308
|
-
if (!isEmergency) {
|
|
1309
|
-
this.lastGoodCheckpointPath = checkpointPath;
|
|
1310
|
-
this.lastGoodCheckpointStep = this.currentStep;
|
|
1311
|
-
this.logger.debug(`Tracked as last good checkpoint: step ${this.currentStep}`);
|
|
1312
|
-
}
|
|
1313
|
-
// Clean up old checkpoints to save disk space
|
|
1314
|
-
const maxCheckpoints = this.config.maxCheckpoints ?? 3;
|
|
1315
|
-
if (maxCheckpoints > 0) {
|
|
1316
|
-
this.cleanupOldCheckpoints(outputDir, maxCheckpoints);
|
|
1317
|
-
}
|
|
1318
|
-
return checkpointPath;
|
|
1319
|
-
}
|
|
1320
|
-
/**
|
|
1321
|
-
* Remove old checkpoints, keeping only the most recent ones
|
|
1322
|
-
* Preserves 'final' and 'emergency-*' checkpoints
|
|
1323
|
-
*/
|
|
1324
|
-
cleanupOldCheckpoints(outputDir, maxToKeep) {
|
|
1325
|
-
try {
|
|
1326
|
-
const entries = readdirSync(outputDir, { withFileTypes: true });
|
|
1327
|
-
// Find regular checkpoint directories (checkpoint-N pattern)
|
|
1328
|
-
const checkpoints = [];
|
|
1329
|
-
for (const entry of entries) {
|
|
1330
|
-
if (!entry.isDirectory())
|
|
1331
|
-
continue;
|
|
1332
|
-
// Skip 'final' and 'emergency-*' checkpoints
|
|
1333
|
-
if (entry.name === 'final' || entry.name.startsWith('emergency-'))
|
|
1334
|
-
continue;
|
|
1335
|
-
// Match checkpoint-N pattern
|
|
1336
|
-
const match = entry.name.match(/^checkpoint-(\d+)$/);
|
|
1337
|
-
if (match) {
|
|
1338
|
-
const checkpointPath = join(outputDir, entry.name);
|
|
1339
|
-
const stat = statSync(checkpointPath);
|
|
1340
|
-
checkpoints.push({
|
|
1341
|
-
name: entry.name,
|
|
1342
|
-
step: parseInt(match[1], 10),
|
|
1343
|
-
mtime: stat.mtime,
|
|
1344
|
-
});
|
|
1345
|
-
}
|
|
1346
|
-
}
|
|
1347
|
-
// Sort by step number descending (newest first)
|
|
1348
|
-
checkpoints.sort((a, b) => b.step - a.step);
|
|
1349
|
-
// Remove old checkpoints beyond maxToKeep
|
|
1350
|
-
if (checkpoints.length > maxToKeep) {
|
|
1351
|
-
const toRemove = checkpoints.slice(maxToKeep);
|
|
1352
|
-
for (const checkpoint of toRemove) {
|
|
1353
|
-
const checkpointPath = join(outputDir, checkpoint.name);
|
|
1354
|
-
rmSync(checkpointPath, { recursive: true, force: true });
|
|
1355
|
-
this.logger.debug(`Removed old checkpoint: ${checkpoint.name}`);
|
|
1356
|
-
}
|
|
1357
|
-
}
|
|
1358
|
-
}
|
|
1359
|
-
catch (error) {
|
|
1360
|
-
// Don't fail training if cleanup fails
|
|
1361
|
-
this.logger.warn(`Failed to cleanup old checkpoints: ${error}`);
|
|
1362
|
-
}
|
|
1363
|
-
}
|
|
1364
|
-
/**
|
|
1365
|
-
* Start a new training epoch
|
|
1366
|
-
*/
|
|
1367
|
-
startEpoch() {
|
|
1368
|
-
this.engine.startEpoch();
|
|
1369
|
-
}
|
|
1370
|
-
/**
|
|
1371
|
-
* End the current epoch and get metrics
|
|
1372
|
-
*
|
|
1373
|
-
* @param epochTimeSecs - Duration of the epoch in seconds
|
|
1374
|
-
*/
|
|
1375
|
-
endEpoch(epochTimeSecs) {
|
|
1376
|
-
return this.engine.endEpoch(epochTimeSecs);
|
|
1377
|
-
}
|
|
1378
|
-
/**
|
|
1379
|
-
* Reset the trainer for a new training run
|
|
1380
|
-
*/
|
|
1381
|
-
reset() {
|
|
1382
|
-
this.engine.reset();
|
|
1383
|
-
}
|
|
1384
|
-
/**
|
|
1385
|
-
* Get current training step
|
|
1386
|
-
*/
|
|
1387
|
-
get step() {
|
|
1388
|
-
return Number(this.engine.step);
|
|
1389
|
-
}
|
|
1390
|
-
/**
|
|
1391
|
-
* Get current epoch
|
|
1392
|
-
*/
|
|
1393
|
-
get epoch() {
|
|
1394
|
-
return this.engine.epoch;
|
|
1395
|
-
}
|
|
1396
|
-
/**
|
|
1397
|
-
* Get current micro-step within gradient accumulation
|
|
1398
|
-
*/
|
|
1399
|
-
get microStep() {
|
|
1400
|
-
return this.engine.microStep;
|
|
1401
|
-
}
|
|
1402
|
-
/**
|
|
1403
|
-
* Check if built-in rewards are configured
|
|
1404
|
-
*/
|
|
1405
|
-
get hasBuiltinRewards() {
|
|
1406
|
-
return this.engine.hasBuiltinRewards;
|
|
1407
|
-
}
|
|
1408
|
-
/**
|
|
1409
|
-
* Get names of registered reward functions
|
|
1410
|
-
*/
|
|
1411
|
-
get rewardNames() {
|
|
1412
|
-
return this.engine.rewardNames;
|
|
1413
|
-
}
|
|
1414
|
-
/**
|
|
1415
|
-
* Get the underlying native engine
|
|
1416
|
-
*
|
|
1417
|
-
* For advanced use cases that need direct access.
|
|
1418
|
-
*/
|
|
1419
|
-
getNativeEngine() {
|
|
1420
|
-
return this.engine;
|
|
1421
|
-
}
|
|
1422
|
-
}
|
|
1423
|
-
/**
|
|
1424
|
-
* Create a standalone reward registry for testing rewards
|
|
1425
|
-
*
|
|
1426
|
-
* @example
|
|
1427
|
-
* ```typescript
|
|
1428
|
-
* const registry = createRewardRegistry();
|
|
1429
|
-
* registry.register({
|
|
1430
|
-
* rewardType: 'ToolUse',
|
|
1431
|
-
* allowedTools: ['search'],
|
|
1432
|
-
* });
|
|
1433
|
-
*
|
|
1434
|
-
* const score = registry.score('prompt', 'completion with <tool_call>...</tool_call>');
|
|
1435
|
-
* ```
|
|
1436
|
-
*/
|
|
1437
|
-
export function createRewardRegistry() {
|
|
1438
|
-
return new NativeRewardRegistry();
|
|
1439
|
-
}
|