@elizaos/sweagent-root 2.0.0-alpha
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +270 -0
- package/package.json +71 -0
- package/python/LICENSE +21 -0
- package/python/config/README.md +15 -0
- package/python/config/bash_only.yaml +222 -0
- package/python/config/benchmarks/250212_sweagent_heavy_sbl.yaml +188 -0
- package/python/config/benchmarks/250225_anthropic_filemap_simple_review.yaml +75 -0
- package/python/config/benchmarks/250522_anthropic_filemap_simple_review.yaml +92 -0
- package/python/config/benchmarks/250526_anthropic_filemap_simple_review_sbl.yaml +93 -0
- package/python/config/benchmarks/anthropic_filemap_multilingual.yaml +66 -0
- package/python/config/coding_challenge.yaml +104 -0
- package/python/config/default.yaml +69 -0
- package/python/config/default_backticks.yaml +69 -0
- package/python/config/default_mm_no_images.yaml +82 -0
- package/python/config/default_mm_with_images.yaml +83 -0
- package/python/config/demo/default.yaml +80 -0
- package/python/config/demo/no_instructions.yaml +69 -0
- package/python/config/demo/only_bash.yaml +60 -0
- package/python/config/exotic/default_shell.yaml +52 -0
- package/python/config/exotic/windowed_replace.yaml +125 -0
- package/python/config/exotic/windowed_replace_late_repro.yaml +127 -0
- package/python/config/human/human.yaml +24 -0
- package/python/config/human/human_demo.yaml +52 -0
- package/python/config/sweagent_0_7/07.yaml +101 -0
- package/python/config/sweagent_0_7/07_fcalling.yaml +100 -0
- package/python/config/sweagent_0_7/07_from_url.yaml +114 -0
- package/python/config/sweagent_0_7/07_thought_action.yaml +102 -0
- package/python/config/sweagent_0_7/07_thought_action_xml.yaml +96 -0
- package/python/mlc_config.json +44 -0
- package/python/pyproject.toml +262 -0
- package/python/sweagent/__init__.py +114 -0
- package/python/sweagent/__main__.py +4 -0
- package/python/sweagent/agent/__init__.py +0 -0
- package/python/sweagent/agent/action_sampler.py +317 -0
- package/python/sweagent/agent/agents.py +1294 -0
- package/python/sweagent/agent/extra/shell_agent.py +106 -0
- package/python/sweagent/agent/history_processors.py +399 -0
- package/python/sweagent/agent/hooks/__init__.py +0 -0
- package/python/sweagent/agent/hooks/abstract.py +139 -0
- package/python/sweagent/agent/hooks/status.py +34 -0
- package/python/sweagent/agent/models.py +896 -0
- package/python/sweagent/agent/problem_statement.py +312 -0
- package/python/sweagent/agent/reviewer.py +664 -0
- package/python/sweagent/environment/__init__.py +0 -0
- package/python/sweagent/environment/hooks/__init__.py +0 -0
- package/python/sweagent/environment/hooks/abstract.py +60 -0
- package/python/sweagent/environment/hooks/status.py +28 -0
- package/python/sweagent/environment/repo.py +219 -0
- package/python/sweagent/environment/swe_env.py +276 -0
- package/python/sweagent/exceptions.py +54 -0
- package/python/sweagent/inspector/README.md +6 -0
- package/python/sweagent/inspector/__init__.py +0 -0
- package/python/sweagent/inspector/favicon.ico +0 -0
- package/python/sweagent/inspector/fileViewer.js +354 -0
- package/python/sweagent/inspector/icons/computer.png +0 -0
- package/python/sweagent/inspector/icons/edit_icon.svg +11 -0
- package/python/sweagent/inspector/icons/swe-agent-logo-50.png +0 -0
- package/python/sweagent/inspector/icons/swellama_blue.png +0 -0
- package/python/sweagent/inspector/icons/swellama_brown.png +0 -0
- package/python/sweagent/inspector/icons/swellama_grey.png +0 -0
- package/python/sweagent/inspector/icons/swellama_tan.png +0 -0
- package/python/sweagent/inspector/index.html +25 -0
- package/python/sweagent/inspector/server.py +354 -0
- package/python/sweagent/inspector/static.py +169 -0
- package/python/sweagent/inspector/style.css +454 -0
- package/python/sweagent/run/__init__.py +0 -0
- package/python/sweagent/run/_progress.py +158 -0
- package/python/sweagent/run/batch_instances.py +419 -0
- package/python/sweagent/run/common.py +387 -0
- package/python/sweagent/run/compare_runs.py +123 -0
- package/python/sweagent/run/extract_pred.py +19 -0
- package/python/sweagent/run/hooks/__init__.py +0 -0
- package/python/sweagent/run/hooks/abstract.py +67 -0
- package/python/sweagent/run/hooks/apply_patch.py +106 -0
- package/python/sweagent/run/hooks/open_pr.py +244 -0
- package/python/sweagent/run/hooks/swe_bench_evaluate.py +113 -0
- package/python/sweagent/run/inspector_cli.py +493 -0
- package/python/sweagent/run/merge_predictions.py +64 -0
- package/python/sweagent/run/quick_stats.py +96 -0
- package/python/sweagent/run/remove_unfinished.py +63 -0
- package/python/sweagent/run/rich_test.py +91 -0
- package/python/sweagent/run/run.py +147 -0
- package/python/sweagent/run/run_batch.py +442 -0
- package/python/sweagent/run/run_replay.py +219 -0
- package/python/sweagent/run/run_shell.py +155 -0
- package/python/sweagent/run/run_single.py +225 -0
- package/python/sweagent/run/run_traj_to_demo.py +85 -0
- package/python/sweagent/tools/__init__.py +0 -0
- package/python/sweagent/tools/bundle.py +57 -0
- package/python/sweagent/tools/commands.py +220 -0
- package/python/sweagent/tools/parsing.py +619 -0
- package/python/sweagent/tools/tools.py +430 -0
- package/python/sweagent/tools/utils.py +108 -0
- package/python/sweagent/types.py +102 -0
- package/python/sweagent/utils/__init__.py +0 -0
- package/python/sweagent/utils/config.py +80 -0
- package/python/sweagent/utils/files.py +27 -0
- package/python/sweagent/utils/github.py +118 -0
- package/python/sweagent/utils/jinja_warnings.py +14 -0
- package/python/sweagent/utils/log.py +175 -0
- package/python/sweagent/utils/patch_formatter.py +152 -0
- package/python/sweagent/utils/serialization.py +45 -0
- package/python/tests/__init__.py +0 -0
- package/python/tests/conftest.py +191 -0
- package/python/tests/test_agent.py +258 -0
- package/python/tests/test_batch_instance.py +43 -0
- package/python/tests/test_commands/_interactive_dummy.py +35 -0
- package/python/tests/test_commands/interactive_dummy_wrapper.sh +29 -0
- package/python/tests/test_data/config_files/dummy_interactive.yaml +62 -0
- package/python/tests/test_data/data_sources/ctf/crypto/Katy/Dockerfile +20 -0
- package/python/tests/test_data/data_sources/ctf/crypto/Katy/README.md +13 -0
- package/python/tests/test_data/data_sources/ctf/crypto/Katy/challenge.json +12 -0
- package/python/tests/test_data/data_sources/ctf/crypto/Katy/customrandom.c +50 -0
- package/python/tests/test_data/data_sources/ctf/crypto/Katy/docker-compose.yml +14 -0
- package/python/tests/test_data/data_sources/ctf/crypto/Katy/release +0 -0
- package/python/tests/test_data/data_sources/ctf/crypto/Katy/server +0 -0
- package/python/tests/test_data/data_sources/ctf/crypto/Katy/solver.py +12 -0
- package/python/tests/test_data/data_sources/ctf/forensics/flash/README.md +16 -0
- package/python/tests/test_data/data_sources/ctf/forensics/flash/challenge.json +9 -0
- package/python/tests/test_data/data_sources/ctf/forensics/flash/flash_c8429a430278283c0e571baebca3d139.zip +0 -0
- package/python/tests/test_data/data_sources/ctf/misc/networking_1/README.md +15 -0
- package/python/tests/test_data/data_sources/ctf/misc/networking_1/challenge.json +10 -0
- package/python/tests/test_data/data_sources/ctf/misc/networking_1/networking.pcap +0 -0
- package/python/tests/test_data/data_sources/ctf/pwn/warmup/Dockerfile +28 -0
- package/python/tests/test_data/data_sources/ctf/pwn/warmup/README.md +14 -0
- package/python/tests/test_data/data_sources/ctf/pwn/warmup/challenge.json +14 -0
- package/python/tests/test_data/data_sources/ctf/pwn/warmup/docker-compose.yml +14 -0
- package/python/tests/test_data/data_sources/ctf/pwn/warmup/flag.txt +1 -0
- package/python/tests/test_data/data_sources/ctf/pwn/warmup/warmup +0 -0
- package/python/tests/test_data/data_sources/ctf/pwn/warmup/warmup.c +26 -0
- package/python/tests/test_data/data_sources/ctf/pwn/warmup/warmup.py +9 -0
- package/python/tests/test_data/data_sources/ctf/rev/rock/README.md +14 -0
- package/python/tests/test_data/data_sources/ctf/rev/rock/challenge.json +8 -0
- package/python/tests/test_data/data_sources/ctf/rev/rock/rock +0 -0
- package/python/tests/test_data/data_sources/ctf/rev/rock/rock.cpp +167 -0
- package/python/tests/test_data/data_sources/ctf/rev/rock/solution.cpp +24 -0
- package/python/tests/test_data/data_sources/ctf/rev/rock/test_solver/solution.py +6 -0
- package/python/tests/test_data/data_sources/ctf/rev/rock/test_solver/test.sh +10 -0
- package/python/tests/test_data/data_sources/ctf/web/i_got_id_demo/000-default.conf +18 -0
- package/python/tests/test_data/data_sources/ctf/web/i_got_id_demo/Dockerfile +20 -0
- package/python/tests/test_data/data_sources/ctf/web/i_got_id_demo/cgi/file.pl +38 -0
- package/python/tests/test_data/data_sources/ctf/web/i_got_id_demo/cgi/forms.pl +40 -0
- package/python/tests/test_data/data_sources/ctf/web/i_got_id_demo/cgi/hello.pl +11 -0
- package/python/tests/test_data/data_sources/ctf/web/i_got_id_demo/challenge.json +12 -0
- package/python/tests/test_data/data_sources/ctf/web/i_got_id_demo/docker-compose.yml +14 -0
- package/python/tests/test_data/data_sources/ctf/web/i_got_id_demo/flag +1 -0
- package/python/tests/test_data/data_sources/ctf/web/i_got_id_demo/index.html +11 -0
- package/python/tests/test_data/data_sources/ctf/web/i_got_id_demo/solution.txt +1 -0
- package/python/tests/test_data/data_sources/debug_20240322.json +1 -0
- package/python/tests/test_data/data_sources/expert_instances.yaml +16 -0
- package/python/tests/test_data/data_sources/human_eval.json +1 -0
- package/python/tests/test_data/data_sources/simple_instances.yaml +3 -0
- package/python/tests/test_data/data_sources/simple_instances_long.yaml +30 -0
- package/python/tests/test_data/data_sources/swe-bench-dev-easy.json +1 -0
- package/python/tests/test_data/data_sources/swe-bench-dev-easy_first_only.json +1 -0
- package/python/tests/test_data/data_sources/swe-bench-lite-test.json +1 -0
- package/python/tests/test_data/trajectories/gpt4__swe-agent-test-repo__default_from_url__t-0.00__p-0.95__c-3.00__install-1/6e44b9__sweagenttestrepo-1c2844.traj +342 -0
- package/python/tests/test_data/trajectories/gpt4__swe-agent-test-repo__default_from_url__t-0.00__p-0.95__c-3.00__install-1/solution_missing_colon.py +15 -0
- package/python/tests/test_data/trajectories/gpt4__swe-agent__test-repo__default_from_url__t-0.00__p-0.95__c-3.00__install-1/args.yaml +518 -0
- package/python/tests/test_data/trajectories/gpt4__swe-agent__test-repo__default_from_url__t-0.00__p-0.95__c-3.00__install-1/swe-agent__test-repo-i1.traj +124 -0
- package/python/tests/test_data/trajectories/gpt4__swe-bench-dev-easy_first_only__default__t-0.00__p-0.95__c-3.00__install-1/all_preds.jsonl +1 -0
- package/python/tests/test_data/trajectories/gpt4__swe-bench-dev-easy_first_only__default__t-0.00__p-0.95__c-3.00__install-1/args.yaml +520 -0
- package/python/tests/test_data/trajectories/gpt4__swe-bench-dev-easy_first_only__default__t-0.00__p-0.95__c-3.00__install-1/patches/pydicom__pydicom-1458.patch +18 -0
- package/python/tests/test_data/trajectories/gpt4__swe-bench-dev-easy_first_only__default__t-0.00__p-0.95__c-3.00__install-1/pydicom__pydicom-1458.traj +257 -0
- package/python/tests/test_env.py +66 -0
- package/python/tests/test_env_utils.py +129 -0
- package/python/tests/test_history_processors.py +40 -0
- package/python/tests/test_models.py +23 -0
- package/python/tests/test_openai_live.py +164 -0
- package/python/tests/test_packaging.py +7 -0
- package/python/tests/test_parsing.py +131 -0
- package/python/tests/test_problem_statement_multimodal.py +111 -0
- package/python/tests/test_quick_stats.py +42 -0
- package/python/tests/test_run.py +37 -0
- package/python/tests/test_run_batch.py +110 -0
- package/python/tests/test_run_hooks.py +114 -0
- package/python/tests/test_run_replay.py +33 -0
- package/python/tests/test_run_single.py +125 -0
- package/python/tests/test_tools_command_parsing.py +193 -0
- package/python/tests/test_utils.py +15 -0
- package/python/tests/tools/__init__.py +0 -0
- package/python/tests/tools/conftest.py +12 -0
- package/python/tests/tools/test_default_utils.py +153 -0
- package/python/tests/tools/test_edit_replace.py +0 -0
- package/python/tests/tools/test_split_string.py +82 -0
- package/python/tests/utils.py +29 -0
- package/python/tools/diff_state/bin/_state_diff_state +52 -0
- package/python/tools/diff_state/config.yaml +2 -0
- package/python/tools/edit_anthropic/bin/_state_anthropic +21 -0
- package/python/tools/edit_anthropic/bin/str_replace_editor +710 -0
- package/python/tools/edit_anthropic/config.yaml +56 -0
- package/python/tools/edit_anthropic/install.sh +3 -0
- package/python/tools/filemap/bin/filemap +45 -0
- package/python/tools/filemap/config.yaml +9 -0
- package/python/tools/filemap/install.sh +2 -0
- package/python/tools/forfeit/bin/exit_forfeit +5 -0
- package/python/tools/forfeit/config.yaml +5 -0
- package/python/tools/image_tools/bin/view_image +36 -0
- package/python/tools/image_tools/config.yaml +9 -0
- package/python/tools/multilingual_setup/bin/do_nothing +2 -0
- package/python/tools/multilingual_setup/config.yaml +1 -0
- package/python/tools/multilingual_setup/install.sh +45 -0
- package/python/tools/registry/bin/_read_env +10 -0
- package/python/tools/registry/bin/_write_env +10 -0
- package/python/tools/registry/config.yaml +1 -0
- package/python/tools/registry/install.sh +6 -0
- package/python/tools/registry/lib/__init__.py +0 -0
- package/python/tools/registry/lib/registry.py +56 -0
- package/python/tools/review_on_submit_m/README.md +6 -0
- package/python/tools/review_on_submit_m/bin/submit +54 -0
- package/python/tools/review_on_submit_m/config.yaml +6 -0
- package/python/tools/review_on_submit_m/install.sh +0 -0
- package/python/tools/search/bin/find_file +31 -0
- package/python/tools/search/bin/search_dir +39 -0
- package/python/tools/search/bin/search_file +55 -0
- package/python/tools/search/config.yaml +37 -0
- package/python/tools/search/install.sh +3 -0
- package/python/tools/submit/bin/submit +17 -0
- package/python/tools/submit/config.yaml +5 -0
- package/python/tools/web_browser/bin/click_mouse +41 -0
- package/python/tools/web_browser/bin/close_site +28 -0
- package/python/tools/web_browser/bin/double_click_mouse +37 -0
- package/python/tools/web_browser/bin/drag_mouse +46 -0
- package/python/tools/web_browser/bin/execute_script_on_page +39 -0
- package/python/tools/web_browser/bin/get_console_output +48 -0
- package/python/tools/web_browser/bin/move_mouse +35 -0
- package/python/tools/web_browser/bin/navigate_back +33 -0
- package/python/tools/web_browser/bin/navigate_forward +33 -0
- package/python/tools/web_browser/bin/open_site +36 -0
- package/python/tools/web_browser/bin/press_keys_on_page +51 -0
- package/python/tools/web_browser/bin/reload_page +33 -0
- package/python/tools/web_browser/bin/run_web_browser_server +394 -0
- package/python/tools/web_browser/bin/screenshot_site +38 -0
- package/python/tools/web_browser/bin/scroll_on_page +40 -0
- package/python/tools/web_browser/bin/set_browser_window_size +40 -0
- package/python/tools/web_browser/bin/type_text +34 -0
- package/python/tools/web_browser/bin/wait_time +39 -0
- package/python/tools/web_browser/config.yaml +155 -0
- package/python/tools/web_browser/install.sh +22 -0
- package/python/tools/web_browser/lib/browser_manager.py +404 -0
- package/python/tools/web_browser/lib/web_browser_config.py +33 -0
- package/python/tools/web_browser/lib/web_browser_utils.py +126 -0
- package/python/tools/web_browser/test_console.html +1 -0
- package/python/tools/windowed/bin/_state +25 -0
- package/python/tools/windowed/bin/create +29 -0
- package/python/tools/windowed/bin/goto +37 -0
- package/python/tools/windowed/bin/open +49 -0
- package/python/tools/windowed/bin/scroll_down +12 -0
- package/python/tools/windowed/bin/scroll_up +13 -0
- package/python/tools/windowed/config.yaml +38 -0
- package/python/tools/windowed/install.sh +15 -0
- package/python/tools/windowed/lib/__init__.py +0 -0
- package/python/tools/windowed/lib/flake8_utils.py +147 -0
- package/python/tools/windowed/lib/windowed_file.py +312 -0
- package/python/tools/windowed_edit_linting/bin/edit +128 -0
- package/python/tools/windowed_edit_linting/config.yaml +31 -0
- package/python/tools/windowed_edit_linting/install.sh +5 -0
- package/python/tools/windowed_edit_replace/bin/edit +172 -0
- package/python/tools/windowed_edit_replace/bin/insert +77 -0
- package/python/tools/windowed_edit_replace/config.yaml +60 -0
- package/python/tools/windowed_edit_replace/install.sh +5 -0
- package/python/tools/windowed_edit_rewrite/bin/edit +78 -0
- package/python/tools/windowed_edit_rewrite/config.yaml +11 -0
- package/python/tools/windowed_edit_rewrite/install.sh +5 -0
- package/python/trajectories/demonstrations/ctf/crypto/BabyEncryption.traj +318 -0
- package/python/trajectories/demonstrations/ctf/crypto/BabyTimeCapsule.traj +197 -0
- package/python/trajectories/demonstrations/ctf/crypto/eps.traj +289 -0
- package/python/trajectories/demonstrations/ctf/crypto/katy.traj +368 -0
- package/python/trajectories/demonstrations/ctf/forensics/flash.traj +102 -0
- package/python/trajectories/demonstrations/ctf/misc/networking_1.traj +102 -0
- package/python/trajectories/demonstrations/ctf/pwn/warmup.traj +159 -0
- package/python/trajectories/demonstrations/ctf/rev/rock.traj +251 -0
- package/python/trajectories/demonstrations/ctf/web/i_got_id_demo.traj +422 -0
- package/python/trajectories/demonstrations/function_calling_simple.traj +151 -0
- package/python/trajectories/demonstrations/human_thought__swe-bench-HumanEvalFix-python__lcb__t-0.00__p-0.95__c-4.00__install-0/humanevalfix-python-0.traj +129 -0
- package/python/trajectories/demonstrations/replay__marshmallow-code__marshmallow-1867__default__t-0.20__p-0.95__c-2.00__install-1___install_from_source/marshmallow-code__marshmallow-1867.traj +318 -0
- package/python/trajectories/demonstrations/replay__marshmallow-code__marshmallow-1867__default_sys-env_cursors_window100__t-0.20__p-0.95__c-2.00__install-1/marshmallow-code__marshmallow-1867.traj +251 -0
- package/python/trajectories/demonstrations/replay__marshmallow-code__marshmallow-1867__default_sys-env_window100__t-0.20__p-0.95__c-2.00__install-1/marshmallow-code__marshmallow-1867.traj +399 -0
- package/python/trajectories/demonstrations/replay__marshmallow-code__marshmallow-1867__function_calling__install-1/marshmallow-code__marshmallow-1867.traj +594 -0
- package/python/trajectories/demonstrations/replay__marshmallow-code__marshmallow-1867__function_calling_replace__install-1/marshmallow-code__marshmallow-1867.traj +592 -0
- package/python/trajectories/demonstrations/replay__marshmallow-code__marshmallow-1867__function_calling_replace_from_source/marshmallow-code__marshmallow-1867.traj +3316 -0
- package/python/trajectories/demonstrations/replay__marshmallow-code__marshmallow-1867__xml_sys-env_cursors_window100__t-0.20__p-0.95__c-2.00__install-1/marshmallow-code__marshmallow-1867.traj +251 -0
- package/python/trajectories/demonstrations/replay__marshmallow-code__marshmallow-1867__xml_sys-env_window100__t-0.20__p-0.95__c-2.00__install-1/marshmallow-code__marshmallow-1867.traj +399 -0
- package/python/trajectories/demonstrations/str_replace_anthropic_demo.yaml +432 -0
- package/rust/Cargo.toml +100 -0
- package/rust/README.md +49 -0
- package/rust/src/agent/action_sampler.rs +130 -0
- package/rust/src/agent/agents.rs +1029 -0
- package/rust/src/agent/history_processors.rs +277 -0
- package/rust/src/agent/hooks/mod.rs +208 -0
- package/rust/src/agent/mod.rs +24 -0
- package/rust/src/agent/models.rs +837 -0
- package/rust/src/agent/problem_statement.rs +355 -0
- package/rust/src/agent/reviewer.rs +505 -0
- package/rust/src/bin/sweagent.rs +784 -0
- package/rust/src/environment/deployment.rs +631 -0
- package/rust/src/environment/hooks/mod.rs +114 -0
- package/rust/src/environment/mod.rs +16 -0
- package/rust/src/environment/repo.rs +265 -0
- package/rust/src/environment/runtime.rs +237 -0
- package/rust/src/environment/swe_env.rs +248 -0
- package/rust/src/exceptions.rs +228 -0
- package/rust/src/lib.rs +68 -0
- package/rust/src/monitoring.rs +482 -0
- package/rust/src/run/hooks/mod.rs +134 -0
- package/rust/src/run/mod.rs +12 -0
- package/rust/src/run/run_batch.rs +563 -0
- package/rust/src/run/run_single.rs +196 -0
- package/rust/src/tools/bundle.rs +224 -0
- package/rust/src/tools/commands.rs +173 -0
- package/rust/src/tools/mod.rs +295 -0
- package/rust/src/tools/parsing.rs +354 -0
- package/rust/src/tools/registry.rs +143 -0
- package/rust/src/types.rs +554 -0
- package/rust/src/utils/config.rs +105 -0
- package/rust/src/utils/files.rs +137 -0
- package/rust/src/utils/github.rs +171 -0
- package/rust/src/utils/log.rs +65 -0
- package/rust/src/utils/mod.rs +17 -0
- package/rust/src/utils/serialization.rs +181 -0
- package/rust/src/utils/template.rs +173 -0
- package/typescript/README.md +335 -0
|
@@ -0,0 +1,664 @@
|
|
|
1
|
+
"""The reviewer implements a retry loop for the agent to retry
|
|
2
|
+
solving the issue and to select the best solution.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import copy
|
|
8
|
+
import re
|
|
9
|
+
from abc import ABC, abstractmethod
|
|
10
|
+
from typing import Any, Literal
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
from jinja2 import Template
|
|
14
|
+
from pydantic import BaseModel, ConfigDict
|
|
15
|
+
|
|
16
|
+
from sweagent.agent.history_processors import _set_cache_control
|
|
17
|
+
from sweagent.agent.models import (
|
|
18
|
+
AbstractModel,
|
|
19
|
+
InstanceStats,
|
|
20
|
+
ModelConfig,
|
|
21
|
+
get_model,
|
|
22
|
+
)
|
|
23
|
+
from sweagent.agent.problem_statement import ProblemStatement
|
|
24
|
+
from sweagent.tools.parsing import ActionParser
|
|
25
|
+
from sweagent.tools.tools import ToolConfig
|
|
26
|
+
from sweagent.types import AgentInfo, Trajectory, TrajectoryStep
|
|
27
|
+
from sweagent.utils.log import get_logger
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ReviewSubmission(BaseModel):
|
|
31
|
+
"""Information that's passed to the reviewer"""
|
|
32
|
+
|
|
33
|
+
#: Total trajectory (including several retries)
|
|
34
|
+
trajectory: Trajectory
|
|
35
|
+
#: Aggregate info dict (including several retries)
|
|
36
|
+
info: AgentInfo
|
|
37
|
+
#: Model stats for this attempt
|
|
38
|
+
model_stats: InstanceStats
|
|
39
|
+
|
|
40
|
+
def to_format_dict(self, *, suffix="") -> dict[str, Any]:
|
|
41
|
+
"""Return all the data that is used to format the
|
|
42
|
+
messages. Trajectory is excluded because it needs special treatment.
|
|
43
|
+
"""
|
|
44
|
+
out = {}
|
|
45
|
+
info = copy.deepcopy(self.info)
|
|
46
|
+
if not info.get("submission"):
|
|
47
|
+
# Observed that not all exit_cost lead to autosubmission
|
|
48
|
+
# so sometimes this might be missing.
|
|
49
|
+
info["submission"] = ""
|
|
50
|
+
for k, v in info.items():
|
|
51
|
+
if isinstance(v, str):
|
|
52
|
+
out[f"{k}{suffix}"] = v
|
|
53
|
+
elif isinstance(v, dict):
|
|
54
|
+
for k2, v2 in v.items():
|
|
55
|
+
out[f"{k}_{k2}{suffix}"] = v2
|
|
56
|
+
return out
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class ReviewerResult(BaseModel):
|
|
60
|
+
accept: bool | float
|
|
61
|
+
outputs: list[str]
|
|
62
|
+
messages: list[dict[str, Any]]
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class PreselectorOutput(BaseModel):
|
|
66
|
+
chosen_idx: list[int]
|
|
67
|
+
response: str
|
|
68
|
+
messages: list[dict[str, Any]]
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class ChooserOutput(BaseModel):
|
|
72
|
+
chosen_idx: int
|
|
73
|
+
response: str
|
|
74
|
+
preselector_output: PreselectorOutput | None = None
|
|
75
|
+
messages: list[dict[str, Any]]
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
# --- INTERFACES ---
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class AbstractReviewer(ABC):
|
|
82
|
+
"""The reviewer checks a single solution and tries to predict
|
|
83
|
+
if it successfully solves the issue.
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
@abstractmethod
|
|
87
|
+
def review(self, instance: ProblemStatement, submission: ReviewSubmission) -> ReviewerResult:
|
|
88
|
+
"""Returns True if the submission is believed to be correct"""
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class AbstractRetryLoop(ABC):
|
|
92
|
+
"""The review loop controls how often the agent tries to solve
|
|
93
|
+
the issue and how it selects the best solution.
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
def retry(self) -> bool:
|
|
97
|
+
"""Returns True if the agent should retry solving the issue"""
|
|
98
|
+
return False
|
|
99
|
+
|
|
100
|
+
def on_submit(self, submission: ReviewSubmission) -> None:
|
|
101
|
+
"""Called when the agent submits a solution"""
|
|
102
|
+
|
|
103
|
+
def on_model_query(self, attempt_stats: InstanceStats):
|
|
104
|
+
"""Called before the model is queried. Can be used to implement
|
|
105
|
+
stop conditions based on attempt cost etc.
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
def on_attempt_started(self, i_attempt: int, agent):
|
|
109
|
+
"""Called when a new attempt is started"""
|
|
110
|
+
pass
|
|
111
|
+
|
|
112
|
+
@abstractmethod
|
|
113
|
+
def get_best(self) -> int:
|
|
114
|
+
"""Returns the best solution"""
|
|
115
|
+
|
|
116
|
+
def get_forwarded_vars(self) -> dict[str, Any]:
|
|
117
|
+
"""Get the variables that should be forwarded to the next iteration.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
A dictionary of variables that should be forwarded to the next iteration.
|
|
121
|
+
"""
|
|
122
|
+
return {}
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
# --- CONFIGS ---
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class PreselectorConfig(BaseModel):
|
|
129
|
+
model: ModelConfig
|
|
130
|
+
system_template: str
|
|
131
|
+
instance_template: str
|
|
132
|
+
submission_template: str
|
|
133
|
+
max_len_submission: int = 5000
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class ChooserConfig(BaseModel):
|
|
137
|
+
model: ModelConfig
|
|
138
|
+
system_template: str
|
|
139
|
+
instance_template: str
|
|
140
|
+
submission_template: str
|
|
141
|
+
max_len_submission: int = 5000
|
|
142
|
+
preselector: PreselectorConfig | None = None
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class TrajFormatterConfig(BaseModel):
|
|
146
|
+
#: Filter the following actions from the trajectory
|
|
147
|
+
filter: list[str] = []
|
|
148
|
+
#: Filter outputs from the following actions from the trajectory
|
|
149
|
+
output_filter: list[str] = []
|
|
150
|
+
#: Format of the trajectory item
|
|
151
|
+
item_template: str = "Model: {{response}}\n\nObservation: {{observation}}"
|
|
152
|
+
only_show_last_n_output: int = 0
|
|
153
|
+
|
|
154
|
+
model_config = ConfigDict(extra="forbid")
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class ReviewerConfig(BaseModel):
|
|
158
|
+
"""The configuration for the reviewer"""
|
|
159
|
+
|
|
160
|
+
system_template: str
|
|
161
|
+
instance_template: str
|
|
162
|
+
#: If a submission autosubmits because of total cost or a similar exit status,
|
|
163
|
+
#: it will get this malus to its score
|
|
164
|
+
failure_score_penalty: float = 0.0
|
|
165
|
+
traj_formatter: TrajFormatterConfig
|
|
166
|
+
n_sample: int = 5
|
|
167
|
+
reduce_by_std: float = 0.0
|
|
168
|
+
score_range: tuple[float | None, float | None] = (None, None)
|
|
169
|
+
#: If set, we assume that the score is in the range [score_range[0], score_range[1]]
|
|
170
|
+
#: Reviews that are outside this range will be ignored
|
|
171
|
+
|
|
172
|
+
type: Literal["reviewer"] = "reviewer"
|
|
173
|
+
|
|
174
|
+
model_config = ConfigDict(extra="forbid")
|
|
175
|
+
|
|
176
|
+
def get_reviewer(self, model: AbstractModel) -> AbstractReviewer:
|
|
177
|
+
return Reviewer(self, model)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class ChooserRetryLoopConfig(BaseModel):
|
|
181
|
+
type: Literal["chooser"] = "chooser"
|
|
182
|
+
chooser: ChooserConfig
|
|
183
|
+
|
|
184
|
+
max_attempts: int
|
|
185
|
+
min_budget_for_new_attempt: float = 0.0
|
|
186
|
+
"""Minimal $ that need to be left in order for us to start a new attempt.
|
|
187
|
+
If set to 0: Always.
|
|
188
|
+
"""
|
|
189
|
+
|
|
190
|
+
cost_limit: float
|
|
191
|
+
"""The maximum cost to spend on all attempts. Does not include cost of choosing.
|
|
192
|
+
"""
|
|
193
|
+
|
|
194
|
+
model_config = ConfigDict(extra="forbid")
|
|
195
|
+
|
|
196
|
+
def get_retry_loop(self, problem_statement: ProblemStatement) -> ChooserRetryLoop:
|
|
197
|
+
return ChooserRetryLoop(self, problem_statement)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class ScoreRetryLoopConfig(BaseModel):
|
|
201
|
+
"""The configuration for the review loop"""
|
|
202
|
+
|
|
203
|
+
type: Literal["score"] = "score"
|
|
204
|
+
|
|
205
|
+
reviewer_config: ReviewerConfig
|
|
206
|
+
|
|
207
|
+
accept_score: float
|
|
208
|
+
max_accepts: int = 1
|
|
209
|
+
max_attempts: int
|
|
210
|
+
|
|
211
|
+
min_budget_for_new_attempt: float = 0.0
|
|
212
|
+
"""Minimal $ that need to be left in order for us to start a new attempt.
|
|
213
|
+
If set to 0: Always.
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
cost_limit: float
|
|
217
|
+
"""The maximum cost to spend on all attempts and reviews except the last review.
|
|
218
|
+
The last review is not included in the cost limit, because we would waste the last
|
|
219
|
+
attempt if we couldn't score it.
|
|
220
|
+
"""
|
|
221
|
+
|
|
222
|
+
model: ModelConfig
|
|
223
|
+
|
|
224
|
+
model_config = ConfigDict(extra="forbid")
|
|
225
|
+
|
|
226
|
+
def validate(self):
|
|
227
|
+
"""Checks config. Raises `ValueError` in case of misconfiguration"""
|
|
228
|
+
...
|
|
229
|
+
|
|
230
|
+
def __post_init__(self):
|
|
231
|
+
self.validate()
|
|
232
|
+
|
|
233
|
+
def get_retry_loop(self, problem_statement: ProblemStatement) -> ScoreRetryLoop:
|
|
234
|
+
return ScoreRetryLoop(self, problem_statement)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
RetryLoopConfig = ScoreRetryLoopConfig | ChooserRetryLoopConfig
|
|
238
|
+
|
|
239
|
+
# --- IMPLEMENTATIONS ---
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
class Preselector:
|
|
243
|
+
def __init__(self, config: PreselectorConfig):
|
|
244
|
+
self.config = config
|
|
245
|
+
self.model = get_model(config.model, ToolConfig(parse_function=ActionParser()))
|
|
246
|
+
self.logger = get_logger("chooser", emoji="🧠")
|
|
247
|
+
|
|
248
|
+
def interpret(self, response: str) -> list[int]:
|
|
249
|
+
if not response:
|
|
250
|
+
self.logger.warning("No response from preselector")
|
|
251
|
+
return []
|
|
252
|
+
# Use regex to extract the last number of the response
|
|
253
|
+
last_line = response.splitlines()[-1]
|
|
254
|
+
try:
|
|
255
|
+
return [int(i) for i in re.findall(r"\d+", last_line)]
|
|
256
|
+
except Exception as e:
|
|
257
|
+
self.logger.error(f"Error interpreting response: {e}")
|
|
258
|
+
return []
|
|
259
|
+
|
|
260
|
+
def format_submission(self, problem_statement: str, submission: ReviewSubmission) -> str:
|
|
261
|
+
if (
|
|
262
|
+
submission.info.get("submission") is None
|
|
263
|
+
or len(submission.info.get("submission", "")) > self.config.max_len_submission > 0 # type: ignore
|
|
264
|
+
):
|
|
265
|
+
return "Solution invalid."
|
|
266
|
+
return Template(self.config.submission_template).render(
|
|
267
|
+
**submission.to_format_dict(),
|
|
268
|
+
# summary=self.summarizer.summarize(problem_statement, submission.trajectory) if self.summarizer else "",
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
def build_messages(self, problem_statement: str, input: list[ReviewSubmission]) -> list[dict[str, Any]]:
|
|
272
|
+
instance_message = Template(self.config.instance_template).render(
|
|
273
|
+
problem_statement=problem_statement,
|
|
274
|
+
submissions=[self.format_submission(problem_statement, s) for s in input],
|
|
275
|
+
)
|
|
276
|
+
self.logger.debug(f"MODEL INPUT (user)\n{instance_message}")
|
|
277
|
+
return [
|
|
278
|
+
{"role": "system", "content": self.config.system_template},
|
|
279
|
+
{"role": "user", "content": instance_message},
|
|
280
|
+
]
|
|
281
|
+
|
|
282
|
+
def choose(self, problem_statement: str, input: list[ReviewSubmission]) -> PreselectorOutput:
|
|
283
|
+
messages = self.build_messages(problem_statement, input)
|
|
284
|
+
response = self.model.query(messages)["message"] # type: ignore
|
|
285
|
+
indices = self.interpret(response)
|
|
286
|
+
if not indices:
|
|
287
|
+
self.logger.warning("No indices found in response, using all indices")
|
|
288
|
+
indices = list(range(len(input)))
|
|
289
|
+
return PreselectorOutput(chosen_idx=indices, response=response, messages=messages)
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
class Chooser:
|
|
293
|
+
def __init__(self, config: ChooserConfig):
|
|
294
|
+
self.config = config
|
|
295
|
+
self.model = get_model(config.model, ToolConfig(parse_function=ActionParser()))
|
|
296
|
+
self.logger = get_logger("chooser", emoji="🧠")
|
|
297
|
+
# self.summarizer = Summarizer(config.summarizer, self.model) if config.summarizer else None
|
|
298
|
+
|
|
299
|
+
def interpret(self, response: str) -> int:
|
|
300
|
+
# Use regex to extract the last number of the response
|
|
301
|
+
try:
|
|
302
|
+
return int(re.findall(r"\d+", response)[-1])
|
|
303
|
+
except Exception as e:
|
|
304
|
+
self.logger.error(f"Error interpreting response: {e}")
|
|
305
|
+
return 0
|
|
306
|
+
|
|
307
|
+
def format_submission(self, problem_statement: str, submission: ReviewSubmission) -> str:
|
|
308
|
+
if (
|
|
309
|
+
submission.info.get("submission") is None
|
|
310
|
+
or len(submission.info.get("submission", "")) > self.config.max_len_submission > 0 # type: ignore
|
|
311
|
+
):
|
|
312
|
+
return "Solution invalid."
|
|
313
|
+
return Template(self.config.submission_template).render(
|
|
314
|
+
**submission.to_format_dict(),
|
|
315
|
+
# summary=self.summarizer.summarize(problem_statement, submission.trajectory) if self.summarizer else "",
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
def build_messages(self, problem_statement: str, input: list[ReviewSubmission]) -> list[dict[str, Any]]:
|
|
319
|
+
instance_message = Template(self.config.instance_template).render(
|
|
320
|
+
problem_statement=problem_statement,
|
|
321
|
+
submissions=[self.format_submission(problem_statement, s) for s in input],
|
|
322
|
+
)
|
|
323
|
+
self.logger.debug(f"MODEL INPUT (user)\n{instance_message}")
|
|
324
|
+
return [
|
|
325
|
+
{"role": "system", "content": self.config.system_template},
|
|
326
|
+
{"role": "user", "content": instance_message},
|
|
327
|
+
]
|
|
328
|
+
|
|
329
|
+
def choose(self, problem_statement: str, input: list[ReviewSubmission]) -> ChooserOutput:
|
|
330
|
+
preselector_output = None
|
|
331
|
+
selected_indices = list(range(len(input)))
|
|
332
|
+
n_submitted = sum(s.info.get("exit_status", "") == "submitted" for s in input)
|
|
333
|
+
if n_submitted >= 2:
|
|
334
|
+
self.logger.debug(f"Got {n_submitted} submitted submissions, only using them")
|
|
335
|
+
selected_indices = [i for i, s in enumerate(input) if s.info.get("exit_status", "") == "submitted"]
|
|
336
|
+
else:
|
|
337
|
+
self.logger.debug(f"Got only {n_submitted} submitted submissions, disabling exit status filtering")
|
|
338
|
+
if self.config.preselector and len(selected_indices) > 2:
|
|
339
|
+
preselector = Preselector(self.config.preselector)
|
|
340
|
+
try:
|
|
341
|
+
preselector_output = preselector.choose(problem_statement, [input[i] for i in selected_indices])
|
|
342
|
+
except Exception as e:
|
|
343
|
+
self.logger.critical(f"Preselector failed: {e}", exc_info=True)
|
|
344
|
+
preselector_output = None
|
|
345
|
+
if preselector_output and preselector_output.chosen_idx:
|
|
346
|
+
try:
|
|
347
|
+
_preselected_indices = [selected_indices[i] for i in preselector_output.chosen_idx]
|
|
348
|
+
except IndexError:
|
|
349
|
+
_preselected_indices = []
|
|
350
|
+
self.logger.error("Preselector gave invalid indices, ignoring it.")
|
|
351
|
+
if not _preselected_indices:
|
|
352
|
+
self.logger.error("Preselector gave no valid indices, ignoring it.")
|
|
353
|
+
else:
|
|
354
|
+
selected_indices = _preselected_indices
|
|
355
|
+
else:
|
|
356
|
+
self.logger.error("Preselector must have failed, ignoring it.")
|
|
357
|
+
messages = self.build_messages(problem_statement, [input[i] for i in selected_indices])
|
|
358
|
+
chosen_idx = None
|
|
359
|
+
try:
|
|
360
|
+
response = self.model.query(messages)["message"] # type: ignore
|
|
361
|
+
chosen_idx = self.interpret(response)
|
|
362
|
+
except Exception as e:
|
|
363
|
+
self.logger.critical(f"Chooser failed: {e}", exc_info=True)
|
|
364
|
+
chosen_idx = None
|
|
365
|
+
if chosen_idx is None or not (0 <= chosen_idx < len(selected_indices)):
|
|
366
|
+
self.logger.error(f"Invalid chosen index: {chosen_idx}, using first index")
|
|
367
|
+
chosen_idx = selected_indices[0]
|
|
368
|
+
else:
|
|
369
|
+
chosen_idx = selected_indices[chosen_idx]
|
|
370
|
+
return ChooserOutput(
|
|
371
|
+
chosen_idx=chosen_idx, response=response, preselector_output=preselector_output, messages=messages
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
class Reviewer(AbstractReviewer):
|
|
376
|
+
def __init__(self, config: ReviewerConfig, model):
|
|
377
|
+
self._config = config
|
|
378
|
+
self._model = model
|
|
379
|
+
self._traj_formatter = TrajectoryFormatter(config=config.traj_formatter)
|
|
380
|
+
self.logger = get_logger("reviewer", emoji="🧑⚖️")
|
|
381
|
+
|
|
382
|
+
def format_messages(self, instance: ProblemStatement, submission: ReviewSubmission):
|
|
383
|
+
system_message = self._config.system_template
|
|
384
|
+
self.logger.debug(f"MODEL INPUT (system)\n{system_message}")
|
|
385
|
+
ps_format_dict = {
|
|
386
|
+
"problem_statement": instance.get_problem_statement(),
|
|
387
|
+
**instance.get_extra_fields(),
|
|
388
|
+
}
|
|
389
|
+
user_message = Template(self._config.instance_template).render(
|
|
390
|
+
**ps_format_dict,
|
|
391
|
+
**submission.to_format_dict(),
|
|
392
|
+
traj=self._traj_formatter.format_trajectory(submission.trajectory),
|
|
393
|
+
)
|
|
394
|
+
self.logger.debug(f"MODEL INPUT (user)\n{user_message}")
|
|
395
|
+
return [
|
|
396
|
+
{"role": "system", "content": system_message},
|
|
397
|
+
{"role": "user", "content": user_message},
|
|
398
|
+
]
|
|
399
|
+
|
|
400
|
+
def interpret(self, response: str) -> bool | float:
|
|
401
|
+
last_line = response.strip().split("\n")[-1].strip()
|
|
402
|
+
# Find all numbers in the last line and take the last one
|
|
403
|
+
numbers = re.findall(r"-?\d+\.?\d*", last_line)
|
|
404
|
+
if not numbers:
|
|
405
|
+
msg = f"Could not interpret response: {last_line!r}"
|
|
406
|
+
raise ValueError(msg)
|
|
407
|
+
number = float(numbers[-1])
|
|
408
|
+
if self._config.score_range[0] is not None and number < self._config.score_range[0]:
|
|
409
|
+
msg = f"Score {number} is below the minimum score {self._config.score_range[0]}"
|
|
410
|
+
raise ValueError(msg)
|
|
411
|
+
if self._config.score_range[1] is not None and number > self._config.score_range[1]:
|
|
412
|
+
msg = f"Score {number} is above the maximum score {self._config.score_range[1]}"
|
|
413
|
+
raise ValueError(msg)
|
|
414
|
+
return number
|
|
415
|
+
|
|
416
|
+
def review(self, instance: ProblemStatement, submission: ReviewSubmission) -> ReviewerResult:
|
|
417
|
+
exit_status = submission.info.get("exit_status")
|
|
418
|
+
messages = []
|
|
419
|
+
penalty = 0.0
|
|
420
|
+
if not exit_status or exit_status.strip() != "submitted":
|
|
421
|
+
penalty = self._config.failure_score_penalty
|
|
422
|
+
messages = self.format_messages(instance, submission)
|
|
423
|
+
if self._config.n_sample > 1:
|
|
424
|
+
_set_cache_control(messages[-1]) # type: ignore
|
|
425
|
+
answers = []
|
|
426
|
+
accepts = []
|
|
427
|
+
for _ in range(self._config.n_sample):
|
|
428
|
+
try:
|
|
429
|
+
answer = self._model.query(messages)["message"]
|
|
430
|
+
except Exception as e:
|
|
431
|
+
self.logger.warning(f"Query failed: {e}", exc_info=True)
|
|
432
|
+
continue
|
|
433
|
+
try:
|
|
434
|
+
score = self.interpret(answer)
|
|
435
|
+
except ValueError as e:
|
|
436
|
+
self.logger.warning(f"Could not interpret response: {answer!r}, got {e}")
|
|
437
|
+
continue
|
|
438
|
+
answers.append(answer)
|
|
439
|
+
accepts.append(score)
|
|
440
|
+
if not accepts:
|
|
441
|
+
answers = ["No valid scores found, failing submission"]
|
|
442
|
+
accepts = [-100.0]
|
|
443
|
+
accept = sum(accepts) / len(accepts) - penalty
|
|
444
|
+
std = np.std(accepts).item()
|
|
445
|
+
if self._config.reduce_by_std > 0:
|
|
446
|
+
accept -= std * self._config.reduce_by_std
|
|
447
|
+
self.logger.info(f"First answer: {answers[0]}")
|
|
448
|
+
self.logger.info(f"Final score: {accept} (penalty: {penalty}, std: {std}), individual: {accepts}")
|
|
449
|
+
return ReviewerResult(accept=accept, outputs=answers, messages=messages)
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
# todo: Couldn't I just replace the whole thing with Jinja templates?
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
class TrajectoryFormatter:
|
|
456
|
+
def __init__(
|
|
457
|
+
self,
|
|
458
|
+
config: TrajFormatterConfig,
|
|
459
|
+
):
|
|
460
|
+
"""Formats trajectories for the use in prompts"""
|
|
461
|
+
self._config = config
|
|
462
|
+
|
|
463
|
+
def _include_step(self, item: TrajectoryStep) -> bool:
|
|
464
|
+
action = item["action"].strip()
|
|
465
|
+
for f in self._config.filter:
|
|
466
|
+
if action.startswith(f):
|
|
467
|
+
return False
|
|
468
|
+
return True
|
|
469
|
+
|
|
470
|
+
def _include_step_output(self, item: TrajectoryStep, i_step: int, n_steps: int) -> bool:
|
|
471
|
+
if self._config.only_show_last_n_output > 0 and i_step < n_steps - self._config.only_show_last_n_output:
|
|
472
|
+
return False
|
|
473
|
+
action = item["action"].strip()
|
|
474
|
+
for f in self._config.output_filter:
|
|
475
|
+
if action.startswith(f):
|
|
476
|
+
return False
|
|
477
|
+
return True
|
|
478
|
+
|
|
479
|
+
def _format_trajectory_step(self, step: TrajectoryStep, i_step: int, *, n_steps: int, i_traj: int = 1) -> str:
|
|
480
|
+
step = copy.deepcopy(step)
|
|
481
|
+
if not self._include_step_output(step, i_step, n_steps=n_steps):
|
|
482
|
+
step["observation"] = "[Output omitted]"
|
|
483
|
+
return Template(self._config.item_template).render(
|
|
484
|
+
**step,
|
|
485
|
+
i_step=i_step,
|
|
486
|
+
i_traj=i_traj,
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
def format_trajectory(self, trajectory: Trajectory, i_traj: int = 1) -> str:
|
|
490
|
+
traj_messages = [step for step in trajectory if self._include_step(step)]
|
|
491
|
+
return "\n\n".join(
|
|
492
|
+
[
|
|
493
|
+
self._format_trajectory_step(step, i_step, i_traj=i_traj, n_steps=len(traj_messages))
|
|
494
|
+
for i_step, step in enumerate(traj_messages)
|
|
495
|
+
]
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
class ChooserRetryLoop(AbstractRetryLoop):
|
|
500
|
+
def __init__(self, config: ChooserRetryLoopConfig, problem_statement: ProblemStatement):
|
|
501
|
+
self._config = config
|
|
502
|
+
self._problem_statement = problem_statement
|
|
503
|
+
self._chooser = Chooser(config.chooser)
|
|
504
|
+
self._submissions: list[ReviewSubmission] = []
|
|
505
|
+
self._n_consec_exit_cost: int = 0
|
|
506
|
+
self.logger = get_logger("chooser_loop", emoji="🔄")
|
|
507
|
+
self._chooser_output: ChooserOutput | None = None
|
|
508
|
+
|
|
509
|
+
@property
|
|
510
|
+
def _total_stats(self) -> InstanceStats:
|
|
511
|
+
return sum((s.model_stats for s in self._submissions), start=InstanceStats())
|
|
512
|
+
|
|
513
|
+
@property
|
|
514
|
+
def review_model_stats(self) -> InstanceStats:
|
|
515
|
+
return InstanceStats()
|
|
516
|
+
|
|
517
|
+
@property
|
|
518
|
+
def _n_attempts(self) -> int:
|
|
519
|
+
return len(self._submissions)
|
|
520
|
+
|
|
521
|
+
def on_submit(self, submission: ReviewSubmission) -> None:
|
|
522
|
+
self._submissions.append(submission)
|
|
523
|
+
|
|
524
|
+
def retry(self) -> bool:
|
|
525
|
+
stat_str = f"n_samples={self._n_attempts}"
|
|
526
|
+
if self._total_stats.instance_cost > self._config.cost_limit > 0:
|
|
527
|
+
self.logger.info(
|
|
528
|
+
f"Exiting retry loop ({stat_str}): Total attempt cost ({self._total_stats.instance_cost}) "
|
|
529
|
+
f"exceeds cost limit ({self._config.cost_limit})"
|
|
530
|
+
)
|
|
531
|
+
return False
|
|
532
|
+
|
|
533
|
+
if self._n_attempts >= self._config.max_attempts > 0:
|
|
534
|
+
self.logger.info(f"Exiting retry loop ({stat_str}): max_attempts={self._config.max_attempts} reached")
|
|
535
|
+
return False
|
|
536
|
+
|
|
537
|
+
remaining_budget = self._config.cost_limit - self._total_stats.instance_cost
|
|
538
|
+
if self._config.min_budget_for_new_attempt > 0 and remaining_budget < self._config.min_budget_for_new_attempt:
|
|
539
|
+
msg = (
|
|
540
|
+
f"Exiting retry loop ({stat_str}): Not enough budget left for a new attempt "
|
|
541
|
+
f"({remaining_budget} remaining, {self._config.min_budget_for_new_attempt} required)"
|
|
542
|
+
)
|
|
543
|
+
self.logger.info(msg)
|
|
544
|
+
return False
|
|
545
|
+
|
|
546
|
+
return True
|
|
547
|
+
|
|
548
|
+
def get_best(self) -> int | None:
|
|
549
|
+
"""Important note: This is cached. Only call this at the end."""
|
|
550
|
+
if self._chooser_output is not None:
|
|
551
|
+
return self._chooser_output.chosen_idx
|
|
552
|
+
if len(self._submissions) == 0:
|
|
553
|
+
return None
|
|
554
|
+
self._chooser_output = self._chooser.choose(self._problem_statement.get_problem_statement(), self._submissions)
|
|
555
|
+
return self._chooser_output.chosen_idx
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
# todo: The model shouldn't be defined here, it should be defined as part of the scorer
|
|
559
|
+
class ScoreRetryLoop(AbstractRetryLoop):
|
|
560
|
+
def __init__(
|
|
561
|
+
self,
|
|
562
|
+
config: ScoreRetryLoopConfig,
|
|
563
|
+
problem_statement: ProblemStatement,
|
|
564
|
+
):
|
|
565
|
+
# This model will not share instance cost with the parent agent
|
|
566
|
+
self._model = get_model(config.model, tools=ToolConfig())
|
|
567
|
+
self._problem_statement = problem_statement
|
|
568
|
+
self._reviewer: AbstractReviewer = config.reviewer_config.get_reviewer(self._model)
|
|
569
|
+
self._config = config
|
|
570
|
+
# Note: These are "cumulative" submissions, i.e., they include all retries
|
|
571
|
+
# up to that point.
|
|
572
|
+
self._submissions: list[ReviewSubmission] = []
|
|
573
|
+
self._reviews: list[ReviewerResult] = []
|
|
574
|
+
#: Number of consecutive exit cost submissions
|
|
575
|
+
self._n_consec_exit_cost: int = 0
|
|
576
|
+
self.logger = get_logger("review_loop", emoji="🔄")
|
|
577
|
+
|
|
578
|
+
# Properties
|
|
579
|
+
# ----------
|
|
580
|
+
|
|
581
|
+
@property
|
|
582
|
+
def review_model_stats(self) -> InstanceStats:
|
|
583
|
+
return self._model.stats
|
|
584
|
+
|
|
585
|
+
@property
|
|
586
|
+
def reviews(self) -> list[ReviewerResult]:
|
|
587
|
+
return self._reviews
|
|
588
|
+
|
|
589
|
+
@property
|
|
590
|
+
def _n_attempts(self) -> int:
|
|
591
|
+
return len(self._submissions)
|
|
592
|
+
|
|
593
|
+
@property
|
|
594
|
+
def _n_accepted(self) -> int:
|
|
595
|
+
return sum(r.accept >= self._config.accept_score for r in self._reviews)
|
|
596
|
+
|
|
597
|
+
@property
|
|
598
|
+
def _total_stats(self) -> InstanceStats:
|
|
599
|
+
return sum((s.model_stats for s in self._submissions), start=InstanceStats()) + self._model.stats
|
|
600
|
+
|
|
601
|
+
# -------
|
|
602
|
+
|
|
603
|
+
def on_submit(self, submission: ReviewSubmission) -> None:
|
|
604
|
+
self._submissions.append(submission)
|
|
605
|
+
self._review()
|
|
606
|
+
|
|
607
|
+
def _review(self) -> float:
|
|
608
|
+
review = self._reviewer.review(self._problem_statement, self._submissions[-1])
|
|
609
|
+
self._reviews.append(review)
|
|
610
|
+
exit_status = self._submissions[-1].info.get("exit_status", "")
|
|
611
|
+
if exit_status and "exit_cost" in exit_status.lower():
|
|
612
|
+
self._n_consec_exit_cost += 1
|
|
613
|
+
else:
|
|
614
|
+
self._n_consec_exit_cost = 0
|
|
615
|
+
return review.accept
|
|
616
|
+
|
|
617
|
+
def retry(self) -> bool:
|
|
618
|
+
max_score = max([r.accept for r in self._reviews], default=-100.0)
|
|
619
|
+
stat_str = f"n_samples={self._n_attempts}, max_score={max_score}, n_accepted={self._n_accepted}"
|
|
620
|
+
|
|
621
|
+
if self._total_stats.instance_cost > self._config.cost_limit > 0:
|
|
622
|
+
self.logger.info(
|
|
623
|
+
f"Exiting retry loop ({stat_str}): Total attempt cost ({self._total_stats.instance_cost}) "
|
|
624
|
+
f"exceeds cost limit ({self._config.cost_limit})"
|
|
625
|
+
)
|
|
626
|
+
return False
|
|
627
|
+
|
|
628
|
+
if self._n_attempts >= self._config.max_attempts > 0:
|
|
629
|
+
self.logger.info(f"Exiting retry loop ({stat_str}): max_attempts={self._config.max_attempts} reached")
|
|
630
|
+
return False
|
|
631
|
+
|
|
632
|
+
if self._n_accepted >= self._config.max_accepts > 0:
|
|
633
|
+
self.logger.info(f"Exiting retry loop ({stat_str}): max_accepts={self._config.max_accepts} reached")
|
|
634
|
+
return False
|
|
635
|
+
|
|
636
|
+
remaining_budget = self._config.cost_limit - self._total_stats.instance_cost
|
|
637
|
+
if self._config.min_budget_for_new_attempt > 0 and remaining_budget < self._config.min_budget_for_new_attempt:
|
|
638
|
+
msg = (
|
|
639
|
+
f"Exiting retry loop ({stat_str}): Not enough budget left for a new attempt "
|
|
640
|
+
f"({remaining_budget} remaining, {self._config.min_budget_for_new_attempt} required)"
|
|
641
|
+
)
|
|
642
|
+
self.logger.info(msg)
|
|
643
|
+
return False
|
|
644
|
+
|
|
645
|
+
return True
|
|
646
|
+
|
|
647
|
+
def get_best(self) -> int | None:
|
|
648
|
+
if len(self._reviews) == 0:
|
|
649
|
+
return None
|
|
650
|
+
scores = [r.accept for r in self._reviews]
|
|
651
|
+
self.logger.debug(f"Scores: {scores}")
|
|
652
|
+
max_score = np.max(scores)
|
|
653
|
+
max_indices = [i for i, s in enumerate(scores) if np.isclose(s, max_score)]
|
|
654
|
+
# If there are multiple submissions with the same score, choose the shortest one
|
|
655
|
+
max_indices = sorted(max_indices, key=lambda i: self._submissions[i].model_stats.api_calls or float("inf"))
|
|
656
|
+
chosen_idx = max_indices[0]
|
|
657
|
+
self.logger.info(f"Best submission: {chosen_idx}")
|
|
658
|
+
return chosen_idx
|
|
659
|
+
|
|
660
|
+
|
|
661
|
+
def get_retry_loop_from_config(
|
|
662
|
+
config: RetryLoopConfig, problem_statement: ProblemStatement
|
|
663
|
+
) -> ScoreRetryLoop | ChooserRetryLoop:
|
|
664
|
+
return config.get_retry_loop(problem_statement=problem_statement)
|
|
File without changes
|
|
File without changes
|