@elizaos/training 2.0.0-alpha.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (207) hide show
  1. package/Dockerfile +75 -0
  2. package/Makefile +374 -0
  3. package/README.md +346 -0
  4. package/config/rubrics.json +137 -0
  5. package/data/.gitkeep +0 -0
  6. package/data/degen/.gitkeep +2 -0
  7. package/data/trader/.gitkeep +2 -0
  8. package/docker-compose.test.yml +57 -0
  9. package/package.json +58 -0
  10. package/python/config/babylon_atropos.yaml +90 -0
  11. package/python/config/profiles/12gb.json +11 -0
  12. package/python/config/profiles/16gb.json +10 -0
  13. package/python/config/profiles/24gb.json +10 -0
  14. package/python/config/profiles/48gb.json +10 -0
  15. package/python/config/profiles/cpu.json +11 -0
  16. package/python/config/profiles/l40-2gpu-safe.json +20 -0
  17. package/python/config/profiles/l40-2gpu.json +22 -0
  18. package/python/config/profiles/l40-4gpu.json +21 -0
  19. package/python/config/profiles/l40.json +17 -0
  20. package/python/config/tinker_training.yaml +143 -0
  21. package/python/curriculum_state.json +165 -0
  22. package/python/env.template +86 -0
  23. package/python/env.training.template +46 -0
  24. package/python/pyproject.toml +41 -0
  25. package/python/requirements-ci.txt +31 -0
  26. package/python/requirements.txt +87 -0
  27. package/python/scripts/__init__.py +4 -0
  28. package/python/scripts/import_json_trajectories.py +412 -0
  29. package/python/scripts/local-finetune/README.md +63 -0
  30. package/python/scripts/local-finetune/ingest_and_score.py +139 -0
  31. package/python/scripts/local-finetune/merge_model.py +32 -0
  32. package/python/scripts/local-finetune/test_adapter.py +91 -0
  33. package/python/scripts/local-finetune/train_from_csv.py +132 -0
  34. package/python/scripts/merge_trajectories.py +318 -0
  35. package/python/scripts/run_ab_test.py +143 -0
  36. package/python/scripts/run_full_pipeline.py +544 -0
  37. package/python/scripts/run_tinker_training.py +192 -0
  38. package/python/scripts/run_training.py +914 -0
  39. package/python/scripts/test_judge.py +155 -0
  40. package/python/scripts/test_pipeline.py +356 -0
  41. package/python/scripts/test_trained_model.py +380 -0
  42. package/python/scripts/train_local.py +528 -0
  43. package/python/setup.py +20 -0
  44. package/python/src/__init__.py +190 -0
  45. package/python/src/data_bridge/__init__.py +24 -0
  46. package/python/src/data_bridge/converter.py +435 -0
  47. package/python/src/data_bridge/reader.py +393 -0
  48. package/python/src/models.py +283 -0
  49. package/python/src/training/__init__.py +605 -0
  50. package/python/src/training/ab_testing.py +404 -0
  51. package/python/src/training/action_executor.py +621 -0
  52. package/python/src/training/archetype_trainer.py +347 -0
  53. package/python/src/training/atropos_trainer.py +980 -0
  54. package/python/src/training/babylon_env.py +1254 -0
  55. package/python/src/training/error_recovery.py +647 -0
  56. package/python/src/training/evaluation.py +856 -0
  57. package/python/src/training/fast_simulator.py +880 -0
  58. package/python/src/training/format_validator.py +584 -0
  59. package/python/src/training/hybrid_env.py +522 -0
  60. package/python/src/training/kl_controller.py +628 -0
  61. package/python/src/training/multi_prompt_dataset.py +883 -0
  62. package/python/src/training/multi_turn.py +656 -0
  63. package/python/src/training/online_env.py +1084 -0
  64. package/python/src/training/quality_scorer.py +391 -0
  65. package/python/src/training/quality_utils.py +633 -0
  66. package/python/src/training/rewards.py +1344 -0
  67. package/python/src/training/rlaif_env.py +17 -0
  68. package/python/src/training/rollout_generator.py +502 -0
  69. package/python/src/training/rubric_loader.py +198 -0
  70. package/python/src/training/scenario_pool.py +1072 -0
  71. package/python/src/training/schemas.py +481 -0
  72. package/python/src/training/service_manager.py +552 -0
  73. package/python/src/training/simulation_bridge.py +535 -0
  74. package/python/src/training/tick_reward_attribution.py +399 -0
  75. package/python/src/training/tinker_client.py +575 -0
  76. package/python/src/training/tinker_trainer.py +646 -0
  77. package/python/src/training/tokenization_utils.py +402 -0
  78. package/python/tests/e2e/__init__.py +13 -0
  79. package/python/tests/e2e/conftest.py +258 -0
  80. package/python/tests/e2e/test_full_pipeline.py +643 -0
  81. package/python/tests/e2e/test_online_training_e2e.py +365 -0
  82. package/python/tests/integration/__init__.py +12 -0
  83. package/python/tests/integration/conftest.py +383 -0
  84. package/python/tests/integration/test_db_integration.py +649 -0
  85. package/python/tests/integration/test_json_mode_integration.py +554 -0
  86. package/python/tests/test_action_executor.py +594 -0
  87. package/python/tests/test_archetype_scoring.py +1027 -0
  88. package/python/tests/test_atropos_integration.py +360 -0
  89. package/python/tests/test_evaluation.py +727 -0
  90. package/python/tests/test_format_validator.py +486 -0
  91. package/python/tests/test_kl_controller.py +432 -0
  92. package/python/tests/test_lr_scheduler.py +579 -0
  93. package/python/tests/test_multi_turn.py +590 -0
  94. package/python/tests/test_online_env.py +519 -0
  95. package/python/tests/test_quality_scorer.py +474 -0
  96. package/python/tests/test_scenario_pool.py +735 -0
  97. package/python/tests/test_service_manager.py +585 -0
  98. package/python/tests/test_simulation_rollout.py +581 -0
  99. package/python/tests/test_tokenization_utils.py +501 -0
  100. package/python/tests/test_training_orchestrator.py +497 -0
  101. package/python/tests/test_training_output_structure.py +661 -0
  102. package/research-output/training-runs/training-run-1770772042899.json +26 -0
  103. package/research-output/training-runs/training-run-1770930079670.json +32 -0
  104. package/research-output/training-runs/training-run-1770930143700.json +44 -0
  105. package/research-output/training-runs/training-run-1770930183638.json +38 -0
  106. package/research-output/training-runs/training-run-1770930442049.json +38 -0
  107. package/research-output/training-runs/training-run-1770930793243.json +38 -0
  108. package/scripts/assess-training-data.ts +422 -0
  109. package/scripts/e2e-training-test.ts +550 -0
  110. package/scripts/export-rubrics.ts +64 -0
  111. package/scripts/generate-research-report.ts +1523 -0
  112. package/scripts/generate_dataset.sh +173 -0
  113. package/scripts/json-mode-benchmark.ts +399 -0
  114. package/scripts/real-archetype-benchmark.ts +210 -0
  115. package/scripts/run-baseline-comparison.ts +116 -0
  116. package/scripts/run-full-pipeline.ts +272 -0
  117. package/scripts/runpod_setup.sh +137 -0
  118. package/scripts/runpod_validate.sh +147 -0
  119. package/scripts/test-model-in-game.ts +955 -0
  120. package/scripts/test-scoring.ts +73 -0
  121. package/scripts/test-trained-model.ts +209 -0
  122. package/scripts/train-and-test.ts +824 -0
  123. package/scripts/verify-final.ts +118 -0
  124. package/src/adapter.ts +516 -0
  125. package/src/archetypes/ArchetypeConfigService.ts +626 -0
  126. package/src/archetypes/derive-archetype.ts +249 -0
  127. package/src/archetypes/index.ts +22 -0
  128. package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
  129. package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
  130. package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
  131. package/src/benchmark/BenchmarkDataViewer.ts +324 -0
  132. package/src/benchmark/BenchmarkHistoryService.ts +221 -0
  133. package/src/benchmark/BenchmarkRunner.ts +685 -0
  134. package/src/benchmark/BenchmarkValidator.ts +206 -0
  135. package/src/benchmark/FastEvalRunner.ts +225 -0
  136. package/src/benchmark/MetricsValidator.ts +165 -0
  137. package/src/benchmark/MetricsVisualizer.ts +909 -0
  138. package/src/benchmark/ModelBenchmarkService.ts +611 -0
  139. package/src/benchmark/ModelRegistry.ts +158 -0
  140. package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
  141. package/src/benchmark/SimulationA2AInterface.ts +1169 -0
  142. package/src/benchmark/SimulationEngine.ts +832 -0
  143. package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
  144. package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
  145. package/src/benchmark/index.ts +89 -0
  146. package/src/benchmark/parseSimulationMetrics.ts +124 -0
  147. package/src/benchmark/simulation-types.ts +78 -0
  148. package/src/dependencies.ts +439 -0
  149. package/src/generation/TrajectoryGenerator.ts +387 -0
  150. package/src/generation/index.ts +12 -0
  151. package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
  152. package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
  153. package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
  154. package/src/huggingface/index.ts +27 -0
  155. package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
  156. package/src/index.ts +102 -0
  157. package/src/init-training.ts +53 -0
  158. package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
  159. package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
  160. package/src/metrics/index.ts +8 -0
  161. package/src/metrics/types.ts +200 -0
  162. package/src/rubrics/__tests__/index.test.ts +184 -0
  163. package/src/rubrics/ass-kisser.ts +85 -0
  164. package/src/rubrics/degen.ts +80 -0
  165. package/src/rubrics/goody-twoshoes.ts +84 -0
  166. package/src/rubrics/index.ts +236 -0
  167. package/src/rubrics/information-trader.ts +84 -0
  168. package/src/rubrics/infosec.ts +101 -0
  169. package/src/rubrics/liar.ts +104 -0
  170. package/src/rubrics/perps-trader.ts +87 -0
  171. package/src/rubrics/researcher.ts +81 -0
  172. package/src/rubrics/scammer.ts +82 -0
  173. package/src/rubrics/social-butterfly.ts +73 -0
  174. package/src/rubrics/super-predictor.ts +97 -0
  175. package/src/rubrics/trader.ts +67 -0
  176. package/src/scoring/ArchetypeScoringService.ts +486 -0
  177. package/src/scoring/JudgePromptBuilder.ts +556 -0
  178. package/src/scoring/LLMJudgeCache.ts +401 -0
  179. package/src/scoring/index.ts +9 -0
  180. package/src/training/AutomationPipeline.ts +916 -0
  181. package/src/training/BenchmarkService.ts +518 -0
  182. package/src/training/ConfigValidator.ts +220 -0
  183. package/src/training/MarketOutcomesTracker.ts +187 -0
  184. package/src/training/ModelDeployer.ts +186 -0
  185. package/src/training/ModelFetcher.ts +76 -0
  186. package/src/training/ModelSelectionService.ts +341 -0
  187. package/src/training/ModelUsageVerifier.ts +160 -0
  188. package/src/training/MultiModelOrchestrator.ts +580 -0
  189. package/src/training/RLModelConfig.ts +407 -0
  190. package/src/training/RewardBackpropagationService.ts +149 -0
  191. package/src/training/RulerScoringService.ts +666 -0
  192. package/src/training/TrainingMonitor.ts +166 -0
  193. package/src/training/TrajectoryRecorder.ts +399 -0
  194. package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
  195. package/src/training/index.ts +100 -0
  196. package/src/training/logRLConfig.ts +34 -0
  197. package/src/training/pipeline.ts +129 -0
  198. package/src/training/storage/ModelStorageService.ts +279 -0
  199. package/src/training/storage/TrainingDataArchiver.ts +197 -0
  200. package/src/training/storage/index.ts +17 -0
  201. package/src/training/types.ts +207 -0
  202. package/src/training/window-utils.ts +138 -0
  203. package/src/utils/index.ts +101 -0
  204. package/src/utils/logger.ts +59 -0
  205. package/src/utils/snowflake.ts +17 -0
  206. package/src/utils/synthetic-detector.ts +111 -0
  207. package/tsconfig.json +20 -0
