@elizaos/training 2.0.0-alpha.11
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Dockerfile +75 -0
- package/Makefile +374 -0
- package/README.md +346 -0
- package/config/rubrics.json +137 -0
- package/data/.gitkeep +0 -0
- package/data/degen/.gitkeep +2 -0
- package/data/trader/.gitkeep +2 -0
- package/docker-compose.test.yml +57 -0
- package/package.json +58 -0
- package/python/config/babylon_atropos.yaml +90 -0
- package/python/config/profiles/12gb.json +11 -0
- package/python/config/profiles/16gb.json +10 -0
- package/python/config/profiles/24gb.json +10 -0
- package/python/config/profiles/48gb.json +10 -0
- package/python/config/profiles/cpu.json +11 -0
- package/python/config/profiles/l40-2gpu-safe.json +20 -0
- package/python/config/profiles/l40-2gpu.json +22 -0
- package/python/config/profiles/l40-4gpu.json +21 -0
- package/python/config/profiles/l40.json +17 -0
- package/python/config/tinker_training.yaml +143 -0
- package/python/curriculum_state.json +165 -0
- package/python/env.template +86 -0
- package/python/env.training.template +46 -0
- package/python/pyproject.toml +41 -0
- package/python/requirements-ci.txt +31 -0
- package/python/requirements.txt +87 -0
- package/python/scripts/__init__.py +4 -0
- package/python/scripts/import_json_trajectories.py +412 -0
- package/python/scripts/local-finetune/README.md +63 -0
- package/python/scripts/local-finetune/ingest_and_score.py +139 -0
- package/python/scripts/local-finetune/merge_model.py +32 -0
- package/python/scripts/local-finetune/test_adapter.py +91 -0
- package/python/scripts/local-finetune/train_from_csv.py +132 -0
- package/python/scripts/merge_trajectories.py +318 -0
- package/python/scripts/run_ab_test.py +143 -0
- package/python/scripts/run_full_pipeline.py +544 -0
- package/python/scripts/run_tinker_training.py +192 -0
- package/python/scripts/run_training.py +914 -0
- package/python/scripts/test_judge.py +155 -0
- package/python/scripts/test_pipeline.py +356 -0
- package/python/scripts/test_trained_model.py +380 -0
- package/python/scripts/train_local.py +528 -0
- package/python/setup.py +20 -0
- package/python/src/__init__.py +190 -0
- package/python/src/data_bridge/__init__.py +24 -0
- package/python/src/data_bridge/converter.py +435 -0
- package/python/src/data_bridge/reader.py +393 -0
- package/python/src/models.py +283 -0
- package/python/src/training/__init__.py +605 -0
- package/python/src/training/ab_testing.py +404 -0
- package/python/src/training/action_executor.py +621 -0
- package/python/src/training/archetype_trainer.py +347 -0
- package/python/src/training/atropos_trainer.py +980 -0
- package/python/src/training/babylon_env.py +1254 -0
- package/python/src/training/error_recovery.py +647 -0
- package/python/src/training/evaluation.py +856 -0
- package/python/src/training/fast_simulator.py +880 -0
- package/python/src/training/format_validator.py +584 -0
- package/python/src/training/hybrid_env.py +522 -0
- package/python/src/training/kl_controller.py +628 -0
- package/python/src/training/multi_prompt_dataset.py +883 -0
- package/python/src/training/multi_turn.py +656 -0
- package/python/src/training/online_env.py +1084 -0
- package/python/src/training/quality_scorer.py +391 -0
- package/python/src/training/quality_utils.py +633 -0
- package/python/src/training/rewards.py +1344 -0
- package/python/src/training/rlaif_env.py +17 -0
- package/python/src/training/rollout_generator.py +502 -0
- package/python/src/training/rubric_loader.py +198 -0
- package/python/src/training/scenario_pool.py +1072 -0
- package/python/src/training/schemas.py +481 -0
- package/python/src/training/service_manager.py +552 -0
- package/python/src/training/simulation_bridge.py +535 -0
- package/python/src/training/tick_reward_attribution.py +399 -0
- package/python/src/training/tinker_client.py +575 -0
- package/python/src/training/tinker_trainer.py +646 -0
- package/python/src/training/tokenization_utils.py +402 -0
- package/python/tests/e2e/__init__.py +13 -0
- package/python/tests/e2e/conftest.py +258 -0
- package/python/tests/e2e/test_full_pipeline.py +643 -0
- package/python/tests/e2e/test_online_training_e2e.py +365 -0
- package/python/tests/integration/__init__.py +12 -0
- package/python/tests/integration/conftest.py +383 -0
- package/python/tests/integration/test_db_integration.py +649 -0
- package/python/tests/integration/test_json_mode_integration.py +554 -0
- package/python/tests/test_action_executor.py +594 -0
- package/python/tests/test_archetype_scoring.py +1027 -0
- package/python/tests/test_atropos_integration.py +360 -0
- package/python/tests/test_evaluation.py +727 -0
- package/python/tests/test_format_validator.py +486 -0
- package/python/tests/test_kl_controller.py +432 -0
- package/python/tests/test_lr_scheduler.py +579 -0
- package/python/tests/test_multi_turn.py +590 -0
- package/python/tests/test_online_env.py +519 -0
- package/python/tests/test_quality_scorer.py +474 -0
- package/python/tests/test_scenario_pool.py +735 -0
- package/python/tests/test_service_manager.py +585 -0
- package/python/tests/test_simulation_rollout.py +581 -0
- package/python/tests/test_tokenization_utils.py +501 -0
- package/python/tests/test_training_orchestrator.py +497 -0
- package/python/tests/test_training_output_structure.py +661 -0
- package/research-output/training-runs/training-run-1770772042899.json +26 -0
- package/research-output/training-runs/training-run-1770930079670.json +32 -0
- package/research-output/training-runs/training-run-1770930143700.json +44 -0
- package/research-output/training-runs/training-run-1770930183638.json +38 -0
- package/research-output/training-runs/training-run-1770930442049.json +38 -0
- package/research-output/training-runs/training-run-1770930793243.json +38 -0
- package/scripts/assess-training-data.ts +422 -0
- package/scripts/e2e-training-test.ts +550 -0
- package/scripts/export-rubrics.ts +64 -0
- package/scripts/generate-research-report.ts +1523 -0
- package/scripts/generate_dataset.sh +173 -0
- package/scripts/json-mode-benchmark.ts +399 -0
- package/scripts/real-archetype-benchmark.ts +210 -0
- package/scripts/run-baseline-comparison.ts +116 -0
- package/scripts/run-full-pipeline.ts +272 -0
- package/scripts/runpod_setup.sh +137 -0
- package/scripts/runpod_validate.sh +147 -0
- package/scripts/test-model-in-game.ts +955 -0
- package/scripts/test-scoring.ts +73 -0
- package/scripts/test-trained-model.ts +209 -0
- package/scripts/train-and-test.ts +824 -0
- package/scripts/verify-final.ts +118 -0
- package/src/adapter.ts +516 -0
- package/src/archetypes/ArchetypeConfigService.ts +626 -0
- package/src/archetypes/derive-archetype.ts +249 -0
- package/src/archetypes/index.ts +22 -0
- package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
- package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
- package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
- package/src/benchmark/BenchmarkDataViewer.ts +324 -0
- package/src/benchmark/BenchmarkHistoryService.ts +221 -0
- package/src/benchmark/BenchmarkRunner.ts +685 -0
- package/src/benchmark/BenchmarkValidator.ts +206 -0
- package/src/benchmark/FastEvalRunner.ts +225 -0
- package/src/benchmark/MetricsValidator.ts +165 -0
- package/src/benchmark/MetricsVisualizer.ts +909 -0
- package/src/benchmark/ModelBenchmarkService.ts +611 -0
- package/src/benchmark/ModelRegistry.ts +158 -0
- package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
- package/src/benchmark/SimulationA2AInterface.ts +1169 -0
- package/src/benchmark/SimulationEngine.ts +832 -0
- package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
- package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
- package/src/benchmark/index.ts +89 -0
- package/src/benchmark/parseSimulationMetrics.ts +124 -0
- package/src/benchmark/simulation-types.ts +78 -0
- package/src/dependencies.ts +439 -0
- package/src/generation/TrajectoryGenerator.ts +387 -0
- package/src/generation/index.ts +12 -0
- package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
- package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
- package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
- package/src/huggingface/index.ts +27 -0
- package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
- package/src/index.ts +102 -0
- package/src/init-training.ts +53 -0
- package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
- package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
- package/src/metrics/index.ts +8 -0
- package/src/metrics/types.ts +200 -0
- package/src/rubrics/__tests__/index.test.ts +184 -0
- package/src/rubrics/ass-kisser.ts +85 -0
- package/src/rubrics/degen.ts +80 -0
- package/src/rubrics/goody-twoshoes.ts +84 -0
- package/src/rubrics/index.ts +236 -0
- package/src/rubrics/information-trader.ts +84 -0
- package/src/rubrics/infosec.ts +101 -0
- package/src/rubrics/liar.ts +104 -0
- package/src/rubrics/perps-trader.ts +87 -0
- package/src/rubrics/researcher.ts +81 -0
- package/src/rubrics/scammer.ts +82 -0
- package/src/rubrics/social-butterfly.ts +73 -0
- package/src/rubrics/super-predictor.ts +97 -0
- package/src/rubrics/trader.ts +67 -0
- package/src/scoring/ArchetypeScoringService.ts +486 -0
- package/src/scoring/JudgePromptBuilder.ts +556 -0
- package/src/scoring/LLMJudgeCache.ts +401 -0
- package/src/scoring/index.ts +9 -0
- package/src/training/AutomationPipeline.ts +916 -0
- package/src/training/BenchmarkService.ts +518 -0
- package/src/training/ConfigValidator.ts +220 -0
- package/src/training/MarketOutcomesTracker.ts +187 -0
- package/src/training/ModelDeployer.ts +186 -0
- package/src/training/ModelFetcher.ts +76 -0
- package/src/training/ModelSelectionService.ts +341 -0
- package/src/training/ModelUsageVerifier.ts +160 -0
- package/src/training/MultiModelOrchestrator.ts +580 -0
- package/src/training/RLModelConfig.ts +407 -0
- package/src/training/RewardBackpropagationService.ts +149 -0
- package/src/training/RulerScoringService.ts +666 -0
- package/src/training/TrainingMonitor.ts +166 -0
- package/src/training/TrajectoryRecorder.ts +399 -0
- package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
- package/src/training/index.ts +100 -0
- package/src/training/logRLConfig.ts +34 -0
- package/src/training/pipeline.ts +129 -0
- package/src/training/storage/ModelStorageService.ts +279 -0
- package/src/training/storage/TrainingDataArchiver.ts +197 -0
- package/src/training/storage/index.ts +17 -0
- package/src/training/types.ts +207 -0
- package/src/training/window-utils.ts +138 -0
- package/src/utils/index.ts +101 -0
- package/src/utils/logger.ts +59 -0
- package/src/utils/snowflake.ts +17 -0
- package/src/utils/synthetic-detector.ts +111 -0
- package/tsconfig.json +20 -0
|
@@ -0,0 +1,399 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tick Reward Attribution
|
|
3
|
+
|
|
4
|
+
A single agent tick may contain multiple LLM calls with different purposes:
|
|
5
|
+
1. REASONING - Analysis and planning (should I trade?)
|
|
6
|
+
2. ACTION - Decision making (what trade to make?)
|
|
7
|
+
3. RESPONSE - Communication (what to say?)
|
|
8
|
+
4. EVALUATION - Self-assessment (how did I do?)
|
|
9
|
+
|
|
10
|
+
The global tick reward needs to be attributed back to individual calls
|
|
11
|
+
to train each prompt type effectively.
|
|
12
|
+
|
|
13
|
+
This module implements credit assignment from tick outcomes to individual LLM calls.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
from dataclasses import dataclass, field
|
|
18
|
+
from enum import Enum
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class CallPurpose(str, Enum):
|
|
24
|
+
"""Purpose categories for LLM calls within a tick"""
|
|
25
|
+
REASONING = "reasoning"
|
|
26
|
+
ACTION = "action"
|
|
27
|
+
RESPONSE = "response"
|
|
28
|
+
EVALUATION = "evaluation"
|
|
29
|
+
OTHER = "other"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class LLMCallRecord:
|
|
34
|
+
"""Record of a single LLM call within a tick"""
|
|
35
|
+
call_index: int
|
|
36
|
+
purpose: CallPurpose
|
|
37
|
+
action_type: str | None # e.g., 'evaluate_trading_opportunity', 'execute_response'
|
|
38
|
+
|
|
39
|
+
# The actual prompts
|
|
40
|
+
system_prompt: str
|
|
41
|
+
user_prompt: str
|
|
42
|
+
response: str
|
|
43
|
+
|
|
44
|
+
# Model info
|
|
45
|
+
model: str
|
|
46
|
+
temperature: float
|
|
47
|
+
max_tokens: int
|
|
48
|
+
latency_ms: int
|
|
49
|
+
|
|
50
|
+
# Outcome tracking
|
|
51
|
+
led_to_action: bool = False # Did this call lead to an action?
|
|
52
|
+
action_success: bool | None = None # If led to action, was it successful?
|
|
53
|
+
|
|
54
|
+
# Attributed reward (calculated by TickRewardAttributor)
|
|
55
|
+
attributed_reward: float = 0.0
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@dataclass
|
|
59
|
+
class TickOutcome:
|
|
60
|
+
"""Outcome of a complete tick"""
|
|
61
|
+
tick_number: int
|
|
62
|
+
|
|
63
|
+
# Financial outcome
|
|
64
|
+
pnl_delta: float # Change in P&L this tick
|
|
65
|
+
balance_delta: float # Change in balance
|
|
66
|
+
|
|
67
|
+
# Action outcomes
|
|
68
|
+
trades_executed: int
|
|
69
|
+
trades_successful: int
|
|
70
|
+
trades_failed: int
|
|
71
|
+
|
|
72
|
+
# Social outcomes (for response calls)
|
|
73
|
+
posts_created: int
|
|
74
|
+
responses_sent: int
|
|
75
|
+
engagement_received: int # likes, replies, etc.
|
|
76
|
+
|
|
77
|
+
# Overall quality signals
|
|
78
|
+
action_count: int
|
|
79
|
+
wait_count: int # Ticks where agent chose to wait
|
|
80
|
+
error_count: int
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@dataclass
|
|
84
|
+
class TickData:
|
|
85
|
+
"""Complete data for a single tick with multiple LLM calls"""
|
|
86
|
+
tick_number: int
|
|
87
|
+
timestamp: int
|
|
88
|
+
agent_id: str
|
|
89
|
+
|
|
90
|
+
# All LLM calls made during this tick
|
|
91
|
+
llm_calls: list[LLMCallRecord] = field(default_factory=list)
|
|
92
|
+
|
|
93
|
+
# Final outcome
|
|
94
|
+
outcome: TickOutcome | None = None
|
|
95
|
+
|
|
96
|
+
# Global tick reward (from environment or judge)
|
|
97
|
+
global_reward: float = 0.0
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class TickRewardAttributor:
|
|
101
|
+
"""
|
|
102
|
+
Attributes global tick reward to individual LLM calls.
|
|
103
|
+
|
|
104
|
+
The key insight is that different call types contribute differently:
|
|
105
|
+
- REASONING calls set up the decision (credit if action succeeds)
|
|
106
|
+
- ACTION calls make the decision (direct credit from outcome)
|
|
107
|
+
- RESPONSE calls handle communication (credit from social metrics)
|
|
108
|
+
- EVALUATION calls assess performance (credit from accuracy)
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
def __init__(
|
|
112
|
+
self,
|
|
113
|
+
reasoning_weight: float = 0.25,
|
|
114
|
+
action_weight: float = 0.50,
|
|
115
|
+
response_weight: float = 0.15,
|
|
116
|
+
evaluation_weight: float = 0.10,
|
|
117
|
+
):
|
|
118
|
+
"""
|
|
119
|
+
Initialize with weights for each call type.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
reasoning_weight: Fraction of action reward attributed to reasoning
|
|
123
|
+
action_weight: Fraction attributed to action decision
|
|
124
|
+
response_weight: Fraction attributed to response generation
|
|
125
|
+
evaluation_weight: Fraction attributed to self-evaluation
|
|
126
|
+
"""
|
|
127
|
+
self.weights = {
|
|
128
|
+
CallPurpose.REASONING: reasoning_weight,
|
|
129
|
+
CallPurpose.ACTION: action_weight,
|
|
130
|
+
CallPurpose.RESPONSE: response_weight,
|
|
131
|
+
CallPurpose.EVALUATION: evaluation_weight,
|
|
132
|
+
CallPurpose.OTHER: 0.0,
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
# Validate weights sum to ~1.0
|
|
136
|
+
total = sum(self.weights.values())
|
|
137
|
+
if abs(total - 1.0) > 0.01:
|
|
138
|
+
logger.warning(f"Reward weights sum to {total}, normalizing...")
|
|
139
|
+
for k in self.weights:
|
|
140
|
+
self.weights[k] /= total
|
|
141
|
+
|
|
142
|
+
def attribute_rewards(self, tick: TickData) -> list[LLMCallRecord]:
|
|
143
|
+
"""
|
|
144
|
+
Attribute the global tick reward to individual LLM calls.
|
|
145
|
+
|
|
146
|
+
The attribution strategy:
|
|
147
|
+
1. Calculate base reward per call type
|
|
148
|
+
2. Adjust based on call-specific outcomes
|
|
149
|
+
3. Apply temporal credit (earlier calls that led to success get more)
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
List of LLM calls with attributed_reward set
|
|
153
|
+
"""
|
|
154
|
+
if not tick.llm_calls:
|
|
155
|
+
return []
|
|
156
|
+
|
|
157
|
+
if tick.outcome is None:
|
|
158
|
+
# No outcome yet, can't attribute
|
|
159
|
+
for call in tick.llm_calls:
|
|
160
|
+
call.attributed_reward = 0.0
|
|
161
|
+
return tick.llm_calls
|
|
162
|
+
|
|
163
|
+
global_reward = tick.global_reward
|
|
164
|
+
outcome = tick.outcome
|
|
165
|
+
|
|
166
|
+
# Group calls by purpose
|
|
167
|
+
calls_by_purpose: dict[CallPurpose, list[LLMCallRecord]] = {
|
|
168
|
+
purpose: [] for purpose in CallPurpose
|
|
169
|
+
}
|
|
170
|
+
for call in tick.llm_calls:
|
|
171
|
+
calls_by_purpose[call.purpose].append(call)
|
|
172
|
+
|
|
173
|
+
# Calculate base reward pool for each purpose
|
|
174
|
+
purpose_rewards: dict[CallPurpose, float] = {}
|
|
175
|
+
|
|
176
|
+
for purpose, weight in self.weights.items():
|
|
177
|
+
calls = calls_by_purpose[purpose]
|
|
178
|
+
if not calls:
|
|
179
|
+
continue
|
|
180
|
+
|
|
181
|
+
# Base reward from global reward
|
|
182
|
+
base_reward = global_reward * weight
|
|
183
|
+
|
|
184
|
+
# Adjust based on purpose-specific outcomes
|
|
185
|
+
if purpose == CallPurpose.ACTION:
|
|
186
|
+
# Action calls get reward based on trade success
|
|
187
|
+
if outcome.trades_executed > 0:
|
|
188
|
+
success_rate = outcome.trades_successful / outcome.trades_executed
|
|
189
|
+
base_reward *= (0.5 + 0.5 * success_rate) # Scale by success
|
|
190
|
+
|
|
191
|
+
# Bonus for P&L
|
|
192
|
+
pnl_bonus = min(1.0, max(-1.0, outcome.pnl_delta / 100.0))
|
|
193
|
+
base_reward += pnl_bonus * 0.2 * abs(global_reward)
|
|
194
|
+
|
|
195
|
+
elif purpose == CallPurpose.RESPONSE:
|
|
196
|
+
# Response calls get reward based on engagement
|
|
197
|
+
if outcome.responses_sent > 0:
|
|
198
|
+
engagement_rate = min(1.0, outcome.engagement_received / (outcome.responses_sent * 5))
|
|
199
|
+
base_reward *= (0.5 + 0.5 * engagement_rate)
|
|
200
|
+
|
|
201
|
+
elif purpose == CallPurpose.REASONING:
|
|
202
|
+
# Reasoning calls share credit with action outcomes
|
|
203
|
+
if outcome.trades_executed > 0:
|
|
204
|
+
success_rate = outcome.trades_successful / outcome.trades_executed
|
|
205
|
+
base_reward *= success_rate # Reasoning credited by action success
|
|
206
|
+
|
|
207
|
+
purpose_rewards[purpose] = base_reward
|
|
208
|
+
|
|
209
|
+
# Distribute rewards to individual calls
|
|
210
|
+
for purpose, calls in calls_by_purpose.items():
|
|
211
|
+
if not calls:
|
|
212
|
+
continue
|
|
213
|
+
|
|
214
|
+
total_reward = purpose_rewards.get(purpose, 0.0)
|
|
215
|
+
|
|
216
|
+
# Apply temporal credit: later calls in successful sequences get more
|
|
217
|
+
for i, call in enumerate(calls):
|
|
218
|
+
# Base share
|
|
219
|
+
share = total_reward / len(calls)
|
|
220
|
+
|
|
221
|
+
# Temporal adjustment
|
|
222
|
+
if call.led_to_action and call.action_success:
|
|
223
|
+
# This call led to successful action - boost it
|
|
224
|
+
share *= 1.2
|
|
225
|
+
elif call.led_to_action and call.action_success is False:
|
|
226
|
+
# This call led to failed action - reduce it
|
|
227
|
+
share *= 0.6
|
|
228
|
+
|
|
229
|
+
call.attributed_reward = share
|
|
230
|
+
|
|
231
|
+
return tick.llm_calls
|
|
232
|
+
|
|
233
|
+
def attribute_batch(self, ticks: list[TickData]) -> list[TickData]:
|
|
234
|
+
"""
|
|
235
|
+
Attribute rewards for a batch of ticks.
|
|
236
|
+
|
|
237
|
+
Also applies relative normalization across the batch for GRPO.
|
|
238
|
+
"""
|
|
239
|
+
# First pass: attribute individual rewards
|
|
240
|
+
for tick in ticks:
|
|
241
|
+
self.attribute_rewards(tick)
|
|
242
|
+
|
|
243
|
+
# Second pass: normalize within purpose groups for GRPO
|
|
244
|
+
# Group all calls by purpose across batch
|
|
245
|
+
all_calls_by_purpose: dict[CallPurpose, list[LLMCallRecord]] = {
|
|
246
|
+
purpose: [] for purpose in CallPurpose
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
for tick in ticks:
|
|
250
|
+
for call in tick.llm_calls:
|
|
251
|
+
all_calls_by_purpose[call.purpose].append(call)
|
|
252
|
+
|
|
253
|
+
# Normalize each purpose group to mean 0 (for GRPO)
|
|
254
|
+
for purpose, calls in all_calls_by_purpose.items():
|
|
255
|
+
if len(calls) < 2:
|
|
256
|
+
continue
|
|
257
|
+
|
|
258
|
+
rewards = [c.attributed_reward for c in calls]
|
|
259
|
+
mean_reward = sum(rewards) / len(rewards)
|
|
260
|
+
|
|
261
|
+
for call in calls:
|
|
262
|
+
call.attributed_reward -= mean_reward
|
|
263
|
+
|
|
264
|
+
return ticks
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def build_training_samples_from_tick(
|
|
268
|
+
tick: TickData,
|
|
269
|
+
trajectory_id: str,
|
|
270
|
+
trajectory_score: float,
|
|
271
|
+
) -> list[dict]:
|
|
272
|
+
"""
|
|
273
|
+
Build training samples from a tick with attributed rewards.
|
|
274
|
+
|
|
275
|
+
Each LLM call becomes a separate training sample with:
|
|
276
|
+
- The original prompt/response
|
|
277
|
+
- Attributed reward from tick outcome
|
|
278
|
+
- Context about what happened
|
|
279
|
+
"""
|
|
280
|
+
samples = []
|
|
281
|
+
|
|
282
|
+
for call in tick.llm_calls:
|
|
283
|
+
sample = {
|
|
284
|
+
"trajectory_id": trajectory_id,
|
|
285
|
+
"tick_number": tick.tick_number,
|
|
286
|
+
"call_index": call.call_index,
|
|
287
|
+
"purpose": call.purpose.value,
|
|
288
|
+
"action_type": call.action_type,
|
|
289
|
+
|
|
290
|
+
# The actual training data
|
|
291
|
+
"messages": [
|
|
292
|
+
{"role": "system", "content": call.system_prompt},
|
|
293
|
+
{"role": "user", "content": call.user_prompt},
|
|
294
|
+
{"role": "assistant", "content": call.response},
|
|
295
|
+
],
|
|
296
|
+
|
|
297
|
+
# Reward signals
|
|
298
|
+
"tick_reward": tick.global_reward,
|
|
299
|
+
"attributed_reward": call.attributed_reward,
|
|
300
|
+
"trajectory_score": trajectory_score,
|
|
301
|
+
|
|
302
|
+
# Outcome context
|
|
303
|
+
"led_to_action": call.led_to_action,
|
|
304
|
+
"action_success": call.action_success,
|
|
305
|
+
|
|
306
|
+
# Model info for analysis
|
|
307
|
+
"model": call.model,
|
|
308
|
+
"temperature": call.temperature,
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
samples.append(sample)
|
|
312
|
+
|
|
313
|
+
return samples
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def group_samples_for_grpo(
|
|
317
|
+
samples: list[dict],
|
|
318
|
+
group_size: int = 4,
|
|
319
|
+
min_variance: float = 0.01,
|
|
320
|
+
) -> list[list[dict]]:
|
|
321
|
+
"""
|
|
322
|
+
Group samples by purpose for GRPO training.
|
|
323
|
+
|
|
324
|
+
Returns groups where samples have the same purpose but
|
|
325
|
+
different attributed rewards (for relative comparison).
|
|
326
|
+
"""
|
|
327
|
+
# Group by purpose
|
|
328
|
+
by_purpose: dict[str, list[dict]] = {}
|
|
329
|
+
for sample in samples:
|
|
330
|
+
purpose = sample["purpose"]
|
|
331
|
+
if purpose not in by_purpose:
|
|
332
|
+
by_purpose[purpose] = []
|
|
333
|
+
by_purpose[purpose].append(sample)
|
|
334
|
+
|
|
335
|
+
groups = []
|
|
336
|
+
|
|
337
|
+
for purpose, purpose_samples in by_purpose.items():
|
|
338
|
+
if len(purpose_samples) < group_size:
|
|
339
|
+
continue
|
|
340
|
+
|
|
341
|
+
# Sort by attributed reward
|
|
342
|
+
sorted_samples = sorted(
|
|
343
|
+
purpose_samples,
|
|
344
|
+
key=lambda s: s["attributed_reward"]
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
# Create groups with variance
|
|
348
|
+
n = len(sorted_samples)
|
|
349
|
+
for i in range(0, n - group_size + 1, group_size // 2):
|
|
350
|
+
group = []
|
|
351
|
+
step = n // group_size
|
|
352
|
+
|
|
353
|
+
for j in range(group_size):
|
|
354
|
+
idx = min(i + j * step, n - 1)
|
|
355
|
+
group.append(sorted_samples[idx])
|
|
356
|
+
|
|
357
|
+
# Check variance
|
|
358
|
+
rewards = [s["attributed_reward"] for s in group]
|
|
359
|
+
mean_r = sum(rewards) / len(rewards)
|
|
360
|
+
variance = sum((r - mean_r) ** 2 for r in rewards) / len(rewards)
|
|
361
|
+
|
|
362
|
+
if variance >= min_variance:
|
|
363
|
+
groups.append(group)
|
|
364
|
+
|
|
365
|
+
return groups
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
# Example of how Eliza prompt structure maps to our training:
|
|
369
|
+
"""
|
|
370
|
+
ELIZA MESSAGE HANDLER (single tick, multiple outputs):
|
|
371
|
+
|
|
372
|
+
Input:
|
|
373
|
+
<task>Generate dialog and actions for {{agentName}}</task>
|
|
374
|
+
<providers>{{providers}}</providers>
|
|
375
|
+
|
|
376
|
+
Output:
|
|
377
|
+
<response>
|
|
378
|
+
<thought>Your thought here</thought> <- PURPOSE: reasoning
|
|
379
|
+
<actions>ACTION1,ACTION2</actions> <- PURPOSE: action
|
|
380
|
+
<providers>PROVIDER1,PROVIDER2</providers> <- metadata
|
|
381
|
+
<text>Your response text here</text> <- PURPOSE: response
|
|
382
|
+
</response>
|
|
383
|
+
|
|
384
|
+
In RL training, we break this into 3 training samples:
|
|
385
|
+
1. REASONING sample: Input -> <thought>...</thought>
|
|
386
|
+
Reward: Attributed based on whether actions succeeded
|
|
387
|
+
|
|
388
|
+
2. ACTION sample: Input + thought context -> <actions>...</actions>
|
|
389
|
+
Reward: Direct from action outcome (P&L, success)
|
|
390
|
+
|
|
391
|
+
3. RESPONSE sample: Input + thought + actions -> <text>...</text>
|
|
392
|
+
Reward: From social engagement metrics
|
|
393
|
+
|
|
394
|
+
This allows the model to learn:
|
|
395
|
+
- Better reasoning that leads to good actions
|
|
396
|
+
- Better action selection given good reasoning
|
|
397
|
+
- Better responses given successful actions
|
|
398
|
+
"""
|
|
399
|
+
|