@mlx-node/trl 0.0.0 → 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +389 -0
- package/package.json +16 -5
- package/dist/data/dataset.d.ts +0 -22
- package/dist/data/dataset.d.ts.map +0 -1
- package/dist/data/dataset.js +0 -142
- package/dist/data/sft-dataset.d.ts +0 -156
- package/dist/data/sft-dataset.d.ts.map +0 -1
- package/dist/data/sft-dataset.js +0 -415
- package/dist/index.d.ts +0 -33
- package/dist/index.d.ts.map +0 -1
- package/dist/index.js +0 -47
- package/dist/trainers/grpo-config.d.ts +0 -42
- package/dist/trainers/grpo-config.d.ts.map +0 -1
- package/dist/trainers/grpo-config.js +0 -220
- package/dist/trainers/grpo-entropy.d.ts +0 -33
- package/dist/trainers/grpo-entropy.d.ts.map +0 -1
- package/dist/trainers/grpo-entropy.js +0 -18
- package/dist/trainers/grpo-trainer.d.ts +0 -602
- package/dist/trainers/grpo-trainer.d.ts.map +0 -1
- package/dist/trainers/grpo-trainer.js +0 -1439
- package/dist/trainers/sft-config.d.ts +0 -32
- package/dist/trainers/sft-config.d.ts.map +0 -1
- package/dist/trainers/sft-config.js +0 -186
- package/dist/trainers/sft-trainer.d.ts +0 -141
- package/dist/trainers/sft-trainer.d.ts.map +0 -1
- package/dist/trainers/sft-trainer.js +0 -502
- package/dist/trainers/training-logger.d.ts +0 -375
- package/dist/trainers/training-logger.d.ts.map +0 -1
- package/dist/trainers/training-logger.js +0 -542
- package/dist/types.d.ts +0 -54
- package/dist/types.d.ts.map +0 -1
- package/dist/types.js +0 -1
- package/dist/utils/path-security.d.ts +0 -51
- package/dist/utils/path-security.d.ts.map +0 -1
- package/dist/utils/path-security.js +0 -69
- package/dist/utils/xml-parser.d.ts +0 -6
- package/dist/utils/xml-parser.d.ts.map +0 -1
- package/dist/utils/xml-parser.js +0 -184
|
@@ -1,542 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* Unified Training Logger
|
|
3
|
-
*
|
|
4
|
-
* A high-level logger abstraction that handles TUI/console mode automatically.
|
|
5
|
-
* Eliminates verbose conditional checks throughout the codebase.
|
|
6
|
-
*
|
|
7
|
-
* @example
|
|
8
|
-
* ```typescript
|
|
9
|
-
* const logger = createTrainingLogger();
|
|
10
|
-
*
|
|
11
|
-
* logger.info('Loading model...'); // Console OR TUI log
|
|
12
|
-
* logger.status('loading', 'Loading...'); // TUI header update
|
|
13
|
-
* logger.step(metrics); // Training step
|
|
14
|
-
* logger.checkpoint(path, step); // Checkpoint saved
|
|
15
|
-
* ```
|
|
16
|
-
*/
|
|
17
|
-
import { appendFileSync, mkdirSync } from 'node:fs';
|
|
18
|
-
import { join } from 'node:path';
|
|
19
|
-
// ============================================================================
|
|
20
|
-
// Metrics Aggregator
|
|
21
|
-
// ============================================================================
|
|
22
|
-
/**
|
|
23
|
-
* Aggregator for computing running statistics over an epoch
|
|
24
|
-
*/
|
|
25
|
-
class MetricsAggregator {
|
|
26
|
-
values = [];
|
|
27
|
-
add(value) {
|
|
28
|
-
this.values.push(value);
|
|
29
|
-
}
|
|
30
|
-
mean() {
|
|
31
|
-
if (this.values.length === 0)
|
|
32
|
-
return 0;
|
|
33
|
-
return this.values.reduce((a, b) => a + b, 0) / this.values.length;
|
|
34
|
-
}
|
|
35
|
-
sum() {
|
|
36
|
-
return this.values.reduce((a, b) => a + b, 0);
|
|
37
|
-
}
|
|
38
|
-
reset() {
|
|
39
|
-
this.values = [];
|
|
40
|
-
}
|
|
41
|
-
}
|
|
42
|
-
// ============================================================================
|
|
43
|
-
// Training Logger
|
|
44
|
-
// ============================================================================
|
|
45
|
-
/**
|
|
46
|
-
* Unified training logger that handles TUI/console mode automatically
|
|
47
|
-
*/
|
|
48
|
-
export class TrainingLogger {
|
|
49
|
-
config;
|
|
50
|
-
jsonlPath;
|
|
51
|
-
startTime;
|
|
52
|
-
lastLogTime;
|
|
53
|
-
// Epoch-level aggregators
|
|
54
|
-
epochLoss = new MetricsAggregator();
|
|
55
|
-
epochReward = new MetricsAggregator();
|
|
56
|
-
epochAdvantage = new MetricsAggregator();
|
|
57
|
-
epochTokens = new MetricsAggregator();
|
|
58
|
-
constructor(config = {}) {
|
|
59
|
-
// TUI mode is ONLY enabled via environment variable (set by mlx-tui)
|
|
60
|
-
// This ensures users don't accidentally miss output in CLI mode
|
|
61
|
-
const tuiMode = process.env.MLX_TUI_MODE === '1';
|
|
62
|
-
this.config = {
|
|
63
|
-
tuiMode,
|
|
64
|
-
logConsole: config.logConsole ?? true,
|
|
65
|
-
logJsonl: config.logJsonl ?? false,
|
|
66
|
-
logInterval: config.logInterval ?? 1,
|
|
67
|
-
outputDir: config.outputDir,
|
|
68
|
-
runName: config.runName,
|
|
69
|
-
};
|
|
70
|
-
this.startTime = Date.now();
|
|
71
|
-
this.lastLogTime = this.startTime;
|
|
72
|
-
this.setupJsonl();
|
|
73
|
-
}
|
|
74
|
-
setupJsonl() {
|
|
75
|
-
if (!this.config.logJsonl || this.config.tuiMode)
|
|
76
|
-
return;
|
|
77
|
-
if (!this.config.outputDir)
|
|
78
|
-
return;
|
|
79
|
-
try {
|
|
80
|
-
mkdirSync(this.config.outputDir, { recursive: true });
|
|
81
|
-
const runName = this.config.runName || 'grpo-training';
|
|
82
|
-
this.jsonlPath = join(this.config.outputDir, `${runName}.jsonl`);
|
|
83
|
-
this.writeJsonl({
|
|
84
|
-
event: 'training_start',
|
|
85
|
-
timestamp: new Date().toISOString(),
|
|
86
|
-
config: this.config,
|
|
87
|
-
});
|
|
88
|
-
}
|
|
89
|
-
catch {
|
|
90
|
-
// Silently ignore - logging shouldn't crash training
|
|
91
|
-
}
|
|
92
|
-
}
|
|
93
|
-
// ==========================================================================
|
|
94
|
-
// Core Logging Methods
|
|
95
|
-
// ==========================================================================
|
|
96
|
-
/** Log info message */
|
|
97
|
-
info(message) {
|
|
98
|
-
if (this.config.tuiMode) {
|
|
99
|
-
this.writeTui({ type: 'log', level: 'info', message });
|
|
100
|
-
}
|
|
101
|
-
else if (this.config.logConsole) {
|
|
102
|
-
console.log(message);
|
|
103
|
-
}
|
|
104
|
-
}
|
|
105
|
-
/** Log warning message */
|
|
106
|
-
warn(message) {
|
|
107
|
-
if (this.config.tuiMode) {
|
|
108
|
-
this.writeTui({ type: 'log', level: 'warn', message });
|
|
109
|
-
}
|
|
110
|
-
else {
|
|
111
|
-
console.warn(message);
|
|
112
|
-
}
|
|
113
|
-
}
|
|
114
|
-
/** Log error message */
|
|
115
|
-
error(message) {
|
|
116
|
-
if (this.config.tuiMode) {
|
|
117
|
-
this.writeTui({ type: 'log', level: 'error', message });
|
|
118
|
-
}
|
|
119
|
-
else {
|
|
120
|
-
console.error(message);
|
|
121
|
-
}
|
|
122
|
-
}
|
|
123
|
-
/** Log debug message (only in verbose mode) */
|
|
124
|
-
debug(message) {
|
|
125
|
-
if (this.config.tuiMode) {
|
|
126
|
-
this.writeTui({ type: 'log', level: 'debug', message });
|
|
127
|
-
}
|
|
128
|
-
else if (this.config.logConsole && process.env.DEBUG) {
|
|
129
|
-
console.log(`[DEBUG] ${message}`);
|
|
130
|
-
}
|
|
131
|
-
}
|
|
132
|
-
/** Update status (updates TUI header, shows in console) */
|
|
133
|
-
status(phase, message) {
|
|
134
|
-
if (this.config.tuiMode) {
|
|
135
|
-
this.writeTui({ type: 'status', phase, message });
|
|
136
|
-
}
|
|
137
|
-
else if (this.config.logConsole) {
|
|
138
|
-
console.log(message);
|
|
139
|
-
}
|
|
140
|
-
}
|
|
141
|
-
/** Print decorative banner (console only, suppressed in TUI mode) */
|
|
142
|
-
banner(...lines) {
|
|
143
|
-
if (this.config.tuiMode)
|
|
144
|
-
return;
|
|
145
|
-
if (!this.config.logConsole)
|
|
146
|
-
return;
|
|
147
|
-
for (const line of lines) {
|
|
148
|
-
console.log(line);
|
|
149
|
-
}
|
|
150
|
-
}
|
|
151
|
-
// ==========================================================================
|
|
152
|
-
// Training Event Methods
|
|
153
|
-
// ==========================================================================
|
|
154
|
-
/** Log training initialization */
|
|
155
|
-
init(model, config, numExamples) {
|
|
156
|
-
const trainingType = config.trainingType ?? 'grpo';
|
|
157
|
-
if (this.config.tuiMode) {
|
|
158
|
-
this.writeTui({ type: 'init', model, config: { ...config, trainingType } });
|
|
159
|
-
}
|
|
160
|
-
else if (this.config.logConsole) {
|
|
161
|
-
const trainingLabel = trainingType === 'sft' ? 'SFT' : 'GRPO';
|
|
162
|
-
console.log(`\n Starting ${trainingLabel} training`);
|
|
163
|
-
if (numExamples)
|
|
164
|
-
console.log(` Examples: ${numExamples}`);
|
|
165
|
-
console.log(` Epochs: ${config.numEpochs}`);
|
|
166
|
-
console.log(` Batch size: ${config.batchSize}`);
|
|
167
|
-
if (trainingType !== 'sft') {
|
|
168
|
-
console.log(` Group size: ${config.groupSize}`);
|
|
169
|
-
}
|
|
170
|
-
console.log(` Learning rate: ${config.learningRate}`);
|
|
171
|
-
}
|
|
172
|
-
if (this.jsonlPath && numExamples !== undefined) {
|
|
173
|
-
this.writeJsonl({
|
|
174
|
-
event: 'training_config',
|
|
175
|
-
num_examples: numExamples,
|
|
176
|
-
config: { ...config, trainingType },
|
|
177
|
-
timestamp: new Date().toISOString(),
|
|
178
|
-
});
|
|
179
|
-
}
|
|
180
|
-
}
|
|
181
|
-
/** Log epoch start */
|
|
182
|
-
epochStart(epoch, totalEpochs, numBatches) {
|
|
183
|
-
if (this.config.tuiMode) {
|
|
184
|
-
this.writeTui({
|
|
185
|
-
type: 'epoch_start',
|
|
186
|
-
epoch: epoch + 1,
|
|
187
|
-
totalEpochs,
|
|
188
|
-
numBatches,
|
|
189
|
-
});
|
|
190
|
-
}
|
|
191
|
-
else if (this.config.logConsole) {
|
|
192
|
-
console.log(`\n=== Epoch ${epoch + 1}/${totalEpochs} (${numBatches} batches) ===`);
|
|
193
|
-
}
|
|
194
|
-
if (this.jsonlPath) {
|
|
195
|
-
this.writeJsonl({
|
|
196
|
-
event: 'epoch_start',
|
|
197
|
-
epoch: epoch + 1,
|
|
198
|
-
num_batches: numBatches,
|
|
199
|
-
timestamp: new Date().toISOString(),
|
|
200
|
-
});
|
|
201
|
-
}
|
|
202
|
-
}
|
|
203
|
-
/** Log training step */
|
|
204
|
-
step(metrics, batchIdx, numBatches) {
|
|
205
|
-
// Aggregate for epoch (use available metrics)
|
|
206
|
-
this.epochLoss.add(metrics.loss);
|
|
207
|
-
if (metrics.meanReward !== undefined) {
|
|
208
|
-
this.epochReward.add(metrics.meanReward);
|
|
209
|
-
}
|
|
210
|
-
if (metrics.meanAdvantage !== undefined) {
|
|
211
|
-
this.epochAdvantage.add(metrics.meanAdvantage);
|
|
212
|
-
}
|
|
213
|
-
this.epochTokens.add(metrics.totalTokens);
|
|
214
|
-
if (this.config.tuiMode) {
|
|
215
|
-
// Build step message with only available fields (no fake values!)
|
|
216
|
-
const stepMsg = {
|
|
217
|
-
type: 'step',
|
|
218
|
-
step: metrics.step,
|
|
219
|
-
loss: metrics.loss,
|
|
220
|
-
totalTokens: metrics.totalTokens,
|
|
221
|
-
};
|
|
222
|
-
// Add GRPO-specific fields if present
|
|
223
|
-
if (metrics.meanReward !== undefined) {
|
|
224
|
-
stepMsg.meanReward = metrics.meanReward;
|
|
225
|
-
}
|
|
226
|
-
if (metrics.stdReward !== undefined) {
|
|
227
|
-
stepMsg.stdReward = metrics.stdReward;
|
|
228
|
-
}
|
|
229
|
-
if (metrics.stdAdvantage !== undefined) {
|
|
230
|
-
stepMsg.stdAdvantage = metrics.stdAdvantage;
|
|
231
|
-
}
|
|
232
|
-
// Add SFT-specific fields if present
|
|
233
|
-
if (metrics.perplexity !== undefined) {
|
|
234
|
-
stepMsg.perplexity = metrics.perplexity;
|
|
235
|
-
}
|
|
236
|
-
if (metrics.tokenAccuracy !== undefined) {
|
|
237
|
-
stepMsg.tokenAccuracy = metrics.tokenAccuracy;
|
|
238
|
-
}
|
|
239
|
-
// Add timing if present
|
|
240
|
-
if (metrics.generationTimeMs !== undefined) {
|
|
241
|
-
stepMsg.generationTimeMs = metrics.generationTimeMs;
|
|
242
|
-
}
|
|
243
|
-
if (metrics.trainingTimeMs !== undefined) {
|
|
244
|
-
stepMsg.trainingTimeMs = metrics.trainingTimeMs;
|
|
245
|
-
}
|
|
246
|
-
// Add memory if present
|
|
247
|
-
if (metrics.peakMemoryMb !== undefined) {
|
|
248
|
-
stepMsg.peakMemoryMb = metrics.peakMemoryMb;
|
|
249
|
-
}
|
|
250
|
-
if (metrics.activeMemoryMb !== undefined) {
|
|
251
|
-
stepMsg.activeMemoryMb = metrics.activeMemoryMb;
|
|
252
|
-
}
|
|
253
|
-
this.writeTui(stepMsg);
|
|
254
|
-
}
|
|
255
|
-
else if (this.config.logConsole && metrics.step % this.config.logInterval === 0) {
|
|
256
|
-
const now = Date.now();
|
|
257
|
-
const stepTime = (now - this.lastLogTime) / this.config.logInterval;
|
|
258
|
-
this.lastLogTime = now;
|
|
259
|
-
const batchInfo = batchIdx !== undefined && numBatches !== undefined ? ` | Batch ${batchIdx + 1}/${numBatches}` : '';
|
|
260
|
-
// Build log message based on available metrics (SFT vs GRPO)
|
|
261
|
-
let logMsg = `Step ${metrics.step}${batchInfo} | Loss: ${metrics.loss.toFixed(4)}`;
|
|
262
|
-
if (metrics.perplexity !== undefined) {
|
|
263
|
-
// SFT format
|
|
264
|
-
logMsg += ` | Perplexity: ${metrics.perplexity.toFixed(2)}`;
|
|
265
|
-
}
|
|
266
|
-
if (metrics.tokenAccuracy !== undefined) {
|
|
267
|
-
logMsg += ` | Acc: ${(metrics.tokenAccuracy * 100).toFixed(1)}%`;
|
|
268
|
-
}
|
|
269
|
-
if (metrics.meanReward !== undefined) {
|
|
270
|
-
// GRPO format
|
|
271
|
-
logMsg += ` | Reward: ${metrics.meanReward.toFixed(4)}`;
|
|
272
|
-
}
|
|
273
|
-
if (metrics.meanAdvantage !== undefined) {
|
|
274
|
-
logMsg += ` | Adv: ${metrics.meanAdvantage.toFixed(4)}`;
|
|
275
|
-
}
|
|
276
|
-
logMsg += ` | Tokens: ${metrics.totalTokens} | Time: ${stepTime.toFixed(0)}ms/step`;
|
|
277
|
-
console.log(logMsg);
|
|
278
|
-
}
|
|
279
|
-
if (this.jsonlPath && metrics.step % this.config.logInterval === 0) {
|
|
280
|
-
this.writeJsonl({
|
|
281
|
-
event: 'step',
|
|
282
|
-
step: metrics.step,
|
|
283
|
-
loss: metrics.loss,
|
|
284
|
-
mean_reward: metrics.meanReward ?? 0,
|
|
285
|
-
std_reward: metrics.stdReward ?? 0,
|
|
286
|
-
mean_advantage: metrics.meanAdvantage ?? 0,
|
|
287
|
-
total_tokens: metrics.totalTokens,
|
|
288
|
-
timestamp: new Date().toISOString(),
|
|
289
|
-
});
|
|
290
|
-
}
|
|
291
|
-
}
|
|
292
|
-
/** Log epoch end/summary */
|
|
293
|
-
epochEnd(epoch, totalEpochs, epochTimeSecs) {
|
|
294
|
-
const avgLoss = this.epochLoss.mean();
|
|
295
|
-
const avgReward = this.epochReward.mean();
|
|
296
|
-
const avgAdvantage = this.epochAdvantage.mean();
|
|
297
|
-
const totalTokens = this.epochTokens.sum();
|
|
298
|
-
if (this.config.tuiMode) {
|
|
299
|
-
this.writeTui({
|
|
300
|
-
type: 'epoch_end',
|
|
301
|
-
epoch: epoch + 1,
|
|
302
|
-
avgLoss,
|
|
303
|
-
avgReward,
|
|
304
|
-
epochTimeSecs: epochTimeSecs ?? 0,
|
|
305
|
-
});
|
|
306
|
-
}
|
|
307
|
-
else if (this.config.logConsole) {
|
|
308
|
-
console.log(`\nEpoch ${epoch + 1}/${totalEpochs} Summary | ` +
|
|
309
|
-
`Avg Loss: ${avgLoss.toFixed(4)} | ` +
|
|
310
|
-
`Avg Reward: ${avgReward.toFixed(4)} | ` +
|
|
311
|
-
`Avg Advantage: ${avgAdvantage.toFixed(4)} | ` +
|
|
312
|
-
`Total Tokens: ${totalTokens.toFixed(0)}`);
|
|
313
|
-
}
|
|
314
|
-
if (this.jsonlPath) {
|
|
315
|
-
this.writeJsonl({
|
|
316
|
-
event: 'epoch',
|
|
317
|
-
epoch: epoch + 1,
|
|
318
|
-
avg_loss: avgLoss,
|
|
319
|
-
avg_reward: avgReward,
|
|
320
|
-
avg_advantage: avgAdvantage,
|
|
321
|
-
total_tokens: totalTokens,
|
|
322
|
-
timestamp: new Date().toISOString(),
|
|
323
|
-
});
|
|
324
|
-
}
|
|
325
|
-
// Reset aggregators
|
|
326
|
-
this.epochLoss.reset();
|
|
327
|
-
this.epochReward.reset();
|
|
328
|
-
this.epochAdvantage.reset();
|
|
329
|
-
this.epochTokens.reset();
|
|
330
|
-
}
|
|
331
|
-
/** Log checkpoint saved */
|
|
332
|
-
checkpoint(path, step) {
|
|
333
|
-
if (this.config.tuiMode) {
|
|
334
|
-
this.writeTui({ type: 'checkpoint', path, step });
|
|
335
|
-
}
|
|
336
|
-
else if (this.config.logConsole) {
|
|
337
|
-
console.log(`💾 Checkpoint saved: ${path}`);
|
|
338
|
-
}
|
|
339
|
-
if (this.jsonlPath) {
|
|
340
|
-
this.writeJsonl({
|
|
341
|
-
event: 'checkpoint',
|
|
342
|
-
step,
|
|
343
|
-
path,
|
|
344
|
-
timestamp: new Date().toISOString(),
|
|
345
|
-
});
|
|
346
|
-
}
|
|
347
|
-
}
|
|
348
|
-
/** Log training completion */
|
|
349
|
-
complete(totalSteps) {
|
|
350
|
-
const totalTime = Date.now() - this.startTime;
|
|
351
|
-
const totalMinutes = totalTime / 60000;
|
|
352
|
-
const totalTimeSecs = totalTime / 1000;
|
|
353
|
-
if (this.config.tuiMode) {
|
|
354
|
-
this.writeTui({ type: 'complete', totalSteps, totalTimeSecs });
|
|
355
|
-
}
|
|
356
|
-
else if (this.config.logConsole) {
|
|
357
|
-
console.log(`\n✓ Training complete! Final step: ${totalSteps} | Total time: ${totalMinutes.toFixed(2)} minutes`);
|
|
358
|
-
}
|
|
359
|
-
if (this.jsonlPath) {
|
|
360
|
-
this.writeJsonl({
|
|
361
|
-
event: 'training_complete',
|
|
362
|
-
final_step: totalSteps,
|
|
363
|
-
total_time_ms: totalTime,
|
|
364
|
-
timestamp: new Date().toISOString(),
|
|
365
|
-
});
|
|
366
|
-
}
|
|
367
|
-
}
|
|
368
|
-
/** Log generation sample (TUI only) */
|
|
369
|
-
generation(sample) {
|
|
370
|
-
if (!this.config.tuiMode)
|
|
371
|
-
return;
|
|
372
|
-
this.writeTui({
|
|
373
|
-
type: 'generation',
|
|
374
|
-
index: sample.index,
|
|
375
|
-
prompt: sample.prompt,
|
|
376
|
-
completion: sample.completion,
|
|
377
|
-
reward: sample.reward,
|
|
378
|
-
tokens: sample.tokens,
|
|
379
|
-
rewardDetails: sample.rewardDetails,
|
|
380
|
-
});
|
|
381
|
-
}
|
|
382
|
-
/** Log training paused (TUI only) */
|
|
383
|
-
paused(step) {
|
|
384
|
-
if (!this.config.tuiMode)
|
|
385
|
-
return;
|
|
386
|
-
this.writeTui({ type: 'paused', step });
|
|
387
|
-
}
|
|
388
|
-
/** Log training resumed (TUI only) */
|
|
389
|
-
resumed(step) {
|
|
390
|
-
if (!this.config.tuiMode)
|
|
391
|
-
return;
|
|
392
|
-
this.writeTui({ type: 'resumed', step });
|
|
393
|
-
}
|
|
394
|
-
/** Log database path for TUI DB tab */
|
|
395
|
-
databasePath(path, runId, runName) {
|
|
396
|
-
if (!this.config.tuiMode)
|
|
397
|
-
return;
|
|
398
|
-
this.writeTui({ type: 'database_path', path, runId, runName });
|
|
399
|
-
}
|
|
400
|
-
/**
|
|
401
|
-
* Send resume state to TUI for sparkline and aggregate restoration.
|
|
402
|
-
* Called when resuming from checkpoint to restore TUI history.
|
|
403
|
-
*/
|
|
404
|
-
resumeState(state) {
|
|
405
|
-
if (!this.config.tuiMode)
|
|
406
|
-
return;
|
|
407
|
-
this.writeTui({
|
|
408
|
-
type: 'resume_state',
|
|
409
|
-
step: state.step,
|
|
410
|
-
epoch: state.epoch,
|
|
411
|
-
totalEpochs: state.totalEpochs,
|
|
412
|
-
stepInEpoch: state.stepInEpoch,
|
|
413
|
-
totalStepsInEpoch: state.totalStepsInEpoch,
|
|
414
|
-
metricsHistory: state.metricsHistory,
|
|
415
|
-
aggregates: state.aggregates,
|
|
416
|
-
});
|
|
417
|
-
}
|
|
418
|
-
/**
|
|
419
|
-
* Send an interactive prompt to the TUI and wait for response.
|
|
420
|
-
* Only works in TUI mode - returns null in non-TUI mode.
|
|
421
|
-
*
|
|
422
|
-
* @param id - Unique ID for this prompt
|
|
423
|
-
* @param message - Message to display
|
|
424
|
-
* @param choices - Available choices
|
|
425
|
-
* @param options - Prompt options
|
|
426
|
-
* @returns The selected value(s), or null if not in TUI mode.
|
|
427
|
-
* For multi-select, returns comma-separated values (use promptMulti for array).
|
|
428
|
-
*/
|
|
429
|
-
async prompt(id, message, choices, options) {
|
|
430
|
-
if (!this.config.tuiMode) {
|
|
431
|
-
return null; // Caller should handle non-TUI mode
|
|
432
|
-
}
|
|
433
|
-
const defaultIndices = options?.default !== undefined
|
|
434
|
-
? Array.isArray(options.default)
|
|
435
|
-
? options.default
|
|
436
|
-
: [options.default]
|
|
437
|
-
: undefined;
|
|
438
|
-
// Send prompt to TUI
|
|
439
|
-
this.writeTui({
|
|
440
|
-
type: 'prompt',
|
|
441
|
-
id,
|
|
442
|
-
message,
|
|
443
|
-
choices,
|
|
444
|
-
default: defaultIndices,
|
|
445
|
-
multiSelect: options?.multiSelect ?? false,
|
|
446
|
-
});
|
|
447
|
-
// Wait for response via stdin
|
|
448
|
-
return new Promise((resolve) => {
|
|
449
|
-
const onData = (data) => {
|
|
450
|
-
const line = data.toString().trim();
|
|
451
|
-
// Expected format: PROMPT:<id>:<value>
|
|
452
|
-
// For multi-select, value is comma-separated
|
|
453
|
-
if (line.startsWith('PROMPT:')) {
|
|
454
|
-
const parts = line.split(':');
|
|
455
|
-
if (parts.length >= 3 && parts[1] === id) {
|
|
456
|
-
const value = parts.slice(2).join(':'); // Handle values with colons
|
|
457
|
-
process.stdin.removeListener('data', onData);
|
|
458
|
-
process.stdin.pause();
|
|
459
|
-
resolve(value);
|
|
460
|
-
}
|
|
461
|
-
}
|
|
462
|
-
};
|
|
463
|
-
process.stdin.resume();
|
|
464
|
-
process.stdin.on('data', onData);
|
|
465
|
-
});
|
|
466
|
-
}
|
|
467
|
-
/**
|
|
468
|
-
* Send a multi-select prompt to the TUI and wait for response.
|
|
469
|
-
* Convenience wrapper that returns an array of selected values.
|
|
470
|
-
*
|
|
471
|
-
* @param id - Unique ID for this prompt
|
|
472
|
-
* @param message - Message to display
|
|
473
|
-
* @param choices - Available choices
|
|
474
|
-
* @param defaultIndices - Optional default selection indices
|
|
475
|
-
* @returns Array of selected values, or null if not in TUI mode
|
|
476
|
-
*/
|
|
477
|
-
async promptMulti(id, message, choices, defaultIndices) {
|
|
478
|
-
const result = await this.prompt(id, message, choices, {
|
|
479
|
-
multiSelect: true,
|
|
480
|
-
default: defaultIndices,
|
|
481
|
-
});
|
|
482
|
-
if (result === null)
|
|
483
|
-
return null;
|
|
484
|
-
if (result === '')
|
|
485
|
-
return []; // No selections
|
|
486
|
-
return result.split(',');
|
|
487
|
-
}
|
|
488
|
-
// ==========================================================================
|
|
489
|
-
// Accessors
|
|
490
|
-
// ==========================================================================
|
|
491
|
-
/** Check if TUI mode is enabled */
|
|
492
|
-
get isTuiMode() {
|
|
493
|
-
return this.config.tuiMode;
|
|
494
|
-
}
|
|
495
|
-
/** Get the log interval */
|
|
496
|
-
get logInterval() {
|
|
497
|
-
return this.config.logInterval;
|
|
498
|
-
}
|
|
499
|
-
// ==========================================================================
|
|
500
|
-
// Internal Methods
|
|
501
|
-
// ==========================================================================
|
|
502
|
-
writeTui(msg) {
|
|
503
|
-
if (!this.config.tuiMode)
|
|
504
|
-
return;
|
|
505
|
-
process.stdout.write(JSON.stringify(msg) + '\n');
|
|
506
|
-
}
|
|
507
|
-
writeJsonl(data) {
|
|
508
|
-
if (!this.jsonlPath)
|
|
509
|
-
return;
|
|
510
|
-
try {
|
|
511
|
-
const line = JSON.stringify(data) + '\n';
|
|
512
|
-
appendFileSync(this.jsonlPath, line, 'utf8');
|
|
513
|
-
}
|
|
514
|
-
catch {
|
|
515
|
-
// Silently ignore - logging shouldn't crash training
|
|
516
|
-
}
|
|
517
|
-
}
|
|
518
|
-
}
|
|
519
|
-
// ============================================================================
|
|
520
|
-
// Factory Function
|
|
521
|
-
// ============================================================================
|
|
522
|
-
/**
|
|
523
|
-
* Create a training logger instance
|
|
524
|
-
*
|
|
525
|
-
* @example
|
|
526
|
-
* ```typescript
|
|
527
|
-
* // Auto-detect TUI mode from environment
|
|
528
|
-
* const logger = createTrainingLogger();
|
|
529
|
-
*
|
|
530
|
-
* // Explicit configuration
|
|
531
|
-
* const logger = createTrainingLogger({
|
|
532
|
-
* tuiMode: false,
|
|
533
|
-
* logConsole: true,
|
|
534
|
-
* logJsonl: true,
|
|
535
|
-
* outputDir: './outputs',
|
|
536
|
-
* logInterval: 10,
|
|
537
|
-
* });
|
|
538
|
-
* ```
|
|
539
|
-
*/
|
|
540
|
-
export function createTrainingLogger(config) {
|
|
541
|
-
return new TrainingLogger(config);
|
|
542
|
-
}
|
package/dist/types.d.ts
DELETED
|
@@ -1,54 +0,0 @@
|
|
|
1
|
-
export type { CompletionInfo, RewardOutput } from '@mlx-node/core';
|
|
2
|
-
import type { RewardOutput } from '@mlx-node/core';
|
|
3
|
-
export type ChatRole = 'system' | 'user' | 'assistant' | 'tool';
|
|
4
|
-
export interface ChatMessage {
|
|
5
|
-
role: ChatRole;
|
|
6
|
-
content: string;
|
|
7
|
-
}
|
|
8
|
-
export interface CompletionMessage extends ChatMessage {
|
|
9
|
-
}
|
|
10
|
-
export type Completion = CompletionMessage[];
|
|
11
|
-
export type DatasetSplit = 'train' | 'test' | (string & {});
|
|
12
|
-
export interface DatasetExample {
|
|
13
|
-
prompt: ChatMessage[];
|
|
14
|
-
metadata?: Record<string, unknown>;
|
|
15
|
-
}
|
|
16
|
-
export interface XmlParseResult {
|
|
17
|
-
reasoning: string | null;
|
|
18
|
-
answer: string | null;
|
|
19
|
-
isStrictMatch: boolean;
|
|
20
|
-
isSoftMatch: boolean;
|
|
21
|
-
errors: string[];
|
|
22
|
-
}
|
|
23
|
-
export interface RewardComputationInput {
|
|
24
|
-
prompts: ChatMessage[][];
|
|
25
|
-
completions: Completion[];
|
|
26
|
-
answers: (string | null)[];
|
|
27
|
-
}
|
|
28
|
-
/**
|
|
29
|
-
* Unified reward function type for GRPO training.
|
|
30
|
-
*
|
|
31
|
-
* Takes an array of RewardOutput objects containing structured completion data.
|
|
32
|
-
* Returns rewards for each completion (one per output).
|
|
33
|
-
*/
|
|
34
|
-
export type RewardFunction<T = unknown> = (outputs: RewardOutput[], context: T) => number[] | Float32Array | Promise<number[] | Float32Array>;
|
|
35
|
-
export interface PromptFormatterOptions {
|
|
36
|
-
includeOneShot?: boolean;
|
|
37
|
-
oneShotExample?: {
|
|
38
|
-
question: string;
|
|
39
|
-
reasoning: string;
|
|
40
|
-
answer: string;
|
|
41
|
-
};
|
|
42
|
-
}
|
|
43
|
-
export type PromptTemplate = (question: string, options?: PromptFormatterOptions) => ChatMessage[];
|
|
44
|
-
/**
|
|
45
|
-
* Converts a ChatMessage array to a string for reward function input
|
|
46
|
-
*
|
|
47
|
-
* This allows customization of how prompts are formatted as strings
|
|
48
|
-
* for different model architectures (Qwen3, Llama, etc.)
|
|
49
|
-
*/
|
|
50
|
-
export type PromptFormatter = (messages: ChatMessage[]) => string;
|
|
51
|
-
export interface DatasetLoader {
|
|
52
|
-
load(split: DatasetSplit, limit?: number): Promise<DatasetExample[]>;
|
|
53
|
-
}
|
|
54
|
-
//# sourceMappingURL=types.d.ts.map
|
package/dist/types.d.ts.map
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"types.d.ts","sourceRoot":"","sources":["../src/types.ts"],"names":[],"mappings":"AACA,YAAY,EAAE,cAAc,EAAE,YAAY,EAAE,MAAM,gBAAgB,CAAC;AACnE,OAAO,KAAK,EAAE,YAAY,EAAE,MAAM,gBAAgB,CAAC;AAEnD,MAAM,MAAM,QAAQ,GAAG,QAAQ,GAAG,MAAM,GAAG,WAAW,GAAG,MAAM,CAAC;AAEhE,MAAM,WAAW,WAAW;IAC1B,IAAI,EAAE,QAAQ,CAAC;IACf,OAAO,EAAE,MAAM,CAAC;CACjB;AAED,MAAM,WAAW,iBAAkB,SAAQ,WAAW;CAAG;AAEzD,MAAM,MAAM,UAAU,GAAG,iBAAiB,EAAE,CAAC;AAE7C,MAAM,MAAM,YAAY,GAAG,OAAO,GAAG,MAAM,GAAG,CAAC,MAAM,GAAG,EAAE,CAAC,CAAC;AAE5D,MAAM,WAAW,cAAc;IAC7B,MAAM,EAAE,WAAW,EAAE,CAAC;IACtB,QAAQ,CAAC,EAAE,MAAM,CAAC,MAAM,EAAE,OAAO,CAAC,CAAC;CACpC;AAED,MAAM,WAAW,cAAc;IAC7B,SAAS,EAAE,MAAM,GAAG,IAAI,CAAC;IACzB,MAAM,EAAE,MAAM,GAAG,IAAI,CAAC;IACtB,aAAa,EAAE,OAAO,CAAC;IACvB,WAAW,EAAE,OAAO,CAAC;IACrB,MAAM,EAAE,MAAM,EAAE,CAAC;CAClB;AAED,MAAM,WAAW,sBAAsB;IACrC,OAAO,EAAE,WAAW,EAAE,EAAE,CAAC;IACzB,WAAW,EAAE,UAAU,EAAE,CAAC;IAC1B,OAAO,EAAE,CAAC,MAAM,GAAG,IAAI,CAAC,EAAE,CAAC;CAC5B;AAED;;;;;GAKG;AACH,MAAM,MAAM,cAAc,CAAC,CAAC,GAAG,OAAO,IAAI,CACxC,OAAO,EAAE,YAAY,EAAE,EACvB,OAAO,EAAE,CAAC,KACP,MAAM,EAAE,GAAG,YAAY,GAAG,OAAO,CAAC,MAAM,EAAE,GAAG,YAAY,CAAC,CAAC;AAEhE,MAAM,WAAW,sBAAsB;IACrC,cAAc,CAAC,EAAE,OAAO,CAAC;IACzB,cAAc,CAAC,EAAE;QACf,QAAQ,EAAE,MAAM,CAAC;QACjB,SAAS,EAAE,MAAM,CAAC;QAClB,MAAM,EAAE,MAAM,CAAC;KAChB,CAAC;CACH;AAED,MAAM,MAAM,cAAc,GAAG,CAAC,QAAQ,EAAE,MAAM,EAAE,OAAO,CAAC,EAAE,sBAAsB,KAAK,WAAW,EAAE,CAAC;AAEnG;;;;;GAKG;AACH,MAAM,MAAM,eAAe,GAAG,CAAC,QAAQ,EAAE,WAAW,EAAE,KAAK,MAAM,CAAC;AAElE,MAAM,WAAW,aAAa;IAC5B,IAAI,CAAC,KAAK,EAAE,YAAY,EAAE,KAAK,CAAC,EAAE,MAAM,GAAG,OAAO,CAAC,cAAc,EAAE,CAAC,CAAC;CACtE"}
|
package/dist/types.js
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
export {};
|
|
@@ -1,51 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* Path security utilities to prevent directory traversal attacks.
|
|
3
|
-
*
|
|
4
|
-
* These utilities ensure that user-provided paths stay within allowed directories,
|
|
5
|
-
* preventing malicious paths like `../../../etc/` from accessing arbitrary files.
|
|
6
|
-
*/
|
|
7
|
-
/**
|
|
8
|
-
* Error thrown when a path traversal attempt is detected.
|
|
9
|
-
*/
|
|
10
|
-
export declare class PathTraversalError extends Error {
|
|
11
|
-
readonly resolvedPath: string;
|
|
12
|
-
readonly allowedRoot: string;
|
|
13
|
-
constructor(resolvedPath: string, allowedRoot: string);
|
|
14
|
-
}
|
|
15
|
-
/**
|
|
16
|
-
* Validates that a resolved path is contained within an allowed root directory.
|
|
17
|
-
* Prevents path traversal attacks via '../' sequences.
|
|
18
|
-
*
|
|
19
|
-
* @param resolvedPath - The fully resolved absolute path to validate
|
|
20
|
-
* @param allowedRoot - The root directory that paths must be contained within
|
|
21
|
-
* @throws PathTraversalError if path escapes the allowed root
|
|
22
|
-
*/
|
|
23
|
-
export declare function validatePathContainment(resolvedPath: string, allowedRoot: string): void;
|
|
24
|
-
/**
|
|
25
|
-
* Resolves a user-provided path and validates it stays within an allowed root.
|
|
26
|
-
*
|
|
27
|
-
* @param userPath - The user-provided path (may be relative or absolute)
|
|
28
|
-
* @param allowedRoot - The root directory that the path must be contained within
|
|
29
|
-
* @returns The resolved absolute path
|
|
30
|
-
* @throws PathTraversalError if the resolved path escapes the allowed root
|
|
31
|
-
*/
|
|
32
|
-
export declare function resolveAndValidatePath(userPath: string, allowedRoot: string): string;
|
|
33
|
-
/**
|
|
34
|
-
* Options for configuring path validation behavior.
|
|
35
|
-
*/
|
|
36
|
-
export interface PathValidationOptions {
|
|
37
|
-
/**
|
|
38
|
-
* The root directory that all paths must be contained within.
|
|
39
|
-
* Defaults to process.cwd() if not specified.
|
|
40
|
-
*/
|
|
41
|
-
allowedRoot?: string;
|
|
42
|
-
}
|
|
43
|
-
/**
|
|
44
|
-
* Get the allowed root directory from options or environment.
|
|
45
|
-
* Checks MLX_NODE_DATA_ROOT environment variable first, then falls back to cwd.
|
|
46
|
-
*
|
|
47
|
-
* @param options - Optional validation options
|
|
48
|
-
* @returns The allowed root directory path
|
|
49
|
-
*/
|
|
50
|
-
export declare function getAllowedRoot(options?: PathValidationOptions): string;
|
|
51
|
-
//# sourceMappingURL=path-security.d.ts.map
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"path-security.d.ts","sourceRoot":"","sources":["../../src/utils/path-security.ts"],"names":[],"mappings":"AAAA;;;;;GAKG;AAIH;;GAEG;AACH,qBAAa,kBAAmB,SAAQ,KAAK;aAEzB,YAAY,EAAE,MAAM;aACpB,WAAW,EAAE,MAAM;gBADnB,YAAY,EAAE,MAAM,EACpB,WAAW,EAAE,MAAM;CAKtC;AAED;;;;;;;GAOG;AACH,wBAAgB,uBAAuB,CAAC,YAAY,EAAE,MAAM,EAAE,WAAW,EAAE,MAAM,GAAG,IAAI,CAWvF;AAED;;;;;;;GAOG;AACH,wBAAgB,sBAAsB,CAAC,QAAQ,EAAE,MAAM,EAAE,WAAW,EAAE,MAAM,GAAG,MAAM,CAIpF;AAED;;GAEG;AACH,MAAM,WAAW,qBAAqB;IACpC;;;OAGG;IACH,WAAW,CAAC,EAAE,MAAM,CAAC;CACtB;AAED;;;;;;GAMG;AACH,wBAAgB,cAAc,CAAC,OAAO,CAAC,EAAE,qBAAqB,GAAG,MAAM,CAYtE"}
|