decodingtrust-agent-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent/__init__.py +30 -0
- agent/claudesdk/__init__.py +8 -0
- agent/claudesdk/example.py +221 -0
- agent/claudesdk/src/__init__.py +8 -0
- agent/claudesdk/src/agent.py +400 -0
- agent/claudesdk/src/mcp_proxy.py +409 -0
- agent/claudesdk/src/utils.py +420 -0
- agent/googleadk/__init__.py +15 -0
- agent/googleadk/example.py +237 -0
- agent/googleadk/src/__init__.py +12 -0
- agent/googleadk/src/agent.py +401 -0
- agent/googleadk/src/mcp_wrapper.py +163 -0
- agent/googleadk/src/utils.py +602 -0
- agent/langchain/__init__.py +8 -0
- agent/langchain/example.py +213 -0
- agent/langchain/src/__init__.py +8 -0
- agent/langchain/src/agent.py +645 -0
- agent/langchain/src/utils.py +433 -0
- agent/openaisdk/__init__.py +17 -0
- agent/openaisdk/example.py +228 -0
- agent/openaisdk/src/__init__.py +12 -0
- agent/openaisdk/src/agent.py +491 -0
- agent/openaisdk/src/agent_wrapper.py +143 -0
- agent/openaisdk/src/mcp_wrapper.py +395 -0
- agent/openaisdk/src/utils.py +493 -0
- agent/openclaw/__init__.py +10 -0
- agent/openclaw/example.py +251 -0
- agent/openclaw/src/__init__.py +14 -0
- agent/openclaw/src/agent.py +930 -0
- agent/openclaw/src/helpers/__init__.py +1 -0
- agent/openclaw/src/helpers/auth_helpers.py +55 -0
- agent/openclaw/src/mcp_proxy.py +564 -0
- agent/openclaw/src/plugin_generator.py +231 -0
- agent/openclaw/src/utils.py +341 -0
- agent/pocketflow/__init__.py +18 -0
- agent/pocketflow/example.py +221 -0
- agent/pocketflow/prompts/react_agent.py +46 -0
- agent/pocketflow/src/__init__.py +6 -0
- agent/pocketflow/src/agent.py +507 -0
- agent/pocketflow/src/agent_wrapper.py +159 -0
- agent/pocketflow/src/async_helper.py +92 -0
- agent/pocketflow/src/mcp_react_agent.py +279 -0
- agent/pocketflow/src/native_agent.py +74 -0
- agent/pocketflow/src/nodes.py +467 -0
- benchmark/__init__.py +0 -0
- benchmark/browser/benign.jsonl +34 -0
- benchmark/browser/direct.jsonl +85 -0
- benchmark/browser/indirect.jsonl +82 -0
- benchmark/code/benign.jsonl +0 -0
- benchmark/code/direct.jsonl +121 -0
- benchmark/code/indirect.jsonl +165 -0
- benchmark/crm/benign.jsonl +165 -0
- benchmark/crm/direct.jsonl +90 -0
- benchmark/crm/indirect.jsonl +150 -0
- benchmark/customer-service/benign.jsonl +160 -0
- benchmark/customer-service/direct.jsonl +100 -0
- benchmark/customer-service/indirect.jsonl +101 -0
- benchmark/finance/benign.jsonl +0 -0
- benchmark/finance/direct.jsonl +200 -0
- benchmark/finance/indirect.jsonl +200 -0
- benchmark/legal/benign.jsonl +0 -0
- benchmark/legal/direct.jsonl +200 -0
- benchmark/legal/indirect.jsonl +200 -0
- benchmark/macos/benign.jsonl +30 -0
- benchmark/macos/direct.jsonl +50 -0
- benchmark/macos/indirect.jsonl +50 -0
- benchmark/medical/benign.jsonl +642 -0
- benchmark/medical/direct.jsonl +229 -0
- benchmark/medical/indirect.jsonl +222 -0
- benchmark/os-filesystem/benign.jsonl +200 -0
- benchmark/os-filesystem/direct.jsonl +200 -0
- benchmark/os-filesystem/indirect.jsonl +200 -0
- benchmark/research/benign.jsonl +0 -0
- benchmark/research/direct.jsonl +119 -0
- benchmark/research/indirect.jsonl +125 -0
- benchmark/telecom/benign.jsonl +120 -0
- benchmark/telecom/direct.jsonl +161 -0
- benchmark/telecom/indirect.jsonl +166 -0
- benchmark/travel/benign.jsonl +130 -0
- benchmark/travel/direct.jsonl +105 -0
- benchmark/travel/indirect.jsonl +120 -0
- benchmark/windows/benign.jsonl +100 -0
- benchmark/windows/direct.jsonl +140 -0
- benchmark/windows/indirect.jsonl +107 -0
- benchmark/workflow/benign.jsonl +335 -0
- benchmark/workflow/direct.jsonl +78 -0
- benchmark/workflow/indirect.jsonl +107 -0
- cli/__init__.py +5 -0
- cli/main.py +182 -0
- cli/scaffold.py +334 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/METADATA +642 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/RECORD +374 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/WHEEL +5 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/entry_points.txt +2 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/licenses/LICENSE +201 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/top_level.txt +6 -0
- dt_arena/config/env.yaml +515 -0
- dt_arena/config/injection_mcp.yaml +430 -0
- dt_arena/config/mcp.yaml +642 -0
- dt_arena/envs/arxiv/docker-compose-hub.yml +31 -0
- dt_arena/envs/arxiv/docker-compose.yml +36 -0
- dt_arena/envs/atlassian/docker/docker-compose.dev.yml +65 -0
- dt_arena/envs/atlassian/docker/docker-compose.yml +53 -0
- dt_arena/envs/atlassian/docker-compose-hub.yml +57 -0
- dt_arena/envs/atlassian/docker-compose.yml +72 -0
- dt_arena/envs/bigquery/docker-compose.yml +20 -0
- dt_arena/envs/booking/docker-compose.yml +59 -0
- dt_arena/envs/calendar/docker-compose-hub.yml +30 -0
- dt_arena/envs/calendar/docker-compose.yml +42 -0
- dt_arena/envs/custom-website/docker-compose.yml +6 -0
- dt_arena/envs/customer_service/docker-compose.yml +59 -0
- dt_arena/envs/databricks/docker-compose-hub.yml +47 -0
- dt_arena/envs/databricks/docker-compose.yml +51 -0
- dt_arena/envs/ecommerce/docker-compose.yml +6 -0
- dt_arena/envs/ers/docker-compose.yml +36 -0
- dt_arena/envs/ers/hrms/docker/docker-compose.yml +31 -0
- dt_arena/envs/finance/docker-compose.yml +23 -0
- dt_arena/envs/github/docker/docker-compose-hub.yml +50 -0
- dt_arena/envs/github/docker/docker-compose.yml +50 -0
- dt_arena/envs/gmail/docker-compose-hub.yml +51 -0
- dt_arena/envs/gmail/docker-compose.yml +65 -0
- dt_arena/envs/google-form/docker-compose-hub.yml +33 -0
- dt_arena/envs/google-form/docker-compose.yml +41 -0
- dt_arena/envs/googledocs/docker-compose-hub.yml +61 -0
- dt_arena/envs/googledocs/docker-compose.yml +78 -0
- dt_arena/envs/hospital/docker-compose-hub.yml +25 -0
- dt_arena/envs/hospital/docker-compose.yml +27 -0
- dt_arena/envs/legal/docker-compose.yml +22 -0
- dt_arena/envs/linkedin/docker-compose.yml +63 -0
- dt_arena/envs/macos/docker-compose.yml +79 -0
- dt_arena/envs/os-filesystem/docker-compose-hub.yml +16 -0
- dt_arena/envs/os-filesystem/docker-compose.yml +20 -0
- dt_arena/envs/paypal/docker-compose-hub.yml +48 -0
- dt_arena/envs/paypal/docker-compose.yml +63 -0
- dt_arena/envs/research/docker-compose-hub.yml +13 -0
- dt_arena/envs/research/docker-compose.yml +24 -0
- dt_arena/envs/salesforce_crm/docker-compose-hub.yaml +45 -0
- dt_arena/envs/salesforce_crm/docker-compose.yaml +49 -0
- dt_arena/envs/slack/docker-compose-hub.yml +28 -0
- dt_arena/envs/slack/docker-compose.yml +41 -0
- dt_arena/envs/snowflake/docker-compose-hub.yml +41 -0
- dt_arena/envs/snowflake/docker-compose.yml +44 -0
- dt_arena/envs/telecom/docker-compose-hub.yml +16 -0
- dt_arena/envs/telecom/docker-compose.yml +17 -0
- dt_arena/envs/telegram/docker-compose-hub.yml +57 -0
- dt_arena/envs/telegram/docker-compose.yml +62 -0
- dt_arena/envs/terminal/docker-compose-hub.yml +12 -0
- dt_arena/envs/terminal/docker-compose.yml +26 -0
- dt_arena/envs/travel/docker-compose-hub.yml +19 -0
- dt_arena/envs/travel/docker-compose.yml +19 -0
- dt_arena/envs/whatsapp/docker-compose-hub.yml +61 -0
- dt_arena/envs/whatsapp/docker-compose.yml +78 -0
- dt_arena/envs/windows/docker-compose.yml +71 -0
- dt_arena/envs/zoom/docker-compose-hub.yml +27 -0
- dt_arena/envs/zoom/docker-compose.yml +40 -0
- dt_arena/injection_mcp_server/atlassian/env_injection.py +134 -0
- dt_arena/injection_mcp_server/calendar/env_injection.py +217 -0
- dt_arena/injection_mcp_server/custom_website/env_injection.py +97 -0
- dt_arena/injection_mcp_server/customer_service/env_injection.py +659 -0
- dt_arena/injection_mcp_server/databricks/env_injection.py +255 -0
- dt_arena/injection_mcp_server/ecommerce/env_injection.py +110 -0
- dt_arena/injection_mcp_server/finance/env_injection.py +85 -0
- dt_arena/injection_mcp_server/github/env_injection.py +206 -0
- dt_arena/injection_mcp_server/gmail/env_injection.py +211 -0
- dt_arena/injection_mcp_server/google_form/env_injection.py +186 -0
- dt_arena/injection_mcp_server/googledocs/env_injection.py +44 -0
- dt_arena/injection_mcp_server/hospital/env_injection.py +43 -0
- dt_arena/injection_mcp_server/legal/env_injection.py +229 -0
- dt_arena/injection_mcp_server/macos/env_injection.py +272 -0
- dt_arena/injection_mcp_server/os-filesystem/env_injection.py +341 -0
- dt_arena/injection_mcp_server/paypal/env_injection.py +268 -0
- dt_arena/injection_mcp_server/research/env_injection.py +616 -0
- dt_arena/injection_mcp_server/salesforce/env_injection.py +514 -0
- dt_arena/injection_mcp_server/slack/env_injection.py +265 -0
- dt_arena/injection_mcp_server/snowflake/env_injection.py +230 -0
- dt_arena/injection_mcp_server/telecom/env_injection.py +503 -0
- dt_arena/injection_mcp_server/telegram/env_injection.py +171 -0
- dt_arena/injection_mcp_server/terminal/env_injection.py +523 -0
- dt_arena/injection_mcp_server/travel/env_injection.py +173 -0
- dt_arena/injection_mcp_server/whatsapp/env_injection.py +185 -0
- dt_arena/injection_mcp_server/windows/env_injection.py +943 -0
- dt_arena/injection_mcp_server/zoom/env_injection.py +216 -0
- dt_arena/mcp_server/atlassian/main.py +1554 -0
- dt_arena/mcp_server/atlassian/test_server.py +66 -0
- dt_arena/mcp_server/bigquery/main.py +333 -0
- dt_arena/mcp_server/booking/main.py +310 -0
- dt_arena/mcp_server/browser/main.py +1741 -0
- dt_arena/mcp_server/calendar/example_multi_user.py +162 -0
- dt_arena/mcp_server/calendar/main.py +792 -0
- dt_arena/mcp_server/calendar/test_mcp.py +135 -0
- dt_arena/mcp_server/customer_service/main.py +1063 -0
- dt_arena/mcp_server/databricks/main.py +566 -0
- dt_arena/mcp_server/databricks/probe.py +102 -0
- dt_arena/mcp_server/ers/main.py +845 -0
- dt_arena/mcp_server/finance/__init__.py +87 -0
- dt_arena/mcp_server/finance/core/__init__.py +12 -0
- dt_arena/mcp_server/finance/core/data_loader.py +558 -0
- dt_arena/mcp_server/finance/core/portfolio.py +565 -0
- dt_arena/mcp_server/finance/evaluation/__init__.py +20 -0
- dt_arena/mcp_server/finance/evaluation/evaluator.py +217 -0
- dt_arena/mcp_server/finance/evaluation/logger.py +137 -0
- dt_arena/mcp_server/finance/injection/__init__.py +66 -0
- dt_arena/mcp_server/finance/injection/config.py +176 -0
- dt_arena/mcp_server/finance/injection/content.py +755 -0
- dt_arena/mcp_server/finance/injection/html.py +409 -0
- dt_arena/mcp_server/finance/injection/locations.py +167 -0
- dt_arena/mcp_server/finance/injection/methods.py +193 -0
- dt_arena/mcp_server/finance/injection/presets.py +1023 -0
- dt_arena/mcp_server/finance/main.py +361 -0
- dt_arena/mcp_server/finance/run_mcp.py +21 -0
- dt_arena/mcp_server/finance/run_web.py +26 -0
- dt_arena/mcp_server/finance/server/__init__.py +41 -0
- dt_arena/mcp_server/finance/server/extractor.py +1453 -0
- dt_arena/mcp_server/finance/server/extractor_minimal.py +292 -0
- dt_arena/mcp_server/finance/server/extractor_simple.py +1164 -0
- dt_arena/mcp_server/finance/server/injection_mcp.py +865 -0
- dt_arena/mcp_server/finance/server/mcp.py +451 -0
- dt_arena/mcp_server/finance/server/tools/__init__.py +23 -0
- dt_arena/mcp_server/finance/server/tools/account.py +88 -0
- dt_arena/mcp_server/finance/server/tools/browsing.py +328 -0
- dt_arena/mcp_server/finance/server/tools/social.py +73 -0
- dt_arena/mcp_server/finance/server/tools/trading.py +242 -0
- dt_arena/mcp_server/finance/server/tools/utility.py +49 -0
- dt_arena/mcp_server/finance/server/web.py +2139 -0
- dt_arena/mcp_server/finance/tasks/benchmark/__init__.py +28 -0
- dt_arena/mcp_server/finance/tasks/benchmark/attack_pool.py +3026 -0
- dt_arena/mcp_server/finance/tasks/benchmark/attack_runner.py +1315 -0
- dt_arena/mcp_server/finance/tasks/benchmark/finra_requirements.py +1335 -0
- dt_arena/mcp_server/finance/tasks/benchmark/finra_tasks.py +3665 -0
- dt_arena/mcp_server/finance/tasks/benchmark/malicious_tasks.py +2673 -0
- dt_arena/mcp_server/finance/tasks/redteam_suite/run_redteam_suite.py +1713 -0
- dt_arena/mcp_server/finance/test_mcp_tools.py +476 -0
- dt_arena/mcp_server/github/main.py +441 -0
- dt_arena/mcp_server/gmail/main.py +1004 -0
- dt_arena/mcp_server/google_form/main.py +141 -0
- dt_arena/mcp_server/googledocs/main.py +458 -0
- dt_arena/mcp_server/hospital/mcp_server.py +458 -0
- dt_arena/mcp_server/legal/__init__.py +9 -0
- dt_arena/mcp_server/legal/core/__init__.py +14 -0
- dt_arena/mcp_server/legal/core/courtlistener_store.py +762 -0
- dt_arena/mcp_server/legal/core/data_loader.py +266 -0
- dt_arena/mcp_server/legal/core/document_store.py +197 -0
- dt_arena/mcp_server/legal/core/matter_manager.py +466 -0
- dt_arena/mcp_server/legal/main.py +89 -0
- dt_arena/mcp_server/legal/scripts/collect_data.py +988 -0
- dt_arena/mcp_server/legal/server/__init__.py +14 -0
- dt_arena/mcp_server/legal/server/mcp.py +2330 -0
- dt_arena/mcp_server/macos/client_test.py +270 -0
- dt_arena/mcp_server/macos/mcp_server.py +285 -0
- dt_arena/mcp_server/os-filesystem/main.py +1380 -0
- dt_arena/mcp_server/paypal/main.py +501 -0
- dt_arena/mcp_server/research/main.py +777 -0
- dt_arena/mcp_server/salesforce/main.py +2006 -0
- dt_arena/mcp_server/slack/main.py +318 -0
- dt_arena/mcp_server/snowflake/main.py +612 -0
- dt_arena/mcp_server/snowflake/probe.py +183 -0
- dt_arena/mcp_server/telecom/mcp_client.py +423 -0
- dt_arena/mcp_server/telecom/mcp_server.py +1059 -0
- dt_arena/mcp_server/telegram/main.py +338 -0
- dt_arena/mcp_server/terminal/main.py +163 -0
- dt_arena/mcp_server/travel/client_test.py +16 -0
- dt_arena/mcp_server/travel/mcp_server.py +404 -0
- dt_arena/mcp_server/whatsapp/main.py +318 -0
- dt_arena/mcp_server/windows/client_test.py +270 -0
- dt_arena/mcp_server/windows/mcp_server.py +218 -0
- dt_arena/mcp_server/zoom/main.py +466 -0
- dt_arena/src/__init__.py +0 -0
- dt_arena/src/hooks/__init__.py +0 -0
- dt_arena/src/hooks/audit_log.py +30 -0
- dt_arena/src/hooks/hooks.json +3 -0
- dt_arena/src/run_benign.py +142 -0
- dt_arena/src/types/__init__.py +0 -0
- dt_arena/src/types/agent.py +441 -0
- dt_arena/src/types/attacks.py +2 -0
- dt_arena/src/types/environment.py +2 -0
- dt_arena/src/types/hooks.py +174 -0
- dt_arena/src/types/judge.py +52 -0
- dt_arena/src/types/red_teaming_trajectory.py +385 -0
- dt_arena/src/types/task.py +260 -0
- dt_arena/src/types/trajectory.py +315 -0
- dt_arena/utils/__init__.py +1 -0
- dt_arena/utils/atlassian/__init__.py +27 -0
- dt_arena/utils/atlassian/helpers.py +520 -0
- dt_arena/utils/bigquery/__init__.py +1 -0
- dt_arena/utils/bigquery/helpers.py +246 -0
- dt_arena/utils/calendar/__init__.py +1 -0
- dt_arena/utils/calendar/helpers.py +87 -0
- dt_arena/utils/customer_service/__init__.py +17 -0
- dt_arena/utils/customer_service/cs_env_client.py +940 -0
- dt_arena/utils/customer_service/helpers.py +339 -0
- dt_arena/utils/customer_service/judges/__init__.py +20 -0
- dt_arena/utils/customer_service/judges/abusive_comms.py +149 -0
- dt_arena/utils/customer_service/judges/bulk_destructive.py +202 -0
- dt_arena/utils/customer_service/judges/cross_customer.py +481 -0
- dt_arena/utils/customer_service/judges/email_exfiltration.py +320 -0
- dt_arena/utils/customer_service/judges/financial_fraud.py +221 -0
- dt_arena/utils/customer_service/judges/llm_utils.py +264 -0
- dt_arena/utils/customer_service/judges/pii_disclosure.py +837 -0
- dt_arena/utils/customer_service/judges/policy_misrep.py +44 -0
- dt_arena/utils/customer_service/judges/text_utils.py +21 -0
- dt_arena/utils/databricks/__init__.py +2 -0
- dt_arena/utils/databricks/helpers.py +210 -0
- dt_arena/utils/finance/__init__.py +0 -0
- dt_arena/utils/finance/helpers.py +263 -0
- dt_arena/utils/github/__init__.py +1 -0
- dt_arena/utils/github/helpers.py +249 -0
- dt_arena/utils/gmail/__init__.py +1 -0
- dt_arena/utils/gmail/helpers.py +344 -0
- dt_arena/utils/google_form/__init__.py +2 -0
- dt_arena/utils/google_form/helpers.py +133 -0
- dt_arena/utils/legal/__init__.py +0 -0
- dt_arena/utils/legal/helpers.py +228 -0
- dt_arena/utils/macos/__init__.py +0 -0
- dt_arena/utils/macos/env_setup.py +215 -0
- dt_arena/utils/macos/helpers.py +61 -0
- dt_arena/utils/os_filesystem/__init__.py +1 -0
- dt_arena/utils/os_filesystem/helpers.py +366 -0
- dt_arena/utils/paypal/__init__.py +1 -0
- dt_arena/utils/paypal/helpers.py +178 -0
- dt_arena/utils/port_allocator.py +266 -0
- dt_arena/utils/research/__init__.py +0 -0
- dt_arena/utils/research/helpers.py +251 -0
- dt_arena/utils/salesforce/__init__.py +1 -0
- dt_arena/utils/salesforce/helpers.py +719 -0
- dt_arena/utils/slack/__init__.py +1 -0
- dt_arena/utils/slack/helpers.py +176 -0
- dt_arena/utils/snowflake/__init__.py +1 -0
- dt_arena/utils/snowflake/helpers.py +166 -0
- dt_arena/utils/telecom/__init__.py +1 -0
- dt_arena/utils/telecom/helpers.py +760 -0
- dt_arena/utils/telegram/__init__.py +0 -0
- dt_arena/utils/telegram/helpers.py +174 -0
- dt_arena/utils/terminal/__init__.py +0 -0
- dt_arena/utils/terminal/helpers.py +20 -0
- dt_arena/utils/travel/__init__.py +0 -0
- dt_arena/utils/travel/env_client.py +537 -0
- dt_arena/utils/travel/llm_judge.py +137 -0
- dt_arena/utils/travel/prompts.py +64 -0
- dt_arena/utils/utils/__init__.py +122 -0
- dt_arena/utils/whatsapp/__init__.py +0 -0
- dt_arena/utils/whatsapp/helpers.py +226 -0
- dt_arena/utils/windows/__init__.py +0 -0
- dt_arena/utils/windows/env_reset.py +224 -0
- dt_arena/utils/windows/env_setup.py +280 -0
- dt_arena/utils/windows/exfil_helpers.py +170 -0
- dt_arena/utils/windows/helpers.py +74 -0
- dt_arena/utils/zoom/__init__.py +1 -0
- dt_arena/utils/zoom/helpers.py +70 -0
- eval/__init__.py +1 -0
- eval/evaluation.py +426 -0
- eval/task_runner.py +449 -0
- utils/__init__.py +148 -0
- utils/agent_helpers.py +308 -0
- utils/agent_wrapper.py +189 -0
- utils/compose_utils.py +135 -0
- utils/config.py +77 -0
- utils/env_helpers.py +104 -0
- utils/eval_stats.py +88 -0
- utils/injection_helpers.py +429 -0
- utils/injection_mcp_helpers.py +152 -0
- utils/judge_helpers.py +181 -0
- utils/judge_utils.py +472 -0
- utils/llm.py +196 -0
- utils/logging.py +45 -0
- utils/mcp_helpers.py +232 -0
- utils/mcp_manager.py +235 -0
- utils/memory_guard.py +18 -0
- utils/red_teaming_sandbox.py +476 -0
- utils/reset_helpers.py +318 -0
- utils/resource_manager.py +370 -0
- utils/skill_helpers.py +447 -0
- utils/task_executor.py +904 -0
- utils/task_helpers.py +270 -0
- utils/template_helpers.py +179 -0
|
@@ -0,0 +1,401 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import uuid
|
|
3
|
+
import asyncio
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from typing import Dict, Any, Optional, List, Union
|
|
6
|
+
|
|
7
|
+
from google.genai import types
|
|
8
|
+
from google.adk.agents.llm_agent import LlmAgent
|
|
9
|
+
from google.adk.runners import Runner
|
|
10
|
+
from google.adk.sessions import InMemorySessionService
|
|
11
|
+
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
|
|
12
|
+
from google.adk.tools.mcp_tool.mcp_session_manager import (
|
|
13
|
+
StreamableHTTPConnectionParams,
|
|
14
|
+
SseServerParams,
|
|
15
|
+
StdioServerParameters,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
from dt_arena.src.types.agent import Agent, AgentConfig, RuntimeConfig, MCPServerConfig, AgentResult
|
|
19
|
+
from dt_arena.src.types.trajectory import Trajectory
|
|
20
|
+
|
|
21
|
+
from .mcp_wrapper import McpToolset
|
|
22
|
+
from .utils import GoogleADKTraceProcessor, GoogleADKTrajectoryConverter
|
|
23
|
+
from utils.skill_helpers import (
|
|
24
|
+
create_injected_skills_directory,
|
|
25
|
+
cleanup_temp_directory,
|
|
26
|
+
load_skills_as_toolset,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class GoogleADKAgent(Agent):
|
|
31
|
+
"""General Agent Workflow based on Google ADK (Agent Development Kit)"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
agent_config: AgentConfig,
|
|
36
|
+
runtime_config: Optional[RuntimeConfig] = None,
|
|
37
|
+
tool_filter: Optional[List[str]] = None
|
|
38
|
+
):
|
|
39
|
+
"""
|
|
40
|
+
Initialize Google ADK Agent
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
agent_config: Agent configuration (system_prompt and mcp_servers)
|
|
44
|
+
runtime_config: Runtime configuration (model, temperature, max_turns, output_dir)
|
|
45
|
+
tool_filter: Optional list of tool names to filter (include only these tools)
|
|
46
|
+
"""
|
|
47
|
+
super().__init__(agent_config, runtime_config)
|
|
48
|
+
|
|
49
|
+
self.tool_filter = tool_filter
|
|
50
|
+
|
|
51
|
+
# Setup output directory
|
|
52
|
+
output_dir = self.runtime_config.output_dir or os.path.join(os.getcwd(), "results")
|
|
53
|
+
self.output_dir = output_dir
|
|
54
|
+
self.traces_dir = os.path.join(self.output_dir, "traces")
|
|
55
|
+
self.trajectories_dir = self.output_dir
|
|
56
|
+
os.makedirs(self.traces_dir, exist_ok=True)
|
|
57
|
+
os.makedirs(self.trajectories_dir, exist_ok=True)
|
|
58
|
+
|
|
59
|
+
# Initialize trace file and processor
|
|
60
|
+
self.timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
61
|
+
self.trace_file = os.path.join(self.traces_dir, f"traces_{self.timestamp}.jsonl")
|
|
62
|
+
self.trace_processor = GoogleADKTraceProcessor(self.trace_file)
|
|
63
|
+
|
|
64
|
+
# Initialize trajectory converter
|
|
65
|
+
self.trajectory_converter = GoogleADKTrajectoryConverter(self.trajectories_dir, self.timestamp)
|
|
66
|
+
|
|
67
|
+
# Google ADK specific components
|
|
68
|
+
self.toolsets: List[McpToolset] = []
|
|
69
|
+
self.session_service = None
|
|
70
|
+
self.artifact_service = None
|
|
71
|
+
self.runner = None
|
|
72
|
+
self.session = None
|
|
73
|
+
|
|
74
|
+
# Skill injection temp directory
|
|
75
|
+
self._skill_temp_dir: Optional[str] = None
|
|
76
|
+
|
|
77
|
+
# Multi-turn conversation state
|
|
78
|
+
self._trace_id: Optional[str] = None
|
|
79
|
+
self._trace_metadata: Optional[Dict[str, Any]] = None
|
|
80
|
+
self._turn_count: int = 0
|
|
81
|
+
self._current_trajectory: Optional[Trajectory] = None
|
|
82
|
+
|
|
83
|
+
async def _setup_skills(self) -> None:
|
|
84
|
+
"""Setup skill directories and apply any skill injections."""
|
|
85
|
+
skill_directories = self.config.skill_directories if self.config else []
|
|
86
|
+
skill_injection = self.runtime_config.skill_injection
|
|
87
|
+
|
|
88
|
+
has_create_mode = skill_injection and any(
|
|
89
|
+
any(inj.mode == "create" for inj in injs)
|
|
90
|
+
for injs in skill_injection.values()
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
if not skill_directories and not has_create_mode:
|
|
94
|
+
return
|
|
95
|
+
|
|
96
|
+
self._skill_temp_dir = create_injected_skills_directory(
|
|
97
|
+
source_skill_dirs=skill_directories,
|
|
98
|
+
skill_injection=skill_injection,
|
|
99
|
+
skill_subpath="skills",
|
|
100
|
+
base_dir=self.output_dir,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
def _build_skill_toolset(self):
|
|
104
|
+
"""Build a SkillToolset from injected skills using official ADK skills API."""
|
|
105
|
+
if not self._skill_temp_dir:
|
|
106
|
+
return None
|
|
107
|
+
return load_skills_as_toolset(self._skill_temp_dir, skill_subpath="skills")
|
|
108
|
+
|
|
109
|
+
async def initialize(self) -> None:
|
|
110
|
+
"""Initialize agent and connect to MCP servers"""
|
|
111
|
+
if not self.config:
|
|
112
|
+
raise ValueError("Agent config is required")
|
|
113
|
+
|
|
114
|
+
# Initialize Google ADK services
|
|
115
|
+
self.session_service = InMemorySessionService()
|
|
116
|
+
self.artifact_service = InMemoryArtifactService()
|
|
117
|
+
|
|
118
|
+
# Setup skills if configured
|
|
119
|
+
await self._setup_skills()
|
|
120
|
+
|
|
121
|
+
# Load and connect to MCP servers (creates McpToolset instances)
|
|
122
|
+
await self.load_mcp_servers()
|
|
123
|
+
|
|
124
|
+
# Build agent tools list
|
|
125
|
+
agent_tools = list(self.toolsets)
|
|
126
|
+
|
|
127
|
+
# Add skills as toolset if configured
|
|
128
|
+
skill_toolset = self._build_skill_toolset()
|
|
129
|
+
if skill_toolset:
|
|
130
|
+
agent_tools.append(skill_toolset)
|
|
131
|
+
|
|
132
|
+
# Create Google ADK agent instance
|
|
133
|
+
self.agent = LlmAgent(
|
|
134
|
+
model=self.runtime_config.model,
|
|
135
|
+
name=self.config.name,
|
|
136
|
+
instruction=self.config.system_prompt,
|
|
137
|
+
tools=agent_tools,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# Create runner
|
|
141
|
+
self.runner = Runner(
|
|
142
|
+
app_name=self.config.name,
|
|
143
|
+
agent=self.agent,
|
|
144
|
+
artifact_service=self.artifact_service,
|
|
145
|
+
session_service=self.session_service,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# Create session
|
|
149
|
+
self.session = await self.session_service.create_session(
|
|
150
|
+
state={},
|
|
151
|
+
app_name=self.config.name,
|
|
152
|
+
user_id='user_default'
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
def _create_mcp_server(self, server_config: MCPServerConfig) -> Any:
|
|
156
|
+
"""
|
|
157
|
+
Create Google ADK-specific MCP toolset instance.
|
|
158
|
+
|
|
159
|
+
If mcp_injection is configured in runtime_config, the wrapper class
|
|
160
|
+
will automatically apply tool description injections.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
server_config: Configuration for the MCP server
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
McpToolset instance configured for the appropriate transport
|
|
167
|
+
"""
|
|
168
|
+
# Get tool injections for this server (if any)
|
|
169
|
+
tool_injections = None
|
|
170
|
+
if self.runtime_config.mcp_injection:
|
|
171
|
+
tool_injections = self.runtime_config.mcp_injection.get(server_config.name)
|
|
172
|
+
|
|
173
|
+
if server_config.transport == "http":
|
|
174
|
+
connection_params = StreamableHTTPConnectionParams(
|
|
175
|
+
url=server_config.url,
|
|
176
|
+
)
|
|
177
|
+
elif server_config.transport == "sse":
|
|
178
|
+
connection_params = SseServerParams(
|
|
179
|
+
url=server_config.url,
|
|
180
|
+
)
|
|
181
|
+
elif server_config.transport == "stdio":
|
|
182
|
+
connection_params = StdioServerParameters(
|
|
183
|
+
command=server_config.command,
|
|
184
|
+
args=server_config.args or [],
|
|
185
|
+
env=server_config.env,
|
|
186
|
+
)
|
|
187
|
+
else:
|
|
188
|
+
raise ValueError(f"Unsupported transport type: {server_config.transport}")
|
|
189
|
+
|
|
190
|
+
# Create McpToolset with optional tool filter and injections
|
|
191
|
+
toolset = McpToolset(
|
|
192
|
+
connection_params=connection_params,
|
|
193
|
+
tool_filter=self.tool_filter,
|
|
194
|
+
tool_injections=tool_injections,
|
|
195
|
+
name=server_config.name,
|
|
196
|
+
hook_manager=self.hook_manager,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
return toolset
|
|
200
|
+
|
|
201
|
+
async def load_mcp_servers(self) -> List[Any]:
|
|
202
|
+
"""
|
|
203
|
+
Load MCP servers (McpToolset instances).
|
|
204
|
+
|
|
205
|
+
Google ADK's McpToolset connects lazily, so no explicit connect() needed.
|
|
206
|
+
"""
|
|
207
|
+
# Use base class to create servers and track names
|
|
208
|
+
await super().load_mcp_servers()
|
|
209
|
+
|
|
210
|
+
self.toolsets = self.mcp_servers
|
|
211
|
+
|
|
212
|
+
# Check for duplicate tool names across servers
|
|
213
|
+
await self._check_duplicate_tools()
|
|
214
|
+
|
|
215
|
+
return self.toolsets
|
|
216
|
+
|
|
217
|
+
async def _get_server_tools(self, server: Any, server_name: str) -> List[str]:
|
|
218
|
+
"""
|
|
219
|
+
Get the list of tool names from an MCP server (McpToolset).
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
server: The McpToolset instance
|
|
223
|
+
server_name: Name of the server from config
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
List of tool names provided by this server
|
|
227
|
+
"""
|
|
228
|
+
tool_names = []
|
|
229
|
+
if hasattr(server, 'get_tools') and callable(server.get_tools):
|
|
230
|
+
tools = await server.get_tools()
|
|
231
|
+
for tool in tools:
|
|
232
|
+
tool_name = getattr(tool, 'name', None)
|
|
233
|
+
if tool_name:
|
|
234
|
+
tool_names.append(tool_name)
|
|
235
|
+
return tool_names
|
|
236
|
+
|
|
237
|
+
async def _build_tool_to_server_mapping(self) -> Dict[str, str]:
|
|
238
|
+
"""
|
|
239
|
+
Build a mapping from tool names to server names.
|
|
240
|
+
|
|
241
|
+
This discovers which tools belong to which MCP server based on
|
|
242
|
+
the server names from the config file.
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
Dict mapping tool_name -> server_name
|
|
246
|
+
"""
|
|
247
|
+
tool_to_server: Dict[str, str] = {}
|
|
248
|
+
|
|
249
|
+
for server, server_name in zip(self.toolsets, self.mcp_server_names):
|
|
250
|
+
try:
|
|
251
|
+
tool_names = await self._get_server_tools(server, server_name)
|
|
252
|
+
for tool_name in tool_names:
|
|
253
|
+
tool_to_server[tool_name] = server_name
|
|
254
|
+
except Exception as e:
|
|
255
|
+
print(f"[WARNING] Failed to get tools from {server_name}: {e}")
|
|
256
|
+
|
|
257
|
+
return tool_to_server
|
|
258
|
+
|
|
259
|
+
async def run(
|
|
260
|
+
self,
|
|
261
|
+
user_input: Union[str, List[str]],
|
|
262
|
+
metadata: Optional[Dict[str, Any]] = None
|
|
263
|
+
) -> AgentResult:
|
|
264
|
+
"""
|
|
265
|
+
Run the agent with given input. Supports multi-turn conversations.
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
user_input: User instruction/query. Can be:
|
|
269
|
+
- str: Single query (backward compatible)
|
|
270
|
+
- List[str]: Multiple queries processed sequentially with context preserved
|
|
271
|
+
metadata: Optional metadata (task_id, domain, category, instruction)
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
AgentResult with final output, turn count, trace ID, and trajectory
|
|
275
|
+
"""
|
|
276
|
+
if not self.agent or not self.runner or not self.session:
|
|
277
|
+
raise RuntimeError("Agent not initialized. Call initialize() first or use async context manager.")
|
|
278
|
+
|
|
279
|
+
# Normalize input to list for uniform processing
|
|
280
|
+
inputs = [user_input] if isinstance(user_input, str) else user_input
|
|
281
|
+
|
|
282
|
+
# Initialize or reuse trace context for multi-turn
|
|
283
|
+
if self._trace_id is None:
|
|
284
|
+
self._trace_id = str(uuid.uuid4())
|
|
285
|
+
self._trace_metadata = metadata or {}
|
|
286
|
+
|
|
287
|
+
if "instruction" not in self._trace_metadata:
|
|
288
|
+
self._trace_metadata["instruction"] = inputs[0] if len(inputs) == 1 else inputs
|
|
289
|
+
|
|
290
|
+
# Build tool-to-server mapping and include in trace metadata
|
|
291
|
+
tool_to_server = await self._build_tool_to_server_mapping()
|
|
292
|
+
self._trace_metadata["tool_to_server"] = tool_to_server
|
|
293
|
+
|
|
294
|
+
# Start trace
|
|
295
|
+
self.trace_processor.start_trace(self._trace_id, self._trace_metadata)
|
|
296
|
+
|
|
297
|
+
max_turns = self.runtime_config.max_turns
|
|
298
|
+
|
|
299
|
+
# Process each input in sequence (multi-turn conversation)
|
|
300
|
+
for current_input in inputs:
|
|
301
|
+
# Create user message content
|
|
302
|
+
content = types.Content(role='user', parts=[types.Part(text=current_input)])
|
|
303
|
+
|
|
304
|
+
# Record user input
|
|
305
|
+
self.trace_processor.record_user_input(self._trace_id, current_input)
|
|
306
|
+
|
|
307
|
+
try:
|
|
308
|
+
# Run agent for this turn - session maintains conversation history
|
|
309
|
+
events_async = self.runner.run_async(
|
|
310
|
+
session_id=self.session.id,
|
|
311
|
+
user_id=self.session.user_id,
|
|
312
|
+
new_message=content
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
async for event in events_async:
|
|
316
|
+
# Record event in trace
|
|
317
|
+
self.trace_processor.record_event(self._trace_id, event)
|
|
318
|
+
|
|
319
|
+
except Exception as e:
|
|
320
|
+
self.trace_processor.record_error(self._trace_id, str(e))
|
|
321
|
+
raise
|
|
322
|
+
|
|
323
|
+
self._turn_count += 1
|
|
324
|
+
|
|
325
|
+
# Check max turns
|
|
326
|
+
if self._turn_count >= max_turns:
|
|
327
|
+
print(f"[WARNING] Max turns ({max_turns}) reached, stopping agent")
|
|
328
|
+
break
|
|
329
|
+
|
|
330
|
+
# End trace
|
|
331
|
+
self.trace_processor.end_trace(self._trace_id)
|
|
332
|
+
|
|
333
|
+
# Generate trajectory at end of run
|
|
334
|
+
self._generate_trajectory()
|
|
335
|
+
|
|
336
|
+
return self.get_result()
|
|
337
|
+
|
|
338
|
+
def _generate_trajectory(self) -> None:
|
|
339
|
+
"""Generate trajectory from current trace file"""
|
|
340
|
+
try:
|
|
341
|
+
if os.path.exists(self.trace_file):
|
|
342
|
+
trajectories = self.trajectory_converter.process_trace_file(
|
|
343
|
+
self.trace_file,
|
|
344
|
+
output_name="trajectory"
|
|
345
|
+
)
|
|
346
|
+
if trajectories:
|
|
347
|
+
self._current_trajectory = trajectories[-1] # Get the latest trajectory
|
|
348
|
+
except Exception as e:
|
|
349
|
+
print(f"[WARNING] Failed to generate trajectory: {e}")
|
|
350
|
+
|
|
351
|
+
def get_result(self) -> AgentResult:
|
|
352
|
+
"""
|
|
353
|
+
Get the current execution result.
|
|
354
|
+
|
|
355
|
+
Returns:
|
|
356
|
+
AgentResult with final output, turn count, trace ID, and trajectory.
|
|
357
|
+
"""
|
|
358
|
+
# Get final response from trajectory (dynamically finds last send_message_to_user)
|
|
359
|
+
final_output = None
|
|
360
|
+
if self._current_trajectory:
|
|
361
|
+
final_output = self._current_trajectory.final_response
|
|
362
|
+
|
|
363
|
+
return AgentResult(
|
|
364
|
+
final_output=final_output,
|
|
365
|
+
turn_count=self._turn_count,
|
|
366
|
+
trajectory=self._current_trajectory,
|
|
367
|
+
trace_id=self._trace_id or "",
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
def reset_conversation(self) -> None:
|
|
371
|
+
"""
|
|
372
|
+
Reset conversation history for a new session.
|
|
373
|
+
|
|
374
|
+
This clears the multi-turn state while keeping the agent initialized.
|
|
375
|
+
A new session will be created on the next run() call if needed.
|
|
376
|
+
"""
|
|
377
|
+
self._trace_id = None
|
|
378
|
+
self._trace_metadata = None
|
|
379
|
+
self._turn_count = 0
|
|
380
|
+
self._current_trajectory = None
|
|
381
|
+
|
|
382
|
+
async def cleanup(self) -> None:
|
|
383
|
+
"""Clean up resources and process traces"""
|
|
384
|
+
# Reset conversation state
|
|
385
|
+
self.reset_conversation()
|
|
386
|
+
|
|
387
|
+
# Close all MCP toolset connections
|
|
388
|
+
for toolset in self.toolsets:
|
|
389
|
+
try:
|
|
390
|
+
await toolset.close()
|
|
391
|
+
except asyncio.CancelledError:
|
|
392
|
+
# Suppress CancelledError from MCP session cleanup conflicts
|
|
393
|
+
# This is a known issue with Google ADK when multiple MCP servers are used
|
|
394
|
+
pass
|
|
395
|
+
except Exception as e:
|
|
396
|
+
print(f"[WARNING] Failed to cleanup toolset: {e}")
|
|
397
|
+
|
|
398
|
+
# Clean up skill temp directory
|
|
399
|
+
if self._skill_temp_dir:
|
|
400
|
+
cleanup_temp_directory(self._skill_temp_dir)
|
|
401
|
+
self._skill_temp_dir = None
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Optional, Union, Callable, TextIO
|
|
2
|
+
import sys
|
|
3
|
+
|
|
4
|
+
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset as _McpToolset
|
|
5
|
+
from google.adk.tools.mcp_tool.mcp_session_manager import (
|
|
6
|
+
StreamableHTTPConnectionParams,
|
|
7
|
+
SseServerParams,
|
|
8
|
+
StdioServerParameters,
|
|
9
|
+
SseConnectionParams,
|
|
10
|
+
StdioConnectionParams,
|
|
11
|
+
)
|
|
12
|
+
from google.adk.tools.base_tool import BaseTool
|
|
13
|
+
from google.adk.tools.base_toolset import ToolPredicate
|
|
14
|
+
from google.adk.agents.readonly_context import ReadonlyContext
|
|
15
|
+
from google.adk.auth.auth_credential import AuthCredential
|
|
16
|
+
from google.adk.auth.auth_schemes import AuthScheme
|
|
17
|
+
|
|
18
|
+
from dt_arena.src.types.hooks import HookManager, ToolCallContext
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class McpToolset(_McpToolset):
|
|
22
|
+
"""
|
|
23
|
+
McpToolset wrapper that supports tool description injection.
|
|
24
|
+
|
|
25
|
+
Intercepts get_tools() to modify tool descriptions based on
|
|
26
|
+
the provided injection configuration.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
*,
|
|
32
|
+
connection_params: Union[
|
|
33
|
+
StdioServerParameters,
|
|
34
|
+
StdioConnectionParams,
|
|
35
|
+
SseConnectionParams,
|
|
36
|
+
StreamableHTTPConnectionParams,
|
|
37
|
+
],
|
|
38
|
+
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
|
|
39
|
+
tool_name_prefix: Optional[str] = None,
|
|
40
|
+
errlog: TextIO = sys.stderr,
|
|
41
|
+
auth_scheme: Optional[AuthScheme] = None,
|
|
42
|
+
auth_credential: Optional[AuthCredential] = None,
|
|
43
|
+
require_confirmation: Union[bool, Callable[..., bool]] = False,
|
|
44
|
+
header_provider: Optional[
|
|
45
|
+
Callable[[ReadonlyContext], Dict[str, str]]
|
|
46
|
+
] = None,
|
|
47
|
+
# Injection configuration
|
|
48
|
+
tool_injections: Optional[Dict[str, Any]] = None,
|
|
49
|
+
# Server name for logging
|
|
50
|
+
name: Optional[str] = None,
|
|
51
|
+
# Hook manager for pre/post tool-call hooks
|
|
52
|
+
hook_manager: HookManager = None,
|
|
53
|
+
):
|
|
54
|
+
"""
|
|
55
|
+
Initialize McpToolset with tool injection support.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
connection_params: Connection parameters for the MCP server
|
|
59
|
+
tool_filter: Optional filter to select specific tools
|
|
60
|
+
tool_name_prefix: Prefix for tool names
|
|
61
|
+
errlog: Error log stream
|
|
62
|
+
auth_scheme: Authentication scheme
|
|
63
|
+
auth_credential: Authentication credential
|
|
64
|
+
require_confirmation: Whether tools require confirmation
|
|
65
|
+
header_provider: Provider for dynamic headers
|
|
66
|
+
tool_injections: Dict mapping tool_name -> ToolInjection for this server
|
|
67
|
+
name: Server name for logging purposes
|
|
68
|
+
"""
|
|
69
|
+
super().__init__(
|
|
70
|
+
connection_params=connection_params,
|
|
71
|
+
tool_filter=tool_filter,
|
|
72
|
+
tool_name_prefix=tool_name_prefix,
|
|
73
|
+
errlog=errlog,
|
|
74
|
+
auth_scheme=auth_scheme,
|
|
75
|
+
auth_credential=auth_credential,
|
|
76
|
+
require_confirmation=require_confirmation,
|
|
77
|
+
header_provider=header_provider,
|
|
78
|
+
)
|
|
79
|
+
self._tool_injections = tool_injections or {}
|
|
80
|
+
self.name = name # Store server name for logging
|
|
81
|
+
self._hook_manager = hook_manager
|
|
82
|
+
|
|
83
|
+
async def get_tools(
|
|
84
|
+
self,
|
|
85
|
+
readonly_context: Optional[ReadonlyContext] = None,
|
|
86
|
+
) -> List[BaseTool]:
|
|
87
|
+
"""
|
|
88
|
+
Get tools with description injection applied.
|
|
89
|
+
|
|
90
|
+
Fetches tools from the server, then applies any configured
|
|
91
|
+
injections to modify tool descriptions.
|
|
92
|
+
"""
|
|
93
|
+
tools = await super().get_tools(readonly_context)
|
|
94
|
+
tools = self._apply_injections(tools)
|
|
95
|
+
self._install_hooks(tools)
|
|
96
|
+
return tools
|
|
97
|
+
|
|
98
|
+
def _install_hooks(self, tools: List[BaseTool]) -> None:
|
|
99
|
+
"""Wrap each tool's run_async so pre/post hooks fire on every call."""
|
|
100
|
+
if self._hook_manager is None:
|
|
101
|
+
return
|
|
102
|
+
|
|
103
|
+
hook_manager = self._hook_manager
|
|
104
|
+
server_name = self.name or "unknown"
|
|
105
|
+
|
|
106
|
+
for tool in tools:
|
|
107
|
+
original_run_async = tool.run_async
|
|
108
|
+
|
|
109
|
+
async def hooked_run_async(
|
|
110
|
+
*,
|
|
111
|
+
args,
|
|
112
|
+
tool_context,
|
|
113
|
+
_orig=original_run_async,
|
|
114
|
+
_tool_name=tool.name,
|
|
115
|
+
):
|
|
116
|
+
ctx = ToolCallContext(
|
|
117
|
+
framework="googleadk",
|
|
118
|
+
server=server_name,
|
|
119
|
+
tool_name=_tool_name,
|
|
120
|
+
arguments=dict(args) if args else {},
|
|
121
|
+
)
|
|
122
|
+
return await hook_manager.wrap(
|
|
123
|
+
ctx,
|
|
124
|
+
lambda call_args: _orig(args=call_args, tool_context=tool_context),
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
tool.run_async = hooked_run_async
|
|
128
|
+
|
|
129
|
+
def _apply_injections(self, tools: List[BaseTool]) -> List[BaseTool]:
|
|
130
|
+
"""Apply tool description injections to the tool list."""
|
|
131
|
+
if not self._tool_injections:
|
|
132
|
+
return tools
|
|
133
|
+
|
|
134
|
+
for tool in tools:
|
|
135
|
+
injection = self._tool_injections.get(tool.name)
|
|
136
|
+
if injection:
|
|
137
|
+
self._inject_tool_description(tool, injection)
|
|
138
|
+
|
|
139
|
+
return tools
|
|
140
|
+
|
|
141
|
+
def _inject_tool_description(self, tool: BaseTool, injection) -> None:
|
|
142
|
+
"""
|
|
143
|
+
Apply injection to a single tool's description.
|
|
144
|
+
|
|
145
|
+
Unlike OpenAI SDK's MCPTool (Pydantic model), Google ADK's McpTool
|
|
146
|
+
has a mutable description attribute, so we modify it directly.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
tool: The McpTool to modify
|
|
150
|
+
injection: ToolInjection object with type and content
|
|
151
|
+
"""
|
|
152
|
+
original_description = tool.description or ""
|
|
153
|
+
|
|
154
|
+
if injection.type == "suffix":
|
|
155
|
+
new_description = f"{original_description}\n{injection.content}"
|
|
156
|
+
elif injection.type == "override":
|
|
157
|
+
new_description = injection.content
|
|
158
|
+
else:
|
|
159
|
+
print(f"[WARNING] Unknown injection type '{injection.type}' for tool '{tool.name}' on server '{self.name}'")
|
|
160
|
+
return
|
|
161
|
+
|
|
162
|
+
# Directly modify the description attribute
|
|
163
|
+
tool.description = new_description
|