@mlx-node/trl 0.0.0 → 0.0.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. package/README.md +389 -0
  2. package/package.json +16 -5
  3. package/dist/data/dataset.d.ts +0 -22
  4. package/dist/data/dataset.d.ts.map +0 -1
  5. package/dist/data/dataset.js +0 -142
  6. package/dist/data/sft-dataset.d.ts +0 -156
  7. package/dist/data/sft-dataset.d.ts.map +0 -1
  8. package/dist/data/sft-dataset.js +0 -415
  9. package/dist/index.d.ts +0 -33
  10. package/dist/index.d.ts.map +0 -1
  11. package/dist/index.js +0 -47
  12. package/dist/trainers/grpo-config.d.ts +0 -42
  13. package/dist/trainers/grpo-config.d.ts.map +0 -1
  14. package/dist/trainers/grpo-config.js +0 -220
  15. package/dist/trainers/grpo-entropy.d.ts +0 -33
  16. package/dist/trainers/grpo-entropy.d.ts.map +0 -1
  17. package/dist/trainers/grpo-entropy.js +0 -18
  18. package/dist/trainers/grpo-trainer.d.ts +0 -602
  19. package/dist/trainers/grpo-trainer.d.ts.map +0 -1
  20. package/dist/trainers/grpo-trainer.js +0 -1439
  21. package/dist/trainers/sft-config.d.ts +0 -32
  22. package/dist/trainers/sft-config.d.ts.map +0 -1
  23. package/dist/trainers/sft-config.js +0 -186
  24. package/dist/trainers/sft-trainer.d.ts +0 -141
  25. package/dist/trainers/sft-trainer.d.ts.map +0 -1
  26. package/dist/trainers/sft-trainer.js +0 -502
  27. package/dist/trainers/training-logger.d.ts +0 -375
  28. package/dist/trainers/training-logger.d.ts.map +0 -1
  29. package/dist/trainers/training-logger.js +0 -542
  30. package/dist/types.d.ts +0 -54
  31. package/dist/types.d.ts.map +0 -1
  32. package/dist/types.js +0 -1
  33. package/dist/utils/path-security.d.ts +0 -51
  34. package/dist/utils/path-security.d.ts.map +0 -1
  35. package/dist/utils/path-security.js +0 -69
  36. package/dist/utils/xml-parser.d.ts +0 -6
  37. package/dist/utils/xml-parser.d.ts.map +0 -1
  38. package/dist/utils/xml-parser.js +0 -184
