synth-ai 0.2.9.dev7__py3-none-any.whl → 0.2.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (323) hide show
  1. examples/__init__.py +16 -0
  2. examples/crafter_debug_render.py +8 -11
  3. examples/dev/qwen3_32b_qlora_4xh100.toml +40 -0
  4. examples/multi_step/crafter_rl_lora.md +29 -0
  5. examples/qwen_coder/README.md +102 -0
  6. examples/qwen_coder/_shared.py +113 -0
  7. examples/qwen_coder/configs/coder_lora_30b.toml +61 -0
  8. examples/qwen_coder/configs/coder_lora_4b.toml +57 -0
  9. examples/qwen_coder/configs/coder_lora_small.toml +58 -0
  10. examples/qwen_coder/generate_dataset.py +98 -0
  11. examples/qwen_coder/infer_ft_smoke.py +65 -0
  12. examples/qwen_coder/infer_prod_proxy.py +73 -0
  13. examples/qwen_coder/infer_via_synth.py +87 -0
  14. examples/qwen_coder/scripts/infer_coder.sh +19 -0
  15. examples/qwen_coder/scripts/train_coder_30b.sh +22 -0
  16. examples/qwen_coder/sft_full_17b.py +103 -0
  17. examples/qwen_coder/sft_lora_30b.py +110 -0
  18. examples/qwen_coder/subset_jsonl.py +39 -0
  19. examples/qwen_coder/todos.md +38 -0
  20. examples/qwen_coder/validate_jsonl.py +60 -0
  21. examples/rl/run_eval.py +36 -37
  22. examples/rl/run_rl_and_save.py +5 -5
  23. examples/rl/task_app/math_single_step.py +65 -43
  24. examples/rl/task_app/math_task_app.py +3 -3
  25. examples/sft/README.md +139 -0
  26. examples/sft/configs/crafter_fft_qwen0p6b.toml +44 -0
  27. examples/sft/configs/crafter_lora_qwen0p6b.toml +45 -0
  28. examples/sft/evaluate.py +117 -0
  29. examples/sft/export_dataset.py +117 -0
  30. examples/sft/generate_traces.py +162 -0
  31. examples/swe/__init__.py +12 -0
  32. examples/swe/task_app/README.md +105 -0
  33. examples/swe/task_app/__init__.py +2 -0
  34. examples/swe/task_app/grpo_swe_mini.py +571 -0
  35. examples/swe/task_app/grpo_swe_mini_task_app.py +136 -0
  36. examples/swe/task_app/hosted/README.md +173 -0
  37. examples/swe/task_app/hosted/__init__.py +5 -0
  38. examples/swe/task_app/hosted/branching.py +143 -0
  39. examples/swe/task_app/hosted/environment_routes.py +1289 -0
  40. examples/swe/task_app/hosted/envs/__init__.py +1 -0
  41. examples/swe/task_app/hosted/envs/crafter/__init__.py +6 -0
  42. examples/swe/task_app/hosted/envs/crafter/app.py +1 -0
  43. examples/swe/task_app/hosted/envs/crafter/environment.py +522 -0
  44. examples/swe/task_app/hosted/envs/crafter/policy.py +478 -0
  45. examples/swe/task_app/hosted/envs/crafter/react_agent.py +108 -0
  46. examples/swe/task_app/hosted/envs/crafter/shared.py +305 -0
  47. examples/swe/task_app/hosted/envs/crafter/tools.py +47 -0
  48. examples/swe/task_app/hosted/envs/mini_swe/__init__.py +8 -0
  49. examples/swe/task_app/hosted/envs/mini_swe/environment.py +1164 -0
  50. examples/swe/task_app/hosted/envs/mini_swe/policy.py +355 -0
  51. examples/swe/task_app/hosted/envs/mini_swe/shared.py +83 -0
  52. examples/swe/task_app/hosted/envs/mini_swe/tools.py +96 -0
  53. examples/swe/task_app/hosted/hosted_app.py +204 -0
  54. examples/swe/task_app/hosted/inference/__init__.py +5 -0
  55. examples/swe/task_app/hosted/inference/openai_client.py +618 -0
  56. examples/swe/task_app/hosted/main.py +100 -0
  57. examples/swe/task_app/hosted/policy_routes.py +1079 -0
  58. examples/swe/task_app/hosted/registry.py +195 -0
  59. examples/swe/task_app/hosted/rollout.py +1869 -0
  60. examples/swe/task_app/hosted/storage/__init__.py +5 -0
  61. examples/swe/task_app/hosted/storage/volume.py +211 -0
  62. examples/swe/task_app/hosted/test_agents.py +161 -0
  63. examples/swe/task_app/hosted/test_service.py +137 -0
  64. examples/swe/task_app/hosted/utils.py +62 -0
  65. examples/vlm/PROPOSAL.md +53 -0
  66. examples/vlm/README.md +68 -0
  67. examples/vlm/configs/crafter_vlm_gpt4o.toml +44 -0
  68. examples/vlm/crafter_image_only_agent.py +207 -0
  69. examples/vlm/crafter_openai_vlm_agent.py +277 -0
  70. examples/vlm/filter_image_rows.py +63 -0
  71. examples/vlm/run_crafter_vlm_benchmark.py +316 -0
  72. examples/warming_up_to_rl/analyze_trace_db.py +5 -5
  73. examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +11 -1
  74. examples/warming_up_to_rl/export_trace_sft.py +78 -21
  75. examples/warming_up_to_rl/groq_test.py +4 -4
  76. examples/warming_up_to_rl/manage_secrets.py +13 -18
  77. examples/warming_up_to_rl/run_eval.py +42 -44
  78. examples/warming_up_to_rl/run_fft_and_save.py +11 -16
  79. examples/warming_up_to_rl/run_local_rollout.py +1 -3
  80. examples/warming_up_to_rl/run_local_rollout_modal.py +2 -4
  81. examples/warming_up_to_rl/run_local_rollout_parallel.py +1 -4
  82. examples/warming_up_to_rl/run_local_rollout_traced.py +3 -5
  83. examples/warming_up_to_rl/run_rl_and_save.py +5 -6
  84. examples/warming_up_to_rl/run_rollout_remote.py +8 -10
  85. examples/warming_up_to_rl/task_app/README.md +6 -2
  86. examples/warming_up_to_rl/task_app/grpo_crafter.py +234 -35
  87. examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +2 -3
  88. examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +1 -1
  89. examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +9 -11
  90. examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +131 -114
  91. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +101 -41
  92. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +73 -51
  93. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +14 -6
  94. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +16 -16
  95. examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +32 -34
  96. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +94 -31
  97. examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +0 -2
  98. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +303 -203
  99. examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +21 -23
  100. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +328 -225
  101. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +13 -13
  102. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +1 -0
  103. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +1 -0
  104. examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +4 -3
  105. synth_ai/api/models/supported.py +376 -0
  106. synth_ai/api/train/builders.py +128 -21
  107. synth_ai/api/train/cli.py +80 -64
  108. synth_ai/api/train/config_finder.py +7 -2
  109. synth_ai/api/train/env_resolver.py +1 -1
  110. synth_ai/api/train/pollers.py +2 -1
  111. synth_ai/api/train/supported_algos.py +139 -0
  112. synth_ai/api/train/task_app.py +1 -2
  113. synth_ai/api/train/utils.py +13 -44
  114. synth_ai/cli/__init__.py +8 -0
  115. synth_ai/cli/_modal_wrapper.py +28 -0
  116. synth_ai/cli/_typer_patch.py +49 -0
  117. synth_ai/cli/balance.py +1 -2
  118. synth_ai/cli/calc.py +1 -1
  119. synth_ai/cli/demo.py +2 -1
  120. synth_ai/cli/recent.py +2 -2
  121. synth_ai/cli/rl_demo.py +2 -1
  122. synth_ai/cli/root.py +11 -13
  123. synth_ai/cli/status.py +2 -2
  124. synth_ai/cli/task_apps.py +529 -179
  125. synth_ai/cli/traces.py +6 -4
  126. synth_ai/cli/watch.py +12 -18
  127. synth_ai/demo_registry.py +1 -1
  128. synth_ai/demos/core/cli.py +36 -43
  129. synth_ai/demos/demo_task_apps/__init__.py +3 -3
  130. synth_ai/demos/demo_task_apps/core.py +17 -25
  131. synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +3 -4
  132. synth_ai/demos/demo_task_apps/math/app.py +2 -1
  133. synth_ai/demos/demo_task_apps/math/deploy_modal.py +3 -4
  134. synth_ai/demos/demo_task_apps/math/modal_task_app.py +16 -18
  135. synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -1
  136. synth_ai/environments/examples/crafter_classic/environment.py +76 -1
  137. synth_ai/environments/reproducibility/tree.py +2 -5
  138. synth_ai/environments/service/app.py +11 -12
  139. synth_ai/environments/service/core_routes.py +4 -7
  140. synth_ai/environments/stateful/engine.py +1 -1
  141. synth_ai/environments/tasks/core.py +1 -0
  142. synth_ai/environments/tasks/filters.py +5 -6
  143. synth_ai/environments/tasks/utils.py +4 -5
  144. synth_ai/handshake.py +9 -9
  145. synth_ai/http.py +1 -1
  146. synth_ai/http_client.py +18 -10
  147. synth_ai/inference/client.py +15 -5
  148. synth_ai/jobs/client.py +78 -83
  149. synth_ai/learning/__init__.py +41 -6
  150. synth_ai/learning/algorithms.py +14 -0
  151. synth_ai/learning/client.py +91 -24
  152. synth_ai/learning/config.py +2 -38
  153. synth_ai/learning/ft_client.py +4 -59
  154. synth_ai/learning/health.py +5 -6
  155. synth_ai/learning/jobs.py +31 -47
  156. synth_ai/{rl → learning/rl}/__init__.py +14 -4
  157. synth_ai/learning/rl/client.py +267 -0
  158. synth_ai/learning/rl/config.py +31 -0
  159. synth_ai/{rl → learning/rl}/contracts.py +5 -8
  160. synth_ai/{rl → learning/rl}/env_keys.py +39 -15
  161. synth_ai/learning/rl/secrets.py +13 -0
  162. synth_ai/learning/rl_client.py +2 -281
  163. synth_ai/learning/sft/__init__.py +29 -0
  164. synth_ai/learning/sft/client.py +68 -0
  165. synth_ai/learning/sft/config.py +270 -0
  166. synth_ai/learning/sft/data.py +295 -0
  167. synth_ai/learning/sse.py +25 -24
  168. synth_ai/learning/validators.py +25 -28
  169. synth_ai/lm/__init__.py +21 -47
  170. synth_ai/task/__init__.py +25 -27
  171. synth_ai/task/apps/__init__.py +7 -8
  172. synth_ai/task/auth.py +8 -8
  173. synth_ai/task/client.py +14 -14
  174. synth_ai/task/contracts.py +36 -35
  175. synth_ai/task/datasets.py +6 -5
  176. synth_ai/task/errors.py +10 -10
  177. synth_ai/task/health.py +17 -9
  178. synth_ai/task/json.py +58 -23
  179. synth_ai/task/proxy.py +13 -9
  180. synth_ai/task/rubrics.py +16 -15
  181. synth_ai/task/server.py +12 -12
  182. synth_ai/task/tracing_utils.py +4 -4
  183. synth_ai/task/vendors.py +5 -6
  184. synth_ai/tracing_v3/__init__.py +2 -0
  185. synth_ai/tracing_v3/abstractions.py +21 -4
  186. synth_ai/tracing_v3/decorators.py +18 -16
  187. synth_ai/tracing_v3/hooks.py +5 -5
  188. synth_ai/tracing_v3/llm_call_record_helpers.py +6 -6
  189. synth_ai/tracing_v3/session_tracer.py +40 -14
  190. synth_ai/tracing_v3/storage/base.py +85 -0
  191. synth_ai/tracing_v3/storage/config.py +21 -8
  192. synth_ai/tracing_v3/storage/factory.py +10 -7
  193. synth_ai/tracing_v3/storage/utils.py +4 -2
  194. synth_ai/tracing_v3/turso/daemon.py +7 -2
  195. synth_ai/tracing_v3/turso/models.py +2 -2
  196. synth_ai/tracing_v3/turso/native_manager.py +1173 -0
  197. synth_ai/tracing_v3/utils.py +4 -4
  198. synth_ai/v0/api/__init__.py +8 -0
  199. synth_ai/v0/api/models/__init__.py +8 -0
  200. synth_ai/v0/api/models/supported.py +8 -0
  201. synth_ai/v0/config/__init__.py +15 -0
  202. synth_ai/v0/config/base_url.py +12 -0
  203. synth_ai/v0/lm/__init__.py +51 -0
  204. synth_ai/{lm → v0/lm}/caching/ephemeral.py +2 -2
  205. synth_ai/{lm → v0/lm}/caching/handler.py +4 -4
  206. synth_ai/{lm → v0/lm}/caching/initialize.py +1 -1
  207. synth_ai/{lm → v0/lm}/caching/persistent.py +1 -1
  208. synth_ai/{lm → v0/lm}/config.py +6 -1
  209. synth_ai/{lm → v0/lm}/core/all.py +9 -9
  210. synth_ai/{lm → v0/lm}/core/main.py +6 -6
  211. synth_ai/{lm → v0/lm}/core/main_v3.py +10 -10
  212. synth_ai/{lm → v0/lm}/core/synth_models.py +2 -14
  213. synth_ai/{lm → v0/lm}/core/vendor_clients.py +2 -2
  214. synth_ai/{lm → v0/lm}/overrides.py +2 -2
  215. synth_ai/{lm → v0/lm}/provider_support/anthropic.py +4 -4
  216. synth_ai/{lm → v0/lm}/provider_support/openai.py +5 -5
  217. synth_ai/{lm → v0/lm}/structured_outputs/handler.py +5 -5
  218. synth_ai/{lm → v0/lm}/structured_outputs/rehabilitate.py +1 -1
  219. synth_ai/{lm → v0/lm}/vendors/core/anthropic_api.py +9 -9
  220. synth_ai/{lm → v0/lm}/vendors/core/gemini_api.py +5 -5
  221. synth_ai/{lm → v0/lm}/vendors/core/mistral_api.py +5 -5
  222. synth_ai/{lm → v0/lm}/vendors/core/openai_api.py +10 -10
  223. synth_ai/{lm → v0/lm}/vendors/openai_standard.py +8 -8
  224. synth_ai/{lm → v0/lm}/vendors/openai_standard_responses.py +2 -2
  225. synth_ai/{lm → v0/lm}/vendors/supported/custom_endpoint.py +3 -3
  226. synth_ai/{lm → v0/lm}/vendors/supported/deepseek.py +2 -2
  227. synth_ai/{lm → v0/lm}/vendors/supported/grok.py +2 -2
  228. synth_ai/{lm → v0/lm}/vendors/supported/groq.py +1 -1
  229. synth_ai/{lm → v0/lm}/vendors/supported/ollama.py +1 -1
  230. synth_ai/{lm → v0/lm}/vendors/supported/openrouter.py +3 -3
  231. synth_ai/{lm → v0/lm}/vendors/supported/together.py +1 -1
  232. synth_ai/{lm → v0/lm}/vendors/synth_client.py +1 -1
  233. synth_ai/v0/tracing_v3/__init__.py +10 -0
  234. synth_ai/v0/tracing_v3/abstractions.py +3 -0
  235. synth_ai/v0/tracing_v3/decorators.py +3 -0
  236. synth_ai/v0/tracing_v3/llm_call_record_helpers.py +3 -0
  237. synth_ai/v0/tracing_v3/session_tracer.py +3 -0
  238. {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.10.dist-info}/METADATA +10 -7
  239. {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.10.dist-info}/RECORD +269 -233
  240. examples/common_old/backend.py +0 -20
  241. examples/evals_old/README.md +0 -98
  242. examples/evals_old/__init__.py +0 -6
  243. examples/evals_old/compare_models.py +0 -1038
  244. examples/evals_old/example_log.md +0 -145
  245. examples/evals_old/run_demo.sh +0 -126
  246. examples/evals_old/trace_analysis.py +0 -270
  247. examples/finetuning_old/_backup_synth_qwen/config.toml +0 -29
  248. examples/finetuning_old/_backup_synth_qwen/example_log.md +0 -324
  249. examples/finetuning_old/_backup_synth_qwen/filter_traces.py +0 -60
  250. examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +0 -243
  251. examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +0 -109
  252. examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +0 -1924
  253. examples/finetuning_old/_backup_synth_qwen/readme.md +0 -49
  254. examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +0 -114
  255. examples/finetuning_old/_backup_synth_qwen/run_demo.sh +0 -195
  256. examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +0 -119
  257. examples/finetuning_old/synth_qwen_v1/README.md +0 -68
  258. examples/finetuning_old/synth_qwen_v1/filter_traces.py +0 -60
  259. examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +0 -243
  260. examples/finetuning_old/synth_qwen_v1/finetune.py +0 -46
  261. examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +0 -71
  262. examples/finetuning_old/synth_qwen_v1/infer.py +0 -36
  263. examples/finetuning_old/synth_qwen_v1/poll.py +0 -46
  264. examples/finetuning_old/synth_qwen_v1/prepare_data.py +0 -35
  265. examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +0 -109
  266. examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +0 -1933
  267. examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +0 -210
  268. examples/finetuning_old/synth_qwen_v1/run_ft_job.py +0 -237
  269. examples/finetuning_old/synth_qwen_v1/upload_data.py +0 -34
  270. examples/finetuning_old/synth_qwen_v1/util.py +0 -152
  271. examples/rl_old/task_app.py +0 -1131
  272. synth_ai/experimental/synth_oss.py +0 -445
  273. synth_ai/learning/filtering.py +0 -0
  274. synth_ai/learning/offline/dpo.py +0 -0
  275. synth_ai/learning/offline/providers.py +0 -7
  276. synth_ai/learning/offline/sft.py +0 -0
  277. synth_ai/learning/offline/shared.py +0 -0
  278. synth_ai/learning/online/grpo.py +0 -0
  279. synth_ai/learning/online/irft.py +0 -0
  280. synth_ai/learning/prompts/banking77_injection_eval.py +0 -168
  281. synth_ai/learning/prompts/gepa.py +0 -0
  282. synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +0 -211
  283. synth_ai/learning/prompts/mipro.py +0 -289
  284. synth_ai/learning/prompts/random_search.py +0 -249
  285. synth_ai/learning/prompts/run_mipro_banking77.py +0 -172
  286. synth_ai/learning/prompts/run_random_search_banking77.py +0 -329
  287. synth_ai/rl/secrets.py +0 -19
  288. synth_ai/scripts/verify_rewards.py +0 -100
  289. synth_ai/tracing/__init__.py +0 -30
  290. synth_ai/tracing_v1/__init__.py +0 -33
  291. synth_ai/tracing_v3/turso/__init__.py +0 -25
  292. synth_ai/tracing_v3/turso/manager.py +0 -838
  293. synth_ai/zyk/__init__.py +0 -30
  294. /synth_ai/{lm → v0/lm}/caching/__init__.py +0 -0
  295. /synth_ai/{lm → v0/lm}/caching/constants.py +0 -0
  296. /synth_ai/{lm → v0/lm}/caching/dbs.py +0 -0
  297. /synth_ai/{lm → v0/lm}/constants.py +0 -0
  298. /synth_ai/{lm → v0/lm}/core/__init__.py +0 -0
  299. /synth_ai/{lm → v0/lm}/core/exceptions.py +0 -0
  300. /synth_ai/{lm → v0/lm}/cost/__init__.py +0 -0
  301. /synth_ai/{lm → v0/lm}/cost/monitor.py +0 -0
  302. /synth_ai/{lm → v0/lm}/cost/statefulness.py +0 -0
  303. /synth_ai/{lm → v0/lm}/injection.py +0 -0
  304. /synth_ai/{lm → v0/lm}/provider_support/__init__.py +0 -0
  305. /synth_ai/{lm → v0/lm}/provider_support/suppress_logging.py +0 -0
  306. /synth_ai/{lm → v0/lm}/structured_outputs/__init__.py +0 -0
  307. /synth_ai/{lm → v0/lm}/structured_outputs/inject.py +0 -0
  308. /synth_ai/{lm → v0/lm}/tools/__init__.py +0 -0
  309. /synth_ai/{lm → v0/lm}/tools/base.py +0 -0
  310. /synth_ai/{lm → v0/lm}/unified_interface.py +0 -0
  311. /synth_ai/{lm → v0/lm}/vendors/__init__.py +0 -0
  312. /synth_ai/{lm → v0/lm}/vendors/base.py +0 -0
  313. /synth_ai/{lm → v0/lm}/vendors/core/__init__.py +0 -0
  314. /synth_ai/{lm → v0/lm}/vendors/core/synth_dev_api.py +0 -0
  315. /synth_ai/{lm → v0/lm}/vendors/local/__init__.py +0 -0
  316. /synth_ai/{lm → v0/lm}/vendors/local/ollama.py +0 -0
  317. /synth_ai/{lm → v0/lm}/vendors/retries.py +0 -0
  318. /synth_ai/{lm → v0/lm}/vendors/supported/__init__.py +0 -0
  319. /synth_ai/{lm → v0/lm}/warmup.py +0 -0
  320. {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.10.dist-info}/WHEEL +0 -0
  321. {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.10.dist-info}/entry_points.txt +0 -0
  322. {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.10.dist-info}/licenses/LICENSE +0 -0
  323. {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.10.dist-info}/top_level.txt +0 -0
