@mlx-node/trl 0.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. package/dist/data/dataset.d.ts +22 -0
  2. package/dist/data/dataset.d.ts.map +1 -0
  3. package/dist/data/dataset.js +142 -0
  4. package/dist/data/sft-dataset.d.ts +156 -0
  5. package/dist/data/sft-dataset.d.ts.map +1 -0
  6. package/dist/data/sft-dataset.js +415 -0
  7. package/dist/index.d.ts +33 -0
  8. package/dist/index.d.ts.map +1 -0
  9. package/dist/index.js +47 -0
  10. package/dist/trainers/grpo-config.d.ts +42 -0
  11. package/dist/trainers/grpo-config.d.ts.map +1 -0
  12. package/dist/trainers/grpo-config.js +220 -0
  13. package/dist/trainers/grpo-entropy.d.ts +33 -0
  14. package/dist/trainers/grpo-entropy.d.ts.map +1 -0
  15. package/dist/trainers/grpo-entropy.js +18 -0
  16. package/dist/trainers/grpo-trainer.d.ts +602 -0
  17. package/dist/trainers/grpo-trainer.d.ts.map +1 -0
  18. package/dist/trainers/grpo-trainer.js +1439 -0
  19. package/dist/trainers/sft-config.d.ts +32 -0
  20. package/dist/trainers/sft-config.d.ts.map +1 -0
  21. package/dist/trainers/sft-config.js +186 -0
  22. package/dist/trainers/sft-trainer.d.ts +141 -0
  23. package/dist/trainers/sft-trainer.d.ts.map +1 -0
  24. package/dist/trainers/sft-trainer.js +502 -0
  25. package/dist/trainers/training-logger.d.ts +375 -0
  26. package/dist/trainers/training-logger.d.ts.map +1 -0
  27. package/dist/trainers/training-logger.js +542 -0
  28. package/dist/types.d.ts +54 -0
  29. package/dist/types.d.ts.map +1 -0
  30. package/dist/types.js +1 -0
  31. package/dist/utils/path-security.d.ts +51 -0
  32. package/dist/utils/path-security.d.ts.map +1 -0
  33. package/dist/utils/path-security.js +69 -0
  34. package/dist/utils/xml-parser.d.ts +6 -0
  35. package/dist/utils/xml-parser.d.ts.map +1 -0
  36. package/dist/utils/xml-parser.js +184 -0
  37. package/package.json +29 -0
