decodingtrust-agent-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent/__init__.py +30 -0
- agent/claudesdk/__init__.py +8 -0
- agent/claudesdk/example.py +221 -0
- agent/claudesdk/src/__init__.py +8 -0
- agent/claudesdk/src/agent.py +400 -0
- agent/claudesdk/src/mcp_proxy.py +409 -0
- agent/claudesdk/src/utils.py +420 -0
- agent/googleadk/__init__.py +15 -0
- agent/googleadk/example.py +237 -0
- agent/googleadk/src/__init__.py +12 -0
- agent/googleadk/src/agent.py +401 -0
- agent/googleadk/src/mcp_wrapper.py +163 -0
- agent/googleadk/src/utils.py +602 -0
- agent/langchain/__init__.py +8 -0
- agent/langchain/example.py +213 -0
- agent/langchain/src/__init__.py +8 -0
- agent/langchain/src/agent.py +645 -0
- agent/langchain/src/utils.py +433 -0
- agent/openaisdk/__init__.py +17 -0
- agent/openaisdk/example.py +228 -0
- agent/openaisdk/src/__init__.py +12 -0
- agent/openaisdk/src/agent.py +491 -0
- agent/openaisdk/src/agent_wrapper.py +143 -0
- agent/openaisdk/src/mcp_wrapper.py +395 -0
- agent/openaisdk/src/utils.py +493 -0
- agent/openclaw/__init__.py +10 -0
- agent/openclaw/example.py +251 -0
- agent/openclaw/src/__init__.py +14 -0
- agent/openclaw/src/agent.py +930 -0
- agent/openclaw/src/helpers/__init__.py +1 -0
- agent/openclaw/src/helpers/auth_helpers.py +55 -0
- agent/openclaw/src/mcp_proxy.py +564 -0
- agent/openclaw/src/plugin_generator.py +231 -0
- agent/openclaw/src/utils.py +341 -0
- agent/pocketflow/__init__.py +18 -0
- agent/pocketflow/example.py +221 -0
- agent/pocketflow/prompts/react_agent.py +46 -0
- agent/pocketflow/src/__init__.py +6 -0
- agent/pocketflow/src/agent.py +507 -0
- agent/pocketflow/src/agent_wrapper.py +159 -0
- agent/pocketflow/src/async_helper.py +92 -0
- agent/pocketflow/src/mcp_react_agent.py +279 -0
- agent/pocketflow/src/native_agent.py +74 -0
- agent/pocketflow/src/nodes.py +467 -0
- benchmark/__init__.py +0 -0
- benchmark/browser/benign.jsonl +34 -0
- benchmark/browser/direct.jsonl +85 -0
- benchmark/browser/indirect.jsonl +82 -0
- benchmark/code/benign.jsonl +0 -0
- benchmark/code/direct.jsonl +121 -0
- benchmark/code/indirect.jsonl +165 -0
- benchmark/crm/benign.jsonl +165 -0
- benchmark/crm/direct.jsonl +90 -0
- benchmark/crm/indirect.jsonl +150 -0
- benchmark/customer-service/benign.jsonl +160 -0
- benchmark/customer-service/direct.jsonl +100 -0
- benchmark/customer-service/indirect.jsonl +101 -0
- benchmark/finance/benign.jsonl +0 -0
- benchmark/finance/direct.jsonl +200 -0
- benchmark/finance/indirect.jsonl +200 -0
- benchmark/legal/benign.jsonl +0 -0
- benchmark/legal/direct.jsonl +200 -0
- benchmark/legal/indirect.jsonl +200 -0
- benchmark/macos/benign.jsonl +30 -0
- benchmark/macos/direct.jsonl +50 -0
- benchmark/macos/indirect.jsonl +50 -0
- benchmark/medical/benign.jsonl +642 -0
- benchmark/medical/direct.jsonl +229 -0
- benchmark/medical/indirect.jsonl +222 -0
- benchmark/os-filesystem/benign.jsonl +200 -0
- benchmark/os-filesystem/direct.jsonl +200 -0
- benchmark/os-filesystem/indirect.jsonl +200 -0
- benchmark/research/benign.jsonl +0 -0
- benchmark/research/direct.jsonl +119 -0
- benchmark/research/indirect.jsonl +125 -0
- benchmark/telecom/benign.jsonl +120 -0
- benchmark/telecom/direct.jsonl +161 -0
- benchmark/telecom/indirect.jsonl +166 -0
- benchmark/travel/benign.jsonl +130 -0
- benchmark/travel/direct.jsonl +105 -0
- benchmark/travel/indirect.jsonl +120 -0
- benchmark/windows/benign.jsonl +100 -0
- benchmark/windows/direct.jsonl +140 -0
- benchmark/windows/indirect.jsonl +107 -0
- benchmark/workflow/benign.jsonl +335 -0
- benchmark/workflow/direct.jsonl +78 -0
- benchmark/workflow/indirect.jsonl +107 -0
- cli/__init__.py +5 -0
- cli/main.py +182 -0
- cli/scaffold.py +334 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/METADATA +642 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/RECORD +374 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/WHEEL +5 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/entry_points.txt +2 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/licenses/LICENSE +201 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/top_level.txt +6 -0
- dt_arena/config/env.yaml +515 -0
- dt_arena/config/injection_mcp.yaml +430 -0
- dt_arena/config/mcp.yaml +642 -0
- dt_arena/envs/arxiv/docker-compose-hub.yml +31 -0
- dt_arena/envs/arxiv/docker-compose.yml +36 -0
- dt_arena/envs/atlassian/docker/docker-compose.dev.yml +65 -0
- dt_arena/envs/atlassian/docker/docker-compose.yml +53 -0
- dt_arena/envs/atlassian/docker-compose-hub.yml +57 -0
- dt_arena/envs/atlassian/docker-compose.yml +72 -0
- dt_arena/envs/bigquery/docker-compose.yml +20 -0
- dt_arena/envs/booking/docker-compose.yml +59 -0
- dt_arena/envs/calendar/docker-compose-hub.yml +30 -0
- dt_arena/envs/calendar/docker-compose.yml +42 -0
- dt_arena/envs/custom-website/docker-compose.yml +6 -0
- dt_arena/envs/customer_service/docker-compose.yml +59 -0
- dt_arena/envs/databricks/docker-compose-hub.yml +47 -0
- dt_arena/envs/databricks/docker-compose.yml +51 -0
- dt_arena/envs/ecommerce/docker-compose.yml +6 -0
- dt_arena/envs/ers/docker-compose.yml +36 -0
- dt_arena/envs/ers/hrms/docker/docker-compose.yml +31 -0
- dt_arena/envs/finance/docker-compose.yml +23 -0
- dt_arena/envs/github/docker/docker-compose-hub.yml +50 -0
- dt_arena/envs/github/docker/docker-compose.yml +50 -0
- dt_arena/envs/gmail/docker-compose-hub.yml +51 -0
- dt_arena/envs/gmail/docker-compose.yml +65 -0
- dt_arena/envs/google-form/docker-compose-hub.yml +33 -0
- dt_arena/envs/google-form/docker-compose.yml +41 -0
- dt_arena/envs/googledocs/docker-compose-hub.yml +61 -0
- dt_arena/envs/googledocs/docker-compose.yml +78 -0
- dt_arena/envs/hospital/docker-compose-hub.yml +25 -0
- dt_arena/envs/hospital/docker-compose.yml +27 -0
- dt_arena/envs/legal/docker-compose.yml +22 -0
- dt_arena/envs/linkedin/docker-compose.yml +63 -0
- dt_arena/envs/macos/docker-compose.yml +79 -0
- dt_arena/envs/os-filesystem/docker-compose-hub.yml +16 -0
- dt_arena/envs/os-filesystem/docker-compose.yml +20 -0
- dt_arena/envs/paypal/docker-compose-hub.yml +48 -0
- dt_arena/envs/paypal/docker-compose.yml +63 -0
- dt_arena/envs/research/docker-compose-hub.yml +13 -0
- dt_arena/envs/research/docker-compose.yml +24 -0
- dt_arena/envs/salesforce_crm/docker-compose-hub.yaml +45 -0
- dt_arena/envs/salesforce_crm/docker-compose.yaml +49 -0
- dt_arena/envs/slack/docker-compose-hub.yml +28 -0
- dt_arena/envs/slack/docker-compose.yml +41 -0
- dt_arena/envs/snowflake/docker-compose-hub.yml +41 -0
- dt_arena/envs/snowflake/docker-compose.yml +44 -0
- dt_arena/envs/telecom/docker-compose-hub.yml +16 -0
- dt_arena/envs/telecom/docker-compose.yml +17 -0
- dt_arena/envs/telegram/docker-compose-hub.yml +57 -0
- dt_arena/envs/telegram/docker-compose.yml +62 -0
- dt_arena/envs/terminal/docker-compose-hub.yml +12 -0
- dt_arena/envs/terminal/docker-compose.yml +26 -0
- dt_arena/envs/travel/docker-compose-hub.yml +19 -0
- dt_arena/envs/travel/docker-compose.yml +19 -0
- dt_arena/envs/whatsapp/docker-compose-hub.yml +61 -0
- dt_arena/envs/whatsapp/docker-compose.yml +78 -0
- dt_arena/envs/windows/docker-compose.yml +71 -0
- dt_arena/envs/zoom/docker-compose-hub.yml +27 -0
- dt_arena/envs/zoom/docker-compose.yml +40 -0
- dt_arena/injection_mcp_server/atlassian/env_injection.py +134 -0
- dt_arena/injection_mcp_server/calendar/env_injection.py +217 -0
- dt_arena/injection_mcp_server/custom_website/env_injection.py +97 -0
- dt_arena/injection_mcp_server/customer_service/env_injection.py +659 -0
- dt_arena/injection_mcp_server/databricks/env_injection.py +255 -0
- dt_arena/injection_mcp_server/ecommerce/env_injection.py +110 -0
- dt_arena/injection_mcp_server/finance/env_injection.py +85 -0
- dt_arena/injection_mcp_server/github/env_injection.py +206 -0
- dt_arena/injection_mcp_server/gmail/env_injection.py +211 -0
- dt_arena/injection_mcp_server/google_form/env_injection.py +186 -0
- dt_arena/injection_mcp_server/googledocs/env_injection.py +44 -0
- dt_arena/injection_mcp_server/hospital/env_injection.py +43 -0
- dt_arena/injection_mcp_server/legal/env_injection.py +229 -0
- dt_arena/injection_mcp_server/macos/env_injection.py +272 -0
- dt_arena/injection_mcp_server/os-filesystem/env_injection.py +341 -0
- dt_arena/injection_mcp_server/paypal/env_injection.py +268 -0
- dt_arena/injection_mcp_server/research/env_injection.py +616 -0
- dt_arena/injection_mcp_server/salesforce/env_injection.py +514 -0
- dt_arena/injection_mcp_server/slack/env_injection.py +265 -0
- dt_arena/injection_mcp_server/snowflake/env_injection.py +230 -0
- dt_arena/injection_mcp_server/telecom/env_injection.py +503 -0
- dt_arena/injection_mcp_server/telegram/env_injection.py +171 -0
- dt_arena/injection_mcp_server/terminal/env_injection.py +523 -0
- dt_arena/injection_mcp_server/travel/env_injection.py +173 -0
- dt_arena/injection_mcp_server/whatsapp/env_injection.py +185 -0
- dt_arena/injection_mcp_server/windows/env_injection.py +943 -0
- dt_arena/injection_mcp_server/zoom/env_injection.py +216 -0
- dt_arena/mcp_server/atlassian/main.py +1554 -0
- dt_arena/mcp_server/atlassian/test_server.py +66 -0
- dt_arena/mcp_server/bigquery/main.py +333 -0
- dt_arena/mcp_server/booking/main.py +310 -0
- dt_arena/mcp_server/browser/main.py +1741 -0
- dt_arena/mcp_server/calendar/example_multi_user.py +162 -0
- dt_arena/mcp_server/calendar/main.py +792 -0
- dt_arena/mcp_server/calendar/test_mcp.py +135 -0
- dt_arena/mcp_server/customer_service/main.py +1063 -0
- dt_arena/mcp_server/databricks/main.py +566 -0
- dt_arena/mcp_server/databricks/probe.py +102 -0
- dt_arena/mcp_server/ers/main.py +845 -0
- dt_arena/mcp_server/finance/__init__.py +87 -0
- dt_arena/mcp_server/finance/core/__init__.py +12 -0
- dt_arena/mcp_server/finance/core/data_loader.py +558 -0
- dt_arena/mcp_server/finance/core/portfolio.py +565 -0
- dt_arena/mcp_server/finance/evaluation/__init__.py +20 -0
- dt_arena/mcp_server/finance/evaluation/evaluator.py +217 -0
- dt_arena/mcp_server/finance/evaluation/logger.py +137 -0
- dt_arena/mcp_server/finance/injection/__init__.py +66 -0
- dt_arena/mcp_server/finance/injection/config.py +176 -0
- dt_arena/mcp_server/finance/injection/content.py +755 -0
- dt_arena/mcp_server/finance/injection/html.py +409 -0
- dt_arena/mcp_server/finance/injection/locations.py +167 -0
- dt_arena/mcp_server/finance/injection/methods.py +193 -0
- dt_arena/mcp_server/finance/injection/presets.py +1023 -0
- dt_arena/mcp_server/finance/main.py +361 -0
- dt_arena/mcp_server/finance/run_mcp.py +21 -0
- dt_arena/mcp_server/finance/run_web.py +26 -0
- dt_arena/mcp_server/finance/server/__init__.py +41 -0
- dt_arena/mcp_server/finance/server/extractor.py +1453 -0
- dt_arena/mcp_server/finance/server/extractor_minimal.py +292 -0
- dt_arena/mcp_server/finance/server/extractor_simple.py +1164 -0
- dt_arena/mcp_server/finance/server/injection_mcp.py +865 -0
- dt_arena/mcp_server/finance/server/mcp.py +451 -0
- dt_arena/mcp_server/finance/server/tools/__init__.py +23 -0
- dt_arena/mcp_server/finance/server/tools/account.py +88 -0
- dt_arena/mcp_server/finance/server/tools/browsing.py +328 -0
- dt_arena/mcp_server/finance/server/tools/social.py +73 -0
- dt_arena/mcp_server/finance/server/tools/trading.py +242 -0
- dt_arena/mcp_server/finance/server/tools/utility.py +49 -0
- dt_arena/mcp_server/finance/server/web.py +2139 -0
- dt_arena/mcp_server/finance/tasks/benchmark/__init__.py +28 -0
- dt_arena/mcp_server/finance/tasks/benchmark/attack_pool.py +3026 -0
- dt_arena/mcp_server/finance/tasks/benchmark/attack_runner.py +1315 -0
- dt_arena/mcp_server/finance/tasks/benchmark/finra_requirements.py +1335 -0
- dt_arena/mcp_server/finance/tasks/benchmark/finra_tasks.py +3665 -0
- dt_arena/mcp_server/finance/tasks/benchmark/malicious_tasks.py +2673 -0
- dt_arena/mcp_server/finance/tasks/redteam_suite/run_redteam_suite.py +1713 -0
- dt_arena/mcp_server/finance/test_mcp_tools.py +476 -0
- dt_arena/mcp_server/github/main.py +441 -0
- dt_arena/mcp_server/gmail/main.py +1004 -0
- dt_arena/mcp_server/google_form/main.py +141 -0
- dt_arena/mcp_server/googledocs/main.py +458 -0
- dt_arena/mcp_server/hospital/mcp_server.py +458 -0
- dt_arena/mcp_server/legal/__init__.py +9 -0
- dt_arena/mcp_server/legal/core/__init__.py +14 -0
- dt_arena/mcp_server/legal/core/courtlistener_store.py +762 -0
- dt_arena/mcp_server/legal/core/data_loader.py +266 -0
- dt_arena/mcp_server/legal/core/document_store.py +197 -0
- dt_arena/mcp_server/legal/core/matter_manager.py +466 -0
- dt_arena/mcp_server/legal/main.py +89 -0
- dt_arena/mcp_server/legal/scripts/collect_data.py +988 -0
- dt_arena/mcp_server/legal/server/__init__.py +14 -0
- dt_arena/mcp_server/legal/server/mcp.py +2330 -0
- dt_arena/mcp_server/macos/client_test.py +270 -0
- dt_arena/mcp_server/macos/mcp_server.py +285 -0
- dt_arena/mcp_server/os-filesystem/main.py +1380 -0
- dt_arena/mcp_server/paypal/main.py +501 -0
- dt_arena/mcp_server/research/main.py +777 -0
- dt_arena/mcp_server/salesforce/main.py +2006 -0
- dt_arena/mcp_server/slack/main.py +318 -0
- dt_arena/mcp_server/snowflake/main.py +612 -0
- dt_arena/mcp_server/snowflake/probe.py +183 -0
- dt_arena/mcp_server/telecom/mcp_client.py +423 -0
- dt_arena/mcp_server/telecom/mcp_server.py +1059 -0
- dt_arena/mcp_server/telegram/main.py +338 -0
- dt_arena/mcp_server/terminal/main.py +163 -0
- dt_arena/mcp_server/travel/client_test.py +16 -0
- dt_arena/mcp_server/travel/mcp_server.py +404 -0
- dt_arena/mcp_server/whatsapp/main.py +318 -0
- dt_arena/mcp_server/windows/client_test.py +270 -0
- dt_arena/mcp_server/windows/mcp_server.py +218 -0
- dt_arena/mcp_server/zoom/main.py +466 -0
- dt_arena/src/__init__.py +0 -0
- dt_arena/src/hooks/__init__.py +0 -0
- dt_arena/src/hooks/audit_log.py +30 -0
- dt_arena/src/hooks/hooks.json +3 -0
- dt_arena/src/run_benign.py +142 -0
- dt_arena/src/types/__init__.py +0 -0
- dt_arena/src/types/agent.py +441 -0
- dt_arena/src/types/attacks.py +2 -0
- dt_arena/src/types/environment.py +2 -0
- dt_arena/src/types/hooks.py +174 -0
- dt_arena/src/types/judge.py +52 -0
- dt_arena/src/types/red_teaming_trajectory.py +385 -0
- dt_arena/src/types/task.py +260 -0
- dt_arena/src/types/trajectory.py +315 -0
- dt_arena/utils/__init__.py +1 -0
- dt_arena/utils/atlassian/__init__.py +27 -0
- dt_arena/utils/atlassian/helpers.py +520 -0
- dt_arena/utils/bigquery/__init__.py +1 -0
- dt_arena/utils/bigquery/helpers.py +246 -0
- dt_arena/utils/calendar/__init__.py +1 -0
- dt_arena/utils/calendar/helpers.py +87 -0
- dt_arena/utils/customer_service/__init__.py +17 -0
- dt_arena/utils/customer_service/cs_env_client.py +940 -0
- dt_arena/utils/customer_service/helpers.py +339 -0
- dt_arena/utils/customer_service/judges/__init__.py +20 -0
- dt_arena/utils/customer_service/judges/abusive_comms.py +149 -0
- dt_arena/utils/customer_service/judges/bulk_destructive.py +202 -0
- dt_arena/utils/customer_service/judges/cross_customer.py +481 -0
- dt_arena/utils/customer_service/judges/email_exfiltration.py +320 -0
- dt_arena/utils/customer_service/judges/financial_fraud.py +221 -0
- dt_arena/utils/customer_service/judges/llm_utils.py +264 -0
- dt_arena/utils/customer_service/judges/pii_disclosure.py +837 -0
- dt_arena/utils/customer_service/judges/policy_misrep.py +44 -0
- dt_arena/utils/customer_service/judges/text_utils.py +21 -0
- dt_arena/utils/databricks/__init__.py +2 -0
- dt_arena/utils/databricks/helpers.py +210 -0
- dt_arena/utils/finance/__init__.py +0 -0
- dt_arena/utils/finance/helpers.py +263 -0
- dt_arena/utils/github/__init__.py +1 -0
- dt_arena/utils/github/helpers.py +249 -0
- dt_arena/utils/gmail/__init__.py +1 -0
- dt_arena/utils/gmail/helpers.py +344 -0
- dt_arena/utils/google_form/__init__.py +2 -0
- dt_arena/utils/google_form/helpers.py +133 -0
- dt_arena/utils/legal/__init__.py +0 -0
- dt_arena/utils/legal/helpers.py +228 -0
- dt_arena/utils/macos/__init__.py +0 -0
- dt_arena/utils/macos/env_setup.py +215 -0
- dt_arena/utils/macos/helpers.py +61 -0
- dt_arena/utils/os_filesystem/__init__.py +1 -0
- dt_arena/utils/os_filesystem/helpers.py +366 -0
- dt_arena/utils/paypal/__init__.py +1 -0
- dt_arena/utils/paypal/helpers.py +178 -0
- dt_arena/utils/port_allocator.py +266 -0
- dt_arena/utils/research/__init__.py +0 -0
- dt_arena/utils/research/helpers.py +251 -0
- dt_arena/utils/salesforce/__init__.py +1 -0
- dt_arena/utils/salesforce/helpers.py +719 -0
- dt_arena/utils/slack/__init__.py +1 -0
- dt_arena/utils/slack/helpers.py +176 -0
- dt_arena/utils/snowflake/__init__.py +1 -0
- dt_arena/utils/snowflake/helpers.py +166 -0
- dt_arena/utils/telecom/__init__.py +1 -0
- dt_arena/utils/telecom/helpers.py +760 -0
- dt_arena/utils/telegram/__init__.py +0 -0
- dt_arena/utils/telegram/helpers.py +174 -0
- dt_arena/utils/terminal/__init__.py +0 -0
- dt_arena/utils/terminal/helpers.py +20 -0
- dt_arena/utils/travel/__init__.py +0 -0
- dt_arena/utils/travel/env_client.py +537 -0
- dt_arena/utils/travel/llm_judge.py +137 -0
- dt_arena/utils/travel/prompts.py +64 -0
- dt_arena/utils/utils/__init__.py +122 -0
- dt_arena/utils/whatsapp/__init__.py +0 -0
- dt_arena/utils/whatsapp/helpers.py +226 -0
- dt_arena/utils/windows/__init__.py +0 -0
- dt_arena/utils/windows/env_reset.py +224 -0
- dt_arena/utils/windows/env_setup.py +280 -0
- dt_arena/utils/windows/exfil_helpers.py +170 -0
- dt_arena/utils/windows/helpers.py +74 -0
- dt_arena/utils/zoom/__init__.py +1 -0
- dt_arena/utils/zoom/helpers.py +70 -0
- eval/__init__.py +1 -0
- eval/evaluation.py +426 -0
- eval/task_runner.py +449 -0
- utils/__init__.py +148 -0
- utils/agent_helpers.py +308 -0
- utils/agent_wrapper.py +189 -0
- utils/compose_utils.py +135 -0
- utils/config.py +77 -0
- utils/env_helpers.py +104 -0
- utils/eval_stats.py +88 -0
- utils/injection_helpers.py +429 -0
- utils/injection_mcp_helpers.py +152 -0
- utils/judge_helpers.py +181 -0
- utils/judge_utils.py +472 -0
- utils/llm.py +196 -0
- utils/logging.py +45 -0
- utils/mcp_helpers.py +232 -0
- utils/mcp_manager.py +235 -0
- utils/memory_guard.py +18 -0
- utils/red_teaming_sandbox.py +476 -0
- utils/reset_helpers.py +318 -0
- utils/resource_manager.py +370 -0
- utils/skill_helpers.py +447 -0
- utils/task_executor.py +904 -0
- utils/task_helpers.py +270 -0
- utils/template_helpers.py +179 -0
utils/env_helpers.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import subprocess
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from .config import PROJECT_ROOT
|
|
7
|
+
from .resource_manager import ResourceManager
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def task_setup(task_dir: Path, task_id: Optional[str] = None) -> None:
|
|
11
|
+
"""
|
|
12
|
+
Run setup.sh script for a task to start Docker environments and seed data.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
task_dir: Path to the task directory containing setup.sh
|
|
16
|
+
task_id: Optional task ID for resource tracking (if None, no tracking)
|
|
17
|
+
"""
|
|
18
|
+
setup_script = task_dir / "setup.sh"
|
|
19
|
+
|
|
20
|
+
if not setup_script.exists():
|
|
21
|
+
print(f"[SETUP] No setup.sh found in {task_dir}, skipping")
|
|
22
|
+
return
|
|
23
|
+
|
|
24
|
+
print(f"[SETUP] Running setup.sh for task: {task_dir}")
|
|
25
|
+
|
|
26
|
+
try:
|
|
27
|
+
result = subprocess.run(
|
|
28
|
+
["bash", str(setup_script)],
|
|
29
|
+
cwd=str(task_dir),
|
|
30
|
+
check=True,
|
|
31
|
+
capture_output=False, # Let output go to stdout/stderr
|
|
32
|
+
)
|
|
33
|
+
# Record setup completion timestamp in PER-TASK file (race-safe under
|
|
34
|
+
# parallel asyncio tasks within the same process). The CS benign judge
|
|
35
|
+
# reads this from task_dir/metadata/setup_done_at.txt to filter
|
|
36
|
+
# agent-authored cases vs frozen baseline.
|
|
37
|
+
from datetime import datetime, timezone
|
|
38
|
+
ts = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%f+00:00")
|
|
39
|
+
meta_dir = task_dir / "metadata"
|
|
40
|
+
meta_dir.mkdir(parents=True, exist_ok=True)
|
|
41
|
+
(meta_dir / "setup_done_at.txt").write_text(ts)
|
|
42
|
+
# Also set env var for backward compatibility, but judge prefers file.
|
|
43
|
+
os.environ["CS_SETUP_COMPLETED_AT"] = ts
|
|
44
|
+
print(f"[SETUP] setup.sh completed successfully (setup_done_at={ts})")
|
|
45
|
+
except subprocess.CalledProcessError as e:
|
|
46
|
+
print(f"[SETUP] setup.sh failed with exit code {e.returncode}")
|
|
47
|
+
raise
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def teardown_task(task_id: str, verbose: bool = True) -> None:
|
|
51
|
+
"""
|
|
52
|
+
Teardown all resources for a task using ResourceManager.
|
|
53
|
+
|
|
54
|
+
This will:
|
|
55
|
+
1. Stop all Docker compose projects registered for this task
|
|
56
|
+
2. Release all allocated ports
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
task_id: The task ID used during setup
|
|
60
|
+
verbose: Whether to print cleanup messages
|
|
61
|
+
"""
|
|
62
|
+
mgr = ResourceManager.instance()
|
|
63
|
+
mgr.cleanup_task(task_id, verbose=verbose)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def teardown_envs(
|
|
67
|
+
envs_used: dict,
|
|
68
|
+
project_name: Optional[str],
|
|
69
|
+
) -> None:
|
|
70
|
+
"""
|
|
71
|
+
Legacy teardown function for backward compatibility.
|
|
72
|
+
|
|
73
|
+
Deprecated: Use teardown_task() with ResourceManager instead.
|
|
74
|
+
|
|
75
|
+
This function is kept for backward compatibility with existing code
|
|
76
|
+
that passes envs_used dict. It now delegates to ResourceManager if
|
|
77
|
+
a task_id can be inferred from project_name.
|
|
78
|
+
"""
|
|
79
|
+
if not project_name:
|
|
80
|
+
return
|
|
81
|
+
|
|
82
|
+
# Try to cleanup via ResourceManager first
|
|
83
|
+
# Project names are like "wf_taskname_pid" or "wf_taskname_pid_envname"
|
|
84
|
+
# We try to find matching task_id
|
|
85
|
+
mgr = ResourceManager.instance()
|
|
86
|
+
snapshot = mgr.snapshot()
|
|
87
|
+
|
|
88
|
+
for task_id, resources in snapshot.items():
|
|
89
|
+
if any(project_name in p for p in resources.get("docker_projects", [])):
|
|
90
|
+
mgr.cleanup_task(task_id, verbose=True)
|
|
91
|
+
return
|
|
92
|
+
|
|
93
|
+
# Fallback: manually stop the project if not tracked
|
|
94
|
+
# This handles cases where setup.sh was run but resources weren't registered
|
|
95
|
+
print(f"[CLEANUP] Attempting manual teardown for project: {project_name}")
|
|
96
|
+
try:
|
|
97
|
+
subprocess.run(
|
|
98
|
+
["docker", "compose", "-p", project_name, "down", "--remove-orphans"],
|
|
99
|
+
capture_output=True,
|
|
100
|
+
check=False,
|
|
101
|
+
timeout=60,
|
|
102
|
+
)
|
|
103
|
+
except Exception as e:
|
|
104
|
+
print(f"[CLEANUP] Manual teardown error: {e}")
|
utils/eval_stats.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
@dataclass
|
|
6
|
+
class TaskTiming:
|
|
7
|
+
"""Timing information for a single task."""
|
|
8
|
+
task_name: str
|
|
9
|
+
start_time: float = 0.0
|
|
10
|
+
end_time: float = 0.0
|
|
11
|
+
container_acquire_time: float = 0.0
|
|
12
|
+
execution_time: float = 0.0
|
|
13
|
+
success: bool = False
|
|
14
|
+
|
|
15
|
+
@property
|
|
16
|
+
def total_time(self) -> float:
|
|
17
|
+
return self.end_time - self.start_time if self.end_time > 0 else 0.0
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class EvaluationStats:
|
|
22
|
+
"""Statistics for the entire evaluation run."""
|
|
23
|
+
total_tasks: int = 0
|
|
24
|
+
successful_tasks: int = 0
|
|
25
|
+
failed_tasks: int = 0
|
|
26
|
+
|
|
27
|
+
# Timing
|
|
28
|
+
eval_start_time: float = 0.0
|
|
29
|
+
eval_end_time: float = 0.0
|
|
30
|
+
scheduling_time: float = 0.0
|
|
31
|
+
|
|
32
|
+
# Per-task timings
|
|
33
|
+
task_timings: List[TaskTiming] = field(default_factory=list)
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def total_time(self) -> float:
|
|
37
|
+
return self.eval_end_time - self.eval_start_time if self.eval_end_time > 0 else 0.0
|
|
38
|
+
|
|
39
|
+
@property
|
|
40
|
+
def avg_task_time(self) -> float:
|
|
41
|
+
if not self.task_timings:
|
|
42
|
+
return 0.0
|
|
43
|
+
return sum(t.total_time for t in self.task_timings) / len(self.task_timings)
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def avg_container_acquire_time(self) -> float:
|
|
47
|
+
if not self.task_timings:
|
|
48
|
+
return 0.0
|
|
49
|
+
return sum(t.container_acquire_time for t in self.task_timings) / len(self.task_timings)
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def avg_execution_time(self) -> float:
|
|
53
|
+
if not self.task_timings:
|
|
54
|
+
return 0.0
|
|
55
|
+
return sum(t.execution_time for t in self.task_timings) / len(self.task_timings)
|
|
56
|
+
|
|
57
|
+
def print_summary(self) -> None:
|
|
58
|
+
"""Print timing summary."""
|
|
59
|
+
print("\n" + "=" * 70)
|
|
60
|
+
print("[TIMING] Performance Statistics")
|
|
61
|
+
print("=" * 70)
|
|
62
|
+
|
|
63
|
+
# Overall timing
|
|
64
|
+
print(f"\n Total evaluation time : {self.total_time:.1f}s ({self.total_time/60:.1f}min)")
|
|
65
|
+
print(f" Scheduling time : {self.scheduling_time:.2f}s")
|
|
66
|
+
|
|
67
|
+
# Task statistics
|
|
68
|
+
print(f"\n Tasks completed : {self.total_tasks}")
|
|
69
|
+
print(f" Successful : {self.successful_tasks}")
|
|
70
|
+
print(f" Failed : {self.failed_tasks}")
|
|
71
|
+
|
|
72
|
+
# Average times
|
|
73
|
+
if self.task_timings:
|
|
74
|
+
print(f"\n Avg task time : {self.avg_task_time:.1f}s")
|
|
75
|
+
print(f" Avg container acquire : {self.avg_container_acquire_time:.1f}s")
|
|
76
|
+
print(f" Avg execution time : {self.avg_execution_time:.1f}s")
|
|
77
|
+
|
|
78
|
+
# Min/Max
|
|
79
|
+
times = [t.total_time for t in self.task_timings]
|
|
80
|
+
print(f" Min task time : {min(times):.1f}s")
|
|
81
|
+
print(f" Max task time : {max(times):.1f}s")
|
|
82
|
+
|
|
83
|
+
# Fastest and slowest tasks
|
|
84
|
+
sorted_timings = sorted(self.task_timings, key=lambda t: t.total_time)
|
|
85
|
+
print(f"\n Fastest task : {sorted_timings[0].task_name} ({sorted_timings[0].total_time:.1f}s)")
|
|
86
|
+
print(f" Slowest task : {sorted_timings[-1].task_name} ({sorted_timings[-1].total_time:.1f}s)")
|
|
87
|
+
|
|
88
|
+
print("=" * 70)
|
|
@@ -0,0 +1,429 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
2
|
+
|
|
3
|
+
from dt_arena.src.types.agent import ToolInjection, SkillInjection
|
|
4
|
+
from dt_arena.src.types.task import AttackConfig, AttackStepConfig
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def build_tool_injections_from_config(
|
|
8
|
+
attack_config: Optional[AttackConfig]
|
|
9
|
+
) -> Optional[Dict[str, Dict[str, ToolInjection]]]:
|
|
10
|
+
"""
|
|
11
|
+
Build mcp_injection dict from all attack turns' tool injection steps.
|
|
12
|
+
|
|
13
|
+
Tool injections are merged from all attack turns since tool descriptions
|
|
14
|
+
are set once when the agent initializes.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
attack_config: The attack configuration containing attack turns
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
Dict mapping server_name -> tool_name -> ToolInjection
|
|
21
|
+
Returns None if no tool injections are configured
|
|
22
|
+
"""
|
|
23
|
+
if not attack_config or not attack_config.attack_turns:
|
|
24
|
+
return None
|
|
25
|
+
|
|
26
|
+
mcp_injection: Dict[str, Dict[str, ToolInjection]] = {}
|
|
27
|
+
|
|
28
|
+
# Merge tool injections from all attack turns
|
|
29
|
+
for attack_turn in attack_config.attack_turns:
|
|
30
|
+
tool_injections = attack_turn.get_tool_injections()
|
|
31
|
+
|
|
32
|
+
for step in tool_injections:
|
|
33
|
+
server_name, tool_name = step.get_tool_server_and_name()
|
|
34
|
+
if not server_name or not tool_name:
|
|
35
|
+
print(f"[WARNING] Invalid injected_tool format: {step.injected_tool}")
|
|
36
|
+
continue
|
|
37
|
+
|
|
38
|
+
# Map mode to ToolInjection type
|
|
39
|
+
injection_type = step.mode # "suffix" or "override"
|
|
40
|
+
if injection_type not in ("suffix", "override"):
|
|
41
|
+
print(f"[WARNING] Invalid injection mode: {step.mode}, defaulting to 'suffix'")
|
|
42
|
+
injection_type = "suffix"
|
|
43
|
+
|
|
44
|
+
if server_name not in mcp_injection:
|
|
45
|
+
mcp_injection[server_name] = {}
|
|
46
|
+
|
|
47
|
+
mcp_injection[server_name][tool_name] = ToolInjection(
|
|
48
|
+
type=injection_type,
|
|
49
|
+
content=step.content or ""
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
print(f"[INJECTION] Tool injection configured: {server_name}:{tool_name} ({injection_type})")
|
|
53
|
+
|
|
54
|
+
return mcp_injection if mcp_injection else None
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def build_skill_injections_from_config(
|
|
58
|
+
attack_config: Optional[AttackConfig]
|
|
59
|
+
) -> Optional[Dict[str, List[SkillInjection]]]:
|
|
60
|
+
"""
|
|
61
|
+
Build skill_injection dict from all attack turns' skill injection steps.
|
|
62
|
+
|
|
63
|
+
Skill injections are merged from all attack turns since skill files
|
|
64
|
+
are set up once when the agent initializes.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
attack_config: The attack configuration containing attack turns
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
Dict mapping skill_name -> list of SkillInjection objects
|
|
71
|
+
Returns None if no skill injections are configured
|
|
72
|
+
"""
|
|
73
|
+
if not attack_config or not attack_config.attack_turns:
|
|
74
|
+
return None
|
|
75
|
+
|
|
76
|
+
skill_injection: Dict[str, List[SkillInjection]] = {}
|
|
77
|
+
|
|
78
|
+
# Merge skill injections from all attack turns
|
|
79
|
+
for attack_turn in attack_config.attack_turns:
|
|
80
|
+
skill_injections = attack_turn.get_skill_injections()
|
|
81
|
+
|
|
82
|
+
for step in skill_injections:
|
|
83
|
+
skill_name = step.skill_name
|
|
84
|
+
if not skill_name:
|
|
85
|
+
raise ValueError(f"Skill injection missing skill_name in turn {attack_turn.turn_id}")
|
|
86
|
+
|
|
87
|
+
mode = step.mode
|
|
88
|
+
if mode not in ("insert", "append", "create"):
|
|
89
|
+
raise ValueError(f"Invalid skill injection mode: '{step.mode}'. Must be one of: insert, append, create")
|
|
90
|
+
|
|
91
|
+
content = step.content or ""
|
|
92
|
+
row = step.row if step.row is not None else -1
|
|
93
|
+
|
|
94
|
+
# Validate row based on mode
|
|
95
|
+
if mode == "insert" and row != -1 and row < 1:
|
|
96
|
+
raise ValueError(f"Invalid row {row} for insert mode. Must be >= 1 or -1 (append)")
|
|
97
|
+
if mode in ("append", "create") and row != -1:
|
|
98
|
+
raise ValueError(f"Row must be -1 for {mode} mode, got {row}")
|
|
99
|
+
|
|
100
|
+
if skill_name not in skill_injection:
|
|
101
|
+
skill_injection[skill_name] = []
|
|
102
|
+
|
|
103
|
+
skill_injection[skill_name].append(SkillInjection(
|
|
104
|
+
mode=mode,
|
|
105
|
+
content=content,
|
|
106
|
+
row=row
|
|
107
|
+
))
|
|
108
|
+
|
|
109
|
+
print(f"[INJECTION] Skill injection configured: {skill_name} ({mode}, row={row})")
|
|
110
|
+
|
|
111
|
+
return skill_injection if skill_injection else None
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def apply_prompt_injections(
|
|
115
|
+
original_instruction: Optional[Union[str, List[str]]],
|
|
116
|
+
attack_config: Optional[AttackConfig]
|
|
117
|
+
) -> Union[str, List[str]]:
|
|
118
|
+
"""
|
|
119
|
+
Apply prompt injections to instructions based on mode.
|
|
120
|
+
|
|
121
|
+
Modes:
|
|
122
|
+
- "suffix": Append content to the original instruction (indirect threat model)
|
|
123
|
+
- "override": Replace the instruction entirely with content (indirect threat model)
|
|
124
|
+
- "jailbreak": Replace the instruction entirely with content (direct threat model)
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
original_instruction: Original user instruction (str or list for multi-turn).
|
|
128
|
+
Can be None for direct attacks.
|
|
129
|
+
attack_config: The attack configuration containing attack turns
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
Modified instruction(s) with prompt injections applied
|
|
133
|
+
"""
|
|
134
|
+
if not attack_config or not attack_config.attack_turns:
|
|
135
|
+
return original_instruction
|
|
136
|
+
|
|
137
|
+
# Sort attack turns by turn_id to ensure correct ordering
|
|
138
|
+
sorted_turns = sorted(attack_config.attack_turns, key=lambda t: t.turn_id)
|
|
139
|
+
|
|
140
|
+
# Collect all jailbreak injection contents as the instructions
|
|
141
|
+
if original_instruction is None or original_instruction == "" or original_instruction == []:
|
|
142
|
+
jailbreak_instructions: List[str] = []
|
|
143
|
+
for attack_turn in sorted_turns:
|
|
144
|
+
prompt_injections = attack_turn.get_prompt_injections()
|
|
145
|
+
print('@'*10, attack_turn)
|
|
146
|
+
for injection in prompt_injections:
|
|
147
|
+
if injection.mode == "jailbreak":
|
|
148
|
+
print(f"[INJECTION] Direct attack jailbreak prompt collected from turn {attack_turn.turn_id}")
|
|
149
|
+
jailbreak_instructions.append(injection.content or "")
|
|
150
|
+
|
|
151
|
+
# Return single string if only one instruction, otherwise list
|
|
152
|
+
if len(jailbreak_instructions) == 0:
|
|
153
|
+
raise ValueError("No jailbreak prompts found for direct attack")
|
|
154
|
+
else:
|
|
155
|
+
return jailbreak_instructions
|
|
156
|
+
|
|
157
|
+
# Handle indirect attacks: apply injections to existing instructions
|
|
158
|
+
is_single = isinstance(original_instruction, str)
|
|
159
|
+
instructions = [original_instruction] if is_single else list(original_instruction)
|
|
160
|
+
|
|
161
|
+
# Apply prompt injections from each attack turn
|
|
162
|
+
for attack_turn in sorted_turns:
|
|
163
|
+
turn_index = attack_turn.turn_id - 1
|
|
164
|
+
|
|
165
|
+
prompt_injections = attack_turn.get_prompt_injections()
|
|
166
|
+
if not prompt_injections:
|
|
167
|
+
continue
|
|
168
|
+
|
|
169
|
+
# Apply the first prompt injection found for this turn
|
|
170
|
+
injection = prompt_injections[0]
|
|
171
|
+
|
|
172
|
+
if injection.mode in ("override", "jailbreak"):
|
|
173
|
+
# Both override and jailbreak replace the instruction entirely
|
|
174
|
+
mode_name = injection.mode
|
|
175
|
+
print(f"[INJECTION] Prompt {mode_name} applied to turn {attack_turn.turn_id}")
|
|
176
|
+
if turn_index < 0 or turn_index >= len(instructions):
|
|
177
|
+
# If turn doesn't exist, append the instruction
|
|
178
|
+
instructions.append(injection.content or "")
|
|
179
|
+
else:
|
|
180
|
+
instructions[turn_index] = injection.content or ""
|
|
181
|
+
|
|
182
|
+
elif injection.mode == "suffix":
|
|
183
|
+
if turn_index < 0 or turn_index >= len(instructions):
|
|
184
|
+
print(f"[WARNING] Attack turn {attack_turn.turn_id} has no corresponding instruction (only {len(instructions)} instructions)")
|
|
185
|
+
continue
|
|
186
|
+
print(f"[INJECTION] Prompt suffix applied to turn {attack_turn.turn_id}")
|
|
187
|
+
instructions[turn_index] = f"{instructions[turn_index]}\n{injection.content or ''}"
|
|
188
|
+
|
|
189
|
+
else:
|
|
190
|
+
raise ValueError(f"Unknown prompt injection mode: {injection.mode} for turn {attack_turn.turn_id}")
|
|
191
|
+
|
|
192
|
+
# Return in original format
|
|
193
|
+
return instructions[0] if is_single else instructions
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def get_env_injections_from_attack(
|
|
197
|
+
attack_config: Optional[AttackConfig],
|
|
198
|
+
turn_id: Optional[int] = None,
|
|
199
|
+
) -> List[Dict[str, Any]]:
|
|
200
|
+
"""
|
|
201
|
+
Extract environment injection steps from attack config.
|
|
202
|
+
|
|
203
|
+
Parses attack_turns to find environment injection steps. Can optionally
|
|
204
|
+
filter by turn_id for per-turn injection application in multi-turn attacks.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
attack_config: The attack configuration containing attack turns
|
|
208
|
+
turn_id: Optional turn ID to filter injections (1-indexed).
|
|
209
|
+
If None, returns all environment injections from all turns.
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
List of environment injection specs:
|
|
213
|
+
[
|
|
214
|
+
{
|
|
215
|
+
"server_name": "salesforce-injection",
|
|
216
|
+
"tool_name": "inject_comment",
|
|
217
|
+
"kwargs": {"comment": "..."},
|
|
218
|
+
"turn_id": 1,
|
|
219
|
+
},
|
|
220
|
+
...
|
|
221
|
+
]
|
|
222
|
+
"""
|
|
223
|
+
if not attack_config or not attack_config.attack_turns:
|
|
224
|
+
return []
|
|
225
|
+
|
|
226
|
+
env_injections: List[Dict[str, Any]] = []
|
|
227
|
+
|
|
228
|
+
for attack_turn in attack_config.attack_turns:
|
|
229
|
+
# Filter by turn_id if specified
|
|
230
|
+
if turn_id is not None and attack_turn.turn_id != turn_id:
|
|
231
|
+
continue
|
|
232
|
+
|
|
233
|
+
env_steps = attack_turn.get_environment_injections()
|
|
234
|
+
|
|
235
|
+
for step in env_steps:
|
|
236
|
+
if not step.injection_mcp_tool:
|
|
237
|
+
print(f"[WARNING] Environment injection missing injection_mcp_tool in turn {attack_turn.turn_id}")
|
|
238
|
+
continue
|
|
239
|
+
|
|
240
|
+
# Parse "server_name:tool_name" format
|
|
241
|
+
if ":" not in step.injection_mcp_tool:
|
|
242
|
+
print(f"[WARNING] Invalid injection_mcp_tool format: {step.injection_mcp_tool}")
|
|
243
|
+
continue
|
|
244
|
+
|
|
245
|
+
server_name, tool_name = step.injection_mcp_tool.split(":", 1)
|
|
246
|
+
|
|
247
|
+
env_injections.append({
|
|
248
|
+
"server_name": server_name,
|
|
249
|
+
"tool_name": tool_name,
|
|
250
|
+
"kwargs": step.kwargs or {},
|
|
251
|
+
"turn_id": attack_turn.turn_id,
|
|
252
|
+
})
|
|
253
|
+
|
|
254
|
+
return env_injections
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def get_required_injection_servers(
|
|
258
|
+
env_injections: List[Dict[str, Any]],
|
|
259
|
+
) -> Dict[str, List[str]]:
|
|
260
|
+
"""
|
|
261
|
+
Determine which injection MCP servers need to be started.
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
env_injections: List of environment injection specs from get_env_injections_from_attack()
|
|
265
|
+
|
|
266
|
+
Returns:
|
|
267
|
+
Dict mapping server_name -> list of tool names needed
|
|
268
|
+
e.g., {"salesforce-injection": ["inject_comment", "inject_lead"]}
|
|
269
|
+
"""
|
|
270
|
+
servers: Dict[str, List[str]] = {}
|
|
271
|
+
|
|
272
|
+
for injection in env_injections:
|
|
273
|
+
server_name = injection["server_name"]
|
|
274
|
+
tool_name = injection["tool_name"]
|
|
275
|
+
|
|
276
|
+
if server_name not in servers:
|
|
277
|
+
servers[server_name] = []
|
|
278
|
+
|
|
279
|
+
if tool_name not in servers[server_name]:
|
|
280
|
+
servers[server_name].append(tool_name)
|
|
281
|
+
|
|
282
|
+
return servers
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
async def apply_environment_injections_async(
|
|
286
|
+
env_injections: List[Dict[str, Any]],
|
|
287
|
+
server_urls: Dict[str, str],
|
|
288
|
+
) -> List[Dict[str, Any]]:
|
|
289
|
+
"""
|
|
290
|
+
Execute all environment injection steps via MCP calls (async version).
|
|
291
|
+
|
|
292
|
+
This function connects to injection MCP servers and executes the
|
|
293
|
+
injection tools specified in the attack config.
|
|
294
|
+
|
|
295
|
+
Args:
|
|
296
|
+
env_injections: List of environment injection specs from get_env_injections_from_attack()
|
|
297
|
+
server_urls: Dict mapping server_name -> URL (from start_injection_mcp_servers)
|
|
298
|
+
|
|
299
|
+
Returns:
|
|
300
|
+
List of results:
|
|
301
|
+
[
|
|
302
|
+
{
|
|
303
|
+
"server_name": "salesforce-injection",
|
|
304
|
+
"tool_name": "inject_comment",
|
|
305
|
+
"success": True,
|
|
306
|
+
"result": "...",
|
|
307
|
+
"error": None,
|
|
308
|
+
},
|
|
309
|
+
...
|
|
310
|
+
]
|
|
311
|
+
"""
|
|
312
|
+
from fastmcp import Client
|
|
313
|
+
|
|
314
|
+
results: List[Dict[str, Any]] = []
|
|
315
|
+
|
|
316
|
+
# Group injections by server to reuse connections
|
|
317
|
+
injections_by_server: Dict[str, List[Dict[str, Any]]] = {}
|
|
318
|
+
for injection in env_injections:
|
|
319
|
+
server_name = injection["server_name"]
|
|
320
|
+
if server_name not in injections_by_server:
|
|
321
|
+
injections_by_server[server_name] = []
|
|
322
|
+
injections_by_server[server_name].append(injection)
|
|
323
|
+
|
|
324
|
+
# Execute injections for each server
|
|
325
|
+
for server_name, server_injections in injections_by_server.items():
|
|
326
|
+
server_url = server_urls.get(server_name)
|
|
327
|
+
if not server_url:
|
|
328
|
+
# Server not started, record errors for all injections
|
|
329
|
+
for injection in server_injections:
|
|
330
|
+
results.append({
|
|
331
|
+
"server_name": server_name,
|
|
332
|
+
"tool_name": injection["tool_name"],
|
|
333
|
+
"kwargs": injection["kwargs"],
|
|
334
|
+
"turn_id": injection["turn_id"],
|
|
335
|
+
"success": False,
|
|
336
|
+
"result": None,
|
|
337
|
+
"error": f"Injection server '{server_name}' URL not available",
|
|
338
|
+
})
|
|
339
|
+
continue
|
|
340
|
+
|
|
341
|
+
# Connect to server and execute injections
|
|
342
|
+
try:
|
|
343
|
+
async with Client(server_url, timeout=30.0) as client:
|
|
344
|
+
for injection in server_injections:
|
|
345
|
+
tool_name = injection["tool_name"]
|
|
346
|
+
kwargs = injection["kwargs"]
|
|
347
|
+
turn_id = injection["turn_id"]
|
|
348
|
+
|
|
349
|
+
try:
|
|
350
|
+
result = await client.call_tool(tool_name, kwargs)
|
|
351
|
+
|
|
352
|
+
# Extract result content
|
|
353
|
+
result_text = _extract_mcp_result(result)
|
|
354
|
+
|
|
355
|
+
results.append({
|
|
356
|
+
"server_name": server_name,
|
|
357
|
+
"tool_name": tool_name,
|
|
358
|
+
"kwargs": kwargs,
|
|
359
|
+
"turn_id": turn_id,
|
|
360
|
+
"success": True,
|
|
361
|
+
"result": result_text,
|
|
362
|
+
"error": None,
|
|
363
|
+
})
|
|
364
|
+
print(f"[ENV INJECTION] {server_name}:{tool_name} - Success")
|
|
365
|
+
|
|
366
|
+
except Exception as e:
|
|
367
|
+
results.append({
|
|
368
|
+
"server_name": server_name,
|
|
369
|
+
"tool_name": tool_name,
|
|
370
|
+
"kwargs": kwargs,
|
|
371
|
+
"turn_id": turn_id,
|
|
372
|
+
"success": False,
|
|
373
|
+
"result": None,
|
|
374
|
+
"error": str(e),
|
|
375
|
+
})
|
|
376
|
+
print(f"[ENV INJECTION] {server_name}:{tool_name} - Failed: {e}")
|
|
377
|
+
|
|
378
|
+
except Exception as e:
|
|
379
|
+
# Connection error - record errors for all injections on this server
|
|
380
|
+
for injection in server_injections:
|
|
381
|
+
results.append({
|
|
382
|
+
"server_name": server_name,
|
|
383
|
+
"tool_name": injection["tool_name"],
|
|
384
|
+
"kwargs": injection["kwargs"],
|
|
385
|
+
"turn_id": injection["turn_id"],
|
|
386
|
+
"success": False,
|
|
387
|
+
"result": None,
|
|
388
|
+
"error": f"Failed to connect to server: {e}",
|
|
389
|
+
})
|
|
390
|
+
print(f"[ENV INJECTION] Failed to connect to {server_name} at {server_url}: {e}")
|
|
391
|
+
|
|
392
|
+
return results
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def apply_environment_injections(
|
|
396
|
+
env_injections: List[Dict[str, Any]],
|
|
397
|
+
server_urls: Dict[str, str],
|
|
398
|
+
) -> List[Dict[str, Any]]:
|
|
399
|
+
"""
|
|
400
|
+
Execute all environment injection steps via MCP calls (sync wrapper).
|
|
401
|
+
|
|
402
|
+
This is a synchronous wrapper around apply_environment_injections_async()
|
|
403
|
+
for use in synchronous contexts.
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
env_injections: List of environment injection specs from get_env_injections_from_attack()
|
|
407
|
+
server_urls: Dict mapping server_name -> URL (from start_injection_mcp_servers)
|
|
408
|
+
|
|
409
|
+
Returns:
|
|
410
|
+
List of results (see apply_environment_injections_async for format)
|
|
411
|
+
"""
|
|
412
|
+
import asyncio
|
|
413
|
+
|
|
414
|
+
return asyncio.run(apply_environment_injections_async(env_injections, server_urls))
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
def _extract_mcp_result(result: Any) -> str:
|
|
418
|
+
"""Extract text content from MCP tool result."""
|
|
419
|
+
if hasattr(result, "content"):
|
|
420
|
+
content_parts = []
|
|
421
|
+
for item in result.content:
|
|
422
|
+
if hasattr(item, "text"):
|
|
423
|
+
content_parts.append(item.text)
|
|
424
|
+
elif hasattr(item, "data"):
|
|
425
|
+
content_parts.append(str(item.data))
|
|
426
|
+
else:
|
|
427
|
+
content_parts.append(str(item))
|
|
428
|
+
return "\n".join(content_parts) if content_parts else str(result)
|
|
429
|
+
return str(result)
|