@@ -1,19 +1,56 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Any, Dict, List, Optional
3
+ import base64
4
4
  import logging
5
+ from io import BytesIO
6
+ from typing import Any
5
7
 
6
- from synth_ai.environments.stateful.core import StatefulEnvironment
8
+ import numpy as np
9
+ from PIL import Image
7
10
  from synth_ai.environments.environment.tools import EnvToolCall
11
+ from synth_ai.environments.stateful.core import StatefulEnvironment
8
12
 
9
13
  from ...utils import convert_numpy_to_python
10
- from .tools import TOOLS_SCHEMA
11
14
  from .shared import CRAFTER_ACTIONS, _format_semantic_map_view
12
-
15
+ from .tools import TOOLS_SCHEMA
13
16
 
14
17
  logger = logging.getLogger(__name__)
15
18
 
16
19
 
20
+ def _encode_image_to_base64(image_array: Any) -> dict[str, Any] | None:
21
+ """Encode an RGB ndarray into a base64 PNG payload with metadata."""
22
+
23
+ if not isinstance(image_array, np.ndarray):
24
+ return None
25
+ if image_array.ndim != 3 or image_array.shape[-1] not in (1, 3, 4):
26
+ return None
27
+ try:
28
+ # Ensure uint8 for PIL compatibility
29
+ array_uint8 = (
30
+ image_array.astype("uint8")
31
+ if image_array.dtype != np.uint8
32
+ else image_array # pragma: no cover - fast path
33
+ )
34
+ mode = "L" if array_uint8.shape[-1] == 1 else "RGB"
35
+ if array_uint8.shape[-1] == 4:
36
+ mode = "RGBA"
37
+ img = Image.fromarray(array_uint8, mode=mode)
38
+ buffer = BytesIO()
39
+ img.save(buffer, format="PNG")
40
+ encoded = base64.b64encode(buffer.getvalue()).decode("ascii")
41
+ width = int(array_uint8.shape[1])
42
+ height = int(array_uint8.shape[0])
43
+ return {
44
+ "format": "png",
45
+ "width": width,
46
+ "height": height,
47
+ "data": encoded,
48
+ "data_url": f"data:image/png;base64,{encoded}",
49
+ }
50
+ except Exception:
51
+ return None
52
+
53
+
17
54
  class CrafterEnvironmentWrapper:
