synth-ai 0.2.17__py3-none-any.whl → 0.2.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (169) hide show
  1. examples/baseline/banking77_baseline.py +204 -0
  2. examples/baseline/crafter_baseline.py +407 -0
  3. examples/baseline/pokemon_red_baseline.py +326 -0
  4. examples/baseline/simple_baseline.py +56 -0
  5. examples/baseline/warming_up_to_rl_baseline.py +239 -0
  6. examples/blog_posts/gepa/README.md +355 -0
  7. examples/blog_posts/gepa/configs/banking77_gepa_local.toml +95 -0
  8. examples/blog_posts/gepa/configs/banking77_gepa_test.toml +82 -0
  9. examples/blog_posts/gepa/configs/banking77_mipro_local.toml +52 -0
  10. examples/blog_posts/gepa/configs/hotpotqa_gepa_local.toml +59 -0
  11. examples/blog_posts/gepa/configs/hotpotqa_gepa_qwen.toml +36 -0
  12. examples/blog_posts/gepa/configs/hotpotqa_mipro_local.toml +53 -0
  13. examples/blog_posts/gepa/configs/hover_gepa_local.toml +59 -0
  14. examples/blog_posts/gepa/configs/hover_gepa_qwen.toml +36 -0
  15. examples/blog_posts/gepa/configs/hover_mipro_local.toml +53 -0
  16. examples/blog_posts/gepa/configs/ifbench_gepa_local.toml +59 -0
  17. examples/blog_posts/gepa/configs/ifbench_gepa_qwen.toml +36 -0
  18. examples/blog_posts/gepa/configs/ifbench_mipro_local.toml +53 -0
  19. examples/blog_posts/gepa/configs/pupa_gepa_local.toml +60 -0
  20. examples/blog_posts/gepa/configs/pupa_mipro_local.toml +54 -0
  21. examples/blog_posts/gepa/deploy_banking77_task_app.sh +41 -0
  22. examples/blog_posts/gepa/gepa_baseline.py +204 -0
  23. examples/blog_posts/gepa/query_prompts_example.py +97 -0
  24. examples/blog_posts/gepa/run_gepa_banking77.sh +87 -0
  25. examples/blog_posts/gepa/task_apps.py +105 -0
  26. examples/blog_posts/gepa/test_gepa_local.sh +67 -0
  27. examples/blog_posts/gepa/verify_banking77_setup.sh +123 -0
  28. examples/blog_posts/pokemon_vl/configs/eval_gpt5nano.toml +26 -0
  29. examples/blog_posts/pokemon_vl/configs/eval_qwen3_vl.toml +12 -10
  30. examples/blog_posts/pokemon_vl/configs/train_rl_from_sft.toml +1 -0
  31. examples/blog_posts/pokemon_vl/extract_images.py +239 -0
  32. examples/blog_posts/pokemon_vl/pokemon_vl_baseline.py +326 -0
  33. examples/blog_posts/pokemon_vl/run_eval_extract_images.py +209 -0
  34. examples/blog_posts/pokemon_vl/run_qwen_eval_extract_images.py +212 -0
  35. examples/blog_posts/pokemon_vl/text_box_analysis.md +106 -0
  36. examples/blog_posts/warming_up_to_rl/ARCHITECTURE.md +195 -0
  37. examples/blog_posts/warming_up_to_rl/FINAL_TEST_RESULTS.md +127 -0
  38. examples/blog_posts/warming_up_to_rl/INFERENCE_SUCCESS.md +132 -0
  39. examples/blog_posts/warming_up_to_rl/SMOKE_TESTING.md +164 -0
  40. examples/blog_posts/warming_up_to_rl/SMOKE_TEST_COMPLETE.md +253 -0
  41. examples/blog_posts/warming_up_to_rl/configs/eval_baseline_qwen32b_10x20.toml +25 -0
  42. examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b_10x20.toml +26 -0
  43. examples/blog_posts/warming_up_to_rl/configs/filter_high_reward_dataset.toml +1 -1
  44. examples/blog_posts/warming_up_to_rl/configs/smoke_test.toml +75 -0
  45. examples/blog_posts/warming_up_to_rl/configs/train_rl_from_sft.toml +60 -10
  46. examples/blog_posts/warming_up_to_rl/configs/train_sft_qwen4b.toml +1 -1
  47. examples/blog_posts/warming_up_to_rl/warming_up_to_rl_baseline.py +187 -0
  48. examples/multi_step/configs/VERILOG_REWARDS.md +4 -0
  49. examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +4 -0
  50. examples/multi_step/configs/crafter_rl_outcome.toml +1 -0
  51. examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +1 -0
  52. examples/multi_step/configs/crafter_rl_stepwise_simple.toml +1 -0
  53. examples/rl/configs/rl_from_base_qwen17.toml +1 -0
  54. examples/swe/task_app/hosted/inference/openai_client.py +0 -34
  55. examples/swe/task_app/hosted/policy_routes.py +17 -0
  56. examples/swe/task_app/hosted/rollout.py +4 -2
  57. examples/task_apps/banking77/__init__.py +6 -0
  58. examples/task_apps/banking77/banking77_task_app.py +841 -0
  59. examples/task_apps/banking77/deploy_wrapper.py +46 -0
  60. examples/task_apps/crafter/CREATE_SFT_DATASET.md +4 -0
  61. examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +4 -0
  62. examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +4 -0
  63. examples/task_apps/crafter/task_app/grpo_crafter.py +24 -2
  64. examples/task_apps/crafter/task_app/synth_envs_hosted/hosted_app.py +49 -0
  65. examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +355 -58
  66. examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +68 -7
  67. examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +78 -21
  68. examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +194 -1
  69. examples/task_apps/gepa_benchmarks/__init__.py +7 -0
  70. examples/task_apps/gepa_benchmarks/common.py +260 -0
  71. examples/task_apps/gepa_benchmarks/hotpotqa_task_app.py +507 -0
  72. examples/task_apps/gepa_benchmarks/hover_task_app.py +436 -0
  73. examples/task_apps/gepa_benchmarks/ifbench_task_app.py +563 -0
  74. examples/task_apps/gepa_benchmarks/pupa_task_app.py +460 -0
  75. examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +4 -0
  76. examples/task_apps/pokemon_red/task_app.py +254 -36
  77. examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +1 -0
  78. examples/warming_up_to_rl/task_app/grpo_crafter.py +53 -4
  79. examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +49 -0
  80. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +152 -41
  81. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +31 -1
  82. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +33 -3
  83. examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +67 -0
  84. examples/workflows/math_rl/configs/rl_from_base_qwen17.toml +1 -0
  85. synth_ai/api/train/builders.py +90 -1
  86. synth_ai/api/train/cli.py +396 -21
  87. synth_ai/api/train/config_finder.py +13 -2
  88. synth_ai/api/train/configs/__init__.py +15 -1
  89. synth_ai/api/train/configs/prompt_learning.py +442 -0
  90. synth_ai/api/train/configs/rl.py +29 -0
  91. synth_ai/api/train/task_app.py +1 -1
  92. synth_ai/api/train/validators.py +277 -0
  93. synth_ai/baseline/__init__.py +25 -0
  94. synth_ai/baseline/config.py +209 -0
  95. synth_ai/baseline/discovery.py +214 -0
  96. synth_ai/baseline/execution.py +146 -0
  97. synth_ai/cli/__init__.py +85 -17
  98. synth_ai/cli/__main__.py +0 -0
  99. synth_ai/cli/claude.py +70 -0
  100. synth_ai/cli/codex.py +84 -0
  101. synth_ai/cli/commands/__init__.py +1 -0
  102. synth_ai/cli/commands/baseline/__init__.py +12 -0
  103. synth_ai/cli/commands/baseline/core.py +637 -0
  104. synth_ai/cli/commands/baseline/list.py +93 -0
  105. synth_ai/cli/commands/eval/core.py +13 -10
  106. synth_ai/cli/commands/filter/core.py +53 -17
  107. synth_ai/cli/commands/help/core.py +0 -1
  108. synth_ai/cli/commands/smoke/__init__.py +7 -0
  109. synth_ai/cli/commands/smoke/core.py +1436 -0
  110. synth_ai/cli/commands/status/subcommands/pricing.py +22 -0
  111. synth_ai/cli/commands/status/subcommands/usage.py +203 -0
  112. synth_ai/cli/commands/train/judge_schemas.py +1 -0
  113. synth_ai/cli/commands/train/judge_validation.py +1 -0
  114. synth_ai/cli/commands/train/validation.py +0 -57
  115. synth_ai/cli/demo.py +35 -3
  116. synth_ai/cli/deploy/__init__.py +40 -25
  117. synth_ai/cli/deploy.py +162 -0
  118. synth_ai/cli/legacy_root_backup.py +14 -8
  119. synth_ai/cli/opencode.py +107 -0
  120. synth_ai/cli/root.py +9 -5
  121. synth_ai/cli/task_app_deploy.py +1 -1
  122. synth_ai/cli/task_apps.py +53 -53
  123. synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +7 -4
  124. synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +9 -5
  125. synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +4 -3
  126. synth_ai/judge_schemas.py +1 -0
  127. synth_ai/learning/__init__.py +10 -0
  128. synth_ai/learning/prompt_learning_client.py +276 -0
  129. synth_ai/learning/prompt_learning_types.py +184 -0
  130. synth_ai/pricing/__init__.py +2 -0
  131. synth_ai/pricing/model_pricing.py +57 -0
  132. synth_ai/streaming/handlers.py +53 -4
  133. synth_ai/streaming/streamer.py +19 -0
  134. synth_ai/task/apps/__init__.py +1 -0
  135. synth_ai/task/config.py +2 -0
  136. synth_ai/task/tracing_utils.py +25 -25
  137. synth_ai/task/validators.py +44 -8
  138. synth_ai/task_app_cfgs.py +21 -0
  139. synth_ai/tracing_v3/config.py +162 -19
  140. synth_ai/tracing_v3/constants.py +1 -1
  141. synth_ai/tracing_v3/db_config.py +24 -38
  142. synth_ai/tracing_v3/storage/config.py +47 -13
  143. synth_ai/tracing_v3/storage/factory.py +3 -3
  144. synth_ai/tracing_v3/turso/daemon.py +113 -11
  145. synth_ai/tracing_v3/turso/native_manager.py +92 -16
  146. synth_ai/types.py +8 -0
  147. synth_ai/urls.py +11 -0
  148. synth_ai/utils/__init__.py +30 -1
  149. synth_ai/utils/agents.py +74 -0
  150. synth_ai/utils/bin.py +39 -0
  151. synth_ai/utils/cli.py +149 -5
  152. synth_ai/utils/env.py +17 -17
  153. synth_ai/utils/json.py +72 -0
  154. synth_ai/utils/modal.py +283 -1
  155. synth_ai/utils/paths.py +48 -0
  156. synth_ai/utils/uvicorn.py +113 -0
  157. {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/METADATA +102 -4
  158. {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/RECORD +162 -88
  159. synth_ai/cli/commands/deploy/__init__.py +0 -23
  160. synth_ai/cli/commands/deploy/core.py +0 -614
  161. synth_ai/cli/commands/deploy/errors.py +0 -72
  162. synth_ai/cli/commands/deploy/validation.py +0 -11
  163. synth_ai/cli/deploy/core.py +0 -5
  164. synth_ai/cli/deploy/errors.py +0 -23
  165. synth_ai/cli/deploy/validation.py +0 -5
  166. {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/WHEEL +0 -0
  167. {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/entry_points.txt +0 -0
  168. {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/licenses/LICENSE +0 -0
  169. {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,26 @@
1
+ [eval]
2
+ app_id = "pokemon_red"
3
+ task_app_url = "http://127.0.0.1:8914"
4
+ model = "gpt-5-nano"
5
+ seeds = [0] # Single seed for testing
6
+ max_turns = 10 # 10 LLM calls per episode to allow more progress
7
+ concurrency = 1 # Run 1 rollout
8
+ env_name = "pokemon_red"
9
+ policy_name = "pokemon_vl_qwen3_vl" # Reuse policy config, will override model
10
+ trace_format = "full"
11
+ return_trace = true
12
+
13
+ [eval.policy_config]
14
+ provider = "openai" # Use OpenAI API for gpt-5-nano
15
+ model = "gpt-5-nano"
16
+ inference_url = "https://api.openai.com/v1"
17
+ temperature = 0.7
18
+ top_p = 0.95
19
+ max_tokens = 512
20
+ use_vision = true
21
+ image_only_mode = false
22
+ max_llm_calls = 10
23
+
24
+ [eval.env_config.env_params]
25
+ max_steps_per_episode = 100 # Allow time to achieve milestones
26
+
@@ -1,10 +1,10 @@
1
1
  [eval]
2
2
  app_id = "pokemon_red"
3
- task_app_url = "https://synth-laboratories--pokemon-vl-qwen-xml-fastapi-app.modal.run"
4
- model = "Qwen/Qwen3-VL-8B-Instruct" # Vision-capable Qwen3-VL model
5
- seeds = [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
6
- max_turns = 10
7
- concurrency = 2
3
+ task_app_url = "http://127.0.0.1:8914"
4
+ model = "Qwen/Qwen3-VL-30B-A3B-Thinking" # Larger thinking variant - needs more time to load
5
+ seeds = [10, 11] # 2 seeds for quick testing
6
+ max_turns = 10 # 10 LLM calls per episode to allow more progress
7
+ concurrency = 2 # Run 2 rollouts in parallel
8
8
  env_name = "pokemon_red"
9
9
  policy_name = "pokemon_vl_qwen3_vl"
10
10
  trace_format = "full"
@@ -12,14 +12,16 @@ return_trace = true
12
12
 
13
13
  [eval.policy_config]
14
14
  provider = "synth" # Use Synth internal API for vision models
15
- model = "Qwen/Qwen3-VL-8B-Instruct" # Vision-capable Qwen3-VL model
16
- inference_url = "http://localhost:8000/api/inference/v1/chat/completions"
17
- temperature = 1.0 # Higher temperature to encourage tool calling
15
+ model = "Qwen/Qwen3-VL-30B-A3B-Thinking" # Larger thinking variant - needs more time to load
16
+ inference_url = "https://synth-laboratories-dev--learning-v2-service-fastapi-app.modal.run/chat/completions"
17
+ temperature = 1.0 # Higher temperature to encourage exploration
18
18
  top_p = 0.95
19
- max_tokens = 4096
19
+ max_tokens = 2048 # Reduced to avoid token budget issues
20
20
  use_vision = true
21
21
  image_only_mode = false
22
22
  max_llm_calls = 10
23
+ thinking_mode = "think" # Enable thinking/reasoning mode
24
+ thinking_budget = 3072 # Increased token budget for reasoning
23
25
 
24
26
  [eval.env_config.env_params]
25
- max_steps_per_episode = 10
27
+ max_steps_per_episode = 100 # Increased from 3 to allow time to achieve milestones
@@ -27,6 +27,7 @@ supports_vision = true
27
27
  [rollout]
28
28
  max_turns = 10
29
29
  episodes_per_batch = 64
30
+ task_app_origin_rewards_only = true
30
31
 
31
32
  [evaluation]
32
33
  instances = 100
@@ -0,0 +1,239 @@
1
+ #!/usr/bin/env python3
2
+ """Extract images from pokemon_vl trace database or trace JSON file and save to images_gpt5 directory.
3
+
4
+ Usage:
5
+ # From trace database:
6
+ python extract_images.py --trace-db traces/v3/pokemon_vl_gpt5nano.db
7
+
8
+ # From trace JSON file:
9
+ python extract_images.py --trace-json trace.json
10
+ """
11
+
12
+ import argparse
13
+ import base64
14
+ import json
15
+ import sqlite3
16
+ from pathlib import Path
17
+ from typing import Any
18
+
19
+ from synth_ai.tracing_v3.trace_utils import load_session_trace
20
+
21
+
22
+ def extract_image_urls_from_content(content: Any) -> list[str]:
23
+ """Extract image URLs from message content."""
24
+ urls = []
25
+
26
+ if isinstance(content, list):
27
+ for part in content:
28
+ if isinstance(part, dict):
29
+ if part.get("type") == "image_url" and "image_url" in part:
30
+ url = part["image_url"].get("url")
31
+ if isinstance(url, str) and url.startswith("data:image"):
32
+ urls.append(url)
33
+ elif part.get("type") == "image":
34
+ img = part.get("image")
35
+ if isinstance(img, str) and img.startswith("data:image"):
36
+ urls.append(img)
37
+ elif isinstance(content, str):
38
+ # Check if it's a JSON string
39
+ try:
40
+ parsed = json.loads(content)
41
+ return extract_image_urls_from_content(parsed)
42
+ except:
43
+ pass
44
+
45
+ return urls
46
+
47
+
48
+ def extract_state_info_from_message(message: dict[str, Any]) -> dict[str, Any]:
49
+ """Extract state info from message metadata or content."""
50
+ metadata = message.get("metadata", {})
51
+ state = {}
52
+
53
+ # Try to get state from metadata
54
+ if "system_state_before" in metadata:
55
+ state_before = metadata["system_state_before"]
56
+ if isinstance(state_before, dict):
57
+ obs = state_before.get("obs", {})
58
+ state.update({
59
+ "position": obs.get("position", "?"),
60
+ "map_id": obs.get("map_id", "?"),
61
+ "player_x": obs.get("player_x", "?"),
62
+ "player_y": obs.get("player_y", "?"),
63
+ "text_box_active": obs.get("text_box_active", False),
64
+ })
65
+
66
+ # Try to extract from content text
67
+ content = message.get("content", "")
68
+ if isinstance(content, str) and "position" in content:
69
+ try:
70
+ # Look for state summary in content
71
+ if "State summary:" in content:
72
+ parts = content.split("State summary:")
73
+ if len(parts) > 1:
74
+ import ast
75
+ state_str = parts[1].split("'")[0] if "'" not in parts[1] else parts[1]
76
+ try:
77
+ state_dict = ast.literal_eval(state_str.split("'")[0] if "'" in state_str else state_str)
78
+ if isinstance(state_dict, dict):
79
+ state.update({
80
+ "position": state_dict.get("position", "?"),
81
+ "map_id": state_dict.get("map_id", "?"),
82
+ "player_x": state_dict.get("player_x", "?"),
83
+ "player_y": state_dict.get("player_y", "?"),
84
+ "text_box_active": state_dict.get("text_box_active", False),
85
+ })
86
+ except:
87
+ pass
88
+ except:
89
+ pass
90
+
91
+ return state
92
+
93
+
94
+ def extract_images_from_trace_dict(trace: dict[str, Any], output_dir: Path):
95
+ """Extract images from a trace dictionary."""
96
+ output_dir.mkdir(parents=True, exist_ok=True)
97
+
98
+ # Get messages from trace
99
+ messages = trace.get("markov_blanket_message_history", []) or trace.get("messages", [])
100
+
101
+ if not messages:
102
+ print(f" No messages found in trace")
103
+ return 0
104
+
105
+ print(f" Found {len(messages)} messages")
106
+
107
+ image_count = 0
108
+ step_idx = 0
109
+ for msg_idx, msg in enumerate(messages):
110
+ # Extract images from message content
111
+ content = msg.get("content", "")
112
+ image_urls = extract_image_urls_from_content(content)
113
+
114
+ if not image_urls:
115
+ continue
116
+
117
+ # Extract state info for filename
118
+ state = extract_state_info_from_message(msg)
119
+
120
+ for img_idx, img_url in enumerate(image_urls):
121
+ # Extract base64 data
122
+ if img_url.startswith("data:image"):
123
+ # Format: data:image/png;base64,<data>
124
+ parts = img_url.split(",", 1)
125
+ if len(parts) != 2:
126
+ continue
127
+
128
+ b64_data = parts[1]
129
+ try:
130
+ img_data = base64.b64decode(b64_data)
131
+
132
+ # Create filename
133
+ pos_str = f"{state.get('map_id', '?')}_{state.get('player_x', '?')},{state.get('player_y', '?')}"
134
+ textbox_str = "True" if state.get("text_box_active") else "False"
135
+ filename = f"step_{step_idx:03d}_pos_{pos_str}_textbox_{textbox_str}.png"
136
+
137
+ filepath = output_dir / filename
138
+ filepath.write_bytes(img_data)
139
+
140
+ print(f" Saved: {filename}")
141
+ image_count += 1
142
+ step_idx += 1
143
+ except Exception as e:
144
+ print(f" Error decoding image: {e}")
145
+ continue
146
+
147
+ return image_count
148
+
149
+
150
+ def extract_images_from_trace_db(trace_db: str, output_dir: Path, model_filter: str | None = None):
151
+ """Extract images from trace database and save to output directory."""
152
+ conn = sqlite3.connect(trace_db)
153
+ conn.row_factory = sqlite3.Row
154
+
155
+ # Get all session IDs
156
+ query = "SELECT session_id, metadata FROM session_traces"
157
+ if model_filter:
158
+ query += " WHERE metadata LIKE ?"
159
+ params = (f'%{model_filter}%',)
160
+ else:
161
+ params = ()
162
+
163
+ rows = conn.execute(query, params).fetchall()
164
+
165
+ if not rows:
166
+ print(f"No traces found in {trace_db}")
167
+ return
168
+
169
+ print(f"Found {len(rows)} trace(s)")
170
+
171
+ total_images = 0
172
+ for row in rows:
173
+ session_id = row["session_id"]
174
+ print(f"\nProcessing session: {session_id}")
175
+
176
+ try:
177
+ trace = load_session_trace(conn, session_id)
178
+ except Exception as e:
179
+ print(f" Error loading trace: {e}")
180
+ continue
181
+
182
+ count = extract_images_from_trace_dict(trace, output_dir)
183
+ total_images += count
184
+
185
+ conn.close()
186
+ print(f"\n✓ Extracted {total_images} images to {output_dir}/")
187
+
188
+
189
+ def extract_images_from_trace_json(trace_json: Path, output_dir: Path):
190
+ """Extract images from trace JSON file."""
191
+ print(f"Loading trace from {trace_json}")
192
+
193
+ with open(trace_json) as f:
194
+ trace = json.load(f)
195
+
196
+ # Handle trace wrapped in "session_trace" key
197
+ if "session_trace" in trace:
198
+ trace = trace["session_trace"]
199
+
200
+ count = extract_images_from_trace_dict(trace, output_dir)
201
+ print(f"\n✓ Extracted {count} images to {output_dir}/")
202
+
203
+
204
+ def main():
205
+ parser = argparse.ArgumentParser(description=__doc__)
206
+ parser.add_argument(
207
+ "--trace-db",
208
+ help="Path to trace database",
209
+ )
210
+ parser.add_argument(
211
+ "--trace-json",
212
+ type=Path,
213
+ help="Path to trace JSON file",
214
+ )
215
+ parser.add_argument(
216
+ "--output-dir",
217
+ default="examples/blog_posts/pokemon_vl/images_gpt5",
218
+ help="Output directory for images",
219
+ )
220
+ parser.add_argument(
221
+ "--model-filter",
222
+ help="Filter traces by model name (optional)",
223
+ )
224
+ args = parser.parse_args()
225
+
226
+ output_dir = Path(args.output_dir)
227
+
228
+ if args.trace_json:
229
+ extract_images_from_trace_json(args.trace_json, output_dir)
230
+ elif args.trace_db:
231
+ extract_images_from_trace_db(args.trace_db, output_dir, args.model_filter)
232
+ else:
233
+ parser.error("Must provide either --trace-db or --trace-json")
234
+
235
+
236
+ if __name__ == "__main__":
237
+ main()
238
+
239
+
@@ -0,0 +1,326 @@
1
+ """Pokemon Red baseline file for Game Boy emulation evaluation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ from synth_ai.baseline import BaselineConfig, BaselineTaskRunner, DataSplit, TaskResult
8
+ from synth_ai.inference import InferenceClient
9
+ import os
10
+ import httpx
11
+
12
+ try:
13
+ from synth_ai.environments.examples.red.environment import PokemonRedEnvironment
14
+ from synth_ai.environments.examples.red.taskset import (
15
+ PokemonRedTaskInstance,
16
+ PokemonRedTaskInstanceMetadata,
17
+ )
18
+ POKEMON_RED_AVAILABLE = True
19
+ except ImportError:
20
+ POKEMON_RED_AVAILABLE = False
21
+
22
+
23
+ class PokemonRedTaskRunner(BaselineTaskRunner):
24
+ """Task runner for Pokemon Red Game Boy emulation."""
25
+
26
+ def __init__(self, policy_config: Dict[str, Any], env_config: Dict[str, Any]):
27
+ super().__init__(policy_config, env_config)
28
+
29
+ if not POKEMON_RED_AVAILABLE:
30
+ raise ImportError(
31
+ "Pokemon Red environment not available. "
32
+ "Install synth-ai with Pokemon Red support."
33
+ )
34
+
35
+ # Store config for inference
36
+ self.model = policy_config["model"]
37
+ self.temperature = policy_config.get("temperature", 0.0)
38
+ self.max_tokens = policy_config.get("max_tokens", 512)
39
+ self.inference_url = policy_config.get("inference_url")
40
+
41
+ # Tool definition
42
+ self.tools = [{
43
+ "type": "function",
44
+ "function": {
45
+ "name": "execute_sequence",
46
+ "description": "Execute multiple button presses in sequence",
47
+ "parameters": {
48
+ "type": "object",
49
+ "properties": {
50
+ "actions": {
51
+ "type": "array",
52
+ "items": {
53
+ "type": "object",
54
+ "properties": {
55
+ "button": {
56
+ "type": "string",
57
+ "enum": ["UP", "DOWN", "LEFT", "RIGHT", "A", "B", "START", "SELECT"],
58
+ },
59
+ "frames": {
60
+ "type": "integer",
61
+ "minimum": 1,
62
+ "maximum": 120,
63
+ "description": "Frames to hold button (60fps)",
64
+ },
65
+ },
66
+ "required": ["button", "frames"],
67
+ },
68
+ "minItems": 1,
69
+ "maxItems": 20,
70
+ },
71
+ },
72
+ "required": ["actions"],
73
+ },
74
+ },
75
+ }]
76
+
77
+ def _format_observation(self, obs: Dict[str, Any], step: int, max_steps: int) -> str:
78
+ """Format observation for LLM."""
79
+ lines = [
80
+ f"Pokemon Red - Step {step}/{max_steps}",
81
+ "",
82
+ ]
83
+
84
+ # Position
85
+ if "map_id" in obs:
86
+ lines.append(f"Location: Map {obs['map_id']}")
87
+ if "player_x" in obs and "player_y" in obs:
88
+ lines.append(f"Position: ({obs['player_x']}, {obs['player_y']})")
89
+
90
+ # Party
91
+ if "party_count" in obs:
92
+ lines.append(f"Party Size: {obs['party_count']}")
93
+ if "party_pokemon" in obs and obs["party_pokemon"]:
94
+ pokemon = obs["party_pokemon"][0]
95
+ lines.append(
96
+ f"First Pokemon: Level {pokemon.get('level', '?')}, "
97
+ f"HP {pokemon.get('hp_current', '?')}/{pokemon.get('hp_max', '?')}"
98
+ )
99
+
100
+ # Battle
101
+ if obs.get("in_battle"):
102
+ lines.append("=== IN BATTLE ===")
103
+ if "enemy_hp_current" in obs:
104
+ lines.append(
105
+ f"Enemy HP: {obs['enemy_hp_current']}/{obs.get('enemy_hp_max', '?')}"
106
+ )
107
+ if "battle_turn" in obs:
108
+ lines.append(f"Battle Turn: {obs['battle_turn']}")
109
+
110
+ # Progress
111
+ if "badges" in obs:
112
+ lines.append(f"Badges: {obs['badges']}")
113
+ if "money" in obs:
114
+ lines.append(f"Money: ${obs['money']}")
115
+
116
+ # Dialogue
117
+ if obs.get("text_box_active"):
118
+ lines.append("Text box is active - press A to advance dialogue")
119
+
120
+ lines.append("")
121
+ lines.append("What actions should we take?")
122
+
123
+ return "\n".join(lines)
124
+
125
+ async def run_task(self, seed: int) -> TaskResult:
126
+ """Run a single Pokemon Red episode."""
127
+
128
+ # Create task instance
129
+ rom_path = self.env_config.get("rom_path")
130
+ if not rom_path:
131
+ raise ValueError("rom_path required in env_config for Pokemon Red")
132
+
133
+ init_state_path = self.env_config.get("init_state_path")
134
+ max_steps = self.env_config.get("max_steps", 500)
135
+
136
+ metadata = PokemonRedTaskInstanceMetadata(
137
+ seed=seed,
138
+ rom_path=rom_path,
139
+ init_state_path=init_state_path,
140
+ reward_type=self.env_config.get("reward_type", "pallet_town_progression"),
141
+ )
142
+
143
+ task_instance = PokemonRedTaskInstance(
144
+ id=f"pokemon-red-{seed}",
145
+ metadata=metadata,
146
+ )
147
+
148
+ # Create environment
149
+ env = PokemonRedEnvironment(task_instance=task_instance)
150
+
151
+ # Initialize environment
152
+ raw_obs = await env.initialize()
153
+ observation = getattr(raw_obs, "observation", raw_obs) if hasattr(raw_obs, "observation") else raw_obs
154
+ obs_dict = observation if isinstance(observation, dict) else {}
155
+
156
+ # Episode loop
157
+ total_reward = 0.0
158
+ total_steps = 0
159
+ event_rewards: List[Dict[str, Any]] = []
160
+ battle_won = False
161
+ game_over = False
162
+
163
+ for step in range(max_steps):
164
+ # Format observation
165
+ prompt = self._format_observation(obs_dict, step, max_steps)
166
+
167
+ # Add image if available
168
+ messages = [{"role": "user", "content": prompt}]
169
+ if obs_dict.get("observation_image_base64"):
170
+ messages[0]["content"] = [
171
+ {
172
+ "type": "image_url",
173
+ "image_url": {
174
+ "url": f"data:image/png;base64,{obs_dict['observation_image_base64']}"
175
+ },
176
+ },
177
+ {"type": "text", "text": prompt},
178
+ ]
179
+
180
+ # Get action from LLM
181
+ if self.inference_url and self.inference_url.startswith("http"):
182
+ api_key = os.getenv("SYNTH_API_KEY") or os.getenv("OPENAI_API_KEY") or ""
183
+ base_url = self.inference_url.rstrip("/")
184
+ if not base_url.endswith("/api"):
185
+ base_url = f"{base_url}/api" if "/api" not in base_url else base_url
186
+ client = InferenceClient(base_url=base_url, api_key=api_key)
187
+ response = await client.create_chat_completion(
188
+ model=self.model,
189
+ messages=messages,
190
+ tools=self.tools,
191
+ tool_choice={"type": "function", "function": {"name": "execute_sequence"}},
192
+ temperature=self.temperature,
193
+ max_tokens=self.max_tokens,
194
+ )
195
+ else:
196
+ api_key = os.getenv("OPENAI_API_KEY") or os.getenv("GROQ_API_KEY") or ""
197
+ base_url = "https://api.openai.com/v1" if "openai" in self.model.lower() else "https://api.groq.com/openai/v1"
198
+ async with httpx.AsyncClient() as http_client:
199
+ resp = await http_client.post(
200
+ f"{base_url}/chat/completions",
201
+ json={
202
+ "model": self.model,
203
+ "messages": messages,
204
+ "tools": self.tools,
205
+ "tool_choice": {"type": "function", "function": {"name": "execute_sequence"}},
206
+ "temperature": self.temperature,
207
+ "max_tokens": self.max_tokens,
208
+ },
209
+ headers={"Authorization": f"Bearer {api_key}"} if api_key else {},
210
+ )
211
+ response = resp.json()
212
+
213
+ # Extract actions
214
+ actions = []
215
+ tool_calls = []
216
+ if "choices" in response and len(response["choices"]) > 0:
217
+ message = response["choices"][0].get("message", {})
218
+ tool_calls = message.get("tool_calls", [])
219
+ elif "tool_calls" in response:
220
+ tool_calls = response["tool_calls"]
221
+
222
+ if tool_calls:
223
+ tool_call = tool_calls[0]
224
+ actions = tool_call["function"]["arguments"].get("actions", [])
225
+
226
+ if not actions:
227
+ break
228
+
229
+ # Execute actions
230
+ for action_spec in actions:
231
+ if total_steps >= max_steps:
232
+ break
233
+
234
+ # Convert to tool call format
235
+ from synth_ai.environments.environment.tools import EnvToolCall
236
+
237
+ tool_call = EnvToolCall(
238
+ name="execute_sequence",
239
+ arguments={"actions": [action_spec]},
240
+ )
241
+
242
+ # Step environment
243
+ step_result = await env.step([tool_call])
244
+ total_steps += 1
245
+
246
+ # Get observation
247
+ step_obs = (
248
+ getattr(step_result, "observation", step_result)
249
+ if hasattr(step_result, "observation")
250
+ else step_result
251
+ )
252
+ obs_dict = step_obs if isinstance(step_obs, dict) else {}
253
+
254
+ # Extract reward
255
+ reward = getattr(step_result, "reward", 0.0)
256
+ total_reward += reward
257
+
258
+ if reward > 0:
259
+ event_rewards.append({
260
+ "step": total_steps,
261
+ "reward": reward,
262
+ })
263
+
264
+ # Check termination
265
+ if getattr(step_result, "terminated", False) or getattr(step_result, "truncated", False):
266
+ game_over = True
267
+ break
268
+
269
+ # Check battle outcome
270
+ if obs_dict.get("battle_outcome") == 1:
271
+ battle_won = True
272
+ elif obs_dict.get("battle_outcome") == 2:
273
+ game_over = True
274
+
275
+ if game_over:
276
+ break
277
+
278
+ # Cleanup
279
+ await env.terminate()
280
+
281
+ return TaskResult(
282
+ seed=seed,
283
+ success=True,
284
+ outcome_reward=total_reward,
285
+ event_rewards=event_rewards,
286
+ total_steps=total_steps,
287
+ metadata={
288
+ "battle_won": battle_won,
289
+ "game_over": game_over,
290
+ "final_map": obs_dict.get("map_id"),
291
+ "badges": obs_dict.get("badges", 0),
292
+ "party_size": obs_dict.get("party_count", 0),
293
+ },
294
+ )
295
+
296
+
297
+ # Define baseline config (only if Pokemon Red is available)
298
+ if POKEMON_RED_AVAILABLE:
299
+ pokemon_vl_baseline = BaselineConfig(
300
+ baseline_id="pokemon_vl",
301
+ name="Pokemon VL - Pokemon Red",
302
+ description="Pokemon Red Game Boy emulation baseline for vision-language agents",
303
+ task_runner=PokemonRedTaskRunner,
304
+ splits={
305
+ "train": DataSplit(name="train", seeds=list(range(20))),
306
+ "val": DataSplit(name="val", seeds=list(range(20, 25))),
307
+ "test": DataSplit(name="test", seeds=list(range(25, 30))),
308
+ },
309
+ default_policy_config={
310
+ "model": "groq:llama-3.1-70b-versatile",
311
+ "temperature": 0.0,
312
+ "max_tokens": 512,
313
+ },
314
+ default_env_config={
315
+ "rom_path": None, # Must be provided
316
+ "init_state_path": None, # Optional
317
+ "reward_type": "pallet_town_progression",
318
+ "max_steps": 500,
319
+ },
320
+ metadata={
321
+ "environment": "pokemon_red",
322
+ "task_type": "emulation",
323
+ "requires_rom": True,
324
+ },
325
+ )
326
+