synth-ai 0.2.4.dev7__py3-none-any.whl ā 0.2.4.dev9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- synth_ai/__init__.py +1 -1
- synth_ai/cli/__init__.py +6 -0
- synth_ai/cli/balance.py +3 -15
- synth_ai/cli/demo.py +68 -9
- synth_ai/cli/rl_demo.py +137 -0
- synth_ai/cli/root.py +65 -0
- synth_ai/config/base_url.py +47 -0
- synth_ai/demos/core/__init__.py +1 -0
- synth_ai/demos/core/cli.py +621 -0
- synth_ai/demos/demo_task_apps/__init__.py +1 -0
- synth_ai/demos/demo_task_apps/core.py +374 -0
- synth_ai/demos/demo_task_apps/math/__init__.py +1 -0
- synth_ai/demos/demo_task_apps/math/app.py +37 -0
- synth_ai/demos/demo_task_apps/math/config.toml +44 -0
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +60 -0
- synth_ai/demos/demo_task_apps/math/deploy_task_app.sh +22 -0
- synth_ai/environments/examples/bandit/__init__.py +33 -0
- synth_ai/environments/examples/bandit/engine.py +294 -0
- synth_ai/environments/examples/bandit/environment.py +194 -0
- synth_ai/environments/examples/bandit/taskset.py +200 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/analyze_semantic_words_markdown.py +250 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +59 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_config.toml +24 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/crafter_synth_config.toml +56 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_config_modal.toml +32 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +724 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/kick_off_ft_modal.py +384 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_action_results.py +53 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_agent_actions.py +178 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_latest_run.py +222 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_lm_traces.py +183 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_no_rewards.py +210 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_trace_issue.py +206 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/check_db_schema.py +49 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/check_latest_results.py +64 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/debug_agent_responses.py +88 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/quick_trace_check.py +77 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/compare_experiments.py +324 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +580 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/kick_off_ft_oai.py +362 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/multi_model_config.toml +49 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_enhanced_hooks.py +332 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_hook_events.py +97 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_hook_results.py +217 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/check_hook_storage.py +87 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/check_seeds.py +88 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/compare_seed_performance.py +195 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/custom_eval_pipelines.py +400 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/plot_hook_frequency.py +195 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/seed_analysis_summary.py +56 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/run_rollouts_for_models_and_compare_v3.py +858 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +52 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +874 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/example_v3_usage.py +216 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/compare_traces.py +296 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_comprehensive_evaluation.py +58 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_env_serialization.py +464 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_quick_evaluation.py +51 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/debug_player_loss.py +112 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/diagnose_service.py +203 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/diagnose_slowness.py +305 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/eval_by_difficulty.py +126 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/eval_example.py +94 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/explore_saved_states.py +142 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/filter_traces_sft.py +26 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/filter_traces_sft_OLD.py +984 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_data_gemini.py +724 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_data_modal.py +386 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_metadata.py +205 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/kick_off_ft_gemini.py +150 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/kick_off_ft_modal.py +283 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/prepare_vertex_ft.py +280 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/profile_env_slowness.py +456 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/replicate_issue.py +166 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/run_and_eval.py +102 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/run_comparison.py +128 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/run_qwen_rollouts.py +655 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/trace_eval_OLD.py +202 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/validate_openai_format.py +166 -0
- synth_ai/environments/examples/crafter_classic/environment.py +41 -2
- synth_ai/environments/examples/crafter_custom/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/crafter_custom/agent_demos/trace_eval.py +202 -0
- synth_ai/environments/examples/crafter_custom/old/analyze_diamond_issue.py +159 -0
- synth_ai/environments/examples/crafter_custom/old/analyze_diamond_spawning.py +158 -0
- synth_ai/environments/examples/crafter_custom/old/compare_worlds.py +71 -0
- synth_ai/environments/examples/crafter_custom/old/dataset_stats.py +105 -0
- synth_ai/environments/examples/crafter_custom/old/diamond_spawning_summary.py +119 -0
- synth_ai/environments/examples/crafter_custom/old/example_dataset_usage.py +52 -0
- synth_ai/environments/examples/enron/units/keyword_stats.py +112 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +48 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +221 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +831 -0
- synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/red/units/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +899 -0
- synth_ai/environments/examples/sokoban/units/astar_common.py +95 -0
- synth_ai/environments/service/app.py +8 -0
- synth_ai/http.py +102 -0
- synth_ai/inference/__init__.py +7 -0
- synth_ai/inference/client.py +20 -0
- synth_ai/install_sqld.sh +40 -0
- synth_ai/jobs/client.py +246 -0
- synth_ai/learning/__init__.py +24 -0
- synth_ai/learning/client.py +149 -0
- synth_ai/learning/config.py +43 -0
- synth_ai/learning/constants.py +29 -0
- synth_ai/learning/ft_client.py +59 -0
- synth_ai/learning/health.py +43 -0
- synth_ai/learning/jobs.py +205 -0
- synth_ai/learning/rl_client.py +256 -0
- synth_ai/learning/sse.py +58 -0
- synth_ai/learning/validators.py +48 -0
- synth_ai/lm/core/main_v3.py +13 -0
- synth_ai/lm/core/synth_models.py +48 -0
- synth_ai/lm/core/vendor_clients.py +9 -6
- synth_ai/lm/vendors/core/openai_api.py +31 -3
- synth_ai/lm/vendors/openai_standard.py +45 -14
- synth_ai/lm/vendors/supported/custom_endpoint.py +12 -2
- synth_ai/lm/vendors/synth_client.py +372 -28
- synth_ai/rl/__init__.py +30 -0
- synth_ai/rl/contracts.py +32 -0
- synth_ai/rl/env_keys.py +137 -0
- synth_ai/rl/secrets.py +19 -0
- synth_ai/scripts/verify_rewards.py +100 -0
- synth_ai/task/__init__.py +10 -0
- synth_ai/task/contracts.py +120 -0
- synth_ai/task/health.py +28 -0
- synth_ai/task/validators.py +12 -0
- synth_ai/tracing_v3/hooks.py +3 -1
- synth_ai/tracing_v3/session_tracer.py +123 -2
- synth_ai/tracing_v3/turso/manager.py +218 -0
- synth_ai/tracing_v3/turso/models.py +53 -0
- synth_ai-0.2.4.dev9.dist-info/METADATA +91 -0
- {synth_ai-0.2.4.dev7.dist-info ā synth_ai-0.2.4.dev9.dist-info}/RECORD +147 -30
- {synth_ai-0.2.4.dev7.dist-info ā synth_ai-0.2.4.dev9.dist-info}/entry_points.txt +1 -0
- synth_ai/tui/__init__.py +0 -1
- synth_ai/tui/__main__.py +0 -13
- synth_ai/tui/cli/__init__.py +0 -1
- synth_ai/tui/cli/query_experiments.py +0 -164
- synth_ai/tui/cli/query_experiments_v3.py +0 -164
- synth_ai/tui/dashboard.py +0 -340
- synth_ai-0.2.4.dev7.dist-info/METADATA +0 -193
- {synth_ai-0.2.4.dev7.dist-info ā synth_ai-0.2.4.dev9.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.4.dev7.dist-info ā synth_ai-0.2.4.dev9.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.4.dev7.dist-info ā synth_ai-0.2.4.dev9.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,724 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Generate Fine-tuning Data for Gemini Models
|
|
4
|
+
===========================================
|
|
5
|
+
This script generates high-quality trajectories from Crafter using Gemini models
|
|
6
|
+
and converts them to JSONL format suitable for Vertex AI fine-tuning.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import asyncio
|
|
10
|
+
import json
|
|
11
|
+
import uuid
|
|
12
|
+
import argparse
|
|
13
|
+
import toml
|
|
14
|
+
import logging
|
|
15
|
+
from datetime import datetime
|
|
16
|
+
from typing import Dict, Any, Optional, List, Set, Tuple
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
import sys
|
|
19
|
+
import os
|
|
20
|
+
import numpy as np
|
|
21
|
+
from collections import defaultdict
|
|
22
|
+
import time
|
|
23
|
+
from tqdm.asyncio import tqdm_asyncio
|
|
24
|
+
from httpx import AsyncClient
|
|
25
|
+
|
|
26
|
+
# Add the src directory to the path
|
|
27
|
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "src"))
|
|
28
|
+
|
|
29
|
+
from synth_ai.lm.core.main import LM
|
|
30
|
+
from synth_ai.lm.tools.base import BaseTool
|
|
31
|
+
from pydantic import BaseModel, Field
|
|
32
|
+
|
|
33
|
+
# Import TaskInstance and related classes
|
|
34
|
+
from synth_ai.environments.tasks.core import (
|
|
35
|
+
Impetus,
|
|
36
|
+
Intent,
|
|
37
|
+
Task,
|
|
38
|
+
TaskInstance,
|
|
39
|
+
)
|
|
40
|
+
from synth_ai.environments.examples.crafter_classic.taskset import CrafterTaskInstance, CrafterTaskInstanceMetadata
|
|
41
|
+
|
|
42
|
+
# Import trace evaluation utilities
|
|
43
|
+
sys.path.append(str(Path(__file__).parent))
|
|
44
|
+
from trace_eval import evaluate_trace, WEIGHTS
|
|
45
|
+
from filter_traces_sft import load_trace, extract_trajectory_score, extract_llm_calls
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# --- Helper Functions ---
|
|
49
|
+
def parse_observation_text(obs_text: str) -> Dict[str, Any]:
|
|
50
|
+
"""Parse structured observation from text format."""
|
|
51
|
+
obs_data = {
|
|
52
|
+
"health": 10,
|
|
53
|
+
"hunger": 10,
|
|
54
|
+
"thirst": 10,
|
|
55
|
+
"inventory": {},
|
|
56
|
+
"achievements_dict": {},
|
|
57
|
+
"player_position": [0, 0],
|
|
58
|
+
"semantic_map": [],
|
|
59
|
+
"done": False
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
if not obs_text:
|
|
63
|
+
return obs_data
|
|
64
|
+
|
|
65
|
+
lines = obs_text.strip().split('\n')
|
|
66
|
+
current_section = None
|
|
67
|
+
|
|
68
|
+
for line in lines:
|
|
69
|
+
line = line.strip()
|
|
70
|
+
if not line:
|
|
71
|
+
continue
|
|
72
|
+
|
|
73
|
+
# Parse stats
|
|
74
|
+
if "Health:" in line:
|
|
75
|
+
try:
|
|
76
|
+
health = line.split(":")[1].strip().split("/")[0]
|
|
77
|
+
obs_data["health"] = int(health)
|
|
78
|
+
except:
|
|
79
|
+
pass
|
|
80
|
+
elif "Hunger:" in line:
|
|
81
|
+
try:
|
|
82
|
+
hunger = line.split(":")[1].strip().split("/")[0]
|
|
83
|
+
obs_data["hunger"] = int(hunger)
|
|
84
|
+
except:
|
|
85
|
+
pass
|
|
86
|
+
elif "Thirst:" in line:
|
|
87
|
+
try:
|
|
88
|
+
thirst = line.split(":")[1].strip().split("/")[0]
|
|
89
|
+
obs_data["thirst"] = int(thirst)
|
|
90
|
+
except:
|
|
91
|
+
pass
|
|
92
|
+
elif "Position:" in line:
|
|
93
|
+
try:
|
|
94
|
+
pos_str = line.split(":")[1].strip()
|
|
95
|
+
x, y = pos_str.strip("()").split(",")
|
|
96
|
+
obs_data["player_position"] = [int(x), int(y)]
|
|
97
|
+
except:
|
|
98
|
+
pass
|
|
99
|
+
elif "Inventory:" in line:
|
|
100
|
+
current_section = "inventory"
|
|
101
|
+
elif "Achievements:" in line:
|
|
102
|
+
current_section = "achievements"
|
|
103
|
+
elif current_section == "inventory" and " - " in line:
|
|
104
|
+
try:
|
|
105
|
+
item, count = line.split(" - ")
|
|
106
|
+
item = item.strip().strip("-").strip()
|
|
107
|
+
count = int(count.split(":")[1].strip())
|
|
108
|
+
obs_data["inventory"][item] = count
|
|
109
|
+
except:
|
|
110
|
+
pass
|
|
111
|
+
elif current_section == "achievements" and line:
|
|
112
|
+
# Parse achievements list
|
|
113
|
+
achievements = line.split(", ")
|
|
114
|
+
for ach in achievements:
|
|
115
|
+
ach = ach.strip()
|
|
116
|
+
if ach:
|
|
117
|
+
obs_data["achievements_dict"][ach] = True
|
|
118
|
+
|
|
119
|
+
return obs_data
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
# --- Configuration ---
|
|
123
|
+
class GenerationConfig:
|
|
124
|
+
"""Configuration for fine-tuning data generation."""
|
|
125
|
+
|
|
126
|
+
def __init__(self, config_path: Optional[str] = None):
|
|
127
|
+
# Default values
|
|
128
|
+
self.model_name = "gemini-2.5-flash" # Best Gemini model for reasoning
|
|
129
|
+
self.num_rollouts = 100
|
|
130
|
+
self.max_turns = 30
|
|
131
|
+
self.difficulty = "easy"
|
|
132
|
+
self.service_base_url = "http://localhost:8901"
|
|
133
|
+
self.service_timeout = 30.0
|
|
134
|
+
self.seed = 42
|
|
135
|
+
self.traces_dir = Path("traces_gemini")
|
|
136
|
+
self.ft_data_dir = Path("ft_data_gemini")
|
|
137
|
+
|
|
138
|
+
# Quality filtering
|
|
139
|
+
self.min_score_threshold = 2.0 # Minimum trajectory score
|
|
140
|
+
self.min_achievements = 3 # Minimum achievements required
|
|
141
|
+
self.enable_thinking = True # Enable thinking/reasoning
|
|
142
|
+
self.thinking_budget = 15000 # Token budget for thinking
|
|
143
|
+
|
|
144
|
+
# Load from TOML if provided
|
|
145
|
+
if config_path and os.path.exists(config_path):
|
|
146
|
+
self.load_from_toml(config_path)
|
|
147
|
+
|
|
148
|
+
def load_from_toml(self, config_path: str):
|
|
149
|
+
"""Load configuration from TOML file."""
|
|
150
|
+
config = toml.load(config_path)
|
|
151
|
+
|
|
152
|
+
# Extract generation settings
|
|
153
|
+
gen_config = config.get("generation", {})
|
|
154
|
+
self.model_name = gen_config.get("model_name", self.model_name)
|
|
155
|
+
self.num_rollouts = gen_config.get("num_rollouts", self.num_rollouts)
|
|
156
|
+
self.max_turns = gen_config.get("max_turns", self.max_turns)
|
|
157
|
+
self.difficulty = gen_config.get("difficulty", self.difficulty)
|
|
158
|
+
self.seed = gen_config.get("seed", self.seed)
|
|
159
|
+
|
|
160
|
+
# Extract service settings
|
|
161
|
+
service_config = config.get("service", {})
|
|
162
|
+
self.service_base_url = service_config.get("base_url", self.service_base_url)
|
|
163
|
+
self.service_timeout = service_config.get("timeout", self.service_timeout)
|
|
164
|
+
|
|
165
|
+
# Extract quality settings
|
|
166
|
+
quality_config = config.get("quality", {})
|
|
167
|
+
self.min_score_threshold = quality_config.get("min_score_threshold", self.min_score_threshold)
|
|
168
|
+
self.min_achievements = quality_config.get("min_achievements", self.min_achievements)
|
|
169
|
+
self.enable_thinking = quality_config.get("enable_thinking", self.enable_thinking)
|
|
170
|
+
self.thinking_budget = quality_config.get("thinking_budget", self.thinking_budget)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
# --- Crafter Action Tool ---
|
|
174
|
+
class CrafterAction(BaseTool):
|
|
175
|
+
"""Tool for performing actions in Crafter environment."""
|
|
176
|
+
|
|
177
|
+
name: str = "crafter_action"
|
|
178
|
+
description: str = "Perform an action in the Crafter environment"
|
|
179
|
+
params: List[tuple] = [
|
|
180
|
+
("action", "str", "The action to perform (e.g., 'move_north', 'collect', 'craft_wood_pickaxe')")
|
|
181
|
+
]
|
|
182
|
+
|
|
183
|
+
def __init__(self, instance_id: str, client: AsyncClient):
|
|
184
|
+
super().__init__()
|
|
185
|
+
self.instance_id = instance_id
|
|
186
|
+
self.client = client
|
|
187
|
+
self.base_url = "http://localhost:8901"
|
|
188
|
+
|
|
189
|
+
# Action mapping from string to integer
|
|
190
|
+
self.action_map = {
|
|
191
|
+
"noop": 0,
|
|
192
|
+
"move_north": 1,
|
|
193
|
+
"move_south": 2,
|
|
194
|
+
"move_east": 3,
|
|
195
|
+
"move_west": 4,
|
|
196
|
+
"attack": 5,
|
|
197
|
+
"collect": 6,
|
|
198
|
+
"craft_wood_pickaxe": 7,
|
|
199
|
+
"craft_stone_pickaxe": 8,
|
|
200
|
+
"craft_iron_pickaxe": 9,
|
|
201
|
+
"craft_wood_sword": 10,
|
|
202
|
+
"craft_stone_sword": 11,
|
|
203
|
+
"craft_iron_sword": 12,
|
|
204
|
+
"eat": 13,
|
|
205
|
+
"drink": 14,
|
|
206
|
+
"sleep": 15,
|
|
207
|
+
"place_stone": 16,
|
|
208
|
+
"place_table": 17,
|
|
209
|
+
"place_furnace": 18,
|
|
210
|
+
"place_plant": 19,
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
def _action_to_int(self, action: str) -> int:
|
|
214
|
+
"""Convert action string to integer."""
|
|
215
|
+
return self.action_map.get(action.lower(), 0)
|
|
216
|
+
|
|
217
|
+
async def _run(self, action: str) -> Dict[str, Any]:
|
|
218
|
+
"""Execute action in environment."""
|
|
219
|
+
response = await self.client.post(
|
|
220
|
+
f"{self.base_url}/env/CrafterClassic/step",
|
|
221
|
+
json={"env_id": self.instance_id, "tool_calls": [{
|
|
222
|
+
"tool_name": "interact",
|
|
223
|
+
"tool_call_id": str(uuid.uuid4()),
|
|
224
|
+
"tool_args": {"action": self._action_to_int(action)}
|
|
225
|
+
}]}
|
|
226
|
+
)
|
|
227
|
+
response.raise_for_status()
|
|
228
|
+
result = response.json()
|
|
229
|
+
|
|
230
|
+
# Return the full result for the agent to process
|
|
231
|
+
return result
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
# --- Gemini Agent ---
|
|
235
|
+
class GeminiCrafterAgent:
|
|
236
|
+
"""Agent that plays Crafter using Gemini models via synth-ai LM."""
|
|
237
|
+
|
|
238
|
+
def __init__(self, model_name: str, instance_id: str, client: AsyncClient):
|
|
239
|
+
self.model_name = model_name
|
|
240
|
+
self.instance_id = instance_id
|
|
241
|
+
self.client = client
|
|
242
|
+
|
|
243
|
+
# Initialize LM with Gemini model
|
|
244
|
+
self.lm = LM(
|
|
245
|
+
model_name=model_name,
|
|
246
|
+
formatting_model_name=model_name,
|
|
247
|
+
temperature=0.7 # Use some temperature for diversity
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
# Create action tool
|
|
251
|
+
self.action_tool = CrafterAction(instance_id, client)
|
|
252
|
+
|
|
253
|
+
# Initialize conversation history
|
|
254
|
+
self.messages = []
|
|
255
|
+
|
|
256
|
+
# System prompt
|
|
257
|
+
self.system_prompt = """You are an expert Crafter player. Your goal is to achieve as many objectives as possible in the game.
|
|
258
|
+
|
|
259
|
+
Key objectives (achievements) in order of importance:
|
|
260
|
+
1. Basic survival: collect resources, eat when hungry, drink when thirsty
|
|
261
|
+
2. Tool progression: craft pickaxe ā stone pickaxe ā iron pickaxe
|
|
262
|
+
3. Advanced goals: make iron sword, defeat enemies
|
|
263
|
+
|
|
264
|
+
Action format: Use the crafter_action tool with one of these actions:
|
|
265
|
+
- Movement: move_north, move_south, move_east, move_west
|
|
266
|
+
- Resource gathering: collect (gathers wood/stone/etc), attack (mines harder materials)
|
|
267
|
+
- Crafting: craft_wood_pickaxe, craft_stone_pickaxe, craft_iron_pickaxe, craft_wood_sword, craft_stone_sword, craft_iron_sword
|
|
268
|
+
- Survival: eat, drink, sleep
|
|
269
|
+
- Placing: place_stone, place_table, place_furnace, place_plant
|
|
270
|
+
|
|
271
|
+
Tips:
|
|
272
|
+
- Start by collecting wood (stand near trees and use 'collect')
|
|
273
|
+
- Craft a wood pickaxe early to mine stone
|
|
274
|
+
- Monitor your health, hunger, and thirst
|
|
275
|
+
- Explore to find water, coal, and iron
|
|
276
|
+
- Use the semantic map to navigate efficiently"""
|
|
277
|
+
|
|
278
|
+
# Add system message
|
|
279
|
+
self.messages.append({"role": "system", "content": self.system_prompt})
|
|
280
|
+
|
|
281
|
+
def _format_observation(self, obs_data: Dict[str, Any]) -> str:
|
|
282
|
+
"""Format observation data into readable text."""
|
|
283
|
+
lines = ["=== Current State ==="]
|
|
284
|
+
|
|
285
|
+
# Stats
|
|
286
|
+
lines.append(f"Position: ({obs_data.get('player_position', [0, 0])[0]}, {obs_data.get('player_position', [0, 0])[1]})")
|
|
287
|
+
lines.append(f"Health: {obs_data.get('health', 0)}/10")
|
|
288
|
+
lines.append(f"Hunger: {obs_data.get('hunger', 0)}/10")
|
|
289
|
+
lines.append(f"Thirst: {obs_data.get('thirst', 0)}/10")
|
|
290
|
+
|
|
291
|
+
# Inventory
|
|
292
|
+
inventory = obs_data.get('inventory', {})
|
|
293
|
+
if inventory:
|
|
294
|
+
lines.append("\nInventory:")
|
|
295
|
+
for item, count in inventory.items():
|
|
296
|
+
if count > 0:
|
|
297
|
+
lines.append(f" - {item}: {count}")
|
|
298
|
+
|
|
299
|
+
# Achievements
|
|
300
|
+
achievements = obs_data.get('achievements_dict', {})
|
|
301
|
+
unlocked = [k for k, v in achievements.items() if v]
|
|
302
|
+
if unlocked:
|
|
303
|
+
lines.append(f"\nAchievements: {', '.join(unlocked)}")
|
|
304
|
+
|
|
305
|
+
# Local view (simplified)
|
|
306
|
+
lines.append("\nNearby (5x5 grid around you):")
|
|
307
|
+
semantic_map = obs_data.get('semantic_map', [])
|
|
308
|
+
if semantic_map:
|
|
309
|
+
# Get center region of semantic map
|
|
310
|
+
# Assuming semantic map is flattened, reconstruct as 2D
|
|
311
|
+
map_size = int(np.sqrt(len(semantic_map)))
|
|
312
|
+
if map_size * map_size == len(semantic_map):
|
|
313
|
+
map_2d = np.array(semantic_map).reshape(map_size, map_size)
|
|
314
|
+
center = map_size // 2
|
|
315
|
+
view_radius = 2
|
|
316
|
+
|
|
317
|
+
# Simple ID to symbol mapping
|
|
318
|
+
id_to_symbol = {
|
|
319
|
+
0: '.', # void/empty
|
|
320
|
+
1: 'G', # grass
|
|
321
|
+
2: 'T', # tree
|
|
322
|
+
3: 'S', # stone
|
|
323
|
+
4: 'W', # water
|
|
324
|
+
5: 'C', # coal
|
|
325
|
+
6: 'I', # iron
|
|
326
|
+
7: '@', # player
|
|
327
|
+
8: 'E', # enemy
|
|
328
|
+
9: 'F', # furnace
|
|
329
|
+
10: 'P', # plant
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
for dy in range(-view_radius, view_radius + 1):
|
|
333
|
+
row = []
|
|
334
|
+
for dx in range(-view_radius, view_radius + 1):
|
|
335
|
+
y, x = center + dy, center + dx
|
|
336
|
+
if 0 <= y < map_size and 0 <= x < map_size:
|
|
337
|
+
cell_id = int(map_2d[y, x])
|
|
338
|
+
symbol = id_to_symbol.get(cell_id, '?')
|
|
339
|
+
if dy == 0 and dx == 0:
|
|
340
|
+
symbol = '@' # Player position
|
|
341
|
+
row.append(symbol)
|
|
342
|
+
else:
|
|
343
|
+
row.append(' ')
|
|
344
|
+
lines.append(' ' + ' '.join(row))
|
|
345
|
+
|
|
346
|
+
return '\n'.join(lines)
|
|
347
|
+
|
|
348
|
+
async def step(self, obs_data: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
|
|
349
|
+
"""Take a step in the environment."""
|
|
350
|
+
# Format observation
|
|
351
|
+
obs_text = self._format_observation(obs_data)
|
|
352
|
+
|
|
353
|
+
# Add observation to conversation
|
|
354
|
+
self.messages.append({"role": "user", "content": obs_text})
|
|
355
|
+
|
|
356
|
+
# Get action from LM with tool
|
|
357
|
+
response = await self.lm.ainvoke(
|
|
358
|
+
self.messages,
|
|
359
|
+
tools=[self.action_tool],
|
|
360
|
+
tool_choice="required"
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
# Extract action from response
|
|
364
|
+
action = None
|
|
365
|
+
thinking = None
|
|
366
|
+
|
|
367
|
+
# Handle response based on type
|
|
368
|
+
if hasattr(response, 'tool_calls') and response.tool_calls:
|
|
369
|
+
# Tool was called
|
|
370
|
+
tool_call = response.tool_calls[0]
|
|
371
|
+
action = tool_call.function.arguments.get('action', 'noop')
|
|
372
|
+
|
|
373
|
+
# Add assistant message
|
|
374
|
+
self.messages.append({
|
|
375
|
+
"role": "assistant",
|
|
376
|
+
"content": response.content or f"Taking action: {action}",
|
|
377
|
+
"tool_calls": response.tool_calls
|
|
378
|
+
})
|
|
379
|
+
else:
|
|
380
|
+
# No tool call, extract action from text
|
|
381
|
+
content = response.content if hasattr(response, 'content') else str(response)
|
|
382
|
+
self.messages.append({"role": "assistant", "content": content})
|
|
383
|
+
action = "noop"
|
|
384
|
+
|
|
385
|
+
# Extract thinking if available
|
|
386
|
+
if hasattr(response, '_raw_response'):
|
|
387
|
+
raw = response._raw_response
|
|
388
|
+
if isinstance(raw, dict) and 'thinking' in raw:
|
|
389
|
+
thinking = raw['thinking']
|
|
390
|
+
|
|
391
|
+
# Execute action
|
|
392
|
+
result = await self.action_tool._run(action)
|
|
393
|
+
|
|
394
|
+
# Add tool response
|
|
395
|
+
if hasattr(response, 'tool_calls') and response.tool_calls:
|
|
396
|
+
self.messages.append({
|
|
397
|
+
"role": "tool",
|
|
398
|
+
"content": json.dumps(result),
|
|
399
|
+
"tool_call_id": response.tool_calls[0].id
|
|
400
|
+
})
|
|
401
|
+
|
|
402
|
+
return action, {"thinking": thinking} if thinking else {}
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
# --- Main Generation Functions ---
|
|
406
|
+
async def generate_trajectory(config: GenerationConfig, instance_num: int) -> Optional[Dict[str, Any]]:
|
|
407
|
+
"""Generate a single trajectory using Gemini model."""
|
|
408
|
+
async with AsyncClient(timeout=config.service_timeout) as client:
|
|
409
|
+
try:
|
|
410
|
+
# Create task instance
|
|
411
|
+
task_instance = CrafterTaskInstance(
|
|
412
|
+
id=uuid.uuid4(),
|
|
413
|
+
impetus=Impetus(
|
|
414
|
+
instructions="Survive and unlock achievements. Focus on collecting resources, crafting tools, and progressing through the game."
|
|
415
|
+
),
|
|
416
|
+
intent=Intent(
|
|
417
|
+
rubric={"goal": "Unlock as many achievements as possible."},
|
|
418
|
+
gold_trajectories=None,
|
|
419
|
+
gold_state_diff={},
|
|
420
|
+
deterministic_eval_functions=[]
|
|
421
|
+
),
|
|
422
|
+
metadata=CrafterTaskInstanceMetadata(
|
|
423
|
+
difficulty=config.difficulty,
|
|
424
|
+
seed=config.seed + instance_num,
|
|
425
|
+
num_trees_radius=4,
|
|
426
|
+
num_cows_radius=2,
|
|
427
|
+
num_hostiles_radius=0 if config.difficulty == "easy" else 2,
|
|
428
|
+
world_config="normal"
|
|
429
|
+
),
|
|
430
|
+
is_reproducible=True,
|
|
431
|
+
initial_engine_snapshot=None # will be filled lazily when env starts
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
# Initialize environment
|
|
435
|
+
create_response = await client.post(
|
|
436
|
+
f"{config.service_base_url}/env/CrafterClassic/initialize",
|
|
437
|
+
json={"task_instance": await task_instance.serialize()}
|
|
438
|
+
)
|
|
439
|
+
create_response.raise_for_status()
|
|
440
|
+
env_data = create_response.json()
|
|
441
|
+
instance_id = env_data["env_id"]
|
|
442
|
+
|
|
443
|
+
print(f"š® Instance {instance_num}: Created {instance_id}")
|
|
444
|
+
|
|
445
|
+
# Get initial observation
|
|
446
|
+
obs_data = env_data.get("observations", [{}])[0]
|
|
447
|
+
|
|
448
|
+
# Parse the observation to get structured data
|
|
449
|
+
if "human_observation" in obs_data:
|
|
450
|
+
obs_text = obs_data["human_observation"]
|
|
451
|
+
# Parse structured data from observation text
|
|
452
|
+
obs_data = parse_observation_text(obs_text)
|
|
453
|
+
obs_data["raw_text"] = obs_text
|
|
454
|
+
|
|
455
|
+
# Create agent
|
|
456
|
+
agent = GeminiCrafterAgent(
|
|
457
|
+
model_name=config.model_name,
|
|
458
|
+
instance_id=instance_id,
|
|
459
|
+
client=client
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
# Track trajectory data
|
|
463
|
+
trajectory = {
|
|
464
|
+
"instance_id": instance_id,
|
|
465
|
+
"instance_num": instance_num,
|
|
466
|
+
"model": config.model_name,
|
|
467
|
+
"start_time": datetime.now().isoformat(),
|
|
468
|
+
"actions": [],
|
|
469
|
+
"observations": [],
|
|
470
|
+
"llm_calls": [],
|
|
471
|
+
"achievements": {},
|
|
472
|
+
"final_score": 0.0
|
|
473
|
+
}
|
|
474
|
+
|
|
475
|
+
# Run episode
|
|
476
|
+
for turn in range(config.max_turns):
|
|
477
|
+
# Get action from agent
|
|
478
|
+
action, metadata = await agent.step(obs_data)
|
|
479
|
+
|
|
480
|
+
# Store LLM call data
|
|
481
|
+
llm_call = {
|
|
482
|
+
"turn": turn,
|
|
483
|
+
"messages": agent.messages[-3:], # Last 3 messages (user, assistant, tool)
|
|
484
|
+
"action": action,
|
|
485
|
+
"metadata": metadata
|
|
486
|
+
}
|
|
487
|
+
trajectory["llm_calls"].append(llm_call)
|
|
488
|
+
|
|
489
|
+
# Store action and observation
|
|
490
|
+
trajectory["actions"].append(action)
|
|
491
|
+
trajectory["observations"].append(obs_data)
|
|
492
|
+
|
|
493
|
+
# Check if done
|
|
494
|
+
if obs_data.get("done", False):
|
|
495
|
+
print(f"ā
Instance {instance_num}: Episode done at turn {turn}")
|
|
496
|
+
break
|
|
497
|
+
|
|
498
|
+
# Step in environment and get next observation
|
|
499
|
+
step_response = await client.post(
|
|
500
|
+
f"{config.service_base_url}/env/CrafterClassic/step",
|
|
501
|
+
json={"env_id": instance_id, "tool_calls": [{
|
|
502
|
+
"tool_name": "interact",
|
|
503
|
+
"tool_call_id": str(uuid.uuid4()),
|
|
504
|
+
"tool_args": {"action": agent.action_tool._action_to_int(action)}
|
|
505
|
+
}]}
|
|
506
|
+
)
|
|
507
|
+
step_response.raise_for_status()
|
|
508
|
+
step_data = step_response.json()
|
|
509
|
+
|
|
510
|
+
# Extract observation
|
|
511
|
+
if "observations" in step_data and step_data["observations"]:
|
|
512
|
+
obs = step_data["observations"][0]
|
|
513
|
+
if "human_observation" in obs:
|
|
514
|
+
obs_data = parse_observation_text(obs["human_observation"])
|
|
515
|
+
obs_data["done"] = step_data.get("done", False)
|
|
516
|
+
else:
|
|
517
|
+
obs_data = {"done": step_data.get("done", False)}
|
|
518
|
+
else:
|
|
519
|
+
obs_data = {"done": True}
|
|
520
|
+
|
|
521
|
+
# Get final achievements
|
|
522
|
+
trajectory["achievements"] = obs_data.get("achievements_dict", {})
|
|
523
|
+
trajectory["end_time"] = datetime.now().isoformat()
|
|
524
|
+
|
|
525
|
+
# Also get achievements from last observation if available
|
|
526
|
+
if trajectory["observations"] and "achievements_dict" in trajectory["observations"][-1]:
|
|
527
|
+
trajectory["achievements"].update(trajectory["observations"][-1]["achievements_dict"])
|
|
528
|
+
|
|
529
|
+
# Calculate score
|
|
530
|
+
unlocked_achievements = sum(1 for v in trajectory["achievements"].values() if v)
|
|
531
|
+
trajectory["final_score"] = float(unlocked_achievements)
|
|
532
|
+
|
|
533
|
+
print(f"š Instance {instance_num}: Score={trajectory['final_score']:.1f}, Achievements={unlocked_achievements}")
|
|
534
|
+
|
|
535
|
+
# Terminate instance
|
|
536
|
+
await client.post(
|
|
537
|
+
f"{config.service_base_url}/env/CrafterClassic/terminate",
|
|
538
|
+
json={"env_id": instance_id}
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
return trajectory
|
|
542
|
+
|
|
543
|
+
except Exception as e:
|
|
544
|
+
print(f"ā Instance {instance_num}: Error - {e}")
|
|
545
|
+
return None
|
|
546
|
+
|
|
547
|
+
|
|
548
|
+
async def generate_all_trajectories(config: GenerationConfig) -> List[Dict[str, Any]]:
|
|
549
|
+
"""Generate multiple trajectories concurrently."""
|
|
550
|
+
print(f"\nš Generating {config.num_rollouts} trajectories with {config.model_name}")
|
|
551
|
+
|
|
552
|
+
# Create tasks
|
|
553
|
+
tasks = [generate_trajectory(config, i) for i in range(config.num_rollouts)]
|
|
554
|
+
|
|
555
|
+
# Run with progress bar
|
|
556
|
+
trajectories = []
|
|
557
|
+
with tqdm_asyncio(total=config.num_rollouts, desc="Generating") as pbar:
|
|
558
|
+
for coro in asyncio.as_completed(tasks):
|
|
559
|
+
trajectory = await coro
|
|
560
|
+
if trajectory:
|
|
561
|
+
trajectories.append(trajectory)
|
|
562
|
+
pbar.update(1)
|
|
563
|
+
|
|
564
|
+
return trajectories
|
|
565
|
+
|
|
566
|
+
|
|
567
|
+
def filter_high_quality_trajectories(trajectories: List[Dict[str, Any]],
|
|
568
|
+
min_score: float = 2.0,
|
|
569
|
+
min_achievements: int = 3) -> List[Dict[str, Any]]:
|
|
570
|
+
"""Filter trajectories based on quality criteria."""
|
|
571
|
+
filtered = []
|
|
572
|
+
|
|
573
|
+
for traj in trajectories:
|
|
574
|
+
# Count achievements
|
|
575
|
+
achievements = traj.get("achievements", {})
|
|
576
|
+
num_achievements = sum(1 for v in achievements.values() if v)
|
|
577
|
+
|
|
578
|
+
# Calculate score (could be more sophisticated)
|
|
579
|
+
score = traj.get("final_score", 0.0)
|
|
580
|
+
|
|
581
|
+
# Apply filters
|
|
582
|
+
if score >= min_score and num_achievements >= min_achievements:
|
|
583
|
+
filtered.append(traj)
|
|
584
|
+
|
|
585
|
+
print(f"\nš Filtering Results:")
|
|
586
|
+
print(f" Total trajectories: {len(trajectories)}")
|
|
587
|
+
if trajectories:
|
|
588
|
+
print(f" High quality: {len(filtered)} ({len(filtered)/len(trajectories)*100:.1f}%)")
|
|
589
|
+
else:
|
|
590
|
+
print(f" High quality: 0 (no trajectories generated)")
|
|
591
|
+
|
|
592
|
+
return filtered
|
|
593
|
+
|
|
594
|
+
|
|
595
|
+
def convert_to_vertex_ai_format(trajectories: List[Dict[str, Any]], output_path: Path):
|
|
596
|
+
"""Convert trajectories to Vertex AI fine-tuning format."""
|
|
597
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
598
|
+
|
|
599
|
+
examples = []
|
|
600
|
+
|
|
601
|
+
for traj in trajectories:
|
|
602
|
+
# Extract LLM calls
|
|
603
|
+
for llm_call in traj.get("llm_calls", []):
|
|
604
|
+
messages = llm_call.get("messages", [])
|
|
605
|
+
|
|
606
|
+
# Skip if not enough messages
|
|
607
|
+
if len(messages) < 2:
|
|
608
|
+
continue
|
|
609
|
+
|
|
610
|
+
# Convert to Vertex AI format
|
|
611
|
+
# Need user message and assistant response
|
|
612
|
+
user_msg = None
|
|
613
|
+
assistant_msg = None
|
|
614
|
+
|
|
615
|
+
for msg in messages:
|
|
616
|
+
if msg["role"] == "user":
|
|
617
|
+
user_msg = msg["content"]
|
|
618
|
+
elif msg["role"] == "assistant":
|
|
619
|
+
assistant_msg = msg["content"]
|
|
620
|
+
|
|
621
|
+
if user_msg and assistant_msg:
|
|
622
|
+
example = {
|
|
623
|
+
"messages": [
|
|
624
|
+
{"role": "user", "content": user_msg},
|
|
625
|
+
{"role": "assistant", "content": assistant_msg}
|
|
626
|
+
]
|
|
627
|
+
}
|
|
628
|
+
examples.append(example)
|
|
629
|
+
|
|
630
|
+
# Write JSONL
|
|
631
|
+
with open(output_path, 'w') as f:
|
|
632
|
+
for example in examples:
|
|
633
|
+
f.write(json.dumps(example) + '\n')
|
|
634
|
+
|
|
635
|
+
print(f"\nā
Wrote {len(examples)} examples to {output_path}")
|
|
636
|
+
return len(examples)
|
|
637
|
+
|
|
638
|
+
|
|
639
|
+
def save_trajectories(trajectories: List[Dict[str, Any]], traces_dir: Path):
|
|
640
|
+
"""Save trajectories to disk."""
|
|
641
|
+
traces_dir.mkdir(parents=True, exist_ok=True)
|
|
642
|
+
|
|
643
|
+
for i, traj in enumerate(trajectories):
|
|
644
|
+
filename = f"trajectory_{i:04d}.json"
|
|
645
|
+
with open(traces_dir / filename, 'w') as f:
|
|
646
|
+
json.dump(traj, f, indent=2)
|
|
647
|
+
|
|
648
|
+
print(f"š¾ Saved {len(trajectories)} trajectories to {traces_dir}")
|
|
649
|
+
|
|
650
|
+
|
|
651
|
+
# --- Main ---
|
|
652
|
+
async def main():
|
|
653
|
+
parser = argparse.ArgumentParser(description="Generate Gemini fine-tuning data for Crafter")
|
|
654
|
+
parser.add_argument("--config", type=str, help="Path to TOML config file")
|
|
655
|
+
parser.add_argument("--num-rollouts", type=int, help="Number of rollouts to generate")
|
|
656
|
+
parser.add_argument("--model", type=str, help="Gemini model name")
|
|
657
|
+
parser.add_argument("--filter-only", action="store_true", help="Only filter existing traces")
|
|
658
|
+
parser.add_argument("--min-achievements", type=int, help="Minimum achievements for filtering")
|
|
659
|
+
|
|
660
|
+
args = parser.parse_args()
|
|
661
|
+
|
|
662
|
+
# Load config
|
|
663
|
+
config = GenerationConfig(args.config)
|
|
664
|
+
|
|
665
|
+
# Override with command line args
|
|
666
|
+
if args.num_rollouts:
|
|
667
|
+
config.num_rollouts = args.num_rollouts
|
|
668
|
+
if args.model:
|
|
669
|
+
config.model_name = args.model
|
|
670
|
+
if args.min_achievements:
|
|
671
|
+
config.min_achievements = args.min_achievements
|
|
672
|
+
|
|
673
|
+
if args.filter_only:
|
|
674
|
+
# Filter existing trajectories
|
|
675
|
+
print("š Filtering existing trajectories...")
|
|
676
|
+
|
|
677
|
+
# Load trajectories
|
|
678
|
+
trajectories = []
|
|
679
|
+
for trace_file in sorted(config.traces_dir.glob("*.json")):
|
|
680
|
+
with open(trace_file) as f:
|
|
681
|
+
trajectories.append(json.load(f))
|
|
682
|
+
|
|
683
|
+
# Filter
|
|
684
|
+
filtered = filter_high_quality_trajectories(
|
|
685
|
+
trajectories,
|
|
686
|
+
min_score=config.min_score_threshold,
|
|
687
|
+
min_achievements=config.min_achievements
|
|
688
|
+
)
|
|
689
|
+
|
|
690
|
+
# Convert to JSONL
|
|
691
|
+
output_path = config.ft_data_dir / "crafter_gemini_ft.jsonl"
|
|
692
|
+
num_examples = convert_to_vertex_ai_format(filtered, output_path)
|
|
693
|
+
|
|
694
|
+
print(f"\nšÆ Summary:")
|
|
695
|
+
print(f" Filtered trajectories: {len(filtered)}")
|
|
696
|
+
print(f" Total training examples: {num_examples}")
|
|
697
|
+
|
|
698
|
+
else:
|
|
699
|
+
# Generate new trajectories
|
|
700
|
+
trajectories = await generate_all_trajectories(config)
|
|
701
|
+
|
|
702
|
+
# Save all trajectories
|
|
703
|
+
save_trajectories(trajectories, config.traces_dir)
|
|
704
|
+
|
|
705
|
+
# Filter high quality
|
|
706
|
+
filtered = filter_high_quality_trajectories(
|
|
707
|
+
trajectories,
|
|
708
|
+
min_score=config.min_score_threshold,
|
|
709
|
+
min_achievements=config.min_achievements
|
|
710
|
+
)
|
|
711
|
+
|
|
712
|
+
# Convert to JSONL
|
|
713
|
+
output_path = config.ft_data_dir / "crafter_gemini_ft.jsonl"
|
|
714
|
+
num_examples = convert_to_vertex_ai_format(filtered, output_path)
|
|
715
|
+
|
|
716
|
+
print(f"\nšÆ Summary:")
|
|
717
|
+
print(f" Generated trajectories: {len(trajectories)}")
|
|
718
|
+
print(f" High quality trajectories: {len(filtered)}")
|
|
719
|
+
print(f" Total training examples: {num_examples}")
|
|
720
|
+
print(f" Output: {output_path}")
|
|
721
|
+
|
|
722
|
+
|
|
723
|
+
if __name__ == "__main__":
|
|
724
|
+
asyncio.run(main())
|