synth-ai 0.2.9.dev0__py3-none-any.whl → 0.2.23.dev3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/README.md +1 -0
- examples/__init__.py +16 -0
- examples/analyze_semantic_words.sh +17 -0
- examples/baseline/banking77_baseline.py +243 -0
- examples/baseline/banking77_pipeline_baseline.py +294 -0
- examples/baseline/crafter_baseline.py +407 -0
- examples/baseline/pokemon_red_baseline.py +326 -0
- examples/baseline/simple_baseline.py +56 -0
- examples/baseline/warming_up_to_rl_baseline.py +239 -0
- examples/blog_posts/gepa/README.md +355 -0
- examples/blog_posts/gepa/configs/banking77_gepa_local.toml +95 -0
- examples/blog_posts/gepa/configs/banking77_gepa_test.toml +80 -0
- examples/blog_posts/gepa/configs/banking77_mipro_local.toml +50 -0
- examples/blog_posts/gepa/configs/banking77_pipeline_gepa_local.toml +101 -0
- examples/blog_posts/gepa/configs/banking77_pipeline_gepa_test.toml +96 -0
- examples/blog_posts/gepa/configs/hotpotqa_gepa_local.toml +57 -0
- examples/blog_posts/gepa/configs/hotpotqa_gepa_qwen.toml +35 -0
- examples/blog_posts/gepa/configs/hotpotqa_mipro_local.toml +51 -0
- examples/blog_posts/gepa/configs/hover_gepa_local.toml +57 -0
- examples/blog_posts/gepa/configs/hover_gepa_qwen.toml +35 -0
- examples/blog_posts/gepa/configs/hover_mipro_local.toml +51 -0
- examples/blog_posts/gepa/configs/ifbench_gepa_local.toml +57 -0
- examples/blog_posts/gepa/configs/ifbench_gepa_qwen.toml +35 -0
- examples/blog_posts/gepa/configs/ifbench_mipro_local.toml +51 -0
- examples/blog_posts/gepa/configs/pupa_gepa_local.toml +58 -0
- examples/blog_posts/gepa/configs/pupa_mipro_local.toml +52 -0
- examples/blog_posts/gepa/deploy_banking77_task_app.sh +54 -0
- examples/blog_posts/gepa/gepa_baseline.py +204 -0
- examples/blog_posts/gepa/query_prompts_example.py +97 -0
- examples/blog_posts/gepa/run_gepa_banking77.sh +112 -0
- examples/blog_posts/gepa/run_gepa_banking77_pipeline.sh +163 -0
- examples/blog_posts/gepa/task_apps.py +105 -0
- examples/blog_posts/gepa/test_gepa_local.sh +67 -0
- examples/blog_posts/gepa/verify_banking77_setup.sh +123 -0
- examples/blog_posts/mipro/README.md +415 -0
- examples/blog_posts/mipro/configs/banking77_mipro_local.toml +91 -0
- examples/blog_posts/mipro/configs/banking77_mipro_test.toml +87 -0
- examples/blog_posts/mipro/configs/banking77_pipeline_mipro_gemini_flash_lite_local.toml +98 -0
- examples/blog_posts/mipro/configs/banking77_pipeline_mipro_gpt41mini_local.toml +96 -0
- examples/blog_posts/mipro/configs/banking77_pipeline_mipro_local.toml +94 -0
- examples/blog_posts/mipro/configs/banking77_pipeline_mipro_test.toml +170 -0
- examples/blog_posts/mipro/deploy_banking77_pipeline_task_app.sh +59 -0
- examples/blog_posts/mipro/deploy_banking77_task_app.sh +41 -0
- examples/blog_posts/mipro/multi_step.md +79 -0
- examples/blog_posts/mipro/run_mipro_banking77.sh +191 -0
- examples/blog_posts/mipro/run_mipro_banking77_pipeline.sh +171 -0
- examples/blog_posts/mipro/run_mipro_banking77_pipeline_gemini_flash_lite.sh +177 -0
- examples/blog_posts/mipro/run_mipro_banking77_pipeline_gpt41mini.sh +173 -0
- examples/blog_posts/mipro/verify_banking77_setup.sh +117 -0
- examples/blog_posts/pokemon_vl/README.md +98 -0
- examples/blog_posts/pokemon_vl/configs/eval_gpt5nano.toml +26 -0
- examples/blog_posts/pokemon_vl/configs/eval_qwen3_vl.toml +27 -0
- examples/blog_posts/pokemon_vl/configs/eval_rl_final.toml +24 -0
- examples/blog_posts/pokemon_vl/configs/filter_high_reward.toml +10 -0
- examples/blog_posts/pokemon_vl/configs/train_rl_from_sft.toml +43 -0
- examples/blog_posts/pokemon_vl/configs/train_sft_qwen4b_vl.toml +40 -0
- examples/blog_posts/pokemon_vl/extract_images.py +239 -0
- examples/blog_posts/pokemon_vl/pokemon_vl_baseline.py +326 -0
- examples/blog_posts/pokemon_vl/run_eval_extract_images.py +209 -0
- examples/blog_posts/pokemon_vl/run_qwen_eval_extract_images.py +212 -0
- examples/blog_posts/pokemon_vl/text_box_analysis.md +106 -0
- examples/blog_posts/warming_up_to_rl/ARCHITECTURE.md +195 -0
- examples/blog_posts/warming_up_to_rl/FINAL_TEST_RESULTS.md +127 -0
- examples/blog_posts/warming_up_to_rl/INFERENCE_SUCCESS.md +132 -0
- examples/blog_posts/warming_up_to_rl/README.md +158 -0
- examples/blog_posts/warming_up_to_rl/SMOKE_TESTING.md +164 -0
- examples/blog_posts/warming_up_to_rl/SMOKE_TEST_COMPLETE.md +253 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_baseline_qwen32b_10x20.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b_10x20.toml +26 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_groq_qwen32b.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_openai_gpt_oss_120b.toml +29 -0
- examples/blog_posts/warming_up_to_rl/configs/filter_high_reward_dataset.toml +10 -0
- examples/blog_posts/warming_up_to_rl/configs/smoke_test.toml +75 -0
- examples/blog_posts/warming_up_to_rl/configs/train_rl_from_sft.toml +91 -0
- examples/blog_posts/warming_up_to_rl/configs/train_sft_qwen4b.toml +40 -0
- examples/blog_posts/warming_up_to_rl/warming_up_to_rl_baseline.py +187 -0
- examples/crafter_debug_render.py +186 -0
- examples/dev/qwen3_32b_qlora_4xh100.toml +45 -0
- examples/gepa/banking77_pipeline_gepa.toml +96 -0
- examples/gepa/multi_stage_gepa_example.toml +84 -0
- examples/gepa/run_gepa_banking77_pipeline.sh +157 -0
- examples/multi_step/SFT_README.md +147 -0
- examples/multi_step/configs/README_verilog_rl.md +77 -0
- examples/multi_step/configs/VERILOG_REWARDS.md +103 -0
- examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +196 -0
- examples/multi_step/configs/crafter_eval_synth_qwen4b.toml +35 -0
- examples/multi_step/configs/crafter_eval_text_only_groq_qwen32b.toml +36 -0
- examples/multi_step/configs/crafter_rl_outcome.toml +75 -0
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +145 -0
- examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +84 -0
- examples/multi_step/configs/crafter_rl_stepwise_simple.toml +79 -0
- examples/multi_step/configs/crafter_rl_stepwise_simple_NEW_FORMAT.toml +105 -0
- examples/multi_step/configs/crafter_sft_qwen30b_lora.toml +62 -0
- examples/multi_step/configs/crafter_synth_backend.md +40 -0
- examples/multi_step/configs/verilog_eval_groq_qwen32b.toml +31 -0
- examples/multi_step/configs/verilog_eval_synth_qwen8b.toml +33 -0
- examples/multi_step/configs/verilog_rl_lora.toml +147 -0
- examples/multi_step/convert_traces_to_sft.py +84 -0
- examples/multi_step/crafter_rl_lora.md +70 -0
- examples/multi_step/judges/crafter_backend_judge.py +220 -0
- examples/multi_step/judges/verilog_backend_judge.py +234 -0
- examples/multi_step/readme.md +48 -0
- examples/multi_step/run_sft_qwen30b.sh +45 -0
- examples/multi_step/sse_metrics_streaming_notes.md +357 -0
- examples/multi_step/task_app_config_notes.md +494 -0
- examples/multi_step/verilog_rl_lora.md +218 -0
- examples/qwen_coder/README.md +102 -0
- examples/qwen_coder/_shared.py +113 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +60 -0
- examples/qwen_coder/configs/coder_lora_4b.toml +61 -0
- examples/qwen_coder/configs/coder_lora_small.toml +57 -0
- examples/qwen_coder/generate_dataset.py +98 -0
- examples/qwen_coder/infer_ft_smoke.py +65 -0
- examples/qwen_coder/infer_prod_proxy.py +73 -0
- examples/qwen_coder/infer_via_synth.py +87 -0
- examples/qwen_coder/scripts/infer_coder.sh +19 -0
- examples/qwen_coder/scripts/train_coder_30b.sh +22 -0
- examples/qwen_coder/sft_full_17b.py +103 -0
- examples/qwen_coder/sft_lora_30b.py +110 -0
- examples/qwen_coder/subset_jsonl.py +39 -0
- examples/qwen_coder/todos.md +38 -0
- examples/qwen_coder/validate_jsonl.py +60 -0
- examples/qwen_vl/BUGS_AND_FIXES.md +232 -0
- examples/qwen_vl/IMAGE_VALIDATION_COMPLETE.md +271 -0
- examples/qwen_vl/IMAGE_VALIDATION_SUMMARY.md +260 -0
- examples/qwen_vl/INFERENCE_SFT_TESTS.md +412 -0
- examples/qwen_vl/NEXT_STEPS_2B.md +325 -0
- examples/qwen_vl/QUICKSTART.md +327 -0
- examples/qwen_vl/QUICKSTART_RL_VISION.md +110 -0
- examples/qwen_vl/README.md +152 -0
- examples/qwen_vl/RL_VISION_COMPLETE.md +475 -0
- examples/qwen_vl/RL_VISION_TESTING.md +333 -0
- examples/qwen_vl/SDK_VISION_INTEGRATION.md +328 -0
- examples/qwen_vl/SETUP_COMPLETE.md +274 -0
- examples/qwen_vl/VISION_TESTS_COMPLETE.md +489 -0
- examples/qwen_vl/VLM_PIPELINE_COMPLETE.md +242 -0
- examples/qwen_vl/__init__.py +2 -0
- examples/qwen_vl/collect_data_via_cli.md +415 -0
- examples/qwen_vl/collect_vision_traces.py +368 -0
- examples/qwen_vl/configs/crafter_rl_vision_qwen3vl4b.toml +110 -0
- examples/qwen_vl/configs/crafter_vlm_sft_example.toml +59 -0
- examples/qwen_vl/configs/eval_gpt4o_mini_vision.toml +26 -0
- examples/qwen_vl/configs/eval_gpt4o_vision_proper.toml +29 -0
- examples/qwen_vl/configs/eval_gpt5nano_vision.toml +26 -0
- examples/qwen_vl/configs/eval_qwen3vl_vision.toml +26 -0
- examples/qwen_vl/configs/filter_qwen3vl_sft.toml +49 -0
- examples/qwen_vl/configs/filter_vision_sft.toml +52 -0
- examples/qwen_vl/configs/filter_vision_test.toml +8 -0
- examples/qwen_vl/configs/sft_qwen3_vl_2b_test.toml +54 -0
- examples/qwen_vl/crafter_gpt5nano_agent.py +308 -0
- examples/qwen_vl/crafter_qwen_vl_agent.py +300 -0
- examples/qwen_vl/run_vision_comparison.sh +61 -0
- examples/qwen_vl/run_vision_sft_pipeline.sh +175 -0
- examples/qwen_vl/test_image_validation.py +201 -0
- examples/qwen_vl/test_sft_vision_data.py +110 -0
- examples/rl/README.md +169 -0
- examples/rl/configs/eval_base_qwen.toml +17 -0
- examples/rl/configs/eval_rl_qwen.toml +13 -0
- examples/rl/configs/rl_from_base_qwen.toml +62 -0
- examples/rl/configs/rl_from_base_qwen17.toml +80 -0
- examples/rl/configs/rl_from_ft_qwen.toml +37 -0
- examples/rl/download_dataset.py +80 -0
- examples/rl/run_eval.py +436 -0
- examples/rl/run_rl_and_save.py +111 -0
- examples/rl/task_app/README.md +21 -0
- {synth_ai/task/apps → examples/rl/task_app}/math_single_step.py +188 -50
- examples/rl/task_app/math_task_app.py +111 -0
- examples/run_crafter_demo.sh +10 -0
- examples/sdk_prompt_learning_example.py +55 -0
- examples/sft/README.md +139 -0
- examples/sft/configs/crafter_fft_qwen0p6b.toml +49 -0
- examples/sft/configs/crafter_lora_qwen0p6b.toml +49 -0
- examples/sft/evaluate.py +117 -0
- examples/sft/export_dataset.py +120 -0
- examples/sft/generate_traces.py +164 -0
- examples/swe/__init__.py +12 -0
- examples/swe/task_app/README.md +135 -0
- examples/swe/task_app/__init__.py +2 -0
- examples/swe/task_app/grpo_swe_mini.py +604 -0
- examples/swe/task_app/grpo_swe_mini_task_app.py +124 -0
- examples/swe/task_app/hosted/README.md +173 -0
- examples/swe/task_app/hosted/__init__.py +5 -0
- examples/swe/task_app/hosted/branching.py +143 -0
- examples/swe/task_app/hosted/environment_routes.py +1289 -0
- examples/swe/task_app/hosted/envs/__init__.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/__init__.py +6 -0
- examples/swe/task_app/hosted/envs/crafter/app.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/environment.py +522 -0
- examples/swe/task_app/hosted/envs/crafter/policy.py +478 -0
- examples/swe/task_app/hosted/envs/crafter/react_agent.py +108 -0
- examples/swe/task_app/hosted/envs/crafter/shared.py +305 -0
- examples/swe/task_app/hosted/envs/crafter/tools.py +47 -0
- examples/swe/task_app/hosted/envs/mini_swe/__init__.py +8 -0
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +1191 -0
- examples/swe/task_app/hosted/envs/mini_swe/policy.py +355 -0
- examples/swe/task_app/hosted/envs/mini_swe/shared.py +83 -0
- examples/swe/task_app/hosted/envs/mini_swe/tools.py +96 -0
- examples/swe/task_app/hosted/hosted_app.py +204 -0
- examples/swe/task_app/hosted/inference/__init__.py +5 -0
- examples/swe/task_app/hosted/inference/openai_client.py +584 -0
- examples/swe/task_app/hosted/main.py +100 -0
- examples/swe/task_app/hosted/policy_routes.py +1094 -0
- examples/swe/task_app/hosted/registry.py +195 -0
- examples/swe/task_app/hosted/rollout.py +1905 -0
- examples/swe/task_app/hosted/storage/__init__.py +5 -0
- examples/swe/task_app/hosted/storage/volume.py +211 -0
- examples/swe/task_app/hosted/test_agents.py +161 -0
- examples/swe/task_app/hosted/test_service.py +136 -0
- examples/swe/task_app/hosted/utils.py +62 -0
- examples/swe/task_app/morph_backend.py +178 -0
- examples/task_apps/IMAGE_ONLY_EVAL_QUICKSTART.md +258 -0
- examples/task_apps/TESTING.md +275 -0
- examples/task_apps/banking77/__init__.py +6 -0
- examples/task_apps/banking77/banking77_task_app.py +912 -0
- examples/task_apps/banking77/deploy_wrapper.py +46 -0
- examples/task_apps/banking77_pipeline/__init__.py +6 -0
- examples/task_apps/banking77_pipeline/banking77_pipeline_task_app.py +489 -0
- examples/task_apps/banking77_pipeline/deploy_wrapper.py +50 -0
- examples/task_apps/crafter/CREATE_SFT_DATASET.md +286 -0
- examples/task_apps/crafter/EVAL_IMAGE_ONLY_RESULTS.md +152 -0
- examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +187 -0
- examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +281 -0
- examples/task_apps/crafter/QUERY_EXAMPLES.md +203 -0
- examples/task_apps/crafter/README_IMAGE_ONLY_EVAL.md +316 -0
- examples/task_apps/crafter/eval_image_only_gpt4o.toml +28 -0
- examples/task_apps/crafter/eval_text_only_groq_llama.toml +36 -0
- examples/task_apps/crafter/filter_sft_dataset.toml +16 -0
- examples/task_apps/crafter/task_app/README.md +42 -0
- examples/task_apps/crafter/task_app/__init__.py +5 -0
- examples/task_apps/crafter/task_app/grpo_crafter.py +1055 -0
- examples/task_apps/crafter/task_app/grpo_crafter_task_app.py +146 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/README.md +173 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/__init__.py +5 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/branching.py +143 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/environment_routes.py +1226 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/__init__.py +1 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/__init__.py +6 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/app.py +1 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/environment.py +532 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/policy.py +583 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/react_agent.py +122 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/shared.py +305 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/tools.py +47 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/hosted_app.py +253 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/__init__.py +5 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +999 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/main.py +100 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +1252 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/registry.py +195 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +2233 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/storage/__init__.py +5 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/storage/volume.py +211 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/test_agents.py +161 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/test_service.py +136 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +411 -0
- examples/task_apps/dev/pokemon_emerald/__init__.py +2 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/README.md +811 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/__init__.py +120 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/action.py +160 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/memory.py +155 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/perception.py +69 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/planning.py +96 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/simple.py +1502 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/system_prompt.py +4 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/grab_map.py +68 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/manual.py +216 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/__init__.py +35 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emerald_utils.py +631 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emulator.py +1544 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/enums.py +1428 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/memory_reader.py +4848 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/types.py +41 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/utils.py +298 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pyproject.toml +95 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/run.py +204 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/app.py +2152 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/client.py +429 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/frame_server.py +155 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/README.md +78 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/run_tests.py +122 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_direct.py +76 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_prompts.py +413 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_battle_state_formatting.py +204 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection.py +133 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection_comprehensive.py +229 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_direct_agent_emulator.py +300 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_fps_adjustment_pytest.py +205 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_direct.py +200 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_transition.py +284 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_map_ground_truth_comparison.py +468 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_memory_map.py +575 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_server_map_validation.py +311 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_torchic_state.py +259 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/anticheat.py +372 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/checkpoint.py +296 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/error_handler.py +275 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/get_local_ip.py +22 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/helpers.py +44 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/llm_logger.py +514 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_formatter.py +415 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher.py +1763 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher_singleton.py +33 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_trimmer.py +106 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_visualizer.py +334 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/ocr_dialogue.py +1020 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/recording.py +188 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/state_formatter.py +1481 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/vlm.py +862 -0
- examples/task_apps/dev/pokemon_emerald/modal_app.py +114 -0
- examples/task_apps/dev/pokemon_emerald/task_app/README.md +81 -0
- examples/task_apps/dev/pokemon_emerald/task_app/__init__.py +6 -0
- examples/task_apps/dev/pokemon_emerald/task_app/pokemon_emerald.py +685 -0
- examples/task_apps/enron/__init__.py +2 -0
- examples/task_apps/enron/eval_groq_qwen32.toml +16 -0
- examples/task_apps/enron/filter_sft.toml +5 -0
- examples/task_apps/enron/task_app/README.md +14 -0
- examples/task_apps/enron/task_app/__init__.py +1 -0
- examples/task_apps/enron/task_app/grpo_enron.py +906 -0
- examples/task_apps/enron/task_app/grpo_enron_task_app.py +146 -0
- examples/task_apps/enron/tests/__init__.py +4 -0
- examples/task_apps/enron/tests/conftest.py +115 -0
- examples/task_apps/enron/tests/integration/__init__.py +4 -0
- examples/task_apps/enron/tests/integration/test_enron_eval.py +179 -0
- examples/task_apps/enron/tests/integration/test_enron_rollout.py +135 -0
- examples/task_apps/enron/tests/unit/__init__.py +4 -0
- examples/task_apps/enron/tests/unit/test_enron_environment.py +126 -0
- examples/task_apps/gepa_benchmarks/__init__.py +7 -0
- examples/task_apps/gepa_benchmarks/common.py +260 -0
- examples/task_apps/gepa_benchmarks/hotpotqa_task_app.py +507 -0
- examples/task_apps/gepa_benchmarks/hover_task_app.py +436 -0
- examples/task_apps/gepa_benchmarks/ifbench_task_app.py +563 -0
- examples/task_apps/gepa_benchmarks/pupa_task_app.py +460 -0
- examples/task_apps/math/README.md +21 -0
- examples/task_apps/math/math_single_step.py +1000 -0
- examples/task_apps/math/math_task_app.py +115 -0
- examples/task_apps/pokemon_battle/__init__.py +2 -0
- examples/task_apps/pokemon_battle/modal_app.py +104 -0
- examples/task_apps/pokemon_battle/task_app/README.md +68 -0
- examples/task_apps/pokemon_battle/task_app/__init__.py +6 -0
- examples/task_apps/pokemon_battle/task_app/pokemon_showdown.py +932 -0
- examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_COMPLETE.md +283 -0
- examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_STATUS.md +155 -0
- examples/task_apps/pokemon_red/README.md +356 -0
- examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +428 -0
- examples/task_apps/pokemon_red/__init__.py +3 -0
- examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +30 -0
- examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +224 -0
- examples/task_apps/pokemon_red/pallet_town_rl_config.toml +75 -0
- examples/task_apps/pokemon_red/task_app.py +1048 -0
- examples/task_apps/pokemon_red/test_pallet_town_rewards.py +193 -0
- examples/task_apps/sokoban/README.md +306 -0
- examples/task_apps/sokoban/__init__.py +3 -0
- examples/task_apps/sokoban/eval_groq_qwen32.toml +16 -0
- examples/task_apps/sokoban/eval_openai_gpt5.toml +16 -0
- examples/task_apps/sokoban/filter_sft.toml +5 -0
- examples/task_apps/sokoban/task_app.py +1058 -0
- examples/task_apps/sokoban/tests/__init__.py +4 -0
- examples/task_apps/sokoban/tests/conftest.py +113 -0
- examples/task_apps/sokoban/tests/integration/__init__.py +4 -0
- examples/task_apps/sokoban/tests/integration/test_sokoban_eval.py +57 -0
- examples/task_apps/sokoban/tests/integration/test_sokoban_rollout.py +198 -0
- examples/task_apps/sokoban/tests/unit/__init__.py +4 -0
- examples/task_apps/sokoban/tests/unit/test_sokoban_environment.py +114 -0
- examples/task_apps/verilog/__init__.py +1 -0
- examples/task_apps/verilog/eval_groq_qwen32b.toml +22 -0
- examples/task_apps/verilog/filter_sft.toml +5 -0
- examples/task_apps/verilog/task_app/README.md +12 -0
- examples/task_apps/verilog/task_app/__init__.py +1 -0
- examples/task_apps/verilog/task_app/grpo_verilog.py +1166 -0
- examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +145 -0
- examples/task_apps/verilog/tests/__init__.py +4 -0
- examples/task_apps/verilog/tests/conftest.py +115 -0
- examples/task_apps/verilog/tests/integration/__init__.py +4 -0
- examples/task_apps/verilog/tests/integration/test_verilog_eval.py +181 -0
- examples/task_apps/verilog/tests/integration/test_verilog_rollout.py +55 -0
- examples/task_apps/verilog/tests/unit/__init__.py +4 -0
- examples/task_apps/verilog/tests/unit/test_verilog_scoring.py +118 -0
- examples/tunnel_gepa_banking77/README.md +106 -0
- examples/tunnel_gepa_banking77/banking77_gepa_tunnel.toml +95 -0
- examples/tunnel_gepa_banking77/keep_tunnel_running.py +60 -0
- examples/tunnel_gepa_banking77/run_gepa_with_tunnel.sh +226 -0
- examples/vlm/PROPOSAL.md +53 -0
- examples/vlm/README.md +68 -0
- examples/vlm/configs/crafter_vlm_gpt4o.toml +49 -0
- examples/vlm/crafter_image_only_agent.py +207 -0
- examples/vlm/crafter_openai_vlm_agent.py +275 -0
- examples/vlm/filter_image_rows.py +63 -0
- examples/vlm/run_crafter_vlm_benchmark.py +316 -0
- examples/warming_up_to_rl/_utils.py +92 -0
- examples/warming_up_to_rl/analyze_trace_db.py +422 -0
- examples/warming_up_to_rl/configs/crafter_fft.toml +53 -0
- examples/warming_up_to_rl/configs/crafter_fft_4b.toml +54 -0
- examples/warming_up_to_rl/configs/eval_fft_qwen4b.toml +22 -0
- examples/warming_up_to_rl/configs/eval_groq_qwen32b.toml +15 -0
- examples/warming_up_to_rl/configs/eval_modal_qwen4b.toml +24 -0
- examples/warming_up_to_rl/configs/eval_stepwise_complex.toml +35 -0
- examples/warming_up_to_rl/configs/eval_stepwise_consistent.toml +26 -0
- examples/warming_up_to_rl/configs/eval_stepwise_per_achievement.toml +36 -0
- examples/warming_up_to_rl/configs/eval_stepwise_simple.toml +32 -0
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +85 -0
- examples/warming_up_to_rl/configs/rl_from_ft.toml +58 -0
- examples/warming_up_to_rl/export_trace_sft.py +837 -0
- examples/warming_up_to_rl/groq_test.py +97 -0
- examples/warming_up_to_rl/manage_secrets.py +131 -0
- examples/warming_up_to_rl/old/event_rewards.md +234 -0
- examples/warming_up_to_rl/old/notes.md +73 -0
- examples/warming_up_to_rl/readme.md +110 -0
- examples/warming_up_to_rl/run_eval.py +736 -0
- examples/warming_up_to_rl/run_fft_and_save.py +380 -0
- examples/warming_up_to_rl/run_local_rollout.py +239 -0
- examples/warming_up_to_rl/run_local_rollout_modal.py +248 -0
- examples/warming_up_to_rl/run_local_rollout_parallel.py +405 -0
- examples/warming_up_to_rl/run_local_rollout_traced.py +477 -0
- examples/warming_up_to_rl/run_rl_and_save.py +124 -0
- examples/warming_up_to_rl/run_rollout_remote.py +156 -0
- examples/warming_up_to_rl/task_app/README.md +42 -0
- examples/warming_up_to_rl/task_app/grpo_crafter.py +876 -0
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +135 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/README.md +173 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +143 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +1226 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +6 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +522 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +454 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +108 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +305 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +47 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +253 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +729 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +100 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +1114 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +195 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +1891 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +211 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +161 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +137 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +129 -0
- examples/workflows/math_rl/configs/eval_base_qwen.toml +15 -0
- examples/workflows/math_rl/configs/eval_rl_qwen.toml +11 -0
- examples/workflows/math_rl/configs/rl_from_base_qwen.toml +62 -0
- examples/workflows/math_rl/configs/rl_from_base_qwen17.toml +80 -0
- examples/workflows/math_rl/configs/rl_from_ft_qwen.toml +35 -0
- examples/workflows/math_rl/download_dataset.py +80 -0
- examples/workflows/math_rl/run_eval.py +436 -0
- examples/workflows/math_rl/run_rl_and_save.py +111 -0
- synth_ai/__init__.py +47 -23
- synth_ai/_utils/__init__.py +47 -0
- synth_ai/_utils/base_url.py +10 -0
- synth_ai/_utils/http.py +10 -0
- synth_ai/_utils/prompts.py +10 -0
- synth_ai/_utils/task_app_state.py +12 -0
- synth_ai/_utils/user_config.py +10 -0
- synth_ai/api/models/supported.py +514 -0
- synth_ai/api/train/__init__.py +60 -2
- synth_ai/api/train/builders.py +347 -39
- synth_ai/api/train/cli.py +895 -160
- synth_ai/api/train/config_finder.py +103 -25
- synth_ai/api/train/configs/__init__.py +65 -0
- synth_ai/api/train/configs/prompt_learning.py +496 -0
- synth_ai/api/train/configs/rl.py +188 -0
- synth_ai/api/train/configs/sft.py +99 -0
- synth_ai/api/train/configs/shared.py +81 -0
- synth_ai/api/train/env_resolver.py +70 -20
- synth_ai/api/train/pollers.py +29 -4
- synth_ai/api/train/prompt_learning.py +425 -0
- synth_ai/api/train/sft.py +390 -0
- synth_ai/api/train/supported_algos.py +147 -0
- synth_ai/api/train/task_app.py +6 -4
- synth_ai/api/train/utils.py +64 -52
- synth_ai/api/train/validators.py +1117 -0
- synth_ai/api/tunnel.py +49 -0
- synth_ai/auth/credentials.py +94 -0
- synth_ai/baseline/__init__.py +25 -0
- synth_ai/baseline/config.py +209 -0
- synth_ai/baseline/discovery.py +214 -0
- synth_ai/baseline/execution.py +146 -0
- synth_ai/cfgs.py +227 -0
- synth_ai/cli/__init__.py +85 -63
- synth_ai/cli/_modal_wrapper.py +31 -0
- synth_ai/cli/_storage.py +20 -0
- synth_ai/cli/_typer_patch.py +47 -0
- synth_ai/cli/_validate_task_app.py +29 -0
- synth_ai/cli/balance.py +16 -4
- synth_ai/cli/calc.py +36 -21
- synth_ai/cli/claude.py +70 -0
- synth_ai/cli/codex.py +267 -0
- synth_ai/cli/commands/__init__.py +18 -0
- synth_ai/cli/commands/baseline/__init__.py +12 -0
- synth_ai/cli/commands/baseline/core.py +637 -0
- synth_ai/cli/commands/baseline/list.py +93 -0
- synth_ai/cli/commands/demo/__init__.py +6 -0
- synth_ai/cli/commands/demo/core.py +163 -0
- synth_ai/cli/commands/eval/__init__.py +19 -0
- synth_ai/cli/commands/eval/core.py +1112 -0
- synth_ai/cli/commands/eval/errors.py +81 -0
- synth_ai/cli/commands/eval/validation.py +133 -0
- synth_ai/cli/commands/filter/__init__.py +12 -0
- synth_ai/cli/commands/filter/core.py +424 -0
- synth_ai/cli/commands/filter/errors.py +55 -0
- synth_ai/cli/commands/filter/validation.py +77 -0
- synth_ai/cli/commands/help/__init__.py +185 -0
- synth_ai/cli/commands/help/core.py +72 -0
- synth_ai/cli/commands/smoke/__init__.py +7 -0
- synth_ai/cli/commands/smoke/core.py +1437 -0
- synth_ai/cli/commands/status/__init__.py +66 -0
- synth_ai/cli/commands/status/client.py +192 -0
- synth_ai/cli/commands/status/config.py +92 -0
- synth_ai/cli/commands/status/errors.py +20 -0
- synth_ai/cli/commands/status/formatters.py +164 -0
- synth_ai/cli/commands/status/subcommands/__init__.py +9 -0
- synth_ai/cli/commands/status/subcommands/files.py +79 -0
- synth_ai/cli/commands/status/subcommands/jobs.py +334 -0
- synth_ai/cli/commands/status/subcommands/models.py +79 -0
- synth_ai/cli/commands/status/subcommands/pricing.py +22 -0
- synth_ai/cli/commands/status/subcommands/runs.py +81 -0
- synth_ai/cli/commands/status/subcommands/session.py +183 -0
- synth_ai/cli/commands/status/subcommands/summary.py +47 -0
- synth_ai/cli/commands/status/subcommands/usage.py +203 -0
- synth_ai/cli/commands/status/utils.py +114 -0
- synth_ai/cli/commands/train/__init__.py +53 -0
- synth_ai/cli/commands/train/core.py +21 -0
- synth_ai/cli/commands/train/errors.py +117 -0
- synth_ai/cli/commands/train/judge_schemas.py +200 -0
- synth_ai/cli/commands/train/judge_validation.py +305 -0
- synth_ai/cli/commands/train/validation.py +386 -0
- synth_ai/cli/demo.py +32 -140
- synth_ai/cli/deploy.py +233 -0
- synth_ai/cli/eval/__init__.py +36 -0
- synth_ai/cli/eval/core.py +5 -0
- synth_ai/cli/eval/errors.py +31 -0
- synth_ai/cli/eval/validation.py +5 -0
- synth_ai/cli/filter/__init__.py +28 -0
- synth_ai/cli/filter/core.py +5 -0
- synth_ai/cli/filter/errors.py +23 -0
- synth_ai/cli/filter/validation.py +5 -0
- synth_ai/cli/legacy_root_backup.py +28 -22
- synth_ai/cli/lib/__init__.py +10 -0
- synth_ai/cli/lib/task_app_discovery.py +7 -0
- synth_ai/cli/lib/task_app_env.py +518 -0
- synth_ai/cli/mcp.py +34 -0
- synth_ai/cli/modal_serve/__init__.py +12 -0
- synth_ai/cli/modal_serve/core.py +14 -0
- synth_ai/cli/modal_serve/errors.py +8 -0
- synth_ai/cli/modal_serve/validation.py +11 -0
- synth_ai/cli/opencode.py +256 -0
- synth_ai/cli/recent.py +13 -7
- synth_ai/cli/rl_demo.py +156 -116
- synth_ai/cli/root.py +131 -132
- synth_ai/cli/serve/__init__.py +12 -0
- synth_ai/cli/serve/core.py +14 -0
- synth_ai/cli/serve/errors.py +8 -0
- synth_ai/cli/serve/validation.py +11 -0
- synth_ai/cli/setup.py +49 -0
- synth_ai/cli/status.py +7 -125
- synth_ai/cli/task_app_deploy.py +7 -0
- synth_ai/cli/task_app_list.py +25 -0
- synth_ai/cli/task_app_modal_serve.py +11 -0
- synth_ai/cli/task_app_serve.py +11 -0
- synth_ai/cli/task_apps.py +2284 -257
- synth_ai/cli/traces.py +9 -5
- synth_ai/cli/train/__init__.py +12 -0
- synth_ai/cli/train/core.py +21 -0
- synth_ai/cli/train/errors.py +8 -0
- synth_ai/cli/train/validation.py +24 -0
- synth_ai/cli/train.py +5 -0
- synth_ai/cli/turso.py +73 -0
- synth_ai/cli/watch.py +13 -18
- synth_ai/demos/__init__.py +10 -0
- synth_ai/demos/core/__init__.py +28 -1
- synth_ai/demos/core/cli.py +579 -291
- synth_ai/demos/crafter/__init__.py +1 -0
- synth_ai/demos/crafter/crafter_fft_4b.toml +55 -0
- synth_ai/demos/crafter/grpo_crafter_task_app.py +185 -0
- synth_ai/demos/crafter/rl_from_base_qwen4b.toml +74 -0
- synth_ai/demos/demo_registry.py +176 -0
- synth_ai/demos/demo_task_apps/__init__.py +3 -3
- synth_ai/demos/demo_task_apps/core.py +64 -28
- synth_ai/demos/demo_task_apps/crafter/__init__.py +1 -0
- synth_ai/demos/demo_task_apps/crafter/configs/crafter_fft_4b.toml +53 -0
- synth_ai/demos/demo_task_apps/crafter/configs/rl_from_base_qwen4b.toml +73 -0
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +184 -0
- synth_ai/demos/demo_task_apps/math/_common.py +1 -2
- synth_ai/demos/demo_task_apps/math/app.py +2 -1
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +3 -6
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +185 -83
- synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -2
- synth_ai/demos/math/__init__.py +1 -0
- synth_ai/demos/math/_common.py +16 -0
- synth_ai/demos/math/app.py +38 -0
- synth_ai/demos/math/config.toml +76 -0
- synth_ai/demos/math/deploy_modal.py +54 -0
- synth_ai/demos/math/modal_task_app.py +703 -0
- synth_ai/demos/math/task_app_entry.py +51 -0
- synth_ai/environments/environment/core.py +7 -1
- synth_ai/environments/examples/bandit/engine.py +12 -5
- synth_ai/environments/examples/bandit/environment.py +0 -1
- synth_ai/environments/examples/bandit/taskset.py +4 -4
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +7 -4
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +9 -5
- synth_ai/environments/examples/crafter_classic/environment.py +93 -2
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +4 -3
- synth_ai/environments/examples/enron/engine.py +7 -2
- synth_ai/environments/examples/enron/environment.py +68 -0
- synth_ai/environments/examples/red/engine.py +60 -12
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +7 -0
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +151 -179
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_progression.py +477 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +32 -0
- synth_ai/environments/examples/red/environment.py +86 -0
- synth_ai/environments/examples/red/trace_hooks_v3.py +168 -0
- synth_ai/environments/examples/sokoban/taskset.py +116 -0
- synth_ai/environments/examples/verilog/engine.py +104 -12
- synth_ai/environments/examples/wordle/environment.py +0 -1
- synth_ai/environments/reproducibility/tree.py +5 -6
- synth_ai/environments/service/app.py +11 -12
- synth_ai/environments/service/core_routes.py +10 -9
- synth_ai/environments/stateful/engine.py +1 -1
- synth_ai/environments/tasks/core.py +1 -0
- synth_ai/environments/tasks/filters.py +5 -6
- synth_ai/environments/tasks/utils.py +4 -5
- synth_ai/evals/__init__.py +15 -0
- synth_ai/evals/base.py +14 -5
- synth_ai/evals/client.py +82 -0
- synth_ai/evals/types.py +42 -0
- synth_ai/http.py +8 -22
- synth_ai/http_client.py +45 -12
- synth_ai/inference/__init__.py +0 -2
- synth_ai/inference/client.py +21 -7
- synth_ai/jobs/client.py +129 -80
- synth_ai/judge_schemas.py +127 -0
- synth_ai/learning/__init__.py +51 -6
- synth_ai/learning/algorithms.py +14 -0
- synth_ai/learning/client.py +122 -30
- synth_ai/learning/config.py +2 -40
- synth_ai/learning/constants.py +0 -2
- synth_ai/learning/ft_client.py +4 -56
- synth_ai/learning/health.py +14 -8
- synth_ai/learning/jobs.py +43 -47
- synth_ai/learning/prompt_learning_client.py +276 -0
- synth_ai/learning/prompt_learning_types.py +185 -0
- synth_ai/{rl → learning/rl}/__init__.py +14 -5
- synth_ai/learning/rl/client.py +269 -0
- synth_ai/learning/rl/config.py +31 -0
- synth_ai/{rl → learning/rl}/contracts.py +5 -10
- synth_ai/{rl → learning/rl}/env_keys.py +45 -16
- synth_ai/learning/rl/secrets.py +13 -0
- synth_ai/learning/rl_client.py +2 -253
- synth_ai/learning/sft/__init__.py +29 -0
- synth_ai/learning/sft/client.py +68 -0
- synth_ai/learning/sft/config.py +270 -0
- synth_ai/learning/sft/data.py +698 -0
- synth_ai/learning/sse.py +25 -26
- synth_ai/learning/validators.py +29 -25
- synth_ai/mcp/__init__.py +5 -0
- synth_ai/mcp/__main__.py +8 -0
- synth_ai/mcp/main.py +254 -0
- synth_ai/mcp/setup.py +100 -0
- synth_ai/modal.py +257 -0
- synth_ai/pricing/__init__.py +3 -0
- synth_ai/pricing/model_pricing.py +64 -0
- synth_ai/session/__init__.py +75 -0
- synth_ai/session/client.py +383 -0
- synth_ai/session/constants.py +63 -0
- synth_ai/session/exceptions.py +105 -0
- synth_ai/session/manager.py +139 -0
- synth_ai/session/models.py +89 -0
- synth_ai/session/query.py +110 -0
- synth_ai/spec/__init__.py +46 -0
- synth_ai/spec/dataclasses.py +149 -0
- synth_ai/spec/loader.py +144 -0
- synth_ai/spec/serializer.py +199 -0
- synth_ai/spec/validation.py +250 -0
- synth_ai/streaming/__init__.py +29 -0
- synth_ai/streaming/config.py +94 -0
- synth_ai/streaming/handlers.py +589 -0
- synth_ai/streaming/streamer.py +320 -0
- synth_ai/streaming/types.py +95 -0
- synth_ai/task/__init__.py +50 -30
- synth_ai/task/apps/__init__.py +63 -19
- synth_ai/task/auth.py +35 -23
- synth_ai/task/client.py +15 -13
- synth_ai/task/config.py +261 -0
- synth_ai/task/contracts.py +165 -64
- synth_ai/task/datasets.py +9 -6
- synth_ai/task/errors.py +11 -10
- synth_ai/task/health.py +17 -11
- synth_ai/task/inference_api.py +101 -0
- synth_ai/task/json.py +58 -24
- synth_ai/task/proxy.py +59 -66
- synth_ai/task/rubrics/__init__.py +55 -0
- synth_ai/task/rubrics/loaders.py +156 -0
- synth_ai/task/rubrics/models.py +57 -0
- synth_ai/task/rubrics/scoring.py +116 -0
- synth_ai/task/rubrics/strict.py +149 -0
- synth_ai/task/rubrics.py +22 -15
- synth_ai/task/server.py +65 -31
- synth_ai/task/trace_correlation_helpers.py +328 -0
- synth_ai/task/tracing_utils.py +44 -28
- synth_ai/task/validators.py +449 -6
- synth_ai/task/vendors.py +5 -7
- synth_ai/tracing_v3/__init__.py +4 -0
- synth_ai/tracing_v3/abstractions.py +21 -4
- synth_ai/tracing_v3/config.py +167 -22
- synth_ai/tracing_v3/constants.py +21 -0
- synth_ai/tracing_v3/db_config.py +42 -29
- synth_ai/tracing_v3/decorators.py +80 -45
- synth_ai/tracing_v3/examples/basic_usage.py +15 -9
- synth_ai/tracing_v3/hooks.py +6 -4
- synth_ai/tracing_v3/llm_call_record_helpers.py +161 -61
- synth_ai/tracing_v3/migration_helper.py +1 -2
- synth_ai/tracing_v3/replica_sync.py +12 -7
- synth_ai/tracing_v3/serialization.py +130 -0
- synth_ai/tracing_v3/session_tracer.py +73 -16
- synth_ai/tracing_v3/storage/base.py +89 -1
- synth_ai/tracing_v3/storage/config.py +63 -16
- synth_ai/tracing_v3/storage/factory.py +11 -9
- synth_ai/tracing_v3/storage/utils.py +15 -11
- synth_ai/tracing_v3/trace_utils.py +317 -0
- synth_ai/tracing_v3/turso/__init__.py +8 -21
- synth_ai/tracing_v3/turso/daemon.py +123 -15
- synth_ai/tracing_v3/turso/models.py +5 -2
- synth_ai/tracing_v3/turso/native_manager.py +1293 -0
- synth_ai/tracing_v3/utils.py +5 -4
- synth_ai/tunnel.py +143 -0
- synth_ai/tunnel_deploy.py +278 -0
- synth_ai/types.py +8 -0
- synth_ai/urls.py +11 -0
- synth_ai/utils/__init__.py +166 -0
- synth_ai/utils/agents.py +74 -0
- synth_ai/utils/apps.py +152 -0
- synth_ai/utils/base_url.py +94 -0
- synth_ai/utils/bin.py +39 -0
- synth_ai/utils/claude.py +36 -0
- synth_ai/utils/cli.py +284 -0
- synth_ai/utils/config.py +81 -0
- synth_ai/utils/env.py +346 -0
- synth_ai/utils/errors.py +85 -0
- synth_ai/utils/http.py +172 -0
- synth_ai/utils/json.py +72 -0
- synth_ai/utils/log_filter.py +99 -0
- synth_ai/utils/logging.py +198 -0
- synth_ai/utils/modal.py +299 -0
- synth_ai/utils/paths.py +95 -0
- synth_ai/utils/process.py +233 -0
- synth_ai/utils/prompts.py +39 -0
- synth_ai/utils/sqld.py +122 -0
- synth_ai/utils/ssl.py +25 -0
- synth_ai/utils/task_app_discovery.py +882 -0
- synth_ai/utils/task_app_env.py +186 -0
- synth_ai/utils/task_app_state.py +318 -0
- synth_ai/utils/tunnel/__init__.py +12 -0
- synth_ai/utils/tunnel/config.py +55 -0
- synth_ai/utils/user_config.py +137 -0
- synth_ai/uvicorn.py +77 -0
- synth_ai-0.2.23.dev3.dist-info/METADATA +357 -0
- synth_ai-0.2.23.dev3.dist-info/RECORD +983 -0
- {synth_ai-0.2.9.dev0.dist-info → synth_ai-0.2.23.dev3.dist-info}/entry_points.txt +0 -1
- {synth_ai-0.2.9.dev0.dist-info → synth_ai-0.2.23.dev3.dist-info}/top_level.txt +1 -0
- synth_ai/cli/man.py +0 -106
- synth_ai/core/experiment.py +0 -15
- synth_ai/core/system.py +0 -15
- synth_ai/demo_registry.py +0 -258
- synth_ai/environments/examples/sokoban/units/astar_common.py +0 -95
- synth_ai/experimental/synth_oss.py +0 -446
- synth_ai/handshake.py +0 -107
- synth_ai/install_sqld.sh +0 -40
- synth_ai/learning/offline/dpo.py +0 -0
- synth_ai/learning/offline/providers.py +0 -7
- synth_ai/learning/offline/sft.py +0 -0
- synth_ai/learning/offline/shared.py +0 -0
- synth_ai/learning/online/grpo.py +0 -0
- synth_ai/learning/online/irft.py +0 -0
- synth_ai/learning/prompts/banking77_injection_eval.py +0 -168
- synth_ai/learning/prompts/gepa.py +0 -0
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +0 -213
- synth_ai/learning/prompts/mipro.py +0 -289
- synth_ai/learning/prompts/random_search.py +0 -246
- synth_ai/learning/prompts/run_mipro_banking77.py +0 -172
- synth_ai/learning/prompts/run_random_search_banking77.py +0 -324
- synth_ai/lm/__init__.py +0 -51
- synth_ai/lm/caching/constants.py +0 -6
- synth_ai/lm/caching/dbs.py +0 -0
- synth_ai/lm/caching/ephemeral.py +0 -102
- synth_ai/lm/caching/handler.py +0 -137
- synth_ai/lm/caching/initialize.py +0 -11
- synth_ai/lm/caching/persistent.py +0 -114
- synth_ai/lm/config.py +0 -110
- synth_ai/lm/constants.py +0 -32
- synth_ai/lm/core/__init__.py +0 -8
- synth_ai/lm/core/all.py +0 -73
- synth_ai/lm/core/exceptions.py +0 -7
- synth_ai/lm/core/main.py +0 -319
- synth_ai/lm/core/main_v3.py +0 -594
- synth_ai/lm/core/synth_models.py +0 -48
- synth_ai/lm/core/vendor_clients.py +0 -188
- synth_ai/lm/cost/monitor.py +0 -1
- synth_ai/lm/cost/statefulness.py +0 -1
- synth_ai/lm/injection.py +0 -80
- synth_ai/lm/overrides.py +0 -206
- synth_ai/lm/provider_support/__init__.py +0 -8
- synth_ai/lm/provider_support/anthropic.py +0 -972
- synth_ai/lm/provider_support/openai.py +0 -1139
- synth_ai/lm/provider_support/suppress_logging.py +0 -31
- synth_ai/lm/structured_outputs/handler.py +0 -440
- synth_ai/lm/structured_outputs/inject.py +0 -297
- synth_ai/lm/structured_outputs/rehabilitate.py +0 -185
- synth_ai/lm/tools/__init__.py +0 -3
- synth_ai/lm/tools/base.py +0 -172
- synth_ai/lm/unified_interface.py +0 -202
- synth_ai/lm/vendors/base.py +0 -81
- synth_ai/lm/vendors/core/anthropic_api.py +0 -387
- synth_ai/lm/vendors/core/gemini_api.py +0 -292
- synth_ai/lm/vendors/core/mistral_api.py +0 -322
- synth_ai/lm/vendors/core/openai_api.py +0 -225
- synth_ai/lm/vendors/core/synth_dev_api.py +0 -0
- synth_ai/lm/vendors/local/ollama.py +0 -0
- synth_ai/lm/vendors/openai_standard.py +0 -780
- synth_ai/lm/vendors/openai_standard_responses.py +0 -256
- synth_ai/lm/vendors/retries.py +0 -22
- synth_ai/lm/vendors/supported/custom_endpoint.py +0 -417
- synth_ai/lm/vendors/supported/deepseek.py +0 -69
- synth_ai/lm/vendors/supported/grok.py +0 -75
- synth_ai/lm/vendors/supported/groq.py +0 -16
- synth_ai/lm/vendors/supported/ollama.py +0 -15
- synth_ai/lm/vendors/supported/openrouter.py +0 -74
- synth_ai/lm/vendors/supported/together.py +0 -11
- synth_ai/lm/vendors/synth_client.py +0 -808
- synth_ai/lm/warmup.py +0 -186
- synth_ai/rl/secrets.py +0 -19
- synth_ai/scripts/verify_rewards.py +0 -100
- synth_ai/task/apps/grpo_crafter.py +0 -438
- synth_ai/tracing/__init__.py +0 -30
- synth_ai/tracing_v1/__init__.py +0 -33
- synth_ai/tracing_v3/turso/manager.py +0 -774
- synth_ai/v0/tracing/abstractions.py +0 -224
- synth_ai/v0/tracing/base_client.py +0 -91
- synth_ai/v0/tracing/client_manager.py +0 -131
- synth_ai/v0/tracing/config.py +0 -142
- synth_ai/v0/tracing/context.py +0 -146
- synth_ai/v0/tracing/decorators.py +0 -682
- synth_ai/v0/tracing/events/__init__.py +0 -0
- synth_ai/v0/tracing/events/manage.py +0 -147
- synth_ai/v0/tracing/events/scope.py +0 -86
- synth_ai/v0/tracing/events/store.py +0 -228
- synth_ai/v0/tracing/immediate_client.py +0 -151
- synth_ai/v0/tracing/local.py +0 -18
- synth_ai/v0/tracing/log_client_base.py +0 -73
- synth_ai/v0/tracing/retry_queue.py +0 -186
- synth_ai/v0/tracing/trackers.py +0 -515
- synth_ai/v0/tracing/upload.py +0 -512
- synth_ai/v0/tracing/utils.py +0 -9
- synth_ai/v0/tracing_v1/__init__.py +0 -16
- synth_ai/v0/tracing_v1/abstractions.py +0 -224
- synth_ai/v0/tracing_v1/base_client.py +0 -91
- synth_ai/v0/tracing_v1/client_manager.py +0 -131
- synth_ai/v0/tracing_v1/config.py +0 -142
- synth_ai/v0/tracing_v1/context.py +0 -146
- synth_ai/v0/tracing_v1/decorators.py +0 -703
- synth_ai/v0/tracing_v1/events/__init__.py +0 -0
- synth_ai/v0/tracing_v1/events/manage.py +0 -147
- synth_ai/v0/tracing_v1/events/scope.py +0 -86
- synth_ai/v0/tracing_v1/events/store.py +0 -228
- synth_ai/v0/tracing_v1/immediate_client.py +0 -151
- synth_ai/v0/tracing_v1/local.py +0 -18
- synth_ai/v0/tracing_v1/log_client_base.py +0 -73
- synth_ai/v0/tracing_v1/retry_queue.py +0 -186
- synth_ai/v0/tracing_v1/trackers.py +0 -515
- synth_ai/v0/tracing_v1/upload.py +0 -527
- synth_ai/v0/tracing_v1/utils.py +0 -9
- synth_ai/zyk/__init__.py +0 -30
- synth_ai-0.2.9.dev0.dist-info/METADATA +0 -131
- synth_ai-0.2.9.dev0.dist-info/RECORD +0 -444
- {synth_ai/lm/caching → examples/task_apps}/__init__.py +0 -0
- {synth_ai/lm/cost → examples/task_apps/crafter}/__init__.py +0 -0
- {synth_ai/lm/structured_outputs → examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server}/__init__.py +0 -0
- {synth_ai/lm/vendors → examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests}/__init__.py +0 -0
- {synth_ai/lm/vendors/core → examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils}/__init__.py +0 -0
- {synth_ai/lm/vendors/local → examples/task_apps/math}/__init__.py +0 -0
- {synth_ai/lm/vendors/supported → examples/workflows}/__init__.py +0 -0
- {synth_ai/v0/tracing → examples/workflows/math_rl}/__init__.py +0 -0
- /synth_ai/{compound/cais.py → cli/__main__.py} +0 -0
- /synth_ai/{learning/filtering.py → py.typed} +0 -0
- {synth_ai-0.2.9.dev0.dist-info → synth_ai-0.2.23.dev3.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev0.dist-info → synth_ai-0.2.23.dev3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,2233 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
import time as _time
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
from typing import Any, Mapping
|
|
10
|
+
|
|
11
|
+
from fastapi import APIRouter, HTTPException, Request, status
|
|
12
|
+
from pydantic import BaseModel, Field
|
|
13
|
+
from synth_ai.tracing_v3 import BaseLMResponse
|
|
14
|
+
from synth_ai.task.tracing_utils import unique_sft_path
|
|
15
|
+
from synth_ai.tracing_v3.abstractions import EnvironmentEvent, LMCAISEvent, TimeRecord
|
|
16
|
+
from synth_ai.task.contracts import RolloutMode
|
|
17
|
+
from synth_ai.tracing_v3.llm_call_record_helpers import create_llm_call_record_from_response
|
|
18
|
+
from synth_ai.tracing_v3.session_tracer import SessionTracer
|
|
19
|
+
|
|
20
|
+
from .registry import registry
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
# --- Seeding utilities (robust, optional deps) ---
|
|
26
|
+
def _set_global_seed(seed_value: int) -> dict[str, Any]:
|
|
27
|
+
"""Set global RNG seeds across common libraries; return details for logging/restoration.
|
|
28
|
+
|
|
29
|
+
Returns a dict containing which libraries were seeded and prior states if obtainable.
|
|
30
|
+
"""
|
|
31
|
+
seeded: dict[str, Any] = {"seed": int(seed_value), "libs": []}
|
|
32
|
+
with contextlib.suppress(Exception):
|
|
33
|
+
import random as _random # type: ignore
|
|
34
|
+
|
|
35
|
+
_random.seed(seed_value)
|
|
36
|
+
seeded["libs"].append("random")
|
|
37
|
+
with contextlib.suppress(Exception):
|
|
38
|
+
import numpy as _np # type: ignore
|
|
39
|
+
|
|
40
|
+
_np.random.seed(seed_value)
|
|
41
|
+
seeded["libs"].append("numpy")
|
|
42
|
+
with contextlib.suppress(Exception):
|
|
43
|
+
import torch as _torch # type: ignore
|
|
44
|
+
|
|
45
|
+
if hasattr(_torch, "manual_seed"):
|
|
46
|
+
_torch.manual_seed(seed_value)
|
|
47
|
+
seeded["libs"].append("torch")
|
|
48
|
+
# Make CUDA deterministic if present (best-effort)
|
|
49
|
+
with contextlib.suppress(Exception):
|
|
50
|
+
if getattr(_torch, "cuda", None) and _torch.cuda.is_available():
|
|
51
|
+
_torch.cuda.manual_seed_all(seed_value)
|
|
52
|
+
seeded.setdefault("cuda", True)
|
|
53
|
+
# CUDNN deterministic flags (optional)
|
|
54
|
+
with contextlib.suppress(Exception):
|
|
55
|
+
if getattr(_torch, "backends", None) and getattr(_torch.backends, "cudnn", None):
|
|
56
|
+
_torch.backends.cudnn.deterministic = True # type: ignore[attr-defined]
|
|
57
|
+
_torch.backends.cudnn.benchmark = False # type: ignore[attr-defined]
|
|
58
|
+
return seeded
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _clear_seed_side_effects() -> None:
|
|
62
|
+
"""Best-effort cleanup to avoid global deterministic side-effects between requests."""
|
|
63
|
+
# We cannot truly restore prior RNG states without capturing them; we just avoid
|
|
64
|
+
# leaving aggressive deterministic flags enabled where it matters.
|
|
65
|
+
with contextlib.suppress(Exception):
|
|
66
|
+
import torch as _torch # type: ignore
|
|
67
|
+
|
|
68
|
+
with contextlib.suppress(Exception):
|
|
69
|
+
if getattr(_torch, "backends", None) and getattr(_torch.backends, "cudnn", None):
|
|
70
|
+
# Re-enable cudnn.benchmark default True only if it was True; safest is False -> leave as is.
|
|
71
|
+
# We'll keep deterministic False to avoid global impact; benchmark left False for stability.
|
|
72
|
+
_torch.backends.cudnn.deterministic = False # type: ignore[attr-defined]
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
router = APIRouter()
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class RolloutEnvSpec(BaseModel):
|
|
79
|
+
env_id: str | None = None
|
|
80
|
+
env_name: str | None = None
|
|
81
|
+
config: dict[str, Any] = {}
|
|
82
|
+
seed: int | None = None
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class RolloutPolicySpec(BaseModel):
|
|
86
|
+
policy_id: str | None = None
|
|
87
|
+
policy_name: str | None = None
|
|
88
|
+
config: dict[str, Any] = {}
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class RolloutBranchConfig(BaseModel):
|
|
92
|
+
branch_every_n_steps: int = 0
|
|
93
|
+
branch_on_condition: str | None = None
|
|
94
|
+
max_branches: int = 0
|
|
95
|
+
branch_policy: bool = False
|
|
96
|
+
branch_env: bool = False
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class RolloutRecordConfig(BaseModel):
|
|
100
|
+
trajectories: bool = True
|
|
101
|
+
logprobs: bool = False
|
|
102
|
+
value: bool = False
|
|
103
|
+
return_trace: bool = False
|
|
104
|
+
trace_format: str = "compact"
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class RolloutSafetyConfig(BaseModel):
|
|
108
|
+
max_ops: int = 100000
|
|
109
|
+
max_time_s: float = 3600.0
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class RolloutRequest(BaseModel):
|
|
113
|
+
run_id: str
|
|
114
|
+
env: RolloutEnvSpec
|
|
115
|
+
policy: RolloutPolicySpec
|
|
116
|
+
ops: list[str] # ["agent", "env", ...]
|
|
117
|
+
record: RolloutRecordConfig = RolloutRecordConfig()
|
|
118
|
+
on_done: str = "reset" # "reset" | "terminate"
|
|
119
|
+
branch: RolloutBranchConfig | None = None
|
|
120
|
+
safety: RolloutSafetyConfig = RolloutSafetyConfig()
|
|
121
|
+
# Optional run/session context
|
|
122
|
+
training_session_id: str | None = None
|
|
123
|
+
synth_base_url: str | None = None
|
|
124
|
+
# Mode controls URL transformation: REQUIRED to make intent explicit
|
|
125
|
+
mode: RolloutMode
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class RolloutStep(BaseModel):
|
|
129
|
+
obs: dict[str, Any]
|
|
130
|
+
tool_calls: list[dict[str, Any]]
|
|
131
|
+
reward: float | None = None
|
|
132
|
+
done: bool = False
|
|
133
|
+
truncated: bool | None = None
|
|
134
|
+
logprob: float | None = None
|
|
135
|
+
value: float | None = None
|
|
136
|
+
info: dict[str, Any] | None = None
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class RolloutTrajectory(BaseModel):
|
|
140
|
+
env_id: str
|
|
141
|
+
policy_id: str
|
|
142
|
+
steps: list[RolloutStep]
|
|
143
|
+
final: dict[str, Any] | None = None
|
|
144
|
+
length: int
|
|
145
|
+
decision_samples: list[dict[str, Any]] | None = None
|
|
146
|
+
inference_url: str | None = None
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def _normalize_step_strategy(raw_strategy: Any) -> str:
|
|
150
|
+
if not isinstance(raw_strategy, str):
|
|
151
|
+
return "consistent"
|
|
152
|
+
candidate = raw_strategy.strip().lower()
|
|
153
|
+
if not candidate:
|
|
154
|
+
return "consistent"
|
|
155
|
+
mapping = {
|
|
156
|
+
"simple": "consistent",
|
|
157
|
+
"consistent": "consistent",
|
|
158
|
+
"consistent_stepwise": "consistent",
|
|
159
|
+
"decision_consistent": "consistent",
|
|
160
|
+
"per_achievement": "per_achievement",
|
|
161
|
+
"per-achievement": "per_achievement",
|
|
162
|
+
"perachievement": "per_achievement",
|
|
163
|
+
"achievement_weighted": "per_achievement",
|
|
164
|
+
"complex": "per_achievement",
|
|
165
|
+
}
|
|
166
|
+
return mapping.get(candidate, "consistent")
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def _coerce_weights(raw_weights: Any) -> dict[str, float]:
|
|
170
|
+
weights: dict[str, float] = {}
|
|
171
|
+
if isinstance(raw_weights, dict):
|
|
172
|
+
for key, value in raw_weights.items():
|
|
173
|
+
try:
|
|
174
|
+
weights[str(key)] = float(value)
|
|
175
|
+
except Exception:
|
|
176
|
+
continue
|
|
177
|
+
return weights
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _coerce_k_limits(raw_limits: Any) -> dict[str, int]:
|
|
181
|
+
limits: dict[str, int] = {}
|
|
182
|
+
if isinstance(raw_limits, dict):
|
|
183
|
+
for key, value in raw_limits.items():
|
|
184
|
+
try:
|
|
185
|
+
limits[str(key)] = int(value)
|
|
186
|
+
except Exception:
|
|
187
|
+
continue
|
|
188
|
+
return limits
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def _coerce_int_value(value: Any) -> int | None:
|
|
192
|
+
if isinstance(value, bool):
|
|
193
|
+
return int(value)
|
|
194
|
+
try:
|
|
195
|
+
return int(value) # type: ignore[arg-type]
|
|
196
|
+
except Exception:
|
|
197
|
+
try:
|
|
198
|
+
return int(float(value)) # type: ignore[arg-type]
|
|
199
|
+
except Exception:
|
|
200
|
+
return None
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def _compute_resource_reward(
|
|
204
|
+
prev_inventory: Mapping[str, Any] | None,
|
|
205
|
+
new_inventory: Mapping[str, Any] | None,
|
|
206
|
+
prev_counts: Mapping[str, Any] | None,
|
|
207
|
+
new_counts: Mapping[str, Any] | None,
|
|
208
|
+
) -> tuple[float, list[dict[str, Any]], dict[str, int], dict[str, int]]:
|
|
209
|
+
reward_total = 0.0
|
|
210
|
+
components: list[dict[str, Any]] = []
|
|
211
|
+
inventory_deltas: dict[str, int] = {}
|
|
212
|
+
achievement_deltas: dict[str, int] = {}
|
|
213
|
+
|
|
214
|
+
resource_weights = {
|
|
215
|
+
"wood": 0.10,
|
|
216
|
+
"sapling": 0.08,
|
|
217
|
+
"stone": 0.15,
|
|
218
|
+
"coal": 0.18,
|
|
219
|
+
"iron": 0.22,
|
|
220
|
+
"plant": 0.06,
|
|
221
|
+
"meat": 0.12,
|
|
222
|
+
"drink": 0.07,
|
|
223
|
+
"food": 0.07,
|
|
224
|
+
"water": 0.07,
|
|
225
|
+
"energy": 0.04,
|
|
226
|
+
}
|
|
227
|
+
tool_weights = {
|
|
228
|
+
"wood_pickaxe": 0.40,
|
|
229
|
+
"stone_pickaxe": 0.55,
|
|
230
|
+
"iron_pickaxe": 0.75,
|
|
231
|
+
"wood_sword": 0.35,
|
|
232
|
+
"stone_sword": 0.50,
|
|
233
|
+
"iron_sword": 0.70,
|
|
234
|
+
"furnace": 0.45,
|
|
235
|
+
"table": 0.30,
|
|
236
|
+
"bow": 0.45,
|
|
237
|
+
}
|
|
238
|
+
achievement_weights = {
|
|
239
|
+
"collect_wood": 0.08,
|
|
240
|
+
"collect_sapling": 0.06,
|
|
241
|
+
"collect_stone": 0.10,
|
|
242
|
+
"collect_coal": 0.12,
|
|
243
|
+
"collect_iron": 0.14,
|
|
244
|
+
"collect_drink": 0.06,
|
|
245
|
+
"collect_food": 0.06,
|
|
246
|
+
"collect_plant": 0.06,
|
|
247
|
+
}
|
|
248
|
+
default_resource_weight = 0.05
|
|
249
|
+
default_achievement_weight = 0.05
|
|
250
|
+
|
|
251
|
+
prev_inv = prev_inventory or {}
|
|
252
|
+
new_inv = new_inventory or {}
|
|
253
|
+
for key, raw_value in new_inv.items():
|
|
254
|
+
new_val = _coerce_int_value(raw_value)
|
|
255
|
+
if new_val is None:
|
|
256
|
+
continue
|
|
257
|
+
prev_val = _coerce_int_value(prev_inv.get(key, 0)) or 0
|
|
258
|
+
delta = new_val - prev_val
|
|
259
|
+
if delta <= 0:
|
|
260
|
+
continue
|
|
261
|
+
weight = resource_weights.get(key)
|
|
262
|
+
if weight is None and key in tool_weights:
|
|
263
|
+
weight = tool_weights[key]
|
|
264
|
+
if weight is None:
|
|
265
|
+
weight = default_resource_weight
|
|
266
|
+
gain = weight * delta
|
|
267
|
+
reward_total += gain
|
|
268
|
+
inventory_deltas[str(key)] = delta
|
|
269
|
+
components.append(
|
|
270
|
+
{
|
|
271
|
+
"type": "inventory",
|
|
272
|
+
"item": str(key),
|
|
273
|
+
"delta": delta,
|
|
274
|
+
"weight": weight,
|
|
275
|
+
"reward": gain,
|
|
276
|
+
}
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
prev_ct = prev_counts or {}
|
|
280
|
+
new_ct = new_counts or {}
|
|
281
|
+
for key, raw_value in new_ct.items():
|
|
282
|
+
new_val = _coerce_int_value(raw_value)
|
|
283
|
+
if new_val is None:
|
|
284
|
+
continue
|
|
285
|
+
prev_val = _coerce_int_value(prev_ct.get(key, 0)) or 0
|
|
286
|
+
delta = new_val - prev_val
|
|
287
|
+
if delta <= 0:
|
|
288
|
+
continue
|
|
289
|
+
weight = achievement_weights.get(key, default_achievement_weight)
|
|
290
|
+
gain = weight * delta
|
|
291
|
+
reward_total += gain
|
|
292
|
+
achievement_deltas[str(key)] = delta
|
|
293
|
+
components.append(
|
|
294
|
+
{
|
|
295
|
+
"type": "achievement_count",
|
|
296
|
+
"name": str(key),
|
|
297
|
+
"delta": delta,
|
|
298
|
+
"weight": weight,
|
|
299
|
+
"reward": gain,
|
|
300
|
+
}
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
return reward_total, components, inventory_deltas, achievement_deltas
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def compute_stepwise_reward(
|
|
307
|
+
prev_achievements: dict[str, bool],
|
|
308
|
+
new_achievements: dict[str, bool],
|
|
309
|
+
decision_index: int,
|
|
310
|
+
actions_summary: list[dict[str, Any]],
|
|
311
|
+
indicator_lambda: float,
|
|
312
|
+
*,
|
|
313
|
+
strategy: str | None = None,
|
|
314
|
+
weights: dict[str, float] | None = None,
|
|
315
|
+
k_limits: dict[str, int] | None = None,
|
|
316
|
+
episode_counts: dict[str, int] | None = None,
|
|
317
|
+
prev_inventory: dict[str, int] | None = None,
|
|
318
|
+
new_inventory: dict[str, int] | None = None,
|
|
319
|
+
prev_counts: dict[str, int] | None = None,
|
|
320
|
+
new_counts: dict[str, int] | None = None,
|
|
321
|
+
) -> tuple[dict[str, Any], dict[str, Any], dict[str, float]]:
|
|
322
|
+
"""Compute stepwise reward metadata given achievement states before/after a decision."""
|
|
323
|
+
|
|
324
|
+
prev_map = prev_achievements or {}
|
|
325
|
+
next_map = new_achievements or {}
|
|
326
|
+
|
|
327
|
+
unlocked = [name for name, value in next_map.items() if value and not prev_map.get(name, False)]
|
|
328
|
+
indicator_from_achievements = 1 if unlocked else 0
|
|
329
|
+
normalized_strategy = _normalize_step_strategy(strategy)
|
|
330
|
+
base_reward = 0.0
|
|
331
|
+
reward_components: list[dict[str, Any]] = []
|
|
332
|
+
credited: list[str] = []
|
|
333
|
+
|
|
334
|
+
if indicator_from_achievements:
|
|
335
|
+
if normalized_strategy == "per_achievement":
|
|
336
|
+
weight_map = weights or {}
|
|
337
|
+
limit_map = k_limits or {}
|
|
338
|
+
counts = episode_counts if isinstance(episode_counts, dict) else {}
|
|
339
|
+
for name in unlocked:
|
|
340
|
+
try:
|
|
341
|
+
limit_val = int(limit_map.get(name, 1))
|
|
342
|
+
except Exception:
|
|
343
|
+
limit_val = 1
|
|
344
|
+
# limit_val <= 0 implies unlimited rewards
|
|
345
|
+
unlimited = limit_val <= 0
|
|
346
|
+
try:
|
|
347
|
+
prev_count = int(counts.get(name, 0))
|
|
348
|
+
except Exception:
|
|
349
|
+
prev_count = 0
|
|
350
|
+
should_credit = unlimited or (prev_count < max(limit_val, 0))
|
|
351
|
+
if should_credit:
|
|
352
|
+
try:
|
|
353
|
+
weight_val = float(weight_map.get(name, 1.0))
|
|
354
|
+
except Exception:
|
|
355
|
+
weight_val = 1.0
|
|
356
|
+
base_reward += weight_val
|
|
357
|
+
reward_components.append(
|
|
358
|
+
{
|
|
359
|
+
"achievement": name,
|
|
360
|
+
"weight": weight_val,
|
|
361
|
+
"count_prior": prev_count,
|
|
362
|
+
"count_limit": limit_val,
|
|
363
|
+
}
|
|
364
|
+
)
|
|
365
|
+
credited.append(name)
|
|
366
|
+
if episode_counts is not None:
|
|
367
|
+
episode_counts[name] = prev_count + 1
|
|
368
|
+
else:
|
|
369
|
+
base_reward = 1.0
|
|
370
|
+
reward_components.append(
|
|
371
|
+
{
|
|
372
|
+
"achievement": "__indicator__",
|
|
373
|
+
"weight": 1.0,
|
|
374
|
+
"count_prior": 0,
|
|
375
|
+
"count_limit": 1,
|
|
376
|
+
}
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
resource_reward = 0.0
|
|
380
|
+
resource_components: list[dict[str, Any]] = []
|
|
381
|
+
inventory_deltas: dict[str, int] = {}
|
|
382
|
+
achievement_deltas: dict[str, int] = {}
|
|
383
|
+
if normalized_strategy == "per_achievement":
|
|
384
|
+
(
|
|
385
|
+
resource_reward,
|
|
386
|
+
resource_components,
|
|
387
|
+
inventory_deltas,
|
|
388
|
+
achievement_deltas,
|
|
389
|
+
) = _compute_resource_reward(prev_inventory, new_inventory, prev_counts, new_counts)
|
|
390
|
+
if resource_components:
|
|
391
|
+
reward_components.extend(resource_components)
|
|
392
|
+
base_reward += resource_reward
|
|
393
|
+
|
|
394
|
+
indicator = 1 if base_reward > 0 else 0
|
|
395
|
+
if indicator == 0 and indicator_from_achievements:
|
|
396
|
+
indicator = indicator_from_achievements
|
|
397
|
+
lambda_effective = indicator_lambda if indicator_lambda not in (None, 0) else 1.0
|
|
398
|
+
reward_value = float(lambda_effective) * float(base_reward)
|
|
399
|
+
|
|
400
|
+
stepwise_info = {
|
|
401
|
+
"decision_index": decision_index,
|
|
402
|
+
"indicator": indicator,
|
|
403
|
+
"new_achievements": unlocked,
|
|
404
|
+
"reward": reward_value,
|
|
405
|
+
"strategy": normalized_strategy,
|
|
406
|
+
"base_reward": float(base_reward),
|
|
407
|
+
}
|
|
408
|
+
if indicator_from_achievements and not unlocked:
|
|
409
|
+
stepwise_info["indicator_from_achievements"] = indicator_from_achievements
|
|
410
|
+
if reward_components:
|
|
411
|
+
stepwise_info["components"] = reward_components
|
|
412
|
+
if credited:
|
|
413
|
+
stepwise_info["credited_achievements"] = credited
|
|
414
|
+
if resource_reward:
|
|
415
|
+
stepwise_info["resource_reward"] = float(resource_reward)
|
|
416
|
+
if inventory_deltas:
|
|
417
|
+
stepwise_info["inventory_deltas"] = inventory_deltas
|
|
418
|
+
if achievement_deltas:
|
|
419
|
+
stepwise_info["achievement_count_deltas"] = achievement_deltas
|
|
420
|
+
|
|
421
|
+
decision_sample = {
|
|
422
|
+
"decision_index": decision_index,
|
|
423
|
+
"indicator": indicator,
|
|
424
|
+
"r_i": reward_value,
|
|
425
|
+
"base": float(base_reward),
|
|
426
|
+
"strategy": normalized_strategy,
|
|
427
|
+
"actions": actions_summary,
|
|
428
|
+
}
|
|
429
|
+
if reward_components:
|
|
430
|
+
decision_sample["components"] = reward_components
|
|
431
|
+
if resource_reward:
|
|
432
|
+
decision_sample["resource_reward"] = float(resource_reward)
|
|
433
|
+
|
|
434
|
+
stats = {
|
|
435
|
+
"indicator": float(indicator),
|
|
436
|
+
"reward": reward_value,
|
|
437
|
+
"new_achievements_count": float(len(unlocked)),
|
|
438
|
+
"base_reward": float(base_reward),
|
|
439
|
+
"credited_achievements_count": float(len(credited)),
|
|
440
|
+
}
|
|
441
|
+
if resource_reward:
|
|
442
|
+
stats["resource_reward"] = float(resource_reward)
|
|
443
|
+
return stepwise_info, decision_sample, stats
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
class RolloutMetrics(BaseModel):
|
|
447
|
+
episode_returns: list[float]
|
|
448
|
+
mean_return: float
|
|
449
|
+
num_steps: int
|
|
450
|
+
num_episodes: int = 0
|
|
451
|
+
outcome_score: float | None = None
|
|
452
|
+
events_score: float | None = None
|
|
453
|
+
details: dict[str, Any] = Field(default_factory=dict)
|
|
454
|
+
|
|
455
|
+
|
|
456
|
+
class RolloutResponse(BaseModel):
|
|
457
|
+
run_id: str
|
|
458
|
+
trajectories: list[RolloutTrajectory]
|
|
459
|
+
branches: dict[str, list[str]] = Field(default_factory=dict)
|
|
460
|
+
metrics: RolloutMetrics
|
|
461
|
+
aborted: bool = False
|
|
462
|
+
ops_executed: int = 0
|
|
463
|
+
trace: dict[str, Any] | None = None
|
|
464
|
+
pipeline_metadata: dict[str, Any] = Field(default_factory=dict)
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
class RolloutTracingContext:
|
|
468
|
+
"""Helper managing tracing_v3 recording and optional SFT dumps for a rollout."""
|
|
469
|
+
|
|
470
|
+
def __init__(
|
|
471
|
+
self,
|
|
472
|
+
tracer: SessionTracer | None,
|
|
473
|
+
request: RolloutRequest,
|
|
474
|
+
fastapi_request: Request,
|
|
475
|
+
) -> None:
|
|
476
|
+
self.tracer = tracer
|
|
477
|
+
self.enabled = tracer is not None
|
|
478
|
+
self.request = request
|
|
479
|
+
self.fastapi_request = fastapi_request
|
|
480
|
+
self.run_id = request.run_id
|
|
481
|
+
self.current_step_id: str | None = None
|
|
482
|
+
self.current_turn: int | None = None
|
|
483
|
+
self.lm_calls_summary: list[dict[str, Any]] = []
|
|
484
|
+
self.decision_rewards: list[dict[str, Any]] = []
|
|
485
|
+
self.sft_records: list[dict[str, Any]] = []
|
|
486
|
+
self.latest_system_messages: list[str] = []
|
|
487
|
+
self.latest_user_messages: list[str] = []
|
|
488
|
+
self.latest_system_prompt_content: list[Any] = []
|
|
489
|
+
self.latest_user_prompt_content: list[Any] = []
|
|
490
|
+
self.trace_format = (
|
|
491
|
+
getattr(request.record, "trace_format", "compact") or "compact"
|
|
492
|
+
).lower()
|
|
493
|
+
self.return_trace = bool(getattr(request.record, "return_trace", False))
|
|
494
|
+
print(
|
|
495
|
+
f"[TRACE_DEBUG] RolloutTracingContext init: trace_format={self.trace_format} return_trace={self.return_trace}",
|
|
496
|
+
flush=True,
|
|
497
|
+
)
|
|
498
|
+
self.sft_output_dir = getattr(fastapi_request.app.state, "sft_output_dir", None)
|
|
499
|
+
self.session_trace = None
|
|
500
|
+
self.metadata_updates: dict[str, Any] = {}
|
|
501
|
+
self.policy_name = request.policy.policy_name or ""
|
|
502
|
+
self.env_name = request.env.env_name or ""
|
|
503
|
+
self.metadata_base: dict[str, Any] = {
|
|
504
|
+
"run_id": self.run_id,
|
|
505
|
+
"policy_name": self.policy_name,
|
|
506
|
+
"policy_id": request.policy.policy_id,
|
|
507
|
+
"env_name": self.env_name,
|
|
508
|
+
"env_id": request.env.env_id,
|
|
509
|
+
"seed": request.env.seed,
|
|
510
|
+
"training_session_id": request.training_session_id,
|
|
511
|
+
"synth_base_url": request.synth_base_url,
|
|
512
|
+
}
|
|
513
|
+
|
|
514
|
+
# Expose context for downstream calls inside this request lifecycle
|
|
515
|
+
fastapi_request.state.rollout_tracing = self
|
|
516
|
+
fastapi_request.state.rollout_run_id = self.run_id
|
|
517
|
+
|
|
518
|
+
async def start_session(self) -> None:
|
|
519
|
+
if not self.enabled or self.tracer is None:
|
|
520
|
+
print("[TRACE_DEBUG] start_session skipped: tracer disabled", flush=True)
|
|
521
|
+
return
|
|
522
|
+
try:
|
|
523
|
+
await self.tracer.initialize()
|
|
524
|
+
print("[TRACE_DEBUG] tracer initialized", flush=True)
|
|
525
|
+
except Exception as exc:
|
|
526
|
+
logger.debug("TRACING_INIT_FAIL: %s", exc)
|
|
527
|
+
# Hard fail: tracing requested but cannot initialize
|
|
528
|
+
raise
|
|
529
|
+
try:
|
|
530
|
+
await self.tracer.start_session(
|
|
531
|
+
session_id=self.run_id, metadata=dict(self.metadata_base)
|
|
532
|
+
)
|
|
533
|
+
print(f"[TRACE_DEBUG] start_session succeeded for run_id={self.run_id}", flush=True)
|
|
534
|
+
except Exception as exc:
|
|
535
|
+
logger.info("TRACING_START_FAIL: %s", exc)
|
|
536
|
+
# Hard fail: tracing requested but cannot start session
|
|
537
|
+
raise
|
|
538
|
+
|
|
539
|
+
async def start_decision(self, turn_number: int) -> None:
|
|
540
|
+
self.current_turn = turn_number
|
|
541
|
+
self.current_step_id = f"decision_{turn_number}"
|
|
542
|
+
if not self.enabled or self.tracer is None:
|
|
543
|
+
return
|
|
544
|
+
try:
|
|
545
|
+
await self.tracer.start_timestep(step_id=self.current_step_id, turn_number=turn_number)
|
|
546
|
+
except Exception as exc:
|
|
547
|
+
logger.debug("TRACING_STEP_START_FAIL: %s", exc)
|
|
548
|
+
|
|
549
|
+
async def end_decision(self) -> None:
|
|
550
|
+
if not self.enabled or self.tracer is None:
|
|
551
|
+
return
|
|
552
|
+
try:
|
|
553
|
+
await self.tracer.end_timestep(step_id=self.current_step_id)
|
|
554
|
+
except Exception as exc:
|
|
555
|
+
logger.debug("TRACING_STEP_END_FAIL: %s", exc)
|
|
556
|
+
finally:
|
|
557
|
+
self.current_step_id = None
|
|
558
|
+
|
|
559
|
+
def _message_metadata(self) -> dict[str, Any]:
|
|
560
|
+
return {
|
|
561
|
+
"turn": self.current_turn,
|
|
562
|
+
"step_id": self.current_step_id,
|
|
563
|
+
}
|
|
564
|
+
|
|
565
|
+
async def record_policy_prompts(
|
|
566
|
+
self,
|
|
567
|
+
system_messages: list[Any],
|
|
568
|
+
user_messages: list[Any],
|
|
569
|
+
) -> None:
|
|
570
|
+
self.latest_system_messages = [self._prompt_text(entry) for entry in system_messages]
|
|
571
|
+
self.latest_user_messages = [self._prompt_text(entry) for entry in user_messages]
|
|
572
|
+
self.latest_system_prompt_content = [
|
|
573
|
+
self._prompt_content(entry, role="system") for entry in system_messages
|
|
574
|
+
]
|
|
575
|
+
self.latest_user_prompt_content = [
|
|
576
|
+
self._prompt_content(entry, role="user") for entry in user_messages
|
|
577
|
+
]
|
|
578
|
+
if not self.enabled or self.tracer is None:
|
|
579
|
+
return
|
|
580
|
+
for entry in system_messages:
|
|
581
|
+
try:
|
|
582
|
+
await self.tracer.record_message(
|
|
583
|
+
content=self._prompt_payload(entry, role="system"),
|
|
584
|
+
message_type="system", # Use standard message type
|
|
585
|
+
metadata=self._message_metadata(),
|
|
586
|
+
)
|
|
587
|
+
except Exception as exc:
|
|
588
|
+
logger.debug("TRACING_SYSTEM_MSG_FAIL: %s", exc)
|
|
589
|
+
for entry in user_messages:
|
|
590
|
+
try:
|
|
591
|
+
await self.tracer.record_message(
|
|
592
|
+
content=self._prompt_payload(entry, role="user"),
|
|
593
|
+
message_type="user", # Use standard message type
|
|
594
|
+
metadata=self._message_metadata(),
|
|
595
|
+
)
|
|
596
|
+
except Exception as exc:
|
|
597
|
+
logger.debug("TRACING_USER_MSG_FAIL: %s", exc)
|
|
598
|
+
|
|
599
|
+
# Debug: Check message count
|
|
600
|
+
if self.tracer and self.tracer._current_trace:
|
|
601
|
+
msg_count = len(self.tracer._current_trace.markov_blanket_message_history)
|
|
602
|
+
print(f"[TRACE_DEBUG] After record_policy_prompts: {msg_count} messages", flush=True)
|
|
603
|
+
|
|
604
|
+
def _content_to_text(self, content: Any) -> str:
|
|
605
|
+
if isinstance(content, str):
|
|
606
|
+
return content
|
|
607
|
+
if isinstance(content, list):
|
|
608
|
+
parts: list[str] = []
|
|
609
|
+
for seg in content:
|
|
610
|
+
if isinstance(seg, dict):
|
|
611
|
+
text_val = seg.get("text") or seg.get("content")
|
|
612
|
+
if isinstance(text_val, str):
|
|
613
|
+
parts.append(text_val)
|
|
614
|
+
return "".join(parts)
|
|
615
|
+
if content is None:
|
|
616
|
+
return ""
|
|
617
|
+
return str(content)
|
|
618
|
+
|
|
619
|
+
def _prompt_text(self, entry: Any) -> str:
|
|
620
|
+
if isinstance(entry, dict):
|
|
621
|
+
text = entry.get("text")
|
|
622
|
+
if isinstance(text, str):
|
|
623
|
+
return text
|
|
624
|
+
content = entry.get("content")
|
|
625
|
+
return self._content_to_text(content)
|
|
626
|
+
return self._content_to_text(entry)
|
|
627
|
+
|
|
628
|
+
def _prompt_payload(self, entry: Any, *, role: str) -> dict[str, Any]:
|
|
629
|
+
if isinstance(entry, dict):
|
|
630
|
+
payload = dict(entry)
|
|
631
|
+
payload.setdefault("role", role)
|
|
632
|
+
return payload
|
|
633
|
+
return {
|
|
634
|
+
"role": role,
|
|
635
|
+
"text": self._prompt_text(entry),
|
|
636
|
+
"content": entry,
|
|
637
|
+
}
|
|
638
|
+
|
|
639
|
+
def _prompt_content(self, entry: Any, *, role: str) -> Any:
|
|
640
|
+
payload = self._prompt_payload(entry, role=role)
|
|
641
|
+
return payload.get("content", payload.get("text"))
|
|
642
|
+
|
|
643
|
+
def _content_has_image(self, content: Any) -> bool:
|
|
644
|
+
if isinstance(content, list):
|
|
645
|
+
return any(
|
|
646
|
+
isinstance(seg, dict)
|
|
647
|
+
and seg.get("type") in {"image", "image_url"}
|
|
648
|
+
for seg in content
|
|
649
|
+
)
|
|
650
|
+
if isinstance(content, dict):
|
|
651
|
+
if content.get("type") in {"image", "image_url"}:
|
|
652
|
+
return True
|
|
653
|
+
inner = content.get("content")
|
|
654
|
+
if isinstance(inner, list):
|
|
655
|
+
return any(
|
|
656
|
+
isinstance(seg, dict)
|
|
657
|
+
and seg.get("type") in {"image", "image_url"}
|
|
658
|
+
for seg in inner
|
|
659
|
+
)
|
|
660
|
+
return False
|
|
661
|
+
|
|
662
|
+
def _safe_json(self, payload: Any, limit: int = 4000) -> str:
|
|
663
|
+
try:
|
|
664
|
+
text = json.dumps(payload, ensure_ascii=False)
|
|
665
|
+
except Exception:
|
|
666
|
+
text = str(payload)
|
|
667
|
+
if len(text) > limit:
|
|
668
|
+
return text[:limit] + "…"
|
|
669
|
+
return text
|
|
670
|
+
|
|
671
|
+
async def record_tool_invocation(self, tool_calls: list[dict[str, Any]] | None) -> None:
|
|
672
|
+
if tool_calls is None:
|
|
673
|
+
return
|
|
674
|
+
if self.enabled and self.tracer is not None:
|
|
675
|
+
try:
|
|
676
|
+
payload = {
|
|
677
|
+
"role": "assistant",
|
|
678
|
+
"tool_calls": tool_calls,
|
|
679
|
+
}
|
|
680
|
+
await self.tracer.record_message(
|
|
681
|
+
content=payload,
|
|
682
|
+
message_type="assistant",
|
|
683
|
+
metadata={**self._message_metadata(), "is_tool_call": True},
|
|
684
|
+
)
|
|
685
|
+
if self.tracer._current_trace:
|
|
686
|
+
print(
|
|
687
|
+
f"[TRACE_DEBUG] After tool invocation: messages={len(self.tracer._current_trace.markov_blanket_message_history)}",
|
|
688
|
+
flush=True,
|
|
689
|
+
)
|
|
690
|
+
except Exception as exc:
|
|
691
|
+
logger.debug("TRACING_TOOL_MSG_FAIL: %s", exc)
|
|
692
|
+
|
|
693
|
+
async def _record_event(self, event: Any) -> int | None:
|
|
694
|
+
if not self.enabled or self.tracer is None:
|
|
695
|
+
return None
|
|
696
|
+
try:
|
|
697
|
+
return await self.tracer.record_event(event)
|
|
698
|
+
except Exception as exc:
|
|
699
|
+
logger.debug("TRACING_EVENT_FAIL: %s", exc)
|
|
700
|
+
return None
|
|
701
|
+
|
|
702
|
+
async def record_llm_call(
|
|
703
|
+
self,
|
|
704
|
+
*,
|
|
705
|
+
inference_request: dict[str, Any],
|
|
706
|
+
inference_response: dict[str, Any],
|
|
707
|
+
tool_calls: list[dict[str, Any]] | None,
|
|
708
|
+
provider: str,
|
|
709
|
+
model_name: str,
|
|
710
|
+
started_at: datetime,
|
|
711
|
+
completed_at: datetime,
|
|
712
|
+
latency_ms: int | None,
|
|
713
|
+
) -> None:
|
|
714
|
+
usage = inference_response.get("usage") or {}
|
|
715
|
+
input_tokens = usage.get("input_tokens") or usage.get("prompt_tokens")
|
|
716
|
+
output_tokens = usage.get("output_tokens") or usage.get("completion_tokens")
|
|
717
|
+
total_tokens = usage.get("total_tokens")
|
|
718
|
+
cost_usd = usage.get("cost_usd") or usage.get("cost") or usage.get("total_cost")
|
|
719
|
+
|
|
720
|
+
assistant_message = None
|
|
721
|
+
choices = inference_response.get("choices") or []
|
|
722
|
+
if choices:
|
|
723
|
+
assistant_message = choices[0].get("message") or {}
|
|
724
|
+
assistant_content = (
|
|
725
|
+
assistant_message.get("content") if isinstance(assistant_message, dict) else None
|
|
726
|
+
)
|
|
727
|
+
|
|
728
|
+
raw_response = self._content_to_text(assistant_content)
|
|
729
|
+
if not raw_response:
|
|
730
|
+
raw_response = self._safe_json(inference_response, limit=2000)
|
|
731
|
+
|
|
732
|
+
base_response = BaseLMResponse(
|
|
733
|
+
raw_response=raw_response,
|
|
734
|
+
tool_calls=assistant_message.get("tool_calls")
|
|
735
|
+
if isinstance(assistant_message, dict)
|
|
736
|
+
else None,
|
|
737
|
+
usage=usage or None,
|
|
738
|
+
api_type="chat_completions",
|
|
739
|
+
)
|
|
740
|
+
|
|
741
|
+
request_messages = inference_request.get("messages") or []
|
|
742
|
+
try:
|
|
743
|
+
temperature = float(inference_request.get("temperature"))
|
|
744
|
+
except Exception:
|
|
745
|
+
temperature = 0.0
|
|
746
|
+
|
|
747
|
+
call_record = create_llm_call_record_from_response(
|
|
748
|
+
response=base_response,
|
|
749
|
+
model_name=model_name,
|
|
750
|
+
provider=provider,
|
|
751
|
+
messages=request_messages,
|
|
752
|
+
temperature=temperature,
|
|
753
|
+
request_params=inference_request,
|
|
754
|
+
tools=inference_request.get("tools"),
|
|
755
|
+
started_at=started_at,
|
|
756
|
+
completed_at=completed_at,
|
|
757
|
+
latency_ms=latency_ms,
|
|
758
|
+
)
|
|
759
|
+
|
|
760
|
+
event_metadata = {
|
|
761
|
+
"policy_id": self.request.policy.policy_id,
|
|
762
|
+
"turn": self.current_turn,
|
|
763
|
+
"run_id": self.run_id,
|
|
764
|
+
}
|
|
765
|
+
|
|
766
|
+
event = LMCAISEvent(
|
|
767
|
+
system_instance_id=f"policy:{self.policy_name or 'unknown'}",
|
|
768
|
+
time_record=TimeRecord(event_time=completed_at.timestamp()),
|
|
769
|
+
model_name=model_name,
|
|
770
|
+
provider=provider,
|
|
771
|
+
input_tokens=input_tokens,
|
|
772
|
+
output_tokens=output_tokens,
|
|
773
|
+
total_tokens=total_tokens,
|
|
774
|
+
cost_usd=cost_usd,
|
|
775
|
+
latency_ms=latency_ms,
|
|
776
|
+
call_records=[call_record],
|
|
777
|
+
metadata=event_metadata,
|
|
778
|
+
)
|
|
779
|
+
|
|
780
|
+
await self._record_event(event)
|
|
781
|
+
|
|
782
|
+
self.lm_calls_summary.append(
|
|
783
|
+
{
|
|
784
|
+
"turn": self.current_turn,
|
|
785
|
+
"model": model_name,
|
|
786
|
+
"provider": provider,
|
|
787
|
+
"total_tokens": total_tokens,
|
|
788
|
+
"input_tokens": input_tokens,
|
|
789
|
+
"output_tokens": output_tokens,
|
|
790
|
+
"latency_ms": latency_ms,
|
|
791
|
+
"tool_calls": len(tool_calls or []),
|
|
792
|
+
}
|
|
793
|
+
)
|
|
794
|
+
|
|
795
|
+
assistant_structured = assistant_content if assistant_content is not None else ""
|
|
796
|
+
assistant_text = self._content_to_text(assistant_content)
|
|
797
|
+
|
|
798
|
+
if self.enabled and self.tracer is not None:
|
|
799
|
+
assistant_payload: dict[str, Any] = {
|
|
800
|
+
"role": "assistant",
|
|
801
|
+
"content": assistant_structured,
|
|
802
|
+
"text": assistant_text,
|
|
803
|
+
}
|
|
804
|
+
if isinstance(assistant_message, dict):
|
|
805
|
+
if assistant_message.get("tool_calls"):
|
|
806
|
+
assistant_payload["tool_calls"] = assistant_message.get("tool_calls")
|
|
807
|
+
if assistant_message.get("reasoning"):
|
|
808
|
+
assistant_payload["reasoning"] = assistant_message.get("reasoning")
|
|
809
|
+
if assistant_message.get("thinking"):
|
|
810
|
+
assistant_payload["thinking"] = assistant_message.get("thinking")
|
|
811
|
+
try:
|
|
812
|
+
await self.tracer.record_message(
|
|
813
|
+
content=assistant_payload,
|
|
814
|
+
message_type="assistant",
|
|
815
|
+
metadata=self._message_metadata(),
|
|
816
|
+
)
|
|
817
|
+
except Exception as exc:
|
|
818
|
+
logger.debug("TRACING_ASSISTANT_MSG_FAIL: %s", exc)
|
|
819
|
+
|
|
820
|
+
if self.sft_output_dir is not None:
|
|
821
|
+
assistant_structured = assistant_content if assistant_content is not None else ""
|
|
822
|
+
dialogue_structured: list[dict[str, Any]] = []
|
|
823
|
+
for content in self.latest_system_prompt_content:
|
|
824
|
+
if content is None:
|
|
825
|
+
continue
|
|
826
|
+
dialogue_structured.append({"role": "system", "content": content})
|
|
827
|
+
for content in self.latest_user_prompt_content:
|
|
828
|
+
if content is None:
|
|
829
|
+
continue
|
|
830
|
+
dialogue_structured.append({"role": "user", "content": content})
|
|
831
|
+
dialogue_text = (
|
|
832
|
+
[{"role": "system", "content": s} for s in self.latest_system_messages]
|
|
833
|
+
+ [{"role": "user", "content": u} for u in self.latest_user_messages]
|
|
834
|
+
)
|
|
835
|
+
user_has_image = any(
|
|
836
|
+
self._content_has_image(content) for content in self.latest_user_prompt_content
|
|
837
|
+
)
|
|
838
|
+
assistant_has_image = self._content_has_image(assistant_structured)
|
|
839
|
+
record = {
|
|
840
|
+
"run_id": self.run_id,
|
|
841
|
+
"turn": self.current_turn,
|
|
842
|
+
"model": model_name,
|
|
843
|
+
"provider": provider,
|
|
844
|
+
"dialogue": dialogue_structured,
|
|
845
|
+
"dialogue_text": dialogue_text,
|
|
846
|
+
"assistant": {
|
|
847
|
+
"content": assistant_structured,
|
|
848
|
+
"content_text": assistant_text,
|
|
849
|
+
"tool_calls": assistant_message.get("tool_calls")
|
|
850
|
+
if isinstance(assistant_message, dict)
|
|
851
|
+
else [],
|
|
852
|
+
"has_image": assistant_has_image,
|
|
853
|
+
},
|
|
854
|
+
"metadata": {
|
|
855
|
+
"user_has_image": user_has_image,
|
|
856
|
+
"assistant_has_image": assistant_has_image,
|
|
857
|
+
"has_image": user_has_image or assistant_has_image,
|
|
858
|
+
},
|
|
859
|
+
"timestamp": datetime.utcnow().isoformat(),
|
|
860
|
+
}
|
|
861
|
+
self.sft_records.append(record)
|
|
862
|
+
|
|
863
|
+
async def record_environment_event(
|
|
864
|
+
self,
|
|
865
|
+
*,
|
|
866
|
+
env_handle: Any,
|
|
867
|
+
prev_obs: dict[str, Any] | None,
|
|
868
|
+
env_response: Any,
|
|
869
|
+
next_obs: dict[str, Any] | None,
|
|
870
|
+
metadata: dict[str, Any] | None = None,
|
|
871
|
+
) -> int | None:
|
|
872
|
+
if not self.enabled or self.tracer is None:
|
|
873
|
+
return None
|
|
874
|
+
|
|
875
|
+
try:
|
|
876
|
+
prev_summary = (
|
|
877
|
+
_summarize_observation_for_storage(env_handle, prev_obs or {})
|
|
878
|
+
if prev_obs is not None
|
|
879
|
+
else None
|
|
880
|
+
)
|
|
881
|
+
except Exception:
|
|
882
|
+
prev_summary = None
|
|
883
|
+
try:
|
|
884
|
+
next_summary = (
|
|
885
|
+
_summarize_observation_for_storage(env_handle, next_obs or {})
|
|
886
|
+
if next_obs is not None
|
|
887
|
+
else None
|
|
888
|
+
)
|
|
889
|
+
except Exception:
|
|
890
|
+
next_summary = None
|
|
891
|
+
|
|
892
|
+
reward_val = getattr(env_response, "reward", None)
|
|
893
|
+
try:
|
|
894
|
+
reward_float = float(reward_val) if reward_val is not None else 0.0
|
|
895
|
+
except Exception:
|
|
896
|
+
reward_float = 0.0
|
|
897
|
+
|
|
898
|
+
event = EnvironmentEvent(
|
|
899
|
+
system_instance_id=f"environment:{self.env_name or 'unknown'}",
|
|
900
|
+
time_record=TimeRecord(event_time=datetime.utcnow().timestamp()),
|
|
901
|
+
reward=reward_float,
|
|
902
|
+
terminated=bool(getattr(env_response, "done", False)),
|
|
903
|
+
truncated=bool(getattr(env_response, "truncated", False)),
|
|
904
|
+
system_state_before=prev_summary,
|
|
905
|
+
system_state_after=next_summary,
|
|
906
|
+
metadata={
|
|
907
|
+
"turn": self.current_turn,
|
|
908
|
+
"run_id": self.run_id,
|
|
909
|
+
**(metadata or {}),
|
|
910
|
+
},
|
|
911
|
+
)
|
|
912
|
+
|
|
913
|
+
return await self._record_event(event)
|
|
914
|
+
|
|
915
|
+
async def record_decision_reward(
|
|
916
|
+
self,
|
|
917
|
+
*,
|
|
918
|
+
event_id: int | None,
|
|
919
|
+
decision_meta: dict[str, Any] | None,
|
|
920
|
+
) -> None:
|
|
921
|
+
decision_meta = decision_meta or {}
|
|
922
|
+
ach_delta = int(decision_meta.get("ach_delta", 0))
|
|
923
|
+
unique_delta = int(decision_meta.get("unique_delta", 0))
|
|
924
|
+
all_ach = list(decision_meta.get("all") or [])
|
|
925
|
+
unique_ach = list(decision_meta.get("unique") or [])
|
|
926
|
+
|
|
927
|
+
self.decision_rewards.append(
|
|
928
|
+
{
|
|
929
|
+
"turn": self.current_turn,
|
|
930
|
+
"ach_delta": ach_delta,
|
|
931
|
+
"unique_delta": unique_delta,
|
|
932
|
+
"achievements": all_ach,
|
|
933
|
+
"unique_achievements": unique_ach,
|
|
934
|
+
}
|
|
935
|
+
)
|
|
936
|
+
|
|
937
|
+
if not self.enabled or self.tracer is None or event_id is None:
|
|
938
|
+
return
|
|
939
|
+
try:
|
|
940
|
+
await self.tracer.record_event_reward(
|
|
941
|
+
event_id=event_id,
|
|
942
|
+
turn_number=self.current_turn,
|
|
943
|
+
reward_value=float(ach_delta),
|
|
944
|
+
reward_type="achievement_delta",
|
|
945
|
+
annotation={"achievements": all_ach},
|
|
946
|
+
source="environment",
|
|
947
|
+
)
|
|
948
|
+
if unique_delta:
|
|
949
|
+
await self.tracer.record_event_reward(
|
|
950
|
+
event_id=event_id,
|
|
951
|
+
turn_number=self.current_turn,
|
|
952
|
+
reward_value=float(unique_delta),
|
|
953
|
+
reward_type="unique_achievement_delta",
|
|
954
|
+
annotation={"achievements": unique_ach},
|
|
955
|
+
source="environment",
|
|
956
|
+
)
|
|
957
|
+
except Exception as exc:
|
|
958
|
+
logger.debug("TRACING_REWARD_FAIL: %s", exc)
|
|
959
|
+
|
|
960
|
+
def update_metadata(self, **kwargs: Any) -> None:
|
|
961
|
+
self.metadata_updates.update({k: v for k, v in kwargs.items() if v is not None})
|
|
962
|
+
|
|
963
|
+
async def finalize(
|
|
964
|
+
self,
|
|
965
|
+
*,
|
|
966
|
+
total_reward: float,
|
|
967
|
+
achievement_state: dict[str, bool] | None,
|
|
968
|
+
total_steps: int,
|
|
969
|
+
) -> Any:
|
|
970
|
+
final_achievements = [key for key, val in (achievement_state or {}).items() if val]
|
|
971
|
+
self.metadata_updates.setdefault("final_achievements", final_achievements)
|
|
972
|
+
if self.enabled and self.tracer is not None:
|
|
973
|
+
try:
|
|
974
|
+
await self.tracer.record_outcome_reward(
|
|
975
|
+
total_reward=int(total_reward),
|
|
976
|
+
achievements_count=len(final_achievements),
|
|
977
|
+
total_steps=int(total_steps),
|
|
978
|
+
reward_metadata=dict(self.metadata_updates),
|
|
979
|
+
)
|
|
980
|
+
except Exception as exc:
|
|
981
|
+
logger.debug("TRACING_OUTCOME_FAIL: %s", exc)
|
|
982
|
+
try:
|
|
983
|
+
# Debug: Check message count before end_session
|
|
984
|
+
if self.tracer._current_trace:
|
|
985
|
+
msg_count = len(self.tracer._current_trace.markov_blanket_message_history)
|
|
986
|
+
print(f"[TRACE_DEBUG] Before end_session: {msg_count} messages in trace", flush=True)
|
|
987
|
+
|
|
988
|
+
self.session_trace = await self.tracer.end_session()
|
|
989
|
+
|
|
990
|
+
# Debug: Check if session was saved
|
|
991
|
+
if self.session_trace:
|
|
992
|
+
print(
|
|
993
|
+
f"[TRACE_DEBUG] Session ended successfully, session_id={self.session_trace.session_id}",
|
|
994
|
+
flush=True,
|
|
995
|
+
)
|
|
996
|
+
self.session_trace.metadata.update(self.metadata_updates)
|
|
997
|
+
print(
|
|
998
|
+
f"[TRACE_DEBUG] session_trace.metadata keys: {list(self.session_trace.metadata.keys())}",
|
|
999
|
+
flush=True,
|
|
1000
|
+
)
|
|
1001
|
+
else:
|
|
1002
|
+
print("[TRACE_DEBUG] end_session returned None!", flush=True)
|
|
1003
|
+
except Exception as exc:
|
|
1004
|
+
logger.warning(f"TRACING_END_SESSION_FAIL: {exc}", exc_info=True)
|
|
1005
|
+
self.session_trace = None
|
|
1006
|
+
with contextlib.suppress(Exception):
|
|
1007
|
+
await self.tracer.close()
|
|
1008
|
+
|
|
1009
|
+
if self.sft_records and self.sft_output_dir:
|
|
1010
|
+
self.write_sft_records()
|
|
1011
|
+
|
|
1012
|
+
# Clear context from request state to avoid leaks
|
|
1013
|
+
self.fastapi_request.state.rollout_tracing = None
|
|
1014
|
+
|
|
1015
|
+
return self.session_trace
|
|
1016
|
+
|
|
1017
|
+
def write_sft_records(self) -> None:
|
|
1018
|
+
if not self.sft_output_dir or not self.sft_records:
|
|
1019
|
+
return
|
|
1020
|
+
try:
|
|
1021
|
+
path = unique_sft_path(self.sft_output_dir, run_id=self.run_id)
|
|
1022
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
1023
|
+
with path.open("w", encoding="utf-8") as fh:
|
|
1024
|
+
for record in self.sft_records:
|
|
1025
|
+
json.dump(record, fh, ensure_ascii=False)
|
|
1026
|
+
fh.write("\n")
|
|
1027
|
+
logger.info(f"SFT_WRITTEN: {path}")
|
|
1028
|
+
except Exception as exc:
|
|
1029
|
+
logger.warning(f"SFT_WRITE_FAIL: {exc}")
|
|
1030
|
+
finally:
|
|
1031
|
+
self.sft_records.clear()
|
|
1032
|
+
|
|
1033
|
+
def build_trace_payload(self, session_trace: Any) -> dict[str, Any] | None:
|
|
1034
|
+
if not self.return_trace or session_trace is None:
|
|
1035
|
+
return None
|
|
1036
|
+
|
|
1037
|
+
# For both "full" and "structured" formats, return the complete session trace
|
|
1038
|
+
# The CLI (synth-ai eval) expects this for proper trace storage
|
|
1039
|
+
if self.trace_format in ("full", "structured"):
|
|
1040
|
+
payload = session_trace.to_dict()
|
|
1041
|
+
payload.setdefault("metadata", {}).update(self.metadata_updates)
|
|
1042
|
+
print(
|
|
1043
|
+
f"[TRACE_DEBUG] build_trace_payload returning structured trace with messages={len(payload.get('markov_blanket_message_history') or [])}",
|
|
1044
|
+
flush=True,
|
|
1045
|
+
)
|
|
1046
|
+
return payload
|
|
1047
|
+
|
|
1048
|
+
# For "compact" format, return only summary stats
|
|
1049
|
+
metadata = dict(session_trace.metadata)
|
|
1050
|
+
metadata.update(self.metadata_updates)
|
|
1051
|
+
return {
|
|
1052
|
+
"session_id": session_trace.session_id,
|
|
1053
|
+
"created_at": session_trace.created_at.isoformat(),
|
|
1054
|
+
"metadata": metadata,
|
|
1055
|
+
"events_count": len(session_trace.event_history),
|
|
1056
|
+
"messages_count": len(session_trace.markov_blanket_message_history),
|
|
1057
|
+
"lm_calls": self.lm_calls_summary,
|
|
1058
|
+
"decision_rewards": self.decision_rewards,
|
|
1059
|
+
}
|
|
1060
|
+
|
|
1061
|
+
|
|
1062
|
+
def _summarize_observation_for_storage(
|
|
1063
|
+
env_handle: Any, observation: dict[str, Any]
|
|
1064
|
+
) -> dict[str, Any]:
|
|
1065
|
+
"""Return a compact dict for trajectory storage instead of the raw observation.
|
|
1066
|
+
|
|
1067
|
+
- For Crafter, use the same summary used for the policy user prompt
|
|
1068
|
+
- For others, keep a minimal subset or plain text preview
|
|
1069
|
+
"""
|
|
1070
|
+
# Try Crafter-specific formatter
|
|
1071
|
+
crafter_wrapper = None
|
|
1072
|
+
with contextlib.suppress(Exception):
|
|
1073
|
+
from .envs.crafter.environment import (
|
|
1074
|
+
CrafterEnvironmentWrapper as _CrafterWrapper, # type: ignore
|
|
1075
|
+
)
|
|
1076
|
+
|
|
1077
|
+
crafter_wrapper = _CrafterWrapper # type: ignore[assignment]
|
|
1078
|
+
|
|
1079
|
+
if crafter_wrapper is not None and isinstance(
|
|
1080
|
+
getattr(env_handle, "env", None), crafter_wrapper
|
|
1081
|
+
):
|
|
1082
|
+
with contextlib.suppress(Exception):
|
|
1083
|
+
from .envs.crafter.shared import format_observation as _fmt # type: ignore
|
|
1084
|
+
|
|
1085
|
+
text = _fmt(observation or {})
|
|
1086
|
+
return {"text": text}
|
|
1087
|
+
|
|
1088
|
+
# Generic fallback: extract a few small fields if present; avoid huge arrays
|
|
1089
|
+
with contextlib.suppress(Exception):
|
|
1090
|
+
inv = observation.get("inventory") if isinstance(observation, dict) else None
|
|
1091
|
+
ach = observation.get("achievements_status") if isinstance(observation, dict) else None
|
|
1092
|
+
pos = observation.get("player_position") if isinstance(observation, dict) else None
|
|
1093
|
+
health = None
|
|
1094
|
+
if isinstance(inv, dict):
|
|
1095
|
+
health = inv.get("health")
|
|
1096
|
+
summary = {
|
|
1097
|
+
"position": pos,
|
|
1098
|
+
"health": health,
|
|
1099
|
+
"inventory_keys": sorted(k for k, v in (inv or {}).items() if v)[:10]
|
|
1100
|
+
if isinstance(inv, dict)
|
|
1101
|
+
else None,
|
|
1102
|
+
"achievements_unlocked": sorted(k for k, v in (ach or {}).items() if v)[:10]
|
|
1103
|
+
if isinstance(ach, dict)
|
|
1104
|
+
else None,
|
|
1105
|
+
}
|
|
1106
|
+
return {"text": json.dumps(summary, ensure_ascii=False)}
|
|
1107
|
+
|
|
1108
|
+
# Last resort: plain string preview
|
|
1109
|
+
try:
|
|
1110
|
+
return {"text": str(observation)[:10000]}
|
|
1111
|
+
except Exception:
|
|
1112
|
+
return {"text": ""}
|
|
1113
|
+
|
|
1114
|
+
|
|
1115
|
+
class RunAbortRequest(BaseModel):
|
|
1116
|
+
run_id: str
|
|
1117
|
+
|
|
1118
|
+
|
|
1119
|
+
class RunAbortResponse(BaseModel):
|
|
1120
|
+
ok: bool
|
|
1121
|
+
run_id: str
|
|
1122
|
+
|
|
1123
|
+
|
|
1124
|
+
class RunStatusResponse(BaseModel):
|
|
1125
|
+
run_id: str
|
|
1126
|
+
status: str
|
|
1127
|
+
started_at: datetime
|
|
1128
|
+
finished_at: datetime | None = None
|
|
1129
|
+
|
|
1130
|
+
|
|
1131
|
+
@router.post("/rollout", response_model=RolloutResponse)
|
|
1132
|
+
async def execute_rollout(
|
|
1133
|
+
request: RolloutRequest,
|
|
1134
|
+
req: Request,
|
|
1135
|
+
) -> RolloutResponse:
|
|
1136
|
+
"""Execute a rollout with coordinated environment and policy steps."""
|
|
1137
|
+
logger.info("ROLLOUT: mode = %s", request.mode)
|
|
1138
|
+
|
|
1139
|
+
# Emit rollout identifier early for correlation
|
|
1140
|
+
with contextlib.suppress(Exception):
|
|
1141
|
+
_rid = getattr(request, "run_id", None)
|
|
1142
|
+
_pol = getattr(request.policy, "policy_name", None) or getattr(request.policy, "policy_id", None)
|
|
1143
|
+
_env = getattr(request.env, "env_name", None) or getattr(request.env, "env_id", None)
|
|
1144
|
+
logger.info("ROLLOUT_BEGIN: run_id=%s policy=%s env=%s mode=%s", _rid, _pol, _env, request.mode)
|
|
1145
|
+
print(f"[rollout] begin run_id={_rid} policy={_pol} env={_env}", flush=True)
|
|
1146
|
+
# Enforce per-episode step cap via env-specific parameters; default to 20 if omitted
|
|
1147
|
+
try:
|
|
1148
|
+
_env_params = {}
|
|
1149
|
+
if isinstance(request.env, RolloutEnvSpec) and isinstance(request.env.config, dict):
|
|
1150
|
+
_env_params = dict(request.env.config.get("env_params") or {})
|
|
1151
|
+
max_steps_per_episode = int(_env_params.get("max_steps_per_episode") or 20)
|
|
1152
|
+
assert max_steps_per_episode > 0, "max_steps_per_episode must be a positive integer"
|
|
1153
|
+
except Exception as _mse:
|
|
1154
|
+
raise HTTPException(
|
|
1155
|
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
1156
|
+
detail={
|
|
1157
|
+
"error": "invalid_env_params",
|
|
1158
|
+
"message": f"Invalid or missing env_params.max_steps_per_episode: {_mse}",
|
|
1159
|
+
},
|
|
1160
|
+
) from _mse
|
|
1161
|
+
# Truncate incoming ops to the enforced cap (each step is [agent, env])
|
|
1162
|
+
ops_seq: list[str] = list(request.ops or [])
|
|
1163
|
+
allowed_ops = max(0, int(max_steps_per_episode) * 2)
|
|
1164
|
+
if len(ops_seq) > allowed_ops:
|
|
1165
|
+
with contextlib.suppress(Exception):
|
|
1166
|
+
logger.info(
|
|
1167
|
+
"ROLL_OUT: truncating ops to cap: requested_ops=%s allowed_ops=%s",
|
|
1168
|
+
str(len(ops_seq)),
|
|
1169
|
+
str(allowed_ops),
|
|
1170
|
+
)
|
|
1171
|
+
ops_seq = ops_seq[:allowed_ops]
|
|
1172
|
+
# Simple API key auth for inbound rollout
|
|
1173
|
+
header_key = req.headers.get("x-api-key")
|
|
1174
|
+
env_key = os.getenv("ENVIRONMENT_API_KEY")
|
|
1175
|
+
dev_key = os.getenv("DEV_ENVIRONMENT_API_KEY")
|
|
1176
|
+
# Accept either ENVIRONMENT_API_KEY or DEV_ENVIRONMENT_API_KEY
|
|
1177
|
+
expected_keys = [k for k in (env_key, dev_key) if k]
|
|
1178
|
+
if not expected_keys:
|
|
1179
|
+
missing = []
|
|
1180
|
+
if not env_key:
|
|
1181
|
+
missing.append("ENVIRONMENT_API_KEY")
|
|
1182
|
+
if not dev_key:
|
|
1183
|
+
missing.append("DEV_ENVIRONMENT_API_KEY")
|
|
1184
|
+
msg = f"Auth not configured: missing {', '.join(missing)} in task service environment"
|
|
1185
|
+
logger.error(msg)
|
|
1186
|
+
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=msg)
|
|
1187
|
+
if not header_key:
|
|
1188
|
+
raise HTTPException(
|
|
1189
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
1190
|
+
detail="Invalid or missing API key: X-API-Key header not provided",
|
|
1191
|
+
)
|
|
1192
|
+
if header_key not in expected_keys:
|
|
1193
|
+
# Do not leak secrets; include short prefix for diagnostics
|
|
1194
|
+
exp_src = env_key if env_key else (dev_key or "")
|
|
1195
|
+
exp_prefix = (exp_src[:7] + "…") if len(exp_src) >= 7 else "set"
|
|
1196
|
+
got_prefix = (header_key[:7] + "…") if len(header_key) >= 7 else "set"
|
|
1197
|
+
raise HTTPException(
|
|
1198
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
1199
|
+
detail=f"Invalid API key: header does not match expected (got={got_prefix}, expected_prefix={exp_prefix})",
|
|
1200
|
+
)
|
|
1201
|
+
|
|
1202
|
+
# Log contextual fields for traceability
|
|
1203
|
+
if request.training_session_id:
|
|
1204
|
+
logger.info(f"ROLL_OUT: training_session_id={request.training_session_id}")
|
|
1205
|
+
if request.synth_base_url:
|
|
1206
|
+
logger.info(f"ROLL_OUT: synth_base_url={request.synth_base_url}")
|
|
1207
|
+
|
|
1208
|
+
# Log masked OpenAI API key presence for diagnostics
|
|
1209
|
+
with contextlib.suppress(Exception):
|
|
1210
|
+
_oa = os.getenv("OPENAI_API_KEY")
|
|
1211
|
+
if _oa:
|
|
1212
|
+
_pref = (_oa[:6] + "…") if len(_oa) >= 6 else "set"
|
|
1213
|
+
logger.info(f"ROLL_OUT: OPENAI_API_KEY present (prefix={_pref})")
|
|
1214
|
+
else:
|
|
1215
|
+
logger.warning("ROLL_OUT: OPENAI_API_KEY missing")
|
|
1216
|
+
|
|
1217
|
+
# Make synth_base_url available for outbound calls in this app
|
|
1218
|
+
with contextlib.suppress(Exception):
|
|
1219
|
+
task_app = req.app.state.task_app
|
|
1220
|
+
if request.synth_base_url:
|
|
1221
|
+
task_app.synth_base_url = request.synth_base_url
|
|
1222
|
+
|
|
1223
|
+
tracer_factory = getattr(req.app.state, "session_tracer_factory", None)
|
|
1224
|
+
tracer_instance: SessionTracer | None = None
|
|
1225
|
+
if callable(tracer_factory):
|
|
1226
|
+
try:
|
|
1227
|
+
inst = tracer_factory()
|
|
1228
|
+
tracer_instance = inst if isinstance(inst, SessionTracer) else None
|
|
1229
|
+
except Exception as exc:
|
|
1230
|
+
logger.debug(f"TRACER_FACTORY_FAIL: {exc}")
|
|
1231
|
+
tracing_context = RolloutTracingContext(tracer_instance, request, req)
|
|
1232
|
+
await tracing_context.start_session()
|
|
1233
|
+
|
|
1234
|
+
# Register run
|
|
1235
|
+
registry.register_run(request.run_id)
|
|
1236
|
+
|
|
1237
|
+
# Track resources created during this rollout so we can guarantee cleanup
|
|
1238
|
+
created_env_id: str | None = None
|
|
1239
|
+
created_policy_id: str | None = None
|
|
1240
|
+
env_seed_used: int | None = None
|
|
1241
|
+
trajectory_steps: list[RolloutStep] = []
|
|
1242
|
+
decision_samples: list[dict[str, Any]] = []
|
|
1243
|
+
pending_tool_calls: Any = None
|
|
1244
|
+
current_obs: Any = {}
|
|
1245
|
+
total_reward: float = 0.0
|
|
1246
|
+
ops_executed = 0
|
|
1247
|
+
last_agent_response_ts: float | None = None
|
|
1248
|
+
last_policy_meta: dict[str, Any] | None = None
|
|
1249
|
+
last_env_step_ms: float | None = None
|
|
1250
|
+
last_env_step_completed_ts: float | None = None
|
|
1251
|
+
decision_open = False
|
|
1252
|
+
finalized = False
|
|
1253
|
+
prev_achievements: dict[str, bool] = {}
|
|
1254
|
+
session_trace = None
|
|
1255
|
+
step_rewards_active = False
|
|
1256
|
+
|
|
1257
|
+
try:
|
|
1258
|
+
# Initialize deterministic seed early for the entire rollout
|
|
1259
|
+
seed_value: int | None = None
|
|
1260
|
+
try:
|
|
1261
|
+
if request.env and request.env.seed is not None:
|
|
1262
|
+
seed_value = int(request.env.seed)
|
|
1263
|
+
else:
|
|
1264
|
+
# Derive a stable seed from run_id
|
|
1265
|
+
import hashlib as _hashlib # local import to avoid global deps
|
|
1266
|
+
|
|
1267
|
+
_digest = _hashlib.sha256(request.run_id.encode("utf-8")).hexdigest()
|
|
1268
|
+
# Use lower 32 bits to fit common RNG ranges
|
|
1269
|
+
seed_value = int(_digest[:8], 16)
|
|
1270
|
+
except Exception:
|
|
1271
|
+
# Fallback to time-based seed if anything goes wrong
|
|
1272
|
+
try:
|
|
1273
|
+
seed_value = int((_time.time_ns() // 1_000_000) % (2**31 - 1))
|
|
1274
|
+
except Exception:
|
|
1275
|
+
seed_value = 42
|
|
1276
|
+
|
|
1277
|
+
_seed_info = _set_global_seed(int(seed_value))
|
|
1278
|
+
with contextlib.suppress(Exception):
|
|
1279
|
+
logger.info(
|
|
1280
|
+
"ROLL_OUT: RNG seeded seed=%s libs=%s",
|
|
1281
|
+
str(_seed_info.get("seed")),
|
|
1282
|
+
",".join(_seed_info.get("libs", [])),
|
|
1283
|
+
)
|
|
1284
|
+
# Resolve or create environment
|
|
1285
|
+
if request.env.env_id:
|
|
1286
|
+
env_handle = registry.get_env(request.env.env_id)
|
|
1287
|
+
if not env_handle:
|
|
1288
|
+
raise HTTPException(
|
|
1289
|
+
status_code=404,
|
|
1290
|
+
detail=f"Environment {request.env.env_id} not found",
|
|
1291
|
+
)
|
|
1292
|
+
env_id = request.env.env_id
|
|
1293
|
+
else:
|
|
1294
|
+
# Create new environment
|
|
1295
|
+
from .environment_routes import EnvCreateRequest, create_environment
|
|
1296
|
+
|
|
1297
|
+
if not request.env.env_name:
|
|
1298
|
+
raise ValueError("FATAL: env_name is required - NO FALLBACKS!")
|
|
1299
|
+
|
|
1300
|
+
# Propagate training_session_id via env config for downstream usage
|
|
1301
|
+
_env_config = dict(request.env.config or {})
|
|
1302
|
+
if request.training_session_id is not None:
|
|
1303
|
+
_env_config.setdefault("training_session_id", request.training_session_id)
|
|
1304
|
+
env_response = await create_environment(
|
|
1305
|
+
EnvCreateRequest(
|
|
1306
|
+
env_name=request.env.env_name,
|
|
1307
|
+
config=_env_config,
|
|
1308
|
+
seed=request.env.seed,
|
|
1309
|
+
rl_run_id=request.run_id,
|
|
1310
|
+
)
|
|
1311
|
+
)
|
|
1312
|
+
env_id = env_response.env_id
|
|
1313
|
+
env_handle = registry.get_env(env_id)
|
|
1314
|
+
created_env_id = env_id
|
|
1315
|
+
|
|
1316
|
+
tracing_context.update_metadata(env_id=env_id)
|
|
1317
|
+
|
|
1318
|
+
# Resolve or create policy
|
|
1319
|
+
if request.policy.policy_id:
|
|
1320
|
+
policy_handle = registry.get_policy(request.policy.policy_id)
|
|
1321
|
+
if not policy_handle:
|
|
1322
|
+
raise HTTPException(
|
|
1323
|
+
status_code=404,
|
|
1324
|
+
detail=f"Policy {request.policy.policy_id} not found",
|
|
1325
|
+
)
|
|
1326
|
+
policy_id = request.policy.policy_id
|
|
1327
|
+
else:
|
|
1328
|
+
# Create new policy
|
|
1329
|
+
from .policy_routes import PolicyCreateRequest, create_policy
|
|
1330
|
+
|
|
1331
|
+
if not request.policy.policy_name:
|
|
1332
|
+
raise ValueError("FATAL: policy_name is required - NO FALLBACKS!")
|
|
1333
|
+
|
|
1334
|
+
# Propagate training_session_id and synth_base_url via policy config
|
|
1335
|
+
_policy_config = dict(request.policy.config or {})
|
|
1336
|
+
if request.training_session_id is not None:
|
|
1337
|
+
_policy_config.setdefault("training_session_id", request.training_session_id)
|
|
1338
|
+
if request.synth_base_url is not None:
|
|
1339
|
+
_policy_config.setdefault("synth_base_url", request.synth_base_url)
|
|
1340
|
+
policy_response = await create_policy(
|
|
1341
|
+
PolicyCreateRequest(
|
|
1342
|
+
policy_name=request.policy.policy_name,
|
|
1343
|
+
config=_policy_config,
|
|
1344
|
+
rl_run_id=request.run_id,
|
|
1345
|
+
bound_env_id=env_id,
|
|
1346
|
+
mode=request.mode, # Pass through mode for URL transformation control
|
|
1347
|
+
),
|
|
1348
|
+
req,
|
|
1349
|
+
)
|
|
1350
|
+
policy_id = policy_response.policy_id
|
|
1351
|
+
policy_handle = registry.get_policy(policy_id)
|
|
1352
|
+
created_policy_id = policy_id
|
|
1353
|
+
|
|
1354
|
+
tracing_context.update_metadata(policy_id=policy_id)
|
|
1355
|
+
|
|
1356
|
+
# Bind policy to environment if not already bound
|
|
1357
|
+
if policy_handle and not policy_handle.bound_env_id:
|
|
1358
|
+
policy_handle.bound_env_id = env_id
|
|
1359
|
+
|
|
1360
|
+
# Record seed bound to environment for end-of-rollout verification/logging
|
|
1361
|
+
try:
|
|
1362
|
+
env_seed_used = int(getattr(env_handle, "seed", 0) or 0)
|
|
1363
|
+
except Exception:
|
|
1364
|
+
env_seed_used = None
|
|
1365
|
+
tracing_context.update_metadata(env_seed=env_seed_used)
|
|
1366
|
+
# Initialize trajectory
|
|
1367
|
+
trajectory_steps = []
|
|
1368
|
+
pending_tool_calls = None
|
|
1369
|
+
current_obs = env_handle.last_observation
|
|
1370
|
+
total_reward = 0.0
|
|
1371
|
+
ops_executed = 0
|
|
1372
|
+
last_agent_response_ts = None
|
|
1373
|
+
last_policy_meta = None
|
|
1374
|
+
last_env_step_ms = None
|
|
1375
|
+
last_env_step_completed_ts = None
|
|
1376
|
+
|
|
1377
|
+
# Stepwise reward configuration (Crafter shaping; gate on explicit enable)
|
|
1378
|
+
step_rewards_cfg_raw: dict[str, Any] = {}
|
|
1379
|
+
try:
|
|
1380
|
+
if isinstance(request.policy.config, dict):
|
|
1381
|
+
step_rewards_cfg_raw = dict(request.policy.config.get("step_rewards") or {})
|
|
1382
|
+
except Exception:
|
|
1383
|
+
step_rewards_cfg_raw = {}
|
|
1384
|
+
if not step_rewards_cfg_raw:
|
|
1385
|
+
try:
|
|
1386
|
+
if isinstance(request.env.config, dict):
|
|
1387
|
+
step_rewards_cfg_raw = dict(request.env.config.get("step_rewards") or {})
|
|
1388
|
+
except Exception:
|
|
1389
|
+
step_rewards_cfg_raw = {}
|
|
1390
|
+
|
|
1391
|
+
step_rewards_enabled = bool(step_rewards_cfg_raw.get("enabled", False))
|
|
1392
|
+
step_rewards_mode = str(step_rewards_cfg_raw.get("mode") or "off").lower()
|
|
1393
|
+
step_rewards_strategy = _normalize_step_strategy(step_rewards_cfg_raw.get("strategy"))
|
|
1394
|
+
step_rewards_weights = _coerce_weights(step_rewards_cfg_raw.get("weights"))
|
|
1395
|
+
step_rewards_k_limits = _coerce_k_limits(step_rewards_cfg_raw.get("k_limits"))
|
|
1396
|
+
try:
|
|
1397
|
+
step_rewards_indicator_lambda = float(
|
|
1398
|
+
step_rewards_cfg_raw.get("indicator_lambda") or 0.0
|
|
1399
|
+
)
|
|
1400
|
+
except Exception:
|
|
1401
|
+
step_rewards_indicator_lambda = 0.0
|
|
1402
|
+
try:
|
|
1403
|
+
step_rewards_beta = float(step_rewards_cfg_raw.get("step_beta") or 0.0)
|
|
1404
|
+
except Exception:
|
|
1405
|
+
step_rewards_beta = 0.0
|
|
1406
|
+
step_rewards_active = step_rewards_enabled and step_rewards_mode == "decision_stepwise"
|
|
1407
|
+
|
|
1408
|
+
def _extract_achievements(obs: Any) -> dict[str, bool]:
|
|
1409
|
+
if not isinstance(obs, dict):
|
|
1410
|
+
return {}
|
|
1411
|
+
ach = obs.get("achievements_status")
|
|
1412
|
+
if isinstance(ach, dict):
|
|
1413
|
+
return {str(k): bool(v) for k, v in ach.items()}
|
|
1414
|
+
return {}
|
|
1415
|
+
|
|
1416
|
+
def _extract_inventory(obs: Any) -> dict[str, int]:
|
|
1417
|
+
if not isinstance(obs, dict):
|
|
1418
|
+
return {}
|
|
1419
|
+
inv = obs.get("inventory")
|
|
1420
|
+
if not isinstance(inv, dict):
|
|
1421
|
+
return {}
|
|
1422
|
+
cleaned: dict[str, int] = {}
|
|
1423
|
+
for key, value in inv.items():
|
|
1424
|
+
coerced = _coerce_int_value(value)
|
|
1425
|
+
if coerced is None:
|
|
1426
|
+
continue
|
|
1427
|
+
cleaned[str(key)] = coerced
|
|
1428
|
+
return cleaned
|
|
1429
|
+
|
|
1430
|
+
def _extract_achievement_counts(obs: Any) -> dict[str, int]:
|
|
1431
|
+
if not isinstance(obs, dict):
|
|
1432
|
+
return {}
|
|
1433
|
+
counts = obs.get("achievements_counts")
|
|
1434
|
+
if not isinstance(counts, dict):
|
|
1435
|
+
return {}
|
|
1436
|
+
cleaned: dict[str, int] = {}
|
|
1437
|
+
for key, value in counts.items():
|
|
1438
|
+
coerced = _coerce_int_value(value)
|
|
1439
|
+
if coerced is None:
|
|
1440
|
+
continue
|
|
1441
|
+
cleaned[str(key)] = coerced
|
|
1442
|
+
return cleaned
|
|
1443
|
+
|
|
1444
|
+
def _summarize_tool_calls(tool_calls: Any) -> list[dict[str, Any]]:
|
|
1445
|
+
if not tool_calls:
|
|
1446
|
+
return []
|
|
1447
|
+
try:
|
|
1448
|
+
items = (
|
|
1449
|
+
tool_calls
|
|
1450
|
+
if isinstance(tool_calls, list)
|
|
1451
|
+
else list(tool_calls) # tolerates tuples or pydantic lists
|
|
1452
|
+
)
|
|
1453
|
+
except Exception:
|
|
1454
|
+
return []
|
|
1455
|
+
summary: list[dict[str, Any]] = []
|
|
1456
|
+
for tc in items:
|
|
1457
|
+
tool_name = None
|
|
1458
|
+
args: Any = {}
|
|
1459
|
+
if isinstance(tc, dict):
|
|
1460
|
+
tool_name = tc.get("tool") or tc.get("tool_name") or tc.get("name")
|
|
1461
|
+
raw_args = tc.get("arguments") or tc.get("args") or {}
|
|
1462
|
+
else:
|
|
1463
|
+
tool_name = getattr(tc, "tool", None) or getattr(tc, "tool_name", None)
|
|
1464
|
+
raw_args = getattr(tc, "arguments", None) or getattr(tc, "args", None) or {}
|
|
1465
|
+
args = raw_args
|
|
1466
|
+
if isinstance(raw_args, str):
|
|
1467
|
+
try:
|
|
1468
|
+
args = json.loads(raw_args)
|
|
1469
|
+
except Exception:
|
|
1470
|
+
args = raw_args
|
|
1471
|
+
summary.append({"tool": tool_name, "args": args})
|
|
1472
|
+
return summary
|
|
1473
|
+
|
|
1474
|
+
decision_samples: list[dict[str, Any]] = []
|
|
1475
|
+
decision_index = 0
|
|
1476
|
+
decision_open = False
|
|
1477
|
+
session_trace = None
|
|
1478
|
+
finalized = False
|
|
1479
|
+
prev_achievements = _extract_achievements(current_obs)
|
|
1480
|
+
prev_inventory_state = _extract_inventory(current_obs)
|
|
1481
|
+
prev_achievement_counts_state = _extract_achievement_counts(current_obs)
|
|
1482
|
+
# Track episode-level achievements that have been seen as true at any point so far
|
|
1483
|
+
episode_seen_achievements: set[str] = {
|
|
1484
|
+
k for k, v in (prev_achievements or {}).items() if bool(v)
|
|
1485
|
+
}
|
|
1486
|
+
episode_achievement_counts: dict[str, int] = {}
|
|
1487
|
+
stepwise_indicator_sum = 0.0
|
|
1488
|
+
stepwise_reward_sum = 0.0
|
|
1489
|
+
stepwise_resource_reward_sum = 0.0
|
|
1490
|
+
stepwise_new_achievements_total = 0
|
|
1491
|
+
final_achievement_count = sum(1 for v in prev_achievements.values() if v)
|
|
1492
|
+
|
|
1493
|
+
# Execute ops sequence (capped by env_params.max_steps_per_episode)
|
|
1494
|
+
for op_idx, op in enumerate(ops_seq):
|
|
1495
|
+
# Check for abort
|
|
1496
|
+
if registry.is_run_aborted(request.run_id):
|
|
1497
|
+
logger.info(f"Run {request.run_id} aborted at op {op_idx}")
|
|
1498
|
+
break
|
|
1499
|
+
|
|
1500
|
+
# Check safety limits
|
|
1501
|
+
if ops_executed >= request.safety.max_ops:
|
|
1502
|
+
logger.warning(f"Reached max_ops limit ({request.safety.max_ops})")
|
|
1503
|
+
break
|
|
1504
|
+
|
|
1505
|
+
if op == "agent":
|
|
1506
|
+
# Policy step
|
|
1507
|
+
from .policy_routes import PolicyStepRequest, step_policy
|
|
1508
|
+
|
|
1509
|
+
if not decision_open:
|
|
1510
|
+
await tracing_context.start_decision(decision_index)
|
|
1511
|
+
decision_open = True
|
|
1512
|
+
|
|
1513
|
+
agent_request_start = _time.perf_counter()
|
|
1514
|
+
if last_agent_response_ts is not None and last_policy_meta is not None:
|
|
1515
|
+
with contextlib.suppress(Exception):
|
|
1516
|
+
timing_prev = last_policy_meta.setdefault("timing", {})
|
|
1517
|
+
decision_ms = max(
|
|
1518
|
+
0.0,
|
|
1519
|
+
(agent_request_start - float(last_agent_response_ts)) * 1000.0,
|
|
1520
|
+
)
|
|
1521
|
+
# Update timing on prior policy meta (kept by previous env step)
|
|
1522
|
+
timing_prev["decision_ms"] = decision_ms
|
|
1523
|
+
if last_env_step_ms is not None:
|
|
1524
|
+
timing_prev["env_step_ms"] = float(last_env_step_ms)
|
|
1525
|
+
timing_prev["overhead_ms"] = max(
|
|
1526
|
+
0.0, decision_ms - float(last_env_step_ms)
|
|
1527
|
+
)
|
|
1528
|
+
else:
|
|
1529
|
+
timing_prev.setdefault("overhead_ms", 0.0)
|
|
1530
|
+
timing_prev["decision_ready_s"] = agent_request_start
|
|
1531
|
+
# Also backfill the last appended trajectory step so the trainer
|
|
1532
|
+
# can always see decision_ms without relying on shared dict refs.
|
|
1533
|
+
if trajectory_steps:
|
|
1534
|
+
with contextlib.suppress(Exception):
|
|
1535
|
+
_last = trajectory_steps[-1]
|
|
1536
|
+
_info = dict(_last.info or {})
|
|
1537
|
+
_meta = dict(_info.get("meta") or {})
|
|
1538
|
+
_timing = dict(_meta.get("timing") or {})
|
|
1539
|
+
_timing["decision_ms"] = decision_ms
|
|
1540
|
+
if last_env_step_ms is not None:
|
|
1541
|
+
_timing.setdefault("env_step_ms", float(last_env_step_ms))
|
|
1542
|
+
_timing.setdefault(
|
|
1543
|
+
"overhead_ms",
|
|
1544
|
+
max(0.0, decision_ms - float(last_env_step_ms)),
|
|
1545
|
+
)
|
|
1546
|
+
else:
|
|
1547
|
+
_timing.setdefault("overhead_ms", 0.0)
|
|
1548
|
+
_meta["timing"] = _timing
|
|
1549
|
+
_info["meta"] = _meta
|
|
1550
|
+
_last.info = _info
|
|
1551
|
+
last_env_step_ms = None
|
|
1552
|
+
last_env_step_completed_ts = None
|
|
1553
|
+
|
|
1554
|
+
# Build metadata for policy (carry previous tool_calls and env result)
|
|
1555
|
+
metadata = {}
|
|
1556
|
+
if pending_tool_calls:
|
|
1557
|
+
metadata["prev_tool_calls"] = pending_tool_calls
|
|
1558
|
+
if len(trajectory_steps) > 0:
|
|
1559
|
+
last_step = trajectory_steps[-1]
|
|
1560
|
+
# Prefer the last executed tool calls to seed history
|
|
1561
|
+
if last_step.tool_calls:
|
|
1562
|
+
metadata["prev_tool_calls"] = last_step.tool_calls
|
|
1563
|
+
# Provide a compact env result snapshot
|
|
1564
|
+
metadata["prev_env_result"] = {
|
|
1565
|
+
"observation": last_step.obs,
|
|
1566
|
+
"reward": last_step.reward,
|
|
1567
|
+
"done": last_step.done,
|
|
1568
|
+
"truncated": last_step.truncated,
|
|
1569
|
+
"info": last_step.info,
|
|
1570
|
+
}
|
|
1571
|
+
|
|
1572
|
+
# Log compact metadata summary to confirm history threading
|
|
1573
|
+
with contextlib.suppress(Exception):
|
|
1574
|
+
_prev_calls = metadata.get("prev_tool_calls")
|
|
1575
|
+
_count = len(_prev_calls) if isinstance(_prev_calls, list) else 0
|
|
1576
|
+
_first_guess = None
|
|
1577
|
+
if _count > 0 and isinstance(_prev_calls[0], dict):
|
|
1578
|
+
_args = _prev_calls[0].get("arguments", None)
|
|
1579
|
+
if isinstance(_args, str):
|
|
1580
|
+
import json as _json
|
|
1581
|
+
with contextlib.suppress(Exception):
|
|
1582
|
+
_args = _json.loads(_args)
|
|
1583
|
+
if not isinstance(_args, dict):
|
|
1584
|
+
_args = {}
|
|
1585
|
+
_first_guess = _args.get("guess") or _args.get("word")
|
|
1586
|
+
logger.info(
|
|
1587
|
+
"POLICY_METADATA: prev_tool_calls=%d first_guess=%r has_prev_env_result=%s",
|
|
1588
|
+
_count,
|
|
1589
|
+
_first_guess,
|
|
1590
|
+
str("prev_env_result" in metadata),
|
|
1591
|
+
)
|
|
1592
|
+
|
|
1593
|
+
try:
|
|
1594
|
+
policy_response = await step_policy(
|
|
1595
|
+
PolicyStepRequest(
|
|
1596
|
+
policy_id=policy_id,
|
|
1597
|
+
observation=current_obs,
|
|
1598
|
+
metadata=metadata,
|
|
1599
|
+
),
|
|
1600
|
+
req,
|
|
1601
|
+
)
|
|
1602
|
+
except Exception as _pe:
|
|
1603
|
+
# Hard fail the rollout on policy step error (e.g., inference auth 4xx)
|
|
1604
|
+
logger.error(
|
|
1605
|
+
"POLICY_STEP_HARD_FAIL: run_id=%s op_idx=%s err=%s",
|
|
1606
|
+
request.run_id,
|
|
1607
|
+
str(op_idx),
|
|
1608
|
+
str(_pe),
|
|
1609
|
+
)
|
|
1610
|
+
raise HTTPException(status_code=500, detail=f"policy_step_failed: {str(_pe)}")
|
|
1611
|
+
|
|
1612
|
+
agent_response_ts = _time.perf_counter()
|
|
1613
|
+
if isinstance(policy_response.meta, dict):
|
|
1614
|
+
with contextlib.suppress(Exception):
|
|
1615
|
+
timing_cur = policy_response.meta.setdefault("timing", {})
|
|
1616
|
+
timing_cur["agent_request_start_s"] = agent_request_start
|
|
1617
|
+
timing_cur["agent_response_s"] = agent_response_ts
|
|
1618
|
+
if "inference_ms" in policy_response.meta:
|
|
1619
|
+
with contextlib.suppress(Exception):
|
|
1620
|
+
timing_cur.setdefault(
|
|
1621
|
+
"inference_ms",
|
|
1622
|
+
float(policy_response.meta["inference_ms"]),
|
|
1623
|
+
)
|
|
1624
|
+
timing_cur.setdefault(
|
|
1625
|
+
"inference_s",
|
|
1626
|
+
float(policy_response.meta["inference_ms"]) / 1000.0,
|
|
1627
|
+
)
|
|
1628
|
+
last_policy_meta = policy_response.meta
|
|
1629
|
+
else:
|
|
1630
|
+
last_policy_meta = None
|
|
1631
|
+
last_agent_response_ts = agent_response_ts
|
|
1632
|
+
|
|
1633
|
+
# Diagnostic: summarize policy step target and tool calls
|
|
1634
|
+
try:
|
|
1635
|
+
model_name = None
|
|
1636
|
+
target_url = None
|
|
1637
|
+
if isinstance(policy_response.meta, dict):
|
|
1638
|
+
req_body = policy_response.meta.get("inference_request") or {}
|
|
1639
|
+
model_name = req_body.get("model")
|
|
1640
|
+
target_url = policy_response.meta.get("inference_url")
|
|
1641
|
+
_tc = policy_response.tool_calls or []
|
|
1642
|
+
print(
|
|
1643
|
+
{
|
|
1644
|
+
"rollout.policy_step": True,
|
|
1645
|
+
"run_id": request.run_id,
|
|
1646
|
+
"model": model_name,
|
|
1647
|
+
"inference_url": target_url,
|
|
1648
|
+
"tool_calls_count": len(_tc) if isinstance(_tc, list) else 0,
|
|
1649
|
+
},
|
|
1650
|
+
flush=True,
|
|
1651
|
+
)
|
|
1652
|
+
except Exception:
|
|
1653
|
+
pass
|
|
1654
|
+
|
|
1655
|
+
pending_tool_calls = policy_response.tool_calls
|
|
1656
|
+
# Log summarized agent tool calls
|
|
1657
|
+
with contextlib.suppress(Exception):
|
|
1658
|
+
_tc = pending_tool_calls or []
|
|
1659
|
+
_summary = []
|
|
1660
|
+
for _item in (_tc if isinstance(_tc, list) else []):
|
|
1661
|
+
try:
|
|
1662
|
+
if isinstance(_item, dict):
|
|
1663
|
+
_tool = _item.get("tool")
|
|
1664
|
+
_args = _item.get("args")
|
|
1665
|
+
_keys = list(_args.keys()) if isinstance(_args, dict) else []
|
|
1666
|
+
_summary.append({"tool": _tool, "args_keys": _keys})
|
|
1667
|
+
except Exception:
|
|
1668
|
+
continue
|
|
1669
|
+
_rid = getattr(request, "run_id", None)
|
|
1670
|
+
logger.info("AGENT_TOOL_CALLS: run_id=%s count=%d summary=%s", _rid, len(_tc), _summary)
|
|
1671
|
+
print(f"[rollout] agent tool_calls run_id={_rid} count={len(_tc)} summary={_summary}", flush=True)
|
|
1672
|
+
await tracing_context.record_tool_invocation(pending_tool_calls)
|
|
1673
|
+
ops_executed += 1
|
|
1674
|
+
|
|
1675
|
+
elif op == "env":
|
|
1676
|
+
if not pending_tool_calls:
|
|
1677
|
+
# Instead of failing, inject a no-op action to keep the rollout going
|
|
1678
|
+
with contextlib.suppress(Exception):
|
|
1679
|
+
logger.warning(
|
|
1680
|
+
"POLICY_STEP_NOOP: missing tool_calls; injecting noop action run_id=%s op_idx=%s",
|
|
1681
|
+
request.run_id,
|
|
1682
|
+
str(op_idx),
|
|
1683
|
+
)
|
|
1684
|
+
# Create a noop tool call in the format expected by the environment
|
|
1685
|
+
pending_tool_calls = [
|
|
1686
|
+
{
|
|
1687
|
+
"id": f"noop_{op_idx}",
|
|
1688
|
+
"tool": "interact",
|
|
1689
|
+
"arguments": {"action": "noop"},
|
|
1690
|
+
}
|
|
1691
|
+
]
|
|
1692
|
+
|
|
1693
|
+
# Environment step
|
|
1694
|
+
from .environment_routes import EnvStepRequest, step_environment
|
|
1695
|
+
|
|
1696
|
+
env_step_error: Exception | None = None
|
|
1697
|
+
env_response = None
|
|
1698
|
+
env_step_start = _time.perf_counter()
|
|
1699
|
+
try:
|
|
1700
|
+
env_response = await step_environment(
|
|
1701
|
+
EnvStepRequest(
|
|
1702
|
+
env_id=env_id,
|
|
1703
|
+
tool_calls=pending_tool_calls,
|
|
1704
|
+
)
|
|
1705
|
+
)
|
|
1706
|
+
except Exception as _ee:
|
|
1707
|
+
env_step_error = _ee
|
|
1708
|
+
env_step_end = _time.perf_counter()
|
|
1709
|
+
env_step_duration_ms = (env_step_end - env_step_start) * 1000.0
|
|
1710
|
+
last_env_step_ms = env_step_duration_ms
|
|
1711
|
+
last_env_step_completed_ts = env_step_end
|
|
1712
|
+
if last_policy_meta is not None:
|
|
1713
|
+
with contextlib.suppress(Exception):
|
|
1714
|
+
timing_env = last_policy_meta.setdefault("timing", {})
|
|
1715
|
+
timing_env["env_step_ms"] = env_step_duration_ms
|
|
1716
|
+
timing_env["env_step_end_s"] = env_step_end
|
|
1717
|
+
|
|
1718
|
+
if env_step_error is not None:
|
|
1719
|
+
with contextlib.suppress(Exception):
|
|
1720
|
+
logger.warning(
|
|
1721
|
+
"ENV_STEP_FAIL: failing rollout run_id=%s op_idx=%s err=%s",
|
|
1722
|
+
request.run_id,
|
|
1723
|
+
str(op_idx),
|
|
1724
|
+
str(env_step_error),
|
|
1725
|
+
)
|
|
1726
|
+
raise HTTPException(
|
|
1727
|
+
status_code=500,
|
|
1728
|
+
detail=f"env_step_failed: {str(env_step_error)}",
|
|
1729
|
+
)
|
|
1730
|
+
|
|
1731
|
+
# Reaching here means env step succeeded
|
|
1732
|
+
assert env_response is not None
|
|
1733
|
+
|
|
1734
|
+
# Record step, including policy meta if present for timing/tokens observability
|
|
1735
|
+
_info = env_response.info if isinstance(env_response.info, dict) else {}
|
|
1736
|
+
# Attach policy meta from the immediately preceding agent step
|
|
1737
|
+
with contextlib.suppress(Exception):
|
|
1738
|
+
prev_meta = {}
|
|
1739
|
+
if "policy_response" in locals() and isinstance(policy_response.meta, dict): # type: ignore[name-defined]
|
|
1740
|
+
prev_meta = policy_response.meta
|
|
1741
|
+
if prev_meta:
|
|
1742
|
+
_info = dict(_info)
|
|
1743
|
+
_info["meta"] = prev_meta
|
|
1744
|
+
|
|
1745
|
+
event_metadata = {
|
|
1746
|
+
"op_index": op_idx,
|
|
1747
|
+
}
|
|
1748
|
+
event_id = await tracing_context.record_environment_event(
|
|
1749
|
+
env_handle=env_handle,
|
|
1750
|
+
prev_obs=current_obs,
|
|
1751
|
+
env_response=env_response,
|
|
1752
|
+
next_obs=getattr(env_response, "observation", None),
|
|
1753
|
+
metadata=event_metadata,
|
|
1754
|
+
)
|
|
1755
|
+
|
|
1756
|
+
decision_index += 1
|
|
1757
|
+
next_obs = env_response.observation
|
|
1758
|
+
new_achievement_state = _extract_achievements(next_obs)
|
|
1759
|
+
new_inventory_state = _extract_inventory(next_obs)
|
|
1760
|
+
new_achievement_counts_state = _extract_achievement_counts(next_obs)
|
|
1761
|
+
final_achievement_count = sum(
|
|
1762
|
+
1 for _, unlocked in new_achievement_state.items() if unlocked
|
|
1763
|
+
)
|
|
1764
|
+
indicator_val = 0
|
|
1765
|
+
reward_stepwise = 0.0
|
|
1766
|
+
decision_rewards_meta: dict[str, Any] | None = None
|
|
1767
|
+
decision_record = None
|
|
1768
|
+
_info = {} if not isinstance(_info, dict) else dict(_info)
|
|
1769
|
+
if step_rewards_active:
|
|
1770
|
+
decision_actions = _summarize_tool_calls(pending_tool_calls)
|
|
1771
|
+
stepwise_info, decision_record, stats = compute_stepwise_reward(
|
|
1772
|
+
prev_achievements or {},
|
|
1773
|
+
new_achievement_state,
|
|
1774
|
+
decision_index,
|
|
1775
|
+
decision_actions,
|
|
1776
|
+
step_rewards_indicator_lambda,
|
|
1777
|
+
strategy=step_rewards_strategy,
|
|
1778
|
+
weights=step_rewards_weights,
|
|
1779
|
+
k_limits=step_rewards_k_limits,
|
|
1780
|
+
episode_counts=episode_achievement_counts,
|
|
1781
|
+
prev_inventory=prev_inventory_state,
|
|
1782
|
+
new_inventory=new_inventory_state,
|
|
1783
|
+
prev_counts=prev_achievement_counts_state,
|
|
1784
|
+
new_counts=new_achievement_counts_state,
|
|
1785
|
+
)
|
|
1786
|
+
indicator_val = int(stats.get("indicator", 0.0))
|
|
1787
|
+
reward_stepwise = float(stats.get("reward", 0.0))
|
|
1788
|
+
stepwise_indicator_sum += float(stats.get("indicator", 0.0))
|
|
1789
|
+
stepwise_reward_sum += reward_stepwise
|
|
1790
|
+
stepwise_new_achievements_total += int(stats.get("new_achievements_count", 0.0))
|
|
1791
|
+
with contextlib.suppress(Exception):
|
|
1792
|
+
resource_component = stats.get("resource_reward")
|
|
1793
|
+
if resource_component is not None:
|
|
1794
|
+
stepwise_resource_reward_sum += float(resource_component)
|
|
1795
|
+
_info["stepwise"] = stepwise_info
|
|
1796
|
+
# Compute decision-level rewards (absolute vs unique) and attach to metadata
|
|
1797
|
+
with contextlib.suppress(Exception):
|
|
1798
|
+
turned_true = set(stepwise_info.get("new_achievements") or [])
|
|
1799
|
+
seen_before = set(episode_seen_achievements)
|
|
1800
|
+
new_unique = sorted(turned_true - seen_before)
|
|
1801
|
+
ach_delta = int(len(turned_true))
|
|
1802
|
+
unique_delta = int(len(new_unique))
|
|
1803
|
+
# Prepare stable lists for logging/metadata
|
|
1804
|
+
all_list = sorted(turned_true)
|
|
1805
|
+
# Ensure nested meta exists
|
|
1806
|
+
meta_block = (
|
|
1807
|
+
_info.get("meta") if isinstance(_info.get("meta"), dict) else {}
|
|
1808
|
+
)
|
|
1809
|
+
decision_rewards = {
|
|
1810
|
+
"turn": int(decision_index),
|
|
1811
|
+
"ach_delta": ach_delta,
|
|
1812
|
+
"unique_delta": unique_delta,
|
|
1813
|
+
"all": all_list,
|
|
1814
|
+
"unique": new_unique,
|
|
1815
|
+
}
|
|
1816
|
+
decision_rewards_meta = decision_rewards
|
|
1817
|
+
meta_block["decision_rewards"] = decision_rewards
|
|
1818
|
+
_info["meta"] = meta_block
|
|
1819
|
+
# Update episode-level seen set after attributing uniqueness to this decision
|
|
1820
|
+
episode_seen_achievements.update(turned_true)
|
|
1821
|
+
if decision_record is not None:
|
|
1822
|
+
decision_samples.append(decision_record)
|
|
1823
|
+
prev_achievements = new_achievement_state
|
|
1824
|
+
prev_inventory_state = new_inventory_state
|
|
1825
|
+
prev_achievement_counts_state = new_achievement_counts_state
|
|
1826
|
+
|
|
1827
|
+
await tracing_context.record_decision_reward(
|
|
1828
|
+
event_id=event_id,
|
|
1829
|
+
decision_meta=decision_rewards_meta,
|
|
1830
|
+
)
|
|
1831
|
+
|
|
1832
|
+
step = RolloutStep(
|
|
1833
|
+
obs=_summarize_observation_for_storage(env_handle, current_obs),
|
|
1834
|
+
tool_calls=pending_tool_calls,
|
|
1835
|
+
reward=env_response.reward,
|
|
1836
|
+
done=env_response.done,
|
|
1837
|
+
truncated=env_response.truncated,
|
|
1838
|
+
info=_info,
|
|
1839
|
+
)
|
|
1840
|
+
# Log summarized env application of tool calls and immediate reward/done
|
|
1841
|
+
with contextlib.suppress(Exception):
|
|
1842
|
+
_tc = pending_tool_calls or []
|
|
1843
|
+
_summary = []
|
|
1844
|
+
for _item in (_tc if isinstance(_tc, list) else []):
|
|
1845
|
+
try:
|
|
1846
|
+
if isinstance(_item, dict):
|
|
1847
|
+
_tool = _item.get("tool")
|
|
1848
|
+
_args = _item.get("args")
|
|
1849
|
+
_keys = list(_args.keys()) if isinstance(_args, dict) else []
|
|
1850
|
+
_summary.append({"tool": _tool, "args_keys": _keys})
|
|
1851
|
+
except Exception:
|
|
1852
|
+
continue
|
|
1853
|
+
_rid = getattr(request, "run_id", None)
|
|
1854
|
+
logger.info(
|
|
1855
|
+
"ENV_APPLY: run_id=%s tool_calls=%d reward=%s done=%s summary=%s",
|
|
1856
|
+
_rid,
|
|
1857
|
+
len(_tc),
|
|
1858
|
+
str(env_response.reward),
|
|
1859
|
+
str(env_response.done),
|
|
1860
|
+
_summary,
|
|
1861
|
+
)
|
|
1862
|
+
print(
|
|
1863
|
+
f"[rollout] env apply run_id={_rid} tool_calls={len(_tc)} reward={env_response.reward} done={env_response.done} summary={_summary}",
|
|
1864
|
+
flush=True,
|
|
1865
|
+
)
|
|
1866
|
+
trajectory_steps.append(step)
|
|
1867
|
+
|
|
1868
|
+
if env_response.reward is not None:
|
|
1869
|
+
total_reward += env_response.reward
|
|
1870
|
+
|
|
1871
|
+
# Update state
|
|
1872
|
+
current_obs = next_obs
|
|
1873
|
+
pending_tool_calls = None
|
|
1874
|
+
ops_executed += 1
|
|
1875
|
+
|
|
1876
|
+
# Handle episode end
|
|
1877
|
+
if env_response.done:
|
|
1878
|
+
if request.on_done == "reset":
|
|
1879
|
+
# Reset environment
|
|
1880
|
+
from .environment_routes import (
|
|
1881
|
+
EnvResetRequest,
|
|
1882
|
+
reset_environment,
|
|
1883
|
+
)
|
|
1884
|
+
|
|
1885
|
+
reset_response = await reset_environment(EnvResetRequest(env_id=env_id))
|
|
1886
|
+
current_obs = reset_response.observation
|
|
1887
|
+
prev_achievements = _extract_achievements(current_obs)
|
|
1888
|
+
episode_seen_achievements = {
|
|
1889
|
+
k for k, v in (prev_achievements or {}).items() if bool(v)
|
|
1890
|
+
}
|
|
1891
|
+
episode_achievement_counts.clear()
|
|
1892
|
+
elif request.on_done == "terminate":
|
|
1893
|
+
break
|
|
1894
|
+
|
|
1895
|
+
if decision_open:
|
|
1896
|
+
await tracing_context.end_decision()
|
|
1897
|
+
decision_open = False
|
|
1898
|
+
|
|
1899
|
+
else:
|
|
1900
|
+
logger.warning(f"Unknown op: {op}")
|
|
1901
|
+
|
|
1902
|
+
if (
|
|
1903
|
+
last_policy_meta is not None
|
|
1904
|
+
and last_agent_response_ts is not None
|
|
1905
|
+
and "timing" in last_policy_meta
|
|
1906
|
+
and isinstance(last_policy_meta["timing"], dict)
|
|
1907
|
+
and "decision_ms" not in last_policy_meta["timing"]
|
|
1908
|
+
):
|
|
1909
|
+
with contextlib.suppress(Exception):
|
|
1910
|
+
final_now = last_env_step_completed_ts or _time.perf_counter()
|
|
1911
|
+
final_decision_ms = max(0.0, (final_now - float(last_agent_response_ts)) * 1000.0)
|
|
1912
|
+
timing_final = last_policy_meta.setdefault("timing", {})
|
|
1913
|
+
timing_final["decision_ms"] = final_decision_ms
|
|
1914
|
+
if last_env_step_ms is not None:
|
|
1915
|
+
timing_final.setdefault("env_step_ms", float(last_env_step_ms))
|
|
1916
|
+
timing_final.setdefault(
|
|
1917
|
+
"overhead_ms",
|
|
1918
|
+
max(0.0, final_decision_ms - float(last_env_step_ms)),
|
|
1919
|
+
)
|
|
1920
|
+
else:
|
|
1921
|
+
timing_final.setdefault("overhead_ms", 0.0)
|
|
1922
|
+
|
|
1923
|
+
# Build trajectory
|
|
1924
|
+
# Extract inference_url from policy config (REQUIRED for trace correlation)
|
|
1925
|
+
# The trainer sets this in policy config with ?cid=... parameter
|
|
1926
|
+
inference_url = None
|
|
1927
|
+
|
|
1928
|
+
# Try policy config from request first (most reliable source)
|
|
1929
|
+
try:
|
|
1930
|
+
policy_config_snapshot = (
|
|
1931
|
+
request.policy.config if isinstance(request.policy.config, dict) else {}
|
|
1932
|
+
)
|
|
1933
|
+
inference_url = policy_config_snapshot.get("inference_url")
|
|
1934
|
+
if inference_url:
|
|
1935
|
+
logger.info(
|
|
1936
|
+
"ROLLOUT_TRAJECTORY: extracted inference_url from request.policy.config run_id=%s url=%s",
|
|
1937
|
+
request.run_id,
|
|
1938
|
+
inference_url,
|
|
1939
|
+
)
|
|
1940
|
+
except Exception as exc:
|
|
1941
|
+
logger.warning(
|
|
1942
|
+
"ROLLOUT_TRAJECTORY: failed to get inference_url from request.policy.config run_id=%s: %s",
|
|
1943
|
+
request.run_id,
|
|
1944
|
+
exc,
|
|
1945
|
+
)
|
|
1946
|
+
|
|
1947
|
+
# Fallback: Try policy handle snapshot (if request.policy.config failed)
|
|
1948
|
+
if not inference_url and policy_handle is not None:
|
|
1949
|
+
try:
|
|
1950
|
+
policy_snapshot = policy_handle.snapshot()
|
|
1951
|
+
inference_url = policy_snapshot.get("config", {}).get("inference_url")
|
|
1952
|
+
if inference_url:
|
|
1953
|
+
logger.info(
|
|
1954
|
+
"ROLLOUT_TRAJECTORY: extracted inference_url from policy_handle.snapshot run_id=%s url=%s",
|
|
1955
|
+
request.run_id,
|
|
1956
|
+
inference_url,
|
|
1957
|
+
)
|
|
1958
|
+
except Exception as exc:
|
|
1959
|
+
logger.warning(
|
|
1960
|
+
"ROLLOUT_TRAJECTORY: failed to snapshot policy for run_id=%s policy_id=%s: %s",
|
|
1961
|
+
request.run_id,
|
|
1962
|
+
policy_id,
|
|
1963
|
+
exc,
|
|
1964
|
+
)
|
|
1965
|
+
|
|
1966
|
+
# ASSERTION: inference_url MUST be present (required by RolloutTrajectory schema)
|
|
1967
|
+
if not inference_url:
|
|
1968
|
+
raise ValueError(
|
|
1969
|
+
f"FATAL: inference_url is required but not found!\n"
|
|
1970
|
+
f"\n"
|
|
1971
|
+
f"run_id: {request.run_id}\n"
|
|
1972
|
+
f"policy_id: {policy_id}\n"
|
|
1973
|
+
f"policy_config_keys: {list(policy_config_snapshot.keys()) if 'policy_config_snapshot' in locals() else 'N/A'}\n"
|
|
1974
|
+
f"\n"
|
|
1975
|
+
f"The trainer MUST set inference_url in policy config with ?cid=... parameter.\n"
|
|
1976
|
+
f"This is required for trace correlation and hydration.\n"
|
|
1977
|
+
)
|
|
1978
|
+
|
|
1979
|
+
# policy_config_snapshot already set above in try block (line 1876-1878)
|
|
1980
|
+
# Ensure it exists for logging below
|
|
1981
|
+
if 'policy_config_snapshot' not in locals():
|
|
1982
|
+
policy_config_snapshot = {}
|
|
1983
|
+
|
|
1984
|
+
# Normalize inference URL for trajectory (and ensure no path in query)
|
|
1985
|
+
try:
|
|
1986
|
+
from .utils import force_normalize_chat_completions_url, ensure_chat_completions_url
|
|
1987
|
+
inference_url = force_normalize_chat_completions_url(inference_url)
|
|
1988
|
+
# apply mode-aware normalization too (keeps cid, appends path if missing)
|
|
1989
|
+
inference_url = ensure_chat_completions_url(inference_url, mode=request.mode)
|
|
1990
|
+
except Exception:
|
|
1991
|
+
pass
|
|
1992
|
+
|
|
1993
|
+
logger.info(
|
|
1994
|
+
"ROLLOUT_TRAJECTORY: run_id=%s policy_id=%s inference_url=%s trace_id=%s",
|
|
1995
|
+
request.run_id,
|
|
1996
|
+
policy_id,
|
|
1997
|
+
inference_url,
|
|
1998
|
+
policy_config_snapshot.get("trace_correlation_id"),
|
|
1999
|
+
)
|
|
2000
|
+
|
|
2001
|
+
trajectory = RolloutTrajectory(
|
|
2002
|
+
env_id=env_id,
|
|
2003
|
+
policy_id=policy_id,
|
|
2004
|
+
steps=trajectory_steps,
|
|
2005
|
+
final={"observation": _summarize_observation_for_storage(env_handle, current_obs)},
|
|
2006
|
+
length=len(trajectory_steps),
|
|
2007
|
+
inference_url=inference_url, # NEW: Required for trace correlation
|
|
2008
|
+
decision_samples=decision_samples if step_rewards_active else None,
|
|
2009
|
+
)
|
|
2010
|
+
|
|
2011
|
+
# Build metrics
|
|
2012
|
+
metrics = RolloutMetrics(
|
|
2013
|
+
episode_returns=[total_reward],
|
|
2014
|
+
mean_return=total_reward,
|
|
2015
|
+
num_steps=len(trajectory_steps),
|
|
2016
|
+
num_episodes=1,
|
|
2017
|
+
)
|
|
2018
|
+
if step_rewards_active:
|
|
2019
|
+
stepwise_summary: dict[str, Any] = {
|
|
2020
|
+
"indicator_sum": float(stepwise_indicator_sum),
|
|
2021
|
+
"reward_sum": float(stepwise_reward_sum),
|
|
2022
|
+
"resource_reward": float(stepwise_resource_reward_sum),
|
|
2023
|
+
"new_achievements_total": int(stepwise_new_achievements_total),
|
|
2024
|
+
"mode": step_rewards_mode,
|
|
2025
|
+
"strategy": step_rewards_strategy,
|
|
2026
|
+
"indicator_lambda": float(step_rewards_indicator_lambda),
|
|
2027
|
+
}
|
|
2028
|
+
if step_rewards_beta:
|
|
2029
|
+
stepwise_summary["step_beta"] = float(step_rewards_beta)
|
|
2030
|
+
if step_rewards_strategy == "per_achievement":
|
|
2031
|
+
if step_rewards_weights:
|
|
2032
|
+
stepwise_summary["weights"] = dict(step_rewards_weights)
|
|
2033
|
+
if step_rewards_k_limits:
|
|
2034
|
+
stepwise_summary["k_limits"] = dict(step_rewards_k_limits)
|
|
2035
|
+
final_achievements_list = sorted(
|
|
2036
|
+
key for key, val in (prev_achievements or {}).items() if bool(val)
|
|
2037
|
+
)
|
|
2038
|
+
stepwise_summary["unique_achievements_total"] = int(len(episode_seen_achievements))
|
|
2039
|
+
stepwise_summary["unique_achievements"] = sorted(episode_seen_achievements)
|
|
2040
|
+
stepwise_summary["final_achievements"] = final_achievements_list
|
|
2041
|
+
metrics.details["stepwise"] = stepwise_summary
|
|
2042
|
+
|
|
2043
|
+
# Environment-specific: Log summary if available
|
|
2044
|
+
try:
|
|
2045
|
+
# Check if this is a Wordle environment and use Wordle helpers (lazy import)
|
|
2046
|
+
wordle_wrapper_cls = None
|
|
2047
|
+
try:
|
|
2048
|
+
from .envs.wordle.environment import WordleEnvironmentWrapper
|
|
2049
|
+
from .envs.wordle.helpers import (
|
|
2050
|
+
get_wordle_rollout_summary,
|
|
2051
|
+
log_wordle_rollout_summary,
|
|
2052
|
+
)
|
|
2053
|
+
|
|
2054
|
+
wordle_wrapper_cls = WordleEnvironmentWrapper
|
|
2055
|
+
except Exception:
|
|
2056
|
+
wordle_wrapper_cls = None # type: ignore[assignment]
|
|
2057
|
+
get_wordle_rollout_summary = None # type: ignore
|
|
2058
|
+
log_wordle_rollout_summary = None # type: ignore
|
|
2059
|
+
|
|
2060
|
+
is_wordle = wordle_wrapper_cls is not None and isinstance(
|
|
2061
|
+
env_handle.env,
|
|
2062
|
+
wordle_wrapper_cls, # type: ignore[arg-type]
|
|
2063
|
+
)
|
|
2064
|
+
if is_wordle:
|
|
2065
|
+
# Convert trajectory steps to expected format
|
|
2066
|
+
formatted_steps = []
|
|
2067
|
+
for step in trajectory_steps:
|
|
2068
|
+
formatted_steps.append({"tool_calls": step.tool_calls or []})
|
|
2069
|
+
|
|
2070
|
+
if (
|
|
2071
|
+
get_wordle_rollout_summary is not None
|
|
2072
|
+
and log_wordle_rollout_summary is not None
|
|
2073
|
+
):
|
|
2074
|
+
summary = get_wordle_rollout_summary(formatted_steps, current_obs, env_handle)
|
|
2075
|
+
log_wordle_rollout_summary(request.run_id, summary)
|
|
2076
|
+
except ImportError:
|
|
2077
|
+
# Wordle helpers not available, skip Wordle-specific logging
|
|
2078
|
+
pass
|
|
2079
|
+
except Exception as e:
|
|
2080
|
+
logger.warning(f"Failed to generate environment-specific summary: {e}")
|
|
2081
|
+
|
|
2082
|
+
# Mark run as completed
|
|
2083
|
+
aborted = registry.is_run_aborted(request.run_id)
|
|
2084
|
+
if not aborted:
|
|
2085
|
+
registry.complete_run(request.run_id)
|
|
2086
|
+
if decision_open:
|
|
2087
|
+
await tracing_context.end_decision()
|
|
2088
|
+
decision_open = False
|
|
2089
|
+
if not finalized:
|
|
2090
|
+
session_trace = await tracing_context.finalize(
|
|
2091
|
+
total_reward=total_reward,
|
|
2092
|
+
achievement_state=prev_achievements,
|
|
2093
|
+
total_steps=len(trajectory_steps),
|
|
2094
|
+
)
|
|
2095
|
+
finalized = True
|
|
2096
|
+
trace_payload = tracing_context.build_trace_payload(session_trace)
|
|
2097
|
+
|
|
2098
|
+
# Debug: Check trace payload
|
|
2099
|
+
logger.info(f"[TRACE_DEBUG] trace_payload is None: {trace_payload is None}, return_trace={tracing_context.return_trace}")
|
|
2100
|
+
if trace_payload:
|
|
2101
|
+
logger.info(f"[TRACE_DEBUG] trace_payload keys: {list(trace_payload.keys())}")
|
|
2102
|
+
|
|
2103
|
+
# Hard-fail if no steps executed (avg_turns == 0 scenario)
|
|
2104
|
+
if metrics.num_steps <= 0:
|
|
2105
|
+
raise HTTPException(status_code=500, detail="no_steps_executed: avg_turns == 0")
|
|
2106
|
+
|
|
2107
|
+
# Ensure at least one tool call executed successfully
|
|
2108
|
+
tool_call_executed = any(
|
|
2109
|
+
isinstance(step.tool_calls, list) and len(step.tool_calls) > 0 for step in trajectory_steps
|
|
2110
|
+
)
|
|
2111
|
+
if not tool_call_executed:
|
|
2112
|
+
raise HTTPException(
|
|
2113
|
+
status_code=502,
|
|
2114
|
+
detail="no_tool_calls_executed: model failed to produce actionable tool calls.",
|
|
2115
|
+
)
|
|
2116
|
+
|
|
2117
|
+
response = RolloutResponse(
|
|
2118
|
+
run_id=request.run_id,
|
|
2119
|
+
trajectories=[trajectory],
|
|
2120
|
+
branches={},
|
|
2121
|
+
metrics=metrics,
|
|
2122
|
+
aborted=aborted,
|
|
2123
|
+
ops_executed=ops_executed,
|
|
2124
|
+
trace=trace_payload,
|
|
2125
|
+
)
|
|
2126
|
+
logger.info(
|
|
2127
|
+
"ROLLOUT_RESPONSE: run_id=%s aborted=%s ops_executed=%s metrics_steps=%s trace_present=%s pipeline_metadata=%s",
|
|
2128
|
+
request.run_id,
|
|
2129
|
+
aborted,
|
|
2130
|
+
ops_executed,
|
|
2131
|
+
metrics.num_steps,
|
|
2132
|
+
bool(trace_payload),
|
|
2133
|
+
response.pipeline_metadata,
|
|
2134
|
+
)
|
|
2135
|
+
return response
|
|
2136
|
+
|
|
2137
|
+
except Exception as e:
|
|
2138
|
+
logger.error(f"Rollout failed for run {request.run_id}: {e}")
|
|
2139
|
+
registry.abort_run(request.run_id)
|
|
2140
|
+
if decision_open:
|
|
2141
|
+
with contextlib.suppress(Exception):
|
|
2142
|
+
await tracing_context.end_decision()
|
|
2143
|
+
decision_open = False
|
|
2144
|
+
if not finalized:
|
|
2145
|
+
session_trace = None
|
|
2146
|
+
with contextlib.suppress(Exception):
|
|
2147
|
+
session_trace = await tracing_context.finalize(
|
|
2148
|
+
total_reward=total_reward,
|
|
2149
|
+
achievement_state=prev_achievements,
|
|
2150
|
+
total_steps=len(trajectory_steps),
|
|
2151
|
+
)
|
|
2152
|
+
finalized = True
|
|
2153
|
+
raise HTTPException(status_code=500, detail=str(e)) from e
|
|
2154
|
+
finally:
|
|
2155
|
+
# Ensure any environment created for this rollout is terminated (no reuse across rollouts)
|
|
2156
|
+
try:
|
|
2157
|
+
if created_env_id:
|
|
2158
|
+
from .environment_routes import EnvTerminateRequest, terminate_environment
|
|
2159
|
+
|
|
2160
|
+
await terminate_environment(EnvTerminateRequest(env_id=created_env_id))
|
|
2161
|
+
logger.info(
|
|
2162
|
+
"ROLL_OUT: terminated environment env_id=%s seed=%s",
|
|
2163
|
+
str(created_env_id),
|
|
2164
|
+
str(env_seed_used) if env_seed_used is not None else "unknown",
|
|
2165
|
+
)
|
|
2166
|
+
# Verify removal from registry
|
|
2167
|
+
with contextlib.suppress(Exception):
|
|
2168
|
+
_post = registry.get_env(created_env_id)
|
|
2169
|
+
logger.info(
|
|
2170
|
+
"ROLL_OUT: env_killed=%s (post_lookup=%s)",
|
|
2171
|
+
str(_post is None),
|
|
2172
|
+
str(_post),
|
|
2173
|
+
)
|
|
2174
|
+
except Exception as _te:
|
|
2175
|
+
logger.warning(f"ROLL_OUT: failed to terminate environment {created_env_id}: {_te}")
|
|
2176
|
+
|
|
2177
|
+
# Best-effort policy cleanup if we created one (avoid reuse across rollouts)
|
|
2178
|
+
with contextlib.suppress(Exception):
|
|
2179
|
+
if created_policy_id:
|
|
2180
|
+
from .policy_routes import PolicyTerminateRequest, terminate_policy
|
|
2181
|
+
|
|
2182
|
+
await terminate_policy(PolicyTerminateRequest(policy_id=created_policy_id))
|
|
2183
|
+
logger.info("ROLL_OUT: terminated policy policy_id=%s", str(created_policy_id))
|
|
2184
|
+
|
|
2185
|
+
if not finalized:
|
|
2186
|
+
session_trace = None
|
|
2187
|
+
with contextlib.suppress(Exception):
|
|
2188
|
+
session_trace = await tracing_context.finalize(
|
|
2189
|
+
total_reward=total_reward,
|
|
2190
|
+
achievement_state=prev_achievements,
|
|
2191
|
+
total_steps=len(trajectory_steps),
|
|
2192
|
+
)
|
|
2193
|
+
finalized = True
|
|
2194
|
+
|
|
2195
|
+
with contextlib.suppress(Exception):
|
|
2196
|
+
_clear_seed_side_effects()
|
|
2197
|
+
logger.info("ROLL_OUT: RNG seed terminated/cleared before conclusion")
|
|
2198
|
+
|
|
2199
|
+
|
|
2200
|
+
@router.post("/run/abort", response_model=RunAbortResponse)
|
|
2201
|
+
async def abort_run(request: RunAbortRequest) -> RunAbortResponse:
|
|
2202
|
+
"""Abort a running rollout."""
|
|
2203
|
+
success = registry.abort_run(request.run_id)
|
|
2204
|
+
|
|
2205
|
+
if not success:
|
|
2206
|
+
raise HTTPException(
|
|
2207
|
+
status_code=404,
|
|
2208
|
+
detail=f"Run {request.run_id} not found",
|
|
2209
|
+
)
|
|
2210
|
+
|
|
2211
|
+
return RunAbortResponse(
|
|
2212
|
+
ok=True,
|
|
2213
|
+
run_id=request.run_id,
|
|
2214
|
+
)
|
|
2215
|
+
|
|
2216
|
+
|
|
2217
|
+
@router.get("/run/status/{run_id}", response_model=RunStatusResponse)
|
|
2218
|
+
async def get_run_status(run_id: str) -> RunStatusResponse:
|
|
2219
|
+
"""Get the status of a run."""
|
|
2220
|
+
run_handle = registry.get_run(run_id)
|
|
2221
|
+
|
|
2222
|
+
if not run_handle:
|
|
2223
|
+
raise HTTPException(
|
|
2224
|
+
status_code=404,
|
|
2225
|
+
detail=f"Run {run_id} not found",
|
|
2226
|
+
)
|
|
2227
|
+
|
|
2228
|
+
return RunStatusResponse(
|
|
2229
|
+
run_id=run_id,
|
|
2230
|
+
status=run_handle.status,
|
|
2231
|
+
started_at=run_handle.started_at,
|
|
2232
|
+
finished_at=run_handle.finished_at,
|
|
2233
|
+
)
|