@elizaos/training 2.0.0-alpha.11
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Dockerfile +75 -0
- package/Makefile +374 -0
- package/README.md +346 -0
- package/config/rubrics.json +137 -0
- package/data/.gitkeep +0 -0
- package/data/degen/.gitkeep +2 -0
- package/data/trader/.gitkeep +2 -0
- package/docker-compose.test.yml +57 -0
- package/package.json +58 -0
- package/python/config/babylon_atropos.yaml +90 -0
- package/python/config/profiles/12gb.json +11 -0
- package/python/config/profiles/16gb.json +10 -0
- package/python/config/profiles/24gb.json +10 -0
- package/python/config/profiles/48gb.json +10 -0
- package/python/config/profiles/cpu.json +11 -0
- package/python/config/profiles/l40-2gpu-safe.json +20 -0
- package/python/config/profiles/l40-2gpu.json +22 -0
- package/python/config/profiles/l40-4gpu.json +21 -0
- package/python/config/profiles/l40.json +17 -0
- package/python/config/tinker_training.yaml +143 -0
- package/python/curriculum_state.json +165 -0
- package/python/env.template +86 -0
- package/python/env.training.template +46 -0
- package/python/pyproject.toml +41 -0
- package/python/requirements-ci.txt +31 -0
- package/python/requirements.txt +87 -0
- package/python/scripts/__init__.py +4 -0
- package/python/scripts/import_json_trajectories.py +412 -0
- package/python/scripts/local-finetune/README.md +63 -0
- package/python/scripts/local-finetune/ingest_and_score.py +139 -0
- package/python/scripts/local-finetune/merge_model.py +32 -0
- package/python/scripts/local-finetune/test_adapter.py +91 -0
- package/python/scripts/local-finetune/train_from_csv.py +132 -0
- package/python/scripts/merge_trajectories.py +318 -0
- package/python/scripts/run_ab_test.py +143 -0
- package/python/scripts/run_full_pipeline.py +544 -0
- package/python/scripts/run_tinker_training.py +192 -0
- package/python/scripts/run_training.py +914 -0
- package/python/scripts/test_judge.py +155 -0
- package/python/scripts/test_pipeline.py +356 -0
- package/python/scripts/test_trained_model.py +380 -0
- package/python/scripts/train_local.py +528 -0
- package/python/setup.py +20 -0
- package/python/src/__init__.py +190 -0
- package/python/src/data_bridge/__init__.py +24 -0
- package/python/src/data_bridge/converter.py +435 -0
- package/python/src/data_bridge/reader.py +393 -0
- package/python/src/models.py +283 -0
- package/python/src/training/__init__.py +605 -0
- package/python/src/training/ab_testing.py +404 -0
- package/python/src/training/action_executor.py +621 -0
- package/python/src/training/archetype_trainer.py +347 -0
- package/python/src/training/atropos_trainer.py +980 -0
- package/python/src/training/babylon_env.py +1254 -0
- package/python/src/training/error_recovery.py +647 -0
- package/python/src/training/evaluation.py +856 -0
- package/python/src/training/fast_simulator.py +880 -0
- package/python/src/training/format_validator.py +584 -0
- package/python/src/training/hybrid_env.py +522 -0
- package/python/src/training/kl_controller.py +628 -0
- package/python/src/training/multi_prompt_dataset.py +883 -0
- package/python/src/training/multi_turn.py +656 -0
- package/python/src/training/online_env.py +1084 -0
- package/python/src/training/quality_scorer.py +391 -0
- package/python/src/training/quality_utils.py +633 -0
- package/python/src/training/rewards.py +1344 -0
- package/python/src/training/rlaif_env.py +17 -0
- package/python/src/training/rollout_generator.py +502 -0
- package/python/src/training/rubric_loader.py +198 -0
- package/python/src/training/scenario_pool.py +1072 -0
- package/python/src/training/schemas.py +481 -0
- package/python/src/training/service_manager.py +552 -0
- package/python/src/training/simulation_bridge.py +535 -0
- package/python/src/training/tick_reward_attribution.py +399 -0
- package/python/src/training/tinker_client.py +575 -0
- package/python/src/training/tinker_trainer.py +646 -0
- package/python/src/training/tokenization_utils.py +402 -0
- package/python/tests/e2e/__init__.py +13 -0
- package/python/tests/e2e/conftest.py +258 -0
- package/python/tests/e2e/test_full_pipeline.py +643 -0
- package/python/tests/e2e/test_online_training_e2e.py +365 -0
- package/python/tests/integration/__init__.py +12 -0
- package/python/tests/integration/conftest.py +383 -0
- package/python/tests/integration/test_db_integration.py +649 -0
- package/python/tests/integration/test_json_mode_integration.py +554 -0
- package/python/tests/test_action_executor.py +594 -0
- package/python/tests/test_archetype_scoring.py +1027 -0
- package/python/tests/test_atropos_integration.py +360 -0
- package/python/tests/test_evaluation.py +727 -0
- package/python/tests/test_format_validator.py +486 -0
- package/python/tests/test_kl_controller.py +432 -0
- package/python/tests/test_lr_scheduler.py +579 -0
- package/python/tests/test_multi_turn.py +590 -0
- package/python/tests/test_online_env.py +519 -0
- package/python/tests/test_quality_scorer.py +474 -0
- package/python/tests/test_scenario_pool.py +735 -0
- package/python/tests/test_service_manager.py +585 -0
- package/python/tests/test_simulation_rollout.py +581 -0
- package/python/tests/test_tokenization_utils.py +501 -0
- package/python/tests/test_training_orchestrator.py +497 -0
- package/python/tests/test_training_output_structure.py +661 -0
- package/research-output/training-runs/training-run-1770772042899.json +26 -0
- package/research-output/training-runs/training-run-1770930079670.json +32 -0
- package/research-output/training-runs/training-run-1770930143700.json +44 -0
- package/research-output/training-runs/training-run-1770930183638.json +38 -0
- package/research-output/training-runs/training-run-1770930442049.json +38 -0
- package/research-output/training-runs/training-run-1770930793243.json +38 -0
- package/scripts/assess-training-data.ts +422 -0
- package/scripts/e2e-training-test.ts +550 -0
- package/scripts/export-rubrics.ts +64 -0
- package/scripts/generate-research-report.ts +1523 -0
- package/scripts/generate_dataset.sh +173 -0
- package/scripts/json-mode-benchmark.ts +399 -0
- package/scripts/real-archetype-benchmark.ts +210 -0
- package/scripts/run-baseline-comparison.ts +116 -0
- package/scripts/run-full-pipeline.ts +272 -0
- package/scripts/runpod_setup.sh +137 -0
- package/scripts/runpod_validate.sh +147 -0
- package/scripts/test-model-in-game.ts +955 -0
- package/scripts/test-scoring.ts +73 -0
- package/scripts/test-trained-model.ts +209 -0
- package/scripts/train-and-test.ts +824 -0
- package/scripts/verify-final.ts +118 -0
- package/src/adapter.ts +516 -0
- package/src/archetypes/ArchetypeConfigService.ts +626 -0
- package/src/archetypes/derive-archetype.ts +249 -0
- package/src/archetypes/index.ts +22 -0
- package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
- package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
- package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
- package/src/benchmark/BenchmarkDataViewer.ts +324 -0
- package/src/benchmark/BenchmarkHistoryService.ts +221 -0
- package/src/benchmark/BenchmarkRunner.ts +685 -0
- package/src/benchmark/BenchmarkValidator.ts +206 -0
- package/src/benchmark/FastEvalRunner.ts +225 -0
- package/src/benchmark/MetricsValidator.ts +165 -0
- package/src/benchmark/MetricsVisualizer.ts +909 -0
- package/src/benchmark/ModelBenchmarkService.ts +611 -0
- package/src/benchmark/ModelRegistry.ts +158 -0
- package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
- package/src/benchmark/SimulationA2AInterface.ts +1169 -0
- package/src/benchmark/SimulationEngine.ts +832 -0
- package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
- package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
- package/src/benchmark/index.ts +89 -0
- package/src/benchmark/parseSimulationMetrics.ts +124 -0
- package/src/benchmark/simulation-types.ts +78 -0
- package/src/dependencies.ts +439 -0
- package/src/generation/TrajectoryGenerator.ts +387 -0
- package/src/generation/index.ts +12 -0
- package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
- package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
- package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
- package/src/huggingface/index.ts +27 -0
- package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
- package/src/index.ts +102 -0
- package/src/init-training.ts +53 -0
- package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
- package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
- package/src/metrics/index.ts +8 -0
- package/src/metrics/types.ts +200 -0
- package/src/rubrics/__tests__/index.test.ts +184 -0
- package/src/rubrics/ass-kisser.ts +85 -0
- package/src/rubrics/degen.ts +80 -0
- package/src/rubrics/goody-twoshoes.ts +84 -0
- package/src/rubrics/index.ts +236 -0
- package/src/rubrics/information-trader.ts +84 -0
- package/src/rubrics/infosec.ts +101 -0
- package/src/rubrics/liar.ts +104 -0
- package/src/rubrics/perps-trader.ts +87 -0
- package/src/rubrics/researcher.ts +81 -0
- package/src/rubrics/scammer.ts +82 -0
- package/src/rubrics/social-butterfly.ts +73 -0
- package/src/rubrics/super-predictor.ts +97 -0
- package/src/rubrics/trader.ts +67 -0
- package/src/scoring/ArchetypeScoringService.ts +486 -0
- package/src/scoring/JudgePromptBuilder.ts +556 -0
- package/src/scoring/LLMJudgeCache.ts +401 -0
- package/src/scoring/index.ts +9 -0
- package/src/training/AutomationPipeline.ts +916 -0
- package/src/training/BenchmarkService.ts +518 -0
- package/src/training/ConfigValidator.ts +220 -0
- package/src/training/MarketOutcomesTracker.ts +187 -0
- package/src/training/ModelDeployer.ts +186 -0
- package/src/training/ModelFetcher.ts +76 -0
- package/src/training/ModelSelectionService.ts +341 -0
- package/src/training/ModelUsageVerifier.ts +160 -0
- package/src/training/MultiModelOrchestrator.ts +580 -0
- package/src/training/RLModelConfig.ts +407 -0
- package/src/training/RewardBackpropagationService.ts +149 -0
- package/src/training/RulerScoringService.ts +666 -0
- package/src/training/TrainingMonitor.ts +166 -0
- package/src/training/TrajectoryRecorder.ts +399 -0
- package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
- package/src/training/index.ts +100 -0
- package/src/training/logRLConfig.ts +34 -0
- package/src/training/pipeline.ts +129 -0
- package/src/training/storage/ModelStorageService.ts +279 -0
- package/src/training/storage/TrainingDataArchiver.ts +197 -0
- package/src/training/storage/index.ts +17 -0
- package/src/training/types.ts +207 -0
- package/src/training/window-utils.ts +138 -0
- package/src/utils/index.ts +101 -0
- package/src/utils/logger.ts +59 -0
- package/src/utils/snowflake.ts +17 -0
- package/src/utils/synthetic-detector.ts +111 -0
- package/tsconfig.json +20 -0
|
@@ -0,0 +1,347 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Archetype-Aware Training Pipeline
|
|
3
|
+
|
|
4
|
+
Train agents with different "values" using archetype-specific rubrics.
|
|
5
|
+
Supports training single archetypes, multiple archetypes, or all archetypes at once.
|
|
6
|
+
|
|
7
|
+
Usage:
|
|
8
|
+
# Train a single archetype
|
|
9
|
+
trainer = ArchetypeTrainer()
|
|
10
|
+
await trainer.train_archetype("trader")
|
|
11
|
+
|
|
12
|
+
# Train multiple archetypes
|
|
13
|
+
await trainer.train_archetypes(["trader", "scammer", "social-butterfly"])
|
|
14
|
+
|
|
15
|
+
# Train all archetypes
|
|
16
|
+
await trainer.train_all_archetypes()
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import asyncio
|
|
20
|
+
import logging
|
|
21
|
+
import os
|
|
22
|
+
from dataclasses import dataclass
|
|
23
|
+
from pathlib import Path
|
|
24
|
+
from typing import Dict, List, Optional
|
|
25
|
+
|
|
26
|
+
# Import rubrics from centralized loader (single source of truth)
|
|
27
|
+
from .rubric_loader import (
|
|
28
|
+
get_rubric,
|
|
29
|
+
get_priority_metrics,
|
|
30
|
+
get_available_archetypes,
|
|
31
|
+
reload_rubrics,
|
|
32
|
+
DEFAULT_RUBRIC,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
logger = logging.getLogger(__name__)
|
|
36
|
+
|
|
37
|
+
# ============================================================================
|
|
38
|
+
# Archetype Rubrics - Loaded from config/rubrics.json via rubric_loader
|
|
39
|
+
# ============================================================================
|
|
40
|
+
#
|
|
41
|
+
# All rubrics are now defined in packages/training/config/rubrics.json
|
|
42
|
+
# This is the single source of truth shared between TypeScript and Python.
|
|
43
|
+
#
|
|
44
|
+
# Use these functions (imported from rubric_loader):
|
|
45
|
+
# get_rubric(archetype) - Get the rubric text for an archetype
|
|
46
|
+
# get_priority_metrics(archetype) - Get priority metrics for scoring
|
|
47
|
+
# get_available_archetypes() - Get list of all archetypes
|
|
48
|
+
# reload_rubrics() - Reload rubrics from JSON file
|
|
49
|
+
# DEFAULT_RUBRIC - Fallback rubric for unknown archetypes
|
|
50
|
+
# ============================================================================
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
# ============================================================================
|
|
54
|
+
# Archetype Training Configuration
|
|
55
|
+
# ============================================================================
|
|
56
|
+
|
|
57
|
+
@dataclass
|
|
58
|
+
class ArchetypeTrainingConfig:
|
|
59
|
+
"""Configuration for archetype-specific training"""
|
|
60
|
+
|
|
61
|
+
# Model settings
|
|
62
|
+
base_model: str = "Qwen/Qwen3-4B"
|
|
63
|
+
|
|
64
|
+
# Training hyperparameters
|
|
65
|
+
training_steps: int = 100
|
|
66
|
+
batch_size: int = 4
|
|
67
|
+
learning_rate: float = 1e-5
|
|
68
|
+
|
|
69
|
+
# Data settings
|
|
70
|
+
min_trajectories_per_archetype: int = 10
|
|
71
|
+
lookback_hours: int = 72
|
|
72
|
+
|
|
73
|
+
# Output settings
|
|
74
|
+
output_dir: str = "./trained_models"
|
|
75
|
+
save_per_archetype: bool = True
|
|
76
|
+
|
|
77
|
+
# Judge settings
|
|
78
|
+
judge_model: str = "gpt-4o-mini"
|
|
79
|
+
|
|
80
|
+
# Logging
|
|
81
|
+
log_to_file: bool = True
|
|
82
|
+
log_dir: str = "./logs"
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@dataclass
|
|
86
|
+
class ArchetypeTrainingResult:
|
|
87
|
+
"""Result of training for a specific archetype"""
|
|
88
|
+
archetype: str
|
|
89
|
+
trajectories_used: int
|
|
90
|
+
training_steps: int
|
|
91
|
+
final_loss: float
|
|
92
|
+
checkpoint_path: str
|
|
93
|
+
metrics: Dict
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
# ============================================================================
|
|
97
|
+
# Main Archetype Trainer
|
|
98
|
+
# ============================================================================
|
|
99
|
+
|
|
100
|
+
class ArchetypeTrainer:
|
|
101
|
+
"""
|
|
102
|
+
Multi-archetype training orchestrator.
|
|
103
|
+
|
|
104
|
+
Makes it easy to train agents with different values/goals.
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
def __init__(self, config: Optional[ArchetypeTrainingConfig] = None):
|
|
108
|
+
self.config = config or ArchetypeTrainingConfig()
|
|
109
|
+
self._ensure_dirs()
|
|
110
|
+
|
|
111
|
+
def _ensure_dirs(self):
|
|
112
|
+
"""Create output directories if they don't exist"""
|
|
113
|
+
Path(self.config.output_dir).mkdir(parents=True, exist_ok=True)
|
|
114
|
+
Path(self.config.log_dir).mkdir(parents=True, exist_ok=True)
|
|
115
|
+
|
|
116
|
+
async def train_archetype(
|
|
117
|
+
self,
|
|
118
|
+
archetype: str,
|
|
119
|
+
trajectories: Optional[List] = None,
|
|
120
|
+
) -> ArchetypeTrainingResult:
|
|
121
|
+
"""
|
|
122
|
+
Train a single archetype.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
archetype: Name of the archetype to train (e.g., "trader", "scammer")
|
|
126
|
+
trajectories: Optional pre-loaded trajectories. If None, loads from DB.
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
ArchetypeTrainingResult with training metrics and checkpoint path
|
|
130
|
+
"""
|
|
131
|
+
from .babylon_env import BabylonEnvConfig
|
|
132
|
+
from .atropos_trainer import BabylonAtroposTrainer, AtroposTrainingConfig
|
|
133
|
+
|
|
134
|
+
logger.info(f"Starting training for archetype: {archetype}")
|
|
135
|
+
|
|
136
|
+
# Get archetype-specific rubric
|
|
137
|
+
rubric = get_rubric(archetype)
|
|
138
|
+
|
|
139
|
+
# Configure environment with archetype rubric
|
|
140
|
+
# Note: env_config is prepared for when the BabylonRLAIFEnv is started
|
|
141
|
+
# In the full pipeline, this would be passed to the environment server
|
|
142
|
+
_ = BabylonEnvConfig(
|
|
143
|
+
scoring_rubric=rubric,
|
|
144
|
+
judge_model=self.config.judge_model,
|
|
145
|
+
lookback_hours=self.config.lookback_hours,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# Configure trainer
|
|
149
|
+
trainer_config = AtroposTrainingConfig(
|
|
150
|
+
model_name=self.config.base_model,
|
|
151
|
+
training_steps=self.config.training_steps,
|
|
152
|
+
batch_size=self.config.batch_size,
|
|
153
|
+
learning_rate=self.config.learning_rate,
|
|
154
|
+
log_to_file=self.config.log_to_file,
|
|
155
|
+
log_file=f"{self.config.log_dir}/training_{archetype}.jsonl",
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# Initialize trainer
|
|
159
|
+
trainer = BabylonAtroposTrainer(trainer_config)
|
|
160
|
+
|
|
161
|
+
# Run training
|
|
162
|
+
result = await trainer.train()
|
|
163
|
+
|
|
164
|
+
# Build output
|
|
165
|
+
checkpoint_path = result.get("final_checkpoint", "")
|
|
166
|
+
|
|
167
|
+
# Rename checkpoint to include archetype
|
|
168
|
+
if checkpoint_path and self.config.save_per_archetype:
|
|
169
|
+
archetype_path = f"{self.config.output_dir}/{archetype}_model"
|
|
170
|
+
import shutil
|
|
171
|
+
if os.path.exists(checkpoint_path):
|
|
172
|
+
shutil.copytree(checkpoint_path, archetype_path, dirs_exist_ok=True)
|
|
173
|
+
checkpoint_path = archetype_path
|
|
174
|
+
|
|
175
|
+
return ArchetypeTrainingResult(
|
|
176
|
+
archetype=archetype,
|
|
177
|
+
trajectories_used=result.get("steps", 0) * self.config.batch_size,
|
|
178
|
+
training_steps=result.get("steps", 0),
|
|
179
|
+
final_loss=result.get("metrics", [{}])[-1].get("loss", 0) if result.get("metrics") else 0,
|
|
180
|
+
checkpoint_path=checkpoint_path,
|
|
181
|
+
metrics={"training_metrics": result.get("metrics", [])},
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
async def train_archetypes(
|
|
185
|
+
self,
|
|
186
|
+
archetypes: List[str],
|
|
187
|
+
parallel: bool = False,
|
|
188
|
+
) -> List[ArchetypeTrainingResult]:
|
|
189
|
+
"""
|
|
190
|
+
Train multiple archetypes.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
archetypes: List of archetype names to train
|
|
194
|
+
parallel: If True, train archetypes in parallel (requires more resources)
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
List of ArchetypeTrainingResult for each archetype
|
|
198
|
+
"""
|
|
199
|
+
logger.info(f"Training {len(archetypes)} archetypes: {archetypes}")
|
|
200
|
+
|
|
201
|
+
if parallel:
|
|
202
|
+
# Train in parallel (requires significant resources)
|
|
203
|
+
tasks = [self.train_archetype(arch) for arch in archetypes]
|
|
204
|
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
205
|
+
|
|
206
|
+
# Filter out exceptions
|
|
207
|
+
valid_results = []
|
|
208
|
+
for i, result in enumerate(results):
|
|
209
|
+
if isinstance(result, Exception):
|
|
210
|
+
logger.error(f"Failed to train {archetypes[i]}: {result}")
|
|
211
|
+
else:
|
|
212
|
+
valid_results.append(result)
|
|
213
|
+
return valid_results
|
|
214
|
+
else:
|
|
215
|
+
# Train sequentially (safer, less resource-intensive)
|
|
216
|
+
results = []
|
|
217
|
+
for archetype in archetypes:
|
|
218
|
+
try:
|
|
219
|
+
result = await self.train_archetype(archetype)
|
|
220
|
+
results.append(result)
|
|
221
|
+
except Exception as e:
|
|
222
|
+
logger.error(f"Failed to train {archetype}: {e}")
|
|
223
|
+
return results
|
|
224
|
+
|
|
225
|
+
async def train_all_archetypes(
|
|
226
|
+
self,
|
|
227
|
+
parallel: bool = False,
|
|
228
|
+
) -> List[ArchetypeTrainingResult]:
|
|
229
|
+
"""
|
|
230
|
+
Train ALL available archetypes.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
parallel: If True, train in parallel
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
List of ArchetypeTrainingResult for all archetypes
|
|
237
|
+
"""
|
|
238
|
+
all_archetypes = get_available_archetypes()
|
|
239
|
+
return await self.train_archetypes(all_archetypes, parallel=parallel)
|
|
240
|
+
|
|
241
|
+
def get_trained_model_path(self, archetype: str) -> Optional[str]:
|
|
242
|
+
"""Get path to trained model for an archetype"""
|
|
243
|
+
path = f"{self.config.output_dir}/{archetype}_model"
|
|
244
|
+
return path if os.path.exists(path) else None
|
|
245
|
+
|
|
246
|
+
def list_trained_archetypes(self) -> List[str]:
|
|
247
|
+
"""List all archetypes that have been trained"""
|
|
248
|
+
output_dir = Path(self.config.output_dir)
|
|
249
|
+
trained = []
|
|
250
|
+
for arch in get_available_archetypes():
|
|
251
|
+
if (output_dir / f"{arch}_model").exists():
|
|
252
|
+
trained.append(arch)
|
|
253
|
+
return trained
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
# ============================================================================
|
|
257
|
+
# CLI Entry Point
|
|
258
|
+
# ============================================================================
|
|
259
|
+
|
|
260
|
+
def main():
|
|
261
|
+
"""CLI entry point for archetype training"""
|
|
262
|
+
import argparse
|
|
263
|
+
|
|
264
|
+
parser = argparse.ArgumentParser(description="Train agents with archetype-specific values")
|
|
265
|
+
parser.add_argument(
|
|
266
|
+
"--archetype",
|
|
267
|
+
type=str,
|
|
268
|
+
default=None,
|
|
269
|
+
help="Single archetype to train (e.g., 'trader', 'scammer')"
|
|
270
|
+
)
|
|
271
|
+
parser.add_argument(
|
|
272
|
+
"--archetypes",
|
|
273
|
+
type=str,
|
|
274
|
+
nargs="+",
|
|
275
|
+
default=None,
|
|
276
|
+
help="Multiple archetypes to train (e.g., --archetypes trader scammer)"
|
|
277
|
+
)
|
|
278
|
+
parser.add_argument(
|
|
279
|
+
"--all",
|
|
280
|
+
action="store_true",
|
|
281
|
+
help="Train all available archetypes"
|
|
282
|
+
)
|
|
283
|
+
parser.add_argument(
|
|
284
|
+
"--parallel",
|
|
285
|
+
action="store_true",
|
|
286
|
+
help="Train archetypes in parallel (requires more resources)"
|
|
287
|
+
)
|
|
288
|
+
parser.add_argument(
|
|
289
|
+
"--list",
|
|
290
|
+
action="store_true",
|
|
291
|
+
help="List all available archetypes"
|
|
292
|
+
)
|
|
293
|
+
parser.add_argument(
|
|
294
|
+
"--steps",
|
|
295
|
+
type=int,
|
|
296
|
+
default=100,
|
|
297
|
+
help="Training steps per archetype"
|
|
298
|
+
)
|
|
299
|
+
parser.add_argument(
|
|
300
|
+
"--output-dir",
|
|
301
|
+
type=str,
|
|
302
|
+
default="./trained_models",
|
|
303
|
+
help="Directory to save trained models"
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
args = parser.parse_args()
|
|
307
|
+
|
|
308
|
+
if args.list:
|
|
309
|
+
print("Available archetypes:")
|
|
310
|
+
for arch in get_available_archetypes():
|
|
311
|
+
print(f" - {arch}")
|
|
312
|
+
return
|
|
313
|
+
|
|
314
|
+
config = ArchetypeTrainingConfig(
|
|
315
|
+
training_steps=args.steps,
|
|
316
|
+
output_dir=args.output_dir,
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
trainer = ArchetypeTrainer(config)
|
|
320
|
+
|
|
321
|
+
async def run():
|
|
322
|
+
if args.all:
|
|
323
|
+
results = await trainer.train_all_archetypes(parallel=args.parallel)
|
|
324
|
+
elif args.archetypes:
|
|
325
|
+
results = await trainer.train_archetypes(args.archetypes, parallel=args.parallel)
|
|
326
|
+
elif args.archetype:
|
|
327
|
+
result = await trainer.train_archetype(args.archetype)
|
|
328
|
+
results = [result]
|
|
329
|
+
else:
|
|
330
|
+
print("Please specify --archetype, --archetypes, or --all")
|
|
331
|
+
print("Use --list to see available archetypes")
|
|
332
|
+
return
|
|
333
|
+
|
|
334
|
+
print("\n" + "=" * 60)
|
|
335
|
+
print("TRAINING COMPLETE")
|
|
336
|
+
print("=" * 60)
|
|
337
|
+
for r in results:
|
|
338
|
+
print(f"\n{r.archetype}:")
|
|
339
|
+
print(f" Steps: {r.training_steps}")
|
|
340
|
+
print(f" Final Loss: {r.final_loss:.4f}")
|
|
341
|
+
print(f" Checkpoint: {r.checkpoint_path}")
|
|
342
|
+
|
|
343
|
+
asyncio.run(run())
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
if __name__ == "__main__":
|
|
347
|
+
main()
|