@@ -0,0 +1,324 @@
1
+ /**
2
+ * Benchmark Data Viewer
3
+ *
4
+ * Provides utilities to view and inspect benchmark data.
5
+ * Useful for validation and understanding benchmark structure.
6
+ */
7
+
8
+ import type { JsonValue } from '../adapter';
9
+ import { promises as fs } from 'fs';
10
+ import type {
11
+ BenchmarkGameSnapshot,
12
+ GameState,
13
+ GroundTruth,
14
+ Tick,
15
+ } from './BenchmarkDataGenerator';
16
+ import { BenchmarkValidator } from './BenchmarkValidator';
17
+
18
+ export interface BenchmarkViewOptions {
19
+ /** Show detailed information */
20
+ verbose?: boolean;
21
+
22
+ /** Show only summary */
23
+ summary?: boolean;
24
+
25
+ /** Show ground truth data */
26
+ showGroundTruth?: boolean;
27
+
28
+ /** Show hidden facts/events */
29
+ showHidden?: boolean;
30
+
31
+ /** Filter by tick range */
32
+ tickRange?: { start: number; end: number };
33
+ }
34
+
35
+ export interface BenchmarkView {
36
+ /** Basic info */
37
+ id: string;
38
+ version: string;
39
+ createdAt: number;
40
+ duration: number;
41
+ tickInterval: number;
42
+
43
+ /** State summary */
44
+ initialState: {
45
+ predictionMarkets: number;
46
+ perpetualMarkets: number;
47
+ agents: number;
48
+ posts: number;
49
+ groupChats: number;
50
+ };
51
+
52
+ /** Ticks summary */
53
+ ticks: {
54
+ total: number;
55
+ withEvents: number;
56
+ eventTypes: Record<string, number>;
57
+ };
58
+
59
+ /** Ground truth summary */
60
+ groundTruth?: {
61
+ marketOutcomes: number;
62
+ priceHistory: Record<string, number>;
63
+ optimalActions: number;
64
+ socialOpportunities: number;
65
+ hiddenFacts: number;
66
+ hiddenEvents: number;
67
+ trueFacts: string[];
68
+ };
69
+
70
+ /** Validation results */
71
+ validation: {
72
+ valid: boolean;
73
+ errors: string[];
74
+ warnings: string[];
75
+ };
76
+ }
77
+
78
+ export class BenchmarkDataViewer {
79
+ /**
80
+ * Load and view a benchmark file
81
+ */
82
+ static async view(
83
+ filePath: string,
84
+ options: BenchmarkViewOptions = {}
85
+ ): Promise<BenchmarkView> {
86
+ const data = await fs.readFile(filePath, 'utf-8');
87
+ const snapshot = JSON.parse(data) as BenchmarkGameSnapshot;
88
+
89
+ // Validate
90
+ const validation = BenchmarkValidator.validate(snapshot);
91
+
92
+ // Build view
93
+ const view: BenchmarkView = {
94
+ id: snapshot.id,
95
+ version: snapshot.version,
96
+ createdAt: snapshot.createdAt,
97
+ duration: snapshot.duration,
98
+ tickInterval: snapshot.tickInterval,
99
+
100
+ initialState: {
101
+ predictionMarkets: snapshot.initialState.predictionMarkets.length,
102
+ perpetualMarkets: snapshot.initialState.perpetualMarkets.length,
103
+ agents: snapshot.initialState.agents.length,
104
+ posts: snapshot.initialState.posts?.length || 0,
105
+ groupChats: snapshot.initialState.groupChats?.length || 0,
106
+ },
107
+
108
+ ticks: this.analyzeTicks(snapshot.ticks),
109
+
110
+ validation,
111
+ };
112
+
113
+ if (options.showGroundTruth || options.verbose) {
114
+ view.groundTruth = this.analyzeGroundTruth(snapshot.groundTruth);
115
+ }
116
+
117
+ return view;
118
+ }
119
+
120
+ /**
121
+ * Analyze ticks
122
+ */
123
+ private static analyzeTicks(ticks: Tick[]): BenchmarkView['ticks'] {
124
+ const eventTypes: Record<string, number> = {};
125
+ let withEvents = 0;
126
+
127
+ for (const tick of ticks) {
128
+ if (tick.events.length > 0) {
129
+ withEvents++;
130
+ }
131
+
132
+ for (const event of tick.events) {
133
+ eventTypes[event.type] = (eventTypes[event.type] || 0) + 1;
134
+ }
135
+ }
136
+
137
+ return {
138
+ total: ticks.length,
139
+ withEvents,
140
+ eventTypes,
141
+ };
142
+ }
143
+
144
+ /**
145
+ * Analyze ground truth
146
+ */
147
+ private static analyzeGroundTruth(
148
+ groundTruth: GroundTruth
149
+ ): BenchmarkView['groundTruth'] {
150
+ return {
151
+ marketOutcomes: Object.keys(groundTruth.marketOutcomes).length,
152
+ priceHistory: Object.fromEntries(
153
+ Object.entries(groundTruth.priceHistory).map(([ticker, history]) => [
154
+ ticker,
155
+ history.length,
156
+ ])
157
+ ),
158
+ optimalActions: groundTruth.optimalActions.length,
159
+ socialOpportunities: groundTruth.socialOpportunities.length,
160
+ hiddenFacts: groundTruth.hiddenFacts?.length || 0,
161
+ hiddenEvents: groundTruth.hiddenEvents?.length || 0,
162
+ trueFacts: Object.keys(groundTruth.trueFacts || {}),
163
+ };
164
+ }
165
+
166
+ /**
167
+ * Print view to console
168
+ */
169
+ static print(view: BenchmarkView, options: BenchmarkViewOptions = {}): void {
170
+ console.log('\n📊 Benchmark Data View\n');
171
+ console.log(`ID: ${view.id}`);
172
+ console.log(`Version: ${view.version}`);
173
+ console.log(`Created: ${new Date(view.createdAt).toISOString()}`);
174
+ console.log(`Duration: ${(view.duration / 60).toFixed(1)} minutes`);
175
+ console.log(`Tick Interval: ${view.tickInterval}s`);
176
+
177
+ console.log('\n📈 Initial State:');
178
+ console.log(` Prediction Markets: ${view.initialState.predictionMarkets}`);
179
+ console.log(` Perpetual Markets: ${view.initialState.perpetualMarkets}`);
180
+ console.log(` Agents: ${view.initialState.agents}`);
181
+ console.log(` Posts: ${view.initialState.posts}`);
182
+ console.log(` Group Chats: ${view.initialState.groupChats}`);
183
+
184
+ console.log('\n⏱️ Ticks:');
185
+ console.log(` Total: ${view.ticks.total}`);
186
+ console.log(` With Events: ${view.ticks.withEvents}`);
187
+ if (options.verbose) {
188
+ console.log(` Event Types:`);
189
+ for (const [type, count] of Object.entries(view.ticks.eventTypes)) {
190
+ console.log(` ${type}: ${count}`);
191
+ }
192
+ }
193
+
194
+ if (view.groundTruth) {
195
+ console.log('\n🎯 Ground Truth:');
196
+ console.log(` Market Outcomes: ${view.groundTruth.marketOutcomes}`);
197
+ console.log(` Price History:`);
198
+ for (const [ticker, count] of Object.entries(
199
+ view.groundTruth.priceHistory
200
+ )) {
201
+ console.log(` ${ticker}: ${count} ticks`);
202
+ }
203
+ console.log(` Optimal Actions: ${view.groundTruth.optimalActions}`);
204
+ console.log(
205
+ ` Social Opportunities: ${view.groundTruth.socialOpportunities}`
206
+ );
207
+ if (options.showHidden) {
208
+ console.log(` Hidden Facts: ${view.groundTruth.hiddenFacts}`);
209
+ console.log(` Hidden Events: ${view.groundTruth.hiddenEvents}`);
210
+ console.log(` True Facts: ${view.groundTruth.trueFacts.join(', ')}`);
211
+ }
212
+ }
213
+
214
+ console.log('\n✅ Validation:');
215
+ console.log(` Valid: ${view.validation.valid ? '✅' : '❌'}`);
216
+ if (view.validation.errors.length > 0) {
217
+ console.log(` Errors: ${view.validation.errors.length}`);
218
+ if (options.verbose) {
219
+ for (const error of view.validation.errors) {
220
+ console.log(` ❌ ${error}`);
221
+ }
222
+ }
223
+ }
224
+ if (view.validation.warnings.length > 0) {
225
+ console.log(` Warnings: ${view.validation.warnings.length}`);
226
+ if (options.verbose) {
227
+ for (const warning of view.validation.warnings) {
228
+ console.log(` ⚠️ ${warning}`);
229
+ }
230
+ }
231
+ }
232
+
233
+ console.log('');
234
+ }
235
+
236
+ /**
237
+ * Get tick details
238
+ */
239
+ static getTickDetails(
240
+ snapshot: BenchmarkGameSnapshot,
241
+ tickNumber: number
242
+ ): {
243
+ tick: Tick | null;
244
+ state: GameState | null;
245
+ events: Array<{ type: string; data: Record<string, JsonValue> }>;
246
+ } {
247
+ const tick = snapshot.ticks[tickNumber] || null;
248
+
249
+ if (!tick) {
250
+ return { tick: null, state: null, events: [] };
251
+ }
252
+
253
+ return {
254
+ tick,
255
+ state: tick.state,
256
+ events: tick.events.map((e) => ({
257
+ type: e.type,
258
+ data: e.data,
259
+ })),
260
+ };
261
+ }
262
+
263
+ /**
264
+ * Get ground truth for a specific tick
265
+ */
266
+ static getGroundTruthForTick(
267
+ snapshot: BenchmarkGameSnapshot,
268
+ tickNumber: number
269
+ ): {
270
+ hiddenFacts: Array<{ fact: string; category: string }>;
271
+ hiddenEvents: Array<{ type: string; description: string }>;
272
+ marketOutcomes: Record<string, boolean>;
273
+ } {
274
+ const gt = snapshot.groundTruth;
275
+
276
+ return {
277
+ hiddenFacts: (gt.hiddenFacts || [])
278
+ .filter((f) => f.tick === tickNumber)
279
+ .map((f) => ({ fact: f.fact, category: f.category })),
280
+ hiddenEvents: (gt.hiddenEvents || [])
281
+ .filter((e) => e.tick === tickNumber)
282
+ .map((e) => ({ type: e.type, description: e.description })),
283
+ marketOutcomes: gt.marketOutcomes,
284
+ };
285
+ }
286
+
287
+ /**
288
+ * Check if agent can access hidden facts (should always be false)
289
+ */
290
+ static verifyAgentCannotAccessHiddenFacts(snapshot: BenchmarkGameSnapshot): {
291
+ canAccess: boolean;
292
+ reason: string;
293
+ } {
294
+ // Agents can only access game state via SimulationA2AInterface
295
+ // Ground truth is stored separately and not exposed
296
+ // This is a verification check
297
+
298
+ const state = snapshot.initialState;
299
+ const hasGroundTruth = !!snapshot.groundTruth;
300
+ const hasHiddenFacts = !!snapshot.groundTruth?.hiddenFacts?.length;
301
+
302
+ // Check if ground truth is accidentally in state
303
+ const stateKeys = Object.keys(state);
304
+ const hasGroundTruthInState =
305
+ stateKeys.includes('groundTruth') ||
306
+ stateKeys.includes('hiddenFacts') ||
307
+ stateKeys.includes('hiddenEvents');
308
+
309
+ if (hasGroundTruthInState) {
310
+ return {
311
+ canAccess: true,
312
+ reason: 'Ground truth found in game state (security issue!)',
313
+ };
314
+ }
315
+
316
+ return {
317
+ canAccess: false,
318
+ reason:
319
+ hasGroundTruth && hasHiddenFacts
320
+ ? 'Ground truth exists but is properly isolated from game state'
321
+ : 'No ground truth data found',
322
+ };
323
+ }
324
+ }
@@ -0,0 +1,221 @@
1
+ /**
2
+ * Benchmark History Service
3
+ *
4
+ * Persists benchmark results to the database for historical tracking and analysis.
5
+ */
6
+
7
+ import {
8
+ getTrainingDataAdapter,
9
+ type BenchmarkResultRecord,
10
+ type JsonValue,
11
+ } from '../adapter';
12
+ import { logger } from '../utils/logger';
13
+ import { generateSnowflakeId } from '../utils/snowflake';
14
+ import type { SimulationMetrics } from './SimulationEngine';
15
+
16
+ export interface BenchmarkResultInput {
17
+ modelId: string;
18
+ benchmarkId: string;
19
+ benchmarkPath: string;
20
+ metrics: SimulationMetrics;
21
+ duration: number;
22
+ baselineComparison?: {
23
+ pnlDelta: number;
24
+ accuracyDelta: number;
25
+ improved: boolean;
26
+ };
27
+ }
28
+
29
+ export interface BenchmarkHistoryQuery {
30
+ modelId?: string;
31
+ benchmarkId?: string;
32
+ startDate?: Date;
33
+ endDate?: Date;
34
+ limit?: number;
35
+ }
36
+
37
+ export interface BenchmarkTrendData {
38
+ modelId: string;
39
+ dates: Date[];
40
+ pnlHistory: number[];
41
+ accuracyHistory: number[];
42
+ optimalityHistory: number[];
43
+ }
44
+
45
+ /**
46
+ * Service for managing benchmark result history
47
+ */
48
+ export class BenchmarkHistoryService {
49
+ /**
50
+ * Save a benchmark result to the database
51
+ */
52
+ static async saveResult(
53
+ input: BenchmarkResultInput
54
+ ): Promise<BenchmarkResultRecord> {
55
+ const id = await generateSnowflakeId();
56
+ const now = new Date();
57
+
58
+ const insertData = {
59
+ id,
60
+ modelId: input.modelId,
61
+ benchmarkId: input.benchmarkId,
62
+ benchmarkPath: input.benchmarkPath,
63
+ runAt: now,
64
+ totalPnl: input.metrics.totalPnl,
65
+ predictionAccuracy: input.metrics.predictionMetrics.accuracy,
66
+ perpWinRate: input.metrics.perpMetrics.winRate,
67
+ optimalityScore: input.metrics.optimalityScore,
68
+ detailedMetrics: JSON.parse(JSON.stringify(input.metrics)) as JsonValue,
69
+ baselinePnlDelta: input.baselineComparison?.pnlDelta ?? null,
70
+ baselineAccuracyDelta: input.baselineComparison?.accuracyDelta ?? null,
71
+ improved: input.baselineComparison?.improved ?? null,
72
+ duration: input.duration,
73
+ };
74
+
75
+ await getTrainingDataAdapter().insertBenchmarkResult(insertData);
76
+
77
+ logger.info('Saved benchmark result', {
78
+ id,
79
+ modelId: input.modelId,
80
+ benchmarkId: input.benchmarkId,
81
+ totalPnl: input.metrics.totalPnl,
82
+ });
83
+
84
+ return { ...insertData, createdAt: now };
85
+ }
86
+
87
+ /**
88
+ * Get benchmark results by query
89
+ */
90
+ static async getResults(
91
+ query: BenchmarkHistoryQuery
92
+ ): Promise<BenchmarkResultRecord[]> {
93
+ return getTrainingDataAdapter().queryBenchmarkResults({
94
+ modelId: query.modelId,
95
+ benchmarkId: query.benchmarkId,
96
+ startDate: query.startDate,
97
+ endDate: query.endDate,
98
+ limit: query.limit ?? 100,
99
+ });
100
+ }
101
+
102
+ /**
103
+ * Get the latest result for a model
104
+ */
105
+ static async getLatestResult(
106
+ modelId: string
107
+ ): Promise<BenchmarkResultRecord | null> {
108
+ const results = await getTrainingDataAdapter().queryBenchmarkResults({
109
+ modelId,
110
+ limit: 1,
111
+ });
112
+ return results[0] ?? null;
113
+ }
114
+
115
+ /**
116
+ * Get trend data for a model
117
+ */
118
+ static async getTrendData(
119
+ modelId: string,
120
+ limit = 20
121
+ ): Promise<BenchmarkTrendData> {
122
+ const results = await getTrainingDataAdapter().queryBenchmarkResults({
123
+ modelId,
124
+ limit,
125
+ });
126
+
127
+ // queryBenchmarkResults returns desc by runAt, reverse for chronological
128
+ const chronological = results.reverse();
129
+
130
+ return {
131
+ modelId,
132
+ dates: chronological.map((r) => r.runAt),
133
+ pnlHistory: chronological.map((r) => r.totalPnl),
134
+ accuracyHistory: chronological.map((r) => r.predictionAccuracy),
135
+ optimalityHistory: chronological.map((r) => r.optimalityScore),
136
+ };
137
+ }
138
+
139
+ /**
140
+ * Get comparison data for multiple models
141
+ */
142
+ static async getModelComparison(
143
+ modelIds: string[],
144
+ benchmarkId?: string
145
+ ): Promise<Map<string, BenchmarkResultRecord[]>> {
146
+ const adapter = getTrainingDataAdapter();
147
+ const comparison = new Map<string, BenchmarkResultRecord[]>();
148
+
149
+ for (const modelId of modelIds) {
150
+ const results = await adapter.queryBenchmarkResults({
151
+ modelId,
152
+ benchmarkId,
153
+ limit: 10,
154
+ });
155
+ comparison.set(modelId, results);
156
+ }
157
+
158
+ return comparison;
159
+ }
160
+
161
+ /**
162
+ * Get summary statistics for all models
163
+ */
164
+ static async getModelSummary(): Promise<
165
+ Array<{
166
+ modelId: string;
167
+ runCount: number;
168
+ avgPnl: number;
169
+ avgAccuracy: number;
170
+ avgOptimality: number;
171
+ bestPnl: number;
172
+ latestRun: Date;
173
+ }>
174
+ > {
175
+ return getTrainingDataAdapter().getBenchmarkModelSummary();
176
+ }
177
+
178
+ /**
179
+ * Check if a model improved vs baseline
180
+ */
181
+ static async checkImprovement(
182
+ modelId: string,
183
+ baselineModelId: string,
184
+ benchmarkId: string
185
+ ): Promise<{
186
+ improved: boolean;
187
+ modelPnl: number;
188
+ baselinePnl: number;
189
+ delta: number;
190
+ } | null> {
191
+ const adapter = getTrainingDataAdapter();
192
+
193
+ const modelResults = await adapter.queryBenchmarkResults({
194
+ modelId,
195
+ benchmarkId,
196
+ limit: 1,
197
+ });
198
+
199
+ const baselineResults = await adapter.queryBenchmarkResults({
200
+ modelId: baselineModelId,
201
+ benchmarkId,
202
+ limit: 1,
203
+ });
204
+
205
+ const modelResult = modelResults[0];
206
+ const baselineResult = baselineResults[0];
207
+
208
+ if (!modelResult || !baselineResult) {
209
+ return null;
210
+ }
211
+
212
+ const delta = modelResult.totalPnl - baselineResult.totalPnl;
213
+
214
+ return {
215
+ improved: delta > 0,
216
+ modelPnl: modelResult.totalPnl,
217
+ baselinePnl: baselineResult.totalPnl,
218
+ delta,
219
+ };
220
+ }
221
+ }