@elizaos/training 2.0.0-alpha.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Dockerfile +75 -0
- package/LICENSE +21 -0
- package/Makefile +374 -0
- package/README.md +346 -0
- package/config/rubrics.json +137 -0
- package/docker-compose.test.yml +57 -0
- package/package.json +57 -0
- package/python/config/babylon_atropos.yaml +90 -0
- package/python/config/profiles/12gb.json +11 -0
- package/python/config/profiles/16gb.json +10 -0
- package/python/config/profiles/24gb.json +10 -0
- package/python/config/profiles/48gb.json +10 -0
- package/python/config/profiles/cpu.json +11 -0
- package/python/config/profiles/l40-2gpu-safe.json +20 -0
- package/python/config/profiles/l40-2gpu.json +22 -0
- package/python/config/profiles/l40-4gpu.json +21 -0
- package/python/config/profiles/l40.json +17 -0
- package/python/config/tinker_training.yaml +143 -0
- package/python/curriculum_state.json +165 -0
- package/python/env.template +86 -0
- package/python/env.training.template +46 -0
- package/python/pyproject.toml +41 -0
- package/python/requirements-ci.txt +31 -0
- package/python/requirements.txt +87 -0
- package/python/scripts/__init__.py +4 -0
- package/python/scripts/benchmark_should_respond.py +190 -0
- package/python/scripts/debug_inference.py +62 -0
- package/python/scripts/import_json_trajectories.py +412 -0
- package/python/scripts/local-finetune/README.md +63 -0
- package/python/scripts/local-finetune/ingest_and_score.py +139 -0
- package/python/scripts/local-finetune/merge_model.py +32 -0
- package/python/scripts/local-finetune/test_adapter.py +91 -0
- package/python/scripts/local-finetune/train_from_csv.py +132 -0
- package/python/scripts/merge_trajectories.py +318 -0
- package/python/scripts/optimize_prompt_grpo.py +269 -0
- package/python/scripts/run_ab_test.py +143 -0
- package/python/scripts/run_full_pipeline.py +544 -0
- package/python/scripts/run_tinker_training.py +192 -0
- package/python/scripts/run_training.py +914 -0
- package/python/scripts/test_generation.py +29 -0
- package/python/scripts/test_judge.py +155 -0
- package/python/scripts/test_pipeline.py +356 -0
- package/python/scripts/test_trained_model.py +380 -0
- package/python/scripts/train_grpo.py +360 -0
- package/python/scripts/train_jsonl.py +223 -0
- package/python/scripts/train_local.py +528 -0
- package/python/setup.py +20 -0
- package/python/src/__init__.py +190 -0
- package/python/src/data_bridge/__init__.py +24 -0
- package/python/src/data_bridge/converter.py +435 -0
- package/python/src/data_bridge/reader.py +393 -0
- package/python/src/models.py +283 -0
- package/python/src/training/__init__.py +605 -0
- package/python/src/training/ab_testing.py +404 -0
- package/python/src/training/action_executor.py +621 -0
- package/python/src/training/archetype_trainer.py +347 -0
- package/python/src/training/atropos_trainer.py +980 -0
- package/python/src/training/babylon_env.py +1254 -0
- package/python/src/training/error_recovery.py +647 -0
- package/python/src/training/evaluation.py +856 -0
- package/python/src/training/fast_simulator.py +880 -0
- package/python/src/training/format_validator.py +584 -0
- package/python/src/training/hybrid_env.py +522 -0
- package/python/src/training/kl_controller.py +628 -0
- package/python/src/training/multi_prompt_dataset.py +883 -0
- package/python/src/training/multi_turn.py +656 -0
- package/python/src/training/online_env.py +1084 -0
- package/python/src/training/quality_scorer.py +391 -0
- package/python/src/training/quality_utils.py +633 -0
- package/python/src/training/rewards.py +1344 -0
- package/python/src/training/rlaif_env.py +17 -0
- package/python/src/training/rollout_generator.py +502 -0
- package/python/src/training/rubric_loader.py +198 -0
- package/python/src/training/scenario_pool.py +1072 -0
- package/python/src/training/schemas.py +481 -0
- package/python/src/training/service_manager.py +552 -0
- package/python/src/training/simulation_bridge.py +535 -0
- package/python/src/training/tick_reward_attribution.py +399 -0
- package/python/src/training/tinker_client.py +575 -0
- package/python/src/training/tinker_trainer.py +646 -0
- package/python/src/training/tokenization_utils.py +402 -0
- package/python/tests/e2e/__init__.py +13 -0
- package/python/tests/e2e/conftest.py +258 -0
- package/python/tests/e2e/test_full_pipeline.py +643 -0
- package/python/tests/e2e/test_online_training_e2e.py +365 -0
- package/python/tests/integration/__init__.py +12 -0
- package/python/tests/integration/conftest.py +383 -0
- package/python/tests/integration/test_db_integration.py +649 -0
- package/python/tests/integration/test_json_mode_integration.py +554 -0
- package/python/tests/test_action_executor.py +594 -0
- package/python/tests/test_archetype_scoring.py +1027 -0
- package/python/tests/test_atropos_integration.py +360 -0
- package/python/tests/test_evaluation.py +727 -0
- package/python/tests/test_format_validator.py +486 -0
- package/python/tests/test_kl_controller.py +432 -0
- package/python/tests/test_lr_scheduler.py +579 -0
- package/python/tests/test_multi_turn.py +590 -0
- package/python/tests/test_online_env.py +519 -0
- package/python/tests/test_quality_scorer.py +474 -0
- package/python/tests/test_scenario_pool.py +735 -0
- package/python/tests/test_service_manager.py +585 -0
- package/python/tests/test_simulation_rollout.py +581 -0
- package/python/tests/test_tokenization_utils.py +501 -0
- package/python/tests/test_training_orchestrator.py +497 -0
- package/python/tests/test_training_output_structure.py +661 -0
- package/research-output/training-runs/training-run-1770772042899.json +26 -0
- package/research-output/training-runs/training-run-1770930079670.json +32 -0
- package/research-output/training-runs/training-run-1770930143700.json +44 -0
- package/research-output/training-runs/training-run-1770930183638.json +38 -0
- package/research-output/training-runs/training-run-1770930442049.json +38 -0
- package/research-output/training-runs/training-run-1770930793243.json +38 -0
- package/research-output/training-runs/training-run-1771276293257.json +38 -0
- package/research-output/training-runs/training-run-1771276389280.json +38 -0
- package/research-output/training-runs/training-run-1771276502776.json +38 -0
- package/research-output/training-runs/training-run-1771277340748.json +38 -0
- package/research-output/training-runs/training-run-1773013658993.json +38 -0
- package/research-output/training-runs/training-run-1773013861014.json +38 -0
- package/research-output/training-runs/training-run-1773014215983.json +38 -0
- package/scripts/assess-training-data.ts +422 -0
- package/scripts/e2e-training-test.ts +550 -0
- package/scripts/export-rubrics.ts +64 -0
- package/scripts/generate-research-report.ts +1523 -0
- package/scripts/generate_dataset.sh +173 -0
- package/scripts/generate_should_respond.ts +267 -0
- package/scripts/generate_should_respond_dataset.ts +162 -0
- package/scripts/json-mode-benchmark.ts +399 -0
- package/scripts/rank_trajectories.ts +207 -0
- package/scripts/real-archetype-benchmark.ts +210 -0
- package/scripts/run-baseline-comparison.ts +116 -0
- package/scripts/run-full-pipeline.ts +272 -0
- package/scripts/run_rlaif_loop.ts +78 -0
- package/scripts/run_task_benchmark.ts +247 -0
- package/scripts/runpod_setup.sh +137 -0
- package/scripts/runpod_validate.sh +147 -0
- package/scripts/test-model-in-game.ts +955 -0
- package/scripts/test-scoring.ts +73 -0
- package/scripts/test-trained-model.ts +209 -0
- package/scripts/train-and-test.ts +824 -0
- package/scripts/verify-final.ts +118 -0
- package/src/adapter.ts +516 -0
- package/src/archetypes/ArchetypeConfigService.ts +626 -0
- package/src/archetypes/derive-archetype.ts +249 -0
- package/src/archetypes/index.ts +22 -0
- package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
- package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
- package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
- package/src/benchmark/BenchmarkDataViewer.ts +324 -0
- package/src/benchmark/BenchmarkHistoryService.ts +221 -0
- package/src/benchmark/BenchmarkRunner.ts +685 -0
- package/src/benchmark/BenchmarkValidator.ts +204 -0
- package/src/benchmark/FastEvalRunner.ts +225 -0
- package/src/benchmark/MetricsValidator.ts +165 -0
- package/src/benchmark/MetricsVisualizer.ts +909 -0
- package/src/benchmark/ModelBenchmarkService.ts +611 -0
- package/src/benchmark/ModelRegistry.ts +158 -0
- package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
- package/src/benchmark/SimulationA2AInterface.ts +1169 -0
- package/src/benchmark/SimulationEngine.ts +832 -0
- package/src/benchmark/TaskRunner.ts +94 -0
- package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
- package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
- package/src/benchmark/index.ts +91 -0
- package/src/benchmark/parseSimulationMetrics.ts +124 -0
- package/src/benchmark/simulation-types.ts +78 -0
- package/src/dependencies.ts +475 -0
- package/src/generation/TrajectoryGenerator.ts +387 -0
- package/src/generation/index.ts +12 -0
- package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
- package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
- package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
- package/src/huggingface/index.ts +27 -0
- package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
- package/src/index.ts +102 -0
- package/src/init-training.ts +53 -0
- package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
- package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
- package/src/metrics/index.ts +8 -0
- package/src/metrics/types.ts +200 -0
- package/src/rubrics/__tests__/index.test.ts +184 -0
- package/src/rubrics/ass-kisser.ts +85 -0
- package/src/rubrics/degen.ts +80 -0
- package/src/rubrics/goody-twoshoes.ts +84 -0
- package/src/rubrics/index.ts +236 -0
- package/src/rubrics/information-trader.ts +84 -0
- package/src/rubrics/infosec.ts +101 -0
- package/src/rubrics/liar.ts +104 -0
- package/src/rubrics/perps-trader.ts +87 -0
- package/src/rubrics/researcher.ts +81 -0
- package/src/rubrics/scammer.ts +82 -0
- package/src/rubrics/social-butterfly.ts +73 -0
- package/src/rubrics/super-predictor.ts +97 -0
- package/src/rubrics/trader.ts +67 -0
- package/src/scoring/ArchetypeScoringService.ts +486 -0
- package/src/scoring/JudgePromptBuilder.ts +556 -0
- package/src/scoring/LLMJudgeCache.ts +401 -0
- package/src/scoring/index.ts +9 -0
- package/src/training/AutomationPipeline.ts +916 -0
- package/src/training/BenchmarkService.ts +518 -0
- package/src/training/ConfigValidator.ts +220 -0
- package/src/training/MarketOutcomesTracker.ts +187 -0
- package/src/training/ModelDeployer.ts +186 -0
- package/src/training/ModelFetcher.ts +76 -0
- package/src/training/ModelSelectionService.ts +341 -0
- package/src/training/ModelUsageVerifier.ts +160 -0
- package/src/training/MultiModelOrchestrator.ts +580 -0
- package/src/training/RLModelConfig.ts +407 -0
- package/src/training/RewardBackpropagationService.ts +149 -0
- package/src/training/RulerScoringService.ts +666 -0
- package/src/training/TrainingMonitor.ts +166 -0
- package/src/training/TrajectoryRecorder.ts +399 -0
- package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
- package/src/training/index.ts +100 -0
- package/src/training/logRLConfig.ts +34 -0
- package/src/training/pipeline.ts +129 -0
- package/src/training/storage/ModelStorageService.ts +279 -0
- package/src/training/storage/TrainingDataArchiver.ts +197 -0
- package/src/training/storage/index.ts +17 -0
- package/src/training/types.ts +207 -0
- package/src/training/window-utils.ts +138 -0
- package/src/utils/index.ts +101 -0
- package/src/utils/logger.ts +59 -0
- package/src/utils/snowflake.ts +17 -0
- package/src/utils/synthetic-detector.ts +111 -0
- package/tsconfig.json +20 -0
|
@@ -0,0 +1,481 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Trajectory Schema Definitions
|
|
3
|
+
|
|
4
|
+
Provides strict schema validation for trajectories to ensure data integrity
|
|
5
|
+
between TypeScript trajectory generation and Python training pipeline.
|
|
6
|
+
|
|
7
|
+
This module catches schema drift early and provides clear error messages
|
|
8
|
+
when data doesn't match expectations.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from dataclasses import dataclass, field
|
|
12
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
13
|
+
import json
|
|
14
|
+
import logging
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# ============================================================================
|
|
20
|
+
# Step Schemas
|
|
21
|
+
# ============================================================================
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class EnvironmentStateSchema:
|
|
25
|
+
"""Schema for step environment state"""
|
|
26
|
+
agent_balance: float = 0.0
|
|
27
|
+
agent_pnl: float = 0.0
|
|
28
|
+
agent_points: int = 0
|
|
29
|
+
open_positions: int = 0
|
|
30
|
+
timestamp: Optional[int] = None
|
|
31
|
+
|
|
32
|
+
# Optional reputation/influence fields
|
|
33
|
+
reputation_delta: Optional[int] = None
|
|
34
|
+
followers_gained: Optional[int] = None
|
|
35
|
+
positive_reactions: Optional[int] = None
|
|
36
|
+
information_spread: Optional[int] = None
|
|
37
|
+
|
|
38
|
+
@classmethod
|
|
39
|
+
def from_dict(cls, data: Dict[str, Any]) -> "EnvironmentStateSchema":
|
|
40
|
+
"""Create from dictionary with field name normalization"""
|
|
41
|
+
return cls(
|
|
42
|
+
agent_balance=data.get("agentBalance", data.get("agent_balance", 0.0)),
|
|
43
|
+
agent_pnl=data.get("agentPnL", data.get("agent_pnl", 0.0)),
|
|
44
|
+
agent_points=data.get("agentPoints", data.get("agent_points", 0)),
|
|
45
|
+
open_positions=data.get("openPositions", data.get("open_positions", 0)),
|
|
46
|
+
timestamp=data.get("timestamp"),
|
|
47
|
+
reputation_delta=data.get("reputationDelta", data.get("reputation_delta")),
|
|
48
|
+
followers_gained=data.get("followersGained", data.get("followers_gained")),
|
|
49
|
+
positive_reactions=data.get("positiveReactions", data.get("positive_reactions")),
|
|
50
|
+
information_spread=data.get("informationSpread", data.get("information_spread")),
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@dataclass
|
|
55
|
+
class ActionParametersSchema:
|
|
56
|
+
"""Schema for action parameters"""
|
|
57
|
+
# Trading parameters
|
|
58
|
+
ticker: Optional[str] = None
|
|
59
|
+
amount: Optional[float] = None
|
|
60
|
+
leverage: Optional[float] = None
|
|
61
|
+
confidence: Optional[float] = None
|
|
62
|
+
market_id: Optional[str] = None
|
|
63
|
+
|
|
64
|
+
# Social parameters
|
|
65
|
+
target_user_id: Optional[str] = None
|
|
66
|
+
recipient_id: Optional[str] = None
|
|
67
|
+
message: Optional[str] = None
|
|
68
|
+
|
|
69
|
+
# Archetype (for batch recording mode)
|
|
70
|
+
archetype: Optional[str] = None
|
|
71
|
+
|
|
72
|
+
@classmethod
|
|
73
|
+
def from_dict(cls, data: Dict[str, Any]) -> "ActionParametersSchema":
|
|
74
|
+
"""Create from dictionary"""
|
|
75
|
+
return cls(
|
|
76
|
+
ticker=data.get("ticker"),
|
|
77
|
+
amount=data.get("amount", data.get("size", data.get("quantity"))),
|
|
78
|
+
leverage=data.get("leverage"),
|
|
79
|
+
confidence=data.get("confidence"),
|
|
80
|
+
market_id=data.get("marketId", data.get("market")),
|
|
81
|
+
target_user_id=data.get("targetUserId"),
|
|
82
|
+
recipient_id=data.get("recipientId"),
|
|
83
|
+
message=data.get("message"),
|
|
84
|
+
archetype=data.get("archetype"),
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@dataclass
|
|
89
|
+
class ActionResultSchema:
|
|
90
|
+
"""Schema for action result"""
|
|
91
|
+
position_id: Optional[str] = None
|
|
92
|
+
pnl: Optional[float] = None
|
|
93
|
+
success: bool = True
|
|
94
|
+
error: Optional[str] = None
|
|
95
|
+
archetype: Optional[str] = None
|
|
96
|
+
|
|
97
|
+
# Prediction-specific
|
|
98
|
+
correct: Optional[bool] = None
|
|
99
|
+
prediction_correct: Optional[bool] = None
|
|
100
|
+
|
|
101
|
+
@classmethod
|
|
102
|
+
def from_dict(cls, data: Dict[str, Any]) -> "ActionResultSchema":
|
|
103
|
+
"""Create from dictionary"""
|
|
104
|
+
return cls(
|
|
105
|
+
position_id=data.get("positionId"),
|
|
106
|
+
pnl=data.get("pnl"),
|
|
107
|
+
success=data.get("success", True),
|
|
108
|
+
error=data.get("error"),
|
|
109
|
+
archetype=data.get("archetype"),
|
|
110
|
+
correct=data.get("correct"),
|
|
111
|
+
prediction_correct=data.get("predictionCorrect"),
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@dataclass
|
|
116
|
+
class ActionSchema:
|
|
117
|
+
"""Schema for trajectory action"""
|
|
118
|
+
action_type: str
|
|
119
|
+
parameters: ActionParametersSchema = field(default_factory=ActionParametersSchema)
|
|
120
|
+
success: bool = True
|
|
121
|
+
result: ActionResultSchema = field(default_factory=ActionResultSchema)
|
|
122
|
+
reasoning: Optional[str] = None
|
|
123
|
+
|
|
124
|
+
@classmethod
|
|
125
|
+
def from_dict(cls, data: Dict[str, Any]) -> "ActionSchema":
|
|
126
|
+
"""Create from dictionary with field name normalization"""
|
|
127
|
+
return cls(
|
|
128
|
+
action_type=data.get("actionType", data.get("action_type", "unknown")),
|
|
129
|
+
parameters=ActionParametersSchema.from_dict(data.get("parameters", {})),
|
|
130
|
+
success=data.get("success", True),
|
|
131
|
+
result=ActionResultSchema.from_dict(data.get("result", {})),
|
|
132
|
+
reasoning=data.get("reasoning"),
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
@dataclass
|
|
137
|
+
class LLMCallSchema:
|
|
138
|
+
"""Schema for LLM call within a step"""
|
|
139
|
+
model: str
|
|
140
|
+
purpose: str = "action"
|
|
141
|
+
system_prompt: Optional[str] = None
|
|
142
|
+
user_prompt: Optional[str] = None
|
|
143
|
+
response: Optional[str] = None
|
|
144
|
+
reasoning: Optional[str] = None
|
|
145
|
+
temperature: float = 0.7
|
|
146
|
+
max_tokens: int = 1000
|
|
147
|
+
latency_ms: Optional[int] = None
|
|
148
|
+
|
|
149
|
+
@classmethod
|
|
150
|
+
def from_dict(cls, data: Dict[str, Any]) -> "LLMCallSchema":
|
|
151
|
+
"""Create from dictionary with field name normalization"""
|
|
152
|
+
return cls(
|
|
153
|
+
model=data.get("model", "unknown"),
|
|
154
|
+
purpose=data.get("purpose", "action"),
|
|
155
|
+
system_prompt=data.get("systemPrompt", data.get("system_prompt")),
|
|
156
|
+
user_prompt=data.get("userPrompt", data.get("user_prompt")),
|
|
157
|
+
response=data.get("response"),
|
|
158
|
+
reasoning=data.get("reasoning"),
|
|
159
|
+
temperature=data.get("temperature", 0.7),
|
|
160
|
+
max_tokens=data.get("maxTokens", data.get("max_tokens", 1000)),
|
|
161
|
+
latency_ms=data.get("latencyMs", data.get("latency_ms")),
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
@dataclass
|
|
166
|
+
class StepSchema:
|
|
167
|
+
"""Schema for a single trajectory step"""
|
|
168
|
+
step_number: int
|
|
169
|
+
timestamp: Optional[int] = None
|
|
170
|
+
environment_state: EnvironmentStateSchema = field(default_factory=EnvironmentStateSchema)
|
|
171
|
+
action: ActionSchema = field(default_factory=lambda: ActionSchema(action_type="unknown"))
|
|
172
|
+
llm_calls: List[LLMCallSchema] = field(default_factory=list)
|
|
173
|
+
reward: float = 0.0
|
|
174
|
+
observation: Optional[Dict[str, Any]] = None
|
|
175
|
+
|
|
176
|
+
@classmethod
|
|
177
|
+
def from_dict(cls, data: Dict[str, Any]) -> "StepSchema":
|
|
178
|
+
"""Create from dictionary with field name normalization"""
|
|
179
|
+
llm_calls_raw = data.get("llmCalls", data.get("llm_calls", []))
|
|
180
|
+
llm_calls = [LLMCallSchema.from_dict(call) for call in llm_calls_raw]
|
|
181
|
+
|
|
182
|
+
return cls(
|
|
183
|
+
step_number=data.get("stepNumber", data.get("step_number", 0)),
|
|
184
|
+
timestamp=data.get("timestamp"),
|
|
185
|
+
environment_state=EnvironmentStateSchema.from_dict(
|
|
186
|
+
data.get("environmentState", data.get("environment_state", {}))
|
|
187
|
+
),
|
|
188
|
+
action=ActionSchema.from_dict(data.get("action", {})),
|
|
189
|
+
llm_calls=llm_calls,
|
|
190
|
+
reward=data.get("reward", 0.0),
|
|
191
|
+
observation=data.get("observation"),
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
# ============================================================================
|
|
196
|
+
# Trajectory Schema
|
|
197
|
+
# ============================================================================
|
|
198
|
+
|
|
199
|
+
@dataclass
|
|
200
|
+
class TrajectorySchema:
|
|
201
|
+
"""Schema for a complete trajectory"""
|
|
202
|
+
trajectory_id: str
|
|
203
|
+
agent_id: str
|
|
204
|
+
window_id: str
|
|
205
|
+
scenario_id: Optional[str] = None
|
|
206
|
+
archetype: Optional[str] = None
|
|
207
|
+
steps_json: str = "[]"
|
|
208
|
+
final_pnl: float = 0.0
|
|
209
|
+
final_balance: Optional[float] = None
|
|
210
|
+
episode_length: int = 0
|
|
211
|
+
total_reward: float = 0.0
|
|
212
|
+
trades_executed: int = 0
|
|
213
|
+
is_training_data: bool = True
|
|
214
|
+
|
|
215
|
+
@classmethod
|
|
216
|
+
def from_dict(cls, data: Dict[str, Any]) -> "TrajectorySchema":
|
|
217
|
+
"""Create from dictionary with field name normalization"""
|
|
218
|
+
return cls(
|
|
219
|
+
trajectory_id=data.get("trajectoryId", data.get("trajectory_id", "")),
|
|
220
|
+
agent_id=data.get("agentId", data.get("agent_id", "")),
|
|
221
|
+
window_id=data.get("windowId", data.get("window_id", "")),
|
|
222
|
+
scenario_id=data.get("scenarioId", data.get("scenario_id")),
|
|
223
|
+
archetype=data.get("archetype"),
|
|
224
|
+
steps_json=data.get("stepsJson", data.get("steps_json", "[]")),
|
|
225
|
+
final_pnl=float(data.get("finalPnL", data.get("final_pnl", 0.0))),
|
|
226
|
+
final_balance=data.get("finalBalance", data.get("final_balance")),
|
|
227
|
+
episode_length=data.get("episodeLength", data.get("episode_length", 0)),
|
|
228
|
+
total_reward=float(data.get("totalReward", data.get("total_reward", 0.0))),
|
|
229
|
+
trades_executed=data.get("tradesExecuted", data.get("trades_executed", 0)),
|
|
230
|
+
is_training_data=data.get("isTrainingData", data.get("is_training_data", True)),
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
def get_steps(self) -> List[StepSchema]:
|
|
234
|
+
"""Parse and return steps as StepSchema objects"""
|
|
235
|
+
try:
|
|
236
|
+
steps_raw = json.loads(self.steps_json)
|
|
237
|
+
return [StepSchema.from_dict(step) for step in steps_raw]
|
|
238
|
+
except json.JSONDecodeError:
|
|
239
|
+
return []
|
|
240
|
+
|
|
241
|
+
def extract_archetype_from_steps(self) -> Optional[str]:
|
|
242
|
+
"""Extract archetype from step action parameters if not set at trajectory level"""
|
|
243
|
+
if self.archetype:
|
|
244
|
+
return self.archetype
|
|
245
|
+
|
|
246
|
+
steps = self.get_steps()
|
|
247
|
+
for step in steps:
|
|
248
|
+
if step.action.parameters.archetype:
|
|
249
|
+
return step.action.parameters.archetype
|
|
250
|
+
if step.action.result.archetype:
|
|
251
|
+
return step.action.result.archetype
|
|
252
|
+
|
|
253
|
+
return None
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
# ============================================================================
|
|
257
|
+
# Validation Functions
|
|
258
|
+
# ============================================================================
|
|
259
|
+
|
|
260
|
+
def validate_trajectory(data: Dict[str, Any]) -> Tuple[bool, List[str]]:
|
|
261
|
+
"""
|
|
262
|
+
Validate trajectory data against schema.
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
Tuple of (is_valid, list of error messages)
|
|
266
|
+
"""
|
|
267
|
+
errors = []
|
|
268
|
+
|
|
269
|
+
# Required fields
|
|
270
|
+
required_fields = ["trajectoryId", "agentId", "windowId"]
|
|
271
|
+
for field_name in required_fields:
|
|
272
|
+
snake_case = _camel_to_snake(field_name)
|
|
273
|
+
if field_name not in data and snake_case not in data:
|
|
274
|
+
errors.append(f"Missing required field: {field_name}")
|
|
275
|
+
|
|
276
|
+
# Validate stepsJson if present
|
|
277
|
+
steps_json = data.get("stepsJson", data.get("steps_json", "[]"))
|
|
278
|
+
if steps_json:
|
|
279
|
+
try:
|
|
280
|
+
steps = json.loads(steps_json)
|
|
281
|
+
if not isinstance(steps, list):
|
|
282
|
+
errors.append(f"stepsJson must be an array, got {type(steps).__name__}")
|
|
283
|
+
elif len(steps) == 0:
|
|
284
|
+
errors.append("stepsJson is empty - trajectory has no steps")
|
|
285
|
+
else:
|
|
286
|
+
for i, step in enumerate(steps):
|
|
287
|
+
step_errors = _validate_step(step, i)
|
|
288
|
+
errors.extend(step_errors)
|
|
289
|
+
except json.JSONDecodeError as e:
|
|
290
|
+
errors.append(f"Invalid JSON in stepsJson: {e}")
|
|
291
|
+
|
|
292
|
+
# Validate numeric fields
|
|
293
|
+
pnl = data.get("finalPnL", data.get("final_pnl"))
|
|
294
|
+
if pnl is not None:
|
|
295
|
+
try:
|
|
296
|
+
float(pnl)
|
|
297
|
+
except (TypeError, ValueError):
|
|
298
|
+
errors.append(f"finalPnL must be a number, got {type(pnl).__name__}")
|
|
299
|
+
|
|
300
|
+
episode_length = data.get("episodeLength", data.get("episode_length"))
|
|
301
|
+
if episode_length is not None:
|
|
302
|
+
if not isinstance(episode_length, int) or episode_length < 0:
|
|
303
|
+
errors.append("episodeLength must be a non-negative integer")
|
|
304
|
+
|
|
305
|
+
return len(errors) == 0, errors
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def _validate_step(step: Dict[str, Any], index: int) -> List[str]:
|
|
309
|
+
"""Validate a single step"""
|
|
310
|
+
errors = []
|
|
311
|
+
prefix = f"Step {index}"
|
|
312
|
+
|
|
313
|
+
# stepNumber should exist
|
|
314
|
+
if "stepNumber" not in step and "step_number" not in step:
|
|
315
|
+
errors.append(f"{prefix}: missing stepNumber")
|
|
316
|
+
|
|
317
|
+
# action should exist and have actionType
|
|
318
|
+
action = step.get("action", {})
|
|
319
|
+
if not action:
|
|
320
|
+
errors.append(f"{prefix}: missing action")
|
|
321
|
+
elif "actionType" not in action and "action_type" not in action:
|
|
322
|
+
errors.append(f"{prefix}: action missing actionType")
|
|
323
|
+
|
|
324
|
+
# environmentState should exist
|
|
325
|
+
env_state = step.get("environmentState", step.get("environment_state"))
|
|
326
|
+
if not env_state:
|
|
327
|
+
errors.append(f"{prefix}: missing environmentState")
|
|
328
|
+
|
|
329
|
+
return errors
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def validate_step(data: Dict[str, Any]) -> Tuple[bool, List[str]]:
|
|
333
|
+
"""
|
|
334
|
+
Validate step data against schema.
|
|
335
|
+
|
|
336
|
+
Returns:
|
|
337
|
+
Tuple of (is_valid, list of error messages)
|
|
338
|
+
"""
|
|
339
|
+
errors = _validate_step(data, 0)
|
|
340
|
+
return len(errors) == 0, errors
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def validate_llm_call(data: Dict[str, Any]) -> Tuple[bool, List[str]]:
|
|
344
|
+
"""
|
|
345
|
+
Validate LLM call data against schema.
|
|
346
|
+
|
|
347
|
+
Returns:
|
|
348
|
+
Tuple of (is_valid, list of error messages)
|
|
349
|
+
"""
|
|
350
|
+
errors = []
|
|
351
|
+
|
|
352
|
+
if "model" not in data:
|
|
353
|
+
errors.append("LLM call missing 'model' field")
|
|
354
|
+
|
|
355
|
+
return len(errors) == 0, errors
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def _camel_to_snake(name: str) -> str:
|
|
359
|
+
"""Convert camelCase to snake_case"""
|
|
360
|
+
result = []
|
|
361
|
+
for i, char in enumerate(name):
|
|
362
|
+
if char.isupper() and i > 0:
|
|
363
|
+
result.append("_")
|
|
364
|
+
result.append(char.lower())
|
|
365
|
+
return "".join(result)
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
# ============================================================================
|
|
369
|
+
# Schema Comparison
|
|
370
|
+
# ============================================================================
|
|
371
|
+
|
|
372
|
+
def compare_trajectory_formats(
|
|
373
|
+
json_data: Dict[str, Any],
|
|
374
|
+
db_data: Dict[str, Any],
|
|
375
|
+
) -> Tuple[bool, List[str]]:
|
|
376
|
+
"""
|
|
377
|
+
Compare trajectory data from JSON and database formats.
|
|
378
|
+
|
|
379
|
+
Returns:
|
|
380
|
+
Tuple of (are_equivalent, list of difference descriptions)
|
|
381
|
+
"""
|
|
382
|
+
differences = []
|
|
383
|
+
|
|
384
|
+
# Map of JSON field names to DB field names
|
|
385
|
+
field_mapping = {
|
|
386
|
+
"trajectoryId": "trajectoryId",
|
|
387
|
+
"agentId": "agentId",
|
|
388
|
+
"windowId": "windowId",
|
|
389
|
+
"scenarioId": "scenarioId",
|
|
390
|
+
"archetype": "archetype",
|
|
391
|
+
"stepsJson": "stepsJson",
|
|
392
|
+
"finalPnL": "finalPnL",
|
|
393
|
+
"episodeLength": "episodeLength",
|
|
394
|
+
}
|
|
395
|
+
|
|
396
|
+
for json_field, db_field in field_mapping.items():
|
|
397
|
+
json_val = json_data.get(json_field)
|
|
398
|
+
db_val = db_data.get(db_field)
|
|
399
|
+
|
|
400
|
+
if json_val != db_val:
|
|
401
|
+
# Special handling for numeric comparison
|
|
402
|
+
if isinstance(json_val, (int, float)) and isinstance(db_val, (int, float)):
|
|
403
|
+
if abs(float(json_val) - float(db_val)) < 0.001:
|
|
404
|
+
continue # Close enough
|
|
405
|
+
|
|
406
|
+
differences.append(f"{json_field}: JSON={json_val!r}, DB={db_val!r}")
|
|
407
|
+
|
|
408
|
+
return len(differences) == 0, differences
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
# ============================================================================
|
|
412
|
+
# Export Validation Results
|
|
413
|
+
# ============================================================================
|
|
414
|
+
|
|
415
|
+
@dataclass
|
|
416
|
+
class ValidationResult:
|
|
417
|
+
"""Result of schema validation"""
|
|
418
|
+
is_valid: bool
|
|
419
|
+
errors: List[str]
|
|
420
|
+
warnings: List[str] = field(default_factory=list)
|
|
421
|
+
|
|
422
|
+
def __bool__(self) -> bool:
|
|
423
|
+
return self.is_valid
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def validate_trajectory_file(file_path: str) -> ValidationResult:
|
|
427
|
+
"""
|
|
428
|
+
Validate a trajectory JSON file.
|
|
429
|
+
|
|
430
|
+
Args:
|
|
431
|
+
file_path: Path to JSON file
|
|
432
|
+
|
|
433
|
+
Returns:
|
|
434
|
+
ValidationResult with validation status and any errors/warnings
|
|
435
|
+
"""
|
|
436
|
+
errors = []
|
|
437
|
+
warnings = []
|
|
438
|
+
|
|
439
|
+
try:
|
|
440
|
+
with open(file_path, "r") as f:
|
|
441
|
+
data = json.load(f)
|
|
442
|
+
except json.JSONDecodeError as e:
|
|
443
|
+
return ValidationResult(False, [f"Invalid JSON: {e}"])
|
|
444
|
+
except FileNotFoundError:
|
|
445
|
+
return ValidationResult(False, [f"File not found: {file_path}"])
|
|
446
|
+
|
|
447
|
+
# Check for trajectory wrapper
|
|
448
|
+
if "trajectory" not in data:
|
|
449
|
+
warnings.append("Missing 'trajectory' wrapper - treating root as trajectory data")
|
|
450
|
+
traj_data = data
|
|
451
|
+
else:
|
|
452
|
+
traj_data = data["trajectory"]
|
|
453
|
+
|
|
454
|
+
# Validate trajectory
|
|
455
|
+
is_valid, traj_errors = validate_trajectory(traj_data)
|
|
456
|
+
errors.extend(traj_errors)
|
|
457
|
+
|
|
458
|
+
# Check for archetype
|
|
459
|
+
archetype = traj_data.get("archetype")
|
|
460
|
+
if not archetype:
|
|
461
|
+
# Try to find in steps
|
|
462
|
+
steps_json = traj_data.get("stepsJson", "[]")
|
|
463
|
+
try:
|
|
464
|
+
steps = json.loads(steps_json)
|
|
465
|
+
found_archetype = False
|
|
466
|
+
for step in steps:
|
|
467
|
+
params = step.get("action", {}).get("parameters", {})
|
|
468
|
+
if params.get("archetype"):
|
|
469
|
+
found_archetype = True
|
|
470
|
+
break
|
|
471
|
+
if not found_archetype:
|
|
472
|
+
warnings.append("No archetype found at trajectory or step level - will use 'default'")
|
|
473
|
+
except json.JSONDecodeError:
|
|
474
|
+
pass # Already caught above
|
|
475
|
+
|
|
476
|
+
return ValidationResult(
|
|
477
|
+
is_valid=len(errors) == 0,
|
|
478
|
+
errors=errors,
|
|
479
|
+
warnings=warnings,
|
|
480
|
+
)
|
|
481
|
+
|