@elizaos/training 2.0.0-alpha.13 → 2.0.0-alpha.15
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/package.json +2 -2
- package/research-output/training-runs/training-run-1773726941205.json +38 -0
- package/scripts/rank_trajectories.ts +0 -1
- package/scripts/run_task_benchmark.ts +4 -11
- package/src/adapter.ts +96 -49
- package/src/archetypes/ArchetypeConfigService.ts +188 -185
- package/src/archetypes/derive-archetype.ts +47 -47
- package/src/archetypes/index.ts +2 -2
- package/src/benchmark/ArchetypeMatchupBenchmark.ts +70 -70
- package/src/benchmark/BenchmarkChartGenerator.ts +70 -69
- package/src/benchmark/BenchmarkDataGenerator.ts +136 -136
- package/src/benchmark/BenchmarkDataViewer.ts +32 -30
- package/src/benchmark/BenchmarkHistoryService.ts +13 -12
- package/src/benchmark/BenchmarkRunner.ts +87 -83
- package/src/benchmark/BenchmarkValidator.ts +48 -46
- package/src/benchmark/FastEvalRunner.ts +17 -16
- package/src/benchmark/MetricsValidator.ts +20 -21
- package/src/benchmark/MetricsVisualizer.ts +92 -85
- package/src/benchmark/ModelBenchmarkService.ts +90 -82
- package/src/benchmark/ModelRegistry.ts +44 -44
- package/src/benchmark/RulerBenchmarkIntegration.ts +24 -24
- package/src/benchmark/SimulationA2AInterface.ts +118 -118
- package/src/benchmark/SimulationEngine.ts +51 -51
- package/src/benchmark/TaskRunner.ts +87 -79
- package/src/benchmark/__tests__/BenchmarkRunner.test.ts +80 -80
- package/src/benchmark/__tests__/HeadToHead.test.ts +26 -26
- package/src/benchmark/index.ts +27 -27
- package/src/benchmark/parseSimulationMetrics.ts +32 -32
- package/src/benchmark/simulation-types.ts +10 -10
- package/src/dependencies.ts +34 -34
- package/src/generation/TrajectoryGenerator.ts +39 -37
- package/src/generation/index.ts +1 -1
- package/src/huggingface/HuggingFaceDatasetUploader.ts +72 -72
- package/src/huggingface/HuggingFaceIntegrationService.ts +59 -53
- package/src/huggingface/HuggingFaceModelUploader.ts +60 -59
- package/src/huggingface/index.ts +6 -6
- package/src/huggingface/shared/HuggingFaceUploadUtil.ts +32 -32
- package/src/index.ts +27 -27
- package/src/init-training.ts +6 -6
- package/src/metrics/TrajectoryMetricsExtractor.ts +70 -71
- package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +182 -182
- package/src/metrics/index.ts +2 -2
- package/src/rubrics/__tests__/index.test.ts +73 -73
- package/src/rubrics/ass-kisser.ts +6 -6
- package/src/rubrics/degen.ts +6 -6
- package/src/rubrics/goody-twoshoes.ts +6 -6
- package/src/rubrics/index.ts +50 -50
- package/src/rubrics/information-trader.ts +6 -6
- package/src/rubrics/infosec.ts +6 -6
- package/src/rubrics/liar.ts +6 -6
- package/src/rubrics/perps-trader.ts +6 -6
- package/src/rubrics/researcher.ts +6 -6
- package/src/rubrics/scammer.ts +6 -6
- package/src/rubrics/social-butterfly.ts +7 -7
- package/src/rubrics/super-predictor.ts +6 -6
- package/src/rubrics/trader.ts +5 -5
- package/src/scoring/ArchetypeScoringService.ts +56 -54
- package/src/scoring/JudgePromptBuilder.ts +96 -96
- package/src/scoring/LLMJudgeCache.ts +26 -23
- package/src/scoring/index.ts +3 -3
- package/src/training/AutomationPipeline.ts +149 -140
- package/src/training/BenchmarkService.ts +49 -45
- package/src/training/ConfigValidator.ts +38 -32
- package/src/training/MarketOutcomesTracker.ts +22 -12
- package/src/training/ModelDeployer.ts +15 -15
- package/src/training/ModelFetcher.ts +7 -7
- package/src/training/ModelSelectionService.ts +32 -32
- package/src/training/ModelUsageVerifier.ts +31 -24
- package/src/training/MultiModelOrchestrator.ts +44 -44
- package/src/training/RLModelConfig.ts +57 -57
- package/src/training/RewardBackpropagationService.ts +18 -17
- package/src/training/RulerScoringService.ts +73 -72
- package/src/training/TrainingMonitor.ts +29 -29
- package/src/training/TrajectoryRecorder.ts +25 -27
- package/src/training/__tests__/TrajectoryRecorder.test.ts +105 -105
- package/src/training/index.ts +36 -36
- package/src/training/logRLConfig.ts +7 -7
- package/src/training/pipeline.ts +13 -16
- package/src/training/storage/ModelStorageService.ts +32 -32
- package/src/training/storage/TrainingDataArchiver.ts +21 -21
- package/src/training/storage/index.ts +2 -2
- package/src/training/types.ts +6 -6
- package/src/training/window-utils.ts +14 -14
- package/src/utils/index.ts +7 -7
- package/src/utils/logger.ts +5 -5
- package/src/utils/snowflake.ts +1 -1
- package/src/utils/synthetic-detector.ts +7 -7
|
@@ -13,22 +13,23 @@
|
|
|
13
13
|
* Based on: https://art.openpipe.ai/fundamentals/ruler
|
|
14
14
|
*/
|
|
15
15
|
|
|
16
|
-
import { getTrainingDataAdapter, type JsonValue, type UUID } from
|
|
16
|
+
import { getTrainingDataAdapter, type JsonValue, type UUID } from "../adapter";
|
|
17
17
|
|
|
18
18
|
/** Cast string to UUID (replaces @elizaos/core asUUID) */
|
|
19
19
|
function asUUID(id: string): UUID {
|
|
20
20
|
return id as UUID;
|
|
21
21
|
}
|
|
22
|
-
|
|
22
|
+
|
|
23
|
+
import { v4 as uuidv4 } from "uuid";
|
|
23
24
|
import {
|
|
24
25
|
getLLMCaller,
|
|
25
26
|
getToTrainingMessages,
|
|
26
27
|
type TrajectoryForTraining,
|
|
27
28
|
type TrajectoryStepForTraining,
|
|
28
|
-
} from
|
|
29
|
-
import { getRubric, sanitizeArchetype } from
|
|
30
|
-
import { logger, splitIntoBatches } from
|
|
31
|
-
import type { TrajectoryStep as TrainingTrajectoryStep } from
|
|
29
|
+
} from "../dependencies";
|
|
30
|
+
import { getRubric, sanitizeArchetype } from "../rubrics";
|
|
31
|
+
import { logger, splitIntoBatches } from "../utils";
|
|
32
|
+
import type { TrajectoryStep as TrainingTrajectoryStep } from "./types";
|
|
32
33
|
|
|
33
34
|
// Use types from dependencies
|
|
34
35
|
type RichTrajectory = TrajectoryForTraining;
|
|
@@ -43,7 +44,7 @@ export interface RulerScore {
|
|
|
43
44
|
|
|
44
45
|
export interface MarketOutcomes {
|
|
45
46
|
stocks: Array<{ ticker: string; changePercent: number }>;
|
|
46
|
-
predictions: Array<{ marketId: string; outcome:
|
|
47
|
+
predictions: Array<{ marketId: string; outcome: "YES" | "NO" }>;
|
|
47
48
|
}
|
|
48
49
|
|
|
49
50
|
interface TrajectoryScore {
|
|
@@ -83,21 +84,21 @@ export class RulerScoringService {
|
|
|
83
84
|
const trajectoriesResult = await this.getTrajectoriesToScore(trajectoryIds);
|
|
84
85
|
|
|
85
86
|
if (trajectoriesResult.length === 0) {
|
|
86
|
-
logger.info(
|
|
87
|
+
logger.info("No trajectories to score", {}, "RulerScoring");
|
|
87
88
|
return 0;
|
|
88
89
|
}
|
|
89
90
|
|
|
90
91
|
const groups = this.groupByScenario(trajectoriesResult);
|
|
91
92
|
|
|
92
93
|
logger.info(
|
|
93
|
-
|
|
94
|
+
"Grouped trajectories for RULER scoring",
|
|
94
95
|
{
|
|
95
96
|
totalTrajectories: trajectoriesResult.length,
|
|
96
97
|
groups: groups.length,
|
|
97
98
|
avgGroupSize:
|
|
98
99
|
groups.length > 0 ? trajectoriesResult.length / groups.length : 0,
|
|
99
100
|
},
|
|
100
|
-
|
|
101
|
+
"RulerScoring",
|
|
101
102
|
);
|
|
102
103
|
|
|
103
104
|
let totalScored = 0;
|
|
@@ -105,13 +106,13 @@ export class RulerScoringService {
|
|
|
105
106
|
for (const group of groups) {
|
|
106
107
|
if (group.trajectories.length < this.minGroupSize) {
|
|
107
108
|
logger.warn(
|
|
108
|
-
|
|
109
|
+
"Skipping group with insufficient trajectories",
|
|
109
110
|
{
|
|
110
111
|
scenarioId: group.scenarioId,
|
|
111
112
|
count: group.trajectories.length,
|
|
112
113
|
minRequired: this.minGroupSize,
|
|
113
114
|
},
|
|
114
|
-
|
|
115
|
+
"RulerScoring",
|
|
115
116
|
);
|
|
116
117
|
continue;
|
|
117
118
|
}
|
|
@@ -125,12 +126,12 @@ export class RulerScoringService {
|
|
|
125
126
|
}
|
|
126
127
|
|
|
127
128
|
logger.info(
|
|
128
|
-
|
|
129
|
+
"RULER scoring complete",
|
|
129
130
|
{
|
|
130
131
|
totalScored,
|
|
131
132
|
totalTrajectories: trajectoriesResult.length,
|
|
132
133
|
},
|
|
133
|
-
|
|
134
|
+
"RulerScoring",
|
|
134
135
|
);
|
|
135
136
|
|
|
136
137
|
return totalScored;
|
|
@@ -148,7 +149,8 @@ export class RulerScoringService {
|
|
|
148
149
|
return null;
|
|
149
150
|
}
|
|
150
151
|
|
|
151
|
-
const updated =
|
|
152
|
+
const updated =
|
|
153
|
+
await getTrainingDataAdapter().getTrajectoryById(trajectoryId);
|
|
152
154
|
|
|
153
155
|
if (!updated || updated.aiJudgeReward === null) {
|
|
154
156
|
return null;
|
|
@@ -157,7 +159,7 @@ export class RulerScoringService {
|
|
|
157
159
|
return {
|
|
158
160
|
trajectoryId: updated.trajectoryId,
|
|
159
161
|
overallScore: updated.aiJudgeReward,
|
|
160
|
-
reasoning: updated.aiJudgeReasoning ||
|
|
162
|
+
reasoning: updated.aiJudgeReasoning || "",
|
|
161
163
|
scoredAt: updated.judgedAt || new Date(),
|
|
162
164
|
};
|
|
163
165
|
}
|
|
@@ -181,7 +183,7 @@ export class RulerScoringService {
|
|
|
181
183
|
episodeLength: number | null;
|
|
182
184
|
archetype: string | null;
|
|
183
185
|
}>,
|
|
184
|
-
scenarioId: string
|
|
186
|
+
scenarioId: string,
|
|
185
187
|
): Promise<number> {
|
|
186
188
|
const richTrajectories: Array<{
|
|
187
189
|
traj: RichTrajectory;
|
|
@@ -192,15 +194,15 @@ export class RulerScoringService {
|
|
|
192
194
|
for (const dbTraj of trajectoriesData) {
|
|
193
195
|
if (
|
|
194
196
|
!dbTraj.stepsJson ||
|
|
195
|
-
dbTraj.stepsJson ===
|
|
196
|
-
dbTraj.stepsJson ===
|
|
197
|
+
dbTraj.stepsJson === "null" ||
|
|
198
|
+
dbTraj.stepsJson === "[]"
|
|
197
199
|
) {
|
|
198
200
|
logger.warn(
|
|
199
|
-
|
|
201
|
+
"Skipping trajectory with invalid stepsJson",
|
|
200
202
|
{
|
|
201
203
|
trajectoryId: dbTraj.trajectoryId,
|
|
202
204
|
},
|
|
203
|
-
|
|
205
|
+
"RulerScoring",
|
|
204
206
|
);
|
|
205
207
|
continue;
|
|
206
208
|
}
|
|
@@ -249,11 +251,11 @@ export class RulerScoringService {
|
|
|
249
251
|
maxTokens: l.maxTokens,
|
|
250
252
|
latencyMs: l.latencyMs,
|
|
251
253
|
purpose: l.purpose as
|
|
252
|
-
|
|
|
253
|
-
|
|
|
254
|
-
|
|
|
255
|
-
|
|
|
256
|
-
|
|
|
254
|
+
| "action"
|
|
255
|
+
| "reasoning"
|
|
256
|
+
| "evaluation"
|
|
257
|
+
| "response"
|
|
258
|
+
| "other",
|
|
257
259
|
actionType: l.actionType,
|
|
258
260
|
})),
|
|
259
261
|
action: {
|
|
@@ -270,7 +272,7 @@ export class RulerScoringService {
|
|
|
270
272
|
reward: s.reward,
|
|
271
273
|
done: idx === steps.length - 1,
|
|
272
274
|
metadata: {},
|
|
273
|
-
})
|
|
275
|
+
}),
|
|
274
276
|
),
|
|
275
277
|
totalReward: steps.reduce((sum, s) => sum + s.reward, 0),
|
|
276
278
|
rewardComponents: {
|
|
@@ -278,7 +280,7 @@ export class RulerScoringService {
|
|
|
278
280
|
},
|
|
279
281
|
metrics: {
|
|
280
282
|
episodeLength: dbTraj.episodeLength || steps.length,
|
|
281
|
-
finalStatus:
|
|
283
|
+
finalStatus: "completed",
|
|
282
284
|
finalPnL: dbTraj.finalPnL || undefined,
|
|
283
285
|
},
|
|
284
286
|
metadata: {
|
|
@@ -295,24 +297,24 @@ export class RulerScoringService {
|
|
|
295
297
|
|
|
296
298
|
if (richTrajectories.length < this.minGroupSize) {
|
|
297
299
|
logger.warn(
|
|
298
|
-
|
|
300
|
+
"Insufficient valid trajectories in group",
|
|
299
301
|
{
|
|
300
302
|
scenarioId,
|
|
301
303
|
validCount: richTrajectories.length,
|
|
302
304
|
},
|
|
303
|
-
|
|
305
|
+
"RulerScoring",
|
|
304
306
|
);
|
|
305
307
|
return 0;
|
|
306
308
|
}
|
|
307
309
|
|
|
308
310
|
const commonPrefix = this.extractCommonPrefix(
|
|
309
|
-
richTrajectories.map((rt) => rt.messages)
|
|
311
|
+
richTrajectories.map((rt) => rt.messages),
|
|
310
312
|
);
|
|
311
313
|
|
|
312
314
|
const judgePrompt = this.buildJudgePrompt(
|
|
313
315
|
richTrajectories,
|
|
314
316
|
commonPrefix,
|
|
315
|
-
scenarioId
|
|
317
|
+
scenarioId,
|
|
316
318
|
);
|
|
317
319
|
|
|
318
320
|
const judgeResponse = await this.callJudge(judgePrompt);
|
|
@@ -322,12 +324,12 @@ export class RulerScoringService {
|
|
|
322
324
|
judgeResponse.scores.length !== richTrajectories.length
|
|
323
325
|
) {
|
|
324
326
|
logger.error(
|
|
325
|
-
|
|
327
|
+
"Invalid judge response",
|
|
326
328
|
{
|
|
327
329
|
expectedScores: richTrajectories.length,
|
|
328
330
|
receivedScores: judgeResponse?.scores.length || 0,
|
|
329
331
|
},
|
|
330
|
-
|
|
332
|
+
"RulerScoring",
|
|
331
333
|
);
|
|
332
334
|
return 0;
|
|
333
335
|
}
|
|
@@ -344,35 +346,35 @@ export class RulerScoringService {
|
|
|
344
346
|
|
|
345
347
|
if (!scoreData) {
|
|
346
348
|
logger.warn(
|
|
347
|
-
|
|
349
|
+
"Judge did not return score for trajectory",
|
|
348
350
|
{
|
|
349
351
|
expectedTrajId,
|
|
350
352
|
receivedIds: judgeResponse.scores.map((s) => s.trajectory_id),
|
|
351
353
|
},
|
|
352
|
-
|
|
354
|
+
"RulerScoring",
|
|
353
355
|
);
|
|
354
356
|
continue;
|
|
355
357
|
}
|
|
356
358
|
|
|
357
|
-
const trajectoryId = richTrajectories[i]
|
|
359
|
+
const trajectoryId = richTrajectories[i]?.traj.trajectoryId;
|
|
358
360
|
|
|
359
361
|
await getTrainingDataAdapter().updateTrajectoryScore(
|
|
360
362
|
trajectoryId,
|
|
361
363
|
Math.max(0, Math.min(1, scoreData.score)),
|
|
362
|
-
scoreData.explanation
|
|
364
|
+
scoreData.explanation,
|
|
363
365
|
);
|
|
364
366
|
|
|
365
367
|
scored++;
|
|
366
368
|
}
|
|
367
369
|
|
|
368
370
|
logger.info(
|
|
369
|
-
|
|
371
|
+
"Scored trajectory group",
|
|
370
372
|
{
|
|
371
373
|
scenarioId,
|
|
372
374
|
scored,
|
|
373
375
|
groupSize: richTrajectories.length,
|
|
374
376
|
},
|
|
375
|
-
|
|
377
|
+
"RulerScoring",
|
|
376
378
|
);
|
|
377
379
|
|
|
378
380
|
return scored;
|
|
@@ -391,13 +393,13 @@ export class RulerScoringService {
|
|
|
391
393
|
archetype: string;
|
|
392
394
|
}>,
|
|
393
395
|
commonPrefix: Array<{ role: string; content: string }>,
|
|
394
|
-
scenarioId: string
|
|
396
|
+
scenarioId: string,
|
|
395
397
|
): string {
|
|
396
398
|
// Build context section with game knowledge (injected into prompt)
|
|
397
399
|
const contextParts: string[] = [];
|
|
398
400
|
contextParts.push(`Scenario: ${scenarioId}`);
|
|
399
401
|
contextParts.push(
|
|
400
|
-
`\nTrajectory Performance Context (use this to inform your scoring)
|
|
402
|
+
`\nTrajectory Performance Context (use this to inform your scoring):`,
|
|
401
403
|
);
|
|
402
404
|
|
|
403
405
|
for (let i = 0; i < richTrajectories.length; i++) {
|
|
@@ -407,24 +409,24 @@ export class RulerScoringService {
|
|
|
407
409
|
contextParts.push(`\n${trajId}:`);
|
|
408
410
|
contextParts.push(` - Archetype: ${rt.archetype}`);
|
|
409
411
|
contextParts.push(
|
|
410
|
-
` - Final P&L: $${rt.traj.metrics.finalPnL?.toFixed(2) ||
|
|
412
|
+
` - Final P&L: $${rt.traj.metrics.finalPnL?.toFixed(2) || "0.00"}`,
|
|
411
413
|
);
|
|
412
414
|
contextParts.push(
|
|
413
|
-
` - Episode Length: ${rt.traj.metrics.episodeLength || 0} steps
|
|
415
|
+
` - Episode Length: ${rt.traj.metrics.episodeLength || 0} steps`,
|
|
414
416
|
);
|
|
415
417
|
contextParts.push(` - Total Reward: ${rt.traj.totalReward.toFixed(2)}`);
|
|
416
418
|
|
|
417
419
|
const actionTypes = rt.traj.steps
|
|
418
420
|
.filter((s: TrajectoryStep): boolean => !!s.action)
|
|
419
|
-
.map((s: TrajectoryStep): string => s.action
|
|
421
|
+
.map((s: TrajectoryStep): string => s.action?.actionType);
|
|
420
422
|
const uniqueActions = [...new Set(actionTypes)];
|
|
421
423
|
contextParts.push(
|
|
422
|
-
` - Actions Taken: ${uniqueActions.join(
|
|
424
|
+
` - Actions Taken: ${uniqueActions.join(", ")} (${actionTypes.length} total)`,
|
|
423
425
|
);
|
|
424
426
|
|
|
425
427
|
// Add success/error info
|
|
426
428
|
const errors = rt.traj.steps.filter(
|
|
427
|
-
(s: TrajectoryStep): boolean => !!s.action && !s.action.success
|
|
429
|
+
(s: TrajectoryStep): boolean => !!s.action && !s.action.success,
|
|
428
430
|
).length;
|
|
429
431
|
const successRate =
|
|
430
432
|
rt.traj.steps.length > 0
|
|
@@ -432,7 +434,7 @@ export class RulerScoringService {
|
|
|
432
434
|
((rt.traj.steps.length - errors) / rt.traj.steps.length) *
|
|
433
435
|
100
|
|
434
436
|
).toFixed(1)
|
|
435
|
-
:
|
|
437
|
+
: "0";
|
|
436
438
|
contextParts.push(` - Success Rate: ${successRate}%`);
|
|
437
439
|
|
|
438
440
|
if (errors > 0) {
|
|
@@ -462,24 +464,24 @@ export class RulerScoringService {
|
|
|
462
464
|
const userContent =
|
|
463
465
|
commonPrefix.length > 0
|
|
464
466
|
? `<context>\n${JSON.stringify(commonPrefix, null, 2)}\n</context>\n\n`
|
|
465
|
-
:
|
|
467
|
+
: "";
|
|
466
468
|
|
|
467
|
-
const prompt = `${userContent}${contextParts.join(
|
|
469
|
+
const prompt = `${userContent}${contextParts.join("\n")}\n\nTrajectories:\n\n${trajectorySections.join("\n\n")}`;
|
|
468
470
|
|
|
469
471
|
// Determine archetype-specific rubric
|
|
470
472
|
// If all trajectories share the same archetype, use that archetype's rubric
|
|
471
473
|
// Otherwise, fall back to the default rubric
|
|
472
474
|
const archetypes = [...new Set(richTrajectories.map((rt) => rt.archetype))];
|
|
473
475
|
const isSingleArchetype =
|
|
474
|
-
archetypes.length === 1 && archetypes[0] !==
|
|
476
|
+
archetypes.length === 1 && archetypes[0] !== "default";
|
|
475
477
|
const rubric = isSingleArchetype
|
|
476
478
|
? getRubric(archetypes[0]!)
|
|
477
479
|
: DEFAULT_RUBRIC;
|
|
478
480
|
const archetypeContext = isSingleArchetype
|
|
479
|
-
? `\n\nYou are evaluating ${archetypes[0]
|
|
481
|
+
? `\n\nYou are evaluating ${archetypes[0]?.toUpperCase()} agents. Score them based on how well they embody that archetype's behavior and goals.`
|
|
480
482
|
: archetypes.length > 1
|
|
481
|
-
? `\n\nNote: This group contains mixed archetypes (${archetypes.join(
|
|
482
|
-
:
|
|
483
|
+
? `\n\nNote: This group contains mixed archetypes (${archetypes.join(", ")}). Consider each agent's archetype when scoring.`
|
|
484
|
+
: "";
|
|
483
485
|
|
|
484
486
|
const systemPrompt = `You are an expert evaluator of AI agent performance. All trajectories below were given the same goal/scenario. Your job is to compare them and assign scores from 0 to 1 based on how well each trajectory achieved its goal.${archetypeContext}
|
|
485
487
|
|
|
@@ -529,26 +531,26 @@ Return ONLY the JSON, no other text.`;
|
|
|
529
531
|
const response = await llmCaller.callGroqDirect({
|
|
530
532
|
prompt: structuredPrompt,
|
|
531
533
|
system: promptData.system,
|
|
532
|
-
modelSize:
|
|
534
|
+
modelSize: "large",
|
|
533
535
|
temperature: 0.3,
|
|
534
536
|
maxTokens: 2000,
|
|
535
|
-
actionType:
|
|
537
|
+
actionType: "ruler_score_trajectories",
|
|
536
538
|
});
|
|
537
539
|
|
|
538
540
|
let jsonText = response.trim();
|
|
539
541
|
jsonText = jsonText
|
|
540
|
-
.replace(/```json\n?/g,
|
|
541
|
-
.replace(/```\n?/g,
|
|
542
|
+
.replace(/```json\n?/g, "")
|
|
543
|
+
.replace(/```\n?/g, "")
|
|
542
544
|
.trim();
|
|
543
545
|
|
|
544
546
|
const jsonMatch = jsonText.match(/\{[\s\S]*\}/);
|
|
545
547
|
if (!jsonMatch) {
|
|
546
548
|
logger.error(
|
|
547
|
-
|
|
549
|
+
"Judge response does not contain JSON",
|
|
548
550
|
{
|
|
549
551
|
response: response.substring(0, 500),
|
|
550
552
|
},
|
|
551
|
-
|
|
553
|
+
"RulerScoring",
|
|
552
554
|
);
|
|
553
555
|
return null;
|
|
554
556
|
}
|
|
@@ -557,9 +559,9 @@ Return ONLY the JSON, no other text.`;
|
|
|
557
559
|
|
|
558
560
|
if (!parsed.scores || !Array.isArray(parsed.scores)) {
|
|
559
561
|
logger.error(
|
|
560
|
-
|
|
562
|
+
"Invalid judge response structure",
|
|
561
563
|
{ parsed },
|
|
562
|
-
|
|
564
|
+
"RulerScoring",
|
|
563
565
|
);
|
|
564
566
|
return null;
|
|
565
567
|
}
|
|
@@ -579,7 +581,7 @@ Return ONLY the JSON, no other text.`;
|
|
|
579
581
|
* RULER deduplicates common prefixes to save tokens.
|
|
580
582
|
*/
|
|
581
583
|
private extractCommonPrefix(
|
|
582
|
-
messageLists: Array<Array<{ role: string; content: string }
|
|
584
|
+
messageLists: Array<Array<{ role: string; content: string }>>,
|
|
583
585
|
): Array<{ role: string; content: string }> {
|
|
584
586
|
if (messageLists.length === 0) return [];
|
|
585
587
|
|
|
@@ -591,8 +593,8 @@ Return ONLY the JSON, no other text.`;
|
|
|
591
593
|
const allMatch = messageLists.every(
|
|
592
594
|
(msgs) =>
|
|
593
595
|
msgs[i] &&
|
|
594
|
-
msgs[i]
|
|
595
|
-
msgs[i]
|
|
596
|
+
msgs[i]?.role === msg.role &&
|
|
597
|
+
msgs[i]?.content === msg.content,
|
|
596
598
|
);
|
|
597
599
|
|
|
598
600
|
if (allMatch) {
|
|
@@ -616,16 +618,16 @@ Return ONLY the JSON, no other text.`;
|
|
|
616
618
|
finalPnL: number | null;
|
|
617
619
|
episodeLength: number | null;
|
|
618
620
|
archetype: string | null;
|
|
619
|
-
}
|
|
621
|
+
}>,
|
|
620
622
|
): Array<{ scenarioId: string; trajectories: typeof trajectoriesData }> {
|
|
621
623
|
const groups = new Map<string, typeof trajectoriesData>();
|
|
622
624
|
|
|
623
625
|
for (const traj of trajectoriesData) {
|
|
624
|
-
const scenarioId = traj.scenarioId ||
|
|
626
|
+
const scenarioId = traj.scenarioId || "default";
|
|
625
627
|
if (!groups.has(scenarioId)) {
|
|
626
628
|
groups.set(scenarioId, []);
|
|
627
629
|
}
|
|
628
|
-
groups.get(scenarioId)
|
|
630
|
+
groups.get(scenarioId)?.push(traj);
|
|
629
631
|
}
|
|
630
632
|
|
|
631
633
|
return Array.from(groups.entries()).map(([scenarioId, trajs]) => ({
|
|
@@ -640,9 +642,7 @@ Return ONLY the JSON, no other text.`;
|
|
|
640
642
|
private async getTrajectoriesToScore(trajectoryIds?: string[]) {
|
|
641
643
|
const adapter = getTrainingDataAdapter();
|
|
642
644
|
return await adapter.getUnscoredTrajectories(
|
|
643
|
-
trajectoryIds && trajectoryIds.length > 0
|
|
644
|
-
? { trajectoryIds }
|
|
645
|
-
: undefined
|
|
645
|
+
trajectoryIds && trajectoryIds.length > 0 ? { trajectoryIds } : undefined,
|
|
646
646
|
);
|
|
647
647
|
}
|
|
648
648
|
|
|
@@ -650,7 +650,8 @@ Return ONLY the JSON, no other text.`;
|
|
|
650
650
|
* Score all unscored trajectories in a time window
|
|
651
651
|
*/
|
|
652
652
|
async scoreWindow(windowId: string): Promise<number> {
|
|
653
|
-
const trajectoryIds =
|
|
653
|
+
const trajectoryIds =
|
|
654
|
+
await getTrainingDataAdapter().getUnscoredWindowTrajectoryIds(windowId);
|
|
654
655
|
|
|
655
656
|
if (trajectoryIds.length === 0) {
|
|
656
657
|
return 0;
|
|
@@ -5,17 +5,17 @@
|
|
|
5
5
|
* Monitors Python training process and W&B runs.
|
|
6
6
|
*/
|
|
7
7
|
|
|
8
|
-
import { getTrainingDataAdapter } from
|
|
9
|
-
import { logger } from
|
|
8
|
+
import { getTrainingDataAdapter } from "../adapter";
|
|
9
|
+
import { logger } from "../utils/logger";
|
|
10
10
|
|
|
11
11
|
export type TrainingStatus =
|
|
12
|
-
|
|
|
13
|
-
|
|
|
14
|
-
|
|
|
15
|
-
|
|
|
16
|
-
|
|
|
17
|
-
|
|
|
18
|
-
|
|
|
12
|
+
| "pending"
|
|
13
|
+
| "preparing"
|
|
14
|
+
| "scoring"
|
|
15
|
+
| "training"
|
|
16
|
+
| "uploading"
|
|
17
|
+
| "completed"
|
|
18
|
+
| "failed";
|
|
19
19
|
|
|
20
20
|
export interface TrainingProgress {
|
|
21
21
|
batchId: string;
|
|
@@ -36,12 +36,12 @@ export class TrainingMonitor {
|
|
|
36
36
|
*/
|
|
37
37
|
async startMonitoring(batchId: string): Promise<void> {
|
|
38
38
|
const adapter = getTrainingDataAdapter();
|
|
39
|
-
await adapter.updateBatchStatus(batchId,
|
|
39
|
+
await adapter.updateBatchStatus(batchId, "training");
|
|
40
40
|
|
|
41
41
|
logger.info(
|
|
42
|
-
|
|
42
|
+
"Started monitoring training job",
|
|
43
43
|
{ batchId },
|
|
44
|
-
|
|
44
|
+
"TrainingMonitor",
|
|
45
45
|
);
|
|
46
46
|
}
|
|
47
47
|
|
|
@@ -50,23 +50,23 @@ export class TrainingMonitor {
|
|
|
50
50
|
*/
|
|
51
51
|
async updateProgress(
|
|
52
52
|
batchId: string,
|
|
53
|
-
progress: Partial<TrainingProgress
|
|
53
|
+
progress: Partial<TrainingProgress>,
|
|
54
54
|
): Promise<void> {
|
|
55
55
|
if (progress.status) {
|
|
56
56
|
const adapter = getTrainingDataAdapter();
|
|
57
57
|
const errorMsg =
|
|
58
|
-
progress.status ===
|
|
58
|
+
progress.status === "failed" ? progress.error : undefined;
|
|
59
59
|
await adapter.updateBatchStatus(batchId, progress.status, errorMsg);
|
|
60
60
|
}
|
|
61
61
|
|
|
62
62
|
logger.info(
|
|
63
|
-
|
|
63
|
+
"Updated training progress",
|
|
64
64
|
{
|
|
65
65
|
batchId,
|
|
66
66
|
status: progress.status,
|
|
67
67
|
progress: progress.progress,
|
|
68
68
|
},
|
|
69
|
-
|
|
69
|
+
"TrainingMonitor",
|
|
70
70
|
);
|
|
71
71
|
}
|
|
72
72
|
|
|
@@ -84,32 +84,32 @@ export class TrainingMonitor {
|
|
|
84
84
|
// Calculate progress based on status
|
|
85
85
|
let progress = 0;
|
|
86
86
|
switch (batch.status) {
|
|
87
|
-
case
|
|
87
|
+
case "pending":
|
|
88
88
|
progress = 0;
|
|
89
89
|
break;
|
|
90
|
-
case
|
|
90
|
+
case "preparing":
|
|
91
91
|
progress = 0.1;
|
|
92
92
|
break;
|
|
93
|
-
case
|
|
93
|
+
case "scoring":
|
|
94
94
|
progress = 0.3;
|
|
95
95
|
break;
|
|
96
|
-
case
|
|
96
|
+
case "training":
|
|
97
97
|
progress = 0.6;
|
|
98
98
|
break;
|
|
99
|
-
case
|
|
99
|
+
case "uploading":
|
|
100
100
|
progress = 0.9;
|
|
101
101
|
break;
|
|
102
|
-
case
|
|
102
|
+
case "completed":
|
|
103
103
|
progress = 1.0;
|
|
104
104
|
break;
|
|
105
|
-
case
|
|
105
|
+
case "failed":
|
|
106
106
|
progress = 0;
|
|
107
107
|
break;
|
|
108
108
|
}
|
|
109
109
|
|
|
110
110
|
// Estimate ETA based on average training time
|
|
111
111
|
let eta: number | undefined;
|
|
112
|
-
if (batch.status ===
|
|
112
|
+
if (batch.status === "training" && batch.startedAt) {
|
|
113
113
|
const avgTrainingTime = 2 * 60 * 60 * 1000; // 2 hours average
|
|
114
114
|
const elapsed = Date.now() - batch.startedAt.getTime();
|
|
115
115
|
eta = Math.max(0, avgTrainingTime - elapsed);
|
|
@@ -135,12 +135,12 @@ export class TrainingMonitor {
|
|
|
135
135
|
|
|
136
136
|
if (stuckJobs.length > 0) {
|
|
137
137
|
logger.warn(
|
|
138
|
-
|
|
138
|
+
"Found stuck training jobs",
|
|
139
139
|
{
|
|
140
140
|
count: stuckJobs.length,
|
|
141
141
|
jobs: stuckJobs,
|
|
142
142
|
},
|
|
143
|
-
|
|
143
|
+
"TrainingMonitor",
|
|
144
144
|
);
|
|
145
145
|
}
|
|
146
146
|
|
|
@@ -152,12 +152,12 @@ export class TrainingMonitor {
|
|
|
152
152
|
*/
|
|
153
153
|
async cancelJob(batchId: string, reason: string): Promise<void> {
|
|
154
154
|
const adapter = getTrainingDataAdapter();
|
|
155
|
-
await adapter.updateBatchStatus(batchId,
|
|
155
|
+
await adapter.updateBatchStatus(batchId, "failed", `Cancelled: ${reason}`);
|
|
156
156
|
|
|
157
157
|
logger.warn(
|
|
158
|
-
|
|
158
|
+
"Training job cancelled",
|
|
159
159
|
{ batchId, reason },
|
|
160
|
-
|
|
160
|
+
"TrainingMonitor",
|
|
161
161
|
);
|
|
162
162
|
}
|
|
163
163
|
}
|