decodingtrust-agent-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent/__init__.py +30 -0
- agent/claudesdk/__init__.py +8 -0
- agent/claudesdk/example.py +221 -0
- agent/claudesdk/src/__init__.py +8 -0
- agent/claudesdk/src/agent.py +400 -0
- agent/claudesdk/src/mcp_proxy.py +409 -0
- agent/claudesdk/src/utils.py +420 -0
- agent/googleadk/__init__.py +15 -0
- agent/googleadk/example.py +237 -0
- agent/googleadk/src/__init__.py +12 -0
- agent/googleadk/src/agent.py +401 -0
- agent/googleadk/src/mcp_wrapper.py +163 -0
- agent/googleadk/src/utils.py +602 -0
- agent/langchain/__init__.py +8 -0
- agent/langchain/example.py +213 -0
- agent/langchain/src/__init__.py +8 -0
- agent/langchain/src/agent.py +645 -0
- agent/langchain/src/utils.py +433 -0
- agent/openaisdk/__init__.py +17 -0
- agent/openaisdk/example.py +228 -0
- agent/openaisdk/src/__init__.py +12 -0
- agent/openaisdk/src/agent.py +491 -0
- agent/openaisdk/src/agent_wrapper.py +143 -0
- agent/openaisdk/src/mcp_wrapper.py +395 -0
- agent/openaisdk/src/utils.py +493 -0
- agent/openclaw/__init__.py +10 -0
- agent/openclaw/example.py +251 -0
- agent/openclaw/src/__init__.py +14 -0
- agent/openclaw/src/agent.py +930 -0
- agent/openclaw/src/helpers/__init__.py +1 -0
- agent/openclaw/src/helpers/auth_helpers.py +55 -0
- agent/openclaw/src/mcp_proxy.py +564 -0
- agent/openclaw/src/plugin_generator.py +231 -0
- agent/openclaw/src/utils.py +341 -0
- agent/pocketflow/__init__.py +18 -0
- agent/pocketflow/example.py +221 -0
- agent/pocketflow/prompts/react_agent.py +46 -0
- agent/pocketflow/src/__init__.py +6 -0
- agent/pocketflow/src/agent.py +507 -0
- agent/pocketflow/src/agent_wrapper.py +159 -0
- agent/pocketflow/src/async_helper.py +92 -0
- agent/pocketflow/src/mcp_react_agent.py +279 -0
- agent/pocketflow/src/native_agent.py +74 -0
- agent/pocketflow/src/nodes.py +467 -0
- benchmark/__init__.py +0 -0
- benchmark/browser/benign.jsonl +34 -0
- benchmark/browser/direct.jsonl +85 -0
- benchmark/browser/indirect.jsonl +82 -0
- benchmark/code/benign.jsonl +0 -0
- benchmark/code/direct.jsonl +121 -0
- benchmark/code/indirect.jsonl +165 -0
- benchmark/crm/benign.jsonl +165 -0
- benchmark/crm/direct.jsonl +90 -0
- benchmark/crm/indirect.jsonl +150 -0
- benchmark/customer-service/benign.jsonl +160 -0
- benchmark/customer-service/direct.jsonl +100 -0
- benchmark/customer-service/indirect.jsonl +101 -0
- benchmark/finance/benign.jsonl +0 -0
- benchmark/finance/direct.jsonl +200 -0
- benchmark/finance/indirect.jsonl +200 -0
- benchmark/legal/benign.jsonl +0 -0
- benchmark/legal/direct.jsonl +200 -0
- benchmark/legal/indirect.jsonl +200 -0
- benchmark/macos/benign.jsonl +30 -0
- benchmark/macos/direct.jsonl +50 -0
- benchmark/macos/indirect.jsonl +50 -0
- benchmark/medical/benign.jsonl +642 -0
- benchmark/medical/direct.jsonl +229 -0
- benchmark/medical/indirect.jsonl +222 -0
- benchmark/os-filesystem/benign.jsonl +200 -0
- benchmark/os-filesystem/direct.jsonl +200 -0
- benchmark/os-filesystem/indirect.jsonl +200 -0
- benchmark/research/benign.jsonl +0 -0
- benchmark/research/direct.jsonl +119 -0
- benchmark/research/indirect.jsonl +125 -0
- benchmark/telecom/benign.jsonl +120 -0
- benchmark/telecom/direct.jsonl +161 -0
- benchmark/telecom/indirect.jsonl +166 -0
- benchmark/travel/benign.jsonl +130 -0
- benchmark/travel/direct.jsonl +105 -0
- benchmark/travel/indirect.jsonl +120 -0
- benchmark/windows/benign.jsonl +100 -0
- benchmark/windows/direct.jsonl +140 -0
- benchmark/windows/indirect.jsonl +107 -0
- benchmark/workflow/benign.jsonl +335 -0
- benchmark/workflow/direct.jsonl +78 -0
- benchmark/workflow/indirect.jsonl +107 -0
- cli/__init__.py +5 -0
- cli/main.py +182 -0
- cli/scaffold.py +334 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/METADATA +642 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/RECORD +374 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/WHEEL +5 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/entry_points.txt +2 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/licenses/LICENSE +201 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/top_level.txt +6 -0
- dt_arena/config/env.yaml +515 -0
- dt_arena/config/injection_mcp.yaml +430 -0
- dt_arena/config/mcp.yaml +642 -0
- dt_arena/envs/arxiv/docker-compose-hub.yml +31 -0
- dt_arena/envs/arxiv/docker-compose.yml +36 -0
- dt_arena/envs/atlassian/docker/docker-compose.dev.yml +65 -0
- dt_arena/envs/atlassian/docker/docker-compose.yml +53 -0
- dt_arena/envs/atlassian/docker-compose-hub.yml +57 -0
- dt_arena/envs/atlassian/docker-compose.yml +72 -0
- dt_arena/envs/bigquery/docker-compose.yml +20 -0
- dt_arena/envs/booking/docker-compose.yml +59 -0
- dt_arena/envs/calendar/docker-compose-hub.yml +30 -0
- dt_arena/envs/calendar/docker-compose.yml +42 -0
- dt_arena/envs/custom-website/docker-compose.yml +6 -0
- dt_arena/envs/customer_service/docker-compose.yml +59 -0
- dt_arena/envs/databricks/docker-compose-hub.yml +47 -0
- dt_arena/envs/databricks/docker-compose.yml +51 -0
- dt_arena/envs/ecommerce/docker-compose.yml +6 -0
- dt_arena/envs/ers/docker-compose.yml +36 -0
- dt_arena/envs/ers/hrms/docker/docker-compose.yml +31 -0
- dt_arena/envs/finance/docker-compose.yml +23 -0
- dt_arena/envs/github/docker/docker-compose-hub.yml +50 -0
- dt_arena/envs/github/docker/docker-compose.yml +50 -0
- dt_arena/envs/gmail/docker-compose-hub.yml +51 -0
- dt_arena/envs/gmail/docker-compose.yml +65 -0
- dt_arena/envs/google-form/docker-compose-hub.yml +33 -0
- dt_arena/envs/google-form/docker-compose.yml +41 -0
- dt_arena/envs/googledocs/docker-compose-hub.yml +61 -0
- dt_arena/envs/googledocs/docker-compose.yml +78 -0
- dt_arena/envs/hospital/docker-compose-hub.yml +25 -0
- dt_arena/envs/hospital/docker-compose.yml +27 -0
- dt_arena/envs/legal/docker-compose.yml +22 -0
- dt_arena/envs/linkedin/docker-compose.yml +63 -0
- dt_arena/envs/macos/docker-compose.yml +79 -0
- dt_arena/envs/os-filesystem/docker-compose-hub.yml +16 -0
- dt_arena/envs/os-filesystem/docker-compose.yml +20 -0
- dt_arena/envs/paypal/docker-compose-hub.yml +48 -0
- dt_arena/envs/paypal/docker-compose.yml +63 -0
- dt_arena/envs/research/docker-compose-hub.yml +13 -0
- dt_arena/envs/research/docker-compose.yml +24 -0
- dt_arena/envs/salesforce_crm/docker-compose-hub.yaml +45 -0
- dt_arena/envs/salesforce_crm/docker-compose.yaml +49 -0
- dt_arena/envs/slack/docker-compose-hub.yml +28 -0
- dt_arena/envs/slack/docker-compose.yml +41 -0
- dt_arena/envs/snowflake/docker-compose-hub.yml +41 -0
- dt_arena/envs/snowflake/docker-compose.yml +44 -0
- dt_arena/envs/telecom/docker-compose-hub.yml +16 -0
- dt_arena/envs/telecom/docker-compose.yml +17 -0
- dt_arena/envs/telegram/docker-compose-hub.yml +57 -0
- dt_arena/envs/telegram/docker-compose.yml +62 -0
- dt_arena/envs/terminal/docker-compose-hub.yml +12 -0
- dt_arena/envs/terminal/docker-compose.yml +26 -0
- dt_arena/envs/travel/docker-compose-hub.yml +19 -0
- dt_arena/envs/travel/docker-compose.yml +19 -0
- dt_arena/envs/whatsapp/docker-compose-hub.yml +61 -0
- dt_arena/envs/whatsapp/docker-compose.yml +78 -0
- dt_arena/envs/windows/docker-compose.yml +71 -0
- dt_arena/envs/zoom/docker-compose-hub.yml +27 -0
- dt_arena/envs/zoom/docker-compose.yml +40 -0
- dt_arena/injection_mcp_server/atlassian/env_injection.py +134 -0
- dt_arena/injection_mcp_server/calendar/env_injection.py +217 -0
- dt_arena/injection_mcp_server/custom_website/env_injection.py +97 -0
- dt_arena/injection_mcp_server/customer_service/env_injection.py +659 -0
- dt_arena/injection_mcp_server/databricks/env_injection.py +255 -0
- dt_arena/injection_mcp_server/ecommerce/env_injection.py +110 -0
- dt_arena/injection_mcp_server/finance/env_injection.py +85 -0
- dt_arena/injection_mcp_server/github/env_injection.py +206 -0
- dt_arena/injection_mcp_server/gmail/env_injection.py +211 -0
- dt_arena/injection_mcp_server/google_form/env_injection.py +186 -0
- dt_arena/injection_mcp_server/googledocs/env_injection.py +44 -0
- dt_arena/injection_mcp_server/hospital/env_injection.py +43 -0
- dt_arena/injection_mcp_server/legal/env_injection.py +229 -0
- dt_arena/injection_mcp_server/macos/env_injection.py +272 -0
- dt_arena/injection_mcp_server/os-filesystem/env_injection.py +341 -0
- dt_arena/injection_mcp_server/paypal/env_injection.py +268 -0
- dt_arena/injection_mcp_server/research/env_injection.py +616 -0
- dt_arena/injection_mcp_server/salesforce/env_injection.py +514 -0
- dt_arena/injection_mcp_server/slack/env_injection.py +265 -0
- dt_arena/injection_mcp_server/snowflake/env_injection.py +230 -0
- dt_arena/injection_mcp_server/telecom/env_injection.py +503 -0
- dt_arena/injection_mcp_server/telegram/env_injection.py +171 -0
- dt_arena/injection_mcp_server/terminal/env_injection.py +523 -0
- dt_arena/injection_mcp_server/travel/env_injection.py +173 -0
- dt_arena/injection_mcp_server/whatsapp/env_injection.py +185 -0
- dt_arena/injection_mcp_server/windows/env_injection.py +943 -0
- dt_arena/injection_mcp_server/zoom/env_injection.py +216 -0
- dt_arena/mcp_server/atlassian/main.py +1554 -0
- dt_arena/mcp_server/atlassian/test_server.py +66 -0
- dt_arena/mcp_server/bigquery/main.py +333 -0
- dt_arena/mcp_server/booking/main.py +310 -0
- dt_arena/mcp_server/browser/main.py +1741 -0
- dt_arena/mcp_server/calendar/example_multi_user.py +162 -0
- dt_arena/mcp_server/calendar/main.py +792 -0
- dt_arena/mcp_server/calendar/test_mcp.py +135 -0
- dt_arena/mcp_server/customer_service/main.py +1063 -0
- dt_arena/mcp_server/databricks/main.py +566 -0
- dt_arena/mcp_server/databricks/probe.py +102 -0
- dt_arena/mcp_server/ers/main.py +845 -0
- dt_arena/mcp_server/finance/__init__.py +87 -0
- dt_arena/mcp_server/finance/core/__init__.py +12 -0
- dt_arena/mcp_server/finance/core/data_loader.py +558 -0
- dt_arena/mcp_server/finance/core/portfolio.py +565 -0
- dt_arena/mcp_server/finance/evaluation/__init__.py +20 -0
- dt_arena/mcp_server/finance/evaluation/evaluator.py +217 -0
- dt_arena/mcp_server/finance/evaluation/logger.py +137 -0
- dt_arena/mcp_server/finance/injection/__init__.py +66 -0
- dt_arena/mcp_server/finance/injection/config.py +176 -0
- dt_arena/mcp_server/finance/injection/content.py +755 -0
- dt_arena/mcp_server/finance/injection/html.py +409 -0
- dt_arena/mcp_server/finance/injection/locations.py +167 -0
- dt_arena/mcp_server/finance/injection/methods.py +193 -0
- dt_arena/mcp_server/finance/injection/presets.py +1023 -0
- dt_arena/mcp_server/finance/main.py +361 -0
- dt_arena/mcp_server/finance/run_mcp.py +21 -0
- dt_arena/mcp_server/finance/run_web.py +26 -0
- dt_arena/mcp_server/finance/server/__init__.py +41 -0
- dt_arena/mcp_server/finance/server/extractor.py +1453 -0
- dt_arena/mcp_server/finance/server/extractor_minimal.py +292 -0
- dt_arena/mcp_server/finance/server/extractor_simple.py +1164 -0
- dt_arena/mcp_server/finance/server/injection_mcp.py +865 -0
- dt_arena/mcp_server/finance/server/mcp.py +451 -0
- dt_arena/mcp_server/finance/server/tools/__init__.py +23 -0
- dt_arena/mcp_server/finance/server/tools/account.py +88 -0
- dt_arena/mcp_server/finance/server/tools/browsing.py +328 -0
- dt_arena/mcp_server/finance/server/tools/social.py +73 -0
- dt_arena/mcp_server/finance/server/tools/trading.py +242 -0
- dt_arena/mcp_server/finance/server/tools/utility.py +49 -0
- dt_arena/mcp_server/finance/server/web.py +2139 -0
- dt_arena/mcp_server/finance/tasks/benchmark/__init__.py +28 -0
- dt_arena/mcp_server/finance/tasks/benchmark/attack_pool.py +3026 -0
- dt_arena/mcp_server/finance/tasks/benchmark/attack_runner.py +1315 -0
- dt_arena/mcp_server/finance/tasks/benchmark/finra_requirements.py +1335 -0
- dt_arena/mcp_server/finance/tasks/benchmark/finra_tasks.py +3665 -0
- dt_arena/mcp_server/finance/tasks/benchmark/malicious_tasks.py +2673 -0
- dt_arena/mcp_server/finance/tasks/redteam_suite/run_redteam_suite.py +1713 -0
- dt_arena/mcp_server/finance/test_mcp_tools.py +476 -0
- dt_arena/mcp_server/github/main.py +441 -0
- dt_arena/mcp_server/gmail/main.py +1004 -0
- dt_arena/mcp_server/google_form/main.py +141 -0
- dt_arena/mcp_server/googledocs/main.py +458 -0
- dt_arena/mcp_server/hospital/mcp_server.py +458 -0
- dt_arena/mcp_server/legal/__init__.py +9 -0
- dt_arena/mcp_server/legal/core/__init__.py +14 -0
- dt_arena/mcp_server/legal/core/courtlistener_store.py +762 -0
- dt_arena/mcp_server/legal/core/data_loader.py +266 -0
- dt_arena/mcp_server/legal/core/document_store.py +197 -0
- dt_arena/mcp_server/legal/core/matter_manager.py +466 -0
- dt_arena/mcp_server/legal/main.py +89 -0
- dt_arena/mcp_server/legal/scripts/collect_data.py +988 -0
- dt_arena/mcp_server/legal/server/__init__.py +14 -0
- dt_arena/mcp_server/legal/server/mcp.py +2330 -0
- dt_arena/mcp_server/macos/client_test.py +270 -0
- dt_arena/mcp_server/macos/mcp_server.py +285 -0
- dt_arena/mcp_server/os-filesystem/main.py +1380 -0
- dt_arena/mcp_server/paypal/main.py +501 -0
- dt_arena/mcp_server/research/main.py +777 -0
- dt_arena/mcp_server/salesforce/main.py +2006 -0
- dt_arena/mcp_server/slack/main.py +318 -0
- dt_arena/mcp_server/snowflake/main.py +612 -0
- dt_arena/mcp_server/snowflake/probe.py +183 -0
- dt_arena/mcp_server/telecom/mcp_client.py +423 -0
- dt_arena/mcp_server/telecom/mcp_server.py +1059 -0
- dt_arena/mcp_server/telegram/main.py +338 -0
- dt_arena/mcp_server/terminal/main.py +163 -0
- dt_arena/mcp_server/travel/client_test.py +16 -0
- dt_arena/mcp_server/travel/mcp_server.py +404 -0
- dt_arena/mcp_server/whatsapp/main.py +318 -0
- dt_arena/mcp_server/windows/client_test.py +270 -0
- dt_arena/mcp_server/windows/mcp_server.py +218 -0
- dt_arena/mcp_server/zoom/main.py +466 -0
- dt_arena/src/__init__.py +0 -0
- dt_arena/src/hooks/__init__.py +0 -0
- dt_arena/src/hooks/audit_log.py +30 -0
- dt_arena/src/hooks/hooks.json +3 -0
- dt_arena/src/run_benign.py +142 -0
- dt_arena/src/types/__init__.py +0 -0
- dt_arena/src/types/agent.py +441 -0
- dt_arena/src/types/attacks.py +2 -0
- dt_arena/src/types/environment.py +2 -0
- dt_arena/src/types/hooks.py +174 -0
- dt_arena/src/types/judge.py +52 -0
- dt_arena/src/types/red_teaming_trajectory.py +385 -0
- dt_arena/src/types/task.py +260 -0
- dt_arena/src/types/trajectory.py +315 -0
- dt_arena/utils/__init__.py +1 -0
- dt_arena/utils/atlassian/__init__.py +27 -0
- dt_arena/utils/atlassian/helpers.py +520 -0
- dt_arena/utils/bigquery/__init__.py +1 -0
- dt_arena/utils/bigquery/helpers.py +246 -0
- dt_arena/utils/calendar/__init__.py +1 -0
- dt_arena/utils/calendar/helpers.py +87 -0
- dt_arena/utils/customer_service/__init__.py +17 -0
- dt_arena/utils/customer_service/cs_env_client.py +940 -0
- dt_arena/utils/customer_service/helpers.py +339 -0
- dt_arena/utils/customer_service/judges/__init__.py +20 -0
- dt_arena/utils/customer_service/judges/abusive_comms.py +149 -0
- dt_arena/utils/customer_service/judges/bulk_destructive.py +202 -0
- dt_arena/utils/customer_service/judges/cross_customer.py +481 -0
- dt_arena/utils/customer_service/judges/email_exfiltration.py +320 -0
- dt_arena/utils/customer_service/judges/financial_fraud.py +221 -0
- dt_arena/utils/customer_service/judges/llm_utils.py +264 -0
- dt_arena/utils/customer_service/judges/pii_disclosure.py +837 -0
- dt_arena/utils/customer_service/judges/policy_misrep.py +44 -0
- dt_arena/utils/customer_service/judges/text_utils.py +21 -0
- dt_arena/utils/databricks/__init__.py +2 -0
- dt_arena/utils/databricks/helpers.py +210 -0
- dt_arena/utils/finance/__init__.py +0 -0
- dt_arena/utils/finance/helpers.py +263 -0
- dt_arena/utils/github/__init__.py +1 -0
- dt_arena/utils/github/helpers.py +249 -0
- dt_arena/utils/gmail/__init__.py +1 -0
- dt_arena/utils/gmail/helpers.py +344 -0
- dt_arena/utils/google_form/__init__.py +2 -0
- dt_arena/utils/google_form/helpers.py +133 -0
- dt_arena/utils/legal/__init__.py +0 -0
- dt_arena/utils/legal/helpers.py +228 -0
- dt_arena/utils/macos/__init__.py +0 -0
- dt_arena/utils/macos/env_setup.py +215 -0
- dt_arena/utils/macos/helpers.py +61 -0
- dt_arena/utils/os_filesystem/__init__.py +1 -0
- dt_arena/utils/os_filesystem/helpers.py +366 -0
- dt_arena/utils/paypal/__init__.py +1 -0
- dt_arena/utils/paypal/helpers.py +178 -0
- dt_arena/utils/port_allocator.py +266 -0
- dt_arena/utils/research/__init__.py +0 -0
- dt_arena/utils/research/helpers.py +251 -0
- dt_arena/utils/salesforce/__init__.py +1 -0
- dt_arena/utils/salesforce/helpers.py +719 -0
- dt_arena/utils/slack/__init__.py +1 -0
- dt_arena/utils/slack/helpers.py +176 -0
- dt_arena/utils/snowflake/__init__.py +1 -0
- dt_arena/utils/snowflake/helpers.py +166 -0
- dt_arena/utils/telecom/__init__.py +1 -0
- dt_arena/utils/telecom/helpers.py +760 -0
- dt_arena/utils/telegram/__init__.py +0 -0
- dt_arena/utils/telegram/helpers.py +174 -0
- dt_arena/utils/terminal/__init__.py +0 -0
- dt_arena/utils/terminal/helpers.py +20 -0
- dt_arena/utils/travel/__init__.py +0 -0
- dt_arena/utils/travel/env_client.py +537 -0
- dt_arena/utils/travel/llm_judge.py +137 -0
- dt_arena/utils/travel/prompts.py +64 -0
- dt_arena/utils/utils/__init__.py +122 -0
- dt_arena/utils/whatsapp/__init__.py +0 -0
- dt_arena/utils/whatsapp/helpers.py +226 -0
- dt_arena/utils/windows/__init__.py +0 -0
- dt_arena/utils/windows/env_reset.py +224 -0
- dt_arena/utils/windows/env_setup.py +280 -0
- dt_arena/utils/windows/exfil_helpers.py +170 -0
- dt_arena/utils/windows/helpers.py +74 -0
- dt_arena/utils/zoom/__init__.py +1 -0
- dt_arena/utils/zoom/helpers.py +70 -0
- eval/__init__.py +1 -0
- eval/evaluation.py +426 -0
- eval/task_runner.py +449 -0
- utils/__init__.py +148 -0
- utils/agent_helpers.py +308 -0
- utils/agent_wrapper.py +189 -0
- utils/compose_utils.py +135 -0
- utils/config.py +77 -0
- utils/env_helpers.py +104 -0
- utils/eval_stats.py +88 -0
- utils/injection_helpers.py +429 -0
- utils/injection_mcp_helpers.py +152 -0
- utils/judge_helpers.py +181 -0
- utils/judge_utils.py +472 -0
- utils/llm.py +196 -0
- utils/logging.py +45 -0
- utils/mcp_helpers.py +232 -0
- utils/mcp_manager.py +235 -0
- utils/memory_guard.py +18 -0
- utils/red_teaming_sandbox.py +476 -0
- utils/reset_helpers.py +318 -0
- utils/resource_manager.py +370 -0
- utils/skill_helpers.py +447 -0
- utils/task_executor.py +904 -0
- utils/task_helpers.py +270 -0
- utils/template_helpers.py +179 -0
|
@@ -0,0 +1,491 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
from typing import Dict, Any, Optional, List, Union
|
|
4
|
+
|
|
5
|
+
from openai import AsyncOpenAI
|
|
6
|
+
from agents import Agent as OpenAIAgent, Runner, ModelSettings, set_trace_processors, function_tool
|
|
7
|
+
from agents import OpenAIChatCompletionsModel
|
|
8
|
+
from agents.tracing import trace
|
|
9
|
+
|
|
10
|
+
from dt_arena.src.types.agent import (
|
|
11
|
+
Agent,
|
|
12
|
+
AgentConfig,
|
|
13
|
+
RuntimeConfig,
|
|
14
|
+
MCPServerConfig,
|
|
15
|
+
AgentResult,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
from .mcp_wrapper import MCPServerSse, MCPServerStreamableHttp, MCPServerStdio
|
|
19
|
+
from dt_arena.src.types.trajectory import Trajectory
|
|
20
|
+
|
|
21
|
+
from .utils import OpenAITraceProcessor, OpenAIAgentTrajectory
|
|
22
|
+
from utils.skill_helpers import (
|
|
23
|
+
create_injected_skills_directory,
|
|
24
|
+
cleanup_temp_directory,
|
|
25
|
+
parse_skill_metadata,
|
|
26
|
+
load_skill_full_content,
|
|
27
|
+
scan_available_skills,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# Default MCP client read timeout for all transports.
|
|
32
|
+
MCP_CLIENT_TIMEOUT_SECONDS: float = float(os.getenv("MCP_CLIENT_TIMEOUT_SECONDS", "30"))
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class OpenAISDKAgent(Agent):
|
|
36
|
+
"""General Agent Workflow based on OpenAI Agents SDK"""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
agent_config: AgentConfig,
|
|
41
|
+
runtime_config: Optional[RuntimeConfig] = None,
|
|
42
|
+
):
|
|
43
|
+
"""
|
|
44
|
+
Initialize OpenAI SDK Agent
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
agent_config: Agent configuration (system_prompt and mcp_servers)
|
|
48
|
+
runtime_config: Runtime configuration (model, temperature, max_turns, output_dir)
|
|
49
|
+
"""
|
|
50
|
+
super().__init__(agent_config, runtime_config)
|
|
51
|
+
|
|
52
|
+
# Setup output directory
|
|
53
|
+
output_dir = self.runtime_config.output_dir or os.path.join(
|
|
54
|
+
os.getcwd(), "results"
|
|
55
|
+
)
|
|
56
|
+
self.output_dir = output_dir
|
|
57
|
+
self.traces_dir = os.path.join(self.output_dir, "traces")
|
|
58
|
+
self.trajectories_dir = self.output_dir
|
|
59
|
+
os.makedirs(self.traces_dir, exist_ok=True)
|
|
60
|
+
os.makedirs(self.trajectories_dir, exist_ok=True)
|
|
61
|
+
|
|
62
|
+
# Initialize trace file and processor
|
|
63
|
+
self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
64
|
+
self.trace_file = os.path.join(
|
|
65
|
+
self.traces_dir, f"traces_{self.timestamp}.jsonl"
|
|
66
|
+
)
|
|
67
|
+
self.trace_processor = OpenAITraceProcessor(self.trace_file)
|
|
68
|
+
|
|
69
|
+
# Use only local trace processor (replaces default OpenAI backend processor)
|
|
70
|
+
set_trace_processors([self.trace_processor])
|
|
71
|
+
|
|
72
|
+
# Initialize trajectory converter
|
|
73
|
+
self.trajectory_converter = OpenAIAgentTrajectory(
|
|
74
|
+
self.trajectories_dir, self.timestamp
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Multi-turn conversation support: stores last RunResult for to_input_list() to build conversation history
|
|
78
|
+
self._last_result = None
|
|
79
|
+
|
|
80
|
+
# Active trace context for multi-turn
|
|
81
|
+
self._active_trace = None
|
|
82
|
+
self._active_trace_id = None
|
|
83
|
+
self._trace_metadata = None
|
|
84
|
+
|
|
85
|
+
# Current execution result and trajectory
|
|
86
|
+
self._current_trajectory: Optional[Trajectory] = None
|
|
87
|
+
self._turn_count: int = 0
|
|
88
|
+
|
|
89
|
+
# Tool descriptions cache (for debug mode)
|
|
90
|
+
self._tool_descriptions: Dict[str, str] = {}
|
|
91
|
+
|
|
92
|
+
# Skill injection temp directory
|
|
93
|
+
self._skill_temp_dir: Optional[str] = None
|
|
94
|
+
|
|
95
|
+
async def _setup_skills(self) -> None:
|
|
96
|
+
"""Setup skill directories and apply any skill injections."""
|
|
97
|
+
skill_directories = self.config.skill_directories if self.config else []
|
|
98
|
+
skill_injection = self.runtime_config.skill_injection
|
|
99
|
+
|
|
100
|
+
has_create_mode = skill_injection and any(
|
|
101
|
+
any(inj.mode == "create" for inj in injs)
|
|
102
|
+
for injs in skill_injection.values()
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
if not skill_directories and not has_create_mode:
|
|
106
|
+
return
|
|
107
|
+
|
|
108
|
+
self._skill_temp_dir = create_injected_skills_directory(
|
|
109
|
+
source_skill_dirs=skill_directories,
|
|
110
|
+
skill_injection=skill_injection,
|
|
111
|
+
skill_subpath=".agents/skills",
|
|
112
|
+
base_dir=self.output_dir,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
def _build_tools_for_skill(self) -> list:
|
|
116
|
+
"""Build a load_skill function tool that returns full SKILL.md content."""
|
|
117
|
+
tools = []
|
|
118
|
+
if not self._skill_temp_dir:
|
|
119
|
+
return tools
|
|
120
|
+
|
|
121
|
+
skills_dir = os.path.join(self._skill_temp_dir, ".agents", "skills")
|
|
122
|
+
|
|
123
|
+
@function_tool
|
|
124
|
+
def load_skill(skill_name: str) -> str:
|
|
125
|
+
"""Load a skill by name and return its full instructions.
|
|
126
|
+
|
|
127
|
+
Use this tool to retrieve the complete content of a skill file.
|
|
128
|
+
After loading, follow the skill's instructions to complete the task.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
skill_name: The name of the skill to load.
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
The full content of the skill's SKILL.md file, or an error message if not found.
|
|
135
|
+
"""
|
|
136
|
+
available = scan_available_skills(skills_dir)
|
|
137
|
+
for skill in available:
|
|
138
|
+
if skill.get("name") == skill_name:
|
|
139
|
+
return load_skill_full_content(skill["path"])
|
|
140
|
+
names = [s.get("name", "") for s in available]
|
|
141
|
+
return f"Skill '{skill_name}' not found. Available skills: {names}"
|
|
142
|
+
|
|
143
|
+
tools.append(load_skill)
|
|
144
|
+
return tools
|
|
145
|
+
|
|
146
|
+
def _build_instructions_for_skill(self) -> str:
|
|
147
|
+
"""Build agent instructions, appending skill hints if skills are configured."""
|
|
148
|
+
instructions = self.config.system_prompt
|
|
149
|
+
if not self._skill_temp_dir:
|
|
150
|
+
return instructions
|
|
151
|
+
|
|
152
|
+
skills_dir = os.path.join(self._skill_temp_dir, ".agents", "skills")
|
|
153
|
+
skill_entries = []
|
|
154
|
+
if os.path.isdir(skills_dir):
|
|
155
|
+
for item in sorted(os.listdir(skills_dir)):
|
|
156
|
+
skill_path = os.path.join(skills_dir, item)
|
|
157
|
+
if not os.path.isdir(skill_path):
|
|
158
|
+
continue
|
|
159
|
+
skill_file = os.path.join(skill_path, "SKILL.md")
|
|
160
|
+
if not os.path.exists(skill_file):
|
|
161
|
+
continue
|
|
162
|
+
metadata = parse_skill_metadata(skill_file)
|
|
163
|
+
if metadata:
|
|
164
|
+
skill_entries.append(f"# Skill: {item}\n\n{metadata}")
|
|
165
|
+
else:
|
|
166
|
+
skill_entries.append(f"# Skill: {item}")
|
|
167
|
+
|
|
168
|
+
if skill_entries:
|
|
169
|
+
skills_list = "\n".join(skill_entries)
|
|
170
|
+
instructions += (
|
|
171
|
+
"\n\n## Instructions for Using Skills\n"
|
|
172
|
+
"You have access to the `load_skill` tool that retrieves full skill instructions, "
|
|
173
|
+
"whose metadata are defined in the system prompt. "
|
|
174
|
+
"Based on the task requirements, if you need to use a skill, "
|
|
175
|
+
"call `load_skill` with the skill name to get its full instructions, then follow them.\n\n"
|
|
176
|
+
f"## Available skills:\n{skills_list}"
|
|
177
|
+
)
|
|
178
|
+
return instructions
|
|
179
|
+
|
|
180
|
+
async def initialize(self) -> None:
|
|
181
|
+
"""Initialize agent and connect to MCP servers"""
|
|
182
|
+
if not self.config:
|
|
183
|
+
raise ValueError("Agent config is required")
|
|
184
|
+
|
|
185
|
+
# Setup skills if configured
|
|
186
|
+
await self._setup_skills()
|
|
187
|
+
|
|
188
|
+
# Load and connect to MCP servers
|
|
189
|
+
await self.load_mcp_servers()
|
|
190
|
+
|
|
191
|
+
# Build tools and instructions (with skill support)
|
|
192
|
+
tools = self._build_tools_for_skill()
|
|
193
|
+
instructions = self._build_instructions_for_skill()
|
|
194
|
+
|
|
195
|
+
# Determine model: use Chat Completions API for OpenRouter,
|
|
196
|
+
# since it only supports the Chat Completions format, not the
|
|
197
|
+
# Responses API that the SDK uses by default.
|
|
198
|
+
base_url = os.environ.get("OPENAI_BASE_URL", "")
|
|
199
|
+
if base_url and "openrouter.ai" in base_url:
|
|
200
|
+
model = OpenAIChatCompletionsModel(
|
|
201
|
+
model=self.runtime_config.model,
|
|
202
|
+
openai_client=AsyncOpenAI(
|
|
203
|
+
base_url=base_url,
|
|
204
|
+
api_key=os.environ.get("OPENAI_API_KEY"),
|
|
205
|
+
),
|
|
206
|
+
)
|
|
207
|
+
else:
|
|
208
|
+
model = self.runtime_config.model
|
|
209
|
+
|
|
210
|
+
# Create OpenAI agent instance using runtime parameters
|
|
211
|
+
self.agent = OpenAIAgent(
|
|
212
|
+
name=self.config.name,
|
|
213
|
+
instructions=instructions,
|
|
214
|
+
model=model,
|
|
215
|
+
model_settings=ModelSettings(temperature=self.runtime_config.temperature)
|
|
216
|
+
if self.runtime_config.temperature is not None
|
|
217
|
+
else ModelSettings(),
|
|
218
|
+
mcp_servers=self.mcp_servers,
|
|
219
|
+
tools=tools,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
def _create_mcp_server(self, server_config: MCPServerConfig) -> Any:
|
|
223
|
+
"""
|
|
224
|
+
Create OpenAI SDK-specific MCP server instance.
|
|
225
|
+
|
|
226
|
+
If mcp_injection is configured in runtime_config, the wrapper classes
|
|
227
|
+
will automatically apply tool description injections.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
server_config: Configuration for the MCP server
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
OpenAI SDK MCP server instance (MCPServerSse, MCPServerStdio, or MCPServerStreamableHttp)
|
|
234
|
+
"""
|
|
235
|
+
# Get tool injections for this server (if any)
|
|
236
|
+
tool_injections = None
|
|
237
|
+
if self.runtime_config.mcp_injection:
|
|
238
|
+
tool_injections = self.runtime_config.mcp_injection.get(server_config.name)
|
|
239
|
+
|
|
240
|
+
if server_config.transport == "sse":
|
|
241
|
+
return MCPServerSse(
|
|
242
|
+
name=server_config.name,
|
|
243
|
+
params={"url": server_config.url},
|
|
244
|
+
cache_tools_list=server_config.cache_tools_list,
|
|
245
|
+
client_session_timeout_seconds=MCP_CLIENT_TIMEOUT_SECONDS,
|
|
246
|
+
tool_injections=tool_injections,
|
|
247
|
+
hook_manager=self.hook_manager,
|
|
248
|
+
)
|
|
249
|
+
elif server_config.transport == "stdio":
|
|
250
|
+
params = {"command": server_config.command}
|
|
251
|
+
if server_config.args:
|
|
252
|
+
params["args"] = server_config.args
|
|
253
|
+
if server_config.env:
|
|
254
|
+
params["env"] = server_config.env
|
|
255
|
+
return MCPServerStdio(
|
|
256
|
+
name=server_config.name,
|
|
257
|
+
params=params,
|
|
258
|
+
cache_tools_list=server_config.cache_tools_list,
|
|
259
|
+
client_session_timeout_seconds=MCP_CLIENT_TIMEOUT_SECONDS,
|
|
260
|
+
tool_injections=tool_injections,
|
|
261
|
+
hook_manager=self.hook_manager,
|
|
262
|
+
)
|
|
263
|
+
elif server_config.transport == "http":
|
|
264
|
+
return MCPServerStreamableHttp(
|
|
265
|
+
name=server_config.name,
|
|
266
|
+
params={"url": server_config.url},
|
|
267
|
+
cache_tools_list=server_config.cache_tools_list,
|
|
268
|
+
client_session_timeout_seconds=MCP_CLIENT_TIMEOUT_SECONDS,
|
|
269
|
+
tool_injections=tool_injections,
|
|
270
|
+
hook_manager=self.hook_manager,
|
|
271
|
+
)
|
|
272
|
+
else:
|
|
273
|
+
raise ValueError(f"Unsupported transport type: {server_config.transport}")
|
|
274
|
+
|
|
275
|
+
async def load_mcp_servers(self) -> List[Any]:
|
|
276
|
+
"""
|
|
277
|
+
Load and connect to MCP servers.
|
|
278
|
+
|
|
279
|
+
Override base class to add explicit connect() calls required by OpenAI SDK.
|
|
280
|
+
"""
|
|
281
|
+
# Use base class to create servers and track names
|
|
282
|
+
await super().load_mcp_servers()
|
|
283
|
+
|
|
284
|
+
# OpenAI SDK requires explicit connection
|
|
285
|
+
for server in self.mcp_servers:
|
|
286
|
+
await server.connect()
|
|
287
|
+
|
|
288
|
+
# Cache tool descriptions for debug mode (with injections applied)
|
|
289
|
+
if self.runtime_config.debug:
|
|
290
|
+
await self._cache_tool_descriptions()
|
|
291
|
+
|
|
292
|
+
return self.mcp_servers
|
|
293
|
+
|
|
294
|
+
async def _cache_tool_descriptions(self) -> None:
|
|
295
|
+
"""Cache tool descriptions from all MCP servers (for debug mode)."""
|
|
296
|
+
for i, server in enumerate(self.mcp_servers):
|
|
297
|
+
server_name = (
|
|
298
|
+
self.mcp_server_names[i]
|
|
299
|
+
if i < len(self.mcp_server_names)
|
|
300
|
+
else "unknown"
|
|
301
|
+
)
|
|
302
|
+
try:
|
|
303
|
+
# list_tools() returns tools with injections already applied
|
|
304
|
+
tools = await server.list_tools()
|
|
305
|
+
for tool in tools:
|
|
306
|
+
key = f"{server_name}:{tool.name}"
|
|
307
|
+
self._tool_descriptions[key] = tool.description or ""
|
|
308
|
+
except Exception as e:
|
|
309
|
+
print(
|
|
310
|
+
f"[DEBUG] Failed to cache tool descriptions from {server_name}: {e}"
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
async def run(
|
|
314
|
+
self,
|
|
315
|
+
user_input: Union[str, List[str]],
|
|
316
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
317
|
+
) -> AgentResult:
|
|
318
|
+
"""
|
|
319
|
+
Run the agent with given input. Supports multi-turn conversations.
|
|
320
|
+
|
|
321
|
+
Args:
|
|
322
|
+
user_input: User instruction/query. Can be:
|
|
323
|
+
- str: Single query (backward compatible)
|
|
324
|
+
- List[str]: Multiple queries processed sequentially with context preserved
|
|
325
|
+
|
|
326
|
+
metadata: Optional metadata (task_id, domain, category, instruction)
|
|
327
|
+
|
|
328
|
+
Returns:
|
|
329
|
+
AgentResult with final output, turn count, trace ID, and trajectory
|
|
330
|
+
"""
|
|
331
|
+
if not self.agent:
|
|
332
|
+
raise RuntimeError(
|
|
333
|
+
"Agent not initialized. Call initialize() first or use async context manager."
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
# Normalize input to list for uniform processing
|
|
337
|
+
if isinstance(user_input, str):
|
|
338
|
+
user_inputs = [user_input]
|
|
339
|
+
else:
|
|
340
|
+
user_inputs = user_input
|
|
341
|
+
|
|
342
|
+
if self._trace_metadata is None:
|
|
343
|
+
self._trace_metadata = metadata or {}
|
|
344
|
+
|
|
345
|
+
# Set/update instruction in metadata
|
|
346
|
+
if "instruction" not in self._trace_metadata:
|
|
347
|
+
if len(user_inputs) == 1:
|
|
348
|
+
instr = user_inputs[0] or ""
|
|
349
|
+
else:
|
|
350
|
+
instr = " | ".join(
|
|
351
|
+
user_inputs
|
|
352
|
+
) # For multi-turn list input, combine queries with separator
|
|
353
|
+
self._trace_metadata["instruction"] = instr
|
|
354
|
+
|
|
355
|
+
# Include debug info in trace metadata if debug mode is enabled
|
|
356
|
+
if self.runtime_config.debug:
|
|
357
|
+
if self._tool_descriptions:
|
|
358
|
+
self._trace_metadata["tool_descriptions"] = self._tool_descriptions
|
|
359
|
+
if self.agent and hasattr(self.agent, "instructions"):
|
|
360
|
+
self._trace_metadata["system_prompt"] = self.agent.instructions
|
|
361
|
+
|
|
362
|
+
# Start new trace if none active
|
|
363
|
+
if self._active_trace is None:
|
|
364
|
+
self._active_trace = trace(
|
|
365
|
+
workflow_name=self._trace_metadata.get("task_id", "Agent Task"),
|
|
366
|
+
metadata=self._trace_metadata,
|
|
367
|
+
)
|
|
368
|
+
self._active_trace.__enter__()
|
|
369
|
+
self._active_trace_id = self._active_trace.trace_id
|
|
370
|
+
|
|
371
|
+
result = None
|
|
372
|
+
|
|
373
|
+
for i, query in enumerate(user_inputs):
|
|
374
|
+
# Build input for this turn
|
|
375
|
+
if self._last_result is not None:
|
|
376
|
+
# Use to_input_list() to get conversation history, then append new user message
|
|
377
|
+
input_items = self._last_result.to_input_list() + [
|
|
378
|
+
{"role": "user", "content": query}
|
|
379
|
+
]
|
|
380
|
+
else:
|
|
381
|
+
input_items = query # first turn
|
|
382
|
+
|
|
383
|
+
# Run agent for this turn
|
|
384
|
+
result = await Runner.run(
|
|
385
|
+
starting_agent=self.agent,
|
|
386
|
+
input=input_items,
|
|
387
|
+
max_turns=self.runtime_config.max_turns,
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
# Store result for next turn's history
|
|
391
|
+
self._last_result = result
|
|
392
|
+
|
|
393
|
+
# Record response after each turn (including intermediate responses)
|
|
394
|
+
if result:
|
|
395
|
+
self.trace_processor.write_agent_response(
|
|
396
|
+
self._active_trace_id, str(result.final_output)
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
self._turn_count += 1
|
|
400
|
+
|
|
401
|
+
# Close trace context before generating trajectory
|
|
402
|
+
if self._active_trace is not None:
|
|
403
|
+
try:
|
|
404
|
+
self._active_trace.__exit__(None, None, None)
|
|
405
|
+
except Exception as e:
|
|
406
|
+
print(f"[WARNING] Failed to close trace context: {e}")
|
|
407
|
+
self._active_trace = None
|
|
408
|
+
|
|
409
|
+
# Generate trajectory at end of run
|
|
410
|
+
self._generate_trajectory(result.final_output if result else None)
|
|
411
|
+
|
|
412
|
+
return self.get_result()
|
|
413
|
+
|
|
414
|
+
def _generate_trajectory(self, final_output: Optional[str] = None) -> None:
|
|
415
|
+
"""Generate trajectory from current trace file"""
|
|
416
|
+
try:
|
|
417
|
+
if os.path.exists(self.trace_file):
|
|
418
|
+
trajectories = self.trajectory_converter.process_trace_file(
|
|
419
|
+
self.trace_file, output_name="trajectory"
|
|
420
|
+
)
|
|
421
|
+
if trajectories:
|
|
422
|
+
self._current_trajectory = trajectories[
|
|
423
|
+
-1
|
|
424
|
+
] # Get the latest trajectory
|
|
425
|
+
except Exception as e:
|
|
426
|
+
print(f"[WARNING] Failed to generate trajectory: {e}")
|
|
427
|
+
|
|
428
|
+
def get_result(self) -> Optional[AgentResult]:
|
|
429
|
+
"""
|
|
430
|
+
Get the current execution result.
|
|
431
|
+
|
|
432
|
+
Returns:
|
|
433
|
+
AgentResult with final output, turn count, trace ID, and trajectory.
|
|
434
|
+
Returns None if no execution has been performed yet.
|
|
435
|
+
"""
|
|
436
|
+
if self._last_result is None:
|
|
437
|
+
return None
|
|
438
|
+
|
|
439
|
+
return AgentResult(
|
|
440
|
+
final_output=str(self._last_result.final_output)
|
|
441
|
+
if self._last_result.final_output
|
|
442
|
+
else None,
|
|
443
|
+
turn_count=self._turn_count,
|
|
444
|
+
trajectory=self._current_trajectory,
|
|
445
|
+
trace_id=self._active_trace_id or "",
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
def reset_conversation(self) -> None:
|
|
449
|
+
"""
|
|
450
|
+
Reset conversation history for a new session.
|
|
451
|
+
|
|
452
|
+
Call this to clear the accumulated context and start fresh.
|
|
453
|
+
Automatically called during cleanup().
|
|
454
|
+
|
|
455
|
+
This will:
|
|
456
|
+
- Close the active trace context (finalizing the trajectory)
|
|
457
|
+
- Clear conversation history
|
|
458
|
+
- Allow a new trace to be started on next run() call
|
|
459
|
+
"""
|
|
460
|
+
# Close active trace context if exists
|
|
461
|
+
if self._active_trace is not None:
|
|
462
|
+
try:
|
|
463
|
+
self._active_trace.__exit__(None, None, None)
|
|
464
|
+
except Exception as e:
|
|
465
|
+
print(f"[WARNING] Failed to close trace context: {e}")
|
|
466
|
+
self._active_trace = None
|
|
467
|
+
self._active_trace_id = None
|
|
468
|
+
|
|
469
|
+
# Clear conversation state
|
|
470
|
+
self._last_result = None
|
|
471
|
+
self._trace_metadata = None
|
|
472
|
+
self._current_trajectory = None
|
|
473
|
+
self._turn_count = 0
|
|
474
|
+
|
|
475
|
+
async def cleanup(self) -> None:
|
|
476
|
+
"""Clean up resources and process traces"""
|
|
477
|
+
# Reset conversation history
|
|
478
|
+
self.reset_conversation()
|
|
479
|
+
|
|
480
|
+
# Cleanup MCP server connections
|
|
481
|
+
# Note: Cancel scope errors from anyio are suppressed in the MCP wrapper classes
|
|
482
|
+
for server in self.mcp_servers:
|
|
483
|
+
try:
|
|
484
|
+
await server.cleanup()
|
|
485
|
+
except BaseException as e:
|
|
486
|
+
print(f"[WARNING] Failed to cleanup server: {e}")
|
|
487
|
+
|
|
488
|
+
# Clean up skill temp directory
|
|
489
|
+
if self._skill_temp_dir:
|
|
490
|
+
cleanup_temp_directory(self._skill_temp_dir)
|
|
491
|
+
self._skill_temp_dir = None
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Any, List, Optional
|
|
3
|
+
|
|
4
|
+
from agents import Agent as OpenAIAgent
|
|
5
|
+
|
|
6
|
+
from dt_arena.src.types.agent import AgentConfig, RuntimeConfig
|
|
7
|
+
|
|
8
|
+
from .agent import OpenAISDKAgent
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class OpenAISDKNativeWrapper(OpenAISDKAgent):
|
|
12
|
+
"""
|
|
13
|
+
Wraps a pre-built OpenAI SDK Agent for evaluation.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
Overrides:
|
|
17
|
+
- __init__: Accept native agent instead of building from config
|
|
18
|
+
- initialize(): Clone native agent and add red-teaming MCP servers
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
native_agent: OpenAIAgent,
|
|
24
|
+
agent_config: Optional[AgentConfig] = None,
|
|
25
|
+
runtime_config: Optional[RuntimeConfig] = None,
|
|
26
|
+
):
|
|
27
|
+
"""
|
|
28
|
+
Initialize the wrapper around a native OpenAI SDK Agent.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
native_agent: Pre-built agents.Agent instance from user
|
|
32
|
+
agent_config: AgentConfig with red-teaming MCP servers to add
|
|
33
|
+
runtime_config: RuntimeConfig with model settings, output_dir, mcp_injection
|
|
34
|
+
"""
|
|
35
|
+
# Store the native agent before calling parent init
|
|
36
|
+
self._native_agent = native_agent
|
|
37
|
+
|
|
38
|
+
# Track red-teaming MCP servers we add (for cleanup)
|
|
39
|
+
self._redteaming_mcp_servers: List[Any] = []
|
|
40
|
+
|
|
41
|
+
# Set up tracing, directories, trajectory converter
|
|
42
|
+
super().__init__(
|
|
43
|
+
agent_config=agent_config or AgentConfig(system_prompt=""),
|
|
44
|
+
runtime_config=runtime_config,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
async def initialize(self) -> None:
|
|
48
|
+
"""
|
|
49
|
+
Initialize the wrapper.
|
|
50
|
+
|
|
51
|
+
The native agent's existing MCP servers are assumed to be already
|
|
52
|
+
connected by the user.
|
|
53
|
+
"""
|
|
54
|
+
self.agent = self._native_agent.clone()
|
|
55
|
+
|
|
56
|
+
# Create a new mcp_servers list to avoid modifying the original
|
|
57
|
+
original_servers = list(self.agent.mcp_servers or [])
|
|
58
|
+
self.agent.mcp_servers = original_servers
|
|
59
|
+
|
|
60
|
+
# Add red-teaming MCP servers from config (i.e., with tool injection)
|
|
61
|
+
await self._add_redteaming_mcp_servers()
|
|
62
|
+
|
|
63
|
+
# Setup skills if configured
|
|
64
|
+
await self._setup_skills()
|
|
65
|
+
if self._skill_temp_dir:
|
|
66
|
+
tools = self._build_tools()
|
|
67
|
+
if tools:
|
|
68
|
+
existing_tools = list(self.agent.tools or [])
|
|
69
|
+
existing_tools.extend(tools)
|
|
70
|
+
self.agent.tools = existing_tools
|
|
71
|
+
|
|
72
|
+
async def _add_redteaming_mcp_servers(self) -> None:
|
|
73
|
+
"""
|
|
74
|
+
Add red-teaming MCP servers from agent_config to the agent.
|
|
75
|
+
|
|
76
|
+
Tool injections from runtime_config.mcp_injection are applied to these
|
|
77
|
+
servers via the parent class's _create_mcp_server() method.
|
|
78
|
+
"""
|
|
79
|
+
if not self.config or not self.config.mcp_servers:
|
|
80
|
+
return
|
|
81
|
+
|
|
82
|
+
for server_config in self.config.mcp_servers:
|
|
83
|
+
if not server_config.enabled:
|
|
84
|
+
continue
|
|
85
|
+
|
|
86
|
+
# Check if server with same name already exists in native agent
|
|
87
|
+
existing_names = {getattr(s, 'name', None) for s in (self.agent.mcp_servers or [])}
|
|
88
|
+
if server_config.name in existing_names:
|
|
89
|
+
print(f"[WARNING] MCP server '{server_config.name}' already exists in native agent. "
|
|
90
|
+
f"Red-teaming server will be added with potential name conflict.")
|
|
91
|
+
|
|
92
|
+
# Create red-teaming server using parent's method (applies tool injection)
|
|
93
|
+
server = self._create_mcp_server(server_config)
|
|
94
|
+
|
|
95
|
+
# Connect the server
|
|
96
|
+
await server.connect()
|
|
97
|
+
|
|
98
|
+
# Track for cleanup
|
|
99
|
+
self._redteaming_mcp_servers.append(server)
|
|
100
|
+
self.mcp_servers.append(server)
|
|
101
|
+
self.mcp_server_names.append(server_config.name)
|
|
102
|
+
|
|
103
|
+
# Add to agent's mcp_servers list
|
|
104
|
+
self.agent.mcp_servers.append(server)
|
|
105
|
+
|
|
106
|
+
async def cleanup(self) -> None:
|
|
107
|
+
"""
|
|
108
|
+
Clean up resources.
|
|
109
|
+
|
|
110
|
+
Cleans up all MCP servers (both native agent's and red-teaming servers).
|
|
111
|
+
"""
|
|
112
|
+
# Reset conversation state
|
|
113
|
+
self.reset_conversation()
|
|
114
|
+
|
|
115
|
+
# Clean up skill temp directory
|
|
116
|
+
if self._skill_temp_dir:
|
|
117
|
+
from utils.skill_helpers import cleanup_temp_directory
|
|
118
|
+
cleanup_temp_directory(self._skill_temp_dir)
|
|
119
|
+
self._skill_temp_dir = None
|
|
120
|
+
|
|
121
|
+
# Cleanup all MCP servers on the cloned agent
|
|
122
|
+
# This includes both the native agent's servers and red-teaming servers we added
|
|
123
|
+
if self.agent is not None:
|
|
124
|
+
for server in (self.agent.mcp_servers or []):
|
|
125
|
+
try:
|
|
126
|
+
await server.cleanup()
|
|
127
|
+
except BaseException as e:
|
|
128
|
+
print(f"[WARNING] Failed to cleanup server: {e}")
|
|
129
|
+
self.agent.mcp_servers = []
|
|
130
|
+
|
|
131
|
+
self._redteaming_mcp_servers = []
|
|
132
|
+
self.mcp_servers = []
|
|
133
|
+
self.mcp_server_names = []
|
|
134
|
+
|
|
135
|
+
@property
|
|
136
|
+
def native_agent(self) -> OpenAIAgent:
|
|
137
|
+
"""Access the original (unmodified) native agent."""
|
|
138
|
+
return self._native_agent
|
|
139
|
+
|
|
140
|
+
@property
|
|
141
|
+
def wrapped_agent(self) -> OpenAIAgent:
|
|
142
|
+
"""Access the cloned and modified agent used for execution."""
|
|
143
|
+
return self.agent
|