synth-ai 0.2.16__py3-none-any.whl → 0.2.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (299) hide show
  1. examples/analyze_semantic_words.sh +2 -2
  2. examples/baseline/banking77_baseline.py +204 -0
  3. examples/baseline/crafter_baseline.py +407 -0
  4. examples/baseline/pokemon_red_baseline.py +326 -0
  5. examples/baseline/simple_baseline.py +56 -0
  6. examples/baseline/warming_up_to_rl_baseline.py +239 -0
  7. examples/blog_posts/gepa/README.md +355 -0
  8. examples/blog_posts/gepa/configs/banking77_gepa_local.toml +95 -0
  9. examples/blog_posts/gepa/configs/banking77_gepa_test.toml +82 -0
  10. examples/blog_posts/gepa/configs/banking77_mipro_local.toml +52 -0
  11. examples/blog_posts/gepa/configs/hotpotqa_gepa_local.toml +59 -0
  12. examples/blog_posts/gepa/configs/hotpotqa_gepa_qwen.toml +36 -0
  13. examples/blog_posts/gepa/configs/hotpotqa_mipro_local.toml +53 -0
  14. examples/blog_posts/gepa/configs/hover_gepa_local.toml +59 -0
  15. examples/blog_posts/gepa/configs/hover_gepa_qwen.toml +36 -0
  16. examples/blog_posts/gepa/configs/hover_mipro_local.toml +53 -0
  17. examples/blog_posts/gepa/configs/ifbench_gepa_local.toml +59 -0
  18. examples/blog_posts/gepa/configs/ifbench_gepa_qwen.toml +36 -0
  19. examples/blog_posts/gepa/configs/ifbench_mipro_local.toml +53 -0
  20. examples/blog_posts/gepa/configs/pupa_gepa_local.toml +60 -0
  21. examples/blog_posts/gepa/configs/pupa_mipro_local.toml +54 -0
  22. examples/blog_posts/gepa/deploy_banking77_task_app.sh +41 -0
  23. examples/blog_posts/gepa/gepa_baseline.py +204 -0
  24. examples/blog_posts/gepa/query_prompts_example.py +97 -0
  25. examples/blog_posts/gepa/run_gepa_banking77.sh +87 -0
  26. examples/blog_posts/gepa/task_apps.py +105 -0
  27. examples/blog_posts/gepa/test_gepa_local.sh +67 -0
  28. examples/blog_posts/gepa/verify_banking77_setup.sh +123 -0
  29. examples/blog_posts/pokemon_vl/README.md +98 -0
  30. examples/blog_posts/pokemon_vl/configs/eval_gpt5nano.toml +26 -0
  31. examples/blog_posts/pokemon_vl/configs/eval_qwen3_vl.toml +27 -0
  32. examples/blog_posts/pokemon_vl/configs/eval_rl_final.toml +24 -0
  33. examples/blog_posts/pokemon_vl/configs/filter_high_reward.toml +10 -0
  34. examples/blog_posts/pokemon_vl/configs/train_rl_from_sft.toml +43 -0
  35. examples/blog_posts/pokemon_vl/configs/train_sft_qwen4b_vl.toml +40 -0
  36. examples/blog_posts/pokemon_vl/extract_images.py +239 -0
  37. examples/blog_posts/pokemon_vl/pokemon_vl_baseline.py +326 -0
  38. examples/blog_posts/pokemon_vl/run_eval_extract_images.py +209 -0
  39. examples/blog_posts/pokemon_vl/run_qwen_eval_extract_images.py +212 -0
  40. examples/blog_posts/pokemon_vl/text_box_analysis.md +106 -0
  41. examples/blog_posts/warming_up_to_rl/ARCHITECTURE.md +195 -0
  42. examples/blog_posts/warming_up_to_rl/FINAL_TEST_RESULTS.md +127 -0
  43. examples/blog_posts/warming_up_to_rl/INFERENCE_SUCCESS.md +132 -0
  44. examples/blog_posts/warming_up_to_rl/README.md +158 -0
  45. examples/blog_posts/warming_up_to_rl/SMOKE_TESTING.md +164 -0
  46. examples/blog_posts/warming_up_to_rl/SMOKE_TEST_COMPLETE.md +253 -0
  47. examples/blog_posts/warming_up_to_rl/configs/eval_baseline_qwen32b_10x20.toml +25 -0
  48. examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b.toml +25 -0
  49. examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b_10x20.toml +26 -0
  50. examples/blog_posts/warming_up_to_rl/configs/eval_groq_qwen32b.toml +25 -0
  51. examples/blog_posts/warming_up_to_rl/configs/eval_openai_gpt_oss_120b.toml +29 -0
  52. examples/blog_posts/warming_up_to_rl/configs/filter_high_reward_dataset.toml +10 -0
  53. examples/blog_posts/warming_up_to_rl/configs/smoke_test.toml +75 -0
  54. examples/blog_posts/warming_up_to_rl/configs/train_rl_from_sft.toml +91 -0
  55. examples/blog_posts/warming_up_to_rl/configs/train_sft_qwen4b.toml +40 -0
  56. examples/blog_posts/warming_up_to_rl/warming_up_to_rl_baseline.py +187 -0
  57. examples/dev/qwen3_32b_qlora_4xh100.toml +5 -0
  58. examples/multi_step/configs/VERILOG_REWARDS.md +4 -0
  59. examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +4 -0
  60. examples/multi_step/configs/crafter_rl_outcome.toml +2 -1
  61. examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +65 -107
  62. examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +2 -1
  63. examples/multi_step/configs/crafter_rl_stepwise_simple.toml +2 -1
  64. examples/multi_step/configs/crafter_rl_stepwise_simple_NEW_FORMAT.toml +105 -0
  65. examples/multi_step/configs/verilog_rl_lora.toml +80 -123
  66. examples/qwen_coder/configs/coder_lora_30b.toml +1 -3
  67. examples/qwen_coder/configs/coder_lora_4b.toml +4 -1
  68. examples/qwen_coder/configs/coder_lora_small.toml +1 -3
  69. examples/qwen_vl/README.md +10 -12
  70. examples/qwen_vl/SETUP_COMPLETE.md +7 -8
  71. examples/qwen_vl/VISION_TESTS_COMPLETE.md +2 -3
  72. examples/qwen_vl/collect_data_via_cli.md +76 -84
  73. examples/qwen_vl/collect_vision_traces.py +4 -4
  74. examples/qwen_vl/configs/crafter_rl_vision_qwen3vl4b.toml +40 -57
  75. examples/qwen_vl/configs/crafter_vlm_sft_example.toml +1 -2
  76. examples/qwen_vl/configs/eval_gpt4o_mini_vision.toml +20 -37
  77. examples/qwen_vl/configs/eval_gpt5nano_vision.toml +21 -40
  78. examples/qwen_vl/configs/eval_qwen3vl_vision.toml +26 -0
  79. examples/qwen_vl/configs/{filter_qwen2vl_sft.toml → filter_qwen3vl_sft.toml} +4 -5
  80. examples/qwen_vl/configs/filter_vision_sft.toml +2 -3
  81. examples/qwen_vl/crafter_qwen_vl_agent.py +5 -5
  82. examples/qwen_vl/run_vision_comparison.sh +6 -7
  83. examples/rl/README.md +5 -5
  84. examples/rl/configs/rl_from_base_qwen.toml +26 -1
  85. examples/rl/configs/rl_from_base_qwen17.toml +6 -2
  86. examples/rl/task_app/README.md +1 -2
  87. examples/rl/task_app/math_single_step.py +2 -2
  88. examples/run_crafter_demo.sh +2 -2
  89. examples/sft/README.md +1 -1
  90. examples/sft/configs/crafter_fft_qwen0p6b.toml +4 -1
  91. examples/sft/configs/crafter_lora_qwen0p6b.toml +4 -1
  92. examples/swe/task_app/README.md +32 -2
  93. examples/swe/task_app/grpo_swe_mini.py +4 -0
  94. examples/swe/task_app/hosted/envs/crafter/react_agent.py +1 -1
  95. examples/swe/task_app/hosted/envs/mini_swe/environment.py +37 -10
  96. examples/swe/task_app/hosted/inference/openai_client.py +4 -38
  97. examples/swe/task_app/hosted/policy_routes.py +17 -0
  98. examples/swe/task_app/hosted/rollout.py +4 -2
  99. examples/swe/task_app/morph_backend.py +178 -0
  100. examples/task_apps/banking77/__init__.py +6 -0
  101. examples/task_apps/banking77/banking77_task_app.py +841 -0
  102. examples/task_apps/banking77/deploy_wrapper.py +46 -0
  103. examples/task_apps/crafter/CREATE_SFT_DATASET.md +4 -0
  104. examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +4 -0
  105. examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +4 -0
  106. examples/task_apps/crafter/task_app/README.md +1 -1
  107. examples/task_apps/crafter/task_app/grpo_crafter.py +90 -5
  108. examples/task_apps/crafter/task_app/grpo_crafter_task_app.py +1 -1
  109. examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/policy.py +4 -26
  110. examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/react_agent.py +1 -2
  111. examples/task_apps/crafter/task_app/synth_envs_hosted/hosted_app.py +49 -0
  112. examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +372 -107
  113. examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +81 -12
  114. examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +82 -11
  115. examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +194 -1
  116. examples/task_apps/enron/task_app/grpo_enron_task_app.py +1 -1
  117. examples/task_apps/gepa_benchmarks/__init__.py +7 -0
  118. examples/task_apps/gepa_benchmarks/common.py +260 -0
  119. examples/task_apps/gepa_benchmarks/hotpotqa_task_app.py +507 -0
  120. examples/task_apps/gepa_benchmarks/hover_task_app.py +436 -0
  121. examples/task_apps/gepa_benchmarks/ifbench_task_app.py +563 -0
  122. examples/task_apps/gepa_benchmarks/pupa_task_app.py +460 -0
  123. examples/task_apps/math/README.md +1 -2
  124. examples/task_apps/pokemon_red/README.md +3 -4
  125. examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +4 -0
  126. examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +6 -5
  127. examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +1 -2
  128. examples/task_apps/pokemon_red/task_app.py +288 -39
  129. examples/task_apps/sokoban/README.md +2 -3
  130. examples/task_apps/verilog/eval_groq_qwen32b.toml +12 -14
  131. examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +1 -1
  132. examples/vlm/configs/crafter_vlm_gpt4o.toml +4 -1
  133. examples/warming_up_to_rl/configs/crafter_fft.toml +4 -1
  134. examples/warming_up_to_rl/configs/crafter_fft_4b.toml +0 -2
  135. examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +3 -2
  136. examples/warming_up_to_rl/run_local_rollout_traced.py +1 -1
  137. examples/warming_up_to_rl/task_app/README.md +1 -1
  138. examples/warming_up_to_rl/task_app/grpo_crafter.py +185 -5
  139. examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +1 -1
  140. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +3 -27
  141. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +1 -1
  142. examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +49 -0
  143. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +156 -45
  144. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +37 -4
  145. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +33 -3
  146. examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +67 -0
  147. examples/workflows/math_rl/configs/rl_from_base_qwen.toml +27 -0
  148. examples/workflows/math_rl/configs/rl_from_base_qwen17.toml +6 -0
  149. synth_ai/api/train/builders.py +99 -4
  150. synth_ai/api/train/cli.py +516 -26
  151. synth_ai/api/train/config_finder.py +13 -2
  152. synth_ai/api/train/configs/__init__.py +23 -2
  153. synth_ai/api/train/configs/prompt_learning.py +442 -0
  154. synth_ai/api/train/configs/rl.py +61 -7
  155. synth_ai/api/train/configs/sft.py +6 -2
  156. synth_ai/api/train/configs/shared.py +59 -2
  157. synth_ai/api/train/task_app.py +1 -1
  158. synth_ai/api/train/validators.py +277 -0
  159. synth_ai/auth/credentials.py +119 -0
  160. synth_ai/baseline/__init__.py +25 -0
  161. synth_ai/baseline/config.py +209 -0
  162. synth_ai/baseline/discovery.py +214 -0
  163. synth_ai/baseline/execution.py +146 -0
  164. synth_ai/cli/__init__.py +94 -18
  165. synth_ai/cli/__main__.py +0 -0
  166. synth_ai/cli/claude.py +70 -0
  167. synth_ai/cli/codex.py +84 -0
  168. synth_ai/cli/commands/__init__.py +18 -0
  169. synth_ai/cli/commands/baseline/__init__.py +12 -0
  170. synth_ai/cli/commands/baseline/core.py +637 -0
  171. synth_ai/cli/commands/baseline/list.py +93 -0
  172. synth_ai/cli/commands/demo/__init__.py +6 -0
  173. synth_ai/cli/commands/demo/core.py +163 -0
  174. synth_ai/cli/commands/eval/__init__.py +19 -0
  175. synth_ai/cli/commands/eval/core.py +1112 -0
  176. synth_ai/cli/commands/eval/errors.py +81 -0
  177. synth_ai/cli/commands/eval/validation.py +133 -0
  178. synth_ai/cli/commands/filter/__init__.py +12 -0
  179. synth_ai/cli/commands/filter/core.py +424 -0
  180. synth_ai/cli/commands/filter/errors.py +55 -0
  181. synth_ai/cli/commands/filter/validation.py +77 -0
  182. synth_ai/cli/commands/help/__init__.py +177 -0
  183. synth_ai/cli/commands/help/core.py +72 -0
  184. synth_ai/cli/commands/smoke/__init__.py +7 -0
  185. synth_ai/cli/commands/smoke/core.py +1436 -0
  186. synth_ai/cli/commands/status/__init__.py +64 -0
  187. synth_ai/cli/commands/status/client.py +192 -0
  188. synth_ai/cli/commands/status/config.py +92 -0
  189. synth_ai/cli/commands/status/errors.py +20 -0
  190. synth_ai/cli/commands/status/formatters.py +164 -0
  191. synth_ai/cli/commands/status/subcommands/__init__.py +9 -0
  192. synth_ai/cli/commands/status/subcommands/files.py +79 -0
  193. synth_ai/cli/commands/status/subcommands/jobs.py +334 -0
  194. synth_ai/cli/commands/status/subcommands/models.py +79 -0
  195. synth_ai/cli/commands/status/subcommands/pricing.py +22 -0
  196. synth_ai/cli/commands/status/subcommands/runs.py +81 -0
  197. synth_ai/cli/commands/status/subcommands/summary.py +47 -0
  198. synth_ai/cli/commands/status/subcommands/usage.py +203 -0
  199. synth_ai/cli/commands/status/utils.py +114 -0
  200. synth_ai/cli/commands/train/__init__.py +53 -0
  201. synth_ai/cli/commands/train/core.py +21 -0
  202. synth_ai/cli/commands/train/errors.py +117 -0
  203. synth_ai/cli/commands/train/judge_schemas.py +200 -0
  204. synth_ai/cli/commands/train/judge_validation.py +305 -0
  205. synth_ai/cli/commands/train/validation.py +386 -0
  206. synth_ai/cli/demo.py +30 -158
  207. synth_ai/cli/deploy/__init__.py +43 -0
  208. synth_ai/cli/deploy.py +162 -0
  209. synth_ai/cli/eval/__init__.py +36 -0
  210. synth_ai/cli/eval/core.py +5 -0
  211. synth_ai/cli/eval/errors.py +31 -0
  212. synth_ai/cli/eval/validation.py +5 -0
  213. synth_ai/cli/filter/__init__.py +28 -0
  214. synth_ai/cli/filter/core.py +5 -0
  215. synth_ai/cli/filter/errors.py +23 -0
  216. synth_ai/cli/filter/validation.py +5 -0
  217. synth_ai/cli/legacy_root_backup.py +14 -8
  218. synth_ai/cli/modal_serve/__init__.py +12 -0
  219. synth_ai/cli/modal_serve/core.py +14 -0
  220. synth_ai/cli/modal_serve/errors.py +8 -0
  221. synth_ai/cli/modal_serve/validation.py +11 -0
  222. synth_ai/cli/opencode.py +107 -0
  223. synth_ai/cli/root.py +9 -5
  224. synth_ai/cli/serve/__init__.py +12 -0
  225. synth_ai/cli/serve/core.py +14 -0
  226. synth_ai/cli/serve/errors.py +8 -0
  227. synth_ai/cli/serve/validation.py +11 -0
  228. synth_ai/cli/setup.py +20 -265
  229. synth_ai/cli/status.py +7 -126
  230. synth_ai/cli/task_app_deploy.py +1 -10
  231. synth_ai/cli/task_app_modal_serve.py +4 -9
  232. synth_ai/cli/task_app_serve.py +4 -11
  233. synth_ai/cli/task_apps.py +51 -1480
  234. synth_ai/cli/train/__init__.py +12 -0
  235. synth_ai/cli/train/core.py +21 -0
  236. synth_ai/cli/train/errors.py +8 -0
  237. synth_ai/cli/train/validation.py +24 -0
  238. synth_ai/cli/train.py +1 -14
  239. synth_ai/demos/crafter/grpo_crafter_task_app.py +1 -1
  240. synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +1 -1
  241. synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +7 -4
  242. synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +9 -5
  243. synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +4 -3
  244. synth_ai/environments/examples/red/engine.py +33 -12
  245. synth_ai/environments/examples/red/engine_helpers/reward_components.py +151 -179
  246. synth_ai/environments/examples/red/environment.py +26 -0
  247. synth_ai/environments/examples/red/trace_hooks_v3.py +168 -0
  248. synth_ai/http.py +12 -0
  249. synth_ai/judge_schemas.py +10 -10
  250. synth_ai/learning/__init__.py +10 -0
  251. synth_ai/learning/prompt_learning_client.py +276 -0
  252. synth_ai/learning/prompt_learning_types.py +184 -0
  253. synth_ai/learning/rl/client.py +3 -1
  254. synth_ai/pricing/__init__.py +2 -0
  255. synth_ai/pricing/model_pricing.py +57 -0
  256. synth_ai/streaming/__init__.py +29 -0
  257. synth_ai/streaming/config.py +94 -0
  258. synth_ai/streaming/handlers.py +518 -0
  259. synth_ai/streaming/streamer.py +320 -0
  260. synth_ai/streaming/types.py +95 -0
  261. synth_ai/task/apps/__init__.py +1 -0
  262. synth_ai/task/config.py +2 -0
  263. synth_ai/task/tracing_utils.py +25 -25
  264. synth_ai/task/validators.py +45 -9
  265. synth_ai/task_app_cfgs.py +21 -0
  266. synth_ai/tracing_v3/config.py +162 -19
  267. synth_ai/tracing_v3/constants.py +1 -1
  268. synth_ai/tracing_v3/db_config.py +24 -38
  269. synth_ai/tracing_v3/migration_helper.py +1 -2
  270. synth_ai/tracing_v3/storage/config.py +47 -13
  271. synth_ai/tracing_v3/storage/factory.py +3 -3
  272. synth_ai/tracing_v3/turso/daemon.py +113 -11
  273. synth_ai/tracing_v3/turso/native_manager.py +92 -16
  274. synth_ai/types.py +8 -0
  275. synth_ai/urls.py +11 -0
  276. synth_ai/utils/__init__.py +30 -1
  277. synth_ai/utils/agents.py +74 -0
  278. synth_ai/utils/bin.py +39 -0
  279. synth_ai/utils/cli.py +149 -5
  280. synth_ai/utils/env.py +40 -33
  281. synth_ai/utils/http.py +4 -1
  282. synth_ai/utils/json.py +72 -0
  283. synth_ai/utils/modal.py +285 -3
  284. synth_ai/utils/paths.py +48 -0
  285. synth_ai/utils/uvicorn.py +113 -0
  286. {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/METADATA +109 -6
  287. {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/RECORD +291 -142
  288. examples/qwen_vl/configs/eval_qwen2vl_vision.toml +0 -44
  289. synth_ai/cli/tui.py +0 -62
  290. synth_ai/tui/__init__.py +0 -5
  291. synth_ai/tui/__main__.py +0 -13
  292. synth_ai/tui/cli/__init__.py +0 -1
  293. synth_ai/tui/cli/query_experiments.py +0 -164
  294. synth_ai/tui/cli/query_experiments_v3.py +0 -164
  295. synth_ai/tui/dashboard.py +0 -911
  296. {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/WHEEL +0 -0
  297. {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/entry_points.txt +0 -0
  298. {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/licenses/LICENSE +0 -0
  299. {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/top_level.txt +0 -0
@@ -5,7 +5,7 @@
5
5
  # Output: Markdown tables and JSON data (no plotting dependencies)
6
6
 
7
7
  echo "🔍 Analyzing semantic map words from Crafter agent..."
8
- echo "Make sure the synth-ai service is running: uvx synth-ai serve"
8
+ echo "Make sure the synth-ai service is running: uvx synth-ai deploy --runtime uvicorn"
9
9
  echo ""
10
10
 
11
11
  cd synth_ai/environments/examples/crafter_classic/agent_demos/
@@ -14,4 +14,4 @@ cd synth_ai/environments/examples/crafter_classic/agent_demos/
14
14
  python analyze_semantic_words_markdown.py --model gemini-1.5-flash --episodes 3 --max-turns 30
15
15
 
16
16
  echo ""
17
- echo "✅ Analysis complete! Check the generated markdown report and JSON files."
17
+ echo "✅ Analysis complete! Check the generated markdown report and JSON files."
@@ -0,0 +1,204 @@
1
+ """Banking77 baseline file for intent classification evaluation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict
6
+
7
+ from datasets import load_dataset
8
+
9
+ from synth_ai.baseline import BaselineConfig, BaselineTaskRunner, DataSplit, TaskResult
10
+ from synth_ai.inference import InferenceClient
11
+ import os
12
+ import httpx
13
+
14
+
15
+ # Load dataset once at module level
16
+ _dataset = None
17
+ _label_names = None
18
+
19
+
20
+ def _load_dataset():
21
+ """Load Banking77 dataset."""
22
+ global _dataset, _label_names
23
+ if _dataset is None:
24
+ try:
25
+ _dataset = load_dataset("PolyAI/banking77")
26
+ except Exception:
27
+ # Fallback: try without org prefix
28
+ _dataset = load_dataset("banking77")
29
+ _label_names = _dataset["train"].features["label"].names
30
+ return _dataset, _label_names
31
+
32
+
33
+ class Banking77TaskRunner(BaselineTaskRunner):
34
+ """Task runner for Banking77 intent classification."""
35
+
36
+ def __init__(self, policy_config: Dict[str, Any], env_config: Dict[str, Any]):
37
+ super().__init__(policy_config, env_config)
38
+
39
+ # Load dataset
40
+ self.dataset, self.label_names = _load_dataset()
41
+
42
+ # Store config for inference
43
+ self.model = policy_config["model"]
44
+ self.temperature = policy_config.get("temperature", 0.0)
45
+ self.max_tokens = policy_config.get("max_tokens", 128)
46
+ self.inference_url = policy_config.get("inference_url")
47
+
48
+ # Tool definition
49
+ self.tool = {
50
+ "type": "function",
51
+ "function": {
52
+ "name": "banking77_classify",
53
+ "description": "Classify a banking query into an intent",
54
+ "parameters": {
55
+ "type": "object",
56
+ "properties": {
57
+ "label": {
58
+ "type": "string",
59
+ "enum": self.label_names,
60
+ "description": "The intent label",
61
+ }
62
+ },
63
+ "required": ["label"],
64
+ },
65
+ },
66
+ }
67
+
68
+ async def run_task(self, seed: int) -> TaskResult:
69
+ """Run a single Banking77 classification task."""
70
+
71
+ # Get split
72
+ split = self.env_config.get("split", "train")
73
+
74
+ # Get example from dataset
75
+ example = self.dataset[split][seed]
76
+
77
+ # Build prompt
78
+ system_prompt = f"""You are an expert banking assistant that classifies customer queries.
79
+ Given a customer message, respond with exactly one intent label using the tool call.
80
+
81
+ Valid intents: {', '.join(self.label_names)}"""
82
+
83
+ user_prompt = f"Customer Query: {example['text']}\n\nClassify this query."
84
+
85
+ # Run inference
86
+ messages = [
87
+ {"role": "system", "content": system_prompt},
88
+ {"role": "user", "content": user_prompt},
89
+ ]
90
+
91
+ # Use InferenceClient if URL provided, otherwise use OpenAI-compatible API
92
+ if self.inference_url and self.inference_url.startswith("http"):
93
+ api_key = os.getenv("SYNTH_API_KEY") or os.getenv("OPENAI_API_KEY") or ""
94
+ base_url = self.inference_url.rstrip("/")
95
+ if not base_url.endswith("/api"):
96
+ base_url = f"{base_url}/api" if "/api" not in base_url else base_url
97
+ client = InferenceClient(base_url=base_url, api_key=api_key)
98
+ response = await client.create_chat_completion(
99
+ model=self.model,
100
+ messages=messages,
101
+ tools=[self.tool],
102
+ tool_choice={"type": "function", "function": {"name": "banking77_classify"}},
103
+ temperature=self.temperature,
104
+ max_tokens=self.max_tokens,
105
+ )
106
+ else:
107
+ # Use OpenAI/Groq directly
108
+ # Check if model starts with groq: prefix
109
+ model_name = self.model
110
+ use_groq = model_name.startswith("groq:")
111
+ if use_groq:
112
+ model_name = model_name[5:] # Remove "groq:" prefix
113
+
114
+ api_key = os.getenv("GROQ_API_KEY") if use_groq else os.getenv("OPENAI_API_KEY") or ""
115
+ base_url = "https://api.groq.com/openai/v1" if use_groq else "https://api.openai.com/v1"
116
+ async with httpx.AsyncClient() as http_client:
117
+ resp = await http_client.post(
118
+ f"{base_url}/chat/completions",
119
+ json={
120
+ "model": model_name,
121
+ "messages": messages,
122
+ "tools": [self.tool],
123
+ "tool_choice": {"type": "function", "function": {"name": "banking77_classify"}},
124
+ "temperature": self.temperature,
125
+ "max_tokens": self.max_tokens,
126
+ },
127
+ headers={"Authorization": f"Bearer {api_key}"} if api_key else {},
128
+ )
129
+ response = resp.json()
130
+
131
+ # Extract prediction
132
+ predicted_label = ""
133
+ tool_calls = []
134
+ if "choices" in response and len(response["choices"]) > 0:
135
+ message = response["choices"][0].get("message", {})
136
+ tool_calls = message.get("tool_calls", [])
137
+ elif "tool_calls" in response:
138
+ tool_calls = response["tool_calls"]
139
+
140
+ if tool_calls:
141
+ # Handle both string and dict arguments
142
+ args = tool_calls[0]["function"].get("arguments", "")
143
+ if isinstance(args, str):
144
+ import json
145
+ args = json.loads(args)
146
+ predicted_label = args.get("label", "") if isinstance(args, dict) else ""
147
+
148
+ # Evaluate
149
+ expected_label = self.label_names[example["label"]]
150
+ correct = predicted_label == expected_label
151
+
152
+ return TaskResult(
153
+ seed=seed,
154
+ success=True,
155
+ outcome_reward=1.0 if correct else 0.0,
156
+ total_steps=1,
157
+ metadata={
158
+ "query": example["text"],
159
+ "expected": expected_label,
160
+ "predicted": predicted_label,
161
+ "correct": correct,
162
+ "split": split,
163
+ },
164
+ )
165
+
166
+
167
+ # Define baseline config
168
+ # Note: We need to load the dataset first to get the label names
169
+ _load_dataset()
170
+ banking77_baseline = BaselineConfig(
171
+ baseline_id="banking77",
172
+ name="Banking77 Intent Classification",
173
+ description="Banking intent classification from customer queries",
174
+ task_runner=Banking77TaskRunner,
175
+ splits={
176
+ "train": DataSplit(
177
+ name="train",
178
+ seeds=list(range(min(10000, len(_dataset["train"]))) if _dataset else range(10000)),
179
+ ),
180
+ "val": DataSplit(
181
+ name="val",
182
+ seeds=list(range(min(1000, len(_dataset["test"]))) if _dataset else range(1000)),
183
+ ),
184
+ "test": DataSplit(
185
+ name="test",
186
+ seeds=list(range(min(3000, len(_dataset["test"]))) if _dataset else range(3000)),
187
+ ),
188
+ },
189
+ default_policy_config={
190
+ "model": "groq:llama-3.1-70b-versatile",
191
+ "temperature": 0.0,
192
+ "max_tokens": 128,
193
+ },
194
+ default_env_config={
195
+ "split": "train",
196
+ },
197
+ metadata={
198
+ "dataset": "PolyAI/banking77",
199
+ "num_classes": 77,
200
+ "task_type": "classification",
201
+ },
202
+ tags=["classification", "nlp", "intent"],
203
+ )
204
+
@@ -0,0 +1,407 @@
1
+ """Crafter baseline file for self-contained evaluation.
2
+
3
+ This baseline file defines how to evaluate agents on Crafter without
4
+ requiring a deployed task app. It includes train/val/test splits and
5
+ computes both event rewards (achievement deltas) and outcome rewards
6
+ (total unique achievements).
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from typing import Any, Dict, List, Optional, Set
12
+ from uuid import uuid4
13
+
14
+ from synth_ai.baseline import BaselineConfig, BaselineTaskRunner, DataSplit, TaskResult
15
+ from synth_ai.environments.examples.crafter_classic.environment import (
16
+ CrafterClassicEnvironment,
17
+ )
18
+ from synth_ai.environments.examples.crafter_classic.taskset import (
19
+ CrafterTaskInstance,
20
+ CrafterTaskInstanceMetadata,
21
+ )
22
+ from synth_ai.environments.tasks.core import Impetus, Intent
23
+ from synth_ai.inference import InferenceClient
24
+ from synth_ai.tracing_v3.session_tracer import SessionTracer
25
+ import os
26
+
27
+
28
+ # Action mapping: string names to action indices
29
+ CRAFTER_ACTION_MAP: Dict[str, int] = {
30
+ "noop": 0,
31
+ "move_left": 1,
32
+ "move_right": 2,
33
+ "move_up": 3,
34
+ "move_down": 4,
35
+ "do": 5,
36
+ "sleep": 6,
37
+ "place_stone": 7,
38
+ "place_table": 8,
39
+ "place_furnace": 9,
40
+ "place_plant": 10,
41
+ "make_wood_pickaxe": 11,
42
+ "make_stone_pickaxe": 12,
43
+ "make_iron_pickaxe": 13,
44
+ "make_wood_sword": 14,
45
+ "make_stone_sword": 15,
46
+ "make_iron_sword": 16,
47
+ }
48
+
49
+
50
+ def format_crafter_observation(obs: Dict[str, Any]) -> str:
51
+ """Format Crafter observation as text for LLM."""
52
+ health = obs.get("health") or obs.get("inventory", {}).get("health", 0)
53
+ inventory = obs.get("inventory", {})
54
+ pos = obs.get("player_position", [0, 0])
55
+ achievements_status = obs.get("achievements_status", {})
56
+
57
+ # Format inventory (skip health)
58
+ inv_items = [f"{k}:{v}" for k, v in inventory.items() if v > 0 and k != "health"]
59
+ inventory_str = ", ".join(inv_items) if inv_items else "empty"
60
+
61
+ # Format achievements
62
+ achieved_list = [k for k, v in achievements_status.items() if v]
63
+ achievements_str = ", ".join(achieved_list) if achieved_list else "none"
64
+
65
+ return f"""Crafter Game State:
66
+ - Health: {health}/10
67
+ - Hunger: {inventory.get('hunger', 0)}/10
68
+ - Position: {pos}
69
+ - Inventory: {inventory_str}
70
+ - Achievements unlocked: {len(achieved_list)}/22
71
+ - Achievements: {achievements_str}
72
+
73
+ What actions should we take?"""
74
+
75
+
76
+ class CrafterTaskRunner(BaselineTaskRunner):
77
+ """Task runner for Crafter survival game."""
78
+
79
+ def __init__(self, policy_config: Dict[str, Any], env_config: Dict[str, Any]):
80
+ super().__init__(policy_config, env_config)
81
+
82
+ # Initialize inference client
83
+ inference_url = policy_config.get("inference_url")
84
+ if inference_url and inference_url.startswith("http"):
85
+ # External URL - use InferenceClient
86
+ api_key = os.getenv("SYNTH_API_KEY") or os.getenv("OPENAI_API_KEY") or ""
87
+ base_url = inference_url.rstrip("/")
88
+ if not base_url.endswith("/api"):
89
+ base_url = f"{base_url}/api" if "/api" not in base_url else base_url
90
+ self.client = InferenceClient(base_url=base_url, api_key=api_key)
91
+ self.use_inference_client = True
92
+ else:
93
+ # For OpenAI/Groq direct APIs, we'll use httpx
94
+ import httpx
95
+ self.http_client = httpx.AsyncClient()
96
+ self.use_inference_client = False
97
+
98
+ self.model = policy_config["model"]
99
+ self.temperature = policy_config.get("temperature", 0.0)
100
+ self.max_tokens = policy_config.get("max_tokens", 512)
101
+
102
+ # System prompt
103
+ self.system_prompt = """You are playing Crafter, a survival game. Your goal is to unlock achievements.
104
+
105
+ Core rules:
106
+ - The world contains trees (wood), stone, coal, iron, plants, cows, zombies, and water.
107
+ - Movement constraints: you cannot walk onto blocking tiles (tree, stone, water, lava, coal, iron).
108
+ - You start with empty hands and low health/hunger.
109
+ - Interact ('do') only when adjacent to a resource.
110
+ - Movement is essential: move multiple steps in one turn to explore.
111
+
112
+ Available actions: noop, move_up, move_down, move_left, move_right, do, sleep,
113
+ place_stone, place_table, place_furnace, place_plant, make_wood_pickaxe,
114
+ make_stone_pickaxe, make_iron_pickaxe, make_wood_sword, make_stone_sword, make_iron_sword
115
+
116
+ Always return a tool call: interact_many({actions: [...]})
117
+ Use 2-5 actions per call. Prefer long movement sequences."""
118
+
119
+ # Tool definition
120
+ self.tools = [{
121
+ "type": "function",
122
+ "function": {
123
+ "name": "interact_many",
124
+ "description": "Execute multiple Crafter actions in sequence",
125
+ "parameters": {
126
+ "type": "object",
127
+ "properties": {
128
+ "actions": {
129
+ "type": "array",
130
+ "items": {"type": "string", "enum": list(CRAFTER_ACTION_MAP.keys())},
131
+ "description": "List of actions to execute",
132
+ }
133
+ },
134
+ "required": ["actions"],
135
+ },
136
+ },
137
+ }]
138
+
139
+ async def run_task(self, seed: int) -> TaskResult:
140
+ """Run a single Crafter episode and return results."""
141
+
142
+ # Create task instance
143
+ difficulty = self.env_config.get("difficulty", "normal")
144
+ max_steps = self.env_config.get("max_steps", 100)
145
+
146
+ impetus = Impetus(instructions="Survive and unlock achievements.")
147
+ intent = Intent(
148
+ rubric={"goal": "Unlock achievements"},
149
+ gold_trajectories=None,
150
+ gold_state_diff={},
151
+ )
152
+ metadata = CrafterTaskInstanceMetadata(
153
+ difficulty=difficulty,
154
+ seed=seed,
155
+ num_trees_radius=0,
156
+ num_cows_radius=0,
157
+ num_hostiles_radius=0,
158
+ )
159
+ task_instance = CrafterTaskInstance(
160
+ id=uuid4(),
161
+ impetus=impetus,
162
+ intent=intent,
163
+ metadata=metadata,
164
+ is_reproducible=True,
165
+ initial_engine_snapshot=None,
166
+ )
167
+
168
+ # Attach config
169
+ task_instance.config = {"seed": seed, "length": 256, "area": [64, 64]}
170
+
171
+ # Create environment
172
+ env = CrafterClassicEnvironment(task_instance=task_instance)
173
+
174
+ # Setup tracing
175
+ tracer: Optional[SessionTracer] = None
176
+ session_id: Optional[str] = None
177
+ if self.env_config.get("enable_tracing", True):
178
+ tracer = SessionTracer(db_url=None, auto_save=False)
179
+ await tracer.initialize()
180
+ session_id = tracer.create_session(metadata={
181
+ "seed": seed,
182
+ "difficulty": difficulty,
183
+ "model": self.policy_config["model"],
184
+ })
185
+
186
+ # Initialize environment
187
+ raw_obs = await env.initialize()
188
+ observation = getattr(raw_obs, "observation", raw_obs) if hasattr(raw_obs, "observation") else raw_obs
189
+ obs_dict = observation if isinstance(observation, dict) else {}
190
+
191
+ # Track achievements
192
+ prev_achievements: Set[str] = set()
193
+ if isinstance(obs_dict.get("achievements_status"), dict):
194
+ prev_achievements = {
195
+ k for k, v in obs_dict.get("achievements_status", {}).items() if v
196
+ }
197
+
198
+ event_rewards: List[Dict[str, Any]] = []
199
+ total_steps = 0
200
+ tool_calls_history: List[Dict[str, Any]] = []
201
+
202
+ # Episode loop
203
+ for step in range(max_steps):
204
+ # Format observation
205
+ obs_text = format_crafter_observation(obs_dict)
206
+
207
+ # Build messages
208
+ messages = [
209
+ {"role": "system", "content": self.system_prompt},
210
+ {"role": "user", "content": f"{obs_text}\n\nPrevious tool calls: {tool_calls_history[-3:]}"},
211
+ ]
212
+
213
+ # Record LLM event
214
+ llm_event_id = None
215
+ if tracer and session_id:
216
+ llm_event_id = tracer.record_event(
217
+ session_id=session_id,
218
+ event_type="cais",
219
+ data={"messages": messages, "step": step},
220
+ )
221
+
222
+ # Get action from LLM
223
+ if self.use_inference_client:
224
+ response = await self.client.create_chat_completion(
225
+ model=self.model,
226
+ messages=messages,
227
+ tools=self.tools,
228
+ tool_choice={"type": "function", "function": {"name": "interact_many"}},
229
+ temperature=self.temperature,
230
+ max_tokens=self.max_tokens,
231
+ )
232
+ else:
233
+ # Fallback: use OpenAI-compatible API
234
+ import httpx
235
+ import json as json_lib
236
+ api_key = os.getenv("OPENAI_API_KEY") or os.getenv("GROQ_API_KEY") or ""
237
+ base_url = "https://api.openai.com/v1" if "openai" in self.model.lower() else "https://api.groq.com/openai/v1"
238
+ async with httpx.AsyncClient() as client:
239
+ resp = await client.post(
240
+ f"{base_url}/chat/completions",
241
+ json={
242
+ "model": self.model,
243
+ "messages": messages,
244
+ "tools": self.tools,
245
+ "tool_choice": {"type": "function", "function": {"name": "interact_many"}},
246
+ "temperature": self.temperature,
247
+ "max_tokens": self.max_tokens,
248
+ },
249
+ headers={"Authorization": f"Bearer {api_key}"} if api_key else {},
250
+ )
251
+ response = resp.json()
252
+
253
+ # Parse tool call
254
+ tool_calls = []
255
+ if "choices" in response and len(response["choices"]) > 0:
256
+ message = response["choices"][0].get("message", {})
257
+ tool_calls = message.get("tool_calls", [])
258
+ elif "tool_calls" in response:
259
+ tool_calls = response["tool_calls"]
260
+
261
+ if not tool_calls:
262
+ break
263
+
264
+ tool_call = tool_calls[0]
265
+ actions = tool_call["function"]["arguments"].get("actions", [])
266
+ tool_calls_history.append({"step": step, "actions": actions})
267
+
268
+ # Execute actions
269
+ for action_name in actions:
270
+ if total_steps >= max_steps:
271
+ break
272
+
273
+ # Map action string to index
274
+ action_idx = CRAFTER_ACTION_MAP.get(action_name, 0)
275
+
276
+ # Step environment
277
+ step_result = await env.step(action_idx)
278
+ total_steps += 1
279
+
280
+ # Get observation from step result
281
+ step_obs = getattr(step_result, "observation", step_result) if hasattr(step_result, "observation") else step_result
282
+ obs_dict = step_obs if isinstance(step_obs, dict) else {}
283
+
284
+ # Record environment event
285
+ env_event_id = None
286
+ if tracer and session_id:
287
+ env_event_id = tracer.record_event(
288
+ session_id=session_id,
289
+ event_type="environment",
290
+ data={
291
+ "action": action_name,
292
+ "reward": getattr(step_result, "reward", 0.0),
293
+ "terminated": getattr(step_result, "terminated", False),
294
+ "step": total_steps,
295
+ },
296
+ )
297
+
298
+ # Check for new achievements
299
+ current_achievements: Set[str] = set()
300
+ if isinstance(obs_dict.get("achievements_status"), dict):
301
+ current_achievements = {
302
+ k for k, v in obs_dict.get("achievements_status", {}).items() if v
303
+ }
304
+
305
+ new_achievements = current_achievements - prev_achievements
306
+
307
+ if new_achievements:
308
+ event_reward_value = len(new_achievements)
309
+ if tracer and session_id and env_event_id:
310
+ tracer.record_event_reward(
311
+ session_id=session_id,
312
+ event_id=env_event_id,
313
+ reward_value=float(event_reward_value),
314
+ reward_type="achievement_delta",
315
+ key="achievements",
316
+ annotation={"new_achievements": list(new_achievements)},
317
+ source="environment",
318
+ )
319
+ event_rewards.append({
320
+ "step": total_steps,
321
+ "reward": event_reward_value,
322
+ "achievements": list(new_achievements),
323
+ })
324
+
325
+ prev_achievements = current_achievements
326
+
327
+ # Check termination
328
+ if getattr(step_result, "terminated", False) or getattr(step_result, "truncated", False):
329
+ break
330
+
331
+ if getattr(step_result, "terminated", False) or getattr(step_result, "truncated", False):
332
+ break
333
+
334
+ # Compute outcome reward
335
+ unique_achievements = len(prev_achievements)
336
+ if tracer and session_id:
337
+ tracer.record_outcome_reward(
338
+ session_id=session_id,
339
+ total_reward=unique_achievements,
340
+ achievements_count=unique_achievements,
341
+ total_steps=total_steps,
342
+ reward_metadata={
343
+ "achievements": list(prev_achievements),
344
+ },
345
+ )
346
+
347
+ # Export trace
348
+ trace_dict = await tracer.export_session(session_id)
349
+ else:
350
+ trace_dict = None
351
+
352
+ return TaskResult(
353
+ seed=seed,
354
+ success=True,
355
+ outcome_reward=float(unique_achievements),
356
+ event_rewards=event_rewards,
357
+ total_steps=total_steps,
358
+ metadata={
359
+ "achievements": list(prev_achievements),
360
+ "achievement_count": unique_achievements,
361
+ "difficulty": difficulty,
362
+ },
363
+ trace=trace_dict,
364
+ )
365
+
366
+
367
+ # Define baseline config
368
+ crafter_baseline = BaselineConfig(
369
+ baseline_id="crafter",
370
+ name="Crafter Survival",
371
+ description="Crafter survival game with achievement tracking",
372
+ task_runner=CrafterTaskRunner,
373
+ splits={
374
+ "train": DataSplit(
375
+ name="train",
376
+ seeds=list(range(100)),
377
+ metadata={"difficulty": "normal"},
378
+ ),
379
+ "val": DataSplit(
380
+ name="val",
381
+ seeds=list(range(100, 150)),
382
+ metadata={"difficulty": "normal"},
383
+ ),
384
+ "test": DataSplit(
385
+ name="test",
386
+ seeds=list(range(150, 200)),
387
+ metadata={"difficulty": "hard"},
388
+ ),
389
+ },
390
+ default_policy_config={
391
+ "model": "groq:llama-3.1-70b-versatile",
392
+ "temperature": 0.0,
393
+ "max_tokens": 1024,
394
+ },
395
+ default_env_config={
396
+ "difficulty": "normal",
397
+ "max_steps": 100,
398
+ "enable_tracing": True,
399
+ },
400
+ metadata={
401
+ "environment": "crafter",
402
+ "reward_type": "achievements",
403
+ "max_achievements": 22,
404
+ },
405
+ tags=["rl", "gym", "survival", "achievements"],
406
+ )
407
+