@elizaos/training 2.0.0-alpha.11
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Dockerfile +75 -0
- package/Makefile +374 -0
- package/README.md +346 -0
- package/config/rubrics.json +137 -0
- package/data/.gitkeep +0 -0
- package/data/degen/.gitkeep +2 -0
- package/data/trader/.gitkeep +2 -0
- package/docker-compose.test.yml +57 -0
- package/package.json +58 -0
- package/python/config/babylon_atropos.yaml +90 -0
- package/python/config/profiles/12gb.json +11 -0
- package/python/config/profiles/16gb.json +10 -0
- package/python/config/profiles/24gb.json +10 -0
- package/python/config/profiles/48gb.json +10 -0
- package/python/config/profiles/cpu.json +11 -0
- package/python/config/profiles/l40-2gpu-safe.json +20 -0
- package/python/config/profiles/l40-2gpu.json +22 -0
- package/python/config/profiles/l40-4gpu.json +21 -0
- package/python/config/profiles/l40.json +17 -0
- package/python/config/tinker_training.yaml +143 -0
- package/python/curriculum_state.json +165 -0
- package/python/env.template +86 -0
- package/python/env.training.template +46 -0
- package/python/pyproject.toml +41 -0
- package/python/requirements-ci.txt +31 -0
- package/python/requirements.txt +87 -0
- package/python/scripts/__init__.py +4 -0
- package/python/scripts/import_json_trajectories.py +412 -0
- package/python/scripts/local-finetune/README.md +63 -0
- package/python/scripts/local-finetune/ingest_and_score.py +139 -0
- package/python/scripts/local-finetune/merge_model.py +32 -0
- package/python/scripts/local-finetune/test_adapter.py +91 -0
- package/python/scripts/local-finetune/train_from_csv.py +132 -0
- package/python/scripts/merge_trajectories.py +318 -0
- package/python/scripts/run_ab_test.py +143 -0
- package/python/scripts/run_full_pipeline.py +544 -0
- package/python/scripts/run_tinker_training.py +192 -0
- package/python/scripts/run_training.py +914 -0
- package/python/scripts/test_judge.py +155 -0
- package/python/scripts/test_pipeline.py +356 -0
- package/python/scripts/test_trained_model.py +380 -0
- package/python/scripts/train_local.py +528 -0
- package/python/setup.py +20 -0
- package/python/src/__init__.py +190 -0
- package/python/src/data_bridge/__init__.py +24 -0
- package/python/src/data_bridge/converter.py +435 -0
- package/python/src/data_bridge/reader.py +393 -0
- package/python/src/models.py +283 -0
- package/python/src/training/__init__.py +605 -0
- package/python/src/training/ab_testing.py +404 -0
- package/python/src/training/action_executor.py +621 -0
- package/python/src/training/archetype_trainer.py +347 -0
- package/python/src/training/atropos_trainer.py +980 -0
- package/python/src/training/babylon_env.py +1254 -0
- package/python/src/training/error_recovery.py +647 -0
- package/python/src/training/evaluation.py +856 -0
- package/python/src/training/fast_simulator.py +880 -0
- package/python/src/training/format_validator.py +584 -0
- package/python/src/training/hybrid_env.py +522 -0
- package/python/src/training/kl_controller.py +628 -0
- package/python/src/training/multi_prompt_dataset.py +883 -0
- package/python/src/training/multi_turn.py +656 -0
- package/python/src/training/online_env.py +1084 -0
- package/python/src/training/quality_scorer.py +391 -0
- package/python/src/training/quality_utils.py +633 -0
- package/python/src/training/rewards.py +1344 -0
- package/python/src/training/rlaif_env.py +17 -0
- package/python/src/training/rollout_generator.py +502 -0
- package/python/src/training/rubric_loader.py +198 -0
- package/python/src/training/scenario_pool.py +1072 -0
- package/python/src/training/schemas.py +481 -0
- package/python/src/training/service_manager.py +552 -0
- package/python/src/training/simulation_bridge.py +535 -0
- package/python/src/training/tick_reward_attribution.py +399 -0
- package/python/src/training/tinker_client.py +575 -0
- package/python/src/training/tinker_trainer.py +646 -0
- package/python/src/training/tokenization_utils.py +402 -0
- package/python/tests/e2e/__init__.py +13 -0
- package/python/tests/e2e/conftest.py +258 -0
- package/python/tests/e2e/test_full_pipeline.py +643 -0
- package/python/tests/e2e/test_online_training_e2e.py +365 -0
- package/python/tests/integration/__init__.py +12 -0
- package/python/tests/integration/conftest.py +383 -0
- package/python/tests/integration/test_db_integration.py +649 -0
- package/python/tests/integration/test_json_mode_integration.py +554 -0
- package/python/tests/test_action_executor.py +594 -0
- package/python/tests/test_archetype_scoring.py +1027 -0
- package/python/tests/test_atropos_integration.py +360 -0
- package/python/tests/test_evaluation.py +727 -0
- package/python/tests/test_format_validator.py +486 -0
- package/python/tests/test_kl_controller.py +432 -0
- package/python/tests/test_lr_scheduler.py +579 -0
- package/python/tests/test_multi_turn.py +590 -0
- package/python/tests/test_online_env.py +519 -0
- package/python/tests/test_quality_scorer.py +474 -0
- package/python/tests/test_scenario_pool.py +735 -0
- package/python/tests/test_service_manager.py +585 -0
- package/python/tests/test_simulation_rollout.py +581 -0
- package/python/tests/test_tokenization_utils.py +501 -0
- package/python/tests/test_training_orchestrator.py +497 -0
- package/python/tests/test_training_output_structure.py +661 -0
- package/research-output/training-runs/training-run-1770772042899.json +26 -0
- package/research-output/training-runs/training-run-1770930079670.json +32 -0
- package/research-output/training-runs/training-run-1770930143700.json +44 -0
- package/research-output/training-runs/training-run-1770930183638.json +38 -0
- package/research-output/training-runs/training-run-1770930442049.json +38 -0
- package/research-output/training-runs/training-run-1770930793243.json +38 -0
- package/scripts/assess-training-data.ts +422 -0
- package/scripts/e2e-training-test.ts +550 -0
- package/scripts/export-rubrics.ts +64 -0
- package/scripts/generate-research-report.ts +1523 -0
- package/scripts/generate_dataset.sh +173 -0
- package/scripts/json-mode-benchmark.ts +399 -0
- package/scripts/real-archetype-benchmark.ts +210 -0
- package/scripts/run-baseline-comparison.ts +116 -0
- package/scripts/run-full-pipeline.ts +272 -0
- package/scripts/runpod_setup.sh +137 -0
- package/scripts/runpod_validate.sh +147 -0
- package/scripts/test-model-in-game.ts +955 -0
- package/scripts/test-scoring.ts +73 -0
- package/scripts/test-trained-model.ts +209 -0
- package/scripts/train-and-test.ts +824 -0
- package/scripts/verify-final.ts +118 -0
- package/src/adapter.ts +516 -0
- package/src/archetypes/ArchetypeConfigService.ts +626 -0
- package/src/archetypes/derive-archetype.ts +249 -0
- package/src/archetypes/index.ts +22 -0
- package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
- package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
- package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
- package/src/benchmark/BenchmarkDataViewer.ts +324 -0
- package/src/benchmark/BenchmarkHistoryService.ts +221 -0
- package/src/benchmark/BenchmarkRunner.ts +685 -0
- package/src/benchmark/BenchmarkValidator.ts +206 -0
- package/src/benchmark/FastEvalRunner.ts +225 -0
- package/src/benchmark/MetricsValidator.ts +165 -0
- package/src/benchmark/MetricsVisualizer.ts +909 -0
- package/src/benchmark/ModelBenchmarkService.ts +611 -0
- package/src/benchmark/ModelRegistry.ts +158 -0
- package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
- package/src/benchmark/SimulationA2AInterface.ts +1169 -0
- package/src/benchmark/SimulationEngine.ts +832 -0
- package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
- package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
- package/src/benchmark/index.ts +89 -0
- package/src/benchmark/parseSimulationMetrics.ts +124 -0
- package/src/benchmark/simulation-types.ts +78 -0
- package/src/dependencies.ts +439 -0
- package/src/generation/TrajectoryGenerator.ts +387 -0
- package/src/generation/index.ts +12 -0
- package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
- package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
- package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
- package/src/huggingface/index.ts +27 -0
- package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
- package/src/index.ts +102 -0
- package/src/init-training.ts +53 -0
- package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
- package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
- package/src/metrics/index.ts +8 -0
- package/src/metrics/types.ts +200 -0
- package/src/rubrics/__tests__/index.test.ts +184 -0
- package/src/rubrics/ass-kisser.ts +85 -0
- package/src/rubrics/degen.ts +80 -0
- package/src/rubrics/goody-twoshoes.ts +84 -0
- package/src/rubrics/index.ts +236 -0
- package/src/rubrics/information-trader.ts +84 -0
- package/src/rubrics/infosec.ts +101 -0
- package/src/rubrics/liar.ts +104 -0
- package/src/rubrics/perps-trader.ts +87 -0
- package/src/rubrics/researcher.ts +81 -0
- package/src/rubrics/scammer.ts +82 -0
- package/src/rubrics/social-butterfly.ts +73 -0
- package/src/rubrics/super-predictor.ts +97 -0
- package/src/rubrics/trader.ts +67 -0
- package/src/scoring/ArchetypeScoringService.ts +486 -0
- package/src/scoring/JudgePromptBuilder.ts +556 -0
- package/src/scoring/LLMJudgeCache.ts +401 -0
- package/src/scoring/index.ts +9 -0
- package/src/training/AutomationPipeline.ts +916 -0
- package/src/training/BenchmarkService.ts +518 -0
- package/src/training/ConfigValidator.ts +220 -0
- package/src/training/MarketOutcomesTracker.ts +187 -0
- package/src/training/ModelDeployer.ts +186 -0
- package/src/training/ModelFetcher.ts +76 -0
- package/src/training/ModelSelectionService.ts +341 -0
- package/src/training/ModelUsageVerifier.ts +160 -0
- package/src/training/MultiModelOrchestrator.ts +580 -0
- package/src/training/RLModelConfig.ts +407 -0
- package/src/training/RewardBackpropagationService.ts +149 -0
- package/src/training/RulerScoringService.ts +666 -0
- package/src/training/TrainingMonitor.ts +166 -0
- package/src/training/TrajectoryRecorder.ts +399 -0
- package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
- package/src/training/index.ts +100 -0
- package/src/training/logRLConfig.ts +34 -0
- package/src/training/pipeline.ts +129 -0
- package/src/training/storage/ModelStorageService.ts +279 -0
- package/src/training/storage/TrainingDataArchiver.ts +197 -0
- package/src/training/storage/index.ts +17 -0
- package/src/training/types.ts +207 -0
- package/src/training/window-utils.ts +138 -0
- package/src/utils/index.ts +101 -0
- package/src/utils/logger.ts +59 -0
- package/src/utils/snowflake.ts +17 -0
- package/src/utils/synthetic-detector.ts +111 -0
- package/tsconfig.json +20 -0
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "elizaos-training"
|
|
3
|
+
version = "0.2.0"
|
|
4
|
+
description = "RL training pipeline for ElizaOS agents using Atropos"
|
|
5
|
+
authors = [
|
|
6
|
+
{name = "ElizaOS Contributors"}
|
|
7
|
+
]
|
|
8
|
+
requires-python = ">=3.10"
|
|
9
|
+
dependencies = [
|
|
10
|
+
"atroposlib>=0.1.0",
|
|
11
|
+
"asyncpg>=0.29.0",
|
|
12
|
+
"python-dotenv>=1.0.0",
|
|
13
|
+
"pydantic>=2.0.0",
|
|
14
|
+
"numpy>=1.24.0",
|
|
15
|
+
"openai>=1.0.0",
|
|
16
|
+
"litellm>=1.0.0",
|
|
17
|
+
"tabulate>=0.9.0",
|
|
18
|
+
"vllm>=0.4.0",
|
|
19
|
+
"transformers>=4.40.0",
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
[project.optional-dependencies]
|
|
23
|
+
dev = [
|
|
24
|
+
"pytest>=7.0.0",
|
|
25
|
+
"pytest-asyncio>=0.21.0",
|
|
26
|
+
"black>=23.0.0",
|
|
27
|
+
"ruff>=0.1.0",
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
[project.scripts]
|
|
31
|
+
train-mmo = "scripts.train_mmo:main"
|
|
32
|
+
check-windows = "scripts.check_windows:main"
|
|
33
|
+
run-migrations = "scripts.run_migrations:main"
|
|
34
|
+
|
|
35
|
+
[tool.ruff]
|
|
36
|
+
line-length = 100
|
|
37
|
+
target-version = "py310"
|
|
38
|
+
|
|
39
|
+
[tool.black]
|
|
40
|
+
line-length = 100
|
|
41
|
+
target-version = ["py310"]
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# Lightweight dependencies for CI test runs (avoids GPU / vLLM installs)
|
|
2
|
+
atroposlib>=0.3.0
|
|
3
|
+
|
|
4
|
+
# Database (integration tests)
|
|
5
|
+
asyncpg>=0.29.0
|
|
6
|
+
psycopg2-binary>=2.9.9
|
|
7
|
+
|
|
8
|
+
# HTTP/API
|
|
9
|
+
httpx>=0.26.0
|
|
10
|
+
aiohttp>=3.9.0
|
|
11
|
+
requests>=2.31.0
|
|
12
|
+
openai>=1.0.0
|
|
13
|
+
|
|
14
|
+
# Config / typing
|
|
15
|
+
python-dotenv>=1.0.0
|
|
16
|
+
pydantic>=2.5.0
|
|
17
|
+
pyyaml>=6.0.1
|
|
18
|
+
|
|
19
|
+
# Testing
|
|
20
|
+
pytest>=7.4.0
|
|
21
|
+
pytest-asyncio>=0.21.0
|
|
22
|
+
|
|
23
|
+
# Utilities used by the training pipeline
|
|
24
|
+
wandb>=0.16.0
|
|
25
|
+
tqdm>=4.66.0
|
|
26
|
+
psutil>=5.9.0
|
|
27
|
+
numpy>=1.24.0
|
|
28
|
+
tenacity>=8.2.0
|
|
29
|
+
rich>=13.0.0
|
|
30
|
+
jsonlines>=4.0.0
|
|
31
|
+
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
# Babylon RL Training - Atropos + Tinker Framework
|
|
2
|
+
# Supports: Tinker Cloud (RECOMMENDED), CUDA (GPU), MLX (Apple Silicon), CPU
|
|
3
|
+
|
|
4
|
+
# ===========================================
|
|
5
|
+
# Tinker API (RECOMMENDED - Cloud Training)
|
|
6
|
+
# ===========================================
|
|
7
|
+
# Tinker provides cloud-based training without local GPU requirements
|
|
8
|
+
# Get API key from: https://tinker-docs.thinkingmachines.ai/
|
|
9
|
+
tinker>=0.1.0
|
|
10
|
+
|
|
11
|
+
# ===========================================
|
|
12
|
+
# Core Atropos Framework (for environments)
|
|
13
|
+
# ===========================================
|
|
14
|
+
atroposlib>=0.3.0
|
|
15
|
+
|
|
16
|
+
# ===========================================
|
|
17
|
+
# Database
|
|
18
|
+
# ===========================================
|
|
19
|
+
asyncpg>=0.29.0
|
|
20
|
+
psycopg2-binary>=2.9.9
|
|
21
|
+
|
|
22
|
+
# ===========================================
|
|
23
|
+
# HTTP/API
|
|
24
|
+
# ===========================================
|
|
25
|
+
httpx>=0.26.0
|
|
26
|
+
aiohttp>=3.9.0
|
|
27
|
+
requests>=2.31.0
|
|
28
|
+
|
|
29
|
+
# ===========================================
|
|
30
|
+
# OpenAI-compatible client (for RLAIF judge)
|
|
31
|
+
# ===========================================
|
|
32
|
+
openai>=1.0.0
|
|
33
|
+
|
|
34
|
+
# ===========================================
|
|
35
|
+
# Configuration
|
|
36
|
+
# ===========================================
|
|
37
|
+
pyyaml>=6.0.1
|
|
38
|
+
python-dotenv>=1.0.0
|
|
39
|
+
pydantic>=2.5.0
|
|
40
|
+
pydantic-cli>=3.0.0
|
|
41
|
+
|
|
42
|
+
# ===========================================
|
|
43
|
+
# Testing
|
|
44
|
+
# ===========================================
|
|
45
|
+
pytest>=7.4.0
|
|
46
|
+
pytest-asyncio>=0.21.0
|
|
47
|
+
|
|
48
|
+
# ===========================================
|
|
49
|
+
# Experiment Tracking (Optional)
|
|
50
|
+
# ===========================================
|
|
51
|
+
wandb>=0.16.0 # W&B integration, falls back to offline mode if no API key
|
|
52
|
+
|
|
53
|
+
# ===========================================
|
|
54
|
+
# Utilities
|
|
55
|
+
# ===========================================
|
|
56
|
+
tqdm>=4.66.0
|
|
57
|
+
psutil>=5.9.0
|
|
58
|
+
numpy>=1.24.0
|
|
59
|
+
tenacity>=8.2.0
|
|
60
|
+
rich>=13.0.0
|
|
61
|
+
jsonlines>=4.0.0
|
|
62
|
+
|
|
63
|
+
# ============================================
|
|
64
|
+
# OPTIONAL: Local Training (GPU/CPU)
|
|
65
|
+
# ============================================
|
|
66
|
+
# Only needed if NOT using Tinker for training
|
|
67
|
+
# Uncomment if you need local fallback:
|
|
68
|
+
#
|
|
69
|
+
# torch>=2.1.0
|
|
70
|
+
# transformers>=4.36.0
|
|
71
|
+
# peft>=0.8.0
|
|
72
|
+
# vllm>=0.3.0
|
|
73
|
+
# accelerate>=1.12.0
|
|
74
|
+
|
|
75
|
+
# ============================================
|
|
76
|
+
# OPTIONAL: MLX Backend (Apple Silicon only)
|
|
77
|
+
# ============================================
|
|
78
|
+
# Install on Mac with Apple Silicon:
|
|
79
|
+
# pip install mlx mlx-lm
|
|
80
|
+
#
|
|
81
|
+
# For fine-tuning support:
|
|
82
|
+
# pip install mlx-lm[finetuning]
|
|
83
|
+
#
|
|
84
|
+
# Recommended models for MLX:
|
|
85
|
+
# - mlx-community/Qwen2.5-3B-Instruct-4bit
|
|
86
|
+
# - mlx-community/Qwen2.5-7B-Instruct-4bit
|
|
87
|
+
# - mlx-community/Qwen3-4B-4bit
|
|
@@ -0,0 +1,412 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Import JSON Trajectories to PostgreSQL
|
|
4
|
+
|
|
5
|
+
Bridges JSON-mode generated trajectories to the database for integration testing
|
|
6
|
+
and production seeding. This enables testing the database pipeline with data
|
|
7
|
+
generated by the TypeScript simulation engine.
|
|
8
|
+
|
|
9
|
+
Usage:
|
|
10
|
+
# Import from default training-data-output directory
|
|
11
|
+
python scripts/import_json_trajectories.py
|
|
12
|
+
|
|
13
|
+
# Import from custom directory
|
|
14
|
+
python scripts/import_json_trajectories.py --source ./my-trajectories
|
|
15
|
+
|
|
16
|
+
# Dry run (validate without inserting)
|
|
17
|
+
python scripts/import_json_trajectories.py --dry-run
|
|
18
|
+
|
|
19
|
+
# Verbose output
|
|
20
|
+
python scripts/import_json_trajectories.py --verbose
|
|
21
|
+
|
|
22
|
+
Environment:
|
|
23
|
+
DATABASE_URL: PostgreSQL connection string (required)
|
|
24
|
+
|
|
25
|
+
Requirements:
|
|
26
|
+
- psycopg2
|
|
27
|
+
- JSON trajectory files generated by generate-training-data.ts
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
import argparse
|
|
31
|
+
import json
|
|
32
|
+
import logging
|
|
33
|
+
import os
|
|
34
|
+
import sys
|
|
35
|
+
from dataclasses import dataclass
|
|
36
|
+
from datetime import datetime
|
|
37
|
+
from pathlib import Path
|
|
38
|
+
from typing import Dict, List, Optional
|
|
39
|
+
import hashlib
|
|
40
|
+
|
|
41
|
+
# Add src to path
|
|
42
|
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
43
|
+
|
|
44
|
+
from src.data_bridge.reader import JsonTrajectoryReader, validate_llm_calls
|
|
45
|
+
from src.training.rubric_loader import normalize_archetype, has_custom_rubric
|
|
46
|
+
|
|
47
|
+
logging.basicConfig(
|
|
48
|
+
level=logging.INFO,
|
|
49
|
+
format="%(asctime)s - %(levelname)s - %(message)s"
|
|
50
|
+
)
|
|
51
|
+
logger = logging.getLogger(__name__)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@dataclass
|
|
55
|
+
class ImportStats:
|
|
56
|
+
"""Statistics for import operation."""
|
|
57
|
+
total_files: int = 0
|
|
58
|
+
valid_trajectories: int = 0
|
|
59
|
+
invalid_trajectories: int = 0
|
|
60
|
+
inserted: int = 0
|
|
61
|
+
skipped_existing: int = 0
|
|
62
|
+
failed: int = 0
|
|
63
|
+
archetypes_seen: Dict[str, int] = None
|
|
64
|
+
|
|
65
|
+
def __post_init__(self):
|
|
66
|
+
if self.archetypes_seen is None:
|
|
67
|
+
self.archetypes_seen = {}
|
|
68
|
+
|
|
69
|
+
def record_archetype(self, archetype: str):
|
|
70
|
+
normalized = normalize_archetype(archetype)
|
|
71
|
+
self.archetypes_seen[normalized] = self.archetypes_seen.get(normalized, 0) + 1
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def get_db_connection():
|
|
75
|
+
"""Get database connection from environment."""
|
|
76
|
+
database_url = os.environ.get("DATABASE_URL")
|
|
77
|
+
if not database_url:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
"DATABASE_URL environment variable not set. "
|
|
80
|
+
"Example: postgresql://babylon:password@localhost:5433/babylon"
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
try:
|
|
84
|
+
import psycopg2
|
|
85
|
+
except ImportError:
|
|
86
|
+
raise ImportError(
|
|
87
|
+
"psycopg2 is required for database import. "
|
|
88
|
+
"Install with: pip install psycopg2-binary"
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
return psycopg2.connect(database_url)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def generate_snowflake_id() -> str:
|
|
95
|
+
"""Generate a unique ID similar to TypeScript snowflake."""
|
|
96
|
+
timestamp = int(datetime.now().timestamp() * 1000)
|
|
97
|
+
random_part = hashlib.sha256(os.urandom(8)).hexdigest()[:8]
|
|
98
|
+
return f"{timestamp}{random_part}"
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def extract_archetype_from_trajectory(traj_data: Dict) -> str:
|
|
102
|
+
"""
|
|
103
|
+
Extract archetype from trajectory data.
|
|
104
|
+
|
|
105
|
+
Priority:
|
|
106
|
+
1. trajectory.archetype field
|
|
107
|
+
2. First step's action.parameters.archetype
|
|
108
|
+
3. Default to 'trader'
|
|
109
|
+
"""
|
|
110
|
+
# Check trajectory-level archetype
|
|
111
|
+
archetype = traj_data.get("archetype")
|
|
112
|
+
if archetype and archetype != "default":
|
|
113
|
+
return normalize_archetype(archetype)
|
|
114
|
+
|
|
115
|
+
# Check steps for archetype in action parameters
|
|
116
|
+
steps_json = traj_data.get("stepsJson", "[]")
|
|
117
|
+
steps = json.loads(steps_json) if isinstance(steps_json, str) else steps_json
|
|
118
|
+
|
|
119
|
+
for step in steps:
|
|
120
|
+
action = step.get("action", {})
|
|
121
|
+
params = action.get("parameters", {})
|
|
122
|
+
step_archetype = params.get("archetype")
|
|
123
|
+
if step_archetype:
|
|
124
|
+
return normalize_archetype(step_archetype)
|
|
125
|
+
|
|
126
|
+
return "trader"
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def validate_trajectory(traj_data: Dict) -> tuple[bool, List[str]]:
|
|
130
|
+
"""
|
|
131
|
+
Validate a trajectory for import.
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
(is_valid, list_of_issues)
|
|
135
|
+
"""
|
|
136
|
+
issues = []
|
|
137
|
+
|
|
138
|
+
# Required fields
|
|
139
|
+
required = ["trajectoryId", "agentId", "windowId"]
|
|
140
|
+
for field in required:
|
|
141
|
+
if not traj_data.get(field):
|
|
142
|
+
issues.append(f"Missing required field: {field}")
|
|
143
|
+
|
|
144
|
+
# Steps validation
|
|
145
|
+
steps_json = traj_data.get("stepsJson", "[]")
|
|
146
|
+
try:
|
|
147
|
+
steps = json.loads(steps_json) if isinstance(steps_json, str) else steps_json
|
|
148
|
+
if len(steps) == 0:
|
|
149
|
+
issues.append("No steps in trajectory")
|
|
150
|
+
else:
|
|
151
|
+
is_valid_llm, llm_issues = validate_llm_calls(steps)
|
|
152
|
+
if not is_valid_llm:
|
|
153
|
+
issues.extend(llm_issues)
|
|
154
|
+
except json.JSONDecodeError as e:
|
|
155
|
+
issues.append(f"Invalid stepsJson: {e}")
|
|
156
|
+
|
|
157
|
+
# Archetype validation
|
|
158
|
+
archetype = extract_archetype_from_trajectory(traj_data)
|
|
159
|
+
if archetype != "default" and not has_custom_rubric(archetype):
|
|
160
|
+
issues.append(f"Unknown archetype: {archetype}")
|
|
161
|
+
|
|
162
|
+
return len(issues) == 0, issues
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def check_trajectory_exists(conn, trajectory_id: str) -> bool:
|
|
166
|
+
"""Check if trajectory already exists in database."""
|
|
167
|
+
cur = conn.cursor()
|
|
168
|
+
cur.execute(
|
|
169
|
+
'SELECT 1 FROM trajectories WHERE "trajectoryId" = %s LIMIT 1',
|
|
170
|
+
(trajectory_id,)
|
|
171
|
+
)
|
|
172
|
+
exists = cur.fetchone() is not None
|
|
173
|
+
cur.close()
|
|
174
|
+
return exists
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def insert_trajectory(conn, traj_data: Dict) -> bool:
|
|
178
|
+
"""
|
|
179
|
+
Insert a trajectory into the database.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
True if inserted, False if failed
|
|
183
|
+
"""
|
|
184
|
+
trajectory_id = traj_data.get("trajectoryId")
|
|
185
|
+
archetype = extract_archetype_from_trajectory(traj_data)
|
|
186
|
+
|
|
187
|
+
# Parse timestamps
|
|
188
|
+
start_time = traj_data.get("startTime")
|
|
189
|
+
if isinstance(start_time, str):
|
|
190
|
+
start_time = datetime.fromisoformat(start_time.replace("Z", "+00:00"))
|
|
191
|
+
elif start_time is None:
|
|
192
|
+
start_time = datetime.now()
|
|
193
|
+
|
|
194
|
+
end_time = traj_data.get("endTime")
|
|
195
|
+
if isinstance(end_time, str):
|
|
196
|
+
end_time = datetime.fromisoformat(end_time.replace("Z", "+00:00"))
|
|
197
|
+
elif end_time is None:
|
|
198
|
+
end_time = datetime.now()
|
|
199
|
+
|
|
200
|
+
# Prepare values
|
|
201
|
+
values = (
|
|
202
|
+
traj_data.get("id", generate_snowflake_id()),
|
|
203
|
+
trajectory_id,
|
|
204
|
+
traj_data.get("agentId"),
|
|
205
|
+
archetype,
|
|
206
|
+
traj_data.get("windowId"),
|
|
207
|
+
traj_data.get("scenarioId"),
|
|
208
|
+
traj_data.get("stepsJson", "[]"),
|
|
209
|
+
json.dumps(traj_data.get("rewardComponents", {})),
|
|
210
|
+
json.dumps(traj_data.get("metrics", {})),
|
|
211
|
+
json.dumps(traj_data.get("metadata", {})),
|
|
212
|
+
float(traj_data.get("finalPnL", 0)),
|
|
213
|
+
int(traj_data.get("episodeLength", 0)),
|
|
214
|
+
float(traj_data.get("totalReward", 0)),
|
|
215
|
+
traj_data.get("finalStatus", "completed"),
|
|
216
|
+
float(traj_data.get("finalBalance")) if traj_data.get("finalBalance") else None,
|
|
217
|
+
int(traj_data.get("tradesExecuted", 0)),
|
|
218
|
+
int(traj_data.get("postsCreated", 0)),
|
|
219
|
+
float(traj_data.get("aiJudgeReward")) if traj_data.get("aiJudgeReward") else None,
|
|
220
|
+
traj_data.get("isTrainingData", True),
|
|
221
|
+
traj_data.get("isEvaluation", False),
|
|
222
|
+
traj_data.get("usedInTraining", False),
|
|
223
|
+
start_time,
|
|
224
|
+
end_time,
|
|
225
|
+
int(traj_data.get("durationMs", 0)),
|
|
226
|
+
1, # windowHours
|
|
227
|
+
datetime.now(),
|
|
228
|
+
datetime.now(),
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
cur = conn.cursor()
|
|
232
|
+
cur.execute(
|
|
233
|
+
'''
|
|
234
|
+
INSERT INTO trajectories (
|
|
235
|
+
"id", "trajectoryId", "agentId", "archetype", "windowId",
|
|
236
|
+
"scenarioId", "stepsJson", "rewardComponentsJson", "metricsJson",
|
|
237
|
+
"metadataJson", "finalPnL", "episodeLength", "totalReward",
|
|
238
|
+
"finalStatus", "finalBalance", "tradesExecuted", "postsCreated",
|
|
239
|
+
"aiJudgeReward", "isTrainingData", "isEvaluation", "usedInTraining",
|
|
240
|
+
"startTime", "endTime", "durationMs", "windowHours",
|
|
241
|
+
"createdAt", "updatedAt"
|
|
242
|
+
) VALUES (
|
|
243
|
+
%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s
|
|
244
|
+
)
|
|
245
|
+
''',
|
|
246
|
+
values
|
|
247
|
+
)
|
|
248
|
+
conn.commit()
|
|
249
|
+
cur.close()
|
|
250
|
+
return True
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def import_trajectories(
|
|
254
|
+
source_dir: Path,
|
|
255
|
+
dry_run: bool = False,
|
|
256
|
+
verbose: bool = False,
|
|
257
|
+
) -> ImportStats:
|
|
258
|
+
"""
|
|
259
|
+
Import all trajectories from source directory to database.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
source_dir: Directory containing trajectory JSON files
|
|
263
|
+
dry_run: If True, validate but don't insert
|
|
264
|
+
verbose: If True, log each trajectory
|
|
265
|
+
|
|
266
|
+
Returns:
|
|
267
|
+
Import statistics
|
|
268
|
+
"""
|
|
269
|
+
stats = ImportStats()
|
|
270
|
+
|
|
271
|
+
# Find trajectory files
|
|
272
|
+
traj_dir = source_dir / "trajectories"
|
|
273
|
+
if not traj_dir.exists():
|
|
274
|
+
logger.warning(f"Trajectories directory not found: {traj_dir}")
|
|
275
|
+
traj_dir = source_dir
|
|
276
|
+
|
|
277
|
+
json_files = list(traj_dir.glob("*.json"))
|
|
278
|
+
stats.total_files = len(json_files)
|
|
279
|
+
|
|
280
|
+
if stats.total_files == 0:
|
|
281
|
+
logger.warning(f"No JSON files found in {traj_dir}")
|
|
282
|
+
return stats
|
|
283
|
+
|
|
284
|
+
logger.info(f"Found {stats.total_files} trajectory files in {traj_dir}")
|
|
285
|
+
|
|
286
|
+
# Get database connection (skip if dry run)
|
|
287
|
+
conn = None
|
|
288
|
+
if not dry_run:
|
|
289
|
+
conn = get_db_connection()
|
|
290
|
+
logger.info("Connected to database")
|
|
291
|
+
|
|
292
|
+
# Process each file
|
|
293
|
+
for file_path in json_files:
|
|
294
|
+
try:
|
|
295
|
+
with open(file_path, "r", encoding="utf-8") as f:
|
|
296
|
+
data = json.load(f)
|
|
297
|
+
|
|
298
|
+
# Handle wrapped format (trajectory key) vs direct format
|
|
299
|
+
traj_data = data.get("trajectory", data)
|
|
300
|
+
trajectory_id = traj_data.get("trajectoryId", file_path.stem)
|
|
301
|
+
|
|
302
|
+
# Validate
|
|
303
|
+
is_valid, issues = validate_trajectory(traj_data)
|
|
304
|
+
if not is_valid:
|
|
305
|
+
stats.invalid_trajectories += 1
|
|
306
|
+
if verbose:
|
|
307
|
+
logger.warning(f"Invalid trajectory {trajectory_id}: {issues}")
|
|
308
|
+
continue
|
|
309
|
+
|
|
310
|
+
stats.valid_trajectories += 1
|
|
311
|
+
archetype = extract_archetype_from_trajectory(traj_data)
|
|
312
|
+
stats.record_archetype(archetype)
|
|
313
|
+
|
|
314
|
+
if verbose:
|
|
315
|
+
logger.info(f"Validated: {trajectory_id} (archetype: {archetype})")
|
|
316
|
+
|
|
317
|
+
if dry_run:
|
|
318
|
+
continue
|
|
319
|
+
|
|
320
|
+
# Check if exists
|
|
321
|
+
if check_trajectory_exists(conn, trajectory_id):
|
|
322
|
+
stats.skipped_existing += 1
|
|
323
|
+
if verbose:
|
|
324
|
+
logger.info(f"Skipped existing: {trajectory_id}")
|
|
325
|
+
continue
|
|
326
|
+
|
|
327
|
+
# Insert
|
|
328
|
+
try:
|
|
329
|
+
insert_trajectory(conn, traj_data)
|
|
330
|
+
stats.inserted += 1
|
|
331
|
+
if verbose:
|
|
332
|
+
logger.info(f"Inserted: {trajectory_id}")
|
|
333
|
+
except Exception as e:
|
|
334
|
+
stats.failed += 1
|
|
335
|
+
logger.error(f"Failed to insert {trajectory_id}: {e}")
|
|
336
|
+
|
|
337
|
+
except json.JSONDecodeError as e:
|
|
338
|
+
stats.invalid_trajectories += 1
|
|
339
|
+
logger.warning(f"Invalid JSON in {file_path}: {e}")
|
|
340
|
+
except Exception as e:
|
|
341
|
+
stats.failed += 1
|
|
342
|
+
logger.error(f"Error processing {file_path}: {e}")
|
|
343
|
+
|
|
344
|
+
if conn:
|
|
345
|
+
conn.close()
|
|
346
|
+
|
|
347
|
+
return stats
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def main():
|
|
351
|
+
parser = argparse.ArgumentParser(
|
|
352
|
+
description="Import JSON trajectories to PostgreSQL database"
|
|
353
|
+
)
|
|
354
|
+
parser.add_argument(
|
|
355
|
+
"--source",
|
|
356
|
+
type=Path,
|
|
357
|
+
default=Path("./training-data-output"),
|
|
358
|
+
help="Source directory containing trajectory JSON files"
|
|
359
|
+
)
|
|
360
|
+
parser.add_argument(
|
|
361
|
+
"--dry-run",
|
|
362
|
+
action="store_true",
|
|
363
|
+
help="Validate trajectories without inserting to database"
|
|
364
|
+
)
|
|
365
|
+
parser.add_argument(
|
|
366
|
+
"--verbose",
|
|
367
|
+
action="store_true",
|
|
368
|
+
help="Log each trajectory processed"
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
args = parser.parse_args()
|
|
372
|
+
|
|
373
|
+
if not args.source.exists():
|
|
374
|
+
logger.error(f"Source directory not found: {args.source}")
|
|
375
|
+
sys.exit(1)
|
|
376
|
+
|
|
377
|
+
if args.dry_run:
|
|
378
|
+
logger.info("DRY RUN MODE - No database modifications")
|
|
379
|
+
|
|
380
|
+
stats = import_trajectories(
|
|
381
|
+
source_dir=args.source,
|
|
382
|
+
dry_run=args.dry_run,
|
|
383
|
+
verbose=args.verbose,
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
# Print summary
|
|
387
|
+
print("\n" + "=" * 50)
|
|
388
|
+
print("IMPORT SUMMARY")
|
|
389
|
+
print("=" * 50)
|
|
390
|
+
print(f"Total files: {stats.total_files}")
|
|
391
|
+
print(f"Valid trajectories: {stats.valid_trajectories}")
|
|
392
|
+
print(f"Invalid trajectories: {stats.invalid_trajectories}")
|
|
393
|
+
|
|
394
|
+
if not args.dry_run:
|
|
395
|
+
print(f"Inserted: {stats.inserted}")
|
|
396
|
+
print(f"Skipped (existing): {stats.skipped_existing}")
|
|
397
|
+
print(f"Failed: {stats.failed}")
|
|
398
|
+
|
|
399
|
+
if stats.archetypes_seen:
|
|
400
|
+
print("\nArchetypes found:")
|
|
401
|
+
for archetype, count in sorted(stats.archetypes_seen.items()):
|
|
402
|
+
print(f" - {archetype}: {count}")
|
|
403
|
+
|
|
404
|
+
if stats.failed > 0:
|
|
405
|
+
sys.exit(1)
|
|
406
|
+
|
|
407
|
+
sys.exit(0)
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
if __name__ == "__main__":
|
|
411
|
+
main()
|
|
412
|
+
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
# Local Fine-Tuning Pipeline
|
|
2
|
+
|
|
3
|
+
This directory contains scripts to train RL adapters from Babylon simulation logs.
|
|
4
|
+
|
|
5
|
+
## Workflow
|
|
6
|
+
|
|
7
|
+
1. **Generate Data:** Run `bun packages/engine/examples/generate-training-data.ts`
|
|
8
|
+
2. **Score & Format:** Run `python ingest_and_score.py`
|
|
9
|
+
3. **Train:** Run `python train_from_csv.py`
|
|
10
|
+
4. **Test:** Run `python test_adapter.py`
|
|
11
|
+
|
|
12
|
+
## Quick Start
|
|
13
|
+
|
|
14
|
+
If you do not have a local Postgres database, Atropos server, or vLLM instance running, you can use the **Offline Pipeline**. This generates data to JSON files and uses direct PyTorch/HuggingFace libraries for training.
|
|
15
|
+
|
|
16
|
+
### Prerequisites
|
|
17
|
+
|
|
18
|
+
1. `GROQ_API_KEY` or `OPENAI_API_KEY` set in environment.
|
|
19
|
+
2. Python dependencies: `pip install torch transformers peft pandas datasets trl`
|
|
20
|
+
|
|
21
|
+
### Step 1: Generate Data (TypeScript)
|
|
22
|
+
|
|
23
|
+
Runs the game simulation in-memory and dumps "Observation -> Action" logs to JSON.
|
|
24
|
+
|
|
25
|
+
```bash
|
|
26
|
+
# Runs 24 simulated hours
|
|
27
|
+
bun packages/engine/examples/generate-training-data.ts
|
|
28
|
+
```
|
|
29
|
+
|
|
30
|
+
_Output:_ `training-data-output/trajectories/*.json`
|
|
31
|
+
|
|
32
|
+
### Step 2: Process & Score (Python)
|
|
33
|
+
|
|
34
|
+
Converts raw JSON logs into a scored CSV dataset (System/User/Assistant format).
|
|
35
|
+
|
|
36
|
+
```bash
|
|
37
|
+
cd packages/training/python/scripts/local-finetune
|
|
38
|
+
python ingest_and_score.py
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
_Output:_ `packages/training/data/scored_trajectories.csv`
|
|
42
|
+
|
|
43
|
+
### Step 3: Train Model (Python)
|
|
44
|
+
|
|
45
|
+
Fine-tunes a base model (Qwen2.5-0.5B by default) on your scored data using LoRA.
|
|
46
|
+
|
|
47
|
+
```bash
|
|
48
|
+
python train_from_csv.py --output ./my-model-v1
|
|
49
|
+
```
|
|
50
|
+
|
|
51
|
+
### Step 4: Test Inference
|
|
52
|
+
|
|
53
|
+
Interactively chat with your new LoRA adapter to verify behavior.
|
|
54
|
+
|
|
55
|
+
```bash
|
|
56
|
+
python test_adapter.py
|
|
57
|
+
```
|
|
58
|
+
|
|
59
|
+
---
|
|
60
|
+
|
|
61
|
+
## 🏗️ Production Architecture (Tinker/Atropos)
|
|
62
|
+
|
|
63
|
+
_For the full cloud-based pipeline involving Postgres, GRPO, and Tinker compute, refer to `scripts/run_full_pipeline.py`._
|