@elizaos/training 2.0.0-alpha.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Dockerfile +75 -0
- package/LICENSE +21 -0
- package/Makefile +374 -0
- package/README.md +346 -0
- package/config/rubrics.json +137 -0
- package/docker-compose.test.yml +57 -0
- package/package.json +57 -0
- package/python/config/babylon_atropos.yaml +90 -0
- package/python/config/profiles/12gb.json +11 -0
- package/python/config/profiles/16gb.json +10 -0
- package/python/config/profiles/24gb.json +10 -0
- package/python/config/profiles/48gb.json +10 -0
- package/python/config/profiles/cpu.json +11 -0
- package/python/config/profiles/l40-2gpu-safe.json +20 -0
- package/python/config/profiles/l40-2gpu.json +22 -0
- package/python/config/profiles/l40-4gpu.json +21 -0
- package/python/config/profiles/l40.json +17 -0
- package/python/config/tinker_training.yaml +143 -0
- package/python/curriculum_state.json +165 -0
- package/python/env.template +86 -0
- package/python/env.training.template +46 -0
- package/python/pyproject.toml +41 -0
- package/python/requirements-ci.txt +31 -0
- package/python/requirements.txt +87 -0
- package/python/scripts/__init__.py +4 -0
- package/python/scripts/benchmark_should_respond.py +190 -0
- package/python/scripts/debug_inference.py +62 -0
- package/python/scripts/import_json_trajectories.py +412 -0
- package/python/scripts/local-finetune/README.md +63 -0
- package/python/scripts/local-finetune/ingest_and_score.py +139 -0
- package/python/scripts/local-finetune/merge_model.py +32 -0
- package/python/scripts/local-finetune/test_adapter.py +91 -0
- package/python/scripts/local-finetune/train_from_csv.py +132 -0
- package/python/scripts/merge_trajectories.py +318 -0
- package/python/scripts/optimize_prompt_grpo.py +269 -0
- package/python/scripts/run_ab_test.py +143 -0
- package/python/scripts/run_full_pipeline.py +544 -0
- package/python/scripts/run_tinker_training.py +192 -0
- package/python/scripts/run_training.py +914 -0
- package/python/scripts/test_generation.py +29 -0
- package/python/scripts/test_judge.py +155 -0
- package/python/scripts/test_pipeline.py +356 -0
- package/python/scripts/test_trained_model.py +380 -0
- package/python/scripts/train_grpo.py +360 -0
- package/python/scripts/train_jsonl.py +223 -0
- package/python/scripts/train_local.py +528 -0
- package/python/setup.py +20 -0
- package/python/src/__init__.py +190 -0
- package/python/src/data_bridge/__init__.py +24 -0
- package/python/src/data_bridge/converter.py +435 -0
- package/python/src/data_bridge/reader.py +393 -0
- package/python/src/models.py +283 -0
- package/python/src/training/__init__.py +605 -0
- package/python/src/training/ab_testing.py +404 -0
- package/python/src/training/action_executor.py +621 -0
- package/python/src/training/archetype_trainer.py +347 -0
- package/python/src/training/atropos_trainer.py +980 -0
- package/python/src/training/babylon_env.py +1254 -0
- package/python/src/training/error_recovery.py +647 -0
- package/python/src/training/evaluation.py +856 -0
- package/python/src/training/fast_simulator.py +880 -0
- package/python/src/training/format_validator.py +584 -0
- package/python/src/training/hybrid_env.py +522 -0
- package/python/src/training/kl_controller.py +628 -0
- package/python/src/training/multi_prompt_dataset.py +883 -0
- package/python/src/training/multi_turn.py +656 -0
- package/python/src/training/online_env.py +1084 -0
- package/python/src/training/quality_scorer.py +391 -0
- package/python/src/training/quality_utils.py +633 -0
- package/python/src/training/rewards.py +1344 -0
- package/python/src/training/rlaif_env.py +17 -0
- package/python/src/training/rollout_generator.py +502 -0
- package/python/src/training/rubric_loader.py +198 -0
- package/python/src/training/scenario_pool.py +1072 -0
- package/python/src/training/schemas.py +481 -0
- package/python/src/training/service_manager.py +552 -0
- package/python/src/training/simulation_bridge.py +535 -0
- package/python/src/training/tick_reward_attribution.py +399 -0
- package/python/src/training/tinker_client.py +575 -0
- package/python/src/training/tinker_trainer.py +646 -0
- package/python/src/training/tokenization_utils.py +402 -0
- package/python/tests/e2e/__init__.py +13 -0
- package/python/tests/e2e/conftest.py +258 -0
- package/python/tests/e2e/test_full_pipeline.py +643 -0
- package/python/tests/e2e/test_online_training_e2e.py +365 -0
- package/python/tests/integration/__init__.py +12 -0
- package/python/tests/integration/conftest.py +383 -0
- package/python/tests/integration/test_db_integration.py +649 -0
- package/python/tests/integration/test_json_mode_integration.py +554 -0
- package/python/tests/test_action_executor.py +594 -0
- package/python/tests/test_archetype_scoring.py +1027 -0
- package/python/tests/test_atropos_integration.py +360 -0
- package/python/tests/test_evaluation.py +727 -0
- package/python/tests/test_format_validator.py +486 -0
- package/python/tests/test_kl_controller.py +432 -0
- package/python/tests/test_lr_scheduler.py +579 -0
- package/python/tests/test_multi_turn.py +590 -0
- package/python/tests/test_online_env.py +519 -0
- package/python/tests/test_quality_scorer.py +474 -0
- package/python/tests/test_scenario_pool.py +735 -0
- package/python/tests/test_service_manager.py +585 -0
- package/python/tests/test_simulation_rollout.py +581 -0
- package/python/tests/test_tokenization_utils.py +501 -0
- package/python/tests/test_training_orchestrator.py +497 -0
- package/python/tests/test_training_output_structure.py +661 -0
- package/research-output/training-runs/training-run-1770772042899.json +26 -0
- package/research-output/training-runs/training-run-1770930079670.json +32 -0
- package/research-output/training-runs/training-run-1770930143700.json +44 -0
- package/research-output/training-runs/training-run-1770930183638.json +38 -0
- package/research-output/training-runs/training-run-1770930442049.json +38 -0
- package/research-output/training-runs/training-run-1770930793243.json +38 -0
- package/research-output/training-runs/training-run-1771276293257.json +38 -0
- package/research-output/training-runs/training-run-1771276389280.json +38 -0
- package/research-output/training-runs/training-run-1771276502776.json +38 -0
- package/research-output/training-runs/training-run-1771277340748.json +38 -0
- package/research-output/training-runs/training-run-1773013658993.json +38 -0
- package/research-output/training-runs/training-run-1773013861014.json +38 -0
- package/research-output/training-runs/training-run-1773014215983.json +38 -0
- package/scripts/assess-training-data.ts +422 -0
- package/scripts/e2e-training-test.ts +550 -0
- package/scripts/export-rubrics.ts +64 -0
- package/scripts/generate-research-report.ts +1523 -0
- package/scripts/generate_dataset.sh +173 -0
- package/scripts/generate_should_respond.ts +267 -0
- package/scripts/generate_should_respond_dataset.ts +162 -0
- package/scripts/json-mode-benchmark.ts +399 -0
- package/scripts/rank_trajectories.ts +207 -0
- package/scripts/real-archetype-benchmark.ts +210 -0
- package/scripts/run-baseline-comparison.ts +116 -0
- package/scripts/run-full-pipeline.ts +272 -0
- package/scripts/run_rlaif_loop.ts +78 -0
- package/scripts/run_task_benchmark.ts +247 -0
- package/scripts/runpod_setup.sh +137 -0
- package/scripts/runpod_validate.sh +147 -0
- package/scripts/test-model-in-game.ts +955 -0
- package/scripts/test-scoring.ts +73 -0
- package/scripts/test-trained-model.ts +209 -0
- package/scripts/train-and-test.ts +824 -0
- package/scripts/verify-final.ts +118 -0
- package/src/adapter.ts +516 -0
- package/src/archetypes/ArchetypeConfigService.ts +626 -0
- package/src/archetypes/derive-archetype.ts +249 -0
- package/src/archetypes/index.ts +22 -0
- package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
- package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
- package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
- package/src/benchmark/BenchmarkDataViewer.ts +324 -0
- package/src/benchmark/BenchmarkHistoryService.ts +221 -0
- package/src/benchmark/BenchmarkRunner.ts +685 -0
- package/src/benchmark/BenchmarkValidator.ts +204 -0
- package/src/benchmark/FastEvalRunner.ts +225 -0
- package/src/benchmark/MetricsValidator.ts +165 -0
- package/src/benchmark/MetricsVisualizer.ts +909 -0
- package/src/benchmark/ModelBenchmarkService.ts +611 -0
- package/src/benchmark/ModelRegistry.ts +158 -0
- package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
- package/src/benchmark/SimulationA2AInterface.ts +1169 -0
- package/src/benchmark/SimulationEngine.ts +832 -0
- package/src/benchmark/TaskRunner.ts +94 -0
- package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
- package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
- package/src/benchmark/index.ts +91 -0
- package/src/benchmark/parseSimulationMetrics.ts +124 -0
- package/src/benchmark/simulation-types.ts +78 -0
- package/src/dependencies.ts +475 -0
- package/src/generation/TrajectoryGenerator.ts +387 -0
- package/src/generation/index.ts +12 -0
- package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
- package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
- package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
- package/src/huggingface/index.ts +27 -0
- package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
- package/src/index.ts +102 -0
- package/src/init-training.ts +53 -0
- package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
- package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
- package/src/metrics/index.ts +8 -0
- package/src/metrics/types.ts +200 -0
- package/src/rubrics/__tests__/index.test.ts +184 -0
- package/src/rubrics/ass-kisser.ts +85 -0
- package/src/rubrics/degen.ts +80 -0
- package/src/rubrics/goody-twoshoes.ts +84 -0
- package/src/rubrics/index.ts +236 -0
- package/src/rubrics/information-trader.ts +84 -0
- package/src/rubrics/infosec.ts +101 -0
- package/src/rubrics/liar.ts +104 -0
- package/src/rubrics/perps-trader.ts +87 -0
- package/src/rubrics/researcher.ts +81 -0
- package/src/rubrics/scammer.ts +82 -0
- package/src/rubrics/social-butterfly.ts +73 -0
- package/src/rubrics/super-predictor.ts +97 -0
- package/src/rubrics/trader.ts +67 -0
- package/src/scoring/ArchetypeScoringService.ts +486 -0
- package/src/scoring/JudgePromptBuilder.ts +556 -0
- package/src/scoring/LLMJudgeCache.ts +401 -0
- package/src/scoring/index.ts +9 -0
- package/src/training/AutomationPipeline.ts +916 -0
- package/src/training/BenchmarkService.ts +518 -0
- package/src/training/ConfigValidator.ts +220 -0
- package/src/training/MarketOutcomesTracker.ts +187 -0
- package/src/training/ModelDeployer.ts +186 -0
- package/src/training/ModelFetcher.ts +76 -0
- package/src/training/ModelSelectionService.ts +341 -0
- package/src/training/ModelUsageVerifier.ts +160 -0
- package/src/training/MultiModelOrchestrator.ts +580 -0
- package/src/training/RLModelConfig.ts +407 -0
- package/src/training/RewardBackpropagationService.ts +149 -0
- package/src/training/RulerScoringService.ts +666 -0
- package/src/training/TrainingMonitor.ts +166 -0
- package/src/training/TrajectoryRecorder.ts +399 -0
- package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
- package/src/training/index.ts +100 -0
- package/src/training/logRLConfig.ts +34 -0
- package/src/training/pipeline.ts +129 -0
- package/src/training/storage/ModelStorageService.ts +279 -0
- package/src/training/storage/TrainingDataArchiver.ts +197 -0
- package/src/training/storage/index.ts +17 -0
- package/src/training/types.ts +207 -0
- package/src/training/window-utils.ts +138 -0
- package/src/utils/index.ts +101 -0
- package/src/utils/logger.ts +59 -0
- package/src/utils/snowflake.ts +17 -0
- package/src/utils/synthetic-detector.ts +111 -0
- package/tsconfig.json +20 -0
|
@@ -0,0 +1,611 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Model Benchmark Service (For HuggingFace Integration)
|
|
3
|
+
*
|
|
4
|
+
* Runs benchmark tests on trained RL models for HuggingFace upload decisions.
|
|
5
|
+
* Compares new models against baselines and previous versions.
|
|
6
|
+
*
|
|
7
|
+
* **Purpose:** Evaluate models for HuggingFace upload
|
|
8
|
+
* **Used by:** HuggingFace integration, weekly CRON, CLI scripts
|
|
9
|
+
* **Storage:** benchmark_results table (dedicated table)
|
|
10
|
+
* **Focus:** Public model release, baseline comparison
|
|
11
|
+
*
|
|
12
|
+
* **Note:** For training pipeline benchmarking, see BenchmarkService
|
|
13
|
+
*
|
|
14
|
+
* @see BenchmarkService - For training pipeline evaluation
|
|
15
|
+
*/
|
|
16
|
+
|
|
17
|
+
import { getTrainingDataAdapter } from '../adapter';
|
|
18
|
+
import { ethers } from 'ethers';
|
|
19
|
+
import { promises as fs } from 'fs';
|
|
20
|
+
import * as path from 'path';
|
|
21
|
+
import { getAgentRuntimeManager } from '../dependencies';
|
|
22
|
+
import { logger } from '../utils/logger';
|
|
23
|
+
import { generateSnowflakeId } from '../utils/snowflake';
|
|
24
|
+
import { BenchmarkRunner } from './BenchmarkRunner';
|
|
25
|
+
import {
|
|
26
|
+
type JsonValue,
|
|
27
|
+
parseSimulationMetrics,
|
|
28
|
+
} from './parseSimulationMetrics';
|
|
29
|
+
import type { SimulationMetrics, SimulationResult } from './SimulationEngine';
|
|
30
|
+
|
|
31
|
+
export interface ModelBenchmarkOptions {
|
|
32
|
+
modelId: string;
|
|
33
|
+
benchmarkPaths: string[]; // Paths to benchmark JSON files
|
|
34
|
+
outputDir?: string;
|
|
35
|
+
saveResults?: boolean;
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
export interface ModelBenchmarkResult {
|
|
39
|
+
modelId: string;
|
|
40
|
+
modelVersion: string;
|
|
41
|
+
benchmarkId: string;
|
|
42
|
+
benchmarkPath: string;
|
|
43
|
+
runAt: Date;
|
|
44
|
+
metrics: SimulationMetrics;
|
|
45
|
+
comparisonToBaseline?: {
|
|
46
|
+
pnlDelta: number;
|
|
47
|
+
accuracyDelta: number;
|
|
48
|
+
optimalityDelta: number;
|
|
49
|
+
improved: boolean;
|
|
50
|
+
};
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
export interface ModelComparisonResult {
|
|
54
|
+
newModel: {
|
|
55
|
+
modelId: string;
|
|
56
|
+
version: string;
|
|
57
|
+
avgMetrics: AverageMetrics;
|
|
58
|
+
};
|
|
59
|
+
baseline: {
|
|
60
|
+
modelId: string;
|
|
61
|
+
avgMetrics: AverageMetrics;
|
|
62
|
+
};
|
|
63
|
+
improvement: {
|
|
64
|
+
pnlDelta: number;
|
|
65
|
+
accuracyDelta: number;
|
|
66
|
+
optimalityDelta: number;
|
|
67
|
+
isImprovement: boolean;
|
|
68
|
+
};
|
|
69
|
+
recommendation: 'deploy' | 'keep_training' | 'baseline_better';
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
export interface AverageMetrics {
|
|
73
|
+
totalPnl: number;
|
|
74
|
+
accuracy: number;
|
|
75
|
+
winRate: number;
|
|
76
|
+
optimality: number;
|
|
77
|
+
benchmarkCount: number;
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
export class ModelBenchmarkService {
|
|
81
|
+
/**
|
|
82
|
+
* Benchmark a trained model against standard benchmarks
|
|
83
|
+
*/
|
|
84
|
+
static async benchmarkModel(
|
|
85
|
+
options: ModelBenchmarkOptions
|
|
86
|
+
): Promise<ModelBenchmarkResult[]> {
|
|
87
|
+
logger.info('Starting model benchmark', { modelId: options.modelId });
|
|
88
|
+
|
|
89
|
+
// Load model from database
|
|
90
|
+
const adapter = getTrainingDataAdapter();
|
|
91
|
+
const model = await adapter.getModelById(options.modelId);
|
|
92
|
+
|
|
93
|
+
if (!model) {
|
|
94
|
+
throw new Error(`Model not found: ${options.modelId}`);
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
// Check if model already benchmarked
|
|
98
|
+
const existingBenchmarks = await this.getModelBenchmarks(options.modelId);
|
|
99
|
+
if (existingBenchmarks.length > 0 && !options.saveResults) {
|
|
100
|
+
logger.info('Model already benchmarked', {
|
|
101
|
+
modelId: options.modelId,
|
|
102
|
+
count: existingBenchmarks.length,
|
|
103
|
+
});
|
|
104
|
+
return existingBenchmarks;
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
// Create test agent for benchmarking
|
|
108
|
+
const testAgentId = await this.getOrCreateTestAgent();
|
|
109
|
+
|
|
110
|
+
const results: ModelBenchmarkResult[] = [];
|
|
111
|
+
|
|
112
|
+
// Run each benchmark
|
|
113
|
+
for (const benchmarkPath of options.benchmarkPaths) {
|
|
114
|
+
logger.info('Running benchmark', {
|
|
115
|
+
benchmark: benchmarkPath,
|
|
116
|
+
modelId: options.modelId,
|
|
117
|
+
});
|
|
118
|
+
|
|
119
|
+
try {
|
|
120
|
+
// Get agent runtime (will use the RL model if configured)
|
|
121
|
+
const runtime = await getAgentRuntimeManager().getRuntime(testAgentId);
|
|
122
|
+
|
|
123
|
+
// Run benchmark
|
|
124
|
+
const simulationResult: SimulationResult =
|
|
125
|
+
await BenchmarkRunner.runSingle({
|
|
126
|
+
benchmarkPath,
|
|
127
|
+
agentRuntime: runtime,
|
|
128
|
+
agentUserId: testAgentId,
|
|
129
|
+
saveTrajectory: false,
|
|
130
|
+
outputDir:
|
|
131
|
+
options.outputDir ||
|
|
132
|
+
path.join(
|
|
133
|
+
process.cwd(),
|
|
134
|
+
'benchmarks',
|
|
135
|
+
'model-results',
|
|
136
|
+
model.version
|
|
137
|
+
),
|
|
138
|
+
forceModel: model.storagePath, // Use the RL model
|
|
139
|
+
});
|
|
140
|
+
|
|
141
|
+
// Create benchmark result
|
|
142
|
+
const benchmarkResult: ModelBenchmarkResult = {
|
|
143
|
+
modelId: options.modelId,
|
|
144
|
+
modelVersion: model.version,
|
|
145
|
+
benchmarkId: simulationResult.benchmarkId,
|
|
146
|
+
benchmarkPath,
|
|
147
|
+
runAt: new Date(),
|
|
148
|
+
metrics: simulationResult.metrics,
|
|
149
|
+
};
|
|
150
|
+
|
|
151
|
+
// Compare to baseline if available
|
|
152
|
+
const baseline = await this.getBaselineBenchmark(benchmarkPath);
|
|
153
|
+
if (baseline) {
|
|
154
|
+
benchmarkResult.comparisonToBaseline = {
|
|
155
|
+
pnlDelta: simulationResult.metrics.totalPnl - baseline.totalPnl,
|
|
156
|
+
accuracyDelta:
|
|
157
|
+
simulationResult.metrics.predictionMetrics.accuracy -
|
|
158
|
+
baseline.predictionMetrics.accuracy,
|
|
159
|
+
optimalityDelta:
|
|
160
|
+
simulationResult.metrics.optimalityScore -
|
|
161
|
+
baseline.optimalityScore,
|
|
162
|
+
improved: simulationResult.metrics.totalPnl > baseline.totalPnl,
|
|
163
|
+
};
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
results.push(benchmarkResult);
|
|
167
|
+
|
|
168
|
+
logger.info('Benchmark completed', {
|
|
169
|
+
benchmark: benchmarkPath,
|
|
170
|
+
pnl: simulationResult.metrics.totalPnl,
|
|
171
|
+
accuracy: simulationResult.metrics.predictionMetrics.accuracy,
|
|
172
|
+
});
|
|
173
|
+
|
|
174
|
+
// Save result if requested (to both database and files)
|
|
175
|
+
if (options.saveResults) {
|
|
176
|
+
await this.saveBenchmarkResultToDatabase(benchmarkResult);
|
|
177
|
+
await this.saveBenchmarkResult(benchmarkResult);
|
|
178
|
+
}
|
|
179
|
+
} catch (error) {
|
|
180
|
+
logger.error('Benchmark failed', { benchmark: benchmarkPath, error });
|
|
181
|
+
}
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
// Update model with aggregate benchmark score
|
|
185
|
+
if (results.length > 0) {
|
|
186
|
+
const avgOptimality =
|
|
187
|
+
results.reduce((sum, r) => sum + r.metrics.optimalityScore, 0) /
|
|
188
|
+
results.length;
|
|
189
|
+
const avgPnl =
|
|
190
|
+
results.reduce((sum, r) => sum + r.metrics.totalPnl, 0) /
|
|
191
|
+
results.length;
|
|
192
|
+
|
|
193
|
+
await adapter.updateModelBenchmark(
|
|
194
|
+
options.modelId,
|
|
195
|
+
avgOptimality,
|
|
196
|
+
avgPnl,
|
|
197
|
+
(model.benchmarkCount || 0) + results.length,
|
|
198
|
+
);
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
logger.info('Model benchmark complete', {
|
|
202
|
+
modelId: options.modelId,
|
|
203
|
+
benchmarksRun: results.length,
|
|
204
|
+
});
|
|
205
|
+
|
|
206
|
+
return results;
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
/**
|
|
210
|
+
* Compare new model against baseline
|
|
211
|
+
*/
|
|
212
|
+
static async compareToBaseline(
|
|
213
|
+
modelId: string
|
|
214
|
+
): Promise<ModelComparisonResult> {
|
|
215
|
+
// Get new model benchmarks
|
|
216
|
+
const newModelBenchmarks = await this.getModelBenchmarks(modelId);
|
|
217
|
+
|
|
218
|
+
if (newModelBenchmarks.length === 0) {
|
|
219
|
+
throw new Error(`No benchmarks found for model: ${modelId}`);
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
// Calculate new model average metrics
|
|
223
|
+
const newModelMetrics = this.calculateAverageMetrics(
|
|
224
|
+
newModelBenchmarks.map((b) => b.metrics)
|
|
225
|
+
);
|
|
226
|
+
|
|
227
|
+
// Get baseline benchmarks (use best baseline model)
|
|
228
|
+
const baselineMetrics = await this.getBaselineAverageMetrics();
|
|
229
|
+
|
|
230
|
+
// Calculate improvement
|
|
231
|
+
const pnlDelta = newModelMetrics.totalPnl - baselineMetrics.totalPnl;
|
|
232
|
+
const accuracyDelta = newModelMetrics.accuracy - baselineMetrics.accuracy;
|
|
233
|
+
const optimalityDelta =
|
|
234
|
+
newModelMetrics.optimality - baselineMetrics.optimality;
|
|
235
|
+
|
|
236
|
+
// Determine if this is an improvement (weighted score)
|
|
237
|
+
const improvementScore =
|
|
238
|
+
(pnlDelta > 0 ? 1 : 0) * 0.4 +
|
|
239
|
+
(accuracyDelta > 0 ? 1 : 0) * 0.3 +
|
|
240
|
+
(optimalityDelta > 0 ? 1 : 0) * 0.3;
|
|
241
|
+
|
|
242
|
+
const isImprovement = improvementScore > 0.5;
|
|
243
|
+
|
|
244
|
+
let recommendation: 'deploy' | 'keep_training' | 'baseline_better';
|
|
245
|
+
if (isImprovement && pnlDelta > 0) {
|
|
246
|
+
recommendation = 'deploy';
|
|
247
|
+
} else if (pnlDelta < -100) {
|
|
248
|
+
recommendation = 'baseline_better';
|
|
249
|
+
} else {
|
|
250
|
+
recommendation = 'keep_training';
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
return {
|
|
254
|
+
newModel: {
|
|
255
|
+
modelId,
|
|
256
|
+
version: newModelBenchmarks[0]!.modelVersion,
|
|
257
|
+
avgMetrics: newModelMetrics,
|
|
258
|
+
},
|
|
259
|
+
baseline: {
|
|
260
|
+
modelId: 'baseline',
|
|
261
|
+
avgMetrics: baselineMetrics,
|
|
262
|
+
},
|
|
263
|
+
improvement: {
|
|
264
|
+
pnlDelta,
|
|
265
|
+
accuracyDelta,
|
|
266
|
+
optimalityDelta,
|
|
267
|
+
isImprovement,
|
|
268
|
+
},
|
|
269
|
+
recommendation,
|
|
270
|
+
};
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
/**
|
|
274
|
+
* Get all unbenchmarked models
|
|
275
|
+
*/
|
|
276
|
+
static async getUnbenchmarkedModels(): Promise<string[]> {
|
|
277
|
+
return getTrainingDataAdapter().getUnbenchmarkedModels();
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
/**
|
|
281
|
+
* Get model benchmark results
|
|
282
|
+
*/
|
|
283
|
+
private static async getModelBenchmarks(
|
|
284
|
+
modelId: string
|
|
285
|
+
): Promise<ModelBenchmarkResult[]> {
|
|
286
|
+
// For now, read from files
|
|
287
|
+
// In production, you'd store these in a database table
|
|
288
|
+
|
|
289
|
+
const benchmarksDir = path.join(
|
|
290
|
+
process.cwd(),
|
|
291
|
+
'benchmarks',
|
|
292
|
+
'model-results'
|
|
293
|
+
);
|
|
294
|
+
const results: ModelBenchmarkResult[] = [];
|
|
295
|
+
|
|
296
|
+
try {
|
|
297
|
+
const model = await getTrainingDataAdapter().getModelById(modelId);
|
|
298
|
+
|
|
299
|
+
if (!model) return results;
|
|
300
|
+
|
|
301
|
+
const modelDir = path.join(benchmarksDir, model.version);
|
|
302
|
+
const files = await fs.readdir(modelDir).catch(() => []);
|
|
303
|
+
|
|
304
|
+
for (const file of files) {
|
|
305
|
+
if (file.endsWith('.json')) {
|
|
306
|
+
const filePath = path.join(modelDir, file);
|
|
307
|
+
const data = JSON.parse(await fs.readFile(filePath, 'utf-8'));
|
|
308
|
+
|
|
309
|
+
if (data.modelId === modelId) {
|
|
310
|
+
results.push(data);
|
|
311
|
+
}
|
|
312
|
+
}
|
|
313
|
+
}
|
|
314
|
+
} catch (error) {
|
|
315
|
+
logger.warn('Could not load benchmark results', { error });
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
return results;
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
/**
|
|
322
|
+
* Save benchmark result to database
|
|
323
|
+
*/
|
|
324
|
+
private static async saveBenchmarkResultToDatabase(
|
|
325
|
+
result: ModelBenchmarkResult
|
|
326
|
+
): Promise<void> {
|
|
327
|
+
await getTrainingDataAdapter().insertBenchmarkResult({
|
|
328
|
+
id: await generateSnowflakeId(),
|
|
329
|
+
modelId: result.modelId,
|
|
330
|
+
benchmarkId: result.benchmarkId,
|
|
331
|
+
benchmarkPath: result.benchmarkPath,
|
|
332
|
+
runAt: result.runAt,
|
|
333
|
+
totalPnl: result.metrics.totalPnl,
|
|
334
|
+
predictionAccuracy: result.metrics.predictionMetrics.accuracy,
|
|
335
|
+
perpWinRate: result.metrics.perpMetrics.winRate,
|
|
336
|
+
optimalityScore: result.metrics.optimalityScore,
|
|
337
|
+
detailedMetrics: JSON.parse(JSON.stringify(result.metrics)),
|
|
338
|
+
baselinePnlDelta: result.comparisonToBaseline?.pnlDelta ?? null,
|
|
339
|
+
baselineAccuracyDelta: result.comparisonToBaseline?.accuracyDelta ?? null,
|
|
340
|
+
improved: result.comparisonToBaseline?.improved ?? null,
|
|
341
|
+
duration: result.metrics.timing.totalDuration,
|
|
342
|
+
});
|
|
343
|
+
|
|
344
|
+
logger.info('Benchmark result saved to database', {
|
|
345
|
+
modelId: result.modelId,
|
|
346
|
+
benchmarkId: result.benchmarkId,
|
|
347
|
+
});
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
/**
|
|
351
|
+
* Save benchmark result to file
|
|
352
|
+
*/
|
|
353
|
+
private static async saveBenchmarkResult(
|
|
354
|
+
result: ModelBenchmarkResult
|
|
355
|
+
): Promise<void> {
|
|
356
|
+
const outputDir = path.join(
|
|
357
|
+
process.cwd(),
|
|
358
|
+
'benchmarks',
|
|
359
|
+
'model-results',
|
|
360
|
+
result.modelVersion
|
|
361
|
+
);
|
|
362
|
+
await fs.mkdir(outputDir, { recursive: true });
|
|
363
|
+
|
|
364
|
+
const filename = `benchmark-${result.benchmarkId}-${Date.now()}.json`;
|
|
365
|
+
const filePath = path.join(outputDir, filename);
|
|
366
|
+
|
|
367
|
+
await fs.writeFile(filePath, JSON.stringify(result, null, 2));
|
|
368
|
+
|
|
369
|
+
logger.info('Benchmark result saved to file', { filePath });
|
|
370
|
+
}
|
|
371
|
+
|
|
372
|
+
/**
|
|
373
|
+
* Get benchmark results from database
|
|
374
|
+
*/
|
|
375
|
+
static async getBenchmarkResultsFromDatabase(
|
|
376
|
+
modelId: string
|
|
377
|
+
): Promise<ModelBenchmarkResult[]> {
|
|
378
|
+
const results = await getTrainingDataAdapter().getBenchmarkResultsByModel(modelId);
|
|
379
|
+
|
|
380
|
+
return results.map((r) => ({
|
|
381
|
+
modelId: r.modelId,
|
|
382
|
+
modelVersion: '', // Not stored in results table
|
|
383
|
+
benchmarkId: r.benchmarkId,
|
|
384
|
+
benchmarkPath: r.benchmarkPath,
|
|
385
|
+
runAt: r.runAt,
|
|
386
|
+
metrics: parseSimulationMetrics(r.detailedMetrics as JsonValue),
|
|
387
|
+
comparisonToBaseline:
|
|
388
|
+
r.baselinePnlDelta !== null
|
|
389
|
+
? {
|
|
390
|
+
pnlDelta: r.baselinePnlDelta,
|
|
391
|
+
accuracyDelta: r.baselineAccuracyDelta ?? 0,
|
|
392
|
+
optimalityDelta: 0, // Not stored separately
|
|
393
|
+
improved: r.improved || false,
|
|
394
|
+
}
|
|
395
|
+
: undefined,
|
|
396
|
+
}));
|
|
397
|
+
}
|
|
398
|
+
|
|
399
|
+
/**
|
|
400
|
+
* Get baseline benchmark for comparison
|
|
401
|
+
*/
|
|
402
|
+
private static async getBaselineBenchmark(
|
|
403
|
+
benchmarkPath: string
|
|
404
|
+
): Promise<SimulationMetrics | null> {
|
|
405
|
+
try {
|
|
406
|
+
// Look for baseline result for this benchmark
|
|
407
|
+
const baselinesDir = path.join(process.cwd(), 'benchmarks', 'baselines');
|
|
408
|
+
const files = await fs.readdir(baselinesDir).catch(() => []);
|
|
409
|
+
|
|
410
|
+
for (const file of files) {
|
|
411
|
+
if (file.endsWith('.json')) {
|
|
412
|
+
const filePath = path.join(baselinesDir, file);
|
|
413
|
+
const data = JSON.parse(await fs.readFile(filePath, 'utf-8'));
|
|
414
|
+
|
|
415
|
+
if (
|
|
416
|
+
data.benchmark?.path === benchmarkPath ||
|
|
417
|
+
data.benchmark === benchmarkPath
|
|
418
|
+
) {
|
|
419
|
+
return data.metrics;
|
|
420
|
+
}
|
|
421
|
+
}
|
|
422
|
+
}
|
|
423
|
+
} catch (error) {
|
|
424
|
+
logger.warn('Could not load baseline benchmark', { error });
|
|
425
|
+
}
|
|
426
|
+
|
|
427
|
+
return null;
|
|
428
|
+
}
|
|
429
|
+
|
|
430
|
+
/**
|
|
431
|
+
* Calculate average metrics across multiple benchmark results
|
|
432
|
+
*/
|
|
433
|
+
private static calculateAverageMetrics(
|
|
434
|
+
metricsArray: SimulationMetrics[]
|
|
435
|
+
): AverageMetrics {
|
|
436
|
+
if (metricsArray.length === 0) {
|
|
437
|
+
return {
|
|
438
|
+
totalPnl: 0,
|
|
439
|
+
accuracy: 0,
|
|
440
|
+
winRate: 0,
|
|
441
|
+
optimality: 0,
|
|
442
|
+
benchmarkCount: 0,
|
|
443
|
+
};
|
|
444
|
+
}
|
|
445
|
+
|
|
446
|
+
const totals = metricsArray.reduce(
|
|
447
|
+
(acc, metrics) => ({
|
|
448
|
+
pnl: acc.pnl + metrics.totalPnl,
|
|
449
|
+
accuracy: acc.accuracy + metrics.predictionMetrics.accuracy,
|
|
450
|
+
winRate: acc.winRate + metrics.perpMetrics.winRate,
|
|
451
|
+
optimality: acc.optimality + metrics.optimalityScore,
|
|
452
|
+
}),
|
|
453
|
+
{ pnl: 0, accuracy: 0, winRate: 0, optimality: 0 }
|
|
454
|
+
);
|
|
455
|
+
|
|
456
|
+
const count = metricsArray.length;
|
|
457
|
+
|
|
458
|
+
return {
|
|
459
|
+
totalPnl: totals.pnl / count,
|
|
460
|
+
accuracy: totals.accuracy / count,
|
|
461
|
+
winRate: totals.winRate / count,
|
|
462
|
+
optimality: totals.optimality / count,
|
|
463
|
+
benchmarkCount: count,
|
|
464
|
+
};
|
|
465
|
+
}
|
|
466
|
+
|
|
467
|
+
/**
|
|
468
|
+
* Get baseline average metrics
|
|
469
|
+
*/
|
|
470
|
+
private static async getBaselineAverageMetrics(): Promise<AverageMetrics> {
|
|
471
|
+
const baselinesDir = path.join(process.cwd(), 'benchmarks', 'baselines');
|
|
472
|
+
const metricsArray: SimulationMetrics[] = [];
|
|
473
|
+
|
|
474
|
+
try {
|
|
475
|
+
const files = await fs.readdir(baselinesDir).catch(() => []);
|
|
476
|
+
|
|
477
|
+
for (const file of files) {
|
|
478
|
+
if (file.endsWith('.json')) {
|
|
479
|
+
const filePath = path.join(baselinesDir, file);
|
|
480
|
+
const data = JSON.parse(await fs.readFile(filePath, 'utf-8'));
|
|
481
|
+
|
|
482
|
+
if (data.metrics) {
|
|
483
|
+
metricsArray.push(data.metrics);
|
|
484
|
+
}
|
|
485
|
+
}
|
|
486
|
+
}
|
|
487
|
+
} catch (error) {
|
|
488
|
+
logger.warn('Could not load baseline metrics', { error });
|
|
489
|
+
}
|
|
490
|
+
|
|
491
|
+
return this.calculateAverageMetrics(metricsArray);
|
|
492
|
+
}
|
|
493
|
+
|
|
494
|
+
/**
|
|
495
|
+
* Get or create test agent for benchmarking
|
|
496
|
+
*/
|
|
497
|
+
private static async getOrCreateTestAgent(): Promise<string> {
|
|
498
|
+
const testAgentUsername = 'model-benchmark-agent';
|
|
499
|
+
const adapter = getTrainingDataAdapter();
|
|
500
|
+
|
|
501
|
+
const existing = await adapter.getUserByUsername(testAgentUsername);
|
|
502
|
+
|
|
503
|
+
if (existing) {
|
|
504
|
+
return existing.id;
|
|
505
|
+
}
|
|
506
|
+
|
|
507
|
+
// Create new test agent
|
|
508
|
+
const agentId = await generateSnowflakeId();
|
|
509
|
+
const agent = await adapter.createUser({
|
|
510
|
+
id: agentId,
|
|
511
|
+
privyId: `did:privy:model-benchmark-${agentId}`,
|
|
512
|
+
username: testAgentUsername,
|
|
513
|
+
displayName: 'Model Benchmark Agent',
|
|
514
|
+
walletAddress: ethers.Wallet.createRandom().address,
|
|
515
|
+
isAgent: true,
|
|
516
|
+
virtualBalance: '10000',
|
|
517
|
+
reputationPoints: 1000,
|
|
518
|
+
isTest: true,
|
|
519
|
+
updatedAt: new Date(),
|
|
520
|
+
});
|
|
521
|
+
|
|
522
|
+
// Create agent config in separate table
|
|
523
|
+
if (agent) {
|
|
524
|
+
await adapter.createAgentConfig({
|
|
525
|
+
id: await generateSnowflakeId(),
|
|
526
|
+
userId: agentId,
|
|
527
|
+
autonomousTrading: true,
|
|
528
|
+
autonomousPosting: false,
|
|
529
|
+
autonomousCommenting: false,
|
|
530
|
+
systemPrompt:
|
|
531
|
+
'You are a test agent for benchmarking model performance.',
|
|
532
|
+
modelTier: 'pro',
|
|
533
|
+
updatedAt: new Date(),
|
|
534
|
+
});
|
|
535
|
+
}
|
|
536
|
+
|
|
537
|
+
if (!agent) {
|
|
538
|
+
throw new Error('Failed to create model benchmark test agent');
|
|
539
|
+
}
|
|
540
|
+
|
|
541
|
+
logger.info('Created model benchmark test agent', { agentId: agent.id });
|
|
542
|
+
|
|
543
|
+
return agent.id;
|
|
544
|
+
}
|
|
545
|
+
|
|
546
|
+
/**
|
|
547
|
+
* Get standard benchmark paths for model evaluation
|
|
548
|
+
*/
|
|
549
|
+
static async getStandardBenchmarkPaths(): Promise<string[]> {
|
|
550
|
+
const benchmarksDir = path.join(process.cwd(), 'benchmarks');
|
|
551
|
+
const standardBenchmarks: string[] = [];
|
|
552
|
+
|
|
553
|
+
try {
|
|
554
|
+
// First, look in benchmarks/standard/ directory
|
|
555
|
+
const standardDir = path.join(benchmarksDir, 'standard');
|
|
556
|
+
if (
|
|
557
|
+
await fs
|
|
558
|
+
.access(standardDir)
|
|
559
|
+
.then(() => true)
|
|
560
|
+
.catch(() => false)
|
|
561
|
+
) {
|
|
562
|
+
const standardFiles = await fs.readdir(standardDir);
|
|
563
|
+
for (const file of standardFiles) {
|
|
564
|
+
if (file.startsWith('standard-') && file.endsWith('.json')) {
|
|
565
|
+
standardBenchmarks.push(path.join(standardDir, file));
|
|
566
|
+
}
|
|
567
|
+
}
|
|
568
|
+
}
|
|
569
|
+
|
|
570
|
+
// If standard benchmarks found, use those
|
|
571
|
+
if (standardBenchmarks.length > 0) {
|
|
572
|
+
logger.info(
|
|
573
|
+
`Using ${standardBenchmarks.length} standard benchmarks from benchmarks/standard/`
|
|
574
|
+
);
|
|
575
|
+
return standardBenchmarks;
|
|
576
|
+
}
|
|
577
|
+
|
|
578
|
+
// Fallback: Look for week-long benchmarks in main directory
|
|
579
|
+
const files = await fs.readdir(benchmarksDir);
|
|
580
|
+
for (const file of files) {
|
|
581
|
+
if (file.startsWith('benchmark-week-') && file.endsWith('.json')) {
|
|
582
|
+
standardBenchmarks.push(path.join(benchmarksDir, file));
|
|
583
|
+
}
|
|
584
|
+
}
|
|
585
|
+
|
|
586
|
+
// If still nothing, use any benchmark files
|
|
587
|
+
if (standardBenchmarks.length === 0) {
|
|
588
|
+
for (const file of files) {
|
|
589
|
+
if (
|
|
590
|
+
file.startsWith('benchmark-') &&
|
|
591
|
+
file.endsWith('.json') &&
|
|
592
|
+
!file.includes('comparison')
|
|
593
|
+
) {
|
|
594
|
+
const filePath = path.join(benchmarksDir, file);
|
|
595
|
+
standardBenchmarks.push(filePath);
|
|
596
|
+
}
|
|
597
|
+
}
|
|
598
|
+
}
|
|
599
|
+
} catch (error) {
|
|
600
|
+
logger.error('Could not load standard benchmarks', { error });
|
|
601
|
+
}
|
|
602
|
+
|
|
603
|
+
if (standardBenchmarks.length === 0) {
|
|
604
|
+
logger.warn(
|
|
605
|
+
'No standard benchmarks found. Generate benchmark fixtures before upload.'
|
|
606
|
+
);
|
|
607
|
+
}
|
|
608
|
+
|
|
609
|
+
return standardBenchmarks;
|
|
610
|
+
}
|
|
611
|
+
}
|