@mlx-node/trl 0.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. package/dist/data/dataset.d.ts +22 -0
  2. package/dist/data/dataset.d.ts.map +1 -0
  3. package/dist/data/dataset.js +142 -0
  4. package/dist/data/sft-dataset.d.ts +156 -0
  5. package/dist/data/sft-dataset.d.ts.map +1 -0
  6. package/dist/data/sft-dataset.js +415 -0
  7. package/dist/index.d.ts +33 -0
  8. package/dist/index.d.ts.map +1 -0
  9. package/dist/index.js +47 -0
  10. package/dist/trainers/grpo-config.d.ts +42 -0
  11. package/dist/trainers/grpo-config.d.ts.map +1 -0
  12. package/dist/trainers/grpo-config.js +220 -0
  13. package/dist/trainers/grpo-entropy.d.ts +33 -0
  14. package/dist/trainers/grpo-entropy.d.ts.map +1 -0
  15. package/dist/trainers/grpo-entropy.js +18 -0
  16. package/dist/trainers/grpo-trainer.d.ts +602 -0
  17. package/dist/trainers/grpo-trainer.d.ts.map +1 -0
  18. package/dist/trainers/grpo-trainer.js +1439 -0
  19. package/dist/trainers/sft-config.d.ts +32 -0
  20. package/dist/trainers/sft-config.d.ts.map +1 -0
  21. package/dist/trainers/sft-config.js +186 -0
  22. package/dist/trainers/sft-trainer.d.ts +141 -0
  23. package/dist/trainers/sft-trainer.d.ts.map +1 -0
  24. package/dist/trainers/sft-trainer.js +502 -0
  25. package/dist/trainers/training-logger.d.ts +375 -0
  26. package/dist/trainers/training-logger.d.ts.map +1 -0
  27. package/dist/trainers/training-logger.js +542 -0
  28. package/dist/types.d.ts +54 -0
  29. package/dist/types.d.ts.map +1 -0
  30. package/dist/types.js +1 -0
  31. package/dist/utils/path-security.d.ts +51 -0
  32. package/dist/utils/path-security.d.ts.map +1 -0
  33. package/dist/utils/path-security.js +69 -0
  34. package/dist/utils/xml-parser.d.ts +6 -0
  35. package/dist/utils/xml-parser.d.ts.map +1 -0
  36. package/dist/utils/xml-parser.js +184 -0
  37. package/package.json +29 -0
