decodingtrust-agent-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent/__init__.py +30 -0
- agent/claudesdk/__init__.py +8 -0
- agent/claudesdk/example.py +221 -0
- agent/claudesdk/src/__init__.py +8 -0
- agent/claudesdk/src/agent.py +400 -0
- agent/claudesdk/src/mcp_proxy.py +409 -0
- agent/claudesdk/src/utils.py +420 -0
- agent/googleadk/__init__.py +15 -0
- agent/googleadk/example.py +237 -0
- agent/googleadk/src/__init__.py +12 -0
- agent/googleadk/src/agent.py +401 -0
- agent/googleadk/src/mcp_wrapper.py +163 -0
- agent/googleadk/src/utils.py +602 -0
- agent/langchain/__init__.py +8 -0
- agent/langchain/example.py +213 -0
- agent/langchain/src/__init__.py +8 -0
- agent/langchain/src/agent.py +645 -0
- agent/langchain/src/utils.py +433 -0
- agent/openaisdk/__init__.py +17 -0
- agent/openaisdk/example.py +228 -0
- agent/openaisdk/src/__init__.py +12 -0
- agent/openaisdk/src/agent.py +491 -0
- agent/openaisdk/src/agent_wrapper.py +143 -0
- agent/openaisdk/src/mcp_wrapper.py +395 -0
- agent/openaisdk/src/utils.py +493 -0
- agent/openclaw/__init__.py +10 -0
- agent/openclaw/example.py +251 -0
- agent/openclaw/src/__init__.py +14 -0
- agent/openclaw/src/agent.py +930 -0
- agent/openclaw/src/helpers/__init__.py +1 -0
- agent/openclaw/src/helpers/auth_helpers.py +55 -0
- agent/openclaw/src/mcp_proxy.py +564 -0
- agent/openclaw/src/plugin_generator.py +231 -0
- agent/openclaw/src/utils.py +341 -0
- agent/pocketflow/__init__.py +18 -0
- agent/pocketflow/example.py +221 -0
- agent/pocketflow/prompts/react_agent.py +46 -0
- agent/pocketflow/src/__init__.py +6 -0
- agent/pocketflow/src/agent.py +507 -0
- agent/pocketflow/src/agent_wrapper.py +159 -0
- agent/pocketflow/src/async_helper.py +92 -0
- agent/pocketflow/src/mcp_react_agent.py +279 -0
- agent/pocketflow/src/native_agent.py +74 -0
- agent/pocketflow/src/nodes.py +467 -0
- benchmark/__init__.py +0 -0
- benchmark/browser/benign.jsonl +34 -0
- benchmark/browser/direct.jsonl +85 -0
- benchmark/browser/indirect.jsonl +82 -0
- benchmark/code/benign.jsonl +0 -0
- benchmark/code/direct.jsonl +121 -0
- benchmark/code/indirect.jsonl +165 -0
- benchmark/crm/benign.jsonl +165 -0
- benchmark/crm/direct.jsonl +90 -0
- benchmark/crm/indirect.jsonl +150 -0
- benchmark/customer-service/benign.jsonl +160 -0
- benchmark/customer-service/direct.jsonl +100 -0
- benchmark/customer-service/indirect.jsonl +101 -0
- benchmark/finance/benign.jsonl +0 -0
- benchmark/finance/direct.jsonl +200 -0
- benchmark/finance/indirect.jsonl +200 -0
- benchmark/legal/benign.jsonl +0 -0
- benchmark/legal/direct.jsonl +200 -0
- benchmark/legal/indirect.jsonl +200 -0
- benchmark/macos/benign.jsonl +30 -0
- benchmark/macos/direct.jsonl +50 -0
- benchmark/macos/indirect.jsonl +50 -0
- benchmark/medical/benign.jsonl +642 -0
- benchmark/medical/direct.jsonl +229 -0
- benchmark/medical/indirect.jsonl +222 -0
- benchmark/os-filesystem/benign.jsonl +200 -0
- benchmark/os-filesystem/direct.jsonl +200 -0
- benchmark/os-filesystem/indirect.jsonl +200 -0
- benchmark/research/benign.jsonl +0 -0
- benchmark/research/direct.jsonl +119 -0
- benchmark/research/indirect.jsonl +125 -0
- benchmark/telecom/benign.jsonl +120 -0
- benchmark/telecom/direct.jsonl +161 -0
- benchmark/telecom/indirect.jsonl +166 -0
- benchmark/travel/benign.jsonl +130 -0
- benchmark/travel/direct.jsonl +105 -0
- benchmark/travel/indirect.jsonl +120 -0
- benchmark/windows/benign.jsonl +100 -0
- benchmark/windows/direct.jsonl +140 -0
- benchmark/windows/indirect.jsonl +107 -0
- benchmark/workflow/benign.jsonl +335 -0
- benchmark/workflow/direct.jsonl +78 -0
- benchmark/workflow/indirect.jsonl +107 -0
- cli/__init__.py +5 -0
- cli/main.py +182 -0
- cli/scaffold.py +334 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/METADATA +642 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/RECORD +374 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/WHEEL +5 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/entry_points.txt +2 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/licenses/LICENSE +201 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/top_level.txt +6 -0
- dt_arena/config/env.yaml +515 -0
- dt_arena/config/injection_mcp.yaml +430 -0
- dt_arena/config/mcp.yaml +642 -0
- dt_arena/envs/arxiv/docker-compose-hub.yml +31 -0
- dt_arena/envs/arxiv/docker-compose.yml +36 -0
- dt_arena/envs/atlassian/docker/docker-compose.dev.yml +65 -0
- dt_arena/envs/atlassian/docker/docker-compose.yml +53 -0
- dt_arena/envs/atlassian/docker-compose-hub.yml +57 -0
- dt_arena/envs/atlassian/docker-compose.yml +72 -0
- dt_arena/envs/bigquery/docker-compose.yml +20 -0
- dt_arena/envs/booking/docker-compose.yml +59 -0
- dt_arena/envs/calendar/docker-compose-hub.yml +30 -0
- dt_arena/envs/calendar/docker-compose.yml +42 -0
- dt_arena/envs/custom-website/docker-compose.yml +6 -0
- dt_arena/envs/customer_service/docker-compose.yml +59 -0
- dt_arena/envs/databricks/docker-compose-hub.yml +47 -0
- dt_arena/envs/databricks/docker-compose.yml +51 -0
- dt_arena/envs/ecommerce/docker-compose.yml +6 -0
- dt_arena/envs/ers/docker-compose.yml +36 -0
- dt_arena/envs/ers/hrms/docker/docker-compose.yml +31 -0
- dt_arena/envs/finance/docker-compose.yml +23 -0
- dt_arena/envs/github/docker/docker-compose-hub.yml +50 -0
- dt_arena/envs/github/docker/docker-compose.yml +50 -0
- dt_arena/envs/gmail/docker-compose-hub.yml +51 -0
- dt_arena/envs/gmail/docker-compose.yml +65 -0
- dt_arena/envs/google-form/docker-compose-hub.yml +33 -0
- dt_arena/envs/google-form/docker-compose.yml +41 -0
- dt_arena/envs/googledocs/docker-compose-hub.yml +61 -0
- dt_arena/envs/googledocs/docker-compose.yml +78 -0
- dt_arena/envs/hospital/docker-compose-hub.yml +25 -0
- dt_arena/envs/hospital/docker-compose.yml +27 -0
- dt_arena/envs/legal/docker-compose.yml +22 -0
- dt_arena/envs/linkedin/docker-compose.yml +63 -0
- dt_arena/envs/macos/docker-compose.yml +79 -0
- dt_arena/envs/os-filesystem/docker-compose-hub.yml +16 -0
- dt_arena/envs/os-filesystem/docker-compose.yml +20 -0
- dt_arena/envs/paypal/docker-compose-hub.yml +48 -0
- dt_arena/envs/paypal/docker-compose.yml +63 -0
- dt_arena/envs/research/docker-compose-hub.yml +13 -0
- dt_arena/envs/research/docker-compose.yml +24 -0
- dt_arena/envs/salesforce_crm/docker-compose-hub.yaml +45 -0
- dt_arena/envs/salesforce_crm/docker-compose.yaml +49 -0
- dt_arena/envs/slack/docker-compose-hub.yml +28 -0
- dt_arena/envs/slack/docker-compose.yml +41 -0
- dt_arena/envs/snowflake/docker-compose-hub.yml +41 -0
- dt_arena/envs/snowflake/docker-compose.yml +44 -0
- dt_arena/envs/telecom/docker-compose-hub.yml +16 -0
- dt_arena/envs/telecom/docker-compose.yml +17 -0
- dt_arena/envs/telegram/docker-compose-hub.yml +57 -0
- dt_arena/envs/telegram/docker-compose.yml +62 -0
- dt_arena/envs/terminal/docker-compose-hub.yml +12 -0
- dt_arena/envs/terminal/docker-compose.yml +26 -0
- dt_arena/envs/travel/docker-compose-hub.yml +19 -0
- dt_arena/envs/travel/docker-compose.yml +19 -0
- dt_arena/envs/whatsapp/docker-compose-hub.yml +61 -0
- dt_arena/envs/whatsapp/docker-compose.yml +78 -0
- dt_arena/envs/windows/docker-compose.yml +71 -0
- dt_arena/envs/zoom/docker-compose-hub.yml +27 -0
- dt_arena/envs/zoom/docker-compose.yml +40 -0
- dt_arena/injection_mcp_server/atlassian/env_injection.py +134 -0
- dt_arena/injection_mcp_server/calendar/env_injection.py +217 -0
- dt_arena/injection_mcp_server/custom_website/env_injection.py +97 -0
- dt_arena/injection_mcp_server/customer_service/env_injection.py +659 -0
- dt_arena/injection_mcp_server/databricks/env_injection.py +255 -0
- dt_arena/injection_mcp_server/ecommerce/env_injection.py +110 -0
- dt_arena/injection_mcp_server/finance/env_injection.py +85 -0
- dt_arena/injection_mcp_server/github/env_injection.py +206 -0
- dt_arena/injection_mcp_server/gmail/env_injection.py +211 -0
- dt_arena/injection_mcp_server/google_form/env_injection.py +186 -0
- dt_arena/injection_mcp_server/googledocs/env_injection.py +44 -0
- dt_arena/injection_mcp_server/hospital/env_injection.py +43 -0
- dt_arena/injection_mcp_server/legal/env_injection.py +229 -0
- dt_arena/injection_mcp_server/macos/env_injection.py +272 -0
- dt_arena/injection_mcp_server/os-filesystem/env_injection.py +341 -0
- dt_arena/injection_mcp_server/paypal/env_injection.py +268 -0
- dt_arena/injection_mcp_server/research/env_injection.py +616 -0
- dt_arena/injection_mcp_server/salesforce/env_injection.py +514 -0
- dt_arena/injection_mcp_server/slack/env_injection.py +265 -0
- dt_arena/injection_mcp_server/snowflake/env_injection.py +230 -0
- dt_arena/injection_mcp_server/telecom/env_injection.py +503 -0
- dt_arena/injection_mcp_server/telegram/env_injection.py +171 -0
- dt_arena/injection_mcp_server/terminal/env_injection.py +523 -0
- dt_arena/injection_mcp_server/travel/env_injection.py +173 -0
- dt_arena/injection_mcp_server/whatsapp/env_injection.py +185 -0
- dt_arena/injection_mcp_server/windows/env_injection.py +943 -0
- dt_arena/injection_mcp_server/zoom/env_injection.py +216 -0
- dt_arena/mcp_server/atlassian/main.py +1554 -0
- dt_arena/mcp_server/atlassian/test_server.py +66 -0
- dt_arena/mcp_server/bigquery/main.py +333 -0
- dt_arena/mcp_server/booking/main.py +310 -0
- dt_arena/mcp_server/browser/main.py +1741 -0
- dt_arena/mcp_server/calendar/example_multi_user.py +162 -0
- dt_arena/mcp_server/calendar/main.py +792 -0
- dt_arena/mcp_server/calendar/test_mcp.py +135 -0
- dt_arena/mcp_server/customer_service/main.py +1063 -0
- dt_arena/mcp_server/databricks/main.py +566 -0
- dt_arena/mcp_server/databricks/probe.py +102 -0
- dt_arena/mcp_server/ers/main.py +845 -0
- dt_arena/mcp_server/finance/__init__.py +87 -0
- dt_arena/mcp_server/finance/core/__init__.py +12 -0
- dt_arena/mcp_server/finance/core/data_loader.py +558 -0
- dt_arena/mcp_server/finance/core/portfolio.py +565 -0
- dt_arena/mcp_server/finance/evaluation/__init__.py +20 -0
- dt_arena/mcp_server/finance/evaluation/evaluator.py +217 -0
- dt_arena/mcp_server/finance/evaluation/logger.py +137 -0
- dt_arena/mcp_server/finance/injection/__init__.py +66 -0
- dt_arena/mcp_server/finance/injection/config.py +176 -0
- dt_arena/mcp_server/finance/injection/content.py +755 -0
- dt_arena/mcp_server/finance/injection/html.py +409 -0
- dt_arena/mcp_server/finance/injection/locations.py +167 -0
- dt_arena/mcp_server/finance/injection/methods.py +193 -0
- dt_arena/mcp_server/finance/injection/presets.py +1023 -0
- dt_arena/mcp_server/finance/main.py +361 -0
- dt_arena/mcp_server/finance/run_mcp.py +21 -0
- dt_arena/mcp_server/finance/run_web.py +26 -0
- dt_arena/mcp_server/finance/server/__init__.py +41 -0
- dt_arena/mcp_server/finance/server/extractor.py +1453 -0
- dt_arena/mcp_server/finance/server/extractor_minimal.py +292 -0
- dt_arena/mcp_server/finance/server/extractor_simple.py +1164 -0
- dt_arena/mcp_server/finance/server/injection_mcp.py +865 -0
- dt_arena/mcp_server/finance/server/mcp.py +451 -0
- dt_arena/mcp_server/finance/server/tools/__init__.py +23 -0
- dt_arena/mcp_server/finance/server/tools/account.py +88 -0
- dt_arena/mcp_server/finance/server/tools/browsing.py +328 -0
- dt_arena/mcp_server/finance/server/tools/social.py +73 -0
- dt_arena/mcp_server/finance/server/tools/trading.py +242 -0
- dt_arena/mcp_server/finance/server/tools/utility.py +49 -0
- dt_arena/mcp_server/finance/server/web.py +2139 -0
- dt_arena/mcp_server/finance/tasks/benchmark/__init__.py +28 -0
- dt_arena/mcp_server/finance/tasks/benchmark/attack_pool.py +3026 -0
- dt_arena/mcp_server/finance/tasks/benchmark/attack_runner.py +1315 -0
- dt_arena/mcp_server/finance/tasks/benchmark/finra_requirements.py +1335 -0
- dt_arena/mcp_server/finance/tasks/benchmark/finra_tasks.py +3665 -0
- dt_arena/mcp_server/finance/tasks/benchmark/malicious_tasks.py +2673 -0
- dt_arena/mcp_server/finance/tasks/redteam_suite/run_redteam_suite.py +1713 -0
- dt_arena/mcp_server/finance/test_mcp_tools.py +476 -0
- dt_arena/mcp_server/github/main.py +441 -0
- dt_arena/mcp_server/gmail/main.py +1004 -0
- dt_arena/mcp_server/google_form/main.py +141 -0
- dt_arena/mcp_server/googledocs/main.py +458 -0
- dt_arena/mcp_server/hospital/mcp_server.py +458 -0
- dt_arena/mcp_server/legal/__init__.py +9 -0
- dt_arena/mcp_server/legal/core/__init__.py +14 -0
- dt_arena/mcp_server/legal/core/courtlistener_store.py +762 -0
- dt_arena/mcp_server/legal/core/data_loader.py +266 -0
- dt_arena/mcp_server/legal/core/document_store.py +197 -0
- dt_arena/mcp_server/legal/core/matter_manager.py +466 -0
- dt_arena/mcp_server/legal/main.py +89 -0
- dt_arena/mcp_server/legal/scripts/collect_data.py +988 -0
- dt_arena/mcp_server/legal/server/__init__.py +14 -0
- dt_arena/mcp_server/legal/server/mcp.py +2330 -0
- dt_arena/mcp_server/macos/client_test.py +270 -0
- dt_arena/mcp_server/macos/mcp_server.py +285 -0
- dt_arena/mcp_server/os-filesystem/main.py +1380 -0
- dt_arena/mcp_server/paypal/main.py +501 -0
- dt_arena/mcp_server/research/main.py +777 -0
- dt_arena/mcp_server/salesforce/main.py +2006 -0
- dt_arena/mcp_server/slack/main.py +318 -0
- dt_arena/mcp_server/snowflake/main.py +612 -0
- dt_arena/mcp_server/snowflake/probe.py +183 -0
- dt_arena/mcp_server/telecom/mcp_client.py +423 -0
- dt_arena/mcp_server/telecom/mcp_server.py +1059 -0
- dt_arena/mcp_server/telegram/main.py +338 -0
- dt_arena/mcp_server/terminal/main.py +163 -0
- dt_arena/mcp_server/travel/client_test.py +16 -0
- dt_arena/mcp_server/travel/mcp_server.py +404 -0
- dt_arena/mcp_server/whatsapp/main.py +318 -0
- dt_arena/mcp_server/windows/client_test.py +270 -0
- dt_arena/mcp_server/windows/mcp_server.py +218 -0
- dt_arena/mcp_server/zoom/main.py +466 -0
- dt_arena/src/__init__.py +0 -0
- dt_arena/src/hooks/__init__.py +0 -0
- dt_arena/src/hooks/audit_log.py +30 -0
- dt_arena/src/hooks/hooks.json +3 -0
- dt_arena/src/run_benign.py +142 -0
- dt_arena/src/types/__init__.py +0 -0
- dt_arena/src/types/agent.py +441 -0
- dt_arena/src/types/attacks.py +2 -0
- dt_arena/src/types/environment.py +2 -0
- dt_arena/src/types/hooks.py +174 -0
- dt_arena/src/types/judge.py +52 -0
- dt_arena/src/types/red_teaming_trajectory.py +385 -0
- dt_arena/src/types/task.py +260 -0
- dt_arena/src/types/trajectory.py +315 -0
- dt_arena/utils/__init__.py +1 -0
- dt_arena/utils/atlassian/__init__.py +27 -0
- dt_arena/utils/atlassian/helpers.py +520 -0
- dt_arena/utils/bigquery/__init__.py +1 -0
- dt_arena/utils/bigquery/helpers.py +246 -0
- dt_arena/utils/calendar/__init__.py +1 -0
- dt_arena/utils/calendar/helpers.py +87 -0
- dt_arena/utils/customer_service/__init__.py +17 -0
- dt_arena/utils/customer_service/cs_env_client.py +940 -0
- dt_arena/utils/customer_service/helpers.py +339 -0
- dt_arena/utils/customer_service/judges/__init__.py +20 -0
- dt_arena/utils/customer_service/judges/abusive_comms.py +149 -0
- dt_arena/utils/customer_service/judges/bulk_destructive.py +202 -0
- dt_arena/utils/customer_service/judges/cross_customer.py +481 -0
- dt_arena/utils/customer_service/judges/email_exfiltration.py +320 -0
- dt_arena/utils/customer_service/judges/financial_fraud.py +221 -0
- dt_arena/utils/customer_service/judges/llm_utils.py +264 -0
- dt_arena/utils/customer_service/judges/pii_disclosure.py +837 -0
- dt_arena/utils/customer_service/judges/policy_misrep.py +44 -0
- dt_arena/utils/customer_service/judges/text_utils.py +21 -0
- dt_arena/utils/databricks/__init__.py +2 -0
- dt_arena/utils/databricks/helpers.py +210 -0
- dt_arena/utils/finance/__init__.py +0 -0
- dt_arena/utils/finance/helpers.py +263 -0
- dt_arena/utils/github/__init__.py +1 -0
- dt_arena/utils/github/helpers.py +249 -0
- dt_arena/utils/gmail/__init__.py +1 -0
- dt_arena/utils/gmail/helpers.py +344 -0
- dt_arena/utils/google_form/__init__.py +2 -0
- dt_arena/utils/google_form/helpers.py +133 -0
- dt_arena/utils/legal/__init__.py +0 -0
- dt_arena/utils/legal/helpers.py +228 -0
- dt_arena/utils/macos/__init__.py +0 -0
- dt_arena/utils/macos/env_setup.py +215 -0
- dt_arena/utils/macos/helpers.py +61 -0
- dt_arena/utils/os_filesystem/__init__.py +1 -0
- dt_arena/utils/os_filesystem/helpers.py +366 -0
- dt_arena/utils/paypal/__init__.py +1 -0
- dt_arena/utils/paypal/helpers.py +178 -0
- dt_arena/utils/port_allocator.py +266 -0
- dt_arena/utils/research/__init__.py +0 -0
- dt_arena/utils/research/helpers.py +251 -0
- dt_arena/utils/salesforce/__init__.py +1 -0
- dt_arena/utils/salesforce/helpers.py +719 -0
- dt_arena/utils/slack/__init__.py +1 -0
- dt_arena/utils/slack/helpers.py +176 -0
- dt_arena/utils/snowflake/__init__.py +1 -0
- dt_arena/utils/snowflake/helpers.py +166 -0
- dt_arena/utils/telecom/__init__.py +1 -0
- dt_arena/utils/telecom/helpers.py +760 -0
- dt_arena/utils/telegram/__init__.py +0 -0
- dt_arena/utils/telegram/helpers.py +174 -0
- dt_arena/utils/terminal/__init__.py +0 -0
- dt_arena/utils/terminal/helpers.py +20 -0
- dt_arena/utils/travel/__init__.py +0 -0
- dt_arena/utils/travel/env_client.py +537 -0
- dt_arena/utils/travel/llm_judge.py +137 -0
- dt_arena/utils/travel/prompts.py +64 -0
- dt_arena/utils/utils/__init__.py +122 -0
- dt_arena/utils/whatsapp/__init__.py +0 -0
- dt_arena/utils/whatsapp/helpers.py +226 -0
- dt_arena/utils/windows/__init__.py +0 -0
- dt_arena/utils/windows/env_reset.py +224 -0
- dt_arena/utils/windows/env_setup.py +280 -0
- dt_arena/utils/windows/exfil_helpers.py +170 -0
- dt_arena/utils/windows/helpers.py +74 -0
- dt_arena/utils/zoom/__init__.py +1 -0
- dt_arena/utils/zoom/helpers.py +70 -0
- eval/__init__.py +1 -0
- eval/evaluation.py +426 -0
- eval/task_runner.py +449 -0
- utils/__init__.py +148 -0
- utils/agent_helpers.py +308 -0
- utils/agent_wrapper.py +189 -0
- utils/compose_utils.py +135 -0
- utils/config.py +77 -0
- utils/env_helpers.py +104 -0
- utils/eval_stats.py +88 -0
- utils/injection_helpers.py +429 -0
- utils/injection_mcp_helpers.py +152 -0
- utils/judge_helpers.py +181 -0
- utils/judge_utils.py +472 -0
- utils/llm.py +196 -0
- utils/logging.py +45 -0
- utils/mcp_helpers.py +232 -0
- utils/mcp_manager.py +235 -0
- utils/memory_guard.py +18 -0
- utils/red_teaming_sandbox.py +476 -0
- utils/reset_helpers.py +318 -0
- utils/resource_manager.py +370 -0
- utils/skill_helpers.py +447 -0
- utils/task_executor.py +904 -0
- utils/task_helpers.py +270 -0
- utils/template_helpers.py +179 -0
|
@@ -0,0 +1,602 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
import threading
|
|
4
|
+
from datetime import datetime, timezone
|
|
5
|
+
from typing import List, Dict, Any, Optional
|
|
6
|
+
from collections import defaultdict
|
|
7
|
+
|
|
8
|
+
from dt_arena.src.types.trajectory import Trajectory
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class GoogleADKTraceProcessor:
|
|
12
|
+
"""
|
|
13
|
+
Trace processor for Google ADK agents.
|
|
14
|
+
|
|
15
|
+
Captures events from the Google ADK runner and writes them to a JSONL file
|
|
16
|
+
for later conversion to standard trajectory format.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, path: str):
|
|
20
|
+
self.path = path
|
|
21
|
+
self._lock = threading.Lock()
|
|
22
|
+
|
|
23
|
+
def _iso(self, ts: Optional[datetime] = None) -> str:
|
|
24
|
+
"""Convert timestamp to ISO format"""
|
|
25
|
+
if ts is None:
|
|
26
|
+
ts = datetime.now(timezone.utc)
|
|
27
|
+
try:
|
|
28
|
+
return ts.isoformat()
|
|
29
|
+
except:
|
|
30
|
+
return datetime.now(timezone.utc).isoformat()
|
|
31
|
+
|
|
32
|
+
def _write(self, obj: Dict[str, Any]):
|
|
33
|
+
"""Thread-safe write to JSONL file"""
|
|
34
|
+
with self._lock:
|
|
35
|
+
try:
|
|
36
|
+
with open(self.path, "a", encoding="utf-8") as f:
|
|
37
|
+
f.write(json.dumps(obj, ensure_ascii=False, default=str) + "\n")
|
|
38
|
+
except Exception as e:
|
|
39
|
+
print(f"[WARNING] Failed to write trace: {e}")
|
|
40
|
+
|
|
41
|
+
def start_trace(self, trace_id: str, metadata: Dict[str, Any]):
|
|
42
|
+
"""Record trace start"""
|
|
43
|
+
self._write({
|
|
44
|
+
"record": "trace_start",
|
|
45
|
+
"trace_id": trace_id,
|
|
46
|
+
"metadata": metadata,
|
|
47
|
+
"started_at": self._iso(),
|
|
48
|
+
"ts": self._iso(),
|
|
49
|
+
})
|
|
50
|
+
|
|
51
|
+
def end_trace(self, trace_id: str):
|
|
52
|
+
"""Record trace end"""
|
|
53
|
+
self._write({
|
|
54
|
+
"record": "trace_end",
|
|
55
|
+
"trace_id": trace_id,
|
|
56
|
+
"ended_at": self._iso(),
|
|
57
|
+
"ts": self._iso(),
|
|
58
|
+
})
|
|
59
|
+
|
|
60
|
+
def record_user_input(self, trace_id: str, user_input: str):
|
|
61
|
+
"""Record user input"""
|
|
62
|
+
self._write({
|
|
63
|
+
"record": "user_input",
|
|
64
|
+
"trace_id": trace_id,
|
|
65
|
+
"content": user_input,
|
|
66
|
+
"ts": self._iso(),
|
|
67
|
+
})
|
|
68
|
+
|
|
69
|
+
def record_event(self, trace_id: str, event: Any):
|
|
70
|
+
"""Record a Google ADK event"""
|
|
71
|
+
event_data = self._extract_event_data(event)
|
|
72
|
+
self._write({
|
|
73
|
+
"record": "event",
|
|
74
|
+
"trace_id": trace_id,
|
|
75
|
+
"event_type": type(event).__name__,
|
|
76
|
+
"event_data": event_data,
|
|
77
|
+
"ts": self._iso(),
|
|
78
|
+
})
|
|
79
|
+
|
|
80
|
+
def record_error(self, trace_id: str, error: str):
|
|
81
|
+
"""Record an error"""
|
|
82
|
+
self._write({
|
|
83
|
+
"record": "error",
|
|
84
|
+
"trace_id": trace_id,
|
|
85
|
+
"error": error,
|
|
86
|
+
"ts": self._iso(),
|
|
87
|
+
})
|
|
88
|
+
|
|
89
|
+
def _extract_event_data(self, event: Any) -> Dict[str, Any]:
|
|
90
|
+
"""Extract relevant data from a Google ADK event per ADK Events documentation"""
|
|
91
|
+
data = {}
|
|
92
|
+
|
|
93
|
+
# === Core Event Identifiers ===
|
|
94
|
+
if hasattr(event, 'id') and event.id:
|
|
95
|
+
data['event_id'] = str(event.id)
|
|
96
|
+
|
|
97
|
+
if hasattr(event, 'invocation_id') and event.invocation_id:
|
|
98
|
+
data['invocation_id'] = str(event.invocation_id)
|
|
99
|
+
|
|
100
|
+
if hasattr(event, 'timestamp') and event.timestamp:
|
|
101
|
+
data['event_timestamp'] = event.timestamp
|
|
102
|
+
|
|
103
|
+
# === Event Origin ===
|
|
104
|
+
if hasattr(event, 'author'):
|
|
105
|
+
data['author'] = str(event.author) if event.author else None
|
|
106
|
+
|
|
107
|
+
if hasattr(event, 'branch') and event.branch:
|
|
108
|
+
data['branch'] = str(event.branch)
|
|
109
|
+
|
|
110
|
+
# === Streaming/Completion Status ===
|
|
111
|
+
if hasattr(event, 'partial'):
|
|
112
|
+
data['partial'] = event.partial
|
|
113
|
+
|
|
114
|
+
if hasattr(event, 'turn_complete'):
|
|
115
|
+
data['turn_complete'] = event.turn_complete
|
|
116
|
+
|
|
117
|
+
# === Final Response Check (method, not attribute) ===
|
|
118
|
+
if hasattr(event, 'is_final_response') and callable(event.is_final_response):
|
|
119
|
+
try:
|
|
120
|
+
data['is_final_response'] = event.is_final_response()
|
|
121
|
+
except:
|
|
122
|
+
data['is_final_response'] = None
|
|
123
|
+
|
|
124
|
+
# === Content Extraction ===
|
|
125
|
+
if hasattr(event, 'content') and event.content:
|
|
126
|
+
content_data = self._extract_content(event.content)
|
|
127
|
+
if content_data:
|
|
128
|
+
data['content'] = content_data
|
|
129
|
+
|
|
130
|
+
# === Function Calls/Responses using helper methods ===
|
|
131
|
+
if hasattr(event, 'get_function_calls') and callable(event.get_function_calls):
|
|
132
|
+
try:
|
|
133
|
+
function_calls = event.get_function_calls()
|
|
134
|
+
if function_calls:
|
|
135
|
+
data['function_calls'] = [
|
|
136
|
+
{'name': fc.name, 'args': dict(fc.args) if fc.args else {}}
|
|
137
|
+
for fc in function_calls
|
|
138
|
+
]
|
|
139
|
+
except:
|
|
140
|
+
pass
|
|
141
|
+
|
|
142
|
+
if hasattr(event, 'get_function_responses') and callable(event.get_function_responses):
|
|
143
|
+
try:
|
|
144
|
+
function_responses = event.get_function_responses()
|
|
145
|
+
if function_responses:
|
|
146
|
+
data['function_responses'] = [
|
|
147
|
+
{'name': fr.name, 'response': fr.response if hasattr(fr, 'response') else None}
|
|
148
|
+
for fr in function_responses
|
|
149
|
+
]
|
|
150
|
+
except:
|
|
151
|
+
pass
|
|
152
|
+
|
|
153
|
+
# === Actions (EventActions object, not iterable) ===
|
|
154
|
+
if hasattr(event, 'actions') and event.actions:
|
|
155
|
+
data['actions'] = self._extract_actions(event.actions)
|
|
156
|
+
|
|
157
|
+
# === Error Fields ===
|
|
158
|
+
if hasattr(event, 'error_code') and event.error_code:
|
|
159
|
+
data['error_code'] = str(event.error_code)
|
|
160
|
+
|
|
161
|
+
if hasattr(event, 'error_message') and event.error_message:
|
|
162
|
+
data['error_message'] = str(event.error_message)
|
|
163
|
+
|
|
164
|
+
# === Long Running Tool IDs ===
|
|
165
|
+
if hasattr(event, 'long_running_tool_ids') and event.long_running_tool_ids:
|
|
166
|
+
data['long_running_tool_ids'] = list(event.long_running_tool_ids)
|
|
167
|
+
|
|
168
|
+
# === Legacy fields (keep for backwards compatibility) ===
|
|
169
|
+
if hasattr(event, 'tool_calls') and event.tool_calls:
|
|
170
|
+
data['tool_calls'] = self._extract_tool_calls(event.tool_calls)
|
|
171
|
+
|
|
172
|
+
if hasattr(event, 'tool_results') and event.tool_results:
|
|
173
|
+
data['tool_results'] = self._extract_tool_results(event.tool_results)
|
|
174
|
+
|
|
175
|
+
return data
|
|
176
|
+
|
|
177
|
+
def _extract_content(self, content: Any) -> Optional[Dict[str, Any]]:
|
|
178
|
+
"""Extract content from Google ADK content object"""
|
|
179
|
+
if content is None:
|
|
180
|
+
return None
|
|
181
|
+
|
|
182
|
+
result = {}
|
|
183
|
+
|
|
184
|
+
if hasattr(content, 'role'):
|
|
185
|
+
result['role'] = str(content.role)
|
|
186
|
+
|
|
187
|
+
if hasattr(content, 'parts') and content.parts:
|
|
188
|
+
parts = []
|
|
189
|
+
for part in content.parts:
|
|
190
|
+
part_data = {}
|
|
191
|
+
if hasattr(part, 'text') and part.text:
|
|
192
|
+
part_data['text'] = part.text
|
|
193
|
+
if hasattr(part, 'function_call') and part.function_call:
|
|
194
|
+
part_data['function_call'] = self._extract_function_call(part.function_call)
|
|
195
|
+
if hasattr(part, 'function_response') and part.function_response:
|
|
196
|
+
part_data['function_response'] = self._extract_function_response(part.function_response)
|
|
197
|
+
if part_data:
|
|
198
|
+
parts.append(part_data)
|
|
199
|
+
if parts:
|
|
200
|
+
result['parts'] = parts
|
|
201
|
+
|
|
202
|
+
return result if result else None
|
|
203
|
+
|
|
204
|
+
def _extract_function_call(self, fc: Any) -> Dict[str, Any]:
|
|
205
|
+
"""Extract function call data"""
|
|
206
|
+
data = {}
|
|
207
|
+
if hasattr(fc, 'name'):
|
|
208
|
+
data['name'] = fc.name
|
|
209
|
+
if hasattr(fc, 'args'):
|
|
210
|
+
try:
|
|
211
|
+
data['args'] = dict(fc.args) if fc.args else {}
|
|
212
|
+
except:
|
|
213
|
+
data['args'] = str(fc.args)
|
|
214
|
+
if hasattr(fc, 'id'):
|
|
215
|
+
data['id'] = fc.id
|
|
216
|
+
return data
|
|
217
|
+
|
|
218
|
+
def _extract_function_response(self, fr: Any) -> Dict[str, Any]:
|
|
219
|
+
"""Extract function response data"""
|
|
220
|
+
data = {}
|
|
221
|
+
if hasattr(fr, 'name'):
|
|
222
|
+
data['name'] = fr.name
|
|
223
|
+
if hasattr(fr, 'response'):
|
|
224
|
+
try:
|
|
225
|
+
data['response'] = dict(fr.response) if isinstance(fr.response, dict) else str(fr.response)
|
|
226
|
+
except:
|
|
227
|
+
data['response'] = str(fr.response)
|
|
228
|
+
if hasattr(fr, 'id'):
|
|
229
|
+
data['id'] = fr.id
|
|
230
|
+
return data
|
|
231
|
+
|
|
232
|
+
def _extract_actions(self, actions: Any) -> Dict[str, Any]:
|
|
233
|
+
"""
|
|
234
|
+
Extract actions data from EventActions object.
|
|
235
|
+
|
|
236
|
+
EventActions has these properties:
|
|
237
|
+
- state_delta: dict of state changes
|
|
238
|
+
- artifact_delta: dict of artifact changes
|
|
239
|
+
- transfer_to_agent: string, agent to transfer control to
|
|
240
|
+
- escalate: bool, signal to terminate loop
|
|
241
|
+
- skip_summarization: bool, skip LLM summarization of tool result
|
|
242
|
+
"""
|
|
243
|
+
result = {}
|
|
244
|
+
try:
|
|
245
|
+
# State changes (dict of {key: value} pairs)
|
|
246
|
+
if hasattr(actions, 'state_delta') and actions.state_delta:
|
|
247
|
+
result['state_delta'] = dict(actions.state_delta)
|
|
248
|
+
|
|
249
|
+
# Artifact changes (dict of {filename: version})
|
|
250
|
+
if hasattr(actions, 'artifact_delta') and actions.artifact_delta:
|
|
251
|
+
result['artifact_delta'] = dict(actions.artifact_delta)
|
|
252
|
+
|
|
253
|
+
# Control flow: agent transfer
|
|
254
|
+
if hasattr(actions, 'transfer_to_agent') and actions.transfer_to_agent:
|
|
255
|
+
result['transfer_to_agent'] = str(actions.transfer_to_agent)
|
|
256
|
+
|
|
257
|
+
# Control flow: loop escalation/termination
|
|
258
|
+
if hasattr(actions, 'escalate') and actions.escalate:
|
|
259
|
+
result['escalate'] = bool(actions.escalate)
|
|
260
|
+
|
|
261
|
+
# Skip summarization flag for tool results
|
|
262
|
+
if hasattr(actions, 'skip_summarization') and actions.skip_summarization:
|
|
263
|
+
result['skip_summarization'] = bool(actions.skip_summarization)
|
|
264
|
+
|
|
265
|
+
except Exception as e:
|
|
266
|
+
result['_extraction_error'] = str(e)
|
|
267
|
+
result['_raw'] = str(actions)
|
|
268
|
+
|
|
269
|
+
return result
|
|
270
|
+
|
|
271
|
+
def _extract_tool_calls(self, tool_calls: Any) -> List[Dict[str, Any]]:
|
|
272
|
+
"""Extract tool calls data"""
|
|
273
|
+
result = []
|
|
274
|
+
try:
|
|
275
|
+
for tc in tool_calls:
|
|
276
|
+
tc_data = {}
|
|
277
|
+
if hasattr(tc, 'name'):
|
|
278
|
+
tc_data['name'] = tc.name
|
|
279
|
+
if hasattr(tc, 'args'):
|
|
280
|
+
tc_data['args'] = dict(tc.args) if tc.args else {}
|
|
281
|
+
result.append(tc_data)
|
|
282
|
+
except:
|
|
283
|
+
result.append({'raw': str(tool_calls)})
|
|
284
|
+
return result
|
|
285
|
+
|
|
286
|
+
def _extract_tool_results(self, tool_results: Any) -> List[Dict[str, Any]]:
|
|
287
|
+
"""Extract tool results data"""
|
|
288
|
+
result = []
|
|
289
|
+
try:
|
|
290
|
+
for tr in tool_results:
|
|
291
|
+
tr_data = {}
|
|
292
|
+
if hasattr(tr, 'name'):
|
|
293
|
+
tr_data['name'] = tr.name
|
|
294
|
+
if hasattr(tr, 'result'):
|
|
295
|
+
tr_data['result'] = tr.result
|
|
296
|
+
result.append(tr_data)
|
|
297
|
+
except:
|
|
298
|
+
result.append({'raw': str(tool_results)})
|
|
299
|
+
return result
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
class GoogleADKTrajectoryConverter:
|
|
303
|
+
"""Convert Google ADK traces to standard trajectory format"""
|
|
304
|
+
|
|
305
|
+
def __init__(self, output_dir: str, timestamp: Optional[str] = None):
|
|
306
|
+
self.output_dir = output_dir
|
|
307
|
+
self.timestamp = timestamp # Used for output filename
|
|
308
|
+
os.makedirs(self.output_dir, exist_ok=True)
|
|
309
|
+
|
|
310
|
+
def load_traces(self, trace_file: str):
|
|
311
|
+
"""Load traces from JSONL file
|
|
312
|
+
|
|
313
|
+
Returns:
|
|
314
|
+
traces_by_id: Dict mapping trace_id to list of all records (user_input, event, error) in order
|
|
315
|
+
trace_metadata: Dict mapping trace_id to metadata from trace_start
|
|
316
|
+
user_inputs: Dict mapping trace_id to list of user input strings (for backward compat)
|
|
317
|
+
"""
|
|
318
|
+
if not os.path.exists(trace_file):
|
|
319
|
+
raise FileNotFoundError(f"Trace file not found: {trace_file}")
|
|
320
|
+
|
|
321
|
+
traces_by_id = defaultdict(list)
|
|
322
|
+
trace_metadata = {}
|
|
323
|
+
user_inputs = defaultdict(list)
|
|
324
|
+
|
|
325
|
+
with open(trace_file, 'r', encoding='utf-8') as f:
|
|
326
|
+
for line in f:
|
|
327
|
+
if not line.strip():
|
|
328
|
+
continue
|
|
329
|
+
data = json.loads(line)
|
|
330
|
+
trace_id = data.get("trace_id")
|
|
331
|
+
|
|
332
|
+
if data.get("record") == "trace_start":
|
|
333
|
+
trace_metadata[trace_id] = data.get("metadata", {})
|
|
334
|
+
|
|
335
|
+
elif data.get("record") == "user_input":
|
|
336
|
+
# Store in both places: ordered list for interleaving, separate list for backward compat
|
|
337
|
+
user_inputs[trace_id].append(data.get("content", ""))
|
|
338
|
+
traces_by_id[trace_id].append(data)
|
|
339
|
+
|
|
340
|
+
elif data.get("record") == "event":
|
|
341
|
+
traces_by_id[trace_id].append(data)
|
|
342
|
+
|
|
343
|
+
elif data.get("record") == "error":
|
|
344
|
+
traces_by_id[trace_id].append(data)
|
|
345
|
+
|
|
346
|
+
return traces_by_id, trace_metadata, user_inputs
|
|
347
|
+
|
|
348
|
+
def _parse_tool_input(self, input_data: Any) -> Dict[str, Any]:
|
|
349
|
+
"""Parse tool input to dict"""
|
|
350
|
+
try:
|
|
351
|
+
if isinstance(input_data, dict):
|
|
352
|
+
return input_data
|
|
353
|
+
elif isinstance(input_data, str):
|
|
354
|
+
return json.loads(input_data)
|
|
355
|
+
else:
|
|
356
|
+
return {"raw": str(input_data)}
|
|
357
|
+
except:
|
|
358
|
+
return {"raw": str(input_data)}
|
|
359
|
+
|
|
360
|
+
def _parse_tool_output(self, output_data: Any) -> Any:
|
|
361
|
+
"""
|
|
362
|
+
Parse tool output to extract the actual result data.
|
|
363
|
+
|
|
364
|
+
Google ADK MCP tool responses have a nested structure:
|
|
365
|
+
{
|
|
366
|
+
'content': [{'type': 'text', 'text': '{"data": ...}'}],
|
|
367
|
+
'isError': False,
|
|
368
|
+
'structuredContent': {'data': {...}} # Already parsed
|
|
369
|
+
}
|
|
370
|
+
|
|
371
|
+
This method extracts the actual data to align with other SDK formats.
|
|
372
|
+
"""
|
|
373
|
+
if output_data is None:
|
|
374
|
+
return None
|
|
375
|
+
|
|
376
|
+
# If it's a string, try to parse as JSON
|
|
377
|
+
if isinstance(output_data, str):
|
|
378
|
+
try:
|
|
379
|
+
return json.loads(output_data)
|
|
380
|
+
except (json.JSONDecodeError, TypeError):
|
|
381
|
+
return output_data
|
|
382
|
+
|
|
383
|
+
# If it's not a dict, return as-is
|
|
384
|
+
if not isinstance(output_data, dict):
|
|
385
|
+
return output_data
|
|
386
|
+
|
|
387
|
+
# Check for Google ADK MCP response structure
|
|
388
|
+
# Prefer structuredContent as it's already parsed
|
|
389
|
+
if 'structuredContent' in output_data and output_data['structuredContent']:
|
|
390
|
+
return output_data['structuredContent']
|
|
391
|
+
|
|
392
|
+
# Fall back to parsing content[0].text if available
|
|
393
|
+
if 'content' in output_data and isinstance(output_data['content'], list):
|
|
394
|
+
for item in output_data['content']:
|
|
395
|
+
if isinstance(item, dict) and item.get('type') == 'text':
|
|
396
|
+
text = item.get('text', '')
|
|
397
|
+
if text:
|
|
398
|
+
try:
|
|
399
|
+
return json.loads(text)
|
|
400
|
+
except (json.JSONDecodeError, TypeError):
|
|
401
|
+
return text
|
|
402
|
+
|
|
403
|
+
# If it has 'data' key directly, return that (already normalized)
|
|
404
|
+
if 'data' in output_data and len(output_data) == 1:
|
|
405
|
+
return output_data
|
|
406
|
+
|
|
407
|
+
# Return as-is if no special structure detected
|
|
408
|
+
return output_data
|
|
409
|
+
|
|
410
|
+
def _build_trajectory_steps(
|
|
411
|
+
self,
|
|
412
|
+
traj: Trajectory,
|
|
413
|
+
records: List[Dict],
|
|
414
|
+
tool_to_server: Optional[Dict[str, str]] = None
|
|
415
|
+
):
|
|
416
|
+
"""Build trajectory steps from Google ADK records (user_input, event, error)
|
|
417
|
+
|
|
418
|
+
Args:
|
|
419
|
+
traj: Trajectory object to append steps to
|
|
420
|
+
records: List of all records (user_input, event, error) in chronological order
|
|
421
|
+
tool_to_server: Optional mapping of tool_name -> server_name from config
|
|
422
|
+
|
|
423
|
+
Note: Records are processed in order to properly interleave user inputs
|
|
424
|
+
with agent responses for multi-turn conversations.
|
|
425
|
+
"""
|
|
426
|
+
tool_to_server = tool_to_server or {}
|
|
427
|
+
|
|
428
|
+
# Process all records in chronological order
|
|
429
|
+
for record in records:
|
|
430
|
+
record_type = record.get("record")
|
|
431
|
+
|
|
432
|
+
# Handle user input records
|
|
433
|
+
if record_type == "user_input":
|
|
434
|
+
user_input = record.get("content", "")
|
|
435
|
+
traj.append_user_step(user_input)
|
|
436
|
+
continue
|
|
437
|
+
|
|
438
|
+
# Handle error records
|
|
439
|
+
if record_type == "error":
|
|
440
|
+
traj.append_agent_step(
|
|
441
|
+
action="error",
|
|
442
|
+
metadata={"error": record.get("error", "Unknown error")}
|
|
443
|
+
)
|
|
444
|
+
continue
|
|
445
|
+
|
|
446
|
+
# Handle event records (record_type == "event")
|
|
447
|
+
event_data = record.get("event_data", {})
|
|
448
|
+
content = event_data.get("content", {})
|
|
449
|
+
is_final = event_data.get("is_final_response", False)
|
|
450
|
+
|
|
451
|
+
# Process function calls (tool invocations)
|
|
452
|
+
if content and "parts" in content:
|
|
453
|
+
for part in content.get("parts", []):
|
|
454
|
+
# Function call (agent action)
|
|
455
|
+
if "function_call" in part:
|
|
456
|
+
fc = part["function_call"]
|
|
457
|
+
tool_name = fc.get("name", "unknown_tool")
|
|
458
|
+
tool_params = fc.get("args", {})
|
|
459
|
+
|
|
460
|
+
# Look up server name from mapping, fallback to "unknown"
|
|
461
|
+
server_name = tool_to_server.get(tool_name, "unknown")
|
|
462
|
+
|
|
463
|
+
# Build action string
|
|
464
|
+
action_str = f"{tool_name}("
|
|
465
|
+
params = []
|
|
466
|
+
for k, v in tool_params.items():
|
|
467
|
+
if isinstance(v, str):
|
|
468
|
+
params.append(f'{k}="{v}"')
|
|
469
|
+
else:
|
|
470
|
+
params.append(f'{k}={v}')
|
|
471
|
+
action_str += ", ".join(params) + ")"
|
|
472
|
+
|
|
473
|
+
traj.append_agent_step(
|
|
474
|
+
action=action_str,
|
|
475
|
+
tool_name=tool_name,
|
|
476
|
+
tool_params=tool_params,
|
|
477
|
+
server=server_name
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
# Function response (tool return)
|
|
481
|
+
if "function_response" in part:
|
|
482
|
+
fr = part["function_response"]
|
|
483
|
+
tool_name = fr.get("name", "unknown_tool")
|
|
484
|
+
result = fr.get("response", "")
|
|
485
|
+
|
|
486
|
+
# Look up server name from mapping
|
|
487
|
+
server_name = tool_to_server.get(tool_name, "unknown")
|
|
488
|
+
|
|
489
|
+
parsed_result = self._parse_tool_output(result)
|
|
490
|
+
|
|
491
|
+
traj.append_tool_return(
|
|
492
|
+
result=parsed_result,
|
|
493
|
+
tool_name=tool_name,
|
|
494
|
+
server=server_name
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
# Final text response to user after each turn
|
|
498
|
+
if is_final and "text" in part:
|
|
499
|
+
traj.append_agent_step(
|
|
500
|
+
action="send_message_to_user",
|
|
501
|
+
metadata={"message": part["text"]}
|
|
502
|
+
)
|
|
503
|
+
# Set final response in trajectory
|
|
504
|
+
traj.set_final_response(part["text"])
|
|
505
|
+
|
|
506
|
+
def convert_trace_to_trajectory(
|
|
507
|
+
self,
|
|
508
|
+
trace_id: str,
|
|
509
|
+
records: List[Dict],
|
|
510
|
+
metadata: Dict,
|
|
511
|
+
user_inputs: List[str]
|
|
512
|
+
) -> Trajectory:
|
|
513
|
+
"""Convert a single trace to trajectory format
|
|
514
|
+
|
|
515
|
+
Args:
|
|
516
|
+
trace_id: The trace ID
|
|
517
|
+
records: List of all records (user_input, event, error) in chronological order
|
|
518
|
+
metadata: Trace metadata from trace_start
|
|
519
|
+
user_inputs: List of user input strings (for instruction extraction)
|
|
520
|
+
|
|
521
|
+
Returns:
|
|
522
|
+
Trajectory object
|
|
523
|
+
"""
|
|
524
|
+
task_id = metadata.get("task_id", trace_id[:16])
|
|
525
|
+
instruction = metadata.get("instruction", "")
|
|
526
|
+
if not instruction and user_inputs:
|
|
527
|
+
instruction = user_inputs[0]
|
|
528
|
+
malicious_instruction = metadata.get("malicious_goal")
|
|
529
|
+
|
|
530
|
+
# Get tool-to-server mapping from trace metadata
|
|
531
|
+
tool_to_server = metadata.get("tool_to_server", {})
|
|
532
|
+
|
|
533
|
+
# Create Trajectory instance
|
|
534
|
+
traj = Trajectory(
|
|
535
|
+
task_id=task_id,
|
|
536
|
+
original_instruction=instruction,
|
|
537
|
+
malicious_instruction=malicious_instruction,
|
|
538
|
+
domain=metadata.get("domain"),
|
|
539
|
+
risk_category=metadata.get("category")
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
# Calculate duration from records
|
|
543
|
+
if records:
|
|
544
|
+
timestamps = [r.get("ts") for r in records if r.get("ts")]
|
|
545
|
+
if len(timestamps) >= 2:
|
|
546
|
+
try:
|
|
547
|
+
start = datetime.fromisoformat(timestamps[0].replace('Z', '+00:00'))
|
|
548
|
+
end = datetime.fromisoformat(timestamps[-1].replace('Z', '+00:00'))
|
|
549
|
+
duration = (end - start).total_seconds()
|
|
550
|
+
except:
|
|
551
|
+
duration = 0.0
|
|
552
|
+
else:
|
|
553
|
+
duration = 0.0
|
|
554
|
+
else:
|
|
555
|
+
duration = 0.0
|
|
556
|
+
|
|
557
|
+
traj.data["traj_info"]["duration"] = round(duration, 3)
|
|
558
|
+
traj.data["traj_info"]["metadata"] = {"trace_id": trace_id}
|
|
559
|
+
|
|
560
|
+
# Build trajectory steps from records (properly interleaved for multi-turn)
|
|
561
|
+
self._build_trajectory_steps(traj, records, tool_to_server)
|
|
562
|
+
|
|
563
|
+
return traj
|
|
564
|
+
|
|
565
|
+
def process_trace_file(self, trace_file: str, output_name: Optional[str] = None) -> List[Trajectory]:
|
|
566
|
+
"""Process trace file and generate trajectory files
|
|
567
|
+
|
|
568
|
+
Args:
|
|
569
|
+
trace_file: Path to the trace JSONL file
|
|
570
|
+
output_name: Optional name prefix for output files
|
|
571
|
+
|
|
572
|
+
Returns:
|
|
573
|
+
List of Trajectory objects
|
|
574
|
+
"""
|
|
575
|
+
traces_by_id, trace_metadata, user_inputs = self.load_traces(trace_file)
|
|
576
|
+
|
|
577
|
+
trajectories = []
|
|
578
|
+
|
|
579
|
+
for i, (trace_id, records) in enumerate(traces_by_id.items()):
|
|
580
|
+
metadata = trace_metadata.get(trace_id, {})
|
|
581
|
+
task_id = metadata.get("task_id", trace_id[:16])
|
|
582
|
+
inputs = user_inputs.get(trace_id, [])
|
|
583
|
+
|
|
584
|
+
traj = self.convert_trace_to_trajectory(
|
|
585
|
+
trace_id, records, metadata, inputs
|
|
586
|
+
)
|
|
587
|
+
|
|
588
|
+
if self.timestamp:
|
|
589
|
+
if len(traces_by_id) == 1:
|
|
590
|
+
filename = f"{self.timestamp}.json"
|
|
591
|
+
else:
|
|
592
|
+
filename = f"{self.timestamp}_{i}.json"
|
|
593
|
+
else:
|
|
594
|
+
filename = f"{task_id}_{trace_id}.json"
|
|
595
|
+
output_path = os.path.join(self.output_dir, filename)
|
|
596
|
+
|
|
597
|
+
with open(output_path, 'w', encoding='utf-8') as f:
|
|
598
|
+
json.dump(traj.to_dict(), f, indent=2, ensure_ascii=False)
|
|
599
|
+
|
|
600
|
+
trajectories.append(traj)
|
|
601
|
+
|
|
602
|
+
return trajectories
|