synth-ai 0.2.9.dev7__py3-none-any.whl → 0.2.9.dev9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/__init__.py +16 -0
- examples/crafter_debug_render.py +8 -11
- examples/qwen_coder/README.md +102 -0
- examples/qwen_coder/_shared.py +113 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +61 -0
- examples/qwen_coder/configs/coder_lora_4b.toml +57 -0
- examples/qwen_coder/configs/coder_lora_small.toml +58 -0
- examples/qwen_coder/generate_dataset.py +98 -0
- examples/qwen_coder/infer_ft_smoke.py +64 -0
- examples/qwen_coder/infer_prod_proxy.py +73 -0
- examples/qwen_coder/infer_via_synth.py +87 -0
- examples/qwen_coder/scripts/infer_coder.sh +18 -0
- examples/qwen_coder/scripts/train_coder_30b.sh +21 -0
- examples/qwen_coder/sft_full_17b.py +103 -0
- examples/qwen_coder/sft_lora_30b.py +110 -0
- examples/qwen_coder/subset_jsonl.py +38 -0
- examples/qwen_coder/validate_jsonl.py +59 -0
- examples/rl/run_eval.py +36 -37
- examples/rl/run_rl_and_save.py +5 -5
- examples/rl/task_app/math_single_step.py +65 -43
- examples/rl/task_app/math_task_app.py +3 -3
- examples/sft/README.md +139 -0
- examples/sft/configs/crafter_fft_qwen0p6b.toml +44 -0
- examples/sft/configs/crafter_lora_qwen0p6b.toml +45 -0
- examples/sft/evaluate.py +117 -0
- examples/sft/export_dataset.py +117 -0
- examples/sft/generate_traces.py +162 -0
- examples/swe/__init__.py +12 -0
- examples/swe/task_app/README.md +105 -0
- examples/swe/task_app/__init__.py +2 -0
- examples/swe/task_app/grpo_swe_mini.py +571 -0
- examples/swe/task_app/grpo_swe_mini_task_app.py +136 -0
- examples/swe/task_app/hosted/README.md +173 -0
- examples/swe/task_app/hosted/__init__.py +5 -0
- examples/swe/task_app/hosted/branching.py +143 -0
- examples/swe/task_app/hosted/environment_routes.py +1289 -0
- examples/swe/task_app/hosted/envs/__init__.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/__init__.py +6 -0
- examples/swe/task_app/hosted/envs/crafter/app.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/environment.py +522 -0
- examples/swe/task_app/hosted/envs/crafter/policy.py +478 -0
- examples/swe/task_app/hosted/envs/crafter/react_agent.py +108 -0
- examples/swe/task_app/hosted/envs/crafter/shared.py +305 -0
- examples/swe/task_app/hosted/envs/crafter/tools.py +47 -0
- examples/swe/task_app/hosted/envs/mini_swe/__init__.py +8 -0
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +1164 -0
- examples/swe/task_app/hosted/envs/mini_swe/policy.py +355 -0
- examples/swe/task_app/hosted/envs/mini_swe/shared.py +83 -0
- examples/swe/task_app/hosted/envs/mini_swe/tools.py +96 -0
- examples/swe/task_app/hosted/hosted_app.py +204 -0
- examples/swe/task_app/hosted/inference/__init__.py +5 -0
- examples/swe/task_app/hosted/inference/openai_client.py +618 -0
- examples/swe/task_app/hosted/main.py +100 -0
- examples/swe/task_app/hosted/policy_routes.py +1079 -0
- examples/swe/task_app/hosted/registry.py +195 -0
- examples/swe/task_app/hosted/rollout.py +1869 -0
- examples/swe/task_app/hosted/storage/__init__.py +5 -0
- examples/swe/task_app/hosted/storage/volume.py +211 -0
- examples/swe/task_app/hosted/test_agents.py +161 -0
- examples/swe/task_app/hosted/test_service.py +137 -0
- examples/swe/task_app/hosted/utils.py +62 -0
- examples/vlm/README.md +68 -0
- examples/vlm/configs/crafter_vlm_gpt4o.toml +44 -0
- examples/vlm/crafter_image_only_agent.py +207 -0
- examples/vlm/crafter_openai_vlm_agent.py +277 -0
- examples/vlm/filter_image_rows.py +63 -0
- examples/vlm/run_crafter_vlm_benchmark.py +316 -0
- examples/warming_up_to_rl/analyze_trace_db.py +5 -5
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +11 -1
- examples/warming_up_to_rl/export_trace_sft.py +78 -21
- examples/warming_up_to_rl/groq_test.py +4 -4
- examples/warming_up_to_rl/manage_secrets.py +13 -18
- examples/warming_up_to_rl/run_eval.py +42 -44
- examples/warming_up_to_rl/run_fft_and_save.py +11 -16
- examples/warming_up_to_rl/run_local_rollout.py +1 -3
- examples/warming_up_to_rl/run_local_rollout_modal.py +2 -4
- examples/warming_up_to_rl/run_local_rollout_parallel.py +1 -4
- examples/warming_up_to_rl/run_local_rollout_traced.py +3 -5
- examples/warming_up_to_rl/run_rl_and_save.py +5 -6
- examples/warming_up_to_rl/run_rollout_remote.py +8 -10
- examples/warming_up_to_rl/task_app/README.md +6 -2
- examples/warming_up_to_rl/task_app/grpo_crafter.py +234 -35
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +2 -3
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +9 -11
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +131 -114
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +101 -41
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +73 -51
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +14 -6
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +16 -16
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +32 -34
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +94 -31
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +0 -2
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +303 -203
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +21 -23
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +328 -225
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +13 -13
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +4 -3
- synth/__init__.py +14 -0
- synth_ai/__init__.py +26 -4
- synth_ai/api/models/supported.py +376 -0
- synth_ai/api/train/builders.py +128 -21
- synth_ai/api/train/cli.py +80 -64
- synth_ai/api/train/config_finder.py +7 -2
- synth_ai/api/train/env_resolver.py +1 -1
- synth_ai/api/train/pollers.py +2 -1
- synth_ai/api/train/supported_algos.py +139 -0
- synth_ai/api/train/task_app.py +1 -2
- synth_ai/api/train/utils.py +13 -44
- synth_ai/cli/__init__.py +8 -0
- synth_ai/cli/_modal_wrapper.py +28 -0
- synth_ai/cli/_typer_patch.py +49 -0
- synth_ai/cli/balance.py +1 -2
- synth_ai/cli/calc.py +1 -1
- synth_ai/cli/demo.py +2 -1
- synth_ai/cli/recent.py +2 -2
- synth_ai/cli/rl_demo.py +2 -1
- synth_ai/cli/root.py +11 -13
- synth_ai/cli/status.py +2 -2
- synth_ai/cli/task_apps.py +529 -179
- synth_ai/cli/traces.py +6 -4
- synth_ai/cli/watch.py +12 -18
- synth_ai/demo_registry.py +1 -1
- synth_ai/demos/core/cli.py +36 -43
- synth_ai/demos/demo_task_apps/__init__.py +3 -3
- synth_ai/demos/demo_task_apps/core.py +17 -25
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +3 -4
- synth_ai/demos/demo_task_apps/math/app.py +2 -1
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +3 -4
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +16 -18
- synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -1
- synth_ai/environments/examples/crafter_classic/environment.py +76 -1
- synth_ai/environments/reproducibility/tree.py +2 -5
- synth_ai/environments/service/app.py +11 -12
- synth_ai/environments/service/core_routes.py +4 -7
- synth_ai/environments/stateful/engine.py +1 -1
- synth_ai/environments/tasks/core.py +1 -0
- synth_ai/environments/tasks/filters.py +5 -6
- synth_ai/environments/tasks/utils.py +4 -5
- synth_ai/handshake.py +9 -9
- synth_ai/http.py +1 -1
- synth_ai/http_client.py +18 -10
- synth_ai/inference/client.py +15 -5
- synth_ai/jobs/client.py +78 -83
- synth_ai/learning/__init__.py +41 -6
- synth_ai/learning/algorithms.py +14 -0
- synth_ai/learning/client.py +91 -24
- synth_ai/learning/config.py +2 -38
- synth_ai/learning/ft_client.py +4 -59
- synth_ai/learning/health.py +5 -6
- synth_ai/learning/jobs.py +31 -47
- synth_ai/{rl → learning/rl}/__init__.py +14 -4
- synth_ai/learning/rl/client.py +267 -0
- synth_ai/learning/rl/config.py +31 -0
- synth_ai/{rl → learning/rl}/contracts.py +5 -8
- synth_ai/{rl → learning/rl}/env_keys.py +39 -15
- synth_ai/learning/rl/secrets.py +13 -0
- synth_ai/learning/rl_client.py +2 -281
- synth_ai/learning/sft/__init__.py +29 -0
- synth_ai/learning/sft/client.py +68 -0
- synth_ai/learning/sft/config.py +270 -0
- synth_ai/learning/sft/data.py +295 -0
- synth_ai/learning/sse.py +25 -24
- synth_ai/learning/validators.py +25 -28
- synth_ai/lm/__init__.py +21 -47
- synth_ai/main.py +6 -0
- synth_ai/task/__init__.py +25 -27
- synth_ai/task/apps/__init__.py +7 -8
- synth_ai/task/auth.py +8 -8
- synth_ai/task/client.py +14 -14
- synth_ai/task/contracts.py +36 -35
- synth_ai/task/datasets.py +6 -5
- synth_ai/task/errors.py +10 -10
- synth_ai/task/health.py +17 -9
- synth_ai/task/json.py +58 -23
- synth_ai/task/proxy.py +13 -9
- synth_ai/task/rubrics.py +16 -15
- synth_ai/task/server.py +12 -12
- synth_ai/task/tracing_utils.py +4 -4
- synth_ai/task/vendors.py +5 -6
- synth_ai/tracing_v3/__init__.py +2 -0
- synth_ai/tracing_v3/abstractions.py +21 -4
- synth_ai/tracing_v3/decorators.py +18 -16
- synth_ai/tracing_v3/hooks.py +5 -5
- synth_ai/tracing_v3/llm_call_record_helpers.py +6 -6
- synth_ai/tracing_v3/session_tracer.py +40 -14
- synth_ai/tracing_v3/storage/base.py +85 -0
- synth_ai/tracing_v3/storage/config.py +21 -8
- synth_ai/tracing_v3/storage/factory.py +10 -7
- synth_ai/tracing_v3/storage/utils.py +4 -2
- synth_ai/tracing_v3/turso/daemon.py +7 -2
- synth_ai/tracing_v3/turso/models.py +2 -2
- synth_ai/tracing_v3/turso/native_manager.py +1173 -0
- synth_ai/tracing_v3/utils.py +4 -4
- synth_ai/v0/api/__init__.py +8 -0
- synth_ai/v0/api/models/__init__.py +8 -0
- synth_ai/v0/api/models/supported.py +8 -0
- synth_ai/v0/config/__init__.py +15 -0
- synth_ai/v0/config/base_url.py +12 -0
- synth_ai/v0/lm/__init__.py +51 -0
- synth_ai/{lm → v0/lm}/caching/ephemeral.py +2 -2
- synth_ai/{lm → v0/lm}/caching/handler.py +4 -4
- synth_ai/{lm → v0/lm}/caching/initialize.py +1 -1
- synth_ai/{lm → v0/lm}/caching/persistent.py +1 -1
- synth_ai/{lm → v0/lm}/config.py +6 -1
- synth_ai/{lm → v0/lm}/core/all.py +9 -9
- synth_ai/{lm → v0/lm}/core/main.py +6 -6
- synth_ai/{lm → v0/lm}/core/main_v3.py +10 -10
- synth_ai/{lm → v0/lm}/core/synth_models.py +2 -14
- synth_ai/{lm → v0/lm}/core/vendor_clients.py +2 -2
- synth_ai/{lm → v0/lm}/overrides.py +2 -2
- synth_ai/{lm → v0/lm}/provider_support/anthropic.py +4 -4
- synth_ai/{lm → v0/lm}/provider_support/openai.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/handler.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/rehabilitate.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/core/anthropic_api.py +9 -9
- synth_ai/{lm → v0/lm}/vendors/core/gemini_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/mistral_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/openai_api.py +10 -10
- synth_ai/{lm → v0/lm}/vendors/openai_standard.py +8 -8
- synth_ai/{lm → v0/lm}/vendors/openai_standard_responses.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/custom_endpoint.py +3 -3
- synth_ai/{lm → v0/lm}/vendors/supported/deepseek.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/grok.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/groq.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/ollama.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/openrouter.py +3 -3
- synth_ai/{lm → v0/lm}/vendors/supported/together.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/synth_client.py +1 -1
- synth_ai/v0/tracing_v3/__init__.py +10 -0
- synth_ai/v0/tracing_v3/abstractions.py +3 -0
- synth_ai/v0/tracing_v3/decorators.py +3 -0
- synth_ai/v0/tracing_v3/llm_call_record_helpers.py +3 -0
- synth_ai/v0/tracing_v3/session_tracer.py +3 -0
- synth_ai-0.2.9.dev9.dist-info/METADATA +191 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev9.dist-info}/RECORD +268 -238
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev9.dist-info}/top_level.txt +1 -0
- examples/common_old/backend.py +0 -20
- examples/evals_old/README.md +0 -98
- examples/evals_old/__init__.py +0 -6
- examples/evals_old/compare_models.py +0 -1038
- examples/evals_old/example_log.md +0 -145
- examples/evals_old/run_demo.sh +0 -126
- examples/evals_old/trace_analysis.py +0 -270
- examples/finetuning_old/_backup_synth_qwen/config.toml +0 -29
- examples/finetuning_old/_backup_synth_qwen/example_log.md +0 -324
- examples/finetuning_old/_backup_synth_qwen/filter_traces.py +0 -60
- examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +0 -243
- examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +0 -109
- examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +0 -1924
- examples/finetuning_old/_backup_synth_qwen/readme.md +0 -49
- examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +0 -114
- examples/finetuning_old/_backup_synth_qwen/run_demo.sh +0 -195
- examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +0 -119
- examples/finetuning_old/synth_qwen_v1/README.md +0 -68
- examples/finetuning_old/synth_qwen_v1/filter_traces.py +0 -60
- examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +0 -243
- examples/finetuning_old/synth_qwen_v1/finetune.py +0 -46
- examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +0 -71
- examples/finetuning_old/synth_qwen_v1/infer.py +0 -36
- examples/finetuning_old/synth_qwen_v1/poll.py +0 -46
- examples/finetuning_old/synth_qwen_v1/prepare_data.py +0 -35
- examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +0 -109
- examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +0 -1933
- examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +0 -210
- examples/finetuning_old/synth_qwen_v1/run_ft_job.py +0 -237
- examples/finetuning_old/synth_qwen_v1/upload_data.py +0 -34
- examples/finetuning_old/synth_qwen_v1/util.py +0 -152
- examples/rl_old/task_app.py +0 -1131
- examples/warming_up_to_rl/old/event_rewards.md +0 -234
- examples/warming_up_to_rl/old/notes.md +0 -73
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +0 -738
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +0 -580
- synth_ai/experimental/synth_oss.py +0 -445
- synth_ai/learning/filtering.py +0 -0
- synth_ai/learning/offline/dpo.py +0 -0
- synth_ai/learning/offline/providers.py +0 -7
- synth_ai/learning/offline/sft.py +0 -0
- synth_ai/learning/offline/shared.py +0 -0
- synth_ai/learning/online/grpo.py +0 -0
- synth_ai/learning/online/irft.py +0 -0
- synth_ai/learning/prompts/banking77_injection_eval.py +0 -168
- synth_ai/learning/prompts/gepa.py +0 -0
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +0 -211
- synth_ai/learning/prompts/mipro.py +0 -289
- synth_ai/learning/prompts/random_search.py +0 -249
- synth_ai/learning/prompts/run_mipro_banking77.py +0 -172
- synth_ai/learning/prompts/run_random_search_banking77.py +0 -329
- synth_ai/rl/secrets.py +0 -19
- synth_ai/scripts/verify_rewards.py +0 -100
- synth_ai/tracing/__init__.py +0 -30
- synth_ai/tracing_v1/__init__.py +0 -33
- synth_ai/tracing_v3/turso/__init__.py +0 -25
- synth_ai/tracing_v3/turso/manager.py +0 -838
- synth_ai/zyk/__init__.py +0 -30
- synth_ai-0.2.9.dev7.dist-info/METADATA +0 -131
- /synth_ai/{lm → v0/lm}/caching/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/dbs.py +0 -0
- /synth_ai/{lm → v0/lm}/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/core/exceptions.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/monitor.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/statefulness.py +0 -0
- /synth_ai/{lm → v0/lm}/injection.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/suppress_logging.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/inject.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/base.py +0 -0
- /synth_ai/{lm → v0/lm}/unified_interface.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/base.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/synth_dev_api.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/ollama.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/retries.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/supported/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/warmup.py +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev9.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev9.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev9.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,316 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Benchmark Crafter performance across prompt modalities (text-only, image-only, both).
|
|
4
|
+
|
|
5
|
+
For each mode we:
|
|
6
|
+
* Run 20 seeded episodes (configurable) with GPT-4o mini via OpenAI Chat Completions.
|
|
7
|
+
* Execute the returned tool calls in the local Crafter environment.
|
|
8
|
+
* Record achievements/steps and save every rendered frame under `examples/vlm/temp/`.
|
|
9
|
+
|
|
10
|
+
Concurrency is capped by an asyncio semaphore (default parallelism = 10).
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import argparse
|
|
16
|
+
import asyncio
|
|
17
|
+
import base64
|
|
18
|
+
import json
|
|
19
|
+
import os
|
|
20
|
+
from collections import Counter, defaultdict
|
|
21
|
+
from dataclasses import dataclass
|
|
22
|
+
from enum import Enum
|
|
23
|
+
from pathlib import Path
|
|
24
|
+
from typing import Any
|
|
25
|
+
from uuid import uuid4
|
|
26
|
+
|
|
27
|
+
from examples.warming_up_to_rl.task_app.synth_envs_hosted.envs.crafter.environment import (
|
|
28
|
+
CrafterEnvironmentWrapper,
|
|
29
|
+
)
|
|
30
|
+
from examples.warming_up_to_rl.task_app.synth_envs_hosted.envs.crafter.policy import CrafterPolicy
|
|
31
|
+
from openai import AsyncOpenAI
|
|
32
|
+
from synth_ai.environments.examples.crafter_classic.environment import CrafterClassicEnvironment
|
|
33
|
+
from synth_ai.environments.examples.crafter_classic.taskset import (
|
|
34
|
+
CrafterTaskInstance,
|
|
35
|
+
CrafterTaskInstanceMetadata,
|
|
36
|
+
)
|
|
37
|
+
from synth_ai.environments.tasks.core import Impetus, Intent
|
|
38
|
+
|
|
39
|
+
OUTPUT_ROOT = Path("examples/vlm/temp")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class Mode(str, Enum):
|
|
43
|
+
TEXT = "text"
|
|
44
|
+
IMAGE = "image"
|
|
45
|
+
BOTH = "both"
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class EpisodeResult:
|
|
50
|
+
mode: Mode
|
|
51
|
+
seed: int
|
|
52
|
+
steps_taken: int
|
|
53
|
+
achievements: set[str]
|
|
54
|
+
total_reward: float
|
|
55
|
+
tool_calls: int
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _ensure_openai_client(api_key: str | None) -> AsyncOpenAI:
|
|
59
|
+
if not api_key:
|
|
60
|
+
raise RuntimeError(
|
|
61
|
+
"OPENAI_API_KEY must be set to run the VLM benchmark (export the key or add to your .env)."
|
|
62
|
+
)
|
|
63
|
+
return AsyncOpenAI(api_key=api_key)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _build_task_instance(seed: int) -> CrafterTaskInstance:
|
|
67
|
+
impetus = Impetus(instructions="Explore, survive, and unlock achievements.")
|
|
68
|
+
intent = Intent(rubric={"goal": "Unlock achievements"}, gold_trajectories=None, gold_state_diff={})
|
|
69
|
+
metadata = CrafterTaskInstanceMetadata(
|
|
70
|
+
difficulty="custom",
|
|
71
|
+
seed=seed,
|
|
72
|
+
num_trees_radius=0,
|
|
73
|
+
num_cows_radius=0,
|
|
74
|
+
num_hostiles_radius=0,
|
|
75
|
+
)
|
|
76
|
+
instance = CrafterTaskInstance(
|
|
77
|
+
id=uuid4(),
|
|
78
|
+
impetus=impetus,
|
|
79
|
+
intent=intent,
|
|
80
|
+
metadata=metadata,
|
|
81
|
+
is_reproducible=True,
|
|
82
|
+
initial_engine_snapshot=None,
|
|
83
|
+
)
|
|
84
|
+
# Engine expects these config keys
|
|
85
|
+
instance.config = {"seed": seed, "length": 256, "area": [64, 64]}
|
|
86
|
+
return instance
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _save_observation_frame(observation_packet: dict[str, Any], dest_path: Path) -> None:
|
|
90
|
+
obs = observation_packet.get("observation")
|
|
91
|
+
if not isinstance(obs, dict):
|
|
92
|
+
return
|
|
93
|
+
image_b64 = obs.get("observation_image_base64")
|
|
94
|
+
if not isinstance(image_b64, str) or not image_b64:
|
|
95
|
+
return
|
|
96
|
+
try:
|
|
97
|
+
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
|
98
|
+
dest_path.write_bytes(base64.b64decode(image_b64))
|
|
99
|
+
except Exception:
|
|
100
|
+
pass # best effort
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _strip_image_fields(observation_packet: dict[str, Any]) -> dict[str, Any]:
|
|
104
|
+
stripped = json.loads(json.dumps(observation_packet))
|
|
105
|
+
obs = stripped.get("observation")
|
|
106
|
+
if isinstance(obs, dict):
|
|
107
|
+
for key in list(obs.keys()):
|
|
108
|
+
if key.startswith("observation_image"):
|
|
109
|
+
obs.pop(key, None)
|
|
110
|
+
return stripped
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _make_image_only_request(request: dict[str, Any]) -> dict[str, Any]:
|
|
114
|
+
cloned = json.loads(json.dumps(request))
|
|
115
|
+
for message in cloned.get("messages", []):
|
|
116
|
+
if message.get("role") != "user":
|
|
117
|
+
continue
|
|
118
|
+
content = message.get("content")
|
|
119
|
+
if isinstance(content, list):
|
|
120
|
+
image_parts = [
|
|
121
|
+
item
|
|
122
|
+
for item in content
|
|
123
|
+
if isinstance(item, dict) and item.get("type") in {"image_url", "image"}
|
|
124
|
+
]
|
|
125
|
+
message["content"] = image_parts or content
|
|
126
|
+
elif isinstance(content, str):
|
|
127
|
+
# No structured parts available; leave as empty string
|
|
128
|
+
message["content"] = ""
|
|
129
|
+
return cloned
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
async def _run_episode(
|
|
133
|
+
*,
|
|
134
|
+
mode: Mode,
|
|
135
|
+
seed: int,
|
|
136
|
+
client: AsyncOpenAI,
|
|
137
|
+
model: str,
|
|
138
|
+
max_steps: int,
|
|
139
|
+
temperature: float,
|
|
140
|
+
semaphore: asyncio.Semaphore,
|
|
141
|
+
) -> EpisodeResult:
|
|
142
|
+
async with semaphore:
|
|
143
|
+
task_instance = _build_task_instance(seed)
|
|
144
|
+
env = CrafterClassicEnvironment(task_instance)
|
|
145
|
+
wrapper = CrafterEnvironmentWrapper(env, seed=seed)
|
|
146
|
+
|
|
147
|
+
policy = CrafterPolicy(inference_url="openai://chat-completions", model=model)
|
|
148
|
+
await policy.initialize({"use_tools": True, "model": model})
|
|
149
|
+
|
|
150
|
+
observation_packet = await wrapper.initialize()
|
|
151
|
+
achievements: set[str] = set()
|
|
152
|
+
total_reward = 0.0
|
|
153
|
+
steps_taken = 0
|
|
154
|
+
tool_calls_total = 0
|
|
155
|
+
|
|
156
|
+
frames_dir = OUTPUT_ROOT / f"{mode.value}_frames" / f"seed_{seed:04d}"
|
|
157
|
+
_save_observation_frame(observation_packet, frames_dir / "step_000.png")
|
|
158
|
+
|
|
159
|
+
for step_idx in range(max_steps):
|
|
160
|
+
obs_dict = observation_packet.get("observation")
|
|
161
|
+
if not isinstance(obs_dict, dict):
|
|
162
|
+
break
|
|
163
|
+
|
|
164
|
+
observation_for_policy: dict[str, Any]
|
|
165
|
+
metadata_payload: dict[str, Any] = {}
|
|
166
|
+
|
|
167
|
+
if mode == Mode.TEXT:
|
|
168
|
+
observation_for_policy = _strip_image_fields(observation_packet)
|
|
169
|
+
else:
|
|
170
|
+
observation_for_policy = json.loads(json.dumps(observation_packet))
|
|
171
|
+
metadata_payload["raw_observation"] = observation_packet
|
|
172
|
+
|
|
173
|
+
obs_text = policy._format_observation_for_llm(observation_for_policy) # noqa: SLF001
|
|
174
|
+
_, meta = await policy.step(
|
|
175
|
+
observation_text=obs_text,
|
|
176
|
+
metadata=metadata_payload,
|
|
177
|
+
)
|
|
178
|
+
inference_request = json.loads(json.dumps(meta["inference_request"]))
|
|
179
|
+
|
|
180
|
+
if mode == Mode.IMAGE:
|
|
181
|
+
inference_request = _make_image_only_request(inference_request)
|
|
182
|
+
|
|
183
|
+
inference_request.update(
|
|
184
|
+
{
|
|
185
|
+
"model": model,
|
|
186
|
+
"temperature": temperature,
|
|
187
|
+
"max_tokens": inference_request.get("max_tokens", 512),
|
|
188
|
+
}
|
|
189
|
+
)
|
|
190
|
+
inference_request.pop("stop_after_tool_calls", None)
|
|
191
|
+
inference_request.pop("thinking_mode", None)
|
|
192
|
+
inference_request.pop("thinking_budget", None)
|
|
193
|
+
|
|
194
|
+
response = await client.chat.completions.create(**inference_request)
|
|
195
|
+
response_dict = response.model_dump()
|
|
196
|
+
|
|
197
|
+
assistant_tool_calls = CrafterPolicy.parse_response_to_tool_calls(
|
|
198
|
+
response_dict,
|
|
199
|
+
use_tools=policy.use_tools,
|
|
200
|
+
)
|
|
201
|
+
if not assistant_tool_calls:
|
|
202
|
+
break
|
|
203
|
+
|
|
204
|
+
tool_calls_total += len(assistant_tool_calls)
|
|
205
|
+
assistant_message = response_dict["choices"][0].get("message") or {}
|
|
206
|
+
assistant_text = assistant_message.get("content")
|
|
207
|
+
|
|
208
|
+
env_response = await wrapper.step(assistant_tool_calls)
|
|
209
|
+
if not isinstance(env_response, dict):
|
|
210
|
+
raise RuntimeError(f"Unexpected environment response type: {type(env_response)!r}")
|
|
211
|
+
|
|
212
|
+
policy._append_assistant_turn( # noqa: SLF001
|
|
213
|
+
assistant_text,
|
|
214
|
+
assistant_tool_calls,
|
|
215
|
+
env_response,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
steps_taken += 1
|
|
219
|
+
obs = env_response.get("observation")
|
|
220
|
+
if isinstance(obs, dict):
|
|
221
|
+
ach = obs.get("achievements_status")
|
|
222
|
+
if isinstance(ach, dict):
|
|
223
|
+
for name, unlocked in ach.items():
|
|
224
|
+
if unlocked:
|
|
225
|
+
achievements.add(str(name))
|
|
226
|
+
reward = obs.get("reward_last_step")
|
|
227
|
+
if isinstance(reward, (int, float)):
|
|
228
|
+
total_reward += float(reward)
|
|
229
|
+
|
|
230
|
+
_save_observation_frame(env_response, frames_dir / f"step_{step_idx + 1:03d}.png")
|
|
231
|
+
|
|
232
|
+
if env_response.get("done"):
|
|
233
|
+
break
|
|
234
|
+
observation_packet = env_response
|
|
235
|
+
|
|
236
|
+
await wrapper.terminate()
|
|
237
|
+
return EpisodeResult(
|
|
238
|
+
mode=mode,
|
|
239
|
+
seed=seed,
|
|
240
|
+
steps_taken=steps_taken,
|
|
241
|
+
achievements=achievements,
|
|
242
|
+
total_reward=total_reward,
|
|
243
|
+
tool_calls=tool_calls_total,
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def _summarise(results: list[EpisodeResult]) -> dict[str, Any]:
|
|
248
|
+
grouped: dict[Mode, list[EpisodeResult]] = defaultdict(list)
|
|
249
|
+
for result in results:
|
|
250
|
+
grouped[result.mode].append(result)
|
|
251
|
+
|
|
252
|
+
summary: dict[str, Any] = {}
|
|
253
|
+
for mode, mode_results in grouped.items():
|
|
254
|
+
if not mode_results:
|
|
255
|
+
continue
|
|
256
|
+
mean_steps = sum(r.steps_taken for r in mode_results) / len(mode_results)
|
|
257
|
+
mean_achievements = sum(len(r.achievements) for r in mode_results) / len(mode_results)
|
|
258
|
+
achievement_counts = Counter()
|
|
259
|
+
for res in mode_results:
|
|
260
|
+
achievement_counts.update(res.achievements)
|
|
261
|
+
summary[mode.value] = {
|
|
262
|
+
"episodes": len(mode_results),
|
|
263
|
+
"mean_steps": round(mean_steps, 2),
|
|
264
|
+
"mean_achievements": round(mean_achievements, 2),
|
|
265
|
+
"total_tool_calls": sum(r.tool_calls for r in mode_results),
|
|
266
|
+
"achievements": {name: count for name, count in sorted(achievement_counts.items())},
|
|
267
|
+
}
|
|
268
|
+
return summary
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
async def main() -> None:
|
|
272
|
+
parser = argparse.ArgumentParser(description=__doc__)
|
|
273
|
+
parser.add_argument("--model", default="gpt-4o-mini-2024-07-18", help="OpenAI model id to benchmark")
|
|
274
|
+
parser.add_argument("--seeds", type=int, default=20, help="Number of seeds per mode")
|
|
275
|
+
parser.add_argument("--steps", type=int, default=10, help="Max steps per episode")
|
|
276
|
+
parser.add_argument("--temperature", type=float, default=0.6, help="Sampling temperature")
|
|
277
|
+
parser.add_argument("--concurrency", type=int, default=10, help="Max concurrent OpenAI calls")
|
|
278
|
+
args = parser.parse_args()
|
|
279
|
+
|
|
280
|
+
api_key = os.getenv("OPENAI_API_KEY")
|
|
281
|
+
client = _ensure_openai_client(api_key)
|
|
282
|
+
semaphore = asyncio.Semaphore(max(1, args.concurrency))
|
|
283
|
+
|
|
284
|
+
OUTPUT_ROOT.mkdir(parents=True, exist_ok=True)
|
|
285
|
+
|
|
286
|
+
tasks: list[asyncio.Task[EpisodeResult]] = []
|
|
287
|
+
for mode in (Mode.TEXT, Mode.IMAGE, Mode.BOTH):
|
|
288
|
+
for seed in range(args.seeds):
|
|
289
|
+
task = asyncio.create_task(
|
|
290
|
+
_run_episode(
|
|
291
|
+
mode=mode,
|
|
292
|
+
seed=seed,
|
|
293
|
+
client=client,
|
|
294
|
+
model=args.model,
|
|
295
|
+
max_steps=args.steps,
|
|
296
|
+
temperature=args.temperature,
|
|
297
|
+
semaphore=semaphore,
|
|
298
|
+
)
|
|
299
|
+
)
|
|
300
|
+
tasks.append(task)
|
|
301
|
+
|
|
302
|
+
results = await asyncio.gather(*tasks)
|
|
303
|
+
summary = _summarise(results)
|
|
304
|
+
|
|
305
|
+
summary_path = OUTPUT_ROOT / "vlm_benchmark_summary.json"
|
|
306
|
+
summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
|
|
307
|
+
|
|
308
|
+
print("\nBenchmark Summary")
|
|
309
|
+
print("-----------------")
|
|
310
|
+
print(json.dumps(summary, indent=2))
|
|
311
|
+
print(f"\nFrames stored under: {OUTPUT_ROOT}/<mode>_frames/seed_xxxx/")
|
|
312
|
+
print(f"Summary saved to: {summary_path}")
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
if __name__ == "__main__":
|
|
316
|
+
asyncio.run(main())
|
|
@@ -9,7 +9,7 @@ import sqlite3
|
|
|
9
9
|
import sys
|
|
10
10
|
from collections import Counter, defaultdict
|
|
11
11
|
from pathlib import Path
|
|
12
|
-
from typing import Any
|
|
12
|
+
from typing import Any
|
|
13
13
|
|
|
14
14
|
Row = sqlite3.Row
|
|
15
15
|
|
|
@@ -56,7 +56,7 @@ def fetch_model_usage(conn: sqlite3.Connection) -> list[dict[str, Any]]:
|
|
|
56
56
|
def _parse_json(value: Any) -> Any:
|
|
57
57
|
if value is None:
|
|
58
58
|
return None
|
|
59
|
-
if isinstance(value,
|
|
59
|
+
if isinstance(value, dict | list):
|
|
60
60
|
return value
|
|
61
61
|
try:
|
|
62
62
|
return json.loads(value)
|
|
@@ -64,7 +64,7 @@ def _parse_json(value: Any) -> Any:
|
|
|
64
64
|
return None
|
|
65
65
|
|
|
66
66
|
|
|
67
|
-
AchievementMap = dict[
|
|
67
|
+
AchievementMap = dict[tuple[str, int], dict[str, list[str]]]
|
|
68
68
|
|
|
69
69
|
|
|
70
70
|
def fetch_achievement_data(
|
|
@@ -162,7 +162,7 @@ def fetch_achievement_data(
|
|
|
162
162
|
achievement_name_counts.update(achievement_set)
|
|
163
163
|
|
|
164
164
|
achievement_size_counts: Counter = Counter()
|
|
165
|
-
for
|
|
165
|
+
for _session_id, count in unique_counts_per_session.items():
|
|
166
166
|
achievement_size_counts[count] += 1
|
|
167
167
|
|
|
168
168
|
return (
|
|
@@ -295,7 +295,7 @@ def format_reward_summary(outcome: dict[str, Any], breakdown: list[dict[str, Any
|
|
|
295
295
|
|
|
296
296
|
|
|
297
297
|
def compute_model_achievement_stats(
|
|
298
|
-
conn: sqlite3.Connection, session_unique_sets: dict[str,
|
|
298
|
+
conn: sqlite3.Connection, session_unique_sets: dict[str, set[str]]
|
|
299
299
|
) -> dict[str, dict[str, Any]]:
|
|
300
300
|
"""Aggregate unique-achievement stats per model."""
|
|
301
301
|
|
|
@@ -42,9 +42,13 @@ base = "Qwen/Qwen3-4B"
|
|
|
42
42
|
label = "crafter-rl-from-base"
|
|
43
43
|
|
|
44
44
|
[rollout]
|
|
45
|
+
env_name = "crafter"
|
|
45
46
|
max_turns = 10
|
|
46
47
|
episodes_per_batch = 64
|
|
47
|
-
policy_name = "crafter"
|
|
48
|
+
policy_name = "crafter-react"
|
|
49
|
+
max_concurrent_rollouts = 8
|
|
50
|
+
batches_per_step = 2
|
|
51
|
+
ops = ["agent", "env"]
|
|
48
52
|
|
|
49
53
|
[evaluation]
|
|
50
54
|
# Run baseline evaluation over the first 100 seeds every 20 training iterations
|
|
@@ -55,6 +59,12 @@ seeds = [
|
|
|
55
59
|
]
|
|
56
60
|
|
|
57
61
|
[training]
|
|
62
|
+
num_epochs = 1
|
|
63
|
+
iterations_per_epoch = 10
|
|
64
|
+
batch_size = 16
|
|
65
|
+
group_size = 4
|
|
66
|
+
gradient_accumulation_steps = 1
|
|
67
|
+
learning_rate = 5e-5
|
|
58
68
|
log_interval = 1
|
|
59
69
|
weight_sync_interval = 1
|
|
60
70
|
# Additional RL hyperparameters can go here
|
|
@@ -8,8 +8,9 @@ import json
|
|
|
8
8
|
import sqlite3
|
|
9
9
|
import sys
|
|
10
10
|
from collections import Counter, defaultdict
|
|
11
|
+
from collections.abc import Iterable
|
|
11
12
|
from pathlib import Path
|
|
12
|
-
from typing import Any
|
|
13
|
+
from typing import Any
|
|
13
14
|
|
|
14
15
|
Row = sqlite3.Row
|
|
15
16
|
|
|
@@ -23,7 +24,7 @@ def connect(db_path: Path) -> sqlite3.Connection:
|
|
|
23
24
|
def _parse_json(value: Any) -> Any:
|
|
24
25
|
if value is None:
|
|
25
26
|
return None
|
|
26
|
-
if isinstance(value,
|
|
27
|
+
if isinstance(value, dict | list):
|
|
27
28
|
return value
|
|
28
29
|
try:
|
|
29
30
|
return json.loads(value)
|
|
@@ -31,7 +32,7 @@ def _parse_json(value: Any) -> Any:
|
|
|
31
32
|
return None
|
|
32
33
|
|
|
33
34
|
|
|
34
|
-
AchievementMap = dict[
|
|
35
|
+
AchievementMap = dict[tuple[str, int], dict[str, list[str]]]
|
|
35
36
|
|
|
36
37
|
|
|
37
38
|
def fetch_achievement_data(
|
|
@@ -116,7 +117,7 @@ def fetch_achievement_data(
|
|
|
116
117
|
achievement_name_counts.update(achievement_set)
|
|
117
118
|
|
|
118
119
|
achievement_size_counts: Counter = Counter()
|
|
119
|
-
for
|
|
120
|
+
for _session_id, count in unique_counts_per_session.items():
|
|
120
121
|
achievement_size_counts[count] += 1
|
|
121
122
|
|
|
122
123
|
return (
|
|
@@ -203,25 +204,71 @@ def parse_event_filters(specs: list[str] | None) -> list[tuple[str, float]]:
|
|
|
203
204
|
if min_val_str:
|
|
204
205
|
try:
|
|
205
206
|
min_val = float(min_val_str)
|
|
206
|
-
except ValueError:
|
|
207
|
+
except ValueError as e:
|
|
207
208
|
print(f"Invalid event reward specification '{spec}'", file=sys.stderr)
|
|
208
|
-
raise SystemExit(1)
|
|
209
|
+
raise SystemExit(1) from e
|
|
209
210
|
filters.append((reward_type, min_val))
|
|
210
211
|
return filters
|
|
211
212
|
|
|
212
213
|
|
|
213
|
-
def
|
|
214
|
-
|
|
214
|
+
def _collect_content(
|
|
215
|
+
parts: Iterable[dict[str, Any]] | None,
|
|
216
|
+
) -> tuple[Any, bool]:
|
|
217
|
+
"""Normalise multimodal content parts into OpenAI-style segments."""
|
|
218
|
+
|
|
215
219
|
if not parts:
|
|
216
|
-
return ""
|
|
220
|
+
return "", False
|
|
221
|
+
|
|
222
|
+
segments: list[dict[str, Any]] = []
|
|
223
|
+
has_image = False
|
|
224
|
+
|
|
217
225
|
for part in parts:
|
|
218
226
|
if not isinstance(part, dict):
|
|
219
227
|
continue
|
|
220
|
-
|
|
228
|
+
ptype = part.get("type")
|
|
229
|
+
if ptype == "text":
|
|
221
230
|
text = part.get("text")
|
|
222
|
-
if isinstance(text, str)
|
|
223
|
-
|
|
224
|
-
|
|
231
|
+
if isinstance(text, str):
|
|
232
|
+
segments.append({"type": "text", "text": text})
|
|
233
|
+
elif ptype == "image":
|
|
234
|
+
uri = part.get("uri")
|
|
235
|
+
mime_type = part.get("mime_type") or "image/png"
|
|
236
|
+
data_url = None
|
|
237
|
+
if isinstance(uri, str) and uri.startswith("data:"):
|
|
238
|
+
data_url = uri
|
|
239
|
+
else:
|
|
240
|
+
source = part.get("data") or part.get("source")
|
|
241
|
+
if isinstance(source, dict):
|
|
242
|
+
base64_data = source.get("data")
|
|
243
|
+
media_type = source.get("media_type") or mime_type
|
|
244
|
+
if isinstance(base64_data, str) and base64_data:
|
|
245
|
+
data_url = f"data:{media_type};base64,{base64_data}"
|
|
246
|
+
if data_url:
|
|
247
|
+
has_image = True
|
|
248
|
+
segments.append({"type": "image_url", "image_url": {"url": data_url}})
|
|
249
|
+
elif ptype == "image_url":
|
|
250
|
+
image_url = part.get("image_url", {})
|
|
251
|
+
if isinstance(image_url, dict):
|
|
252
|
+
url = image_url.get("url")
|
|
253
|
+
if isinstance(url, str) and url:
|
|
254
|
+
has_image = True
|
|
255
|
+
segments.append({"type": "image_url", "image_url": {"url": url}})
|
|
256
|
+
|
|
257
|
+
if not segments:
|
|
258
|
+
return "", False
|
|
259
|
+
if not has_image and len(segments) == 1 and segments[0]["type"] == "text":
|
|
260
|
+
return segments[0]["text"], False
|
|
261
|
+
return segments, has_image
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def _normalise_output_content(content: Any) -> tuple[Any, bool]:
|
|
265
|
+
if isinstance(content, list):
|
|
266
|
+
return _collect_content(content)
|
|
267
|
+
if isinstance(content, str):
|
|
268
|
+
return content, False
|
|
269
|
+
if content is None:
|
|
270
|
+
return "", False
|
|
271
|
+
return str(content), False
|
|
225
272
|
|
|
226
273
|
|
|
227
274
|
def _normalise_tool_calls(tool_calls: list[dict[str, Any]] | None) -> list[dict[str, Any]]:
|
|
@@ -251,7 +298,7 @@ def _normalise_tool_calls(tool_calls: list[dict[str, Any]] | None) -> list[dict[
|
|
|
251
298
|
except Exception:
|
|
252
299
|
args = raw
|
|
253
300
|
|
|
254
|
-
if isinstance(args,
|
|
301
|
+
if isinstance(args, dict | list):
|
|
255
302
|
args_str = json.dumps(args, ensure_ascii=False)
|
|
256
303
|
elif isinstance(args, str):
|
|
257
304
|
args_str = args
|
|
@@ -279,7 +326,7 @@ def _normalise_tool_calls(tool_calls: list[dict[str, Any]] | None) -> list[dict[
|
|
|
279
326
|
def build_sft_dataset(
|
|
280
327
|
conn: sqlite3.Connection,
|
|
281
328
|
achievements_map: AchievementMap,
|
|
282
|
-
sessions_filter:
|
|
329
|
+
sessions_filter: set[str],
|
|
283
330
|
*,
|
|
284
331
|
allowed_models: set[str] | None = None,
|
|
285
332
|
limit: int | None = None,
|
|
@@ -329,14 +376,18 @@ def build_sft_dataset(
|
|
|
329
376
|
|
|
330
377
|
for record in call_records:
|
|
331
378
|
messages: list[dict[str, Any]] = []
|
|
379
|
+
input_has_image = False
|
|
332
380
|
for message in record.get("input_messages", []):
|
|
333
381
|
role = message.get("role", "unknown")
|
|
334
|
-
content =
|
|
335
|
-
if not
|
|
382
|
+
content, has_image = _collect_content(message.get("parts"))
|
|
383
|
+
if (content == "" or content is None) and not has_image:
|
|
336
384
|
continue
|
|
385
|
+
if has_image and role == "user":
|
|
386
|
+
input_has_image = True
|
|
337
387
|
messages.append({"role": role, "content": content})
|
|
338
388
|
|
|
339
|
-
|
|
389
|
+
assistant_content_value: Any = ""
|
|
390
|
+
assistant_has_image = False
|
|
340
391
|
assistant_tool_calls: list[dict[str, Any]] = []
|
|
341
392
|
|
|
342
393
|
output_text = record.get("output_text")
|
|
@@ -351,7 +402,9 @@ def build_sft_dataset(
|
|
|
351
402
|
choices = parsed_response.get("choices") or []
|
|
352
403
|
if choices:
|
|
353
404
|
message = choices[0].get("message") or {}
|
|
354
|
-
|
|
405
|
+
assistant_content_value, assistant_has_image = _normalise_output_content(
|
|
406
|
+
message.get("content")
|
|
407
|
+
)
|
|
355
408
|
assistant_tool_calls = _normalise_tool_calls(message.get("tool_calls"))
|
|
356
409
|
|
|
357
410
|
if not assistant_tool_calls:
|
|
@@ -359,12 +412,13 @@ def build_sft_dataset(
|
|
|
359
412
|
|
|
360
413
|
assistant_message: dict[str, Any] = {
|
|
361
414
|
"role": "assistant",
|
|
362
|
-
"content":
|
|
415
|
+
"content": assistant_content_value,
|
|
363
416
|
}
|
|
364
417
|
if assistant_tool_calls:
|
|
365
418
|
assistant_message["tool_calls"] = assistant_tool_calls
|
|
366
419
|
|
|
367
|
-
|
|
420
|
+
content_empty = assistant_message.get("content") in ("", None)
|
|
421
|
+
if content_empty and not assistant_message.get("tool_calls"):
|
|
368
422
|
continue
|
|
369
423
|
|
|
370
424
|
messages.append(assistant_message)
|
|
@@ -385,6 +439,9 @@ def build_sft_dataset(
|
|
|
385
439
|
"turned_true": achievements.get("all", []),
|
|
386
440
|
"cumulative_unique": cumulative_unique[session_id],
|
|
387
441
|
},
|
|
442
|
+
"user_has_image": input_has_image,
|
|
443
|
+
"assistant_has_image": assistant_has_image,
|
|
444
|
+
"has_image": input_has_image or assistant_has_image,
|
|
388
445
|
}
|
|
389
446
|
|
|
390
447
|
dataset.append({"messages": messages, "metadata": metadata})
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
1
|
"""Quick smoke test that drives a rollout through the Groq proxy-backed Crafter Task App."""
|
|
4
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
5
|
import argparse
|
|
6
6
|
import asyncio
|
|
7
7
|
import os
|
|
@@ -29,8 +29,8 @@ def _build_policy_payload(seed: int, model: str) -> dict[str, Any]:
|
|
|
29
29
|
{
|
|
30
30
|
"role": "user",
|
|
31
31
|
"content": (
|
|
32
|
-
"Environment seed {seed}. Plan initial survival/crafting steps and then call interact with concrete actions."
|
|
33
|
-
)
|
|
32
|
+
f"Environment seed {seed}. Plan initial survival/crafting steps and then call interact with concrete actions."
|
|
33
|
+
),
|
|
34
34
|
},
|
|
35
35
|
],
|
|
36
36
|
}
|
|
@@ -8,11 +8,10 @@ import subprocess
|
|
|
8
8
|
import sys
|
|
9
9
|
import tempfile
|
|
10
10
|
from pathlib import Path
|
|
11
|
-
from typing import Dict, Tuple
|
|
12
11
|
|
|
13
12
|
|
|
14
|
-
def load_env_file(path: Path) ->
|
|
15
|
-
env:
|
|
13
|
+
def load_env_file(path: Path) -> dict[str, str]:
|
|
14
|
+
env: dict[str, str] = {}
|
|
16
15
|
if not path.exists():
|
|
17
16
|
raise FileNotFoundError(f".env not found at {path}")
|
|
18
17
|
for line in path.read_text(encoding="utf-8").splitlines():
|
|
@@ -24,7 +23,7 @@ def load_env_file(path: Path) -> Dict[str, str]:
|
|
|
24
23
|
return env
|
|
25
24
|
|
|
26
25
|
|
|
27
|
-
def write_temp_env(kv:
|
|
26
|
+
def write_temp_env(kv: dict[str, str]) -> Path:
|
|
28
27
|
fd, p = tempfile.mkstemp(prefix="modal_secret_", suffix=".env")
|
|
29
28
|
path = Path(p)
|
|
30
29
|
with os.fdopen(fd, "w", encoding="utf-8") as fh:
|
|
@@ -33,14 +32,14 @@ def write_temp_env(kv: Dict[str, str]) -> Path:
|
|
|
33
32
|
return path
|
|
34
33
|
|
|
35
34
|
|
|
36
|
-
def run(cmd: str) ->
|
|
35
|
+
def run(cmd: str) -> tuple[int, str]:
|
|
37
36
|
proc = subprocess.run(
|
|
38
37
|
cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
|
39
38
|
)
|
|
40
39
|
return proc.returncode, proc.stdout
|
|
41
40
|
|
|
42
41
|
|
|
43
|
-
def ensure_secret(secret_name: str, kv:
|
|
42
|
+
def ensure_secret(secret_name: str, kv: dict[str, str]) -> None:
|
|
44
43
|
if not kv:
|
|
45
44
|
print(f"[skip] {secret_name}: no values provided")
|
|
46
45
|
return
|
|
@@ -48,10 +47,10 @@ def ensure_secret(secret_name: str, kv: Dict[str, str]) -> None:
|
|
|
48
47
|
kv_args = " ".join([f"{shlex.quote(k)}={shlex.quote(v)}" for k, v in kv.items()])
|
|
49
48
|
|
|
50
49
|
# Try plain modal first; fallback to uv run modal
|
|
51
|
-
def _create() ->
|
|
50
|
+
def _create() -> tuple[int, str]:
|
|
52
51
|
return run(f"modal secret create {shlex.quote(secret_name)} {kv_args}")
|
|
53
52
|
|
|
54
|
-
def _delete() ->
|
|
53
|
+
def _delete() -> tuple[int, str]:
|
|
55
54
|
return run(f"printf 'y\n' | modal secret delete {shlex.quote(secret_name)}")
|
|
56
55
|
|
|
57
56
|
rc, out = _create()
|
|
@@ -86,15 +85,6 @@ def main() -> None:
|
|
|
86
85
|
env = load_env_file(Path(args.env_path))
|
|
87
86
|
|
|
88
87
|
# Secrets used by the task app
|
|
89
|
-
env_secret = {
|
|
90
|
-
k: v
|
|
91
|
-
for k, v in {
|
|
92
|
-
"ENVIRONMENT_API_KEY": env.get("ENVIRONMENT_API_KEY", ""),
|
|
93
|
-
"dev_environment_api_key": env.get("ENVIRONMENT_API_KEY", ""),
|
|
94
|
-
}.items()
|
|
95
|
-
if v
|
|
96
|
-
}
|
|
97
|
-
|
|
98
88
|
groq_secret = {
|
|
99
89
|
k: v
|
|
100
90
|
for k, v in {
|
|
@@ -118,7 +108,12 @@ def main() -> None:
|
|
|
118
108
|
{"SYNTH_API_KEY": env.get("SYNTH_API_KEY", "")} if env.get("SYNTH_API_KEY") else {}
|
|
119
109
|
)
|
|
120
110
|
|
|
121
|
-
|
|
111
|
+
env_key = env.get("ENVIRONMENT_API_KEY", "")
|
|
112
|
+
if env_key:
|
|
113
|
+
print(
|
|
114
|
+
"Skipping Modal secret 'crafter-environment-sdk'; the task app now expects "
|
|
115
|
+
"ENVIRONMENT_API_KEY via --env-file so the CLI-minted value stays in sync."
|
|
116
|
+
)
|
|
122
117
|
ensure_secret("groq-api-key", groq_secret)
|
|
123
118
|
ensure_secret("openai-api-key", openai_secret)
|
|
124
119
|
if synth_secret:
|