@@ -0,0 +1,502 @@
1
+ /**
2
+ * SFT (Supervised Fine-Tuning) Trainer
3
+ *
4
+ * This module provides a Rust-native SFT training engine for training
5
+ * models on fixed prompt-completion pairs using cross-entropy loss.
6
+ *
7
+ * ## Key Features
8
+ * - Training loop runs in Rust (eliminates FFI overhead)
9
+ * - Cross-entropy loss with completion masking (ignore_index=-100)
10
+ * - Label smoothing support
11
+ * - Gradient accumulation and clipping
12
+ * - High-level train() method for full training runs
13
+ * - Low-level trainStep() for custom training loops
14
+ *
15
+ * ## Usage
16
+ * ```typescript
17
+ * const trainer = await SFTTrainer.create({
18
+ * modelPath: './model',
19
+ * learningRate: 2e-5,
20
+ * numEpochs: 3,
21
+ * });
22
+ * await trainer.train(dataset);
23
+ * ```
24
+ */
25
+ import { existsSync, mkdirSync, writeFileSync, readFileSync, readdirSync, copyFileSync, rmSync } from 'node:fs';
26
+ import { join, parse } from 'node:path';
27
+ import * as readline from 'node:readline';
28
+ import { SftTrainingEngine, Qwen3Model, Qwen3Tokenizer, MxArray, } from '@mlx-node/core';
29
+ import { getDefaultSFTConfig, mergeSFTConfig } from './sft-config';
30
+ import { loadSFTDataset } from '../data/sft-dataset';
31
+ import { createTrainingLogger } from './training-logger';
32
+ // Re-export types
33
+ export { SftTrainingEngine } from '@mlx-node/core';
34
+ /**
35
+ * SFT Trainer - Rust-Native Training Engine
36
+ *
37
+ * Provides a TypeScript-friendly interface to the Rust SFT training engine.
38
+ */
39
+ export class SFTTrainer {
40
+ engine;
41
+ model;
42
+ tokenizer;
43
+ config;
44
+ currentEpoch = 0;
45
+ currentStep = 0;
46
+ /** Original model path (for tokenizer files when saving checkpoints) */
47
+ originalModelPath;
48
+ // TUI state
49
+ paused = false;
50
+ stopRequested = false;
51
+ stdinInterface;
52
+ logger;
53
+ sampleDisplayMode = 'all';
54
+ /**
55
+ * Create a new SFT trainer from a model
56
+ *
57
+ * @param model - Pre-loaded Qwen3 model
58
+ * @param tokenizer - Pre-loaded tokenizer
59
+ * @param config - Training configuration
60
+ * @param logger - Optional custom logger
61
+ */
62
+ constructor(model, tokenizer, config = {}, logger) {
63
+ // Auto-detect TUI mode from environment variable
64
+ const tuiModeFromEnv = process.env.MLX_TUI_MODE === '1';
65
+ if (tuiModeFromEnv && config.tui_mode === undefined) {
66
+ config.tui_mode = true;
67
+ }
68
+ this.config = mergeSFTConfig(getDefaultSFTConfig(), config);
69
+ this.model = model;
70
+ this.tokenizer = tokenizer;
71
+ // Create or use provided logger
72
+ this.logger =
73
+ logger ??
74
+ createTrainingLogger({
75
+ logConsole: !this.config.tui_mode,
76
+ logJsonl: this.config.log_jsonl,
77
+ outputDir: this.config.output_dir,
78
+ runName: this.config.run_name,
79
+ logInterval: this.config.logging_steps,
80
+ });
81
+ // Convert to native config
82
+ const engineConfig = {
83
+ learningRate: this.config.learning_rate,
84
+ gradientAccumulationSteps: this.config.gradient_accumulation_steps,
85
+ gradientClipNorm: this.config.max_grad_norm,
86
+ weightDecay: this.config.weight_decay,
87
+ labelSmoothing: this.config.label_smoothing,
88
+ };
89
+ this.engine = new SftTrainingEngine(model, engineConfig);
90
+ // Setup stdin handler if TUI mode
91
+ if (this.config.tui_mode) {
92
+ this.setupStdinHandler();
93
+ }
94
+ }
95
+ /**
96
+ * Setup stdin handler for TUI control commands
97
+ */
98
+ setupStdinHandler() {
99
+ if (!this.config.tui_mode)
100
+ return;
101
+ this.stdinInterface = readline.createInterface({
102
+ input: process.stdin,
103
+ output: process.stdout,
104
+ terminal: false,
105
+ });
106
+ this.stdinInterface.on('line', (line) => {
107
+ const cmd = line.trim();
108
+ this.handleStdinCommand(cmd);
109
+ });
110
+ }
111
+ /**
112
+ * Handle a command received from stdin
113
+ */
114
+ handleStdinCommand(cmd) {
115
+ switch (cmd) {
116
+ case 'PAUSE':
117
+ this.paused = true;
118
+ this.logger.paused(this.currentStep);
119
+ break;
120
+ case 'RESUME':
121
+ this.paused = false;
122
+ this.logger.resumed(this.currentStep);
123
+ break;
124
+ case 'SAVE_CHECKPOINT':
125
+ this.saveCheckpoint().catch(() => { });
126
+ break;
127
+ case 'STOP':
128
+ this.stopRequested = true;
129
+ break;
130
+ default:
131
+ // Handle SET commands (e.g., SET sample_display=best_worst)
132
+ if (cmd.startsWith('SET ')) {
133
+ const keyValue = cmd.slice(4); // Remove 'SET ' prefix
134
+ const eqIdx = keyValue.indexOf('=');
135
+ if (eqIdx > 0) {
136
+ const key = keyValue.slice(0, eqIdx);
137
+ const value = keyValue.slice(eqIdx + 1);
138
+ if (key === 'sample_display') {
139
+ if (value === 'all' || value === 'best_worst' || value === 'random') {
140
+ this.sampleDisplayMode = value;
141
+ }
142
+ }
143
+ }
144
+ }
145
+ break;
146
+ }
147
+ }
148
+ /**
149
+ * Wait for resume if paused
150
+ */
151
+ async waitForResume() {
152
+ while (this.paused && !this.stopRequested) {
153
+ await new Promise((resolve) => setTimeout(resolve, 100));
154
+ }
155
+ }
156
+ /**
157
+ * Create a trainer by loading a model from disk
158
+ *
159
+ * @param config - Configuration including modelPath
160
+ * @returns Promise<SFTTrainer>
161
+ */
162
+ static async create(config) {
163
+ if (!config.model_name) {
164
+ throw new Error('model_name is required when using SFTTrainer.create()');
165
+ }
166
+ // Create logger early
167
+ const logger = createTrainingLogger({
168
+ logConsole: !config.tui_mode,
169
+ logJsonl: config.log_jsonl ?? true,
170
+ outputDir: config.output_dir,
171
+ runName: config.run_name,
172
+ logInterval: config.logging_steps ?? 10,
173
+ });
174
+ let modelPath = config.model_name;
175
+ let resumedState = null;
176
+ // Handle checkpoint resumption
177
+ if (config.resume_from_checkpoint) {
178
+ const checkpointPath = config.resume_from_checkpoint === 'latest'
179
+ ? SFTTrainer.findLatestCheckpoint(config.output_dir)
180
+ : config.resume_from_checkpoint;
181
+ if (checkpointPath) {
182
+ const statePath = join(checkpointPath, 'training_state.json');
183
+ if (existsSync(statePath)) {
184
+ resumedState = JSON.parse(readFileSync(statePath, 'utf-8'));
185
+ logger.info(`Resuming from checkpoint: ${checkpointPath} (step ${resumedState?.step}, epoch ${resumedState?.epoch})`);
186
+ }
187
+ modelPath = checkpointPath;
188
+ }
189
+ else if (config.resume_from_checkpoint === 'latest') {
190
+ logger.info('No checkpoint found, starting fresh training');
191
+ }
192
+ }
193
+ // Get model name for display
194
+ const modelName = parse(modelPath).base || 'Unknown';
195
+ logger.status('loading', `Loading ${modelName}...`);
196
+ // Load model and tokenizer
197
+ const model = await Qwen3Model.loadPretrained(modelPath);
198
+ const tokenizer = await Qwen3Tokenizer.fromPretrained(join(modelPath, 'tokenizer.json'));
199
+ logger.status('loading', `${modelName} loaded`);
200
+ // Create trainer
201
+ const trainer = new SFTTrainer(model, tokenizer, config, logger);
202
+ trainer.originalModelPath = config.model_name;
203
+ // Restore training state if resuming
204
+ if (resumedState) {
205
+ trainer.currentStep = resumedState.step;
206
+ trainer.currentEpoch = resumedState.epoch;
207
+ // Also restore engine state to sync step/epoch accounting
208
+ trainer.engine.restoreState(resumedState.step, resumedState.epoch);
209
+ }
210
+ return trainer;
211
+ }
212
+ /**
213
+ * Find the latest checkpoint in the output directory
214
+ */
215
+ static findLatestCheckpoint(outputDir) {
216
+ if (!outputDir || !existsSync(outputDir)) {
217
+ return null;
218
+ }
219
+ const entries = readdirSync(outputDir, { withFileTypes: true });
220
+ const checkpoints = entries
221
+ .filter((e) => e.isDirectory() && e.name.startsWith('checkpoint-'))
222
+ .map((e) => ({
223
+ name: e.name,
224
+ step: parseInt(e.name.replace('checkpoint-', ''), 10),
225
+ path: join(outputDir, e.name),
226
+ }))
227
+ .filter((c) => !isNaN(c.step))
228
+ .sort((a, b) => b.step - a.step);
229
+ return checkpoints.length > 0 ? checkpoints[0].path : null;
230
+ }
231
+ /**
232
+ * Run a single training step
233
+ *
234
+ * @param batch - Tokenized batch with input_ids and labels
235
+ * @returns Training step metrics
236
+ */
237
+ async trainStep(batch) {
238
+ // Convert Int32Array to MxArray
239
+ const inputIds = MxArray.fromInt32(batch.inputIds, BigInt64Array.from(batch.shape.map(BigInt)));
240
+ const labels = MxArray.fromInt32(batch.labels, BigInt64Array.from(batch.shape.map(BigInt)));
241
+ // Call native engine
242
+ const metrics = await this.engine.trainStep(inputIds, labels);
243
+ // Sync step with engine when gradients are applied (fixes gradient accumulation accounting)
244
+ // Note: metrics.step is i64 from Rust; JS number may lose precision beyond 2^53-1,
245
+ // but such step counts are unrealistic for any practical training run.
246
+ if (metrics.gradientsApplied) {
247
+ this.currentStep = Number(metrics.step);
248
+ }
249
+ return {
250
+ metrics,
251
+ epoch: this.currentEpoch,
252
+ };
253
+ }
254
+ /**
255
+ * Run a full training loop over a dataset
256
+ *
257
+ * @param dataset - SFT dataset or path to JSONL file
258
+ */
259
+ async train(dataset) {
260
+ // Load dataset if path provided
261
+ let sftDataset;
262
+ if (typeof dataset === 'string') {
263
+ sftDataset = await loadSFTDataset(dataset, this.tokenizer, {
264
+ maxSeqLength: this.config.max_seq_length,
265
+ completionOnly: this.config.completion_only,
266
+ seed: this.config.seed,
267
+ limit: this.config.max_train_samples > 0 ? this.config.max_train_samples : undefined,
268
+ });
269
+ }
270
+ else {
271
+ sftDataset = dataset;
272
+ }
273
+ if (sftDataset.length === 0) {
274
+ return;
275
+ }
276
+ const numEpochs = this.config.num_epochs;
277
+ const batchSize = this.config.batch_size;
278
+ const saveInterval = this.config.save_steps;
279
+ // Create output directory
280
+ if (this.config.output_dir && !existsSync(this.config.output_dir)) {
281
+ mkdirSync(this.config.output_dir, { recursive: true });
282
+ }
283
+ // Calculate steps per epoch (in batches)
284
+ const stepsPerEpoch = sftDataset.numBatches(batchSize);
285
+ // Compute resume position (all logic centralized in Rust)
286
+ const resumePos = this.engine.computeResumePosition(stepsPerEpoch);
287
+ const effectiveStartEpoch = resumePos.startEpoch;
288
+ const effectiveStartBatchIdx = resumePos.startBatchIdx;
289
+ // Get model name
290
+ const modelName = (this.originalModelPath ? parse(this.originalModelPath).base : null) ??
291
+ (this.config.model_name ? parse(this.config.model_name).base : null) ??
292
+ 'Unknown';
293
+ // Log training start
294
+ this.logger.init(modelName, {
295
+ trainingType: 'sft',
296
+ numEpochs,
297
+ batchSize,
298
+ groupSize: 1, // SFT doesn't use groups
299
+ learningRate: this.config.learning_rate,
300
+ }, sftDataset.length);
301
+ if (this.currentStep > 0) {
302
+ if (resumePos.isEpochBoundary) {
303
+ this.logger.info(`Resuming at epoch boundary, advancing to epoch ${effectiveStartEpoch + 1}`);
304
+ }
305
+ else {
306
+ this.logger.info(`Resuming from step ${this.currentStep} (epoch ${effectiveStartEpoch + 1}, batch ${effectiveStartBatchIdx + 1}/${stepsPerEpoch})`);
307
+ }
308
+ }
309
+ for (let epoch = effectiveStartEpoch; epoch < numEpochs; epoch++) {
310
+ if (this.stopRequested)
311
+ break;
312
+ this.currentEpoch = epoch;
313
+ this.engine.startEpoch(epoch);
314
+ const epochStartTime = Date.now();
315
+ // Use epoch-based shuffle (deterministic, reproducible via seed + epoch)
316
+ sftDataset.shuffleForEpoch(epoch);
317
+ // Log epoch start
318
+ this.logger.epochStart(epoch, numEpochs, stepsPerEpoch);
319
+ // Determine batch start position for this epoch
320
+ const batchStart = epoch === effectiveStartEpoch ? effectiveStartBatchIdx : 0;
321
+ // Iterate through batches
322
+ let batchIdx = 0;
323
+ for await (const batch of sftDataset.batches(batchSize)) {
324
+ if (this.stopRequested)
325
+ break;
326
+ // Skip batches if resuming mid-epoch
327
+ if (batchIdx < batchStart) {
328
+ batchIdx++;
329
+ continue;
330
+ }
331
+ // Wait if paused
332
+ if (this.paused) {
333
+ await this.waitForResume();
334
+ if (this.stopRequested)
335
+ break;
336
+ }
337
+ // Run training step
338
+ const { metrics } = await this.trainStep(batch);
339
+ // Log step metrics (only when gradients are applied to avoid duplicate logs during accumulation)
340
+ if (metrics.gradientsApplied) {
341
+ this.logger.step({
342
+ step: this.currentStep,
343
+ loss: metrics.loss,
344
+ totalTokens: metrics.totalTokens,
345
+ // SFT-specific metrics (no reward/advantage!)
346
+ perplexity: Math.exp(metrics.loss),
347
+ // Token accuracy is not currently tracked in the SFT engine
348
+ // Could be added later if the Rust engine exposes it
349
+ trainingTimeMs: metrics.trainingTimeMs,
350
+ }, batchIdx, stepsPerEpoch);
351
+ // Save checkpoint periodically
352
+ if (this.config.output_dir && this.currentStep > 0 && this.currentStep % saveInterval === 0) {
353
+ const path = await this.saveCheckpoint();
354
+ if (path) {
355
+ this.logger.checkpoint(path, this.currentStep);
356
+ }
357
+ }
358
+ }
359
+ // Check for emergency checkpoint
360
+ if (this.config.output_dir && this.engine.needsEmergencySave()) {
361
+ this.logger.warn(`[EMERGENCY] Saving emergency checkpoint at step ${this.currentStep} due to NaN gradients`);
362
+ await this.saveCheckpoint(`emergency-checkpoint-${this.currentStep}`);
363
+ this.engine.clearEmergencySave();
364
+ }
365
+ batchIdx++;
366
+ }
367
+ // Flush any remaining accumulated gradients (TRL parity)
368
+ const flushed = this.engine.flushGradients();
369
+ if (flushed) {
370
+ this.currentStep = this.engine.getStep();
371
+ // Check if flush step aligns with save interval
372
+ if (this.config.output_dir && this.currentStep > 0 && this.currentStep % saveInterval === 0) {
373
+ const path = await this.saveCheckpoint();
374
+ if (path) {
375
+ this.logger.checkpoint(path, this.currentStep);
376
+ }
377
+ }
378
+ }
379
+ const epochEndTime = Date.now();
380
+ const epochTimeSecs = (epochEndTime - epochStartTime) / 1000;
381
+ this.engine.endEpoch(epochTimeSecs);
382
+ this.logger.epochEnd(epoch, numEpochs, epochTimeSecs);
383
+ }
384
+ // Save final checkpoint
385
+ if (this.config.output_dir && !this.stopRequested) {
386
+ const path = await this.saveCheckpoint('final');
387
+ if (path) {
388
+ this.logger.checkpoint(path, this.currentStep);
389
+ }
390
+ }
391
+ // Log completion
392
+ this.logger.complete(this.currentStep);
393
+ // Cleanup
394
+ if (this.stdinInterface) {
395
+ this.stdinInterface.close();
396
+ }
397
+ }
398
+ /**
399
+ * Save a checkpoint with model weights and training state
400
+ *
401
+ * @param name - Checkpoint name (default: "checkpoint-{step}")
402
+ * @returns Path to saved checkpoint
403
+ */
404
+ async saveCheckpoint(name) {
405
+ const checkpointName = name ?? `checkpoint-${this.currentStep}`;
406
+ const outputDir = this.config.output_dir ?? './outputs';
407
+ const checkpointPath = join(outputDir, checkpointName);
408
+ // Create checkpoint directory
409
+ if (!existsSync(checkpointPath)) {
410
+ mkdirSync(checkpointPath, { recursive: true });
411
+ }
412
+ // Save training state
413
+ const state = {
414
+ step: this.currentStep,
415
+ epoch: this.currentEpoch,
416
+ timestamp: new Date().toISOString(),
417
+ trainerType: 'sft',
418
+ };
419
+ const statePath = join(checkpointPath, 'training_state.json');
420
+ writeFileSync(statePath, JSON.stringify(state, null, 2));
421
+ // Save model weights (use trained model from engine, not original)
422
+ const trainedModel = this.engine.getModel();
423
+ await trainedModel.saveModel(checkpointPath);
424
+ // Copy tokenizer files
425
+ const tokenizerSource = this.originalModelPath ?? this.config.model_name;
426
+ if (tokenizerSource) {
427
+ const tokenizerFiles = ['tokenizer.json', 'tokenizer_config.json', 'vocab.json', 'merges.txt'];
428
+ for (const file of tokenizerFiles) {
429
+ const srcPath = join(tokenizerSource, file);
430
+ const destPath = join(checkpointPath, file);
431
+ if (existsSync(srcPath) && !existsSync(destPath)) {
432
+ copyFileSync(srcPath, destPath);
433
+ }
434
+ }
435
+ }
436
+ this.logger.info(`Checkpoint saved: ${checkpointPath}`);
437
+ // Clean up old checkpoints
438
+ const maxCheckpoints = this.config.max_checkpoints;
439
+ if (maxCheckpoints > 0) {
440
+ this.cleanupOldCheckpoints(outputDir, maxCheckpoints);
441
+ }
442
+ return checkpointPath;
443
+ }
444
+ /**
445
+ * Remove old checkpoints, keeping only the most recent ones
446
+ */
447
+ cleanupOldCheckpoints(outputDir, maxToKeep) {
448
+ try {
449
+ const entries = readdirSync(outputDir, { withFileTypes: true });
450
+ const checkpoints = [];
451
+ for (const entry of entries) {
452
+ if (!entry.isDirectory())
453
+ continue;
454
+ if (entry.name === 'final' || entry.name.startsWith('emergency-'))
455
+ continue;
456
+ const match = entry.name.match(/^checkpoint-(\d+)$/);
457
+ if (match) {
458
+ checkpoints.push({
459
+ name: entry.name,
460
+ step: parseInt(match[1], 10),
461
+ });
462
+ }
463
+ }
464
+ checkpoints.sort((a, b) => b.step - a.step);
465
+ if (checkpoints.length > maxToKeep) {
466
+ const toRemove = checkpoints.slice(maxToKeep);
467
+ for (const checkpoint of toRemove) {
468
+ const checkpointPath = join(outputDir, checkpoint.name);
469
+ rmSync(checkpointPath, { recursive: true, force: true });
470
+ this.logger.debug(`Removed old checkpoint: ${checkpoint.name}`);
471
+ }
472
+ }
473
+ }
474
+ catch (error) {
475
+ this.logger.warn(`Failed to cleanup old checkpoints: ${error}`);
476
+ }
477
+ }
478
+ /**
479
+ * Get current training step
480
+ */
481
+ get step() {
482
+ return this.engine.getStep();
483
+ }
484
+ /**
485
+ * Get current epoch
486
+ */
487
+ get epoch() {
488
+ return this.engine.getEpoch();
489
+ }
490
+ /**
491
+ * Get the underlying model for inference
492
+ */
493
+ getModel() {
494
+ return this.engine.getModel();
495
+ }
496
+ /**
497
+ * Get the tokenizer
498
+ */
499
+ getTokenizer() {
500
+ return this.tokenizer;
501
+ }
502
+ }