@elizaos/training 2.0.0-alpha.11
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Dockerfile +75 -0
- package/Makefile +374 -0
- package/README.md +346 -0
- package/config/rubrics.json +137 -0
- package/data/.gitkeep +0 -0
- package/data/degen/.gitkeep +2 -0
- package/data/trader/.gitkeep +2 -0
- package/docker-compose.test.yml +57 -0
- package/package.json +58 -0
- package/python/config/babylon_atropos.yaml +90 -0
- package/python/config/profiles/12gb.json +11 -0
- package/python/config/profiles/16gb.json +10 -0
- package/python/config/profiles/24gb.json +10 -0
- package/python/config/profiles/48gb.json +10 -0
- package/python/config/profiles/cpu.json +11 -0
- package/python/config/profiles/l40-2gpu-safe.json +20 -0
- package/python/config/profiles/l40-2gpu.json +22 -0
- package/python/config/profiles/l40-4gpu.json +21 -0
- package/python/config/profiles/l40.json +17 -0
- package/python/config/tinker_training.yaml +143 -0
- package/python/curriculum_state.json +165 -0
- package/python/env.template +86 -0
- package/python/env.training.template +46 -0
- package/python/pyproject.toml +41 -0
- package/python/requirements-ci.txt +31 -0
- package/python/requirements.txt +87 -0
- package/python/scripts/__init__.py +4 -0
- package/python/scripts/import_json_trajectories.py +412 -0
- package/python/scripts/local-finetune/README.md +63 -0
- package/python/scripts/local-finetune/ingest_and_score.py +139 -0
- package/python/scripts/local-finetune/merge_model.py +32 -0
- package/python/scripts/local-finetune/test_adapter.py +91 -0
- package/python/scripts/local-finetune/train_from_csv.py +132 -0
- package/python/scripts/merge_trajectories.py +318 -0
- package/python/scripts/run_ab_test.py +143 -0
- package/python/scripts/run_full_pipeline.py +544 -0
- package/python/scripts/run_tinker_training.py +192 -0
- package/python/scripts/run_training.py +914 -0
- package/python/scripts/test_judge.py +155 -0
- package/python/scripts/test_pipeline.py +356 -0
- package/python/scripts/test_trained_model.py +380 -0
- package/python/scripts/train_local.py +528 -0
- package/python/setup.py +20 -0
- package/python/src/__init__.py +190 -0
- package/python/src/data_bridge/__init__.py +24 -0
- package/python/src/data_bridge/converter.py +435 -0
- package/python/src/data_bridge/reader.py +393 -0
- package/python/src/models.py +283 -0
- package/python/src/training/__init__.py +605 -0
- package/python/src/training/ab_testing.py +404 -0
- package/python/src/training/action_executor.py +621 -0
- package/python/src/training/archetype_trainer.py +347 -0
- package/python/src/training/atropos_trainer.py +980 -0
- package/python/src/training/babylon_env.py +1254 -0
- package/python/src/training/error_recovery.py +647 -0
- package/python/src/training/evaluation.py +856 -0
- package/python/src/training/fast_simulator.py +880 -0
- package/python/src/training/format_validator.py +584 -0
- package/python/src/training/hybrid_env.py +522 -0
- package/python/src/training/kl_controller.py +628 -0
- package/python/src/training/multi_prompt_dataset.py +883 -0
- package/python/src/training/multi_turn.py +656 -0
- package/python/src/training/online_env.py +1084 -0
- package/python/src/training/quality_scorer.py +391 -0
- package/python/src/training/quality_utils.py +633 -0
- package/python/src/training/rewards.py +1344 -0
- package/python/src/training/rlaif_env.py +17 -0
- package/python/src/training/rollout_generator.py +502 -0
- package/python/src/training/rubric_loader.py +198 -0
- package/python/src/training/scenario_pool.py +1072 -0
- package/python/src/training/schemas.py +481 -0
- package/python/src/training/service_manager.py +552 -0
- package/python/src/training/simulation_bridge.py +535 -0
- package/python/src/training/tick_reward_attribution.py +399 -0
- package/python/src/training/tinker_client.py +575 -0
- package/python/src/training/tinker_trainer.py +646 -0
- package/python/src/training/tokenization_utils.py +402 -0
- package/python/tests/e2e/__init__.py +13 -0
- package/python/tests/e2e/conftest.py +258 -0
- package/python/tests/e2e/test_full_pipeline.py +643 -0
- package/python/tests/e2e/test_online_training_e2e.py +365 -0
- package/python/tests/integration/__init__.py +12 -0
- package/python/tests/integration/conftest.py +383 -0
- package/python/tests/integration/test_db_integration.py +649 -0
- package/python/tests/integration/test_json_mode_integration.py +554 -0
- package/python/tests/test_action_executor.py +594 -0
- package/python/tests/test_archetype_scoring.py +1027 -0
- package/python/tests/test_atropos_integration.py +360 -0
- package/python/tests/test_evaluation.py +727 -0
- package/python/tests/test_format_validator.py +486 -0
- package/python/tests/test_kl_controller.py +432 -0
- package/python/tests/test_lr_scheduler.py +579 -0
- package/python/tests/test_multi_turn.py +590 -0
- package/python/tests/test_online_env.py +519 -0
- package/python/tests/test_quality_scorer.py +474 -0
- package/python/tests/test_scenario_pool.py +735 -0
- package/python/tests/test_service_manager.py +585 -0
- package/python/tests/test_simulation_rollout.py +581 -0
- package/python/tests/test_tokenization_utils.py +501 -0
- package/python/tests/test_training_orchestrator.py +497 -0
- package/python/tests/test_training_output_structure.py +661 -0
- package/research-output/training-runs/training-run-1770772042899.json +26 -0
- package/research-output/training-runs/training-run-1770930079670.json +32 -0
- package/research-output/training-runs/training-run-1770930143700.json +44 -0
- package/research-output/training-runs/training-run-1770930183638.json +38 -0
- package/research-output/training-runs/training-run-1770930442049.json +38 -0
- package/research-output/training-runs/training-run-1770930793243.json +38 -0
- package/scripts/assess-training-data.ts +422 -0
- package/scripts/e2e-training-test.ts +550 -0
- package/scripts/export-rubrics.ts +64 -0
- package/scripts/generate-research-report.ts +1523 -0
- package/scripts/generate_dataset.sh +173 -0
- package/scripts/json-mode-benchmark.ts +399 -0
- package/scripts/real-archetype-benchmark.ts +210 -0
- package/scripts/run-baseline-comparison.ts +116 -0
- package/scripts/run-full-pipeline.ts +272 -0
- package/scripts/runpod_setup.sh +137 -0
- package/scripts/runpod_validate.sh +147 -0
- package/scripts/test-model-in-game.ts +955 -0
- package/scripts/test-scoring.ts +73 -0
- package/scripts/test-trained-model.ts +209 -0
- package/scripts/train-and-test.ts +824 -0
- package/scripts/verify-final.ts +118 -0
- package/src/adapter.ts +516 -0
- package/src/archetypes/ArchetypeConfigService.ts +626 -0
- package/src/archetypes/derive-archetype.ts +249 -0
- package/src/archetypes/index.ts +22 -0
- package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
- package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
- package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
- package/src/benchmark/BenchmarkDataViewer.ts +324 -0
- package/src/benchmark/BenchmarkHistoryService.ts +221 -0
- package/src/benchmark/BenchmarkRunner.ts +685 -0
- package/src/benchmark/BenchmarkValidator.ts +206 -0
- package/src/benchmark/FastEvalRunner.ts +225 -0
- package/src/benchmark/MetricsValidator.ts +165 -0
- package/src/benchmark/MetricsVisualizer.ts +909 -0
- package/src/benchmark/ModelBenchmarkService.ts +611 -0
- package/src/benchmark/ModelRegistry.ts +158 -0
- package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
- package/src/benchmark/SimulationA2AInterface.ts +1169 -0
- package/src/benchmark/SimulationEngine.ts +832 -0
- package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
- package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
- package/src/benchmark/index.ts +89 -0
- package/src/benchmark/parseSimulationMetrics.ts +124 -0
- package/src/benchmark/simulation-types.ts +78 -0
- package/src/dependencies.ts +439 -0
- package/src/generation/TrajectoryGenerator.ts +387 -0
- package/src/generation/index.ts +12 -0
- package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
- package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
- package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
- package/src/huggingface/index.ts +27 -0
- package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
- package/src/index.ts +102 -0
- package/src/init-training.ts +53 -0
- package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
- package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
- package/src/metrics/index.ts +8 -0
- package/src/metrics/types.ts +200 -0
- package/src/rubrics/__tests__/index.test.ts +184 -0
- package/src/rubrics/ass-kisser.ts +85 -0
- package/src/rubrics/degen.ts +80 -0
- package/src/rubrics/goody-twoshoes.ts +84 -0
- package/src/rubrics/index.ts +236 -0
- package/src/rubrics/information-trader.ts +84 -0
- package/src/rubrics/infosec.ts +101 -0
- package/src/rubrics/liar.ts +104 -0
- package/src/rubrics/perps-trader.ts +87 -0
- package/src/rubrics/researcher.ts +81 -0
- package/src/rubrics/scammer.ts +82 -0
- package/src/rubrics/social-butterfly.ts +73 -0
- package/src/rubrics/super-predictor.ts +97 -0
- package/src/rubrics/trader.ts +67 -0
- package/src/scoring/ArchetypeScoringService.ts +486 -0
- package/src/scoring/JudgePromptBuilder.ts +556 -0
- package/src/scoring/LLMJudgeCache.ts +401 -0
- package/src/scoring/index.ts +9 -0
- package/src/training/AutomationPipeline.ts +916 -0
- package/src/training/BenchmarkService.ts +518 -0
- package/src/training/ConfigValidator.ts +220 -0
- package/src/training/MarketOutcomesTracker.ts +187 -0
- package/src/training/ModelDeployer.ts +186 -0
- package/src/training/ModelFetcher.ts +76 -0
- package/src/training/ModelSelectionService.ts +341 -0
- package/src/training/ModelUsageVerifier.ts +160 -0
- package/src/training/MultiModelOrchestrator.ts +580 -0
- package/src/training/RLModelConfig.ts +407 -0
- package/src/training/RewardBackpropagationService.ts +149 -0
- package/src/training/RulerScoringService.ts +666 -0
- package/src/training/TrainingMonitor.ts +166 -0
- package/src/training/TrajectoryRecorder.ts +399 -0
- package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
- package/src/training/index.ts +100 -0
- package/src/training/logRLConfig.ts +34 -0
- package/src/training/pipeline.ts +129 -0
- package/src/training/storage/ModelStorageService.ts +279 -0
- package/src/training/storage/TrainingDataArchiver.ts +197 -0
- package/src/training/storage/index.ts +17 -0
- package/src/training/types.ts +207 -0
- package/src/training/window-utils.ts +138 -0
- package/src/utils/index.ts +101 -0
- package/src/utils/logger.ts +59 -0
- package/src/utils/snowflake.ts +17 -0
- package/src/utils/synthetic-detector.ts +111 -0
- package/tsconfig.json +20 -0
|
@@ -0,0 +1,656 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Multi-Turn Episode Manager for GRPO Training
|
|
3
|
+
|
|
4
|
+
⚠️ STATUS: NOT YET INTEGRATED
|
|
5
|
+
This module is ready to use but not currently called by babylon_env.py or online_env.py.
|
|
6
|
+
To integrate, see TRAINING_ROADMAP.md Phase 4.
|
|
7
|
+
|
|
8
|
+
Handles multi-turn trading episodes with proper credit assignment
|
|
9
|
+
using Generalized Advantage Estimation (GAE).
|
|
10
|
+
|
|
11
|
+
For trading scenarios:
|
|
12
|
+
- A single action's value depends on future market movements
|
|
13
|
+
- Subsequent actions affect overall trajectory value
|
|
14
|
+
- Final episode outcome determines success
|
|
15
|
+
|
|
16
|
+
This module provides:
|
|
17
|
+
1. TurnData - Structure for individual turn data
|
|
18
|
+
2. EpisodeBuffer - Buffer for collecting episode turns
|
|
19
|
+
3. MultiTurnEpisodeManager - GAE-based advantage computation
|
|
20
|
+
4. Trajectory utilities for GRPO training
|
|
21
|
+
|
|
22
|
+
Usage:
|
|
23
|
+
manager = MultiTurnEpisodeManager(gamma=0.99, gae_lambda=0.95)
|
|
24
|
+
|
|
25
|
+
# Collect episode turns
|
|
26
|
+
episode = EpisodeBuffer(scenario_id="test-1")
|
|
27
|
+
for turn in range(max_turns):
|
|
28
|
+
turn_data = TurnData(...)
|
|
29
|
+
episode.add_turn(turn_data)
|
|
30
|
+
if done:
|
|
31
|
+
break
|
|
32
|
+
|
|
33
|
+
# Compute advantages
|
|
34
|
+
manager.compute_advantages(episode.turns)
|
|
35
|
+
|
|
36
|
+
# Create training items
|
|
37
|
+
training_items = manager.create_training_items(episode.turns, tokenizer)
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
import logging
|
|
41
|
+
from dataclasses import dataclass, field
|
|
42
|
+
from datetime import datetime, timezone
|
|
43
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
44
|
+
|
|
45
|
+
logger = logging.getLogger(__name__)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# =============================================================================
|
|
49
|
+
# Turn and Episode Data Structures
|
|
50
|
+
# =============================================================================
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass
|
|
54
|
+
class TurnData:
|
|
55
|
+
"""
|
|
56
|
+
Data for a single turn in a multi-turn episode.
|
|
57
|
+
|
|
58
|
+
Contains all information needed for training on this turn,
|
|
59
|
+
including computed advantages for GRPO.
|
|
60
|
+
"""
|
|
61
|
+
# Turn identification
|
|
62
|
+
turn_number: int
|
|
63
|
+
episode_id: str = ""
|
|
64
|
+
|
|
65
|
+
# State at this turn
|
|
66
|
+
state: Dict[str, Any] = field(default_factory=dict)
|
|
67
|
+
observation: str = ""
|
|
68
|
+
|
|
69
|
+
# Action taken
|
|
70
|
+
action: Dict[str, Any] = field(default_factory=dict)
|
|
71
|
+
action_text: str = ""
|
|
72
|
+
action_type: str = ""
|
|
73
|
+
|
|
74
|
+
# Messages up to this point
|
|
75
|
+
messages: List[Dict[str, str]] = field(default_factory=list)
|
|
76
|
+
|
|
77
|
+
# Reward received after action
|
|
78
|
+
reward: float = 0.0
|
|
79
|
+
|
|
80
|
+
# Quality scores
|
|
81
|
+
format_score: float = 0.0
|
|
82
|
+
reasoning_score: float = 0.0
|
|
83
|
+
|
|
84
|
+
# Episode termination
|
|
85
|
+
done: bool = False
|
|
86
|
+
termination_reason: str = ""
|
|
87
|
+
|
|
88
|
+
# Computed values (filled in by GAE)
|
|
89
|
+
value: float = 0.0
|
|
90
|
+
advantage: float = 0.0
|
|
91
|
+
return_to_go: float = 0.0
|
|
92
|
+
|
|
93
|
+
# Token data for training
|
|
94
|
+
tokens: List[int] = field(default_factory=list)
|
|
95
|
+
masks: List[int] = field(default_factory=list)
|
|
96
|
+
logprobs: List[float] = field(default_factory=list)
|
|
97
|
+
|
|
98
|
+
# Metadata
|
|
99
|
+
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
100
|
+
|
|
101
|
+
def to_dict(self) -> Dict:
|
|
102
|
+
"""Convert to dictionary for serialization"""
|
|
103
|
+
return {
|
|
104
|
+
"turn_number": self.turn_number,
|
|
105
|
+
"episode_id": self.episode_id,
|
|
106
|
+
"action_type": self.action_type,
|
|
107
|
+
"action_text": self.action_text[:200], # Truncate for logging
|
|
108
|
+
"reward": round(self.reward, 4),
|
|
109
|
+
"format_score": round(self.format_score, 3),
|
|
110
|
+
"reasoning_score": round(self.reasoning_score, 3),
|
|
111
|
+
"done": self.done,
|
|
112
|
+
"value": round(self.value, 4),
|
|
113
|
+
"advantage": round(self.advantage, 4),
|
|
114
|
+
"return_to_go": round(self.return_to_go, 4),
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@dataclass
|
|
119
|
+
class EpisodeBuffer:
|
|
120
|
+
"""
|
|
121
|
+
Buffer for collecting turns during an episode.
|
|
122
|
+
|
|
123
|
+
Tracks episode-level metrics and provides utilities
|
|
124
|
+
for episode analysis.
|
|
125
|
+
"""
|
|
126
|
+
# Episode identification
|
|
127
|
+
episode_id: str
|
|
128
|
+
scenario_id: str = ""
|
|
129
|
+
archetype: str = "trader"
|
|
130
|
+
|
|
131
|
+
# Turns collected
|
|
132
|
+
turns: List[TurnData] = field(default_factory=list)
|
|
133
|
+
|
|
134
|
+
# Episode-level metrics (computed after completion)
|
|
135
|
+
total_reward: float = 0.0
|
|
136
|
+
total_pnl: float = 0.0
|
|
137
|
+
episode_length: int = 0
|
|
138
|
+
completed: bool = False
|
|
139
|
+
success: bool = False
|
|
140
|
+
|
|
141
|
+
# Timing
|
|
142
|
+
start_time: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
143
|
+
end_time: Optional[datetime] = None
|
|
144
|
+
|
|
145
|
+
def add_turn(self, turn: TurnData) -> None:
|
|
146
|
+
"""Add a turn to the buffer"""
|
|
147
|
+
turn.episode_id = self.episode_id
|
|
148
|
+
self.turns.append(turn)
|
|
149
|
+
|
|
150
|
+
if turn.done:
|
|
151
|
+
self._finalize()
|
|
152
|
+
|
|
153
|
+
def _finalize(self) -> None:
|
|
154
|
+
"""Finalize episode after completion"""
|
|
155
|
+
self.completed = True
|
|
156
|
+
self.end_time = datetime.now(timezone.utc)
|
|
157
|
+
self.episode_length = len(self.turns)
|
|
158
|
+
self.total_reward = sum(t.reward for t in self.turns)
|
|
159
|
+
|
|
160
|
+
# Check success based on final reward
|
|
161
|
+
if self.turns:
|
|
162
|
+
final_turn = self.turns[-1]
|
|
163
|
+
# Episode is successful if total reward is positive
|
|
164
|
+
self.success = self.total_reward > 0
|
|
165
|
+
|
|
166
|
+
def get_messages(self) -> List[Dict[str, str]]:
|
|
167
|
+
"""Get all messages from the episode"""
|
|
168
|
+
if self.turns:
|
|
169
|
+
return self.turns[-1].messages.copy()
|
|
170
|
+
return []
|
|
171
|
+
|
|
172
|
+
def get_trajectory(self) -> List[Tuple[str, float, str]]:
|
|
173
|
+
"""Get (action_type, reward, action_text) for each turn"""
|
|
174
|
+
return [
|
|
175
|
+
(t.action_type, t.reward, t.action_text[:100])
|
|
176
|
+
for t in self.turns
|
|
177
|
+
]
|
|
178
|
+
|
|
179
|
+
def to_dict(self) -> Dict:
|
|
180
|
+
"""Convert to dictionary for logging"""
|
|
181
|
+
return {
|
|
182
|
+
"episode_id": self.episode_id,
|
|
183
|
+
"scenario_id": self.scenario_id,
|
|
184
|
+
"archetype": self.archetype,
|
|
185
|
+
"episode_length": self.episode_length,
|
|
186
|
+
"total_reward": round(self.total_reward, 4),
|
|
187
|
+
"total_pnl": round(self.total_pnl, 2),
|
|
188
|
+
"completed": self.completed,
|
|
189
|
+
"success": self.success,
|
|
190
|
+
"turns": [t.to_dict() for t in self.turns],
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
# =============================================================================
|
|
195
|
+
# Multi-Turn Episode Manager
|
|
196
|
+
# =============================================================================
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
@dataclass
|
|
200
|
+
class GAEConfig:
|
|
201
|
+
"""Configuration for GAE computation"""
|
|
202
|
+
gamma: float = 0.99 # Discount factor
|
|
203
|
+
gae_lambda: float = 0.95 # GAE lambda
|
|
204
|
+
normalize_advantages: bool = True # Normalize across batch
|
|
205
|
+
clip_advantages: bool = True # Clip extreme advantages
|
|
206
|
+
advantage_clip: float = 10.0 # Clip threshold
|
|
207
|
+
use_value_normalization: bool = True # Normalize values
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class MultiTurnEpisodeManager:
|
|
211
|
+
"""
|
|
212
|
+
Manages multi-turn trading episodes with GAE-based credit assignment.
|
|
213
|
+
|
|
214
|
+
For trading scenarios, a single action's value depends on:
|
|
215
|
+
- Future market movements
|
|
216
|
+
- Subsequent actions
|
|
217
|
+
- Final episode outcome
|
|
218
|
+
|
|
219
|
+
This manager computes proper advantages for each turn to enable
|
|
220
|
+
credit assignment in multi-turn settings.
|
|
221
|
+
"""
|
|
222
|
+
|
|
223
|
+
def __init__(
|
|
224
|
+
self,
|
|
225
|
+
gamma: float = 0.99,
|
|
226
|
+
gae_lambda: float = 0.95,
|
|
227
|
+
max_turns: int = 20,
|
|
228
|
+
normalize_advantages: bool = True,
|
|
229
|
+
):
|
|
230
|
+
self.config = GAEConfig(
|
|
231
|
+
gamma=gamma,
|
|
232
|
+
gae_lambda=gae_lambda,
|
|
233
|
+
normalize_advantages=normalize_advantages,
|
|
234
|
+
)
|
|
235
|
+
self.max_turns = max_turns
|
|
236
|
+
|
|
237
|
+
# Statistics
|
|
238
|
+
self._episodes_processed = 0
|
|
239
|
+
self._total_turns = 0
|
|
240
|
+
|
|
241
|
+
def compute_advantages(self, turns: List[TurnData]) -> None:
|
|
242
|
+
"""
|
|
243
|
+
Compute GAE advantages for each turn in an episode.
|
|
244
|
+
|
|
245
|
+
Modifies turns in-place to add:
|
|
246
|
+
- value: Estimated value of the state
|
|
247
|
+
- return_to_go: Cumulative discounted future reward
|
|
248
|
+
- advantage: GAE advantage estimate
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
turns: List of turns in chronological order
|
|
252
|
+
"""
|
|
253
|
+
if not turns:
|
|
254
|
+
return
|
|
255
|
+
|
|
256
|
+
# Step 1: Compute return-to-go (cumulative discounted reward)
|
|
257
|
+
# Going backwards from the end
|
|
258
|
+
cumulative = 0.0
|
|
259
|
+
for turn in reversed(turns):
|
|
260
|
+
cumulative = turn.reward + self.config.gamma * cumulative
|
|
261
|
+
turn.return_to_go = cumulative
|
|
262
|
+
|
|
263
|
+
# Step 2: Estimate values
|
|
264
|
+
# Simple approach: value = return_to_go (Monte Carlo estimate)
|
|
265
|
+
for turn in turns:
|
|
266
|
+
turn.value = turn.return_to_go
|
|
267
|
+
|
|
268
|
+
# Step 3: Compute GAE advantages
|
|
269
|
+
next_value = 0.0
|
|
270
|
+
gae = 0.0
|
|
271
|
+
|
|
272
|
+
for turn in reversed(turns):
|
|
273
|
+
# TD error
|
|
274
|
+
if turn.done:
|
|
275
|
+
next_value = 0.0
|
|
276
|
+
|
|
277
|
+
delta = turn.reward + self.config.gamma * next_value - turn.value
|
|
278
|
+
|
|
279
|
+
# GAE accumulation
|
|
280
|
+
gae = delta + self.config.gamma * self.config.gae_lambda * gae
|
|
281
|
+
turn.advantage = gae
|
|
282
|
+
|
|
283
|
+
next_value = turn.value
|
|
284
|
+
|
|
285
|
+
# Step 4: Clip extreme advantages
|
|
286
|
+
if self.config.clip_advantages:
|
|
287
|
+
for turn in turns:
|
|
288
|
+
turn.advantage = max(
|
|
289
|
+
-self.config.advantage_clip,
|
|
290
|
+
min(self.config.advantage_clip, turn.advantage),
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
self._episodes_processed += 1
|
|
294
|
+
self._total_turns += len(turns)
|
|
295
|
+
|
|
296
|
+
def compute_batch_advantages(
|
|
297
|
+
self,
|
|
298
|
+
episodes: List[List[TurnData]],
|
|
299
|
+
) -> None:
|
|
300
|
+
"""
|
|
301
|
+
Compute advantages for a batch of episodes.
|
|
302
|
+
|
|
303
|
+
Also normalizes advantages across the batch if configured.
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
episodes: List of episodes, each containing turns
|
|
307
|
+
"""
|
|
308
|
+
# Compute per-episode advantages
|
|
309
|
+
for episode in episodes:
|
|
310
|
+
self.compute_advantages(episode)
|
|
311
|
+
|
|
312
|
+
# Normalize across batch
|
|
313
|
+
if self.config.normalize_advantages and episodes:
|
|
314
|
+
all_advantages = []
|
|
315
|
+
for episode in episodes:
|
|
316
|
+
for turn in episode:
|
|
317
|
+
all_advantages.append(turn.advantage)
|
|
318
|
+
|
|
319
|
+
if all_advantages:
|
|
320
|
+
mean = sum(all_advantages) / len(all_advantages)
|
|
321
|
+
|
|
322
|
+
if len(all_advantages) > 1:
|
|
323
|
+
variance = sum((a - mean) ** 2 for a in all_advantages) / len(all_advantages)
|
|
324
|
+
std = max(variance ** 0.5, 1e-8)
|
|
325
|
+
else:
|
|
326
|
+
std = 1.0
|
|
327
|
+
|
|
328
|
+
# Normalize
|
|
329
|
+
for episode in episodes:
|
|
330
|
+
for turn in episode:
|
|
331
|
+
turn.advantage = (turn.advantage - mean) / std
|
|
332
|
+
|
|
333
|
+
def create_training_items(
|
|
334
|
+
self,
|
|
335
|
+
turns: List[TurnData],
|
|
336
|
+
tokenizer,
|
|
337
|
+
train_on_all_assistant_turns: bool = True,
|
|
338
|
+
) -> List[Dict]:
|
|
339
|
+
"""
|
|
340
|
+
Create training items from episode turns.
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
turns: List of turns with computed advantages
|
|
344
|
+
tokenizer: Tokenizer for token processing
|
|
345
|
+
train_on_all_assistant_turns: Train on all turns or just last
|
|
346
|
+
|
|
347
|
+
Returns:
|
|
348
|
+
List of training items suitable for GRPO
|
|
349
|
+
"""
|
|
350
|
+
items = []
|
|
351
|
+
|
|
352
|
+
for turn in turns:
|
|
353
|
+
# Skip if no advantage computed
|
|
354
|
+
if turn.advantage == 0.0 and turn.reward == 0.0:
|
|
355
|
+
continue
|
|
356
|
+
|
|
357
|
+
# If tokens not pre-computed, compute from messages
|
|
358
|
+
if not turn.tokens and turn.messages:
|
|
359
|
+
tokens = self._tokenize_messages(tokenizer, turn.messages)
|
|
360
|
+
turn.tokens = tokens
|
|
361
|
+
turn.masks = self._create_masks(tokenizer, turn.messages, tokens)
|
|
362
|
+
|
|
363
|
+
item = {
|
|
364
|
+
"tokens": turn.tokens,
|
|
365
|
+
"masks": turn.masks,
|
|
366
|
+
"score": turn.advantage, # Use advantage as score for GRPO
|
|
367
|
+
"metadata": {
|
|
368
|
+
"turn": turn.turn_number,
|
|
369
|
+
"episode_id": turn.episode_id,
|
|
370
|
+
"reward": turn.reward,
|
|
371
|
+
"value": turn.value,
|
|
372
|
+
"return_to_go": turn.return_to_go,
|
|
373
|
+
"action_type": turn.action_type,
|
|
374
|
+
},
|
|
375
|
+
}
|
|
376
|
+
|
|
377
|
+
items.append(item)
|
|
378
|
+
|
|
379
|
+
return items
|
|
380
|
+
|
|
381
|
+
def _tokenize_messages(
|
|
382
|
+
self,
|
|
383
|
+
tokenizer,
|
|
384
|
+
messages: List[Dict[str, str]],
|
|
385
|
+
) -> List[int]:
|
|
386
|
+
"""Tokenize chat messages"""
|
|
387
|
+
try:
|
|
388
|
+
tokens = tokenizer.apply_chat_template(
|
|
389
|
+
messages,
|
|
390
|
+
add_generation_prompt=False,
|
|
391
|
+
return_tensors=None,
|
|
392
|
+
)
|
|
393
|
+
return tokens
|
|
394
|
+
except Exception as e:
|
|
395
|
+
logger.warning(f"Tokenization failed: {e}")
|
|
396
|
+
return []
|
|
397
|
+
|
|
398
|
+
def _create_masks(
|
|
399
|
+
self,
|
|
400
|
+
tokenizer,
|
|
401
|
+
messages: List[Dict[str, str]],
|
|
402
|
+
tokens: List[int],
|
|
403
|
+
) -> List[int]:
|
|
404
|
+
"""
|
|
405
|
+
Create training masks for tokens.
|
|
406
|
+
|
|
407
|
+
Masks assistant turns for training.
|
|
408
|
+
"""
|
|
409
|
+
if not tokens:
|
|
410
|
+
return []
|
|
411
|
+
|
|
412
|
+
# Simple approach: find last assistant turn and mask it
|
|
413
|
+
masks = [0] * len(tokens)
|
|
414
|
+
|
|
415
|
+
# Find the position where the last assistant response starts
|
|
416
|
+
# This is an approximation; proper implementation would use
|
|
417
|
+
# the tokenizer's chat template structure
|
|
418
|
+
|
|
419
|
+
if not messages:
|
|
420
|
+
return masks
|
|
421
|
+
|
|
422
|
+
# Find last assistant message
|
|
423
|
+
last_assistant_idx = -1
|
|
424
|
+
for i, msg in enumerate(messages):
|
|
425
|
+
if msg.get("role") == "assistant":
|
|
426
|
+
last_assistant_idx = i
|
|
427
|
+
|
|
428
|
+
if last_assistant_idx == -1:
|
|
429
|
+
return masks
|
|
430
|
+
|
|
431
|
+
# Tokenize everything before the last assistant message
|
|
432
|
+
prompt_messages = messages[:last_assistant_idx]
|
|
433
|
+
if prompt_messages:
|
|
434
|
+
prompt_tokens = self._tokenize_messages(tokenizer, prompt_messages)
|
|
435
|
+
prompt_len = len(prompt_tokens)
|
|
436
|
+
else:
|
|
437
|
+
prompt_len = 0
|
|
438
|
+
|
|
439
|
+
# Mark tokens after prompt as trainable
|
|
440
|
+
for i in range(prompt_len, len(tokens)):
|
|
441
|
+
masks[i] = 1
|
|
442
|
+
|
|
443
|
+
return masks
|
|
444
|
+
|
|
445
|
+
def get_stats(self) -> Dict:
|
|
446
|
+
"""Get manager statistics"""
|
|
447
|
+
return {
|
|
448
|
+
"episodes_processed": self._episodes_processed,
|
|
449
|
+
"total_turns": self._total_turns,
|
|
450
|
+
"avg_turns_per_episode": (
|
|
451
|
+
self._total_turns / max(1, self._episodes_processed)
|
|
452
|
+
),
|
|
453
|
+
"gamma": self.config.gamma,
|
|
454
|
+
"gae_lambda": self.config.gae_lambda,
|
|
455
|
+
"max_turns": self.max_turns,
|
|
456
|
+
}
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
# =============================================================================
|
|
460
|
+
# Reward Shaping Utilities
|
|
461
|
+
# =============================================================================
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
def shape_trading_rewards(
|
|
465
|
+
turns: List[TurnData],
|
|
466
|
+
format_weight: float = 0.2,
|
|
467
|
+
reasoning_weight: float = 0.1,
|
|
468
|
+
pnl_weight: float = 0.5,
|
|
469
|
+
action_weight: float = 0.2,
|
|
470
|
+
) -> None:
|
|
471
|
+
"""
|
|
472
|
+
Shape rewards for trading episodes.
|
|
473
|
+
|
|
474
|
+
Combines multiple reward signals:
|
|
475
|
+
- Format quality (think tags, action JSON)
|
|
476
|
+
- Reasoning quality
|
|
477
|
+
- PnL changes
|
|
478
|
+
- Action quality (valid, appropriate for situation)
|
|
479
|
+
|
|
480
|
+
Modifies turns in-place.
|
|
481
|
+
|
|
482
|
+
Args:
|
|
483
|
+
turns: Episode turns
|
|
484
|
+
format_weight: Weight for format score
|
|
485
|
+
reasoning_weight: Weight for reasoning score
|
|
486
|
+
pnl_weight: Weight for PnL-based reward
|
|
487
|
+
action_weight: Weight for action quality
|
|
488
|
+
"""
|
|
489
|
+
for turn in turns:
|
|
490
|
+
# Start with raw reward (usually PnL-based)
|
|
491
|
+
raw_reward = turn.reward
|
|
492
|
+
|
|
493
|
+
# Add format and reasoning bonuses
|
|
494
|
+
format_bonus = turn.format_score * format_weight
|
|
495
|
+
reasoning_bonus = turn.reasoning_score * reasoning_weight
|
|
496
|
+
|
|
497
|
+
# Compute action quality bonus
|
|
498
|
+
action_bonus = 0.0
|
|
499
|
+
if turn.action_type in ["buy", "sell", "open_perp", "close_perp"]:
|
|
500
|
+
action_bonus = 0.1 * action_weight # Reward for active trading
|
|
501
|
+
elif turn.action_type == "wait":
|
|
502
|
+
action_bonus = 0.05 * action_weight # Smaller reward for waiting
|
|
503
|
+
else:
|
|
504
|
+
action_bonus = -0.1 * action_weight # Penalty for invalid actions
|
|
505
|
+
|
|
506
|
+
# Combine
|
|
507
|
+
turn.reward = (
|
|
508
|
+
raw_reward * pnl_weight +
|
|
509
|
+
format_bonus +
|
|
510
|
+
reasoning_bonus +
|
|
511
|
+
action_bonus
|
|
512
|
+
)
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
def compute_episode_return(
|
|
516
|
+
turns: List[TurnData],
|
|
517
|
+
gamma: float = 0.99,
|
|
518
|
+
) -> float:
|
|
519
|
+
"""
|
|
520
|
+
Compute total discounted return for an episode.
|
|
521
|
+
|
|
522
|
+
Args:
|
|
523
|
+
turns: Episode turns
|
|
524
|
+
gamma: Discount factor
|
|
525
|
+
|
|
526
|
+
Returns:
|
|
527
|
+
Discounted return
|
|
528
|
+
"""
|
|
529
|
+
if not turns:
|
|
530
|
+
return 0.0
|
|
531
|
+
|
|
532
|
+
total = 0.0
|
|
533
|
+
discount = 1.0
|
|
534
|
+
|
|
535
|
+
for turn in turns:
|
|
536
|
+
total += discount * turn.reward
|
|
537
|
+
discount *= gamma
|
|
538
|
+
|
|
539
|
+
return total
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
def normalize_episode_rewards(
|
|
543
|
+
episodes: List[List[TurnData]],
|
|
544
|
+
) -> None:
|
|
545
|
+
"""
|
|
546
|
+
Normalize rewards across episodes.
|
|
547
|
+
|
|
548
|
+
Useful for reducing variance in training.
|
|
549
|
+
|
|
550
|
+
Args:
|
|
551
|
+
episodes: List of episodes to normalize
|
|
552
|
+
"""
|
|
553
|
+
all_rewards = []
|
|
554
|
+
for episode in episodes:
|
|
555
|
+
for turn in episode:
|
|
556
|
+
all_rewards.append(turn.reward)
|
|
557
|
+
|
|
558
|
+
if not all_rewards or len(all_rewards) < 2:
|
|
559
|
+
return
|
|
560
|
+
|
|
561
|
+
mean = sum(all_rewards) / len(all_rewards)
|
|
562
|
+
variance = sum((r - mean) ** 2 for r in all_rewards) / len(all_rewards)
|
|
563
|
+
std = max(variance ** 0.5, 1e-8)
|
|
564
|
+
|
|
565
|
+
for episode in episodes:
|
|
566
|
+
for turn in episode:
|
|
567
|
+
turn.reward = (turn.reward - mean) / std
|
|
568
|
+
|
|
569
|
+
|
|
570
|
+
# =============================================================================
|
|
571
|
+
# Episode Collectors
|
|
572
|
+
# =============================================================================
|
|
573
|
+
|
|
574
|
+
|
|
575
|
+
class EpisodeCollector:
|
|
576
|
+
"""
|
|
577
|
+
Utility for collecting episodes with consistent structure.
|
|
578
|
+
|
|
579
|
+
Provides helper methods for episode management during rollouts.
|
|
580
|
+
"""
|
|
581
|
+
|
|
582
|
+
def __init__(self, max_episodes: int = 1000):
|
|
583
|
+
self.max_episodes = max_episodes
|
|
584
|
+
self.episodes: List[EpisodeBuffer] = []
|
|
585
|
+
self._current_episode: Optional[EpisodeBuffer] = None
|
|
586
|
+
|
|
587
|
+
def start_episode(
|
|
588
|
+
self,
|
|
589
|
+
scenario_id: str,
|
|
590
|
+
archetype: str = "trader",
|
|
591
|
+
) -> EpisodeBuffer:
|
|
592
|
+
"""Start a new episode"""
|
|
593
|
+
import uuid
|
|
594
|
+
|
|
595
|
+
episode_id = f"ep-{uuid.uuid4().hex[:8]}"
|
|
596
|
+
episode = EpisodeBuffer(
|
|
597
|
+
episode_id=episode_id,
|
|
598
|
+
scenario_id=scenario_id,
|
|
599
|
+
archetype=archetype,
|
|
600
|
+
)
|
|
601
|
+
|
|
602
|
+
self._current_episode = episode
|
|
603
|
+
return episode
|
|
604
|
+
|
|
605
|
+
def add_turn(self, turn: TurnData) -> None:
|
|
606
|
+
"""Add turn to current episode"""
|
|
607
|
+
if self._current_episode is None:
|
|
608
|
+
raise RuntimeError("No active episode. Call start_episode first.")
|
|
609
|
+
|
|
610
|
+
self._current_episode.add_turn(turn)
|
|
611
|
+
|
|
612
|
+
if turn.done:
|
|
613
|
+
self._finalize_current()
|
|
614
|
+
|
|
615
|
+
def _finalize_current(self) -> None:
|
|
616
|
+
"""Finalize and store current episode"""
|
|
617
|
+
if self._current_episode is not None:
|
|
618
|
+
self.episodes.append(self._current_episode)
|
|
619
|
+
|
|
620
|
+
# Trim if too many
|
|
621
|
+
if len(self.episodes) > self.max_episodes:
|
|
622
|
+
self.episodes = self.episodes[-self.max_episodes:]
|
|
623
|
+
|
|
624
|
+
self._current_episode = None
|
|
625
|
+
|
|
626
|
+
def get_completed_episodes(self) -> List[EpisodeBuffer]:
|
|
627
|
+
"""Get all completed episodes"""
|
|
628
|
+
return [e for e in self.episodes if e.completed]
|
|
629
|
+
|
|
630
|
+
def get_successful_episodes(self) -> List[EpisodeBuffer]:
|
|
631
|
+
"""Get successful episodes"""
|
|
632
|
+
return [e for e in self.episodes if e.completed and e.success]
|
|
633
|
+
|
|
634
|
+
def clear(self) -> None:
|
|
635
|
+
"""Clear all episodes"""
|
|
636
|
+
self.episodes = []
|
|
637
|
+
self._current_episode = None
|
|
638
|
+
|
|
639
|
+
def get_stats(self) -> Dict:
|
|
640
|
+
"""Get collector statistics"""
|
|
641
|
+
completed = self.get_completed_episodes()
|
|
642
|
+
successful = self.get_successful_episodes()
|
|
643
|
+
|
|
644
|
+
return {
|
|
645
|
+
"total_episodes": len(self.episodes),
|
|
646
|
+
"completed_episodes": len(completed),
|
|
647
|
+
"successful_episodes": len(successful),
|
|
648
|
+
"success_rate": len(successful) / max(1, len(completed)),
|
|
649
|
+
"avg_episode_length": (
|
|
650
|
+
sum(e.episode_length for e in completed) / max(1, len(completed))
|
|
651
|
+
),
|
|
652
|
+
"avg_reward": (
|
|
653
|
+
sum(e.total_reward for e in completed) / max(1, len(completed))
|
|
654
|
+
),
|
|
655
|
+
}
|
|
656
|
+
|