@@ -0,0 +1,32 @@
1
+ export declare class SFTConfigError extends Error {
2
+ constructor(message: string);
3
+ }
4
+ export interface SFTTrainerConfig {
5
+ model_name: string;
6
+ output_dir: string;
7
+ run_name: string;
8
+ learning_rate: number;
9
+ batch_size: number;
10
+ gradient_accumulation_steps: number;
11
+ num_epochs: number;
12
+ max_train_samples: number;
13
+ max_grad_norm: number;
14
+ weight_decay: number;
15
+ max_seq_length: number;
16
+ completion_only: boolean;
17
+ label_smoothing: number;
18
+ logging_steps: number;
19
+ save_steps: number;
20
+ max_checkpoints: number;
21
+ log_jsonl: boolean;
22
+ tui_mode: boolean;
23
+ seed: number;
24
+ resume_from_checkpoint: string;
25
+ }
26
+ declare const DEFAULT_SFT_CONFIG: SFTTrainerConfig;
27
+ export declare function getDefaultSFTConfig(): SFTTrainerConfig;
28
+ export declare function mergeSFTConfig(base: SFTTrainerConfig, update: Partial<SFTTrainerConfig>): SFTTrainerConfig;
29
+ export declare function loadSFTTomlConfig(filePath: string): SFTTrainerConfig;
30
+ export declare function applySFTOverrides(config: SFTTrainerConfig, overrides: string[]): SFTTrainerConfig;
31
+ export { DEFAULT_SFT_CONFIG };
32
+ //# sourceMappingURL=sft-config.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"sft-config.d.ts","sourceRoot":"","sources":["../../src/trainers/sft-config.ts"],"names":[],"mappings":"AAIA,qBAAa,cAAe,SAAQ,KAAK;gBAC3B,OAAO,EAAE,MAAM;CAI5B;AAED,MAAM,WAAW,gBAAgB;IAE/B,UAAU,EAAE,MAAM,CAAC;IACnB,UAAU,EAAE,MAAM,CAAC;IACnB,QAAQ,EAAE,MAAM,CAAC;IAGjB,aAAa,EAAE,MAAM,CAAC;IACtB,UAAU,EAAE,MAAM,CAAC;IACnB,2BAA2B,EAAE,MAAM,CAAC;IACpC,UAAU,EAAE,MAAM,CAAC;IACnB,iBAAiB,EAAE,MAAM,CAAC;IAC1B,aAAa,EAAE,MAAM,CAAC;IACtB,YAAY,EAAE,MAAM,CAAC;IAGrB,cAAc,EAAE,MAAM,CAAC;IACvB,eAAe,EAAE,OAAO,CAAC;IACzB,eAAe,EAAE,MAAM,CAAC;IAGxB,aAAa,EAAE,MAAM,CAAC;IACtB,UAAU,EAAE,MAAM,CAAC;IACnB,eAAe,EAAE,MAAM,CAAC;IACxB,SAAS,EAAE,OAAO,CAAC;IACnB,QAAQ,EAAE,OAAO,CAAC;IAGlB,IAAI,EAAE,MAAM,CAAC;IACb,sBAAsB,EAAE,MAAM,CAAC;CAChC;AAED,QAAA,MAAM,kBAAkB,EAAE,gBAyBxB,CAAC;AA0GH,wBAAgB,mBAAmB,IAAI,gBAAgB,CAEtD;AAED,wBAAgB,cAAc,CAAC,IAAI,EAAE,gBAAgB,EAAE,MAAM,EAAE,OAAO,CAAC,gBAAgB,CAAC,GAAG,gBAAgB,CAa1G;AAED,wBAAgB,iBAAiB,CAAC,QAAQ,EAAE,MAAM,GAAG,gBAAgB,CAsBpE;AAED,wBAAgB,iBAAiB,CAAC,MAAM,EAAE,gBAAgB,EAAE,SAAS,EAAE,MAAM,EAAE,GAAG,gBAAgB,CAkBjG;AAED,OAAO,EAAE,kBAAkB,EAAE,CAAC"}
@@ -0,0 +1,186 @@
1
+ import { parse as parseToml } from '@std/toml';
2
+ import { readFileSync } from 'node:fs';
3
+ import { resolve as resolvePath } from 'node:path';
4
+ export class SFTConfigError extends Error {
5
+ constructor(message) {
6
+ super(message);
7
+ this.name = 'SFTConfigError';
8
+ }
9
+ }
10
+ const DEFAULT_SFT_CONFIG = Object.freeze({
11
+ model_name: 'Qwen/Qwen3-0.6B',
12
+ output_dir: 'outputs/sft',
13
+ run_name: 'sft-run',
14
+ learning_rate: 2e-5,
15
+ batch_size: 4,
16
+ gradient_accumulation_steps: 1,
17
+ num_epochs: 3,
18
+ max_train_samples: 0,
19
+ max_grad_norm: 1.0,
20
+ weight_decay: 0.01,
21
+ max_seq_length: 2048,
22
+ completion_only: false, // Changed to false for TRL parity
23
+ label_smoothing: 0.0,
24
+ logging_steps: 10,
25
+ save_steps: 100,
26
+ max_checkpoints: 3,
27
+ log_jsonl: true,
28
+ tui_mode: false,
29
+ seed: 42,
30
+ resume_from_checkpoint: '',
31
+ });
32
+ const SFT_CONFIG_VALUE_TYPES = {
33
+ model_name: 'string',
34
+ output_dir: 'string',
35
+ run_name: 'string',
36
+ learning_rate: 'number',
37
+ batch_size: 'number',
38
+ gradient_accumulation_steps: 'number',
39
+ num_epochs: 'number',
40
+ max_train_samples: 'number',
41
+ max_grad_norm: 'number',
42
+ weight_decay: 'number',
43
+ max_seq_length: 'number',
44
+ completion_only: 'boolean',
45
+ label_smoothing: 'number',
46
+ logging_steps: 'number',
47
+ save_steps: 'number',
48
+ max_checkpoints: 'number',
49
+ log_jsonl: 'boolean',
50
+ tui_mode: 'boolean',
51
+ seed: 'number',
52
+ resume_from_checkpoint: 'string',
53
+ };
54
+ const SFT_INTEGER_KEYS = new Set([
55
+ 'batch_size',
56
+ 'gradient_accumulation_steps',
57
+ 'num_epochs',
58
+ 'max_train_samples',
59
+ 'max_seq_length',
60
+ 'logging_steps',
61
+ 'save_steps',
62
+ 'max_checkpoints',
63
+ 'seed',
64
+ ]);
65
+ function cloneDefaults() {
66
+ return { ...DEFAULT_SFT_CONFIG };
67
+ }
68
+ function isConfigKey(value) {
69
+ return Object.prototype.hasOwnProperty.call(SFT_CONFIG_VALUE_TYPES, value);
70
+ }
71
+ function coerceBoolean(value, key) {
72
+ if (typeof value === 'boolean')
73
+ return value;
74
+ if (typeof value === 'string') {
75
+ const normalized = value.trim().toLowerCase();
76
+ if (['true', '1', 'yes', 'on'].includes(normalized))
77
+ return true;
78
+ if (['false', '0', 'no', 'off'].includes(normalized))
79
+ return false;
80
+ }
81
+ throw new SFTConfigError(`Invalid boolean for ${key}: ${String(value)}`);
82
+ }
83
+ function coerceNumber(value, key) {
84
+ if (typeof value === 'string' && value.trim() === '') {
85
+ throw new SFTConfigError(`Invalid number for ${key}: empty string`);
86
+ }
87
+ const parsed = typeof value === 'number' ? value : Number(value);
88
+ if (!Number.isFinite(parsed)) {
89
+ throw new SFTConfigError(`Invalid number for ${key}: ${String(value)}`);
90
+ }
91
+ if (SFT_INTEGER_KEYS.has(key) && !Number.isInteger(parsed)) {
92
+ throw new SFTConfigError(`Expected integer for ${key}, received ${parsed}`);
93
+ }
94
+ return parsed;
95
+ }
96
+ function coerceString(value, key) {
97
+ if (typeof value === 'string')
98
+ return value;
99
+ throw new SFTConfigError(`Invalid string for ${key}: ${String(value)}`);
100
+ }
101
+ function coerceValue(key, value) {
102
+ const expected = SFT_CONFIG_VALUE_TYPES[key];
103
+ if (expected === 'boolean') {
104
+ return coerceBoolean(value, key);
105
+ }
106
+ if (expected === 'number') {
107
+ return coerceNumber(value, key);
108
+ }
109
+ return coerceString(value, key);
110
+ }
111
+ function setConfigValue(config, key, value) {
112
+ config[key] = value;
113
+ }
114
+ function normalizeTomlRecord(record) {
115
+ const normalized = {};
116
+ for (const [rawKey, rawValue] of Object.entries(record)) {
117
+ if (!isConfigKey(rawKey)) {
118
+ continue;
119
+ }
120
+ setConfigValue(normalized, rawKey, coerceValue(rawKey, rawValue));
121
+ }
122
+ return normalized;
123
+ }
124
+ export function getDefaultSFTConfig() {
125
+ return cloneDefaults();
126
+ }
127
+ export function mergeSFTConfig(base, update) {
128
+ if (!update) {
129
+ return { ...base };
130
+ }
131
+ const result = { ...base };
132
+ for (const [key, value] of Object.entries(update)) {
133
+ if (value === undefined)
134
+ continue;
135
+ if (!isConfigKey(key)) {
136
+ throw new SFTConfigError(`Unknown configuration key: ${key}`);
137
+ }
138
+ setConfigValue(result, key, value);
139
+ }
140
+ return result;
141
+ }
142
+ export function loadSFTTomlConfig(filePath) {
143
+ const absolutePath = resolvePath(filePath);
144
+ let fileContents;
145
+ try {
146
+ fileContents = readFileSync(absolutePath, 'utf8');
147
+ }
148
+ catch (error) {
149
+ const message = error instanceof Error ? error.message : String(error);
150
+ throw new SFTConfigError(`Failed to read config at ${absolutePath}: ${message}`);
151
+ }
152
+ let parsedRaw;
153
+ try {
154
+ parsedRaw = parseToml(fileContents);
155
+ }
156
+ catch (error) {
157
+ const message = error instanceof Error ? error.message : String(error);
158
+ throw new SFTConfigError(`Failed to parse TOML at ${absolutePath}: ${message}`);
159
+ }
160
+ if (parsedRaw === null || typeof parsedRaw !== 'object' || Array.isArray(parsedRaw)) {
161
+ throw new SFTConfigError(`Expected table at ${absolutePath}`);
162
+ }
163
+ const parsed = parsedRaw;
164
+ const normalized = normalizeTomlRecord(parsed);
165
+ return mergeSFTConfig(getDefaultSFTConfig(), normalized);
166
+ }
167
+ export function applySFTOverrides(config, overrides) {
168
+ if (!overrides.length) {
169
+ return { ...config };
170
+ }
171
+ const accumulated = {};
172
+ for (const entry of overrides) {
173
+ const idx = entry.indexOf('=');
174
+ if (idx === -1) {
175
+ throw new SFTConfigError(`Invalid override "${entry}", expected key=value format`);
176
+ }
177
+ const key = entry.slice(0, idx).trim();
178
+ const rawValue = entry.slice(idx + 1).trim();
179
+ if (!isConfigKey(key)) {
180
+ throw new SFTConfigError(`Unknown configuration key in override: ${key}`);
181
+ }
182
+ setConfigValue(accumulated, key, coerceValue(key, rawValue));
183
+ }
184
+ return mergeSFTConfig(config, accumulated);
185
+ }
186
+ export { DEFAULT_SFT_CONFIG };
@@ -0,0 +1,141 @@
1
+ /**
2
+ * SFT (Supervised Fine-Tuning) Trainer
3
+ *
4
+ * This module provides a Rust-native SFT training engine for training
5
+ * models on fixed prompt-completion pairs using cross-entropy loss.
6
+ *
7
+ * ## Key Features
8
+ * - Training loop runs in Rust (eliminates FFI overhead)
9
+ * - Cross-entropy loss with completion masking (ignore_index=-100)
10
+ * - Label smoothing support
11
+ * - Gradient accumulation and clipping
12
+ * - High-level train() method for full training runs
13
+ * - Low-level trainStep() for custom training loops
14
+ *
15
+ * ## Usage
16
+ * ```typescript
17
+ * const trainer = await SFTTrainer.create({
18
+ * modelPath: './model',
19
+ * learningRate: 2e-5,
20
+ * numEpochs: 3,
21
+ * });
22
+ * await trainer.train(dataset);
23
+ * ```
24
+ */
25
+ import { Qwen3Model, Qwen3Tokenizer, type SftStepMetrics } from '@mlx-node/core';
26
+ import type { SFTTrainerConfig } from './sft-config';
27
+ import { SFTDataset, type SFTBatch } from '../data/sft-dataset';
28
+ import { type TrainingLogger } from './training-logger';
29
+ export { SftTrainingEngine } from '@mlx-node/core';
30
+ export type { SftEngineConfig, SftStepMetrics, SftEpochMetrics } from '@mlx-node/core';
31
+ /**
32
+ * Training state saved with checkpoints for resumption
33
+ */
34
+ export interface SFTTrainingState {
35
+ step: number;
36
+ epoch: number;
37
+ timestamp: string;
38
+ trainerType: 'sft';
39
+ }
40
+ /**
41
+ * Training step result
42
+ */
43
+ export interface SFTTrainStepResult {
44
+ /** Step metrics */
45
+ metrics: SftStepMetrics;
46
+ /** Current epoch */
47
+ epoch: number;
48
+ }
49
+ /**
50
+ * SFT Trainer - Rust-Native Training Engine
51
+ *
52
+ * Provides a TypeScript-friendly interface to the Rust SFT training engine.
53
+ */
54
+ export declare class SFTTrainer {
55
+ private engine;
56
+ private model;
57
+ private tokenizer;
58
+ private config;
59
+ private currentEpoch;
60
+ private currentStep;
61
+ /** Original model path (for tokenizer files when saving checkpoints) */
62
+ private originalModelPath?;
63
+ private paused;
64
+ private stopRequested;
65
+ private stdinInterface?;
66
+ private logger;
67
+ private sampleDisplayMode;
68
+ /**
69
+ * Create a new SFT trainer from a model
70
+ *
71
+ * @param model - Pre-loaded Qwen3 model
72
+ * @param tokenizer - Pre-loaded tokenizer
73
+ * @param config - Training configuration
74
+ * @param logger - Optional custom logger
75
+ */
76
+ constructor(model: Qwen3Model, tokenizer: Qwen3Tokenizer, config?: Partial<SFTTrainerConfig>, logger?: TrainingLogger);
77
+ /**
78
+ * Setup stdin handler for TUI control commands
79
+ */
80
+ private setupStdinHandler;
81
+ /**
82
+ * Handle a command received from stdin
83
+ */
84
+ private handleStdinCommand;
85
+ /**
86
+ * Wait for resume if paused
87
+ */
88
+ private waitForResume;
89
+ /**
90
+ * Create a trainer by loading a model from disk
91
+ *
92
+ * @param config - Configuration including modelPath
93
+ * @returns Promise<SFTTrainer>
94
+ */
95
+ static create(config: Partial<SFTTrainerConfig>): Promise<SFTTrainer>;
96
+ /**
97
+ * Find the latest checkpoint in the output directory
98
+ */
99
+ static findLatestCheckpoint(outputDir?: string): string | null;
100
+ /**
101
+ * Run a single training step
102
+ *
103
+ * @param batch - Tokenized batch with input_ids and labels
104
+ * @returns Training step metrics
105
+ */
106
+ trainStep(batch: SFTBatch): Promise<SFTTrainStepResult>;
107
+ /**
108
+ * Run a full training loop over a dataset
109
+ *
110
+ * @param dataset - SFT dataset or path to JSONL file
111
+ */
112
+ train(dataset: SFTDataset | string): Promise<void>;
113
+ /**
114
+ * Save a checkpoint with model weights and training state
115
+ *
116
+ * @param name - Checkpoint name (default: "checkpoint-{step}")
117
+ * @returns Path to saved checkpoint
118
+ */
119
+ saveCheckpoint(name?: string): Promise<string>;
120
+ /**
121
+ * Remove old checkpoints, keeping only the most recent ones
122
+ */
123
+ private cleanupOldCheckpoints;
124
+ /**
125
+ * Get current training step
126
+ */
127
+ get step(): number;
128
+ /**
129
+ * Get current epoch
130
+ */
131
+ get epoch(): number;
132
+ /**
133
+ * Get the underlying model for inference
134
+ */
135
+ getModel(): Qwen3Model;
136
+ /**
137
+ * Get the tokenizer
138
+ */
139
+ getTokenizer(): Qwen3Tokenizer;
140
+ }
141
+ //# sourceMappingURL=sft-trainer.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"sft-trainer.d.ts","sourceRoot":"","sources":["../../src/trainers/sft-trainer.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;;;;;;;;;GAuBG;AAMH,OAAO,EAEL,UAAU,EACV,cAAc,EAGd,KAAK,cAAc,EAEpB,MAAM,gBAAgB,CAAC;AAExB,OAAO,KAAK,EAAE,gBAAgB,EAAE,MAAM,cAAc,CAAC;AAErD,OAAO,EAAE,UAAU,EAAkB,KAAK,QAAQ,EAAE,MAAM,qBAAqB,CAAC;AAChF,OAAO,EAAwB,KAAK,cAAc,EAAE,MAAM,mBAAmB,CAAC;AAG9E,OAAO,EAAE,iBAAiB,EAAE,MAAM,gBAAgB,CAAC;AACnD,YAAY,EAAE,eAAe,EAAE,cAAc,EAAE,eAAe,EAAE,MAAM,gBAAgB,CAAC;AAEvF;;GAEG;AACH,MAAM,WAAW,gBAAgB;IAC/B,IAAI,EAAE,MAAM,CAAC;IACb,KAAK,EAAE,MAAM,CAAC;IACd,SAAS,EAAE,MAAM,CAAC;IAClB,WAAW,EAAE,KAAK,CAAC;CACpB;AAED;;GAEG;AACH,MAAM,WAAW,kBAAkB;IACjC,mBAAmB;IACnB,OAAO,EAAE,cAAc,CAAC;IACxB,oBAAoB;IACpB,KAAK,EAAE,MAAM,CAAC;CACf;AAED;;;;GAIG;AACH,qBAAa,UAAU;IACrB,OAAO,CAAC,MAAM,CAAoB;IAClC,OAAO,CAAC,KAAK,CAAa;IAC1B,OAAO,CAAC,SAAS,CAAiB;IAClC,OAAO,CAAC,MAAM,CAAmB;IACjC,OAAO,CAAC,YAAY,CAAa;IACjC,OAAO,CAAC,WAAW,CAAa;IAChC,wEAAwE;IACxE,OAAO,CAAC,iBAAiB,CAAC,CAAS;IAGnC,OAAO,CAAC,MAAM,CAAkB;IAChC,OAAO,CAAC,aAAa,CAAkB;IACvC,OAAO,CAAC,cAAc,CAAC,CAAqB;IAC5C,OAAO,CAAC,MAAM,CAAiB;IAC/B,OAAO,CAAC,iBAAiB,CAA0C;IAEnE;;;;;;;OAOG;gBAED,KAAK,EAAE,UAAU,EACjB,SAAS,EAAE,cAAc,EACzB,MAAM,GAAE,OAAO,CAAC,gBAAgB,CAAM,EACtC,MAAM,CAAC,EAAE,cAAc;IAwCzB;;OAEG;IACH,OAAO,CAAC,iBAAiB;IAezB;;OAEG;IACH,OAAO,CAAC,kBAAkB;IAmC1B;;OAEG;YACW,aAAa;IAM3B;;;;;OAKG;WACU,MAAM,CAAC,MAAM,EAAE,OAAO,CAAC,gBAAgB,CAAC,GAAG,OAAO,CAAC,UAAU,CAAC;IA+D3E;;OAEG;IACH,MAAM,CAAC,oBAAoB,CAAC,SAAS,CAAC,EAAE,MAAM,GAAG,MAAM,GAAG,IAAI;IAmB9D;;;;;OAKG;IACG,SAAS,CAAC,KAAK,EAAE,QAAQ,GAAG,OAAO,CAAC,kBAAkB,CAAC;IAqB7D;;;;OAIG;IACG,KAAK,CAAC,OAAO,EAAE,UAAU,GAAG,MAAM,GAAG,OAAO,CAAC,IAAI,CAAC;IA8KxD;;;;;OAKG;IACG,cAAc,CAAC,IAAI,CAAC,EAAE,MAAM,GAAG,OAAO,CAAC,MAAM,CAAC;IAgDpD;;OAEG;IACH,OAAO,CAAC,qBAAqB;IAiC7B;;OAEG;IACH,IAAI,IAAI,IAAI,MAAM,CAEjB;IAED;;OAEG;IACH,IAAI,KAAK,IAAI,MAAM,CAElB;IAED;;OAEG;IACH,QAAQ,IAAI,UAAU;IAItB;;OAEG;IACH,YAAY,IAAI,cAAc;CAG/B"}