decodingtrust-agent-sdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (374) hide show
  1. agent/__init__.py +30 -0
  2. agent/claudesdk/__init__.py +8 -0
  3. agent/claudesdk/example.py +221 -0
  4. agent/claudesdk/src/__init__.py +8 -0
  5. agent/claudesdk/src/agent.py +400 -0
  6. agent/claudesdk/src/mcp_proxy.py +409 -0
  7. agent/claudesdk/src/utils.py +420 -0
  8. agent/googleadk/__init__.py +15 -0
  9. agent/googleadk/example.py +237 -0
  10. agent/googleadk/src/__init__.py +12 -0
  11. agent/googleadk/src/agent.py +401 -0
  12. agent/googleadk/src/mcp_wrapper.py +163 -0
  13. agent/googleadk/src/utils.py +602 -0
  14. agent/langchain/__init__.py +8 -0
  15. agent/langchain/example.py +213 -0
  16. agent/langchain/src/__init__.py +8 -0
  17. agent/langchain/src/agent.py +645 -0
  18. agent/langchain/src/utils.py +433 -0
  19. agent/openaisdk/__init__.py +17 -0
  20. agent/openaisdk/example.py +228 -0
  21. agent/openaisdk/src/__init__.py +12 -0
  22. agent/openaisdk/src/agent.py +491 -0
  23. agent/openaisdk/src/agent_wrapper.py +143 -0
  24. agent/openaisdk/src/mcp_wrapper.py +395 -0
  25. agent/openaisdk/src/utils.py +493 -0
  26. agent/openclaw/__init__.py +10 -0
  27. agent/openclaw/example.py +251 -0
  28. agent/openclaw/src/__init__.py +14 -0
  29. agent/openclaw/src/agent.py +930 -0
  30. agent/openclaw/src/helpers/__init__.py +1 -0
  31. agent/openclaw/src/helpers/auth_helpers.py +55 -0
  32. agent/openclaw/src/mcp_proxy.py +564 -0
  33. agent/openclaw/src/plugin_generator.py +231 -0
  34. agent/openclaw/src/utils.py +341 -0
  35. agent/pocketflow/__init__.py +18 -0
  36. agent/pocketflow/example.py +221 -0
  37. agent/pocketflow/prompts/react_agent.py +46 -0
  38. agent/pocketflow/src/__init__.py +6 -0
  39. agent/pocketflow/src/agent.py +507 -0
  40. agent/pocketflow/src/agent_wrapper.py +159 -0
  41. agent/pocketflow/src/async_helper.py +92 -0
  42. agent/pocketflow/src/mcp_react_agent.py +279 -0
  43. agent/pocketflow/src/native_agent.py +74 -0
  44. agent/pocketflow/src/nodes.py +467 -0
  45. benchmark/__init__.py +0 -0
  46. benchmark/browser/benign.jsonl +34 -0
  47. benchmark/browser/direct.jsonl +85 -0
  48. benchmark/browser/indirect.jsonl +82 -0
  49. benchmark/code/benign.jsonl +0 -0
  50. benchmark/code/direct.jsonl +121 -0
  51. benchmark/code/indirect.jsonl +165 -0
  52. benchmark/crm/benign.jsonl +165 -0
  53. benchmark/crm/direct.jsonl +90 -0
  54. benchmark/crm/indirect.jsonl +150 -0
  55. benchmark/customer-service/benign.jsonl +160 -0
  56. benchmark/customer-service/direct.jsonl +100 -0
  57. benchmark/customer-service/indirect.jsonl +101 -0
  58. benchmark/finance/benign.jsonl +0 -0
  59. benchmark/finance/direct.jsonl +200 -0
  60. benchmark/finance/indirect.jsonl +200 -0
  61. benchmark/legal/benign.jsonl +0 -0
  62. benchmark/legal/direct.jsonl +200 -0
  63. benchmark/legal/indirect.jsonl +200 -0
  64. benchmark/macos/benign.jsonl +30 -0
  65. benchmark/macos/direct.jsonl +50 -0
  66. benchmark/macos/indirect.jsonl +50 -0
  67. benchmark/medical/benign.jsonl +642 -0
  68. benchmark/medical/direct.jsonl +229 -0
  69. benchmark/medical/indirect.jsonl +222 -0
  70. benchmark/os-filesystem/benign.jsonl +200 -0
  71. benchmark/os-filesystem/direct.jsonl +200 -0
  72. benchmark/os-filesystem/indirect.jsonl +200 -0
  73. benchmark/research/benign.jsonl +0 -0
  74. benchmark/research/direct.jsonl +119 -0
  75. benchmark/research/indirect.jsonl +125 -0
  76. benchmark/telecom/benign.jsonl +120 -0
  77. benchmark/telecom/direct.jsonl +161 -0
  78. benchmark/telecom/indirect.jsonl +166 -0
  79. benchmark/travel/benign.jsonl +130 -0
  80. benchmark/travel/direct.jsonl +105 -0
  81. benchmark/travel/indirect.jsonl +120 -0
  82. benchmark/windows/benign.jsonl +100 -0
  83. benchmark/windows/direct.jsonl +140 -0
  84. benchmark/windows/indirect.jsonl +107 -0
  85. benchmark/workflow/benign.jsonl +335 -0
  86. benchmark/workflow/direct.jsonl +78 -0
  87. benchmark/workflow/indirect.jsonl +107 -0
  88. cli/__init__.py +5 -0
  89. cli/main.py +182 -0
  90. cli/scaffold.py +334 -0
  91. decodingtrust_agent_sdk-0.1.0.dist-info/METADATA +642 -0
  92. decodingtrust_agent_sdk-0.1.0.dist-info/RECORD +374 -0
  93. decodingtrust_agent_sdk-0.1.0.dist-info/WHEEL +5 -0
  94. decodingtrust_agent_sdk-0.1.0.dist-info/entry_points.txt +2 -0
  95. decodingtrust_agent_sdk-0.1.0.dist-info/licenses/LICENSE +201 -0
  96. decodingtrust_agent_sdk-0.1.0.dist-info/top_level.txt +6 -0
  97. dt_arena/config/env.yaml +515 -0
  98. dt_arena/config/injection_mcp.yaml +430 -0
  99. dt_arena/config/mcp.yaml +642 -0
  100. dt_arena/envs/arxiv/docker-compose-hub.yml +31 -0
  101. dt_arena/envs/arxiv/docker-compose.yml +36 -0
  102. dt_arena/envs/atlassian/docker/docker-compose.dev.yml +65 -0
  103. dt_arena/envs/atlassian/docker/docker-compose.yml +53 -0
  104. dt_arena/envs/atlassian/docker-compose-hub.yml +57 -0
  105. dt_arena/envs/atlassian/docker-compose.yml +72 -0
  106. dt_arena/envs/bigquery/docker-compose.yml +20 -0
  107. dt_arena/envs/booking/docker-compose.yml +59 -0
  108. dt_arena/envs/calendar/docker-compose-hub.yml +30 -0
  109. dt_arena/envs/calendar/docker-compose.yml +42 -0
  110. dt_arena/envs/custom-website/docker-compose.yml +6 -0
  111. dt_arena/envs/customer_service/docker-compose.yml +59 -0
  112. dt_arena/envs/databricks/docker-compose-hub.yml +47 -0
  113. dt_arena/envs/databricks/docker-compose.yml +51 -0
  114. dt_arena/envs/ecommerce/docker-compose.yml +6 -0
  115. dt_arena/envs/ers/docker-compose.yml +36 -0
  116. dt_arena/envs/ers/hrms/docker/docker-compose.yml +31 -0
  117. dt_arena/envs/finance/docker-compose.yml +23 -0
  118. dt_arena/envs/github/docker/docker-compose-hub.yml +50 -0
  119. dt_arena/envs/github/docker/docker-compose.yml +50 -0
  120. dt_arena/envs/gmail/docker-compose-hub.yml +51 -0
  121. dt_arena/envs/gmail/docker-compose.yml +65 -0
  122. dt_arena/envs/google-form/docker-compose-hub.yml +33 -0
  123. dt_arena/envs/google-form/docker-compose.yml +41 -0
  124. dt_arena/envs/googledocs/docker-compose-hub.yml +61 -0
  125. dt_arena/envs/googledocs/docker-compose.yml +78 -0
  126. dt_arena/envs/hospital/docker-compose-hub.yml +25 -0
  127. dt_arena/envs/hospital/docker-compose.yml +27 -0
  128. dt_arena/envs/legal/docker-compose.yml +22 -0
  129. dt_arena/envs/linkedin/docker-compose.yml +63 -0
  130. dt_arena/envs/macos/docker-compose.yml +79 -0
  131. dt_arena/envs/os-filesystem/docker-compose-hub.yml +16 -0
  132. dt_arena/envs/os-filesystem/docker-compose.yml +20 -0
  133. dt_arena/envs/paypal/docker-compose-hub.yml +48 -0
  134. dt_arena/envs/paypal/docker-compose.yml +63 -0
  135. dt_arena/envs/research/docker-compose-hub.yml +13 -0
  136. dt_arena/envs/research/docker-compose.yml +24 -0
  137. dt_arena/envs/salesforce_crm/docker-compose-hub.yaml +45 -0
  138. dt_arena/envs/salesforce_crm/docker-compose.yaml +49 -0
  139. dt_arena/envs/slack/docker-compose-hub.yml +28 -0
  140. dt_arena/envs/slack/docker-compose.yml +41 -0
  141. dt_arena/envs/snowflake/docker-compose-hub.yml +41 -0
  142. dt_arena/envs/snowflake/docker-compose.yml +44 -0
  143. dt_arena/envs/telecom/docker-compose-hub.yml +16 -0
  144. dt_arena/envs/telecom/docker-compose.yml +17 -0
  145. dt_arena/envs/telegram/docker-compose-hub.yml +57 -0
  146. dt_arena/envs/telegram/docker-compose.yml +62 -0
  147. dt_arena/envs/terminal/docker-compose-hub.yml +12 -0
  148. dt_arena/envs/terminal/docker-compose.yml +26 -0
  149. dt_arena/envs/travel/docker-compose-hub.yml +19 -0
  150. dt_arena/envs/travel/docker-compose.yml +19 -0
  151. dt_arena/envs/whatsapp/docker-compose-hub.yml +61 -0
  152. dt_arena/envs/whatsapp/docker-compose.yml +78 -0
  153. dt_arena/envs/windows/docker-compose.yml +71 -0
  154. dt_arena/envs/zoom/docker-compose-hub.yml +27 -0
  155. dt_arena/envs/zoom/docker-compose.yml +40 -0
  156. dt_arena/injection_mcp_server/atlassian/env_injection.py +134 -0
  157. dt_arena/injection_mcp_server/calendar/env_injection.py +217 -0
  158. dt_arena/injection_mcp_server/custom_website/env_injection.py +97 -0
  159. dt_arena/injection_mcp_server/customer_service/env_injection.py +659 -0
  160. dt_arena/injection_mcp_server/databricks/env_injection.py +255 -0
  161. dt_arena/injection_mcp_server/ecommerce/env_injection.py +110 -0
  162. dt_arena/injection_mcp_server/finance/env_injection.py +85 -0
  163. dt_arena/injection_mcp_server/github/env_injection.py +206 -0
  164. dt_arena/injection_mcp_server/gmail/env_injection.py +211 -0
  165. dt_arena/injection_mcp_server/google_form/env_injection.py +186 -0
  166. dt_arena/injection_mcp_server/googledocs/env_injection.py +44 -0
  167. dt_arena/injection_mcp_server/hospital/env_injection.py +43 -0
  168. dt_arena/injection_mcp_server/legal/env_injection.py +229 -0
  169. dt_arena/injection_mcp_server/macos/env_injection.py +272 -0
  170. dt_arena/injection_mcp_server/os-filesystem/env_injection.py +341 -0
  171. dt_arena/injection_mcp_server/paypal/env_injection.py +268 -0
  172. dt_arena/injection_mcp_server/research/env_injection.py +616 -0
  173. dt_arena/injection_mcp_server/salesforce/env_injection.py +514 -0
  174. dt_arena/injection_mcp_server/slack/env_injection.py +265 -0
  175. dt_arena/injection_mcp_server/snowflake/env_injection.py +230 -0
  176. dt_arena/injection_mcp_server/telecom/env_injection.py +503 -0
  177. dt_arena/injection_mcp_server/telegram/env_injection.py +171 -0
  178. dt_arena/injection_mcp_server/terminal/env_injection.py +523 -0
  179. dt_arena/injection_mcp_server/travel/env_injection.py +173 -0
  180. dt_arena/injection_mcp_server/whatsapp/env_injection.py +185 -0
  181. dt_arena/injection_mcp_server/windows/env_injection.py +943 -0
  182. dt_arena/injection_mcp_server/zoom/env_injection.py +216 -0
  183. dt_arena/mcp_server/atlassian/main.py +1554 -0
  184. dt_arena/mcp_server/atlassian/test_server.py +66 -0
  185. dt_arena/mcp_server/bigquery/main.py +333 -0
  186. dt_arena/mcp_server/booking/main.py +310 -0
  187. dt_arena/mcp_server/browser/main.py +1741 -0
  188. dt_arena/mcp_server/calendar/example_multi_user.py +162 -0
  189. dt_arena/mcp_server/calendar/main.py +792 -0
  190. dt_arena/mcp_server/calendar/test_mcp.py +135 -0
  191. dt_arena/mcp_server/customer_service/main.py +1063 -0
  192. dt_arena/mcp_server/databricks/main.py +566 -0
  193. dt_arena/mcp_server/databricks/probe.py +102 -0
  194. dt_arena/mcp_server/ers/main.py +845 -0
  195. dt_arena/mcp_server/finance/__init__.py +87 -0
  196. dt_arena/mcp_server/finance/core/__init__.py +12 -0
  197. dt_arena/mcp_server/finance/core/data_loader.py +558 -0
  198. dt_arena/mcp_server/finance/core/portfolio.py +565 -0
  199. dt_arena/mcp_server/finance/evaluation/__init__.py +20 -0
  200. dt_arena/mcp_server/finance/evaluation/evaluator.py +217 -0
  201. dt_arena/mcp_server/finance/evaluation/logger.py +137 -0
  202. dt_arena/mcp_server/finance/injection/__init__.py +66 -0
  203. dt_arena/mcp_server/finance/injection/config.py +176 -0
  204. dt_arena/mcp_server/finance/injection/content.py +755 -0
  205. dt_arena/mcp_server/finance/injection/html.py +409 -0
  206. dt_arena/mcp_server/finance/injection/locations.py +167 -0
  207. dt_arena/mcp_server/finance/injection/methods.py +193 -0
  208. dt_arena/mcp_server/finance/injection/presets.py +1023 -0
  209. dt_arena/mcp_server/finance/main.py +361 -0
  210. dt_arena/mcp_server/finance/run_mcp.py +21 -0
  211. dt_arena/mcp_server/finance/run_web.py +26 -0
  212. dt_arena/mcp_server/finance/server/__init__.py +41 -0
  213. dt_arena/mcp_server/finance/server/extractor.py +1453 -0
  214. dt_arena/mcp_server/finance/server/extractor_minimal.py +292 -0
  215. dt_arena/mcp_server/finance/server/extractor_simple.py +1164 -0
  216. dt_arena/mcp_server/finance/server/injection_mcp.py +865 -0
  217. dt_arena/mcp_server/finance/server/mcp.py +451 -0
  218. dt_arena/mcp_server/finance/server/tools/__init__.py +23 -0
  219. dt_arena/mcp_server/finance/server/tools/account.py +88 -0
  220. dt_arena/mcp_server/finance/server/tools/browsing.py +328 -0
  221. dt_arena/mcp_server/finance/server/tools/social.py +73 -0
  222. dt_arena/mcp_server/finance/server/tools/trading.py +242 -0
  223. dt_arena/mcp_server/finance/server/tools/utility.py +49 -0
  224. dt_arena/mcp_server/finance/server/web.py +2139 -0
  225. dt_arena/mcp_server/finance/tasks/benchmark/__init__.py +28 -0
  226. dt_arena/mcp_server/finance/tasks/benchmark/attack_pool.py +3026 -0
  227. dt_arena/mcp_server/finance/tasks/benchmark/attack_runner.py +1315 -0
  228. dt_arena/mcp_server/finance/tasks/benchmark/finra_requirements.py +1335 -0
  229. dt_arena/mcp_server/finance/tasks/benchmark/finra_tasks.py +3665 -0
  230. dt_arena/mcp_server/finance/tasks/benchmark/malicious_tasks.py +2673 -0
  231. dt_arena/mcp_server/finance/tasks/redteam_suite/run_redteam_suite.py +1713 -0
  232. dt_arena/mcp_server/finance/test_mcp_tools.py +476 -0
  233. dt_arena/mcp_server/github/main.py +441 -0
  234. dt_arena/mcp_server/gmail/main.py +1004 -0
  235. dt_arena/mcp_server/google_form/main.py +141 -0
  236. dt_arena/mcp_server/googledocs/main.py +458 -0
  237. dt_arena/mcp_server/hospital/mcp_server.py +458 -0
  238. dt_arena/mcp_server/legal/__init__.py +9 -0
  239. dt_arena/mcp_server/legal/core/__init__.py +14 -0
  240. dt_arena/mcp_server/legal/core/courtlistener_store.py +762 -0
  241. dt_arena/mcp_server/legal/core/data_loader.py +266 -0
  242. dt_arena/mcp_server/legal/core/document_store.py +197 -0
  243. dt_arena/mcp_server/legal/core/matter_manager.py +466 -0
  244. dt_arena/mcp_server/legal/main.py +89 -0
  245. dt_arena/mcp_server/legal/scripts/collect_data.py +988 -0
  246. dt_arena/mcp_server/legal/server/__init__.py +14 -0
  247. dt_arena/mcp_server/legal/server/mcp.py +2330 -0
  248. dt_arena/mcp_server/macos/client_test.py +270 -0
  249. dt_arena/mcp_server/macos/mcp_server.py +285 -0
  250. dt_arena/mcp_server/os-filesystem/main.py +1380 -0
  251. dt_arena/mcp_server/paypal/main.py +501 -0
  252. dt_arena/mcp_server/research/main.py +777 -0
  253. dt_arena/mcp_server/salesforce/main.py +2006 -0
  254. dt_arena/mcp_server/slack/main.py +318 -0
  255. dt_arena/mcp_server/snowflake/main.py +612 -0
  256. dt_arena/mcp_server/snowflake/probe.py +183 -0
  257. dt_arena/mcp_server/telecom/mcp_client.py +423 -0
  258. dt_arena/mcp_server/telecom/mcp_server.py +1059 -0
  259. dt_arena/mcp_server/telegram/main.py +338 -0
  260. dt_arena/mcp_server/terminal/main.py +163 -0
  261. dt_arena/mcp_server/travel/client_test.py +16 -0
  262. dt_arena/mcp_server/travel/mcp_server.py +404 -0
  263. dt_arena/mcp_server/whatsapp/main.py +318 -0
  264. dt_arena/mcp_server/windows/client_test.py +270 -0
  265. dt_arena/mcp_server/windows/mcp_server.py +218 -0
  266. dt_arena/mcp_server/zoom/main.py +466 -0
  267. dt_arena/src/__init__.py +0 -0
  268. dt_arena/src/hooks/__init__.py +0 -0
  269. dt_arena/src/hooks/audit_log.py +30 -0
  270. dt_arena/src/hooks/hooks.json +3 -0
  271. dt_arena/src/run_benign.py +142 -0
  272. dt_arena/src/types/__init__.py +0 -0
  273. dt_arena/src/types/agent.py +441 -0
  274. dt_arena/src/types/attacks.py +2 -0
  275. dt_arena/src/types/environment.py +2 -0
  276. dt_arena/src/types/hooks.py +174 -0
  277. dt_arena/src/types/judge.py +52 -0
  278. dt_arena/src/types/red_teaming_trajectory.py +385 -0
  279. dt_arena/src/types/task.py +260 -0
  280. dt_arena/src/types/trajectory.py +315 -0
  281. dt_arena/utils/__init__.py +1 -0
  282. dt_arena/utils/atlassian/__init__.py +27 -0
  283. dt_arena/utils/atlassian/helpers.py +520 -0
  284. dt_arena/utils/bigquery/__init__.py +1 -0
  285. dt_arena/utils/bigquery/helpers.py +246 -0
  286. dt_arena/utils/calendar/__init__.py +1 -0
  287. dt_arena/utils/calendar/helpers.py +87 -0
  288. dt_arena/utils/customer_service/__init__.py +17 -0
  289. dt_arena/utils/customer_service/cs_env_client.py +940 -0
  290. dt_arena/utils/customer_service/helpers.py +339 -0
  291. dt_arena/utils/customer_service/judges/__init__.py +20 -0
  292. dt_arena/utils/customer_service/judges/abusive_comms.py +149 -0
  293. dt_arena/utils/customer_service/judges/bulk_destructive.py +202 -0
  294. dt_arena/utils/customer_service/judges/cross_customer.py +481 -0
  295. dt_arena/utils/customer_service/judges/email_exfiltration.py +320 -0
  296. dt_arena/utils/customer_service/judges/financial_fraud.py +221 -0
  297. dt_arena/utils/customer_service/judges/llm_utils.py +264 -0
  298. dt_arena/utils/customer_service/judges/pii_disclosure.py +837 -0
  299. dt_arena/utils/customer_service/judges/policy_misrep.py +44 -0
  300. dt_arena/utils/customer_service/judges/text_utils.py +21 -0
  301. dt_arena/utils/databricks/__init__.py +2 -0
  302. dt_arena/utils/databricks/helpers.py +210 -0
  303. dt_arena/utils/finance/__init__.py +0 -0
  304. dt_arena/utils/finance/helpers.py +263 -0
  305. dt_arena/utils/github/__init__.py +1 -0
  306. dt_arena/utils/github/helpers.py +249 -0
  307. dt_arena/utils/gmail/__init__.py +1 -0
  308. dt_arena/utils/gmail/helpers.py +344 -0
  309. dt_arena/utils/google_form/__init__.py +2 -0
  310. dt_arena/utils/google_form/helpers.py +133 -0
  311. dt_arena/utils/legal/__init__.py +0 -0
  312. dt_arena/utils/legal/helpers.py +228 -0
  313. dt_arena/utils/macos/__init__.py +0 -0
  314. dt_arena/utils/macos/env_setup.py +215 -0
  315. dt_arena/utils/macos/helpers.py +61 -0
  316. dt_arena/utils/os_filesystem/__init__.py +1 -0
  317. dt_arena/utils/os_filesystem/helpers.py +366 -0
  318. dt_arena/utils/paypal/__init__.py +1 -0
  319. dt_arena/utils/paypal/helpers.py +178 -0
  320. dt_arena/utils/port_allocator.py +266 -0
  321. dt_arena/utils/research/__init__.py +0 -0
  322. dt_arena/utils/research/helpers.py +251 -0
  323. dt_arena/utils/salesforce/__init__.py +1 -0
  324. dt_arena/utils/salesforce/helpers.py +719 -0
  325. dt_arena/utils/slack/__init__.py +1 -0
  326. dt_arena/utils/slack/helpers.py +176 -0
  327. dt_arena/utils/snowflake/__init__.py +1 -0
  328. dt_arena/utils/snowflake/helpers.py +166 -0
  329. dt_arena/utils/telecom/__init__.py +1 -0
  330. dt_arena/utils/telecom/helpers.py +760 -0
  331. dt_arena/utils/telegram/__init__.py +0 -0
  332. dt_arena/utils/telegram/helpers.py +174 -0
  333. dt_arena/utils/terminal/__init__.py +0 -0
  334. dt_arena/utils/terminal/helpers.py +20 -0
  335. dt_arena/utils/travel/__init__.py +0 -0
  336. dt_arena/utils/travel/env_client.py +537 -0
  337. dt_arena/utils/travel/llm_judge.py +137 -0
  338. dt_arena/utils/travel/prompts.py +64 -0
  339. dt_arena/utils/utils/__init__.py +122 -0
  340. dt_arena/utils/whatsapp/__init__.py +0 -0
  341. dt_arena/utils/whatsapp/helpers.py +226 -0
  342. dt_arena/utils/windows/__init__.py +0 -0
  343. dt_arena/utils/windows/env_reset.py +224 -0
  344. dt_arena/utils/windows/env_setup.py +280 -0
  345. dt_arena/utils/windows/exfil_helpers.py +170 -0
  346. dt_arena/utils/windows/helpers.py +74 -0
  347. dt_arena/utils/zoom/__init__.py +1 -0
  348. dt_arena/utils/zoom/helpers.py +70 -0
  349. eval/__init__.py +1 -0
  350. eval/evaluation.py +426 -0
  351. eval/task_runner.py +449 -0
  352. utils/__init__.py +148 -0
  353. utils/agent_helpers.py +308 -0
  354. utils/agent_wrapper.py +189 -0
  355. utils/compose_utils.py +135 -0
  356. utils/config.py +77 -0
  357. utils/env_helpers.py +104 -0
  358. utils/eval_stats.py +88 -0
  359. utils/injection_helpers.py +429 -0
  360. utils/injection_mcp_helpers.py +152 -0
  361. utils/judge_helpers.py +181 -0
  362. utils/judge_utils.py +472 -0
  363. utils/llm.py +196 -0
  364. utils/logging.py +45 -0
  365. utils/mcp_helpers.py +232 -0
  366. utils/mcp_manager.py +235 -0
  367. utils/memory_guard.py +18 -0
  368. utils/red_teaming_sandbox.py +476 -0
  369. utils/reset_helpers.py +318 -0
  370. utils/resource_manager.py +370 -0
  371. utils/skill_helpers.py +447 -0
  372. utils/task_executor.py +904 -0
  373. utils/task_helpers.py +270 -0
  374. utils/template_helpers.py +179 -0
