@elizaos/training 2.0.0-alpha.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Dockerfile +75 -0
- package/LICENSE +21 -0
- package/Makefile +374 -0
- package/README.md +346 -0
- package/config/rubrics.json +137 -0
- package/docker-compose.test.yml +57 -0
- package/package.json +57 -0
- package/python/config/babylon_atropos.yaml +90 -0
- package/python/config/profiles/12gb.json +11 -0
- package/python/config/profiles/16gb.json +10 -0
- package/python/config/profiles/24gb.json +10 -0
- package/python/config/profiles/48gb.json +10 -0
- package/python/config/profiles/cpu.json +11 -0
- package/python/config/profiles/l40-2gpu-safe.json +20 -0
- package/python/config/profiles/l40-2gpu.json +22 -0
- package/python/config/profiles/l40-4gpu.json +21 -0
- package/python/config/profiles/l40.json +17 -0
- package/python/config/tinker_training.yaml +143 -0
- package/python/curriculum_state.json +165 -0
- package/python/env.template +86 -0
- package/python/env.training.template +46 -0
- package/python/pyproject.toml +41 -0
- package/python/requirements-ci.txt +31 -0
- package/python/requirements.txt +87 -0
- package/python/scripts/__init__.py +4 -0
- package/python/scripts/benchmark_should_respond.py +190 -0
- package/python/scripts/debug_inference.py +62 -0
- package/python/scripts/import_json_trajectories.py +412 -0
- package/python/scripts/local-finetune/README.md +63 -0
- package/python/scripts/local-finetune/ingest_and_score.py +139 -0
- package/python/scripts/local-finetune/merge_model.py +32 -0
- package/python/scripts/local-finetune/test_adapter.py +91 -0
- package/python/scripts/local-finetune/train_from_csv.py +132 -0
- package/python/scripts/merge_trajectories.py +318 -0
- package/python/scripts/optimize_prompt_grpo.py +269 -0
- package/python/scripts/run_ab_test.py +143 -0
- package/python/scripts/run_full_pipeline.py +544 -0
- package/python/scripts/run_tinker_training.py +192 -0
- package/python/scripts/run_training.py +914 -0
- package/python/scripts/test_generation.py +29 -0
- package/python/scripts/test_judge.py +155 -0
- package/python/scripts/test_pipeline.py +356 -0
- package/python/scripts/test_trained_model.py +380 -0
- package/python/scripts/train_grpo.py +360 -0
- package/python/scripts/train_jsonl.py +223 -0
- package/python/scripts/train_local.py +528 -0
- package/python/setup.py +20 -0
- package/python/src/__init__.py +190 -0
- package/python/src/data_bridge/__init__.py +24 -0
- package/python/src/data_bridge/converter.py +435 -0
- package/python/src/data_bridge/reader.py +393 -0
- package/python/src/models.py +283 -0
- package/python/src/training/__init__.py +605 -0
- package/python/src/training/ab_testing.py +404 -0
- package/python/src/training/action_executor.py +621 -0
- package/python/src/training/archetype_trainer.py +347 -0
- package/python/src/training/atropos_trainer.py +980 -0
- package/python/src/training/babylon_env.py +1254 -0
- package/python/src/training/error_recovery.py +647 -0
- package/python/src/training/evaluation.py +856 -0
- package/python/src/training/fast_simulator.py +880 -0
- package/python/src/training/format_validator.py +584 -0
- package/python/src/training/hybrid_env.py +522 -0
- package/python/src/training/kl_controller.py +628 -0
- package/python/src/training/multi_prompt_dataset.py +883 -0
- package/python/src/training/multi_turn.py +656 -0
- package/python/src/training/online_env.py +1084 -0
- package/python/src/training/quality_scorer.py +391 -0
- package/python/src/training/quality_utils.py +633 -0
- package/python/src/training/rewards.py +1344 -0
- package/python/src/training/rlaif_env.py +17 -0
- package/python/src/training/rollout_generator.py +502 -0
- package/python/src/training/rubric_loader.py +198 -0
- package/python/src/training/scenario_pool.py +1072 -0
- package/python/src/training/schemas.py +481 -0
- package/python/src/training/service_manager.py +552 -0
- package/python/src/training/simulation_bridge.py +535 -0
- package/python/src/training/tick_reward_attribution.py +399 -0
- package/python/src/training/tinker_client.py +575 -0
- package/python/src/training/tinker_trainer.py +646 -0
- package/python/src/training/tokenization_utils.py +402 -0
- package/python/tests/e2e/__init__.py +13 -0
- package/python/tests/e2e/conftest.py +258 -0
- package/python/tests/e2e/test_full_pipeline.py +643 -0
- package/python/tests/e2e/test_online_training_e2e.py +365 -0
- package/python/tests/integration/__init__.py +12 -0
- package/python/tests/integration/conftest.py +383 -0
- package/python/tests/integration/test_db_integration.py +649 -0
- package/python/tests/integration/test_json_mode_integration.py +554 -0
- package/python/tests/test_action_executor.py +594 -0
- package/python/tests/test_archetype_scoring.py +1027 -0
- package/python/tests/test_atropos_integration.py +360 -0
- package/python/tests/test_evaluation.py +727 -0
- package/python/tests/test_format_validator.py +486 -0
- package/python/tests/test_kl_controller.py +432 -0
- package/python/tests/test_lr_scheduler.py +579 -0
- package/python/tests/test_multi_turn.py +590 -0
- package/python/tests/test_online_env.py +519 -0
- package/python/tests/test_quality_scorer.py +474 -0
- package/python/tests/test_scenario_pool.py +735 -0
- package/python/tests/test_service_manager.py +585 -0
- package/python/tests/test_simulation_rollout.py +581 -0
- package/python/tests/test_tokenization_utils.py +501 -0
- package/python/tests/test_training_orchestrator.py +497 -0
- package/python/tests/test_training_output_structure.py +661 -0
- package/research-output/training-runs/training-run-1770772042899.json +26 -0
- package/research-output/training-runs/training-run-1770930079670.json +32 -0
- package/research-output/training-runs/training-run-1770930143700.json +44 -0
- package/research-output/training-runs/training-run-1770930183638.json +38 -0
- package/research-output/training-runs/training-run-1770930442049.json +38 -0
- package/research-output/training-runs/training-run-1770930793243.json +38 -0
- package/research-output/training-runs/training-run-1771276293257.json +38 -0
- package/research-output/training-runs/training-run-1771276389280.json +38 -0
- package/research-output/training-runs/training-run-1771276502776.json +38 -0
- package/research-output/training-runs/training-run-1771277340748.json +38 -0
- package/research-output/training-runs/training-run-1773013658993.json +38 -0
- package/research-output/training-runs/training-run-1773013861014.json +38 -0
- package/research-output/training-runs/training-run-1773014215983.json +38 -0
- package/scripts/assess-training-data.ts +422 -0
- package/scripts/e2e-training-test.ts +550 -0
- package/scripts/export-rubrics.ts +64 -0
- package/scripts/generate-research-report.ts +1523 -0
- package/scripts/generate_dataset.sh +173 -0
- package/scripts/generate_should_respond.ts +267 -0
- package/scripts/generate_should_respond_dataset.ts +162 -0
- package/scripts/json-mode-benchmark.ts +399 -0
- package/scripts/rank_trajectories.ts +207 -0
- package/scripts/real-archetype-benchmark.ts +210 -0
- package/scripts/run-baseline-comparison.ts +116 -0
- package/scripts/run-full-pipeline.ts +272 -0
- package/scripts/run_rlaif_loop.ts +78 -0
- package/scripts/run_task_benchmark.ts +247 -0
- package/scripts/runpod_setup.sh +137 -0
- package/scripts/runpod_validate.sh +147 -0
- package/scripts/test-model-in-game.ts +955 -0
- package/scripts/test-scoring.ts +73 -0
- package/scripts/test-trained-model.ts +209 -0
- package/scripts/train-and-test.ts +824 -0
- package/scripts/verify-final.ts +118 -0
- package/src/adapter.ts +516 -0
- package/src/archetypes/ArchetypeConfigService.ts +626 -0
- package/src/archetypes/derive-archetype.ts +249 -0
- package/src/archetypes/index.ts +22 -0
- package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
- package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
- package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
- package/src/benchmark/BenchmarkDataViewer.ts +324 -0
- package/src/benchmark/BenchmarkHistoryService.ts +221 -0
- package/src/benchmark/BenchmarkRunner.ts +685 -0
- package/src/benchmark/BenchmarkValidator.ts +204 -0
- package/src/benchmark/FastEvalRunner.ts +225 -0
- package/src/benchmark/MetricsValidator.ts +165 -0
- package/src/benchmark/MetricsVisualizer.ts +909 -0
- package/src/benchmark/ModelBenchmarkService.ts +611 -0
- package/src/benchmark/ModelRegistry.ts +158 -0
- package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
- package/src/benchmark/SimulationA2AInterface.ts +1169 -0
- package/src/benchmark/SimulationEngine.ts +832 -0
- package/src/benchmark/TaskRunner.ts +94 -0
- package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
- package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
- package/src/benchmark/index.ts +91 -0
- package/src/benchmark/parseSimulationMetrics.ts +124 -0
- package/src/benchmark/simulation-types.ts +78 -0
- package/src/dependencies.ts +475 -0
- package/src/generation/TrajectoryGenerator.ts +387 -0
- package/src/generation/index.ts +12 -0
- package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
- package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
- package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
- package/src/huggingface/index.ts +27 -0
- package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
- package/src/index.ts +102 -0
- package/src/init-training.ts +53 -0
- package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
- package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
- package/src/metrics/index.ts +8 -0
- package/src/metrics/types.ts +200 -0
- package/src/rubrics/__tests__/index.test.ts +184 -0
- package/src/rubrics/ass-kisser.ts +85 -0
- package/src/rubrics/degen.ts +80 -0
- package/src/rubrics/goody-twoshoes.ts +84 -0
- package/src/rubrics/index.ts +236 -0
- package/src/rubrics/information-trader.ts +84 -0
- package/src/rubrics/infosec.ts +101 -0
- package/src/rubrics/liar.ts +104 -0
- package/src/rubrics/perps-trader.ts +87 -0
- package/src/rubrics/researcher.ts +81 -0
- package/src/rubrics/scammer.ts +82 -0
- package/src/rubrics/social-butterfly.ts +73 -0
- package/src/rubrics/super-predictor.ts +97 -0
- package/src/rubrics/trader.ts +67 -0
- package/src/scoring/ArchetypeScoringService.ts +486 -0
- package/src/scoring/JudgePromptBuilder.ts +556 -0
- package/src/scoring/LLMJudgeCache.ts +401 -0
- package/src/scoring/index.ts +9 -0
- package/src/training/AutomationPipeline.ts +916 -0
- package/src/training/BenchmarkService.ts +518 -0
- package/src/training/ConfigValidator.ts +220 -0
- package/src/training/MarketOutcomesTracker.ts +187 -0
- package/src/training/ModelDeployer.ts +186 -0
- package/src/training/ModelFetcher.ts +76 -0
- package/src/training/ModelSelectionService.ts +341 -0
- package/src/training/ModelUsageVerifier.ts +160 -0
- package/src/training/MultiModelOrchestrator.ts +580 -0
- package/src/training/RLModelConfig.ts +407 -0
- package/src/training/RewardBackpropagationService.ts +149 -0
- package/src/training/RulerScoringService.ts +666 -0
- package/src/training/TrainingMonitor.ts +166 -0
- package/src/training/TrajectoryRecorder.ts +399 -0
- package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
- package/src/training/index.ts +100 -0
- package/src/training/logRLConfig.ts +34 -0
- package/src/training/pipeline.ts +129 -0
- package/src/training/storage/ModelStorageService.ts +279 -0
- package/src/training/storage/TrainingDataArchiver.ts +197 -0
- package/src/training/storage/index.ts +17 -0
- package/src/training/types.ts +207 -0
- package/src/training/window-utils.ts +138 -0
- package/src/utils/index.ts +101 -0
- package/src/utils/logger.ts +59 -0
- package/src/utils/snowflake.ts +17 -0
- package/src/utils/synthetic-detector.ts +111 -0
- package/tsconfig.json +20 -0
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Compatibility module exposing neutral RLAIF environment names.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from .babylon_env import (
|
|
6
|
+
RLAIFEnv,
|
|
7
|
+
RLAIFEnvConfig,
|
|
8
|
+
BabylonRLAIFEnv,
|
|
9
|
+
BabylonEnvConfig,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"RLAIFEnv",
|
|
14
|
+
"RLAIFEnvConfig",
|
|
15
|
+
"BabylonRLAIFEnv",
|
|
16
|
+
"BabylonEnvConfig",
|
|
17
|
+
]
|
|
@@ -0,0 +1,502 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Babylon Fast Rollout Generator
|
|
3
|
+
|
|
4
|
+
Generates high-quality rollouts at maximum speed for RL training.
|
|
5
|
+
Captures the COMPLETE agent tick including all thinking, planning, and execution.
|
|
6
|
+
|
|
7
|
+
A complete agent tick consists of:
|
|
8
|
+
1. Environment Observation - What the agent sees
|
|
9
|
+
2. Thinking/Reasoning - Internal deliberation
|
|
10
|
+
3. Planning - What actions to take
|
|
11
|
+
4. Action Execution - The actual action
|
|
12
|
+
5. Feedback - Result and reward
|
|
13
|
+
|
|
14
|
+
We need to capture ALL of this for training.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import asyncio
|
|
18
|
+
import json
|
|
19
|
+
import logging
|
|
20
|
+
from dataclasses import dataclass, field
|
|
21
|
+
from typing import Callable, Protocol
|
|
22
|
+
import time
|
|
23
|
+
|
|
24
|
+
from ..models import (
|
|
25
|
+
BabylonTrajectory,
|
|
26
|
+
LLMCall,
|
|
27
|
+
Action,
|
|
28
|
+
EnvironmentState,
|
|
29
|
+
)
|
|
30
|
+
from .quality_utils import (
|
|
31
|
+
calculate_trajectory_quality_score,
|
|
32
|
+
build_trajectory_from_ticks,
|
|
33
|
+
state_to_observation,
|
|
34
|
+
state_to_env_state,
|
|
35
|
+
calculate_detailed_tick_quality
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
from .rewards import TrajectoryRewardInputs, composite_reward, calculate_risk_reward
|
|
39
|
+
|
|
40
|
+
logger = logging.getLogger(__name__)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class AgentTickData:
|
|
45
|
+
"""
|
|
46
|
+
Complete data for a single agent tick.
|
|
47
|
+
|
|
48
|
+
This captures EVERYTHING the agent does in one tick:
|
|
49
|
+
- All observations received
|
|
50
|
+
- All LLM calls made (thinking, planning, action)
|
|
51
|
+
- The final action taken
|
|
52
|
+
- Feedback received
|
|
53
|
+
"""
|
|
54
|
+
tick_number: int
|
|
55
|
+
timestamp: int
|
|
56
|
+
|
|
57
|
+
# Environment observation
|
|
58
|
+
observation: dict
|
|
59
|
+
environment_state: EnvironmentState
|
|
60
|
+
|
|
61
|
+
# All LLM calls during this tick
|
|
62
|
+
llm_calls: list[LLMCall] = field(default_factory=list)
|
|
63
|
+
|
|
64
|
+
# The reasoning chain (concatenated thinking)
|
|
65
|
+
reasoning_chain: str = ""
|
|
66
|
+
|
|
67
|
+
# Final action
|
|
68
|
+
action: Action | None = None
|
|
69
|
+
|
|
70
|
+
# Feedback from environment
|
|
71
|
+
feedback: dict = field(default_factory=dict)
|
|
72
|
+
reward: float = 0.0
|
|
73
|
+
|
|
74
|
+
def get_full_context(self) -> str:
|
|
75
|
+
"""Get the complete context string for this tick"""
|
|
76
|
+
parts = []
|
|
77
|
+
|
|
78
|
+
# Observation
|
|
79
|
+
parts.append(f"=== OBSERVATION (Tick {self.tick_number}) ===")
|
|
80
|
+
parts.append(json.dumps(self.observation, indent=2))
|
|
81
|
+
|
|
82
|
+
# All LLM calls in order
|
|
83
|
+
for i, call in enumerate(self.llm_calls, 1):
|
|
84
|
+
parts.append(f"\n=== LLM CALL {i} ({call.purpose}) ===")
|
|
85
|
+
parts.append(f"System: {call.system_prompt}")
|
|
86
|
+
parts.append(f"User: {call.user_prompt}")
|
|
87
|
+
parts.append(f"Response: {call.response}")
|
|
88
|
+
if call.reasoning:
|
|
89
|
+
parts.append(f"Reasoning: {call.reasoning}")
|
|
90
|
+
|
|
91
|
+
# Action
|
|
92
|
+
if self.action:
|
|
93
|
+
parts.append("\n=== ACTION ===")
|
|
94
|
+
parts.append(f"Type: {self.action.action_type}")
|
|
95
|
+
parts.append(f"Parameters: {json.dumps(self.action.parameters)}")
|
|
96
|
+
if self.action.reasoning:
|
|
97
|
+
parts.append(f"Reasoning: {self.action.reasoning}")
|
|
98
|
+
|
|
99
|
+
# Feedback
|
|
100
|
+
if self.feedback:
|
|
101
|
+
parts.append("\n=== FEEDBACK ===")
|
|
102
|
+
parts.append(json.dumps(self.feedback, indent=2))
|
|
103
|
+
parts.append(f"Reward: {self.reward}")
|
|
104
|
+
|
|
105
|
+
return "\n".join(parts)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
@dataclass
|
|
109
|
+
class RolloutConfig:
|
|
110
|
+
"""Configuration for rollout generation"""
|
|
111
|
+
|
|
112
|
+
# Speed settings
|
|
113
|
+
fast_forward: bool = True
|
|
114
|
+
parallel_agents: int = 4
|
|
115
|
+
max_ticks_per_agent: int = 100
|
|
116
|
+
|
|
117
|
+
# Quality settings
|
|
118
|
+
min_llm_calls_per_tick: int = 1
|
|
119
|
+
require_action: bool = True
|
|
120
|
+
|
|
121
|
+
# Database settings
|
|
122
|
+
database_url: str = ""
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
@dataclass
|
|
126
|
+
class RolloutResult:
|
|
127
|
+
"""Result of a rollout generation run"""
|
|
128
|
+
agent_id: str
|
|
129
|
+
trajectory_id: str
|
|
130
|
+
ticks_completed: int
|
|
131
|
+
total_duration_ms: int
|
|
132
|
+
avg_tick_duration_ms: float
|
|
133
|
+
total_llm_calls: int
|
|
134
|
+
total_reward: float
|
|
135
|
+
final_pnl: float
|
|
136
|
+
quality_score: float # 0-1 based on completeness
|
|
137
|
+
trajectory: BabylonTrajectory | None = None
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class AgentRunner(Protocol):
|
|
141
|
+
"""Protocol for agent implementations - consistent with FastSimulator"""
|
|
142
|
+
|
|
143
|
+
async def run_tick(
|
|
144
|
+
self,
|
|
145
|
+
agent_id: str,
|
|
146
|
+
observation: dict,
|
|
147
|
+
env_state: EnvironmentState,
|
|
148
|
+
) -> AgentTickData:
|
|
149
|
+
"""Run a single tick and return complete tick data"""
|
|
150
|
+
...
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class FastRolloutGenerator:
|
|
154
|
+
"""
|
|
155
|
+
Generates rollouts at maximum speed while maintaining quality.
|
|
156
|
+
|
|
157
|
+
Key features:
|
|
158
|
+
- Fast-forward mode skips all waiting
|
|
159
|
+
- Parallel agent execution
|
|
160
|
+
- Complete tick capture (observation → thinking → action → feedback)
|
|
161
|
+
- Quality validation
|
|
162
|
+
"""
|
|
163
|
+
|
|
164
|
+
def __init__(self, config: RolloutConfig):
|
|
165
|
+
self.config = config
|
|
166
|
+
self.rollouts_generated = 0
|
|
167
|
+
self.total_ticks = 0
|
|
168
|
+
self.start_time: float | None = None
|
|
169
|
+
|
|
170
|
+
async def generate_rollout(
|
|
171
|
+
self,
|
|
172
|
+
agent: AgentRunner,
|
|
173
|
+
agent_id: str,
|
|
174
|
+
simulation, # SimulationEngine instance
|
|
175
|
+
) -> RolloutResult:
|
|
176
|
+
"""
|
|
177
|
+
Generate a single rollout from an agent running through simulation.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
agent: Agent implementation
|
|
181
|
+
agent_id: Unique agent identifier
|
|
182
|
+
simulation: Simulation engine providing environment
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
RolloutResult with complete trajectory data
|
|
186
|
+
"""
|
|
187
|
+
start_time = time.time()
|
|
188
|
+
tick_durations: list[float] = []
|
|
189
|
+
all_ticks: list[AgentTickData] = []
|
|
190
|
+
total_llm_calls = 0
|
|
191
|
+
total_reward = 0.0
|
|
192
|
+
|
|
193
|
+
trajectory_id = f"rollout-{agent_id}-{int(start_time * 1000)}"
|
|
194
|
+
|
|
195
|
+
logger.info(f"Starting rollout generation for agent {agent_id}")
|
|
196
|
+
|
|
197
|
+
# Run through simulation
|
|
198
|
+
tick_number = 0
|
|
199
|
+
while not simulation.isComplete() and tick_number < self.config.max_ticks_per_agent:
|
|
200
|
+
tick_start = time.time()
|
|
201
|
+
|
|
202
|
+
# Get observation from simulation
|
|
203
|
+
game_state = simulation.getGameState()
|
|
204
|
+
observation = state_to_observation(game_state)
|
|
205
|
+
env_state = state_to_env_state(game_state, agent_id)
|
|
206
|
+
|
|
207
|
+
# Agent processes tick (captures all LLM calls)
|
|
208
|
+
tick_data = await agent.run_tick(agent_id, observation, env_state)
|
|
209
|
+
tick_data.tick_number = tick_number
|
|
210
|
+
tick_data.timestamp = int(time.time() * 1000)
|
|
211
|
+
|
|
212
|
+
# Execute action in simulation if provided
|
|
213
|
+
if tick_data.action and tick_data.action.action_type != "wait":
|
|
214
|
+
result = await simulation.performAction(
|
|
215
|
+
tick_data.action.action_type,
|
|
216
|
+
tick_data.action.parameters,
|
|
217
|
+
)
|
|
218
|
+
tick_data.feedback = result
|
|
219
|
+
tick_data.action.success = result.get("success", False)
|
|
220
|
+
tick_data.action.result = result.get("result")
|
|
221
|
+
tick_data.action.error = result.get("error")
|
|
222
|
+
|
|
223
|
+
# Calculate reward for this tick using The Judge logic
|
|
224
|
+
tick_data.reward = self._calculate_tick_reward(
|
|
225
|
+
tick_data, env_state)
|
|
226
|
+
total_reward += tick_data.reward
|
|
227
|
+
|
|
228
|
+
# Validate tick quality
|
|
229
|
+
if not self._validate_tick_quality(tick_data):
|
|
230
|
+
logger.warning(f"Tick {tick_number} failed quality check")
|
|
231
|
+
|
|
232
|
+
# Store tick
|
|
233
|
+
all_ticks.append(tick_data)
|
|
234
|
+
total_llm_calls += len(tick_data.llm_calls)
|
|
235
|
+
|
|
236
|
+
# Track timing
|
|
237
|
+
tick_duration = time.time() - tick_start
|
|
238
|
+
tick_durations.append(tick_duration)
|
|
239
|
+
|
|
240
|
+
# Advance simulation (no artificial delay in fast-forward mode)
|
|
241
|
+
simulation.advanceTick()
|
|
242
|
+
tick_number += 1
|
|
243
|
+
|
|
244
|
+
# Log progress periodically
|
|
245
|
+
if tick_number % 50 == 0:
|
|
246
|
+
avg_tick = sum(tick_durations[-50:]) / \
|
|
247
|
+
min(50, len(tick_durations))
|
|
248
|
+
logger.info(
|
|
249
|
+
f"Tick {tick_number}: avg {avg_tick*1000:.1f}ms/tick")
|
|
250
|
+
|
|
251
|
+
total_duration_ms = int((time.time() - start_time) * 1000)
|
|
252
|
+
avg_tick_duration = sum(tick_durations) / \
|
|
253
|
+
len(tick_durations) if tick_durations else 0
|
|
254
|
+
|
|
255
|
+
# Build trajectory from ticks
|
|
256
|
+
trajectory = build_trajectory_from_ticks(
|
|
257
|
+
trajectory_id=trajectory_id,
|
|
258
|
+
agent_id=agent_id,
|
|
259
|
+
ticks=all_ticks,
|
|
260
|
+
min_steps=1,
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
# Calculate quality score
|
|
264
|
+
quality_score = calculate_trajectory_quality_score(all_ticks)
|
|
265
|
+
|
|
266
|
+
result = RolloutResult(
|
|
267
|
+
agent_id=agent_id,
|
|
268
|
+
trajectory_id=trajectory_id,
|
|
269
|
+
ticks_completed=tick_number,
|
|
270
|
+
total_duration_ms=total_duration_ms,
|
|
271
|
+
avg_tick_duration_ms=avg_tick_duration * 1000,
|
|
272
|
+
total_llm_calls=total_llm_calls,
|
|
273
|
+
total_reward=total_reward,
|
|
274
|
+
final_pnl=trajectory.final_pnl if trajectory else 0.0,
|
|
275
|
+
quality_score=quality_score,
|
|
276
|
+
trajectory=trajectory,
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
self.rollouts_generated += 1
|
|
280
|
+
self.total_ticks += tick_number
|
|
281
|
+
|
|
282
|
+
logger.info(
|
|
283
|
+
f"Rollout complete: {tick_number} ticks in {total_duration_ms}ms "
|
|
284
|
+
f"({avg_tick_duration*1000:.1f}ms/tick), quality={quality_score:.2f}"
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
return result
|
|
288
|
+
|
|
289
|
+
async def generate_parallel_rollouts(
|
|
290
|
+
self,
|
|
291
|
+
agents: list[tuple[AgentRunner, str]], # (agent, agent_id)
|
|
292
|
+
simulation_factory: Callable,
|
|
293
|
+
) -> list[RolloutResult]:
|
|
294
|
+
"""
|
|
295
|
+
Generate multiple rollouts in parallel.
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
agents: List of (agent, agent_id) tuples
|
|
299
|
+
simulation_factory: Factory to create simulation instances
|
|
300
|
+
|
|
301
|
+
Returns:
|
|
302
|
+
List of RolloutResults
|
|
303
|
+
"""
|
|
304
|
+
self.start_time = time.time()
|
|
305
|
+
|
|
306
|
+
logger.info(
|
|
307
|
+
f"Starting parallel rollout generation for {len(agents)} agents")
|
|
308
|
+
|
|
309
|
+
# Create tasks for each agent
|
|
310
|
+
tasks = []
|
|
311
|
+
for agent, agent_id in agents:
|
|
312
|
+
simulation = simulation_factory()
|
|
313
|
+
simulation.initialize()
|
|
314
|
+
tasks.append(self.generate_rollout(agent, agent_id, simulation))
|
|
315
|
+
|
|
316
|
+
# Run all in parallel
|
|
317
|
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
318
|
+
|
|
319
|
+
# Filter out errors
|
|
320
|
+
valid_results = []
|
|
321
|
+
for r in results:
|
|
322
|
+
if isinstance(r, Exception):
|
|
323
|
+
logger.error(f"Rollout failed: {r}")
|
|
324
|
+
else:
|
|
325
|
+
valid_results.append(r)
|
|
326
|
+
|
|
327
|
+
total_time = time.time() - self.start_time
|
|
328
|
+
logger.info(
|
|
329
|
+
f"Parallel rollout generation complete: "
|
|
330
|
+
f"{len(valid_results)}/{len(agents)} succeeded, "
|
|
331
|
+
f"{self.total_ticks} total ticks in {total_time:.1f}s "
|
|
332
|
+
f"({self.total_ticks/total_time:.1f} ticks/s)"
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
return valid_results
|
|
336
|
+
|
|
337
|
+
def _calculate_tick_reward(
|
|
338
|
+
self,
|
|
339
|
+
tick_data: AgentTickData,
|
|
340
|
+
env_state: EnvironmentState,
|
|
341
|
+
) -> float:
|
|
342
|
+
"""
|
|
343
|
+
Calculate reward for a single tick.
|
|
344
|
+
|
|
345
|
+
Combines:
|
|
346
|
+
1. Financial Performance (PnL)
|
|
347
|
+
2. Format Compliance (XML validation)
|
|
348
|
+
3. Reasoning Alignment (Financial Literacy)
|
|
349
|
+
4. Risk Management (Exposure penalties)
|
|
350
|
+
"""
|
|
351
|
+
# 1. Quality Scores (Format & Reasoning)
|
|
352
|
+
fmt_score, rsn_score = calculate_detailed_tick_quality(
|
|
353
|
+
tick_data.llm_calls,
|
|
354
|
+
tick_data.action,
|
|
355
|
+
tick_data.feedback
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
# 2. Risk Calculation
|
|
359
|
+
# Exposure proxy: active positions / max reasonable positions (e.g. 10)
|
|
360
|
+
# Or ideally use a dedicated exposure field if available in env_state
|
|
361
|
+
exposure = min(1.0, env_state.open_positions / 10.0)
|
|
362
|
+
|
|
363
|
+
action_type = "wait"
|
|
364
|
+
if tick_data.action:
|
|
365
|
+
action_type = tick_data.action.action_type
|
|
366
|
+
|
|
367
|
+
risk_penalty_count = 0
|
|
368
|
+
if calculate_risk_reward(exposure, action_type) < 0:
|
|
369
|
+
risk_penalty_count = 1
|
|
370
|
+
|
|
371
|
+
# 3. Financials (PnL for this tick)
|
|
372
|
+
pnl = tick_data.feedback.get("pnl_delta", 0.0)
|
|
373
|
+
|
|
374
|
+
# Build Inputs for Composite Reward
|
|
375
|
+
inputs = TrajectoryRewardInputs(
|
|
376
|
+
final_pnl=pnl,
|
|
377
|
+
starting_balance=env_state.agent_balance,
|
|
378
|
+
end_balance=env_state.agent_balance + pnl,
|
|
379
|
+
format_score=fmt_score,
|
|
380
|
+
reasoning_score=rsn_score,
|
|
381
|
+
risky_actions_count=risk_penalty_count,
|
|
382
|
+
total_actions=1 if tick_data.action else 0,
|
|
383
|
+
successful_actions=1 if tick_data.action and tick_data.action.success else 0
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
return composite_reward(inputs)
|
|
387
|
+
|
|
388
|
+
def _validate_tick_quality(self, tick_data: AgentTickData) -> bool:
|
|
389
|
+
"""Validate that tick data meets quality requirements"""
|
|
390
|
+
# Must have LLM calls if configured
|
|
391
|
+
if self.config.min_llm_calls_per_tick > 0:
|
|
392
|
+
if len(tick_data.llm_calls) < self.config.min_llm_calls_per_tick:
|
|
393
|
+
return False
|
|
394
|
+
|
|
395
|
+
# Must have action if configured
|
|
396
|
+
if self.config.require_action and tick_data.action is None:
|
|
397
|
+
return False
|
|
398
|
+
|
|
399
|
+
# LLM calls must have non-empty responses
|
|
400
|
+
for call in tick_data.llm_calls:
|
|
401
|
+
if not call.response or len(call.response.strip()) == 0:
|
|
402
|
+
return False
|
|
403
|
+
|
|
404
|
+
return True
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
class RolloutQualityValidator:
|
|
408
|
+
"""
|
|
409
|
+
Validates that rollouts meet quality standards for training.
|
|
410
|
+
"""
|
|
411
|
+
|
|
412
|
+
@staticmethod
|
|
413
|
+
def validate_rollout(result: RolloutResult) -> tuple[bool, list[str]]:
|
|
414
|
+
"""
|
|
415
|
+
Validate a rollout result.
|
|
416
|
+
|
|
417
|
+
Returns:
|
|
418
|
+
(is_valid, list of issues)
|
|
419
|
+
"""
|
|
420
|
+
issues = []
|
|
421
|
+
|
|
422
|
+
# Must have trajectory
|
|
423
|
+
if result.trajectory is None:
|
|
424
|
+
issues.append("No trajectory data")
|
|
425
|
+
return False, issues
|
|
426
|
+
|
|
427
|
+
# Minimum ticks
|
|
428
|
+
if result.ticks_completed < 5:
|
|
429
|
+
issues.append(f"Too few ticks: {result.ticks_completed} < 5")
|
|
430
|
+
|
|
431
|
+
# Must have LLM calls
|
|
432
|
+
if result.total_llm_calls < result.ticks_completed:
|
|
433
|
+
issues.append(
|
|
434
|
+
f"Low LLM call rate: {result.total_llm_calls} calls for {result.ticks_completed} ticks")
|
|
435
|
+
|
|
436
|
+
# Quality score threshold
|
|
437
|
+
if result.quality_score < 0.5:
|
|
438
|
+
issues.append(
|
|
439
|
+
f"Quality score too low: {result.quality_score:.2f} < 0.5")
|
|
440
|
+
|
|
441
|
+
# Check trajectory steps
|
|
442
|
+
traj = result.trajectory
|
|
443
|
+
for i, step in enumerate(traj.steps):
|
|
444
|
+
# Each step should have LLM calls
|
|
445
|
+
if not step.llm_calls:
|
|
446
|
+
issues.append(f"Step {i} has no LLM calls")
|
|
447
|
+
|
|
448
|
+
# Each LLM call should have content
|
|
449
|
+
for j, call in enumerate(step.llm_calls):
|
|
450
|
+
if not call.user_prompt or not call.response:
|
|
451
|
+
issues.append(
|
|
452
|
+
f"Step {i}, call {j}: missing prompt or response")
|
|
453
|
+
if not call.system_prompt:
|
|
454
|
+
issues.append(f"Step {i}, call {j}: missing system prompt")
|
|
455
|
+
|
|
456
|
+
is_valid = len(issues) == 0
|
|
457
|
+
return is_valid, issues
|
|
458
|
+
|
|
459
|
+
@staticmethod
|
|
460
|
+
def print_quality_report(results: list[RolloutResult]) -> None:
|
|
461
|
+
"""Print a quality report for a batch of rollouts"""
|
|
462
|
+
print("\n" + "=" * 60)
|
|
463
|
+
print(" ROLLOUT QUALITY REPORT")
|
|
464
|
+
print("=" * 60)
|
|
465
|
+
|
|
466
|
+
total = len(results)
|
|
467
|
+
valid_count = 0
|
|
468
|
+
total_ticks = 0
|
|
469
|
+
total_llm_calls = 0
|
|
470
|
+
total_quality = 0.0
|
|
471
|
+
all_issues: list[str] = []
|
|
472
|
+
|
|
473
|
+
for result in results:
|
|
474
|
+
is_valid, issues = RolloutQualityValidator.validate_rollout(result)
|
|
475
|
+
if is_valid:
|
|
476
|
+
valid_count += 1
|
|
477
|
+
all_issues.extend(issues)
|
|
478
|
+
|
|
479
|
+
total_ticks += result.ticks_completed
|
|
480
|
+
total_llm_calls += result.total_llm_calls
|
|
481
|
+
total_quality += result.quality_score
|
|
482
|
+
|
|
483
|
+
print(
|
|
484
|
+
f"\nValid rollouts: {valid_count}/{total} ({valid_count/total*100:.1f}%)")
|
|
485
|
+
print(f"Total ticks: {total_ticks}")
|
|
486
|
+
print(f"Total LLM calls: {total_llm_calls}")
|
|
487
|
+
print(f"Average quality score: {total_quality/total:.2f}")
|
|
488
|
+
print(f"LLM calls per tick: {total_llm_calls/total_ticks:.1f}")
|
|
489
|
+
|
|
490
|
+
if all_issues:
|
|
491
|
+
print(f"\nIssues found ({len(all_issues)} total):")
|
|
492
|
+
# Group and count issues
|
|
493
|
+
issue_counts: dict[str, int] = {}
|
|
494
|
+
for issue in all_issues:
|
|
495
|
+
# Normalize issue text
|
|
496
|
+
key = issue.split(":")[0] if ":" in issue else issue
|
|
497
|
+
issue_counts[key] = issue_counts.get(key, 0) + 1
|
|
498
|
+
|
|
499
|
+
for issue, count in sorted(issue_counts.items(), key=lambda x: -x[1])[:10]:
|
|
500
|
+
print(f" - {issue}: {count} occurrences")
|
|
501
|
+
|
|
502
|
+
print("=" * 60 + "\n")
|