18
55
  """Host-side environment wrapper matching the sketch contract.
19
56
 
@@ -25,20 +62,20 @@ class CrafterEnvironmentWrapper:
25
62
  - snapshot()/restore() handled at route level; this wrapper exposes checkpoint via synth-ai
26
63
  """
27
64
 
28
- def __init__(self, env: StatefulEnvironment, seed: Optional[int] = None) -> None:
65
+ def __init__(self, env: StatefulEnvironment, seed: int | None = None) -> None:
29
66
  self.env = env
30
67
  self.seed = seed
31
68
  self.step_idx = 0
32
- self.last_observation: Optional[Dict[str, Any]] = None
33
- self.last_info: Optional[Dict[str, Any]] = None
69
+ self.last_observation: dict[str, Any] | None = None
70
+ self.last_info: dict[str, Any] | None = None
34
71
 
35
- async def initialize(self) -> Dict[str, Any]:
72
+ async def initialize(self) -> dict[str, Any]:
36
73
  obs = await self.env.initialize()
37
74
  # synth-ai InternalObservation expected to expose .observation (dict-like)
38
75
  self.step_idx = 0
39
76
  self.last_observation = getattr(obs, "observation", obs) # tolerate dict-like
40
77
  self.last_info = getattr(obs, "info", None)
