@mlx-node/trl 0.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/data/dataset.d.ts +22 -0
- package/dist/data/dataset.d.ts.map +1 -0
- package/dist/data/dataset.js +142 -0
- package/dist/data/sft-dataset.d.ts +156 -0
- package/dist/data/sft-dataset.d.ts.map +1 -0
- package/dist/data/sft-dataset.js +415 -0
- package/dist/index.d.ts +33 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +47 -0
- package/dist/trainers/grpo-config.d.ts +42 -0
- package/dist/trainers/grpo-config.d.ts.map +1 -0
- package/dist/trainers/grpo-config.js +220 -0
- package/dist/trainers/grpo-entropy.d.ts +33 -0
- package/dist/trainers/grpo-entropy.d.ts.map +1 -0
- package/dist/trainers/grpo-entropy.js +18 -0
- package/dist/trainers/grpo-trainer.d.ts +602 -0
- package/dist/trainers/grpo-trainer.d.ts.map +1 -0
- package/dist/trainers/grpo-trainer.js +1439 -0
- package/dist/trainers/sft-config.d.ts +32 -0
- package/dist/trainers/sft-config.d.ts.map +1 -0
- package/dist/trainers/sft-config.js +186 -0
- package/dist/trainers/sft-trainer.d.ts +141 -0
- package/dist/trainers/sft-trainer.d.ts.map +1 -0
- package/dist/trainers/sft-trainer.js +502 -0
- package/dist/trainers/training-logger.d.ts +375 -0
- package/dist/trainers/training-logger.d.ts.map +1 -0
- package/dist/trainers/training-logger.js +542 -0
- package/dist/types.d.ts +54 -0
- package/dist/types.d.ts.map +1 -0
- package/dist/types.js +1 -0
- package/dist/utils/path-security.d.ts +51 -0
- package/dist/utils/path-security.d.ts.map +1 -0
- package/dist/utils/path-security.js +69 -0
- package/dist/utils/xml-parser.d.ts +6 -0
- package/dist/utils/xml-parser.d.ts.map +1 -0
- package/dist/utils/xml-parser.js +184 -0
- package/package.json +29 -0
|
@@ -0,0 +1,502 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* SFT (Supervised Fine-Tuning) Trainer
|
|
3
|
+
*
|
|
4
|
+
* This module provides a Rust-native SFT training engine for training
|
|
5
|
+
* models on fixed prompt-completion pairs using cross-entropy loss.
|
|
6
|
+
*
|
|
7
|
+
* ## Key Features
|
|
8
|
+
* - Training loop runs in Rust (eliminates FFI overhead)
|
|
9
|
+
* - Cross-entropy loss with completion masking (ignore_index=-100)
|
|
10
|
+
* - Label smoothing support
|
|
11
|
+
* - Gradient accumulation and clipping
|
|
12
|
+
* - High-level train() method for full training runs
|
|
13
|
+
* - Low-level trainStep() for custom training loops
|
|
14
|
+
*
|
|
15
|
+
* ## Usage
|
|
16
|
+
* ```typescript
|
|
17
|
+
* const trainer = await SFTTrainer.create({
|
|
18
|
+
* modelPath: './model',
|
|
19
|
+
* learningRate: 2e-5,
|
|
20
|
+
* numEpochs: 3,
|
|
21
|
+
* });
|
|
22
|
+
* await trainer.train(dataset);
|
|
23
|
+
* ```
|
|
24
|
+
*/
|
|
25
|
+
import { existsSync, mkdirSync, writeFileSync, readFileSync, readdirSync, copyFileSync, rmSync } from 'node:fs';
|
|
26
|
+
import { join, parse } from 'node:path';
|
|
27
|
+
import * as readline from 'node:readline';
|
|
28
|
+
import { SftTrainingEngine, Qwen3Model, Qwen3Tokenizer, MxArray, } from '@mlx-node/core';
|
|
29
|
+
import { getDefaultSFTConfig, mergeSFTConfig } from './sft-config';
|
|
30
|
+
import { loadSFTDataset } from '../data/sft-dataset';
|
|
31
|
+
import { createTrainingLogger } from './training-logger';
|
|
32
|
+
// Re-export types
|
|
33
|
+
export { SftTrainingEngine } from '@mlx-node/core';
|
|
34
|
+
/**
|
|
35
|
+
* SFT Trainer - Rust-Native Training Engine
|
|
36
|
+
*
|
|
37
|
+
* Provides a TypeScript-friendly interface to the Rust SFT training engine.
|
|
38
|
+
*/
|
|
39
|
+
export class SFTTrainer {
|
|
40
|
+
engine;
|
|
41
|
+
model;
|
|
42
|
+
tokenizer;
|
|
43
|
+
config;
|
|
44
|
+
currentEpoch = 0;
|
|
45
|
+
currentStep = 0;
|
|
46
|
+
/** Original model path (for tokenizer files when saving checkpoints) */
|
|
47
|
+
originalModelPath;
|
|
48
|
+
// TUI state
|
|
49
|
+
paused = false;
|
|
50
|
+
stopRequested = false;
|
|
51
|
+
stdinInterface;
|
|
52
|
+
logger;
|
|
53
|
+
sampleDisplayMode = 'all';
|
|
54
|
+
/**
|
|
55
|
+
* Create a new SFT trainer from a model
|
|
56
|
+
*
|
|
57
|
+
* @param model - Pre-loaded Qwen3 model
|
|
58
|
+
* @param tokenizer - Pre-loaded tokenizer
|
|
59
|
+
* @param config - Training configuration
|
|
60
|
+
* @param logger - Optional custom logger
|
|
61
|
+
*/
|
|
62
|
+
constructor(model, tokenizer, config = {}, logger) {
|
|
63
|
+
// Auto-detect TUI mode from environment variable
|
|
64
|
+
const tuiModeFromEnv = process.env.MLX_TUI_MODE === '1';
|
|
65
|
+
if (tuiModeFromEnv && config.tui_mode === undefined) {
|
|
66
|
+
config.tui_mode = true;
|
|
67
|
+
}
|
|
68
|
+
this.config = mergeSFTConfig(getDefaultSFTConfig(), config);
|
|
69
|
+
this.model = model;
|
|
70
|
+
this.tokenizer = tokenizer;
|
|
71
|
+
// Create or use provided logger
|
|
72
|
+
this.logger =
|
|
73
|
+
logger ??
|
|
74
|
+
createTrainingLogger({
|
|
75
|
+
logConsole: !this.config.tui_mode,
|
|
76
|
+
logJsonl: this.config.log_jsonl,
|
|
77
|
+
outputDir: this.config.output_dir,
|
|
78
|
+
runName: this.config.run_name,
|
|
79
|
+
logInterval: this.config.logging_steps,
|
|
80
|
+
});
|
|
81
|
+
// Convert to native config
|
|
82
|
+
const engineConfig = {
|
|
83
|
+
learningRate: this.config.learning_rate,
|
|
84
|
+
gradientAccumulationSteps: this.config.gradient_accumulation_steps,
|
|
85
|
+
gradientClipNorm: this.config.max_grad_norm,
|
|
86
|
+
weightDecay: this.config.weight_decay,
|
|
87
|
+
labelSmoothing: this.config.label_smoothing,
|
|
88
|
+
};
|
|
89
|
+
this.engine = new SftTrainingEngine(model, engineConfig);
|
|
90
|
+
// Setup stdin handler if TUI mode
|
|
91
|
+
if (this.config.tui_mode) {
|
|
92
|
+
this.setupStdinHandler();
|
|
93
|
+
}
|
|
94
|
+
}
|
|
95
|
+
/**
|
|
96
|
+
* Setup stdin handler for TUI control commands
|
|
97
|
+
*/
|
|
98
|
+
setupStdinHandler() {
|
|
99
|
+
if (!this.config.tui_mode)
|
|
100
|
+
return;
|
|
101
|
+
this.stdinInterface = readline.createInterface({
|
|
102
|
+
input: process.stdin,
|
|
103
|
+
output: process.stdout,
|
|
104
|
+
terminal: false,
|
|
105
|
+
});
|
|
106
|
+
this.stdinInterface.on('line', (line) => {
|
|
107
|
+
const cmd = line.trim();
|
|
108
|
+
this.handleStdinCommand(cmd);
|
|
109
|
+
});
|
|
110
|
+
}
|
|
111
|
+
/**
|
|
112
|
+
* Handle a command received from stdin
|
|
113
|
+
*/
|
|
114
|
+
handleStdinCommand(cmd) {
|
|
115
|
+
switch (cmd) {
|
|
116
|
+
case 'PAUSE':
|
|
117
|
+
this.paused = true;
|
|
118
|
+
this.logger.paused(this.currentStep);
|
|
119
|
+
break;
|
|
120
|
+
case 'RESUME':
|
|
121
|
+
this.paused = false;
|
|
122
|
+
this.logger.resumed(this.currentStep);
|
|
123
|
+
break;
|
|
124
|
+
case 'SAVE_CHECKPOINT':
|
|
125
|
+
this.saveCheckpoint().catch(() => { });
|
|
126
|
+
break;
|
|
127
|
+
case 'STOP':
|
|
128
|
+
this.stopRequested = true;
|
|
129
|
+
break;
|
|
130
|
+
default:
|
|
131
|
+
// Handle SET commands (e.g., SET sample_display=best_worst)
|
|
132
|
+
if (cmd.startsWith('SET ')) {
|
|
133
|
+
const keyValue = cmd.slice(4); // Remove 'SET ' prefix
|
|
134
|
+
const eqIdx = keyValue.indexOf('=');
|
|
135
|
+
if (eqIdx > 0) {
|
|
136
|
+
const key = keyValue.slice(0, eqIdx);
|
|
137
|
+
const value = keyValue.slice(eqIdx + 1);
|
|
138
|
+
if (key === 'sample_display') {
|
|
139
|
+
if (value === 'all' || value === 'best_worst' || value === 'random') {
|
|
140
|
+
this.sampleDisplayMode = value;
|
|
141
|
+
}
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
}
|
|
145
|
+
break;
|
|
146
|
+
}
|
|
147
|
+
}
|
|
148
|
+
/**
|
|
149
|
+
* Wait for resume if paused
|
|
150
|
+
*/
|
|
151
|
+
async waitForResume() {
|
|
152
|
+
while (this.paused && !this.stopRequested) {
|
|
153
|
+
await new Promise((resolve) => setTimeout(resolve, 100));
|
|
154
|
+
}
|
|
155
|
+
}
|
|
156
|
+
/**
|
|
157
|
+
* Create a trainer by loading a model from disk
|
|
158
|
+
*
|
|
159
|
+
* @param config - Configuration including modelPath
|
|
160
|
+
* @returns Promise<SFTTrainer>
|
|
161
|
+
*/
|
|
162
|
+
static async create(config) {
|
|
163
|
+
if (!config.model_name) {
|
|
164
|
+
throw new Error('model_name is required when using SFTTrainer.create()');
|
|
165
|
+
}
|
|
166
|
+
// Create logger early
|
|
167
|
+
const logger = createTrainingLogger({
|
|
168
|
+
logConsole: !config.tui_mode,
|
|
169
|
+
logJsonl: config.log_jsonl ?? true,
|
|
170
|
+
outputDir: config.output_dir,
|
|
171
|
+
runName: config.run_name,
|
|
172
|
+
logInterval: config.logging_steps ?? 10,
|
|
173
|
+
});
|
|
174
|
+
let modelPath = config.model_name;
|
|
175
|
+
let resumedState = null;
|
|
176
|
+
// Handle checkpoint resumption
|
|
177
|
+
if (config.resume_from_checkpoint) {
|
|
178
|
+
const checkpointPath = config.resume_from_checkpoint === 'latest'
|
|
179
|
+
? SFTTrainer.findLatestCheckpoint(config.output_dir)
|
|
180
|
+
: config.resume_from_checkpoint;
|
|
181
|
+
if (checkpointPath) {
|
|
182
|
+
const statePath = join(checkpointPath, 'training_state.json');
|
|
183
|
+
if (existsSync(statePath)) {
|
|
184
|
+
resumedState = JSON.parse(readFileSync(statePath, 'utf-8'));
|
|
185
|
+
logger.info(`Resuming from checkpoint: ${checkpointPath} (step ${resumedState?.step}, epoch ${resumedState?.epoch})`);
|
|
186
|
+
}
|
|
187
|
+
modelPath = checkpointPath;
|
|
188
|
+
}
|
|
189
|
+
else if (config.resume_from_checkpoint === 'latest') {
|
|
190
|
+
logger.info('No checkpoint found, starting fresh training');
|
|
191
|
+
}
|
|
192
|
+
}
|
|
193
|
+
// Get model name for display
|
|
194
|
+
const modelName = parse(modelPath).base || 'Unknown';
|
|
195
|
+
logger.status('loading', `Loading ${modelName}...`);
|
|
196
|
+
// Load model and tokenizer
|
|
197
|
+
const model = await Qwen3Model.loadPretrained(modelPath);
|
|
198
|
+
const tokenizer = await Qwen3Tokenizer.fromPretrained(join(modelPath, 'tokenizer.json'));
|
|
199
|
+
logger.status('loading', `${modelName} loaded`);
|
|
200
|
+
// Create trainer
|
|
201
|
+
const trainer = new SFTTrainer(model, tokenizer, config, logger);
|
|
202
|
+
trainer.originalModelPath = config.model_name;
|
|
203
|
+
// Restore training state if resuming
|
|
204
|
+
if (resumedState) {
|
|
205
|
+
trainer.currentStep = resumedState.step;
|
|
206
|
+
trainer.currentEpoch = resumedState.epoch;
|
|
207
|
+
// Also restore engine state to sync step/epoch accounting
|
|
208
|
+
trainer.engine.restoreState(resumedState.step, resumedState.epoch);
|
|
209
|
+
}
|
|
210
|
+
return trainer;
|
|
211
|
+
}
|
|
212
|
+
/**
|
|
213
|
+
* Find the latest checkpoint in the output directory
|
|
214
|
+
*/
|
|
215
|
+
static findLatestCheckpoint(outputDir) {
|
|
216
|
+
if (!outputDir || !existsSync(outputDir)) {
|
|
217
|
+
return null;
|
|
218
|
+
}
|
|
219
|
+
const entries = readdirSync(outputDir, { withFileTypes: true });
|
|
220
|
+
const checkpoints = entries
|
|
221
|
+
.filter((e) => e.isDirectory() && e.name.startsWith('checkpoint-'))
|
|
222
|
+
.map((e) => ({
|
|
223
|
+
name: e.name,
|
|
224
|
+
step: parseInt(e.name.replace('checkpoint-', ''), 10),
|
|
225
|
+
path: join(outputDir, e.name),
|
|
226
|
+
}))
|
|
227
|
+
.filter((c) => !isNaN(c.step))
|
|
228
|
+
.sort((a, b) => b.step - a.step);
|
|
229
|
+
return checkpoints.length > 0 ? checkpoints[0].path : null;
|
|
230
|
+
}
|
|
231
|
+
/**
|
|
232
|
+
* Run a single training step
|
|
233
|
+
*
|
|
234
|
+
* @param batch - Tokenized batch with input_ids and labels
|
|
235
|
+
* @returns Training step metrics
|
|
236
|
+
*/
|
|
237
|
+
async trainStep(batch) {
|
|
238
|
+
// Convert Int32Array to MxArray
|
|
239
|
+
const inputIds = MxArray.fromInt32(batch.inputIds, BigInt64Array.from(batch.shape.map(BigInt)));
|
|
240
|
+
const labels = MxArray.fromInt32(batch.labels, BigInt64Array.from(batch.shape.map(BigInt)));
|
|
241
|
+
// Call native engine
|
|
242
|
+
const metrics = await this.engine.trainStep(inputIds, labels);
|
|
243
|
+
// Sync step with engine when gradients are applied (fixes gradient accumulation accounting)
|
|
244
|
+
// Note: metrics.step is i64 from Rust; JS number may lose precision beyond 2^53-1,
|
|
245
|
+
// but such step counts are unrealistic for any practical training run.
|
|
246
|
+
if (metrics.gradientsApplied) {
|
|
247
|
+
this.currentStep = Number(metrics.step);
|
|
248
|
+
}
|
|
249
|
+
return {
|
|
250
|
+
metrics,
|
|
251
|
+
epoch: this.currentEpoch,
|
|
252
|
+
};
|
|
253
|
+
}
|
|
254
|
+
/**
|
|
255
|
+
* Run a full training loop over a dataset
|
|
256
|
+
*
|
|
257
|
+
* @param dataset - SFT dataset or path to JSONL file
|
|
258
|
+
*/
|
|
259
|
+
async train(dataset) {
|
|
260
|
+
// Load dataset if path provided
|
|
261
|
+
let sftDataset;
|
|
262
|
+
if (typeof dataset === 'string') {
|
|
263
|
+
sftDataset = await loadSFTDataset(dataset, this.tokenizer, {
|
|
264
|
+
maxSeqLength: this.config.max_seq_length,
|
|
265
|
+
completionOnly: this.config.completion_only,
|
|
266
|
+
seed: this.config.seed,
|
|
267
|
+
limit: this.config.max_train_samples > 0 ? this.config.max_train_samples : undefined,
|
|
268
|
+
});
|
|
269
|
+
}
|
|
270
|
+
else {
|
|
271
|
+
sftDataset = dataset;
|
|
272
|
+
}
|
|
273
|
+
if (sftDataset.length === 0) {
|
|
274
|
+
return;
|
|
275
|
+
}
|
|
276
|
+
const numEpochs = this.config.num_epochs;
|
|
277
|
+
const batchSize = this.config.batch_size;
|
|
278
|
+
const saveInterval = this.config.save_steps;
|
|
279
|
+
// Create output directory
|
|
280
|
+
if (this.config.output_dir && !existsSync(this.config.output_dir)) {
|
|
281
|
+
mkdirSync(this.config.output_dir, { recursive: true });
|
|
282
|
+
}
|
|
283
|
+
// Calculate steps per epoch (in batches)
|
|
284
|
+
const stepsPerEpoch = sftDataset.numBatches(batchSize);
|
|
285
|
+
// Compute resume position (all logic centralized in Rust)
|
|
286
|
+
const resumePos = this.engine.computeResumePosition(stepsPerEpoch);
|
|
287
|
+
const effectiveStartEpoch = resumePos.startEpoch;
|
|
288
|
+
const effectiveStartBatchIdx = resumePos.startBatchIdx;
|
|
289
|
+
// Get model name
|
|
290
|
+
const modelName = (this.originalModelPath ? parse(this.originalModelPath).base : null) ??
|
|
291
|
+
(this.config.model_name ? parse(this.config.model_name).base : null) ??
|
|
292
|
+
'Unknown';
|
|
293
|
+
// Log training start
|
|
294
|
+
this.logger.init(modelName, {
|
|
295
|
+
trainingType: 'sft',
|
|
296
|
+
numEpochs,
|
|
297
|
+
batchSize,
|
|
298
|
+
groupSize: 1, // SFT doesn't use groups
|
|
299
|
+
learningRate: this.config.learning_rate,
|
|
300
|
+
}, sftDataset.length);
|
|
301
|
+
if (this.currentStep > 0) {
|
|
302
|
+
if (resumePos.isEpochBoundary) {
|
|
303
|
+
this.logger.info(`Resuming at epoch boundary, advancing to epoch ${effectiveStartEpoch + 1}`);
|
|
304
|
+
}
|
|
305
|
+
else {
|
|
306
|
+
this.logger.info(`Resuming from step ${this.currentStep} (epoch ${effectiveStartEpoch + 1}, batch ${effectiveStartBatchIdx + 1}/${stepsPerEpoch})`);
|
|
307
|
+
}
|
|
308
|
+
}
|
|
309
|
+
for (let epoch = effectiveStartEpoch; epoch < numEpochs; epoch++) {
|
|
310
|
+
if (this.stopRequested)
|
|
311
|
+
break;
|
|
312
|
+
this.currentEpoch = epoch;
|
|
313
|
+
this.engine.startEpoch(epoch);
|
|
314
|
+
const epochStartTime = Date.now();
|
|
315
|
+
// Use epoch-based shuffle (deterministic, reproducible via seed + epoch)
|
|
316
|
+
sftDataset.shuffleForEpoch(epoch);
|
|
317
|
+
// Log epoch start
|
|
318
|
+
this.logger.epochStart(epoch, numEpochs, stepsPerEpoch);
|
|
319
|
+
// Determine batch start position for this epoch
|
|
320
|
+
const batchStart = epoch === effectiveStartEpoch ? effectiveStartBatchIdx : 0;
|
|
321
|
+
// Iterate through batches
|
|
322
|
+
let batchIdx = 0;
|
|
323
|
+
for await (const batch of sftDataset.batches(batchSize)) {
|
|
324
|
+
if (this.stopRequested)
|
|
325
|
+
break;
|
|
326
|
+
// Skip batches if resuming mid-epoch
|
|
327
|
+
if (batchIdx < batchStart) {
|
|
328
|
+
batchIdx++;
|
|
329
|
+
continue;
|
|
330
|
+
}
|
|
331
|
+
// Wait if paused
|
|
332
|
+
if (this.paused) {
|
|
333
|
+
await this.waitForResume();
|
|
334
|
+
if (this.stopRequested)
|
|
335
|
+
break;
|
|
336
|
+
}
|
|
337
|
+
// Run training step
|
|
338
|
+
const { metrics } = await this.trainStep(batch);
|
|
339
|
+
// Log step metrics (only when gradients are applied to avoid duplicate logs during accumulation)
|
|
340
|
+
if (metrics.gradientsApplied) {
|
|
341
|
+
this.logger.step({
|
|
342
|
+
step: this.currentStep,
|
|
343
|
+
loss: metrics.loss,
|
|
344
|
+
totalTokens: metrics.totalTokens,
|
|
345
|
+
// SFT-specific metrics (no reward/advantage!)
|
|
346
|
+
perplexity: Math.exp(metrics.loss),
|
|
347
|
+
// Token accuracy is not currently tracked in the SFT engine
|
|
348
|
+
// Could be added later if the Rust engine exposes it
|
|
349
|
+
trainingTimeMs: metrics.trainingTimeMs,
|
|
350
|
+
}, batchIdx, stepsPerEpoch);
|
|
351
|
+
// Save checkpoint periodically
|
|
352
|
+
if (this.config.output_dir && this.currentStep > 0 && this.currentStep % saveInterval === 0) {
|
|
353
|
+
const path = await this.saveCheckpoint();
|
|
354
|
+
if (path) {
|
|
355
|
+
this.logger.checkpoint(path, this.currentStep);
|
|
356
|
+
}
|
|
357
|
+
}
|
|
358
|
+
}
|
|
359
|
+
// Check for emergency checkpoint
|
|
360
|
+
if (this.config.output_dir && this.engine.needsEmergencySave()) {
|
|
361
|
+
this.logger.warn(`[EMERGENCY] Saving emergency checkpoint at step ${this.currentStep} due to NaN gradients`);
|
|
362
|
+
await this.saveCheckpoint(`emergency-checkpoint-${this.currentStep}`);
|
|
363
|
+
this.engine.clearEmergencySave();
|
|
364
|
+
}
|
|
365
|
+
batchIdx++;
|
|
366
|
+
}
|
|
367
|
+
// Flush any remaining accumulated gradients (TRL parity)
|
|
368
|
+
const flushed = this.engine.flushGradients();
|
|
369
|
+
if (flushed) {
|
|
370
|
+
this.currentStep = this.engine.getStep();
|
|
371
|
+
// Check if flush step aligns with save interval
|
|
372
|
+
if (this.config.output_dir && this.currentStep > 0 && this.currentStep % saveInterval === 0) {
|
|
373
|
+
const path = await this.saveCheckpoint();
|
|
374
|
+
if (path) {
|
|
375
|
+
this.logger.checkpoint(path, this.currentStep);
|
|
376
|
+
}
|
|
377
|
+
}
|
|
378
|
+
}
|
|
379
|
+
const epochEndTime = Date.now();
|
|
380
|
+
const epochTimeSecs = (epochEndTime - epochStartTime) / 1000;
|
|
381
|
+
this.engine.endEpoch(epochTimeSecs);
|
|
382
|
+
this.logger.epochEnd(epoch, numEpochs, epochTimeSecs);
|
|
383
|
+
}
|
|
384
|
+
// Save final checkpoint
|
|
385
|
+
if (this.config.output_dir && !this.stopRequested) {
|
|
386
|
+
const path = await this.saveCheckpoint('final');
|
|
387
|
+
if (path) {
|
|
388
|
+
this.logger.checkpoint(path, this.currentStep);
|
|
389
|
+
}
|
|
390
|
+
}
|
|
391
|
+
// Log completion
|
|
392
|
+
this.logger.complete(this.currentStep);
|
|
393
|
+
// Cleanup
|
|
394
|
+
if (this.stdinInterface) {
|
|
395
|
+
this.stdinInterface.close();
|
|
396
|
+
}
|
|
397
|
+
}
|
|
398
|
+
/**
|
|
399
|
+
* Save a checkpoint with model weights and training state
|
|
400
|
+
*
|
|
401
|
+
* @param name - Checkpoint name (default: "checkpoint-{step}")
|
|
402
|
+
* @returns Path to saved checkpoint
|
|
403
|
+
*/
|
|
404
|
+
async saveCheckpoint(name) {
|
|
405
|
+
const checkpointName = name ?? `checkpoint-${this.currentStep}`;
|
|
406
|
+
const outputDir = this.config.output_dir ?? './outputs';
|
|
407
|
+
const checkpointPath = join(outputDir, checkpointName);
|
|
408
|
+
// Create checkpoint directory
|
|
409
|
+
if (!existsSync(checkpointPath)) {
|
|
410
|
+
mkdirSync(checkpointPath, { recursive: true });
|
|
411
|
+
}
|
|
412
|
+
// Save training state
|
|
413
|
+
const state = {
|
|
414
|
+
step: this.currentStep,
|
|
415
|
+
epoch: this.currentEpoch,
|
|
416
|
+
timestamp: new Date().toISOString(),
|
|
417
|
+
trainerType: 'sft',
|
|
418
|
+
};
|
|
419
|
+
const statePath = join(checkpointPath, 'training_state.json');
|
|
420
|
+
writeFileSync(statePath, JSON.stringify(state, null, 2));
|
|
421
|
+
// Save model weights (use trained model from engine, not original)
|
|
422
|
+
const trainedModel = this.engine.getModel();
|
|
423
|
+
await trainedModel.saveModel(checkpointPath);
|
|
424
|
+
// Copy tokenizer files
|
|
425
|
+
const tokenizerSource = this.originalModelPath ?? this.config.model_name;
|
|
426
|
+
if (tokenizerSource) {
|
|
427
|
+
const tokenizerFiles = ['tokenizer.json', 'tokenizer_config.json', 'vocab.json', 'merges.txt'];
|
|
428
|
+
for (const file of tokenizerFiles) {
|
|
429
|
+
const srcPath = join(tokenizerSource, file);
|
|
430
|
+
const destPath = join(checkpointPath, file);
|
|
431
|
+
if (existsSync(srcPath) && !existsSync(destPath)) {
|
|
432
|
+
copyFileSync(srcPath, destPath);
|
|
433
|
+
}
|
|
434
|
+
}
|
|
435
|
+
}
|
|
436
|
+
this.logger.info(`Checkpoint saved: ${checkpointPath}`);
|
|
437
|
+
// Clean up old checkpoints
|
|
438
|
+
const maxCheckpoints = this.config.max_checkpoints;
|
|
439
|
+
if (maxCheckpoints > 0) {
|
|
440
|
+
this.cleanupOldCheckpoints(outputDir, maxCheckpoints);
|
|
441
|
+
}
|
|
442
|
+
return checkpointPath;
|
|
443
|
+
}
|
|
444
|
+
/**
|
|
445
|
+
* Remove old checkpoints, keeping only the most recent ones
|
|
446
|
+
*/
|
|
447
|
+
cleanupOldCheckpoints(outputDir, maxToKeep) {
|
|
448
|
+
try {
|
|
449
|
+
const entries = readdirSync(outputDir, { withFileTypes: true });
|
|
450
|
+
const checkpoints = [];
|
|
451
|
+
for (const entry of entries) {
|
|
452
|
+
if (!entry.isDirectory())
|
|
453
|
+
continue;
|
|
454
|
+
if (entry.name === 'final' || entry.name.startsWith('emergency-'))
|
|
455
|
+
continue;
|
|
456
|
+
const match = entry.name.match(/^checkpoint-(\d+)$/);
|
|
457
|
+
if (match) {
|
|
458
|
+
checkpoints.push({
|
|
459
|
+
name: entry.name,
|
|
460
|
+
step: parseInt(match[1], 10),
|
|
461
|
+
});
|
|
462
|
+
}
|
|
463
|
+
}
|
|
464
|
+
checkpoints.sort((a, b) => b.step - a.step);
|
|
465
|
+
if (checkpoints.length > maxToKeep) {
|
|
466
|
+
const toRemove = checkpoints.slice(maxToKeep);
|
|
467
|
+
for (const checkpoint of toRemove) {
|
|
468
|
+
const checkpointPath = join(outputDir, checkpoint.name);
|
|
469
|
+
rmSync(checkpointPath, { recursive: true, force: true });
|
|
470
|
+
this.logger.debug(`Removed old checkpoint: ${checkpoint.name}`);
|
|
471
|
+
}
|
|
472
|
+
}
|
|
473
|
+
}
|
|
474
|
+
catch (error) {
|
|
475
|
+
this.logger.warn(`Failed to cleanup old checkpoints: ${error}`);
|
|
476
|
+
}
|
|
477
|
+
}
|
|
478
|
+
/**
|
|
479
|
+
* Get current training step
|
|
480
|
+
*/
|
|
481
|
+
get step() {
|
|
482
|
+
return this.engine.getStep();
|
|
483
|
+
}
|
|
484
|
+
/**
|
|
485
|
+
* Get current epoch
|
|
486
|
+
*/
|
|
487
|
+
get epoch() {
|
|
488
|
+
return this.engine.getEpoch();
|
|
489
|
+
}
|
|
490
|
+
/**
|
|
491
|
+
* Get the underlying model for inference
|
|
492
|
+
*/
|
|
493
|
+
getModel() {
|
|
494
|
+
return this.engine.getModel();
|
|
495
|
+
}
|
|
496
|
+
/**
|
|
497
|
+
* Get the tokenizer
|
|
498
|
+
*/
|
|
499
|
+
getTokenizer() {
|
|
500
|
+
return this.tokenizer;
|
|
501
|
+
}
|
|
502
|
+
}
|