utils/task_executor.py ADDED
@@ -0,0 +1,904 @@
1
+ import asyncio
2
+ import grp
3
+ import json
4
+ import os
5
+ import time
6
+ from collections import defaultdict
7
+ from dataclasses import dataclass, field
8
+ from enum import Enum
9
+ from pathlib import Path
10
+ from typing import Any, Callable, Coroutine, Dict, FrozenSet, List, Optional, Set, Tuple
11
+
12
+ import yaml
13
+
14
+ from .memory_guard import check_memory_before_launch
15
+ from .reset_helpers import reset_environment
16
+
17
+
18
+ def _needs_sudo_for_docker() -> bool:
19
+ """Check if we need sudo to run docker commands."""
20
+ import subprocess
21
+ # First, try running docker directly
22
+ try:
23
+ result = subprocess.run(
24
+ ["docker", "ps"],
25
+ capture_output=True,
26
+ timeout=5
27
+ )
28
+ if result.returncode == 0:
29
+ return False
30
+ except (subprocess.SubprocessError, FileNotFoundError):
31
+ pass
32
+
33
+ # Fallback to group check
34
+ try:
35
+ docker_gid = grp.getgrnam("docker").gr_gid
36
+ if docker_gid in os.getgroups():
37
+ return False
38
+ except (KeyError, OSError):
39
+ pass
40
+ # Default to using sudo
41
+ return True
42
+
43
+
44
+ # Cache the result since it won't change during execution
45
+ _USE_SUDO = _needs_sudo_for_docker()
46
+
47
+
48
+ def get_task_environments(task_dir: Path) -> List[str]:
49
+ """
50
+ Determine which environments a task needs based on its config.yaml.
51
+
52
+ Uses the 'environment' field in mcp.yaml to directly map MCP servers
53
+ to their required Docker environments.
54
+
55
+ Args:
56
+ task_dir: Path to the task directory
57
+
58
+ Returns:
59
+ List of environment names (e.g., ["gmail", "slack"])
60
+ """
61
+ config_path = task_dir / "config.yaml"
62
+ if not config_path.exists():
63
+ return []
64
+
65
+ try:
66
+ config = yaml.safe_load(config_path.read_text()) or {}
67
+ except yaml.YAMLError as e:
68
+ raise ValueError(f"Failed to parse {config_path}: {e}")
69
+
70
+ agent_cfg = config.get("Agent", {})
71
+ mcp_servers = agent_cfg.get("mcp_servers", []) or config.get("mcp_servers", [])
72
+
73
+ server_names = [
74
+ srv.get("name", "").lower()
75
+ for srv in mcp_servers
76
+ if srv.get("enabled", True)
77
+ ]
78
+
79
+ project_root = Path(__file__).resolve().parent.parent
80
+ mcp_config_path = project_root / "dt_arena" / "config" / "mcp.yaml"
81
+
82
+ if not mcp_config_path.exists():
83
+ return server_names
84
+
85
+ try:
86
+ mcp_cfg = yaml.safe_load(mcp_config_path.read_text()) or {}
87
+ except yaml.YAMLError as e:
88
+ raise ValueError(f"Failed to parse {mcp_config_path}: {e}")
89
+
90
+ mcp_servers_cfg = {srv["name"].lower(): srv for srv in mcp_cfg.get("servers", [])}
91
+
92
+ needed_envs: Set[str] = set()
93
+ for srv_name in server_names:
94
+ srv_cfg = mcp_servers_cfg.get(srv_name, {})
95
+ env_value = srv_cfg.get("environment")
96
+ if env_value:
97
+ # Support both single string and list of environments
98
+ if isinstance(env_value, list):
99
+ needed_envs.update(env_value)
100
+ else:
101
+ needed_envs.add(env_value)
102
+
103
+ return list(needed_envs)
104
+
105
+
106
+ class EnvState(Enum):
107
+ """Environment instance lifecycle states."""
108
+ STARTING = "starting"
109
+ AVAILABLE = "available"
110
+ IN_USE = "in_use"
111
+ STOPPING = "stopping"
112
+
113
+
114
+ @dataclass
115
+ class EnvInstance:
116
+ """A running instance of a Docker environment."""
117
+ instance_id: str # Unique ID like "gmail:pool_gmail_1_12345"
118
+ env_name: str # Environment type like "gmail", "slack"
119
+ project_name: str # Docker compose project name
120
+ compose_file: Path
121
+ state: EnvState = EnvState.STARTING
122
+ ports: Dict[str, int] = field(default_factory=dict)
123
+ current_task_id: Optional[str] = None
124
+ use_count: int = 0
125
+ last_used: float = field(default_factory=time.time)
126
+
127
+
128
+ @dataclass
129
+ class ScheduledTask:
130
+ """A task with its environment requirements."""
131
+ task_dir: Path
132
+ environments: FrozenSet[str]
133
+ original_index: int
134
+ domain: Optional[str] = None
135
+ task_type: Optional[str] = None
136
+ threat_model: Optional[str] = None
137
+ risk_category: Optional[str] = None
138
+ task_id: Optional[str] = None
139
+
140
+
141
+ @dataclass
142
+ class RunningTask:
143
+ """A task currently being executed."""
144
+ task: ScheduledTask
145
+ task_id: str
146
+ instances: List[EnvInstance]
147
+ start_time: float = field(default_factory=time.time)
148
+ future: Optional[asyncio.Future] = None
149
+
150
+
151
+ @dataclass
152
+ class ExecutorStats:
153
+ """Statistics about the executor state."""
154
+ total_tasks: int
155
+ pending_tasks: int
156
+ running_tasks: int
157
+ completed_tasks: int
158
+ total_instances: int
159
+ available_instances: int
160
+ in_use_instances: int
161
+ instances_by_env: Dict[str, int]
162
+
163
+
164
+ class TaskExecutor:
165
+ """
166
+ Task-based parallel executor with environment instance pooling.
167
+
168
+ Key features:
169
+ - Controls parallelism by max_parallel_tasks (not max environments)
170
+ - Creates environment instances on demand
171
+ - Reuses available instances when possible
172
+ - Stops instances when no pending task needs them
173
+ - Respects per-environment instance limits (e.g., max 1 Windows VM)
174
+ """
175
+
176
+ def __init__(self, max_parallel: int = 5):
177
+ self.max_parallel = max_parallel
178
+ self._lock = asyncio.Lock()
179
+
180
+ # Task tracking
181
+ self._pending: List[ScheduledTask] = []
182
+ self._running: Dict[str, RunningTask] = {} # task_id -> RunningTask
183
+ self._completed: List[Tuple[ScheduledTask, int]] = [] # (task, return_code)
184
+
185
+ # Environment instance tracking
186
+ self._instances: Dict[str, EnvInstance] = {} # instance_id -> EnvInstance
187
+ self._env_instances: Dict[str, List[str]] = defaultdict(list) # env_name -> [instance_ids]
188
+ self._task_instances: Dict[str, List[str]] = {} # task_id -> [instance_ids]
189
+
190
+ # Configuration
191
+ self._env_config: Dict = {}
192
+ self._env_limits: Dict[str, int] = {} # env_name -> max_instances
193
+ self._default_max_instances: Optional[int] = None
194
+ self._project_counter = 0
195
+ self._pool_id = str(os.getpid())
196
+
197
+ # Load configuration
198
+ self._load_config()
199
+
200
+ def _load_config(self) -> None:
201
+ """Load environment configuration including instance limits."""
202
+ project_root = Path(__file__).resolve().parent.parent
203
+ env_config_path = project_root / "dt_arena" / "config" / "env.yaml"
204
+
205
+ if not env_config_path.exists():
206
+ return
207
+
208
+ try:
209
+ self._env_config = yaml.safe_load(env_config_path.read_text()) or {}
210
+ except Exception as e:
211
+ print(f"[EXECUTOR] Warning: Failed to load env config: {e}")
212
+ return
213
+
214
+ # Load default max instances
215
+ self._default_max_instances = self._env_config.get("default_max_instances")
216
+
217
+ # Load per-environment limits
218
+ environments = self._env_config.get("environments", {})
219
+ for env_name, env_def in environments.items():
220
+ if "max_instances" in env_def:
221
+ self._env_limits[env_name] = env_def["max_instances"]
222
+
223
+ def _get_max_instances(self, env_name: str) -> Optional[int]:
224
+ """Get max instances for an environment (None = unlimited)."""
225
+ if env_name in self._env_limits:
226
+ return self._env_limits[env_name]
227
+ return self._default_max_instances
228
+
229
+ def _get_compose_file(self, env_name: str) -> Optional[Path]:
230
+ """Get the docker-compose file path for an environment."""
231
+ project_root = Path(__file__).resolve().parent.parent
232
+ environments = self._env_config.get("environments", {})
233
+
234
+ if env_name in environments:
235
+ compose_rel = environments[env_name].get("docker_compose")
236
+ if compose_rel:
237
+ return (project_root / compose_rel).resolve()
238
+
239
+ # Fallback: check standard location
240
+ env_path = project_root / "dt_arena" / "envs" / env_name
241
+ for name in ["docker-compose.yml", "docker-compose.yaml"]:
242
+ compose = env_path / name
243
+ if compose.exists():
244
+ return compose
245
+
246
+ return None
247
+
248
+ def _generate_project_name(self, env_name: str) -> str:
249
+ """Generate unique project name for a new instance."""
250
+ self._project_counter += 1
251
+ return f"pool_{env_name}_{self._project_counter}_{self._pool_id}"
252
+
253
+ def _allocate_ports(self, env_name: str) -> Dict[str, int]:
254
+ """Allocate ports for a new instance."""
255
+ from .resource_manager import ResourceManager
256
+
257
+ mgr = ResourceManager.instance()
258
+ environments = self._env_config.get("environments", {})
259
+ env_def = environments.get(env_name, {})
260
+ ports_cfg = env_def.get("ports", {})
261
+
262
+ allocated = {}
263
+ container_id = f"pool_{env_name}_{self._project_counter}"
264
+
265
+ for var_name, meta in ports_cfg.items():
266
+ default = int(meta.get("default", meta.get("container_port", 8000)))
267
+ port = mgr.allocate_port(container_id, var_name, default=default)
268
+ allocated[var_name] = port
269
+
270
+ return allocated
271
+
272
+ def _validate_env_requirements(self, compose_file: Path) -> bool:
273
+ """
274
+ Validate environment-specific requirements before starting.
275
+
276
+ Looks for a validate.py in the environment directory and calls its
277
+ validate(env_dir) function if present.
278
+
279
+ Returns:
280
+ True if validation passes, False otherwise
281
+ """
282
+ import importlib.util
283
+
284
+ env_dir = compose_file.parent
285
+ validate_script = env_dir / "validate.py"
286
+
287
+ if not validate_script.exists():
288
+ return True
289
+
290
+ try:
291
+ spec = importlib.util.spec_from_file_location("validate", validate_script)
292
+ if spec is None or spec.loader is None:
293
+ return True
294
+
295
+ module = importlib.util.module_from_spec(spec)
296
+ spec.loader.exec_module(module)
297
+
298
+ if hasattr(module, "validate"):
299
+ success, message = module.validate(env_dir)
300
+ if not success:
301
+ print(message, flush=True)
302
+ return False
303
+
304
+ except Exception as e:
305
+ print(f"[EXECUTOR] Warning: Failed to run {validate_script}: {e}", flush=True)
306
+
307
+ return True
308
+
309
+ async def _start_instance(self, env_name: str, max_retries: int = 5, wait_time: int = 30) -> Optional[EnvInstance]:
310
+ """Start a new environment instance."""
311
+ # Memory guard: wait for sufficient memory before launching
312
+ for attempt in range(max_retries):
313
+ try:
314
+ check_memory_before_launch()
315
+ break
316
+ except RuntimeError as e:
317
+ if attempt == max_retries - 1:
318
+ print(f"[EXECUTOR] Memory guard blocked launch after {max_retries} retries: {e}", flush=True)
319
+ return None
320
+ wait = wait_time * (attempt + 1)
321
+ print(f"[EXECUTOR] Low memory, waiting {wait}s before retry ({attempt+1}/{max_retries})...", flush=True)
322
+ await asyncio.sleep(wait)
323
+
324
+ compose_file = self._get_compose_file(env_name)
325
+ if not compose_file or not compose_file.exists():
326
+ print(f"[EXECUTOR] No docker-compose file found for {env_name}", flush=True)
327
+ return None
328
+
329
+ # Pre-start validation (calls validate.py if present in env directory)
330
+ if not self._validate_env_requirements(compose_file):
331
+ return None
332
+
333
+ project_name = self._generate_project_name(env_name)
334
+ instance_id = f"{env_name}:{project_name}"
335
+ ports = self._allocate_ports(env_name)
336
+
337
+ instance = EnvInstance(
338
+ instance_id=instance_id,
339
+ env_name=env_name,
340
+ project_name=project_name,
341
+ compose_file=compose_file,
342
+ state=EnvState.STARTING,
343
+ ports=ports,
344
+ )
345
+
346
+ print(f"[EXECUTOR] Starting instance {instance_id}", flush=True)
347
+
348
+ # Track the new instance
349
+ self._instances[instance.instance_id] = instance
350
+ self._env_instances[env_name].append(instance.instance_id)
351
+
352
+ try:
353
+ # Build docker compose command
354
+ if _USE_SUDO:
355
+ # Use sudo with 'env' command to ensure env vars are properly passed
356
+ cmd = ["sudo", "env"] + [f"{k}={v}" for k, v in ports.items()] + [
357
+ "docker", "compose", "-p", project_name, "-f", str(compose_file), "up", "-d"
358
+ ]
359
+ proc = await asyncio.create_subprocess_exec(
360
+ *cmd,
361
+ cwd=str(compose_file.parent),
362
+ stdout=asyncio.subprocess.PIPE,
363
+ stderr=asyncio.subprocess.PIPE,
364
+ )
365
+ else:
366
+ # Set port environment variables when not using sudo
367
+ env = os.environ.copy()
368
+ for var_name, port in ports.items():
369
+ env[var_name] = str(port)
370
+
371
+ proc = await asyncio.create_subprocess_exec(
372
+ "docker", "compose", "-p", project_name, "-f", str(compose_file), "up", "-d",
373
+ cwd=str(compose_file.parent),
374
+ env=env,
375
+ stdout=asyncio.subprocess.PIPE,
376
+ stderr=asyncio.subprocess.PIPE,
377
+ )
378
+
379
+ stdout, stderr = await proc.communicate()
380
+ if proc.returncode != 0:
381
+ print(f"[EXECUTOR] Start failed for {instance_id}: {stderr.decode()}", flush=True)
382
+ return None
383
+
384
+ # Wait for containers to be healthy
385
+ # Get timeout from env config (default 120s, Windows needs 240s+)
386
+ environments = self._env_config.get("environments", {})
387
+ env_def = environments.get(env_name, {})
388
+ health_timeout = env_def.get("health_timeout", 120)
389
+ await self._wait_for_healthy(project_name, compose_file, timeout=health_timeout)
390
+
391
+ instance.state = EnvState.AVAILABLE
392
+ print(f"[EXECUTOR] Instance {instance_id} started successfully", flush=True)
393
+ return instance
394
+
395
+ except Exception as e:
396
+ print(f"[EXECUTOR] Failed to start instance {instance_id}: {e}", flush=True)
397
+ return None
398
+
399
+ async def _wait_for_healthy(
400
+ self,
401
+ project_name: str,
402
+ compose_file: Path,
403
+ timeout: int = 120,
404
+ ) -> bool:
405
+ """Wait for all containers in the project to be healthy."""
406
+ print(f"[EXECUTOR] Waiting for containers to be healthy...", flush=True)
407
+ start_time = time.time()
408
+
409
+ while time.time() - start_time < timeout:
410
+ # Check container health status
411
+ cmd = ["docker", "compose", "-p", project_name, "-f", str(compose_file), "ps", "--format", "json"]
412
+ if _USE_SUDO:
413
+ cmd = ["sudo"] + cmd
414
+
415
+ proc = await asyncio.create_subprocess_exec(
416
+ *cmd,
417
+ cwd=str(compose_file.parent),
418
+ stdout=asyncio.subprocess.PIPE,
419
+ stderr=asyncio.subprocess.PIPE,
420
+ )
421
+ stdout, stderr = await proc.communicate()
422
+
423
+ if proc.returncode != 0:
424
+ await asyncio.sleep(2)
425
+ continue
426
+
427
+ # Parse JSON output (one JSON object per line)
428
+ all_healthy = True
429
+ output = stdout.decode().strip()
430
+ if not output:
431
+ await asyncio.sleep(2)
432
+ continue
433
+
434
+ for line in output.split('\n'):
435
+ if not line.strip():
436
+ continue
437
+ try:
438
+ container = json.loads(line)
439
+ health = container.get("Health", "")
440
+ state = container.get("State", "")
441
+ # Container is ready if: no healthcheck (health empty) and running, OR healthy
442
+ if state != "running":
443
+ all_healthy = False
444
+ break
445
+ if health and health != "healthy":
446
+ all_healthy = False
447
+ break
448
+ except json.JSONDecodeError:
449
+ continue
450
+
451
+ if all_healthy:
452
+ print(f"[EXECUTOR] All containers healthy", flush=True)
453
+ return True
454
+
455
+ await asyncio.sleep(2)
456
+
457
+ print(f"[EXECUTOR] Timeout waiting for containers to be healthy", flush=True)
458
+ return False
459
+
460
+ async def _stop_instance(self, instance: EnvInstance) -> bool:
461
+ """Stop and remove an environment instance."""
462
+ instance.state = EnvState.STOPPING
463
+ print(f"[EXECUTOR] Stopping instance {instance.instance_id}", flush=True)
464
+
465
+ try:
466
+ cmd = ["docker", "compose", "-p", instance.project_name,
467
+ "-f", str(instance.compose_file), "down", "--remove-orphans", "--volumes"]
468
+ if _USE_SUDO:
469
+ cmd = ["sudo"] + cmd
470
+
471
+ proc = await asyncio.create_subprocess_exec(
472
+ *cmd,
473
+ cwd=str(instance.compose_file.parent),
474
+ stdout=asyncio.subprocess.PIPE,
475
+ stderr=asyncio.subprocess.PIPE,
476
+ )
477
+ await asyncio.wait_for(proc.wait(), timeout=60)
478
+
479
+ # Remove from tracking
480
+ if instance.instance_id in self._instances:
481
+ del self._instances[instance.instance_id]
482
+ if instance.instance_id in self._env_instances.get(instance.env_name, []):
483
+ self._env_instances[instance.env_name].remove(instance.instance_id)
484
+
485
+ print(f"[EXECUTOR] Instance {instance.instance_id} stopped", flush=True)
486
+ return True
487
+
488
+ except Exception as e:
489
+ print(f"[EXECUTOR] Error stopping instance {instance.instance_id}: {e}", flush=True)
490
+ return False
491
+
492
+ def _get_disable_reuse_flag(self, env_name: str) -> bool:
493
+ """Get disable reuse flag from env config."""
494
+ environments = self._env_config.get("environments", {})
495
+ env_def = environments.get(env_name, {})
496
+ return env_def.get("disable_reuse", False)
497
+
498
+ def _get_reset_scripts(self, env_name: str) -> Dict[str, str]:
499
+ """Get reset script paths for each service from env config."""
500
+ environments = self._env_config.get("environments", {})
501
+ env_def = environments.get(env_name, {})
502
+ return env_def.get("reset_scripts", {})
503
+
504
+ def _get_reset_endpoints(self, env_name: str) -> Dict[str, Dict[str, Any]]:
505
+ """
506
+ Get reset API endpoints from env config.
507
+
508
+ Returns a dict mapping endpoint names to their config:
509
+ {
510
+ "endpoint_name": {
511
+ "url": "http://localhost:${PORT}/api/v1/reset",
512
+ "method": "POST", # optional, defaults to POST
513
+ "port_var": "SALESFORCE_API_PORT" # which port variable to use
514
+ }
515
+ }
516
+ """
517
+ environments = self._env_config.get("environments", {})
518
+ env_def = environments.get(env_name, {})
519
+ return env_def.get("reset_endpoints", {})
520
+
521
+ async def _reset_instance(self, instance: EnvInstance) -> None:
522
+ """
523
+ Reset an instance's data state.
524
+
525
+ Delegates to reset_environment from reset_helpers module which handles:
526
+ 1. API endpoints (reset_endpoints) - preferred, safer
527
+ 2. Docker exec scripts (reset_scripts) - fallback
528
+
529
+ Raises:
530
+ RuntimeError: If reset fails (both endpoints and scripts failed)
531
+ """
532
+ print(f"[EXECUTOR] Resetting instance {instance.instance_id}", flush=True)
533
+
534
+ # Read per-environment script timeout (e.g. Windows loadvm needs ~60s)
535
+ environments = self._env_config.get("environments", {})
536
+ env_def = environments.get(instance.env_name, {})
537
+ script_timeout = env_def.get("reset_script_timeout", 60)
538
+
539
+ await reset_environment(
540
+ env_name=instance.env_name,
541
+ ports=instance.ports,
542
+ env_config=self._env_config,
543
+ project_name=instance.project_name,
544
+ compose_file=instance.compose_file,
545
+ script_timeout=script_timeout,
546
+ )
547
+
548
+ def _count_env_instances(self, env_name: str, exclude_stopping: bool = True) -> int:
549
+ """Count running instances of an environment type."""
550
+ count = 0
551
+ for inst_id in self._env_instances.get(env_name, []):
552
+ inst = self._instances.get(inst_id)
553
+ if inst:
554
+ if exclude_stopping and inst.state == EnvState.STOPPING:
555
+ continue
556
+ count += 1
557
+ return count
558
+
559
+ def _find_available_instance(self, env_name: str) -> Optional[EnvInstance]:
560
+ """Find an available instance of the given environment type."""
561
+ for inst_id in self._env_instances.get(env_name, []):
562
+ inst = self._instances.get(inst_id)
563
+ if inst and inst.state == EnvState.AVAILABLE:
564
+ return inst
565
+ return None
566
+
567
+ def _get_needed_envs(self) -> Set[str]:
568
+ """Get all environment types needed by pending tasks."""
569
+ needed = set()
570
+ for task in self._pending:
571
+ needed.update(task.environments)
572
+ return needed
573
+
574
+ def _can_start_task(self, task: ScheduledTask) -> bool:
575
+ """Check if a task can start given current env limits."""
576
+ for env_name in task.environments:
577
+ # Check if we have an available instance
578
+ if self._find_available_instance(env_name):
579
+ continue
580
+
581
+ # Check if we can create a new instance
582
+ max_inst = self._get_max_instances(env_name)
583
+ if max_inst is not None:
584
+ current_count = self._count_env_instances(env_name)
585
+ if current_count >= max_inst:
586
+ return False
587
+
588
+ return True
589
+
590
+ def _calculate_task_priority(
591
+ self,
592
+ task: ScheduledTask,
593
+ available_envs: Set[str],
594
+ ) -> Tuple[int, int]:
595
+ """
596
+ Calculate priority for a task.
597
+
598
+ Returns (negative_reuse_count, original_index) for sorting.
599
+ Lower values = higher priority.
600
+ """
601
+ # Count how many environments can be reused
602
+ reusable = sum(1 for env in task.environments if env in available_envs)
603
+ return (-reusable, task.original_index)
604
+
605
+ def _pick_next_task(self) -> Optional[ScheduledTask]:
606
+ """Pick the best pending task to run next."""
607
+ if not self._pending:
608
+ return None
609
+
610
+ # Get currently available environment types
611
+ available_envs = set()
612
+ for inst in self._instances.values():
613
+ if inst.state == EnvState.AVAILABLE:
614
+ available_envs.add(inst.env_name)
615
+
616
+ # Filter to tasks that can actually start
617
+ candidates = [t for t in self._pending if self._can_start_task(t)]
618
+ if not candidates:
619
+ return None
620
+
621
+ # Sort by priority (max reuse, then FIFO)
622
+ candidates.sort(key=lambda t: self._calculate_task_priority(t, available_envs))
623
+
624
+ chosen = candidates[0]
625
+ self._pending.remove(chosen)
626
+ return chosen
627
+
628
+ async def _acquire_instances_for_task(
629
+ self,
630
+ task: ScheduledTask,
631
+ task_id: str,
632
+ ) -> Optional[List[EnvInstance]]:
633
+ """Acquire all required instances for a task."""
634
+ acquired: List[EnvInstance] = []
635
+
636
+ for env_name in task.environments:
637
+ # Try to find an available instance
638
+ instance = self._find_available_instance(env_name)
639
+ disable_reuse = self._get_disable_reuse_flag(env_name)
640
+ if instance and disable_reuse:
641
+ print(f"[EXECUTOR] Reuse disabled for {env_name}, stopping instance {instance.instance_id}", flush=True)
642
+ await self._stop_instance(instance)
643
+ instance = None
644
+
645
+ if instance:
646
+ # Reuse existing instance
647
+ instance.state = EnvState.IN_USE
648
+ instance.current_task_id = task_id
649
+ instance.use_count += 1
650
+ instance.last_used = time.time()
651
+ acquired.append(instance)
652
+
653
+ # Track early so release works even if reset fails
654
+ self._task_instances[task_id] = [inst.instance_id for inst in acquired]
655
+
656
+ # Reset before reuse (with configurable retries)
657
+ environments = self._env_config.get("environments", {})
658
+ env_def = environments.get(env_name, {})
659
+ max_retries = env_def.get("reset_retries", 1)
660
+ retry_delay = env_def.get("reset_retry_delay", 10)
661
+
662
+ for attempt in range(1 + max_retries):
663
+ try:
664
+ await self._reset_instance(instance)
665
+ break
666
+ except Exception as e:
667
+ if attempt < max_retries:
668
+ print(f"[EXECUTOR] Reset failed for {instance.instance_id} (attempt {attempt + 1}/{1 + max_retries}), retrying in {retry_delay}s: {e}", flush=True)
669
+ await asyncio.sleep(retry_delay)
670
+ else:
671
+ raise
672
+
673
+ print(f"[EXECUTOR] Reusing instance {instance.instance_id} for task {task_id} (use #{instance.use_count})", flush=True)
674
+ else:
675
+ # Start new instance (release lock during slow operation)
676
+ self._lock.release()
677
+ try:
678
+ instance = await self._start_instance(env_name)
679
+ finally:
680
+ await self._lock.acquire()
681
+
682
+ if not instance:
683
+ # Rollback acquired instances
684
+ for inst in acquired:
685
+ inst.state = EnvState.AVAILABLE
686
+ inst.current_task_id = None
687
+ return None
688
+
689
+ instance.state = EnvState.IN_USE
690
+ instance.current_task_id = task_id
691
+ instance.use_count = 1
692
+ acquired.append(instance)
693
+
694
+ # Track early so release works even if reset fails
695
+ self._task_instances[task_id] = [inst.instance_id for inst in acquired]
696
+
697
+ # Reset after start
698
+ await self._reset_instance(instance)
699
+
700
+ print(f"[EXECUTOR] Started new instance {instance.instance_id} for task {task_id}", flush=True)
701
+
702
+ # Final update of tracking (in case multiple environments)
703
+ self._task_instances[task_id] = [inst.instance_id for inst in acquired]
704
+
705
+ return acquired
706
+
707
+ async def _release_instances_for_task(self, task_id: str) -> Set[str]:
708
+ """Release instances used by a task. Returns released env names."""
709
+ instance_ids = self._task_instances.pop(task_id, [])
710
+ released_envs = set()
711
+
712
+ for inst_id in instance_ids:
713
+ inst = self._instances.get(inst_id)
714
+ if inst and inst.state == EnvState.IN_USE:
715
+ inst.state = EnvState.AVAILABLE
716
+ inst.current_task_id = None
717
+ inst.last_used = time.time()
718
+ released_envs.add(inst.env_name)
719
+ print(f"[EXECUTOR] Released instance {inst_id}", flush=True)
720
+
721
+ return released_envs
722
+
723
+ async def _cleanup_unused_instances(self, released_envs: Set[str]) -> None:
724
+ """Stop instances for environments no longer needed by pending tasks."""
725
+ needed_envs = self._get_needed_envs()
726
+
727
+ for env_name in released_envs:
728
+ if env_name not in needed_envs:
729
+ # Stop all available instances of this env type
730
+ for inst_id in list(self._env_instances.get(env_name, [])):
731
+ inst = self._instances.get(inst_id)
732
+ if inst and inst.state == EnvState.AVAILABLE:
733
+ # Release lock during slow stop operation
734
+ self._lock.release()
735
+ try:
736
+ await self._stop_instance(inst)
737
+ finally:
738
+ await self._lock.acquire()
739
+
740
+ async def _run_single_task(
741
+ self,
742
+ task: ScheduledTask,
743
+ task_id: str,
744
+ run_fn: Callable[[ScheduledTask, Dict[str, EnvInstance]], Coroutine[Any, Any, int]],
745
+ ) -> int:
746
+ """Run a single task and handle completion."""
747
+ instances: Optional[List[EnvInstance]] = None
748
+
749
+ try:
750
+ # Acquire instances
751
+ async with self._lock:
752
+ instances = await self._acquire_instances_for_task(task, task_id)
753
+
754
+ if not instances:
755
+ print(f"[EXECUTOR] Failed to acquire instances for task {task_id}", flush=True)
756
+ return 1
757
+
758
+ # Build env_name -> instance mapping for the task
759
+ env_instances = {inst.env_name: inst for inst in instances}
760
+
761
+ # Run the task
762
+ return_code = await run_fn(task, env_instances)
763
+ return return_code
764
+
765
+ except Exception as e:
766
+ print(f"[EXECUTOR] Error running task {task_id}: {e}", flush=True)
767
+ return 1
768
+
769
+ finally:
770
+ # Release instances and cleanup
771
+ async with self._lock:
772
+ released_envs = await self._release_instances_for_task(task_id)
773
+ await self._cleanup_unused_instances(released_envs)
774
+
775
+ async def _worker(
776
+ self,
777
+ run_fn: Callable[[ScheduledTask, Dict[str, EnvInstance]], Coroutine[Any, Any, int]],
778
+ results: List[Tuple[ScheduledTask, int]],
779
+ slot_available: asyncio.Event,
780
+ ) -> None:
781
+ """Worker that processes tasks from the pending queue."""
782
+ while True:
783
+ task: Optional[ScheduledTask] = None
784
+ task_id: Optional[str] = None
785
+
786
+ async with self._lock:
787
+ # Check if we should exit
788
+ if not self._pending and not self._running:
789
+ return
790
+
791
+ # Try to pick a task
792
+ if len(self._running) < self.max_parallel:
793
+ task = self._pick_next_task()
794
+ if task:
795
+ task_id = f"{task.task_dir.name}_{id(task)}"
796
+ self._running[task_id] = RunningTask(
797
+ task=task,
798
+ task_id=task_id,
799
+ instances=[],
800
+ )
801
+
802
+ if task and task_id:
803
+ # Run the task outside lock
804
+ return_code = await self._run_single_task(task, task_id, run_fn)
805
+
806
+ async with self._lock:
807
+ # Record result
808
+ results.append((task, return_code))
809
+ del self._running[task_id]
810
+
811
+ # Signal that a slot is available
812
+ slot_available.set()
813
+ else:
814
+ # Wait for a slot to become available
815
+ slot_available.clear()
816
+ await slot_available.wait()
817
+
818
+ async def run_all(
819
+ self,
820
+ tasks: List[ScheduledTask],
821
+ run_fn: Callable[[ScheduledTask, Dict[str, EnvInstance]], Coroutine[Any, Any, int]],
822
+ ) -> List[Tuple[ScheduledTask, int]]:
823
+ """
824
+ Run all tasks with optimal parallelism and environment reuse.
825
+
826
+ Args:
827
+ tasks: List of tasks to run
828
+ run_fn: Async function that takes (task, env_instances) and returns exit code
829
+
830
+ Returns:
831
+ List of (task, return_code) tuples
832
+ """
833
+ if not tasks:
834
+ return []
835
+
836
+ # Initialize pending queue (maintain original order)
837
+ self._pending = list(tasks)
838
+ self._running.clear()
839
+ results: List[Tuple[ScheduledTask, int]] = []
840
+
841
+ print(f"[EXECUTOR] Starting {len(tasks)} tasks with max_parallel={self.max_parallel}", flush=True)
842
+
843
+ # Create worker coordination
844
+ slot_available = asyncio.Event()
845
+ slot_available.set() # Start with slots available
846
+
847
+ # Create workers
848
+ workers = [
849
+ asyncio.create_task(self._worker(run_fn, results, slot_available))
850
+ for _ in range(self.max_parallel)
851
+ ]
852
+
853
+ # Wait for all workers to complete
854
+ await asyncio.gather(*workers)
855
+
856
+ return results
857
+
858
+ def stats(self) -> ExecutorStats:
859
+ """Get current executor statistics."""
860
+ instances_by_env: Dict[str, int] = defaultdict(int)
861
+ available = in_use = 0
862
+
863
+ for inst in self._instances.values():
864
+ instances_by_env[inst.env_name] += 1
865
+ if inst.state == EnvState.AVAILABLE:
866
+ available += 1
867
+ elif inst.state == EnvState.IN_USE:
868
+ in_use += 1
869
+
870
+ return ExecutorStats(
871
+ total_tasks=len(self._pending) + len(self._running) + len(self._completed),
872
+ pending_tasks=len(self._pending),
873
+ running_tasks=len(self._running),
874
+ completed_tasks=len(self._completed),
875
+ total_instances=len(self._instances),
876
+ available_instances=available,
877
+ in_use_instances=in_use,
878
+ instances_by_env=dict(instances_by_env),
879
+ )
880
+
881
+ async def shutdown(self) -> None:
882
+ """Shutdown all instances."""
883
+ async with self._lock:
884
+ print(f"[EXECUTOR] Shutting down {len(self._instances)} instance(s)...", flush=True)
885
+
886
+ # Stop all instances concurrently
887
+ stop_tasks = []
888
+ for inst in list(self._instances.values()):
889
+ if inst.state != EnvState.STOPPING:
890
+ stop_tasks.append(self._stop_instance(inst))
891
+
892
+ if stop_tasks:
893
+ # Release lock during shutdown
894
+ self._lock.release()
895
+ try:
896
+ await asyncio.gather(*stop_tasks, return_exceptions=True)
897
+ finally:
898
+ await self._lock.acquire()
899
+
900
+ self._instances.clear()
901
+ self._env_instances.clear()
902
+ self._task_instances.clear()
903
+
904
+ print("[EXECUTOR] Shutdown complete", flush=True)