synth-ai 0.2.9.dev7__py3-none-any.whl → 0.2.9.dev9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/__init__.py +16 -0
- examples/crafter_debug_render.py +8 -11
- examples/qwen_coder/README.md +102 -0
- examples/qwen_coder/_shared.py +113 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +61 -0
- examples/qwen_coder/configs/coder_lora_4b.toml +57 -0
- examples/qwen_coder/configs/coder_lora_small.toml +58 -0
- examples/qwen_coder/generate_dataset.py +98 -0
- examples/qwen_coder/infer_ft_smoke.py +64 -0
- examples/qwen_coder/infer_prod_proxy.py +73 -0
- examples/qwen_coder/infer_via_synth.py +87 -0
- examples/qwen_coder/scripts/infer_coder.sh +18 -0
- examples/qwen_coder/scripts/train_coder_30b.sh +21 -0
- examples/qwen_coder/sft_full_17b.py +103 -0
- examples/qwen_coder/sft_lora_30b.py +110 -0
- examples/qwen_coder/subset_jsonl.py +38 -0
- examples/qwen_coder/validate_jsonl.py +59 -0
- examples/rl/run_eval.py +36 -37
- examples/rl/run_rl_and_save.py +5 -5
- examples/rl/task_app/math_single_step.py +65 -43
- examples/rl/task_app/math_task_app.py +3 -3
- examples/sft/README.md +139 -0
- examples/sft/configs/crafter_fft_qwen0p6b.toml +44 -0
- examples/sft/configs/crafter_lora_qwen0p6b.toml +45 -0
- examples/sft/evaluate.py +117 -0
- examples/sft/export_dataset.py +117 -0
- examples/sft/generate_traces.py +162 -0
- examples/swe/__init__.py +12 -0
- examples/swe/task_app/README.md +105 -0
- examples/swe/task_app/__init__.py +2 -0
- examples/swe/task_app/grpo_swe_mini.py +571 -0
- examples/swe/task_app/grpo_swe_mini_task_app.py +136 -0
- examples/swe/task_app/hosted/README.md +173 -0
- examples/swe/task_app/hosted/__init__.py +5 -0
- examples/swe/task_app/hosted/branching.py +143 -0
- examples/swe/task_app/hosted/environment_routes.py +1289 -0
- examples/swe/task_app/hosted/envs/__init__.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/__init__.py +6 -0
- examples/swe/task_app/hosted/envs/crafter/app.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/environment.py +522 -0
- examples/swe/task_app/hosted/envs/crafter/policy.py +478 -0
- examples/swe/task_app/hosted/envs/crafter/react_agent.py +108 -0
- examples/swe/task_app/hosted/envs/crafter/shared.py +305 -0
- examples/swe/task_app/hosted/envs/crafter/tools.py +47 -0
- examples/swe/task_app/hosted/envs/mini_swe/__init__.py +8 -0
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +1164 -0
- examples/swe/task_app/hosted/envs/mini_swe/policy.py +355 -0
- examples/swe/task_app/hosted/envs/mini_swe/shared.py +83 -0
- examples/swe/task_app/hosted/envs/mini_swe/tools.py +96 -0
- examples/swe/task_app/hosted/hosted_app.py +204 -0
- examples/swe/task_app/hosted/inference/__init__.py +5 -0
- examples/swe/task_app/hosted/inference/openai_client.py +618 -0
- examples/swe/task_app/hosted/main.py +100 -0
- examples/swe/task_app/hosted/policy_routes.py +1079 -0
- examples/swe/task_app/hosted/registry.py +195 -0
- examples/swe/task_app/hosted/rollout.py +1869 -0
- examples/swe/task_app/hosted/storage/__init__.py +5 -0
- examples/swe/task_app/hosted/storage/volume.py +211 -0
- examples/swe/task_app/hosted/test_agents.py +161 -0
- examples/swe/task_app/hosted/test_service.py +137 -0
- examples/swe/task_app/hosted/utils.py +62 -0
- examples/vlm/README.md +68 -0
- examples/vlm/configs/crafter_vlm_gpt4o.toml +44 -0
- examples/vlm/crafter_image_only_agent.py +207 -0
- examples/vlm/crafter_openai_vlm_agent.py +277 -0
- examples/vlm/filter_image_rows.py +63 -0
- examples/vlm/run_crafter_vlm_benchmark.py +316 -0
- examples/warming_up_to_rl/analyze_trace_db.py +5 -5
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +11 -1
- examples/warming_up_to_rl/export_trace_sft.py +78 -21
- examples/warming_up_to_rl/groq_test.py +4 -4
- examples/warming_up_to_rl/manage_secrets.py +13 -18
- examples/warming_up_to_rl/run_eval.py +42 -44
- examples/warming_up_to_rl/run_fft_and_save.py +11 -16
- examples/warming_up_to_rl/run_local_rollout.py +1 -3
- examples/warming_up_to_rl/run_local_rollout_modal.py +2 -4
- examples/warming_up_to_rl/run_local_rollout_parallel.py +1 -4
- examples/warming_up_to_rl/run_local_rollout_traced.py +3 -5
- examples/warming_up_to_rl/run_rl_and_save.py +5 -6
- examples/warming_up_to_rl/run_rollout_remote.py +8 -10
- examples/warming_up_to_rl/task_app/README.md +6 -2
- examples/warming_up_to_rl/task_app/grpo_crafter.py +234 -35
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +2 -3
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +9 -11
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +131 -114
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +101 -41
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +73 -51
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +14 -6
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +16 -16
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +32 -34
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +94 -31
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +0 -2
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +303 -203
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +21 -23
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +328 -225
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +13 -13
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +4 -3
- synth/__init__.py +14 -0
- synth_ai/__init__.py +26 -4
- synth_ai/api/models/supported.py +376 -0
- synth_ai/api/train/builders.py +128 -21
- synth_ai/api/train/cli.py +80 -64
- synth_ai/api/train/config_finder.py +7 -2
- synth_ai/api/train/env_resolver.py +1 -1
- synth_ai/api/train/pollers.py +2 -1
- synth_ai/api/train/supported_algos.py +139 -0
- synth_ai/api/train/task_app.py +1 -2
- synth_ai/api/train/utils.py +13 -44
- synth_ai/cli/__init__.py +8 -0
- synth_ai/cli/_modal_wrapper.py +28 -0
- synth_ai/cli/_typer_patch.py +49 -0
- synth_ai/cli/balance.py +1 -2
- synth_ai/cli/calc.py +1 -1
- synth_ai/cli/demo.py +2 -1
- synth_ai/cli/recent.py +2 -2
- synth_ai/cli/rl_demo.py +2 -1
- synth_ai/cli/root.py +11 -13
- synth_ai/cli/status.py +2 -2
- synth_ai/cli/task_apps.py +529 -179
- synth_ai/cli/traces.py +6 -4
- synth_ai/cli/watch.py +12 -18
- synth_ai/demo_registry.py +1 -1
- synth_ai/demos/core/cli.py +36 -43
- synth_ai/demos/demo_task_apps/__init__.py +3 -3
- synth_ai/demos/demo_task_apps/core.py +17 -25
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +3 -4
- synth_ai/demos/demo_task_apps/math/app.py +2 -1
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +3 -4
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +16 -18
- synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -1
- synth_ai/environments/examples/crafter_classic/environment.py +76 -1
- synth_ai/environments/reproducibility/tree.py +2 -5
- synth_ai/environments/service/app.py +11 -12
- synth_ai/environments/service/core_routes.py +4 -7
- synth_ai/environments/stateful/engine.py +1 -1
- synth_ai/environments/tasks/core.py +1 -0
- synth_ai/environments/tasks/filters.py +5 -6
- synth_ai/environments/tasks/utils.py +4 -5
- synth_ai/handshake.py +9 -9
- synth_ai/http.py +1 -1
- synth_ai/http_client.py +18 -10
- synth_ai/inference/client.py +15 -5
- synth_ai/jobs/client.py +78 -83
- synth_ai/learning/__init__.py +41 -6
- synth_ai/learning/algorithms.py +14 -0
- synth_ai/learning/client.py +91 -24
- synth_ai/learning/config.py +2 -38
- synth_ai/learning/ft_client.py +4 -59
- synth_ai/learning/health.py +5 -6
- synth_ai/learning/jobs.py +31 -47
- synth_ai/{rl → learning/rl}/__init__.py +14 -4
- synth_ai/learning/rl/client.py +267 -0
- synth_ai/learning/rl/config.py +31 -0
- synth_ai/{rl → learning/rl}/contracts.py +5 -8
- synth_ai/{rl → learning/rl}/env_keys.py +39 -15
- synth_ai/learning/rl/secrets.py +13 -0
- synth_ai/learning/rl_client.py +2 -281
- synth_ai/learning/sft/__init__.py +29 -0
- synth_ai/learning/sft/client.py +68 -0
- synth_ai/learning/sft/config.py +270 -0
- synth_ai/learning/sft/data.py +295 -0
- synth_ai/learning/sse.py +25 -24
- synth_ai/learning/validators.py +25 -28
- synth_ai/lm/__init__.py +21 -47
- synth_ai/main.py +6 -0
- synth_ai/task/__init__.py +25 -27
- synth_ai/task/apps/__init__.py +7 -8
- synth_ai/task/auth.py +8 -8
- synth_ai/task/client.py +14 -14
- synth_ai/task/contracts.py +36 -35
- synth_ai/task/datasets.py +6 -5
- synth_ai/task/errors.py +10 -10
- synth_ai/task/health.py +17 -9
- synth_ai/task/json.py +58 -23
- synth_ai/task/proxy.py +13 -9
- synth_ai/task/rubrics.py +16 -15
- synth_ai/task/server.py +12 -12
- synth_ai/task/tracing_utils.py +4 -4
- synth_ai/task/vendors.py +5 -6
- synth_ai/tracing_v3/__init__.py +2 -0
- synth_ai/tracing_v3/abstractions.py +21 -4
- synth_ai/tracing_v3/decorators.py +18 -16
- synth_ai/tracing_v3/hooks.py +5 -5
- synth_ai/tracing_v3/llm_call_record_helpers.py +6 -6
- synth_ai/tracing_v3/session_tracer.py +40 -14
- synth_ai/tracing_v3/storage/base.py +85 -0
- synth_ai/tracing_v3/storage/config.py +21 -8
- synth_ai/tracing_v3/storage/factory.py +10 -7
- synth_ai/tracing_v3/storage/utils.py +4 -2
- synth_ai/tracing_v3/turso/daemon.py +7 -2
- synth_ai/tracing_v3/turso/models.py +2 -2
- synth_ai/tracing_v3/turso/native_manager.py +1173 -0
- synth_ai/tracing_v3/utils.py +4 -4
- synth_ai/v0/api/__init__.py +8 -0
- synth_ai/v0/api/models/__init__.py +8 -0
- synth_ai/v0/api/models/supported.py +8 -0
- synth_ai/v0/config/__init__.py +15 -0
- synth_ai/v0/config/base_url.py +12 -0
- synth_ai/v0/lm/__init__.py +51 -0
- synth_ai/{lm → v0/lm}/caching/ephemeral.py +2 -2
- synth_ai/{lm → v0/lm}/caching/handler.py +4 -4
- synth_ai/{lm → v0/lm}/caching/initialize.py +1 -1
- synth_ai/{lm → v0/lm}/caching/persistent.py +1 -1
- synth_ai/{lm → v0/lm}/config.py +6 -1
- synth_ai/{lm → v0/lm}/core/all.py +9 -9
- synth_ai/{lm → v0/lm}/core/main.py +6 -6
- synth_ai/{lm → v0/lm}/core/main_v3.py +10 -10
- synth_ai/{lm → v0/lm}/core/synth_models.py +2 -14
- synth_ai/{lm → v0/lm}/core/vendor_clients.py +2 -2
- synth_ai/{lm → v0/lm}/overrides.py +2 -2
- synth_ai/{lm → v0/lm}/provider_support/anthropic.py +4 -4
- synth_ai/{lm → v0/lm}/provider_support/openai.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/handler.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/rehabilitate.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/core/anthropic_api.py +9 -9
- synth_ai/{lm → v0/lm}/vendors/core/gemini_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/mistral_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/openai_api.py +10 -10
- synth_ai/{lm → v0/lm}/vendors/openai_standard.py +8 -8
- synth_ai/{lm → v0/lm}/vendors/openai_standard_responses.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/custom_endpoint.py +3 -3
- synth_ai/{lm → v0/lm}/vendors/supported/deepseek.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/grok.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/groq.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/ollama.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/openrouter.py +3 -3
- synth_ai/{lm → v0/lm}/vendors/supported/together.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/synth_client.py +1 -1
- synth_ai/v0/tracing_v3/__init__.py +10 -0
- synth_ai/v0/tracing_v3/abstractions.py +3 -0
- synth_ai/v0/tracing_v3/decorators.py +3 -0
- synth_ai/v0/tracing_v3/llm_call_record_helpers.py +3 -0
- synth_ai/v0/tracing_v3/session_tracer.py +3 -0
- synth_ai-0.2.9.dev9.dist-info/METADATA +191 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev9.dist-info}/RECORD +268 -238
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev9.dist-info}/top_level.txt +1 -0
- examples/common_old/backend.py +0 -20
- examples/evals_old/README.md +0 -98
- examples/evals_old/__init__.py +0 -6
- examples/evals_old/compare_models.py +0 -1038
- examples/evals_old/example_log.md +0 -145
- examples/evals_old/run_demo.sh +0 -126
- examples/evals_old/trace_analysis.py +0 -270
- examples/finetuning_old/_backup_synth_qwen/config.toml +0 -29
- examples/finetuning_old/_backup_synth_qwen/example_log.md +0 -324
- examples/finetuning_old/_backup_synth_qwen/filter_traces.py +0 -60
- examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +0 -243
- examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +0 -109
- examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +0 -1924
- examples/finetuning_old/_backup_synth_qwen/readme.md +0 -49
- examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +0 -114
- examples/finetuning_old/_backup_synth_qwen/run_demo.sh +0 -195
- examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +0 -119
- examples/finetuning_old/synth_qwen_v1/README.md +0 -68
- examples/finetuning_old/synth_qwen_v1/filter_traces.py +0 -60
- examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +0 -243
- examples/finetuning_old/synth_qwen_v1/finetune.py +0 -46
- examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +0 -71
- examples/finetuning_old/synth_qwen_v1/infer.py +0 -36
- examples/finetuning_old/synth_qwen_v1/poll.py +0 -46
- examples/finetuning_old/synth_qwen_v1/prepare_data.py +0 -35
- examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +0 -109
- examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +0 -1933
- examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +0 -210
- examples/finetuning_old/synth_qwen_v1/run_ft_job.py +0 -237
- examples/finetuning_old/synth_qwen_v1/upload_data.py +0 -34
- examples/finetuning_old/synth_qwen_v1/util.py +0 -152
- examples/rl_old/task_app.py +0 -1131
- examples/warming_up_to_rl/old/event_rewards.md +0 -234
- examples/warming_up_to_rl/old/notes.md +0 -73
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +0 -738
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +0 -580
- synth_ai/experimental/synth_oss.py +0 -445
- synth_ai/learning/filtering.py +0 -0
- synth_ai/learning/offline/dpo.py +0 -0
- synth_ai/learning/offline/providers.py +0 -7
- synth_ai/learning/offline/sft.py +0 -0
- synth_ai/learning/offline/shared.py +0 -0
- synth_ai/learning/online/grpo.py +0 -0
- synth_ai/learning/online/irft.py +0 -0
- synth_ai/learning/prompts/banking77_injection_eval.py +0 -168
- synth_ai/learning/prompts/gepa.py +0 -0
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +0 -211
- synth_ai/learning/prompts/mipro.py +0 -289
- synth_ai/learning/prompts/random_search.py +0 -249
- synth_ai/learning/prompts/run_mipro_banking77.py +0 -172
- synth_ai/learning/prompts/run_random_search_banking77.py +0 -329
- synth_ai/rl/secrets.py +0 -19
- synth_ai/scripts/verify_rewards.py +0 -100
- synth_ai/tracing/__init__.py +0 -30
- synth_ai/tracing_v1/__init__.py +0 -33
- synth_ai/tracing_v3/turso/__init__.py +0 -25
- synth_ai/tracing_v3/turso/manager.py +0 -838
- synth_ai/zyk/__init__.py +0 -30
- synth_ai-0.2.9.dev7.dist-info/METADATA +0 -131
- /synth_ai/{lm → v0/lm}/caching/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/dbs.py +0 -0
- /synth_ai/{lm → v0/lm}/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/core/exceptions.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/monitor.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/statefulness.py +0 -0
- /synth_ai/{lm → v0/lm}/injection.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/suppress_logging.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/inject.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/base.py +0 -0
- /synth_ai/{lm → v0/lm}/unified_interface.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/base.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/synth_dev_api.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/ollama.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/retries.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/supported/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/warmup.py +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev9.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev9.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev9.dist-info}/licenses/LICENSE +0 -0
examples/rl_old/task_app.py
DELETED
|
@@ -1,1131 +0,0 @@
|
|
|
1
|
-
import modal
|
|
2
|
-
from typing import Any, Optional
|
|
3
|
-
from dataclasses import dataclass, field
|
|
4
|
-
import os as _os
|
|
5
|
-
import sys as _sys
|
|
6
|
-
from pathlib import Path as _Path
|
|
7
|
-
import time
|
|
8
|
-
|
|
9
|
-
# Make local 'crafter' importable when running locally
|
|
10
|
-
_HERE = _Path(__file__).resolve()
|
|
11
|
-
_LOCAL_CRAFTER_PARENT = _HERE.parent.parent # points to examples/rl
|
|
12
|
-
if str(_LOCAL_CRAFTER_PARENT) not in _sys.path:
|
|
13
|
-
_sys.path.insert(0, str(_LOCAL_CRAFTER_PARENT))
|
|
14
|
-
if "/opt" not in _sys.path:
|
|
15
|
-
_sys.path.insert(0, "/opt")
|
|
16
|
-
|
|
17
|
-
# Use environment-aware names to avoid collisions across dev/prod
|
|
18
|
-
_env_flag = (
|
|
19
|
-
(
|
|
20
|
-
_os.getenv("SYNTH_BACKEND_URL_OVERRIDE", "")
|
|
21
|
-
or _os.getenv("ENVIRONMENT", "")
|
|
22
|
-
or _os.getenv("APP_ENVIRONMENT", "")
|
|
23
|
-
)
|
|
24
|
-
.strip()
|
|
25
|
-
.lower()
|
|
26
|
-
)
|
|
27
|
-
_is_prod = _env_flag in ("prod", "production")
|
|
28
|
-
|
|
29
|
-
# Secret name must be provided explicitly via TASK_APP_SECRET_NAME
|
|
30
|
-
MODAL_SECRET_NAME = _os.getenv("TASK_APP_SECRET_NAME")
|
|
31
|
-
assert MODAL_SECRET_NAME, "TASK_APP_SECRET_NAME must be set before launching the task app"
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
# Modal app name (overridable via TASK_APP_NAME)
|
|
35
|
-
_default_app_name = "grpo-task-service-sdk-prod" if _is_prod else "grpo-task-service-sdk"
|
|
36
|
-
app = modal.App(_os.getenv("TASK_APP_NAME", _default_app_name))
|
|
37
|
-
|
|
38
|
-
image = (
|
|
39
|
-
modal.Image.debian_slim(python_version="3.11")
|
|
40
|
-
.pip_install(
|
|
41
|
-
[
|
|
42
|
-
"fastapi",
|
|
43
|
-
"uvicorn",
|
|
44
|
-
"pydantic>=2",
|
|
45
|
-
"httpx",
|
|
46
|
-
"requests",
|
|
47
|
-
"tqdm",
|
|
48
|
-
"urllib3>=2.3.0",
|
|
49
|
-
"jsonschema>=4.23.0",
|
|
50
|
-
"typing_extensions>=4.0.0",
|
|
51
|
-
"numpy",
|
|
52
|
-
"pandas",
|
|
53
|
-
"sqlalchemy",
|
|
54
|
-
"aiosqlite",
|
|
55
|
-
"asyncpg>=0.30.0",
|
|
56
|
-
"crafter",
|
|
57
|
-
"pillow",
|
|
58
|
-
"imageio",
|
|
59
|
-
"opensimplex",
|
|
60
|
-
"ruamel.yaml",
|
|
61
|
-
"networkx>=3.4.2",
|
|
62
|
-
"redis>=6.2.0",
|
|
63
|
-
"duckdb>=1.0.0",
|
|
64
|
-
"ty>=0.0.1a5",
|
|
65
|
-
"toml>=0.10.2",
|
|
66
|
-
"libsql>=0.1.8",
|
|
67
|
-
"python-dotenv",
|
|
68
|
-
"anthropic",
|
|
69
|
-
"openai",
|
|
70
|
-
"diskcache",
|
|
71
|
-
"backoff",
|
|
72
|
-
"groq",
|
|
73
|
-
"google-genai",
|
|
74
|
-
"google-generativeai",
|
|
75
|
-
"google-api-python-client",
|
|
76
|
-
"google-api-core>=2.25.1",
|
|
77
|
-
"google-auth",
|
|
78
|
-
"google-auth-httplib2",
|
|
79
|
-
"opentelemetry-api>=1.26.0,<1.27.0",
|
|
80
|
-
"opentelemetry-sdk>=1.26.0,<1.27.0",
|
|
81
|
-
"opentelemetry-exporter-otlp-proto-http>=1.26.0,<1.27.0",
|
|
82
|
-
"wrapt",
|
|
83
|
-
"langfuse>=2.53.9,<3.0.0",
|
|
84
|
-
"together",
|
|
85
|
-
"mistralai>=1.9.2",
|
|
86
|
-
"click>=8.1.0",
|
|
87
|
-
"textual>=1.1.0",
|
|
88
|
-
"openai-harmony>=0.0.1",
|
|
89
|
-
"aiohttp>=3.8.0",
|
|
90
|
-
"datasets>=4.0.0",
|
|
91
|
-
"gymnasium>=0.29.1",
|
|
92
|
-
"minigrid>=2.3.1",
|
|
93
|
-
]
|
|
94
|
-
)
|
|
95
|
-
# Bundle the crafter module into the image for imports at runtime (absolute path)
|
|
96
|
-
.add_local_dir(
|
|
97
|
-
str((_HERE.parent / "crafter_task_app_helpers").resolve()), "/opt/crafter_task_app_helpers"
|
|
98
|
-
)
|
|
99
|
-
# Bundle synth_ai package to import full environment implementation.
|
|
100
|
-
# Resolve repo root robustly (examples/rl/task_app.py -> repo_root = examples/rl/../../..)
|
|
101
|
-
.add_local_dir(str((_HERE.parent.parent.parent / "synth_ai").resolve()), "/opt/synth_ai")
|
|
102
|
-
)
|
|
103
|
-
|
|
104
|
-
# --- OpenAI payload sanitizer (local) ---
|
|
105
|
-
OPENAI_MAX_COMPLETION_TOKENS_MIN = 16000
|
|
106
|
-
OPENAI_REMOVE_FIELDS = (
|
|
107
|
-
"stop_after_tool_calls",
|
|
108
|
-
"thinking_mode",
|
|
109
|
-
"thinking_budget",
|
|
110
|
-
"reasoning",
|
|
111
|
-
)
|
|
112
|
-
OPENAI_REMOVE_SAMPLING_FIELDS = ("temperature", "top_p")
|
|
113
|
-
OPENAI_TOOL_CHOICE_FORCED = {"type": "function", "function": {"name": "interact"}}
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
def prepare_inference_payload_for_model(
|
|
117
|
-
model: str | None, payload: dict[str, Any]
|
|
118
|
-
) -> dict[str, Any]:
|
|
119
|
-
"""Sanitize payload for OpenAI API.
|
|
120
|
-
|
|
121
|
-
- Always strip Synth-specific fields not supported by OpenAI (e.g., stop_after_tool_calls).
|
|
122
|
-
- For gpt-5 family: map max_tokens->max_completion_tokens, enforce tool_choice, disable parallel tools,
|
|
123
|
-
and remove vendor-specific sampling fields.
|
|
124
|
-
"""
|
|
125
|
-
out = dict(payload)
|
|
126
|
-
# Always remove unsupported fields for OpenAI
|
|
127
|
-
for k in OPENAI_REMOVE_FIELDS:
|
|
128
|
-
if k in out:
|
|
129
|
-
out.pop(k)
|
|
130
|
-
|
|
131
|
-
# gpt-5 family specific adjustments
|
|
132
|
-
if model and "gpt-5" in model:
|
|
133
|
-
if "max_completion_tokens" not in out and "max_tokens" in out:
|
|
134
|
-
out["max_completion_tokens"] = out.pop("max_tokens")
|
|
135
|
-
# Ensure we don't send both
|
|
136
|
-
if "max_tokens" in out:
|
|
137
|
-
out.pop("max_tokens")
|
|
138
|
-
for k in OPENAI_REMOVE_SAMPLING_FIELDS:
|
|
139
|
-
if k in out:
|
|
140
|
-
out.pop(k)
|
|
141
|
-
mct = out.get("max_completion_tokens")
|
|
142
|
-
if not isinstance(mct, int) or mct < OPENAI_MAX_COMPLETION_TOKENS_MIN:
|
|
143
|
-
out["max_completion_tokens"] = OPENAI_MAX_COMPLETION_TOKENS_MIN
|
|
144
|
-
out["tool_choice"] = OPENAI_TOOL_CHOICE_FORCED
|
|
145
|
-
out["parallel_tool_calls"] = False
|
|
146
|
-
return out
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
@app.function(
|
|
150
|
-
image=image,
|
|
151
|
-
secrets=[modal.Secret.from_name(MODAL_SECRET_NAME)],
|
|
152
|
-
min_containers=1,
|
|
153
|
-
max_containers=1,
|
|
154
|
-
)
|
|
155
|
-
@modal.asgi_app()
|
|
156
|
-
def fastapi_app():
|
|
157
|
-
# Import FastAPI/Pydantic inside the container runtime to avoid local import errors
|
|
158
|
-
from fastapi import FastAPI, Body, HTTPException, status
|
|
159
|
-
from starlette.requests import Request
|
|
160
|
-
from fastapi.responses import JSONResponse
|
|
161
|
-
from pydantic import BaseModel
|
|
162
|
-
import logging
|
|
163
|
-
import sys
|
|
164
|
-
import os
|
|
165
|
-
import httpx
|
|
166
|
-
|
|
167
|
-
# Logger for debug output
|
|
168
|
-
logger = logging.getLogger(__name__)
|
|
169
|
-
|
|
170
|
-
# Preload synth_ai modules and vendor deps so missing packages surface early
|
|
171
|
-
if "/opt/synth_ai" not in sys.path:
|
|
172
|
-
sys.path.insert(0, "/opt/synth_ai")
|
|
173
|
-
# Ensure tracing DB points to a writable location in the container
|
|
174
|
-
os.environ.setdefault("TURSO_LOCAL_DB_URL", "sqlite+aiosqlite:////tmp/synth_ai.db")
|
|
175
|
-
|
|
176
|
-
import importlib
|
|
177
|
-
|
|
178
|
-
preload_modules = [
|
|
179
|
-
# synth_ai core
|
|
180
|
-
"synth_ai",
|
|
181
|
-
"synth_ai.lm",
|
|
182
|
-
"synth_ai.lm.core.main",
|
|
183
|
-
"synth_ai.lm.core.main_v3",
|
|
184
|
-
"synth_ai.lm.core.vendor_clients",
|
|
185
|
-
"synth_ai.lm.core.all",
|
|
186
|
-
# vendors
|
|
187
|
-
"synth_ai.lm.vendors.core.anthropic_api",
|
|
188
|
-
"synth_ai.lm.vendors.core.openai_api",
|
|
189
|
-
"synth_ai.lm.vendors.openai_standard",
|
|
190
|
-
"synth_ai.lm.vendors.core.gemini_api",
|
|
191
|
-
# environments
|
|
192
|
-
"synth_ai.environments",
|
|
193
|
-
"synth_ai.environments.environment.rewards.core",
|
|
194
|
-
"synth_ai.environments.examples.crafter_classic.environment",
|
|
195
|
-
# tracing
|
|
196
|
-
"synth_ai.tracing_v3.turso.models",
|
|
197
|
-
"synth_ai.tracing_v3.turso.manager",
|
|
198
|
-
# common 3p libs these modules rely on
|
|
199
|
-
"anthropic",
|
|
200
|
-
"openai",
|
|
201
|
-
"groq",
|
|
202
|
-
"google.genai",
|
|
203
|
-
"google.generativeai",
|
|
204
|
-
"googleapiclient.discovery",
|
|
205
|
-
"google.auth",
|
|
206
|
-
"google_auth_httplib2",
|
|
207
|
-
"requests",
|
|
208
|
-
"tqdm",
|
|
209
|
-
"langfuse",
|
|
210
|
-
"diskcache",
|
|
211
|
-
"backoff",
|
|
212
|
-
"together",
|
|
213
|
-
"dotenv",
|
|
214
|
-
"grpc",
|
|
215
|
-
]
|
|
216
|
-
for mod in preload_modules:
|
|
217
|
-
try:
|
|
218
|
-
importlib.import_module(mod)
|
|
219
|
-
except Exception as _e:
|
|
220
|
-
print(f"[task:crafter] preload missing/err: {mod}: {_e}", flush=True)
|
|
221
|
-
|
|
222
|
-
# Make packaged local crafter modules importable ahead of site-packages 'crafter'
|
|
223
|
-
if "/opt/crafter_task_app_helpers" not in sys.path:
|
|
224
|
-
sys.path.insert(0, "/opt/crafter_task_app_helpers")
|
|
225
|
-
if "/opt" not in sys.path:
|
|
226
|
-
sys.path.insert(0, "/opt")
|
|
227
|
-
if "/opt/synth_ai" not in sys.path:
|
|
228
|
-
sys.path.insert(0, "/opt/synth_ai")
|
|
229
|
-
from crafter_task_app_helpers.env import EnvRegistry
|
|
230
|
-
from crafter_task_app_helpers.rewards import compute_decision_rewards
|
|
231
|
-
from crafter_task_app_helpers.config import ACTION_SPACE, ENV_NAME
|
|
232
|
-
from crafter_task_app_helpers.policy import CrafterPolicy
|
|
233
|
-
|
|
234
|
-
_registry = EnvRegistry()
|
|
235
|
-
|
|
236
|
-
# --- JSON sanitization for responses (convert numpy -> python primitives, arrays -> shapes) ---
|
|
237
|
-
import numpy as _np
|
|
238
|
-
|
|
239
|
-
def _to_jsonable(value):
|
|
240
|
-
# Numpy types first: scalars vs arrays
|
|
241
|
-
if isinstance(value, (_np.generic,)):
|
|
242
|
-
return value.item()
|
|
243
|
-
if isinstance(value, _np.ndarray):
|
|
244
|
-
return f"<ndarray shape={tuple(value.shape)} dtype={str(value.dtype)}>"
|
|
245
|
-
# Basic containers
|
|
246
|
-
if isinstance(value, dict):
|
|
247
|
-
return {k: _to_jsonable(v) for k, v in value.items()}
|
|
248
|
-
if isinstance(value, (list, tuple)):
|
|
249
|
-
return [_to_jsonable(v) for v in value]
|
|
250
|
-
# Sets to lists
|
|
251
|
-
if isinstance(value, set):
|
|
252
|
-
return [_to_jsonable(v) for v in value]
|
|
253
|
-
return value
|
|
254
|
-
|
|
255
|
-
class InitRequest(BaseModel):
|
|
256
|
-
env_name: str | None = None
|
|
257
|
-
env_config: dict[str, Any] | None = None
|
|
258
|
-
|
|
259
|
-
class StepRequest(BaseModel):
|
|
260
|
-
env_id: str
|
|
261
|
-
action: str
|
|
262
|
-
|
|
263
|
-
api = FastAPI(debug=True)
|
|
264
|
-
|
|
265
|
-
# Basic root endpoints so HEAD/GET / succeeds for preflight checks
|
|
266
|
-
@api.head("/")
|
|
267
|
-
def head_root(): # type: ignore[empty-body]
|
|
268
|
-
return JSONResponse(status_code=status.HTTP_200_OK, content=None)
|
|
269
|
-
|
|
270
|
-
@api.get("/")
|
|
271
|
-
def get_root():
|
|
272
|
-
return {"ok": True, "service": "synth-ai task app"}
|
|
273
|
-
|
|
274
|
-
@api.get("/health")
|
|
275
|
-
def health(request: Request):
|
|
276
|
-
env_key = os.environ.get("ENVIRONMENT_API_KEY")
|
|
277
|
-
if not env_key:
|
|
278
|
-
raise HTTPException(
|
|
279
|
-
status_code=503,
|
|
280
|
-
detail="Auth not configured: missing ENVIRONMENT_API_KEY in task service environment",
|
|
281
|
-
)
|
|
282
|
-
# Authorize using all header variants; avoid typed Header to prevent 422s
|
|
283
|
-
try:
|
|
284
|
-
from synth_ai.task.auth import is_api_key_header_authorized
|
|
285
|
-
|
|
286
|
-
authorized = is_api_key_header_authorized(request)
|
|
287
|
-
except Exception:
|
|
288
|
-
# Fallback: check only x-api-key
|
|
289
|
-
header_key = request.headers.get("x-api-key")
|
|
290
|
-
authorized = bool(header_key) and (header_key == env_key)
|
|
291
|
-
if not authorized:
|
|
292
|
-
# Soft 200 with authorized flag so CLI preflight can proceed
|
|
293
|
-
prefix = env_key[: max(1, len(env_key) // 2)]
|
|
294
|
-
content = {"status": "healthy", "authorized": False, "expected_api_key_prefix": prefix}
|
|
295
|
-
return JSONResponse(status_code=200, content=content)
|
|
296
|
-
return {"healthy": True, "authorized": True}
|
|
297
|
-
|
|
298
|
-
# Rollout health endpoint used by CLI configure flow
|
|
299
|
-
@api.get("/health/rollout")
|
|
300
|
-
def health_rollout(request: Request):
|
|
301
|
-
expected = os.environ.get("ENVIRONMENT_API_KEY")
|
|
302
|
-
if not expected:
|
|
303
|
-
raise HTTPException(
|
|
304
|
-
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
305
|
-
detail="Missing ENVIRONMENT_API_KEY in service env",
|
|
306
|
-
)
|
|
307
|
-
try:
|
|
308
|
-
from synth_ai.task.auth import is_api_key_header_authorized
|
|
309
|
-
|
|
310
|
-
authorized = is_api_key_header_authorized(request)
|
|
311
|
-
except Exception:
|
|
312
|
-
header_key = request.headers.get("x-api-key")
|
|
313
|
-
authorized = bool(header_key) and (header_key == expected)
|
|
314
|
-
if not authorized:
|
|
315
|
-
prefix = expected[: max(1, len(expected) // 2)]
|
|
316
|
-
content = {"status": "healthy", "authorized": False, "expected_api_key_prefix": prefix}
|
|
317
|
-
return JSONResponse(status_code=200, content=content)
|
|
318
|
-
return {"ok": True, "authorized": True}
|
|
319
|
-
|
|
320
|
-
# Log and surface 422 validation errors with header presence
|
|
321
|
-
from fastapi.exceptions import RequestValidationError
|
|
322
|
-
|
|
323
|
-
@api.exception_handler(RequestValidationError)
|
|
324
|
-
async def _on_validation_error(request: Request, exc: RequestValidationError):
|
|
325
|
-
try:
|
|
326
|
-
hdr = request.headers
|
|
327
|
-
snapshot = {
|
|
328
|
-
"path": str(getattr(request, "url").path),
|
|
329
|
-
"have_x_api_key": bool(hdr.get("x-api-key")),
|
|
330
|
-
"have_x_api_keys": bool(hdr.get("x-api-keys")),
|
|
331
|
-
"have_authorization": bool(hdr.get("authorization")),
|
|
332
|
-
"errors": exc.errors()[:5],
|
|
333
|
-
}
|
|
334
|
-
print("[422] validation", snapshot, flush=True)
|
|
335
|
-
except Exception:
|
|
336
|
-
pass
|
|
337
|
-
return JSONResponse(
|
|
338
|
-
status_code=422, content={"status": "invalid", "detail": exc.errors()[:5]}
|
|
339
|
-
)
|
|
340
|
-
|
|
341
|
-
@api.post(f"/env/{ENV_NAME}/initialize")
|
|
342
|
-
async def initialize(req: InitRequest, request: Request):
|
|
343
|
-
# Optionally tie the environment to a run_id header so we can guarantee isolation
|
|
344
|
-
run_id_hdr = request.headers.get("X-Run-Id") or request.headers.get("X-Run-ID")
|
|
345
|
-
env_id, obs = await _registry.initialize(req.env_config, run_id=run_id_hdr)
|
|
346
|
-
return {"env_id": env_id, "observation": _to_jsonable(obs)}
|
|
347
|
-
|
|
348
|
-
@api.post(f"/env/{ENV_NAME}/step")
|
|
349
|
-
async def step(req: StepRequest):
|
|
350
|
-
obs, reward, done, info = await _registry.step(req.env_id, req.action)
|
|
351
|
-
return {
|
|
352
|
-
"observation": _to_jsonable(obs),
|
|
353
|
-
"reward": float(reward) if isinstance(reward, (int, float)) else reward,
|
|
354
|
-
"done": bool(done),
|
|
355
|
-
"info": _to_jsonable(info) if info is not None else None,
|
|
356
|
-
}
|
|
357
|
-
|
|
358
|
-
@api.post(f"/env/{ENV_NAME}/terminate")
|
|
359
|
-
async def terminate(req: dict[str, str] = Body(...)):
|
|
360
|
-
env_id = str(req.get("env_id"))
|
|
361
|
-
return await _registry.terminate(env_id)
|
|
362
|
-
|
|
363
|
-
@api.get("/actions")
|
|
364
|
-
def actions():
|
|
365
|
-
return {"actions": ACTION_SPACE}
|
|
366
|
-
|
|
367
|
-
# OpenAI proxy: forward chat/completions to OpenAI using env OPENAI_API_KEY
|
|
368
|
-
@api.post("/proxy/v1/chat/completions")
|
|
369
|
-
def proxy_chat_completions(req: dict[str, Any]):
|
|
370
|
-
openai_key = os.environ.get("OPENAI_API_KEY")
|
|
371
|
-
if not openai_key:
|
|
372
|
-
raise HTTPException(
|
|
373
|
-
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
374
|
-
detail="Missing OPENAI_API_KEY in task service environment",
|
|
375
|
-
)
|
|
376
|
-
# Sanitize payload for OpenAI models (e.g., gpt-5-*)
|
|
377
|
-
model = req.get("model")
|
|
378
|
-
payload = prepare_inference_payload_for_model(model, req)
|
|
379
|
-
headers = {"Authorization": f"Bearer {openai_key}"}
|
|
380
|
-
# Increase timeout for proxy calls (models may be slower)
|
|
381
|
-
with httpx.Client(timeout=120.0) as client:
|
|
382
|
-
resp = client.post(
|
|
383
|
-
"https://api.openai.com/v1/chat/completions", json=payload, headers=headers
|
|
384
|
-
)
|
|
385
|
-
try:
|
|
386
|
-
data = resp.json()
|
|
387
|
-
except Exception:
|
|
388
|
-
data = {"error": "invalid_json", "raw": resp.text[:800]}
|
|
389
|
-
if resp.status_code >= 400:
|
|
390
|
-
return JSONResponse(status_code=resp.status_code, content=data)
|
|
391
|
-
return data
|
|
392
|
-
|
|
393
|
-
# Unified rollout schema imported from SDK task contracts
|
|
394
|
-
from synth_ai.task.contracts import (
|
|
395
|
-
RolloutEnvSpec,
|
|
396
|
-
RolloutPolicySpec,
|
|
397
|
-
RolloutRecordConfig,
|
|
398
|
-
RolloutSafetyConfig,
|
|
399
|
-
RolloutRequest,
|
|
400
|
-
RolloutStep,
|
|
401
|
-
RolloutTrajectory,
|
|
402
|
-
RolloutMetrics,
|
|
403
|
-
RolloutResponse,
|
|
404
|
-
)
|
|
405
|
-
|
|
406
|
-
@api.post("/rollout", response_model=RolloutResponse)
|
|
407
|
-
async def rollout(req: RolloutRequest, request: Request):
|
|
408
|
-
expected = os.environ.get("ENVIRONMENT_API_KEY")
|
|
409
|
-
if not expected:
|
|
410
|
-
logger.error("rollout.auth.misconfigured: missing ENVIRONMENT_API_KEY")
|
|
411
|
-
raise HTTPException(
|
|
412
|
-
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
413
|
-
detail="Auth not configured: missing ENVIRONMENT_API_KEY",
|
|
414
|
-
)
|
|
415
|
-
# Compute masked diagnostics (never log full keys)
|
|
416
|
-
try:
|
|
417
|
-
exp_len = len(expected)
|
|
418
|
-
exp_suf = expected[-5:] if exp_len >= 5 else "" # last 5 chars
|
|
419
|
-
# Collect candidates from headers: X-API-Key, X-API-Keys (CSV), Authorization: Bearer
|
|
420
|
-
hdr = request.headers
|
|
421
|
-
single = hdr.get("x-api-key") or ""
|
|
422
|
-
multi = [p.strip() for p in (hdr.get("x-api-keys") or "").split(",") if p.strip()]
|
|
423
|
-
auth = hdr.get("authorization") or ""
|
|
424
|
-
bearer = auth.split(" ", 1)[1].strip() if auth.lower().startswith("bearer ") else ""
|
|
425
|
-
candidates = [c for c in [single, bearer, *multi] if c]
|
|
426
|
-
# Assert server sees ALL keys sent by client
|
|
427
|
-
if multi:
|
|
428
|
-
logger.info(
|
|
429
|
-
"rollout.auth.candidates: n=%s first15=%s",
|
|
430
|
-
len(candidates),
|
|
431
|
-
[c[:15] for c in candidates],
|
|
432
|
-
)
|
|
433
|
-
got_len = len(single or bearer or "")
|
|
434
|
-
got_suf = (single or bearer or "")[-5:] if got_len >= 5 else ""
|
|
435
|
-
except Exception:
|
|
436
|
-
exp_len = -1
|
|
437
|
-
exp_suf = ""
|
|
438
|
-
got_len = -1
|
|
439
|
-
got_suf = ""
|
|
440
|
-
# Authorize if ANY candidate matches expected
|
|
441
|
-
authorized = any(c == expected for c in candidates)
|
|
442
|
-
if not authorized:
|
|
443
|
-
logger.warning(
|
|
444
|
-
"rollout.auth.failed: have_any=%s expect_len=%s expect_last5=%s got_len=%s got_last5=%s",
|
|
445
|
-
bool(candidates),
|
|
446
|
-
exp_len,
|
|
447
|
-
exp_suf,
|
|
448
|
-
got_len,
|
|
449
|
-
got_suf,
|
|
450
|
-
)
|
|
451
|
-
raise HTTPException(
|
|
452
|
-
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or missing API key"
|
|
453
|
-
)
|
|
454
|
-
else:
|
|
455
|
-
logger.info(
|
|
456
|
-
"rollout.auth.ok: expect_len=%s expect_last5=%s got_len=%s got_last5=%s",
|
|
457
|
-
exp_len,
|
|
458
|
-
exp_suf,
|
|
459
|
-
got_len,
|
|
460
|
-
got_suf,
|
|
461
|
-
)
|
|
462
|
-
|
|
463
|
-
# Extract policy config
|
|
464
|
-
inference_url = req.policy.config["inference_url"]
|
|
465
|
-
model = req.policy.config.get("model")
|
|
466
|
-
max_steps = int(req.env.config.get("max_steps_per_episode", 10))
|
|
467
|
-
policy = CrafterPolicy(inference_url=inference_url, model=model)
|
|
468
|
-
|
|
469
|
-
# Debug: request summary
|
|
470
|
-
print(
|
|
471
|
-
"[task:crafter] ROLLOUT req: ",
|
|
472
|
-
{
|
|
473
|
-
"run_id": req.run_id,
|
|
474
|
-
"env": req.env.env_name,
|
|
475
|
-
"seed": req.env.seed,
|
|
476
|
-
"ops": len(req.ops),
|
|
477
|
-
"model": model,
|
|
478
|
-
"inference_url": inference_url,
|
|
479
|
-
"max_steps": max_steps,
|
|
480
|
-
},
|
|
481
|
-
flush=True,
|
|
482
|
-
)
|
|
483
|
-
|
|
484
|
-
# Initialize env (preemptively terminate any prior instance for this run to avoid sharing)
|
|
485
|
-
cfg = dict(req.env.config or {})
|
|
486
|
-
if req.env.seed is not None:
|
|
487
|
-
cfg["seed"] = int(req.env.seed)
|
|
488
|
-
env_id, observation = await _registry.initialize(cfg, run_id=req.run_id)
|
|
489
|
-
|
|
490
|
-
trajectory_steps: list[RolloutStep] = []
|
|
491
|
-
# Track per-decision achievement flips for stepwise shaping
|
|
492
|
-
decision_summaries: list[dict[str, Any]] = []
|
|
493
|
-
prev_ach: dict[str, bool] | None = None
|
|
494
|
-
total_reward = 0.0
|
|
495
|
-
ops_executed = 0
|
|
496
|
-
pending_tool_calls: list[dict[str, Any]] | None = None
|
|
497
|
-
try:
|
|
498
|
-
for op in req.ops:
|
|
499
|
-
if ops_executed >= req.safety.max_ops:
|
|
500
|
-
break
|
|
501
|
-
if op == "agent":
|
|
502
|
-
# Format current observation for the prompt
|
|
503
|
-
# Cache for mapping semantic ids to names
|
|
504
|
-
_id_to_item_cache: list[str] | None = None
|
|
505
|
-
|
|
506
|
-
def _ensure_semantic_mapping() -> list[str] | None:
|
|
507
|
-
nonlocal _id_to_item_cache
|
|
508
|
-
if _id_to_item_cache is not None:
|
|
509
|
-
return _id_to_item_cache
|
|
510
|
-
# Build mapping using crafter's internal ids
|
|
511
|
-
import itertools as _it
|
|
512
|
-
import crafter as _crafter
|
|
513
|
-
|
|
514
|
-
dummy = None
|
|
515
|
-
try:
|
|
516
|
-
dummy = _crafter.Env()
|
|
517
|
-
max_id = (
|
|
518
|
-
max(
|
|
519
|
-
max(dummy._world._mat_ids.values()),
|
|
520
|
-
max(dummy._sem_view._obj_ids.values()),
|
|
521
|
-
)
|
|
522
|
-
+ 1
|
|
523
|
-
)
|
|
524
|
-
id_to_item = ["void"] * max_id
|
|
525
|
-
for name, ind in _it.chain(
|
|
526
|
-
dummy._world._mat_ids.items(), dummy._sem_view._obj_ids.items()
|
|
527
|
-
):
|
|
528
|
-
if name is None:
|
|
529
|
-
clean = "none"
|
|
530
|
-
elif hasattr(name, "__name__"):
|
|
531
|
-
clean = name.__name__
|
|
532
|
-
else:
|
|
533
|
-
clean = str(name)
|
|
534
|
-
id_to_item[ind] = clean.lower()
|
|
535
|
-
_id_to_item_cache = id_to_item
|
|
536
|
-
finally:
|
|
537
|
-
if dummy is not None:
|
|
538
|
-
try:
|
|
539
|
-
dummy.close()
|
|
540
|
-
except Exception:
|
|
541
|
-
pass
|
|
542
|
-
return _id_to_item_cache
|
|
543
|
-
|
|
544
|
-
def _format_obs(obs: dict[str, Any]) -> str:
|
|
545
|
-
if not isinstance(obs, dict):
|
|
546
|
-
# Avoid dumping raw matrices; encourage exploration to gather context
|
|
547
|
-
return "no salient state; explore to gather context"
|
|
548
|
-
inv = obs.get("inventory") or {}
|
|
549
|
-
pos = obs.get("player_position")
|
|
550
|
-
steps = obs.get("num_steps_taken")
|
|
551
|
-
direction = obs.get("player_direction")
|
|
552
|
-
ach = obs.get("achievements_status") or {}
|
|
553
|
-
inv_lines = ", ".join(f"{k}:{v}" for k, v in inv.items() if v)
|
|
554
|
-
ach_on = [k for k, v in ach.items() if v]
|
|
555
|
-
lines = []
|
|
556
|
-
if pos is not None:
|
|
557
|
-
px, py = int(pos[0]), int(pos[1])
|
|
558
|
-
lines.append(f"position: (x={px}, y={py})")
|
|
559
|
-
if direction is not None:
|
|
560
|
-
dx, dy = int(direction[0]), int(direction[1])
|
|
561
|
-
dir_label = {
|
|
562
|
-
(1, 0): "→ east/right",
|
|
563
|
-
(-1, 0): "← west/left",
|
|
564
|
-
(0, 1): "↓ south/down",
|
|
565
|
-
(0, -1): "↑ north/up",
|
|
566
|
-
(0, 0): "• idle",
|
|
567
|
-
}.get((dx, dy), f"({dx},{dy})")
|
|
568
|
-
lines.append(f"direction: {dir_label}")
|
|
569
|
-
if steps is not None:
|
|
570
|
-
lines.append(f"steps: {int(steps)}")
|
|
571
|
-
if inv_lines:
|
|
572
|
-
lines.append(f"inventory: {inv_lines}")
|
|
573
|
-
if ach:
|
|
574
|
-
all_achievements = list(ach.keys())
|
|
575
|
-
lines.append(f"achievements_available: {', '.join(all_achievements)}")
|
|
576
|
-
lines.append(
|
|
577
|
-
f"achievements_unlocked: {', '.join(ach_on)}"
|
|
578
|
-
if ach_on
|
|
579
|
-
else "achievements_unlocked: "
|
|
580
|
-
)
|
|
581
|
-
lines.append(
|
|
582
|
-
f"achievements_progress: {len(ach_on)}/{len(all_achievements)}"
|
|
583
|
-
)
|
|
584
|
-
# Local surroundings (7x7) using semantic_map
|
|
585
|
-
smap = obs.get("semantic_map")
|
|
586
|
-
if smap is not None and pos is not None:
|
|
587
|
-
try:
|
|
588
|
-
px, py = int(pos[0]), int(pos[1])
|
|
589
|
-
view_size = 7
|
|
590
|
-
half = view_size // 2
|
|
591
|
-
id_to_item = _ensure_semantic_mapping() or []
|
|
592
|
-
grid_rows: list[str] = []
|
|
593
|
-
# Build matrix centered at player, then transpose for human-friendly view
|
|
594
|
-
matrix: list[list[str]] = []
|
|
595
|
-
for dy in range(-half, half + 1):
|
|
596
|
-
row: list[str] = []
|
|
597
|
-
for dx in range(-half, half + 1):
|
|
598
|
-
x, y = px + dx, py + dy
|
|
599
|
-
if not (0 <= x < smap.shape[0] and 0 <= y < smap.shape[1]):
|
|
600
|
-
row.append("void")
|
|
601
|
-
elif dx == 0 and dy == 0:
|
|
602
|
-
row.append("player")
|
|
603
|
-
else:
|
|
604
|
-
idx = int(smap[x, y])
|
|
605
|
-
name = (
|
|
606
|
-
id_to_item[idx]
|
|
607
|
-
if 0 <= idx < len(id_to_item)
|
|
608
|
-
else str(idx)
|
|
609
|
-
)
|
|
610
|
-
row.append(name)
|
|
611
|
-
matrix.append(row)
|
|
612
|
-
# Transpose to match visual orientation
|
|
613
|
-
transposed = list(zip(*matrix))
|
|
614
|
-
for row in transposed:
|
|
615
|
-
grid_rows.append(" ".join(row))
|
|
616
|
-
if grid_rows:
|
|
617
|
-
lines.append(f"Local Map View (7x7):\n" + "\n".join(grid_rows))
|
|
618
|
-
except Exception:
|
|
619
|
-
# If any issue occurs, skip map rendering without crashing
|
|
620
|
-
pass
|
|
621
|
-
if not lines:
|
|
622
|
-
lines.append("no salient state; explore to gather context")
|
|
623
|
-
return "\n".join(lines)
|
|
624
|
-
|
|
625
|
-
# Build compact context from last few tool calls (gpt-5-nano friendly)
|
|
626
|
-
lines: list[str] = []
|
|
627
|
-
for rec in reversed(trajectory_steps):
|
|
628
|
-
if len(lines) >= 3:
|
|
629
|
-
break
|
|
630
|
-
tcs = rec.tool_calls
|
|
631
|
-
if not tcs:
|
|
632
|
-
continue
|
|
633
|
-
tc0 = tcs[0] if isinstance(tcs, list) and tcs else None
|
|
634
|
-
if not isinstance(tc0, dict):
|
|
635
|
-
continue
|
|
636
|
-
name = tc0.get("tool_name") or tc0.get("name") or "unknown"
|
|
637
|
-
args = tc0.get("arguments")
|
|
638
|
-
lines.append(f"- {name}: {args}")
|
|
639
|
-
context_text = "Previous tool calls (most recent first):\n" + (
|
|
640
|
-
"\n".join(lines) if lines else "- none"
|
|
641
|
-
)
|
|
642
|
-
obs_text = _format_obs(observation)
|
|
643
|
-
combined_text = f"Current observation:\n{obs_text}\n\n{context_text}"
|
|
644
|
-
payload = policy.build_inference_request(
|
|
645
|
-
combined_text, history=[], turn=len(trajectory_steps)
|
|
646
|
-
)
|
|
647
|
-
# Debug: print the full prompt content in a stable labeled block for grepability
|
|
648
|
-
try:
|
|
649
|
-
print("PROMPT_DUMP_BEGIN")
|
|
650
|
-
print(combined_text)
|
|
651
|
-
print("PROMPT_DUMP_END")
|
|
652
|
-
except Exception:
|
|
653
|
-
pass
|
|
654
|
-
# Debug: print user prompt and achievements unlocked list
|
|
655
|
-
try:
|
|
656
|
-
_msgs = payload.get("messages", [])
|
|
657
|
-
_last_user = None
|
|
658
|
-
for _m in reversed(_msgs):
|
|
659
|
-
if isinstance(_m, dict) and _m.get("role") == "user":
|
|
660
|
-
_last_user = _m
|
|
661
|
-
break
|
|
662
|
-
if _last_user is not None:
|
|
663
|
-
_content = _last_user.get("content")
|
|
664
|
-
print("[task:crafter] user prompt:", _content, flush=True)
|
|
665
|
-
except Exception:
|
|
666
|
-
pass
|
|
667
|
-
try:
|
|
668
|
-
_ach = (
|
|
669
|
-
observation.get("achievements_status")
|
|
670
|
-
if isinstance(observation, dict)
|
|
671
|
-
else {}
|
|
672
|
-
)
|
|
673
|
-
_ach_on = [k for k, v in (_ach or {}).items() if v]
|
|
674
|
-
print(f"[task:crafter] achievements_unlocked: {_ach_on}", flush=True)
|
|
675
|
-
except Exception:
|
|
676
|
-
pass
|
|
677
|
-
|
|
678
|
-
# Prepare payload based on model family (OpenAI vs vLLM)
|
|
679
|
-
def _prepare_payload(p: dict, mdl: str | None) -> dict:
|
|
680
|
-
return prepare_inference_payload_for_model(mdl, p)
|
|
681
|
-
|
|
682
|
-
# Debug: payload shape
|
|
683
|
-
print(
|
|
684
|
-
"[task:crafter] inference payload: ",
|
|
685
|
-
{
|
|
686
|
-
"has_model": bool(payload.get("model") is not None),
|
|
687
|
-
"messages": payload.get("messages", []),
|
|
688
|
-
"tools": isinstance(payload.get("tools"), list),
|
|
689
|
-
"tool_choice": payload.get("tool_choice"),
|
|
690
|
-
"stop_after_tool_calls": payload.get("stop_after_tool_calls"),
|
|
691
|
-
},
|
|
692
|
-
flush=True,
|
|
693
|
-
)
|
|
694
|
-
headers: dict[str, str] = {}
|
|
695
|
-
_okey = os.environ.get("OPENAI_API_KEY")
|
|
696
|
-
# Configure granular timeouts for slow model/tool runs
|
|
697
|
-
_timeouts = httpx.Timeout(connect=10.0, read=180.0, write=60.0, pool=60.0)
|
|
698
|
-
with httpx.Client(timeout=_timeouts) as client:
|
|
699
|
-
# Decide endpoint: avoid calling our own /proxy inside the same request
|
|
700
|
-
_direct = "api.openai.com" in inference_url
|
|
701
|
-
if _direct:
|
|
702
|
-
# Call OpenAI directly
|
|
703
|
-
if _okey:
|
|
704
|
-
headers["Authorization"] = f"Bearer {_okey}"
|
|
705
|
-
to_send = _prepare_payload(payload, model)
|
|
706
|
-
endpoint_base = "https://api.openai.com"
|
|
707
|
-
else:
|
|
708
|
-
# Non-OpenAI inference endpoint
|
|
709
|
-
to_send = payload
|
|
710
|
-
endpoint_base = inference_url
|
|
711
|
-
# If targeting Synth proxy, attach backend auth
|
|
712
|
-
if "/proxy" in endpoint_base:
|
|
713
|
-
_skey = os.environ.get("SYNTH_API_KEY")
|
|
714
|
-
if _skey:
|
|
715
|
-
headers["Authorization"] = f"Bearer {_skey}"
|
|
716
|
-
|
|
717
|
-
# Debug: outbound request diagnostics
|
|
718
|
-
try:
|
|
719
|
-
import json as _json
|
|
720
|
-
|
|
721
|
-
_size = len(_json.dumps(to_send))
|
|
722
|
-
except Exception:
|
|
723
|
-
_size = -1
|
|
724
|
-
print(
|
|
725
|
-
"[task:crafter] inference dispatch:",
|
|
726
|
-
{
|
|
727
|
-
"endpoint": f"{endpoint_base.rstrip('/')}/v1/chat/completions",
|
|
728
|
-
"direct_openai": bool(_direct),
|
|
729
|
-
"timeout": {
|
|
730
|
-
"read": 180.0,
|
|
731
|
-
"connect": 10.0,
|
|
732
|
-
"write": 60.0,
|
|
733
|
-
"pool": 60.0,
|
|
734
|
-
},
|
|
735
|
-
"payload_bytes": _size,
|
|
736
|
-
"has_auth": bool(headers.get("Authorization")),
|
|
737
|
-
},
|
|
738
|
-
flush=True,
|
|
739
|
-
)
|
|
740
|
-
|
|
741
|
-
_t0 = time.time()
|
|
742
|
-
try:
|
|
743
|
-
resp = client.post(
|
|
744
|
-
f"{endpoint_base.rstrip('/')}/v1/chat/completions",
|
|
745
|
-
json=to_send,
|
|
746
|
-
headers=headers,
|
|
747
|
-
)
|
|
748
|
-
except httpx.ReadTimeout as rte:
|
|
749
|
-
_elapsed = time.time() - _t0
|
|
750
|
-
print(
|
|
751
|
-
f"[task:crafter][timeout] read timeout after {_elapsed:.1f}s: {rte}",
|
|
752
|
-
flush=True,
|
|
753
|
-
)
|
|
754
|
-
raise
|
|
755
|
-
except Exception as re:
|
|
756
|
-
_elapsed = time.time() - _t0
|
|
757
|
-
print(
|
|
758
|
-
f"[task:crafter][error] request failed after {_elapsed:.1f}s: {type(re).__name__}: {re}",
|
|
759
|
-
flush=True,
|
|
760
|
-
)
|
|
761
|
-
raise
|
|
762
|
-
_elapsed = time.time() - _t0
|
|
763
|
-
print(
|
|
764
|
-
f"[task:crafter] inference status= {resp.status_code} elapsed={_elapsed:.2f}s",
|
|
765
|
-
flush=True,
|
|
766
|
-
)
|
|
767
|
-
# Emit a light-weight perf snapshot for visibility
|
|
768
|
-
try:
|
|
769
|
-
print(
|
|
770
|
-
"[metric] perf ",
|
|
771
|
-
"tok/s=n/a",
|
|
772
|
-
f"decision p50=n/a p95=n/a",
|
|
773
|
-
"roll n/a",
|
|
774
|
-
flush=True,
|
|
775
|
-
)
|
|
776
|
-
except Exception:
|
|
777
|
-
pass
|
|
778
|
-
# Debug: response status and body (on errors)
|
|
779
|
-
print("[task:crafter] inference status=", resp.status_code, flush=True)
|
|
780
|
-
if resp.status_code >= 400:
|
|
781
|
-
body_preview = resp.text[:800]
|
|
782
|
-
print("[task:crafter] inference error body:", body_preview, flush=True)
|
|
783
|
-
data = resp.json()
|
|
784
|
-
print(f"[task:crafter] inference response: {data}")
|
|
785
|
-
parsed = CrafterPolicy.parse_response_to_tool_calls(data, use_tools=True) or []
|
|
786
|
-
# Debug: parsed tool call summary
|
|
787
|
-
print(
|
|
788
|
-
"[task:crafter] parsed tool_calls: ",
|
|
789
|
-
{
|
|
790
|
-
"n": len(parsed),
|
|
791
|
-
"first": (parsed[0] if isinstance(parsed, list) and parsed else None),
|
|
792
|
-
},
|
|
793
|
-
flush=True,
|
|
794
|
-
)
|
|
795
|
-
# Print full tool call payloads for inspection
|
|
796
|
-
try:
|
|
797
|
-
import json as _json
|
|
798
|
-
|
|
799
|
-
for _i, _tc in enumerate(parsed):
|
|
800
|
-
try:
|
|
801
|
-
print(
|
|
802
|
-
f"[task:crafter] tool_call[{_i}]:",
|
|
803
|
-
_json.dumps(_tc, separators=(",", ":")),
|
|
804
|
-
flush=True,
|
|
805
|
-
)
|
|
806
|
-
except Exception:
|
|
807
|
-
print(f"[task:crafter] tool_call[{_i}]: {_tc}", flush=True)
|
|
808
|
-
except Exception:
|
|
809
|
-
pass
|
|
810
|
-
if not parsed:
|
|
811
|
-
# Dump compact body preview to understand schema when no tools parsed
|
|
812
|
-
try:
|
|
813
|
-
import json as _json
|
|
814
|
-
|
|
815
|
-
preview = _json.dumps(data, separators=(",", ":"))
|
|
816
|
-
print(
|
|
817
|
-
"[task:crafter] body(no_tools) preview:", preview[:800], flush=True
|
|
818
|
-
)
|
|
819
|
-
except Exception:
|
|
820
|
-
pass
|
|
821
|
-
# Early terminate the episode to avoid hanging on empty tool calls
|
|
822
|
-
print("[task:crafter] NO_TOOL_CALLS: terminating episode early", flush=True)
|
|
823
|
-
break
|
|
824
|
-
pending_tool_calls = parsed
|
|
825
|
-
ops_executed += 1
|
|
826
|
-
elif op == "env":
|
|
827
|
-
if not pending_tool_calls:
|
|
828
|
-
print("[task:crafter] no tool_calls; skipping env step", flush=True)
|
|
829
|
-
continue
|
|
830
|
-
info: dict[str, Any] | None = None
|
|
831
|
-
for tc in pending_tool_calls:
|
|
832
|
-
name = tc.get("tool_name")
|
|
833
|
-
if name == "interact":
|
|
834
|
-
# Parse the JSON arguments string
|
|
835
|
-
import json
|
|
836
|
-
|
|
837
|
-
args_str = tc.get("arguments", "{}")
|
|
838
|
-
try:
|
|
839
|
-
args_dict = json.loads(args_str)
|
|
840
|
-
actions = args_dict.get("actions", [])
|
|
841
|
-
reasoning = args_dict.get("reasoning", "")
|
|
842
|
-
print(f"[task:crafter] reasoning: {reasoning}", flush=True)
|
|
843
|
-
except (json.JSONDecodeError, TypeError):
|
|
844
|
-
print(
|
|
845
|
-
f"[task:crafter] ERROR: Failed to parse arguments: {args_str}",
|
|
846
|
-
flush=True,
|
|
847
|
-
)
|
|
848
|
-
actions = []
|
|
849
|
-
reasoning = "Parse error"
|
|
850
|
-
|
|
851
|
-
print(f"[task:crafter] env actions: {actions}", flush=True)
|
|
852
|
-
# Print a compact echo of the current prompt + tool call for easier triage
|
|
853
|
-
try:
|
|
854
|
-
import json as _json
|
|
855
|
-
|
|
856
|
-
print(
|
|
857
|
-
"TOOLCALL_CONFIG:",
|
|
858
|
-
_json.dumps(
|
|
859
|
-
{
|
|
860
|
-
"policy": req.policy.policy_name,
|
|
861
|
-
"tools_present": True,
|
|
862
|
-
"tool_choice": "required",
|
|
863
|
-
"stop_after": 1,
|
|
864
|
-
}
|
|
865
|
-
),
|
|
866
|
-
)
|
|
867
|
-
except Exception:
|
|
868
|
-
pass
|
|
869
|
-
|
|
870
|
-
# Execute each action individually
|
|
871
|
-
# Reset decision-level flip set for this decision
|
|
872
|
-
decision_flips: set[str] = set()
|
|
873
|
-
for act in actions:
|
|
874
|
-
observation, reward, done, _info = await _registry.step(env_id, act)
|
|
875
|
-
total_reward += float(reward)
|
|
876
|
-
# Debug: print step outcome (compact)
|
|
877
|
-
try:
|
|
878
|
-
ok = (
|
|
879
|
-
list(observation.keys())
|
|
880
|
-
if isinstance(observation, dict)
|
|
881
|
-
else []
|
|
882
|
-
)
|
|
883
|
-
print(
|
|
884
|
-
f"[task:crafter] step => a={act} r={float(reward)} done={bool(done)} obs_keys={ok[:5]}",
|
|
885
|
-
flush=True,
|
|
886
|
-
)
|
|
887
|
-
except Exception:
|
|
888
|
-
pass
|
|
889
|
-
step = RolloutStep(
|
|
890
|
-
obs=observation,
|
|
891
|
-
tool_calls=pending_tool_calls,
|
|
892
|
-
reward=float(reward),
|
|
893
|
-
done=bool(done),
|
|
894
|
-
truncated=False,
|
|
895
|
-
info=info,
|
|
896
|
-
)
|
|
897
|
-
trajectory_steps.append(step)
|
|
898
|
-
ops_executed += 1
|
|
899
|
-
|
|
900
|
-
# Check for achievement-based termination
|
|
901
|
-
if isinstance(observation, dict):
|
|
902
|
-
current_achievements = observation.get(
|
|
903
|
-
"achievements_status", {}
|
|
904
|
-
)
|
|
905
|
-
# Track flips 0→1 within this decision
|
|
906
|
-
try:
|
|
907
|
-
if not isinstance(current_achievements, dict):
|
|
908
|
-
current_achievements = {}
|
|
909
|
-
if prev_ach is None:
|
|
910
|
-
prev_ach = {
|
|
911
|
-
k: bool(v)
|
|
912
|
-
for k, v in (current_achievements or {}).items()
|
|
913
|
-
}
|
|
914
|
-
else:
|
|
915
|
-
for name, on in (current_achievements or {}).items():
|
|
916
|
-
if bool(on) and not bool(prev_ach.get(name, False)):
|
|
917
|
-
decision_flips.add(str(name))
|
|
918
|
-
# Update prev_ach to latest snapshot
|
|
919
|
-
prev_ach = {
|
|
920
|
-
k: bool(v)
|
|
921
|
-
for k, v in (current_achievements or {}).items()
|
|
922
|
-
}
|
|
923
|
-
except Exception:
|
|
924
|
-
pass
|
|
925
|
-
achieved_count = sum(
|
|
926
|
-
1 for v in current_achievements.values() if v
|
|
927
|
-
)
|
|
928
|
-
total_achievements = len(current_achievements)
|
|
929
|
-
|
|
930
|
-
# Terminate if we've achieved a significant portion of available achievements
|
|
931
|
-
if total_achievements > 0 and achieved_count >= max(
|
|
932
|
-
3, total_achievements // 2
|
|
933
|
-
):
|
|
934
|
-
print(
|
|
935
|
-
f"[task:crafter] achievement_termination: {achieved_count}/{total_achievements} achievements reached",
|
|
936
|
-
flush=True,
|
|
937
|
-
)
|
|
938
|
-
print(
|
|
939
|
-
f"[task:crafter] achieved: {[k for k, v in current_achievements.items() if v]}",
|
|
940
|
-
flush=True,
|
|
941
|
-
)
|
|
942
|
-
break
|
|
943
|
-
|
|
944
|
-
if done or len(trajectory_steps) >= max_steps:
|
|
945
|
-
print(
|
|
946
|
-
f"[task:crafter] episode_end: done={bool(done)} steps={len(trajectory_steps)} total_reward={total_reward}",
|
|
947
|
-
flush=True,
|
|
948
|
-
)
|
|
949
|
-
break
|
|
950
|
-
elif name == "terminate":
|
|
951
|
-
# Handle termination
|
|
952
|
-
print("[task:crafter] Agent requested termination", flush=True)
|
|
953
|
-
break
|
|
954
|
-
else:
|
|
955
|
-
# Non-interact tool call: count as a step without env change
|
|
956
|
-
print("[task:crafter] non-interact tool_call:", name, flush=True)
|
|
957
|
-
step = RolloutStep(
|
|
958
|
-
obs=observation,
|
|
959
|
-
tool_calls=pending_tool_calls,
|
|
960
|
-
reward=None,
|
|
961
|
-
done=False,
|
|
962
|
-
truncated=False,
|
|
963
|
-
info=info,
|
|
964
|
-
)
|
|
965
|
-
trajectory_steps.append(step)
|
|
966
|
-
ops_executed += 1
|
|
967
|
-
# End of decision: record indicator_i for shaping
|
|
968
|
-
try:
|
|
969
|
-
indicator_i = 1 if decision_flips else 0
|
|
970
|
-
decision_summaries.append({"indicator_i": indicator_i})
|
|
971
|
-
except Exception:
|
|
972
|
-
pass
|
|
973
|
-
pending_tool_calls = None
|
|
974
|
-
if len(trajectory_steps) >= max_steps:
|
|
975
|
-
print(
|
|
976
|
-
f"[task:crafter] max_steps_reached: steps={len(trajectory_steps)} total_reward={total_reward}",
|
|
977
|
-
flush=True,
|
|
978
|
-
)
|
|
979
|
-
break
|
|
980
|
-
else:
|
|
981
|
-
# Unknown op: skip
|
|
982
|
-
continue
|
|
983
|
-
if len(trajectory_steps) >= max_steps:
|
|
984
|
-
break
|
|
985
|
-
finally:
|
|
986
|
-
await _registry.terminate(env_id)
|
|
987
|
-
|
|
988
|
-
# Sanitize steps for JSON
|
|
989
|
-
safe_steps = [
|
|
990
|
-
RolloutStep(
|
|
991
|
-
obs=_to_jsonable(s.obs),
|
|
992
|
-
tool_calls=s.tool_calls,
|
|
993
|
-
reward=float(s.reward) if s.reward is not None else None,
|
|
994
|
-
done=bool(s.done),
|
|
995
|
-
truncated=bool(s.truncated) if s.truncated is not None else None,
|
|
996
|
-
info=_to_jsonable(s.info) if s.info is not None else None,
|
|
997
|
-
)
|
|
998
|
-
for s in trajectory_steps
|
|
999
|
-
]
|
|
1000
|
-
|
|
1001
|
-
trajectory = RolloutTrajectory(
|
|
1002
|
-
env_id=env_id,
|
|
1003
|
-
policy_id=req.policy.policy_name or "crafter-policy",
|
|
1004
|
-
steps=safe_steps,
|
|
1005
|
-
final={"observation": _to_jsonable(observation)},
|
|
1006
|
-
length=len(safe_steps),
|
|
1007
|
-
)
|
|
1008
|
-
# Calculate achievements for this episode
|
|
1009
|
-
final_obs = observation
|
|
1010
|
-
if isinstance(final_obs, dict):
|
|
1011
|
-
final_achievements = final_obs.get("achievements_status", {})
|
|
1012
|
-
else:
|
|
1013
|
-
# Handle numpy array case - no achievements available
|
|
1014
|
-
final_achievements = {}
|
|
1015
|
-
total_achievements = sum(1 for v in final_achievements.values() if v)
|
|
1016
|
-
|
|
1017
|
-
# Step-reward shaping: compute decision-level rewards if enabled
|
|
1018
|
-
branches: dict[str, Any] = {}
|
|
1019
|
-
try:
|
|
1020
|
-
sr_cfg = (
|
|
1021
|
-
(req.record.config or {}).get("step_rewards")
|
|
1022
|
-
if isinstance(req.record, RolloutRecordConfig)
|
|
1023
|
-
else None
|
|
1024
|
-
)
|
|
1025
|
-
except Exception:
|
|
1026
|
-
sr_cfg = None
|
|
1027
|
-
try:
|
|
1028
|
-
enabled = False
|
|
1029
|
-
mode = None
|
|
1030
|
-
step_beta = 0.0
|
|
1031
|
-
indicator_lambda = 0.0
|
|
1032
|
-
if isinstance(sr_cfg, dict):
|
|
1033
|
-
enabled = bool(sr_cfg.get("enabled", False))
|
|
1034
|
-
mode = (sr_cfg.get("mode") or "off").strip().lower()
|
|
1035
|
-
step_beta = float(sr_cfg.get("step_beta", 0.0))
|
|
1036
|
-
indicator_lambda = float(sr_cfg.get("indicator_lambda", 0.0))
|
|
1037
|
-
# Env overrides
|
|
1038
|
-
import os as _os2
|
|
1039
|
-
|
|
1040
|
-
if _os2.getenv("STEP_BETA"):
|
|
1041
|
-
step_beta = float(_os2.getenv("STEP_BETA"))
|
|
1042
|
-
if _os2.getenv("STEP_LAMBDA"):
|
|
1043
|
-
indicator_lambda = float(_os2.getenv("STEP_LAMBDA"))
|
|
1044
|
-
if enabled and mode == "decision_stepwise" and decision_summaries:
|
|
1045
|
-
dec_rewards = compute_decision_rewards(
|
|
1046
|
-
decision_summaries=decision_summaries,
|
|
1047
|
-
total_achievements=total_achievements,
|
|
1048
|
-
step_beta=step_beta,
|
|
1049
|
-
indicator_lambda=indicator_lambda,
|
|
1050
|
-
)
|
|
1051
|
-
branches["decision_rewards"] = dec_rewards
|
|
1052
|
-
print(
|
|
1053
|
-
"[task:crafter] step_rewards: ",
|
|
1054
|
-
{
|
|
1055
|
-
"enabled": True,
|
|
1056
|
-
"mode": mode,
|
|
1057
|
-
"step_beta": step_beta,
|
|
1058
|
-
"indicator_lambda": indicator_lambda,
|
|
1059
|
-
"decisions": len(dec_rewards),
|
|
1060
|
-
},
|
|
1061
|
-
flush=True,
|
|
1062
|
-
)
|
|
1063
|
-
except Exception as _e_sr:
|
|
1064
|
-
print(f"[task:crafter] step_rewards_error: {_e_sr}", flush=True)
|
|
1065
|
-
|
|
1066
|
-
# Optional tracing of episode/rewards (gated)
|
|
1067
|
-
try:
|
|
1068
|
-
import os as _os3
|
|
1069
|
-
|
|
1070
|
-
if _os3.getenv("TRACE_RL", "0") == "1":
|
|
1071
|
-
from synth_ai.tracing_v3.session_tracer import SessionTracer # type: ignore
|
|
1072
|
-
|
|
1073
|
-
tracer = SessionTracer()
|
|
1074
|
-
await tracer.initialize()
|
|
1075
|
-
meta = {
|
|
1076
|
-
"env": req.env.env_name,
|
|
1077
|
-
"policy": req.policy.policy_name,
|
|
1078
|
-
"step_rewards": {
|
|
1079
|
-
"enabled": bool(sr_cfg.get("enabled", False))
|
|
1080
|
-
if isinstance(sr_cfg, dict)
|
|
1081
|
-
else False,
|
|
1082
|
-
"mode": (sr_cfg.get("mode") if isinstance(sr_cfg, dict) else None),
|
|
1083
|
-
},
|
|
1084
|
-
}
|
|
1085
|
-
async with tracer.session(metadata=meta):
|
|
1086
|
-
# Record episode outcome at end
|
|
1087
|
-
await tracer.record_outcome_reward(
|
|
1088
|
-
total_reward=int(total_reward),
|
|
1089
|
-
achievements_count=int(total_achievements),
|
|
1090
|
-
total_steps=int(len(trajectory_steps)),
|
|
1091
|
-
)
|
|
1092
|
-
except Exception as _te:
|
|
1093
|
-
print(f"[task:crafter] tracing_error: {_te}", flush=True)
|
|
1094
|
-
|
|
1095
|
-
metrics = RolloutMetrics(
|
|
1096
|
-
episode_returns=[total_reward],
|
|
1097
|
-
mean_return=float(total_achievements),
|
|
1098
|
-
num_steps=len(trajectory_steps),
|
|
1099
|
-
num_episodes=1,
|
|
1100
|
-
)
|
|
1101
|
-
# Debug: print reward and achievement metrics
|
|
1102
|
-
print(
|
|
1103
|
-
f"[task:crafter] Rollout metrics: total_reward={total_reward}, total_achievements={total_achievements}, mean_return={metrics.mean_return}, episode_returns={metrics.episode_returns}",
|
|
1104
|
-
flush=True,
|
|
1105
|
-
)
|
|
1106
|
-
return RolloutResponse(
|
|
1107
|
-
run_id=req.run_id,
|
|
1108
|
-
trajectories=[trajectory],
|
|
1109
|
-
branches=branches,
|
|
1110
|
-
metrics=metrics,
|
|
1111
|
-
aborted=False,
|
|
1112
|
-
ops_executed=ops_executed,
|
|
1113
|
-
)
|
|
1114
|
-
|
|
1115
|
-
@api.get("/test_auth")
|
|
1116
|
-
def test_auth(request: Request):
|
|
1117
|
-
expected = os.environ.get("ENVIRONMENT_API_KEY")
|
|
1118
|
-
if not expected:
|
|
1119
|
-
raise HTTPException(
|
|
1120
|
-
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
1121
|
-
detail="Missing ENVIRONMENT_API_KEY in service env",
|
|
1122
|
-
)
|
|
1123
|
-
header_key = request.headers.get("x-api-key") or request.headers.get("X-API-Key")
|
|
1124
|
-
ok = bool(header_key) and (header_key == expected)
|
|
1125
|
-
if not ok:
|
|
1126
|
-
raise HTTPException(
|
|
1127
|
-
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or missing API key"
|
|
1128
|
-
)
|
|
1129
|
-
return {"ok": True}
|
|
1130
|
-
|
|
1131
|
-
return api
|