@mlx-node/trl 0.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. package/dist/data/dataset.d.ts +22 -0
  2. package/dist/data/dataset.d.ts.map +1 -0
  3. package/dist/data/dataset.js +142 -0
  4. package/dist/data/sft-dataset.d.ts +156 -0
  5. package/dist/data/sft-dataset.d.ts.map +1 -0
  6. package/dist/data/sft-dataset.js +415 -0
  7. package/dist/index.d.ts +33 -0
  8. package/dist/index.d.ts.map +1 -0
  9. package/dist/index.js +47 -0
  10. package/dist/trainers/grpo-config.d.ts +42 -0
  11. package/dist/trainers/grpo-config.d.ts.map +1 -0
  12. package/dist/trainers/grpo-config.js +220 -0
  13. package/dist/trainers/grpo-entropy.d.ts +33 -0
  14. package/dist/trainers/grpo-entropy.d.ts.map +1 -0
  15. package/dist/trainers/grpo-entropy.js +18 -0
  16. package/dist/trainers/grpo-trainer.d.ts +602 -0
  17. package/dist/trainers/grpo-trainer.d.ts.map +1 -0
  18. package/dist/trainers/grpo-trainer.js +1439 -0
  19. package/dist/trainers/sft-config.d.ts +32 -0
  20. package/dist/trainers/sft-config.d.ts.map +1 -0
  21. package/dist/trainers/sft-config.js +186 -0
  22. package/dist/trainers/sft-trainer.d.ts +141 -0
  23. package/dist/trainers/sft-trainer.d.ts.map +1 -0
  24. package/dist/trainers/sft-trainer.js +502 -0
  25. package/dist/trainers/training-logger.d.ts +375 -0
  26. package/dist/trainers/training-logger.d.ts.map +1 -0
  27. package/dist/trainers/training-logger.js +542 -0
  28. package/dist/types.d.ts +54 -0
  29. package/dist/types.d.ts.map +1 -0
  30. package/dist/types.js +1 -0
  31. package/dist/utils/path-security.d.ts +51 -0
  32. package/dist/utils/path-security.d.ts.map +1 -0
  33. package/dist/utils/path-security.js +69 -0
  34. package/dist/utils/xml-parser.d.ts +6 -0
  35. package/dist/utils/xml-parser.d.ts.map +1 -0
  36. package/dist/utils/xml-parser.js +184 -0
  37. package/package.json +29 -0
