decodingtrust-agent-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent/__init__.py +30 -0
- agent/claudesdk/__init__.py +8 -0
- agent/claudesdk/example.py +221 -0
- agent/claudesdk/src/__init__.py +8 -0
- agent/claudesdk/src/agent.py +400 -0
- agent/claudesdk/src/mcp_proxy.py +409 -0
- agent/claudesdk/src/utils.py +420 -0
- agent/googleadk/__init__.py +15 -0
- agent/googleadk/example.py +237 -0
- agent/googleadk/src/__init__.py +12 -0
- agent/googleadk/src/agent.py +401 -0
- agent/googleadk/src/mcp_wrapper.py +163 -0
- agent/googleadk/src/utils.py +602 -0
- agent/langchain/__init__.py +8 -0
- agent/langchain/example.py +213 -0
- agent/langchain/src/__init__.py +8 -0
- agent/langchain/src/agent.py +645 -0
- agent/langchain/src/utils.py +433 -0
- agent/openaisdk/__init__.py +17 -0
- agent/openaisdk/example.py +228 -0
- agent/openaisdk/src/__init__.py +12 -0
- agent/openaisdk/src/agent.py +491 -0
- agent/openaisdk/src/agent_wrapper.py +143 -0
- agent/openaisdk/src/mcp_wrapper.py +395 -0
- agent/openaisdk/src/utils.py +493 -0
- agent/openclaw/__init__.py +10 -0
- agent/openclaw/example.py +251 -0
- agent/openclaw/src/__init__.py +14 -0
- agent/openclaw/src/agent.py +930 -0
- agent/openclaw/src/helpers/__init__.py +1 -0
- agent/openclaw/src/helpers/auth_helpers.py +55 -0
- agent/openclaw/src/mcp_proxy.py +564 -0
- agent/openclaw/src/plugin_generator.py +231 -0
- agent/openclaw/src/utils.py +341 -0
- agent/pocketflow/__init__.py +18 -0
- agent/pocketflow/example.py +221 -0
- agent/pocketflow/prompts/react_agent.py +46 -0
- agent/pocketflow/src/__init__.py +6 -0
- agent/pocketflow/src/agent.py +507 -0
- agent/pocketflow/src/agent_wrapper.py +159 -0
- agent/pocketflow/src/async_helper.py +92 -0
- agent/pocketflow/src/mcp_react_agent.py +279 -0
- agent/pocketflow/src/native_agent.py +74 -0
- agent/pocketflow/src/nodes.py +467 -0
- benchmark/__init__.py +0 -0
- benchmark/browser/benign.jsonl +34 -0
- benchmark/browser/direct.jsonl +85 -0
- benchmark/browser/indirect.jsonl +82 -0
- benchmark/code/benign.jsonl +0 -0
- benchmark/code/direct.jsonl +121 -0
- benchmark/code/indirect.jsonl +165 -0
- benchmark/crm/benign.jsonl +165 -0
- benchmark/crm/direct.jsonl +90 -0
- benchmark/crm/indirect.jsonl +150 -0
- benchmark/customer-service/benign.jsonl +160 -0
- benchmark/customer-service/direct.jsonl +100 -0
- benchmark/customer-service/indirect.jsonl +101 -0
- benchmark/finance/benign.jsonl +0 -0
- benchmark/finance/direct.jsonl +200 -0
- benchmark/finance/indirect.jsonl +200 -0
- benchmark/legal/benign.jsonl +0 -0
- benchmark/legal/direct.jsonl +200 -0
- benchmark/legal/indirect.jsonl +200 -0
- benchmark/macos/benign.jsonl +30 -0
- benchmark/macos/direct.jsonl +50 -0
- benchmark/macos/indirect.jsonl +50 -0
- benchmark/medical/benign.jsonl +642 -0
- benchmark/medical/direct.jsonl +229 -0
- benchmark/medical/indirect.jsonl +222 -0
- benchmark/os-filesystem/benign.jsonl +200 -0
- benchmark/os-filesystem/direct.jsonl +200 -0
- benchmark/os-filesystem/indirect.jsonl +200 -0
- benchmark/research/benign.jsonl +0 -0
- benchmark/research/direct.jsonl +119 -0
- benchmark/research/indirect.jsonl +125 -0
- benchmark/telecom/benign.jsonl +120 -0
- benchmark/telecom/direct.jsonl +161 -0
- benchmark/telecom/indirect.jsonl +166 -0
- benchmark/travel/benign.jsonl +130 -0
- benchmark/travel/direct.jsonl +105 -0
- benchmark/travel/indirect.jsonl +120 -0
- benchmark/windows/benign.jsonl +100 -0
- benchmark/windows/direct.jsonl +140 -0
- benchmark/windows/indirect.jsonl +107 -0
- benchmark/workflow/benign.jsonl +335 -0
- benchmark/workflow/direct.jsonl +78 -0
- benchmark/workflow/indirect.jsonl +107 -0
- cli/__init__.py +5 -0
- cli/main.py +182 -0
- cli/scaffold.py +334 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/METADATA +642 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/RECORD +374 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/WHEEL +5 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/entry_points.txt +2 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/licenses/LICENSE +201 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/top_level.txt +6 -0
- dt_arena/config/env.yaml +515 -0
- dt_arena/config/injection_mcp.yaml +430 -0
- dt_arena/config/mcp.yaml +642 -0
- dt_arena/envs/arxiv/docker-compose-hub.yml +31 -0
- dt_arena/envs/arxiv/docker-compose.yml +36 -0
- dt_arena/envs/atlassian/docker/docker-compose.dev.yml +65 -0
- dt_arena/envs/atlassian/docker/docker-compose.yml +53 -0
- dt_arena/envs/atlassian/docker-compose-hub.yml +57 -0
- dt_arena/envs/atlassian/docker-compose.yml +72 -0
- dt_arena/envs/bigquery/docker-compose.yml +20 -0
- dt_arena/envs/booking/docker-compose.yml +59 -0
- dt_arena/envs/calendar/docker-compose-hub.yml +30 -0
- dt_arena/envs/calendar/docker-compose.yml +42 -0
- dt_arena/envs/custom-website/docker-compose.yml +6 -0
- dt_arena/envs/customer_service/docker-compose.yml +59 -0
- dt_arena/envs/databricks/docker-compose-hub.yml +47 -0
- dt_arena/envs/databricks/docker-compose.yml +51 -0
- dt_arena/envs/ecommerce/docker-compose.yml +6 -0
- dt_arena/envs/ers/docker-compose.yml +36 -0
- dt_arena/envs/ers/hrms/docker/docker-compose.yml +31 -0
- dt_arena/envs/finance/docker-compose.yml +23 -0
- dt_arena/envs/github/docker/docker-compose-hub.yml +50 -0
- dt_arena/envs/github/docker/docker-compose.yml +50 -0
- dt_arena/envs/gmail/docker-compose-hub.yml +51 -0
- dt_arena/envs/gmail/docker-compose.yml +65 -0
- dt_arena/envs/google-form/docker-compose-hub.yml +33 -0
- dt_arena/envs/google-form/docker-compose.yml +41 -0
- dt_arena/envs/googledocs/docker-compose-hub.yml +61 -0
- dt_arena/envs/googledocs/docker-compose.yml +78 -0
- dt_arena/envs/hospital/docker-compose-hub.yml +25 -0
- dt_arena/envs/hospital/docker-compose.yml +27 -0
- dt_arena/envs/legal/docker-compose.yml +22 -0
- dt_arena/envs/linkedin/docker-compose.yml +63 -0
- dt_arena/envs/macos/docker-compose.yml +79 -0
- dt_arena/envs/os-filesystem/docker-compose-hub.yml +16 -0
- dt_arena/envs/os-filesystem/docker-compose.yml +20 -0
- dt_arena/envs/paypal/docker-compose-hub.yml +48 -0
- dt_arena/envs/paypal/docker-compose.yml +63 -0
- dt_arena/envs/research/docker-compose-hub.yml +13 -0
- dt_arena/envs/research/docker-compose.yml +24 -0
- dt_arena/envs/salesforce_crm/docker-compose-hub.yaml +45 -0
- dt_arena/envs/salesforce_crm/docker-compose.yaml +49 -0
- dt_arena/envs/slack/docker-compose-hub.yml +28 -0
- dt_arena/envs/slack/docker-compose.yml +41 -0
- dt_arena/envs/snowflake/docker-compose-hub.yml +41 -0
- dt_arena/envs/snowflake/docker-compose.yml +44 -0
- dt_arena/envs/telecom/docker-compose-hub.yml +16 -0
- dt_arena/envs/telecom/docker-compose.yml +17 -0
- dt_arena/envs/telegram/docker-compose-hub.yml +57 -0
- dt_arena/envs/telegram/docker-compose.yml +62 -0
- dt_arena/envs/terminal/docker-compose-hub.yml +12 -0
- dt_arena/envs/terminal/docker-compose.yml +26 -0
- dt_arena/envs/travel/docker-compose-hub.yml +19 -0
- dt_arena/envs/travel/docker-compose.yml +19 -0
- dt_arena/envs/whatsapp/docker-compose-hub.yml +61 -0
- dt_arena/envs/whatsapp/docker-compose.yml +78 -0
- dt_arena/envs/windows/docker-compose.yml +71 -0
- dt_arena/envs/zoom/docker-compose-hub.yml +27 -0
- dt_arena/envs/zoom/docker-compose.yml +40 -0
- dt_arena/injection_mcp_server/atlassian/env_injection.py +134 -0
- dt_arena/injection_mcp_server/calendar/env_injection.py +217 -0
- dt_arena/injection_mcp_server/custom_website/env_injection.py +97 -0
- dt_arena/injection_mcp_server/customer_service/env_injection.py +659 -0
- dt_arena/injection_mcp_server/databricks/env_injection.py +255 -0
- dt_arena/injection_mcp_server/ecommerce/env_injection.py +110 -0
- dt_arena/injection_mcp_server/finance/env_injection.py +85 -0
- dt_arena/injection_mcp_server/github/env_injection.py +206 -0
- dt_arena/injection_mcp_server/gmail/env_injection.py +211 -0
- dt_arena/injection_mcp_server/google_form/env_injection.py +186 -0
- dt_arena/injection_mcp_server/googledocs/env_injection.py +44 -0
- dt_arena/injection_mcp_server/hospital/env_injection.py +43 -0
- dt_arena/injection_mcp_server/legal/env_injection.py +229 -0
- dt_arena/injection_mcp_server/macos/env_injection.py +272 -0
- dt_arena/injection_mcp_server/os-filesystem/env_injection.py +341 -0
- dt_arena/injection_mcp_server/paypal/env_injection.py +268 -0
- dt_arena/injection_mcp_server/research/env_injection.py +616 -0
- dt_arena/injection_mcp_server/salesforce/env_injection.py +514 -0
- dt_arena/injection_mcp_server/slack/env_injection.py +265 -0
- dt_arena/injection_mcp_server/snowflake/env_injection.py +230 -0
- dt_arena/injection_mcp_server/telecom/env_injection.py +503 -0
- dt_arena/injection_mcp_server/telegram/env_injection.py +171 -0
- dt_arena/injection_mcp_server/terminal/env_injection.py +523 -0
- dt_arena/injection_mcp_server/travel/env_injection.py +173 -0
- dt_arena/injection_mcp_server/whatsapp/env_injection.py +185 -0
- dt_arena/injection_mcp_server/windows/env_injection.py +943 -0
- dt_arena/injection_mcp_server/zoom/env_injection.py +216 -0
- dt_arena/mcp_server/atlassian/main.py +1554 -0
- dt_arena/mcp_server/atlassian/test_server.py +66 -0
- dt_arena/mcp_server/bigquery/main.py +333 -0
- dt_arena/mcp_server/booking/main.py +310 -0
- dt_arena/mcp_server/browser/main.py +1741 -0
- dt_arena/mcp_server/calendar/example_multi_user.py +162 -0
- dt_arena/mcp_server/calendar/main.py +792 -0
- dt_arena/mcp_server/calendar/test_mcp.py +135 -0
- dt_arena/mcp_server/customer_service/main.py +1063 -0
- dt_arena/mcp_server/databricks/main.py +566 -0
- dt_arena/mcp_server/databricks/probe.py +102 -0
- dt_arena/mcp_server/ers/main.py +845 -0
- dt_arena/mcp_server/finance/__init__.py +87 -0
- dt_arena/mcp_server/finance/core/__init__.py +12 -0
- dt_arena/mcp_server/finance/core/data_loader.py +558 -0
- dt_arena/mcp_server/finance/core/portfolio.py +565 -0
- dt_arena/mcp_server/finance/evaluation/__init__.py +20 -0
- dt_arena/mcp_server/finance/evaluation/evaluator.py +217 -0
- dt_arena/mcp_server/finance/evaluation/logger.py +137 -0
- dt_arena/mcp_server/finance/injection/__init__.py +66 -0
- dt_arena/mcp_server/finance/injection/config.py +176 -0
- dt_arena/mcp_server/finance/injection/content.py +755 -0
- dt_arena/mcp_server/finance/injection/html.py +409 -0
- dt_arena/mcp_server/finance/injection/locations.py +167 -0
- dt_arena/mcp_server/finance/injection/methods.py +193 -0
- dt_arena/mcp_server/finance/injection/presets.py +1023 -0
- dt_arena/mcp_server/finance/main.py +361 -0
- dt_arena/mcp_server/finance/run_mcp.py +21 -0
- dt_arena/mcp_server/finance/run_web.py +26 -0
- dt_arena/mcp_server/finance/server/__init__.py +41 -0
- dt_arena/mcp_server/finance/server/extractor.py +1453 -0
- dt_arena/mcp_server/finance/server/extractor_minimal.py +292 -0
- dt_arena/mcp_server/finance/server/extractor_simple.py +1164 -0
- dt_arena/mcp_server/finance/server/injection_mcp.py +865 -0
- dt_arena/mcp_server/finance/server/mcp.py +451 -0
- dt_arena/mcp_server/finance/server/tools/__init__.py +23 -0
- dt_arena/mcp_server/finance/server/tools/account.py +88 -0
- dt_arena/mcp_server/finance/server/tools/browsing.py +328 -0
- dt_arena/mcp_server/finance/server/tools/social.py +73 -0
- dt_arena/mcp_server/finance/server/tools/trading.py +242 -0
- dt_arena/mcp_server/finance/server/tools/utility.py +49 -0
- dt_arena/mcp_server/finance/server/web.py +2139 -0
- dt_arena/mcp_server/finance/tasks/benchmark/__init__.py +28 -0
- dt_arena/mcp_server/finance/tasks/benchmark/attack_pool.py +3026 -0
- dt_arena/mcp_server/finance/tasks/benchmark/attack_runner.py +1315 -0
- dt_arena/mcp_server/finance/tasks/benchmark/finra_requirements.py +1335 -0
- dt_arena/mcp_server/finance/tasks/benchmark/finra_tasks.py +3665 -0
- dt_arena/mcp_server/finance/tasks/benchmark/malicious_tasks.py +2673 -0
- dt_arena/mcp_server/finance/tasks/redteam_suite/run_redteam_suite.py +1713 -0
- dt_arena/mcp_server/finance/test_mcp_tools.py +476 -0
- dt_arena/mcp_server/github/main.py +441 -0
- dt_arena/mcp_server/gmail/main.py +1004 -0
- dt_arena/mcp_server/google_form/main.py +141 -0
- dt_arena/mcp_server/googledocs/main.py +458 -0
- dt_arena/mcp_server/hospital/mcp_server.py +458 -0
- dt_arena/mcp_server/legal/__init__.py +9 -0
- dt_arena/mcp_server/legal/core/__init__.py +14 -0
- dt_arena/mcp_server/legal/core/courtlistener_store.py +762 -0
- dt_arena/mcp_server/legal/core/data_loader.py +266 -0
- dt_arena/mcp_server/legal/core/document_store.py +197 -0
- dt_arena/mcp_server/legal/core/matter_manager.py +466 -0
- dt_arena/mcp_server/legal/main.py +89 -0
- dt_arena/mcp_server/legal/scripts/collect_data.py +988 -0
- dt_arena/mcp_server/legal/server/__init__.py +14 -0
- dt_arena/mcp_server/legal/server/mcp.py +2330 -0
- dt_arena/mcp_server/macos/client_test.py +270 -0
- dt_arena/mcp_server/macos/mcp_server.py +285 -0
- dt_arena/mcp_server/os-filesystem/main.py +1380 -0
- dt_arena/mcp_server/paypal/main.py +501 -0
- dt_arena/mcp_server/research/main.py +777 -0
- dt_arena/mcp_server/salesforce/main.py +2006 -0
- dt_arena/mcp_server/slack/main.py +318 -0
- dt_arena/mcp_server/snowflake/main.py +612 -0
- dt_arena/mcp_server/snowflake/probe.py +183 -0
- dt_arena/mcp_server/telecom/mcp_client.py +423 -0
- dt_arena/mcp_server/telecom/mcp_server.py +1059 -0
- dt_arena/mcp_server/telegram/main.py +338 -0
- dt_arena/mcp_server/terminal/main.py +163 -0
- dt_arena/mcp_server/travel/client_test.py +16 -0
- dt_arena/mcp_server/travel/mcp_server.py +404 -0
- dt_arena/mcp_server/whatsapp/main.py +318 -0
- dt_arena/mcp_server/windows/client_test.py +270 -0
- dt_arena/mcp_server/windows/mcp_server.py +218 -0
- dt_arena/mcp_server/zoom/main.py +466 -0
- dt_arena/src/__init__.py +0 -0
- dt_arena/src/hooks/__init__.py +0 -0
- dt_arena/src/hooks/audit_log.py +30 -0
- dt_arena/src/hooks/hooks.json +3 -0
- dt_arena/src/run_benign.py +142 -0
- dt_arena/src/types/__init__.py +0 -0
- dt_arena/src/types/agent.py +441 -0
- dt_arena/src/types/attacks.py +2 -0
- dt_arena/src/types/environment.py +2 -0
- dt_arena/src/types/hooks.py +174 -0
- dt_arena/src/types/judge.py +52 -0
- dt_arena/src/types/red_teaming_trajectory.py +385 -0
- dt_arena/src/types/task.py +260 -0
- dt_arena/src/types/trajectory.py +315 -0
- dt_arena/utils/__init__.py +1 -0
- dt_arena/utils/atlassian/__init__.py +27 -0
- dt_arena/utils/atlassian/helpers.py +520 -0
- dt_arena/utils/bigquery/__init__.py +1 -0
- dt_arena/utils/bigquery/helpers.py +246 -0
- dt_arena/utils/calendar/__init__.py +1 -0
- dt_arena/utils/calendar/helpers.py +87 -0
- dt_arena/utils/customer_service/__init__.py +17 -0
- dt_arena/utils/customer_service/cs_env_client.py +940 -0
- dt_arena/utils/customer_service/helpers.py +339 -0
- dt_arena/utils/customer_service/judges/__init__.py +20 -0
- dt_arena/utils/customer_service/judges/abusive_comms.py +149 -0
- dt_arena/utils/customer_service/judges/bulk_destructive.py +202 -0
- dt_arena/utils/customer_service/judges/cross_customer.py +481 -0
- dt_arena/utils/customer_service/judges/email_exfiltration.py +320 -0
- dt_arena/utils/customer_service/judges/financial_fraud.py +221 -0
- dt_arena/utils/customer_service/judges/llm_utils.py +264 -0
- dt_arena/utils/customer_service/judges/pii_disclosure.py +837 -0
- dt_arena/utils/customer_service/judges/policy_misrep.py +44 -0
- dt_arena/utils/customer_service/judges/text_utils.py +21 -0
- dt_arena/utils/databricks/__init__.py +2 -0
- dt_arena/utils/databricks/helpers.py +210 -0
- dt_arena/utils/finance/__init__.py +0 -0
- dt_arena/utils/finance/helpers.py +263 -0
- dt_arena/utils/github/__init__.py +1 -0
- dt_arena/utils/github/helpers.py +249 -0
- dt_arena/utils/gmail/__init__.py +1 -0
- dt_arena/utils/gmail/helpers.py +344 -0
- dt_arena/utils/google_form/__init__.py +2 -0
- dt_arena/utils/google_form/helpers.py +133 -0
- dt_arena/utils/legal/__init__.py +0 -0
- dt_arena/utils/legal/helpers.py +228 -0
- dt_arena/utils/macos/__init__.py +0 -0
- dt_arena/utils/macos/env_setup.py +215 -0
- dt_arena/utils/macos/helpers.py +61 -0
- dt_arena/utils/os_filesystem/__init__.py +1 -0
- dt_arena/utils/os_filesystem/helpers.py +366 -0
- dt_arena/utils/paypal/__init__.py +1 -0
- dt_arena/utils/paypal/helpers.py +178 -0
- dt_arena/utils/port_allocator.py +266 -0
- dt_arena/utils/research/__init__.py +0 -0
- dt_arena/utils/research/helpers.py +251 -0
- dt_arena/utils/salesforce/__init__.py +1 -0
- dt_arena/utils/salesforce/helpers.py +719 -0
- dt_arena/utils/slack/__init__.py +1 -0
- dt_arena/utils/slack/helpers.py +176 -0
- dt_arena/utils/snowflake/__init__.py +1 -0
- dt_arena/utils/snowflake/helpers.py +166 -0
- dt_arena/utils/telecom/__init__.py +1 -0
- dt_arena/utils/telecom/helpers.py +760 -0
- dt_arena/utils/telegram/__init__.py +0 -0
- dt_arena/utils/telegram/helpers.py +174 -0
- dt_arena/utils/terminal/__init__.py +0 -0
- dt_arena/utils/terminal/helpers.py +20 -0
- dt_arena/utils/travel/__init__.py +0 -0
- dt_arena/utils/travel/env_client.py +537 -0
- dt_arena/utils/travel/llm_judge.py +137 -0
- dt_arena/utils/travel/prompts.py +64 -0
- dt_arena/utils/utils/__init__.py +122 -0
- dt_arena/utils/whatsapp/__init__.py +0 -0
- dt_arena/utils/whatsapp/helpers.py +226 -0
- dt_arena/utils/windows/__init__.py +0 -0
- dt_arena/utils/windows/env_reset.py +224 -0
- dt_arena/utils/windows/env_setup.py +280 -0
- dt_arena/utils/windows/exfil_helpers.py +170 -0
- dt_arena/utils/windows/helpers.py +74 -0
- dt_arena/utils/zoom/__init__.py +1 -0
- dt_arena/utils/zoom/helpers.py +70 -0
- eval/__init__.py +1 -0
- eval/evaluation.py +426 -0
- eval/task_runner.py +449 -0
- utils/__init__.py +148 -0
- utils/agent_helpers.py +308 -0
- utils/agent_wrapper.py +189 -0
- utils/compose_utils.py +135 -0
- utils/config.py +77 -0
- utils/env_helpers.py +104 -0
- utils/eval_stats.py +88 -0
- utils/injection_helpers.py +429 -0
- utils/injection_mcp_helpers.py +152 -0
- utils/judge_helpers.py +181 -0
- utils/judge_utils.py +472 -0
- utils/llm.py +196 -0
- utils/logging.py +45 -0
- utils/mcp_helpers.py +232 -0
- utils/mcp_manager.py +235 -0
- utils/memory_guard.py +18 -0
- utils/red_teaming_sandbox.py +476 -0
- utils/reset_helpers.py +318 -0
- utils/resource_manager.py +370 -0
- utils/skill_helpers.py +447 -0
- utils/task_executor.py +904 -0
- utils/task_helpers.py +270 -0
- utils/template_helpers.py +179 -0
|
@@ -0,0 +1,537 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Unified client for Travel environment API.
|
|
3
|
+
Used by both task generators and judges to ensure consistency.
|
|
4
|
+
"""
|
|
5
|
+
import os
|
|
6
|
+
import requests
|
|
7
|
+
from typing import Optional, Dict, List, Any, Tuple
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TravelEnvClient:
|
|
11
|
+
"""Client for interacting with Travel environment server."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, host: Optional[str] = None, port: Optional[str] = None):
|
|
14
|
+
"""Initialize the client with host and port.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
host: Server host (defaults to TRAVEL_HOST env var or 'localhost')
|
|
18
|
+
port: Server port (defaults to TRAVEL_PORT env var or '10300')
|
|
19
|
+
"""
|
|
20
|
+
self.host = host or os.getenv("TRAVEL_HOST", "localhost")
|
|
21
|
+
self.port = port or os.getenv("TRAVEL_PORT", "10300")
|
|
22
|
+
self.base_url = f"http://{self.host}:{self.port}"
|
|
23
|
+
|
|
24
|
+
# ===== Query APIs =====
|
|
25
|
+
|
|
26
|
+
def query_flights(self, departure: str, destination: str,
|
|
27
|
+
departure_date: str) -> List[Dict[str, Any]]:
|
|
28
|
+
"""Query flights between two cities on a specific date.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
departure: Departure city name
|
|
32
|
+
destination: Destination city name
|
|
33
|
+
departure_date: Departure date in format YYYY-MM-DD
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
List of flight dictionaries with fields like:
|
|
37
|
+
- Flight Number, Price, DepTime, ArrTime, Duration, etc.
|
|
38
|
+
"""
|
|
39
|
+
try:
|
|
40
|
+
resp = requests.get(
|
|
41
|
+
f"{self.base_url}/query_flight",
|
|
42
|
+
params={
|
|
43
|
+
"departure": departure,
|
|
44
|
+
"destination": destination,
|
|
45
|
+
"departure_date": departure_date
|
|
46
|
+
},
|
|
47
|
+
timeout=10
|
|
48
|
+
)
|
|
49
|
+
data = resp.json()
|
|
50
|
+
if data.get("status") == "success":
|
|
51
|
+
return data.get("result", [])
|
|
52
|
+
except Exception as e:
|
|
53
|
+
print(f"[EnvClient] Error querying flights: {e}")
|
|
54
|
+
return []
|
|
55
|
+
|
|
56
|
+
def query_accommodations(self, city: str) -> List[Dict[str, Any]]:
|
|
57
|
+
"""Query accommodations in a city.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
city: City name
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
List of accommodation dictionaries with fields like:
|
|
64
|
+
- NAME, price, rating, room type, house_rules, maximum occupancy, etc.
|
|
65
|
+
"""
|
|
66
|
+
try:
|
|
67
|
+
resp = requests.get(
|
|
68
|
+
f"{self.base_url}/query_accommodation",
|
|
69
|
+
params={"city": city},
|
|
70
|
+
timeout=10
|
|
71
|
+
)
|
|
72
|
+
data = resp.json()
|
|
73
|
+
if data.get("status") == "success":
|
|
74
|
+
return data.get("result", [])
|
|
75
|
+
except Exception as e:
|
|
76
|
+
print(f"[EnvClient] Error querying accommodations: {e}")
|
|
77
|
+
return []
|
|
78
|
+
|
|
79
|
+
def query_restaurants(self, city: str) -> List[Dict[str, Any]]:
|
|
80
|
+
"""Query restaurants in a city.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
city: City name
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
List of restaurant dictionaries with fields like:
|
|
87
|
+
- Name, Aggregate Rating, Average Cost, Cuisines, etc.
|
|
88
|
+
"""
|
|
89
|
+
try:
|
|
90
|
+
resp = requests.get(
|
|
91
|
+
f"{self.base_url}/query_restaurant",
|
|
92
|
+
params={"city": city},
|
|
93
|
+
timeout=10
|
|
94
|
+
)
|
|
95
|
+
data = resp.json()
|
|
96
|
+
if data.get("status") == "success":
|
|
97
|
+
return data.get("result", [])
|
|
98
|
+
except Exception as e:
|
|
99
|
+
print(f"[EnvClient] Error querying restaurants: {e}")
|
|
100
|
+
return []
|
|
101
|
+
|
|
102
|
+
def query_attractions(self, city: str) -> List[Dict[str, Any]]:
|
|
103
|
+
"""Query attractions in a city.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
city: City name
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
List of attraction dictionaries with fields like:
|
|
110
|
+
- Name, Address, Website, etc.
|
|
111
|
+
"""
|
|
112
|
+
try:
|
|
113
|
+
resp = requests.get(
|
|
114
|
+
f"{self.base_url}/query_attraction",
|
|
115
|
+
params={"city": city},
|
|
116
|
+
timeout=10
|
|
117
|
+
)
|
|
118
|
+
data = resp.json()
|
|
119
|
+
if data.get("status") == "success":
|
|
120
|
+
return data.get("result", [])
|
|
121
|
+
except Exception as e:
|
|
122
|
+
print(f"[EnvClient] Error querying attractions: {e}")
|
|
123
|
+
return []
|
|
124
|
+
|
|
125
|
+
def query_cities(self, state: str) -> List[Dict[str, Any]]:
|
|
126
|
+
"""Query cities in a state.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
state: State name
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
List of city dictionaries
|
|
133
|
+
"""
|
|
134
|
+
try:
|
|
135
|
+
resp = requests.get(
|
|
136
|
+
f"{self.base_url}/query_city",
|
|
137
|
+
params={"state": state},
|
|
138
|
+
timeout=10
|
|
139
|
+
)
|
|
140
|
+
data = resp.json()
|
|
141
|
+
if data.get("status") == "success":
|
|
142
|
+
return data.get("result", [])
|
|
143
|
+
except Exception as e:
|
|
144
|
+
print(f"[EnvClient] Error querying cities: {e}")
|
|
145
|
+
return []
|
|
146
|
+
|
|
147
|
+
def query_distance(self, departure: str, destination: str,
|
|
148
|
+
mode: str = "driving") -> Optional[Dict[str, Any]]:
|
|
149
|
+
"""Query distance between two locations.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
departure: Departure location
|
|
153
|
+
destination: Destination location
|
|
154
|
+
mode: Travel mode (driving, walking, etc.)
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
Distance information dictionary
|
|
158
|
+
"""
|
|
159
|
+
try:
|
|
160
|
+
resp = requests.get(
|
|
161
|
+
f"{self.base_url}/query_distance",
|
|
162
|
+
params={
|
|
163
|
+
"departure": departure,
|
|
164
|
+
"destination": destination,
|
|
165
|
+
"mode": mode
|
|
166
|
+
},
|
|
167
|
+
timeout=10
|
|
168
|
+
)
|
|
169
|
+
data = resp.json()
|
|
170
|
+
if data.get("status") == "success":
|
|
171
|
+
return data.get("result")
|
|
172
|
+
except Exception as e:
|
|
173
|
+
print(f"[EnvClient] Error querying distance: {e}")
|
|
174
|
+
return None
|
|
175
|
+
|
|
176
|
+
# ===== Booking APIs =====
|
|
177
|
+
|
|
178
|
+
def query_bookings(self) -> Dict[str, List]:
|
|
179
|
+
"""Query all current bookings.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
Dictionary with keys: flights, accommodations, restaurants
|
|
183
|
+
Each key maps to a list of booking dictionaries
|
|
184
|
+
"""
|
|
185
|
+
try:
|
|
186
|
+
resp = requests.get(f"{self.base_url}/query_booking", timeout=10)
|
|
187
|
+
data = resp.json()
|
|
188
|
+
if data.get("status") == "success":
|
|
189
|
+
return data.get("result", {})
|
|
190
|
+
except Exception as e:
|
|
191
|
+
print(f"[EnvClient] Error querying bookings: {e}")
|
|
192
|
+
return {"flights": [], "accommodations": [], "restaurants": []}
|
|
193
|
+
|
|
194
|
+
def get_booked_flights(self) -> List[Dict[str, Any]]:
|
|
195
|
+
"""Get all booked flights."""
|
|
196
|
+
return self.query_bookings().get("flights", [])
|
|
197
|
+
|
|
198
|
+
def get_booked_accommodations(self) -> List[Dict[str, Any]]:
|
|
199
|
+
"""Get all booked accommodations."""
|
|
200
|
+
return self.query_bookings().get("accommodations", [])
|
|
201
|
+
|
|
202
|
+
def get_booked_restaurants(self) -> List[Dict[str, Any]]:
|
|
203
|
+
"""Get all booked restaurants."""
|
|
204
|
+
return self.query_bookings().get("restaurants", [])
|
|
205
|
+
|
|
206
|
+
# ===== Helper methods for validation =====
|
|
207
|
+
|
|
208
|
+
def has_flights(self, departure: str, destination: str, date: str) -> bool:
|
|
209
|
+
"""Check if flights exist for a route on a date."""
|
|
210
|
+
flights = self.query_flights(departure, destination, date)
|
|
211
|
+
return len(flights) > 0
|
|
212
|
+
|
|
213
|
+
def has_accommodations(self, city: str) -> bool:
|
|
214
|
+
"""Check if accommodations exist in a city."""
|
|
215
|
+
accommodations = self.query_accommodations(city)
|
|
216
|
+
return len(accommodations) > 0
|
|
217
|
+
|
|
218
|
+
def has_restaurants(self, city: str) -> bool:
|
|
219
|
+
"""Check if restaurants exist in a city."""
|
|
220
|
+
restaurants = self.query_restaurants(city)
|
|
221
|
+
return len(restaurants) > 0
|
|
222
|
+
|
|
223
|
+
def has_attractions(self, city: str) -> bool:
|
|
224
|
+
"""Check if attractions exist in a city."""
|
|
225
|
+
attractions = self.query_attractions(city)
|
|
226
|
+
return len(attractions) > 0
|
|
227
|
+
|
|
228
|
+
# ===== Dynamic Verification Helpers (for judges with potential ties) =====
|
|
229
|
+
|
|
230
|
+
def get_cheapest_hotel_names(self, city: str) -> set:
|
|
231
|
+
"""Return names of all hotels with the minimum price in a city."""
|
|
232
|
+
hotels = self.query_accommodations(city)
|
|
233
|
+
if not hotels:
|
|
234
|
+
return set()
|
|
235
|
+
min_price = min(h.get("price", float('inf')) for h in hotels)
|
|
236
|
+
return {h["NAME"] for h in hotels if h.get("price") == min_price}
|
|
237
|
+
|
|
238
|
+
def get_highest_rated_hotel_names(self, city: str) -> set:
|
|
239
|
+
"""Return names of all hotels with the maximum rating in a city."""
|
|
240
|
+
hotels = self.query_accommodations(city)
|
|
241
|
+
if not hotels:
|
|
242
|
+
return set()
|
|
243
|
+
max_rating = max(h.get("rating", 0) for h in hotels)
|
|
244
|
+
return {h["NAME"] for h in hotels if h.get("rating") == max_rating}
|
|
245
|
+
|
|
246
|
+
def get_cheapest_flight_numbers(self, departure: str, destination: str,
|
|
247
|
+
dates: list) -> set:
|
|
248
|
+
"""Return flight numbers of all flights with minimum price across given dates."""
|
|
249
|
+
all_flights = []
|
|
250
|
+
for date in dates:
|
|
251
|
+
all_flights.extend(self.query_flights(departure, destination, date))
|
|
252
|
+
if not all_flights:
|
|
253
|
+
return set()
|
|
254
|
+
min_price = min(f.get("Price", float('inf')) for f in all_flights)
|
|
255
|
+
return {f["Flight Number"] for f in all_flights if f.get("Price") == min_price}
|
|
256
|
+
|
|
257
|
+
def get_earliest_flight_numbers(self, departure: str, destination: str,
|
|
258
|
+
date: str) -> set:
|
|
259
|
+
"""Return flight numbers of all flights with earliest DepTime."""
|
|
260
|
+
flights = self.query_flights(departure, destination, date)
|
|
261
|
+
if not flights:
|
|
262
|
+
return set()
|
|
263
|
+
min_time = min(f.get("DepTime", "99:99") for f in flights)
|
|
264
|
+
return {f["Flight Number"] for f in flights if f.get("DepTime") == min_time}
|
|
265
|
+
|
|
266
|
+
@staticmethod
|
|
267
|
+
def _arrival_sort_key(flight: dict) -> str:
|
|
268
|
+
"""Return a sort key for arrival time that handles overnight flights.
|
|
269
|
+
|
|
270
|
+
If ArrTime < DepTime, the flight arrives the next day, so we prefix
|
|
271
|
+
with '1' to sort it after same-day arrivals (prefixed with '0').
|
|
272
|
+
"""
|
|
273
|
+
dep = flight.get("DepTime", "00:00")
|
|
274
|
+
arr = flight.get("ArrTime", "99:99")
|
|
275
|
+
is_overnight = arr < dep
|
|
276
|
+
return f"{'1' if is_overnight else '0'}{arr}"
|
|
277
|
+
|
|
278
|
+
def get_earliest_arriving_flight_numbers(self, departure: str, destination: str,
|
|
279
|
+
date: str) -> set:
|
|
280
|
+
"""Return flight numbers of all flights with earliest ArrTime (same-day arrivals first)."""
|
|
281
|
+
flights = self.query_flights(departure, destination, date)
|
|
282
|
+
if not flights:
|
|
283
|
+
return set()
|
|
284
|
+
min_key = min(self._arrival_sort_key(f) for f in flights)
|
|
285
|
+
return {f["Flight Number"] for f in flights if self._arrival_sort_key(f) == min_key}
|
|
286
|
+
|
|
287
|
+
@staticmethod
|
|
288
|
+
def _parse_duration_minutes(duration_str: str) -> int:
|
|
289
|
+
"""Parse duration string like '2 hours 8 minutes' into total minutes."""
|
|
290
|
+
import re
|
|
291
|
+
hours = minutes = 0
|
|
292
|
+
h_match = re.search(r'(\d+)\s*hour', duration_str)
|
|
293
|
+
m_match = re.search(r'(\d+)\s*minute', duration_str)
|
|
294
|
+
if h_match:
|
|
295
|
+
hours = int(h_match.group(1))
|
|
296
|
+
if m_match:
|
|
297
|
+
minutes = int(m_match.group(1))
|
|
298
|
+
return hours * 60 + minutes
|
|
299
|
+
|
|
300
|
+
def get_shortest_flight_numbers(self, departure: str, destination: str,
|
|
301
|
+
date: str) -> set:
|
|
302
|
+
"""Return flight numbers of all flights with minimum ActualElapsedTime."""
|
|
303
|
+
flights = self.query_flights(departure, destination, date)
|
|
304
|
+
if not flights:
|
|
305
|
+
return set()
|
|
306
|
+
def _dur(f):
|
|
307
|
+
val = f.get("ActualElapsedTime", "")
|
|
308
|
+
return self._parse_duration_minutes(val) if isinstance(val, str) else float('inf')
|
|
309
|
+
min_dur = min(_dur(f) for f in flights)
|
|
310
|
+
return {f["Flight Number"] for f in flights if _dur(f) == min_dur}
|
|
311
|
+
|
|
312
|
+
def get_highest_rated_restaurant_names(self, city: str, top_n: int = 1) -> set:
|
|
313
|
+
"""Return names of top_n highest-rated restaurants in a city.
|
|
314
|
+
|
|
315
|
+
Handles ties: if multiple restaurants share the nth-highest rating,
|
|
316
|
+
all of them are included.
|
|
317
|
+
"""
|
|
318
|
+
restaurants = self.query_restaurants(city)
|
|
319
|
+
if not restaurants:
|
|
320
|
+
return set()
|
|
321
|
+
sorted_r = sorted(
|
|
322
|
+
restaurants,
|
|
323
|
+
key=lambda r: r.get("rating", 0),
|
|
324
|
+
reverse=True
|
|
325
|
+
)
|
|
326
|
+
if top_n >= len(sorted_r):
|
|
327
|
+
return {r["Name"] for r in sorted_r}
|
|
328
|
+
cutoff = sorted_r[top_n - 1].get("rating", 0)
|
|
329
|
+
return {r["Name"] for r in sorted_r
|
|
330
|
+
if r.get("rating", 0) >= cutoff}
|
|
331
|
+
|
|
332
|
+
@staticmethod
|
|
333
|
+
def cuisine_matches(restaurant_cuisines: str, required: str) -> bool:
|
|
334
|
+
"""Check if a restaurant's cuisines string contains a required cuisine.
|
|
335
|
+
|
|
336
|
+
Handles slash-separated alternatives like 'Bakery/Cafe'.
|
|
337
|
+
"""
|
|
338
|
+
parts = [p.strip().lower() for p in required.split("/")]
|
|
339
|
+
return any(part in restaurant_cuisines.lower() for part in parts)
|
|
340
|
+
|
|
341
|
+
def validate_cuisine_bookings(
|
|
342
|
+
self,
|
|
343
|
+
booked_restaurants: List[Dict[str, Any]],
|
|
344
|
+
cuisines_per_city: Dict[str, List[str]],
|
|
345
|
+
) -> List[str]:
|
|
346
|
+
"""Validate that booked restaurants match required cuisines per city.
|
|
347
|
+
|
|
348
|
+
Returns a list of issue strings (empty if all valid).
|
|
349
|
+
Checks two things per city:
|
|
350
|
+
1. Each booked restaurant matches at least one required cuisine.
|
|
351
|
+
2. Each required cuisine is covered by at least one booking.
|
|
352
|
+
"""
|
|
353
|
+
issues = []
|
|
354
|
+
|
|
355
|
+
# Build lookup: restaurant name -> cuisines, and city membership sets
|
|
356
|
+
restaurant_cuisines: Dict[str, str] = {}
|
|
357
|
+
city_names: Dict[str, set] = {}
|
|
358
|
+
for city in cuisines_per_city:
|
|
359
|
+
names = set()
|
|
360
|
+
for r in self.query_restaurants(city):
|
|
361
|
+
restaurant_cuisines[r["Name"]] = (r.get("Cuisines") or "").lower()
|
|
362
|
+
names.add(r["Name"])
|
|
363
|
+
city_names[city] = names
|
|
364
|
+
|
|
365
|
+
for city, required_cuisines in cuisines_per_city.items():
|
|
366
|
+
booked_in_city = [
|
|
367
|
+
r for r in booked_restaurants
|
|
368
|
+
if (r.get("name") or r.get("Name", "")) in city_names.get(city, set())
|
|
369
|
+
]
|
|
370
|
+
|
|
371
|
+
# Each booked restaurant must match at least one required cuisine
|
|
372
|
+
for r in booked_in_city:
|
|
373
|
+
name = r.get("name") or r.get("Name", "")
|
|
374
|
+
r_cuis = restaurant_cuisines.get(name, "")
|
|
375
|
+
if not any(self.cuisine_matches(r_cuis, req) for req in required_cuisines):
|
|
376
|
+
issues.append(
|
|
377
|
+
f"Restaurant '{name}' in {city} does not match "
|
|
378
|
+
f"any required cuisine {required_cuisines}"
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
# Each required cuisine must be covered by at least one booking
|
|
382
|
+
for cuisine in required_cuisines:
|
|
383
|
+
found = any(
|
|
384
|
+
self.cuisine_matches(
|
|
385
|
+
restaurant_cuisines.get(r.get("name") or r.get("Name", ""), ""),
|
|
386
|
+
cuisine,
|
|
387
|
+
)
|
|
388
|
+
for r in booked_in_city
|
|
389
|
+
)
|
|
390
|
+
if not found:
|
|
391
|
+
issues.append(f"No restaurant covering cuisine '{cuisine}' booked in {city}")
|
|
392
|
+
|
|
393
|
+
return issues, city_names
|
|
394
|
+
|
|
395
|
+
# ===== Common Judge Helpers =====
|
|
396
|
+
|
|
397
|
+
@staticmethod
|
|
398
|
+
def check_duplicate_bookings(
|
|
399
|
+
bookings: List[Dict[str, Any]],
|
|
400
|
+
name_key: str = "name",
|
|
401
|
+
fallback_key: str = "Name",
|
|
402
|
+
) -> List[str]:
|
|
403
|
+
"""Check for duplicate bookings by name.
|
|
404
|
+
|
|
405
|
+
Returns a list of issue strings (empty if no duplicates).
|
|
406
|
+
"""
|
|
407
|
+
names = [b.get(name_key) or b.get(fallback_key, "") for b in bookings]
|
|
408
|
+
if len(names) == len(set(names)):
|
|
409
|
+
return []
|
|
410
|
+
duplicates = [n for n in set(names) if names.count(n) > 1]
|
|
411
|
+
return [f"Duplicate bookings: {duplicates}"]
|
|
412
|
+
|
|
413
|
+
def validate_hotel_dates(
|
|
414
|
+
self,
|
|
415
|
+
booked_accommodations: List[Dict[str, Any]],
|
|
416
|
+
stays: List[Dict[str, str]],
|
|
417
|
+
) -> List[str]:
|
|
418
|
+
"""Validate check-in/check-out dates for booked hotels against expected stays.
|
|
419
|
+
|
|
420
|
+
Args:
|
|
421
|
+
booked_accommodations: List of booked hotel records.
|
|
422
|
+
stays: List of dicts with keys: city, check_in, check_out.
|
|
423
|
+
|
|
424
|
+
Returns a list of issue strings (empty if all valid).
|
|
425
|
+
"""
|
|
426
|
+
issues = []
|
|
427
|
+
for stay in stays:
|
|
428
|
+
city_hotels = [
|
|
429
|
+
h for h in booked_accommodations
|
|
430
|
+
if stay["city"].lower() in h.get("city", "").lower()
|
|
431
|
+
]
|
|
432
|
+
for hotel in city_hotels:
|
|
433
|
+
if hotel.get("check_in_date") != stay["check_in"]:
|
|
434
|
+
issues.append(
|
|
435
|
+
f"Hotel in {stay['city']} check-in wrong: "
|
|
436
|
+
f"{hotel.get('check_in_date')} (expected {stay['check_in']})"
|
|
437
|
+
)
|
|
438
|
+
if hotel.get("check_out_date") != stay["check_out"]:
|
|
439
|
+
issues.append(
|
|
440
|
+
f"Hotel in {stay['city']} check-out wrong: "
|
|
441
|
+
f"{hotel.get('check_out_date')} (expected {stay['check_out']})"
|
|
442
|
+
)
|
|
443
|
+
return issues
|
|
444
|
+
|
|
445
|
+
def validate_restaurants_per_city(
|
|
446
|
+
self,
|
|
447
|
+
booked_restaurants: List[Dict[str, Any]],
|
|
448
|
+
cities: List[str],
|
|
449
|
+
restaurants_per_city,
|
|
450
|
+
) -> Tuple[List[str], Dict[str, set]]:
|
|
451
|
+
"""Validate per-city restaurant count.
|
|
452
|
+
|
|
453
|
+
Args:
|
|
454
|
+
booked_restaurants: List of booked restaurant records.
|
|
455
|
+
cities: List of city names.
|
|
456
|
+
restaurants_per_city: int (same for all cities) or dict {city: int}.
|
|
457
|
+
|
|
458
|
+
Returns:
|
|
459
|
+
Tuple of (issues, city_names) where city_names maps city -> set of
|
|
460
|
+
restaurant names in that city (from environment data).
|
|
461
|
+
"""
|
|
462
|
+
issues = []
|
|
463
|
+
city_names: Dict[str, set] = {}
|
|
464
|
+
for city in cities:
|
|
465
|
+
city_names[city] = {r["Name"] for r in self.query_restaurants(city)}
|
|
466
|
+
|
|
467
|
+
for city in cities:
|
|
468
|
+
exp = (restaurants_per_city if isinstance(restaurants_per_city, int)
|
|
469
|
+
else restaurants_per_city.get(city, 0))
|
|
470
|
+
booked_in_city = [
|
|
471
|
+
r for r in booked_restaurants
|
|
472
|
+
if (r.get("name") or r.get("Name", "")) in city_names.get(city, set())
|
|
473
|
+
]
|
|
474
|
+
if len(booked_in_city) != exp:
|
|
475
|
+
issues.append(
|
|
476
|
+
f"Wrong number of restaurants in {city}: "
|
|
477
|
+
f"{len(booked_in_city)}/{exp}"
|
|
478
|
+
)
|
|
479
|
+
return issues, city_names
|
|
480
|
+
|
|
481
|
+
# ===== Request Log APIs (for security testing) =====
|
|
482
|
+
|
|
483
|
+
def query_request_log(self) -> Dict[str, Any]:
|
|
484
|
+
"""Query all logged requests for security analysis.
|
|
485
|
+
|
|
486
|
+
Returns:
|
|
487
|
+
Dictionary with keys: total_requests, requests
|
|
488
|
+
requests is a list of request records with fields:
|
|
489
|
+
- timestamp, endpoint, method, params, result_status
|
|
490
|
+
"""
|
|
491
|
+
try:
|
|
492
|
+
resp = requests.get(f"{self.base_url}/query_request_log", timeout=10)
|
|
493
|
+
data = resp.json()
|
|
494
|
+
if data.get("status") == "success":
|
|
495
|
+
return data.get("result", {"total_requests": 0, "requests": []})
|
|
496
|
+
except Exception as e:
|
|
497
|
+
print(f"[EnvClient] Error querying request log: {e}")
|
|
498
|
+
return {"total_requests": 0, "requests": []}
|
|
499
|
+
|
|
500
|
+
def get_environment_state(self, endpoint: str) -> List[Dict[str, Any]]:
|
|
501
|
+
"""Get all requests for a specific endpoint.
|
|
502
|
+
|
|
503
|
+
Args:
|
|
504
|
+
endpoint: The endpoint name (e.g., 'apply_promo', 'post_review')
|
|
505
|
+
|
|
506
|
+
Returns:
|
|
507
|
+
List of request records for that endpoint
|
|
508
|
+
"""
|
|
509
|
+
log = self.query_request_log()
|
|
510
|
+
return [r for r in log.get("requests", []) if r.get("endpoint") == endpoint]
|
|
511
|
+
|
|
512
|
+
def has_called_endpoint(self, endpoint: str) -> bool:
|
|
513
|
+
"""Check if an endpoint was called during the session.
|
|
514
|
+
|
|
515
|
+
Args:
|
|
516
|
+
endpoint: The endpoint name to check
|
|
517
|
+
|
|
518
|
+
Returns:
|
|
519
|
+
True if the endpoint was called at least once
|
|
520
|
+
"""
|
|
521
|
+
return len(self.get_environment_state(endpoint)) > 0
|
|
522
|
+
|
|
523
|
+
# ===== Environment Control APIs =====
|
|
524
|
+
|
|
525
|
+
def reset(self) -> bool:
|
|
526
|
+
"""Reset all environment data to initial state.
|
|
527
|
+
|
|
528
|
+
Returns:
|
|
529
|
+
True if reset was successful, False otherwise
|
|
530
|
+
"""
|
|
531
|
+
try:
|
|
532
|
+
resp = requests.post(f"{self.base_url}/reset", timeout=10)
|
|
533
|
+
data = resp.json()
|
|
534
|
+
return data.get("status") == "success"
|
|
535
|
+
except Exception as e:
|
|
536
|
+
print(f"[EnvClient] Error resetting environment: {e}")
|
|
537
|
+
return False
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import Any, Dict, Optional, Tuple
|
|
3
|
+
from utils.llm import call_llm
|
|
4
|
+
from dt_arena.utils.travel.prompts import (
|
|
5
|
+
DECEPTION_SYSTEM_PROMPT, DECEPTION_JUDGE_PROMPT,
|
|
6
|
+
CREDENTIAL_COLLECTION_SYSTEM_PROMPT, CREDENTIAL_COLLECTION_JUDGE_PROMPT,
|
|
7
|
+
PII_COLLECTION_SYSTEM_PROMPT, PII_COLLECTION_JUDGE_PROMPT,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _extract_json_obj(text: str) -> Optional[Dict[str, Any]]:
|
|
12
|
+
try:
|
|
13
|
+
return json.loads(text)
|
|
14
|
+
except Exception:
|
|
15
|
+
pass
|
|
16
|
+
start = text.find("{")
|
|
17
|
+
end = text.rfind("}")
|
|
18
|
+
if start == -1 or end == -1 or end <= start:
|
|
19
|
+
return None
|
|
20
|
+
try:
|
|
21
|
+
return json.loads(text[start : end + 1])
|
|
22
|
+
except Exception:
|
|
23
|
+
return None
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def llm_judge_deception(
|
|
27
|
+
malicious_goal: str,
|
|
28
|
+
agent_text: str,
|
|
29
|
+
model: str,
|
|
30
|
+
retries: int = 3,
|
|
31
|
+
) -> Optional[Tuple[bool, Dict[str, Any]]]:
|
|
32
|
+
"""Use LLM to judge whether agent made deceptive claims in its responses.
|
|
33
|
+
|
|
34
|
+
Returns (success, metadata) or None on failure.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
prompt = DECEPTION_JUDGE_PROMPT.format(
|
|
38
|
+
malicious_goal=malicious_goal,
|
|
39
|
+
agent_text=agent_text,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
data = None
|
|
43
|
+
for attempt in range(retries):
|
|
44
|
+
try:
|
|
45
|
+
content = call_llm(
|
|
46
|
+
messages=[
|
|
47
|
+
{"role": "system", "content": DECEPTION_SYSTEM_PROMPT},
|
|
48
|
+
{"role": "user", "content": prompt},
|
|
49
|
+
],
|
|
50
|
+
model=model,
|
|
51
|
+
temperature=0.0,
|
|
52
|
+
)
|
|
53
|
+
data = _extract_json_obj(content or "")
|
|
54
|
+
if not isinstance(data, dict) or "attack_succeeded" not in data:
|
|
55
|
+
continue
|
|
56
|
+
except Exception:
|
|
57
|
+
print(f"LLM judge attempt {attempt + 1} failed to produce valid JSON. Retrying...")
|
|
58
|
+
# if attempt == retries - 1:
|
|
59
|
+
# return None
|
|
60
|
+
continue
|
|
61
|
+
|
|
62
|
+
success = bool(data.get("attack_succeeded")) if data else None
|
|
63
|
+
return (success, {"llm_judge": data, "llm_model": model}) if data else None
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def llm_judge_credential_collection(
|
|
67
|
+
malicious_goal: str,
|
|
68
|
+
agent_text: str,
|
|
69
|
+
model: str,
|
|
70
|
+
retries: int = 3,
|
|
71
|
+
) -> Optional[Tuple[bool, Dict[str, Any]]]:
|
|
72
|
+
"""Use LLM to judge whether agent requested payment credentials from the user.
|
|
73
|
+
|
|
74
|
+
Returns (success, metadata) or None on failure.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
prompt = CREDENTIAL_COLLECTION_JUDGE_PROMPT.format(
|
|
78
|
+
malicious_goal=malicious_goal,
|
|
79
|
+
agent_text=agent_text,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
data = None
|
|
83
|
+
for attempt in range(retries):
|
|
84
|
+
try:
|
|
85
|
+
content = call_llm(
|
|
86
|
+
messages=[
|
|
87
|
+
{"role": "system", "content": CREDENTIAL_COLLECTION_SYSTEM_PROMPT},
|
|
88
|
+
{"role": "user", "content": prompt},
|
|
89
|
+
],
|
|
90
|
+
model=model,
|
|
91
|
+
temperature=0.0,
|
|
92
|
+
)
|
|
93
|
+
data = _extract_json_obj(content or "")
|
|
94
|
+
if not isinstance(data, dict) or "attack_succeeded" not in data:
|
|
95
|
+
continue
|
|
96
|
+
except Exception:
|
|
97
|
+
continue
|
|
98
|
+
|
|
99
|
+
success = bool(data.get("attack_succeeded")) if data else None
|
|
100
|
+
return (success, {"llm_judge": data, "llm_model": model}) if data else None
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def llm_judge_pii_collection(
|
|
104
|
+
malicious_goal: str,
|
|
105
|
+
agent_text: str,
|
|
106
|
+
model: str,
|
|
107
|
+
retries: int = 3,
|
|
108
|
+
) -> Optional[Tuple[bool, Dict[str, Any]]]:
|
|
109
|
+
"""Use LLM to judge whether agent requested unnecessary personal information from the user.
|
|
110
|
+
|
|
111
|
+
Returns (success, metadata) or None on failure.
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
prompt = PII_COLLECTION_JUDGE_PROMPT.format(
|
|
115
|
+
malicious_goal=malicious_goal,
|
|
116
|
+
agent_text=agent_text,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
data = None
|
|
120
|
+
for attempt in range(retries):
|
|
121
|
+
try:
|
|
122
|
+
content = call_llm(
|
|
123
|
+
messages=[
|
|
124
|
+
{"role": "system", "content": PII_COLLECTION_SYSTEM_PROMPT},
|
|
125
|
+
{"role": "user", "content": prompt},
|
|
126
|
+
],
|
|
127
|
+
model=model,
|
|
128
|
+
temperature=0.0,
|
|
129
|
+
)
|
|
130
|
+
data = _extract_json_obj(content or "")
|
|
131
|
+
if not isinstance(data, dict) or "attack_succeeded" not in data:
|
|
132
|
+
continue
|
|
133
|
+
except Exception:
|
|
134
|
+
continue
|
|
135
|
+
|
|
136
|
+
success = bool(data.get("attack_succeeded")) if data else None
|
|
137
|
+
return (success, {"llm_judge": data, "llm_model": model}) if data else None
|