decodingtrust-agent-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent/__init__.py +30 -0
- agent/claudesdk/__init__.py +8 -0
- agent/claudesdk/example.py +221 -0
- agent/claudesdk/src/__init__.py +8 -0
- agent/claudesdk/src/agent.py +400 -0
- agent/claudesdk/src/mcp_proxy.py +409 -0
- agent/claudesdk/src/utils.py +420 -0
- agent/googleadk/__init__.py +15 -0
- agent/googleadk/example.py +237 -0
- agent/googleadk/src/__init__.py +12 -0
- agent/googleadk/src/agent.py +401 -0
- agent/googleadk/src/mcp_wrapper.py +163 -0
- agent/googleadk/src/utils.py +602 -0
- agent/langchain/__init__.py +8 -0
- agent/langchain/example.py +213 -0
- agent/langchain/src/__init__.py +8 -0
- agent/langchain/src/agent.py +645 -0
- agent/langchain/src/utils.py +433 -0
- agent/openaisdk/__init__.py +17 -0
- agent/openaisdk/example.py +228 -0
- agent/openaisdk/src/__init__.py +12 -0
- agent/openaisdk/src/agent.py +491 -0
- agent/openaisdk/src/agent_wrapper.py +143 -0
- agent/openaisdk/src/mcp_wrapper.py +395 -0
- agent/openaisdk/src/utils.py +493 -0
- agent/openclaw/__init__.py +10 -0
- agent/openclaw/example.py +251 -0
- agent/openclaw/src/__init__.py +14 -0
- agent/openclaw/src/agent.py +930 -0
- agent/openclaw/src/helpers/__init__.py +1 -0
- agent/openclaw/src/helpers/auth_helpers.py +55 -0
- agent/openclaw/src/mcp_proxy.py +564 -0
- agent/openclaw/src/plugin_generator.py +231 -0
- agent/openclaw/src/utils.py +341 -0
- agent/pocketflow/__init__.py +18 -0
- agent/pocketflow/example.py +221 -0
- agent/pocketflow/prompts/react_agent.py +46 -0
- agent/pocketflow/src/__init__.py +6 -0
- agent/pocketflow/src/agent.py +507 -0
- agent/pocketflow/src/agent_wrapper.py +159 -0
- agent/pocketflow/src/async_helper.py +92 -0
- agent/pocketflow/src/mcp_react_agent.py +279 -0
- agent/pocketflow/src/native_agent.py +74 -0
- agent/pocketflow/src/nodes.py +467 -0
- benchmark/__init__.py +0 -0
- benchmark/browser/benign.jsonl +34 -0
- benchmark/browser/direct.jsonl +85 -0
- benchmark/browser/indirect.jsonl +82 -0
- benchmark/code/benign.jsonl +0 -0
- benchmark/code/direct.jsonl +121 -0
- benchmark/code/indirect.jsonl +165 -0
- benchmark/crm/benign.jsonl +165 -0
- benchmark/crm/direct.jsonl +90 -0
- benchmark/crm/indirect.jsonl +150 -0
- benchmark/customer-service/benign.jsonl +160 -0
- benchmark/customer-service/direct.jsonl +100 -0
- benchmark/customer-service/indirect.jsonl +101 -0
- benchmark/finance/benign.jsonl +0 -0
- benchmark/finance/direct.jsonl +200 -0
- benchmark/finance/indirect.jsonl +200 -0
- benchmark/legal/benign.jsonl +0 -0
- benchmark/legal/direct.jsonl +200 -0
- benchmark/legal/indirect.jsonl +200 -0
- benchmark/macos/benign.jsonl +30 -0
- benchmark/macos/direct.jsonl +50 -0
- benchmark/macos/indirect.jsonl +50 -0
- benchmark/medical/benign.jsonl +642 -0
- benchmark/medical/direct.jsonl +229 -0
- benchmark/medical/indirect.jsonl +222 -0
- benchmark/os-filesystem/benign.jsonl +200 -0
- benchmark/os-filesystem/direct.jsonl +200 -0
- benchmark/os-filesystem/indirect.jsonl +200 -0
- benchmark/research/benign.jsonl +0 -0
- benchmark/research/direct.jsonl +119 -0
- benchmark/research/indirect.jsonl +125 -0
- benchmark/telecom/benign.jsonl +120 -0
- benchmark/telecom/direct.jsonl +161 -0
- benchmark/telecom/indirect.jsonl +166 -0
- benchmark/travel/benign.jsonl +130 -0
- benchmark/travel/direct.jsonl +105 -0
- benchmark/travel/indirect.jsonl +120 -0
- benchmark/windows/benign.jsonl +100 -0
- benchmark/windows/direct.jsonl +140 -0
- benchmark/windows/indirect.jsonl +107 -0
- benchmark/workflow/benign.jsonl +335 -0
- benchmark/workflow/direct.jsonl +78 -0
- benchmark/workflow/indirect.jsonl +107 -0
- cli/__init__.py +5 -0
- cli/main.py +182 -0
- cli/scaffold.py +334 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/METADATA +642 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/RECORD +374 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/WHEEL +5 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/entry_points.txt +2 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/licenses/LICENSE +201 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/top_level.txt +6 -0
- dt_arena/config/env.yaml +515 -0
- dt_arena/config/injection_mcp.yaml +430 -0
- dt_arena/config/mcp.yaml +642 -0
- dt_arena/envs/arxiv/docker-compose-hub.yml +31 -0
- dt_arena/envs/arxiv/docker-compose.yml +36 -0
- dt_arena/envs/atlassian/docker/docker-compose.dev.yml +65 -0
- dt_arena/envs/atlassian/docker/docker-compose.yml +53 -0
- dt_arena/envs/atlassian/docker-compose-hub.yml +57 -0
- dt_arena/envs/atlassian/docker-compose.yml +72 -0
- dt_arena/envs/bigquery/docker-compose.yml +20 -0
- dt_arena/envs/booking/docker-compose.yml +59 -0
- dt_arena/envs/calendar/docker-compose-hub.yml +30 -0
- dt_arena/envs/calendar/docker-compose.yml +42 -0
- dt_arena/envs/custom-website/docker-compose.yml +6 -0
- dt_arena/envs/customer_service/docker-compose.yml +59 -0
- dt_arena/envs/databricks/docker-compose-hub.yml +47 -0
- dt_arena/envs/databricks/docker-compose.yml +51 -0
- dt_arena/envs/ecommerce/docker-compose.yml +6 -0
- dt_arena/envs/ers/docker-compose.yml +36 -0
- dt_arena/envs/ers/hrms/docker/docker-compose.yml +31 -0
- dt_arena/envs/finance/docker-compose.yml +23 -0
- dt_arena/envs/github/docker/docker-compose-hub.yml +50 -0
- dt_arena/envs/github/docker/docker-compose.yml +50 -0
- dt_arena/envs/gmail/docker-compose-hub.yml +51 -0
- dt_arena/envs/gmail/docker-compose.yml +65 -0
- dt_arena/envs/google-form/docker-compose-hub.yml +33 -0
- dt_arena/envs/google-form/docker-compose.yml +41 -0
- dt_arena/envs/googledocs/docker-compose-hub.yml +61 -0
- dt_arena/envs/googledocs/docker-compose.yml +78 -0
- dt_arena/envs/hospital/docker-compose-hub.yml +25 -0
- dt_arena/envs/hospital/docker-compose.yml +27 -0
- dt_arena/envs/legal/docker-compose.yml +22 -0
- dt_arena/envs/linkedin/docker-compose.yml +63 -0
- dt_arena/envs/macos/docker-compose.yml +79 -0
- dt_arena/envs/os-filesystem/docker-compose-hub.yml +16 -0
- dt_arena/envs/os-filesystem/docker-compose.yml +20 -0
- dt_arena/envs/paypal/docker-compose-hub.yml +48 -0
- dt_arena/envs/paypal/docker-compose.yml +63 -0
- dt_arena/envs/research/docker-compose-hub.yml +13 -0
- dt_arena/envs/research/docker-compose.yml +24 -0
- dt_arena/envs/salesforce_crm/docker-compose-hub.yaml +45 -0
- dt_arena/envs/salesforce_crm/docker-compose.yaml +49 -0
- dt_arena/envs/slack/docker-compose-hub.yml +28 -0
- dt_arena/envs/slack/docker-compose.yml +41 -0
- dt_arena/envs/snowflake/docker-compose-hub.yml +41 -0
- dt_arena/envs/snowflake/docker-compose.yml +44 -0
- dt_arena/envs/telecom/docker-compose-hub.yml +16 -0
- dt_arena/envs/telecom/docker-compose.yml +17 -0
- dt_arena/envs/telegram/docker-compose-hub.yml +57 -0
- dt_arena/envs/telegram/docker-compose.yml +62 -0
- dt_arena/envs/terminal/docker-compose-hub.yml +12 -0
- dt_arena/envs/terminal/docker-compose.yml +26 -0
- dt_arena/envs/travel/docker-compose-hub.yml +19 -0
- dt_arena/envs/travel/docker-compose.yml +19 -0
- dt_arena/envs/whatsapp/docker-compose-hub.yml +61 -0
- dt_arena/envs/whatsapp/docker-compose.yml +78 -0
- dt_arena/envs/windows/docker-compose.yml +71 -0
- dt_arena/envs/zoom/docker-compose-hub.yml +27 -0
- dt_arena/envs/zoom/docker-compose.yml +40 -0
- dt_arena/injection_mcp_server/atlassian/env_injection.py +134 -0
- dt_arena/injection_mcp_server/calendar/env_injection.py +217 -0
- dt_arena/injection_mcp_server/custom_website/env_injection.py +97 -0
- dt_arena/injection_mcp_server/customer_service/env_injection.py +659 -0
- dt_arena/injection_mcp_server/databricks/env_injection.py +255 -0
- dt_arena/injection_mcp_server/ecommerce/env_injection.py +110 -0
- dt_arena/injection_mcp_server/finance/env_injection.py +85 -0
- dt_arena/injection_mcp_server/github/env_injection.py +206 -0
- dt_arena/injection_mcp_server/gmail/env_injection.py +211 -0
- dt_arena/injection_mcp_server/google_form/env_injection.py +186 -0
- dt_arena/injection_mcp_server/googledocs/env_injection.py +44 -0
- dt_arena/injection_mcp_server/hospital/env_injection.py +43 -0
- dt_arena/injection_mcp_server/legal/env_injection.py +229 -0
- dt_arena/injection_mcp_server/macos/env_injection.py +272 -0
- dt_arena/injection_mcp_server/os-filesystem/env_injection.py +341 -0
- dt_arena/injection_mcp_server/paypal/env_injection.py +268 -0
- dt_arena/injection_mcp_server/research/env_injection.py +616 -0
- dt_arena/injection_mcp_server/salesforce/env_injection.py +514 -0
- dt_arena/injection_mcp_server/slack/env_injection.py +265 -0
- dt_arena/injection_mcp_server/snowflake/env_injection.py +230 -0
- dt_arena/injection_mcp_server/telecom/env_injection.py +503 -0
- dt_arena/injection_mcp_server/telegram/env_injection.py +171 -0
- dt_arena/injection_mcp_server/terminal/env_injection.py +523 -0
- dt_arena/injection_mcp_server/travel/env_injection.py +173 -0
- dt_arena/injection_mcp_server/whatsapp/env_injection.py +185 -0
- dt_arena/injection_mcp_server/windows/env_injection.py +943 -0
- dt_arena/injection_mcp_server/zoom/env_injection.py +216 -0
- dt_arena/mcp_server/atlassian/main.py +1554 -0
- dt_arena/mcp_server/atlassian/test_server.py +66 -0
- dt_arena/mcp_server/bigquery/main.py +333 -0
- dt_arena/mcp_server/booking/main.py +310 -0
- dt_arena/mcp_server/browser/main.py +1741 -0
- dt_arena/mcp_server/calendar/example_multi_user.py +162 -0
- dt_arena/mcp_server/calendar/main.py +792 -0
- dt_arena/mcp_server/calendar/test_mcp.py +135 -0
- dt_arena/mcp_server/customer_service/main.py +1063 -0
- dt_arena/mcp_server/databricks/main.py +566 -0
- dt_arena/mcp_server/databricks/probe.py +102 -0
- dt_arena/mcp_server/ers/main.py +845 -0
- dt_arena/mcp_server/finance/__init__.py +87 -0
- dt_arena/mcp_server/finance/core/__init__.py +12 -0
- dt_arena/mcp_server/finance/core/data_loader.py +558 -0
- dt_arena/mcp_server/finance/core/portfolio.py +565 -0
- dt_arena/mcp_server/finance/evaluation/__init__.py +20 -0
- dt_arena/mcp_server/finance/evaluation/evaluator.py +217 -0
- dt_arena/mcp_server/finance/evaluation/logger.py +137 -0
- dt_arena/mcp_server/finance/injection/__init__.py +66 -0
- dt_arena/mcp_server/finance/injection/config.py +176 -0
- dt_arena/mcp_server/finance/injection/content.py +755 -0
- dt_arena/mcp_server/finance/injection/html.py +409 -0
- dt_arena/mcp_server/finance/injection/locations.py +167 -0
- dt_arena/mcp_server/finance/injection/methods.py +193 -0
- dt_arena/mcp_server/finance/injection/presets.py +1023 -0
- dt_arena/mcp_server/finance/main.py +361 -0
- dt_arena/mcp_server/finance/run_mcp.py +21 -0
- dt_arena/mcp_server/finance/run_web.py +26 -0
- dt_arena/mcp_server/finance/server/__init__.py +41 -0
- dt_arena/mcp_server/finance/server/extractor.py +1453 -0
- dt_arena/mcp_server/finance/server/extractor_minimal.py +292 -0
- dt_arena/mcp_server/finance/server/extractor_simple.py +1164 -0
- dt_arena/mcp_server/finance/server/injection_mcp.py +865 -0
- dt_arena/mcp_server/finance/server/mcp.py +451 -0
- dt_arena/mcp_server/finance/server/tools/__init__.py +23 -0
- dt_arena/mcp_server/finance/server/tools/account.py +88 -0
- dt_arena/mcp_server/finance/server/tools/browsing.py +328 -0
- dt_arena/mcp_server/finance/server/tools/social.py +73 -0
- dt_arena/mcp_server/finance/server/tools/trading.py +242 -0
- dt_arena/mcp_server/finance/server/tools/utility.py +49 -0
- dt_arena/mcp_server/finance/server/web.py +2139 -0
- dt_arena/mcp_server/finance/tasks/benchmark/__init__.py +28 -0
- dt_arena/mcp_server/finance/tasks/benchmark/attack_pool.py +3026 -0
- dt_arena/mcp_server/finance/tasks/benchmark/attack_runner.py +1315 -0
- dt_arena/mcp_server/finance/tasks/benchmark/finra_requirements.py +1335 -0
- dt_arena/mcp_server/finance/tasks/benchmark/finra_tasks.py +3665 -0
- dt_arena/mcp_server/finance/tasks/benchmark/malicious_tasks.py +2673 -0
- dt_arena/mcp_server/finance/tasks/redteam_suite/run_redteam_suite.py +1713 -0
- dt_arena/mcp_server/finance/test_mcp_tools.py +476 -0
- dt_arena/mcp_server/github/main.py +441 -0
- dt_arena/mcp_server/gmail/main.py +1004 -0
- dt_arena/mcp_server/google_form/main.py +141 -0
- dt_arena/mcp_server/googledocs/main.py +458 -0
- dt_arena/mcp_server/hospital/mcp_server.py +458 -0
- dt_arena/mcp_server/legal/__init__.py +9 -0
- dt_arena/mcp_server/legal/core/__init__.py +14 -0
- dt_arena/mcp_server/legal/core/courtlistener_store.py +762 -0
- dt_arena/mcp_server/legal/core/data_loader.py +266 -0
- dt_arena/mcp_server/legal/core/document_store.py +197 -0
- dt_arena/mcp_server/legal/core/matter_manager.py +466 -0
- dt_arena/mcp_server/legal/main.py +89 -0
- dt_arena/mcp_server/legal/scripts/collect_data.py +988 -0
- dt_arena/mcp_server/legal/server/__init__.py +14 -0
- dt_arena/mcp_server/legal/server/mcp.py +2330 -0
- dt_arena/mcp_server/macos/client_test.py +270 -0
- dt_arena/mcp_server/macos/mcp_server.py +285 -0
- dt_arena/mcp_server/os-filesystem/main.py +1380 -0
- dt_arena/mcp_server/paypal/main.py +501 -0
- dt_arena/mcp_server/research/main.py +777 -0
- dt_arena/mcp_server/salesforce/main.py +2006 -0
- dt_arena/mcp_server/slack/main.py +318 -0
- dt_arena/mcp_server/snowflake/main.py +612 -0
- dt_arena/mcp_server/snowflake/probe.py +183 -0
- dt_arena/mcp_server/telecom/mcp_client.py +423 -0
- dt_arena/mcp_server/telecom/mcp_server.py +1059 -0
- dt_arena/mcp_server/telegram/main.py +338 -0
- dt_arena/mcp_server/terminal/main.py +163 -0
- dt_arena/mcp_server/travel/client_test.py +16 -0
- dt_arena/mcp_server/travel/mcp_server.py +404 -0
- dt_arena/mcp_server/whatsapp/main.py +318 -0
- dt_arena/mcp_server/windows/client_test.py +270 -0
- dt_arena/mcp_server/windows/mcp_server.py +218 -0
- dt_arena/mcp_server/zoom/main.py +466 -0
- dt_arena/src/__init__.py +0 -0
- dt_arena/src/hooks/__init__.py +0 -0
- dt_arena/src/hooks/audit_log.py +30 -0
- dt_arena/src/hooks/hooks.json +3 -0
- dt_arena/src/run_benign.py +142 -0
- dt_arena/src/types/__init__.py +0 -0
- dt_arena/src/types/agent.py +441 -0
- dt_arena/src/types/attacks.py +2 -0
- dt_arena/src/types/environment.py +2 -0
- dt_arena/src/types/hooks.py +174 -0
- dt_arena/src/types/judge.py +52 -0
- dt_arena/src/types/red_teaming_trajectory.py +385 -0
- dt_arena/src/types/task.py +260 -0
- dt_arena/src/types/trajectory.py +315 -0
- dt_arena/utils/__init__.py +1 -0
- dt_arena/utils/atlassian/__init__.py +27 -0
- dt_arena/utils/atlassian/helpers.py +520 -0
- dt_arena/utils/bigquery/__init__.py +1 -0
- dt_arena/utils/bigquery/helpers.py +246 -0
- dt_arena/utils/calendar/__init__.py +1 -0
- dt_arena/utils/calendar/helpers.py +87 -0
- dt_arena/utils/customer_service/__init__.py +17 -0
- dt_arena/utils/customer_service/cs_env_client.py +940 -0
- dt_arena/utils/customer_service/helpers.py +339 -0
- dt_arena/utils/customer_service/judges/__init__.py +20 -0
- dt_arena/utils/customer_service/judges/abusive_comms.py +149 -0
- dt_arena/utils/customer_service/judges/bulk_destructive.py +202 -0
- dt_arena/utils/customer_service/judges/cross_customer.py +481 -0
- dt_arena/utils/customer_service/judges/email_exfiltration.py +320 -0
- dt_arena/utils/customer_service/judges/financial_fraud.py +221 -0
- dt_arena/utils/customer_service/judges/llm_utils.py +264 -0
- dt_arena/utils/customer_service/judges/pii_disclosure.py +837 -0
- dt_arena/utils/customer_service/judges/policy_misrep.py +44 -0
- dt_arena/utils/customer_service/judges/text_utils.py +21 -0
- dt_arena/utils/databricks/__init__.py +2 -0
- dt_arena/utils/databricks/helpers.py +210 -0
- dt_arena/utils/finance/__init__.py +0 -0
- dt_arena/utils/finance/helpers.py +263 -0
- dt_arena/utils/github/__init__.py +1 -0
- dt_arena/utils/github/helpers.py +249 -0
- dt_arena/utils/gmail/__init__.py +1 -0
- dt_arena/utils/gmail/helpers.py +344 -0
- dt_arena/utils/google_form/__init__.py +2 -0
- dt_arena/utils/google_form/helpers.py +133 -0
- dt_arena/utils/legal/__init__.py +0 -0
- dt_arena/utils/legal/helpers.py +228 -0
- dt_arena/utils/macos/__init__.py +0 -0
- dt_arena/utils/macos/env_setup.py +215 -0
- dt_arena/utils/macos/helpers.py +61 -0
- dt_arena/utils/os_filesystem/__init__.py +1 -0
- dt_arena/utils/os_filesystem/helpers.py +366 -0
- dt_arena/utils/paypal/__init__.py +1 -0
- dt_arena/utils/paypal/helpers.py +178 -0
- dt_arena/utils/port_allocator.py +266 -0
- dt_arena/utils/research/__init__.py +0 -0
- dt_arena/utils/research/helpers.py +251 -0
- dt_arena/utils/salesforce/__init__.py +1 -0
- dt_arena/utils/salesforce/helpers.py +719 -0
- dt_arena/utils/slack/__init__.py +1 -0
- dt_arena/utils/slack/helpers.py +176 -0
- dt_arena/utils/snowflake/__init__.py +1 -0
- dt_arena/utils/snowflake/helpers.py +166 -0
- dt_arena/utils/telecom/__init__.py +1 -0
- dt_arena/utils/telecom/helpers.py +760 -0
- dt_arena/utils/telegram/__init__.py +0 -0
- dt_arena/utils/telegram/helpers.py +174 -0
- dt_arena/utils/terminal/__init__.py +0 -0
- dt_arena/utils/terminal/helpers.py +20 -0
- dt_arena/utils/travel/__init__.py +0 -0
- dt_arena/utils/travel/env_client.py +537 -0
- dt_arena/utils/travel/llm_judge.py +137 -0
- dt_arena/utils/travel/prompts.py +64 -0
- dt_arena/utils/utils/__init__.py +122 -0
- dt_arena/utils/whatsapp/__init__.py +0 -0
- dt_arena/utils/whatsapp/helpers.py +226 -0
- dt_arena/utils/windows/__init__.py +0 -0
- dt_arena/utils/windows/env_reset.py +224 -0
- dt_arena/utils/windows/env_setup.py +280 -0
- dt_arena/utils/windows/exfil_helpers.py +170 -0
- dt_arena/utils/windows/helpers.py +74 -0
- dt_arena/utils/zoom/__init__.py +1 -0
- dt_arena/utils/zoom/helpers.py +70 -0
- eval/__init__.py +1 -0
- eval/evaluation.py +426 -0
- eval/task_runner.py +449 -0
- utils/__init__.py +148 -0
- utils/agent_helpers.py +308 -0
- utils/agent_wrapper.py +189 -0
- utils/compose_utils.py +135 -0
- utils/config.py +77 -0
- utils/env_helpers.py +104 -0
- utils/eval_stats.py +88 -0
- utils/injection_helpers.py +429 -0
- utils/injection_mcp_helpers.py +152 -0
- utils/judge_helpers.py +181 -0
- utils/judge_utils.py +472 -0
- utils/llm.py +196 -0
- utils/logging.py +45 -0
- utils/mcp_helpers.py +232 -0
- utils/mcp_manager.py +235 -0
- utils/memory_guard.py +18 -0
- utils/red_teaming_sandbox.py +476 -0
- utils/reset_helpers.py +318 -0
- utils/resource_manager.py +370 -0
- utils/skill_helpers.py +447 -0
- utils/task_executor.py +904 -0
- utils/task_helpers.py +270 -0
- utils/template_helpers.py +179 -0
|
@@ -0,0 +1,645 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import json
|
|
4
|
+
import asyncio
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from typing import Dict, Any, Optional, List, Callable, Union
|
|
7
|
+
|
|
8
|
+
from langchain_core.tools import StructuredTool, BaseTool
|
|
9
|
+
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, BaseMessage
|
|
10
|
+
from langchain.agents import create_agent
|
|
11
|
+
|
|
12
|
+
# FastMCP client for connecting to MCP servers (simpler API)
|
|
13
|
+
from fastmcp import Client as FastMCPClient
|
|
14
|
+
|
|
15
|
+
# Add parent directory to path to import types
|
|
16
|
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../..'))
|
|
17
|
+
from dt_arena.src.types.agent import Agent, AgentConfig, RuntimeConfig, MCPServerConfig, AgentResult
|
|
18
|
+
from dt_arena.src.types.task import TaskConfig, AttackConfig
|
|
19
|
+
from dt_arena.src.types.trajectory import Trajectory
|
|
20
|
+
from dt_arena.src.types.hooks import HookManager, ToolCallContext
|
|
21
|
+
|
|
22
|
+
from .utils import LangChainTraceProcessor, LangChainTrajectoryConverter
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class MCPToolWrapper:
|
|
26
|
+
"""Wrapper to convert MCP tools to LangChain tools using FastMCP Client"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
client: FastMCPClient,
|
|
31
|
+
tool_info: Dict[str, Any],
|
|
32
|
+
server_name: str,
|
|
33
|
+
tool_injection: Optional[Any] = None,
|
|
34
|
+
hook_manager: HookManager = None,
|
|
35
|
+
):
|
|
36
|
+
self.client = client
|
|
37
|
+
self.tool_info = tool_info
|
|
38
|
+
self.server_name = server_name
|
|
39
|
+
self.name = tool_info.get("name", "unknown_tool")
|
|
40
|
+
self.input_schema = tool_info.get("inputSchema", {})
|
|
41
|
+
self._hook_manager = hook_manager
|
|
42
|
+
|
|
43
|
+
# Apply tool injection to description
|
|
44
|
+
original_description = tool_info.get("description", "No description")
|
|
45
|
+
self.description = self._apply_injection(original_description, tool_injection)
|
|
46
|
+
|
|
47
|
+
def _apply_injection(self, original_description: str, injection: Optional[Any]) -> str:
|
|
48
|
+
"""
|
|
49
|
+
Apply tool description injection.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
original_description: The original tool description
|
|
53
|
+
injection: ToolInjection object with type and content
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
Modified description with injection applied
|
|
57
|
+
"""
|
|
58
|
+
if not injection:
|
|
59
|
+
return original_description
|
|
60
|
+
|
|
61
|
+
if injection.type == "suffix":
|
|
62
|
+
return f"{original_description}\n{injection.content}"
|
|
63
|
+
elif injection.type == "override":
|
|
64
|
+
return injection.content
|
|
65
|
+
else:
|
|
66
|
+
print(f"[WARNING] Unknown injection type '{injection.type}' for tool '{self.name}'")
|
|
67
|
+
return original_description
|
|
68
|
+
|
|
69
|
+
async def __call__(self, **kwargs) -> str:
|
|
70
|
+
"""Execute the MCP tool via FastMCP client"""
|
|
71
|
+
try:
|
|
72
|
+
# LangChain may pass args in different formats, normalize them
|
|
73
|
+
tool_args = kwargs
|
|
74
|
+
|
|
75
|
+
# If a single 'input' key exists (LangChain structured tool format)
|
|
76
|
+
if len(kwargs) == 1 and 'input' in kwargs:
|
|
77
|
+
tool_args = kwargs['input'] if isinstance(kwargs['input'], dict) else {'input': kwargs['input']}
|
|
78
|
+
|
|
79
|
+
# If args are nested under a key matching tool name
|
|
80
|
+
if len(kwargs) == 1 and self.name in kwargs:
|
|
81
|
+
tool_args = kwargs[self.name] if isinstance(kwargs[self.name], dict) else kwargs
|
|
82
|
+
|
|
83
|
+
ctx = ToolCallContext(
|
|
84
|
+
framework="langchain",
|
|
85
|
+
server=self.server_name,
|
|
86
|
+
tool_name=self.name,
|
|
87
|
+
arguments=tool_args,
|
|
88
|
+
)
|
|
89
|
+
result = await self._hook_manager.wrap(
|
|
90
|
+
ctx,
|
|
91
|
+
lambda args: self.client.call_tool(self.name, args),
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
if hasattr(result, 'data') and result.data is not None:
|
|
95
|
+
if isinstance(result.data, (dict, list)):
|
|
96
|
+
return json.dumps(result.data, ensure_ascii=False, default=str)
|
|
97
|
+
return str(result.data)
|
|
98
|
+
|
|
99
|
+
# Fallback: extract text from content
|
|
100
|
+
if hasattr(result, 'content') and result.content:
|
|
101
|
+
content_parts = []
|
|
102
|
+
for item in result.content:
|
|
103
|
+
if hasattr(item, 'text'):
|
|
104
|
+
content_parts.append(item.text)
|
|
105
|
+
elif isinstance(item, dict) and 'text' in item:
|
|
106
|
+
content_parts.append(item['text'])
|
|
107
|
+
else:
|
|
108
|
+
content_parts.append(str(item))
|
|
109
|
+
return "\n".join(content_parts) if content_parts else str(result)
|
|
110
|
+
|
|
111
|
+
return str(result)
|
|
112
|
+
except Exception as e:
|
|
113
|
+
return f"Error calling tool {self.name}: {str(e)}"
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class LangChainAgent(Agent):
|
|
117
|
+
"""LangChain-based Agent implementation for DecodingTrust Agent Arena"""
|
|
118
|
+
|
|
119
|
+
def __init__(
|
|
120
|
+
self,
|
|
121
|
+
agent_config: AgentConfig,
|
|
122
|
+
runtime_config: Optional[RuntimeConfig] = None,
|
|
123
|
+
task_config: Optional[TaskConfig] = None,
|
|
124
|
+
attack_config: Optional[AttackConfig] = None
|
|
125
|
+
):
|
|
126
|
+
"""
|
|
127
|
+
Initialize LangChain Agent
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
agent_config: Agent configuration (system_prompt and mcp_servers)
|
|
131
|
+
runtime_config: Runtime configuration (model, temperature, max_turns, output_dir)
|
|
132
|
+
task_config: Optional task configuration
|
|
133
|
+
attack_config: Optional attack configuration
|
|
134
|
+
"""
|
|
135
|
+
super().__init__(agent_config, runtime_config)
|
|
136
|
+
|
|
137
|
+
# Store task and attack config
|
|
138
|
+
self.task_config = task_config
|
|
139
|
+
self.attack_config = attack_config
|
|
140
|
+
|
|
141
|
+
# Setup output directories
|
|
142
|
+
output_dir = self.runtime_config.output_dir or os.path.join(os.getcwd(), "results")
|
|
143
|
+
self.output_dir = output_dir
|
|
144
|
+
self.traces_dir = os.path.join(self.output_dir, "traces")
|
|
145
|
+
self.trajectories_dir = self.output_dir # Match OpenAI SDK structure
|
|
146
|
+
os.makedirs(self.traces_dir, exist_ok=True)
|
|
147
|
+
os.makedirs(self.trajectories_dir, exist_ok=True)
|
|
148
|
+
|
|
149
|
+
# Initialize trace file and processor
|
|
150
|
+
self.timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
151
|
+
self.trace_file = os.path.join(self.traces_dir, f"traces_{self.timestamp}.jsonl")
|
|
152
|
+
self.trace_processor = LangChainTraceProcessor(self.trace_file)
|
|
153
|
+
|
|
154
|
+
# Initialize trajectory converter
|
|
155
|
+
self.trajectory_converter = LangChainTrajectoryConverter(self.trajectories_dir)
|
|
156
|
+
|
|
157
|
+
# MCP session management
|
|
158
|
+
self.mcp_sessions: List[Dict[str, Any]] = []
|
|
159
|
+
self.mcp_tools: List[BaseTool] = []
|
|
160
|
+
|
|
161
|
+
# LangChain agent graph
|
|
162
|
+
self.agent_graph = None
|
|
163
|
+
|
|
164
|
+
# Multi-turn conversation support
|
|
165
|
+
self._conversation_history: List[BaseMessage] = []
|
|
166
|
+
self._turn_count: int = 0
|
|
167
|
+
self._current_trace_id: Optional[str] = None
|
|
168
|
+
self._trace_metadata: Optional[Dict[str, Any]] = None
|
|
169
|
+
self._current_trajectory: Optional[Trajectory] = None
|
|
170
|
+
self._all_intermediate_steps: List[Dict[str, Any]] = []
|
|
171
|
+
|
|
172
|
+
def _get_model_string(self) -> str:
|
|
173
|
+
"""
|
|
174
|
+
Get model string for LangChain create_agent.
|
|
175
|
+
|
|
176
|
+
LangChain 1.1+ uses format: "provider:model_name"
|
|
177
|
+
e.g., "openai:gpt-4.1", "anthropic:claude-3-opus"
|
|
178
|
+
"""
|
|
179
|
+
model = self.runtime_config.model
|
|
180
|
+
|
|
181
|
+
# If already in provider:model format, return as-is
|
|
182
|
+
if ":" in model:
|
|
183
|
+
return model
|
|
184
|
+
|
|
185
|
+
# Auto-detect provider based on model name
|
|
186
|
+
if model.startswith("claude") or model.startswith("anthropic"):
|
|
187
|
+
return f"anthropic:{model}"
|
|
188
|
+
elif model.startswith("gpt") or model.startswith("o1") or model.startswith("o3"):
|
|
189
|
+
return f"openai:{model}"
|
|
190
|
+
elif model.startswith("gemini"):
|
|
191
|
+
return f"google_genai:{model}"
|
|
192
|
+
else:
|
|
193
|
+
# Default to OpenAI for unknown models
|
|
194
|
+
return f"openai:{model}"
|
|
195
|
+
|
|
196
|
+
async def _connect_mcp_server(self, server_config: MCPServerConfig) -> Optional[Dict[str, Any]]:
|
|
197
|
+
"""
|
|
198
|
+
Connect to a single MCP server using FastMCP Client
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
server_config: MCP server configuration
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
Dictionary with client, tools, and config info
|
|
205
|
+
"""
|
|
206
|
+
try:
|
|
207
|
+
# Determine URL based on transport type
|
|
208
|
+
if server_config.transport == "sse":
|
|
209
|
+
url = server_config.url
|
|
210
|
+
elif server_config.transport == "http":
|
|
211
|
+
# FastMCP auto-detects transport
|
|
212
|
+
url = server_config.url
|
|
213
|
+
elif server_config.transport == "stdio":
|
|
214
|
+
# For stdio, we need to handle differently
|
|
215
|
+
# FastMCP supports stdio via command string
|
|
216
|
+
print(f"[WARNING] Stdio transport not fully supported yet for {server_config.name}")
|
|
217
|
+
return None
|
|
218
|
+
else:
|
|
219
|
+
print(f"[WARNING] Unsupported transport: {server_config.transport}")
|
|
220
|
+
return None
|
|
221
|
+
|
|
222
|
+
# Create FastMCP client
|
|
223
|
+
client = FastMCPClient(url)
|
|
224
|
+
|
|
225
|
+
# Enter the async context
|
|
226
|
+
await client.__aenter__()
|
|
227
|
+
|
|
228
|
+
# Get available tools
|
|
229
|
+
tools = await client.list_tools()
|
|
230
|
+
|
|
231
|
+
return {
|
|
232
|
+
'name': server_config.name,
|
|
233
|
+
'client': client,
|
|
234
|
+
'tools': tools,
|
|
235
|
+
'config': server_config
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
except Exception as e:
|
|
239
|
+
print(f"[ERROR] Failed to connect to MCP server {server_config.name}: {e}")
|
|
240
|
+
import traceback
|
|
241
|
+
traceback.print_exc()
|
|
242
|
+
return None
|
|
243
|
+
|
|
244
|
+
def _create_langchain_tool(self, mcp_wrapper: MCPToolWrapper) -> BaseTool:
|
|
245
|
+
"""
|
|
246
|
+
Create a LangChain StructuredTool from MCP wrapper with proper schema.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
mcp_wrapper: MCPToolWrapper instance
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
LangChain StructuredTool with proper args schema
|
|
253
|
+
"""
|
|
254
|
+
from pydantic import create_model, Field
|
|
255
|
+
from typing import Any as TypingAny
|
|
256
|
+
|
|
257
|
+
# Build Pydantic model from MCP input schema
|
|
258
|
+
input_schema = mcp_wrapper.input_schema
|
|
259
|
+
properties = input_schema.get("properties", {})
|
|
260
|
+
required = set(input_schema.get("required", []))
|
|
261
|
+
|
|
262
|
+
# Create field definitions for Pydantic model
|
|
263
|
+
field_definitions = {}
|
|
264
|
+
for prop_name, prop_info in properties.items():
|
|
265
|
+
prop_type = prop_info.get("type", "string")
|
|
266
|
+
prop_desc = prop_info.get("description", "")
|
|
267
|
+
|
|
268
|
+
# Map JSON schema types to Python types
|
|
269
|
+
type_mapping = {
|
|
270
|
+
"string": str,
|
|
271
|
+
"integer": int,
|
|
272
|
+
"number": float,
|
|
273
|
+
"boolean": bool,
|
|
274
|
+
"array": list,
|
|
275
|
+
"object": dict,
|
|
276
|
+
}
|
|
277
|
+
python_type = type_mapping.get(prop_type, str)
|
|
278
|
+
|
|
279
|
+
# Create field with or without default based on required
|
|
280
|
+
if prop_name in required:
|
|
281
|
+
field_definitions[prop_name] = (python_type, Field(description=prop_desc))
|
|
282
|
+
else:
|
|
283
|
+
field_definitions[prop_name] = (Optional[python_type], Field(default=None, description=prop_desc))
|
|
284
|
+
|
|
285
|
+
# Create dynamic Pydantic model for args schema
|
|
286
|
+
if field_definitions:
|
|
287
|
+
ArgsSchema = create_model(f"{mcp_wrapper.name}_args", **field_definitions)
|
|
288
|
+
else:
|
|
289
|
+
ArgsSchema = None
|
|
290
|
+
|
|
291
|
+
# Create the async coroutine for the tool
|
|
292
|
+
async def async_tool_func(**kwargs) -> str:
|
|
293
|
+
return await mcp_wrapper(**kwargs)
|
|
294
|
+
|
|
295
|
+
# Create StructuredTool
|
|
296
|
+
tool = StructuredTool.from_function(
|
|
297
|
+
func=lambda **kwargs: asyncio.get_event_loop().run_until_complete(mcp_wrapper(**kwargs)),
|
|
298
|
+
coroutine=async_tool_func,
|
|
299
|
+
name=mcp_wrapper.name,
|
|
300
|
+
description=mcp_wrapper.description,
|
|
301
|
+
args_schema=ArgsSchema,
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
# Store reference to wrapper for debugging
|
|
305
|
+
tool._mcp_wrapper = mcp_wrapper
|
|
306
|
+
tool._server_name = mcp_wrapper.server_name
|
|
307
|
+
|
|
308
|
+
return tool
|
|
309
|
+
|
|
310
|
+
def _create_mcp_server(self, server_config: MCPServerConfig) -> Any:
|
|
311
|
+
"""
|
|
312
|
+
Create MCP server configuration (placeholder for base class compatibility).
|
|
313
|
+
Actual connection is done in initialize().
|
|
314
|
+
|
|
315
|
+
Args:
|
|
316
|
+
server_config: Configuration for the MCP server
|
|
317
|
+
|
|
318
|
+
Returns:
|
|
319
|
+
Server configuration dict
|
|
320
|
+
"""
|
|
321
|
+
return server_config
|
|
322
|
+
|
|
323
|
+
async def _check_duplicate_tools_langchain(self) -> None:
|
|
324
|
+
"""
|
|
325
|
+
Check for duplicate tool names across MCP servers.
|
|
326
|
+
|
|
327
|
+
Raises:
|
|
328
|
+
ValueError: If duplicate tool names are found across servers
|
|
329
|
+
"""
|
|
330
|
+
tool_to_servers: Dict[str, List[str]] = {}
|
|
331
|
+
|
|
332
|
+
for tool in self.mcp_tools:
|
|
333
|
+
tool_name = tool.name
|
|
334
|
+
server_name = getattr(tool, '_server_name', 'unknown')
|
|
335
|
+
|
|
336
|
+
if tool_name not in tool_to_servers:
|
|
337
|
+
tool_to_servers[tool_name] = []
|
|
338
|
+
tool_to_servers[tool_name].append(server_name)
|
|
339
|
+
|
|
340
|
+
# Find duplicates
|
|
341
|
+
duplicates = {
|
|
342
|
+
tool_name: servers for tool_name, servers in tool_to_servers.items()
|
|
343
|
+
if len(servers) > 1
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
if duplicates:
|
|
347
|
+
dup_info = ", ".join([f"{name} (in {', '.join(srvs)})" for name, srvs in duplicates.items()])
|
|
348
|
+
raise ValueError(
|
|
349
|
+
f"Duplicate tool names found across MCP servers: {dup_info}"
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
async def initialize(self) -> None:
|
|
353
|
+
"""Initialize agent and connect to MCP servers"""
|
|
354
|
+
if not self.config:
|
|
355
|
+
raise ValueError("Agent config is required")
|
|
356
|
+
|
|
357
|
+
# Connect to all MCP servers and collect tools
|
|
358
|
+
all_tools = []
|
|
359
|
+
|
|
360
|
+
for server_config in self.config.mcp_servers:
|
|
361
|
+
if not server_config.enabled:
|
|
362
|
+
continue
|
|
363
|
+
|
|
364
|
+
client_info = await self._connect_mcp_server(server_config)
|
|
365
|
+
if client_info:
|
|
366
|
+
self.mcp_sessions.append(client_info)
|
|
367
|
+
|
|
368
|
+
# Get tool injections for this server (if any)
|
|
369
|
+
tool_injections = None
|
|
370
|
+
if self.runtime_config.mcp_injection:
|
|
371
|
+
tool_injections = self.runtime_config.mcp_injection.get(server_config.name)
|
|
372
|
+
|
|
373
|
+
# Convert MCP tools to LangChain tools
|
|
374
|
+
for tool in client_info['tools']:
|
|
375
|
+
tool_info = {
|
|
376
|
+
'name': tool.name,
|
|
377
|
+
'description': tool.description or "",
|
|
378
|
+
'inputSchema': tool.inputSchema if hasattr(tool, 'inputSchema') else {}
|
|
379
|
+
}
|
|
380
|
+
|
|
381
|
+
# Get injection for this specific tool (if any)
|
|
382
|
+
tool_injection = None
|
|
383
|
+
if tool_injections:
|
|
384
|
+
tool_injection = tool_injections.get(tool.name)
|
|
385
|
+
|
|
386
|
+
mcp_wrapper = MCPToolWrapper(
|
|
387
|
+
client_info['client'], # Use FastMCP client
|
|
388
|
+
tool_info,
|
|
389
|
+
server_config.name,
|
|
390
|
+
tool_injection=tool_injection,
|
|
391
|
+
hook_manager=self.hook_manager,
|
|
392
|
+
)
|
|
393
|
+
lc_tool = self._create_langchain_tool(mcp_wrapper)
|
|
394
|
+
all_tools.append(lc_tool)
|
|
395
|
+
|
|
396
|
+
print(f"[INFO] Connected to MCP server {server_config.name} with {len(client_info['tools'])} tools")
|
|
397
|
+
|
|
398
|
+
self.mcp_tools = all_tools
|
|
399
|
+
|
|
400
|
+
# Check for duplicate tool names across servers
|
|
401
|
+
await self._check_duplicate_tools_langchain()
|
|
402
|
+
|
|
403
|
+
# Get model string (e.g., "openai:gpt-4.1")
|
|
404
|
+
model_string = self._get_model_string()
|
|
405
|
+
print(f"[INFO] Using model: {model_string}")
|
|
406
|
+
|
|
407
|
+
# Create agent using LangChain 1.1+ API
|
|
408
|
+
task_name = self.task_config.task_id if self.task_config else "LangChainAgent"
|
|
409
|
+
self.agent_graph = create_agent(
|
|
410
|
+
model=model_string,
|
|
411
|
+
tools=self.mcp_tools if self.mcp_tools else None,
|
|
412
|
+
system_prompt=self.config.system_prompt,
|
|
413
|
+
name=task_name
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
async def run(
|
|
417
|
+
self,
|
|
418
|
+
user_input: Union[str, List[str]],
|
|
419
|
+
metadata: Optional[Dict[str, Any]] = None
|
|
420
|
+
) -> AgentResult:
|
|
421
|
+
"""
|
|
422
|
+
Run the agent with given input. Supports multi-turn conversations.
|
|
423
|
+
|
|
424
|
+
Args:
|
|
425
|
+
user_input: User instruction/query. Can be:
|
|
426
|
+
- str: Single query (backward compatible)
|
|
427
|
+
- List[str]: Multiple queries processed sequentially with context preserved
|
|
428
|
+
|
|
429
|
+
metadata: Optional metadata (task_id, domain, category, instruction)
|
|
430
|
+
|
|
431
|
+
Returns:
|
|
432
|
+
AgentResult with final output, turn count, trace ID, and trajectory
|
|
433
|
+
"""
|
|
434
|
+
if not self.agent_graph:
|
|
435
|
+
raise RuntimeError("Agent not initialized. Call initialize() first or use async context manager.")
|
|
436
|
+
|
|
437
|
+
# Normalize input to list for uniform processing
|
|
438
|
+
if isinstance(user_input, str):
|
|
439
|
+
user_inputs = [user_input]
|
|
440
|
+
else:
|
|
441
|
+
user_inputs = user_input
|
|
442
|
+
|
|
443
|
+
# Initialize trace metadata
|
|
444
|
+
if self._trace_metadata is None:
|
|
445
|
+
self._trace_metadata = metadata or {}
|
|
446
|
+
|
|
447
|
+
# Set/update instruction in metadata
|
|
448
|
+
if "instruction" not in self._trace_metadata:
|
|
449
|
+
if len(user_inputs) == 1:
|
|
450
|
+
instr = user_inputs[0] or ""
|
|
451
|
+
else:
|
|
452
|
+
instr = " | ".join(user_inputs) # For multi-turn list input
|
|
453
|
+
max_len = 500
|
|
454
|
+
if len(instr) > max_len:
|
|
455
|
+
instr = instr[:max_len]
|
|
456
|
+
self._trace_metadata["instruction"] = instr
|
|
457
|
+
|
|
458
|
+
# Start new trace if none active
|
|
459
|
+
if self._current_trace_id is None:
|
|
460
|
+
self._current_trace_id = self.trace_processor.start_trace(self._trace_metadata)
|
|
461
|
+
|
|
462
|
+
start_time = datetime.now()
|
|
463
|
+
final_output = ""
|
|
464
|
+
|
|
465
|
+
try:
|
|
466
|
+
for query in user_inputs:
|
|
467
|
+
# Add user message to conversation history
|
|
468
|
+
self._conversation_history.append(HumanMessage(content=query))
|
|
469
|
+
|
|
470
|
+
# Record user input
|
|
471
|
+
self.trace_processor.record_user_input(self._current_trace_id, query)
|
|
472
|
+
|
|
473
|
+
# Prepare input for the agent graph (include conversation history)
|
|
474
|
+
input_messages = {"messages": self._conversation_history.copy()}
|
|
475
|
+
|
|
476
|
+
# Run the agent graph using ainvoke
|
|
477
|
+
result = await self.agent_graph.ainvoke(input_messages)
|
|
478
|
+
|
|
479
|
+
# Extract messages from result
|
|
480
|
+
messages = result.get("messages", [])
|
|
481
|
+
turn_final_output = ""
|
|
482
|
+
|
|
483
|
+
for msg in messages:
|
|
484
|
+
msg_type = type(msg).__name__
|
|
485
|
+
|
|
486
|
+
if msg_type == "AIMessage":
|
|
487
|
+
# Check for tool calls
|
|
488
|
+
if hasattr(msg, 'tool_calls') and msg.tool_calls:
|
|
489
|
+
for tool_call in msg.tool_calls:
|
|
490
|
+
tool_name = tool_call.get('name', 'unknown')
|
|
491
|
+
tool_args = tool_call.get('args', {})
|
|
492
|
+
|
|
493
|
+
self._all_intermediate_steps.append({
|
|
494
|
+
'type': 'tool_call',
|
|
495
|
+
'tool': tool_name,
|
|
496
|
+
'args': tool_args
|
|
497
|
+
})
|
|
498
|
+
else:
|
|
499
|
+
# Regular AI message - this is the final output for this turn
|
|
500
|
+
if hasattr(msg, 'content') and msg.content:
|
|
501
|
+
turn_final_output = msg.content
|
|
502
|
+
|
|
503
|
+
elif msg_type == "ToolMessage":
|
|
504
|
+
# Tool response
|
|
505
|
+
tool_output = msg.content if hasattr(msg, 'content') else str(msg)
|
|
506
|
+
tool_name = msg.name if hasattr(msg, 'name') else 'unknown'
|
|
507
|
+
|
|
508
|
+
self._all_intermediate_steps.append({
|
|
509
|
+
'type': 'tool_response',
|
|
510
|
+
'tool': tool_name,
|
|
511
|
+
'output': tool_output
|
|
512
|
+
})
|
|
513
|
+
|
|
514
|
+
# Record tool call in trace
|
|
515
|
+
prev_tool_calls = [s for s in self._all_intermediate_steps if s.get('type') == 'tool_call']
|
|
516
|
+
if prev_tool_calls:
|
|
517
|
+
prev_step = prev_tool_calls[-1]
|
|
518
|
+
if isinstance(tool_output, (dict, list)):
|
|
519
|
+
tool_output_str = json.dumps(tool_output, ensure_ascii=False, default=str)
|
|
520
|
+
else:
|
|
521
|
+
tool_output_str = str(tool_output)
|
|
522
|
+
self.trace_processor.record_tool_call(
|
|
523
|
+
self._current_trace_id,
|
|
524
|
+
tool_name=prev_step.get('tool', tool_name),
|
|
525
|
+
tool_input=prev_step.get('args', {}),
|
|
526
|
+
tool_output=tool_output_str
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
# Update conversation history with AI response
|
|
530
|
+
if turn_final_output:
|
|
531
|
+
self._conversation_history.append(AIMessage(content=turn_final_output))
|
|
532
|
+
final_output = turn_final_output # Keep updating to get the last response
|
|
533
|
+
|
|
534
|
+
# Record response after each turn
|
|
535
|
+
self.trace_processor.write_agent_response(self._current_trace_id, turn_final_output)
|
|
536
|
+
|
|
537
|
+
self._turn_count += 1
|
|
538
|
+
|
|
539
|
+
# Record final response
|
|
540
|
+
self.trace_processor.record_agent_response(self._current_trace_id, final_output)
|
|
541
|
+
|
|
542
|
+
end_time = datetime.now()
|
|
543
|
+
duration = (end_time - start_time).total_seconds()
|
|
544
|
+
|
|
545
|
+
# End trace
|
|
546
|
+
self.trace_processor.end_trace(self._current_trace_id, final_output, duration)
|
|
547
|
+
|
|
548
|
+
# Generate trajectory
|
|
549
|
+
self._generate_trajectory()
|
|
550
|
+
|
|
551
|
+
return self.get_result()
|
|
552
|
+
|
|
553
|
+
except Exception as e:
|
|
554
|
+
if self._current_trace_id:
|
|
555
|
+
self.trace_processor.record_error(self._current_trace_id, str(e))
|
|
556
|
+
raise
|
|
557
|
+
|
|
558
|
+
def _generate_trajectory(self) -> None:
|
|
559
|
+
"""Generate trajectory from current trace file"""
|
|
560
|
+
try:
|
|
561
|
+
if os.path.exists(self.trace_file):
|
|
562
|
+
saved_files = self.trajectory_converter.process_trace_file(
|
|
563
|
+
self.trace_file,
|
|
564
|
+
output_name="trajectory"
|
|
565
|
+
)
|
|
566
|
+
if saved_files:
|
|
567
|
+
# Load the trajectory from the saved file
|
|
568
|
+
import json
|
|
569
|
+
with open(saved_files[-1], 'r') as f:
|
|
570
|
+
traj_data = json.load(f)
|
|
571
|
+
# Create Trajectory object from loaded data
|
|
572
|
+
self._current_trajectory = Trajectory(
|
|
573
|
+
task_id=traj_data.get("task_info", {}).get("task_id"),
|
|
574
|
+
original_instruction=traj_data.get("task_info", {}).get("original_instruction"),
|
|
575
|
+
)
|
|
576
|
+
self._current_trajectory.data = traj_data
|
|
577
|
+
except Exception as e:
|
|
578
|
+
print(f"[WARNING] Failed to generate trajectory: {e}")
|
|
579
|
+
|
|
580
|
+
def get_result(self) -> AgentResult:
|
|
581
|
+
"""
|
|
582
|
+
Get the current execution result.
|
|
583
|
+
|
|
584
|
+
Returns:
|
|
585
|
+
AgentResult with final output, turn count, trace ID, and trajectory.
|
|
586
|
+
"""
|
|
587
|
+
# Get final output from conversation history
|
|
588
|
+
final_output = None
|
|
589
|
+
for msg in reversed(self._conversation_history):
|
|
590
|
+
if isinstance(msg, AIMessage) and msg.content:
|
|
591
|
+
final_output = msg.content
|
|
592
|
+
break
|
|
593
|
+
|
|
594
|
+
return AgentResult(
|
|
595
|
+
final_output=final_output,
|
|
596
|
+
turn_count=self._turn_count,
|
|
597
|
+
trajectory=self._current_trajectory,
|
|
598
|
+
trace_id=self._current_trace_id or "",
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
def reset_conversation(self) -> None:
|
|
602
|
+
"""
|
|
603
|
+
Reset conversation history for a new session.
|
|
604
|
+
|
|
605
|
+
Call this to clear the accumulated context and start fresh.
|
|
606
|
+
Automatically called during cleanup().
|
|
607
|
+
|
|
608
|
+
This will:
|
|
609
|
+
- Clear conversation history
|
|
610
|
+
- Clear trace state
|
|
611
|
+
- Allow a new trace to be started on next run() call
|
|
612
|
+
"""
|
|
613
|
+
self._conversation_history = []
|
|
614
|
+
self._turn_count = 0
|
|
615
|
+
self._current_trace_id = None
|
|
616
|
+
self._trace_metadata = None
|
|
617
|
+
self._current_trajectory = None
|
|
618
|
+
self._all_intermediate_steps = []
|
|
619
|
+
|
|
620
|
+
async def cleanup(self) -> None:
|
|
621
|
+
"""Clean up resources and process traces"""
|
|
622
|
+
# Reset conversation history
|
|
623
|
+
self.reset_conversation()
|
|
624
|
+
|
|
625
|
+
# Close all MCP clients
|
|
626
|
+
for client_info in self.mcp_sessions:
|
|
627
|
+
try:
|
|
628
|
+
client = client_info['client']
|
|
629
|
+
await client.__aexit__(None, None, None)
|
|
630
|
+
except Exception as e:
|
|
631
|
+
print(f"[WARNING] Failed to cleanup MCP client {client_info['name']}: {e}")
|
|
632
|
+
|
|
633
|
+
self.mcp_sessions = []
|
|
634
|
+
self.mcp_tools = []
|
|
635
|
+
|
|
636
|
+
# Process trace file to generate trajectories
|
|
637
|
+
try:
|
|
638
|
+
if os.path.exists(self.trace_file):
|
|
639
|
+
saved_files = self.trajectory_converter.process_trace_file(
|
|
640
|
+
self.trace_file,
|
|
641
|
+
output_name="trajectory"
|
|
642
|
+
)
|
|
643
|
+
print(f"[INFO] Generated {len(saved_files)} trajectory files")
|
|
644
|
+
except Exception as e:
|
|
645
|
+
print(f"[WARNING] Failed to process traces: {e}")
|