@elizaos/training 2.0.0-alpha.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Dockerfile +75 -0
- package/LICENSE +21 -0
- package/Makefile +374 -0
- package/README.md +346 -0
- package/config/rubrics.json +137 -0
- package/docker-compose.test.yml +57 -0
- package/package.json +57 -0
- package/python/config/babylon_atropos.yaml +90 -0
- package/python/config/profiles/12gb.json +11 -0
- package/python/config/profiles/16gb.json +10 -0
- package/python/config/profiles/24gb.json +10 -0
- package/python/config/profiles/48gb.json +10 -0
- package/python/config/profiles/cpu.json +11 -0
- package/python/config/profiles/l40-2gpu-safe.json +20 -0
- package/python/config/profiles/l40-2gpu.json +22 -0
- package/python/config/profiles/l40-4gpu.json +21 -0
- package/python/config/profiles/l40.json +17 -0
- package/python/config/tinker_training.yaml +143 -0
- package/python/curriculum_state.json +165 -0
- package/python/env.template +86 -0
- package/python/env.training.template +46 -0
- package/python/pyproject.toml +41 -0
- package/python/requirements-ci.txt +31 -0
- package/python/requirements.txt +87 -0
- package/python/scripts/__init__.py +4 -0
- package/python/scripts/benchmark_should_respond.py +190 -0
- package/python/scripts/debug_inference.py +62 -0
- package/python/scripts/import_json_trajectories.py +412 -0
- package/python/scripts/local-finetune/README.md +63 -0
- package/python/scripts/local-finetune/ingest_and_score.py +139 -0
- package/python/scripts/local-finetune/merge_model.py +32 -0
- package/python/scripts/local-finetune/test_adapter.py +91 -0
- package/python/scripts/local-finetune/train_from_csv.py +132 -0
- package/python/scripts/merge_trajectories.py +318 -0
- package/python/scripts/optimize_prompt_grpo.py +269 -0
- package/python/scripts/run_ab_test.py +143 -0
- package/python/scripts/run_full_pipeline.py +544 -0
- package/python/scripts/run_tinker_training.py +192 -0
- package/python/scripts/run_training.py +914 -0
- package/python/scripts/test_generation.py +29 -0
- package/python/scripts/test_judge.py +155 -0
- package/python/scripts/test_pipeline.py +356 -0
- package/python/scripts/test_trained_model.py +380 -0
- package/python/scripts/train_grpo.py +360 -0
- package/python/scripts/train_jsonl.py +223 -0
- package/python/scripts/train_local.py +528 -0
- package/python/setup.py +20 -0
- package/python/src/__init__.py +190 -0
- package/python/src/data_bridge/__init__.py +24 -0
- package/python/src/data_bridge/converter.py +435 -0
- package/python/src/data_bridge/reader.py +393 -0
- package/python/src/models.py +283 -0
- package/python/src/training/__init__.py +605 -0
- package/python/src/training/ab_testing.py +404 -0
- package/python/src/training/action_executor.py +621 -0
- package/python/src/training/archetype_trainer.py +347 -0
- package/python/src/training/atropos_trainer.py +980 -0
- package/python/src/training/babylon_env.py +1254 -0
- package/python/src/training/error_recovery.py +647 -0
- package/python/src/training/evaluation.py +856 -0
- package/python/src/training/fast_simulator.py +880 -0
- package/python/src/training/format_validator.py +584 -0
- package/python/src/training/hybrid_env.py +522 -0
- package/python/src/training/kl_controller.py +628 -0
- package/python/src/training/multi_prompt_dataset.py +883 -0
- package/python/src/training/multi_turn.py +656 -0
- package/python/src/training/online_env.py +1084 -0
- package/python/src/training/quality_scorer.py +391 -0
- package/python/src/training/quality_utils.py +633 -0
- package/python/src/training/rewards.py +1344 -0
- package/python/src/training/rlaif_env.py +17 -0
- package/python/src/training/rollout_generator.py +502 -0
- package/python/src/training/rubric_loader.py +198 -0
- package/python/src/training/scenario_pool.py +1072 -0
- package/python/src/training/schemas.py +481 -0
- package/python/src/training/service_manager.py +552 -0
- package/python/src/training/simulation_bridge.py +535 -0
- package/python/src/training/tick_reward_attribution.py +399 -0
- package/python/src/training/tinker_client.py +575 -0
- package/python/src/training/tinker_trainer.py +646 -0
- package/python/src/training/tokenization_utils.py +402 -0
- package/python/tests/e2e/__init__.py +13 -0
- package/python/tests/e2e/conftest.py +258 -0
- package/python/tests/e2e/test_full_pipeline.py +643 -0
- package/python/tests/e2e/test_online_training_e2e.py +365 -0
- package/python/tests/integration/__init__.py +12 -0
- package/python/tests/integration/conftest.py +383 -0
- package/python/tests/integration/test_db_integration.py +649 -0
- package/python/tests/integration/test_json_mode_integration.py +554 -0
- package/python/tests/test_action_executor.py +594 -0
- package/python/tests/test_archetype_scoring.py +1027 -0
- package/python/tests/test_atropos_integration.py +360 -0
- package/python/tests/test_evaluation.py +727 -0
- package/python/tests/test_format_validator.py +486 -0
- package/python/tests/test_kl_controller.py +432 -0
- package/python/tests/test_lr_scheduler.py +579 -0
- package/python/tests/test_multi_turn.py +590 -0
- package/python/tests/test_online_env.py +519 -0
- package/python/tests/test_quality_scorer.py +474 -0
- package/python/tests/test_scenario_pool.py +735 -0
- package/python/tests/test_service_manager.py +585 -0
- package/python/tests/test_simulation_rollout.py +581 -0
- package/python/tests/test_tokenization_utils.py +501 -0
- package/python/tests/test_training_orchestrator.py +497 -0
- package/python/tests/test_training_output_structure.py +661 -0
- package/research-output/training-runs/training-run-1770772042899.json +26 -0
- package/research-output/training-runs/training-run-1770930079670.json +32 -0
- package/research-output/training-runs/training-run-1770930143700.json +44 -0
- package/research-output/training-runs/training-run-1770930183638.json +38 -0
- package/research-output/training-runs/training-run-1770930442049.json +38 -0
- package/research-output/training-runs/training-run-1770930793243.json +38 -0
- package/research-output/training-runs/training-run-1771276293257.json +38 -0
- package/research-output/training-runs/training-run-1771276389280.json +38 -0
- package/research-output/training-runs/training-run-1771276502776.json +38 -0
- package/research-output/training-runs/training-run-1771277340748.json +38 -0
- package/research-output/training-runs/training-run-1773013658993.json +38 -0
- package/research-output/training-runs/training-run-1773013861014.json +38 -0
- package/research-output/training-runs/training-run-1773014215983.json +38 -0
- package/scripts/assess-training-data.ts +422 -0
- package/scripts/e2e-training-test.ts +550 -0
- package/scripts/export-rubrics.ts +64 -0
- package/scripts/generate-research-report.ts +1523 -0
- package/scripts/generate_dataset.sh +173 -0
- package/scripts/generate_should_respond.ts +267 -0
- package/scripts/generate_should_respond_dataset.ts +162 -0
- package/scripts/json-mode-benchmark.ts +399 -0
- package/scripts/rank_trajectories.ts +207 -0
- package/scripts/real-archetype-benchmark.ts +210 -0
- package/scripts/run-baseline-comparison.ts +116 -0
- package/scripts/run-full-pipeline.ts +272 -0
- package/scripts/run_rlaif_loop.ts +78 -0
- package/scripts/run_task_benchmark.ts +247 -0
- package/scripts/runpod_setup.sh +137 -0
- package/scripts/runpod_validate.sh +147 -0
- package/scripts/test-model-in-game.ts +955 -0
- package/scripts/test-scoring.ts +73 -0
- package/scripts/test-trained-model.ts +209 -0
- package/scripts/train-and-test.ts +824 -0
- package/scripts/verify-final.ts +118 -0
- package/src/adapter.ts +516 -0
- package/src/archetypes/ArchetypeConfigService.ts +626 -0
- package/src/archetypes/derive-archetype.ts +249 -0
- package/src/archetypes/index.ts +22 -0
- package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
- package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
- package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
- package/src/benchmark/BenchmarkDataViewer.ts +324 -0
- package/src/benchmark/BenchmarkHistoryService.ts +221 -0
- package/src/benchmark/BenchmarkRunner.ts +685 -0
- package/src/benchmark/BenchmarkValidator.ts +204 -0
- package/src/benchmark/FastEvalRunner.ts +225 -0
- package/src/benchmark/MetricsValidator.ts +165 -0
- package/src/benchmark/MetricsVisualizer.ts +909 -0
- package/src/benchmark/ModelBenchmarkService.ts +611 -0
- package/src/benchmark/ModelRegistry.ts +158 -0
- package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
- package/src/benchmark/SimulationA2AInterface.ts +1169 -0
- package/src/benchmark/SimulationEngine.ts +832 -0
- package/src/benchmark/TaskRunner.ts +94 -0
- package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
- package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
- package/src/benchmark/index.ts +91 -0
- package/src/benchmark/parseSimulationMetrics.ts +124 -0
- package/src/benchmark/simulation-types.ts +78 -0
- package/src/dependencies.ts +475 -0
- package/src/generation/TrajectoryGenerator.ts +387 -0
- package/src/generation/index.ts +12 -0
- package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
- package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
- package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
- package/src/huggingface/index.ts +27 -0
- package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
- package/src/index.ts +102 -0
- package/src/init-training.ts +53 -0
- package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
- package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
- package/src/metrics/index.ts +8 -0
- package/src/metrics/types.ts +200 -0
- package/src/rubrics/__tests__/index.test.ts +184 -0
- package/src/rubrics/ass-kisser.ts +85 -0
- package/src/rubrics/degen.ts +80 -0
- package/src/rubrics/goody-twoshoes.ts +84 -0
- package/src/rubrics/index.ts +236 -0
- package/src/rubrics/information-trader.ts +84 -0
- package/src/rubrics/infosec.ts +101 -0
- package/src/rubrics/liar.ts +104 -0
- package/src/rubrics/perps-trader.ts +87 -0
- package/src/rubrics/researcher.ts +81 -0
- package/src/rubrics/scammer.ts +82 -0
- package/src/rubrics/social-butterfly.ts +73 -0
- package/src/rubrics/super-predictor.ts +97 -0
- package/src/rubrics/trader.ts +67 -0
- package/src/scoring/ArchetypeScoringService.ts +486 -0
- package/src/scoring/JudgePromptBuilder.ts +556 -0
- package/src/scoring/LLMJudgeCache.ts +401 -0
- package/src/scoring/index.ts +9 -0
- package/src/training/AutomationPipeline.ts +916 -0
- package/src/training/BenchmarkService.ts +518 -0
- package/src/training/ConfigValidator.ts +220 -0
- package/src/training/MarketOutcomesTracker.ts +187 -0
- package/src/training/ModelDeployer.ts +186 -0
- package/src/training/ModelFetcher.ts +76 -0
- package/src/training/ModelSelectionService.ts +341 -0
- package/src/training/ModelUsageVerifier.ts +160 -0
- package/src/training/MultiModelOrchestrator.ts +580 -0
- package/src/training/RLModelConfig.ts +407 -0
- package/src/training/RewardBackpropagationService.ts +149 -0
- package/src/training/RulerScoringService.ts +666 -0
- package/src/training/TrainingMonitor.ts +166 -0
- package/src/training/TrajectoryRecorder.ts +399 -0
- package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
- package/src/training/index.ts +100 -0
- package/src/training/logRLConfig.ts +34 -0
- package/src/training/pipeline.ts +129 -0
- package/src/training/storage/ModelStorageService.ts +279 -0
- package/src/training/storage/TrainingDataArchiver.ts +197 -0
- package/src/training/storage/index.ts +17 -0
- package/src/training/types.ts +207 -0
- package/src/training/window-utils.ts +138 -0
- package/src/utils/index.ts +101 -0
- package/src/utils/logger.ts +59 -0
- package/src/utils/snowflake.ts +17 -0
- package/src/utils/synthetic-detector.ts +111 -0
- package/tsconfig.json +20 -0
|
@@ -0,0 +1,633 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Shared Quality Utilities
|
|
3
|
+
|
|
4
|
+
Common quality scoring and validation functions used across the training pipeline.
|
|
5
|
+
Extracted to avoid duplication between rollout_generator and fast_simulator.
|
|
6
|
+
|
|
7
|
+
ENHANCED v3:
|
|
8
|
+
- Archetype-specific scoring weights
|
|
9
|
+
- Reasoning-action alignment validation with Financial Literacy
|
|
10
|
+
- XML Structure validation
|
|
11
|
+
- Coherence heuristics
|
|
12
|
+
- Curriculum learning support
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import re
|
|
16
|
+
import json
|
|
17
|
+
from dataclasses import dataclass
|
|
18
|
+
from datetime import datetime, timezone
|
|
19
|
+
from typing import TYPE_CHECKING, Literal
|
|
20
|
+
|
|
21
|
+
from ..models import (
|
|
22
|
+
BabylonTrajectory,
|
|
23
|
+
TrajectoryStep,
|
|
24
|
+
Action,
|
|
25
|
+
EnvironmentState,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from .rollout_generator import AgentTickData
|
|
30
|
+
|
|
31
|
+
# Archetype-specific quality weights
|
|
32
|
+
ARCHETYPE_WEIGHTS: dict[str, dict[str, float]] = {
|
|
33
|
+
# Research-heavy archetypes prioritize reasoning
|
|
34
|
+
"researcher": {"llm_calls": 0.3, "reasoning": 0.45, "action": 0.15, "feedback": 0.1},
|
|
35
|
+
"information-trader": {"llm_calls": 0.3, "reasoning": 0.4, "action": 0.2, "feedback": 0.1},
|
|
36
|
+
"super-predictor": {"llm_calls": 0.3, "reasoning": 0.4, "action": 0.2, "feedback": 0.1},
|
|
37
|
+
|
|
38
|
+
# Action-heavy archetypes prioritize execution
|
|
39
|
+
"trader": {"llm_calls": 0.3, "reasoning": 0.2, "action": 0.4, "feedback": 0.1},
|
|
40
|
+
"degen": {"llm_calls": 0.2, "reasoning": 0.15, "action": 0.55, "feedback": 0.1},
|
|
41
|
+
"perps-trader": {"llm_calls": 0.25, "reasoning": 0.2, "action": 0.45, "feedback": 0.1},
|
|
42
|
+
|
|
43
|
+
# Social archetypes prioritize engagement (response quality)
|
|
44
|
+
"social-butterfly": {"llm_calls": 0.35, "reasoning": 0.25, "action": 0.25, "feedback": 0.15},
|
|
45
|
+
"ass-kisser": {"llm_calls": 0.35, "reasoning": 0.3, "action": 0.2, "feedback": 0.15},
|
|
46
|
+
"goody-twoshoes": {"llm_calls": 0.35, "reasoning": 0.3, "action": 0.2, "feedback": 0.15},
|
|
47
|
+
|
|
48
|
+
# Deceptive archetypes prioritize reasoning (planning deception)
|
|
49
|
+
"scammer": {"llm_calls": 0.25, "reasoning": 0.4, "action": 0.25, "feedback": 0.1},
|
|
50
|
+
"liar": {"llm_calls": 0.25, "reasoning": 0.4, "action": 0.25, "feedback": 0.1},
|
|
51
|
+
|
|
52
|
+
# Balanced
|
|
53
|
+
"infosec": {"llm_calls": 0.3, "reasoning": 0.3, "action": 0.3, "feedback": 0.1},
|
|
54
|
+
|
|
55
|
+
# Default
|
|
56
|
+
"default": {"llm_calls": 0.4, "reasoning": 0.3, "action": 0.2, "feedback": 0.1},
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def validate_xml_structure(response: str) -> float:
|
|
61
|
+
"""
|
|
62
|
+
Validate that the response contains valid decision XML tags.
|
|
63
|
+
|
|
64
|
+
Criteria:
|
|
65
|
+
1. Must contain <decisions> and </decisions> tags.
|
|
66
|
+
2. Must contain at least one <decision> tag.
|
|
67
|
+
3. Attributes 'amount' and 'ticker' (or 'marketId') should be present.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
+0.5 for valid syntax and attributes
|
|
71
|
+
-1.0 for broken XML or missing tags
|
|
72
|
+
-0.2 for missing attributes in otherwise valid tags
|
|
73
|
+
"""
|
|
74
|
+
if not response:
|
|
75
|
+
return -1.0
|
|
76
|
+
|
|
77
|
+
# Check for wrapping tags
|
|
78
|
+
if "<decisions>" not in response or "</decisions>" not in response:
|
|
79
|
+
return -1.0
|
|
80
|
+
|
|
81
|
+
# Check for inner tags
|
|
82
|
+
if "<decision" not in response:
|
|
83
|
+
return -0.5 # Has wrappers but no decision?
|
|
84
|
+
|
|
85
|
+
# Check for critical attributes (simple heuristic regex to handle both quote styles)
|
|
86
|
+
has_ticker = re.search(
|
|
87
|
+
r'ticker="[^"]+"', response) or re.search(r"ticker='[^']+'", response)
|
|
88
|
+
has_market = re.search(
|
|
89
|
+
r'marketId="[^"]+"', response) or re.search(r"marketId='[^']+'", response)
|
|
90
|
+
has_amount = re.search(
|
|
91
|
+
r'amount="[^"]+"', response) or re.search(r"amount='[^']+'", response)
|
|
92
|
+
|
|
93
|
+
# Need either ticker OR marketId, AND amount
|
|
94
|
+
if (not has_ticker and not has_market) or not has_amount:
|
|
95
|
+
return -0.2 # Penalty for partial hallucination / missing args
|
|
96
|
+
|
|
97
|
+
return 0.5
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def check_reasoning_action_alignment(
|
|
101
|
+
reasoning_text: str,
|
|
102
|
+
action: Action | None,
|
|
103
|
+
) -> float:
|
|
104
|
+
"""
|
|
105
|
+
Check if reasoning aligns with action taken, including Financial Literacy check.
|
|
106
|
+
|
|
107
|
+
Components:
|
|
108
|
+
1. Directional Alignment (Up/Buy vs Down/Sell)
|
|
109
|
+
2. Financial Literacy Bonus (referencing Exposure or PnL)
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
Score between 0.0 and 1.0
|
|
113
|
+
"""
|
|
114
|
+
if not action or not reasoning_text:
|
|
115
|
+
return 0.5 # Neutral if we can't check
|
|
116
|
+
|
|
117
|
+
reasoning_lower = reasoning_text.lower()
|
|
118
|
+
action_type = action.action_type.lower()
|
|
119
|
+
|
|
120
|
+
score = 0.5
|
|
121
|
+
|
|
122
|
+
# --- 1. Financial Literacy Check ---
|
|
123
|
+
literacy_bonus = 0.0
|
|
124
|
+
if "exposure" in reasoning_lower:
|
|
125
|
+
literacy_bonus += 0.15
|
|
126
|
+
if "pnl" in reasoning_lower or "profit" in reasoning_lower or "loss" in reasoning_lower:
|
|
127
|
+
literacy_bonus += 0.15
|
|
128
|
+
|
|
129
|
+
# --- 2. Directional Alignment ---
|
|
130
|
+
# Sentiment indicators
|
|
131
|
+
bullish_words = ["bullish", "buy", "long",
|
|
132
|
+
"upward", "positive", "opportunity", "moon"]
|
|
133
|
+
bearish_words = ["bearish", "sell", "short",
|
|
134
|
+
"downward", "negative", "avoid", "dump"]
|
|
135
|
+
wait_words = ["wait", "hold", "unclear",
|
|
136
|
+
"uncertain", "need more data", "observing"]
|
|
137
|
+
|
|
138
|
+
# Count sentiment
|
|
139
|
+
bullish_score = sum(1 for w in bullish_words if w in reasoning_lower)
|
|
140
|
+
bearish_score = sum(1 for w in bearish_words if w in reasoning_lower)
|
|
141
|
+
wait_score = sum(1 for w in wait_words if w in reasoning_lower)
|
|
142
|
+
|
|
143
|
+
# Check alignment
|
|
144
|
+
is_buy = action_type in ["buy", "buy_prediction", "open_perp", "long"]
|
|
145
|
+
is_sell = action_type in ["sell", "sell_prediction", "close_perp", "short"]
|
|
146
|
+
is_wait = action_type in ["wait", "hold"]
|
|
147
|
+
|
|
148
|
+
if is_buy:
|
|
149
|
+
if bullish_score > bearish_score:
|
|
150
|
+
score = 0.7 # Aligned
|
|
151
|
+
elif bearish_score > bullish_score:
|
|
152
|
+
score = 0.0 # Misaligned (Hallucination penalty)
|
|
153
|
+
else:
|
|
154
|
+
score = 0.4
|
|
155
|
+
elif is_sell:
|
|
156
|
+
if bearish_score > bullish_score:
|
|
157
|
+
score = 0.7 # Aligned
|
|
158
|
+
elif bullish_score > bearish_score:
|
|
159
|
+
score = 0.0 # Misaligned (Hallucination penalty)
|
|
160
|
+
else:
|
|
161
|
+
score = 0.4
|
|
162
|
+
elif is_wait:
|
|
163
|
+
if wait_score > 0:
|
|
164
|
+
score = 0.7
|
|
165
|
+
else:
|
|
166
|
+
score = 0.5
|
|
167
|
+
|
|
168
|
+
# Cap total at 1.0
|
|
169
|
+
return min(1.0, score + literacy_bonus)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def check_reasoning_coherence(reasoning_text: str) -> float:
|
|
173
|
+
"""
|
|
174
|
+
Check reasoning coherence using simple heuristics (0-1 score).
|
|
175
|
+
"""
|
|
176
|
+
if not reasoning_text or len(reasoning_text) < 20:
|
|
177
|
+
return 0.1
|
|
178
|
+
|
|
179
|
+
score = 0.0
|
|
180
|
+
text = reasoning_text
|
|
181
|
+
|
|
182
|
+
# Check for structure (numbered lists, bullet points)
|
|
183
|
+
if re.search(r'(\d+[\.\):]|\-|\*|\•)', text):
|
|
184
|
+
score += 0.25
|
|
185
|
+
|
|
186
|
+
# Check for conclusion markers
|
|
187
|
+
conclusion_markers = [
|
|
188
|
+
"therefore", "conclusion", "decision", "recommend",
|
|
189
|
+
"suggest", "final", "result", "action:", "execute"
|
|
190
|
+
]
|
|
191
|
+
if any(marker in text.lower() for marker in conclusion_markers):
|
|
192
|
+
score += 0.25
|
|
193
|
+
|
|
194
|
+
# Check sentence count (2-10 sentences is ideal)
|
|
195
|
+
sentences = text.split('. ')
|
|
196
|
+
if 2 <= len(sentences) <= 10:
|
|
197
|
+
score += 0.2
|
|
198
|
+
elif len(sentences) > 10:
|
|
199
|
+
score += 0.1 # Too verbose
|
|
200
|
+
|
|
201
|
+
# Check for repetitive patterns (bad quality indicator)
|
|
202
|
+
words = text.lower().split()
|
|
203
|
+
if len(words) > 10:
|
|
204
|
+
unique_ratio = len(set(words)) / len(words)
|
|
205
|
+
if unique_ratio > 0.4:
|
|
206
|
+
score += 0.15 # Good vocabulary diversity
|
|
207
|
+
else:
|
|
208
|
+
score -= 0.1 # Repetitive
|
|
209
|
+
else:
|
|
210
|
+
score += 0.1
|
|
211
|
+
|
|
212
|
+
# Check for numeric analysis (prices, percentages)
|
|
213
|
+
if re.search(r'\$?\d+(?:\.\d+)?(?:%|k|K|M)?', text):
|
|
214
|
+
score += 0.15 # Contains quantitative analysis
|
|
215
|
+
|
|
216
|
+
return min(max(score, 0.0), 1.0)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def calculate_detailed_tick_quality(
|
|
220
|
+
llm_calls: list,
|
|
221
|
+
action: Action | None,
|
|
222
|
+
feedback: dict | None,
|
|
223
|
+
archetype: str | None = None,
|
|
224
|
+
) -> tuple[float, float]:
|
|
225
|
+
"""
|
|
226
|
+
Calculate detailed quality scores.
|
|
227
|
+
Returns: (format_score, reasoning_score)
|
|
228
|
+
"""
|
|
229
|
+
format_score = 0.0
|
|
230
|
+
reasoning_score = 0.0
|
|
231
|
+
|
|
232
|
+
# 1. Format Score (XML)
|
|
233
|
+
if llm_calls:
|
|
234
|
+
last_call = llm_calls[-1]
|
|
235
|
+
if last_call.response:
|
|
236
|
+
format_score = validate_xml_structure(last_call.response)
|
|
237
|
+
|
|
238
|
+
# 2. Reasoning Score
|
|
239
|
+
reasoning_texts = []
|
|
240
|
+
for call in llm_calls:
|
|
241
|
+
if call.reasoning:
|
|
242
|
+
reasoning_texts.append(call.reasoning)
|
|
243
|
+
if call.response:
|
|
244
|
+
reasoning_texts.append(call.response)
|
|
245
|
+
|
|
246
|
+
if action and action.reasoning:
|
|
247
|
+
reasoning_texts.append(action.reasoning)
|
|
248
|
+
|
|
249
|
+
full_reasoning = " ".join(reasoning_texts)
|
|
250
|
+
|
|
251
|
+
if full_reasoning:
|
|
252
|
+
reasoning_score = check_reasoning_action_alignment(
|
|
253
|
+
full_reasoning, action)
|
|
254
|
+
# Coherence boost
|
|
255
|
+
reasoning_score += check_reasoning_coherence(full_reasoning) * 0.2
|
|
256
|
+
|
|
257
|
+
return format_score, min(1.0, reasoning_score)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def calculate_tick_quality_score(
|
|
261
|
+
llm_calls: list,
|
|
262
|
+
action: Action | None,
|
|
263
|
+
feedback: dict | None,
|
|
264
|
+
archetype: str | None = None,
|
|
265
|
+
) -> float:
|
|
266
|
+
"""
|
|
267
|
+
Calculate quality score for a single tick (0-1).
|
|
268
|
+
Legacy wrapper that returns a single float to maintain API compatibility.
|
|
269
|
+
"""
|
|
270
|
+
weights = ARCHETYPE_WEIGHTS.get(
|
|
271
|
+
archetype or "default", ARCHETYPE_WEIGHTS["default"])
|
|
272
|
+
|
|
273
|
+
# Get detailed scores
|
|
274
|
+
fmt, rsn = calculate_detailed_tick_quality(
|
|
275
|
+
llm_calls, action, feedback, archetype)
|
|
276
|
+
|
|
277
|
+
# Calculate action score separately as before
|
|
278
|
+
action_score = 0.0
|
|
279
|
+
if action:
|
|
280
|
+
if action.success:
|
|
281
|
+
action_score = 1.0
|
|
282
|
+
elif action.error:
|
|
283
|
+
action_score = 0.25
|
|
284
|
+
else:
|
|
285
|
+
action_score = 0.5
|
|
286
|
+
|
|
287
|
+
feedback_score = 0.0
|
|
288
|
+
if feedback:
|
|
289
|
+
feedback_score = 1.0
|
|
290
|
+
|
|
291
|
+
# Combine using legacy weights logic plus new components
|
|
292
|
+
# We map format (-1 to 0.5) to a 0-1 range for the legacy score roughly:
|
|
293
|
+
# 0.5 -> 1.0, -1.0 -> 0.0
|
|
294
|
+
normalized_format = (fmt + 1.0) / 1.5
|
|
295
|
+
|
|
296
|
+
total_score = (
|
|
297
|
+
normalized_format * weights["llm_calls"] +
|
|
298
|
+
rsn * weights["reasoning"] +
|
|
299
|
+
action_score * weights["action"] +
|
|
300
|
+
feedback_score * weights["feedback"]
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
return min(1.0, max(0.0, total_score))
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
CurriculumLevel = Literal["easy", "medium", "hard"]
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
@dataclass
|
|
310
|
+
class TrajectoryDifficulty:
|
|
311
|
+
"""Trajectory difficulty assessment for curriculum learning"""
|
|
312
|
+
level: CurriculumLevel
|
|
313
|
+
score: float # 0-1, higher = harder
|
|
314
|
+
reasons: list[str]
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def calculate_trajectory_quality_score(
|
|
318
|
+
ticks: list["AgentTickData"],
|
|
319
|
+
archetype: str | None = None,
|
|
320
|
+
) -> float:
|
|
321
|
+
"""
|
|
322
|
+
Calculate overall quality score for a trajectory (0-1).
|
|
323
|
+
|
|
324
|
+
Args:
|
|
325
|
+
ticks: List of tick data
|
|
326
|
+
archetype: Agent archetype for weight customization
|
|
327
|
+
"""
|
|
328
|
+
if not ticks:
|
|
329
|
+
return 0.0
|
|
330
|
+
|
|
331
|
+
scores = [
|
|
332
|
+
calculate_tick_quality_score(
|
|
333
|
+
tick.llm_calls,
|
|
334
|
+
tick.action,
|
|
335
|
+
tick.feedback,
|
|
336
|
+
archetype=archetype,
|
|
337
|
+
)
|
|
338
|
+
for tick in ticks
|
|
339
|
+
]
|
|
340
|
+
|
|
341
|
+
return sum(scores) / len(scores)
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def assess_trajectory_difficulty(
|
|
345
|
+
ticks: list["AgentTickData"],
|
|
346
|
+
) -> TrajectoryDifficulty:
|
|
347
|
+
"""
|
|
348
|
+
Assess difficulty of a trajectory for curriculum learning.
|
|
349
|
+
|
|
350
|
+
Difficulty factors:
|
|
351
|
+
- Number of market changes
|
|
352
|
+
- Action complexity (leverage, size)
|
|
353
|
+
- Decision reversals
|
|
354
|
+
- Length of reasoning required
|
|
355
|
+
"""
|
|
356
|
+
reasons = []
|
|
357
|
+
difficulty_score = 0.0
|
|
358
|
+
|
|
359
|
+
if not ticks:
|
|
360
|
+
return TrajectoryDifficulty(level="easy", score=0.0, reasons=["Empty trajectory"])
|
|
361
|
+
|
|
362
|
+
# Factor 1: Trajectory length (longer = harder)
|
|
363
|
+
if len(ticks) > 20:
|
|
364
|
+
difficulty_score += 0.2
|
|
365
|
+
reasons.append(f"Long trajectory ({len(ticks)} ticks)")
|
|
366
|
+
elif len(ticks) > 10:
|
|
367
|
+
difficulty_score += 0.1
|
|
368
|
+
|
|
369
|
+
# Factor 2: Action diversity (more diverse = harder)
|
|
370
|
+
action_types = set()
|
|
371
|
+
for tick in ticks:
|
|
372
|
+
if tick.action:
|
|
373
|
+
action_types.add(tick.action.action_type)
|
|
374
|
+
|
|
375
|
+
if len(action_types) >= 4:
|
|
376
|
+
difficulty_score += 0.2
|
|
377
|
+
reasons.append(f"High action diversity ({len(action_types)} types)")
|
|
378
|
+
elif len(action_types) >= 2:
|
|
379
|
+
difficulty_score += 0.1
|
|
380
|
+
|
|
381
|
+
# Factor 3: Complex parameters (leverage, large sizes)
|
|
382
|
+
complex_actions = 0
|
|
383
|
+
for tick in ticks:
|
|
384
|
+
if tick.action and tick.action.parameters:
|
|
385
|
+
params = tick.action.parameters
|
|
386
|
+
# Explicitly cast to string then float to satisfy Pylance
|
|
387
|
+
try:
|
|
388
|
+
leverage = float(str(params.get("leverage", 1)))
|
|
389
|
+
if leverage > 1:
|
|
390
|
+
complex_actions += 1
|
|
391
|
+
except (ValueError, TypeError):
|
|
392
|
+
pass
|
|
393
|
+
|
|
394
|
+
try:
|
|
395
|
+
amount = float(str(params.get("amount", 0)))
|
|
396
|
+
if amount > 1000:
|
|
397
|
+
complex_actions += 1
|
|
398
|
+
except (ValueError, TypeError):
|
|
399
|
+
pass
|
|
400
|
+
|
|
401
|
+
if complex_actions >= 3:
|
|
402
|
+
difficulty_score += 0.2
|
|
403
|
+
reasons.append(f"Complex action parameters ({complex_actions})")
|
|
404
|
+
elif complex_actions >= 1:
|
|
405
|
+
difficulty_score += 0.1
|
|
406
|
+
|
|
407
|
+
# Factor 4: Decision reversals (buy -> sell in short time)
|
|
408
|
+
reversals = 0
|
|
409
|
+
prev_action = None
|
|
410
|
+
for tick in ticks:
|
|
411
|
+
if tick.action:
|
|
412
|
+
curr = tick.action.action_type
|
|
413
|
+
if prev_action:
|
|
414
|
+
if (prev_action in ["buy", "long"] and curr in ["sell", "short"]) or \
|
|
415
|
+
(prev_action in ["sell", "short"] and curr in ["buy", "long"]):
|
|
416
|
+
reversals += 1
|
|
417
|
+
prev_action = curr
|
|
418
|
+
|
|
419
|
+
if reversals >= 2:
|
|
420
|
+
difficulty_score += 0.2
|
|
421
|
+
reasons.append(f"Multiple reversals ({reversals})")
|
|
422
|
+
elif reversals >= 1:
|
|
423
|
+
difficulty_score += 0.1
|
|
424
|
+
|
|
425
|
+
# Factor 5: Reasoning depth required
|
|
426
|
+
total_reasoning_len = sum(
|
|
427
|
+
sum(len(c.reasoning or "") for c in tick.llm_calls) +
|
|
428
|
+
len((tick.action.reasoning or "") if tick.action else "")
|
|
429
|
+
for tick in ticks
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
avg_reasoning = total_reasoning_len / len(ticks) if ticks else 0
|
|
433
|
+
if avg_reasoning > 200:
|
|
434
|
+
difficulty_score += 0.2
|
|
435
|
+
reasons.append(
|
|
436
|
+
f"Deep reasoning required (avg {avg_reasoning:.0f} chars)")
|
|
437
|
+
elif avg_reasoning > 100:
|
|
438
|
+
difficulty_score += 0.1
|
|
439
|
+
|
|
440
|
+
# Normalize and categorize
|
|
441
|
+
difficulty_score = min(difficulty_score, 1.0)
|
|
442
|
+
|
|
443
|
+
if difficulty_score >= 0.6:
|
|
444
|
+
level: CurriculumLevel = "hard"
|
|
445
|
+
elif difficulty_score >= 0.3:
|
|
446
|
+
level = "medium"
|
|
447
|
+
else:
|
|
448
|
+
level = "easy"
|
|
449
|
+
|
|
450
|
+
return TrajectoryDifficulty(
|
|
451
|
+
level=level,
|
|
452
|
+
score=difficulty_score,
|
|
453
|
+
reasons=reasons if reasons else ["Standard complexity"],
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
def build_trajectory_from_ticks(
|
|
458
|
+
trajectory_id: str,
|
|
459
|
+
agent_id: str,
|
|
460
|
+
ticks: list["AgentTickData"],
|
|
461
|
+
min_steps: int = 1,
|
|
462
|
+
) -> BabylonTrajectory | None:
|
|
463
|
+
"""
|
|
464
|
+
Build a BabylonTrajectory from tick data.
|
|
465
|
+
|
|
466
|
+
Args:
|
|
467
|
+
trajectory_id: Unique trajectory ID
|
|
468
|
+
agent_id: Agent ID
|
|
469
|
+
ticks: List of AgentTickData
|
|
470
|
+
min_steps: Minimum steps required (returns None if fewer)
|
|
471
|
+
|
|
472
|
+
Returns:
|
|
473
|
+
BabylonTrajectory or None if insufficient data
|
|
474
|
+
"""
|
|
475
|
+
if len(ticks) < min_steps:
|
|
476
|
+
return None
|
|
477
|
+
|
|
478
|
+
steps = []
|
|
479
|
+
for tick in ticks:
|
|
480
|
+
step = TrajectoryStep(
|
|
481
|
+
step_number=tick.tick_number,
|
|
482
|
+
timestamp=tick.timestamp,
|
|
483
|
+
environment_state=tick.environment_state,
|
|
484
|
+
provider_accesses=[],
|
|
485
|
+
llm_calls=tick.llm_calls,
|
|
486
|
+
action=tick.action or Action(
|
|
487
|
+
action_type="wait",
|
|
488
|
+
parameters={},
|
|
489
|
+
success=True,
|
|
490
|
+
),
|
|
491
|
+
reward=tick.reward,
|
|
492
|
+
)
|
|
493
|
+
steps.append(step)
|
|
494
|
+
|
|
495
|
+
# Calculate final metrics
|
|
496
|
+
final_pnl = ticks[-1].environment_state.agent_pnl if ticks else 0.0
|
|
497
|
+
final_balance = ticks[-1].environment_state.agent_balance if ticks else 10000.0
|
|
498
|
+
total_reward = sum(t.reward for t in ticks)
|
|
499
|
+
|
|
500
|
+
# Count trades and posts
|
|
501
|
+
trades_executed = sum(
|
|
502
|
+
1 for t in ticks
|
|
503
|
+
if t.action and t.action.action_type in [
|
|
504
|
+
"buy", "sell", "buy_prediction", "sell_prediction",
|
|
505
|
+
"open_perp", "close_perp"
|
|
506
|
+
]
|
|
507
|
+
)
|
|
508
|
+
posts_created = sum(
|
|
509
|
+
1 for t in ticks
|
|
510
|
+
if t.action and t.action.action_type in ["create_post", "post"]
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
now = datetime.now(timezone.utc)
|
|
514
|
+
|
|
515
|
+
return BabylonTrajectory(
|
|
516
|
+
id=trajectory_id,
|
|
517
|
+
trajectory_id=trajectory_id,
|
|
518
|
+
agent_id=agent_id,
|
|
519
|
+
window_id=now.strftime("%Y-%m-%dT%H:00"),
|
|
520
|
+
start_time=datetime.fromtimestamp(
|
|
521
|
+
ticks[0].timestamp / 1000, tz=timezone.utc),
|
|
522
|
+
end_time=datetime.fromtimestamp(
|
|
523
|
+
ticks[-1].timestamp / 1000, tz=timezone.utc),
|
|
524
|
+
duration_ms=ticks[-1].timestamp - ticks[0].timestamp,
|
|
525
|
+
steps=steps,
|
|
526
|
+
total_reward=total_reward,
|
|
527
|
+
final_pnl=final_pnl,
|
|
528
|
+
final_balance=final_balance,
|
|
529
|
+
trades_executed=trades_executed,
|
|
530
|
+
posts_created=posts_created,
|
|
531
|
+
episode_length=len(steps),
|
|
532
|
+
final_status="completed",
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
def state_to_observation(game_state: dict) -> dict:
|
|
537
|
+
"""Convert game state to agent observation"""
|
|
538
|
+
return {
|
|
539
|
+
"tick": game_state.get("tick", 0),
|
|
540
|
+
"time": game_state.get("currentTime", 0),
|
|
541
|
+
"markets": game_state.get("predictionMarkets", []),
|
|
542
|
+
"perpetuals": game_state.get("perpetualMarkets", []),
|
|
543
|
+
"news": game_state.get("news", [])[:5], # Limit for speed
|
|
544
|
+
"posts": game_state.get("socialFeed", [])[:10],
|
|
545
|
+
}
|
|
546
|
+
|
|
547
|
+
|
|
548
|
+
def state_to_env_state(game_state: dict, agent_id: str) -> EnvironmentState:
|
|
549
|
+
"""Extract environment state for an agent from game state"""
|
|
550
|
+
# Find agent's portfolio
|
|
551
|
+
portfolio = {}
|
|
552
|
+
for p in game_state.get("portfolios", []):
|
|
553
|
+
if p.get("agentId") == agent_id:
|
|
554
|
+
portfolio = p
|
|
555
|
+
break
|
|
556
|
+
|
|
557
|
+
return EnvironmentState(
|
|
558
|
+
agent_balance=portfolio.get("balance", 10000.0),
|
|
559
|
+
agentPnL=portfolio.get("pnl", 0.0),
|
|
560
|
+
open_positions=portfolio.get(
|
|
561
|
+
"positionCount", portfolio.get("positions", 0)),
|
|
562
|
+
active_markets=len(game_state.get("predictionMarkets", [])),
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
|
|
566
|
+
@dataclass
|
|
567
|
+
class ValidationResult:
|
|
568
|
+
"""Result of rollout validation"""
|
|
569
|
+
is_valid: bool
|
|
570
|
+
issues: list[str]
|
|
571
|
+
quality_score: float
|
|
572
|
+
|
|
573
|
+
@property
|
|
574
|
+
def issue_count(self) -> int:
|
|
575
|
+
return len(self.issues)
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
def validate_trajectory_quality(
|
|
579
|
+
ticks: list["AgentTickData"],
|
|
580
|
+
min_ticks: int = 5,
|
|
581
|
+
min_llm_calls_per_tick: float = 0.8, # 80% of ticks should have LLM calls
|
|
582
|
+
min_quality_score: float = 0.5,
|
|
583
|
+
) -> ValidationResult:
|
|
584
|
+
"""
|
|
585
|
+
Validate trajectory meets quality requirements for training.
|
|
586
|
+
|
|
587
|
+
Args:
|
|
588
|
+
ticks: List of tick data
|
|
589
|
+
min_ticks: Minimum number of ticks required
|
|
590
|
+
min_llm_calls_per_tick: Minimum fraction of ticks with LLM calls
|
|
591
|
+
min_quality_score: Minimum quality score threshold
|
|
592
|
+
|
|
593
|
+
Returns:
|
|
594
|
+
ValidationResult with validity, issues, and score
|
|
595
|
+
"""
|
|
596
|
+
issues: list[str] = []
|
|
597
|
+
|
|
598
|
+
# Check tick count
|
|
599
|
+
if len(ticks) < min_ticks:
|
|
600
|
+
issues.append(f"Too few ticks: {len(ticks)} < {min_ticks}")
|
|
601
|
+
|
|
602
|
+
if not ticks:
|
|
603
|
+
return ValidationResult(is_valid=False, issues=issues, quality_score=0.0)
|
|
604
|
+
|
|
605
|
+
# Check LLM call coverage
|
|
606
|
+
ticks_with_calls = sum(1 for t in ticks if t.llm_calls)
|
|
607
|
+
call_coverage = ticks_with_calls / len(ticks)
|
|
608
|
+
if call_coverage < min_llm_calls_per_tick:
|
|
609
|
+
issues.append(
|
|
610
|
+
f"Low LLM call coverage: {call_coverage:.1%} < {min_llm_calls_per_tick:.1%}")
|
|
611
|
+
|
|
612
|
+
# Check for empty LLM calls
|
|
613
|
+
empty_calls = 0
|
|
614
|
+
for tick in ticks:
|
|
615
|
+
for call in tick.llm_calls:
|
|
616
|
+
if not call.user_prompt or not call.response:
|
|
617
|
+
empty_calls += 1
|
|
618
|
+
|
|
619
|
+
if empty_calls > 0:
|
|
620
|
+
issues.append(f"{empty_calls} LLM calls with empty prompt/response")
|
|
621
|
+
|
|
622
|
+
# Calculate quality score
|
|
623
|
+
quality_score = calculate_trajectory_quality_score(ticks)
|
|
624
|
+
|
|
625
|
+
if quality_score < min_quality_score:
|
|
626
|
+
issues.append(
|
|
627
|
+
f"Quality score too low: {quality_score:.2f} < {min_quality_score}")
|
|
628
|
+
|
|
629
|
+
return ValidationResult(
|
|
630
|
+
is_valid=len(issues) == 0,
|
|
631
|
+
issues=issues,
|
|
632
|
+
quality_score=quality_score,
|
|
633
|
+
)
|