synth-ai 0.2.9.dev7__py3-none-any.whl → 0.2.9.dev9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/__init__.py +16 -0
- examples/crafter_debug_render.py +8 -11
- examples/qwen_coder/README.md +102 -0
- examples/qwen_coder/_shared.py +113 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +61 -0
- examples/qwen_coder/configs/coder_lora_4b.toml +57 -0
- examples/qwen_coder/configs/coder_lora_small.toml +58 -0
- examples/qwen_coder/generate_dataset.py +98 -0
- examples/qwen_coder/infer_ft_smoke.py +64 -0
- examples/qwen_coder/infer_prod_proxy.py +73 -0
- examples/qwen_coder/infer_via_synth.py +87 -0
- examples/qwen_coder/scripts/infer_coder.sh +18 -0
- examples/qwen_coder/scripts/train_coder_30b.sh +21 -0
- examples/qwen_coder/sft_full_17b.py +103 -0
- examples/qwen_coder/sft_lora_30b.py +110 -0
- examples/qwen_coder/subset_jsonl.py +38 -0
- examples/qwen_coder/validate_jsonl.py +59 -0
- examples/rl/run_eval.py +36 -37
- examples/rl/run_rl_and_save.py +5 -5
- examples/rl/task_app/math_single_step.py +65 -43
- examples/rl/task_app/math_task_app.py +3 -3
- examples/sft/README.md +139 -0
- examples/sft/configs/crafter_fft_qwen0p6b.toml +44 -0
- examples/sft/configs/crafter_lora_qwen0p6b.toml +45 -0
- examples/sft/evaluate.py +117 -0
- examples/sft/export_dataset.py +117 -0
- examples/sft/generate_traces.py +162 -0
- examples/swe/__init__.py +12 -0
- examples/swe/task_app/README.md +105 -0
- examples/swe/task_app/__init__.py +2 -0
- examples/swe/task_app/grpo_swe_mini.py +571 -0
- examples/swe/task_app/grpo_swe_mini_task_app.py +136 -0
- examples/swe/task_app/hosted/README.md +173 -0
- examples/swe/task_app/hosted/__init__.py +5 -0
- examples/swe/task_app/hosted/branching.py +143 -0
- examples/swe/task_app/hosted/environment_routes.py +1289 -0
- examples/swe/task_app/hosted/envs/__init__.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/__init__.py +6 -0
- examples/swe/task_app/hosted/envs/crafter/app.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/environment.py +522 -0
- examples/swe/task_app/hosted/envs/crafter/policy.py +478 -0
- examples/swe/task_app/hosted/envs/crafter/react_agent.py +108 -0
- examples/swe/task_app/hosted/envs/crafter/shared.py +305 -0
- examples/swe/task_app/hosted/envs/crafter/tools.py +47 -0
- examples/swe/task_app/hosted/envs/mini_swe/__init__.py +8 -0
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +1164 -0
- examples/swe/task_app/hosted/envs/mini_swe/policy.py +355 -0
- examples/swe/task_app/hosted/envs/mini_swe/shared.py +83 -0
- examples/swe/task_app/hosted/envs/mini_swe/tools.py +96 -0
- examples/swe/task_app/hosted/hosted_app.py +204 -0
- examples/swe/task_app/hosted/inference/__init__.py +5 -0
- examples/swe/task_app/hosted/inference/openai_client.py +618 -0
- examples/swe/task_app/hosted/main.py +100 -0
- examples/swe/task_app/hosted/policy_routes.py +1079 -0
- examples/swe/task_app/hosted/registry.py +195 -0
- examples/swe/task_app/hosted/rollout.py +1869 -0
- examples/swe/task_app/hosted/storage/__init__.py +5 -0
- examples/swe/task_app/hosted/storage/volume.py +211 -0
- examples/swe/task_app/hosted/test_agents.py +161 -0
- examples/swe/task_app/hosted/test_service.py +137 -0
- examples/swe/task_app/hosted/utils.py +62 -0
- examples/vlm/README.md +68 -0
- examples/vlm/configs/crafter_vlm_gpt4o.toml +44 -0
- examples/vlm/crafter_image_only_agent.py +207 -0
- examples/vlm/crafter_openai_vlm_agent.py +277 -0
- examples/vlm/filter_image_rows.py +63 -0
- examples/vlm/run_crafter_vlm_benchmark.py +316 -0
- examples/warming_up_to_rl/analyze_trace_db.py +5 -5
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +11 -1
- examples/warming_up_to_rl/export_trace_sft.py +78 -21
- examples/warming_up_to_rl/groq_test.py +4 -4
- examples/warming_up_to_rl/manage_secrets.py +13 -18
- examples/warming_up_to_rl/run_eval.py +42 -44
- examples/warming_up_to_rl/run_fft_and_save.py +11 -16
- examples/warming_up_to_rl/run_local_rollout.py +1 -3
- examples/warming_up_to_rl/run_local_rollout_modal.py +2 -4
- examples/warming_up_to_rl/run_local_rollout_parallel.py +1 -4
- examples/warming_up_to_rl/run_local_rollout_traced.py +3 -5
- examples/warming_up_to_rl/run_rl_and_save.py +5 -6
- examples/warming_up_to_rl/run_rollout_remote.py +8 -10
- examples/warming_up_to_rl/task_app/README.md +6 -2
- examples/warming_up_to_rl/task_app/grpo_crafter.py +234 -35
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +2 -3
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +9 -11
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +131 -114
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +101 -41
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +73 -51
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +14 -6
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +16 -16
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +32 -34
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +94 -31
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +0 -2
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +303 -203
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +21 -23
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +328 -225
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +13 -13
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +4 -3
- synth/__init__.py +14 -0
- synth_ai/__init__.py +26 -4
- synth_ai/api/models/supported.py +376 -0
- synth_ai/api/train/builders.py +128 -21
- synth_ai/api/train/cli.py +80 -64
- synth_ai/api/train/config_finder.py +7 -2
- synth_ai/api/train/env_resolver.py +1 -1
- synth_ai/api/train/pollers.py +2 -1
- synth_ai/api/train/supported_algos.py +139 -0
- synth_ai/api/train/task_app.py +1 -2
- synth_ai/api/train/utils.py +13 -44
- synth_ai/cli/__init__.py +8 -0
- synth_ai/cli/_modal_wrapper.py +28 -0
- synth_ai/cli/_typer_patch.py +49 -0
- synth_ai/cli/balance.py +1 -2
- synth_ai/cli/calc.py +1 -1
- synth_ai/cli/demo.py +2 -1
- synth_ai/cli/recent.py +2 -2
- synth_ai/cli/rl_demo.py +2 -1
- synth_ai/cli/root.py +11 -13
- synth_ai/cli/status.py +2 -2
- synth_ai/cli/task_apps.py +529 -179
- synth_ai/cli/traces.py +6 -4
- synth_ai/cli/watch.py +12 -18
- synth_ai/demo_registry.py +1 -1
- synth_ai/demos/core/cli.py +36 -43
- synth_ai/demos/demo_task_apps/__init__.py +3 -3
- synth_ai/demos/demo_task_apps/core.py +17 -25
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +3 -4
- synth_ai/demos/demo_task_apps/math/app.py +2 -1
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +3 -4
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +16 -18
- synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -1
- synth_ai/environments/examples/crafter_classic/environment.py +76 -1
- synth_ai/environments/reproducibility/tree.py +2 -5
- synth_ai/environments/service/app.py +11 -12
- synth_ai/environments/service/core_routes.py +4 -7
- synth_ai/environments/stateful/engine.py +1 -1
- synth_ai/environments/tasks/core.py +1 -0
- synth_ai/environments/tasks/filters.py +5 -6
- synth_ai/environments/tasks/utils.py +4 -5
- synth_ai/handshake.py +9 -9
- synth_ai/http.py +1 -1
- synth_ai/http_client.py +18 -10
- synth_ai/inference/client.py +15 -5
- synth_ai/jobs/client.py +78 -83
- synth_ai/learning/__init__.py +41 -6
- synth_ai/learning/algorithms.py +14 -0
- synth_ai/learning/client.py +91 -24
- synth_ai/learning/config.py +2 -38
- synth_ai/learning/ft_client.py +4 -59
- synth_ai/learning/health.py +5 -6
- synth_ai/learning/jobs.py +31 -47
- synth_ai/{rl → learning/rl}/__init__.py +14 -4
- synth_ai/learning/rl/client.py +267 -0
- synth_ai/learning/rl/config.py +31 -0
- synth_ai/{rl → learning/rl}/contracts.py +5 -8
- synth_ai/{rl → learning/rl}/env_keys.py +39 -15
- synth_ai/learning/rl/secrets.py +13 -0
- synth_ai/learning/rl_client.py +2 -281
- synth_ai/learning/sft/__init__.py +29 -0
- synth_ai/learning/sft/client.py +68 -0
- synth_ai/learning/sft/config.py +270 -0
- synth_ai/learning/sft/data.py +295 -0
- synth_ai/learning/sse.py +25 -24
- synth_ai/learning/validators.py +25 -28
- synth_ai/lm/__init__.py +21 -47
- synth_ai/main.py +6 -0
- synth_ai/task/__init__.py +25 -27
- synth_ai/task/apps/__init__.py +7 -8
- synth_ai/task/auth.py +8 -8
- synth_ai/task/client.py +14 -14
- synth_ai/task/contracts.py +36 -35
- synth_ai/task/datasets.py +6 -5
- synth_ai/task/errors.py +10 -10
- synth_ai/task/health.py +17 -9
- synth_ai/task/json.py +58 -23
- synth_ai/task/proxy.py +13 -9
- synth_ai/task/rubrics.py +16 -15
- synth_ai/task/server.py +12 -12
- synth_ai/task/tracing_utils.py +4 -4
- synth_ai/task/vendors.py +5 -6
- synth_ai/tracing_v3/__init__.py +2 -0
- synth_ai/tracing_v3/abstractions.py +21 -4
- synth_ai/tracing_v3/decorators.py +18 -16
- synth_ai/tracing_v3/hooks.py +5 -5
- synth_ai/tracing_v3/llm_call_record_helpers.py +6 -6
- synth_ai/tracing_v3/session_tracer.py +40 -14
- synth_ai/tracing_v3/storage/base.py +85 -0
- synth_ai/tracing_v3/storage/config.py +21 -8
- synth_ai/tracing_v3/storage/factory.py +10 -7
- synth_ai/tracing_v3/storage/utils.py +4 -2
- synth_ai/tracing_v3/turso/daemon.py +7 -2
- synth_ai/tracing_v3/turso/models.py +2 -2
- synth_ai/tracing_v3/turso/native_manager.py +1173 -0
- synth_ai/tracing_v3/utils.py +4 -4
- synth_ai/v0/api/__init__.py +8 -0
- synth_ai/v0/api/models/__init__.py +8 -0
- synth_ai/v0/api/models/supported.py +8 -0
- synth_ai/v0/config/__init__.py +15 -0
- synth_ai/v0/config/base_url.py +12 -0
- synth_ai/v0/lm/__init__.py +51 -0
- synth_ai/{lm → v0/lm}/caching/ephemeral.py +2 -2
- synth_ai/{lm → v0/lm}/caching/handler.py +4 -4
- synth_ai/{lm → v0/lm}/caching/initialize.py +1 -1
- synth_ai/{lm → v0/lm}/caching/persistent.py +1 -1
- synth_ai/{lm → v0/lm}/config.py +6 -1
- synth_ai/{lm → v0/lm}/core/all.py +9 -9
- synth_ai/{lm → v0/lm}/core/main.py +6 -6
- synth_ai/{lm → v0/lm}/core/main_v3.py +10 -10
- synth_ai/{lm → v0/lm}/core/synth_models.py +2 -14
- synth_ai/{lm → v0/lm}/core/vendor_clients.py +2 -2
- synth_ai/{lm → v0/lm}/overrides.py +2 -2
- synth_ai/{lm → v0/lm}/provider_support/anthropic.py +4 -4
- synth_ai/{lm → v0/lm}/provider_support/openai.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/handler.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/rehabilitate.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/core/anthropic_api.py +9 -9
- synth_ai/{lm → v0/lm}/vendors/core/gemini_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/mistral_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/openai_api.py +10 -10
- synth_ai/{lm → v0/lm}/vendors/openai_standard.py +8 -8
- synth_ai/{lm → v0/lm}/vendors/openai_standard_responses.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/custom_endpoint.py +3 -3
- synth_ai/{lm → v0/lm}/vendors/supported/deepseek.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/grok.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/groq.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/ollama.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/openrouter.py +3 -3
- synth_ai/{lm → v0/lm}/vendors/supported/together.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/synth_client.py +1 -1
- synth_ai/v0/tracing_v3/__init__.py +10 -0
- synth_ai/v0/tracing_v3/abstractions.py +3 -0
- synth_ai/v0/tracing_v3/decorators.py +3 -0
- synth_ai/v0/tracing_v3/llm_call_record_helpers.py +3 -0
- synth_ai/v0/tracing_v3/session_tracer.py +3 -0
- synth_ai-0.2.9.dev9.dist-info/METADATA +191 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev9.dist-info}/RECORD +268 -238
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev9.dist-info}/top_level.txt +1 -0
- examples/common_old/backend.py +0 -20
- examples/evals_old/README.md +0 -98
- examples/evals_old/__init__.py +0 -6
- examples/evals_old/compare_models.py +0 -1038
- examples/evals_old/example_log.md +0 -145
- examples/evals_old/run_demo.sh +0 -126
- examples/evals_old/trace_analysis.py +0 -270
- examples/finetuning_old/_backup_synth_qwen/config.toml +0 -29
- examples/finetuning_old/_backup_synth_qwen/example_log.md +0 -324
- examples/finetuning_old/_backup_synth_qwen/filter_traces.py +0 -60
- examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +0 -243
- examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +0 -109
- examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +0 -1924
- examples/finetuning_old/_backup_synth_qwen/readme.md +0 -49
- examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +0 -114
- examples/finetuning_old/_backup_synth_qwen/run_demo.sh +0 -195
- examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +0 -119
- examples/finetuning_old/synth_qwen_v1/README.md +0 -68
- examples/finetuning_old/synth_qwen_v1/filter_traces.py +0 -60
- examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +0 -243
- examples/finetuning_old/synth_qwen_v1/finetune.py +0 -46
- examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +0 -71
- examples/finetuning_old/synth_qwen_v1/infer.py +0 -36
- examples/finetuning_old/synth_qwen_v1/poll.py +0 -46
- examples/finetuning_old/synth_qwen_v1/prepare_data.py +0 -35
- examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +0 -109
- examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +0 -1933
- examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +0 -210
- examples/finetuning_old/synth_qwen_v1/run_ft_job.py +0 -237
- examples/finetuning_old/synth_qwen_v1/upload_data.py +0 -34
- examples/finetuning_old/synth_qwen_v1/util.py +0 -152
- examples/rl_old/task_app.py +0 -1131
- examples/warming_up_to_rl/old/event_rewards.md +0 -234
- examples/warming_up_to_rl/old/notes.md +0 -73
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +0 -738
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +0 -580
- synth_ai/experimental/synth_oss.py +0 -445
- synth_ai/learning/filtering.py +0 -0
- synth_ai/learning/offline/dpo.py +0 -0
- synth_ai/learning/offline/providers.py +0 -7
- synth_ai/learning/offline/sft.py +0 -0
- synth_ai/learning/offline/shared.py +0 -0
- synth_ai/learning/online/grpo.py +0 -0
- synth_ai/learning/online/irft.py +0 -0
- synth_ai/learning/prompts/banking77_injection_eval.py +0 -168
- synth_ai/learning/prompts/gepa.py +0 -0
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +0 -211
- synth_ai/learning/prompts/mipro.py +0 -289
- synth_ai/learning/prompts/random_search.py +0 -249
- synth_ai/learning/prompts/run_mipro_banking77.py +0 -172
- synth_ai/learning/prompts/run_random_search_banking77.py +0 -329
- synth_ai/rl/secrets.py +0 -19
- synth_ai/scripts/verify_rewards.py +0 -100
- synth_ai/tracing/__init__.py +0 -30
- synth_ai/tracing_v1/__init__.py +0 -33
- synth_ai/tracing_v3/turso/__init__.py +0 -25
- synth_ai/tracing_v3/turso/manager.py +0 -838
- synth_ai/zyk/__init__.py +0 -30
- synth_ai-0.2.9.dev7.dist-info/METADATA +0 -131
- /synth_ai/{lm → v0/lm}/caching/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/dbs.py +0 -0
- /synth_ai/{lm → v0/lm}/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/core/exceptions.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/monitor.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/statefulness.py +0 -0
- /synth_ai/{lm → v0/lm}/injection.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/suppress_logging.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/inject.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/base.py +0 -0
- /synth_ai/{lm → v0/lm}/unified_interface.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/base.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/synth_dev_api.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/ollama.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/retries.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/supported/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/warmup.py +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev9.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev9.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev9.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,19 +1,56 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
import base64
|
|
4
4
|
import logging
|
|
5
|
+
from io import BytesIO
|
|
6
|
+
from typing import Any
|
|
5
7
|
|
|
6
|
-
|
|
8
|
+
import numpy as np
|
|
9
|
+
from PIL import Image
|
|
7
10
|
from synth_ai.environments.environment.tools import EnvToolCall
|
|
11
|
+
from synth_ai.environments.stateful.core import StatefulEnvironment
|
|
8
12
|
|
|
9
13
|
from ...utils import convert_numpy_to_python
|
|
10
|
-
from .tools import TOOLS_SCHEMA
|
|
11
14
|
from .shared import CRAFTER_ACTIONS, _format_semantic_map_view
|
|
12
|
-
|
|
15
|
+
from .tools import TOOLS_SCHEMA
|
|
13
16
|
|
|
14
17
|
logger = logging.getLogger(__name__)
|
|
15
18
|
|
|
16
19
|
|
|
20
|
+
def _encode_image_to_base64(image_array: Any) -> dict[str, Any] | None:
|
|
21
|
+
"""Encode an RGB ndarray into a base64 PNG payload with metadata."""
|
|
22
|
+
|
|
23
|
+
if not isinstance(image_array, np.ndarray):
|
|
24
|
+
return None
|
|
25
|
+
if image_array.ndim != 3 or image_array.shape[-1] not in (1, 3, 4):
|
|
26
|
+
return None
|
|
27
|
+
try:
|
|
28
|
+
# Ensure uint8 for PIL compatibility
|
|
29
|
+
array_uint8 = (
|
|
30
|
+
image_array.astype("uint8")
|
|
31
|
+
if image_array.dtype != np.uint8
|
|
32
|
+
else image_array # pragma: no cover - fast path
|
|
33
|
+
)
|
|
34
|
+
mode = "L" if array_uint8.shape[-1] == 1 else "RGB"
|
|
35
|
+
if array_uint8.shape[-1] == 4:
|
|
36
|
+
mode = "RGBA"
|
|
37
|
+
img = Image.fromarray(array_uint8, mode=mode)
|
|
38
|
+
buffer = BytesIO()
|
|
39
|
+
img.save(buffer, format="PNG")
|
|
40
|
+
encoded = base64.b64encode(buffer.getvalue()).decode("ascii")
|
|
41
|
+
width = int(array_uint8.shape[1])
|
|
42
|
+
height = int(array_uint8.shape[0])
|
|
43
|
+
return {
|
|
44
|
+
"format": "png",
|
|
45
|
+
"width": width,
|
|
46
|
+
"height": height,
|
|
47
|
+
"data": encoded,
|
|
48
|
+
"data_url": f"data:image/png;base64,{encoded}",
|
|
49
|
+
}
|
|
50
|
+
except Exception:
|
|
51
|
+
return None
|
|
52
|
+
|
|
53
|
+
|
|
17
54
|
class CrafterEnvironmentWrapper:
|
|
18
55
|
"""Host-side environment wrapper matching the sketch contract.
|
|
19
56
|
|
|
@@ -25,20 +62,20 @@ class CrafterEnvironmentWrapper:
|
|
|
25
62
|
- snapshot()/restore() handled at route level; this wrapper exposes checkpoint via synth-ai
|
|
26
63
|
"""
|
|
27
64
|
|
|
28
|
-
def __init__(self, env: StatefulEnvironment, seed:
|
|
65
|
+
def __init__(self, env: StatefulEnvironment, seed: int | None = None) -> None:
|
|
29
66
|
self.env = env
|
|
30
67
|
self.seed = seed
|
|
31
68
|
self.step_idx = 0
|
|
32
|
-
self.last_observation:
|
|
33
|
-
self.last_info:
|
|
69
|
+
self.last_observation: dict[str, Any] | None = None
|
|
70
|
+
self.last_info: dict[str, Any] | None = None
|
|
34
71
|
|
|
35
|
-
async def initialize(self) ->
|
|
72
|
+
async def initialize(self) -> dict[str, Any]:
|
|
36
73
|
obs = await self.env.initialize()
|
|
37
74
|
# synth-ai InternalObservation expected to expose .observation (dict-like)
|
|
38
75
|
self.step_idx = 0
|
|
39
76
|
self.last_observation = getattr(obs, "observation", obs) # tolerate dict-like
|
|
40
77
|
self.last_info = getattr(obs, "info", None)
|
|
41
|
-
out_obs
|
|
78
|
+
out_obs = self._prepare_observation(self.last_observation)
|
|
42
79
|
# Attach a 7x7 semantic map patch centered on player for client-side rendering
|
|
43
80
|
try:
|
|
44
81
|
pub = self.env.engine._get_public_state_from_env() # type: ignore[attr-defined]
|
|
@@ -47,13 +84,13 @@ class CrafterEnvironmentWrapper:
|
|
|
47
84
|
size = 7
|
|
48
85
|
half = size // 2
|
|
49
86
|
patch = []
|
|
50
|
-
|
|
51
|
-
|
|
87
|
+
height = len(sem) if hasattr(sem, "__len__") else 0
|
|
88
|
+
width = len(sem[0]) if height and hasattr(sem[0], "__len__") else 0
|
|
52
89
|
for dy in range(-half, half + 1):
|
|
53
90
|
row = []
|
|
54
91
|
for dx in range(-half, half + 1):
|
|
55
92
|
x, y = int(px) + dx, int(py) + dy
|
|
56
|
-
if 0 <= x <
|
|
93
|
+
if 0 <= x < height and 0 <= y < width:
|
|
57
94
|
row.append(int(sem[x][y]))
|
|
58
95
|
else:
|
|
59
96
|
row.append(0)
|
|
@@ -68,7 +105,7 @@ class CrafterEnvironmentWrapper:
|
|
|
68
105
|
"step_idx": self.step_idx,
|
|
69
106
|
}
|
|
70
107
|
|
|
71
|
-
async def step(self, tool_calls:
|
|
108
|
+
async def step(self, tool_calls: list[dict[str, Any]] | list[EnvToolCall]) -> dict[str, Any]:
|
|
72
109
|
# Normalize JSON tool_calls into EnvToolCall instances if needed
|
|
73
110
|
# Underlying synth-ai environment expects only tool="interact" with args={"action": <action_name>}.
|
|
74
111
|
# LLM may emit:
|
|
@@ -79,9 +116,9 @@ class CrafterEnvironmentWrapper:
|
|
|
79
116
|
allowed_actions = set(
|
|
80
117
|
TOOLS_SCHEMA[0]["function"]["parameters"]["properties"]["actions"]["items"]["enum"]
|
|
81
118
|
)
|
|
82
|
-
normalized:
|
|
119
|
+
normalized: list[EnvToolCall] = []
|
|
83
120
|
|
|
84
|
-
def _action_to_int(action: Any) ->
|
|
121
|
+
def _action_to_int(action: Any) -> int | None:
|
|
85
122
|
# Handle invalid actions gracefully instead of failing
|
|
86
123
|
if isinstance(action, int):
|
|
87
124
|
return action
|
|
@@ -153,10 +190,8 @@ class CrafterEnvironmentWrapper:
|
|
|
153
190
|
if isinstance(args, dict) and "action" in args:
|
|
154
191
|
candidate_action = args["action"]
|
|
155
192
|
# If the caller provided a numeric action id, accept it directly
|
|
156
|
-
action_int:
|
|
157
|
-
if isinstance(candidate_action, int)
|
|
158
|
-
action_int = _action_to_int(candidate_action)
|
|
159
|
-
elif (
|
|
193
|
+
action_int: int | None
|
|
194
|
+
if isinstance(candidate_action, int) or (
|
|
160
195
|
isinstance(candidate_action, str)
|
|
161
196
|
and candidate_action in allowed_actions
|
|
162
197
|
):
|
|
@@ -175,7 +210,7 @@ class CrafterEnvironmentWrapper:
|
|
|
175
210
|
normalized.append(EnvToolCall(tool="interact", args={"action": 0})) # noop action
|
|
176
211
|
|
|
177
212
|
# Pre-step logging: capture current public state and print concise summary
|
|
178
|
-
before_state:
|
|
213
|
+
before_state: dict[str, Any] | None = None
|
|
179
214
|
try:
|
|
180
215
|
pub_before = self.env.engine._get_public_state_from_env() # type: ignore[attr-defined]
|
|
181
216
|
before_state = {
|
|
@@ -231,7 +266,7 @@ class CrafterEnvironmentWrapper:
|
|
|
231
266
|
ach_added_latest: list[str] | None = None
|
|
232
267
|
try:
|
|
233
268
|
pub_after = self.env.engine._get_public_state_from_env() # type: ignore[attr-defined]
|
|
234
|
-
after_dict:
|
|
269
|
+
after_dict: dict[str, Any] = {
|
|
235
270
|
"inventory": pub_after.inventory,
|
|
236
271
|
"achievements_status": pub_after.achievements_status,
|
|
237
272
|
"player_position": list(pub_after.player_position),
|
|
@@ -255,8 +290,8 @@ class CrafterEnvironmentWrapper:
|
|
|
255
290
|
# Position delta
|
|
256
291
|
pb = before_state.get("player_position", [0, 0])
|
|
257
292
|
pa = after_dict.get("player_position", [0, 0])
|
|
258
|
-
pb_t = (int(pb[0]), int(pb[1])) if isinstance(pb,
|
|
259
|
-
pa_t = (int(pa[0]), int(pa[1])) if isinstance(pa,
|
|
293
|
+
pb_t = (int(pb[0]), int(pb[1])) if isinstance(pb, list | tuple) else (0, 0)
|
|
294
|
+
pa_t = (int(pa[0]), int(pa[1])) if isinstance(pa, list | tuple) else (0, 0)
|
|
260
295
|
delta = (pa_t[0] - pb_t[0], pa_t[1] - pb_t[1])
|
|
261
296
|
|
|
262
297
|
# Inventory changes
|
|
@@ -280,9 +315,9 @@ class CrafterEnvironmentWrapper:
|
|
|
280
315
|
ach_a = {
|
|
281
316
|
k for k, v in (after_dict.get("achievements_status", {}) or {}).items() if v
|
|
282
317
|
}
|
|
283
|
-
ach_added = sorted(
|
|
318
|
+
ach_added = sorted(ach_a - ach_b)
|
|
284
319
|
ach_added_latest = ach_added
|
|
285
|
-
ach_removed = sorted(
|
|
320
|
+
ach_removed = sorted(ach_b - ach_a)
|
|
286
321
|
|
|
287
322
|
logger.info(
|
|
288
323
|
"Changes: pos %s->%s Δ=%s | inv %s | ach +%s -%s",
|
|
@@ -312,8 +347,8 @@ class CrafterEnvironmentWrapper:
|
|
|
312
347
|
)
|
|
313
348
|
except Exception as _:
|
|
314
349
|
pass
|
|
315
|
-
result:
|
|
316
|
-
"observation":
|
|
350
|
+
result: dict[str, Any] = {
|
|
351
|
+
"observation": self._prepare_observation(observation),
|
|
317
352
|
"step_idx": self.step_idx,
|
|
318
353
|
"done": bool(done) if done is not None else False, # Ensure boolean
|
|
319
354
|
}
|
|
@@ -325,13 +360,13 @@ class CrafterEnvironmentWrapper:
|
|
|
325
360
|
size = 7
|
|
326
361
|
half = size // 2
|
|
327
362
|
patch = []
|
|
328
|
-
|
|
329
|
-
|
|
363
|
+
height = len(sem) if hasattr(sem, "__len__") else 0
|
|
364
|
+
width = len(sem[0]) if height and hasattr(sem[0], "__len__") else 0
|
|
330
365
|
for dy in range(-half, half + 1):
|
|
331
366
|
row = []
|
|
332
367
|
for dx in range(-half, half + 1):
|
|
333
368
|
x, y = px + dx, py + dy
|
|
334
|
-
if 0 <= x <
|
|
369
|
+
if 0 <= x < height and 0 <= y < width:
|
|
335
370
|
row.append(int(sem[x][y]))
|
|
336
371
|
else:
|
|
337
372
|
row.append(0)
|
|
@@ -341,10 +376,7 @@ class CrafterEnvironmentWrapper:
|
|
|
341
376
|
obs_out["semantic_map_patch7"] = patch
|
|
342
377
|
except Exception:
|
|
343
378
|
pass
|
|
344
|
-
if info is not None
|
|
345
|
-
result_info = convert_numpy_to_python(info)
|
|
346
|
-
else:
|
|
347
|
-
result_info = {}
|
|
379
|
+
result_info = convert_numpy_to_python(info) if info is not None else {}
|
|
348
380
|
# Attach achievements delta for downstream metrics if useful
|
|
349
381
|
if ach_added_latest is not None:
|
|
350
382
|
try:
|
|
@@ -404,9 +436,37 @@ class CrafterEnvironmentWrapper:
|
|
|
404
436
|
)
|
|
405
437
|
except Exception:
|
|
406
438
|
pass
|
|
439
|
+
|
|
407
440
|
return result
|
|
408
441
|
|
|
409
|
-
|
|
442
|
+
def _prepare_observation(self, observation: Any) -> dict[str, Any]:
|
|
443
|
+
"""Convert raw observation into a JSON-serializable dict with encoded image."""
|
|
444
|
+
|
|
445
|
+
obs_dict: dict[str, Any]
|
|
446
|
+
image_payload: dict[str, Any] | None = None
|
|
447
|
+
|
|
448
|
+
if isinstance(observation, dict):
|
|
449
|
+
image_payload = _encode_image_to_base64(observation.get("observation_image"))
|
|
450
|
+
# Work on a shallow copy to avoid mutating engine state
|
|
451
|
+
sanitized = dict(observation)
|
|
452
|
+
sanitized.pop("observation_image", None)
|
|
453
|
+
obs_dict = convert_numpy_to_python(sanitized) or {}
|
|
454
|
+
else:
|
|
455
|
+
obs_dict = convert_numpy_to_python(observation) or {}
|
|
456
|
+
|
|
457
|
+
if not isinstance(obs_dict, dict):
|
|
458
|
+
obs_dict = {"value": obs_dict}
|
|
459
|
+
|
|
460
|
+
if image_payload:
|
|
461
|
+
obs_dict["observation_image_base64"] = image_payload["data"]
|
|
462
|
+
obs_dict["observation_image_format"] = image_payload["format"]
|
|
463
|
+
obs_dict["observation_image_width"] = image_payload["width"]
|
|
464
|
+
obs_dict["observation_image_height"] = image_payload["height"]
|
|
465
|
+
obs_dict["observation_image_data_url"] = image_payload["data_url"]
|
|
466
|
+
|
|
467
|
+
return obs_dict
|
|
468
|
+
|
|
469
|
+
async def checkpoint(self) -> dict[str, Any]:
|
|
410
470
|
obs = await self.env.checkpoint()
|
|
411
471
|
observation = getattr(obs, "observation", obs)
|
|
412
472
|
info = getattr(obs, "info", None)
|
|
@@ -416,7 +476,7 @@ class CrafterEnvironmentWrapper:
|
|
|
416
476
|
"step_idx": self.step_idx,
|
|
417
477
|
}
|
|
418
478
|
|
|
419
|
-
async def terminate(self) ->
|
|
479
|
+
async def terminate(self) -> dict[str, Any]:
|
|
420
480
|
obs = await self.env.terminate()
|
|
421
481
|
observation = getattr(obs, "observation", obs)
|
|
422
482
|
info = getattr(obs, "info", None)
|
|
@@ -426,7 +486,7 @@ class CrafterEnvironmentWrapper:
|
|
|
426
486
|
"step_idx": self.step_idx,
|
|
427
487
|
}
|
|
428
488
|
|
|
429
|
-
def state_dict(self) ->
|
|
489
|
+
def state_dict(self) -> dict[str, Any]:
|
|
430
490
|
return {
|
|
431
491
|
"seed": self.seed,
|
|
432
492
|
"step_idx": self.step_idx,
|
|
@@ -434,13 +494,13 @@ class CrafterEnvironmentWrapper:
|
|
|
434
494
|
"last_info": self.last_info,
|
|
435
495
|
}
|
|
436
496
|
|
|
437
|
-
def load_state_dict(self, state:
|
|
497
|
+
def load_state_dict(self, state: dict[str, Any]) -> None:
|
|
438
498
|
self.seed = state["seed"]
|
|
439
499
|
self.step_idx = int(state["step_idx"])
|
|
440
500
|
self.last_observation = state["last_observation"]
|
|
441
501
|
self.last_info = state["last_info"]
|
|
442
502
|
|
|
443
|
-
async def serialize(self) ->
|
|
503
|
+
async def serialize(self) -> dict[str, Any]:
|
|
444
504
|
return {
|
|
445
505
|
"name": "crafter",
|
|
446
506
|
"config": {"seed": self.seed},
|
|
@@ -450,9 +510,9 @@ class CrafterEnvironmentWrapper:
|
|
|
450
510
|
@classmethod
|
|
451
511
|
async def deserialize(
|
|
452
512
|
cls,
|
|
453
|
-
payload:
|
|
513
|
+
payload: dict[str, Any],
|
|
454
514
|
env: StatefulEnvironment,
|
|
455
|
-
) ->
|
|
515
|
+
) -> CrafterEnvironmentWrapper:
|
|
456
516
|
seed = payload["config"]["seed"]
|
|
457
517
|
wrapper = cls(env=env, seed=seed)
|
|
458
518
|
wrapper.load_state_dict(payload["state"])
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from typing import Any, Dict, List, Optional, Tuple
|
|
4
3
|
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
5
6
|
from .react_agent import CrafterReActAgent
|
|
6
7
|
from .tools import TOOLS_SCHEMA
|
|
7
8
|
|
|
@@ -12,15 +13,15 @@ class Policy(ABC):
|
|
|
12
13
|
|
|
13
14
|
@abstractmethod
|
|
14
15
|
def prepare_inference_request(
|
|
15
|
-
self, observation:
|
|
16
|
-
) ->
|
|
16
|
+
self, observation: dict[str, Any], history: list[dict[str, Any]] = None
|
|
17
|
+
) -> tuple[list[dict[str, Any]], list[dict[str, Any]] | None]:
|
|
17
18
|
"""Prepare an inference request."""
|
|
18
19
|
pass
|
|
19
20
|
|
|
20
21
|
@abstractmethod
|
|
21
22
|
def parse_model_response(
|
|
22
|
-
self, response: str, observation:
|
|
23
|
-
) ->
|
|
23
|
+
self, response: str, observation: dict[str, Any]
|
|
24
|
+
) -> list[dict[str, Any]]:
|
|
24
25
|
"""Parse model response into tool calls."""
|
|
25
26
|
pass
|
|
26
27
|
|
|
@@ -39,23 +40,23 @@ class CrafterPolicy(Policy):
|
|
|
39
40
|
|
|
40
41
|
name: str = "crafter-react"
|
|
41
42
|
|
|
42
|
-
def __init__(self, inference_url: str, model:
|
|
43
|
+
def __init__(self, inference_url: str, model: str | None = None) -> None:
|
|
43
44
|
self.inference_url = inference_url
|
|
44
45
|
self.model = model
|
|
45
46
|
self.use_tools = True
|
|
46
47
|
# Sampling parameters (populated via initialize(config))
|
|
47
|
-
self.temperature:
|
|
48
|
-
self.top_p:
|
|
49
|
-
self.max_tokens:
|
|
48
|
+
self.temperature: float | None = None
|
|
49
|
+
self.top_p: float | None = None
|
|
50
|
+
self.max_tokens: int | None = None
|
|
50
51
|
# Thinking controls (populated via initialize(config))
|
|
51
|
-
self.thinking_mode:
|
|
52
|
-
self.thinking_budget:
|
|
52
|
+
self.thinking_mode: str | None = None
|
|
53
|
+
self.thinking_budget: int | None = None
|
|
53
54
|
# Rolling conversation and action history for non-Markov policies
|
|
54
|
-
self.history_messages:
|
|
55
|
+
self.history_messages: list[dict[str, str]] = [] # chat-style without system
|
|
55
56
|
self.turn_index: int = 0
|
|
56
|
-
self.trajectory_history:
|
|
57
|
+
self.trajectory_history: list[dict[str, Any]] = [] # env/policy step records
|
|
57
58
|
|
|
58
|
-
async def initialize(self, config:
|
|
59
|
+
async def initialize(self, config: dict[str, Any]) -> None:
|
|
59
60
|
if "inference_url" in config:
|
|
60
61
|
self.inference_url = config["inference_url"]
|
|
61
62
|
if "model" in config:
|
|
@@ -91,15 +92,15 @@ class CrafterPolicy(Policy):
|
|
|
91
92
|
|
|
92
93
|
def _append_assistant_turn(
|
|
93
94
|
self,
|
|
94
|
-
assistant_text:
|
|
95
|
-
tool_calls:
|
|
96
|
-
env_result:
|
|
95
|
+
assistant_text: str | None,
|
|
96
|
+
tool_calls: list[dict[str, Any]] | None,
|
|
97
|
+
env_result: dict[str, Any] | None,
|
|
97
98
|
) -> None:
|
|
98
99
|
# Record assistant content (if any)
|
|
99
100
|
if assistant_text is not None:
|
|
100
101
|
self.history_messages.append({"role": "assistant", "content": assistant_text})
|
|
101
102
|
# Keep structured step record for training/analysis
|
|
102
|
-
record:
|
|
103
|
+
record: dict[str, Any] = {"turn": self.turn_index}
|
|
103
104
|
if tool_calls is not None:
|
|
104
105
|
record["tool_calls"] = tool_calls
|
|
105
106
|
if env_result is not None:
|
|
@@ -109,13 +110,17 @@ class CrafterPolicy(Policy):
|
|
|
109
110
|
def build_inference_request(
|
|
110
111
|
self,
|
|
111
112
|
observation_text: str,
|
|
112
|
-
history:
|
|
113
|
-
turn:
|
|
114
|
-
|
|
113
|
+
history: list[dict[str, Any]] | None = None,
|
|
114
|
+
turn: int | None = None,
|
|
115
|
+
image_parts: list[dict[str, Any]] | None = None,
|
|
116
|
+
) -> dict[str, Any]:
|
|
115
117
|
messages = CrafterReActAgent.build_messages(
|
|
116
|
-
observation=observation_text,
|
|
118
|
+
observation=observation_text,
|
|
119
|
+
history=history,
|
|
120
|
+
turn=turn,
|
|
121
|
+
image_parts=image_parts,
|
|
117
122
|
)
|
|
118
|
-
payload:
|
|
123
|
+
payload: dict[str, Any] = {
|
|
119
124
|
"messages": messages,
|
|
120
125
|
}
|
|
121
126
|
if self.model is not None:
|
|
@@ -150,9 +155,9 @@ class CrafterPolicy(Policy):
|
|
|
150
155
|
|
|
151
156
|
@staticmethod
|
|
152
157
|
def parse_response_to_tool_calls(
|
|
153
|
-
response:
|
|
158
|
+
response: dict[str, Any],
|
|
154
159
|
use_tools: bool = True,
|
|
155
|
-
) ->
|
|
160
|
+
) -> list[dict[str, Any]]:
|
|
156
161
|
"""Turn an inference response into environment tool calls.
|
|
157
162
|
|
|
158
163
|
- If tools were used, expect tool_calls-compatible output and forward as-is
|
|
@@ -162,7 +167,7 @@ class CrafterPolicy(Policy):
|
|
|
162
167
|
"""
|
|
163
168
|
# First check if we got actual tool calls
|
|
164
169
|
choices = response.get("choices", [])
|
|
165
|
-
tool_calls:
|
|
170
|
+
tool_calls: list[dict[str, Any]] = []
|
|
166
171
|
|
|
167
172
|
for choice in choices:
|
|
168
173
|
msg = choice.get("message", {})
|
|
@@ -192,7 +197,7 @@ class CrafterPolicy(Policy):
|
|
|
192
197
|
if tool_calls:
|
|
193
198
|
# Normalize common degenerate pattern ["move_right", "do"] when nothing is nearby.
|
|
194
199
|
# If previous env_result indicates no interaction target, drop trailing 'do'.
|
|
195
|
-
normalized:
|
|
200
|
+
normalized: list[dict[str, Any]] = []
|
|
196
201
|
for tc in tool_calls:
|
|
197
202
|
if tc and isinstance(tc, dict) and tc.get("tool_name") == "interact_many":
|
|
198
203
|
args = tc.get("arguments")
|
|
@@ -242,9 +247,9 @@ class CrafterPolicy(Policy):
|
|
|
242
247
|
async def step(
|
|
243
248
|
self,
|
|
244
249
|
observation_text: str,
|
|
245
|
-
state:
|
|
246
|
-
metadata:
|
|
247
|
-
) ->
|
|
250
|
+
state: dict[str, Any] | None = None,
|
|
251
|
+
metadata: dict[str, Any] | None = None,
|
|
252
|
+
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
|
|
248
253
|
"""Stateful step: update policy history and prepare inference request.
|
|
249
254
|
|
|
250
255
|
Inputs (via metadata, optional):
|
|
@@ -261,9 +266,9 @@ class CrafterPolicy(Policy):
|
|
|
261
266
|
"""
|
|
262
267
|
# If caller provided results from previous cycle, record them first
|
|
263
268
|
if metadata is not None:
|
|
264
|
-
prev_assistant_text:
|
|
265
|
-
prev_tool_calls:
|
|
266
|
-
prev_env_result:
|
|
269
|
+
prev_assistant_text: str | None = None
|
|
270
|
+
prev_tool_calls: list[dict[str, Any]] | None = None
|
|
271
|
+
prev_env_result: dict[str, Any] | None = None
|
|
267
272
|
if "prev_assistant_text" in metadata:
|
|
268
273
|
prev_assistant_text = metadata["prev_assistant_text"]
|
|
269
274
|
if "prev_tool_calls" in metadata:
|
|
@@ -283,7 +288,7 @@ class CrafterPolicy(Policy):
|
|
|
283
288
|
# Build user message by combining the current observation text
|
|
284
289
|
# (formatted surroundings/inventory) with the previous 3 tool calls as context.
|
|
285
290
|
# Most recent first.
|
|
286
|
-
lines:
|
|
291
|
+
lines: list[str] = []
|
|
287
292
|
|
|
288
293
|
def _format_tool_call_line_for_context(
|
|
289
294
|
tool_name: str, arguments: Any, max_chars: int = 500
|
|
@@ -291,7 +296,7 @@ class CrafterPolicy(Policy):
|
|
|
291
296
|
import json as _json
|
|
292
297
|
|
|
293
298
|
# Render arguments compactly, then clip to max_chars
|
|
294
|
-
if isinstance(arguments,
|
|
299
|
+
if isinstance(arguments, dict | list):
|
|
295
300
|
try:
|
|
296
301
|
rendered = _json.dumps(arguments, ensure_ascii=False, separators=(",", ":"))
|
|
297
302
|
except Exception:
|
|
@@ -321,7 +326,7 @@ class CrafterPolicy(Policy):
|
|
|
321
326
|
|
|
322
327
|
# If trajectory history is empty (first few turns), fall back to metadata once
|
|
323
328
|
if not lines and metadata is not None and metadata.get("prev_tool_calls"):
|
|
324
|
-
calls:
|
|
329
|
+
calls: list[dict[str, Any]] = metadata["prev_tool_calls"]
|
|
325
330
|
for call in reversed(calls):
|
|
326
331
|
if len(lines) >= 3:
|
|
327
332
|
break
|
|
@@ -338,10 +343,18 @@ class CrafterPolicy(Policy):
|
|
|
338
343
|
# Combine observation with context so the model always sees surroundings/inventory
|
|
339
344
|
combined_text = f"{observation_text}\n\n{context_text}"
|
|
340
345
|
|
|
346
|
+
raw_observation: dict[str, Any] | None = None
|
|
347
|
+
if metadata is not None:
|
|
348
|
+
raw_candidate = metadata.get("raw_observation")
|
|
349
|
+
if isinstance(raw_candidate, dict):
|
|
350
|
+
raw_observation = raw_candidate
|
|
351
|
+
image_parts = self._extract_image_parts(raw_observation)
|
|
352
|
+
|
|
341
353
|
payload = self.build_inference_request(
|
|
342
354
|
combined_text,
|
|
343
355
|
history=[], # no prior user/assistant history
|
|
344
356
|
turn=self.turn_index,
|
|
357
|
+
image_parts=image_parts,
|
|
345
358
|
)
|
|
346
359
|
# print("Debugging only:; ", payload)
|
|
347
360
|
meta_out = {
|
|
@@ -352,19 +365,19 @@ class CrafterPolicy(Policy):
|
|
|
352
365
|
}
|
|
353
366
|
return [], meta_out
|
|
354
367
|
|
|
355
|
-
def state_dict(self) ->
|
|
368
|
+
def state_dict(self) -> dict[str, Any]:
|
|
356
369
|
return {
|
|
357
370
|
"turn_index": self.turn_index,
|
|
358
371
|
"history_messages": self.history_messages,
|
|
359
372
|
"trajectory_history": self.trajectory_history,
|
|
360
373
|
}
|
|
361
374
|
|
|
362
|
-
def load_state_dict(self, state:
|
|
375
|
+
def load_state_dict(self, state: dict[str, Any]) -> None:
|
|
363
376
|
self.turn_index = int(state["turn_index"])
|
|
364
377
|
self.history_messages = state["history_messages"]
|
|
365
378
|
self.trajectory_history = state["trajectory_history"]
|
|
366
379
|
|
|
367
|
-
async def serialize(self) ->
|
|
380
|
+
async def serialize(self) -> dict[str, Any]:
|
|
368
381
|
return {
|
|
369
382
|
"name": self.name,
|
|
370
383
|
"config": {
|
|
@@ -376,7 +389,7 @@ class CrafterPolicy(Policy):
|
|
|
376
389
|
}
|
|
377
390
|
|
|
378
391
|
@classmethod
|
|
379
|
-
async def deserialize(cls, payload:
|
|
392
|
+
async def deserialize(cls, payload: dict[str, Any]) -> CrafterPolicy:
|
|
380
393
|
config = payload["config"]
|
|
381
394
|
state = payload["state"]
|
|
382
395
|
policy = cls(
|
|
@@ -391,22 +404,26 @@ class CrafterPolicy(Policy):
|
|
|
391
404
|
return None
|
|
392
405
|
|
|
393
406
|
def prepare_inference_request(
|
|
394
|
-
self, observation:
|
|
395
|
-
) ->
|
|
407
|
+
self, observation: dict[str, Any], history: list[dict[str, Any]] = None
|
|
408
|
+
) -> tuple[list[dict[str, Any]], list[dict[str, Any]] | None]:
|
|
396
409
|
"""Prepare an inference request (implementing abstract method)."""
|
|
397
410
|
# Format observation with rich contextual information
|
|
398
411
|
observation_text = self._format_observation_for_llm(observation)
|
|
412
|
+
image_parts = self._extract_image_parts(observation)
|
|
399
413
|
|
|
400
414
|
# Build messages (observation_text already formatted; no raw matrices)
|
|
401
415
|
messages = CrafterReActAgent.build_messages(
|
|
402
|
-
observation=observation_text,
|
|
416
|
+
observation=observation_text,
|
|
417
|
+
history=history,
|
|
418
|
+
turn=self.turn_index,
|
|
419
|
+
image_parts=image_parts,
|
|
403
420
|
)
|
|
404
421
|
|
|
405
422
|
# Return messages and tools schema
|
|
406
423
|
tools = TOOLS_SCHEMA if self.use_tools else None
|
|
407
424
|
return messages, tools
|
|
408
425
|
|
|
409
|
-
def _format_observation_for_llm(self, observation:
|
|
426
|
+
def _format_observation_for_llm(self, observation: dict[str, Any]) -> str:
|
|
410
427
|
"""Format observation with rich contextual information for the LLM using the shared formatter."""
|
|
411
428
|
from .shared import format_observation
|
|
412
429
|
|
|
@@ -423,17 +440,22 @@ class CrafterPolicy(Policy):
|
|
|
423
440
|
|
|
424
441
|
# Get additional info from the observation wrapper
|
|
425
442
|
info = observation.get("info", {})
|
|
426
|
-
if isinstance(info, dict):
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
obs_data = dict(obs_data) # Make a copy
|
|
430
|
-
obs_data["health"] = info["health"]
|
|
443
|
+
if isinstance(info, dict) and "health" in info and "health" not in obs_data:
|
|
444
|
+
obs_data = dict(obs_data) # Make a copy
|
|
445
|
+
obs_data["health"] = info["health"]
|
|
431
446
|
|
|
432
447
|
return format_observation(obs_data, step_count=step_idx, max_steps=max_steps)
|
|
433
448
|
|
|
449
|
+
def _extract_image_parts(
|
|
450
|
+
self, observation: dict[str, Any] | None
|
|
451
|
+
) -> list[dict[str, Any]]:
|
|
452
|
+
"""Crafter policy uses text-only prompts; do not attach image parts."""
|
|
453
|
+
|
|
454
|
+
return []
|
|
455
|
+
|
|
434
456
|
def parse_model_response(
|
|
435
|
-
self, response: str, observation:
|
|
436
|
-
) ->
|
|
457
|
+
self, response: str, observation: dict[str, Any]
|
|
458
|
+
) -> list[dict[str, Any]]:
|
|
437
459
|
"""Parse model response into tool calls (implementing abstract method).
|
|
438
460
|
|
|
439
461
|
Note: Despite the type hint, vLLM actually returns a dict response,
|
|
@@ -7,7 +7,7 @@ utilities to keep a single parser.
|
|
|
7
7
|
|
|
8
8
|
from __future__ import annotations
|
|
9
9
|
|
|
10
|
-
from typing import
|
|
10
|
+
from typing import Any
|
|
11
11
|
|
|
12
12
|
from .shared import parse_actions
|
|
13
13
|
|
|
@@ -81,19 +81,27 @@ class CrafterReActAgent:
|
|
|
81
81
|
|
|
82
82
|
@staticmethod
|
|
83
83
|
def build_messages(
|
|
84
|
-
observation: str,
|
|
85
|
-
|
|
84
|
+
observation: str,
|
|
85
|
+
history: list[dict[str, Any]] | None = None,
|
|
86
|
+
turn: int | None = None,
|
|
87
|
+
image_parts: list[dict[str, Any]] | None = None,
|
|
88
|
+
) -> list[dict[str, Any]]:
|
|
86
89
|
"""Construct OpenAI-style messages list for vLLM generation."""
|
|
87
|
-
msgs:
|
|
90
|
+
msgs: list[dict[str, Any]] = [
|
|
88
91
|
{"role": "system", "content": CrafterReActAgent.get_system_prompt()}
|
|
89
92
|
]
|
|
90
93
|
if history:
|
|
91
94
|
msgs.extend(history)
|
|
92
|
-
|
|
95
|
+
user_content: Any
|
|
96
|
+
if image_parts:
|
|
97
|
+
user_content = [{"type": "text", "text": observation}] + list(image_parts)
|
|
98
|
+
else:
|
|
99
|
+
user_content = observation
|
|
100
|
+
msgs.append({"role": "user", "content": user_content})
|
|
93
101
|
return msgs
|
|
94
102
|
|
|
95
103
|
@staticmethod
|
|
96
|
-
def parse_actions_from_response(response_text: str) ->
|
|
104
|
+
def parse_actions_from_response(response_text: str) -> list[str]:
|
|
97
105
|
return parse_actions(response_text)
|
|
98
106
|
|
|
99
107
|
|