@mlx-node/trl 0.0.0 → 0.0.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. package/README.md +389 -0
  2. package/package.json +16 -5
  3. package/dist/data/dataset.d.ts +0 -22
  4. package/dist/data/dataset.d.ts.map +0 -1
  5. package/dist/data/dataset.js +0 -142
  6. package/dist/data/sft-dataset.d.ts +0 -156
  7. package/dist/data/sft-dataset.d.ts.map +0 -1
  8. package/dist/data/sft-dataset.js +0 -415
  9. package/dist/index.d.ts +0 -33
  10. package/dist/index.d.ts.map +0 -1
  11. package/dist/index.js +0 -47
  12. package/dist/trainers/grpo-config.d.ts +0 -42
  13. package/dist/trainers/grpo-config.d.ts.map +0 -1
  14. package/dist/trainers/grpo-config.js +0 -220
  15. package/dist/trainers/grpo-entropy.d.ts +0 -33
  16. package/dist/trainers/grpo-entropy.d.ts.map +0 -1
  17. package/dist/trainers/grpo-entropy.js +0 -18
  18. package/dist/trainers/grpo-trainer.d.ts +0 -602
  19. package/dist/trainers/grpo-trainer.d.ts.map +0 -1
  20. package/dist/trainers/grpo-trainer.js +0 -1439
  21. package/dist/trainers/sft-config.d.ts +0 -32
  22. package/dist/trainers/sft-config.d.ts.map +0 -1
  23. package/dist/trainers/sft-config.js +0 -186
  24. package/dist/trainers/sft-trainer.d.ts +0 -141
  25. package/dist/trainers/sft-trainer.d.ts.map +0 -1
  26. package/dist/trainers/sft-trainer.js +0 -502
  27. package/dist/trainers/training-logger.d.ts +0 -375
  28. package/dist/trainers/training-logger.d.ts.map +0 -1
  29. package/dist/trainers/training-logger.js +0 -542
  30. package/dist/types.d.ts +0 -54
  31. package/dist/types.d.ts.map +0 -1
  32. package/dist/types.js +0 -1
  33. package/dist/utils/path-security.d.ts +0 -51
  34. package/dist/utils/path-security.d.ts.map +0 -1
  35. package/dist/utils/path-security.js +0 -69
  36. package/dist/utils/xml-parser.d.ts +0 -6
  37. package/dist/utils/xml-parser.d.ts.map +0 -1
  38. package/dist/utils/xml-parser.js +0 -184
