@elizaos/training 2.0.0-alpha.11
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Dockerfile +75 -0
- package/Makefile +374 -0
- package/README.md +346 -0
- package/config/rubrics.json +137 -0
- package/data/.gitkeep +0 -0
- package/data/degen/.gitkeep +2 -0
- package/data/trader/.gitkeep +2 -0
- package/docker-compose.test.yml +57 -0
- package/package.json +58 -0
- package/python/config/babylon_atropos.yaml +90 -0
- package/python/config/profiles/12gb.json +11 -0
- package/python/config/profiles/16gb.json +10 -0
- package/python/config/profiles/24gb.json +10 -0
- package/python/config/profiles/48gb.json +10 -0
- package/python/config/profiles/cpu.json +11 -0
- package/python/config/profiles/l40-2gpu-safe.json +20 -0
- package/python/config/profiles/l40-2gpu.json +22 -0
- package/python/config/profiles/l40-4gpu.json +21 -0
- package/python/config/profiles/l40.json +17 -0
- package/python/config/tinker_training.yaml +143 -0
- package/python/curriculum_state.json +165 -0
- package/python/env.template +86 -0
- package/python/env.training.template +46 -0
- package/python/pyproject.toml +41 -0
- package/python/requirements-ci.txt +31 -0
- package/python/requirements.txt +87 -0
- package/python/scripts/__init__.py +4 -0
- package/python/scripts/import_json_trajectories.py +412 -0
- package/python/scripts/local-finetune/README.md +63 -0
- package/python/scripts/local-finetune/ingest_and_score.py +139 -0
- package/python/scripts/local-finetune/merge_model.py +32 -0
- package/python/scripts/local-finetune/test_adapter.py +91 -0
- package/python/scripts/local-finetune/train_from_csv.py +132 -0
- package/python/scripts/merge_trajectories.py +318 -0
- package/python/scripts/run_ab_test.py +143 -0
- package/python/scripts/run_full_pipeline.py +544 -0
- package/python/scripts/run_tinker_training.py +192 -0
- package/python/scripts/run_training.py +914 -0
- package/python/scripts/test_judge.py +155 -0
- package/python/scripts/test_pipeline.py +356 -0
- package/python/scripts/test_trained_model.py +380 -0
- package/python/scripts/train_local.py +528 -0
- package/python/setup.py +20 -0
- package/python/src/__init__.py +190 -0
- package/python/src/data_bridge/__init__.py +24 -0
- package/python/src/data_bridge/converter.py +435 -0
- package/python/src/data_bridge/reader.py +393 -0
- package/python/src/models.py +283 -0
- package/python/src/training/__init__.py +605 -0
- package/python/src/training/ab_testing.py +404 -0
- package/python/src/training/action_executor.py +621 -0
- package/python/src/training/archetype_trainer.py +347 -0
- package/python/src/training/atropos_trainer.py +980 -0
- package/python/src/training/babylon_env.py +1254 -0
- package/python/src/training/error_recovery.py +647 -0
- package/python/src/training/evaluation.py +856 -0
- package/python/src/training/fast_simulator.py +880 -0
- package/python/src/training/format_validator.py +584 -0
- package/python/src/training/hybrid_env.py +522 -0
- package/python/src/training/kl_controller.py +628 -0
- package/python/src/training/multi_prompt_dataset.py +883 -0
- package/python/src/training/multi_turn.py +656 -0
- package/python/src/training/online_env.py +1084 -0
- package/python/src/training/quality_scorer.py +391 -0
- package/python/src/training/quality_utils.py +633 -0
- package/python/src/training/rewards.py +1344 -0
- package/python/src/training/rlaif_env.py +17 -0
- package/python/src/training/rollout_generator.py +502 -0
- package/python/src/training/rubric_loader.py +198 -0
- package/python/src/training/scenario_pool.py +1072 -0
- package/python/src/training/schemas.py +481 -0
- package/python/src/training/service_manager.py +552 -0
- package/python/src/training/simulation_bridge.py +535 -0
- package/python/src/training/tick_reward_attribution.py +399 -0
- package/python/src/training/tinker_client.py +575 -0
- package/python/src/training/tinker_trainer.py +646 -0
- package/python/src/training/tokenization_utils.py +402 -0
- package/python/tests/e2e/__init__.py +13 -0
- package/python/tests/e2e/conftest.py +258 -0
- package/python/tests/e2e/test_full_pipeline.py +643 -0
- package/python/tests/e2e/test_online_training_e2e.py +365 -0
- package/python/tests/integration/__init__.py +12 -0
- package/python/tests/integration/conftest.py +383 -0
- package/python/tests/integration/test_db_integration.py +649 -0
- package/python/tests/integration/test_json_mode_integration.py +554 -0
- package/python/tests/test_action_executor.py +594 -0
- package/python/tests/test_archetype_scoring.py +1027 -0
- package/python/tests/test_atropos_integration.py +360 -0
- package/python/tests/test_evaluation.py +727 -0
- package/python/tests/test_format_validator.py +486 -0
- package/python/tests/test_kl_controller.py +432 -0
- package/python/tests/test_lr_scheduler.py +579 -0
- package/python/tests/test_multi_turn.py +590 -0
- package/python/tests/test_online_env.py +519 -0
- package/python/tests/test_quality_scorer.py +474 -0
- package/python/tests/test_scenario_pool.py +735 -0
- package/python/tests/test_service_manager.py +585 -0
- package/python/tests/test_simulation_rollout.py +581 -0
- package/python/tests/test_tokenization_utils.py +501 -0
- package/python/tests/test_training_orchestrator.py +497 -0
- package/python/tests/test_training_output_structure.py +661 -0
- package/research-output/training-runs/training-run-1770772042899.json +26 -0
- package/research-output/training-runs/training-run-1770930079670.json +32 -0
- package/research-output/training-runs/training-run-1770930143700.json +44 -0
- package/research-output/training-runs/training-run-1770930183638.json +38 -0
- package/research-output/training-runs/training-run-1770930442049.json +38 -0
- package/research-output/training-runs/training-run-1770930793243.json +38 -0
- package/scripts/assess-training-data.ts +422 -0
- package/scripts/e2e-training-test.ts +550 -0
- package/scripts/export-rubrics.ts +64 -0
- package/scripts/generate-research-report.ts +1523 -0
- package/scripts/generate_dataset.sh +173 -0
- package/scripts/json-mode-benchmark.ts +399 -0
- package/scripts/real-archetype-benchmark.ts +210 -0
- package/scripts/run-baseline-comparison.ts +116 -0
- package/scripts/run-full-pipeline.ts +272 -0
- package/scripts/runpod_setup.sh +137 -0
- package/scripts/runpod_validate.sh +147 -0
- package/scripts/test-model-in-game.ts +955 -0
- package/scripts/test-scoring.ts +73 -0
- package/scripts/test-trained-model.ts +209 -0
- package/scripts/train-and-test.ts +824 -0
- package/scripts/verify-final.ts +118 -0
- package/src/adapter.ts +516 -0
- package/src/archetypes/ArchetypeConfigService.ts +626 -0
- package/src/archetypes/derive-archetype.ts +249 -0
- package/src/archetypes/index.ts +22 -0
- package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
- package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
- package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
- package/src/benchmark/BenchmarkDataViewer.ts +324 -0
- package/src/benchmark/BenchmarkHistoryService.ts +221 -0
- package/src/benchmark/BenchmarkRunner.ts +685 -0
- package/src/benchmark/BenchmarkValidator.ts +206 -0
- package/src/benchmark/FastEvalRunner.ts +225 -0
- package/src/benchmark/MetricsValidator.ts +165 -0
- package/src/benchmark/MetricsVisualizer.ts +909 -0
- package/src/benchmark/ModelBenchmarkService.ts +611 -0
- package/src/benchmark/ModelRegistry.ts +158 -0
- package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
- package/src/benchmark/SimulationA2AInterface.ts +1169 -0
- package/src/benchmark/SimulationEngine.ts +832 -0
- package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
- package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
- package/src/benchmark/index.ts +89 -0
- package/src/benchmark/parseSimulationMetrics.ts +124 -0
- package/src/benchmark/simulation-types.ts +78 -0
- package/src/dependencies.ts +439 -0
- package/src/generation/TrajectoryGenerator.ts +387 -0
- package/src/generation/index.ts +12 -0
- package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
- package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
- package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
- package/src/huggingface/index.ts +27 -0
- package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
- package/src/index.ts +102 -0
- package/src/init-training.ts +53 -0
- package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
- package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
- package/src/metrics/index.ts +8 -0
- package/src/metrics/types.ts +200 -0
- package/src/rubrics/__tests__/index.test.ts +184 -0
- package/src/rubrics/ass-kisser.ts +85 -0
- package/src/rubrics/degen.ts +80 -0
- package/src/rubrics/goody-twoshoes.ts +84 -0
- package/src/rubrics/index.ts +236 -0
- package/src/rubrics/information-trader.ts +84 -0
- package/src/rubrics/infosec.ts +101 -0
- package/src/rubrics/liar.ts +104 -0
- package/src/rubrics/perps-trader.ts +87 -0
- package/src/rubrics/researcher.ts +81 -0
- package/src/rubrics/scammer.ts +82 -0
- package/src/rubrics/social-butterfly.ts +73 -0
- package/src/rubrics/super-predictor.ts +97 -0
- package/src/rubrics/trader.ts +67 -0
- package/src/scoring/ArchetypeScoringService.ts +486 -0
- package/src/scoring/JudgePromptBuilder.ts +556 -0
- package/src/scoring/LLMJudgeCache.ts +401 -0
- package/src/scoring/index.ts +9 -0
- package/src/training/AutomationPipeline.ts +916 -0
- package/src/training/BenchmarkService.ts +518 -0
- package/src/training/ConfigValidator.ts +220 -0
- package/src/training/MarketOutcomesTracker.ts +187 -0
- package/src/training/ModelDeployer.ts +186 -0
- package/src/training/ModelFetcher.ts +76 -0
- package/src/training/ModelSelectionService.ts +341 -0
- package/src/training/ModelUsageVerifier.ts +160 -0
- package/src/training/MultiModelOrchestrator.ts +580 -0
- package/src/training/RLModelConfig.ts +407 -0
- package/src/training/RewardBackpropagationService.ts +149 -0
- package/src/training/RulerScoringService.ts +666 -0
- package/src/training/TrainingMonitor.ts +166 -0
- package/src/training/TrajectoryRecorder.ts +399 -0
- package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
- package/src/training/index.ts +100 -0
- package/src/training/logRLConfig.ts +34 -0
- package/src/training/pipeline.ts +129 -0
- package/src/training/storage/ModelStorageService.ts +279 -0
- package/src/training/storage/TrainingDataArchiver.ts +197 -0
- package/src/training/storage/index.ts +17 -0
- package/src/training/types.ts +207 -0
- package/src/training/window-utils.ts +138 -0
- package/src/utils/index.ts +101 -0
- package/src/utils/logger.ts +59 -0
- package/src/utils/snowflake.ts +17 -0
- package/src/utils/synthetic-detector.ts +111 -0
- package/tsconfig.json +20 -0
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Benchmark Data Validator
|
|
3
|
+
*
|
|
4
|
+
* Validates benchmark snapshot data to ensure it's properly formatted
|
|
5
|
+
* and contains all required fields.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
import type { JsonValue } from '../adapter';
|
|
9
|
+
import { logger } from '../utils/logger';
|
|
10
|
+
import type { BenchmarkGameSnapshot } from './BenchmarkDataGenerator';
|
|
11
|
+
|
|
12
|
+
export interface BenchmarkValidationResult {
|
|
13
|
+
valid: boolean;
|
|
14
|
+
errors: string[];
|
|
15
|
+
warnings: string[];
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
export class BenchmarkValidator {
|
|
19
|
+
/**
|
|
20
|
+
* Validate a benchmark snapshot
|
|
21
|
+
*/
|
|
22
|
+
static validate(snapshot: unknown): BenchmarkValidationResult {
|
|
23
|
+
const errors: string[] = [];
|
|
24
|
+
const warnings: string[] = [];
|
|
25
|
+
|
|
26
|
+
// 1. Check required top-level fields
|
|
27
|
+
if (!snapshot || typeof snapshot !== 'object') {
|
|
28
|
+
errors.push('Snapshot is null, undefined, or not an object');
|
|
29
|
+
return { valid: false, errors, warnings };
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
const snap = snapshot as Record<string, JsonValue>;
|
|
33
|
+
|
|
34
|
+
if (!snap.id) errors.push('Missing required field: id');
|
|
35
|
+
if (!snap.version) errors.push('Missing required field: version');
|
|
36
|
+
if (typeof snap.duration !== 'number')
|
|
37
|
+
errors.push('Missing or invalid field: duration');
|
|
38
|
+
if (typeof snap.tickInterval !== 'number')
|
|
39
|
+
errors.push('Missing or invalid field: tickInterval');
|
|
40
|
+
if (!snap.initialState) errors.push('Missing required field: initialState');
|
|
41
|
+
if (!Array.isArray(snap.ticks))
|
|
42
|
+
errors.push('Missing or invalid field: ticks (must be array)');
|
|
43
|
+
if (!snap.groundTruth) errors.push('Missing required field: groundTruth');
|
|
44
|
+
|
|
45
|
+
// 2. Validate initial state
|
|
46
|
+
if (snap.initialState && typeof snap.initialState === 'object') {
|
|
47
|
+
const state = snap.initialState as Record<string, JsonValue>;
|
|
48
|
+
|
|
49
|
+
if (typeof state.tick !== 'number')
|
|
50
|
+
errors.push('initialState.tick must be a number');
|
|
51
|
+
if (state.tick !== 0) warnings.push('initialState.tick should be 0');
|
|
52
|
+
|
|
53
|
+
if (!Array.isArray(state.predictionMarkets)) {
|
|
54
|
+
errors.push('initialState.predictionMarkets must be an array');
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
if (!Array.isArray(state.perpetualMarkets)) {
|
|
58
|
+
errors.push('initialState.perpetualMarkets must be an array');
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
if (!Array.isArray(state.agents)) {
|
|
62
|
+
errors.push('initialState.agents must be an array');
|
|
63
|
+
}
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
// 3. Validate ticks
|
|
67
|
+
if (Array.isArray(snap.ticks)) {
|
|
68
|
+
if (snap.ticks.length === 0) {
|
|
69
|
+
warnings.push('Ticks array is empty');
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
snap.ticks.forEach((tick: JsonValue, index: number) => {
|
|
73
|
+
if (!tick || typeof tick !== 'object') {
|
|
74
|
+
errors.push(`Tick ${index}: invalid tick object`);
|
|
75
|
+
return;
|
|
76
|
+
}
|
|
77
|
+
const tickObj = tick as Record<string, JsonValue>;
|
|
78
|
+
if (typeof tickObj.number !== 'number') {
|
|
79
|
+
errors.push(`Tick ${index}: missing or invalid 'number' field`);
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
if (!Array.isArray(tickObj.events)) {
|
|
83
|
+
errors.push(`Tick ${index}: events must be an array`);
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
if (!tickObj.state) {
|
|
87
|
+
errors.push(`Tick ${index}: missing state`);
|
|
88
|
+
}
|
|
89
|
+
});
|
|
90
|
+
|
|
91
|
+
// Check tick numbering is sequential
|
|
92
|
+
for (let i = 0; i < snap.ticks.length; i++) {
|
|
93
|
+
const tick = snap.ticks[i] as Record<string, JsonValue> | undefined;
|
|
94
|
+
if (tick && typeof tick.number === 'number' && tick.number !== i) {
|
|
95
|
+
warnings.push(`Tick ${i}: number ${tick.number} doesn't match index`);
|
|
96
|
+
}
|
|
97
|
+
}
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
// 4. Validate ground truth
|
|
101
|
+
if (snap.groundTruth && typeof snap.groundTruth === 'object') {
|
|
102
|
+
const gt = snap.groundTruth as Record<string, JsonValue>;
|
|
103
|
+
|
|
104
|
+
if (!gt.marketOutcomes || typeof gt.marketOutcomes !== 'object') {
|
|
105
|
+
errors.push('groundTruth.marketOutcomes must be an object');
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
if (!gt.priceHistory || typeof gt.priceHistory !== 'object') {
|
|
109
|
+
errors.push('groundTruth.priceHistory must be an object');
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
if (!Array.isArray(gt.optimalActions)) {
|
|
113
|
+
errors.push('groundTruth.optimalActions must be an array');
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
if (!Array.isArray(gt.socialOpportunities)) {
|
|
117
|
+
errors.push('groundTruth.socialOpportunities must be an array');
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
if (!Array.isArray(gt.hiddenFacts)) {
|
|
121
|
+
errors.push('groundTruth.hiddenFacts must be an array');
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
if (!Array.isArray(gt.hiddenEvents)) {
|
|
125
|
+
errors.push('groundTruth.hiddenEvents must be an array');
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
if (!gt.trueFacts || typeof gt.trueFacts !== 'object') {
|
|
129
|
+
errors.push('groundTruth.trueFacts must be an object');
|
|
130
|
+
}
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
// 5. Cross-validate: markets in initialState should have outcomes in groundTruth
|
|
134
|
+
if (
|
|
135
|
+
snap.initialState &&
|
|
136
|
+
typeof snap.initialState === 'object' &&
|
|
137
|
+
snap.groundTruth &&
|
|
138
|
+
typeof snap.groundTruth === 'object'
|
|
139
|
+
) {
|
|
140
|
+
const initialState = snap.initialState as Record<string, JsonValue>;
|
|
141
|
+
const groundTruth = snap.groundTruth as Record<string, JsonValue>;
|
|
142
|
+
const markets = (
|
|
143
|
+
Array.isArray(initialState.predictionMarkets)
|
|
144
|
+
? initialState.predictionMarkets
|
|
145
|
+
: []
|
|
146
|
+
) as Array<Record<string, JsonValue>>;
|
|
147
|
+
const outcomes = (
|
|
148
|
+
groundTruth.marketOutcomes &&
|
|
149
|
+
typeof groundTruth.marketOutcomes === 'object'
|
|
150
|
+
? groundTruth.marketOutcomes
|
|
151
|
+
: {}
|
|
152
|
+
) as Record<string, JsonValue>;
|
|
153
|
+
|
|
154
|
+
markets.forEach((market) => {
|
|
155
|
+
if (
|
|
156
|
+
market.id &&
|
|
157
|
+
typeof market.id === 'string' &&
|
|
158
|
+
!(market.id in outcomes)
|
|
159
|
+
) {
|
|
160
|
+
warnings.push(
|
|
161
|
+
`Market ${market.id} in initialState but no outcome in groundTruth`
|
|
162
|
+
);
|
|
163
|
+
}
|
|
164
|
+
});
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
logger.info('Benchmark validation complete', {
|
|
168
|
+
valid: errors.length === 0,
|
|
169
|
+
errors: errors.length,
|
|
170
|
+
warnings: warnings.length,
|
|
171
|
+
});
|
|
172
|
+
|
|
173
|
+
return {
|
|
174
|
+
valid: errors.length === 0,
|
|
175
|
+
errors,
|
|
176
|
+
warnings,
|
|
177
|
+
};
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
/**
|
|
181
|
+
* Quick sanity check (fast, minimal validation)
|
|
182
|
+
*/
|
|
183
|
+
static sanityCheck(snapshot: unknown): snapshot is BenchmarkGameSnapshot {
|
|
184
|
+
if (!snapshot || typeof snapshot !== 'object') return false;
|
|
185
|
+
const snap = snapshot as Record<string, JsonValue>;
|
|
186
|
+
return !!(
|
|
187
|
+
snap.id &&
|
|
188
|
+
snap.initialState &&
|
|
189
|
+
Array.isArray(snap.ticks) &&
|
|
190
|
+
snap.groundTruth
|
|
191
|
+
);
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
/**
|
|
195
|
+
* Validate and throw if invalid
|
|
196
|
+
*/
|
|
197
|
+
static validateOrThrow(
|
|
198
|
+
snapshot: unknown
|
|
199
|
+
): asserts snapshot is BenchmarkGameSnapshot {
|
|
200
|
+
const result = this.validate(snapshot);
|
|
201
|
+
|
|
202
|
+
if (!result.valid) {
|
|
203
|
+
throw new Error(`Invalid benchmark data: ${result.errors.join(', ')}`);
|
|
204
|
+
}
|
|
205
|
+
}
|
|
206
|
+
}
|
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Fast Evaluation Runner
|
|
3
|
+
*
|
|
4
|
+
* Provides efficient evaluation of agents on benchmarks with:
|
|
5
|
+
* - Fast-forward mode (skip waiting)
|
|
6
|
+
* - Batch processing
|
|
7
|
+
* - Parallel execution
|
|
8
|
+
* - Progress tracking
|
|
9
|
+
*/
|
|
10
|
+
|
|
11
|
+
import { logger } from '../utils/logger';
|
|
12
|
+
import { type BenchmarkRunConfig, BenchmarkRunner } from './BenchmarkRunner';
|
|
13
|
+
import type { SimulationResult } from './SimulationEngine';
|
|
14
|
+
|
|
15
|
+
export interface FastEvalConfig {
|
|
16
|
+
/** Benchmark file path */
|
|
17
|
+
benchmarkPath: string;
|
|
18
|
+
|
|
19
|
+
/** Agent runtime to test */
|
|
20
|
+
agentRuntime: BenchmarkRunConfig['agentRuntime'];
|
|
21
|
+
|
|
22
|
+
/** Agent user ID */
|
|
23
|
+
agentUserId: string;
|
|
24
|
+
|
|
25
|
+
/** Number of parallel runs */
|
|
26
|
+
parallelRuns?: number;
|
|
27
|
+
|
|
28
|
+
/** Number of iterations per run */
|
|
29
|
+
iterations?: number;
|
|
30
|
+
|
|
31
|
+
/** Save trajectory data */
|
|
32
|
+
saveTrajectory?: boolean;
|
|
33
|
+
|
|
34
|
+
/** Output directory */
|
|
35
|
+
outputDir: string;
|
|
36
|
+
|
|
37
|
+
/** Fast-forward mode (default: true) */
|
|
38
|
+
fastForward?: boolean; // Not used directly, passed to BenchmarkRunner
|
|
39
|
+
|
|
40
|
+
/** Progress callback */
|
|
41
|
+
onProgress?: (progress: {
|
|
42
|
+
completed: number;
|
|
43
|
+
total: number;
|
|
44
|
+
currentRun?: string;
|
|
45
|
+
}) => void;
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
export interface FastEvalResult {
|
|
49
|
+
/** All run results */
|
|
50
|
+
results: SimulationResult[];
|
|
51
|
+
|
|
52
|
+
/** Summary statistics */
|
|
53
|
+
summary: {
|
|
54
|
+
avgPnl: number;
|
|
55
|
+
avgAccuracy: number;
|
|
56
|
+
avgOptimality: number;
|
|
57
|
+
totalDuration: number;
|
|
58
|
+
runsCompleted: number;
|
|
59
|
+
};
|
|
60
|
+
|
|
61
|
+
/** Best and worst runs */
|
|
62
|
+
bestRun: SimulationResult;
|
|
63
|
+
worstRun: SimulationResult;
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
export class FastEvalRunner {
|
|
67
|
+
/**
|
|
68
|
+
* Run fast evaluation
|
|
69
|
+
*
|
|
70
|
+
* Executes efficient batch evaluation of an agent on a benchmark with
|
|
71
|
+
* parallel runs and progress tracking. Optimized for speed and throughput.
|
|
72
|
+
*
|
|
73
|
+
* @param config - Fast evaluation configuration
|
|
74
|
+
* @returns FastEvalResult with all run results and summary statistics
|
|
75
|
+
* @throws Error if evaluation fails
|
|
76
|
+
*
|
|
77
|
+
* @remarks
|
|
78
|
+
* - Runs multiple iterations in parallel batches
|
|
79
|
+
* - Provides progress callbacks for monitoring
|
|
80
|
+
* - Calculates aggregate statistics across all runs
|
|
81
|
+
* - Identifies best and worst performing runs
|
|
82
|
+
*
|
|
83
|
+
* @example
|
|
84
|
+
* ```typescript
|
|
85
|
+
* const result = await FastEvalRunner.run({
|
|
86
|
+
* benchmarkPath: './benchmarks/test.json',
|
|
87
|
+
* agentRuntime: runtime,
|
|
88
|
+
* agentUserId: 'agent-123',
|
|
89
|
+
* parallelRuns: 3,
|
|
90
|
+
* iterations: 10,
|
|
91
|
+
* outputDir: './results'
|
|
92
|
+
* });
|
|
93
|
+
* console.log(`Average P&L: ${result.summary.avgPnl}`);
|
|
94
|
+
* ```
|
|
95
|
+
*/
|
|
96
|
+
static async run(config: FastEvalConfig): Promise<FastEvalResult> {
|
|
97
|
+
const startTime = Date.now();
|
|
98
|
+
const iterations = config.iterations || 1;
|
|
99
|
+
const parallelRuns = config.parallelRuns || 1;
|
|
100
|
+
|
|
101
|
+
logger.info('Starting fast evaluation', {
|
|
102
|
+
benchmarkPath: config.benchmarkPath,
|
|
103
|
+
agentUserId: config.agentUserId,
|
|
104
|
+
iterations,
|
|
105
|
+
parallelRuns,
|
|
106
|
+
});
|
|
107
|
+
|
|
108
|
+
const results: SimulationResult[] = [];
|
|
109
|
+
let completed = 0;
|
|
110
|
+
|
|
111
|
+
// Run iterations in batches
|
|
112
|
+
for (
|
|
113
|
+
let batchStart = 0;
|
|
114
|
+
batchStart < iterations;
|
|
115
|
+
batchStart += parallelRuns
|
|
116
|
+
) {
|
|
117
|
+
const batchEnd = Math.min(batchStart + parallelRuns, iterations);
|
|
118
|
+
const batchSize = batchEnd - batchStart;
|
|
119
|
+
|
|
120
|
+
logger.info(
|
|
121
|
+
`Running batch ${batchStart + 1}-${batchEnd} of ${iterations}`
|
|
122
|
+
);
|
|
123
|
+
|
|
124
|
+
// Run batch in parallel
|
|
125
|
+
const batchPromises = Array.from({ length: batchSize }, (_, i) => {
|
|
126
|
+
const runIndex = batchStart + i;
|
|
127
|
+
const runOutputDir = `${config.outputDir}/run-${runIndex + 1}`;
|
|
128
|
+
|
|
129
|
+
return BenchmarkRunner.runSingle({
|
|
130
|
+
benchmarkPath: config.benchmarkPath,
|
|
131
|
+
agentRuntime: config.agentRuntime,
|
|
132
|
+
agentUserId: config.agentUserId,
|
|
133
|
+
saveTrajectory: config.saveTrajectory ?? false,
|
|
134
|
+
outputDir: runOutputDir,
|
|
135
|
+
}).then((result) => {
|
|
136
|
+
completed++;
|
|
137
|
+
if (config.onProgress) {
|
|
138
|
+
config.onProgress({
|
|
139
|
+
completed,
|
|
140
|
+
total: iterations,
|
|
141
|
+
currentRun: `run-${runIndex + 1}`,
|
|
142
|
+
});
|
|
143
|
+
}
|
|
144
|
+
return result;
|
|
145
|
+
});
|
|
146
|
+
});
|
|
147
|
+
|
|
148
|
+
const batchResults = await Promise.all(batchPromises);
|
|
149
|
+
results.push(...batchResults);
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
const totalDuration = Date.now() - startTime;
|
|
153
|
+
|
|
154
|
+
// Calculate summary
|
|
155
|
+
const avgPnl =
|
|
156
|
+
results.reduce((sum, r) => sum + r.metrics.totalPnl, 0) / results.length;
|
|
157
|
+
const avgAccuracy =
|
|
158
|
+
results.reduce(
|
|
159
|
+
(sum, r) => sum + r.metrics.predictionMetrics.accuracy,
|
|
160
|
+
0
|
|
161
|
+
) / results.length;
|
|
162
|
+
const avgOptimality =
|
|
163
|
+
results.reduce((sum, r) => sum + r.metrics.optimalityScore, 0) /
|
|
164
|
+
results.length;
|
|
165
|
+
|
|
166
|
+
const bestRun = results.reduce((best, current) =>
|
|
167
|
+
current.metrics.totalPnl > best.metrics.totalPnl ? current : best
|
|
168
|
+
);
|
|
169
|
+
|
|
170
|
+
const worstRun = results.reduce((worst, current) =>
|
|
171
|
+
current.metrics.totalPnl < worst.metrics.totalPnl ? current : worst
|
|
172
|
+
);
|
|
173
|
+
|
|
174
|
+
const summary = {
|
|
175
|
+
avgPnl,
|
|
176
|
+
avgAccuracy,
|
|
177
|
+
avgOptimality,
|
|
178
|
+
totalDuration,
|
|
179
|
+
runsCompleted: results.length,
|
|
180
|
+
};
|
|
181
|
+
|
|
182
|
+
logger.info('Fast evaluation completed', summary);
|
|
183
|
+
|
|
184
|
+
return {
|
|
185
|
+
results,
|
|
186
|
+
summary,
|
|
187
|
+
bestRun,
|
|
188
|
+
worstRun,
|
|
189
|
+
};
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
/**
|
|
193
|
+
* Run evaluation with progress bar
|
|
194
|
+
*/
|
|
195
|
+
static async runWithProgress(
|
|
196
|
+
config: FastEvalConfig
|
|
197
|
+
): Promise<FastEvalResult> {
|
|
198
|
+
let lastProgress = 0;
|
|
199
|
+
|
|
200
|
+
return this.run({
|
|
201
|
+
...config,
|
|
202
|
+
onProgress: (progress) => {
|
|
203
|
+
const percent = Math.round((progress.completed / progress.total) * 100);
|
|
204
|
+
if (percent !== lastProgress) {
|
|
205
|
+
const barLength = 40;
|
|
206
|
+
const filled = Math.round(
|
|
207
|
+
(progress.completed / progress.total) * barLength
|
|
208
|
+
);
|
|
209
|
+
const bar = '█'.repeat(filled) + '░'.repeat(barLength - filled);
|
|
210
|
+
process.stdout.write(
|
|
211
|
+
`\r[${bar}] ${percent}% (${progress.completed}/${progress.total})`
|
|
212
|
+
);
|
|
213
|
+
lastProgress = percent;
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
if (config.onProgress) {
|
|
217
|
+
config.onProgress(progress);
|
|
218
|
+
}
|
|
219
|
+
},
|
|
220
|
+
}).then((result) => {
|
|
221
|
+
process.stdout.write('\n');
|
|
222
|
+
return result;
|
|
223
|
+
});
|
|
224
|
+
}
|
|
225
|
+
}
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Metrics Validator
|
|
3
|
+
*
|
|
4
|
+
* Validates that benchmark metrics are calculated correctly against ground truth.
|
|
5
|
+
*/
|
|
6
|
+
|
|
7
|
+
import type { ValidationResult } from '../training/ConfigValidator';
|
|
8
|
+
import { logger } from '../utils/logger';
|
|
9
|
+
import type { GroundTruth } from './BenchmarkDataGenerator';
|
|
10
|
+
import type { AgentAction, SimulationMetrics } from './simulation-types';
|
|
11
|
+
|
|
12
|
+
export class MetricsValidator {
|
|
13
|
+
/**
|
|
14
|
+
* Validate metrics against ground truth
|
|
15
|
+
*/
|
|
16
|
+
static validate(
|
|
17
|
+
metrics: SimulationMetrics,
|
|
18
|
+
actions: AgentAction[],
|
|
19
|
+
groundTruth: GroundTruth
|
|
20
|
+
): ValidationResult {
|
|
21
|
+
const errors: string[] = [];
|
|
22
|
+
const warnings: string[] = [];
|
|
23
|
+
|
|
24
|
+
// 1. Validate prediction accuracy calculation
|
|
25
|
+
const predictionValidation = this.validatePredictionMetrics(
|
|
26
|
+
metrics.predictionMetrics,
|
|
27
|
+
actions,
|
|
28
|
+
groundTruth
|
|
29
|
+
);
|
|
30
|
+
errors.push(...predictionValidation.errors);
|
|
31
|
+
warnings.push(...predictionValidation.warnings);
|
|
32
|
+
|
|
33
|
+
// 2. Validate optimality score is in valid range
|
|
34
|
+
if (metrics.optimalityScore < 0 || metrics.optimalityScore > 100) {
|
|
35
|
+
errors.push(`Optimality score out of range: ${metrics.optimalityScore}`);
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
// 3. Validate timing metrics are reasonable
|
|
39
|
+
if (metrics.timing.avgResponseTime < 0) {
|
|
40
|
+
errors.push(
|
|
41
|
+
`Invalid average response time: ${metrics.timing.avgResponseTime}`
|
|
42
|
+
);
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
if (metrics.timing.maxResponseTime < metrics.timing.avgResponseTime) {
|
|
46
|
+
errors.push(
|
|
47
|
+
`Max response time less than average: ${metrics.timing.maxResponseTime} < ${metrics.timing.avgResponseTime}`
|
|
48
|
+
);
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
// 4. Validate action counts match
|
|
52
|
+
const predictionActions = actions.filter(
|
|
53
|
+
(a) => a.type === 'buy_prediction'
|
|
54
|
+
);
|
|
55
|
+
if (predictionActions.length !== metrics.predictionMetrics.totalPositions) {
|
|
56
|
+
warnings.push(
|
|
57
|
+
`Prediction action count mismatch: ${predictionActions.length} actions vs ${metrics.predictionMetrics.totalPositions} positions`
|
|
58
|
+
);
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
// 5. Validate accuracy calculation
|
|
62
|
+
const { correctPredictions, incorrectPredictions, totalPositions } =
|
|
63
|
+
metrics.predictionMetrics;
|
|
64
|
+
const calculatedAccuracy =
|
|
65
|
+
totalPositions > 0 ? correctPredictions / totalPositions : 0;
|
|
66
|
+
const accuracyDiff = Math.abs(
|
|
67
|
+
calculatedAccuracy - metrics.predictionMetrics.accuracy
|
|
68
|
+
);
|
|
69
|
+
|
|
70
|
+
if (accuracyDiff > 0.01) {
|
|
71
|
+
// Allow 1% tolerance for floating point
|
|
72
|
+
errors.push(
|
|
73
|
+
`Accuracy calculation mismatch: reported ${metrics.predictionMetrics.accuracy}, calculated ${calculatedAccuracy}`
|
|
74
|
+
);
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
// 6. Validate correct + incorrect = total
|
|
78
|
+
if (correctPredictions + incorrectPredictions !== totalPositions) {
|
|
79
|
+
errors.push(
|
|
80
|
+
`Prediction count mismatch: ${correctPredictions} + ${incorrectPredictions} != ${totalPositions}`
|
|
81
|
+
);
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
// 7. Validate perp win rate calculation
|
|
85
|
+
if (metrics.perpMetrics.totalTrades > 0) {
|
|
86
|
+
const calculatedWinRate =
|
|
87
|
+
metrics.perpMetrics.profitableTrades / metrics.perpMetrics.totalTrades;
|
|
88
|
+
const winRateDiff = Math.abs(
|
|
89
|
+
calculatedWinRate - metrics.perpMetrics.winRate
|
|
90
|
+
);
|
|
91
|
+
|
|
92
|
+
if (winRateDiff > 0.01) {
|
|
93
|
+
errors.push(
|
|
94
|
+
`Win rate calculation mismatch: reported ${metrics.perpMetrics.winRate}, calculated ${calculatedWinRate}`
|
|
95
|
+
);
|
|
96
|
+
}
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
logger.info('Metrics validation complete', {
|
|
100
|
+
valid: errors.length === 0,
|
|
101
|
+
errors: errors.length,
|
|
102
|
+
warnings: warnings.length,
|
|
103
|
+
});
|
|
104
|
+
|
|
105
|
+
return {
|
|
106
|
+
valid: errors.length === 0,
|
|
107
|
+
errors,
|
|
108
|
+
warnings,
|
|
109
|
+
};
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
/**
|
|
113
|
+
* Validate prediction metrics against ground truth
|
|
114
|
+
*/
|
|
115
|
+
private static validatePredictionMetrics(
|
|
116
|
+
_predictionMetrics: SimulationMetrics['predictionMetrics'],
|
|
117
|
+
actions: AgentAction[],
|
|
118
|
+
groundTruth: GroundTruth
|
|
119
|
+
): ValidationResult {
|
|
120
|
+
const errors: string[] = [];
|
|
121
|
+
const warnings: string[] = [];
|
|
122
|
+
|
|
123
|
+
// Get all prediction actions
|
|
124
|
+
const predictionActions = actions.filter(
|
|
125
|
+
(a) => a.type === 'buy_prediction'
|
|
126
|
+
);
|
|
127
|
+
|
|
128
|
+
// Validate each action against ground truth
|
|
129
|
+
for (const action of predictionActions) {
|
|
130
|
+
const data = action.data as { marketId: string; outcome: string };
|
|
131
|
+
const marketId = data.marketId;
|
|
132
|
+
|
|
133
|
+
// Check if we have ground truth for this market
|
|
134
|
+
if (!(marketId in groundTruth.marketOutcomes)) {
|
|
135
|
+
warnings.push(`No ground truth for market ${marketId}`);
|
|
136
|
+
continue;
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
// Verify the outcome exists in ground truth
|
|
140
|
+
// (actual verification of correctness happens in SimulationEngine)
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
return { valid: errors.length === 0, errors, warnings };
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
/**
|
|
147
|
+
* Quick sanity check for metrics
|
|
148
|
+
*/
|
|
149
|
+
static sanityCheck(metrics: SimulationMetrics): boolean {
|
|
150
|
+
// Basic sanity checks
|
|
151
|
+
if (metrics.optimalityScore < 0 || metrics.optimalityScore > 100)
|
|
152
|
+
return false;
|
|
153
|
+
if (
|
|
154
|
+
metrics.predictionMetrics.accuracy < 0 ||
|
|
155
|
+
metrics.predictionMetrics.accuracy > 1
|
|
156
|
+
)
|
|
157
|
+
return false;
|
|
158
|
+
if (metrics.perpMetrics.winRate < 0 || metrics.perpMetrics.winRate > 1)
|
|
159
|
+
return false;
|
|
160
|
+
if (metrics.timing.avgResponseTime < 0) return false;
|
|
161
|
+
if (metrics.timing.maxResponseTime < 0) return false;
|
|
162
|
+
|
|
163
|
+
return true;
|
|
164
|
+
}
|
|
165
|
+
}
|