synth-ai 0.2.13.dev1__py3-none-any.whl → 0.2.13.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +12 -1
- examples/swe/task_app/grpo_swe_mini.py +55 -26
- examples/swe/task_app/hosted/rollout.py +40 -0
- examples/swe/task_app/hosted/test_service.py +5 -6
- examples/task_apps/TESTING.md +275 -0
- examples/task_apps/__init__.py +0 -0
- examples/task_apps/crafter/__init__.py +0 -0
- examples/task_apps/crafter/task_app/__init__.py +2 -0
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter.py +18 -13
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter_task_app.py +1 -1
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/policy.py +60 -4
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/policy_routes.py +25 -3
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/rollout.py +10 -0
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_service.py +5 -6
- examples/task_apps/dev/pokemon_emerald/__init__.py +2 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/README.md +811 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/__init__.py +120 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/action.py +160 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/memory.py +155 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/perception.py +69 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/planning.py +96 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/simple.py +1502 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/system_prompt.py +4 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/grab_map.py +68 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/manual.py +216 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/__init__.py +35 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emerald_utils.py +631 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emulator.py +1544 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/enums.py +1428 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/memory_reader.py +4848 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/types.py +41 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/utils.py +298 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pyproject.toml +95 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/run.py +204 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/app.py +2152 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/client.py +429 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/frame_server.py +155 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/README.md +78 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/run_tests.py +122 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_direct.py +76 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_prompts.py +413 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_battle_state_formatting.py +204 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection.py +133 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection_comprehensive.py +229 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_direct_agent_emulator.py +300 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_fps_adjustment_pytest.py +205 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_direct.py +200 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_transition.py +284 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_map_ground_truth_comparison.py +468 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_memory_map.py +575 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_server_map_validation.py +311 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_torchic_state.py +259 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/anticheat.py +372 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/checkpoint.py +296 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/error_handler.py +275 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/get_local_ip.py +22 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/helpers.py +44 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/llm_logger.py +514 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_formatter.py +415 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher.py +1763 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher_singleton.py +33 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_trimmer.py +106 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_visualizer.py +334 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/ocr_dialogue.py +1020 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/recording.py +188 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/state_formatter.py +1481 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/vlm.py +862 -0
- examples/task_apps/dev/pokemon_emerald/modal_app.py +114 -0
- examples/task_apps/dev/pokemon_emerald/task_app/README.md +81 -0
- examples/task_apps/dev/pokemon_emerald/task_app/__init__.py +6 -0
- examples/task_apps/dev/pokemon_emerald/task_app/pokemon_emerald.py +685 -0
- examples/task_apps/enron/__init__.py +1 -0
- examples/task_apps/enron/eval_groq_qwen32.toml +16 -0
- examples/task_apps/enron/task_app/README.md +14 -0
- examples/task_apps/enron/task_app/__init__.py +1 -0
- examples/task_apps/enron/task_app/grpo_enron.py +906 -0
- examples/task_apps/enron/task_app/grpo_enron_task_app.py +146 -0
- examples/task_apps/enron/tests/__init__.py +2 -0
- examples/task_apps/enron/tests/conftest.py +115 -0
- examples/task_apps/enron/tests/integration/__init__.py +2 -0
- examples/task_apps/enron/tests/integration/test_enron_eval.py +177 -0
- examples/task_apps/enron/tests/integration/test_enron_rollout.py +135 -0
- examples/task_apps/enron/tests/unit/__init__.py +2 -0
- examples/task_apps/enron/tests/unit/test_enron_environment.py +126 -0
- examples/task_apps/math/__init__.py +0 -0
- examples/{rl/task_app → task_apps/math}/math_single_step.py +19 -10
- examples/task_apps/pokemon_battle/__init__.py +2 -0
- examples/task_apps/pokemon_battle/modal_app.py +104 -0
- examples/task_apps/pokemon_battle/task_app/README.md +68 -0
- examples/task_apps/pokemon_battle/task_app/__init__.py +6 -0
- examples/task_apps/pokemon_battle/task_app/pokemon_showdown.py +932 -0
- examples/task_apps/pokemon_red/README.md +357 -0
- examples/task_apps/pokemon_red/__init__.py +3 -0
- examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +225 -0
- examples/task_apps/pokemon_red/pallet_town_rl_config.toml +73 -0
- examples/task_apps/pokemon_red/task_app.py +606 -0
- examples/task_apps/pokemon_red/test_pallet_town_rewards.py +191 -0
- examples/task_apps/sokoban/README.md +307 -0
- examples/task_apps/sokoban/__init__.py +3 -0
- examples/task_apps/sokoban/eval_groq_qwen32.toml +16 -0
- examples/task_apps/sokoban/eval_openai_gpt5.toml +16 -0
- examples/task_apps/sokoban/task_app.py +1058 -0
- examples/task_apps/sokoban/tests/__init__.py +2 -0
- examples/task_apps/sokoban/tests/conftest.py +113 -0
- examples/task_apps/sokoban/tests/integration/__init__.py +2 -0
- examples/task_apps/sokoban/tests/integration/test_sokoban_eval.py +57 -0
- examples/task_apps/sokoban/tests/integration/test_sokoban_rollout.py +198 -0
- examples/task_apps/sokoban/tests/unit/__init__.py +2 -0
- examples/task_apps/sokoban/tests/unit/test_sokoban_environment.py +114 -0
- examples/task_apps/verilog/__init__.py +1 -0
- examples/task_apps/verilog/eval_groq_qwen32b.toml +20 -0
- examples/task_apps/verilog/task_app/README.md +12 -0
- examples/task_apps/verilog/task_app/__init__.py +1 -0
- examples/task_apps/verilog/task_app/grpo_verilog.py +931 -0
- examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +145 -0
- examples/task_apps/verilog/tests/__init__.py +2 -0
- examples/task_apps/verilog/tests/conftest.py +115 -0
- examples/task_apps/verilog/tests/integration/__init__.py +2 -0
- examples/task_apps/verilog/tests/integration/test_verilog_eval.py +179 -0
- examples/task_apps/verilog/tests/integration/test_verilog_rollout.py +55 -0
- examples/task_apps/verilog/tests/unit/__init__.py +2 -0
- examples/task_apps/verilog/tests/unit/test_verilog_scoring.py +118 -0
- examples/vlm/crafter_openai_vlm_agent.py +4 -4
- examples/vlm/run_crafter_vlm_benchmark.py +4 -4
- examples/workflows/__init__.py +0 -0
- examples/workflows/math_rl/__init__.py +0 -0
- examples/workflows/math_rl/download_dataset.py +80 -0
- synth_ai/__init__.py +2 -2
- synth_ai/api/train/builders.py +25 -11
- synth_ai/api/train/cli.py +12 -6
- synth_ai/api/train/configs/__init__.py +10 -10
- synth_ai/api/train/configs/rl.py +5 -4
- synth_ai/api/train/configs/sft.py +4 -3
- synth_ai/api/train/env_resolver.py +5 -2
- synth_ai/api/train/supported_algos.py +10 -5
- synth_ai/api/train/utils.py +7 -4
- synth_ai/cli/__init__.py +7 -51
- synth_ai/cli/_storage.py +4 -3
- synth_ai/cli/_validate_task_app.py +11 -0
- synth_ai/cli/balance.py +4 -3
- synth_ai/cli/calc.py +2 -2
- synth_ai/cli/demo.py +14 -7
- synth_ai/cli/legacy_root_backup.py +1 -1
- synth_ai/cli/rl_demo.py +8 -7
- synth_ai/cli/root.py +0 -97
- synth_ai/cli/task_apps.py +1707 -186
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +28 -16
- synth_ai/environments/examples/enron/engine.py +7 -2
- synth_ai/environments/examples/enron/environment.py +68 -0
- synth_ai/environments/examples/red/engine.py +27 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +7 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_progression.py +477 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +32 -0
- synth_ai/environments/examples/red/environment.py +60 -0
- synth_ai/environments/examples/sokoban/taskset.py +116 -0
- synth_ai/environments/examples/verilog/engine.py +30 -4
- synth_ai/evals/client.py +58 -61
- synth_ai/jobs/client.py +16 -4
- synth_ai/judge_schemas.py +16 -16
- synth_ai/py.typed +0 -0
- synth_ai/task/__init__.py +14 -5
- synth_ai/task/contracts.py +124 -38
- synth_ai/task/proxy.py +48 -56
- synth_ai/task/rubrics/__init__.py +53 -0
- synth_ai/task/rubrics/loaders.py +133 -0
- synth_ai/task/rubrics/models.py +57 -0
- synth_ai/task/rubrics/scoring.py +113 -0
- synth_ai/{rubrics/validators.py → task/rubrics/strict.py} +53 -30
- synth_ai/task/server.py +8 -7
- synth_ai/task/validators.py +269 -6
- synth_ai/tracing_v3/decorators.py +7 -3
- synth_ai/tracing_v3/replica_sync.py +4 -4
- synth_ai/tracing_v3/serialization.py +5 -5
- synth_ai/tracing_v3/trace_utils.py +317 -0
- synth_ai/tracing_v3/turso/native_manager.py +3 -3
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/METADATA +4 -1
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/RECORD +214 -101
- examples/agora_ex/README_MoE.md +0 -224
- examples/agora_ex/__init__.py +0 -7
- examples/agora_ex/agora_ex.py +0 -65
- examples/agora_ex/agora_ex_task_app.py +0 -590
- examples/agora_ex/configs/rl_lora_qwen3_moe_2xh200.toml +0 -121
- examples/agora_ex/reward_fn_grpo-human.py +0 -129
- examples/agora_ex/system_prompt_CURRENT.md +0 -63
- examples/agora_ex/task_app/agora_ex_task_app.py +0 -590
- examples/agora_ex/task_app/reward_fn_grpo-human.py +0 -129
- examples/agora_ex/task_app/system_prompt_CURRENT.md +0 -63
- synth_ai/rubrics/__init__.py +0 -22
- synth_ai/task/rubrics.py +0 -219
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/README.md +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/README.md +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/branching.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/environment_routes.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/app.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/environment.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/react_agent.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/shared.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/tools.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/hosted_app.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/openai_client.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/main.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/registry.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/volume.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_agents.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/utils.py +0 -0
- /examples/{rl/task_app → task_apps/math}/README.md +0 -0
- /examples/{rl/task_app → task_apps/math}/math_task_app.py +0 -0
- /examples/{rl → workflows/math_rl}/configs/eval_base_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/eval_rl_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen17.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_ft_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/run_eval.py +0 -0
- /examples/{rl → workflows/math_rl}/run_rl_and_save.py +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,932 @@
|
|
|
1
|
+
"""Task App configuration for a Horizons-backed Pokémon Showdown battle environment."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import math
|
|
7
|
+
import os
|
|
8
|
+
import random
|
|
9
|
+
import sys
|
|
10
|
+
from copy import deepcopy
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Any, Iterable, Sequence
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
import numpy as np
|
|
17
|
+
from fastapi import HTTPException, Request
|
|
18
|
+
from poke_env.data.gen_data import GenData
|
|
19
|
+
from poke_env.environment.battle import Battle
|
|
20
|
+
from poke_env.environment.move import Move
|
|
21
|
+
from poke_env.environment.pokemon import Pokemon
|
|
22
|
+
from poke_env.environment.pokemon_type import PokemonType
|
|
23
|
+
from poke_env.player.battle_order import BattleOrder
|
|
24
|
+
from poke_env.player.local_simulation import LocalSim
|
|
25
|
+
from poke_env.player.baselines import AbyssalPlayer
|
|
26
|
+
from poke_env.teambuilder.teambuilder import Teambuilder
|
|
27
|
+
from poke_env.teambuilder.teambuilder_pokemon import TeambuilderPokemon
|
|
28
|
+
|
|
29
|
+
from synth_ai.task.apps import ModalDeploymentConfig, TaskAppEntry, register_task_app
|
|
30
|
+
from synth_ai.task.contracts import (
|
|
31
|
+
RolloutMetrics,
|
|
32
|
+
RolloutRequest,
|
|
33
|
+
RolloutResponse,
|
|
34
|
+
RolloutStep,
|
|
35
|
+
RolloutTrajectory,
|
|
36
|
+
TaskInfo,
|
|
37
|
+
)
|
|
38
|
+
from synth_ai.task.datasets import TaskDatasetRegistry, TaskDatasetSpec
|
|
39
|
+
from synth_ai.task.server import ProxyConfig, TaskAppConfig
|
|
40
|
+
|
|
41
|
+
logger = logging.getLogger(__name__)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
DATASET_SPEC = TaskDatasetSpec(
|
|
45
|
+
id="pokemon_showdown_reference",
|
|
46
|
+
name="Pokémon Showdown Reference Matches",
|
|
47
|
+
version="0.1.0",
|
|
48
|
+
splits=["train", "eval"],
|
|
49
|
+
default_split="train",
|
|
50
|
+
description=(
|
|
51
|
+
"Seeded Gen 9 OU matches derived from the PokeChamp benchmark packs and "
|
|
52
|
+
"PokéAgent Track 1 starter kit."
|
|
53
|
+
),
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _resolve_repo_root(env_key: str, repo_dir: str) -> Path | None:
|
|
58
|
+
env_path = os.getenv(env_key)
|
|
59
|
+
if env_path:
|
|
60
|
+
candidate = Path(env_path).expanduser()
|
|
61
|
+
if candidate.exists():
|
|
62
|
+
return candidate.resolve()
|
|
63
|
+
|
|
64
|
+
here = Path(__file__).resolve()
|
|
65
|
+
candidates: list[Path] = []
|
|
66
|
+
for ancestor in here.parents:
|
|
67
|
+
candidates.append(ancestor / "external" / repo_dir)
|
|
68
|
+
candidates.append(ancestor / repo_dir)
|
|
69
|
+
for candidate in candidates:
|
|
70
|
+
try:
|
|
71
|
+
resolved = candidate.resolve()
|
|
72
|
+
except Exception: # pragma: no cover - path resolution edge cases
|
|
73
|
+
continue
|
|
74
|
+
if resolved.exists():
|
|
75
|
+
return resolved
|
|
76
|
+
return None
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _ensure_on_path(path: Path | None) -> None:
|
|
80
|
+
if not path:
|
|
81
|
+
return
|
|
82
|
+
path_str = str(path)
|
|
83
|
+
if path_str not in sys.path:
|
|
84
|
+
sys.path.insert(0, path_str)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _read_text_if_exists(path: Path | None) -> str | None:
|
|
88
|
+
if not path:
|
|
89
|
+
return None
|
|
90
|
+
try:
|
|
91
|
+
return path.read_text()
|
|
92
|
+
except Exception:
|
|
93
|
+
return None
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@dataclass(frozen=True)
|
|
97
|
+
class PokemonBattleScenario:
|
|
98
|
+
seed: int
|
|
99
|
+
name: str
|
|
100
|
+
format_id: str
|
|
101
|
+
player_team_ref: str
|
|
102
|
+
opponent_team_ref: str
|
|
103
|
+
description: str
|
|
104
|
+
source: str
|
|
105
|
+
tags: tuple[str, ...] = ()
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class PokemonBattleDataset:
|
|
109
|
+
"""In-memory catalogue of deterministic battle scenarios."""
|
|
110
|
+
|
|
111
|
+
def __init__(self, spec: TaskDatasetSpec) -> None:
|
|
112
|
+
self.spec = spec
|
|
113
|
+
self.repo_root = _resolve_repo_root("POKECHAMP_ROOT", "pokechamp")
|
|
114
|
+
_ensure_on_path(self.repo_root)
|
|
115
|
+
|
|
116
|
+
self._team_roots: list[Path] = []
|
|
117
|
+
if self.repo_root:
|
|
118
|
+
self._team_roots.extend(
|
|
119
|
+
[
|
|
120
|
+
self.repo_root / "poke_env" / "data" / "static" / "gen9" / "ou",
|
|
121
|
+
self.repo_root / "poke_env" / "data" / "static" / "teams",
|
|
122
|
+
self.repo_root / "resource" / "teams",
|
|
123
|
+
]
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
scenarios: list[PokemonBattleScenario] = [
|
|
127
|
+
PokemonBattleScenario(
|
|
128
|
+
seed=1001,
|
|
129
|
+
name="hazard_balance_vs_pivot_offense",
|
|
130
|
+
format_id="gen9ou",
|
|
131
|
+
player_team_ref="poke_env/data/static/gen9/ou/gen9ou-1825.txt",
|
|
132
|
+
opponent_team_ref="poke_env/data/static/gen9/ou/gen9ou-1500.txt",
|
|
133
|
+
description=(
|
|
134
|
+
"Balanced hazard stack roster into a pivot-heavy opponent. "
|
|
135
|
+
"Mirrors the PokeChamp ICML 2025 evaluation seed."
|
|
136
|
+
),
|
|
137
|
+
source="https://github.com/sethkarten/pokechamp",
|
|
138
|
+
tags=("benchmark", "pokechamp", "gen9"),
|
|
139
|
+
),
|
|
140
|
+
PokemonBattleScenario(
|
|
141
|
+
seed=2002,
|
|
142
|
+
name="sunroom_vs_rainroom",
|
|
143
|
+
format_id="gen9ou",
|
|
144
|
+
player_team_ref="poke_env/data/static/gen9/ou/gen9ou-1500.txt",
|
|
145
|
+
opponent_team_ref="poke_env/data/static/gen9/ou/gen9ou-0.txt",
|
|
146
|
+
description="Weather control showdown drawn from the PokéAgent ladder starter.",
|
|
147
|
+
source="https://pokeagent.github.io/track1.html",
|
|
148
|
+
tags=("ladder", "weather", "gen9"),
|
|
149
|
+
),
|
|
150
|
+
PokemonBattleScenario(
|
|
151
|
+
seed=3003,
|
|
152
|
+
name="stall_vs_hyper_offense",
|
|
153
|
+
format_id="gen9ou",
|
|
154
|
+
player_team_ref="poke_env/data/static/gen9/ou/gen9ou-0.txt",
|
|
155
|
+
opponent_team_ref="poke_env/data/static/gen9/ou/gen9ou-1825.txt",
|
|
156
|
+
description="Long-horizon stall versus hyper-offense curriculum seed.",
|
|
157
|
+
source="https://github.com/sethkarten/pokechamp",
|
|
158
|
+
tags=("curriculum", "gen9"),
|
|
159
|
+
),
|
|
160
|
+
]
|
|
161
|
+
self._scenarios: dict[int, PokemonBattleScenario] = {s.seed: s for s in scenarios}
|
|
162
|
+
self.default_seed = scenarios[0].seed
|
|
163
|
+
|
|
164
|
+
@property
|
|
165
|
+
def seeds(self) -> list[int]:
|
|
166
|
+
return sorted(self._scenarios)
|
|
167
|
+
|
|
168
|
+
@property
|
|
169
|
+
def formats(self) -> list[str]:
|
|
170
|
+
return sorted({scenario.format_id for scenario in self._scenarios.values()})
|
|
171
|
+
|
|
172
|
+
@property
|
|
173
|
+
def count(self) -> int:
|
|
174
|
+
return len(self._scenarios)
|
|
175
|
+
|
|
176
|
+
def resolve_seed(self, seed: int | None) -> int:
|
|
177
|
+
if seed is None:
|
|
178
|
+
return self.default_seed
|
|
179
|
+
if seed not in self._scenarios:
|
|
180
|
+
raise KeyError(f"Unknown battle seed: {seed}")
|
|
181
|
+
return seed
|
|
182
|
+
|
|
183
|
+
def describe_seed(self, seed: int) -> dict[str, Any]:
|
|
184
|
+
scenario = self._scenarios.get(seed)
|
|
185
|
+
if not scenario:
|
|
186
|
+
raise KeyError(f"Unknown battle seed: {seed}")
|
|
187
|
+
|
|
188
|
+
player_team_text = self._load_team_text(scenario.player_team_ref)
|
|
189
|
+
opponent_team_text = self._load_team_text(scenario.opponent_team_ref)
|
|
190
|
+
|
|
191
|
+
return {
|
|
192
|
+
"seed": seed,
|
|
193
|
+
"name": scenario.name,
|
|
194
|
+
"format_id": scenario.format_id,
|
|
195
|
+
"player_team_ref": scenario.player_team_ref,
|
|
196
|
+
"player_team": player_team_text,
|
|
197
|
+
"opponent_team_ref": scenario.opponent_team_ref,
|
|
198
|
+
"opponent_team": opponent_team_text,
|
|
199
|
+
"description": scenario.description,
|
|
200
|
+
"source": scenario.source,
|
|
201
|
+
"tags": list(scenario.tags),
|
|
202
|
+
"assets_ready": bool(player_team_text and opponent_team_text),
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
def _load_team_text(self, reference: str) -> str | None:
|
|
206
|
+
raw_ref = reference.strip()
|
|
207
|
+
if not raw_ref:
|
|
208
|
+
return None
|
|
209
|
+
|
|
210
|
+
if raw_ref.startswith("text:"):
|
|
211
|
+
return raw_ref.split(":", 1)[1]
|
|
212
|
+
|
|
213
|
+
candidates: list[Path] = []
|
|
214
|
+
ref_path = Path(raw_ref)
|
|
215
|
+
if ref_path.is_absolute():
|
|
216
|
+
candidates.append(ref_path)
|
|
217
|
+
|
|
218
|
+
if self.repo_root:
|
|
219
|
+
candidates.append(self.repo_root / raw_ref)
|
|
220
|
+
|
|
221
|
+
for base in self._team_roots:
|
|
222
|
+
candidates.append(base / ref_path.name)
|
|
223
|
+
candidates.append(base / raw_ref)
|
|
224
|
+
|
|
225
|
+
for candidate in candidates:
|
|
226
|
+
if candidate.exists():
|
|
227
|
+
return _read_text_if_exists(candidate)
|
|
228
|
+
return None
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def _build_dataset_registry() -> tuple[TaskDatasetRegistry, PokemonBattleDataset]:
|
|
232
|
+
registry = TaskDatasetRegistry()
|
|
233
|
+
dataset = PokemonBattleDataset(DATASET_SPEC)
|
|
234
|
+
registry.register(DATASET_SPEC, lambda _spec: dataset, cache=True)
|
|
235
|
+
return registry, dataset
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def _base_task_info(dataset: PokemonBattleDataset) -> TaskInfo:
|
|
239
|
+
return TaskInfo(
|
|
240
|
+
task={"id": "pokemon_showdown", "name": "Pokémon Showdown Battle", "version": "0.1.0"},
|
|
241
|
+
environment="pokemon_showdown",
|
|
242
|
+
action_space={
|
|
243
|
+
"type": "structured",
|
|
244
|
+
"schema": {
|
|
245
|
+
"type": "object",
|
|
246
|
+
"properties": {
|
|
247
|
+
"action": {"enum": ["move", "switch", "team-preview"]},
|
|
248
|
+
"index": {"type": "integer", "minimum": 0},
|
|
249
|
+
"target": {"type": "integer", "minimum": 0, "nullable": True},
|
|
250
|
+
"metadata": {"type": "object"},
|
|
251
|
+
},
|
|
252
|
+
"required": ["action"],
|
|
253
|
+
},
|
|
254
|
+
"notes": "Legal indices are surfaced in observation['legal_actions'].",
|
|
255
|
+
},
|
|
256
|
+
observation={
|
|
257
|
+
"summary": "Structured Showdown state and a text rendering per turn.",
|
|
258
|
+
"keys": ["structured", "legal_actions", "text"],
|
|
259
|
+
"text_role": "Battle transcript for language agents.",
|
|
260
|
+
},
|
|
261
|
+
dataset={
|
|
262
|
+
**DATASET_SPEC.model_dump(),
|
|
263
|
+
"seed_count": dataset.count,
|
|
264
|
+
"seeds": dataset.seeds,
|
|
265
|
+
"formats": dataset.formats,
|
|
266
|
+
"source_repos": [
|
|
267
|
+
"https://github.com/sethkarten/pokechamp",
|
|
268
|
+
"https://pokeagent.github.io/track1.html",
|
|
269
|
+
],
|
|
270
|
+
"pokechamp_root": str(dataset.repo_root) if dataset.repo_root else None,
|
|
271
|
+
},
|
|
272
|
+
rubric={
|
|
273
|
+
"version": "1",
|
|
274
|
+
"criteria_count": 2,
|
|
275
|
+
"source": "inline",
|
|
276
|
+
"summary": "Win/loss outcome plus faint differential.",
|
|
277
|
+
},
|
|
278
|
+
inference={
|
|
279
|
+
"supports_proxy": True,
|
|
280
|
+
"tool": {"name": "battle_action", "parallel_tool_calls": False},
|
|
281
|
+
"endpoints": {
|
|
282
|
+
"openai": "/proxy/v1/chat/completions",
|
|
283
|
+
"groq": "/proxy/groq/v1/chat/completions",
|
|
284
|
+
},
|
|
285
|
+
},
|
|
286
|
+
limits={"max_turns": 200, "max_time_s": 1800, "max_ops": 4096},
|
|
287
|
+
task_metadata={
|
|
288
|
+
"preferred_engine": "pokechamp",
|
|
289
|
+
"supports_remote_server": True,
|
|
290
|
+
"documentation": "https://github.com/sethkarten/pokechamp",
|
|
291
|
+
},
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def describe_taskset(dataset: PokemonBattleDataset) -> dict[str, Any]:
|
|
296
|
+
return {
|
|
297
|
+
**DATASET_SPEC.model_dump(),
|
|
298
|
+
"count": dataset.count,
|
|
299
|
+
"seeds": dataset.seeds,
|
|
300
|
+
"formats": dataset.formats,
|
|
301
|
+
"assets_ready": all(dataset.describe_seed(seed)["assets_ready"] for seed in dataset.seeds),
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def provide_task_instances(
|
|
306
|
+
dataset: PokemonBattleDataset, base_info: TaskInfo, seeds: Sequence[int]
|
|
307
|
+
) -> Iterable[TaskInfo]:
|
|
308
|
+
infos: list[TaskInfo] = []
|
|
309
|
+
base_observation = getattr(base_info, "observation", None)
|
|
310
|
+
if hasattr(base_observation, "model_dump"):
|
|
311
|
+
observation_template = base_observation.model_dump()
|
|
312
|
+
elif isinstance(base_observation, dict):
|
|
313
|
+
observation_template = dict(base_observation)
|
|
314
|
+
else:
|
|
315
|
+
observation_template = {}
|
|
316
|
+
|
|
317
|
+
for seed_value in seeds:
|
|
318
|
+
resolved_seed = dataset.resolve_seed(seed_value)
|
|
319
|
+
details = dataset.describe_seed(resolved_seed)
|
|
320
|
+
infos.append(
|
|
321
|
+
TaskInfo(
|
|
322
|
+
task=base_info.task,
|
|
323
|
+
environment=base_info.environment,
|
|
324
|
+
action_space=base_info.action_space,
|
|
325
|
+
observation={
|
|
326
|
+
**observation_template,
|
|
327
|
+
"seed": resolved_seed,
|
|
328
|
+
"format_id": details["format_id"],
|
|
329
|
+
"player_team_ref": details["player_team_ref"],
|
|
330
|
+
"opponent_team_ref": details["opponent_team_ref"],
|
|
331
|
+
"description": details["description"],
|
|
332
|
+
},
|
|
333
|
+
dataset={
|
|
334
|
+
**base_info.dataset.model_dump(),
|
|
335
|
+
"seed": resolved_seed,
|
|
336
|
+
"scenario": details,
|
|
337
|
+
},
|
|
338
|
+
rubric=base_info.rubric,
|
|
339
|
+
inference=base_info.inference,
|
|
340
|
+
limits=base_info.limits,
|
|
341
|
+
task_metadata={
|
|
342
|
+
**base_info.task_metadata,
|
|
343
|
+
"source": details["source"],
|
|
344
|
+
"tags": details["tags"],
|
|
345
|
+
"assets_ready": details["assets_ready"],
|
|
346
|
+
},
|
|
347
|
+
)
|
|
348
|
+
)
|
|
349
|
+
return infos
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
class PokechampAssets:
|
|
353
|
+
"""Lazy loader for pokechamp static data used by the local simulator."""
|
|
354
|
+
|
|
355
|
+
move_effect: dict[str, Any] = {}
|
|
356
|
+
pokemon_move_dict: dict[str, Any] = {}
|
|
357
|
+
ability_effect: dict[str, Any] = {}
|
|
358
|
+
pokemon_ability_dict: dict[str, Any] = {}
|
|
359
|
+
item_effect: dict[str, Any] = {}
|
|
360
|
+
pokemon_item_dict: dict[str, Any] = {}
|
|
361
|
+
loaded = False
|
|
362
|
+
|
|
363
|
+
@classmethod
|
|
364
|
+
def ensure_loaded(cls, repo_root: Path) -> None:
|
|
365
|
+
if cls.loaded:
|
|
366
|
+
return
|
|
367
|
+
|
|
368
|
+
def _require_json(rel_path: str) -> dict[str, Any]:
|
|
369
|
+
path = repo_root / rel_path
|
|
370
|
+
if not path.exists():
|
|
371
|
+
raise FileNotFoundError(
|
|
372
|
+
f"Required pokechamp asset missing at {path}. "
|
|
373
|
+
"Ensure POKECHAMP_ROOT is mounted with the repository assets."
|
|
374
|
+
)
|
|
375
|
+
with path.open("r", encoding="utf-8") as f:
|
|
376
|
+
return json.load(f)
|
|
377
|
+
|
|
378
|
+
cls.move_effect = _require_json("poke_env/data/static/moves/moves_effect.json")
|
|
379
|
+
cls.pokemon_move_dict = _require_json("poke_env/data/static/moves/gen8pokemon_move_dict.json")
|
|
380
|
+
cls.ability_effect = _require_json("poke_env/data/static/abilities/ability_effect.json")
|
|
381
|
+
cls.pokemon_ability_dict = _require_json("poke_env/data/static/abilities/gen8pokemon_ability_dict.json")
|
|
382
|
+
cls.item_effect = _require_json("poke_env/data/static/items/item_effect.json")
|
|
383
|
+
cls.pokemon_item_dict = {}
|
|
384
|
+
cls.loaded = True
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
class PokemonShowdownAdapter:
|
|
388
|
+
"""Local deterministic battle adapter powered by pokechamp's LocalSim."""
|
|
389
|
+
|
|
390
|
+
STEP_PENALTY = 0.05
|
|
391
|
+
WIN_REWARD = 1.0
|
|
392
|
+
LOSS_PENALTY = -1.0
|
|
393
|
+
|
|
394
|
+
def __init__(self, *, scenario: dict[str, Any], repo_root: Path, seed: int | None = None):
|
|
395
|
+
if not repo_root.exists():
|
|
396
|
+
raise FileNotFoundError(
|
|
397
|
+
f"Pokechamp repository root not found at {repo_root}. "
|
|
398
|
+
"Set POKECHAMP_ROOT to the cloned repository."
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
if not scenario.get("assets_ready"):
|
|
402
|
+
raise ValueError(
|
|
403
|
+
f"Scenario '{scenario['name']}' is missing team assets. "
|
|
404
|
+
"Ensure the pokechamp dataset files are present."
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
PokechampAssets.ensure_loaded(repo_root)
|
|
408
|
+
self.scenario = scenario
|
|
409
|
+
self.repo_root = repo_root
|
|
410
|
+
self.format_id = scenario["format_id"]
|
|
411
|
+
|
|
412
|
+
seed_value = seed if seed is not None else scenario.get("seed", 0)
|
|
413
|
+
random.seed(seed_value)
|
|
414
|
+
np.random.seed(seed_value)
|
|
415
|
+
try: # pragma: no cover - optional dependency
|
|
416
|
+
import torch
|
|
417
|
+
|
|
418
|
+
torch.manual_seed(seed_value)
|
|
419
|
+
except Exception:
|
|
420
|
+
pass
|
|
421
|
+
|
|
422
|
+
self.random = random.Random(seed_value)
|
|
423
|
+
self.gen_data = GenData.from_format(self.format_id)
|
|
424
|
+
|
|
425
|
+
player_team_text = scenario.get("player_team")
|
|
426
|
+
opponent_team_text = scenario.get("opponent_team")
|
|
427
|
+
if not player_team_text or not opponent_team_text:
|
|
428
|
+
raise ValueError(
|
|
429
|
+
f"Scenario '{scenario['name']}' is missing team definitions."
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
self._base_battle = self._build_base_battle(
|
|
433
|
+
player_team_text=player_team_text,
|
|
434
|
+
opponent_team_text=opponent_team_text,
|
|
435
|
+
)
|
|
436
|
+
self.sim = self._create_sim(deepcopy(self._base_battle))
|
|
437
|
+
self.battle = self.sim.battle
|
|
438
|
+
self._sync_available_actions()
|
|
439
|
+
|
|
440
|
+
self._prev_score = self._score()
|
|
441
|
+
self.turn = 0
|
|
442
|
+
self.done = False
|
|
443
|
+
self.outcome = 0.0
|
|
444
|
+
|
|
445
|
+
def reset(self) -> dict[str, Any]:
|
|
446
|
+
self.sim = self._create_sim(deepcopy(self._base_battle))
|
|
447
|
+
self.battle = self.sim.battle
|
|
448
|
+
self._sync_available_actions()
|
|
449
|
+
self._prev_score = self._score()
|
|
450
|
+
self.turn = 0
|
|
451
|
+
self.done = False
|
|
452
|
+
self.outcome = 0.0
|
|
453
|
+
self._abyssal = AbyssalPlayer(
|
|
454
|
+
battle_format=self.format_id,
|
|
455
|
+
team=self._opponent_packed_team,
|
|
456
|
+
save_replays=False,
|
|
457
|
+
log_level=logging.WARNING,
|
|
458
|
+
)
|
|
459
|
+
return self._build_observation()
|
|
460
|
+
|
|
461
|
+
def step(self, action: dict[str, Any]) -> tuple[dict[str, Any], float, bool, dict[str, Any]]:
|
|
462
|
+
if self.done:
|
|
463
|
+
raise RuntimeError("Battle finished. Call reset() before stepping again.")
|
|
464
|
+
|
|
465
|
+
agent_order = self._action_to_order(action)
|
|
466
|
+
opponent_order = self._opponent_policy()
|
|
467
|
+
|
|
468
|
+
self.sim.step(agent_order, opponent_order)
|
|
469
|
+
self.battle = self.sim.battle
|
|
470
|
+
self._sync_available_actions()
|
|
471
|
+
|
|
472
|
+
reward = -self.STEP_PENALTY
|
|
473
|
+
current = self._score()
|
|
474
|
+
reward += current - self._prev_score
|
|
475
|
+
self._prev_score = current
|
|
476
|
+
self.turn += 1
|
|
477
|
+
|
|
478
|
+
self.done = self._check_finished()
|
|
479
|
+
if self.done:
|
|
480
|
+
reward += self.outcome
|
|
481
|
+
|
|
482
|
+
observation = self._build_observation()
|
|
483
|
+
info = {"legal_actions": self._legal_actions(), "turn": self.turn}
|
|
484
|
+
return observation, reward, self.done, info
|
|
485
|
+
|
|
486
|
+
def snapshot(self) -> bytes:
|
|
487
|
+
import pickle
|
|
488
|
+
|
|
489
|
+
state = {
|
|
490
|
+
"battle": deepcopy(self.battle),
|
|
491
|
+
"turn": self.turn,
|
|
492
|
+
"prev_score": self._prev_score,
|
|
493
|
+
"random_state": self.random.getstate(),
|
|
494
|
+
"outcome": self.outcome,
|
|
495
|
+
}
|
|
496
|
+
return pickle.dumps(state)
|
|
497
|
+
|
|
498
|
+
def restore(self, snapshot_bytes: bytes) -> dict[str, Any]:
|
|
499
|
+
import pickle
|
|
500
|
+
|
|
501
|
+
state = pickle.loads(snapshot_bytes)
|
|
502
|
+
self.sim = self._create_sim(state["battle"])
|
|
503
|
+
self.battle = self.sim.battle
|
|
504
|
+
self._sync_available_actions()
|
|
505
|
+
self.turn = state["turn"]
|
|
506
|
+
self._prev_score = state["prev_score"]
|
|
507
|
+
self.random.setstate(state["random_state"])
|
|
508
|
+
self.outcome = state["outcome"]
|
|
509
|
+
self.done = False
|
|
510
|
+
return self._build_observation()
|
|
511
|
+
|
|
512
|
+
# -- battle construction -------------------------------------------------
|
|
513
|
+
def _build_base_battle(self, *, player_team_text: str, opponent_team_text: str) -> Battle:
|
|
514
|
+
battle = Battle(
|
|
515
|
+
battle_tag="battle-local",
|
|
516
|
+
username="agent",
|
|
517
|
+
logger=logging.getLogger("pokemon_showdown_env"),
|
|
518
|
+
gen=self.gen_data.gen,
|
|
519
|
+
save_replays=False,
|
|
520
|
+
)
|
|
521
|
+
battle._format = self.format_id
|
|
522
|
+
battle._player_role = "p1"
|
|
523
|
+
|
|
524
|
+
self._player_packed_team = self._apply_team(
|
|
525
|
+
battle, player_team_text, prefix="p1", is_player=True
|
|
526
|
+
)
|
|
527
|
+
self._opponent_packed_team = self._apply_team(
|
|
528
|
+
battle, opponent_team_text, prefix="p2", is_player=False
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
battle._opponent_username = "opponent"
|
|
532
|
+
battle._player_username = "agent"
|
|
533
|
+
|
|
534
|
+
self._abyssal = AbyssalPlayer(
|
|
535
|
+
battle_format=self.format_id,
|
|
536
|
+
team=self._opponent_packed_team,
|
|
537
|
+
save_replays=False,
|
|
538
|
+
log_level=logging.WARNING,
|
|
539
|
+
)
|
|
540
|
+
return battle
|
|
541
|
+
|
|
542
|
+
def _apply_team(self, battle: Battle, team_text: str, *, prefix: str, is_player: bool) -> str:
|
|
543
|
+
mons = Teambuilder.parse_showdown_team(team_text or "")
|
|
544
|
+
if not mons:
|
|
545
|
+
raise ValueError(f"Showdown team for prefix {prefix} is empty.")
|
|
546
|
+
|
|
547
|
+
team_dict = getattr(battle, "_team" if is_player else "_opponent_team")
|
|
548
|
+
|
|
549
|
+
packed_entries: list[str] = []
|
|
550
|
+
for mon in mons:
|
|
551
|
+
entry = TeambuilderPokemon()
|
|
552
|
+
entry.species = mon.species
|
|
553
|
+
entry.item = mon.item
|
|
554
|
+
entry.ability = mon.ability
|
|
555
|
+
entry.moves = mon.moves[:]
|
|
556
|
+
entry.level = mon.level or 80
|
|
557
|
+
entry.shiny = mon.shiny
|
|
558
|
+
packed_entries.append(entry.formatted)
|
|
559
|
+
|
|
560
|
+
packed_team = "]".join(packed_entries)
|
|
561
|
+
|
|
562
|
+
for idx, mon in enumerate(mons):
|
|
563
|
+
species = mon.species or mon.nickname or f"Slot{idx+1}"
|
|
564
|
+
ident = f"{prefix}: {species}"
|
|
565
|
+
details = f"{species}, L{mon.level or 80}"
|
|
566
|
+
pokemon = battle.get_pokemon(
|
|
567
|
+
ident,
|
|
568
|
+
force_self_team=is_player,
|
|
569
|
+
force_opp_team=not is_player,
|
|
570
|
+
details=details,
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
pokemon._level = mon.level or 80
|
|
574
|
+
if mon.item:
|
|
575
|
+
pokemon.item = mon.item
|
|
576
|
+
if mon.ability:
|
|
577
|
+
pokemon.ability = mon.ability
|
|
578
|
+
for move in mon.moves:
|
|
579
|
+
pokemon._add_move(move)
|
|
580
|
+
pokemon.set_hp_status("300/300")
|
|
581
|
+
pokemon._shiny = mon.shiny
|
|
582
|
+
team_dict[ident] = pokemon
|
|
583
|
+
|
|
584
|
+
if idx == 0:
|
|
585
|
+
battle.switch(f"{prefix}a: {species}", details, "300/300")
|
|
586
|
+
|
|
587
|
+
return packed_team
|
|
588
|
+
|
|
589
|
+
def _create_sim(self, battle: Battle) -> LocalSim:
|
|
590
|
+
return LocalSim(
|
|
591
|
+
battle=battle,
|
|
592
|
+
move_effect=PokechampAssets.move_effect,
|
|
593
|
+
pokemon_move_dict=PokechampAssets.pokemon_move_dict,
|
|
594
|
+
ability_effect=PokechampAssets.ability_effect,
|
|
595
|
+
pokemon_ability_dict=PokechampAssets.pokemon_ability_dict,
|
|
596
|
+
item_effect=PokechampAssets.item_effect,
|
|
597
|
+
pokemon_item_dict=PokechampAssets.pokemon_item_dict,
|
|
598
|
+
gen=self.gen_data,
|
|
599
|
+
_dynamax_disable=True,
|
|
600
|
+
format=self.format_id,
|
|
601
|
+
prompt_translate=None,
|
|
602
|
+
)
|
|
603
|
+
|
|
604
|
+
# -- helpers -------------------------------------------------------------
|
|
605
|
+
def _sync_available_actions(self) -> None:
|
|
606
|
+
active = self.battle.active_pokemon
|
|
607
|
+
if active:
|
|
608
|
+
self.battle._available_moves = list(active.available_moves)
|
|
609
|
+
else:
|
|
610
|
+
self.battle._available_moves = []
|
|
611
|
+
|
|
612
|
+
switches: list[Pokemon] = []
|
|
613
|
+
for mon in self.battle.team.values():
|
|
614
|
+
if mon and not mon.active and not mon.fainted:
|
|
615
|
+
switches.append(mon)
|
|
616
|
+
self.battle._available_switches = switches
|
|
617
|
+
|
|
618
|
+
def _action_to_order(self, action: dict[str, Any]) -> BattleOrder:
|
|
619
|
+
if not action or "action" not in action:
|
|
620
|
+
raise ValueError("Action payload must include 'action' key.")
|
|
621
|
+
|
|
622
|
+
action_type = action["action"]
|
|
623
|
+
if action_type == "move":
|
|
624
|
+
moves = self.battle.available_moves
|
|
625
|
+
if not moves:
|
|
626
|
+
raise ValueError("No moves available.")
|
|
627
|
+
index = int(action.get("index", 0))
|
|
628
|
+
if index < 0 or index >= len(moves):
|
|
629
|
+
raise IndexError(f"Move index {index} out of range.")
|
|
630
|
+
move = moves[index]
|
|
631
|
+
target = action.get("target")
|
|
632
|
+
move_target = target if isinstance(target, int) else None
|
|
633
|
+
return BattleOrder(move, move_target=move_target)
|
|
634
|
+
|
|
635
|
+
if action_type == "switch":
|
|
636
|
+
switches = self.battle.available_switches
|
|
637
|
+
if not switches:
|
|
638
|
+
raise ValueError("No switches available.")
|
|
639
|
+
index = int(action.get("index", 0))
|
|
640
|
+
if index < 0 or index >= len(switches):
|
|
641
|
+
raise IndexError(f"Switch index {index} out of range.")
|
|
642
|
+
pokemon = switches[index]
|
|
643
|
+
return BattleOrder(pokemon)
|
|
644
|
+
|
|
645
|
+
raise ValueError(f"Unsupported action type '{action_type}'.")
|
|
646
|
+
|
|
647
|
+
def _opponent_policy(self) -> BattleOrder:
|
|
648
|
+
try:
|
|
649
|
+
order = self._abyssal.choose_move(self.battle)
|
|
650
|
+
if isinstance(order, BattleOrder):
|
|
651
|
+
return order
|
|
652
|
+
except Exception as exc:
|
|
653
|
+
logger.warning("Abyssal opponent failed: %s", exc, exc_info=True)
|
|
654
|
+
|
|
655
|
+
opponent = self.battle.opponent_active_pokemon
|
|
656
|
+
agent_active = self.battle.active_pokemon
|
|
657
|
+
|
|
658
|
+
moves = opponent.available_moves if opponent else []
|
|
659
|
+
if moves and agent_active:
|
|
660
|
+
def _move_score(move: Move) -> tuple[float, float]:
|
|
661
|
+
try:
|
|
662
|
+
multiplier = move.type.damage_multiplier(
|
|
663
|
+
agent_active.type_1 or PokemonType.NORMAL,
|
|
664
|
+
agent_active.type_2,
|
|
665
|
+
)
|
|
666
|
+
except Exception:
|
|
667
|
+
multiplier = 1.0
|
|
668
|
+
base_power = move.base_power or 0
|
|
669
|
+
return multiplier, base_power
|
|
670
|
+
|
|
671
|
+
best_move = max(moves, key=_move_score)
|
|
672
|
+
return BattleOrder(best_move)
|
|
673
|
+
|
|
674
|
+
bench = [
|
|
675
|
+
mon for mon in self.battle.opponent_team.values() if mon and not mon.active and not mon.fainted
|
|
676
|
+
]
|
|
677
|
+
if bench:
|
|
678
|
+
return BattleOrder(bench[0])
|
|
679
|
+
|
|
680
|
+
fallback_moves = self.battle.available_moves
|
|
681
|
+
if fallback_moves:
|
|
682
|
+
return BattleOrder(fallback_moves[0])
|
|
683
|
+
raise RuntimeError("Opponent policy could not determine a valid action.")
|
|
684
|
+
|
|
685
|
+
def _legal_actions(self) -> dict[str, Any]:
|
|
686
|
+
moves = [
|
|
687
|
+
{
|
|
688
|
+
"index": idx,
|
|
689
|
+
"id": move.id,
|
|
690
|
+
"name": move.id,
|
|
691
|
+
"type": str(move.type),
|
|
692
|
+
"base_power": move.base_power or 0,
|
|
693
|
+
"accuracy": move.accuracy or 0,
|
|
694
|
+
"pp": move.current_pp if move.current_pp is not None else move.max_pp,
|
|
695
|
+
}
|
|
696
|
+
for idx, move in enumerate(self.battle.available_moves)
|
|
697
|
+
]
|
|
698
|
+
switches = [
|
|
699
|
+
{
|
|
700
|
+
"index": idx,
|
|
701
|
+
"species": mon.species,
|
|
702
|
+
"hp_fraction": mon.current_hp_fraction,
|
|
703
|
+
"status": mon.status.name if mon.status else None,
|
|
704
|
+
}
|
|
705
|
+
for idx, mon in enumerate(self.battle.available_switches)
|
|
706
|
+
]
|
|
707
|
+
return {"moves": moves, "switches": switches}
|
|
708
|
+
|
|
709
|
+
def _build_observation(self) -> dict[str, Any]:
|
|
710
|
+
active = self.battle.active_pokemon
|
|
711
|
+
opponent = self.battle.opponent_active_pokemon
|
|
712
|
+
|
|
713
|
+
observation = {
|
|
714
|
+
"structured": {
|
|
715
|
+
"scenario": self.scenario["name"],
|
|
716
|
+
"format_id": self.format_id,
|
|
717
|
+
"turn": self.turn,
|
|
718
|
+
"ally_active": self._serialize_pokemon(active),
|
|
719
|
+
"opponent_active": self._serialize_pokemon(opponent),
|
|
720
|
+
"ally_team": self._serialize_team(self.battle.team),
|
|
721
|
+
"opponent_team": self._serialize_team(self.battle.opponent_team),
|
|
722
|
+
},
|
|
723
|
+
"legal_actions": self._legal_actions(),
|
|
724
|
+
"text": self._build_text_summary(active, opponent),
|
|
725
|
+
}
|
|
726
|
+
return observation
|
|
727
|
+
|
|
728
|
+
def _serialize_pokemon(self, pokemon: Pokemon | None) -> dict[str, Any] | None:
|
|
729
|
+
if pokemon is None:
|
|
730
|
+
return None
|
|
731
|
+
return {
|
|
732
|
+
"species": pokemon.species,
|
|
733
|
+
"hp_fraction": pokemon.current_hp_fraction,
|
|
734
|
+
"status": pokemon.status.name if pokemon.status else None,
|
|
735
|
+
"moves": [move.id for move in pokemon.moves.values()],
|
|
736
|
+
}
|
|
737
|
+
|
|
738
|
+
def _serialize_team(self, team: dict[str, Pokemon]) -> list[dict[str, Any]]:
|
|
739
|
+
bundle: list[dict[str, Any]] = []
|
|
740
|
+
for mon in team.values():
|
|
741
|
+
if mon is None:
|
|
742
|
+
continue
|
|
743
|
+
bundle.append(
|
|
744
|
+
{
|
|
745
|
+
"species": mon.species,
|
|
746
|
+
"hp_fraction": mon.current_hp_fraction,
|
|
747
|
+
"status": mon.status.name if mon.status else None,
|
|
748
|
+
"active": mon.active,
|
|
749
|
+
"fainted": mon.fainted,
|
|
750
|
+
}
|
|
751
|
+
)
|
|
752
|
+
return bundle
|
|
753
|
+
|
|
754
|
+
def _build_text_summary(self, active: Pokemon | None, opponent: Pokemon | None) -> str:
|
|
755
|
+
player_hp = active.current_hp_fraction if active else 0.0
|
|
756
|
+
opponent_hp = opponent.current_hp_fraction if opponent else 0.0
|
|
757
|
+
player_status = active.status.name if active and active.status else "OK"
|
|
758
|
+
opponent_status = opponent.status.name if opponent and opponent.status else "OK"
|
|
759
|
+
return (
|
|
760
|
+
f"Turn {self.turn}: "
|
|
761
|
+
f"{active.species if active else 'None'} ({player_hp:.1f}%, {player_status}) "
|
|
762
|
+
f"vs {opponent.species if opponent else 'None'} ({opponent_hp:.1f}%, {opponent_status})"
|
|
763
|
+
)
|
|
764
|
+
|
|
765
|
+
def _score(self) -> float:
|
|
766
|
+
def team_score(team: dict[str, Pokemon]) -> float:
|
|
767
|
+
return sum((mon.current_hp_fraction or 0.0) for mon in team.values() if mon is not None)
|
|
768
|
+
|
|
769
|
+
return team_score(self.battle.team) - team_score(self.battle.opponent_team)
|
|
770
|
+
|
|
771
|
+
def _check_finished(self) -> bool:
|
|
772
|
+
ours_alive = any(mon and not mon.fainted for mon in self.battle.team.values())
|
|
773
|
+
opp_alive = any(mon and not mon.fainted for mon in self.battle.opponent_team.values())
|
|
774
|
+
|
|
775
|
+
if not ours_alive and not opp_alive:
|
|
776
|
+
self.outcome = 0.0
|
|
777
|
+
return True
|
|
778
|
+
if not ours_alive:
|
|
779
|
+
self.outcome = self.LOSS_PENALTY
|
|
780
|
+
return True
|
|
781
|
+
if not opp_alive:
|
|
782
|
+
self.outcome = self.WIN_REWARD
|
|
783
|
+
return True
|
|
784
|
+
return False
|
|
785
|
+
|
|
786
|
+
|
|
787
|
+
async def rollout_executor(request: RolloutRequest, fastapi_request: Request) -> RolloutResponse:
|
|
788
|
+
dataset: PokemonBattleDataset | None = fastapi_request.app.state.get("battle_dataset")
|
|
789
|
+
if dataset is None:
|
|
790
|
+
raise HTTPException(status_code=500, detail="Battle dataset missing from app state.")
|
|
791
|
+
|
|
792
|
+
seed = dataset.resolve_seed(request.env.seed)
|
|
793
|
+
scenario = dataset.describe_seed(seed)
|
|
794
|
+
if not scenario["assets_ready"]:
|
|
795
|
+
raise HTTPException(
|
|
796
|
+
status_code=500,
|
|
797
|
+
detail=f"Scenario '{scenario['name']}' is missing required assets. "
|
|
798
|
+
"Ensure pokechamp static files are present.",
|
|
799
|
+
)
|
|
800
|
+
|
|
801
|
+
adapter = PokemonShowdownAdapter(
|
|
802
|
+
scenario=scenario,
|
|
803
|
+
repo_root=dataset.repo_root if dataset.repo_root else Path("."),
|
|
804
|
+
seed=request.env.seed,
|
|
805
|
+
)
|
|
806
|
+
|
|
807
|
+
obs0 = adapter.reset()
|
|
808
|
+
steps: list[RolloutStep] = [
|
|
809
|
+
RolloutStep(
|
|
810
|
+
obs=obs0,
|
|
811
|
+
tool_calls=[],
|
|
812
|
+
reward=0.0,
|
|
813
|
+
done=False,
|
|
814
|
+
info={"legal_actions": adapter._legal_actions()},
|
|
815
|
+
),
|
|
816
|
+
]
|
|
817
|
+
|
|
818
|
+
total_reward = 0.0
|
|
819
|
+
done = False
|
|
820
|
+
|
|
821
|
+
def _normalise_op(raw: Any) -> dict[str, Any]:
|
|
822
|
+
payload = raw
|
|
823
|
+
if isinstance(raw, dict) and "arguments" in raw:
|
|
824
|
+
try:
|
|
825
|
+
payload = json.loads(raw["arguments"])
|
|
826
|
+
except json.JSONDecodeError as exc:
|
|
827
|
+
raise ValueError(f"Invalid JSON in arguments field: {exc}") from exc
|
|
828
|
+
if isinstance(payload, dict):
|
|
829
|
+
return payload
|
|
830
|
+
raise ValueError(f"Unsupported op payload: {payload!r}")
|
|
831
|
+
|
|
832
|
+
for op in request.ops or []:
|
|
833
|
+
if done:
|
|
834
|
+
break
|
|
835
|
+
try:
|
|
836
|
+
action_payload = _normalise_op(op)
|
|
837
|
+
except ValueError as exc:
|
|
838
|
+
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
|
839
|
+
obs, reward, done, info = adapter.step(action_payload)
|
|
840
|
+
total_reward += reward
|
|
841
|
+
steps.append(
|
|
842
|
+
RolloutStep(obs=obs, tool_calls=[], reward=reward, done=done, info=info),
|
|
843
|
+
)
|
|
844
|
+
|
|
845
|
+
final_obs = steps[-1].obs if steps else obs0
|
|
846
|
+
metrics = RolloutMetrics(
|
|
847
|
+
episode_returns=[total_reward],
|
|
848
|
+
mean_return=total_reward,
|
|
849
|
+
num_steps=max(len(steps) - 1, 0),
|
|
850
|
+
num_episodes=1,
|
|
851
|
+
outcome_score=total_reward,
|
|
852
|
+
details={
|
|
853
|
+
"seed": seed,
|
|
854
|
+
"scenario": scenario["name"],
|
|
855
|
+
"assets_ready": scenario["assets_ready"],
|
|
856
|
+
},
|
|
857
|
+
)
|
|
858
|
+
|
|
859
|
+
# Extract inference_url from policy config
|
|
860
|
+
inference_url = (request.policy.config or {}).get("inference_url")
|
|
861
|
+
|
|
862
|
+
trajectory = RolloutTrajectory(
|
|
863
|
+
env_id="pokemon_showdown",
|
|
864
|
+
policy_id=request.policy.policy_id or "policy",
|
|
865
|
+
steps=steps,
|
|
866
|
+
final={"observation": final_obs, "reward": total_reward, "done": done},
|
|
867
|
+
length=len(steps),
|
|
868
|
+
inference_url=inference_url, # NEW: Required for trace correlation
|
|
869
|
+
)
|
|
870
|
+
|
|
871
|
+
return RolloutResponse(
|
|
872
|
+
run_id=request.run_id,
|
|
873
|
+
trajectories=[trajectory],
|
|
874
|
+
branches={},
|
|
875
|
+
metrics=metrics,
|
|
876
|
+
aborted=False,
|
|
877
|
+
ops_executed=len(request.ops or []),
|
|
878
|
+
trace=None,
|
|
879
|
+
)
|
|
880
|
+
|
|
881
|
+
|
|
882
|
+
def build_config() -> TaskAppConfig:
|
|
883
|
+
registry, dataset = _build_dataset_registry()
|
|
884
|
+
base_info = _base_task_info(dataset)
|
|
885
|
+
config = TaskAppConfig(
|
|
886
|
+
app_id="pokemon_showdown",
|
|
887
|
+
name="Pokémon Showdown Task App",
|
|
888
|
+
description="Expose deterministic Pokémon Showdown battles via the Synth AI task framework.",
|
|
889
|
+
base_task_info=base_info,
|
|
890
|
+
describe_taskset=lambda: describe_taskset(dataset),
|
|
891
|
+
provide_task_instances=lambda seeds: provide_task_instances(dataset, base_info, seeds),
|
|
892
|
+
rollout=rollout_executor,
|
|
893
|
+
dataset_registry=registry,
|
|
894
|
+
proxy=ProxyConfig(
|
|
895
|
+
enable_openai=True,
|
|
896
|
+
enable_groq=True,
|
|
897
|
+
system_hint="Respond with legal Pokémon Showdown actions encoded as JSON.",
|
|
898
|
+
),
|
|
899
|
+
app_state={"battle_dataset": dataset},
|
|
900
|
+
require_api_key=True,
|
|
901
|
+
expose_debug_env=True,
|
|
902
|
+
cors_origins=["*"],
|
|
903
|
+
)
|
|
904
|
+
return config
|
|
905
|
+
|
|
906
|
+
|
|
907
|
+
register_task_app(
|
|
908
|
+
entry=TaskAppEntry(
|
|
909
|
+
app_id="pokemon_showdown",
|
|
910
|
+
description="Pokémon Showdown (Track 1) task app skeleton.",
|
|
911
|
+
config_factory=build_config,
|
|
912
|
+
aliases=("pokemon_battle", "pokemon_track1"),
|
|
913
|
+
env_files=(),
|
|
914
|
+
modal=ModalDeploymentConfig(
|
|
915
|
+
app_name="pokemon-showdown-task-app",
|
|
916
|
+
python_version="3.11",
|
|
917
|
+
pip_packages=("horizons-ai",),
|
|
918
|
+
extra_local_dirs=(
|
|
919
|
+
("repo", "/opt/synth_ai_repo"),
|
|
920
|
+
("pokechamp", "/external/pokechamp"),
|
|
921
|
+
("pokemon_showdown", "/external/pokemon-showdown"),
|
|
922
|
+
),
|
|
923
|
+
secret_names=("ENVIRONMENT_API_KEY", "OPENAI_API_KEY", "GROQ_API_KEY"),
|
|
924
|
+
timeout=900,
|
|
925
|
+
memory=8192,
|
|
926
|
+
cpu=4.0,
|
|
927
|
+
),
|
|
928
|
+
)
|
|
929
|
+
)
|
|
930
|
+
|
|
931
|
+
|
|
932
|
+
__all__ = ["build_config"]
|