41
- out_obs: Dict[str, Any] = convert_numpy_to_python(self.last_observation) or {}
78
+ out_obs = self._prepare_observation(self.last_observation)
42
79
  # Attach a 7x7 semantic map patch centered on player for client-side rendering
43
80
  try:
44
81
  pub = self.env.engine._get_public_state_from_env() # type: ignore[attr-defined]
@@ -47,13 +84,13 @@ class CrafterEnvironmentWrapper:
47
84
  size = 7
48
85
  half = size // 2
49
86
  patch = []
50
- H = len(sem) if hasattr(sem, "__len__") else 0
51
- W = len(sem[0]) if H and hasattr(sem[0], "__len__") else 0
87
+ height = len(sem) if hasattr(sem, "__len__") else 0
88
+ width = len(sem[0]) if height and hasattr(sem[0], "__len__") else 0
52
89
  for dy in range(-half, half + 1):
53
90
  row = []
54
91
  for dx in range(-half, half + 1):
55
92
  x, y = int(px) + dx, int(py) + dy
56
- if 0 <= x < H and 0 <= y < W:
93
+ if 0 <= x < height and 0 <= y < width:
57
94
  row.append(int(sem[x][y]))
58
95
  else:
59
96
  row.append(0)
@@ -68,7 +105,7 @@ class CrafterEnvironmentWrapper:
68
105
  "step_idx": self.step_idx,
