synth-ai 0.2.17__py3-none-any.whl → 0.2.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/baseline/banking77_baseline.py +204 -0
- examples/baseline/crafter_baseline.py +407 -0
- examples/baseline/pokemon_red_baseline.py +326 -0
- examples/baseline/simple_baseline.py +56 -0
- examples/baseline/warming_up_to_rl_baseline.py +239 -0
- examples/blog_posts/gepa/README.md +355 -0
- examples/blog_posts/gepa/configs/banking77_gepa_local.toml +95 -0
- examples/blog_posts/gepa/configs/banking77_gepa_test.toml +82 -0
- examples/blog_posts/gepa/configs/banking77_mipro_local.toml +52 -0
- examples/blog_posts/gepa/configs/hotpotqa_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/hotpotqa_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/hotpotqa_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/hover_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/hover_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/hover_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/ifbench_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/ifbench_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/ifbench_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/pupa_gepa_local.toml +60 -0
- examples/blog_posts/gepa/configs/pupa_mipro_local.toml +54 -0
- examples/blog_posts/gepa/deploy_banking77_task_app.sh +41 -0
- examples/blog_posts/gepa/gepa_baseline.py +204 -0
- examples/blog_posts/gepa/query_prompts_example.py +97 -0
- examples/blog_posts/gepa/run_gepa_banking77.sh +87 -0
- examples/blog_posts/gepa/task_apps.py +105 -0
- examples/blog_posts/gepa/test_gepa_local.sh +67 -0
- examples/blog_posts/gepa/verify_banking77_setup.sh +123 -0
- examples/blog_posts/pokemon_vl/configs/eval_gpt5nano.toml +26 -0
- examples/blog_posts/pokemon_vl/configs/eval_qwen3_vl.toml +12 -10
- examples/blog_posts/pokemon_vl/configs/train_rl_from_sft.toml +1 -0
- examples/blog_posts/pokemon_vl/extract_images.py +239 -0
- examples/blog_posts/pokemon_vl/pokemon_vl_baseline.py +326 -0
- examples/blog_posts/pokemon_vl/run_eval_extract_images.py +209 -0
- examples/blog_posts/pokemon_vl/run_qwen_eval_extract_images.py +212 -0
- examples/blog_posts/pokemon_vl/text_box_analysis.md +106 -0
- examples/blog_posts/warming_up_to_rl/ARCHITECTURE.md +195 -0
- examples/blog_posts/warming_up_to_rl/FINAL_TEST_RESULTS.md +127 -0
- examples/blog_posts/warming_up_to_rl/INFERENCE_SUCCESS.md +132 -0
- examples/blog_posts/warming_up_to_rl/SMOKE_TESTING.md +164 -0
- examples/blog_posts/warming_up_to_rl/SMOKE_TEST_COMPLETE.md +253 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_baseline_qwen32b_10x20.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b_10x20.toml +26 -0
- examples/blog_posts/warming_up_to_rl/configs/filter_high_reward_dataset.toml +1 -1
- examples/blog_posts/warming_up_to_rl/configs/smoke_test.toml +75 -0
- examples/blog_posts/warming_up_to_rl/configs/train_rl_from_sft.toml +60 -10
- examples/blog_posts/warming_up_to_rl/configs/train_sft_qwen4b.toml +1 -1
- examples/blog_posts/warming_up_to_rl/warming_up_to_rl_baseline.py +187 -0
- examples/multi_step/configs/VERILOG_REWARDS.md +4 -0
- examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +4 -0
- examples/multi_step/configs/crafter_rl_outcome.toml +1 -0
- examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +1 -0
- examples/multi_step/configs/crafter_rl_stepwise_simple.toml +1 -0
- examples/rl/configs/rl_from_base_qwen17.toml +1 -0
- examples/swe/task_app/hosted/inference/openai_client.py +0 -34
- examples/swe/task_app/hosted/policy_routes.py +17 -0
- examples/swe/task_app/hosted/rollout.py +4 -2
- examples/task_apps/banking77/__init__.py +6 -0
- examples/task_apps/banking77/banking77_task_app.py +841 -0
- examples/task_apps/banking77/deploy_wrapper.py +46 -0
- examples/task_apps/crafter/CREATE_SFT_DATASET.md +4 -0
- examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +4 -0
- examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +4 -0
- examples/task_apps/crafter/task_app/grpo_crafter.py +24 -2
- examples/task_apps/crafter/task_app/synth_envs_hosted/hosted_app.py +49 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +355 -58
- examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +68 -7
- examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +78 -21
- examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +194 -1
- examples/task_apps/gepa_benchmarks/__init__.py +7 -0
- examples/task_apps/gepa_benchmarks/common.py +260 -0
- examples/task_apps/gepa_benchmarks/hotpotqa_task_app.py +507 -0
- examples/task_apps/gepa_benchmarks/hover_task_app.py +436 -0
- examples/task_apps/gepa_benchmarks/ifbench_task_app.py +563 -0
- examples/task_apps/gepa_benchmarks/pupa_task_app.py +460 -0
- examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +4 -0
- examples/task_apps/pokemon_red/task_app.py +254 -36
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +1 -0
- examples/warming_up_to_rl/task_app/grpo_crafter.py +53 -4
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +49 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +152 -41
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +31 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +33 -3
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +67 -0
- examples/workflows/math_rl/configs/rl_from_base_qwen17.toml +1 -0
- synth_ai/api/train/builders.py +90 -1
- synth_ai/api/train/cli.py +396 -21
- synth_ai/api/train/config_finder.py +13 -2
- synth_ai/api/train/configs/__init__.py +15 -1
- synth_ai/api/train/configs/prompt_learning.py +442 -0
- synth_ai/api/train/configs/rl.py +29 -0
- synth_ai/api/train/task_app.py +1 -1
- synth_ai/api/train/validators.py +277 -0
- synth_ai/baseline/__init__.py +25 -0
- synth_ai/baseline/config.py +209 -0
- synth_ai/baseline/discovery.py +214 -0
- synth_ai/baseline/execution.py +146 -0
- synth_ai/cli/__init__.py +85 -17
- synth_ai/cli/__main__.py +0 -0
- synth_ai/cli/claude.py +70 -0
- synth_ai/cli/codex.py +84 -0
- synth_ai/cli/commands/__init__.py +1 -0
- synth_ai/cli/commands/baseline/__init__.py +12 -0
- synth_ai/cli/commands/baseline/core.py +637 -0
- synth_ai/cli/commands/baseline/list.py +93 -0
- synth_ai/cli/commands/eval/core.py +13 -10
- synth_ai/cli/commands/filter/core.py +53 -17
- synth_ai/cli/commands/help/core.py +0 -1
- synth_ai/cli/commands/smoke/__init__.py +7 -0
- synth_ai/cli/commands/smoke/core.py +1436 -0
- synth_ai/cli/commands/status/subcommands/pricing.py +22 -0
- synth_ai/cli/commands/status/subcommands/usage.py +203 -0
- synth_ai/cli/commands/train/judge_schemas.py +1 -0
- synth_ai/cli/commands/train/judge_validation.py +1 -0
- synth_ai/cli/commands/train/validation.py +0 -57
- synth_ai/cli/demo.py +35 -3
- synth_ai/cli/deploy/__init__.py +40 -25
- synth_ai/cli/deploy.py +162 -0
- synth_ai/cli/legacy_root_backup.py +14 -8
- synth_ai/cli/opencode.py +107 -0
- synth_ai/cli/root.py +9 -5
- synth_ai/cli/task_app_deploy.py +1 -1
- synth_ai/cli/task_apps.py +53 -53
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +7 -4
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +9 -5
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +4 -3
- synth_ai/judge_schemas.py +1 -0
- synth_ai/learning/__init__.py +10 -0
- synth_ai/learning/prompt_learning_client.py +276 -0
- synth_ai/learning/prompt_learning_types.py +184 -0
- synth_ai/pricing/__init__.py +2 -0
- synth_ai/pricing/model_pricing.py +57 -0
- synth_ai/streaming/handlers.py +53 -4
- synth_ai/streaming/streamer.py +19 -0
- synth_ai/task/apps/__init__.py +1 -0
- synth_ai/task/config.py +2 -0
- synth_ai/task/tracing_utils.py +25 -25
- synth_ai/task/validators.py +44 -8
- synth_ai/task_app_cfgs.py +21 -0
- synth_ai/tracing_v3/config.py +162 -19
- synth_ai/tracing_v3/constants.py +1 -1
- synth_ai/tracing_v3/db_config.py +24 -38
- synth_ai/tracing_v3/storage/config.py +47 -13
- synth_ai/tracing_v3/storage/factory.py +3 -3
- synth_ai/tracing_v3/turso/daemon.py +113 -11
- synth_ai/tracing_v3/turso/native_manager.py +92 -16
- synth_ai/types.py +8 -0
- synth_ai/urls.py +11 -0
- synth_ai/utils/__init__.py +30 -1
- synth_ai/utils/agents.py +74 -0
- synth_ai/utils/bin.py +39 -0
- synth_ai/utils/cli.py +149 -5
- synth_ai/utils/env.py +17 -17
- synth_ai/utils/json.py +72 -0
- synth_ai/utils/modal.py +283 -1
- synth_ai/utils/paths.py +48 -0
- synth_ai/utils/uvicorn.py +113 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/METADATA +102 -4
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/RECORD +162 -88
- synth_ai/cli/commands/deploy/__init__.py +0 -23
- synth_ai/cli/commands/deploy/core.py +0 -614
- synth_ai/cli/commands/deploy/errors.py +0 -72
- synth_ai/cli/commands/deploy/validation.py +0 -11
- synth_ai/cli/deploy/core.py +0 -5
- synth_ai/cli/deploy/errors.py +0 -23
- synth_ai/cli/deploy/validation.py +0 -5
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
[eval]
|
|
2
|
+
app_id = "pokemon_red"
|
|
3
|
+
task_app_url = "http://127.0.0.1:8914"
|
|
4
|
+
model = "gpt-5-nano"
|
|
5
|
+
seeds = [0] # Single seed for testing
|
|
6
|
+
max_turns = 10 # 10 LLM calls per episode to allow more progress
|
|
7
|
+
concurrency = 1 # Run 1 rollout
|
|
8
|
+
env_name = "pokemon_red"
|
|
9
|
+
policy_name = "pokemon_vl_qwen3_vl" # Reuse policy config, will override model
|
|
10
|
+
trace_format = "full"
|
|
11
|
+
return_trace = true
|
|
12
|
+
|
|
13
|
+
[eval.policy_config]
|
|
14
|
+
provider = "openai" # Use OpenAI API for gpt-5-nano
|
|
15
|
+
model = "gpt-5-nano"
|
|
16
|
+
inference_url = "https://api.openai.com/v1"
|
|
17
|
+
temperature = 0.7
|
|
18
|
+
top_p = 0.95
|
|
19
|
+
max_tokens = 512
|
|
20
|
+
use_vision = true
|
|
21
|
+
image_only_mode = false
|
|
22
|
+
max_llm_calls = 10
|
|
23
|
+
|
|
24
|
+
[eval.env_config.env_params]
|
|
25
|
+
max_steps_per_episode = 100 # Allow time to achieve milestones
|
|
26
|
+
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
[eval]
|
|
2
2
|
app_id = "pokemon_red"
|
|
3
|
-
task_app_url = "
|
|
4
|
-
model = "Qwen/Qwen3-VL-
|
|
5
|
-
seeds = [10, 11
|
|
6
|
-
max_turns = 10
|
|
7
|
-
concurrency = 2
|
|
3
|
+
task_app_url = "http://127.0.0.1:8914"
|
|
4
|
+
model = "Qwen/Qwen3-VL-30B-A3B-Thinking" # Larger thinking variant - needs more time to load
|
|
5
|
+
seeds = [10, 11] # 2 seeds for quick testing
|
|
6
|
+
max_turns = 10 # 10 LLM calls per episode to allow more progress
|
|
7
|
+
concurrency = 2 # Run 2 rollouts in parallel
|
|
8
8
|
env_name = "pokemon_red"
|
|
9
9
|
policy_name = "pokemon_vl_qwen3_vl"
|
|
10
10
|
trace_format = "full"
|
|
@@ -12,14 +12,16 @@ return_trace = true
|
|
|
12
12
|
|
|
13
13
|
[eval.policy_config]
|
|
14
14
|
provider = "synth" # Use Synth internal API for vision models
|
|
15
|
-
model = "Qwen/Qwen3-VL-
|
|
16
|
-
inference_url = "
|
|
17
|
-
temperature = 1.0 # Higher temperature to encourage
|
|
15
|
+
model = "Qwen/Qwen3-VL-30B-A3B-Thinking" # Larger thinking variant - needs more time to load
|
|
16
|
+
inference_url = "https://synth-laboratories-dev--learning-v2-service-fastapi-app.modal.run/chat/completions"
|
|
17
|
+
temperature = 1.0 # Higher temperature to encourage exploration
|
|
18
18
|
top_p = 0.95
|
|
19
|
-
max_tokens =
|
|
19
|
+
max_tokens = 2048 # Reduced to avoid token budget issues
|
|
20
20
|
use_vision = true
|
|
21
21
|
image_only_mode = false
|
|
22
22
|
max_llm_calls = 10
|
|
23
|
+
thinking_mode = "think" # Enable thinking/reasoning mode
|
|
24
|
+
thinking_budget = 3072 # Increased token budget for reasoning
|
|
23
25
|
|
|
24
26
|
[eval.env_config.env_params]
|
|
25
|
-
max_steps_per_episode =
|
|
27
|
+
max_steps_per_episode = 100 # Increased from 3 to allow time to achieve milestones
|
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Extract images from pokemon_vl trace database or trace JSON file and save to images_gpt5 directory.
|
|
3
|
+
|
|
4
|
+
Usage:
|
|
5
|
+
# From trace database:
|
|
6
|
+
python extract_images.py --trace-db traces/v3/pokemon_vl_gpt5nano.db
|
|
7
|
+
|
|
8
|
+
# From trace JSON file:
|
|
9
|
+
python extract_images.py --trace-json trace.json
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import argparse
|
|
13
|
+
import base64
|
|
14
|
+
import json
|
|
15
|
+
import sqlite3
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
from synth_ai.tracing_v3.trace_utils import load_session_trace
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def extract_image_urls_from_content(content: Any) -> list[str]:
|
|
23
|
+
"""Extract image URLs from message content."""
|
|
24
|
+
urls = []
|
|
25
|
+
|
|
26
|
+
if isinstance(content, list):
|
|
27
|
+
for part in content:
|
|
28
|
+
if isinstance(part, dict):
|
|
29
|
+
if part.get("type") == "image_url" and "image_url" in part:
|
|
30
|
+
url = part["image_url"].get("url")
|
|
31
|
+
if isinstance(url, str) and url.startswith("data:image"):
|
|
32
|
+
urls.append(url)
|
|
33
|
+
elif part.get("type") == "image":
|
|
34
|
+
img = part.get("image")
|
|
35
|
+
if isinstance(img, str) and img.startswith("data:image"):
|
|
36
|
+
urls.append(img)
|
|
37
|
+
elif isinstance(content, str):
|
|
38
|
+
# Check if it's a JSON string
|
|
39
|
+
try:
|
|
40
|
+
parsed = json.loads(content)
|
|
41
|
+
return extract_image_urls_from_content(parsed)
|
|
42
|
+
except:
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
return urls
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def extract_state_info_from_message(message: dict[str, Any]) -> dict[str, Any]:
|
|
49
|
+
"""Extract state info from message metadata or content."""
|
|
50
|
+
metadata = message.get("metadata", {})
|
|
51
|
+
state = {}
|
|
52
|
+
|
|
53
|
+
# Try to get state from metadata
|
|
54
|
+
if "system_state_before" in metadata:
|
|
55
|
+
state_before = metadata["system_state_before"]
|
|
56
|
+
if isinstance(state_before, dict):
|
|
57
|
+
obs = state_before.get("obs", {})
|
|
58
|
+
state.update({
|
|
59
|
+
"position": obs.get("position", "?"),
|
|
60
|
+
"map_id": obs.get("map_id", "?"),
|
|
61
|
+
"player_x": obs.get("player_x", "?"),
|
|
62
|
+
"player_y": obs.get("player_y", "?"),
|
|
63
|
+
"text_box_active": obs.get("text_box_active", False),
|
|
64
|
+
})
|
|
65
|
+
|
|
66
|
+
# Try to extract from content text
|
|
67
|
+
content = message.get("content", "")
|
|
68
|
+
if isinstance(content, str) and "position" in content:
|
|
69
|
+
try:
|
|
70
|
+
# Look for state summary in content
|
|
71
|
+
if "State summary:" in content:
|
|
72
|
+
parts = content.split("State summary:")
|
|
73
|
+
if len(parts) > 1:
|
|
74
|
+
import ast
|
|
75
|
+
state_str = parts[1].split("'")[0] if "'" not in parts[1] else parts[1]
|
|
76
|
+
try:
|
|
77
|
+
state_dict = ast.literal_eval(state_str.split("'")[0] if "'" in state_str else state_str)
|
|
78
|
+
if isinstance(state_dict, dict):
|
|
79
|
+
state.update({
|
|
80
|
+
"position": state_dict.get("position", "?"),
|
|
81
|
+
"map_id": state_dict.get("map_id", "?"),
|
|
82
|
+
"player_x": state_dict.get("player_x", "?"),
|
|
83
|
+
"player_y": state_dict.get("player_y", "?"),
|
|
84
|
+
"text_box_active": state_dict.get("text_box_active", False),
|
|
85
|
+
})
|
|
86
|
+
except:
|
|
87
|
+
pass
|
|
88
|
+
except:
|
|
89
|
+
pass
|
|
90
|
+
|
|
91
|
+
return state
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def extract_images_from_trace_dict(trace: dict[str, Any], output_dir: Path):
|
|
95
|
+
"""Extract images from a trace dictionary."""
|
|
96
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
97
|
+
|
|
98
|
+
# Get messages from trace
|
|
99
|
+
messages = trace.get("markov_blanket_message_history", []) or trace.get("messages", [])
|
|
100
|
+
|
|
101
|
+
if not messages:
|
|
102
|
+
print(f" No messages found in trace")
|
|
103
|
+
return 0
|
|
104
|
+
|
|
105
|
+
print(f" Found {len(messages)} messages")
|
|
106
|
+
|
|
107
|
+
image_count = 0
|
|
108
|
+
step_idx = 0
|
|
109
|
+
for msg_idx, msg in enumerate(messages):
|
|
110
|
+
# Extract images from message content
|
|
111
|
+
content = msg.get("content", "")
|
|
112
|
+
image_urls = extract_image_urls_from_content(content)
|
|
113
|
+
|
|
114
|
+
if not image_urls:
|
|
115
|
+
continue
|
|
116
|
+
|
|
117
|
+
# Extract state info for filename
|
|
118
|
+
state = extract_state_info_from_message(msg)
|
|
119
|
+
|
|
120
|
+
for img_idx, img_url in enumerate(image_urls):
|
|
121
|
+
# Extract base64 data
|
|
122
|
+
if img_url.startswith("data:image"):
|
|
123
|
+
# Format: data:image/png;base64,<data>
|
|
124
|
+
parts = img_url.split(",", 1)
|
|
125
|
+
if len(parts) != 2:
|
|
126
|
+
continue
|
|
127
|
+
|
|
128
|
+
b64_data = parts[1]
|
|
129
|
+
try:
|
|
130
|
+
img_data = base64.b64decode(b64_data)
|
|
131
|
+
|
|
132
|
+
# Create filename
|
|
133
|
+
pos_str = f"{state.get('map_id', '?')}_{state.get('player_x', '?')},{state.get('player_y', '?')}"
|
|
134
|
+
textbox_str = "True" if state.get("text_box_active") else "False"
|
|
135
|
+
filename = f"step_{step_idx:03d}_pos_{pos_str}_textbox_{textbox_str}.png"
|
|
136
|
+
|
|
137
|
+
filepath = output_dir / filename
|
|
138
|
+
filepath.write_bytes(img_data)
|
|
139
|
+
|
|
140
|
+
print(f" Saved: {filename}")
|
|
141
|
+
image_count += 1
|
|
142
|
+
step_idx += 1
|
|
143
|
+
except Exception as e:
|
|
144
|
+
print(f" Error decoding image: {e}")
|
|
145
|
+
continue
|
|
146
|
+
|
|
147
|
+
return image_count
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def extract_images_from_trace_db(trace_db: str, output_dir: Path, model_filter: str | None = None):
|
|
151
|
+
"""Extract images from trace database and save to output directory."""
|
|
152
|
+
conn = sqlite3.connect(trace_db)
|
|
153
|
+
conn.row_factory = sqlite3.Row
|
|
154
|
+
|
|
155
|
+
# Get all session IDs
|
|
156
|
+
query = "SELECT session_id, metadata FROM session_traces"
|
|
157
|
+
if model_filter:
|
|
158
|
+
query += " WHERE metadata LIKE ?"
|
|
159
|
+
params = (f'%{model_filter}%',)
|
|
160
|
+
else:
|
|
161
|
+
params = ()
|
|
162
|
+
|
|
163
|
+
rows = conn.execute(query, params).fetchall()
|
|
164
|
+
|
|
165
|
+
if not rows:
|
|
166
|
+
print(f"No traces found in {trace_db}")
|
|
167
|
+
return
|
|
168
|
+
|
|
169
|
+
print(f"Found {len(rows)} trace(s)")
|
|
170
|
+
|
|
171
|
+
total_images = 0
|
|
172
|
+
for row in rows:
|
|
173
|
+
session_id = row["session_id"]
|
|
174
|
+
print(f"\nProcessing session: {session_id}")
|
|
175
|
+
|
|
176
|
+
try:
|
|
177
|
+
trace = load_session_trace(conn, session_id)
|
|
178
|
+
except Exception as e:
|
|
179
|
+
print(f" Error loading trace: {e}")
|
|
180
|
+
continue
|
|
181
|
+
|
|
182
|
+
count = extract_images_from_trace_dict(trace, output_dir)
|
|
183
|
+
total_images += count
|
|
184
|
+
|
|
185
|
+
conn.close()
|
|
186
|
+
print(f"\n✓ Extracted {total_images} images to {output_dir}/")
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def extract_images_from_trace_json(trace_json: Path, output_dir: Path):
|
|
190
|
+
"""Extract images from trace JSON file."""
|
|
191
|
+
print(f"Loading trace from {trace_json}")
|
|
192
|
+
|
|
193
|
+
with open(trace_json) as f:
|
|
194
|
+
trace = json.load(f)
|
|
195
|
+
|
|
196
|
+
# Handle trace wrapped in "session_trace" key
|
|
197
|
+
if "session_trace" in trace:
|
|
198
|
+
trace = trace["session_trace"]
|
|
199
|
+
|
|
200
|
+
count = extract_images_from_trace_dict(trace, output_dir)
|
|
201
|
+
print(f"\n✓ Extracted {count} images to {output_dir}/")
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def main():
|
|
205
|
+
parser = argparse.ArgumentParser(description=__doc__)
|
|
206
|
+
parser.add_argument(
|
|
207
|
+
"--trace-db",
|
|
208
|
+
help="Path to trace database",
|
|
209
|
+
)
|
|
210
|
+
parser.add_argument(
|
|
211
|
+
"--trace-json",
|
|
212
|
+
type=Path,
|
|
213
|
+
help="Path to trace JSON file",
|
|
214
|
+
)
|
|
215
|
+
parser.add_argument(
|
|
216
|
+
"--output-dir",
|
|
217
|
+
default="examples/blog_posts/pokemon_vl/images_gpt5",
|
|
218
|
+
help="Output directory for images",
|
|
219
|
+
)
|
|
220
|
+
parser.add_argument(
|
|
221
|
+
"--model-filter",
|
|
222
|
+
help="Filter traces by model name (optional)",
|
|
223
|
+
)
|
|
224
|
+
args = parser.parse_args()
|
|
225
|
+
|
|
226
|
+
output_dir = Path(args.output_dir)
|
|
227
|
+
|
|
228
|
+
if args.trace_json:
|
|
229
|
+
extract_images_from_trace_json(args.trace_json, output_dir)
|
|
230
|
+
elif args.trace_db:
|
|
231
|
+
extract_images_from_trace_db(args.trace_db, output_dir, args.model_filter)
|
|
232
|
+
else:
|
|
233
|
+
parser.error("Must provide either --trace-db or --trace-json")
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
if __name__ == "__main__":
|
|
237
|
+
main()
|
|
238
|
+
|
|
239
|
+
|
|
@@ -0,0 +1,326 @@
|
|
|
1
|
+
"""Pokemon Red baseline file for Game Boy emulation evaluation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Dict, List, Optional
|
|
6
|
+
|
|
7
|
+
from synth_ai.baseline import BaselineConfig, BaselineTaskRunner, DataSplit, TaskResult
|
|
8
|
+
from synth_ai.inference import InferenceClient
|
|
9
|
+
import os
|
|
10
|
+
import httpx
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
from synth_ai.environments.examples.red.environment import PokemonRedEnvironment
|
|
14
|
+
from synth_ai.environments.examples.red.taskset import (
|
|
15
|
+
PokemonRedTaskInstance,
|
|
16
|
+
PokemonRedTaskInstanceMetadata,
|
|
17
|
+
)
|
|
18
|
+
POKEMON_RED_AVAILABLE = True
|
|
19
|
+
except ImportError:
|
|
20
|
+
POKEMON_RED_AVAILABLE = False
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class PokemonRedTaskRunner(BaselineTaskRunner):
|
|
24
|
+
"""Task runner for Pokemon Red Game Boy emulation."""
|
|
25
|
+
|
|
26
|
+
def __init__(self, policy_config: Dict[str, Any], env_config: Dict[str, Any]):
|
|
27
|
+
super().__init__(policy_config, env_config)
|
|
28
|
+
|
|
29
|
+
if not POKEMON_RED_AVAILABLE:
|
|
30
|
+
raise ImportError(
|
|
31
|
+
"Pokemon Red environment not available. "
|
|
32
|
+
"Install synth-ai with Pokemon Red support."
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
# Store config for inference
|
|
36
|
+
self.model = policy_config["model"]
|
|
37
|
+
self.temperature = policy_config.get("temperature", 0.0)
|
|
38
|
+
self.max_tokens = policy_config.get("max_tokens", 512)
|
|
39
|
+
self.inference_url = policy_config.get("inference_url")
|
|
40
|
+
|
|
41
|
+
# Tool definition
|
|
42
|
+
self.tools = [{
|
|
43
|
+
"type": "function",
|
|
44
|
+
"function": {
|
|
45
|
+
"name": "execute_sequence",
|
|
46
|
+
"description": "Execute multiple button presses in sequence",
|
|
47
|
+
"parameters": {
|
|
48
|
+
"type": "object",
|
|
49
|
+
"properties": {
|
|
50
|
+
"actions": {
|
|
51
|
+
"type": "array",
|
|
52
|
+
"items": {
|
|
53
|
+
"type": "object",
|
|
54
|
+
"properties": {
|
|
55
|
+
"button": {
|
|
56
|
+
"type": "string",
|
|
57
|
+
"enum": ["UP", "DOWN", "LEFT", "RIGHT", "A", "B", "START", "SELECT"],
|
|
58
|
+
},
|
|
59
|
+
"frames": {
|
|
60
|
+
"type": "integer",
|
|
61
|
+
"minimum": 1,
|
|
62
|
+
"maximum": 120,
|
|
63
|
+
"description": "Frames to hold button (60fps)",
|
|
64
|
+
},
|
|
65
|
+
},
|
|
66
|
+
"required": ["button", "frames"],
|
|
67
|
+
},
|
|
68
|
+
"minItems": 1,
|
|
69
|
+
"maxItems": 20,
|
|
70
|
+
},
|
|
71
|
+
},
|
|
72
|
+
"required": ["actions"],
|
|
73
|
+
},
|
|
74
|
+
},
|
|
75
|
+
}]
|
|
76
|
+
|
|
77
|
+
def _format_observation(self, obs: Dict[str, Any], step: int, max_steps: int) -> str:
|
|
78
|
+
"""Format observation for LLM."""
|
|
79
|
+
lines = [
|
|
80
|
+
f"Pokemon Red - Step {step}/{max_steps}",
|
|
81
|
+
"",
|
|
82
|
+
]
|
|
83
|
+
|
|
84
|
+
# Position
|
|
85
|
+
if "map_id" in obs:
|
|
86
|
+
lines.append(f"Location: Map {obs['map_id']}")
|
|
87
|
+
if "player_x" in obs and "player_y" in obs:
|
|
88
|
+
lines.append(f"Position: ({obs['player_x']}, {obs['player_y']})")
|
|
89
|
+
|
|
90
|
+
# Party
|
|
91
|
+
if "party_count" in obs:
|
|
92
|
+
lines.append(f"Party Size: {obs['party_count']}")
|
|
93
|
+
if "party_pokemon" in obs and obs["party_pokemon"]:
|
|
94
|
+
pokemon = obs["party_pokemon"][0]
|
|
95
|
+
lines.append(
|
|
96
|
+
f"First Pokemon: Level {pokemon.get('level', '?')}, "
|
|
97
|
+
f"HP {pokemon.get('hp_current', '?')}/{pokemon.get('hp_max', '?')}"
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
# Battle
|
|
101
|
+
if obs.get("in_battle"):
|
|
102
|
+
lines.append("=== IN BATTLE ===")
|
|
103
|
+
if "enemy_hp_current" in obs:
|
|
104
|
+
lines.append(
|
|
105
|
+
f"Enemy HP: {obs['enemy_hp_current']}/{obs.get('enemy_hp_max', '?')}"
|
|
106
|
+
)
|
|
107
|
+
if "battle_turn" in obs:
|
|
108
|
+
lines.append(f"Battle Turn: {obs['battle_turn']}")
|
|
109
|
+
|
|
110
|
+
# Progress
|
|
111
|
+
if "badges" in obs:
|
|
112
|
+
lines.append(f"Badges: {obs['badges']}")
|
|
113
|
+
if "money" in obs:
|
|
114
|
+
lines.append(f"Money: ${obs['money']}")
|
|
115
|
+
|
|
116
|
+
# Dialogue
|
|
117
|
+
if obs.get("text_box_active"):
|
|
118
|
+
lines.append("Text box is active - press A to advance dialogue")
|
|
119
|
+
|
|
120
|
+
lines.append("")
|
|
121
|
+
lines.append("What actions should we take?")
|
|
122
|
+
|
|
123
|
+
return "\n".join(lines)
|
|
124
|
+
|
|
125
|
+
async def run_task(self, seed: int) -> TaskResult:
|
|
126
|
+
"""Run a single Pokemon Red episode."""
|
|
127
|
+
|
|
128
|
+
# Create task instance
|
|
129
|
+
rom_path = self.env_config.get("rom_path")
|
|
130
|
+
if not rom_path:
|
|
131
|
+
raise ValueError("rom_path required in env_config for Pokemon Red")
|
|
132
|
+
|
|
133
|
+
init_state_path = self.env_config.get("init_state_path")
|
|
134
|
+
max_steps = self.env_config.get("max_steps", 500)
|
|
135
|
+
|
|
136
|
+
metadata = PokemonRedTaskInstanceMetadata(
|
|
137
|
+
seed=seed,
|
|
138
|
+
rom_path=rom_path,
|
|
139
|
+
init_state_path=init_state_path,
|
|
140
|
+
reward_type=self.env_config.get("reward_type", "pallet_town_progression"),
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
task_instance = PokemonRedTaskInstance(
|
|
144
|
+
id=f"pokemon-red-{seed}",
|
|
145
|
+
metadata=metadata,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# Create environment
|
|
149
|
+
env = PokemonRedEnvironment(task_instance=task_instance)
|
|
150
|
+
|
|
151
|
+
# Initialize environment
|
|
152
|
+
raw_obs = await env.initialize()
|
|
153
|
+
observation = getattr(raw_obs, "observation", raw_obs) if hasattr(raw_obs, "observation") else raw_obs
|
|
154
|
+
obs_dict = observation if isinstance(observation, dict) else {}
|
|
155
|
+
|
|
156
|
+
# Episode loop
|
|
157
|
+
total_reward = 0.0
|
|
158
|
+
total_steps = 0
|
|
159
|
+
event_rewards: List[Dict[str, Any]] = []
|
|
160
|
+
battle_won = False
|
|
161
|
+
game_over = False
|
|
162
|
+
|
|
163
|
+
for step in range(max_steps):
|
|
164
|
+
# Format observation
|
|
165
|
+
prompt = self._format_observation(obs_dict, step, max_steps)
|
|
166
|
+
|
|
167
|
+
# Add image if available
|
|
168
|
+
messages = [{"role": "user", "content": prompt}]
|
|
169
|
+
if obs_dict.get("observation_image_base64"):
|
|
170
|
+
messages[0]["content"] = [
|
|
171
|
+
{
|
|
172
|
+
"type": "image_url",
|
|
173
|
+
"image_url": {
|
|
174
|
+
"url": f"data:image/png;base64,{obs_dict['observation_image_base64']}"
|
|
175
|
+
},
|
|
176
|
+
},
|
|
177
|
+
{"type": "text", "text": prompt},
|
|
178
|
+
]
|
|
179
|
+
|
|
180
|
+
# Get action from LLM
|
|
181
|
+
if self.inference_url and self.inference_url.startswith("http"):
|
|
182
|
+
api_key = os.getenv("SYNTH_API_KEY") or os.getenv("OPENAI_API_KEY") or ""
|
|
183
|
+
base_url = self.inference_url.rstrip("/")
|
|
184
|
+
if not base_url.endswith("/api"):
|
|
185
|
+
base_url = f"{base_url}/api" if "/api" not in base_url else base_url
|
|
186
|
+
client = InferenceClient(base_url=base_url, api_key=api_key)
|
|
187
|
+
response = await client.create_chat_completion(
|
|
188
|
+
model=self.model,
|
|
189
|
+
messages=messages,
|
|
190
|
+
tools=self.tools,
|
|
191
|
+
tool_choice={"type": "function", "function": {"name": "execute_sequence"}},
|
|
192
|
+
temperature=self.temperature,
|
|
193
|
+
max_tokens=self.max_tokens,
|
|
194
|
+
)
|
|
195
|
+
else:
|
|
196
|
+
api_key = os.getenv("OPENAI_API_KEY") or os.getenv("GROQ_API_KEY") or ""
|
|
197
|
+
base_url = "https://api.openai.com/v1" if "openai" in self.model.lower() else "https://api.groq.com/openai/v1"
|
|
198
|
+
async with httpx.AsyncClient() as http_client:
|
|
199
|
+
resp = await http_client.post(
|
|
200
|
+
f"{base_url}/chat/completions",
|
|
201
|
+
json={
|
|
202
|
+
"model": self.model,
|
|
203
|
+
"messages": messages,
|
|
204
|
+
"tools": self.tools,
|
|
205
|
+
"tool_choice": {"type": "function", "function": {"name": "execute_sequence"}},
|
|
206
|
+
"temperature": self.temperature,
|
|
207
|
+
"max_tokens": self.max_tokens,
|
|
208
|
+
},
|
|
209
|
+
headers={"Authorization": f"Bearer {api_key}"} if api_key else {},
|
|
210
|
+
)
|
|
211
|
+
response = resp.json()
|
|
212
|
+
|
|
213
|
+
# Extract actions
|
|
214
|
+
actions = []
|
|
215
|
+
tool_calls = []
|
|
216
|
+
if "choices" in response and len(response["choices"]) > 0:
|
|
217
|
+
message = response["choices"][0].get("message", {})
|
|
218
|
+
tool_calls = message.get("tool_calls", [])
|
|
219
|
+
elif "tool_calls" in response:
|
|
220
|
+
tool_calls = response["tool_calls"]
|
|
221
|
+
|
|
222
|
+
if tool_calls:
|
|
223
|
+
tool_call = tool_calls[0]
|
|
224
|
+
actions = tool_call["function"]["arguments"].get("actions", [])
|
|
225
|
+
|
|
226
|
+
if not actions:
|
|
227
|
+
break
|
|
228
|
+
|
|
229
|
+
# Execute actions
|
|
230
|
+
for action_spec in actions:
|
|
231
|
+
if total_steps >= max_steps:
|
|
232
|
+
break
|
|
233
|
+
|
|
234
|
+
# Convert to tool call format
|
|
235
|
+
from synth_ai.environments.environment.tools import EnvToolCall
|
|
236
|
+
|
|
237
|
+
tool_call = EnvToolCall(
|
|
238
|
+
name="execute_sequence",
|
|
239
|
+
arguments={"actions": [action_spec]},
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
# Step environment
|
|
243
|
+
step_result = await env.step([tool_call])
|
|
244
|
+
total_steps += 1
|
|
245
|
+
|
|
246
|
+
# Get observation
|
|
247
|
+
step_obs = (
|
|
248
|
+
getattr(step_result, "observation", step_result)
|
|
249
|
+
if hasattr(step_result, "observation")
|
|
250
|
+
else step_result
|
|
251
|
+
)
|
|
252
|
+
obs_dict = step_obs if isinstance(step_obs, dict) else {}
|
|
253
|
+
|
|
254
|
+
# Extract reward
|
|
255
|
+
reward = getattr(step_result, "reward", 0.0)
|
|
256
|
+
total_reward += reward
|
|
257
|
+
|
|
258
|
+
if reward > 0:
|
|
259
|
+
event_rewards.append({
|
|
260
|
+
"step": total_steps,
|
|
261
|
+
"reward": reward,
|
|
262
|
+
})
|
|
263
|
+
|
|
264
|
+
# Check termination
|
|
265
|
+
if getattr(step_result, "terminated", False) or getattr(step_result, "truncated", False):
|
|
266
|
+
game_over = True
|
|
267
|
+
break
|
|
268
|
+
|
|
269
|
+
# Check battle outcome
|
|
270
|
+
if obs_dict.get("battle_outcome") == 1:
|
|
271
|
+
battle_won = True
|
|
272
|
+
elif obs_dict.get("battle_outcome") == 2:
|
|
273
|
+
game_over = True
|
|
274
|
+
|
|
275
|
+
if game_over:
|
|
276
|
+
break
|
|
277
|
+
|
|
278
|
+
# Cleanup
|
|
279
|
+
await env.terminate()
|
|
280
|
+
|
|
281
|
+
return TaskResult(
|
|
282
|
+
seed=seed,
|
|
283
|
+
success=True,
|
|
284
|
+
outcome_reward=total_reward,
|
|
285
|
+
event_rewards=event_rewards,
|
|
286
|
+
total_steps=total_steps,
|
|
287
|
+
metadata={
|
|
288
|
+
"battle_won": battle_won,
|
|
289
|
+
"game_over": game_over,
|
|
290
|
+
"final_map": obs_dict.get("map_id"),
|
|
291
|
+
"badges": obs_dict.get("badges", 0),
|
|
292
|
+
"party_size": obs_dict.get("party_count", 0),
|
|
293
|
+
},
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
# Define baseline config (only if Pokemon Red is available)
|
|
298
|
+
if POKEMON_RED_AVAILABLE:
|
|
299
|
+
pokemon_vl_baseline = BaselineConfig(
|
|
300
|
+
baseline_id="pokemon_vl",
|
|
301
|
+
name="Pokemon VL - Pokemon Red",
|
|
302
|
+
description="Pokemon Red Game Boy emulation baseline for vision-language agents",
|
|
303
|
+
task_runner=PokemonRedTaskRunner,
|
|
304
|
+
splits={
|
|
305
|
+
"train": DataSplit(name="train", seeds=list(range(20))),
|
|
306
|
+
"val": DataSplit(name="val", seeds=list(range(20, 25))),
|
|
307
|
+
"test": DataSplit(name="test", seeds=list(range(25, 30))),
|
|
308
|
+
},
|
|
309
|
+
default_policy_config={
|
|
310
|
+
"model": "groq:llama-3.1-70b-versatile",
|
|
311
|
+
"temperature": 0.0,
|
|
312
|
+
"max_tokens": 512,
|
|
313
|
+
},
|
|
314
|
+
default_env_config={
|
|
315
|
+
"rom_path": None, # Must be provided
|
|
316
|
+
"init_state_path": None, # Optional
|
|
317
|
+
"reward_type": "pallet_town_progression",
|
|
318
|
+
"max_steps": 500,
|
|
319
|
+
},
|
|
320
|
+
metadata={
|
|
321
|
+
"environment": "pokemon_red",
|
|
322
|
+
"task_type": "emulation",
|
|
323
|
+
"requires_rom": True,
|
|
324
|
+
},
|
|
325
|
+
)
|
|
326
|
+
|