@mlx-node/trl 0.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. package/dist/data/dataset.d.ts +22 -0
  2. package/dist/data/dataset.d.ts.map +1 -0
  3. package/dist/data/dataset.js +142 -0
  4. package/dist/data/sft-dataset.d.ts +156 -0
  5. package/dist/data/sft-dataset.d.ts.map +1 -0
  6. package/dist/data/sft-dataset.js +415 -0
  7. package/dist/index.d.ts +33 -0
  8. package/dist/index.d.ts.map +1 -0
  9. package/dist/index.js +47 -0
  10. package/dist/trainers/grpo-config.d.ts +42 -0
  11. package/dist/trainers/grpo-config.d.ts.map +1 -0
  12. package/dist/trainers/grpo-config.js +220 -0
  13. package/dist/trainers/grpo-entropy.d.ts +33 -0
  14. package/dist/trainers/grpo-entropy.d.ts.map +1 -0
  15. package/dist/trainers/grpo-entropy.js +18 -0
  16. package/dist/trainers/grpo-trainer.d.ts +602 -0
  17. package/dist/trainers/grpo-trainer.d.ts.map +1 -0
  18. package/dist/trainers/grpo-trainer.js +1439 -0
  19. package/dist/trainers/sft-config.d.ts +32 -0
  20. package/dist/trainers/sft-config.d.ts.map +1 -0
  21. package/dist/trainers/sft-config.js +186 -0
  22. package/dist/trainers/sft-trainer.d.ts +141 -0
  23. package/dist/trainers/sft-trainer.d.ts.map +1 -0
  24. package/dist/trainers/sft-trainer.js +502 -0
  25. package/dist/trainers/training-logger.d.ts +375 -0
  26. package/dist/trainers/training-logger.d.ts.map +1 -0
  27. package/dist/trainers/training-logger.js +542 -0
  28. package/dist/types.d.ts +54 -0
  29. package/dist/types.d.ts.map +1 -0
  30. package/dist/types.js +1 -0
  31. package/dist/utils/path-security.d.ts +51 -0
  32. package/dist/utils/path-security.d.ts.map +1 -0
  33. package/dist/utils/path-security.js +69 -0
  34. package/dist/utils/xml-parser.d.ts +6 -0
  35. package/dist/utils/xml-parser.d.ts.map +1 -0
  36. package/dist/utils/xml-parser.js +184 -0
  37. package/package.json +29 -0
