decodingtrust-agent-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent/__init__.py +30 -0
- agent/claudesdk/__init__.py +8 -0
- agent/claudesdk/example.py +221 -0
- agent/claudesdk/src/__init__.py +8 -0
- agent/claudesdk/src/agent.py +400 -0
- agent/claudesdk/src/mcp_proxy.py +409 -0
- agent/claudesdk/src/utils.py +420 -0
- agent/googleadk/__init__.py +15 -0
- agent/googleadk/example.py +237 -0
- agent/googleadk/src/__init__.py +12 -0
- agent/googleadk/src/agent.py +401 -0
- agent/googleadk/src/mcp_wrapper.py +163 -0
- agent/googleadk/src/utils.py +602 -0
- agent/langchain/__init__.py +8 -0
- agent/langchain/example.py +213 -0
- agent/langchain/src/__init__.py +8 -0
- agent/langchain/src/agent.py +645 -0
- agent/langchain/src/utils.py +433 -0
- agent/openaisdk/__init__.py +17 -0
- agent/openaisdk/example.py +228 -0
- agent/openaisdk/src/__init__.py +12 -0
- agent/openaisdk/src/agent.py +491 -0
- agent/openaisdk/src/agent_wrapper.py +143 -0
- agent/openaisdk/src/mcp_wrapper.py +395 -0
- agent/openaisdk/src/utils.py +493 -0
- agent/openclaw/__init__.py +10 -0
- agent/openclaw/example.py +251 -0
- agent/openclaw/src/__init__.py +14 -0
- agent/openclaw/src/agent.py +930 -0
- agent/openclaw/src/helpers/__init__.py +1 -0
- agent/openclaw/src/helpers/auth_helpers.py +55 -0
- agent/openclaw/src/mcp_proxy.py +564 -0
- agent/openclaw/src/plugin_generator.py +231 -0
- agent/openclaw/src/utils.py +341 -0
- agent/pocketflow/__init__.py +18 -0
- agent/pocketflow/example.py +221 -0
- agent/pocketflow/prompts/react_agent.py +46 -0
- agent/pocketflow/src/__init__.py +6 -0
- agent/pocketflow/src/agent.py +507 -0
- agent/pocketflow/src/agent_wrapper.py +159 -0
- agent/pocketflow/src/async_helper.py +92 -0
- agent/pocketflow/src/mcp_react_agent.py +279 -0
- agent/pocketflow/src/native_agent.py +74 -0
- agent/pocketflow/src/nodes.py +467 -0
- benchmark/__init__.py +0 -0
- benchmark/browser/benign.jsonl +34 -0
- benchmark/browser/direct.jsonl +85 -0
- benchmark/browser/indirect.jsonl +82 -0
- benchmark/code/benign.jsonl +0 -0
- benchmark/code/direct.jsonl +121 -0
- benchmark/code/indirect.jsonl +165 -0
- benchmark/crm/benign.jsonl +165 -0
- benchmark/crm/direct.jsonl +90 -0
- benchmark/crm/indirect.jsonl +150 -0
- benchmark/customer-service/benign.jsonl +160 -0
- benchmark/customer-service/direct.jsonl +100 -0
- benchmark/customer-service/indirect.jsonl +101 -0
- benchmark/finance/benign.jsonl +0 -0
- benchmark/finance/direct.jsonl +200 -0
- benchmark/finance/indirect.jsonl +200 -0
- benchmark/legal/benign.jsonl +0 -0
- benchmark/legal/direct.jsonl +200 -0
- benchmark/legal/indirect.jsonl +200 -0
- benchmark/macos/benign.jsonl +30 -0
- benchmark/macos/direct.jsonl +50 -0
- benchmark/macos/indirect.jsonl +50 -0
- benchmark/medical/benign.jsonl +642 -0
- benchmark/medical/direct.jsonl +229 -0
- benchmark/medical/indirect.jsonl +222 -0
- benchmark/os-filesystem/benign.jsonl +200 -0
- benchmark/os-filesystem/direct.jsonl +200 -0
- benchmark/os-filesystem/indirect.jsonl +200 -0
- benchmark/research/benign.jsonl +0 -0
- benchmark/research/direct.jsonl +119 -0
- benchmark/research/indirect.jsonl +125 -0
- benchmark/telecom/benign.jsonl +120 -0
- benchmark/telecom/direct.jsonl +161 -0
- benchmark/telecom/indirect.jsonl +166 -0
- benchmark/travel/benign.jsonl +130 -0
- benchmark/travel/direct.jsonl +105 -0
- benchmark/travel/indirect.jsonl +120 -0
- benchmark/windows/benign.jsonl +100 -0
- benchmark/windows/direct.jsonl +140 -0
- benchmark/windows/indirect.jsonl +107 -0
- benchmark/workflow/benign.jsonl +335 -0
- benchmark/workflow/direct.jsonl +78 -0
- benchmark/workflow/indirect.jsonl +107 -0
- cli/__init__.py +5 -0
- cli/main.py +182 -0
- cli/scaffold.py +334 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/METADATA +642 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/RECORD +374 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/WHEEL +5 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/entry_points.txt +2 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/licenses/LICENSE +201 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/top_level.txt +6 -0
- dt_arena/config/env.yaml +515 -0
- dt_arena/config/injection_mcp.yaml +430 -0
- dt_arena/config/mcp.yaml +642 -0
- dt_arena/envs/arxiv/docker-compose-hub.yml +31 -0
- dt_arena/envs/arxiv/docker-compose.yml +36 -0
- dt_arena/envs/atlassian/docker/docker-compose.dev.yml +65 -0
- dt_arena/envs/atlassian/docker/docker-compose.yml +53 -0
- dt_arena/envs/atlassian/docker-compose-hub.yml +57 -0
- dt_arena/envs/atlassian/docker-compose.yml +72 -0
- dt_arena/envs/bigquery/docker-compose.yml +20 -0
- dt_arena/envs/booking/docker-compose.yml +59 -0
- dt_arena/envs/calendar/docker-compose-hub.yml +30 -0
- dt_arena/envs/calendar/docker-compose.yml +42 -0
- dt_arena/envs/custom-website/docker-compose.yml +6 -0
- dt_arena/envs/customer_service/docker-compose.yml +59 -0
- dt_arena/envs/databricks/docker-compose-hub.yml +47 -0
- dt_arena/envs/databricks/docker-compose.yml +51 -0
- dt_arena/envs/ecommerce/docker-compose.yml +6 -0
- dt_arena/envs/ers/docker-compose.yml +36 -0
- dt_arena/envs/ers/hrms/docker/docker-compose.yml +31 -0
- dt_arena/envs/finance/docker-compose.yml +23 -0
- dt_arena/envs/github/docker/docker-compose-hub.yml +50 -0
- dt_arena/envs/github/docker/docker-compose.yml +50 -0
- dt_arena/envs/gmail/docker-compose-hub.yml +51 -0
- dt_arena/envs/gmail/docker-compose.yml +65 -0
- dt_arena/envs/google-form/docker-compose-hub.yml +33 -0
- dt_arena/envs/google-form/docker-compose.yml +41 -0
- dt_arena/envs/googledocs/docker-compose-hub.yml +61 -0
- dt_arena/envs/googledocs/docker-compose.yml +78 -0
- dt_arena/envs/hospital/docker-compose-hub.yml +25 -0
- dt_arena/envs/hospital/docker-compose.yml +27 -0
- dt_arena/envs/legal/docker-compose.yml +22 -0
- dt_arena/envs/linkedin/docker-compose.yml +63 -0
- dt_arena/envs/macos/docker-compose.yml +79 -0
- dt_arena/envs/os-filesystem/docker-compose-hub.yml +16 -0
- dt_arena/envs/os-filesystem/docker-compose.yml +20 -0
- dt_arena/envs/paypal/docker-compose-hub.yml +48 -0
- dt_arena/envs/paypal/docker-compose.yml +63 -0
- dt_arena/envs/research/docker-compose-hub.yml +13 -0
- dt_arena/envs/research/docker-compose.yml +24 -0
- dt_arena/envs/salesforce_crm/docker-compose-hub.yaml +45 -0
- dt_arena/envs/salesforce_crm/docker-compose.yaml +49 -0
- dt_arena/envs/slack/docker-compose-hub.yml +28 -0
- dt_arena/envs/slack/docker-compose.yml +41 -0
- dt_arena/envs/snowflake/docker-compose-hub.yml +41 -0
- dt_arena/envs/snowflake/docker-compose.yml +44 -0
- dt_arena/envs/telecom/docker-compose-hub.yml +16 -0
- dt_arena/envs/telecom/docker-compose.yml +17 -0
- dt_arena/envs/telegram/docker-compose-hub.yml +57 -0
- dt_arena/envs/telegram/docker-compose.yml +62 -0
- dt_arena/envs/terminal/docker-compose-hub.yml +12 -0
- dt_arena/envs/terminal/docker-compose.yml +26 -0
- dt_arena/envs/travel/docker-compose-hub.yml +19 -0
- dt_arena/envs/travel/docker-compose.yml +19 -0
- dt_arena/envs/whatsapp/docker-compose-hub.yml +61 -0
- dt_arena/envs/whatsapp/docker-compose.yml +78 -0
- dt_arena/envs/windows/docker-compose.yml +71 -0
- dt_arena/envs/zoom/docker-compose-hub.yml +27 -0
- dt_arena/envs/zoom/docker-compose.yml +40 -0
- dt_arena/injection_mcp_server/atlassian/env_injection.py +134 -0
- dt_arena/injection_mcp_server/calendar/env_injection.py +217 -0
- dt_arena/injection_mcp_server/custom_website/env_injection.py +97 -0
- dt_arena/injection_mcp_server/customer_service/env_injection.py +659 -0
- dt_arena/injection_mcp_server/databricks/env_injection.py +255 -0
- dt_arena/injection_mcp_server/ecommerce/env_injection.py +110 -0
- dt_arena/injection_mcp_server/finance/env_injection.py +85 -0
- dt_arena/injection_mcp_server/github/env_injection.py +206 -0
- dt_arena/injection_mcp_server/gmail/env_injection.py +211 -0
- dt_arena/injection_mcp_server/google_form/env_injection.py +186 -0
- dt_arena/injection_mcp_server/googledocs/env_injection.py +44 -0
- dt_arena/injection_mcp_server/hospital/env_injection.py +43 -0
- dt_arena/injection_mcp_server/legal/env_injection.py +229 -0
- dt_arena/injection_mcp_server/macos/env_injection.py +272 -0
- dt_arena/injection_mcp_server/os-filesystem/env_injection.py +341 -0
- dt_arena/injection_mcp_server/paypal/env_injection.py +268 -0
- dt_arena/injection_mcp_server/research/env_injection.py +616 -0
- dt_arena/injection_mcp_server/salesforce/env_injection.py +514 -0
- dt_arena/injection_mcp_server/slack/env_injection.py +265 -0
- dt_arena/injection_mcp_server/snowflake/env_injection.py +230 -0
- dt_arena/injection_mcp_server/telecom/env_injection.py +503 -0
- dt_arena/injection_mcp_server/telegram/env_injection.py +171 -0
- dt_arena/injection_mcp_server/terminal/env_injection.py +523 -0
- dt_arena/injection_mcp_server/travel/env_injection.py +173 -0
- dt_arena/injection_mcp_server/whatsapp/env_injection.py +185 -0
- dt_arena/injection_mcp_server/windows/env_injection.py +943 -0
- dt_arena/injection_mcp_server/zoom/env_injection.py +216 -0
- dt_arena/mcp_server/atlassian/main.py +1554 -0
- dt_arena/mcp_server/atlassian/test_server.py +66 -0
- dt_arena/mcp_server/bigquery/main.py +333 -0
- dt_arena/mcp_server/booking/main.py +310 -0
- dt_arena/mcp_server/browser/main.py +1741 -0
- dt_arena/mcp_server/calendar/example_multi_user.py +162 -0
- dt_arena/mcp_server/calendar/main.py +792 -0
- dt_arena/mcp_server/calendar/test_mcp.py +135 -0
- dt_arena/mcp_server/customer_service/main.py +1063 -0
- dt_arena/mcp_server/databricks/main.py +566 -0
- dt_arena/mcp_server/databricks/probe.py +102 -0
- dt_arena/mcp_server/ers/main.py +845 -0
- dt_arena/mcp_server/finance/__init__.py +87 -0
- dt_arena/mcp_server/finance/core/__init__.py +12 -0
- dt_arena/mcp_server/finance/core/data_loader.py +558 -0
- dt_arena/mcp_server/finance/core/portfolio.py +565 -0
- dt_arena/mcp_server/finance/evaluation/__init__.py +20 -0
- dt_arena/mcp_server/finance/evaluation/evaluator.py +217 -0
- dt_arena/mcp_server/finance/evaluation/logger.py +137 -0
- dt_arena/mcp_server/finance/injection/__init__.py +66 -0
- dt_arena/mcp_server/finance/injection/config.py +176 -0
- dt_arena/mcp_server/finance/injection/content.py +755 -0
- dt_arena/mcp_server/finance/injection/html.py +409 -0
- dt_arena/mcp_server/finance/injection/locations.py +167 -0
- dt_arena/mcp_server/finance/injection/methods.py +193 -0
- dt_arena/mcp_server/finance/injection/presets.py +1023 -0
- dt_arena/mcp_server/finance/main.py +361 -0
- dt_arena/mcp_server/finance/run_mcp.py +21 -0
- dt_arena/mcp_server/finance/run_web.py +26 -0
- dt_arena/mcp_server/finance/server/__init__.py +41 -0
- dt_arena/mcp_server/finance/server/extractor.py +1453 -0
- dt_arena/mcp_server/finance/server/extractor_minimal.py +292 -0
- dt_arena/mcp_server/finance/server/extractor_simple.py +1164 -0
- dt_arena/mcp_server/finance/server/injection_mcp.py +865 -0
- dt_arena/mcp_server/finance/server/mcp.py +451 -0
- dt_arena/mcp_server/finance/server/tools/__init__.py +23 -0
- dt_arena/mcp_server/finance/server/tools/account.py +88 -0
- dt_arena/mcp_server/finance/server/tools/browsing.py +328 -0
- dt_arena/mcp_server/finance/server/tools/social.py +73 -0
- dt_arena/mcp_server/finance/server/tools/trading.py +242 -0
- dt_arena/mcp_server/finance/server/tools/utility.py +49 -0
- dt_arena/mcp_server/finance/server/web.py +2139 -0
- dt_arena/mcp_server/finance/tasks/benchmark/__init__.py +28 -0
- dt_arena/mcp_server/finance/tasks/benchmark/attack_pool.py +3026 -0
- dt_arena/mcp_server/finance/tasks/benchmark/attack_runner.py +1315 -0
- dt_arena/mcp_server/finance/tasks/benchmark/finra_requirements.py +1335 -0
- dt_arena/mcp_server/finance/tasks/benchmark/finra_tasks.py +3665 -0
- dt_arena/mcp_server/finance/tasks/benchmark/malicious_tasks.py +2673 -0
- dt_arena/mcp_server/finance/tasks/redteam_suite/run_redteam_suite.py +1713 -0
- dt_arena/mcp_server/finance/test_mcp_tools.py +476 -0
- dt_arena/mcp_server/github/main.py +441 -0
- dt_arena/mcp_server/gmail/main.py +1004 -0
- dt_arena/mcp_server/google_form/main.py +141 -0
- dt_arena/mcp_server/googledocs/main.py +458 -0
- dt_arena/mcp_server/hospital/mcp_server.py +458 -0
- dt_arena/mcp_server/legal/__init__.py +9 -0
- dt_arena/mcp_server/legal/core/__init__.py +14 -0
- dt_arena/mcp_server/legal/core/courtlistener_store.py +762 -0
- dt_arena/mcp_server/legal/core/data_loader.py +266 -0
- dt_arena/mcp_server/legal/core/document_store.py +197 -0
- dt_arena/mcp_server/legal/core/matter_manager.py +466 -0
- dt_arena/mcp_server/legal/main.py +89 -0
- dt_arena/mcp_server/legal/scripts/collect_data.py +988 -0
- dt_arena/mcp_server/legal/server/__init__.py +14 -0
- dt_arena/mcp_server/legal/server/mcp.py +2330 -0
- dt_arena/mcp_server/macos/client_test.py +270 -0
- dt_arena/mcp_server/macos/mcp_server.py +285 -0
- dt_arena/mcp_server/os-filesystem/main.py +1380 -0
- dt_arena/mcp_server/paypal/main.py +501 -0
- dt_arena/mcp_server/research/main.py +777 -0
- dt_arena/mcp_server/salesforce/main.py +2006 -0
- dt_arena/mcp_server/slack/main.py +318 -0
- dt_arena/mcp_server/snowflake/main.py +612 -0
- dt_arena/mcp_server/snowflake/probe.py +183 -0
- dt_arena/mcp_server/telecom/mcp_client.py +423 -0
- dt_arena/mcp_server/telecom/mcp_server.py +1059 -0
- dt_arena/mcp_server/telegram/main.py +338 -0
- dt_arena/mcp_server/terminal/main.py +163 -0
- dt_arena/mcp_server/travel/client_test.py +16 -0
- dt_arena/mcp_server/travel/mcp_server.py +404 -0
- dt_arena/mcp_server/whatsapp/main.py +318 -0
- dt_arena/mcp_server/windows/client_test.py +270 -0
- dt_arena/mcp_server/windows/mcp_server.py +218 -0
- dt_arena/mcp_server/zoom/main.py +466 -0
- dt_arena/src/__init__.py +0 -0
- dt_arena/src/hooks/__init__.py +0 -0
- dt_arena/src/hooks/audit_log.py +30 -0
- dt_arena/src/hooks/hooks.json +3 -0
- dt_arena/src/run_benign.py +142 -0
- dt_arena/src/types/__init__.py +0 -0
- dt_arena/src/types/agent.py +441 -0
- dt_arena/src/types/attacks.py +2 -0
- dt_arena/src/types/environment.py +2 -0
- dt_arena/src/types/hooks.py +174 -0
- dt_arena/src/types/judge.py +52 -0
- dt_arena/src/types/red_teaming_trajectory.py +385 -0
- dt_arena/src/types/task.py +260 -0
- dt_arena/src/types/trajectory.py +315 -0
- dt_arena/utils/__init__.py +1 -0
- dt_arena/utils/atlassian/__init__.py +27 -0
- dt_arena/utils/atlassian/helpers.py +520 -0
- dt_arena/utils/bigquery/__init__.py +1 -0
- dt_arena/utils/bigquery/helpers.py +246 -0
- dt_arena/utils/calendar/__init__.py +1 -0
- dt_arena/utils/calendar/helpers.py +87 -0
- dt_arena/utils/customer_service/__init__.py +17 -0
- dt_arena/utils/customer_service/cs_env_client.py +940 -0
- dt_arena/utils/customer_service/helpers.py +339 -0
- dt_arena/utils/customer_service/judges/__init__.py +20 -0
- dt_arena/utils/customer_service/judges/abusive_comms.py +149 -0
- dt_arena/utils/customer_service/judges/bulk_destructive.py +202 -0
- dt_arena/utils/customer_service/judges/cross_customer.py +481 -0
- dt_arena/utils/customer_service/judges/email_exfiltration.py +320 -0
- dt_arena/utils/customer_service/judges/financial_fraud.py +221 -0
- dt_arena/utils/customer_service/judges/llm_utils.py +264 -0
- dt_arena/utils/customer_service/judges/pii_disclosure.py +837 -0
- dt_arena/utils/customer_service/judges/policy_misrep.py +44 -0
- dt_arena/utils/customer_service/judges/text_utils.py +21 -0
- dt_arena/utils/databricks/__init__.py +2 -0
- dt_arena/utils/databricks/helpers.py +210 -0
- dt_arena/utils/finance/__init__.py +0 -0
- dt_arena/utils/finance/helpers.py +263 -0
- dt_arena/utils/github/__init__.py +1 -0
- dt_arena/utils/github/helpers.py +249 -0
- dt_arena/utils/gmail/__init__.py +1 -0
- dt_arena/utils/gmail/helpers.py +344 -0
- dt_arena/utils/google_form/__init__.py +2 -0
- dt_arena/utils/google_form/helpers.py +133 -0
- dt_arena/utils/legal/__init__.py +0 -0
- dt_arena/utils/legal/helpers.py +228 -0
- dt_arena/utils/macos/__init__.py +0 -0
- dt_arena/utils/macos/env_setup.py +215 -0
- dt_arena/utils/macos/helpers.py +61 -0
- dt_arena/utils/os_filesystem/__init__.py +1 -0
- dt_arena/utils/os_filesystem/helpers.py +366 -0
- dt_arena/utils/paypal/__init__.py +1 -0
- dt_arena/utils/paypal/helpers.py +178 -0
- dt_arena/utils/port_allocator.py +266 -0
- dt_arena/utils/research/__init__.py +0 -0
- dt_arena/utils/research/helpers.py +251 -0
- dt_arena/utils/salesforce/__init__.py +1 -0
- dt_arena/utils/salesforce/helpers.py +719 -0
- dt_arena/utils/slack/__init__.py +1 -0
- dt_arena/utils/slack/helpers.py +176 -0
- dt_arena/utils/snowflake/__init__.py +1 -0
- dt_arena/utils/snowflake/helpers.py +166 -0
- dt_arena/utils/telecom/__init__.py +1 -0
- dt_arena/utils/telecom/helpers.py +760 -0
- dt_arena/utils/telegram/__init__.py +0 -0
- dt_arena/utils/telegram/helpers.py +174 -0
- dt_arena/utils/terminal/__init__.py +0 -0
- dt_arena/utils/terminal/helpers.py +20 -0
- dt_arena/utils/travel/__init__.py +0 -0
- dt_arena/utils/travel/env_client.py +537 -0
- dt_arena/utils/travel/llm_judge.py +137 -0
- dt_arena/utils/travel/prompts.py +64 -0
- dt_arena/utils/utils/__init__.py +122 -0
- dt_arena/utils/whatsapp/__init__.py +0 -0
- dt_arena/utils/whatsapp/helpers.py +226 -0
- dt_arena/utils/windows/__init__.py +0 -0
- dt_arena/utils/windows/env_reset.py +224 -0
- dt_arena/utils/windows/env_setup.py +280 -0
- dt_arena/utils/windows/exfil_helpers.py +170 -0
- dt_arena/utils/windows/helpers.py +74 -0
- dt_arena/utils/zoom/__init__.py +1 -0
- dt_arena/utils/zoom/helpers.py +70 -0
- eval/__init__.py +1 -0
- eval/evaluation.py +426 -0
- eval/task_runner.py +449 -0
- utils/__init__.py +148 -0
- utils/agent_helpers.py +308 -0
- utils/agent_wrapper.py +189 -0
- utils/compose_utils.py +135 -0
- utils/config.py +77 -0
- utils/env_helpers.py +104 -0
- utils/eval_stats.py +88 -0
- utils/injection_helpers.py +429 -0
- utils/injection_mcp_helpers.py +152 -0
- utils/judge_helpers.py +181 -0
- utils/judge_utils.py +472 -0
- utils/llm.py +196 -0
- utils/logging.py +45 -0
- utils/mcp_helpers.py +232 -0
- utils/mcp_manager.py +235 -0
- utils/memory_guard.py +18 -0
- utils/red_teaming_sandbox.py +476 -0
- utils/reset_helpers.py +318 -0
- utils/resource_manager.py +370 -0
- utils/skill_helpers.py +447 -0
- utils/task_executor.py +904 -0
- utils/task_helpers.py +270 -0
- utils/template_helpers.py +179 -0
utils/task_executor.py
ADDED
|
@@ -0,0 +1,904 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import grp
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
import time
|
|
6
|
+
from collections import defaultdict
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from enum import Enum
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any, Callable, Coroutine, Dict, FrozenSet, List, Optional, Set, Tuple
|
|
11
|
+
|
|
12
|
+
import yaml
|
|
13
|
+
|
|
14
|
+
from .memory_guard import check_memory_before_launch
|
|
15
|
+
from .reset_helpers import reset_environment
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _needs_sudo_for_docker() -> bool:
|
|
19
|
+
"""Check if we need sudo to run docker commands."""
|
|
20
|
+
import subprocess
|
|
21
|
+
# First, try running docker directly
|
|
22
|
+
try:
|
|
23
|
+
result = subprocess.run(
|
|
24
|
+
["docker", "ps"],
|
|
25
|
+
capture_output=True,
|
|
26
|
+
timeout=5
|
|
27
|
+
)
|
|
28
|
+
if result.returncode == 0:
|
|
29
|
+
return False
|
|
30
|
+
except (subprocess.SubprocessError, FileNotFoundError):
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
# Fallback to group check
|
|
34
|
+
try:
|
|
35
|
+
docker_gid = grp.getgrnam("docker").gr_gid
|
|
36
|
+
if docker_gid in os.getgroups():
|
|
37
|
+
return False
|
|
38
|
+
except (KeyError, OSError):
|
|
39
|
+
pass
|
|
40
|
+
# Default to using sudo
|
|
41
|
+
return True
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
# Cache the result since it won't change during execution
|
|
45
|
+
_USE_SUDO = _needs_sudo_for_docker()
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def get_task_environments(task_dir: Path) -> List[str]:
|
|
49
|
+
"""
|
|
50
|
+
Determine which environments a task needs based on its config.yaml.
|
|
51
|
+
|
|
52
|
+
Uses the 'environment' field in mcp.yaml to directly map MCP servers
|
|
53
|
+
to their required Docker environments.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
task_dir: Path to the task directory
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
List of environment names (e.g., ["gmail", "slack"])
|
|
60
|
+
"""
|
|
61
|
+
config_path = task_dir / "config.yaml"
|
|
62
|
+
if not config_path.exists():
|
|
63
|
+
return []
|
|
64
|
+
|
|
65
|
+
try:
|
|
66
|
+
config = yaml.safe_load(config_path.read_text()) or {}
|
|
67
|
+
except yaml.YAMLError as e:
|
|
68
|
+
raise ValueError(f"Failed to parse {config_path}: {e}")
|
|
69
|
+
|
|
70
|
+
agent_cfg = config.get("Agent", {})
|
|
71
|
+
mcp_servers = agent_cfg.get("mcp_servers", []) or config.get("mcp_servers", [])
|
|
72
|
+
|
|
73
|
+
server_names = [
|
|
74
|
+
srv.get("name", "").lower()
|
|
75
|
+
for srv in mcp_servers
|
|
76
|
+
if srv.get("enabled", True)
|
|
77
|
+
]
|
|
78
|
+
|
|
79
|
+
project_root = Path(__file__).resolve().parent.parent
|
|
80
|
+
mcp_config_path = project_root / "dt_arena" / "config" / "mcp.yaml"
|
|
81
|
+
|
|
82
|
+
if not mcp_config_path.exists():
|
|
83
|
+
return server_names
|
|
84
|
+
|
|
85
|
+
try:
|
|
86
|
+
mcp_cfg = yaml.safe_load(mcp_config_path.read_text()) or {}
|
|
87
|
+
except yaml.YAMLError as e:
|
|
88
|
+
raise ValueError(f"Failed to parse {mcp_config_path}: {e}")
|
|
89
|
+
|
|
90
|
+
mcp_servers_cfg = {srv["name"].lower(): srv for srv in mcp_cfg.get("servers", [])}
|
|
91
|
+
|
|
92
|
+
needed_envs: Set[str] = set()
|
|
93
|
+
for srv_name in server_names:
|
|
94
|
+
srv_cfg = mcp_servers_cfg.get(srv_name, {})
|
|
95
|
+
env_value = srv_cfg.get("environment")
|
|
96
|
+
if env_value:
|
|
97
|
+
# Support both single string and list of environments
|
|
98
|
+
if isinstance(env_value, list):
|
|
99
|
+
needed_envs.update(env_value)
|
|
100
|
+
else:
|
|
101
|
+
needed_envs.add(env_value)
|
|
102
|
+
|
|
103
|
+
return list(needed_envs)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class EnvState(Enum):
|
|
107
|
+
"""Environment instance lifecycle states."""
|
|
108
|
+
STARTING = "starting"
|
|
109
|
+
AVAILABLE = "available"
|
|
110
|
+
IN_USE = "in_use"
|
|
111
|
+
STOPPING = "stopping"
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@dataclass
|
|
115
|
+
class EnvInstance:
|
|
116
|
+
"""A running instance of a Docker environment."""
|
|
117
|
+
instance_id: str # Unique ID like "gmail:pool_gmail_1_12345"
|
|
118
|
+
env_name: str # Environment type like "gmail", "slack"
|
|
119
|
+
project_name: str # Docker compose project name
|
|
120
|
+
compose_file: Path
|
|
121
|
+
state: EnvState = EnvState.STARTING
|
|
122
|
+
ports: Dict[str, int] = field(default_factory=dict)
|
|
123
|
+
current_task_id: Optional[str] = None
|
|
124
|
+
use_count: int = 0
|
|
125
|
+
last_used: float = field(default_factory=time.time)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@dataclass
|
|
129
|
+
class ScheduledTask:
|
|
130
|
+
"""A task with its environment requirements."""
|
|
131
|
+
task_dir: Path
|
|
132
|
+
environments: FrozenSet[str]
|
|
133
|
+
original_index: int
|
|
134
|
+
domain: Optional[str] = None
|
|
135
|
+
task_type: Optional[str] = None
|
|
136
|
+
threat_model: Optional[str] = None
|
|
137
|
+
risk_category: Optional[str] = None
|
|
138
|
+
task_id: Optional[str] = None
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
@dataclass
|
|
142
|
+
class RunningTask:
|
|
143
|
+
"""A task currently being executed."""
|
|
144
|
+
task: ScheduledTask
|
|
145
|
+
task_id: str
|
|
146
|
+
instances: List[EnvInstance]
|
|
147
|
+
start_time: float = field(default_factory=time.time)
|
|
148
|
+
future: Optional[asyncio.Future] = None
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
@dataclass
|
|
152
|
+
class ExecutorStats:
|
|
153
|
+
"""Statistics about the executor state."""
|
|
154
|
+
total_tasks: int
|
|
155
|
+
pending_tasks: int
|
|
156
|
+
running_tasks: int
|
|
157
|
+
completed_tasks: int
|
|
158
|
+
total_instances: int
|
|
159
|
+
available_instances: int
|
|
160
|
+
in_use_instances: int
|
|
161
|
+
instances_by_env: Dict[str, int]
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class TaskExecutor:
|
|
165
|
+
"""
|
|
166
|
+
Task-based parallel executor with environment instance pooling.
|
|
167
|
+
|
|
168
|
+
Key features:
|
|
169
|
+
- Controls parallelism by max_parallel_tasks (not max environments)
|
|
170
|
+
- Creates environment instances on demand
|
|
171
|
+
- Reuses available instances when possible
|
|
172
|
+
- Stops instances when no pending task needs them
|
|
173
|
+
- Respects per-environment instance limits (e.g., max 1 Windows VM)
|
|
174
|
+
"""
|
|
175
|
+
|
|
176
|
+
def __init__(self, max_parallel: int = 5):
|
|
177
|
+
self.max_parallel = max_parallel
|
|
178
|
+
self._lock = asyncio.Lock()
|
|
179
|
+
|
|
180
|
+
# Task tracking
|
|
181
|
+
self._pending: List[ScheduledTask] = []
|
|
182
|
+
self._running: Dict[str, RunningTask] = {} # task_id -> RunningTask
|
|
183
|
+
self._completed: List[Tuple[ScheduledTask, int]] = [] # (task, return_code)
|
|
184
|
+
|
|
185
|
+
# Environment instance tracking
|
|
186
|
+
self._instances: Dict[str, EnvInstance] = {} # instance_id -> EnvInstance
|
|
187
|
+
self._env_instances: Dict[str, List[str]] = defaultdict(list) # env_name -> [instance_ids]
|
|
188
|
+
self._task_instances: Dict[str, List[str]] = {} # task_id -> [instance_ids]
|
|
189
|
+
|
|
190
|
+
# Configuration
|
|
191
|
+
self._env_config: Dict = {}
|
|
192
|
+
self._env_limits: Dict[str, int] = {} # env_name -> max_instances
|
|
193
|
+
self._default_max_instances: Optional[int] = None
|
|
194
|
+
self._project_counter = 0
|
|
195
|
+
self._pool_id = str(os.getpid())
|
|
196
|
+
|
|
197
|
+
# Load configuration
|
|
198
|
+
self._load_config()
|
|
199
|
+
|
|
200
|
+
def _load_config(self) -> None:
|
|
201
|
+
"""Load environment configuration including instance limits."""
|
|
202
|
+
project_root = Path(__file__).resolve().parent.parent
|
|
203
|
+
env_config_path = project_root / "dt_arena" / "config" / "env.yaml"
|
|
204
|
+
|
|
205
|
+
if not env_config_path.exists():
|
|
206
|
+
return
|
|
207
|
+
|
|
208
|
+
try:
|
|
209
|
+
self._env_config = yaml.safe_load(env_config_path.read_text()) or {}
|
|
210
|
+
except Exception as e:
|
|
211
|
+
print(f"[EXECUTOR] Warning: Failed to load env config: {e}")
|
|
212
|
+
return
|
|
213
|
+
|
|
214
|
+
# Load default max instances
|
|
215
|
+
self._default_max_instances = self._env_config.get("default_max_instances")
|
|
216
|
+
|
|
217
|
+
# Load per-environment limits
|
|
218
|
+
environments = self._env_config.get("environments", {})
|
|
219
|
+
for env_name, env_def in environments.items():
|
|
220
|
+
if "max_instances" in env_def:
|
|
221
|
+
self._env_limits[env_name] = env_def["max_instances"]
|
|
222
|
+
|
|
223
|
+
def _get_max_instances(self, env_name: str) -> Optional[int]:
|
|
224
|
+
"""Get max instances for an environment (None = unlimited)."""
|
|
225
|
+
if env_name in self._env_limits:
|
|
226
|
+
return self._env_limits[env_name]
|
|
227
|
+
return self._default_max_instances
|
|
228
|
+
|
|
229
|
+
def _get_compose_file(self, env_name: str) -> Optional[Path]:
|
|
230
|
+
"""Get the docker-compose file path for an environment."""
|
|
231
|
+
project_root = Path(__file__).resolve().parent.parent
|
|
232
|
+
environments = self._env_config.get("environments", {})
|
|
233
|
+
|
|
234
|
+
if env_name in environments:
|
|
235
|
+
compose_rel = environments[env_name].get("docker_compose")
|
|
236
|
+
if compose_rel:
|
|
237
|
+
return (project_root / compose_rel).resolve()
|
|
238
|
+
|
|
239
|
+
# Fallback: check standard location
|
|
240
|
+
env_path = project_root / "dt_arena" / "envs" / env_name
|
|
241
|
+
for name in ["docker-compose.yml", "docker-compose.yaml"]:
|
|
242
|
+
compose = env_path / name
|
|
243
|
+
if compose.exists():
|
|
244
|
+
return compose
|
|
245
|
+
|
|
246
|
+
return None
|
|
247
|
+
|
|
248
|
+
def _generate_project_name(self, env_name: str) -> str:
|
|
249
|
+
"""Generate unique project name for a new instance."""
|
|
250
|
+
self._project_counter += 1
|
|
251
|
+
return f"pool_{env_name}_{self._project_counter}_{self._pool_id}"
|
|
252
|
+
|
|
253
|
+
def _allocate_ports(self, env_name: str) -> Dict[str, int]:
|
|
254
|
+
"""Allocate ports for a new instance."""
|
|
255
|
+
from .resource_manager import ResourceManager
|
|
256
|
+
|
|
257
|
+
mgr = ResourceManager.instance()
|
|
258
|
+
environments = self._env_config.get("environments", {})
|
|
259
|
+
env_def = environments.get(env_name, {})
|
|
260
|
+
ports_cfg = env_def.get("ports", {})
|
|
261
|
+
|
|
262
|
+
allocated = {}
|
|
263
|
+
container_id = f"pool_{env_name}_{self._project_counter}"
|
|
264
|
+
|
|
265
|
+
for var_name, meta in ports_cfg.items():
|
|
266
|
+
default = int(meta.get("default", meta.get("container_port", 8000)))
|
|
267
|
+
port = mgr.allocate_port(container_id, var_name, default=default)
|
|
268
|
+
allocated[var_name] = port
|
|
269
|
+
|
|
270
|
+
return allocated
|
|
271
|
+
|
|
272
|
+
def _validate_env_requirements(self, compose_file: Path) -> bool:
|
|
273
|
+
"""
|
|
274
|
+
Validate environment-specific requirements before starting.
|
|
275
|
+
|
|
276
|
+
Looks for a validate.py in the environment directory and calls its
|
|
277
|
+
validate(env_dir) function if present.
|
|
278
|
+
|
|
279
|
+
Returns:
|
|
280
|
+
True if validation passes, False otherwise
|
|
281
|
+
"""
|
|
282
|
+
import importlib.util
|
|
283
|
+
|
|
284
|
+
env_dir = compose_file.parent
|
|
285
|
+
validate_script = env_dir / "validate.py"
|
|
286
|
+
|
|
287
|
+
if not validate_script.exists():
|
|
288
|
+
return True
|
|
289
|
+
|
|
290
|
+
try:
|
|
291
|
+
spec = importlib.util.spec_from_file_location("validate", validate_script)
|
|
292
|
+
if spec is None or spec.loader is None:
|
|
293
|
+
return True
|
|
294
|
+
|
|
295
|
+
module = importlib.util.module_from_spec(spec)
|
|
296
|
+
spec.loader.exec_module(module)
|
|
297
|
+
|
|
298
|
+
if hasattr(module, "validate"):
|
|
299
|
+
success, message = module.validate(env_dir)
|
|
300
|
+
if not success:
|
|
301
|
+
print(message, flush=True)
|
|
302
|
+
return False
|
|
303
|
+
|
|
304
|
+
except Exception as e:
|
|
305
|
+
print(f"[EXECUTOR] Warning: Failed to run {validate_script}: {e}", flush=True)
|
|
306
|
+
|
|
307
|
+
return True
|
|
308
|
+
|
|
309
|
+
async def _start_instance(self, env_name: str, max_retries: int = 5, wait_time: int = 30) -> Optional[EnvInstance]:
|
|
310
|
+
"""Start a new environment instance."""
|
|
311
|
+
# Memory guard: wait for sufficient memory before launching
|
|
312
|
+
for attempt in range(max_retries):
|
|
313
|
+
try:
|
|
314
|
+
check_memory_before_launch()
|
|
315
|
+
break
|
|
316
|
+
except RuntimeError as e:
|
|
317
|
+
if attempt == max_retries - 1:
|
|
318
|
+
print(f"[EXECUTOR] Memory guard blocked launch after {max_retries} retries: {e}", flush=True)
|
|
319
|
+
return None
|
|
320
|
+
wait = wait_time * (attempt + 1)
|
|
321
|
+
print(f"[EXECUTOR] Low memory, waiting {wait}s before retry ({attempt+1}/{max_retries})...", flush=True)
|
|
322
|
+
await asyncio.sleep(wait)
|
|
323
|
+
|
|
324
|
+
compose_file = self._get_compose_file(env_name)
|
|
325
|
+
if not compose_file or not compose_file.exists():
|
|
326
|
+
print(f"[EXECUTOR] No docker-compose file found for {env_name}", flush=True)
|
|
327
|
+
return None
|
|
328
|
+
|
|
329
|
+
# Pre-start validation (calls validate.py if present in env directory)
|
|
330
|
+
if not self._validate_env_requirements(compose_file):
|
|
331
|
+
return None
|
|
332
|
+
|
|
333
|
+
project_name = self._generate_project_name(env_name)
|
|
334
|
+
instance_id = f"{env_name}:{project_name}"
|
|
335
|
+
ports = self._allocate_ports(env_name)
|
|
336
|
+
|
|
337
|
+
instance = EnvInstance(
|
|
338
|
+
instance_id=instance_id,
|
|
339
|
+
env_name=env_name,
|
|
340
|
+
project_name=project_name,
|
|
341
|
+
compose_file=compose_file,
|
|
342
|
+
state=EnvState.STARTING,
|
|
343
|
+
ports=ports,
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
print(f"[EXECUTOR] Starting instance {instance_id}", flush=True)
|
|
347
|
+
|
|
348
|
+
# Track the new instance
|
|
349
|
+
self._instances[instance.instance_id] = instance
|
|
350
|
+
self._env_instances[env_name].append(instance.instance_id)
|
|
351
|
+
|
|
352
|
+
try:
|
|
353
|
+
# Build docker compose command
|
|
354
|
+
if _USE_SUDO:
|
|
355
|
+
# Use sudo with 'env' command to ensure env vars are properly passed
|
|
356
|
+
cmd = ["sudo", "env"] + [f"{k}={v}" for k, v in ports.items()] + [
|
|
357
|
+
"docker", "compose", "-p", project_name, "-f", str(compose_file), "up", "-d"
|
|
358
|
+
]
|
|
359
|
+
proc = await asyncio.create_subprocess_exec(
|
|
360
|
+
*cmd,
|
|
361
|
+
cwd=str(compose_file.parent),
|
|
362
|
+
stdout=asyncio.subprocess.PIPE,
|
|
363
|
+
stderr=asyncio.subprocess.PIPE,
|
|
364
|
+
)
|
|
365
|
+
else:
|
|
366
|
+
# Set port environment variables when not using sudo
|
|
367
|
+
env = os.environ.copy()
|
|
368
|
+
for var_name, port in ports.items():
|
|
369
|
+
env[var_name] = str(port)
|
|
370
|
+
|
|
371
|
+
proc = await asyncio.create_subprocess_exec(
|
|
372
|
+
"docker", "compose", "-p", project_name, "-f", str(compose_file), "up", "-d",
|
|
373
|
+
cwd=str(compose_file.parent),
|
|
374
|
+
env=env,
|
|
375
|
+
stdout=asyncio.subprocess.PIPE,
|
|
376
|
+
stderr=asyncio.subprocess.PIPE,
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
stdout, stderr = await proc.communicate()
|
|
380
|
+
if proc.returncode != 0:
|
|
381
|
+
print(f"[EXECUTOR] Start failed for {instance_id}: {stderr.decode()}", flush=True)
|
|
382
|
+
return None
|
|
383
|
+
|
|
384
|
+
# Wait for containers to be healthy
|
|
385
|
+
# Get timeout from env config (default 120s, Windows needs 240s+)
|
|
386
|
+
environments = self._env_config.get("environments", {})
|
|
387
|
+
env_def = environments.get(env_name, {})
|
|
388
|
+
health_timeout = env_def.get("health_timeout", 120)
|
|
389
|
+
await self._wait_for_healthy(project_name, compose_file, timeout=health_timeout)
|
|
390
|
+
|
|
391
|
+
instance.state = EnvState.AVAILABLE
|
|
392
|
+
print(f"[EXECUTOR] Instance {instance_id} started successfully", flush=True)
|
|
393
|
+
return instance
|
|
394
|
+
|
|
395
|
+
except Exception as e:
|
|
396
|
+
print(f"[EXECUTOR] Failed to start instance {instance_id}: {e}", flush=True)
|
|
397
|
+
return None
|
|
398
|
+
|
|
399
|
+
async def _wait_for_healthy(
|
|
400
|
+
self,
|
|
401
|
+
project_name: str,
|
|
402
|
+
compose_file: Path,
|
|
403
|
+
timeout: int = 120,
|
|
404
|
+
) -> bool:
|
|
405
|
+
"""Wait for all containers in the project to be healthy."""
|
|
406
|
+
print(f"[EXECUTOR] Waiting for containers to be healthy...", flush=True)
|
|
407
|
+
start_time = time.time()
|
|
408
|
+
|
|
409
|
+
while time.time() - start_time < timeout:
|
|
410
|
+
# Check container health status
|
|
411
|
+
cmd = ["docker", "compose", "-p", project_name, "-f", str(compose_file), "ps", "--format", "json"]
|
|
412
|
+
if _USE_SUDO:
|
|
413
|
+
cmd = ["sudo"] + cmd
|
|
414
|
+
|
|
415
|
+
proc = await asyncio.create_subprocess_exec(
|
|
416
|
+
*cmd,
|
|
417
|
+
cwd=str(compose_file.parent),
|
|
418
|
+
stdout=asyncio.subprocess.PIPE,
|
|
419
|
+
stderr=asyncio.subprocess.PIPE,
|
|
420
|
+
)
|
|
421
|
+
stdout, stderr = await proc.communicate()
|
|
422
|
+
|
|
423
|
+
if proc.returncode != 0:
|
|
424
|
+
await asyncio.sleep(2)
|
|
425
|
+
continue
|
|
426
|
+
|
|
427
|
+
# Parse JSON output (one JSON object per line)
|
|
428
|
+
all_healthy = True
|
|
429
|
+
output = stdout.decode().strip()
|
|
430
|
+
if not output:
|
|
431
|
+
await asyncio.sleep(2)
|
|
432
|
+
continue
|
|
433
|
+
|
|
434
|
+
for line in output.split('\n'):
|
|
435
|
+
if not line.strip():
|
|
436
|
+
continue
|
|
437
|
+
try:
|
|
438
|
+
container = json.loads(line)
|
|
439
|
+
health = container.get("Health", "")
|
|
440
|
+
state = container.get("State", "")
|
|
441
|
+
# Container is ready if: no healthcheck (health empty) and running, OR healthy
|
|
442
|
+
if state != "running":
|
|
443
|
+
all_healthy = False
|
|
444
|
+
break
|
|
445
|
+
if health and health != "healthy":
|
|
446
|
+
all_healthy = False
|
|
447
|
+
break
|
|
448
|
+
except json.JSONDecodeError:
|
|
449
|
+
continue
|
|
450
|
+
|
|
451
|
+
if all_healthy:
|
|
452
|
+
print(f"[EXECUTOR] All containers healthy", flush=True)
|
|
453
|
+
return True
|
|
454
|
+
|
|
455
|
+
await asyncio.sleep(2)
|
|
456
|
+
|
|
457
|
+
print(f"[EXECUTOR] Timeout waiting for containers to be healthy", flush=True)
|
|
458
|
+
return False
|
|
459
|
+
|
|
460
|
+
async def _stop_instance(self, instance: EnvInstance) -> bool:
|
|
461
|
+
"""Stop and remove an environment instance."""
|
|
462
|
+
instance.state = EnvState.STOPPING
|
|
463
|
+
print(f"[EXECUTOR] Stopping instance {instance.instance_id}", flush=True)
|
|
464
|
+
|
|
465
|
+
try:
|
|
466
|
+
cmd = ["docker", "compose", "-p", instance.project_name,
|
|
467
|
+
"-f", str(instance.compose_file), "down", "--remove-orphans", "--volumes"]
|
|
468
|
+
if _USE_SUDO:
|
|
469
|
+
cmd = ["sudo"] + cmd
|
|
470
|
+
|
|
471
|
+
proc = await asyncio.create_subprocess_exec(
|
|
472
|
+
*cmd,
|
|
473
|
+
cwd=str(instance.compose_file.parent),
|
|
474
|
+
stdout=asyncio.subprocess.PIPE,
|
|
475
|
+
stderr=asyncio.subprocess.PIPE,
|
|
476
|
+
)
|
|
477
|
+
await asyncio.wait_for(proc.wait(), timeout=60)
|
|
478
|
+
|
|
479
|
+
# Remove from tracking
|
|
480
|
+
if instance.instance_id in self._instances:
|
|
481
|
+
del self._instances[instance.instance_id]
|
|
482
|
+
if instance.instance_id in self._env_instances.get(instance.env_name, []):
|
|
483
|
+
self._env_instances[instance.env_name].remove(instance.instance_id)
|
|
484
|
+
|
|
485
|
+
print(f"[EXECUTOR] Instance {instance.instance_id} stopped", flush=True)
|
|
486
|
+
return True
|
|
487
|
+
|
|
488
|
+
except Exception as e:
|
|
489
|
+
print(f"[EXECUTOR] Error stopping instance {instance.instance_id}: {e}", flush=True)
|
|
490
|
+
return False
|
|
491
|
+
|
|
492
|
+
def _get_disable_reuse_flag(self, env_name: str) -> bool:
|
|
493
|
+
"""Get disable reuse flag from env config."""
|
|
494
|
+
environments = self._env_config.get("environments", {})
|
|
495
|
+
env_def = environments.get(env_name, {})
|
|
496
|
+
return env_def.get("disable_reuse", False)
|
|
497
|
+
|
|
498
|
+
def _get_reset_scripts(self, env_name: str) -> Dict[str, str]:
|
|
499
|
+
"""Get reset script paths for each service from env config."""
|
|
500
|
+
environments = self._env_config.get("environments", {})
|
|
501
|
+
env_def = environments.get(env_name, {})
|
|
502
|
+
return env_def.get("reset_scripts", {})
|
|
503
|
+
|
|
504
|
+
def _get_reset_endpoints(self, env_name: str) -> Dict[str, Dict[str, Any]]:
|
|
505
|
+
"""
|
|
506
|
+
Get reset API endpoints from env config.
|
|
507
|
+
|
|
508
|
+
Returns a dict mapping endpoint names to their config:
|
|
509
|
+
{
|
|
510
|
+
"endpoint_name": {
|
|
511
|
+
"url": "http://localhost:${PORT}/api/v1/reset",
|
|
512
|
+
"method": "POST", # optional, defaults to POST
|
|
513
|
+
"port_var": "SALESFORCE_API_PORT" # which port variable to use
|
|
514
|
+
}
|
|
515
|
+
}
|
|
516
|
+
"""
|
|
517
|
+
environments = self._env_config.get("environments", {})
|
|
518
|
+
env_def = environments.get(env_name, {})
|
|
519
|
+
return env_def.get("reset_endpoints", {})
|
|
520
|
+
|
|
521
|
+
async def _reset_instance(self, instance: EnvInstance) -> None:
|
|
522
|
+
"""
|
|
523
|
+
Reset an instance's data state.
|
|
524
|
+
|
|
525
|
+
Delegates to reset_environment from reset_helpers module which handles:
|
|
526
|
+
1. API endpoints (reset_endpoints) - preferred, safer
|
|
527
|
+
2. Docker exec scripts (reset_scripts) - fallback
|
|
528
|
+
|
|
529
|
+
Raises:
|
|
530
|
+
RuntimeError: If reset fails (both endpoints and scripts failed)
|
|
531
|
+
"""
|
|
532
|
+
print(f"[EXECUTOR] Resetting instance {instance.instance_id}", flush=True)
|
|
533
|
+
|
|
534
|
+
# Read per-environment script timeout (e.g. Windows loadvm needs ~60s)
|
|
535
|
+
environments = self._env_config.get("environments", {})
|
|
536
|
+
env_def = environments.get(instance.env_name, {})
|
|
537
|
+
script_timeout = env_def.get("reset_script_timeout", 60)
|
|
538
|
+
|
|
539
|
+
await reset_environment(
|
|
540
|
+
env_name=instance.env_name,
|
|
541
|
+
ports=instance.ports,
|
|
542
|
+
env_config=self._env_config,
|
|
543
|
+
project_name=instance.project_name,
|
|
544
|
+
compose_file=instance.compose_file,
|
|
545
|
+
script_timeout=script_timeout,
|
|
546
|
+
)
|
|
547
|
+
|
|
548
|
+
def _count_env_instances(self, env_name: str, exclude_stopping: bool = True) -> int:
|
|
549
|
+
"""Count running instances of an environment type."""
|
|
550
|
+
count = 0
|
|
551
|
+
for inst_id in self._env_instances.get(env_name, []):
|
|
552
|
+
inst = self._instances.get(inst_id)
|
|
553
|
+
if inst:
|
|
554
|
+
if exclude_stopping and inst.state == EnvState.STOPPING:
|
|
555
|
+
continue
|
|
556
|
+
count += 1
|
|
557
|
+
return count
|
|
558
|
+
|
|
559
|
+
def _find_available_instance(self, env_name: str) -> Optional[EnvInstance]:
|
|
560
|
+
"""Find an available instance of the given environment type."""
|
|
561
|
+
for inst_id in self._env_instances.get(env_name, []):
|
|
562
|
+
inst = self._instances.get(inst_id)
|
|
563
|
+
if inst and inst.state == EnvState.AVAILABLE:
|
|
564
|
+
return inst
|
|
565
|
+
return None
|
|
566
|
+
|
|
567
|
+
def _get_needed_envs(self) -> Set[str]:
|
|
568
|
+
"""Get all environment types needed by pending tasks."""
|
|
569
|
+
needed = set()
|
|
570
|
+
for task in self._pending:
|
|
571
|
+
needed.update(task.environments)
|
|
572
|
+
return needed
|
|
573
|
+
|
|
574
|
+
def _can_start_task(self, task: ScheduledTask) -> bool:
|
|
575
|
+
"""Check if a task can start given current env limits."""
|
|
576
|
+
for env_name in task.environments:
|
|
577
|
+
# Check if we have an available instance
|
|
578
|
+
if self._find_available_instance(env_name):
|
|
579
|
+
continue
|
|
580
|
+
|
|
581
|
+
# Check if we can create a new instance
|
|
582
|
+
max_inst = self._get_max_instances(env_name)
|
|
583
|
+
if max_inst is not None:
|
|
584
|
+
current_count = self._count_env_instances(env_name)
|
|
585
|
+
if current_count >= max_inst:
|
|
586
|
+
return False
|
|
587
|
+
|
|
588
|
+
return True
|
|
589
|
+
|
|
590
|
+
def _calculate_task_priority(
|
|
591
|
+
self,
|
|
592
|
+
task: ScheduledTask,
|
|
593
|
+
available_envs: Set[str],
|
|
594
|
+
) -> Tuple[int, int]:
|
|
595
|
+
"""
|
|
596
|
+
Calculate priority for a task.
|
|
597
|
+
|
|
598
|
+
Returns (negative_reuse_count, original_index) for sorting.
|
|
599
|
+
Lower values = higher priority.
|
|
600
|
+
"""
|
|
601
|
+
# Count how many environments can be reused
|
|
602
|
+
reusable = sum(1 for env in task.environments if env in available_envs)
|
|
603
|
+
return (-reusable, task.original_index)
|
|
604
|
+
|
|
605
|
+
def _pick_next_task(self) -> Optional[ScheduledTask]:
|
|
606
|
+
"""Pick the best pending task to run next."""
|
|
607
|
+
if not self._pending:
|
|
608
|
+
return None
|
|
609
|
+
|
|
610
|
+
# Get currently available environment types
|
|
611
|
+
available_envs = set()
|
|
612
|
+
for inst in self._instances.values():
|
|
613
|
+
if inst.state == EnvState.AVAILABLE:
|
|
614
|
+
available_envs.add(inst.env_name)
|
|
615
|
+
|
|
616
|
+
# Filter to tasks that can actually start
|
|
617
|
+
candidates = [t for t in self._pending if self._can_start_task(t)]
|
|
618
|
+
if not candidates:
|
|
619
|
+
return None
|
|
620
|
+
|
|
621
|
+
# Sort by priority (max reuse, then FIFO)
|
|
622
|
+
candidates.sort(key=lambda t: self._calculate_task_priority(t, available_envs))
|
|
623
|
+
|
|
624
|
+
chosen = candidates[0]
|
|
625
|
+
self._pending.remove(chosen)
|
|
626
|
+
return chosen
|
|
627
|
+
|
|
628
|
+
async def _acquire_instances_for_task(
|
|
629
|
+
self,
|
|
630
|
+
task: ScheduledTask,
|
|
631
|
+
task_id: str,
|
|
632
|
+
) -> Optional[List[EnvInstance]]:
|
|
633
|
+
"""Acquire all required instances for a task."""
|
|
634
|
+
acquired: List[EnvInstance] = []
|
|
635
|
+
|
|
636
|
+
for env_name in task.environments:
|
|
637
|
+
# Try to find an available instance
|
|
638
|
+
instance = self._find_available_instance(env_name)
|
|
639
|
+
disable_reuse = self._get_disable_reuse_flag(env_name)
|
|
640
|
+
if instance and disable_reuse:
|
|
641
|
+
print(f"[EXECUTOR] Reuse disabled for {env_name}, stopping instance {instance.instance_id}", flush=True)
|
|
642
|
+
await self._stop_instance(instance)
|
|
643
|
+
instance = None
|
|
644
|
+
|
|
645
|
+
if instance:
|
|
646
|
+
# Reuse existing instance
|
|
647
|
+
instance.state = EnvState.IN_USE
|
|
648
|
+
instance.current_task_id = task_id
|
|
649
|
+
instance.use_count += 1
|
|
650
|
+
instance.last_used = time.time()
|
|
651
|
+
acquired.append(instance)
|
|
652
|
+
|
|
653
|
+
# Track early so release works even if reset fails
|
|
654
|
+
self._task_instances[task_id] = [inst.instance_id for inst in acquired]
|
|
655
|
+
|
|
656
|
+
# Reset before reuse (with configurable retries)
|
|
657
|
+
environments = self._env_config.get("environments", {})
|
|
658
|
+
env_def = environments.get(env_name, {})
|
|
659
|
+
max_retries = env_def.get("reset_retries", 1)
|
|
660
|
+
retry_delay = env_def.get("reset_retry_delay", 10)
|
|
661
|
+
|
|
662
|
+
for attempt in range(1 + max_retries):
|
|
663
|
+
try:
|
|
664
|
+
await self._reset_instance(instance)
|
|
665
|
+
break
|
|
666
|
+
except Exception as e:
|
|
667
|
+
if attempt < max_retries:
|
|
668
|
+
print(f"[EXECUTOR] Reset failed for {instance.instance_id} (attempt {attempt + 1}/{1 + max_retries}), retrying in {retry_delay}s: {e}", flush=True)
|
|
669
|
+
await asyncio.sleep(retry_delay)
|
|
670
|
+
else:
|
|
671
|
+
raise
|
|
672
|
+
|
|
673
|
+
print(f"[EXECUTOR] Reusing instance {instance.instance_id} for task {task_id} (use #{instance.use_count})", flush=True)
|
|
674
|
+
else:
|
|
675
|
+
# Start new instance (release lock during slow operation)
|
|
676
|
+
self._lock.release()
|
|
677
|
+
try:
|
|
678
|
+
instance = await self._start_instance(env_name)
|
|
679
|
+
finally:
|
|
680
|
+
await self._lock.acquire()
|
|
681
|
+
|
|
682
|
+
if not instance:
|
|
683
|
+
# Rollback acquired instances
|
|
684
|
+
for inst in acquired:
|
|
685
|
+
inst.state = EnvState.AVAILABLE
|
|
686
|
+
inst.current_task_id = None
|
|
687
|
+
return None
|
|
688
|
+
|
|
689
|
+
instance.state = EnvState.IN_USE
|
|
690
|
+
instance.current_task_id = task_id
|
|
691
|
+
instance.use_count = 1
|
|
692
|
+
acquired.append(instance)
|
|
693
|
+
|
|
694
|
+
# Track early so release works even if reset fails
|
|
695
|
+
self._task_instances[task_id] = [inst.instance_id for inst in acquired]
|
|
696
|
+
|
|
697
|
+
# Reset after start
|
|
698
|
+
await self._reset_instance(instance)
|
|
699
|
+
|
|
700
|
+
print(f"[EXECUTOR] Started new instance {instance.instance_id} for task {task_id}", flush=True)
|
|
701
|
+
|
|
702
|
+
# Final update of tracking (in case multiple environments)
|
|
703
|
+
self._task_instances[task_id] = [inst.instance_id for inst in acquired]
|
|
704
|
+
|
|
705
|
+
return acquired
|
|
706
|
+
|
|
707
|
+
async def _release_instances_for_task(self, task_id: str) -> Set[str]:
|
|
708
|
+
"""Release instances used by a task. Returns released env names."""
|
|
709
|
+
instance_ids = self._task_instances.pop(task_id, [])
|
|
710
|
+
released_envs = set()
|
|
711
|
+
|
|
712
|
+
for inst_id in instance_ids:
|
|
713
|
+
inst = self._instances.get(inst_id)
|
|
714
|
+
if inst and inst.state == EnvState.IN_USE:
|
|
715
|
+
inst.state = EnvState.AVAILABLE
|
|
716
|
+
inst.current_task_id = None
|
|
717
|
+
inst.last_used = time.time()
|
|
718
|
+
released_envs.add(inst.env_name)
|
|
719
|
+
print(f"[EXECUTOR] Released instance {inst_id}", flush=True)
|
|
720
|
+
|
|
721
|
+
return released_envs
|
|
722
|
+
|
|
723
|
+
async def _cleanup_unused_instances(self, released_envs: Set[str]) -> None:
|
|
724
|
+
"""Stop instances for environments no longer needed by pending tasks."""
|
|
725
|
+
needed_envs = self._get_needed_envs()
|
|
726
|
+
|
|
727
|
+
for env_name in released_envs:
|
|
728
|
+
if env_name not in needed_envs:
|
|
729
|
+
# Stop all available instances of this env type
|
|
730
|
+
for inst_id in list(self._env_instances.get(env_name, [])):
|
|
731
|
+
inst = self._instances.get(inst_id)
|
|
732
|
+
if inst and inst.state == EnvState.AVAILABLE:
|
|
733
|
+
# Release lock during slow stop operation
|
|
734
|
+
self._lock.release()
|
|
735
|
+
try:
|
|
736
|
+
await self._stop_instance(inst)
|
|
737
|
+
finally:
|
|
738
|
+
await self._lock.acquire()
|
|
739
|
+
|
|
740
|
+
async def _run_single_task(
|
|
741
|
+
self,
|
|
742
|
+
task: ScheduledTask,
|
|
743
|
+
task_id: str,
|
|
744
|
+
run_fn: Callable[[ScheduledTask, Dict[str, EnvInstance]], Coroutine[Any, Any, int]],
|
|
745
|
+
) -> int:
|
|
746
|
+
"""Run a single task and handle completion."""
|
|
747
|
+
instances: Optional[List[EnvInstance]] = None
|
|
748
|
+
|
|
749
|
+
try:
|
|
750
|
+
# Acquire instances
|
|
751
|
+
async with self._lock:
|
|
752
|
+
instances = await self._acquire_instances_for_task(task, task_id)
|
|
753
|
+
|
|
754
|
+
if not instances:
|
|
755
|
+
print(f"[EXECUTOR] Failed to acquire instances for task {task_id}", flush=True)
|
|
756
|
+
return 1
|
|
757
|
+
|
|
758
|
+
# Build env_name -> instance mapping for the task
|
|
759
|
+
env_instances = {inst.env_name: inst for inst in instances}
|
|
760
|
+
|
|
761
|
+
# Run the task
|
|
762
|
+
return_code = await run_fn(task, env_instances)
|
|
763
|
+
return return_code
|
|
764
|
+
|
|
765
|
+
except Exception as e:
|
|
766
|
+
print(f"[EXECUTOR] Error running task {task_id}: {e}", flush=True)
|
|
767
|
+
return 1
|
|
768
|
+
|
|
769
|
+
finally:
|
|
770
|
+
# Release instances and cleanup
|
|
771
|
+
async with self._lock:
|
|
772
|
+
released_envs = await self._release_instances_for_task(task_id)
|
|
773
|
+
await self._cleanup_unused_instances(released_envs)
|
|
774
|
+
|
|
775
|
+
async def _worker(
|
|
776
|
+
self,
|
|
777
|
+
run_fn: Callable[[ScheduledTask, Dict[str, EnvInstance]], Coroutine[Any, Any, int]],
|
|
778
|
+
results: List[Tuple[ScheduledTask, int]],
|
|
779
|
+
slot_available: asyncio.Event,
|
|
780
|
+
) -> None:
|
|
781
|
+
"""Worker that processes tasks from the pending queue."""
|
|
782
|
+
while True:
|
|
783
|
+
task: Optional[ScheduledTask] = None
|
|
784
|
+
task_id: Optional[str] = None
|
|
785
|
+
|
|
786
|
+
async with self._lock:
|
|
787
|
+
# Check if we should exit
|
|
788
|
+
if not self._pending and not self._running:
|
|
789
|
+
return
|
|
790
|
+
|
|
791
|
+
# Try to pick a task
|
|
792
|
+
if len(self._running) < self.max_parallel:
|
|
793
|
+
task = self._pick_next_task()
|
|
794
|
+
if task:
|
|
795
|
+
task_id = f"{task.task_dir.name}_{id(task)}"
|
|
796
|
+
self._running[task_id] = RunningTask(
|
|
797
|
+
task=task,
|
|
798
|
+
task_id=task_id,
|
|
799
|
+
instances=[],
|
|
800
|
+
)
|
|
801
|
+
|
|
802
|
+
if task and task_id:
|
|
803
|
+
# Run the task outside lock
|
|
804
|
+
return_code = await self._run_single_task(task, task_id, run_fn)
|
|
805
|
+
|
|
806
|
+
async with self._lock:
|
|
807
|
+
# Record result
|
|
808
|
+
results.append((task, return_code))
|
|
809
|
+
del self._running[task_id]
|
|
810
|
+
|
|
811
|
+
# Signal that a slot is available
|
|
812
|
+
slot_available.set()
|
|
813
|
+
else:
|
|
814
|
+
# Wait for a slot to become available
|
|
815
|
+
slot_available.clear()
|
|
816
|
+
await slot_available.wait()
|
|
817
|
+
|
|
818
|
+
async def run_all(
|
|
819
|
+
self,
|
|
820
|
+
tasks: List[ScheduledTask],
|
|
821
|
+
run_fn: Callable[[ScheduledTask, Dict[str, EnvInstance]], Coroutine[Any, Any, int]],
|
|
822
|
+
) -> List[Tuple[ScheduledTask, int]]:
|
|
823
|
+
"""
|
|
824
|
+
Run all tasks with optimal parallelism and environment reuse.
|
|
825
|
+
|
|
826
|
+
Args:
|
|
827
|
+
tasks: List of tasks to run
|
|
828
|
+
run_fn: Async function that takes (task, env_instances) and returns exit code
|
|
829
|
+
|
|
830
|
+
Returns:
|
|
831
|
+
List of (task, return_code) tuples
|
|
832
|
+
"""
|
|
833
|
+
if not tasks:
|
|
834
|
+
return []
|
|
835
|
+
|
|
836
|
+
# Initialize pending queue (maintain original order)
|
|
837
|
+
self._pending = list(tasks)
|
|
838
|
+
self._running.clear()
|
|
839
|
+
results: List[Tuple[ScheduledTask, int]] = []
|
|
840
|
+
|
|
841
|
+
print(f"[EXECUTOR] Starting {len(tasks)} tasks with max_parallel={self.max_parallel}", flush=True)
|
|
842
|
+
|
|
843
|
+
# Create worker coordination
|
|
844
|
+
slot_available = asyncio.Event()
|
|
845
|
+
slot_available.set() # Start with slots available
|
|
846
|
+
|
|
847
|
+
# Create workers
|
|
848
|
+
workers = [
|
|
849
|
+
asyncio.create_task(self._worker(run_fn, results, slot_available))
|
|
850
|
+
for _ in range(self.max_parallel)
|
|
851
|
+
]
|
|
852
|
+
|
|
853
|
+
# Wait for all workers to complete
|
|
854
|
+
await asyncio.gather(*workers)
|
|
855
|
+
|
|
856
|
+
return results
|
|
857
|
+
|
|
858
|
+
def stats(self) -> ExecutorStats:
|
|
859
|
+
"""Get current executor statistics."""
|
|
860
|
+
instances_by_env: Dict[str, int] = defaultdict(int)
|
|
861
|
+
available = in_use = 0
|
|
862
|
+
|
|
863
|
+
for inst in self._instances.values():
|
|
864
|
+
instances_by_env[inst.env_name] += 1
|
|
865
|
+
if inst.state == EnvState.AVAILABLE:
|
|
866
|
+
available += 1
|
|
867
|
+
elif inst.state == EnvState.IN_USE:
|
|
868
|
+
in_use += 1
|
|
869
|
+
|
|
870
|
+
return ExecutorStats(
|
|
871
|
+
total_tasks=len(self._pending) + len(self._running) + len(self._completed),
|
|
872
|
+
pending_tasks=len(self._pending),
|
|
873
|
+
running_tasks=len(self._running),
|
|
874
|
+
completed_tasks=len(self._completed),
|
|
875
|
+
total_instances=len(self._instances),
|
|
876
|
+
available_instances=available,
|
|
877
|
+
in_use_instances=in_use,
|
|
878
|
+
instances_by_env=dict(instances_by_env),
|
|
879
|
+
)
|
|
880
|
+
|
|
881
|
+
async def shutdown(self) -> None:
|
|
882
|
+
"""Shutdown all instances."""
|
|
883
|
+
async with self._lock:
|
|
884
|
+
print(f"[EXECUTOR] Shutting down {len(self._instances)} instance(s)...", flush=True)
|
|
885
|
+
|
|
886
|
+
# Stop all instances concurrently
|
|
887
|
+
stop_tasks = []
|
|
888
|
+
for inst in list(self._instances.values()):
|
|
889
|
+
if inst.state != EnvState.STOPPING:
|
|
890
|
+
stop_tasks.append(self._stop_instance(inst))
|
|
891
|
+
|
|
892
|
+
if stop_tasks:
|
|
893
|
+
# Release lock during shutdown
|
|
894
|
+
self._lock.release()
|
|
895
|
+
try:
|
|
896
|
+
await asyncio.gather(*stop_tasks, return_exceptions=True)
|
|
897
|
+
finally:
|
|
898
|
+
await self._lock.acquire()
|
|
899
|
+
|
|
900
|
+
self._instances.clear()
|
|
901
|
+
self._env_instances.clear()
|
|
902
|
+
self._task_instances.clear()
|
|
903
|
+
|
|
904
|
+
print("[EXECUTOR] Shutdown complete", flush=True)
|