@elizaos/sweagent-root 2.0.0-alpha

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +270 -0
  3. package/package.json +71 -0
  4. package/python/LICENSE +21 -0
  5. package/python/config/README.md +15 -0
  6. package/python/config/bash_only.yaml +222 -0
  7. package/python/config/benchmarks/250212_sweagent_heavy_sbl.yaml +188 -0
  8. package/python/config/benchmarks/250225_anthropic_filemap_simple_review.yaml +75 -0
  9. package/python/config/benchmarks/250522_anthropic_filemap_simple_review.yaml +92 -0
  10. package/python/config/benchmarks/250526_anthropic_filemap_simple_review_sbl.yaml +93 -0
  11. package/python/config/benchmarks/anthropic_filemap_multilingual.yaml +66 -0
  12. package/python/config/coding_challenge.yaml +104 -0
  13. package/python/config/default.yaml +69 -0
  14. package/python/config/default_backticks.yaml +69 -0
  15. package/python/config/default_mm_no_images.yaml +82 -0
  16. package/python/config/default_mm_with_images.yaml +83 -0
  17. package/python/config/demo/default.yaml +80 -0
  18. package/python/config/demo/no_instructions.yaml +69 -0
  19. package/python/config/demo/only_bash.yaml +60 -0
  20. package/python/config/exotic/default_shell.yaml +52 -0
  21. package/python/config/exotic/windowed_replace.yaml +125 -0
  22. package/python/config/exotic/windowed_replace_late_repro.yaml +127 -0
  23. package/python/config/human/human.yaml +24 -0
  24. package/python/config/human/human_demo.yaml +52 -0
  25. package/python/config/sweagent_0_7/07.yaml +101 -0
  26. package/python/config/sweagent_0_7/07_fcalling.yaml +100 -0
  27. package/python/config/sweagent_0_7/07_from_url.yaml +114 -0
  28. package/python/config/sweagent_0_7/07_thought_action.yaml +102 -0
  29. package/python/config/sweagent_0_7/07_thought_action_xml.yaml +96 -0
  30. package/python/mlc_config.json +44 -0
  31. package/python/pyproject.toml +262 -0
  32. package/python/sweagent/__init__.py +114 -0
  33. package/python/sweagent/__main__.py +4 -0
  34. package/python/sweagent/agent/__init__.py +0 -0
  35. package/python/sweagent/agent/action_sampler.py +317 -0
  36. package/python/sweagent/agent/agents.py +1294 -0
  37. package/python/sweagent/agent/extra/shell_agent.py +106 -0
  38. package/python/sweagent/agent/history_processors.py +399 -0
  39. package/python/sweagent/agent/hooks/__init__.py +0 -0
  40. package/python/sweagent/agent/hooks/abstract.py +139 -0
  41. package/python/sweagent/agent/hooks/status.py +34 -0
  42. package/python/sweagent/agent/models.py +896 -0
  43. package/python/sweagent/agent/problem_statement.py +312 -0
  44. package/python/sweagent/agent/reviewer.py +664 -0
  45. package/python/sweagent/environment/__init__.py +0 -0
  46. package/python/sweagent/environment/hooks/__init__.py +0 -0
  47. package/python/sweagent/environment/hooks/abstract.py +60 -0
  48. package/python/sweagent/environment/hooks/status.py +28 -0
  49. package/python/sweagent/environment/repo.py +219 -0
  50. package/python/sweagent/environment/swe_env.py +276 -0
  51. package/python/sweagent/exceptions.py +54 -0
  52. package/python/sweagent/inspector/README.md +6 -0
  53. package/python/sweagent/inspector/__init__.py +0 -0
  54. package/python/sweagent/inspector/favicon.ico +0 -0
  55. package/python/sweagent/inspector/fileViewer.js +354 -0
  56. package/python/sweagent/inspector/icons/computer.png +0 -0
  57. package/python/sweagent/inspector/icons/edit_icon.svg +11 -0
  58. package/python/sweagent/inspector/icons/swe-agent-logo-50.png +0 -0
  59. package/python/sweagent/inspector/icons/swellama_blue.png +0 -0
  60. package/python/sweagent/inspector/icons/swellama_brown.png +0 -0
  61. package/python/sweagent/inspector/icons/swellama_grey.png +0 -0
  62. package/python/sweagent/inspector/icons/swellama_tan.png +0 -0
  63. package/python/sweagent/inspector/index.html +25 -0
  64. package/python/sweagent/inspector/server.py +354 -0
  65. package/python/sweagent/inspector/static.py +169 -0
  66. package/python/sweagent/inspector/style.css +454 -0
  67. package/python/sweagent/run/__init__.py +0 -0
  68. package/python/sweagent/run/_progress.py +158 -0
  69. package/python/sweagent/run/batch_instances.py +419 -0
  70. package/python/sweagent/run/common.py +387 -0
  71. package/python/sweagent/run/compare_runs.py +123 -0
  72. package/python/sweagent/run/extract_pred.py +19 -0
  73. package/python/sweagent/run/hooks/__init__.py +0 -0
  74. package/python/sweagent/run/hooks/abstract.py +67 -0
  75. package/python/sweagent/run/hooks/apply_patch.py +106 -0
  76. package/python/sweagent/run/hooks/open_pr.py +244 -0
  77. package/python/sweagent/run/hooks/swe_bench_evaluate.py +113 -0
  78. package/python/sweagent/run/inspector_cli.py +493 -0
  79. package/python/sweagent/run/merge_predictions.py +64 -0
  80. package/python/sweagent/run/quick_stats.py +96 -0
  81. package/python/sweagent/run/remove_unfinished.py +63 -0
  82. package/python/sweagent/run/rich_test.py +91 -0
  83. package/python/sweagent/run/run.py +147 -0
  84. package/python/sweagent/run/run_batch.py +442 -0
  85. package/python/sweagent/run/run_replay.py +219 -0
  86. package/python/sweagent/run/run_shell.py +155 -0
  87. package/python/sweagent/run/run_single.py +225 -0
  88. package/python/sweagent/run/run_traj_to_demo.py +85 -0
  89. package/python/sweagent/tools/__init__.py +0 -0
  90. package/python/sweagent/tools/bundle.py +57 -0
  91. package/python/sweagent/tools/commands.py +220 -0
  92. package/python/sweagent/tools/parsing.py +619 -0
  93. package/python/sweagent/tools/tools.py +430 -0
  94. package/python/sweagent/tools/utils.py +108 -0
  95. package/python/sweagent/types.py +102 -0
  96. package/python/sweagent/utils/__init__.py +0 -0
  97. package/python/sweagent/utils/config.py +80 -0
  98. package/python/sweagent/utils/files.py +27 -0
  99. package/python/sweagent/utils/github.py +118 -0
  100. package/python/sweagent/utils/jinja_warnings.py +14 -0
  101. package/python/sweagent/utils/log.py +175 -0
  102. package/python/sweagent/utils/patch_formatter.py +152 -0
  103. package/python/sweagent/utils/serialization.py +45 -0
  104. package/python/tests/__init__.py +0 -0
  105. package/python/tests/conftest.py +191 -0
  106. package/python/tests/test_agent.py +258 -0
  107. package/python/tests/test_batch_instance.py +43 -0
  108. package/python/tests/test_commands/_interactive_dummy.py +35 -0
  109. package/python/tests/test_commands/interactive_dummy_wrapper.sh +29 -0
  110. package/python/tests/test_data/config_files/dummy_interactive.yaml +62 -0
  111. package/python/tests/test_data/data_sources/ctf/crypto/Katy/Dockerfile +20 -0
  112. package/python/tests/test_data/data_sources/ctf/crypto/Katy/README.md +13 -0
  113. package/python/tests/test_data/data_sources/ctf/crypto/Katy/challenge.json +12 -0
  114. package/python/tests/test_data/data_sources/ctf/crypto/Katy/customrandom.c +50 -0
  115. package/python/tests/test_data/data_sources/ctf/crypto/Katy/docker-compose.yml +14 -0
  116. package/python/tests/test_data/data_sources/ctf/crypto/Katy/release +0 -0
  117. package/python/tests/test_data/data_sources/ctf/crypto/Katy/server +0 -0
  118. package/python/tests/test_data/data_sources/ctf/crypto/Katy/solver.py +12 -0
  119. package/python/tests/test_data/data_sources/ctf/forensics/flash/README.md +16 -0
  120. package/python/tests/test_data/data_sources/ctf/forensics/flash/challenge.json +9 -0
  121. package/python/tests/test_data/data_sources/ctf/forensics/flash/flash_c8429a430278283c0e571baebca3d139.zip +0 -0
  122. package/python/tests/test_data/data_sources/ctf/misc/networking_1/README.md +15 -0
  123. package/python/tests/test_data/data_sources/ctf/misc/networking_1/challenge.json +10 -0
  124. package/python/tests/test_data/data_sources/ctf/misc/networking_1/networking.pcap +0 -0
  125. package/python/tests/test_data/data_sources/ctf/pwn/warmup/Dockerfile +28 -0
  126. package/python/tests/test_data/data_sources/ctf/pwn/warmup/README.md +14 -0
  127. package/python/tests/test_data/data_sources/ctf/pwn/warmup/challenge.json +14 -0
  128. package/python/tests/test_data/data_sources/ctf/pwn/warmup/docker-compose.yml +14 -0
  129. package/python/tests/test_data/data_sources/ctf/pwn/warmup/flag.txt +1 -0
  130. package/python/tests/test_data/data_sources/ctf/pwn/warmup/warmup +0 -0
  131. package/python/tests/test_data/data_sources/ctf/pwn/warmup/warmup.c +26 -0
  132. package/python/tests/test_data/data_sources/ctf/pwn/warmup/warmup.py +9 -0
  133. package/python/tests/test_data/data_sources/ctf/rev/rock/README.md +14 -0
  134. package/python/tests/test_data/data_sources/ctf/rev/rock/challenge.json +8 -0
  135. package/python/tests/test_data/data_sources/ctf/rev/rock/rock +0 -0
  136. package/python/tests/test_data/data_sources/ctf/rev/rock/rock.cpp +167 -0
  137. package/python/tests/test_data/data_sources/ctf/rev/rock/solution.cpp +24 -0
  138. package/python/tests/test_data/data_sources/ctf/rev/rock/test_solver/solution.py +6 -0
  139. package/python/tests/test_data/data_sources/ctf/rev/rock/test_solver/test.sh +10 -0
  140. package/python/tests/test_data/data_sources/ctf/web/i_got_id_demo/000-default.conf +18 -0
  141. package/python/tests/test_data/data_sources/ctf/web/i_got_id_demo/Dockerfile +20 -0
  142. package/python/tests/test_data/data_sources/ctf/web/i_got_id_demo/cgi/file.pl +38 -0
  143. package/python/tests/test_data/data_sources/ctf/web/i_got_id_demo/cgi/forms.pl +40 -0
  144. package/python/tests/test_data/data_sources/ctf/web/i_got_id_demo/cgi/hello.pl +11 -0
  145. package/python/tests/test_data/data_sources/ctf/web/i_got_id_demo/challenge.json +12 -0
  146. package/python/tests/test_data/data_sources/ctf/web/i_got_id_demo/docker-compose.yml +14 -0
  147. package/python/tests/test_data/data_sources/ctf/web/i_got_id_demo/flag +1 -0
  148. package/python/tests/test_data/data_sources/ctf/web/i_got_id_demo/index.html +11 -0
  149. package/python/tests/test_data/data_sources/ctf/web/i_got_id_demo/solution.txt +1 -0
  150. package/python/tests/test_data/data_sources/debug_20240322.json +1 -0
  151. package/python/tests/test_data/data_sources/expert_instances.yaml +16 -0
  152. package/python/tests/test_data/data_sources/human_eval.json +1 -0
  153. package/python/tests/test_data/data_sources/simple_instances.yaml +3 -0
  154. package/python/tests/test_data/data_sources/simple_instances_long.yaml +30 -0
  155. package/python/tests/test_data/data_sources/swe-bench-dev-easy.json +1 -0
  156. package/python/tests/test_data/data_sources/swe-bench-dev-easy_first_only.json +1 -0
  157. package/python/tests/test_data/data_sources/swe-bench-lite-test.json +1 -0
  158. package/python/tests/test_data/trajectories/gpt4__swe-agent-test-repo__default_from_url__t-0.00__p-0.95__c-3.00__install-1/6e44b9__sweagenttestrepo-1c2844.traj +342 -0
  159. package/python/tests/test_data/trajectories/gpt4__swe-agent-test-repo__default_from_url__t-0.00__p-0.95__c-3.00__install-1/solution_missing_colon.py +15 -0
  160. package/python/tests/test_data/trajectories/gpt4__swe-agent__test-repo__default_from_url__t-0.00__p-0.95__c-3.00__install-1/args.yaml +518 -0
  161. package/python/tests/test_data/trajectories/gpt4__swe-agent__test-repo__default_from_url__t-0.00__p-0.95__c-3.00__install-1/swe-agent__test-repo-i1.traj +124 -0
  162. package/python/tests/test_data/trajectories/gpt4__swe-bench-dev-easy_first_only__default__t-0.00__p-0.95__c-3.00__install-1/all_preds.jsonl +1 -0
  163. package/python/tests/test_data/trajectories/gpt4__swe-bench-dev-easy_first_only__default__t-0.00__p-0.95__c-3.00__install-1/args.yaml +520 -0
  164. package/python/tests/test_data/trajectories/gpt4__swe-bench-dev-easy_first_only__default__t-0.00__p-0.95__c-3.00__install-1/patches/pydicom__pydicom-1458.patch +18 -0
  165. package/python/tests/test_data/trajectories/gpt4__swe-bench-dev-easy_first_only__default__t-0.00__p-0.95__c-3.00__install-1/pydicom__pydicom-1458.traj +257 -0
  166. package/python/tests/test_env.py +66 -0
  167. package/python/tests/test_env_utils.py +129 -0
  168. package/python/tests/test_history_processors.py +40 -0
  169. package/python/tests/test_models.py +23 -0
  170. package/python/tests/test_openai_live.py +164 -0
  171. package/python/tests/test_packaging.py +7 -0
  172. package/python/tests/test_parsing.py +131 -0
  173. package/python/tests/test_problem_statement_multimodal.py +111 -0
  174. package/python/tests/test_quick_stats.py +42 -0
  175. package/python/tests/test_run.py +37 -0
  176. package/python/tests/test_run_batch.py +110 -0
  177. package/python/tests/test_run_hooks.py +114 -0
  178. package/python/tests/test_run_replay.py +33 -0
  179. package/python/tests/test_run_single.py +125 -0
  180. package/python/tests/test_tools_command_parsing.py +193 -0
  181. package/python/tests/test_utils.py +15 -0
  182. package/python/tests/tools/__init__.py +0 -0
  183. package/python/tests/tools/conftest.py +12 -0
  184. package/python/tests/tools/test_default_utils.py +153 -0
  185. package/python/tests/tools/test_edit_replace.py +0 -0
  186. package/python/tests/tools/test_split_string.py +82 -0
  187. package/python/tests/utils.py +29 -0
  188. package/python/tools/diff_state/bin/_state_diff_state +52 -0
  189. package/python/tools/diff_state/config.yaml +2 -0
  190. package/python/tools/edit_anthropic/bin/_state_anthropic +21 -0
  191. package/python/tools/edit_anthropic/bin/str_replace_editor +710 -0
  192. package/python/tools/edit_anthropic/config.yaml +56 -0
  193. package/python/tools/edit_anthropic/install.sh +3 -0
  194. package/python/tools/filemap/bin/filemap +45 -0
  195. package/python/tools/filemap/config.yaml +9 -0
  196. package/python/tools/filemap/install.sh +2 -0
  197. package/python/tools/forfeit/bin/exit_forfeit +5 -0
  198. package/python/tools/forfeit/config.yaml +5 -0
  199. package/python/tools/image_tools/bin/view_image +36 -0
  200. package/python/tools/image_tools/config.yaml +9 -0
  201. package/python/tools/multilingual_setup/bin/do_nothing +2 -0
  202. package/python/tools/multilingual_setup/config.yaml +1 -0
  203. package/python/tools/multilingual_setup/install.sh +45 -0
  204. package/python/tools/registry/bin/_read_env +10 -0
  205. package/python/tools/registry/bin/_write_env +10 -0
  206. package/python/tools/registry/config.yaml +1 -0
  207. package/python/tools/registry/install.sh +6 -0
  208. package/python/tools/registry/lib/__init__.py +0 -0
  209. package/python/tools/registry/lib/registry.py +56 -0
  210. package/python/tools/review_on_submit_m/README.md +6 -0
  211. package/python/tools/review_on_submit_m/bin/submit +54 -0
  212. package/python/tools/review_on_submit_m/config.yaml +6 -0
  213. package/python/tools/review_on_submit_m/install.sh +0 -0
  214. package/python/tools/search/bin/find_file +31 -0
  215. package/python/tools/search/bin/search_dir +39 -0
  216. package/python/tools/search/bin/search_file +55 -0
  217. package/python/tools/search/config.yaml +37 -0
  218. package/python/tools/search/install.sh +3 -0
  219. package/python/tools/submit/bin/submit +17 -0
  220. package/python/tools/submit/config.yaml +5 -0
  221. package/python/tools/web_browser/bin/click_mouse +41 -0
  222. package/python/tools/web_browser/bin/close_site +28 -0
  223. package/python/tools/web_browser/bin/double_click_mouse +37 -0
  224. package/python/tools/web_browser/bin/drag_mouse +46 -0
  225. package/python/tools/web_browser/bin/execute_script_on_page +39 -0
  226. package/python/tools/web_browser/bin/get_console_output +48 -0
  227. package/python/tools/web_browser/bin/move_mouse +35 -0
  228. package/python/tools/web_browser/bin/navigate_back +33 -0
  229. package/python/tools/web_browser/bin/navigate_forward +33 -0
  230. package/python/tools/web_browser/bin/open_site +36 -0
  231. package/python/tools/web_browser/bin/press_keys_on_page +51 -0
  232. package/python/tools/web_browser/bin/reload_page +33 -0
  233. package/python/tools/web_browser/bin/run_web_browser_server +394 -0
  234. package/python/tools/web_browser/bin/screenshot_site +38 -0
  235. package/python/tools/web_browser/bin/scroll_on_page +40 -0
  236. package/python/tools/web_browser/bin/set_browser_window_size +40 -0
  237. package/python/tools/web_browser/bin/type_text +34 -0
  238. package/python/tools/web_browser/bin/wait_time +39 -0
  239. package/python/tools/web_browser/config.yaml +155 -0
  240. package/python/tools/web_browser/install.sh +22 -0
  241. package/python/tools/web_browser/lib/browser_manager.py +404 -0
  242. package/python/tools/web_browser/lib/web_browser_config.py +33 -0
  243. package/python/tools/web_browser/lib/web_browser_utils.py +126 -0
  244. package/python/tools/web_browser/test_console.html +1 -0
  245. package/python/tools/windowed/bin/_state +25 -0
  246. package/python/tools/windowed/bin/create +29 -0
  247. package/python/tools/windowed/bin/goto +37 -0
  248. package/python/tools/windowed/bin/open +49 -0
  249. package/python/tools/windowed/bin/scroll_down +12 -0
  250. package/python/tools/windowed/bin/scroll_up +13 -0
  251. package/python/tools/windowed/config.yaml +38 -0
  252. package/python/tools/windowed/install.sh +15 -0
  253. package/python/tools/windowed/lib/__init__.py +0 -0
  254. package/python/tools/windowed/lib/flake8_utils.py +147 -0
  255. package/python/tools/windowed/lib/windowed_file.py +312 -0
  256. package/python/tools/windowed_edit_linting/bin/edit +128 -0
  257. package/python/tools/windowed_edit_linting/config.yaml +31 -0
  258. package/python/tools/windowed_edit_linting/install.sh +5 -0
  259. package/python/tools/windowed_edit_replace/bin/edit +172 -0
  260. package/python/tools/windowed_edit_replace/bin/insert +77 -0
  261. package/python/tools/windowed_edit_replace/config.yaml +60 -0
  262. package/python/tools/windowed_edit_replace/install.sh +5 -0
  263. package/python/tools/windowed_edit_rewrite/bin/edit +78 -0
  264. package/python/tools/windowed_edit_rewrite/config.yaml +11 -0
  265. package/python/tools/windowed_edit_rewrite/install.sh +5 -0
  266. package/python/trajectories/demonstrations/ctf/crypto/BabyEncryption.traj +318 -0
  267. package/python/trajectories/demonstrations/ctf/crypto/BabyTimeCapsule.traj +197 -0
  268. package/python/trajectories/demonstrations/ctf/crypto/eps.traj +289 -0
  269. package/python/trajectories/demonstrations/ctf/crypto/katy.traj +368 -0
  270. package/python/trajectories/demonstrations/ctf/forensics/flash.traj +102 -0
  271. package/python/trajectories/demonstrations/ctf/misc/networking_1.traj +102 -0
  272. package/python/trajectories/demonstrations/ctf/pwn/warmup.traj +159 -0
  273. package/python/trajectories/demonstrations/ctf/rev/rock.traj +251 -0
  274. package/python/trajectories/demonstrations/ctf/web/i_got_id_demo.traj +422 -0
  275. package/python/trajectories/demonstrations/function_calling_simple.traj +151 -0
  276. package/python/trajectories/demonstrations/human_thought__swe-bench-HumanEvalFix-python__lcb__t-0.00__p-0.95__c-4.00__install-0/humanevalfix-python-0.traj +129 -0
  277. package/python/trajectories/demonstrations/replay__marshmallow-code__marshmallow-1867__default__t-0.20__p-0.95__c-2.00__install-1___install_from_source/marshmallow-code__marshmallow-1867.traj +318 -0
  278. package/python/trajectories/demonstrations/replay__marshmallow-code__marshmallow-1867__default_sys-env_cursors_window100__t-0.20__p-0.95__c-2.00__install-1/marshmallow-code__marshmallow-1867.traj +251 -0
  279. package/python/trajectories/demonstrations/replay__marshmallow-code__marshmallow-1867__default_sys-env_window100__t-0.20__p-0.95__c-2.00__install-1/marshmallow-code__marshmallow-1867.traj +399 -0
  280. package/python/trajectories/demonstrations/replay__marshmallow-code__marshmallow-1867__function_calling__install-1/marshmallow-code__marshmallow-1867.traj +594 -0
  281. package/python/trajectories/demonstrations/replay__marshmallow-code__marshmallow-1867__function_calling_replace__install-1/marshmallow-code__marshmallow-1867.traj +592 -0
  282. package/python/trajectories/demonstrations/replay__marshmallow-code__marshmallow-1867__function_calling_replace_from_source/marshmallow-code__marshmallow-1867.traj +3316 -0
  283. package/python/trajectories/demonstrations/replay__marshmallow-code__marshmallow-1867__xml_sys-env_cursors_window100__t-0.20__p-0.95__c-2.00__install-1/marshmallow-code__marshmallow-1867.traj +251 -0
  284. package/python/trajectories/demonstrations/replay__marshmallow-code__marshmallow-1867__xml_sys-env_window100__t-0.20__p-0.95__c-2.00__install-1/marshmallow-code__marshmallow-1867.traj +399 -0
  285. package/python/trajectories/demonstrations/str_replace_anthropic_demo.yaml +432 -0
  286. package/rust/Cargo.toml +100 -0
  287. package/rust/README.md +49 -0
  288. package/rust/src/agent/action_sampler.rs +130 -0
  289. package/rust/src/agent/agents.rs +1029 -0
  290. package/rust/src/agent/history_processors.rs +277 -0
  291. package/rust/src/agent/hooks/mod.rs +208 -0
  292. package/rust/src/agent/mod.rs +24 -0
  293. package/rust/src/agent/models.rs +837 -0
  294. package/rust/src/agent/problem_statement.rs +355 -0
  295. package/rust/src/agent/reviewer.rs +505 -0
  296. package/rust/src/bin/sweagent.rs +784 -0
  297. package/rust/src/environment/deployment.rs +631 -0
  298. package/rust/src/environment/hooks/mod.rs +114 -0
  299. package/rust/src/environment/mod.rs +16 -0
  300. package/rust/src/environment/repo.rs +265 -0
  301. package/rust/src/environment/runtime.rs +237 -0
  302. package/rust/src/environment/swe_env.rs +248 -0
  303. package/rust/src/exceptions.rs +228 -0
  304. package/rust/src/lib.rs +68 -0
  305. package/rust/src/monitoring.rs +482 -0
  306. package/rust/src/run/hooks/mod.rs +134 -0
  307. package/rust/src/run/mod.rs +12 -0
  308. package/rust/src/run/run_batch.rs +563 -0
  309. package/rust/src/run/run_single.rs +196 -0
  310. package/rust/src/tools/bundle.rs +224 -0
  311. package/rust/src/tools/commands.rs +173 -0
  312. package/rust/src/tools/mod.rs +295 -0
  313. package/rust/src/tools/parsing.rs +354 -0
  314. package/rust/src/tools/registry.rs +143 -0
  315. package/rust/src/types.rs +554 -0
  316. package/rust/src/utils/config.rs +105 -0
  317. package/rust/src/utils/files.rs +137 -0
  318. package/rust/src/utils/github.rs +171 -0
  319. package/rust/src/utils/log.rs +65 -0
  320. package/rust/src/utils/mod.rs +17 -0
  321. package/rust/src/utils/serialization.rs +181 -0
  322. package/rust/src/utils/template.rs +173 -0
  323. package/typescript/README.md +335 -0
