synth-ai 0.2.13.dev1__py3-none-any.whl → 0.2.13.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +12 -1
- examples/swe/task_app/grpo_swe_mini.py +55 -26
- examples/swe/task_app/hosted/rollout.py +40 -0
- examples/swe/task_app/hosted/test_service.py +5 -6
- examples/task_apps/TESTING.md +275 -0
- examples/task_apps/__init__.py +0 -0
- examples/task_apps/crafter/__init__.py +0 -0
- examples/task_apps/crafter/task_app/__init__.py +2 -0
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter.py +18 -13
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter_task_app.py +1 -1
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/policy.py +60 -4
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/policy_routes.py +25 -3
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/rollout.py +10 -0
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_service.py +5 -6
- examples/task_apps/dev/pokemon_emerald/__init__.py +2 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/README.md +811 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/__init__.py +120 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/action.py +160 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/memory.py +155 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/perception.py +69 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/planning.py +96 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/simple.py +1502 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/system_prompt.py +4 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/grab_map.py +68 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/manual.py +216 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/__init__.py +35 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emerald_utils.py +631 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emulator.py +1544 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/enums.py +1428 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/memory_reader.py +4848 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/types.py +41 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/utils.py +298 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pyproject.toml +95 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/run.py +204 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/app.py +2152 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/client.py +429 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/frame_server.py +155 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/README.md +78 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/run_tests.py +122 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_direct.py +76 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_prompts.py +413 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_battle_state_formatting.py +204 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection.py +133 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection_comprehensive.py +229 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_direct_agent_emulator.py +300 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_fps_adjustment_pytest.py +205 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_direct.py +200 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_transition.py +284 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_map_ground_truth_comparison.py +468 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_memory_map.py +575 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_server_map_validation.py +311 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_torchic_state.py +259 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/anticheat.py +372 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/checkpoint.py +296 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/error_handler.py +275 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/get_local_ip.py +22 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/helpers.py +44 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/llm_logger.py +514 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_formatter.py +415 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher.py +1763 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher_singleton.py +33 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_trimmer.py +106 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_visualizer.py +334 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/ocr_dialogue.py +1020 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/recording.py +188 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/state_formatter.py +1481 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/vlm.py +862 -0
- examples/task_apps/dev/pokemon_emerald/modal_app.py +114 -0
- examples/task_apps/dev/pokemon_emerald/task_app/README.md +81 -0
- examples/task_apps/dev/pokemon_emerald/task_app/__init__.py +6 -0
- examples/task_apps/dev/pokemon_emerald/task_app/pokemon_emerald.py +685 -0
- examples/task_apps/enron/__init__.py +1 -0
- examples/task_apps/enron/eval_groq_qwen32.toml +16 -0
- examples/task_apps/enron/task_app/README.md +14 -0
- examples/task_apps/enron/task_app/__init__.py +1 -0
- examples/task_apps/enron/task_app/grpo_enron.py +906 -0
- examples/task_apps/enron/task_app/grpo_enron_task_app.py +146 -0
- examples/task_apps/enron/tests/__init__.py +2 -0
- examples/task_apps/enron/tests/conftest.py +115 -0
- examples/task_apps/enron/tests/integration/__init__.py +2 -0
- examples/task_apps/enron/tests/integration/test_enron_eval.py +177 -0
- examples/task_apps/enron/tests/integration/test_enron_rollout.py +135 -0
- examples/task_apps/enron/tests/unit/__init__.py +2 -0
- examples/task_apps/enron/tests/unit/test_enron_environment.py +126 -0
- examples/task_apps/math/__init__.py +0 -0
- examples/{rl/task_app → task_apps/math}/math_single_step.py +19 -10
- examples/task_apps/pokemon_battle/__init__.py +2 -0
- examples/task_apps/pokemon_battle/modal_app.py +104 -0
- examples/task_apps/pokemon_battle/task_app/README.md +68 -0
- examples/task_apps/pokemon_battle/task_app/__init__.py +6 -0
- examples/task_apps/pokemon_battle/task_app/pokemon_showdown.py +932 -0
- examples/task_apps/pokemon_red/README.md +357 -0
- examples/task_apps/pokemon_red/__init__.py +3 -0
- examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +225 -0
- examples/task_apps/pokemon_red/pallet_town_rl_config.toml +73 -0
- examples/task_apps/pokemon_red/task_app.py +606 -0
- examples/task_apps/pokemon_red/test_pallet_town_rewards.py +191 -0
- examples/task_apps/sokoban/README.md +307 -0
- examples/task_apps/sokoban/__init__.py +3 -0
- examples/task_apps/sokoban/eval_groq_qwen32.toml +16 -0
- examples/task_apps/sokoban/eval_openai_gpt5.toml +16 -0
- examples/task_apps/sokoban/task_app.py +1058 -0
- examples/task_apps/sokoban/tests/__init__.py +2 -0
- examples/task_apps/sokoban/tests/conftest.py +113 -0
- examples/task_apps/sokoban/tests/integration/__init__.py +2 -0
- examples/task_apps/sokoban/tests/integration/test_sokoban_eval.py +57 -0
- examples/task_apps/sokoban/tests/integration/test_sokoban_rollout.py +198 -0
- examples/task_apps/sokoban/tests/unit/__init__.py +2 -0
- examples/task_apps/sokoban/tests/unit/test_sokoban_environment.py +114 -0
- examples/task_apps/verilog/__init__.py +1 -0
- examples/task_apps/verilog/eval_groq_qwen32b.toml +20 -0
- examples/task_apps/verilog/task_app/README.md +12 -0
- examples/task_apps/verilog/task_app/__init__.py +1 -0
- examples/task_apps/verilog/task_app/grpo_verilog.py +931 -0
- examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +145 -0
- examples/task_apps/verilog/tests/__init__.py +2 -0
- examples/task_apps/verilog/tests/conftest.py +115 -0
- examples/task_apps/verilog/tests/integration/__init__.py +2 -0
- examples/task_apps/verilog/tests/integration/test_verilog_eval.py +179 -0
- examples/task_apps/verilog/tests/integration/test_verilog_rollout.py +55 -0
- examples/task_apps/verilog/tests/unit/__init__.py +2 -0
- examples/task_apps/verilog/tests/unit/test_verilog_scoring.py +118 -0
- examples/vlm/crafter_openai_vlm_agent.py +4 -4
- examples/vlm/run_crafter_vlm_benchmark.py +4 -4
- examples/workflows/__init__.py +0 -0
- examples/workflows/math_rl/__init__.py +0 -0
- examples/workflows/math_rl/download_dataset.py +80 -0
- synth_ai/__init__.py +2 -2
- synth_ai/api/train/builders.py +25 -11
- synth_ai/api/train/cli.py +12 -6
- synth_ai/api/train/configs/__init__.py +10 -10
- synth_ai/api/train/configs/rl.py +5 -4
- synth_ai/api/train/configs/sft.py +4 -3
- synth_ai/api/train/env_resolver.py +5 -2
- synth_ai/api/train/supported_algos.py +10 -5
- synth_ai/api/train/utils.py +7 -4
- synth_ai/cli/__init__.py +7 -51
- synth_ai/cli/_storage.py +4 -3
- synth_ai/cli/_validate_task_app.py +11 -0
- synth_ai/cli/balance.py +4 -3
- synth_ai/cli/calc.py +2 -2
- synth_ai/cli/demo.py +14 -7
- synth_ai/cli/legacy_root_backup.py +1 -1
- synth_ai/cli/rl_demo.py +8 -7
- synth_ai/cli/root.py +0 -97
- synth_ai/cli/task_apps.py +1707 -186
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +28 -16
- synth_ai/environments/examples/enron/engine.py +7 -2
- synth_ai/environments/examples/enron/environment.py +68 -0
- synth_ai/environments/examples/red/engine.py +27 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +7 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_progression.py +477 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +32 -0
- synth_ai/environments/examples/red/environment.py +60 -0
- synth_ai/environments/examples/sokoban/taskset.py +116 -0
- synth_ai/environments/examples/verilog/engine.py +30 -4
- synth_ai/evals/client.py +58 -61
- synth_ai/jobs/client.py +16 -4
- synth_ai/judge_schemas.py +16 -16
- synth_ai/py.typed +0 -0
- synth_ai/task/__init__.py +14 -5
- synth_ai/task/contracts.py +124 -38
- synth_ai/task/proxy.py +48 -56
- synth_ai/task/rubrics/__init__.py +53 -0
- synth_ai/task/rubrics/loaders.py +133 -0
- synth_ai/task/rubrics/models.py +57 -0
- synth_ai/task/rubrics/scoring.py +113 -0
- synth_ai/{rubrics/validators.py → task/rubrics/strict.py} +53 -30
- synth_ai/task/server.py +8 -7
- synth_ai/task/validators.py +269 -6
- synth_ai/tracing_v3/decorators.py +7 -3
- synth_ai/tracing_v3/replica_sync.py +4 -4
- synth_ai/tracing_v3/serialization.py +5 -5
- synth_ai/tracing_v3/trace_utils.py +317 -0
- synth_ai/tracing_v3/turso/native_manager.py +3 -3
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/METADATA +4 -1
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/RECORD +214 -101
- examples/agora_ex/README_MoE.md +0 -224
- examples/agora_ex/__init__.py +0 -7
- examples/agora_ex/agora_ex.py +0 -65
- examples/agora_ex/agora_ex_task_app.py +0 -590
- examples/agora_ex/configs/rl_lora_qwen3_moe_2xh200.toml +0 -121
- examples/agora_ex/reward_fn_grpo-human.py +0 -129
- examples/agora_ex/system_prompt_CURRENT.md +0 -63
- examples/agora_ex/task_app/agora_ex_task_app.py +0 -590
- examples/agora_ex/task_app/reward_fn_grpo-human.py +0 -129
- examples/agora_ex/task_app/system_prompt_CURRENT.md +0 -63
- synth_ai/rubrics/__init__.py +0 -22
- synth_ai/task/rubrics.py +0 -219
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/README.md +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/README.md +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/branching.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/environment_routes.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/app.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/environment.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/react_agent.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/shared.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/tools.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/hosted_app.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/openai_client.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/main.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/registry.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/volume.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_agents.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/utils.py +0 -0
- /examples/{rl/task_app → task_apps/math}/README.md +0 -0
- /examples/{rl/task_app → task_apps/math}/math_task_app.py +0 -0
- /examples/{rl → workflows/math_rl}/configs/eval_base_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/eval_rl_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen17.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_ft_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/run_eval.py +0 -0
- /examples/{rl → workflows/math_rl}/run_rl_and_save.py +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1058 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
import re
|
|
6
|
+
import contextlib
|
|
7
|
+
import time
|
|
8
|
+
from datetime import datetime, timezone
|
|
9
|
+
from typing import Any, Dict, List, Mapping, Optional, Sequence
|
|
10
|
+
|
|
11
|
+
import httpx
|
|
12
|
+
|
|
13
|
+
from fastapi import APIRouter, HTTPException, Request
|
|
14
|
+
from fastapi.exceptions import RequestValidationError
|
|
15
|
+
from fastapi.responses import JSONResponse
|
|
16
|
+
|
|
17
|
+
from synth_ai.environments.environment.tools import EnvToolCall
|
|
18
|
+
from synth_ai.environments.examples.sokoban.environment import SokobanEnvironment
|
|
19
|
+
from synth_ai.environments.examples.sokoban.taskset import (
|
|
20
|
+
SokobanTaskInstance,
|
|
21
|
+
SokobanTaskSet,
|
|
22
|
+
create_task_instance_from_seed,
|
|
23
|
+
)
|
|
24
|
+
from synth_ai.task.apps import TaskAppEntry, register_task_app
|
|
25
|
+
from synth_ai.task.contracts import (
|
|
26
|
+
RolloutMetrics,
|
|
27
|
+
RolloutRequest,
|
|
28
|
+
RolloutResponse,
|
|
29
|
+
RolloutStep,
|
|
30
|
+
RolloutTrajectory,
|
|
31
|
+
TaskInfo,
|
|
32
|
+
)
|
|
33
|
+
from synth_ai.task.auth import is_api_key_header_authorized, normalize_environment_api_key
|
|
34
|
+
from synth_ai.task.server import TaskAppConfig, create_task_app
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
ACTION_ID_TO_NAME = {0: "left", 1: "up", 2: "right", 3: "down"}
|
|
38
|
+
ACTION_TOKEN_TO_ID = {
|
|
39
|
+
"0": 0,
|
|
40
|
+
"1": 1,
|
|
41
|
+
"2": 2,
|
|
42
|
+
"3": 3,
|
|
43
|
+
"left": 0,
|
|
44
|
+
"move_left": 0,
|
|
45
|
+
"west": 0,
|
|
46
|
+
"l": 0,
|
|
47
|
+
"up": 1,
|
|
48
|
+
"move_up": 1,
|
|
49
|
+
"north": 1,
|
|
50
|
+
"u": 1,
|
|
51
|
+
"right": 2,
|
|
52
|
+
"move_right": 2,
|
|
53
|
+
"east": 2,
|
|
54
|
+
"r": 2,
|
|
55
|
+
"down": 3,
|
|
56
|
+
"move_down": 3,
|
|
57
|
+
"south": 3,
|
|
58
|
+
"d": 3,
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
SOKOBAN_SYSTEM_PROMPT = """You are an agent playing Sokoban.
|
|
62
|
+
The grid uses characters: '#' wall, '_' floor, 'O' box, '√' box on target, 'X' target, 'P' player.
|
|
63
|
+
Always respond with a single tool call named interact_many containing 1-5 actions.
|
|
64
|
+
Valid action tokens are digits 0/1/2/3 or their direction words (left/up/right/down).
|
|
65
|
+
Mapping: 0=left, 1=up, 2=right, 3=down. Avoid undoing progress and focus on pushing boxes onto targets."""
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _short_text(value: Any, *, limit: int = 280) -> str:
|
|
69
|
+
if value is None:
|
|
70
|
+
return ""
|
|
71
|
+
if isinstance(value, str):
|
|
72
|
+
text = value.strip()
|
|
73
|
+
return text if len(text) <= limit else text[: limit - 1] + "…"
|
|
74
|
+
if isinstance(value, (int, float, bool)):
|
|
75
|
+
return str(value)
|
|
76
|
+
try:
|
|
77
|
+
text = json.dumps(value, ensure_ascii=False)
|
|
78
|
+
except Exception:
|
|
79
|
+
text = str(value)
|
|
80
|
+
text = text.strip()
|
|
81
|
+
return text if len(text) <= limit else text[: limit - 1] + "…"
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _summarize_observation(observation: Any) -> str:
|
|
85
|
+
if isinstance(observation, dict):
|
|
86
|
+
for key in ("room_text", "observation", "grid"):
|
|
87
|
+
value = observation.get(key)
|
|
88
|
+
if isinstance(value, str) and value.strip():
|
|
89
|
+
return _short_text(value, limit=512)
|
|
90
|
+
preview = {
|
|
91
|
+
key: observation.get(key)
|
|
92
|
+
for key in ("player_position", "boxes_on_target", "num_boxes", "steps_taken")
|
|
93
|
+
if key in observation
|
|
94
|
+
}
|
|
95
|
+
if preview:
|
|
96
|
+
return _short_text(preview, limit=512)
|
|
97
|
+
return _short_text(observation, limit=512)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _format_tool_calls(tool_calls: Sequence[Dict[str, Any]] | None) -> str:
|
|
101
|
+
if not tool_calls:
|
|
102
|
+
return "<noop>"
|
|
103
|
+
formatted: list[str] = []
|
|
104
|
+
for call in tool_calls:
|
|
105
|
+
args = call.get("args") if isinstance(call, dict) else None
|
|
106
|
+
if not isinstance(args, dict):
|
|
107
|
+
continue
|
|
108
|
+
if "actions" in args and isinstance(args["actions"], list):
|
|
109
|
+
parts: list[str] = []
|
|
110
|
+
for item in args["actions"]:
|
|
111
|
+
try:
|
|
112
|
+
val = int(item)
|
|
113
|
+
except Exception:
|
|
114
|
+
token = str(item).strip().lower()
|
|
115
|
+
val = ACTION_TOKEN_TO_ID.get(token)
|
|
116
|
+
name = ACTION_ID_TO_NAME.get(val, str(item)) if val is not None else str(item)
|
|
117
|
+
parts.append(str(name))
|
|
118
|
+
if parts:
|
|
119
|
+
formatted.append("[" + ", ".join(parts) + "]")
|
|
120
|
+
continue
|
|
121
|
+
action = args.get("action")
|
|
122
|
+
if action is None:
|
|
123
|
+
continue
|
|
124
|
+
try:
|
|
125
|
+
action = int(action)
|
|
126
|
+
except Exception:
|
|
127
|
+
token = str(action).strip().lower()
|
|
128
|
+
action = ACTION_TOKEN_TO_ID.get(token, action)
|
|
129
|
+
name = ACTION_ID_TO_NAME.get(action, str(action))
|
|
130
|
+
formatted.append(str(name))
|
|
131
|
+
return ", ".join(formatted) if formatted else "<noop>"
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _build_trace_payload(
|
|
135
|
+
request: RolloutRequest,
|
|
136
|
+
steps: Sequence[RolloutStep],
|
|
137
|
+
metrics: RolloutMetrics,
|
|
138
|
+
*,
|
|
139
|
+
difficulty: str,
|
|
140
|
+
initial_observation: Any,
|
|
141
|
+
provider: str = "local",
|
|
142
|
+
) -> Dict[str, Any]:
|
|
143
|
+
created_at = datetime.now(timezone.utc)
|
|
144
|
+
base_time = time.time()
|
|
145
|
+
event_history: list[dict[str, Any]] = []
|
|
146
|
+
markov_messages: list[dict[str, Any]] = []
|
|
147
|
+
session_steps: list[dict[str, Any]] = []
|
|
148
|
+
|
|
149
|
+
if not steps:
|
|
150
|
+
observation_text = _summarize_observation(initial_observation)
|
|
151
|
+
event_time = base_time
|
|
152
|
+
observation_msg = {
|
|
153
|
+
"content": {"text": observation_text},
|
|
154
|
+
"message_type": "observation",
|
|
155
|
+
"time_record": {"event_time": event_time},
|
|
156
|
+
"metadata": {"step_index": 0},
|
|
157
|
+
}
|
|
158
|
+
markov_messages.append(observation_msg)
|
|
159
|
+
event_history.append(
|
|
160
|
+
{
|
|
161
|
+
"system_instance_id": "sokoban.step.0",
|
|
162
|
+
"time_record": {"event_time": event_time},
|
|
163
|
+
"reward": 0.0,
|
|
164
|
+
"terminated": True,
|
|
165
|
+
"truncated": False,
|
|
166
|
+
"metadata": {
|
|
167
|
+
"tool_calls": [],
|
|
168
|
+
},
|
|
169
|
+
}
|
|
170
|
+
)
|
|
171
|
+
session_steps.append(
|
|
172
|
+
{
|
|
173
|
+
"step_id": "step_0",
|
|
174
|
+
"step_index": 0,
|
|
175
|
+
"events": [event_history[-1]],
|
|
176
|
+
"markov_blanket_messages": markov_messages[-1:],
|
|
177
|
+
"step_metadata": {"reward": 0.0, "done": True, "truncated": False},
|
|
178
|
+
}
|
|
179
|
+
)
|
|
180
|
+
else:
|
|
181
|
+
for idx, step in enumerate(steps):
|
|
182
|
+
event_time = base_time + idx * 0.01
|
|
183
|
+
observation_text = _summarize_observation(step.obs)
|
|
184
|
+
action_text = _format_tool_calls(step.tool_calls)
|
|
185
|
+
observation_msg = {
|
|
186
|
+
"content": {"text": observation_text},
|
|
187
|
+
"message_type": "observation",
|
|
188
|
+
"time_record": {"event_time": event_time},
|
|
189
|
+
"metadata": {"step_index": idx},
|
|
190
|
+
}
|
|
191
|
+
action_msg = {
|
|
192
|
+
"content": {"text": action_text},
|
|
193
|
+
"message_type": "action",
|
|
194
|
+
"time_record": {"event_time": event_time + 0.0005},
|
|
195
|
+
"metadata": {"step_index": idx},
|
|
196
|
+
}
|
|
197
|
+
markov_messages.extend([observation_msg, action_msg])
|
|
198
|
+
reward_val = float(step.reward or 0.0)
|
|
199
|
+
event_history.append(
|
|
200
|
+
{
|
|
201
|
+
"system_instance_id": f"sokoban.step.{idx}",
|
|
202
|
+
"time_record": {"event_time": event_time},
|
|
203
|
+
"reward": reward_val,
|
|
204
|
+
"terminated": bool(step.done),
|
|
205
|
+
"truncated": bool(step.truncated),
|
|
206
|
+
"metadata": {
|
|
207
|
+
"tool_calls": step.tool_calls,
|
|
208
|
+
"info": step.info or {},
|
|
209
|
+
},
|
|
210
|
+
}
|
|
211
|
+
)
|
|
212
|
+
session_steps.append(
|
|
213
|
+
{
|
|
214
|
+
"step_id": f"step_{idx}",
|
|
215
|
+
"step_index": idx,
|
|
216
|
+
"events": [event_history[-1]],
|
|
217
|
+
"markov_blanket_messages": [observation_msg, action_msg],
|
|
218
|
+
"step_metadata": {
|
|
219
|
+
"reward": reward_val,
|
|
220
|
+
"done": bool(step.done),
|
|
221
|
+
"truncated": bool(step.truncated),
|
|
222
|
+
},
|
|
223
|
+
}
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
session_trace = {
|
|
227
|
+
"session_id": str(request.run_id),
|
|
228
|
+
"created_at": created_at.isoformat(),
|
|
229
|
+
"metadata": {
|
|
230
|
+
"task": "sokoban",
|
|
231
|
+
"difficulty": difficulty,
|
|
232
|
+
"seed": request.env.seed,
|
|
233
|
+
"provider": provider,
|
|
234
|
+
"env": request.env.model_dump(),
|
|
235
|
+
"policy": request.policy.model_dump(),
|
|
236
|
+
},
|
|
237
|
+
"session_time_steps": session_steps,
|
|
238
|
+
"event_history": event_history,
|
|
239
|
+
"markov_blanket_message_history": markov_messages,
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
return {
|
|
243
|
+
"version": 3,
|
|
244
|
+
"session_trace": session_trace,
|
|
245
|
+
"run_id": request.run_id,
|
|
246
|
+
"policy_id": request.policy.policy_id or request.policy.policy_name,
|
|
247
|
+
"reward": metrics.mean_return,
|
|
248
|
+
"episode_returns": metrics.episode_returns,
|
|
249
|
+
"mean_return": metrics.mean_return,
|
|
250
|
+
"num_steps": metrics.num_steps,
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def _task_info() -> TaskInfo:
|
|
256
|
+
return TaskInfo(
|
|
257
|
+
task={"id": "sokoban", "name": "Sokoban", "version": "1.0.0"},
|
|
258
|
+
environment="sokoban",
|
|
259
|
+
action_space={
|
|
260
|
+
"type": "tool_call",
|
|
261
|
+
"tools": [{"name": "interact", "schema": {"action": "int"}}],
|
|
262
|
+
"max_calls": 1,
|
|
263
|
+
},
|
|
264
|
+
observation={"summary": "Sokoban grid observation", "keys": ["grid", "player"]},
|
|
265
|
+
dataset={"id": "sokoban", "name": "Sokoban", "version": "1.0.0"},
|
|
266
|
+
rubric={"version": "1", "criteria_count": 1, "source": "inline"},
|
|
267
|
+
inference={"supports_proxy": False},
|
|
268
|
+
limits={"max_turns": 200},
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
router = APIRouter()
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
async def rollout_executor(request: RolloutRequest, fastapi_request: Request) -> RolloutResponse:
|
|
276
|
+
policy_cfg = dict(request.policy.config or {})
|
|
277
|
+
provider = str(policy_cfg.get("provider") or "").strip().lower()
|
|
278
|
+
if provider == "groq":
|
|
279
|
+
return await _rollout_with_groq(request, fastapi_request, policy_cfg)
|
|
280
|
+
if provider == "openai":
|
|
281
|
+
return await _rollout_with_openai(request, fastapi_request, policy_cfg)
|
|
282
|
+
|
|
283
|
+
taskset: SokobanTaskSet = fastapi_request.app.state.sokoban_taskset
|
|
284
|
+
seed = request.env.seed or 0
|
|
285
|
+
difficulty = (request.env.config or {}).get("difficulty") or "easy"
|
|
286
|
+
# Create deterministic instance from seed
|
|
287
|
+
instance: SokobanTaskInstance = await create_task_instance_from_seed(str(difficulty), int(seed))
|
|
288
|
+
env = SokobanEnvironment(instance)
|
|
289
|
+
obs = await env.initialize()
|
|
290
|
+
initial_observation = obs
|
|
291
|
+
|
|
292
|
+
tool_calls: List[Dict[str, Any]] = []
|
|
293
|
+
# If a predefined action sequence is provided, execute it (evaluation-style)
|
|
294
|
+
actions: Optional[Sequence[int]] = None
|
|
295
|
+
try:
|
|
296
|
+
cfg = request.policy.config or {}
|
|
297
|
+
if isinstance(cfg.get("actions"), list):
|
|
298
|
+
actions = [int(a) for a in cfg["actions"]]
|
|
299
|
+
except Exception:
|
|
300
|
+
actions = None
|
|
301
|
+
|
|
302
|
+
last_obs: Any = obs
|
|
303
|
+
steps: List[RolloutStep] = []
|
|
304
|
+
max_steps = int((request.env.config or {}).get("max_steps") or 50)
|
|
305
|
+
executed = 0
|
|
306
|
+
if actions:
|
|
307
|
+
for a in actions[:max_steps]:
|
|
308
|
+
last_obs = await env.step(EnvToolCall(tool="interact", args={"action": int(a)}))
|
|
309
|
+
executed += 1
|
|
310
|
+
steps.append(
|
|
311
|
+
RolloutStep(obs=last_obs, tool_calls=[{"tool": "interact", "args": {"action": int(a)}}], reward=0.0, done=False, info={})
|
|
312
|
+
)
|
|
313
|
+
# Mark episode end (single-episode trajectory)
|
|
314
|
+
final = {"observation": last_obs, "reward": 0.0}
|
|
315
|
+
if not steps:
|
|
316
|
+
steps = [RolloutStep(obs=last_obs, tool_calls=[], reward=0.0, done=True, info={})]
|
|
317
|
+
|
|
318
|
+
# Extract inference_url from policy config (None for manual rollouts)
|
|
319
|
+
inference_url = policy_cfg.get("inference_url")
|
|
320
|
+
|
|
321
|
+
traj = RolloutTrajectory(
|
|
322
|
+
env_id="sokoban",
|
|
323
|
+
policy_id=request.policy.policy_id or "policy",
|
|
324
|
+
steps=steps,
|
|
325
|
+
final=final,
|
|
326
|
+
length=len(steps),
|
|
327
|
+
inference_url=inference_url, # NEW: Required for trace correlation
|
|
328
|
+
)
|
|
329
|
+
metrics = RolloutMetrics(
|
|
330
|
+
episode_returns=[final.get("reward", 0.0) or 0.0],
|
|
331
|
+
mean_return=final.get("reward", 0.0) or 0.0,
|
|
332
|
+
num_steps=len(steps),
|
|
333
|
+
num_episodes=1,
|
|
334
|
+
outcome_score=None,
|
|
335
|
+
events_score=None,
|
|
336
|
+
details={},
|
|
337
|
+
)
|
|
338
|
+
trace_payload = _build_trace_payload(
|
|
339
|
+
request,
|
|
340
|
+
steps,
|
|
341
|
+
metrics,
|
|
342
|
+
difficulty=str(difficulty),
|
|
343
|
+
initial_observation=initial_observation,
|
|
344
|
+
)
|
|
345
|
+
return RolloutResponse(
|
|
346
|
+
run_id=request.run_id,
|
|
347
|
+
trajectories=[traj],
|
|
348
|
+
branches={},
|
|
349
|
+
metrics=metrics,
|
|
350
|
+
aborted=False,
|
|
351
|
+
ops_executed=1 + executed,
|
|
352
|
+
trace=trace_payload,
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def _format_sokoban_prompt(observation: dict[str, Any], last_actions: list[int]) -> str:
|
|
357
|
+
grid = observation.get("room_text", "")
|
|
358
|
+
boxes = observation.get("boxes_on_target", 0)
|
|
359
|
+
total_boxes = observation.get("num_boxes", boxes)
|
|
360
|
+
position = observation.get("player_position", ())
|
|
361
|
+
reward_last = observation.get("reward_last", 0.0)
|
|
362
|
+
steps_taken = observation.get("steps_taken", 0)
|
|
363
|
+
max_steps = observation.get("max_steps", 0)
|
|
364
|
+
last_str = (
|
|
365
|
+
", ".join(ACTION_ID_TO_NAME.get(a, str(a)) for a in last_actions) if last_actions else "none"
|
|
366
|
+
)
|
|
367
|
+
return (
|
|
368
|
+
f"Step {steps_taken} / {max_steps}\n"
|
|
369
|
+
f"Player position: {position}\n"
|
|
370
|
+
f"Boxes on target: {boxes} / {total_boxes}\n"
|
|
371
|
+
f"Last reward: {reward_last}\n"
|
|
372
|
+
f"Previous actions: {last_str}\n"
|
|
373
|
+
"Grid:\n"
|
|
374
|
+
f"{grid}\n"
|
|
375
|
+
"Select up to five next actions via the interact_many tool."
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
def _extract_actions_from_response(
|
|
380
|
+
response: dict[str, Any], max_actions: int
|
|
381
|
+
) -> list[int]:
|
|
382
|
+
import json as json_lib
|
|
383
|
+
print(f"[extract] FULL RESPONSE:", flush=True)
|
|
384
|
+
print(json_lib.dumps(response, indent=2)[:2000], flush=True)
|
|
385
|
+
|
|
386
|
+
actions: list[int] = []
|
|
387
|
+
choices = response.get("choices") or []
|
|
388
|
+
print(f"[extract] {len(choices)} choices", flush=True)
|
|
389
|
+
if choices:
|
|
390
|
+
msg = choices[0].get("message", {})
|
|
391
|
+
print(f"[extract] tool_calls: {msg.get('tool_calls')}", flush=True)
|
|
392
|
+
print(f"[extract] content: {msg.get('content')}", flush=True)
|
|
393
|
+
print(f"[extract] finish_reason: {choices[0].get('finish_reason')}", flush=True)
|
|
394
|
+
for choice in choices:
|
|
395
|
+
message = choice.get("message") or {}
|
|
396
|
+
tool_calls = message.get("tool_calls") or []
|
|
397
|
+
for tool_call in tool_calls:
|
|
398
|
+
function = tool_call.get("function") or {}
|
|
399
|
+
arguments = function.get("arguments")
|
|
400
|
+
payload: dict[str, Any] | None = None
|
|
401
|
+
if isinstance(arguments, str):
|
|
402
|
+
try:
|
|
403
|
+
payload = json.loads(arguments)
|
|
404
|
+
except json.JSONDecodeError:
|
|
405
|
+
payload = None
|
|
406
|
+
elif isinstance(arguments, dict):
|
|
407
|
+
payload = arguments
|
|
408
|
+
if not payload:
|
|
409
|
+
continue
|
|
410
|
+
raw_actions = payload.get("actions")
|
|
411
|
+
if isinstance(raw_actions, list):
|
|
412
|
+
for item in raw_actions:
|
|
413
|
+
if isinstance(item, int) and item in ACTION_ID_TO_NAME:
|
|
414
|
+
actions.append(int(item))
|
|
415
|
+
continue
|
|
416
|
+
if isinstance(item, str):
|
|
417
|
+
token = item.strip().lower()
|
|
418
|
+
if token in ACTION_TOKEN_TO_ID:
|
|
419
|
+
actions.append(ACTION_TOKEN_TO_ID[token])
|
|
420
|
+
if actions:
|
|
421
|
+
break
|
|
422
|
+
if actions:
|
|
423
|
+
break
|
|
424
|
+
|
|
425
|
+
if not actions and choices:
|
|
426
|
+
# Fallback: parse tokens from assistant text
|
|
427
|
+
text = choices[0].get("message", {}).get("content") or ""
|
|
428
|
+
tokens = re.findall(r"[0-3a-zA-Z_]+", text)
|
|
429
|
+
for tok in tokens:
|
|
430
|
+
token = tok.strip().lower()
|
|
431
|
+
if token in ACTION_TOKEN_TO_ID:
|
|
432
|
+
actions.append(ACTION_TOKEN_TO_ID[token])
|
|
433
|
+
|
|
434
|
+
if len(actions) > max_actions:
|
|
435
|
+
return actions[:max_actions]
|
|
436
|
+
return actions
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
async def _call_groq_chat(
|
|
440
|
+
client: httpx.AsyncClient,
|
|
441
|
+
api_key: str,
|
|
442
|
+
payload: dict[str, Any],
|
|
443
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
444
|
+
try:
|
|
445
|
+
response = await client.post(
|
|
446
|
+
"https://api.groq.com/openai/v1/chat/completions",
|
|
447
|
+
json=payload,
|
|
448
|
+
headers={"Authorization": f"Bearer {api_key}"},
|
|
449
|
+
)
|
|
450
|
+
response.raise_for_status()
|
|
451
|
+
data = response.json()
|
|
452
|
+
return data, {
|
|
453
|
+
"status": response.status_code,
|
|
454
|
+
"headers": dict(response.headers),
|
|
455
|
+
"body": data,
|
|
456
|
+
}
|
|
457
|
+
except httpx.HTTPStatusError as exc:
|
|
458
|
+
try:
|
|
459
|
+
body = exc.response.json()
|
|
460
|
+
except Exception:
|
|
461
|
+
body = {"raw": exc.response.text}
|
|
462
|
+
error_detail = {
|
|
463
|
+
"status": exc.response.status_code,
|
|
464
|
+
"body": body,
|
|
465
|
+
"headers": dict(exc.response.headers),
|
|
466
|
+
}
|
|
467
|
+
raise HTTPException(status_code=exc.response.status_code, detail=error_detail) from exc
|
|
468
|
+
except httpx.RequestError as exc:
|
|
469
|
+
raise HTTPException(status_code=502, detail=f"Groq request error: {exc}") from exc
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
async def _call_openai_chat(
|
|
473
|
+
client: httpx.AsyncClient,
|
|
474
|
+
api_key: str,
|
|
475
|
+
payload: dict[str, Any],
|
|
476
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
477
|
+
try:
|
|
478
|
+
response = await client.post(
|
|
479
|
+
"https://api.openai.com/v1/chat/completions",
|
|
480
|
+
json=payload,
|
|
481
|
+
headers={"Authorization": f"Bearer {api_key}"},
|
|
482
|
+
)
|
|
483
|
+
response.raise_for_status()
|
|
484
|
+
data = response.json()
|
|
485
|
+
return data, {
|
|
486
|
+
"status": response.status_code,
|
|
487
|
+
"headers": dict(response.headers),
|
|
488
|
+
"body": data,
|
|
489
|
+
}
|
|
490
|
+
except httpx.HTTPStatusError as exc:
|
|
491
|
+
try:
|
|
492
|
+
body = exc.response.json()
|
|
493
|
+
except Exception:
|
|
494
|
+
body = {"raw": exc.response.text}
|
|
495
|
+
error_detail = {
|
|
496
|
+
"status": exc.response.status_code,
|
|
497
|
+
"body": body,
|
|
498
|
+
"headers": dict(exc.response.headers),
|
|
499
|
+
}
|
|
500
|
+
try:
|
|
501
|
+
print("[openai:error]", error_detail, flush=True)
|
|
502
|
+
except Exception:
|
|
503
|
+
pass
|
|
504
|
+
raise HTTPException(status_code=exc.response.status_code, detail=error_detail) from exc
|
|
505
|
+
except httpx.RequestError as exc:
|
|
506
|
+
raise HTTPException(status_code=502, detail=f"OpenAI request error: {exc}") from exc
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
async def _rollout_with_groq(
|
|
510
|
+
request: RolloutRequest,
|
|
511
|
+
fastapi_request: Request,
|
|
512
|
+
config: dict[str, Any],
|
|
513
|
+
) -> RolloutResponse:
|
|
514
|
+
api_key = os.getenv("GROQ_API_KEY")
|
|
515
|
+
if not api_key:
|
|
516
|
+
raise HTTPException(
|
|
517
|
+
status_code=503,
|
|
518
|
+
detail="GROQ_API_KEY environment variable is required for Groq rollouts.",
|
|
519
|
+
)
|
|
520
|
+
|
|
521
|
+
seed = request.env.seed or 0
|
|
522
|
+
difficulty = (request.env.config or {}).get("difficulty") or "easy"
|
|
523
|
+
instance: SokobanTaskInstance = await create_task_instance_from_seed(str(difficulty), int(seed))
|
|
524
|
+
env = SokobanEnvironment(instance)
|
|
525
|
+
observation = await env.initialize()
|
|
526
|
+
initial_observation = observation
|
|
527
|
+
|
|
528
|
+
model = config.get("model") or "qwen/qwen3-32b"
|
|
529
|
+
temperature = float(config.get("temperature", 0.0) or 0.0)
|
|
530
|
+
top_p = float(config.get("top_p", 0.95) or 0.95)
|
|
531
|
+
max_tokens = int(config.get("max_tokens", 128) or 128)
|
|
532
|
+
actions_per_call = int(config.get("max_actions_per_call", 4) or 4)
|
|
533
|
+
actions_per_call = max(1, min(8, actions_per_call))
|
|
534
|
+
|
|
535
|
+
max_steps = int((request.env.config or {}).get("max_steps") or 50)
|
|
536
|
+
|
|
537
|
+
steps: List[RolloutStep] = []
|
|
538
|
+
last_actions: list[int] = []
|
|
539
|
+
total_reward = float(observation.get("total_reward") or 0.0)
|
|
540
|
+
executed = 0
|
|
541
|
+
|
|
542
|
+
tool_items_enum = sorted(set(ACTION_TOKEN_TO_ID.keys()))
|
|
543
|
+
tool_schema = {
|
|
544
|
+
"type": "function",
|
|
545
|
+
"function": {
|
|
546
|
+
"name": "interact_many",
|
|
547
|
+
"description": "Execute a short sequence of Sokoban moves in order.",
|
|
548
|
+
"parameters": {
|
|
549
|
+
"type": "object",
|
|
550
|
+
"properties": {
|
|
551
|
+
"actions": {
|
|
552
|
+
"type": "array",
|
|
553
|
+
"items": {"type": "string", "enum": tool_items_enum},
|
|
554
|
+
"minItems": 1,
|
|
555
|
+
"maxItems": actions_per_call,
|
|
556
|
+
}
|
|
557
|
+
},
|
|
558
|
+
"required": ["actions"],
|
|
559
|
+
"additionalProperties": False,
|
|
560
|
+
},
|
|
561
|
+
},
|
|
562
|
+
}
|
|
563
|
+
|
|
564
|
+
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client:
|
|
565
|
+
for _ in range(max_steps):
|
|
566
|
+
user_prompt = _format_sokoban_prompt(observation, last_actions)
|
|
567
|
+
messages = [
|
|
568
|
+
{"role": "system", "content": SOKOBAN_SYSTEM_PROMPT},
|
|
569
|
+
{"role": "user", "content": user_prompt},
|
|
570
|
+
]
|
|
571
|
+
payload = {
|
|
572
|
+
"model": model,
|
|
573
|
+
"messages": messages,
|
|
574
|
+
"temperature": temperature,
|
|
575
|
+
"top_p": top_p,
|
|
576
|
+
"max_tokens": max_tokens,
|
|
577
|
+
"tools": [tool_schema],
|
|
578
|
+
"tool_choice": {"type": "function", "function": {"name": "interact_many"}},
|
|
579
|
+
}
|
|
580
|
+
vendor_attempts: list[dict[str, Any]] = []
|
|
581
|
+
try:
|
|
582
|
+
response, response_meta = await _call_groq_chat(client, api_key, payload)
|
|
583
|
+
vendor_attempts.append({"request": payload, "response": response_meta})
|
|
584
|
+
except HTTPException as exc:
|
|
585
|
+
detail = exc.detail
|
|
586
|
+
if isinstance(detail, dict):
|
|
587
|
+
vendor_attempts.append({"request": payload, "error": detail})
|
|
588
|
+
else:
|
|
589
|
+
vendor_attempts.append({"request": payload, "error": {"message": str(detail)}})
|
|
590
|
+
raise
|
|
591
|
+
|
|
592
|
+
actions = _extract_actions_from_response(response, actions_per_call)
|
|
593
|
+
if not actions:
|
|
594
|
+
break
|
|
595
|
+
|
|
596
|
+
aggregated_actions: list[int] = []
|
|
597
|
+
aggregated_reward = 0.0
|
|
598
|
+
done = False
|
|
599
|
+
truncated = False
|
|
600
|
+
intermediate_rewards: list[float] = []
|
|
601
|
+
if executed >= max_steps:
|
|
602
|
+
break
|
|
603
|
+
|
|
604
|
+
for action in actions:
|
|
605
|
+
if executed >= max_steps:
|
|
606
|
+
break
|
|
607
|
+
aggregated_actions.append(int(action))
|
|
608
|
+
observation = await env.step(
|
|
609
|
+
EnvToolCall(tool="interact", args={"action": int(action)})
|
|
610
|
+
)
|
|
611
|
+
current_total = float(observation.get("total_reward") or total_reward)
|
|
612
|
+
reward_delta = current_total - total_reward
|
|
613
|
+
total_reward = current_total
|
|
614
|
+
aggregated_reward += reward_delta
|
|
615
|
+
intermediate_rewards.append(reward_delta)
|
|
616
|
+
done = bool(observation.get("terminated"))
|
|
617
|
+
truncated = bool(observation.get("truncated"))
|
|
618
|
+
executed += 1
|
|
619
|
+
if done or truncated:
|
|
620
|
+
break
|
|
621
|
+
|
|
622
|
+
if not aggregated_actions:
|
|
623
|
+
continue
|
|
624
|
+
|
|
625
|
+
last_actions = aggregated_actions
|
|
626
|
+
step = RolloutStep(
|
|
627
|
+
obs=observation,
|
|
628
|
+
tool_calls=[
|
|
629
|
+
{
|
|
630
|
+
"tool": "interact_many",
|
|
631
|
+
"args": {"actions": [int(a) for a in aggregated_actions]},
|
|
632
|
+
"source": "groq",
|
|
633
|
+
}
|
|
634
|
+
],
|
|
635
|
+
reward=aggregated_reward,
|
|
636
|
+
done=done,
|
|
637
|
+
truncated=truncated if truncated else None,
|
|
638
|
+
info={
|
|
639
|
+
"provider": "groq",
|
|
640
|
+
"model": model,
|
|
641
|
+
"actions_executed": aggregated_actions,
|
|
642
|
+
"prompt": user_prompt,
|
|
643
|
+
"reward_deltas": intermediate_rewards,
|
|
644
|
+
"vendor_attempts": vendor_attempts,
|
|
645
|
+
"groq_attempts": vendor_attempts,
|
|
646
|
+
},
|
|
647
|
+
)
|
|
648
|
+
steps.append(step)
|
|
649
|
+
|
|
650
|
+
if step.done or (step.truncated or False):
|
|
651
|
+
break
|
|
652
|
+
|
|
653
|
+
final = {"observation": observation, "reward": total_reward}
|
|
654
|
+
inference_url_groq = "https://api.groq.com/openai/v1/chat/completions"
|
|
655
|
+
|
|
656
|
+
trajectory = RolloutTrajectory(
|
|
657
|
+
env_id=request.env.env_id or request.env.env_name or "sokoban",
|
|
658
|
+
policy_id=request.policy.policy_id or request.policy.policy_name or "sokoban-groq",
|
|
659
|
+
steps=steps,
|
|
660
|
+
final=final,
|
|
661
|
+
length=len(steps),
|
|
662
|
+
inference_url=inference_url_groq, # NEW: Required for trace correlation
|
|
663
|
+
)
|
|
664
|
+
metrics = RolloutMetrics(
|
|
665
|
+
episode_returns=[total_reward],
|
|
666
|
+
mean_return=total_reward if steps else 0.0,
|
|
667
|
+
num_steps=len(steps),
|
|
668
|
+
num_episodes=1,
|
|
669
|
+
outcome_score=None,
|
|
670
|
+
events_score=None,
|
|
671
|
+
details={"provider": "groq", "model": model},
|
|
672
|
+
)
|
|
673
|
+
trace_payload = _build_trace_payload(
|
|
674
|
+
request,
|
|
675
|
+
steps,
|
|
676
|
+
metrics,
|
|
677
|
+
difficulty=str(difficulty),
|
|
678
|
+
initial_observation=initial_observation,
|
|
679
|
+
provider="groq",
|
|
680
|
+
)
|
|
681
|
+
return RolloutResponse(
|
|
682
|
+
run_id=request.run_id,
|
|
683
|
+
trajectories=[trajectory],
|
|
684
|
+
branches={},
|
|
685
|
+
metrics=metrics,
|
|
686
|
+
aborted=False,
|
|
687
|
+
ops_executed=executed,
|
|
688
|
+
trace=trace_payload,
|
|
689
|
+
)
|
|
690
|
+
|
|
691
|
+
|
|
692
|
+
async def _rollout_with_openai(
|
|
693
|
+
request: RolloutRequest,
|
|
694
|
+
fastapi_request: Request,
|
|
695
|
+
config: dict[str, Any],
|
|
696
|
+
) -> RolloutResponse:
|
|
697
|
+
api_key = os.getenv("OPENAI_API_KEY")
|
|
698
|
+
if not api_key:
|
|
699
|
+
raise HTTPException(
|
|
700
|
+
status_code=503,
|
|
701
|
+
detail="OPENAI_API_KEY environment variable is required for OpenAI rollouts.",
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
seed = request.env.seed or 0
|
|
705
|
+
difficulty = (request.env.config or {}).get("difficulty") or "easy"
|
|
706
|
+
instance: SokobanTaskInstance = await create_task_instance_from_seed(str(difficulty), int(seed))
|
|
707
|
+
env = SokobanEnvironment(instance)
|
|
708
|
+
observation = await env.initialize()
|
|
709
|
+
initial_observation = observation
|
|
710
|
+
|
|
711
|
+
model = config.get("model") or "gpt-5"
|
|
712
|
+
temperature_cfg = config.get("temperature")
|
|
713
|
+
top_p_cfg = config.get("top_p")
|
|
714
|
+
completion_tokens = int(
|
|
715
|
+
config.get("max_completion_tokens")
|
|
716
|
+
or config.get("max_tokens")
|
|
717
|
+
or 4000
|
|
718
|
+
)
|
|
719
|
+
actions_per_call = int(config.get("max_actions_per_call", 4) or 4)
|
|
720
|
+
actions_per_call = max(1, min(8, actions_per_call))
|
|
721
|
+
|
|
722
|
+
max_steps = int((request.env.config or {}).get("max_steps") or 50)
|
|
723
|
+
|
|
724
|
+
steps: List[RolloutStep] = []
|
|
725
|
+
last_actions: list[int] = []
|
|
726
|
+
total_reward = float(observation.get("total_reward") or 0.0)
|
|
727
|
+
executed = 0
|
|
728
|
+
|
|
729
|
+
tool_items_enum = sorted(set(ACTION_TOKEN_TO_ID.keys()))
|
|
730
|
+
tool_schema = {
|
|
731
|
+
"type": "function",
|
|
732
|
+
"function": {
|
|
733
|
+
"name": "interact_many",
|
|
734
|
+
"description": "Execute a short sequence of Sokoban moves in order.",
|
|
735
|
+
"parameters": {
|
|
736
|
+
"type": "object",
|
|
737
|
+
"properties": {
|
|
738
|
+
"actions": {
|
|
739
|
+
"type": "array",
|
|
740
|
+
"items": {"type": "string", "enum": tool_items_enum},
|
|
741
|
+
"minItems": 1,
|
|
742
|
+
"maxItems": actions_per_call,
|
|
743
|
+
}
|
|
744
|
+
},
|
|
745
|
+
"required": ["actions"],
|
|
746
|
+
"additionalProperties": False,
|
|
747
|
+
},
|
|
748
|
+
},
|
|
749
|
+
}
|
|
750
|
+
|
|
751
|
+
async with httpx.AsyncClient(timeout=httpx.Timeout(120.0)) as client:
|
|
752
|
+
# Process ops array - each "policy" op triggers one LLM call
|
|
753
|
+
ops_to_process = request.ops or []
|
|
754
|
+
if not ops_to_process:
|
|
755
|
+
# If no ops provided, default to max_steps policy calls
|
|
756
|
+
ops_to_process = ["policy"] * max_steps
|
|
757
|
+
|
|
758
|
+
for op_idx, op in enumerate(ops_to_process):
|
|
759
|
+
# Only process "policy" ops, skip explicit actions for now
|
|
760
|
+
if op != "policy" and not (isinstance(op, str) and op.lower() == "policy"):
|
|
761
|
+
continue
|
|
762
|
+
|
|
763
|
+
user_prompt = _format_sokoban_prompt(observation, last_actions)
|
|
764
|
+
messages = [
|
|
765
|
+
{"role": "system", "content": SOKOBAN_SYSTEM_PROMPT},
|
|
766
|
+
{"role": "user", "content": user_prompt},
|
|
767
|
+
]
|
|
768
|
+
payload_base: dict[str, Any] = {
|
|
769
|
+
"model": model,
|
|
770
|
+
"messages": messages,
|
|
771
|
+
"max_completion_tokens": completion_tokens,
|
|
772
|
+
"tools": [tool_schema],
|
|
773
|
+
"tool_choice": {"type": "function", "function": {"name": "interact_many"}},
|
|
774
|
+
}
|
|
775
|
+
# GPT-5 models don't support temperature/top_p (only default value of 1)
|
|
776
|
+
is_gpt5 = "gpt-5" in model.lower()
|
|
777
|
+
if temperature_cfg is not None and not is_gpt5:
|
|
778
|
+
with contextlib.suppress(Exception):
|
|
779
|
+
payload_base["temperature"] = float(temperature_cfg)
|
|
780
|
+
if top_p_cfg is not None and not is_gpt5:
|
|
781
|
+
with contextlib.suppress(Exception):
|
|
782
|
+
payload_base["top_p"] = float(top_p_cfg)
|
|
783
|
+
|
|
784
|
+
vendor_attempts: list[dict[str, Any]] = []
|
|
785
|
+
attempt_payload = dict(payload_base)
|
|
786
|
+
while True:
|
|
787
|
+
attempt_record: dict[str, Any] = {"request": dict(attempt_payload)}
|
|
788
|
+
try:
|
|
789
|
+
response, response_meta = await _call_openai_chat(client, api_key, attempt_payload)
|
|
790
|
+
attempt_record["response"] = response_meta
|
|
791
|
+
vendor_attempts.append(attempt_record)
|
|
792
|
+
break
|
|
793
|
+
except HTTPException as exc:
|
|
794
|
+
detail = exc.detail
|
|
795
|
+
attempt_record["error"] = detail if isinstance(detail, dict) else {"message": str(detail)}
|
|
796
|
+
vendor_attempts.append(attempt_record)
|
|
797
|
+
handled = False
|
|
798
|
+
body = detail.get("body") if isinstance(detail, dict) else None
|
|
799
|
+
error_info = body.get("error") if isinstance(body, dict) else None
|
|
800
|
+
code = error_info.get("code") if isinstance(error_info, dict) else None
|
|
801
|
+
param = error_info.get("param") if isinstance(error_info, dict) else None
|
|
802
|
+
if code in {"unsupported_parameter", "unsupported_value"}:
|
|
803
|
+
if param == "temperature" and "temperature" in attempt_payload:
|
|
804
|
+
attempt_payload = dict(attempt_payload)
|
|
805
|
+
attempt_payload.pop("temperature", None)
|
|
806
|
+
handled = True
|
|
807
|
+
elif param == "top_p" and "top_p" in attempt_payload:
|
|
808
|
+
attempt_payload = dict(attempt_payload)
|
|
809
|
+
attempt_payload.pop("top_p", None)
|
|
810
|
+
handled = True
|
|
811
|
+
if handled:
|
|
812
|
+
continue
|
|
813
|
+
raise
|
|
814
|
+
|
|
815
|
+
actions = _extract_actions_from_response(response, actions_per_call)
|
|
816
|
+
if not actions:
|
|
817
|
+
break
|
|
818
|
+
|
|
819
|
+
aggregated_actions: list[int] = []
|
|
820
|
+
aggregated_reward = 0.0
|
|
821
|
+
done = False
|
|
822
|
+
truncated = False
|
|
823
|
+
intermediate_rewards: list[float] = []
|
|
824
|
+
if executed >= max_steps:
|
|
825
|
+
break
|
|
826
|
+
|
|
827
|
+
print(f"[debug] Processing {len(actions)} actions from LLM", flush=True)
|
|
828
|
+
for action in actions:
|
|
829
|
+
if executed >= max_steps:
|
|
830
|
+
break
|
|
831
|
+
aggregated_actions.append(int(action))
|
|
832
|
+
observation = await env.step(
|
|
833
|
+
EnvToolCall(tool="interact", args={"action": int(action)})
|
|
834
|
+
)
|
|
835
|
+
current_total = float(observation.get("total_reward") or total_reward)
|
|
836
|
+
reward_delta = current_total - total_reward
|
|
837
|
+
total_reward = current_total
|
|
838
|
+
aggregated_reward += reward_delta
|
|
839
|
+
intermediate_rewards.append(reward_delta)
|
|
840
|
+
done = bool(observation.get("terminated"))
|
|
841
|
+
truncated = bool(observation.get("truncated"))
|
|
842
|
+
executed += 1
|
|
843
|
+
if done or truncated:
|
|
844
|
+
break
|
|
845
|
+
|
|
846
|
+
print(f"[debug] After action {action}: done={done}, trunc={truncated}, exec={executed}", flush=True)
|
|
847
|
+
if not aggregated_actions:
|
|
848
|
+
continue
|
|
849
|
+
|
|
850
|
+
last_actions = aggregated_actions
|
|
851
|
+
step = RolloutStep(
|
|
852
|
+
obs=observation,
|
|
853
|
+
tool_calls=[
|
|
854
|
+
{
|
|
855
|
+
"tool": "interact_many",
|
|
856
|
+
"args": {"actions": [int(a) for a in aggregated_actions]},
|
|
857
|
+
"source": "openai",
|
|
858
|
+
}
|
|
859
|
+
],
|
|
860
|
+
reward=aggregated_reward,
|
|
861
|
+
done=done,
|
|
862
|
+
truncated=truncated if truncated else None,
|
|
863
|
+
info={
|
|
864
|
+
"provider": "openai",
|
|
865
|
+
"model": model,
|
|
866
|
+
"actions_executed": aggregated_actions,
|
|
867
|
+
"prompt": user_prompt,
|
|
868
|
+
"reward_deltas": intermediate_rewards,
|
|
869
|
+
"vendor_attempts": vendor_attempts,
|
|
870
|
+
"openai_attempts": vendor_attempts,
|
|
871
|
+
"max_completion_tokens": completion_tokens,
|
|
872
|
+
"temperature_requested": temperature_cfg,
|
|
873
|
+
"top_p_requested": top_p_cfg,
|
|
874
|
+
},
|
|
875
|
+
)
|
|
876
|
+
steps.append(step)
|
|
877
|
+
|
|
878
|
+
if step.done or (step.truncated or False):
|
|
879
|
+
break
|
|
880
|
+
|
|
881
|
+
final = {"observation": observation, "reward": total_reward}
|
|
882
|
+
inference_url_openai = "https://api.openai.com/v1/chat/completions"
|
|
883
|
+
|
|
884
|
+
trajectory = RolloutTrajectory(
|
|
885
|
+
env_id=request.env.env_id or request.env.env_name or "sokoban",
|
|
886
|
+
policy_id=request.policy.policy_id or request.policy.policy_name or "sokoban-openai",
|
|
887
|
+
steps=steps,
|
|
888
|
+
final=final,
|
|
889
|
+
length=len(steps),
|
|
890
|
+
inference_url=inference_url_openai, # NEW: Required for trace correlation
|
|
891
|
+
)
|
|
892
|
+
metrics = RolloutMetrics(
|
|
893
|
+
episode_returns=[total_reward],
|
|
894
|
+
mean_return=total_reward if steps else 0.0,
|
|
895
|
+
num_steps=len(steps),
|
|
896
|
+
num_episodes=1,
|
|
897
|
+
outcome_score=None,
|
|
898
|
+
events_score=None,
|
|
899
|
+
details={"provider": "openai", "model": model},
|
|
900
|
+
)
|
|
901
|
+
trace_payload = _build_trace_payload(
|
|
902
|
+
request,
|
|
903
|
+
steps,
|
|
904
|
+
metrics,
|
|
905
|
+
difficulty=str(difficulty),
|
|
906
|
+
initial_observation=initial_observation,
|
|
907
|
+
provider="openai",
|
|
908
|
+
)
|
|
909
|
+
return RolloutResponse(
|
|
910
|
+
run_id=request.run_id,
|
|
911
|
+
trajectories=[trajectory],
|
|
912
|
+
branches={},
|
|
913
|
+
metrics=metrics,
|
|
914
|
+
aborted=False,
|
|
915
|
+
ops_executed=executed,
|
|
916
|
+
trace=trace_payload,
|
|
917
|
+
)
|
|
918
|
+
|
|
919
|
+
|
|
920
|
+
def build_config() -> TaskAppConfig:
|
|
921
|
+
taskset = SokobanTaskSet()
|
|
922
|
+
base = _task_info()
|
|
923
|
+
app_state: dict[str, Any] = {"sokoban_taskset": taskset, "sokoban_envs": {}}
|
|
924
|
+
config = TaskAppConfig(
|
|
925
|
+
app_id="sokoban",
|
|
926
|
+
name="Sokoban Task App",
|
|
927
|
+
description="Sokoban environment exposed as a Synth task app.",
|
|
928
|
+
base_task_info=base,
|
|
929
|
+
describe_taskset=lambda: {"id": "sokoban", "name": "Sokoban"},
|
|
930
|
+
provide_task_instances=lambda seeds: taskset.provide_task_instances(seeds),
|
|
931
|
+
rollout=rollout_executor,
|
|
932
|
+
dataset_registry=None,
|
|
933
|
+
rubrics=None,
|
|
934
|
+
proxy=None,
|
|
935
|
+
routers=(router,),
|
|
936
|
+
app_state=app_state,
|
|
937
|
+
cors_origins=["*"],
|
|
938
|
+
)
|
|
939
|
+
return config
|
|
940
|
+
|
|
941
|
+
|
|
942
|
+
# --- Health routes (auth-tolerant) ---
|
|
943
|
+
def fastapi_app():
|
|
944
|
+
app = create_task_app(build_config())
|
|
945
|
+
|
|
946
|
+
# Replace default health handlers to log expected ENVIRONMENT_API_KEY when unauthorized
|
|
947
|
+
filtered_routes = []
|
|
948
|
+
for route in app.router.routes:
|
|
949
|
+
path = getattr(route, "path", None)
|
|
950
|
+
methods = getattr(route, "methods", set()) or set()
|
|
951
|
+
if path in {"/health", "/health/rollout"} and "GET" in methods:
|
|
952
|
+
continue
|
|
953
|
+
filtered_routes.append(route)
|
|
954
|
+
app.router.routes = filtered_routes
|
|
955
|
+
|
|
956
|
+
def _key_prefix() -> Optional[str]:
|
|
957
|
+
key = normalize_environment_api_key()
|
|
958
|
+
return key[: max(1, len(key) // 2)] if key else None
|
|
959
|
+
|
|
960
|
+
@app.get("/health")
|
|
961
|
+
async def health(request: Request):
|
|
962
|
+
env_key = normalize_environment_api_key()
|
|
963
|
+
if not env_key:
|
|
964
|
+
return JSONResponse(status_code=503, content={"status": "unhealthy", "detail": "Missing ENVIRONMENT_API_KEY"})
|
|
965
|
+
if not is_api_key_header_authorized(request):
|
|
966
|
+
content: Dict[str, Any] = {"status": "healthy", "authorized": False}
|
|
967
|
+
prefix = _key_prefix()
|
|
968
|
+
if prefix:
|
|
969
|
+
content["expected_api_key_prefix"] = prefix
|
|
970
|
+
return JSONResponse(status_code=200, content=content)
|
|
971
|
+
return {"status": "healthy", "authorized": True}
|
|
972
|
+
|
|
973
|
+
@app.get("/health/rollout")
|
|
974
|
+
async def health_rollout(request: Request):
|
|
975
|
+
env_key = normalize_environment_api_key()
|
|
976
|
+
if not env_key:
|
|
977
|
+
return JSONResponse(status_code=503, content={"status": "unhealthy", "detail": "Missing ENVIRONMENT_API_KEY"})
|
|
978
|
+
if not is_api_key_header_authorized(request):
|
|
979
|
+
content: Dict[str, Any] = {"status": "healthy", "authorized": False}
|
|
980
|
+
prefix = _key_prefix()
|
|
981
|
+
if prefix:
|
|
982
|
+
content["expected_api_key_prefix"] = prefix
|
|
983
|
+
return JSONResponse(status_code=200, content=content)
|
|
984
|
+
return {"ok": True, "authorized": True}
|
|
985
|
+
|
|
986
|
+
# Basic env lifecycle routes (for local eval only)
|
|
987
|
+
@app.post("/env/sokoban/initialize")
|
|
988
|
+
async def initialize_env(request: Request, payload: Dict[str, Any]):
|
|
989
|
+
difficulty = str((payload.get("config") or {}).get("difficulty") or "easy")
|
|
990
|
+
seed = payload.get("seed")
|
|
991
|
+
try:
|
|
992
|
+
instance: SokobanTaskInstance = await create_task_instance_from_seed(difficulty, int(seed) if seed is not None else 0)
|
|
993
|
+
except Exception as exc:
|
|
994
|
+
raise HTTPException(status_code=400, detail=str(exc))
|
|
995
|
+
env = SokobanEnvironment(instance)
|
|
996
|
+
obs = await env.initialize()
|
|
997
|
+
envs: Dict[str, SokobanEnvironment] = request.app.state.sokoban_envs
|
|
998
|
+
env_id = f"{difficulty}:{seed or 0}"
|
|
999
|
+
envs[env_id] = env
|
|
1000
|
+
return {"env_id": env_id, "observation": obs}
|
|
1001
|
+
|
|
1002
|
+
@app.post("/env/sokoban/step")
|
|
1003
|
+
async def step_env(request: Request, payload: Dict[str, Any]):
|
|
1004
|
+
env_id = str(payload.get("env_id") or "")
|
|
1005
|
+
if not env_id:
|
|
1006
|
+
raise HTTPException(status_code=400, detail="env_id required")
|
|
1007
|
+
envs: Dict[str, SokobanEnvironment] = request.app.state.sokoban_envs
|
|
1008
|
+
env = envs.get(env_id)
|
|
1009
|
+
if not env:
|
|
1010
|
+
raise HTTPException(status_code=404, detail="Unknown env_id")
|
|
1011
|
+
|
|
1012
|
+
action = None
|
|
1013
|
+
tool_calls = payload.get("tool_calls") or []
|
|
1014
|
+
if tool_calls:
|
|
1015
|
+
try:
|
|
1016
|
+
first = tool_calls[0] or {}
|
|
1017
|
+
args = first.get("args") or {}
|
|
1018
|
+
action = int(args.get("action")) if "action" in args else None
|
|
1019
|
+
except Exception:
|
|
1020
|
+
action = None
|
|
1021
|
+
if action is None and "action" in payload:
|
|
1022
|
+
try:
|
|
1023
|
+
action = int(payload.get("action"))
|
|
1024
|
+
except Exception:
|
|
1025
|
+
action = None
|
|
1026
|
+
if action is None:
|
|
1027
|
+
raise HTTPException(status_code=400, detail="action required")
|
|
1028
|
+
obs = await env.step(EnvToolCall(tool="interact", args={"action": int(action)}))
|
|
1029
|
+
return {"observation": obs}
|
|
1030
|
+
|
|
1031
|
+
@app.post("/env/sokoban/terminate")
|
|
1032
|
+
async def terminate_env(request: Request, payload: Dict[str, Any]):
|
|
1033
|
+
env_id = str(payload.get("env_id") or "")
|
|
1034
|
+
envs: Dict[str, SokobanEnvironment] = request.app.state.sokoban_envs
|
|
1035
|
+
env = envs.pop(env_id, None)
|
|
1036
|
+
if env:
|
|
1037
|
+
obs = await env.terminate()
|
|
1038
|
+
else:
|
|
1039
|
+
obs = {"terminated": True}
|
|
1040
|
+
return {"ok": True, "observation": obs}
|
|
1041
|
+
|
|
1042
|
+
@app.exception_handler(RequestValidationError)
|
|
1043
|
+
async def _on_validation_error(_request: Request, exc: RequestValidationError):
|
|
1044
|
+
return JSONResponse(status_code=422, content={"status": "invalid", "detail": exc.errors()[:5]})
|
|
1045
|
+
|
|
1046
|
+
return app
|
|
1047
|
+
|
|
1048
|
+
|
|
1049
|
+
register_task_app(
|
|
1050
|
+
entry=TaskAppEntry(
|
|
1051
|
+
app_id="sokoban",
|
|
1052
|
+
description="Sokoban task app",
|
|
1053
|
+
config_factory=build_config,
|
|
1054
|
+
aliases=("sokoban-rl",),
|
|
1055
|
+
env_files=(),
|
|
1056
|
+
modal=None,
|
|
1057
|
+
)
|
|
1058
|
+
)
|