@@ -0,0 +1,22 @@
1
+ import type { DatasetExample, ChatMessage, DatasetSplit, PromptFormatterOptions, PromptTemplate, DatasetLoader } from '../types';
2
+ import { type PathValidationOptions } from '../utils/path-security';
3
+ export interface LocalDatasetOptions extends PromptFormatterOptions, PathValidationOptions {
4
+ basePath?: string;
5
+ promptTemplate?: PromptTemplate;
6
+ metadata?: Record<string, unknown>;
7
+ }
8
+ export declare const SYSTEM_PROMPT: string;
9
+ export declare const XML_COT_FORMAT = "<reasoning>\n{reasoning}\n</reasoning>\n<answer>\n{answer}\n</answer>";
10
+ export declare const defaultPromptTemplate: PromptTemplate;
11
+ export declare function createDatasetExample(prompt: ChatMessage[], metadata?: Record<string, unknown>): DatasetExample;
12
+ export declare function extractGsm8kAnswer(raw: string): string | null;
13
+ export declare function validateDatasetExample(example: DatasetExample): void;
14
+ export declare function loadLocalGsm8kDataset(split: DatasetSplit, options?: LocalDatasetOptions & {
15
+ limit?: number;
16
+ }): Promise<DatasetExample[]>;
17
+ export declare class LocalGsm8kDatasetLoader implements DatasetLoader {
18
+ private readonly options;
19
+ constructor(options?: LocalDatasetOptions);
20
+ load(split: DatasetSplit, limit?: number): Promise<DatasetExample[]>;
21
+ }
22
+ //# sourceMappingURL=dataset.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"dataset.d.ts","sourceRoot":"","sources":["../../src/data/dataset.ts"],"names":[],"mappings":"AAEA,OAAO,KAAK,EACV,cAAc,EACd,WAAW,EAEX,YAAY,EACZ,sBAAsB,EACtB,cAAc,EACd,aAAa,EACd,MAAM,UAAU,CAAC;AAElB,OAAO,EAA2C,KAAK,qBAAqB,EAAE,MAAM,wBAAwB,CAAC;AAE7G,MAAM,WAAW,mBAAoB,SAAQ,sBAAsB,EAAE,qBAAqB;IACxF,QAAQ,CAAC,EAAE,MAAM,CAAC;IAClB,cAAc,CAAC,EAAE,cAAc,CAAC;IAChC,QAAQ,CAAC,EAAE,MAAM,CAAC,MAAM,EAAE,OAAO,CAAC,CAAC;CACpC;AAUD,eAAO,MAAM,aAAa,QASlB,CAAC;AAET,eAAO,MAAM,cAAc,0EAKjB,CAAC;AAWX,eAAO,MAAM,qBAAqB,EAAE,cAYnC,CAAC;AAEF,wBAAgB,oBAAoB,CAAC,MAAM,EAAE,WAAW,EAAE,EAAE,QAAQ,CAAC,EAAE,MAAM,CAAC,MAAM,EAAE,OAAO,CAAC,GAAG,cAAc,CAK9G;AAED,wBAAgB,kBAAkB,CAAC,GAAG,EAAE,MAAM,GAAG,MAAM,GAAG,IAAI,CAE7D;AAED,wBAAgB,sBAAsB,CAAC,OAAO,EAAE,cAAc,GAAG,IAAI,CAYpE;AAwDD,wBAAsB,qBAAqB,CACzC,KAAK,EAAE,YAAY,EACnB,OAAO,GAAE,mBAAmB,GAAG;IAAE,KAAK,CAAC,EAAE,MAAM,CAAA;CAAO,GACrD,OAAO,CAAC,cAAc,EAAE,CAAC,CA4B3B;AAED,qBAAa,uBAAwB,YAAW,aAAa;IAC3D,OAAO,CAAC,QAAQ,CAAC,OAAO,CAAsB;gBAElC,OAAO,GAAE,mBAAwB;IAIvC,IAAI,CAAC,KAAK,EAAE,YAAY,EAAE,KAAK,CAAC,EAAE,MAAM,GAAG,OAAO,CAAC,cAAc,EAAE,CAAC;CAG3E"}
@@ -0,0 +1,142 @@
1
+ import { readFileSync } from 'node:fs';
2
+ import { resolve as resolvePath } from 'node:path';
3
+ import { extractHashAnswer } from '../utils/xml-parser';
4
+ import { validatePathContainment, getAllowedRoot } from '../utils/path-security';
5
+ const DEFAULT_BASE_PATH = resolvePath(process.cwd(), 'data/gsm8k');
6
+ const VALID_SPLITS = new Set(['train', 'test']);
7
+ export const SYSTEM_PROMPT = `
8
+ Respond in the following format:
9
+
10
+ <reasoning>
11
+ ...
12
+ </reasoning>
13
+ <answer>
14
+ ...
15
+ </answer>
16
+ `.trim();
17
+ export const XML_COT_FORMAT = `<reasoning>
18
+ {reasoning}
19
+ </reasoning>
20
+ <answer>
21
+ {answer}
22
+ </answer>`;
23
+ const SYSTEM_MESSAGE = {
24
+ role: 'system',
25
+ content: SYSTEM_PROMPT,
26
+ };
27
+ function createMessage(role, content) {
28
+ return { role, content };
29
+ }
30
+ export const defaultPromptTemplate = (question, options) => {
31
+ const messages = [SYSTEM_MESSAGE];
32
+ if (options?.includeOneShot && options.oneShotExample) {
33
+ const { question: exampleQuestion, reasoning, answer } = options.oneShotExample;
34
+ messages.push(createMessage('user', exampleQuestion), createMessage('assistant', XML_COT_FORMAT.replace('{reasoning}', reasoning).replace('{answer}', answer)));
35
+ }
36
+ messages.push(createMessage('user', question));
37
+ return messages;
38
+ };
39
+ export function createDatasetExample(prompt, metadata) {
40
+ return {
41
+ prompt: prompt.map((message) => ({ ...message })), // defensive copy
42
+ metadata: metadata ? { ...metadata } : undefined,
43
+ };
44
+ }
45
+ export function extractGsm8kAnswer(raw) {
46
+ return extractHashAnswer(raw);
47
+ }
48
+ export function validateDatasetExample(example) {
49
+ if (!Array.isArray(example.prompt) || example.prompt.length === 0) {
50
+ throw new Error('Dataset example must contain at least one prompt message.');
51
+ }
52
+ for (const message of example.prompt) {
53
+ if (!message || typeof message.content !== 'string' || message.content.trim() === '') {
54
+ throw new Error('Prompt messages must include non-empty textual content.');
55
+ }
56
+ if (message.role !== 'system' && message.role !== 'user' && message.role !== 'assistant') {
57
+ throw new Error(`Unsupported chat role: ${String(message.role)}`);
58
+ }
59
+ }
60
+ }
61
+ function resolveBasePath(optionPath, options) {
62
+ const allowedRoot = getAllowedRoot(options);
63
+ if (!optionPath) {
64
+ // Default path - validate it's within allowed root
65
+ validatePathContainment(DEFAULT_BASE_PATH, allowedRoot);
66
+ return DEFAULT_BASE_PATH;
67
+ }
68
+ // Resolve and validate user-provided path
69
+ const resolved = resolvePath(allowedRoot, optionPath);
70
+ validatePathContainment(resolved, allowedRoot);
71
+ return resolved;
72
+ }
73
+ function datasetFileForSplit(split) {
74
+ if (!VALID_SPLITS.has(split)) {
75
+ throw new Error(`Unsupported GSM8K split "${split}". Expected one of: ${Array.from(VALID_SPLITS).join(', ')}`);
76
+ }
77
+ return `${split}.jsonl`;
78
+ }
79
+ function readDatasetFile(filePath) {
80
+ try {
81
+ return readFileSync(filePath, 'utf8');
82
+ }
83
+ catch (error) {
84
+ const message = error instanceof Error ? error.message : String(error);
85
+ throw new Error(`Failed to read dataset file at ${filePath}: ${message}`);
86
+ }
87
+ }
88
+ function readJsonl(path, limit) {
89
+ const fileContents = readDatasetFile(path);
90
+ const lines = fileContents.split(/\r?\n/).filter((line) => line.trim().length > 0);
91
+ const records = [];
92
+ const max = typeof limit === 'number' && limit >= 0 ? limit : Number.POSITIVE_INFINITY;
93
+ for (let i = 0; i < lines.length && records.length < max; i += 1) {
94
+ const line = lines[i];
95
+ try {
96
+ const parsed = JSON.parse(line);
97
+ if (typeof parsed.question !== 'string' || typeof parsed.answer !== 'string') {
98
+ throw new Error('Record must include string "question" and "answer" fields.');
99
+ }
100
+ records.push({ question: parsed.question, answer: parsed.answer });
101
+ }
102
+ catch (error) {
103
+ const message = error instanceof Error ? error.message : String(error);
104
+ throw new Error(`Failed to parse JSONL record at ${path}:${i + 1} - ${message}`);
105
+ }
106
+ }
107
+ return records;
108
+ }
109
+ export async function loadLocalGsm8kDataset(split, options = {}) {
110
+ const basePath = resolveBasePath(options.basePath, options);
111
+ const fileName = datasetFileForSplit(split);
112
+ const filePath = resolvePath(basePath, fileName);
113
+ // Additional validation: ensure the final file path stays within the base path
114
+ // This protects against any edge cases where the filename could escape
115
+ validatePathContainment(filePath, basePath);
116
+ const promptTemplate = options.promptTemplate ?? defaultPromptTemplate;
117
+ const records = readJsonl(filePath, options.limit);
118
+ const examples = records.map((record, index) => {
119
+ const prompt = promptTemplate(record.question, {
120
+ includeOneShot: options.includeOneShot,
121
+ oneShotExample: options.oneShotExample,
122
+ });
123
+ const example = createDatasetExample(prompt, {
124
+ split,
125
+ index,
126
+ raw_answer: record.answer,
127
+ ...options.metadata,
128
+ });
129
+ validateDatasetExample(example);
130
+ return example;
131
+ });
132
+ return examples;
133
+ }
134
+ export class LocalGsm8kDatasetLoader {
135
+ options;
136
+ constructor(options = {}) {
137
+ this.options = { ...options };
138
+ }
139
+ async load(split, limit) {
140
+ return loadLocalGsm8kDataset(split, { ...this.options, limit });
141
+ }
142
+ }
@@ -0,0 +1,156 @@
1
+ /**
2
+ * SFT Dataset handling for Supervised Fine-Tuning
3
+ *
4
+ * Supports two data formats (auto-detected):
5
+ * 1. Prompt-Completion: { prompt: ChatMessage[], completion: ChatMessage }
6
+ * 2. Full Conversation: { messages: ChatMessage[] }
7
+ *
8
+ * Both formats produce tokenized batches with labels masked appropriately.
9
+ */
10
+ import type { Qwen3Tokenizer } from '@mlx-node/core';
11
+ import type { ChatMessage } from '../types';
12
+ import { type PathValidationOptions } from '../utils/path-security';
13
+ /**
14
+ * Special token IDs for SFT label masking
15
+ *
16
+ * These are used to detect assistant message boundaries in tokenized conversations.
17
+ * The IDs can be derived from the tokenizer or provided explicitly.
18
+ */
19
+ export interface SpecialTokenIds {
20
+ /** Token ID for <|im_start|> */
21
+ imStart: number;
22
+ /** Token ID for <|im_end|> */
23
+ imEnd: number;
24
+ /** Token IDs that represent newlines (for detecting end of role header) */
25
+ newlineTokens: number[];
26
+ }
27
+ /**
28
+ * Prompt-Completion format for tool-use training
29
+ */
30
+ export interface SFTPromptCompletionExample {
31
+ prompt: ChatMessage[];
32
+ completion: ChatMessage;
33
+ }
34
+ /**
35
+ * Full conversation format for multi-turn dialogue
36
+ */
37
+ export interface SFTConversationExample {
38
+ messages: ChatMessage[];
39
+ }
40
+ /**
41
+ * Union type for SFT examples
42
+ */
43
+ export type SFTExample = SFTPromptCompletionExample | SFTConversationExample;
44
+ /**
45
+ * A tokenized batch ready for SFT training
46
+ */
47
+ export interface SFTBatch {
48
+ inputIds: Int32Array;
49
+ labels: Int32Array;
50
+ shape: [number, number];
51
+ }
52
+ /**
53
+ * Configuration for SFT dataset
54
+ */
55
+ export interface SFTDatasetConfig {
56
+ maxSeqLength?: number;
57
+ completionOnly?: boolean;
58
+ enableThinking?: boolean;
59
+ seed?: number;
60
+ /**
61
+ * Special token IDs for label masking.
62
+ *
63
+ * If not provided, these are automatically derived from the tokenizer.
64
+ * This option allows explicit overriding for custom tokenizers or
65
+ * non-standard vocabularies.
66
+ */
67
+ specialTokenIds?: Partial<SpecialTokenIds>;
68
+ }
69
+ /**
70
+ * SFT Dataset class for handling SFT training data
71
+ */
72
+ export declare class SFTDataset {
73
+ private examples;
74
+ private tokenizer;
75
+ private config;
76
+ private format;
77
+ private shuffledIndices;
78
+ private rng;
79
+ /** Cached special token IDs for label masking */
80
+ private specialTokenIds;
81
+ constructor(examples: SFTExample[], tokenizer: Qwen3Tokenizer, config?: SFTDatasetConfig);
82
+ /**
83
+ * Get the number of examples in the dataset
84
+ */
85
+ get length(): number;
86
+ /**
87
+ * Shuffle dataset for a specific epoch using epoch-based seeding.
88
+ * This ensures reproducible shuffles across training resumes.
89
+ * Each epoch gets a deterministic shuffle based on (baseSeed + epoch).
90
+ *
91
+ * @param epoch - The epoch number (used as seed offset)
92
+ */
93
+ shuffleForEpoch(epoch: number): void;
94
+ /**
95
+ * Create a seeded pseudo-random number generator (Linear Congruential Generator)
96
+ */
97
+ private createSeededRandom;
98
+ /**
99
+ * Find length of common prefix between two token arrays
100
+ * Handles chat template quirks where prompt tokens may not be exact prefix of full tokens
101
+ */
102
+ private findCommonPrefixLength;
103
+ /**
104
+ * Tokenize a prompt-completion example
105
+ */
106
+ private tokenizePromptCompletion;
107
+ /**
108
+ * Check if a token ID is a newline token
109
+ */
110
+ private isNewlineToken;
111
+ /**
112
+ * Tokenize a conversation example
113
+ *
114
+ * For conversations, we train on all assistant turns.
115
+ * Non-assistant tokens (system, user) are masked with -100.
116
+ *
117
+ * Uses single-pass tokenization with token-based boundary detection.
118
+ * Token IDs are derived from the tokenizer for portability across models.
119
+ */
120
+ private tokenizeConversation;
121
+ /**
122
+ * Tokenize a single example based on its format
123
+ */
124
+ private tokenizeExample;
125
+ /**
126
+ * Collate multiple examples into a padded batch
127
+ */
128
+ collateBatch(indices: number[]): Promise<SFTBatch>;
129
+ /**
130
+ * Generate batches for training
131
+ */
132
+ batches(batchSize: number): AsyncGenerator<SFTBatch>;
133
+ /**
134
+ * Get total number of batches for a given batch size
135
+ */
136
+ numBatches(batchSize: number): number;
137
+ }
138
+ /**
139
+ * Load SFT dataset from a JSONL file
140
+ *
141
+ * Supports two formats:
142
+ * 1. Prompt-Completion: {"prompt": [...], "completion": {...}}
143
+ * 2. Conversation: {"messages": [...]}
144
+ *
145
+ * @param path - Path to the JSONL file (relative to cwd or allowedRoot)
146
+ * @param tokenizer - Qwen3 tokenizer instance
147
+ * @param config - Optional configuration including path validation options
148
+ */
149
+ export declare function loadSFTDataset(path: string, tokenizer: Qwen3Tokenizer, config?: SFTDatasetConfig & {
150
+ limit?: number;
151
+ } & PathValidationOptions): Promise<SFTDataset>;
152
+ /**
153
+ * Create SFT dataset from examples directly
154
+ */
155
+ export declare function createSFTDataset(examples: SFTExample[], tokenizer: Qwen3Tokenizer, config?: SFTDatasetConfig): SFTDataset;
156
+ //# sourceMappingURL=sft-dataset.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"sft-dataset.d.ts","sourceRoot":"","sources":["../../src/data/sft-dataset.ts"],"names":[],"mappings":"AAAA;;;;;;;;GAQG;AAIH,OAAO,KAAK,EAAE,cAAc,EAAE,MAAM,gBAAgB,CAAC;AACrD,OAAO,KAAK,EAAE,WAAW,EAAE,MAAM,UAAU,CAAC;AAC5C,OAAO,EAA2C,KAAK,qBAAqB,EAAE,MAAM,wBAAwB,CAAC;AAK7G;;;;;GAKG;AACH,MAAM,WAAW,eAAe;IAC9B,gCAAgC;IAChC,OAAO,EAAE,MAAM,CAAC;IAChB,8BAA8B;IAC9B,KAAK,EAAE,MAAM,CAAC;IACd,2EAA2E;IAC3E,aAAa,EAAE,MAAM,EAAE,CAAC;CACzB;AAgDD;;GAEG;AACH,MAAM,WAAW,0BAA0B;IACzC,MAAM,EAAE,WAAW,EAAE,CAAC;IACtB,UAAU,EAAE,WAAW,CAAC;CACzB;AAED;;GAEG;AACH,MAAM,WAAW,sBAAsB;IACrC,QAAQ,EAAE,WAAW,EAAE,CAAC;CACzB;AAED;;GAEG;AACH,MAAM,MAAM,UAAU,GAAG,0BAA0B,GAAG,sBAAsB,CAAC;AAE7E;;GAEG;AACH,MAAM,WAAW,QAAQ;IACvB,QAAQ,EAAE,UAAU,CAAC;IACrB,MAAM,EAAE,UAAU,CAAC;IACnB,KAAK,EAAE,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;CACzB;AAED;;GAEG;AACH,MAAM,WAAW,gBAAgB;IAC/B,YAAY,CAAC,EAAE,MAAM,CAAC;IACtB,cAAc,CAAC,EAAE,OAAO,CAAC;IACzB,cAAc,CAAC,EAAE,OAAO,CAAC;IACzB,IAAI,CAAC,EAAE,MAAM,CAAC;IAEd;;;;;;OAMG;IACH,eAAe,CAAC,EAAE,OAAO,CAAC,eAAe,CAAC,CAAC;CAC5C;AAeD;;GAEG;AACH,qBAAa,UAAU;IACrB,OAAO,CAAC,QAAQ,CAAe;IAC/B,OAAO,CAAC,SAAS,CAAiB;IAClC,OAAO,CAAC,MAAM,CAAkF;IAChG,OAAO,CAAC,MAAM,CAAuC;IACrD,OAAO,CAAC,eAAe,CAAW;IAClC,OAAO,CAAC,GAAG,CAAe;IAC1B,iDAAiD;IACjD,OAAO,CAAC,eAAe,CAAkB;gBAE7B,QAAQ,EAAE,UAAU,EAAE,EAAE,SAAS,EAAE,cAAc,EAAE,MAAM,GAAE,gBAAqB;IAsC5F;;OAEG;IACH,IAAI,MAAM,IAAI,MAAM,CAEnB;IAED;;;;;;OAMG;IACH,eAAe,CAAC,KAAK,EAAE,MAAM,GAAG,IAAI;IAYpC;;OAEG;IACH,OAAO,CAAC,kBAAkB;IAQ1B;;;OAGG;IACH,OAAO,CAAC,sBAAsB;IAS9B;;OAEG;YACW,wBAAwB;IA8CtC;;OAEG;IACH,OAAO,CAAC,cAAc;IAItB;;;;;;;;OAQG;YACW,oBAAoB;IA2DlC;;OAEG;YACW,eAAe;IAQ7B;;OAEG;IACG,YAAY,CAAC,OAAO,EAAE,MAAM,EAAE,GAAG,OAAO,CAAC,QAAQ,CAAC;IA8CxD;;OAEG;IACI,OAAO,CAAC,SAAS,EAAE,MAAM,GAAG,cAAc,CAAC,QAAQ,CAAC;IAQ3D;;OAEG;IACH,UAAU,CAAC,SAAS,EAAE,MAAM,GAAG,MAAM;CAGtC;AAoFD;;;;;;;;;;GAUG;AACH,wBAAsB,cAAc,CAClC,IAAI,EAAE,MAAM,EACZ,SAAS,EAAE,cAAc,EACzB,MAAM,CAAC,EAAE,gBAAgB,GAAG;IAAE,KAAK,CAAC,EAAE,MAAM,CAAA;CAAE,GAAG,qBAAqB,GACrE,OAAO,CAAC,UAAU,CAAC,CAarB;AAED;;GAEG;AACH,wBAAgB,gBAAgB,CAC9B,QAAQ,EAAE,UAAU,EAAE,EACtB,SAAS,EAAE,cAAc,EACzB,MAAM,CAAC,EAAE,gBAAgB,GACxB,UAAU,CAEZ"}