@@ -0,0 +1,896 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import json
5
+ import os
6
+ import random
7
+ import shlex
8
+ import threading
9
+ import time
10
+ from abc import ABC, abstractmethod
11
+ from pathlib import Path
12
+ from threading import Lock
13
+ from typing import Annotated, Any, Literal
14
+
15
+ import litellm
16
+ import litellm.types.utils
17
+ from pydantic import BaseModel as PydanticBaseModel
18
+ from pydantic import ConfigDict, Field, SecretStr
19
+ from swerex.exceptions import SwerexException
20
+ from tenacity import (
21
+ RetryCallState,
22
+ Retrying,
23
+ retry_if_not_exception_type,
24
+ stop_after_attempt,
25
+ wait_random_exponential,
26
+ )
27
+
28
+ from sweagent import REPO_ROOT
29
+ from sweagent.exceptions import (
30
+ ContentPolicyViolationError,
31
+ ContextWindowExceededError,
32
+ CostLimitExceededError,
33
+ FunctionCallingFormatError,
34
+ InstanceCallLimitExceededError,
35
+ InstanceCostLimitExceededError,
36
+ ModelConfigurationError,
37
+ TotalCostLimitExceededError,
38
+ )
39
+ from sweagent.tools.tools import ToolConfig
40
+ from sweagent.types import History, HistoryItem
41
+ from sweagent.utils.log import get_logger
42
+
43
+ try:
44
+ import readline # noqa: F401
45
+ except ImportError:
46
+ readline = None
47
+
48
+ litellm.suppress_debug_info = True
49
+
50
+
51
+ _THREADS_THAT_USED_API_KEYS = []
52
+ """Keeps track of thread orders so that we can choose the same API key for the same thread."""
53
+
54
+
55
+ class RetryConfig(PydanticBaseModel):
56
+ """This configuration object specifies how many times to retry a failed LM API call."""
57
+
58
+ retries: int = 20
59
+ """Number of retries"""
60
+ min_wait: float = 10
61
+ """Minimum wait time between retries (random exponential wait)"""
62
+ max_wait: float = 120
63
+ """Maximum wait time between retries (random exponential wait)"""
64
+
65
+
66
+ class GenericAPIModelConfig(PydanticBaseModel):
67
+ """This configuration object specifies a LM like GPT4 or similar.
68
+ The model will be served with the help of the `litellm` library.
69
+ """
70
+
71
+ name: str = Field(description="Name of the model.")
72
+
73
+ per_instance_cost_limit: float = Field(
74
+ default=3.0,
75
+ description="Cost limit for every instance (task).",
76
+ )
77
+ total_cost_limit: float = Field(default=0.0, description="Total cost limit.")
78
+ per_instance_call_limit: int = Field(default=0, description="Per instance call limit.")
79
+ temperature: float = 0.0
80
+ """Sampling temperature"""
81
+ top_p: float | None = 1.0
82
+ """Sampling top-p"""
83
+ api_base: str | None = None
84
+ api_version: str | None = None
85
+ api_key: SecretStr | None = None
86
+ """API key to the model. We recommend using environment variables to set this instead
87
+ or putting your environment variables in a `.env` file.
88
+ You can concatenate more than one key by separating them with `:::`, e.g.,
89
+ `key1:::key2`.
90
+ If field starts with `$`, it will be interpreted as an environment variable.
91
+ """
92
+ stop: list[str] = []
93
+ """Custom stop sequences"""
94
+
95
+ completion_kwargs: dict[str, Any] = {}
96
+ """Additional kwargs to pass to `litellm.completion`"""
97
+
98
+ convert_system_to_user: bool = False
99
+ """Whether to convert system messages to user messages. This is useful for
100
+ models that do not support system messages like o1.
101
+ """
102
+
103
+ retry: RetryConfig = RetryConfig()
104
+ """Retry configuration: How often to retry after a failure (e.g., from a rate limit)
105
+ etc.
106
+ """
107
+
108
+ delay: float = 0.0
109
+ """Minimum delay before querying (this can help to avoid overusing the API if sharing
110
+ it with other people).
111
+ """
112
+
113
+ fallbacks: list[dict[str, Any]] = []
114
+ """List of fallbacks to try if the main model fails
115
+ See https://docs.litellm.ai/docs/completion/reliable_completions#fallbacks-sdk
116
+ for more information.
117
+ """
118
+
119
+ choose_api_key_by_thread: bool = True
120
+ """Whether to choose the API key based on the thread name (if multiple are configured).
121
+ This ensures that with
122
+ run-batch, we use the same API key within a single-thread so that prompt caching still works.
123
+ """
124
+
125
+ max_input_tokens: int | None = None
126
+ """If set, this will override the max input tokens for the model that we usually look
127
+ up from `litellm.model_cost`.
128
+ Use this for local models or if you want to set a custom max input token limit.
129
+ If this value is exceeded, a `ContextWindowExceededError` will be raised.
130
+ Set this to 0 to disable this check.
131
+ """
132
+
133
+ max_output_tokens: int | None = None
134
+ """If set, this will override the max output tokens for the model that we usually look
135
+ up from `litellm.model_cost`.
136
+ Use this for local models or if you want to set a custom max output token limit.
137
+ If this value is exceeded, a `ContextWindowExceededError` will be raised.
138
+ Set this to 0 to disable this check.
139
+ """
140
+
141
+ litellm_model_registry: str | None = None
142
+ """If set, this will override the default model registry for litellm.
143
+ Use this for local models or models not (yet) in the default litellm model registry for tracking costs.
144
+ """
145
+
146
+ custom_tokenizer: dict[str, Any] | None = None
147
+ """Override the default tokenizer for the model.
148
+ Use the arguments of `litellm.create_pretrained_tokenizer`.
149
+ Basic example: `{"identifier": "hf-internal-testing/llama-tokenizer"}`
150
+ """
151
+
152
+ # pydantic
153
+ model_config = ConfigDict(extra="forbid")
154
+
155
+ def get_api_keys(self) -> list[str]:
156
+ """Returns a list of API keys that were explicitly set in this config.
157
+ Does not return API keys that were set via environment variables/.env
158
+ """
159
+ if self.api_key is None:
160
+ return []
161
+ api_key = self.api_key.get_secret_value()
162
+ if not api_key:
163
+ return []
164
+ if api_key.startswith("$"):
165
+ env_var_name = api_key[1:]
166
+ api_key = os.getenv(env_var_name, "")
167
+ if not api_key:
168
+ get_logger("swea-config", emoji="🔧").warning(f"Environment variable {env_var_name} not set")
169
+ return []
170
+ return api_key.split(":::")
171
+
172
+ def choose_api_key(self) -> str | None:
173
+ """Chooses an API key based on the API keys explicitly set in this config.
174
+ If no API keys are set, returns None (which means that the API key will be
175
+ taken from the environment variables/.env file).
176
+ """
177
+ api_keys = self.get_api_keys()
178
+ if not api_keys:
179
+ return None
180
+ if not self.choose_api_key_by_thread:
181
+ return random.choice(api_keys)
182
+ thread_name = threading.current_thread().name
183
+ if thread_name not in _THREADS_THAT_USED_API_KEYS:
184
+ _THREADS_THAT_USED_API_KEYS.append(thread_name)
185
+ thread_idx = _THREADS_THAT_USED_API_KEYS.index(thread_name)
186
+ key_idx = thread_idx % len(api_keys)
187
+ get_logger("config", emoji="🔧").debug(
188
+ f"Choosing API key {key_idx} for thread {thread_name} (idx {thread_idx})"
189
+ )
190
+ return api_keys[key_idx]
191
+
192
+ @property
193
+ def id(self) -> str:
194
+ name = self.name.replace("/", "--")
195
+ if self.top_p is not None:
196
+ top_p = f"{self.top_p:.2f}"
197
+ else:
198
+ top_p = "None"
199
+ temperature = f"{self.temperature:.2f}"
200
+ per_instance_cost_limit = f"{self.per_instance_cost_limit:.2f}"
201
+ return f"{name}__t-{temperature}__p-{top_p}__c-{per_instance_cost_limit}"
202
+
203
+
204
+ class ReplayModelConfig(GenericAPIModelConfig):
205
+ replay_path: Path = Field(description="Path to replay file when using the replay model.")
206
+
207
+ per_instance_cost_limit: float = Field(
208
+ default=0.0, description="Cost limit for every instance (task). This is a dummy value here."
209
+ )
210
+ total_cost_limit: float = Field(
211
+ default=0.0, description="Cost limit for all instances (tasks). This is a dummy value here."
212
+ )
213
+
214
+ name: Literal["replay"] = Field(default="replay", description="Model name.")
215
+
216
+ model_config = ConfigDict(extra="forbid")
217
+
218
+
219
+ class InstantEmptySubmitModelConfig(GenericAPIModelConfig):
220
+ """Model that immediately submits an empty patch"""
221
+
222
+ name: Literal["instant_empty_submit"] = Field(default="instant_empty_submit", description="Model name.")
223
+
224
+ per_instance_cost_limit: float = Field(
225
+ default=0.0, description="Cost limit for every instance (task). This is a dummy value here."
226
+ )
227
+ total_cost_limit: float = Field(
228
+ default=0.0, description="Cost limit for all instances (tasks). This is a dummy value here."
229
+ )
230
+ delay: float = 0.0
231
+ """Delay before answering"""
232
+
233
+ model_config = ConfigDict(extra="forbid")
234
+
235
+
236
+ class HumanModelConfig(GenericAPIModelConfig):
237
+ name: Literal["human"] = Field(default="human", description="Model name.")
238
+
239
+ per_instance_cost_limit: float = Field(
240
+ default=0.0, description="Cost limit for every instance (task). This is a dummy value here."
241
+ )
242
+ total_cost_limit: float = Field(default=0.0, description="Cost limit for all instances (tasks).")
243
+ cost_per_call: float = 0.0
244
+ catch_eof: bool = True
245
+ """Whether to catch EOF and return 'exit' when ^D is pressed. Set to False when used in human_step_in mode."""
246
+ model_config = ConfigDict(extra="forbid")
247
+
248
+
249
+ class HumanThoughtModelConfig(HumanModelConfig):
250
+ name: Literal["human_thought"] = Field(default="human_thought", description="Model name.")
251
+
252
+ per_instance_cost_limit: float = Field(
253
+ default=0.0, description="Cost limit for every instance (task). This is a dummy value here."
254
+ )
255
+ total_cost_limit: float = Field(
256
+ default=0.0, description="Cost limit for all instances (tasks). This is a dummy value here."
257
+ )
258
+ cost_per_call: float = 0.0
259
+
260
+ model_config = ConfigDict(extra="forbid")
261
+
262
+
263
+ ModelConfig = Annotated[
264
+ GenericAPIModelConfig
265
+ | ReplayModelConfig
266
+ | InstantEmptySubmitModelConfig
267
+ | HumanModelConfig
268
+ | HumanThoughtModelConfig,
269
+ Field(union_mode="left_to_right"),
270
+ ]
271
+
272
+
273
+ class GlobalStats(PydanticBaseModel):
274
+ """This class tracks usage numbers (costs etc.) across all instances."""
275
+
276
+ total_cost: float = 0
277
+ """Cumulative cost for all instances so far"""
278
+
279
+ last_query_timestamp: float = 0
280
+ """Timestamp of the last query. Currently only used with API models."""
281
+
282
+
283
+ GLOBAL_STATS = GlobalStats()
284
+ """This object tracks usage numbers (costs etc.) across all instances.
285
+ Please use the `GLOBAL_STATS_LOCK` lock when accessing this object to avoid race conditions.
286
+ """
287
+
288
+ GLOBAL_STATS_LOCK = Lock()
289
+ """Lock for accessing `GLOBAL_STATS` without race conditions"""
290
+
291
+
292
+ class InstanceStats(PydanticBaseModel):
293
+ """This object tracks usage numbers (costs etc.) for a single instance."""
294
+
295
+ instance_cost: float = 0
296
+ tokens_sent: int = 0
297
+ tokens_received: int = 0
298
+ api_calls: int = 0
299
+
300
+ def __add__(self, other: InstanceStats) -> InstanceStats:
301
+ return InstanceStats(
302
+ **{field: getattr(self, field) + getattr(other, field) for field in self.model_fields.keys()},
303
+ )
304
+
305
+ def __sub__(self, other: InstanceStats) -> InstanceStats:
306
+ return InstanceStats(
307
+ **{field: getattr(self, field) - getattr(other, field) for field in self.model_fields.keys()},
308
+ )
309
+
310
+
311
+ class AbstractModel(ABC):
312
+ def __init__(self, config: ModelConfig, tools: ToolConfig):
313
+ self.config: ModelConfig
314
+ self.stats: InstanceStats
315
+
316
+ def reset_stats(self):
317
+ self.stats = InstanceStats()
318
+
319
+ @abstractmethod
320
+ def query(self, history: History, action_prompt: str = "> ") -> dict: ...
321
+
322
+ @property
323
+ def instance_cost_limit(self) -> float:
324
+ """Cost limit for the model. Returns 0 if there is no limit."""
325
+ return 0
326
+
327
+
328
+ def _handle_raise_commands(action: str) -> None:
329
+ if action == "raise_runtime":
330
+ raise SwerexException()
331
+ elif action == "raise_cost":
332
+ raise CostLimitExceededError()
333
+ elif action == "raise_context":
334
+ raise ContextWindowExceededError()
335
+ elif action.startswith("raise_function_calling"):
336
+ parts = shlex.split(action)
337
+ error_code = parts[1]
338
+ if len(parts) == 3:
339
+ error_message = parts[2]
340
+ assert len(parts) < 4
341
+ raise FunctionCallingFormatError(error_message, error_code) # type: ignore
342
+
343
+
344
+ class HumanModel(AbstractModel):
345
+ def __init__(self, config: HumanModelConfig, tools: ToolConfig):
346
+ """Model that allows for human-in-the-loop"""
347
+ self.logger = get_logger("swea-lm", emoji="🤖")
348
+ self.config: HumanModelConfig = config
349
+ self.stats = InstanceStats()
350
+
351
+ # Determine which commands require multi-line input
352
+ self.multi_line_command_endings = {
353
+ command.name: command.end_name for command in tools.commands if command.end_name is not None
354
+ }
355
+ self._readline_histfile = REPO_ROOT / ".swe-agent-human-history"
356
+ self._load_readline_history()
357
+
358
+ def _load_readline_history(self) -> None:
359
+ """Load autocomplete history from file"""
360
+ if readline is None:
361
+ return
362
+ if self._readline_histfile.is_file():
363
+ self.logger.debug(f"Loading readline history from {self._readline_histfile}")
364
+ readline.read_history_file(self._readline_histfile)
365
+
366
+ def _save_readline_history(self) -> None:
367
+ """Save autocomplete history to file"""
368
+ if readline is None:
369
+ return
370
+ readline.write_history_file(self._readline_histfile)
371
+
372
+ def _update_stats(
373
+ self,
374
+ ) -> None:
375
+ self.stats.instance_cost += self.config.cost_per_call
376
+ self.stats.api_calls += 1
377
+ if 0 < self.config.per_instance_cost_limit < self.stats.instance_cost:
378
+ msg = f"Instance cost limit exceeded: {self.stats.instance_cost} > {self.config.per_instance_cost_limit}"
379
+ raise InstanceCostLimitExceededError(msg)
380
+ if 0 < self.config.total_cost_limit < self.stats.instance_cost:
381
+ msg = f"Total cost limit exceeded: {self.stats.instance_cost} > {self.config.total_cost_limit}"
382
+ raise TotalCostLimitExceededError(msg)
383
+
384
+ def _query(
385
+ self,
386
+ history: History,
387
+ action_prompt: str = "> ",
388
+ ) -> dict:
389
+ """Logic for handling user input to pass to SWEEnv"""
390
+ action = input(action_prompt)
391
+ self._save_readline_history()
392
+ command_name = action.split()[0] if action.strip() else ""
393
+
394
+ # Special handling for multi-line input actions (i.e. edit)
395
+ if command_name in self.multi_line_command_endings:
396
+ buffer = [action]
397
+ end_keyword = self.multi_line_command_endings[command_name]
398
+ while True:
399
+ action = input("... ")
400
+ buffer.append(action)
401
+ if action.rstrip() == end_keyword:
402
+ # Continue reading input until terminating keyword inputted
403
+ break
404
+ action = "\n".join(buffer)
405
+ elif action.strip() == "start_multiline_command": # do arbitrary multi-line input
406
+ buffer = []
407
+ while True:
408
+ action = input("... ")
409
+ if action.rstrip() == "end_multiline_command":
410
+ break
411
+ buffer.append(action)
412
+ action = "\n".join(buffer)
413
+ else:
414
+ # Input has escaped things like \n, so we need to unescape it
415
+ action = action.encode("utf8").decode("unicode_escape")
416
+ if action.strip() and action.strip().split()[0] == "spend_money":
417
+ money = float(action.strip().split()[1])
418
+ self.stats.instance_cost += money
419
+ action = f"echo 'Spent {money} dollars'"
420
+ _handle_raise_commands(action)
421
+ self._update_stats()
422
+ return {"message": action}
423
+
424
+ def query(self, history: History, action_prompt: str = "> ", n: int | None = None, **kwargs) -> dict | list[dict]:
425
+ """Wrapper to separate action prompt from formatting"""
426
+ out = []
427
+ n_samples = n or 1
428
+ for _ in range(n_samples):
429
+ try:
430
+ out.append(self._query(history, action_prompt))
431
+ except KeyboardInterrupt:
432
+ print("^C (exit with ^D)")
433
+ out.append(self.query(history, action_prompt))
434
+ except EOFError:
435
+ if self.config.catch_eof:
436
+ print("\nGoodbye!")
437
+ out.append({"message": "exit"})
438
+ else:
439
+ # Re-raise EOFError when catch_eof is disabled
440
+ raise
441
+ if n is None:
442
+ return out[0]
443
+ return out
444
+
445
+
446
+ class HumanThoughtModel(HumanModel):
447
+ def query(self, history: History, **kwargs) -> dict:
448
+ """Logic for handling user input (both thought + action) to pass to SWEEnv"""
449
+ thought_all = ""
450
+ thought = input("Thought (end w/ END_THOUGHT): ")
451
+ while True:
452
+ if "END_THOUGHT" in thought:
453
+ thought = thought.split("END_THOUGHT")[0]
454
+ thought_all += thought
455
+ break
456
+ thought_all += thought
457
+ thought = input("... ")
458
+
459
+ action = super()._query(history, action_prompt="Action: ")["message"]
460
+
461
+ return {"message": f"{thought_all}\n```\n{action}\n```"}
462
+
463
+
464
+ class ReplayModel(AbstractModel):
465
+ def __init__(self, config: ReplayModelConfig, tools: ToolConfig):
466
+ """Model used for replaying a trajectory (i.e., taking all the actions for the `.traj` file
467
+ and re-issuing them.
468
+ """
469
+ self.config = config
470
+ self.stats = InstanceStats()
471
+
472
+ if not self.config.replay_path.exists():
473
+ msg = f"Replay file {self.config.replay_path} not found"
474
+ raise FileNotFoundError(msg)
475
+
476
+ self._replays = [
477
+ list(json.loads(x).values())[0] for x in Path(self.config.replay_path).read_text().splitlines(keepends=True)
478
+ ]
479
+ self._replay_idx = 0
480
+ self._action_idx = 0
481
+ self.use_function_calling = tools.use_function_calling
482
+ self.submit_command = tools.submit_command
483
+ self.logger = get_logger("swea-lm", emoji="🤖")
484
+
485
+ def _next_replay(self) -> None:
486
+ """Called after last action"""
487
+ self._replay_idx += 1
488
+ self._action_idx = 0
489
+
490
+ def query(self, history: History) -> dict:
491
+ """Logic for tracking which replay action to pass to SWEEnv"""
492
+ self.stats.api_calls += 1
493
+ actions = self._replays[self._replay_idx]
494
+ try:
495
+ action = actions[self._action_idx]
496
+ except IndexError:
497
+ # log error
498
+ self.logger.error("Reached end of replay trajectory without submitting. Submitting now.")
499
+ if self.use_function_calling:
500
+ action = {
501
+ "message": f"Calling `{self.submit_command}` to submit.",
502
+ "tool_calls": [
503
+ {
504
+ "type": "function",
505
+ "id": "call_submit",
506
+ "function": {
507
+ "name": self.submit_command,
508
+ "arguments": "{}",
509
+ },
510
+ }
511
+ ],
512
+ }
513
+ else:
514
+ action = f"```\n{self.submit_command}\n```"
515
+
516
+ self._action_idx += 1
517
+
518
+ # Assuming `submit` is always last action of replay trajectory
519
+ if isinstance(action, str) and action == "submit":
520
+ self._next_replay()
521
+ return {"message": action}
522
+
523
+ # Handle both dict and string actions
524
+ if isinstance(action, dict):
525
+ return action
526
+ return {"message": action}
527
+
528
+
529
+ class PredeterminedTestModel(AbstractModel):
530
+ def __init__(self, outputs: list[dict | str]):
531
+ """Model that outputs a predetermined sequence of messages. Useful for testing."""
532
+ self._outputs = outputs
533
+ self._idx = -1
534
+ self.stats = InstanceStats()
535
+
536
+ def query(self, *args, **kwargs) -> dict:
537
+ self._idx += 1
538
+ output = self._outputs[self._idx]
539
+ if isinstance(output, str):
540
+ _handle_raise_commands(output)
541
+ return {"message": output}
542
+ if not isinstance(output, dict):
543
+ msg = f"Output must be string or dict, got {type(output)}"
544
+ raise ValueError(msg)
545
+ result = {"message": output["message"]}
546
+ if "tool_calls" in output:
547
+ result["tool_calls"] = output["tool_calls"]
548
+ return result
549
+
550
+
551
+ class InstantEmptySubmitTestModel(AbstractModel):
552
+ def __init__(self, args: InstantEmptySubmitModelConfig, tools: ToolConfig):
553
+ """This model immediately submits. Useful for testing purposes"""
554
+ super().__init__(args, tools)
555
+ self.config: InstantEmptySubmitModelConfig = args
556
+ self.stats = InstanceStats()
557
+ self._action_idx = 0
558
+
559
+ def query(self, history: list[dict[str, str]]) -> dict:
560
+ time.sleep(random.uniform(0, self.config.delay))
561
+ # Need to at least do _something_ to submit
562
+ if self._action_idx == 0:
563
+ self._action_idx = 1
564
+ action = (
565
+ "DISCUSSION\n"
566
+ "Let's reproduce the bug by creating a `reproduce.py` file.\n\n"
567
+ "```\n"
568
+ "touch reproduce.py\n"
569
+ "```\n"
570
+ )
571
+ elif self._action_idx == 1:
572
+ self._action_idx = 0
573
+ action = "DISCUSSION\nThe task should be resolved, so let's submit the patch.\n\n```\nsubmit\n```\n"
574
+ self.stats.api_calls += 1
575
+ return {"message": action}
576
+
577
+
578
+ class LiteLLMModel(AbstractModel):
579
+ def __init__(self, args: GenericAPIModelConfig, tools: ToolConfig):
580
+ """Model served by the `litellm` library."""
581
+ # Always copy config to avoid shared state between different instances
582
+ self.config: GenericAPIModelConfig = args.model_copy(deep=True)
583
+ self.stats = InstanceStats()
584
+ self.tools = tools
585
+ self.logger = get_logger("swea-lm", emoji="🤖")
586
+
587
+ if tools.use_function_calling:
588
+ if not litellm.utils.supports_function_calling(model=self.config.name):
589
+ msg = (
590
+ f"Model {self.config.name} does not support function calling. If your model"
591
+ " does not support function calling, you can use `parse_function='thought_action'` instead. "
592
+ "See https://swe-agent.com/latest/faq/ for more information."
593
+ )
594
+ self.logger.warning(msg)
595
+ if self.config.litellm_model_registry is not None:
596
+ with open(self.config.litellm_model_registry) as f:
597
+ model_costs = json.load(f)
598
+ litellm.register_model(model_costs)
599
+ if self.config.max_input_tokens is not None:
600
+ self.model_max_input_tokens = self.config.max_input_tokens
601
+ else:
602
+ self.model_max_input_tokens = litellm.model_cost.get(self.config.name, {}).get("max_input_tokens")
603
+
604
+ if self.config.max_output_tokens is not None:
605
+ self.model_max_output_tokens = self.config.max_output_tokens
606
+ else:
607
+ self.model_max_output_tokens = litellm.model_cost.get(self.config.name, {}).get("max_output_tokens")
608
+ # Special handling for Claude 3.7 models to set 64k context by default when beta header not present
609
+ # See https://github.com/SWE-agent/SWE-agent/pull/1016
610
+ is_claude_3_7 = "claude-3-7-sonnet" in self.config.name or "claude-sonnet-4" in self.config.name
611
+ has_128k_beta_header = (
612
+ self.config.completion_kwargs.get("extra_headers", {}).get("anthropic-beta") == "output-128k-2025-02-19"
613
+ )
614
+ if is_claude_3_7 and not has_128k_beta_header:
615
+ self.model_max_output_tokens = 64000
616
+ self.logger.warning(
617
+ "Claude 3.7/4 models do not support 128k context by default. "
618
+ "Setting max output tokens to 64k. To enable 128k context, please set the "
619
+ "completion_kwargs to {'extra_headers': {'anthropic-beta': 'output-128k-2025-02-19'}}."
620
+ )
621
+
622
+ self.lm_provider = litellm.model_cost.get(self.config.name, {}).get("litellm_provider", self.config.name)
623
+ self.custom_tokenizer = None
624
+ if self.config.custom_tokenizer is not None:
625
+ self.custom_tokenizer = litellm.utils.create_pretrained_tokenizer(**self.config.custom_tokenizer)
626
+
627
+ @property
628
+ def instance_cost_limit(self) -> float:
629
+ """Cost limit for the model. Returns 0 if there is no limit."""
630
+ return self.config.per_instance_cost_limit
631
+
632
+ def _update_stats(self, *, input_tokens: int, output_tokens: int, cost: float) -> None:
633
+ with GLOBAL_STATS_LOCK:
634
+ GLOBAL_STATS.total_cost += cost
635
+ self.stats.instance_cost += cost
636
+ self.stats.tokens_sent += input_tokens
637
+ self.stats.tokens_received += output_tokens
638
+ self.stats.api_calls += 1
639
+
640
+ # Log updated cost values to std. err
641
+ self.logger.debug(
642
+ f"input_tokens={input_tokens:,}, "
643
+ f"output_tokens={output_tokens:,}, "
644
+ f"instance_cost={self.stats.instance_cost:.2f}, "
645
+ f"cost={cost:.2f}",
646
+ )
647
+ self.logger.debug(
648
+ f"total_tokens_sent={self.stats.tokens_sent:,}, "
649
+ f"total_tokens_received={self.stats.tokens_received:,}, "
650
+ f"total_cost={GLOBAL_STATS.total_cost:.2f}, "
651
+ f"total_api_calls={self.stats.api_calls:,}",
652
+ )
653
+
654
+ # Check whether total cost or instance cost limits have been exceeded
655
+ if 0 < self.config.total_cost_limit < GLOBAL_STATS.total_cost:
656
+ self.logger.warning(f"Cost {GLOBAL_STATS.total_cost:.2f} exceeds limit {self.config.total_cost_limit:.2f}")
657
+ msg = "Total cost limit exceeded"
658
+ raise TotalCostLimitExceededError(msg)
659
+
660
+ if 0 < self.config.per_instance_cost_limit < self.stats.instance_cost:
661
+ self.logger.warning(
662
+ f"Cost {self.stats.instance_cost:.2f} exceeds limit {self.config.per_instance_cost_limit:.2f}"
663
+ )
664
+ msg = "Instance cost limit exceeded"
665
+ raise InstanceCostLimitExceededError(msg)
666
+
667
+ if 0 < self.config.per_instance_call_limit < self.stats.api_calls:
668
+ self.logger.warning(f"API calls {self.stats.api_calls} exceeds limit {self.config.per_instance_call_limit}")
669
+ msg = "Per instance call limit exceeded"
670
+ raise InstanceCallLimitExceededError(msg)
671
+
672
+ def _sleep(self) -> None:
673
+ elapsed_time = time.time() - GLOBAL_STATS.last_query_timestamp
674
+ if elapsed_time < self.config.delay:
675
+ time.sleep(self.config.delay - elapsed_time)
676
+ with GLOBAL_STATS_LOCK:
677
+ GLOBAL_STATS.last_query_timestamp = time.time()
678
+
679
+ def _single_query(
680
+ self, messages: list[dict[str, str]], n: int | None = None, temperature: float | None = None
681
+ ) -> list[dict]:
682
+ self._sleep()
683
+ # Workaround for litellm bug https://github.com/SWE-agent/SWE-agent/issues/1109
684
+ messages_no_cache_control = copy.deepcopy(messages)
685
+ for message in messages_no_cache_control:
686
+ if "cache_control" in message:
687
+ del message["cache_control"]
688
+ if "thinking_blocks" in message:
689
+ del message["thinking_blocks"]
690
+ input_tokens: int = litellm.utils.token_counter(
691
+ messages=messages_no_cache_control,
692
+ model=self.custom_tokenizer["identifier"] if self.custom_tokenizer is not None else self.config.name,
693
+ custom_tokenizer=self.custom_tokenizer,
694
+ )
695
+ if self.model_max_input_tokens is None:
696
+ msg = (
697
+ f"No max input tokens found for model {self.config.name!r}. "
698
+ "If you are using a local model, you can set `max_input_token` in the model config to override this."
699
+ )
700
+ self.logger.warning(msg)
701
+ elif input_tokens > self.model_max_input_tokens > 0:
702
+ msg = f"Input tokens {input_tokens} exceed max tokens {self.model_max_input_tokens}"
703
+ raise ContextWindowExceededError(msg)
704
+ extra_args = {}
705
+ if self.config.api_base:
706
+ # Not assigned a default value in litellm, so only pass this if it's set
707
+ extra_args["api_base"] = self.config.api_base
708
+ if self.tools.use_function_calling:
709
+ extra_args["tools"] = self.tools.tools
710
+ # We need to always set max_tokens for anthropic models
711
+ completion_kwargs = self.config.completion_kwargs
712
+ if self.lm_provider == "anthropic":
713
+ completion_kwargs["max_tokens"] = self.model_max_output_tokens
714
+ try:
715
+ response: litellm.types.utils.ModelResponse = litellm.completion( # type: ignore
716
+ model=self.config.name,
717
+ messages=messages,
718
+ temperature=self.config.temperature if temperature is None else temperature,
719
+ top_p=self.config.top_p,
720
+ api_version=self.config.api_version,
721
+ api_key=self.config.choose_api_key(),
722
+ fallbacks=self.config.fallbacks,
723
+ **completion_kwargs,
724
+ **extra_args,
725
+ n=n,
726
+ )
727
+ except litellm.exceptions.ContextWindowExceededError as e:
728
+ raise ContextWindowExceededError from e
729
+ except litellm.exceptions.ContentPolicyViolationError as e:
730
+ raise ContentPolicyViolationError from e
731
+ except litellm.exceptions.BadRequestError as e:
732
+ if "is longer than the model's context length" in str(e):
733
+ raise ContextWindowExceededError from e
734
+ raise
735
+ self.logger.debug(f"Response: {response}")
736
+ try:
737
+ cost = litellm.cost_calculator.completion_cost(response, model=self.config.name)
738
+ except Exception as e:
739
+ self.logger.debug(f"Error calculating cost: {e}, setting cost to 0.")
740
+ if self.config.per_instance_cost_limit > 0 or self.config.total_cost_limit > 0:
741
+ msg = (
742
+ f"Error calculating cost: {e} for your model {self.config.name}. If this is ok "
743
+ "(local models, etc.), please make sure you set `per_instance_cost_limit` and "
744
+ "`total_cost_limit` to 0 to disable this safety check."
745
+ )
746
+ self.logger.error(msg)
747
+ raise ModelConfigurationError(msg)
748
+ cost = 0
749
+ choices: litellm.types.utils.Choices = response.choices # type: ignore
750
+ n_choices = n if n is not None else 1
751
+ outputs = []
752
+ output_tokens = 0
753
+ for i in range(n_choices):
754
+ output = choices[i].message.content or ""
755
+ output_tokens += litellm.utils.token_counter(
756
+ text=output,
757
+ model=self.custom_tokenizer["identifier"] if self.custom_tokenizer is not None else self.config.name,
758
+ custom_tokenizer=self.custom_tokenizer,
759
+ )
760
+ output_dict = {"message": output}
761
+ if self.tools.use_function_calling:
762
+ if response.choices[i].message.tool_calls: # type: ignore
763
+ tool_calls = [call.to_dict() for call in response.choices[i].message.tool_calls] # type: ignore
764
+ else:
765
+ tool_calls = []
766
+ output_dict["tool_calls"] = tool_calls
767
+ if (
768
+ hasattr(response.choices[i].message, "thinking_blocks") # type: ignore
769
+ and response.choices[i].message.thinking_blocks # type: ignore
770
+ ):
771
+ output_dict["thinking_blocks"] = response.choices[i].message.thinking_blocks # type: ignore
772
+ outputs.append(output_dict)
773
+ self._update_stats(input_tokens=input_tokens, output_tokens=output_tokens, cost=cost)
774
+ return outputs
775
+
776
+ def _query(
777
+ self, messages: list[dict[str, str]], n: int | None = None, temperature: float | None = None
778
+ ) -> list[dict]:
779
+ if n is None:
780
+ return self._single_query(messages, temperature=temperature)
781
+ outputs = []
782
+ # not needed for openai, but oh well.
783
+ for _ in range(n):
784
+ outputs.extend(self._single_query(messages))
785
+ return outputs
786
+
787
+ def query(self, history: History, n: int = 1, temperature: float | None = None) -> list[dict] | dict:
788
+ messages = self._history_to_messages(history)
789
+
790
+ def retry_warning(retry_state: RetryCallState):
791
+ exception_info = ""
792
+ if attempt.retry_state.outcome is not None and attempt.retry_state.outcome.exception() is not None:
793
+ exception = attempt.retry_state.outcome.exception()
794
+ exception_info = f" due to {exception.__class__.__name__}: {str(exception)}"
795
+
796
+ self.logger.warning(
797
+ f"Retrying LM query: attempt {attempt.retry_state.attempt_number} "
798
+ f"(slept for {attempt.retry_state.idle_for:.2f}s)"
799
+ f"{exception_info}"
800
+ )
801
+
802
+ for attempt in Retrying(
803
+ stop=stop_after_attempt(self.config.retry.retries),
804
+ wait=wait_random_exponential(min=self.config.retry.min_wait, max=self.config.retry.max_wait),
805
+ reraise=True,
806
+ retry=retry_if_not_exception_type(
807
+ (
808
+ ContextWindowExceededError,
809
+ CostLimitExceededError,
810
+ RuntimeError,
811
+ litellm.exceptions.UnsupportedParamsError,
812
+ litellm.exceptions.NotFoundError,
813
+ litellm.exceptions.PermissionDeniedError,
814
+ litellm.exceptions.ContextWindowExceededError,
815
+ litellm.exceptions.APIError,
816
+ litellm.exceptions.ContentPolicyViolationError,
817
+ TypeError,
818
+ litellm.exceptions.AuthenticationError,
819
+ ContentPolicyViolationError,
820
+ ModelConfigurationError,
821
+ KeyboardInterrupt,
822
+ IndexError,
823
+ )
824
+ ),
825
+ before_sleep=retry_warning,
826
+ ):
827
+ with attempt:
828
+ result = self._query(messages, n=n, temperature=temperature)
829
+ if n is None or n == 1:
830
+ return result[0]
831
+ return result
832
+
833
+ def _history_to_messages(
834
+ self,
835
+ history: History,
836
+ ) -> list[dict[str, str]]:
837
+ history = copy.deepcopy(history)
838
+
839
+ def get_role(history_item: HistoryItem) -> str:
840
+ if history_item["role"] == "system":
841
+ return "user" if self.config.convert_system_to_user else "system"
842
+ return history_item["role"]
843
+
844
+ messages = []
845
+ for history_item in history:
846
+ role = get_role(history_item)
847
+ if role == "tool":
848
+ message = {
849
+ "role": role,
850
+ "content": history_item["content"],
851
+ # Only one tool call per observations
852
+ "tool_call_id": history_item["tool_call_ids"][0], # type: ignore
853
+ }
854
+ elif (tool_calls := history_item.get("tool_calls")) is not None:
855
+ message = {"role": role, "content": history_item["content"], "tool_calls": tool_calls}
856
+ if thinking_blocks := history_item.get("thinking_blocks"):
857
+ message["thinking_blocks"] = thinking_blocks
858
+ else:
859
+ message = {"role": role, "content": history_item["content"]}
860
+ if "cache_control" in history_item:
861
+ message["cache_control"] = history_item["cache_control"]
862
+ messages.append(message)
863
+ n_cache_control = str(messages).count("cache_control")
864
+ self.logger.debug(f"n_cache_control: {n_cache_control}")
865
+ return messages
866
+
867
+
868
+ def get_model(args: ModelConfig, tools: ToolConfig) -> AbstractModel:
869
+ """Returns correct model object given arguments and commands"""
870
+ # Convert GenericAPIModelConfig to specific model config if needed
871
+ if isinstance(args, GenericAPIModelConfig) and not isinstance(
872
+ args, HumanModelConfig | HumanThoughtModelConfig | ReplayModelConfig | InstantEmptySubmitModelConfig
873
+ ):
874
+ if args.name == "human":
875
+ args = HumanModelConfig(**args.model_dump())
876
+ elif args.name == "human_thought":
877
+ args = HumanThoughtModelConfig(**args.model_dump())
878
+ elif args.name == "replay":
879
+ args = ReplayModelConfig(**args.model_dump())
880
+ elif args.name == "instant_empty_submit":
881
+ args = InstantEmptySubmitModelConfig(**args.model_dump())
882
+
883
+ if args.name == "human":
884
+ assert isinstance(args, HumanModelConfig), f"Expected {HumanModelConfig}, got {args}"
885
+ return HumanModel(args, tools)
886
+ if args.name == "human_thought":
887
+ assert isinstance(args, HumanThoughtModelConfig), f"Expected {HumanThoughtModelConfig}, got {args}"
888
+ return HumanThoughtModel(args, tools)
889
+ if args.name == "replay":
890
+ assert isinstance(args, ReplayModelConfig), f"Expected {ReplayModelConfig}, got {args}"
891
+ return ReplayModel(args, tools)
892
+ elif args.name == "instant_empty_submit":
893
+ assert isinstance(args, InstantEmptySubmitModelConfig), f"Expected {InstantEmptySubmitModelConfig}, got {args}"
894
+ return InstantEmptySubmitTestModel(args, tools)
895
+ assert isinstance(args, GenericAPIModelConfig), f"Expected {GenericAPIModelConfig}, got {args}"
896
+ return LiteLLMModel(args, tools)