69
106
  }
70
107
 
71
- async def step(self, tool_calls: List[Dict[str, Any]] | List[EnvToolCall]) -> Dict[str, Any]:
108
+ async def step(self, tool_calls: list[dict[str, Any]] | list[EnvToolCall]) -> dict[str, Any]:
72
109
  # Normalize JSON tool_calls into EnvToolCall instances if needed
73
110
  # Underlying synth-ai environment expects only tool="interact" with args={"action": <action_name>}.
74
111
  # LLM may emit:
@@ -79,9 +116,9 @@ class CrafterEnvironmentWrapper:
79
116
  allowed_actions = set(
80
117
  TOOLS_SCHEMA[0]["function"]["parameters"]["properties"]["actions"]["items"]["enum"]
81
118
  )
82
- normalized: List[EnvToolCall] = []
119
+ normalized: list[EnvToolCall] = []
83
120
 
84
- def _action_to_int(action: Any) -> Optional[int]:
121
+ def _action_to_int(action: Any) -> int | None:
85
122
  # Handle invalid actions gracefully instead of failing
86
123
  if isinstance(action, int):
87
124
  return action
@@ -153,10 +190,8 @@ class CrafterEnvironmentWrapper:
153
190
  if isinstance(args, dict) and "action" in args:
154
191
  candidate_action = args["action"]
155
192
  # If the caller provided a numeric action id, accept it directly