@@ -0,0 +1,1439 @@
1
+ /**
2
+ * GRPO Training Engine - Rust-Native Training
3
+ *
4
+ * This module provides a Rust-native GRPO training engine that minimizes
5
+ * FFI overhead by keeping the training loop entirely in Rust.
6
+ *
7
+ * ## Key Features
8
+ * - Training loop runs in Rust (eliminates FFI overhead)
9
+ * - Built-in reward functions (tool use, XML format, length, JSON schema)
10
+ * - Custom JS rewards via callback pattern
11
+ * - Gradient accumulation and memory management in Rust
12
+ * - High-level train() method for full training runs
13
+ * - Low-level trainStep() for custom training loops
14
+ *
15
+ * ## High-Level Usage (train with dataset)
16
+ * ```typescript
17
+ * const trainer = await GRPOTrainer.create({
18
+ * modelPath: './model',
19
+ * modelName: 'qwen3-0.6b',
20
+ * rewardFunction: (prompts, completions) => [...scores],
21
+ * });
22
+ * await trainer.train(dataset);
23
+ * ```
24
+ *
25
+ * ## Low-Level Usage (step-by-step)
26
+ * ```typescript
27
+ * const model = await Qwen3Model.loadPretrained(modelPath);
28
+ * const trainer = new GRPOTrainer(model, config);
29
+ *
30
+ * trainer.registerBuiltinReward({
31
+ * rewardType: 'ToolUse',
32
+ * allowedTools: ['search', 'calculate'],
33
+ * });
34
+ *
35
+ * for (const batch of dataset) {
36
+ * const completions = await trainer.generateBatch(batch.prompts);
37
+ * const rewards = await myRewardFunction(batch.prompts, completions);
38
+ * const metrics = await trainer.trainStep(batch.prompts, rewards);
39
+ * }
40
+ * ```
41
+ */
42
+ import { existsSync, mkdirSync, writeFileSync, readFileSync, readdirSync, copyFileSync, cpSync, rmSync, statSync, } from 'node:fs';
43
+ import { createHash } from 'node:crypto';
44
+ import { dirname, join } from 'node:path';
45
+ import * as readline from 'node:readline';
46
+ import { GrpoTrainingEngine, NativeRewardRegistry, Qwen3Model, OutputStore, buildRewardOutputs, } from '@mlx-node/core';
47
+ import { createTrainingLogger } from './training-logger';
48
+ // Re-export native types
49
+ export { GrpoTrainingEngine, NativeRewardRegistry, OutputStore } from '@mlx-node/core';
50
+ /**
51
+ * Default configuration
52
+ */
53
+ export const DEFAULT_GRPO_CONFIG = {
54
+ learningRate: 1e-6,
55
+ gradientAccumulationSteps: 1,
56
+ gradientClipNorm: 1.0,
57
+ weightDecay: 0.01,
58
+ numEpochs: 1,
59
+ batchSize: 1,
60
+ groupSize: 4,
61
+ clipEpsilon: 0.2,
62
+ klCoef: 0.0,
63
+ lossType: 'grpo',
64
+ advantageNormalization: true,
65
+ maxCompletionLength: 256,
66
+ temperature: 0.8,
67
+ topP: 0.95,
68
+ repetitionPenalty: 1.1,
69
+ logInterval: 1,
70
+ saveInterval: 100,
71
+ evalInterval: 100,
72
+ logConsole: true,
73
+ logJsonl: true,
74
+ maxCheckpoints: 3,
75
+ };
76
+ /**
77
+ * Compute a hash of dataset content for identity checking on resume.
78
+ *
79
+ * Hashes the first N examples to create a fingerprint that can detect
80
+ * if the dataset has been modified between training runs.
81
+ *
82
+ * @param dataset - Array of dataset examples
83
+ * @param sampleSize - Number of examples to hash (default: 10)
84
+ * @returns 16-character hex hash string
85
+ */
86
+ export function computeDatasetHash(dataset, sampleSize = 10) {
87
+ const samples = dataset.slice(0, sampleSize);
88
+ // Create a content string from prompts (stringified for consistency)
89
+ const content = samples.map((ex) => JSON.stringify(ex.prompt)).join('|||');
90
+ return createHash('sha256').update(content).digest('hex').slice(0, 16);
91
+ }
92
+ // Note: RewardOutput is now built using the Rust buildRewardOutputs function
93
+ // which handles tool call parsing and thinking extraction natively.
94
+ /**
95
+ * Error thrown when a reward function times out.
96
+ */
97
+ export class RewardTimeoutError extends Error {
98
+ timeoutMs;
99
+ constructor(message, timeoutMs) {
100
+ super(message);
101
+ this.timeoutMs = timeoutMs;
102
+ this.name = 'RewardTimeoutError';
103
+ }
104
+ }
105
+ /**
106
+ * Wraps a promise with a timeout.
107
+ *
108
+ * @param promise - The promise to wrap
109
+ * @param timeoutMs - Timeout in milliseconds (0 = no timeout)
110
+ * @param errorMessage - Error message if timeout is reached
111
+ * @returns The promise result or throws RewardTimeoutError
112
+ */
113
+ function withTimeout(promise, timeoutMs, errorMessage) {
114
+ // Timeout of 0 means no timeout
115
+ if (timeoutMs <= 0) {
116
+ return promise;
117
+ }
118
+ return new Promise((resolve, reject) => {
119
+ const timer = setTimeout(() => {
120
+ reject(new RewardTimeoutError(errorMessage, timeoutMs));
121
+ }, timeoutMs);
122
+ promise
123
+ .then((result) => {
124
+ clearTimeout(timer);
125
+ resolve(result);
126
+ })
127
+ .catch((error) => {
128
+ clearTimeout(timer);
129
+ reject(error);
130
+ });
131
+ });
132
+ }
133
+ /**
134
+ * GRPO Trainer - Rust-Native Training Engine
135
+ *
136
+ * Provides a TypeScript-friendly interface to the Rust training engine.
137
+ * Supports both high-level training (train()) and low-level step-by-step (trainStep()).
138
+ */
139
+ export class GRPOTrainer {
140
+ engine;
141
+ model;
142
+ config;
143
+ rewardFn;
144
+ currentEpoch = 0;
145
+ currentStep = 0;
146
+ /** Original model path (for tokenizer files when saving checkpoints) */
147
+ originalModelPath;
148
+ // TUI state
149
+ paused = false;
150
+ stopRequested = false;
151
+ stdinInterface;
152
+ logger;
153
+ sampleDisplayMode = 'all';
154
+ // Output recording
155
+ outputStore;
156
+ outputStoreInitPromise;
157
+ outputStoreRunId;
158
+ outputStorePath;
159
+ // Crash recovery
160
+ lastCheckpointStep = 0;
161
+ signalHandlersInstalled = false;
162
+ // Last known good checkpoint tracking (for NaN gradient recovery)
163
+ lastGoodCheckpointPath = null;
164
+ lastGoodCheckpointStep = 0;
165
+ // Dataset tracking for resume validation
166
+ datasetMetadata;
167
+ processedBatchIndices = new Set();
168
+ /**
169
+ * Create a new GRPO trainer from a model
170
+ *
171
+ * @param model - Pre-loaded Qwen3 model
172
+ * @param config - Training configuration
173
+ */
174
+ constructor(model, config = {}, logger) {
175
+ // Auto-detect TUI mode from environment variable (set by mlx-train TUI)
176
+ const tuiModeFromEnv = process.env.MLX_TUI_MODE === '1';
177
+ if (tuiModeFromEnv && config.tuiMode === undefined) {
178
+ config.tuiMode = true;
179
+ }
180
+ // Auto-enable database persistence in TUI mode (enables Database tab)
181
+ if (tuiModeFromEnv && config.outputStore === undefined) {
182
+ config.outputStore = { enabled: true };
183
+ }
184
+ this.config = { ...DEFAULT_GRPO_CONFIG, ...config };
185
+ this.model = model;
186
+ // Create or use provided logger (TUI mode auto-detected from MLX_TUI_MODE env var)
187
+ this.logger =
188
+ logger ??
189
+ createTrainingLogger({
190
+ logConsole: this.config.logConsole,
191
+ logJsonl: this.config.logJsonl,
192
+ outputDir: this.config.outputDir,
193
+ runName: this.config.runName,
194
+ logInterval: this.config.logInterval ?? 1,
195
+ });
196
+ // Set reward function if provided
197
+ if (this.config.rewardFunction) {
198
+ this.rewardFn = this.config.rewardFunction;
199
+ }
200
+ // Convert to native config
201
+ const engineConfig = {
202
+ learningRate: this.config.learningRate,
203
+ gradientAccumulationSteps: this.config.gradientAccumulationSteps,
204
+ gradientClipNorm: this.config.gradientClipNorm,
205
+ groupSize: this.config.groupSize,
206
+ clipEpsilon: this.config.clipEpsilon,
207
+ klCoef: this.config.klCoef,
208
+ lossType: this.config.lossType,
209
+ maxCompletionLength: this.config.maxCompletionLength,
210
+ temperature: this.config.temperature,
211
+ topP: this.config.topP,
212
+ topK: this.config.topK,
213
+ repetitionPenalty: this.config.repetitionPenalty,
214
+ // Tool calling support
215
+ tools: this.config.tools,
216
+ enableThinking: this.config.enableThinking,
217
+ // Memory optimization
218
+ lmHeadChunkSize: this.config.lmHeadChunkSize,
219
+ forwardChunkSize: this.config.forwardChunkSize,
220
+ vocabChunkSize: this.config.vocabChunkSize,
221
+ // Parallel batch generation
222
+ useParallelBatchGeneration: this.config.useParallelBatchGeneration,
223
+ };
224
+ this.engine = new GrpoTrainingEngine(model, engineConfig);
225
+ // Setup stdin handler if TUI mode
226
+ if (this.config.tuiMode) {
227
+ this.setupStdinHandler();
228
+ }
229
+ // Always setup signal handlers for crash recovery
230
+ this.setupSignalHandlers();
231
+ }
232
+ /**
233
+ * Setup stdin handler for TUI control commands
234
+ */
235
+ setupStdinHandler() {
236
+ if (!this.config.tuiMode)
237
+ return;
238
+ this.stdinInterface = readline.createInterface({
239
+ input: process.stdin,
240
+ output: process.stdout,
241
+ terminal: false,
242
+ });
243
+ this.stdinInterface.on('line', (line) => {
244
+ const cmd = line.trim();
245
+ this.handleStdinCommand(cmd);
246
+ });
247
+ }
248
+ /**
249
+ * Setup OS signal handlers for graceful shutdown on crash/interrupt
250
+ *
251
+ * Catches SIGTERM, SIGINT, and uncaught exceptions to:
252
+ * - Save emergency checkpoint (if > 10 steps since last)
253
+ * - Finalize OutputStore with 'crashed' status
254
+ * - Exit cleanly
255
+ */
256
+ setupSignalHandlers() {
257
+ if (this.signalHandlersInstalled)
258
+ return;
259
+ this.signalHandlersInstalled = true;
260
+ const gracefulShutdown = async (signal) => {
261
+ this.logger.warn(`Received ${signal}, initiating graceful shutdown...`);
262
+ this.stopRequested = true;
263
+ try {
264
+ // Skip checkpoint if recent one exists (within 10 steps)
265
+ const stepsSinceCheckpoint = this.currentStep - this.lastCheckpointStep;
266
+ if (this.config.outputDir && stepsSinceCheckpoint > 10) {
267
+ this.logger.info(`Saving emergency checkpoint (${stepsSinceCheckpoint} steps since last)...`);
268
+ await this.saveCheckpoint(`emergency-${this.currentStep}`);
269
+ }
270
+ else if (stepsSinceCheckpoint <= 10) {
271
+ this.logger.info(`Skipping checkpoint (only ${stepsSinceCheckpoint} steps since last)`);
272
+ }
273
+ // Finalize OutputStore with crashed status
274
+ if (this.outputStore) {
275
+ await this.outputStore.endRun('crashed');
276
+ await this.outputStore.flush();
277
+ this.logger.info('OutputStore finalized with crashed status');
278
+ }
279
+ }
280
+ catch (e) {
281
+ console.error('Emergency save failed:', e);
282
+ }
283
+ // Cleanup stdin interface
284
+ if (this.stdinInterface) {
285
+ this.stdinInterface.close();
286
+ }
287
+ process.exit(0);
288
+ };
289
+ process.on('SIGTERM', () => {
290
+ gracefulShutdown('SIGTERM').catch(console.error);
291
+ });
292
+ process.on('SIGINT', () => {
293
+ gracefulShutdown('SIGINT').catch(console.error);
294
+ });
295
+ process.on('uncaughtException', (err) => {
296
+ console.error('Uncaught exception:', err);
297
+ gracefulShutdown('uncaughtException').catch(console.error);
298
+ });
299
+ process.on('unhandledRejection', (reason) => {
300
+ console.error('Unhandled rejection:', reason);
301
+ gracefulShutdown('unhandledRejection').catch(console.error);
302
+ });
303
+ }
304
+ /**
305
+ * Initialize the output store for recording training outputs
306
+ */
307
+ async initOutputStore(stepsPerEpoch) {
308
+ // Guard against re-initialization (e.g., train() called after trainStepAuto())
309
+ if (this.outputStore)
310
+ return;
311
+ const cfg = this.config.outputStore;
312
+ if (!cfg?.enabled) {
313
+ // Even without outputStore, send UI resume state if resuming from checkpoint
314
+ if (this.config.resumeFromCheckpoint && this.currentStep > 0 && stepsPerEpoch) {
315
+ this.sendResumeStateUiOnly(stepsPerEpoch);
316
+ }
317
+ return;
318
+ }
319
+ const localPath = cfg.localPath ?? join(this.config.outputDir ?? '.', 'outputs.db');
320
+ this.outputStorePath = localPath;
321
+ // Ensure parent directory exists (for lazy init via trainStepAuto)
322
+ const parentDir = dirname(localPath);
323
+ if (parentDir !== '.' && !existsSync(parentDir)) {
324
+ mkdirSync(parentDir, { recursive: true });
325
+ }
326
+ this.outputStore = await OutputStore.local(localPath);
327
+ // If resuming from checkpoint AND we have a run name AND checkpoint was actually loaded,
328
+ // try to resume the existing database run. currentStep > 0 means checkpoint was loaded.
329
+ if (this.config.resumeFromCheckpoint && this.config.runName && this.currentStep > 0) {
330
+ const existingRun = await this.outputStore.findRunByName(this.config.runName);
331
+ if (existingRun) {
332
+ this.logger.info(`Resuming database run: ${this.config.runName} (${existingRun.id})`);
333
+ await this.outputStore.resumeRun(existingRun.id);
334
+ this.outputStoreRunId = existingRun.id;
335
+ // Clean up any database records that are ahead of the checkpoint step
336
+ // This prevents UNIQUE constraint errors when re-recording steps
337
+ // Uses cascade delete to also clean up orphaned generations, tool_calls, and logs
338
+ if (this.currentStep > 0) {
339
+ const cleanupStats = await this.outputStore.deleteAllAfterStep(existingRun.id, this.currentStep);
340
+ if (cleanupStats.stepsDeleted > 0) {
341
+ this.logger.info(`Cleaned up stale records after step ${this.currentStep}: ` +
342
+ `${cleanupStats.stepsDeleted} steps, ${cleanupStats.generationsDeleted} generations, ` +
343
+ `${cleanupStats.logsDeleted} logs`);
344
+ }
345
+ }
346
+ this.logger.databasePath(localPath, this.outputStoreRunId, this.config.runName ?? undefined);
347
+ // Send resume state to TUI for sparkline and aggregate restoration
348
+ await this.sendResumeState(existingRun.id, stepsPerEpoch);
349
+ return;
350
+ }
351
+ else {
352
+ // Run name specified but not found - warn and create new
353
+ this.logger.warn(`No existing run found with name: ${this.config.runName}. Starting new run.`);
354
+ }
355
+ }
356
+ // Start a new run with sanitized config (no auth token)
357
+ const modelName = this.config.modelName ?? 'qwen3';
358
+ const modelPath = this.originalModelPath ?? this.config.modelPath ?? undefined;
359
+ const sanitizedConfig = {
360
+ ...this.config,
361
+ outputStore: this.config.outputStore ? { ...this.config.outputStore, authToken: undefined } : undefined,
362
+ };
363
+ // Use startRunWithName if a run name is provided, otherwise use startRun
364
+ if (this.config.runName) {
365
+ this.outputStoreRunId = await this.outputStore.startRunWithName(this.config.runName, modelName, modelPath, JSON.stringify(sanitizedConfig));
366
+ }
367
+ else {
368
+ this.outputStoreRunId = await this.outputStore.startRun(modelName, modelPath, JSON.stringify(sanitizedConfig));
369
+ }
370
+ // Emit database path to TUI for the Database tab
371
+ this.logger.databasePath(localPath, this.outputStoreRunId, this.config.runName ?? undefined);
372
+ // If resuming from checkpoint but didn't find existing DB run, still send UI state
373
+ // This ensures TUI displays correct batch progress even without historical data
374
+ if (this.config.resumeFromCheckpoint && this.currentStep > 0 && stepsPerEpoch) {
375
+ this.sendResumeStateUiOnly(stepsPerEpoch);
376
+ }
377
+ }
378
+ /**
379
+ * Send minimal resume state to TUI for UI display only (no historical data)
380
+ *
381
+ * Used when resuming from checkpoint without a matching database run.
382
+ * Ensures TUI shows correct epoch/batch progress.
383
+ */
384
+ sendResumeStateUiOnly(stepsPerEpoch) {
385
+ if (!this.logger.isTuiMode)
386
+ return;
387
+ const totalEpochs = this.config.numEpochs ?? 1;
388
+ const stepInEpoch = this.currentStep > 0 ? ((this.currentStep - 1) % stepsPerEpoch) + 1 : 0;
389
+ this.logger.resumeState({
390
+ step: this.currentStep,
391
+ epoch: this.currentEpoch + 1, // 1-indexed
392
+ totalEpochs,
393
+ stepInEpoch,
394
+ totalStepsInEpoch: stepsPerEpoch,
395
+ metricsHistory: [], // No historical data
396
+ aggregates: {
397
+ bestReward: 0,
398
+ avgReward: 0,
399
+ rewardCount: 0,
400
+ bestLoss: Infinity,
401
+ avgLoss: 0,
402
+ lossCount: 0,
403
+ totalTokens: 0,
404
+ avgGenerationTimeMs: 0,
405
+ avgTrainingTimeMs: 0,
406
+ },
407
+ });
408
+ this.logger.info(`Sent UI resume state to TUI (step ${this.currentStep}, no historical data)`);
409
+ }
410
+ /**
411
+ * Send resume state to TUI for restoring sparklines and aggregates
412
+ *
413
+ * Queries the database for historical metrics and aggregates, then sends
414
+ * to TUI via the resumeState message.
415
+ *
416
+ * @param runId - Database run ID
417
+ * @param actualStepsPerEpoch - Actual steps per epoch from dataset (if known)
418
+ */
419
+ async sendResumeState(runId, actualStepsPerEpoch) {
420
+ if (!this.outputStore || !this.logger.isTuiMode)
421
+ return;
422
+ try {
423
+ // Query historical metrics (last 60 for sparklines)
424
+ const metricsHistory = await this.outputStore.getRecentStepMetrics(runId, 60);
425
+ // Query aggregate statistics
426
+ const aggregates = await this.outputStore.getRunAggregates(runId);
427
+ // Use actual steps per epoch if provided, otherwise use a reasonable default
428
+ const totalEpochs = this.config.numEpochs ?? 1;
429
+ const stepsPerEpoch = actualStepsPerEpoch ?? 50;
430
+ // Calculate step within current epoch
431
+ const stepInEpoch = this.currentStep > 0 ? ((this.currentStep - 1) % stepsPerEpoch) + 1 : 0;
432
+ // Send resume state to TUI
433
+ // Note: epoch is 1-indexed to match epochStart() convention
434
+ this.logger.resumeState({
435
+ step: this.currentStep,
436
+ epoch: this.currentEpoch + 1,
437
+ totalEpochs,
438
+ stepInEpoch,
439
+ totalStepsInEpoch: stepsPerEpoch,
440
+ metricsHistory: metricsHistory.map((m) => ({
441
+ step: Number(m.step),
442
+ loss: m.loss,
443
+ meanReward: m.meanReward,
444
+ stdAdvantage: m.stdAdvantage,
445
+ perplexity: m.perplexity ?? undefined,
446
+ tokenAccuracy: m.tokenAccuracy ?? undefined,
447
+ generationTimeMs: m.generationTimeMs ?? undefined,
448
+ trainingTimeMs: m.trainingTimeMs ?? undefined,
449
+ })),
450
+ aggregates: {
451
+ bestReward: aggregates.bestReward,
452
+ avgReward: aggregates.avgReward,
453
+ rewardCount: Number(aggregates.rewardCount),
454
+ bestLoss: aggregates.bestLoss,
455
+ avgLoss: aggregates.avgLoss,
456
+ lossCount: Number(aggregates.lossCount),
457
+ totalTokens: Number(aggregates.totalTokens),
458
+ avgGenerationTimeMs: aggregates.avgGenerationTimeMs,
459
+ avgTrainingTimeMs: aggregates.avgTrainingTimeMs,
460
+ },
461
+ });
462
+ this.logger.info(`Sent ${metricsHistory.length} historical metrics to TUI`);
463
+ }
464
+ catch (err) {
465
+ this.logger.warn(`Failed to send resume state to TUI: ${err}`);
466
+ }
467
+ }
468
+ /**
469
+ * Ensure output store is initialized (lazy initialization for low-level API users)
470
+ * Uses promise mutex to prevent race conditions from concurrent calls.
471
+ *
472
+ * Call this method from custom training loops before starting training
473
+ * to enable database recording and TUI database tab.
474
+ */
475
+ async ensureOutputStoreInitialized() {
476
+ if (this.outputStore)
477
+ return; // Already initialized
478
+ if (!this.config.outputStore?.enabled)
479
+ return; // Not enabled
480
+ // Use promise mutex to prevent concurrent initialization
481
+ if (this.outputStoreInitPromise) {
482
+ await this.outputStoreInitPromise;
483
+ return;
484
+ }
485
+ this.outputStoreInitPromise = this.initOutputStore();
486
+ try {
487
+ await this.outputStoreInitPromise;
488
+ }
489
+ catch (err) {
490
+ // Clear promise on failure to allow retry
491
+ this.outputStoreInitPromise = undefined;
492
+ throw err;
493
+ }
494
+ }
495
+ /**
496
+ * Get the output store (for querying recorded data)
497
+ */
498
+ getOutputStore() {
499
+ return this.outputStore;
500
+ }
501
+ /**
502
+ * Handle a command received from stdin
503
+ */
504
+ handleStdinCommand(cmd) {
505
+ switch (cmd) {
506
+ case 'PAUSE':
507
+ this.paused = true;
508
+ this.logger.paused(this.currentStep);
509
+ break;
510
+ case 'RESUME':
511
+ this.paused = false;
512
+ this.logger.resumed(this.currentStep);
513
+ break;
514
+ case 'SAVE_CHECKPOINT':
515
+ // Will be handled in the training loop
516
+ this.saveCheckpoint().catch(() => { });
517
+ break;
518
+ case 'STOP':
519
+ this.stopRequested = true;
520
+ break;
521
+ default:
522
+ // Handle SET commands (e.g., SET sample_display=best_worst)
523
+ if (cmd.startsWith('SET ')) {
524
+ const keyValue = cmd.slice(4); // Remove 'SET ' prefix
525
+ const eqIdx = keyValue.indexOf('=');
526
+ if (eqIdx > 0) {
527
+ const key = keyValue.slice(0, eqIdx);
528
+ const value = keyValue.slice(eqIdx + 1);
529
+ if (key === 'sample_display') {
530
+ if (value === 'all' || value === 'best_worst' || value === 'random') {
531
+ this.sampleDisplayMode = value;
532
+ }
533
+ }
534
+ }
535
+ }
536
+ break;
537
+ }
538
+ }
539
+ /**
540
+ * Wait for resume if paused, with polling
541
+ */
542
+ async waitForResume() {
543
+ while (this.paused && !this.stopRequested) {
544
+ await new Promise((resolve) => setTimeout(resolve, 100));
545
+ }
546
+ }
547
+ /**
548
+ * Create a trainer by loading a model from disk
549
+ *
550
+ * This is the recommended way to create a trainer for training runs.
551
+ * If resumeFromCheckpoint is set, loads from checkpoint instead of modelPath.
552
+ *
553
+ * @param config - Configuration including modelPath
554
+ * @returns Promise<GRPOTrainer>
555
+ */
556
+ static async create(config) {
557
+ if (!config.modelPath) {
558
+ throw new Error('modelPath is required when using GRPOTrainer.create()');
559
+ }
560
+ // Validate unsupported config options (fail fast)
561
+ if (config.advantageNormalization === false) {
562
+ throw new Error('advantageNormalization=false is not yet supported. Remove this option or set to true.');
563
+ }
564
+ if (config.weightDecay !== undefined && config.weightDecay !== 0.01) {
565
+ throw new Error('Custom weightDecay is not yet implemented. Optimizer uses simple SGD. Remove weightDecay from config or use default (0.01).');
566
+ }
567
+ if (config.rewardType === 'model') {
568
+ throw new Error('rewardType="model" is not implemented. Use rewardType="function" with a custom reward function.');
569
+ }
570
+ if (config.rewardModelPath) {
571
+ throw new Error('rewardModelPath is not implemented. Use rewardType="function" with a custom reward function.');
572
+ }
573
+ if (config.device && config.device !== 'metal') {
574
+ throw new Error(`device="${config.device}" is not supported. MLX only supports Metal GPU. Remove device from config.`);
575
+ }
576
+ // Create logger early (before model loading)
577
+ // TUI mode is auto-detected from MLX_TUI_MODE env var (set by mlx-tui)
578
+ const logger = createTrainingLogger({
579
+ logConsole: config.logConsole,
580
+ logJsonl: config.logJsonl,
581
+ outputDir: config.outputDir,
582
+ runName: config.runName,
583
+ logInterval: config.logInterval ?? 1,
584
+ });
585
+ let modelPath = config.modelPath;
586
+ let resumedState = null;
587
+ // Handle checkpoint resumption
588
+ if (config.resumeFromCheckpoint) {
589
+ const checkpointPath = config.resumeFromCheckpoint === 'latest'
590
+ ? GRPOTrainer.findLatestCheckpoint(config.outputDir)
591
+ : config.resumeFromCheckpoint;
592
+ if (checkpointPath) {
593
+ const statePath = join(checkpointPath, 'training_state.json');
594
+ if (existsSync(statePath)) {
595
+ resumedState = JSON.parse(readFileSync(statePath, 'utf-8'));
596
+ // Fallback: If training_state.json has step 0 but checkpoint name suggests otherwise,
597
+ // derive step from checkpoint name (e.g., checkpoint-8 → step 8)
598
+ // This handles cases where training_state.json was corrupted or overwritten
599
+ if (resumedState && resumedState.step === 0) {
600
+ const checkpointName = checkpointPath.split('/').pop() ?? '';
601
+ const match = checkpointName.match(/^checkpoint-(\d+)$/);
602
+ if (match) {
603
+ const derivedStep = parseInt(match[1], 10);
604
+ if (derivedStep > 0) {
605
+ logger.warn(`Checkpoint ${checkpointName} has step 0 in training_state.json but name suggests step ${derivedStep}. Using ${derivedStep}.`);
606
+ resumedState.step = derivedStep;
607
+ // Estimate epoch from step (will be refined by actual training data)
608
+ }
609
+ }
610
+ }
611
+ logger.info(`Resuming from checkpoint: ${checkpointPath} (step ${resumedState?.step}, epoch ${resumedState?.epoch})`);
612
+ }
613
+ // Load model weights from checkpoint
614
+ modelPath = checkpointPath;
615
+ }
616
+ else if (config.resumeFromCheckpoint === 'latest') {
617
+ logger.info('No checkpoint found, starting fresh training');
618
+ }
619
+ }
620
+ // Get model name for display
621
+ const modelName = modelPath.split('/').pop() ?? 'Unknown';
622
+ logger.status('loading', `Loading ${modelName}...`);
623
+ // Load model from disk (checkpoint or original)
624
+ const model = await Qwen3Model.loadPretrained(modelPath);
625
+ logger.status('loading', `${modelName} loaded`);
626
+ // Create trainer with the pre-created logger
627
+ const trainer = new GRPOTrainer(model, config, logger);
628
+ // Always store the original model path (for tokenizer files when saving checkpoints)
629
+ trainer.originalModelPath = config.modelPath;
630
+ // Restore training state if resuming
631
+ if (resumedState) {
632
+ trainer.currentStep = resumedState.step;
633
+ trainer.currentEpoch = resumedState.epoch;
634
+ // Restore dataset metadata for resume validation
635
+ if (resumedState.dataset) {
636
+ trainer.datasetMetadata = {
637
+ size: resumedState.dataset.size,
638
+ contentHash: resumedState.dataset.contentHash,
639
+ shuffleSeed: resumedState.dataset.shuffleSeed,
640
+ };
641
+ // Restore processed batch indices
642
+ if (resumedState.dataset.processedBatchIndices) {
643
+ trainer.processedBatchIndices = new Set(resumedState.dataset.processedBatchIndices);
644
+ }
645
+ logger.debug(`Restored dataset metadata: size=${resumedState.dataset.size}, hash=${resumedState.dataset.contentHash}, ` +
646
+ `${trainer.processedBatchIndices.size} processed batches`);
647
+ }
648
+ // If resuming from a regular checkpoint (not emergency), track it as last known good
649
+ // This allows recovery to fall back to this checkpoint if NaN gradients occur
650
+ if (modelPath && !modelPath.includes('emergency-')) {
651
+ trainer.lastGoodCheckpointPath = modelPath;
652
+ trainer.lastGoodCheckpointStep = resumedState.step;
653
+ logger.debug(`Initialized last good checkpoint from resumed state: step ${resumedState.step}`);
654
+ }
655
+ }
656
+ return trainer;
657
+ }
658
+ /**
659
+ * Find the latest checkpoint in the output directory
660
+ */
661
+ static findLatestCheckpoint(outputDir) {
662
+ if (!outputDir || !existsSync(outputDir)) {
663
+ return null;
664
+ }
665
+ const entries = readdirSync(outputDir, { withFileTypes: true });
666
+ const checkpoints = entries
667
+ .filter((e) => e.isDirectory() && e.name.startsWith('checkpoint-'))
668
+ .map((e) => ({
669
+ name: e.name,
670
+ step: parseInt(e.name.replace('checkpoint-', ''), 10),
671
+ path: join(outputDir, e.name),
672
+ }))
673
+ .filter((c) => !isNaN(c.step))
674
+ .sort((a, b) => b.step - a.step);
675
+ return checkpoints.length > 0 ? checkpoints[0].path : null;
676
+ }
677
+ /**
678
+ * Register a built-in reward function
679
+ *
680
+ * Built-in rewards run entirely in Rust with no FFI overhead.
681
+ *
682
+ * @example
683
+ * ```typescript
684
+ * // Tool use validation
685
+ * trainer.registerBuiltinReward({
686
+ * rewardType: 'ToolUse',
687
+ * allowedTools: ['search', 'calculate'],
688
+ * required: true,
689
+ * weight: 1.0,
690
+ * });
691
+ *
692
+ * // XML format validation
693
+ * trainer.registerBuiltinReward({
694
+ * rewardType: 'XmlFormat',
695
+ * requiredTags: ['thinking', 'answer'],
696
+ * weight: 0.5,
697
+ * });
698
+ *
699
+ * // Length-based reward
700
+ * trainer.registerBuiltinReward({
701
+ * rewardType: 'Length',
702
+ * minLength: 100,
703
+ * maxLength: 500,
704
+ * useChars: true,
705
+ * });
706
+ * ```
707
+ */
708
+ registerBuiltinReward(config) {
709
+ this.engine.registerBuiltinReward(config);
710
+ }
711
+ /**
712
+ * Set a custom JavaScript reward function
713
+ *
714
+ * The function will be called after generation to compute rewards.
715
+ *
716
+ * @param fn - Reward function that takes prompts and completions
717
+ */
718
+ setRewardFunction(fn) {
719
+ this.rewardFn = fn;
720
+ }
721
+ /**
722
+ * Generate completions for prompts
723
+ *
724
+ * Generates `groupSize` completions per prompt.
725
+ * Returns all data needed for training, including tokens and log probabilities.
726
+ *
727
+ * @param prompts - Array of chat conversations
728
+ * @returns GenerateBatchResult with completion texts and native generation data
729
+ */
730
+ async generateBatch(prompts) {
731
+ if (prompts.length === 0) {
732
+ return {
733
+ completionTexts: [],
734
+ nativeResult: {
735
+ completionTexts: [],
736
+ completionTokens: [],
737
+ completionLogprobs: [],
738
+ completionLengths: [],
739
+ finishReasons: [],
740
+ },
741
+ tokenCounts: [],
742
+ finishReasons: [],
743
+ };
744
+ }
745
+ // Call the native engine to generate completions with full data
746
+ const nativeResult = await this.engine.generateBatchForTraining(prompts);
747
+ return {
748
+ completionTexts: nativeResult.completionTexts,
749
+ nativeResult,
750
+ tokenCounts: nativeResult.completionLengths,
751
+ finishReasons: nativeResult.finishReasons,
752
+ };
753
+ }
754
+ /**
755
+ * Score completions using built-in rewards
756
+ *
757
+ * @param prompts - Prompt texts (one per completion)
758
+ * @param completions - Completion texts
759
+ * @returns Array of reward scores
760
+ */
761
+ scoreCompletions(prompts, completions) {
762
+ return this.engine.scoreCompletions(prompts, completions);
763
+ }
764
+ /**
765
+ * Score generations using the configured reward function.
766
+ *
767
+ * Builds RewardOutput array with structured completion data and passes to reward function.
768
+ *
769
+ * @param prompts - Array of chat conversations
770
+ * @param completions - Generated completion texts
771
+ * @param context - Context for the reward function
772
+ * @param groupSize - Number of completions per prompt (optional, defaults to config.groupSize)
773
+ * @param tokenCounts - Token counts for each completion (optional, defaults to 0s)
774
+ * @param finishReasons - Finish reasons from generation (optional, e.g. "eos", "length", "repetition")
775
+ * @returns Promise<Float32Array> of reward scores
776
+ */
777
+ async scoreGenerations(prompts, completions, context, groupSize, tokenCounts, finishReasons) {
778
+ const effectiveGroupSize = groupSize ?? this.config.groupSize ?? 4;
779
+ const expectedCompletions = prompts.length * effectiveGroupSize;
780
+ if (completions.length !== expectedCompletions) {
781
+ throw new Error(`Expected ${expectedCompletions} completions (${prompts.length} prompts × ${effectiveGroupSize} groupSize) but got ${completions.length}`);
782
+ }
783
+ if (!this.rewardFn && !this.engine.hasBuiltinRewards) {
784
+ throw new Error('No reward function configured. Set rewardFunction in config or call setRewardFunction()');
785
+ }
786
+ // Convert ChatMessage[][] to string[] for Rust function
787
+ const promptTexts = prompts.map((msgs) => msgs.map((m) => `${m.role}: ${m.content}`).join('\n'));
788
+ // Use provided token counts or default to 0
789
+ const effectiveTokenCounts = tokenCounts ?? completions.map(() => 0);
790
+ // Use provided finish reasons or default to empty (triggers inference fallback in Rust)
791
+ const effectiveFinishReasons = finishReasons ?? [];
792
+ // Build structured reward outputs using Rust function
793
+ const rewardOutputs = buildRewardOutputs(promptTexts, completions, effectiveTokenCounts, effectiveFinishReasons, effectiveGroupSize);
794
+ let rewards;
795
+ if (this.rewardFn) {
796
+ // Get timeout from config (default 60 seconds, 0 = no timeout)
797
+ const rewardTimeout = this.config.rewardTimeout ?? 60_000;
798
+ // Wrap reward function call with timeout
799
+ const rewardPromise = Promise.resolve(this.rewardFn(rewardOutputs, context));
800
+ rewards = await withTimeout(rewardPromise, rewardTimeout, `Reward function timed out after ${rewardTimeout}ms. ` +
801
+ `Consider increasing rewardTimeout in config or optimizing your reward function.`);
802
+ }
803
+ else {
804
+ // For built-in rewards, extract prompts and completions for legacy API
805
+ const promptStrings = rewardOutputs.map((o) => o.prompt);
806
+ const completionTexts = rewardOutputs.map((o) => o.completion.rawText);
807
+ rewards = this.scoreCompletions(promptStrings, completionTexts);
808
+ }
809
+ const rewardsArray = rewards instanceof Float32Array ? rewards : Float32Array.from(rewards);
810
+ if (rewardsArray.length !== expectedCompletions) {
811
+ throw new Error(`Reward function returned ${rewardsArray.length} rewards but expected ${expectedCompletions}`);
812
+ }
813
+ return rewardsArray;
814
+ }
815
+ /**
816
+ * Run a training step
817
+ *
818
+ * This method:
819
+ * 1. Generates completions with tokens and log probabilities
820
+ * 2. Computes rewards using the configured reward function
821
+ * 3. Trains using the SAME completions that were scored (no double-generation)
822
+ *
823
+ * @param prompts - Array of chat conversations
824
+ * @returns Training step metrics
825
+ */
826
+ async trainStep(prompts, context) {
827
+ const { metrics } = await this.trainStepAuto(prompts, context);
828
+ return metrics;
829
+ }
830
+ /**
831
+ * Run a complete training step with automatic reward computation
832
+ *
833
+ * This method combines generation, reward scoring, and training into a single
834
+ * Rust call, eliminating FFI overhead by keeping token data in Rust memory.
835
+ *
836
+ * 1. Generates completions with full token/logprob data (stays in Rust)
837
+ * 2. Calls JS reward function with RewardOutput[]
838
+ * 3. Performs training update using the in-memory data
839
+ *
840
+ * @param prompts - Array of chat conversations
841
+ * @param context - Context for the reward function
842
+ * @returns Training metrics and generated completions
843
+ */
844
+ async trainStepAuto(prompts, context) {
845
+ // Lazy initialize output store for low-level API users
846
+ await this.ensureOutputStoreInitialized();
847
+ if (!this.rewardFn && !this.engine.hasBuiltinRewards) {
848
+ throw new Error('No reward function configured. Set rewardFunction in config or call setRewardFunction()');
849
+ }
850
+ // Create reward callback that parses JSON and converts output to number[]
851
+ // The Rust side serializes Vec<RewardOutput> to JSON because complex nested types
852
+ // don't convert properly through ThreadsafeFunction
853
+ // Note: With CalleeHandled=true (default), callback receives (err, value) format
854
+ // Using ThreadsafeFunction<T, Promise<R>> pattern so Rust can await the Promise
855
+ const rewardCallback = async (err, outputsJson) => {
856
+ const rewardStart = Date.now();
857
+ if (err) {
858
+ throw new Error(`Reward callback error from Rust: ${err.message}`);
859
+ }
860
+ if (!outputsJson || outputsJson === 'null') {
861
+ throw new Error(`Invalid JSON received from Rust: ${outputsJson}`);
862
+ }
863
+ // Parse JSON and convert snake_case to camelCase for TypeScript compatibility
864
+ // Rust's serde serializes as snake_case but TypeScript expects camelCase
865
+ const rawOutputs = JSON.parse(outputsJson);
866
+ // Convert to RewardOutput format with proper camelCase properties
867
+ const outputs = rawOutputs.map((o) => ({
868
+ prompt: o.prompt,
869
+ completion: {
870
+ text: o.completion.text,
871
+ rawText: o.completion.raw_text,
872
+ toolCalls: o.completion.tool_calls.map((tc) => ({
873
+ id: tc.id,
874
+ name: tc.name,
875
+ arguments: tc.arguments,
876
+ status: tc.status, // 'ok' | 'invalid_json' | 'missing_name'
877
+ error: tc.error,
878
+ rawContent: tc.raw_content,
879
+ })),
880
+ thinking: o.completion.thinking ?? undefined,
881
+ numTokens: o.completion.num_tokens,
882
+ finishReason: o.completion.finish_reason,
883
+ },
884
+ expectedAnswer: o.expected_answer ?? undefined,
885
+ }));
886
+ this.logger.info(` → Computing rewards for ${outputs.length} completions...`);
887
+ let rewards;
888
+ if (this.rewardFn) {
889
+ // Get timeout from config (default 60 seconds, 0 = no timeout)
890
+ const rewardTimeout = this.config.rewardTimeout ?? 60_000;
891
+ // Wrap reward function call with timeout
892
+ const rewardPromise = Promise.resolve(
893
+ // @ts-expect-error context is optional
894
+ this.rewardFn(outputs, context));
895
+ rewards = await withTimeout(rewardPromise, rewardTimeout, `Reward function timed out after ${rewardTimeout}ms. ` +
896
+ `Consider increasing rewardTimeout in config or optimizing your reward function.`);
897
+ }
898
+ else {
899
+ // Use built-in rewards
900
+ const promptStrings = outputs.map((o) => o.prompt);
901
+ const completionTexts = outputs.map((o) => o.completion.rawText);
902
+ rewards = this.scoreCompletions(promptStrings, completionTexts);
903
+ }
904
+ // Convert Float32Array to plain number[] for NAPI compatibility
905
+ let result;
906
+ if (rewards instanceof Float32Array) {
907
+ result = Array.from(rewards, (v) => Number(v));
908
+ }
909
+ else {
910
+ result = rewards.map((v) => Number(v));
911
+ }
912
+ const rewardDuration = Date.now() - rewardStart;
913
+ const avgReward = result.reduce((a, b) => a + b, 0) / result.length;
914
+ this.logger.info(` → Rewards computed in ${rewardDuration}ms (avg=${avgReward.toFixed(2)})`);
915
+ return result;
916
+ };
917
+ // Call unified Rust method - generation, scoring, and training in one FFI call
918
+ // Use recording method if output store is enabled
919
+ const recordOutputs = !!this.outputStore;
920
+ const result = await this.engine.trainStepAuto(prompts, rewardCallback, recordOutputs);
921
+ this.currentStep++;
922
+ // Record outputs to database if enabled
923
+ if (this.outputStore && result.outputsJson) {
924
+ try {
925
+ await this.outputStore.recordStepFromOutputs(this.currentStep, result.metrics, result.outputsJson, result.rewards, this.config.groupSize ?? 4);
926
+ }
927
+ catch (err) {
928
+ // Log error but don't fail training
929
+ console.error('[OutputStore] Failed to record step:', err);
930
+ }
931
+ }
932
+ return {
933
+ metrics: { ...result.metrics, epoch: this.currentEpoch },
934
+ completions: result.completions,
935
+ rewards: result.rewards,
936
+ completionLengths: result.completionLengths,
937
+ };
938
+ }
939
+ /**
940
+ * Increment the step counter (for custom training loops)
941
+ *
942
+ * Call this after each training step when using low-level APIs like
943
+ * engine.trainStepWithGenerations() instead of trainer.trainStepAuto().
944
+ */
945
+ incrementStep() {
946
+ this.currentStep++;
947
+ }
948
+ /**
949
+ * Get the current step number
950
+ */
951
+ getStep() {
952
+ return this.currentStep;
953
+ }
954
+ /**
955
+ * Get the current epoch number
956
+ */
957
+ getEpoch() {
958
+ return this.currentEpoch;
959
+ }
960
+ /**
961
+ * Record a training step to the output store database (for custom training loops)
962
+ *
963
+ * Use this when building custom training loops with engine.trainStepWithGenerations().
964
+ * The step number should be the value after incrementStep() was called.
965
+ *
966
+ * @param step - Step number
967
+ * @param metrics - Step metrics from the engine
968
+ * @param completions - Generated completion texts
969
+ * @param rewards - Reward values for each completion
970
+ * @param prompts - Prompt messages for each completion
971
+ */
972
+ async recordStepToDatabase(step, metrics, completions, rewards, prompts) {
973
+ if (!this.outputStore)
974
+ return;
975
+ const groupSize = this.config.groupSize ?? 4;
976
+ // Build outputs JSON in the format expected by recordStepFromOutputs
977
+ const outputs = completions.map((text, i) => ({
978
+ prompt: prompts[Math.floor(i / groupSize)] ?? '',
979
+ completion: {
980
+ text,
981
+ raw_text: text,
982
+ tool_calls: [],
983
+ thinking: null,
984
+ num_tokens: text.length, // Approximate
985
+ finish_reason: 'eos',
986
+ },
987
+ expected_answer: null,
988
+ }));
989
+ const outputsJson = JSON.stringify(outputs);
990
+ try {
991
+ await this.outputStore.recordStepFromOutputs(step, {
992
+ step,
993
+ loss: metrics.loss,
994
+ totalTokens: metrics.totalTokens,
995
+ meanReward: metrics.meanReward,
996
+ stdReward: metrics.stdReward,
997
+ meanAdvantage: metrics.meanAdvantage,
998
+ stdAdvantage: metrics.stdAdvantage,
999
+ generationTimeMs: 0,
1000
+ trainingTimeMs: 0,
1001
+ peakMemoryMb: 0,
1002
+ activeMemoryMb: 0,
1003
+ gradientsApplied: true,
1004
+ }, outputsJson, rewards, groupSize);
1005
+ }
1006
+ catch (err) {
1007
+ console.error('[OutputStore] Failed to record step:', err);
1008
+ }
1009
+ }
1010
+ /**
1011
+ * Run a full training loop over a dataset
1012
+ *
1013
+ * This is the high-level training API that handles:
1014
+ * - Epoch iteration
1015
+ * - Batching
1016
+ * - Generation and reward computation
1017
+ * - Logging (if configured)
1018
+ * - Checkpoint saving and resumption
1019
+ * - TUI mode support (pause/resume, sample reporting)
1020
+ *
1021
+ * @param dataset - Array of DatasetExample items
1022
+ */
1023
+ async train(dataset) {
1024
+ if (dataset.length === 0) {
1025
+ return;
1026
+ }
1027
+ const numEpochs = this.config.numEpochs ?? 1;
1028
+ const batchSize = this.config.batchSize ?? 1;
1029
+ const saveInterval = this.config.saveInterval ?? 100;
1030
+ // Create output directory if needed
1031
+ if (this.config.outputDir && !existsSync(this.config.outputDir)) {
1032
+ mkdirSync(this.config.outputDir, { recursive: true });
1033
+ }
1034
+ // Calculate total steps per epoch BEFORE initOutputStore (needed for accurate resume state)
1035
+ const stepsPerEpoch = Math.ceil(dataset.length / batchSize);
1036
+ // Compute current dataset metadata
1037
+ const currentDatasetHash = computeDatasetHash(dataset);
1038
+ const currentDatasetMetadata = {
1039
+ size: dataset.length,
1040
+ contentHash: currentDatasetHash,
1041
+ };
1042
+ // Validate dataset on resume if we have previous metadata
1043
+ if (this.datasetMetadata && this.currentStep > 0) {
1044
+ const prevMeta = this.datasetMetadata;
1045
+ if (prevMeta.size !== dataset.length) {
1046
+ this.logger.warn(`[Resume] Dataset size mismatch: checkpoint was trained on ${prevMeta.size} examples, ` +
1047
+ `current dataset has ${dataset.length} examples. Batch indices may not align correctly.`);
1048
+ }
1049
+ if (prevMeta.contentHash !== currentDatasetHash) {
1050
+ this.logger.warn(`[Resume] Dataset content mismatch: checkpoint dataset hash ${prevMeta.contentHash}, ` +
1051
+ `current dataset hash ${currentDatasetHash}. Dataset may have been modified or shuffled differently.`);
1052
+ }
1053
+ // Log validation result
1054
+ if (prevMeta.size === dataset.length && prevMeta.contentHash === currentDatasetHash) {
1055
+ this.logger.info(`[Resume] Dataset validated: ${dataset.length} examples, hash ${currentDatasetHash} (matches checkpoint)`);
1056
+ }
1057
+ }
1058
+ // Store current dataset metadata for future checkpoints
1059
+ this.datasetMetadata = currentDatasetMetadata;
1060
+ // Initialize output store if enabled (pass stepsPerEpoch for accurate batch display on resume)
1061
+ await this.initOutputStore(stepsPerEpoch);
1062
+ // Determine starting point based on resumed state
1063
+ const startEpoch = this.currentEpoch;
1064
+ const startStep = this.currentStep;
1065
+ const startBatchIdx = startStep > 0 ? startStep % stepsPerEpoch : 0;
1066
+ // Get model name from path
1067
+ const modelName = this.originalModelPath?.split('/').pop() ?? this.config.modelPath?.split('/').pop() ?? 'Unknown';
1068
+ // Log training start
1069
+ this.logger.init(modelName, {
1070
+ trainingType: 'grpo',
1071
+ numEpochs,
1072
+ batchSize,
1073
+ groupSize: this.config.groupSize ?? 4,
1074
+ learningRate: this.config.learningRate ?? 1e-6,
1075
+ }, dataset.length);
1076
+ if (startStep > 0) {
1077
+ this.logger.info(`Resuming from step ${startStep} (epoch ${startEpoch + 1}, batch ${startBatchIdx + 1}/${stepsPerEpoch})`);
1078
+ }
1079
+ for (let epoch = startEpoch; epoch < numEpochs; epoch++) {
1080
+ // Check for stop request
1081
+ if (this.stopRequested)
1082
+ break;
1083
+ this.currentEpoch = epoch;
1084
+ this.startEpoch();
1085
+ const epochStartTime = Date.now();
1086
+ // Log epoch start
1087
+ this.logger.epochStart(epoch, numEpochs, stepsPerEpoch);
1088
+ // Calculate starting batch index for this epoch
1089
+ const epochStartBatch = epoch === startEpoch && startStep > 0 ? startBatchIdx * batchSize : 0;
1090
+ // Iterate through batches
1091
+ for (let i = epochStartBatch; i < dataset.length; i += batchSize) {
1092
+ // Check for stop request
1093
+ if (this.stopRequested)
1094
+ break;
1095
+ // Wait if paused
1096
+ if (this.paused) {
1097
+ await this.waitForResume();
1098
+ if (this.stopRequested)
1099
+ break;
1100
+ }
1101
+ // Calculate batch index (0-indexed within epoch)
1102
+ const batchIdx = Math.floor(i / batchSize);
1103
+ // Skip already processed batches on resume (from checkpoint's processedBatchIndices)
1104
+ if (this.processedBatchIndices.has(batchIdx)) {
1105
+ this.logger.debug(`Skipping already processed batch ${batchIdx + 1}/${stepsPerEpoch} (from checkpoint)`);
1106
+ continue;
1107
+ }
1108
+ const batch = dataset.slice(i, Math.min(i + batchSize, dataset.length));
1109
+ // Extract prompts and answers from batch
1110
+ const prompts = batch.map((ex) => ex.prompt);
1111
+ // Verbose logging for debugging stuck batches
1112
+ const batchNum = batchIdx + 1;
1113
+ this.logger.info(`Batch ${batchNum}/${stepsPerEpoch} starting (${prompts.length} prompts × ${this.config.groupSize ?? 4} groups)`);
1114
+ // Run training step with auto reward computation
1115
+ const stepStartTime = Date.now();
1116
+ const { metrics, completions, rewards, completionLengths } = await this.trainStepAuto(prompts);
1117
+ const stepDuration = Date.now() - stepStartTime;
1118
+ this.logger.info(`Batch ${batchNum}/${stepsPerEpoch} done in ${(stepDuration / 1000).toFixed(1)}s ` +
1119
+ `(gen=${metrics.generationTimeMs?.toFixed(0) ?? '?'}ms, train=${metrics.trainingTimeMs?.toFixed(0) ?? '?'}ms, loss=${metrics.loss.toFixed(4)})`);
1120
+ // Log step metrics (logger handles TUI/console mode internally)
1121
+ this.logger.step(metrics, batchIdx, stepsPerEpoch);
1122
+ // Track processed batch for resume
1123
+ this.processedBatchIndices.add(batchIdx);
1124
+ // Report generation samples to TUI based on display mode
1125
+ // In console mode, logger.generation() is a no-op
1126
+ const groupSize = this.config.groupSize ?? 4;
1127
+ // Determine which sample indices to report based on display mode
1128
+ let indicesToReport;
1129
+ if (this.sampleDisplayMode === 'all') {
1130
+ // Report all samples
1131
+ indicesToReport = Array.from({ length: completions.length }, (_, i) => i);
1132
+ }
1133
+ else if (this.sampleDisplayMode === 'best_worst') {
1134
+ // Find indices of best (max reward) and worst (min reward) samples
1135
+ let bestIdx = 0;
1136
+ let worstIdx = 0;
1137
+ let bestReward = rewards[0];
1138
+ let worstReward = rewards[0];
1139
+ for (let j = 1; j < rewards.length; j++) {
1140
+ if (rewards[j] > bestReward) {
1141
+ bestReward = rewards[j];
1142
+ bestIdx = j;
1143
+ }
1144
+ if (rewards[j] < worstReward) {
1145
+ worstReward = rewards[j];
1146
+ worstIdx = j;
1147
+ }
1148
+ }
1149
+ // Avoid duplicates if best and worst are the same
1150
+ indicesToReport = bestIdx === worstIdx ? [bestIdx] : [bestIdx, worstIdx];
1151
+ }
1152
+ else {
1153
+ // random: pick 2 random samples (or fewer if completions.length < 2)
1154
+ const numSamples = Math.min(2, completions.length);
1155
+ const shuffled = Array.from({ length: completions.length }, (_, i) => i);
1156
+ // Fisher-Yates partial shuffle for first numSamples
1157
+ for (let k = 0; k < numSamples; k++) {
1158
+ const randIdx = k + Math.floor(Math.random() * (shuffled.length - k));
1159
+ [shuffled[k], shuffled[randIdx]] = [shuffled[randIdx], shuffled[k]];
1160
+ }
1161
+ indicesToReport = shuffled.slice(0, numSamples);
1162
+ }
1163
+ for (const j of indicesToReport) {
1164
+ // Get the prompt for this completion (each prompt has groupSize completions)
1165
+ const promptIdx = Math.floor(j / groupSize);
1166
+ const promptMessages = prompts[promptIdx] ?? [];
1167
+ // Format prompt as text (last user message is most relevant)
1168
+ const lastUserMsg = promptMessages.filter((m) => m.role === 'user').pop();
1169
+ const promptText = lastUserMsg?.content ?? '';
1170
+ this.logger.generation({
1171
+ index: j,
1172
+ prompt: promptText,
1173
+ completion: completions[j],
1174
+ reward: rewards[j],
1175
+ tokens: completionLengths[j] ?? this.config.maxCompletionLength ?? 256,
1176
+ });
1177
+ }
1178
+ // Save checkpoint periodically
1179
+ if (this.config.outputDir && this.currentStep > 0 && this.currentStep % saveInterval === 0) {
1180
+ const path = await this.saveCheckpoint();
1181
+ if (path) {
1182
+ this.logger.checkpoint(path, this.currentStep);
1183
+ }
1184
+ }
1185
+ // Check for emergency checkpoint (triggered by consecutive NaN gradients)
1186
+ if (this.config.outputDir && this.engine.needsEmergencySave) {
1187
+ this.logger.warn(`[EMERGENCY] Emergency save triggered after ${this.engine.nanGradientCount} consecutive NaN gradients at step ${this.currentStep}`);
1188
+ // Save current (possibly corrupted) state for debugging
1189
+ const debugCheckpointPath = `emergency-debug-step-${this.currentStep}`;
1190
+ await this.saveCheckpoint(debugCheckpointPath, { isEmergency: true });
1191
+ this.logger.info(`[EMERGENCY] Saved debug checkpoint with current (possibly corrupted) state to ${debugCheckpointPath}`);
1192
+ // If we have a last known good checkpoint, copy it for recovery
1193
+ if (this.lastGoodCheckpointPath && existsSync(this.lastGoodCheckpointPath)) {
1194
+ this.logger.warn(`[EMERGENCY] Reverting to last good checkpoint from step ${this.lastGoodCheckpointStep}: ${this.lastGoodCheckpointPath}`);
1195
+ // Copy last good checkpoint to a recovery location
1196
+ const outputDir = this.config.outputDir ?? './outputs';
1197
+ const recoveryPath = join(outputDir, `emergency-recovery-step-${this.lastGoodCheckpointStep}`);
1198
+ try {
1199
+ // Remove existing recovery checkpoint if it exists
1200
+ if (existsSync(recoveryPath)) {
1201
+ rmSync(recoveryPath, { recursive: true, force: true });
1202
+ }
1203
+ // Copy the last good checkpoint to recovery location
1204
+ cpSync(this.lastGoodCheckpointPath, recoveryPath, { recursive: true });
1205
+ this.logger.info(`[EMERGENCY] Copied last good checkpoint to ${recoveryPath}`);
1206
+ this.logger.warn(`[EMERGENCY] Recovery checkpoint available at: ${recoveryPath}\n` +
1207
+ ` To resume from the last good state, use: resumeFromCheckpoint: '${recoveryPath}'`);
1208
+ }
1209
+ catch (copyError) {
1210
+ this.logger.error(`[EMERGENCY] Failed to copy last good checkpoint: ${copyError}`);
1211
+ }
1212
+ }
1213
+ else {
1214
+ this.logger.error(`[EMERGENCY] No previous good checkpoint available for recovery! ` +
1215
+ `The debug checkpoint contains the current (potentially corrupted) model state.`);
1216
+ }
1217
+ // Clear the emergency flag
1218
+ this.engine.clearEmergencySaveFlag();
1219
+ // Log recovery guidance
1220
+ this.logger.warn(`[EMERGENCY] Training will continue, but model quality may be degraded.\n` +
1221
+ ` Recommendations:\n` +
1222
+ ` - Reduce learning rate (current: ${this.config.learningRate ?? 1e-6})\n` +
1223
+ ` - Check training data for anomalies\n` +
1224
+ ` - Consider stopping and resuming from the recovery checkpoint`);
1225
+ }
1226
+ }
1227
+ const epochEndTime = Date.now();
1228
+ const epochTimeSecs = (epochEndTime - epochStartTime) / 1000;
1229
+ this.endEpoch(epochTimeSecs);
1230
+ this.logger.epochEnd(epoch, numEpochs, epochTimeSecs);
1231
+ // Clear processed batch indices at epoch boundary (new epoch = new batches)
1232
+ this.processedBatchIndices.clear();
1233
+ }
1234
+ // Save final checkpoint
1235
+ if (this.config.outputDir && !this.stopRequested) {
1236
+ const path = await this.saveCheckpoint('final');
1237
+ if (path) {
1238
+ this.logger.checkpoint(path, this.currentStep);
1239
+ }
1240
+ }
1241
+ // Log completion
1242
+ this.logger.complete(this.currentStep);
1243
+ // End output store run if active
1244
+ if (this.outputStore) {
1245
+ const status = this.stopRequested ? 'stopped' : 'completed';
1246
+ await this.outputStore.endRun(status);
1247
+ await this.outputStore.flush();
1248
+ }
1249
+ // Cleanup stdin interface
1250
+ if (this.stdinInterface) {
1251
+ this.stdinInterface.close();
1252
+ }
1253
+ }
1254
+ /**
1255
+ * Save a checkpoint with model weights and training state
1256
+ *
1257
+ * Regular checkpoints (non-emergency) are tracked as "last known good" checkpoints.
1258
+ * When NaN gradients occur, the emergency save logic can restore from the last good checkpoint.
1259
+ *
1260
+ * @param name - Checkpoint name (default: "checkpoint-{step}")
1261
+ * @param options - Optional settings for checkpoint save behavior
1262
+ * @param options.isEmergency - If true, this is an emergency checkpoint (debug state, not "good")
1263
+ * @returns Path to saved checkpoint, or empty string if save was skipped due to corruption
1264
+ */
1265
+ async saveCheckpoint(name, options) {
1266
+ const isEmergency = options?.isEmergency ?? false;
1267
+ const checkpointName = name ?? `checkpoint-${this.currentStep}`;
1268
+ const outputDir = this.config.outputDir ?? './outputs';
1269
+ const checkpointPath = join(outputDir, checkpointName);
1270
+ // Create checkpoint directory
1271
+ if (!existsSync(checkpointPath)) {
1272
+ mkdirSync(checkpointPath, { recursive: true });
1273
+ }
1274
+ // Save training state with dataset metadata for resume validation
1275
+ const state = {
1276
+ step: this.currentStep,
1277
+ epoch: this.currentEpoch,
1278
+ timestamp: new Date().toISOString(),
1279
+ dataset: this.datasetMetadata
1280
+ ? {
1281
+ size: this.datasetMetadata.size,
1282
+ contentHash: this.datasetMetadata.contentHash,
1283
+ shuffleSeed: this.datasetMetadata.shuffleSeed,
1284
+ processedBatchIndices: Array.from(this.processedBatchIndices),
1285
+ }
1286
+ : undefined,
1287
+ };
1288
+ const statePath = join(checkpointPath, 'training_state.json');
1289
+ writeFileSync(statePath, JSON.stringify(state, null, 2));
1290
+ // Save model weights
1291
+ await this.model.saveModel(checkpointPath);
1292
+ // Copy tokenizer files from original model path (required for loading checkpoints)
1293
+ const tokenizerSource = this.originalModelPath ?? this.config.modelPath;
1294
+ if (tokenizerSource) {
1295
+ const tokenizerFiles = ['tokenizer.json', 'tokenizer_config.json', 'vocab.json', 'merges.txt'];
1296
+ for (const file of tokenizerFiles) {
1297
+ const srcPath = join(tokenizerSource, file);
1298
+ const destPath = join(checkpointPath, file);
1299
+ if (existsSync(srcPath) && !existsSync(destPath)) {
1300
+ copyFileSync(srcPath, destPath);
1301
+ }
1302
+ }
1303
+ }
1304
+ this.logger.info(`Checkpoint saved: ${checkpointPath}`);
1305
+ // Track last checkpoint step for emergency save throttling
1306
+ this.lastCheckpointStep = this.currentStep;
1307
+ // Track as "last known good" checkpoint (only for regular saves, not emergency saves)
1308
+ if (!isEmergency) {
1309
+ this.lastGoodCheckpointPath = checkpointPath;
1310
+ this.lastGoodCheckpointStep = this.currentStep;
1311
+ this.logger.debug(`Tracked as last good checkpoint: step ${this.currentStep}`);
1312
+ }
1313
+ // Clean up old checkpoints to save disk space
1314
+ const maxCheckpoints = this.config.maxCheckpoints ?? 3;
1315
+ if (maxCheckpoints > 0) {
1316
+ this.cleanupOldCheckpoints(outputDir, maxCheckpoints);
1317
+ }
1318
+ return checkpointPath;
1319
+ }
1320
+ /**
1321
+ * Remove old checkpoints, keeping only the most recent ones
1322
+ * Preserves 'final' and 'emergency-*' checkpoints
1323
+ */
1324
+ cleanupOldCheckpoints(outputDir, maxToKeep) {
1325
+ try {
1326
+ const entries = readdirSync(outputDir, { withFileTypes: true });
1327
+ // Find regular checkpoint directories (checkpoint-N pattern)
1328
+ const checkpoints = [];
1329
+ for (const entry of entries) {
1330
+ if (!entry.isDirectory())
1331
+ continue;
1332
+ // Skip 'final' and 'emergency-*' checkpoints
1333
+ if (entry.name === 'final' || entry.name.startsWith('emergency-'))
1334
+ continue;
1335
+ // Match checkpoint-N pattern
1336
+ const match = entry.name.match(/^checkpoint-(\d+)$/);
1337
+ if (match) {
1338
+ const checkpointPath = join(outputDir, entry.name);
1339
+ const stat = statSync(checkpointPath);
1340
+ checkpoints.push({
1341
+ name: entry.name,
1342
+ step: parseInt(match[1], 10),
1343
+ mtime: stat.mtime,
1344
+ });
1345
+ }
1346
+ }
1347
+ // Sort by step number descending (newest first)
1348
+ checkpoints.sort((a, b) => b.step - a.step);
1349
+ // Remove old checkpoints beyond maxToKeep
1350
+ if (checkpoints.length > maxToKeep) {
1351
+ const toRemove = checkpoints.slice(maxToKeep);
1352
+ for (const checkpoint of toRemove) {
1353
+ const checkpointPath = join(outputDir, checkpoint.name);
1354
+ rmSync(checkpointPath, { recursive: true, force: true });
1355
+ this.logger.debug(`Removed old checkpoint: ${checkpoint.name}`);
1356
+ }
1357
+ }
1358
+ }
1359
+ catch (error) {
1360
+ // Don't fail training if cleanup fails
1361
+ this.logger.warn(`Failed to cleanup old checkpoints: ${error}`);
1362
+ }
1363
+ }
1364
+ /**
1365
+ * Start a new training epoch
1366
+ */
1367
+ startEpoch() {
1368
+ this.engine.startEpoch();
1369
+ }
1370
+ /**
1371
+ * End the current epoch and get metrics
1372
+ *
1373
+ * @param epochTimeSecs - Duration of the epoch in seconds
1374
+ */
1375
+ endEpoch(epochTimeSecs) {
1376
+ return this.engine.endEpoch(epochTimeSecs);
1377
+ }
1378
+ /**
1379
+ * Reset the trainer for a new training run
1380
+ */
1381
+ reset() {
1382
+ this.engine.reset();
1383
+ }
1384
+ /**
1385
+ * Get current training step
1386
+ */
1387
+ get step() {
1388
+ return Number(this.engine.step);
1389
+ }
1390
+ /**
1391
+ * Get current epoch
1392
+ */
1393
+ get epoch() {
1394
+ return this.engine.epoch;
1395
+ }
1396
+ /**
1397
+ * Get current micro-step within gradient accumulation
1398
+ */
1399
+ get microStep() {
1400
+ return this.engine.microStep;
1401
+ }
1402
+ /**
1403
+ * Check if built-in rewards are configured
1404
+ */
1405
+ get hasBuiltinRewards() {
1406
+ return this.engine.hasBuiltinRewards;
1407
+ }
1408
+ /**
1409
+ * Get names of registered reward functions
1410
+ */
1411
+ get rewardNames() {
1412
+ return this.engine.rewardNames;
1413
+ }
1414
+ /**
1415
+ * Get the underlying native engine
1416
+ *
1417
+ * For advanced use cases that need direct access.
1418
+ */
1419
+ getNativeEngine() {
1420
+ return this.engine;
1421
+ }
1422
+ }
1423
+ /**
1424
+ * Create a standalone reward registry for testing rewards
1425
+ *
1426
+ * @example
1427
+ * ```typescript
1428
+ * const registry = createRewardRegistry();
1429
+ * registry.register({
1430
+ * rewardType: 'ToolUse',
1431
+ * allowedTools: ['search'],
1432
+ * });
1433
+ *
1434
+ * const score = registry.score('prompt', 'completion with <tool_call>...</tool_call>');
1435
+ * ```
1436
+ */
1437
+ export function createRewardRegistry() {
1438
+ return new NativeRewardRegistry();
1439
+ }