@mlx-node/trl 0.0.0 → 0.0.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +389 -0
- package/package.json +16 -5
- package/dist/data/dataset.d.ts +0 -22
- package/dist/data/dataset.d.ts.map +0 -1
- package/dist/data/dataset.js +0 -142
- package/dist/data/sft-dataset.d.ts +0 -156
- package/dist/data/sft-dataset.d.ts.map +0 -1
- package/dist/data/sft-dataset.js +0 -415
- package/dist/index.d.ts +0 -33
- package/dist/index.d.ts.map +0 -1
- package/dist/index.js +0 -47
- package/dist/trainers/grpo-config.d.ts +0 -42
- package/dist/trainers/grpo-config.d.ts.map +0 -1
- package/dist/trainers/grpo-config.js +0 -220
- package/dist/trainers/grpo-entropy.d.ts +0 -33
- package/dist/trainers/grpo-entropy.d.ts.map +0 -1
- package/dist/trainers/grpo-entropy.js +0 -18
- package/dist/trainers/grpo-trainer.d.ts +0 -602
- package/dist/trainers/grpo-trainer.d.ts.map +0 -1
- package/dist/trainers/grpo-trainer.js +0 -1439
- package/dist/trainers/sft-config.d.ts +0 -32
- package/dist/trainers/sft-config.d.ts.map +0 -1
- package/dist/trainers/sft-config.js +0 -186
- package/dist/trainers/sft-trainer.d.ts +0 -141
- package/dist/trainers/sft-trainer.d.ts.map +0 -1
- package/dist/trainers/sft-trainer.js +0 -502
- package/dist/trainers/training-logger.d.ts +0 -375
- package/dist/trainers/training-logger.d.ts.map +0 -1
- package/dist/trainers/training-logger.js +0 -542
- package/dist/types.d.ts +0 -54
- package/dist/types.d.ts.map +0 -1
- package/dist/types.js +0 -1
- package/dist/utils/path-security.d.ts +0 -51
- package/dist/utils/path-security.d.ts.map +0 -1
- package/dist/utils/path-security.js +0 -69
- package/dist/utils/xml-parser.d.ts +0 -6
- package/dist/utils/xml-parser.d.ts.map +0 -1
- package/dist/utils/xml-parser.js +0 -184
|
@@ -1,156 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* SFT Dataset handling for Supervised Fine-Tuning
|
|
3
|
-
*
|
|
4
|
-
* Supports two data formats (auto-detected):
|
|
5
|
-
* 1. Prompt-Completion: { prompt: ChatMessage[], completion: ChatMessage }
|
|
6
|
-
* 2. Full Conversation: { messages: ChatMessage[] }
|
|
7
|
-
*
|
|
8
|
-
* Both formats produce tokenized batches with labels masked appropriately.
|
|
9
|
-
*/
|
|
10
|
-
import type { Qwen3Tokenizer } from '@mlx-node/core';
|
|
11
|
-
import type { ChatMessage } from '../types';
|
|
12
|
-
import { type PathValidationOptions } from '../utils/path-security';
|
|
13
|
-
/**
|
|
14
|
-
* Special token IDs for SFT label masking
|
|
15
|
-
*
|
|
16
|
-
* These are used to detect assistant message boundaries in tokenized conversations.
|
|
17
|
-
* The IDs can be derived from the tokenizer or provided explicitly.
|
|
18
|
-
*/
|
|
19
|
-
export interface SpecialTokenIds {
|
|
20
|
-
/** Token ID for <|im_start|> */
|
|
21
|
-
imStart: number;
|
|
22
|
-
/** Token ID for <|im_end|> */
|
|
23
|
-
imEnd: number;
|
|
24
|
-
/** Token IDs that represent newlines (for detecting end of role header) */
|
|
25
|
-
newlineTokens: number[];
|
|
26
|
-
}
|
|
27
|
-
/**
|
|
28
|
-
* Prompt-Completion format for tool-use training
|
|
29
|
-
*/
|
|
30
|
-
export interface SFTPromptCompletionExample {
|
|
31
|
-
prompt: ChatMessage[];
|
|
32
|
-
completion: ChatMessage;
|
|
33
|
-
}
|
|
34
|
-
/**
|
|
35
|
-
* Full conversation format for multi-turn dialogue
|
|
36
|
-
*/
|
|
37
|
-
export interface SFTConversationExample {
|
|
38
|
-
messages: ChatMessage[];
|
|
39
|
-
}
|
|
40
|
-
/**
|
|
41
|
-
* Union type for SFT examples
|
|
42
|
-
*/
|
|
43
|
-
export type SFTExample = SFTPromptCompletionExample | SFTConversationExample;
|
|
44
|
-
/**
|
|
45
|
-
* A tokenized batch ready for SFT training
|
|
46
|
-
*/
|
|
47
|
-
export interface SFTBatch {
|
|
48
|
-
inputIds: Int32Array;
|
|
49
|
-
labels: Int32Array;
|
|
50
|
-
shape: [number, number];
|
|
51
|
-
}
|
|
52
|
-
/**
|
|
53
|
-
* Configuration for SFT dataset
|
|
54
|
-
*/
|
|
55
|
-
export interface SFTDatasetConfig {
|
|
56
|
-
maxSeqLength?: number;
|
|
57
|
-
completionOnly?: boolean;
|
|
58
|
-
enableThinking?: boolean;
|
|
59
|
-
seed?: number;
|
|
60
|
-
/**
|
|
61
|
-
* Special token IDs for label masking.
|
|
62
|
-
*
|
|
63
|
-
* If not provided, these are automatically derived from the tokenizer.
|
|
64
|
-
* This option allows explicit overriding for custom tokenizers or
|
|
65
|
-
* non-standard vocabularies.
|
|
66
|
-
*/
|
|
67
|
-
specialTokenIds?: Partial<SpecialTokenIds>;
|
|
68
|
-
}
|
|
69
|
-
/**
|
|
70
|
-
* SFT Dataset class for handling SFT training data
|
|
71
|
-
*/
|
|
72
|
-
export declare class SFTDataset {
|
|
73
|
-
private examples;
|
|
74
|
-
private tokenizer;
|
|
75
|
-
private config;
|
|
76
|
-
private format;
|
|
77
|
-
private shuffledIndices;
|
|
78
|
-
private rng;
|
|
79
|
-
/** Cached special token IDs for label masking */
|
|
80
|
-
private specialTokenIds;
|
|
81
|
-
constructor(examples: SFTExample[], tokenizer: Qwen3Tokenizer, config?: SFTDatasetConfig);
|
|
82
|
-
/**
|
|
83
|
-
* Get the number of examples in the dataset
|
|
84
|
-
*/
|
|
85
|
-
get length(): number;
|
|
86
|
-
/**
|
|
87
|
-
* Shuffle dataset for a specific epoch using epoch-based seeding.
|
|
88
|
-
* This ensures reproducible shuffles across training resumes.
|
|
89
|
-
* Each epoch gets a deterministic shuffle based on (baseSeed + epoch).
|
|
90
|
-
*
|
|
91
|
-
* @param epoch - The epoch number (used as seed offset)
|
|
92
|
-
*/
|
|
93
|
-
shuffleForEpoch(epoch: number): void;
|
|
94
|
-
/**
|
|
95
|
-
* Create a seeded pseudo-random number generator (Linear Congruential Generator)
|
|
96
|
-
*/
|
|
97
|
-
private createSeededRandom;
|
|
98
|
-
/**
|
|
99
|
-
* Find length of common prefix between two token arrays
|
|
100
|
-
* Handles chat template quirks where prompt tokens may not be exact prefix of full tokens
|
|
101
|
-
*/
|
|
102
|
-
private findCommonPrefixLength;
|
|
103
|
-
/**
|
|
104
|
-
* Tokenize a prompt-completion example
|
|
105
|
-
*/
|
|
106
|
-
private tokenizePromptCompletion;
|
|
107
|
-
/**
|
|
108
|
-
* Check if a token ID is a newline token
|
|
109
|
-
*/
|
|
110
|
-
private isNewlineToken;
|
|
111
|
-
/**
|
|
112
|
-
* Tokenize a conversation example
|
|
113
|
-
*
|
|
114
|
-
* For conversations, we train on all assistant turns.
|
|
115
|
-
* Non-assistant tokens (system, user) are masked with -100.
|
|
116
|
-
*
|
|
117
|
-
* Uses single-pass tokenization with token-based boundary detection.
|
|
118
|
-
* Token IDs are derived from the tokenizer for portability across models.
|
|
119
|
-
*/
|
|
120
|
-
private tokenizeConversation;
|
|
121
|
-
/**
|
|
122
|
-
* Tokenize a single example based on its format
|
|
123
|
-
*/
|
|
124
|
-
private tokenizeExample;
|
|
125
|
-
/**
|
|
126
|
-
* Collate multiple examples into a padded batch
|
|
127
|
-
*/
|
|
128
|
-
collateBatch(indices: number[]): Promise<SFTBatch>;
|
|
129
|
-
/**
|
|
130
|
-
* Generate batches for training
|
|
131
|
-
*/
|
|
132
|
-
batches(batchSize: number): AsyncGenerator<SFTBatch>;
|
|
133
|
-
/**
|
|
134
|
-
* Get total number of batches for a given batch size
|
|
135
|
-
*/
|
|
136
|
-
numBatches(batchSize: number): number;
|
|
137
|
-
}
|
|
138
|
-
/**
|
|
139
|
-
* Load SFT dataset from a JSONL file
|
|
140
|
-
*
|
|
141
|
-
* Supports two formats:
|
|
142
|
-
* 1. Prompt-Completion: {"prompt": [...], "completion": {...}}
|
|
143
|
-
* 2. Conversation: {"messages": [...]}
|
|
144
|
-
*
|
|
145
|
-
* @param path - Path to the JSONL file (relative to cwd or allowedRoot)
|
|
146
|
-
* @param tokenizer - Qwen3 tokenizer instance
|
|
147
|
-
* @param config - Optional configuration including path validation options
|
|
148
|
-
*/
|
|
149
|
-
export declare function loadSFTDataset(path: string, tokenizer: Qwen3Tokenizer, config?: SFTDatasetConfig & {
|
|
150
|
-
limit?: number;
|
|
151
|
-
} & PathValidationOptions): Promise<SFTDataset>;
|
|
152
|
-
/**
|
|
153
|
-
* Create SFT dataset from examples directly
|
|
154
|
-
*/
|
|
155
|
-
export declare function createSFTDataset(examples: SFTExample[], tokenizer: Qwen3Tokenizer, config?: SFTDatasetConfig): SFTDataset;
|
|
156
|
-
//# sourceMappingURL=sft-dataset.d.ts.map
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"sft-dataset.d.ts","sourceRoot":"","sources":["../../src/data/sft-dataset.ts"],"names":[],"mappings":"AAAA;;;;;;;;GAQG;AAIH,OAAO,KAAK,EAAE,cAAc,EAAE,MAAM,gBAAgB,CAAC;AACrD,OAAO,KAAK,EAAE,WAAW,EAAE,MAAM,UAAU,CAAC;AAC5C,OAAO,EAA2C,KAAK,qBAAqB,EAAE,MAAM,wBAAwB,CAAC;AAK7G;;;;;GAKG;AACH,MAAM,WAAW,eAAe;IAC9B,gCAAgC;IAChC,OAAO,EAAE,MAAM,CAAC;IAChB,8BAA8B;IAC9B,KAAK,EAAE,MAAM,CAAC;IACd,2EAA2E;IAC3E,aAAa,EAAE,MAAM,EAAE,CAAC;CACzB;AAgDD;;GAEG;AACH,MAAM,WAAW,0BAA0B;IACzC,MAAM,EAAE,WAAW,EAAE,CAAC;IACtB,UAAU,EAAE,WAAW,CAAC;CACzB;AAED;;GAEG;AACH,MAAM,WAAW,sBAAsB;IACrC,QAAQ,EAAE,WAAW,EAAE,CAAC;CACzB;AAED;;GAEG;AACH,MAAM,MAAM,UAAU,GAAG,0BAA0B,GAAG,sBAAsB,CAAC;AAE7E;;GAEG;AACH,MAAM,WAAW,QAAQ;IACvB,QAAQ,EAAE,UAAU,CAAC;IACrB,MAAM,EAAE,UAAU,CAAC;IACnB,KAAK,EAAE,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;CACzB;AAED;;GAEG;AACH,MAAM,WAAW,gBAAgB;IAC/B,YAAY,CAAC,EAAE,MAAM,CAAC;IACtB,cAAc,CAAC,EAAE,OAAO,CAAC;IACzB,cAAc,CAAC,EAAE,OAAO,CAAC;IACzB,IAAI,CAAC,EAAE,MAAM,CAAC;IAEd;;;;;;OAMG;IACH,eAAe,CAAC,EAAE,OAAO,CAAC,eAAe,CAAC,CAAC;CAC5C;AAeD;;GAEG;AACH,qBAAa,UAAU;IACrB,OAAO,CAAC,QAAQ,CAAe;IAC/B,OAAO,CAAC,SAAS,CAAiB;IAClC,OAAO,CAAC,MAAM,CAAkF;IAChG,OAAO,CAAC,MAAM,CAAuC;IACrD,OAAO,CAAC,eAAe,CAAW;IAClC,OAAO,CAAC,GAAG,CAAe;IAC1B,iDAAiD;IACjD,OAAO,CAAC,eAAe,CAAkB;gBAE7B,QAAQ,EAAE,UAAU,EAAE,EAAE,SAAS,EAAE,cAAc,EAAE,MAAM,GAAE,gBAAqB;IAsC5F;;OAEG;IACH,IAAI,MAAM,IAAI,MAAM,CAEnB;IAED;;;;;;OAMG;IACH,eAAe,CAAC,KAAK,EAAE,MAAM,GAAG,IAAI;IAYpC;;OAEG;IACH,OAAO,CAAC,kBAAkB;IAQ1B;;;OAGG;IACH,OAAO,CAAC,sBAAsB;IAS9B;;OAEG;YACW,wBAAwB;IA8CtC;;OAEG;IACH,OAAO,CAAC,cAAc;IAItB;;;;;;;;OAQG;YACW,oBAAoB;IA2DlC;;OAEG;YACW,eAAe;IAQ7B;;OAEG;IACG,YAAY,CAAC,OAAO,EAAE,MAAM,EAAE,GAAG,OAAO,CAAC,QAAQ,CAAC;IA8CxD;;OAEG;IACI,OAAO,CAAC,SAAS,EAAE,MAAM,GAAG,cAAc,CAAC,QAAQ,CAAC;IAQ3D;;OAEG;IACH,UAAU,CAAC,SAAS,EAAE,MAAM,GAAG,MAAM;CAGtC;AAoFD;;;;;;;;;;GAUG;AACH,wBAAsB,cAAc,CAClC,IAAI,EAAE,MAAM,EACZ,SAAS,EAAE,cAAc,EACzB,MAAM,CAAC,EAAE,gBAAgB,GAAG;IAAE,KAAK,CAAC,EAAE,MAAM,CAAA;CAAE,GAAG,qBAAqB,GACrE,OAAO,CAAC,UAAU,CAAC,CAarB;AAED;;GAEG;AACH,wBAAgB,gBAAgB,CAC9B,QAAQ,EAAE,UAAU,EAAE,EACtB,SAAS,EAAE,cAAc,EACzB,MAAM,CAAC,EAAE,gBAAgB,GACxB,UAAU,CAEZ"}
|
package/dist/data/sft-dataset.js
DELETED
|
@@ -1,415 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* SFT Dataset handling for Supervised Fine-Tuning
|
|
3
|
-
*
|
|
4
|
-
* Supports two data formats (auto-detected):
|
|
5
|
-
* 1. Prompt-Completion: { prompt: ChatMessage[], completion: ChatMessage }
|
|
6
|
-
* 2. Full Conversation: { messages: ChatMessage[] }
|
|
7
|
-
*
|
|
8
|
-
* Both formats produce tokenized batches with labels masked appropriately.
|
|
9
|
-
*/
|
|
10
|
-
import { readFileSync } from 'node:fs';
|
|
11
|
-
import { resolve as resolvePath } from 'node:path';
|
|
12
|
-
import { validatePathContainment, getAllowedRoot } from '../utils/path-security';
|
|
13
|
-
// -100 is the standard ignore index for cross-entropy loss
|
|
14
|
-
const IGNORE_INDEX = -100;
|
|
15
|
-
/**
|
|
16
|
-
* Get special token IDs from a tokenizer
|
|
17
|
-
*
|
|
18
|
-
* Queries the tokenizer to get the actual token IDs for special tokens.
|
|
19
|
-
* This ensures portability across different tokenizers/vocabularies.
|
|
20
|
-
*
|
|
21
|
-
* @param tokenizer - The tokenizer instance
|
|
22
|
-
* @returns Special token IDs derived from the tokenizer
|
|
23
|
-
* @throws Error if required special tokens are not found
|
|
24
|
-
*/
|
|
25
|
-
function getSpecialTokenIds(tokenizer) {
|
|
26
|
-
// Get im_start and im_end tokens using the tokenizer's special token getters
|
|
27
|
-
const imStartToken = tokenizer.getImStartToken(); // "<|im_start|>"
|
|
28
|
-
const imEndToken = tokenizer.getImEndToken(); // "<|im_end|>"
|
|
29
|
-
const imStart = tokenizer.tokenToId(imStartToken);
|
|
30
|
-
const imEnd = tokenizer.tokenToId(imEndToken);
|
|
31
|
-
// Validate that we got valid IDs (tokenToId returns null for unknown tokens)
|
|
32
|
-
if (imStart === null || imEnd === null) {
|
|
33
|
-
throw new Error(`Tokenizer does not have required special tokens for ChatML format. ` +
|
|
34
|
-
`Got im_start=${imStart}, im_end=${imEnd}. ` +
|
|
35
|
-
`This tokenizer may not be compatible with ChatML format.`);
|
|
36
|
-
}
|
|
37
|
-
// Get newline token IDs - these vary by tokenizer
|
|
38
|
-
// Try common newline representations
|
|
39
|
-
const newlineTokens = [];
|
|
40
|
-
const potentialNewlines = ['\n', ' \n', '\r\n', '\n\n'];
|
|
41
|
-
for (const nl of potentialNewlines) {
|
|
42
|
-
const id = tokenizer.tokenToId(nl);
|
|
43
|
-
if (id !== null && !newlineTokens.includes(id)) {
|
|
44
|
-
newlineTokens.push(id);
|
|
45
|
-
}
|
|
46
|
-
}
|
|
47
|
-
// If no newline tokens found, we'll rely on the fallback in tokenizeConversation
|
|
48
|
-
return {
|
|
49
|
-
imStart,
|
|
50
|
-
imEnd,
|
|
51
|
-
newlineTokens,
|
|
52
|
-
};
|
|
53
|
-
}
|
|
54
|
-
/**
|
|
55
|
-
* Detect the format of an SFT example
|
|
56
|
-
*/
|
|
57
|
-
function detectFormat(example) {
|
|
58
|
-
if ('prompt' in example && 'completion' in example) {
|
|
59
|
-
return 'prompt-completion';
|
|
60
|
-
}
|
|
61
|
-
if ('messages' in example) {
|
|
62
|
-
return 'conversation';
|
|
63
|
-
}
|
|
64
|
-
throw new Error('Invalid SFT example format. Expected either {prompt, completion} or {messages}');
|
|
65
|
-
}
|
|
66
|
-
/**
|
|
67
|
-
* SFT Dataset class for handling SFT training data
|
|
68
|
-
*/
|
|
69
|
-
export class SFTDataset {
|
|
70
|
-
examples;
|
|
71
|
-
tokenizer;
|
|
72
|
-
config;
|
|
73
|
-
format;
|
|
74
|
-
shuffledIndices;
|
|
75
|
-
rng;
|
|
76
|
-
/** Cached special token IDs for label masking */
|
|
77
|
-
specialTokenIds;
|
|
78
|
-
constructor(examples, tokenizer, config = {}) {
|
|
79
|
-
if (examples.length === 0) {
|
|
80
|
-
throw new Error('SFT dataset must contain at least one example');
|
|
81
|
-
}
|
|
82
|
-
this.examples = examples;
|
|
83
|
-
this.tokenizer = tokenizer;
|
|
84
|
-
this.config = {
|
|
85
|
-
maxSeqLength: config.maxSeqLength ?? 2048,
|
|
86
|
-
completionOnly: config.completionOnly ?? false, // Changed to false for TRL parity
|
|
87
|
-
enableThinking: config.enableThinking ?? false,
|
|
88
|
-
seed: config.seed ?? 42,
|
|
89
|
-
};
|
|
90
|
-
this.rng = this.createSeededRandom(this.config.seed);
|
|
91
|
-
// Get special token IDs from tokenizer, with optional overrides
|
|
92
|
-
const derivedTokenIds = getSpecialTokenIds(tokenizer);
|
|
93
|
-
this.specialTokenIds = {
|
|
94
|
-
imStart: config.specialTokenIds?.imStart ?? derivedTokenIds.imStart,
|
|
95
|
-
imEnd: config.specialTokenIds?.imEnd ?? derivedTokenIds.imEnd,
|
|
96
|
-
newlineTokens: config.specialTokenIds?.newlineTokens ?? derivedTokenIds.newlineTokens,
|
|
97
|
-
};
|
|
98
|
-
// Detect format from first example
|
|
99
|
-
this.format = detectFormat(examples[0]);
|
|
100
|
-
// Validate all examples have the same format
|
|
101
|
-
for (let i = 1; i < examples.length; i++) {
|
|
102
|
-
const fmt = detectFormat(examples[i]);
|
|
103
|
-
if (fmt !== this.format) {
|
|
104
|
-
throw new Error(`Inconsistent SFT data format: example 0 is ${this.format}, example ${i} is ${fmt}`);
|
|
105
|
-
}
|
|
106
|
-
}
|
|
107
|
-
// Initialize indices
|
|
108
|
-
this.shuffledIndices = Array.from({ length: examples.length }, (_, i) => i);
|
|
109
|
-
}
|
|
110
|
-
/**
|
|
111
|
-
* Get the number of examples in the dataset
|
|
112
|
-
*/
|
|
113
|
-
get length() {
|
|
114
|
-
return this.examples.length;
|
|
115
|
-
}
|
|
116
|
-
/**
|
|
117
|
-
* Shuffle dataset for a specific epoch using epoch-based seeding.
|
|
118
|
-
* This ensures reproducible shuffles across training resumes.
|
|
119
|
-
* Each epoch gets a deterministic shuffle based on (baseSeed + epoch).
|
|
120
|
-
*
|
|
121
|
-
* @param epoch - The epoch number (used as seed offset)
|
|
122
|
-
*/
|
|
123
|
-
shuffleForEpoch(epoch) {
|
|
124
|
-
// Reset RNG with epoch-specific seed for reproducibility
|
|
125
|
-
this.rng = this.createSeededRandom(this.config.seed + epoch);
|
|
126
|
-
// Reset indices to original order
|
|
127
|
-
this.shuffledIndices = Array.from({ length: this.examples.length }, (_, i) => i);
|
|
128
|
-
// Fisher-Yates shuffle
|
|
129
|
-
for (let i = this.shuffledIndices.length - 1; i > 0; i--) {
|
|
130
|
-
const j = Math.floor(this.rng() * (i + 1));
|
|
131
|
-
[this.shuffledIndices[i], this.shuffledIndices[j]] = [this.shuffledIndices[j], this.shuffledIndices[i]];
|
|
132
|
-
}
|
|
133
|
-
}
|
|
134
|
-
/**
|
|
135
|
-
* Create a seeded pseudo-random number generator (Linear Congruential Generator)
|
|
136
|
-
*/
|
|
137
|
-
createSeededRandom(seed) {
|
|
138
|
-
let s = seed;
|
|
139
|
-
return () => {
|
|
140
|
-
s = (s * 1103515245 + 12345) & 0x7fffffff;
|
|
141
|
-
return s / 0x7fffffff;
|
|
142
|
-
};
|
|
143
|
-
}
|
|
144
|
-
/**
|
|
145
|
-
* Find length of common prefix between two token arrays
|
|
146
|
-
* Handles chat template quirks where prompt tokens may not be exact prefix of full tokens
|
|
147
|
-
*/
|
|
148
|
-
findCommonPrefixLength(prompt, full) {
|
|
149
|
-
let i = 0;
|
|
150
|
-
const maxLen = Math.min(prompt.length, full.length);
|
|
151
|
-
while (i < maxLen && prompt[i] === full[i]) {
|
|
152
|
-
i++;
|
|
153
|
-
}
|
|
154
|
-
return i;
|
|
155
|
-
}
|
|
156
|
-
/**
|
|
157
|
-
* Tokenize a prompt-completion example
|
|
158
|
-
*/
|
|
159
|
-
async tokenizePromptCompletion(example) {
|
|
160
|
-
// Tokenize prompt with generation prompt (so the model learns to continue)
|
|
161
|
-
const promptTokens = await this.tokenizer.applyChatTemplate(example.prompt, true, // add generation prompt
|
|
162
|
-
null, this.config.enableThinking);
|
|
163
|
-
// Create full messages for tokenization
|
|
164
|
-
const fullMessages = [...example.prompt, example.completion];
|
|
165
|
-
const fullTokens = await this.tokenizer.applyChatTemplate(fullMessages, false, // no generation prompt at the end
|
|
166
|
-
null, this.config.enableThinking);
|
|
167
|
-
// Convert to regular arrays for manipulation
|
|
168
|
-
const promptArr = Array.from(promptTokens, Number);
|
|
169
|
-
const inputIds = Array.from(fullTokens, Number);
|
|
170
|
-
// Use common prefix detection to handle chat template quirks
|
|
171
|
-
// (some templates may not produce prompt tokens as exact prefix of full tokens)
|
|
172
|
-
const promptLen = this.findCommonPrefixLength(promptArr, inputIds);
|
|
173
|
-
if (promptLen !== promptArr.length) {
|
|
174
|
-
console.warn(`[SFT Dataset] Prompt tokens differ from prefix of full sequence ` +
|
|
175
|
-
`(${promptArr.length} vs ${promptLen}). Using common prefix for masking.`);
|
|
176
|
-
}
|
|
177
|
-
// Create labels: -100 for prompt tokens, actual tokens for completion
|
|
178
|
-
const labels = inputIds.map((id, i) => {
|
|
179
|
-
if (this.config.completionOnly && i < promptLen) {
|
|
180
|
-
return IGNORE_INDEX;
|
|
181
|
-
}
|
|
182
|
-
return id;
|
|
183
|
-
});
|
|
184
|
-
return { inputIds, labels };
|
|
185
|
-
}
|
|
186
|
-
/**
|
|
187
|
-
* Check if a token ID is a newline token
|
|
188
|
-
*/
|
|
189
|
-
isNewlineToken(tokenId) {
|
|
190
|
-
return this.specialTokenIds.newlineTokens.includes(tokenId);
|
|
191
|
-
}
|
|
192
|
-
/**
|
|
193
|
-
* Tokenize a conversation example
|
|
194
|
-
*
|
|
195
|
-
* For conversations, we train on all assistant turns.
|
|
196
|
-
* Non-assistant tokens (system, user) are masked with -100.
|
|
197
|
-
*
|
|
198
|
-
* Uses single-pass tokenization with token-based boundary detection.
|
|
199
|
-
* Token IDs are derived from the tokenizer for portability across models.
|
|
200
|
-
*/
|
|
201
|
-
async tokenizeConversation(example) {
|
|
202
|
-
const messages = example.messages;
|
|
203
|
-
// Single tokenization pass
|
|
204
|
-
const fullTokens = await this.tokenizer.applyChatTemplate(messages, false, null, this.config.enableThinking);
|
|
205
|
-
const inputIds = Array.from(fullTokens, Number);
|
|
206
|
-
// If not masking prompts, all tokens are trainable
|
|
207
|
-
if (!this.config.completionOnly) {
|
|
208
|
-
return { inputIds, labels: inputIds.slice() };
|
|
209
|
-
}
|
|
210
|
-
// Token-based boundary detection using special tokens (derived from tokenizer)
|
|
211
|
-
const { imStart, imEnd } = this.specialTokenIds;
|
|
212
|
-
// Get "assistant" token ID (it's a single token in Qwen3)
|
|
213
|
-
const assistantTokenId = this.tokenizer.tokenToId('assistant');
|
|
214
|
-
const labels = Array.from({ length: inputIds.length }, () => IGNORE_INDEX);
|
|
215
|
-
let inAssistant = false;
|
|
216
|
-
for (let i = 0; i < inputIds.length; i++) {
|
|
217
|
-
// Detect assistant region: <|im_start|> followed by "assistant" token
|
|
218
|
-
if (inputIds[i] === imStart && i + 1 < inputIds.length && inputIds[i + 1] === assistantTokenId) {
|
|
219
|
-
// Skip the <|im_start|>assistant\n header, start training from content
|
|
220
|
-
// Find the newline after "assistant"
|
|
221
|
-
let j = i + 2;
|
|
222
|
-
while (j < inputIds.length && inputIds[j] !== imEnd) {
|
|
223
|
-
// Look for newline token (dynamically derived from tokenizer)
|
|
224
|
-
if (this.isNewlineToken(inputIds[j])) {
|
|
225
|
-
inAssistant = true;
|
|
226
|
-
i = j; // Skip to after header
|
|
227
|
-
break;
|
|
228
|
-
}
|
|
229
|
-
j++;
|
|
230
|
-
}
|
|
231
|
-
if (!inAssistant) {
|
|
232
|
-
// Fallback: just start after assistant token
|
|
233
|
-
inAssistant = true;
|
|
234
|
-
i = i + 1;
|
|
235
|
-
}
|
|
236
|
-
continue;
|
|
237
|
-
}
|
|
238
|
-
if (inAssistant && inputIds[i] !== imEnd) {
|
|
239
|
-
labels[i] = inputIds[i];
|
|
240
|
-
}
|
|
241
|
-
if (inputIds[i] === imEnd) {
|
|
242
|
-
inAssistant = false;
|
|
243
|
-
}
|
|
244
|
-
}
|
|
245
|
-
return { inputIds, labels };
|
|
246
|
-
}
|
|
247
|
-
/**
|
|
248
|
-
* Tokenize a single example based on its format
|
|
249
|
-
*/
|
|
250
|
-
async tokenizeExample(example) {
|
|
251
|
-
if (this.format === 'prompt-completion') {
|
|
252
|
-
return this.tokenizePromptCompletion(example);
|
|
253
|
-
}
|
|
254
|
-
else {
|
|
255
|
-
return this.tokenizeConversation(example);
|
|
256
|
-
}
|
|
257
|
-
}
|
|
258
|
-
/**
|
|
259
|
-
* Collate multiple examples into a padded batch
|
|
260
|
-
*/
|
|
261
|
-
async collateBatch(indices) {
|
|
262
|
-
const examples = indices.map((i) => this.examples[this.shuffledIndices[i]]);
|
|
263
|
-
// Tokenize all examples
|
|
264
|
-
const tokenized = [];
|
|
265
|
-
for (const example of examples) {
|
|
266
|
-
tokenized.push(await this.tokenizeExample(example));
|
|
267
|
-
}
|
|
268
|
-
// Find max length (capped at maxSeqLength)
|
|
269
|
-
const maxLen = Math.min(this.config.maxSeqLength, Math.max(...tokenized.map((t) => t.inputIds.length)));
|
|
270
|
-
// Pad and truncate
|
|
271
|
-
const batchSize = examples.length;
|
|
272
|
-
const paddedInputIds = new Int32Array(batchSize * maxLen);
|
|
273
|
-
const paddedLabels = new Int32Array(batchSize * maxLen);
|
|
274
|
-
const padTokenId = this.tokenizer.getPadTokenId();
|
|
275
|
-
for (let b = 0; b < batchSize; b++) {
|
|
276
|
-
const { inputIds, labels } = tokenized[b];
|
|
277
|
-
const seqLen = Math.min(inputIds.length, maxLen);
|
|
278
|
-
// Truncate from the left if necessary (keep the end of the sequence)
|
|
279
|
-
const startIdx = Math.max(0, inputIds.length - maxLen);
|
|
280
|
-
for (let s = 0; s < maxLen; s++) {
|
|
281
|
-
const offset = b * maxLen + s;
|
|
282
|
-
if (s < seqLen) {
|
|
283
|
-
paddedInputIds[offset] = inputIds[startIdx + s];
|
|
284
|
-
paddedLabels[offset] = labels[startIdx + s];
|
|
285
|
-
}
|
|
286
|
-
else {
|
|
287
|
-
// Pad
|
|
288
|
-
paddedInputIds[offset] = padTokenId;
|
|
289
|
-
paddedLabels[offset] = IGNORE_INDEX;
|
|
290
|
-
}
|
|
291
|
-
}
|
|
292
|
-
}
|
|
293
|
-
return {
|
|
294
|
-
inputIds: paddedInputIds,
|
|
295
|
-
labels: paddedLabels,
|
|
296
|
-
shape: [batchSize, maxLen],
|
|
297
|
-
};
|
|
298
|
-
}
|
|
299
|
-
/**
|
|
300
|
-
* Generate batches for training
|
|
301
|
-
*/
|
|
302
|
-
async *batches(batchSize) {
|
|
303
|
-
for (let i = 0; i < this.examples.length; i += batchSize) {
|
|
304
|
-
const end = Math.min(i + batchSize, this.examples.length);
|
|
305
|
-
const indices = Array.from({ length: end - i }, (_, j) => i + j);
|
|
306
|
-
yield await this.collateBatch(indices);
|
|
307
|
-
}
|
|
308
|
-
}
|
|
309
|
-
/**
|
|
310
|
-
* Get total number of batches for a given batch size
|
|
311
|
-
*/
|
|
312
|
-
numBatches(batchSize) {
|
|
313
|
-
return Math.ceil(this.examples.length / batchSize);
|
|
314
|
-
}
|
|
315
|
-
}
|
|
316
|
-
/**
|
|
317
|
-
* Read JSONL file and parse into records
|
|
318
|
-
*/
|
|
319
|
-
function readJsonl(path, limit) {
|
|
320
|
-
let fileContents;
|
|
321
|
-
try {
|
|
322
|
-
fileContents = readFileSync(path, 'utf8');
|
|
323
|
-
}
|
|
324
|
-
catch (error) {
|
|
325
|
-
const message = error instanceof Error ? error.message : String(error);
|
|
326
|
-
throw new Error(`Failed to read SFT dataset at ${path}: ${message}`);
|
|
327
|
-
}
|
|
328
|
-
const lines = fileContents.split(/\r?\n/).filter((line) => line.trim().length > 0);
|
|
329
|
-
const records = [];
|
|
330
|
-
const max = typeof limit === 'number' && limit > 0 ? limit : Number.POSITIVE_INFINITY;
|
|
331
|
-
for (let i = 0; i < lines.length && records.length < max; i++) {
|
|
332
|
-
const line = lines[i];
|
|
333
|
-
try {
|
|
334
|
-
const parsed = JSON.parse(line);
|
|
335
|
-
records.push(parsed);
|
|
336
|
-
}
|
|
337
|
-
catch (error) {
|
|
338
|
-
const message = error instanceof Error ? error.message : String(error);
|
|
339
|
-
throw new Error(`Failed to parse JSONL at ${path}:${i + 1} - ${message}`);
|
|
340
|
-
}
|
|
341
|
-
}
|
|
342
|
-
return records;
|
|
343
|
-
}
|
|
344
|
-
/**
|
|
345
|
-
* Validate an SFT example
|
|
346
|
-
*/
|
|
347
|
-
function validateSFTExample(example, index) {
|
|
348
|
-
if (typeof example !== 'object' || example === null) {
|
|
349
|
-
throw new Error(`SFT example ${index} must be an object`);
|
|
350
|
-
}
|
|
351
|
-
const obj = example;
|
|
352
|
-
// Check for prompt-completion format
|
|
353
|
-
if ('prompt' in obj && 'completion' in obj) {
|
|
354
|
-
if (!Array.isArray(obj.prompt)) {
|
|
355
|
-
throw new Error(`SFT example ${index}: prompt must be an array of messages`);
|
|
356
|
-
}
|
|
357
|
-
if (typeof obj.completion !== 'object' || obj.completion === null) {
|
|
358
|
-
throw new Error(`SFT example ${index}: completion must be a message object`);
|
|
359
|
-
}
|
|
360
|
-
const completion = obj.completion;
|
|
361
|
-
if (completion.role !== 'assistant') {
|
|
362
|
-
throw new Error(`SFT example ${index}: completion.role must be 'assistant'`);
|
|
363
|
-
}
|
|
364
|
-
if (typeof completion.content !== 'string') {
|
|
365
|
-
throw new Error(`SFT example ${index}: completion.content must be a string`);
|
|
366
|
-
}
|
|
367
|
-
return {
|
|
368
|
-
prompt: obj.prompt,
|
|
369
|
-
completion: obj.completion,
|
|
370
|
-
};
|
|
371
|
-
}
|
|
372
|
-
// Check for conversation format
|
|
373
|
-
if ('messages' in obj) {
|
|
374
|
-
if (!Array.isArray(obj.messages)) {
|
|
375
|
-
throw new Error(`SFT example ${index}: messages must be an array`);
|
|
376
|
-
}
|
|
377
|
-
if (obj.messages.length === 0) {
|
|
378
|
-
throw new Error(`SFT example ${index}: messages cannot be empty`);
|
|
379
|
-
}
|
|
380
|
-
// Check that at least one message is from assistant
|
|
381
|
-
const hasAssistant = obj.messages.some((m) => typeof m === 'object' && m !== null && m.role === 'assistant');
|
|
382
|
-
if (!hasAssistant) {
|
|
383
|
-
throw new Error(`SFT example ${index}: messages must contain at least one assistant message`);
|
|
384
|
-
}
|
|
385
|
-
return { messages: obj.messages };
|
|
386
|
-
}
|
|
387
|
-
throw new Error(`SFT example ${index}: must have either {prompt, completion} or {messages}`);
|
|
388
|
-
}
|
|
389
|
-
/**
|
|
390
|
-
* Load SFT dataset from a JSONL file
|
|
391
|
-
*
|
|
392
|
-
* Supports two formats:
|
|
393
|
-
* 1. Prompt-Completion: {"prompt": [...], "completion": {...}}
|
|
394
|
-
* 2. Conversation: {"messages": [...]}
|
|
395
|
-
*
|
|
396
|
-
* @param path - Path to the JSONL file (relative to cwd or allowedRoot)
|
|
397
|
-
* @param tokenizer - Qwen3 tokenizer instance
|
|
398
|
-
* @param config - Optional configuration including path validation options
|
|
399
|
-
*/
|
|
400
|
-
export async function loadSFTDataset(path, tokenizer, config) {
|
|
401
|
-
const allowedRoot = getAllowedRoot(config);
|
|
402
|
-
const absolutePath = resolvePath(allowedRoot, path);
|
|
403
|
-
// Validate the path stays within allowed root to prevent directory traversal
|
|
404
|
-
validatePathContainment(absolutePath, allowedRoot);
|
|
405
|
-
const rawRecords = readJsonl(absolutePath, config?.limit);
|
|
406
|
-
// Validate and convert
|
|
407
|
-
const examples = rawRecords.map((record, i) => validateSFTExample(record, i));
|
|
408
|
-
return new SFTDataset(examples, tokenizer, config);
|
|
409
|
-
}
|
|
410
|
-
/**
|
|
411
|
-
* Create SFT dataset from examples directly
|
|
412
|
-
*/
|
|
413
|
-
export function createSFTDataset(examples, tokenizer, config) {
|
|
414
|
-
return new SFTDataset(examples, tokenizer, config);
|
|
415
|
-
}
|
package/dist/index.d.ts
DELETED
|
@@ -1,33 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* @mlx-node/trl - Training utilities for MLX models
|
|
3
|
-
*
|
|
4
|
-
* This package provides everything needed for training ML models,
|
|
5
|
-
* aligned with Python's TRL (Transformer Reinforcement Learning) library.
|
|
6
|
-
*
|
|
7
|
-
* For model loading and inference, import from @mlx-node/lm.
|
|
8
|
-
*
|
|
9
|
-
* @example
|
|
10
|
-
* ```typescript
|
|
11
|
-
* import { GRPOTrainer, GRPOConfig, loadLocalGsm8kDataset } from '@mlx-node/trl';
|
|
12
|
-
* import { ModelLoader } from '@mlx-node/lm';
|
|
13
|
-
*
|
|
14
|
-
* const model = await ModelLoader.loadPretrained('./models/qwen3-0.6b');
|
|
15
|
-
* const trainer = await GRPOTrainer.create({ modelPath: './models/qwen3-0.6b' });
|
|
16
|
-
* ```
|
|
17
|
-
*/
|
|
18
|
-
export type { ToolDefinition, FunctionDefinition, FunctionParameters } from '@mlx-node/core';
|
|
19
|
-
export { MxArray } from '@mlx-node/core';
|
|
20
|
-
export { convertModel, convertParquetToJsonl } from '@mlx-node/core';
|
|
21
|
-
export type { ConversionOptions, ConversionResult } from '@mlx-node/core';
|
|
22
|
-
export { type MLXGRPOConfig, ConfigError, getDefaultConfig, mergeConfig, loadTomlConfig, applyOverrides, } from './trainers/grpo-config';
|
|
23
|
-
export { GRPOTrainer, type GRPOTrainerConfig, DEFAULT_GRPO_CONFIG, createRewardRegistry, computeDatasetHash, RewardTimeoutError, type GenerateBatchResult, type TrainStepMetrics, type TrainingMetrics, type TrainingState, type DatasetMetadata, GrpoTrainingEngine, NativeRewardRegistry, type GrpoEngineConfig, type EngineStepMetrics, type EngineEpochMetrics, type BuiltinRewardConfig, } from './trainers/grpo-trainer';
|
|
24
|
-
export { TrainingLogger, createTrainingLogger, type TrainingLoggerConfig, type TrainingMetrics as TrainingLoggerMetrics, type GenerationSample, type TrainingConfigFields, type TuiMessage, type LogEvent, type PromptChoice, type PromptOptions, } from './trainers/training-logger';
|
|
25
|
-
export { type EntropyFilteringConfig, DEFAULT_ENTROPY_CONFIG } from './trainers/grpo-entropy';
|
|
26
|
-
export { SFTTrainer, SftTrainingEngine, type SFTTrainStepResult, type SFTTrainingState, type SftEngineConfig, type SftStepMetrics, type SftEpochMetrics, } from './trainers/sft-trainer';
|
|
27
|
-
export { type SFTTrainerConfig, SFTConfigError, getDefaultSFTConfig, mergeSFTConfig, loadSFTTomlConfig, applySFTOverrides, DEFAULT_SFT_CONFIG, } from './trainers/sft-config';
|
|
28
|
-
export * from './data/dataset';
|
|
29
|
-
export { SFTDataset, loadSFTDataset, createSFTDataset, type SFTExample, type SFTPromptCompletionExample, type SFTConversationExample, type SFTBatch, type SFTDatasetConfig, type SpecialTokenIds, } from './data/sft-dataset';
|
|
30
|
-
export * from './utils/xml-parser';
|
|
31
|
-
export { validatePathContainment, resolveAndValidatePath, getAllowedRoot, PathTraversalError, type PathValidationOptions, } from './utils/path-security';
|
|
32
|
-
export type { ChatRole, ChatMessage, CompletionMessage, Completion, DatasetSplit, DatasetExample, XmlParseResult, RewardComputationInput, PromptFormatterOptions, PromptTemplate, DatasetLoader, RewardFunction, PromptFormatter, CompletionInfo, RewardOutput, } from './types';
|
|
33
|
-
//# sourceMappingURL=index.d.ts.map
|
package/dist/index.d.ts.map
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;;GAgBG;AAOH,YAAY,EAAE,cAAc,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,MAAM,gBAAgB,CAAC;AAG7F,OAAO,EAAE,OAAO,EAAE,MAAM,gBAAgB,CAAC;AAWzC,OAAO,EAAE,YAAY,EAAE,qBAAqB,EAAE,MAAM,gBAAgB,CAAC;AACrE,YAAY,EAAE,iBAAiB,EAAE,gBAAgB,EAAE,MAAM,gBAAgB,CAAC;AAO1E,OAAO,EACL,KAAK,aAAa,EAClB,WAAW,EACX,gBAAgB,EAChB,WAAW,EACX,cAAc,EACd,cAAc,GACf,MAAM,wBAAwB,CAAC;AAEhC,OAAO,EACL,WAAW,EACX,KAAK,iBAAiB,EACtB,mBAAmB,EACnB,oBAAoB,EACpB,kBAAkB,EAClB,kBAAkB,EAClB,KAAK,mBAAmB,EACxB,KAAK,gBAAgB,EACrB,KAAK,eAAe,EACpB,KAAK,aAAa,EAClB,KAAK,eAAe,EAEpB,kBAAkB,EAClB,oBAAoB,EACpB,KAAK,gBAAgB,EACrB,KAAK,iBAAiB,EACtB,KAAK,kBAAkB,EACvB,KAAK,mBAAmB,GACzB,MAAM,yBAAyB,CAAC;AAGjC,OAAO,EACL,cAAc,EACd,oBAAoB,EACpB,KAAK,oBAAoB,EACzB,KAAK,eAAe,IAAI,qBAAqB,EAC7C,KAAK,gBAAgB,EACrB,KAAK,oBAAoB,EACzB,KAAK,UAAU,EACf,KAAK,QAAQ,EACb,KAAK,YAAY,EACjB,KAAK,aAAa,GACnB,MAAM,4BAA4B,CAAC;AAGpC,OAAO,EAAE,KAAK,sBAAsB,EAAE,sBAAsB,EAAE,MAAM,yBAAyB,CAAC;AAG9F,OAAO,EACL,UAAU,EACV,iBAAiB,EACjB,KAAK,kBAAkB,EACvB,KAAK,gBAAgB,EACrB,KAAK,eAAe,EACpB,KAAK,cAAc,EACnB,KAAK,eAAe,GACrB,MAAM,wBAAwB,CAAC;AAEhC,OAAO,EACL,KAAK,gBAAgB,EACrB,cAAc,EACd,mBAAmB,EACnB,cAAc,EACd,iBAAiB,EACjB,iBAAiB,EACjB,kBAAkB,GACnB,MAAM,uBAAuB,CAAC;AAG/B,cAAc,gBAAgB,CAAC;AAC/B,OAAO,EACL,UAAU,EACV,cAAc,EACd,gBAAgB,EAChB,KAAK,UAAU,EACf,KAAK,0BAA0B,EAC/B,KAAK,sBAAsB,EAC3B,KAAK,QAAQ,EACb,KAAK,gBAAgB,EACrB,KAAK,eAAe,GACrB,MAAM,oBAAoB,CAAC;AAG5B,cAAc,oBAAoB,CAAC;AACnC,OAAO,EACL,uBAAuB,EACvB,sBAAsB,EACtB,cAAc,EACd,kBAAkB,EAClB,KAAK,qBAAqB,GAC3B,MAAM,uBAAuB,CAAC;AAG/B,YAAY,EACV,QAAQ,EACR,WAAW,EACX,iBAAiB,EACjB,UAAU,EACV,YAAY,EACZ,cAAc,EACd,cAAc,EACd,sBAAsB,EACtB,sBAAsB,EACtB,cAAc,EACd,aAAa,EACb,cAAc,EACd,eAAe,EAEf,cAAc,EACd,YAAY,GACb,MAAM,SAAS,CAAC"}
|
package/dist/index.js
DELETED
|
@@ -1,47 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* @mlx-node/trl - Training utilities for MLX models
|
|
3
|
-
*
|
|
4
|
-
* This package provides everything needed for training ML models,
|
|
5
|
-
* aligned with Python's TRL (Transformer Reinforcement Learning) library.
|
|
6
|
-
*
|
|
7
|
-
* For model loading and inference, import from @mlx-node/lm.
|
|
8
|
-
*
|
|
9
|
-
* @example
|
|
10
|
-
* ```typescript
|
|
11
|
-
* import { GRPOTrainer, GRPOConfig, loadLocalGsm8kDataset } from '@mlx-node/trl';
|
|
12
|
-
* import { ModelLoader } from '@mlx-node/lm';
|
|
13
|
-
*
|
|
14
|
-
* const model = await ModelLoader.loadPretrained('./models/qwen3-0.6b');
|
|
15
|
-
* const trainer = await GRPOTrainer.create({ modelPath: './models/qwen3-0.6b' });
|
|
16
|
-
* ```
|
|
17
|
-
*/
|
|
18
|
-
// Core tensor (for custom rewards/models)
|
|
19
|
-
export { MxArray } from '@mlx-node/core';
|
|
20
|
-
// Activations are internal-only (Rust) - used by transformers, sampling, GRPO
|
|
21
|
-
// Transformer components are now internal-only (Rust)
|
|
22
|
-
// Use model.chat() or model.generate() instead
|
|
23
|
-
// GRPO utilities (computeAdvantages, computeEntropy, getHighEntropyMask) are internal-only
|
|
24
|
-
// They are used by GRPOTrainingEngine in Rust
|
|
25
|
-
// Model conversion
|
|
26
|
-
export { convertModel, convertParquetToJsonl } from '@mlx-node/core';
|
|
27
|
-
// =============================================================================
|
|
28
|
-
// TRL-specific exports
|
|
29
|
-
// =============================================================================
|
|
30
|
-
// Trainers
|
|
31
|
-
export { ConfigError, getDefaultConfig, mergeConfig, loadTomlConfig, applyOverrides, } from './trainers/grpo-config';
|
|
32
|
-
export { GRPOTrainer, DEFAULT_GRPO_CONFIG, createRewardRegistry, computeDatasetHash, RewardTimeoutError,
|
|
33
|
-
// Re-export native types from trainer
|
|
34
|
-
GrpoTrainingEngine, NativeRewardRegistry, } from './trainers/grpo-trainer';
|
|
35
|
-
// Unified Training Logger (recommended)
|
|
36
|
-
export { TrainingLogger, createTrainingLogger, } from './trainers/training-logger';
|
|
37
|
-
// Entropy configuration
|
|
38
|
-
export { DEFAULT_ENTROPY_CONFIG } from './trainers/grpo-entropy';
|
|
39
|
-
// SFT Trainer
|
|
40
|
-
export { SFTTrainer, SftTrainingEngine, } from './trainers/sft-trainer';
|
|
41
|
-
export { SFTConfigError, getDefaultSFTConfig, mergeSFTConfig, loadSFTTomlConfig, applySFTOverrides, DEFAULT_SFT_CONFIG, } from './trainers/sft-config';
|
|
42
|
-
// Data
|
|
43
|
-
export * from './data/dataset';
|
|
44
|
-
export { SFTDataset, loadSFTDataset, createSFTDataset, } from './data/sft-dataset';
|
|
45
|
-
// Utils
|
|
46
|
-
export * from './utils/xml-parser';
|
|
47
|
-
export { validatePathContainment, resolveAndValidatePath, getAllowedRoot, PathTraversalError, } from './utils/path-security';
|