156
- action_int: Optional[int]
157
- if isinstance(candidate_action, int):
158
- action_int = _action_to_int(candidate_action)
159
- elif (
193
+ action_int: int | None
194
+ if isinstance(candidate_action, int) or (
160
195
  isinstance(candidate_action, str)
161
196
  and candidate_action in allowed_actions
162
197
  ):
@@ -175,7 +210,7 @@ class CrafterEnvironmentWrapper:
175
210
  normalized.append(EnvToolCall(tool="interact", args={"action": 0})) # noop action
176
211
 
177
212
  # Pre-step logging: capture current public state and print concise summary
178
- before_state: Optional[Dict[str, Any]] = None
213
+ before_state: dict[str, Any] | None = None
179
214
  try:
180
215
  pub_before = self.env.engine._get_public_state_from_env() # type: ignore[attr-defined]
181
216
  before_state = {
@@ -231,7 +266,7 @@ class CrafterEnvironmentWrapper:
231
266
  ach_added_latest: list[str] | None = None
232
267
  try:
233
268
  pub_after = self.env.engine._get_public_state_from_env() # type: ignore[attr-defined]
234
- after_dict: Dict[str, Any] = {
269
+ after_dict: dict[str, Any] = {
235
270
  "inventory": pub_after.inventory,
236
271
  "achievements_status": pub_after.achievements_status,
237
272
  "player_position": list(pub_after.player_position),
@@ -255,8 +290,8 @@ class CrafterEnvironmentWrapper:
255
290
  # Position delta
256
291
  pb = before_state.get("player_position", [0, 0])
257
292
  pa = after_dict.get("player_position", [0, 0])
258
- pb_t = (int(pb[0]), int(pb[1])) if isinstance(pb, (list, tuple)) else (0, 0)
259
- pa_t = (int(pa[0]), int(pa[1])) if isinstance(pa, (list, tuple)) else (0, 0)
293
+ pb_t = (int(pb[0]), int(pb[1])) if isinstance(pb, list | tuple) else (0, 0)
294
+ pa_t = (int(pa[0]), int(pa[1])) if isinstance(pa, list | tuple) else (0, 0)
260
295
  delta = (pa_t[0] - pb_t[0], pa_t[1] - pb_t[1])
261
296
 
262
297
  # Inventory changes
@@ -280,9 +315,9 @@ class CrafterEnvironmentWrapper:
280
315
  ach_a = {
281
316
  k for k, v in (after_dict.get("achievements_status", {}) or {}).items() if v
282
317
  }
283
- ach_added = sorted(list(ach_a - ach_b))
318
+ ach_added = sorted(ach_a - ach_b)
284
319
  ach_added_latest = ach_added
285
- ach_removed = sorted(list(ach_b - ach_a))
320
+ ach_removed = sorted(ach_b - ach_a)
286
321
 
287
322
  logger.info(
288
323
  "Changes: pos %s->%s Δ=%s | inv %s | ach +%s -%s",
@@ -312,8 +347,8 @@ class CrafterEnvironmentWrapper:
312
347
  )
313
348
  except Exception as _:
314
349
  pass
315
- result: Dict[str, Any] = {
316
- "observation": convert_numpy_to_python(observation),
350
+ result: dict[str, Any] = {
351
+ "observation": self._prepare_observation(observation),
317
352
  "step_idx": self.step_idx,
318
353
  "done": bool(done) if done is not None else False, # Ensure boolean
319
354
  }
@@ -325,13 +360,13 @@ class CrafterEnvironmentWrapper:
325
360
  size = 7
326
361
  half = size // 2
327
362
  patch = []
328
- H = len(sem) if hasattr(sem, "__len__") else 0
329
- W = len(sem[0]) if H and hasattr(sem[0], "__len__") else 0
363
+ height = len(sem) if hasattr(sem, "__len__") else 0
364
+ width = len(sem[0]) if height and hasattr(sem[0], "__len__") else 0
330
365
  for dy in range(-half, half + 1):
331
366
  row = []
332
367
  for dx in range(-half, half + 1):
333
368
  x, y = px + dx, py + dy
334
- if 0 <= x < H and 0 <= y < W:
369
+ if 0 <= x < height and 0 <= y < width:
335
370
  row.append(int(sem[x][y]))
336
371
  else:
337
372
  row.append(0)
@@ -341,10 +376,7 @@ class CrafterEnvironmentWrapper:
341
376
  obs_out["semantic_map_patch7"] = patch
342
377
  except Exception:
343
378
  pass
344
- if info is not None:
345
- result_info = convert_numpy_to_python(info)
346
- else:
347
- result_info = {}
379
+ result_info = convert_numpy_to_python(info) if info is not None else {}
348
380
  # Attach achievements delta for downstream metrics if useful
349
381
  if ach_added_latest is not None:
350
382
  try:
@@ -404,9 +436,37 @@ class CrafterEnvironmentWrapper:
404
436
  )
405
437
  except Exception:
406
438
  pass
439
+
407
440
  return result
408
441
 
409
- async def checkpoint(self) -> Dict[str, Any]:
442
+ def _prepare_observation(self, observation: Any) -> dict[str, Any]:
443
+ """Convert raw observation into a JSON-serializable dict with encoded image."""
444
+
445
+ obs_dict: dict[str, Any]
446
+ image_payload: dict[str, Any] | None = None
447
+
448
+ if isinstance(observation, dict):
449
+ image_payload = _encode_image_to_base64(observation.get("observation_image"))
450
+ # Work on a shallow copy to avoid mutating engine state
451
+ sanitized = dict(observation)
452
+ sanitized.pop("observation_image", None)
453
+ obs_dict = convert_numpy_to_python(sanitized) or {}
454
+ else:
455
+ obs_dict = convert_numpy_to_python(observation) or {}
456
+
457
+ if not isinstance(obs_dict, dict):
458
+ obs_dict = {"value": obs_dict}
459
+
460
+ if image_payload:
461
+ obs_dict["observation_image_base64"] = image_payload["data"]
462
+ obs_dict["observation_image_format"] = image_payload["format"]
463
+ obs_dict["observation_image_width"] = image_payload["width"]
464
+ obs_dict["observation_image_height"] = image_payload["height"]
465
+ obs_dict["observation_image_data_url"] = image_payload["data_url"]
466
+
467
+ return obs_dict
468
+
469
+ async def checkpoint(self) -> dict[str, Any]:
410
470
  obs = await self.env.checkpoint()
411
471
  observation = getattr(obs, "observation", obs)
412
472
  info = getattr(obs, "info", None)
@@ -416,7 +476,7 @@ class CrafterEnvironmentWrapper:
416
476
  "step_idx": self.step_idx,
417
477
  }
418
478
 
419
- async def terminate(self) -> Dict[str, Any]:
479
+ async def terminate(self) -> dict[str, Any]:
420
480
  obs = await self.env.terminate()
421
481
  observation = getattr(obs, "observation", obs)
422
482
  info = getattr(obs, "info", None)
@@ -426,7 +486,7 @@ class CrafterEnvironmentWrapper:
426
486
  "step_idx": self.step_idx,
427
487
  }
428
488
 
429
- def state_dict(self) -> Dict[str, Any]:
489
+ def state_dict(self) -> dict[str, Any]:
430
490
  return {
431
491
  "seed": self.seed,
432
492
  "step_idx": self.step_idx,
@@ -434,13 +494,13 @@ class CrafterEnvironmentWrapper:
434
494
  "last_info": self.last_info,
435
495
  }
436
496
 
437
- def load_state_dict(self, state: Dict[str, Any]) -> None:
497
+ def load_state_dict(self, state: dict[str, Any]) -> None:
438
498
  self.seed = state["seed"]
439
499
  self.step_idx = int(state["step_idx"])
440
500
  self.last_observation = state["last_observation"]
441
501
  self.last_info = state["last_info"]
442
502
 
443
- async def serialize(self) -> Dict[str, Any]:
503
+ async def serialize(self) -> dict[str, Any]:
444
504
  return {
445
505
  "name": "crafter",
446
506
  "config": {"seed": self.seed},
@@ -450,9 +510,9 @@ class CrafterEnvironmentWrapper:
450
510
  @classmethod
451
511
  async def deserialize(
452
512
  cls,
453
- payload: Dict[str, Any],
513
+ payload: dict[str, Any],
454
514
  env: StatefulEnvironment,
455
- ) -> "CrafterEnvironmentWrapper":
515
+ ) -> CrafterEnvironmentWrapper:
456
516
  seed = payload["config"]["seed"]
457
517
  wrapper = cls(env=env, seed=seed)
458
518
  wrapper.load_state_dict(payload["state"])
@@ -1,7 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Any, Dict, List, Optional, Tuple
4
3
  from abc import ABC, abstractmethod
4
+ from typing import Any
5
+
5
6
  from .react_agent import CrafterReActAgent
6
7
  from .tools import TOOLS_SCHEMA
7
8
 
@@ -12,15 +13,15 @@ class Policy(ABC):
12
13
 
13
14
  @abstractmethod
14
15
  def prepare_inference_request(
15
- self, observation: Dict[str, Any], history: List[Dict[str, Any]] = None
16
- ) -> Tuple[List[Dict[str, Any]], Optional[List[Dict[str, Any]]]]:
16
+ self, observation: dict[str, Any], history: list[dict[str, Any]] = None
17
+ ) -> tuple[list[dict[str, Any]], list[dict[str, Any]] | None]:
17
18
  """Prepare an inference request."""
18
19
  pass
19
20
 
20
21
  @abstractmethod
21
22
  def parse_model_response(
22
- self, response: str, observation: Dict[str, Any]
23
- ) -> List[Dict[str, Any]]:
23
+ self, response: str, observation: dict[str, Any]
24
+ ) -> list[dict[str, Any]]:
24
25
  """Parse model response into tool calls."""
25
26
  pass
26
27
 
@@ -39,23 +40,23 @@ class CrafterPolicy(Policy):
39
40
 
40
41
  name: str = "crafter-react"
41
42
 
42
- def __init__(self, inference_url: str, model: Optional[str] = None) -> None:
43
+ def __init__(self, inference_url: str, model: str | None = None) -> None:
43
44
  self.inference_url = inference_url
44
45
  self.model = model
45
46
  self.use_tools = True
46
47
  # Sampling parameters (populated via initialize(config))
47
- self.temperature: Optional[float] = None
48
- self.top_p: Optional[float] = None
49
- self.max_tokens: Optional[int] = None
48
+ self.temperature: float | None = None
49
+ self.top_p: float | None = None
50
+ self.max_tokens: int | None = None
50
51
  # Thinking controls (populated via initialize(config))
51
- self.thinking_mode: Optional[str] = None
52
- self.thinking_budget: Optional[int] = None
52
+ self.thinking_mode: str | None = None
53
+ self.thinking_budget: int | None = None
53
54
  # Rolling conversation and action history for non-Markov policies
54
- self.history_messages: List[Dict[str, str]] = [] # chat-style without system
55
+ self.history_messages: list[dict[str, str]] = [] # chat-style without system
55
56
  self.turn_index: int = 0
56
- self.trajectory_history: List[Dict[str, Any]] = [] # env/policy step records
57
+ self.trajectory_history: list[dict[str, Any]] = [] # env/policy step records
57
58
 
58
- async def initialize(self, config: Dict[str, Any]) -> None:
59
+ async def initialize(self, config: dict[str, Any]) -> None:
59
60
  if "inference_url" in config:
60
61
  self.inference_url = config["inference_url"]
61
62
  if "model" in config:
@@ -91,15 +92,15 @@ class CrafterPolicy(Policy):
91
92
 
92
93
  def _append_assistant_turn(
93
94
  self,
94
- assistant_text: Optional[str],
95
- tool_calls: Optional[List[Dict[str, Any]]],
96
- env_result: Optional[Dict[str, Any]],
95
+ assistant_text: str | None,
96
+ tool_calls: list[dict[str, Any]] | None,
97
+ env_result: dict[str, Any] | None,
97
98
  ) -> None:
98
99
  # Record assistant content (if any)
99
100
  if assistant_text is not None:
100
101
  self.history_messages.append({"role": "assistant", "content": assistant_text})
101
102
  # Keep structured step record for training/analysis
102
- record: Dict[str, Any] = {"turn": self.turn_index}
103
+ record: dict[str, Any] = {"turn": self.turn_index}
103
104
  if tool_calls is not None:
104
105
  record["tool_calls"] = tool_calls
105
106
  if env_result is not None:
@@ -109,13 +110,17 @@ class CrafterPolicy(Policy):
109
110
  def build_inference_request(
110
111
  self,
111
112
  observation_text: str,
112
- history: Optional[List[Dict[str, str]]] = None,
113
- turn: Optional[int] = None,
114
- ) -> Dict[str, Any]:
113
+ history: list[dict[str, Any]] | None = None,
114
+ turn: int | None = None,
115
+ image_parts: list[dict[str, Any]] | None = None,
116
+ ) -> dict[str, Any]:
115
117
  messages = CrafterReActAgent.build_messages(
116
- observation=observation_text, history=history, turn=turn
118
+ observation=observation_text,
119
+ history=history,
120
+ turn=turn,
121
+ image_parts=image_parts,
117
122
  )
118
- payload: Dict[str, Any] = {
123
+ payload: dict[str, Any] = {
119
124
  "messages": messages,
120
125
  }
121
126
  if self.model is not None:
@@ -150,9 +155,9 @@ class CrafterPolicy(Policy):
150
155
 
151
156
  @staticmethod
152
157
  def parse_response_to_tool_calls(
153
- response: Dict[str, Any],
158
+ response: dict[str, Any],
154
159
  use_tools: bool = True,
155
- ) -> List[Dict[str, Any]]:
160
+ ) -> list[dict[str, Any]]:
156
161
  """Turn an inference response into environment tool calls.
157
162
 
158
163
  - If tools were used, expect tool_calls-compatible output and forward as-is
@@ -162,7 +167,7 @@ class CrafterPolicy(Policy):
162
167
  """
163
168
  # First check if we got actual tool calls
164
169
  choices = response.get("choices", [])
165
- tool_calls: List[Dict[str, Any]] = []
170
+ tool_calls: list[dict[str, Any]] = []
166
171
 
167
172
  for choice in choices:
168
173
  msg = choice.get("message", {})
@@ -192,7 +197,7 @@ class CrafterPolicy(Policy):
192
197
  if tool_calls:
193
198
  # Normalize common degenerate pattern ["move_right", "do"] when nothing is nearby.
194
199
  # If previous env_result indicates no interaction target, drop trailing 'do'.
195
- normalized: List[Dict[str, Any]] = []
200
+ normalized: list[dict[str, Any]] = []
196
201
  for tc in tool_calls:
197
202
  if tc and isinstance(tc, dict) and tc.get("tool_name") == "interact_many":
198
203
  args = tc.get("arguments")
@@ -242,9 +247,9 @@ class CrafterPolicy(Policy):
242
247
  async def step(
243
248
  self,
244
249
  observation_text: str,
245
- state: Optional[Dict[str, Any]] = None,
246
- metadata: Optional[Dict[str, Any]] = None,
247
- ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
250
+ state: dict[str, Any] | None = None,
251
+ metadata: dict[str, Any] | None = None,
252
+ ) -> tuple[list[dict[str, Any]], dict[str, Any]]:
248
253
  """Stateful step: update policy history and prepare inference request.
249
254
 
250
255
  Inputs (via metadata, optional):
@@ -261,9 +266,9 @@ class CrafterPolicy(Policy):
261
266
  """
262
267
  # If caller provided results from previous cycle, record them first
263
268
  if metadata is not None:
264
- prev_assistant_text: Optional[str] = None
265
- prev_tool_calls: Optional[List[Dict[str, Any]]] = None
266
- prev_env_result: Optional[Dict[str, Any]] = None
269
+ prev_assistant_text: str | None = None
270
+ prev_tool_calls: list[dict[str, Any]] | None = None
271
+ prev_env_result: dict[str, Any] | None = None
267
272
  if "prev_assistant_text" in metadata:
268
273
  prev_assistant_text = metadata["prev_assistant_text"]
269
274
  if "prev_tool_calls" in metadata:
@@ -283,7 +288,7 @@ class CrafterPolicy(Policy):
283
288
  # Build user message by combining the current observation text
284
289
  # (formatted surroundings/inventory) with the previous 3 tool calls as context.
285
290
  # Most recent first.
286
- lines: List[str] = []
291
+ lines: list[str] = []
287
292
 
