synth-ai 0.2.17__py3-none-any.whl → 0.2.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/baseline/banking77_baseline.py +204 -0
- examples/baseline/crafter_baseline.py +407 -0
- examples/baseline/pokemon_red_baseline.py +326 -0
- examples/baseline/simple_baseline.py +56 -0
- examples/baseline/warming_up_to_rl_baseline.py +239 -0
- examples/blog_posts/gepa/README.md +355 -0
- examples/blog_posts/gepa/configs/banking77_gepa_local.toml +95 -0
- examples/blog_posts/gepa/configs/banking77_gepa_test.toml +82 -0
- examples/blog_posts/gepa/configs/banking77_mipro_local.toml +52 -0
- examples/blog_posts/gepa/configs/hotpotqa_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/hotpotqa_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/hotpotqa_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/hover_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/hover_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/hover_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/ifbench_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/ifbench_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/ifbench_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/pupa_gepa_local.toml +60 -0
- examples/blog_posts/gepa/configs/pupa_mipro_local.toml +54 -0
- examples/blog_posts/gepa/deploy_banking77_task_app.sh +41 -0
- examples/blog_posts/gepa/gepa_baseline.py +204 -0
- examples/blog_posts/gepa/query_prompts_example.py +97 -0
- examples/blog_posts/gepa/run_gepa_banking77.sh +87 -0
- examples/blog_posts/gepa/task_apps.py +105 -0
- examples/blog_posts/gepa/test_gepa_local.sh +67 -0
- examples/blog_posts/gepa/verify_banking77_setup.sh +123 -0
- examples/blog_posts/pokemon_vl/configs/eval_gpt5nano.toml +26 -0
- examples/blog_posts/pokemon_vl/configs/eval_qwen3_vl.toml +12 -10
- examples/blog_posts/pokemon_vl/configs/train_rl_from_sft.toml +1 -0
- examples/blog_posts/pokemon_vl/extract_images.py +239 -0
- examples/blog_posts/pokemon_vl/pokemon_vl_baseline.py +326 -0
- examples/blog_posts/pokemon_vl/run_eval_extract_images.py +209 -0
- examples/blog_posts/pokemon_vl/run_qwen_eval_extract_images.py +212 -0
- examples/blog_posts/pokemon_vl/text_box_analysis.md +106 -0
- examples/blog_posts/warming_up_to_rl/ARCHITECTURE.md +195 -0
- examples/blog_posts/warming_up_to_rl/FINAL_TEST_RESULTS.md +127 -0
- examples/blog_posts/warming_up_to_rl/INFERENCE_SUCCESS.md +132 -0
- examples/blog_posts/warming_up_to_rl/SMOKE_TESTING.md +164 -0
- examples/blog_posts/warming_up_to_rl/SMOKE_TEST_COMPLETE.md +253 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_baseline_qwen32b_10x20.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b_10x20.toml +26 -0
- examples/blog_posts/warming_up_to_rl/configs/filter_high_reward_dataset.toml +1 -1
- examples/blog_posts/warming_up_to_rl/configs/smoke_test.toml +75 -0
- examples/blog_posts/warming_up_to_rl/configs/train_rl_from_sft.toml +60 -10
- examples/blog_posts/warming_up_to_rl/configs/train_sft_qwen4b.toml +1 -1
- examples/blog_posts/warming_up_to_rl/warming_up_to_rl_baseline.py +187 -0
- examples/multi_step/configs/VERILOG_REWARDS.md +4 -0
- examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +4 -0
- examples/multi_step/configs/crafter_rl_outcome.toml +1 -0
- examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +1 -0
- examples/multi_step/configs/crafter_rl_stepwise_simple.toml +1 -0
- examples/rl/configs/rl_from_base_qwen17.toml +1 -0
- examples/swe/task_app/hosted/inference/openai_client.py +0 -34
- examples/swe/task_app/hosted/policy_routes.py +17 -0
- examples/swe/task_app/hosted/rollout.py +4 -2
- examples/task_apps/banking77/__init__.py +6 -0
- examples/task_apps/banking77/banking77_task_app.py +841 -0
- examples/task_apps/banking77/deploy_wrapper.py +46 -0
- examples/task_apps/crafter/CREATE_SFT_DATASET.md +4 -0
- examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +4 -0
- examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +4 -0
- examples/task_apps/crafter/task_app/grpo_crafter.py +24 -2
- examples/task_apps/crafter/task_app/synth_envs_hosted/hosted_app.py +49 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +355 -58
- examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +68 -7
- examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +78 -21
- examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +194 -1
- examples/task_apps/gepa_benchmarks/__init__.py +7 -0
- examples/task_apps/gepa_benchmarks/common.py +260 -0
- examples/task_apps/gepa_benchmarks/hotpotqa_task_app.py +507 -0
- examples/task_apps/gepa_benchmarks/hover_task_app.py +436 -0
- examples/task_apps/gepa_benchmarks/ifbench_task_app.py +563 -0
- examples/task_apps/gepa_benchmarks/pupa_task_app.py +460 -0
- examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +4 -0
- examples/task_apps/pokemon_red/task_app.py +254 -36
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +1 -0
- examples/warming_up_to_rl/task_app/grpo_crafter.py +53 -4
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +49 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +152 -41
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +31 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +33 -3
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +67 -0
- examples/workflows/math_rl/configs/rl_from_base_qwen17.toml +1 -0
- synth_ai/api/train/builders.py +90 -1
- synth_ai/api/train/cli.py +396 -21
- synth_ai/api/train/config_finder.py +13 -2
- synth_ai/api/train/configs/__init__.py +15 -1
- synth_ai/api/train/configs/prompt_learning.py +442 -0
- synth_ai/api/train/configs/rl.py +29 -0
- synth_ai/api/train/task_app.py +1 -1
- synth_ai/api/train/validators.py +277 -0
- synth_ai/baseline/__init__.py +25 -0
- synth_ai/baseline/config.py +209 -0
- synth_ai/baseline/discovery.py +214 -0
- synth_ai/baseline/execution.py +146 -0
- synth_ai/cli/__init__.py +85 -17
- synth_ai/cli/__main__.py +0 -0
- synth_ai/cli/claude.py +70 -0
- synth_ai/cli/codex.py +84 -0
- synth_ai/cli/commands/__init__.py +1 -0
- synth_ai/cli/commands/baseline/__init__.py +12 -0
- synth_ai/cli/commands/baseline/core.py +637 -0
- synth_ai/cli/commands/baseline/list.py +93 -0
- synth_ai/cli/commands/eval/core.py +13 -10
- synth_ai/cli/commands/filter/core.py +53 -17
- synth_ai/cli/commands/help/core.py +0 -1
- synth_ai/cli/commands/smoke/__init__.py +7 -0
- synth_ai/cli/commands/smoke/core.py +1436 -0
- synth_ai/cli/commands/status/subcommands/pricing.py +22 -0
- synth_ai/cli/commands/status/subcommands/usage.py +203 -0
- synth_ai/cli/commands/train/judge_schemas.py +1 -0
- synth_ai/cli/commands/train/judge_validation.py +1 -0
- synth_ai/cli/commands/train/validation.py +0 -57
- synth_ai/cli/demo.py +35 -3
- synth_ai/cli/deploy/__init__.py +40 -25
- synth_ai/cli/deploy.py +162 -0
- synth_ai/cli/legacy_root_backup.py +14 -8
- synth_ai/cli/opencode.py +107 -0
- synth_ai/cli/root.py +9 -5
- synth_ai/cli/task_app_deploy.py +1 -1
- synth_ai/cli/task_apps.py +53 -53
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +7 -4
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +9 -5
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +4 -3
- synth_ai/judge_schemas.py +1 -0
- synth_ai/learning/__init__.py +10 -0
- synth_ai/learning/prompt_learning_client.py +276 -0
- synth_ai/learning/prompt_learning_types.py +184 -0
- synth_ai/pricing/__init__.py +2 -0
- synth_ai/pricing/model_pricing.py +57 -0
- synth_ai/streaming/handlers.py +53 -4
- synth_ai/streaming/streamer.py +19 -0
- synth_ai/task/apps/__init__.py +1 -0
- synth_ai/task/config.py +2 -0
- synth_ai/task/tracing_utils.py +25 -25
- synth_ai/task/validators.py +44 -8
- synth_ai/task_app_cfgs.py +21 -0
- synth_ai/tracing_v3/config.py +162 -19
- synth_ai/tracing_v3/constants.py +1 -1
- synth_ai/tracing_v3/db_config.py +24 -38
- synth_ai/tracing_v3/storage/config.py +47 -13
- synth_ai/tracing_v3/storage/factory.py +3 -3
- synth_ai/tracing_v3/turso/daemon.py +113 -11
- synth_ai/tracing_v3/turso/native_manager.py +92 -16
- synth_ai/types.py +8 -0
- synth_ai/urls.py +11 -0
- synth_ai/utils/__init__.py +30 -1
- synth_ai/utils/agents.py +74 -0
- synth_ai/utils/bin.py +39 -0
- synth_ai/utils/cli.py +149 -5
- synth_ai/utils/env.py +17 -17
- synth_ai/utils/json.py +72 -0
- synth_ai/utils/modal.py +283 -1
- synth_ai/utils/paths.py +48 -0
- synth_ai/utils/uvicorn.py +113 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/METADATA +102 -4
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/RECORD +162 -88
- synth_ai/cli/commands/deploy/__init__.py +0 -23
- synth_ai/cli/commands/deploy/core.py +0 -614
- synth_ai/cli/commands/deploy/errors.py +0 -72
- synth_ai/cli/commands/deploy/validation.py +0 -11
- synth_ai/cli/deploy/core.py +0 -5
- synth_ai/cli/deploy/errors.py +0 -23
- synth_ai/cli/deploy/validation.py +0 -5
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,326 @@
|
|
|
1
|
+
"""Pokemon Red baseline file for Game Boy emulation evaluation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Dict, List, Optional
|
|
6
|
+
|
|
7
|
+
from synth_ai.baseline import BaselineConfig, BaselineTaskRunner, DataSplit, TaskResult
|
|
8
|
+
from synth_ai.inference import InferenceClient
|
|
9
|
+
import os
|
|
10
|
+
import httpx
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
from synth_ai.environments.examples.red.environment import PokemonRedEnvironment
|
|
14
|
+
from synth_ai.environments.examples.red.taskset import (
|
|
15
|
+
PokemonRedTaskInstance,
|
|
16
|
+
PokemonRedTaskInstanceMetadata,
|
|
17
|
+
)
|
|
18
|
+
POKEMON_RED_AVAILABLE = True
|
|
19
|
+
except ImportError:
|
|
20
|
+
POKEMON_RED_AVAILABLE = False
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class PokemonRedTaskRunner(BaselineTaskRunner):
|
|
24
|
+
"""Task runner for Pokemon Red Game Boy emulation."""
|
|
25
|
+
|
|
26
|
+
def __init__(self, policy_config: Dict[str, Any], env_config: Dict[str, Any]):
|
|
27
|
+
super().__init__(policy_config, env_config)
|
|
28
|
+
|
|
29
|
+
if not POKEMON_RED_AVAILABLE:
|
|
30
|
+
raise ImportError(
|
|
31
|
+
"Pokemon Red environment not available. "
|
|
32
|
+
"Install synth-ai with Pokemon Red support."
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
# Store config for inference
|
|
36
|
+
self.model = policy_config["model"]
|
|
37
|
+
self.temperature = policy_config.get("temperature", 0.0)
|
|
38
|
+
self.max_tokens = policy_config.get("max_tokens", 512)
|
|
39
|
+
self.inference_url = policy_config.get("inference_url")
|
|
40
|
+
|
|
41
|
+
# Tool definition
|
|
42
|
+
self.tools = [{
|
|
43
|
+
"type": "function",
|
|
44
|
+
"function": {
|
|
45
|
+
"name": "execute_sequence",
|
|
46
|
+
"description": "Execute multiple button presses in sequence",
|
|
47
|
+
"parameters": {
|
|
48
|
+
"type": "object",
|
|
49
|
+
"properties": {
|
|
50
|
+
"actions": {
|
|
51
|
+
"type": "array",
|
|
52
|
+
"items": {
|
|
53
|
+
"type": "object",
|
|
54
|
+
"properties": {
|
|
55
|
+
"button": {
|
|
56
|
+
"type": "string",
|
|
57
|
+
"enum": ["UP", "DOWN", "LEFT", "RIGHT", "A", "B", "START", "SELECT"],
|
|
58
|
+
},
|
|
59
|
+
"frames": {
|
|
60
|
+
"type": "integer",
|
|
61
|
+
"minimum": 1,
|
|
62
|
+
"maximum": 120,
|
|
63
|
+
"description": "Frames to hold button (60fps)",
|
|
64
|
+
},
|
|
65
|
+
},
|
|
66
|
+
"required": ["button", "frames"],
|
|
67
|
+
},
|
|
68
|
+
"minItems": 1,
|
|
69
|
+
"maxItems": 20,
|
|
70
|
+
},
|
|
71
|
+
},
|
|
72
|
+
"required": ["actions"],
|
|
73
|
+
},
|
|
74
|
+
},
|
|
75
|
+
}]
|
|
76
|
+
|
|
77
|
+
def _format_observation(self, obs: Dict[str, Any], step: int, max_steps: int) -> str:
|
|
78
|
+
"""Format observation for LLM."""
|
|
79
|
+
lines = [
|
|
80
|
+
f"Pokemon Red - Step {step}/{max_steps}",
|
|
81
|
+
"",
|
|
82
|
+
]
|
|
83
|
+
|
|
84
|
+
# Position
|
|
85
|
+
if "map_id" in obs:
|
|
86
|
+
lines.append(f"Location: Map {obs['map_id']}")
|
|
87
|
+
if "player_x" in obs and "player_y" in obs:
|
|
88
|
+
lines.append(f"Position: ({obs['player_x']}, {obs['player_y']})")
|
|
89
|
+
|
|
90
|
+
# Party
|
|
91
|
+
if "party_count" in obs:
|
|
92
|
+
lines.append(f"Party Size: {obs['party_count']}")
|
|
93
|
+
if "party_pokemon" in obs and obs["party_pokemon"]:
|
|
94
|
+
pokemon = obs["party_pokemon"][0]
|
|
95
|
+
lines.append(
|
|
96
|
+
f"First Pokemon: Level {pokemon.get('level', '?')}, "
|
|
97
|
+
f"HP {pokemon.get('hp_current', '?')}/{pokemon.get('hp_max', '?')}"
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
# Battle
|
|
101
|
+
if obs.get("in_battle"):
|
|
102
|
+
lines.append("=== IN BATTLE ===")
|
|
103
|
+
if "enemy_hp_current" in obs:
|
|
104
|
+
lines.append(
|
|
105
|
+
f"Enemy HP: {obs['enemy_hp_current']}/{obs.get('enemy_hp_max', '?')}"
|
|
106
|
+
)
|
|
107
|
+
if "battle_turn" in obs:
|
|
108
|
+
lines.append(f"Battle Turn: {obs['battle_turn']}")
|
|
109
|
+
|
|
110
|
+
# Progress
|
|
111
|
+
if "badges" in obs:
|
|
112
|
+
lines.append(f"Badges: {obs['badges']}")
|
|
113
|
+
if "money" in obs:
|
|
114
|
+
lines.append(f"Money: ${obs['money']}")
|
|
115
|
+
|
|
116
|
+
# Dialogue
|
|
117
|
+
if obs.get("text_box_active"):
|
|
118
|
+
lines.append("Text box is active - press A to advance dialogue")
|
|
119
|
+
|
|
120
|
+
lines.append("")
|
|
121
|
+
lines.append("What actions should we take?")
|
|
122
|
+
|
|
123
|
+
return "\n".join(lines)
|
|
124
|
+
|
|
125
|
+
async def run_task(self, seed: int) -> TaskResult:
|
|
126
|
+
"""Run a single Pokemon Red episode."""
|
|
127
|
+
|
|
128
|
+
# Create task instance
|
|
129
|
+
rom_path = self.env_config.get("rom_path")
|
|
130
|
+
if not rom_path:
|
|
131
|
+
raise ValueError("rom_path required in env_config for Pokemon Red")
|
|
132
|
+
|
|
133
|
+
init_state_path = self.env_config.get("init_state_path")
|
|
134
|
+
max_steps = self.env_config.get("max_steps", 500)
|
|
135
|
+
|
|
136
|
+
metadata = PokemonRedTaskInstanceMetadata(
|
|
137
|
+
seed=seed,
|
|
138
|
+
rom_path=rom_path,
|
|
139
|
+
init_state_path=init_state_path,
|
|
140
|
+
reward_type=self.env_config.get("reward_type", "pallet_town_progression"),
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
task_instance = PokemonRedTaskInstance(
|
|
144
|
+
id=f"pokemon-red-{seed}",
|
|
145
|
+
metadata=metadata,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# Create environment
|
|
149
|
+
env = PokemonRedEnvironment(task_instance=task_instance)
|
|
150
|
+
|
|
151
|
+
# Initialize environment
|
|
152
|
+
raw_obs = await env.initialize()
|
|
153
|
+
observation = getattr(raw_obs, "observation", raw_obs) if hasattr(raw_obs, "observation") else raw_obs
|
|
154
|
+
obs_dict = observation if isinstance(observation, dict) else {}
|
|
155
|
+
|
|
156
|
+
# Episode loop
|
|
157
|
+
total_reward = 0.0
|
|
158
|
+
total_steps = 0
|
|
159
|
+
event_rewards: List[Dict[str, Any]] = []
|
|
160
|
+
battle_won = False
|
|
161
|
+
game_over = False
|
|
162
|
+
|
|
163
|
+
for step in range(max_steps):
|
|
164
|
+
# Format observation
|
|
165
|
+
prompt = self._format_observation(obs_dict, step, max_steps)
|
|
166
|
+
|
|
167
|
+
# Add image if available
|
|
168
|
+
messages = [{"role": "user", "content": prompt}]
|
|
169
|
+
if obs_dict.get("observation_image_base64"):
|
|
170
|
+
messages[0]["content"] = [
|
|
171
|
+
{
|
|
172
|
+
"type": "image_url",
|
|
173
|
+
"image_url": {
|
|
174
|
+
"url": f"data:image/png;base64,{obs_dict['observation_image_base64']}"
|
|
175
|
+
},
|
|
176
|
+
},
|
|
177
|
+
{"type": "text", "text": prompt},
|
|
178
|
+
]
|
|
179
|
+
|
|
180
|
+
# Get action from LLM
|
|
181
|
+
if self.inference_url and self.inference_url.startswith("http"):
|
|
182
|
+
api_key = os.getenv("SYNTH_API_KEY") or os.getenv("OPENAI_API_KEY") or ""
|
|
183
|
+
base_url = self.inference_url.rstrip("/")
|
|
184
|
+
if not base_url.endswith("/api"):
|
|
185
|
+
base_url = f"{base_url}/api" if "/api" not in base_url else base_url
|
|
186
|
+
client = InferenceClient(base_url=base_url, api_key=api_key)
|
|
187
|
+
response = await client.create_chat_completion(
|
|
188
|
+
model=self.model,
|
|
189
|
+
messages=messages,
|
|
190
|
+
tools=self.tools,
|
|
191
|
+
tool_choice={"type": "function", "function": {"name": "execute_sequence"}},
|
|
192
|
+
temperature=self.temperature,
|
|
193
|
+
max_tokens=self.max_tokens,
|
|
194
|
+
)
|
|
195
|
+
else:
|
|
196
|
+
api_key = os.getenv("OPENAI_API_KEY") or os.getenv("GROQ_API_KEY") or ""
|
|
197
|
+
base_url = "https://api.openai.com/v1" if "openai" in self.model.lower() else "https://api.groq.com/openai/v1"
|
|
198
|
+
async with httpx.AsyncClient() as http_client:
|
|
199
|
+
resp = await http_client.post(
|
|
200
|
+
f"{base_url}/chat/completions",
|
|
201
|
+
json={
|
|
202
|
+
"model": self.model,
|
|
203
|
+
"messages": messages,
|
|
204
|
+
"tools": self.tools,
|
|
205
|
+
"tool_choice": {"type": "function", "function": {"name": "execute_sequence"}},
|
|
206
|
+
"temperature": self.temperature,
|
|
207
|
+
"max_tokens": self.max_tokens,
|
|
208
|
+
},
|
|
209
|
+
headers={"Authorization": f"Bearer {api_key}"} if api_key else {},
|
|
210
|
+
)
|
|
211
|
+
response = resp.json()
|
|
212
|
+
|
|
213
|
+
# Extract actions
|
|
214
|
+
actions = []
|
|
215
|
+
tool_calls = []
|
|
216
|
+
if "choices" in response and len(response["choices"]) > 0:
|
|
217
|
+
message = response["choices"][0].get("message", {})
|
|
218
|
+
tool_calls = message.get("tool_calls", [])
|
|
219
|
+
elif "tool_calls" in response:
|
|
220
|
+
tool_calls = response["tool_calls"]
|
|
221
|
+
|
|
222
|
+
if tool_calls:
|
|
223
|
+
tool_call = tool_calls[0]
|
|
224
|
+
actions = tool_call["function"]["arguments"].get("actions", [])
|
|
225
|
+
|
|
226
|
+
if not actions:
|
|
227
|
+
break
|
|
228
|
+
|
|
229
|
+
# Execute actions
|
|
230
|
+
for action_spec in actions:
|
|
231
|
+
if total_steps >= max_steps:
|
|
232
|
+
break
|
|
233
|
+
|
|
234
|
+
# Convert to tool call format
|
|
235
|
+
from synth_ai.environments.environment.tools import EnvToolCall
|
|
236
|
+
|
|
237
|
+
tool_call = EnvToolCall(
|
|
238
|
+
name="execute_sequence",
|
|
239
|
+
arguments={"actions": [action_spec]},
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
# Step environment
|
|
243
|
+
step_result = await env.step([tool_call])
|
|
244
|
+
total_steps += 1
|
|
245
|
+
|
|
246
|
+
# Get observation
|
|
247
|
+
step_obs = (
|
|
248
|
+
getattr(step_result, "observation", step_result)
|
|
249
|
+
if hasattr(step_result, "observation")
|
|
250
|
+
else step_result
|
|
251
|
+
)
|
|
252
|
+
obs_dict = step_obs if isinstance(step_obs, dict) else {}
|
|
253
|
+
|
|
254
|
+
# Extract reward
|
|
255
|
+
reward = getattr(step_result, "reward", 0.0)
|
|
256
|
+
total_reward += reward
|
|
257
|
+
|
|
258
|
+
if reward > 0:
|
|
259
|
+
event_rewards.append({
|
|
260
|
+
"step": total_steps,
|
|
261
|
+
"reward": reward,
|
|
262
|
+
})
|
|
263
|
+
|
|
264
|
+
# Check termination
|
|
265
|
+
if getattr(step_result, "terminated", False) or getattr(step_result, "truncated", False):
|
|
266
|
+
game_over = True
|
|
267
|
+
break
|
|
268
|
+
|
|
269
|
+
# Check battle outcome
|
|
270
|
+
if obs_dict.get("battle_outcome") == 1:
|
|
271
|
+
battle_won = True
|
|
272
|
+
elif obs_dict.get("battle_outcome") == 2:
|
|
273
|
+
game_over = True
|
|
274
|
+
|
|
275
|
+
if game_over:
|
|
276
|
+
break
|
|
277
|
+
|
|
278
|
+
# Cleanup
|
|
279
|
+
await env.terminate()
|
|
280
|
+
|
|
281
|
+
return TaskResult(
|
|
282
|
+
seed=seed,
|
|
283
|
+
success=True,
|
|
284
|
+
outcome_reward=total_reward,
|
|
285
|
+
event_rewards=event_rewards,
|
|
286
|
+
total_steps=total_steps,
|
|
287
|
+
metadata={
|
|
288
|
+
"battle_won": battle_won,
|
|
289
|
+
"game_over": game_over,
|
|
290
|
+
"final_map": obs_dict.get("map_id"),
|
|
291
|
+
"badges": obs_dict.get("badges", 0),
|
|
292
|
+
"party_size": obs_dict.get("party_count", 0),
|
|
293
|
+
},
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
# Define baseline config (only if Pokemon Red is available)
|
|
298
|
+
if POKEMON_RED_AVAILABLE:
|
|
299
|
+
pokemon_red_baseline = BaselineConfig(
|
|
300
|
+
baseline_id="pokemon_red",
|
|
301
|
+
name="Pokemon Red",
|
|
302
|
+
description="Pokemon Red Game Boy emulation with PyBoy",
|
|
303
|
+
task_runner=PokemonRedTaskRunner,
|
|
304
|
+
splits={
|
|
305
|
+
"train": DataSplit(name="train", seeds=list(range(20))),
|
|
306
|
+
"val": DataSplit(name="val", seeds=list(range(20, 25))),
|
|
307
|
+
"test": DataSplit(name="test", seeds=list(range(25, 30))),
|
|
308
|
+
},
|
|
309
|
+
default_policy_config={
|
|
310
|
+
"model": "groq:llama-3.1-70b-versatile",
|
|
311
|
+
"temperature": 0.0,
|
|
312
|
+
"max_tokens": 512,
|
|
313
|
+
},
|
|
314
|
+
default_env_config={
|
|
315
|
+
"rom_path": None, # Must be provided
|
|
316
|
+
"init_state_path": None, # Optional
|
|
317
|
+
"reward_type": "pallet_town_progression",
|
|
318
|
+
"max_steps": 500,
|
|
319
|
+
},
|
|
320
|
+
metadata={
|
|
321
|
+
"environment": "pokemon_red",
|
|
322
|
+
"task_type": "emulation",
|
|
323
|
+
"requires_rom": True,
|
|
324
|
+
},
|
|
325
|
+
)
|
|
326
|
+
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Simple example baseline file for testing."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from synth_ai.baseline import BaselineConfig, BaselineTaskRunner, DataSplit, TaskResult
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SimpleTaskRunner(BaselineTaskRunner):
|
|
9
|
+
"""Simple task runner that returns success for testing."""
|
|
10
|
+
|
|
11
|
+
async def run_task(self, seed: int) -> TaskResult:
|
|
12
|
+
"""Execute a simple task that always succeeds."""
|
|
13
|
+
return TaskResult(
|
|
14
|
+
seed=seed,
|
|
15
|
+
success=True,
|
|
16
|
+
outcome_reward=1.0,
|
|
17
|
+
total_steps=1,
|
|
18
|
+
metadata={
|
|
19
|
+
"seed": seed,
|
|
20
|
+
"message": f"Task completed successfully for seed {seed}",
|
|
21
|
+
},
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
# Define baseline config
|
|
26
|
+
simple_baseline = BaselineConfig(
|
|
27
|
+
baseline_id="simple",
|
|
28
|
+
name="Simple Baseline",
|
|
29
|
+
description="A simple baseline for testing",
|
|
30
|
+
task_runner=SimpleTaskRunner,
|
|
31
|
+
splits={
|
|
32
|
+
"train": DataSplit(
|
|
33
|
+
name="train",
|
|
34
|
+
seeds=list(range(10)),
|
|
35
|
+
metadata={"difficulty": "easy"},
|
|
36
|
+
),
|
|
37
|
+
"val": DataSplit(
|
|
38
|
+
name="val",
|
|
39
|
+
seeds=list(range(10, 15)),
|
|
40
|
+
metadata={"difficulty": "medium"},
|
|
41
|
+
),
|
|
42
|
+
"test": DataSplit(
|
|
43
|
+
name="test",
|
|
44
|
+
seeds=list(range(15, 20)),
|
|
45
|
+
metadata={"difficulty": "hard"},
|
|
46
|
+
),
|
|
47
|
+
},
|
|
48
|
+
default_policy_config={
|
|
49
|
+
"model": "gpt-4o-mini",
|
|
50
|
+
"temperature": 0.0,
|
|
51
|
+
},
|
|
52
|
+
default_env_config={
|
|
53
|
+
"max_steps": 10,
|
|
54
|
+
},
|
|
55
|
+
)
|
|
56
|
+
|
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
"""Warming Up to RL baseline file for Gymnasium environments."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Dict
|
|
6
|
+
|
|
7
|
+
import gymnasium as gym
|
|
8
|
+
|
|
9
|
+
from synth_ai.baseline import BaselineConfig, BaselineTaskRunner, DataSplit, TaskResult
|
|
10
|
+
from synth_ai.inference import InferenceClient
|
|
11
|
+
import os
|
|
12
|
+
import httpx
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class WarmingUpToRLTaskRunner(BaselineTaskRunner):
|
|
16
|
+
"""Task runner for Gymnasium environments (CartPole, FrozenLake, etc.)."""
|
|
17
|
+
|
|
18
|
+
def __init__(self, policy_config: Dict[str, Any], env_config: Dict[str, Any]):
|
|
19
|
+
super().__init__(policy_config, env_config)
|
|
20
|
+
|
|
21
|
+
# Store config for inference
|
|
22
|
+
self.model = policy_config["model"]
|
|
23
|
+
self.temperature = policy_config.get("temperature", 0.0)
|
|
24
|
+
self.max_tokens = policy_config.get("max_tokens", 128)
|
|
25
|
+
self.inference_url = policy_config.get("inference_url")
|
|
26
|
+
|
|
27
|
+
# Environment name
|
|
28
|
+
self.env_name = env_config.get("env_name", "CartPole-v1")
|
|
29
|
+
|
|
30
|
+
def _get_action_tool(self, env: gym.Env) -> Dict[str, Any]:
|
|
31
|
+
"""Generate tool schema based on environment action space."""
|
|
32
|
+
if isinstance(env.action_space, gym.spaces.Discrete):
|
|
33
|
+
return {
|
|
34
|
+
"type": "function",
|
|
35
|
+
"function": {
|
|
36
|
+
"name": "take_action",
|
|
37
|
+
"description": f"Take action in {env.spec.id if env.spec else self.env_name}",
|
|
38
|
+
"parameters": {
|
|
39
|
+
"type": "object",
|
|
40
|
+
"properties": {
|
|
41
|
+
"action": {
|
|
42
|
+
"type": "integer",
|
|
43
|
+
"minimum": 0,
|
|
44
|
+
"maximum": env.action_space.n - 1,
|
|
45
|
+
"description": "Action index",
|
|
46
|
+
}
|
|
47
|
+
},
|
|
48
|
+
"required": ["action"],
|
|
49
|
+
},
|
|
50
|
+
},
|
|
51
|
+
}
|
|
52
|
+
else:
|
|
53
|
+
# Default for unknown action spaces
|
|
54
|
+
return {
|
|
55
|
+
"type": "function",
|
|
56
|
+
"function": {
|
|
57
|
+
"name": "take_action",
|
|
58
|
+
"description": "Take action in the environment",
|
|
59
|
+
"parameters": {
|
|
60
|
+
"type": "object",
|
|
61
|
+
"properties": {
|
|
62
|
+
"action": {
|
|
63
|
+
"type": "integer",
|
|
64
|
+
"description": "Action index",
|
|
65
|
+
}
|
|
66
|
+
},
|
|
67
|
+
"required": ["action"],
|
|
68
|
+
},
|
|
69
|
+
},
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
def _format_observation(self, obs: Any, env: gym.Env, step: int, max_steps: int) -> str:
|
|
73
|
+
"""Format observation for LLM."""
|
|
74
|
+
obs_str = str(obs)
|
|
75
|
+
if hasattr(env, "spec") and env.spec:
|
|
76
|
+
env_id = env.spec.id
|
|
77
|
+
else:
|
|
78
|
+
env_id = self.env_name
|
|
79
|
+
|
|
80
|
+
return f"""Environment: {env_id}
|
|
81
|
+
Step: {step}/{max_steps}
|
|
82
|
+
Observation: {obs_str}
|
|
83
|
+
|
|
84
|
+
What action should we take?"""
|
|
85
|
+
|
|
86
|
+
async def run_task(self, seed: int) -> TaskResult:
|
|
87
|
+
"""Run a single Gymnasium episode."""
|
|
88
|
+
|
|
89
|
+
# Create environment
|
|
90
|
+
env = gym.make(self.env_name)
|
|
91
|
+
|
|
92
|
+
# Reset with seed
|
|
93
|
+
obs, info = env.reset(seed=seed)
|
|
94
|
+
|
|
95
|
+
# Get action tool
|
|
96
|
+
action_tool = self._get_action_tool(env)
|
|
97
|
+
|
|
98
|
+
# Episode loop
|
|
99
|
+
total_reward = 0.0
|
|
100
|
+
total_steps = 0
|
|
101
|
+
max_steps = self.env_config.get("max_steps", 500)
|
|
102
|
+
|
|
103
|
+
terminated = False
|
|
104
|
+
truncated = False
|
|
105
|
+
|
|
106
|
+
for step in range(max_steps):
|
|
107
|
+
# Format observation
|
|
108
|
+
prompt = self._format_observation(obs, env, step, max_steps)
|
|
109
|
+
|
|
110
|
+
# Get action from LLM
|
|
111
|
+
messages = [{"role": "user", "content": prompt}]
|
|
112
|
+
|
|
113
|
+
if self.inference_url and self.inference_url.startswith("http"):
|
|
114
|
+
api_key = os.getenv("SYNTH_API_KEY") or os.getenv("OPENAI_API_KEY") or ""
|
|
115
|
+
base_url = self.inference_url.rstrip("/")
|
|
116
|
+
if not base_url.endswith("/api"):
|
|
117
|
+
base_url = f"{base_url}/api" if "/api" not in base_url else base_url
|
|
118
|
+
client = InferenceClient(base_url=base_url, api_key=api_key)
|
|
119
|
+
response = await client.create_chat_completion(
|
|
120
|
+
model=self.model,
|
|
121
|
+
messages=messages,
|
|
122
|
+
tools=[action_tool],
|
|
123
|
+
tool_choice={"type": "function", "function": {"name": "take_action"}},
|
|
124
|
+
temperature=self.temperature,
|
|
125
|
+
max_tokens=self.max_tokens,
|
|
126
|
+
)
|
|
127
|
+
else:
|
|
128
|
+
api_key = os.getenv("OPENAI_API_KEY") or os.getenv("GROQ_API_KEY") or ""
|
|
129
|
+
base_url = "https://api.openai.com/v1" if "openai" in self.model.lower() else "https://api.groq.com/openai/v1"
|
|
130
|
+
async with httpx.AsyncClient() as http_client:
|
|
131
|
+
resp = await http_client.post(
|
|
132
|
+
f"{base_url}/chat/completions",
|
|
133
|
+
json={
|
|
134
|
+
"model": self.model,
|
|
135
|
+
"messages": messages,
|
|
136
|
+
"tools": [action_tool],
|
|
137
|
+
"tool_choice": {"type": "function", "function": {"name": "take_action"}},
|
|
138
|
+
"temperature": self.temperature,
|
|
139
|
+
"max_tokens": self.max_tokens,
|
|
140
|
+
},
|
|
141
|
+
headers={"Authorization": f"Bearer {api_key}"} if api_key else {},
|
|
142
|
+
)
|
|
143
|
+
response = resp.json()
|
|
144
|
+
|
|
145
|
+
# Extract action
|
|
146
|
+
action = 0
|
|
147
|
+
tool_calls = []
|
|
148
|
+
if "choices" in response and len(response["choices"]) > 0:
|
|
149
|
+
message = response["choices"][0].get("message", {})
|
|
150
|
+
tool_calls = message.get("tool_calls", [])
|
|
151
|
+
elif "tool_calls" in response:
|
|
152
|
+
tool_calls = response["tool_calls"]
|
|
153
|
+
|
|
154
|
+
if tool_calls:
|
|
155
|
+
action = tool_calls[0]["function"]["arguments"].get("action", 0)
|
|
156
|
+
else:
|
|
157
|
+
# Fallback: sample random action
|
|
158
|
+
action = env.action_space.sample()
|
|
159
|
+
|
|
160
|
+
# Step environment
|
|
161
|
+
obs, reward, terminated, truncated, info = env.step(action)
|
|
162
|
+
total_reward += reward
|
|
163
|
+
total_steps += 1
|
|
164
|
+
|
|
165
|
+
if terminated or truncated:
|
|
166
|
+
break
|
|
167
|
+
|
|
168
|
+
env.close()
|
|
169
|
+
|
|
170
|
+
return TaskResult(
|
|
171
|
+
seed=seed,
|
|
172
|
+
success=True,
|
|
173
|
+
outcome_reward=total_reward,
|
|
174
|
+
total_steps=total_steps,
|
|
175
|
+
metadata={
|
|
176
|
+
"env_name": self.env_name,
|
|
177
|
+
"episode_length": total_steps,
|
|
178
|
+
"terminated": terminated,
|
|
179
|
+
"truncated": truncated,
|
|
180
|
+
"final_reward": total_reward,
|
|
181
|
+
},
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
# Define baseline configs for different environments
|
|
186
|
+
cartpole_baseline = BaselineConfig(
|
|
187
|
+
baseline_id="cartpole",
|
|
188
|
+
name="CartPole-v1",
|
|
189
|
+
description="Balance a pole on a cart using Gymnasium",
|
|
190
|
+
task_runner=WarmingUpToRLTaskRunner,
|
|
191
|
+
splits={
|
|
192
|
+
"train": DataSplit(name="train", seeds=list(range(100))),
|
|
193
|
+
"val": DataSplit(name="val", seeds=list(range(100, 120))),
|
|
194
|
+
"test": DataSplit(name="test", seeds=list(range(120, 140))),
|
|
195
|
+
},
|
|
196
|
+
default_policy_config={
|
|
197
|
+
"model": "groq:llama-3.1-70b-versatile",
|
|
198
|
+
"temperature": 0.0,
|
|
199
|
+
"max_tokens": 128,
|
|
200
|
+
},
|
|
201
|
+
default_env_config={
|
|
202
|
+
"env_name": "CartPole-v1",
|
|
203
|
+
"max_steps": 500,
|
|
204
|
+
},
|
|
205
|
+
metadata={
|
|
206
|
+
"environment": "CartPole-v1",
|
|
207
|
+
"task_type": "control",
|
|
208
|
+
"max_reward": 500,
|
|
209
|
+
},
|
|
210
|
+
tags=["rl", "gymnasium", "control"],
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
frozenlake_baseline = BaselineConfig(
|
|
214
|
+
baseline_id="frozenlake",
|
|
215
|
+
name="FrozenLake-v1",
|
|
216
|
+
description="Navigate a frozen lake to reach goal using Gymnasium",
|
|
217
|
+
task_runner=WarmingUpToRLTaskRunner,
|
|
218
|
+
splits={
|
|
219
|
+
"train": DataSplit(name="train", seeds=list(range(100))),
|
|
220
|
+
"val": DataSplit(name="val", seeds=list(range(100, 120))),
|
|
221
|
+
"test": DataSplit(name="test", seeds=list(range(120, 140))),
|
|
222
|
+
},
|
|
223
|
+
default_policy_config={
|
|
224
|
+
"model": "groq:llama-3.1-70b-versatile",
|
|
225
|
+
"temperature": 0.0,
|
|
226
|
+
"max_tokens": 128,
|
|
227
|
+
},
|
|
228
|
+
default_env_config={
|
|
229
|
+
"env_name": "FrozenLake-v1",
|
|
230
|
+
"max_steps": 100,
|
|
231
|
+
},
|
|
232
|
+
metadata={
|
|
233
|
+
"environment": "FrozenLake-v1",
|
|
234
|
+
"task_type": "navigation",
|
|
235
|
+
"max_reward": 1,
|
|
236
|
+
},
|
|
237
|
+
tags=["rl", "gymnasium", "navigation"],
|
|
238
|
+
)
|
|
239
|
+
|