package/README.md ADDED
@@ -0,0 +1,389 @@
1
+ # @mlx-node/trl
2
+
3
+ Training library for language models on Apple Silicon. Supports GRPO (Group Relative Policy Optimization) and SFT (Supervised Fine-Tuning) with Metal GPU acceleration, built-in reward functions, dataset handling, and checkpoint management.
4
+
5
+ ## Requirements
6
+
7
+ - macOS with Apple Silicon (M1 or later)
8
+ - Node.js 18+
9
+
10
+ ## Installation
11
+
12
+ ```bash
13
+ npm install @mlx-node/trl
14
+ ```
15
+
16
+ ## Quick Start
17
+
18
+ ### GRPO Training
19
+
20
+ ```typescript
21
+ import { GRPOTrainer } from '@mlx-node/trl';
22
+
23
+ const trainer = await GRPOTrainer.create({
24
+ modelPath: './models/Qwen3-0.6B',
25
+ outputDir: './output/grpo-run',
26
+ learningRate: 1e-6,
27
+ groupSize: 4,
28
+ maxCompletionLength: 256,
29
+ temperature: 0.8,
30
+ rewardFunction: async (outputs) => {
31
+ return outputs.map((o) => (o.text.includes('correct') ? 1.0 : 0.0));
32
+ },
33
+ });
34
+
35
+ const dataset = await loadDataset('train');
36
+ await trainer.train(dataset);
37
+ ```
38
+
39
+ ### SFT Training
40
+
41
+ ```typescript
42
+ import { SFTTrainer } from '@mlx-node/trl';
43
+
44
+ const trainer = await SFTTrainer.create({
45
+ modelName: './models/Qwen3-0.6B',
46
+ outputDir: './output/sft-run',
47
+ learningRate: 2e-5,
48
+ batchSize: 4,
49
+ numEpochs: 3,
50
+ completionOnly: true,
51
+ });
52
+
53
+ await trainer.train('./data/training.jsonl');
54
+ ```
55
+
56
+ ## GRPO Training
57
+
58
+ GRPO generates multiple completions per prompt, scores them with reward functions, and trains the model to prefer higher-reward outputs.
59
+
60
+ ### Loss Variants
61
+
62
+ | Loss Type | Description |
63
+ | --------- | ------------------------------------------- |
64
+ | `grpo` | Standard Group Relative Policy Optimization |
65
+ | `dapo` | Dynamic sampling with adaptive clipping |
66
+ | `dr_grpo` | Dr.GRPO with improved gradient estimation |
67
+ | `bnpo` | Batch-normalized policy optimization |
68
+
69
+ ### Configuration
70
+
71
+ ```typescript
72
+ import { GRPOTrainer, GRPOTrainerConfig } from '@mlx-node/trl';
73
+
74
+ const config: GRPOTrainerConfig = {
75
+ // Model
76
+ modelPath: './models/Qwen3-0.6B',
77
+ outputDir: './output',
78
+
79
+ // Training
80
+ learningRate: 1e-6,
81
+ batchSize: 1,
82
+ numEpochs: 1,
83
+ gradientAccumulationSteps: 1,
84
+ gradientClipNorm: 1.0,
85
+ weightDecay: 0.01,
86
+
87
+ // GRPO
88
+ groupSize: 4, // completions per prompt
89
+ clipEpsilon: 0.2, // PPO clipping
90
+ klCoef: 0.0, // KL divergence coefficient
91
+ lossType: 'grpo', // grpo | dapo | dr_grpo | bnpo
92
+
93
+ // Generation
94
+ maxCompletionLength: 256,
95
+ temperature: 0.8,
96
+ topP: 0.95,
97
+ repetitionPenalty: 1.1,
98
+
99
+ // Tool calling
100
+ tools: [toolDef],
101
+ enableThinking: true,
102
+
103
+ // Rewards
104
+ rewardFunction: myRewardFn,
105
+
106
+ // Memory optimization
107
+ gradientCheckpointing: true,
108
+ lmHeadChunkSize: 2,
109
+ vocabChunkSize: 65536,
110
+
111
+ // Checkpointing
112
+ saveInterval: 100,
113
+ maxCheckpoints: 3,
114
+ resumeFromCheckpoint: './output/checkpoint-500',
115
+
116
+ // Optimizer
117
+ optimizerType: 'adamw', // adamw | sgd
118
+ };
119
+ ```
120
+
121
+ ### TOML Configuration
122
+
123
+ Load training config from a TOML file:
124
+
125
+ ```typescript
126
+ import { loadTomlConfig, applyOverrides } from '@mlx-node/trl';
127
+
128
+ const config = loadTomlConfig('./train.toml');
129
+ applyOverrides(config, ['learningRate=2e-6', 'batchSize=2']);
130
+ ```
131
+
132
+ ### Built-in Rewards
133
+
134
+ Register native Rust reward functions for high-performance scoring:
135
+
136
+ ```typescript
137
+ trainer.registerBuiltinReward({
138
+ type: 'ToolUse',
139
+ weight: 1.0,
140
+ allowedTools: ['get_weather', 'search'],
141
+ });
142
+
143
+ trainer.registerBuiltinReward({
144
+ type: 'XmlFormat',
145
+ weight: 0.5,
146
+ requiredTags: ['reasoning', 'answer'],
147
+ });
148
+
149
+ trainer.registerBuiltinReward({
150
+ type: 'Length',
151
+ weight: 0.3,
152
+ min: 50,
153
+ max: 500,
154
+ });
155
+
156
+ trainer.registerBuiltinReward({
157
+ type: 'JsonSchema',
158
+ weight: 1.0,
159
+ });
160
+ ```
161
+
162
+ ### Custom Reward Functions
163
+
164
+ ```typescript
165
+ import { RewardFunction, RewardOutput } from '@mlx-node/trl';
166
+
167
+ const reward: RewardFunction = async (outputs: RewardOutput[]) => {
168
+ return outputs.map((output) => {
169
+ let score = 0;
170
+ if (output.toolCalls?.length) score += 0.5;
171
+ if (output.text.length > 100) score += 0.3;
172
+ return score;
173
+ });
174
+ };
175
+
176
+ trainer.setRewardFunction(reward);
177
+ ```
178
+
179
+ ### Custom Training Loop
180
+
181
+ For advanced use cases, use the low-level API:
182
+
183
+ ```typescript
184
+ const trainer = await GRPOTrainer.create(config);
185
+
186
+ for (const batch of dataset) {
187
+ const generations = await trainer.generateBatch(batch.prompts);
188
+ const rewards = await trainer.scoreGenerations(batch.prompts, generations.completions, context);
189
+ const metrics = trainer.trainStep(batch.prompts, context);
190
+ trainer.incrementStep();
191
+
192
+ if (metrics.step % 100 === 0) {
193
+ await trainer.saveCheckpoint();
194
+ }
195
+ }
196
+ ```
197
+
198
+ ### Output Store (SQLite)
199
+
200
+ Record all training generations and metrics to SQLite for analysis:
201
+
202
+ ```typescript
203
+ const trainer = await GRPOTrainer.create({
204
+ ...config,
205
+ outputStore: {
206
+ enabled: true,
207
+ database: './output/training.db',
208
+ },
209
+ });
210
+ ```
211
+
212
+ ## SFT Training
213
+
214
+ Supervised fine-tuning with autograd, gradient accumulation, and completion-only masking.
215
+
216
+ ### Dataset Formats
217
+
218
+ Two formats are auto-detected from JSONL files:
219
+
220
+ **Prompt-Completion:**
221
+
222
+ ```json
223
+ { "prompt": [{ "role": "user", "content": "Hello" }], "completion": { "role": "assistant", "content": "Hi!" } }
224
+ ```
225
+
226
+ **Conversation:**
227
+
228
+ ```json
229
+ {
230
+ "messages": [
231
+ { "role": "user", "content": "Hello" },
232
+ { "role": "assistant", "content": "Hi!" }
233
+ ]
234
+ }
235
+ ```
236
+
237
+ ### SFT Configuration
238
+
239
+ ```typescript
240
+ import { SFTTrainer, SFTTrainerConfig } from '@mlx-node/trl';
241
+
242
+ const config: SFTTrainerConfig = {
243
+ modelName: './models/Qwen3-0.6B',
244
+ outputDir: './output/sft',
245
+ learningRate: 2e-5,
246
+ batchSize: 4,
247
+ gradientAccumulationSteps: 8,
248
+ numEpochs: 3,
249
+ maxSeqLength: 2048,
250
+ completionOnly: true, // only compute loss on assistant tokens
251
+ labelSmoothing: 0.1,
252
+ maxGradNorm: 1.0,
253
+ weightDecay: 0.01,
254
+ loggingSteps: 10,
255
+ saveSteps: 100,
256
+ maxCheckpoints: 3,
257
+ gradientCheckpointing: true,
258
+ };
259
+ ```
260
+
261
+ ### Programmatic Dataset
262
+
263
+ ```typescript
264
+ import { SFTDataset, createSFTDataset } from '@mlx-node/trl';
265
+
266
+ const dataset = createSFTDataset(examples, tokenizer, {
267
+ maxSeqLength: 2048,
268
+ completionOnly: true,
269
+ });
270
+
271
+ const trainer = await SFTTrainer.create(config);
272
+ await trainer.train(dataset);
273
+ ```
274
+
275
+ ## Datasets
276
+
277
+ ### GSM8K Loader
278
+
279
+ Built-in loader for the GSM8K math dataset:
280
+
281
+ ```typescript
282
+ import { loadLocalGsm8kDataset, LocalGsm8kDatasetLoader } from '@mlx-node/trl';
283
+
284
+ // Direct load
285
+ const examples = await loadLocalGsm8kDataset('train', { limit: 1000 });
286
+
287
+ // Via DatasetLoader interface
288
+ const loader = new LocalGsm8kDatasetLoader('./data/gsm8k');
289
+ const trainData = await loader.load('train');
290
+ ```
291
+
292
+ ### Custom Datasets
293
+
294
+ Implement the `DatasetLoader` interface:
295
+
296
+ ```typescript
297
+ import { DatasetLoader, DatasetExample } from '@mlx-node/trl';
298
+
299
+ class MyDataset implements DatasetLoader {
300
+ async load(split: 'train' | 'test', limit?: number): Promise<DatasetExample[]> {
301
+ return examples.map((e) => ({
302
+ prompt: [
303
+ { role: 'system', content: 'You are helpful.' },
304
+ { role: 'user', content: e.question },
305
+ ],
306
+ metadata: { answer: e.answer },
307
+ }));
308
+ }
309
+ }
310
+ ```
311
+
312
+ ## Utilities
313
+
314
+ ### XML Chain-of-Thought Parser
315
+
316
+ Parse `<reasoning>...</reasoning><answer>...</answer>` format:
317
+
318
+ ```typescript
319
+ import { parseXmlCot, extractXmlAnswer } from '@mlx-node/trl';
320
+
321
+ const result = parseXmlCot(modelOutput);
322
+ // { reasoning: "...", answer: "42", isStrictMatch: true, isSoftMatch: true, errors: [] }
323
+
324
+ const answer = extractXmlAnswer(modelOutput);
325
+ // "42"
326
+ ```
327
+
328
+ ### Model Conversion
329
+
330
+ Re-exported from `@mlx-node/core`:
331
+
332
+ ```typescript
333
+ import { convertModel, convertParquetToJsonl } from '@mlx-node/trl';
334
+ ```
335
+
336
+ ## Features
337
+
338
+ - **Checkpoint resume** — automatic state restoration including optimizer, step count, and dataset position
339
+ - **Emergency save** — catches NaN gradients and SIGTERM/SIGINT for safe recovery
340
+ - **TUI mode** — interactive terminal UI with pause/resume/stop (via `mlx-tui` binary)
341
+ - **JSONL logging** — structured training logs for external monitoring
342
+ - **Multi-model** — supports Qwen3, Qwen3.5 Dense, and Qwen3.5 MoE architectures
343
+ - **Reward timeout** — configurable timeout for async reward functions (default 60s)
344
+ - **Path security** — traversal prevention for dataset file loading
345
+
346
+ ## API Reference
347
+
348
+ ### Trainers
349
+
350
+ | Class | Description |
351
+ | ------------- | --------------------------------------------------------------- |
352
+ | `GRPOTrainer` | GRPO training with generation, rewards, and policy optimization |
353
+ | `SFTTrainer` | Supervised fine-tuning with completion-only masking |
354
+
355
+ ### Datasets
356
+
357
+ | Export | Description |
358
+ | ------------------------- | ------------------------------------------------ |
359
+ | `loadLocalGsm8kDataset()` | Load GSM8K JSONL dataset |
360
+ | `LocalGsm8kDatasetLoader` | `DatasetLoader` implementation for GSM8K |
361
+ | `SFTDataset` | Tokenized SFT dataset with padding and shuffling |
362
+ | `loadSFTDataset()` | Load SFT dataset from JSONL file |
363
+ | `createSFTDataset()` | Create SFT dataset from in-memory examples |
364
+
365
+ ### Configuration
366
+
367
+ | Export | Description |
368
+ | ----------------------- | --------------------------------- |
369
+ | `GRPOTrainerConfig` | Full GRPO configuration interface |
370
+ | `SFTTrainerConfig` | Full SFT configuration interface |
371
+ | `loadTomlConfig()` | Load GRPO config from TOML file |
372
+ | `loadSFTTomlConfig()` | Load SFT config from TOML file |
373
+ | `getDefaultConfig()` | Default GRPO config |
374
+ | `getDefaultSFTConfig()` | Default SFT config |
375
+
376
+ ### Types
377
+
378
+ | Type | Description |
379
+ | --------------------- | -------------------------------------------------- |
380
+ | `DatasetExample` | Training example with prompt messages and metadata |
381
+ | `RewardFunction<T>` | Custom reward function signature |
382
+ | `RewardOutput` | Structured completion data for reward scoring |
383
+ | `XmlParseResult` | Result of XML chain-of-thought parsing |
384
+ | `TrainStepMetrics` | Per-step training metrics |
385
+ | `BuiltinRewardConfig` | Configuration for native reward functions |
386
+
387
+ ## License
388
+
389
+ [MIT](https://github.com/mlx-node/mlx-node/blob/main/LICENSE)
package/package.json CHANGED
@@ -1,6 +1,16 @@
1
1
  {
2
2
  "name": "@mlx-node/trl",
3
- "version": "0.0.0",
3
+ "version": "0.0.2",
4
+ "homepage": "https://github.com/mlx-node/mlx-node",
5
+ "bugs": {
6
+ "url": "https://github.com/mlx-node/mlx-node/issues"
7
+ },
8
+ "license": "MIT",
9
+ "repository": {
10
+ "type": "git",
11
+ "url": "https://github.com/mlx-node/mlx-node.git",
12
+ "directory": "packages/trl"
13
+ },
4
14
  "files": [
5
15
  "dist"
6
16
  ],
@@ -19,11 +29,12 @@
19
29
  "test:trainer": "TEST_TRAINER=1 vite test run"
20
30
  },
21
31
  "dependencies": {
22
- "@mlx-node/core": "0.0.0",
23
- "@mlx-node/lm": "0.0.0",
24
- "@std/toml": "npm:@jsr/std__toml@^1.0.11"
32
+ "@mlx-node/core": "0.0.2",
33
+ "@mlx-node/lm": "0.0.2",
34
+ "@std/toml": "npm:@jsr/std__toml@^1.0.11",
35
+ "change-case": "^5.4.4"
25
36
  },
26
37
  "devDependencies": {
27
- "@huggingface/hub": "^2.7.1"
38
+ "@huggingface/hub": "^2.10.7"
28
39
  }
29
40
  }