@@ -1,156 +0,0 @@
1
- /**
2
- * SFT Dataset handling for Supervised Fine-Tuning
3
- *
4
- * Supports two data formats (auto-detected):
5
- * 1. Prompt-Completion: { prompt: ChatMessage[], completion: ChatMessage }
6
- * 2. Full Conversation: { messages: ChatMessage[] }
7
- *
8
- * Both formats produce tokenized batches with labels masked appropriately.
9
- */
10
- import type { Qwen3Tokenizer } from '@mlx-node/core';
11
- import type { ChatMessage } from '../types';
12
- import { type PathValidationOptions } from '../utils/path-security';
13
- /**
14
- * Special token IDs for SFT label masking
15
- *
16
- * These are used to detect assistant message boundaries in tokenized conversations.
17
- * The IDs can be derived from the tokenizer or provided explicitly.
18
- */
19
- export interface SpecialTokenIds {
20
- /** Token ID for <|im_start|> */
21
- imStart: number;
22
- /** Token ID for <|im_end|> */
23
- imEnd: number;
24
- /** Token IDs that represent newlines (for detecting end of role header) */
25
- newlineTokens: number[];
26
- }
27
- /**
28
- * Prompt-Completion format for tool-use training
29
- */
30
- export interface SFTPromptCompletionExample {
31
- prompt: ChatMessage[];
32
- completion: ChatMessage;
33
- }
34
- /**
35
- * Full conversation format for multi-turn dialogue
36
- */
37
- export interface SFTConversationExample {
38
- messages: ChatMessage[];
39
- }
40
- /**
41
- * Union type for SFT examples
42
- */
43
- export type SFTExample = SFTPromptCompletionExample | SFTConversationExample;
44
- /**
45
- * A tokenized batch ready for SFT training
46
- */
47
- export interface SFTBatch {
48
- inputIds: Int32Array;
49
- labels: Int32Array;
50
- shape: [number, number];
51
- }
52
- /**
53
- * Configuration for SFT dataset
54
- */
55
- export interface SFTDatasetConfig {
56
- maxSeqLength?: number;
57
- completionOnly?: boolean;
58
- enableThinking?: boolean;
59
- seed?: number;
60
- /**
61
- * Special token IDs for label masking.
62
- *
63
- * If not provided, these are automatically derived from the tokenizer.
64
- * This option allows explicit overriding for custom tokenizers or
65
- * non-standard vocabularies.
66
- */
67
- specialTokenIds?: Partial<SpecialTokenIds>;
68
- }
69
- /**
70
- * SFT Dataset class for handling SFT training data
71
- */
72
- export declare class SFTDataset {
73
- private examples;
74
- private tokenizer;
75
- private config;
76
- private format;
77
- private shuffledIndices;
78
- private rng;
79
- /** Cached special token IDs for label masking */
80
- private specialTokenIds;
81
- constructor(examples: SFTExample[], tokenizer: Qwen3Tokenizer, config?: SFTDatasetConfig);
82
- /**
83
- * Get the number of examples in the dataset
84
- */
85
- get length(): number;
86
- /**
87
- * Shuffle dataset for a specific epoch using epoch-based seeding.
88
- * This ensures reproducible shuffles across training resumes.
89
- * Each epoch gets a deterministic shuffle based on (baseSeed + epoch).
90
- *
91
- * @param epoch - The epoch number (used as seed offset)
92
- */
93
- shuffleForEpoch(epoch: number): void;
94
- /**
95
- * Create a seeded pseudo-random number generator (Linear Congruential Generator)
96
- */
97
- private createSeededRandom;
98
- /**
99
- * Find length of common prefix between two token arrays
100
- * Handles chat template quirks where prompt tokens may not be exact prefix of full tokens
101
- */
102
- private findCommonPrefixLength;
103
- /**
104
- * Tokenize a prompt-completion example
105
- */
106
- private tokenizePromptCompletion;
107
- /**
108
- * Check if a token ID is a newline token
109
- */
110
- private isNewlineToken;
111
- /**
112
- * Tokenize a conversation example
113
- *
114
- * For conversations, we train on all assistant turns.
115
- * Non-assistant tokens (system, user) are masked with -100.
116
- *
117
- * Uses single-pass tokenization with token-based boundary detection.
118
- * Token IDs are derived from the tokenizer for portability across models.
119
- */
120
- private tokenizeConversation;
121
- /**
122
- * Tokenize a single example based on its format
123
- */
124
- private tokenizeExample;
125
- /**
126
- * Collate multiple examples into a padded batch
127
- */
128
- collateBatch(indices: number[]): Promise<SFTBatch>;
129
- /**
130
- * Generate batches for training
131
- */
132
- batches(batchSize: number): AsyncGenerator<SFTBatch>;
133
- /**
134
- * Get total number of batches for a given batch size
135
- */
136
- numBatches(batchSize: number): number;
137
- }
138
- /**
139
- * Load SFT dataset from a JSONL file
140
- *
141
- * Supports two formats:
142
- * 1. Prompt-Completion: {"prompt": [...], "completion": {...}}
143
- * 2. Conversation: {"messages": [...]}
144
- *
145
- * @param path - Path to the JSONL file (relative to cwd or allowedRoot)
146
- * @param tokenizer - Qwen3 tokenizer instance
147
- * @param config - Optional configuration including path validation options
148
- */
149
- export declare function loadSFTDataset(path: string, tokenizer: Qwen3Tokenizer, config?: SFTDatasetConfig & {
150
- limit?: number;
151
- } & PathValidationOptions): Promise<SFTDataset>;
152
- /**
153
- * Create SFT dataset from examples directly
154
- */
155
- export declare function createSFTDataset(examples: SFTExample[], tokenizer: Qwen3Tokenizer, config?: SFTDatasetConfig): SFTDataset;
156
- //# sourceMappingURL=sft-dataset.d.ts.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"sft-dataset.d.ts","sourceRoot":"","sources":["../../src/data/sft-dataset.ts"],"names":[],"mappings":"AAAA;;;;;;;;GAQG;AAIH,OAAO,KAAK,EAAE,cAAc,EAAE,MAAM,gBAAgB,CAAC;AACrD,OAAO,KAAK,EAAE,WAAW,EAAE,MAAM,UAAU,CAAC;AAC5C,OAAO,EAA2C,KAAK,qBAAqB,EAAE,MAAM,wBAAwB,CAAC;AAK7G;;;;;GAKG;AACH,MAAM,WAAW,eAAe;IAC9B,gCAAgC;IAChC,OAAO,EAAE,MAAM,CAAC;IAChB,8BAA8B;IAC9B,KAAK,EAAE,MAAM,CAAC;IACd,2EAA2E;IAC3E,aAAa,EAAE,MAAM,EAAE,CAAC;CACzB;AAgDD;;GAEG;AACH,MAAM,WAAW,0BAA0B;IACzC,MAAM,EAAE,WAAW,EAAE,CAAC;IACtB,UAAU,EAAE,WAAW,CAAC;CACzB;AAED;;GAEG;AACH,MAAM,WAAW,sBAAsB;IACrC,QAAQ,EAAE,WAAW,EAAE,CAAC;CACzB;AAED;;GAEG;AACH,MAAM,MAAM,UAAU,GAAG,0BAA0B,GAAG,sBAAsB,CAAC;AAE7E;;GAEG;AACH,MAAM,WAAW,QAAQ;IACvB,QAAQ,EAAE,UAAU,CAAC;IACrB,MAAM,EAAE,UAAU,CAAC;IACnB,KAAK,EAAE,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;CACzB;AAED;;GAEG;AACH,MAAM,WAAW,gBAAgB;IAC/B,YAAY,CAAC,EAAE,MAAM,CAAC;IACtB,cAAc,CAAC,EAAE,OAAO,CAAC;IACzB,cAAc,CAAC,EAAE,OAAO,CAAC;IACzB,IAAI,CAAC,EAAE,MAAM,CAAC;IAEd;;;;;;OAMG;IACH,eAAe,CAAC,EAAE,OAAO,CAAC,eAAe,CAAC,CAAC;CAC5C;AAeD;;GAEG;AACH,qBAAa,UAAU;IACrB,OAAO,CAAC,QAAQ,CAAe;IAC/B,OAAO,CAAC,SAAS,CAAiB;IAClC,OAAO,CAAC,MAAM,CAAkF;IAChG,OAAO,CAAC,MAAM,CAAuC;IACrD,OAAO,CAAC,eAAe,CAAW;IAClC,OAAO,CAAC,GAAG,CAAe;IAC1B,iDAAiD;IACjD,OAAO,CAAC,eAAe,CAAkB;gBAE7B,QAAQ,EAAE,UAAU,EAAE,EAAE,SAAS,EAAE,cAAc,EAAE,MAAM,GAAE,gBAAqB;IAsC5F;;OAEG;IACH,IAAI,MAAM,IAAI,MAAM,CAEnB;IAED;;;;;;OAMG;IACH,eAAe,CAAC,KAAK,EAAE,MAAM,GAAG,IAAI;IAYpC;;OAEG;IACH,OAAO,CAAC,kBAAkB;IAQ1B;;;OAGG;IACH,OAAO,CAAC,sBAAsB;IAS9B;;OAEG;YACW,wBAAwB;IA8CtC;;OAEG;IACH,OAAO,CAAC,cAAc;IAItB;;;;;;;;OAQG;YACW,oBAAoB;IA2DlC;;OAEG;YACW,eAAe;IAQ7B;;OAEG;IACG,YAAY,CAAC,OAAO,EAAE,MAAM,EAAE,GAAG,OAAO,CAAC,QAAQ,CAAC;IA8CxD;;OAEG;IACI,OAAO,CAAC,SAAS,EAAE,MAAM,GAAG,cAAc,CAAC,QAAQ,CAAC;IAQ3D;;OAEG;IACH,UAAU,CAAC,SAAS,EAAE,MAAM,GAAG,MAAM;CAGtC;AAoFD;;;;;;;;;;GAUG;AACH,wBAAsB,cAAc,CAClC,IAAI,EAAE,MAAM,EACZ,SAAS,EAAE,cAAc,EACzB,MAAM,CAAC,EAAE,gBAAgB,GAAG;IAAE,KAAK,CAAC,EAAE,MAAM,CAAA;CAAE,GAAG,qBAAqB,GACrE,OAAO,CAAC,UAAU,CAAC,CAarB;AAED;;GAEG;AACH,wBAAgB,gBAAgB,CAC9B,QAAQ,EAAE,UAAU,EAAE,EACtB,SAAS,EAAE,cAAc,EACzB,MAAM,CAAC,EAAE,gBAAgB,GACxB,UAAU,CAEZ"}
@@ -1,415 +0,0 @@
1
- /**
2
- * SFT Dataset handling for Supervised Fine-Tuning
3
- *
4
- * Supports two data formats (auto-detected):
5
- * 1. Prompt-Completion: { prompt: ChatMessage[], completion: ChatMessage }
6
- * 2. Full Conversation: { messages: ChatMessage[] }
7
- *
8
- * Both formats produce tokenized batches with labels masked appropriately.
9
- */
10
- import { readFileSync } from 'node:fs';
11
- import { resolve as resolvePath } from 'node:path';
12
- import { validatePathContainment, getAllowedRoot } from '../utils/path-security';
13
- // -100 is the standard ignore index for cross-entropy loss
14
- const IGNORE_INDEX = -100;
15
- /**
16
- * Get special token IDs from a tokenizer
17
- *
18
- * Queries the tokenizer to get the actual token IDs for special tokens.
19
- * This ensures portability across different tokenizers/vocabularies.
20
- *
21
- * @param tokenizer - The tokenizer instance
22
- * @returns Special token IDs derived from the tokenizer
23
- * @throws Error if required special tokens are not found
24
- */
25
- function getSpecialTokenIds(tokenizer) {
26
- // Get im_start and im_end tokens using the tokenizer's special token getters
27
- const imStartToken = tokenizer.getImStartToken(); // "<|im_start|>"
28
- const imEndToken = tokenizer.getImEndToken(); // "<|im_end|>"
29
- const imStart = tokenizer.tokenToId(imStartToken);
30
- const imEnd = tokenizer.tokenToId(imEndToken);
31
- // Validate that we got valid IDs (tokenToId returns null for unknown tokens)
32
- if (imStart === null || imEnd === null) {
33
- throw new Error(`Tokenizer does not have required special tokens for ChatML format. ` +
34
- `Got im_start=${imStart}, im_end=${imEnd}. ` +
35
- `This tokenizer may not be compatible with ChatML format.`);
36
- }
37
- // Get newline token IDs - these vary by tokenizer
38
- // Try common newline representations
39
- const newlineTokens = [];
40
- const potentialNewlines = ['\n', ' \n', '\r\n', '\n\n'];
41
- for (const nl of potentialNewlines) {
42
- const id = tokenizer.tokenToId(nl);
43
- if (id !== null && !newlineTokens.includes(id)) {
44
- newlineTokens.push(id);
45
- }
46
- }
47
- // If no newline tokens found, we'll rely on the fallback in tokenizeConversation
48
- return {
49
- imStart,
50
- imEnd,
51
- newlineTokens,
52
- };
53
- }
54
- /**
55
- * Detect the format of an SFT example
56
- */
57
- function detectFormat(example) {
58
- if ('prompt' in example && 'completion' in example) {
59
- return 'prompt-completion';
60
- }
61
- if ('messages' in example) {
62
- return 'conversation';
63
- }
64
- throw new Error('Invalid SFT example format. Expected either {prompt, completion} or {messages}');
65
- }
66
- /**
67
- * SFT Dataset class for handling SFT training data
68
- */
69
- export class SFTDataset {
70
- examples;
71
- tokenizer;
72
- config;
73
- format;
74
- shuffledIndices;
75
- rng;
76
- /** Cached special token IDs for label masking */
77
- specialTokenIds;
78
- constructor(examples, tokenizer, config = {}) {
79
- if (examples.length === 0) {
80
- throw new Error('SFT dataset must contain at least one example');
81
- }
82
- this.examples = examples;
83
- this.tokenizer = tokenizer;
84
- this.config = {
85
- maxSeqLength: config.maxSeqLength ?? 2048,
86
- completionOnly: config.completionOnly ?? false, // Changed to false for TRL parity
87
- enableThinking: config.enableThinking ?? false,
88
- seed: config.seed ?? 42,
89
- };
90
- this.rng = this.createSeededRandom(this.config.seed);
91
- // Get special token IDs from tokenizer, with optional overrides
92
- const derivedTokenIds = getSpecialTokenIds(tokenizer);
93
- this.specialTokenIds = {
94
- imStart: config.specialTokenIds?.imStart ?? derivedTokenIds.imStart,
95
- imEnd: config.specialTokenIds?.imEnd ?? derivedTokenIds.imEnd,
96
- newlineTokens: config.specialTokenIds?.newlineTokens ?? derivedTokenIds.newlineTokens,
97
- };
98
- // Detect format from first example
99
- this.format = detectFormat(examples[0]);
100
- // Validate all examples have the same format
101
- for (let i = 1; i < examples.length; i++) {
102
- const fmt = detectFormat(examples[i]);
103
- if (fmt !== this.format) {
104
- throw new Error(`Inconsistent SFT data format: example 0 is ${this.format}, example ${i} is ${fmt}`);
105
- }
106
- }
107
- // Initialize indices
108
- this.shuffledIndices = Array.from({ length: examples.length }, (_, i) => i);
109
- }
110
- /**
111
- * Get the number of examples in the dataset
112
- */
113
- get length() {
114
- return this.examples.length;
115
- }
116
- /**
117
- * Shuffle dataset for a specific epoch using epoch-based seeding.
118
- * This ensures reproducible shuffles across training resumes.
119
- * Each epoch gets a deterministic shuffle based on (baseSeed + epoch).
120
- *
121
- * @param epoch - The epoch number (used as seed offset)
122
- */
123
- shuffleForEpoch(epoch) {
124
- // Reset RNG with epoch-specific seed for reproducibility
125
- this.rng = this.createSeededRandom(this.config.seed + epoch);
126
- // Reset indices to original order
127
- this.shuffledIndices = Array.from({ length: this.examples.length }, (_, i) => i);
128
- // Fisher-Yates shuffle
129
- for (let i = this.shuffledIndices.length - 1; i > 0; i--) {
130
- const j = Math.floor(this.rng() * (i + 1));
131
- [this.shuffledIndices[i], this.shuffledIndices[j]] = [this.shuffledIndices[j], this.shuffledIndices[i]];
132
- }
133
- }
134
- /**
135
- * Create a seeded pseudo-random number generator (Linear Congruential Generator)
136
- */
137
- createSeededRandom(seed) {
138
- let s = seed;
139
- return () => {
140
- s = (s * 1103515245 + 12345) & 0x7fffffff;
141
- return s / 0x7fffffff;
142
- };
143
- }
144
- /**
145
- * Find length of common prefix between two token arrays
146
- * Handles chat template quirks where prompt tokens may not be exact prefix of full tokens
147
- */
148
- findCommonPrefixLength(prompt, full) {
149
- let i = 0;
150
- const maxLen = Math.min(prompt.length, full.length);
151
- while (i < maxLen && prompt[i] === full[i]) {
152
- i++;
153
- }
154
- return i;
155
- }
156
- /**
157
- * Tokenize a prompt-completion example
158
- */
159
- async tokenizePromptCompletion(example) {
160
- // Tokenize prompt with generation prompt (so the model learns to continue)
161
- const promptTokens = await this.tokenizer.applyChatTemplate(example.prompt, true, // add generation prompt
162
- null, this.config.enableThinking);
163
- // Create full messages for tokenization
164
- const fullMessages = [...example.prompt, example.completion];
165
- const fullTokens = await this.tokenizer.applyChatTemplate(fullMessages, false, // no generation prompt at the end
166
- null, this.config.enableThinking);
167
- // Convert to regular arrays for manipulation
168
- const promptArr = Array.from(promptTokens, Number);
169
- const inputIds = Array.from(fullTokens, Number);
170
- // Use common prefix detection to handle chat template quirks
171
- // (some templates may not produce prompt tokens as exact prefix of full tokens)
172
- const promptLen = this.findCommonPrefixLength(promptArr, inputIds);
173
- if (promptLen !== promptArr.length) {
174
- console.warn(`[SFT Dataset] Prompt tokens differ from prefix of full sequence ` +
175
- `(${promptArr.length} vs ${promptLen}). Using common prefix for masking.`);
176
- }
177
- // Create labels: -100 for prompt tokens, actual tokens for completion
178
- const labels = inputIds.map((id, i) => {
179
- if (this.config.completionOnly && i < promptLen) {
180
- return IGNORE_INDEX;
181
- }
182
- return id;
183
- });
184
- return { inputIds, labels };
185
- }
186
- /**
187
- * Check if a token ID is a newline token
188
- */
189
- isNewlineToken(tokenId) {
190
- return this.specialTokenIds.newlineTokens.includes(tokenId);
191
- }
192
- /**
193
- * Tokenize a conversation example
194
- *
195
- * For conversations, we train on all assistant turns.
196
- * Non-assistant tokens (system, user) are masked with -100.
197
- *
198
- * Uses single-pass tokenization with token-based boundary detection.
199
- * Token IDs are derived from the tokenizer for portability across models.
200
- */
201
- async tokenizeConversation(example) {
202
- const messages = example.messages;
203
- // Single tokenization pass
204
- const fullTokens = await this.tokenizer.applyChatTemplate(messages, false, null, this.config.enableThinking);
205
- const inputIds = Array.from(fullTokens, Number);
206
- // If not masking prompts, all tokens are trainable
207
- if (!this.config.completionOnly) {
208
- return { inputIds, labels: inputIds.slice() };
209
- }
210
- // Token-based boundary detection using special tokens (derived from tokenizer)
211
- const { imStart, imEnd } = this.specialTokenIds;
212
- // Get "assistant" token ID (it's a single token in Qwen3)
213
- const assistantTokenId = this.tokenizer.tokenToId('assistant');
214
- const labels = Array.from({ length: inputIds.length }, () => IGNORE_INDEX);
215
- let inAssistant = false;
216
- for (let i = 0; i < inputIds.length; i++) {
217
- // Detect assistant region: <|im_start|> followed by "assistant" token
218
- if (inputIds[i] === imStart && i + 1 < inputIds.length && inputIds[i + 1] === assistantTokenId) {
219
- // Skip the <|im_start|>assistant\n header, start training from content
220
- // Find the newline after "assistant"
221
- let j = i + 2;
222
- while (j < inputIds.length && inputIds[j] !== imEnd) {
223
- // Look for newline token (dynamically derived from tokenizer)
224
- if (this.isNewlineToken(inputIds[j])) {
225
- inAssistant = true;
226
- i = j; // Skip to after header
227
- break;
228
- }
229
- j++;
230
- }
231
- if (!inAssistant) {
232
- // Fallback: just start after assistant token
233
- inAssistant = true;
234
- i = i + 1;
235
- }
236
- continue;
237
- }
238
- if (inAssistant && inputIds[i] !== imEnd) {
239
- labels[i] = inputIds[i];
240
- }
241
- if (inputIds[i] === imEnd) {
242
- inAssistant = false;
243
- }
244
- }
245
- return { inputIds, labels };
246
- }
247
- /**
248
- * Tokenize a single example based on its format
249
- */
250
- async tokenizeExample(example) {
251
- if (this.format === 'prompt-completion') {
252
- return this.tokenizePromptCompletion(example);
253
- }
254
- else {
255
- return this.tokenizeConversation(example);
256
- }
257
- }
258
- /**
259
- * Collate multiple examples into a padded batch
260
- */
261
- async collateBatch(indices) {
262
- const examples = indices.map((i) => this.examples[this.shuffledIndices[i]]);
263
- // Tokenize all examples
264
- const tokenized = [];
265
- for (const example of examples) {
266
- tokenized.push(await this.tokenizeExample(example));
267
- }
268
- // Find max length (capped at maxSeqLength)
269
- const maxLen = Math.min(this.config.maxSeqLength, Math.max(...tokenized.map((t) => t.inputIds.length)));
270
- // Pad and truncate
271
- const batchSize = examples.length;
272
- const paddedInputIds = new Int32Array(batchSize * maxLen);
273
- const paddedLabels = new Int32Array(batchSize * maxLen);
274
- const padTokenId = this.tokenizer.getPadTokenId();
275
- for (let b = 0; b < batchSize; b++) {
276
- const { inputIds, labels } = tokenized[b];
277
- const seqLen = Math.min(inputIds.length, maxLen);
278
- // Truncate from the left if necessary (keep the end of the sequence)
279
- const startIdx = Math.max(0, inputIds.length - maxLen);
280
- for (let s = 0; s < maxLen; s++) {
281
- const offset = b * maxLen + s;
282
- if (s < seqLen) {
283
- paddedInputIds[offset] = inputIds[startIdx + s];
284
- paddedLabels[offset] = labels[startIdx + s];
285
- }
286
- else {
287
- // Pad
288
- paddedInputIds[offset] = padTokenId;
289
- paddedLabels[offset] = IGNORE_INDEX;
290
- }
291
- }
292
- }
293
- return {
294
- inputIds: paddedInputIds,
295
- labels: paddedLabels,
296
- shape: [batchSize, maxLen],
297
- };
298
- }
299
- /**
300
- * Generate batches for training
301
- */
302
- async *batches(batchSize) {
303
- for (let i = 0; i < this.examples.length; i += batchSize) {
304
- const end = Math.min(i + batchSize, this.examples.length);
305
- const indices = Array.from({ length: end - i }, (_, j) => i + j);
306
- yield await this.collateBatch(indices);
307
- }
308
- }
309
- /**
310
- * Get total number of batches for a given batch size
311
- */
312
- numBatches(batchSize) {
313
- return Math.ceil(this.examples.length / batchSize);
314
- }
315
- }
316
- /**
317
- * Read JSONL file and parse into records
318
- */
319
- function readJsonl(path, limit) {
320
- let fileContents;
321
- try {
322
- fileContents = readFileSync(path, 'utf8');
323
- }
324
- catch (error) {
325
- const message = error instanceof Error ? error.message : String(error);
326
- throw new Error(`Failed to read SFT dataset at ${path}: ${message}`);
327
- }
328
- const lines = fileContents.split(/\r?\n/).filter((line) => line.trim().length > 0);
329
- const records = [];
330
- const max = typeof limit === 'number' && limit > 0 ? limit : Number.POSITIVE_INFINITY;
331
- for (let i = 0; i < lines.length && records.length < max; i++) {
332
- const line = lines[i];
333
- try {
334
- const parsed = JSON.parse(line);
335
- records.push(parsed);
336
- }
337
- catch (error) {
338
- const message = error instanceof Error ? error.message : String(error);
339
- throw new Error(`Failed to parse JSONL at ${path}:${i + 1} - ${message}`);
340
- }
341
- }
342
- return records;
343
- }
344
- /**
345
- * Validate an SFT example
346
- */
347
- function validateSFTExample(example, index) {
348
- if (typeof example !== 'object' || example === null) {
349
- throw new Error(`SFT example ${index} must be an object`);
350
- }
351
- const obj = example;
352
- // Check for prompt-completion format
353
- if ('prompt' in obj && 'completion' in obj) {
354
- if (!Array.isArray(obj.prompt)) {
355
- throw new Error(`SFT example ${index}: prompt must be an array of messages`);
356
- }
357
- if (typeof obj.completion !== 'object' || obj.completion === null) {
358
- throw new Error(`SFT example ${index}: completion must be a message object`);
359
- }
360
- const completion = obj.completion;
361
- if (completion.role !== 'assistant') {
362
- throw new Error(`SFT example ${index}: completion.role must be 'assistant'`);
363
- }
364
- if (typeof completion.content !== 'string') {
365
- throw new Error(`SFT example ${index}: completion.content must be a string`);
366
- }
367
- return {
368
- prompt: obj.prompt,
369
- completion: obj.completion,
370
- };
371
- }
372
- // Check for conversation format
373
- if ('messages' in obj) {
374
- if (!Array.isArray(obj.messages)) {
375
- throw new Error(`SFT example ${index}: messages must be an array`);
376
- }
377
- if (obj.messages.length === 0) {
378
- throw new Error(`SFT example ${index}: messages cannot be empty`);
379
- }
380
- // Check that at least one message is from assistant
381
- const hasAssistant = obj.messages.some((m) => typeof m === 'object' && m !== null && m.role === 'assistant');
382
- if (!hasAssistant) {
383
- throw new Error(`SFT example ${index}: messages must contain at least one assistant message`);
384
- }
385
- return { messages: obj.messages };
386
- }
387
- throw new Error(`SFT example ${index}: must have either {prompt, completion} or {messages}`);
388
- }
389
- /**
390
- * Load SFT dataset from a JSONL file
391
- *
392
- * Supports two formats:
393
- * 1. Prompt-Completion: {"prompt": [...], "completion": {...}}
394
- * 2. Conversation: {"messages": [...]}
395
- *
396
- * @param path - Path to the JSONL file (relative to cwd or allowedRoot)
397
- * @param tokenizer - Qwen3 tokenizer instance
398
- * @param config - Optional configuration including path validation options
399
- */
400
- export async function loadSFTDataset(path, tokenizer, config) {
401
- const allowedRoot = getAllowedRoot(config);
402
- const absolutePath = resolvePath(allowedRoot, path);
403
- // Validate the path stays within allowed root to prevent directory traversal
404
- validatePathContainment(absolutePath, allowedRoot);
405
- const rawRecords = readJsonl(absolutePath, config?.limit);
406
- // Validate and convert
407
- const examples = rawRecords.map((record, i) => validateSFTExample(record, i));
408
- return new SFTDataset(examples, tokenizer, config);
409
- }
410
- /**
411
- * Create SFT dataset from examples directly
412
- */
413
- export function createSFTDataset(examples, tokenizer, config) {
414
- return new SFTDataset(examples, tokenizer, config);
415
- }
package/dist/index.d.ts DELETED
@@ -1,33 +0,0 @@
1
- /**
2
- * @mlx-node/trl - Training utilities for MLX models
3
- *
4
- * This package provides everything needed for training ML models,
5
- * aligned with Python's TRL (Transformer Reinforcement Learning) library.
6
- *
7
- * For model loading and inference, import from @mlx-node/lm.
8
- *
9
- * @example
10
- * ```typescript
11
- * import { GRPOTrainer, GRPOConfig, loadLocalGsm8kDataset } from '@mlx-node/trl';
12
- * import { ModelLoader } from '@mlx-node/lm';
13
- *
14
- * const model = await ModelLoader.loadPretrained('./models/qwen3-0.6b');
15
- * const trainer = await GRPOTrainer.create({ modelPath: './models/qwen3-0.6b' });
16
- * ```
17
- */
18
- export type { ToolDefinition, FunctionDefinition, FunctionParameters } from '@mlx-node/core';
19
- export { MxArray } from '@mlx-node/core';
20
- export { convertModel, convertParquetToJsonl } from '@mlx-node/core';
21
- export type { ConversionOptions, ConversionResult } from '@mlx-node/core';
22
- export { type MLXGRPOConfig, ConfigError, getDefaultConfig, mergeConfig, loadTomlConfig, applyOverrides, } from './trainers/grpo-config';
23
- export { GRPOTrainer, type GRPOTrainerConfig, DEFAULT_GRPO_CONFIG, createRewardRegistry, computeDatasetHash, RewardTimeoutError, type GenerateBatchResult, type TrainStepMetrics, type TrainingMetrics, type TrainingState, type DatasetMetadata, GrpoTrainingEngine, NativeRewardRegistry, type GrpoEngineConfig, type EngineStepMetrics, type EngineEpochMetrics, type BuiltinRewardConfig, } from './trainers/grpo-trainer';
24
- export { TrainingLogger, createTrainingLogger, type TrainingLoggerConfig, type TrainingMetrics as TrainingLoggerMetrics, type GenerationSample, type TrainingConfigFields, type TuiMessage, type LogEvent, type PromptChoice, type PromptOptions, } from './trainers/training-logger';
25
- export { type EntropyFilteringConfig, DEFAULT_ENTROPY_CONFIG } from './trainers/grpo-entropy';
26
- export { SFTTrainer, SftTrainingEngine, type SFTTrainStepResult, type SFTTrainingState, type SftEngineConfig, type SftStepMetrics, type SftEpochMetrics, } from './trainers/sft-trainer';
27
- export { type SFTTrainerConfig, SFTConfigError, getDefaultSFTConfig, mergeSFTConfig, loadSFTTomlConfig, applySFTOverrides, DEFAULT_SFT_CONFIG, } from './trainers/sft-config';
28
- export * from './data/dataset';
29
- export { SFTDataset, loadSFTDataset, createSFTDataset, type SFTExample, type SFTPromptCompletionExample, type SFTConversationExample, type SFTBatch, type SFTDatasetConfig, type SpecialTokenIds, } from './data/sft-dataset';
30
- export * from './utils/xml-parser';
31
- export { validatePathContainment, resolveAndValidatePath, getAllowedRoot, PathTraversalError, type PathValidationOptions, } from './utils/path-security';
32
- export type { ChatRole, ChatMessage, CompletionMessage, Completion, DatasetSplit, DatasetExample, XmlParseResult, RewardComputationInput, PromptFormatterOptions, PromptTemplate, DatasetLoader, RewardFunction, PromptFormatter, CompletionInfo, RewardOutput, } from './types';
33
- //# sourceMappingURL=index.d.ts.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;;GAgBG;AAOH,YAAY,EAAE,cAAc,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,MAAM,gBAAgB,CAAC;AAG7F,OAAO,EAAE,OAAO,EAAE,MAAM,gBAAgB,CAAC;AAWzC,OAAO,EAAE,YAAY,EAAE,qBAAqB,EAAE,MAAM,gBAAgB,CAAC;AACrE,YAAY,EAAE,iBAAiB,EAAE,gBAAgB,EAAE,MAAM,gBAAgB,CAAC;AAO1E,OAAO,EACL,KAAK,aAAa,EAClB,WAAW,EACX,gBAAgB,EAChB,WAAW,EACX,cAAc,EACd,cAAc,GACf,MAAM,wBAAwB,CAAC;AAEhC,OAAO,EACL,WAAW,EACX,KAAK,iBAAiB,EACtB,mBAAmB,EACnB,oBAAoB,EACpB,kBAAkB,EAClB,kBAAkB,EAClB,KAAK,mBAAmB,EACxB,KAAK,gBAAgB,EACrB,KAAK,eAAe,EACpB,KAAK,aAAa,EAClB,KAAK,eAAe,EAEpB,kBAAkB,EAClB,oBAAoB,EACpB,KAAK,gBAAgB,EACrB,KAAK,iBAAiB,EACtB,KAAK,kBAAkB,EACvB,KAAK,mBAAmB,GACzB,MAAM,yBAAyB,CAAC;AAGjC,OAAO,EACL,cAAc,EACd,oBAAoB,EACpB,KAAK,oBAAoB,EACzB,KAAK,eAAe,IAAI,qBAAqB,EAC7C,KAAK,gBAAgB,EACrB,KAAK,oBAAoB,EACzB,KAAK,UAAU,EACf,KAAK,QAAQ,EACb,KAAK,YAAY,EACjB,KAAK,aAAa,GACnB,MAAM,4BAA4B,CAAC;AAGpC,OAAO,EAAE,KAAK,sBAAsB,EAAE,sBAAsB,EAAE,MAAM,yBAAyB,CAAC;AAG9F,OAAO,EACL,UAAU,EACV,iBAAiB,EACjB,KAAK,kBAAkB,EACvB,KAAK,gBAAgB,EACrB,KAAK,eAAe,EACpB,KAAK,cAAc,EACnB,KAAK,eAAe,GACrB,MAAM,wBAAwB,CAAC;AAEhC,OAAO,EACL,KAAK,gBAAgB,EACrB,cAAc,EACd,mBAAmB,EACnB,cAAc,EACd,iBAAiB,EACjB,iBAAiB,EACjB,kBAAkB,GACnB,MAAM,uBAAuB,CAAC;AAG/B,cAAc,gBAAgB,CAAC;AAC/B,OAAO,EACL,UAAU,EACV,cAAc,EACd,gBAAgB,EAChB,KAAK,UAAU,EACf,KAAK,0BAA0B,EAC/B,KAAK,sBAAsB,EAC3B,KAAK,QAAQ,EACb,KAAK,gBAAgB,EACrB,KAAK,eAAe,GACrB,MAAM,oBAAoB,CAAC;AAG5B,cAAc,oBAAoB,CAAC;AACnC,OAAO,EACL,uBAAuB,EACvB,sBAAsB,EACtB,cAAc,EACd,kBAAkB,EAClB,KAAK,qBAAqB,GAC3B,MAAM,uBAAuB,CAAC;AAG/B,YAAY,EACV,QAAQ,EACR,WAAW,EACX,iBAAiB,EACjB,UAAU,EACV,YAAY,EACZ,cAAc,EACd,cAAc,EACd,sBAAsB,EACtB,sBAAsB,EACtB,cAAc,EACd,aAAa,EACb,cAAc,EACd,eAAe,EAEf,cAAc,EACd,YAAY,GACb,MAAM,SAAS,CAAC"}
package/dist/index.js DELETED
@@ -1,47 +0,0 @@
1
- /**
2
- * @mlx-node/trl - Training utilities for MLX models
3
- *
4
- * This package provides everything needed for training ML models,
5
- * aligned with Python's TRL (Transformer Reinforcement Learning) library.
6
- *
7
- * For model loading and inference, import from @mlx-node/lm.
8
- *
9
- * @example
10
- * ```typescript
11
- * import { GRPOTrainer, GRPOConfig, loadLocalGsm8kDataset } from '@mlx-node/trl';
12
- * import { ModelLoader } from '@mlx-node/lm';
13
- *
14
- * const model = await ModelLoader.loadPretrained('./models/qwen3-0.6b');
15
- * const trainer = await GRPOTrainer.create({ modelPath: './models/qwen3-0.6b' });
16
- * ```
17
- */
18
- // Core tensor (for custom rewards/models)
19
- export { MxArray } from '@mlx-node/core';
20
- // Activations are internal-only (Rust) - used by transformers, sampling, GRPO
21
- // Transformer components are now internal-only (Rust)
22
- // Use model.chat() or model.generate() instead
23
- // GRPO utilities (computeAdvantages, computeEntropy, getHighEntropyMask) are internal-only
24
- // They are used by GRPOTrainingEngine in Rust
25
- // Model conversion
26
- export { convertModel, convertParquetToJsonl } from '@mlx-node/core';
27
- // =============================================================================
28
- // TRL-specific exports
29
- // =============================================================================
30
- // Trainers
31
- export { ConfigError, getDefaultConfig, mergeConfig, loadTomlConfig, applyOverrides, } from './trainers/grpo-config';
32
- export { GRPOTrainer, DEFAULT_GRPO_CONFIG, createRewardRegistry, computeDatasetHash, RewardTimeoutError,
33
- // Re-export native types from trainer
34
- GrpoTrainingEngine, NativeRewardRegistry, } from './trainers/grpo-trainer';
35
- // Unified Training Logger (recommended)
36
- export { TrainingLogger, createTrainingLogger, } from './trainers/training-logger';
37
- // Entropy configuration
38
- export { DEFAULT_ENTROPY_CONFIG } from './trainers/grpo-entropy';
39
- // SFT Trainer
40
- export { SFTTrainer, SftTrainingEngine, } from './trainers/sft-trainer';
41
- export { SFTConfigError, getDefaultSFTConfig, mergeSFTConfig, loadSFTTomlConfig, applySFTOverrides, DEFAULT_SFT_CONFIG, } from './trainers/sft-config';
42
- // Data
43
- export * from './data/dataset';
44
- export { SFTDataset, loadSFTDataset, createSFTDataset, } from './data/sft-dataset';
45
- // Utils
46
- export * from './utils/xml-parser';
47
- export { validatePathContainment, resolveAndValidatePath, getAllowedRoot, PathTraversalError, } from './utils/path-security';