synth-ai 0.2.13.dev1__py3-none-any.whl → 0.2.13.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +12 -1
- examples/swe/task_app/grpo_swe_mini.py +55 -26
- examples/swe/task_app/hosted/rollout.py +40 -0
- examples/swe/task_app/hosted/test_service.py +5 -6
- examples/task_apps/TESTING.md +275 -0
- examples/task_apps/__init__.py +0 -0
- examples/task_apps/crafter/__init__.py +0 -0
- examples/task_apps/crafter/task_app/__init__.py +2 -0
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter.py +18 -13
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter_task_app.py +1 -1
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/policy.py +60 -4
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/policy_routes.py +25 -3
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/rollout.py +10 -0
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_service.py +5 -6
- examples/task_apps/dev/pokemon_emerald/__init__.py +2 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/README.md +811 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/__init__.py +120 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/action.py +160 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/memory.py +155 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/perception.py +69 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/planning.py +96 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/simple.py +1502 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/system_prompt.py +4 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/grab_map.py +68 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/manual.py +216 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/__init__.py +35 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emerald_utils.py +631 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emulator.py +1544 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/enums.py +1428 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/memory_reader.py +4848 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/types.py +41 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/utils.py +298 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pyproject.toml +95 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/run.py +204 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/app.py +2152 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/client.py +429 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/frame_server.py +155 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/README.md +78 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/run_tests.py +122 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_direct.py +76 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_prompts.py +413 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_battle_state_formatting.py +204 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection.py +133 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection_comprehensive.py +229 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_direct_agent_emulator.py +300 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_fps_adjustment_pytest.py +205 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_direct.py +200 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_transition.py +284 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_map_ground_truth_comparison.py +468 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_memory_map.py +575 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_server_map_validation.py +311 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_torchic_state.py +259 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/anticheat.py +372 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/checkpoint.py +296 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/error_handler.py +275 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/get_local_ip.py +22 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/helpers.py +44 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/llm_logger.py +514 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_formatter.py +415 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher.py +1763 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher_singleton.py +33 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_trimmer.py +106 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_visualizer.py +334 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/ocr_dialogue.py +1020 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/recording.py +188 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/state_formatter.py +1481 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/vlm.py +862 -0
- examples/task_apps/dev/pokemon_emerald/modal_app.py +114 -0
- examples/task_apps/dev/pokemon_emerald/task_app/README.md +81 -0
- examples/task_apps/dev/pokemon_emerald/task_app/__init__.py +6 -0
- examples/task_apps/dev/pokemon_emerald/task_app/pokemon_emerald.py +685 -0
- examples/task_apps/enron/__init__.py +1 -0
- examples/task_apps/enron/eval_groq_qwen32.toml +16 -0
- examples/task_apps/enron/task_app/README.md +14 -0
- examples/task_apps/enron/task_app/__init__.py +1 -0
- examples/task_apps/enron/task_app/grpo_enron.py +906 -0
- examples/task_apps/enron/task_app/grpo_enron_task_app.py +146 -0
- examples/task_apps/enron/tests/__init__.py +2 -0
- examples/task_apps/enron/tests/conftest.py +115 -0
- examples/task_apps/enron/tests/integration/__init__.py +2 -0
- examples/task_apps/enron/tests/integration/test_enron_eval.py +177 -0
- examples/task_apps/enron/tests/integration/test_enron_rollout.py +135 -0
- examples/task_apps/enron/tests/unit/__init__.py +2 -0
- examples/task_apps/enron/tests/unit/test_enron_environment.py +126 -0
- examples/task_apps/math/__init__.py +0 -0
- examples/{rl/task_app → task_apps/math}/math_single_step.py +19 -10
- examples/task_apps/pokemon_battle/__init__.py +2 -0
- examples/task_apps/pokemon_battle/modal_app.py +104 -0
- examples/task_apps/pokemon_battle/task_app/README.md +68 -0
- examples/task_apps/pokemon_battle/task_app/__init__.py +6 -0
- examples/task_apps/pokemon_battle/task_app/pokemon_showdown.py +932 -0
- examples/task_apps/pokemon_red/README.md +357 -0
- examples/task_apps/pokemon_red/__init__.py +3 -0
- examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +225 -0
- examples/task_apps/pokemon_red/pallet_town_rl_config.toml +73 -0
- examples/task_apps/pokemon_red/task_app.py +606 -0
- examples/task_apps/pokemon_red/test_pallet_town_rewards.py +191 -0
- examples/task_apps/sokoban/README.md +307 -0
- examples/task_apps/sokoban/__init__.py +3 -0
- examples/task_apps/sokoban/eval_groq_qwen32.toml +16 -0
- examples/task_apps/sokoban/eval_openai_gpt5.toml +16 -0
- examples/task_apps/sokoban/task_app.py +1058 -0
- examples/task_apps/sokoban/tests/__init__.py +2 -0
- examples/task_apps/sokoban/tests/conftest.py +113 -0
- examples/task_apps/sokoban/tests/integration/__init__.py +2 -0
- examples/task_apps/sokoban/tests/integration/test_sokoban_eval.py +57 -0
- examples/task_apps/sokoban/tests/integration/test_sokoban_rollout.py +198 -0
- examples/task_apps/sokoban/tests/unit/__init__.py +2 -0
- examples/task_apps/sokoban/tests/unit/test_sokoban_environment.py +114 -0
- examples/task_apps/verilog/__init__.py +1 -0
- examples/task_apps/verilog/eval_groq_qwen32b.toml +20 -0
- examples/task_apps/verilog/task_app/README.md +12 -0
- examples/task_apps/verilog/task_app/__init__.py +1 -0
- examples/task_apps/verilog/task_app/grpo_verilog.py +931 -0
- examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +145 -0
- examples/task_apps/verilog/tests/__init__.py +2 -0
- examples/task_apps/verilog/tests/conftest.py +115 -0
- examples/task_apps/verilog/tests/integration/__init__.py +2 -0
- examples/task_apps/verilog/tests/integration/test_verilog_eval.py +179 -0
- examples/task_apps/verilog/tests/integration/test_verilog_rollout.py +55 -0
- examples/task_apps/verilog/tests/unit/__init__.py +2 -0
- examples/task_apps/verilog/tests/unit/test_verilog_scoring.py +118 -0
- examples/vlm/crafter_openai_vlm_agent.py +4 -4
- examples/vlm/run_crafter_vlm_benchmark.py +4 -4
- examples/workflows/__init__.py +0 -0
- examples/workflows/math_rl/__init__.py +0 -0
- examples/workflows/math_rl/download_dataset.py +80 -0
- synth_ai/__init__.py +2 -2
- synth_ai/api/train/builders.py +25 -11
- synth_ai/api/train/cli.py +12 -6
- synth_ai/api/train/configs/__init__.py +10 -10
- synth_ai/api/train/configs/rl.py +5 -4
- synth_ai/api/train/configs/sft.py +4 -3
- synth_ai/api/train/env_resolver.py +5 -2
- synth_ai/api/train/supported_algos.py +10 -5
- synth_ai/api/train/utils.py +7 -4
- synth_ai/cli/__init__.py +7 -51
- synth_ai/cli/_storage.py +4 -3
- synth_ai/cli/_validate_task_app.py +11 -0
- synth_ai/cli/balance.py +4 -3
- synth_ai/cli/calc.py +2 -2
- synth_ai/cli/demo.py +14 -7
- synth_ai/cli/legacy_root_backup.py +1 -1
- synth_ai/cli/rl_demo.py +8 -7
- synth_ai/cli/root.py +0 -97
- synth_ai/cli/task_apps.py +1707 -186
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +28 -16
- synth_ai/environments/examples/enron/engine.py +7 -2
- synth_ai/environments/examples/enron/environment.py +68 -0
- synth_ai/environments/examples/red/engine.py +27 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +7 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_progression.py +477 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +32 -0
- synth_ai/environments/examples/red/environment.py +60 -0
- synth_ai/environments/examples/sokoban/taskset.py +116 -0
- synth_ai/environments/examples/verilog/engine.py +30 -4
- synth_ai/evals/client.py +58 -61
- synth_ai/jobs/client.py +16 -4
- synth_ai/judge_schemas.py +16 -16
- synth_ai/py.typed +0 -0
- synth_ai/task/__init__.py +14 -5
- synth_ai/task/contracts.py +124 -38
- synth_ai/task/proxy.py +48 -56
- synth_ai/task/rubrics/__init__.py +53 -0
- synth_ai/task/rubrics/loaders.py +133 -0
- synth_ai/task/rubrics/models.py +57 -0
- synth_ai/task/rubrics/scoring.py +113 -0
- synth_ai/{rubrics/validators.py → task/rubrics/strict.py} +53 -30
- synth_ai/task/server.py +8 -7
- synth_ai/task/validators.py +269 -6
- synth_ai/tracing_v3/decorators.py +7 -3
- synth_ai/tracing_v3/replica_sync.py +4 -4
- synth_ai/tracing_v3/serialization.py +5 -5
- synth_ai/tracing_v3/trace_utils.py +317 -0
- synth_ai/tracing_v3/turso/native_manager.py +3 -3
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/METADATA +4 -1
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/RECORD +214 -101
- examples/agora_ex/README_MoE.md +0 -224
- examples/agora_ex/__init__.py +0 -7
- examples/agora_ex/agora_ex.py +0 -65
- examples/agora_ex/agora_ex_task_app.py +0 -590
- examples/agora_ex/configs/rl_lora_qwen3_moe_2xh200.toml +0 -121
- examples/agora_ex/reward_fn_grpo-human.py +0 -129
- examples/agora_ex/system_prompt_CURRENT.md +0 -63
- examples/agora_ex/task_app/agora_ex_task_app.py +0 -590
- examples/agora_ex/task_app/reward_fn_grpo-human.py +0 -129
- examples/agora_ex/task_app/system_prompt_CURRENT.md +0 -63
- synth_ai/rubrics/__init__.py +0 -22
- synth_ai/task/rubrics.py +0 -219
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/README.md +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/README.md +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/branching.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/environment_routes.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/app.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/environment.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/react_agent.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/shared.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/tools.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/hosted_app.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/openai_client.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/main.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/registry.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/volume.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_agents.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/utils.py +0 -0
- /examples/{rl/task_app → task_apps/math}/README.md +0 -0
- /examples/{rl/task_app → task_apps/math}/math_task_app.py +0 -0
- /examples/{rl → workflows/math_rl}/configs/eval_base_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/eval_rl_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen17.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_ft_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/run_eval.py +0 -0
- /examples/{rl → workflows/math_rl}/run_rl_and_save.py +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,685 @@
|
|
|
1
|
+
"""Task App configuration for the PokéAgent Emerald speedrun environment."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import base64
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
import sys
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from io import BytesIO
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Any, Iterable, Sequence
|
|
13
|
+
|
|
14
|
+
from fastapi import HTTPException, Request
|
|
15
|
+
|
|
16
|
+
try: # Optional dependency resolved at runtime during reset()
|
|
17
|
+
from pokemon_env.emulator import EmeraldEmulator # type: ignore
|
|
18
|
+
except Exception: # pragma: no cover - handled later with explicit error
|
|
19
|
+
EmeraldEmulator = None
|
|
20
|
+
|
|
21
|
+
from synth_ai.task.apps import ModalDeploymentConfig, TaskAppEntry, register_task_app
|
|
22
|
+
from synth_ai.task.contracts import (
|
|
23
|
+
RolloutMetrics,
|
|
24
|
+
RolloutRequest,
|
|
25
|
+
RolloutResponse,
|
|
26
|
+
RolloutStep,
|
|
27
|
+
RolloutTrajectory,
|
|
28
|
+
TaskInfo,
|
|
29
|
+
)
|
|
30
|
+
from synth_ai.task.datasets import TaskDatasetRegistry, TaskDatasetSpec
|
|
31
|
+
from synth_ai.task.server import ProxyConfig, TaskAppConfig
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
DATASET_SPEC = TaskDatasetSpec(
|
|
37
|
+
id="pokemon_emerald_objectives",
|
|
38
|
+
name="Pokémon Emerald Speedrun Objectives",
|
|
39
|
+
version="0.1.0",
|
|
40
|
+
splits=["train", "eval"],
|
|
41
|
+
default_split="train",
|
|
42
|
+
description="Savestate checkpoints for the PokéAgent Track 2 Emerald speedrun starter.",
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _resolve_repo_root(env_key: str, repo_dir: str) -> Path | None:
|
|
47
|
+
env_path = os.getenv(env_key)
|
|
48
|
+
if env_path:
|
|
49
|
+
candidate = Path(env_path).expanduser()
|
|
50
|
+
if candidate.exists():
|
|
51
|
+
return candidate.resolve()
|
|
52
|
+
|
|
53
|
+
here = Path(__file__).resolve()
|
|
54
|
+
candidates: list[Path] = []
|
|
55
|
+
for ancestor in here.parents:
|
|
56
|
+
candidates.append(ancestor / "external" / repo_dir)
|
|
57
|
+
candidates.append(ancestor / repo_dir)
|
|
58
|
+
for candidate in candidates:
|
|
59
|
+
try:
|
|
60
|
+
resolved = candidate.resolve()
|
|
61
|
+
except Exception: # pragma: no cover - path resolution edge cases
|
|
62
|
+
continue
|
|
63
|
+
if resolved.exists():
|
|
64
|
+
return resolved
|
|
65
|
+
return None
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _ensure_on_path(path: Path | None) -> None:
|
|
69
|
+
if not path:
|
|
70
|
+
return
|
|
71
|
+
path_str = str(path)
|
|
72
|
+
if path_str not in sys.path:
|
|
73
|
+
sys.path.insert(0, path_str)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _maybe_resolve(path: Path | None) -> Path | None:
|
|
77
|
+
if not path:
|
|
78
|
+
return None
|
|
79
|
+
try:
|
|
80
|
+
return path.resolve()
|
|
81
|
+
except Exception:
|
|
82
|
+
return path
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@dataclass(frozen=True)
|
|
86
|
+
class EmeraldScenario:
|
|
87
|
+
seed: int
|
|
88
|
+
name: str
|
|
89
|
+
checkpoint_ref: str
|
|
90
|
+
objective: str
|
|
91
|
+
description: str
|
|
92
|
+
timeout_steps: int
|
|
93
|
+
tags: tuple[str, ...]
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class PokemonEmeraldDataset:
|
|
97
|
+
"""In-memory catalogue of Emerald checkpoints and objectives."""
|
|
98
|
+
|
|
99
|
+
def __init__(self, spec: TaskDatasetSpec) -> None:
|
|
100
|
+
self.spec = spec
|
|
101
|
+
self.repo_root = _resolve_repo_root("POKEAGENT_SPEEDRUN_ROOT", "pokeagent-speedrun")
|
|
102
|
+
_ensure_on_path(self.repo_root)
|
|
103
|
+
|
|
104
|
+
self._state_roots: list[Path] = []
|
|
105
|
+
self._rom_roots: list[Path] = []
|
|
106
|
+
assets_root_env = os.getenv("POKEMON_EMERALD_ASSETS")
|
|
107
|
+
if assets_root_env:
|
|
108
|
+
assets_path = Path(assets_root_env).expanduser()
|
|
109
|
+
self._state_roots.append(assets_path)
|
|
110
|
+
self._rom_roots.append(assets_path)
|
|
111
|
+
if self.repo_root:
|
|
112
|
+
self._state_roots.extend(
|
|
113
|
+
[
|
|
114
|
+
self.repo_root / "Emerald-GBAdvance",
|
|
115
|
+
self.repo_root / "tests" / "states",
|
|
116
|
+
self.repo_root / "pokemon_env" / "states",
|
|
117
|
+
]
|
|
118
|
+
)
|
|
119
|
+
self._rom_roots.extend(
|
|
120
|
+
[
|
|
121
|
+
self.repo_root / "Emerald-GBAdvance",
|
|
122
|
+
self.repo_root / "roms",
|
|
123
|
+
]
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
scenarios: list[EmeraldScenario] = [
|
|
127
|
+
EmeraldScenario(
|
|
128
|
+
seed=4001,
|
|
129
|
+
name="littleroot_intro",
|
|
130
|
+
checkpoint_ref="Emerald-GBAdvance/truck_start.state",
|
|
131
|
+
objective="Exit the moving truck and meet May.",
|
|
132
|
+
description="Spawn inside the Littleroot truck with dialogue mid-sequence.",
|
|
133
|
+
timeout_steps=1800,
|
|
134
|
+
tags=("tutorial", "movement"),
|
|
135
|
+
),
|
|
136
|
+
EmeraldScenario(
|
|
137
|
+
seed=4002,
|
|
138
|
+
name="rustboro_split",
|
|
139
|
+
checkpoint_ref="Emerald-GBAdvance/start.state",
|
|
140
|
+
objective="Defeat Roxanne (Badge 1).",
|
|
141
|
+
description="Start at Rustboro gym entrance with levelled Torchic party.",
|
|
142
|
+
timeout_steps=7200,
|
|
143
|
+
tags=("badge", "combat", "routing"),
|
|
144
|
+
),
|
|
145
|
+
EmeraldScenario(
|
|
146
|
+
seed=4003,
|
|
147
|
+
name="mauville_goal",
|
|
148
|
+
checkpoint_ref="Emerald-GBAdvance/quick_start_save.state",
|
|
149
|
+
objective="Acquire HM06 Rock Smash and return to Wally.",
|
|
150
|
+
description="Begins after Slateport with prepared inventory routing.",
|
|
151
|
+
timeout_steps=5400,
|
|
152
|
+
tags=("hm", "quest"),
|
|
153
|
+
),
|
|
154
|
+
]
|
|
155
|
+
self._scenarios: dict[int, EmeraldScenario] = {s.seed: s for s in scenarios}
|
|
156
|
+
self.default_seed = scenarios[0].seed
|
|
157
|
+
|
|
158
|
+
@property
|
|
159
|
+
def seeds(self) -> list[int]:
|
|
160
|
+
return sorted(self._scenarios)
|
|
161
|
+
|
|
162
|
+
@property
|
|
163
|
+
def count(self) -> int:
|
|
164
|
+
return len(self._scenarios)
|
|
165
|
+
|
|
166
|
+
def resolve_seed(self, seed: int | None) -> int:
|
|
167
|
+
if seed is None:
|
|
168
|
+
return self.default_seed
|
|
169
|
+
if seed not in self._scenarios:
|
|
170
|
+
raise KeyError(f"Unknown Emerald seed: {seed}")
|
|
171
|
+
return seed
|
|
172
|
+
|
|
173
|
+
def describe_seed(self, seed: int) -> dict[str, Any]:
|
|
174
|
+
scenario = self._scenarios.get(seed)
|
|
175
|
+
if not scenario:
|
|
176
|
+
raise KeyError(f"Unknown Emerald seed: {seed}")
|
|
177
|
+
checkpoint_path = self._resolve_checkpoint(scenario.checkpoint_ref)
|
|
178
|
+
return {
|
|
179
|
+
"seed": seed,
|
|
180
|
+
"name": scenario.name,
|
|
181
|
+
"checkpoint_ref": scenario.checkpoint_ref,
|
|
182
|
+
"checkpoint_path": str(checkpoint_path) if checkpoint_path else None,
|
|
183
|
+
"objective": scenario.objective,
|
|
184
|
+
"description": scenario.description,
|
|
185
|
+
"timeout_steps": scenario.timeout_steps,
|
|
186
|
+
"tags": list(scenario.tags),
|
|
187
|
+
"assets_ready": checkpoint_path is not None,
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
def _resolve_checkpoint(self, reference: str) -> Path | None:
|
|
191
|
+
ref = Path(reference)
|
|
192
|
+
candidates: list[Path] = []
|
|
193
|
+
if ref.is_absolute():
|
|
194
|
+
candidates.append(ref)
|
|
195
|
+
if self.repo_root:
|
|
196
|
+
candidates.append(self.repo_root / reference)
|
|
197
|
+
for base in self._state_roots:
|
|
198
|
+
candidates.append(base / ref.name)
|
|
199
|
+
candidates.append(base / reference)
|
|
200
|
+
for candidate in candidates:
|
|
201
|
+
if candidate.exists():
|
|
202
|
+
return _maybe_resolve(candidate)
|
|
203
|
+
return None
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def _build_dataset_registry() -> tuple[TaskDatasetRegistry, PokemonEmeraldDataset]:
|
|
207
|
+
registry = TaskDatasetRegistry()
|
|
208
|
+
dataset = PokemonEmeraldDataset(DATASET_SPEC)
|
|
209
|
+
registry.register(DATASET_SPEC, lambda _spec: dataset, cache=True)
|
|
210
|
+
return registry, dataset
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def _base_task_info(dataset: PokemonEmeraldDataset) -> TaskInfo:
|
|
214
|
+
return TaskInfo(
|
|
215
|
+
task={"id": "pokemon_emerald", "name": "Pokémon Emerald Speedrun", "version": "0.1.0"},
|
|
216
|
+
environments=["pokemon_emerald"],
|
|
217
|
+
action_space={
|
|
218
|
+
"type": "structured",
|
|
219
|
+
"schema": {
|
|
220
|
+
"type": "object",
|
|
221
|
+
"properties": {
|
|
222
|
+
"macro": {
|
|
223
|
+
"enum": [
|
|
224
|
+
"noop",
|
|
225
|
+
"step_up",
|
|
226
|
+
"step_down",
|
|
227
|
+
"step_left",
|
|
228
|
+
"step_right",
|
|
229
|
+
"press_a",
|
|
230
|
+
"press_b",
|
|
231
|
+
"press_start",
|
|
232
|
+
"press_select",
|
|
233
|
+
"open_menu",
|
|
234
|
+
"close_menu",
|
|
235
|
+
"mash_a",
|
|
236
|
+
]
|
|
237
|
+
},
|
|
238
|
+
"frames": {"type": "integer", "minimum": 1, "maximum": 120},
|
|
239
|
+
"metadata": {"type": "object"},
|
|
240
|
+
},
|
|
241
|
+
"required": ["macro"],
|
|
242
|
+
},
|
|
243
|
+
"notes": "Macros expand to mGBA button sequences inside the Horizons adapter.",
|
|
244
|
+
},
|
|
245
|
+
observation={
|
|
246
|
+
"summary": "Memory-derived game state plus base64-encoded RGB frame.",
|
|
247
|
+
"keys": ["player_state", "party", "inventory", "flags", "frame_png"],
|
|
248
|
+
"player_state": ["map_id", "x", "y", "facing", "badges"],
|
|
249
|
+
},
|
|
250
|
+
dataset={
|
|
251
|
+
**DATASET_SPEC.model_dump(),
|
|
252
|
+
"seed_count": dataset.count,
|
|
253
|
+
"seeds": dataset.seeds,
|
|
254
|
+
"source_repos": [
|
|
255
|
+
"https://github.com/sethkarten/pokeagent-speedrun",
|
|
256
|
+
"https://pokeagent.github.io/track2.html",
|
|
257
|
+
],
|
|
258
|
+
"pokeagent_speedrun_root": str(dataset.repo_root) if dataset.repo_root else None,
|
|
259
|
+
},
|
|
260
|
+
rubric={
|
|
261
|
+
"version": "1",
|
|
262
|
+
"criteria_count": 3,
|
|
263
|
+
"source": "inline",
|
|
264
|
+
"summary": "Milestone completion, time penalties, and soft-lock avoidance.",
|
|
265
|
+
},
|
|
266
|
+
inference={
|
|
267
|
+
"supports_proxy": True,
|
|
268
|
+
"tool": {"name": "emerald_macro", "parallel_tool_calls": False},
|
|
269
|
+
"endpoints": {
|
|
270
|
+
"openai": "/proxy/v1/chat/completions",
|
|
271
|
+
"groq": "/proxy/groq/v1/chat/completions",
|
|
272
|
+
},
|
|
273
|
+
},
|
|
274
|
+
capabilities={
|
|
275
|
+
"supports_rollout": True,
|
|
276
|
+
"supports_env_lifecycle": True,
|
|
277
|
+
"requires_api_key_header": True,
|
|
278
|
+
},
|
|
279
|
+
limits={"max_steps": 10000, "max_time_s": 7200, "max_ops": 8192},
|
|
280
|
+
task_metadata={
|
|
281
|
+
"preferred_backend": "pokeagent-speedrun",
|
|
282
|
+
"emulator": "mGBA",
|
|
283
|
+
"documentation": "https://github.com/sethkarten/pokeagent-speedrun",
|
|
284
|
+
},
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def describe_taskset(dataset: PokemonEmeraldDataset) -> dict[str, Any]:
|
|
289
|
+
return {
|
|
290
|
+
**DATASET_SPEC.model_dump(),
|
|
291
|
+
"count": dataset.count,
|
|
292
|
+
"seeds": dataset.seeds,
|
|
293
|
+
"assets_ready": all(dataset.describe_seed(seed)["assets_ready"] for seed in dataset.seeds),
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def provide_task_instances(
|
|
298
|
+
dataset: PokemonEmeraldDataset, base_info: TaskInfo, seeds: Sequence[int]
|
|
299
|
+
) -> Iterable[TaskInfo]:
|
|
300
|
+
infos: list[TaskInfo] = []
|
|
301
|
+
for seed_value in seeds:
|
|
302
|
+
resolved_seed = dataset.resolve_seed(seed_value)
|
|
303
|
+
details = dataset.describe_seed(resolved_seed)
|
|
304
|
+
infos.append(
|
|
305
|
+
TaskInfo(
|
|
306
|
+
task=base_info.task,
|
|
307
|
+
environments=base_info.environments,
|
|
308
|
+
action_space=base_info.action_space,
|
|
309
|
+
observation={
|
|
310
|
+
**base_info.observation,
|
|
311
|
+
"seed": resolved_seed,
|
|
312
|
+
"checkpoint_ref": details["checkpoint_ref"],
|
|
313
|
+
"objective": details["objective"],
|
|
314
|
+
"timeout_steps": details["timeout_steps"],
|
|
315
|
+
},
|
|
316
|
+
dataset={**base_info.dataset, "seed": resolved_seed, "scenario": details},
|
|
317
|
+
rubric=base_info.rubric,
|
|
318
|
+
inference=base_info.inference,
|
|
319
|
+
capabilities=base_info.capabilities,
|
|
320
|
+
limits=base_info.limits,
|
|
321
|
+
task_metadata={
|
|
322
|
+
**base_info.task_metadata,
|
|
323
|
+
"tags": details["tags"],
|
|
324
|
+
"assets_ready": details["assets_ready"],
|
|
325
|
+
},
|
|
326
|
+
)
|
|
327
|
+
)
|
|
328
|
+
return infos
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
class PokemonEmeraldAdapter:
|
|
332
|
+
"""Adapter around pokeagent-speedrun's mGBA wrapper with snapshot support."""
|
|
333
|
+
|
|
334
|
+
DEFAULT_STEP_PENALTY = 0.01
|
|
335
|
+
BADGE_REWARD = 10.0
|
|
336
|
+
LOCATION_REWARD = 0.5
|
|
337
|
+
|
|
338
|
+
MACRO_BUTTONS: dict[str, list[str]] = {
|
|
339
|
+
"noop": [],
|
|
340
|
+
"press_a": ["a"],
|
|
341
|
+
"press_b": ["b"],
|
|
342
|
+
"press_start": ["start"],
|
|
343
|
+
"press_select": ["select"],
|
|
344
|
+
"step_up": ["up"],
|
|
345
|
+
"step_down": ["down"],
|
|
346
|
+
"step_left": ["left"],
|
|
347
|
+
"step_right": ["right"],
|
|
348
|
+
"open_menu": ["start"],
|
|
349
|
+
"close_menu": ["b"],
|
|
350
|
+
"mash_a": ["a"],
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
def __init__(
|
|
354
|
+
self,
|
|
355
|
+
*,
|
|
356
|
+
scenario: dict[str, Any],
|
|
357
|
+
rom_path: Path,
|
|
358
|
+
frames_per_step: int = 6,
|
|
359
|
+
step_penalty: float = DEFAULT_STEP_PENALTY,
|
|
360
|
+
) -> None:
|
|
361
|
+
if EmeraldEmulator is None:
|
|
362
|
+
raise RuntimeError(
|
|
363
|
+
"pokemon_env.emulator.EmeraldEmulator import failed. "
|
|
364
|
+
"Install pokeagent-speedrun with mgba dependencies before running this adapter."
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
if not rom_path.exists():
|
|
368
|
+
raise FileNotFoundError(
|
|
369
|
+
f"Pokémon Emerald ROM not found at {rom_path}. "
|
|
370
|
+
"Set POKEMON_EMERALD_ROM or upload emerald.gba to the deployment volume."
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
checkpoint_ref = scenario.get("checkpoint_ref")
|
|
374
|
+
checkpoint_path = scenario.get("checkpoint_path")
|
|
375
|
+
if not checkpoint_path or not Path(checkpoint_path).exists():
|
|
376
|
+
raise FileNotFoundError(
|
|
377
|
+
f"Savestate for scenario '{scenario['name']}' not found at {checkpoint_path} "
|
|
378
|
+
f"(reference: {checkpoint_ref})."
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
self.scenario = scenario
|
|
382
|
+
self.rom_path = rom_path
|
|
383
|
+
self.checkpoint_path = Path(checkpoint_path)
|
|
384
|
+
self.frames_per_step = frames_per_step
|
|
385
|
+
self.step_penalty = step_penalty
|
|
386
|
+
self.timeout_steps = int(scenario.get("timeout_steps") or 0)
|
|
387
|
+
|
|
388
|
+
self._emu: EmeraldEmulator | None = None
|
|
389
|
+
self._step_count = 0
|
|
390
|
+
self._episode_return = 0.0
|
|
391
|
+
self._prev_badges = 0
|
|
392
|
+
self._prev_location: str | None = None
|
|
393
|
+
|
|
394
|
+
def reset(self) -> dict[str, Any]:
|
|
395
|
+
self.close()
|
|
396
|
+
self._emu = EmeraldEmulator(str(self.rom_path), headless=True, sound=False)
|
|
397
|
+
self._emu.initialize()
|
|
398
|
+
self._emu.load_state(path=str(self.checkpoint_path))
|
|
399
|
+
|
|
400
|
+
self._step_count = 0
|
|
401
|
+
self._episode_return = 0.0
|
|
402
|
+
|
|
403
|
+
obs = self._build_observation()
|
|
404
|
+
self._prev_badges = len(obs["player_state"].get("badges", []))
|
|
405
|
+
self._prev_location = obs["player_state"].get("location")
|
|
406
|
+
return obs
|
|
407
|
+
|
|
408
|
+
def step(self, action: dict[str, Any]) -> tuple[dict[str, Any], float, bool, dict[str, Any]]:
|
|
409
|
+
if self._emu is None:
|
|
410
|
+
raise RuntimeError("Environment not initialised. Call reset() first.")
|
|
411
|
+
|
|
412
|
+
macro = (action or {}).get("macro")
|
|
413
|
+
if macro not in self.MACRO_BUTTONS:
|
|
414
|
+
raise ValueError(
|
|
415
|
+
f"Unsupported macro '{macro}'. "
|
|
416
|
+
f"Valid macros: {sorted(self.MACRO_BUTTONS)}"
|
|
417
|
+
)
|
|
418
|
+
frames = int(action.get("frames") or self.frames_per_step)
|
|
419
|
+
frames = max(1, min(frames, 120))
|
|
420
|
+
|
|
421
|
+
buttons = self.MACRO_BUTTONS[macro]
|
|
422
|
+
for _ in range(frames):
|
|
423
|
+
self._emu.run_frame_with_buttons(buttons)
|
|
424
|
+
|
|
425
|
+
obs = self._build_observation()
|
|
426
|
+
reward = self._compute_reward(obs, macro)
|
|
427
|
+
self._episode_return += reward
|
|
428
|
+
self._step_count += 1
|
|
429
|
+
|
|
430
|
+
done = bool(self.timeout_steps and self._step_count >= self.timeout_steps)
|
|
431
|
+
info = {
|
|
432
|
+
"macro": macro,
|
|
433
|
+
"frames": frames,
|
|
434
|
+
"step_count": self._step_count,
|
|
435
|
+
"episode_return": self._episode_return,
|
|
436
|
+
}
|
|
437
|
+
return obs, reward, done, info
|
|
438
|
+
|
|
439
|
+
def snapshot(self) -> bytes:
|
|
440
|
+
if self._emu is None:
|
|
441
|
+
raise RuntimeError("Environment not initialised. Call reset() first.")
|
|
442
|
+
state_bytes = self._emu.save_state()
|
|
443
|
+
if state_bytes is None:
|
|
444
|
+
raise RuntimeError("Failed to capture Emerald savestate bytes.")
|
|
445
|
+
return bytes(state_bytes)
|
|
446
|
+
|
|
447
|
+
def restore(self, snapshot_bytes: bytes) -> dict[str, Any]:
|
|
448
|
+
if self._emu is None:
|
|
449
|
+
raise RuntimeError("Environment not initialised. Call reset() first.")
|
|
450
|
+
self._emu.load_state(state_bytes=snapshot_bytes)
|
|
451
|
+
obs = self._build_observation()
|
|
452
|
+
self._prev_badges = len(obs["player_state"].get("badges", []))
|
|
453
|
+
self._prev_location = obs["player_state"].get("location")
|
|
454
|
+
return obs
|
|
455
|
+
|
|
456
|
+
def close(self) -> None:
|
|
457
|
+
if self._emu is not None:
|
|
458
|
+
try:
|
|
459
|
+
self._emu.stop()
|
|
460
|
+
except Exception: # pragma: no cover - best effort clean-up
|
|
461
|
+
pass
|
|
462
|
+
self._emu = None
|
|
463
|
+
|
|
464
|
+
# -- helpers -------------------------------------------------------
|
|
465
|
+
def _encode_frame(self, image) -> str | None:
|
|
466
|
+
if image is None:
|
|
467
|
+
return None
|
|
468
|
+
buffer = BytesIO()
|
|
469
|
+
image.save(buffer, format="PNG")
|
|
470
|
+
encoded = base64.b64encode(buffer.getvalue()).decode("ascii")
|
|
471
|
+
return f"data:image/png;base64,{encoded}"
|
|
472
|
+
|
|
473
|
+
def _build_observation(self) -> dict[str, Any]:
|
|
474
|
+
if self._emu is None or self._emu.memory_reader is None:
|
|
475
|
+
raise RuntimeError("Emerald emulator not initialised.")
|
|
476
|
+
|
|
477
|
+
state = self._emu.get_comprehensive_state()
|
|
478
|
+
player = state.get("player", {})
|
|
479
|
+
game = state.get("game", {})
|
|
480
|
+
visual = state.get("visual", {})
|
|
481
|
+
frame_encoded = self._encode_frame(visual.get("screenshot"))
|
|
482
|
+
|
|
483
|
+
badges = game.get("badges") or []
|
|
484
|
+
location = player.get("location")
|
|
485
|
+
coords = player.get("position") or {}
|
|
486
|
+
|
|
487
|
+
summary_bits = [
|
|
488
|
+
f"Location: {location}",
|
|
489
|
+
f"Position: ({coords.get('x')}, {coords.get('y')})",
|
|
490
|
+
f"Badges: {len(badges)}",
|
|
491
|
+
]
|
|
492
|
+
if game.get("game_state"):
|
|
493
|
+
summary_bits.append(f"State: {game['game_state']}")
|
|
494
|
+
if game.get("is_in_battle"):
|
|
495
|
+
summary_bits.append("In battle")
|
|
496
|
+
|
|
497
|
+
observation = {
|
|
498
|
+
"player_state": {
|
|
499
|
+
"name": player.get("name"),
|
|
500
|
+
"position": coords,
|
|
501
|
+
"facing": player.get("facing"),
|
|
502
|
+
"location": location,
|
|
503
|
+
"badges": badges,
|
|
504
|
+
"game_time": game.get("time"),
|
|
505
|
+
},
|
|
506
|
+
"party": game.get("party"),
|
|
507
|
+
"inventory": game.get("items"),
|
|
508
|
+
"flags": {
|
|
509
|
+
"game_state": game.get("game_state"),
|
|
510
|
+
"in_battle": bool(game.get("is_in_battle")),
|
|
511
|
+
},
|
|
512
|
+
"frame_png": frame_encoded,
|
|
513
|
+
"text": " | ".join(filter(None, summary_bits)),
|
|
514
|
+
}
|
|
515
|
+
return observation
|
|
516
|
+
|
|
517
|
+
def _compute_reward(self, observation: dict[str, Any], macro: str) -> float:
|
|
518
|
+
reward = -self.step_penalty
|
|
519
|
+
|
|
520
|
+
badge_count = len(observation["player_state"].get("badges", []))
|
|
521
|
+
if badge_count > self._prev_badges:
|
|
522
|
+
reward += (badge_count - self._prev_badges) * self.BADGE_REWARD
|
|
523
|
+
|
|
524
|
+
location = observation["player_state"].get("location")
|
|
525
|
+
if location and location != self._prev_location:
|
|
526
|
+
reward += self.LOCATION_REWARD
|
|
527
|
+
|
|
528
|
+
if macro in {"press_a", "open_menu"}:
|
|
529
|
+
reward += 0.02
|
|
530
|
+
|
|
531
|
+
self._prev_badges = badge_count
|
|
532
|
+
self._prev_location = location
|
|
533
|
+
return reward
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
async def rollout_executor(request: RolloutRequest, fastapi_request: Request) -> RolloutResponse:
|
|
537
|
+
dataset: PokemonEmeraldDataset | None = fastapi_request.app.state.get("emerald_dataset")
|
|
538
|
+
if dataset is None:
|
|
539
|
+
raise HTTPException(status_code=500, detail="Emerald dataset missing from app state.")
|
|
540
|
+
|
|
541
|
+
seed = dataset.resolve_seed(request.env.seed)
|
|
542
|
+
scenario = dataset.describe_seed(seed)
|
|
543
|
+
|
|
544
|
+
rom_candidates: list[Path] = []
|
|
545
|
+
env_rom = os.getenv("POKEMON_EMERALD_ROM")
|
|
546
|
+
if env_rom:
|
|
547
|
+
rom_candidates.append(Path(env_rom).expanduser())
|
|
548
|
+
|
|
549
|
+
assets_root = os.getenv("POKEMON_EMERALD_ASSETS")
|
|
550
|
+
if assets_root:
|
|
551
|
+
rom_candidates.append(Path(assets_root).expanduser() / "emerald.gba")
|
|
552
|
+
|
|
553
|
+
# Fallback relative to checkpoint directory
|
|
554
|
+
rom_candidates.append(Path(scenario["checkpoint_ref"]).resolve().parent / "rom.gba")
|
|
555
|
+
|
|
556
|
+
rom_path = next((candidate for candidate in rom_candidates if candidate.exists()), None)
|
|
557
|
+
if rom_path is None:
|
|
558
|
+
raise HTTPException(
|
|
559
|
+
status_code=500,
|
|
560
|
+
detail=(
|
|
561
|
+
"Unable to locate Pokémon Emerald ROM. "
|
|
562
|
+
"Set POKEMON_EMERALD_ROM or place emerald.gba alongside the savestates."
|
|
563
|
+
),
|
|
564
|
+
)
|
|
565
|
+
|
|
566
|
+
frames_per_step = int(request.env.config.get("frames_per_step", 6))
|
|
567
|
+
adapter = PokemonEmeraldAdapter(
|
|
568
|
+
scenario=scenario,
|
|
569
|
+
rom_path=rom_path,
|
|
570
|
+
frames_per_step=frames_per_step,
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
try:
|
|
574
|
+
obs0 = adapter.reset()
|
|
575
|
+
steps: list[RolloutStep] = [
|
|
576
|
+
RolloutStep(
|
|
577
|
+
obs=obs0,
|
|
578
|
+
tool_calls=[],
|
|
579
|
+
reward=0.0,
|
|
580
|
+
done=False,
|
|
581
|
+
info={"available_macros": sorted(PokemonEmeraldAdapter.MACRO_BUTTONS)},
|
|
582
|
+
),
|
|
583
|
+
]
|
|
584
|
+
|
|
585
|
+
total_reward = 0.0
|
|
586
|
+
done = False
|
|
587
|
+
|
|
588
|
+
for op in request.ops or []:
|
|
589
|
+
if done:
|
|
590
|
+
break
|
|
591
|
+
action_payload = op.get("action") if isinstance(op, dict) else op
|
|
592
|
+
if action_payload is None:
|
|
593
|
+
continue
|
|
594
|
+
obs, reward, done, info = adapter.step(action_payload)
|
|
595
|
+
total_reward += reward
|
|
596
|
+
steps.append(
|
|
597
|
+
RolloutStep(obs=obs, tool_calls=[], reward=reward, done=done, info=info),
|
|
598
|
+
)
|
|
599
|
+
|
|
600
|
+
final_obs = steps[-1].obs if steps else obs0
|
|
601
|
+
metrics = RolloutMetrics(
|
|
602
|
+
episode_returns=[total_reward],
|
|
603
|
+
mean_return=total_reward,
|
|
604
|
+
num_steps=max(len(steps) - 1, 0),
|
|
605
|
+
num_episodes=1,
|
|
606
|
+
outcome_score=total_reward,
|
|
607
|
+
details={
|
|
608
|
+
"seed": seed,
|
|
609
|
+
"scenario": scenario["name"],
|
|
610
|
+
"checkpoint_ref": scenario["checkpoint_ref"],
|
|
611
|
+
"assets_ready": scenario["assets_ready"],
|
|
612
|
+
},
|
|
613
|
+
)
|
|
614
|
+
|
|
615
|
+
trajectory = RolloutTrajectory(
|
|
616
|
+
env_id="pokemon_emerald",
|
|
617
|
+
policy_id=request.policy.policy_id or "policy",
|
|
618
|
+
steps=steps,
|
|
619
|
+
final={"observation": final_obs, "reward": total_reward, "done": done},
|
|
620
|
+
length=len(steps),
|
|
621
|
+
)
|
|
622
|
+
|
|
623
|
+
return RolloutResponse(
|
|
624
|
+
run_id=request.run_id,
|
|
625
|
+
trajectories=[trajectory],
|
|
626
|
+
branches={},
|
|
627
|
+
metrics=metrics,
|
|
628
|
+
aborted=False,
|
|
629
|
+
ops_executed=len(request.ops or []),
|
|
630
|
+
trace=None,
|
|
631
|
+
)
|
|
632
|
+
finally:
|
|
633
|
+
adapter.close()
|
|
634
|
+
|
|
635
|
+
|
|
636
|
+
def build_config() -> TaskAppConfig:
|
|
637
|
+
registry, dataset = _build_dataset_registry()
|
|
638
|
+
base_info = _base_task_info(dataset)
|
|
639
|
+
config = TaskAppConfig(
|
|
640
|
+
app_id="pokemon_emerald",
|
|
641
|
+
name="Pokémon Emerald Task App",
|
|
642
|
+
description="Expose Emerald speedrun checkpoints via the Synth AI task framework.",
|
|
643
|
+
base_task_info=base_info,
|
|
644
|
+
describe_taskset=lambda: describe_taskset(dataset),
|
|
645
|
+
provide_task_instances=lambda seeds: provide_task_instances(dataset, base_info, seeds),
|
|
646
|
+
rollout=rollout_executor,
|
|
647
|
+
dataset_registry=registry,
|
|
648
|
+
proxy=ProxyConfig(
|
|
649
|
+
enable_openai=True,
|
|
650
|
+
enable_groq=True,
|
|
651
|
+
system_hint="Respond with Emerald macro actions encoded as JSON.",
|
|
652
|
+
),
|
|
653
|
+
app_state={"emerald_dataset": dataset},
|
|
654
|
+
require_api_key=True,
|
|
655
|
+
expose_debug_env=True,
|
|
656
|
+
cors_origins=["*"],
|
|
657
|
+
)
|
|
658
|
+
return config
|
|
659
|
+
|
|
660
|
+
|
|
661
|
+
register_task_app(
|
|
662
|
+
entry=TaskAppEntry(
|
|
663
|
+
app_id="pokemon_emerald",
|
|
664
|
+
description="Pokémon Emerald (Track 2) task app skeleton.",
|
|
665
|
+
config_factory=build_config,
|
|
666
|
+
aliases=("pokemon_speedrun", "pokemon_track2"),
|
|
667
|
+
env_files=(),
|
|
668
|
+
modal=ModalDeploymentConfig(
|
|
669
|
+
app_name="pokemon-emerald-task-app",
|
|
670
|
+
python_version="3.11",
|
|
671
|
+
pip_packages=("horizons-ai",),
|
|
672
|
+
extra_local_dirs=(
|
|
673
|
+
("repo", "/opt/synth_ai_repo"),
|
|
674
|
+
("pokeagent_speedrun", "/external/pokeagent-speedrun"),
|
|
675
|
+
),
|
|
676
|
+
secret_names=("ENVIRONMENT_API_KEY", "OPENAI_API_KEY", "GROQ_API_KEY"),
|
|
677
|
+
timeout=900,
|
|
678
|
+
memory=9216,
|
|
679
|
+
cpu=4.0,
|
|
680
|
+
),
|
|
681
|
+
)
|
|
682
|
+
)
|
|
683
|
+
|
|
684
|
+
|
|
685
|
+
__all__ = ["build_config"]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Evaluation config for running Groq Qwen/Qwen3-32B on the Enron email QA task app.
|
|
2
|
+
|
|
3
|
+
provider = "groq"
|
|
4
|
+
task_app_url = "http://127.0.0.1:8102"
|
|
5
|
+
model = "qwen/qwen3-32b"
|
|
6
|
+
seeds = [0, 1, 2]
|
|
7
|
+
max_turns = 8
|
|
8
|
+
concurrency = 1
|
|
9
|
+
|
|
10
|
+
[policy]
|
|
11
|
+
provider = "groq"
|
|
12
|
+
model = "qwen/qwen3-32b"
|
|
13
|
+
temperature = 0.2
|
|
14
|
+
top_p = 0.8
|
|
15
|
+
max_tokens = 1024
|
|
16
|
+
max_turns = 8
|