synth-ai 0.2.13.dev1__py3-none-any.whl → 0.2.13.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (226) hide show
  1. examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +12 -1
  2. examples/swe/task_app/grpo_swe_mini.py +55 -26
  3. examples/swe/task_app/hosted/rollout.py +40 -0
  4. examples/swe/task_app/hosted/test_service.py +5 -6
  5. examples/task_apps/TESTING.md +275 -0
  6. examples/task_apps/__init__.py +0 -0
  7. examples/task_apps/crafter/__init__.py +0 -0
  8. examples/task_apps/crafter/task_app/__init__.py +2 -0
  9. examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter.py +18 -13
  10. examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter_task_app.py +1 -1
  11. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/policy.py +60 -4
  12. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/policy_routes.py +25 -3
  13. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/rollout.py +10 -0
  14. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_service.py +5 -6
  15. examples/task_apps/dev/pokemon_emerald/__init__.py +2 -0
  16. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/README.md +811 -0
  17. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/__init__.py +120 -0
  18. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/action.py +160 -0
  19. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/memory.py +155 -0
  20. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/perception.py +69 -0
  21. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/planning.py +96 -0
  22. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/simple.py +1502 -0
  23. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/system_prompt.py +4 -0
  24. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/grab_map.py +68 -0
  25. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/manual.py +216 -0
  26. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/__init__.py +35 -0
  27. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emerald_utils.py +631 -0
  28. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emulator.py +1544 -0
  29. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/enums.py +1428 -0
  30. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/memory_reader.py +4848 -0
  31. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/types.py +41 -0
  32. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/utils.py +298 -0
  33. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pyproject.toml +95 -0
  34. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/run.py +204 -0
  35. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/__init__.py +0 -0
  36. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/app.py +2152 -0
  37. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/client.py +429 -0
  38. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/frame_server.py +155 -0
  39. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/README.md +78 -0
  40. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/__init__.py +0 -0
  41. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/run_tests.py +122 -0
  42. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_direct.py +76 -0
  43. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_prompts.py +413 -0
  44. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_battle_state_formatting.py +204 -0
  45. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection.py +133 -0
  46. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection_comprehensive.py +229 -0
  47. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_direct_agent_emulator.py +300 -0
  48. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_fps_adjustment_pytest.py +205 -0
  49. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_direct.py +200 -0
  50. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_transition.py +284 -0
  51. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_map_ground_truth_comparison.py +468 -0
  52. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_memory_map.py +575 -0
  53. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_server_map_validation.py +311 -0
  54. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_torchic_state.py +259 -0
  55. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/__init__.py +0 -0
  56. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/anticheat.py +372 -0
  57. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/checkpoint.py +296 -0
  58. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/error_handler.py +275 -0
  59. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/get_local_ip.py +22 -0
  60. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/helpers.py +44 -0
  61. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/llm_logger.py +514 -0
  62. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_formatter.py +415 -0
  63. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher.py +1763 -0
  64. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher_singleton.py +33 -0
  65. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_trimmer.py +106 -0
  66. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_visualizer.py +334 -0
  67. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/ocr_dialogue.py +1020 -0
  68. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/recording.py +188 -0
  69. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/state_formatter.py +1481 -0
  70. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/vlm.py +862 -0
  71. examples/task_apps/dev/pokemon_emerald/modal_app.py +114 -0
  72. examples/task_apps/dev/pokemon_emerald/task_app/README.md +81 -0
  73. examples/task_apps/dev/pokemon_emerald/task_app/__init__.py +6 -0
  74. examples/task_apps/dev/pokemon_emerald/task_app/pokemon_emerald.py +685 -0
  75. examples/task_apps/enron/__init__.py +1 -0
  76. examples/task_apps/enron/eval_groq_qwen32.toml +16 -0
  77. examples/task_apps/enron/task_app/README.md +14 -0
  78. examples/task_apps/enron/task_app/__init__.py +1 -0
  79. examples/task_apps/enron/task_app/grpo_enron.py +906 -0
  80. examples/task_apps/enron/task_app/grpo_enron_task_app.py +146 -0
  81. examples/task_apps/enron/tests/__init__.py +2 -0
  82. examples/task_apps/enron/tests/conftest.py +115 -0
  83. examples/task_apps/enron/tests/integration/__init__.py +2 -0
  84. examples/task_apps/enron/tests/integration/test_enron_eval.py +177 -0
  85. examples/task_apps/enron/tests/integration/test_enron_rollout.py +135 -0
  86. examples/task_apps/enron/tests/unit/__init__.py +2 -0
  87. examples/task_apps/enron/tests/unit/test_enron_environment.py +126 -0
  88. examples/task_apps/math/__init__.py +0 -0
  89. examples/{rl/task_app → task_apps/math}/math_single_step.py +19 -10
  90. examples/task_apps/pokemon_battle/__init__.py +2 -0
  91. examples/task_apps/pokemon_battle/modal_app.py +104 -0
  92. examples/task_apps/pokemon_battle/task_app/README.md +68 -0
  93. examples/task_apps/pokemon_battle/task_app/__init__.py +6 -0
  94. examples/task_apps/pokemon_battle/task_app/pokemon_showdown.py +932 -0
  95. examples/task_apps/pokemon_red/README.md +357 -0
  96. examples/task_apps/pokemon_red/__init__.py +3 -0
  97. examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +225 -0
  98. examples/task_apps/pokemon_red/pallet_town_rl_config.toml +73 -0
  99. examples/task_apps/pokemon_red/task_app.py +606 -0
  100. examples/task_apps/pokemon_red/test_pallet_town_rewards.py +191 -0
  101. examples/task_apps/sokoban/README.md +307 -0
  102. examples/task_apps/sokoban/__init__.py +3 -0
  103. examples/task_apps/sokoban/eval_groq_qwen32.toml +16 -0
  104. examples/task_apps/sokoban/eval_openai_gpt5.toml +16 -0
  105. examples/task_apps/sokoban/task_app.py +1058 -0
  106. examples/task_apps/sokoban/tests/__init__.py +2 -0
  107. examples/task_apps/sokoban/tests/conftest.py +113 -0
  108. examples/task_apps/sokoban/tests/integration/__init__.py +2 -0
  109. examples/task_apps/sokoban/tests/integration/test_sokoban_eval.py +57 -0
  110. examples/task_apps/sokoban/tests/integration/test_sokoban_rollout.py +198 -0
  111. examples/task_apps/sokoban/tests/unit/__init__.py +2 -0
  112. examples/task_apps/sokoban/tests/unit/test_sokoban_environment.py +114 -0
  113. examples/task_apps/verilog/__init__.py +1 -0
  114. examples/task_apps/verilog/eval_groq_qwen32b.toml +20 -0
  115. examples/task_apps/verilog/task_app/README.md +12 -0
  116. examples/task_apps/verilog/task_app/__init__.py +1 -0
  117. examples/task_apps/verilog/task_app/grpo_verilog.py +931 -0
  118. examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +145 -0
  119. examples/task_apps/verilog/tests/__init__.py +2 -0
  120. examples/task_apps/verilog/tests/conftest.py +115 -0
  121. examples/task_apps/verilog/tests/integration/__init__.py +2 -0
  122. examples/task_apps/verilog/tests/integration/test_verilog_eval.py +179 -0
  123. examples/task_apps/verilog/tests/integration/test_verilog_rollout.py +55 -0
  124. examples/task_apps/verilog/tests/unit/__init__.py +2 -0
  125. examples/task_apps/verilog/tests/unit/test_verilog_scoring.py +118 -0
  126. examples/vlm/crafter_openai_vlm_agent.py +4 -4
  127. examples/vlm/run_crafter_vlm_benchmark.py +4 -4
  128. examples/workflows/__init__.py +0 -0
  129. examples/workflows/math_rl/__init__.py +0 -0
  130. examples/workflows/math_rl/download_dataset.py +80 -0
  131. synth_ai/__init__.py +2 -2
  132. synth_ai/api/train/builders.py +25 -11
  133. synth_ai/api/train/cli.py +12 -6
  134. synth_ai/api/train/configs/__init__.py +10 -10
  135. synth_ai/api/train/configs/rl.py +5 -4
  136. synth_ai/api/train/configs/sft.py +4 -3
  137. synth_ai/api/train/env_resolver.py +5 -2
  138. synth_ai/api/train/supported_algos.py +10 -5
  139. synth_ai/api/train/utils.py +7 -4
  140. synth_ai/cli/__init__.py +7 -51
  141. synth_ai/cli/_storage.py +4 -3
  142. synth_ai/cli/_validate_task_app.py +11 -0
  143. synth_ai/cli/balance.py +4 -3
  144. synth_ai/cli/calc.py +2 -2
  145. synth_ai/cli/demo.py +14 -7
  146. synth_ai/cli/legacy_root_backup.py +1 -1
  147. synth_ai/cli/rl_demo.py +8 -7
  148. synth_ai/cli/root.py +0 -97
  149. synth_ai/cli/task_apps.py +1707 -186
  150. synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +28 -16
  151. synth_ai/environments/examples/enron/engine.py +7 -2
  152. synth_ai/environments/examples/enron/environment.py +68 -0
  153. synth_ai/environments/examples/red/engine.py +27 -0
  154. synth_ai/environments/examples/red/engine_helpers/memory_map.py +7 -0
  155. synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_progression.py +477 -0
  156. synth_ai/environments/examples/red/engine_helpers/state_extraction.py +32 -0
  157. synth_ai/environments/examples/red/environment.py +60 -0
  158. synth_ai/environments/examples/sokoban/taskset.py +116 -0
  159. synth_ai/environments/examples/verilog/engine.py +30 -4
  160. synth_ai/evals/client.py +58 -61
  161. synth_ai/jobs/client.py +16 -4
  162. synth_ai/judge_schemas.py +16 -16
  163. synth_ai/py.typed +0 -0
  164. synth_ai/task/__init__.py +14 -5
  165. synth_ai/task/contracts.py +124 -38
  166. synth_ai/task/proxy.py +48 -56
  167. synth_ai/task/rubrics/__init__.py +53 -0
  168. synth_ai/task/rubrics/loaders.py +133 -0
  169. synth_ai/task/rubrics/models.py +57 -0
  170. synth_ai/task/rubrics/scoring.py +113 -0
  171. synth_ai/{rubrics/validators.py → task/rubrics/strict.py} +53 -30
  172. synth_ai/task/server.py +8 -7
  173. synth_ai/task/validators.py +269 -6
  174. synth_ai/tracing_v3/decorators.py +7 -3
  175. synth_ai/tracing_v3/replica_sync.py +4 -4
  176. synth_ai/tracing_v3/serialization.py +5 -5
  177. synth_ai/tracing_v3/trace_utils.py +317 -0
  178. synth_ai/tracing_v3/turso/native_manager.py +3 -3
  179. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/METADATA +4 -1
  180. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/RECORD +214 -101
  181. examples/agora_ex/README_MoE.md +0 -224
  182. examples/agora_ex/__init__.py +0 -7
  183. examples/agora_ex/agora_ex.py +0 -65
  184. examples/agora_ex/agora_ex_task_app.py +0 -590
  185. examples/agora_ex/configs/rl_lora_qwen3_moe_2xh200.toml +0 -121
  186. examples/agora_ex/reward_fn_grpo-human.py +0 -129
  187. examples/agora_ex/system_prompt_CURRENT.md +0 -63
  188. examples/agora_ex/task_app/agora_ex_task_app.py +0 -590
  189. examples/agora_ex/task_app/reward_fn_grpo-human.py +0 -129
  190. examples/agora_ex/task_app/system_prompt_CURRENT.md +0 -63
  191. synth_ai/rubrics/__init__.py +0 -22
  192. synth_ai/task/rubrics.py +0 -219
  193. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/README.md +0 -0
  194. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/README.md +0 -0
  195. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/__init__.py +0 -0
  196. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/branching.py +0 -0
  197. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/environment_routes.py +0 -0
  198. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/__init__.py +0 -0
  199. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/__init__.py +0 -0
  200. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/app.py +0 -0
  201. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/environment.py +0 -0
  202. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/react_agent.py +0 -0
  203. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/shared.py +0 -0
  204. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/tools.py +0 -0
  205. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/hosted_app.py +0 -0
  206. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/__init__.py +0 -0
  207. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/openai_client.py +0 -0
  208. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/main.py +0 -0
  209. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/registry.py +0 -0
  210. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/__init__.py +0 -0
  211. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/volume.py +0 -0
  212. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_agents.py +0 -0
  213. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/utils.py +0 -0
  214. /examples/{rl/task_app → task_apps/math}/README.md +0 -0
  215. /examples/{rl/task_app → task_apps/math}/math_task_app.py +0 -0
  216. /examples/{rl → workflows/math_rl}/configs/eval_base_qwen.toml +0 -0
  217. /examples/{rl → workflows/math_rl}/configs/eval_rl_qwen.toml +0 -0
  218. /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen.toml +0 -0
  219. /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen17.toml +0 -0
  220. /examples/{rl → workflows/math_rl}/configs/rl_from_ft_qwen.toml +0 -0
  221. /examples/{rl → workflows/math_rl}/run_eval.py +0 -0
  222. /examples/{rl → workflows/math_rl}/run_rl_and_save.py +0 -0
  223. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/WHEEL +0 -0
  224. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/entry_points.txt +0 -0
  225. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/licenses/LICENSE +0 -0
  226. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1058 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ import re
