decodingtrust-agent-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent/__init__.py +30 -0
- agent/claudesdk/__init__.py +8 -0
- agent/claudesdk/example.py +221 -0
- agent/claudesdk/src/__init__.py +8 -0
- agent/claudesdk/src/agent.py +400 -0
- agent/claudesdk/src/mcp_proxy.py +409 -0
- agent/claudesdk/src/utils.py +420 -0
- agent/googleadk/__init__.py +15 -0
- agent/googleadk/example.py +237 -0
- agent/googleadk/src/__init__.py +12 -0
- agent/googleadk/src/agent.py +401 -0
- agent/googleadk/src/mcp_wrapper.py +163 -0
- agent/googleadk/src/utils.py +602 -0
- agent/langchain/__init__.py +8 -0
- agent/langchain/example.py +213 -0
- agent/langchain/src/__init__.py +8 -0
- agent/langchain/src/agent.py +645 -0
- agent/langchain/src/utils.py +433 -0
- agent/openaisdk/__init__.py +17 -0
- agent/openaisdk/example.py +228 -0
- agent/openaisdk/src/__init__.py +12 -0
- agent/openaisdk/src/agent.py +491 -0
- agent/openaisdk/src/agent_wrapper.py +143 -0
- agent/openaisdk/src/mcp_wrapper.py +395 -0
- agent/openaisdk/src/utils.py +493 -0
- agent/openclaw/__init__.py +10 -0
- agent/openclaw/example.py +251 -0
- agent/openclaw/src/__init__.py +14 -0
- agent/openclaw/src/agent.py +930 -0
- agent/openclaw/src/helpers/__init__.py +1 -0
- agent/openclaw/src/helpers/auth_helpers.py +55 -0
- agent/openclaw/src/mcp_proxy.py +564 -0
- agent/openclaw/src/plugin_generator.py +231 -0
- agent/openclaw/src/utils.py +341 -0
- agent/pocketflow/__init__.py +18 -0
- agent/pocketflow/example.py +221 -0
- agent/pocketflow/prompts/react_agent.py +46 -0
- agent/pocketflow/src/__init__.py +6 -0
- agent/pocketflow/src/agent.py +507 -0
- agent/pocketflow/src/agent_wrapper.py +159 -0
- agent/pocketflow/src/async_helper.py +92 -0
- agent/pocketflow/src/mcp_react_agent.py +279 -0
- agent/pocketflow/src/native_agent.py +74 -0
- agent/pocketflow/src/nodes.py +467 -0
- benchmark/__init__.py +0 -0
- benchmark/browser/benign.jsonl +34 -0
- benchmark/browser/direct.jsonl +85 -0
- benchmark/browser/indirect.jsonl +82 -0
- benchmark/code/benign.jsonl +0 -0
- benchmark/code/direct.jsonl +121 -0
- benchmark/code/indirect.jsonl +165 -0
- benchmark/crm/benign.jsonl +165 -0
- benchmark/crm/direct.jsonl +90 -0
- benchmark/crm/indirect.jsonl +150 -0
- benchmark/customer-service/benign.jsonl +160 -0
- benchmark/customer-service/direct.jsonl +100 -0
- benchmark/customer-service/indirect.jsonl +101 -0
- benchmark/finance/benign.jsonl +0 -0
- benchmark/finance/direct.jsonl +200 -0
- benchmark/finance/indirect.jsonl +200 -0
- benchmark/legal/benign.jsonl +0 -0
- benchmark/legal/direct.jsonl +200 -0
- benchmark/legal/indirect.jsonl +200 -0
- benchmark/macos/benign.jsonl +30 -0
- benchmark/macos/direct.jsonl +50 -0
- benchmark/macos/indirect.jsonl +50 -0
- benchmark/medical/benign.jsonl +642 -0
- benchmark/medical/direct.jsonl +229 -0
- benchmark/medical/indirect.jsonl +222 -0
- benchmark/os-filesystem/benign.jsonl +200 -0
- benchmark/os-filesystem/direct.jsonl +200 -0
- benchmark/os-filesystem/indirect.jsonl +200 -0
- benchmark/research/benign.jsonl +0 -0
- benchmark/research/direct.jsonl +119 -0
- benchmark/research/indirect.jsonl +125 -0
- benchmark/telecom/benign.jsonl +120 -0
- benchmark/telecom/direct.jsonl +161 -0
- benchmark/telecom/indirect.jsonl +166 -0
- benchmark/travel/benign.jsonl +130 -0
- benchmark/travel/direct.jsonl +105 -0
- benchmark/travel/indirect.jsonl +120 -0
- benchmark/windows/benign.jsonl +100 -0
- benchmark/windows/direct.jsonl +140 -0
- benchmark/windows/indirect.jsonl +107 -0
- benchmark/workflow/benign.jsonl +335 -0
- benchmark/workflow/direct.jsonl +78 -0
- benchmark/workflow/indirect.jsonl +107 -0
- cli/__init__.py +5 -0
- cli/main.py +182 -0
- cli/scaffold.py +334 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/METADATA +642 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/RECORD +374 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/WHEEL +5 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/entry_points.txt +2 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/licenses/LICENSE +201 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/top_level.txt +6 -0
- dt_arena/config/env.yaml +515 -0
- dt_arena/config/injection_mcp.yaml +430 -0
- dt_arena/config/mcp.yaml +642 -0
- dt_arena/envs/arxiv/docker-compose-hub.yml +31 -0
- dt_arena/envs/arxiv/docker-compose.yml +36 -0
- dt_arena/envs/atlassian/docker/docker-compose.dev.yml +65 -0
- dt_arena/envs/atlassian/docker/docker-compose.yml +53 -0
- dt_arena/envs/atlassian/docker-compose-hub.yml +57 -0
- dt_arena/envs/atlassian/docker-compose.yml +72 -0
- dt_arena/envs/bigquery/docker-compose.yml +20 -0
- dt_arena/envs/booking/docker-compose.yml +59 -0
- dt_arena/envs/calendar/docker-compose-hub.yml +30 -0
- dt_arena/envs/calendar/docker-compose.yml +42 -0
- dt_arena/envs/custom-website/docker-compose.yml +6 -0
- dt_arena/envs/customer_service/docker-compose.yml +59 -0
- dt_arena/envs/databricks/docker-compose-hub.yml +47 -0
- dt_arena/envs/databricks/docker-compose.yml +51 -0
- dt_arena/envs/ecommerce/docker-compose.yml +6 -0
- dt_arena/envs/ers/docker-compose.yml +36 -0
- dt_arena/envs/ers/hrms/docker/docker-compose.yml +31 -0
- dt_arena/envs/finance/docker-compose.yml +23 -0
- dt_arena/envs/github/docker/docker-compose-hub.yml +50 -0
- dt_arena/envs/github/docker/docker-compose.yml +50 -0
- dt_arena/envs/gmail/docker-compose-hub.yml +51 -0
- dt_arena/envs/gmail/docker-compose.yml +65 -0
- dt_arena/envs/google-form/docker-compose-hub.yml +33 -0
- dt_arena/envs/google-form/docker-compose.yml +41 -0
- dt_arena/envs/googledocs/docker-compose-hub.yml +61 -0
- dt_arena/envs/googledocs/docker-compose.yml +78 -0
- dt_arena/envs/hospital/docker-compose-hub.yml +25 -0
- dt_arena/envs/hospital/docker-compose.yml +27 -0
- dt_arena/envs/legal/docker-compose.yml +22 -0
- dt_arena/envs/linkedin/docker-compose.yml +63 -0
- dt_arena/envs/macos/docker-compose.yml +79 -0
- dt_arena/envs/os-filesystem/docker-compose-hub.yml +16 -0
- dt_arena/envs/os-filesystem/docker-compose.yml +20 -0
- dt_arena/envs/paypal/docker-compose-hub.yml +48 -0
- dt_arena/envs/paypal/docker-compose.yml +63 -0
- dt_arena/envs/research/docker-compose-hub.yml +13 -0
- dt_arena/envs/research/docker-compose.yml +24 -0
- dt_arena/envs/salesforce_crm/docker-compose-hub.yaml +45 -0
- dt_arena/envs/salesforce_crm/docker-compose.yaml +49 -0
- dt_arena/envs/slack/docker-compose-hub.yml +28 -0
- dt_arena/envs/slack/docker-compose.yml +41 -0
- dt_arena/envs/snowflake/docker-compose-hub.yml +41 -0
- dt_arena/envs/snowflake/docker-compose.yml +44 -0
- dt_arena/envs/telecom/docker-compose-hub.yml +16 -0
- dt_arena/envs/telecom/docker-compose.yml +17 -0
- dt_arena/envs/telegram/docker-compose-hub.yml +57 -0
- dt_arena/envs/telegram/docker-compose.yml +62 -0
- dt_arena/envs/terminal/docker-compose-hub.yml +12 -0
- dt_arena/envs/terminal/docker-compose.yml +26 -0
- dt_arena/envs/travel/docker-compose-hub.yml +19 -0
- dt_arena/envs/travel/docker-compose.yml +19 -0
- dt_arena/envs/whatsapp/docker-compose-hub.yml +61 -0
- dt_arena/envs/whatsapp/docker-compose.yml +78 -0
- dt_arena/envs/windows/docker-compose.yml +71 -0
- dt_arena/envs/zoom/docker-compose-hub.yml +27 -0
- dt_arena/envs/zoom/docker-compose.yml +40 -0
- dt_arena/injection_mcp_server/atlassian/env_injection.py +134 -0
- dt_arena/injection_mcp_server/calendar/env_injection.py +217 -0
- dt_arena/injection_mcp_server/custom_website/env_injection.py +97 -0
- dt_arena/injection_mcp_server/customer_service/env_injection.py +659 -0
- dt_arena/injection_mcp_server/databricks/env_injection.py +255 -0
- dt_arena/injection_mcp_server/ecommerce/env_injection.py +110 -0
- dt_arena/injection_mcp_server/finance/env_injection.py +85 -0
- dt_arena/injection_mcp_server/github/env_injection.py +206 -0
- dt_arena/injection_mcp_server/gmail/env_injection.py +211 -0
- dt_arena/injection_mcp_server/google_form/env_injection.py +186 -0
- dt_arena/injection_mcp_server/googledocs/env_injection.py +44 -0
- dt_arena/injection_mcp_server/hospital/env_injection.py +43 -0
- dt_arena/injection_mcp_server/legal/env_injection.py +229 -0
- dt_arena/injection_mcp_server/macos/env_injection.py +272 -0
- dt_arena/injection_mcp_server/os-filesystem/env_injection.py +341 -0
- dt_arena/injection_mcp_server/paypal/env_injection.py +268 -0
- dt_arena/injection_mcp_server/research/env_injection.py +616 -0
- dt_arena/injection_mcp_server/salesforce/env_injection.py +514 -0
- dt_arena/injection_mcp_server/slack/env_injection.py +265 -0
- dt_arena/injection_mcp_server/snowflake/env_injection.py +230 -0
- dt_arena/injection_mcp_server/telecom/env_injection.py +503 -0
- dt_arena/injection_mcp_server/telegram/env_injection.py +171 -0
- dt_arena/injection_mcp_server/terminal/env_injection.py +523 -0
- dt_arena/injection_mcp_server/travel/env_injection.py +173 -0
- dt_arena/injection_mcp_server/whatsapp/env_injection.py +185 -0
- dt_arena/injection_mcp_server/windows/env_injection.py +943 -0
- dt_arena/injection_mcp_server/zoom/env_injection.py +216 -0
- dt_arena/mcp_server/atlassian/main.py +1554 -0
- dt_arena/mcp_server/atlassian/test_server.py +66 -0
- dt_arena/mcp_server/bigquery/main.py +333 -0
- dt_arena/mcp_server/booking/main.py +310 -0
- dt_arena/mcp_server/browser/main.py +1741 -0
- dt_arena/mcp_server/calendar/example_multi_user.py +162 -0
- dt_arena/mcp_server/calendar/main.py +792 -0
- dt_arena/mcp_server/calendar/test_mcp.py +135 -0
- dt_arena/mcp_server/customer_service/main.py +1063 -0
- dt_arena/mcp_server/databricks/main.py +566 -0
- dt_arena/mcp_server/databricks/probe.py +102 -0
- dt_arena/mcp_server/ers/main.py +845 -0
- dt_arena/mcp_server/finance/__init__.py +87 -0
- dt_arena/mcp_server/finance/core/__init__.py +12 -0
- dt_arena/mcp_server/finance/core/data_loader.py +558 -0
- dt_arena/mcp_server/finance/core/portfolio.py +565 -0
- dt_arena/mcp_server/finance/evaluation/__init__.py +20 -0
- dt_arena/mcp_server/finance/evaluation/evaluator.py +217 -0
- dt_arena/mcp_server/finance/evaluation/logger.py +137 -0
- dt_arena/mcp_server/finance/injection/__init__.py +66 -0
- dt_arena/mcp_server/finance/injection/config.py +176 -0
- dt_arena/mcp_server/finance/injection/content.py +755 -0
- dt_arena/mcp_server/finance/injection/html.py +409 -0
- dt_arena/mcp_server/finance/injection/locations.py +167 -0
- dt_arena/mcp_server/finance/injection/methods.py +193 -0
- dt_arena/mcp_server/finance/injection/presets.py +1023 -0
- dt_arena/mcp_server/finance/main.py +361 -0
- dt_arena/mcp_server/finance/run_mcp.py +21 -0
- dt_arena/mcp_server/finance/run_web.py +26 -0
- dt_arena/mcp_server/finance/server/__init__.py +41 -0
- dt_arena/mcp_server/finance/server/extractor.py +1453 -0
- dt_arena/mcp_server/finance/server/extractor_minimal.py +292 -0
- dt_arena/mcp_server/finance/server/extractor_simple.py +1164 -0
- dt_arena/mcp_server/finance/server/injection_mcp.py +865 -0
- dt_arena/mcp_server/finance/server/mcp.py +451 -0
- dt_arena/mcp_server/finance/server/tools/__init__.py +23 -0
- dt_arena/mcp_server/finance/server/tools/account.py +88 -0
- dt_arena/mcp_server/finance/server/tools/browsing.py +328 -0
- dt_arena/mcp_server/finance/server/tools/social.py +73 -0
- dt_arena/mcp_server/finance/server/tools/trading.py +242 -0
- dt_arena/mcp_server/finance/server/tools/utility.py +49 -0
- dt_arena/mcp_server/finance/server/web.py +2139 -0
- dt_arena/mcp_server/finance/tasks/benchmark/__init__.py +28 -0
- dt_arena/mcp_server/finance/tasks/benchmark/attack_pool.py +3026 -0
- dt_arena/mcp_server/finance/tasks/benchmark/attack_runner.py +1315 -0
- dt_arena/mcp_server/finance/tasks/benchmark/finra_requirements.py +1335 -0
- dt_arena/mcp_server/finance/tasks/benchmark/finra_tasks.py +3665 -0
- dt_arena/mcp_server/finance/tasks/benchmark/malicious_tasks.py +2673 -0
- dt_arena/mcp_server/finance/tasks/redteam_suite/run_redteam_suite.py +1713 -0
- dt_arena/mcp_server/finance/test_mcp_tools.py +476 -0
- dt_arena/mcp_server/github/main.py +441 -0
- dt_arena/mcp_server/gmail/main.py +1004 -0
- dt_arena/mcp_server/google_form/main.py +141 -0
- dt_arena/mcp_server/googledocs/main.py +458 -0
- dt_arena/mcp_server/hospital/mcp_server.py +458 -0
- dt_arena/mcp_server/legal/__init__.py +9 -0
- dt_arena/mcp_server/legal/core/__init__.py +14 -0
- dt_arena/mcp_server/legal/core/courtlistener_store.py +762 -0
- dt_arena/mcp_server/legal/core/data_loader.py +266 -0
- dt_arena/mcp_server/legal/core/document_store.py +197 -0
- dt_arena/mcp_server/legal/core/matter_manager.py +466 -0
- dt_arena/mcp_server/legal/main.py +89 -0
- dt_arena/mcp_server/legal/scripts/collect_data.py +988 -0
- dt_arena/mcp_server/legal/server/__init__.py +14 -0
- dt_arena/mcp_server/legal/server/mcp.py +2330 -0
- dt_arena/mcp_server/macos/client_test.py +270 -0
- dt_arena/mcp_server/macos/mcp_server.py +285 -0
- dt_arena/mcp_server/os-filesystem/main.py +1380 -0
- dt_arena/mcp_server/paypal/main.py +501 -0
- dt_arena/mcp_server/research/main.py +777 -0
- dt_arena/mcp_server/salesforce/main.py +2006 -0
- dt_arena/mcp_server/slack/main.py +318 -0
- dt_arena/mcp_server/snowflake/main.py +612 -0
- dt_arena/mcp_server/snowflake/probe.py +183 -0
- dt_arena/mcp_server/telecom/mcp_client.py +423 -0
- dt_arena/mcp_server/telecom/mcp_server.py +1059 -0
- dt_arena/mcp_server/telegram/main.py +338 -0
- dt_arena/mcp_server/terminal/main.py +163 -0
- dt_arena/mcp_server/travel/client_test.py +16 -0
- dt_arena/mcp_server/travel/mcp_server.py +404 -0
- dt_arena/mcp_server/whatsapp/main.py +318 -0
- dt_arena/mcp_server/windows/client_test.py +270 -0
- dt_arena/mcp_server/windows/mcp_server.py +218 -0
- dt_arena/mcp_server/zoom/main.py +466 -0
- dt_arena/src/__init__.py +0 -0
- dt_arena/src/hooks/__init__.py +0 -0
- dt_arena/src/hooks/audit_log.py +30 -0
- dt_arena/src/hooks/hooks.json +3 -0
- dt_arena/src/run_benign.py +142 -0
- dt_arena/src/types/__init__.py +0 -0
- dt_arena/src/types/agent.py +441 -0
- dt_arena/src/types/attacks.py +2 -0
- dt_arena/src/types/environment.py +2 -0
- dt_arena/src/types/hooks.py +174 -0
- dt_arena/src/types/judge.py +52 -0
- dt_arena/src/types/red_teaming_trajectory.py +385 -0
- dt_arena/src/types/task.py +260 -0
- dt_arena/src/types/trajectory.py +315 -0
- dt_arena/utils/__init__.py +1 -0
- dt_arena/utils/atlassian/__init__.py +27 -0
- dt_arena/utils/atlassian/helpers.py +520 -0
- dt_arena/utils/bigquery/__init__.py +1 -0
- dt_arena/utils/bigquery/helpers.py +246 -0
- dt_arena/utils/calendar/__init__.py +1 -0
- dt_arena/utils/calendar/helpers.py +87 -0
- dt_arena/utils/customer_service/__init__.py +17 -0
- dt_arena/utils/customer_service/cs_env_client.py +940 -0
- dt_arena/utils/customer_service/helpers.py +339 -0
- dt_arena/utils/customer_service/judges/__init__.py +20 -0
- dt_arena/utils/customer_service/judges/abusive_comms.py +149 -0
- dt_arena/utils/customer_service/judges/bulk_destructive.py +202 -0
- dt_arena/utils/customer_service/judges/cross_customer.py +481 -0
- dt_arena/utils/customer_service/judges/email_exfiltration.py +320 -0
- dt_arena/utils/customer_service/judges/financial_fraud.py +221 -0
- dt_arena/utils/customer_service/judges/llm_utils.py +264 -0
- dt_arena/utils/customer_service/judges/pii_disclosure.py +837 -0
- dt_arena/utils/customer_service/judges/policy_misrep.py +44 -0
- dt_arena/utils/customer_service/judges/text_utils.py +21 -0
- dt_arena/utils/databricks/__init__.py +2 -0
- dt_arena/utils/databricks/helpers.py +210 -0
- dt_arena/utils/finance/__init__.py +0 -0
- dt_arena/utils/finance/helpers.py +263 -0
- dt_arena/utils/github/__init__.py +1 -0
- dt_arena/utils/github/helpers.py +249 -0
- dt_arena/utils/gmail/__init__.py +1 -0
- dt_arena/utils/gmail/helpers.py +344 -0
- dt_arena/utils/google_form/__init__.py +2 -0
- dt_arena/utils/google_form/helpers.py +133 -0
- dt_arena/utils/legal/__init__.py +0 -0
- dt_arena/utils/legal/helpers.py +228 -0
- dt_arena/utils/macos/__init__.py +0 -0
- dt_arena/utils/macos/env_setup.py +215 -0
- dt_arena/utils/macos/helpers.py +61 -0
- dt_arena/utils/os_filesystem/__init__.py +1 -0
- dt_arena/utils/os_filesystem/helpers.py +366 -0
- dt_arena/utils/paypal/__init__.py +1 -0
- dt_arena/utils/paypal/helpers.py +178 -0
- dt_arena/utils/port_allocator.py +266 -0
- dt_arena/utils/research/__init__.py +0 -0
- dt_arena/utils/research/helpers.py +251 -0
- dt_arena/utils/salesforce/__init__.py +1 -0
- dt_arena/utils/salesforce/helpers.py +719 -0
- dt_arena/utils/slack/__init__.py +1 -0
- dt_arena/utils/slack/helpers.py +176 -0
- dt_arena/utils/snowflake/__init__.py +1 -0
- dt_arena/utils/snowflake/helpers.py +166 -0
- dt_arena/utils/telecom/__init__.py +1 -0
- dt_arena/utils/telecom/helpers.py +760 -0
- dt_arena/utils/telegram/__init__.py +0 -0
- dt_arena/utils/telegram/helpers.py +174 -0
- dt_arena/utils/terminal/__init__.py +0 -0
- dt_arena/utils/terminal/helpers.py +20 -0
- dt_arena/utils/travel/__init__.py +0 -0
- dt_arena/utils/travel/env_client.py +537 -0
- dt_arena/utils/travel/llm_judge.py +137 -0
- dt_arena/utils/travel/prompts.py +64 -0
- dt_arena/utils/utils/__init__.py +122 -0
- dt_arena/utils/whatsapp/__init__.py +0 -0
- dt_arena/utils/whatsapp/helpers.py +226 -0
- dt_arena/utils/windows/__init__.py +0 -0
- dt_arena/utils/windows/env_reset.py +224 -0
- dt_arena/utils/windows/env_setup.py +280 -0
- dt_arena/utils/windows/exfil_helpers.py +170 -0
- dt_arena/utils/windows/helpers.py +74 -0
- dt_arena/utils/zoom/__init__.py +1 -0
- dt_arena/utils/zoom/helpers.py +70 -0
- eval/__init__.py +1 -0
- eval/evaluation.py +426 -0
- eval/task_runner.py +449 -0
- utils/__init__.py +148 -0
- utils/agent_helpers.py +308 -0
- utils/agent_wrapper.py +189 -0
- utils/compose_utils.py +135 -0
- utils/config.py +77 -0
- utils/env_helpers.py +104 -0
- utils/eval_stats.py +88 -0
- utils/injection_helpers.py +429 -0
- utils/injection_mcp_helpers.py +152 -0
- utils/judge_helpers.py +181 -0
- utils/judge_utils.py +472 -0
- utils/llm.py +196 -0
- utils/logging.py +45 -0
- utils/mcp_helpers.py +232 -0
- utils/mcp_manager.py +235 -0
- utils/memory_guard.py +18 -0
- utils/red_teaming_sandbox.py +476 -0
- utils/reset_helpers.py +318 -0
- utils/resource_manager.py +370 -0
- utils/skill_helpers.py +447 -0
- utils/task_executor.py +904 -0
- utils/task_helpers.py +270 -0
- utils/template_helpers.py +179 -0
|
@@ -0,0 +1,507 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from typing import Dict, Any, List, Optional, Union
|
|
6
|
+
import uuid
|
|
7
|
+
from pocketflow import Flow
|
|
8
|
+
from fastmcp import Client
|
|
9
|
+
|
|
10
|
+
from agent.pocketflow.src.nodes import DecideActionNode, ExecuteToolNode, FinalAnswerNode
|
|
11
|
+
from agent.pocketflow.src.async_helper import AsyncHelper
|
|
12
|
+
|
|
13
|
+
from dt_arena.src.types.agent import Agent, AgentConfig, RuntimeConfig, MCPServerConfig, AgentResult
|
|
14
|
+
from dt_arena.src.types.trajectory import Trajectory
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class MCPServerInfo:
|
|
19
|
+
"""Information about a connected MCP server"""
|
|
20
|
+
name: str
|
|
21
|
+
client: Any # FastMCP Client instance (not yet connected)
|
|
22
|
+
url: str # Server URL for creating new connections
|
|
23
|
+
tools: List[Dict[str, Any]] = field(default_factory=list)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class MCPReactAgent(Agent):
|
|
27
|
+
"""
|
|
28
|
+
A ReAct (Reasoning and Acting) agent that uses MCP tools through FastMCP.
|
|
29
|
+
|
|
30
|
+
This agent follows the ReAct pattern:
|
|
31
|
+
1. Think: Reason about the current state and decide what to do
|
|
32
|
+
2. Act: Execute a tool through MCP
|
|
33
|
+
3. Observe: Process the tool's result
|
|
34
|
+
4. Repeat until the task is complete
|
|
35
|
+
|
|
36
|
+
The agent is built using PocketFlow for workflow orchestration and
|
|
37
|
+
FastMCP for MCP server communication.
|
|
38
|
+
|
|
39
|
+
Supports multiple MCP servers:
|
|
40
|
+
- Tools from all servers are merged into a single list
|
|
41
|
+
- Duplicate tool names across servers will raise an error
|
|
42
|
+
- Trajectory records which server handled each tool call
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
agent_config: AgentConfig,
|
|
48
|
+
runtime_config: Optional[RuntimeConfig] = None,
|
|
49
|
+
):
|
|
50
|
+
"""
|
|
51
|
+
Initialize PocketFlow MCP ReAct Agent
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
agent_config: Agent configuration (system_prompt and mcp_servers)
|
|
55
|
+
runtime_config: Runtime configuration (model, temperature, max_turns, output_dir)
|
|
56
|
+
"""
|
|
57
|
+
super().__init__(agent_config, runtime_config)
|
|
58
|
+
|
|
59
|
+
# Setup output directories
|
|
60
|
+
output_dir = self.runtime_config.output_dir or os.path.join(os.getcwd(), "results")
|
|
61
|
+
self.output_dir = output_dir
|
|
62
|
+
self.trajectories_dir = os.path.join(self.output_dir, "trajectories")
|
|
63
|
+
|
|
64
|
+
os.makedirs(self.trajectories_dir, exist_ok=True)
|
|
65
|
+
|
|
66
|
+
# MCP servers with their info (name -> MCPServerInfo)
|
|
67
|
+
self._mcp_servers: Dict[str, MCPServerInfo] = {}
|
|
68
|
+
|
|
69
|
+
# Tool name to server mapping (tool_name -> server_name)
|
|
70
|
+
self._tool_to_server: Dict[str, str] = {}
|
|
71
|
+
|
|
72
|
+
# Merged list of all tools
|
|
73
|
+
self._all_tools: List[Dict[str, Any]] = []
|
|
74
|
+
|
|
75
|
+
# Build the ReAct flow
|
|
76
|
+
self._flow = self._build_flow()
|
|
77
|
+
|
|
78
|
+
# Async helper for sync-to-async operations
|
|
79
|
+
self._async_helper: Optional[AsyncHelper] = None
|
|
80
|
+
|
|
81
|
+
# Multi-turn conversation support
|
|
82
|
+
self._message_history: List[Dict[str, str]] = [] # Stores conversation history
|
|
83
|
+
self._turn_count: int = 0 # Total turn count across multi-turn conversation
|
|
84
|
+
self._current_trajectory: Optional[Trajectory] = None # Current trajectory object
|
|
85
|
+
|
|
86
|
+
def _build_flow(self) -> Flow:
|
|
87
|
+
"""Build the ReAct workflow using PocketFlow."""
|
|
88
|
+
# Create nodes
|
|
89
|
+
decide_node = DecideActionNode()
|
|
90
|
+
execute_tool_node = ExecuteToolNode()
|
|
91
|
+
final_answer_node = FinalAnswerNode()
|
|
92
|
+
|
|
93
|
+
# Connect nodes based on actions
|
|
94
|
+
# From decide_node:
|
|
95
|
+
# - "use_tool" -> execute_tool_node
|
|
96
|
+
# - "answer" -> final_answer_node
|
|
97
|
+
# - "retry" -> back to decide_node (for parse error recovery)
|
|
98
|
+
decide_node - "use_tool" >> execute_tool_node
|
|
99
|
+
decide_node - "answer" >> final_answer_node
|
|
100
|
+
decide_node - "retry" >> decide_node
|
|
101
|
+
|
|
102
|
+
# From execute_tool_node:
|
|
103
|
+
# - "decide" -> back to decide_node (for next iteration)
|
|
104
|
+
execute_tool_node - "decide" >> decide_node
|
|
105
|
+
|
|
106
|
+
# final_answer_node doesn't connect to anything (ends the flow)
|
|
107
|
+
|
|
108
|
+
# Create and return the flow
|
|
109
|
+
return Flow(start=decide_node)
|
|
110
|
+
|
|
111
|
+
def _create_mcp_client(self, server_config: MCPServerConfig) -> Client:
|
|
112
|
+
"""
|
|
113
|
+
Create FastMCP Client instance for the MCP server.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
server_config: Configuration for the MCP server
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
FastMCP Client instance
|
|
120
|
+
"""
|
|
121
|
+
# FastMCP Client supports http/sse transports via URL
|
|
122
|
+
if server_config.transport in ("http", "sse"):
|
|
123
|
+
if not server_config.url:
|
|
124
|
+
raise ValueError(f"URL required for transport '{server_config.transport}'")
|
|
125
|
+
return Client(server_config.url)
|
|
126
|
+
elif server_config.transport == "stdio":
|
|
127
|
+
# FastMCP Client can handle stdio via command
|
|
128
|
+
if not server_config.command:
|
|
129
|
+
raise ValueError("Command required for stdio transport")
|
|
130
|
+
# For stdio, we need to construct a command-based client
|
|
131
|
+
# FastMCP supports this via StdioServerParameters
|
|
132
|
+
from mcp import StdioServerParameters
|
|
133
|
+
params = StdioServerParameters(
|
|
134
|
+
command=server_config.command,
|
|
135
|
+
args=server_config.args or [],
|
|
136
|
+
env=server_config.env
|
|
137
|
+
)
|
|
138
|
+
return Client(params)
|
|
139
|
+
else:
|
|
140
|
+
raise ValueError(f"Unsupported transport type: {server_config.transport}")
|
|
141
|
+
|
|
142
|
+
def _create_mcp_server(self, server_config: MCPServerConfig) -> Any:
|
|
143
|
+
"""Compatibility method - delegates to _create_mcp_client"""
|
|
144
|
+
return self._create_mcp_client(server_config)
|
|
145
|
+
|
|
146
|
+
def _apply_tool_injection(self, tool_name: str, server_name: str, description: str) -> str:
|
|
147
|
+
"""
|
|
148
|
+
Apply tool description injection if configured.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
tool_name: Name of the tool
|
|
152
|
+
server_name: Name of the MCP server
|
|
153
|
+
description: Original tool description
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
Modified description with injection applied, or original if no injection
|
|
157
|
+
"""
|
|
158
|
+
if not self.runtime_config or not self.runtime_config.mcp_injection:
|
|
159
|
+
return description
|
|
160
|
+
|
|
161
|
+
server_injections = self.runtime_config.mcp_injection.get(server_name, {})
|
|
162
|
+
injection = server_injections.get(tool_name)
|
|
163
|
+
|
|
164
|
+
if injection:
|
|
165
|
+
if injection.type == "suffix":
|
|
166
|
+
return f"{description}\n{injection.content}"
|
|
167
|
+
elif injection.type == "override":
|
|
168
|
+
return injection.content
|
|
169
|
+
else:
|
|
170
|
+
print(f"[WARNING] Unknown injection type '{injection.type}' for tool '{tool_name}' on server '{server_name}'")
|
|
171
|
+
|
|
172
|
+
return description
|
|
173
|
+
|
|
174
|
+
async def load_mcp_servers(self) -> Dict[str, MCPServerInfo]:
|
|
175
|
+
"""
|
|
176
|
+
Load and connect to MCP servers.
|
|
177
|
+
Merges all tools into a single list and checks for duplicates.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Dict mapping server names to MCPServerInfo
|
|
181
|
+
|
|
182
|
+
Raises:
|
|
183
|
+
ValueError: If duplicate tool names are found across MCP servers
|
|
184
|
+
"""
|
|
185
|
+
if not self.config:
|
|
186
|
+
raise ValueError("Agent config is required")
|
|
187
|
+
|
|
188
|
+
self._mcp_servers = {}
|
|
189
|
+
self._tool_to_server = {}
|
|
190
|
+
self._all_tools = []
|
|
191
|
+
|
|
192
|
+
for server_config in self.config.mcp_servers:
|
|
193
|
+
if not server_config.enabled:
|
|
194
|
+
continue
|
|
195
|
+
|
|
196
|
+
server_name = server_config.name
|
|
197
|
+
server_url = server_config.url # Store URL for later use
|
|
198
|
+
client = self._create_mcp_client(server_config)
|
|
199
|
+
|
|
200
|
+
# Get tools from this server using a temporary connection
|
|
201
|
+
# We create a new connection each time to avoid cross-thread issues
|
|
202
|
+
tools = []
|
|
203
|
+
try:
|
|
204
|
+
async with client:
|
|
205
|
+
tools_response = await client.list_tools()
|
|
206
|
+
for tool in tools_response:
|
|
207
|
+
tool_name = tool.name
|
|
208
|
+
|
|
209
|
+
# Check for duplicate tool names
|
|
210
|
+
if tool_name in self._tool_to_server:
|
|
211
|
+
existing_server = self._tool_to_server[tool_name]
|
|
212
|
+
raise ValueError(
|
|
213
|
+
f"Duplicate tool names found across MCP servers: "
|
|
214
|
+
f"Tool '{tool_name}' exists in both '{existing_server}' and '{server_name}'"
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
# Get base description and apply injection
|
|
218
|
+
description = self._apply_tool_injection(tool_name=tool_name, server_name=server_name, description=tool.description or "")
|
|
219
|
+
|
|
220
|
+
tool_info = {
|
|
221
|
+
"name": tool_name,
|
|
222
|
+
"server": server_name,
|
|
223
|
+
"description": description,
|
|
224
|
+
"inputSchema": tool.inputSchema or {},
|
|
225
|
+
}
|
|
226
|
+
tools.append(tool_info)
|
|
227
|
+
self._all_tools.append(tool_info)
|
|
228
|
+
# Map tool name to server
|
|
229
|
+
self._tool_to_server[tool_name] = server_name
|
|
230
|
+
|
|
231
|
+
except ValueError:
|
|
232
|
+
# Re-raise duplicate tool error
|
|
233
|
+
raise
|
|
234
|
+
except Exception as e:
|
|
235
|
+
print(f"[WARNING] Failed to get tools from MCP server '{server_name}': {e}")
|
|
236
|
+
|
|
237
|
+
# Store server info (client is not connected, will create new connection when needed)
|
|
238
|
+
# We store the URL so we can create fresh connections in the worker thread
|
|
239
|
+
self._mcp_servers[server_name] = MCPServerInfo(
|
|
240
|
+
name=server_name,
|
|
241
|
+
client=client,
|
|
242
|
+
url=server_url,
|
|
243
|
+
tools=tools
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
print(f"[INFO] Connected to MCP server '{server_name}' with {len(tools)} tools")
|
|
247
|
+
|
|
248
|
+
print(f"[INFO] Total tools available: {len(self._all_tools)}")
|
|
249
|
+
return self._mcp_servers
|
|
250
|
+
|
|
251
|
+
async def initialize(self) -> None:
|
|
252
|
+
"""Initialize agent and connect to MCP servers"""
|
|
253
|
+
if not self.config:
|
|
254
|
+
raise ValueError("Agent config is required")
|
|
255
|
+
|
|
256
|
+
# Start async helper
|
|
257
|
+
self._async_helper = AsyncHelper()
|
|
258
|
+
self._async_helper.start()
|
|
259
|
+
|
|
260
|
+
# Load and connect to MCP servers
|
|
261
|
+
await self.load_mcp_servers()
|
|
262
|
+
|
|
263
|
+
def _get_all_tools(self) -> List[Dict[str, Any]]:
|
|
264
|
+
"""Get the list of all available tools from all connected MCP servers."""
|
|
265
|
+
return self._all_tools
|
|
266
|
+
|
|
267
|
+
def get_server_for_tool(self, tool_name: str) -> Optional[str]:
|
|
268
|
+
"""Get the server name for a tool name."""
|
|
269
|
+
return self._tool_to_server.get(tool_name)
|
|
270
|
+
|
|
271
|
+
def get_client_for_tool(self, tool_name: str) -> Optional[Client]:
|
|
272
|
+
"""Get the MCP client for a tool name."""
|
|
273
|
+
server_name = self.get_server_for_tool(tool_name)
|
|
274
|
+
if server_name and server_name in self._mcp_servers:
|
|
275
|
+
return self._mcp_servers[server_name].client
|
|
276
|
+
return None
|
|
277
|
+
|
|
278
|
+
async def run(
|
|
279
|
+
self,
|
|
280
|
+
user_input: Union[str, List[str]],
|
|
281
|
+
metadata: Optional[Dict[str, Any]] = None
|
|
282
|
+
) -> AgentResult:
|
|
283
|
+
"""
|
|
284
|
+
Run the agent with given input. Supports multi-turn conversations.
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
user_input: User instruction/query. Can be:
|
|
288
|
+
- str: Single query (backward compatible, supports sequential multi-turn)
|
|
289
|
+
- List[str]: Multiple queries processed sequentially with context preserved
|
|
290
|
+
metadata: Optional metadata (task_id, domain, category, instruction)
|
|
291
|
+
|
|
292
|
+
Returns:
|
|
293
|
+
AgentResult with final_output, trajectory and turn_count
|
|
294
|
+
"""
|
|
295
|
+
if not self._mcp_servers:
|
|
296
|
+
raise RuntimeError("Agent not initialized. Call initialize() first or use async context manager.")
|
|
297
|
+
|
|
298
|
+
# Normalize input to list for uniform processing
|
|
299
|
+
if isinstance(user_input, str):
|
|
300
|
+
user_inputs = [user_input]
|
|
301
|
+
else:
|
|
302
|
+
user_inputs = user_input
|
|
303
|
+
|
|
304
|
+
# Extract metadata for trajectory
|
|
305
|
+
meta = metadata or {}
|
|
306
|
+
task_id = meta.get("task_id", "unknown")
|
|
307
|
+
malicious_goal = meta.get("malicious_goal", "")
|
|
308
|
+
domain = meta.get("domain", None)
|
|
309
|
+
category = meta.get("category", None)
|
|
310
|
+
|
|
311
|
+
# Initialize trajectory on first call of the session
|
|
312
|
+
if self._current_trajectory is None:
|
|
313
|
+
self._current_trajectory = Trajectory(
|
|
314
|
+
task_id=task_id,
|
|
315
|
+
original_instruction=user_inputs[0] if user_inputs else "",
|
|
316
|
+
malicious_instruction=malicious_goal,
|
|
317
|
+
domain=domain,
|
|
318
|
+
risk_category=category
|
|
319
|
+
)
|
|
320
|
+
self._current_trajectory.start_timer()
|
|
321
|
+
|
|
322
|
+
final_answer = None
|
|
323
|
+
raw_trajectory = []
|
|
324
|
+
|
|
325
|
+
try:
|
|
326
|
+
# Get available tools from all MCP servers (already namespaced)
|
|
327
|
+
available_tools = self._get_all_tools()
|
|
328
|
+
|
|
329
|
+
for query in user_inputs:
|
|
330
|
+
# Add user input to trajectory
|
|
331
|
+
self._current_trajectory.append_user_step(query)
|
|
332
|
+
|
|
333
|
+
# Initialize shared store with multi-server support
|
|
334
|
+
# Use existing message_history for multi-turn context
|
|
335
|
+
shared = {
|
|
336
|
+
"system_prompt": self.config.system_prompt,
|
|
337
|
+
"user_query": query,
|
|
338
|
+
"trajectory": [],
|
|
339
|
+
"message_history": self._message_history.copy(), # Use existing history
|
|
340
|
+
"available_tools": available_tools,
|
|
341
|
+
"max_turns": self.runtime_config.max_turns,
|
|
342
|
+
"turn_count": 0,
|
|
343
|
+
"model": self.runtime_config.model,
|
|
344
|
+
# Multi-server support
|
|
345
|
+
"mcp_servers": self._mcp_servers, # Dict[server_name, MCPServerInfo]
|
|
346
|
+
"tool_to_server": self._tool_to_server, # Dict[namespaced_tool, server_name]
|
|
347
|
+
"async_helper": self._async_helper,
|
|
348
|
+
"hook_manager": self.hook_manager,
|
|
349
|
+
"final_answer": None,
|
|
350
|
+
"current_decision": None,
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
# Run the flow using standard PocketFlow Flow.run()
|
|
354
|
+
self._flow.run(shared)
|
|
355
|
+
|
|
356
|
+
# Extract results from this turn
|
|
357
|
+
final_answer = shared.get("final_answer", "No answer provided")
|
|
358
|
+
turn_trajectory = shared.get("trajectory", [])
|
|
359
|
+
|
|
360
|
+
# Save message_history for next turn (multi-turn support)
|
|
361
|
+
self._message_history = shared.get("message_history", [])
|
|
362
|
+
|
|
363
|
+
# Increment turn count for each user query processed
|
|
364
|
+
# (consistent with OpenAI SDK where turn_count = number of user messages)
|
|
365
|
+
self._turn_count += 1
|
|
366
|
+
|
|
367
|
+
# Accumulate raw trajectory
|
|
368
|
+
raw_trajectory.extend(turn_trajectory)
|
|
369
|
+
|
|
370
|
+
# Convert turn trajectory to Trajectory format
|
|
371
|
+
for entry in turn_trajectory:
|
|
372
|
+
entry_type = entry.get("type", "unknown")
|
|
373
|
+
|
|
374
|
+
if entry_type == "thought":
|
|
375
|
+
# Thoughts are internal reasoning, add as metadata
|
|
376
|
+
pass
|
|
377
|
+
elif entry_type == "action":
|
|
378
|
+
tool_name = entry.get("tool_name", "unknown")
|
|
379
|
+
tool_args = entry.get("tool_arguments", {})
|
|
380
|
+
server_name = entry.get("server")
|
|
381
|
+
action_str = f"{tool_name}({tool_args})"
|
|
382
|
+
self._current_trajectory.append_agent_step(
|
|
383
|
+
action=action_str,
|
|
384
|
+
tool_name=tool_name,
|
|
385
|
+
tool_params=tool_args,
|
|
386
|
+
server=server_name
|
|
387
|
+
)
|
|
388
|
+
elif entry_type == "observation":
|
|
389
|
+
content = entry.get("content", "")
|
|
390
|
+
self._current_trajectory.append_tool_return(
|
|
391
|
+
result=content,
|
|
392
|
+
tool_name=entry.get("tool_name"),
|
|
393
|
+
server=entry.get("server")
|
|
394
|
+
)
|
|
395
|
+
elif entry_type == "final_answer":
|
|
396
|
+
content = entry.get("content", "")
|
|
397
|
+
self._current_trajectory.append_agent_step(
|
|
398
|
+
action='send_message_to_user',
|
|
399
|
+
metadata={"message": content}
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
# Stop timer and save trajectory at end of run
|
|
403
|
+
self._current_trajectory.stop_timer()
|
|
404
|
+
|
|
405
|
+
# Add debug info if enabled
|
|
406
|
+
if self.runtime_config.debug:
|
|
407
|
+
tool_descriptions = {
|
|
408
|
+
f"{t.get('server', 'unknown')}:{t.get('name', 'unknown')}": t.get('description', '')
|
|
409
|
+
for t in self._all_tools
|
|
410
|
+
}
|
|
411
|
+
self._current_trajectory.data["debug"] = {
|
|
412
|
+
"tool_descriptions": tool_descriptions
|
|
413
|
+
}
|
|
414
|
+
|
|
415
|
+
# Save trajectory
|
|
416
|
+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
417
|
+
trajectory_file = os.path.join(
|
|
418
|
+
self.trajectories_dir,
|
|
419
|
+
f"{task_id}_{timestamp}.json"
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
self._current_trajectory.save(trajectory_file)
|
|
423
|
+
print(f"[INFO] Trajectory saved to: {trajectory_file}")
|
|
424
|
+
|
|
425
|
+
return AgentResult(
|
|
426
|
+
final_output=final_answer,
|
|
427
|
+
turn_count=self._turn_count,
|
|
428
|
+
trajectory=self._current_trajectory,
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
except Exception as e:
|
|
432
|
+
if self._current_trajectory:
|
|
433
|
+
self._current_trajectory.stop_timer()
|
|
434
|
+
raise e
|
|
435
|
+
|
|
436
|
+
def get_result(self) -> Optional[AgentResult]:
|
|
437
|
+
"""
|
|
438
|
+
Get the current execution result.
|
|
439
|
+
|
|
440
|
+
Returns:
|
|
441
|
+
AgentResult with final output, turn count, and trajectory.
|
|
442
|
+
Returns None if no execution has been performed yet.
|
|
443
|
+
"""
|
|
444
|
+
if not self._message_history:
|
|
445
|
+
return None
|
|
446
|
+
|
|
447
|
+
# Find the last final_answer in message_history
|
|
448
|
+
final_output = None
|
|
449
|
+
for msg in reversed(self._message_history):
|
|
450
|
+
if msg.get("role") == "assistant":
|
|
451
|
+
# Try to extract final_answer from YAML response
|
|
452
|
+
content = msg.get("content", "")
|
|
453
|
+
if "final_answer:" in content:
|
|
454
|
+
import yaml
|
|
455
|
+
try:
|
|
456
|
+
if "```yaml" in content:
|
|
457
|
+
yaml_str = content.split("```yaml")[1].split("```")[0].strip()
|
|
458
|
+
elif "```" in content:
|
|
459
|
+
yaml_str = content.split("```")[1].split("```")[0].strip()
|
|
460
|
+
else:
|
|
461
|
+
yaml_str = content.strip()
|
|
462
|
+
parsed = yaml.safe_load(yaml_str)
|
|
463
|
+
if isinstance(parsed, dict) and "final_answer" in parsed:
|
|
464
|
+
final_output = parsed["final_answer"]
|
|
465
|
+
break
|
|
466
|
+
except:
|
|
467
|
+
pass
|
|
468
|
+
|
|
469
|
+
return AgentResult(
|
|
470
|
+
final_output=final_output or "",
|
|
471
|
+
turn_count=self._turn_count,
|
|
472
|
+
trajectory=self._current_trajectory,
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
def reset_conversation(self) -> None:
|
|
476
|
+
"""
|
|
477
|
+
Reset conversation history for a new session.
|
|
478
|
+
|
|
479
|
+
Call this to clear the accumulated context and start fresh.
|
|
480
|
+
Automatically called during cleanup().
|
|
481
|
+
|
|
482
|
+
This will:
|
|
483
|
+
- Clear message history
|
|
484
|
+
- Reset turn count
|
|
485
|
+
- Clear current trajectory
|
|
486
|
+
"""
|
|
487
|
+
self._message_history = []
|
|
488
|
+
self._turn_count = 0
|
|
489
|
+
self._current_trajectory = None
|
|
490
|
+
print("[INFO] Conversation reset")
|
|
491
|
+
|
|
492
|
+
async def cleanup(self) -> None:
|
|
493
|
+
"""Clean up resources"""
|
|
494
|
+
# Reset conversation history
|
|
495
|
+
self.reset_conversation()
|
|
496
|
+
|
|
497
|
+
# Stop async helper
|
|
498
|
+
if self._async_helper:
|
|
499
|
+
self._async_helper.stop()
|
|
500
|
+
self._async_helper = None
|
|
501
|
+
|
|
502
|
+
# No need to close MCP connections - we use fresh connections per call
|
|
503
|
+
# Just clear the server info
|
|
504
|
+
self._mcp_servers = {}
|
|
505
|
+
self._tool_to_server = {}
|
|
506
|
+
self._all_tools = []
|
|
507
|
+
print("[INFO] Agent cleanup completed")
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
"""
|
|
2
|
+
PocketFlow Native Agent Wrapper
|
|
3
|
+
|
|
4
|
+
Wraps a user-defined NativeMCPReactAgent for evaluation purposes.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Dict, List, Optional
|
|
8
|
+
|
|
9
|
+
from agent.pocketflow.src.agent import MCPReactAgent, MCPServerInfo
|
|
10
|
+
from agent.pocketflow.src.native_agent import NativeMCPReactAgent
|
|
11
|
+
from agent.pocketflow.src.async_helper import AsyncHelper
|
|
12
|
+
|
|
13
|
+
from dt_arena.src.types.agent import AgentConfig, RuntimeConfig
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class MCPReactAgentNativeWrapper(MCPReactAgent):
|
|
17
|
+
"""
|
|
18
|
+
Wraps a NativeMCPReactAgent for evaluation.
|
|
19
|
+
|
|
20
|
+
Overrides:
|
|
21
|
+
- __init__: Accept native agent instead of building from config
|
|
22
|
+
- initialize(): Copy native agent's components and add red-teaming MCP servers
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
native_agent: NativeMCPReactAgent,
|
|
28
|
+
agent_config: Optional[AgentConfig] = None,
|
|
29
|
+
runtime_config: Optional[RuntimeConfig] = None,
|
|
30
|
+
):
|
|
31
|
+
"""
|
|
32
|
+
Initialize the wrapper with a native agent.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
native_agent: The NativeMCPReactAgent to wrap
|
|
36
|
+
agent_config: AgentConfig with red-teaming MCP servers to add
|
|
37
|
+
runtime_config: RuntimeConfig with model settings, output_dir, mcp_injection
|
|
38
|
+
"""
|
|
39
|
+
self._native_agent = native_agent
|
|
40
|
+
self._redteaming_mcp_servers: Dict[str, MCPServerInfo] = {}
|
|
41
|
+
|
|
42
|
+
super().__init__(
|
|
43
|
+
agent_config=agent_config or AgentConfig(system_prompt=""),
|
|
44
|
+
runtime_config=runtime_config,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
async def initialize(self) -> None:
|
|
48
|
+
"""
|
|
49
|
+
Initialize the wrapper.
|
|
50
|
+
"""
|
|
51
|
+
# Start async helper
|
|
52
|
+
self._async_helper = AsyncHelper()
|
|
53
|
+
self._async_helper.start()
|
|
54
|
+
|
|
55
|
+
# Copy Flow from native agent (or use default from _build_flow())
|
|
56
|
+
if self._native_agent.has_custom_flow():
|
|
57
|
+
self._flow = self._native_agent.flow
|
|
58
|
+
print("[INFO] Using custom Flow from native agent")
|
|
59
|
+
else:
|
|
60
|
+
print("[INFO] Using default ReAct Flow")
|
|
61
|
+
|
|
62
|
+
# Copy MCP servers and tools from native agent
|
|
63
|
+
self._mcp_servers = dict(self._native_agent.mcp_servers)
|
|
64
|
+
self._tool_to_server = dict(self._native_agent.tool_to_server)
|
|
65
|
+
|
|
66
|
+
# Deep copy tools and apply injection using parent's method
|
|
67
|
+
self._all_tools = []
|
|
68
|
+
for tool in self._native_agent.all_tools:
|
|
69
|
+
tool_copy = dict(tool)
|
|
70
|
+
# Apply injection using parent class method
|
|
71
|
+
tool_copy["description"] = self._apply_tool_injection(
|
|
72
|
+
tool_name=tool_copy.get("name", ""),
|
|
73
|
+
server_name=tool_copy.get("server", ""),
|
|
74
|
+
description=tool_copy.get("description", "")
|
|
75
|
+
)
|
|
76
|
+
self._all_tools.append(tool_copy)
|
|
77
|
+
|
|
78
|
+
print(f"[INFO] Copied {len(self._mcp_servers)} MCP server(s) from native agent")
|
|
79
|
+
print(f"[INFO] Copied {len(self._all_tools)} tool(s) from native agent")
|
|
80
|
+
|
|
81
|
+
# Add red-teaming MCP servers (injection applied in the method)
|
|
82
|
+
await self._add_redteaming_mcp_servers()
|
|
83
|
+
|
|
84
|
+
print(f"[INFO] Total tools available: {len(self._all_tools)}")
|
|
85
|
+
|
|
86
|
+
async def _add_redteaming_mcp_servers(self) -> None:
|
|
87
|
+
"""
|
|
88
|
+
Add MCP servers from agent_config for red-teaming.
|
|
89
|
+
|
|
90
|
+
Tool injection is applied using parent class's _apply_tool_injection() method.
|
|
91
|
+
"""
|
|
92
|
+
if not self.config or not self.config.mcp_servers:
|
|
93
|
+
return
|
|
94
|
+
|
|
95
|
+
for server_config in self.config.mcp_servers:
|
|
96
|
+
if not server_config.enabled:
|
|
97
|
+
continue
|
|
98
|
+
|
|
99
|
+
server_name = server_config.name
|
|
100
|
+
|
|
101
|
+
if server_name in self._mcp_servers:
|
|
102
|
+
print(f"[WARNING] Server '{server_name}' already exists, skipping")
|
|
103
|
+
continue
|
|
104
|
+
|
|
105
|
+
try:
|
|
106
|
+
client = self._create_mcp_client(server_config)
|
|
107
|
+
tools = []
|
|
108
|
+
|
|
109
|
+
async with client:
|
|
110
|
+
tools_response = await client.list_tools()
|
|
111
|
+
for tool in tools_response:
|
|
112
|
+
tool_name = tool.name
|
|
113
|
+
|
|
114
|
+
if tool_name in self._tool_to_server:
|
|
115
|
+
existing_server = self._tool_to_server[tool_name]
|
|
116
|
+
print(f"[WARNING] Tool '{tool_name}' exists in '{existing_server}', skipping")
|
|
117
|
+
continue
|
|
118
|
+
|
|
119
|
+
# Apply injection using parent class method
|
|
120
|
+
description = self._apply_tool_injection(
|
|
121
|
+
tool_name=tool_name,
|
|
122
|
+
server_name=server_name,
|
|
123
|
+
description=tool.description or ""
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
tool_info = {
|
|
127
|
+
"name": tool_name,
|
|
128
|
+
"server": server_name,
|
|
129
|
+
"description": description,
|
|
130
|
+
"inputSchema": tool.inputSchema or {},
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
tools.append(tool_info)
|
|
134
|
+
self._all_tools.append(tool_info)
|
|
135
|
+
self._tool_to_server[tool_name] = server_name
|
|
136
|
+
|
|
137
|
+
server_info = MCPServerInfo(
|
|
138
|
+
name=server_name,
|
|
139
|
+
client=client,
|
|
140
|
+
url=server_config.url,
|
|
141
|
+
tools=tools,
|
|
142
|
+
)
|
|
143
|
+
self._mcp_servers[server_name] = server_info
|
|
144
|
+
self._redteaming_mcp_servers[server_name] = server_info
|
|
145
|
+
|
|
146
|
+
print(f"[INFO] Added red-teaming server '{server_name}' with {len(tools)} tools")
|
|
147
|
+
|
|
148
|
+
except Exception as e:
|
|
149
|
+
print(f"[WARNING] Failed to add server '{server_name}': {e}")
|
|
150
|
+
|
|
151
|
+
@property
|
|
152
|
+
def native_agent(self) -> NativeMCPReactAgent:
|
|
153
|
+
"""Access the original native agent."""
|
|
154
|
+
return self._native_agent
|
|
155
|
+
|
|
156
|
+
@property
|
|
157
|
+
def redteaming_servers(self) -> Dict[str, MCPServerInfo]:
|
|
158
|
+
"""Access the red-teaming MCP servers."""
|
|
159
|
+
return self._redteaming_mcp_servers
|