synth-ai 0.2.16__py3-none-any.whl → 0.2.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/analyze_semantic_words.sh +2 -2
- examples/baseline/banking77_baseline.py +204 -0
- examples/baseline/crafter_baseline.py +407 -0
- examples/baseline/pokemon_red_baseline.py +326 -0
- examples/baseline/simple_baseline.py +56 -0
- examples/baseline/warming_up_to_rl_baseline.py +239 -0
- examples/blog_posts/gepa/README.md +355 -0
- examples/blog_posts/gepa/configs/banking77_gepa_local.toml +95 -0
- examples/blog_posts/gepa/configs/banking77_gepa_test.toml +82 -0
- examples/blog_posts/gepa/configs/banking77_mipro_local.toml +52 -0
- examples/blog_posts/gepa/configs/hotpotqa_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/hotpotqa_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/hotpotqa_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/hover_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/hover_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/hover_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/ifbench_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/ifbench_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/ifbench_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/pupa_gepa_local.toml +60 -0
- examples/blog_posts/gepa/configs/pupa_mipro_local.toml +54 -0
- examples/blog_posts/gepa/deploy_banking77_task_app.sh +41 -0
- examples/blog_posts/gepa/gepa_baseline.py +204 -0
- examples/blog_posts/gepa/query_prompts_example.py +97 -0
- examples/blog_posts/gepa/run_gepa_banking77.sh +87 -0
- examples/blog_posts/gepa/task_apps.py +105 -0
- examples/blog_posts/gepa/test_gepa_local.sh +67 -0
- examples/blog_posts/gepa/verify_banking77_setup.sh +123 -0
- examples/blog_posts/pokemon_vl/README.md +98 -0
- examples/blog_posts/pokemon_vl/configs/eval_gpt5nano.toml +26 -0
- examples/blog_posts/pokemon_vl/configs/eval_qwen3_vl.toml +27 -0
- examples/blog_posts/pokemon_vl/configs/eval_rl_final.toml +24 -0
- examples/blog_posts/pokemon_vl/configs/filter_high_reward.toml +10 -0
- examples/blog_posts/pokemon_vl/configs/train_rl_from_sft.toml +43 -0
- examples/blog_posts/pokemon_vl/configs/train_sft_qwen4b_vl.toml +40 -0
- examples/blog_posts/pokemon_vl/extract_images.py +239 -0
- examples/blog_posts/pokemon_vl/pokemon_vl_baseline.py +326 -0
- examples/blog_posts/pokemon_vl/run_eval_extract_images.py +209 -0
- examples/blog_posts/pokemon_vl/run_qwen_eval_extract_images.py +212 -0
- examples/blog_posts/pokemon_vl/text_box_analysis.md +106 -0
- examples/blog_posts/warming_up_to_rl/ARCHITECTURE.md +195 -0
- examples/blog_posts/warming_up_to_rl/FINAL_TEST_RESULTS.md +127 -0
- examples/blog_posts/warming_up_to_rl/INFERENCE_SUCCESS.md +132 -0
- examples/blog_posts/warming_up_to_rl/README.md +158 -0
- examples/blog_posts/warming_up_to_rl/SMOKE_TESTING.md +164 -0
- examples/blog_posts/warming_up_to_rl/SMOKE_TEST_COMPLETE.md +253 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_baseline_qwen32b_10x20.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b_10x20.toml +26 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_groq_qwen32b.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_openai_gpt_oss_120b.toml +29 -0
- examples/blog_posts/warming_up_to_rl/configs/filter_high_reward_dataset.toml +10 -0
- examples/blog_posts/warming_up_to_rl/configs/smoke_test.toml +75 -0
- examples/blog_posts/warming_up_to_rl/configs/train_rl_from_sft.toml +91 -0
- examples/blog_posts/warming_up_to_rl/configs/train_sft_qwen4b.toml +40 -0
- examples/blog_posts/warming_up_to_rl/warming_up_to_rl_baseline.py +187 -0
- examples/dev/qwen3_32b_qlora_4xh100.toml +5 -0
- examples/multi_step/configs/VERILOG_REWARDS.md +4 -0
- examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +4 -0
- examples/multi_step/configs/crafter_rl_outcome.toml +2 -1
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +65 -107
- examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +2 -1
- examples/multi_step/configs/crafter_rl_stepwise_simple.toml +2 -1
- examples/multi_step/configs/crafter_rl_stepwise_simple_NEW_FORMAT.toml +105 -0
- examples/multi_step/configs/verilog_rl_lora.toml +80 -123
- examples/qwen_coder/configs/coder_lora_30b.toml +1 -3
- examples/qwen_coder/configs/coder_lora_4b.toml +4 -1
- examples/qwen_coder/configs/coder_lora_small.toml +1 -3
- examples/qwen_vl/README.md +10 -12
- examples/qwen_vl/SETUP_COMPLETE.md +7 -8
- examples/qwen_vl/VISION_TESTS_COMPLETE.md +2 -3
- examples/qwen_vl/collect_data_via_cli.md +76 -84
- examples/qwen_vl/collect_vision_traces.py +4 -4
- examples/qwen_vl/configs/crafter_rl_vision_qwen3vl4b.toml +40 -57
- examples/qwen_vl/configs/crafter_vlm_sft_example.toml +1 -2
- examples/qwen_vl/configs/eval_gpt4o_mini_vision.toml +20 -37
- examples/qwen_vl/configs/eval_gpt5nano_vision.toml +21 -40
- examples/qwen_vl/configs/eval_qwen3vl_vision.toml +26 -0
- examples/qwen_vl/configs/{filter_qwen2vl_sft.toml → filter_qwen3vl_sft.toml} +4 -5
- examples/qwen_vl/configs/filter_vision_sft.toml +2 -3
- examples/qwen_vl/crafter_qwen_vl_agent.py +5 -5
- examples/qwen_vl/run_vision_comparison.sh +6 -7
- examples/rl/README.md +5 -5
- examples/rl/configs/rl_from_base_qwen.toml +26 -1
- examples/rl/configs/rl_from_base_qwen17.toml +6 -2
- examples/rl/task_app/README.md +1 -2
- examples/rl/task_app/math_single_step.py +2 -2
- examples/run_crafter_demo.sh +2 -2
- examples/sft/README.md +1 -1
- examples/sft/configs/crafter_fft_qwen0p6b.toml +4 -1
- examples/sft/configs/crafter_lora_qwen0p6b.toml +4 -1
- examples/swe/task_app/README.md +32 -2
- examples/swe/task_app/grpo_swe_mini.py +4 -0
- examples/swe/task_app/hosted/envs/crafter/react_agent.py +1 -1
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +37 -10
- examples/swe/task_app/hosted/inference/openai_client.py +4 -38
- examples/swe/task_app/hosted/policy_routes.py +17 -0
- examples/swe/task_app/hosted/rollout.py +4 -2
- examples/swe/task_app/morph_backend.py +178 -0
- examples/task_apps/banking77/__init__.py +6 -0
- examples/task_apps/banking77/banking77_task_app.py +841 -0
- examples/task_apps/banking77/deploy_wrapper.py +46 -0
- examples/task_apps/crafter/CREATE_SFT_DATASET.md +4 -0
- examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +4 -0
- examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +4 -0
- examples/task_apps/crafter/task_app/README.md +1 -1
- examples/task_apps/crafter/task_app/grpo_crafter.py +90 -5
- examples/task_apps/crafter/task_app/grpo_crafter_task_app.py +1 -1
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/policy.py +4 -26
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/react_agent.py +1 -2
- examples/task_apps/crafter/task_app/synth_envs_hosted/hosted_app.py +49 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +372 -107
- examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +81 -12
- examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +82 -11
- examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +194 -1
- examples/task_apps/enron/task_app/grpo_enron_task_app.py +1 -1
- examples/task_apps/gepa_benchmarks/__init__.py +7 -0
- examples/task_apps/gepa_benchmarks/common.py +260 -0
- examples/task_apps/gepa_benchmarks/hotpotqa_task_app.py +507 -0
- examples/task_apps/gepa_benchmarks/hover_task_app.py +436 -0
- examples/task_apps/gepa_benchmarks/ifbench_task_app.py +563 -0
- examples/task_apps/gepa_benchmarks/pupa_task_app.py +460 -0
- examples/task_apps/math/README.md +1 -2
- examples/task_apps/pokemon_red/README.md +3 -4
- examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +4 -0
- examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +6 -5
- examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +1 -2
- examples/task_apps/pokemon_red/task_app.py +288 -39
- examples/task_apps/sokoban/README.md +2 -3
- examples/task_apps/verilog/eval_groq_qwen32b.toml +12 -14
- examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +1 -1
- examples/vlm/configs/crafter_vlm_gpt4o.toml +4 -1
- examples/warming_up_to_rl/configs/crafter_fft.toml +4 -1
- examples/warming_up_to_rl/configs/crafter_fft_4b.toml +0 -2
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +3 -2
- examples/warming_up_to_rl/run_local_rollout_traced.py +1 -1
- examples/warming_up_to_rl/task_app/README.md +1 -1
- examples/warming_up_to_rl/task_app/grpo_crafter.py +185 -5
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +3 -27
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +49 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +156 -45
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +37 -4
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +33 -3
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +67 -0
- examples/workflows/math_rl/configs/rl_from_base_qwen.toml +27 -0
- examples/workflows/math_rl/configs/rl_from_base_qwen17.toml +6 -0
- synth_ai/api/train/builders.py +99 -4
- synth_ai/api/train/cli.py +516 -26
- synth_ai/api/train/config_finder.py +13 -2
- synth_ai/api/train/configs/__init__.py +23 -2
- synth_ai/api/train/configs/prompt_learning.py +442 -0
- synth_ai/api/train/configs/rl.py +61 -7
- synth_ai/api/train/configs/sft.py +6 -2
- synth_ai/api/train/configs/shared.py +59 -2
- synth_ai/api/train/task_app.py +1 -1
- synth_ai/api/train/validators.py +277 -0
- synth_ai/auth/credentials.py +119 -0
- synth_ai/baseline/__init__.py +25 -0
- synth_ai/baseline/config.py +209 -0
- synth_ai/baseline/discovery.py +214 -0
- synth_ai/baseline/execution.py +146 -0
- synth_ai/cli/__init__.py +94 -18
- synth_ai/cli/__main__.py +0 -0
- synth_ai/cli/claude.py +70 -0
- synth_ai/cli/codex.py +84 -0
- synth_ai/cli/commands/__init__.py +18 -0
- synth_ai/cli/commands/baseline/__init__.py +12 -0
- synth_ai/cli/commands/baseline/core.py +637 -0
- synth_ai/cli/commands/baseline/list.py +93 -0
- synth_ai/cli/commands/demo/__init__.py +6 -0
- synth_ai/cli/commands/demo/core.py +163 -0
- synth_ai/cli/commands/eval/__init__.py +19 -0
- synth_ai/cli/commands/eval/core.py +1112 -0
- synth_ai/cli/commands/eval/errors.py +81 -0
- synth_ai/cli/commands/eval/validation.py +133 -0
- synth_ai/cli/commands/filter/__init__.py +12 -0
- synth_ai/cli/commands/filter/core.py +424 -0
- synth_ai/cli/commands/filter/errors.py +55 -0
- synth_ai/cli/commands/filter/validation.py +77 -0
- synth_ai/cli/commands/help/__init__.py +177 -0
- synth_ai/cli/commands/help/core.py +72 -0
- synth_ai/cli/commands/smoke/__init__.py +7 -0
- synth_ai/cli/commands/smoke/core.py +1436 -0
- synth_ai/cli/commands/status/__init__.py +64 -0
- synth_ai/cli/commands/status/client.py +192 -0
- synth_ai/cli/commands/status/config.py +92 -0
- synth_ai/cli/commands/status/errors.py +20 -0
- synth_ai/cli/commands/status/formatters.py +164 -0
- synth_ai/cli/commands/status/subcommands/__init__.py +9 -0
- synth_ai/cli/commands/status/subcommands/files.py +79 -0
- synth_ai/cli/commands/status/subcommands/jobs.py +334 -0
- synth_ai/cli/commands/status/subcommands/models.py +79 -0
- synth_ai/cli/commands/status/subcommands/pricing.py +22 -0
- synth_ai/cli/commands/status/subcommands/runs.py +81 -0
- synth_ai/cli/commands/status/subcommands/summary.py +47 -0
- synth_ai/cli/commands/status/subcommands/usage.py +203 -0
- synth_ai/cli/commands/status/utils.py +114 -0
- synth_ai/cli/commands/train/__init__.py +53 -0
- synth_ai/cli/commands/train/core.py +21 -0
- synth_ai/cli/commands/train/errors.py +117 -0
- synth_ai/cli/commands/train/judge_schemas.py +200 -0
- synth_ai/cli/commands/train/judge_validation.py +305 -0
- synth_ai/cli/commands/train/validation.py +386 -0
- synth_ai/cli/demo.py +30 -158
- synth_ai/cli/deploy/__init__.py +43 -0
- synth_ai/cli/deploy.py +162 -0
- synth_ai/cli/eval/__init__.py +36 -0
- synth_ai/cli/eval/core.py +5 -0
- synth_ai/cli/eval/errors.py +31 -0
- synth_ai/cli/eval/validation.py +5 -0
- synth_ai/cli/filter/__init__.py +28 -0
- synth_ai/cli/filter/core.py +5 -0
- synth_ai/cli/filter/errors.py +23 -0
- synth_ai/cli/filter/validation.py +5 -0
- synth_ai/cli/legacy_root_backup.py +14 -8
- synth_ai/cli/modal_serve/__init__.py +12 -0
- synth_ai/cli/modal_serve/core.py +14 -0
- synth_ai/cli/modal_serve/errors.py +8 -0
- synth_ai/cli/modal_serve/validation.py +11 -0
- synth_ai/cli/opencode.py +107 -0
- synth_ai/cli/root.py +9 -5
- synth_ai/cli/serve/__init__.py +12 -0
- synth_ai/cli/serve/core.py +14 -0
- synth_ai/cli/serve/errors.py +8 -0
- synth_ai/cli/serve/validation.py +11 -0
- synth_ai/cli/setup.py +20 -265
- synth_ai/cli/status.py +7 -126
- synth_ai/cli/task_app_deploy.py +1 -10
- synth_ai/cli/task_app_modal_serve.py +4 -9
- synth_ai/cli/task_app_serve.py +4 -11
- synth_ai/cli/task_apps.py +51 -1480
- synth_ai/cli/train/__init__.py +12 -0
- synth_ai/cli/train/core.py +21 -0
- synth_ai/cli/train/errors.py +8 -0
- synth_ai/cli/train/validation.py +24 -0
- synth_ai/cli/train.py +1 -14
- synth_ai/demos/crafter/grpo_crafter_task_app.py +1 -1
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +1 -1
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +7 -4
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +9 -5
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +4 -3
- synth_ai/environments/examples/red/engine.py +33 -12
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +151 -179
- synth_ai/environments/examples/red/environment.py +26 -0
- synth_ai/environments/examples/red/trace_hooks_v3.py +168 -0
- synth_ai/http.py +12 -0
- synth_ai/judge_schemas.py +10 -10
- synth_ai/learning/__init__.py +10 -0
- synth_ai/learning/prompt_learning_client.py +276 -0
- synth_ai/learning/prompt_learning_types.py +184 -0
- synth_ai/learning/rl/client.py +3 -1
- synth_ai/pricing/__init__.py +2 -0
- synth_ai/pricing/model_pricing.py +57 -0
- synth_ai/streaming/__init__.py +29 -0
- synth_ai/streaming/config.py +94 -0
- synth_ai/streaming/handlers.py +518 -0
- synth_ai/streaming/streamer.py +320 -0
- synth_ai/streaming/types.py +95 -0
- synth_ai/task/apps/__init__.py +1 -0
- synth_ai/task/config.py +2 -0
- synth_ai/task/tracing_utils.py +25 -25
- synth_ai/task/validators.py +45 -9
- synth_ai/task_app_cfgs.py +21 -0
- synth_ai/tracing_v3/config.py +162 -19
- synth_ai/tracing_v3/constants.py +1 -1
- synth_ai/tracing_v3/db_config.py +24 -38
- synth_ai/tracing_v3/migration_helper.py +1 -2
- synth_ai/tracing_v3/storage/config.py +47 -13
- synth_ai/tracing_v3/storage/factory.py +3 -3
- synth_ai/tracing_v3/turso/daemon.py +113 -11
- synth_ai/tracing_v3/turso/native_manager.py +92 -16
- synth_ai/types.py +8 -0
- synth_ai/urls.py +11 -0
- synth_ai/utils/__init__.py +30 -1
- synth_ai/utils/agents.py +74 -0
- synth_ai/utils/bin.py +39 -0
- synth_ai/utils/cli.py +149 -5
- synth_ai/utils/env.py +40 -33
- synth_ai/utils/http.py +4 -1
- synth_ai/utils/json.py +72 -0
- synth_ai/utils/modal.py +285 -3
- synth_ai/utils/paths.py +48 -0
- synth_ai/utils/uvicorn.py +113 -0
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/METADATA +109 -6
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/RECORD +291 -142
- examples/qwen_vl/configs/eval_qwen2vl_vision.toml +0 -44
- synth_ai/cli/tui.py +0 -62
- synth_ai/tui/__init__.py +0 -5
- synth_ai/tui/__main__.py +0 -13
- synth_ai/tui/cli/__init__.py +0 -1
- synth_ai/tui/cli/query_experiments.py +0 -164
- synth_ai/tui/cli/query_experiments_v3.py +0 -164
- synth_ai/tui/dashboard.py +0 -911
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,326 @@
|
|
|
1
|
+
"""Pokemon Red baseline file for Game Boy emulation evaluation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Dict, List, Optional
|
|
6
|
+
|
|
7
|
+
from synth_ai.baseline import BaselineConfig, BaselineTaskRunner, DataSplit, TaskResult
|
|
8
|
+
from synth_ai.inference import InferenceClient
|
|
9
|
+
import os
|
|
10
|
+
import httpx
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
from synth_ai.environments.examples.red.environment import PokemonRedEnvironment
|
|
14
|
+
from synth_ai.environments.examples.red.taskset import (
|
|
15
|
+
PokemonRedTaskInstance,
|
|
16
|
+
PokemonRedTaskInstanceMetadata,
|
|
17
|
+
)
|
|
18
|
+
POKEMON_RED_AVAILABLE = True
|
|
19
|
+
except ImportError:
|
|
20
|
+
POKEMON_RED_AVAILABLE = False
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class PokemonRedTaskRunner(BaselineTaskRunner):
|
|
24
|
+
"""Task runner for Pokemon Red Game Boy emulation."""
|
|
25
|
+
|
|
26
|
+
def __init__(self, policy_config: Dict[str, Any], env_config: Dict[str, Any]):
|
|
27
|
+
super().__init__(policy_config, env_config)
|
|
28
|
+
|
|
29
|
+
if not POKEMON_RED_AVAILABLE:
|
|
30
|
+
raise ImportError(
|
|
31
|
+
"Pokemon Red environment not available. "
|
|
32
|
+
"Install synth-ai with Pokemon Red support."
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
# Store config for inference
|
|
36
|
+
self.model = policy_config["model"]
|
|
37
|
+
self.temperature = policy_config.get("temperature", 0.0)
|
|
38
|
+
self.max_tokens = policy_config.get("max_tokens", 512)
|
|
39
|
+
self.inference_url = policy_config.get("inference_url")
|
|
40
|
+
|
|
41
|
+
# Tool definition
|
|
42
|
+
self.tools = [{
|
|
43
|
+
"type": "function",
|
|
44
|
+
"function": {
|
|
45
|
+
"name": "execute_sequence",
|
|
46
|
+
"description": "Execute multiple button presses in sequence",
|
|
47
|
+
"parameters": {
|
|
48
|
+
"type": "object",
|
|
49
|
+
"properties": {
|
|
50
|
+
"actions": {
|
|
51
|
+
"type": "array",
|
|
52
|
+
"items": {
|
|
53
|
+
"type": "object",
|
|
54
|
+
"properties": {
|
|
55
|
+
"button": {
|
|
56
|
+
"type": "string",
|
|
57
|
+
"enum": ["UP", "DOWN", "LEFT", "RIGHT", "A", "B", "START", "SELECT"],
|
|
58
|
+
},
|
|
59
|
+
"frames": {
|
|
60
|
+
"type": "integer",
|
|
61
|
+
"minimum": 1,
|
|
62
|
+
"maximum": 120,
|
|
63
|
+
"description": "Frames to hold button (60fps)",
|
|
64
|
+
},
|
|
65
|
+
},
|
|
66
|
+
"required": ["button", "frames"],
|
|
67
|
+
},
|
|
68
|
+
"minItems": 1,
|
|
69
|
+
"maxItems": 20,
|
|
70
|
+
},
|
|
71
|
+
},
|
|
72
|
+
"required": ["actions"],
|
|
73
|
+
},
|
|
74
|
+
},
|
|
75
|
+
}]
|
|
76
|
+
|
|
77
|
+
def _format_observation(self, obs: Dict[str, Any], step: int, max_steps: int) -> str:
|
|
78
|
+
"""Format observation for LLM."""
|
|
79
|
+
lines = [
|
|
80
|
+
f"Pokemon Red - Step {step}/{max_steps}",
|
|
81
|
+
"",
|
|
82
|
+
]
|
|
83
|
+
|
|
84
|
+
# Position
|
|
85
|
+
if "map_id" in obs:
|
|
86
|
+
lines.append(f"Location: Map {obs['map_id']}")
|
|
87
|
+
if "player_x" in obs and "player_y" in obs:
|
|
88
|
+
lines.append(f"Position: ({obs['player_x']}, {obs['player_y']})")
|
|
89
|
+
|
|
90
|
+
# Party
|
|
91
|
+
if "party_count" in obs:
|
|
92
|
+
lines.append(f"Party Size: {obs['party_count']}")
|
|
93
|
+
if "party_pokemon" in obs and obs["party_pokemon"]:
|
|
94
|
+
pokemon = obs["party_pokemon"][0]
|
|
95
|
+
lines.append(
|
|
96
|
+
f"First Pokemon: Level {pokemon.get('level', '?')}, "
|
|
97
|
+
f"HP {pokemon.get('hp_current', '?')}/{pokemon.get('hp_max', '?')}"
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
# Battle
|
|
101
|
+
if obs.get("in_battle"):
|
|
102
|
+
lines.append("=== IN BATTLE ===")
|
|
103
|
+
if "enemy_hp_current" in obs:
|
|
104
|
+
lines.append(
|
|
105
|
+
f"Enemy HP: {obs['enemy_hp_current']}/{obs.get('enemy_hp_max', '?')}"
|
|
106
|
+
)
|
|
107
|
+
if "battle_turn" in obs:
|
|
108
|
+
lines.append(f"Battle Turn: {obs['battle_turn']}")
|
|
109
|
+
|
|
110
|
+
# Progress
|
|
111
|
+
if "badges" in obs:
|
|
112
|
+
lines.append(f"Badges: {obs['badges']}")
|
|
113
|
+
if "money" in obs:
|
|
114
|
+
lines.append(f"Money: ${obs['money']}")
|
|
115
|
+
|
|
116
|
+
# Dialogue
|
|
117
|
+
if obs.get("text_box_active"):
|
|
118
|
+
lines.append("Text box is active - press A to advance dialogue")
|
|
119
|
+
|
|
120
|
+
lines.append("")
|
|
121
|
+
lines.append("What actions should we take?")
|
|
122
|
+
|
|
123
|
+
return "\n".join(lines)
|
|
124
|
+
|
|
125
|
+
async def run_task(self, seed: int) -> TaskResult:
|
|
126
|
+
"""Run a single Pokemon Red episode."""
|
|
127
|
+
|
|
128
|
+
# Create task instance
|
|
129
|
+
rom_path = self.env_config.get("rom_path")
|
|
130
|
+
if not rom_path:
|
|
131
|
+
raise ValueError("rom_path required in env_config for Pokemon Red")
|
|
132
|
+
|
|
133
|
+
init_state_path = self.env_config.get("init_state_path")
|
|
134
|
+
max_steps = self.env_config.get("max_steps", 500)
|
|
135
|
+
|
|
136
|
+
metadata = PokemonRedTaskInstanceMetadata(
|
|
137
|
+
seed=seed,
|
|
138
|
+
rom_path=rom_path,
|
|
139
|
+
init_state_path=init_state_path,
|
|
140
|
+
reward_type=self.env_config.get("reward_type", "pallet_town_progression"),
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
task_instance = PokemonRedTaskInstance(
|
|
144
|
+
id=f"pokemon-red-{seed}",
|
|
145
|
+
metadata=metadata,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# Create environment
|
|
149
|
+
env = PokemonRedEnvironment(task_instance=task_instance)
|
|
150
|
+
|
|
151
|
+
# Initialize environment
|
|
152
|
+
raw_obs = await env.initialize()
|
|
153
|
+
observation = getattr(raw_obs, "observation", raw_obs) if hasattr(raw_obs, "observation") else raw_obs
|
|
154
|
+
obs_dict = observation if isinstance(observation, dict) else {}
|
|
155
|
+
|
|
156
|
+
# Episode loop
|
|
157
|
+
total_reward = 0.0
|
|
158
|
+
total_steps = 0
|
|
159
|
+
event_rewards: List[Dict[str, Any]] = []
|
|
160
|
+
battle_won = False
|
|
161
|
+
game_over = False
|
|
162
|
+
|
|
163
|
+
for step in range(max_steps):
|
|
164
|
+
# Format observation
|
|
165
|
+
prompt = self._format_observation(obs_dict, step, max_steps)
|
|
166
|
+
|
|
167
|
+
# Add image if available
|
|
168
|
+
messages = [{"role": "user", "content": prompt}]
|
|
169
|
+
if obs_dict.get("observation_image_base64"):
|
|
170
|
+
messages[0]["content"] = [
|
|
171
|
+
{
|
|
172
|
+
"type": "image_url",
|
|
173
|
+
"image_url": {
|
|
174
|
+
"url": f"data:image/png;base64,{obs_dict['observation_image_base64']}"
|
|
175
|
+
},
|
|
176
|
+
},
|
|
177
|
+
{"type": "text", "text": prompt},
|
|
178
|
+
]
|
|
179
|
+
|
|
180
|
+
# Get action from LLM
|
|
181
|
+
if self.inference_url and self.inference_url.startswith("http"):
|
|
182
|
+
api_key = os.getenv("SYNTH_API_KEY") or os.getenv("OPENAI_API_KEY") or ""
|
|
183
|
+
base_url = self.inference_url.rstrip("/")
|
|
184
|
+
if not base_url.endswith("/api"):
|
|
185
|
+
base_url = f"{base_url}/api" if "/api" not in base_url else base_url
|
|
186
|
+
client = InferenceClient(base_url=base_url, api_key=api_key)
|
|
187
|
+
response = await client.create_chat_completion(
|
|
188
|
+
model=self.model,
|
|
189
|
+
messages=messages,
|
|
190
|
+
tools=self.tools,
|
|
191
|
+
tool_choice={"type": "function", "function": {"name": "execute_sequence"}},
|
|
192
|
+
temperature=self.temperature,
|
|
193
|
+
max_tokens=self.max_tokens,
|
|
194
|
+
)
|
|
195
|
+
else:
|
|
196
|
+
api_key = os.getenv("OPENAI_API_KEY") or os.getenv("GROQ_API_KEY") or ""
|
|
197
|
+
base_url = "https://api.openai.com/v1" if "openai" in self.model.lower() else "https://api.groq.com/openai/v1"
|
|
198
|
+
async with httpx.AsyncClient() as http_client:
|
|
199
|
+
resp = await http_client.post(
|
|
200
|
+
f"{base_url}/chat/completions",
|
|
201
|
+
json={
|
|
202
|
+
"model": self.model,
|
|
203
|
+
"messages": messages,
|
|
204
|
+
"tools": self.tools,
|
|
205
|
+
"tool_choice": {"type": "function", "function": {"name": "execute_sequence"}},
|
|
206
|
+
"temperature": self.temperature,
|
|
207
|
+
"max_tokens": self.max_tokens,
|
|
208
|
+
},
|
|
209
|
+
headers={"Authorization": f"Bearer {api_key}"} if api_key else {},
|
|
210
|
+
)
|
|
211
|
+
response = resp.json()
|
|
212
|
+
|
|
213
|
+
# Extract actions
|
|
214
|
+
actions = []
|
|
215
|
+
tool_calls = []
|
|
216
|
+
if "choices" in response and len(response["choices"]) > 0:
|
|
217
|
+
message = response["choices"][0].get("message", {})
|
|
218
|
+
tool_calls = message.get("tool_calls", [])
|
|
219
|
+
elif "tool_calls" in response:
|
|
220
|
+
tool_calls = response["tool_calls"]
|
|
221
|
+
|
|
222
|
+
if tool_calls:
|
|
223
|
+
tool_call = tool_calls[0]
|
|
224
|
+
actions = tool_call["function"]["arguments"].get("actions", [])
|
|
225
|
+
|
|
226
|
+
if not actions:
|
|
227
|
+
break
|
|
228
|
+
|
|
229
|
+
# Execute actions
|
|
230
|
+
for action_spec in actions:
|
|
231
|
+
if total_steps >= max_steps:
|
|
232
|
+
break
|
|
233
|
+
|
|
234
|
+
# Convert to tool call format
|
|
235
|
+
from synth_ai.environments.environment.tools import EnvToolCall
|
|
236
|
+
|
|
237
|
+
tool_call = EnvToolCall(
|
|
238
|
+
name="execute_sequence",
|
|
239
|
+
arguments={"actions": [action_spec]},
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
# Step environment
|
|
243
|
+
step_result = await env.step([tool_call])
|
|
244
|
+
total_steps += 1
|
|
245
|
+
|
|
246
|
+
# Get observation
|
|
247
|
+
step_obs = (
|
|
248
|
+
getattr(step_result, "observation", step_result)
|
|
249
|
+
if hasattr(step_result, "observation")
|
|
250
|
+
else step_result
|
|
251
|
+
)
|
|
252
|
+
obs_dict = step_obs if isinstance(step_obs, dict) else {}
|
|
253
|
+
|
|
254
|
+
# Extract reward
|
|
255
|
+
reward = getattr(step_result, "reward", 0.0)
|
|
256
|
+
total_reward += reward
|
|
257
|
+
|
|
258
|
+
if reward > 0:
|
|
259
|
+
event_rewards.append({
|
|
260
|
+
"step": total_steps,
|
|
261
|
+
"reward": reward,
|
|
262
|
+
})
|
|
263
|
+
|
|
264
|
+
# Check termination
|
|
265
|
+
if getattr(step_result, "terminated", False) or getattr(step_result, "truncated", False):
|
|
266
|
+
game_over = True
|
|
267
|
+
break
|
|
268
|
+
|
|
269
|
+
# Check battle outcome
|
|
270
|
+
if obs_dict.get("battle_outcome") == 1:
|
|
271
|
+
battle_won = True
|
|
272
|
+
elif obs_dict.get("battle_outcome") == 2:
|
|
273
|
+
game_over = True
|
|
274
|
+
|
|
275
|
+
if game_over:
|
|
276
|
+
break
|
|
277
|
+
|
|
278
|
+
# Cleanup
|
|
279
|
+
await env.terminate()
|
|
280
|
+
|
|
281
|
+
return TaskResult(
|
|
282
|
+
seed=seed,
|
|
283
|
+
success=True,
|
|
284
|
+
outcome_reward=total_reward,
|
|
285
|
+
event_rewards=event_rewards,
|
|
286
|
+
total_steps=total_steps,
|
|
287
|
+
metadata={
|
|
288
|
+
"battle_won": battle_won,
|
|
289
|
+
"game_over": game_over,
|
|
290
|
+
"final_map": obs_dict.get("map_id"),
|
|
291
|
+
"badges": obs_dict.get("badges", 0),
|
|
292
|
+
"party_size": obs_dict.get("party_count", 0),
|
|
293
|
+
},
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
# Define baseline config (only if Pokemon Red is available)
|
|
298
|
+
if POKEMON_RED_AVAILABLE:
|
|
299
|
+
pokemon_vl_baseline = BaselineConfig(
|
|
300
|
+
baseline_id="pokemon_vl",
|
|
301
|
+
name="Pokemon VL - Pokemon Red",
|
|
302
|
+
description="Pokemon Red Game Boy emulation baseline for vision-language agents",
|
|
303
|
+
task_runner=PokemonRedTaskRunner,
|
|
304
|
+
splits={
|
|
305
|
+
"train": DataSplit(name="train", seeds=list(range(20))),
|
|
306
|
+
"val": DataSplit(name="val", seeds=list(range(20, 25))),
|
|
307
|
+
"test": DataSplit(name="test", seeds=list(range(25, 30))),
|
|
308
|
+
},
|
|
309
|
+
default_policy_config={
|
|
310
|
+
"model": "groq:llama-3.1-70b-versatile",
|
|
311
|
+
"temperature": 0.0,
|
|
312
|
+
"max_tokens": 512,
|
|
313
|
+
},
|
|
314
|
+
default_env_config={
|
|
315
|
+
"rom_path": None, # Must be provided
|
|
316
|
+
"init_state_path": None, # Optional
|
|
317
|
+
"reward_type": "pallet_town_progression",
|
|
318
|
+
"max_steps": 500,
|
|
319
|
+
},
|
|
320
|
+
metadata={
|
|
321
|
+
"environment": "pokemon_red",
|
|
322
|
+
"task_type": "emulation",
|
|
323
|
+
"requires_rom": True,
|
|
324
|
+
},
|
|
325
|
+
)
|
|
326
|
+
|
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Run pokemon_vl eval with gpt-5-nano and extract images from trajectory response.
|
|
3
|
+
|
|
4
|
+
This script bypasses the trace validation issue by extracting images directly from
|
|
5
|
+
the trajectory steps in the rollout response.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import argparse
|
|
9
|
+
import asyncio
|
|
10
|
+
import base64
|
|
11
|
+
import json
|
|
12
|
+
import os
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
|
|
15
|
+
import httpx
|
|
16
|
+
from dotenv import load_dotenv
|
|
17
|
+
|
|
18
|
+
load_dotenv()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
async def run_eval_and_extract_images(
|
|
22
|
+
task_app_url: str,
|
|
23
|
+
output_dir: Path,
|
|
24
|
+
seed: int = 0,
|
|
25
|
+
max_turns: int = 10,
|
|
26
|
+
model: str = "gpt-5-nano",
|
|
27
|
+
):
|
|
28
|
+
"""Run eval and extract images from trajectory."""
|
|
29
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
30
|
+
|
|
31
|
+
async with httpx.AsyncClient(timeout=300.0) as client:
|
|
32
|
+
# Build rollout request
|
|
33
|
+
rollout_request = {
|
|
34
|
+
"run_id": f"gpt5nano_eval_seed_{seed}",
|
|
35
|
+
"env": {
|
|
36
|
+
"env_name": "pokemon_red",
|
|
37
|
+
"seed": seed,
|
|
38
|
+
"config": {
|
|
39
|
+
"split": "train",
|
|
40
|
+
"index": seed,
|
|
41
|
+
"env_params": {"max_steps_per_episode": 100},
|
|
42
|
+
},
|
|
43
|
+
},
|
|
44
|
+
"policy": {
|
|
45
|
+
"policy_name": "pokemon_vl_qwen3_vl",
|
|
46
|
+
"config": {
|
|
47
|
+
"model": model,
|
|
48
|
+
"provider": "openai",
|
|
49
|
+
"inference_url": "https://api.openai.com/v1",
|
|
50
|
+
"temperature": 0.7,
|
|
51
|
+
"top_p": 0.95,
|
|
52
|
+
"max_tokens": 512,
|
|
53
|
+
"use_vision": True,
|
|
54
|
+
"image_only_mode": False,
|
|
55
|
+
"max_llm_calls": max_turns,
|
|
56
|
+
},
|
|
57
|
+
},
|
|
58
|
+
"ops": ["policy"] * max_turns,
|
|
59
|
+
"mode": "eval",
|
|
60
|
+
"record": {
|
|
61
|
+
"return_trace": True,
|
|
62
|
+
"trace_format": "full",
|
|
63
|
+
},
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
print(f"Running eval with gpt-5-nano (seed={seed})...")
|
|
67
|
+
response = await client.post(f"{task_app_url}/rollout", json=rollout_request)
|
|
68
|
+
response.raise_for_status()
|
|
69
|
+
result = response.json()
|
|
70
|
+
|
|
71
|
+
# Extract trajectory
|
|
72
|
+
trajectories = result.get("trajectories", [])
|
|
73
|
+
if not trajectories:
|
|
74
|
+
print("Error: No trajectories in response")
|
|
75
|
+
return
|
|
76
|
+
|
|
77
|
+
trajectory = trajectories[0]
|
|
78
|
+
steps = trajectory.get("steps", [])
|
|
79
|
+
|
|
80
|
+
print(f"✓ Received {len(steps)} steps")
|
|
81
|
+
print(f"Extracting images (filtering intermediate text box frames)...")
|
|
82
|
+
|
|
83
|
+
# First pass: collect all images with their state
|
|
84
|
+
image_data = []
|
|
85
|
+
for idx, step in enumerate(steps):
|
|
86
|
+
obs = step.get("obs", {})
|
|
87
|
+
img_b64 = obs.get("observation_image_base64")
|
|
88
|
+
|
|
89
|
+
if not img_b64:
|
|
90
|
+
continue
|
|
91
|
+
|
|
92
|
+
try:
|
|
93
|
+
img_data = base64.b64decode(img_b64)
|
|
94
|
+
map_id = obs.get("map_id", "?")
|
|
95
|
+
player_x = obs.get("player_x", "?")
|
|
96
|
+
player_y = obs.get("player_y", "?")
|
|
97
|
+
text_box_active = obs.get("text_box_active", False)
|
|
98
|
+
|
|
99
|
+
image_data.append({
|
|
100
|
+
"idx": idx,
|
|
101
|
+
"img_data": img_data,
|
|
102
|
+
"map_id": map_id,
|
|
103
|
+
"player_x": player_x,
|
|
104
|
+
"player_y": player_y,
|
|
105
|
+
"text_box_active": text_box_active,
|
|
106
|
+
})
|
|
107
|
+
except Exception as e:
|
|
108
|
+
print(f" Error decoding step {idx}: {e}")
|
|
109
|
+
continue
|
|
110
|
+
|
|
111
|
+
# Second pass: filter out intermediate text box frames
|
|
112
|
+
# Keep: text_box_active=False OR the last frame of a text box sequence
|
|
113
|
+
filtered_images = []
|
|
114
|
+
for i, img_info in enumerate(image_data):
|
|
115
|
+
text_box_active = img_info["text_box_active"]
|
|
116
|
+
prev_text_box_active = image_data[i - 1]["text_box_active"] if i > 0 else False
|
|
117
|
+
next_text_box_active = image_data[i + 1]["text_box_active"] if i + 1 < len(image_data) else False
|
|
118
|
+
|
|
119
|
+
# Keep if:
|
|
120
|
+
# 1. Not in a text box (text_box_active=False)
|
|
121
|
+
# 2. Last frame of text box sequence (text_box_active=True and next is False)
|
|
122
|
+
# 3. Last frame overall and in text box (no next frame)
|
|
123
|
+
if not text_box_active:
|
|
124
|
+
# Always keep non-text-box frames
|
|
125
|
+
filtered_images.append(img_info)
|
|
126
|
+
elif text_box_active and (not next_text_box_active or i + 1 >= len(image_data)):
|
|
127
|
+
# Keep final frame of text box sequence (transition out or end of trajectory)
|
|
128
|
+
filtered_images.append(img_info)
|
|
129
|
+
# Otherwise skip intermediate text box loading frames
|
|
130
|
+
|
|
131
|
+
# Save filtered images
|
|
132
|
+
image_count = 0
|
|
133
|
+
for img_info in filtered_images:
|
|
134
|
+
try:
|
|
135
|
+
map_id = img_info["map_id"]
|
|
136
|
+
player_x = img_info["player_x"]
|
|
137
|
+
player_y = img_info["player_y"]
|
|
138
|
+
text_box_active = img_info["text_box_active"]
|
|
139
|
+
idx = img_info["idx"]
|
|
140
|
+
|
|
141
|
+
pos_str = f"Map{map_id}_{player_x},{player_y}"
|
|
142
|
+
textbox_str = "True" if text_box_active else "False"
|
|
143
|
+
filename = f"step_{idx:03d}_pos_{pos_str}_textbox_{textbox_str}.png"
|
|
144
|
+
|
|
145
|
+
filepath = output_dir / filename
|
|
146
|
+
filepath.write_bytes(img_info["img_data"])
|
|
147
|
+
|
|
148
|
+
print(f" Saved: {filename}")
|
|
149
|
+
image_count += 1
|
|
150
|
+
except Exception as e:
|
|
151
|
+
print(f" Error saving step {img_info['idx']}: {e}")
|
|
152
|
+
continue
|
|
153
|
+
|
|
154
|
+
print(f"\n Filtered: {len(image_data)} -> {len(filtered_images)} images (removed {len(image_data) - len(filtered_images)} intermediate text box frames)")
|
|
155
|
+
|
|
156
|
+
print(f"\n✓ Extracted {image_count} images to {output_dir}/")
|
|
157
|
+
|
|
158
|
+
# Also save metrics
|
|
159
|
+
metrics = result.get("metrics", {})
|
|
160
|
+
if metrics:
|
|
161
|
+
metrics_file = output_dir / "metrics.json"
|
|
162
|
+
with open(metrics_file, "w") as f:
|
|
163
|
+
json.dump(metrics, f, indent=2)
|
|
164
|
+
print(f"✓ Saved metrics to {metrics_file}")
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
async def main():
|
|
168
|
+
parser = argparse.ArgumentParser(description=__doc__)
|
|
169
|
+
parser.add_argument(
|
|
170
|
+
"--task-app-url",
|
|
171
|
+
default="http://127.0.0.1:8914",
|
|
172
|
+
help="Task app URL",
|
|
173
|
+
)
|
|
174
|
+
parser.add_argument(
|
|
175
|
+
"--output-dir",
|
|
176
|
+
default="examples/blog_posts/pokemon_vl/images_gpt5",
|
|
177
|
+
help="Output directory for images",
|
|
178
|
+
)
|
|
179
|
+
parser.add_argument(
|
|
180
|
+
"--seed",
|
|
181
|
+
type=int,
|
|
182
|
+
default=0,
|
|
183
|
+
help="Random seed",
|
|
184
|
+
)
|
|
185
|
+
parser.add_argument(
|
|
186
|
+
"--max-turns",
|
|
187
|
+
type=int,
|
|
188
|
+
default=10,
|
|
189
|
+
help="Maximum turns",
|
|
190
|
+
)
|
|
191
|
+
parser.add_argument(
|
|
192
|
+
"--model",
|
|
193
|
+
default="gpt-5-nano",
|
|
194
|
+
help="Model name",
|
|
195
|
+
)
|
|
196
|
+
args = parser.parse_args()
|
|
197
|
+
|
|
198
|
+
await run_eval_and_extract_images(
|
|
199
|
+
args.task_app_url,
|
|
200
|
+
Path(args.output_dir),
|
|
201
|
+
args.seed,
|
|
202
|
+
args.max_turns,
|
|
203
|
+
args.model,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
if __name__ == "__main__":
|
|
208
|
+
asyncio.run(main())
|
|
209
|
+
|