@elizaos/training 2.0.0-alpha.11
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Dockerfile +75 -0
- package/Makefile +374 -0
- package/README.md +346 -0
- package/config/rubrics.json +137 -0
- package/data/.gitkeep +0 -0
- package/data/degen/.gitkeep +2 -0
- package/data/trader/.gitkeep +2 -0
- package/docker-compose.test.yml +57 -0
- package/package.json +58 -0
- package/python/config/babylon_atropos.yaml +90 -0
- package/python/config/profiles/12gb.json +11 -0
- package/python/config/profiles/16gb.json +10 -0
- package/python/config/profiles/24gb.json +10 -0
- package/python/config/profiles/48gb.json +10 -0
- package/python/config/profiles/cpu.json +11 -0
- package/python/config/profiles/l40-2gpu-safe.json +20 -0
- package/python/config/profiles/l40-2gpu.json +22 -0
- package/python/config/profiles/l40-4gpu.json +21 -0
- package/python/config/profiles/l40.json +17 -0
- package/python/config/tinker_training.yaml +143 -0
- package/python/curriculum_state.json +165 -0
- package/python/env.template +86 -0
- package/python/env.training.template +46 -0
- package/python/pyproject.toml +41 -0
- package/python/requirements-ci.txt +31 -0
- package/python/requirements.txt +87 -0
- package/python/scripts/__init__.py +4 -0
- package/python/scripts/import_json_trajectories.py +412 -0
- package/python/scripts/local-finetune/README.md +63 -0
- package/python/scripts/local-finetune/ingest_and_score.py +139 -0
- package/python/scripts/local-finetune/merge_model.py +32 -0
- package/python/scripts/local-finetune/test_adapter.py +91 -0
- package/python/scripts/local-finetune/train_from_csv.py +132 -0
- package/python/scripts/merge_trajectories.py +318 -0
- package/python/scripts/run_ab_test.py +143 -0
- package/python/scripts/run_full_pipeline.py +544 -0
- package/python/scripts/run_tinker_training.py +192 -0
- package/python/scripts/run_training.py +914 -0
- package/python/scripts/test_judge.py +155 -0
- package/python/scripts/test_pipeline.py +356 -0
- package/python/scripts/test_trained_model.py +380 -0
- package/python/scripts/train_local.py +528 -0
- package/python/setup.py +20 -0
- package/python/src/__init__.py +190 -0
- package/python/src/data_bridge/__init__.py +24 -0
- package/python/src/data_bridge/converter.py +435 -0
- package/python/src/data_bridge/reader.py +393 -0
- package/python/src/models.py +283 -0
- package/python/src/training/__init__.py +605 -0
- package/python/src/training/ab_testing.py +404 -0
- package/python/src/training/action_executor.py +621 -0
- package/python/src/training/archetype_trainer.py +347 -0
- package/python/src/training/atropos_trainer.py +980 -0
- package/python/src/training/babylon_env.py +1254 -0
- package/python/src/training/error_recovery.py +647 -0
- package/python/src/training/evaluation.py +856 -0
- package/python/src/training/fast_simulator.py +880 -0
- package/python/src/training/format_validator.py +584 -0
- package/python/src/training/hybrid_env.py +522 -0
- package/python/src/training/kl_controller.py +628 -0
- package/python/src/training/multi_prompt_dataset.py +883 -0
- package/python/src/training/multi_turn.py +656 -0
- package/python/src/training/online_env.py +1084 -0
- package/python/src/training/quality_scorer.py +391 -0
- package/python/src/training/quality_utils.py +633 -0
- package/python/src/training/rewards.py +1344 -0
- package/python/src/training/rlaif_env.py +17 -0
- package/python/src/training/rollout_generator.py +502 -0
- package/python/src/training/rubric_loader.py +198 -0
- package/python/src/training/scenario_pool.py +1072 -0
- package/python/src/training/schemas.py +481 -0
- package/python/src/training/service_manager.py +552 -0
- package/python/src/training/simulation_bridge.py +535 -0
- package/python/src/training/tick_reward_attribution.py +399 -0
- package/python/src/training/tinker_client.py +575 -0
- package/python/src/training/tinker_trainer.py +646 -0
- package/python/src/training/tokenization_utils.py +402 -0
- package/python/tests/e2e/__init__.py +13 -0
- package/python/tests/e2e/conftest.py +258 -0
- package/python/tests/e2e/test_full_pipeline.py +643 -0
- package/python/tests/e2e/test_online_training_e2e.py +365 -0
- package/python/tests/integration/__init__.py +12 -0
- package/python/tests/integration/conftest.py +383 -0
- package/python/tests/integration/test_db_integration.py +649 -0
- package/python/tests/integration/test_json_mode_integration.py +554 -0
- package/python/tests/test_action_executor.py +594 -0
- package/python/tests/test_archetype_scoring.py +1027 -0
- package/python/tests/test_atropos_integration.py +360 -0
- package/python/tests/test_evaluation.py +727 -0
- package/python/tests/test_format_validator.py +486 -0
- package/python/tests/test_kl_controller.py +432 -0
- package/python/tests/test_lr_scheduler.py +579 -0
- package/python/tests/test_multi_turn.py +590 -0
- package/python/tests/test_online_env.py +519 -0
- package/python/tests/test_quality_scorer.py +474 -0
- package/python/tests/test_scenario_pool.py +735 -0
- package/python/tests/test_service_manager.py +585 -0
- package/python/tests/test_simulation_rollout.py +581 -0
- package/python/tests/test_tokenization_utils.py +501 -0
- package/python/tests/test_training_orchestrator.py +497 -0
- package/python/tests/test_training_output_structure.py +661 -0
- package/research-output/training-runs/training-run-1770772042899.json +26 -0
- package/research-output/training-runs/training-run-1770930079670.json +32 -0
- package/research-output/training-runs/training-run-1770930143700.json +44 -0
- package/research-output/training-runs/training-run-1770930183638.json +38 -0
- package/research-output/training-runs/training-run-1770930442049.json +38 -0
- package/research-output/training-runs/training-run-1770930793243.json +38 -0
- package/scripts/assess-training-data.ts +422 -0
- package/scripts/e2e-training-test.ts +550 -0
- package/scripts/export-rubrics.ts +64 -0
- package/scripts/generate-research-report.ts +1523 -0
- package/scripts/generate_dataset.sh +173 -0
- package/scripts/json-mode-benchmark.ts +399 -0
- package/scripts/real-archetype-benchmark.ts +210 -0
- package/scripts/run-baseline-comparison.ts +116 -0
- package/scripts/run-full-pipeline.ts +272 -0
- package/scripts/runpod_setup.sh +137 -0
- package/scripts/runpod_validate.sh +147 -0
- package/scripts/test-model-in-game.ts +955 -0
- package/scripts/test-scoring.ts +73 -0
- package/scripts/test-trained-model.ts +209 -0
- package/scripts/train-and-test.ts +824 -0
- package/scripts/verify-final.ts +118 -0
- package/src/adapter.ts +516 -0
- package/src/archetypes/ArchetypeConfigService.ts +626 -0
- package/src/archetypes/derive-archetype.ts +249 -0
- package/src/archetypes/index.ts +22 -0
- package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
- package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
- package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
- package/src/benchmark/BenchmarkDataViewer.ts +324 -0
- package/src/benchmark/BenchmarkHistoryService.ts +221 -0
- package/src/benchmark/BenchmarkRunner.ts +685 -0
- package/src/benchmark/BenchmarkValidator.ts +206 -0
- package/src/benchmark/FastEvalRunner.ts +225 -0
- package/src/benchmark/MetricsValidator.ts +165 -0
- package/src/benchmark/MetricsVisualizer.ts +909 -0
- package/src/benchmark/ModelBenchmarkService.ts +611 -0
- package/src/benchmark/ModelRegistry.ts +158 -0
- package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
- package/src/benchmark/SimulationA2AInterface.ts +1169 -0
- package/src/benchmark/SimulationEngine.ts +832 -0
- package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
- package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
- package/src/benchmark/index.ts +89 -0
- package/src/benchmark/parseSimulationMetrics.ts +124 -0
- package/src/benchmark/simulation-types.ts +78 -0
- package/src/dependencies.ts +439 -0
- package/src/generation/TrajectoryGenerator.ts +387 -0
- package/src/generation/index.ts +12 -0
- package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
- package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
- package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
- package/src/huggingface/index.ts +27 -0
- package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
- package/src/index.ts +102 -0
- package/src/init-training.ts +53 -0
- package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
- package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
- package/src/metrics/index.ts +8 -0
- package/src/metrics/types.ts +200 -0
- package/src/rubrics/__tests__/index.test.ts +184 -0
- package/src/rubrics/ass-kisser.ts +85 -0
- package/src/rubrics/degen.ts +80 -0
- package/src/rubrics/goody-twoshoes.ts +84 -0
- package/src/rubrics/index.ts +236 -0
- package/src/rubrics/information-trader.ts +84 -0
- package/src/rubrics/infosec.ts +101 -0
- package/src/rubrics/liar.ts +104 -0
- package/src/rubrics/perps-trader.ts +87 -0
- package/src/rubrics/researcher.ts +81 -0
- package/src/rubrics/scammer.ts +82 -0
- package/src/rubrics/social-butterfly.ts +73 -0
- package/src/rubrics/super-predictor.ts +97 -0
- package/src/rubrics/trader.ts +67 -0
- package/src/scoring/ArchetypeScoringService.ts +486 -0
- package/src/scoring/JudgePromptBuilder.ts +556 -0
- package/src/scoring/LLMJudgeCache.ts +401 -0
- package/src/scoring/index.ts +9 -0
- package/src/training/AutomationPipeline.ts +916 -0
- package/src/training/BenchmarkService.ts +518 -0
- package/src/training/ConfigValidator.ts +220 -0
- package/src/training/MarketOutcomesTracker.ts +187 -0
- package/src/training/ModelDeployer.ts +186 -0
- package/src/training/ModelFetcher.ts +76 -0
- package/src/training/ModelSelectionService.ts +341 -0
- package/src/training/ModelUsageVerifier.ts +160 -0
- package/src/training/MultiModelOrchestrator.ts +580 -0
- package/src/training/RLModelConfig.ts +407 -0
- package/src/training/RewardBackpropagationService.ts +149 -0
- package/src/training/RulerScoringService.ts +666 -0
- package/src/training/TrainingMonitor.ts +166 -0
- package/src/training/TrajectoryRecorder.ts +399 -0
- package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
- package/src/training/index.ts +100 -0
- package/src/training/logRLConfig.ts +34 -0
- package/src/training/pipeline.ts +129 -0
- package/src/training/storage/ModelStorageService.ts +279 -0
- package/src/training/storage/TrainingDataArchiver.ts +197 -0
- package/src/training/storage/index.ts +17 -0
- package/src/training/types.ts +207 -0
- package/src/training/window-utils.ts +138 -0
- package/src/utils/index.ts +101 -0
- package/src/utils/logger.ts +59 -0
- package/src/utils/snowflake.ts +17 -0
- package/src/utils/synthetic-detector.ts +111 -0
- package/tsconfig.json +20 -0
|
@@ -0,0 +1,646 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tinker Trainer
|
|
3
|
+
|
|
4
|
+
Lightweight GRPO trainer using Tinker API.
|
|
5
|
+
Replaces heavy local vLLM + PyTorch training with cloud-based training.
|
|
6
|
+
|
|
7
|
+
This trainer:
|
|
8
|
+
1. Uses TinkerClient for training and inference
|
|
9
|
+
2. Integrates with RLAIFEnv for trajectory collection
|
|
10
|
+
3. Implements GRPO/IS training loop
|
|
11
|
+
4. Handles weight synchronization
|
|
12
|
+
|
|
13
|
+
Benefits over local training:
|
|
14
|
+
- No local GPU required
|
|
15
|
+
- Access to larger models (Qwen3-235B)
|
|
16
|
+
- Faster weight sync (no vLLM restarts)
|
|
17
|
+
- Better on-policy training with low staleness
|
|
18
|
+
- Pay only for training time, not idle GPU
|
|
19
|
+
|
|
20
|
+
Based on: tinker-atropos integration (Nous Research)
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import json
|
|
24
|
+
import logging
|
|
25
|
+
import os
|
|
26
|
+
from dataclasses import dataclass, field
|
|
27
|
+
from datetime import datetime, timezone
|
|
28
|
+
from pathlib import Path
|
|
29
|
+
from typing import List
|
|
30
|
+
|
|
31
|
+
import numpy as np
|
|
32
|
+
from dotenv import load_dotenv
|
|
33
|
+
from pydantic import BaseModel, Field
|
|
34
|
+
|
|
35
|
+
from .tinker_client import (
|
|
36
|
+
TinkerClient,
|
|
37
|
+
TinkerConfig,
|
|
38
|
+
TinkerDatum,
|
|
39
|
+
TINKER_AVAILABLE,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
logger = logging.getLogger(__name__)
|
|
43
|
+
|
|
44
|
+
# Load environment variables
|
|
45
|
+
project_root = Path(__file__).parent.parent.parent.parent
|
|
46
|
+
env_path = project_root / ".env"
|
|
47
|
+
env_local_path = project_root / ".env.local"
|
|
48
|
+
|
|
49
|
+
if env_local_path.exists():
|
|
50
|
+
load_dotenv(env_local_path, override=True)
|
|
51
|
+
if env_path.exists():
|
|
52
|
+
load_dotenv(env_path, override=False)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class TinkerTrainingConfig(BaseModel):
|
|
56
|
+
"""Configuration for Tinker-based training"""
|
|
57
|
+
|
|
58
|
+
# Model settings
|
|
59
|
+
base_model: str = Field(
|
|
60
|
+
default="Qwen/Qwen3-30B-A3B-Instruct",
|
|
61
|
+
description="Base model from Tinker's supported models",
|
|
62
|
+
)
|
|
63
|
+
lora_rank: int = Field(default=32, description="LoRA rank for fine-tuning")
|
|
64
|
+
|
|
65
|
+
# Training hyperparameters
|
|
66
|
+
learning_rate: float = Field(default=4e-5, description="Learning rate")
|
|
67
|
+
training_steps: int = Field(default=100, description="Number of training steps")
|
|
68
|
+
group_size: int = Field(default=4, description="Group size for GRPO comparison")
|
|
69
|
+
|
|
70
|
+
# Weight sync settings
|
|
71
|
+
weight_sync_interval: int = Field(
|
|
72
|
+
default=5, description="Sync weights to sampler every N steps"
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# Environment settings
|
|
76
|
+
database_url: str = Field(
|
|
77
|
+
default_factory=lambda: os.getenv("DATABASE_URL", ""),
|
|
78
|
+
description="PostgreSQL connection URL",
|
|
79
|
+
)
|
|
80
|
+
lookback_hours: int = Field(
|
|
81
|
+
default=72, description="Hours to look back for trajectories"
|
|
82
|
+
)
|
|
83
|
+
min_agents_per_window: int = Field(
|
|
84
|
+
default=2, description="Minimum agents per window"
|
|
85
|
+
)
|
|
86
|
+
min_actions_per_trajectory: int = Field(
|
|
87
|
+
default=3, description="Minimum actions per trajectory"
|
|
88
|
+
)
|
|
89
|
+
max_steps_per_trajectory: int = Field(
|
|
90
|
+
default=20, description="Max steps to include per trajectory"
|
|
91
|
+
)
|
|
92
|
+
max_token_length: int = Field(default=4096, description="Maximum sequence length")
|
|
93
|
+
|
|
94
|
+
# RLAIF Judge settings
|
|
95
|
+
judge_model: str = Field(default="gpt-4o-mini", description="Model for RLAIF judge")
|
|
96
|
+
judge_temperature: float = Field(default=0.3, description="Judge temperature")
|
|
97
|
+
|
|
98
|
+
# Logging settings
|
|
99
|
+
log_to_file: bool = Field(default=True, description="Log metrics to file")
|
|
100
|
+
log_file: str = Field(
|
|
101
|
+
default="./logs/tinker_training_metrics.jsonl", description="Metrics log file"
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# Inference settings
|
|
105
|
+
inference_max_tokens: int = Field(
|
|
106
|
+
default=512, description="Max tokens for inference"
|
|
107
|
+
)
|
|
108
|
+
inference_temperature: float = Field(
|
|
109
|
+
default=0.7, description="Temperature for inference"
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
@dataclass
|
|
114
|
+
class TrainingMetrics:
|
|
115
|
+
"""Metrics from training"""
|
|
116
|
+
|
|
117
|
+
step: int
|
|
118
|
+
loss: float
|
|
119
|
+
num_samples: int
|
|
120
|
+
logprobs_mean: float = 0.0
|
|
121
|
+
pos_advantage_mean: float = 0.0
|
|
122
|
+
neg_advantage_mean: float = 0.0
|
|
123
|
+
avg_score: float = 0.0
|
|
124
|
+
windows_processed: int = 0
|
|
125
|
+
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class TinkerTrainer:
|
|
129
|
+
"""
|
|
130
|
+
GRPO Trainer using Tinker API.
|
|
131
|
+
|
|
132
|
+
This replaces local heavyweight trainer flows with a lighter implementation:
|
|
133
|
+
- No local vLLM management
|
|
134
|
+
- No GPU requirements on training machine
|
|
135
|
+
- Training happens in Tinker cloud
|
|
136
|
+
- Only data loading runs locally
|
|
137
|
+
|
|
138
|
+
The training loop:
|
|
139
|
+
1. Load trajectory groups from database
|
|
140
|
+
2. Score trajectories using LLM judge (RLAIF)
|
|
141
|
+
3. Convert to training format
|
|
142
|
+
4. Call Tinker for forward_backward + optim_step
|
|
143
|
+
5. Periodically sync weights to sampling client
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
def __init__(self, config: TinkerTrainingConfig):
|
|
147
|
+
if not TINKER_AVAILABLE:
|
|
148
|
+
raise RuntimeError(
|
|
149
|
+
"Tinker not installed. Install with: pip install tinker"
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
self.config = config
|
|
153
|
+
self.tinker_config = TinkerConfig(
|
|
154
|
+
base_model=config.base_model,
|
|
155
|
+
lora_rank=config.lora_rank,
|
|
156
|
+
learning_rate=config.learning_rate,
|
|
157
|
+
default_max_tokens=config.inference_max_tokens,
|
|
158
|
+
default_temperature=config.inference_temperature,
|
|
159
|
+
)
|
|
160
|
+
self.tinker_client = TinkerClient(self.tinker_config)
|
|
161
|
+
|
|
162
|
+
self.current_step = 0
|
|
163
|
+
self.run_id = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
|
|
164
|
+
self.all_metrics: List[TrainingMetrics] = []
|
|
165
|
+
|
|
166
|
+
# Database pool (lazy init)
|
|
167
|
+
self._db_pool = None
|
|
168
|
+
|
|
169
|
+
# Judge client (lazy init)
|
|
170
|
+
self._judge_client = None
|
|
171
|
+
|
|
172
|
+
async def setup(self) -> None:
|
|
173
|
+
"""Initialize Tinker client and database connection"""
|
|
174
|
+
logger.info(f"Setting up Tinker trainer with {self.config.base_model}")
|
|
175
|
+
logger.info(f"Run ID: {self.run_id}")
|
|
176
|
+
|
|
177
|
+
# Initialize Tinker
|
|
178
|
+
self.tinker_client.setup()
|
|
179
|
+
logger.info("Tinker client initialized")
|
|
180
|
+
|
|
181
|
+
# Setup logging
|
|
182
|
+
if self.config.log_to_file:
|
|
183
|
+
log_dir = Path(self.config.log_file).parent
|
|
184
|
+
log_dir.mkdir(parents=True, exist_ok=True)
|
|
185
|
+
logger.info(f"Metrics will be logged to: {self.config.log_file}")
|
|
186
|
+
|
|
187
|
+
# Connect to database
|
|
188
|
+
await self._connect_database()
|
|
189
|
+
|
|
190
|
+
# Initialize judge
|
|
191
|
+
await self._init_judge()
|
|
192
|
+
|
|
193
|
+
logger.info("Setup complete")
|
|
194
|
+
|
|
195
|
+
async def _connect_database(self) -> None:
|
|
196
|
+
"""Connect to PostgreSQL database"""
|
|
197
|
+
import asyncpg
|
|
198
|
+
|
|
199
|
+
if not self.config.database_url:
|
|
200
|
+
raise ValueError("DATABASE_URL not set")
|
|
201
|
+
|
|
202
|
+
self._db_pool = await asyncpg.create_pool(
|
|
203
|
+
self.config.database_url,
|
|
204
|
+
min_size=2,
|
|
205
|
+
max_size=10,
|
|
206
|
+
command_timeout=60,
|
|
207
|
+
)
|
|
208
|
+
logger.info("Connected to database")
|
|
209
|
+
|
|
210
|
+
async def _init_judge(self) -> None:
|
|
211
|
+
"""Initialize OpenAI client for RLAIF judge"""
|
|
212
|
+
import openai
|
|
213
|
+
|
|
214
|
+
self._judge_client = openai.AsyncOpenAI()
|
|
215
|
+
logger.info(f"Judge initialized with model: {self.config.judge_model}")
|
|
216
|
+
|
|
217
|
+
async def cleanup(self) -> None:
|
|
218
|
+
"""Clean up resources"""
|
|
219
|
+
if self._db_pool:
|
|
220
|
+
await self._db_pool.close()
|
|
221
|
+
self._db_pool = None
|
|
222
|
+
logger.info("Database connection closed")
|
|
223
|
+
|
|
224
|
+
def log_metrics(self, metrics: TrainingMetrics) -> None:
|
|
225
|
+
"""Log metrics to file"""
|
|
226
|
+
if self.config.log_to_file:
|
|
227
|
+
metrics_dict = {
|
|
228
|
+
"timestamp": metrics.timestamp,
|
|
229
|
+
"run_id": self.run_id,
|
|
230
|
+
"step": metrics.step,
|
|
231
|
+
"loss": metrics.loss,
|
|
232
|
+
"num_samples": metrics.num_samples,
|
|
233
|
+
"logprobs_mean": metrics.logprobs_mean,
|
|
234
|
+
"pos_advantage_mean": metrics.pos_advantage_mean,
|
|
235
|
+
"neg_advantage_mean": metrics.neg_advantage_mean,
|
|
236
|
+
"avg_score": metrics.avg_score,
|
|
237
|
+
"windows_processed": metrics.windows_processed,
|
|
238
|
+
}
|
|
239
|
+
with open(self.config.log_file, "a") as f:
|
|
240
|
+
f.write(json.dumps(metrics_dict) + "\n")
|
|
241
|
+
|
|
242
|
+
self.all_metrics.append(metrics)
|
|
243
|
+
|
|
244
|
+
async def load_trajectory_groups(self) -> List[dict]:
|
|
245
|
+
"""Load trajectory groups from database"""
|
|
246
|
+
if not self._db_pool:
|
|
247
|
+
raise RuntimeError("Database not connected")
|
|
248
|
+
|
|
249
|
+
async with self._db_pool.acquire() as conn:
|
|
250
|
+
rows = await conn.fetch(
|
|
251
|
+
"""
|
|
252
|
+
SELECT
|
|
253
|
+
t."trajectoryId",
|
|
254
|
+
t."agentId",
|
|
255
|
+
t."windowId",
|
|
256
|
+
t."scenarioId",
|
|
257
|
+
t."stepsJson",
|
|
258
|
+
t."finalPnL",
|
|
259
|
+
t."episodeLength",
|
|
260
|
+
t."totalReward",
|
|
261
|
+
u.username as agent_name
|
|
262
|
+
FROM trajectories t
|
|
263
|
+
LEFT JOIN "User" u ON t."agentId" = u.id
|
|
264
|
+
WHERE
|
|
265
|
+
t."createdAt" > NOW() - $1::interval
|
|
266
|
+
AND t."stepsJson" IS NOT NULL
|
|
267
|
+
AND t."stepsJson"::text != 'null'
|
|
268
|
+
AND t."stepsJson"::text != '[]'
|
|
269
|
+
AND t."episodeLength" >= $2
|
|
270
|
+
ORDER BY t."windowId", t."scenarioId", t."createdAt"
|
|
271
|
+
""",
|
|
272
|
+
f"{self.config.lookback_hours} hours",
|
|
273
|
+
self.config.min_actions_per_trajectory,
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
# Group by window/scenario
|
|
277
|
+
groups: dict = {}
|
|
278
|
+
for row in rows:
|
|
279
|
+
group_key = f"{row['windowId']}_{row['scenarioId'] or 'default'}"
|
|
280
|
+
|
|
281
|
+
if group_key not in groups:
|
|
282
|
+
groups[group_key] = []
|
|
283
|
+
|
|
284
|
+
steps = json.loads(row["stepsJson"] or "[]")
|
|
285
|
+
if len(steps) < self.config.min_actions_per_trajectory:
|
|
286
|
+
continue
|
|
287
|
+
|
|
288
|
+
groups[group_key].append(
|
|
289
|
+
{
|
|
290
|
+
"trajectory_id": row["trajectoryId"],
|
|
291
|
+
"agent_id": row["agentId"],
|
|
292
|
+
"agent_name": row["agent_name"] or row["agentId"][:8],
|
|
293
|
+
"window_id": row["windowId"],
|
|
294
|
+
"scenario_id": row["scenarioId"],
|
|
295
|
+
"steps": steps,
|
|
296
|
+
"final_pnl": float(row["finalPnL"] or 0),
|
|
297
|
+
"episode_length": row["episodeLength"] or len(steps),
|
|
298
|
+
"total_reward": float(row["totalReward"] or 0),
|
|
299
|
+
}
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
# Filter groups with enough trajectories
|
|
303
|
+
valid_groups = [
|
|
304
|
+
{"group_key": k, "trajectories": v}
|
|
305
|
+
for k, v in groups.items()
|
|
306
|
+
if len(v) >= self.config.min_agents_per_window
|
|
307
|
+
]
|
|
308
|
+
|
|
309
|
+
logger.info(f"Loaded {len(valid_groups)} trajectory groups")
|
|
310
|
+
return valid_groups
|
|
311
|
+
|
|
312
|
+
def trajectory_to_messages(self, traj: dict) -> List[dict]:
|
|
313
|
+
"""Convert trajectory to chat messages format"""
|
|
314
|
+
messages = []
|
|
315
|
+
|
|
316
|
+
# System message
|
|
317
|
+
system_content = f"""You are a trading agent in a prediction market simulation.
|
|
318
|
+
|
|
319
|
+
Agent: {traj.get('agent_name', 'Agent')}
|
|
320
|
+
Window: {traj.get('window_id', 'Unknown')}
|
|
321
|
+
Final P&L: ${traj.get('final_pnl', 0):.2f}
|
|
322
|
+
|
|
323
|
+
Your goal is to make profitable trading decisions based on market analysis."""
|
|
324
|
+
|
|
325
|
+
messages.append({"role": "system", "content": system_content})
|
|
326
|
+
|
|
327
|
+
# Convert steps
|
|
328
|
+
steps = traj.get("steps", [])
|
|
329
|
+
max_steps = self.config.max_steps_per_trajectory
|
|
330
|
+
|
|
331
|
+
if len(steps) > max_steps:
|
|
332
|
+
steps = steps[-max_steps:]
|
|
333
|
+
|
|
334
|
+
for step_idx, step in enumerate(steps):
|
|
335
|
+
if not isinstance(step, dict):
|
|
336
|
+
continue
|
|
337
|
+
|
|
338
|
+
# Get LLM calls if available
|
|
339
|
+
llm_calls = step.get("llmCalls", step.get("llm_calls", []))
|
|
340
|
+
|
|
341
|
+
if llm_calls:
|
|
342
|
+
for llm_call in llm_calls:
|
|
343
|
+
purpose = llm_call.get("purpose", "action")
|
|
344
|
+
user_prompt = llm_call.get(
|
|
345
|
+
"userPrompt", llm_call.get("user_prompt", "")
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
# Build user content
|
|
349
|
+
user_content = f"[Step {step_idx + 1}, {purpose.upper()}]\n"
|
|
350
|
+
|
|
351
|
+
env_state = step.get(
|
|
352
|
+
"environmentState", step.get("environment_state", {})
|
|
353
|
+
)
|
|
354
|
+
if env_state:
|
|
355
|
+
balance = env_state.get(
|
|
356
|
+
"agentBalance", env_state.get("agent_balance", 0)
|
|
357
|
+
)
|
|
358
|
+
pnl = env_state.get("agentPnL", env_state.get("agent_pnl", 0))
|
|
359
|
+
positions = env_state.get(
|
|
360
|
+
"openPositions", env_state.get("open_positions", 0)
|
|
361
|
+
)
|
|
362
|
+
user_content += (
|
|
363
|
+
f"State: Balance=${balance:.2f}, "
|
|
364
|
+
f"P&L=${pnl:.2f}, Positions={positions}\n\n"
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
if user_prompt:
|
|
368
|
+
user_content += user_prompt
|
|
369
|
+
|
|
370
|
+
messages.append({"role": "user", "content": user_content})
|
|
371
|
+
|
|
372
|
+
# Assistant response
|
|
373
|
+
response = llm_call.get("response", "")
|
|
374
|
+
reasoning = llm_call.get("reasoning", "")
|
|
375
|
+
|
|
376
|
+
assistant_content = ""
|
|
377
|
+
if reasoning:
|
|
378
|
+
assistant_content += f"<thinking>\n{reasoning}\n</thinking>\n\n"
|
|
379
|
+
if response:
|
|
380
|
+
assistant_content += response
|
|
381
|
+
|
|
382
|
+
if assistant_content.strip():
|
|
383
|
+
messages.append(
|
|
384
|
+
{"role": "assistant", "content": assistant_content}
|
|
385
|
+
)
|
|
386
|
+
else:
|
|
387
|
+
# Fallback: build from environment state and action
|
|
388
|
+
env_state = step.get(
|
|
389
|
+
"environmentState", step.get("environment_state", {})
|
|
390
|
+
)
|
|
391
|
+
balance = env_state.get(
|
|
392
|
+
"agentBalance", env_state.get("agent_balance", 0)
|
|
393
|
+
)
|
|
394
|
+
pnl = env_state.get("agentPnL", env_state.get("agent_pnl", 0))
|
|
395
|
+
positions = env_state.get(
|
|
396
|
+
"openPositions", env_state.get("open_positions", 0)
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
user_content = (
|
|
400
|
+
f"[Step {step_idx + 1}]\n"
|
|
401
|
+
f"Market Update:\n"
|
|
402
|
+
f"- Balance: ${balance:.2f}\n"
|
|
403
|
+
f"- P&L: ${pnl:.2f}\n"
|
|
404
|
+
f"- Open Positions: {positions}"
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
messages.append({"role": "user", "content": user_content})
|
|
408
|
+
|
|
409
|
+
# Action as assistant message
|
|
410
|
+
action = step.get("action", {})
|
|
411
|
+
action_type = action.get(
|
|
412
|
+
"actionType", action.get("action_type", "wait")
|
|
413
|
+
)
|
|
414
|
+
params = action.get("parameters", {})
|
|
415
|
+
reasoning = action.get("reasoning", "")
|
|
416
|
+
|
|
417
|
+
assistant_content = ""
|
|
418
|
+
if reasoning:
|
|
419
|
+
assistant_content += f"<thinking>\n{reasoning}\n</thinking>\n\n"
|
|
420
|
+
assistant_content += f"Action: {action_type}"
|
|
421
|
+
if params:
|
|
422
|
+
assistant_content += f"\nParameters: {json.dumps(params, indent=2)}"
|
|
423
|
+
|
|
424
|
+
messages.append({"role": "assistant", "content": assistant_content})
|
|
425
|
+
|
|
426
|
+
return messages
|
|
427
|
+
|
|
428
|
+
async def score_trajectories(
|
|
429
|
+
self, trajectories: List[dict]
|
|
430
|
+
) -> List[float]:
|
|
431
|
+
"""Score trajectories using LLM judge (RLAIF)"""
|
|
432
|
+
# Build judge prompt
|
|
433
|
+
prompt_parts = [
|
|
434
|
+
"# Trading Agent Evaluation\n",
|
|
435
|
+
"Score each trajectory from 0.0 to 1.0 based on:\n",
|
|
436
|
+
"- Profitability (higher P&L = higher score)\n",
|
|
437
|
+
"- Risk management\n",
|
|
438
|
+
"- Decision quality\n\n",
|
|
439
|
+
"## Trajectories:\n",
|
|
440
|
+
]
|
|
441
|
+
|
|
442
|
+
for i, traj in enumerate(trajectories):
|
|
443
|
+
prompt_parts.append(f"\n### Trajectory {i + 1}:")
|
|
444
|
+
prompt_parts.append(f"- Agent: {traj.get('agent_name', 'Unknown')}")
|
|
445
|
+
prompt_parts.append(f"- Final P&L: ${traj.get('final_pnl', 0):.2f}")
|
|
446
|
+
prompt_parts.append(f"- Episode Length: {traj.get('episode_length', 0)}")
|
|
447
|
+
|
|
448
|
+
prompt_parts.append("\n## Output (JSON only):")
|
|
449
|
+
prompt_parts.append(
|
|
450
|
+
'{"scores": [{"trajectory_id": 1, "score": 0.85}, ...]}'
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
judge_prompt = "\n".join(prompt_parts)
|
|
454
|
+
|
|
455
|
+
# Call judge
|
|
456
|
+
response = await self._judge_client.chat.completions.create(
|
|
457
|
+
model=self.config.judge_model,
|
|
458
|
+
messages=[
|
|
459
|
+
{
|
|
460
|
+
"role": "system",
|
|
461
|
+
"content": "You are an expert evaluator. Respond with valid JSON only.",
|
|
462
|
+
},
|
|
463
|
+
{"role": "user", "content": judge_prompt},
|
|
464
|
+
],
|
|
465
|
+
max_tokens=500,
|
|
466
|
+
temperature=self.config.judge_temperature,
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
# Parse response
|
|
470
|
+
content = response.choices[0].message.content or ""
|
|
471
|
+
try:
|
|
472
|
+
# Clean and parse JSON
|
|
473
|
+
clean = content.strip().replace("```json", "").replace("```", "")
|
|
474
|
+
if "{" in clean:
|
|
475
|
+
start = clean.find("{")
|
|
476
|
+
end = clean.rfind("}") + 1
|
|
477
|
+
parsed = json.loads(clean[start:end])
|
|
478
|
+
scores_data = parsed.get("scores", parsed)
|
|
479
|
+
|
|
480
|
+
scores = []
|
|
481
|
+
for item in scores_data:
|
|
482
|
+
if isinstance(item, dict):
|
|
483
|
+
scores.append(float(item.get("score", 0.5)))
|
|
484
|
+
else:
|
|
485
|
+
scores.append(float(item))
|
|
486
|
+
|
|
487
|
+
if len(scores) == len(trajectories):
|
|
488
|
+
return scores
|
|
489
|
+
|
|
490
|
+
except (json.JSONDecodeError, ValueError, KeyError) as e:
|
|
491
|
+
logger.warning(f"Failed to parse judge response: {e}")
|
|
492
|
+
|
|
493
|
+
# Fallback: P&L-based scoring
|
|
494
|
+
pnls = [t.get("final_pnl", 0) for t in trajectories]
|
|
495
|
+
min_pnl, max_pnl = min(pnls), max(pnls)
|
|
496
|
+
pnl_range = max_pnl - min_pnl if max_pnl != min_pnl else 1.0
|
|
497
|
+
|
|
498
|
+
return [(p - min_pnl) / pnl_range for p in pnls]
|
|
499
|
+
|
|
500
|
+
async def train_on_group(
|
|
501
|
+
self, group: dict
|
|
502
|
+
) -> TrainingMetrics | None:
|
|
503
|
+
"""Train on a single trajectory group"""
|
|
504
|
+
trajectories = group["trajectories"]
|
|
505
|
+
|
|
506
|
+
# Sample if too many
|
|
507
|
+
if len(trajectories) > self.config.group_size:
|
|
508
|
+
import random
|
|
509
|
+
|
|
510
|
+
trajectories = random.sample(trajectories, self.config.group_size)
|
|
511
|
+
|
|
512
|
+
if len(trajectories) < 2:
|
|
513
|
+
logger.warning(f"Group {group['group_key']} has insufficient trajectories")
|
|
514
|
+
return None
|
|
515
|
+
|
|
516
|
+
# Score trajectories
|
|
517
|
+
scores = await self.score_trajectories(trajectories)
|
|
518
|
+
|
|
519
|
+
# Normalize to mean 0 for GRPO
|
|
520
|
+
mean_score = sum(scores) / len(scores)
|
|
521
|
+
advantages = [s - mean_score for s in scores]
|
|
522
|
+
|
|
523
|
+
# Normalize variance
|
|
524
|
+
if len(advantages) > 1:
|
|
525
|
+
std = float(np.std(advantages))
|
|
526
|
+
if std > 1e-8:
|
|
527
|
+
advantages = [a / std for a in advantages]
|
|
528
|
+
|
|
529
|
+
# Convert to training data
|
|
530
|
+
data: List[TinkerDatum] = []
|
|
531
|
+
valid_advantages: List[float] = []
|
|
532
|
+
|
|
533
|
+
for traj, advantage in zip(trajectories, advantages):
|
|
534
|
+
messages = self.trajectory_to_messages(traj)
|
|
535
|
+
|
|
536
|
+
if len(messages) < 3: # Need at least system + user + assistant
|
|
537
|
+
continue
|
|
538
|
+
|
|
539
|
+
# Get last assistant message as completion
|
|
540
|
+
assistant_msgs = [m for m in messages if m["role"] == "assistant"]
|
|
541
|
+
if not assistant_msgs:
|
|
542
|
+
continue
|
|
543
|
+
|
|
544
|
+
completion = assistant_msgs[-1]["content"]
|
|
545
|
+
context_messages = messages[:-1] # All but last
|
|
546
|
+
|
|
547
|
+
# Prepare datum
|
|
548
|
+
datum = self.tinker_client.prepare_datum(
|
|
549
|
+
messages=context_messages,
|
|
550
|
+
completion=completion,
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
data.append(datum)
|
|
554
|
+
valid_advantages.append(advantage)
|
|
555
|
+
|
|
556
|
+
if not data:
|
|
557
|
+
logger.warning("No valid training data from group")
|
|
558
|
+
return None
|
|
559
|
+
|
|
560
|
+
# Train step
|
|
561
|
+
result = self.tinker_client.train_step(
|
|
562
|
+
data=data,
|
|
563
|
+
scores=valid_advantages,
|
|
564
|
+
loss_fn="importance_sampling",
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
return TrainingMetrics(
|
|
568
|
+
step=self.current_step,
|
|
569
|
+
loss=result.loss,
|
|
570
|
+
num_samples=result.num_samples,
|
|
571
|
+
logprobs_mean=result.logprobs_mean,
|
|
572
|
+
pos_advantage_mean=result.pos_advantage_mean,
|
|
573
|
+
neg_advantage_mean=result.neg_advantage_mean,
|
|
574
|
+
avg_score=float(np.mean(scores)),
|
|
575
|
+
)
|
|
576
|
+
|
|
577
|
+
async def train(self) -> dict:
|
|
578
|
+
"""Main training loop"""
|
|
579
|
+
await self.setup()
|
|
580
|
+
|
|
581
|
+
try:
|
|
582
|
+
logger.info(f"Starting training for {self.config.training_steps} steps")
|
|
583
|
+
|
|
584
|
+
# Load all trajectory groups
|
|
585
|
+
all_groups = await self.load_trajectory_groups()
|
|
586
|
+
|
|
587
|
+
if not all_groups:
|
|
588
|
+
raise ValueError("No trajectory groups found")
|
|
589
|
+
|
|
590
|
+
group_idx = 0
|
|
591
|
+
windows_processed = 0
|
|
592
|
+
|
|
593
|
+
for step in range(self.config.training_steps):
|
|
594
|
+
self.current_step = step + 1
|
|
595
|
+
logger.info(
|
|
596
|
+
f"Step {self.current_step}/{self.config.training_steps}"
|
|
597
|
+
)
|
|
598
|
+
|
|
599
|
+
# Get next group (circular)
|
|
600
|
+
group = all_groups[group_idx % len(all_groups)]
|
|
601
|
+
group_idx += 1
|
|
602
|
+
|
|
603
|
+
# Train on group
|
|
604
|
+
metrics = await self.train_on_group(group)
|
|
605
|
+
|
|
606
|
+
if metrics:
|
|
607
|
+
windows_processed += 1
|
|
608
|
+
metrics.windows_processed = windows_processed
|
|
609
|
+
|
|
610
|
+
logger.info(
|
|
611
|
+
f" Loss: {metrics.loss:.4f}, "
|
|
612
|
+
f"Samples: {metrics.num_samples}, "
|
|
613
|
+
f"Avg Score: {metrics.avg_score:.3f}"
|
|
614
|
+
)
|
|
615
|
+
|
|
616
|
+
self.log_metrics(metrics)
|
|
617
|
+
else:
|
|
618
|
+
logger.warning(" No metrics (empty batch)")
|
|
619
|
+
|
|
620
|
+
# Sync weights periodically
|
|
621
|
+
if self.current_step % self.config.weight_sync_interval == 0:
|
|
622
|
+
logger.info("Syncing weights to sampling client...")
|
|
623
|
+
self.tinker_client.sync_weights(
|
|
624
|
+
name=f"eliza-{self.run_id}-step-{self.current_step}"
|
|
625
|
+
)
|
|
626
|
+
|
|
627
|
+
# Final weight sync
|
|
628
|
+
final_name = f"eliza-{self.run_id}-final"
|
|
629
|
+
self.tinker_client.sync_weights(name=final_name)
|
|
630
|
+
logger.info(f"Training complete! Final weights: {final_name}")
|
|
631
|
+
|
|
632
|
+
return {
|
|
633
|
+
"success": True,
|
|
634
|
+
"run_id": self.run_id,
|
|
635
|
+
"steps": self.current_step,
|
|
636
|
+
"windows_processed": windows_processed,
|
|
637
|
+
"final_weights": final_name,
|
|
638
|
+
"metrics_file": self.config.log_file if self.config.log_to_file else None,
|
|
639
|
+
}
|
|
640
|
+
|
|
641
|
+
finally:
|
|
642
|
+
await self.cleanup()
|
|
643
|
+
|
|
644
|
+
|
|
645
|
+
# Backward compatibility alias while imports migrate.
|
|
646
|
+
BabylonTinkerTrainer = TinkerTrainer
|