@elizaos/training 2.0.0-alpha.11
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Dockerfile +75 -0
- package/Makefile +374 -0
- package/README.md +346 -0
- package/config/rubrics.json +137 -0
- package/data/.gitkeep +0 -0
- package/data/degen/.gitkeep +2 -0
- package/data/trader/.gitkeep +2 -0
- package/docker-compose.test.yml +57 -0
- package/package.json +58 -0
- package/python/config/babylon_atropos.yaml +90 -0
- package/python/config/profiles/12gb.json +11 -0
- package/python/config/profiles/16gb.json +10 -0
- package/python/config/profiles/24gb.json +10 -0
- package/python/config/profiles/48gb.json +10 -0
- package/python/config/profiles/cpu.json +11 -0
- package/python/config/profiles/l40-2gpu-safe.json +20 -0
- package/python/config/profiles/l40-2gpu.json +22 -0
- package/python/config/profiles/l40-4gpu.json +21 -0
- package/python/config/profiles/l40.json +17 -0
- package/python/config/tinker_training.yaml +143 -0
- package/python/curriculum_state.json +165 -0
- package/python/env.template +86 -0
- package/python/env.training.template +46 -0
- package/python/pyproject.toml +41 -0
- package/python/requirements-ci.txt +31 -0
- package/python/requirements.txt +87 -0
- package/python/scripts/__init__.py +4 -0
- package/python/scripts/import_json_trajectories.py +412 -0
- package/python/scripts/local-finetune/README.md +63 -0
- package/python/scripts/local-finetune/ingest_and_score.py +139 -0
- package/python/scripts/local-finetune/merge_model.py +32 -0
- package/python/scripts/local-finetune/test_adapter.py +91 -0
- package/python/scripts/local-finetune/train_from_csv.py +132 -0
- package/python/scripts/merge_trajectories.py +318 -0
- package/python/scripts/run_ab_test.py +143 -0
- package/python/scripts/run_full_pipeline.py +544 -0
- package/python/scripts/run_tinker_training.py +192 -0
- package/python/scripts/run_training.py +914 -0
- package/python/scripts/test_judge.py +155 -0
- package/python/scripts/test_pipeline.py +356 -0
- package/python/scripts/test_trained_model.py +380 -0
- package/python/scripts/train_local.py +528 -0
- package/python/setup.py +20 -0
- package/python/src/__init__.py +190 -0
- package/python/src/data_bridge/__init__.py +24 -0
- package/python/src/data_bridge/converter.py +435 -0
- package/python/src/data_bridge/reader.py +393 -0
- package/python/src/models.py +283 -0
- package/python/src/training/__init__.py +605 -0
- package/python/src/training/ab_testing.py +404 -0
- package/python/src/training/action_executor.py +621 -0
- package/python/src/training/archetype_trainer.py +347 -0
- package/python/src/training/atropos_trainer.py +980 -0
- package/python/src/training/babylon_env.py +1254 -0
- package/python/src/training/error_recovery.py +647 -0
- package/python/src/training/evaluation.py +856 -0
- package/python/src/training/fast_simulator.py +880 -0
- package/python/src/training/format_validator.py +584 -0
- package/python/src/training/hybrid_env.py +522 -0
- package/python/src/training/kl_controller.py +628 -0
- package/python/src/training/multi_prompt_dataset.py +883 -0
- package/python/src/training/multi_turn.py +656 -0
- package/python/src/training/online_env.py +1084 -0
- package/python/src/training/quality_scorer.py +391 -0
- package/python/src/training/quality_utils.py +633 -0
- package/python/src/training/rewards.py +1344 -0
- package/python/src/training/rlaif_env.py +17 -0
- package/python/src/training/rollout_generator.py +502 -0
- package/python/src/training/rubric_loader.py +198 -0
- package/python/src/training/scenario_pool.py +1072 -0
- package/python/src/training/schemas.py +481 -0
- package/python/src/training/service_manager.py +552 -0
- package/python/src/training/simulation_bridge.py +535 -0
- package/python/src/training/tick_reward_attribution.py +399 -0
- package/python/src/training/tinker_client.py +575 -0
- package/python/src/training/tinker_trainer.py +646 -0
- package/python/src/training/tokenization_utils.py +402 -0
- package/python/tests/e2e/__init__.py +13 -0
- package/python/tests/e2e/conftest.py +258 -0
- package/python/tests/e2e/test_full_pipeline.py +643 -0
- package/python/tests/e2e/test_online_training_e2e.py +365 -0
- package/python/tests/integration/__init__.py +12 -0
- package/python/tests/integration/conftest.py +383 -0
- package/python/tests/integration/test_db_integration.py +649 -0
- package/python/tests/integration/test_json_mode_integration.py +554 -0
- package/python/tests/test_action_executor.py +594 -0
- package/python/tests/test_archetype_scoring.py +1027 -0
- package/python/tests/test_atropos_integration.py +360 -0
- package/python/tests/test_evaluation.py +727 -0
- package/python/tests/test_format_validator.py +486 -0
- package/python/tests/test_kl_controller.py +432 -0
- package/python/tests/test_lr_scheduler.py +579 -0
- package/python/tests/test_multi_turn.py +590 -0
- package/python/tests/test_online_env.py +519 -0
- package/python/tests/test_quality_scorer.py +474 -0
- package/python/tests/test_scenario_pool.py +735 -0
- package/python/tests/test_service_manager.py +585 -0
- package/python/tests/test_simulation_rollout.py +581 -0
- package/python/tests/test_tokenization_utils.py +501 -0
- package/python/tests/test_training_orchestrator.py +497 -0
- package/python/tests/test_training_output_structure.py +661 -0
- package/research-output/training-runs/training-run-1770772042899.json +26 -0
- package/research-output/training-runs/training-run-1770930079670.json +32 -0
- package/research-output/training-runs/training-run-1770930143700.json +44 -0
- package/research-output/training-runs/training-run-1770930183638.json +38 -0
- package/research-output/training-runs/training-run-1770930442049.json +38 -0
- package/research-output/training-runs/training-run-1770930793243.json +38 -0
- package/scripts/assess-training-data.ts +422 -0
- package/scripts/e2e-training-test.ts +550 -0
- package/scripts/export-rubrics.ts +64 -0
- package/scripts/generate-research-report.ts +1523 -0
- package/scripts/generate_dataset.sh +173 -0
- package/scripts/json-mode-benchmark.ts +399 -0
- package/scripts/real-archetype-benchmark.ts +210 -0
- package/scripts/run-baseline-comparison.ts +116 -0
- package/scripts/run-full-pipeline.ts +272 -0
- package/scripts/runpod_setup.sh +137 -0
- package/scripts/runpod_validate.sh +147 -0
- package/scripts/test-model-in-game.ts +955 -0
- package/scripts/test-scoring.ts +73 -0
- package/scripts/test-trained-model.ts +209 -0
- package/scripts/train-and-test.ts +824 -0
- package/scripts/verify-final.ts +118 -0
- package/src/adapter.ts +516 -0
- package/src/archetypes/ArchetypeConfigService.ts +626 -0
- package/src/archetypes/derive-archetype.ts +249 -0
- package/src/archetypes/index.ts +22 -0
- package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
- package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
- package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
- package/src/benchmark/BenchmarkDataViewer.ts +324 -0
- package/src/benchmark/BenchmarkHistoryService.ts +221 -0
- package/src/benchmark/BenchmarkRunner.ts +685 -0
- package/src/benchmark/BenchmarkValidator.ts +206 -0
- package/src/benchmark/FastEvalRunner.ts +225 -0
- package/src/benchmark/MetricsValidator.ts +165 -0
- package/src/benchmark/MetricsVisualizer.ts +909 -0
- package/src/benchmark/ModelBenchmarkService.ts +611 -0
- package/src/benchmark/ModelRegistry.ts +158 -0
- package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
- package/src/benchmark/SimulationA2AInterface.ts +1169 -0
- package/src/benchmark/SimulationEngine.ts +832 -0
- package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
- package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
- package/src/benchmark/index.ts +89 -0
- package/src/benchmark/parseSimulationMetrics.ts +124 -0
- package/src/benchmark/simulation-types.ts +78 -0
- package/src/dependencies.ts +439 -0
- package/src/generation/TrajectoryGenerator.ts +387 -0
- package/src/generation/index.ts +12 -0
- package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
- package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
- package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
- package/src/huggingface/index.ts +27 -0
- package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
- package/src/index.ts +102 -0
- package/src/init-training.ts +53 -0
- package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
- package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
- package/src/metrics/index.ts +8 -0
- package/src/metrics/types.ts +200 -0
- package/src/rubrics/__tests__/index.test.ts +184 -0
- package/src/rubrics/ass-kisser.ts +85 -0
- package/src/rubrics/degen.ts +80 -0
- package/src/rubrics/goody-twoshoes.ts +84 -0
- package/src/rubrics/index.ts +236 -0
- package/src/rubrics/information-trader.ts +84 -0
- package/src/rubrics/infosec.ts +101 -0
- package/src/rubrics/liar.ts +104 -0
- package/src/rubrics/perps-trader.ts +87 -0
- package/src/rubrics/researcher.ts +81 -0
- package/src/rubrics/scammer.ts +82 -0
- package/src/rubrics/social-butterfly.ts +73 -0
- package/src/rubrics/super-predictor.ts +97 -0
- package/src/rubrics/trader.ts +67 -0
- package/src/scoring/ArchetypeScoringService.ts +486 -0
- package/src/scoring/JudgePromptBuilder.ts +556 -0
- package/src/scoring/LLMJudgeCache.ts +401 -0
- package/src/scoring/index.ts +9 -0
- package/src/training/AutomationPipeline.ts +916 -0
- package/src/training/BenchmarkService.ts +518 -0
- package/src/training/ConfigValidator.ts +220 -0
- package/src/training/MarketOutcomesTracker.ts +187 -0
- package/src/training/ModelDeployer.ts +186 -0
- package/src/training/ModelFetcher.ts +76 -0
- package/src/training/ModelSelectionService.ts +341 -0
- package/src/training/ModelUsageVerifier.ts +160 -0
- package/src/training/MultiModelOrchestrator.ts +580 -0
- package/src/training/RLModelConfig.ts +407 -0
- package/src/training/RewardBackpropagationService.ts +149 -0
- package/src/training/RulerScoringService.ts +666 -0
- package/src/training/TrainingMonitor.ts +166 -0
- package/src/training/TrajectoryRecorder.ts +399 -0
- package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
- package/src/training/index.ts +100 -0
- package/src/training/logRLConfig.ts +34 -0
- package/src/training/pipeline.ts +129 -0
- package/src/training/storage/ModelStorageService.ts +279 -0
- package/src/training/storage/TrainingDataArchiver.ts +197 -0
- package/src/training/storage/index.ts +17 -0
- package/src/training/types.ts +207 -0
- package/src/training/window-utils.ts +138 -0
- package/src/utils/index.ts +101 -0
- package/src/utils/logger.ts +59 -0
- package/src/utils/snowflake.ts +17 -0
- package/src/utils/synthetic-detector.ts +111 -0
- package/tsconfig.json +20 -0
|
@@ -0,0 +1,649 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Database Integration Tests
|
|
3
|
+
|
|
4
|
+
Tests the complete database-based training data pipeline:
|
|
5
|
+
1. Trajectory insertion with archetype
|
|
6
|
+
2. Querying trajectories by archetype
|
|
7
|
+
3. Scoring database trajectories
|
|
8
|
+
4. GRPO group formation from database
|
|
9
|
+
5. End-to-end database pipeline
|
|
10
|
+
|
|
11
|
+
These tests REQUIRE a running PostgreSQL database.
|
|
12
|
+
Use docker compose -f docker-compose.test.yml up -d before running.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
import os
|
|
17
|
+
import sys
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from typing import Dict, List
|
|
20
|
+
from datetime import datetime
|
|
21
|
+
import uuid
|
|
22
|
+
|
|
23
|
+
import pytest
|
|
24
|
+
|
|
25
|
+
# Add src to path
|
|
26
|
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
|
27
|
+
|
|
28
|
+
from src.training.rewards import (
|
|
29
|
+
archetype_composite_reward,
|
|
30
|
+
BehaviorMetrics,
|
|
31
|
+
TrajectoryRewardInputs,
|
|
32
|
+
)
|
|
33
|
+
from src.training.rubric_loader import (
|
|
34
|
+
normalize_archetype,
|
|
35
|
+
has_custom_rubric,
|
|
36
|
+
get_available_archetypes,
|
|
37
|
+
)
|
|
38
|
+
from tests.integration.conftest import (
|
|
39
|
+
TrajectoryFixture,
|
|
40
|
+
skip_if_no_database,
|
|
41
|
+
is_database_available,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# Skip all tests in this module if database is not available
|
|
46
|
+
pytestmark = skip_if_no_database()
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def generate_test_id() -> str:
|
|
50
|
+
"""Generate a unique test ID to avoid conflicts."""
|
|
51
|
+
return f"test-{uuid.uuid4().hex[:8]}"
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class TestDatabaseTrajectoryOperations:
|
|
55
|
+
"""Test trajectory CRUD operations in database."""
|
|
56
|
+
|
|
57
|
+
@pytest.fixture(autouse=True)
|
|
58
|
+
def setup_db(self, database_url: str):
|
|
59
|
+
"""Setup database connection and cleanup."""
|
|
60
|
+
import psycopg2
|
|
61
|
+
self.conn = psycopg2.connect(database_url)
|
|
62
|
+
self.test_prefix = f"test-{uuid.uuid4().hex[:8]}"
|
|
63
|
+
yield
|
|
64
|
+
# Cleanup test data
|
|
65
|
+
cur = self.conn.cursor()
|
|
66
|
+
cur.execute(
|
|
67
|
+
'DELETE FROM trajectories WHERE "trajectoryId" LIKE %s',
|
|
68
|
+
(f"{self.test_prefix}%",)
|
|
69
|
+
)
|
|
70
|
+
self.conn.commit()
|
|
71
|
+
cur.close()
|
|
72
|
+
self.conn.close()
|
|
73
|
+
|
|
74
|
+
def _insert_trajectory(
|
|
75
|
+
self,
|
|
76
|
+
trajectory_id: str,
|
|
77
|
+
agent_id: str,
|
|
78
|
+
archetype: str,
|
|
79
|
+
window_id: str,
|
|
80
|
+
steps: List[Dict],
|
|
81
|
+
final_pnl: float,
|
|
82
|
+
episode_length: int,
|
|
83
|
+
):
|
|
84
|
+
"""Insert a test trajectory into the database."""
|
|
85
|
+
cur = self.conn.cursor()
|
|
86
|
+
cur.execute(
|
|
87
|
+
'''
|
|
88
|
+
INSERT INTO trajectories (
|
|
89
|
+
"id", "trajectoryId", "agentId", "archetype", "windowId",
|
|
90
|
+
"stepsJson", "rewardComponentsJson", "metricsJson", "metadataJson",
|
|
91
|
+
"finalPnL", "episodeLength", "totalReward",
|
|
92
|
+
"finalStatus", "isTrainingData", "startTime", "endTime",
|
|
93
|
+
"durationMs", "createdAt", "updatedAt"
|
|
94
|
+
) VALUES (
|
|
95
|
+
%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s
|
|
96
|
+
)
|
|
97
|
+
''',
|
|
98
|
+
(
|
|
99
|
+
trajectory_id,
|
|
100
|
+
trajectory_id,
|
|
101
|
+
agent_id,
|
|
102
|
+
archetype,
|
|
103
|
+
window_id,
|
|
104
|
+
json.dumps(steps),
|
|
105
|
+
"{}", # rewardComponentsJson
|
|
106
|
+
"{}", # metricsJson
|
|
107
|
+
"{}", # metadataJson
|
|
108
|
+
final_pnl,
|
|
109
|
+
episode_length,
|
|
110
|
+
0.5,
|
|
111
|
+
"completed",
|
|
112
|
+
True,
|
|
113
|
+
datetime.now(),
|
|
114
|
+
datetime.now(),
|
|
115
|
+
5000,
|
|
116
|
+
datetime.now(),
|
|
117
|
+
datetime.now(),
|
|
118
|
+
)
|
|
119
|
+
)
|
|
120
|
+
self.conn.commit()
|
|
121
|
+
cur.close()
|
|
122
|
+
|
|
123
|
+
def test_insert_trajectory_with_archetype(
|
|
124
|
+
self,
|
|
125
|
+
sample_trader_trajectory: TrajectoryFixture,
|
|
126
|
+
):
|
|
127
|
+
"""Test inserting a trajectory with archetype."""
|
|
128
|
+
traj_id = f"{self.test_prefix}-trader-001"
|
|
129
|
+
|
|
130
|
+
self._insert_trajectory(
|
|
131
|
+
trajectory_id=traj_id,
|
|
132
|
+
agent_id="test-agent",
|
|
133
|
+
archetype="trader",
|
|
134
|
+
window_id=f"{self.test_prefix}-window-1",
|
|
135
|
+
steps=sample_trader_trajectory.steps,
|
|
136
|
+
final_pnl=sample_trader_trajectory.final_pnl,
|
|
137
|
+
episode_length=sample_trader_trajectory.episode_length,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# Verify insertion
|
|
141
|
+
cur = self.conn.cursor()
|
|
142
|
+
cur.execute(
|
|
143
|
+
'SELECT "archetype" FROM trajectories WHERE "trajectoryId" = %s',
|
|
144
|
+
(traj_id,)
|
|
145
|
+
)
|
|
146
|
+
row = cur.fetchone()
|
|
147
|
+
cur.close()
|
|
148
|
+
|
|
149
|
+
assert row is not None
|
|
150
|
+
assert row[0] == "trader"
|
|
151
|
+
|
|
152
|
+
def test_query_trajectories_by_archetype(self):
|
|
153
|
+
"""Test querying trajectories filtered by archetype."""
|
|
154
|
+
window_id = f"{self.test_prefix}-window-multi"
|
|
155
|
+
|
|
156
|
+
# Insert multiple archetypes
|
|
157
|
+
for i, archetype in enumerate(["trader", "degen", "scammer"]):
|
|
158
|
+
self._insert_trajectory(
|
|
159
|
+
trajectory_id=f"{self.test_prefix}-{archetype}-{i}",
|
|
160
|
+
agent_id=f"agent-{archetype}",
|
|
161
|
+
archetype=archetype,
|
|
162
|
+
window_id=window_id,
|
|
163
|
+
steps=[],
|
|
164
|
+
final_pnl=100.0 * (i + 1),
|
|
165
|
+
episode_length=3,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
# Query traders only
|
|
169
|
+
cur = self.conn.cursor()
|
|
170
|
+
cur.execute(
|
|
171
|
+
'''
|
|
172
|
+
SELECT "trajectoryId", "archetype" FROM trajectories
|
|
173
|
+
WHERE "windowId" = %s AND "archetype" = %s
|
|
174
|
+
''',
|
|
175
|
+
(window_id, "trader")
|
|
176
|
+
)
|
|
177
|
+
rows = cur.fetchall()
|
|
178
|
+
cur.close()
|
|
179
|
+
|
|
180
|
+
assert len(rows) == 1
|
|
181
|
+
assert rows[0][1] == "trader"
|
|
182
|
+
|
|
183
|
+
def test_query_trajectories_by_window_with_archetypes(self):
|
|
184
|
+
"""Test querying all trajectories in window with their archetypes."""
|
|
185
|
+
window_id = f"{self.test_prefix}-window-group"
|
|
186
|
+
archetypes = ["trader", "degen", "scammer", "social-butterfly"]
|
|
187
|
+
|
|
188
|
+
# Insert multiple archetypes
|
|
189
|
+
for i, archetype in enumerate(archetypes):
|
|
190
|
+
self._insert_trajectory(
|
|
191
|
+
trajectory_id=f"{self.test_prefix}-group-{i}",
|
|
192
|
+
agent_id=f"agent-{archetype}",
|
|
193
|
+
archetype=archetype,
|
|
194
|
+
window_id=window_id,
|
|
195
|
+
steps=[],
|
|
196
|
+
final_pnl=100.0 * (i + 1),
|
|
197
|
+
episode_length=3,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
# Query all in window
|
|
201
|
+
cur = self.conn.cursor()
|
|
202
|
+
cur.execute(
|
|
203
|
+
'''
|
|
204
|
+
SELECT "trajectoryId", "archetype", "finalPnL" FROM trajectories
|
|
205
|
+
WHERE "windowId" = %s AND "isTrainingData" = true
|
|
206
|
+
ORDER BY "archetype"
|
|
207
|
+
''',
|
|
208
|
+
(window_id,)
|
|
209
|
+
)
|
|
210
|
+
rows = cur.fetchall()
|
|
211
|
+
cur.close()
|
|
212
|
+
|
|
213
|
+
assert len(rows) == 4
|
|
214
|
+
db_archetypes = {row[1] for row in rows}
|
|
215
|
+
assert db_archetypes == set(archetypes)
|
|
216
|
+
|
|
217
|
+
def test_null_archetype_defaults_to_default(self):
|
|
218
|
+
"""Test that NULL archetype is handled correctly."""
|
|
219
|
+
traj_id = f"{self.test_prefix}-null-arch"
|
|
220
|
+
|
|
221
|
+
# Insert with NULL archetype
|
|
222
|
+
cur = self.conn.cursor()
|
|
223
|
+
cur.execute(
|
|
224
|
+
'''
|
|
225
|
+
INSERT INTO trajectories (
|
|
226
|
+
"id", "trajectoryId", "agentId", "archetype", "windowId",
|
|
227
|
+
"stepsJson", "rewardComponentsJson", "metricsJson", "metadataJson",
|
|
228
|
+
"finalPnL", "episodeLength", "totalReward",
|
|
229
|
+
"finalStatus", "isTrainingData", "startTime", "endTime",
|
|
230
|
+
"durationMs", "createdAt", "updatedAt"
|
|
231
|
+
) VALUES (
|
|
232
|
+
%s, %s, %s, NULL, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s
|
|
233
|
+
)
|
|
234
|
+
''',
|
|
235
|
+
(
|
|
236
|
+
traj_id, traj_id, "test-agent", f"{self.test_prefix}-window-null",
|
|
237
|
+
"[]", "{}", "{}", "{}", 100.0, 3, 0.5, "completed", True,
|
|
238
|
+
datetime.now(), datetime.now(), 5000, datetime.now(), datetime.now(),
|
|
239
|
+
)
|
|
240
|
+
)
|
|
241
|
+
self.conn.commit()
|
|
242
|
+
cur.close()
|
|
243
|
+
|
|
244
|
+
# Query and normalize
|
|
245
|
+
cur = self.conn.cursor()
|
|
246
|
+
cur.execute(
|
|
247
|
+
'SELECT "archetype" FROM trajectories WHERE "trajectoryId" = %s',
|
|
248
|
+
(traj_id,)
|
|
249
|
+
)
|
|
250
|
+
row = cur.fetchone()
|
|
251
|
+
cur.close()
|
|
252
|
+
|
|
253
|
+
archetype = row[0]
|
|
254
|
+
normalized = normalize_archetype(archetype)
|
|
255
|
+
assert normalized == "default"
|
|
256
|
+
|
|
257
|
+
def test_archetype_case_normalization_in_scoring(self):
|
|
258
|
+
"""Test that archetype case variations are normalized correctly."""
|
|
259
|
+
test_cases = [
|
|
260
|
+
("TRADER", "trader"),
|
|
261
|
+
("Social_Butterfly", "social-butterfly"),
|
|
262
|
+
("goody_twoshoes", "goody-twoshoes"),
|
|
263
|
+
]
|
|
264
|
+
|
|
265
|
+
for db_value, expected_normalized in test_cases:
|
|
266
|
+
normalized = normalize_archetype(db_value)
|
|
267
|
+
assert normalized == expected_normalized
|
|
268
|
+
assert has_custom_rubric(normalized)
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
class TestDatabaseScoring:
|
|
272
|
+
"""Test scoring trajectories from database."""
|
|
273
|
+
|
|
274
|
+
@pytest.fixture(autouse=True)
|
|
275
|
+
def setup_db(self, database_url: str):
|
|
276
|
+
"""Setup database connection and cleanup."""
|
|
277
|
+
import psycopg2
|
|
278
|
+
self.conn = psycopg2.connect(database_url)
|
|
279
|
+
self.test_prefix = f"test-{uuid.uuid4().hex[:8]}"
|
|
280
|
+
yield
|
|
281
|
+
cur = self.conn.cursor()
|
|
282
|
+
cur.execute(
|
|
283
|
+
'DELETE FROM trajectories WHERE "trajectoryId" LIKE %s',
|
|
284
|
+
(f"{self.test_prefix}%",)
|
|
285
|
+
)
|
|
286
|
+
self.conn.commit()
|
|
287
|
+
cur.close()
|
|
288
|
+
self.conn.close()
|
|
289
|
+
|
|
290
|
+
def _insert_and_fetch_trajectory(
|
|
291
|
+
self,
|
|
292
|
+
archetype: str,
|
|
293
|
+
final_pnl: float,
|
|
294
|
+
steps: List[Dict],
|
|
295
|
+
) -> Dict:
|
|
296
|
+
"""Insert trajectory and return fetched data."""
|
|
297
|
+
traj_id = f"{self.test_prefix}-{archetype}-{uuid.uuid4().hex[:4]}"
|
|
298
|
+
window_id = f"{self.test_prefix}-scoring-window"
|
|
299
|
+
|
|
300
|
+
cur = self.conn.cursor()
|
|
301
|
+
cur.execute(
|
|
302
|
+
'''
|
|
303
|
+
INSERT INTO trajectories (
|
|
304
|
+
"id", "trajectoryId", "agentId", "archetype", "windowId",
|
|
305
|
+
"stepsJson", "rewardComponentsJson", "metricsJson", "metadataJson",
|
|
306
|
+
"finalPnL", "episodeLength", "totalReward",
|
|
307
|
+
"finalStatus", "isTrainingData", "startTime", "endTime",
|
|
308
|
+
"durationMs", "createdAt", "updatedAt"
|
|
309
|
+
) VALUES (
|
|
310
|
+
%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s
|
|
311
|
+
)
|
|
312
|
+
RETURNING "trajectoryId", "archetype", "stepsJson", "finalPnL", "episodeLength"
|
|
313
|
+
''',
|
|
314
|
+
(
|
|
315
|
+
traj_id, traj_id, f"agent-{archetype}", archetype, window_id,
|
|
316
|
+
json.dumps(steps), "{}", "{}", "{}", final_pnl, len(steps), 0.5, "completed", True,
|
|
317
|
+
datetime.now(), datetime.now(), 5000, datetime.now(), datetime.now(),
|
|
318
|
+
)
|
|
319
|
+
)
|
|
320
|
+
row = cur.fetchone()
|
|
321
|
+
self.conn.commit()
|
|
322
|
+
cur.close()
|
|
323
|
+
|
|
324
|
+
return {
|
|
325
|
+
"trajectory_id": row[0],
|
|
326
|
+
"archetype": row[1],
|
|
327
|
+
"steps": json.loads(row[2]),
|
|
328
|
+
"final_pnl": float(row[3]),
|
|
329
|
+
"episode_length": row[4],
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
def test_score_trader_from_database(
|
|
333
|
+
self,
|
|
334
|
+
sample_trader_trajectory: TrajectoryFixture,
|
|
335
|
+
):
|
|
336
|
+
"""Test scoring a trader trajectory from database."""
|
|
337
|
+
traj_data = self._insert_and_fetch_trajectory(
|
|
338
|
+
archetype="trader",
|
|
339
|
+
final_pnl=150.0,
|
|
340
|
+
steps=sample_trader_trajectory.steps,
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
behavior = BehaviorMetrics(
|
|
344
|
+
trades_executed=3,
|
|
345
|
+
total_pnl=traj_data["final_pnl"],
|
|
346
|
+
episode_length=traj_data["episode_length"],
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
inputs = TrajectoryRewardInputs(
|
|
350
|
+
final_pnl=traj_data["final_pnl"],
|
|
351
|
+
starting_balance=10000.0,
|
|
352
|
+
end_balance=10000.0 + traj_data["final_pnl"],
|
|
353
|
+
format_score=0.8,
|
|
354
|
+
reasoning_score=0.75,
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
score = archetype_composite_reward(
|
|
358
|
+
inputs,
|
|
359
|
+
normalize_archetype(traj_data["archetype"]),
|
|
360
|
+
behavior
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
assert 0.0 <= score <= 1.0
|
|
364
|
+
assert score > 0.3 # Profitable trader should score reasonably
|
|
365
|
+
|
|
366
|
+
def test_score_multiple_archetypes_from_database(self):
|
|
367
|
+
"""Test scoring multiple archetype trajectories from database."""
|
|
368
|
+
archetypes_pnl = [
|
|
369
|
+
("trader", 200.0),
|
|
370
|
+
("degen", -500.0),
|
|
371
|
+
("scammer", 300.0),
|
|
372
|
+
("social-butterfly", 20.0),
|
|
373
|
+
]
|
|
374
|
+
|
|
375
|
+
scores = {}
|
|
376
|
+
for archetype, pnl in archetypes_pnl:
|
|
377
|
+
traj_data = self._insert_and_fetch_trajectory(
|
|
378
|
+
archetype=archetype,
|
|
379
|
+
final_pnl=pnl,
|
|
380
|
+
steps=[],
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
behavior = BehaviorMetrics(
|
|
384
|
+
trades_executed=5,
|
|
385
|
+
total_pnl=pnl,
|
|
386
|
+
episode_length=5,
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
inputs = TrajectoryRewardInputs(
|
|
390
|
+
final_pnl=pnl,
|
|
391
|
+
starting_balance=10000.0,
|
|
392
|
+
end_balance=10000.0 + pnl,
|
|
393
|
+
format_score=0.7,
|
|
394
|
+
reasoning_score=0.7,
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
score = archetype_composite_reward(
|
|
398
|
+
inputs,
|
|
399
|
+
normalize_archetype(archetype),
|
|
400
|
+
behavior
|
|
401
|
+
)
|
|
402
|
+
scores[archetype] = score
|
|
403
|
+
|
|
404
|
+
# All scores should be valid
|
|
405
|
+
for archetype, score in scores.items():
|
|
406
|
+
assert 0.0 <= score <= 1.0, f"{archetype} has invalid score: {score}"
|
|
407
|
+
|
|
408
|
+
def test_grpo_group_formation_from_database(self):
|
|
409
|
+
"""Test forming GRPO groups from database trajectories."""
|
|
410
|
+
window_id = f"{self.test_prefix}-grpo-window"
|
|
411
|
+
archetypes = ["trader", "degen", "scammer"]
|
|
412
|
+
|
|
413
|
+
# Insert multiple trajectories to same window
|
|
414
|
+
for i, archetype in enumerate(archetypes):
|
|
415
|
+
cur = self.conn.cursor()
|
|
416
|
+
traj_id = f"{self.test_prefix}-grpo-{i}"
|
|
417
|
+
cur.execute(
|
|
418
|
+
'''
|
|
419
|
+
INSERT INTO trajectories (
|
|
420
|
+
"id", "trajectoryId", "agentId", "archetype", "windowId",
|
|
421
|
+
"stepsJson", "rewardComponentsJson", "metricsJson", "metadataJson",
|
|
422
|
+
"finalPnL", "episodeLength", "totalReward",
|
|
423
|
+
"finalStatus", "isTrainingData", "startTime", "endTime",
|
|
424
|
+
"durationMs", "createdAt", "updatedAt"
|
|
425
|
+
) VALUES (
|
|
426
|
+
%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s
|
|
427
|
+
)
|
|
428
|
+
''',
|
|
429
|
+
(
|
|
430
|
+
traj_id, traj_id, f"agent-{i}", archetype, window_id,
|
|
431
|
+
"[]", "{}", "{}", "{}", 100.0 * (i + 1), 3, 0.5, "completed", True,
|
|
432
|
+
datetime.now(), datetime.now(), 5000, datetime.now(), datetime.now(),
|
|
433
|
+
)
|
|
434
|
+
)
|
|
435
|
+
self.conn.commit()
|
|
436
|
+
cur.close()
|
|
437
|
+
|
|
438
|
+
# Query group
|
|
439
|
+
cur = self.conn.cursor()
|
|
440
|
+
cur.execute(
|
|
441
|
+
'''
|
|
442
|
+
SELECT "trajectoryId", "archetype", "finalPnL", "stepsJson", "episodeLength"
|
|
443
|
+
FROM trajectories
|
|
444
|
+
WHERE "windowId" = %s AND "isTrainingData" = true
|
|
445
|
+
''',
|
|
446
|
+
(window_id,)
|
|
447
|
+
)
|
|
448
|
+
rows = cur.fetchall()
|
|
449
|
+
cur.close()
|
|
450
|
+
|
|
451
|
+
assert len(rows) >= 2 # GRPO requires at least 2
|
|
452
|
+
|
|
453
|
+
# Score all trajectories
|
|
454
|
+
scores = []
|
|
455
|
+
for row in rows:
|
|
456
|
+
archetype = normalize_archetype(row[1])
|
|
457
|
+
pnl = float(row[2])
|
|
458
|
+
|
|
459
|
+
behavior = BehaviorMetrics(trades_executed=3, total_pnl=pnl)
|
|
460
|
+
inputs = TrajectoryRewardInputs(
|
|
461
|
+
final_pnl=pnl,
|
|
462
|
+
starting_balance=10000.0,
|
|
463
|
+
end_balance=10000.0 + pnl,
|
|
464
|
+
format_score=0.7,
|
|
465
|
+
reasoning_score=0.7,
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
score = archetype_composite_reward(inputs, archetype, behavior)
|
|
469
|
+
scores.append(score)
|
|
470
|
+
|
|
471
|
+
# Center for GRPO
|
|
472
|
+
mean_score = sum(scores) / len(scores)
|
|
473
|
+
centered = [s - mean_score for s in scores]
|
|
474
|
+
centered_mean = sum(centered) / len(centered)
|
|
475
|
+
|
|
476
|
+
assert abs(centered_mean) < 0.01
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
class TestEndToEndDatabasePipeline:
|
|
480
|
+
"""Test complete database pipeline end-to-end."""
|
|
481
|
+
|
|
482
|
+
@pytest.fixture(autouse=True)
|
|
483
|
+
def setup_db(self, database_url: str):
|
|
484
|
+
"""Setup database connection and cleanup."""
|
|
485
|
+
import psycopg2
|
|
486
|
+
self.conn = psycopg2.connect(database_url)
|
|
487
|
+
self.test_prefix = f"test-{uuid.uuid4().hex[:8]}"
|
|
488
|
+
yield
|
|
489
|
+
cur = self.conn.cursor()
|
|
490
|
+
cur.execute(
|
|
491
|
+
'DELETE FROM trajectories WHERE "trajectoryId" LIKE %s',
|
|
492
|
+
(f"{self.test_prefix}%",)
|
|
493
|
+
)
|
|
494
|
+
self.conn.commit()
|
|
495
|
+
cur.close()
|
|
496
|
+
self.conn.close()
|
|
497
|
+
|
|
498
|
+
def test_full_database_pipeline(
|
|
499
|
+
self,
|
|
500
|
+
trajectory_group: List[TrajectoryFixture],
|
|
501
|
+
):
|
|
502
|
+
"""Test full pipeline: insert → query → score → center."""
|
|
503
|
+
window_id = f"{self.test_prefix}-full-pipeline"
|
|
504
|
+
|
|
505
|
+
# Step 1: Insert trajectories
|
|
506
|
+
cur = self.conn.cursor()
|
|
507
|
+
for traj in trajectory_group:
|
|
508
|
+
traj_id = f"{self.test_prefix}-{traj.archetype}"
|
|
509
|
+
cur.execute(
|
|
510
|
+
'''
|
|
511
|
+
INSERT INTO trajectories (
|
|
512
|
+
"id", "trajectoryId", "agentId", "archetype", "windowId",
|
|
513
|
+
"stepsJson", "rewardComponentsJson", "metricsJson", "metadataJson",
|
|
514
|
+
"finalPnL", "episodeLength", "totalReward",
|
|
515
|
+
"finalStatus", "isTrainingData", "startTime", "endTime",
|
|
516
|
+
"durationMs", "createdAt", "updatedAt"
|
|
517
|
+
) VALUES (
|
|
518
|
+
%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s
|
|
519
|
+
)
|
|
520
|
+
''',
|
|
521
|
+
(
|
|
522
|
+
traj_id, traj_id, traj.agent_id, traj.archetype, window_id,
|
|
523
|
+
json.dumps(traj.steps), "{}", "{}", "{}", traj.final_pnl, traj.episode_length,
|
|
524
|
+
traj.total_reward, "completed", True,
|
|
525
|
+
datetime.now(), datetime.now(), 5000, datetime.now(), datetime.now(),
|
|
526
|
+
)
|
|
527
|
+
)
|
|
528
|
+
self.conn.commit()
|
|
529
|
+
cur.close()
|
|
530
|
+
|
|
531
|
+
# Step 2: Query trajectories
|
|
532
|
+
cur = self.conn.cursor()
|
|
533
|
+
cur.execute(
|
|
534
|
+
'''
|
|
535
|
+
SELECT "trajectoryId", "archetype", "stepsJson", "finalPnL", "episodeLength"
|
|
536
|
+
FROM trajectories
|
|
537
|
+
WHERE "windowId" = %s AND "isTrainingData" = true
|
|
538
|
+
''',
|
|
539
|
+
(window_id,)
|
|
540
|
+
)
|
|
541
|
+
rows = cur.fetchall()
|
|
542
|
+
cur.close()
|
|
543
|
+
|
|
544
|
+
assert len(rows) == 3
|
|
545
|
+
|
|
546
|
+
# Step 3: Score each trajectory
|
|
547
|
+
scored_trajectories = []
|
|
548
|
+
for row in rows:
|
|
549
|
+
traj_id, archetype, steps_json, pnl, episode_length = row
|
|
550
|
+
archetype_norm = normalize_archetype(archetype)
|
|
551
|
+
steps = json.loads(steps_json)
|
|
552
|
+
|
|
553
|
+
behavior = BehaviorMetrics(
|
|
554
|
+
trades_executed=len([s for s in steps if s.get("action", {}).get("actionType") != "hold"]),
|
|
555
|
+
total_pnl=float(pnl),
|
|
556
|
+
episode_length=episode_length,
|
|
557
|
+
)
|
|
558
|
+
|
|
559
|
+
inputs = TrajectoryRewardInputs(
|
|
560
|
+
final_pnl=float(pnl),
|
|
561
|
+
starting_balance=10000.0,
|
|
562
|
+
end_balance=10000.0 + float(pnl),
|
|
563
|
+
format_score=0.7,
|
|
564
|
+
reasoning_score=0.7,
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
score = archetype_composite_reward(inputs, archetype_norm, behavior)
|
|
568
|
+
scored_trajectories.append({
|
|
569
|
+
"trajectory_id": traj_id,
|
|
570
|
+
"archetype": archetype_norm,
|
|
571
|
+
"pnl": float(pnl),
|
|
572
|
+
"score": score,
|
|
573
|
+
})
|
|
574
|
+
|
|
575
|
+
# Step 4: Center scores for GRPO
|
|
576
|
+
mean_score = sum(t["score"] for t in scored_trajectories) / len(scored_trajectories)
|
|
577
|
+
for t in scored_trajectories:
|
|
578
|
+
t["centered_score"] = t["score"] - mean_score
|
|
579
|
+
|
|
580
|
+
# Verify results
|
|
581
|
+
centered_mean = sum(t["centered_score"] for t in scored_trajectories) / len(scored_trajectories)
|
|
582
|
+
assert abs(centered_mean) < 0.01
|
|
583
|
+
|
|
584
|
+
# Verify trader scores higher than degen (positive PnL vs negative)
|
|
585
|
+
trader = next(t for t in scored_trajectories if t["archetype"] == "trader")
|
|
586
|
+
degen = next(t for t in scored_trajectories if t["archetype"] == "degen")
|
|
587
|
+
assert trader["score"] > degen["score"]
|
|
588
|
+
|
|
589
|
+
def test_pipeline_with_all_archetypes(self):
|
|
590
|
+
"""Test pipeline handles all valid archetypes."""
|
|
591
|
+
window_id = f"{self.test_prefix}-all-archetypes"
|
|
592
|
+
archetypes = get_available_archetypes()
|
|
593
|
+
|
|
594
|
+
# Insert one trajectory per archetype
|
|
595
|
+
cur = self.conn.cursor()
|
|
596
|
+
for i, archetype in enumerate(archetypes):
|
|
597
|
+
traj_id = f"{self.test_prefix}-all-{i}"
|
|
598
|
+
cur.execute(
|
|
599
|
+
'''
|
|
600
|
+
INSERT INTO trajectories (
|
|
601
|
+
"id", "trajectoryId", "agentId", "archetype", "windowId",
|
|
602
|
+
"stepsJson", "rewardComponentsJson", "metricsJson", "metadataJson",
|
|
603
|
+
"finalPnL", "episodeLength", "totalReward",
|
|
604
|
+
"finalStatus", "isTrainingData", "startTime", "endTime",
|
|
605
|
+
"durationMs", "createdAt", "updatedAt"
|
|
606
|
+
) VALUES (
|
|
607
|
+
%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s
|
|
608
|
+
)
|
|
609
|
+
''',
|
|
610
|
+
(
|
|
611
|
+
traj_id, traj_id, f"agent-{archetype}", archetype, window_id,
|
|
612
|
+
"[]", "{}", "{}", "{}", 100.0 + i * 10, 3, 0.5, "completed", True,
|
|
613
|
+
datetime.now(), datetime.now(), 5000, datetime.now(), datetime.now(),
|
|
614
|
+
)
|
|
615
|
+
)
|
|
616
|
+
self.conn.commit()
|
|
617
|
+
cur.close()
|
|
618
|
+
|
|
619
|
+
# Query and score all
|
|
620
|
+
cur = self.conn.cursor()
|
|
621
|
+
cur.execute(
|
|
622
|
+
'''
|
|
623
|
+
SELECT "archetype", "finalPnL" FROM trajectories
|
|
624
|
+
WHERE "windowId" = %s
|
|
625
|
+
''',
|
|
626
|
+
(window_id,)
|
|
627
|
+
)
|
|
628
|
+
rows = cur.fetchall()
|
|
629
|
+
cur.close()
|
|
630
|
+
|
|
631
|
+
assert len(rows) == len(archetypes)
|
|
632
|
+
|
|
633
|
+
# Score each
|
|
634
|
+
for archetype, pnl in rows:
|
|
635
|
+
normalized = normalize_archetype(archetype)
|
|
636
|
+
assert has_custom_rubric(normalized), f"{archetype} should have rubric"
|
|
637
|
+
|
|
638
|
+
behavior = BehaviorMetrics(trades_executed=3, total_pnl=float(pnl))
|
|
639
|
+
inputs = TrajectoryRewardInputs(
|
|
640
|
+
final_pnl=float(pnl),
|
|
641
|
+
starting_balance=10000.0,
|
|
642
|
+
end_balance=10000.0 + float(pnl),
|
|
643
|
+
format_score=0.7,
|
|
644
|
+
reasoning_score=0.7,
|
|
645
|
+
)
|
|
646
|
+
|
|
647
|
+
score = archetype_composite_reward(inputs, normalized, behavior)
|
|
648
|
+
assert 0.0 <= score <= 1.0
|
|
649
|
+
|