synth-ai 0.2.16__py3-none-any.whl → 0.2.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (299) hide show
  1. examples/analyze_semantic_words.sh +2 -2
  2. examples/baseline/banking77_baseline.py +204 -0
  3. examples/baseline/crafter_baseline.py +407 -0
  4. examples/baseline/pokemon_red_baseline.py +326 -0
  5. examples/baseline/simple_baseline.py +56 -0
  6. examples/baseline/warming_up_to_rl_baseline.py +239 -0
  7. examples/blog_posts/gepa/README.md +355 -0
  8. examples/blog_posts/gepa/configs/banking77_gepa_local.toml +95 -0
  9. examples/blog_posts/gepa/configs/banking77_gepa_test.toml +82 -0
  10. examples/blog_posts/gepa/configs/banking77_mipro_local.toml +52 -0
  11. examples/blog_posts/gepa/configs/hotpotqa_gepa_local.toml +59 -0
  12. examples/blog_posts/gepa/configs/hotpotqa_gepa_qwen.toml +36 -0
  13. examples/blog_posts/gepa/configs/hotpotqa_mipro_local.toml +53 -0
  14. examples/blog_posts/gepa/configs/hover_gepa_local.toml +59 -0
  15. examples/blog_posts/gepa/configs/hover_gepa_qwen.toml +36 -0
  16. examples/blog_posts/gepa/configs/hover_mipro_local.toml +53 -0
  17. examples/blog_posts/gepa/configs/ifbench_gepa_local.toml +59 -0
  18. examples/blog_posts/gepa/configs/ifbench_gepa_qwen.toml +36 -0
  19. examples/blog_posts/gepa/configs/ifbench_mipro_local.toml +53 -0
  20. examples/blog_posts/gepa/configs/pupa_gepa_local.toml +60 -0
  21. examples/blog_posts/gepa/configs/pupa_mipro_local.toml +54 -0
  22. examples/blog_posts/gepa/deploy_banking77_task_app.sh +41 -0
  23. examples/blog_posts/gepa/gepa_baseline.py +204 -0
  24. examples/blog_posts/gepa/query_prompts_example.py +97 -0
  25. examples/blog_posts/gepa/run_gepa_banking77.sh +87 -0
  26. examples/blog_posts/gepa/task_apps.py +105 -0
  27. examples/blog_posts/gepa/test_gepa_local.sh +67 -0
  28. examples/blog_posts/gepa/verify_banking77_setup.sh +123 -0
  29. examples/blog_posts/pokemon_vl/README.md +98 -0
  30. examples/blog_posts/pokemon_vl/configs/eval_gpt5nano.toml +26 -0
  31. examples/blog_posts/pokemon_vl/configs/eval_qwen3_vl.toml +27 -0
  32. examples/blog_posts/pokemon_vl/configs/eval_rl_final.toml +24 -0
  33. examples/blog_posts/pokemon_vl/configs/filter_high_reward.toml +10 -0
  34. examples/blog_posts/pokemon_vl/configs/train_rl_from_sft.toml +43 -0
  35. examples/blog_posts/pokemon_vl/configs/train_sft_qwen4b_vl.toml +40 -0
  36. examples/blog_posts/pokemon_vl/extract_images.py +239 -0
  37. examples/blog_posts/pokemon_vl/pokemon_vl_baseline.py +326 -0
  38. examples/blog_posts/pokemon_vl/run_eval_extract_images.py +209 -0
  39. examples/blog_posts/pokemon_vl/run_qwen_eval_extract_images.py +212 -0
  40. examples/blog_posts/pokemon_vl/text_box_analysis.md +106 -0
  41. examples/blog_posts/warming_up_to_rl/ARCHITECTURE.md +195 -0
  42. examples/blog_posts/warming_up_to_rl/FINAL_TEST_RESULTS.md +127 -0
  43. examples/blog_posts/warming_up_to_rl/INFERENCE_SUCCESS.md +132 -0
  44. examples/blog_posts/warming_up_to_rl/README.md +158 -0
  45. examples/blog_posts/warming_up_to_rl/SMOKE_TESTING.md +164 -0
  46. examples/blog_posts/warming_up_to_rl/SMOKE_TEST_COMPLETE.md +253 -0
  47. examples/blog_posts/warming_up_to_rl/configs/eval_baseline_qwen32b_10x20.toml +25 -0
  48. examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b.toml +25 -0
  49. examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b_10x20.toml +26 -0
  50. examples/blog_posts/warming_up_to_rl/configs/eval_groq_qwen32b.toml +25 -0
  51. examples/blog_posts/warming_up_to_rl/configs/eval_openai_gpt_oss_120b.toml +29 -0
  52. examples/blog_posts/warming_up_to_rl/configs/filter_high_reward_dataset.toml +10 -0
  53. examples/blog_posts/warming_up_to_rl/configs/smoke_test.toml +75 -0
  54. examples/blog_posts/warming_up_to_rl/configs/train_rl_from_sft.toml +91 -0
  55. examples/blog_posts/warming_up_to_rl/configs/train_sft_qwen4b.toml +40 -0
  56. examples/blog_posts/warming_up_to_rl/warming_up_to_rl_baseline.py +187 -0
  57. examples/dev/qwen3_32b_qlora_4xh100.toml +5 -0
  58. examples/multi_step/configs/VERILOG_REWARDS.md +4 -0
  59. examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +4 -0
  60. examples/multi_step/configs/crafter_rl_outcome.toml +2 -1
  61. examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +65 -107
  62. examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +2 -1
  63. examples/multi_step/configs/crafter_rl_stepwise_simple.toml +2 -1
  64. examples/multi_step/configs/crafter_rl_stepwise_simple_NEW_FORMAT.toml +105 -0
  65. examples/multi_step/configs/verilog_rl_lora.toml +80 -123
  66. examples/qwen_coder/configs/coder_lora_30b.toml +1 -3
  67. examples/qwen_coder/configs/coder_lora_4b.toml +4 -1
  68. examples/qwen_coder/configs/coder_lora_small.toml +1 -3
  69. examples/qwen_vl/README.md +10 -12
  70. examples/qwen_vl/SETUP_COMPLETE.md +7 -8
  71. examples/qwen_vl/VISION_TESTS_COMPLETE.md +2 -3
  72. examples/qwen_vl/collect_data_via_cli.md +76 -84
  73. examples/qwen_vl/collect_vision_traces.py +4 -4
  74. examples/qwen_vl/configs/crafter_rl_vision_qwen3vl4b.toml +40 -57
  75. examples/qwen_vl/configs/crafter_vlm_sft_example.toml +1 -2
  76. examples/qwen_vl/configs/eval_gpt4o_mini_vision.toml +20 -37
  77. examples/qwen_vl/configs/eval_gpt5nano_vision.toml +21 -40
  78. examples/qwen_vl/configs/eval_qwen3vl_vision.toml +26 -0
  79. examples/qwen_vl/configs/{filter_qwen2vl_sft.toml → filter_qwen3vl_sft.toml} +4 -5
  80. examples/qwen_vl/configs/filter_vision_sft.toml +2 -3
  81. examples/qwen_vl/crafter_qwen_vl_agent.py +5 -5
  82. examples/qwen_vl/run_vision_comparison.sh +6 -7
  83. examples/rl/README.md +5 -5
  84. examples/rl/configs/rl_from_base_qwen.toml +26 -1
  85. examples/rl/configs/rl_from_base_qwen17.toml +6 -2
  86. examples/rl/task_app/README.md +1 -2
  87. examples/rl/task_app/math_single_step.py +2 -2
  88. examples/run_crafter_demo.sh +2 -2
  89. examples/sft/README.md +1 -1
  90. examples/sft/configs/crafter_fft_qwen0p6b.toml +4 -1
  91. examples/sft/configs/crafter_lora_qwen0p6b.toml +4 -1
  92. examples/swe/task_app/README.md +32 -2
  93. examples/swe/task_app/grpo_swe_mini.py +4 -0
  94. examples/swe/task_app/hosted/envs/crafter/react_agent.py +1 -1
  95. examples/swe/task_app/hosted/envs/mini_swe/environment.py +37 -10
  96. examples/swe/task_app/hosted/inference/openai_client.py +4 -38
  97. examples/swe/task_app/hosted/policy_routes.py +17 -0
  98. examples/swe/task_app/hosted/rollout.py +4 -2
  99. examples/swe/task_app/morph_backend.py +178 -0
  100. examples/task_apps/banking77/__init__.py +6 -0
  101. examples/task_apps/banking77/banking77_task_app.py +841 -0
  102. examples/task_apps/banking77/deploy_wrapper.py +46 -0
  103. examples/task_apps/crafter/CREATE_SFT_DATASET.md +4 -0
  104. examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +4 -0
  105. examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +4 -0
  106. examples/task_apps/crafter/task_app/README.md +1 -1
  107. examples/task_apps/crafter/task_app/grpo_crafter.py +90 -5
  108. examples/task_apps/crafter/task_app/grpo_crafter_task_app.py +1 -1
  109. examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/policy.py +4 -26
  110. examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/react_agent.py +1 -2
  111. examples/task_apps/crafter/task_app/synth_envs_hosted/hosted_app.py +49 -0
  112. examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +372 -107
  113. examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +81 -12
  114. examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +82 -11
  115. examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +194 -1
  116. examples/task_apps/enron/task_app/grpo_enron_task_app.py +1 -1
  117. examples/task_apps/gepa_benchmarks/__init__.py +7 -0
  118. examples/task_apps/gepa_benchmarks/common.py +260 -0
  119. examples/task_apps/gepa_benchmarks/hotpotqa_task_app.py +507 -0
  120. examples/task_apps/gepa_benchmarks/hover_task_app.py +436 -0
  121. examples/task_apps/gepa_benchmarks/ifbench_task_app.py +563 -0
  122. examples/task_apps/gepa_benchmarks/pupa_task_app.py +460 -0
  123. examples/task_apps/math/README.md +1 -2
  124. examples/task_apps/pokemon_red/README.md +3 -4
  125. examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +4 -0
  126. examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +6 -5
  127. examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +1 -2
  128. examples/task_apps/pokemon_red/task_app.py +288 -39
  129. examples/task_apps/sokoban/README.md +2 -3
  130. examples/task_apps/verilog/eval_groq_qwen32b.toml +12 -14
  131. examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +1 -1
  132. examples/vlm/configs/crafter_vlm_gpt4o.toml +4 -1
  133. examples/warming_up_to_rl/configs/crafter_fft.toml +4 -1
  134. examples/warming_up_to_rl/configs/crafter_fft_4b.toml +0 -2
  135. examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +3 -2
  136. examples/warming_up_to_rl/run_local_rollout_traced.py +1 -1
  137. examples/warming_up_to_rl/task_app/README.md +1 -1
  138. examples/warming_up_to_rl/task_app/grpo_crafter.py +185 -5
  139. examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +1 -1
  140. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +3 -27
  141. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +1 -1
  142. examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +49 -0
  143. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +156 -45
  144. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +37 -4
  145. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +33 -3
  146. examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +67 -0
  147. examples/workflows/math_rl/configs/rl_from_base_qwen.toml +27 -0
  148. examples/workflows/math_rl/configs/rl_from_base_qwen17.toml +6 -0
  149. synth_ai/api/train/builders.py +99 -4
  150. synth_ai/api/train/cli.py +516 -26
  151. synth_ai/api/train/config_finder.py +13 -2
  152. synth_ai/api/train/configs/__init__.py +23 -2
  153. synth_ai/api/train/configs/prompt_learning.py +442 -0
  154. synth_ai/api/train/configs/rl.py +61 -7
  155. synth_ai/api/train/configs/sft.py +6 -2
  156. synth_ai/api/train/configs/shared.py +59 -2
  157. synth_ai/api/train/task_app.py +1 -1
  158. synth_ai/api/train/validators.py +277 -0
  159. synth_ai/auth/credentials.py +119 -0
  160. synth_ai/baseline/__init__.py +25 -0
  161. synth_ai/baseline/config.py +209 -0
  162. synth_ai/baseline/discovery.py +214 -0
  163. synth_ai/baseline/execution.py +146 -0
  164. synth_ai/cli/__init__.py +94 -18
  165. synth_ai/cli/__main__.py +0 -0
  166. synth_ai/cli/claude.py +70 -0
  167. synth_ai/cli/codex.py +84 -0
  168. synth_ai/cli/commands/__init__.py +18 -0
  169. synth_ai/cli/commands/baseline/__init__.py +12 -0
  170. synth_ai/cli/commands/baseline/core.py +637 -0
  171. synth_ai/cli/commands/baseline/list.py +93 -0
  172. synth_ai/cli/commands/demo/__init__.py +6 -0
  173. synth_ai/cli/commands/demo/core.py +163 -0
  174. synth_ai/cli/commands/eval/__init__.py +19 -0
  175. synth_ai/cli/commands/eval/core.py +1112 -0
  176. synth_ai/cli/commands/eval/errors.py +81 -0
  177. synth_ai/cli/commands/eval/validation.py +133 -0
  178. synth_ai/cli/commands/filter/__init__.py +12 -0
  179. synth_ai/cli/commands/filter/core.py +424 -0
  180. synth_ai/cli/commands/filter/errors.py +55 -0
  181. synth_ai/cli/commands/filter/validation.py +77 -0
  182. synth_ai/cli/commands/help/__init__.py +177 -0
  183. synth_ai/cli/commands/help/core.py +72 -0
  184. synth_ai/cli/commands/smoke/__init__.py +7 -0
  185. synth_ai/cli/commands/smoke/core.py +1436 -0
  186. synth_ai/cli/commands/status/__init__.py +64 -0
  187. synth_ai/cli/commands/status/client.py +192 -0
  188. synth_ai/cli/commands/status/config.py +92 -0
  189. synth_ai/cli/commands/status/errors.py +20 -0
  190. synth_ai/cli/commands/status/formatters.py +164 -0
  191. synth_ai/cli/commands/status/subcommands/__init__.py +9 -0
  192. synth_ai/cli/commands/status/subcommands/files.py +79 -0
  193. synth_ai/cli/commands/status/subcommands/jobs.py +334 -0
  194. synth_ai/cli/commands/status/subcommands/models.py +79 -0
  195. synth_ai/cli/commands/status/subcommands/pricing.py +22 -0
  196. synth_ai/cli/commands/status/subcommands/runs.py +81 -0
  197. synth_ai/cli/commands/status/subcommands/summary.py +47 -0
  198. synth_ai/cli/commands/status/subcommands/usage.py +203 -0
  199. synth_ai/cli/commands/status/utils.py +114 -0
  200. synth_ai/cli/commands/train/__init__.py +53 -0
  201. synth_ai/cli/commands/train/core.py +21 -0
  202. synth_ai/cli/commands/train/errors.py +117 -0
  203. synth_ai/cli/commands/train/judge_schemas.py +200 -0
  204. synth_ai/cli/commands/train/judge_validation.py +305 -0
  205. synth_ai/cli/commands/train/validation.py +386 -0
  206. synth_ai/cli/demo.py +30 -158
  207. synth_ai/cli/deploy/__init__.py +43 -0
  208. synth_ai/cli/deploy.py +162 -0
  209. synth_ai/cli/eval/__init__.py +36 -0
  210. synth_ai/cli/eval/core.py +5 -0
  211. synth_ai/cli/eval/errors.py +31 -0
  212. synth_ai/cli/eval/validation.py +5 -0
  213. synth_ai/cli/filter/__init__.py +28 -0
  214. synth_ai/cli/filter/core.py +5 -0
  215. synth_ai/cli/filter/errors.py +23 -0
  216. synth_ai/cli/filter/validation.py +5 -0
  217. synth_ai/cli/legacy_root_backup.py +14 -8
  218. synth_ai/cli/modal_serve/__init__.py +12 -0
  219. synth_ai/cli/modal_serve/core.py +14 -0
  220. synth_ai/cli/modal_serve/errors.py +8 -0
  221. synth_ai/cli/modal_serve/validation.py +11 -0
  222. synth_ai/cli/opencode.py +107 -0
  223. synth_ai/cli/root.py +9 -5
  224. synth_ai/cli/serve/__init__.py +12 -0
  225. synth_ai/cli/serve/core.py +14 -0
  226. synth_ai/cli/serve/errors.py +8 -0
  227. synth_ai/cli/serve/validation.py +11 -0
  228. synth_ai/cli/setup.py +20 -265
  229. synth_ai/cli/status.py +7 -126
  230. synth_ai/cli/task_app_deploy.py +1 -10
  231. synth_ai/cli/task_app_modal_serve.py +4 -9
  232. synth_ai/cli/task_app_serve.py +4 -11
  233. synth_ai/cli/task_apps.py +51 -1480
  234. synth_ai/cli/train/__init__.py +12 -0
  235. synth_ai/cli/train/core.py +21 -0
  236. synth_ai/cli/train/errors.py +8 -0
  237. synth_ai/cli/train/validation.py +24 -0
  238. synth_ai/cli/train.py +1 -14
  239. synth_ai/demos/crafter/grpo_crafter_task_app.py +1 -1
  240. synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +1 -1
  241. synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +7 -4
  242. synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +9 -5
  243. synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +4 -3
  244. synth_ai/environments/examples/red/engine.py +33 -12
  245. synth_ai/environments/examples/red/engine_helpers/reward_components.py +151 -179
  246. synth_ai/environments/examples/red/environment.py +26 -0
  247. synth_ai/environments/examples/red/trace_hooks_v3.py +168 -0
  248. synth_ai/http.py +12 -0
  249. synth_ai/judge_schemas.py +10 -10
  250. synth_ai/learning/__init__.py +10 -0
  251. synth_ai/learning/prompt_learning_client.py +276 -0
  252. synth_ai/learning/prompt_learning_types.py +184 -0
  253. synth_ai/learning/rl/client.py +3 -1
  254. synth_ai/pricing/__init__.py +2 -0
  255. synth_ai/pricing/model_pricing.py +57 -0
  256. synth_ai/streaming/__init__.py +29 -0
  257. synth_ai/streaming/config.py +94 -0
  258. synth_ai/streaming/handlers.py +518 -0
  259. synth_ai/streaming/streamer.py +320 -0
  260. synth_ai/streaming/types.py +95 -0
  261. synth_ai/task/apps/__init__.py +1 -0
  262. synth_ai/task/config.py +2 -0
  263. synth_ai/task/tracing_utils.py +25 -25
  264. synth_ai/task/validators.py +45 -9
  265. synth_ai/task_app_cfgs.py +21 -0
  266. synth_ai/tracing_v3/config.py +162 -19
  267. synth_ai/tracing_v3/constants.py +1 -1
  268. synth_ai/tracing_v3/db_config.py +24 -38
  269. synth_ai/tracing_v3/migration_helper.py +1 -2
  270. synth_ai/tracing_v3/storage/config.py +47 -13
  271. synth_ai/tracing_v3/storage/factory.py +3 -3
  272. synth_ai/tracing_v3/turso/daemon.py +113 -11
  273. synth_ai/tracing_v3/turso/native_manager.py +92 -16
  274. synth_ai/types.py +8 -0
  275. synth_ai/urls.py +11 -0
  276. synth_ai/utils/__init__.py +30 -1
  277. synth_ai/utils/agents.py +74 -0
  278. synth_ai/utils/bin.py +39 -0
  279. synth_ai/utils/cli.py +149 -5
  280. synth_ai/utils/env.py +40 -33
  281. synth_ai/utils/http.py +4 -1
  282. synth_ai/utils/json.py +72 -0
  283. synth_ai/utils/modal.py +285 -3
  284. synth_ai/utils/paths.py +48 -0
  285. synth_ai/utils/uvicorn.py +113 -0
  286. {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/METADATA +109 -6
  287. {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/RECORD +291 -142
  288. examples/qwen_vl/configs/eval_qwen2vl_vision.toml +0 -44
  289. synth_ai/cli/tui.py +0 -62
  290. synth_ai/tui/__init__.py +0 -5
  291. synth_ai/tui/__main__.py +0 -13
  292. synth_ai/tui/cli/__init__.py +0 -1
  293. synth_ai/tui/cli/query_experiments.py +0 -164
  294. synth_ai/tui/cli/query_experiments_v3.py +0 -164
  295. synth_ai/tui/dashboard.py +0 -911
  296. {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/WHEEL +0 -0
  297. {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/entry_points.txt +0 -0
  298. {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/licenses/LICENSE +0 -0
  299. {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,326 @@
1
+ """Pokemon Red baseline file for Game Boy emulation evaluation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ from synth_ai.baseline import BaselineConfig, BaselineTaskRunner, DataSplit, TaskResult
8
+ from synth_ai.inference import InferenceClient
9
+ import os
10
+ import httpx
11
+
12
+ try:
13
+ from synth_ai.environments.examples.red.environment import PokemonRedEnvironment
14
+ from synth_ai.environments.examples.red.taskset import (
15
+ PokemonRedTaskInstance,
16
+ PokemonRedTaskInstanceMetadata,
17
+ )
18
+ POKEMON_RED_AVAILABLE = True
19
+ except ImportError:
20
+ POKEMON_RED_AVAILABLE = False
21
+
22
+
23
+ class PokemonRedTaskRunner(BaselineTaskRunner):
24
+ """Task runner for Pokemon Red Game Boy emulation."""
25
+
26
+ def __init__(self, policy_config: Dict[str, Any], env_config: Dict[str, Any]):
27
+ super().__init__(policy_config, env_config)
28
+
29
+ if not POKEMON_RED_AVAILABLE:
30
+ raise ImportError(
31
+ "Pokemon Red environment not available. "
32
+ "Install synth-ai with Pokemon Red support."
33
+ )
34
+
35
+ # Store config for inference
36
+ self.model = policy_config["model"]
37
+ self.temperature = policy_config.get("temperature", 0.0)
38
+ self.max_tokens = policy_config.get("max_tokens", 512)
39
+ self.inference_url = policy_config.get("inference_url")
40
+
41
+ # Tool definition
42
+ self.tools = [{
43
+ "type": "function",
44
+ "function": {
45
+ "name": "execute_sequence",
46
+ "description": "Execute multiple button presses in sequence",
47
+ "parameters": {
48
+ "type": "object",
49
+ "properties": {
50
+ "actions": {
51
+ "type": "array",
52
+ "items": {
53
+ "type": "object",
54
+ "properties": {
55
+ "button": {
56
+ "type": "string",
57
+ "enum": ["UP", "DOWN", "LEFT", "RIGHT", "A", "B", "START", "SELECT"],
58
+ },
59
+ "frames": {
60
+ "type": "integer",
61
+ "minimum": 1,
62
+ "maximum": 120,
63
+ "description": "Frames to hold button (60fps)",
64
+ },
65
+ },
66
+ "required": ["button", "frames"],
67
+ },
68
+ "minItems": 1,
69
+ "maxItems": 20,
70
+ },
71
+ },
72
+ "required": ["actions"],
73
+ },
74
+ },
75
+ }]
76
+
77
+ def _format_observation(self, obs: Dict[str, Any], step: int, max_steps: int) -> str:
78
+ """Format observation for LLM."""
79
+ lines = [
80
+ f"Pokemon Red - Step {step}/{max_steps}",
81
+ "",
82
+ ]
83
+
84
+ # Position
85
+ if "map_id" in obs:
86
+ lines.append(f"Location: Map {obs['map_id']}")
87
+ if "player_x" in obs and "player_y" in obs:
88
+ lines.append(f"Position: ({obs['player_x']}, {obs['player_y']})")
89
+
90
+ # Party
91
+ if "party_count" in obs:
92
+ lines.append(f"Party Size: {obs['party_count']}")
93
+ if "party_pokemon" in obs and obs["party_pokemon"]:
94
+ pokemon = obs["party_pokemon"][0]
95
+ lines.append(
96
+ f"First Pokemon: Level {pokemon.get('level', '?')}, "
97
+ f"HP {pokemon.get('hp_current', '?')}/{pokemon.get('hp_max', '?')}"
98
+ )
99
+
100
+ # Battle
101
+ if obs.get("in_battle"):
102
+ lines.append("=== IN BATTLE ===")
103
+ if "enemy_hp_current" in obs:
104
+ lines.append(
105
+ f"Enemy HP: {obs['enemy_hp_current']}/{obs.get('enemy_hp_max', '?')}"
106
+ )
107
+ if "battle_turn" in obs:
108
+ lines.append(f"Battle Turn: {obs['battle_turn']}")
109
+
110
+ # Progress
111
+ if "badges" in obs:
112
+ lines.append(f"Badges: {obs['badges']}")
113
+ if "money" in obs:
114
+ lines.append(f"Money: ${obs['money']}")
115
+
116
+ # Dialogue
117
+ if obs.get("text_box_active"):
118
+ lines.append("Text box is active - press A to advance dialogue")
119
+
120
+ lines.append("")
121
+ lines.append("What actions should we take?")
122
+
123
+ return "\n".join(lines)
124
+
125
+ async def run_task(self, seed: int) -> TaskResult:
126
+ """Run a single Pokemon Red episode."""
127
+
128
+ # Create task instance
129
+ rom_path = self.env_config.get("rom_path")
130
+ if not rom_path:
131
+ raise ValueError("rom_path required in env_config for Pokemon Red")
132
+
133
+ init_state_path = self.env_config.get("init_state_path")
134
+ max_steps = self.env_config.get("max_steps", 500)
135
+
136
+ metadata = PokemonRedTaskInstanceMetadata(
137
+ seed=seed,
138
+ rom_path=rom_path,
139
+ init_state_path=init_state_path,
140
+ reward_type=self.env_config.get("reward_type", "pallet_town_progression"),
141
+ )
142
+
143
+ task_instance = PokemonRedTaskInstance(
144
+ id=f"pokemon-red-{seed}",
145
+ metadata=metadata,
146
+ )
147
+
148
+ # Create environment
149
+ env = PokemonRedEnvironment(task_instance=task_instance)
150
+
151
+ # Initialize environment
152
+ raw_obs = await env.initialize()
153
+ observation = getattr(raw_obs, "observation", raw_obs) if hasattr(raw_obs, "observation") else raw_obs
154
+ obs_dict = observation if isinstance(observation, dict) else {}
155
+
156
+ # Episode loop
157
+ total_reward = 0.0
158
+ total_steps = 0
159
+ event_rewards: List[Dict[str, Any]] = []
160
+ battle_won = False
161
+ game_over = False
162
+
163
+ for step in range(max_steps):
164
+ # Format observation
165
+ prompt = self._format_observation(obs_dict, step, max_steps)
166
+
167
+ # Add image if available
168
+ messages = [{"role": "user", "content": prompt}]
169
+ if obs_dict.get("observation_image_base64"):
170
+ messages[0]["content"] = [
171
+ {
172
+ "type": "image_url",
173
+ "image_url": {
174
+ "url": f"data:image/png;base64,{obs_dict['observation_image_base64']}"
175
+ },
176
+ },
177
+ {"type": "text", "text": prompt},
178
+ ]
179
+
180
+ # Get action from LLM
181
+ if self.inference_url and self.inference_url.startswith("http"):
182
+ api_key = os.getenv("SYNTH_API_KEY") or os.getenv("OPENAI_API_KEY") or ""
183
+ base_url = self.inference_url.rstrip("/")
184
+ if not base_url.endswith("/api"):
185
+ base_url = f"{base_url}/api" if "/api" not in base_url else base_url
186
+ client = InferenceClient(base_url=base_url, api_key=api_key)
187
+ response = await client.create_chat_completion(
188
+ model=self.model,
189
+ messages=messages,
190
+ tools=self.tools,
191
+ tool_choice={"type": "function", "function": {"name": "execute_sequence"}},
192
+ temperature=self.temperature,
193
+ max_tokens=self.max_tokens,
194
+ )
195
+ else:
196
+ api_key = os.getenv("OPENAI_API_KEY") or os.getenv("GROQ_API_KEY") or ""
197
+ base_url = "https://api.openai.com/v1" if "openai" in self.model.lower() else "https://api.groq.com/openai/v1"
198
+ async with httpx.AsyncClient() as http_client:
199
+ resp = await http_client.post(
200
+ f"{base_url}/chat/completions",
201
+ json={
202
+ "model": self.model,
203
+ "messages": messages,
204
+ "tools": self.tools,
205
+ "tool_choice": {"type": "function", "function": {"name": "execute_sequence"}},
206
+ "temperature": self.temperature,
207
+ "max_tokens": self.max_tokens,
208
+ },
209
+ headers={"Authorization": f"Bearer {api_key}"} if api_key else {},
210
+ )
211
+ response = resp.json()
212
+
213
+ # Extract actions
214
+ actions = []
215
+ tool_calls = []
216
+ if "choices" in response and len(response["choices"]) > 0:
217
+ message = response["choices"][0].get("message", {})
218
+ tool_calls = message.get("tool_calls", [])
219
+ elif "tool_calls" in response:
220
+ tool_calls = response["tool_calls"]
221
+
222
+ if tool_calls:
223
+ tool_call = tool_calls[0]
224
+ actions = tool_call["function"]["arguments"].get("actions", [])
225
+
226
+ if not actions:
227
+ break
228
+
229
+ # Execute actions
230
+ for action_spec in actions:
231
+ if total_steps >= max_steps:
232
+ break
233
+
234
+ # Convert to tool call format
235
+ from synth_ai.environments.environment.tools import EnvToolCall
236
+
237
+ tool_call = EnvToolCall(
238
+ name="execute_sequence",
239
+ arguments={"actions": [action_spec]},
240
+ )
241
+
242
+ # Step environment
243
+ step_result = await env.step([tool_call])
244
+ total_steps += 1
245
+
246
+ # Get observation
247
+ step_obs = (
248
+ getattr(step_result, "observation", step_result)
249
+ if hasattr(step_result, "observation")
250
+ else step_result
251
+ )
252
+ obs_dict = step_obs if isinstance(step_obs, dict) else {}
253
+
254
+ # Extract reward
255
+ reward = getattr(step_result, "reward", 0.0)
256
+ total_reward += reward
257
+
258
+ if reward > 0:
259
+ event_rewards.append({
260
+ "step": total_steps,
261
+ "reward": reward,
262
+ })
263
+
264
+ # Check termination
265
+ if getattr(step_result, "terminated", False) or getattr(step_result, "truncated", False):
266
+ game_over = True
267
+ break
268
+
269
+ # Check battle outcome
270
+ if obs_dict.get("battle_outcome") == 1:
271
+ battle_won = True
272
+ elif obs_dict.get("battle_outcome") == 2:
273
+ game_over = True
274
+
275
+ if game_over:
276
+ break
277
+
278
+ # Cleanup
279
+ await env.terminate()
280
+
281
+ return TaskResult(
282
+ seed=seed,
283
+ success=True,
284
+ outcome_reward=total_reward,
285
+ event_rewards=event_rewards,
286
+ total_steps=total_steps,
287
+ metadata={
288
+ "battle_won": battle_won,
289
+ "game_over": game_over,
290
+ "final_map": obs_dict.get("map_id"),
291
+ "badges": obs_dict.get("badges", 0),
292
+ "party_size": obs_dict.get("party_count", 0),
293
+ },
294
+ )
295
+
296
+
297
+ # Define baseline config (only if Pokemon Red is available)
298
+ if POKEMON_RED_AVAILABLE:
299
+ pokemon_red_baseline = BaselineConfig(
300
+ baseline_id="pokemon_red",
301
+ name="Pokemon Red",
302
+ description="Pokemon Red Game Boy emulation with PyBoy",
303
+ task_runner=PokemonRedTaskRunner,
304
+ splits={
305
+ "train": DataSplit(name="train", seeds=list(range(20))),
306
+ "val": DataSplit(name="val", seeds=list(range(20, 25))),
307
+ "test": DataSplit(name="test", seeds=list(range(25, 30))),
308
+ },
309
+ default_policy_config={
310
+ "model": "groq:llama-3.1-70b-versatile",
311
+ "temperature": 0.0,
312
+ "max_tokens": 512,
313
+ },
314
+ default_env_config={
315
+ "rom_path": None, # Must be provided
316
+ "init_state_path": None, # Optional
317
+ "reward_type": "pallet_town_progression",
318
+ "max_steps": 500,
319
+ },
320
+ metadata={
321
+ "environment": "pokemon_red",
322
+ "task_type": "emulation",
323
+ "requires_rom": True,
324
+ },
325
+ )
326
+
@@ -0,0 +1,56 @@
1
+ """Simple example baseline file for testing."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from synth_ai.baseline import BaselineConfig, BaselineTaskRunner, DataSplit, TaskResult
6
+
7
+
8
+ class SimpleTaskRunner(BaselineTaskRunner):
9
+ """Simple task runner that returns success for testing."""
10
+
11
+ async def run_task(self, seed: int) -> TaskResult:
12
+ """Execute a simple task that always succeeds."""
13
+ return TaskResult(
14
+ seed=seed,
15
+ success=True,
16
+ outcome_reward=1.0,
17
+ total_steps=1,
18
+ metadata={
19
+ "seed": seed,
20
+ "message": f"Task completed successfully for seed {seed}",
21
+ },
22
+ )
23
+
24
+
25
+ # Define baseline config
26
+ simple_baseline = BaselineConfig(
27
+ baseline_id="simple",
28
+ name="Simple Baseline",
29
+ description="A simple baseline for testing",
30
+ task_runner=SimpleTaskRunner,
31
+ splits={
32
+ "train": DataSplit(
33
+ name="train",
34
+ seeds=list(range(10)),
35
+ metadata={"difficulty": "easy"},
36
+ ),
37
+ "val": DataSplit(
38
+ name="val",
39
+ seeds=list(range(10, 15)),
40
+ metadata={"difficulty": "medium"},
41
+ ),
42
+ "test": DataSplit(
43
+ name="test",
44
+ seeds=list(range(15, 20)),
45
+ metadata={"difficulty": "hard"},
46
+ ),
47
+ },
48
+ default_policy_config={
49
+ "model": "gpt-4o-mini",
50
+ "temperature": 0.0,
51
+ },
52
+ default_env_config={
53
+ "max_steps": 10,
54
+ },
55
+ )
56
+
@@ -0,0 +1,239 @@
1
+ """Warming Up to RL baseline file for Gymnasium environments."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict
6
+
7
+ import gymnasium as gym
8
+
9
+ from synth_ai.baseline import BaselineConfig, BaselineTaskRunner, DataSplit, TaskResult
10
+ from synth_ai.inference import InferenceClient
11
+ import os
12
+ import httpx
13
+
14
+
15
+ class WarmingUpToRLTaskRunner(BaselineTaskRunner):
16
+ """Task runner for Gymnasium environments (CartPole, FrozenLake, etc.)."""
17
+
18
+ def __init__(self, policy_config: Dict[str, Any], env_config: Dict[str, Any]):
19
+ super().__init__(policy_config, env_config)
20
+
21
+ # Store config for inference
22
+ self.model = policy_config["model"]
23
+ self.temperature = policy_config.get("temperature", 0.0)
24
+ self.max_tokens = policy_config.get("max_tokens", 128)
25
+ self.inference_url = policy_config.get("inference_url")
26
+
27
+ # Environment name
28
+ self.env_name = env_config.get("env_name", "CartPole-v1")
29
+
30
+ def _get_action_tool(self, env: gym.Env) -> Dict[str, Any]:
31
+ """Generate tool schema based on environment action space."""
32
+ if isinstance(env.action_space, gym.spaces.Discrete):
33
+ return {
34
+ "type": "function",
35
+ "function": {
36
+ "name": "take_action",
37
+ "description": f"Take action in {env.spec.id if env.spec else self.env_name}",
38
+ "parameters": {
39
+ "type": "object",
40
+ "properties": {
41
+ "action": {
42
+ "type": "integer",
43
+ "minimum": 0,
44
+ "maximum": env.action_space.n - 1,
45
+ "description": "Action index",
46
+ }
47
+ },
48
+ "required": ["action"],
49
+ },
50
+ },
51
+ }
52
+ else:
53
+ # Default for unknown action spaces
54
+ return {
55
+ "type": "function",
56
+ "function": {
57
+ "name": "take_action",
58
+ "description": "Take action in the environment",
59
+ "parameters": {
60
+ "type": "object",
61
+ "properties": {
62
+ "action": {
63
+ "type": "integer",
64
+ "description": "Action index",
65
+ }
66
+ },
67
+ "required": ["action"],
68
+ },
69
+ },
70
+ }
71
+
72
+ def _format_observation(self, obs: Any, env: gym.Env, step: int, max_steps: int) -> str:
73
+ """Format observation for LLM."""
74
+ obs_str = str(obs)
75
+ if hasattr(env, "spec") and env.spec:
76
+ env_id = env.spec.id
77
+ else:
78
+ env_id = self.env_name
79
+
80
+ return f"""Environment: {env_id}
81
+ Step: {step}/{max_steps}
82
+ Observation: {obs_str}
83
+
84
+ What action should we take?"""
85
+
86
+ async def run_task(self, seed: int) -> TaskResult:
87
+ """Run a single Gymnasium episode."""
88
+
89
+ # Create environment
90
+ env = gym.make(self.env_name)
91
+
92
+ # Reset with seed
93
+ obs, info = env.reset(seed=seed)
94
+
95
+ # Get action tool
96
+ action_tool = self._get_action_tool(env)
97
+
98
+ # Episode loop
99
+ total_reward = 0.0
100
+ total_steps = 0
101
+ max_steps = self.env_config.get("max_steps", 500)
102
+
103
+ terminated = False
104
+ truncated = False
105
+
106
+ for step in range(max_steps):
107
+ # Format observation
108
+ prompt = self._format_observation(obs, env, step, max_steps)
109
+
110
+ # Get action from LLM
111
+ messages = [{"role": "user", "content": prompt}]
112
+
113
+ if self.inference_url and self.inference_url.startswith("http"):
114
+ api_key = os.getenv("SYNTH_API_KEY") or os.getenv("OPENAI_API_KEY") or ""
115
+ base_url = self.inference_url.rstrip("/")
116
+ if not base_url.endswith("/api"):
117
+ base_url = f"{base_url}/api" if "/api" not in base_url else base_url
118
+ client = InferenceClient(base_url=base_url, api_key=api_key)
119
+ response = await client.create_chat_completion(
120
+ model=self.model,
121
+ messages=messages,
122
+ tools=[action_tool],
123
+ tool_choice={"type": "function", "function": {"name": "take_action"}},
124
+ temperature=self.temperature,
125
+ max_tokens=self.max_tokens,
126
+ )
127
+ else:
128
+ api_key = os.getenv("OPENAI_API_KEY") or os.getenv("GROQ_API_KEY") or ""
129
+ base_url = "https://api.openai.com/v1" if "openai" in self.model.lower() else "https://api.groq.com/openai/v1"
130
+ async with httpx.AsyncClient() as http_client:
131
+ resp = await http_client.post(
132
+ f"{base_url}/chat/completions",
133
+ json={
134
+ "model": self.model,
135
+ "messages": messages,
136
+ "tools": [action_tool],
137
+ "tool_choice": {"type": "function", "function": {"name": "take_action"}},
138
+ "temperature": self.temperature,
139
+ "max_tokens": self.max_tokens,
140
+ },
141
+ headers={"Authorization": f"Bearer {api_key}"} if api_key else {},
142
+ )
143
+ response = resp.json()
144
+
145
+ # Extract action
146
+ action = 0
147
+ tool_calls = []
148
+ if "choices" in response and len(response["choices"]) > 0:
149
+ message = response["choices"][0].get("message", {})
150
+ tool_calls = message.get("tool_calls", [])
151
+ elif "tool_calls" in response:
152
+ tool_calls = response["tool_calls"]
153
+
154
+ if tool_calls:
155
+ action = tool_calls[0]["function"]["arguments"].get("action", 0)
156
+ else:
157
+ # Fallback: sample random action
158
+ action = env.action_space.sample()
159
+
160
+ # Step environment
161
+ obs, reward, terminated, truncated, info = env.step(action)
162
+ total_reward += reward
163
+ total_steps += 1
164
+
165
+ if terminated or truncated:
166
+ break
167
+
168
+ env.close()
169
+
170
+ return TaskResult(
171
+ seed=seed,
172
+ success=True,
173
+ outcome_reward=total_reward,
174
+ total_steps=total_steps,
175
+ metadata={
176
+ "env_name": self.env_name,
177
+ "episode_length": total_steps,
178
+ "terminated": terminated,
179
+ "truncated": truncated,
180
+ "final_reward": total_reward,
181
+ },
182
+ )
183
+
184
+
185
+ # Define baseline configs for different environments
186
+ cartpole_baseline = BaselineConfig(
187
+ baseline_id="cartpole",
188
+ name="CartPole-v1",
189
+ description="Balance a pole on a cart using Gymnasium",
190
+ task_runner=WarmingUpToRLTaskRunner,
191
+ splits={
192
+ "train": DataSplit(name="train", seeds=list(range(100))),
193
+ "val": DataSplit(name="val", seeds=list(range(100, 120))),
194
+ "test": DataSplit(name="test", seeds=list(range(120, 140))),
195
+ },
196
+ default_policy_config={
197
+ "model": "groq:llama-3.1-70b-versatile",
198
+ "temperature": 0.0,
199
+ "max_tokens": 128,
200
+ },
201
+ default_env_config={
202
+ "env_name": "CartPole-v1",
203
+ "max_steps": 500,
204
+ },
205
+ metadata={
206
+ "environment": "CartPole-v1",
207
+ "task_type": "control",
208
+ "max_reward": 500,
209
+ },
210
+ tags=["rl", "gymnasium", "control"],
211
+ )
212
+
213
+ frozenlake_baseline = BaselineConfig(
214
+ baseline_id="frozenlake",
215
+ name="FrozenLake-v1",
216
+ description="Navigate a frozen lake to reach goal using Gymnasium",
217
+ task_runner=WarmingUpToRLTaskRunner,
218
+ splits={
219
+ "train": DataSplit(name="train", seeds=list(range(100))),
220
+ "val": DataSplit(name="val", seeds=list(range(100, 120))),
221
+ "test": DataSplit(name="test", seeds=list(range(120, 140))),
222
+ },
223
+ default_policy_config={
224
+ "model": "groq:llama-3.1-70b-versatile",
225
+ "temperature": 0.0,
226
+ "max_tokens": 128,
227
+ },
228
+ default_env_config={
229
+ "env_name": "FrozenLake-v1",
230
+ "max_steps": 100,
231
+ },
232
+ metadata={
233
+ "environment": "FrozenLake-v1",
234
+ "task_type": "navigation",
235
+ "max_reward": 1,
236
+ },
237
+ tags=["rl", "gymnasium", "navigation"],
238
+ )
239
+