@elizaos/training 2.0.0-alpha.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (207) hide show
  1. package/Dockerfile +75 -0
  2. package/Makefile +374 -0
  3. package/README.md +346 -0
  4. package/config/rubrics.json +137 -0
  5. package/data/.gitkeep +0 -0
  6. package/data/degen/.gitkeep +2 -0
  7. package/data/trader/.gitkeep +2 -0
  8. package/docker-compose.test.yml +57 -0
  9. package/package.json +58 -0
  10. package/python/config/babylon_atropos.yaml +90 -0
  11. package/python/config/profiles/12gb.json +11 -0
  12. package/python/config/profiles/16gb.json +10 -0
  13. package/python/config/profiles/24gb.json +10 -0
  14. package/python/config/profiles/48gb.json +10 -0
  15. package/python/config/profiles/cpu.json +11 -0
  16. package/python/config/profiles/l40-2gpu-safe.json +20 -0
  17. package/python/config/profiles/l40-2gpu.json +22 -0
  18. package/python/config/profiles/l40-4gpu.json +21 -0
  19. package/python/config/profiles/l40.json +17 -0
  20. package/python/config/tinker_training.yaml +143 -0
  21. package/python/curriculum_state.json +165 -0
  22. package/python/env.template +86 -0
  23. package/python/env.training.template +46 -0
  24. package/python/pyproject.toml +41 -0
  25. package/python/requirements-ci.txt +31 -0
  26. package/python/requirements.txt +87 -0
  27. package/python/scripts/__init__.py +4 -0
  28. package/python/scripts/import_json_trajectories.py +412 -0
  29. package/python/scripts/local-finetune/README.md +63 -0
  30. package/python/scripts/local-finetune/ingest_and_score.py +139 -0
  31. package/python/scripts/local-finetune/merge_model.py +32 -0
  32. package/python/scripts/local-finetune/test_adapter.py +91 -0
  33. package/python/scripts/local-finetune/train_from_csv.py +132 -0
  34. package/python/scripts/merge_trajectories.py +318 -0
  35. package/python/scripts/run_ab_test.py +143 -0
  36. package/python/scripts/run_full_pipeline.py +544 -0
  37. package/python/scripts/run_tinker_training.py +192 -0
  38. package/python/scripts/run_training.py +914 -0
  39. package/python/scripts/test_judge.py +155 -0
  40. package/python/scripts/test_pipeline.py +356 -0
  41. package/python/scripts/test_trained_model.py +380 -0
  42. package/python/scripts/train_local.py +528 -0
  43. package/python/setup.py +20 -0
  44. package/python/src/__init__.py +190 -0
  45. package/python/src/data_bridge/__init__.py +24 -0
  46. package/python/src/data_bridge/converter.py +435 -0
  47. package/python/src/data_bridge/reader.py +393 -0
  48. package/python/src/models.py +283 -0
  49. package/python/src/training/__init__.py +605 -0
  50. package/python/src/training/ab_testing.py +404 -0
  51. package/python/src/training/action_executor.py +621 -0
  52. package/python/src/training/archetype_trainer.py +347 -0
  53. package/python/src/training/atropos_trainer.py +980 -0
  54. package/python/src/training/babylon_env.py +1254 -0
  55. package/python/src/training/error_recovery.py +647 -0
  56. package/python/src/training/evaluation.py +856 -0
  57. package/python/src/training/fast_simulator.py +880 -0
  58. package/python/src/training/format_validator.py +584 -0
  59. package/python/src/training/hybrid_env.py +522 -0
  60. package/python/src/training/kl_controller.py +628 -0
  61. package/python/src/training/multi_prompt_dataset.py +883 -0
  62. package/python/src/training/multi_turn.py +656 -0
  63. package/python/src/training/online_env.py +1084 -0
  64. package/python/src/training/quality_scorer.py +391 -0
  65. package/python/src/training/quality_utils.py +633 -0
  66. package/python/src/training/rewards.py +1344 -0
  67. package/python/src/training/rlaif_env.py +17 -0
  68. package/python/src/training/rollout_generator.py +502 -0
  69. package/python/src/training/rubric_loader.py +198 -0
  70. package/python/src/training/scenario_pool.py +1072 -0
  71. package/python/src/training/schemas.py +481 -0
  72. package/python/src/training/service_manager.py +552 -0
  73. package/python/src/training/simulation_bridge.py +535 -0
  74. package/python/src/training/tick_reward_attribution.py +399 -0
  75. package/python/src/training/tinker_client.py +575 -0
  76. package/python/src/training/tinker_trainer.py +646 -0
  77. package/python/src/training/tokenization_utils.py +402 -0
  78. package/python/tests/e2e/__init__.py +13 -0
  79. package/python/tests/e2e/conftest.py +258 -0
  80. package/python/tests/e2e/test_full_pipeline.py +643 -0
  81. package/python/tests/e2e/test_online_training_e2e.py +365 -0
  82. package/python/tests/integration/__init__.py +12 -0
  83. package/python/tests/integration/conftest.py +383 -0
  84. package/python/tests/integration/test_db_integration.py +649 -0
  85. package/python/tests/integration/test_json_mode_integration.py +554 -0
  86. package/python/tests/test_action_executor.py +594 -0
  87. package/python/tests/test_archetype_scoring.py +1027 -0
  88. package/python/tests/test_atropos_integration.py +360 -0
  89. package/python/tests/test_evaluation.py +727 -0
  90. package/python/tests/test_format_validator.py +486 -0
  91. package/python/tests/test_kl_controller.py +432 -0
  92. package/python/tests/test_lr_scheduler.py +579 -0
  93. package/python/tests/test_multi_turn.py +590 -0
  94. package/python/tests/test_online_env.py +519 -0
  95. package/python/tests/test_quality_scorer.py +474 -0
  96. package/python/tests/test_scenario_pool.py +735 -0
  97. package/python/tests/test_service_manager.py +585 -0
  98. package/python/tests/test_simulation_rollout.py +581 -0
  99. package/python/tests/test_tokenization_utils.py +501 -0
  100. package/python/tests/test_training_orchestrator.py +497 -0
  101. package/python/tests/test_training_output_structure.py +661 -0
  102. package/research-output/training-runs/training-run-1770772042899.json +26 -0
  103. package/research-output/training-runs/training-run-1770930079670.json +32 -0
  104. package/research-output/training-runs/training-run-1770930143700.json +44 -0
  105. package/research-output/training-runs/training-run-1770930183638.json +38 -0
  106. package/research-output/training-runs/training-run-1770930442049.json +38 -0
  107. package/research-output/training-runs/training-run-1770930793243.json +38 -0
  108. package/scripts/assess-training-data.ts +422 -0
  109. package/scripts/e2e-training-test.ts +550 -0
  110. package/scripts/export-rubrics.ts +64 -0
  111. package/scripts/generate-research-report.ts +1523 -0
  112. package/scripts/generate_dataset.sh +173 -0
  113. package/scripts/json-mode-benchmark.ts +399 -0
  114. package/scripts/real-archetype-benchmark.ts +210 -0
  115. package/scripts/run-baseline-comparison.ts +116 -0
  116. package/scripts/run-full-pipeline.ts +272 -0
  117. package/scripts/runpod_setup.sh +137 -0
  118. package/scripts/runpod_validate.sh +147 -0
  119. package/scripts/test-model-in-game.ts +955 -0
  120. package/scripts/test-scoring.ts +73 -0
  121. package/scripts/test-trained-model.ts +209 -0
  122. package/scripts/train-and-test.ts +824 -0
  123. package/scripts/verify-final.ts +118 -0
  124. package/src/adapter.ts +516 -0
  125. package/src/archetypes/ArchetypeConfigService.ts +626 -0
  126. package/src/archetypes/derive-archetype.ts +249 -0
  127. package/src/archetypes/index.ts +22 -0
  128. package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
  129. package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
  130. package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
  131. package/src/benchmark/BenchmarkDataViewer.ts +324 -0
  132. package/src/benchmark/BenchmarkHistoryService.ts +221 -0
  133. package/src/benchmark/BenchmarkRunner.ts +685 -0
  134. package/src/benchmark/BenchmarkValidator.ts +206 -0
  135. package/src/benchmark/FastEvalRunner.ts +225 -0
  136. package/src/benchmark/MetricsValidator.ts +165 -0
  137. package/src/benchmark/MetricsVisualizer.ts +909 -0
  138. package/src/benchmark/ModelBenchmarkService.ts +611 -0
  139. package/src/benchmark/ModelRegistry.ts +158 -0
  140. package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
  141. package/src/benchmark/SimulationA2AInterface.ts +1169 -0
  142. package/src/benchmark/SimulationEngine.ts +832 -0
  143. package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
  144. package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
  145. package/src/benchmark/index.ts +89 -0
  146. package/src/benchmark/parseSimulationMetrics.ts +124 -0
  147. package/src/benchmark/simulation-types.ts +78 -0
  148. package/src/dependencies.ts +439 -0
  149. package/src/generation/TrajectoryGenerator.ts +387 -0
  150. package/src/generation/index.ts +12 -0
  151. package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
  152. package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
  153. package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
  154. package/src/huggingface/index.ts +27 -0
  155. package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
  156. package/src/index.ts +102 -0
  157. package/src/init-training.ts +53 -0
  158. package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
  159. package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
  160. package/src/metrics/index.ts +8 -0
  161. package/src/metrics/types.ts +200 -0
  162. package/src/rubrics/__tests__/index.test.ts +184 -0
  163. package/src/rubrics/ass-kisser.ts +85 -0
  164. package/src/rubrics/degen.ts +80 -0
  165. package/src/rubrics/goody-twoshoes.ts +84 -0
  166. package/src/rubrics/index.ts +236 -0
  167. package/src/rubrics/information-trader.ts +84 -0
  168. package/src/rubrics/infosec.ts +101 -0
  169. package/src/rubrics/liar.ts +104 -0
  170. package/src/rubrics/perps-trader.ts +87 -0
  171. package/src/rubrics/researcher.ts +81 -0
  172. package/src/rubrics/scammer.ts +82 -0
  173. package/src/rubrics/social-butterfly.ts +73 -0
  174. package/src/rubrics/super-predictor.ts +97 -0
  175. package/src/rubrics/trader.ts +67 -0
  176. package/src/scoring/ArchetypeScoringService.ts +486 -0
  177. package/src/scoring/JudgePromptBuilder.ts +556 -0
  178. package/src/scoring/LLMJudgeCache.ts +401 -0
  179. package/src/scoring/index.ts +9 -0
  180. package/src/training/AutomationPipeline.ts +916 -0
  181. package/src/training/BenchmarkService.ts +518 -0
  182. package/src/training/ConfigValidator.ts +220 -0
  183. package/src/training/MarketOutcomesTracker.ts +187 -0
  184. package/src/training/ModelDeployer.ts +186 -0
  185. package/src/training/ModelFetcher.ts +76 -0
  186. package/src/training/ModelSelectionService.ts +341 -0
  187. package/src/training/ModelUsageVerifier.ts +160 -0
  188. package/src/training/MultiModelOrchestrator.ts +580 -0
  189. package/src/training/RLModelConfig.ts +407 -0
  190. package/src/training/RewardBackpropagationService.ts +149 -0
  191. package/src/training/RulerScoringService.ts +666 -0
  192. package/src/training/TrainingMonitor.ts +166 -0
  193. package/src/training/TrajectoryRecorder.ts +399 -0
  194. package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
  195. package/src/training/index.ts +100 -0
  196. package/src/training/logRLConfig.ts +34 -0
  197. package/src/training/pipeline.ts +129 -0
  198. package/src/training/storage/ModelStorageService.ts +279 -0
  199. package/src/training/storage/TrainingDataArchiver.ts +197 -0
  200. package/src/training/storage/index.ts +17 -0
  201. package/src/training/types.ts +207 -0
  202. package/src/training/window-utils.ts +138 -0
  203. package/src/utils/index.ts +101 -0
  204. package/src/utils/logger.ts +59 -0
  205. package/src/utils/snowflake.ts +17 -0
  206. package/src/utils/synthetic-detector.ts +111 -0
  207. package/tsconfig.json +20 -0
