synth-ai 0.2.12__py3-none-any.whl → 0.2.13.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (229) hide show
  1. examples/multi_step/configs/crafter_rl_outcome.toml +74 -0
  2. examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +186 -0
  3. examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +83 -0
  4. examples/multi_step/configs/crafter_rl_stepwise_simple.toml +78 -0
  5. examples/multi_step/crafter_rl_lora.md +51 -10
  6. examples/multi_step/sse_metrics_streaming_notes.md +357 -0
  7. examples/multi_step/task_app_config_notes.md +7 -1
  8. examples/swe/task_app/grpo_swe_mini.py +55 -26
  9. examples/swe/task_app/hosted/rollout.py +40 -0
  10. examples/swe/task_app/hosted/test_service.py +5 -6
  11. examples/task_apps/TESTING.md +275 -0
  12. examples/task_apps/__init__.py +0 -0
  13. examples/task_apps/crafter/__init__.py +0 -0
  14. examples/task_apps/crafter/task_app/__init__.py +2 -0
  15. examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter.py +21 -46
  16. examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter_task_app.py +1 -1
  17. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/policy.py +60 -4
  18. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/openai_client.py +109 -45
  19. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/policy_routes.py +67 -49
  20. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/rollout.py +242 -193
  21. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_service.py +5 -6
  22. examples/task_apps/dev/pokemon_emerald/__init__.py +2 -0
  23. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/README.md +811 -0
  24. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/__init__.py +120 -0
  25. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/action.py +160 -0
  26. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/memory.py +155 -0
  27. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/perception.py +69 -0
  28. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/planning.py +96 -0
  29. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/simple.py +1502 -0
  30. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/system_prompt.py +4 -0
  31. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/grab_map.py +68 -0
  32. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/manual.py +216 -0
  33. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/__init__.py +35 -0
  34. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emerald_utils.py +631 -0
  35. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emulator.py +1544 -0
  36. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/enums.py +1428 -0
  37. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/memory_reader.py +4848 -0
  38. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/types.py +41 -0
  39. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/utils.py +298 -0
  40. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pyproject.toml +95 -0
  41. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/run.py +204 -0
  42. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/__init__.py +0 -0
  43. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/app.py +2152 -0
  44. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/client.py +429 -0
  45. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/frame_server.py +155 -0
  46. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/README.md +78 -0
  47. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/__init__.py +0 -0
  48. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/run_tests.py +122 -0
  49. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_direct.py +76 -0
  50. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_prompts.py +413 -0
  51. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_battle_state_formatting.py +204 -0
  52. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection.py +133 -0
  53. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection_comprehensive.py +229 -0
  54. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_direct_agent_emulator.py +300 -0
  55. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_fps_adjustment_pytest.py +205 -0
  56. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_direct.py +200 -0
  57. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_transition.py +284 -0
  58. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_map_ground_truth_comparison.py +468 -0
  59. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_memory_map.py +575 -0
  60. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_server_map_validation.py +311 -0
  61. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_torchic_state.py +259 -0
  62. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/__init__.py +0 -0
  63. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/anticheat.py +372 -0
  64. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/checkpoint.py +296 -0
  65. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/error_handler.py +275 -0
  66. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/get_local_ip.py +22 -0
  67. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/helpers.py +44 -0
  68. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/llm_logger.py +514 -0
  69. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_formatter.py +415 -0
  70. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher.py +1763 -0
  71. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher_singleton.py +33 -0
  72. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_trimmer.py +106 -0
  73. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_visualizer.py +334 -0
  74. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/ocr_dialogue.py +1020 -0
  75. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/recording.py +188 -0
  76. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/state_formatter.py +1481 -0
  77. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/vlm.py +862 -0
  78. examples/task_apps/dev/pokemon_emerald/modal_app.py +114 -0
  79. examples/task_apps/dev/pokemon_emerald/task_app/README.md +81 -0
  80. examples/task_apps/dev/pokemon_emerald/task_app/__init__.py +6 -0
  81. examples/task_apps/dev/pokemon_emerald/task_app/pokemon_emerald.py +685 -0
  82. examples/task_apps/enron/__init__.py +1 -0
  83. examples/task_apps/enron/eval_groq_qwen32.toml +16 -0
  84. examples/task_apps/enron/task_app/README.md +14 -0
  85. examples/task_apps/enron/task_app/__init__.py +1 -0
  86. examples/task_apps/enron/task_app/grpo_enron.py +906 -0
  87. examples/task_apps/enron/task_app/grpo_enron_task_app.py +146 -0
  88. examples/task_apps/enron/tests/__init__.py +2 -0
  89. examples/task_apps/enron/tests/conftest.py +115 -0
  90. examples/task_apps/enron/tests/integration/__init__.py +2 -0
  91. examples/task_apps/enron/tests/integration/test_enron_eval.py +177 -0
  92. examples/task_apps/enron/tests/integration/test_enron_rollout.py +135 -0
  93. examples/task_apps/enron/tests/unit/__init__.py +2 -0
  94. examples/task_apps/enron/tests/unit/test_enron_environment.py +126 -0
  95. examples/task_apps/math/__init__.py +0 -0
  96. examples/{rl/task_app → task_apps/math}/math_single_step.py +19 -10
  97. examples/task_apps/pokemon_battle/__init__.py +2 -0
  98. examples/task_apps/pokemon_battle/modal_app.py +104 -0
  99. examples/task_apps/pokemon_battle/task_app/README.md +68 -0
  100. examples/task_apps/pokemon_battle/task_app/__init__.py +6 -0
  101. examples/task_apps/pokemon_battle/task_app/pokemon_showdown.py +932 -0
  102. examples/task_apps/pokemon_red/README.md +357 -0
  103. examples/task_apps/pokemon_red/__init__.py +3 -0
  104. examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +225 -0
  105. examples/task_apps/pokemon_red/pallet_town_rl_config.toml +73 -0
  106. examples/task_apps/pokemon_red/task_app.py +606 -0
  107. examples/task_apps/pokemon_red/test_pallet_town_rewards.py +191 -0
  108. examples/task_apps/sokoban/README.md +307 -0
  109. examples/task_apps/sokoban/__init__.py +3 -0
  110. examples/task_apps/sokoban/eval_groq_qwen32.toml +16 -0
  111. examples/task_apps/sokoban/eval_openai_gpt5.toml +16 -0
  112. examples/task_apps/sokoban/task_app.py +1058 -0
  113. examples/task_apps/sokoban/tests/__init__.py +2 -0
  114. examples/task_apps/sokoban/tests/conftest.py +113 -0
  115. examples/task_apps/sokoban/tests/integration/__init__.py +2 -0
  116. examples/task_apps/sokoban/tests/integration/test_sokoban_eval.py +57 -0
  117. examples/task_apps/sokoban/tests/integration/test_sokoban_rollout.py +198 -0
  118. examples/task_apps/sokoban/tests/unit/__init__.py +2 -0
  119. examples/task_apps/sokoban/tests/unit/test_sokoban_environment.py +114 -0
  120. examples/task_apps/verilog/__init__.py +1 -0
  121. examples/task_apps/verilog/eval_groq_qwen32b.toml +20 -0
  122. examples/task_apps/verilog/task_app/README.md +12 -0
  123. examples/task_apps/verilog/task_app/__init__.py +1 -0
  124. examples/task_apps/verilog/task_app/grpo_verilog.py +931 -0
  125. examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +145 -0
  126. examples/task_apps/verilog/tests/__init__.py +2 -0
  127. examples/task_apps/verilog/tests/conftest.py +115 -0
  128. examples/task_apps/verilog/tests/integration/__init__.py +2 -0
  129. examples/task_apps/verilog/tests/integration/test_verilog_eval.py +179 -0
  130. examples/task_apps/verilog/tests/integration/test_verilog_rollout.py +55 -0
  131. examples/task_apps/verilog/tests/unit/__init__.py +2 -0
  132. examples/task_apps/verilog/tests/unit/test_verilog_scoring.py +118 -0
  133. examples/vlm/crafter_openai_vlm_agent.py +4 -4
  134. examples/vlm/run_crafter_vlm_benchmark.py +4 -4
  135. examples/warming_up_to_rl/configs/eval_stepwise_complex.toml +4 -2
  136. examples/warming_up_to_rl/configs/eval_stepwise_simple.toml +4 -2
  137. examples/warming_up_to_rl/run_eval.py +127 -18
  138. examples/workflows/__init__.py +0 -0
  139. examples/workflows/math_rl/__init__.py +0 -0
  140. examples/workflows/math_rl/download_dataset.py +80 -0
  141. synth_ai/__init__.py +41 -1
  142. synth_ai/api/train/builders.py +73 -29
  143. synth_ai/api/train/cli.py +12 -6
  144. synth_ai/api/train/configs/__init__.py +44 -0
  145. synth_ai/api/train/configs/rl.py +134 -0
  146. synth_ai/api/train/configs/sft.py +95 -0
  147. synth_ai/api/train/configs/shared.py +24 -0
  148. synth_ai/api/train/env_resolver.py +5 -2
  149. synth_ai/api/train/supported_algos.py +10 -5
  150. synth_ai/api/train/utils.py +7 -4
  151. synth_ai/cli/__init__.py +7 -51
  152. synth_ai/cli/_storage.py +4 -3
  153. synth_ai/cli/_validate_task_app.py +11 -0
  154. synth_ai/cli/balance.py +4 -3
  155. synth_ai/cli/calc.py +2 -2
  156. synth_ai/cli/demo.py +49 -43
  157. synth_ai/cli/legacy_root_backup.py +1 -1
  158. synth_ai/cli/rl_demo.py +86 -106
  159. synth_ai/cli/root.py +0 -97
  160. synth_ai/cli/task_apps.py +1710 -186
  161. synth_ai/demos/core/cli.py +121 -159
  162. synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +28 -16
  163. synth_ai/environments/examples/crafter_classic/environment.py +16 -0
  164. synth_ai/environments/examples/enron/engine.py +7 -2
  165. synth_ai/environments/examples/enron/environment.py +68 -0
  166. synth_ai/environments/examples/red/engine.py +27 -0
  167. synth_ai/environments/examples/red/engine_helpers/memory_map.py +7 -0
  168. synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_progression.py +477 -0
  169. synth_ai/environments/examples/red/engine_helpers/state_extraction.py +32 -0
  170. synth_ai/environments/examples/red/environment.py +60 -0
  171. synth_ai/environments/examples/sokoban/taskset.py +116 -0
  172. synth_ai/environments/examples/verilog/engine.py +30 -4
  173. synth_ai/evals/__init__.py +15 -0
  174. synth_ai/evals/client.py +82 -0
  175. synth_ai/evals/types.py +42 -0
  176. synth_ai/jobs/client.py +16 -4
  177. synth_ai/judge_schemas.py +127 -0
  178. synth_ai/py.typed +0 -0
  179. synth_ai/task/__init__.py +14 -5
  180. synth_ai/task/contracts.py +124 -38
  181. synth_ai/task/proxy.py +48 -56
  182. synth_ai/task/rubrics/__init__.py +53 -0
  183. synth_ai/task/rubrics/loaders.py +133 -0
  184. synth_ai/task/rubrics/models.py +57 -0
  185. synth_ai/task/rubrics/scoring.py +113 -0
  186. synth_ai/task/rubrics/strict.py +149 -0
  187. synth_ai/task/server.py +8 -7
  188. synth_ai/task/validators.py +269 -6
  189. synth_ai/tracing_v3/decorators.py +7 -3
  190. synth_ai/tracing_v3/replica_sync.py +4 -4
  191. synth_ai/tracing_v3/serialization.py +130 -0
  192. synth_ai/tracing_v3/trace_utils.py +317 -0
  193. synth_ai/tracing_v3/turso/native_manager.py +3 -3
  194. {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev2.dist-info}/METADATA +4 -1
  195. {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev2.dist-info}/RECORD +228 -89
  196. {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev2.dist-info}/entry_points.txt +0 -1
  197. synth_ai/task/rubrics.py +0 -219
  198. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/README.md +0 -0
  199. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/README.md +0 -0
  200. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/__init__.py +0 -0
  201. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/branching.py +0 -0
  202. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/environment_routes.py +0 -0
  203. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/__init__.py +0 -0
  204. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/__init__.py +0 -0
  205. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/app.py +0 -0
  206. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/environment.py +0 -0
  207. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/react_agent.py +0 -0
  208. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/shared.py +0 -0
  209. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/tools.py +0 -0
  210. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/hosted_app.py +0 -0
  211. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/__init__.py +0 -0
  212. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/main.py +0 -0
  213. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/registry.py +0 -0
  214. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/__init__.py +0 -0
  215. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/volume.py +0 -0
  216. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_agents.py +0 -0
  217. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/utils.py +0 -0
  218. /examples/{rl/task_app → task_apps/math}/README.md +0 -0
  219. /examples/{rl/task_app → task_apps/math}/math_task_app.py +0 -0
  220. /examples/{rl → workflows/math_rl}/configs/eval_base_qwen.toml +0 -0
  221. /examples/{rl → workflows/math_rl}/configs/eval_rl_qwen.toml +0 -0
  222. /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen.toml +0 -0
  223. /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen17.toml +0 -0
  224. /examples/{rl → workflows/math_rl}/configs/rl_from_ft_qwen.toml +0 -0
  225. /examples/{rl → workflows/math_rl}/run_eval.py +0 -0
  226. /examples/{rl → workflows/math_rl}/run_rl_and_save.py +0 -0
  227. {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev2.dist-info}/WHEEL +0 -0
  228. {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev2.dist-info}/licenses/LICENSE +0 -0
  229. {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,606 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, Iterable, Mapping, Sequence
4
+
5
+ from fastapi import HTTPException, Request
6
+ import httpx
7
+
8
+ from synth_ai.environments.examples.red.environment import PokemonRedEnvironment
9
+ from synth_ai.environments.environment.tools import EnvToolCall
10
+ from synth_ai.environments.examples.red.taskset import INSTANCE as RED_DEFAULT_INSTANCE
11
+ from synth_ai.environments.examples.red.engine_helpers.reward_library.pallet_town_progression import (
12
+ PalletTownProgressionCompositeReward,
13
+ )
14
+ from synth_ai.task.apps import TaskAppEntry, register_task_app
15
+ from synth_ai.task.contracts import (
16
+ RolloutMetrics,
17
+ RolloutRequest,
18
+ RolloutResponse,
19
+ RolloutStep,
20
+ RolloutTrajectory,
21
+ TaskInfo,
22
+ )
23
+ from synth_ai.task.server import ProxyConfig, TaskAppConfig
24
+
25
+
26
+ def _base_task_info() -> TaskInfo:
27
+ return TaskInfo(
28
+ task={"id": "pokemon_red", "name": "Pokémon Red", "version": "0.1.0"},
29
+ environment="pokemon_red",
30
+ action_space={
31
+ "type": "tool_call",
32
+ "tools": [
33
+ {
34
+ "name": "press_button",
35
+ "schema": {"button": "string", "frames": "int"},
36
+ },
37
+ {
38
+ "name": "execute_sequence",
39
+ "description": "Execute multiple button presses in sequence. More efficient than separate calls. Recommended: 5-10 actions per call.",
40
+ "schema": {
41
+ "type": "object",
42
+ "properties": {
43
+ "actions": {
44
+ "type": "array",
45
+ "items": {
46
+ "type": "object",
47
+ "properties": {
48
+ "button": {"type": "string", "enum": ["UP", "DOWN", "LEFT", "RIGHT", "A", "B", "START", "SELECT"]},
49
+ "frames": {"type": "integer", "minimum": 1, "maximum": 120}
50
+ },
51
+ "required": ["button", "frames"]
52
+ },
53
+ "minItems": 1,
54
+ "maxItems": 20
55
+ }
56
+ },
57
+ "required": ["actions"]
58
+ },
59
+ }
60
+ ],
61
+ "max_calls": 1,
62
+ },
63
+ observation={
64
+ "summary": "GB memory-derived state with reward fields.",
65
+ "keys": [
66
+ "position",
67
+ "badges_earned",
68
+ "badges_bitfield",
69
+ "hp_status",
70
+ "party_level",
71
+ "party_xp",
72
+ "in_battle",
73
+ "step_count",
74
+ "reward_last_step",
75
+ "total_reward",
76
+ "terminated",
77
+ ],
78
+ },
79
+ dataset={"id": "pokemon_red_default", "name": "Pokémon Red Default", "version": "0.1.0"},
80
+ rubric={"version": "1", "criteria_count": 1, "source": "inline"},
81
+ inference={
82
+ "supports_proxy": True,
83
+ "tool": {"name": "press_button", "parallel_tool_calls": False},
84
+ "endpoints": {
85
+ "openai": "/proxy/v1/chat/completions",
86
+ "groq": "/proxy/groq/v1/chat/completions",
87
+ },
88
+ },
89
+ limits={"max_steps": 1000},
90
+ )
91
+
92
+
93
+ def _describe_taskset() -> dict[str, Any]:
94
+ return {"id": "pokemon_red_default", "name": "Pokémon Red Default"}
95
+
96
+
97
+ def _provide_task_instances(seeds: Sequence[int]) -> Iterable[TaskInfo]:
98
+ base = _base_task_info()
99
+ for s in seeds:
100
+ yield TaskInfo(
101
+ task=base.task,
102
+ environment=base.environment,
103
+ action_space=base.action_space,
104
+ observation={**base.observation, "seed": s},
105
+ dataset=base.dataset,
106
+ rubric=base.rubric,
107
+ inference=base.inference,
108
+ limits=base.limits,
109
+ )
110
+
111
+
112
+ def _build_action_context(prev_state: dict[str, Any], current_state: dict[str, Any]) -> dict[str, Any]:
113
+ """Build action context dict with prev_ fields for reward calculation."""
114
+ return {
115
+ "prev_map_id": prev_state.get("map_id", 0),
116
+ "prev_player_x": prev_state.get("player_x", 0),
117
+ "prev_player_y": prev_state.get("player_y", 0),
118
+ "prev_party_count": prev_state.get("party_count", 0),
119
+ "prev_in_battle": prev_state.get("in_battle", False),
120
+ "prev_text_box_active": prev_state.get("text_box_active", False),
121
+ "prev_enemy_hp_current": prev_state.get("enemy_hp_current", 0),
122
+ "prev_enemy_hp_percentage": prev_state.get("enemy_hp_percentage", 0.0),
123
+ "prev_badges": prev_state.get("badges", 0),
124
+ "prev_party_level": prev_state.get("party_level", 0),
125
+ "prev_party_xp": prev_state.get("party_xp", 0),
126
+ }
127
+
128
+
129
+ def _describe_milestone(current_state: dict[str, Any], prev_state: dict[str, Any], reward: float) -> str:
130
+ """Generate human-readable milestone description."""
131
+ descriptions = []
132
+
133
+ # Map transitions
134
+ prev_map = prev_state.get("map_id", -1)
135
+ curr_map = current_state.get("map_id", -1)
136
+ if prev_map != curr_map:
137
+ map_names = {0: "Pallet Town", 1: "Bedroom", 2: "House", 3: "Oak's Lab"}
138
+ descriptions.append(f"Moved from {map_names.get(prev_map, f'Map{prev_map}')} to {map_names.get(curr_map, f'Map{curr_map}')}")
139
+
140
+ # Party changes
141
+ prev_party = prev_state.get("party_count", 0)
142
+ curr_party = current_state.get("party_count", 0)
143
+ if curr_party > prev_party:
144
+ descriptions.append(f"Received Pokémon (party: {prev_party}→{curr_party})")
145
+
146
+ # Battle state
147
+ prev_battle = prev_state.get("in_battle", False)
148
+ curr_battle = current_state.get("in_battle", False)
149
+ if not prev_battle and curr_battle:
150
+ descriptions.append("Entered battle")
151
+ elif prev_battle and not curr_battle:
152
+ battle_outcome = current_state.get("battle_outcome", 0)
153
+ if battle_outcome == 1:
154
+ descriptions.append("Won battle")
155
+ elif battle_outcome == 2:
156
+ descriptions.append("Lost battle")
157
+
158
+ # HP damage
159
+ prev_enemy_hp = prev_state.get("enemy_hp_current", 0)
160
+ curr_enemy_hp = current_state.get("enemy_hp_current", 0)
161
+ if prev_enemy_hp > curr_enemy_hp > 0:
162
+ damage = prev_enemy_hp - curr_enemy_hp
163
+ descriptions.append(f"Dealt {damage} damage to enemy")
164
+
165
+ return " | ".join(descriptions) if descriptions else f"Progress (+{reward:.0f})"
166
+
167
+
168
+ def _calculate_outcome_score(final_state: dict[str, Any], total_reward: float) -> float:
169
+ """Calculate outcome score based on final state and total reward."""
170
+ # Normalize reward to 0-1 scale (max expected is ~700)
171
+ reward_score = min(total_reward / 700.0, 1.0)
172
+
173
+ # Bonus for having Pokemon
174
+ has_pokemon = 1.0 if final_state.get("party_count", 0) > 0 else 0.0
175
+
176
+ # Bonus for being in Oak's lab or having left it
177
+ map_id = final_state.get("map_id", -1)
178
+ map_bonus = 0.5 if map_id in [0, 3] else 0.0 # Pallet Town or Oak's Lab
179
+
180
+ # Weighted combination
181
+ return (reward_score * 0.7) + (has_pokemon * 0.2) + (map_bonus * 0.1)
182
+
183
+
184
+ async def rollout_executor(request: RolloutRequest, fastapi_request: Request) -> RolloutResponse:
185
+ async def _call_inference(policy_cfg: Mapping[str, Any], observation: Mapping[str, Any]) -> Mapping[str, Any]:
186
+ messages = [
187
+ {
188
+ "role": "system",
189
+ "content": (
190
+ "You are controlling Pokémon Red. Respond with a single tool call named 'press_button' "
191
+ "with JSON arguments {button: 'A|B|UP|DOWN|LEFT|RIGHT|START|SELECT', frames: 1-120}."
192
+ ),
193
+ },
194
+ {
195
+ "role": "user",
196
+ "content": (
197
+ "State summary: " + str({k: observation.get(k) for k in observation.keys() if k != "error"})
198
+ ),
199
+ },
200
+ ]
201
+ payload = {
202
+ "model": policy_cfg.get("model") or "qwen-2.5-7b",
203
+ "messages": messages,
204
+ "tools": [
205
+ {
206
+ "type": "function",
207
+ "function": {
208
+ "name": "execute_sequence",
209
+ "description": "Execute multiple button presses in sequence. More efficient than separate calls. Recommended: 5-10 actions per call.",
210
+ "parameters": {
211
+ "type": "object",
212
+ "properties": {
213
+ "actions": {
214
+ "type": "array",
215
+ "items": {
216
+ "type": "object",
217
+ "properties": {
218
+ "button": {
219
+ "type": "string",
220
+ "enum": ["UP", "DOWN", "LEFT", "RIGHT", "A", "B", "START", "SELECT"],
221
+ "description": "Game Boy button to press"
222
+ },
223
+ "frames": {
224
+ "type": "integer",
225
+ "minimum": 1,
226
+ "maximum": 120,
227
+ "description": "Number of frames to hold the button (30 frames = 0.5 seconds)"
228
+ }
229
+ },
230
+ "required": ["button", "frames"]
231
+ },
232
+ "minItems": 1,
233
+ "maxItems": 20,
234
+ "description": "Sequence of button presses to execute"
235
+ }
236
+ },
237
+ "required": ["actions"],
238
+ "additionalProperties": False,
239
+ },
240
+ },
241
+ },
242
+ {
243
+ "type": "function",
244
+ "function": {
245
+ "name": "press_button",
246
+ "description": "Press a single Game Boy button for N frames (use execute_sequence for multiple actions)",
247
+ "parameters": {
248
+ "type": "object",
249
+ "properties": {
250
+ "button": {"type": "string", "enum": ["UP", "DOWN", "LEFT", "RIGHT", "A", "B", "START", "SELECT"]},
251
+ "frames": {"type": "integer", "minimum": 1, "maximum": 120},
252
+ },
253
+ "required": ["button"],
254
+ "additionalProperties": False,
255
+ },
256
+ },
257
+ }
258
+ ],
259
+ "tool_choice": {"type": "function", "function": {"name": "execute_sequence"}},
260
+ "temperature": float(policy_cfg.get("temperature") or 0.0),
261
+ "top_p": float(policy_cfg.get("top_p") or 1.0),
262
+ "max_tokens": int(policy_cfg.get("max_tokens") or 500),
263
+ }
264
+ inference_url = str(policy_cfg.get("inference_url") or "").rstrip("/")
265
+ if not inference_url:
266
+ # Prefer built-in proxy endpoints from app if no external URL
267
+ provider = (policy_cfg.get("provider") or "").lower()
268
+ if provider == "groq":
269
+ inference_url = "/proxy/groq/v1/chat/completions"
270
+ else:
271
+ inference_url = "/proxy/v1/chat/completions"
272
+ async with httpx.AsyncClient(base_url="http://127.0.0.1:" + str(fastapi_request.url.port or 8913), timeout=httpx.Timeout(60.0)) as client: # best-effort
273
+ resp = await client.post(inference_url, json=payload)
274
+ resp.raise_for_status()
275
+ data = resp.json()
276
+ # Extract first tool call
277
+ choices = data.get("choices") or []
278
+ if not choices:
279
+ return {}
280
+ message = choices[0].get("message") or {}
281
+ raw_calls = message.get("tool_calls") or []
282
+ if not raw_calls:
283
+ return {}
284
+ f = raw_calls[0].get("function") or {}
285
+ tool_name = f.get("name", "")
286
+ args = f.get("arguments")
287
+ import json as _json
288
+ try:
289
+ parsed_args = _json.loads(args) if isinstance(args, str) else dict(args or {})
290
+ except Exception:
291
+ parsed_args = {}
292
+
293
+ # Handle execute_sequence tool
294
+ if tool_name == "execute_sequence":
295
+ return {"actions": parsed_args.get("actions", [])}
296
+
297
+ # Handle press_button tool (legacy single action)
298
+ return {"button": parsed_args.get("button"), "frames": int(parsed_args.get("frames") or 30)}
299
+
300
+ # Initialize reward function
301
+ reward_fn = PalletTownProgressionCompositeReward()
302
+
303
+ env = PokemonRedEnvironment(RED_DEFAULT_INSTANCE)
304
+ obs0 = await env.initialize()
305
+
306
+ # Track cumulative stats
307
+ total_reward = 0.0
308
+ all_reward_components: list[dict[str, Any]] = []
309
+ milestone_events: list[dict[str, Any]] = []
310
+
311
+ steps: list[RolloutStep] = [
312
+ RolloutStep(obs=obs0, tool_calls=[], reward=0.0, done=False, info={"step_type": "initial"}),
313
+ ]
314
+
315
+ # Track previous state for reward calculation
316
+ prev_state = dict(obs0) if isinstance(obs0, Mapping) else {}
317
+
318
+ # Process all ops (explicit actions)
319
+ final_obs = obs0
320
+ for step_idx, op in enumerate(request.ops or []):
321
+ macro = None
322
+ if isinstance(op, dict):
323
+ macro = op.get("action") or op
324
+
325
+ if isinstance(macro, dict):
326
+ # Check if this is an execute_sequence call
327
+ if "actions" in macro:
328
+ # Handle execute_sequence: multiple actions in one call
329
+ actions_list = macro.get("actions", [])
330
+ sequence_reward = 0.0
331
+ sequence_tool_calls = []
332
+
333
+ for action_item in actions_list:
334
+ button = action_item.get("button", "A")
335
+ frames = int(action_item.get("frames", 1))
336
+
337
+ obs1 = await env.step(EnvToolCall(tool="press_button", args={"button": button, "frames": frames}))
338
+ current_state = dict(obs1) if isinstance(obs1, Mapping) else {}
339
+ action_context = _build_action_context(prev_state, current_state)
340
+ step_reward = await reward_fn.score(current_state, action_context)
341
+
342
+ sequence_reward += step_reward
343
+ sequence_tool_calls.append({"tool": "press_button", "args": {"button": button, "frames": frames}})
344
+
345
+ if step_reward > 0:
346
+ reward_component = {
347
+ "step": step_idx + 1,
348
+ "reward": step_reward,
349
+ "button": button,
350
+ "map_id": current_state.get("map_id"),
351
+ "position": f"({current_state.get('player_x')},{current_state.get('player_y')})",
352
+ }
353
+ all_reward_components.append(reward_component)
354
+ milestone_events.append({
355
+ "type": "milestone",
356
+ "step": step_idx + 1,
357
+ "reward": step_reward,
358
+ "description": _describe_milestone(current_state, prev_state, step_reward),
359
+ })
360
+
361
+ final_obs = obs1
362
+ prev_state = current_state
363
+
364
+ total_reward += sequence_reward
365
+ step_info = {
366
+ "step_type": "sequence",
367
+ "step_idx": step_idx,
368
+ "actions_count": len(actions_list),
369
+ "cumulative_reward": total_reward,
370
+ }
371
+ if sequence_reward > 0:
372
+ step_info["sequence_reward"] = sequence_reward
373
+
374
+ steps.append(
375
+ RolloutStep(
376
+ obs=final_obs,
377
+ tool_calls=sequence_tool_calls,
378
+ reward=sequence_reward,
379
+ done=False,
380
+ info=step_info,
381
+ )
382
+ )
383
+ else:
384
+ # Handle single press_button call
385
+ button = macro.get("button") or "A"
386
+ frames = int(macro.get("frames") or 1)
387
+ obs1 = await env.step(EnvToolCall(tool="press_button", args={"button": button, "frames": frames}))
388
+
389
+ # Calculate step reward
390
+ current_state = dict(obs1) if isinstance(obs1, Mapping) else {}
391
+ action_context = _build_action_context(prev_state, current_state)
392
+ step_reward = await reward_fn.score(current_state, action_context)
393
+ total_reward += step_reward
394
+
395
+ # Track reward components if non-zero
396
+ step_info: dict[str, Any] = {"step_type": "action", "step_idx": step_idx}
397
+ if step_reward > 0:
398
+ reward_component = {
399
+ "step": step_idx + 1,
400
+ "reward": step_reward,
401
+ "button": button,
402
+ "map_id": current_state.get("map_id"),
403
+ "position": f"({current_state.get('player_x')},{current_state.get('player_y')})",
404
+ }
405
+ all_reward_components.append(reward_component)
406
+ step_info["reward_component"] = reward_component
407
+
408
+ # Track milestone events
409
+ milestone_events.append({
410
+ "type": "milestone",
411
+ "step": step_idx + 1,
412
+ "reward": step_reward,
413
+ "description": _describe_milestone(current_state, prev_state, step_reward),
414
+ })
415
+
416
+ step_info["cumulative_reward"] = total_reward
417
+
418
+ steps.append(
419
+ RolloutStep(
420
+ obs=obs1,
421
+ tool_calls=[{"tool": "press_button", "args": {"button": button, "frames": frames}}],
422
+ reward=step_reward,
423
+ done=False,
424
+ info=step_info,
425
+ )
426
+ )
427
+ final_obs = obs1
428
+ prev_state = current_state
429
+ else:
430
+ # Attempt policy-driven step if policy.config present
431
+ policy_cfg = request.policy.config or {}
432
+ if policy_cfg:
433
+ try:
434
+ action = await _call_inference(policy_cfg, final_obs if isinstance(final_obs, Mapping) else {})
435
+
436
+ # Handle execute_sequence from policy
437
+ if "actions" in action:
438
+ actions_list = action.get("actions", [])
439
+ sequence_reward = 0.0
440
+ sequence_tool_calls = []
441
+
442
+ for action_item in actions_list:
443
+ button = action_item.get("button", "A")
444
+ frames = int(action_item.get("frames", 30))
445
+
446
+ obs1 = await env.step(EnvToolCall(tool="press_button", args={"button": button, "frames": frames}))
447
+ current_state = dict(obs1) if isinstance(obs1, Mapping) else {}
448
+ action_context = _build_action_context(prev_state, current_state)
449
+ step_reward = await reward_fn.score(current_state, action_context)
450
+
451
+ sequence_reward += step_reward
452
+ sequence_tool_calls.append({"tool": "press_button", "args": {"button": button, "frames": frames}})
453
+
454
+ if step_reward > 0:
455
+ reward_component = {
456
+ "step": step_idx + 1,
457
+ "reward": step_reward,
458
+ "button": button,
459
+ "map_id": current_state.get("map_id"),
460
+ "position": f"({current_state.get('player_x')},{current_state.get('player_y')})",
461
+ }
462
+ all_reward_components.append(reward_component)
463
+ milestone_events.append({
464
+ "type": "milestone",
465
+ "step": step_idx + 1,
466
+ "reward": step_reward,
467
+ "description": _describe_milestone(current_state, prev_state, step_reward),
468
+ })
469
+
470
+ final_obs = obs1
471
+ prev_state = current_state
472
+
473
+ total_reward += sequence_reward
474
+ step_info = {
475
+ "step_type": "policy_sequence",
476
+ "step_idx": step_idx,
477
+ "actions_count": len(actions_list),
478
+ "cumulative_reward": total_reward,
479
+ }
480
+ if sequence_reward > 0:
481
+ step_info["sequence_reward"] = sequence_reward
482
+
483
+ steps.append(
484
+ RolloutStep(
485
+ obs=final_obs,
486
+ tool_calls=sequence_tool_calls,
487
+ reward=sequence_reward,
488
+ done=False,
489
+ info=step_info,
490
+ )
491
+ )
492
+
493
+ # Handle single button press from policy
494
+ elif action.get("button"):
495
+ obs1 = await env.step(EnvToolCall(tool="press_button", args=action))
496
+
497
+ # Calculate step reward
498
+ current_state = dict(obs1) if isinstance(obs1, Mapping) else {}
499
+ action_context = _build_action_context(prev_state, current_state)
500
+ step_reward = await reward_fn.score(current_state, action_context)
501
+ total_reward += step_reward
502
+
503
+ step_info_policy: dict[str, Any] = {
504
+ "step_type": "policy",
505
+ "step_idx": step_idx,
506
+ "cumulative_reward": total_reward,
507
+ "proxy": True,
508
+ }
509
+ if step_reward > 0:
510
+ step_info_policy["reward_earned"] = step_reward
511
+
512
+ steps.append(
513
+ RolloutStep(
514
+ obs=obs1,
515
+ tool_calls=[{"tool": "press_button", "args": action}],
516
+ reward=step_reward,
517
+ done=False,
518
+ info=step_info_policy,
519
+ )
520
+ )
521
+ final_obs = obs1
522
+ prev_state = current_state
523
+ except Exception:
524
+ pass
525
+
526
+ # Calculate outcome score based on milestones achieved
527
+ final_state = dict(final_obs) if isinstance(final_obs, Mapping) else {}
528
+ outcome_score = _calculate_outcome_score(final_state, total_reward)
529
+
530
+ metrics = RolloutMetrics(
531
+ episode_returns=[total_reward],
532
+ mean_return=total_reward,
533
+ num_steps=len(steps),
534
+ num_episodes=1,
535
+ outcome_score=outcome_score,
536
+ details={
537
+ "total_reward": total_reward,
538
+ "reward_components": all_reward_components,
539
+ "milestone_events": milestone_events,
540
+ "final_map": final_state.get("map_id"),
541
+ "party_count": final_state.get("party_count", 0),
542
+ "badges": final_state.get("badges", 0),
543
+ },
544
+ )
545
+
546
+ # Extract inference_url from policy config
547
+ inference_url = (policy_cfg or {}).get("inference_url")
548
+
549
+ trajectory = RolloutTrajectory(
550
+ env_id="pokemon_red",
551
+ policy_id=request.policy.policy_id or "policy",
552
+ steps=steps,
553
+ final={"observation": final_obs, "reward": total_reward},
554
+ length=len(steps),
555
+ inference_url=inference_url, # NEW: Required for trace correlation
556
+ )
557
+
558
+ return RolloutResponse(
559
+ run_id=request.run_id,
560
+ trajectories=[trajectory],
561
+ branches={},
562
+ metrics=metrics,
563
+ aborted=False,
564
+ ops_executed=len(request.ops or []),
565
+ )
566
+
567
+
568
+ def build_config() -> TaskAppConfig:
569
+ base_info = _base_task_info()
570
+ return TaskAppConfig(
571
+ app_id="pokemon_red",
572
+ name="Pokémon Red Task App",
573
+ description="Expose Pokémon Red via Synth task framework (demo).",
574
+ base_task_info=base_info,
575
+ describe_taskset=_describe_taskset,
576
+ provide_task_instances=_provide_task_instances,
577
+ rollout=rollout_executor,
578
+ dataset_registry=None,
579
+ proxy=ProxyConfig(
580
+ enable_openai=True,
581
+ enable_groq=True,
582
+ system_hint=(
583
+ "You control Pokémon Red. Use 'execute_sequence' with 5-10 actions to play efficiently. "
584
+ "Plan ahead: navigate rooms, advance dialogue, battle strategically. "
585
+ "Example: {\"tool\": \"execute_sequence\", \"args\": {\"actions\": [{\"button\": \"DOWN\", \"frames\": 30}, ...]}}"
586
+ ),
587
+ ),
588
+ app_state={},
589
+ require_api_key=False,
590
+ expose_debug_env=True,
591
+ cors_origins=["*"],
592
+ )
593
+
594
+
595
+ register_task_app(
596
+ entry=TaskAppEntry(
597
+ app_id="pokemon_red",
598
+ description="Pokémon Red demo task app",
599
+ config_factory=build_config,
600
+ aliases=("pokemon_red_demo",),
601
+ env_files=(),
602
+ modal=None,
603
+ )
604
+ )
605
+
606
+