synth-ai 0.2.13.dev1__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/multi_step/configs/README_verilog_rl.md +77 -0
- examples/multi_step/configs/VERILOG_REWARDS.md +90 -0
- examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +183 -0
- examples/multi_step/configs/crafter_eval_synth_qwen4b.toml +35 -0
- examples/multi_step/configs/crafter_eval_text_only_groq_qwen32b.toml +36 -0
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +17 -5
- examples/multi_step/configs/crafter_synth_backend.md +40 -0
- examples/multi_step/configs/verilog_eval_groq_qwen32b.toml +31 -0
- examples/multi_step/configs/verilog_eval_synth_qwen8b.toml +33 -0
- examples/multi_step/configs/verilog_rl_lora.toml +190 -0
- examples/multi_step/judges/crafter_backend_judge.py +220 -0
- examples/multi_step/judges/verilog_backend_judge.py +234 -0
- examples/multi_step/readme.md +48 -0
- examples/multi_step/verilog_rl_lora.md +218 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +1 -1
- examples/sft/evaluate.py +2 -0
- examples/sft/generate_traces.py +2 -0
- examples/swe/task_app/grpo_swe_mini.py +56 -26
- examples/swe/task_app/hosted/rollout.py +42 -0
- examples/swe/task_app/hosted/test_service.py +5 -6
- examples/task_apps/IMAGE_ONLY_EVAL_QUICKSTART.md +258 -0
- examples/task_apps/TESTING.md +275 -0
- examples/task_apps/__init__.py +0 -0
- examples/task_apps/crafter/CREATE_SFT_DATASET.md +273 -0
- examples/task_apps/crafter/EVAL_IMAGE_ONLY_RESULTS.md +152 -0
- examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +174 -0
- examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +268 -0
- examples/task_apps/crafter/QUERY_EXAMPLES.md +203 -0
- examples/task_apps/crafter/README_IMAGE_ONLY_EVAL.md +316 -0
- examples/task_apps/crafter/__init__.py +0 -0
- examples/task_apps/crafter/eval_image_only_gpt4o.toml +28 -0
- examples/task_apps/crafter/eval_text_only_groq_llama.toml +36 -0
- examples/task_apps/crafter/filter_sft_dataset.toml +16 -0
- examples/task_apps/crafter/task_app/__init__.py +5 -0
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter.py +324 -21
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter_task_app.py +1 -1
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/environment.py +10 -0
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/policy.py +76 -7
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/react_agent.py +17 -2
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/openai_client.py +25 -3
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/policy_routes.py +77 -4
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/rollout.py +117 -9
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_service.py +5 -6
- examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +218 -0
- examples/task_apps/dev/pokemon_emerald/__init__.py +2 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/README.md +811 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/__init__.py +120 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/action.py +160 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/memory.py +155 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/perception.py +69 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/planning.py +96 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/simple.py +1502 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/system_prompt.py +4 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/grab_map.py +68 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/manual.py +216 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/__init__.py +35 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emerald_utils.py +631 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emulator.py +1544 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/enums.py +1428 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/memory_reader.py +4848 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/types.py +41 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/utils.py +298 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pyproject.toml +95 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/run.py +204 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/app.py +2152 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/client.py +429 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/frame_server.py +155 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/README.md +78 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/run_tests.py +122 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_direct.py +76 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_prompts.py +413 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_battle_state_formatting.py +204 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection.py +133 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection_comprehensive.py +229 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_direct_agent_emulator.py +300 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_fps_adjustment_pytest.py +205 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_direct.py +200 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_transition.py +284 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_map_ground_truth_comparison.py +468 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_memory_map.py +575 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_server_map_validation.py +311 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_torchic_state.py +259 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/anticheat.py +372 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/checkpoint.py +296 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/error_handler.py +275 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/get_local_ip.py +22 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/helpers.py +44 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/llm_logger.py +514 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_formatter.py +415 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher.py +1763 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher_singleton.py +33 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_trimmer.py +106 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_visualizer.py +334 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/ocr_dialogue.py +1020 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/recording.py +188 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/state_formatter.py +1481 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/vlm.py +862 -0
- examples/task_apps/dev/pokemon_emerald/modal_app.py +114 -0
- examples/task_apps/dev/pokemon_emerald/task_app/README.md +81 -0
- examples/task_apps/dev/pokemon_emerald/task_app/__init__.py +6 -0
- examples/task_apps/dev/pokemon_emerald/task_app/pokemon_emerald.py +685 -0
- examples/task_apps/enron/__init__.py +1 -0
- examples/task_apps/enron/eval_groq_qwen32.toml +16 -0
- examples/task_apps/enron/filter_sft.toml +5 -0
- examples/task_apps/enron/task_app/README.md +14 -0
- examples/task_apps/enron/task_app/__init__.py +1 -0
- examples/task_apps/enron/task_app/grpo_enron.py +906 -0
- examples/task_apps/enron/task_app/grpo_enron_task_app.py +146 -0
- examples/task_apps/enron/tests/__init__.py +4 -0
- examples/task_apps/enron/tests/conftest.py +115 -0
- examples/task_apps/enron/tests/integration/__init__.py +4 -0
- examples/task_apps/enron/tests/integration/test_enron_eval.py +179 -0
- examples/task_apps/enron/tests/integration/test_enron_rollout.py +135 -0
- examples/task_apps/enron/tests/unit/__init__.py +4 -0
- examples/task_apps/enron/tests/unit/test_enron_environment.py +126 -0
- examples/task_apps/math/__init__.py +0 -0
- examples/{rl/task_app → task_apps/math}/math_single_step.py +19 -10
- examples/task_apps/pokemon_battle/__init__.py +2 -0
- examples/task_apps/pokemon_battle/modal_app.py +104 -0
- examples/task_apps/pokemon_battle/task_app/README.md +68 -0
- examples/task_apps/pokemon_battle/task_app/__init__.py +6 -0
- examples/task_apps/pokemon_battle/task_app/pokemon_showdown.py +932 -0
- examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_COMPLETE.md +283 -0
- examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_STATUS.md +155 -0
- examples/task_apps/pokemon_red/README.md +357 -0
- examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +415 -0
- examples/task_apps/pokemon_red/__init__.py +3 -0
- examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +29 -0
- examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +225 -0
- examples/task_apps/pokemon_red/pallet_town_rl_config.toml +75 -0
- examples/task_apps/pokemon_red/task_app.py +799 -0
- examples/task_apps/pokemon_red/test_pallet_town_rewards.py +193 -0
- examples/task_apps/sokoban/README.md +307 -0
- examples/task_apps/sokoban/__init__.py +3 -0
- examples/task_apps/sokoban/eval_groq_qwen32.toml +16 -0
- examples/task_apps/sokoban/eval_openai_gpt5.toml +16 -0
- examples/task_apps/sokoban/filter_sft.toml +5 -0
- examples/task_apps/sokoban/task_app.py +1058 -0
- examples/task_apps/sokoban/tests/__init__.py +4 -0
- examples/task_apps/sokoban/tests/conftest.py +113 -0
- examples/task_apps/sokoban/tests/integration/__init__.py +4 -0
- examples/task_apps/sokoban/tests/integration/test_sokoban_eval.py +57 -0
- examples/task_apps/sokoban/tests/integration/test_sokoban_rollout.py +198 -0
- examples/task_apps/sokoban/tests/unit/__init__.py +4 -0
- examples/task_apps/sokoban/tests/unit/test_sokoban_environment.py +114 -0
- examples/task_apps/verilog/__init__.py +1 -0
- examples/task_apps/verilog/eval_groq_qwen32b.toml +24 -0
- examples/task_apps/verilog/filter_sft.toml +5 -0
- examples/task_apps/verilog/task_app/README.md +12 -0
- examples/task_apps/verilog/task_app/__init__.py +1 -0
- examples/task_apps/verilog/task_app/grpo_verilog.py +1166 -0
- examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +145 -0
- examples/task_apps/verilog/tests/__init__.py +4 -0
- examples/task_apps/verilog/tests/conftest.py +115 -0
- examples/task_apps/verilog/tests/integration/__init__.py +4 -0
- examples/task_apps/verilog/tests/integration/test_verilog_eval.py +181 -0
- examples/task_apps/verilog/tests/integration/test_verilog_rollout.py +55 -0
- examples/task_apps/verilog/tests/unit/__init__.py +4 -0
- examples/task_apps/verilog/tests/unit/test_verilog_scoring.py +118 -0
- examples/vlm/crafter_openai_vlm_agent.py +4 -4
- examples/vlm/run_crafter_vlm_benchmark.py +4 -4
- examples/warming_up_to_rl/groq_test.py +2 -0
- examples/warming_up_to_rl/run_local_rollout.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_modal.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_parallel.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_traced.py +2 -0
- examples/warming_up_to_rl/run_rollout_remote.py +2 -0
- examples/workflows/__init__.py +0 -0
- examples/workflows/math_rl/__init__.py +0 -0
- examples/workflows/math_rl/download_dataset.py +80 -0
- synth_ai/__init__.py +2 -2
- synth_ai/api/models/supported.py +1 -0
- synth_ai/api/train/builders.py +25 -11
- synth_ai/api/train/cli.py +12 -6
- synth_ai/api/train/configs/__init__.py +10 -10
- synth_ai/api/train/configs/rl.py +5 -4
- synth_ai/api/train/configs/sft.py +4 -3
- synth_ai/api/train/env_resolver.py +5 -2
- synth_ai/api/train/supported_algos.py +10 -5
- synth_ai/api/train/utils.py +7 -4
- synth_ai/cli/__init__.py +48 -59
- synth_ai/cli/_modal_wrapper.py +3 -2
- synth_ai/cli/_storage.py +4 -3
- synth_ai/cli/_validate_task_app.py +11 -0
- synth_ai/cli/balance.py +4 -3
- synth_ai/cli/calc.py +2 -2
- synth_ai/cli/demo.py +14 -7
- synth_ai/cli/legacy_root_backup.py +1 -1
- synth_ai/cli/recent.py +1 -1
- synth_ai/cli/rl_demo.py +8 -7
- synth_ai/cli/root.py +0 -97
- synth_ai/cli/status.py +1 -1
- synth_ai/cli/task_apps.py +1922 -190
- synth_ai/cli/traces.py +1 -1
- synth_ai/cli/tui.py +57 -0
- synth_ai/cli/turso.py +1 -1
- synth_ai/cli/watch.py +1 -1
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +29 -17
- synth_ai/environments/examples/crafter_classic/environment.py +1 -1
- synth_ai/environments/examples/enron/engine.py +7 -2
- synth_ai/environments/examples/enron/environment.py +68 -0
- synth_ai/environments/examples/red/engine.py +27 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +7 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_progression.py +477 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +32 -0
- synth_ai/environments/examples/red/environment.py +60 -0
- synth_ai/environments/examples/sokoban/taskset.py +116 -0
- synth_ai/environments/examples/verilog/engine.py +104 -12
- synth_ai/evals/client.py +58 -61
- synth_ai/jobs/client.py +16 -4
- synth_ai/judge_schemas.py +9 -9
- synth_ai/py.typed +0 -0
- synth_ai/task/__init__.py +24 -5
- synth_ai/task/apps/__init__.py +1 -0
- synth_ai/task/config.py +257 -0
- synth_ai/task/contracts.py +138 -39
- synth_ai/task/proxy.py +48 -56
- synth_ai/task/rubrics/__init__.py +56 -0
- synth_ai/task/rubrics/loaders.py +152 -0
- synth_ai/task/rubrics/models.py +57 -0
- synth_ai/task/rubrics/scoring.py +116 -0
- synth_ai/{rubrics/validators.py → task/rubrics/strict.py} +53 -30
- synth_ai/task/server.py +8 -7
- synth_ai/task/trace_correlation_helpers.py +315 -0
- synth_ai/task/validators.py +413 -6
- synth_ai/tracing_v3/abstractions.py +3 -3
- synth_ai/tracing_v3/decorators.py +7 -3
- synth_ai/tracing_v3/llm_call_record_helpers.py +5 -5
- synth_ai/tracing_v3/replica_sync.py +4 -4
- synth_ai/tracing_v3/serialization.py +5 -5
- synth_ai/tracing_v3/session_tracer.py +16 -6
- synth_ai/tracing_v3/storage/base.py +29 -29
- synth_ai/tracing_v3/storage/config.py +3 -3
- synth_ai/tracing_v3/trace_utils.py +317 -0
- synth_ai/tracing_v3/turso/daemon.py +8 -7
- synth_ai/tracing_v3/turso/native_manager.py +66 -43
- synth_ai/tracing_v3/utils.py +3 -3
- synth_ai/tui/__init__.py +5 -0
- synth_ai/tui/__main__.py +13 -0
- synth_ai/tui/cli/__init__.py +1 -0
- synth_ai/tui/cli/query_experiments.py +164 -0
- synth_ai/tui/cli/query_experiments_v3.py +164 -0
- synth_ai/tui/dashboard.py +906 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/METADATA +4 -1
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/RECORD +278 -126
- examples/agora_ex/README_MoE.md +0 -224
- examples/agora_ex/__init__.py +0 -7
- examples/agora_ex/agora_ex.py +0 -65
- examples/agora_ex/agora_ex_task_app.py +0 -590
- examples/agora_ex/configs/rl_lora_qwen3_moe_2xh200.toml +0 -121
- examples/agora_ex/reward_fn_grpo-human.py +0 -129
- examples/agora_ex/system_prompt_CURRENT.md +0 -63
- examples/agora_ex/task_app/agora_ex_task_app.py +0 -590
- examples/agora_ex/task_app/reward_fn_grpo-human.py +0 -129
- examples/agora_ex/task_app/system_prompt_CURRENT.md +0 -63
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +0 -62
- synth_ai/rubrics/__init__.py +0 -22
- synth_ai/task/rubrics.py +0 -219
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/README.md +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/README.md +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/branching.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/environment_routes.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/app.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/shared.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/tools.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/hosted_app.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/main.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/registry.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/volume.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_agents.py +0 -0
- /examples/{rl/task_app → task_apps/math}/README.md +0 -0
- /examples/{rl/task_app → task_apps/math}/math_task_app.py +0 -0
- /examples/{rl → workflows/math_rl}/configs/eval_base_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/eval_rl_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen17.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_ft_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/run_eval.py +0 -0
- /examples/{rl → workflows/math_rl}/run_rl_and_save.py +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1166 @@
|
|
|
1
|
+
"""Task App configuration for the GRPO Verilog spec-to-RTL example."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import contextlib
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
import os
|
|
10
|
+
import re
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Any, Iterable, Optional, Sequence
|
|
14
|
+
|
|
15
|
+
import httpx
|
|
16
|
+
|
|
17
|
+
from synth_ai.environments.environment.tools import EnvToolCall
|
|
18
|
+
from synth_ai.environments.examples.verilog.environment import VerilogEnvironment
|
|
19
|
+
from synth_ai.environments.examples.verilog.taskset import (
|
|
20
|
+
VerilogTaskInstance,
|
|
21
|
+
VerilogTaskInstanceMetadata,
|
|
22
|
+
create_verilog_taskset,
|
|
23
|
+
)
|
|
24
|
+
from synth_ai.environments.tasks.core import TaskInstanceSet
|
|
25
|
+
from synth_ai.task.apps import ModalDeploymentConfig, TaskAppEntry, register_task_app
|
|
26
|
+
from synth_ai.task.contracts import (
|
|
27
|
+
RolloutMetrics,
|
|
28
|
+
RolloutRequest,
|
|
29
|
+
RolloutResponse,
|
|
30
|
+
RolloutTrajectory,
|
|
31
|
+
RolloutStep,
|
|
32
|
+
TaskInfo,
|
|
33
|
+
)
|
|
34
|
+
from synth_ai.task.datasets import TaskDatasetRegistry, TaskDatasetSpec
|
|
35
|
+
from synth_ai.task.rubrics import load_rubric
|
|
36
|
+
from synth_ai.task.server import ProxyConfig, RubricBundle, TaskAppConfig
|
|
37
|
+
from synth_ai.task.validators import normalize_inference_url
|
|
38
|
+
from synth_ai.task.tracing_utils import (
|
|
39
|
+
build_tracer_factory,
|
|
40
|
+
resolve_sft_output_dir,
|
|
41
|
+
resolve_tracing_db_url,
|
|
42
|
+
tracing_env_enabled,
|
|
43
|
+
)
|
|
44
|
+
from synth_ai.tracing_v3.session_tracer import SessionTracer
|
|
45
|
+
|
|
46
|
+
logger = logging.getLogger(__name__)
|
|
47
|
+
|
|
48
|
+
_HERE = Path(__file__).resolve()
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _resolve_repo_root() -> Path:
|
|
52
|
+
"""Find synth-ai repo root, checking env var and parent traversal."""
|
|
53
|
+
candidates: list[Path] = []
|
|
54
|
+
env_root = os.getenv("SYNTH_AI_REPO_ROOT")
|
|
55
|
+
if env_root:
|
|
56
|
+
candidates.append(Path(env_root).expanduser())
|
|
57
|
+
|
|
58
|
+
# Try Modal mount point
|
|
59
|
+
candidates.append(Path("/opt/synth_ai_repo"))
|
|
60
|
+
|
|
61
|
+
# Traverse up from current file
|
|
62
|
+
current = _HERE
|
|
63
|
+
for _ in range(6):
|
|
64
|
+
current = current.parent
|
|
65
|
+
candidates.append(current)
|
|
66
|
+
if (current / "synth_ai").is_dir() and (current / "examples").is_dir():
|
|
67
|
+
return current
|
|
68
|
+
|
|
69
|
+
# Return first existing candidate
|
|
70
|
+
for candidate in candidates:
|
|
71
|
+
if candidate.is_dir() and (candidate / "synth_ai").exists():
|
|
72
|
+
return candidate
|
|
73
|
+
|
|
74
|
+
# Fallback to current parent structure (may not work in Modal)
|
|
75
|
+
return _HERE.parent.parent.parent.parent
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
REPO_ROOT = _resolve_repo_root()
|
|
79
|
+
|
|
80
|
+
DATASET_SPEC = TaskDatasetSpec(
|
|
81
|
+
id="verilog_eval_v2",
|
|
82
|
+
name="VerilogEval Spec-to-RTL",
|
|
83
|
+
version="1.0.0",
|
|
84
|
+
splits=["train", "val", "test"],
|
|
85
|
+
default_split="val",
|
|
86
|
+
description="Spec-to-RTL problems sourced from the VerilogEval v2 benchmark.",
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
MAX_INSTANCES = int(os.getenv("VERILOG_MAX_INSTANCES", "10"))
|
|
90
|
+
TOOLS = ["write_file", "compile", "simulate", "submit"]
|
|
91
|
+
DEFAULT_INFERENCE_URL = os.getenv(
|
|
92
|
+
"VERILOG_INFERENCE_URL", "https://api.groq.com/openai/v1/chat/completions"
|
|
93
|
+
)
|
|
94
|
+
DEFAULT_MODEL = os.getenv("VERILOG_DEFAULT_MODEL", "qwen/qwen3-32b")
|
|
95
|
+
DEFAULT_TEMPERATURE = float(os.getenv("VERILOG_DEFAULT_TEMPERATURE", "0.2"))
|
|
96
|
+
DEFAULT_MAX_TOKENS = int(os.getenv("VERILOG_DEFAULT_MAX_TOKENS", "768"))
|
|
97
|
+
DEFAULT_MAX_STEPS = int(os.getenv("VERILOG_DEFAULT_MAX_STEPS", "10"))
|
|
98
|
+
FILE_PREVIEW_CHARS = int(os.getenv("VERILOG_FILE_PREVIEW_CHARS", "600"))
|
|
99
|
+
HTTP_TIMEOUT_SECONDS = float(os.getenv("VERILOG_INFERENCE_TIMEOUT", "90"))
|
|
100
|
+
|
|
101
|
+
VERILOG_SYSTEM_PROMPT = (
|
|
102
|
+
"You are an expert digital design engineer helping with Verilog spec-to-RTL tasks. "
|
|
103
|
+
"Choose between these tools: write_file, compile, simulate, submit. "
|
|
104
|
+
"Always respond with a JSON object describing exactly one tool call in the form "
|
|
105
|
+
"{\"tool\": \"<tool_name>\", \"args\": { ... }}. "
|
|
106
|
+
"You may wrap the JSON inside a ```json``` block but MUST NOT include any other prose outside it. "
|
|
107
|
+
"When editing files, rewrite the entire file content. Compile after code changes, simulate to verify behavior, "
|
|
108
|
+
"and submit only after the tests pass. If compilation reports errors (missing ports, mismatched interfaces, etc.), "
|
|
109
|
+
"fix the design with write_file before compiling again—never repeat compile without modifying the source first."
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _load_taskset_blocking(max_instances: int) -> TaskInstanceSet:
|
|
114
|
+
try:
|
|
115
|
+
return asyncio.run(create_verilog_taskset(max_instances=max_instances))
|
|
116
|
+
except RuntimeError:
|
|
117
|
+
loop = asyncio.new_event_loop()
|
|
118
|
+
try:
|
|
119
|
+
return loop.run_until_complete(create_verilog_taskset(max_instances=max_instances))
|
|
120
|
+
finally:
|
|
121
|
+
loop.close()
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
@dataclass
|
|
125
|
+
class VerilogDataset:
|
|
126
|
+
spec: TaskDatasetSpec
|
|
127
|
+
max_instances: int
|
|
128
|
+
|
|
129
|
+
def __post_init__(self) -> None:
|
|
130
|
+
self._taskset = _load_taskset_blocking(self.max_instances)
|
|
131
|
+
self.instances: list[VerilogTaskInstance] = list(self._taskset.instances)
|
|
132
|
+
self.instance_ids = [str(inst.id) for inst in self.instances]
|
|
133
|
+
self.default_seed = 0
|
|
134
|
+
self.seed_min = 0
|
|
135
|
+
self.seed_max = max(len(self.instances) - 1, 0)
|
|
136
|
+
|
|
137
|
+
def describe(self) -> dict[str, Any]:
|
|
138
|
+
return {
|
|
139
|
+
**self.spec.model_dump(),
|
|
140
|
+
"instance_count": len(self.instances),
|
|
141
|
+
"instance_ids": self.instance_ids[:50],
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
def instance_by_seed(self, seed: int | None) -> VerilogTaskInstance:
|
|
145
|
+
if not self.instances:
|
|
146
|
+
raise ValueError("Verilog dataset is empty.")
|
|
147
|
+
if seed is None:
|
|
148
|
+
index = 0
|
|
149
|
+
else:
|
|
150
|
+
index = int(seed) % len(self.instances)
|
|
151
|
+
return self.instances[index]
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def build_dataset() -> tuple[TaskDatasetRegistry, VerilogDataset]:
|
|
155
|
+
registry = TaskDatasetRegistry()
|
|
156
|
+
dataset = VerilogDataset(DATASET_SPEC, MAX_INSTANCES)
|
|
157
|
+
registry.register(DATASET_SPEC, lambda _spec: dataset, cache=True)
|
|
158
|
+
return registry, dataset
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def _base_task_info(dataset: VerilogDataset) -> TaskInfo:
|
|
162
|
+
return TaskInfo(
|
|
163
|
+
task={"id": "verilog_eval_v2", "name": "VerilogEval Spec-to-RTL", "version": "1.0.0"},
|
|
164
|
+
environment="verilog",
|
|
165
|
+
action_space={
|
|
166
|
+
"type": "tool_calls",
|
|
167
|
+
"tools": TOOLS,
|
|
168
|
+
"description": "Filesystem editing, compilation, simulation, and submission tools.",
|
|
169
|
+
},
|
|
170
|
+
observation={
|
|
171
|
+
"summary": "Dictionary observations describing files, compilation status, simulation results, and rewards.",
|
|
172
|
+
"format": "dict",
|
|
173
|
+
"keys": ["files", "compile_status", "simulate_status", "reward_last"],
|
|
174
|
+
},
|
|
175
|
+
dataset={**dataset.describe(), "default_seed": dataset.default_seed},
|
|
176
|
+
rubric={
|
|
177
|
+
"version": "1",
|
|
178
|
+
"criteria_count": 1,
|
|
179
|
+
"source": "inline",
|
|
180
|
+
"aggregation": "weighted_sum",
|
|
181
|
+
},
|
|
182
|
+
inference={
|
|
183
|
+
"supports_proxy": True,
|
|
184
|
+
"endpoints": {
|
|
185
|
+
"openai": "/proxy/v1/chat/completions",
|
|
186
|
+
"groq": "/proxy/groq/v1/chat/completions",
|
|
187
|
+
},
|
|
188
|
+
"tool": {"name": "verilog_tools", "parallel_tool_calls": False},
|
|
189
|
+
},
|
|
190
|
+
limits={"max_ops": 0, "max_time_s": 3600},
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def _format_file_previews(files: dict[str, str]) -> str:
|
|
197
|
+
if not files:
|
|
198
|
+
return "No files in the workspace yet."
|
|
199
|
+
|
|
200
|
+
sections: list[str] = []
|
|
201
|
+
for name in sorted(files.keys()):
|
|
202
|
+
content = files[name] or ""
|
|
203
|
+
snippet = content.strip()
|
|
204
|
+
if len(snippet) > FILE_PREVIEW_CHARS:
|
|
205
|
+
snippet = snippet[:FILE_PREVIEW_CHARS] + "\n..."
|
|
206
|
+
sections.append(f"{name}:\n{snippet}")
|
|
207
|
+
return "\n\n".join(sections)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def _format_observation_text(
|
|
211
|
+
*,
|
|
212
|
+
observation: dict[str, Any],
|
|
213
|
+
step_index: int,
|
|
214
|
+
instructions: str | None,
|
|
215
|
+
action_feedback: str | None,
|
|
216
|
+
guidance: str | None = None,
|
|
217
|
+
) -> str:
|
|
218
|
+
lines: list[str] = []
|
|
219
|
+
if step_index == 0 and instructions:
|
|
220
|
+
lines.append("Task instructions:")
|
|
221
|
+
lines.append(instructions.strip())
|
|
222
|
+
lines.append("")
|
|
223
|
+
|
|
224
|
+
lines.append(f"Step {step_index} status:")
|
|
225
|
+
reward_last = observation.get("reward_last")
|
|
226
|
+
total_reward = observation.get("total_reward")
|
|
227
|
+
if reward_last is not None or total_reward is not None:
|
|
228
|
+
lines.append(
|
|
229
|
+
f"- reward_last={reward_last!r}, total_reward={total_reward!r}"
|
|
230
|
+
)
|
|
231
|
+
lines.append(f"- task_completed={bool(observation.get('task_completed'))}")
|
|
232
|
+
compile_status = observation.get("compile_status")
|
|
233
|
+
if compile_status:
|
|
234
|
+
lines.append(f"- compile_status: {compile_status}")
|
|
235
|
+
simulate_status = observation.get("simulate_status")
|
|
236
|
+
if simulate_status:
|
|
237
|
+
lines.append(f"- simulate_status: {simulate_status}")
|
|
238
|
+
build_dir = observation.get("build_dir")
|
|
239
|
+
if build_dir:
|
|
240
|
+
lines.append(f"- build_directory: {build_dir}")
|
|
241
|
+
|
|
242
|
+
if action_feedback:
|
|
243
|
+
lines.append("")
|
|
244
|
+
lines.append(action_feedback)
|
|
245
|
+
|
|
246
|
+
files = observation.get("files")
|
|
247
|
+
lines.append("")
|
|
248
|
+
lines.append("Workspace files:")
|
|
249
|
+
lines.append(_format_file_previews(files or {}))
|
|
250
|
+
|
|
251
|
+
lines.append("")
|
|
252
|
+
lines.append(
|
|
253
|
+
"Select the single most helpful tool for the next step (write_file, compile, simulate, submit)."
|
|
254
|
+
)
|
|
255
|
+
lines.append(
|
|
256
|
+
"Respond with JSON only: {\"tool\": \"<tool_name>\", \"args\": {...}}."
|
|
257
|
+
)
|
|
258
|
+
if guidance:
|
|
259
|
+
lines.append("")
|
|
260
|
+
lines.append(guidance.strip())
|
|
261
|
+
return "\n".join(lines)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def _summarize_action_feedback(
|
|
265
|
+
tool_name: str, args: dict[str, Any], observation: dict[str, Any], reward: float
|
|
266
|
+
) -> str:
|
|
267
|
+
argument_preview = json.dumps(args, ensure_ascii=False)
|
|
268
|
+
parts = [
|
|
269
|
+
f"Previous action: {tool_name}({argument_preview})",
|
|
270
|
+
f"Reward delta: {reward:.4f}",
|
|
271
|
+
]
|
|
272
|
+
compile_status = observation.get("compile_status")
|
|
273
|
+
if compile_status:
|
|
274
|
+
parts.append(f"Compile status: {compile_status}")
|
|
275
|
+
simulate_status = observation.get("simulate_status")
|
|
276
|
+
if simulate_status:
|
|
277
|
+
parts.append(f"Simulation status: {simulate_status}")
|
|
278
|
+
if observation.get("task_completed"):
|
|
279
|
+
parts.append("Task completed ✅")
|
|
280
|
+
total_reward = observation.get("total_reward")
|
|
281
|
+
if total_reward is not None:
|
|
282
|
+
parts.append(f"Total reward: {total_reward}")
|
|
283
|
+
return "\n".join(parts)
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
JSON_BLOCK_PATTERN = re.compile(r"```(?:json)?\s*(\{.*?\})\s*```", re.DOTALL)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def _parse_tool_json(text: str) -> list[dict[str, Any]]:
|
|
290
|
+
candidates: list[dict[str, Any]] = []
|
|
291
|
+
|
|
292
|
+
try:
|
|
293
|
+
parsed = json.loads(text)
|
|
294
|
+
if isinstance(parsed, dict):
|
|
295
|
+
candidates.append(parsed)
|
|
296
|
+
except Exception:
|
|
297
|
+
pass
|
|
298
|
+
|
|
299
|
+
if not candidates:
|
|
300
|
+
for match in JSON_BLOCK_PATTERN.finditer(text):
|
|
301
|
+
snippet = match.group(1)
|
|
302
|
+
try:
|
|
303
|
+
parsed = json.loads(snippet)
|
|
304
|
+
except Exception:
|
|
305
|
+
continue
|
|
306
|
+
if isinstance(parsed, dict):
|
|
307
|
+
candidates.append(parsed)
|
|
308
|
+
|
|
309
|
+
if not candidates:
|
|
310
|
+
brace_match = re.search(r"\{.*\}", text, re.DOTALL)
|
|
311
|
+
if brace_match:
|
|
312
|
+
try:
|
|
313
|
+
parsed = json.loads(brace_match.group(0))
|
|
314
|
+
if isinstance(parsed, dict):
|
|
315
|
+
candidates.append(parsed)
|
|
316
|
+
except Exception:
|
|
317
|
+
pass
|
|
318
|
+
|
|
319
|
+
for candidate in candidates:
|
|
320
|
+
tool_name = candidate.get("tool") if isinstance(candidate, dict) else None
|
|
321
|
+
if not isinstance(tool_name, str):
|
|
322
|
+
continue
|
|
323
|
+
raw_args = candidate.get("args") if isinstance(candidate, dict) else None
|
|
324
|
+
args = raw_args if isinstance(raw_args, dict) else {}
|
|
325
|
+
tool_name = tool_name.strip()
|
|
326
|
+
normalized_args: dict[str, Any] = dict(args)
|
|
327
|
+
if tool_name == "write_file":
|
|
328
|
+
if "file_path" in normalized_args and "path" not in normalized_args:
|
|
329
|
+
normalized_args["path"] = normalized_args.pop("file_path")
|
|
330
|
+
if "file" in normalized_args and "path" not in normalized_args:
|
|
331
|
+
normalized_args["path"] = normalized_args.pop("file")
|
|
332
|
+
if "contents" in normalized_args and "content" not in normalized_args:
|
|
333
|
+
normalized_args["content"] = normalized_args.pop("contents")
|
|
334
|
+
return [{"tool": tool_name, "args": normalized_args}]
|
|
335
|
+
|
|
336
|
+
return []
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
class VerilogLLMAgent:
|
|
340
|
+
"""Minimal ReAct-style agent that communicates with a chat-completions API."""
|
|
341
|
+
|
|
342
|
+
def __init__(
|
|
343
|
+
self,
|
|
344
|
+
*,
|
|
345
|
+
instructions: str,
|
|
346
|
+
inference_url: str | None,
|
|
347
|
+
model: str | None,
|
|
348
|
+
temperature: float,
|
|
349
|
+
max_tokens: int,
|
|
350
|
+
) -> None:
|
|
351
|
+
self.instructions = instructions.strip()
|
|
352
|
+
self.inference_url = normalize_inference_url(inference_url, default=DEFAULT_INFERENCE_URL)
|
|
353
|
+
self.model = model or DEFAULT_MODEL
|
|
354
|
+
self.temperature = temperature
|
|
355
|
+
self.max_tokens = max_tokens
|
|
356
|
+
self.messages: list[dict[str, Any]] = [{"role": "system", "content": VERILOG_SYSTEM_PROMPT}]
|
|
357
|
+
self.headers: dict[str, str] = {"Content-Type": "application/json"}
|
|
358
|
+
|
|
359
|
+
lowered = self.inference_url.lower()
|
|
360
|
+
if "groq" in lowered:
|
|
361
|
+
api_key = os.getenv("GROQ_API_KEY")
|
|
362
|
+
if not api_key:
|
|
363
|
+
raise RuntimeError("GROQ_API_KEY is not configured for Verilog inference.")
|
|
364
|
+
self.headers["Authorization"] = f"Bearer {api_key.strip()}"
|
|
365
|
+
# If target is Synth backend (any deployment), use SYNTH_API_KEY
|
|
366
|
+
elif any(pattern in lowered for pattern in [
|
|
367
|
+
"synth-backend", "synth.run", "agent-learning",
|
|
368
|
+
"localhost:8000", "127.0.0.1:8000"
|
|
369
|
+
]):
|
|
370
|
+
api_key = os.getenv("SYNTH_API_KEY")
|
|
371
|
+
if not api_key:
|
|
372
|
+
raise RuntimeError("SYNTH_API_KEY is not configured for Verilog inference with Synth backend.")
|
|
373
|
+
self.headers["Authorization"] = f"Bearer {api_key.strip()}"
|
|
374
|
+
elif "openai" in lowered or "api.openai.com" in lowered:
|
|
375
|
+
api_key = os.getenv("OPENAI_API_KEY")
|
|
376
|
+
if not api_key:
|
|
377
|
+
raise RuntimeError("OPENAI_API_KEY is not configured for Verilog inference.")
|
|
378
|
+
self.headers["Authorization"] = f"Bearer {api_key.strip()}"
|
|
379
|
+
|
|
380
|
+
self.history: list[dict[str, Any]] = []
|
|
381
|
+
|
|
382
|
+
def append_observation(
|
|
383
|
+
self,
|
|
384
|
+
*,
|
|
385
|
+
observation: dict[str, Any],
|
|
386
|
+
step_index: int,
|
|
387
|
+
action_feedback: str | None,
|
|
388
|
+
guidance: str | None = None,
|
|
389
|
+
) -> str:
|
|
390
|
+
text = _format_observation_text(
|
|
391
|
+
observation=observation,
|
|
392
|
+
step_index=step_index,
|
|
393
|
+
instructions=self.instructions if step_index == 0 else None,
|
|
394
|
+
action_feedback=action_feedback,
|
|
395
|
+
guidance=guidance,
|
|
396
|
+
)
|
|
397
|
+
self.messages.append({"role": "user", "content": text})
|
|
398
|
+
self.history.append({"role": "user", "content": text})
|
|
399
|
+
return text
|
|
400
|
+
|
|
401
|
+
async def invoke(
|
|
402
|
+
self, client: httpx.AsyncClient
|
|
403
|
+
) -> tuple[str, list[dict[str, Any]], dict[str, Any], dict[str, Any]]:
|
|
404
|
+
payload: dict[str, Any] = {
|
|
405
|
+
"model": self.model,
|
|
406
|
+
"messages": self.messages,
|
|
407
|
+
"temperature": self.temperature,
|
|
408
|
+
}
|
|
409
|
+
if self.max_tokens > 0:
|
|
410
|
+
payload["max_tokens"] = self.max_tokens
|
|
411
|
+
|
|
412
|
+
try:
|
|
413
|
+
response = await client.post(self.inference_url, json=payload, headers=self.headers)
|
|
414
|
+
except Exception as exc: # pragma: no cover - network failure
|
|
415
|
+
raise RuntimeError(f"Failed to reach inference endpoint: {exc}") from exc
|
|
416
|
+
|
|
417
|
+
try:
|
|
418
|
+
response.raise_for_status()
|
|
419
|
+
except httpx.HTTPStatusError as exc: # pragma: no cover - inference error
|
|
420
|
+
preview = exc.response.text[:2000]
|
|
421
|
+
raise RuntimeError(
|
|
422
|
+
f"Inference call failed with status {exc.response.status_code}: {preview}"
|
|
423
|
+
) from exc
|
|
424
|
+
|
|
425
|
+
data = response.json()
|
|
426
|
+
choices = data.get("choices") or []
|
|
427
|
+
message = choices[0].get("message", {}) if choices else {}
|
|
428
|
+
assistant_text = message.get("content") or ""
|
|
429
|
+
self.messages.append({"role": "assistant", "content": assistant_text})
|
|
430
|
+
self.history.append({"role": "assistant", "content": assistant_text})
|
|
431
|
+
|
|
432
|
+
parsed_calls = _parse_tool_json(assistant_text)
|
|
433
|
+
|
|
434
|
+
return assistant_text, parsed_calls, data, payload
|
|
435
|
+
|
|
436
|
+
OUTCOME_RUBRIC = load_rubric(
|
|
437
|
+
{
|
|
438
|
+
"version": "1",
|
|
439
|
+
"goal_text": "Produce a Verilog implementation that passes the provided testbench.",
|
|
440
|
+
"aggregation": "weighted_sum",
|
|
441
|
+
"criteria": [
|
|
442
|
+
{
|
|
443
|
+
"id": "tests_pass",
|
|
444
|
+
"description": "Submission passes all compile and simulation checks.",
|
|
445
|
+
"weight": 1.0,
|
|
446
|
+
}
|
|
447
|
+
],
|
|
448
|
+
}
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
EVENTS_RUBRIC = load_rubric(
|
|
452
|
+
{
|
|
453
|
+
"version": "1",
|
|
454
|
+
"goal_text": "Encourage deliberate hardware design iterations.",
|
|
455
|
+
"aggregation": "weighted_sum",
|
|
456
|
+
"criteria": [
|
|
457
|
+
{
|
|
458
|
+
"id": "efficient_iterations",
|
|
459
|
+
"description": "Use write/compile/simulate tools strategically before submitting.",
|
|
460
|
+
"weight": 1.0,
|
|
461
|
+
}
|
|
462
|
+
],
|
|
463
|
+
}
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
def describe_taskset(dataset: VerilogDataset) -> dict[str, Any]:
|
|
468
|
+
return dataset.describe()
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
def provide_task_instances(
|
|
472
|
+
dataset: VerilogDataset, base_info: TaskInfo, seeds: Sequence[int]
|
|
473
|
+
) -> Iterable[TaskInfo]:
|
|
474
|
+
infos: list[TaskInfo] = []
|
|
475
|
+
base_observation = getattr(base_info, "observation", None)
|
|
476
|
+
if hasattr(base_observation, "model_dump"):
|
|
477
|
+
observation_template = base_observation.model_dump()
|
|
478
|
+
elif isinstance(base_observation, dict):
|
|
479
|
+
observation_template = dict(base_observation)
|
|
480
|
+
else:
|
|
481
|
+
observation_template = {}
|
|
482
|
+
|
|
483
|
+
for seed in seeds:
|
|
484
|
+
instance = dataset.instance_by_seed(seed)
|
|
485
|
+
metadata: VerilogTaskInstanceMetadata = instance.metadata # type: ignore[assignment]
|
|
486
|
+
meta_dict = {
|
|
487
|
+
"problem_name": getattr(metadata, "problem_name", None),
|
|
488
|
+
"difficulty": getattr(metadata, "difficulty", None),
|
|
489
|
+
"description": getattr(metadata, "description", None),
|
|
490
|
+
"files_provided": getattr(metadata, "files_provided", None),
|
|
491
|
+
}
|
|
492
|
+
infos.append(
|
|
493
|
+
TaskInfo(
|
|
494
|
+
task=base_info.task,
|
|
495
|
+
environment=base_info.environment,
|
|
496
|
+
action_space=base_info.action_space,
|
|
497
|
+
observation={
|
|
498
|
+
**observation_template,
|
|
499
|
+
"problem_name": meta_dict["problem_name"],
|
|
500
|
+
"difficulty": meta_dict["difficulty"],
|
|
501
|
+
},
|
|
502
|
+
dataset={
|
|
503
|
+
**base_info.dataset.model_dump(),
|
|
504
|
+
"instance_id": str(instance.id),
|
|
505
|
+
"metadata": meta_dict,
|
|
506
|
+
},
|
|
507
|
+
rubric=base_info.rubric,
|
|
508
|
+
inference=base_info.inference,
|
|
509
|
+
limits=base_info.limits,
|
|
510
|
+
)
|
|
511
|
+
)
|
|
512
|
+
return infos
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
def _ensure_dataset_from_state(fastapi_request, fallback: VerilogDataset) -> VerilogDataset:
|
|
516
|
+
if fastapi_request is None:
|
|
517
|
+
return fallback
|
|
518
|
+
state = getattr(getattr(fastapi_request, "app", None), "state", None)
|
|
519
|
+
candidate = getattr(state, "dataset", None)
|
|
520
|
+
return candidate or fallback
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
def _normalise_observation(value: Any) -> dict[str, Any]:
|
|
524
|
+
if isinstance(value, dict):
|
|
525
|
+
return value
|
|
526
|
+
if hasattr(value, "observation"):
|
|
527
|
+
obs = getattr(value, "observation")
|
|
528
|
+
if isinstance(obs, dict):
|
|
529
|
+
return obs
|
|
530
|
+
return {"text": str(obs)}
|
|
531
|
+
return {"text": str(value)}
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
async def rollout_executor(
|
|
535
|
+
request: RolloutRequest, fastapi_request
|
|
536
|
+
) -> RolloutResponse:
|
|
537
|
+
dataset = _ensure_dataset_from_state(fastapi_request, RUNTIME_DATASET)
|
|
538
|
+
env_seed = getattr(request.env, "seed", None) if request and request.env else None
|
|
539
|
+
instance = dataset.instance_by_seed(env_seed)
|
|
540
|
+
env = VerilogEnvironment(task_instance=instance)
|
|
541
|
+
|
|
542
|
+
policy_config_raw = getattr(request.policy, "config", {}) if request.policy else {}
|
|
543
|
+
policy_config = dict(policy_config_raw) if isinstance(policy_config_raw, dict) else {}
|
|
544
|
+
|
|
545
|
+
policy_model = policy_config.get("model")
|
|
546
|
+
if not isinstance(policy_model, str) or not policy_model.strip():
|
|
547
|
+
policy_model = getattr(request.policy, "policy_name", None) or DEFAULT_MODEL
|
|
548
|
+
policy_model = policy_model.strip()
|
|
549
|
+
|
|
550
|
+
temperature = policy_config.get("temperature", DEFAULT_TEMPERATURE)
|
|
551
|
+
try:
|
|
552
|
+
temperature = float(temperature)
|
|
553
|
+
except (TypeError, ValueError):
|
|
554
|
+
temperature = DEFAULT_TEMPERATURE
|
|
555
|
+
|
|
556
|
+
max_tokens = policy_config.get("max_tokens", DEFAULT_MAX_TOKENS)
|
|
557
|
+
try:
|
|
558
|
+
max_tokens = int(max_tokens)
|
|
559
|
+
except (TypeError, ValueError):
|
|
560
|
+
max_tokens = DEFAULT_MAX_TOKENS
|
|
561
|
+
|
|
562
|
+
max_steps_candidate = (
|
|
563
|
+
policy_config.get("max_steps")
|
|
564
|
+
or policy_config.get("max_llm_calls")
|
|
565
|
+
or DEFAULT_MAX_STEPS
|
|
566
|
+
)
|
|
567
|
+
try:
|
|
568
|
+
max_steps = int(max_steps_candidate)
|
|
569
|
+
except (TypeError, ValueError):
|
|
570
|
+
max_steps = DEFAULT_MAX_STEPS
|
|
571
|
+
max_steps = max(1, min(25, max_steps))
|
|
572
|
+
|
|
573
|
+
inference_url = policy_config.get("inference_url")
|
|
574
|
+
if isinstance(inference_url, str) and inference_url.strip():
|
|
575
|
+
resolved_inference = inference_url.strip()
|
|
576
|
+
else:
|
|
577
|
+
resolved_inference = os.getenv("VERILOG_INFERENCE_URL", DEFAULT_INFERENCE_URL)
|
|
578
|
+
|
|
579
|
+
instructions = getattr(getattr(instance, "impetus", None), "instructions", "")
|
|
580
|
+
agent = VerilogLLMAgent(
|
|
581
|
+
instructions=getattr(getattr(instance, "impetus", None), "instructions", ""),
|
|
582
|
+
inference_url=resolved_inference,
|
|
583
|
+
model=policy_model,
|
|
584
|
+
temperature=temperature,
|
|
585
|
+
max_tokens=max_tokens,
|
|
586
|
+
)
|
|
587
|
+
|
|
588
|
+
policy_id = (
|
|
589
|
+
getattr(request.policy, "policy_id", None)
|
|
590
|
+
or getattr(request.policy, "policy_name", None)
|
|
591
|
+
or policy_model
|
|
592
|
+
)
|
|
593
|
+
env_id = getattr(request.env, "env_id", None) or getattr(request.env, "env_name", None) or "verilog"
|
|
594
|
+
|
|
595
|
+
steps: list[RolloutStep] = []
|
|
596
|
+
total_reward = 0.0
|
|
597
|
+
final_observation: dict[str, Any] | None = None
|
|
598
|
+
truncated_due_to_limit = False
|
|
599
|
+
|
|
600
|
+
# Log episode start
|
|
601
|
+
problem_id = getattr(instance, "problem_id", "unknown")
|
|
602
|
+
logger.info("=" * 80)
|
|
603
|
+
logger.info(f"[EPISODE START] run_id={request.run_id}")
|
|
604
|
+
logger.info(f" Problem ID: {problem_id}")
|
|
605
|
+
logger.info(f" Policy: {policy_id}")
|
|
606
|
+
logger.info(f" Model: {policy_model}")
|
|
607
|
+
logger.info(f" Max steps: {max_steps}")
|
|
608
|
+
logger.info(f" Temperature: {temperature}")
|
|
609
|
+
logger.info(f" Max tokens: {max_tokens}")
|
|
610
|
+
if instructions:
|
|
611
|
+
instructions_preview = instructions[:150] + "..." if len(instructions) > 150 else instructions
|
|
612
|
+
logger.info(f" Instructions: {instructions_preview}")
|
|
613
|
+
logger.info("=" * 80)
|
|
614
|
+
code_dirty = False
|
|
615
|
+
last_compile_success = False
|
|
616
|
+
simulate_since_last_compile = False
|
|
617
|
+
last_compile_failed = False
|
|
618
|
+
needs_design_update = False
|
|
619
|
+
|
|
620
|
+
def _build_guidance(step_idx: int) -> str | None:
|
|
621
|
+
hints: list[str] = []
|
|
622
|
+
if step_idx == 0 and not last_compile_success:
|
|
623
|
+
hints.append("Begin by using write_file to implement TopModule according to the problem instructions before compiling.")
|
|
624
|
+
if last_compile_failed or needs_design_update:
|
|
625
|
+
hints.append("Compilation failed; update the design with write_file to match the required ports and behavior before compiling again.")
|
|
626
|
+
if code_dirty and not last_compile_success:
|
|
627
|
+
hints.append("Source was modified; run compile before simulate or submit.")
|
|
628
|
+
if (not code_dirty) and last_compile_success and not simulate_since_last_compile:
|
|
629
|
+
hints.append("Compilation succeeded; run simulate to verify before other actions.")
|
|
630
|
+
if (not code_dirty) and last_compile_success and simulate_since_last_compile:
|
|
631
|
+
hints.append("Simulation already ran after the latest compile; submit if the checks passed or make new edits first.")
|
|
632
|
+
return " ".join(hints) if hints else None
|
|
633
|
+
|
|
634
|
+
try:
|
|
635
|
+
initial_raw_observation = await env.initialize()
|
|
636
|
+
current_observation = _normalise_observation(initial_raw_observation)
|
|
637
|
+
final_observation = current_observation
|
|
638
|
+
agent.append_observation(
|
|
639
|
+
observation=current_observation,
|
|
640
|
+
step_index=0,
|
|
641
|
+
action_feedback=None,
|
|
642
|
+
guidance=_build_guidance(0),
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
total_reward = float(current_observation.get("total_reward") or 0.0)
|
|
646
|
+
already_done = bool(
|
|
647
|
+
current_observation.get("terminated") or current_observation.get("task_completed")
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
timeout = httpx.Timeout(
|
|
651
|
+
HTTP_TIMEOUT_SECONDS,
|
|
652
|
+
connect=HTTP_TIMEOUT_SECONDS,
|
|
653
|
+
read=HTTP_TIMEOUT_SECONDS,
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
657
|
+
if not already_done:
|
|
658
|
+
for step_index in range(1, max_steps + 1):
|
|
659
|
+
assistant_text, tool_calls, raw_response, request_payload = await agent.invoke(client)
|
|
660
|
+
override_info: dict[str, Any] | None = None
|
|
661
|
+
if not tool_calls:
|
|
662
|
+
fallback_tool = (
|
|
663
|
+
"submit" if current_observation.get("task_completed") else "compile"
|
|
664
|
+
)
|
|
665
|
+
tool_calls = [{"tool": fallback_tool, "args": {}}]
|
|
666
|
+
|
|
667
|
+
primary_call = dict(tool_calls[0])
|
|
668
|
+
tool_name_raw = str(primary_call.get("tool", ""))
|
|
669
|
+
normalized_tool = tool_name_raw.strip().lower()
|
|
670
|
+
if normalized_tool == "compile":
|
|
671
|
+
if (not code_dirty) and last_compile_success and not simulate_since_last_compile:
|
|
672
|
+
override_info = {
|
|
673
|
+
"from": dict(primary_call),
|
|
674
|
+
"reason": "compile_after_success_without_changes",
|
|
675
|
+
}
|
|
676
|
+
primary_call = {"tool": "simulate", "args": {}}
|
|
677
|
+
tool_calls = [primary_call]
|
|
678
|
+
override_info["to"] = dict(primary_call)
|
|
679
|
+
env_call = EnvToolCall(tool=primary_call["tool"], args=primary_call["args"])
|
|
680
|
+
|
|
681
|
+
try:
|
|
682
|
+
skip_env_step = (
|
|
683
|
+
normalized_tool == "compile"
|
|
684
|
+
and needs_design_update
|
|
685
|
+
and not code_dirty
|
|
686
|
+
)
|
|
687
|
+
if skip_env_step:
|
|
688
|
+
reward_last = 0.0 # No reward for blocked operations
|
|
689
|
+
total_reward += reward_last
|
|
690
|
+
current_observation = dict(current_observation)
|
|
691
|
+
current_observation["reward_last"] = reward_last
|
|
692
|
+
current_observation["total_reward"] = total_reward
|
|
693
|
+
final_observation = current_observation
|
|
694
|
+
done_flag = False
|
|
695
|
+
truncated_flag = False
|
|
696
|
+
else:
|
|
697
|
+
step_observation = await env.step(env_call)
|
|
698
|
+
current_observation = _normalise_observation(step_observation)
|
|
699
|
+
final_observation = current_observation
|
|
700
|
+
reward_last = float(current_observation.get("reward_last") or 0.0)
|
|
701
|
+
total_reward = float(
|
|
702
|
+
current_observation.get("total_reward") or (total_reward + reward_last)
|
|
703
|
+
)
|
|
704
|
+
done_flag = bool(
|
|
705
|
+
current_observation.get("terminated")
|
|
706
|
+
or current_observation.get("task_completed")
|
|
707
|
+
)
|
|
708
|
+
truncated_flag = bool(current_observation.get("truncated"))
|
|
709
|
+
|
|
710
|
+
# Log what the environment returned
|
|
711
|
+
print(f"\n{'='*80}")
|
|
712
|
+
print(f"[STEP {step_index}] TOOL CALL:")
|
|
713
|
+
print(f" Tool: {env_call.tool}")
|
|
714
|
+
print(f" Args: {env_call.args}")
|
|
715
|
+
print(f"\n[STEP {step_index}] ENVIRONMENT RESPONSE:")
|
|
716
|
+
print(f" Reward: {reward_last:.4f} (cumulative: {total_reward:.4f})")
|
|
717
|
+
print(f" Task completed: {step_observation.get('task_completed')}")
|
|
718
|
+
print(f" Done: {done_flag} | Truncated: {truncated_flag}")
|
|
719
|
+
if 'compile_status' in step_observation and step_observation.get('compile_status'):
|
|
720
|
+
print(f" Compile status:\n{step_observation.get('compile_status')}")
|
|
721
|
+
if 'simulate_status' in step_observation and step_observation.get('simulate_status'):
|
|
722
|
+
print(f" Simulate status:\n{step_observation.get('simulate_status')}")
|
|
723
|
+
if 'files' in step_observation:
|
|
724
|
+
print(f" Files: {list(step_observation.get('files', {}).keys())}")
|
|
725
|
+
print(f"{'='*80}\n")
|
|
726
|
+
|
|
727
|
+
executed_tool_name = str(primary_call["tool"])
|
|
728
|
+
normalized_executed_tool = executed_tool_name.strip().lower()
|
|
729
|
+
|
|
730
|
+
if normalized_executed_tool == "write_file":
|
|
731
|
+
code_dirty = True
|
|
732
|
+
last_compile_success = False
|
|
733
|
+
simulate_since_last_compile = False
|
|
734
|
+
last_compile_failed = False
|
|
735
|
+
needs_design_update = False
|
|
736
|
+
elif normalized_executed_tool == "compile":
|
|
737
|
+
compile_status_text = str(current_observation.get("compile_status") or "")
|
|
738
|
+
if "success" in compile_status_text.lower():
|
|
739
|
+
code_dirty = False
|
|
740
|
+
last_compile_success = True
|
|
741
|
+
simulate_since_last_compile = False
|
|
742
|
+
last_compile_failed = False
|
|
743
|
+
needs_design_update = False
|
|
744
|
+
else:
|
|
745
|
+
last_compile_success = False
|
|
746
|
+
last_compile_failed = True
|
|
747
|
+
needs_design_update = True
|
|
748
|
+
elif normalized_executed_tool == "simulate":
|
|
749
|
+
simulate_since_last_compile = True
|
|
750
|
+
|
|
751
|
+
tool_call_records = [
|
|
752
|
+
{"tool_name": call["tool"], "arguments": call["args"]}
|
|
753
|
+
for call in tool_calls
|
|
754
|
+
]
|
|
755
|
+
|
|
756
|
+
# Print tool calls for debugging
|
|
757
|
+
logger.info(f"[STEP {step_index}] Tool calls executed:")
|
|
758
|
+
for call in tool_calls:
|
|
759
|
+
tool_name = call["tool"]
|
|
760
|
+
args = call["args"]
|
|
761
|
+
# Truncate long arguments for readability
|
|
762
|
+
if "code" in args or "content" in args:
|
|
763
|
+
args_preview = {k: (v[:100] + "..." if isinstance(v, str) and len(v) > 100 else v)
|
|
764
|
+
for k, v in args.items()}
|
|
765
|
+
else:
|
|
766
|
+
args_preview = args
|
|
767
|
+
logger.info(f" └─ {tool_name}({args_preview})")
|
|
768
|
+
|
|
769
|
+
# Log reward details for debugging
|
|
770
|
+
logger.info(f"[STEP {step_index}] Reward details:")
|
|
771
|
+
logger.info(f" └─ reward_last: {reward_last:.4f}")
|
|
772
|
+
logger.info(f" └─ total_reward: {total_reward:.4f}")
|
|
773
|
+
logger.info(f" └─ skip_env_step: {skip_env_step}")
|
|
774
|
+
if not skip_env_step:
|
|
775
|
+
logger.info(f" └─ obs.task_completed: {current_observation.get('task_completed', False)}")
|
|
776
|
+
logger.info(f" └─ obs.compile_status: {current_observation.get('compile_status', 'N/A')}")
|
|
777
|
+
logger.info(f" └─ obs.simulate_status: {current_observation.get('simulate_status', 'N/A')}")
|
|
778
|
+
logger.info(f" └─ obs.terminated: {current_observation.get('terminated', False)}")
|
|
779
|
+
else:
|
|
780
|
+
logger.info(f" └─ (blocked operation - no env step)")
|
|
781
|
+
|
|
782
|
+
step_info = {
|
|
783
|
+
"assistant_message": assistant_text,
|
|
784
|
+
"model_response": raw_response,
|
|
785
|
+
"llm_request": request_payload,
|
|
786
|
+
"meta": {
|
|
787
|
+
"inference_url": policy_config.get("inference_url") or resolved_inference, # CRITICAL: Required by RL trainer for trace extraction (must have ?cid=...)
|
|
788
|
+
},
|
|
789
|
+
}
|
|
790
|
+
if override_info:
|
|
791
|
+
step_info["auto_override"] = override_info
|
|
792
|
+
if normalized_tool == "compile" and skip_env_step:
|
|
793
|
+
step_info["compile_blocked"] = {
|
|
794
|
+
"reason": "design_requires_update_before_compile",
|
|
795
|
+
"hint": "Use write_file to match required ports/behavior before compiling again.",
|
|
796
|
+
}
|
|
797
|
+
steps.append(
|
|
798
|
+
RolloutStep(
|
|
799
|
+
obs=current_observation,
|
|
800
|
+
tool_calls=tool_call_records,
|
|
801
|
+
reward=reward_last,
|
|
802
|
+
done=done_flag,
|
|
803
|
+
truncated=truncated_flag,
|
|
804
|
+
info=step_info,
|
|
805
|
+
)
|
|
806
|
+
)
|
|
807
|
+
|
|
808
|
+
if normalized_tool == "compile" and skip_env_step:
|
|
809
|
+
action_feedback = (
|
|
810
|
+
"Compilation blocked: update the design with write_file (declare required ports and logic) before compiling again."
|
|
811
|
+
)
|
|
812
|
+
else:
|
|
813
|
+
action_feedback = _summarize_action_feedback(
|
|
814
|
+
primary_call["tool"], primary_call["args"], current_observation, reward_last
|
|
815
|
+
)
|
|
816
|
+
agent.append_observation(
|
|
817
|
+
observation=current_observation,
|
|
818
|
+
step_index=step_index,
|
|
819
|
+
action_feedback=action_feedback,
|
|
820
|
+
guidance=_build_guidance(step_index),
|
|
821
|
+
)
|
|
822
|
+
|
|
823
|
+
if done_flag:
|
|
824
|
+
break
|
|
825
|
+
|
|
826
|
+
if step_index == max_steps:
|
|
827
|
+
truncated_due_to_limit = True
|
|
828
|
+
break
|
|
829
|
+
except Exception as exc: # pragma: no cover - defensive path
|
|
830
|
+
error_text = str(exc)
|
|
831
|
+
logger.exception("Verilog environment step failed: %s", exc)
|
|
832
|
+
failure_observation = dict(current_observation)
|
|
833
|
+
failure_observation["error"] = error_text
|
|
834
|
+
final_observation = failure_observation
|
|
835
|
+
tool_call_records = [
|
|
836
|
+
{"tool_name": primary_call["tool"], "arguments": primary_call["args"]}
|
|
837
|
+
]
|
|
838
|
+
step_info = {
|
|
839
|
+
"assistant_message": assistant_text,
|
|
840
|
+
"model_response": raw_response,
|
|
841
|
+
"llm_request": request_payload,
|
|
842
|
+
"error": error_text,
|
|
843
|
+
"meta": {
|
|
844
|
+
"inference_url": policy_config.get("inference_url") or resolved_inference, # CRITICAL: Required by RL trainer
|
|
845
|
+
},
|
|
846
|
+
}
|
|
847
|
+
steps.append(
|
|
848
|
+
RolloutStep(
|
|
849
|
+
obs=failure_observation,
|
|
850
|
+
tool_calls=tool_call_records,
|
|
851
|
+
reward=0.0,
|
|
852
|
+
done=True,
|
|
853
|
+
truncated=True,
|
|
854
|
+
info=step_info,
|
|
855
|
+
)
|
|
856
|
+
)
|
|
857
|
+
truncated_due_to_limit = True
|
|
858
|
+
break
|
|
859
|
+
finally:
|
|
860
|
+
with contextlib.suppress(Exception):
|
|
861
|
+
await env.terminate()
|
|
862
|
+
|
|
863
|
+
if final_observation is None:
|
|
864
|
+
final_observation = {}
|
|
865
|
+
|
|
866
|
+
final_total_reward = float(final_observation.get("total_reward") or total_reward)
|
|
867
|
+
final_done = bool(
|
|
868
|
+
final_observation.get("terminated") or final_observation.get("task_completed")
|
|
869
|
+
)
|
|
870
|
+
final_truncated = truncated_due_to_limit or bool(final_observation.get("truncated"))
|
|
871
|
+
|
|
872
|
+
metrics = RolloutMetrics(
|
|
873
|
+
episode_returns=[final_total_reward],
|
|
874
|
+
mean_return=final_total_reward,
|
|
875
|
+
num_steps=len(steps),
|
|
876
|
+
num_episodes=1,
|
|
877
|
+
outcome_score=final_total_reward,
|
|
878
|
+
events_score=None,
|
|
879
|
+
details={
|
|
880
|
+
"task_completed": bool(final_observation.get("task_completed")),
|
|
881
|
+
"total_reward": final_total_reward,
|
|
882
|
+
"steps": len(steps),
|
|
883
|
+
"truncated": final_truncated,
|
|
884
|
+
},
|
|
885
|
+
)
|
|
886
|
+
|
|
887
|
+
# Extract inference_url from policy config (REQUIRED for RL trace correlation)
|
|
888
|
+
# The trainer injects this with ?cid=trace_xxxxx parameter for trace linking
|
|
889
|
+
final_inference_url = policy_config.get("inference_url")
|
|
890
|
+
if not isinstance(final_inference_url, str) or not final_inference_url.strip():
|
|
891
|
+
# Fallback to agent's inference_url if not in policy config
|
|
892
|
+
final_inference_url = agent.inference_url
|
|
893
|
+
logger.warning(
|
|
894
|
+
"VERILOG_ROLLOUT: inference_url not found in policy_config, using agent.inference_url run_id=%s url=%s",
|
|
895
|
+
request.run_id,
|
|
896
|
+
final_inference_url,
|
|
897
|
+
)
|
|
898
|
+
else:
|
|
899
|
+
logger.info(
|
|
900
|
+
"VERILOG_ROLLOUT: using inference_url from policy_config run_id=%s url=%s has_cid=%s",
|
|
901
|
+
request.run_id,
|
|
902
|
+
final_inference_url,
|
|
903
|
+
"?cid=" in final_inference_url,
|
|
904
|
+
)
|
|
905
|
+
|
|
906
|
+
trajectory = RolloutTrajectory(
|
|
907
|
+
env_id=str(env_id),
|
|
908
|
+
policy_id=str(policy_id),
|
|
909
|
+
steps=steps,
|
|
910
|
+
final={
|
|
911
|
+
"observation": final_observation,
|
|
912
|
+
"reward": final_total_reward,
|
|
913
|
+
"done": final_done,
|
|
914
|
+
"truncated": final_truncated,
|
|
915
|
+
"info": {
|
|
916
|
+
"total_reward": final_total_reward,
|
|
917
|
+
"task_completed": bool(final_observation.get("task_completed")),
|
|
918
|
+
"policy_model": policy_model,
|
|
919
|
+
"inference_url": final_inference_url,
|
|
920
|
+
},
|
|
921
|
+
},
|
|
922
|
+
length=len(steps),
|
|
923
|
+
inference_url=final_inference_url, # CRITICAL: Must contain ?cid=... for trace correlation
|
|
924
|
+
decision_samples=None,
|
|
925
|
+
)
|
|
926
|
+
|
|
927
|
+
# Build trace payload
|
|
928
|
+
trace_payload = {
|
|
929
|
+
"session_trace": {
|
|
930
|
+
"session_id": request.run_id,
|
|
931
|
+
"created_at": None,
|
|
932
|
+
"metadata": {
|
|
933
|
+
"task": "verilog",
|
|
934
|
+
"provider": "groq",
|
|
935
|
+
"model": policy_model,
|
|
936
|
+
"total_reward": final_total_reward,
|
|
937
|
+
"task_completed": bool(final_observation.get("task_completed")),
|
|
938
|
+
},
|
|
939
|
+
"session_time_steps": [],
|
|
940
|
+
"event_history": [],
|
|
941
|
+
"markov_blanket_message_history": [],
|
|
942
|
+
}
|
|
943
|
+
}
|
|
944
|
+
|
|
945
|
+
# Build pipeline_metadata (required for RL training)
|
|
946
|
+
pipeline_metadata = {
|
|
947
|
+
"reward_score": final_total_reward,
|
|
948
|
+
"policy_id": policy_id,
|
|
949
|
+
"inference_url": final_inference_url, # CRITICAL: Must be at top level for RL trainer (expects ?cid=...)
|
|
950
|
+
"inference": {
|
|
951
|
+
"provider": "groq",
|
|
952
|
+
"model": policy_model,
|
|
953
|
+
"url": final_inference_url, # Use final_inference_url (has ?cid=...)
|
|
954
|
+
},
|
|
955
|
+
"env_name": env_id,
|
|
956
|
+
"task_id": getattr(instance, "problem_id", None),
|
|
957
|
+
"task_split": getattr(instance, "split", "val"),
|
|
958
|
+
}
|
|
959
|
+
|
|
960
|
+
# Log episode summary with reward breakdown
|
|
961
|
+
compile_status = final_observation.get("compile_status", "N/A")
|
|
962
|
+
simulate_status = final_observation.get("simulate_status", "N/A")
|
|
963
|
+
task_completed = bool(final_observation.get("task_completed", False))
|
|
964
|
+
|
|
965
|
+
logger.info("=" * 80)
|
|
966
|
+
logger.info(f"[EPISODE COMPLETE] run_id={request.run_id}")
|
|
967
|
+
logger.info(f" Steps taken: {len(steps)}")
|
|
968
|
+
logger.info(f" Total reward: {final_total_reward:.3f}")
|
|
969
|
+
logger.info(f" Task completed: {task_completed}")
|
|
970
|
+
logger.info(f" Compile status: {compile_status}")
|
|
971
|
+
logger.info(f" Simulate status: {simulate_status}")
|
|
972
|
+
logger.info(f" Done/Truncated: {final_done}/{final_truncated}")
|
|
973
|
+
logger.info(f" Problem ID: {getattr(instance, 'problem_id', 'N/A')}")
|
|
974
|
+
|
|
975
|
+
# DEBUG: Log each step's reward for RL debugging
|
|
976
|
+
print(f"\n[REWARD DEBUG] Step-by-step breakdown:")
|
|
977
|
+
for idx, step in enumerate(steps):
|
|
978
|
+
print(f" Step {idx}: reward={step.reward:.4f} tool_calls={[tc.get('tool_name') for tc in step.tool_calls]}")
|
|
979
|
+
print(f"[REWARD DEBUG] Final observation keys: {list(final_observation.keys())}")
|
|
980
|
+
print(f"[REWARD DEBUG] Final obs total_reward: {final_observation.get('total_reward')}")
|
|
981
|
+
print(f"[REWARD DEBUG] Metrics outcome_score: {metrics.outcome_score}")
|
|
982
|
+
print(f"[REWARD DEBUG] Metrics mean_return: {metrics.mean_return}")
|
|
983
|
+
|
|
984
|
+
# Reward breakdown for debugging
|
|
985
|
+
logger.info("\n[REWARD BREAKDOWN]")
|
|
986
|
+
compile_count = sum(1 for s in steps if any(tc.get("tool_name") == "compile" for tc in s.tool_calls))
|
|
987
|
+
simulate_count = sum(1 for s in steps if any(tc.get("tool_name") == "simulate" for tc in s.tool_calls))
|
|
988
|
+
submit_count = sum(1 for s in steps if any(tc.get("tool_name") == "submit" for tc in s.tool_calls))
|
|
989
|
+
write_count = sum(1 for s in steps if any(tc.get("tool_name") == "write_file" for tc in s.tool_calls))
|
|
990
|
+
|
|
991
|
+
logger.info(f" Tool usage: write_file={write_count}, compile={compile_count}, simulate={simulate_count}, submit={submit_count}")
|
|
992
|
+
|
|
993
|
+
# Show per-step rewards
|
|
994
|
+
step_rewards = [s.reward for s in steps]
|
|
995
|
+
nonzero_rewards = [r for r in step_rewards if r != 0.0]
|
|
996
|
+
logger.info(f" Step rewards: {step_rewards}")
|
|
997
|
+
if nonzero_rewards:
|
|
998
|
+
logger.info(f" Non-zero rewards: {nonzero_rewards}")
|
|
999
|
+
else:
|
|
1000
|
+
logger.info(f" ⚠️ ALL REWARDS ZERO! Possible reasons:")
|
|
1001
|
+
logger.info(f" - No successful compiles (compile reward = 0.01)")
|
|
1002
|
+
logger.info(f" - No successful simulations (simulate reward = 0.1)")
|
|
1003
|
+
logger.info(f" - No successful submits (submit reward = 1.0)")
|
|
1004
|
+
logger.info(f" - Check if task_completed={task_completed}")
|
|
1005
|
+
logger.info(f" - Check compile_status='{compile_status}'")
|
|
1006
|
+
logger.info(f" - Check simulate_status='{simulate_status}'")
|
|
1007
|
+
logger.info("=" * 80)
|
|
1008
|
+
|
|
1009
|
+
# Log for debugging RL training
|
|
1010
|
+
logger.info(
|
|
1011
|
+
"VERILOG_ROLLOUT: pipeline_metadata run_id=%s reward=%.3f inference_url=%s",
|
|
1012
|
+
request.run_id,
|
|
1013
|
+
final_total_reward,
|
|
1014
|
+
final_inference_url,
|
|
1015
|
+
)
|
|
1016
|
+
|
|
1017
|
+
# DEBUG: Log what we're returning to the RL trainer
|
|
1018
|
+
print(f"\n[RETURN DEBUG] Trajectory structure being returned:")
|
|
1019
|
+
print(f" trajectory.steps count: {len(steps)}")
|
|
1020
|
+
print(f" trajectory.final.reward: {trajectory.final.get('reward') if trajectory.final else 'None'}")
|
|
1021
|
+
print(f" trajectory.length: {trajectory.length}")
|
|
1022
|
+
print(f" metrics.outcome_score: {metrics.outcome_score}")
|
|
1023
|
+
print(f" metrics.mean_return: {metrics.mean_return}")
|
|
1024
|
+
print(f" metrics.episode_returns: {metrics.episode_returns}")
|
|
1025
|
+
print(f" pipeline_metadata.reward_score: {pipeline_metadata.get('reward_score')}")
|
|
1026
|
+
|
|
1027
|
+
# ASSERTIONS: Validate RL-required fields before returning
|
|
1028
|
+
# These catch structural issues early (before they reach the backend trainer)
|
|
1029
|
+
# Only enforce for RL mode, not EVAL mode
|
|
1030
|
+
is_rl_mode = hasattr(request, 'mode') and str(getattr(request, 'mode', '')).lower() == 'rl'
|
|
1031
|
+
|
|
1032
|
+
assert isinstance(pipeline_metadata, dict), (
|
|
1033
|
+
f"VERILOG_ROLLOUT_VALIDATION: pipeline_metadata must be dict, got {type(pipeline_metadata).__name__}"
|
|
1034
|
+
)
|
|
1035
|
+
assert "inference_url" in pipeline_metadata, (
|
|
1036
|
+
f"VERILOG_ROLLOUT_VALIDATION: pipeline_metadata missing 'inference_url' (REQUIRED for RL training)"
|
|
1037
|
+
)
|
|
1038
|
+
assert isinstance(pipeline_metadata["inference_url"], str), (
|
|
1039
|
+
f"VERILOG_ROLLOUT_VALIDATION: pipeline_metadata['inference_url'] must be string, got {type(pipeline_metadata['inference_url']).__name__}"
|
|
1040
|
+
)
|
|
1041
|
+
# Only require ?cid= for RL mode (not needed for EVAL)
|
|
1042
|
+
if is_rl_mode:
|
|
1043
|
+
assert "?cid=" in pipeline_metadata["inference_url"], (
|
|
1044
|
+
f"VERILOG_ROLLOUT_VALIDATION: pipeline_metadata['inference_url'] must contain '?cid=' for trace correlation in RL mode. "
|
|
1045
|
+
f"Got: {pipeline_metadata['inference_url'][:100]}"
|
|
1046
|
+
)
|
|
1047
|
+
|
|
1048
|
+
# Validate each step has meta.inference_url (backend expects this nested structure)
|
|
1049
|
+
for step_idx, step in enumerate(steps):
|
|
1050
|
+
step_dict = step if isinstance(step, dict) else (step.model_dump() if hasattr(step, "model_dump") else {})
|
|
1051
|
+
step_info = step_dict.get("info", {})
|
|
1052
|
+
assert isinstance(step_info, dict), (
|
|
1053
|
+
f"VERILOG_ROLLOUT_VALIDATION: step[{step_idx}].info must be dict, got {type(step_info).__name__}"
|
|
1054
|
+
)
|
|
1055
|
+
step_meta = step_info.get("meta", {})
|
|
1056
|
+
assert isinstance(step_meta, dict), (
|
|
1057
|
+
f"VERILOG_ROLLOUT_VALIDATION: step[{step_idx}].info.meta must be dict, got {type(step_meta).__name__}"
|
|
1058
|
+
)
|
|
1059
|
+
assert "inference_url" in step_meta, (
|
|
1060
|
+
f"VERILOG_ROLLOUT_VALIDATION: step[{step_idx}].info.meta missing 'inference_url' (REQUIRED for RL training)"
|
|
1061
|
+
)
|
|
1062
|
+
assert isinstance(step_meta["inference_url"], str), (
|
|
1063
|
+
f"VERILOG_ROLLOUT_VALIDATION: step[{step_idx}].info.meta['inference_url'] must be string, got {type(step_meta['inference_url']).__name__}"
|
|
1064
|
+
)
|
|
1065
|
+
|
|
1066
|
+
logger.info(
|
|
1067
|
+
"VERILOG_ROLLOUT_VALIDATION: ✓ All RL-required fields present run_id=%s steps=%d",
|
|
1068
|
+
request.run_id,
|
|
1069
|
+
len(steps),
|
|
1070
|
+
)
|
|
1071
|
+
|
|
1072
|
+
return RolloutResponse(
|
|
1073
|
+
run_id=request.run_id,
|
|
1074
|
+
trajectories=[trajectory],
|
|
1075
|
+
branches={},
|
|
1076
|
+
metrics=metrics,
|
|
1077
|
+
aborted=False,
|
|
1078
|
+
ops_executed=len(steps),
|
|
1079
|
+
trace=trace_payload,
|
|
1080
|
+
pipeline_metadata=pipeline_metadata,
|
|
1081
|
+
)
|
|
1082
|
+
|
|
1083
|
+
|
|
1084
|
+
RUNTIME_DATASET: VerilogDataset
|
|
1085
|
+
registry, RUNTIME_DATASET = build_dataset()
|
|
1086
|
+
BASE_INFO = _base_task_info(RUNTIME_DATASET)
|
|
1087
|
+
|
|
1088
|
+
|
|
1089
|
+
def build_config() -> TaskAppConfig:
|
|
1090
|
+
tracing_enabled = tracing_env_enabled()
|
|
1091
|
+
tracing_db_url = resolve_tracing_db_url()
|
|
1092
|
+
tracer_factory = build_tracer_factory(
|
|
1093
|
+
SessionTracer, enabled=tracing_enabled, db_url=tracing_db_url
|
|
1094
|
+
)
|
|
1095
|
+
sft_output_dir = resolve_sft_output_dir()
|
|
1096
|
+
|
|
1097
|
+
app_state: dict[str, Any] = {
|
|
1098
|
+
"dataset": RUNTIME_DATASET,
|
|
1099
|
+
"allowed_environments": ["verilog"],
|
|
1100
|
+
"tracing_enabled": tracing_enabled,
|
|
1101
|
+
}
|
|
1102
|
+
if tracer_factory is not None:
|
|
1103
|
+
app_state["session_tracer_factory"] = tracer_factory
|
|
1104
|
+
if sft_output_dir:
|
|
1105
|
+
app_state["sft_output_dir"] = sft_output_dir
|
|
1106
|
+
|
|
1107
|
+
if tracing_enabled:
|
|
1108
|
+
logger.info("[verilog:tracing] enabled (db=%s)", tracing_db_url or "default")
|
|
1109
|
+
else:
|
|
1110
|
+
logger.info("[verilog:tracing] disabled")
|
|
1111
|
+
if sft_output_dir:
|
|
1112
|
+
logger.info("[verilog:sft] writing JSONL to %s", sft_output_dir)
|
|
1113
|
+
|
|
1114
|
+
config = TaskAppConfig(
|
|
1115
|
+
app_id="grpo-verilog",
|
|
1116
|
+
name="GRPO Verilog Task App",
|
|
1117
|
+
description="Spec-to-RTL Verilog environment with GRPO-compatible metadata endpoints.",
|
|
1118
|
+
base_task_info=BASE_INFO,
|
|
1119
|
+
describe_taskset=lambda: describe_taskset(RUNTIME_DATASET),
|
|
1120
|
+
provide_task_instances=lambda seeds: provide_task_instances(RUNTIME_DATASET, BASE_INFO, seeds),
|
|
1121
|
+
rollout=rollout_executor,
|
|
1122
|
+
dataset_registry=registry,
|
|
1123
|
+
rubrics=RubricBundle(outcome=OUTCOME_RUBRIC, events=EVENTS_RUBRIC),
|
|
1124
|
+
proxy=ProxyConfig(
|
|
1125
|
+
enable_openai=True,
|
|
1126
|
+
enable_groq=True,
|
|
1127
|
+
system_hint=VERILOG_SYSTEM_PROMPT,
|
|
1128
|
+
),
|
|
1129
|
+
routers=(),
|
|
1130
|
+
app_state=app_state,
|
|
1131
|
+
cors_origins=["*"],
|
|
1132
|
+
)
|
|
1133
|
+
return config
|
|
1134
|
+
|
|
1135
|
+
|
|
1136
|
+
register_task_app(
|
|
1137
|
+
entry=TaskAppEntry(
|
|
1138
|
+
app_id="grpo-verilog",
|
|
1139
|
+
description="Verilog spec-to-RTL task app with rollout metadata endpoints.",
|
|
1140
|
+
config_factory=build_config,
|
|
1141
|
+
aliases=("verilog", "verilog-task"),
|
|
1142
|
+
env_files=(str(REPO_ROOT / "backend" / ".env.dev"),),
|
|
1143
|
+
modal=ModalDeploymentConfig(
|
|
1144
|
+
app_name="grpo-verilog-task-app",
|
|
1145
|
+
python_version="3.11",
|
|
1146
|
+
pip_packages=(
|
|
1147
|
+
"fastapi>=0.100.0",
|
|
1148
|
+
"uvicorn>=0.23.0",
|
|
1149
|
+
"pydantic>=2.0.0",
|
|
1150
|
+
"httpx>=0.24.0",
|
|
1151
|
+
"python-dotenv>=1.0.1",
|
|
1152
|
+
"datasets>=2.10.0",
|
|
1153
|
+
),
|
|
1154
|
+
apt_packages=("iverilog",), # Icarus Verilog compiler and simulator (provides iverilog and vvp)
|
|
1155
|
+
extra_local_dirs=(
|
|
1156
|
+
(str(REPO_ROOT), "/opt/synth_ai_repo"),
|
|
1157
|
+
(str(REPO_ROOT / "synth_ai"), "/opt/synth_ai_repo/synth_ai"),
|
|
1158
|
+
(str(_HERE.parent), "/opt/synth_ai_repo/examples/task_apps/verilog/task_app"),
|
|
1159
|
+
),
|
|
1160
|
+
secret_names=("groq-api-key", "openai-api-key"),
|
|
1161
|
+
memory=8192,
|
|
1162
|
+
cpu=2.0,
|
|
1163
|
+
max_containers=4,
|
|
1164
|
+
),
|
|
1165
|
+
)
|
|
1166
|
+
)
|