@elizaos/training 2.0.0-alpha.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Dockerfile +75 -0
- package/LICENSE +21 -0
- package/Makefile +374 -0
- package/README.md +346 -0
- package/config/rubrics.json +137 -0
- package/docker-compose.test.yml +57 -0
- package/package.json +57 -0
- package/python/config/babylon_atropos.yaml +90 -0
- package/python/config/profiles/12gb.json +11 -0
- package/python/config/profiles/16gb.json +10 -0
- package/python/config/profiles/24gb.json +10 -0
- package/python/config/profiles/48gb.json +10 -0
- package/python/config/profiles/cpu.json +11 -0
- package/python/config/profiles/l40-2gpu-safe.json +20 -0
- package/python/config/profiles/l40-2gpu.json +22 -0
- package/python/config/profiles/l40-4gpu.json +21 -0
- package/python/config/profiles/l40.json +17 -0
- package/python/config/tinker_training.yaml +143 -0
- package/python/curriculum_state.json +165 -0
- package/python/env.template +86 -0
- package/python/env.training.template +46 -0
- package/python/pyproject.toml +41 -0
- package/python/requirements-ci.txt +31 -0
- package/python/requirements.txt +87 -0
- package/python/scripts/__init__.py +4 -0
- package/python/scripts/benchmark_should_respond.py +190 -0
- package/python/scripts/debug_inference.py +62 -0
- package/python/scripts/import_json_trajectories.py +412 -0
- package/python/scripts/local-finetune/README.md +63 -0
- package/python/scripts/local-finetune/ingest_and_score.py +139 -0
- package/python/scripts/local-finetune/merge_model.py +32 -0
- package/python/scripts/local-finetune/test_adapter.py +91 -0
- package/python/scripts/local-finetune/train_from_csv.py +132 -0
- package/python/scripts/merge_trajectories.py +318 -0
- package/python/scripts/optimize_prompt_grpo.py +269 -0
- package/python/scripts/run_ab_test.py +143 -0
- package/python/scripts/run_full_pipeline.py +544 -0
- package/python/scripts/run_tinker_training.py +192 -0
- package/python/scripts/run_training.py +914 -0
- package/python/scripts/test_generation.py +29 -0
- package/python/scripts/test_judge.py +155 -0
- package/python/scripts/test_pipeline.py +356 -0
- package/python/scripts/test_trained_model.py +380 -0
- package/python/scripts/train_grpo.py +360 -0
- package/python/scripts/train_jsonl.py +223 -0
- package/python/scripts/train_local.py +528 -0
- package/python/setup.py +20 -0
- package/python/src/__init__.py +190 -0
- package/python/src/data_bridge/__init__.py +24 -0
- package/python/src/data_bridge/converter.py +435 -0
- package/python/src/data_bridge/reader.py +393 -0
- package/python/src/models.py +283 -0
- package/python/src/training/__init__.py +605 -0
- package/python/src/training/ab_testing.py +404 -0
- package/python/src/training/action_executor.py +621 -0
- package/python/src/training/archetype_trainer.py +347 -0
- package/python/src/training/atropos_trainer.py +980 -0
- package/python/src/training/babylon_env.py +1254 -0
- package/python/src/training/error_recovery.py +647 -0
- package/python/src/training/evaluation.py +856 -0
- package/python/src/training/fast_simulator.py +880 -0
- package/python/src/training/format_validator.py +584 -0
- package/python/src/training/hybrid_env.py +522 -0
- package/python/src/training/kl_controller.py +628 -0
- package/python/src/training/multi_prompt_dataset.py +883 -0
- package/python/src/training/multi_turn.py +656 -0
- package/python/src/training/online_env.py +1084 -0
- package/python/src/training/quality_scorer.py +391 -0
- package/python/src/training/quality_utils.py +633 -0
- package/python/src/training/rewards.py +1344 -0
- package/python/src/training/rlaif_env.py +17 -0
- package/python/src/training/rollout_generator.py +502 -0
- package/python/src/training/rubric_loader.py +198 -0
- package/python/src/training/scenario_pool.py +1072 -0
- package/python/src/training/schemas.py +481 -0
- package/python/src/training/service_manager.py +552 -0
- package/python/src/training/simulation_bridge.py +535 -0
- package/python/src/training/tick_reward_attribution.py +399 -0
- package/python/src/training/tinker_client.py +575 -0
- package/python/src/training/tinker_trainer.py +646 -0
- package/python/src/training/tokenization_utils.py +402 -0
- package/python/tests/e2e/__init__.py +13 -0
- package/python/tests/e2e/conftest.py +258 -0
- package/python/tests/e2e/test_full_pipeline.py +643 -0
- package/python/tests/e2e/test_online_training_e2e.py +365 -0
- package/python/tests/integration/__init__.py +12 -0
- package/python/tests/integration/conftest.py +383 -0
- package/python/tests/integration/test_db_integration.py +649 -0
- package/python/tests/integration/test_json_mode_integration.py +554 -0
- package/python/tests/test_action_executor.py +594 -0
- package/python/tests/test_archetype_scoring.py +1027 -0
- package/python/tests/test_atropos_integration.py +360 -0
- package/python/tests/test_evaluation.py +727 -0
- package/python/tests/test_format_validator.py +486 -0
- package/python/tests/test_kl_controller.py +432 -0
- package/python/tests/test_lr_scheduler.py +579 -0
- package/python/tests/test_multi_turn.py +590 -0
- package/python/tests/test_online_env.py +519 -0
- package/python/tests/test_quality_scorer.py +474 -0
- package/python/tests/test_scenario_pool.py +735 -0
- package/python/tests/test_service_manager.py +585 -0
- package/python/tests/test_simulation_rollout.py +581 -0
- package/python/tests/test_tokenization_utils.py +501 -0
- package/python/tests/test_training_orchestrator.py +497 -0
- package/python/tests/test_training_output_structure.py +661 -0
- package/research-output/training-runs/training-run-1770772042899.json +26 -0
- package/research-output/training-runs/training-run-1770930079670.json +32 -0
- package/research-output/training-runs/training-run-1770930143700.json +44 -0
- package/research-output/training-runs/training-run-1770930183638.json +38 -0
- package/research-output/training-runs/training-run-1770930442049.json +38 -0
- package/research-output/training-runs/training-run-1770930793243.json +38 -0
- package/research-output/training-runs/training-run-1771276293257.json +38 -0
- package/research-output/training-runs/training-run-1771276389280.json +38 -0
- package/research-output/training-runs/training-run-1771276502776.json +38 -0
- package/research-output/training-runs/training-run-1771277340748.json +38 -0
- package/research-output/training-runs/training-run-1773013658993.json +38 -0
- package/research-output/training-runs/training-run-1773013861014.json +38 -0
- package/research-output/training-runs/training-run-1773014215983.json +38 -0
- package/scripts/assess-training-data.ts +422 -0
- package/scripts/e2e-training-test.ts +550 -0
- package/scripts/export-rubrics.ts +64 -0
- package/scripts/generate-research-report.ts +1523 -0
- package/scripts/generate_dataset.sh +173 -0
- package/scripts/generate_should_respond.ts +267 -0
- package/scripts/generate_should_respond_dataset.ts +162 -0
- package/scripts/json-mode-benchmark.ts +399 -0
- package/scripts/rank_trajectories.ts +207 -0
- package/scripts/real-archetype-benchmark.ts +210 -0
- package/scripts/run-baseline-comparison.ts +116 -0
- package/scripts/run-full-pipeline.ts +272 -0
- package/scripts/run_rlaif_loop.ts +78 -0
- package/scripts/run_task_benchmark.ts +247 -0
- package/scripts/runpod_setup.sh +137 -0
- package/scripts/runpod_validate.sh +147 -0
- package/scripts/test-model-in-game.ts +955 -0
- package/scripts/test-scoring.ts +73 -0
- package/scripts/test-trained-model.ts +209 -0
- package/scripts/train-and-test.ts +824 -0
- package/scripts/verify-final.ts +118 -0
- package/src/adapter.ts +516 -0
- package/src/archetypes/ArchetypeConfigService.ts +626 -0
- package/src/archetypes/derive-archetype.ts +249 -0
- package/src/archetypes/index.ts +22 -0
- package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
- package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
- package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
- package/src/benchmark/BenchmarkDataViewer.ts +324 -0
- package/src/benchmark/BenchmarkHistoryService.ts +221 -0
- package/src/benchmark/BenchmarkRunner.ts +685 -0
- package/src/benchmark/BenchmarkValidator.ts +204 -0
- package/src/benchmark/FastEvalRunner.ts +225 -0
- package/src/benchmark/MetricsValidator.ts +165 -0
- package/src/benchmark/MetricsVisualizer.ts +909 -0
- package/src/benchmark/ModelBenchmarkService.ts +611 -0
- package/src/benchmark/ModelRegistry.ts +158 -0
- package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
- package/src/benchmark/SimulationA2AInterface.ts +1169 -0
- package/src/benchmark/SimulationEngine.ts +832 -0
- package/src/benchmark/TaskRunner.ts +94 -0
- package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
- package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
- package/src/benchmark/index.ts +91 -0
- package/src/benchmark/parseSimulationMetrics.ts +124 -0
- package/src/benchmark/simulation-types.ts +78 -0
- package/src/dependencies.ts +475 -0
- package/src/generation/TrajectoryGenerator.ts +387 -0
- package/src/generation/index.ts +12 -0
- package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
- package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
- package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
- package/src/huggingface/index.ts +27 -0
- package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
- package/src/index.ts +102 -0
- package/src/init-training.ts +53 -0
- package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
- package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
- package/src/metrics/index.ts +8 -0
- package/src/metrics/types.ts +200 -0
- package/src/rubrics/__tests__/index.test.ts +184 -0
- package/src/rubrics/ass-kisser.ts +85 -0
- package/src/rubrics/degen.ts +80 -0
- package/src/rubrics/goody-twoshoes.ts +84 -0
- package/src/rubrics/index.ts +236 -0
- package/src/rubrics/information-trader.ts +84 -0
- package/src/rubrics/infosec.ts +101 -0
- package/src/rubrics/liar.ts +104 -0
- package/src/rubrics/perps-trader.ts +87 -0
- package/src/rubrics/researcher.ts +81 -0
- package/src/rubrics/scammer.ts +82 -0
- package/src/rubrics/social-butterfly.ts +73 -0
- package/src/rubrics/super-predictor.ts +97 -0
- package/src/rubrics/trader.ts +67 -0
- package/src/scoring/ArchetypeScoringService.ts +486 -0
- package/src/scoring/JudgePromptBuilder.ts +556 -0
- package/src/scoring/LLMJudgeCache.ts +401 -0
- package/src/scoring/index.ts +9 -0
- package/src/training/AutomationPipeline.ts +916 -0
- package/src/training/BenchmarkService.ts +518 -0
- package/src/training/ConfigValidator.ts +220 -0
- package/src/training/MarketOutcomesTracker.ts +187 -0
- package/src/training/ModelDeployer.ts +186 -0
- package/src/training/ModelFetcher.ts +76 -0
- package/src/training/ModelSelectionService.ts +341 -0
- package/src/training/ModelUsageVerifier.ts +160 -0
- package/src/training/MultiModelOrchestrator.ts +580 -0
- package/src/training/RLModelConfig.ts +407 -0
- package/src/training/RewardBackpropagationService.ts +149 -0
- package/src/training/RulerScoringService.ts +666 -0
- package/src/training/TrainingMonitor.ts +166 -0
- package/src/training/TrajectoryRecorder.ts +399 -0
- package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
- package/src/training/index.ts +100 -0
- package/src/training/logRLConfig.ts +34 -0
- package/src/training/pipeline.ts +129 -0
- package/src/training/storage/ModelStorageService.ts +279 -0
- package/src/training/storage/TrainingDataArchiver.ts +197 -0
- package/src/training/storage/index.ts +17 -0
- package/src/training/types.ts +207 -0
- package/src/training/window-utils.ts +138 -0
- package/src/utils/index.ts +101 -0
- package/src/utils/logger.ts +59 -0
- package/src/utils/snowflake.ts +17 -0
- package/src/utils/synthetic-detector.ts +111 -0
- package/tsconfig.json +20 -0
|
@@ -0,0 +1,393 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Babylon Trajectory Reader
|
|
3
|
+
|
|
4
|
+
Reads trajectories from PostgreSQL database or local JSON files for training.
|
|
5
|
+
Validates LLM call quality to ensure training data authenticity.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from typing import Optional, List, Dict
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
import logging
|
|
14
|
+
|
|
15
|
+
# Handle optional psycopg2 import for JSON-only workflows.
|
|
16
|
+
try:
|
|
17
|
+
import psycopg2
|
|
18
|
+
except ImportError:
|
|
19
|
+
psycopg2 = None
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class TrajectoryRow:
|
|
26
|
+
"""Raw trajectory data from database. Used by PostgresTrajectoryReader."""
|
|
27
|
+
|
|
28
|
+
trajectory_id: str
|
|
29
|
+
agent_id: str
|
|
30
|
+
window_id: str
|
|
31
|
+
steps_json: str
|
|
32
|
+
metrics_json: str
|
|
33
|
+
metadata_json: str
|
|
34
|
+
total_reward: float
|
|
35
|
+
episode_length: int
|
|
36
|
+
final_status: str
|
|
37
|
+
final_pnl: Optional[float]
|
|
38
|
+
trades_executed: Optional[int]
|
|
39
|
+
ai_judge_reward: Optional[float]
|
|
40
|
+
archetype: Optional[str]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_connection():
|
|
44
|
+
"""
|
|
45
|
+
Get PostgreSQL connection from environment.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
psycopg2 connection
|
|
49
|
+
|
|
50
|
+
Raises:
|
|
51
|
+
ValueError: If DATABASE_URL not set
|
|
52
|
+
ImportError: If psycopg2 is not installed
|
|
53
|
+
"""
|
|
54
|
+
if psycopg2 is None:
|
|
55
|
+
raise ImportError(
|
|
56
|
+
"psycopg2 is not installed. Please install it with 'pip install psycopg2-binary'")
|
|
57
|
+
database_url = os.environ.get("DATABASE_URL")
|
|
58
|
+
if not database_url:
|
|
59
|
+
raise ValueError("DATABASE_URL environment variable required")
|
|
60
|
+
return psycopg2.connect(database_url)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def validate_llm_calls(steps: list, min_steps_with_llm: int = 3) -> tuple[bool, list[str]]:
|
|
64
|
+
"""
|
|
65
|
+
Validate trajectory steps contain real LLM calls.
|
|
66
|
+
|
|
67
|
+
Training data MUST have actual LLM calls with real prompts and responses.
|
|
68
|
+
Synthetic or placeholder data will cause training failures.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
steps: List of trajectory steps
|
|
72
|
+
min_steps_with_llm: Minimum steps with valid LLM calls
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
Tuple of (is_valid, list of issue descriptions)
|
|
76
|
+
"""
|
|
77
|
+
issues: list[str] = []
|
|
78
|
+
steps_with_llm = 0
|
|
79
|
+
|
|
80
|
+
if not steps:
|
|
81
|
+
issues.append("Trajectory has no steps.")
|
|
82
|
+
return False, issues
|
|
83
|
+
|
|
84
|
+
for i, step in enumerate(steps):
|
|
85
|
+
llm_calls = step.get("llmCalls") or step.get("llm_calls") or []
|
|
86
|
+
if not llm_calls:
|
|
87
|
+
continue
|
|
88
|
+
|
|
89
|
+
valid_calls_in_step = 0
|
|
90
|
+
for call_idx, call in enumerate(llm_calls):
|
|
91
|
+
system_prompt = call.get("systemPrompt") or call.get(
|
|
92
|
+
"system_prompt") or ""
|
|
93
|
+
user_prompt = call.get("userPrompt") or call.get(
|
|
94
|
+
"user_prompt") or ""
|
|
95
|
+
response = call.get("response") or ""
|
|
96
|
+
|
|
97
|
+
call_issues = []
|
|
98
|
+
if len(system_prompt) < 20:
|
|
99
|
+
call_issues.append("system_prompt too short")
|
|
100
|
+
if len(user_prompt) < 20:
|
|
101
|
+
call_issues.append("user_prompt too short")
|
|
102
|
+
if len(response) < 20:
|
|
103
|
+
call_issues.append("response too short")
|
|
104
|
+
|
|
105
|
+
if not call_issues:
|
|
106
|
+
valid_calls_in_step += 1
|
|
107
|
+
else:
|
|
108
|
+
issues.append(
|
|
109
|
+
f"Step {i}, Call {call_idx}: " + ", ".join(call_issues))
|
|
110
|
+
|
|
111
|
+
if valid_calls_in_step > 0:
|
|
112
|
+
steps_with_llm += 1
|
|
113
|
+
|
|
114
|
+
if steps_with_llm < min_steps_with_llm:
|
|
115
|
+
issues.append(
|
|
116
|
+
f"Only {steps_with_llm}/{len(steps)} steps have valid LLM calls (need at least {min_steps_with_llm})")
|
|
117
|
+
|
|
118
|
+
return len(issues) == 0, issues
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class PostgresTrajectoryReader:
|
|
122
|
+
"""Reads Babylon trajectories from a PostgreSQL database."""
|
|
123
|
+
|
|
124
|
+
def __init__(self, database_url: str):
|
|
125
|
+
if psycopg2 is None:
|
|
126
|
+
raise ImportError(
|
|
127
|
+
"psycopg2 is not installed for PostgresTrajectoryReader. Please install it with 'pip install psycopg2-binary'")
|
|
128
|
+
if not database_url:
|
|
129
|
+
raise ValueError(
|
|
130
|
+
"DATABASE_URL must be provided for PostgresTrajectoryReader")
|
|
131
|
+
self.db_url = database_url
|
|
132
|
+
self.conn = None
|
|
133
|
+
|
|
134
|
+
async def __aenter__(self):
|
|
135
|
+
"""Connect to the database upon entering the async context."""
|
|
136
|
+
# Check to satisfy Pylance's static analysis
|
|
137
|
+
if psycopg2 is None:
|
|
138
|
+
raise ImportError(
|
|
139
|
+
"psycopg2 is not installed, cannot connect to database.")
|
|
140
|
+
self.conn = psycopg2.connect(self.db_url)
|
|
141
|
+
return self
|
|
142
|
+
|
|
143
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
144
|
+
"""Close the database connection upon exiting the context."""
|
|
145
|
+
if self.conn:
|
|
146
|
+
self.conn.close()
|
|
147
|
+
|
|
148
|
+
async def get_window_ids(self, limit: int = 100, only_scored: bool = True, lookback_hours: int = 168, min_agents: int = 1) -> list[str]:
|
|
149
|
+
if not self.conn:
|
|
150
|
+
raise ConnectionError("Database not connected.")
|
|
151
|
+
with self.conn.cursor() as cur:
|
|
152
|
+
query = """
|
|
153
|
+
SELECT DISTINCT "windowId" FROM trajectories
|
|
154
|
+
WHERE "isTrainingData" = true AND "createdAt" > NOW() - INTERVAL '%s hours'
|
|
155
|
+
"""
|
|
156
|
+
params = [lookback_hours]
|
|
157
|
+
if only_scored:
|
|
158
|
+
query += ' AND "aiJudgeReward" IS NOT NULL'
|
|
159
|
+
query += ' ORDER BY "windowId" DESC LIMIT %s'
|
|
160
|
+
params.append(limit)
|
|
161
|
+
cur.execute(query, tuple(params))
|
|
162
|
+
return [row[0] for row in cur.fetchall() if row[0]]
|
|
163
|
+
|
|
164
|
+
async def get_trajectories_by_window(
|
|
165
|
+
self, window_id: str, min_score: Optional[float] = None,
|
|
166
|
+
validate: bool = True, min_actions: int = 1
|
|
167
|
+
) -> list[TrajectoryRow]:
|
|
168
|
+
if not self.conn:
|
|
169
|
+
raise ConnectionError("Database not connected.")
|
|
170
|
+
with self.conn.cursor() as cur:
|
|
171
|
+
query = """
|
|
172
|
+
SELECT "trajectoryId", "agentId", "windowId", "stepsJson", "metricsJson", "metadataJson",
|
|
173
|
+
"totalReward", "episodeLength", "finalStatus", "finalPnL", "tradesExecuted",
|
|
174
|
+
"aiJudgeReward", "archetype"
|
|
175
|
+
FROM trajectories WHERE "windowId" = %s AND "isTrainingData" = true AND "episodeLength" >= %s
|
|
176
|
+
"""
|
|
177
|
+
params: list = [window_id, min_actions]
|
|
178
|
+
if min_score is not None:
|
|
179
|
+
query += ' AND "aiJudgeReward" >= %s'
|
|
180
|
+
params.append(min_score)
|
|
181
|
+
cur.execute(query, tuple(params))
|
|
182
|
+
rows = cur.fetchall()
|
|
183
|
+
|
|
184
|
+
results = []
|
|
185
|
+
for row in rows:
|
|
186
|
+
trajectory = TrajectoryRow(
|
|
187
|
+
trajectory_id=row[0], agent_id=row[1], window_id=row[2], steps_json=row[3],
|
|
188
|
+
metrics_json=row[4], metadata_json=row[5], total_reward=float(
|
|
189
|
+
row[6] or 0.0),
|
|
190
|
+
episode_length=int(row[7] or 0), final_status=row[8] or "unknown",
|
|
191
|
+
final_pnl=float(row[9]) if row[9] else None, trades_executed=int(row[10]) if row[10] else None,
|
|
192
|
+
ai_judge_reward=float(row[11]) if row[11] else None, archetype=row[12],
|
|
193
|
+
)
|
|
194
|
+
if validate:
|
|
195
|
+
try:
|
|
196
|
+
steps = json.loads(trajectory.steps_json)
|
|
197
|
+
is_valid, issues = validate_llm_calls(steps)
|
|
198
|
+
if not is_valid:
|
|
199
|
+
logger.debug(
|
|
200
|
+
f"Skipping DB trajectory {trajectory.trajectory_id}: {issues}")
|
|
201
|
+
continue
|
|
202
|
+
except (json.JSONDecodeError, TypeError):
|
|
203
|
+
logger.warning(
|
|
204
|
+
f"Could not parse steps_json for trajectory {trajectory.trajectory_id}")
|
|
205
|
+
continue
|
|
206
|
+
results.append(trajectory)
|
|
207
|
+
return results
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class JsonTrajectoryReader:
|
|
211
|
+
"""Reads Babylon trajectories from a local directory of JSON files."""
|
|
212
|
+
|
|
213
|
+
def __init__(self, directory_path: str):
|
|
214
|
+
self._directory = Path(directory_path)
|
|
215
|
+
self._trajectories_by_window: Dict[str, List[Dict]] = {}
|
|
216
|
+
|
|
217
|
+
if not self._directory.is_dir():
|
|
218
|
+
raise FileNotFoundError(
|
|
219
|
+
f"Source directory not found: {self._directory.resolve()}")
|
|
220
|
+
|
|
221
|
+
self._scan_files()
|
|
222
|
+
logger.info(
|
|
223
|
+
f"Found {len(self._trajectories_by_window)} windows in {self._directory}")
|
|
224
|
+
|
|
225
|
+
def _scan_files(self):
|
|
226
|
+
file_count = 0
|
|
227
|
+
for file_path in self._directory.glob("*.json"):
|
|
228
|
+
file_count += 1
|
|
229
|
+
try:
|
|
230
|
+
with file_path.open('r', encoding='utf-8') as f:
|
|
231
|
+
data = json.load(f)
|
|
232
|
+
trajectory_data = data.get('trajectory', data)
|
|
233
|
+
window_id = trajectory_data.get("windowId", "default_window")
|
|
234
|
+
if window_id not in self._trajectories_by_window:
|
|
235
|
+
self._trajectories_by_window[window_id] = []
|
|
236
|
+
self._trajectories_by_window[window_id].append(trajectory_data)
|
|
237
|
+
except (json.JSONDecodeError, KeyError, TypeError) as e:
|
|
238
|
+
logger.warning(f"Skipping invalid JSON file {file_path}: {e}")
|
|
239
|
+
|
|
240
|
+
if file_count == 0:
|
|
241
|
+
logger.warning(
|
|
242
|
+
f"No JSON files found in directory: {self._directory}")
|
|
243
|
+
|
|
244
|
+
def get_window_ids(self) -> List[str]:
|
|
245
|
+
return list(self._trajectories_by_window.keys())
|
|
246
|
+
|
|
247
|
+
def get_trajectories_by_window(self, window_id: str) -> List[Dict]:
|
|
248
|
+
return self._trajectories_by_window.get(window_id, [])
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def get_window_ids(limit: int = 100, only_scored: bool = True) -> list[str]:
|
|
252
|
+
"""
|
|
253
|
+
Get distinct window IDs with training data.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
limit: Maximum windows to return
|
|
257
|
+
only_scored: Only return windows with scored trajectories
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
List of window IDs
|
|
261
|
+
"""
|
|
262
|
+
conn = get_connection()
|
|
263
|
+
cur = conn.cursor()
|
|
264
|
+
query = 'SELECT DISTINCT "windowId" FROM trajectories WHERE "isTrainingData" = true'
|
|
265
|
+
if only_scored:
|
|
266
|
+
query += ' AND "aiJudgeReward" IS NOT NULL'
|
|
267
|
+
query += ' ORDER BY "windowId" DESC LIMIT %s'
|
|
268
|
+
cur.execute(query, (limit,))
|
|
269
|
+
rows = cur.fetchall()
|
|
270
|
+
cur.close()
|
|
271
|
+
conn.close()
|
|
272
|
+
return [row[0] for row in rows if row[0]]
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def get_trajectories_by_window(
|
|
276
|
+
window_id: str,
|
|
277
|
+
min_score: Optional[float] = None,
|
|
278
|
+
validate: bool = True,
|
|
279
|
+
) -> list[TrajectoryRow]:
|
|
280
|
+
"""
|
|
281
|
+
Get trajectories for a specific window.
|
|
282
|
+
|
|
283
|
+
Args:
|
|
284
|
+
window_id: Window ID to query
|
|
285
|
+
min_score: Optional minimum AI judge score
|
|
286
|
+
validate: Whether to validate LLM calls
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
List of trajectory rows
|
|
290
|
+
"""
|
|
291
|
+
conn = get_connection()
|
|
292
|
+
cur = conn.cursor()
|
|
293
|
+
query = """
|
|
294
|
+
SELECT "trajectoryId", "agentId", "windowId", "stepsJson", "metricsJson", "metadataJson",
|
|
295
|
+
"totalReward", "episodeLength", "finalStatus", "finalPnL", "tradesExecuted",
|
|
296
|
+
"aiJudgeReward", "archetype"
|
|
297
|
+
FROM trajectories WHERE "windowId" = %s AND "isTrainingData" = true
|
|
298
|
+
"""
|
|
299
|
+
params: list = [window_id]
|
|
300
|
+
if min_score is not None:
|
|
301
|
+
query += ' AND "aiJudgeReward" >= %s'
|
|
302
|
+
params.append(min_score)
|
|
303
|
+
cur.execute(query, params)
|
|
304
|
+
rows = cur.fetchall()
|
|
305
|
+
cur.close()
|
|
306
|
+
conn.close()
|
|
307
|
+
results: list[TrajectoryRow] = []
|
|
308
|
+
for row in rows:
|
|
309
|
+
trajectory = TrajectoryRow(
|
|
310
|
+
trajectory_id=row[0], agent_id=row[1], window_id=row[2], steps_json=row[3],
|
|
311
|
+
metrics_json=row[4], metadata_json=row[5], total_reward=float(
|
|
312
|
+
row[6] or 0.0),
|
|
313
|
+
episode_length=int(row[7] or 0), final_status=row[8] or "unknown",
|
|
314
|
+
final_pnl=float(row[9]) if row[9] else None, trades_executed=int(row[10]) if row[10] else None,
|
|
315
|
+
ai_judge_reward=float(row[11]) if row[11] else None, archetype=row[12],
|
|
316
|
+
)
|
|
317
|
+
if validate:
|
|
318
|
+
steps = json.loads(trajectory.steps_json)
|
|
319
|
+
is_valid, _ = validate_llm_calls(steps)
|
|
320
|
+
if not is_valid:
|
|
321
|
+
continue
|
|
322
|
+
results.append(trajectory)
|
|
323
|
+
return results
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def get_all_training_trajectories(
|
|
327
|
+
limit: int = 1000,
|
|
328
|
+
min_score: Optional[float] = None,
|
|
329
|
+
archetype: Optional[str] = None,
|
|
330
|
+
) -> list[TrajectoryRow]:
|
|
331
|
+
"""
|
|
332
|
+
Get all training trajectories.
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
limit: Maximum trajectories to return
|
|
336
|
+
min_score: Optional minimum AI judge score
|
|
337
|
+
archetype: Optional filter by archetype
|
|
338
|
+
|
|
339
|
+
Returns:
|
|
340
|
+
List of trajectory rows
|
|
341
|
+
"""
|
|
342
|
+
conn = get_connection()
|
|
343
|
+
cur = conn.cursor()
|
|
344
|
+
query = """
|
|
345
|
+
SELECT "trajectoryId", "agentId", "windowId", "stepsJson", "metricsJson", "metadataJson",
|
|
346
|
+
"totalReward", "episodeLength", "finalStatus", "finalPnL", "tradesExecuted",
|
|
347
|
+
"aiJudgeReward", "archetype"
|
|
348
|
+
FROM trajectories WHERE "isTrainingData" = true
|
|
349
|
+
"""
|
|
350
|
+
params: list = []
|
|
351
|
+
if min_score is not None:
|
|
352
|
+
query += ' AND "aiJudgeReward" >= %s'
|
|
353
|
+
params.append(min_score)
|
|
354
|
+
if archetype is not None:
|
|
355
|
+
query += ' AND "archetype" = %s'
|
|
356
|
+
params.append(archetype)
|
|
357
|
+
query += ' ORDER BY "createdAt" DESC LIMIT %s'
|
|
358
|
+
params.append(limit)
|
|
359
|
+
cur.execute(query, params)
|
|
360
|
+
rows = cur.fetchall()
|
|
361
|
+
cur.close()
|
|
362
|
+
conn.close()
|
|
363
|
+
return [TrajectoryRow(
|
|
364
|
+
trajectory_id=r[0], agent_id=r[1], window_id=r[2], steps_json=r[3],
|
|
365
|
+
metrics_json=r[4], metadata_json=r[5], total_reward=float(r[6] or 0.0),
|
|
366
|
+
episode_length=int(r[7] or 0), final_status=r[8] or "unknown",
|
|
367
|
+
final_pnl=float(r[9]) if r[9] else None, trades_executed=int(r[10]) if r[10] else None,
|
|
368
|
+
ai_judge_reward=float(r[11]) if r[11] else None, archetype=r[12],
|
|
369
|
+
) for r in rows]
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
def get_trajectory_stats() -> dict:
|
|
373
|
+
conn = get_connection()
|
|
374
|
+
cur = conn.cursor()
|
|
375
|
+
cur.execute("""
|
|
376
|
+
SELECT COUNT(*), COUNT("aiJudgeReward"), AVG("aiJudgeReward"),
|
|
377
|
+
MIN("aiJudgeReward"), MAX("aiJudgeReward"), COUNT(DISTINCT "archetype")
|
|
378
|
+
FROM trajectories WHERE "isTrainingData" = true
|
|
379
|
+
""")
|
|
380
|
+
row = cur.fetchone()
|
|
381
|
+
cur.close()
|
|
382
|
+
conn.close()
|
|
383
|
+
|
|
384
|
+
if row is None:
|
|
385
|
+
return {"total": 0, "scored": 0, "avg_score": 0.0, "min_score": 0.0, "max_score": 0.0, "archetypes": 0}
|
|
386
|
+
|
|
387
|
+
return {
|
|
388
|
+
"total": row[0] or 0, "scored": row[1] or 0,
|
|
389
|
+
"avg_score": float(row[2]) if row[2] else 0.0,
|
|
390
|
+
"min_score": float(row[3]) if row[3] else 0.0,
|
|
391
|
+
"max_score": float(row[4]) if row[4] else 0.0,
|
|
392
|
+
"archetypes": row[5] or 0,
|
|
393
|
+
}
|
|
@@ -0,0 +1,283 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Shared type definitions for RL training.
|
|
3
|
+
Strong, validated types.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Dict, List, Literal
|
|
7
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
from pydantic.alias_generators import to_camel
|
|
10
|
+
|
|
11
|
+
# Type alias for JSON-serializable values
|
|
12
|
+
JsonDict = Dict[str, object]
|
|
13
|
+
|
|
14
|
+
# Type alias for chat messages with known structure
|
|
15
|
+
ChatMessage = Dict[str, str] # {"role": str, "content": str}
|
|
16
|
+
|
|
17
|
+
# Base config for camelCase conversion, to be used by all models
|
|
18
|
+
camel_case_config = ConfigDict(
|
|
19
|
+
alias_generator=to_camel,
|
|
20
|
+
populate_by_name=True,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class EnvironmentState(BaseModel):
|
|
25
|
+
"""Environment state at a given point"""
|
|
26
|
+
model_config = camel_case_config
|
|
27
|
+
|
|
28
|
+
agent_balance: float
|
|
29
|
+
# Explicit alias for the 'agentPnL' field from the JSON data
|
|
30
|
+
agent_pnl: float = Field(..., alias='agentPnL')
|
|
31
|
+
open_positions: int
|
|
32
|
+
active_markets: int = 0
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ProviderAccess(BaseModel):
|
|
36
|
+
"""Data accessed from a provider"""
|
|
37
|
+
# Combines camelCase conversion with allowing extra fields
|
|
38
|
+
model_config = ConfigDict(
|
|
39
|
+
alias_generator=to_camel,
|
|
40
|
+
populate_by_name=True,
|
|
41
|
+
extra="allow"
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
provider_name: str
|
|
45
|
+
data: JsonDict
|
|
46
|
+
purpose: str
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class LLMCall(BaseModel):
|
|
50
|
+
"""
|
|
51
|
+
Single LLM call record.
|
|
52
|
+
Matches the TypeScript LLMCall interface in plugin-trajectory-logger/types.ts
|
|
53
|
+
"""
|
|
54
|
+
model_config = camel_case_config
|
|
55
|
+
|
|
56
|
+
model: str
|
|
57
|
+
model_version: str | None = None
|
|
58
|
+
system_prompt: str
|
|
59
|
+
user_prompt: str
|
|
60
|
+
response: str
|
|
61
|
+
reasoning: str | None = None
|
|
62
|
+
temperature: float
|
|
63
|
+
max_tokens: int
|
|
64
|
+
latency_ms: int | None = None
|
|
65
|
+
prompt_tokens: int | None = None
|
|
66
|
+
completion_tokens: int | None = None
|
|
67
|
+
purpose: Literal['action', 'reasoning', 'evaluation', 'response', 'other']
|
|
68
|
+
action_type: str | None = None
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class Action(BaseModel):
|
|
72
|
+
"""Action taken by agent"""
|
|
73
|
+
# Combines camelCase conversion with allowing extra fields
|
|
74
|
+
model_config = ConfigDict(
|
|
75
|
+
alias_generator=to_camel,
|
|
76
|
+
populate_by_name=True,
|
|
77
|
+
extra="allow"
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
action_type: str
|
|
81
|
+
parameters: JsonDict
|
|
82
|
+
success: bool
|
|
83
|
+
result: JsonDict | None = None
|
|
84
|
+
error: str | None = None
|
|
85
|
+
reasoning: str | None = None
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class TrajectoryStep(BaseModel):
|
|
89
|
+
"""Single step in a trajectory"""
|
|
90
|
+
model_config = camel_case_config
|
|
91
|
+
|
|
92
|
+
step_number: int
|
|
93
|
+
timestamp: int
|
|
94
|
+
environment_state: EnvironmentState
|
|
95
|
+
provider_accesses: List[ProviderAccess] = Field(default_factory=list)
|
|
96
|
+
llm_calls: List[LLMCall] = Field(default_factory=list)
|
|
97
|
+
action: Action | None = None
|
|
98
|
+
reward: float = 0.0
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class TrainingTrajectory(BaseModel):
|
|
102
|
+
"""Complete trajectory from database"""
|
|
103
|
+
# Combines camelCase conversion with mutability
|
|
104
|
+
model_config = ConfigDict(
|
|
105
|
+
alias_generator=to_camel,
|
|
106
|
+
populate_by_name=True,
|
|
107
|
+
frozen=False
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
trajectory_id: str
|
|
111
|
+
agent_id: str
|
|
112
|
+
|
|
113
|
+
id: str = ""
|
|
114
|
+
window_id: str = "default"
|
|
115
|
+
start_time: datetime | None = None
|
|
116
|
+
end_time: datetime | None = None
|
|
117
|
+
duration_ms: int = 0
|
|
118
|
+
scenario_id: str | None = None
|
|
119
|
+
episode_id: str | None = None
|
|
120
|
+
steps: List[TrajectoryStep] = Field(default_factory=list)
|
|
121
|
+
total_reward: float = 0.0
|
|
122
|
+
final_pnl: float = 0.0
|
|
123
|
+
final_balance: float | None = None
|
|
124
|
+
trades_executed: int = 0
|
|
125
|
+
successful_trades: int = 0
|
|
126
|
+
failed_trades: int = 0
|
|
127
|
+
posts_created: int = 0
|
|
128
|
+
provider_accesses: int = 0
|
|
129
|
+
episode_length: int = 0
|
|
130
|
+
final_status: str = "completed"
|
|
131
|
+
archetype: str | None = None
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class StockOutcome(BaseModel):
|
|
135
|
+
"""Market outcome for a stock"""
|
|
136
|
+
model_config = camel_case_config
|
|
137
|
+
ticker: str
|
|
138
|
+
start_price: float
|
|
139
|
+
end_price: float
|
|
140
|
+
change_percent: float
|
|
141
|
+
sentiment: Literal['BULLISH', 'BEARISH', 'NEUTRAL'] | None = None
|
|
142
|
+
news_events: List[str] = Field(default_factory=list)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class PredictionOutcome(BaseModel):
|
|
146
|
+
"""Outcome for a prediction market"""
|
|
147
|
+
model_config = camel_case_config
|
|
148
|
+
market_id: str
|
|
149
|
+
question: str
|
|
150
|
+
outcome: Literal['YES', 'NO', 'UNRESOLVED']
|
|
151
|
+
final_probability: float
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class MarketOutcomes(BaseModel):
|
|
155
|
+
"""All market outcomes for a window"""
|
|
156
|
+
model_config = camel_case_config
|
|
157
|
+
window_id: str
|
|
158
|
+
window_start: datetime
|
|
159
|
+
window_end: datetime
|
|
160
|
+
stocks: dict[str, StockOutcome] = Field(default_factory=dict)
|
|
161
|
+
predictions: dict[str, PredictionOutcome] = Field(default_factory=dict)
|
|
162
|
+
overall_trend: Literal['BULLISH', 'BEARISH', 'NEUTRAL'] | None = None
|
|
163
|
+
volatility: Literal['HIGH', 'MEDIUM', 'LOW'] | None = None
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class WindowStatistics(BaseModel):
|
|
167
|
+
"""Statistics for a training window"""
|
|
168
|
+
model_config = camel_case_config
|
|
169
|
+
window_id: str
|
|
170
|
+
agent_count: int
|
|
171
|
+
trajectory_count: int
|
|
172
|
+
total_actions: int
|
|
173
|
+
avg_pnl: float
|
|
174
|
+
min_pnl: float
|
|
175
|
+
max_pnl: float
|
|
176
|
+
start_time: datetime
|
|
177
|
+
end_time: datetime
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class TrainingBatchSummary(BaseModel):
|
|
181
|
+
"""Summary of a training batch"""
|
|
182
|
+
model_config = camel_case_config
|
|
183
|
+
windows: int
|
|
184
|
+
total_trajectories: int
|
|
185
|
+
avg_trajectories_per_window: float
|
|
186
|
+
score_min: float
|
|
187
|
+
score_max: float
|
|
188
|
+
score_avg: float
|
|
189
|
+
pnl_min: float
|
|
190
|
+
pnl_max: float
|
|
191
|
+
pnl_avg: float
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
# =============================================================================
|
|
195
|
+
# Atropos-compatible types
|
|
196
|
+
# =============================================================================
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
class AtroposScoredItem(BaseModel):
|
|
200
|
+
"""Single scored item for Atropos training"""
|
|
201
|
+
model_config = camel_case_config
|
|
202
|
+
tokens: List[int]
|
|
203
|
+
masks: List[int]
|
|
204
|
+
score: float
|
|
205
|
+
logprobs: List[float] = Field(default_factory=list)
|
|
206
|
+
messages: List[ChatMessage] = Field(default_factory=list)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
class AtroposScoredGroup(BaseModel):
|
|
210
|
+
"""Group of scored items for Atropos GRPO training"""
|
|
211
|
+
model_config = camel_case_config
|
|
212
|
+
tokens: List[List[int]]
|
|
213
|
+
masks: List[List[int]]
|
|
214
|
+
scores: List[float]
|
|
215
|
+
inference_logprobs: List[List[float]] = Field(default_factory=list)
|
|
216
|
+
messages: List[List[ChatMessage]] = Field(default_factory=list)
|
|
217
|
+
env_id: int | None = None
|
|
218
|
+
|
|
219
|
+
@property
|
|
220
|
+
def group_size(self) -> int:
|
|
221
|
+
return len(self.tokens)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
class TrajectoryGroup(BaseModel):
|
|
225
|
+
"""Group of trajectories for relative comparison"""
|
|
226
|
+
model_config = camel_case_config
|
|
227
|
+
group_key: str
|
|
228
|
+
window_id: str
|
|
229
|
+
scenario_id: str | None = None
|
|
230
|
+
trajectories: List[TrainingTrajectory]
|
|
231
|
+
|
|
232
|
+
@property
|
|
233
|
+
def size(self) -> int:
|
|
234
|
+
return len(self.trajectories)
|
|
235
|
+
|
|
236
|
+
def get_pnl_stats(self) -> dict:
|
|
237
|
+
"""Get P&L statistics for the group"""
|
|
238
|
+
pnls = [t.final_pnl for t in self.trajectories]
|
|
239
|
+
return {
|
|
240
|
+
"min": min(pnls) if pnls else 0,
|
|
241
|
+
"max": max(pnls) if pnls else 0,
|
|
242
|
+
"mean": sum(pnls) / len(pnls) if pnls else 0,
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
class JudgeScore(BaseModel):
|
|
247
|
+
"""Score from LLM judge for a trajectory"""
|
|
248
|
+
model_config = camel_case_config
|
|
249
|
+
trajectory_id: str
|
|
250
|
+
score: float = Field(ge=0.0, le=1.0)
|
|
251
|
+
explanation: str
|
|
252
|
+
confidence: float = Field(default=1.0, ge=0.0, le=1.0)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
class JudgeResponse(BaseModel):
|
|
256
|
+
"""Response from LLM judge for a group of trajectories"""
|
|
257
|
+
model_config = camel_case_config
|
|
258
|
+
reasoning: str
|
|
259
|
+
scores: List[JudgeScore]
|
|
260
|
+
|
|
261
|
+
def get_score_for(self, trajectory_id: str) -> float | None:
|
|
262
|
+
"""Get score for a specific trajectory"""
|
|
263
|
+
for score in self.scores:
|
|
264
|
+
if score.trajectory_id == trajectory_id:
|
|
265
|
+
return score.score
|
|
266
|
+
return None
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
# Backward compatibility alias while downstream code migrates.
|
|
270
|
+
BabylonTrajectory = TrainingTrajectory
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
class TrainingMetrics(BaseModel):
|
|
274
|
+
"""Metrics from a training step"""
|
|
275
|
+
model_config = camel_case_config
|
|
276
|
+
step: int
|
|
277
|
+
loss: float
|
|
278
|
+
grad_norm: float
|
|
279
|
+
learning_rate: float
|
|
280
|
+
pos_logp: float = 0.0
|
|
281
|
+
neg_logp: float = 0.0
|
|
282
|
+
num_samples: int = 0
|
|
283
|
+
timestamp: datetime = Field(default_factory=datetime.now)
|