6
+ import contextlib
7
+ import time
8
+ from datetime import datetime, timezone
9
+ from typing import Any, Dict, List, Mapping, Optional, Sequence
10
+
11
+ import httpx
12
+
13
+ from fastapi import APIRouter, HTTPException, Request
14
+ from fastapi.exceptions import RequestValidationError
15
+ from fastapi.responses import JSONResponse
16
+
17
+ from synth_ai.environments.environment.tools import EnvToolCall
18
+ from synth_ai.environments.examples.sokoban.environment import SokobanEnvironment
19
+ from synth_ai.environments.examples.sokoban.taskset import (
20
+ SokobanTaskInstance,
21
+ SokobanTaskSet,
22
+ create_task_instance_from_seed,
23
+ )
24
+ from synth_ai.task.apps import TaskAppEntry, register_task_app
25
+ from synth_ai.task.contracts import (
26
+ RolloutMetrics,
27
+ RolloutRequest,
28
+ RolloutResponse,
29
+ RolloutStep,
30
+ RolloutTrajectory,
31
+ TaskInfo,
32
+ )
33
+ from synth_ai.task.auth import is_api_key_header_authorized, normalize_environment_api_key
34
+ from synth_ai.task.server import TaskAppConfig, create_task_app
35
+
36
+
37
+ ACTION_ID_TO_NAME = {0: "left", 1: "up", 2: "right", 3: "down"}
38
+ ACTION_TOKEN_TO_ID = {
39
+ "0": 0,
40
+ "1": 1,
41
+ "2": 2,
42
+ "3": 3,
43
+ "left": 0,
44
+ "move_left": 0,
45
+ "west": 0,
46
+ "l": 0,
47
+ "up": 1,
48
+ "move_up": 1,
49
+ "north": 1,
50
+ "u": 1,
51
+ "right": 2,
52
+ "move_right": 2,
53
+ "east": 2,
54
+ "r": 2,
55
+ "down": 3,
56
+ "move_down": 3,
57
+ "south": 3,
58
+ "d": 3,
59
+ }
60
+
61
+ SOKOBAN_SYSTEM_PROMPT = """You are an agent playing Sokoban.
62
+ The grid uses characters: '#' wall, '_' floor, 'O' box, '√' box on target, 'X' target, 'P' player.
63
+ Always respond with a single tool call named interact_many containing 1-5 actions.
64
+ Valid action tokens are digits 0/1/2/3 or their direction words (left/up/right/down).
65
+ Mapping: 0=left, 1=up, 2=right, 3=down. Avoid undoing progress and focus on pushing boxes onto targets."""
66
+
67
+
68
+ def _short_text(value: Any, *, limit: int = 280) -> str:
69
+ if value is None:
70
+ return ""
71
+ if isinstance(value, str):
72
+ text = value.strip()
73
+ return text if len(text) <= limit else text[: limit - 1] + "…"
74
+ if isinstance(value, (int, float, bool)):
75
+ return str(value)
76
+ try:
77
+ text = json.dumps(value, ensure_ascii=False)
78
+ except Exception:
79
+ text = str(value)
80
+ text = text.strip()
81
+ return text if len(text) <= limit else text[: limit - 1] + "…"
82
+
83
+
84
+ def _summarize_observation(observation: Any) -> str:
85
+ if isinstance(observation, dict):
86
+ for key in ("room_text", "observation", "grid"):
87
+ value = observation.get(key)
88
+ if isinstance(value, str) and value.strip():
89
+ return _short_text(value, limit=512)
90
+ preview = {
91
+ key: observation.get(key)
92
+ for key in ("player_position", "boxes_on_target", "num_boxes", "steps_taken")
93
+ if key in observation
94
+ }
95
+ if preview:
96
+ return _short_text(preview, limit=512)
97
+ return _short_text(observation, limit=512)
98
+
99
+
100
+ def _format_tool_calls(tool_calls: Sequence[Dict[str, Any]] | None) -> str:
101
+ if not tool_calls:
102
+ return "<noop>"
103
+ formatted: list[str] = []
104
+ for call in tool_calls:
105
+ args = call.get("args") if isinstance(call, dict) else None
106
+ if not isinstance(args, dict):
107
+ continue
108
+ if "actions" in args and isinstance(args["actions"], list):
109
+ parts: list[str] = []
110
+ for item in args["actions"]:
111
+ try:
112
+ val = int(item)
113
+ except Exception:
114
+ token = str(item).strip().lower()
115
+ val = ACTION_TOKEN_TO_ID.get(token)
116
+ name = ACTION_ID_TO_NAME.get(val, str(item)) if val is not None else str(item)
117
+ parts.append(str(name))
118
+ if parts:
119
+ formatted.append("[" + ", ".join(parts) + "]")
120
+ continue
121
+ action = args.get("action")
122
+ if action is None:
123
+ continue
124
+ try:
125
+ action = int(action)
126
+ except Exception:
127
+ token = str(action).strip().lower()
128
+ action = ACTION_TOKEN_TO_ID.get(token, action)
129
+ name = ACTION_ID_TO_NAME.get(action, str(action))
130
+ formatted.append(str(name))
131
+ return ", ".join(formatted) if formatted else "<noop>"
132
+
133
+
134
+ def _build_trace_payload(
135
+ request: RolloutRequest,
136
+ steps: Sequence[RolloutStep],
137
+ metrics: RolloutMetrics,
138
+ *,
139
+ difficulty: str,
140
+ initial_observation: Any,
141
+ provider: str = "local",
142
+ ) -> Dict[str, Any]:
143
+ created_at = datetime.now(timezone.utc)
144
+ base_time = time.time()
145
+ event_history: list[dict[str, Any]] = []
146
+ markov_messages: list[dict[str, Any]] = []
147
+ session_steps: list[dict[str, Any]] = []
148
+
149
+ if not steps:
150
+ observation_text = _summarize_observation(initial_observation)
151
+ event_time = base_time
152
+ observation_msg = {
153
+ "content": {"text": observation_text},
154
+ "message_type": "observation",
155
+ "time_record": {"event_time": event_time},
156
+ "metadata": {"step_index": 0},
157
+ }
158
+ markov_messages.append(observation_msg)
159
+ event_history.append(
160
+ {
161
+ "system_instance_id": "sokoban.step.0",
162
+ "time_record": {"event_time": event_time},
163
+ "reward": 0.0,
164
+ "terminated": True,
165
+ "truncated": False,
166
+ "metadata": {
167
+ "tool_calls": [],
168
+ },
169
+ }
170
+ )
171
+ session_steps.append(
172
+ {
173
+ "step_id": "step_0",
174
+ "step_index": 0,
175
+ "events": [event_history[-1]],
176
+ "markov_blanket_messages": markov_messages[-1:],
177
+ "step_metadata": {"reward": 0.0, "done": True, "truncated": False},
178
+ }
179
+ )
180
+ else:
181
+ for idx, step in enumerate(steps):
182
+ event_time = base_time + idx * 0.01
183
+ observation_text = _summarize_observation(step.obs)
184
+ action_text = _format_tool_calls(step.tool_calls)
185
+ observation_msg = {
186
+ "content": {"text": observation_text},
187
+ "message_type": "observation",
188
+ "time_record": {"event_time": event_time},
189
+ "metadata": {"step_index": idx},
190
+ }
191
+ action_msg = {
192
+ "content": {"text": action_text},
193
+ "message_type": "action",
194
+ "time_record": {"event_time": event_time + 0.0005},
195
+ "metadata": {"step_index": idx},
196
+ }
197
+ markov_messages.extend([observation_msg, action_msg])
198
+ reward_val = float(step.reward or 0.0)
199
+ event_history.append(
200
+ {
201
+ "system_instance_id": f"sokoban.step.{idx}",
202
+ "time_record": {"event_time": event_time},
203
+ "reward": reward_val,
204
+ "terminated": bool(step.done),
205
+ "truncated": bool(step.truncated),
206
+ "metadata": {
207
+ "tool_calls": step.tool_calls,
208
+ "info": step.info or {},
209
+ },
210
+ }
211
+ )
212
+ session_steps.append(
213
+ {
214
+ "step_id": f"step_{idx}",
215
+ "step_index": idx,
216
+ "events": [event_history[-1]],
217
+ "markov_blanket_messages": [observation_msg, action_msg],
218
+ "step_metadata": {
219
+ "reward": reward_val,
220
+ "done": bool(step.done),
221
+ "truncated": bool(step.truncated),
222
+ },
223
+ }
224
+ )
225
+
226
+ session_trace = {
227
+ "session_id": str(request.run_id),
228
+ "created_at": created_at.isoformat(),
229
+ "metadata": {
230
+ "task": "sokoban",
231
+ "difficulty": difficulty,
232
+ "seed": request.env.seed,
233
+ "provider": provider,
234
+ "env": request.env.model_dump(),
235
+ "policy": request.policy.model_dump(),
236
+ },
237
+ "session_time_steps": session_steps,
238
+ "event_history": event_history,
239
+ "markov_blanket_message_history": markov_messages,
240
+ }
241
+
242
+ return {
243
+ "version": 3,
244
+ "session_trace": session_trace,
245
+ "run_id": request.run_id,
246
+ "policy_id": request.policy.policy_id or request.policy.policy_name,
247
+ "reward": metrics.mean_return,
248
+ "episode_returns": metrics.episode_returns,
249
+ "mean_return": metrics.mean_return,
250
+ "num_steps": metrics.num_steps,
251
+ }
252
+
253
+
254
+
255
+ def _task_info() -> TaskInfo:
256
+ return TaskInfo(
257
+ task={"id": "sokoban", "name": "Sokoban", "version": "1.0.0"},
258
+ environment="sokoban",
259
+ action_space={
260
+ "type": "tool_call",
261
+ "tools": [{"name": "interact", "schema": {"action": "int"}}],
262
+ "max_calls": 1,
263
+ },
264
+ observation={"summary": "Sokoban grid observation", "keys": ["grid", "player"]},
265
+ dataset={"id": "sokoban", "name": "Sokoban", "version": "1.0.0"},
266
+ rubric={"version": "1", "criteria_count": 1, "source": "inline"},
267
+ inference={"supports_proxy": False},
268
+ limits={"max_turns": 200},
269
+ )
270
+
271
+
272
+ router = APIRouter()
273
+
274
+
275
+ async def rollout_executor(request: RolloutRequest, fastapi_request: Request) -> RolloutResponse:
276
+ policy_cfg = dict(request.policy.config or {})
277
+ provider = str(policy_cfg.get("provider") or "").strip().lower()
278
+ if provider == "groq":
279
+ return await _rollout_with_groq(request, fastapi_request, policy_cfg)
280
+ if provider == "openai":
281
+ return await _rollout_with_openai(request, fastapi_request, policy_cfg)
282
+
283
+ taskset: SokobanTaskSet = fastapi_request.app.state.sokoban_taskset
284
+ seed = request.env.seed or 0
285
+ difficulty = (request.env.config or {}).get("difficulty") or "easy"
286
+ # Create deterministic instance from seed
287
+ instance: SokobanTaskInstance = await create_task_instance_from_seed(str(difficulty), int(seed))
288
+ env = SokobanEnvironment(instance)
289
+ obs = await env.initialize()
290
+ initial_observation = obs
291
+
292
+ tool_calls: List[Dict[str, Any]] = []
293
+ # If a predefined action sequence is provided, execute it (evaluation-style)
294
+ actions: Optional[Sequence[int]] = None
295
+ try:
296
+ cfg = request.policy.config or {}
297
+ if isinstance(cfg.get("actions"), list):
298
+ actions = [int(a) for a in cfg["actions"]]
299
+ except Exception:
300
+ actions = None
301
+
302
+ last_obs: Any = obs
303
+ steps: List[RolloutStep] = []
304
+ max_steps = int((request.env.config or {}).get("max_steps") or 50)
305
+ executed = 0
306
+ if actions:
307
+ for a in actions[:max_steps]:
308
+ last_obs = await env.step(EnvToolCall(tool="interact", args={"action": int(a)}))
309
+ executed += 1
310
+ steps.append(
311
+ RolloutStep(obs=last_obs, tool_calls=[{"tool": "interact", "args": {"action": int(a)}}], reward=0.0, done=False, info={})
312
+ )
313
+ # Mark episode end (single-episode trajectory)
314
+ final = {"observation": last_obs, "reward": 0.0}
315
+ if not steps:
316
+ steps = [RolloutStep(obs=last_obs, tool_calls=[], reward=0.0, done=True, info={})]
317
+
318
+ # Extract inference_url from policy config (None for manual rollouts)
319
+ inference_url = policy_cfg.get("inference_url")
320
+
321
+ traj = RolloutTrajectory(
322
+ env_id="sokoban",
323
+ policy_id=request.policy.policy_id or "policy",
324
+ steps=steps,
325
+ final=final,
326
+ length=len(steps),
327
+ inference_url=inference_url, # NEW: Required for trace correlation
328
+ )
329
+ metrics = RolloutMetrics(
330
+ episode_returns=[final.get("reward", 0.0) or 0.0],
331
+ mean_return=final.get("reward", 0.0) or 0.0,
332
+ num_steps=len(steps),
333
+ num_episodes=1,
334
+ outcome_score=None,
335
+ events_score=None,
336
+ details={},
337
+ )
338
+ trace_payload = _build_trace_payload(
339
+ request,
340
+ steps,
341
+ metrics,
342
+ difficulty=str(difficulty),
343
+ initial_observation=initial_observation,
344
+ )
345
+ return RolloutResponse(
346
+ run_id=request.run_id,
347
+ trajectories=[traj],
348
+ branches={},
349
+ metrics=metrics,
350
+ aborted=False,
351
+ ops_executed=1 + executed,
352
+ trace=trace_payload,
353
+ )
354
+
355
+
356
+ def _format_sokoban_prompt(observation: dict[str, Any], last_actions: list[int]) -> str:
357
+ grid = observation.get("room_text", "")
358
+ boxes = observation.get("boxes_on_target", 0)
359
+ total_boxes = observation.get("num_boxes", boxes)
360
+ position = observation.get("player_position", ())
361
+ reward_last = observation.get("reward_last", 0.0)
362
+ steps_taken = observation.get("steps_taken", 0)
363
+ max_steps = observation.get("max_steps", 0)
364
+ last_str = (
365
+ ", ".join(ACTION_ID_TO_NAME.get(a, str(a)) for a in last_actions) if last_actions else "none"
366
+ )
367
+ return (
368
+ f"Step {steps_taken} / {max_steps}\n"
369
+ f"Player position: {position}\n"
370
+ f"Boxes on target: {boxes} / {total_boxes}\n"
371
+ f"Last reward: {reward_last}\n"
372
+ f"Previous actions: {last_str}\n"
373
+ "Grid:\n"
374
+ f"{grid}\n"
375
+ "Select up to five next actions via the interact_many tool."
376
+ )
377
+
378
+
379
+ def _extract_actions_from_response(
380
+ response: dict[str, Any], max_actions: int
381
+ ) -> list[int]:
382
+ import json as json_lib
383
+ print(f"[extract] FULL RESPONSE:", flush=True)
384
+ print(json_lib.dumps(response, indent=2)[:2000], flush=True)
385
+
386
+ actions: list[int] = []
387
+ choices = response.get("choices") or []
388
+ print(f"[extract] {len(choices)} choices", flush=True)
389
+ if choices:
390
+ msg = choices[0].get("message", {})
391
+ print(f"[extract] tool_calls: {msg.get('tool_calls')}", flush=True)
392
+ print(f"[extract] content: {msg.get('content')}", flush=True)
393
+ print(f"[extract] finish_reason: {choices[0].get('finish_reason')}", flush=True)
394
+ for choice in choices:
395
+ message = choice.get("message") or {}
396
+ tool_calls = message.get("tool_calls") or []
397
+ for tool_call in tool_calls:
398
+ function = tool_call.get("function") or {}
399
+ arguments = function.get("arguments")
400
+ payload: dict[str, Any] | None = None
401
+ if isinstance(arguments, str):
402
+ try:
403
+ payload = json.loads(arguments)
404
+ except json.JSONDecodeError:
405
+ payload = None
406
+ elif isinstance(arguments, dict):
407
+ payload = arguments
408
+ if not payload:
409
+ continue
410
+ raw_actions = payload.get("actions")
411
+ if isinstance(raw_actions, list):
412
+ for item in raw_actions:
413
+ if isinstance(item, int) and item in ACTION_ID_TO_NAME:
414
+ actions.append(int(item))
415
+ continue
416
+ if isinstance(item, str):
417
+ token = item.strip().lower()
418
+ if token in ACTION_TOKEN_TO_ID:
419
+ actions.append(ACTION_TOKEN_TO_ID[token])
420
+ if actions:
421
+ break
422
+ if actions:
423
+ break
424
+
425
+ if not actions and choices:
426
+ # Fallback: parse tokens from assistant text
427
+ text = choices[0].get("message", {}).get("content") or ""
428
+ tokens = re.findall(r"[0-3a-zA-Z_]+", text)
429
+ for tok in tokens:
430
+ token = tok.strip().lower()
431
+ if token in ACTION_TOKEN_TO_ID:
432
+ actions.append(ACTION_TOKEN_TO_ID[token])
433
+
434
+ if len(actions) > max_actions:
435
+ return actions[:max_actions]
436
+ return actions
437
+
438
+
439
+ async def _call_groq_chat(
440
+ client: httpx.AsyncClient,
441
+ api_key: str,
442
+ payload: dict[str, Any],
443
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
444
+ try:
445
+ response = await client.post(
446
+ "https://api.groq.com/openai/v1/chat/completions",
447
+ json=payload,
448
+ headers={"Authorization": f"Bearer {api_key}"},
449
+ )
450
+ response.raise_for_status()
451
+ data = response.json()
452
+ return data, {
453
+ "status": response.status_code,
454
+ "headers": dict(response.headers),
455
+ "body": data,
456
+ }
457
+ except httpx.HTTPStatusError as exc:
458
+ try:
459
+ body = exc.response.json()
460
+ except Exception:
461
+ body = {"raw": exc.response.text}
462
+ error_detail = {
463
+ "status": exc.response.status_code,
464
+ "body": body,
465
+ "headers": dict(exc.response.headers),
466
+ }
467
+ raise HTTPException(status_code=exc.response.status_code, detail=error_detail) from exc
468
+ except httpx.RequestError as exc:
469
+ raise HTTPException(status_code=502, detail=f"Groq request error: {exc}") from exc
470
+
471
+
472
+ async def _call_openai_chat(
473
+ client: httpx.AsyncClient,
474
+ api_key: str,
475
+ payload: dict[str, Any],
476
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
477
+ try:
478
+ response = await client.post(
479
+ "https://api.openai.com/v1/chat/completions",
480
+ json=payload,
481
+ headers={"Authorization": f"Bearer {api_key}"},
482
+ )
483
+ response.raise_for_status()
484
+ data = response.json()
485
+ return data, {
486
+ "status": response.status_code,
487
+ "headers": dict(response.headers),
488
+ "body": data,
489
+ }
490
+ except httpx.HTTPStatusError as exc:
491
+ try:
492
+ body = exc.response.json()
493
+ except Exception:
494
+ body = {"raw": exc.response.text}
495
+ error_detail = {
496
+ "status": exc.response.status_code,
497
+ "body": body,
498
+ "headers": dict(exc.response.headers),
499
+ }
500
+ try:
501
+ print("[openai:error]", error_detail, flush=True)
502
+ except Exception:
503
+ pass
504
+ raise HTTPException(status_code=exc.response.status_code, detail=error_detail) from exc
505
+ except httpx.RequestError as exc:
506
+ raise HTTPException(status_code=502, detail=f"OpenAI request error: {exc}") from exc
507
+
508
+
509
+ async def _rollout_with_groq(
510
+ request: RolloutRequest,
511
+ fastapi_request: Request,
512
+ config: dict[str, Any],
513
+ ) -> RolloutResponse:
514
+ api_key = os.getenv("GROQ_API_KEY")
515
+ if not api_key:
516
+ raise HTTPException(
517
+ status_code=503,
518
+ detail="GROQ_API_KEY environment variable is required for Groq rollouts.",
519
+ )
520
+
521
+ seed = request.env.seed or 0
522
+ difficulty = (request.env.config or {}).get("difficulty") or "easy"
523
+ instance: SokobanTaskInstance = await create_task_instance_from_seed(str(difficulty), int(seed))
524
+ env = SokobanEnvironment(instance)
525
+ observation = await env.initialize()
526
+ initial_observation = observation
527
+
528
+ model = config.get("model") or "qwen/qwen3-32b"
529
+ temperature = float(config.get("temperature", 0.0) or 0.0)
530
+ top_p = float(config.get("top_p", 0.95) or 0.95)
531
+ max_tokens = int(config.get("max_tokens", 128) or 128)
532
+ actions_per_call = int(config.get("max_actions_per_call", 4) or 4)
533
+ actions_per_call = max(1, min(8, actions_per_call))
534
+
535
+ max_steps = int((request.env.config or {}).get("max_steps") or 50)
536
+
537
+ steps: List[RolloutStep] = []
538
+ last_actions: list[int] = []
539
+ total_reward = float(observation.get("total_reward") or 0.0)
540
+ executed = 0
541
+
542
+ tool_items_enum = sorted(set(ACTION_TOKEN_TO_ID.keys()))
543
+ tool_schema = {
544
+ "type": "function",
545
+ "function": {
546
+ "name": "interact_many",
547
+ "description": "Execute a short sequence of Sokoban moves in order.",
548
+ "parameters": {
549
+ "type": "object",
550
+ "properties": {
551
+ "actions": {
552
+ "type": "array",
553
+ "items": {"type": "string", "enum": tool_items_enum},
554
+ "minItems": 1,
555
+ "maxItems": actions_per_call,
556
+ }
557
+ },
558
+ "required": ["actions"],
559
+ "additionalProperties": False,
560
+ },
561
+ },
562
+ }
563
+
564
+ async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client:
565
+ for _ in range(max_steps):
566
+ user_prompt = _format_sokoban_prompt(observation, last_actions)
567
+ messages = [
568
+ {"role": "system", "content": SOKOBAN_SYSTEM_PROMPT},
569
+ {"role": "user", "content": user_prompt},
570
+ ]
571
+ payload = {
572
+ "model": model,
573
+ "messages": messages,
574
+ "temperature": temperature,
575
+ "top_p": top_p,
576
+ "max_tokens": max_tokens,
577
+ "tools": [tool_schema],
578
+ "tool_choice": {"type": "function", "function": {"name": "interact_many"}},
579
+ }
580
+ vendor_attempts: list[dict[str, Any]] = []
581
+ try:
582
+ response, response_meta = await _call_groq_chat(client, api_key, payload)
583
+ vendor_attempts.append({"request": payload, "response": response_meta})
584
+ except HTTPException as exc:
585
+ detail = exc.detail
586
+ if isinstance(detail, dict):
587
+ vendor_attempts.append({"request": payload, "error": detail})
588
+ else:
589
+ vendor_attempts.append({"request": payload, "error": {"message": str(detail)}})
590
+ raise
591
+
592
+ actions = _extract_actions_from_response(response, actions_per_call)
593
+ if not actions:
594
+ break
595
+
596
+ aggregated_actions: list[int] = []
597
+ aggregated_reward = 0.0
598
+ done = False
599
+ truncated = False
600
+ intermediate_rewards: list[float] = []
601
+ if executed >= max_steps:
602
+ break
603
+
604
+ for action in actions:
605
+ if executed >= max_steps:
606
+ break
607
+ aggregated_actions.append(int(action))
608
+ observation = await env.step(
609
+ EnvToolCall(tool="interact", args={"action": int(action)})
610
+ )
611
+ current_total = float(observation.get("total_reward") or total_reward)
612
+ reward_delta = current_total - total_reward
613
+ total_reward = current_total
614
+ aggregated_reward += reward_delta
615
+ intermediate_rewards.append(reward_delta)
616
+ done = bool(observation.get("terminated"))
617
+ truncated = bool(observation.get("truncated"))
618
+ executed += 1
619
+ if done or truncated:
620
+ break
621
+
622
+ if not aggregated_actions:
623
+ continue
624
+
625
+ last_actions = aggregated_actions
626
+ step = RolloutStep(
627
+ obs=observation,
628
+ tool_calls=[
629
+ {
630
+ "tool": "interact_many",
631
+ "args": {"actions": [int(a) for a in aggregated_actions]},
632
+ "source": "groq",
633
+ }
634
+ ],
635
+ reward=aggregated_reward,
636
+ done=done,
637
+ truncated=truncated if truncated else None,
638
+ info={
639
+ "provider": "groq",
640
+ "model": model,
641
+ "actions_executed": aggregated_actions,
642
+ "prompt": user_prompt,
643
+ "reward_deltas": intermediate_rewards,
644
+ "vendor_attempts": vendor_attempts,
645
+ "groq_attempts": vendor_attempts,
646
+ },
647
+ )
648
+ steps.append(step)
649
+
650
+ if step.done or (step.truncated or False):
651
+ break
652
+
653
+ final = {"observation": observation, "reward": total_reward}
654
+ inference_url_groq = "https://api.groq.com/openai/v1/chat/completions"
655
+
656
+ trajectory = RolloutTrajectory(
657
+ env_id=request.env.env_id or request.env.env_name or "sokoban",
658
+ policy_id=request.policy.policy_id or request.policy.policy_name or "sokoban-groq",
659
+ steps=steps,
660
+ final=final,
661
+ length=len(steps),
662
+ inference_url=inference_url_groq, # NEW: Required for trace correlation
663
+ )
664
+ metrics = RolloutMetrics(
665
+ episode_returns=[total_reward],
666
+ mean_return=total_reward if steps else 0.0,
667
+ num_steps=len(steps),
668
+ num_episodes=1,
669
+ outcome_score=None,
670
+ events_score=None,
671
+ details={"provider": "groq", "model": model},
672
+ )
673
+ trace_payload = _build_trace_payload(
674
+ request,
675
+ steps,
676
+ metrics,
677
+ difficulty=str(difficulty),
678
+ initial_observation=initial_observation,
679
+ provider="groq",
680
+ )
681
+ return RolloutResponse(
682
+ run_id=request.run_id,
683
+ trajectories=[trajectory],
684
+ branches={},
685
+ metrics=metrics,
686
+ aborted=False,
687
+ ops_executed=executed,
688
+ trace=trace_payload,
689
+ )
690
+
691
+
692
+ async def _rollout_with_openai(
693
+ request: RolloutRequest,
694
+ fastapi_request: Request,
695
+ config: dict[str, Any],
696
+ ) -> RolloutResponse:
697
+ api_key = os.getenv("OPENAI_API_KEY")
698
+ if not api_key:
699
+ raise HTTPException(
700
+ status_code=503,
701
+ detail="OPENAI_API_KEY environment variable is required for OpenAI rollouts.",
702
+ )
703
+
704
+ seed = request.env.seed or 0
705
+ difficulty = (request.env.config or {}).get("difficulty") or "easy"
706
+ instance: SokobanTaskInstance = await create_task_instance_from_seed(str(difficulty), int(seed))
707
+ env = SokobanEnvironment(instance)
708
+ observation = await env.initialize()
709
+ initial_observation = observation
710
+
711
+ model = config.get("model") or "gpt-5"
712
+ temperature_cfg = config.get("temperature")
713
+ top_p_cfg = config.get("top_p")
714
+ completion_tokens = int(
715
+ config.get("max_completion_tokens")
716
+ or config.get("max_tokens")
717
+ or 4000
718
+ )
719
+ actions_per_call = int(config.get("max_actions_per_call", 4) or 4)
720
+ actions_per_call = max(1, min(8, actions_per_call))
721
+
722
+ max_steps = int((request.env.config or {}).get("max_steps") or 50)
723
+
724
+ steps: List[RolloutStep] = []
725
+ last_actions: list[int] = []
726
+ total_reward = float(observation.get("total_reward") or 0.0)
727
+ executed = 0
728
+
729
+ tool_items_enum = sorted(set(ACTION_TOKEN_TO_ID.keys()))
730
+ tool_schema = {
731
+ "type": "function",
732
+ "function": {
733
+ "name": "interact_many",
734
+ "description": "Execute a short sequence of Sokoban moves in order.",
735
+ "parameters": {
736
+ "type": "object",
737
+ "properties": {
738
+ "actions": {
739
+ "type": "array",
740
+ "items": {"type": "string", "enum": tool_items_enum},
741
+ "minItems": 1,
742
+ "maxItems": actions_per_call,
743
+ }
744
+ },
745
+ "required": ["actions"],
746
+ "additionalProperties": False,
747
+ },
748
+ },
749
+ }
750
+
751
+ async with httpx.AsyncClient(timeout=httpx.Timeout(120.0)) as client:
752
+ # Process ops array - each "policy" op triggers one LLM call
753
+ ops_to_process = request.ops or []
754
+ if not ops_to_process:
755
+ # If no ops provided, default to max_steps policy calls
756
+ ops_to_process = ["policy"] * max_steps
757
+
758
+ for op_idx, op in enumerate(ops_to_process):
759
+ # Only process "policy" ops, skip explicit actions for now
760
+ if op != "policy" and not (isinstance(op, str) and op.lower() == "policy"):
761
+ continue
762
+
763
+ user_prompt = _format_sokoban_prompt(observation, last_actions)
764
+ messages = [
765
+ {"role": "system", "content": SOKOBAN_SYSTEM_PROMPT},
766
+ {"role": "user", "content": user_prompt},
767
+ ]
768
+ payload_base: dict[str, Any] = {
769
+ "model": model,
770
+ "messages": messages,
771
+ "max_completion_tokens": completion_tokens,
772
+ "tools": [tool_schema],
773
+ "tool_choice": {"type": "function", "function": {"name": "interact_many"}},
774
+ }
775
+ # GPT-5 models don't support temperature/top_p (only default value of 1)
776
+ is_gpt5 = "gpt-5" in model.lower()
777
+ if temperature_cfg is not None and not is_gpt5:
778
+ with contextlib.suppress(Exception):
779
+ payload_base["temperature"] = float(temperature_cfg)
780
+ if top_p_cfg is not None and not is_gpt5:
781
+ with contextlib.suppress(Exception):
782
+ payload_base["top_p"] = float(top_p_cfg)
783
+
784
+ vendor_attempts: list[dict[str, Any]] = []
785
+ attempt_payload = dict(payload_base)
786
+ while True:
787
+ attempt_record: dict[str, Any] = {"request": dict(attempt_payload)}
788
+ try:
789
+ response, response_meta = await _call_openai_chat(client, api_key, attempt_payload)
790
+ attempt_record["response"] = response_meta
791
+ vendor_attempts.append(attempt_record)
792
+ break
793
+ except HTTPException as exc:
794
+ detail = exc.detail
795
+ attempt_record["error"] = detail if isinstance(detail, dict) else {"message": str(detail)}
796
+ vendor_attempts.append(attempt_record)
797
+ handled = False
798
+ body = detail.get("body") if isinstance(detail, dict) else None
799
+ error_info = body.get("error") if isinstance(body, dict) else None
800
+ code = error_info.get("code") if isinstance(error_info, dict) else None
801
+ param = error_info.get("param") if isinstance(error_info, dict) else None
802
+ if code in {"unsupported_parameter", "unsupported_value"}:
803
+ if param == "temperature" and "temperature" in attempt_payload:
804
+ attempt_payload = dict(attempt_payload)
805
+ attempt_payload.pop("temperature", None)
806
+ handled = True
807
+ elif param == "top_p" and "top_p" in attempt_payload:
808
+ attempt_payload = dict(attempt_payload)
809
+ attempt_payload.pop("top_p", None)
810
+ handled = True
811
+ if handled:
812
+ continue
813
+ raise
814
+
815
+ actions = _extract_actions_from_response(response, actions_per_call)
816
+ if not actions:
817
+ break
818
+
819
+ aggregated_actions: list[int] = []
820
+ aggregated_reward = 0.0
821
+ done = False
822
+ truncated = False
823
+ intermediate_rewards: list[float] = []
824
+ if executed >= max_steps:
825
+ break
826
+
827
+ print(f"[debug] Processing {len(actions)} actions from LLM", flush=True)
828
+ for action in actions:
829
+ if executed >= max_steps:
830
+ break
831
+ aggregated_actions.append(int(action))
832
+ observation = await env.step(
833
+ EnvToolCall(tool="interact", args={"action": int(action)})
834
+ )
835
+ current_total = float(observation.get("total_reward") or total_reward)
836
+ reward_delta = current_total - total_reward
837
+ total_reward = current_total
838
+ aggregated_reward += reward_delta
839
+ intermediate_rewards.append(reward_delta)
840
+ done = bool(observation.get("terminated"))
841
+ truncated = bool(observation.get("truncated"))
842
+ executed += 1
843
+ if done or truncated:
844
+ break
845
+
846
+ print(f"[debug] After action {action}: done={done}, trunc={truncated}, exec={executed}", flush=True)
847
+ if not aggregated_actions:
848
+ continue
849
+
850
+ last_actions = aggregated_actions
851
+ step = RolloutStep(
852
+ obs=observation,
853
+ tool_calls=[
854
+ {
855
+ "tool": "interact_many",
856
+ "args": {"actions": [int(a) for a in aggregated_actions]},
857
+ "source": "openai",
858
+ }
859
+ ],
860
+ reward=aggregated_reward,
861
+ done=done,
862
+ truncated=truncated if truncated else None,
863
+ info={
864
+ "provider": "openai",
865
+ "model": model,
866
+ "actions_executed": aggregated_actions,
867
+ "prompt": user_prompt,
868
+ "reward_deltas": intermediate_rewards,
869
+ "vendor_attempts": vendor_attempts,
870
+ "openai_attempts": vendor_attempts,
871
+ "max_completion_tokens": completion_tokens,
872
+ "temperature_requested": temperature_cfg,
873
+ "top_p_requested": top_p_cfg,
874
+ },
875
+ )
876
+ steps.append(step)
877
+
878
+ if step.done or (step.truncated or False):
879
+ break
880
+
881
+ final = {"observation": observation, "reward": total_reward}
882
+ inference_url_openai = "https://api.openai.com/v1/chat/completions"
883
+
884
+ trajectory = RolloutTrajectory(
885
+ env_id=request.env.env_id or request.env.env_name or "sokoban",
886
+ policy_id=request.policy.policy_id or request.policy.policy_name or "sokoban-openai",
887
+ steps=steps,
888
+ final=final,
889
+ length=len(steps),
890
+ inference_url=inference_url_openai, # NEW: Required for trace correlation
891
+ )
892
+ metrics = RolloutMetrics(
893
+ episode_returns=[total_reward],
894
+ mean_return=total_reward if steps else 0.0,
895
+ num_steps=len(steps),
896
+ num_episodes=1,
897
+ outcome_score=None,
898
+ events_score=None,
899
+ details={"provider": "openai", "model": model},
900
+ )
901
+ trace_payload = _build_trace_payload(
902
+ request,
903
+ steps,
904
+ metrics,
905
+ difficulty=str(difficulty),
906
+ initial_observation=initial_observation,
907
+ provider="openai",
908
+ )
909
+ return RolloutResponse(
910
+ run_id=request.run_id,
911
+ trajectories=[trajectory],
912
+ branches={},
913
+ metrics=metrics,
914
+ aborted=False,
915
+ ops_executed=executed,
916
+ trace=trace_payload,
917
+ )
918
+
919
+
920
+ def build_config() -> TaskAppConfig:
921
+ taskset = SokobanTaskSet()
922
+ base = _task_info()
923
+ app_state: dict[str, Any] = {"sokoban_taskset": taskset, "sokoban_envs": {}}
924
+ config = TaskAppConfig(
925
+ app_id="sokoban",
926
+ name="Sokoban Task App",
927
+ description="Sokoban environment exposed as a Synth task app.",
928
+ base_task_info=base,
929
+ describe_taskset=lambda: {"id": "sokoban", "name": "Sokoban"},
930
+ provide_task_instances=lambda seeds: taskset.provide_task_instances(seeds),
931
+ rollout=rollout_executor,
932
+ dataset_registry=None,
933
+ rubrics=None,
934
+ proxy=None,
935
+ routers=(router,),
936
+ app_state=app_state,
937
+ cors_origins=["*"],
938
+ )
939
+ return config
940
+
941
+
942
+ # --- Health routes (auth-tolerant) ---
943
+ def fastapi_app():
944
+ app = create_task_app(build_config())
945
+
946
+ # Replace default health handlers to log expected ENVIRONMENT_API_KEY when unauthorized
947
+ filtered_routes = []
948
+ for route in app.router.routes:
949
+ path = getattr(route, "path", None)
950
+ methods = getattr(route, "methods", set()) or set()
951
+ if path in {"/health", "/health/rollout"} and "GET" in methods:
952
+ continue
953
+ filtered_routes.append(route)
954
+ app.router.routes = filtered_routes
955
+
956
+ def _key_prefix() -> Optional[str]:
957
+ key = normalize_environment_api_key()
958
+ return key[: max(1, len(key) // 2)] if key else None
959
+
960
+ @app.get("/health")
961
+ async def health(request: Request):
962
+ env_key = normalize_environment_api_key()
963
+ if not env_key:
964
+ return JSONResponse(status_code=503, content={"status": "unhealthy", "detail": "Missing ENVIRONMENT_API_KEY"})
965
+ if not is_api_key_header_authorized(request):
966
+ content: Dict[str, Any] = {"status": "healthy", "authorized": False}
967
+ prefix = _key_prefix()
968
+ if prefix:
969
+ content["expected_api_key_prefix"] = prefix
970
+ return JSONResponse(status_code=200, content=content)
971
+ return {"status": "healthy", "authorized": True}
972
+
973
+ @app.get("/health/rollout")
974
+ async def health_rollout(request: Request):
975
+ env_key = normalize_environment_api_key()
976
+ if not env_key:
977
+ return JSONResponse(status_code=503, content={"status": "unhealthy", "detail": "Missing ENVIRONMENT_API_KEY"})
978
+ if not is_api_key_header_authorized(request):
979
+ content: Dict[str, Any] = {"status": "healthy", "authorized": False}
980
+ prefix = _key_prefix()
981
+ if prefix:
982
+ content["expected_api_key_prefix"] = prefix
983
+ return JSONResponse(status_code=200, content=content)
984
+ return {"ok": True, "authorized": True}
985
+
986
+ # Basic env lifecycle routes (for local eval only)
987
+ @app.post("/env/sokoban/initialize")
988
+ async def initialize_env(request: Request, payload: Dict[str, Any]):
989
+ difficulty = str((payload.get("config") or {}).get("difficulty") or "easy")
990
+ seed = payload.get("seed")
991
+ try:
992
+ instance: SokobanTaskInstance = await create_task_instance_from_seed(difficulty, int(seed) if seed is not None else 0)
993
+ except Exception as exc:
994
+ raise HTTPException(status_code=400, detail=str(exc))
995
+ env = SokobanEnvironment(instance)
996
+ obs = await env.initialize()
997
+ envs: Dict[str, SokobanEnvironment] = request.app.state.sokoban_envs
998
+ env_id = f"{difficulty}:{seed or 0}"
999
+ envs[env_id] = env
1000
+ return {"env_id": env_id, "observation": obs}
1001
+
1002
+ @app.post("/env/sokoban/step")
1003
+ async def step_env(request: Request, payload: Dict[str, Any]):
1004
+ env_id = str(payload.get("env_id") or "")
1005
+ if not env_id:
1006
+ raise HTTPException(status_code=400, detail="env_id required")
1007
+ envs: Dict[str, SokobanEnvironment] = request.app.state.sokoban_envs
1008
+ env = envs.get(env_id)
1009
+ if not env:
1010
+ raise HTTPException(status_code=404, detail="Unknown env_id")
1011
+
1012
+ action = None
1013
+ tool_calls = payload.get("tool_calls") or []
1014
+ if tool_calls:
1015
+ try:
1016
+ first = tool_calls[0] or {}
1017
+ args = first.get("args") or {}
1018
+ action = int(args.get("action")) if "action" in args else None
1019
+ except Exception:
1020
+ action = None
1021
+ if action is None and "action" in payload:
1022
+ try:
1023
+ action = int(payload.get("action"))
1024
+ except Exception:
1025
+ action = None
1026
+ if action is None:
1027
+ raise HTTPException(status_code=400, detail="action required")
1028
+ obs = await env.step(EnvToolCall(tool="interact", args={"action": int(action)}))
1029
+ return {"observation": obs}
1030
+
1031
+ @app.post("/env/sokoban/terminate")
1032
+ async def terminate_env(request: Request, payload: Dict[str, Any]):
1033
+ env_id = str(payload.get("env_id") or "")
1034
+ envs: Dict[str, SokobanEnvironment] = request.app.state.sokoban_envs
1035
+ env = envs.pop(env_id, None)
1036
+ if env:
1037
+ obs = await env.terminate()
1038
+ else:
1039
+ obs = {"terminated": True}
1040
+ return {"ok": True, "observation": obs}
1041
+
1042
+ @app.exception_handler(RequestValidationError)
1043
+ async def _on_validation_error(_request: Request, exc: RequestValidationError):
1044
+ return JSONResponse(status_code=422, content={"status": "invalid", "detail": exc.errors()[:5]})
1045
+
1046
+ return app
1047
+
1048
+
1049
+ register_task_app(
1050
+ entry=TaskAppEntry(
1051
+ app_id="sokoban",
1052
+ description="Sokoban task app",
1053
+ config_factory=build_config,
1054
+ aliases=("sokoban-rl",),
1055
+ env_files=(),
1056
+ modal=None,
1057
+ )
1058
+ )