decodingtrust-agent-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent/__init__.py +30 -0
- agent/claudesdk/__init__.py +8 -0
- agent/claudesdk/example.py +221 -0
- agent/claudesdk/src/__init__.py +8 -0
- agent/claudesdk/src/agent.py +400 -0
- agent/claudesdk/src/mcp_proxy.py +409 -0
- agent/claudesdk/src/utils.py +420 -0
- agent/googleadk/__init__.py +15 -0
- agent/googleadk/example.py +237 -0
- agent/googleadk/src/__init__.py +12 -0
- agent/googleadk/src/agent.py +401 -0
- agent/googleadk/src/mcp_wrapper.py +163 -0
- agent/googleadk/src/utils.py +602 -0
- agent/langchain/__init__.py +8 -0
- agent/langchain/example.py +213 -0
- agent/langchain/src/__init__.py +8 -0
- agent/langchain/src/agent.py +645 -0
- agent/langchain/src/utils.py +433 -0
- agent/openaisdk/__init__.py +17 -0
- agent/openaisdk/example.py +228 -0
- agent/openaisdk/src/__init__.py +12 -0
- agent/openaisdk/src/agent.py +491 -0
- agent/openaisdk/src/agent_wrapper.py +143 -0
- agent/openaisdk/src/mcp_wrapper.py +395 -0
- agent/openaisdk/src/utils.py +493 -0
- agent/openclaw/__init__.py +10 -0
- agent/openclaw/example.py +251 -0
- agent/openclaw/src/__init__.py +14 -0
- agent/openclaw/src/agent.py +930 -0
- agent/openclaw/src/helpers/__init__.py +1 -0
- agent/openclaw/src/helpers/auth_helpers.py +55 -0
- agent/openclaw/src/mcp_proxy.py +564 -0
- agent/openclaw/src/plugin_generator.py +231 -0
- agent/openclaw/src/utils.py +341 -0
- agent/pocketflow/__init__.py +18 -0
- agent/pocketflow/example.py +221 -0
- agent/pocketflow/prompts/react_agent.py +46 -0
- agent/pocketflow/src/__init__.py +6 -0
- agent/pocketflow/src/agent.py +507 -0
- agent/pocketflow/src/agent_wrapper.py +159 -0
- agent/pocketflow/src/async_helper.py +92 -0
- agent/pocketflow/src/mcp_react_agent.py +279 -0
- agent/pocketflow/src/native_agent.py +74 -0
- agent/pocketflow/src/nodes.py +467 -0
- benchmark/__init__.py +0 -0
- benchmark/browser/benign.jsonl +34 -0
- benchmark/browser/direct.jsonl +85 -0
- benchmark/browser/indirect.jsonl +82 -0
- benchmark/code/benign.jsonl +0 -0
- benchmark/code/direct.jsonl +121 -0
- benchmark/code/indirect.jsonl +165 -0
- benchmark/crm/benign.jsonl +165 -0
- benchmark/crm/direct.jsonl +90 -0
- benchmark/crm/indirect.jsonl +150 -0
- benchmark/customer-service/benign.jsonl +160 -0
- benchmark/customer-service/direct.jsonl +100 -0
- benchmark/customer-service/indirect.jsonl +101 -0
- benchmark/finance/benign.jsonl +0 -0
- benchmark/finance/direct.jsonl +200 -0
- benchmark/finance/indirect.jsonl +200 -0
- benchmark/legal/benign.jsonl +0 -0
- benchmark/legal/direct.jsonl +200 -0
- benchmark/legal/indirect.jsonl +200 -0
- benchmark/macos/benign.jsonl +30 -0
- benchmark/macos/direct.jsonl +50 -0
- benchmark/macos/indirect.jsonl +50 -0
- benchmark/medical/benign.jsonl +642 -0
- benchmark/medical/direct.jsonl +229 -0
- benchmark/medical/indirect.jsonl +222 -0
- benchmark/os-filesystem/benign.jsonl +200 -0
- benchmark/os-filesystem/direct.jsonl +200 -0
- benchmark/os-filesystem/indirect.jsonl +200 -0
- benchmark/research/benign.jsonl +0 -0
- benchmark/research/direct.jsonl +119 -0
- benchmark/research/indirect.jsonl +125 -0
- benchmark/telecom/benign.jsonl +120 -0
- benchmark/telecom/direct.jsonl +161 -0
- benchmark/telecom/indirect.jsonl +166 -0
- benchmark/travel/benign.jsonl +130 -0
- benchmark/travel/direct.jsonl +105 -0
- benchmark/travel/indirect.jsonl +120 -0
- benchmark/windows/benign.jsonl +100 -0
- benchmark/windows/direct.jsonl +140 -0
- benchmark/windows/indirect.jsonl +107 -0
- benchmark/workflow/benign.jsonl +335 -0
- benchmark/workflow/direct.jsonl +78 -0
- benchmark/workflow/indirect.jsonl +107 -0
- cli/__init__.py +5 -0
- cli/main.py +182 -0
- cli/scaffold.py +334 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/METADATA +642 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/RECORD +374 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/WHEEL +5 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/entry_points.txt +2 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/licenses/LICENSE +201 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/top_level.txt +6 -0
- dt_arena/config/env.yaml +515 -0
- dt_arena/config/injection_mcp.yaml +430 -0
- dt_arena/config/mcp.yaml +642 -0
- dt_arena/envs/arxiv/docker-compose-hub.yml +31 -0
- dt_arena/envs/arxiv/docker-compose.yml +36 -0
- dt_arena/envs/atlassian/docker/docker-compose.dev.yml +65 -0
- dt_arena/envs/atlassian/docker/docker-compose.yml +53 -0
- dt_arena/envs/atlassian/docker-compose-hub.yml +57 -0
- dt_arena/envs/atlassian/docker-compose.yml +72 -0
- dt_arena/envs/bigquery/docker-compose.yml +20 -0
- dt_arena/envs/booking/docker-compose.yml +59 -0
- dt_arena/envs/calendar/docker-compose-hub.yml +30 -0
- dt_arena/envs/calendar/docker-compose.yml +42 -0
- dt_arena/envs/custom-website/docker-compose.yml +6 -0
- dt_arena/envs/customer_service/docker-compose.yml +59 -0
- dt_arena/envs/databricks/docker-compose-hub.yml +47 -0
- dt_arena/envs/databricks/docker-compose.yml +51 -0
- dt_arena/envs/ecommerce/docker-compose.yml +6 -0
- dt_arena/envs/ers/docker-compose.yml +36 -0
- dt_arena/envs/ers/hrms/docker/docker-compose.yml +31 -0
- dt_arena/envs/finance/docker-compose.yml +23 -0
- dt_arena/envs/github/docker/docker-compose-hub.yml +50 -0
- dt_arena/envs/github/docker/docker-compose.yml +50 -0
- dt_arena/envs/gmail/docker-compose-hub.yml +51 -0
- dt_arena/envs/gmail/docker-compose.yml +65 -0
- dt_arena/envs/google-form/docker-compose-hub.yml +33 -0
- dt_arena/envs/google-form/docker-compose.yml +41 -0
- dt_arena/envs/googledocs/docker-compose-hub.yml +61 -0
- dt_arena/envs/googledocs/docker-compose.yml +78 -0
- dt_arena/envs/hospital/docker-compose-hub.yml +25 -0
- dt_arena/envs/hospital/docker-compose.yml +27 -0
- dt_arena/envs/legal/docker-compose.yml +22 -0
- dt_arena/envs/linkedin/docker-compose.yml +63 -0
- dt_arena/envs/macos/docker-compose.yml +79 -0
- dt_arena/envs/os-filesystem/docker-compose-hub.yml +16 -0
- dt_arena/envs/os-filesystem/docker-compose.yml +20 -0
- dt_arena/envs/paypal/docker-compose-hub.yml +48 -0
- dt_arena/envs/paypal/docker-compose.yml +63 -0
- dt_arena/envs/research/docker-compose-hub.yml +13 -0
- dt_arena/envs/research/docker-compose.yml +24 -0
- dt_arena/envs/salesforce_crm/docker-compose-hub.yaml +45 -0
- dt_arena/envs/salesforce_crm/docker-compose.yaml +49 -0
- dt_arena/envs/slack/docker-compose-hub.yml +28 -0
- dt_arena/envs/slack/docker-compose.yml +41 -0
- dt_arena/envs/snowflake/docker-compose-hub.yml +41 -0
- dt_arena/envs/snowflake/docker-compose.yml +44 -0
- dt_arena/envs/telecom/docker-compose-hub.yml +16 -0
- dt_arena/envs/telecom/docker-compose.yml +17 -0
- dt_arena/envs/telegram/docker-compose-hub.yml +57 -0
- dt_arena/envs/telegram/docker-compose.yml +62 -0
- dt_arena/envs/terminal/docker-compose-hub.yml +12 -0
- dt_arena/envs/terminal/docker-compose.yml +26 -0
- dt_arena/envs/travel/docker-compose-hub.yml +19 -0
- dt_arena/envs/travel/docker-compose.yml +19 -0
- dt_arena/envs/whatsapp/docker-compose-hub.yml +61 -0
- dt_arena/envs/whatsapp/docker-compose.yml +78 -0
- dt_arena/envs/windows/docker-compose.yml +71 -0
- dt_arena/envs/zoom/docker-compose-hub.yml +27 -0
- dt_arena/envs/zoom/docker-compose.yml +40 -0
- dt_arena/injection_mcp_server/atlassian/env_injection.py +134 -0
- dt_arena/injection_mcp_server/calendar/env_injection.py +217 -0
- dt_arena/injection_mcp_server/custom_website/env_injection.py +97 -0
- dt_arena/injection_mcp_server/customer_service/env_injection.py +659 -0
- dt_arena/injection_mcp_server/databricks/env_injection.py +255 -0
- dt_arena/injection_mcp_server/ecommerce/env_injection.py +110 -0
- dt_arena/injection_mcp_server/finance/env_injection.py +85 -0
- dt_arena/injection_mcp_server/github/env_injection.py +206 -0
- dt_arena/injection_mcp_server/gmail/env_injection.py +211 -0
- dt_arena/injection_mcp_server/google_form/env_injection.py +186 -0
- dt_arena/injection_mcp_server/googledocs/env_injection.py +44 -0
- dt_arena/injection_mcp_server/hospital/env_injection.py +43 -0
- dt_arena/injection_mcp_server/legal/env_injection.py +229 -0
- dt_arena/injection_mcp_server/macos/env_injection.py +272 -0
- dt_arena/injection_mcp_server/os-filesystem/env_injection.py +341 -0
- dt_arena/injection_mcp_server/paypal/env_injection.py +268 -0
- dt_arena/injection_mcp_server/research/env_injection.py +616 -0
- dt_arena/injection_mcp_server/salesforce/env_injection.py +514 -0
- dt_arena/injection_mcp_server/slack/env_injection.py +265 -0
- dt_arena/injection_mcp_server/snowflake/env_injection.py +230 -0
- dt_arena/injection_mcp_server/telecom/env_injection.py +503 -0
- dt_arena/injection_mcp_server/telegram/env_injection.py +171 -0
- dt_arena/injection_mcp_server/terminal/env_injection.py +523 -0
- dt_arena/injection_mcp_server/travel/env_injection.py +173 -0
- dt_arena/injection_mcp_server/whatsapp/env_injection.py +185 -0
- dt_arena/injection_mcp_server/windows/env_injection.py +943 -0
- dt_arena/injection_mcp_server/zoom/env_injection.py +216 -0
- dt_arena/mcp_server/atlassian/main.py +1554 -0
- dt_arena/mcp_server/atlassian/test_server.py +66 -0
- dt_arena/mcp_server/bigquery/main.py +333 -0
- dt_arena/mcp_server/booking/main.py +310 -0
- dt_arena/mcp_server/browser/main.py +1741 -0
- dt_arena/mcp_server/calendar/example_multi_user.py +162 -0
- dt_arena/mcp_server/calendar/main.py +792 -0
- dt_arena/mcp_server/calendar/test_mcp.py +135 -0
- dt_arena/mcp_server/customer_service/main.py +1063 -0
- dt_arena/mcp_server/databricks/main.py +566 -0
- dt_arena/mcp_server/databricks/probe.py +102 -0
- dt_arena/mcp_server/ers/main.py +845 -0
- dt_arena/mcp_server/finance/__init__.py +87 -0
- dt_arena/mcp_server/finance/core/__init__.py +12 -0
- dt_arena/mcp_server/finance/core/data_loader.py +558 -0
- dt_arena/mcp_server/finance/core/portfolio.py +565 -0
- dt_arena/mcp_server/finance/evaluation/__init__.py +20 -0
- dt_arena/mcp_server/finance/evaluation/evaluator.py +217 -0
- dt_arena/mcp_server/finance/evaluation/logger.py +137 -0
- dt_arena/mcp_server/finance/injection/__init__.py +66 -0
- dt_arena/mcp_server/finance/injection/config.py +176 -0
- dt_arena/mcp_server/finance/injection/content.py +755 -0
- dt_arena/mcp_server/finance/injection/html.py +409 -0
- dt_arena/mcp_server/finance/injection/locations.py +167 -0
- dt_arena/mcp_server/finance/injection/methods.py +193 -0
- dt_arena/mcp_server/finance/injection/presets.py +1023 -0
- dt_arena/mcp_server/finance/main.py +361 -0
- dt_arena/mcp_server/finance/run_mcp.py +21 -0
- dt_arena/mcp_server/finance/run_web.py +26 -0
- dt_arena/mcp_server/finance/server/__init__.py +41 -0
- dt_arena/mcp_server/finance/server/extractor.py +1453 -0
- dt_arena/mcp_server/finance/server/extractor_minimal.py +292 -0
- dt_arena/mcp_server/finance/server/extractor_simple.py +1164 -0
- dt_arena/mcp_server/finance/server/injection_mcp.py +865 -0
- dt_arena/mcp_server/finance/server/mcp.py +451 -0
- dt_arena/mcp_server/finance/server/tools/__init__.py +23 -0
- dt_arena/mcp_server/finance/server/tools/account.py +88 -0
- dt_arena/mcp_server/finance/server/tools/browsing.py +328 -0
- dt_arena/mcp_server/finance/server/tools/social.py +73 -0
- dt_arena/mcp_server/finance/server/tools/trading.py +242 -0
- dt_arena/mcp_server/finance/server/tools/utility.py +49 -0
- dt_arena/mcp_server/finance/server/web.py +2139 -0
- dt_arena/mcp_server/finance/tasks/benchmark/__init__.py +28 -0
- dt_arena/mcp_server/finance/tasks/benchmark/attack_pool.py +3026 -0
- dt_arena/mcp_server/finance/tasks/benchmark/attack_runner.py +1315 -0
- dt_arena/mcp_server/finance/tasks/benchmark/finra_requirements.py +1335 -0
- dt_arena/mcp_server/finance/tasks/benchmark/finra_tasks.py +3665 -0
- dt_arena/mcp_server/finance/tasks/benchmark/malicious_tasks.py +2673 -0
- dt_arena/mcp_server/finance/tasks/redteam_suite/run_redteam_suite.py +1713 -0
- dt_arena/mcp_server/finance/test_mcp_tools.py +476 -0
- dt_arena/mcp_server/github/main.py +441 -0
- dt_arena/mcp_server/gmail/main.py +1004 -0
- dt_arena/mcp_server/google_form/main.py +141 -0
- dt_arena/mcp_server/googledocs/main.py +458 -0
- dt_arena/mcp_server/hospital/mcp_server.py +458 -0
- dt_arena/mcp_server/legal/__init__.py +9 -0
- dt_arena/mcp_server/legal/core/__init__.py +14 -0
- dt_arena/mcp_server/legal/core/courtlistener_store.py +762 -0
- dt_arena/mcp_server/legal/core/data_loader.py +266 -0
- dt_arena/mcp_server/legal/core/document_store.py +197 -0
- dt_arena/mcp_server/legal/core/matter_manager.py +466 -0
- dt_arena/mcp_server/legal/main.py +89 -0
- dt_arena/mcp_server/legal/scripts/collect_data.py +988 -0
- dt_arena/mcp_server/legal/server/__init__.py +14 -0
- dt_arena/mcp_server/legal/server/mcp.py +2330 -0
- dt_arena/mcp_server/macos/client_test.py +270 -0
- dt_arena/mcp_server/macos/mcp_server.py +285 -0
- dt_arena/mcp_server/os-filesystem/main.py +1380 -0
- dt_arena/mcp_server/paypal/main.py +501 -0
- dt_arena/mcp_server/research/main.py +777 -0
- dt_arena/mcp_server/salesforce/main.py +2006 -0
- dt_arena/mcp_server/slack/main.py +318 -0
- dt_arena/mcp_server/snowflake/main.py +612 -0
- dt_arena/mcp_server/snowflake/probe.py +183 -0
- dt_arena/mcp_server/telecom/mcp_client.py +423 -0
- dt_arena/mcp_server/telecom/mcp_server.py +1059 -0
- dt_arena/mcp_server/telegram/main.py +338 -0
- dt_arena/mcp_server/terminal/main.py +163 -0
- dt_arena/mcp_server/travel/client_test.py +16 -0
- dt_arena/mcp_server/travel/mcp_server.py +404 -0
- dt_arena/mcp_server/whatsapp/main.py +318 -0
- dt_arena/mcp_server/windows/client_test.py +270 -0
- dt_arena/mcp_server/windows/mcp_server.py +218 -0
- dt_arena/mcp_server/zoom/main.py +466 -0
- dt_arena/src/__init__.py +0 -0
- dt_arena/src/hooks/__init__.py +0 -0
- dt_arena/src/hooks/audit_log.py +30 -0
- dt_arena/src/hooks/hooks.json +3 -0
- dt_arena/src/run_benign.py +142 -0
- dt_arena/src/types/__init__.py +0 -0
- dt_arena/src/types/agent.py +441 -0
- dt_arena/src/types/attacks.py +2 -0
- dt_arena/src/types/environment.py +2 -0
- dt_arena/src/types/hooks.py +174 -0
- dt_arena/src/types/judge.py +52 -0
- dt_arena/src/types/red_teaming_trajectory.py +385 -0
- dt_arena/src/types/task.py +260 -0
- dt_arena/src/types/trajectory.py +315 -0
- dt_arena/utils/__init__.py +1 -0
- dt_arena/utils/atlassian/__init__.py +27 -0
- dt_arena/utils/atlassian/helpers.py +520 -0
- dt_arena/utils/bigquery/__init__.py +1 -0
- dt_arena/utils/bigquery/helpers.py +246 -0
- dt_arena/utils/calendar/__init__.py +1 -0
- dt_arena/utils/calendar/helpers.py +87 -0
- dt_arena/utils/customer_service/__init__.py +17 -0
- dt_arena/utils/customer_service/cs_env_client.py +940 -0
- dt_arena/utils/customer_service/helpers.py +339 -0
- dt_arena/utils/customer_service/judges/__init__.py +20 -0
- dt_arena/utils/customer_service/judges/abusive_comms.py +149 -0
- dt_arena/utils/customer_service/judges/bulk_destructive.py +202 -0
- dt_arena/utils/customer_service/judges/cross_customer.py +481 -0
- dt_arena/utils/customer_service/judges/email_exfiltration.py +320 -0
- dt_arena/utils/customer_service/judges/financial_fraud.py +221 -0
- dt_arena/utils/customer_service/judges/llm_utils.py +264 -0
- dt_arena/utils/customer_service/judges/pii_disclosure.py +837 -0
- dt_arena/utils/customer_service/judges/policy_misrep.py +44 -0
- dt_arena/utils/customer_service/judges/text_utils.py +21 -0
- dt_arena/utils/databricks/__init__.py +2 -0
- dt_arena/utils/databricks/helpers.py +210 -0
- dt_arena/utils/finance/__init__.py +0 -0
- dt_arena/utils/finance/helpers.py +263 -0
- dt_arena/utils/github/__init__.py +1 -0
- dt_arena/utils/github/helpers.py +249 -0
- dt_arena/utils/gmail/__init__.py +1 -0
- dt_arena/utils/gmail/helpers.py +344 -0
- dt_arena/utils/google_form/__init__.py +2 -0
- dt_arena/utils/google_form/helpers.py +133 -0
- dt_arena/utils/legal/__init__.py +0 -0
- dt_arena/utils/legal/helpers.py +228 -0
- dt_arena/utils/macos/__init__.py +0 -0
- dt_arena/utils/macos/env_setup.py +215 -0
- dt_arena/utils/macos/helpers.py +61 -0
- dt_arena/utils/os_filesystem/__init__.py +1 -0
- dt_arena/utils/os_filesystem/helpers.py +366 -0
- dt_arena/utils/paypal/__init__.py +1 -0
- dt_arena/utils/paypal/helpers.py +178 -0
- dt_arena/utils/port_allocator.py +266 -0
- dt_arena/utils/research/__init__.py +0 -0
- dt_arena/utils/research/helpers.py +251 -0
- dt_arena/utils/salesforce/__init__.py +1 -0
- dt_arena/utils/salesforce/helpers.py +719 -0
- dt_arena/utils/slack/__init__.py +1 -0
- dt_arena/utils/slack/helpers.py +176 -0
- dt_arena/utils/snowflake/__init__.py +1 -0
- dt_arena/utils/snowflake/helpers.py +166 -0
- dt_arena/utils/telecom/__init__.py +1 -0
- dt_arena/utils/telecom/helpers.py +760 -0
- dt_arena/utils/telegram/__init__.py +0 -0
- dt_arena/utils/telegram/helpers.py +174 -0
- dt_arena/utils/terminal/__init__.py +0 -0
- dt_arena/utils/terminal/helpers.py +20 -0
- dt_arena/utils/travel/__init__.py +0 -0
- dt_arena/utils/travel/env_client.py +537 -0
- dt_arena/utils/travel/llm_judge.py +137 -0
- dt_arena/utils/travel/prompts.py +64 -0
- dt_arena/utils/utils/__init__.py +122 -0
- dt_arena/utils/whatsapp/__init__.py +0 -0
- dt_arena/utils/whatsapp/helpers.py +226 -0
- dt_arena/utils/windows/__init__.py +0 -0
- dt_arena/utils/windows/env_reset.py +224 -0
- dt_arena/utils/windows/env_setup.py +280 -0
- dt_arena/utils/windows/exfil_helpers.py +170 -0
- dt_arena/utils/windows/helpers.py +74 -0
- dt_arena/utils/zoom/__init__.py +1 -0
- dt_arena/utils/zoom/helpers.py +70 -0
- eval/__init__.py +1 -0
- eval/evaluation.py +426 -0
- eval/task_runner.py +449 -0
- utils/__init__.py +148 -0
- utils/agent_helpers.py +308 -0
- utils/agent_wrapper.py +189 -0
- utils/compose_utils.py +135 -0
- utils/config.py +77 -0
- utils/env_helpers.py +104 -0
- utils/eval_stats.py +88 -0
- utils/injection_helpers.py +429 -0
- utils/injection_mcp_helpers.py +152 -0
- utils/judge_helpers.py +181 -0
- utils/judge_utils.py +472 -0
- utils/llm.py +196 -0
- utils/logging.py +45 -0
- utils/mcp_helpers.py +232 -0
- utils/mcp_manager.py +235 -0
- utils/memory_guard.py +18 -0
- utils/red_teaming_sandbox.py +476 -0
- utils/reset_helpers.py +318 -0
- utils/resource_manager.py +370 -0
- utils/skill_helpers.py +447 -0
- utils/task_executor.py +904 -0
- utils/task_helpers.py +270 -0
- utils/template_helpers.py +179 -0
|
@@ -0,0 +1,433 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
import uuid
|
|
4
|
+
import threading
|
|
5
|
+
from datetime import datetime, timezone
|
|
6
|
+
from typing import List, Dict, Any, Optional
|
|
7
|
+
from collections import defaultdict
|
|
8
|
+
|
|
9
|
+
# Add parent directory to path to import types
|
|
10
|
+
import sys
|
|
11
|
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../..'))
|
|
12
|
+
from dt_arena.src.types.trajectory import Trajectory
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class LangChainTraceProcessor:
|
|
16
|
+
"""
|
|
17
|
+
Trace processor for LangChain agent that captures all execution data.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(self, path: str):
|
|
21
|
+
self.path = path
|
|
22
|
+
self._lock = threading.Lock()
|
|
23
|
+
self._traces: Dict[str, Dict[str, Any]] = {}
|
|
24
|
+
|
|
25
|
+
def _write(self, obj: Dict[str, Any]):
|
|
26
|
+
"""Write a record to the trace file"""
|
|
27
|
+
with self._lock:
|
|
28
|
+
try:
|
|
29
|
+
with open(self.path, "a", encoding="utf-8") as f:
|
|
30
|
+
f.write(json.dumps(obj, ensure_ascii=False, default=str) + "\n")
|
|
31
|
+
except Exception as e:
|
|
32
|
+
print(f"[WARNING] Failed to write trace: {e}")
|
|
33
|
+
|
|
34
|
+
def start_trace(self, metadata: Dict[str, Any]) -> str:
|
|
35
|
+
"""
|
|
36
|
+
Start a new trace
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
metadata: Trace metadata (task_id, domain, category, etc.)
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
trace_id: Unique identifier for this trace
|
|
43
|
+
"""
|
|
44
|
+
trace_id = str(uuid.uuid4())
|
|
45
|
+
|
|
46
|
+
self._traces[trace_id] = {
|
|
47
|
+
"trace_id": trace_id,
|
|
48
|
+
"metadata": metadata,
|
|
49
|
+
"started_at": datetime.now(timezone.utc).isoformat(),
|
|
50
|
+
"steps": [],
|
|
51
|
+
"user_inputs": [],
|
|
52
|
+
"tool_calls": [],
|
|
53
|
+
"errors": []
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
self._write({
|
|
57
|
+
"record": "trace_start",
|
|
58
|
+
"trace_id": trace_id,
|
|
59
|
+
"metadata": metadata,
|
|
60
|
+
"ts": datetime.now(timezone.utc).isoformat()
|
|
61
|
+
})
|
|
62
|
+
|
|
63
|
+
return trace_id
|
|
64
|
+
|
|
65
|
+
def record_user_input(self, trace_id: str, user_input: str):
|
|
66
|
+
"""Record user input message"""
|
|
67
|
+
if trace_id in self._traces:
|
|
68
|
+
self._traces[trace_id]["user_inputs"].append(user_input)
|
|
69
|
+
|
|
70
|
+
self._write({
|
|
71
|
+
"record": "user_input",
|
|
72
|
+
"trace_id": trace_id,
|
|
73
|
+
"content": user_input,
|
|
74
|
+
"ts": datetime.now(timezone.utc).isoformat()
|
|
75
|
+
})
|
|
76
|
+
|
|
77
|
+
def record_tool_call(
|
|
78
|
+
self,
|
|
79
|
+
trace_id: str,
|
|
80
|
+
tool_name: str,
|
|
81
|
+
tool_input: Dict[str, Any],
|
|
82
|
+
tool_output: str,
|
|
83
|
+
server: Optional[str] = None
|
|
84
|
+
):
|
|
85
|
+
"""Record a tool call and its result"""
|
|
86
|
+
tool_call_data = {
|
|
87
|
+
"tool_name": tool_name,
|
|
88
|
+
"tool_input": tool_input,
|
|
89
|
+
"tool_output": tool_output,
|
|
90
|
+
"server": server,
|
|
91
|
+
"ts": datetime.now(timezone.utc).isoformat()
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
if trace_id in self._traces:
|
|
95
|
+
self._traces[trace_id]["tool_calls"].append(tool_call_data)
|
|
96
|
+
|
|
97
|
+
self._write({
|
|
98
|
+
"record": "tool_call",
|
|
99
|
+
"trace_id": trace_id,
|
|
100
|
+
**tool_call_data
|
|
101
|
+
})
|
|
102
|
+
|
|
103
|
+
def record_agent_response(self, trace_id: str, response: str):
|
|
104
|
+
"""Record agent's final response"""
|
|
105
|
+
self._write({
|
|
106
|
+
"record": "agent_response",
|
|
107
|
+
"trace_id": trace_id,
|
|
108
|
+
"response": response,
|
|
109
|
+
"ts": datetime.now(timezone.utc).isoformat()
|
|
110
|
+
})
|
|
111
|
+
|
|
112
|
+
def write_agent_response(self, trace_id: str, response: str):
|
|
113
|
+
"""
|
|
114
|
+
Write intermediate agent response (for multi-turn conversations).
|
|
115
|
+
|
|
116
|
+
This is called after each turn to record the intermediate response,
|
|
117
|
+
distinct from record_agent_response which is for the final response.
|
|
118
|
+
"""
|
|
119
|
+
self._write({
|
|
120
|
+
"record": "agent_turn_response",
|
|
121
|
+
"trace_id": trace_id,
|
|
122
|
+
"response": response,
|
|
123
|
+
"ts": datetime.now(timezone.utc).isoformat()
|
|
124
|
+
})
|
|
125
|
+
|
|
126
|
+
def record_error(self, trace_id: str, error: str):
|
|
127
|
+
"""Record an error"""
|
|
128
|
+
if trace_id in self._traces:
|
|
129
|
+
self._traces[trace_id]["errors"].append(error)
|
|
130
|
+
|
|
131
|
+
self._write({
|
|
132
|
+
"record": "error",
|
|
133
|
+
"trace_id": trace_id,
|
|
134
|
+
"error": error,
|
|
135
|
+
"ts": datetime.now(timezone.utc).isoformat()
|
|
136
|
+
})
|
|
137
|
+
|
|
138
|
+
def end_trace(self, trace_id: str, final_output: Optional[str], duration: float):
|
|
139
|
+
"""End the trace and record final state"""
|
|
140
|
+
ended_at = datetime.now(timezone.utc).isoformat()
|
|
141
|
+
|
|
142
|
+
if trace_id in self._traces:
|
|
143
|
+
self._traces[trace_id]["ended_at"] = ended_at
|
|
144
|
+
self._traces[trace_id]["duration"] = duration
|
|
145
|
+
self._traces[trace_id]["final_output"] = final_output
|
|
146
|
+
|
|
147
|
+
self._write({
|
|
148
|
+
"record": "trace_end",
|
|
149
|
+
"trace_id": trace_id,
|
|
150
|
+
"final_output": final_output,
|
|
151
|
+
"duration": duration,
|
|
152
|
+
"ended_at": ended_at,
|
|
153
|
+
"ts": datetime.now(timezone.utc).isoformat()
|
|
154
|
+
})
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class LangChainTrajectoryConverter:
|
|
158
|
+
"""Convert LangChain agent traces to standard trajectory format"""
|
|
159
|
+
|
|
160
|
+
def __init__(self, output_dir: str):
|
|
161
|
+
self.output_dir = output_dir
|
|
162
|
+
os.makedirs(self.output_dir, exist_ok=True)
|
|
163
|
+
|
|
164
|
+
def load_traces(self, trace_file: str) -> Dict[str, Dict[str, Any]]:
|
|
165
|
+
"""
|
|
166
|
+
Load traces from JSONL file
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
trace_file: Path to trace JSONL file
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
Dictionary mapping trace_id to trace data
|
|
173
|
+
"""
|
|
174
|
+
if not os.path.exists(trace_file):
|
|
175
|
+
raise FileNotFoundError(f"Trace file not found: {trace_file}")
|
|
176
|
+
|
|
177
|
+
traces: Dict[str, Dict[str, Any]] = defaultdict(lambda: {
|
|
178
|
+
"metadata": {},
|
|
179
|
+
"user_inputs": [],
|
|
180
|
+
"tool_calls": [],
|
|
181
|
+
"agent_responses": [],
|
|
182
|
+
"turn_responses": [], # Intermediate responses for multi-turn
|
|
183
|
+
"errors": [],
|
|
184
|
+
"final_output": None,
|
|
185
|
+
"duration": 0.0
|
|
186
|
+
})
|
|
187
|
+
|
|
188
|
+
with open(trace_file, 'r', encoding='utf-8') as f:
|
|
189
|
+
for line in f:
|
|
190
|
+
if not line.strip():
|
|
191
|
+
continue
|
|
192
|
+
|
|
193
|
+
try:
|
|
194
|
+
data = json.loads(line)
|
|
195
|
+
record_type = data.get("record")
|
|
196
|
+
trace_id = data.get("trace_id")
|
|
197
|
+
|
|
198
|
+
if not trace_id:
|
|
199
|
+
continue
|
|
200
|
+
|
|
201
|
+
if record_type == "trace_start":
|
|
202
|
+
traces[trace_id]["metadata"] = data.get("metadata", {})
|
|
203
|
+
traces[trace_id]["started_at"] = data.get("ts")
|
|
204
|
+
|
|
205
|
+
elif record_type == "user_input":
|
|
206
|
+
traces[trace_id]["user_inputs"].append(data.get("content", ""))
|
|
207
|
+
|
|
208
|
+
elif record_type == "tool_call":
|
|
209
|
+
traces[trace_id]["tool_calls"].append({
|
|
210
|
+
"tool_name": data.get("tool_name"),
|
|
211
|
+
"tool_input": data.get("tool_input", {}),
|
|
212
|
+
"tool_output": data.get("tool_output", ""),
|
|
213
|
+
"server": data.get("server"),
|
|
214
|
+
"ts": data.get("ts")
|
|
215
|
+
})
|
|
216
|
+
|
|
217
|
+
elif record_type == "agent_response":
|
|
218
|
+
traces[trace_id]["agent_responses"].append(data.get("response", ""))
|
|
219
|
+
|
|
220
|
+
elif record_type == "agent_turn_response":
|
|
221
|
+
# Intermediate responses from multi-turn conversations
|
|
222
|
+
traces[trace_id]["turn_responses"].append({
|
|
223
|
+
"response": data.get("response", ""),
|
|
224
|
+
"ts": data.get("ts")
|
|
225
|
+
})
|
|
226
|
+
|
|
227
|
+
elif record_type == "error":
|
|
228
|
+
traces[trace_id]["errors"].append(data.get("error", ""))
|
|
229
|
+
|
|
230
|
+
elif record_type == "trace_end":
|
|
231
|
+
traces[trace_id]["final_output"] = data.get("final_output")
|
|
232
|
+
traces[trace_id]["duration"] = data.get("duration", 0.0)
|
|
233
|
+
traces[trace_id]["ended_at"] = data.get("ended_at")
|
|
234
|
+
|
|
235
|
+
except json.JSONDecodeError as e:
|
|
236
|
+
print(f"[WARNING] Failed to parse trace line: {e}")
|
|
237
|
+
continue
|
|
238
|
+
|
|
239
|
+
return dict(traces)
|
|
240
|
+
|
|
241
|
+
def convert_trace_to_trajectory(
|
|
242
|
+
self,
|
|
243
|
+
trace_id: str,
|
|
244
|
+
trace_data: Dict[str, Any]
|
|
245
|
+
) -> Dict[str, Any]:
|
|
246
|
+
"""
|
|
247
|
+
Convert a single trace to trajectory format
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
trace_id: Trace identifier
|
|
251
|
+
trace_data: Trace data dictionary
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
Trajectory dictionary
|
|
255
|
+
"""
|
|
256
|
+
metadata = trace_data.get("metadata", {})
|
|
257
|
+
|
|
258
|
+
# Create Trajectory instance
|
|
259
|
+
traj = Trajectory(
|
|
260
|
+
task_id=metadata.get("task_id", trace_id[:16]),
|
|
261
|
+
original_instruction=metadata.get("instruction", ""),
|
|
262
|
+
malicious_instruction=metadata.get("malicious_goal"),
|
|
263
|
+
domain=metadata.get("domain"),
|
|
264
|
+
risk_category=metadata.get("category")
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
# Set duration and timestamp
|
|
268
|
+
traj.data["traj_info"]["duration"] = trace_data.get("duration", 0.0)
|
|
269
|
+
if trace_data.get("ended_at"):
|
|
270
|
+
traj.data["traj_info"]["timestamp"] = trace_data["ended_at"]
|
|
271
|
+
|
|
272
|
+
# Build timeline of events for proper interleaving
|
|
273
|
+
# For multi-turn conversations, we need to interleave:
|
|
274
|
+
# user_input -> tool_calls -> agent_response -> user_input -> ...
|
|
275
|
+
|
|
276
|
+
user_inputs = trace_data.get("user_inputs", [])
|
|
277
|
+
tool_calls = trace_data.get("tool_calls", [])
|
|
278
|
+
turn_responses = trace_data.get("turn_responses", [])
|
|
279
|
+
|
|
280
|
+
# Track seen responses to avoid duplicates
|
|
281
|
+
seen_responses = set()
|
|
282
|
+
|
|
283
|
+
# Process in order: for each user input, add related tool calls and response
|
|
284
|
+
# Note: This is a simplified approach assuming tool calls happen between user inputs
|
|
285
|
+
tool_call_idx = 0
|
|
286
|
+
turn_response_idx = 0
|
|
287
|
+
|
|
288
|
+
for i, user_input in enumerate(user_inputs):
|
|
289
|
+
# Add user input
|
|
290
|
+
traj.append_user_step(user_input)
|
|
291
|
+
|
|
292
|
+
# Add tool calls that happened after this user input
|
|
293
|
+
# (before next user input or end)
|
|
294
|
+
while tool_call_idx < len(tool_calls):
|
|
295
|
+
tool_call = tool_calls[tool_call_idx]
|
|
296
|
+
tool_name = tool_call.get("tool_name", "unknown")
|
|
297
|
+
tool_input = tool_call.get("tool_input", {})
|
|
298
|
+
tool_output = tool_call.get("tool_output", "")
|
|
299
|
+
server = tool_call.get("server")
|
|
300
|
+
|
|
301
|
+
# Build action string
|
|
302
|
+
action_str = f"{tool_name}("
|
|
303
|
+
params = []
|
|
304
|
+
for k, v in tool_input.items():
|
|
305
|
+
if isinstance(v, str):
|
|
306
|
+
params.append(f'{k}="{v}"')
|
|
307
|
+
else:
|
|
308
|
+
params.append(f'{k}={v}')
|
|
309
|
+
action_str += ", ".join(params) + ")"
|
|
310
|
+
|
|
311
|
+
# Add agent action step
|
|
312
|
+
traj.append_agent_step(
|
|
313
|
+
action=action_str,
|
|
314
|
+
tool_name=tool_name,
|
|
315
|
+
tool_params=tool_input,
|
|
316
|
+
server=server
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
# Add tool return step
|
|
320
|
+
try:
|
|
321
|
+
parsed_output = json.loads(tool_output) if isinstance(tool_output, str) else tool_output
|
|
322
|
+
except json.JSONDecodeError:
|
|
323
|
+
parsed_output = tool_output
|
|
324
|
+
|
|
325
|
+
traj.append_tool_return(
|
|
326
|
+
result=parsed_output,
|
|
327
|
+
tool_name=tool_name,
|
|
328
|
+
server=server
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
tool_call_idx += 1
|
|
332
|
+
|
|
333
|
+
# Check if there's a turn response after this tool call
|
|
334
|
+
# Simple heuristic: if we've processed enough tool calls for this turn
|
|
335
|
+
if turn_response_idx < len(turn_responses):
|
|
336
|
+
# Check if next item is a user input or more tool calls
|
|
337
|
+
# For now, add turn response after tool calls
|
|
338
|
+
break
|
|
339
|
+
|
|
340
|
+
# Add turn response after tool calls for this turn
|
|
341
|
+
if turn_response_idx < len(turn_responses):
|
|
342
|
+
turn_resp = turn_responses[turn_response_idx]
|
|
343
|
+
response_text = turn_resp.get("response", "")
|
|
344
|
+
if response_text and response_text not in seen_responses:
|
|
345
|
+
traj.append_agent_step(
|
|
346
|
+
action="send_message_to_user",
|
|
347
|
+
metadata={"message": response_text}
|
|
348
|
+
)
|
|
349
|
+
seen_responses.add(response_text)
|
|
350
|
+
turn_response_idx += 1
|
|
351
|
+
|
|
352
|
+
# Process any remaining tool calls
|
|
353
|
+
while tool_call_idx < len(tool_calls):
|
|
354
|
+
tool_call = tool_calls[tool_call_idx]
|
|
355
|
+
tool_name = tool_call.get("tool_name", "unknown")
|
|
356
|
+
tool_input = tool_call.get("tool_input", {})
|
|
357
|
+
tool_output = tool_call.get("tool_output", "")
|
|
358
|
+
server = tool_call.get("server")
|
|
359
|
+
|
|
360
|
+
action_str = f"{tool_name}("
|
|
361
|
+
params = []
|
|
362
|
+
for k, v in tool_input.items():
|
|
363
|
+
if isinstance(v, str):
|
|
364
|
+
params.append(f'{k}="{v}"')
|
|
365
|
+
else:
|
|
366
|
+
params.append(f'{k}={v}')
|
|
367
|
+
action_str += ", ".join(params) + ")"
|
|
368
|
+
|
|
369
|
+
traj.append_agent_step(
|
|
370
|
+
action=action_str,
|
|
371
|
+
tool_name=tool_name,
|
|
372
|
+
tool_params=tool_input,
|
|
373
|
+
server=server
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
try:
|
|
377
|
+
parsed_output = json.loads(tool_output) if isinstance(tool_output, str) else tool_output
|
|
378
|
+
except json.JSONDecodeError:
|
|
379
|
+
parsed_output = tool_output
|
|
380
|
+
|
|
381
|
+
traj.append_tool_return(
|
|
382
|
+
result=parsed_output,
|
|
383
|
+
tool_name=tool_name,
|
|
384
|
+
server=server
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
tool_call_idx += 1
|
|
388
|
+
|
|
389
|
+
# Add final agent response (if different from turn responses)
|
|
390
|
+
final_output = trace_data.get("final_output")
|
|
391
|
+
if final_output and final_output not in seen_responses:
|
|
392
|
+
traj.append_agent_step(
|
|
393
|
+
action="send_message_to_user",
|
|
394
|
+
metadata={"message": final_output}
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
return traj.to_dict()
|
|
398
|
+
|
|
399
|
+
def process_trace_file(
|
|
400
|
+
self,
|
|
401
|
+
trace_file: str,
|
|
402
|
+
output_name: Optional[str] = None
|
|
403
|
+
) -> List[str]:
|
|
404
|
+
"""
|
|
405
|
+
Process trace file and generate trajectory files
|
|
406
|
+
|
|
407
|
+
Args:
|
|
408
|
+
trace_file: Path to trace JSONL file
|
|
409
|
+
output_name: Optional prefix for output file names
|
|
410
|
+
|
|
411
|
+
Returns:
|
|
412
|
+
List of saved trajectory file paths
|
|
413
|
+
"""
|
|
414
|
+
traces = self.load_traces(trace_file)
|
|
415
|
+
saved_files = []
|
|
416
|
+
|
|
417
|
+
for trace_id, trace_data in traces.items():
|
|
418
|
+
task_id = trace_data.get("metadata", {}).get("task_id")
|
|
419
|
+
|
|
420
|
+
trajectory = self.convert_trace_to_trajectory(trace_id, trace_data)
|
|
421
|
+
|
|
422
|
+
# Generate filename
|
|
423
|
+
filename = f"{output_name}_{task_id or trace_id[:8]}.json" if output_name else f"trajectory_{task_id or trace_id[:8]}.json"
|
|
424
|
+
output_path = os.path.join(self.output_dir, filename)
|
|
425
|
+
|
|
426
|
+
# Save trajectory
|
|
427
|
+
with open(output_path, 'w', encoding='utf-8') as f:
|
|
428
|
+
json.dump(trajectory, f, indent=2, ensure_ascii=False)
|
|
429
|
+
|
|
430
|
+
saved_files.append(output_path)
|
|
431
|
+
|
|
432
|
+
return saved_files
|
|
433
|
+
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""
|
|
2
|
+
OpenAI SDK Agent Implementation
|
|
3
|
+
|
|
4
|
+
This module provides an agent implementation using the OpenAI Agents SDK
|
|
5
|
+
with trajectory tracking and MCP server support.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from .src.agent import OpenAISDKAgent
|
|
9
|
+
from .src.agent_wrapper import OpenAISDKNativeWrapper
|
|
10
|
+
from .src.utils import OpenAITraceProcessor, OpenAIAgentTrajectory
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"OpenAISDKAgent",
|
|
14
|
+
"OpenAISDKNativeWrapper",
|
|
15
|
+
"OpenAITraceProcessor",
|
|
16
|
+
"OpenAIAgentTrajectory",
|
|
17
|
+
]
|
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import asyncio
|
|
4
|
+
import argparse
|
|
5
|
+
|
|
6
|
+
from dt_arena.src.types.agent import AgentConfig, RuntimeConfig
|
|
7
|
+
from dt_arena.src.types.task import TaskConfig, AttackConfig
|
|
8
|
+
|
|
9
|
+
from agent.openaisdk import OpenAISDKAgent
|
|
10
|
+
from utils.injection_helpers import (
|
|
11
|
+
build_tool_injections_from_config,
|
|
12
|
+
build_skill_injections_from_config,
|
|
13
|
+
apply_prompt_injections,
|
|
14
|
+
)
|
|
15
|
+
from utils.task_helpers import extract_dataset_path
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# ============================================================================
|
|
19
|
+
# Standalone Test Example for OpenAISDKAgent
|
|
20
|
+
# Supports tool injection and prompt injection from attack config
|
|
21
|
+
# ============================================================================
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
async def main():
|
|
25
|
+
"""Example demonstrating agent with MCP server and trajectory generation"""
|
|
26
|
+
|
|
27
|
+
parser = argparse.ArgumentParser(
|
|
28
|
+
description="Run OpenAI SDK Agent with configuration file",
|
|
29
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
30
|
+
epilog="""
|
|
31
|
+
Examples:
|
|
32
|
+
# Run with default settings
|
|
33
|
+
python example.py --config path/to/config.yaml
|
|
34
|
+
|
|
35
|
+
# Run with custom model and temperature
|
|
36
|
+
python example.py --config path/to/config.yaml --model gpt-4o-mini --temperature 0.5
|
|
37
|
+
|
|
38
|
+
# Run with custom max turns and output directory
|
|
39
|
+
python example.py --config path/to/config.yaml --max-turns 20 --output-dir ./custom_results
|
|
40
|
+
"""
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
parser.add_argument(
|
|
44
|
+
"--config",
|
|
45
|
+
type=str,
|
|
46
|
+
required=True,
|
|
47
|
+
help="Path to YAML configuration file (with Task, Agent, Attack sections)"
|
|
48
|
+
)
|
|
49
|
+
parser.add_argument(
|
|
50
|
+
"--model",
|
|
51
|
+
type=str,
|
|
52
|
+
default="gpt-4o",
|
|
53
|
+
help="Model to use (default: gpt-4o)"
|
|
54
|
+
)
|
|
55
|
+
parser.add_argument(
|
|
56
|
+
"--temperature",
|
|
57
|
+
type=float,
|
|
58
|
+
default=0.1,
|
|
59
|
+
help="Sampling temperature, 0.0-1.0 (default: 0.1)"
|
|
60
|
+
)
|
|
61
|
+
parser.add_argument(
|
|
62
|
+
"--max-turns",
|
|
63
|
+
type=int,
|
|
64
|
+
default=10,
|
|
65
|
+
help="Maximum conversation turns (default: 10)"
|
|
66
|
+
)
|
|
67
|
+
parser.add_argument(
|
|
68
|
+
"--output-dir",
|
|
69
|
+
type=str,
|
|
70
|
+
default=None,
|
|
71
|
+
help="Output directory for traces and trajectories (default: ./results)"
|
|
72
|
+
)
|
|
73
|
+
parser.add_argument(
|
|
74
|
+
"--debug",
|
|
75
|
+
action="store_true",
|
|
76
|
+
help="Enable debug mode for tracing"
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
args = parser.parse_args()
|
|
80
|
+
|
|
81
|
+
# Setup
|
|
82
|
+
config_path = args.config
|
|
83
|
+
if not os.path.exists(config_path):
|
|
84
|
+
print(f"[ERROR] Configuration file not found: {config_path}")
|
|
85
|
+
sys.exit(1)
|
|
86
|
+
|
|
87
|
+
# Extract dataset path structure (e.g., crm/benign/1)
|
|
88
|
+
dataset_path = extract_dataset_path(config_path)
|
|
89
|
+
|
|
90
|
+
# Build output directory
|
|
91
|
+
base_output_dir = args.output_dir or os.path.join(os.getcwd(), "results")
|
|
92
|
+
output_dir = os.path.join(base_output_dir, "openaisdk", dataset_path)
|
|
93
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
94
|
+
|
|
95
|
+
print(f"[INFO] Loading configuration from: {config_path}")
|
|
96
|
+
print(f"[INFO] Model: {args.model}")
|
|
97
|
+
print(f"[INFO] Temperature: {args.temperature}")
|
|
98
|
+
print(f"[INFO] Max Turns: {args.max_turns}")
|
|
99
|
+
print(f"[INFO] Output Directory: {output_dir}")
|
|
100
|
+
print("-" * 80)
|
|
101
|
+
|
|
102
|
+
try:
|
|
103
|
+
# Load configuration
|
|
104
|
+
task_config = TaskConfig.from_yaml(config_path)
|
|
105
|
+
agent_config = AgentConfig.from_yaml(config_path)
|
|
106
|
+
attack_config = AttackConfig.from_yaml(config_path)
|
|
107
|
+
|
|
108
|
+
# Build tool injections from all attack turns
|
|
109
|
+
mcp_injection = build_tool_injections_from_config(attack_config)
|
|
110
|
+
|
|
111
|
+
# Build skill injections from all attack turns
|
|
112
|
+
skill_injection = build_skill_injections_from_config(attack_config)
|
|
113
|
+
|
|
114
|
+
# Create runtime config with tool and skill injections
|
|
115
|
+
runtime_config = RuntimeConfig(
|
|
116
|
+
model=args.model,
|
|
117
|
+
temperature=args.temperature,
|
|
118
|
+
max_turns=args.max_turns,
|
|
119
|
+
output_dir=output_dir,
|
|
120
|
+
mcp_injection=mcp_injection,
|
|
121
|
+
skill_injection=skill_injection,
|
|
122
|
+
debug=args.debug,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
# Create agent with configurations
|
|
126
|
+
agent = OpenAISDKAgent(
|
|
127
|
+
agent_config=agent_config,
|
|
128
|
+
runtime_config=runtime_config,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
# Display loaded configuration
|
|
132
|
+
print(f"\n[TASK] Task ID: {task_config.task_id}")
|
|
133
|
+
print(f"[TASK] Domain: {task_config.domain}")
|
|
134
|
+
|
|
135
|
+
if attack_config:
|
|
136
|
+
print(f"\n[ATTACK] Risk Category: {attack_config.risk_category}")
|
|
137
|
+
print(f"[ATTACK] Threat Model: {attack_config.threat_model}")
|
|
138
|
+
if attack_config.malicious_goal:
|
|
139
|
+
print(f"[ATTACK] Malicious Goal: {attack_config.malicious_goal}")
|
|
140
|
+
if attack_config.attack_turns:
|
|
141
|
+
print(f"[ATTACK] Attack Turns: {len(attack_config.attack_turns)}")
|
|
142
|
+
for turn in attack_config.attack_turns:
|
|
143
|
+
print(f" Turn {turn.turn_id}: {len(turn.attack_steps)} attack step(s)")
|
|
144
|
+
for step in turn.attack_steps:
|
|
145
|
+
print(f" - {step.type} ({step.mode})")
|
|
146
|
+
|
|
147
|
+
# Handle prompt injections based on threat model
|
|
148
|
+
is_direct_attack = attack_config and attack_config.threat_model == "direct"
|
|
149
|
+
|
|
150
|
+
if is_direct_attack:
|
|
151
|
+
user_instruction = apply_prompt_injections(original_instruction=None, attack_config=attack_config)
|
|
152
|
+
if not user_instruction:
|
|
153
|
+
user_instruction = task_config.original_instruction
|
|
154
|
+
else:
|
|
155
|
+
user_instruction = task_config.original_instruction
|
|
156
|
+
user_instruction = apply_prompt_injections(user_instruction, attack_config)
|
|
157
|
+
|
|
158
|
+
# Normalize instruction
|
|
159
|
+
if isinstance(user_instruction, str):
|
|
160
|
+
user_instruction = user_instruction.strip()
|
|
161
|
+
else:
|
|
162
|
+
user_instruction = [instr.strip() for instr in user_instruction]
|
|
163
|
+
|
|
164
|
+
print("\n" + "=" * 80)
|
|
165
|
+
print("[INSTRUCTION (after injection)]")
|
|
166
|
+
print("=" * 80)
|
|
167
|
+
if isinstance(user_instruction, list):
|
|
168
|
+
print(f"Multi-turn task with {len(user_instruction)} queries:")
|
|
169
|
+
for i, instr in enumerate(user_instruction, 1):
|
|
170
|
+
print(f" {i}. {instr}")
|
|
171
|
+
else:
|
|
172
|
+
print(user_instruction)
|
|
173
|
+
print("=" * 80)
|
|
174
|
+
|
|
175
|
+
if attack_config and attack_config.malicious_goal:
|
|
176
|
+
print("\n[MALICIOUS GOAL] Testing Against:")
|
|
177
|
+
print("-" * 80)
|
|
178
|
+
print(attack_config.malicious_goal.strip())
|
|
179
|
+
print("-" * 80)
|
|
180
|
+
|
|
181
|
+
# Run agent with async context manager
|
|
182
|
+
async with agent:
|
|
183
|
+
print("\n[INFO] Running agent...")
|
|
184
|
+
if isinstance(user_instruction, list):
|
|
185
|
+
print(f"[INFO] Processing {len(user_instruction)} turns...")
|
|
186
|
+
# Pass task metadata via run() method
|
|
187
|
+
metadata = {
|
|
188
|
+
"task_id": task_config.task_id,
|
|
189
|
+
"domain": task_config.domain,
|
|
190
|
+
"category": attack_config.risk_category if attack_config else None,
|
|
191
|
+
"malicious_goal": attack_config.malicious_goal if attack_config else None,
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
result = await agent.run(user_instruction, metadata=metadata)
|
|
195
|
+
|
|
196
|
+
print("\n" + "=" * 80)
|
|
197
|
+
print("[AGENT RESPONSE]")
|
|
198
|
+
print("=" * 80)
|
|
199
|
+
print(result.final_output)
|
|
200
|
+
print("=" * 80)
|
|
201
|
+
|
|
202
|
+
print(f"\n[SUCCESS] Task completed")
|
|
203
|
+
print(f"[INFO] Turns: {result.turn_count}")
|
|
204
|
+
print(f"[INFO] Trace ID: {result.trace_id}")
|
|
205
|
+
print(f"[INFO] Traces saved to: {agent.traces_dir}")
|
|
206
|
+
print(f"[INFO] Trajectories saved to: {agent.trajectories_dir}")
|
|
207
|
+
if result.trajectory:
|
|
208
|
+
print(f"[INFO] Trajectory steps: {len(result.trajectory.data['trajectory'])}")
|
|
209
|
+
|
|
210
|
+
return result
|
|
211
|
+
|
|
212
|
+
except Exception as e:
|
|
213
|
+
print(f"\n[ERROR] {e}")
|
|
214
|
+
import traceback
|
|
215
|
+
traceback.print_exc()
|
|
216
|
+
sys.exit(1)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
if __name__ == "__main__":
|
|
220
|
+
try:
|
|
221
|
+
result = asyncio.run(main())
|
|
222
|
+
if result:
|
|
223
|
+
print(f"\n[SUCCESS] Task completed")
|
|
224
|
+
except KeyboardInterrupt:
|
|
225
|
+
print("\n[INFO] Interrupted by user")
|
|
226
|
+
except Exception as e:
|
|
227
|
+
print(f"\n[ERROR] {e}")
|
|
228
|
+
sys.exit(1)
|