decodingtrust-agent-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent/__init__.py +30 -0
- agent/claudesdk/__init__.py +8 -0
- agent/claudesdk/example.py +221 -0
- agent/claudesdk/src/__init__.py +8 -0
- agent/claudesdk/src/agent.py +400 -0
- agent/claudesdk/src/mcp_proxy.py +409 -0
- agent/claudesdk/src/utils.py +420 -0
- agent/googleadk/__init__.py +15 -0
- agent/googleadk/example.py +237 -0
- agent/googleadk/src/__init__.py +12 -0
- agent/googleadk/src/agent.py +401 -0
- agent/googleadk/src/mcp_wrapper.py +163 -0
- agent/googleadk/src/utils.py +602 -0
- agent/langchain/__init__.py +8 -0
- agent/langchain/example.py +213 -0
- agent/langchain/src/__init__.py +8 -0
- agent/langchain/src/agent.py +645 -0
- agent/langchain/src/utils.py +433 -0
- agent/openaisdk/__init__.py +17 -0
- agent/openaisdk/example.py +228 -0
- agent/openaisdk/src/__init__.py +12 -0
- agent/openaisdk/src/agent.py +491 -0
- agent/openaisdk/src/agent_wrapper.py +143 -0
- agent/openaisdk/src/mcp_wrapper.py +395 -0
- agent/openaisdk/src/utils.py +493 -0
- agent/openclaw/__init__.py +10 -0
- agent/openclaw/example.py +251 -0
- agent/openclaw/src/__init__.py +14 -0
- agent/openclaw/src/agent.py +930 -0
- agent/openclaw/src/helpers/__init__.py +1 -0
- agent/openclaw/src/helpers/auth_helpers.py +55 -0
- agent/openclaw/src/mcp_proxy.py +564 -0
- agent/openclaw/src/plugin_generator.py +231 -0
- agent/openclaw/src/utils.py +341 -0
- agent/pocketflow/__init__.py +18 -0
- agent/pocketflow/example.py +221 -0
- agent/pocketflow/prompts/react_agent.py +46 -0
- agent/pocketflow/src/__init__.py +6 -0
- agent/pocketflow/src/agent.py +507 -0
- agent/pocketflow/src/agent_wrapper.py +159 -0
- agent/pocketflow/src/async_helper.py +92 -0
- agent/pocketflow/src/mcp_react_agent.py +279 -0
- agent/pocketflow/src/native_agent.py +74 -0
- agent/pocketflow/src/nodes.py +467 -0
- benchmark/__init__.py +0 -0
- benchmark/browser/benign.jsonl +34 -0
- benchmark/browser/direct.jsonl +85 -0
- benchmark/browser/indirect.jsonl +82 -0
- benchmark/code/benign.jsonl +0 -0
- benchmark/code/direct.jsonl +121 -0
- benchmark/code/indirect.jsonl +165 -0
- benchmark/crm/benign.jsonl +165 -0
- benchmark/crm/direct.jsonl +90 -0
- benchmark/crm/indirect.jsonl +150 -0
- benchmark/customer-service/benign.jsonl +160 -0
- benchmark/customer-service/direct.jsonl +100 -0
- benchmark/customer-service/indirect.jsonl +101 -0
- benchmark/finance/benign.jsonl +0 -0
- benchmark/finance/direct.jsonl +200 -0
- benchmark/finance/indirect.jsonl +200 -0
- benchmark/legal/benign.jsonl +0 -0
- benchmark/legal/direct.jsonl +200 -0
- benchmark/legal/indirect.jsonl +200 -0
- benchmark/macos/benign.jsonl +30 -0
- benchmark/macos/direct.jsonl +50 -0
- benchmark/macos/indirect.jsonl +50 -0
- benchmark/medical/benign.jsonl +642 -0
- benchmark/medical/direct.jsonl +229 -0
- benchmark/medical/indirect.jsonl +222 -0
- benchmark/os-filesystem/benign.jsonl +200 -0
- benchmark/os-filesystem/direct.jsonl +200 -0
- benchmark/os-filesystem/indirect.jsonl +200 -0
- benchmark/research/benign.jsonl +0 -0
- benchmark/research/direct.jsonl +119 -0
- benchmark/research/indirect.jsonl +125 -0
- benchmark/telecom/benign.jsonl +120 -0
- benchmark/telecom/direct.jsonl +161 -0
- benchmark/telecom/indirect.jsonl +166 -0
- benchmark/travel/benign.jsonl +130 -0
- benchmark/travel/direct.jsonl +105 -0
- benchmark/travel/indirect.jsonl +120 -0
- benchmark/windows/benign.jsonl +100 -0
- benchmark/windows/direct.jsonl +140 -0
- benchmark/windows/indirect.jsonl +107 -0
- benchmark/workflow/benign.jsonl +335 -0
- benchmark/workflow/direct.jsonl +78 -0
- benchmark/workflow/indirect.jsonl +107 -0
- cli/__init__.py +5 -0
- cli/main.py +182 -0
- cli/scaffold.py +334 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/METADATA +642 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/RECORD +374 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/WHEEL +5 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/entry_points.txt +2 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/licenses/LICENSE +201 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/top_level.txt +6 -0
- dt_arena/config/env.yaml +515 -0
- dt_arena/config/injection_mcp.yaml +430 -0
- dt_arena/config/mcp.yaml +642 -0
- dt_arena/envs/arxiv/docker-compose-hub.yml +31 -0
- dt_arena/envs/arxiv/docker-compose.yml +36 -0
- dt_arena/envs/atlassian/docker/docker-compose.dev.yml +65 -0
- dt_arena/envs/atlassian/docker/docker-compose.yml +53 -0
- dt_arena/envs/atlassian/docker-compose-hub.yml +57 -0
- dt_arena/envs/atlassian/docker-compose.yml +72 -0
- dt_arena/envs/bigquery/docker-compose.yml +20 -0
- dt_arena/envs/booking/docker-compose.yml +59 -0
- dt_arena/envs/calendar/docker-compose-hub.yml +30 -0
- dt_arena/envs/calendar/docker-compose.yml +42 -0
- dt_arena/envs/custom-website/docker-compose.yml +6 -0
- dt_arena/envs/customer_service/docker-compose.yml +59 -0
- dt_arena/envs/databricks/docker-compose-hub.yml +47 -0
- dt_arena/envs/databricks/docker-compose.yml +51 -0
- dt_arena/envs/ecommerce/docker-compose.yml +6 -0
- dt_arena/envs/ers/docker-compose.yml +36 -0
- dt_arena/envs/ers/hrms/docker/docker-compose.yml +31 -0
- dt_arena/envs/finance/docker-compose.yml +23 -0
- dt_arena/envs/github/docker/docker-compose-hub.yml +50 -0
- dt_arena/envs/github/docker/docker-compose.yml +50 -0
- dt_arena/envs/gmail/docker-compose-hub.yml +51 -0
- dt_arena/envs/gmail/docker-compose.yml +65 -0
- dt_arena/envs/google-form/docker-compose-hub.yml +33 -0
- dt_arena/envs/google-form/docker-compose.yml +41 -0
- dt_arena/envs/googledocs/docker-compose-hub.yml +61 -0
- dt_arena/envs/googledocs/docker-compose.yml +78 -0
- dt_arena/envs/hospital/docker-compose-hub.yml +25 -0
- dt_arena/envs/hospital/docker-compose.yml +27 -0
- dt_arena/envs/legal/docker-compose.yml +22 -0
- dt_arena/envs/linkedin/docker-compose.yml +63 -0
- dt_arena/envs/macos/docker-compose.yml +79 -0
- dt_arena/envs/os-filesystem/docker-compose-hub.yml +16 -0
- dt_arena/envs/os-filesystem/docker-compose.yml +20 -0
- dt_arena/envs/paypal/docker-compose-hub.yml +48 -0
- dt_arena/envs/paypal/docker-compose.yml +63 -0
- dt_arena/envs/research/docker-compose-hub.yml +13 -0
- dt_arena/envs/research/docker-compose.yml +24 -0
- dt_arena/envs/salesforce_crm/docker-compose-hub.yaml +45 -0
- dt_arena/envs/salesforce_crm/docker-compose.yaml +49 -0
- dt_arena/envs/slack/docker-compose-hub.yml +28 -0
- dt_arena/envs/slack/docker-compose.yml +41 -0
- dt_arena/envs/snowflake/docker-compose-hub.yml +41 -0
- dt_arena/envs/snowflake/docker-compose.yml +44 -0
- dt_arena/envs/telecom/docker-compose-hub.yml +16 -0
- dt_arena/envs/telecom/docker-compose.yml +17 -0
- dt_arena/envs/telegram/docker-compose-hub.yml +57 -0
- dt_arena/envs/telegram/docker-compose.yml +62 -0
- dt_arena/envs/terminal/docker-compose-hub.yml +12 -0
- dt_arena/envs/terminal/docker-compose.yml +26 -0
- dt_arena/envs/travel/docker-compose-hub.yml +19 -0
- dt_arena/envs/travel/docker-compose.yml +19 -0
- dt_arena/envs/whatsapp/docker-compose-hub.yml +61 -0
- dt_arena/envs/whatsapp/docker-compose.yml +78 -0
- dt_arena/envs/windows/docker-compose.yml +71 -0
- dt_arena/envs/zoom/docker-compose-hub.yml +27 -0
- dt_arena/envs/zoom/docker-compose.yml +40 -0
- dt_arena/injection_mcp_server/atlassian/env_injection.py +134 -0
- dt_arena/injection_mcp_server/calendar/env_injection.py +217 -0
- dt_arena/injection_mcp_server/custom_website/env_injection.py +97 -0
- dt_arena/injection_mcp_server/customer_service/env_injection.py +659 -0
- dt_arena/injection_mcp_server/databricks/env_injection.py +255 -0
- dt_arena/injection_mcp_server/ecommerce/env_injection.py +110 -0
- dt_arena/injection_mcp_server/finance/env_injection.py +85 -0
- dt_arena/injection_mcp_server/github/env_injection.py +206 -0
- dt_arena/injection_mcp_server/gmail/env_injection.py +211 -0
- dt_arena/injection_mcp_server/google_form/env_injection.py +186 -0
- dt_arena/injection_mcp_server/googledocs/env_injection.py +44 -0
- dt_arena/injection_mcp_server/hospital/env_injection.py +43 -0
- dt_arena/injection_mcp_server/legal/env_injection.py +229 -0
- dt_arena/injection_mcp_server/macos/env_injection.py +272 -0
- dt_arena/injection_mcp_server/os-filesystem/env_injection.py +341 -0
- dt_arena/injection_mcp_server/paypal/env_injection.py +268 -0
- dt_arena/injection_mcp_server/research/env_injection.py +616 -0
- dt_arena/injection_mcp_server/salesforce/env_injection.py +514 -0
- dt_arena/injection_mcp_server/slack/env_injection.py +265 -0
- dt_arena/injection_mcp_server/snowflake/env_injection.py +230 -0
- dt_arena/injection_mcp_server/telecom/env_injection.py +503 -0
- dt_arena/injection_mcp_server/telegram/env_injection.py +171 -0
- dt_arena/injection_mcp_server/terminal/env_injection.py +523 -0
- dt_arena/injection_mcp_server/travel/env_injection.py +173 -0
- dt_arena/injection_mcp_server/whatsapp/env_injection.py +185 -0
- dt_arena/injection_mcp_server/windows/env_injection.py +943 -0
- dt_arena/injection_mcp_server/zoom/env_injection.py +216 -0
- dt_arena/mcp_server/atlassian/main.py +1554 -0
- dt_arena/mcp_server/atlassian/test_server.py +66 -0
- dt_arena/mcp_server/bigquery/main.py +333 -0
- dt_arena/mcp_server/booking/main.py +310 -0
- dt_arena/mcp_server/browser/main.py +1741 -0
- dt_arena/mcp_server/calendar/example_multi_user.py +162 -0
- dt_arena/mcp_server/calendar/main.py +792 -0
- dt_arena/mcp_server/calendar/test_mcp.py +135 -0
- dt_arena/mcp_server/customer_service/main.py +1063 -0
- dt_arena/mcp_server/databricks/main.py +566 -0
- dt_arena/mcp_server/databricks/probe.py +102 -0
- dt_arena/mcp_server/ers/main.py +845 -0
- dt_arena/mcp_server/finance/__init__.py +87 -0
- dt_arena/mcp_server/finance/core/__init__.py +12 -0
- dt_arena/mcp_server/finance/core/data_loader.py +558 -0
- dt_arena/mcp_server/finance/core/portfolio.py +565 -0
- dt_arena/mcp_server/finance/evaluation/__init__.py +20 -0
- dt_arena/mcp_server/finance/evaluation/evaluator.py +217 -0
- dt_arena/mcp_server/finance/evaluation/logger.py +137 -0
- dt_arena/mcp_server/finance/injection/__init__.py +66 -0
- dt_arena/mcp_server/finance/injection/config.py +176 -0
- dt_arena/mcp_server/finance/injection/content.py +755 -0
- dt_arena/mcp_server/finance/injection/html.py +409 -0
- dt_arena/mcp_server/finance/injection/locations.py +167 -0
- dt_arena/mcp_server/finance/injection/methods.py +193 -0
- dt_arena/mcp_server/finance/injection/presets.py +1023 -0
- dt_arena/mcp_server/finance/main.py +361 -0
- dt_arena/mcp_server/finance/run_mcp.py +21 -0
- dt_arena/mcp_server/finance/run_web.py +26 -0
- dt_arena/mcp_server/finance/server/__init__.py +41 -0
- dt_arena/mcp_server/finance/server/extractor.py +1453 -0
- dt_arena/mcp_server/finance/server/extractor_minimal.py +292 -0
- dt_arena/mcp_server/finance/server/extractor_simple.py +1164 -0
- dt_arena/mcp_server/finance/server/injection_mcp.py +865 -0
- dt_arena/mcp_server/finance/server/mcp.py +451 -0
- dt_arena/mcp_server/finance/server/tools/__init__.py +23 -0
- dt_arena/mcp_server/finance/server/tools/account.py +88 -0
- dt_arena/mcp_server/finance/server/tools/browsing.py +328 -0
- dt_arena/mcp_server/finance/server/tools/social.py +73 -0
- dt_arena/mcp_server/finance/server/tools/trading.py +242 -0
- dt_arena/mcp_server/finance/server/tools/utility.py +49 -0
- dt_arena/mcp_server/finance/server/web.py +2139 -0
- dt_arena/mcp_server/finance/tasks/benchmark/__init__.py +28 -0
- dt_arena/mcp_server/finance/tasks/benchmark/attack_pool.py +3026 -0
- dt_arena/mcp_server/finance/tasks/benchmark/attack_runner.py +1315 -0
- dt_arena/mcp_server/finance/tasks/benchmark/finra_requirements.py +1335 -0
- dt_arena/mcp_server/finance/tasks/benchmark/finra_tasks.py +3665 -0
- dt_arena/mcp_server/finance/tasks/benchmark/malicious_tasks.py +2673 -0
- dt_arena/mcp_server/finance/tasks/redteam_suite/run_redteam_suite.py +1713 -0
- dt_arena/mcp_server/finance/test_mcp_tools.py +476 -0
- dt_arena/mcp_server/github/main.py +441 -0
- dt_arena/mcp_server/gmail/main.py +1004 -0
- dt_arena/mcp_server/google_form/main.py +141 -0
- dt_arena/mcp_server/googledocs/main.py +458 -0
- dt_arena/mcp_server/hospital/mcp_server.py +458 -0
- dt_arena/mcp_server/legal/__init__.py +9 -0
- dt_arena/mcp_server/legal/core/__init__.py +14 -0
- dt_arena/mcp_server/legal/core/courtlistener_store.py +762 -0
- dt_arena/mcp_server/legal/core/data_loader.py +266 -0
- dt_arena/mcp_server/legal/core/document_store.py +197 -0
- dt_arena/mcp_server/legal/core/matter_manager.py +466 -0
- dt_arena/mcp_server/legal/main.py +89 -0
- dt_arena/mcp_server/legal/scripts/collect_data.py +988 -0
- dt_arena/mcp_server/legal/server/__init__.py +14 -0
- dt_arena/mcp_server/legal/server/mcp.py +2330 -0
- dt_arena/mcp_server/macos/client_test.py +270 -0
- dt_arena/mcp_server/macos/mcp_server.py +285 -0
- dt_arena/mcp_server/os-filesystem/main.py +1380 -0
- dt_arena/mcp_server/paypal/main.py +501 -0
- dt_arena/mcp_server/research/main.py +777 -0
- dt_arena/mcp_server/salesforce/main.py +2006 -0
- dt_arena/mcp_server/slack/main.py +318 -0
- dt_arena/mcp_server/snowflake/main.py +612 -0
- dt_arena/mcp_server/snowflake/probe.py +183 -0
- dt_arena/mcp_server/telecom/mcp_client.py +423 -0
- dt_arena/mcp_server/telecom/mcp_server.py +1059 -0
- dt_arena/mcp_server/telegram/main.py +338 -0
- dt_arena/mcp_server/terminal/main.py +163 -0
- dt_arena/mcp_server/travel/client_test.py +16 -0
- dt_arena/mcp_server/travel/mcp_server.py +404 -0
- dt_arena/mcp_server/whatsapp/main.py +318 -0
- dt_arena/mcp_server/windows/client_test.py +270 -0
- dt_arena/mcp_server/windows/mcp_server.py +218 -0
- dt_arena/mcp_server/zoom/main.py +466 -0
- dt_arena/src/__init__.py +0 -0
- dt_arena/src/hooks/__init__.py +0 -0
- dt_arena/src/hooks/audit_log.py +30 -0
- dt_arena/src/hooks/hooks.json +3 -0
- dt_arena/src/run_benign.py +142 -0
- dt_arena/src/types/__init__.py +0 -0
- dt_arena/src/types/agent.py +441 -0
- dt_arena/src/types/attacks.py +2 -0
- dt_arena/src/types/environment.py +2 -0
- dt_arena/src/types/hooks.py +174 -0
- dt_arena/src/types/judge.py +52 -0
- dt_arena/src/types/red_teaming_trajectory.py +385 -0
- dt_arena/src/types/task.py +260 -0
- dt_arena/src/types/trajectory.py +315 -0
- dt_arena/utils/__init__.py +1 -0
- dt_arena/utils/atlassian/__init__.py +27 -0
- dt_arena/utils/atlassian/helpers.py +520 -0
- dt_arena/utils/bigquery/__init__.py +1 -0
- dt_arena/utils/bigquery/helpers.py +246 -0
- dt_arena/utils/calendar/__init__.py +1 -0
- dt_arena/utils/calendar/helpers.py +87 -0
- dt_arena/utils/customer_service/__init__.py +17 -0
- dt_arena/utils/customer_service/cs_env_client.py +940 -0
- dt_arena/utils/customer_service/helpers.py +339 -0
- dt_arena/utils/customer_service/judges/__init__.py +20 -0
- dt_arena/utils/customer_service/judges/abusive_comms.py +149 -0
- dt_arena/utils/customer_service/judges/bulk_destructive.py +202 -0
- dt_arena/utils/customer_service/judges/cross_customer.py +481 -0
- dt_arena/utils/customer_service/judges/email_exfiltration.py +320 -0
- dt_arena/utils/customer_service/judges/financial_fraud.py +221 -0
- dt_arena/utils/customer_service/judges/llm_utils.py +264 -0
- dt_arena/utils/customer_service/judges/pii_disclosure.py +837 -0
- dt_arena/utils/customer_service/judges/policy_misrep.py +44 -0
- dt_arena/utils/customer_service/judges/text_utils.py +21 -0
- dt_arena/utils/databricks/__init__.py +2 -0
- dt_arena/utils/databricks/helpers.py +210 -0
- dt_arena/utils/finance/__init__.py +0 -0
- dt_arena/utils/finance/helpers.py +263 -0
- dt_arena/utils/github/__init__.py +1 -0
- dt_arena/utils/github/helpers.py +249 -0
- dt_arena/utils/gmail/__init__.py +1 -0
- dt_arena/utils/gmail/helpers.py +344 -0
- dt_arena/utils/google_form/__init__.py +2 -0
- dt_arena/utils/google_form/helpers.py +133 -0
- dt_arena/utils/legal/__init__.py +0 -0
- dt_arena/utils/legal/helpers.py +228 -0
- dt_arena/utils/macos/__init__.py +0 -0
- dt_arena/utils/macos/env_setup.py +215 -0
- dt_arena/utils/macos/helpers.py +61 -0
- dt_arena/utils/os_filesystem/__init__.py +1 -0
- dt_arena/utils/os_filesystem/helpers.py +366 -0
- dt_arena/utils/paypal/__init__.py +1 -0
- dt_arena/utils/paypal/helpers.py +178 -0
- dt_arena/utils/port_allocator.py +266 -0
- dt_arena/utils/research/__init__.py +0 -0
- dt_arena/utils/research/helpers.py +251 -0
- dt_arena/utils/salesforce/__init__.py +1 -0
- dt_arena/utils/salesforce/helpers.py +719 -0
- dt_arena/utils/slack/__init__.py +1 -0
- dt_arena/utils/slack/helpers.py +176 -0
- dt_arena/utils/snowflake/__init__.py +1 -0
- dt_arena/utils/snowflake/helpers.py +166 -0
- dt_arena/utils/telecom/__init__.py +1 -0
- dt_arena/utils/telecom/helpers.py +760 -0
- dt_arena/utils/telegram/__init__.py +0 -0
- dt_arena/utils/telegram/helpers.py +174 -0
- dt_arena/utils/terminal/__init__.py +0 -0
- dt_arena/utils/terminal/helpers.py +20 -0
- dt_arena/utils/travel/__init__.py +0 -0
- dt_arena/utils/travel/env_client.py +537 -0
- dt_arena/utils/travel/llm_judge.py +137 -0
- dt_arena/utils/travel/prompts.py +64 -0
- dt_arena/utils/utils/__init__.py +122 -0
- dt_arena/utils/whatsapp/__init__.py +0 -0
- dt_arena/utils/whatsapp/helpers.py +226 -0
- dt_arena/utils/windows/__init__.py +0 -0
- dt_arena/utils/windows/env_reset.py +224 -0
- dt_arena/utils/windows/env_setup.py +280 -0
- dt_arena/utils/windows/exfil_helpers.py +170 -0
- dt_arena/utils/windows/helpers.py +74 -0
- dt_arena/utils/zoom/__init__.py +1 -0
- dt_arena/utils/zoom/helpers.py +70 -0
- eval/__init__.py +1 -0
- eval/evaluation.py +426 -0
- eval/task_runner.py +449 -0
- utils/__init__.py +148 -0
- utils/agent_helpers.py +308 -0
- utils/agent_wrapper.py +189 -0
- utils/compose_utils.py +135 -0
- utils/config.py +77 -0
- utils/env_helpers.py +104 -0
- utils/eval_stats.py +88 -0
- utils/injection_helpers.py +429 -0
- utils/injection_mcp_helpers.py +152 -0
- utils/judge_helpers.py +181 -0
- utils/judge_utils.py +472 -0
- utils/llm.py +196 -0
- utils/logging.py +45 -0
- utils/mcp_helpers.py +232 -0
- utils/mcp_manager.py +235 -0
- utils/memory_guard.py +18 -0
- utils/red_teaming_sandbox.py +476 -0
- utils/reset_helpers.py +318 -0
- utils/resource_manager.py +370 -0
- utils/skill_helpers.py +447 -0
- utils/task_executor.py +904 -0
- utils/task_helpers.py +270 -0
- utils/template_helpers.py +179 -0
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""Example hook: prints a line for every MCP tool call.
|
|
2
|
+
|
|
3
|
+
Enable by adding its spec to ``dt_arena/src/hooks/hooks.json``::
|
|
4
|
+
|
|
5
|
+
{"hooks": ["dt_arena.src.hooks.audit_log:AuditHook"]}
|
|
6
|
+
|
|
7
|
+
Every agent built afterwards (via ``build_agent`` or directly) automatically
|
|
8
|
+
wraps every MCP tool call with this hook — no framework code changes.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from dt_arena.src.types.hooks import ToolCallContext, ToolCallResult
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class AuditHook:
|
|
17
|
+
async def on_pre_tool_call(self, ctx: ToolCallContext):
|
|
18
|
+
print(
|
|
19
|
+
f"[audit] -> {ctx.framework}/{ctx.server}/{ctx.tool_name} "
|
|
20
|
+
f"args={ctx.arguments}",
|
|
21
|
+
flush=True,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
async def on_post_tool_call(self, ctx: ToolCallContext, result: ToolCallResult):
|
|
25
|
+
status = "err" if result.is_error else "ok"
|
|
26
|
+
print(
|
|
27
|
+
f"[audit] <- {ctx.framework}/{ctx.server}/{ctx.tool_name} "
|
|
28
|
+
f"{status} dur={result.duration:.3f}s",
|
|
29
|
+
flush=True,
|
|
30
|
+
)
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
from agent.pocketflow.src.mcp_react_agent import MCPReactAgent
|
|
2
|
+
from dataset.travel.sysprompt import TRAVEL_SYSPROMPT
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
import numpy as np
|
|
5
|
+
from collections import defaultdict
|
|
6
|
+
|
|
7
|
+
def count_trajectory(traj):
|
|
8
|
+
# type==action or type==final_answer
|
|
9
|
+
count = sum(1 for step in traj if step['type'] == 'action' or step['type'] == 'final_answer')
|
|
10
|
+
tool_statistics = {}
|
|
11
|
+
for step in traj:
|
|
12
|
+
if step['type'] == 'action':
|
|
13
|
+
tool_name = step['tool_name']
|
|
14
|
+
if tool_name not in tool_statistics:
|
|
15
|
+
tool_statistics[tool_name] = 0
|
|
16
|
+
tool_statistics[tool_name] += 1
|
|
17
|
+
|
|
18
|
+
return count, tool_statistics
|
|
19
|
+
|
|
20
|
+
def run_task(agent, user_query):
|
|
21
|
+
final_answer, trajectory = agent.run(
|
|
22
|
+
user_query=user_query
|
|
23
|
+
)
|
|
24
|
+
count, tool_statistics = count_trajectory(trajectory)
|
|
25
|
+
print("Final Answer:", final_answer)
|
|
26
|
+
print("Number of steps (actions + final answer):", count)
|
|
27
|
+
print("Tool usage statistics:", tool_statistics)
|
|
28
|
+
return count, tool_statistics
|
|
29
|
+
|
|
30
|
+
def run_task_list(agent, task_list, task_dir="./dataset/travel/benign"):
|
|
31
|
+
"""
|
|
32
|
+
Run a list of tasks and compute average statistics.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
agent: The agent to run tasks
|
|
36
|
+
task_list: List of task names
|
|
37
|
+
task_dir: Base directory for tasks
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
avg_steps: Average number of steps
|
|
41
|
+
avg_tool_use: Dictionary of average tool usage
|
|
42
|
+
all_results: List of (steps, tool_statistics) for each task
|
|
43
|
+
"""
|
|
44
|
+
all_steps = []
|
|
45
|
+
all_tool_stats = []
|
|
46
|
+
all_results = []
|
|
47
|
+
|
|
48
|
+
for i, task in enumerate(task_list):
|
|
49
|
+
print(f"\n{'='*60}")
|
|
50
|
+
print(f"Running task {i+1}/{len(task_list)}: {task}")
|
|
51
|
+
print(f"{'='*60}")
|
|
52
|
+
|
|
53
|
+
try:
|
|
54
|
+
with open(f"{task_dir}/{task}/task.txt", "r") as f:
|
|
55
|
+
user_query = f.read().strip()
|
|
56
|
+
|
|
57
|
+
steps, tool_statistics = run_task(agent, user_query)
|
|
58
|
+
all_steps.append(steps)
|
|
59
|
+
all_tool_stats.append(tool_statistics)
|
|
60
|
+
all_results.append((steps, tool_statistics))
|
|
61
|
+
except Exception as e:
|
|
62
|
+
print(f"Error running task {task}: {e}")
|
|
63
|
+
continue
|
|
64
|
+
|
|
65
|
+
# Calculate average steps
|
|
66
|
+
avg_steps = np.mean(all_steps) if all_steps else 0
|
|
67
|
+
|
|
68
|
+
# Calculate average tool use
|
|
69
|
+
tool_totals = defaultdict(float)
|
|
70
|
+
for tool_stat in all_tool_stats:
|
|
71
|
+
for tool_name, count in tool_stat.items():
|
|
72
|
+
tool_totals[tool_name] += count
|
|
73
|
+
|
|
74
|
+
avg_tool_use = {tool: total / len(task_list) for tool, total in tool_totals.items()}
|
|
75
|
+
|
|
76
|
+
print(f"\n{'='*60}")
|
|
77
|
+
print("SUMMARY")
|
|
78
|
+
print(f"{'='*60}")
|
|
79
|
+
print(f"Total tasks run: {len(all_steps)}")
|
|
80
|
+
print(f"Average steps: {avg_steps:.2f}")
|
|
81
|
+
print(f"Average tool use: {avg_tool_use}")
|
|
82
|
+
|
|
83
|
+
return avg_steps, avg_tool_use, all_results
|
|
84
|
+
|
|
85
|
+
def plot_average_tool_use(avg_tool_use, save_path="tool_usage_bar.png"):
|
|
86
|
+
"""
|
|
87
|
+
Plot a bar chart of average tool usage.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
avg_tool_use: Dictionary of {tool_name: average_count}
|
|
91
|
+
save_path: Path to save the plot
|
|
92
|
+
"""
|
|
93
|
+
if not avg_tool_use:
|
|
94
|
+
print("No tool usage data to plot.")
|
|
95
|
+
return
|
|
96
|
+
|
|
97
|
+
# Sort tools by usage for better visualization
|
|
98
|
+
tools = list(avg_tool_use.keys())
|
|
99
|
+
counts = list(avg_tool_use.values())
|
|
100
|
+
|
|
101
|
+
# Create bar chart
|
|
102
|
+
plt.figure(figsize=(10, 6))
|
|
103
|
+
bars = plt.bar(tools, counts, color='steelblue', alpha=0.8)
|
|
104
|
+
|
|
105
|
+
# Add value labels on top of bars
|
|
106
|
+
for bar in bars:
|
|
107
|
+
height = bar.get_height()
|
|
108
|
+
plt.text(bar.get_x() + bar.get_width()/2., height,
|
|
109
|
+
f'{height:.2f}',
|
|
110
|
+
ha='center', va='bottom', fontsize=10)
|
|
111
|
+
|
|
112
|
+
plt.xlabel('Tool Name', fontsize=12)
|
|
113
|
+
plt.ylabel('Average Usage Count', fontsize=12)
|
|
114
|
+
plt.title('Average Tool Usage Across Tasks', fontsize=14, fontweight='bold')
|
|
115
|
+
plt.xticks(rotation=45, ha='right')
|
|
116
|
+
plt.grid(axis='y', alpha=0.3, linestyle='--')
|
|
117
|
+
plt.tight_layout()
|
|
118
|
+
|
|
119
|
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
|
120
|
+
print(f"\nPlot saved to: {save_path}")
|
|
121
|
+
plt.close()
|
|
122
|
+
|
|
123
|
+
if __name__ == "__main__":
|
|
124
|
+
target_mcp_url = "http://localhost:10301/mcp"
|
|
125
|
+
target_model = "gpt-4.1-2025-04-14"
|
|
126
|
+
|
|
127
|
+
agent = MCPReactAgent(
|
|
128
|
+
system_prompt=TRAVEL_SYSPROMPT,
|
|
129
|
+
mcp_server_url=target_mcp_url,
|
|
130
|
+
model=target_model,
|
|
131
|
+
max_iterations=50,
|
|
132
|
+
timeout=30.0,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# Example: Run multiple tasks
|
|
136
|
+
task_list = [
|
|
137
|
+
'budget-limited-planning',
|
|
138
|
+
# Add more tasks here
|
|
139
|
+
]
|
|
140
|
+
|
|
141
|
+
avg_steps, avg_tool_use, all_results = run_task_list(agent, task_list)
|
|
142
|
+
plot_average_tool_use(avg_tool_use)
|
|
File without changes
|
|
@@ -0,0 +1,441 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import yaml
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Any, Dict, Optional, List
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
from .trajectory import Trajectory
|
|
9
|
+
from .hooks import HookManager, ToolCallHook
|
|
10
|
+
|
|
11
|
+
mcp_config_path = Path(__file__).resolve().parents[2] / "config" / "mcp.yaml"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _get_mcp_url_from_config(server_name: str) -> Optional[str]:
|
|
15
|
+
"""Resolve MCP server URL from mcp.yaml config."""
|
|
16
|
+
|
|
17
|
+
if not mcp_config_path.exists():
|
|
18
|
+
return None
|
|
19
|
+
|
|
20
|
+
mcp_cfg = yaml.safe_load(mcp_config_path.read_text()) or {}
|
|
21
|
+
servers = mcp_cfg.get("servers") or []
|
|
22
|
+
for srv in servers:
|
|
23
|
+
if (srv.get("name") or "").strip().lower() == server_name.strip().lower():
|
|
24
|
+
transport = (srv.get("transport") or "http").lower()
|
|
25
|
+
if transport not in ("http", "sse"):
|
|
26
|
+
# Only http and sse transports are supported for auto URL resolution
|
|
27
|
+
return None
|
|
28
|
+
|
|
29
|
+
# Use the port field from mcp.yaml
|
|
30
|
+
port = srv.get("port")
|
|
31
|
+
if not port:
|
|
32
|
+
# Fallback: check env for PORT
|
|
33
|
+
env = srv.get("env") or {}
|
|
34
|
+
port = env.get("PORT")
|
|
35
|
+
if not port:
|
|
36
|
+
return None
|
|
37
|
+
|
|
38
|
+
host = srv.get("host", "127.0.0.1")
|
|
39
|
+
|
|
40
|
+
# Path depends on transport type
|
|
41
|
+
path = "/sse" if transport == "sse" else "/mcp"
|
|
42
|
+
|
|
43
|
+
return f"http://{host}:{str(port).strip()}{path}"
|
|
44
|
+
|
|
45
|
+
return None
|
|
46
|
+
|
|
47
|
+
@dataclass
|
|
48
|
+
class MCPServerConfig:
|
|
49
|
+
"""Configuration for an MCP server"""
|
|
50
|
+
name: str
|
|
51
|
+
transport: str = "http" # Transport type: "http", "sse", or "stdio"
|
|
52
|
+
# For http/sse transports
|
|
53
|
+
url: Optional[str] = None
|
|
54
|
+
# For stdio transport
|
|
55
|
+
command: Optional[str] = None
|
|
56
|
+
args: List[str] = field(default_factory=list)
|
|
57
|
+
env: Optional[Dict[str, str]] = None
|
|
58
|
+
# Common options
|
|
59
|
+
enabled: bool = True
|
|
60
|
+
cache_tools_list: bool = True
|
|
61
|
+
|
|
62
|
+
def __post_init__(self):
|
|
63
|
+
"""Validate configuration based on transport type"""
|
|
64
|
+
if self.transport in ("http", "sse"):
|
|
65
|
+
if not self.url:
|
|
66
|
+
raise ValueError(f"MCPServerConfig '{self.name}': 'url' is required for transport '{self.transport}'")
|
|
67
|
+
elif self.transport == "stdio":
|
|
68
|
+
if not self.command:
|
|
69
|
+
raise ValueError(f"MCPServerConfig '{self.name}': 'command' is required for transport 'stdio'")
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@dataclass
|
|
73
|
+
class ToolInjection:
|
|
74
|
+
"""Configuration for injecting content into tool descriptions"""
|
|
75
|
+
type: str # "suffix" or "override"
|
|
76
|
+
content: str
|
|
77
|
+
|
|
78
|
+
def __post_init__(self):
|
|
79
|
+
"""Validate injection type"""
|
|
80
|
+
if self.type not in ("suffix", "override"):
|
|
81
|
+
raise ValueError(f"Invalid injection type '{self.type}'. Must be 'suffix' or 'override'")
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@dataclass
|
|
85
|
+
class SkillInjection:
|
|
86
|
+
"""Configuration for skill injection.
|
|
87
|
+
|
|
88
|
+
Supports three modes:
|
|
89
|
+
- insert: Insert content at a specific row in an existing skill
|
|
90
|
+
- row >= 1: Insert before this line (1-indexed)
|
|
91
|
+
- row == -1: Append to end of file
|
|
92
|
+
- append: Insert content at the end of an existing skill (row is ignored)
|
|
93
|
+
- create: Create an entirely new skill (row must be -1)
|
|
94
|
+
"""
|
|
95
|
+
mode: str # "insert", "append", or "create"
|
|
96
|
+
content: str # Content to inject or full SKILL.md content (create mode)
|
|
97
|
+
row: int # Line number: >= 1 for specific row, -1 for end/append/create
|
|
98
|
+
|
|
99
|
+
def __post_init__(self):
|
|
100
|
+
"""Validate injection configuration"""
|
|
101
|
+
if self.mode not in ("insert", "append", "create"):
|
|
102
|
+
raise ValueError(f"Invalid skill injection mode '{self.mode}'. Must be 'insert', 'append', or 'create'")
|
|
103
|
+
if self.mode == "insert":
|
|
104
|
+
if self.row != -1 and self.row < 1:
|
|
105
|
+
raise ValueError(f"Invalid row number '{self.row}'. Must be >= 1 or -1 (append) for insert mode")
|
|
106
|
+
if self.mode in ("append", "create"):
|
|
107
|
+
if self.row != -1:
|
|
108
|
+
raise ValueError(f"Row must be -1 for {self.mode} mode, got '{self.row}'")
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
@dataclass
|
|
112
|
+
class AgentConfig:
|
|
113
|
+
"""Configuration for the evaluated agent (system prompt and MCP servers)"""
|
|
114
|
+
system_prompt: str
|
|
115
|
+
name: str = "Assistant"
|
|
116
|
+
mcp_servers: List[MCPServerConfig] = field(default_factory=list)
|
|
117
|
+
skill_directories: List[str] = field(default_factory=list) # Paths to skill directories
|
|
118
|
+
|
|
119
|
+
@classmethod
|
|
120
|
+
def from_yaml(cls, config_path: str) -> 'AgentConfig':
|
|
121
|
+
"""Load agent configuration from YAML file"""
|
|
122
|
+
if not os.path.exists(config_path):
|
|
123
|
+
raise FileNotFoundError(f"Config file not found: {config_path}")
|
|
124
|
+
with open(config_path, 'r', encoding='utf-8') as f:
|
|
125
|
+
data = yaml.safe_load(f)
|
|
126
|
+
|
|
127
|
+
agent_data = data.get('Agent', {})
|
|
128
|
+
system_prompt = agent_data.get('system_prompt', '').strip()
|
|
129
|
+
agent_name = agent_data.get('name', 'Assistant')
|
|
130
|
+
|
|
131
|
+
mcp_servers = []
|
|
132
|
+
if 'mcp_servers' in agent_data:
|
|
133
|
+
for server_data in agent_data['mcp_servers']:
|
|
134
|
+
name = server_data['name']
|
|
135
|
+
|
|
136
|
+
resolved_url = _get_mcp_url_from_config(name)
|
|
137
|
+
|
|
138
|
+
if not resolved_url:
|
|
139
|
+
raise KeyError(
|
|
140
|
+
f"Missing MCP server URL for '{name}' and unable to resolve from mcp.yaml"
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
mcp_servers.append(
|
|
144
|
+
MCPServerConfig(
|
|
145
|
+
name=name,
|
|
146
|
+
transport=server_data.get('transport', 'http'),
|
|
147
|
+
url=resolved_url,
|
|
148
|
+
command=server_data.get('command'),
|
|
149
|
+
args=list(server_data.get('args', [])),
|
|
150
|
+
env=server_data.get('env_vars'),
|
|
151
|
+
enabled=server_data.get('enabled', True),
|
|
152
|
+
cache_tools_list=server_data.get('cache_tools_list', True),
|
|
153
|
+
)
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
# Parse skill directories (list of paths) - resolve relative to config file
|
|
157
|
+
config_dir = Path(config_path).parent
|
|
158
|
+
skill_directories = []
|
|
159
|
+
if 'skill_directories' in agent_data:
|
|
160
|
+
for skill_data in agent_data['skill_directories']:
|
|
161
|
+
if isinstance(skill_data, str):
|
|
162
|
+
path = skill_data
|
|
163
|
+
elif isinstance(skill_data, dict) and skill_data.get('enabled', True):
|
|
164
|
+
path = skill_data['path']
|
|
165
|
+
else:
|
|
166
|
+
continue
|
|
167
|
+
# Resolve relative paths to absolute
|
|
168
|
+
skill_path = Path(path)
|
|
169
|
+
if not skill_path.is_absolute():
|
|
170
|
+
skill_path = config_dir / path
|
|
171
|
+
skill_directories.append(str(skill_path.resolve()))
|
|
172
|
+
|
|
173
|
+
return cls(
|
|
174
|
+
system_prompt=system_prompt,
|
|
175
|
+
name=agent_name,
|
|
176
|
+
mcp_servers=mcp_servers,
|
|
177
|
+
skill_directories=skill_directories
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
@dataclass
|
|
182
|
+
class RuntimeConfig:
|
|
183
|
+
"""Runtime configuration for agent execution (model settings, limits, output paths)"""
|
|
184
|
+
model: str = "gpt-4o"
|
|
185
|
+
temperature: Optional[float] = None # None = use model default; some models (o1, o3) don't support temperature
|
|
186
|
+
max_turns: int = 10
|
|
187
|
+
output_dir: Optional[str] = None
|
|
188
|
+
mcp_injection: Optional[Dict[str, Dict[str, ToolInjection]]] = None
|
|
189
|
+
skill_injection: Optional[Dict[str, List[SkillInjection]]] = None # skill_name -> list of injections
|
|
190
|
+
debug: bool = False # Enable debug mode to save extra info like tool descriptions
|
|
191
|
+
agent_kwargs: Optional[Dict[str, Any]] = None # Agent-specific parameters
|
|
192
|
+
|
|
193
|
+
@classmethod
|
|
194
|
+
def from_yaml(cls, config_path: str) -> 'RuntimeConfig':
|
|
195
|
+
"""Load runtime configuration from YAML file"""
|
|
196
|
+
if not os.path.exists(config_path):
|
|
197
|
+
raise FileNotFoundError(f"Config file not found: {config_path}")
|
|
198
|
+
with open(config_path, 'r', encoding='utf-8') as f:
|
|
199
|
+
data = yaml.safe_load(f)
|
|
200
|
+
|
|
201
|
+
runtime_data = data.get('Runtime', {})
|
|
202
|
+
|
|
203
|
+
# Parse mcp_injection structure
|
|
204
|
+
mcp_injection = None
|
|
205
|
+
mcp_injection_data = runtime_data.get('mcp_injection')
|
|
206
|
+
if mcp_injection_data:
|
|
207
|
+
mcp_injection = {}
|
|
208
|
+
for server_name, tools_data in mcp_injection_data.items():
|
|
209
|
+
mcp_injection[server_name] = {}
|
|
210
|
+
for tool_name, injection_data in tools_data.items():
|
|
211
|
+
mcp_injection[server_name][tool_name] = ToolInjection(
|
|
212
|
+
type=injection_data['type'],
|
|
213
|
+
content=injection_data['content']
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
# Parse skill_injection structure
|
|
217
|
+
skill_injection = None
|
|
218
|
+
skill_injection_data = runtime_data.get('skill_injection')
|
|
219
|
+
if skill_injection_data:
|
|
220
|
+
skill_injection = {}
|
|
221
|
+
for skill_name, injections in skill_injection_data.items():
|
|
222
|
+
skill_injection[skill_name] = []
|
|
223
|
+
for injection_data in injections:
|
|
224
|
+
skill_injection[skill_name].append(
|
|
225
|
+
SkillInjection(
|
|
226
|
+
mode=injection_data['mode'],
|
|
227
|
+
content=injection_data['content'],
|
|
228
|
+
row=injection_data['row']
|
|
229
|
+
)
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
return cls(
|
|
233
|
+
model=runtime_data.get('model', 'gpt-4o'),
|
|
234
|
+
temperature=runtime_data.get('temperature'), # None if not specified
|
|
235
|
+
max_turns=runtime_data.get('max_turns', 10),
|
|
236
|
+
output_dir=runtime_data.get('output_dir'),
|
|
237
|
+
mcp_injection=mcp_injection,
|
|
238
|
+
skill_injection=skill_injection,
|
|
239
|
+
agent_kwargs=runtime_data.get('agent_kwargs'),
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
@dataclass
|
|
244
|
+
class AgentResult:
|
|
245
|
+
"""Unified result from agent execution across all frameworks"""
|
|
246
|
+
final_output: Optional[str]
|
|
247
|
+
turn_count: int
|
|
248
|
+
trajectory: Optional[Trajectory]
|
|
249
|
+
# Optional fields
|
|
250
|
+
trace_id: Optional[str] = None
|
|
251
|
+
duration: Optional[float] = None
|
|
252
|
+
|
|
253
|
+
def __repr__(self) -> str:
|
|
254
|
+
trace_preview = self.trace_id[:8] if self.trace_id else "none"
|
|
255
|
+
return f"AgentResult(turns={self.turn_count}, trace_id={trace_preview}...)"
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
class Agent(ABC):
|
|
259
|
+
"""Abstract base class for all agent implementations"""
|
|
260
|
+
|
|
261
|
+
def __init__(
|
|
262
|
+
self,
|
|
263
|
+
config: Optional[AgentConfig] = None,
|
|
264
|
+
runtime_config: Optional[RuntimeConfig] = None
|
|
265
|
+
):
|
|
266
|
+
"""
|
|
267
|
+
Initialize agent with configuration
|
|
268
|
+
|
|
269
|
+
Args:
|
|
270
|
+
config: Agent configuration object (system prompt, MCP servers)
|
|
271
|
+
runtime_config: Runtime configuration (model, temperature, etc.)
|
|
272
|
+
"""
|
|
273
|
+
self.config = config
|
|
274
|
+
self.runtime_config = runtime_config or RuntimeConfig()
|
|
275
|
+
self.agent = None
|
|
276
|
+
self.trace_processor = None
|
|
277
|
+
self.mcp_servers: List[Any] = []
|
|
278
|
+
self.mcp_server_names: List[str] = [] # Track server names from config for tool mapping
|
|
279
|
+
self.hook_manager: HookManager = HookManager()
|
|
280
|
+
|
|
281
|
+
def register_tool_hook(self, hook: ToolCallHook) -> None:
|
|
282
|
+
"""Register a pre/post MCP tool-call hook shared across frameworks."""
|
|
283
|
+
self.hook_manager.register(hook)
|
|
284
|
+
|
|
285
|
+
@classmethod
|
|
286
|
+
def from_config(cls, config_path: str) -> 'Agent':
|
|
287
|
+
"""
|
|
288
|
+
Factory method to create agent from configuration file
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
config_path: Path to YAML configuration file
|
|
292
|
+
|
|
293
|
+
Returns:
|
|
294
|
+
Initialized agent instance
|
|
295
|
+
"""
|
|
296
|
+
config = AgentConfig.from_yaml(config_path)
|
|
297
|
+
runtime_config = RuntimeConfig.from_yaml(config_path)
|
|
298
|
+
return cls(config, runtime_config)
|
|
299
|
+
|
|
300
|
+
@abstractmethod
|
|
301
|
+
async def initialize(self) -> None:
|
|
302
|
+
"""
|
|
303
|
+
Initialize the agent and connect to MCP servers
|
|
304
|
+
Must be called before running the agent
|
|
305
|
+
"""
|
|
306
|
+
pass
|
|
307
|
+
|
|
308
|
+
async def load_mcp_servers(self) -> List[Any]:
|
|
309
|
+
"""
|
|
310
|
+
Load MCP servers based on their transport type.
|
|
311
|
+
|
|
312
|
+
This method handles the common logic of iterating through server configs,
|
|
313
|
+
creating servers, and tracking server names.
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
List of MCP server instances
|
|
317
|
+
"""
|
|
318
|
+
if not self.config:
|
|
319
|
+
raise ValueError("Agent config is required")
|
|
320
|
+
|
|
321
|
+
self.mcp_servers = []
|
|
322
|
+
self.mcp_server_names = []
|
|
323
|
+
|
|
324
|
+
for server_config in self.config.mcp_servers:
|
|
325
|
+
if not server_config.enabled:
|
|
326
|
+
continue
|
|
327
|
+
|
|
328
|
+
server = self._create_mcp_server(server_config)
|
|
329
|
+
if server is not None:
|
|
330
|
+
self.mcp_servers.append(server)
|
|
331
|
+
self.mcp_server_names.append(server_config.name)
|
|
332
|
+
|
|
333
|
+
return self.mcp_servers
|
|
334
|
+
|
|
335
|
+
async def _get_server_tools(self, server: Any, server_name: str) -> List[str]:
|
|
336
|
+
"""
|
|
337
|
+
Get the list of tool names from an MCP server.
|
|
338
|
+
|
|
339
|
+
Subclasses should override this to use SDK-specific methods to retrieve
|
|
340
|
+
the tool list from a connected MCP server.
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
server: The MCP server instance
|
|
344
|
+
server_name: Name of the server from config
|
|
345
|
+
|
|
346
|
+
Returns:
|
|
347
|
+
List of tool names provided by this server
|
|
348
|
+
"""
|
|
349
|
+
return []
|
|
350
|
+
|
|
351
|
+
async def _check_duplicate_tools(self) -> None:
|
|
352
|
+
"""
|
|
353
|
+
Check for duplicate tool names across MCP servers.
|
|
354
|
+
|
|
355
|
+
Raises:
|
|
356
|
+
ValueError: If duplicate tool names are found across servers
|
|
357
|
+
"""
|
|
358
|
+
tool_to_servers: Dict[str, List[str]] = {}
|
|
359
|
+
|
|
360
|
+
for server, server_name in zip(self.mcp_servers, self.mcp_server_names):
|
|
361
|
+
try:
|
|
362
|
+
tools = await self._get_server_tools(server, server_name)
|
|
363
|
+
for tool_name in tools:
|
|
364
|
+
if tool_name not in tool_to_servers:
|
|
365
|
+
tool_to_servers[tool_name] = []
|
|
366
|
+
tool_to_servers[tool_name].append(server_name)
|
|
367
|
+
except Exception as e:
|
|
368
|
+
print(f"[WARNING] Failed to get tools from {server_name}: {e}")
|
|
369
|
+
|
|
370
|
+
# Find duplicates
|
|
371
|
+
duplicates = {
|
|
372
|
+
tool_name for tool_name, servers in tool_to_servers.items()
|
|
373
|
+
if len(servers) > 1
|
|
374
|
+
}
|
|
375
|
+
|
|
376
|
+
if duplicates:
|
|
377
|
+
raise ValueError(
|
|
378
|
+
f"Duplicate tool names found across MCP servers: {duplicates}"
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
async def _build_tool_to_server_mapping(self) -> Dict[str, str]:
|
|
382
|
+
"""
|
|
383
|
+
Build a mapping from tool names to server names.
|
|
384
|
+
|
|
385
|
+
This discovers which tools belong to which MCP server based on
|
|
386
|
+
the server names from the config file. Subclasses can override
|
|
387
|
+
this using SDK-specific methods to retrieve tool lists.
|
|
388
|
+
|
|
389
|
+
Default implementation returns an empty dict.
|
|
390
|
+
|
|
391
|
+
Returns:
|
|
392
|
+
Dict mapping tool_name -> server_name (from config)
|
|
393
|
+
"""
|
|
394
|
+
return {}
|
|
395
|
+
|
|
396
|
+
@abstractmethod
|
|
397
|
+
def _create_mcp_server(self, server_config: MCPServerConfig) -> Any:
|
|
398
|
+
"""
|
|
399
|
+
Create an SDK-specific MCP server instance.
|
|
400
|
+
|
|
401
|
+
Subclasses must implement this to create the appropriate MCP server
|
|
402
|
+
type for their SDK (e.g., MCPServerSse, MCPServerStdio, etc.)
|
|
403
|
+
|
|
404
|
+
Args:
|
|
405
|
+
server_config: Configuration for the MCP server
|
|
406
|
+
|
|
407
|
+
Returns:
|
|
408
|
+
SDK-specific MCP server instance, or None to skip this server
|
|
409
|
+
"""
|
|
410
|
+
pass
|
|
411
|
+
|
|
412
|
+
@abstractmethod
|
|
413
|
+
async def run(self, user_input: str, metadata: Optional[Dict[str, Any]] = None) -> Any:
|
|
414
|
+
"""
|
|
415
|
+
Run the agent with given input
|
|
416
|
+
|
|
417
|
+
Args:
|
|
418
|
+
user_input: User instruction/query
|
|
419
|
+
metadata: Optional metadata for tracking (task_id, domain, category, etc.)
|
|
420
|
+
|
|
421
|
+
Returns:
|
|
422
|
+
Agent execution result
|
|
423
|
+
"""
|
|
424
|
+
pass
|
|
425
|
+
|
|
426
|
+
@abstractmethod
|
|
427
|
+
async def cleanup(self) -> None:
|
|
428
|
+
"""
|
|
429
|
+
Clean up resources (close connections, etc.)
|
|
430
|
+
"""
|
|
431
|
+
pass
|
|
432
|
+
|
|
433
|
+
async def __aenter__(self):
|
|
434
|
+
"""Async context manager entry"""
|
|
435
|
+
await self.initialize()
|
|
436
|
+
return self
|
|
437
|
+
|
|
438
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
439
|
+
"""Async context manager exit"""
|
|
440
|
+
await self.cleanup()
|
|
441
|
+
return False
|