@mlx-node/trl 0.0.0 → 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +389 -0
- package/package.json +16 -5
- package/dist/data/dataset.d.ts +0 -22
- package/dist/data/dataset.d.ts.map +0 -1
- package/dist/data/dataset.js +0 -142
- package/dist/data/sft-dataset.d.ts +0 -156
- package/dist/data/sft-dataset.d.ts.map +0 -1
- package/dist/data/sft-dataset.js +0 -415
- package/dist/index.d.ts +0 -33
- package/dist/index.d.ts.map +0 -1
- package/dist/index.js +0 -47
- package/dist/trainers/grpo-config.d.ts +0 -42
- package/dist/trainers/grpo-config.d.ts.map +0 -1
- package/dist/trainers/grpo-config.js +0 -220
- package/dist/trainers/grpo-entropy.d.ts +0 -33
- package/dist/trainers/grpo-entropy.d.ts.map +0 -1
- package/dist/trainers/grpo-entropy.js +0 -18
- package/dist/trainers/grpo-trainer.d.ts +0 -602
- package/dist/trainers/grpo-trainer.d.ts.map +0 -1
- package/dist/trainers/grpo-trainer.js +0 -1439
- package/dist/trainers/sft-config.d.ts +0 -32
- package/dist/trainers/sft-config.d.ts.map +0 -1
- package/dist/trainers/sft-config.js +0 -186
- package/dist/trainers/sft-trainer.d.ts +0 -141
- package/dist/trainers/sft-trainer.d.ts.map +0 -1
- package/dist/trainers/sft-trainer.js +0 -502
- package/dist/trainers/training-logger.d.ts +0 -375
- package/dist/trainers/training-logger.d.ts.map +0 -1
- package/dist/trainers/training-logger.js +0 -542
- package/dist/types.d.ts +0 -54
- package/dist/types.d.ts.map +0 -1
- package/dist/types.js +0 -1
- package/dist/utils/path-security.d.ts +0 -51
- package/dist/utils/path-security.d.ts.map +0 -1
- package/dist/utils/path-security.js +0 -69
- package/dist/utils/xml-parser.d.ts +0 -6
- package/dist/utils/xml-parser.d.ts.map +0 -1
- package/dist/utils/xml-parser.js +0 -184
package/README.md
ADDED
|
@@ -0,0 +1,389 @@
|
|
|
1
|
+
# @mlx-node/trl
|
|
2
|
+
|
|
3
|
+
Training library for language models on Apple Silicon. Supports GRPO (Group Relative Policy Optimization) and SFT (Supervised Fine-Tuning) with Metal GPU acceleration, built-in reward functions, dataset handling, and checkpoint management.
|
|
4
|
+
|
|
5
|
+
## Requirements
|
|
6
|
+
|
|
7
|
+
- macOS with Apple Silicon (M1 or later)
|
|
8
|
+
- Node.js 18+
|
|
9
|
+
|
|
10
|
+
## Installation
|
|
11
|
+
|
|
12
|
+
```bash
|
|
13
|
+
npm install @mlx-node/trl
|
|
14
|
+
```
|
|
15
|
+
|
|
16
|
+
## Quick Start
|
|
17
|
+
|
|
18
|
+
### GRPO Training
|
|
19
|
+
|
|
20
|
+
```typescript
|
|
21
|
+
import { GRPOTrainer } from '@mlx-node/trl';
|
|
22
|
+
|
|
23
|
+
const trainer = await GRPOTrainer.create({
|
|
24
|
+
modelPath: './models/Qwen3-0.6B',
|
|
25
|
+
outputDir: './output/grpo-run',
|
|
26
|
+
learningRate: 1e-6,
|
|
27
|
+
groupSize: 4,
|
|
28
|
+
maxCompletionLength: 256,
|
|
29
|
+
temperature: 0.8,
|
|
30
|
+
rewardFunction: async (outputs) => {
|
|
31
|
+
return outputs.map((o) => (o.text.includes('correct') ? 1.0 : 0.0));
|
|
32
|
+
},
|
|
33
|
+
});
|
|
34
|
+
|
|
35
|
+
const dataset = await loadDataset('train');
|
|
36
|
+
await trainer.train(dataset);
|
|
37
|
+
```
|
|
38
|
+
|
|
39
|
+
### SFT Training
|
|
40
|
+
|
|
41
|
+
```typescript
|
|
42
|
+
import { SFTTrainer } from '@mlx-node/trl';
|
|
43
|
+
|
|
44
|
+
const trainer = await SFTTrainer.create({
|
|
45
|
+
modelName: './models/Qwen3-0.6B',
|
|
46
|
+
outputDir: './output/sft-run',
|
|
47
|
+
learningRate: 2e-5,
|
|
48
|
+
batchSize: 4,
|
|
49
|
+
numEpochs: 3,
|
|
50
|
+
completionOnly: true,
|
|
51
|
+
});
|
|
52
|
+
|
|
53
|
+
await trainer.train('./data/training.jsonl');
|
|
54
|
+
```
|
|
55
|
+
|
|
56
|
+
## GRPO Training
|
|
57
|
+
|
|
58
|
+
GRPO generates multiple completions per prompt, scores them with reward functions, and trains the model to prefer higher-reward outputs.
|
|
59
|
+
|
|
60
|
+
### Loss Variants
|
|
61
|
+
|
|
62
|
+
| Loss Type | Description |
|
|
63
|
+
| --------- | ------------------------------------------- |
|
|
64
|
+
| `grpo` | Standard Group Relative Policy Optimization |
|
|
65
|
+
| `dapo` | Dynamic sampling with adaptive clipping |
|
|
66
|
+
| `dr_grpo` | Dr.GRPO with improved gradient estimation |
|
|
67
|
+
| `bnpo` | Batch-normalized policy optimization |
|
|
68
|
+
|
|
69
|
+
### Configuration
|
|
70
|
+
|
|
71
|
+
```typescript
|
|
72
|
+
import { GRPOTrainer, GRPOTrainerConfig } from '@mlx-node/trl';
|
|
73
|
+
|
|
74
|
+
const config: GRPOTrainerConfig = {
|
|
75
|
+
// Model
|
|
76
|
+
modelPath: './models/Qwen3-0.6B',
|
|
77
|
+
outputDir: './output',
|
|
78
|
+
|
|
79
|
+
// Training
|
|
80
|
+
learningRate: 1e-6,
|
|
81
|
+
batchSize: 1,
|
|
82
|
+
numEpochs: 1,
|
|
83
|
+
gradientAccumulationSteps: 1,
|
|
84
|
+
gradientClipNorm: 1.0,
|
|
85
|
+
weightDecay: 0.01,
|
|
86
|
+
|
|
87
|
+
// GRPO
|
|
88
|
+
groupSize: 4, // completions per prompt
|
|
89
|
+
clipEpsilon: 0.2, // PPO clipping
|
|
90
|
+
klCoef: 0.0, // KL divergence coefficient
|
|
91
|
+
lossType: 'grpo', // grpo | dapo | dr_grpo | bnpo
|
|
92
|
+
|
|
93
|
+
// Generation
|
|
94
|
+
maxCompletionLength: 256,
|
|
95
|
+
temperature: 0.8,
|
|
96
|
+
topP: 0.95,
|
|
97
|
+
repetitionPenalty: 1.1,
|
|
98
|
+
|
|
99
|
+
// Tool calling
|
|
100
|
+
tools: [toolDef],
|
|
101
|
+
enableThinking: true,
|
|
102
|
+
|
|
103
|
+
// Rewards
|
|
104
|
+
rewardFunction: myRewardFn,
|
|
105
|
+
|
|
106
|
+
// Memory optimization
|
|
107
|
+
gradientCheckpointing: true,
|
|
108
|
+
lmHeadChunkSize: 2,
|
|
109
|
+
vocabChunkSize: 65536,
|
|
110
|
+
|
|
111
|
+
// Checkpointing
|
|
112
|
+
saveInterval: 100,
|
|
113
|
+
maxCheckpoints: 3,
|
|
114
|
+
resumeFromCheckpoint: './output/checkpoint-500',
|
|
115
|
+
|
|
116
|
+
// Optimizer
|
|
117
|
+
optimizerType: 'adamw', // adamw | sgd
|
|
118
|
+
};
|
|
119
|
+
```
|
|
120
|
+
|
|
121
|
+
### TOML Configuration
|
|
122
|
+
|
|
123
|
+
Load training config from a TOML file:
|
|
124
|
+
|
|
125
|
+
```typescript
|
|
126
|
+
import { loadTomlConfig, applyOverrides } from '@mlx-node/trl';
|
|
127
|
+
|
|
128
|
+
const config = loadTomlConfig('./train.toml');
|
|
129
|
+
applyOverrides(config, ['learningRate=2e-6', 'batchSize=2']);
|
|
130
|
+
```
|
|
131
|
+
|
|
132
|
+
### Built-in Rewards
|
|
133
|
+
|
|
134
|
+
Register native Rust reward functions for high-performance scoring:
|
|
135
|
+
|
|
136
|
+
```typescript
|
|
137
|
+
trainer.registerBuiltinReward({
|
|
138
|
+
type: 'ToolUse',
|
|
139
|
+
weight: 1.0,
|
|
140
|
+
allowedTools: ['get_weather', 'search'],
|
|
141
|
+
});
|
|
142
|
+
|
|
143
|
+
trainer.registerBuiltinReward({
|
|
144
|
+
type: 'XmlFormat',
|
|
145
|
+
weight: 0.5,
|
|
146
|
+
requiredTags: ['reasoning', 'answer'],
|
|
147
|
+
});
|
|
148
|
+
|
|
149
|
+
trainer.registerBuiltinReward({
|
|
150
|
+
type: 'Length',
|
|
151
|
+
weight: 0.3,
|
|
152
|
+
min: 50,
|
|
153
|
+
max: 500,
|
|
154
|
+
});
|
|
155
|
+
|
|
156
|
+
trainer.registerBuiltinReward({
|
|
157
|
+
type: 'JsonSchema',
|
|
158
|
+
weight: 1.0,
|
|
159
|
+
});
|
|
160
|
+
```
|
|
161
|
+
|
|
162
|
+
### Custom Reward Functions
|
|
163
|
+
|
|
164
|
+
```typescript
|
|
165
|
+
import { RewardFunction, RewardOutput } from '@mlx-node/trl';
|
|
166
|
+
|
|
167
|
+
const reward: RewardFunction = async (outputs: RewardOutput[]) => {
|
|
168
|
+
return outputs.map((output) => {
|
|
169
|
+
let score = 0;
|
|
170
|
+
if (output.toolCalls?.length) score += 0.5;
|
|
171
|
+
if (output.text.length > 100) score += 0.3;
|
|
172
|
+
return score;
|
|
173
|
+
});
|
|
174
|
+
};
|
|
175
|
+
|
|
176
|
+
trainer.setRewardFunction(reward);
|
|
177
|
+
```
|
|
178
|
+
|
|
179
|
+
### Custom Training Loop
|
|
180
|
+
|
|
181
|
+
For advanced use cases, use the low-level API:
|
|
182
|
+
|
|
183
|
+
```typescript
|
|
184
|
+
const trainer = await GRPOTrainer.create(config);
|
|
185
|
+
|
|
186
|
+
for (const batch of dataset) {
|
|
187
|
+
const generations = await trainer.generateBatch(batch.prompts);
|
|
188
|
+
const rewards = await trainer.scoreGenerations(batch.prompts, generations.completions, context);
|
|
189
|
+
const metrics = trainer.trainStep(batch.prompts, context);
|
|
190
|
+
trainer.incrementStep();
|
|
191
|
+
|
|
192
|
+
if (metrics.step % 100 === 0) {
|
|
193
|
+
await trainer.saveCheckpoint();
|
|
194
|
+
}
|
|
195
|
+
}
|
|
196
|
+
```
|
|
197
|
+
|
|
198
|
+
### Output Store (SQLite)
|
|
199
|
+
|
|
200
|
+
Record all training generations and metrics to SQLite for analysis:
|
|
201
|
+
|
|
202
|
+
```typescript
|
|
203
|
+
const trainer = await GRPOTrainer.create({
|
|
204
|
+
...config,
|
|
205
|
+
outputStore: {
|
|
206
|
+
enabled: true,
|
|
207
|
+
database: './output/training.db',
|
|
208
|
+
},
|
|
209
|
+
});
|
|
210
|
+
```
|
|
211
|
+
|
|
212
|
+
## SFT Training
|
|
213
|
+
|
|
214
|
+
Supervised fine-tuning with autograd, gradient accumulation, and completion-only masking.
|
|
215
|
+
|
|
216
|
+
### Dataset Formats
|
|
217
|
+
|
|
218
|
+
Two formats are auto-detected from JSONL files:
|
|
219
|
+
|
|
220
|
+
**Prompt-Completion:**
|
|
221
|
+
|
|
222
|
+
```json
|
|
223
|
+
{ "prompt": [{ "role": "user", "content": "Hello" }], "completion": { "role": "assistant", "content": "Hi!" } }
|
|
224
|
+
```
|
|
225
|
+
|
|
226
|
+
**Conversation:**
|
|
227
|
+
|
|
228
|
+
```json
|
|
229
|
+
{
|
|
230
|
+
"messages": [
|
|
231
|
+
{ "role": "user", "content": "Hello" },
|
|
232
|
+
{ "role": "assistant", "content": "Hi!" }
|
|
233
|
+
]
|
|
234
|
+
}
|
|
235
|
+
```
|
|
236
|
+
|
|
237
|
+
### SFT Configuration
|
|
238
|
+
|
|
239
|
+
```typescript
|
|
240
|
+
import { SFTTrainer, SFTTrainerConfig } from '@mlx-node/trl';
|
|
241
|
+
|
|
242
|
+
const config: SFTTrainerConfig = {
|
|
243
|
+
modelName: './models/Qwen3-0.6B',
|
|
244
|
+
outputDir: './output/sft',
|
|
245
|
+
learningRate: 2e-5,
|
|
246
|
+
batchSize: 4,
|
|
247
|
+
gradientAccumulationSteps: 8,
|
|
248
|
+
numEpochs: 3,
|
|
249
|
+
maxSeqLength: 2048,
|
|
250
|
+
completionOnly: true, // only compute loss on assistant tokens
|
|
251
|
+
labelSmoothing: 0.1,
|
|
252
|
+
maxGradNorm: 1.0,
|
|
253
|
+
weightDecay: 0.01,
|
|
254
|
+
loggingSteps: 10,
|
|
255
|
+
saveSteps: 100,
|
|
256
|
+
maxCheckpoints: 3,
|
|
257
|
+
gradientCheckpointing: true,
|
|
258
|
+
};
|
|
259
|
+
```
|
|
260
|
+
|
|
261
|
+
### Programmatic Dataset
|
|
262
|
+
|
|
263
|
+
```typescript
|
|
264
|
+
import { SFTDataset, createSFTDataset } from '@mlx-node/trl';
|
|
265
|
+
|
|
266
|
+
const dataset = createSFTDataset(examples, tokenizer, {
|
|
267
|
+
maxSeqLength: 2048,
|
|
268
|
+
completionOnly: true,
|
|
269
|
+
});
|
|
270
|
+
|
|
271
|
+
const trainer = await SFTTrainer.create(config);
|
|
272
|
+
await trainer.train(dataset);
|
|
273
|
+
```
|
|
274
|
+
|
|
275
|
+
## Datasets
|
|
276
|
+
|
|
277
|
+
### GSM8K Loader
|
|
278
|
+
|
|
279
|
+
Built-in loader for the GSM8K math dataset:
|
|
280
|
+
|
|
281
|
+
```typescript
|
|
282
|
+
import { loadLocalGsm8kDataset, LocalGsm8kDatasetLoader } from '@mlx-node/trl';
|
|
283
|
+
|
|
284
|
+
// Direct load
|
|
285
|
+
const examples = await loadLocalGsm8kDataset('train', { limit: 1000 });
|
|
286
|
+
|
|
287
|
+
// Via DatasetLoader interface
|
|
288
|
+
const loader = new LocalGsm8kDatasetLoader('./data/gsm8k');
|
|
289
|
+
const trainData = await loader.load('train');
|
|
290
|
+
```
|
|
291
|
+
|
|
292
|
+
### Custom Datasets
|
|
293
|
+
|
|
294
|
+
Implement the `DatasetLoader` interface:
|
|
295
|
+
|
|
296
|
+
```typescript
|
|
297
|
+
import { DatasetLoader, DatasetExample } from '@mlx-node/trl';
|
|
298
|
+
|
|
299
|
+
class MyDataset implements DatasetLoader {
|
|
300
|
+
async load(split: 'train' | 'test', limit?: number): Promise<DatasetExample[]> {
|
|
301
|
+
return examples.map((e) => ({
|
|
302
|
+
prompt: [
|
|
303
|
+
{ role: 'system', content: 'You are helpful.' },
|
|
304
|
+
{ role: 'user', content: e.question },
|
|
305
|
+
],
|
|
306
|
+
metadata: { answer: e.answer },
|
|
307
|
+
}));
|
|
308
|
+
}
|
|
309
|
+
}
|
|
310
|
+
```
|
|
311
|
+
|
|
312
|
+
## Utilities
|
|
313
|
+
|
|
314
|
+
### XML Chain-of-Thought Parser
|
|
315
|
+
|
|
316
|
+
Parse `<reasoning>...</reasoning><answer>...</answer>` format:
|
|
317
|
+
|
|
318
|
+
```typescript
|
|
319
|
+
import { parseXmlCot, extractXmlAnswer } from '@mlx-node/trl';
|
|
320
|
+
|
|
321
|
+
const result = parseXmlCot(modelOutput);
|
|
322
|
+
// { reasoning: "...", answer: "42", isStrictMatch: true, isSoftMatch: true, errors: [] }
|
|
323
|
+
|
|
324
|
+
const answer = extractXmlAnswer(modelOutput);
|
|
325
|
+
// "42"
|
|
326
|
+
```
|
|
327
|
+
|
|
328
|
+
### Model Conversion
|
|
329
|
+
|
|
330
|
+
Re-exported from `@mlx-node/core`:
|
|
331
|
+
|
|
332
|
+
```typescript
|
|
333
|
+
import { convertModel, convertParquetToJsonl } from '@mlx-node/trl';
|
|
334
|
+
```
|
|
335
|
+
|
|
336
|
+
## Features
|
|
337
|
+
|
|
338
|
+
- **Checkpoint resume** — automatic state restoration including optimizer, step count, and dataset position
|
|
339
|
+
- **Emergency save** — catches NaN gradients and SIGTERM/SIGINT for safe recovery
|
|
340
|
+
- **TUI mode** — interactive terminal UI with pause/resume/stop (via `mlx-tui` binary)
|
|
341
|
+
- **JSONL logging** — structured training logs for external monitoring
|
|
342
|
+
- **Multi-model** — supports Qwen3, Qwen3.5 Dense, and Qwen3.5 MoE architectures
|
|
343
|
+
- **Reward timeout** — configurable timeout for async reward functions (default 60s)
|
|
344
|
+
- **Path security** — traversal prevention for dataset file loading
|
|
345
|
+
|
|
346
|
+
## API Reference
|
|
347
|
+
|
|
348
|
+
### Trainers
|
|
349
|
+
|
|
350
|
+
| Class | Description |
|
|
351
|
+
| ------------- | --------------------------------------------------------------- |
|
|
352
|
+
| `GRPOTrainer` | GRPO training with generation, rewards, and policy optimization |
|
|
353
|
+
| `SFTTrainer` | Supervised fine-tuning with completion-only masking |
|
|
354
|
+
|
|
355
|
+
### Datasets
|
|
356
|
+
|
|
357
|
+
| Export | Description |
|
|
358
|
+
| ------------------------- | ------------------------------------------------ |
|
|
359
|
+
| `loadLocalGsm8kDataset()` | Load GSM8K JSONL dataset |
|
|
360
|
+
| `LocalGsm8kDatasetLoader` | `DatasetLoader` implementation for GSM8K |
|
|
361
|
+
| `SFTDataset` | Tokenized SFT dataset with padding and shuffling |
|
|
362
|
+
| `loadSFTDataset()` | Load SFT dataset from JSONL file |
|
|
363
|
+
| `createSFTDataset()` | Create SFT dataset from in-memory examples |
|
|
364
|
+
|
|
365
|
+
### Configuration
|
|
366
|
+
|
|
367
|
+
| Export | Description |
|
|
368
|
+
| ----------------------- | --------------------------------- |
|
|
369
|
+
| `GRPOTrainerConfig` | Full GRPO configuration interface |
|
|
370
|
+
| `SFTTrainerConfig` | Full SFT configuration interface |
|
|
371
|
+
| `loadTomlConfig()` | Load GRPO config from TOML file |
|
|
372
|
+
| `loadSFTTomlConfig()` | Load SFT config from TOML file |
|
|
373
|
+
| `getDefaultConfig()` | Default GRPO config |
|
|
374
|
+
| `getDefaultSFTConfig()` | Default SFT config |
|
|
375
|
+
|
|
376
|
+
### Types
|
|
377
|
+
|
|
378
|
+
| Type | Description |
|
|
379
|
+
| --------------------- | -------------------------------------------------- |
|
|
380
|
+
| `DatasetExample` | Training example with prompt messages and metadata |
|
|
381
|
+
| `RewardFunction<T>` | Custom reward function signature |
|
|
382
|
+
| `RewardOutput` | Structured completion data for reward scoring |
|
|
383
|
+
| `XmlParseResult` | Result of XML chain-of-thought parsing |
|
|
384
|
+
| `TrainStepMetrics` | Per-step training metrics |
|
|
385
|
+
| `BuiltinRewardConfig` | Configuration for native reward functions |
|
|
386
|
+
|
|
387
|
+
## License
|
|
388
|
+
|
|
389
|
+
[MIT](https://github.com/mlx-node/mlx-node/blob/main/LICENSE)
|
package/package.json
CHANGED
|
@@ -1,6 +1,16 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@mlx-node/trl",
|
|
3
|
-
"version": "0.0.
|
|
3
|
+
"version": "0.0.1",
|
|
4
|
+
"homepage": "https://github.com/mlx-node/mlx-node",
|
|
5
|
+
"bugs": {
|
|
6
|
+
"url": "https://github.com/mlx-node/mlx-node/issues"
|
|
7
|
+
},
|
|
8
|
+
"license": "MIT",
|
|
9
|
+
"repository": {
|
|
10
|
+
"type": "git",
|
|
11
|
+
"url": "https://github.com/mlx-node/mlx-node.git",
|
|
12
|
+
"directory": "packages/trl"
|
|
13
|
+
},
|
|
4
14
|
"files": [
|
|
5
15
|
"dist"
|
|
6
16
|
],
|
|
@@ -19,11 +29,12 @@
|
|
|
19
29
|
"test:trainer": "TEST_TRAINER=1 vite test run"
|
|
20
30
|
},
|
|
21
31
|
"dependencies": {
|
|
22
|
-
"@mlx-node/core": "0.0.
|
|
23
|
-
"@mlx-node/lm": "0.0.
|
|
24
|
-
"@std/toml": "npm:@jsr/std__toml@^1.0.11"
|
|
32
|
+
"@mlx-node/core": "0.0.1",
|
|
33
|
+
"@mlx-node/lm": "0.0.1",
|
|
34
|
+
"@std/toml": "npm:@jsr/std__toml@^1.0.11",
|
|
35
|
+
"change-case": "^5.4.4"
|
|
25
36
|
},
|
|
26
37
|
"devDependencies": {
|
|
27
|
-
"@huggingface/hub": "^2.7
|
|
38
|
+
"@huggingface/hub": "^2.10.7"
|
|
28
39
|
}
|
|
29
40
|
}
|
package/dist/data/dataset.d.ts
DELETED
|
@@ -1,22 +0,0 @@
|
|
|
1
|
-
import type { DatasetExample, ChatMessage, DatasetSplit, PromptFormatterOptions, PromptTemplate, DatasetLoader } from '../types';
|
|
2
|
-
import { type PathValidationOptions } from '../utils/path-security';
|
|
3
|
-
export interface LocalDatasetOptions extends PromptFormatterOptions, PathValidationOptions {
|
|
4
|
-
basePath?: string;
|
|
5
|
-
promptTemplate?: PromptTemplate;
|
|
6
|
-
metadata?: Record<string, unknown>;
|
|
7
|
-
}
|
|
8
|
-
export declare const SYSTEM_PROMPT: string;
|
|
9
|
-
export declare const XML_COT_FORMAT = "<reasoning>\n{reasoning}\n</reasoning>\n<answer>\n{answer}\n</answer>";
|
|
10
|
-
export declare const defaultPromptTemplate: PromptTemplate;
|
|
11
|
-
export declare function createDatasetExample(prompt: ChatMessage[], metadata?: Record<string, unknown>): DatasetExample;
|
|
12
|
-
export declare function extractGsm8kAnswer(raw: string): string | null;
|
|
13
|
-
export declare function validateDatasetExample(example: DatasetExample): void;
|
|
14
|
-
export declare function loadLocalGsm8kDataset(split: DatasetSplit, options?: LocalDatasetOptions & {
|
|
15
|
-
limit?: number;
|
|
16
|
-
}): Promise<DatasetExample[]>;
|
|
17
|
-
export declare class LocalGsm8kDatasetLoader implements DatasetLoader {
|
|
18
|
-
private readonly options;
|
|
19
|
-
constructor(options?: LocalDatasetOptions);
|
|
20
|
-
load(split: DatasetSplit, limit?: number): Promise<DatasetExample[]>;
|
|
21
|
-
}
|
|
22
|
-
//# sourceMappingURL=dataset.d.ts.map
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"dataset.d.ts","sourceRoot":"","sources":["../../src/data/dataset.ts"],"names":[],"mappings":"AAEA,OAAO,KAAK,EACV,cAAc,EACd,WAAW,EAEX,YAAY,EACZ,sBAAsB,EACtB,cAAc,EACd,aAAa,EACd,MAAM,UAAU,CAAC;AAElB,OAAO,EAA2C,KAAK,qBAAqB,EAAE,MAAM,wBAAwB,CAAC;AAE7G,MAAM,WAAW,mBAAoB,SAAQ,sBAAsB,EAAE,qBAAqB;IACxF,QAAQ,CAAC,EAAE,MAAM,CAAC;IAClB,cAAc,CAAC,EAAE,cAAc,CAAC;IAChC,QAAQ,CAAC,EAAE,MAAM,CAAC,MAAM,EAAE,OAAO,CAAC,CAAC;CACpC;AAUD,eAAO,MAAM,aAAa,QASlB,CAAC;AAET,eAAO,MAAM,cAAc,0EAKjB,CAAC;AAWX,eAAO,MAAM,qBAAqB,EAAE,cAYnC,CAAC;AAEF,wBAAgB,oBAAoB,CAAC,MAAM,EAAE,WAAW,EAAE,EAAE,QAAQ,CAAC,EAAE,MAAM,CAAC,MAAM,EAAE,OAAO,CAAC,GAAG,cAAc,CAK9G;AAED,wBAAgB,kBAAkB,CAAC,GAAG,EAAE,MAAM,GAAG,MAAM,GAAG,IAAI,CAE7D;AAED,wBAAgB,sBAAsB,CAAC,OAAO,EAAE,cAAc,GAAG,IAAI,CAYpE;AAwDD,wBAAsB,qBAAqB,CACzC,KAAK,EAAE,YAAY,EACnB,OAAO,GAAE,mBAAmB,GAAG;IAAE,KAAK,CAAC,EAAE,MAAM,CAAA;CAAO,GACrD,OAAO,CAAC,cAAc,EAAE,CAAC,CA4B3B;AAED,qBAAa,uBAAwB,YAAW,aAAa;IAC3D,OAAO,CAAC,QAAQ,CAAC,OAAO,CAAsB;gBAElC,OAAO,GAAE,mBAAwB;IAIvC,IAAI,CAAC,KAAK,EAAE,YAAY,EAAE,KAAK,CAAC,EAAE,MAAM,GAAG,OAAO,CAAC,cAAc,EAAE,CAAC;CAG3E"}
|
package/dist/data/dataset.js
DELETED
|
@@ -1,142 +0,0 @@
|
|
|
1
|
-
import { readFileSync } from 'node:fs';
|
|
2
|
-
import { resolve as resolvePath } from 'node:path';
|
|
3
|
-
import { extractHashAnswer } from '../utils/xml-parser';
|
|
4
|
-
import { validatePathContainment, getAllowedRoot } from '../utils/path-security';
|
|
5
|
-
const DEFAULT_BASE_PATH = resolvePath(process.cwd(), 'data/gsm8k');
|
|
6
|
-
const VALID_SPLITS = new Set(['train', 'test']);
|
|
7
|
-
export const SYSTEM_PROMPT = `
|
|
8
|
-
Respond in the following format:
|
|
9
|
-
|
|
10
|
-
<reasoning>
|
|
11
|
-
...
|
|
12
|
-
</reasoning>
|
|
13
|
-
<answer>
|
|
14
|
-
...
|
|
15
|
-
</answer>
|
|
16
|
-
`.trim();
|
|
17
|
-
export const XML_COT_FORMAT = `<reasoning>
|
|
18
|
-
{reasoning}
|
|
19
|
-
</reasoning>
|
|
20
|
-
<answer>
|
|
21
|
-
{answer}
|
|
22
|
-
</answer>`;
|
|
23
|
-
const SYSTEM_MESSAGE = {
|
|
24
|
-
role: 'system',
|
|
25
|
-
content: SYSTEM_PROMPT,
|
|
26
|
-
};
|
|
27
|
-
function createMessage(role, content) {
|
|
28
|
-
return { role, content };
|
|
29
|
-
}
|
|
30
|
-
export const defaultPromptTemplate = (question, options) => {
|
|
31
|
-
const messages = [SYSTEM_MESSAGE];
|
|
32
|
-
if (options?.includeOneShot && options.oneShotExample) {
|
|
33
|
-
const { question: exampleQuestion, reasoning, answer } = options.oneShotExample;
|
|
34
|
-
messages.push(createMessage('user', exampleQuestion), createMessage('assistant', XML_COT_FORMAT.replace('{reasoning}', reasoning).replace('{answer}', answer)));
|
|
35
|
-
}
|
|
36
|
-
messages.push(createMessage('user', question));
|
|
37
|
-
return messages;
|
|
38
|
-
};
|
|
39
|
-
export function createDatasetExample(prompt, metadata) {
|
|
40
|
-
return {
|
|
41
|
-
prompt: prompt.map((message) => ({ ...message })), // defensive copy
|
|
42
|
-
metadata: metadata ? { ...metadata } : undefined,
|
|
43
|
-
};
|
|
44
|
-
}
|
|
45
|
-
export function extractGsm8kAnswer(raw) {
|
|
46
|
-
return extractHashAnswer(raw);
|
|
47
|
-
}
|
|
48
|
-
export function validateDatasetExample(example) {
|
|
49
|
-
if (!Array.isArray(example.prompt) || example.prompt.length === 0) {
|
|
50
|
-
throw new Error('Dataset example must contain at least one prompt message.');
|
|
51
|
-
}
|
|
52
|
-
for (const message of example.prompt) {
|
|
53
|
-
if (!message || typeof message.content !== 'string' || message.content.trim() === '') {
|
|
54
|
-
throw new Error('Prompt messages must include non-empty textual content.');
|
|
55
|
-
}
|
|
56
|
-
if (message.role !== 'system' && message.role !== 'user' && message.role !== 'assistant') {
|
|
57
|
-
throw new Error(`Unsupported chat role: ${String(message.role)}`);
|
|
58
|
-
}
|
|
59
|
-
}
|
|
60
|
-
}
|
|
61
|
-
function resolveBasePath(optionPath, options) {
|
|
62
|
-
const allowedRoot = getAllowedRoot(options);
|
|
63
|
-
if (!optionPath) {
|
|
64
|
-
// Default path - validate it's within allowed root
|
|
65
|
-
validatePathContainment(DEFAULT_BASE_PATH, allowedRoot);
|
|
66
|
-
return DEFAULT_BASE_PATH;
|
|
67
|
-
}
|
|
68
|
-
// Resolve and validate user-provided path
|
|
69
|
-
const resolved = resolvePath(allowedRoot, optionPath);
|
|
70
|
-
validatePathContainment(resolved, allowedRoot);
|
|
71
|
-
return resolved;
|
|
72
|
-
}
|
|
73
|
-
function datasetFileForSplit(split) {
|
|
74
|
-
if (!VALID_SPLITS.has(split)) {
|
|
75
|
-
throw new Error(`Unsupported GSM8K split "${split}". Expected one of: ${Array.from(VALID_SPLITS).join(', ')}`);
|
|
76
|
-
}
|
|
77
|
-
return `${split}.jsonl`;
|
|
78
|
-
}
|
|
79
|
-
function readDatasetFile(filePath) {
|
|
80
|
-
try {
|
|
81
|
-
return readFileSync(filePath, 'utf8');
|
|
82
|
-
}
|
|
83
|
-
catch (error) {
|
|
84
|
-
const message = error instanceof Error ? error.message : String(error);
|
|
85
|
-
throw new Error(`Failed to read dataset file at ${filePath}: ${message}`);
|
|
86
|
-
}
|
|
87
|
-
}
|
|
88
|
-
function readJsonl(path, limit) {
|
|
89
|
-
const fileContents = readDatasetFile(path);
|
|
90
|
-
const lines = fileContents.split(/\r?\n/).filter((line) => line.trim().length > 0);
|
|
91
|
-
const records = [];
|
|
92
|
-
const max = typeof limit === 'number' && limit >= 0 ? limit : Number.POSITIVE_INFINITY;
|
|
93
|
-
for (let i = 0; i < lines.length && records.length < max; i += 1) {
|
|
94
|
-
const line = lines[i];
|
|
95
|
-
try {
|
|
96
|
-
const parsed = JSON.parse(line);
|
|
97
|
-
if (typeof parsed.question !== 'string' || typeof parsed.answer !== 'string') {
|
|
98
|
-
throw new Error('Record must include string "question" and "answer" fields.');
|
|
99
|
-
}
|
|
100
|
-
records.push({ question: parsed.question, answer: parsed.answer });
|
|
101
|
-
}
|
|
102
|
-
catch (error) {
|
|
103
|
-
const message = error instanceof Error ? error.message : String(error);
|
|
104
|
-
throw new Error(`Failed to parse JSONL record at ${path}:${i + 1} - ${message}`);
|
|
105
|
-
}
|
|
106
|
-
}
|
|
107
|
-
return records;
|
|
108
|
-
}
|
|
109
|
-
export async function loadLocalGsm8kDataset(split, options = {}) {
|
|
110
|
-
const basePath = resolveBasePath(options.basePath, options);
|
|
111
|
-
const fileName = datasetFileForSplit(split);
|
|
112
|
-
const filePath = resolvePath(basePath, fileName);
|
|
113
|
-
// Additional validation: ensure the final file path stays within the base path
|
|
114
|
-
// This protects against any edge cases where the filename could escape
|
|
115
|
-
validatePathContainment(filePath, basePath);
|
|
116
|
-
const promptTemplate = options.promptTemplate ?? defaultPromptTemplate;
|
|
117
|
-
const records = readJsonl(filePath, options.limit);
|
|
118
|
-
const examples = records.map((record, index) => {
|
|
119
|
-
const prompt = promptTemplate(record.question, {
|
|
120
|
-
includeOneShot: options.includeOneShot,
|
|
121
|
-
oneShotExample: options.oneShotExample,
|
|
122
|
-
});
|
|
123
|
-
const example = createDatasetExample(prompt, {
|
|
124
|
-
split,
|
|
125
|
-
index,
|
|
126
|
-
raw_answer: record.answer,
|
|
127
|
-
...options.metadata,
|
|
128
|
-
});
|
|
129
|
-
validateDatasetExample(example);
|
|
130
|
-
return example;
|
|
131
|
-
});
|
|
132
|
-
return examples;
|
|
133
|
-
}
|
|
134
|
-
export class LocalGsm8kDatasetLoader {
|
|
135
|
-
options;
|
|
136
|
-
constructor(options = {}) {
|
|
137
|
-
this.options = { ...options };
|
|
138
|
-
}
|
|
139
|
-
async load(split, limit) {
|
|
140
|
-
return loadLocalGsm8kDataset(split, { ...this.options, limit });
|
|
141
|
-
}
|
|
142
|
-
}
|