decodingtrust-agent-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent/__init__.py +30 -0
- agent/claudesdk/__init__.py +8 -0
- agent/claudesdk/example.py +221 -0
- agent/claudesdk/src/__init__.py +8 -0
- agent/claudesdk/src/agent.py +400 -0
- agent/claudesdk/src/mcp_proxy.py +409 -0
- agent/claudesdk/src/utils.py +420 -0
- agent/googleadk/__init__.py +15 -0
- agent/googleadk/example.py +237 -0
- agent/googleadk/src/__init__.py +12 -0
- agent/googleadk/src/agent.py +401 -0
- agent/googleadk/src/mcp_wrapper.py +163 -0
- agent/googleadk/src/utils.py +602 -0
- agent/langchain/__init__.py +8 -0
- agent/langchain/example.py +213 -0
- agent/langchain/src/__init__.py +8 -0
- agent/langchain/src/agent.py +645 -0
- agent/langchain/src/utils.py +433 -0
- agent/openaisdk/__init__.py +17 -0
- agent/openaisdk/example.py +228 -0
- agent/openaisdk/src/__init__.py +12 -0
- agent/openaisdk/src/agent.py +491 -0
- agent/openaisdk/src/agent_wrapper.py +143 -0
- agent/openaisdk/src/mcp_wrapper.py +395 -0
- agent/openaisdk/src/utils.py +493 -0
- agent/openclaw/__init__.py +10 -0
- agent/openclaw/example.py +251 -0
- agent/openclaw/src/__init__.py +14 -0
- agent/openclaw/src/agent.py +930 -0
- agent/openclaw/src/helpers/__init__.py +1 -0
- agent/openclaw/src/helpers/auth_helpers.py +55 -0
- agent/openclaw/src/mcp_proxy.py +564 -0
- agent/openclaw/src/plugin_generator.py +231 -0
- agent/openclaw/src/utils.py +341 -0
- agent/pocketflow/__init__.py +18 -0
- agent/pocketflow/example.py +221 -0
- agent/pocketflow/prompts/react_agent.py +46 -0
- agent/pocketflow/src/__init__.py +6 -0
- agent/pocketflow/src/agent.py +507 -0
- agent/pocketflow/src/agent_wrapper.py +159 -0
- agent/pocketflow/src/async_helper.py +92 -0
- agent/pocketflow/src/mcp_react_agent.py +279 -0
- agent/pocketflow/src/native_agent.py +74 -0
- agent/pocketflow/src/nodes.py +467 -0
- benchmark/__init__.py +0 -0
- benchmark/browser/benign.jsonl +34 -0
- benchmark/browser/direct.jsonl +85 -0
- benchmark/browser/indirect.jsonl +82 -0
- benchmark/code/benign.jsonl +0 -0
- benchmark/code/direct.jsonl +121 -0
- benchmark/code/indirect.jsonl +165 -0
- benchmark/crm/benign.jsonl +165 -0
- benchmark/crm/direct.jsonl +90 -0
- benchmark/crm/indirect.jsonl +150 -0
- benchmark/customer-service/benign.jsonl +160 -0
- benchmark/customer-service/direct.jsonl +100 -0
- benchmark/customer-service/indirect.jsonl +101 -0
- benchmark/finance/benign.jsonl +0 -0
- benchmark/finance/direct.jsonl +200 -0
- benchmark/finance/indirect.jsonl +200 -0
- benchmark/legal/benign.jsonl +0 -0
- benchmark/legal/direct.jsonl +200 -0
- benchmark/legal/indirect.jsonl +200 -0
- benchmark/macos/benign.jsonl +30 -0
- benchmark/macos/direct.jsonl +50 -0
- benchmark/macos/indirect.jsonl +50 -0
- benchmark/medical/benign.jsonl +642 -0
- benchmark/medical/direct.jsonl +229 -0
- benchmark/medical/indirect.jsonl +222 -0
- benchmark/os-filesystem/benign.jsonl +200 -0
- benchmark/os-filesystem/direct.jsonl +200 -0
- benchmark/os-filesystem/indirect.jsonl +200 -0
- benchmark/research/benign.jsonl +0 -0
- benchmark/research/direct.jsonl +119 -0
- benchmark/research/indirect.jsonl +125 -0
- benchmark/telecom/benign.jsonl +120 -0
- benchmark/telecom/direct.jsonl +161 -0
- benchmark/telecom/indirect.jsonl +166 -0
- benchmark/travel/benign.jsonl +130 -0
- benchmark/travel/direct.jsonl +105 -0
- benchmark/travel/indirect.jsonl +120 -0
- benchmark/windows/benign.jsonl +100 -0
- benchmark/windows/direct.jsonl +140 -0
- benchmark/windows/indirect.jsonl +107 -0
- benchmark/workflow/benign.jsonl +335 -0
- benchmark/workflow/direct.jsonl +78 -0
- benchmark/workflow/indirect.jsonl +107 -0
- cli/__init__.py +5 -0
- cli/main.py +182 -0
- cli/scaffold.py +334 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/METADATA +642 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/RECORD +374 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/WHEEL +5 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/entry_points.txt +2 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/licenses/LICENSE +201 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/top_level.txt +6 -0
- dt_arena/config/env.yaml +515 -0
- dt_arena/config/injection_mcp.yaml +430 -0
- dt_arena/config/mcp.yaml +642 -0
- dt_arena/envs/arxiv/docker-compose-hub.yml +31 -0
- dt_arena/envs/arxiv/docker-compose.yml +36 -0
- dt_arena/envs/atlassian/docker/docker-compose.dev.yml +65 -0
- dt_arena/envs/atlassian/docker/docker-compose.yml +53 -0
- dt_arena/envs/atlassian/docker-compose-hub.yml +57 -0
- dt_arena/envs/atlassian/docker-compose.yml +72 -0
- dt_arena/envs/bigquery/docker-compose.yml +20 -0
- dt_arena/envs/booking/docker-compose.yml +59 -0
- dt_arena/envs/calendar/docker-compose-hub.yml +30 -0
- dt_arena/envs/calendar/docker-compose.yml +42 -0
- dt_arena/envs/custom-website/docker-compose.yml +6 -0
- dt_arena/envs/customer_service/docker-compose.yml +59 -0
- dt_arena/envs/databricks/docker-compose-hub.yml +47 -0
- dt_arena/envs/databricks/docker-compose.yml +51 -0
- dt_arena/envs/ecommerce/docker-compose.yml +6 -0
- dt_arena/envs/ers/docker-compose.yml +36 -0
- dt_arena/envs/ers/hrms/docker/docker-compose.yml +31 -0
- dt_arena/envs/finance/docker-compose.yml +23 -0
- dt_arena/envs/github/docker/docker-compose-hub.yml +50 -0
- dt_arena/envs/github/docker/docker-compose.yml +50 -0
- dt_arena/envs/gmail/docker-compose-hub.yml +51 -0
- dt_arena/envs/gmail/docker-compose.yml +65 -0
- dt_arena/envs/google-form/docker-compose-hub.yml +33 -0
- dt_arena/envs/google-form/docker-compose.yml +41 -0
- dt_arena/envs/googledocs/docker-compose-hub.yml +61 -0
- dt_arena/envs/googledocs/docker-compose.yml +78 -0
- dt_arena/envs/hospital/docker-compose-hub.yml +25 -0
- dt_arena/envs/hospital/docker-compose.yml +27 -0
- dt_arena/envs/legal/docker-compose.yml +22 -0
- dt_arena/envs/linkedin/docker-compose.yml +63 -0
- dt_arena/envs/macos/docker-compose.yml +79 -0
- dt_arena/envs/os-filesystem/docker-compose-hub.yml +16 -0
- dt_arena/envs/os-filesystem/docker-compose.yml +20 -0
- dt_arena/envs/paypal/docker-compose-hub.yml +48 -0
- dt_arena/envs/paypal/docker-compose.yml +63 -0
- dt_arena/envs/research/docker-compose-hub.yml +13 -0
- dt_arena/envs/research/docker-compose.yml +24 -0
- dt_arena/envs/salesforce_crm/docker-compose-hub.yaml +45 -0
- dt_arena/envs/salesforce_crm/docker-compose.yaml +49 -0
- dt_arena/envs/slack/docker-compose-hub.yml +28 -0
- dt_arena/envs/slack/docker-compose.yml +41 -0
- dt_arena/envs/snowflake/docker-compose-hub.yml +41 -0
- dt_arena/envs/snowflake/docker-compose.yml +44 -0
- dt_arena/envs/telecom/docker-compose-hub.yml +16 -0
- dt_arena/envs/telecom/docker-compose.yml +17 -0
- dt_arena/envs/telegram/docker-compose-hub.yml +57 -0
- dt_arena/envs/telegram/docker-compose.yml +62 -0
- dt_arena/envs/terminal/docker-compose-hub.yml +12 -0
- dt_arena/envs/terminal/docker-compose.yml +26 -0
- dt_arena/envs/travel/docker-compose-hub.yml +19 -0
- dt_arena/envs/travel/docker-compose.yml +19 -0
- dt_arena/envs/whatsapp/docker-compose-hub.yml +61 -0
- dt_arena/envs/whatsapp/docker-compose.yml +78 -0
- dt_arena/envs/windows/docker-compose.yml +71 -0
- dt_arena/envs/zoom/docker-compose-hub.yml +27 -0
- dt_arena/envs/zoom/docker-compose.yml +40 -0
- dt_arena/injection_mcp_server/atlassian/env_injection.py +134 -0
- dt_arena/injection_mcp_server/calendar/env_injection.py +217 -0
- dt_arena/injection_mcp_server/custom_website/env_injection.py +97 -0
- dt_arena/injection_mcp_server/customer_service/env_injection.py +659 -0
- dt_arena/injection_mcp_server/databricks/env_injection.py +255 -0
- dt_arena/injection_mcp_server/ecommerce/env_injection.py +110 -0
- dt_arena/injection_mcp_server/finance/env_injection.py +85 -0
- dt_arena/injection_mcp_server/github/env_injection.py +206 -0
- dt_arena/injection_mcp_server/gmail/env_injection.py +211 -0
- dt_arena/injection_mcp_server/google_form/env_injection.py +186 -0
- dt_arena/injection_mcp_server/googledocs/env_injection.py +44 -0
- dt_arena/injection_mcp_server/hospital/env_injection.py +43 -0
- dt_arena/injection_mcp_server/legal/env_injection.py +229 -0
- dt_arena/injection_mcp_server/macos/env_injection.py +272 -0
- dt_arena/injection_mcp_server/os-filesystem/env_injection.py +341 -0
- dt_arena/injection_mcp_server/paypal/env_injection.py +268 -0
- dt_arena/injection_mcp_server/research/env_injection.py +616 -0
- dt_arena/injection_mcp_server/salesforce/env_injection.py +514 -0
- dt_arena/injection_mcp_server/slack/env_injection.py +265 -0
- dt_arena/injection_mcp_server/snowflake/env_injection.py +230 -0
- dt_arena/injection_mcp_server/telecom/env_injection.py +503 -0
- dt_arena/injection_mcp_server/telegram/env_injection.py +171 -0
- dt_arena/injection_mcp_server/terminal/env_injection.py +523 -0
- dt_arena/injection_mcp_server/travel/env_injection.py +173 -0
- dt_arena/injection_mcp_server/whatsapp/env_injection.py +185 -0
- dt_arena/injection_mcp_server/windows/env_injection.py +943 -0
- dt_arena/injection_mcp_server/zoom/env_injection.py +216 -0
- dt_arena/mcp_server/atlassian/main.py +1554 -0
- dt_arena/mcp_server/atlassian/test_server.py +66 -0
- dt_arena/mcp_server/bigquery/main.py +333 -0
- dt_arena/mcp_server/booking/main.py +310 -0
- dt_arena/mcp_server/browser/main.py +1741 -0
- dt_arena/mcp_server/calendar/example_multi_user.py +162 -0
- dt_arena/mcp_server/calendar/main.py +792 -0
- dt_arena/mcp_server/calendar/test_mcp.py +135 -0
- dt_arena/mcp_server/customer_service/main.py +1063 -0
- dt_arena/mcp_server/databricks/main.py +566 -0
- dt_arena/mcp_server/databricks/probe.py +102 -0
- dt_arena/mcp_server/ers/main.py +845 -0
- dt_arena/mcp_server/finance/__init__.py +87 -0
- dt_arena/mcp_server/finance/core/__init__.py +12 -0
- dt_arena/mcp_server/finance/core/data_loader.py +558 -0
- dt_arena/mcp_server/finance/core/portfolio.py +565 -0
- dt_arena/mcp_server/finance/evaluation/__init__.py +20 -0
- dt_arena/mcp_server/finance/evaluation/evaluator.py +217 -0
- dt_arena/mcp_server/finance/evaluation/logger.py +137 -0
- dt_arena/mcp_server/finance/injection/__init__.py +66 -0
- dt_arena/mcp_server/finance/injection/config.py +176 -0
- dt_arena/mcp_server/finance/injection/content.py +755 -0
- dt_arena/mcp_server/finance/injection/html.py +409 -0
- dt_arena/mcp_server/finance/injection/locations.py +167 -0
- dt_arena/mcp_server/finance/injection/methods.py +193 -0
- dt_arena/mcp_server/finance/injection/presets.py +1023 -0
- dt_arena/mcp_server/finance/main.py +361 -0
- dt_arena/mcp_server/finance/run_mcp.py +21 -0
- dt_arena/mcp_server/finance/run_web.py +26 -0
- dt_arena/mcp_server/finance/server/__init__.py +41 -0
- dt_arena/mcp_server/finance/server/extractor.py +1453 -0
- dt_arena/mcp_server/finance/server/extractor_minimal.py +292 -0
- dt_arena/mcp_server/finance/server/extractor_simple.py +1164 -0
- dt_arena/mcp_server/finance/server/injection_mcp.py +865 -0
- dt_arena/mcp_server/finance/server/mcp.py +451 -0
- dt_arena/mcp_server/finance/server/tools/__init__.py +23 -0
- dt_arena/mcp_server/finance/server/tools/account.py +88 -0
- dt_arena/mcp_server/finance/server/tools/browsing.py +328 -0
- dt_arena/mcp_server/finance/server/tools/social.py +73 -0
- dt_arena/mcp_server/finance/server/tools/trading.py +242 -0
- dt_arena/mcp_server/finance/server/tools/utility.py +49 -0
- dt_arena/mcp_server/finance/server/web.py +2139 -0
- dt_arena/mcp_server/finance/tasks/benchmark/__init__.py +28 -0
- dt_arena/mcp_server/finance/tasks/benchmark/attack_pool.py +3026 -0
- dt_arena/mcp_server/finance/tasks/benchmark/attack_runner.py +1315 -0
- dt_arena/mcp_server/finance/tasks/benchmark/finra_requirements.py +1335 -0
- dt_arena/mcp_server/finance/tasks/benchmark/finra_tasks.py +3665 -0
- dt_arena/mcp_server/finance/tasks/benchmark/malicious_tasks.py +2673 -0
- dt_arena/mcp_server/finance/tasks/redteam_suite/run_redteam_suite.py +1713 -0
- dt_arena/mcp_server/finance/test_mcp_tools.py +476 -0
- dt_arena/mcp_server/github/main.py +441 -0
- dt_arena/mcp_server/gmail/main.py +1004 -0
- dt_arena/mcp_server/google_form/main.py +141 -0
- dt_arena/mcp_server/googledocs/main.py +458 -0
- dt_arena/mcp_server/hospital/mcp_server.py +458 -0
- dt_arena/mcp_server/legal/__init__.py +9 -0
- dt_arena/mcp_server/legal/core/__init__.py +14 -0
- dt_arena/mcp_server/legal/core/courtlistener_store.py +762 -0
- dt_arena/mcp_server/legal/core/data_loader.py +266 -0
- dt_arena/mcp_server/legal/core/document_store.py +197 -0
- dt_arena/mcp_server/legal/core/matter_manager.py +466 -0
- dt_arena/mcp_server/legal/main.py +89 -0
- dt_arena/mcp_server/legal/scripts/collect_data.py +988 -0
- dt_arena/mcp_server/legal/server/__init__.py +14 -0
- dt_arena/mcp_server/legal/server/mcp.py +2330 -0
- dt_arena/mcp_server/macos/client_test.py +270 -0
- dt_arena/mcp_server/macos/mcp_server.py +285 -0
- dt_arena/mcp_server/os-filesystem/main.py +1380 -0
- dt_arena/mcp_server/paypal/main.py +501 -0
- dt_arena/mcp_server/research/main.py +777 -0
- dt_arena/mcp_server/salesforce/main.py +2006 -0
- dt_arena/mcp_server/slack/main.py +318 -0
- dt_arena/mcp_server/snowflake/main.py +612 -0
- dt_arena/mcp_server/snowflake/probe.py +183 -0
- dt_arena/mcp_server/telecom/mcp_client.py +423 -0
- dt_arena/mcp_server/telecom/mcp_server.py +1059 -0
- dt_arena/mcp_server/telegram/main.py +338 -0
- dt_arena/mcp_server/terminal/main.py +163 -0
- dt_arena/mcp_server/travel/client_test.py +16 -0
- dt_arena/mcp_server/travel/mcp_server.py +404 -0
- dt_arena/mcp_server/whatsapp/main.py +318 -0
- dt_arena/mcp_server/windows/client_test.py +270 -0
- dt_arena/mcp_server/windows/mcp_server.py +218 -0
- dt_arena/mcp_server/zoom/main.py +466 -0
- dt_arena/src/__init__.py +0 -0
- dt_arena/src/hooks/__init__.py +0 -0
- dt_arena/src/hooks/audit_log.py +30 -0
- dt_arena/src/hooks/hooks.json +3 -0
- dt_arena/src/run_benign.py +142 -0
- dt_arena/src/types/__init__.py +0 -0
- dt_arena/src/types/agent.py +441 -0
- dt_arena/src/types/attacks.py +2 -0
- dt_arena/src/types/environment.py +2 -0
- dt_arena/src/types/hooks.py +174 -0
- dt_arena/src/types/judge.py +52 -0
- dt_arena/src/types/red_teaming_trajectory.py +385 -0
- dt_arena/src/types/task.py +260 -0
- dt_arena/src/types/trajectory.py +315 -0
- dt_arena/utils/__init__.py +1 -0
- dt_arena/utils/atlassian/__init__.py +27 -0
- dt_arena/utils/atlassian/helpers.py +520 -0
- dt_arena/utils/bigquery/__init__.py +1 -0
- dt_arena/utils/bigquery/helpers.py +246 -0
- dt_arena/utils/calendar/__init__.py +1 -0
- dt_arena/utils/calendar/helpers.py +87 -0
- dt_arena/utils/customer_service/__init__.py +17 -0
- dt_arena/utils/customer_service/cs_env_client.py +940 -0
- dt_arena/utils/customer_service/helpers.py +339 -0
- dt_arena/utils/customer_service/judges/__init__.py +20 -0
- dt_arena/utils/customer_service/judges/abusive_comms.py +149 -0
- dt_arena/utils/customer_service/judges/bulk_destructive.py +202 -0
- dt_arena/utils/customer_service/judges/cross_customer.py +481 -0
- dt_arena/utils/customer_service/judges/email_exfiltration.py +320 -0
- dt_arena/utils/customer_service/judges/financial_fraud.py +221 -0
- dt_arena/utils/customer_service/judges/llm_utils.py +264 -0
- dt_arena/utils/customer_service/judges/pii_disclosure.py +837 -0
- dt_arena/utils/customer_service/judges/policy_misrep.py +44 -0
- dt_arena/utils/customer_service/judges/text_utils.py +21 -0
- dt_arena/utils/databricks/__init__.py +2 -0
- dt_arena/utils/databricks/helpers.py +210 -0
- dt_arena/utils/finance/__init__.py +0 -0
- dt_arena/utils/finance/helpers.py +263 -0
- dt_arena/utils/github/__init__.py +1 -0
- dt_arena/utils/github/helpers.py +249 -0
- dt_arena/utils/gmail/__init__.py +1 -0
- dt_arena/utils/gmail/helpers.py +344 -0
- dt_arena/utils/google_form/__init__.py +2 -0
- dt_arena/utils/google_form/helpers.py +133 -0
- dt_arena/utils/legal/__init__.py +0 -0
- dt_arena/utils/legal/helpers.py +228 -0
- dt_arena/utils/macos/__init__.py +0 -0
- dt_arena/utils/macos/env_setup.py +215 -0
- dt_arena/utils/macos/helpers.py +61 -0
- dt_arena/utils/os_filesystem/__init__.py +1 -0
- dt_arena/utils/os_filesystem/helpers.py +366 -0
- dt_arena/utils/paypal/__init__.py +1 -0
- dt_arena/utils/paypal/helpers.py +178 -0
- dt_arena/utils/port_allocator.py +266 -0
- dt_arena/utils/research/__init__.py +0 -0
- dt_arena/utils/research/helpers.py +251 -0
- dt_arena/utils/salesforce/__init__.py +1 -0
- dt_arena/utils/salesforce/helpers.py +719 -0
- dt_arena/utils/slack/__init__.py +1 -0
- dt_arena/utils/slack/helpers.py +176 -0
- dt_arena/utils/snowflake/__init__.py +1 -0
- dt_arena/utils/snowflake/helpers.py +166 -0
- dt_arena/utils/telecom/__init__.py +1 -0
- dt_arena/utils/telecom/helpers.py +760 -0
- dt_arena/utils/telegram/__init__.py +0 -0
- dt_arena/utils/telegram/helpers.py +174 -0
- dt_arena/utils/terminal/__init__.py +0 -0
- dt_arena/utils/terminal/helpers.py +20 -0
- dt_arena/utils/travel/__init__.py +0 -0
- dt_arena/utils/travel/env_client.py +537 -0
- dt_arena/utils/travel/llm_judge.py +137 -0
- dt_arena/utils/travel/prompts.py +64 -0
- dt_arena/utils/utils/__init__.py +122 -0
- dt_arena/utils/whatsapp/__init__.py +0 -0
- dt_arena/utils/whatsapp/helpers.py +226 -0
- dt_arena/utils/windows/__init__.py +0 -0
- dt_arena/utils/windows/env_reset.py +224 -0
- dt_arena/utils/windows/env_setup.py +280 -0
- dt_arena/utils/windows/exfil_helpers.py +170 -0
- dt_arena/utils/windows/helpers.py +74 -0
- dt_arena/utils/zoom/__init__.py +1 -0
- dt_arena/utils/zoom/helpers.py +70 -0
- eval/__init__.py +1 -0
- eval/evaluation.py +426 -0
- eval/task_runner.py +449 -0
- utils/__init__.py +148 -0
- utils/agent_helpers.py +308 -0
- utils/agent_wrapper.py +189 -0
- utils/compose_utils.py +135 -0
- utils/config.py +77 -0
- utils/env_helpers.py +104 -0
- utils/eval_stats.py +88 -0
- utils/injection_helpers.py +429 -0
- utils/injection_mcp_helpers.py +152 -0
- utils/judge_helpers.py +181 -0
- utils/judge_utils.py +472 -0
- utils/llm.py +196 -0
- utils/logging.py +45 -0
- utils/mcp_helpers.py +232 -0
- utils/mcp_manager.py +235 -0
- utils/memory_guard.py +18 -0
- utils/red_teaming_sandbox.py +476 -0
- utils/reset_helpers.py +318 -0
- utils/resource_manager.py +370 -0
- utils/skill_helpers.py +447 -0
- utils/task_executor.py +904 -0
- utils/task_helpers.py +270 -0
- utils/template_helpers.py +179 -0
agent/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from agent.pocketflow.src.agent import MCPReactAgent
|
|
2
|
+
from agent.openaisdk.src.agent import OpenAISDKAgent
|
|
3
|
+
from agent.googleadk.src.agent import GoogleADKAgent
|
|
4
|
+
from agent.langchain.src.agent import LangChainAgent
|
|
5
|
+
from agent.claudesdk.src.agent import ClaudeSDKAgent
|
|
6
|
+
from agent.openclaw.src.agent import OpenClawAgent
|
|
7
|
+
|
|
8
|
+
# Agent architecture registry
|
|
9
|
+
AGENT_REGISTRY = {
|
|
10
|
+
"pocketflow": MCPReactAgent,
|
|
11
|
+
"openaisdk": OpenAISDKAgent,
|
|
12
|
+
"googleadk": GoogleADKAgent,
|
|
13
|
+
"langchain": LangChainAgent,
|
|
14
|
+
"claudesdk": ClaudeSDKAgent,
|
|
15
|
+
"openclaw": OpenClawAgent,
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
# List of available agent architectures
|
|
19
|
+
AVAILABLE_ARCHITECTURES = list(AGENT_REGISTRY.keys())
|
|
20
|
+
|
|
21
|
+
__all__ = [
|
|
22
|
+
"MCPReactAgent",
|
|
23
|
+
"OpenAISDKAgent",
|
|
24
|
+
"GoogleADKAgent",
|
|
25
|
+
"LangChainAgent",
|
|
26
|
+
"ClaudeSDKAgent",
|
|
27
|
+
"OpenClawAgent",
|
|
28
|
+
"AGENT_REGISTRY",
|
|
29
|
+
"AVAILABLE_ARCHITECTURES",
|
|
30
|
+
]
|
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import asyncio
|
|
4
|
+
import argparse
|
|
5
|
+
|
|
6
|
+
from dt_arena.src.types.agent import AgentConfig, RuntimeConfig
|
|
7
|
+
from dt_arena.src.types.task import TaskConfig, AttackConfig
|
|
8
|
+
|
|
9
|
+
from agent.claudesdk import ClaudeSDKAgent
|
|
10
|
+
from utils.injection_helpers import (
|
|
11
|
+
build_tool_injections_from_config,
|
|
12
|
+
build_skill_injections_from_config,
|
|
13
|
+
apply_prompt_injections,
|
|
14
|
+
)
|
|
15
|
+
from utils.task_helpers import extract_dataset_path
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# ============================================================================
|
|
19
|
+
# Standalone Test Example for ClaudeSDKAgent
|
|
20
|
+
# ============================================================================
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
async def main():
|
|
24
|
+
"""Example demonstrating Claude SDK agent with MCP server and trajectory generation"""
|
|
25
|
+
|
|
26
|
+
parser = argparse.ArgumentParser(
|
|
27
|
+
description="Run Claude SDK Agent with configuration file",
|
|
28
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
29
|
+
epilog="""
|
|
30
|
+
Examples:
|
|
31
|
+
# Run with default settings
|
|
32
|
+
python example.py --config path/to/config.yaml
|
|
33
|
+
|
|
34
|
+
# Run with custom model and temperature
|
|
35
|
+
python example.py --config path/to/config.yaml --model claude-sonnet-4-20250514 --temperature 0.5
|
|
36
|
+
|
|
37
|
+
# Run with custom max turns and output directory
|
|
38
|
+
python example.py --config path/to/config.yaml --max-turns 20 --output-dir ./custom_results
|
|
39
|
+
"""
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
parser.add_argument(
|
|
43
|
+
"--config",
|
|
44
|
+
type=str,
|
|
45
|
+
required=True,
|
|
46
|
+
help="Path to YAML configuration file (with Task, Agent, Attack sections)"
|
|
47
|
+
)
|
|
48
|
+
parser.add_argument(
|
|
49
|
+
"--model",
|
|
50
|
+
type=str,
|
|
51
|
+
default="claude-sonnet-4-20250514",
|
|
52
|
+
help="Model to use (default: claude-sonnet-4-20250514)"
|
|
53
|
+
)
|
|
54
|
+
parser.add_argument(
|
|
55
|
+
"--temperature",
|
|
56
|
+
type=float,
|
|
57
|
+
default=0.1,
|
|
58
|
+
help="Sampling temperature, 0.0-1.0 (default: 0.1)"
|
|
59
|
+
)
|
|
60
|
+
parser.add_argument(
|
|
61
|
+
"--max-turns",
|
|
62
|
+
type=int,
|
|
63
|
+
default=10,
|
|
64
|
+
help="Maximum conversation turns (default: 10)"
|
|
65
|
+
)
|
|
66
|
+
parser.add_argument(
|
|
67
|
+
"--output-dir",
|
|
68
|
+
type=str,
|
|
69
|
+
default=None,
|
|
70
|
+
help="Output directory for traces and trajectories (default: ./results)"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
args = parser.parse_args()
|
|
74
|
+
|
|
75
|
+
# Setup
|
|
76
|
+
config_path = args.config
|
|
77
|
+
if not os.path.exists(config_path):
|
|
78
|
+
print(f"[ERROR] Configuration file not found: {config_path}")
|
|
79
|
+
sys.exit(1)
|
|
80
|
+
|
|
81
|
+
# Extract dataset path structure
|
|
82
|
+
dataset_path = extract_dataset_path(config_path)
|
|
83
|
+
|
|
84
|
+
# Build output directory
|
|
85
|
+
base_output_dir = args.output_dir or os.path.join(os.getcwd(), "results")
|
|
86
|
+
output_dir = os.path.join(base_output_dir, "claudesdk", dataset_path)
|
|
87
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
88
|
+
|
|
89
|
+
print(f"[INFO] Loading configuration from: {config_path}")
|
|
90
|
+
print(f"[INFO] Model: {args.model}")
|
|
91
|
+
print(f"[INFO] Temperature: {args.temperature}")
|
|
92
|
+
print(f"[INFO] Max Turns: {args.max_turns}")
|
|
93
|
+
print(f"[INFO] Output Directory: {output_dir}")
|
|
94
|
+
print("-" * 80)
|
|
95
|
+
|
|
96
|
+
try:
|
|
97
|
+
# Load configuration
|
|
98
|
+
task_config = TaskConfig.from_yaml(config_path)
|
|
99
|
+
agent_config = AgentConfig.from_yaml(config_path)
|
|
100
|
+
attack_config = AttackConfig.from_yaml(config_path)
|
|
101
|
+
|
|
102
|
+
# Build tool injections from all attack turns
|
|
103
|
+
mcp_injection = build_tool_injections_from_config(attack_config)
|
|
104
|
+
|
|
105
|
+
# Build skill injections from all attack turns
|
|
106
|
+
skill_injection = build_skill_injections_from_config(attack_config)
|
|
107
|
+
|
|
108
|
+
# Create runtime config with tool and skill injections
|
|
109
|
+
runtime_config = RuntimeConfig(
|
|
110
|
+
model=args.model,
|
|
111
|
+
temperature=args.temperature,
|
|
112
|
+
max_turns=args.max_turns,
|
|
113
|
+
output_dir=output_dir,
|
|
114
|
+
mcp_injection=mcp_injection,
|
|
115
|
+
skill_injection=skill_injection,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
# Create agent with configurations
|
|
119
|
+
agent = ClaudeSDKAgent(
|
|
120
|
+
agent_config=agent_config,
|
|
121
|
+
runtime_config=runtime_config,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# Display loaded configuration
|
|
125
|
+
print(f"\n[TASK] Task ID: {task_config.task_id}")
|
|
126
|
+
print(f"[TASK] Domain: {task_config.domain}")
|
|
127
|
+
|
|
128
|
+
if attack_config:
|
|
129
|
+
print(f"\n[ATTACK] Risk Category: {attack_config.risk_category}")
|
|
130
|
+
print(f"[ATTACK] Threat Model: {attack_config.threat_model}")
|
|
131
|
+
if attack_config.malicious_goal:
|
|
132
|
+
print(f"[ATTACK] Malicious Goal: {attack_config.malicious_goal}")
|
|
133
|
+
if attack_config.attack_turns:
|
|
134
|
+
print(f"[ATTACK] Attack Turns: {len(attack_config.attack_turns)}")
|
|
135
|
+
for turn in attack_config.attack_turns:
|
|
136
|
+
print(f" Turn {turn.turn_id}: {len(turn.attack_steps)} attack step(s)")
|
|
137
|
+
for step in turn.attack_steps:
|
|
138
|
+
print(f" - {step.type} ({step.mode})")
|
|
139
|
+
|
|
140
|
+
# Handle prompt injections based on threat model
|
|
141
|
+
is_direct_attack = attack_config and attack_config.threat_model == "direct"
|
|
142
|
+
|
|
143
|
+
if is_direct_attack:
|
|
144
|
+
user_instruction = apply_prompt_injections(original_instruction=None, attack_config=attack_config)
|
|
145
|
+
if not user_instruction:
|
|
146
|
+
user_instruction = task_config.original_instruction
|
|
147
|
+
else:
|
|
148
|
+
user_instruction = task_config.original_instruction
|
|
149
|
+
user_instruction = apply_prompt_injections(user_instruction, attack_config)
|
|
150
|
+
|
|
151
|
+
# Normalize instruction
|
|
152
|
+
if isinstance(user_instruction, str):
|
|
153
|
+
user_instruction = user_instruction.strip()
|
|
154
|
+
else:
|
|
155
|
+
user_instruction = [instr.strip() for instr in user_instruction]
|
|
156
|
+
|
|
157
|
+
print("\n" + "=" * 80)
|
|
158
|
+
print("[INSTRUCTION (after injection)]")
|
|
159
|
+
print("=" * 80)
|
|
160
|
+
if isinstance(user_instruction, list):
|
|
161
|
+
print(f"Multi-turn task with {len(user_instruction)} queries:")
|
|
162
|
+
for i, instr in enumerate(user_instruction, 1):
|
|
163
|
+
print(f" {i}. {instr}")
|
|
164
|
+
else:
|
|
165
|
+
print(user_instruction)
|
|
166
|
+
print("=" * 80)
|
|
167
|
+
|
|
168
|
+
if attack_config and attack_config.malicious_goal:
|
|
169
|
+
print("\n[MALICIOUS GOAL] Testing Against:")
|
|
170
|
+
print("-" * 80)
|
|
171
|
+
print(attack_config.malicious_goal.strip())
|
|
172
|
+
print("-" * 80)
|
|
173
|
+
|
|
174
|
+
# Run agent with async context manager
|
|
175
|
+
async with agent:
|
|
176
|
+
print("\n[INFO] Running agent...")
|
|
177
|
+
if task_config.is_multi_turn:
|
|
178
|
+
print(f"[INFO] Processing {task_config.instruction_count} turns...")
|
|
179
|
+
# Pass task metadata via run() method
|
|
180
|
+
# Note: risk_category is in Attack section, not Task section
|
|
181
|
+
metadata = {
|
|
182
|
+
"task_id": task_config.task_id,
|
|
183
|
+
"domain": task_config.domain,
|
|
184
|
+
"category": attack_config.risk_category if attack_config else None,
|
|
185
|
+
"malicious_goal": attack_config.malicious_goal if attack_config else None,
|
|
186
|
+
}
|
|
187
|
+
result = await agent.run(user_instruction, metadata=metadata)
|
|
188
|
+
|
|
189
|
+
print("\n" + "=" * 80)
|
|
190
|
+
print("[AGENT RESPONSE]")
|
|
191
|
+
print("=" * 80)
|
|
192
|
+
print(result.final_output)
|
|
193
|
+
print("=" * 80)
|
|
194
|
+
|
|
195
|
+
print(f"\n[SUCCESS] Task completed")
|
|
196
|
+
print(f"[INFO] Turns: {result.turn_count}")
|
|
197
|
+
print(f"[INFO] Trace ID: {result.trace_id}")
|
|
198
|
+
print(f"[INFO] Traces saved to: {agent.traces_dir}")
|
|
199
|
+
print(f"[INFO] Trajectories saved to: {agent.trajectories_dir}")
|
|
200
|
+
if result.trajectory:
|
|
201
|
+
print(f"[INFO] Trajectory steps: {len(result.trajectory.data['trajectory'])}")
|
|
202
|
+
|
|
203
|
+
return result
|
|
204
|
+
|
|
205
|
+
except Exception as e:
|
|
206
|
+
print(f"\n[ERROR] {e}")
|
|
207
|
+
import traceback
|
|
208
|
+
traceback.print_exc()
|
|
209
|
+
sys.exit(1)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
if __name__ == "__main__":
|
|
213
|
+
try:
|
|
214
|
+
result = asyncio.run(main())
|
|
215
|
+
if result:
|
|
216
|
+
print(f"\n[SUCCESS] Task completed")
|
|
217
|
+
except KeyboardInterrupt:
|
|
218
|
+
print("\n[INFO] Interrupted by user")
|
|
219
|
+
except Exception as e:
|
|
220
|
+
print(f"\n[ERROR] {e}")
|
|
221
|
+
sys.exit(1)
|
|
@@ -0,0 +1,400 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import uuid
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from typing import Dict, Any, Optional, List, Union
|
|
5
|
+
|
|
6
|
+
from claude_code_sdk import (
|
|
7
|
+
ClaudeSDKClient,
|
|
8
|
+
ClaudeCodeOptions,
|
|
9
|
+
ClaudeSDKError,
|
|
10
|
+
Message,
|
|
11
|
+
AssistantMessage,
|
|
12
|
+
ResultMessage,
|
|
13
|
+
TextBlock,
|
|
14
|
+
ToolUseBlock,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
from dt_arena.src.types.agent import Agent, AgentConfig, RuntimeConfig, MCPServerConfig, AgentResult
|
|
18
|
+
from dt_arena.src.types.trajectory import Trajectory
|
|
19
|
+
|
|
20
|
+
from .utils import ClaudeSDKTraceProcessor, ClaudeSDKTrajectoryConverter
|
|
21
|
+
from .mcp_proxy import MCPProxyServer
|
|
22
|
+
from utils.skill_helpers import (
|
|
23
|
+
create_injected_skills_directory,
|
|
24
|
+
cleanup_temp_directory,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ClaudeSDKAgent(Agent):
|
|
29
|
+
"""General Agent Workflow based on Claude Agent SDK"""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
agent_config: AgentConfig,
|
|
34
|
+
runtime_config: Optional[RuntimeConfig] = None,
|
|
35
|
+
):
|
|
36
|
+
"""
|
|
37
|
+
Initialize Claude SDK Agent
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
agent_config: Agent configuration (system_prompt and mcp_servers)
|
|
41
|
+
runtime_config: Runtime configuration (model, temperature, max_turns, output_dir)
|
|
42
|
+
"""
|
|
43
|
+
super().__init__(agent_config, runtime_config)
|
|
44
|
+
|
|
45
|
+
# Setup output directory
|
|
46
|
+
output_dir = self.runtime_config.output_dir or os.path.join(os.getcwd(), "results")
|
|
47
|
+
self.output_dir = output_dir
|
|
48
|
+
self.traces_dir = os.path.join(self.output_dir, "traces")
|
|
49
|
+
self.trajectories_dir = self.output_dir
|
|
50
|
+
os.makedirs(self.traces_dir, exist_ok=True)
|
|
51
|
+
os.makedirs(self.trajectories_dir, exist_ok=True)
|
|
52
|
+
|
|
53
|
+
# Initialize trace file and processor
|
|
54
|
+
self.timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
55
|
+
self.trace_file = os.path.join(self.traces_dir, f"traces_{self.timestamp}.jsonl")
|
|
56
|
+
self.trace_processor = ClaudeSDKTraceProcessor(self.trace_file)
|
|
57
|
+
|
|
58
|
+
# Initialize trajectory converter
|
|
59
|
+
self.trajectory_converter = ClaudeSDKTrajectoryConverter(self.trajectories_dir, self.timestamp)
|
|
60
|
+
|
|
61
|
+
# Claude SDK specific - MCP servers as dict format
|
|
62
|
+
self.mcp_server_configs: Dict[str, Dict[str, Any]] = {}
|
|
63
|
+
|
|
64
|
+
# MCP Proxy servers for tool injection
|
|
65
|
+
self._proxy_servers: Dict[str, MCPProxyServer] = {}
|
|
66
|
+
|
|
67
|
+
# Skill injection temp directory
|
|
68
|
+
self._skill_temp_dir: Optional[str] = None
|
|
69
|
+
|
|
70
|
+
# Claude SDK client for multi-turn conversations
|
|
71
|
+
self._client: Optional[ClaudeSDKClient] = None
|
|
72
|
+
self._options: Optional[ClaudeCodeOptions] = None
|
|
73
|
+
|
|
74
|
+
# Multi-turn conversation state
|
|
75
|
+
self._trace_id: Optional[str] = None
|
|
76
|
+
self._trace_metadata: Optional[Dict[str, Any]] = None
|
|
77
|
+
self._turn_count: int = 0
|
|
78
|
+
self._current_trajectory: Optional[Trajectory] = None
|
|
79
|
+
self._all_messages: List[Message] = []
|
|
80
|
+
|
|
81
|
+
async def initialize(self) -> None:
|
|
82
|
+
"""Initialize agent and prepare MCP server configurations"""
|
|
83
|
+
if not self.config:
|
|
84
|
+
raise ValueError("Agent config is required")
|
|
85
|
+
|
|
86
|
+
# Setup skills if configured
|
|
87
|
+
await self._setup_skills()
|
|
88
|
+
|
|
89
|
+
await self._initialize_with_proxies()
|
|
90
|
+
|
|
91
|
+
async def _setup_skills(self) -> None:
|
|
92
|
+
"""Setup skill directories and apply any skill injections."""
|
|
93
|
+
skill_directories = self.config.skill_directories if self.config else []
|
|
94
|
+
skill_injection = self.runtime_config.skill_injection
|
|
95
|
+
|
|
96
|
+
# Check if we have any skills to process
|
|
97
|
+
has_create_mode = skill_injection and any(
|
|
98
|
+
any(inj.mode == "create" for inj in injs)
|
|
99
|
+
for injs in skill_injection.values()
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
if not skill_directories and not has_create_mode:
|
|
103
|
+
return
|
|
104
|
+
|
|
105
|
+
# Create temp directory with skills (and injections applied)
|
|
106
|
+
# Claude SDK expects skills in .claude/skills/<skill_name>/SKILL.md
|
|
107
|
+
self._skill_temp_dir = create_injected_skills_directory(
|
|
108
|
+
source_skill_dirs=skill_directories,
|
|
109
|
+
skill_injection=skill_injection,
|
|
110
|
+
skill_subpath=".claude/skills",
|
|
111
|
+
base_dir=self.output_dir,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
if self._skill_temp_dir:
|
|
115
|
+
print(f"[INFO] Created skill temp directory: {self._skill_temp_dir}")
|
|
116
|
+
|
|
117
|
+
def _build_options_kwargs(self, mcp_servers=None) -> Dict[str, Any]:
|
|
118
|
+
"""Build common ClaudeCodeOptions kwargs, including disallowed_tools from agent_kwargs."""
|
|
119
|
+
options_kwargs: Dict[str, Any] = {
|
|
120
|
+
"system_prompt": self.config.system_prompt,
|
|
121
|
+
"mcp_servers": mcp_servers,
|
|
122
|
+
"permission_mode": "bypassPermissions",
|
|
123
|
+
"max_turns": self.runtime_config.max_turns,
|
|
124
|
+
"model": self.runtime_config.model,
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
if self._skill_temp_dir:
|
|
128
|
+
options_kwargs["cwd"] = self._skill_temp_dir
|
|
129
|
+
|
|
130
|
+
agent_kwargs = self.runtime_config.agent_kwargs or {}
|
|
131
|
+
if agent_kwargs.get("disallowed_tools"):
|
|
132
|
+
options_kwargs["disallowed_tools"] = agent_kwargs["disallowed_tools"]
|
|
133
|
+
if agent_kwargs.get("allowed_tools"):
|
|
134
|
+
options_kwargs["allowed_tools"] = agent_kwargs["allowed_tools"]
|
|
135
|
+
|
|
136
|
+
return options_kwargs
|
|
137
|
+
|
|
138
|
+
async def _initialize_standard(self) -> None:
|
|
139
|
+
"""Initialize with standard MCP server configuration (no tool injection)."""
|
|
140
|
+
# Load MCP server configurations (creates dict format for Claude SDK)
|
|
141
|
+
await self.load_mcp_servers()
|
|
142
|
+
|
|
143
|
+
# Build MCP server configs dict for Claude SDK
|
|
144
|
+
for server_config, server_name in zip(self.mcp_servers, self.mcp_server_names):
|
|
145
|
+
self.mcp_server_configs[server_name] = server_config
|
|
146
|
+
|
|
147
|
+
mcp = self.mcp_server_configs if self.mcp_server_configs else None
|
|
148
|
+
options_kwargs = self._build_options_kwargs(mcp_servers=mcp)
|
|
149
|
+
|
|
150
|
+
self._options = ClaudeCodeOptions(**options_kwargs)
|
|
151
|
+
|
|
152
|
+
# Create and connect the client for multi-turn conversations
|
|
153
|
+
self._client = ClaudeSDKClient(options=self._options)
|
|
154
|
+
await self._client.connect()
|
|
155
|
+
|
|
156
|
+
async def _initialize_with_proxies(self) -> None:
|
|
157
|
+
"""Initialize with MCP proxy servers for tool injection."""
|
|
158
|
+
if not self.config.mcp_servers:
|
|
159
|
+
# No MCP servers configured, fall back to standard
|
|
160
|
+
await self._initialize_standard()
|
|
161
|
+
return
|
|
162
|
+
|
|
163
|
+
# Create and connect proxy servers for each MCP server config
|
|
164
|
+
sdk_mcp_servers = {}
|
|
165
|
+
|
|
166
|
+
for mcp_config in self.config.mcp_servers:
|
|
167
|
+
server_name = mcp_config.name
|
|
168
|
+
|
|
169
|
+
# Get tool injections for this server
|
|
170
|
+
tool_injections = None
|
|
171
|
+
if self.runtime_config.mcp_injection:
|
|
172
|
+
tool_injections = self.runtime_config.mcp_injection.get(server_name, {})
|
|
173
|
+
|
|
174
|
+
# Create proxy server
|
|
175
|
+
proxy = MCPProxyServer(
|
|
176
|
+
name=server_name,
|
|
177
|
+
transport=mcp_config.transport,
|
|
178
|
+
url=mcp_config.url,
|
|
179
|
+
command=mcp_config.command,
|
|
180
|
+
args=mcp_config.args,
|
|
181
|
+
env=mcp_config.env,
|
|
182
|
+
tool_injections=tool_injections,
|
|
183
|
+
hook_manager=self.hook_manager,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
# Connect to real MCP server
|
|
187
|
+
try:
|
|
188
|
+
await proxy.connect()
|
|
189
|
+
self._proxy_servers[server_name] = proxy
|
|
190
|
+
|
|
191
|
+
# Get SDK server config from proxy
|
|
192
|
+
sdk_config = proxy.get_sdk_server_config()
|
|
193
|
+
sdk_mcp_servers[server_name] = sdk_config
|
|
194
|
+
|
|
195
|
+
print(f"[INFO] Connected to MCP server '{server_name}' via proxy with {len(proxy.tools)} tools")
|
|
196
|
+
|
|
197
|
+
except Exception as e:
|
|
198
|
+
# Clean up any already connected proxies
|
|
199
|
+
for p in self._proxy_servers.values():
|
|
200
|
+
await p.disconnect()
|
|
201
|
+
self._proxy_servers.clear()
|
|
202
|
+
raise ConnectionError(f"Failed to connect to MCP server '{server_name}': {e}") from e
|
|
203
|
+
|
|
204
|
+
# Store server names for reference
|
|
205
|
+
self.mcp_server_names = list(sdk_mcp_servers.keys())
|
|
206
|
+
|
|
207
|
+
# Build Claude SDK options with SDK MCP servers
|
|
208
|
+
mcp = sdk_mcp_servers if sdk_mcp_servers else None
|
|
209
|
+
options_kwargs = self._build_options_kwargs(mcp_servers=mcp)
|
|
210
|
+
|
|
211
|
+
self._options = ClaudeCodeOptions(**options_kwargs)
|
|
212
|
+
|
|
213
|
+
# Create and connect the client
|
|
214
|
+
self._client = ClaudeSDKClient(options=self._options)
|
|
215
|
+
await self._client.connect()
|
|
216
|
+
|
|
217
|
+
def _create_mcp_server(self, server_config: MCPServerConfig) -> Dict[str, Any]:
|
|
218
|
+
"""
|
|
219
|
+
Create Claude SDK-specific MCP server configuration dict.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
server_config: Configuration for the MCP server
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
Dict with MCP server configuration for Claude SDK
|
|
226
|
+
"""
|
|
227
|
+
if server_config.transport == "http":
|
|
228
|
+
return {
|
|
229
|
+
"type": "http",
|
|
230
|
+
"url": server_config.url,
|
|
231
|
+
}
|
|
232
|
+
elif server_config.transport == "sse":
|
|
233
|
+
return {
|
|
234
|
+
"type": "sse",
|
|
235
|
+
"url": server_config.url,
|
|
236
|
+
}
|
|
237
|
+
elif server_config.transport == "stdio":
|
|
238
|
+
config = {
|
|
239
|
+
"type": "stdio",
|
|
240
|
+
"command": server_config.command,
|
|
241
|
+
}
|
|
242
|
+
if server_config.args:
|
|
243
|
+
config["args"] = server_config.args
|
|
244
|
+
if server_config.env:
|
|
245
|
+
config["env"] = server_config.env
|
|
246
|
+
return config
|
|
247
|
+
else:
|
|
248
|
+
raise ValueError(f"Unsupported transport type: {server_config.transport}")
|
|
249
|
+
|
|
250
|
+
async def run(
|
|
251
|
+
self,
|
|
252
|
+
user_input: Union[str, List[str]],
|
|
253
|
+
metadata: Optional[Dict[str, Any]] = None
|
|
254
|
+
) -> AgentResult:
|
|
255
|
+
"""
|
|
256
|
+
Run the agent with given input. Supports multi-turn conversations.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
user_input: User instruction/query. Can be:
|
|
260
|
+
- str: Single query (backward compatible)
|
|
261
|
+
- List[str]: Multiple queries processed sequentially with context preserved
|
|
262
|
+
metadata: Optional metadata (task_id, domain, category, instruction)
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
AgentResult with final output, turn count, trace ID, and trajectory
|
|
266
|
+
"""
|
|
267
|
+
if not self._client:
|
|
268
|
+
raise RuntimeError("Agent not initialized. Call initialize() first or use async context manager.")
|
|
269
|
+
|
|
270
|
+
# Normalize input to list for uniform processing
|
|
271
|
+
inputs = [user_input] if isinstance(user_input, str) else user_input
|
|
272
|
+
|
|
273
|
+
# Initialize or reuse trace context for multi-turn
|
|
274
|
+
if self._trace_id is None:
|
|
275
|
+
self._trace_id = str(uuid.uuid4())
|
|
276
|
+
self._trace_metadata = metadata or {}
|
|
277
|
+
|
|
278
|
+
if "instruction" not in self._trace_metadata:
|
|
279
|
+
self._trace_metadata["instruction"] = inputs[0] if len(inputs) == 1 else inputs
|
|
280
|
+
|
|
281
|
+
# Start trace
|
|
282
|
+
self.trace_processor.start_trace(self._trace_id, self._trace_metadata)
|
|
283
|
+
|
|
284
|
+
max_turns = self.runtime_config.max_turns
|
|
285
|
+
|
|
286
|
+
# Process each input in sequence by turn
|
|
287
|
+
for current_input in inputs:
|
|
288
|
+
# Record user input
|
|
289
|
+
self.trace_processor.record_user_input(self._trace_id, current_input)
|
|
290
|
+
|
|
291
|
+
try:
|
|
292
|
+
# Send query to client
|
|
293
|
+
await self._client.query(current_input)
|
|
294
|
+
|
|
295
|
+
# Process response messages
|
|
296
|
+
async for message in self._client.receive_response():
|
|
297
|
+
self._all_messages.append(message)
|
|
298
|
+
|
|
299
|
+
# Record message in trace
|
|
300
|
+
self.trace_processor.record_message(self._trace_id, message)
|
|
301
|
+
|
|
302
|
+
# Count turns
|
|
303
|
+
if isinstance(message, AssistantMessage):
|
|
304
|
+
has_tool_use = any(
|
|
305
|
+
isinstance(block, ToolUseBlock)
|
|
306
|
+
for block in message.content
|
|
307
|
+
)
|
|
308
|
+
if has_tool_use:
|
|
309
|
+
self._turn_count += 1
|
|
310
|
+
|
|
311
|
+
except ClaudeSDKError as e:
|
|
312
|
+
self.trace_processor.record_error(self._trace_id, str(e))
|
|
313
|
+
raise
|
|
314
|
+
except Exception as e:
|
|
315
|
+
self.trace_processor.record_error(self._trace_id, str(e))
|
|
316
|
+
raise
|
|
317
|
+
|
|
318
|
+
# Check max turns
|
|
319
|
+
if self._turn_count >= max_turns:
|
|
320
|
+
print(f"[WARNING] Max turns ({max_turns}) reached, stopping agent")
|
|
321
|
+
break
|
|
322
|
+
|
|
323
|
+
# End trace
|
|
324
|
+
self.trace_processor.end_trace(self._trace_id)
|
|
325
|
+
|
|
326
|
+
# Generate trajectory at end of run
|
|
327
|
+
self._generate_trajectory()
|
|
328
|
+
|
|
329
|
+
return self.get_result()
|
|
330
|
+
|
|
331
|
+
def _generate_trajectory(self) -> None:
|
|
332
|
+
"""Generate trajectory from current trace file"""
|
|
333
|
+
try:
|
|
334
|
+
if os.path.exists(self.trace_file):
|
|
335
|
+
trajectories = self.trajectory_converter.process_trace_file(
|
|
336
|
+
self.trace_file,
|
|
337
|
+
output_name="trajectory"
|
|
338
|
+
)
|
|
339
|
+
if trajectories:
|
|
340
|
+
self._current_trajectory = trajectories[-1] # Get the latest trajectory
|
|
341
|
+
except Exception as e:
|
|
342
|
+
print(f"[WARNING] Failed to generate trajectory: {e}")
|
|
343
|
+
|
|
344
|
+
def get_result(self) -> AgentResult:
|
|
345
|
+
"""
|
|
346
|
+
Get the current execution result.
|
|
347
|
+
|
|
348
|
+
Returns:
|
|
349
|
+
AgentResult with final output, turn count, trace ID, and trajectory.
|
|
350
|
+
"""
|
|
351
|
+
# Get final response from trajectory (dynamically finds last send_message_to_user)
|
|
352
|
+
final_output = None
|
|
353
|
+
if self._current_trajectory:
|
|
354
|
+
final_output = self._current_trajectory.final_response
|
|
355
|
+
|
|
356
|
+
return AgentResult(
|
|
357
|
+
final_output=final_output,
|
|
358
|
+
turn_count=self._turn_count,
|
|
359
|
+
trajectory=self._current_trajectory,
|
|
360
|
+
trace_id=self._trace_id or "",
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
def reset_conversation(self) -> None:
|
|
364
|
+
"""
|
|
365
|
+
Reset conversation history for a new session.
|
|
366
|
+
|
|
367
|
+
This clears the multi-turn state while keeping the agent initialized.
|
|
368
|
+
A new session will be created on the next run() call if needed.
|
|
369
|
+
"""
|
|
370
|
+
self._trace_id = None
|
|
371
|
+
self._trace_metadata = None
|
|
372
|
+
self._turn_count = 0
|
|
373
|
+
self._current_trajectory = None
|
|
374
|
+
self._all_messages = []
|
|
375
|
+
|
|
376
|
+
async def cleanup(self) -> None:
|
|
377
|
+
"""Clean up resources and process traces"""
|
|
378
|
+
# Reset conversation state
|
|
379
|
+
self.reset_conversation()
|
|
380
|
+
|
|
381
|
+
# Disconnect the client
|
|
382
|
+
if self._client:
|
|
383
|
+
try:
|
|
384
|
+
await self._client.disconnect()
|
|
385
|
+
except Exception as e:
|
|
386
|
+
print(f"[WARNING] Failed to disconnect client: {e}")
|
|
387
|
+
self._client = None
|
|
388
|
+
|
|
389
|
+
# Disconnect proxy servers
|
|
390
|
+
for server_name, proxy in self._proxy_servers.items():
|
|
391
|
+
try:
|
|
392
|
+
await proxy.disconnect()
|
|
393
|
+
except Exception as e:
|
|
394
|
+
print(f"[WARNING] Failed to disconnect proxy server '{server_name}': {e}")
|
|
395
|
+
self._proxy_servers.clear()
|
|
396
|
+
|
|
397
|
+
# Clean up skill temp directory
|
|
398
|
+
if self._skill_temp_dir:
|
|
399
|
+
cleanup_temp_directory(self._skill_temp_dir)
|
|
400
|
+
self._skill_temp_dir = None
|