synth-ai 0.2.13.dev1__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/multi_step/configs/README_verilog_rl.md +77 -0
- examples/multi_step/configs/VERILOG_REWARDS.md +90 -0
- examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +183 -0
- examples/multi_step/configs/crafter_eval_synth_qwen4b.toml +35 -0
- examples/multi_step/configs/crafter_eval_text_only_groq_qwen32b.toml +36 -0
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +17 -5
- examples/multi_step/configs/crafter_synth_backend.md +40 -0
- examples/multi_step/configs/verilog_eval_groq_qwen32b.toml +31 -0
- examples/multi_step/configs/verilog_eval_synth_qwen8b.toml +33 -0
- examples/multi_step/configs/verilog_rl_lora.toml +190 -0
- examples/multi_step/judges/crafter_backend_judge.py +220 -0
- examples/multi_step/judges/verilog_backend_judge.py +234 -0
- examples/multi_step/readme.md +48 -0
- examples/multi_step/verilog_rl_lora.md +218 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +1 -1
- examples/sft/evaluate.py +2 -0
- examples/sft/generate_traces.py +2 -0
- examples/swe/task_app/grpo_swe_mini.py +56 -26
- examples/swe/task_app/hosted/rollout.py +42 -0
- examples/swe/task_app/hosted/test_service.py +5 -6
- examples/task_apps/IMAGE_ONLY_EVAL_QUICKSTART.md +258 -0
- examples/task_apps/TESTING.md +275 -0
- examples/task_apps/__init__.py +0 -0
- examples/task_apps/crafter/CREATE_SFT_DATASET.md +273 -0
- examples/task_apps/crafter/EVAL_IMAGE_ONLY_RESULTS.md +152 -0
- examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +174 -0
- examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +268 -0
- examples/task_apps/crafter/QUERY_EXAMPLES.md +203 -0
- examples/task_apps/crafter/README_IMAGE_ONLY_EVAL.md +316 -0
- examples/task_apps/crafter/__init__.py +0 -0
- examples/task_apps/crafter/eval_image_only_gpt4o.toml +28 -0
- examples/task_apps/crafter/eval_text_only_groq_llama.toml +36 -0
- examples/task_apps/crafter/filter_sft_dataset.toml +16 -0
- examples/task_apps/crafter/task_app/__init__.py +5 -0
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter.py +324 -21
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter_task_app.py +1 -1
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/environment.py +10 -0
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/policy.py +76 -7
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/react_agent.py +17 -2
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/openai_client.py +25 -3
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/policy_routes.py +77 -4
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/rollout.py +117 -9
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_service.py +5 -6
- examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +218 -0
- examples/task_apps/dev/pokemon_emerald/__init__.py +2 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/README.md +811 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/__init__.py +120 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/action.py +160 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/memory.py +155 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/perception.py +69 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/planning.py +96 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/simple.py +1502 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/system_prompt.py +4 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/grab_map.py +68 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/manual.py +216 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/__init__.py +35 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emerald_utils.py +631 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emulator.py +1544 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/enums.py +1428 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/memory_reader.py +4848 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/types.py +41 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/utils.py +298 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pyproject.toml +95 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/run.py +204 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/app.py +2152 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/client.py +429 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/frame_server.py +155 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/README.md +78 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/run_tests.py +122 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_direct.py +76 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_prompts.py +413 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_battle_state_formatting.py +204 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection.py +133 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection_comprehensive.py +229 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_direct_agent_emulator.py +300 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_fps_adjustment_pytest.py +205 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_direct.py +200 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_transition.py +284 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_map_ground_truth_comparison.py +468 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_memory_map.py +575 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_server_map_validation.py +311 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_torchic_state.py +259 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/anticheat.py +372 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/checkpoint.py +296 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/error_handler.py +275 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/get_local_ip.py +22 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/helpers.py +44 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/llm_logger.py +514 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_formatter.py +415 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher.py +1763 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher_singleton.py +33 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_trimmer.py +106 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_visualizer.py +334 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/ocr_dialogue.py +1020 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/recording.py +188 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/state_formatter.py +1481 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/vlm.py +862 -0
- examples/task_apps/dev/pokemon_emerald/modal_app.py +114 -0
- examples/task_apps/dev/pokemon_emerald/task_app/README.md +81 -0
- examples/task_apps/dev/pokemon_emerald/task_app/__init__.py +6 -0
- examples/task_apps/dev/pokemon_emerald/task_app/pokemon_emerald.py +685 -0
- examples/task_apps/enron/__init__.py +1 -0
- examples/task_apps/enron/eval_groq_qwen32.toml +16 -0
- examples/task_apps/enron/filter_sft.toml +5 -0
- examples/task_apps/enron/task_app/README.md +14 -0
- examples/task_apps/enron/task_app/__init__.py +1 -0
- examples/task_apps/enron/task_app/grpo_enron.py +906 -0
- examples/task_apps/enron/task_app/grpo_enron_task_app.py +146 -0
- examples/task_apps/enron/tests/__init__.py +4 -0
- examples/task_apps/enron/tests/conftest.py +115 -0
- examples/task_apps/enron/tests/integration/__init__.py +4 -0
- examples/task_apps/enron/tests/integration/test_enron_eval.py +179 -0
- examples/task_apps/enron/tests/integration/test_enron_rollout.py +135 -0
- examples/task_apps/enron/tests/unit/__init__.py +4 -0
- examples/task_apps/enron/tests/unit/test_enron_environment.py +126 -0
- examples/task_apps/math/__init__.py +0 -0
- examples/{rl/task_app → task_apps/math}/math_single_step.py +19 -10
- examples/task_apps/pokemon_battle/__init__.py +2 -0
- examples/task_apps/pokemon_battle/modal_app.py +104 -0
- examples/task_apps/pokemon_battle/task_app/README.md +68 -0
- examples/task_apps/pokemon_battle/task_app/__init__.py +6 -0
- examples/task_apps/pokemon_battle/task_app/pokemon_showdown.py +932 -0
- examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_COMPLETE.md +283 -0
- examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_STATUS.md +155 -0
- examples/task_apps/pokemon_red/README.md +357 -0
- examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +415 -0
- examples/task_apps/pokemon_red/__init__.py +3 -0
- examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +29 -0
- examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +225 -0
- examples/task_apps/pokemon_red/pallet_town_rl_config.toml +75 -0
- examples/task_apps/pokemon_red/task_app.py +799 -0
- examples/task_apps/pokemon_red/test_pallet_town_rewards.py +193 -0
- examples/task_apps/sokoban/README.md +307 -0
- examples/task_apps/sokoban/__init__.py +3 -0
- examples/task_apps/sokoban/eval_groq_qwen32.toml +16 -0
- examples/task_apps/sokoban/eval_openai_gpt5.toml +16 -0
- examples/task_apps/sokoban/filter_sft.toml +5 -0
- examples/task_apps/sokoban/task_app.py +1058 -0
- examples/task_apps/sokoban/tests/__init__.py +4 -0
- examples/task_apps/sokoban/tests/conftest.py +113 -0
- examples/task_apps/sokoban/tests/integration/__init__.py +4 -0
- examples/task_apps/sokoban/tests/integration/test_sokoban_eval.py +57 -0
- examples/task_apps/sokoban/tests/integration/test_sokoban_rollout.py +198 -0
- examples/task_apps/sokoban/tests/unit/__init__.py +4 -0
- examples/task_apps/sokoban/tests/unit/test_sokoban_environment.py +114 -0
- examples/task_apps/verilog/__init__.py +1 -0
- examples/task_apps/verilog/eval_groq_qwen32b.toml +24 -0
- examples/task_apps/verilog/filter_sft.toml +5 -0
- examples/task_apps/verilog/task_app/README.md +12 -0
- examples/task_apps/verilog/task_app/__init__.py +1 -0
- examples/task_apps/verilog/task_app/grpo_verilog.py +1166 -0
- examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +145 -0
- examples/task_apps/verilog/tests/__init__.py +4 -0
- examples/task_apps/verilog/tests/conftest.py +115 -0
- examples/task_apps/verilog/tests/integration/__init__.py +4 -0
- examples/task_apps/verilog/tests/integration/test_verilog_eval.py +181 -0
- examples/task_apps/verilog/tests/integration/test_verilog_rollout.py +55 -0
- examples/task_apps/verilog/tests/unit/__init__.py +4 -0
- examples/task_apps/verilog/tests/unit/test_verilog_scoring.py +118 -0
- examples/vlm/crafter_openai_vlm_agent.py +4 -4
- examples/vlm/run_crafter_vlm_benchmark.py +4 -4
- examples/warming_up_to_rl/groq_test.py +2 -0
- examples/warming_up_to_rl/run_local_rollout.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_modal.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_parallel.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_traced.py +2 -0
- examples/warming_up_to_rl/run_rollout_remote.py +2 -0
- examples/workflows/__init__.py +0 -0
- examples/workflows/math_rl/__init__.py +0 -0
- examples/workflows/math_rl/download_dataset.py +80 -0
- synth_ai/__init__.py +2 -2
- synth_ai/api/models/supported.py +1 -0
- synth_ai/api/train/builders.py +25 -11
- synth_ai/api/train/cli.py +12 -6
- synth_ai/api/train/configs/__init__.py +10 -10
- synth_ai/api/train/configs/rl.py +5 -4
- synth_ai/api/train/configs/sft.py +4 -3
- synth_ai/api/train/env_resolver.py +5 -2
- synth_ai/api/train/supported_algos.py +10 -5
- synth_ai/api/train/utils.py +7 -4
- synth_ai/cli/__init__.py +48 -59
- synth_ai/cli/_modal_wrapper.py +3 -2
- synth_ai/cli/_storage.py +4 -3
- synth_ai/cli/_validate_task_app.py +11 -0
- synth_ai/cli/balance.py +4 -3
- synth_ai/cli/calc.py +2 -2
- synth_ai/cli/demo.py +14 -7
- synth_ai/cli/legacy_root_backup.py +1 -1
- synth_ai/cli/recent.py +1 -1
- synth_ai/cli/rl_demo.py +8 -7
- synth_ai/cli/root.py +0 -97
- synth_ai/cli/status.py +1 -1
- synth_ai/cli/task_apps.py +1922 -190
- synth_ai/cli/traces.py +1 -1
- synth_ai/cli/tui.py +57 -0
- synth_ai/cli/turso.py +1 -1
- synth_ai/cli/watch.py +1 -1
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +29 -17
- synth_ai/environments/examples/crafter_classic/environment.py +1 -1
- synth_ai/environments/examples/enron/engine.py +7 -2
- synth_ai/environments/examples/enron/environment.py +68 -0
- synth_ai/environments/examples/red/engine.py +27 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +7 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_progression.py +477 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +32 -0
- synth_ai/environments/examples/red/environment.py +60 -0
- synth_ai/environments/examples/sokoban/taskset.py +116 -0
- synth_ai/environments/examples/verilog/engine.py +104 -12
- synth_ai/evals/client.py +58 -61
- synth_ai/jobs/client.py +16 -4
- synth_ai/judge_schemas.py +9 -9
- synth_ai/py.typed +0 -0
- synth_ai/task/__init__.py +24 -5
- synth_ai/task/apps/__init__.py +1 -0
- synth_ai/task/config.py +257 -0
- synth_ai/task/contracts.py +138 -39
- synth_ai/task/proxy.py +48 -56
- synth_ai/task/rubrics/__init__.py +56 -0
- synth_ai/task/rubrics/loaders.py +152 -0
- synth_ai/task/rubrics/models.py +57 -0
- synth_ai/task/rubrics/scoring.py +116 -0
- synth_ai/{rubrics/validators.py → task/rubrics/strict.py} +53 -30
- synth_ai/task/server.py +8 -7
- synth_ai/task/trace_correlation_helpers.py +315 -0
- synth_ai/task/validators.py +413 -6
- synth_ai/tracing_v3/abstractions.py +3 -3
- synth_ai/tracing_v3/decorators.py +7 -3
- synth_ai/tracing_v3/llm_call_record_helpers.py +5 -5
- synth_ai/tracing_v3/replica_sync.py +4 -4
- synth_ai/tracing_v3/serialization.py +5 -5
- synth_ai/tracing_v3/session_tracer.py +16 -6
- synth_ai/tracing_v3/storage/base.py +29 -29
- synth_ai/tracing_v3/storage/config.py +3 -3
- synth_ai/tracing_v3/trace_utils.py +317 -0
- synth_ai/tracing_v3/turso/daemon.py +8 -7
- synth_ai/tracing_v3/turso/native_manager.py +66 -43
- synth_ai/tracing_v3/utils.py +3 -3
- synth_ai/tui/__init__.py +5 -0
- synth_ai/tui/__main__.py +13 -0
- synth_ai/tui/cli/__init__.py +1 -0
- synth_ai/tui/cli/query_experiments.py +164 -0
- synth_ai/tui/cli/query_experiments_v3.py +164 -0
- synth_ai/tui/dashboard.py +906 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/METADATA +4 -1
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/RECORD +278 -126
- examples/agora_ex/README_MoE.md +0 -224
- examples/agora_ex/__init__.py +0 -7
- examples/agora_ex/agora_ex.py +0 -65
- examples/agora_ex/agora_ex_task_app.py +0 -590
- examples/agora_ex/configs/rl_lora_qwen3_moe_2xh200.toml +0 -121
- examples/agora_ex/reward_fn_grpo-human.py +0 -129
- examples/agora_ex/system_prompt_CURRENT.md +0 -63
- examples/agora_ex/task_app/agora_ex_task_app.py +0 -590
- examples/agora_ex/task_app/reward_fn_grpo-human.py +0 -129
- examples/agora_ex/task_app/system_prompt_CURRENT.md +0 -63
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +0 -62
- synth_ai/rubrics/__init__.py +0 -22
- synth_ai/task/rubrics.py +0 -219
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/README.md +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/README.md +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/branching.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/environment_routes.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/app.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/shared.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/tools.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/hosted_app.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/main.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/registry.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/volume.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_agents.py +0 -0
- /examples/{rl/task_app → task_apps/math}/README.md +0 -0
- /examples/{rl/task_app → task_apps/math}/math_task_app.py +0 -0
- /examples/{rl → workflows/math_rl}/configs/eval_base_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/eval_rl_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen17.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_ft_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/run_eval.py +0 -0
- /examples/{rl → workflows/math_rl}/run_rl_and_save.py +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/top_level.txt +0 -0
|
@@ -85,8 +85,17 @@ class CrafterReActAgent:
|
|
|
85
85
|
history: list[dict[str, Any]] | None = None,
|
|
86
86
|
turn: int | None = None,
|
|
87
87
|
image_parts: list[dict[str, Any]] | None = None,
|
|
88
|
+
image_only_mode: bool = False,
|
|
88
89
|
) -> list[dict[str, Any]]:
|
|
89
|
-
"""Construct OpenAI-style messages list for vLLM generation.
|
|
90
|
+
"""Construct OpenAI-style messages list for vLLM generation.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
observation: Text observation to include
|
|
94
|
+
history: Previous conversation history
|
|
95
|
+
turn: Current turn number
|
|
96
|
+
image_parts: Image content parts in OpenAI format
|
|
97
|
+
image_only_mode: If True, only include images without text observation
|
|
98
|
+
"""
|
|
90
99
|
msgs: list[dict[str, Any]] = [
|
|
91
100
|
{"role": "system", "content": CrafterReActAgent.get_system_prompt()}
|
|
92
101
|
]
|
|
@@ -94,8 +103,14 @@ class CrafterReActAgent:
|
|
|
94
103
|
msgs.extend(history)
|
|
95
104
|
user_content: Any
|
|
96
105
|
if image_parts:
|
|
97
|
-
|
|
106
|
+
# Image-only mode: send only images without text observation
|
|
107
|
+
if image_only_mode:
|
|
108
|
+
user_content = list(image_parts)
|
|
109
|
+
else:
|
|
110
|
+
# Normal vision mode: send both text and images
|
|
111
|
+
user_content = [{"type": "text", "text": observation}] + list(image_parts)
|
|
98
112
|
else:
|
|
113
|
+
# Text-only mode (default): no images
|
|
99
114
|
user_content = observation
|
|
100
115
|
msgs.append({"role": "user", "content": user_content})
|
|
101
116
|
return msgs
|
|
@@ -149,7 +149,11 @@ class OpenAIClient:
|
|
|
149
149
|
OpenAI-compatible chat completion response
|
|
150
150
|
"""
|
|
151
151
|
base = (base_url or self.base_url).rstrip("/")
|
|
152
|
-
|
|
152
|
+
# Don't append /v1/chat/completions if the URL already contains it
|
|
153
|
+
if "/v1/chat/completions" in base:
|
|
154
|
+
url = base
|
|
155
|
+
else:
|
|
156
|
+
url = base + "/v1/chat/completions"
|
|
153
157
|
timeout = timeout_s or self.timeout_s
|
|
154
158
|
|
|
155
159
|
# Merge headers
|
|
@@ -164,10 +168,28 @@ class OpenAIClient:
|
|
|
164
168
|
except Exception:
|
|
165
169
|
pass
|
|
166
170
|
|
|
167
|
-
#
|
|
171
|
+
# Set Authorization header based on the target URL
|
|
168
172
|
try:
|
|
169
173
|
low_url = (url or "").lower()
|
|
170
|
-
|
|
174
|
+
|
|
175
|
+
# If calling OpenAI directly (api.openai.com)
|
|
176
|
+
if "api.openai.com" in low_url:
|
|
177
|
+
openai_key = os.getenv("OPENAI_API_KEY")
|
|
178
|
+
if openai_key and isinstance(openai_key, str):
|
|
179
|
+
headers["Authorization"] = f"Bearer {openai_key}"
|
|
180
|
+
|
|
181
|
+
# If target is Synth backend (any deployment), use SYNTH_API_KEY
|
|
182
|
+
# Matches: synth-backend-*, agent-learning*, localhost:8000, 127.0.0.1:8000
|
|
183
|
+
elif any(pattern in low_url for pattern in [
|
|
184
|
+
"synth-backend", "synth.run", "agent-learning",
|
|
185
|
+
"localhost:8000", "127.0.0.1:8000"
|
|
186
|
+
]):
|
|
187
|
+
synth_key = os.getenv("SYNTH_API_KEY")
|
|
188
|
+
if synth_key and isinstance(synth_key, str):
|
|
189
|
+
headers["Authorization"] = f"Bearer {synth_key}"
|
|
190
|
+
|
|
191
|
+
# If target is Groq, use GROQ_API_KEY
|
|
192
|
+
elif "/proxy/groq" in low_url or "api.groq.com" in low_url:
|
|
171
193
|
gk = os.getenv("GROQ_API_KEY")
|
|
172
194
|
if gk and isinstance(gk, str):
|
|
173
195
|
headers["Authorization"] = f"Bearer {gk}"
|
|
@@ -10,11 +10,13 @@ from fastapi import APIRouter, HTTPException, Request
|
|
|
10
10
|
from pydantic import BaseModel
|
|
11
11
|
|
|
12
12
|
from synth_ai.task.auth import allowed_environment_api_keys, normalize_environment_api_key
|
|
13
|
+
from synth_ai.task.contracts import RolloutMode
|
|
13
14
|
|
|
14
15
|
from .envs.crafter.policy import CrafterPolicy
|
|
15
16
|
from .inference.openai_client import create_inference_client
|
|
16
17
|
from .registry import registry
|
|
17
18
|
from .storage.volume import storage
|
|
19
|
+
from .utils import ensure_chat_completions_url
|
|
18
20
|
|
|
19
21
|
# Token budgeting (shared logic with inference server)
|
|
20
22
|
try:
|
|
@@ -40,6 +42,7 @@ class PolicyCreateRequest(BaseModel):
|
|
|
40
42
|
parent_policy_id: str | None = None
|
|
41
43
|
rl_run_id: str
|
|
42
44
|
bound_env_id: str | None = None
|
|
45
|
+
mode: RolloutMode
|
|
43
46
|
|
|
44
47
|
|
|
45
48
|
class PolicyCreateResponse(BaseModel):
|
|
@@ -97,10 +100,40 @@ async def create_policy(
|
|
|
97
100
|
|
|
98
101
|
# Set defaults from TaskApp / environment if not provided
|
|
99
102
|
config = dict(request.config or {})
|
|
103
|
+
provider_raw = config.get("provider") or config.get("vendor")
|
|
104
|
+
provider = str(provider_raw).strip().lower() if provider_raw else None
|
|
105
|
+
|
|
106
|
+
# Resolve base URL for proxy endpoints (strip trailing slash)
|
|
107
|
+
base_url = str(req.base_url).rstrip("/")
|
|
108
|
+
|
|
109
|
+
if provider == "groq":
|
|
110
|
+
# Route through in-app Groq proxy by default
|
|
111
|
+
config.setdefault("inference_url", f"{base_url}/proxy/groq")
|
|
112
|
+
# Default to a recent Groq-hosted Qwen unless caller overrides
|
|
113
|
+
preferred_model = "qwen/qwen3-32b"
|
|
114
|
+
config.setdefault("model", preferred_model)
|
|
115
|
+
# Groq Qwen defaults tuned for deterministic tool use
|
|
116
|
+
config.setdefault("temperature", 0.0)
|
|
117
|
+
config.setdefault("top_p", 0.95)
|
|
118
|
+
config.setdefault("max_tokens", 256)
|
|
119
|
+
# Avoid leaking provider in downstream policy if unset
|
|
120
|
+
config["provider"] = "groq"
|
|
121
|
+
elif provider == "openai":
|
|
122
|
+
config.setdefault("inference_url", f"{base_url}/proxy")
|
|
123
|
+
config["provider"] = "openai"
|
|
124
|
+
|
|
125
|
+
received_url = config.get("inference_url")
|
|
126
|
+
logger.info(
|
|
127
|
+
"POLICY_CREATE: policy=%s provider=%s raw_inference_url=%s",
|
|
128
|
+
request.policy_name,
|
|
129
|
+
provider,
|
|
130
|
+
received_url,
|
|
131
|
+
)
|
|
132
|
+
|
|
100
133
|
if "inference_url" not in config and task_app is not None:
|
|
101
|
-
|
|
102
|
-
if
|
|
103
|
-
config["inference_url"] =
|
|
134
|
+
task_base_url = getattr(task_app, "vllm_base_url", None)
|
|
135
|
+
if task_base_url:
|
|
136
|
+
config["inference_url"] = task_base_url
|
|
104
137
|
if "model" not in config and task_app is not None:
|
|
105
138
|
default_model = getattr(task_app, "default_model", None)
|
|
106
139
|
if default_model:
|
|
@@ -111,6 +144,31 @@ async def create_policy(
|
|
|
111
144
|
detail="Policy configuration must include 'inference_url' and 'model'.",
|
|
112
145
|
)
|
|
113
146
|
|
|
147
|
+
# Get mode from PolicyCreateRequest (defaults to "rl" for backward compatibility)
|
|
148
|
+
mode = request.mode
|
|
149
|
+
logger.info("POLICY_CREATE: Using mode=%s for URL processing", mode)
|
|
150
|
+
|
|
151
|
+
sanitized_url = ensure_chat_completions_url(config.get("inference_url"), mode=mode)
|
|
152
|
+
if isinstance(sanitized_url, str) and sanitized_url:
|
|
153
|
+
if sanitized_url != config.get("inference_url"):
|
|
154
|
+
logger.warning(
|
|
155
|
+
"POLICY_CREATE: normalized inference_url for policy=%s provider=%s mode=%s from %s to %s",
|
|
156
|
+
request.policy_name,
|
|
157
|
+
provider,
|
|
158
|
+
mode,
|
|
159
|
+
config.get("inference_url"),
|
|
160
|
+
sanitized_url,
|
|
161
|
+
)
|
|
162
|
+
config["inference_url"] = sanitized_url
|
|
163
|
+
else:
|
|
164
|
+
logger.warning(
|
|
165
|
+
"POLICY_CREATE: unable to normalize inference_url for policy=%s provider=%s mode=%s raw=%s",
|
|
166
|
+
request.policy_name,
|
|
167
|
+
mode,
|
|
168
|
+
provider,
|
|
169
|
+
config.get("inference_url"),
|
|
170
|
+
)
|
|
171
|
+
|
|
114
172
|
# Create policy instance based on name
|
|
115
173
|
pname = request.policy_name.lower()
|
|
116
174
|
if pname in ["crafter-react", "crafter"]:
|
|
@@ -485,7 +543,22 @@ async def step_policy(
|
|
|
485
543
|
|
|
486
544
|
# Ensure meta carries the final target URL for downstream logging/clients
|
|
487
545
|
with contextlib.suppress(Exception):
|
|
488
|
-
|
|
546
|
+
sanitized_target = ensure_chat_completions_url(target_url)
|
|
547
|
+
if sanitized_target and sanitized_target != target_url:
|
|
548
|
+
logger.warning(
|
|
549
|
+
"POLICY_STEP: normalized inference_url mid-flight policy=%s from %s to %s",
|
|
550
|
+
policy_name,
|
|
551
|
+
target_url,
|
|
552
|
+
sanitized_target,
|
|
553
|
+
)
|
|
554
|
+
elif not sanitized_target:
|
|
555
|
+
logger.info(
|
|
556
|
+
"POLICY_STEP: inference_url unchanged policy=%s target=%s",
|
|
557
|
+
policy_name,
|
|
558
|
+
target_url,
|
|
559
|
+
)
|
|
560
|
+
meta["inference_url"] = sanitized_target if sanitized_target else target_url
|
|
561
|
+
target_url = sanitized_target or target_url
|
|
489
562
|
|
|
490
563
|
# Select API key based on resolved target URL
|
|
491
564
|
api_key_override = None
|
|
@@ -13,6 +13,7 @@ from pydantic import BaseModel, Field
|
|
|
13
13
|
from synth_ai.lm.vendors.base import BaseLMResponse
|
|
14
14
|
from synth_ai.task.tracing_utils import unique_sft_path
|
|
15
15
|
from synth_ai.tracing_v3.abstractions import EnvironmentEvent, LMCAISEvent, TimeRecord
|
|
16
|
+
from synth_ai.task.contracts import RolloutMode
|
|
16
17
|
from synth_ai.tracing_v3.llm_call_record_helpers import create_llm_call_record_from_response
|
|
17
18
|
from synth_ai.tracing_v3.session_tracer import SessionTracer
|
|
18
19
|
|
|
@@ -120,6 +121,8 @@ class RolloutRequest(BaseModel):
|
|
|
120
121
|
# Optional run/session context
|
|
121
122
|
training_session_id: str | None = None
|
|
122
123
|
synth_base_url: str | None = None
|
|
124
|
+
# Mode controls URL transformation: REQUIRED to make intent explicit
|
|
125
|
+
mode: RolloutMode
|
|
123
126
|
|
|
124
127
|
|
|
125
128
|
class RolloutStep(BaseModel):
|
|
@@ -140,6 +143,7 @@ class RolloutTrajectory(BaseModel):
|
|
|
140
143
|
final: dict[str, Any] | None = None
|
|
141
144
|
length: int
|
|
142
145
|
decision_samples: list[dict[str, Any]] | None = None
|
|
146
|
+
inference_url: str | None = None
|
|
143
147
|
|
|
144
148
|
|
|
145
149
|
def _normalize_step_strategy(raw_strategy: Any) -> str:
|
|
@@ -452,11 +456,12 @@ class RolloutMetrics(BaseModel):
|
|
|
452
456
|
class RolloutResponse(BaseModel):
|
|
453
457
|
run_id: str
|
|
454
458
|
trajectories: list[RolloutTrajectory]
|
|
455
|
-
branches: dict[str, list[str]] =
|
|
459
|
+
branches: dict[str, list[str]] = Field(default_factory=dict)
|
|
456
460
|
metrics: RolloutMetrics
|
|
457
461
|
aborted: bool = False
|
|
458
462
|
ops_executed: int = 0
|
|
459
463
|
trace: dict[str, Any] | None = None
|
|
464
|
+
pipeline_metadata: dict[str, Any] = Field(default_factory=dict)
|
|
460
465
|
|
|
461
466
|
|
|
462
467
|
class RolloutTracingContext:
|
|
@@ -567,7 +572,7 @@ class RolloutTracingContext:
|
|
|
567
572
|
try:
|
|
568
573
|
await self.tracer.record_message(
|
|
569
574
|
content=self._prompt_payload(entry, role="system"),
|
|
570
|
-
message_type="
|
|
575
|
+
message_type="system", # Use standard message type
|
|
571
576
|
metadata=self._message_metadata(),
|
|
572
577
|
)
|
|
573
578
|
except Exception as exc:
|
|
@@ -576,11 +581,16 @@ class RolloutTracingContext:
|
|
|
576
581
|
try:
|
|
577
582
|
await self.tracer.record_message(
|
|
578
583
|
content=self._prompt_payload(entry, role="user"),
|
|
579
|
-
message_type="
|
|
584
|
+
message_type="user", # Use standard message type
|
|
580
585
|
metadata=self._message_metadata(),
|
|
581
586
|
)
|
|
582
587
|
except Exception as exc:
|
|
583
588
|
logger.debug("TRACING_USER_MSG_FAIL: %s", exc)
|
|
589
|
+
|
|
590
|
+
# Debug: Check message count
|
|
591
|
+
if self.tracer and self.tracer._current_trace:
|
|
592
|
+
msg_count = len(self.tracer._current_trace.markov_blanket_message_history)
|
|
593
|
+
logger.info(f"[TRACE_DEBUG] After record_policy_prompts: {msg_count} messages in trace")
|
|
584
594
|
|
|
585
595
|
def _content_to_text(self, content: Any) -> str:
|
|
586
596
|
if isinstance(content, str):
|
|
@@ -656,8 +666,8 @@ class RolloutTracingContext:
|
|
|
656
666
|
try:
|
|
657
667
|
await self.tracer.record_message(
|
|
658
668
|
content=self._safe_json(tool_calls),
|
|
659
|
-
message_type="
|
|
660
|
-
metadata=self._message_metadata(),
|
|
669
|
+
message_type="assistant", # Map to standard assistant message type
|
|
670
|
+
metadata={**self._message_metadata(), "is_tool_call": True},
|
|
661
671
|
)
|
|
662
672
|
except Exception as exc:
|
|
663
673
|
logger.debug("TRACING_TOOL_MSG_FAIL: %s", exc)
|
|
@@ -928,11 +938,22 @@ class RolloutTracingContext:
|
|
|
928
938
|
except Exception as exc:
|
|
929
939
|
logger.debug("TRACING_OUTCOME_FAIL: %s", exc)
|
|
930
940
|
try:
|
|
941
|
+
# Debug: Check message count before end_session
|
|
942
|
+
if self.tracer._current_trace:
|
|
943
|
+
msg_count = len(self.tracer._current_trace.markov_blanket_message_history)
|
|
944
|
+
logger.info(f"[TRACE_DEBUG] Before end_session: {msg_count} messages in trace")
|
|
945
|
+
|
|
931
946
|
self.session_trace = await self.tracer.end_session()
|
|
932
|
-
|
|
947
|
+
|
|
948
|
+
# Debug: Check if session was saved
|
|
949
|
+
if self.session_trace:
|
|
950
|
+
logger.info(f"[TRACE_DEBUG] Session ended successfully, session_id={self.session_trace.session_id}")
|
|
933
951
|
self.session_trace.metadata.update(self.metadata_updates)
|
|
952
|
+
logger.info(f"[TRACE_DEBUG] session_trace.metadata keys: {list(self.session_trace.metadata.keys())}")
|
|
953
|
+
else:
|
|
954
|
+
logger.warning("[TRACE_DEBUG] end_session returned None!")
|
|
934
955
|
except Exception as exc:
|
|
935
|
-
logger.
|
|
956
|
+
logger.warning(f"TRACING_END_SESSION_FAIL: {exc}", exc_info=True)
|
|
936
957
|
self.session_trace = None
|
|
937
958
|
with contextlib.suppress(Exception):
|
|
938
959
|
await self.tracer.close()
|
|
@@ -1056,12 +1077,14 @@ async def execute_rollout(
|
|
|
1056
1077
|
req: Request,
|
|
1057
1078
|
) -> RolloutResponse:
|
|
1058
1079
|
"""Execute a rollout with coordinated environment and policy steps."""
|
|
1080
|
+
logger.info("ROLLOUT: mode = %s", request.mode)
|
|
1081
|
+
|
|
1059
1082
|
# Emit rollout identifier early for correlation
|
|
1060
1083
|
with contextlib.suppress(Exception):
|
|
1061
1084
|
_rid = getattr(request, "run_id", None)
|
|
1062
1085
|
_pol = getattr(request.policy, "policy_name", None) or getattr(request.policy, "policy_id", None)
|
|
1063
1086
|
_env = getattr(request.env, "env_name", None) or getattr(request.env, "env_id", None)
|
|
1064
|
-
logger.info("ROLLOUT_BEGIN: run_id=%s policy=%s env=%s", _rid, _pol, _env)
|
|
1087
|
+
logger.info("ROLLOUT_BEGIN: run_id=%s policy=%s env=%s mode=%s", _rid, _pol, _env, request.mode)
|
|
1065
1088
|
print(f"[rollout] begin run_id={_rid} policy={_pol} env={_env}", flush=True)
|
|
1066
1089
|
# Enforce per-episode step cap via env-specific parameters; default to 20 if omitted
|
|
1067
1090
|
try:
|
|
@@ -1271,6 +1294,7 @@ async def execute_rollout(
|
|
|
1271
1294
|
config=_policy_config,
|
|
1272
1295
|
rl_run_id=request.run_id,
|
|
1273
1296
|
bound_env_id=env_id,
|
|
1297
|
+
mode=request.mode, # Pass through mode for URL transformation control
|
|
1274
1298
|
),
|
|
1275
1299
|
req,
|
|
1276
1300
|
)
|
|
@@ -1843,12 +1867,81 @@ async def execute_rollout(
|
|
|
1843
1867
|
timing_final.setdefault("overhead_ms", 0.0)
|
|
1844
1868
|
|
|
1845
1869
|
# Build trajectory
|
|
1870
|
+
# Extract inference_url from policy config (REQUIRED for trace correlation)
|
|
1871
|
+
# The trainer sets this in policy config with ?cid=... parameter
|
|
1872
|
+
inference_url = None
|
|
1873
|
+
|
|
1874
|
+
# Try policy config from request first (most reliable source)
|
|
1875
|
+
try:
|
|
1876
|
+
policy_config_snapshot = (
|
|
1877
|
+
request.policy.config if isinstance(request.policy.config, dict) else {}
|
|
1878
|
+
)
|
|
1879
|
+
inference_url = policy_config_snapshot.get("inference_url")
|
|
1880
|
+
if inference_url:
|
|
1881
|
+
logger.info(
|
|
1882
|
+
"ROLLOUT_TRAJECTORY: extracted inference_url from request.policy.config run_id=%s url=%s",
|
|
1883
|
+
request.run_id,
|
|
1884
|
+
inference_url,
|
|
1885
|
+
)
|
|
1886
|
+
except Exception as exc:
|
|
1887
|
+
logger.warning(
|
|
1888
|
+
"ROLLOUT_TRAJECTORY: failed to get inference_url from request.policy.config run_id=%s: %s",
|
|
1889
|
+
request.run_id,
|
|
1890
|
+
exc,
|
|
1891
|
+
)
|
|
1892
|
+
|
|
1893
|
+
# Fallback: Try policy handle snapshot (if request.policy.config failed)
|
|
1894
|
+
if not inference_url and policy_handle is not None:
|
|
1895
|
+
try:
|
|
1896
|
+
policy_snapshot = policy_handle.snapshot()
|
|
1897
|
+
inference_url = policy_snapshot.get("config", {}).get("inference_url")
|
|
1898
|
+
if inference_url:
|
|
1899
|
+
logger.info(
|
|
1900
|
+
"ROLLOUT_TRAJECTORY: extracted inference_url from policy_handle.snapshot run_id=%s url=%s",
|
|
1901
|
+
request.run_id,
|
|
1902
|
+
inference_url,
|
|
1903
|
+
)
|
|
1904
|
+
except Exception as exc:
|
|
1905
|
+
logger.warning(
|
|
1906
|
+
"ROLLOUT_TRAJECTORY: failed to snapshot policy for run_id=%s policy_id=%s: %s",
|
|
1907
|
+
request.run_id,
|
|
1908
|
+
policy_id,
|
|
1909
|
+
exc,
|
|
1910
|
+
)
|
|
1911
|
+
|
|
1912
|
+
# ASSERTION: inference_url MUST be present (required by RolloutTrajectory schema)
|
|
1913
|
+
if not inference_url:
|
|
1914
|
+
raise ValueError(
|
|
1915
|
+
f"FATAL: inference_url is required but not found!\n"
|
|
1916
|
+
f"\n"
|
|
1917
|
+
f"run_id: {request.run_id}\n"
|
|
1918
|
+
f"policy_id: {policy_id}\n"
|
|
1919
|
+
f"policy_config_keys: {list(policy_config_snapshot.keys()) if 'policy_config_snapshot' in locals() else 'N/A'}\n"
|
|
1920
|
+
f"\n"
|
|
1921
|
+
f"The trainer MUST set inference_url in policy config with ?cid=... parameter.\n"
|
|
1922
|
+
f"This is required for trace correlation and hydration.\n"
|
|
1923
|
+
)
|
|
1924
|
+
|
|
1925
|
+
# policy_config_snapshot already set above in try block (line 1876-1878)
|
|
1926
|
+
# Ensure it exists for logging below
|
|
1927
|
+
if 'policy_config_snapshot' not in locals():
|
|
1928
|
+
policy_config_snapshot = {}
|
|
1929
|
+
|
|
1930
|
+
logger.info(
|
|
1931
|
+
"ROLLOUT_TRAJECTORY: run_id=%s policy_id=%s inference_url=%s trace_id=%s",
|
|
1932
|
+
request.run_id,
|
|
1933
|
+
policy_id,
|
|
1934
|
+
inference_url,
|
|
1935
|
+
policy_config_snapshot.get("trace_correlation_id"),
|
|
1936
|
+
)
|
|
1937
|
+
|
|
1846
1938
|
trajectory = RolloutTrajectory(
|
|
1847
1939
|
env_id=env_id,
|
|
1848
1940
|
policy_id=policy_id,
|
|
1849
1941
|
steps=trajectory_steps,
|
|
1850
1942
|
final={"observation": _summarize_observation_for_storage(env_handle, current_obs)},
|
|
1851
1943
|
length=len(trajectory_steps),
|
|
1944
|
+
inference_url=inference_url, # NEW: Required for trace correlation
|
|
1852
1945
|
decision_samples=decision_samples if step_rewards_active else None,
|
|
1853
1946
|
)
|
|
1854
1947
|
|
|
@@ -1938,12 +2031,17 @@ async def execute_rollout(
|
|
|
1938
2031
|
)
|
|
1939
2032
|
finalized = True
|
|
1940
2033
|
trace_payload = tracing_context.build_trace_payload(session_trace)
|
|
2034
|
+
|
|
2035
|
+
# Debug: Check trace payload
|
|
2036
|
+
logger.info(f"[TRACE_DEBUG] trace_payload is None: {trace_payload is None}, return_trace={tracing_context.return_trace}")
|
|
2037
|
+
if trace_payload:
|
|
2038
|
+
logger.info(f"[TRACE_DEBUG] trace_payload keys: {list(trace_payload.keys())}")
|
|
1941
2039
|
|
|
1942
2040
|
# Hard-fail if no steps executed (avg_turns == 0 scenario)
|
|
1943
2041
|
if metrics.num_steps <= 0:
|
|
1944
2042
|
raise HTTPException(status_code=500, detail="no_steps_executed: avg_turns == 0")
|
|
1945
2043
|
|
|
1946
|
-
|
|
2044
|
+
response = RolloutResponse(
|
|
1947
2045
|
run_id=request.run_id,
|
|
1948
2046
|
trajectories=[trajectory],
|
|
1949
2047
|
branches={},
|
|
@@ -1952,6 +2050,16 @@ async def execute_rollout(
|
|
|
1952
2050
|
ops_executed=ops_executed,
|
|
1953
2051
|
trace=trace_payload,
|
|
1954
2052
|
)
|
|
2053
|
+
logger.info(
|
|
2054
|
+
"ROLLOUT_RESPONSE: run_id=%s aborted=%s ops_executed=%s metrics_steps=%s trace_present=%s pipeline_metadata=%s",
|
|
2055
|
+
request.run_id,
|
|
2056
|
+
aborted,
|
|
2057
|
+
ops_executed,
|
|
2058
|
+
metrics.num_steps,
|
|
2059
|
+
bool(trace_payload),
|
|
2060
|
+
response.pipeline_metadata,
|
|
2061
|
+
)
|
|
2062
|
+
return response
|
|
1955
2063
|
|
|
1956
2064
|
except Exception as e:
|
|
1957
2065
|
logger.error(f"Rollout failed for run {request.run_id}: {e}")
|
|
@@ -1,15 +1,14 @@
|
|
|
1
1
|
#!/usr/bin/env python3
|
|
2
|
-
"""
|
|
3
|
-
Simple test script for the GRPO Synth Envs Hosted Service.
|
|
4
|
-
|
|
5
|
-
Run this after starting the service with:
|
|
6
|
-
python main.py
|
|
7
|
-
"""
|
|
2
|
+
"""Manual smoke script for the GRPO Synth Envs Hosted Service."""
|
|
8
3
|
|
|
9
4
|
import asyncio
|
|
10
5
|
import json
|
|
11
6
|
|
|
12
7
|
import httpx
|
|
8
|
+
import pytest
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
pytestmark = pytest.mark.skip(reason="Requires running hosted service on localhost:8000")
|
|
13
12
|
|
|
14
13
|
|
|
15
14
|
async def test_service():
|
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
"""Utility functions for the task service."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Any
|
|
5
|
+
from urllib.parse import parse_qs, urlparse, urlunparse
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
_CHAT_COMPLETIONS_SUFFIX = "/v1/chat/completions"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def ensure_chat_completions_url(raw_url: Any, mode: str | None = None) -> Any:
|
|
15
|
+
"""
|
|
16
|
+
Ensure inference URLs point at the chat completions endpoint.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
raw_url: The inference URL to process
|
|
20
|
+
mode: "rl" applies URL transformations, "eval" uses URLs as-is (deprecated - use RolloutMode enum)
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
Processed URL (transformed in RL mode, unchanged in EVAL mode)
|
|
24
|
+
"""
|
|
25
|
+
# In EVAL mode, use URLs exactly as provided - no transformations
|
|
26
|
+
# Accept both string "eval" (legacy) and RolloutMode.EVAL
|
|
27
|
+
from synth_ai.task.contracts import RolloutMode
|
|
28
|
+
is_eval_mode = (mode == "eval" or mode == RolloutMode.EVAL or
|
|
29
|
+
(hasattr(mode, 'value') and mode.value == "eval"))
|
|
30
|
+
|
|
31
|
+
if is_eval_mode:
|
|
32
|
+
logger.info("ensure_chat_completions_url: EVAL mode - using URL as-is: %s", raw_url)
|
|
33
|
+
return raw_url
|
|
34
|
+
|
|
35
|
+
# RL mode: apply transformations for compatibility
|
|
36
|
+
if not isinstance(raw_url, str):
|
|
37
|
+
logger.debug("ensure_chat_completions_url: non-string input %r (type=%s)", raw_url, type(raw_url))
|
|
38
|
+
return raw_url
|
|
39
|
+
url = raw_url.strip()
|
|
40
|
+
if not url:
|
|
41
|
+
logger.debug("ensure_chat_completions_url: blank/whitespace URL input")
|
|
42
|
+
return raw_url
|
|
43
|
+
|
|
44
|
+
parsed = urlparse(url)
|
|
45
|
+
path = (parsed.path or "").rstrip("/")
|
|
46
|
+
if path.endswith("/v1/chat/completions"):
|
|
47
|
+
logger.debug("ensure_chat_completions_url: URL already normalized %s", url)
|
|
48
|
+
# Already targeting the desired endpoint; keep original to preserve trailing slash.
|
|
49
|
+
return url
|
|
50
|
+
|
|
51
|
+
if not path:
|
|
52
|
+
new_path = _CHAT_COMPLETIONS_SUFFIX
|
|
53
|
+
else:
|
|
54
|
+
new_path = f"{path}{_CHAT_COMPLETIONS_SUFFIX}"
|
|
55
|
+
|
|
56
|
+
rebuilt = parsed._replace(path=new_path)
|
|
57
|
+
normalized = urlunparse(rebuilt)
|
|
58
|
+
logger.info(
|
|
59
|
+
"ensure_chat_completions_url: RL mode - normalized inference URL from %s to %s",
|
|
60
|
+
url,
|
|
61
|
+
normalized,
|
|
62
|
+
)
|
|
63
|
+
return normalized
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def inference_url_to_trace_correlation_id(raw_url: Any, *, required: bool = False, mode: Any = None) -> str | None:
|
|
67
|
+
"""
|
|
68
|
+
Extract trace_correlation_id from inference URL query params.
|
|
69
|
+
|
|
70
|
+
The inference URL should contain ?cid=trace_xxxxx parameter.
|
|
71
|
+
This is THE canonical source for trace_correlation_id - it's what the
|
|
72
|
+
inference server uses to tag traces, so we extract it here.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
raw_url: Inference URL (should contain ?cid=... query param)
|
|
76
|
+
required: If True, raises AssertionError if trace_correlation_id not found
|
|
77
|
+
mode: RolloutMode or string ("rl" or "eval"). Controls warning behavior -
|
|
78
|
+
warnings only logged for RL mode, not EVAL mode.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
trace_correlation_id if found in URL, None otherwise
|
|
82
|
+
|
|
83
|
+
Raises:
|
|
84
|
+
AssertionError: If required=True and trace_correlation_id not found
|
|
85
|
+
"""
|
|
86
|
+
if not isinstance(raw_url, str):
|
|
87
|
+
logger.debug(
|
|
88
|
+
"inference_url_to_trace_correlation_id: non-string input %r (type=%s)",
|
|
89
|
+
raw_url,
|
|
90
|
+
type(raw_url)
|
|
91
|
+
)
|
|
92
|
+
if required:
|
|
93
|
+
raise AssertionError(
|
|
94
|
+
f"FATAL: inference_url_to_trace_correlation_id requires string URL, got {type(raw_url)}: {raw_url!r}"
|
|
95
|
+
)
|
|
96
|
+
return None
|
|
97
|
+
|
|
98
|
+
parsed = urlparse(raw_url)
|
|
99
|
+
query_params = parse_qs(parsed.query or "")
|
|
100
|
+
|
|
101
|
+
# Check all possible parameter names (cid is primary)
|
|
102
|
+
candidates = (
|
|
103
|
+
query_params.get("cid") or
|
|
104
|
+
query_params.get("trace") or
|
|
105
|
+
query_params.get("trace_correlation_id") or
|
|
106
|
+
[]
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
for value in candidates:
|
|
110
|
+
if isinstance(value, str) and value.strip():
|
|
111
|
+
correlation_id = value.strip()
|
|
112
|
+
logger.info(
|
|
113
|
+
"inference_url_to_trace_correlation_id: ✅ extracted id=%s from url=%s",
|
|
114
|
+
correlation_id,
|
|
115
|
+
raw_url,
|
|
116
|
+
)
|
|
117
|
+
# ASSERTION: Correlation ID should look like trace_xxxxx
|
|
118
|
+
assert correlation_id.startswith("trace_"), (
|
|
119
|
+
f"FATAL: trace_correlation_id has unexpected format: {correlation_id!r}. "
|
|
120
|
+
f"Expected to start with 'trace_'"
|
|
121
|
+
)
|
|
122
|
+
return correlation_id
|
|
123
|
+
|
|
124
|
+
# Not found - check if we're in EVAL mode (trace_correlation_id not required for eval)
|
|
125
|
+
from synth_ai.task.contracts import RolloutMode
|
|
126
|
+
is_eval_mode = (mode == "eval" or mode == RolloutMode.EVAL or
|
|
127
|
+
(hasattr(mode, 'value') and mode.value == "eval"))
|
|
128
|
+
|
|
129
|
+
if is_eval_mode:
|
|
130
|
+
# For EVAL mode, missing trace_correlation_id is expected - log as debug, not warning
|
|
131
|
+
logger.debug(
|
|
132
|
+
"inference_url_to_trace_correlation_id: No trace_correlation_id in EVAL mode (expected) url=%s query_params=%s",
|
|
133
|
+
raw_url,
|
|
134
|
+
list(query_params.keys())
|
|
135
|
+
)
|
|
136
|
+
else:
|
|
137
|
+
# For RL mode, missing trace_correlation_id is concerning
|
|
138
|
+
logger.warning(
|
|
139
|
+
"inference_url_to_trace_correlation_id: ❌ NO trace_correlation_id found in url=%s query_params=%s",
|
|
140
|
+
raw_url,
|
|
141
|
+
list(query_params.keys())
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
if required:
|
|
145
|
+
raise AssertionError(
|
|
146
|
+
f"FATAL: trace_correlation_id REQUIRED but not found in inference_url!\n"
|
|
147
|
+
f"\n"
|
|
148
|
+
f"URL: {raw_url}\n"
|
|
149
|
+
f"Query params found: {list(query_params.keys())}\n"
|
|
150
|
+
f"\n"
|
|
151
|
+
f"The inference_url MUST contain ?cid=trace_xxxxx parameter.\n"
|
|
152
|
+
f"This is set by the trainer when generating rollout requests.\n"
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
return None
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
# Legacy alias for backward compatibility
|
|
159
|
+
def extract_trace_correlation_id(raw_url: Any, mode: Any = None) -> str | None:
|
|
160
|
+
"""DEPRECATED: Use inference_url_to_trace_correlation_id instead."""
|
|
161
|
+
return inference_url_to_trace_correlation_id(raw_url, required=False, mode=mode)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def convert_numpy_to_python(obj: Any) -> Any:
|
|
165
|
+
"""
|
|
166
|
+
Recursively convert numpy types to Python native types for JSON serialization.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
obj: Object that may contain numpy types
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
Object with numpy types converted to Python native types
|
|
173
|
+
"""
|
|
174
|
+
if isinstance(obj, np.integer):
|
|
175
|
+
return int(obj)
|
|
176
|
+
elif isinstance(obj, np.floating):
|
|
177
|
+
return float(obj)
|
|
178
|
+
elif isinstance(obj, np.ndarray):
|
|
179
|
+
return obj.tolist()
|
|
180
|
+
elif isinstance(obj, dict):
|
|
181
|
+
return {key: convert_numpy_to_python(value) for key, value in obj.items()}
|
|
182
|
+
elif isinstance(obj, list | tuple):
|
|
183
|
+
return [convert_numpy_to_python(item) for item in obj]
|
|
184
|
+
else:
|
|
185
|
+
return obj
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def sanitize_observation(observation: dict[str, Any]) -> dict[str, Any]:
|
|
189
|
+
"""
|
|
190
|
+
Sanitize observation data for JSON serialization.
|
|
191
|
+
|
|
192
|
+
Converts numpy types and removes non-serializable objects.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
observation: Raw observation from environment
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
Sanitized observation safe for JSON serialization
|
|
199
|
+
"""
|
|
200
|
+
if not isinstance(observation, dict):
|
|
201
|
+
return observation
|
|
202
|
+
|
|
203
|
+
sanitized = {}
|
|
204
|
+
for key, value in observation.items():
|
|
205
|
+
# Skip non-serializable keys or convert them
|
|
206
|
+
if key in ["semantic_map", "world_material_map", "observation_image"]:
|
|
207
|
+
# These are likely numpy arrays - convert to lists or skip
|
|
208
|
+
if isinstance(value, np.ndarray):
|
|
209
|
+
# For large arrays, we might want to skip or compress
|
|
210
|
+
# For now, skip them as they're likely debug info
|
|
211
|
+
continue
|
|
212
|
+
elif key == "player_position" and isinstance(value, tuple):
|
|
213
|
+
# Convert tuple with potential numpy types
|
|
214
|
+
sanitized[key] = [convert_numpy_to_python(v) for v in value]
|
|
215
|
+
else:
|
|
216
|
+
sanitized[key] = convert_numpy_to_python(value)
|
|
217
|
+
|
|
218
|
+
return sanitized
|