288
293
  def _format_tool_call_line_for_context(
289
294
  tool_name: str, arguments: Any, max_chars: int = 500
@@ -291,7 +296,7 @@ class CrafterPolicy(Policy):
291
296
  import json as _json
292
297
 
293
298
  # Render arguments compactly, then clip to max_chars
294
- if isinstance(arguments, (dict, list)):
299
+ if isinstance(arguments, dict | list):
295
300
  try:
296
301
  rendered = _json.dumps(arguments, ensure_ascii=False, separators=(",", ":"))
297
302
  except Exception:
@@ -321,7 +326,7 @@ class CrafterPolicy(Policy):
321
326
 
322
327
  # If trajectory history is empty (first few turns), fall back to metadata once
323
328
  if not lines and metadata is not None and metadata.get("prev_tool_calls"):
324
- calls: List[Dict[str, Any]] = metadata["prev_tool_calls"]
329
+ calls: list[dict[str, Any]] = metadata["prev_tool_calls"]
325
330
  for call in reversed(calls):
326
331
  if len(lines) >= 3:
327
332
  break
@@ -338,10 +343,18 @@ class CrafterPolicy(Policy):
338
343
  # Combine observation with context so the model always sees surroundings/inventory
339
344
  combined_text = f"{observation_text}\n\n{context_text}"
340
345
 
346
+ raw_observation: dict[str, Any] | None = None
347
+ if metadata is not None:
348
+ raw_candidate = metadata.get("raw_observation")
349
+ if isinstance(raw_candidate, dict):
350
+ raw_observation = raw_candidate
351
+ image_parts = self._extract_image_parts(raw_observation)
352
+
341
353
  payload = self.build_inference_request(
342
354
  combined_text,
343
355
  history=[], # no prior user/assistant history
344
356
  turn=self.turn_index,
357
+ image_parts=image_parts,
345
358
  )
346
359
  # print("Debugging only:; ", payload)
347
360
  meta_out = {
@@ -352,19 +365,19 @@ class CrafterPolicy(Policy):
352
365
  }
353
366
  return [], meta_out
354
367
 
355
- def state_dict(self) -> Dict[str, Any]:
368
+ def state_dict(self) -> dict[str, Any]:
356
369
  return {
357
370
  "turn_index": self.turn_index,
358
371
  "history_messages": self.history_messages,
359
372
  "trajectory_history": self.trajectory_history,
360
373
  }
361
374
 
362
- def load_state_dict(self, state: Dict[str, Any]) -> None:
375
+ def load_state_dict(self, state: dict[str, Any]) -> None:
363
376
  self.turn_index = int(state["turn_index"])
364
377
  self.history_messages = state["history_messages"]
365
378
  self.trajectory_history = state["trajectory_history"]
366
379
 
367
- async def serialize(self) -> Dict[str, Any]:
380
+ async def serialize(self) -> dict[str, Any]:
368
381
  return {
369
382
  "name": self.name,
370
383
  "config": {
@@ -376,7 +389,7 @@ class CrafterPolicy(Policy):
376
389
  }
377
390
 
378
391
  @classmethod
379
- async def deserialize(cls, payload: Dict[str, Any]) -> "CrafterPolicy":
392
+ async def deserialize(cls, payload: dict[str, Any]) -> CrafterPolicy:
380
393
  config = payload["config"]
381
394
  state = payload["state"]
382
395
  policy = cls(
@@ -391,22 +404,26 @@ class CrafterPolicy(Policy):
391
404
  return None
392
405
 
393
406
  def prepare_inference_request(
394
- self, observation: Dict[str, Any], history: List[Dict[str, Any]] = None
395
- ) -> Tuple[List[Dict[str, Any]], Optional[List[Dict[str, Any]]]]:
407
+ self, observation: dict[str, Any], history: list[dict[str, Any]] = None
408
+ ) -> tuple[list[dict[str, Any]], list[dict[str, Any]] | None]:
396
409
  """Prepare an inference request (implementing abstract method)."""
397
410
  # Format observation with rich contextual information
398
411
  observation_text = self._format_observation_for_llm(observation)
412
+ image_parts = self._extract_image_parts(observation)
399
413
 
400
414
  # Build messages (observation_text already formatted; no raw matrices)
401
415
  messages = CrafterReActAgent.build_messages(
402
- observation=observation_text, history=history, turn=self.turn_index
416
+ observation=observation_text,
417
+ history=history,
418
+ turn=self.turn_index,
419
+ image_parts=image_parts,
403
420
  )
404
421
 
405
422
  # Return messages and tools schema
406
423
  tools = TOOLS_SCHEMA if self.use_tools else None
407
424
  return messages, tools
408
425
 
409
- def _format_observation_for_llm(self, observation: Dict[str, Any]) -> str:
426
+ def _format_observation_for_llm(self, observation: dict[str, Any]) -> str:
410
427
  """Format observation with rich contextual information for the LLM using the shared formatter."""
411
428
  from .shared import format_observation
412
429
 
@@ -423,17 +440,22 @@ class CrafterPolicy(Policy):
423
440
 
424
441
  # Get additional info from the observation wrapper
425
442
  info = observation.get("info", {})
426
- if isinstance(info, dict):
427
- # Merge health from info into obs_data for the formatter
428
- if "health" in info and "health" not in obs_data:
429
- obs_data = dict(obs_data) # Make a copy
430
- obs_data["health"] = info["health"]
443
+ if isinstance(info, dict) and "health" in info and "health" not in obs_data:
444
+ obs_data = dict(obs_data) # Make a copy
445
+ obs_data["health"] = info["health"]
431
446
 
432
447
  return format_observation(obs_data, step_count=step_idx, max_steps=max_steps)
433
448
 
449
+ def _extract_image_parts(
450
+ self, observation: dict[str, Any] | None
451
+ ) -> list[dict[str, Any]]:
452
+ """Crafter policy uses text-only prompts; do not attach image parts."""
453
+
454
+ return []
455
+
434
456
  def parse_model_response(
435
- self, response: str, observation: Dict[str, Any]
436
- ) -> List[Dict[str, Any]]:
457
+ self, response: str, observation: dict[str, Any]
458
+ ) -> list[dict[str, Any]]:
437
459
  """Parse model response into tool calls (implementing abstract method).
438
460
 
439
461
  Note: Despite the type hint, vLLM actually returns a dict response,
@@ -7,7 +7,7 @@ utilities to keep a single parser.
7
7
 
8
8
  from __future__ import annotations
9
9
 
10
- from typing import Dict, List, Optional
10
+ from typing import Any
11
11
 
12
12
  from .shared import parse_actions
13
13
 
@@ -81,19 +81,27 @@ class CrafterReActAgent:
81
81
 
82
82
  @staticmethod
83
83
  def build_messages(
84
- observation: str, history: Optional[List[Dict[str, str]]] = None, turn: Optional[int] = None
85
- ) -> List[Dict[str, str]]:
84
+ observation: str,
85
+ history: list[dict[str, Any]] | None = None,
86
+ turn: int | None = None,
87
+ image_parts: list[dict[str, Any]] | None = None,
88
+ ) -> list[dict[str, Any]]:
86
89
  """Construct OpenAI-style messages list for vLLM generation."""
87
- msgs: List[Dict[str, str]] = [
90
+ msgs: list[dict[str, Any]] = [
88
91
  {"role": "system", "content": CrafterReActAgent.get_system_prompt()}
89
92
  ]
90
93
  if history:
91
94
  msgs.extend(history)
92
- msgs.append({"role": "user", "content": observation})
95
+ user_content: Any
96
+ if image_parts:
97
+ user_content = [{"type": "text", "text": observation}] + list(image_parts)
98
+ else:
99
+ user_content = observation
100
+ msgs.append({"role": "user", "content": user_content})
93
101
  return msgs
94
102
 
95
103
  @staticmethod
96
- def parse_actions_from_response(response_text: str) -> List[str]:
104
+ def parse_actions_from_response(response_text: str) -> list[str]:
97
105
  return parse_actions(response_text)
98
106
 
99
107