@@ -1,22 +0,0 @@
1
- import type { DatasetExample, ChatMessage, DatasetSplit, PromptFormatterOptions, PromptTemplate, DatasetLoader } from '../types';
2
- import { type PathValidationOptions } from '../utils/path-security';
3
- export interface LocalDatasetOptions extends PromptFormatterOptions, PathValidationOptions {
4
- basePath?: string;
5
- promptTemplate?: PromptTemplate;
6
- metadata?: Record<string, unknown>;
7
- }
8
- export declare const SYSTEM_PROMPT: string;
9
- export declare const XML_COT_FORMAT = "<reasoning>\n{reasoning}\n</reasoning>\n<answer>\n{answer}\n</answer>";
10
- export declare const defaultPromptTemplate: PromptTemplate;
11
- export declare function createDatasetExample(prompt: ChatMessage[], metadata?: Record<string, unknown>): DatasetExample;
12
- export declare function extractGsm8kAnswer(raw: string): string | null;
13
- export declare function validateDatasetExample(example: DatasetExample): void;
14
- export declare function loadLocalGsm8kDataset(split: DatasetSplit, options?: LocalDatasetOptions & {
15
- limit?: number;
16
- }): Promise<DatasetExample[]>;
17
- export declare class LocalGsm8kDatasetLoader implements DatasetLoader {
18
- private readonly options;
19
- constructor(options?: LocalDatasetOptions);
20
- load(split: DatasetSplit, limit?: number): Promise<DatasetExample[]>;
21
- }
22
- //# sourceMappingURL=dataset.d.ts.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"dataset.d.ts","sourceRoot":"","sources":["../../src/data/dataset.ts"],"names":[],"mappings":"AAEA,OAAO,KAAK,EACV,cAAc,EACd,WAAW,EAEX,YAAY,EACZ,sBAAsB,EACtB,cAAc,EACd,aAAa,EACd,MAAM,UAAU,CAAC;AAElB,OAAO,EAA2C,KAAK,qBAAqB,EAAE,MAAM,wBAAwB,CAAC;AAE7G,MAAM,WAAW,mBAAoB,SAAQ,sBAAsB,EAAE,qBAAqB;IACxF,QAAQ,CAAC,EAAE,MAAM,CAAC;IAClB,cAAc,CAAC,EAAE,cAAc,CAAC;IAChC,QAAQ,CAAC,EAAE,MAAM,CAAC,MAAM,EAAE,OAAO,CAAC,CAAC;CACpC;AAUD,eAAO,MAAM,aAAa,QASlB,CAAC;AAET,eAAO,MAAM,cAAc,0EAKjB,CAAC;AAWX,eAAO,MAAM,qBAAqB,EAAE,cAYnC,CAAC;AAEF,wBAAgB,oBAAoB,CAAC,MAAM,EAAE,WAAW,EAAE,EAAE,QAAQ,CAAC,EAAE,MAAM,CAAC,MAAM,EAAE,OAAO,CAAC,GAAG,cAAc,CAK9G;AAED,wBAAgB,kBAAkB,CAAC,GAAG,EAAE,MAAM,GAAG,MAAM,GAAG,IAAI,CAE7D;AAED,wBAAgB,sBAAsB,CAAC,OAAO,EAAE,cAAc,GAAG,IAAI,CAYpE;AAwDD,wBAAsB,qBAAqB,CACzC,KAAK,EAAE,YAAY,EACnB,OAAO,GAAE,mBAAmB,GAAG;IAAE,KAAK,CAAC,EAAE,MAAM,CAAA;CAAO,GACrD,OAAO,CAAC,cAAc,EAAE,CAAC,CA4B3B;AAED,qBAAa,uBAAwB,YAAW,aAAa;IAC3D,OAAO,CAAC,QAAQ,CAAC,OAAO,CAAsB;gBAElC,OAAO,GAAE,mBAAwB;IAIvC,IAAI,CAAC,KAAK,EAAE,YAAY,EAAE,KAAK,CAAC,EAAE,MAAM,GAAG,OAAO,CAAC,cAAc,EAAE,CAAC;CAG3E"}
@@ -1,142 +0,0 @@
1
- import { readFileSync } from 'node:fs';
2
- import { resolve as resolvePath } from 'node:path';
3
- import { extractHashAnswer } from '../utils/xml-parser';
4
- import { validatePathContainment, getAllowedRoot } from '../utils/path-security';
5
- const DEFAULT_BASE_PATH = resolvePath(process.cwd(), 'data/gsm8k');
6
- const VALID_SPLITS = new Set(['train', 'test']);
7
- export const SYSTEM_PROMPT = `
8
- Respond in the following format:
9
-
10
- <reasoning>
11
- ...
12
- </reasoning>
13
- <answer>
14
- ...
15
- </answer>
16
- `.trim();
17
- export const XML_COT_FORMAT = `<reasoning>
18
- {reasoning}
19
- </reasoning>
20
- <answer>
21
- {answer}
22
- </answer>`;
23
- const SYSTEM_MESSAGE = {
24
- role: 'system',
25
- content: SYSTEM_PROMPT,
26
- };
27
- function createMessage(role, content) {
28
- return { role, content };
29
- }
30
- export const defaultPromptTemplate = (question, options) => {
31
- const messages = [SYSTEM_MESSAGE];
32
- if (options?.includeOneShot && options.oneShotExample) {
33
- const { question: exampleQuestion, reasoning, answer } = options.oneShotExample;
34
- messages.push(createMessage('user', exampleQuestion), createMessage('assistant', XML_COT_FORMAT.replace('{reasoning}', reasoning).replace('{answer}', answer)));
35
- }
36
- messages.push(createMessage('user', question));
37
- return messages;
38
- };
39
- export function createDatasetExample(prompt, metadata) {
40
- return {
41
- prompt: prompt.map((message) => ({ ...message })), // defensive copy
42
- metadata: metadata ? { ...metadata } : undefined,
43
- };
44
- }
45
- export function extractGsm8kAnswer(raw) {
46
- return extractHashAnswer(raw);
47
- }
48
- export function validateDatasetExample(example) {
49
- if (!Array.isArray(example.prompt) || example.prompt.length === 0) {
50
- throw new Error('Dataset example must contain at least one prompt message.');
51
- }
52
- for (const message of example.prompt) {
53
- if (!message || typeof message.content !== 'string' || message.content.trim() === '') {
54
- throw new Error('Prompt messages must include non-empty textual content.');
55
- }
56
- if (message.role !== 'system' && message.role !== 'user' && message.role !== 'assistant') {
57
- throw new Error(`Unsupported chat role: ${String(message.role)}`);
58
- }
59
- }
60
- }
61
- function resolveBasePath(optionPath, options) {
62
- const allowedRoot = getAllowedRoot(options);
63
- if (!optionPath) {
64
- // Default path - validate it's within allowed root
65
- validatePathContainment(DEFAULT_BASE_PATH, allowedRoot);
66
- return DEFAULT_BASE_PATH;
67
- }
68
- // Resolve and validate user-provided path
69
- const resolved = resolvePath(allowedRoot, optionPath);
70
- validatePathContainment(resolved, allowedRoot);
71
- return resolved;
72
- }
73
- function datasetFileForSplit(split) {
74
- if (!VALID_SPLITS.has(split)) {
75
- throw new Error(`Unsupported GSM8K split "${split}". Expected one of: ${Array.from(VALID_SPLITS).join(', ')}`);
76
- }
77
- return `${split}.jsonl`;
78
- }
79
- function readDatasetFile(filePath) {
80
- try {
81
- return readFileSync(filePath, 'utf8');
82
- }
83
- catch (error) {
84
- const message = error instanceof Error ? error.message : String(error);
85
- throw new Error(`Failed to read dataset file at ${filePath}: ${message}`);
86
- }
87
- }
88
- function readJsonl(path, limit) {
89
- const fileContents = readDatasetFile(path);
90
- const lines = fileContents.split(/\r?\n/).filter((line) => line.trim().length > 0);
91
- const records = [];
92
- const max = typeof limit === 'number' && limit >= 0 ? limit : Number.POSITIVE_INFINITY;
93
- for (let i = 0; i < lines.length && records.length < max; i += 1) {
94
- const line = lines[i];
95
- try {
96
- const parsed = JSON.parse(line);
97
- if (typeof parsed.question !== 'string' || typeof parsed.answer !== 'string') {
98
- throw new Error('Record must include string "question" and "answer" fields.');
99
- }
100
- records.push({ question: parsed.question, answer: parsed.answer });
101
- }
102
- catch (error) {
103
- const message = error instanceof Error ? error.message : String(error);
104
- throw new Error(`Failed to parse JSONL record at ${path}:${i + 1} - ${message}`);
105
- }
106
- }
107
- return records;
108
- }
109
- export async function loadLocalGsm8kDataset(split, options = {}) {
110
- const basePath = resolveBasePath(options.basePath, options);
111
- const fileName = datasetFileForSplit(split);
112
- const filePath = resolvePath(basePath, fileName);
113
- // Additional validation: ensure the final file path stays within the base path
114
- // This protects against any edge cases where the filename could escape
115
- validatePathContainment(filePath, basePath);
116
- const promptTemplate = options.promptTemplate ?? defaultPromptTemplate;
117
- const records = readJsonl(filePath, options.limit);
118
- const examples = records.map((record, index) => {
119
- const prompt = promptTemplate(record.question, {
120
- includeOneShot: options.includeOneShot,
121
- oneShotExample: options.oneShotExample,
122
- });
123
- const example = createDatasetExample(prompt, {
124
- split,
125
- index,
126
- raw_answer: record.answer,
127
- ...options.metadata,
128
- });
129
- validateDatasetExample(example);
130
- return example;
131
- });
132
- return examples;
133
- }
134
- export class LocalGsm8kDatasetLoader {
135
- options;
136
- constructor(options = {}) {
137
- this.options = { ...options };
138
- }
139
- async load(split, limit) {
140
- return loadLocalGsm8kDataset(split, { ...this.options, limit });
141
- }
142
- }