@@ -0,0 +1,646 @@
1
+ """
2
+ Tinker Trainer
3
+
4
+ Lightweight GRPO trainer using Tinker API.
5
+ Replaces heavy local vLLM + PyTorch training with cloud-based training.
6
+
7
+ This trainer:
8
+ 1. Uses TinkerClient for training and inference
9
+ 2. Integrates with RLAIFEnv for trajectory collection
10
+ 3. Implements GRPO/IS training loop
11
+ 4. Handles weight synchronization
12
+
13
+ Benefits over local training:
14
+ - No local GPU required
15
+ - Access to larger models (Qwen3-235B)
16
+ - Faster weight sync (no vLLM restarts)
17
+ - Better on-policy training with low staleness
18
+ - Pay only for training time, not idle GPU
19
+
20
+ Based on: tinker-atropos integration (Nous Research)
21
+ """
22
+
23
+ import json
24
+ import logging
25
+ import os
26
+ from dataclasses import dataclass, field
27
+ from datetime import datetime, timezone
28
+ from pathlib import Path
29
+ from typing import List
30
+
31
+ import numpy as np
32
+ from dotenv import load_dotenv
33
+ from pydantic import BaseModel, Field
34
+
35
+ from .tinker_client import (
36
+ TinkerClient,
37
+ TinkerConfig,
38
+ TinkerDatum,
39
+ TINKER_AVAILABLE,
40
+ )
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+ # Load environment variables
45
+ project_root = Path(__file__).parent.parent.parent.parent
46
+ env_path = project_root / ".env"
47
+ env_local_path = project_root / ".env.local"
48
+
49
+ if env_local_path.exists():
50
+ load_dotenv(env_local_path, override=True)
51
+ if env_path.exists():
52
+ load_dotenv(env_path, override=False)
53
+
54
+
55
+ class TinkerTrainingConfig(BaseModel):
56
+ """Configuration for Tinker-based training"""
57
+
58
+ # Model settings
59
+ base_model: str = Field(
60
+ default="Qwen/Qwen3-30B-A3B-Instruct",
61
+ description="Base model from Tinker's supported models",
62
+ )
63
+ lora_rank: int = Field(default=32, description="LoRA rank for fine-tuning")
64
+
65
+ # Training hyperparameters
66
+ learning_rate: float = Field(default=4e-5, description="Learning rate")
67
+ training_steps: int = Field(default=100, description="Number of training steps")
68
+ group_size: int = Field(default=4, description="Group size for GRPO comparison")
69
+
70
+ # Weight sync settings
71
+ weight_sync_interval: int = Field(
72
+ default=5, description="Sync weights to sampler every N steps"
73
+ )
74
+
75
+ # Environment settings
76
+ database_url: str = Field(
77
+ default_factory=lambda: os.getenv("DATABASE_URL", ""),
78
+ description="PostgreSQL connection URL",
79
+ )
80
+ lookback_hours: int = Field(
81
+ default=72, description="Hours to look back for trajectories"
82
+ )
83
+ min_agents_per_window: int = Field(
84
+ default=2, description="Minimum agents per window"
85
+ )
86
+ min_actions_per_trajectory: int = Field(
87
+ default=3, description="Minimum actions per trajectory"
88
+ )
89
+ max_steps_per_trajectory: int = Field(
90
+ default=20, description="Max steps to include per trajectory"
91
+ )
92
+ max_token_length: int = Field(default=4096, description="Maximum sequence length")
93
+
94
+ # RLAIF Judge settings
95
+ judge_model: str = Field(default="gpt-4o-mini", description="Model for RLAIF judge")
96
+ judge_temperature: float = Field(default=0.3, description="Judge temperature")
97
+
98
+ # Logging settings
99
+ log_to_file: bool = Field(default=True, description="Log metrics to file")
100
+ log_file: str = Field(
101
+ default="./logs/tinker_training_metrics.jsonl", description="Metrics log file"
102
+ )
103
+
104
+ # Inference settings
105
+ inference_max_tokens: int = Field(
106
+ default=512, description="Max tokens for inference"
107
+ )
108
+ inference_temperature: float = Field(
109
+ default=0.7, description="Temperature for inference"
110
+ )
111
+
112
+
113
+ @dataclass
114
+ class TrainingMetrics:
115
+ """Metrics from training"""
116
+
117
+ step: int
118
+ loss: float
119
+ num_samples: int
120
+ logprobs_mean: float = 0.0
121
+ pos_advantage_mean: float = 0.0
122
+ neg_advantage_mean: float = 0.0
123
+ avg_score: float = 0.0
124
+ windows_processed: int = 0
125
+ timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
126
+
127
+
128
+ class TinkerTrainer:
129
+ """
130
+ GRPO Trainer using Tinker API.
131
+
132
+ This replaces local heavyweight trainer flows with a lighter implementation:
133
+ - No local vLLM management
134
+ - No GPU requirements on training machine
135
+ - Training happens in Tinker cloud
136
+ - Only data loading runs locally
137
+
138
+ The training loop:
139
+ 1. Load trajectory groups from database
140
+ 2. Score trajectories using LLM judge (RLAIF)
141
+ 3. Convert to training format
142
+ 4. Call Tinker for forward_backward + optim_step
143
+ 5. Periodically sync weights to sampling client
144
+ """
145
+
146
+ def __init__(self, config: TinkerTrainingConfig):
147
+ if not TINKER_AVAILABLE:
148
+ raise RuntimeError(
149
+ "Tinker not installed. Install with: pip install tinker"
150
+ )
151
+
152
+ self.config = config
153
+ self.tinker_config = TinkerConfig(
154
+ base_model=config.base_model,
155
+ lora_rank=config.lora_rank,
156
+ learning_rate=config.learning_rate,
157
+ default_max_tokens=config.inference_max_tokens,
158
+ default_temperature=config.inference_temperature,
159
+ )
160
+ self.tinker_client = TinkerClient(self.tinker_config)
161
+
162
+ self.current_step = 0
163
+ self.run_id = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
164
+ self.all_metrics: List[TrainingMetrics] = []
165
+
166
+ # Database pool (lazy init)
167
+ self._db_pool = None
168
+
169
+ # Judge client (lazy init)
170
+ self._judge_client = None
171
+
172
+ async def setup(self) -> None:
173
+ """Initialize Tinker client and database connection"""
174
+ logger.info(f"Setting up Tinker trainer with {self.config.base_model}")
175
+ logger.info(f"Run ID: {self.run_id}")
176
+
177
+ # Initialize Tinker
178
+ self.tinker_client.setup()
179
+ logger.info("Tinker client initialized")
180
+
181
+ # Setup logging
182
+ if self.config.log_to_file:
183
+ log_dir = Path(self.config.log_file).parent
184
+ log_dir.mkdir(parents=True, exist_ok=True)
185
+ logger.info(f"Metrics will be logged to: {self.config.log_file}")
186
+
187
+ # Connect to database
188
+ await self._connect_database()
189
+
190
+ # Initialize judge
191
+ await self._init_judge()
192
+
193
+ logger.info("Setup complete")
194
+
195
+ async def _connect_database(self) -> None:
196
+ """Connect to PostgreSQL database"""
197
+ import asyncpg
198
+
199
+ if not self.config.database_url:
200
+ raise ValueError("DATABASE_URL not set")
201
+
202
+ self._db_pool = await asyncpg.create_pool(
203
+ self.config.database_url,
204
+ min_size=2,
205
+ max_size=10,
206
+ command_timeout=60,
207
+ )
208
+ logger.info("Connected to database")
209
+
210
+ async def _init_judge(self) -> None:
211
+ """Initialize OpenAI client for RLAIF judge"""
212
+ import openai
213
+
214
+ self._judge_client = openai.AsyncOpenAI()
215
+ logger.info(f"Judge initialized with model: {self.config.judge_model}")
216
+
217
+ async def cleanup(self) -> None:
218
+ """Clean up resources"""
219
+ if self._db_pool:
220
+ await self._db_pool.close()
221
+ self._db_pool = None
222
+ logger.info("Database connection closed")
223
+
224
+ def log_metrics(self, metrics: TrainingMetrics) -> None:
225
+ """Log metrics to file"""
226
+ if self.config.log_to_file:
227
+ metrics_dict = {
228
+ "timestamp": metrics.timestamp,
229
+ "run_id": self.run_id,
230
+ "step": metrics.step,
231
+ "loss": metrics.loss,
232
+ "num_samples": metrics.num_samples,
233
+ "logprobs_mean": metrics.logprobs_mean,
234
+ "pos_advantage_mean": metrics.pos_advantage_mean,
235
+ "neg_advantage_mean": metrics.neg_advantage_mean,
236
+ "avg_score": metrics.avg_score,
237
+ "windows_processed": metrics.windows_processed,
238
+ }
239
+ with open(self.config.log_file, "a") as f:
240
+ f.write(json.dumps(metrics_dict) + "\n")
241
+
242
+ self.all_metrics.append(metrics)
243
+
244
+ async def load_trajectory_groups(self) -> List[dict]:
245
+ """Load trajectory groups from database"""
246
+ if not self._db_pool:
247
+ raise RuntimeError("Database not connected")
248
+
249
+ async with self._db_pool.acquire() as conn:
250
+ rows = await conn.fetch(
251
+ """
252
+ SELECT
253
+ t."trajectoryId",
254
+ t."agentId",
255
+ t."windowId",
256
+ t."scenarioId",
257
+ t."stepsJson",
258
+ t."finalPnL",
259
+ t."episodeLength",
260
+ t."totalReward",
261
+ u.username as agent_name
262
+ FROM trajectories t
263
+ LEFT JOIN "User" u ON t."agentId" = u.id
264
+ WHERE
265
+ t."createdAt" > NOW() - $1::interval
266
+ AND t."stepsJson" IS NOT NULL
267
+ AND t."stepsJson"::text != 'null'
268
+ AND t."stepsJson"::text != '[]'
269
+ AND t."episodeLength" >= $2
270
+ ORDER BY t."windowId", t."scenarioId", t."createdAt"
271
+ """,
272
+ f"{self.config.lookback_hours} hours",
273
+ self.config.min_actions_per_trajectory,
274
+ )
275
+
276
+ # Group by window/scenario
277
+ groups: dict = {}
278
+ for row in rows:
279
+ group_key = f"{row['windowId']}_{row['scenarioId'] or 'default'}"
280
+
281
+ if group_key not in groups:
282
+ groups[group_key] = []
283
+
284
+ steps = json.loads(row["stepsJson"] or "[]")
285
+ if len(steps) < self.config.min_actions_per_trajectory:
286
+ continue
287
+
288
+ groups[group_key].append(
289
+ {
290
+ "trajectory_id": row["trajectoryId"],
291
+ "agent_id": row["agentId"],
292
+ "agent_name": row["agent_name"] or row["agentId"][:8],
293
+ "window_id": row["windowId"],
294
+ "scenario_id": row["scenarioId"],
295
+ "steps": steps,
296
+ "final_pnl": float(row["finalPnL"] or 0),
297
+ "episode_length": row["episodeLength"] or len(steps),
298
+ "total_reward": float(row["totalReward"] or 0),
299
+ }
300
+ )
301
+
302
+ # Filter groups with enough trajectories
303
+ valid_groups = [
304
+ {"group_key": k, "trajectories": v}
305
+ for k, v in groups.items()
306
+ if len(v) >= self.config.min_agents_per_window
307
+ ]
308
+
309
+ logger.info(f"Loaded {len(valid_groups)} trajectory groups")
310
+ return valid_groups
311
+
312
+ def trajectory_to_messages(self, traj: dict) -> List[dict]:
313
+ """Convert trajectory to chat messages format"""
314
+ messages = []
315
+
316
+ # System message
317
+ system_content = f"""You are a trading agent in a prediction market simulation.
318
+
319
+ Agent: {traj.get('agent_name', 'Agent')}
320
+ Window: {traj.get('window_id', 'Unknown')}
321
+ Final P&L: ${traj.get('final_pnl', 0):.2f}
322
+
323
+ Your goal is to make profitable trading decisions based on market analysis."""
324
+
325
+ messages.append({"role": "system", "content": system_content})
326
+
327
+ # Convert steps
328
+ steps = traj.get("steps", [])
329
+ max_steps = self.config.max_steps_per_trajectory
330
+
331
+ if len(steps) > max_steps:
332
+ steps = steps[-max_steps:]
333
+
334
+ for step_idx, step in enumerate(steps):
335
+ if not isinstance(step, dict):
336
+ continue
337
+
338
+ # Get LLM calls if available
339
+ llm_calls = step.get("llmCalls", step.get("llm_calls", []))
340
+
341
+ if llm_calls:
342
+ for llm_call in llm_calls:
343
+ purpose = llm_call.get("purpose", "action")
344
+ user_prompt = llm_call.get(
345
+ "userPrompt", llm_call.get("user_prompt", "")
346
+ )
347
+
348
+ # Build user content
349
+ user_content = f"[Step {step_idx + 1}, {purpose.upper()}]\n"
350
+
351
+ env_state = step.get(
352
+ "environmentState", step.get("environment_state", {})
353
+ )
354
+ if env_state:
355
+ balance = env_state.get(
356
+ "agentBalance", env_state.get("agent_balance", 0)
357
+ )
358
+ pnl = env_state.get("agentPnL", env_state.get("agent_pnl", 0))
359
+ positions = env_state.get(
360
+ "openPositions", env_state.get("open_positions", 0)
361
+ )
362
+ user_content += (
363
+ f"State: Balance=${balance:.2f}, "
364
+ f"P&L=${pnl:.2f}, Positions={positions}\n\n"
365
+ )
366
+
367
+ if user_prompt:
368
+ user_content += user_prompt
369
+
370
+ messages.append({"role": "user", "content": user_content})
371
+
372
+ # Assistant response
373
+ response = llm_call.get("response", "")
374
+ reasoning = llm_call.get("reasoning", "")
375
+
376
+ assistant_content = ""
377
+ if reasoning:
378
+ assistant_content += f"<thinking>\n{reasoning}\n</thinking>\n\n"
379
+ if response:
380
+ assistant_content += response
381
+
382
+ if assistant_content.strip():
383
+ messages.append(
384
+ {"role": "assistant", "content": assistant_content}
385
+ )
386
+ else:
387
+ # Fallback: build from environment state and action
388
+ env_state = step.get(
389
+ "environmentState", step.get("environment_state", {})
390
+ )
391
+ balance = env_state.get(
392
+ "agentBalance", env_state.get("agent_balance", 0)
393
+ )
394
+ pnl = env_state.get("agentPnL", env_state.get("agent_pnl", 0))
395
+ positions = env_state.get(
396
+ "openPositions", env_state.get("open_positions", 0)
397
+ )
398
+
399
+ user_content = (
400
+ f"[Step {step_idx + 1}]\n"
401
+ f"Market Update:\n"
402
+ f"- Balance: ${balance:.2f}\n"
403
+ f"- P&L: ${pnl:.2f}\n"
404
+ f"- Open Positions: {positions}"
405
+ )
406
+
407
+ messages.append({"role": "user", "content": user_content})
408
+
409
+ # Action as assistant message
410
+ action = step.get("action", {})
411
+ action_type = action.get(
412
+ "actionType", action.get("action_type", "wait")
413
+ )
414
+ params = action.get("parameters", {})
415
+ reasoning = action.get("reasoning", "")
416
+
417
+ assistant_content = ""
418
+ if reasoning:
419
+ assistant_content += f"<thinking>\n{reasoning}\n</thinking>\n\n"
420
+ assistant_content += f"Action: {action_type}"
421
+ if params:
422
+ assistant_content += f"\nParameters: {json.dumps(params, indent=2)}"
423
+
424
+ messages.append({"role": "assistant", "content": assistant_content})
425
+
426
+ return messages
427
+
428
+ async def score_trajectories(
429
+ self, trajectories: List[dict]
430
+ ) -> List[float]:
431
+ """Score trajectories using LLM judge (RLAIF)"""
432
+ # Build judge prompt
433
+ prompt_parts = [
434
+ "# Trading Agent Evaluation\n",
435
+ "Score each trajectory from 0.0 to 1.0 based on:\n",
436
+ "- Profitability (higher P&L = higher score)\n",
437
+ "- Risk management\n",
438
+ "- Decision quality\n\n",
439
+ "## Trajectories:\n",
440
+ ]
441
+
442
+ for i, traj in enumerate(trajectories):
443
+ prompt_parts.append(f"\n### Trajectory {i + 1}:")
444
+ prompt_parts.append(f"- Agent: {traj.get('agent_name', 'Unknown')}")
445
+ prompt_parts.append(f"- Final P&L: ${traj.get('final_pnl', 0):.2f}")
446
+ prompt_parts.append(f"- Episode Length: {traj.get('episode_length', 0)}")
447
+
448
+ prompt_parts.append("\n## Output (JSON only):")
449
+ prompt_parts.append(
450
+ '{"scores": [{"trajectory_id": 1, "score": 0.85}, ...]}'
451
+ )
452
+
453
+ judge_prompt = "\n".join(prompt_parts)
454
+
455
+ # Call judge
456
+ response = await self._judge_client.chat.completions.create(
457
+ model=self.config.judge_model,
458
+ messages=[
459
+ {
460
+ "role": "system",
461
+ "content": "You are an expert evaluator. Respond with valid JSON only.",
462
+ },
463
+ {"role": "user", "content": judge_prompt},
464
+ ],
465
+ max_tokens=500,
466
+ temperature=self.config.judge_temperature,
467
+ )
468
+
469
+ # Parse response
470
+ content = response.choices[0].message.content or ""
471
+ try:
472
+ # Clean and parse JSON
473
+ clean = content.strip().replace("```json", "").replace("```", "")
474
+ if "{" in clean:
475
+ start = clean.find("{")
476
+ end = clean.rfind("}") + 1
477
+ parsed = json.loads(clean[start:end])
478
+ scores_data = parsed.get("scores", parsed)
479
+
480
+ scores = []
481
+ for item in scores_data:
482
+ if isinstance(item, dict):
483
+ scores.append(float(item.get("score", 0.5)))
484
+ else:
485
+ scores.append(float(item))
486
+
487
+ if len(scores) == len(trajectories):
488
+ return scores
489
+
490
+ except (json.JSONDecodeError, ValueError, KeyError) as e:
491
+ logger.warning(f"Failed to parse judge response: {e}")
492
+
493
+ # Fallback: P&L-based scoring
494
+ pnls = [t.get("final_pnl", 0) for t in trajectories]
495
+ min_pnl, max_pnl = min(pnls), max(pnls)
496
+ pnl_range = max_pnl - min_pnl if max_pnl != min_pnl else 1.0
497
+
498
+ return [(p - min_pnl) / pnl_range for p in pnls]
499
+
500
+ async def train_on_group(
501
+ self, group: dict
502
+ ) -> TrainingMetrics | None:
503
+ """Train on a single trajectory group"""
504
+ trajectories = group["trajectories"]
505
+
506
+ # Sample if too many
507
+ if len(trajectories) > self.config.group_size:
508
+ import random
509
+
510
+ trajectories = random.sample(trajectories, self.config.group_size)
511
+
512
+ if len(trajectories) < 2:
513
+ logger.warning(f"Group {group['group_key']} has insufficient trajectories")
514
+ return None
515
+
516
+ # Score trajectories
517
+ scores = await self.score_trajectories(trajectories)
518
+
519
+ # Normalize to mean 0 for GRPO
520
+ mean_score = sum(scores) / len(scores)
521
+ advantages = [s - mean_score for s in scores]
522
+
523
+ # Normalize variance
524
+ if len(advantages) > 1:
525
+ std = float(np.std(advantages))
526
+ if std > 1e-8:
527
+ advantages = [a / std for a in advantages]
528
+
529
+ # Convert to training data
530
+ data: List[TinkerDatum] = []
531
+ valid_advantages: List[float] = []
532
+
533
+ for traj, advantage in zip(trajectories, advantages):
534
+ messages = self.trajectory_to_messages(traj)
535
+
536
+ if len(messages) < 3: # Need at least system + user + assistant
537
+ continue
538
+
539
+ # Get last assistant message as completion
540
+ assistant_msgs = [m for m in messages if m["role"] == "assistant"]
541
+ if not assistant_msgs:
542
+ continue
543
+
544
+ completion = assistant_msgs[-1]["content"]
545
+ context_messages = messages[:-1] # All but last
546
+
547
+ # Prepare datum
548
+ datum = self.tinker_client.prepare_datum(
549
+ messages=context_messages,
550
+ completion=completion,
551
+ )
552
+
553
+ data.append(datum)
554
+ valid_advantages.append(advantage)
555
+
556
+ if not data:
557
+ logger.warning("No valid training data from group")
558
+ return None
559
+
560
+ # Train step
561
+ result = self.tinker_client.train_step(
562
+ data=data,
563
+ scores=valid_advantages,
564
+ loss_fn="importance_sampling",
565
+ )
566
+
567
+ return TrainingMetrics(
568
+ step=self.current_step,
569
+ loss=result.loss,
570
+ num_samples=result.num_samples,
571
+ logprobs_mean=result.logprobs_mean,
572
+ pos_advantage_mean=result.pos_advantage_mean,
573
+ neg_advantage_mean=result.neg_advantage_mean,
574
+ avg_score=float(np.mean(scores)),
575
+ )
576
+
577
+ async def train(self) -> dict:
578
+ """Main training loop"""
579
+ await self.setup()
580
+
581
+ try:
582
+ logger.info(f"Starting training for {self.config.training_steps} steps")
583
+
584
+ # Load all trajectory groups
585
+ all_groups = await self.load_trajectory_groups()
586
+
587
+ if not all_groups:
588
+ raise ValueError("No trajectory groups found")
589
+
590
+ group_idx = 0
591
+ windows_processed = 0
592
+
593
+ for step in range(self.config.training_steps):
594
+ self.current_step = step + 1
595
+ logger.info(
596
+ f"Step {self.current_step}/{self.config.training_steps}"
597
+ )
598
+
599
+ # Get next group (circular)
600
+ group = all_groups[group_idx % len(all_groups)]
601
+ group_idx += 1
602
+
603
+ # Train on group
604
+ metrics = await self.train_on_group(group)
605
+
606
+ if metrics:
607
+ windows_processed += 1
608
+ metrics.windows_processed = windows_processed
609
+
610
+ logger.info(
611
+ f" Loss: {metrics.loss:.4f}, "
612
+ f"Samples: {metrics.num_samples}, "
613
+ f"Avg Score: {metrics.avg_score:.3f}"
614
+ )
615
+
616
+ self.log_metrics(metrics)
617
+ else:
618
+ logger.warning(" No metrics (empty batch)")
619
+
620
+ # Sync weights periodically
621
+ if self.current_step % self.config.weight_sync_interval == 0:
622
+ logger.info("Syncing weights to sampling client...")
623
+ self.tinker_client.sync_weights(
624
+ name=f"eliza-{self.run_id}-step-{self.current_step}"
625
+ )
626
+
627
+ # Final weight sync
628
+ final_name = f"eliza-{self.run_id}-final"
629
+ self.tinker_client.sync_weights(name=final_name)
630
+ logger.info(f"Training complete! Final weights: {final_name}")
631
+
632
+ return {
633
+ "success": True,
634
+ "run_id": self.run_id,
635
+ "steps": self.current_step,
636
+ "windows_processed": windows_processed,
637
+ "final_weights": final_name,
638
+ "metrics_file": self.config.log_file if self.config.log_to_file else None,
639
+ }
640
+
641
+ finally:
642
+ await self.cleanup()
643
+
644
+
645
+ # Backward compatibility alias while imports migrate.
646
+ BabylonTinkerTrainer = TinkerTrainer